From cacf5a943b75dad75119c38da308957db149bbe8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 21 Aug 2023 12:01:53 +0300 Subject: [PATCH 0001/1647] Add mutex around EnsureRegistered --- appservice/intent.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/appservice/intent.go b/appservice/intent.go index af6fea37..7995f44b 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "strings" + "sync" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" @@ -23,6 +24,8 @@ type IntentAPI struct { Localpart string UserID id.UserID + registerLock sync.Mutex + IsCustomPuppet bool } @@ -53,6 +56,8 @@ func (intent *IntentAPI) Register() error { } func (intent *IntentAPI) EnsureRegistered() error { + intent.registerLock.Lock() + defer intent.registerLock.Unlock() if intent.IsCustomPuppet || intent.as.StateStore.IsRegistered(intent.UserID) { return nil } From ac5c2c22102cb54a93bd60267bb431e522914154 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 21 Aug 2023 13:48:10 +0300 Subject: [PATCH 0002/1647] Add bridge double puppeting utility --- bridge/bridge.go | 5 + bridge/bridgeconfig/config.go | 7 ++ bridge/commands/doublepuppet.go | 6 +- bridge/doublepuppet.go | 172 ++++++++++++++++++++++++++++++++ requests.go | 2 + 5 files changed, 187 insertions(+), 5 deletions(-) create mode 100644 bridge/doublepuppet.go diff --git a/bridge/bridge.go b/bridge/bridge.go index 9cfc1450..dec6103b 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -102,6 +102,7 @@ type User interface { type DoublePuppet interface { CustomIntent() *appservice.IntentAPI SwitchCustomMXID(accessToken string, userID id.UserID) error + ClearCustomMXID() } type Ghost interface { @@ -171,6 +172,8 @@ type Bridge struct { PublicHSAddress *url.URL + DoublePuppet *doublePuppetUtil + AS *appservice.AppService EventProcessor *appservice.EventProcessor CommandProcessor CommandProcessor @@ -504,6 +507,8 @@ func (br *Bridge) init() { zerolog.DefaultContextLogger = &defaultCtxLog br.Log = maulogadapt.ZeroAsMau(br.ZLog) + br.DoublePuppet = &doublePuppetUtil{br: br, log: br.ZLog.With().Str("component", "double puppet").Logger()} + err = br.validateConfig() if err != nil { br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("Configuration error") diff --git a/bridge/bridgeconfig/config.go b/bridge/bridgeconfig/config.go index 7efc2ef1..be42aab3 100644 --- a/bridge/bridgeconfig/config.go +++ b/bridge/bridgeconfig/config.go @@ -162,12 +162,19 @@ type BridgeConfig interface { GetEncryptionConfig() EncryptionConfig GetCommandPrefix() string GetManagementRoomTexts() ManagementRoomTexts + GetDoublePuppetConfig() DoublePuppetConfig GetResendBridgeInfo() bool EnableMessageStatusEvents() bool EnableMessageErrorNotices() bool Validate() error } +type DoublePuppetConfig struct { + ServerMap map[string]string `yaml:"double_puppet_server_map"` + AllowDiscovery bool `yaml:"double_puppet_allow_discovery"` + SharedSecretMap map[string]string `yaml:"login_shared_secret_map"` +} + type EncryptionConfig struct { Allow bool `yaml:"allow"` Default bool `yaml:"default"` diff --git a/bridge/commands/doublepuppet.go b/bridge/commands/doublepuppet.go index 5bdcbd47..8c2e611e 100644 --- a/bridge/commands/doublepuppet.go +++ b/bridge/commands/doublepuppet.go @@ -78,10 +78,6 @@ func fnLogoutMatrix(ce *Event) { ce.Reply("You don't have double puppeting enabled.") return } - err := puppet.SwitchCustomMXID("", "") - if err != nil { - ce.Reply("Failed to disable double puppeting: %v", err) - return - } + puppet.ClearCustomMXID() ce.Reply("Successfully disabled double puppeting.") } diff --git a/bridge/doublepuppet.go b/bridge/doublepuppet.go new file mode 100644 index 00000000..7ddc1989 --- /dev/null +++ b/bridge/doublepuppet.go @@ -0,0 +1,172 @@ +// Copyright (c) 2023 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridge + +import ( + "crypto/hmac" + "crypto/sha512" + "encoding/hex" + "errors" + "fmt" + "strings" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/id" +) + +type doublePuppetUtil struct { + br *Bridge + log zerolog.Logger +} + +func (dp *doublePuppetUtil) newClient(mxid id.UserID, accessToken string) (*mautrix.Client, error) { + _, homeserver, err := mxid.Parse() + if err != nil { + return nil, err + } + homeserverURL, found := dp.br.Config.Bridge.GetDoublePuppetConfig().ServerMap[homeserver] + if !found { + if homeserver == dp.br.AS.HomeserverDomain { + homeserverURL = "" + } else if dp.br.Config.Bridge.GetDoublePuppetConfig().AllowDiscovery { + resp, err := mautrix.DiscoverClientAPI(homeserver) + if err != nil { + return nil, fmt.Errorf("failed to find homeserver URL for %s: %v", homeserver, err) + } + homeserverURL = resp.Homeserver.BaseURL + dp.log.Debug(). + Str("homeserver", homeserver). + Str("url", homeserverURL). + Str("user_id", mxid.String()). + Msg("Discovered URL to enable double puppeting for user") + } else { + return nil, fmt.Errorf("double puppeting from %s is not allowed", homeserver) + } + } + return dp.br.AS.NewExternalMautrixClient(mxid, accessToken, homeserverURL) +} + +func (dp *doublePuppetUtil) newIntent(mxid id.UserID, accessToken string) (*appservice.IntentAPI, error) { + client, err := dp.newClient(mxid, accessToken) + if err != nil { + return nil, err + } + + ia := dp.br.AS.NewIntentAPI("custom") + ia.Client = client + ia.Localpart, _, _ = mxid.Parse() + ia.UserID = mxid + ia.IsCustomPuppet = true + return ia, nil +} + +func (dp *doublePuppetUtil) autoLogin(mxid id.UserID, loginSecret string) (string, error) { + dp.log.Debug().Str("user_id", mxid.String()).Msg("Logging into user account with shared secret") + client, err := dp.newClient(mxid, "") + if err != nil { + return "", fmt.Errorf("failed to create mautrix client to log in: %v", err) + } + bridgeName := fmt.Sprintf("%s Bridge", dp.br.ProtocolName) + req := mautrix.ReqLogin{ + Identifier: mautrix.UserIdentifier{Type: mautrix.IdentifierTypeUser, User: string(mxid)}, + DeviceID: id.DeviceID(bridgeName), + InitialDeviceDisplayName: bridgeName, + } + if loginSecret == "appservice" { + client.AccessToken = dp.br.AS.Registration.AppToken + req.Type = mautrix.AuthTypeAppservice + } else { + loginFlows, err := client.GetLoginFlows() + if err != nil { + return "", fmt.Errorf("failed to get supported login flows: %w", err) + } + mac := hmac.New(sha512.New, []byte(loginSecret)) + mac.Write([]byte(mxid)) + token := hex.EncodeToString(mac.Sum(nil)) + switch { + case loginFlows.HasFlow(mautrix.AuthTypeDevtureSharedSecret): + req.Type = mautrix.AuthTypeDevtureSharedSecret + req.Token = token + case loginFlows.HasFlow(mautrix.AuthTypePassword): + req.Type = mautrix.AuthTypePassword + req.Password = token + default: + return "", fmt.Errorf("no supported auth types for shared secret auth found") + } + } + resp, err := client.Login(&req) + if err != nil { + return "", err + } + return resp.AccessToken, nil +} + +var ( + ErrMismatchingMXID = errors.New("whoami result does not match custom mxid") + ErrNoAccessToken = errors.New("no access token provided") + ErrNoMXID = errors.New("no mxid provided") +) + +const useConfigASToken = "appservice-config" +const asTokenModePrefix = "as_token:" + +func (dp *doublePuppetUtil) Setup(mxid id.UserID, savedAccessToken string, reloginOnFail bool) (intent *appservice.IntentAPI, newAccessToken string, err error) { + if len(mxid) == 0 { + err = ErrNoMXID + return + } + _, homeserver, _ := mxid.Parse() + loginSecret, hasSecret := dp.br.Config.Bridge.GetDoublePuppetConfig().SharedSecretMap[homeserver] + // Special case appservice: prefix to not login and use it as an as_token directly. + if hasSecret && strings.HasPrefix(loginSecret, asTokenModePrefix) { + intent, err = dp.newIntent(mxid, strings.TrimPrefix(loginSecret, asTokenModePrefix)) + if err != nil { + return + } + intent.SetAppServiceUserID = true + if savedAccessToken != useConfigASToken { + var resp *mautrix.RespWhoami + resp, err = intent.Whoami() + if err == nil && resp.UserID != mxid { + err = ErrMismatchingMXID + } + } + return intent, useConfigASToken, err + } + if savedAccessToken == "" || savedAccessToken == useConfigASToken { + if reloginOnFail && hasSecret { + savedAccessToken, err = dp.autoLogin(mxid, loginSecret) + } else { + err = ErrNoAccessToken + } + if err != nil { + return + } + } + intent, err = dp.newIntent(mxid, savedAccessToken) + if err != nil { + return + } + var resp *mautrix.RespWhoami + resp, err = intent.Whoami() + if err != nil { + if reloginOnFail && hasSecret && errors.Is(err, mautrix.MUnknownToken) { + intent.AccessToken, err = dp.autoLogin(mxid, loginSecret) + if err == nil { + newAccessToken = intent.AccessToken + } + } + } else if resp.UserID != mxid { + err = ErrMismatchingMXID + } else { + newAccessToken = savedAccessToken + } + return +} diff --git a/requests.go b/requests.go index f49a9468..985c8338 100644 --- a/requests.go +++ b/requests.go @@ -23,6 +23,8 @@ const ( AuthTypeAppservice AuthType = "m.login.application_service" AuthTypeSynapseJWT AuthType = "org.matrix.login.jwt" + + AuthTypeDevtureSharedSecret AuthType = "com.devture.shared_secret_auth" ) type IdentifierType string From 57d46e6a23398b0a4f3f3f291f8ea10242e9f17f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 31 Aug 2023 17:18:09 +0300 Subject: [PATCH 0003/1647] Don't escape + in user ID localparts --- id/userid.go | 6 +++--- id/userid_test.go | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/id/userid.go b/id/userid.go index 0522b54c..3aae3b21 100644 --- a/id/userid.go +++ b/id/userid.go @@ -72,7 +72,7 @@ func (userID UserID) URI() *MatrixURI { } } -var ValidLocalpartRegex = regexp.MustCompile("^[0-9a-z-.=_/]+$") +var ValidLocalpartRegex = regexp.MustCompile("^[0-9a-z-.=_/+]+$") // ValidateUserLocalpart validates a Matrix user ID localpart using the grammar // in https://matrix.org/docs/spec/appendices#user-identifier @@ -132,7 +132,7 @@ func escape(buf *bytes.Buffer, b byte) { } func shouldEncode(b byte) bool { - return b != '-' && b != '.' && b != '_' && !(b >= '0' && b <= '9') && !(b >= 'a' && b <= 'z') && !(b >= 'A' && b <= 'Z') + return b != '-' && b != '.' && b != '_' && b != '+' && !(b >= '0' && b <= '9') && !(b >= 'a' && b <= 'z') && !(b >= 'A' && b <= 'Z') } func shouldEscape(b byte) bool { @@ -140,7 +140,7 @@ func shouldEscape(b byte) bool { } func isValidByte(b byte) bool { - return isValidEscapedChar(b) || (b >= '0' && b <= '9') || b == '.' || b == '=' || b == '-' + return isValidEscapedChar(b) || (b >= '0' && b <= '9') || b == '.' || b == '=' || b == '-' || b == '+' } func isValidEscapedChar(b byte) bool { diff --git a/id/userid_test.go b/id/userid_test.go index a18dd314..359bc687 100644 --- a/id/userid_test.go +++ b/id/userid_test.go @@ -66,8 +66,8 @@ func TestUserID_ParseAndValidate_NotLong(t *testing.T) { } func TestUserIDEncoding(t *testing.T) { - const inputLocalpart = "This localpart contains IlLeGaL chäracters 🚨" - const encodedLocalpart = "_this=20localpart=20contains=20_il_le_ga_l=20ch=c3=a4racters=20=f0=9f=9a=a8" + const inputLocalpart = "This local+part contains IlLeGaL chäracters 🚨" + const encodedLocalpart = "_this=20local+part=20contains=20_il_le_ga_l=20ch=c3=a4racters=20=f0=9f=9a=a8" const inputServerName = "example.com" userID := id.NewEncodedUserID(inputLocalpart, inputServerName) parsedLocalpart, parsedServerName, err := userID.ParseAndValidate() From 691b96adc291619f93c52e8364bd53318b04c0e5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 2 Sep 2023 12:26:31 +0300 Subject: [PATCH 0004/1647] Add fields for beeper galleries --- event/message.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/event/message.go b/event/message.go index 30c928b1..542cab67 100644 --- a/event/message.go +++ b/event/message.go @@ -33,6 +33,8 @@ const ( MsgFile MessageType = "m.file" MsgVerificationRequest MessageType = "m.key.verification.request" + + MsgBeeperGallery MessageType = "com.beeper.gallery" ) // Format specifies the format of the formatted_body in m.room.message events. @@ -110,7 +112,10 @@ type MessageEventContent struct { replyFallbackRemoved bool - MessageSendRetry *BeeperRetryMetadata `json:"com.beeper.message_send_retry,omitempty"` + MessageSendRetry *BeeperRetryMetadata `json:"com.beeper.message_send_retry,omitempty"` + BeeperGalleryImages []*MessageEventContent `json:"com.beeper.gallery.images,omitempty"` + BeeperGalleryCaption string `json:"com.beeper.gallery.caption,omitempty"` + BeeperGalleryCaptionHTML string `json:"com.beeper.gallery.caption_html,omitempty"` } func (content *MessageEventContent) GetRelatesTo() *RelatesTo { From 3fffe3f31ce170eee1b707c5c8eb006966713277 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 4 Sep 2023 12:20:06 +0300 Subject: [PATCH 0005/1647] Clarify parameter syntax in bridge help message --- bridge/commands/help.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bridge/commands/help.go b/bridge/commands/help.go index 874d71ef..f4891555 100644 --- a/bridge/commands/help.go +++ b/bridge/commands/help.go @@ -110,6 +110,8 @@ func FormatHelp(ce *Event) string { } _, _ = fmt.Fprintf(&output, prefixMsg, ce.Bridge.Config.Bridge.GetCommandPrefix()) output.WriteByte('\n') + output.WriteString("Parameters in [square brackets] are optional, while parameters in are required.") + output.WriteByte('\n') output.WriteByte('\n') for _, section := range sortedSections { From aafd22eee6bb9a2b1a2bdcbd0d9a31274a333026 Mon Sep 17 00:00:00 2001 From: Brad Murray Date: Tue, 5 Sep 2023 12:18:04 -0400 Subject: [PATCH 0006/1647] Add OlmPkDecryption functions (#141) * Add OlmPkDecryption functions * Trim result to the valid size --- crypto/olm/pk.go | 59 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/crypto/olm/pk.go b/crypto/olm/pk.go index 1e628745..e441ba14 100644 --- a/crypto/olm/pk.go +++ b/crypto/olm/pk.go @@ -109,3 +109,62 @@ func (p *PkSigning) SignJSON(obj interface{}) (string, error) { func (p *PkSigning) lastError() error { return convertError(C.GoString(C.olm_pk_signing_last_error((*C.OlmPkSigning)(p.int)))) } + +type PkDecryption struct { + int *C.OlmPkDecryption + mem []byte + PublicKey []byte +} + +func pkDecryptionSize() uint { + return uint(C.olm_pk_decryption_size()) +} + +func pkDecryptionPublicKeySize() uint { + return uint(C.olm_pk_key_length()) +} + +func NewPkDecryption(privateKey []byte) (*PkDecryption, error) { + memory := make([]byte, pkDecryptionSize()) + p := &PkDecryption{ + int: C.olm_pk_decryption(unsafe.Pointer(&memory[0])), + mem: memory, + } + p.Clear() + pubKey := make([]byte, pkDecryptionPublicKeySize()) + + if C.olm_pk_key_from_private((*C.OlmPkDecryption)(p.int), + unsafe.Pointer(&pubKey[0]), C.size_t(len(pubKey)), + unsafe.Pointer(&privateKey[0]), C.size_t(len(privateKey))) == errorVal() { + return nil, p.lastError() + } + p.PublicKey = pubKey + + return p, nil +} + +func (p *PkDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byte) ([]byte, error) { + maxPlaintextLength := uint(C.olm_pk_max_plaintext_length((*C.OlmPkDecryption)(p.int), C.size_t(len(ciphertext)))) + plaintext := make([]byte, maxPlaintextLength) + + size := C.olm_pk_decrypt((*C.OlmPkDecryption)(p.int), + unsafe.Pointer(&ephemeralKey[0]), C.size_t(len(ephemeralKey)), + unsafe.Pointer(&mac[0]), C.size_t(len(mac)), + unsafe.Pointer(&ciphertext[0]), C.size_t(len(ciphertext)), + unsafe.Pointer(&plaintext[0]), C.size_t(len(plaintext))) + if size == errorVal() { + return nil, p.lastError() + } + + return plaintext[:size], nil +} + +// Clear clears the underlying memory of a PkDecryption object. +func (p *PkDecryption) Clear() { + C.olm_clear_pk_decryption((*C.OlmPkDecryption)(p.int)) +} + +// lastError returns the last error that happened in relation to this PkDecryption object. +func (p *PkDecryption) lastError() error { + return convertError(C.GoString(C.olm_pk_decryption_last_error((*C.OlmPkDecryption)(p.int)))) +} From 26b2e2e590a07e48a874ca31c3f39ca5cc915ad6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 6 Sep 2023 19:04:25 +0300 Subject: [PATCH 0007/1647] Use new retryafter utility --- client.go | 43 +++---------------------- client_internal_test.go | 70 ----------------------------------------- go.mod | 12 +++---- go.sum | 24 +++++++------- 4 files changed, 23 insertions(+), 126 deletions(-) delete mode 100644 client_internal_test.go diff --git a/client.go b/client.go index 43236ef2..17720026 100644 --- a/client.go +++ b/client.go @@ -18,6 +18,7 @@ import ( "time" "github.com/rs/zerolog" + "go.mau.fi/util/retryafter" "maunium.net/go/maulogger/v2/maulogadapt" "maunium.net/go/mautrix/event" @@ -526,36 +527,6 @@ func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) { } } -// parseBackoffFromResponse extracts the backoff time specified in the Retry-After header if present. See -// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After. -func parseBackoffFromResponse(req *http.Request, res *http.Response, now time.Time, fallback time.Duration) time.Duration { - retryAfterHeaderValue := res.Header.Get("Retry-After") - if retryAfterHeaderValue == "" { - return fallback - } - - if t, err := time.Parse(http.TimeFormat, retryAfterHeaderValue); err == nil { - return t.Sub(now) - } - - if seconds, err := strconv.Atoi(retryAfterHeaderValue); err == nil { - return time.Duration(seconds) * time.Second - } - - zerolog.Ctx(req.Context()).Warn(). - Str("retry_after", retryAfterHeaderValue). - Msg("Failed to parse Retry-After header value") - - return fallback -} - -func (cli *Client) shouldRetry(res *http.Response) bool { - return res.StatusCode == http.StatusBadGateway || - res.StatusCode == http.StatusServiceUnavailable || - res.StatusCode == http.StatusGatewayTimeout || - (res.StatusCode == http.StatusTooManyRequests && !cli.IgnoreRateLimit) -} - func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backoff time.Duration, responseJSON interface{}, handler ClientResponseHandler) ([]byte, error) { cli.RequestStart(req) startTime := time.Now() @@ -579,10 +550,8 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof return nil, err } - if retries > 0 && cli.shouldRetry(res) { - if res.StatusCode == http.StatusTooManyRequests { - backoff = parseBackoffFromResponse(req, res, time.Now(), backoff) - } + if retries > 0 && retryafter.Should(res.StatusCode, !cli.IgnoreRateLimit) { + backoff = retryafter.Parse(res.Header.Get("Retry-After"), backoff) return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler) } @@ -1427,10 +1396,8 @@ func (cli *Client) doMediaRequest(req *http.Request, retries int, backoff time.D return nil, err } - if retries > 0 && cli.shouldRetry(res) { - if res.StatusCode == http.StatusTooManyRequests { - backoff = parseBackoffFromResponse(req, res, time.Now(), backoff) - } + if retries > 0 && retryafter.Should(res.StatusCode, !cli.IgnoreRateLimit) { + backoff = retryafter.Parse(res.Header.Get("Retry-After"), backoff) return cli.doMediaRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff) } diff --git a/client_internal_test.go b/client_internal_test.go deleted file mode 100644 index e5d815cb..00000000 --- a/client_internal_test.go +++ /dev/null @@ -1,70 +0,0 @@ -package mautrix - -import ( - "bytes" - "context" - "net/http" - "testing" - "time" - - "github.com/rs/zerolog" - "github.com/tidwall/gjson" -) - -func TestBackoffFromResponse(t *testing.T) { - now := time.Now().Truncate(time.Second) - - defaultBackoff := time.Duration(123) - - for name, tt := range map[string]struct { - headerValue string - expected time.Duration - expectedLog string - }{ - "AsDate": { - headerValue: now.In(time.UTC).Add(5 * time.Hour).Format(http.TimeFormat), - expected: time.Duration(5) * time.Hour, - expectedLog: "", - }, - "AsSeconds": { - headerValue: "12345", - expected: time.Duration(12345) * time.Second, - expectedLog: "", - }, - "Missing": { - headerValue: "", - expected: defaultBackoff, - expectedLog: "", - }, - "Bad": { - headerValue: "invalid", - expected: defaultBackoff, - expectedLog: `Failed to parse Retry-After header value`, - }, - } { - t.Run(name, func(t *testing.T) { - var out bytes.Buffer - c := &Client{Log: zerolog.New(&out)} - - actual := parseBackoffFromResponse( - (&http.Request{}).WithContext(c.Log.WithContext(context.Background())), - &http.Response{ - Header: http.Header{ - "Retry-After": []string{tt.headerValue}, - }, - }, - now, - time.Duration(123), - ) - - if actual != tt.expected { - t.Fatalf("Backoff duration output mismatch, expected %s, got %s", tt.expected, actual) - } - - lastLogged := gjson.GetBytes(out.Bytes(), zerolog.MessageFieldName).Str - if lastLogged != tt.expectedLog { - t.Fatalf(`Log line mismatch, expected "%s", got "%s"`, tt.expectedLog, lastLogged) - } - }) - } -} diff --git a/go.mod b/go.mod index 3240501d..a35b4f3b 100644 --- a/go.mod +++ b/go.mod @@ -11,12 +11,12 @@ require ( github.com/stretchr/testify v1.8.4 github.com/tidwall/gjson v1.16.0 github.com/tidwall/sjson v1.2.5 - github.com/yuin/goldmark v1.5.5 - go.mau.fi/util v0.0.0-20230805171708-199bf3eec776 + github.com/yuin/goldmark v1.5.6 + go.mau.fi/util v0.0.0-20230906154548-ffc399173e21 go.mau.fi/zeroconfig v0.1.2 - golang.org/x/crypto v0.12.0 - golang.org/x/exp v0.0.0-20230810033253-352e893a4cad - golang.org/x/net v0.14.0 + golang.org/x/crypto v0.13.0 + golang.org/x/exp v0.0.0-20230905200255-921286631fa9 + golang.org/x/net v0.15.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 maunium.net/go/maulogger/v2 v2.4.1 @@ -30,6 +30,6 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect - golang.org/x/sys v0.11.0 // indirect + golang.org/x/sys v0.12.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index 8f5d2690..573509f6 100644 --- a/go.sum +++ b/go.sum @@ -33,22 +33,22 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/yuin/goldmark v1.5.5 h1:IJznPe8wOzfIKETmMkd06F8nXkmlhaHqFRM9l1hAGsU= -github.com/yuin/goldmark v1.5.5/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.mau.fi/util v0.0.0-20230805171708-199bf3eec776 h1:VrxDCO/gLFHLQywGUsJzertrvt2mUEMrZPf4hEL/s18= -go.mau.fi/util v0.0.0-20230805171708-199bf3eec776/go.mod h1:AxuJUMCxpzgJ5eV9JbPWKRH8aAJJidxetNdUj7qcb84= +github.com/yuin/goldmark v1.5.6 h1:COmQAWTCcGetChm3Ig7G/t8AFAN00t+o8Mt4cf7JpwA= +github.com/yuin/goldmark v1.5.6/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.mau.fi/util v0.0.0-20230906154548-ffc399173e21 h1:dj1V9hVkB+q5Vlm1ugo4Y4rLROblswgFkbko4gl4UjQ= +go.mau.fi/util v0.0.0-20230906154548-ffc399173e21/go.mod h1:AxuJUMCxpzgJ5eV9JbPWKRH8aAJJidxetNdUj7qcb84= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk= -golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= -golang.org/x/exp v0.0.0-20230810033253-352e893a4cad h1:g0bG7Z4uG+OgH2QDODnjp6ggkk1bJDsINcuWmJN1iJU= -golang.org/x/exp v0.0.0-20230810033253-352e893a4cad/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= -golang.org/x/net v0.14.0 h1:BONx9s002vGdD9umnlX1Po8vOZmrgH34qlHcD1MfK14= -golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= +golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= +golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM= -golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= From 1daa22f8512e2e4e3cb201d01295f3154f9cbe5a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 9 Sep 2023 15:38:35 +0300 Subject: [PATCH 0008/1647] Update changelog --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3cb5f21b..dc0a81cc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +## unreleased + +* **Breaking change *(id)*** Updated user ID localpart encoding to not encode + `+` as per [MSC4009]. + +[MSC4009]: https://github.com/matrix-org/matrix-spec-proposals/pull/4009 + ## v0.16.0 (2023-08-16) * Bumped minimum Go version to 1.20. From b2a54de015fae3df4404ca9cca68391343245fb8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 11 Sep 2023 10:52:45 -0400 Subject: [PATCH 0009/1647] Add warning logs if AS event handling takes long --- appservice/eventprocessor.go | 37 ++++++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/appservice/eventprocessor.go b/appservice/eventprocessor.go index 437d8536..376a4fc4 100644 --- a/appservice/eventprocessor.go +++ b/appservice/eventprocessor.go @@ -9,6 +9,7 @@ package appservice import ( "encoding/json" "runtime/debug" + "time" "github.com/rs/zerolog" @@ -31,6 +32,9 @@ type DeviceListHandler = func(lists *mautrix.DeviceLists, since string) type EventProcessor struct { ExecMode ExecMode + ExecSyncWarnTime time.Duration + ExecSyncTimeout time.Duration + as *AppService stop chan struct{} handlers map[event.Type][]EventHandler @@ -46,6 +50,9 @@ func NewEventProcessor(as *AppService) *EventProcessor { stop: make(chan struct{}, 1), handlers: make(map[event.Type][]EventHandler), + ExecSyncWarnTime: 30 * time.Second, + ExecSyncTimeout: 15 * time.Minute, + otkHandlers: make([]OTKHandler, 0), deviceListHandlers: make([]DeviceListHandler, 0), } @@ -134,8 +141,34 @@ func (ep *EventProcessor) Dispatch(evt *event.Event) { } }() case Sync: - for _, handler := range handlers { - ep.callHandler(handler, evt) + if ep.ExecSyncWarnTime == 0 && ep.ExecSyncTimeout == 0 { + for _, handler := range handlers { + ep.callHandler(handler, evt) + } + return + } + doneChan := make(chan struct{}) + go func() { + for _, handler := range handlers { + ep.callHandler(handler, evt) + } + close(doneChan) + }() + select { + case <-doneChan: + return + case <-time.After(ep.ExecSyncWarnTime): + log := ep.as.Log.With(). + Str("event_id", evt.ID.String()). + Str("event_type", evt.Type.String()). + Logger() + log.Warn().Msg("Handling event in appservice transaction channel is taking long") + select { + case <-doneChan: + return + case <-time.After(ep.ExecSyncTimeout): + log.Error().Msg("Giving up waiting for event handler") + } } } } From eea14cb9a403cfa862ade75274ff6247f2d3c7f5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 11 Sep 2023 15:16:41 -0400 Subject: [PATCH 0010/1647] Include error in log if message checkpoint sending fails --- bridge/messagecheckpoint.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridge/messagecheckpoint.go b/bridge/messagecheckpoint.go index 0447a8a8..a95d2160 100644 --- a/bridge/messagecheckpoint.go +++ b/bridge/messagecheckpoint.go @@ -36,7 +36,7 @@ func (br *Bridge) SendMessageCheckpoint(evt *event.Event, step status.MessageChe func (br *Bridge) SendRawMessageCheckpoint(cp *status.MessageCheckpoint) { err := br.SendMessageCheckpoints([]*status.MessageCheckpoint{cp}) if err != nil { - br.ZLog.Warn().Interface("message_checkpoint", cp).Msg("Error sending message checkpoint") + br.ZLog.Warn().Err(err).Interface("message_checkpoint", cp).Msg("Error sending message checkpoint") } else { br.ZLog.Debug().Interface("message_checkpoint", cp).Msg("Sent message checkpoint") } From 58a7323f0aa102754aea4c09e75e6601f2d17eda Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 16 Sep 2023 09:47:59 -0400 Subject: [PATCH 0011/1647] Bump version to v0.16.1 --- CHANGELOG.md | 10 +++++++++- go.mod | 2 +- go.sum | 4 ++-- version.go | 2 +- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dc0a81cc..75910504 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,15 @@ -## unreleased +## v0.16.1 (2023-09-16) * **Breaking change *(id)*** Updated user ID localpart encoding to not encode `+` as per [MSC4009]. +* *(bridge)* Added bridge utility to handle double puppeting logins. + * The utility supports automatic logins with all three current methods + (shared secret, legacy appservice, new appservice). +* *(appservice)* Added warning logs and timeout on appservice event handling. + * Defaults to warning after 30 seconds and timeout 15 minutes after that. + * Timeouts can be adjusted or disabled by setting `ExecSync` variables in the + `EventProcessor`. +* *(crypto/olm)* Added `PkDecryption` wrapper. [MSC4009]: https://github.com/matrix-org/matrix-spec-proposals/pull/4009 diff --git a/go.mod b/go.mod index a35b4f3b..cc410f2f 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/tidwall/gjson v1.16.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.5.6 - go.mau.fi/util v0.0.0-20230906154548-ffc399173e21 + go.mau.fi/util v0.1.0 go.mau.fi/zeroconfig v0.1.2 golang.org/x/crypto v0.13.0 golang.org/x/exp v0.0.0-20230905200255-921286631fa9 diff --git a/go.sum b/go.sum index 573509f6..3505fa1b 100644 --- a/go.sum +++ b/go.sum @@ -35,8 +35,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.5.6 h1:COmQAWTCcGetChm3Ig7G/t8AFAN00t+o8Mt4cf7JpwA= github.com/yuin/goldmark v1.5.6/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.mau.fi/util v0.0.0-20230906154548-ffc399173e21 h1:dj1V9hVkB+q5Vlm1ugo4Y4rLROblswgFkbko4gl4UjQ= -go.mau.fi/util v0.0.0-20230906154548-ffc399173e21/go.mod h1:AxuJUMCxpzgJ5eV9JbPWKRH8aAJJidxetNdUj7qcb84= +go.mau.fi/util v0.1.0 h1:BwIFWIOEeO7lsiI2eWKFkWTfc5yQmoe+0FYyOFVyaoE= +go.mau.fi/util v0.1.0/go.mod h1:AxuJUMCxpzgJ5eV9JbPWKRH8aAJJidxetNdUj7qcb84= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= diff --git a/version.go b/version.go index 742d9cf5..6da6e5db 100644 --- a/version.go +++ b/version.go @@ -7,7 +7,7 @@ import ( "strings" ) -const Version = "v0.16.0" +const Version = "v0.16.1" var GoModVersion = "" var Commit = "" From 4e423897f7bd5366b1bbb5e9a596faf40b4fa99b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 3 Oct 2023 21:10:04 +0300 Subject: [PATCH 0012/1647] Set global zerologger --- bridge/bridge.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bridge/bridge.go b/bridge/bridge.go index dec6103b..8a62c6b8 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -22,6 +22,7 @@ import ( "github.com/mattn/go-sqlite3" "github.com/rs/zerolog" + deflog "github.com/rs/zerolog/log" "go.mau.fi/util/configupgrade" "go.mau.fi/util/dbutil" _ "go.mau.fi/util/dbutil/litestream" @@ -505,6 +506,7 @@ func (br *Bridge) init() { zerolog.TimeFieldFormat = time.RFC3339Nano zerolog.CallerMarshalFunc = exzerolog.CallerWithFunctionName zerolog.DefaultContextLogger = &defaultCtxLog + deflog.Logger = br.ZLog.With().Bool("global_log", true).Caller().Logger() br.Log = maulogadapt.ZeroAsMau(br.ZLog) br.DoublePuppet = &doublePuppetUtil{br: br, log: br.ZLog.With().Str("component", "double puppet").Logger()} From 1b562eed160cc954e62e109b4a66c0fd0b3b65a4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 6 Oct 2023 16:03:14 +0300 Subject: [PATCH 0013/1647] Add function to reverse TextToHTML --- event/message.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/event/message.go b/event/message.go index 542cab67..99b54cd4 100644 --- a/event/message.go +++ b/event/message.go @@ -151,10 +151,18 @@ func (content *MessageEventContent) SetEdit(original id.EventID) { } } +// TextToHTML converts the given text to a HTML-safe representation by escaping HTML characters +// and replacing newlines with
tags. func TextToHTML(text string) string { return strings.ReplaceAll(html.EscapeString(text), "\n", "
") } +// ReverseTextToHTML reverses the modifications made by TextToHTML, i.e. replaces
tags with newlines +// and unescapes HTML escape codes. For actually parsing HTML, use the format package instead. +func ReverseTextToHTML(input string) string { + return html.UnescapeString(strings.ReplaceAll(input, "
", "\n")) +} + func (content *MessageEventContent) EnsureHasHTML() { if len(content.FormattedBody) == 0 || content.Format != FormatHTML { content.FormattedBody = TextToHTML(content.Body) From a274ab89a5b7b8674ec10b8cb5aca5e72e54c4a9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 10 Oct 2023 19:48:49 +0300 Subject: [PATCH 0014/1647] Add redacts field for redaction event content --- event/message.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/event/message.go b/event/message.go index 99b54cd4..009709af 100644 --- a/event/message.go +++ b/event/message.go @@ -48,12 +48,12 @@ const ( // RedactionEventContent represents the content of a m.room.redaction message event. // -// The redacted event ID is still at the top level, but will move in a future room version. -// See https://github.com/matrix-org/matrix-doc/pull/2244 and https://github.com/matrix-org/matrix-doc/pull/2174 -// -// https://spec.matrix.org/v1.2/client-server-api/#mroomredaction +// https://spec.matrix.org/v1.8/client-server-api/#mroomredaction type RedactionEventContent struct { Reason string `json:"reason,omitempty"` + + // The event ID is here as of room v11. In old servers it may only be at the top level. + Redacts id.EventID `json:"redacts,omitempty"` } // ReactionEventContent represents the content of a m.reaction message event. From 69c80b473a10aa944944790963ff4b426f5fc983 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 25 Oct 2023 12:38:35 +0300 Subject: [PATCH 0015/1647] Send checkpoint if portal isn't found --- bridge/matrix.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/bridge/matrix.go b/bridge/matrix.go index 91117b9c..ff814c6e 100644 --- a/bridge/matrix.go +++ b/bridge/matrix.go @@ -592,6 +592,8 @@ func (mx *MatrixHandler) HandleMessage(evt *event.Event) { portal := mx.bridge.Child.GetIPortal(evt.RoomID) if portal != nil { portal.ReceiveMatrixEvent(user, evt) + } else { + mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepRemote, fmt.Errorf("unknown room"), true, 0) } } @@ -609,6 +611,8 @@ func (mx *MatrixHandler) HandleReaction(evt *event.Event) { portal := mx.bridge.Child.GetIPortal(evt.RoomID) if portal != nil { portal.ReceiveMatrixEvent(user, evt) + } else { + mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepRemote, fmt.Errorf("unknown room"), true, 0) } } @@ -626,6 +630,8 @@ func (mx *MatrixHandler) HandleRedaction(evt *event.Event) { portal := mx.bridge.Child.GetIPortal(evt.RoomID) if portal != nil { portal.ReceiveMatrixEvent(user, evt) + } else { + mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepRemote, fmt.Errorf("unknown room"), true, 0) } } From e3628c3b9e7ba79e9dd445c2ae42a2ef775f4f52 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 3 Nov 2023 19:55:54 +0200 Subject: [PATCH 0016/1647] Update dependencies --- go.mod | 20 ++++++++++---------- go.sum | 44 +++++++++++++++++++++++--------------------- 2 files changed, 33 insertions(+), 31 deletions(-) diff --git a/go.mod b/go.mod index cc410f2f..a526c17a 100644 --- a/go.mod +++ b/go.mod @@ -7,16 +7,16 @@ require ( github.com/gorilla/websocket v1.5.0 github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.17 - github.com/rs/zerolog v1.30.0 + github.com/rs/zerolog v1.31.0 github.com/stretchr/testify v1.8.4 - github.com/tidwall/gjson v1.16.0 + github.com/tidwall/gjson v1.17.0 github.com/tidwall/sjson v1.2.5 - github.com/yuin/goldmark v1.5.6 - go.mau.fi/util v0.1.0 + github.com/yuin/goldmark v1.6.0 + go.mau.fi/util v0.2.0 go.mau.fi/zeroconfig v0.1.2 - golang.org/x/crypto v0.13.0 - golang.org/x/exp v0.0.0-20230905200255-921286631fa9 - golang.org/x/net v0.15.0 + golang.org/x/crypto v0.14.0 + golang.org/x/exp v0.0.0-20231006140011-7918f672742d + golang.org/x/net v0.17.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 maunium.net/go/maulogger/v2 v2.4.1 @@ -25,11 +25,11 @@ require ( require ( github.com/coreos/go-systemd/v22 v22.5.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/mattn/go-colorable v0.1.12 // indirect - github.com/mattn/go-isatty v0.0.14 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect - golang.org/x/sys v0.12.0 // indirect + golang.org/x/sys v0.13.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index 3505fa1b..84049e41 100644 --- a/go.sum +++ b/go.sum @@ -10,45 +10,47 @@ github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWm github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= -github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= -github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= -github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= -github.com/rs/zerolog v1.30.0 h1:SymVODrcRsaRaSInD9yQtKbtWqwsfoPcRff/oRXLj4c= -github.com/rs/zerolog v1.30.0/go.mod h1:/tk+P47gFdPXq4QYjvCmT5/Gsug2nagsFWBWhAiSi1w= +github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A= +github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.16.0 h1:SyXa+dsSPpUlcwEDuKuEBJEz5vzTvOea+9rjyYodQFg= -github.com/tidwall/gjson v1.16.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM= +github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/yuin/goldmark v1.5.6 h1:COmQAWTCcGetChm3Ig7G/t8AFAN00t+o8Mt4cf7JpwA= -github.com/yuin/goldmark v1.5.6/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.mau.fi/util v0.1.0 h1:BwIFWIOEeO7lsiI2eWKFkWTfc5yQmoe+0FYyOFVyaoE= -go.mau.fi/util v0.1.0/go.mod h1:AxuJUMCxpzgJ5eV9JbPWKRH8aAJJidxetNdUj7qcb84= +github.com/yuin/goldmark v1.6.0 h1:boZcn2GTjpsynOsC0iJHnBWa4Bi0qzfJjthwauItG68= +github.com/yuin/goldmark v1.6.0/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.mau.fi/util v0.2.0 h1:AMGBEdg9Ya/smb/09dljo9wBwKr432EpfjDWF7aFQg0= +go.mau.fi/util v0.2.0/go.mod h1:AxuJUMCxpzgJ5eV9JbPWKRH8aAJJidxetNdUj7qcb84= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= -golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= -golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= -golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= -golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= -golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= -golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= +golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= +golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= +golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= From 149dd3ce1488500e26537bcd8ef661f5b6767849 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 12 Nov 2023 12:46:23 +0200 Subject: [PATCH 0017/1647] Update changelog --- CHANGELOG.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 75910504..ac6dc606 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,11 @@ +## v0.16.2 (unreleased) + +* *(event)* Added `Redacts` field to `RedactionEventContent` for room v11+. +* *(event)* Added `ReverseTextToHTML` which reverses the changes made by + `TextToHTML` (i.e. unescapes HTML characters and replaces `
` with `\n`). +* *(bridge)* Added global zerologger to ensure all logs go through the bridge + logger. + ## v0.16.1 (2023-09-16) * **Breaking change *(id)*** Updated user ID localpart encoding to not encode From 37d5493a566a5c8665321b6763ef89ca27e4abc1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 12 Nov 2023 12:47:20 +0200 Subject: [PATCH 0018/1647] Send encryption error replies in thread if applicable --- bridge/matrix.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bridge/matrix.go b/bridge/matrix.go index ff814c6e..b025490d 100644 --- a/bridge/matrix.go +++ b/bridge/matrix.go @@ -371,8 +371,11 @@ func (mx *MatrixHandler) sendCryptoStatusError(ctx context.Context, evt *event.E if errors.Is(err, errNoCrypto) { update.Body = "🔒 This bridge has not been configured to support encryption" } + relatable, ok := evt.Content.Parsed.(event.Relatable) if editEvent != "" { update.SetEdit(editEvent) + } else if ok && relatable.OptionalGetRelatesTo().GetThreadParent() != "" { + update.GetRelatesTo().SetThread(relatable.OptionalGetRelatesTo().GetThreadParent(), evt.ID) } resp, sendErr := mx.bridge.Bot.SendMessageEvent(evt.RoomID, event.EventMessage, &update) if sendErr != nil { From da5a51a279a7ebc1be1f0619905f17929161d9cb Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 12 Nov 2023 12:49:21 +0200 Subject: [PATCH 0019/1647] Remove incorrect err in log --- crypto/decryptmegolm.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index 90ed3122..eaff136a 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -79,7 +79,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event trustLevel = id.TrustStateUnknownDevice } else if len(sess.ForwardingChains) == 0 || (len(sess.ForwardingChains) == 1 && sess.ForwardingChains[0] == sess.SenderKey.String()) { if device == nil { - log.Debug().Err(err). + log.Debug(). Str("session_sender_key", sess.SenderKey.String()). Msg("Couldn't resolve trust level of session: sent by unknown device") trustLevel = id.TrustStateUnknownDevice From 09daa655758e8a841861dd0595b05fafffeaa117 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 14 Nov 2023 16:45:05 +0200 Subject: [PATCH 0020/1647] Update dependencies --- go.mod | 16 ++++++++-------- go.sum | 32 ++++++++++++++++---------------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/go.mod b/go.mod index a526c17a..83f131af 100644 --- a/go.mod +++ b/go.mod @@ -3,20 +3,20 @@ module maunium.net/go/mautrix go 1.20 require ( - github.com/gorilla/mux v1.8.0 - github.com/gorilla/websocket v1.5.0 + github.com/gorilla/mux v1.8.1 + github.com/gorilla/websocket v1.5.1 github.com/lib/pq v1.10.9 - github.com/mattn/go-sqlite3 v1.14.17 + github.com/mattn/go-sqlite3 v1.14.18 github.com/rs/zerolog v1.31.0 github.com/stretchr/testify v1.8.4 github.com/tidwall/gjson v1.17.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.6.0 - go.mau.fi/util v0.2.0 + go.mau.fi/util v0.2.1-0.20231114144345-a692409c548f go.mau.fi/zeroconfig v0.1.2 - golang.org/x/crypto v0.14.0 - golang.org/x/exp v0.0.0-20231006140011-7918f672742d - golang.org/x/net v0.17.0 + golang.org/x/crypto v0.15.0 + golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa + golang.org/x/net v0.18.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 maunium.net/go/maulogger/v2 v2.4.1 @@ -30,6 +30,6 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect - golang.org/x/sys v0.13.0 // indirect + golang.org/x/sys v0.14.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index 84049e41..2bee0b27 100644 --- a/go.sum +++ b/go.sum @@ -4,10 +4,10 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSV github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= -github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= -github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= -github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= +github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= +github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= +github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= @@ -15,8 +15,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= -github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI= +github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -36,21 +36,21 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.6.0 h1:boZcn2GTjpsynOsC0iJHnBWa4Bi0qzfJjthwauItG68= github.com/yuin/goldmark v1.6.0/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.mau.fi/util v0.2.0 h1:AMGBEdg9Ya/smb/09dljo9wBwKr432EpfjDWF7aFQg0= -go.mau.fi/util v0.2.0/go.mod h1:AxuJUMCxpzgJ5eV9JbPWKRH8aAJJidxetNdUj7qcb84= +go.mau.fi/util v0.2.1-0.20231114144345-a692409c548f h1:mJhRlbk3AStG2XfKe1MuO2rmikjUPLcO0pglA+GlWuA= +go.mau.fi/util v0.2.1-0.20231114144345-a692409c548f/go.mod h1:MjlzCQEMzJ+G8RsPawHzpLB8rwTo3aPIjG5FzBvQT/c= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= -golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= -golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= -golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= -golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= -golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/crypto v0.15.0 h1:frVn1TEaCEaZcn3Tmd7Y2b5KKPaZ+I32Q2OA3kYp5TA= +golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g= +golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ= +golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE= +golang.org/x/net v0.18.0 h1:mIYleuAkSbHh0tCv7RvjL3F6ZVbLjq4+R7zbOn3Kokg= +golang.org/x/net v0.18.0/go.mod h1:/czyP5RqHAH4odGYxBJ1qz0+CE5WZ+2j1YgoEo8F2jQ= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= -golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= +golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= From d8e28be543f79493456f18680d7169235e30640c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 16 Nov 2023 15:11:45 +0200 Subject: [PATCH 0021/1647] Update dependencies and downgrade gorilla packages due to regressions in new versions --- go.mod | 6 +++--- go.sum | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index 83f131af..7fce5164 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,8 @@ module maunium.net/go/mautrix go 1.20 require ( - github.com/gorilla/mux v1.8.1 - github.com/gorilla/websocket v1.5.1 + github.com/gorilla/mux v1.8.0 + github.com/gorilla/websocket v1.5.0 github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.18 github.com/rs/zerolog v1.31.0 @@ -12,7 +12,7 @@ require ( github.com/tidwall/gjson v1.17.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.6.0 - go.mau.fi/util v0.2.1-0.20231114144345-a692409c548f + go.mau.fi/util v0.2.1 go.mau.fi/zeroconfig v0.1.2 golang.org/x/crypto v0.15.0 golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa diff --git a/go.sum b/go.sum index 2bee0b27..ce8aac82 100644 --- a/go.sum +++ b/go.sum @@ -4,10 +4,10 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSV github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= -github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= -github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= -github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= +github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= +github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= @@ -36,8 +36,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.6.0 h1:boZcn2GTjpsynOsC0iJHnBWa4Bi0qzfJjthwauItG68= github.com/yuin/goldmark v1.6.0/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.mau.fi/util v0.2.1-0.20231114144345-a692409c548f h1:mJhRlbk3AStG2XfKe1MuO2rmikjUPLcO0pglA+GlWuA= -go.mau.fi/util v0.2.1-0.20231114144345-a692409c548f/go.mod h1:MjlzCQEMzJ+G8RsPawHzpLB8rwTo3aPIjG5FzBvQT/c= +go.mau.fi/util v0.2.1 h1:eazulhFE/UmjOFtPrGg6zkF5YfAyiDzQb8ihLMbsPWw= +go.mau.fi/util v0.2.1/go.mod h1:MjlzCQEMzJ+G8RsPawHzpLB8rwTo3aPIjG5FzBvQT/c= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.15.0 h1:frVn1TEaCEaZcn3Tmd7Y2b5KKPaZ+I32Q2OA3kYp5TA= From e606259d3dac633b95e7e3f8b364544cd7b6657f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 16 Nov 2023 15:14:55 +0200 Subject: [PATCH 0022/1647] Bump version to v0.16.2 --- CHANGELOG.md | 4 +++- version.go | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ac6dc606..53dd63e9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,10 +1,12 @@ -## v0.16.2 (unreleased) +## v0.16.2 (2023-11-16) * *(event)* Added `Redacts` field to `RedactionEventContent` for room v11+. * *(event)* Added `ReverseTextToHTML` which reverses the changes made by `TextToHTML` (i.e. unescapes HTML characters and replaces `
` with `\n`). * *(bridge)* Added global zerologger to ensure all logs go through the bridge logger. +* *(bridge)* Changed encryption error messages to be sent in a thread if the + message that failed to decrypt was in a thread. ## v0.16.1 (2023-09-16) diff --git a/version.go b/version.go index 6da6e5db..7b0c3dbe 100644 --- a/version.go +++ b/version.go @@ -7,7 +7,7 @@ import ( "strings" ) -const Version = "v0.16.1" +const Version = "v0.16.2" var GoModVersion = "" var Commit = "" From f47b2ce7bdcfd154a5759ecd206e5a0ac5c2ad31 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 17 Nov 2023 13:21:44 +0200 Subject: [PATCH 0023/1647] Add context for bridge command events --- bridge/commands/event.go | 2 ++ bridge/commands/processor.go | 15 ++++++--------- bridge/matrix.go | 14 +++++++++----- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/bridge/commands/event.go b/bridge/commands/event.go index fc88ac8e..0adc9237 100644 --- a/bridge/commands/event.go +++ b/bridge/commands/event.go @@ -7,6 +7,7 @@ package commands import ( + "context" "fmt" "strings" @@ -35,6 +36,7 @@ type Event struct { Args []string RawArgs string ReplyTo id.EventID + Ctx context.Context ZLog *zerolog.Logger // Deprecated: switch to ZLog Log maulogger.Logger diff --git a/bridge/commands/processor.go b/bridge/commands/processor.go index d64277b9..904f5c40 100644 --- a/bridge/commands/processor.go +++ b/bridge/commands/processor.go @@ -7,6 +7,7 @@ package commands import ( + "context" "runtime/debug" "strings" @@ -58,14 +59,13 @@ func (proc *Processor) AddHandler(handler Handler) { } // Handle handles messages to the bridge -func (proc *Processor) Handle(roomID id.RoomID, eventID id.EventID, user bridge.User, message string, replyTo id.EventID) { +func (proc *Processor) Handle(ctx context.Context, roomID id.RoomID, eventID id.EventID, user bridge.User, message string, replyTo id.EventID) { defer func() { err := recover() if err != nil { - proc.log.Error(). + zerolog.Ctx(ctx).Error(). Str(zerolog.ErrorStackFieldName, string(debug.Stack())). Interface(zerolog.ErrorFieldName, err). - Str("event_id", eventID.String()). Msg("Panic in Matrix command handler") } }() @@ -75,12 +75,8 @@ func (proc *Processor) Handle(roomID id.RoomID, eventID id.EventID, user bridge. } command := strings.ToLower(args[0]) rawArgs := strings.TrimLeft(strings.TrimPrefix(message, command), " ") - log := proc.log.With(). - Str("user_id", user.GetMXID().String()). - Str("event_id", eventID.String()). - Str("room_id", roomID.String()). - Str("mx_command", command). - Logger() + log := zerolog.Ctx(ctx).With().Str("mx_command", command).Logger() + ctx = log.WithContext(ctx) ce := &Event{ Bot: proc.bridge.Bot, Bridge: proc.bridge, @@ -93,6 +89,7 @@ func (proc *Processor) Handle(roomID id.RoomID, eventID id.EventID, user bridge. Args: args[1:], RawArgs: rawArgs, ReplyTo: replyTo, + Ctx: ctx, ZLog: &log, Log: maulogadapt.ZeroAsMau(&log), } diff --git a/bridge/matrix.go b/bridge/matrix.go index b025490d..a00b34c0 100644 --- a/bridge/matrix.go +++ b/bridge/matrix.go @@ -25,7 +25,7 @@ import ( ) type CommandProcessor interface { - Handle(roomID id.RoomID, eventID id.EventID, user User, message string, replyTo id.EventID) + Handle(ctx context.Context, roomID id.RoomID, eventID id.EventID, user User, message string, replyTo id.EventID) } type MatrixHandler struct { @@ -548,12 +548,16 @@ func (mx *MatrixHandler) waitLongerForSession(ctx context.Context, evt *event.Ev func (mx *MatrixHandler) HandleMessage(evt *event.Event) { defer mx.TrackEventDuration(evt.Type)() + log := mx.log.With(). + Str("event_id", evt.ID.String()). + Str("room_id", evt.RoomID.String()). + Str("sender", evt.Sender.String()). + Logger() + ctx := log.WithContext(context.Background()) if mx.shouldIgnoreEvent(evt) { return } else if !evt.Mautrix.WasEncrypted && mx.bridge.Config.Bridge.GetEncryptionConfig().Require { - log := mx.log.With().Str("event_id", evt.ID.String()).Logger() log.Warn().Msg("Dropping unencrypted event") - ctx := log.WithContext(context.Background()) mx.sendCryptoStatusError(ctx, evt, "", errMessageNotEncrypted, 0, true) return } @@ -572,7 +576,7 @@ func (mx *MatrixHandler) HandleMessage(evt *event.Event) { content.Body = strings.TrimLeft(strings.TrimPrefix(content.Body, commandPrefix), " ") } if hasCommandPrefix || evt.RoomID == user.GetManagementRoomID() { - go mx.bridge.CommandProcessor.Handle(evt.RoomID, evt.ID, user, content.Body, content.RelatesTo.GetReplyTo()) + go mx.bridge.CommandProcessor.Handle(ctx, evt.RoomID, evt.ID, user, content.Body, content.RelatesTo.GetReplyTo()) go mx.bridge.SendMessageSuccessCheckpoint(evt, status.MsgStepCommand, 0) if mx.bridge.Config.Bridge.EnableMessageStatusEvents() { statusEvent := &event.BeeperMessageStatusEventContent{ @@ -585,7 +589,7 @@ func (mx *MatrixHandler) HandleMessage(evt *event.Event) { } _, sendErr := mx.bridge.Bot.SendMessageEvent(evt.RoomID, event.BeeperMessageStatus, statusEvent) if sendErr != nil { - mx.log.Warn().Str("event_id", evt.ID.String()).Err(sendErr).Msg("Failed to send message status event for command") + log.Warn().Err(sendErr).Msg("Failed to send message status event for command") } } return From 4784d6d09fe273a6cf943c57de29c51ead4bd97b Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 17 Nov 2023 09:01:33 -0700 Subject: [PATCH 0024/1647] MembershipHandlingPortal: add full event to handlers Signed-off-by: Sumner Evans --- bridge/bridge.go | 6 +++--- bridge/matrix.go | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/bridge/bridge.go b/bridge/bridge.go index 8a62c6b8..291d6be9 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -65,9 +65,9 @@ type Portal interface { type MembershipHandlingPortal interface { Portal - HandleMatrixLeave(sender User) - HandleMatrixKick(sender User, ghost Ghost) - HandleMatrixInvite(sender User, ghost Ghost) + HandleMatrixLeave(sender User, evt *event.Event) + HandleMatrixKick(sender User, ghost Ghost, evt *event.Event) + HandleMatrixInvite(sender User, ghost Ghost, evt *event.Event) } type ReadReceiptHandlingPortal interface { diff --git a/bridge/matrix.go b/bridge/matrix.go index a00b34c0..3196af60 100644 --- a/bridge/matrix.go +++ b/bridge/matrix.go @@ -290,12 +290,12 @@ func (mx *MatrixHandler) HandleMembership(evt *event.Event) { } } if isSelf { - mhp.HandleMatrixLeave(user) + mhp.HandleMatrixLeave(user, evt) } else if ghost != nil { - mhp.HandleMatrixKick(user, ghost) + mhp.HandleMatrixKick(user, ghost, evt) } } else if content.Membership == event.MembershipInvite && !isSelf && ghost != nil { - mhp.HandleMatrixInvite(user, ghost) + mhp.HandleMatrixInvite(user, ghost, evt) } // TODO kicking/inviting non-ghost users users } From 5fae50102f42a64b0063740d1f8504e1d780cef7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 3 Dec 2023 00:12:45 +0200 Subject: [PATCH 0025/1647] Don't drop unknown events in default syncer --- sync.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sync.go b/sync.go index 5979a17b..f05e9b5f 100644 --- a/sync.go +++ b/sync.go @@ -134,7 +134,10 @@ func NewDefaultSyncer() *DefaultSyncer { globalListeners: []EventHandler{}, ParseEventContent: true, ParseErrorHandler: func(evt *event.Event, err error) bool { - return false + // By default, drop known events that can't be parsed, but let unknown events through + return errors.Is(err, event.ErrUnsupportedContentType) || + // Also allow events that had their content already parsed by some other function + errors.Is(err, event.ErrContentAlreadyParsed) }, } } From 685d5f71da580e63c0f524e103abe83452d54d11 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 3 Dec 2023 00:18:49 +0200 Subject: [PATCH 0026/1647] Update changelog --- CHANGELOG.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 53dd63e9..1961bd79 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,12 @@ +## v0.17.0 (unreleased) + +* **Breaking change *(bridge)*** Added raw event to portal membership handling + functions. +* *(bridge)* Added context parameter for bridge command events. +* *(client)* Changed default syncer to not drop unknown events. + * The syncer will still drop known events if parsing the content fails. + * The behavior can be changed by changing the `ParseErrorHandler` function. + ## v0.16.2 (2023-11-16) * *(event)* Added `Redacts` field to `RedactionEventContent` for room v11+. From 1e0731dc73a4d8c9b9ff512979807c504668f35b Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 12 Dec 2023 12:04:56 -0700 Subject: [PATCH 0027/1647] versions: add constants for v1.8 and v1.9 Signed-off-by: Sumner Evans --- versions.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/versions.go b/versions.go index 2f898b73..d3dd3c67 100644 --- a/versions.go +++ b/versions.go @@ -93,6 +93,8 @@ var ( SpecV15 = MustParseSpecVersion("v1.5") SpecV16 = MustParseSpecVersion("v1.6") SpecV17 = MustParseSpecVersion("v1.7") + SpecV18 = MustParseSpecVersion("v1.8") + SpecV19 = MustParseSpecVersion("v1.9") ) func (svf SpecVersionFormat) String() string { From 9c109c97a6889d6202555d09878af7221279998e Mon Sep 17 00:00:00 2001 From: Lukas Gallandi Date: Fri, 15 Dec 2023 13:25:36 +0200 Subject: [PATCH 0028/1647] Add pure Go implementation of libolm --- crypto/goolm/.gitignore | 48 ++ crypto/goolm/LICENSE | 9 + crypto/goolm/README.md | 31 + crypto/goolm/account/account.go | 519 ++++++++++++++ crypto/goolm/account/account_test.go | 675 ++++++++++++++++++ crypto/goolm/account/account_test_data.go | 267 +++++++ crypto/goolm/base64.go | 41 ++ crypto/goolm/base64_test.go | 47 ++ crypto/goolm/cipher/aesSha256.go | 96 +++ crypto/goolm/cipher/aesSha256_test.go | 83 +++ crypto/goolm/cipher/main.go | 17 + crypto/goolm/cipher/pickle.go | 57 ++ crypto/goolm/cipher/pickle_test.go | 31 + crypto/goolm/crypto/aesCBC.go | 75 ++ crypto/goolm/crypto/aesCBC_test.go | 70 ++ crypto/goolm/crypto/curve25519.go | 184 +++++ crypto/goolm/crypto/curve25519_test.go | 185 +++++ crypto/goolm/crypto/ed25519.go | 180 +++++ crypto/goolm/crypto/ed25519_test.go | 138 ++++ crypto/goolm/crypto/hmac.go | 29 + crypto/goolm/crypto/hmac_test.go | 113 +++ crypto/goolm/crypto/main.go | 2 + crypto/goolm/crypto/oneTimeKey.go | 95 +++ crypto/goolm/errors.go | 36 + crypto/goolm/go.mod | 9 + crypto/goolm/go.sum | 20 + crypto/goolm/libolmPickle/pickle.go | 41 ++ crypto/goolm/libolmPickle/pickle_test.go | 96 +++ crypto/goolm/libolmPickle/unpickle.go | 52 ++ crypto/goolm/libolmPickle/unpickle_test.go | 104 +++ crypto/goolm/libolmVersion.md | 3 + crypto/goolm/main.go | 10 + crypto/goolm/megolm/megolm.go | 234 ++++++ crypto/goolm/megolm/megolm_test.go | 139 ++++ crypto/goolm/message/decoder.go | 71 ++ crypto/goolm/message/decoder_test.go | 82 +++ crypto/goolm/message/groupMessage.go | 144 ++++ crypto/goolm/message/groupMessage_test.go | 49 ++ crypto/goolm/message/message.go | 129 ++++ crypto/goolm/message/message_test.go | 52 ++ crypto/goolm/message/preKeyMessage.go | 120 ++++ crypto/goolm/message/preKeyMessage_test.go | 62 ++ crypto/goolm/message/sessionExport.go | 44 ++ crypto/goolm/message/sessionSharing.go | 50 ++ crypto/goolm/olm/chain.go | 257 +++++++ crypto/goolm/olm/olm.go | 432 +++++++++++ crypto/goolm/olm/olm_test.go | 185 +++++ crypto/goolm/olm/skippedMessage.go | 54 ++ crypto/goolm/pk/decryption.go | 162 +++++ crypto/goolm/pk/encryption.go | 46 ++ crypto/goolm/pk/pk_test.go | 132 ++++ crypto/goolm/pk/signing.go | 44 ++ crypto/goolm/sas/main.go | 76 ++ crypto/goolm/sas/main_test.go | 111 +++ crypto/goolm/session/main.go | 2 + crypto/goolm/session/megolmInboundSession.go | 273 +++++++ crypto/goolm/session/megolmOutboundSession.go | 168 +++++ crypto/goolm/session/melgomSession_test.go | 283 ++++++++ crypto/goolm/session/olmSession.go | 475 ++++++++++++ crypto/goolm/session/olmSession_test.go | 174 +++++ crypto/goolm/utilities/main.go | 25 + crypto/goolm/utilities/main_test.go | 14 + crypto/goolm/utilities/pickle.go | 60 ++ 63 files changed, 7512 insertions(+) create mode 100644 crypto/goolm/.gitignore create mode 100644 crypto/goolm/LICENSE create mode 100644 crypto/goolm/README.md create mode 100644 crypto/goolm/account/account.go create mode 100644 crypto/goolm/account/account_test.go create mode 100644 crypto/goolm/account/account_test_data.go create mode 100644 crypto/goolm/base64.go create mode 100644 crypto/goolm/base64_test.go create mode 100644 crypto/goolm/cipher/aesSha256.go create mode 100644 crypto/goolm/cipher/aesSha256_test.go create mode 100644 crypto/goolm/cipher/main.go create mode 100644 crypto/goolm/cipher/pickle.go create mode 100644 crypto/goolm/cipher/pickle_test.go create mode 100644 crypto/goolm/crypto/aesCBC.go create mode 100644 crypto/goolm/crypto/aesCBC_test.go create mode 100644 crypto/goolm/crypto/curve25519.go create mode 100644 crypto/goolm/crypto/curve25519_test.go create mode 100644 crypto/goolm/crypto/ed25519.go create mode 100644 crypto/goolm/crypto/ed25519_test.go create mode 100644 crypto/goolm/crypto/hmac.go create mode 100644 crypto/goolm/crypto/hmac_test.go create mode 100644 crypto/goolm/crypto/main.go create mode 100644 crypto/goolm/crypto/oneTimeKey.go create mode 100644 crypto/goolm/errors.go create mode 100644 crypto/goolm/go.mod create mode 100644 crypto/goolm/go.sum create mode 100644 crypto/goolm/libolmPickle/pickle.go create mode 100644 crypto/goolm/libolmPickle/pickle_test.go create mode 100644 crypto/goolm/libolmPickle/unpickle.go create mode 100644 crypto/goolm/libolmPickle/unpickle_test.go create mode 100644 crypto/goolm/libolmVersion.md create mode 100644 crypto/goolm/main.go create mode 100644 crypto/goolm/megolm/megolm.go create mode 100644 crypto/goolm/megolm/megolm_test.go create mode 100644 crypto/goolm/message/decoder.go create mode 100644 crypto/goolm/message/decoder_test.go create mode 100644 crypto/goolm/message/groupMessage.go create mode 100644 crypto/goolm/message/groupMessage_test.go create mode 100644 crypto/goolm/message/message.go create mode 100644 crypto/goolm/message/message_test.go create mode 100644 crypto/goolm/message/preKeyMessage.go create mode 100644 crypto/goolm/message/preKeyMessage_test.go create mode 100644 crypto/goolm/message/sessionExport.go create mode 100644 crypto/goolm/message/sessionSharing.go create mode 100644 crypto/goolm/olm/chain.go create mode 100644 crypto/goolm/olm/olm.go create mode 100644 crypto/goolm/olm/olm_test.go create mode 100644 crypto/goolm/olm/skippedMessage.go create mode 100644 crypto/goolm/pk/decryption.go create mode 100644 crypto/goolm/pk/encryption.go create mode 100644 crypto/goolm/pk/pk_test.go create mode 100644 crypto/goolm/pk/signing.go create mode 100644 crypto/goolm/sas/main.go create mode 100644 crypto/goolm/sas/main_test.go create mode 100644 crypto/goolm/session/main.go create mode 100644 crypto/goolm/session/megolmInboundSession.go create mode 100644 crypto/goolm/session/megolmOutboundSession.go create mode 100644 crypto/goolm/session/melgomSession_test.go create mode 100644 crypto/goolm/session/olmSession.go create mode 100644 crypto/goolm/session/olmSession_test.go create mode 100644 crypto/goolm/utilities/main.go create mode 100644 crypto/goolm/utilities/main_test.go create mode 100644 crypto/goolm/utilities/pickle.go diff --git a/crypto/goolm/.gitignore b/crypto/goolm/.gitignore new file mode 100644 index 00000000..8cb9a443 --- /dev/null +++ b/crypto/goolm/.gitignore @@ -0,0 +1,48 @@ +# ---> Go +# If you prefer the allow list template instead of the deny list, see community template: +# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore +# +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +# vendor/ + +# Go workspace file +go.work + +# ---> Go.AllowList +# Allowlisting gitignore template for GO projects prevents us +# from adding various unwanted local files, such as generated +# files, developer configurations or IDE-specific files etc. +# +# Recommended: Go.AllowList.gitignore + +# Ignore everything +* + +# But not these files... +!/.gitignore + +!*.go +!go.sum +!go.mod + +!README.md +!LICENSE +!libolmVersion.md + +# !Makefile + +# ...even if they are in subdirectories +!*/ diff --git a/crypto/goolm/LICENSE b/crypto/goolm/LICENSE new file mode 100644 index 00000000..d81feaab --- /dev/null +++ b/crypto/goolm/LICENSE @@ -0,0 +1,9 @@ +MIT License + +Copyright (c) 2022 Lukas Gallandi + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/crypto/goolm/README.md b/crypto/goolm/README.md new file mode 100644 index 00000000..a34b27ca --- /dev/null +++ b/crypto/goolm/README.md @@ -0,0 +1,31 @@ +# goolm + +[![Please don't upload to GitHub](https://nogithub.codeberg.page/badge.svg)](https://nogithub.codeberg.page) +[![GoDoc](https://godoc.org/codeberg.org/DerLukas/goolm?status.svg)](https://godoc.org/codeberg.org/DerLukas/goolm) + +### A Go implementation of Olm and Megolm + +goolm is a pure Go implementation of libolm. Libolm is a cryptographic library used for end-to-end encryption in Matrix and wirtten in C++. +With goolm there is no need to use cgo when building Matrix clients in go. + +See the GoDoc for usage. + +This package is written to be a easily used in github.com/mautrix/go/crypto/olm. + +PR's are always welcome. + +# Features + +* Test files for most methods and functions adapted from libolm + +## Supported +* [Olm](https://matrix-org.github.io/vodozemac/vodozemac/olm/index.html) +* Pickle structs with encryption using JSON marshalling +* Pickle structs with encryption using the libolm format +* [Megolm](https://matrix-org.github.io/vodozemac/vodozemac/megolm/index.html) +* Inbound and outbound group sessions +* [SAS](https://matrix.org/docs/guides/implementing-more-advanced-e-2-ee-features-such-as-cross-signing) support + +# License + +MIT licensed. See the LICENSE file for details. diff --git a/crypto/goolm/account/account.go b/crypto/goolm/account/account.go new file mode 100644 index 00000000..e7e3beba --- /dev/null +++ b/crypto/goolm/account/account.go @@ -0,0 +1,519 @@ +// account packages an account which stores the identity, one time keys and fallback keys. +package account + +import ( + "encoding/json" + "io" + + "codeberg.org/DerLukas/goolm" + "codeberg.org/DerLukas/goolm/cipher" + "codeberg.org/DerLukas/goolm/crypto" + libolmpickle "codeberg.org/DerLukas/goolm/libolmPickle" + "codeberg.org/DerLukas/goolm/session" + "codeberg.org/DerLukas/goolm/utilities" + "github.com/pkg/errors" + "maunium.net/go/mautrix/id" +) + +const ( + accountPickleVersionJSON byte = 1 + accountPickleVersionLibOLM uint32 = 4 + MaxOneTimeKeys int = 100 //maximum number of stored one time keys per Account +) + +// Account stores an account for end to end encrypted messaging via the olm protocol. +// An Account can not be used to en/decrypt messages. However it can be used to contruct new olm sessions, which in turn do the en/decryption. +// There is no tracking of sessions in an account. +type Account struct { + IdKeys struct { + Ed25519 crypto.Ed25519KeyPair `json:"ed25519,omitempty"` + Curve25519 crypto.Curve25519KeyPair `json:"curve25519,omitempty"` + } `json:"identityKeys"` + OTKeys []crypto.OneTimeKey `json:"oneTimeKeys"` + CurrentFallbackKey crypto.OneTimeKey `json:"currentFallbackKey,omitempty"` + PrevFallbackKey crypto.OneTimeKey `json:"prevFallbackKey,omitempty"` + NextOneTimeKeyID uint32 `json:"nextOneTimeKeyID,omitempty"` + NumFallbackKeys uint8 `json:"numberFallbackKeys"` +} + +// AccountFromJSONPickled loads the Account details from a pickled base64 string. The input is decrypted with the supplied key. +func AccountFromJSONPickled(pickled, key []byte) (*Account, error) { + if len(pickled) == 0 { + return nil, errors.Wrap(goolm.ErrEmptyInput, "accountFromPickled") + } + a := &Account{} + err := a.UnpickleAsJSON(pickled, key) + if err != nil { + return nil, err + } + return a, nil +} + +// AccountFromPickled loads the Account details from a pickled base64 string. The input is decrypted with the supplied key. +func AccountFromPickled(pickled, key []byte) (*Account, error) { + if len(pickled) == 0 { + return nil, errors.Wrap(goolm.ErrEmptyInput, "accountFromPickled") + } + a := &Account{} + err := a.Unpickle(pickled, key) + if err != nil { + return nil, err + } + return a, nil +} + +// NewAccount creates a new Account. If reader is nil, crypto/rand is used for the key creation. +func NewAccount(reader io.Reader) (*Account, error) { + a := &Account{} + kPEd25519, err := crypto.Ed25519GenerateKey(reader) + if err != nil { + return nil, err + } + a.IdKeys.Ed25519 = kPEd25519 + kPCurve25519, err := crypto.Curve25519GenerateKey(reader) + if err != nil { + return nil, err + } + a.IdKeys.Curve25519 = kPCurve25519 + return a, nil +} + +// PickleAsJSON returns an Account as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. +func (a Account) PickleAsJSON(key []byte) ([]byte, error) { + return utilities.PickleAsJSON(a, accountPickleVersionJSON, key) +} + +// UnpickleAsJSON updates an Account by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. +func (a *Account) UnpickleAsJSON(pickled, key []byte) error { + return utilities.UnpickleAsJSON(a, pickled, key, accountPickleVersionJSON) +} + +// IdentityKeysJSON returns the public parts of the identity keys for the Account in a JSON string. +func (a Account) IdentityKeysJSON() ([]byte, error) { + res := struct { + Ed25519 string `json:"ed25519"` + Curve25519 string `json:"curve25519"` + }{} + ed25519, curve25519 := a.IdentityKeys() + res.Ed25519 = string(ed25519) + res.Curve25519 = string(curve25519) + return json.Marshal(res) +} + +// IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity keys for the Account. +func (a Account) IdentityKeys() (id.Ed25519, id.Curve25519) { + ed25519 := id.Ed25519(goolm.Base64Encode(a.IdKeys.Ed25519.PublicKey)) + curve25519 := id.Curve25519(goolm.Base64Encode(a.IdKeys.Curve25519.PublicKey)) + return ed25519, curve25519 +} + +// Sign returns the signature of a message using the Ed25519 key for this Account. +func (a Account) Sign(message []byte) ([]byte, error) { + if len(message) == 0 { + return nil, errors.Wrap(goolm.ErrEmptyInput, "sign") + } + return goolm.Base64Encode(a.IdKeys.Ed25519.Sign(message)), nil +} + +// OneTimeKeys returns the public parts of the unpublished one time keys of the Account. +// +// The returned data is a map with the mapping of key id to base64-encoded Curve25519 key. +func (a Account) OneTimeKeys() map[string]id.Curve25519 { + oneTimeKeys := make(map[string]id.Curve25519) + for _, curKey := range a.OTKeys { + if !curKey.Published { + oneTimeKeys[curKey.KeyIDEncoded()] = id.Curve25519(curKey.PublicKeyEncoded()) + } + } + return oneTimeKeys +} + +//OneTimeKeysJSON returns the public parts of the unpublished one time keys of the Account as a JSON string. +// +//The returned JSON is of format: +/* + { + Curve25519: { + "AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo", + "AAAAAB": "LRvjo46L1X2vx69sS9QNFD29HWulxrmW11Up5AfAjgU" + } + } +*/ +func (a Account) OneTimeKeysJSON() ([]byte, error) { + res := make(map[string]map[string]id.Curve25519) + otKeys := a.OneTimeKeys() + res["Curve25519"] = otKeys + return json.Marshal(res) +} + +// MarkKeysAsPublished marks the current set of one time keys and the fallback key as being +// published. +func (a *Account) MarkKeysAsPublished() { + for keyIndex := range a.OTKeys { + if !a.OTKeys[keyIndex].Published { + a.OTKeys[keyIndex].Published = true + } + } + a.CurrentFallbackKey.Published = true +} + +// GenOneTimeKeys generates a number of new one time keys. If the total number +// of keys stored by this Account exceeds MaxOneTimeKeys then the older +// keys are discarded. If reader is nil, crypto/rand is used for the key creation. +func (a *Account) GenOneTimeKeys(reader io.Reader, num uint) error { + for i := uint(0); i < num; i++ { + key := crypto.OneTimeKey{ + Published: false, + ID: a.NextOneTimeKeyID, + } + newKP, err := crypto.Curve25519GenerateKey(reader) + if err != nil { + return err + } + key.Key = newKP + a.NextOneTimeKeyID++ + a.OTKeys = append([]crypto.OneTimeKey{key}, a.OTKeys...) + } + if len(a.OTKeys) > MaxOneTimeKeys { + a.OTKeys = a.OTKeys[:MaxOneTimeKeys] + } + return nil +} + +// NewOutboundSession creates a new outbound session to a +// given curve25519 identity Key and one time key. +func (a Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*session.OlmSession, error) { + if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 { + return nil, errors.Wrap(goolm.ErrEmptyInput, "outbound session") + } + theirIdentityKeyDecoded, err := goolm.Base64Decode([]byte(theirIdentityKey)) + if err != nil { + return nil, err + } + theirOneTimeKeyDecoded, err := goolm.Base64Decode([]byte(theirOneTimeKey)) + if err != nil { + return nil, err + } + s, err := session.NewOutboundOlmSession(a.IdKeys.Curve25519, theirIdentityKeyDecoded, theirOneTimeKeyDecoded) + if err != nil { + return nil, err + } + return s, nil +} + +// NewInboundSession creates a new inbound session from an incoming PRE_KEY message. +func (a Account) NewInboundSession(theirIdentityKey *id.Curve25519, oneTimeKeyMsg []byte) (*session.OlmSession, error) { + if len(oneTimeKeyMsg) == 0 { + return nil, errors.Wrap(goolm.ErrEmptyInput, "inbound session") + } + var theirIdentityKeyDecoded *crypto.Curve25519PublicKey + var err error + if theirIdentityKey != nil { + theirIdentityKeyDecodedByte, err := goolm.Base64Decode([]byte(*theirIdentityKey)) + if err != nil { + return nil, err + } + theirIdentityKeyCurve := crypto.Curve25519PublicKey(theirIdentityKeyDecodedByte) + theirIdentityKeyDecoded = &theirIdentityKeyCurve + } + + s, err := session.NewInboundOlmSession(theirIdentityKeyDecoded, oneTimeKeyMsg, a.searchOTKForOur, a.IdKeys.Curve25519) + if err != nil { + return nil, err + } + return s, nil +} + +func (a Account) searchOTKForOur(toFind crypto.Curve25519PublicKey) *crypto.OneTimeKey { + for curIndex := range a.OTKeys { + if a.OTKeys[curIndex].Key.PublicKey.Equal(toFind) { + return &a.OTKeys[curIndex] + } + } + if a.NumFallbackKeys >= 1 && a.CurrentFallbackKey.Key.PublicKey.Equal(toFind) { + return &a.CurrentFallbackKey + } + if a.NumFallbackKeys >= 2 && a.PrevFallbackKey.Key.PublicKey.Equal(toFind) { + return &a.PrevFallbackKey + } + return nil +} + +// RemoveOneTimeKeys removes the one time key in this Account which matches the one time key in the session s. +func (a *Account) RemoveOneTimeKeys(s *session.OlmSession) { + toFind := s.BobOneTimeKey + for curIndex := range a.OTKeys { + if a.OTKeys[curIndex].Key.PublicKey.Equal(toFind) { + //Remove and return + a.OTKeys[curIndex] = a.OTKeys[len(a.OTKeys)-1] + a.OTKeys = a.OTKeys[:len(a.OTKeys)-1] + return + } + } + //if the key is a fallback or prevFallback, don't remove it +} + +// GenFallbackKey generates a new fallback key. The old fallback key is stored in a.PrevFallbackKey overwriting any previous PrevFallbackKey. If reader is nil, crypto/rand is used for the key creation. +func (a *Account) GenFallbackKey(reader io.Reader) error { + a.PrevFallbackKey = a.CurrentFallbackKey + key := crypto.OneTimeKey{ + Published: false, + ID: a.NextOneTimeKeyID, + } + newKP, err := crypto.Curve25519GenerateKey(reader) + if err != nil { + return err + } + key.Key = newKP + a.NextOneTimeKeyID++ + if a.NumFallbackKeys < 2 { + a.NumFallbackKeys++ + } + a.CurrentFallbackKey = key + return nil +} + +// FallbackKey returns the public part of the current fallback key of the Account. +// The returned data is a map with the mapping of key id to base64-encoded Curve25519 key. +func (a Account) FallbackKey() map[string]id.Curve25519 { + keys := make(map[string]id.Curve25519) + if a.NumFallbackKeys >= 1 { + keys[a.CurrentFallbackKey.KeyIDEncoded()] = id.Curve25519(a.CurrentFallbackKey.PublicKeyEncoded()) + } + return keys +} + +//FallbackKeyJSON returns the public part of the current fallback key of the Account as a JSON string. +// +//The returned JSON is of format: +/* + { + curve25519: { + "AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo" + } + } +*/ +func (a Account) FallbackKeyJSON() ([]byte, error) { + res := make(map[string]map[string]id.Curve25519) + fbk := a.FallbackKey() + res["curve25519"] = fbk + return json.Marshal(res) +} + +// FallbackKeyUnpublished returns the public part of the current fallback key of the Account only if it is unpublished. +// The returned data is a map with the mapping of key id to base64-encoded Curve25519 key. +func (a Account) FallbackKeyUnpublished() map[string]id.Curve25519 { + keys := make(map[string]id.Curve25519) + if a.NumFallbackKeys >= 1 && !a.CurrentFallbackKey.Published { + keys[a.CurrentFallbackKey.KeyIDEncoded()] = id.Curve25519(a.CurrentFallbackKey.PublicKeyEncoded()) + } + return keys +} + +//FallbackKeyUnpublishedJSON returns the public part of the current fallback key, only if it is unpublished, of the Account as a JSON string. +// +//The returned JSON is of format: +/* + { + curve25519: { + "AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo" + } + } +*/ +func (a Account) FallbackKeyUnpublishedJSON() ([]byte, error) { + res := make(map[string]map[string]id.Curve25519) + fbk := a.FallbackKeyUnpublished() + res["curve25519"] = fbk + return json.Marshal(res) +} + +// ForgetOldFallbackKey resets the previous fallback key in the account. +func (a *Account) ForgetOldFallbackKey() { + if a.NumFallbackKeys >= 2 { + a.NumFallbackKeys = 1 + a.PrevFallbackKey = crypto.OneTimeKey{} + } +} + +// Unpickle decodes the base64 encoded string and decrypts the result with the key. +// The decrypted value is then passed to UnpickleLibOlm. +func (a *Account) Unpickle(pickled, key []byte) error { + decrypted, err := cipher.Unpickle(key, pickled) + if err != nil { + return err + } + _, err = a.UnpickleLibOlm(decrypted) + return err +} + +// UnpickleLibOlm decodes the unencryted value and populates the Account accordingly. It returns the number of bytes read. +func (a *Account) UnpickleLibOlm(value []byte) (int, error) { + //First 4 bytes are the accountPickleVersion + pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value) + if err != nil { + return 0, err + } + switch pickledVersion { + case accountPickleVersionLibOLM, 3, 2: + default: + return 0, errors.Wrap(goolm.ErrBadVersion, "unpickle account") + } + //read ed25519 key pair + readBytes, err := a.IdKeys.Ed25519.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + //read curve25519 key pair + readBytes, err = a.IdKeys.Curve25519.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + //Read number of onetimeKeys + numberOTKeys, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + //Read i one time keys + a.OTKeys = make([]crypto.OneTimeKey, numberOTKeys) + for i := uint32(0); i < numberOTKeys; i++ { + readBytes, err := a.OTKeys[i].UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + } + if pickledVersion <= 2 { + // version 2 did not have fallback keys + a.NumFallbackKeys = 0 + } else if pickledVersion == 3 { + // version 3 used the published flag to indicate how many fallback keys + // were present (we'll have to assume that the keys were published) + readBytes, err := a.CurrentFallbackKey.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + readBytes, err = a.PrevFallbackKey.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + if a.CurrentFallbackKey.Published { + if a.PrevFallbackKey.Published { + a.NumFallbackKeys = 2 + } else { + a.NumFallbackKeys = 1 + } + } else { + a.NumFallbackKeys = 0 + } + } else { + //Read number of fallback keys + numFallbackKeys, readBytes, err := libolmpickle.UnpickleUInt8(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + a.NumFallbackKeys = numFallbackKeys + if a.NumFallbackKeys >= 1 { + readBytes, err := a.CurrentFallbackKey.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + if a.NumFallbackKeys >= 2 { + readBytes, err := a.PrevFallbackKey.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + } + } + } + //Read next onetime key id + nextOTKeyID, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + a.NextOneTimeKeyID = nextOTKeyID + return curPos, nil +} + +// Pickle returns a base64 encoded and with key encrypted pickled account using PickleLibOlm(). +func (a Account) Pickle(key []byte) ([]byte, error) { + pickeledBytes := make([]byte, a.PickleLen()) + written, err := a.PickleLibOlm(pickeledBytes) + if err != nil { + return nil, err + } + if written != len(pickeledBytes) { + return nil, errors.New("number of written bytes not correct") + } + encrypted, err := cipher.Pickle(key, pickeledBytes) + if err != nil { + return nil, err + } + return encrypted, nil +} + +// PickleLibOlm encodes the Account into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (a Account) PickleLibOlm(target []byte) (int, error) { + if len(target) < a.PickleLen() { + return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle account") + } + written := libolmpickle.PickleUInt32(accountPickleVersionLibOLM, target) + writtenEdKey, err := a.IdKeys.Ed25519.PickleLibOlm(target[written:]) + if err != nil { + return 0, errors.Wrap(err, "pickle account") + } + written += writtenEdKey + writtenCurveKey, err := a.IdKeys.Curve25519.PickleLibOlm(target[written:]) + if err != nil { + return 0, errors.Wrap(err, "pickle account") + } + written += writtenCurveKey + written += libolmpickle.PickleUInt32(uint32(len(a.OTKeys)), target[written:]) + for _, curOTKey := range a.OTKeys { + writtenOT, err := curOTKey.PickleLibOlm(target[written:]) + if err != nil { + return 0, errors.Wrap(err, "pickle account") + } + written += writtenOT + } + written += libolmpickle.PickleUInt8(a.NumFallbackKeys, target[written:]) + if a.NumFallbackKeys >= 1 { + writtenOT, err := a.CurrentFallbackKey.PickleLibOlm(target[written:]) + if err != nil { + return 0, errors.Wrap(err, "pickle account") + } + written += writtenOT + + if a.NumFallbackKeys >= 2 { + writtenOT, err := a.PrevFallbackKey.PickleLibOlm(target[written:]) + if err != nil { + return 0, errors.Wrap(err, "pickle account") + } + written += writtenOT + } + } + written += libolmpickle.PickleUInt32(a.NextOneTimeKeyID, target[written:]) + return written, nil +} + +// PickleLen returns the number of bytes the pickled Account will have. +func (a Account) PickleLen() int { + length := libolmpickle.PickleUInt32Len(accountPickleVersionLibOLM) + length += a.IdKeys.Ed25519.PickleLen() + length += a.IdKeys.Curve25519.PickleLen() + length += libolmpickle.PickleUInt32Len(uint32(len(a.OTKeys))) + length += (len(a.OTKeys) * (&crypto.OneTimeKey{}).PickleLen()) + length += libolmpickle.PickleUInt8Len(a.NumFallbackKeys) + length += (int(a.NumFallbackKeys) * (&crypto.OneTimeKey{}).PickleLen()) + length += libolmpickle.PickleUInt32Len(a.NextOneTimeKeyID) + return length +} diff --git a/crypto/goolm/account/account_test.go b/crypto/goolm/account/account_test.go new file mode 100644 index 00000000..1f2b9546 --- /dev/null +++ b/crypto/goolm/account/account_test.go @@ -0,0 +1,675 @@ +package account + +import ( + "bytes" + "errors" + "testing" + + "codeberg.org/DerLukas/goolm" + "codeberg.org/DerLukas/goolm/utilities" + "maunium.net/go/mautrix/id" +) + +type mockRandom struct { + tag byte + current byte +} + +func (m *mockRandom) get(length int) []byte { + res := make([]byte, length) + baseIndex := 0 + for length > 32 { + res[baseIndex] = m.tag + for i := 1; i < 32; i++ { + res[baseIndex+i] = m.current + } + length -= 32 + baseIndex += 32 + m.current++ + } + if length != 0 { + res[baseIndex] = m.tag + for i := 1; i < length-1; i++ { + res[baseIndex+i] = m.current + } + m.current++ + } + return res +} + +func (m *mockRandom) Read(target []byte) (int, error) { + res := m.get(len(target)) + return copy(target, res), nil +} + +func TestAccount(t *testing.T) { + firstAccount, err := NewAccount(nil) + if err != nil { + t.Fatal(err) + } + err = firstAccount.GenFallbackKey(nil) + if err != nil { + t.Fatal(err) + } + err = firstAccount.GenOneTimeKeys(nil, 2) + if err != nil { + t.Fatal(err) + } + encryptionKey := []byte("testkey") + //now pickle account in JSON format + pickled, err := firstAccount.PickleAsJSON(encryptionKey) + if err != nil { + t.Fatal(err) + } + //now unpickle into new Account + unpickledAccount, err := AccountFromJSONPickled(pickled, encryptionKey) + if err != nil { + t.Fatal(err) + } + //check if accounts are the same + if firstAccount.NextOneTimeKeyID != unpickledAccount.NextOneTimeKeyID { + t.Fatal("NextOneTimeKeyID unequal") + } + if !firstAccount.CurrentFallbackKey.Equal(unpickledAccount.CurrentFallbackKey) { + t.Fatal("CurrentFallbackKey unequal") + } + if !firstAccount.PrevFallbackKey.Equal(unpickledAccount.PrevFallbackKey) { + t.Fatal("PrevFallbackKey unequal") + } + if len(firstAccount.OTKeys) != len(unpickledAccount.OTKeys) { + t.Fatal("OneTimeKeysunequal") + } + for i := range firstAccount.OTKeys { + if !firstAccount.OTKeys[i].Equal(unpickledAccount.OTKeys[i]) { + t.Fatalf("OneTimeKeys %d unequal", i) + } + } + if !firstAccount.IdKeys.Curve25519.PrivateKey.Equal(unpickledAccount.IdKeys.Curve25519.PrivateKey) { + t.Fatal("IdentityKeys Curve25519 private unequal") + } + if !firstAccount.IdKeys.Curve25519.PublicKey.Equal(unpickledAccount.IdKeys.Curve25519.PublicKey) { + t.Fatal("IdentityKeys Curve25519 public unequal") + } + if !firstAccount.IdKeys.Ed25519.PrivateKey.Equal(unpickledAccount.IdKeys.Ed25519.PrivateKey) { + t.Fatal("IdentityKeys Ed25519 private unequal") + } + if !firstAccount.IdKeys.Ed25519.PublicKey.Equal(unpickledAccount.IdKeys.Ed25519.PublicKey) { + t.Fatal("IdentityKeys Ed25519 public unequal") + } + + if len(firstAccount.OneTimeKeys()) != 2 { + t.Fatal("should get 2 unpublished oneTimeKeys") + } + if len(firstAccount.FallbackKeyUnpublished()) == 0 { + t.Fatal("should get fallbackKey") + } + firstAccount.MarkKeysAsPublished() + if len(firstAccount.FallbackKey()) == 0 { + t.Fatal("should get fallbackKey") + } + if len(firstAccount.FallbackKeyUnpublished()) != 0 { + t.Fatal("should get no fallbackKey") + } + if len(firstAccount.OneTimeKeys()) != 0 { + t.Fatal("should get no oneTimeKeys") + } +} + +func TestAccountPickleJSON(t *testing.T) { + key := []byte("test key") + + /* + // Generating new values when struct changed + newAccount, _ := NewAccount() + pickled, _ := newAccount.PickleAsJSON(key) + fmt.Println(string(pickled)) + jsonDataNew, _ := newAccount.IdentityKeysJSON() + fmt.Println(string(jsonDataNew)) + return + */ + + pickledData := []byte("fZG5DhZ0+uhVFEcdgo/dyWNy1BlSKo+W18D/QLBcZfvP0rByRzjgJM5yeDIO9N6jYFp2MbV1Y1DikFlDctwq7PhIRvbtLdrzxT94WoLrUdiNtQkw6NRNXvsFYo4NKoAgl1yQauttnGRBHCCPVV6e9d4kvnPVRkZNkbbANnadF0Tld/SMMWWoPI3L7dy+oiRh6nqNKvZz+upvgmOSm6gu2xV0yx9RJpkvLz8oHMDui1VQ1T2wTpfk5vdw0Cx4BXspf8WDnntdv0Ui4qBzUFmsB4lfqLviuhnAxu+qQrrKcZz/EyzbPwmI+P4Tn5KznxzEx2Nw/AjKKPxqVAKpx8+nV7rKKzlah71wX2CHyEsp2ptcNTJ1lr6tJxkOLdy8Rw285jpKw4MrgghnhqZ9Hh3y5P6KnRrq6zom9zfkCtCXs2h8BK+I0tkMPXO+JZoJKVOWzS+n7FIrC9XC9nAu19G5cnxv+tJdPb3p") + account, err := AccountFromJSONPickled(pickledData, key) + if err != nil { + t.Fatal(err) + } + expectedJSON := `{"ed25519":"qWvNB6Ztov5/AOsP073op0O32KJ8/tgSNarT7MaYgQE","curve25519":"TFUB6M6zwgyWhBEp2m1aUodl2AsnsrIuBr8l9AvwGS8"}` + jsonData, err := account.IdentityKeysJSON() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(jsonData, []byte(expectedJSON)) { + t.Fatalf("Expected '%s' but got '%s'", expectedJSON, jsonData) + } +} + +func TestSessions(t *testing.T) { + aliceAccount, err := NewAccount(nil) + if err != nil { + t.Fatal(err) + } + err = aliceAccount.GenOneTimeKeys(nil, 5) + if err != nil { + t.Fatal(err) + } + bobAccount, err := NewAccount(nil) + if err != nil { + t.Fatal(err) + } + err = bobAccount.GenOneTimeKeys(nil, 5) + if err != nil { + t.Fatal(err) + } + aliceSession, err := aliceAccount.NewOutboundSession(bobAccount.IdKeys.Curve25519.B64Encoded(), bobAccount.OTKeys[2].Key.B64Encoded()) + if err != nil { + t.Fatal(err) + } + plaintext := []byte("test message") + msgType, crypttext, err := aliceSession.Encrypt(plaintext, nil) + if err != nil { + t.Fatal(err) + } + if msgType != id.OlmMsgTypePreKey { + t.Fatal("wrong message type") + } + + bobSession, err := bobAccount.NewInboundSession(nil, crypttext) + if err != nil { + t.Fatal(err) + } + decodedText, err := bobSession.Decrypt(crypttext, msgType) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plaintext, decodedText) { + t.Fatalf("expected '%s' but got '%s'", string(plaintext), string(decodedText)) + } +} + +func TestAccountPickle(t *testing.T) { + pickleKey := []byte("secret_key") + account, err := AccountFromPickled(pickledDataFromLibOlm, pickleKey) + if err != nil { + t.Fatal(err) + } + if !expectedEd25519KeyPairPickleLibOLM.PrivateKey.Equal(account.IdKeys.Ed25519.PrivateKey) { + t.Fatal("keys not equal") + } + if !expectedEd25519KeyPairPickleLibOLM.PublicKey.Equal(account.IdKeys.Ed25519.PublicKey) { + t.Fatal("keys not equal") + } + if !expectedCurve25519KeyPairPickleLibOLM.PrivateKey.Equal(account.IdKeys.Curve25519.PrivateKey) { + t.Fatal("keys not equal") + } + if !expectedCurve25519KeyPairPickleLibOLM.PublicKey.Equal(account.IdKeys.Curve25519.PublicKey) { + t.Fatal("keys not equal") + } + if account.NextOneTimeKeyID != 42 { + t.Fatal("wrong next otKey id") + } + if len(account.OTKeys) != len(expectedOTKeysPickleLibOLM) { + t.Fatal("wrong number of otKeys") + } + if account.NumFallbackKeys != 0 { + t.Fatal("fallback keys set but not in pickle") + } + for curIndex, curValue := range account.OTKeys { + curExpected := expectedOTKeysPickleLibOLM[curIndex] + if curExpected.ID != curValue.ID { + t.Fatal("OTKey id not correct") + } + if !curExpected.Key.PublicKey.Equal(curValue.Key.PublicKey) { + t.Fatal("OTKey public key not correct") + } + if !curExpected.Key.PrivateKey.Equal(curValue.Key.PrivateKey) { + t.Fatal("OTKey private key not correct") + } + } + + targetPickled, err := account.Pickle(pickleKey) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(targetPickled, pickledDataFromLibOlm) { + t.Fatal("repickled value does not equal given value") + } +} + +func TestOldAccountPickle(t *testing.T) { + // this uses the old pickle format, which did not use enough space + // for the Ed25519 key. We should reject it. + pickled := []byte("x3h9er86ygvq56pM1yesdAxZou4ResPQC9Rszk/fhEL9JY/umtZ2N/foL/SUgVXS" + + "v0IxHHZTafYjDdzJU9xr8dQeBoOTGfV9E/lCqDGBnIlu7SZndqjEKXtzGyQr4sP4" + + "K/A/8TOu9iK2hDFszy6xETiousHnHgh2ZGbRUh4pQx+YMm8ZdNZeRnwFGLnrWyf9" + + "O5TmXua1FcU") + pickleKey := []byte("") + account, err := NewAccount(nil) + if err != nil { + t.Fatal(err) + } + err = account.Unpickle(pickled, pickleKey) + if err == nil { + t.Fatal("expected error") + } else { + if !errors.Is(err, goolm.ErrBadVersion) { + t.Fatal(err) + } + } +} + +func TestLoopback(t *testing.T) { + mockA := mockRandom{ + tag: []byte("A")[0], + current: 0x00, + } + mockB := mockRandom{ + tag: []byte("B")[0], + current: 0x80, + } + accountA, err := NewAccount(&mockA) + if err != nil { + t.Fatal(err) + } + + accountB, err := NewAccount(&mockB) + if err != nil { + t.Fatal(err) + } + err = accountB.GenOneTimeKeys(&mockB, 42) + if err != nil { + t.Fatal(err) + } + + aliceSession, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), accountB.OTKeys[0].Key.B64Encoded()) + if err != nil { + t.Fatal(err) + } + + plainText := []byte("Hello, World") + msgType, message1, err := aliceSession.Encrypt(plainText, &mockA) + if err != nil { + t.Fatal(err) + } + if msgType != id.OlmMsgTypePreKey { + t.Fatal("wrong message type") + } + + bobSession, err := accountB.NewInboundSession(nil, message1) + if err != nil { + t.Fatal(err) + } + // Check that the inbound session matches the message it was created from. + sessionIsOK, err := bobSession.MatchesInboundSessionFrom(nil, message1) + if err != nil { + t.Fatal(err) + } + if !sessionIsOK { + t.Fatal("session was not detected to be valid") + } + // Check that the inbound session matches the key this message is supposed to be from. + aIDKey := accountA.IdKeys.Curve25519.PublicKey.B64Encoded() + sessionIsOK, err = bobSession.MatchesInboundSessionFrom(&aIDKey, message1) + if err != nil { + t.Fatal(err) + } + if !sessionIsOK { + t.Fatal("session is sad to be not from a but it should") + } + // Check that the inbound session isn't from a different user. + bIDKey := accountB.IdKeys.Curve25519.PublicKey.B64Encoded() + sessionIsOK, err = bobSession.MatchesInboundSessionFrom(&bIDKey, message1) + if err != nil { + t.Fatal(err) + } + if sessionIsOK { + t.Fatal("session is sad to be from b but is from a") + } + // Check that we can decrypt the message. + decryptedMessage, err := bobSession.Decrypt(message1, msgType) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(decryptedMessage, plainText) { + t.Fatal("messages are not the same") + } + + msgTyp2, message2, err := bobSession.Encrypt(plainText, &mockB) + if err != nil { + t.Fatal(err) + } + if msgTyp2 == id.OlmMsgTypePreKey { + t.Fatal("wrong message type") + } + + decryptedMessage2, err := aliceSession.Decrypt(message2, msgTyp2) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(decryptedMessage2, plainText) { + t.Fatal("messages are not the same") + } + + //decrypting again should fail, as the chain moved on + _, err = aliceSession.Decrypt(message2, msgTyp2) + if err == nil { + t.Fatal("expected error") + } + + //compare sessionIDs + if aliceSession.ID() != bobSession.ID() { + t.Fatal("sessionIDs are not equal") + } +} + +func TestMoreMessages(t *testing.T) { + mockA := mockRandom{ + tag: []byte("A")[0], + current: 0x00, + } + mockB := mockRandom{ + tag: []byte("B")[0], + current: 0x80, + } + accountA, err := NewAccount(&mockA) + if err != nil { + t.Fatal(err) + } + + accountB, err := NewAccount(&mockB) + if err != nil { + t.Fatal(err) + } + err = accountB.GenOneTimeKeys(&mockB, 42) + if err != nil { + t.Fatal(err) + } + + aliceSession, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), accountB.OTKeys[0].Key.B64Encoded()) + if err != nil { + t.Fatal(err) + } + + plainText := []byte("Hello, World") + msgType, message1, err := aliceSession.Encrypt(plainText, &mockA) + if err != nil { + t.Fatal(err) + } + if msgType != id.OlmMsgTypePreKey { + t.Fatal("wrong message type") + } + + bobSession, err := accountB.NewInboundSession(nil, message1) + if err != nil { + t.Fatal(err) + } + decryptedMessage, err := bobSession.Decrypt(message1, msgType) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(decryptedMessage, plainText) { + t.Fatal("messages are not the same") + } + + for i := 0; i < 8; i++ { + //alice sends, bob reveices + msgType, message, err := aliceSession.Encrypt(plainText, &mockA) + if err != nil { + t.Fatal(err) + } + if i == 0 { + //The first time should still be a preKeyMessage as bob has not yet send a message to alice + if msgType != id.OlmMsgTypePreKey { + t.Fatal("wrong message type") + } + } else { + if msgType == id.OlmMsgTypePreKey { + t.Fatal("wrong message type") + } + } + decryptedMessage, err := bobSession.Decrypt(message, msgType) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(decryptedMessage, plainText) { + t.Fatal("messages are not the same") + } + + //now bob sends, alice receives + msgType, message, err = bobSession.Encrypt(plainText, &mockA) + if err != nil { + t.Fatal(err) + } + if msgType == id.OlmMsgTypePreKey { + t.Fatal("wrong message type") + } + decryptedMessage, err = aliceSession.Decrypt(message, msgType) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(decryptedMessage, plainText) { + t.Fatal("messages are not the same") + } + } +} + +func TestFallbackKey(t *testing.T) { + mockA := mockRandom{ + tag: []byte("A")[0], + current: 0x00, + } + mockB := mockRandom{ + tag: []byte("B")[0], + current: 0x80, + } + accountA, err := NewAccount(&mockA) + if err != nil { + t.Fatal(err) + } + + accountB, err := NewAccount(&mockB) + if err != nil { + t.Fatal(err) + } + err = accountB.GenFallbackKey(&mockB) + if err != nil { + t.Fatal(err) + } + fallBackKeys := accountB.FallbackKeyUnpublished() + var fallbackKey id.Curve25519 + for _, fbKey := range fallBackKeys { + fallbackKey = fbKey + } + aliceSession, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), fallbackKey) + if err != nil { + t.Fatal(err) + } + + plainText := []byte("Hello, World") + msgType, message1, err := aliceSession.Encrypt(plainText, &mockA) + if err != nil { + t.Fatal(err) + } + if msgType != id.OlmMsgTypePreKey { + t.Fatal("wrong message type") + } + + bobSession, err := accountB.NewInboundSession(nil, message1) + if err != nil { + t.Fatal(err) + } + // Check that the inbound session matches the message it was created from. + sessionIsOK, err := bobSession.MatchesInboundSessionFrom(nil, message1) + if err != nil { + t.Fatal(err) + } + if !sessionIsOK { + t.Fatal("session was not detected to be valid") + } + // Check that the inbound session matches the key this message is supposed to be from. + aIDKey := accountA.IdKeys.Curve25519.PublicKey.B64Encoded() + sessionIsOK, err = bobSession.MatchesInboundSessionFrom(&aIDKey, message1) + if err != nil { + t.Fatal(err) + } + if !sessionIsOK { + t.Fatal("session is sad to be not from a but it should") + } + // Check that the inbound session isn't from a different user. + bIDKey := accountB.IdKeys.Curve25519.PublicKey.B64Encoded() + sessionIsOK, err = bobSession.MatchesInboundSessionFrom(&bIDKey, message1) + if err != nil { + t.Fatal(err) + } + if sessionIsOK { + t.Fatal("session is sad to be from b but is from a") + } + // Check that we can decrypt the message. + decryptedMessage, err := bobSession.Decrypt(message1, msgType) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(decryptedMessage, plainText) { + t.Fatal("messages are not the same") + } + + // create a new fallback key for B (the old fallback should still be usable) + err = accountB.GenFallbackKey(&mockB) + if err != nil { + t.Fatal(err) + } + // start another session and encrypt a message + aliceSession2, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), fallbackKey) + if err != nil { + t.Fatal(err) + } + + msgType2, message2, err := aliceSession2.Encrypt(plainText, &mockA) + if err != nil { + t.Fatal(err) + } + if msgType2 != id.OlmMsgTypePreKey { + t.Fatal("wrong message type") + } + // bobSession should not be valid for the message2 + // Check that the inbound session matches the message it was created from. + sessionIsOK, err = bobSession.MatchesInboundSessionFrom(nil, message2) + if err != nil { + t.Fatal(err) + } + if sessionIsOK { + t.Fatal("session was detected to be valid but should not") + } + bobSession2, err := accountB.NewInboundSession(nil, message2) + if err != nil { + t.Fatal(err) + } + // Check that the inbound session matches the message it was created from. + sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(nil, message2) + if err != nil { + t.Fatal(err) + } + if !sessionIsOK { + t.Fatal("session was not detected to be valid") + } + // Check that the inbound session matches the key this message is supposed to be from. + sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(&aIDKey, message2) + if err != nil { + t.Fatal(err) + } + if !sessionIsOK { + t.Fatal("session is sad to be not from a but it should") + } + // Check that the inbound session isn't from a different user. + sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(&bIDKey, message2) + if err != nil { + t.Fatal(err) + } + if sessionIsOK { + t.Fatal("session is sad to be from b but is from a") + } + // Check that we can decrypt the message. + decryptedMessage2, err := bobSession2.Decrypt(message2, msgType2) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(decryptedMessage2, plainText) { + t.Fatal("messages are not the same") + } + + //Forget the old fallback key -- creating a new session should fail now + accountB.ForgetOldFallbackKey() + // start another session and encrypt a message + aliceSession3, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), fallbackKey) + if err != nil { + t.Fatal(err) + } + msgType3, message3, err := aliceSession3.Encrypt(plainText, &mockA) + if err != nil { + t.Fatal(err) + } + if msgType3 != id.OlmMsgTypePreKey { + t.Fatal("wrong message type") + } + _, err = accountB.NewInboundSession(nil, message3) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, goolm.ErrBadMessageKeyID) { + t.Fatal(err) + } +} + +func TestOldV3AccountPickle(t *testing.T) { + pickledData := []byte("0mSqVn3duHffbhaTbFgW+4JPlcRoqT7z0x4mQ72N+g+eSAk5sgcWSoDzKpMazgcB" + + "46ItEpChthVHTGRA6PD3dly0dUs4ji7VtWTa+1tUv1UbxP92uYf1Ae3fomX0yAoH" + + "OjSrz1+RmuXr+At8jsmsf260sKvhB6LnI3qYsrw6AAtpgk5d5xZd66sLxvvYUuai" + + "+SmmcmT0bHosLTuDiiB9amBvPKkUKtKZmaEAl5ULrgnJygp1/FnwzVfSrw6PBSX6" + + "ZaUEZHZGX1iI6/WjbHqlTQeOQjtaSsPaL5XXpteS9dFsuaANAj+8ks7Ut2Hwg/JP" + + "Ih/ERYBwiMh9Mt3zSAG0NkvgUkcdipKxoSNZ6t+TkqZrN6jG6VCbx+4YpJO24iJb" + + "ShZy8n79aePIgIsxX94ycsTq1ic38sCRSkWGVbCSRkPloHW7ZssLHA") + pickleKey := []byte("") + expectedFallbackJSON := []byte("{\"curve25519\":{\"AAAAAQ\":\"dr98y6VOWt6lJaQgFVZeWY2ky76mga9MEMbdItJTdng\"}}") + expectedUnpublishedFallbackJSON := []byte("{\"curve25519\":{}}") + + account, err := AccountFromPickled(pickledData, pickleKey) + if err != nil { + t.Fatal(err) + } + fallbackJSON, err := account.FallbackKeyJSON() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(fallbackJSON, expectedFallbackJSON) { + t.Fatalf("expected not as result:\n%s\n%s\n", expectedFallbackJSON, fallbackJSON) + } + fallbackJSONUnpublished, err := account.FallbackKeyUnpublishedJSON() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(fallbackJSONUnpublished, expectedUnpublishedFallbackJSON) { + t.Fatalf("expected not as result:\n%s\n%s\n", expectedUnpublishedFallbackJSON, fallbackJSONUnpublished) + } +} + +func TestAccountSign(t *testing.T) { + mockA := mockRandom{ + tag: []byte("A")[0], + current: 0x00, + } + accountA, err := NewAccount(&mockA) + if err != nil { + t.Fatal(err) + } + plainText := []byte("Hello, World") + signature, err := accountA.Sign(plainText) + if err != nil { + t.Fatal(err) + } + verified, err := utilities.VerifySignature(plainText, accountA.IdKeys.Ed25519.B64Encoded(), signature) + if err != nil { + t.Fatal(err) + } + if !verified { + t.Fatal("signature did not verify") + } +} diff --git a/crypto/goolm/account/account_test_data.go b/crypto/goolm/account/account_test_data.go new file mode 100644 index 00000000..742b6ac2 --- /dev/null +++ b/crypto/goolm/account/account_test_data.go @@ -0,0 +1,267 @@ +package account + +import "codeberg.org/DerLukas/goolm/crypto" + +var pickledDataFromLibOlm = []byte("b3jGWBenkTv6DJt90OX+H1ecoXQwihBjhdJHkAft49wS7ubT3Z0ta46p9PCnfKs+fOHeKhJzgfFcD5yCoatcpRzMHRri6V1dG/wMIu8nYvPPMZ8Dy5YlMBRGz0cpnOAhVoUzo/HtvyN8kgoYnZLzorVYepIqQcsLZAiG6qlztXepEflwNG619Rrk/zWYae5RBtxz9Cl0KCTj8cjY5J/SEKU+SCnj4n16wa+RfYXuLK/kBlE30uSWqQBInlLLYiSqOGjr8M0x+3A0eG0gYA+Aohwl5MbjQnDniTbQeg1gh3VwWZ6kJCgRpLnT0j6oc6V4HjP0JjseHe0rBr6W9o88sl6wGmVEr2ZjlvcD6hoCK21A98UZF0GTwHrX0zV7OQtn5cmys3A1xdgcBAo/GXte1d2HzBXSmgrnXExK3Ij+BkZoQSuEFWSUCLjCUFQohK8TfraLZ5+9sOaV/5KaUxqdBTi6HUqoYymCHxzG7olo3hlh+GJ+iOy9tnofqDirISDIIL7KJ2zNJxYHWNZrAVNxHF3rPSrBw9Zl6M2Scm9PdDqnPgGZ+MSCrCnT6UrrfmWurPahnXwdvPED9rykLtcy3aKFsB+RIezeoNdq/4d8sYVwFWd0w9HTXEYG1YY/Km2LiK/exaC98agPN4GXakCUHVZSfz59IZ3bH8jQdtZw2BPkVfNCUoTuUFzIk0LS3AtudCTiUaFdul9/Phj7TlvyvIKH8GUFRiV46fxMHJ9U0HGg6VAKtDR5qkB/nB1X8SWTmmZblR+jGOQvE6VCXSEoCdSyjYK+xtwlZsMHoFoci+NN/uxnrfoMd0D+TpNOyNFwdjn9hoS8kmqvMEyhae0Q5N4O7YwHiH/jZ9ruFTCMK10TeyFN3yxKiRKhiJkgd9bGnmHz25dm9EDkR4i1pXUFuSZCO7WPUX4aLiNMwcltW01EiTdvE7e2jgaoRguXI8gimvOZ/d8kKZ8QIgKURZZHXmud93MOXL3sAy/aBU8dBMt+E5mVeoGM2fns8o5D9Yx3gZ6CgkzmzWinfj82qyc459OcyeyV/gugEt3FI28UBMRfghIV0juOGTAjkh6G3wIyZVk2G4rG0mYONrQQhmgKf06szNQXFBHQ2Pju4pY+QEAng3D2CfXFV6S2bUVeXN0fk46afsV84WPwYg77DWTuR81Ck2arbIcKGsSpMrETMY65rtEoMAcXLzmWgPsIXdo7k+aR4mWmcxjW5a10Wxc1knOi39x0M6gnYGbhmj6IxmalzVFOjG1ZtkFL5fs59nK42aP/JZ0SdtTJjJA0PkbEFL3YOmVtRUmizVtZk63JYuyCgw36XLscTb3VWVynLONYa1RPyRLaz8L5FkTVySCFb8gP9KtDipBpPdGIeD0MGRijAPLweB4iDkM6zv9Yu6dMijZgSR0g6LmjQZPcm1YfI9AK2ht86oKJfvpj+UdYkK+wKKNzJKjKN08+mIKYbsumpbgKqx13d8sawKC4EfGAJHXsadat77Kp/ECCvhh7/i6gqWBHD0+I2LGiuQTr/Vd6OxAGmtyFOzdSGsfWm0cq78Lc6og7HTg3n7TbnEfJMaQktAI7vQYqnsvZV/KnfZ1elfPubFaFiHmCJzfkuk4X4y6r5A6FxpuEltvrHRtecQ6FHLHsBSZrUg7Dei9urMonphUfEj4rsVOMlB3ZKiQ0unmWmacmFKkHm+WhpQtzLS57/iuGCdKiL8qWBiCz80bCQXQp1iwdScZ+pZQ5pwVABH0sr9YfQEz6+oMh3Kp5LiXJAy7kNEs20h4oMP1bc/gN+F5cRubHrz3sXHWpAXF/pNw862Lj7rL0PPOZdomgHSpmKybaQyJxemlxOP9eFw2r2aym/6jc4nQoR+1Mu1ijaroJ8MgZSwTKmru+xgJXnwLx8i76iRlze2F00iNOMg/pRtFQmWh/zLsKukFtIi2PgPKo8xNRQgYvB76x0jLaX3cllpu/pL4LIM2q0p+V9+FPBPCeinkDA7jQzy397vlOSbdECuPYaj2JmkH4zKbxdDlZffgfjWCLDFkqPc0Ixz8O1k24yfkqwG792anEValGM/Hnhdh3a24y4dTV28eo1SoJ6pD1yrjP2RNvgeqs2xEbKOxPywmmOjq0zc805cXBTjDOVyeSFiiY3yJ2GN33KXv67svXw/Ky0Nl9Epbk6xbSB5b76HPAjJ1gZXEUE2zeRTVVOAWzrCERUUjcccz1ozde//rOjzEt+3fxdrq/oglOMW5Ge7ddo/a5lDRs10G3KHreT1sZz9NmA0U7VOQDwwosmt1AYB1LPg3HM2mwi8ahbZf62K1o/W450RWuG072u1unrCYBNYYArN4lnE11L5LvxctFT8qwbBPD1rHAs+Pr4GKQkTWOmUM7wS59GGWE3UEDrntj1Our5NNto7bjK05h33GRF+Vge1+8EfZ+eL2aikjeq/5dU/cN2aw5v3FCkrXzXbIX4YQMw0nu+MbbqXzmPAa7ibS5z0puTYiVs/iMC0ElMSbp9l65iosfVE+kjlF5QDVJaZoW1rTJ8ACo3zOIH6tv515OPvcxDl08nqw/eH37g/bridClsFlJu8aS8Up6Pxzx7hIjNukoDG/wm9qwN/tGgOhZaotzSQEmJTHy7rWc/tamPEJvzZ0Ev99FAlu816Q/HSSk//y8ZriufU5kvYgz0jVeotTShi5LQK30EDTZtjSxE2eJjpRwAmyD6iQwQQiQkVj5inBB6FbF9u1iVhgVsoQ52YEN5id809NSFasmjJ+szJm0E0e6WMfU3tX1j4nPiIdqV7XF8E01pTIZRBmJVUIeiq7XBb5fARjFBveFwwC/Ck3XykYz7CVQJsGOfk6VqlOzzcKhDOivDsVUPbtWJZgP3qC3sJp0dZV1c1BcDHUVbi8HC81F4zobEQGUOyTFsfeiiLUOHvBsveLt87EYpYe2rwjdnEVN5IU3NG2spMc0C7DeEiovv6GCdWgpoHHwPknS42Yv1ltjoDSrlgVF6nXyFXfVWN9CkZoOsEvoUEXGrHnLMT2RPJIWwZFGWXnHs+WKkEfFJbXJvUfTMlPM+hPOIaurwkKmJqh3cCzIGezzWFfSqM7Lv9Fth9FR1QiUcyK2eg1yc7F3lpicxf86RdcbNZD8Wep1uLA0/5zQYmiHA2ZUBFgpN0KanmykSWMHirsBXF4ixwsduRvo4YqqgIMLDrQInJMt5xvHTwHhKMgSOdpReDwM8zYYBV7ukon3cUEo8pwKuXCLs++DNU40bIyP7bpB4rL2Bm0ojNxsTsQwHO/vqXIDPPsyUEtd6J/JuL364XSP9yE5lnk2g5j3rA7uaq/DpA7m1PO+dE0nZcduFH+5OMC/tMZK1PK0Uc3z7LwIaaOzLBraqm16PdJjLCo5YcQgL5GqXN20sznKrRApqq17gdNayaZoEqG1cO0HzmDcr39Ombh/KrbtYFPkBlNbs/9tjgSOgZ5/aWb/gA6lYzspsKdJN0JAD88uYYBY2+GqdJRB86OP+DRHVRrpHOuLkcWNKILmVN5WFljoQXBvtER2IML4GHZytus+E6o2HCq+5Sh5dJlCBKRHIbzs9XPlEWbcE6Z6inuR5zbVYb97VcOhaM83ZjZZGVfdyqX9QagQ7k7eP2Ifju7vZkhz+HHa1v94HeYRGEKZvPRv2nPuzD1bOauwmhu0WKzYXXoSL40XdvSMNhGlKI65uxh6O1DoIJ8fiB3cFZjK2Li1gYMBzw7Mt8R2sSks3iKuiePX0JRTl0cBoiX7RmZuOlJet1HnUhbliS9vncRGZ/aRAL6VpQ5066/FoOq9ahWK+aSmxidcuRHqyIHhqJtUJrP/44vwDWBdZljgu4ysQvv3Fd+reJ1T4BZlY2jWJK1YCtMQMR2TeEaD7c6a7FtoexX3wuWDEe7er55nE2dW8uRjrptSdpAHFXHnmHmjzhP5Mr+MMAbDcK+3YrnsKYnmSdQP1EYPQ8dyIuc07BpypOZLWFiZpRcLwx2mAp6wFULrPptCWfPCcOgzghBa9l8eVDm/zsfZcP1FP6/J8uSbAI5YUz61gcR8VG1DKCX2e5kIk2xoPB0DlRfltx8NlT2VjLz5xKyyyaq0GpE0EfUXu0q4s2Q9tpk") +var expectedEd25519KeyPairPickleLibOLM = crypto.Ed25519KeyPair{ + PublicKey: []byte{237, 217, 234, 95, 181, 217, 229, 96, 41, 51, 153, 83, 191, 158, 47, 242, 100, 163, 120, 171, 15, 117, 176, 58, 70, 181, 5, 53, 64, 26, 99, 55}, + PrivateKey: []byte{232, 245, 108, 122, 156, 40, 107, 206, 71, 27, 156, 60, 52, 126, 39, 215, 255, 217, 81, 248, 206, 228, 153, 244, 31, 114, 88, 127, 207, 250, 255, 122, 196, 207, 3, 96, 142, 227, 172, 88, 13, 230, 140, 125, 200, 220, 19, 127, 144, 79, 32, 249, 135, 238, 3, 205, 227, 73, 250, 219, 223, 248, 175, 20}, +} +var expectedCurve25519KeyPairPickleLibOLM = crypto.Curve25519KeyPair{ + PublicKey: []byte{56, 193, 217, 134, 124, 49, 9, 185, 241, 26, 246, 132, 245, 34, 222, 189, 199, 201, 136, 80, 185, 153, 132, 240, 194, 48, 30, 157, 74, 1, 243, 0}, + PrivateKey: []byte{80, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, +} +var expectedOTKeysPickleLibOLM = []crypto.OneTimeKey{ + {ID: 42, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{41, 72, 49, 87, 49, 27, 143, 250, 203, 35, 151, 49, 248, 200, 99, 225, 101, 68, 203, 251, 132, 115, 253, 59, 21, 61, 111, 58, 252, 200, 85, 61}, + PrivateKey: []byte{80, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43, 43}, + }, + }, + {ID: 41, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{123, 42, 55, 123, 233, 87, 88, 76, 17, 249, 112, 97, 226, 213, 73, 239, 49, 217, 168, 220, 180, 182, 176, 231, 77, 138, 92, 58, 62, 185, 250, 12}, + PrivateKey: []byte{80, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42}, + }, + }, + {ID: 40, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{139, 80, 115, 105, 78, 90, 82, 35, 21, 248, 232, 10, 8, 237, 95, 201, 73, 219, 244, 105, 35, 184, 225, 56, 164, 142, 79, 59, 178, 51, 150, 69}, + PrivateKey: []byte{80, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41}, + }, + }, + {ID: 39, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{176, 111, 229, 19, 195, 233, 77, 12, 228, 241, 254, 193, 139, 127, 150, 20, 182, 36, 103, 30, 207, 5, 35, 93, 60, 81, 53, 133, 216, 4, 81, 94}, + PrivateKey: []byte{80, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40}, + }, + }, + {ID: 38, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{137, 106, 140, 51, 49, 76, 42, 164, 198, 184, 58, 9, 246, 119, 84, 88, 196, 199, 189, 145, 145, 141, 209, 29, 68, 64, 171, 23, 126, 11, 220, 122}, + PrivateKey: []byte{80, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39}, + }, + }, + {ID: 37, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{38, 99, 240, 40, 17, 97, 91, 79, 105, 102, 81, 153, 12, 175, 81, 4, 132, 171, 246, 96, 10, 162, 71, 175, 241, 23, 22, 129, 38, 15, 230, 67}, + PrivateKey: []byte{80, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38, 38}, + }, + }, + {ID: 36, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{205, 28, 163, 27, 148, 116, 82, 169, 230, 7, 184, 192, 76, 177, 196, 129, 62, 32, 76, 145, 247, 56, 220, 180, 74, 193, 205, 178, 158, 209, 168, 123}, + PrivateKey: []byte{80, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37}, + }, + }, + {ID: 35, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{195, 125, 80, 132, 106, 120, 250, 0, 145, 191, 116, 179, 167, 91, 65, 10, 121, 19, 12, 51, 78, 229, 170, 110, 37, 37, 109, 65, 221, 126, 168, 5}, + PrivateKey: []byte{80, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36}, + }, + }, + {ID: 34, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{93, 0, 88, 212, 33, 219, 129, 18, 103, 142, 90, 217, 6, 84, 99, 224, 41, 78, 245, 65, 65, 70, 116, 194, 23, 28, 21, 40, 220, 202, 139, 8}, + PrivateKey: []byte{80, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35}, + }, + }, + {ID: 33, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{158, 245, 92, 234, 230, 162, 236, 226, 172, 246, 255, 113, 231, 162, 211, 19, 141, 244, 36, 127, 235, 47, 38, 209, 7, 107, 245, 147, 161, 89, 246, 53}, + PrivateKey: []byte{80, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34}, + }, + }, + {ID: 32, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{69, 20, 138, 120, 68, 160, 34, 99, 205, 177, 138, 147, 96, 118, 36, 239, 206, 11, 118, 75, 170, 216, 193, 108, 24, 65, 0, 131, 226, 73, 22, 18}, + PrivateKey: []byte{80, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33}, + }, + }, + {ID: 31, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{201, 153, 194, 8, 6, 146, 167, 134, 209, 163, 215, 61, 114, 191, 150, 68, 205, 106, 37, 144, 32, 216, 19, 210, 139, 169, 221, 28, 160, 193, 196, 71}, + PrivateKey: []byte{80, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32}, + }, + }, + {ID: 30, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{211, 29, 161, 172, 192, 112, 209, 226, 113, 120, 177, 145, 108, 134, 92, 21, 31, 29, 162, 237, 77, 179, 96, 247, 123, 246, 47, 40, 238, 242, 206, 53}, + PrivateKey: []byte{80, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31}, + }, + }, + {ID: 29, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{197, 144, 16, 124, 25, 208, 46, 163, 33, 56, 116, 172, 53, 106, 42, 217, 240, 152, 165, 10, 82, 218, 96, 237, 211, 254, 229, 209, 5, 154, 52, 21}, + PrivateKey: []byte{80, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30}, + }, + }, + {ID: 28, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{42, 188, 228, 224, 227, 132, 230, 252, 175, 213, 113, 132, 226, 151, 138, 166, 213, 151, 235, 1, 4, 81, 45, 80, 27, 140, 195, 234, 136, 163, 245, 96}, + PrivateKey: []byte{80, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29}, + }, + }, + {ID: 27, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{153, 0, 67, 133, 177, 241, 105, 219, 32, 58, 135, 239, 145, 124, 122, 32, 137, 109, 40, 177, 54, 85, 46, 69, 231, 253, 146, 150, 228, 172, 9, 66}, + PrivateKey: []byte{80, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28}, + }, + }, + {ID: 26, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{3, 79, 232, 39, 90, 120, 71, 216, 193, 102, 132, 48, 91, 225, 8, 229, 99, 206, 128, 110, 9, 161, 75, 204, 86, 250, 54, 185, 152, 163, 144, 124}, + PrivateKey: []byte{80, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27}, + }, + }, + {ID: 25, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{96, 21, 62, 175, 244, 249, 33, 134, 162, 32, 142, 56, 215, 27, 12, 30, 229, 118, 63, 40, 45, 120, 204, 134, 111, 95, 21, 150, 112, 60, 187, 111}, + PrivateKey: []byte{80, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26}, + }, + }, + {ID: 24, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{103, 239, 218, 49, 88, 55, 161, 63, 238, 39, 114, 106, 175, 158, 59, 43, 39, 112, 239, 175, 29, 174, 75, 172, 9, 84, 230, 109, 214, 77, 170, 124}, + PrivateKey: []byte{80, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25}, + }, + }, + {ID: 23, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{35, 148, 228, 98, 0, 124, 196, 15, 5, 63, 73, 127, 52, 126, 165, 175, 186, 35, 196, 89, 94, 233, 56, 60, 103, 125, 67, 47, 29, 132, 206, 13}, + PrivateKey: []byte{80, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24}, + }, + }, + {ID: 22, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{94, 143, 132, 227, 112, 122, 177, 213, 30, 87, 21, 85, 0, 193, 221, 87, 111, 100, 99, 15, 50, 68, 92, 146, 222, 179, 182, 58, 136, 235, 74, 44}, + PrivateKey: []byte{80, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23, 23}, + }, + }, + {ID: 21, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{232, 20, 27, 90, 55, 105, 146, 28, 107, 129, 73, 107, 1, 35, 70, 190, 227, 54, 169, 214, 160, 99, 150, 180, 37, 109, 115, 211, 84, 115, 91, 73}, + PrivateKey: []byte{80, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22}, + }, + }, + {ID: 20, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{245, 105, 178, 42, 165, 43, 232, 76, 48, 163, 5, 3, 42, 123, 59, 208, 74, 227, 36, 112, 77, 212, 203, 152, 81, 228, 226, 69, 45, 101, 182, 65}, + PrivateKey: []byte{80, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21}, + }, + }, + {ID: 19, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{16, 18, 85, 33, 104, 88, 95, 252, 135, 25, 55, 255, 240, 198, 30, 251, 163, 44, 150, 111, 155, 150, 143, 163, 242, 186, 142, 145, 59, 14, 161, 50}, + PrivateKey: []byte{80, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20}, + }, + }, + {ID: 18, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{32, 138, 232, 106, 32, 165, 39, 122, 146, 194, 126, 235, 84, 72, 127, 106, 83, 32, 219, 45, 201, 36, 226, 133, 201, 67, 168, 199, 112, 73, 166, 68}, + PrivateKey: []byte{80, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19}, + }, + }, + {ID: 17, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{10, 231, 214, 54, 36, 71, 42, 193, 204, 235, 148, 182, 60, 82, 228, 215, 61, 218, 146, 65, 227, 136, 233, 11, 223, 88, 95, 113, 47, 84, 169, 53}, + PrivateKey: []byte{80, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18}, + }, + }, + {ID: 16, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{226, 60, 255, 91, 122, 150, 74, 95, 227, 250, 237, 107, 205, 242, 56, 123, 52, 25, 65, 125, 69, 255, 101, 60, 201, 140, 196, 213, 196, 75, 109, 92}, + PrivateKey: []byte{80, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17}, + }, + }, + {ID: 15, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{171, 229, 71, 9, 133, 66, 150, 143, 73, 156, 11, 216, 148, 7, 153, 129, 237, 207, 228, 193, 55, 183, 156, 178, 132, 85, 154, 43, 19, 29, 170, 127}, + PrivateKey: []byte{80, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16}, + }, + }, + {ID: 14, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{53, 241, 2, 154, 223, 221, 222, 131, 114, 196, 111, 189, 26, 210, 20, 48, 39, 57, 199, 192, 2, 239, 213, 135, 232, 160, 92, 214, 18, 18, 205, 93}, + PrivateKey: []byte{80, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15}, + }, + }, + {ID: 13, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{92, 15, 157, 2, 49, 70, 253, 32, 39, 210, 54, 167, 55, 95, 255, 118, 76, 52, 184, 76, 185, 217, 31, 84, 7, 118, 1, 117, 53, 78, 216, 91}, + PrivateKey: []byte{80, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14}, + }, + }, + {ID: 12, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{67, 94, 86, 147, 61, 140, 71, 173, 0, 97, 202, 174, 242, 37, 198, 173, 214, 104, 89, 37, 204, 136, 32, 62, 166, 165, 56, 194, 242, 26, 79, 12}, + PrivateKey: []byte{80, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13}, + }, + }, + {ID: 11, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{147, 197, 91, 58, 183, 17, 72, 41, 244, 222, 191, 70, 195, 238, 110, 223, 135, 107, 108, 43, 154, 144, 50, 20, 222, 69, 42, 214, 69, 181, 0, 82}, + PrivateKey: []byte{80, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12}, + }, + }, + {ID: 10, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{116, 144, 19, 88, 33, 120, 92, 138, 174, 218, 192, 222, 96, 249, 46, 250, 4, 197, 250, 196, 243, 68, 183, 210, 218, 107, 206, 138, 121, 226, 189, 104}, + PrivateKey: []byte{80, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11}, + }, + }, + {ID: 9, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{140, 220, 222, 205, 238, 56, 126, 139, 40, 172, 222, 189, 235, 73, 50, 238, 125, 114, 73, 193, 80, 87, 86, 82, 205, 247, 206, 222, 164, 151, 1, 110}, + PrivateKey: []byte{80, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10}, + }, + }, + {ID: 8, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{129, 168, 225, 128, 194, 202, 63, 189, 162, 243, 79, 88, 251, 222, 173, 19, 132, 217, 193, 192, 171, 149, 159, 128, 244, 136, 216, 28, 2, 175, 141, 7}, + PrivateKey: []byte{80, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9}, + }, + }, + {ID: 7, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{44, 161, 77, 24, 61, 118, 178, 112, 31, 10, 14, 217, 0, 66, 161, 88, 134, 88, 53, 74, 93, 62, 211, 217, 87, 203, 122, 143, 239, 1, 24, 121}, + PrivateKey: []byte{80, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8}, + }, + }, + {ID: 6, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{167, 86, 75, 53, 54, 151, 106, 235, 48, 47, 54, 144, 180, 160, 209, 24, 78, 99, 57, 76, 109, 162, 233, 213, 170, 121, 37, 203, 178, 212, 130, 0}, + PrivateKey: []byte{80, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7}, + }, + }, + {ID: 5, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{194, 48, 135, 21, 239, 220, 32, 235, 254, 154, 245, 120, 129, 44, 108, 62, 246, 57, 62, 197, 170, 228, 107, 136, 155, 186, 29, 25, 57, 65, 172, 88}, + PrivateKey: []byte{80, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6}, + }, + }, + {ID: 4, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{85, 201, 136, 56, 1, 248, 140, 74, 234, 124, 137, 178, 244, 178, 37, 163, 73, 220, 116, 243, 236, 92, 198, 246, 111, 99, 227, 90, 106, 115, 9, 70}, + PrivateKey: []byte{80, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5}, + }, + }, + {ID: 3, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{38, 244, 146, 200, 126, 125, 48, 184, 222, 106, 254, 236, 231, 113, 26, 128, 84, 137, 162, 163, 97, 54, 213, 96, 254, 23, 55, 178, 114, 105, 93, 83}, + PrivateKey: []byte{80, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4}, + }, + }, + {ID: 2, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{149, 218, 194, 83, 219, 185, 224, 51, 177, 226, 224, 190, 219, 150, 131, 5, 183, 52, 226, 205, 114, 116, 219, 156, 227, 175, 66, 165, 132, 8, 24, 82}, + PrivateKey: []byte{80, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, + }, + }, + {ID: 1, + Key: crypto.Curve25519KeyPair{ + PublicKey: []byte{215, 133, 170, 227, 69, 234, 37, 45, 63, 251, 88, 239, 181, 64, 54, 203, 166, 87, 83, 33, 234, 207, 136, 145, 71, 153, 36, 239, 125, 151, 69, 106}, + PrivateKey: []byte{80, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}, + }, + }, +} diff --git a/crypto/goolm/base64.go b/crypto/goolm/base64.go new file mode 100644 index 00000000..2cc0ec2b --- /dev/null +++ b/crypto/goolm/base64.go @@ -0,0 +1,41 @@ +package goolm + +import ( + "encoding/base64" + + "github.com/pkg/errors" +) + +// Base64Decode decodes the input. Padding characters ('=') will be added if needed. +func Base64Decode(input []byte) ([]byte, error) { + //pad the input to multiple of 4 + addedPadding := 0 + for len(input)%4 != 0 { + input = append(input, []byte("=")...) + addedPadding++ + } + if addedPadding >= 3 { + return nil, errors.Wrap(ErrBase64InvalidLength, "") + } + decoded := make([]byte, base64.StdEncoding.DecodedLen(len(input))) + writtenBytes, err := base64.StdEncoding.Decode(decoded, input) + if err != nil { + return nil, errors.Wrap(ErrBadBase64, err.Error()) + } + //DecodedLen returned the maximum size. However this might not be the true length. + return decoded[:writtenBytes], nil +} + +// Base64Encode encodes the input and strips all padding characters ('=') from the end. +func Base64Encode(input []byte) []byte { + encoded := make([]byte, base64.StdEncoding.EncodedLen(len(input))) + base64.StdEncoding.Encode(encoded, input) + //Remove padding = from output as libolm does so + for curIndex := len(encoded) - 1; curIndex >= 0; curIndex-- { + if string(encoded[curIndex]) != "=" { + encoded = encoded[:curIndex+1] + break + } + } + return encoded +} diff --git a/crypto/goolm/base64_test.go b/crypto/goolm/base64_test.go new file mode 100644 index 00000000..be22a7aa --- /dev/null +++ b/crypto/goolm/base64_test.go @@ -0,0 +1,47 @@ +package goolm + +import ( + "bytes" + "errors" + "testing" +) + +func TestBase64Encode(t *testing.T) { + input := []byte("Hello World") + expected := []byte("SGVsbG8gV29ybGQ") + result := Base64Encode(input) + if !bytes.Equal(result, expected) { + t.Fatalf("expected '%s' but got '%s'", string(expected), string(result)) + } +} + +func TestBase64Decode(t *testing.T) { + input := []byte("SGVsbG8gV29ybGQ") + expected := []byte("Hello World") + result, err := Base64Decode(input) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(result, expected) { + t.Fatalf("expected '%s' but got '%s'", string(expected), string(result)) + } + //This should fai + _, err = Base64Decode([]byte("SGVsbG8gV29ybGQab")) + if err == nil { + t.Fatal("decoded wrong input") + } + if !errors.Is(err, ErrBase64InvalidLength) { + t.Fatalf("got other error as expected: %s", err) + } +} + +func TestBase64DecodeFail(t *testing.T) { + input := []byte("SGVsbG8gV29ybGQab") + _, err := Base64Decode(input) + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, ErrBase64InvalidLength) { + t.Fatal(err) + } +} diff --git a/crypto/goolm/cipher/aesSha256.go b/crypto/goolm/cipher/aesSha256.go new file mode 100644 index 00000000..1234b14e --- /dev/null +++ b/crypto/goolm/cipher/aesSha256.go @@ -0,0 +1,96 @@ +package cipher + +import ( + "bytes" + "io" + + "codeberg.org/DerLukas/goolm/crypto" +) + +// derivedAESKeys stores the derived keys for the AESSha256 cipher +type derivedAESKeys struct { + key []byte + hmacKey []byte + iv []byte +} + +// deriveAESKeys derives three keys for the AESSha256 cipher +func deriveAESKeys(kdfInfo []byte, key []byte) (*derivedAESKeys, error) { + hkdf := crypto.HKDFSHA256(key, nil, kdfInfo) + keys := &derivedAESKeys{ + key: make([]byte, 32), + hmacKey: make([]byte, 32), + iv: make([]byte, 16), + } + if _, err := io.ReadFull(hkdf, keys.key); err != nil { + return nil, err + } + if _, err := io.ReadFull(hkdf, keys.hmacKey); err != nil { + return nil, err + } + if _, err := io.ReadFull(hkdf, keys.iv); err != nil { + return nil, err + } + return keys, nil +} + +// AESSha512BlockSize resturns the blocksize of the cipher AESSha256. +func AESSha512BlockSize() int { + return crypto.AESCBCBlocksize() +} + +// AESSha256 is a valid cipher using AES with CBC and HKDFSha256. +type AESSha256 struct { + kdfInfo []byte +} + +// NewAESSha256 returns a new AESSha256 cipher with the key derive function info (kdfInfo). +func NewAESSha256(kdfInfo []byte) *AESSha256 { + return &AESSha256{ + kdfInfo: kdfInfo, + } +} + +// Encrypt encrypts the plaintext with the key. The key is used to derive the actual encryption key (32 bytes) as well as the iv (16 bytes). +func (c AESSha256) Encrypt(key, plaintext []byte) (ciphertext []byte, err error) { + keys, err := deriveAESKeys(c.kdfInfo, key) + if err != nil { + return nil, err + } + ciphertext, err = crypto.AESCBCEncrypt(keys.key, keys.iv, plaintext) + if err != nil { + return nil, err + } + return ciphertext, nil +} + +// Decrypt decrypts the ciphertext with the key. The key is used to derive the actual encryption key (32 bytes) as well as the iv (16 bytes). +func (c AESSha256) Decrypt(key, ciphertext []byte) (plaintext []byte, err error) { + keys, err := deriveAESKeys(c.kdfInfo, key) + if err != nil { + return nil, err + } + plaintext, err = crypto.AESCBCDecrypt(keys.key, keys.iv, ciphertext) + if err != nil { + return nil, err + } + return plaintext, nil +} + +// MAC returns the MAC for the message using the key. The key is used to derive the actual mac key (32 bytes). +func (c AESSha256) MAC(key, message []byte) ([]byte, error) { + keys, err := deriveAESKeys(c.kdfInfo, key) + if err != nil { + return nil, err + } + return crypto.HMACSHA256(keys.hmacKey, message), nil +} + +// Verify checks the MAC of the message using the key against the givenMAC. The key is used to derive the actual mac key (32 bytes). +func (c AESSha256) Verify(key, message, givenMAC []byte) (bool, error) { + mac, err := c.MAC(key, message) + if err != nil { + return false, err + } + return bytes.Equal(givenMAC, mac[:len(givenMAC)]), nil +} diff --git a/crypto/goolm/cipher/aesSha256_test.go b/crypto/goolm/cipher/aesSha256_test.go new file mode 100644 index 00000000..e8068248 --- /dev/null +++ b/crypto/goolm/cipher/aesSha256_test.go @@ -0,0 +1,83 @@ +package cipher + +import ( + "bytes" + "crypto/aes" + "testing" +) + +func TestDeriveAESKeys(t *testing.T) { + kdfInfo := []byte("test") + key := []byte("test key") + derivedKeys, err := deriveAESKeys(kdfInfo, key) + if err != nil { + t.Fatal(err) + } + derivedKeys2, err := deriveAESKeys(kdfInfo, key) + if err != nil { + t.Fatal(err) + } + //derivedKeys and derivedKeys2 should be identical + if !bytes.Equal(derivedKeys.key, derivedKeys2.key) || + !bytes.Equal(derivedKeys.iv, derivedKeys2.iv) || + !bytes.Equal(derivedKeys.hmacKey, derivedKeys2.hmacKey) { + t.Fail() + } + //changing kdfInfo + kdfInfo = []byte("other kdf") + derivedKeys2, err = deriveAESKeys(kdfInfo, key) + if err != nil { + t.Fatal(err) + } + //derivedKeys and derivedKeys2 should now be different + if bytes.Equal(derivedKeys.key, derivedKeys2.key) || + bytes.Equal(derivedKeys.iv, derivedKeys2.iv) || + bytes.Equal(derivedKeys.hmacKey, derivedKeys2.hmacKey) { + t.Fail() + } + //changing key + key = []byte("other test key") + derivedKeys, err = deriveAESKeys(kdfInfo, key) + if err != nil { + t.Fatal(err) + } + //derivedKeys and derivedKeys2 should now be different + if bytes.Equal(derivedKeys.key, derivedKeys2.key) || + bytes.Equal(derivedKeys.iv, derivedKeys2.iv) || + bytes.Equal(derivedKeys.hmacKey, derivedKeys2.hmacKey) { + t.Fail() + } +} + +func TestCipherAESSha256(t *testing.T) { + key := []byte("test key") + cipher := NewAESSha256([]byte("testKDFinfo")) + message := []byte("this is a random message for testing the implementation") + //increase to next block size + for len(message)%aes.BlockSize != 0 { + message = append(message, []byte("-")...) + } + encrypted, err := cipher.Encrypt(key, []byte(message)) + if err != nil { + t.Fatal(err) + } + mac, err := cipher.MAC(key, encrypted) + if err != nil { + t.Fatal(err) + } + + verified, err := cipher.Verify(key, encrypted, mac[:8]) + if err != nil { + t.Fatal(err) + } + if !verified { + t.Fatal("signature verification failed") + } + resultPlainText, err := cipher.Decrypt(key, encrypted) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(message, resultPlainText) { + t.Fail() + } +} diff --git a/crypto/goolm/cipher/main.go b/crypto/goolm/cipher/main.go new file mode 100644 index 00000000..a8664702 --- /dev/null +++ b/crypto/goolm/cipher/main.go @@ -0,0 +1,17 @@ +// cipher provides the methods and structs to do encryptions for olm/megolm. +package cipher + +// Cipher defines a valid cipher. +type Cipher interface { + // Encrypt encrypts the plaintext. + Encrypt(key, plaintext []byte) (ciphertext []byte, err error) + + // Decrypt decrypts the ciphertext. + Decrypt(key, ciphertext []byte) (plaintext []byte, err error) + + //MAC returns the MAC of the message calculated with the key. + MAC(key, message []byte) ([]byte, error) + + //Verify checks the MAC of the message calculated with the key against the givenMAC. + Verify(key, message, givenMAC []byte) (bool, error) +} diff --git a/crypto/goolm/cipher/pickle.go b/crypto/goolm/cipher/pickle.go new file mode 100644 index 00000000..9faa0285 --- /dev/null +++ b/crypto/goolm/cipher/pickle.go @@ -0,0 +1,57 @@ +package cipher + +import ( + "codeberg.org/DerLukas/goolm" + "github.com/pkg/errors" +) + +const ( + kdfPickle = "Pickle" //used to derive the keys for encryption + pickleMACLength = 8 +) + +// PickleBlockSize returns the blocksize of the used cipher. +func PickleBlockSize() int { + return AESSha512BlockSize() +} + +// Pickle encrypts the input with the key and the cipher AESSha256. The result is then encoded in base64. +func Pickle(key, input []byte) ([]byte, error) { + pickleCipher := NewAESSha256([]byte(kdfPickle)) + ciphertext, err := pickleCipher.Encrypt(key, input) + if err != nil { + return nil, err + } + mac, err := pickleCipher.MAC(key, ciphertext) + if err != nil { + return nil, err + } + ciphertext = append(ciphertext, mac[:pickleMACLength]...) + encoded := goolm.Base64Encode(ciphertext) + return encoded, nil +} + +// Unpickle decodes the input from base64 and decrypts the decoded input with the key and the cipher AESSha256. +func Unpickle(key, input []byte) ([]byte, error) { + pickleCipher := NewAESSha256([]byte(kdfPickle)) + ciphertext, err := goolm.Base64Decode(input) + if err != nil { + return nil, err + } + //remove mac and check + verified, err := pickleCipher.Verify(key, ciphertext[:len(ciphertext)-pickleMACLength], ciphertext[len(ciphertext)-pickleMACLength:]) + if err != nil { + return nil, err + } + if !verified { + return nil, errors.Wrap(goolm.ErrBadMAC, "decrypt pickle") + } + //Set to next block size + targetCipherText := make([]byte, int(len(ciphertext)/PickleBlockSize())*PickleBlockSize()) + copy(targetCipherText, ciphertext) + plaintext, err := pickleCipher.Decrypt(key, targetCipherText) + if err != nil { + return nil, err + } + return plaintext, nil +} diff --git a/crypto/goolm/cipher/pickle_test.go b/crypto/goolm/cipher/pickle_test.go new file mode 100644 index 00000000..0e06fece --- /dev/null +++ b/crypto/goolm/cipher/pickle_test.go @@ -0,0 +1,31 @@ +package cipher + +import ( + "bytes" + "crypto/aes" + "testing" +) + +func TestEncoding(t *testing.T) { + key := []byte("test key") + input := []byte("test") + //pad marshaled to get block size + toEncrypt := input + if len(input)%aes.BlockSize != 0 { + padding := aes.BlockSize - len(input)%aes.BlockSize + toEncrypt = make([]byte, len(input)+padding) + copy(toEncrypt, input) + } + encoded, err := Pickle(key, toEncrypt) + if err != nil { + t.Fatal(err) + } + + decoded, err := Unpickle(key, encoded) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(decoded, toEncrypt) { + t.Fatalf("Expected '%s' but got '%s'", toEncrypt, decoded) + } +} diff --git a/crypto/goolm/crypto/aesCBC.go b/crypto/goolm/crypto/aesCBC.go new file mode 100644 index 00000000..87a12a14 --- /dev/null +++ b/crypto/goolm/crypto/aesCBC.go @@ -0,0 +1,75 @@ +package crypto + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + + "codeberg.org/DerLukas/goolm" + "github.com/pkg/errors" +) + +// AESCBCBlocksize returns the blocksize of the encryption method +func AESCBCBlocksize() int { + return aes.BlockSize +} + +// AESCBCEncrypt encrypts the plaintext with the key and iv. len(iv) must be equal to the blocksize! +func AESCBCEncrypt(key, iv, plaintext []byte) ([]byte, error) { + if len(key) == 0 { + return nil, errors.Wrap(goolm.ErrNoKeyProvided, "AESCBCEncrypt") + } + if len(iv) != AESCBCBlocksize() { + return nil, errors.Wrap(goolm.ErrNotBlocksize, "iv") + } + var cipherText []byte + plaintext = pkcs5Padding(plaintext, AESCBCBlocksize()) + if len(plaintext)%AESCBCBlocksize() != 0 { + return nil, errors.Wrap(goolm.ErrNotMultipleBlocksize, "message") + } + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + cipherText = make([]byte, len(plaintext)) + cbc := cipher.NewCBCEncrypter(block, iv) + cbc.CryptBlocks(cipherText, plaintext) + return cipherText, nil +} + +// AESCBCDecrypt decrypts the ciphertext with the key and iv. len(iv) must be equal to the blocksize! +func AESCBCDecrypt(key, iv, ciphertext []byte) ([]byte, error) { + if len(key) == 0 { + return nil, errors.Wrap(goolm.ErrNoKeyProvided, "AESCBCEncrypt") + } + if len(iv) != AESCBCBlocksize() { + return nil, errors.Wrap(goolm.ErrNotBlocksize, "iv") + } + var block cipher.Block + var err error + block, err = aes.NewCipher(key) + if err != nil { + return nil, err + } + if len(ciphertext) < AESCBCBlocksize() { + return nil, errors.Wrap(goolm.ErrNotMultipleBlocksize, "ciphertext") + } + + cbc := cipher.NewCBCDecrypter(block, iv) + cbc.CryptBlocks(ciphertext, ciphertext) + return pkcs5Unpadding(ciphertext), nil +} + +// pkcs5Padding paddes the plaintext to be used in the AESCBC encryption. +func pkcs5Padding(plaintext []byte, blockSize int) []byte { + padding := (blockSize - len(plaintext)%blockSize) + padtext := bytes.Repeat([]byte{byte(padding)}, padding) + return append(plaintext, padtext...) +} + +// pkcs5Unpadding undoes the padding to the plaintext after AESCBC decryption. +func pkcs5Unpadding(plaintext []byte) []byte { + length := len(plaintext) + unpadding := int(plaintext[length-1]) + return plaintext[:(length - unpadding)] +} diff --git a/crypto/goolm/crypto/aesCBC_test.go b/crypto/goolm/crypto/aesCBC_test.go new file mode 100644 index 00000000..87819b40 --- /dev/null +++ b/crypto/goolm/crypto/aesCBC_test.go @@ -0,0 +1,70 @@ +package crypto + +import ( + "bytes" + "crypto/aes" + "crypto/rand" + "testing" +) + +func TestAESCBC(t *testing.T) { + var ciphertext, plaintext []byte + var err error + + // The key length can be 32, 24, 16 bytes (OR in bits: 128, 192 or 256) + key := make([]byte, 32) + _, err = rand.Read(key) + if err != nil { + t.Fatal(err) + } + iv := make([]byte, aes.BlockSize) + _, err = rand.Read(iv) + if err != nil { + t.Fatal(err) + } + plaintext = []byte("secret message for testing") + //increase to next block size + for len(plaintext)%8 != 0 { + plaintext = append(plaintext, []byte("-")...) + } + + if ciphertext, err = AESCBCEncrypt(key, iv, plaintext); err != nil { + t.Fatal(err) + } + + resultPlainText, err := AESCBCDecrypt(key, iv, ciphertext) + if err != nil { + t.Fatal(err) + } + + if string(resultPlainText) != string(plaintext) { + t.Fatalf("message '%s' (length %d) != '%s'", resultPlainText, len(resultPlainText), plaintext) + } +} + +func TestAESCBCCase1(t *testing.T) { + expected := []byte{ + 0xDC, 0x95, 0xC0, 0x78, 0xA2, 0x40, 0x89, 0x89, + 0xAD, 0x48, 0xA2, 0x14, 0x92, 0x84, 0x20, 0x87, + 0xF3, 0xC0, 0x03, 0xDD, 0xC4, 0xA7, 0xB8, 0xA9, + 0x4B, 0xAE, 0xDF, 0xFC, 0x3D, 0x21, 0x4C, 0x38, + } + input := make([]byte, 16) + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + encrypted, err := AESCBCEncrypt(key, iv, input) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(expected, encrypted) { + t.Fatalf("encrypted did not match expected:\n%v\n%v\n", encrypted, expected) + } + + decrypted, err := AESCBCDecrypt(key, iv, encrypted) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(input, decrypted) { + t.Fatalf("decrypted did not match expected:\n%v\n%v\n", decrypted, input) + } +} diff --git a/crypto/goolm/crypto/curve25519.go b/crypto/goolm/crypto/curve25519.go new file mode 100644 index 00000000..5b39ceca --- /dev/null +++ b/crypto/goolm/crypto/curve25519.go @@ -0,0 +1,184 @@ +package crypto + +import ( + "bytes" + "crypto/rand" + "io" + + "codeberg.org/DerLukas/goolm" + libolmpickle "codeberg.org/DerLukas/goolm/libolmPickle" + "github.com/pkg/errors" + "golang.org/x/crypto/curve25519" + "maunium.net/go/mautrix/id" +) + +const ( + Curve25519KeyLength = curve25519.ScalarSize //The length of the private key. + curve25519PubKeyLength = 32 +) + +// Curve25519GenerateKey creates a new curve25519 key pair. If reader is nil, the random data is taken from crypto/rand. +func Curve25519GenerateKey(reader io.Reader) (Curve25519KeyPair, error) { + privateKeyByte := make([]byte, Curve25519KeyLength) + if reader == nil { + _, err := rand.Read(privateKeyByte) + if err != nil { + return Curve25519KeyPair{}, err + } + } else { + _, err := reader.Read(privateKeyByte) + if err != nil { + return Curve25519KeyPair{}, err + } + } + + privateKey := Curve25519PrivateKey(privateKeyByte) + + publicKey, err := privateKey.PubKey() + if err != nil { + return Curve25519KeyPair{}, err + } + return Curve25519KeyPair{ + PrivateKey: Curve25519PrivateKey(privateKey), + PublicKey: Curve25519PublicKey(publicKey), + }, nil +} + +// Curve25519GenerateFromPrivate creates a new curve25519 key pair with the private key given. +func Curve25519GenerateFromPrivate(private Curve25519PrivateKey) (Curve25519KeyPair, error) { + publicKey, err := private.PubKey() + if err != nil { + return Curve25519KeyPair{}, err + } + return Curve25519KeyPair{ + PrivateKey: private, + PublicKey: Curve25519PublicKey(publicKey), + }, nil +} + +// Curve25519KeyPair stores both parts of a curve25519 key. +type Curve25519KeyPair struct { + PrivateKey Curve25519PrivateKey `json:"private,omitempty"` + PublicKey Curve25519PublicKey `json:"public,omitempty"` +} + +// B64Encoded returns a base64 encoded string of the public key. +func (c Curve25519KeyPair) B64Encoded() id.Curve25519 { + return c.PublicKey.B64Encoded() +} + +// SharedSecret returns the shared secret between the key pair and the given public key. +func (c Curve25519KeyPair) SharedSecret(pubKey Curve25519PublicKey) ([]byte, error) { + return c.PrivateKey.SharedSecret(pubKey) +} + +// PickleLibOlm encodes the key pair into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (c Curve25519KeyPair) PickleLibOlm(target []byte) (int, error) { + if len(target) < c.PickleLen() { + return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle curve25519 key pair") + } + written, err := c.PublicKey.PickleLibOlm(target) + if err != nil { + return 0, errors.Wrap(err, "pickle curve25519 key pair") + } + if len(c.PrivateKey) != Curve25519KeyLength { + written += libolmpickle.PickleBytes(make([]byte, Curve25519KeyLength), target[written:]) + } else { + written += libolmpickle.PickleBytes(c.PrivateKey, target[written:]) + } + return written, nil +} + +// UnpickleLibOlm decodes the unencryted value and populates the key pair accordingly. It returns the number of bytes read. +func (c *Curve25519KeyPair) UnpickleLibOlm(value []byte) (int, error) { + //unpickle PubKey + read, err := c.PublicKey.UnpickleLibOlm(value) + if err != nil { + return 0, err + } + //unpickle PrivateKey + privKey, readPriv, err := libolmpickle.UnpickleBytes(value[read:], Curve25519KeyLength) + if err != nil { + return read, err + } + c.PrivateKey = privKey + return read + readPriv, nil +} + +// PickleLen returns the number of bytes the pickled key pair will have. +func (c Curve25519KeyPair) PickleLen() int { + lenPublic := c.PublicKey.PickleLen() + var lenPrivate int + if len(c.PrivateKey) != Curve25519KeyLength { + lenPrivate = libolmpickle.PickleBytesLen(make([]byte, Curve25519KeyLength)) + } else { + lenPrivate = libolmpickle.PickleBytesLen(c.PrivateKey) + } + return lenPublic + lenPrivate +} + +// Curve25519PrivateKey represents the private key for curve25519 usage +type Curve25519PrivateKey []byte + +// Equal compares the private key to the given private key. +func (c Curve25519PrivateKey) Equal(x Curve25519PrivateKey) bool { + return bytes.Equal(c, x) +} + +// PubKey returns the public key derived from the private key. +func (c Curve25519PrivateKey) PubKey() (Curve25519PublicKey, error) { + publicKey, err := curve25519.X25519(c, curve25519.Basepoint) + if err != nil { + return nil, err + } + return publicKey, nil +} + +// SharedSecret returns the shared secret between the private key and the given public key. +func (c Curve25519PrivateKey) SharedSecret(pubKey Curve25519PublicKey) ([]byte, error) { + return curve25519.X25519(c, pubKey) +} + +// Curve25519PublicKey represents the public key for curve25519 usage +type Curve25519PublicKey []byte + +// Equal compares the public key to the given public key. +func (c Curve25519PublicKey) Equal(x Curve25519PublicKey) bool { + return bytes.Equal(c, x) +} + +// B64Encoded returns a base64 encoded string of the public key. +func (c Curve25519PublicKey) B64Encoded() id.Curve25519 { + return id.Curve25519(goolm.Base64Encode(c)) +} + +// PickleLibOlm encodes the public key into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (c Curve25519PublicKey) PickleLibOlm(target []byte) (int, error) { + if len(target) < c.PickleLen() { + return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle curve25519 public key") + } + if len(c) != curve25519PubKeyLength { + return libolmpickle.PickleBytes(make([]byte, curve25519PubKeyLength), target), nil + } + return libolmpickle.PickleBytes(c, target), nil +} + +// UnpickleLibOlm decodes the unencryted value and populates the public key accordingly. It returns the number of bytes read. +func (c *Curve25519PublicKey) UnpickleLibOlm(value []byte) (int, error) { + unpickled, readBytes, err := libolmpickle.UnpickleBytes(value, curve25519PubKeyLength) + if err != nil { + return 0, err + } + *c = unpickled + return readBytes, nil +} + +// PickleLen returns the number of bytes the pickled public key will have. +func (c Curve25519PublicKey) PickleLen() int { + if len(c) != curve25519PubKeyLength { + return libolmpickle.PickleBytesLen(make([]byte, curve25519PubKeyLength)) + } + return libolmpickle.PickleBytesLen(c) +} diff --git a/crypto/goolm/crypto/curve25519_test.go b/crypto/goolm/crypto/curve25519_test.go new file mode 100644 index 00000000..e8d4090b --- /dev/null +++ b/crypto/goolm/crypto/curve25519_test.go @@ -0,0 +1,185 @@ +package crypto + +import ( + "bytes" + "testing" +) + +func TestCurve25519(t *testing.T) { + firstKeypair, err := Curve25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + secondKeypair, err := Curve25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + sharedSecretFromFirst, err := firstKeypair.SharedSecret(secondKeypair.PublicKey) + if err != nil { + t.Fatal(err) + } + sharedSecretFromSecond, err := secondKeypair.SharedSecret(firstKeypair.PublicKey) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(sharedSecretFromFirst, sharedSecretFromSecond) { + t.Fatal("shared secret not equal") + } + fromPrivate, err := Curve25519GenerateFromPrivate(firstKeypair.PrivateKey) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(fromPrivate.PublicKey, firstKeypair.PublicKey) { + t.Fatal("public keys not equal") + } +} + +func TestCurve25519Case1(t *testing.T) { + alicePrivate := []byte{ + 0x77, 0x07, 0x6D, 0x0A, 0x73, 0x18, 0xA5, 0x7D, + 0x3C, 0x16, 0xC1, 0x72, 0x51, 0xB2, 0x66, 0x45, + 0xDF, 0x4C, 0x2F, 0x87, 0xEB, 0xC0, 0x99, 0x2A, + 0xB1, 0x77, 0xFB, 0xA5, 0x1D, 0xB9, 0x2C, 0x2A, + } + alicePublic := []byte{ + 0x85, 0x20, 0xF0, 0x09, 0x89, 0x30, 0xA7, 0x54, + 0x74, 0x8B, 0x7D, 0xDC, 0xB4, 0x3E, 0xF7, 0x5A, + 0x0D, 0xBF, 0x3A, 0x0D, 0x26, 0x38, 0x1A, 0xF4, + 0xEB, 0xA4, 0xA9, 0x8E, 0xAA, 0x9B, 0x4E, 0x6A, + } + bobPrivate := []byte{ + 0x5D, 0xAB, 0x08, 0x7E, 0x62, 0x4A, 0x8A, 0x4B, + 0x79, 0xE1, 0x7F, 0x8B, 0x83, 0x80, 0x0E, 0xE6, + 0x6F, 0x3B, 0xB1, 0x29, 0x26, 0x18, 0xB6, 0xFD, + 0x1C, 0x2F, 0x8B, 0x27, 0xFF, 0x88, 0xE0, 0xEB, + } + bobPublic := []byte{ + 0xDE, 0x9E, 0xDB, 0x7D, 0x7B, 0x7D, 0xC1, 0xB4, + 0xD3, 0x5B, 0x61, 0xC2, 0xEC, 0xE4, 0x35, 0x37, + 0x3F, 0x83, 0x43, 0xC8, 0x5B, 0x78, 0x67, 0x4D, + 0xAD, 0xFC, 0x7E, 0x14, 0x6F, 0x88, 0x2B, 0x4F, + } + expectedAgreement := []byte{ + 0x4A, 0x5D, 0x9D, 0x5B, 0xA4, 0xCE, 0x2D, 0xE1, + 0x72, 0x8E, 0x3B, 0xF4, 0x80, 0x35, 0x0F, 0x25, + 0xE0, 0x7E, 0x21, 0xC9, 0x47, 0xD1, 0x9E, 0x33, + 0x76, 0xF0, 0x9B, 0x3C, 0x1E, 0x16, 0x17, 0x42, + } + aliceKeyPair := Curve25519KeyPair{ + PrivateKey: alicePrivate, + PublicKey: alicePublic, + } + bobKeyPair := Curve25519KeyPair{ + PrivateKey: bobPrivate, + PublicKey: bobPublic, + } + agreementFromAlice, err := aliceKeyPair.SharedSecret(bobKeyPair.PublicKey) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(agreementFromAlice, expectedAgreement) { + t.Fatal("expected agreement does not match agreement from Alice's view") + } + agreementFromBob, err := bobKeyPair.SharedSecret(aliceKeyPair.PublicKey) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(agreementFromBob, expectedAgreement) { + t.Fatal("expected agreement does not match agreement from Bob's view") + } +} + +func TestCurve25519Pickle(t *testing.T) { + //create keypair + keyPair, err := Curve25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + target := make([]byte, keyPair.PickleLen()) + writtenBytes, err := keyPair.PickleLibOlm(target) + if err != nil { + t.Fatal(err) + } + if writtenBytes != len(target) { + t.Fatal("written bytes not correct") + } + + unpickledKeyPair := Curve25519KeyPair{} + readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) + if err != nil { + t.Fatal(err) + } + if readBytes != len(target) { + t.Fatal("read bytes not correct") + } + if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) { + t.Fatal("private keys not correct") + } + if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) { + t.Fatal("public keys not correct") + } +} + +func TestCurve25519PicklePubKeyOnly(t *testing.T) { + //create keypair + keyPair, err := Curve25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + //Remove privateKey + keyPair.PrivateKey = nil + target := make([]byte, keyPair.PickleLen()) + writtenBytes, err := keyPair.PickleLibOlm(target) + if err != nil { + t.Fatal(err) + } + if writtenBytes != len(target) { + t.Fatal("written bytes not correct") + } + unpickledKeyPair := Curve25519KeyPair{} + readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) + if err != nil { + t.Fatal(err) + } + if readBytes != len(target) { + t.Fatal("read bytes not correct") + } + if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) { + t.Fatal("private keys not correct") + } + if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) { + t.Fatal("public keys not correct") + } +} + +func TestCurve25519PicklePrivKeyOnly(t *testing.T) { + //create keypair + keyPair, err := Curve25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + //Remove public + keyPair.PublicKey = nil + target := make([]byte, keyPair.PickleLen()) + writtenBytes, err := keyPair.PickleLibOlm(target) + if err != nil { + t.Fatal(err) + } + if writtenBytes != len(target) { + t.Fatal("written bytes not correct") + } + unpickledKeyPair := Curve25519KeyPair{} + readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) + if err != nil { + t.Fatal(err) + } + if readBytes != len(target) { + t.Fatal("read bytes not correct") + } + if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) { + t.Fatal("private keys not correct") + } + if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) { + t.Fatal("public keys not correct") + } +} diff --git a/crypto/goolm/crypto/ed25519.go b/crypto/goolm/crypto/ed25519.go new file mode 100644 index 00000000..483e0041 --- /dev/null +++ b/crypto/goolm/crypto/ed25519.go @@ -0,0 +1,180 @@ +package crypto + +import ( + "bytes" + "crypto/ed25519" + "io" + + "codeberg.org/DerLukas/goolm" + libolmpickle "codeberg.org/DerLukas/goolm/libolmPickle" + "github.com/pkg/errors" + "maunium.net/go/mautrix/id" +) + +const ( + ED25519SignatureSize = ed25519.SignatureSize //The length of a signature +) + +// Ed25519GenerateKey creates a new ed25519 key pair. If reader is nil, the random data is taken from crypto/rand. +func Ed25519GenerateKey(reader io.Reader) (Ed25519KeyPair, error) { + publicKey, privateKey, err := ed25519.GenerateKey(reader) + if err != nil { + return Ed25519KeyPair{}, err + } + return Ed25519KeyPair{ + PrivateKey: Ed25519PrivateKey(privateKey), + PublicKey: Ed25519PublicKey(publicKey), + }, nil +} + +// Ed25519GenerateFromPrivate creates a new ed25519 key pair with the private key given. +func Ed25519GenerateFromPrivate(privKey Ed25519PrivateKey) Ed25519KeyPair { + return Ed25519KeyPair{ + PrivateKey: privKey, + PublicKey: privKey.PubKey(), + } +} + +// Ed25519GenerateFromSeed creates a new ed25519 key pair with a given seed. +func Ed25519GenerateFromSeed(seed []byte) Ed25519KeyPair { + privKey := Ed25519PrivateKey(ed25519.NewKeyFromSeed(seed)) + return Ed25519KeyPair{ + PrivateKey: privKey, + PublicKey: privKey.PubKey(), + } +} + +// Ed25519KeyPair stores both parts of a ed25519 key. +type Ed25519KeyPair struct { + PrivateKey Ed25519PrivateKey `json:"private,omitempty"` + PublicKey Ed25519PublicKey `json:"public,omitempty"` +} + +// B64Encoded returns a base64 encoded string of the public key. +func (c Ed25519KeyPair) B64Encoded() id.Ed25519 { + return id.Ed25519(goolm.Base64Encode(c.PublicKey)) +} + +// Sign returns the signature for the message. +func (c Ed25519KeyPair) Sign(message []byte) []byte { + return c.PrivateKey.Sign(message) +} + +// Verify checks the signature of the message against the givenSignature +func (c Ed25519KeyPair) Verify(message, givenSignature []byte) bool { + return c.PublicKey.Verify(message, givenSignature) +} + +// PickleLibOlm encodes the key pair into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (c Ed25519KeyPair) PickleLibOlm(target []byte) (int, error) { + if len(target) < c.PickleLen() { + return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle ed25519 key pair") + } + written, err := c.PublicKey.PickleLibOlm(target) + if err != nil { + return 0, errors.Wrap(err, "pickle ed25519 key pair") + } + + if len(c.PrivateKey) != ed25519.PrivateKeySize { + written += libolmpickle.PickleBytes(make([]byte, ed25519.PrivateKeySize), target[written:]) + } else { + written += libolmpickle.PickleBytes(c.PrivateKey, target[written:]) + } + return written, nil +} + +// UnpickleLibOlm decodes the unencryted value and populates the key pair accordingly. It returns the number of bytes read. +func (c *Ed25519KeyPair) UnpickleLibOlm(value []byte) (int, error) { + //unpickle PubKey + read, err := c.PublicKey.UnpickleLibOlm(value) + if err != nil { + return 0, err + } + //unpickle PrivateKey + privKey, readPriv, err := libolmpickle.UnpickleBytes(value[read:], ed25519.PrivateKeySize) + if err != nil { + return read, err + } + c.PrivateKey = privKey + return read + readPriv, nil +} + +// PickleLen returns the number of bytes the pickled key pair will have. +func (c Ed25519KeyPair) PickleLen() int { + lenPublic := c.PublicKey.PickleLen() + var lenPrivate int + if len(c.PrivateKey) != ed25519.PrivateKeySize { + lenPrivate = libolmpickle.PickleBytesLen(make([]byte, ed25519.PrivateKeySize)) + } else { + lenPrivate = libolmpickle.PickleBytesLen(c.PrivateKey) + } + return lenPublic + lenPrivate +} + +// Curve25519PrivateKey represents the private key for ed25519 usage. This is just a wrapper. +type Ed25519PrivateKey ed25519.PrivateKey + +// Equal compares the private key to the given private key. +func (c Ed25519PrivateKey) Equal(x Ed25519PrivateKey) bool { + return bytes.Equal(c, x) +} + +// PubKey returns the public key derived from the private key. +func (c Ed25519PrivateKey) PubKey() Ed25519PublicKey { + publicKey := ed25519.PrivateKey(c).Public() + return Ed25519PublicKey(publicKey.(ed25519.PublicKey)) +} + +// Sign returns the signature for the message. +func (c Ed25519PrivateKey) Sign(message []byte) []byte { + return ed25519.Sign(ed25519.PrivateKey(c), message) +} + +// Ed25519PublicKey represents the public key for ed25519 usage. This is just a wrapper. +type Ed25519PublicKey ed25519.PublicKey + +// Equal compares the public key to the given public key. +func (c Ed25519PublicKey) Equal(x Ed25519PublicKey) bool { + return bytes.Equal(c, x) +} + +// B64Encoded returns a base64 encoded string of the public key. +func (c Ed25519PublicKey) B64Encoded() id.Curve25519 { + return id.Curve25519(goolm.Base64Encode(c)) +} + +// Verify checks the signature of the message against the givenSignature +func (c Ed25519PublicKey) Verify(message, givenSignature []byte) bool { + return ed25519.Verify(ed25519.PublicKey(c), message, givenSignature) +} + +// PickleLibOlm encodes the public key into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (c Ed25519PublicKey) PickleLibOlm(target []byte) (int, error) { + if len(target) < c.PickleLen() { + return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle ed25519 public key") + } + if len(c) != ed25519.PublicKeySize { + return libolmpickle.PickleBytes(make([]byte, ed25519.PublicKeySize), target), nil + } + return libolmpickle.PickleBytes(c, target), nil +} + +// UnpickleLibOlm decodes the unencryted value and populates the public key accordingly. It returns the number of bytes read. +func (c *Ed25519PublicKey) UnpickleLibOlm(value []byte) (int, error) { + unpickled, readBytes, err := libolmpickle.UnpickleBytes(value, ed25519.PublicKeySize) + if err != nil { + return 0, err + } + *c = unpickled + return readBytes, nil +} + +// PickleLen returns the number of bytes the pickled public key will have. +func (c Ed25519PublicKey) PickleLen() int { + if len(c) != ed25519.PublicKeySize { + return libolmpickle.PickleBytesLen(make([]byte, ed25519.PublicKeySize)) + } + return libolmpickle.PickleBytesLen(c) +} diff --git a/crypto/goolm/crypto/ed25519_test.go b/crypto/goolm/crypto/ed25519_test.go new file mode 100644 index 00000000..dbff8454 --- /dev/null +++ b/crypto/goolm/crypto/ed25519_test.go @@ -0,0 +1,138 @@ +package crypto + +import ( + "bytes" + "testing" +) + +func TestEd25519(t *testing.T) { + keypair, err := Ed25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + message := []byte("test message") + signature := keypair.Sign(message) + if !keypair.Verify(message, signature) { + t.Fail() + } +} + +func TestEd25519Case1(t *testing.T) { + //64 bytes for ed25519 package + keyPair, err := Ed25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + message := []byte("Hello, World") + + keyPair2 := Ed25519GenerateFromPrivate(keyPair.PrivateKey) + if !bytes.Equal(keyPair.PublicKey, keyPair2.PublicKey) { + t.Fatal("not equal key pairs") + } + signature := keyPair.Sign(message) + verified := keyPair.Verify(message, signature) + if !verified { + t.Fatal("message did not verify although it should") + } + //Now change the message and verify again + message = append(message, []byte("a")...) + verified = keyPair.Verify(message, signature) + if verified { + t.Fatal("message did verify although it should not") + } +} + +func TestEd25519Pickle(t *testing.T) { + //create keypair + keyPair, err := Ed25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + target := make([]byte, keyPair.PickleLen()) + writtenBytes, err := keyPair.PickleLibOlm(target) + if err != nil { + t.Fatal(err) + } + if writtenBytes != len(target) { + t.Fatal("written bytes not correct") + } + + unpickledKeyPair := Ed25519KeyPair{} + readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) + if err != nil { + t.Fatal(err) + } + if readBytes != len(target) { + t.Fatal("read bytes not correct") + } + if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) { + t.Fatal("private keys not correct") + } + if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) { + t.Fatal("public keys not correct") + } +} + +func TestEd25519PicklePubKeyOnly(t *testing.T) { + //create keypair + keyPair, err := Ed25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + //Remove privateKey + keyPair.PrivateKey = nil + target := make([]byte, keyPair.PickleLen()) + writtenBytes, err := keyPair.PickleLibOlm(target) + if err != nil { + t.Fatal(err) + } + if writtenBytes != len(target) { + t.Fatal("written bytes not correct") + } + unpickledKeyPair := Ed25519KeyPair{} + readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) + if err != nil { + t.Fatal(err) + } + if readBytes != len(target) { + t.Fatal("read bytes not correct") + } + if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) { + t.Fatal("private keys not correct") + } + if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) { + t.Fatal("public keys not correct") + } +} + +func TestEd25519PicklePrivKeyOnly(t *testing.T) { + //create keypair + keyPair, err := Ed25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + //Remove public + keyPair.PublicKey = nil + target := make([]byte, keyPair.PickleLen()) + writtenBytes, err := keyPair.PickleLibOlm(target) + if err != nil { + t.Fatal(err) + } + if writtenBytes != len(target) { + t.Fatal("written bytes not correct") + } + unpickledKeyPair := Ed25519KeyPair{} + readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) + if err != nil { + t.Fatal(err) + } + if readBytes != len(target) { + t.Fatal("read bytes not correct") + } + if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) { + t.Fatal("private keys not correct") + } + if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) { + t.Fatal("public keys not correct") + } +} diff --git a/crypto/goolm/crypto/hmac.go b/crypto/goolm/crypto/hmac.go new file mode 100644 index 00000000..8542f7cb --- /dev/null +++ b/crypto/goolm/crypto/hmac.go @@ -0,0 +1,29 @@ +package crypto + +import ( + "crypto/hmac" + "crypto/sha256" + "io" + + "golang.org/x/crypto/hkdf" +) + +// HMACSHA256 returns the hash message authentication code with SHA-256 of the input with the key. +func HMACSHA256(key, input []byte) []byte { + hash := hmac.New(sha256.New, key) + hash.Write(input) + return hash.Sum(nil) +} + +// SHA256 return the SHA-256 of the value. +func SHA256(value []byte) []byte { + hash := sha256.New() + hash.Write(value) + return hash.Sum(nil) +} + +// HKDFSHA256 is the key deivation function based on HMAC and returns a reader based on input. salt and info can both be nil. +// The reader can be used to read an arbitary length of bytes which are based on all parameters. +func HKDFSHA256(input, salt, info []byte) io.Reader { + return hkdf.New(sha256.New, input, salt, info) +} diff --git a/crypto/goolm/crypto/hmac_test.go b/crypto/goolm/crypto/hmac_test.go new file mode 100644 index 00000000..2c7f1c71 --- /dev/null +++ b/crypto/goolm/crypto/hmac_test.go @@ -0,0 +1,113 @@ +package crypto + +import ( + "bytes" + "io" + "testing" + + "codeberg.org/DerLukas/goolm" +) + +func TestHMACSha256(t *testing.T) { + key := []byte("test key") + message := []byte("test message") + hash := HMACSHA256(key, message) + if !bytes.Equal(hash, HMACSHA256(key, message)) { + t.Fail() + } + str := "A4M0ovdiWHaZ5msdDFbrvtChFwZIoIaRSVGmv8bmPtc=" + result, err := goolm.Base64Decode([]byte(str)) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(result, hash) { + t.Fail() + } +} + +func TestHKDFSha256(t *testing.T) { + message := []byte("test content") + hkdf := HKDFSHA256(message, nil, nil) + hkdf2 := HKDFSHA256(message, nil, nil) + result := make([]byte, 32) + if _, err := io.ReadFull(hkdf, result); err != nil { + t.Fatal(err) + } + result2 := make([]byte, 32) + if _, err := io.ReadFull(hkdf2, result2); err != nil { + t.Fatal(err) + } + if !bytes.Equal(result, result2) { + t.Fail() + } +} + +func TestSha256Case1(t *testing.T) { + input := make([]byte, 0) + expected := []byte{ + 0xE3, 0xB0, 0xC4, 0x42, 0x98, 0xFC, 0x1C, 0x14, + 0x9A, 0xFB, 0xF4, 0xC8, 0x99, 0x6F, 0xB9, 0x24, + 0x27, 0xAE, 0x41, 0xE4, 0x64, 0x9B, 0x93, 0x4C, + 0xA4, 0x95, 0x99, 0x1B, 0x78, 0x52, 0xB8, 0x55, + } + result := SHA256(input) + if !bytes.Equal(expected, result) { + t.Fatalf("result not as expected:\n%v\n%v\n", result, expected) + } +} + +func TestHMACCase1(t *testing.T) { + input := make([]byte, 0) + expected := []byte{ + 0xb6, 0x13, 0x67, 0x9a, 0x08, 0x14, 0xd9, 0xec, + 0x77, 0x2f, 0x95, 0xd7, 0x78, 0xc3, 0x5f, 0xc5, + 0xff, 0x16, 0x97, 0xc4, 0x93, 0x71, 0x56, 0x53, + 0xc6, 0xc7, 0x12, 0x14, 0x42, 0x92, 0xc5, 0xad, + } + result := HMACSHA256(input, input) + if !bytes.Equal(expected, result) { + t.Fatalf("result not as expected:\n%v\n%v\n", result, expected) + } +} + +func TestHDKFCase1(t *testing.T) { + input := []byte{ + 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, + 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, + 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, + } + salt := []byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, + } + info := []byte{ + 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, + 0xf8, 0xf9, + } + expectedHMAC := []byte{ + 0x07, 0x77, 0x09, 0x36, 0x2c, 0x2e, 0x32, 0xdf, + 0x0d, 0xdc, 0x3f, 0x0d, 0xc4, 0x7b, 0xba, 0x63, + 0x90, 0xb6, 0xc7, 0x3b, 0xb5, 0x0f, 0x9c, 0x31, + 0x22, 0xec, 0x84, 0x4a, 0xd7, 0xc2, 0xb3, 0xe5, + } + result := HMACSHA256(salt, input) + if !bytes.Equal(expectedHMAC, result) { + t.Fatalf("result not as expected:\n%v\n%v\n", result, expectedHMAC) + } + expectedHDKF := []byte{ + 0x3c, 0xb2, 0x5f, 0x25, 0xfa, 0xac, 0xd5, 0x7a, + 0x90, 0x43, 0x4f, 0x64, 0xd0, 0x36, 0x2f, 0x2a, + 0x2d, 0x2d, 0x0a, 0x90, 0xcf, 0x1a, 0x5a, 0x4c, + 0x5d, 0xb0, 0x2d, 0x56, 0xec, 0xc4, 0xc5, 0xbf, + 0x34, 0x00, 0x72, 0x08, 0xd5, 0xb8, 0x87, 0x18, + 0x58, 0x65, + } + resultReader := HKDFSHA256(input, salt, info) + result = make([]byte, len(expectedHDKF)) + if _, err := io.ReadFull(resultReader, result); err != nil { + t.Fatal(err) + } + if !bytes.Equal(expectedHDKF, result) { + t.Fatalf("result not as expected:\n%v\n%v\n", result, expectedHDKF) + } +} diff --git a/crypto/goolm/crypto/main.go b/crypto/goolm/crypto/main.go new file mode 100644 index 00000000..509d44a5 --- /dev/null +++ b/crypto/goolm/crypto/main.go @@ -0,0 +1,2 @@ +// crpyto provides the nessesary encryption methods for olm/megolm +package crypto diff --git a/crypto/goolm/crypto/oneTimeKey.go b/crypto/goolm/crypto/oneTimeKey.go new file mode 100644 index 00000000..53bc1a3b --- /dev/null +++ b/crypto/goolm/crypto/oneTimeKey.go @@ -0,0 +1,95 @@ +package crypto + +import ( + "encoding/binary" + + "codeberg.org/DerLukas/goolm" + libolmpickle "codeberg.org/DerLukas/goolm/libolmPickle" + "github.com/pkg/errors" + "maunium.net/go/mautrix/id" +) + +// OneTimeKey stores the information about a one time key. +type OneTimeKey struct { + ID uint32 `json:"id"` + Published bool `json:"published"` + Key Curve25519KeyPair `json:"key,omitempty"` +} + +// Equal compares the one time key to the given one. +func (otk OneTimeKey) Equal(s OneTimeKey) bool { + if otk.ID != s.ID { + return false + } + if otk.Published != s.Published { + return false + } + if !otk.Key.PrivateKey.Equal(s.Key.PrivateKey) { + return false + } + if !otk.Key.PublicKey.Equal(s.Key.PublicKey) { + return false + } + return true +} + +// PickleLibOlm encodes the key pair into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (c OneTimeKey) PickleLibOlm(target []byte) (int, error) { + if len(target) < c.PickleLen() { + return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle one time key") + } + written := libolmpickle.PickleUInt32(uint32(c.ID), target) + written += libolmpickle.PickleBool(c.Published, target[written:]) + writtenKey, err := c.Key.PickleLibOlm(target[written:]) + if err != nil { + return 0, errors.Wrap(err, "pickle one time key") + } + written += writtenKey + return written, nil +} + +// UnpickleLibOlm decodes the unencryted value and populates the OneTimeKey accordingly. It returns the number of bytes read. +func (c *OneTimeKey) UnpickleLibOlm(value []byte) (int, error) { + totalReadBytes := 0 + id, readBytes, err := libolmpickle.UnpickleUInt32(value) + if err != nil { + return 0, err + } + totalReadBytes += readBytes + c.ID = id + published, readBytes, err := libolmpickle.UnpickleBool(value[totalReadBytes:]) + if err != nil { + return 0, err + } + totalReadBytes += readBytes + c.Published = published + readBytes, err = c.Key.UnpickleLibOlm(value[totalReadBytes:]) + if err != nil { + return 0, err + } + totalReadBytes += readBytes + return totalReadBytes, nil +} + +// PickleLen returns the number of bytes the pickled OneTimeKey will have. +func (c OneTimeKey) PickleLen() int { + length := 0 + length += libolmpickle.PickleUInt32Len(c.ID) + length += libolmpickle.PickleBoolLen(c.Published) + length += c.Key.PickleLen() + return length +} + +// KeyIDEncoded returns the base64 encoded id. +func (c OneTimeKey) KeyIDEncoded() string { + resSlice := make([]byte, 4) + binary.BigEndian.PutUint32(resSlice, c.ID) + encoded := goolm.Base64Encode(resSlice) + return string(encoded) +} + +// PublicKeyEncoded returns the base64 encoded public key +func (c OneTimeKey) PublicKeyEncoded() id.Curve25519 { + return c.Key.PublicKey.B64Encoded() +} diff --git a/crypto/goolm/errors.go b/crypto/goolm/errors.go new file mode 100644 index 00000000..46db78f9 --- /dev/null +++ b/crypto/goolm/errors.go @@ -0,0 +1,36 @@ +package goolm + +import ( + "github.com/pkg/errors" +) + +// Those are the most common used errors +var ( + ErrNoSigningKey = errors.New("no signing key") + ErrBadSignature = errors.New("bad signature") + ErrBadMAC = errors.New("bad mac") + ErrBadMessageFormat = errors.New("bad message format") + ErrBadVerification = errors.New("bad verification") + ErrWrongProtocolVersion = errors.New("wrong protocol version") + ErrNoSessionKey = errors.New("no session key") + ErrEmptyInput = errors.New("empty input") + ErrNoKeyProvided = errors.New("no key") + ErrBadMessageKeyID = errors.New("bad message key id") + ErrRatchetNotAvailable = errors.New("ratchet not available: attempt to decode a message whose index is earlier than our earliest known session key") + ErrMsgIndexTooHigh = errors.New("message index too high") + ErrProtocolViolation = errors.New("not protocol message order") + ErrMessageKeyNotFound = errors.New("message key not found") + ErrChainTooHigh = errors.New("chain index too high") + ErrBadInput = errors.New("bad input") + ErrBadVersion = errors.New("wrong version") + ErrNotBlocksize = errors.New("length != blocksize") + ErrNotMultipleBlocksize = errors.New("length not a multiple of the blocksize") + ErrBase64InvalidLength = errors.New("base64 decode invalid length") + ErrWrongPickleVersion = errors.New("Wrong pickle version") + ErrSignatureNotFound = errors.New("signature not found") + ErrNotEnoughGoRandom = errors.New("Not enough random data available") + ErrValueTooShort = errors.New("value too short") + ErrInputToSmall = errors.New("input too small (truncated?)") + ErrOverflow = errors.New("overflow") + ErrBadBase64 = errors.New("Bad base64") +) diff --git a/crypto/goolm/go.mod b/crypto/goolm/go.mod new file mode 100644 index 00000000..c4b6a461 --- /dev/null +++ b/crypto/goolm/go.mod @@ -0,0 +1,9 @@ +module codeberg.org/DerLukas/goolm + +go 1.19 + +require ( + github.com/pkg/errors v0.9.1 + golang.org/x/crypto v0.3.0 + maunium.net/go/mautrix v0.12.3 +) diff --git a/crypto/goolm/go.sum b/crypto/goolm/go.sum new file mode 100644 index 00000000..0b0d6700 --- /dev/null +++ b/crypto/goolm/go.sum @@ -0,0 +1,20 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.3.0 h1:a06MkbcxBrEFc0w0QIZWXrH/9cCX6KJyWbBOIwAn+7A= +golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= +golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +maunium.net/go/mautrix v0.2.0-beta.4 h1:L7Jpc+8GGc+Qo0DdamACEeU1Ci9G1mergJpsTTgDOUA= +maunium.net/go/mautrix v0.2.0-beta.4/go.mod h1:WeTUYKrM3/4LZK2bXQ9NRIXnRWKsa+6+OA1gw0nf5G8= +maunium.net/go/mautrix v0.12.3 h1:pUeO1ThhtZxE6XibGCzDhRuxwDIFNugsreVr1yYq96k= +maunium.net/go/mautrix v0.12.3/go.mod h1:uOUjkOjm2C+nQS3mr9B5ATjqemZfnPHvjdd1kZezAwg= diff --git a/crypto/goolm/libolmPickle/pickle.go b/crypto/goolm/libolmPickle/pickle.go new file mode 100644 index 00000000..ec125a34 --- /dev/null +++ b/crypto/goolm/libolmPickle/pickle.go @@ -0,0 +1,41 @@ +package libolmpickle + +import ( + "encoding/binary" +) + +func PickleUInt8(value uint8, target []byte) int { + target[0] = value + return 1 +} +func PickleUInt8Len(value uint8) int { + return 1 +} + +func PickleBool(value bool, target []byte) int { + if value { + target[0] = 0x01 + } else { + target[0] = 0x00 + } + return 1 +} +func PickleBoolLen(value bool) int { + return 1 +} + +func PickleBytes(value, target []byte) int { + return copy(target, value) +} +func PickleBytesLen(value []byte) int { + return len(value) +} + +func PickleUInt32(value uint32, target []byte) int { + res := make([]byte, 4) //4 bytes for int32 + binary.BigEndian.PutUint32(res, value) + return copy(target, res) +} +func PickleUInt32Len(value uint32) int { + return 4 +} diff --git a/crypto/goolm/libolmPickle/pickle_test.go b/crypto/goolm/libolmPickle/pickle_test.go new file mode 100644 index 00000000..ff6062c2 --- /dev/null +++ b/crypto/goolm/libolmPickle/pickle_test.go @@ -0,0 +1,96 @@ +package libolmpickle + +import ( + "bytes" + "testing" +) + +func TestPickleUInt32(t *testing.T) { + values := []uint32{ + 0xffffffff, + 0x00ff00ff, + 0xf0000000, + 0xf00f0000, + } + expected := [][]byte{ + {0xff, 0xff, 0xff, 0xff}, + {0x00, 0xff, 0x00, 0xff}, + {0xf0, 0x00, 0x00, 0x00}, + {0xf0, 0x0f, 0x00, 0x00}, + } + for curIndex := range values { + response := make([]byte, 4) + resPLen := PickleUInt32(values[curIndex], response) + if resPLen != PickleUInt32Len(values[curIndex]) { + t.Fatal("written bytes not correct") + } + if !bytes.Equal(response, expected[curIndex]) { + t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) + } + } +} + +func TestPickleBool(t *testing.T) { + values := []bool{ + true, + false, + } + expected := [][]byte{ + {0x01}, + {0x00}, + } + for curIndex := range values { + response := make([]byte, 1) + resPLen := PickleBool(values[curIndex], response) + if resPLen != PickleBoolLen(values[curIndex]) { + t.Fatal("written bytes not correct") + } + if !bytes.Equal(response, expected[curIndex]) { + t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) + } + } +} + +func TestPickleUInt8(t *testing.T) { + values := []uint8{ + 0xff, + 0x1a, + } + expected := [][]byte{ + {0xff}, + {0x1a}, + } + for curIndex := range values { + response := make([]byte, 1) + resPLen := PickleUInt8(values[curIndex], response) + if resPLen != PickleUInt8Len(values[curIndex]) { + t.Fatal("written bytes not correct") + } + if !bytes.Equal(response, expected[curIndex]) { + t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) + } + } +} + +func TestPickleBytes(t *testing.T) { + values := [][]byte{ + {0xff, 0xff, 0xff, 0xff}, + {0x00, 0xff, 0x00, 0xff}, + {0xf0, 0x00, 0x00, 0x00}, + } + expected := [][]byte{ + {0xff, 0xff, 0xff, 0xff}, + {0x00, 0xff, 0x00, 0xff}, + {0xf0, 0x00, 0x00, 0x00}, + } + for curIndex := range values { + response := make([]byte, len(values[curIndex])) + resPLen := PickleBytes(values[curIndex], response) + if resPLen != PickleBytesLen(values[curIndex]) { + t.Fatal("written bytes not correct") + } + if !bytes.Equal(response, expected[curIndex]) { + t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) + } + } +} diff --git a/crypto/goolm/libolmPickle/unpickle.go b/crypto/goolm/libolmPickle/unpickle.go new file mode 100644 index 00000000..7140094f --- /dev/null +++ b/crypto/goolm/libolmPickle/unpickle.go @@ -0,0 +1,52 @@ +package libolmpickle + +import ( + "codeberg.org/DerLukas/goolm" + "github.com/pkg/errors" +) + +func isZeroByteSlice(bytes []byte) bool { + b := byte(0) + for _, s := range bytes { + b |= s + } + return b == 0 +} + +func UnpickleUInt8(value []byte) (uint8, int, error) { + if len(value) < 1 { + return 0, 0, errors.Wrap(goolm.ErrValueTooShort, "unpickle uint8") + } + return value[0], 1, nil +} + +func UnpickleBool(value []byte) (bool, int, error) { + if len(value) < 1 { + return false, 0, errors.Wrap(goolm.ErrValueTooShort, "unpickle bool") + } + return value[0] != uint8(0x00), 1, nil +} + +func UnpickleBytes(value []byte, length int) ([]byte, int, error) { + if len(value) < length { + return nil, 0, errors.Wrap(goolm.ErrValueTooShort, "unpickle bytes") + } + resp := value[:length] + if isZeroByteSlice(resp) { + return nil, length, nil + } + return resp, length, nil +} + +func UnpickleUInt32(value []byte) (uint32, int, error) { + if len(value) < 4 { + return 0, 0, errors.Wrap(goolm.ErrValueTooShort, "unpickle uint32") + } + var res uint32 + count := 0 + for i := 3; i >= 0; i-- { + res |= uint32(value[count]) << (8 * i) + count++ + } + return res, 4, nil +} diff --git a/crypto/goolm/libolmPickle/unpickle_test.go b/crypto/goolm/libolmPickle/unpickle_test.go new file mode 100644 index 00000000..505f3f64 --- /dev/null +++ b/crypto/goolm/libolmPickle/unpickle_test.go @@ -0,0 +1,104 @@ +package libolmpickle + +import ( + "bytes" + "testing" +) + +func TestUnpickleUInt32(t *testing.T) { + expected := []uint32{ + 0xffffffff, + 0x00ff00ff, + 0xf0000000, + } + values := [][]byte{ + {0xff, 0xff, 0xff, 0xff}, + {0x00, 0xff, 0x00, 0xff}, + {0xf0, 0x00, 0x00, 0x00}, + } + for curIndex := range values { + response, readLength, err := UnpickleUInt32(values[curIndex]) + if err != nil { + t.Fatal(err) + } + if readLength != 4 { + t.Fatal("read bytes not correct") + } + if response != expected[curIndex] { + t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) + } + } +} + +func TestUnpickleBool(t *testing.T) { + expected := []bool{ + true, + false, + true, + } + values := [][]byte{ + {0x01}, + {0x00}, + {0x02}, + } + for curIndex := range values { + response, readLength, err := UnpickleBool(values[curIndex]) + if err != nil { + t.Fatal(err) + } + if readLength != 1 { + t.Fatal("read bytes not correct") + } + if response != expected[curIndex] { + t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) + } + } +} + +func TestUnpickleUInt8(t *testing.T) { + expected := []uint8{ + 0xff, + 0x1a, + } + values := [][]byte{ + {0xff}, + {0x1a}, + } + for curIndex := range values { + response, readLength, err := UnpickleUInt8(values[curIndex]) + if err != nil { + t.Fatal(err) + } + if readLength != 1 { + t.Fatal("read bytes not correct") + } + if response != expected[curIndex] { + t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) + } + } +} + +func TestUnpickleBytes(t *testing.T) { + values := [][]byte{ + {0xff, 0xff, 0xff, 0xff}, + {0x00, 0xff, 0x00, 0xff}, + {0xf0, 0x00, 0x00, 0x00}, + } + expected := [][]byte{ + {0xff, 0xff, 0xff, 0xff}, + {0x00, 0xff, 0x00, 0xff}, + {0xf0, 0x00, 0x00, 0x00}, + } + for curIndex := range values { + response, readLength, err := UnpickleBytes(values[curIndex], 4) + if err != nil { + t.Fatal(err) + } + if readLength != 4 { + t.Fatal("read bytes not correct") + } + if !bytes.Equal(response, expected[curIndex]) { + t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) + } + } +} diff --git a/crypto/goolm/libolmVersion.md b/crypto/goolm/libolmVersion.md new file mode 100644 index 00000000..d91f241e --- /dev/null +++ b/crypto/goolm/libolmVersion.md @@ -0,0 +1,3 @@ +### This package is based on libolm version 3.2.14 + +Changes to the libolm implementation should be reflected in this package and this file should be updated. diff --git a/crypto/goolm/main.go b/crypto/goolm/main.go new file mode 100644 index 00000000..5e785c7b --- /dev/null +++ b/crypto/goolm/main.go @@ -0,0 +1,10 @@ +// goolm is a pure Go implementation of libolm. Libolm is a cryptographic library used for end-to-end encryption in Matrix and wirtten in C++. +// With goolm there is no need to use cgo when building Matrix clients in go. +/* +This package contains the possible errors which can occur as well as some simple functions. All the 'action' happens in the subdirectories. +*/ +package goolm + +func GetLibaryVersion() (major, minor, patch uint8) { + return 3, 2, 14 +} diff --git a/crypto/goolm/megolm/megolm.go b/crypto/goolm/megolm/megolm.go new file mode 100644 index 00000000..652d8b17 --- /dev/null +++ b/crypto/goolm/megolm/megolm.go @@ -0,0 +1,234 @@ +// megolm provides the ratchet used by the megolm protocol +package megolm + +import ( + "crypto/rand" + + "codeberg.org/DerLukas/goolm" + "codeberg.org/DerLukas/goolm/cipher" + "codeberg.org/DerLukas/goolm/crypto" + libolmpickle "codeberg.org/DerLukas/goolm/libolmPickle" + "codeberg.org/DerLukas/goolm/message" + "codeberg.org/DerLukas/goolm/utilities" + "github.com/pkg/errors" +) + +const ( + megolmPickleVersion uint8 = 1 +) + +const ( + protocolVersion = 3 + RatchetParts = 4 // number of ratchet parts + RatchetPartLength = 256 / 8 // length of each ratchet part in bytes +) + +var RatchetCipher = cipher.NewAESSha256([]byte("MEGOLM_KEYS")) + +// hasKeySeed are the seed for the different ratchet parts +var hashKeySeeds [RatchetParts][]byte = [RatchetParts][]byte{ + {0x00}, + {0x01}, + {0x02}, + {0x03}, +} + +// Ratchet represents the megolm ratchet as described in +// +// https://gitlab.matrix.org/matrix-org/olm/-/blob/master/docs/megolm.md +type Ratchet struct { + Data [RatchetParts * RatchetPartLength]byte `json:"data"` + Counter uint32 `json:"counter"` +} + +// New creates a new ratchet with counter set to counter and the ratchet data set to data. +func New(counter uint32, data [RatchetParts * RatchetPartLength]byte) (*Ratchet, error) { + m := &Ratchet{ + Counter: counter, + Data: data, + } + return m, nil +} + +// NewWithRandom creates a new ratchet with counter set to counter an the data filled with random values. +func NewWithRandom(counter uint32) (*Ratchet, error) { + var data [RatchetParts * RatchetPartLength]byte + _, err := rand.Read(data[:]) + if err != nil { + return nil, err + } + return New(counter, data) +} + +// rehashPart rehases the part of the ratchet data with the base defined as from storing into the target to. +func (m *Ratchet) rehashPart(from, to int) { + newData := crypto.HMACSHA256(m.Data[from*RatchetPartLength:from*RatchetPartLength+RatchetPartLength], hashKeySeeds[to]) + copy(m.Data[to*RatchetPartLength:], newData[:RatchetPartLength]) +} + +// Advance advances the ratchet one step. +func (m *Ratchet) Advance() { + var mask uint32 = 0x00FFFFFF + var h int + m.Counter++ + + // figure out how much we need to rekey + for h < RatchetParts { + if (m.Counter & mask) == 0 { + break + } + h++ + mask >>= 8 + } + + // now update R(h)...R(3) based on R(h) + for i := RatchetParts - 1; i >= h; i-- { + m.rehashPart(h, i) + } +} + +// AdvanceTo advances the ratchet so that the ratchet counter = target +func (m *Ratchet) AdvanceTo(target uint32) { + //starting with R0, see if we need to update each part of the hash + for j := 0; j < RatchetParts; j++ { + shift := uint32((RatchetParts - j - 1) * 8) + mask := (^uint32(0)) << shift + + // how many times do we need to rehash this part? + // '& 0xff' ensures we handle integer wraparound correctly + steps := ((target >> shift) - m.Counter>>shift) & uint32(0xff) + + if steps == 0 { + /* + deal with the edge case where m.Counter is slightly larger + than target. This should only happen for R(0), and implies + that target has wrapped around and we need to advance R(0) + 256 times. + */ + if target < m.Counter { + steps = 0x100 + } else { + continue + } + } + // for all but the last step, we can just bump R(j) without regard to R(j+1)...R(3). + for steps > 1 { + m.rehashPart(j, j) + steps-- + } + /* + on the last step we also need to bump R(j+1)...R(3). + + (Theoretically, we could skip bumping R(j+2) if we're going to bump + R(j+1) again, but the code to figure that out is a bit baroque and + doesn't save us much). + */ + for k := 3; k >= j; k-- { + m.rehashPart(j, k) + } + m.Counter = target & mask + } +} + +// Encrypt encrypts the message in a message.GroupMessage with MAC and signature. +// The output is base64 encoded. +func (r *Ratchet) Encrypt(plaintext []byte, key *crypto.Ed25519KeyPair) ([]byte, error) { + var err error + encryptedText, err := RatchetCipher.Encrypt(r.Data[:], plaintext) + if err != nil { + return nil, errors.Wrap(err, "cipher encrypt") + } + + message := &message.GroupMessage{} + message.Version = protocolVersion + message.MessageIndex = r.Counter + message.Ciphertext = encryptedText + //creating the mac and signing is done in encode + output, err := message.EncodeAndMacAndSign(r.Data[:], RatchetCipher, key) + if err != nil { + return nil, err + } + r.Advance() + return output, nil +} + +// SessionSharingMessage creates a message in the session sharing format. +func (r Ratchet) SessionSharingMessage(key crypto.Ed25519KeyPair) ([]byte, error) { + m := message.MegolmSessionSharing{} + m.Counter = r.Counter + m.RatchetData = r.Data + encoded := m.EncodeAndSign(key) + return goolm.Base64Encode(encoded), nil +} + +// SessionExportMessage creates a message in the session export format. +func (r Ratchet) SessionExportMessage(key crypto.Ed25519PublicKey) ([]byte, error) { + m := message.MegolmSessionExport{} + m.Counter = r.Counter + m.RatchetData = r.Data + m.PublicKey = key + encoded := m.Encode() + return goolm.Base64Encode(encoded), nil +} + +// Decrypt decrypts the ciphertext and verifies the MAC but not the signature. +func (r Ratchet) Decrypt(ciphertext []byte, signingkey *crypto.Ed25519PublicKey, msg *message.GroupMessage) ([]byte, error) { + //verify mac + verifiedMAC, err := msg.VerifyMACInline(r.Data[:], RatchetCipher, ciphertext) + if err != nil { + return nil, err + } + if !verifiedMAC { + return nil, errors.Wrap(goolm.ErrBadMAC, "decrypt") + } + + return RatchetCipher.Decrypt(r.Data[:], msg.Ciphertext) +} + +// PickleAsJSON returns a ratchet as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. +func (r Ratchet) PickleAsJSON(key []byte) ([]byte, error) { + return utilities.PickleAsJSON(r, megolmPickleVersion, key) +} + +// UnpickleAsJSON updates a ratchet by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. +func (r *Ratchet) UnpickleAsJSON(pickled, key []byte) error { + return utilities.UnpickleAsJSON(r, pickled, key, megolmPickleVersion) +} + +// UnpickleLibOlm decodes the unencryted value and populates the Ratchet accordingly. It returns the number of bytes read. +func (r *Ratchet) UnpickleLibOlm(unpickled []byte) (int, error) { + //read ratchet data + curPos := 0 + ratchetData, readBytes, err := libolmpickle.UnpickleBytes(unpickled, RatchetParts*RatchetPartLength) + if err != nil { + return 0, err + } + copy(r.Data[:], ratchetData) + curPos += readBytes + //Read counter + counter, readBytes, err := libolmpickle.UnpickleUInt32(unpickled[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + r.Counter = counter + return curPos, nil +} + +// PickleLibOlm encodes the ratchet into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (r Ratchet) PickleLibOlm(target []byte) (int, error) { + if len(target) < r.PickleLen() { + return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle account") + } + written := libolmpickle.PickleBytes(r.Data[:], target) + written += libolmpickle.PickleUInt32(r.Counter, target[written:]) + return written, nil +} + +// PickleLen returns the number of bytes the pickled ratchet will have. +func (r Ratchet) PickleLen() int { + length := libolmpickle.PickleBytesLen(r.Data[:]) + length += libolmpickle.PickleUInt32Len(r.Counter) + return length +} diff --git a/crypto/goolm/megolm/megolm_test.go b/crypto/goolm/megolm/megolm_test.go new file mode 100644 index 00000000..414deb4e --- /dev/null +++ b/crypto/goolm/megolm/megolm_test.go @@ -0,0 +1,139 @@ +package megolm + +import ( + "bytes" + "testing" +) + +var startData [RatchetParts * RatchetPartLength]byte + +func init() { + startValue := []byte("0123456789ABCDEF0123456789ABCDEF") + copy(startData[:], startValue) + copy(startData[32:], startValue) + copy(startData[64:], startValue) + copy(startData[96:], startValue) +} + +func TestAdvance(t *testing.T) { + m, err := New(0, startData) + if err != nil { + t.Fatal(err) + } + + expectedData := [RatchetParts * RatchetPartLength]byte{ + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, + 0xba, 0x9c, 0xd9, 0x55, 0x74, 0x1d, 0x1c, 0x16, 0x23, 0x23, 0xec, 0x82, 0x5e, 0x7c, 0x5c, 0xe8, + 0x89, 0xbb, 0xb4, 0x23, 0xa1, 0x8f, 0x23, 0x82, 0x8f, 0xb2, 0x09, 0x0d, 0x6e, 0x2a, 0xf8, 0x6a, + } + m.Advance() + if !bytes.Equal(m.Data[:], expectedData[:]) { + t.Fatal("result after advancing the ratchet is not as expected") + } + + //repeat with complex advance + m.Data = startData + expectedData = [RatchetParts * RatchetPartLength]byte{ + 0x54, 0x02, 0x2d, 0x7d, 0xc0, 0x29, 0x8e, 0x16, 0x37, 0xe2, 0x1c, 0x97, 0x15, 0x30, 0x92, 0xf9, + 0x33, 0xc0, 0x56, 0xff, 0x74, 0xfe, 0x1b, 0x92, 0x2d, 0x97, 0x1f, 0x24, 0x82, 0xc2, 0x85, 0x9c, + 0x70, 0x04, 0xc0, 0x1e, 0xe4, 0x9b, 0xd6, 0xef, 0xe0, 0x07, 0x35, 0x25, 0xaf, 0x9b, 0x16, 0x32, + 0xc5, 0xbe, 0x72, 0x6d, 0x12, 0x34, 0x9c, 0xc5, 0xbd, 0x47, 0x2b, 0xdc, 0x2d, 0xf6, 0x54, 0x0f, + 0x31, 0x12, 0x59, 0x11, 0x94, 0xfd, 0xa6, 0x17, 0xe5, 0x68, 0xc6, 0x83, 0x10, 0x1e, 0xae, 0xcd, + 0x7e, 0xdd, 0xd6, 0xde, 0x1f, 0xbc, 0x07, 0x67, 0xae, 0x34, 0xda, 0x1a, 0x09, 0xa5, 0x4e, 0xab, + 0xba, 0x9c, 0xd9, 0x55, 0x74, 0x1d, 0x1c, 0x16, 0x23, 0x23, 0xec, 0x82, 0x5e, 0x7c, 0x5c, 0xe8, + 0x89, 0xbb, 0xb4, 0x23, 0xa1, 0x8f, 0x23, 0x82, 0x8f, 0xb2, 0x09, 0x0d, 0x6e, 0x2a, 0xf8, 0x6a, + } + m.AdvanceTo(0x1000000) + if !bytes.Equal(m.Data[:], expectedData[:]) { + t.Fatal("result after advancing the ratchet is not as expected") + } + expectedData = [RatchetParts * RatchetPartLength]byte{ + 0x54, 0x02, 0x2d, 0x7d, 0xc0, 0x29, 0x8e, 0x16, 0x37, 0xe2, 0x1c, 0x97, 0x15, 0x30, 0x92, 0xf9, + 0x33, 0xc0, 0x56, 0xff, 0x74, 0xfe, 0x1b, 0x92, 0x2d, 0x97, 0x1f, 0x24, 0x82, 0xc2, 0x85, 0x9c, + 0x55, 0x58, 0x8d, 0xf5, 0xb7, 0xa4, 0x88, 0x78, 0x42, 0x89, 0x27, 0x86, 0x81, 0x64, 0x58, 0x9f, + 0x36, 0x63, 0x44, 0x7b, 0x51, 0xed, 0xc3, 0x59, 0x5b, 0x03, 0x6c, 0xa6, 0x04, 0xc4, 0x6d, 0xcd, + 0x5c, 0x54, 0x85, 0x0b, 0xfa, 0x98, 0xa1, 0xfd, 0x79, 0xa9, 0xdf, 0x1c, 0xbe, 0x8f, 0xc5, 0x68, + 0x19, 0x37, 0xd3, 0x0c, 0x85, 0xc8, 0xc3, 0x1f, 0x7b, 0xb8, 0x28, 0x81, 0x6c, 0xf9, 0xff, 0x3b, + 0x95, 0x6c, 0xbf, 0x80, 0x7e, 0x65, 0x12, 0x6a, 0x49, 0x55, 0x8d, 0x45, 0xc8, 0x4a, 0x2e, 0x4c, + 0xd5, 0x6f, 0x03, 0xe2, 0x44, 0x16, 0xb9, 0x8e, 0x1c, 0xfd, 0x97, 0xc2, 0x06, 0xaa, 0x90, 0x7a, + } + m.AdvanceTo(0x1041506) + if !bytes.Equal(m.Data[:], expectedData[:]) { + t.Fatal("result after advancing the ratchet is not as expected") + } +} + +func TestAdvanceWraparound(t *testing.T) { + m, err := New(0xffffffff, startData) + if err != nil { + t.Fatal(err) + } + m.AdvanceTo(0x1000000) + if m.Counter != 0x1000000 { + t.Fatal("counter not correct") + } + + m2, err := New(0, startData) + if err != nil { + t.Fatal(err) + } + m2.AdvanceTo(0x2000000) + if m2.Counter != 0x2000000 { + t.Fatal("counter not correct") + } + if !bytes.Equal(m.Data[:], m2.Data[:]) { + t.Fatal("result after wrapping the ratchet is not as expected") + } +} + +func TestAdvanceOverflowByOne(t *testing.T) { + m, err := New(0xffffffff, startData) + if err != nil { + t.Fatal(err) + } + m.AdvanceTo(0x0) + if m.Counter != 0x0 { + t.Fatal("counter not correct") + } + + m2, err := New(0xffffffff, startData) + if err != nil { + t.Fatal(err) + } + m2.Advance() + if m2.Counter != 0x0 { + t.Fatal("counter not correct") + } + if !bytes.Equal(m.Data[:], m2.Data[:]) { + t.Fatal("result after wrapping the ratchet is not as expected") + } +} + +func TestAdvanceOverflow(t *testing.T) { + m, err := New(0x1, startData) + if err != nil { + t.Fatal(err) + } + m.AdvanceTo(0x80000000) + m.AdvanceTo(0x0) + if m.Counter != 0x0 { + t.Fatal("counter not correct") + } + + m2, err := New(0x1, startData) + if err != nil { + t.Fatal(err) + } + m2.AdvanceTo(0x0) + if m2.Counter != 0x0 { + t.Fatal("counter not correct") + } + if !bytes.Equal(m.Data[:], m2.Data[:]) { + t.Fatal("result after wrapping the ratchet is not as expected") + } +} diff --git a/crypto/goolm/message/decoder.go b/crypto/goolm/message/decoder.go new file mode 100644 index 00000000..37104ab3 --- /dev/null +++ b/crypto/goolm/message/decoder.go @@ -0,0 +1,71 @@ +package message + +import ( + "encoding/binary" + + "codeberg.org/DerLukas/goolm" + "github.com/pkg/errors" +) + +// checkDecodeErr checks if there was an error during decode. +func checkDecodeErr(readBytes int) error { + if readBytes == 0 { + //end reached + return errors.Wrap(goolm.ErrInputToSmall, "") + } + if readBytes < 0 { + return errors.Wrap(goolm.ErrOverflow, "") + } + return nil +} + +// decodeVarInt decodes a single big-endian encoded varint. +func decodeVarInt(input []byte) (uint32, int) { + value, readBytes := binary.Uvarint(input) + return uint32(value), readBytes +} + +// decodeVarString decodes the length of the string (varint) and returns the actual string +func decodeVarString(input []byte) ([]byte, int) { + stringLen, readBytes := decodeVarInt(input) + if readBytes <= 0 { + return nil, readBytes + } + input = input[readBytes:] + value := input[:stringLen] + readBytes += int(stringLen) + return value, readBytes +} + +// encodeVarIntByteLength returns the number of bytes needed to encode the uint32. +func encodeVarIntByteLength(input uint32) int { + result := 1 + for input >= 128 { + result++ + input >>= 7 + } + return result +} + +// encodeVarStringByteLength returns the number of bytes needed to encode the input. +func encodeVarStringByteLength(input []byte) int { + result := encodeVarIntByteLength(uint32(len(input))) + result += len(input) + return result +} + +// encodeVarInt encodes a single uint32 +func encodeVarInt(input uint32) []byte { + out := make([]byte, encodeVarIntByteLength(input)) + binary.PutUvarint(out, uint64(input)) + return out +} + +// encodeVarString encodes the length of the input (varint) and appends the actual input +func encodeVarString(input []byte) []byte { + out := make([]byte, encodeVarStringByteLength(input)) + length := encodeVarInt(uint32(len(input))) + copy(out, length) + copy(out[len(length):], input) + return out +} diff --git a/crypto/goolm/message/decoder_test.go b/crypto/goolm/message/decoder_test.go new file mode 100644 index 00000000..39503e3e --- /dev/null +++ b/crypto/goolm/message/decoder_test.go @@ -0,0 +1,82 @@ +package message + +import ( + "bytes" + "testing" +) + +func TestEncodeLengthInt(t *testing.T) { + numbers := []uint32{127, 128, 16383, 16384, 32767} + expected := []int{1, 2, 2, 3, 3} + for curIndex := range numbers { + if result := encodeVarIntByteLength(numbers[curIndex]); result != expected[curIndex] { + t.Fatalf("expected byte length of %d but got %d", expected[curIndex], result) + } + } +} + +func TestEncodeLengthString(t *testing.T) { + var strings [][]byte + var expected []int + strings = append(strings, []byte("test")) + expected = append(expected, 1+4) + strings = append(strings, []byte("this is a long message with a length of 127 so that the varint of the length is just one byte. just needs some padding---------")) + expected = append(expected, 1+127) + strings = append(strings, []byte("this is an even longer message with a length between 128 and 16383 so that the varint of the length needs two byte. just needs some padding again ---------")) + expected = append(expected, 2+155) + for curIndex := range strings { + if result := encodeVarStringByteLength(strings[curIndex]); result != expected[curIndex] { + t.Fatalf("expected byte length of %d but got %d", expected[curIndex], result) + } + } +} + +func TestEncodeInt(t *testing.T) { + var ints []uint32 + var expected [][]byte + ints = append(ints, 7) + expected = append(expected, []byte{0b00000111}) + ints = append(ints, 127) + expected = append(expected, []byte{0b01111111}) + ints = append(ints, 128) + expected = append(expected, []byte{0b10000000, 0b00000001}) + ints = append(ints, 16383) + expected = append(expected, []byte{0b11111111, 0b01111111}) + for curIndex := range ints { + if result := encodeVarInt(ints[curIndex]); !bytes.Equal(result, expected[curIndex]) { + t.Fatalf("expected byte of %b but got %b", expected[curIndex], result) + } + } +} + +func TestEncodeString(t *testing.T) { + var strings [][]byte + var expected [][]byte + curTest := []byte("test") + strings = append(strings, curTest) + res := []byte{ + 0b00000100, //varint length of string + } + res = append(res, curTest...) //Add string itself + expected = append(expected, res) + curTest = []byte("this is a long message with a length of 127 so that the varint of the length is just one byte. just needs some padding---------") + strings = append(strings, curTest) + res = []byte{ + 0b01111111, //varint length of string + } + res = append(res, curTest...) //Add string itself + expected = append(expected, res) + curTest = []byte("this is an even longer message with a length between 128 and 16383 so that the varint of the length needs two byte. just needs some padding again ---------") + strings = append(strings, curTest) + res = []byte{ + 0b10011011, //varint length of string + 0b00000001, //varint length of string + } + res = append(res, curTest...) //Add string itself + expected = append(expected, res) + for curIndex := range strings { + if result := encodeVarString(strings[curIndex]); !bytes.Equal(result, expected[curIndex]) { + t.Fatalf("expected byte of %b but got %b", expected[curIndex], result) + } + } +} diff --git a/crypto/goolm/message/groupMessage.go b/crypto/goolm/message/groupMessage.go new file mode 100644 index 00000000..2ef39e48 --- /dev/null +++ b/crypto/goolm/message/groupMessage.go @@ -0,0 +1,144 @@ +package message + +import ( + "bytes" + + "codeberg.org/DerLukas/goolm/cipher" + "codeberg.org/DerLukas/goolm/crypto" +) + +const ( + messageIndexTag = 0x08 + cipherTextTag = 0x12 + countMACBytesGroupMessage = 8 +) + +// GroupMessage represents a message in the group message format. +type GroupMessage struct { + Version byte `json:"version"` + MessageIndex uint32 `json:"index"` + Ciphertext []byte `json:"ciphertext"` + HasMessageIndex bool `json:"hasIndex"` +} + +// Decodes decodes the input and populates the corresponding fileds. MAC and signature are ignored but have to be present. +func (r *GroupMessage) Decode(input []byte) error { + r.Version = 0 + r.MessageIndex = 0 + r.Ciphertext = nil + if len(input) == 0 { + return nil + } + //first Byte is always version + r.Version = input[0] + curPos := 1 + for curPos < len(input)-countMACBytesGroupMessage-crypto.ED25519SignatureSize { + //Read Key + curKey, readBytes := decodeVarInt(input[curPos:]) + if err := checkDecodeErr(readBytes); err != nil { + return err + } + curPos += readBytes + if (curKey & 0b111) == 0 { + //The value is of type varint + value, readBytes := decodeVarInt(input[curPos:]) + if err := checkDecodeErr(readBytes); err != nil { + return err + } + curPos += readBytes + switch curKey { + case messageIndexTag: + r.MessageIndex = value + r.HasMessageIndex = true + } + } else if (curKey & 0b111) == 2 { + //The value is of type string + value, readBytes := decodeVarString(input[curPos:]) + if err := checkDecodeErr(readBytes); err != nil { + return err + } + curPos += readBytes + switch curKey { + case cipherTextTag: + r.Ciphertext = value + } + } + } + + return nil +} + +// EncodeAndMacAndSign encodes the message, creates the mac with the key and the cipher and signs the message. +// If macKey or cipher is nil, no mac is appended. If signKey is nil, no signature is appended. +func (r *GroupMessage) EncodeAndMacAndSign(macKey []byte, cipher cipher.Cipher, signKey *crypto.Ed25519KeyPair) ([]byte, error) { + var lengthOfMessage int + lengthOfMessage += 1 //Version + lengthOfMessage += encodeVarIntByteLength(messageIndexTag) + encodeVarIntByteLength(r.MessageIndex) + lengthOfMessage += encodeVarIntByteLength(cipherTextTag) + encodeVarStringByteLength(r.Ciphertext) + out := make([]byte, lengthOfMessage) + out[0] = r.Version + curPos := 1 + encodedTag := encodeVarInt(messageIndexTag) + copy(out[curPos:], encodedTag) + curPos += len(encodedTag) + encodedValue := encodeVarInt(r.MessageIndex) + copy(out[curPos:], encodedValue) + curPos += len(encodedValue) + encodedTag = encodeVarInt(cipherTextTag) + copy(out[curPos:], encodedTag) + curPos += len(encodedTag) + encodedValue = encodeVarString(r.Ciphertext) + copy(out[curPos:], encodedValue) + curPos += len(encodedValue) + if len(macKey) != 0 && cipher != nil { + mac, err := r.MAC(macKey, cipher, out) + if err != nil { + return nil, err + } + out = append(out, mac[:countMACBytesGroupMessage]...) + } + if signKey != nil { + signature := signKey.Sign(out) + out = append(out, signature...) + } + return out, nil +} + +// MAC returns the MAC of the message calculated with cipher and key. The length of the MAC is truncated to the correct length. +func (r *GroupMessage) MAC(key []byte, cipher cipher.Cipher, message []byte) ([]byte, error) { + mac, err := cipher.MAC(key, message) + if err != nil { + return nil, err + } + return mac[:countMACBytesGroupMessage], nil +} + +// VerifySignature verifies the givenSignature to the calculated signature of the message. +func (r *GroupMessage) VerifySignature(key crypto.Ed25519PublicKey, message, givenSignature []byte) bool { + return key.Verify(message, givenSignature) +} + +// VerifySignature verifies the signature taken from the message to the calculated signature of the message. +func (r *GroupMessage) VerifySignatureInline(key crypto.Ed25519PublicKey, message []byte) bool { + signature := message[len(message)-crypto.ED25519SignatureSize:] + message = message[:len(message)-crypto.ED25519SignatureSize] + return key.Verify(message, signature) +} + +// VerifyMAC verifies the givenMAC to the calculated MAC of the message. +func (r *GroupMessage) VerifyMAC(key []byte, cipher cipher.Cipher, message, givenMAC []byte) (bool, error) { + checkMac, err := r.MAC(key, cipher, message) + if err != nil { + return false, err + } + return bytes.Equal(checkMac[:countMACBytesGroupMessage], givenMAC), nil +} + +// VerifyMACInline verifies the MAC taken from the message to the calculated MAC of the message. +func (r *GroupMessage) VerifyMACInline(key []byte, cipher cipher.Cipher, message []byte) (bool, error) { + startMAC := len(message) - countMACBytesGroupMessage - crypto.ED25519SignatureSize + endMAC := startMAC + countMACBytesGroupMessage + suplMac := message[startMAC:endMAC] + message = message[:startMAC] + return r.VerifyMAC(key, cipher, message, suplMac) +} diff --git a/crypto/goolm/message/groupMessage_test.go b/crypto/goolm/message/groupMessage_test.go new file mode 100644 index 00000000..fab2f9ea --- /dev/null +++ b/crypto/goolm/message/groupMessage_test.go @@ -0,0 +1,49 @@ +package message + +import ( + "bytes" + "testing" +) + +func TestGroupMessageDecode(t *testing.T) { + messageRaw := []byte("\x03\x08\xC8\x01\x12\x0aciphertexthmacsha2") + signature := []byte("signature1234567891234567890123412345678912345678912345678901234") + messageRaw = append(messageRaw, signature...) + expectedMessageIndex := uint32(200) + expectedCipherText := []byte("ciphertext") + + msg := GroupMessage{} + err := msg.Decode(messageRaw) + if err != nil { + t.Fatal(err) + } + if msg.Version != 3 { + t.Fatalf("Expected Version to be 3 but go %d", msg.Version) + } + if msg.MessageIndex != expectedMessageIndex { + t.Fatalf("Expected message index to be %d but got %d", expectedMessageIndex, msg.MessageIndex) + } + if !bytes.Equal(msg.Ciphertext, expectedCipherText) { + t.Fatalf("expected '%s' but got '%s'", expectedCipherText, msg.Ciphertext) + } +} + +func TestGroupMessageEncode(t *testing.T) { + expectedRaw := []byte("\x03\x08\xC8\x01\x12\x0aciphertexthmacsha2signature") + hmacsha256 := []byte("hmacsha2") + sign := []byte("signature") + msg := GroupMessage{ + Version: 3, + MessageIndex: 200, + Ciphertext: []byte("ciphertext"), + } + encoded, err := msg.EncodeAndMacAndSign(nil, nil, nil) + if err != nil { + t.Fatal(err) + } + encoded = append(encoded, hmacsha256...) + encoded = append(encoded, sign...) + if !bytes.Equal(encoded, expectedRaw) { + t.Fatalf("expected '%s' but got '%s'", expectedRaw, encoded) + } +} diff --git a/crypto/goolm/message/message.go b/crypto/goolm/message/message.go new file mode 100644 index 00000000..d76c458c --- /dev/null +++ b/crypto/goolm/message/message.go @@ -0,0 +1,129 @@ +package message + +import ( + "bytes" + + "codeberg.org/DerLukas/goolm/cipher" + "codeberg.org/DerLukas/goolm/crypto" +) + +const ( + ratchetKeyTag = 0x0A + counterTag = 0x10 + cipherTextKeyTag = 0x22 + countMACBytesMessage = 8 +) + +// GroupMessage represents a message in the message format. +type Message struct { + Version byte `json:"version"` + HasCounter bool `json:"hasCounter"` + Counter uint32 `json:"counter"` + RatchetKey crypto.Curve25519PublicKey `json:"ratchetKey"` + Ciphertext []byte `json:"ciphertext"` +} + +// Decodes decodes the input and populates the corresponding fileds. MAC is ignored but has to be present. +func (r *Message) Decode(input []byte) error { + r.Version = 0 + r.HasCounter = false + r.Counter = 0 + r.RatchetKey = nil + r.Ciphertext = nil + if len(input) == 0 { + return nil + } + //first Byte is always version + r.Version = input[0] + curPos := 1 + for curPos < len(input)-countMACBytesMessage { + //Read Key + curKey, readBytes := decodeVarInt(input[curPos:]) + if err := checkDecodeErr(readBytes); err != nil { + return err + } + curPos += readBytes + if (curKey & 0b111) == 0 { + //The value is of type varint + value, readBytes := decodeVarInt(input[curPos:]) + if err := checkDecodeErr(readBytes); err != nil { + return err + } + curPos += readBytes + switch curKey { + case counterTag: + r.HasCounter = true + r.Counter = value + } + } else if (curKey & 0b111) == 2 { + //The value is of type string + value, readBytes := decodeVarString(input[curPos:]) + if err := checkDecodeErr(readBytes); err != nil { + return err + } + curPos += readBytes + switch curKey { + case ratchetKeyTag: + r.RatchetKey = value + case cipherTextKeyTag: + r.Ciphertext = value + } + } + } + + return nil +} + +// EncodeAndMAC encodes the message and creates the MAC with the key and the cipher. +// If key or cipher is nil, no MAC is appended. +func (r *Message) EncodeAndMAC(key []byte, cipher cipher.Cipher) ([]byte, error) { + var lengthOfMessage int + lengthOfMessage += 1 //Version + lengthOfMessage += encodeVarIntByteLength(ratchetKeyTag) + encodeVarStringByteLength(r.RatchetKey) + lengthOfMessage += encodeVarIntByteLength(counterTag) + encodeVarIntByteLength(r.Counter) + lengthOfMessage += encodeVarIntByteLength(cipherTextKeyTag) + encodeVarStringByteLength(r.Ciphertext) + out := make([]byte, lengthOfMessage) + out[0] = r.Version + curPos := 1 + encodedTag := encodeVarInt(ratchetKeyTag) + copy(out[curPos:], encodedTag) + curPos += len(encodedTag) + encodedValue := encodeVarString(r.RatchetKey) + copy(out[curPos:], encodedValue) + curPos += len(encodedValue) + encodedTag = encodeVarInt(counterTag) + copy(out[curPos:], encodedTag) + curPos += len(encodedTag) + encodedValue = encodeVarInt(r.Counter) + copy(out[curPos:], encodedValue) + curPos += len(encodedValue) + encodedTag = encodeVarInt(cipherTextKeyTag) + copy(out[curPos:], encodedTag) + curPos += len(encodedTag) + encodedValue = encodeVarString(r.Ciphertext) + copy(out[curPos:], encodedValue) + curPos += len(encodedValue) + if len(key) != 0 && cipher != nil { + mac, err := cipher.MAC(key, out) + if err != nil { + return nil, err + } + out = append(out, mac[:countMACBytesMessage]...) + } + return out, nil +} + +// VerifyMAC verifies the givenMAC to the calculated MAC of the message. +func (r *Message) VerifyMAC(key []byte, cipher cipher.Cipher, message, givenMAC []byte) (bool, error) { + checkMAC, err := cipher.MAC(key, message) + if err != nil { + return false, err + } + return bytes.Equal(checkMAC[:countMACBytesMessage], givenMAC), nil +} + +// VerifyMACInline verifies the MAC taken from the message to the calculated MAC of the message. +func (r *Message) VerifyMACInline(key []byte, cipher cipher.Cipher, message []byte) (bool, error) { + givenMAC := message[len(message)-countMACBytesMessage:] + return r.VerifyMAC(key, cipher, message[:len(message)-countMACBytesMessage], givenMAC) +} diff --git a/crypto/goolm/message/message_test.go b/crypto/goolm/message/message_test.go new file mode 100644 index 00000000..1e2921c7 --- /dev/null +++ b/crypto/goolm/message/message_test.go @@ -0,0 +1,52 @@ +package message + +import ( + "bytes" + "testing" +) + +func TestMessageDecode(t *testing.T) { + messageRaw := []byte("\x03\x10\x01\n\nratchetkey\"\nciphertexthmacsha2") + expectedRatchetKey := []byte("ratchetkey") + expectedCipherText := []byte("ciphertext") + + msg := Message{} + err := msg.Decode(messageRaw) + if err != nil { + t.Fatal(err) + } + if msg.Version != 3 { + t.Fatalf("Expected Version to be 3 but go %d", msg.Version) + } + if !msg.HasCounter { + t.Fatal("Expected to have counter") + } + if msg.Counter != 1 { + t.Fatalf("Expected counter to be 1 but got %d", msg.Counter) + } + if !bytes.Equal(msg.Ciphertext, expectedCipherText) { + t.Fatalf("expected '%s' but got '%s'", expectedCipherText, msg.Ciphertext) + } + if !bytes.Equal(msg.RatchetKey, expectedRatchetKey) { + t.Fatalf("expected '%s' but got '%s'", expectedRatchetKey, msg.RatchetKey) + } +} + +func TestMessageEncode(t *testing.T) { + expectedRaw := []byte("\x03\n\nratchetkey\x10\x01\"\nciphertexthmacsha2") + hmacsha256 := []byte("hmacsha2") + msg := Message{ + Version: 3, + Counter: 1, + RatchetKey: []byte("ratchetkey"), + Ciphertext: []byte("ciphertext"), + } + encoded, err := msg.EncodeAndMAC(nil, nil) + if err != nil { + t.Fatal(err) + } + encoded = append(encoded, hmacsha256...) + if !bytes.Equal(encoded, expectedRaw) { + t.Fatalf("expected '%s' but got '%s'", expectedRaw, encoded) + } +} diff --git a/crypto/goolm/message/preKeyMessage.go b/crypto/goolm/message/preKeyMessage.go new file mode 100644 index 00000000..8f1826d7 --- /dev/null +++ b/crypto/goolm/message/preKeyMessage.go @@ -0,0 +1,120 @@ +package message + +import ( + "codeberg.org/DerLukas/goolm/crypto" +) + +const ( + oneTimeKeyIdTag = 0x0A + baseKeyTag = 0x12 + identityKeyTag = 0x1A + messageTag = 0x22 +) + +type PreKeyMessage struct { + Version byte `json:"version"` + IdentityKey crypto.Curve25519PublicKey `json:"idKey"` + BaseKey crypto.Curve25519PublicKey `json:"baseKey"` + OneTimeKey crypto.Curve25519PublicKey `json:"otKey"` + Message []byte `json:"message"` +} + +// Decodes decodes the input and populates the corresponding fileds. +func (r *PreKeyMessage) Decode(input []byte) error { + r.Version = 0 + r.IdentityKey = nil + r.BaseKey = nil + r.OneTimeKey = nil + r.Message = nil + if len(input) == 0 { + return nil + } + //first Byte is always version + r.Version = input[0] + curPos := 1 + for curPos < len(input) { + //Read Key + curKey, readBytes := decodeVarInt(input[curPos:]) + if err := checkDecodeErr(readBytes); err != nil { + return err + } + curPos += readBytes + if (curKey & 0b111) == 0 { + //The value is of type varint + _, readBytes := decodeVarInt(input[curPos:]) + if err := checkDecodeErr(readBytes); err != nil { + return err + } + curPos += readBytes + } else if (curKey & 0b111) == 2 { + //The value is of type string + value, readBytes := decodeVarString(input[curPos:]) + if err := checkDecodeErr(readBytes); err != nil { + return err + } + curPos += readBytes + switch curKey { + case oneTimeKeyIdTag: + r.OneTimeKey = value + case baseKeyTag: + r.BaseKey = value + case identityKeyTag: + r.IdentityKey = value + case messageTag: + r.Message = value + } + } + } + + return nil +} + +// CheckField verifies the fields. If theirIdentityKey is nil, it is not compared to the key in the message. +func (r *PreKeyMessage) CheckFields(theirIdentityKey *crypto.Curve25519PublicKey) bool { + ok := true + ok = ok && (theirIdentityKey != nil || r.IdentityKey != nil) + if r.IdentityKey != nil { + ok = ok && (len(r.IdentityKey) == crypto.Curve25519KeyLength) + } + ok = ok && len(r.Message) != 0 + ok = ok && len(r.BaseKey) == crypto.Curve25519KeyLength + ok = ok && len(r.OneTimeKey) == crypto.Curve25519KeyLength + return ok +} + +// Encode encodes the message. +func (r *PreKeyMessage) Encode() ([]byte, error) { + var lengthOfMessage int + lengthOfMessage += 1 //Version + lengthOfMessage += encodeVarIntByteLength(oneTimeKeyIdTag) + encodeVarStringByteLength(r.OneTimeKey) + lengthOfMessage += encodeVarIntByteLength(identityKeyTag) + encodeVarStringByteLength(r.IdentityKey) + lengthOfMessage += encodeVarIntByteLength(baseKeyTag) + encodeVarStringByteLength(r.BaseKey) + lengthOfMessage += encodeVarIntByteLength(messageTag) + encodeVarStringByteLength(r.Message) + out := make([]byte, lengthOfMessage) + out[0] = r.Version + curPos := 1 + encodedTag := encodeVarInt(oneTimeKeyIdTag) + copy(out[curPos:], encodedTag) + curPos += len(encodedTag) + encodedValue := encodeVarString(r.OneTimeKey) + copy(out[curPos:], encodedValue) + curPos += len(encodedValue) + encodedTag = encodeVarInt(identityKeyTag) + copy(out[curPos:], encodedTag) + curPos += len(encodedTag) + encodedValue = encodeVarString(r.IdentityKey) + copy(out[curPos:], encodedValue) + curPos += len(encodedValue) + encodedTag = encodeVarInt(baseKeyTag) + copy(out[curPos:], encodedTag) + curPos += len(encodedTag) + encodedValue = encodeVarString(r.BaseKey) + copy(out[curPos:], encodedValue) + curPos += len(encodedValue) + encodedTag = encodeVarInt(messageTag) + copy(out[curPos:], encodedTag) + curPos += len(encodedTag) + encodedValue = encodeVarString(r.Message) + copy(out[curPos:], encodedValue) + return out, nil +} diff --git a/crypto/goolm/message/preKeyMessage_test.go b/crypto/goolm/message/preKeyMessage_test.go new file mode 100644 index 00000000..a244e95e --- /dev/null +++ b/crypto/goolm/message/preKeyMessage_test.go @@ -0,0 +1,62 @@ +package message + +import ( + "bytes" + "testing" + + "codeberg.org/DerLukas/goolm/crypto" +) + +func TestPreKeyMessageDecode(t *testing.T) { + //Keys are 32 bytes to pass field check + //Added a tag for an integer of 0 just for checkes + messageRaw := []byte("\x03\x0a\x20onetimeKey.-.-.-.-.-.-.-.-.-.-.-\x1a\x20idKeywithlendth32bytes-.-.-.-.-.\x12\x20baseKey-.-.-.-.-.-.-.-.-.-.-.-.-\x22\x07message\x00\x00") + expectedOneTimeKey := []byte("onetimeKey.-.-.-.-.-.-.-.-.-.-.-") + expectedIdKey := []byte("idKeywithlendth32bytes-.-.-.-.-.") + expectedbaseKey := []byte("baseKey-.-.-.-.-.-.-.-.-.-.-.-.-") + expectedmessage := []byte("message") + + msg := PreKeyMessage{} + err := msg.Decode(messageRaw) + if err != nil { + t.Fatal(err) + } + if msg.Version != 3 { + t.Fatalf("Expected Version to be 3 but go %d", msg.Version) + } + if !bytes.Equal(msg.OneTimeKey, expectedOneTimeKey) { + t.Fatalf("expected '%s' but got '%s'", expectedOneTimeKey, msg.OneTimeKey) + } + if !bytes.Equal(msg.IdentityKey, expectedIdKey) { + t.Fatalf("expected '%s' but got '%s'", expectedIdKey, msg.IdentityKey) + } + if !bytes.Equal(msg.BaseKey, expectedbaseKey) { + t.Fatalf("expected '%s' but got '%s'", expectedbaseKey, msg.BaseKey) + } + if !bytes.Equal(msg.Message, expectedmessage) { + t.Fatalf("expected '%s' but got '%s'", expectedmessage, msg.Message) + } + theirIDKey := crypto.Curve25519PublicKey(expectedIdKey) + checked := msg.CheckFields(&theirIDKey) + if !checked { + t.Fatal("field check failed") + } +} + +func TestPreKeyMessageEncode(t *testing.T) { + expectedRaw := []byte("\x03\x0a\x0aonetimeKey\x1a\x05idKey\x12\x07baseKey\x22\x07message") + msg := PreKeyMessage{ + Version: 3, + IdentityKey: []byte("idKey"), + BaseKey: []byte("baseKey"), + OneTimeKey: []byte("onetimeKey"), + Message: []byte("message"), + } + encoded, err := msg.Encode() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(encoded, expectedRaw) { + t.Fatalf("got other than expected:\nExpected:\n%v\nGot:\n%v", expectedRaw, encoded) + } +} diff --git a/crypto/goolm/message/sessionExport.go b/crypto/goolm/message/sessionExport.go new file mode 100644 index 00000000..05814b3a --- /dev/null +++ b/crypto/goolm/message/sessionExport.go @@ -0,0 +1,44 @@ +package message + +import ( + "encoding/binary" + + "codeberg.org/DerLukas/goolm" + "codeberg.org/DerLukas/goolm/crypto" + "github.com/pkg/errors" +) + +const ( + sessionExportVersion = 0x01 +) + +// MegolmSessionExport represents a message in the session export format. +type MegolmSessionExport struct { + Counter uint32 `json:"counter"` + RatchetData [128]byte `json:"data"` + PublicKey crypto.Ed25519PublicKey `json:"kPub"` +} + +// Encode returns the encoded message in the correct format. +func (s MegolmSessionExport) Encode() []byte { + output := make([]byte, 165) + output[0] = sessionExportVersion + binary.BigEndian.PutUint32(output[1:], s.Counter) + copy(output[5:], s.RatchetData[:]) + copy(output[133:], s.PublicKey) + return output +} + +// Decode populates the struct with the data encoded in input. +func (s *MegolmSessionExport) Decode(input []byte) error { + if len(input) != 165 { + return errors.Wrap(goolm.ErrBadInput, "decrypt") + } + if input[0] != sessionExportVersion { + return errors.Wrap(goolm.ErrBadVersion, "decrypt") + } + s.Counter = binary.BigEndian.Uint32(input[1:5]) + copy(s.RatchetData[:], input[5:133]) + s.PublicKey = input[133:] + return nil +} diff --git a/crypto/goolm/message/sessionSharing.go b/crypto/goolm/message/sessionSharing.go new file mode 100644 index 00000000..5c0cd773 --- /dev/null +++ b/crypto/goolm/message/sessionSharing.go @@ -0,0 +1,50 @@ +package message + +import ( + "encoding/binary" + + "codeberg.org/DerLukas/goolm" + "codeberg.org/DerLukas/goolm/crypto" + "github.com/pkg/errors" +) + +const ( + sessionSharingVersion = 0x02 +) + +// MegolmSessionSharing represents a message in the session sharing format. +type MegolmSessionSharing struct { + Counter uint32 `json:"counter"` + RatchetData [128]byte `json:"data"` + PublicKey crypto.Ed25519PublicKey `json:"-"` //only used when decrypting messages +} + +// Encode returns the encoded message in the correct format with the signature by key appended. +func (s MegolmSessionSharing) EncodeAndSign(key crypto.Ed25519KeyPair) []byte { + output := make([]byte, 229) + output[0] = sessionSharingVersion + binary.BigEndian.PutUint32(output[1:], s.Counter) + copy(output[5:], s.RatchetData[:]) + copy(output[133:], key.PublicKey) + signature := key.Sign(output[:165]) + copy(output[165:], signature) + return output +} + +// VerifyAndDecode verifies the input and populates the struct with the data encoded in input. +func (s *MegolmSessionSharing) VerifyAndDecode(input []byte) error { + if len(input) != 229 { + return errors.Wrap(goolm.ErrBadInput, "verify") + } + publicKey := crypto.Ed25519PublicKey(input[133:165]) + if !publicKey.Verify(input[:165], input[165:]) { + return errors.Wrap(goolm.ErrBadVerification, "verify") + } + s.PublicKey = publicKey + if input[0] != sessionSharingVersion { + return errors.Wrap(goolm.ErrBadVersion, "verify") + } + s.Counter = binary.BigEndian.Uint32(input[1:5]) + copy(s.RatchetData[:], input[5:133]) + return nil +} diff --git a/crypto/goolm/olm/chain.go b/crypto/goolm/olm/chain.go new file mode 100644 index 00000000..187f5f6d --- /dev/null +++ b/crypto/goolm/olm/chain.go @@ -0,0 +1,257 @@ +package olm + +import ( + "codeberg.org/DerLukas/goolm" + "codeberg.org/DerLukas/goolm/crypto" + libolmpickle "codeberg.org/DerLukas/goolm/libolmPickle" + "github.com/pkg/errors" +) + +const ( + chainKeySeed = 0x02 + messageKeyLength = 32 +) + +// chainKey wraps the index and the public key +type chainKey struct { + Index uint32 `json:"index"` + Key crypto.Curve25519PublicKey `json:"key"` +} + +// advance advances the chain +func (c *chainKey) advance() { + c.Key = crypto.HMACSHA256(c.Key, []byte{chainKeySeed}) + c.Index++ +} + +// UnpickleLibOlm decodes the unencryted value and populates the chain key accordingly. It returns the number of bytes read. +func (r *chainKey) UnpickleLibOlm(value []byte) (int, error) { + curPos := 0 + readBytes, err := r.Key.UnpickleLibOlm(value) + if err != nil { + return 0, err + } + curPos += readBytes + r.Index, readBytes, err = libolmpickle.UnpickleUInt32(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + return curPos, nil +} + +// PickleLibOlm encodes the chain key into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (r chainKey) PickleLibOlm(target []byte) (int, error) { + if len(target) < r.PickleLen() { + return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle chain key") + } + written, err := r.Key.PickleLibOlm(target) + if err != nil { + return 0, errors.Wrap(err, "pickle chain key") + } + written += libolmpickle.PickleUInt32(r.Index, target[written:]) + return written, nil +} + +// PickleLen returns the number of bytes the pickled chain key will have. +func (r chainKey) PickleLen() int { + length := r.Key.PickleLen() + length += libolmpickle.PickleUInt32Len(r.Index) + return length +} + +// senderChain is a chain for sending messages +type senderChain struct { + RKey crypto.Curve25519KeyPair `json:"ratchetKey"` + CKey chainKey `json:"chainKey"` + IsSet bool `json:"set"` +} + +// newSenderChain returns a sender chain initialized with chainKey and ratchet key pair. +func newSenderChain(key crypto.Curve25519PublicKey, ratchet crypto.Curve25519KeyPair) *senderChain { + return &senderChain{ + RKey: ratchet, + CKey: chainKey{ + Index: 0, + Key: key, + }, + IsSet: true, + } +} + +// advance advances the chain +func (s *senderChain) advance() { + s.CKey.advance() +} + +// ratchetKey returns the ratchet key pair. +func (s senderChain) ratchetKey() crypto.Curve25519KeyPair { + return s.RKey +} + +// chainKey returns the current chainKey. +func (s senderChain) chainKey() chainKey { + return s.CKey +} + +// UnpickleLibOlm decodes the unencryted value and populates the chain accordingly. It returns the number of bytes read. +func (r *senderChain) UnpickleLibOlm(value []byte) (int, error) { + curPos := 0 + readBytes, err := r.RKey.UnpickleLibOlm(value) + if err != nil { + return 0, err + } + curPos += readBytes + readBytes, err = r.CKey.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + return curPos, nil +} + +// PickleLibOlm encodes the chain into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (r senderChain) PickleLibOlm(target []byte) (int, error) { + if len(target) < r.PickleLen() { + return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle sender chain") + } + written, err := r.RKey.PickleLibOlm(target) + if err != nil { + return 0, errors.Wrap(err, "pickle sender chain") + } + writtenChain, err := r.CKey.PickleLibOlm(target[written:]) + if err != nil { + return 0, errors.Wrap(err, "pickle sender chain") + } + written += writtenChain + return written, nil +} + +// PickleLen returns the number of bytes the pickled chain will have. +func (r senderChain) PickleLen() int { + length := r.RKey.PickleLen() + length += r.CKey.PickleLen() + return length +} + +// senderChain is a chain for receiving messages +type receiverChain struct { + RKey crypto.Curve25519PublicKey `json:"ratchetKey"` + CKey chainKey `json:"chainKey"` +} + +// newReceiverChain returns a receiver chain initialized with chainKey and ratchet public key. +func newReceiverChain(chain crypto.Curve25519PublicKey, ratchet crypto.Curve25519PublicKey) *receiverChain { + return &receiverChain{ + RKey: ratchet, + CKey: chainKey{ + Index: 0, + Key: chain, + }, + } +} + +// advance advances the chain +func (s *receiverChain) advance() { + s.CKey.advance() +} + +// ratchetKey returns the ratchet public key. +func (s receiverChain) ratchetKey() crypto.Curve25519PublicKey { + return s.RKey +} + +// chainKey returns the current chainKey. +func (s receiverChain) chainKey() chainKey { + return s.CKey +} + +// UnpickleLibOlm decodes the unencryted value and populates the chain accordingly. It returns the number of bytes read. +func (r *receiverChain) UnpickleLibOlm(value []byte) (int, error) { + curPos := 0 + readBytes, err := r.RKey.UnpickleLibOlm(value) + if err != nil { + return 0, err + } + curPos += readBytes + readBytes, err = r.CKey.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + return curPos, nil +} + +// PickleLibOlm encodes the chain into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (r receiverChain) PickleLibOlm(target []byte) (int, error) { + if len(target) < r.PickleLen() { + return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle sender chain") + } + written, err := r.RKey.PickleLibOlm(target) + if err != nil { + return 0, errors.Wrap(err, "pickle sender chain") + } + writtenChain, err := r.CKey.PickleLibOlm(target) + if err != nil { + return 0, errors.Wrap(err, "pickle sender chain") + } + written += writtenChain + return written, nil +} + +// PickleLen returns the number of bytes the pickled chain will have. +func (r receiverChain) PickleLen() int { + length := r.RKey.PickleLen() + length += r.CKey.PickleLen() + return length +} + +// messageKey wraps the index and the key of a message +type messageKey struct { + Index uint32 `json:"index"` + Key []byte `json:"key"` +} + +// UnpickleLibOlm decodes the unencryted value and populates the message key accordingly. It returns the number of bytes read. +func (m *messageKey) UnpickleLibOlm(value []byte) (int, error) { + curPos := 0 + ratchetKey, readBytes, err := libolmpickle.UnpickleBytes(value, messageKeyLength) + if err != nil { + return 0, err + } + m.Key = ratchetKey + curPos += readBytes + keyID, readBytes, err := libolmpickle.UnpickleUInt32(value[:curPos]) + if err != nil { + return 0, err + } + curPos += readBytes + m.Index = keyID + return curPos, nil +} + +// PickleLibOlm encodes the message key into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (m messageKey) PickleLibOlm(target []byte) (int, error) { + if len(target) < m.PickleLen() { + return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle message key") + } + written := 0 + if len(m.Key) != messageKeyLength { + written += libolmpickle.PickleBytes(make([]byte, messageKeyLength), target) + } else { + written += libolmpickle.PickleBytes(m.Key, target) + } + written += libolmpickle.PickleUInt32(m.Index, target[written:]) + return written, nil +} + +// PickleLen returns the number of bytes the pickled message key will have. +func (r messageKey) PickleLen() int { + length := libolmpickle.PickleBytesLen(make([]byte, messageKeyLength)) + length += libolmpickle.PickleUInt32Len(r.Index) + return length +} diff --git a/crypto/goolm/olm/olm.go b/crypto/goolm/olm/olm.go new file mode 100644 index 00000000..7e1e3e56 --- /dev/null +++ b/crypto/goolm/olm/olm.go @@ -0,0 +1,432 @@ +// olm provides the ratchet used by the olm protocol +package olm + +import ( + "io" + + "codeberg.org/DerLukas/goolm" + "codeberg.org/DerLukas/goolm/cipher" + "codeberg.org/DerLukas/goolm/crypto" + libolmpickle "codeberg.org/DerLukas/goolm/libolmPickle" + "codeberg.org/DerLukas/goolm/message" + "codeberg.org/DerLukas/goolm/utilities" + "github.com/pkg/errors" +) + +const ( + olmPickleVersion uint8 = 1 +) + +const ( + maxReceiverChains = 5 + maxSkippedMessageKeys = 40 + protocolVersion = 3 + messageKeySeed = 0x01 + + maxMessageGap = 2000 + sharedKeyLength = 32 +) + +// KdfInfo has the infos used for the kdf +var KdfInfo = struct { + Root []byte + Ratchet []byte +}{ + Root: []byte("OLM_ROOT"), + Ratchet: []byte("OLM_RATCHET"), +} + +var RatchetCipher = cipher.NewAESSha256([]byte("OLM_KEYS")) + +// Ratchet represents the olm ratchet as described in +// +// https://gitlab.matrix.org/matrix-org/olm/-/blob/master/docs/olm.md +type Ratchet struct { + // The root key is used to generate chain keys from the ephemeral keys. + // A new root_key is derived each time a new chain is started. + RootKey crypto.Curve25519PublicKey `json:"rootKey"` + + // The sender chain is used to send messages. Each time a new ephemeral + // key is received from the remote server we generate a new sender chain + // with a new ephemeral key when we next send a message. + SenderChains senderChain `json:"senderChain"` + + // The receiver chain is used to decrypt received messages. We store the + // last few chains so we can decrypt any out of order messages we haven't + // received yet. + // New chains are prepended for easier access. + ReceiverChains []receiverChain `json:"receiverChains"` + + // Storing the keys of missed messages for future use. + // The order of the elements is not important. + SkippedMessageKeys []skippedMessageKey `json:"skippedMessageKeys"` +} + +// New creates a new ratchet, setting the kdfInfos and cipher. +func New() *Ratchet { + r := &Ratchet{} + return r +} + +// InitialiseAsBob initialises this ratchet from a receiving point of view (only first message). +func (r *Ratchet) InitialiseAsBob(sharedSecret []byte, theirRatchetKey crypto.Curve25519PublicKey) error { + derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, nil, KdfInfo.Root) + derivedSecrets := make([]byte, 2*sharedKeyLength) + if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil { + return err + } + r.RootKey = derivedSecrets[0:sharedKeyLength] + newReceiverChain := newReceiverChain(derivedSecrets[sharedKeyLength:], theirRatchetKey) + r.ReceiverChains = append([]receiverChain{*newReceiverChain}, r.ReceiverChains...) + return nil +} + +// InitialiseAsAlice initialises this ratchet from a sending point of view (only first message). +func (r *Ratchet) InitialiseAsAlice(sharedSecret []byte, ourRatchetKey crypto.Curve25519KeyPair) error { + derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, nil, KdfInfo.Root) + derivedSecrets := make([]byte, 2*sharedKeyLength) + if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil { + return err + } + r.RootKey = derivedSecrets[0:sharedKeyLength] + newSenderChain := newSenderChain(derivedSecrets[sharedKeyLength:], ourRatchetKey) + r.SenderChains = *newSenderChain + return nil +} + +// Encrypt encrypts the message in a message.Message with MAC. If reader is nil, crypto/rand is used for key generations. +func (r *Ratchet) Encrypt(plaintext []byte, reader io.Reader) ([]byte, error) { + var err error + if !r.SenderChains.IsSet { + newRatchetKey, err := crypto.Curve25519GenerateKey(reader) + if err != nil { + return nil, err + } + newChainKey, err := r.advanceRootKey(newRatchetKey, r.ReceiverChains[0].ratchetKey()) + if err != nil { + return nil, err + } + newSenderChain := newSenderChain(newChainKey, newRatchetKey) + r.SenderChains = *newSenderChain + } + + messageKey := r.createMessageKeys(r.SenderChains.chainKey()) + r.SenderChains.advance() + + encryptedText, err := RatchetCipher.Encrypt(messageKey.Key, plaintext) + if err != nil { + return nil, errors.Wrap(err, "cipher encrypt") + } + + message := &message.Message{} + message.Version = protocolVersion + message.Counter = messageKey.Index + message.RatchetKey = r.SenderChains.ratchetKey().PublicKey + message.Ciphertext = encryptedText + //creating the mac is done in encode + output, err := message.EncodeAndMAC(messageKey.Key, RatchetCipher) + if err != nil { + return nil, err + } + + return output, nil +} + +// Decrypt decrypts the ciphertext and verifies the MAC. If reader is nil, crypto/rand is used for key generations. +func (r *Ratchet) Decrypt(input []byte) ([]byte, error) { + message := &message.Message{} + //The mac is not verified here, as we do not know the key yet + err := message.Decode(input) + if err != nil { + return nil, err + } + if message.Version != protocolVersion { + return nil, errors.Wrap(goolm.ErrWrongProtocolVersion, "decrypt") + } + if !message.HasCounter || len(message.RatchetKey) == 0 || len(message.Ciphertext) == 0 { + return nil, errors.Wrap(goolm.ErrBadMessageFormat, "decrypt") + } + var receiverChainFromMessage *receiverChain + for curChainIndex := range r.ReceiverChains { + if r.ReceiverChains[curChainIndex].ratchetKey().Equal(message.RatchetKey) { + receiverChainFromMessage = &r.ReceiverChains[curChainIndex] + break + } + } + var result []byte + if receiverChainFromMessage == nil { + //Advancing the chain is done in this method + result, err = r.decryptForNewChain(message, input) + if err != nil { + return nil, err + } + } else if receiverChainFromMessage.chainKey().Index > message.Counter { + // No need to advance the chain + // Chain already advanced beyond the key for this message + // Check if the message keys are in the skipped key list. + foundSkippedKey := false + for curSkippedIndex := range r.SkippedMessageKeys { + if message.Counter == r.SkippedMessageKeys[curSkippedIndex].MKey.Index { + // Found the key for this message. Check the MAC. + verified, err := message.VerifyMACInline(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, RatchetCipher, input) + if err != nil { + return nil, err + } + if !verified { + return nil, errors.Wrap(goolm.ErrBadMAC, "decrypt from skipped message keys") + } + result, err = RatchetCipher.Decrypt(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, message.Ciphertext) + if err != nil { + return nil, errors.Wrap(err, "cipher decrypt") + } + if len(result) != 0 { + // Remove the key from the skipped keys now that we've + // decoded the message it corresponds to. + r.SkippedMessageKeys[curSkippedIndex] = r.SkippedMessageKeys[len(r.SkippedMessageKeys)-1] + r.SkippedMessageKeys = r.SkippedMessageKeys[:len(r.SkippedMessageKeys)-1] + } + foundSkippedKey = true + } + } + if !foundSkippedKey { + return nil, errors.Wrap(goolm.ErrMessageKeyNotFound, "decrypt") + } + } else { + //Advancing the chain is done in this method + result, err = r.decryptForExistingChain(receiverChainFromMessage, message, input) + if err != nil { + return nil, err + } + } + + return result, nil +} + +// advanceRootKey created the next root key and returns the next chainKey +func (r *Ratchet) advanceRootKey(newRatchetKey crypto.Curve25519KeyPair, oldRatchetKey crypto.Curve25519PublicKey) (crypto.Curve25519PublicKey, error) { + sharedSecret, err := newRatchetKey.SharedSecret(oldRatchetKey) + if err != nil { + return nil, err + } + derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, r.RootKey, KdfInfo.Ratchet) + derivedSecrets := make([]byte, 2*sharedKeyLength) + if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil { + return nil, err + } + r.RootKey = derivedSecrets[:sharedKeyLength] + return derivedSecrets[sharedKeyLength:], nil +} + +// createMessageKeys returns the messageKey derived from the chainKey +func (r Ratchet) createMessageKeys(chainKey chainKey) messageKey { + res := messageKey{} + res.Key = crypto.HMACSHA256(chainKey.Key, []byte{messageKeySeed}) + res.Index = chainKey.Index + return res +} + +// decryptForExistingChain returns the decrypted message by using the chain. The MAC of the rawMessage is verified. +func (r *Ratchet) decryptForExistingChain(chain *receiverChain, message *message.Message, rawMessage []byte) ([]byte, error) { + if message.Counter < chain.CKey.Index { + return nil, errors.Wrap(goolm.ErrChainTooHigh, "decrypt") + } + // Limit the number of hashes we're prepared to compute + if message.Counter-chain.CKey.Index > maxMessageGap { + return nil, errors.Wrap(goolm.ErrMsgIndexTooHigh, "decrypt from existing chain") + } + for chain.CKey.Index < message.Counter { + messageKey := r.createMessageKeys(chain.chainKey()) + skippedKey := skippedMessageKey{ + MKey: messageKey, + RKey: chain.ratchetKey(), + } + r.SkippedMessageKeys = append(r.SkippedMessageKeys, skippedKey) + chain.advance() + } + messageKey := r.createMessageKeys(chain.chainKey()) + chain.advance() + verified, err := message.VerifyMACInline(messageKey.Key, RatchetCipher, rawMessage) + if err != nil { + return nil, err + } + if !verified { + return nil, errors.Wrap(goolm.ErrBadMAC, "decrypt from existing chain") + } + return RatchetCipher.Decrypt(messageKey.Key, message.Ciphertext) +} + +// decryptForNewChain returns the decrypted message by creating a new chain and advancing the root key. +func (r *Ratchet) decryptForNewChain(message *message.Message, rawMessage []byte) ([]byte, error) { + // They shouldn't move to a new chain until we've sent them a message + // acknowledging the last one + if !r.SenderChains.IsSet { + return nil, errors.Wrap(goolm.ErrProtocolViolation, "decrypt for new chain") + } + // Limit the number of hashes we're prepared to compute + if message.Counter > maxMessageGap { + return nil, errors.Wrap(goolm.ErrMsgIndexTooHigh, "decrypt for new chain") + } + + newChainKey, err := r.advanceRootKey(r.SenderChains.ratchetKey(), message.RatchetKey) + if err != nil { + return nil, err + } + newChain := newReceiverChain(newChainKey, message.RatchetKey) + r.ReceiverChains = append([]receiverChain{*newChain}, r.ReceiverChains...) + /* + They have started using a new ephemeral ratchet key. + We needed to derive a new set of chain keys. + We can discard our previous ephemeral ratchet key. + We will generate a new key when we send the next message. + */ + r.SenderChains = senderChain{} + + decrypted, err := r.decryptForExistingChain(&r.ReceiverChains[0], message, rawMessage) + if err != nil { + return nil, err + } + return decrypted, nil +} + +// PickleAsJSON returns a ratchet as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. +func (r Ratchet) PickleAsJSON(key []byte) ([]byte, error) { + return utilities.PickleAsJSON(r, olmPickleVersion, key) +} + +// UnpickleAsJSON updates a ratchet by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. +func (r *Ratchet) UnpickleAsJSON(pickled, key []byte) error { + return utilities.UnpickleAsJSON(r, pickled, key, olmPickleVersion) +} + +// UnpickleLibOlm decodes the unencryted value and populates the Ratchet accordingly. It returns the number of bytes read. +func (r *Ratchet) UnpickleLibOlm(value []byte, includesChainIndex bool) (int, error) { + //read ratchet data + curPos := 0 + readBytes, err := r.RootKey.UnpickleLibOlm(value) + if err != nil { + return 0, err + } + curPos += readBytes + countSenderChains, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) //Length of sender chain + if err != nil { + return 0, err + } + curPos += readBytes + for i := uint32(0); i < countSenderChains; i++ { + if i == 0 { + //only first is stored + readBytes, err := r.SenderChains.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + r.SenderChains.IsSet = true + } else { + dummy := senderChain{} + readBytes, err := dummy.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + } + } + countReceivChains, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) //Length of recevier chain + if err != nil { + return 0, err + } + curPos += readBytes + r.ReceiverChains = make([]receiverChain, countReceivChains) + for i := uint32(0); i < countReceivChains; i++ { + readBytes, err := r.ReceiverChains[i].UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + } + countSkippedMessageKeys, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) //Length of skippedMessageKeys + if err != nil { + return 0, err + } + curPos += readBytes + r.SkippedMessageKeys = make([]skippedMessageKey, countSkippedMessageKeys) + for i := uint32(0); i < countSkippedMessageKeys; i++ { + readBytes, err := r.SkippedMessageKeys[i].UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + } + // pickle v 0x80000001 includes a chain index; pickle v1 does not. + if includesChainIndex { + _, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + } + return curPos, nil +} + +// PickleLibOlm encodes the ratchet into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (r Ratchet) PickleLibOlm(target []byte) (int, error) { + if len(target) < r.PickleLen() { + return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle ratchet") + } + written, err := r.RootKey.PickleLibOlm(target) + if err != nil { + return 0, errors.Wrap(err, "pickle ratchet") + } + if r.SenderChains.IsSet { + written += libolmpickle.PickleUInt32(1, target[written:]) //Length of sender chain, always 1 + writtenSender, err := r.SenderChains.PickleLibOlm(target[written:]) + if err != nil { + return 0, errors.Wrap(err, "pickle ratchet") + } + written += writtenSender + } else { + written += libolmpickle.PickleUInt32(0, target[written:]) //Length of sender chain + } + written += libolmpickle.PickleUInt32(uint32(len(r.ReceiverChains)), target[written:]) + for _, curChain := range r.ReceiverChains { + writtenChain, err := curChain.PickleLibOlm(target[written:]) + if err != nil { + return 0, errors.Wrap(err, "pickle ratchet") + } + written += writtenChain + } + written += libolmpickle.PickleUInt32(uint32(len(r.SkippedMessageKeys)), target[written:]) + for _, curChain := range r.SkippedMessageKeys { + writtenChain, err := curChain.PickleLibOlm(target[written:]) + if err != nil { + return 0, errors.Wrap(err, "pickle ratchet") + } + written += writtenChain + } + return written, nil +} + +// PickleLen returns the actual number of bytes the pickled ratchet will have. +func (r Ratchet) PickleLen() int { + length := r.RootKey.PickleLen() + if r.SenderChains.IsSet { + length += libolmpickle.PickleUInt32Len(1) + length += r.SenderChains.PickleLen() + } else { + length += libolmpickle.PickleUInt32Len(0) + } + length += libolmpickle.PickleUInt32Len(uint32(len(r.ReceiverChains))) + length += len(r.ReceiverChains) * receiverChain{}.PickleLen() + length += libolmpickle.PickleUInt32Len(uint32(len(r.SkippedMessageKeys))) + length += len(r.SkippedMessageKeys) * skippedMessageKey{}.PickleLen() + return length +} + +// PickleLen returns the minimum number of bytes the pickled ratchet must have. +func (r Ratchet) PickleLenMin() int { + length := r.RootKey.PickleLen() + length += libolmpickle.PickleUInt32Len(0) + length += libolmpickle.PickleUInt32Len(0) + length += libolmpickle.PickleUInt32Len(0) + return length +} diff --git a/crypto/goolm/olm/olm_test.go b/crypto/goolm/olm/olm_test.go new file mode 100644 index 00000000..4f70ae81 --- /dev/null +++ b/crypto/goolm/olm/olm_test.go @@ -0,0 +1,185 @@ +package olm + +import ( + "bytes" + "encoding/json" + "testing" + + "codeberg.org/DerLukas/goolm/cipher" + "codeberg.org/DerLukas/goolm/crypto" +) + +var ( + sharedSecret = []byte("A secret") +) + +func initializeRatchets() (*Ratchet, *Ratchet, error) { + KdfInfo = struct { + Root []byte + Ratchet []byte + }{ + Root: []byte("Olm"), + Ratchet: []byte("OlmRatchet"), + } + RatchetCipher = cipher.NewAESSha256([]byte("OlmMessageKeys")) + aliceRatchet := New() + bobRatchet := New() + + aliceKey, err := crypto.Curve25519GenerateKey(nil) + if err != nil { + return nil, nil, err + } + + aliceRatchet.InitialiseAsAlice(sharedSecret, aliceKey) + bobRatchet.InitialiseAsBob(sharedSecret, aliceKey.PublicKey) + return aliceRatchet, bobRatchet, nil +} + +func TestSendReceive(t *testing.T) { + aliceRatchet, bobRatchet, err := initializeRatchets() + if err != nil { + t.Fatal(err) + } + + plainText := []byte("Hello Bob") + + //Alice sends Bob a message + encryptedMessage, err := aliceRatchet.Encrypt(plainText, nil) + if err != nil { + t.Fatal(err) + } + + decrypted, err := bobRatchet.Decrypt(encryptedMessage) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plainText, decrypted) { + t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted) + } + + //Bob sends Alice a message + plainText = []byte("Hello Alice") + encryptedMessage, err = bobRatchet.Encrypt(plainText, nil) + if err != nil { + t.Fatal(err) + } + decrypted, err = aliceRatchet.Decrypt(encryptedMessage) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plainText, decrypted) { + t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted) + } +} + +func TestOutOfOrder(t *testing.T) { + aliceRatchet, bobRatchet, err := initializeRatchets() + if err != nil { + t.Fatal(err) + } + + plainText1 := []byte("First Message") + plainText2 := []byte("Second Messsage. A bit longer than the first.") + + /* Alice sends Bob two messages and they arrive out of order */ + message1Encrypted, err := aliceRatchet.Encrypt(plainText1, nil) + if err != nil { + t.Fatal(err) + } + message2Encrypted, err := aliceRatchet.Encrypt(plainText2, nil) + if err != nil { + t.Fatal(err) + } + + decrypted2, err := bobRatchet.Decrypt(message2Encrypted) + if err != nil { + t.Fatal(err) + } + decrypted1, err := bobRatchet.Decrypt(message1Encrypted) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plainText1, decrypted1) { + t.Fatalf("expected '%v' from decryption but got '%v'", plainText1, decrypted1) + } + if !bytes.Equal(plainText2, decrypted2) { + t.Fatalf("expected '%v' from decryption but got '%v'", plainText2, decrypted2) + } +} + +func TestMoreMessages(t *testing.T) { + aliceRatchet, bobRatchet, err := initializeRatchets() + if err != nil { + t.Fatal(err) + } + plainText := []byte("These 15 bytes") + for i := 0; i < 8; i++ { + messageEncrypted, err := aliceRatchet.Encrypt(plainText, nil) + if err != nil { + t.Fatal(err) + } + decrypted, err := bobRatchet.Decrypt(messageEncrypted) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plainText, decrypted) { + t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted) + } + } + for i := 0; i < 8; i++ { + messageEncrypted, err := bobRatchet.Encrypt(plainText, nil) + if err != nil { + t.Fatal(err) + } + decrypted, err := aliceRatchet.Decrypt(messageEncrypted) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plainText, decrypted) { + t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted) + } + } + messageEncrypted, err := aliceRatchet.Encrypt(plainText, nil) + if err != nil { + t.Fatal(err) + } + decrypted, err := bobRatchet.Decrypt(messageEncrypted) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plainText, decrypted) { + t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted) + } +} + +func TestJSONEncoding(t *testing.T) { + aliceRatchet, bobRatchet, err := initializeRatchets() + if err != nil { + t.Fatal(err) + } + marshaled, err := json.Marshal(aliceRatchet) + if err != nil { + t.Fatal(err) + } + + newRatcher := Ratchet{} + err = json.Unmarshal(marshaled, &newRatcher) + if err != nil { + t.Fatal(err) + } + + plainText := []byte("These 15 bytes") + + messageEncrypted, err := newRatcher.Encrypt(plainText, nil) + if err != nil { + t.Fatal(err) + } + decrypted, err := bobRatchet.Decrypt(messageEncrypted) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plainText, decrypted) { + t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted) + } + +} diff --git a/crypto/goolm/olm/skippedMessage.go b/crypto/goolm/olm/skippedMessage.go new file mode 100644 index 00000000..893d548d --- /dev/null +++ b/crypto/goolm/olm/skippedMessage.go @@ -0,0 +1,54 @@ +package olm + +import ( + "codeberg.org/DerLukas/goolm" + "codeberg.org/DerLukas/goolm/crypto" + "github.com/pkg/errors" +) + +// skippedMessageKey stores a skipped message key +type skippedMessageKey struct { + RKey crypto.Curve25519PublicKey `json:"ratchetKey"` + MKey messageKey `json:"messageKey"` +} + +// UnpickleLibOlm decodes the unencryted value and populates the chain accordingly. It returns the number of bytes read. +func (r *skippedMessageKey) UnpickleLibOlm(value []byte) (int, error) { + curPos := 0 + readBytes, err := r.RKey.UnpickleLibOlm(value) + if err != nil { + return 0, err + } + curPos += readBytes + readBytes, err = r.MKey.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + return curPos, nil +} + +// PickleLibOlm encodes the chain into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (r skippedMessageKey) PickleLibOlm(target []byte) (int, error) { + if len(target) < r.PickleLen() { + return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle sender chain") + } + written, err := r.RKey.PickleLibOlm(target) + if err != nil { + return 0, errors.Wrap(err, "pickle sender chain") + } + writtenChain, err := r.MKey.PickleLibOlm(target) + if err != nil { + return 0, errors.Wrap(err, "pickle sender chain") + } + written += writtenChain + return written, nil +} + +// PickleLen returns the number of bytes the pickled chain will have. +func (r skippedMessageKey) PickleLen() int { + length := r.RKey.PickleLen() + length += r.MKey.PickleLen() + return length +} diff --git a/crypto/goolm/pk/decryption.go b/crypto/goolm/pk/decryption.go new file mode 100644 index 00000000..e530b037 --- /dev/null +++ b/crypto/goolm/pk/decryption.go @@ -0,0 +1,162 @@ +package pk + +import ( + "codeberg.org/DerLukas/goolm" + "codeberg.org/DerLukas/goolm/cipher" + "codeberg.org/DerLukas/goolm/crypto" + libolmpickle "codeberg.org/DerLukas/goolm/libolmPickle" + "codeberg.org/DerLukas/goolm/utilities" + "github.com/pkg/errors" + "maunium.net/go/mautrix/id" +) + +const ( + decryptionPickleVersionJSON uint8 = 1 + decryptionPickleVersionLibOlm uint32 = 1 +) + +// Decription is used to decrypt pk messages +type Decription struct { + KeyPair crypto.Curve25519KeyPair `json:"keyPair"` +} + +// NewDecription returns a new Decription with a new generated key pair. +func NewDecription() (*Decription, error) { + keyPair, err := crypto.Curve25519GenerateKey(nil) + if err != nil { + return nil, err + } + return &Decription{ + KeyPair: keyPair, + }, nil +} + +// NewDescriptionFromPrivate resturns a new Decription with the private key fixed. +func NewDecriptionFromPrivate(privateKey crypto.Curve25519PrivateKey) (*Decription, error) { + s := &Decription{} + keyPair, err := crypto.Curve25519GenerateFromPrivate(privateKey) + if err != nil { + return nil, err + } + s.KeyPair = keyPair + return s, nil +} + +// PubKey returns the public key base 64 encoded. +func (s Decription) PubKey() id.Curve25519 { + return s.KeyPair.B64Encoded() +} + +// PrivateKey returns the private key. +func (s Decription) PrivateKey() crypto.Curve25519PrivateKey { + return s.KeyPair.PrivateKey +} + +// Decrypt decrypts the ciphertext and verifies the MAC. The base64 encoded key is used to construct the shared secret. +func (s Decription) Decrypt(ciphertext, mac []byte, key id.Curve25519) ([]byte, error) { + keyDecoded, err := goolm.Base64Decode([]byte(key)) + if err != nil { + return nil, err + } + sharedSecret, err := s.KeyPair.SharedSecret(keyDecoded) + if err != nil { + return nil, err + } + decodedMAC, err := goolm.Base64Decode(mac) + if err != nil { + return nil, err + } + cipher := cipher.NewAESSha256(nil) + verified, err := cipher.Verify(sharedSecret, ciphertext, decodedMAC) + if err != nil { + return nil, err + } + if !verified { + return nil, errors.Wrap(goolm.ErrBadMAC, "decrypt") + } + plaintext, err := cipher.Decrypt(sharedSecret, ciphertext) + if err != nil { + return nil, err + } + return plaintext, nil +} + +// PickleAsJSON returns an Decription as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. +func (a Decription) PickleAsJSON(key []byte) ([]byte, error) { + return utilities.PickleAsJSON(a, decryptionPickleVersionJSON, key) +} + +// UnpickleAsJSON updates an Decription by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. +func (a *Decription) UnpickleAsJSON(pickled, key []byte) error { + return utilities.UnpickleAsJSON(a, pickled, key, decryptionPickleVersionJSON) +} + +// Unpickle decodes the base64 encoded string and decrypts the result with the key. +// The decrypted value is then passed to UnpickleLibOlm. +func (a *Decription) Unpickle(pickled, key []byte) error { + decrypted, err := cipher.Unpickle(key, pickled) + if err != nil { + return err + } + _, err = a.UnpickleLibOlm(decrypted) + return err +} + +// UnpickleLibOlm decodes the unencryted value and populates the Decription accordingly. It returns the number of bytes read. +func (a *Decription) UnpickleLibOlm(value []byte) (int, error) { + //First 4 bytes are the accountPickleVersion + pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value) + if err != nil { + return 0, err + } + switch pickledVersion { + case decryptionPickleVersionLibOlm: + default: + return 0, errors.Wrap(goolm.ErrBadVersion, "unpickle olmSession") + } + readBytes, err := a.KeyPair.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + return curPos, nil +} + +// Pickle returns a base64 encoded and with key encrypted pickled Decription using PickleLibOlm(). +func (a Decription) Pickle(key []byte) ([]byte, error) { + pickeledBytes := make([]byte, a.PickleLen()) + written, err := a.PickleLibOlm(pickeledBytes) + if err != nil { + return nil, err + } + if written != len(pickeledBytes) { + return nil, errors.New("number of written bytes not correct") + } + encrypted, err := cipher.Pickle(key, pickeledBytes) + if err != nil { + return nil, err + } + return encrypted, nil +} + +// PickleLibOlm encodes the Decription into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (a Decription) PickleLibOlm(target []byte) (int, error) { + if len(target) < a.PickleLen() { + return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle Decription") + } + written := libolmpickle.PickleUInt32(decryptionPickleVersionLibOlm, target) + writtenKey, err := a.KeyPair.PickleLibOlm(target[written:]) + if err != nil { + return 0, errors.Wrap(err, "pickle Decription") + } + written += writtenKey + return written, nil +} + +// PickleLen returns the number of bytes the pickled Decription will have. +func (a Decription) PickleLen() int { + length := libolmpickle.PickleUInt32Len(decryptionPickleVersionLibOlm) + length += a.KeyPair.PickleLen() + return length +} diff --git a/crypto/goolm/pk/encryption.go b/crypto/goolm/pk/encryption.go new file mode 100644 index 00000000..2f64b7b2 --- /dev/null +++ b/crypto/goolm/pk/encryption.go @@ -0,0 +1,46 @@ +package pk + +import ( + "codeberg.org/DerLukas/goolm" + "codeberg.org/DerLukas/goolm/cipher" + "codeberg.org/DerLukas/goolm/crypto" + "maunium.net/go/mautrix/id" +) + +// Encryption is used to encrypt pk messages +type Encryption struct { + RecipientKey crypto.Curve25519PublicKey +} + +// NewEncryption returns a new Encryption with the base64 encoded public key of the recipient +func NewEncryption(pubKey id.Curve25519) (*Encryption, error) { + pubKeyDecoded, err := goolm.Base64Decode([]byte(pubKey)) + if err != nil { + return nil, err + } + return &Encryption{ + RecipientKey: pubKeyDecoded, + }, nil +} + +// Encrypt encrypts the plaintext with the privateKey and returns the ciphertext and base64 encoded MAC. +func (e Encryption) Encrypt(plaintext []byte, privateKey crypto.Curve25519PrivateKey) (ciphertext, mac []byte, err error) { + keyPair, err := crypto.Curve25519GenerateFromPrivate(privateKey) + if err != nil { + return nil, nil, err + } + sharedSecret, err := keyPair.SharedSecret(e.RecipientKey) + if err != nil { + return nil, nil, err + } + cipher := cipher.NewAESSha256(nil) + ciphertext, err = cipher.Encrypt(sharedSecret, plaintext) + if err != nil { + return nil, nil, err + } + mac, err = cipher.MAC(sharedSecret, ciphertext) + if err != nil { + return nil, nil, err + } + return ciphertext, goolm.Base64Encode(mac), nil +} diff --git a/crypto/goolm/pk/pk_test.go b/crypto/goolm/pk/pk_test.go new file mode 100644 index 00000000..58918bea --- /dev/null +++ b/crypto/goolm/pk/pk_test.go @@ -0,0 +1,132 @@ +package pk + +import ( + "bytes" + "testing" + + "codeberg.org/DerLukas/goolm" + "codeberg.org/DerLukas/goolm/crypto" + "maunium.net/go/mautrix/id" +) + +func TestEncryptionDecryption(t *testing.T) { + alicePrivate := []byte{ + 0x77, 0x07, 0x6D, 0x0A, 0x73, 0x18, 0xA5, 0x7D, + 0x3C, 0x16, 0xC1, 0x72, 0x51, 0xB2, 0x66, 0x45, + 0xDF, 0x4C, 0x2F, 0x87, 0xEB, 0xC0, 0x99, 0x2A, + 0xB1, 0x77, 0xFB, 0xA5, 0x1D, 0xB9, 0x2C, 0x2A, + } + alicePublic := []byte("hSDwCYkwp1R0i33ctD73Wg2/Og0mOBr066SpjqqbTmo") + bobPrivate := []byte{ + 0x5D, 0xAB, 0x08, 0x7E, 0x62, 0x4A, 0x8A, 0x4B, + 0x79, 0xE1, 0x7F, 0x8B, 0x83, 0x80, 0x0E, 0xE6, + 0x6F, 0x3B, 0xB1, 0x29, 0x26, 0x18, 0xB6, 0xFD, + 0x1C, 0x2F, 0x8B, 0x27, 0xFF, 0x88, 0xE0, 0xEB, + } + bobPublic := []byte("3p7bfXt9wbTTW2HC7OQ1Nz+DQ8hbeGdNrfx+FG+IK08") + decryption, err := NewDecriptionFromPrivate(alicePrivate) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal([]byte(decryption.PubKey()), alicePublic) { + t.Fatal("public key not correct") + } + if !bytes.Equal(decryption.PrivateKey(), alicePrivate) { + t.Fatal("private key not correct") + } + + encryption, err := NewEncryption(decryption.PubKey()) + if err != nil { + t.Fatal(err) + } + plaintext := []byte("This is a test") + + ciphertext, mac, err := encryption.Encrypt(plaintext, bobPrivate) + if err != nil { + t.Fatal(err) + } + + decrypted, err := decryption.Decrypt(ciphertext, mac, id.Curve25519(bobPublic)) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(decrypted, plaintext) { + t.Fatal("message not equal") + } +} + +func TestSigning(t *testing.T) { + seed := []byte{ + 0x77, 0x07, 0x6D, 0x0A, 0x73, 0x18, 0xA5, 0x7D, + 0x3C, 0x16, 0xC1, 0x72, 0x51, 0xB2, 0x66, 0x45, + 0xDF, 0x4C, 0x2F, 0x87, 0xEB, 0xC0, 0x99, 0x2A, + 0xB1, 0x77, 0xFB, 0xA5, 0x1D, 0xB9, 0x2C, 0x2A, + } + message := []byte("We hold these truths to be self-evident, that all men are created equal, that they are endowed by their Creator with certain unalienable Rights, that among these are Life, Liberty and the pursuit of Happiness.") + signing, _ := NewSigningFromSeed(seed) + signature := signing.Sign(message) + signatureDecoded, err := goolm.Base64Decode(signature) + if err != nil { + t.Fatal(err) + } + pubKeyEncoded := signing.PublicKey() + pubKeyDecoded, err := goolm.Base64Decode([]byte(pubKeyEncoded)) + if err != nil { + t.Fatal(err) + } + pubKey := crypto.Ed25519PublicKey(pubKeyDecoded) + + verified := pubKey.Verify(message, signatureDecoded) + if !verified { + t.Fatal("signature did not verify") + } + copy(signatureDecoded[0:], []byte("m")) + verified = pubKey.Verify(message, signatureDecoded) + if verified { + t.Fatal("signature did verify") + } +} + +func TestDecryptionPickling(t *testing.T) { + alicePrivate := []byte{ + 0x77, 0x07, 0x6D, 0x0A, 0x73, 0x18, 0xA5, 0x7D, + 0x3C, 0x16, 0xC1, 0x72, 0x51, 0xB2, 0x66, 0x45, + 0xDF, 0x4C, 0x2F, 0x87, 0xEB, 0xC0, 0x99, 0x2A, + 0xB1, 0x77, 0xFB, 0xA5, 0x1D, 0xB9, 0x2C, 0x2A, + } + alicePublic := []byte("hSDwCYkwp1R0i33ctD73Wg2/Og0mOBr066SpjqqbTmo") + decryption, err := NewDecriptionFromPrivate(alicePrivate) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal([]byte(decryption.PubKey()), alicePublic) { + t.Fatal("public key not correct") + } + if !bytes.Equal(decryption.PrivateKey(), alicePrivate) { + t.Fatal("private key not correct") + } + pickleKey := []byte("secret_key") + expectedPickle := []byte("qx37WTQrjZLz5tId/uBX9B3/okqAbV1ofl9UnHKno1eipByCpXleAAlAZoJgYnCDOQZDQWzo3luTSfkF9pU1mOILCbbouubs6TVeDyPfgGD9i86J8irHjA") + pickled, err := decryption.Pickle(pickleKey) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(expectedPickle, pickled) { + t.Fatalf("pickle not as expected:\n%v\n%v\n", pickled, expectedPickle) + } + + newDecription, err := NewDecription() + if err != nil { + t.Fatal(err) + } + err = newDecription.Unpickle(pickled, pickleKey) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal([]byte(newDecription.PubKey()), alicePublic) { + t.Fatal("public key not correct") + } + if !bytes.Equal(newDecription.PrivateKey(), alicePrivate) { + t.Fatal("private key not correct") + } +} diff --git a/crypto/goolm/pk/signing.go b/crypto/goolm/pk/signing.go new file mode 100644 index 00000000..a80cc5c8 --- /dev/null +++ b/crypto/goolm/pk/signing.go @@ -0,0 +1,44 @@ +package pk + +import ( + "crypto/rand" + + "codeberg.org/DerLukas/goolm" + "codeberg.org/DerLukas/goolm/crypto" + "maunium.net/go/mautrix/id" +) + +// Signing is used for signing a pk +type Signing struct { + KeyPair crypto.Ed25519KeyPair `json:"keyPair"` + Seed []byte `json:"seed"` +} + +// NewSigningFromSeed constructs a new Signing based on a seed. +func NewSigningFromSeed(seed []byte) (*Signing, error) { + s := &Signing{} + s.Seed = seed + s.KeyPair = crypto.Ed25519GenerateFromSeed(seed) + return s, nil +} + +// NewSigning returns a Signing based on a random seed +func NewSigning() (*Signing, error) { + seed := make([]byte, 32) + _, err := rand.Read(seed) + if err != nil { + return nil, err + } + return NewSigningFromSeed(seed) +} + +// Sign returns the signature of the message base64 encoded. +func (s Signing) Sign(message []byte) []byte { + signature := s.KeyPair.Sign(message) + return goolm.Base64Encode(signature) +} + +// PublicKey returns the public key of the key pair base 64 encoded. +func (s Signing) PublicKey() id.Ed25519 { + return s.KeyPair.B64Encoded() +} diff --git a/crypto/goolm/sas/main.go b/crypto/goolm/sas/main.go new file mode 100644 index 00000000..b1f55069 --- /dev/null +++ b/crypto/goolm/sas/main.go @@ -0,0 +1,76 @@ +// sas provides the means to do SAS between keys +package sas + +import ( + "io" + + "codeberg.org/DerLukas/goolm" + "codeberg.org/DerLukas/goolm/crypto" +) + +// SAS contains the key pair and secret for SAS. +type SAS struct { + KeyPair crypto.Curve25519KeyPair + Secret []byte +} + +// New creates a new SAS with a new key pair. +func New() (*SAS, error) { + kp, err := crypto.Curve25519GenerateKey(nil) + if err != nil { + return nil, err + } + s := &SAS{ + KeyPair: kp, + } + return s, nil +} + +// GetPubkey returns the public key of the key pair base64 encoded +func (s SAS) GetPubkey() []byte { + return goolm.Base64Encode(s.KeyPair.PublicKey) +} + +// SetTheirKey sets the key of the other party and computes the shared secret. +func (s *SAS) SetTheirKey(key []byte) error { + keyDecoded, err := goolm.Base64Decode(key) + if err != nil { + return err + } + sharedSecret, err := s.KeyPair.SharedSecret(keyDecoded) + if err != nil { + return err + } + s.Secret = sharedSecret + return nil +} + +// GenerateBytes creates length bytes from the shared secret and info. +func (s SAS) GenerateBytes(info []byte, length uint) ([]byte, error) { + byteReader := crypto.HKDFSHA256(s.Secret, nil, info) + output := make([]byte, length) + if _, err := io.ReadFull(byteReader, output); err != nil { + return nil, err + } + return output, nil +} + +// calculateMAC returns a base64 encoded MAC of input. +func (s *SAS) calculateMAC(input, info []byte, length uint) ([]byte, error) { + key, err := s.GenerateBytes(info, length) + if err != nil { + return nil, err + } + mac := crypto.HMACSHA256(key, input) + return goolm.Base64Encode(mac), nil +} + +// CalculateMACFixes returns a base64 encoded, 32 byte long MAC of input. +func (s SAS) CalculateMAC(input, info []byte) ([]byte, error) { + return s.calculateMAC(input, info, 32) +} + +// CalculateMACLongKDF returns a base64 encoded, 256 byte long MAC of input. +func (s SAS) CalculateMACLongKDF(input, info []byte) ([]byte, error) { + return s.calculateMAC(input, info, 256) +} diff --git a/crypto/goolm/sas/main_test.go b/crypto/goolm/sas/main_test.go new file mode 100644 index 00000000..6e897a32 --- /dev/null +++ b/crypto/goolm/sas/main_test.go @@ -0,0 +1,111 @@ +package sas + +import ( + "bytes" + "testing" + + "codeberg.org/DerLukas/goolm/crypto" +) + +func initSAS() (*SAS, *SAS, error) { + alicePrivate := crypto.Curve25519PrivateKey([]byte{ + 0x77, 0x07, 0x6D, 0x0A, 0x73, 0x18, 0xA5, 0x7D, + 0x3C, 0x16, 0xC1, 0x72, 0x51, 0xB2, 0x66, 0x45, + 0xDF, 0x4C, 0x2F, 0x87, 0xEB, 0xC0, 0x99, 0x2A, + 0xB1, 0x77, 0xFB, 0xA5, 0x1D, 0xB9, 0x2C, 0x2A, + }) + bobPrivate := crypto.Curve25519PrivateKey([]byte{ + 0x5D, 0xAB, 0x08, 0x7E, 0x62, 0x4A, 0x8A, 0x4B, + 0x79, 0xE1, 0x7F, 0x8B, 0x83, 0x80, 0x0E, 0xE6, + 0x6F, 0x3B, 0xB1, 0x29, 0x26, 0x18, 0xB6, 0xFD, + 0x1C, 0x2F, 0x8B, 0x27, 0xFF, 0x88, 0xE0, 0xEB, + }) + + aliceSAS, err := New() + if err != nil { + return nil, nil, err + } + aliceSAS.KeyPair.PrivateKey = alicePrivate + aliceSAS.KeyPair.PublicKey, err = alicePrivate.PubKey() + if err != nil { + return nil, nil, err + } + + bobSAS, err := New() + if err != nil { + return nil, nil, err + } + bobSAS.KeyPair.PrivateKey = bobPrivate + bobSAS.KeyPair.PublicKey, err = bobPrivate.PubKey() + if err != nil { + return nil, nil, err + } + return aliceSAS, bobSAS, nil +} + +func TestGenerateBytes(t *testing.T) { + aliceSAS, bobSAS, err := initSAS() + if err != nil { + t.Fatal(err) + } + alicePublicEncoded := []byte("hSDwCYkwp1R0i33ctD73Wg2/Og0mOBr066SpjqqbTmo") + bobPublicEncoded := []byte("3p7bfXt9wbTTW2HC7OQ1Nz+DQ8hbeGdNrfx+FG+IK08") + + if !bytes.Equal(aliceSAS.GetPubkey(), alicePublicEncoded) { + t.Fatal("public keys not equal") + } + if !bytes.Equal(bobSAS.GetPubkey(), bobPublicEncoded) { + t.Fatal("public keys not equal") + } + + err = aliceSAS.SetTheirKey(bobSAS.GetPubkey()) + if err != nil { + t.Fatal(err) + } + err = bobSAS.SetTheirKey(aliceSAS.GetPubkey()) + if err != nil { + t.Fatal(err) + } + + aliceBytes, err := aliceSAS.GenerateBytes([]byte("SAS"), 6) + if err != nil { + t.Fatal(err) + } + bobBytes, err := bobSAS.GenerateBytes([]byte("SAS"), 6) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(aliceBytes, bobBytes) { + t.Fatal("results are not equal") + } +} + +func TestSASMac(t *testing.T) { + aliceSAS, bobSAS, err := initSAS() + if err != nil { + t.Fatal(err) + } + err = aliceSAS.SetTheirKey(bobSAS.GetPubkey()) + if err != nil { + t.Fatal(err) + } + err = bobSAS.SetTheirKey(aliceSAS.GetPubkey()) + if err != nil { + t.Fatal(err) + } + + plainText := []byte("Hello world!") + info := []byte("MAC") + + aliceMac, err := aliceSAS.CalculateMAC(plainText, info) + if err != nil { + t.Fatal(err) + } + bobMac, err := bobSAS.CalculateMAC(plainText, info) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(aliceMac, bobMac) { + t.Fatal("results are not equal") + } +} diff --git a/crypto/goolm/session/main.go b/crypto/goolm/session/main.go new file mode 100644 index 00000000..0caf8045 --- /dev/null +++ b/crypto/goolm/session/main.go @@ -0,0 +1,2 @@ +// session provides the different types of sessions for en/decrypting of messages +package session diff --git a/crypto/goolm/session/megolmInboundSession.go b/crypto/goolm/session/megolmInboundSession.go new file mode 100644 index 00000000..5f1d1c8f --- /dev/null +++ b/crypto/goolm/session/megolmInboundSession.go @@ -0,0 +1,273 @@ +package session + +import ( + "codeberg.org/DerLukas/goolm" + "codeberg.org/DerLukas/goolm/cipher" + "codeberg.org/DerLukas/goolm/crypto" + libolmpickle "codeberg.org/DerLukas/goolm/libolmPickle" + "codeberg.org/DerLukas/goolm/megolm" + "codeberg.org/DerLukas/goolm/message" + "codeberg.org/DerLukas/goolm/utilities" + "github.com/pkg/errors" + "maunium.net/go/mautrix/id" +) + +const ( + megolmInboundSessionPickleVersionJSON byte = 1 + megolmInboundSessionPickleVersionLibOlm uint32 = 2 +) + +// MegolmInboundSession stores information about the sessions of receive. +type MegolmInboundSession struct { + Ratchet megolm.Ratchet `json:"ratchet"` + SigningKey crypto.Ed25519PublicKey `json:"signingKey"` + InitialRatchet megolm.Ratchet `json:"initalRatchet"` + SigningKeyVerified bool `json:"signingKeyVerified"` //not used for now +} + +// NewMegolmInboundSession creates a new MegolmInboundSession from a base64 encoded session sharing message. +func NewMegolmInboundSession(input []byte) (*MegolmInboundSession, error) { + var err error + input, err = goolm.Base64Decode(input) + if err != nil { + return nil, err + } + msg := message.MegolmSessionSharing{} + err = msg.VerifyAndDecode(input) + if err != nil { + return nil, err + } + o := &MegolmInboundSession{} + o.SigningKey = msg.PublicKey + o.SigningKeyVerified = true + ratchet, err := megolm.New(msg.Counter, msg.RatchetData) + if err != nil { + return nil, err + } + o.Ratchet = *ratchet + o.InitialRatchet = *ratchet + return o, nil +} + +// NewMegolmInboundSessionFromExport creates a new MegolmInboundSession from a base64 encoded session export message. +func NewMegolmInboundSessionFromExport(input []byte) (*MegolmInboundSession, error) { + var err error + input, err = goolm.Base64Decode(input) + if err != nil { + return nil, err + } + msg := message.MegolmSessionExport{} + err = msg.Decode(input) + if err != nil { + return nil, err + } + o := &MegolmInboundSession{} + o.SigningKey = msg.PublicKey + ratchet, err := megolm.New(msg.Counter, msg.RatchetData) + if err != nil { + return nil, err + } + o.Ratchet = *ratchet + o.InitialRatchet = *ratchet + return o, nil +} + +// MegolmInboundSessionFromPickled loads the MegolmInboundSession details from a pickled base64 string. The input is decrypted with the supplied key. +func MegolmInboundSessionFromPickled(pickled, key []byte) (*MegolmInboundSession, error) { + if len(pickled) == 0 { + return nil, errors.Wrap(goolm.ErrEmptyInput, "megolmInboundSessionFromPickled") + } + a := &MegolmInboundSession{} + err := a.Unpickle(pickled, key) + if err != nil { + return nil, err + } + return a, nil +} + +// getRatchet tries to find the correct ratchet for a messageIndex. +func (o MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, error) { + // pick a megolm instance to use. if we are at or beyond the latest ratchet value, use that + if (messageIndex - o.Ratchet.Counter) < uint32(1<<31) { + o.Ratchet.AdvanceTo(messageIndex) + return &o.Ratchet, nil + } + if (messageIndex - o.InitialRatchet.Counter) >= uint32(1<<31) { + // the counter is before our initial ratchet - we can't decode this + return nil, errors.Wrap(goolm.ErrRatchetNotAvailable, "decrypt") + } + // otherwise, start from the initial ratchet. Take a copy so that we don't overwrite the initial ratchet + copiedRatchet := o.InitialRatchet + copiedRatchet.AdvanceTo(messageIndex) + return &copiedRatchet, nil + +} + +// Decrypt decrypts a base64 encoded group message. +func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint32, error) { + if o.SigningKey == nil { + return nil, 0, errors.Wrap(goolm.ErrBadMessageFormat, "decrypt") + } + decoded, err := goolm.Base64Decode(ciphertext) + if err != nil { + return nil, 0, err + } + msg := &message.GroupMessage{} + err = msg.Decode(decoded) + if err != nil { + return nil, 0, err + } + if msg.Version != protocolVersion { + return nil, 0, errors.Wrap(goolm.ErrWrongProtocolVersion, "decrypt") + } + if msg.Ciphertext == nil || !msg.HasMessageIndex { + return nil, 0, errors.Wrap(goolm.ErrBadMessageFormat, "decrypt") + } + + // verify signature + verifiedSignature := msg.VerifySignatureInline(o.SigningKey, decoded) + if !verifiedSignature { + return nil, 0, errors.Wrap(goolm.ErrBadSignature, "decrypt") + } + + targetRatch, err := o.getRatchet(msg.MessageIndex) + if err != nil { + return nil, 0, err + } + + decrypted, err := targetRatch.Decrypt(decoded, &o.SigningKey, msg) + if err != nil { + return nil, 0, err + } + o.SigningKeyVerified = true + return decrypted, msg.MessageIndex, nil + +} + +// SessionID returns the base64 endoded signing key +func (o MegolmInboundSession) SessionID() id.SessionID { + return id.SessionID(goolm.Base64Encode(o.SigningKey)) +} + +// PickleAsJSON returns an MegolmInboundSession as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. +func (o MegolmInboundSession) PickleAsJSON(key []byte) ([]byte, error) { + return utilities.PickleAsJSON(o, megolmInboundSessionPickleVersionJSON, key) +} + +// UnpickleAsJSON updates an MegolmInboundSession by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. +func (o *MegolmInboundSession) UnpickleAsJSON(pickled, key []byte) error { + return utilities.UnpickleAsJSON(o, pickled, key, megolmInboundSessionPickleVersionJSON) +} + +// SessionExportMessage creates an base64 encoded export of the session. +func (o MegolmInboundSession) SessionExportMessage(messageIndex uint32) ([]byte, error) { + ratchet, err := o.getRatchet(messageIndex) + if err != nil { + return nil, err + } + return ratchet.SessionExportMessage(o.SigningKey) +} + +// Unpickle decodes the base64 encoded string and decrypts the result with the key. +// The decrypted value is then passed to UnpickleLibOlm. +func (o *MegolmInboundSession) Unpickle(pickled, key []byte) error { + decrypted, err := cipher.Unpickle(key, pickled) + if err != nil { + return err + } + _, err = o.UnpickleLibOlm(decrypted) + return err +} + +// UnpickleLibOlm decodes the unencryted value and populates the Session accordingly. It returns the number of bytes read. +func (o *MegolmInboundSession) UnpickleLibOlm(value []byte) (int, error) { + //First 4 bytes are the accountPickleVersion + pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value) + if err != nil { + return 0, err + } + switch pickledVersion { + case megolmInboundSessionPickleVersionLibOlm, 1: + default: + return 0, errors.Wrap(goolm.ErrBadVersion, "unpickle MegolmInboundSession") + } + readBytes, err := o.InitialRatchet.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + readBytes, err = o.Ratchet.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + readBytes, err = o.SigningKey.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + if pickledVersion == 1 { + // pickle v1 had no signing_key_verified field (all keyshares were verified at import time) + o.SigningKeyVerified = true + } else { + o.SigningKeyVerified, readBytes, err = libolmpickle.UnpickleBool(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + } + return curPos, nil +} + +// Pickle returns a base64 encoded and with key encrypted pickled MegolmInboundSession using PickleLibOlm(). +func (o MegolmInboundSession) Pickle(key []byte) ([]byte, error) { + pickeledBytes := make([]byte, o.PickleLen()) + written, err := o.PickleLibOlm(pickeledBytes) + if err != nil { + return nil, err + } + if written != len(pickeledBytes) { + return nil, errors.New("number of written bytes not correct") + } + encrypted, err := cipher.Pickle(key, pickeledBytes) + if err != nil { + return nil, err + } + return encrypted, nil +} + +// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (o MegolmInboundSession) PickleLibOlm(target []byte) (int, error) { + if len(target) < o.PickleLen() { + return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle MegolmInboundSession") + } + written := libolmpickle.PickleUInt32(megolmInboundSessionPickleVersionLibOlm, target) + writtenInitRatchet, err := o.InitialRatchet.PickleLibOlm(target[written:]) + if err != nil { + return 0, errors.Wrap(err, "pickle MegolmInboundSession") + } + written += writtenInitRatchet + writtenRatchet, err := o.Ratchet.PickleLibOlm(target[written:]) + if err != nil { + return 0, errors.Wrap(err, "pickle MegolmInboundSession") + } + written += writtenRatchet + writtenPubKey, err := o.SigningKey.PickleLibOlm(target[written:]) + if err != nil { + return 0, errors.Wrap(err, "pickle MegolmInboundSession") + } + written += writtenPubKey + written += libolmpickle.PickleBool(o.SigningKeyVerified, target[written:]) + return written, nil +} + +// PickleLen returns the number of bytes the pickled session will have. +func (o MegolmInboundSession) PickleLen() int { + length := libolmpickle.PickleUInt32Len(megolmInboundSessionPickleVersionLibOlm) + length += o.InitialRatchet.PickleLen() + length += o.Ratchet.PickleLen() + length += o.SigningKey.PickleLen() + length += libolmpickle.PickleBoolLen(o.SigningKeyVerified) + return length +} diff --git a/crypto/goolm/session/megolmOutboundSession.go b/crypto/goolm/session/megolmOutboundSession.go new file mode 100644 index 00000000..23deb43f --- /dev/null +++ b/crypto/goolm/session/megolmOutboundSession.go @@ -0,0 +1,168 @@ +package session + +import ( + "math/rand" + + "codeberg.org/DerLukas/goolm" + "codeberg.org/DerLukas/goolm/cipher" + "codeberg.org/DerLukas/goolm/crypto" + libolmpickle "codeberg.org/DerLukas/goolm/libolmPickle" + "codeberg.org/DerLukas/goolm/megolm" + "codeberg.org/DerLukas/goolm/utilities" + "github.com/pkg/errors" + "maunium.net/go/mautrix/id" +) + +const ( + megolmOutboundSessionPickleVersion byte = 1 + megolmOutboundSessionPickleVersionLibOlm uint32 = 1 +) + +// MegolmOutboundSession stores information about the sessions to send. +type MegolmOutboundSession struct { + Ratchet megolm.Ratchet `json:"ratchet"` + SigningKey crypto.Ed25519KeyPair `json:"signingKey"` +} + +// NewMegolmOutboundSession creates a new MegolmOutboundSession. +func NewMegolmOutboundSession() (*MegolmOutboundSession, error) { + o := &MegolmOutboundSession{} + var err error + o.SigningKey, err = crypto.Ed25519GenerateKey(nil) + if err != nil { + return nil, err + } + var randomData [megolm.RatchetParts * megolm.RatchetPartLength]byte + _, err = rand.Read(randomData[:]) + if err != nil { + return nil, err + } + ratchet, err := megolm.New(0, randomData) + if err != nil { + return nil, err + } + o.Ratchet = *ratchet + return o, nil +} + +// MegolmOutboundSessionFromPickled loads the MegolmOutboundSession details from a pickled base64 string. The input is decrypted with the supplied key. +func MegolmOutboundSessionFromPickled(pickled, key []byte) (*MegolmOutboundSession, error) { + if len(pickled) == 0 { + return nil, errors.Wrap(goolm.ErrEmptyInput, "megolmOutboundSessionFromPickled") + } + a := &MegolmOutboundSession{} + err := a.Unpickle(pickled, key) + if err != nil { + return nil, err + } + return a, nil +} + +// Encrypt encrypts the plaintext as a base64 encoded group message. +func (o *MegolmOutboundSession) Encrypt(plaintext []byte) ([]byte, error) { + encrypted, err := o.Ratchet.Encrypt(plaintext, &o.SigningKey) + if err != nil { + return nil, err + } + return goolm.Base64Encode(encrypted), nil +} + +// SessionID returns the base64 endoded public signing key +func (o MegolmOutboundSession) SessionID() id.SessionID { + return id.SessionID(goolm.Base64Encode(o.SigningKey.PublicKey)) +} + +// PickleAsJSON returns an Session as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. +func (o MegolmOutboundSession) PickleAsJSON(key []byte) ([]byte, error) { + return utilities.PickleAsJSON(o, megolmOutboundSessionPickleVersion, key) +} + +// UnpickleAsJSON updates an Session by a base64 encrypted string with the key. The unencrypted representation has to be in JSON format. +func (o *MegolmOutboundSession) UnpickleAsJSON(pickled, key []byte) error { + return utilities.UnpickleAsJSON(o, pickled, key, megolmOutboundSessionPickleVersion) +} + +// Unpickle decodes the base64 encoded string and decrypts the result with the key. +// The decrypted value is then passed to UnpickleLibOlm. +func (o *MegolmOutboundSession) Unpickle(pickled, key []byte) error { + decrypted, err := cipher.Unpickle(key, pickled) + if err != nil { + return err + } + _, err = o.UnpickleLibOlm(decrypted) + return err +} + +// UnpickleLibOlm decodes the unencryted value and populates the Session accordingly. It returns the number of bytes read. +func (o *MegolmOutboundSession) UnpickleLibOlm(value []byte) (int, error) { + //First 4 bytes are the accountPickleVersion + pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value) + if err != nil { + return 0, err + } + switch pickledVersion { + case megolmOutboundSessionPickleVersionLibOlm: + default: + return 0, errors.Wrap(goolm.ErrBadVersion, "unpickle MegolmInboundSession") + } + readBytes, err := o.Ratchet.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + readBytes, err = o.SigningKey.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + return curPos, nil +} + +// Pickle returns a base64 encoded and with key encrypted pickled MegolmOutboundSession using PickleLibOlm(). +func (o MegolmOutboundSession) Pickle(key []byte) ([]byte, error) { + pickeledBytes := make([]byte, o.PickleLen()) + written, err := o.PickleLibOlm(pickeledBytes) + if err != nil { + return nil, err + } + if written != len(pickeledBytes) { + return nil, errors.New("number of written bytes not correct") + } + encrypted, err := cipher.Pickle(key, pickeledBytes) + if err != nil { + return nil, err + } + return encrypted, nil +} + +// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (o MegolmOutboundSession) PickleLibOlm(target []byte) (int, error) { + if len(target) < o.PickleLen() { + return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle MegolmOutboundSession") + } + written := libolmpickle.PickleUInt32(megolmOutboundSessionPickleVersionLibOlm, target) + writtenRatchet, err := o.Ratchet.PickleLibOlm(target[written:]) + if err != nil { + return 0, errors.Wrap(err, "pickle MegolmOutboundSession") + } + written += writtenRatchet + writtenPubKey, err := o.SigningKey.PickleLibOlm(target[written:]) + if err != nil { + return 0, errors.Wrap(err, "pickle MegolmOutboundSession") + } + written += writtenPubKey + return written, nil +} + +// PickleLen returns the number of bytes the pickled session will have. +func (o MegolmOutboundSession) PickleLen() int { + length := libolmpickle.PickleUInt32Len(megolmOutboundSessionPickleVersionLibOlm) + length += o.Ratchet.PickleLen() + length += o.SigningKey.PickleLen() + return length +} + +func (o MegolmOutboundSession) SessionSharingMessage() ([]byte, error) { + return o.Ratchet.SessionSharingMessage(o.SigningKey) +} diff --git a/crypto/goolm/session/melgomSession_test.go b/crypto/goolm/session/melgomSession_test.go new file mode 100644 index 00000000..55f57ad4 --- /dev/null +++ b/crypto/goolm/session/melgomSession_test.go @@ -0,0 +1,283 @@ +package session + +import ( + "bytes" + "errors" + "math/rand" + "testing" + + "codeberg.org/DerLukas/goolm" + "codeberg.org/DerLukas/goolm/crypto" + "codeberg.org/DerLukas/goolm/megolm" +) + +func TestOutboundPickleJSON(t *testing.T) { + pickleKey := []byte("secretKey") + session, err := NewMegolmOutboundSession() + if err != nil { + t.Fatal(err) + } + kp, err := crypto.Ed25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + session.SigningKey = kp + pickled, err := session.PickleAsJSON(pickleKey) + if err != nil { + t.Fatal(err) + } + + newSession := MegolmOutboundSession{} + err = newSession.UnpickleAsJSON(pickled, pickleKey) + if err != nil { + t.Fatal(err) + } + if session.SessionID() != newSession.SessionID() { + t.Fatal("session ids not equal") + } + if !bytes.Equal(session.SigningKey.PrivateKey, newSession.SigningKey.PrivateKey) { + t.Fatal("private keys not equal") + } + if !bytes.Equal(session.Ratchet.Data[:], newSession.Ratchet.Data[:]) { + t.Fatal("ratchet data not equal") + } + if session.Ratchet.Counter != newSession.Ratchet.Counter { + t.Fatal("ratchet counter not equal") + } +} + +func TestInboundPickleJSON(t *testing.T) { + pickleKey := []byte("secretKey") + session := MegolmInboundSession{} + kp, err := crypto.Ed25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + session.SigningKey = kp.PublicKey + var randomData [megolm.RatchetParts * megolm.RatchetPartLength]byte + _, err = rand.Read(randomData[:]) + if err != nil { + t.Fatal(err) + } + ratchet, err := megolm.New(0, randomData) + if err != nil { + t.Fatal(err) + } + session.Ratchet = *ratchet + pickled, err := session.PickleAsJSON(pickleKey) + if err != nil { + t.Fatal(err) + } + + newSession := MegolmInboundSession{} + err = newSession.UnpickleAsJSON(pickled, pickleKey) + if err != nil { + t.Fatal(err) + } + if session.SessionID() != newSession.SessionID() { + t.Fatal("session ids not equal") + } + if !bytes.Equal(session.SigningKey, newSession.SigningKey) { + t.Fatal("private keys not equal") + } + if !bytes.Equal(session.Ratchet.Data[:], newSession.Ratchet.Data[:]) { + t.Fatal("ratchet data not equal") + } + if session.Ratchet.Counter != newSession.Ratchet.Counter { + t.Fatal("ratchet counter not equal") + } +} + +func TestGroupSendReceive(t *testing.T) { + randomData := []byte( + "0123456789ABDEF0123456789ABCDEF" + + "0123456789ABDEF0123456789ABCDEF" + + "0123456789ABDEF0123456789ABCDEF" + + "0123456789ABDEF0123456789ABCDEF" + + "0123456789ABDEF0123456789ABCDEF" + + "0123456789ABDEF0123456789ABCDEF", + ) + + outboundSession, err := NewMegolmOutboundSession() + if err != nil { + t.Fatal(err) + } + copy(outboundSession.Ratchet.Data[:], randomData) + if outboundSession.Ratchet.Counter != 0 { + t.Fatal("ratchet counter is not correkt") + } + sessionSharing, err := outboundSession.SessionSharingMessage() + if err != nil { + t.Fatal(err) + } + plainText := []byte("Message") + ciphertext, err := outboundSession.Encrypt(plainText) + if err != nil { + t.Fatal(err) + } + if outboundSession.Ratchet.Counter != 1 { + t.Fatal("ratchet counter is not correkt") + } + + //build inbound session + inboundSession, err := NewMegolmInboundSession(sessionSharing) + if err != nil { + t.Fatal(err) + } + if !inboundSession.SigningKeyVerified { + t.Fatal("key not verified") + } + if inboundSession.SessionID() != outboundSession.SessionID() { + t.Fatal("session ids not equal") + } + + //decode message + decoded, _, err := inboundSession.Decrypt(ciphertext) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plainText, decoded) { + t.Fatal("messages not equal") + } +} + +func TestGroupSessionExportImport(t *testing.T) { + plaintext := []byte("Message") + sessionKey := []byte( + "AgAAAAAwMTIzNDU2Nzg5QUJERUYwMTIzNDU2Nzg5QUJDREVGMDEyMzQ1Njc4OUFCREVGM" + + "DEyMzQ1Njc4OUFCQ0RFRjAxMjM0NTY3ODlBQkRFRjAxMjM0NTY3ODlBQkNERUYwMTIzND" + + "U2Nzg5QUJERUYwMTIzNDU2Nzg5QUJDREVGMDEyMw0bdg1BDq4Px/slBow06q8n/B9WBfw" + + "WYyNOB8DlUmXGGwrFmaSb9bR/eY8xgERrxmP07hFmD9uqA2p8PMHdnV5ysmgufE6oLZ5+" + + "8/mWQOW3VVTnDIlnwd8oHUYRuk8TCQ", + ) + message := []byte( + "AwgAEhAcbh6UpbByoyZxufQ+h2B+8XHMjhR69G8F4+qjMaFlnIXusJZX3r8LnRORG9T3D" + + "XFdbVuvIWrLyRfm4i8QRbe8VPwGRFG57B1CtmxanuP8bHtnnYqlwPsD", + ) + + //init inbound + inboundSession, err := NewMegolmInboundSession(sessionKey) + if err != nil { + t.Fatal(err) + } + if !inboundSession.SigningKeyVerified { + t.Fatal("signing key not verified") + } + + decrypted, _, err := inboundSession.Decrypt(message) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plaintext, decrypted) { + t.Fatal("message is not correct") + } + + //Export the keys + exported, err := inboundSession.SessionExportMessage(0) + if err != nil { + t.Fatal(err) + } + + secondInboundSession, err := NewMegolmInboundSessionFromExport(exported) + if err != nil { + t.Fatal(err) + } + if secondInboundSession.SigningKeyVerified { + t.Fatal("signing key is verified") + } + //decrypt with new session + decrypted, _, err = secondInboundSession.Decrypt(message) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plaintext, decrypted) { + t.Fatal("message is not correct") + } + if !secondInboundSession.SigningKeyVerified { + t.Fatal("signing key not verified") + } +} + +func TestBadSignatureGroupMessage(t *testing.T) { + plaintext := []byte("Message") + sessionKey := []byte( + "AgAAAAAwMTIzNDU2Nzg5QUJERUYwMTIzNDU2Nzg5QUJDREVGMDEyMzQ1Njc4OUFCREVGM" + + "DEyMzQ1Njc4OUFCQ0RFRjAxMjM0NTY3ODlBQkRFRjAxMjM0NTY3ODlBQkNERUYwMTIzND" + + "U2Nzg5QUJERUYwMTIzNDU2Nzg5QUJDREVGMDEyMztqJ7zOtqQtYqOo0CpvDXNlMhV3HeJ" + + "DpjrASKGLWdop4lx1cSN3Xv1TgfLPW8rhGiW+hHiMxd36nRuxscNv9k4oJA/KP+o0mi1w" + + "v44StrEJ1wwx9WZHBUIWkQbaBSuBDw", + ) + message := []byte( + "AwgAEhAcbh6UpbByoyZxufQ+h2B+8XHMjhR69G8nP4pNZGl/3QMgrzCZPmP+F2aPLyKPz" + + "xRPBMUkeXRJ6Iqm5NeOdx2eERgTW7P20CM+lL3Xpk+ZUOOPvsSQNaAL", + ) + + //init inbound + inboundSession, err := NewMegolmInboundSession(sessionKey) + if err != nil { + t.Fatal(err) + } + if !inboundSession.SigningKeyVerified { + t.Fatal("signing key not verified") + } + + decrypted, _, err := inboundSession.Decrypt(message) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plaintext, decrypted) { + t.Fatal("message is not correct") + } + + //Now twiddle the signature + copy(message[len(message)-1:], []byte("E")) + _, _, err = inboundSession.Decrypt(message) + if err == nil { + t.Fatal("Signature was changed but did not cause an error") + } + if !errors.Is(err, goolm.ErrBadSignature) { + t.Fatalf("wrong error %s", err.Error()) + } +} + +func TestOutbountPickle(t *testing.T) { + pickledDataFromLibOlm := []byte("icDKYm0b4aO23WgUuOxdpPoxC0UlEOYPVeuduNH3IkpFsmnWx5KuEOpxGiZw5IuB/sSn2RZUCTiJ90IvgC7AClkYGHep9O8lpiqQX73XVKD9okZDCAkBc83eEq0DKYC7HBkGRAU/4T6QPIBBY3UK4QZwULLE/fLsi3j4YZBehMtnlsqgHK0q1bvX4cRznZItUO3TiOp5I+6PnQka6n8eHTyIEh3tCetilD+BKnHvtakE0eHHvG6pjEsMNN/vs7lkB5rV6XkoUKHLTE1dAfFunYEeHEZuKQpbG385dBwaMJXt4JrC0hU5jnv6jWNqAA0Ud9GxRDvkp04") + pickleKey := []byte("secret_key") + session, err := MegolmOutboundSessionFromPickled(pickledDataFromLibOlm, pickleKey) + if err != nil { + t.Fatal(err) + } + newPickled, err := session.Pickle(pickleKey) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(pickledDataFromLibOlm, newPickled) { + t.Fatal("pickled version does not equal libolm version") + } + pickledDataFromLibOlm = append(pickledDataFromLibOlm, []byte("a")...) + _, err = MegolmOutboundSessionFromPickled(pickledDataFromLibOlm, pickleKey) + if err == nil { + t.Fatal("should have gotten an error") + } +} + +func TestInbountPickle(t *testing.T) { + pickledDataFromLibOlm := []byte("1/IPCdtUoQxMba5XT7sjjUW0Hrs7no9duGFnhsEmxzFX2H3qtRc4eaFBRZYXxOBRTGZ6eMgy3IiSrgAQ1gUlSZf5Q4AVKeBkhvN4LZ6hdhQFv91mM+C2C55/4B9/gDjJEbDGiRgLoMqbWPDV+y0F4h0KaR1V1PiTCC7zCi4WdxJQ098nJLgDL4VSsDbnaLcSMO60FOYgRN4KsLaKUGkXiiUBWp4boFMCiuTTOiyH8XlH0e9uWc0vMLyGNUcO8kCbpAnx3v1JTIVan3WGsnGv4K8Qu4M8GAkZewpexrsb2BSNNeLclOV9/cR203Y5KlzXcpiWNXSs8XoB3TLEtHYMnjuakMQfyrcXKIQntg4xPD/+wvfqkcMg9i7pcplQh7X2OK5ylrMZQrZkJ1fAYBGbBz1tykWOjfrZ") + pickleKey := []byte("secret_key") + session, err := MegolmInboundSessionFromPickled(pickledDataFromLibOlm, pickleKey) + if err != nil { + t.Fatal(err) + } + newPickled, err := session.Pickle(pickleKey) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(pickledDataFromLibOlm, newPickled) { + t.Fatal("pickled version does not equal libolm version") + } + pickledDataFromLibOlm = append(pickledDataFromLibOlm, []byte("a")...) + _, err = MegolmInboundSessionFromPickled(pickledDataFromLibOlm, pickleKey) + if err == nil { + t.Fatal("should have gotten an error") + } +} diff --git a/crypto/goolm/session/olmSession.go b/crypto/goolm/session/olmSession.go new file mode 100644 index 00000000..20cd2af5 --- /dev/null +++ b/crypto/goolm/session/olmSession.go @@ -0,0 +1,475 @@ +package session + +import ( + "bytes" + "fmt" + "io" + + "codeberg.org/DerLukas/goolm" + "codeberg.org/DerLukas/goolm/cipher" + "codeberg.org/DerLukas/goolm/crypto" + libolmpickle "codeberg.org/DerLukas/goolm/libolmPickle" + "codeberg.org/DerLukas/goolm/message" + "codeberg.org/DerLukas/goolm/olm" + "codeberg.org/DerLukas/goolm/utilities" + "github.com/pkg/errors" + "maunium.net/go/mautrix/id" +) + +const ( + olmSessionPickleVersionJSON uint8 = 1 + olmSessionPickleVersionLibOlm uint32 = 1 +) + +const ( + protocolVersion = 0x3 +) + +// OlmSession stores all information for an olm session +type OlmSession struct { + RecievedMessage bool `json:"recievedMessage"` + AliceIdKey crypto.Curve25519PublicKey `json:"aliceIdKey"` + AliceBaseKey crypto.Curve25519PublicKey `json:"aliceBaseKey"` + BobOneTimeKey crypto.Curve25519PublicKey `json:"bobOnTimeKey"` + Ratchet olm.Ratchet `json:"ratchet"` +} + +// used to retrieve a crypto.OneTimeKey from a public key. +type SearchOTKFunc = func(crypto.Curve25519PublicKey) *crypto.OneTimeKey + +// OlmSessionFromJSONPickled loads an OlmSession from a pickled base64 string. Decrypts +// the Session using the supplied key. +func OlmSessionFromJSONPickled(pickled, key []byte) (*OlmSession, error) { + if len(pickled) == 0 { + return nil, errors.Wrap(goolm.ErrEmptyInput, "sessionFromPickled") + } + a := &OlmSession{} + err := a.UnpickleAsJSON(pickled, key) + if err != nil { + return nil, err + } + return a, nil +} + +// OlmSessionFromPickled loads the OlmSession details from a pickled base64 string. The input is decrypted with the supplied key. +func OlmSessionFromPickled(pickled, key []byte) (*OlmSession, error) { + if len(pickled) == 0 { + return nil, errors.Wrap(goolm.ErrEmptyInput, "sessionFromPickled") + } + a := &OlmSession{} + err := a.Unpickle(pickled, key) + if err != nil { + return nil, err + } + return a, nil +} + +// NewSession creates a new Session. +func NewOlmSession() *OlmSession { + s := &OlmSession{} + s.Ratchet = *olm.New() + return s +} + +// NewOutboundSession creates a new outbound session for sending the first message to a +// given curve25519 identityKey and oneTimeKey. +func NewOutboundOlmSession(identityKeyAlice crypto.Curve25519KeyPair, identityKeyBob crypto.Curve25519PublicKey, oneTimeKeyBob crypto.Curve25519PublicKey) (*OlmSession, error) { + s := NewOlmSession() + //generate E_A + baseKey, err := crypto.Curve25519GenerateKey(nil) + if err != nil { + return nil, err + } + //generate T_0 + ratchetKey, err := crypto.Curve25519GenerateKey(nil) + if err != nil { + return nil, err + } + + //Calculate shared secret via Triple Diffie-Hellman + var secret []byte + //ECDH(I_A,E_B) + idSecret, err := identityKeyAlice.SharedSecret(oneTimeKeyBob) + if err != nil { + return nil, err + } + //ECDH(E_A,I_B) + baseIdSecret, err := baseKey.SharedSecret(identityKeyBob) + if err != nil { + return nil, err + } + //ECDH(E_A,E_B) + baseOneTimeSecret, err := baseKey.SharedSecret(oneTimeKeyBob) + if err != nil { + return nil, err + } + secret = append(secret, idSecret...) + secret = append(secret, baseIdSecret...) + secret = append(secret, baseOneTimeSecret...) + //Init Ratchet + s.Ratchet.InitialiseAsAlice(secret, ratchetKey) + s.AliceIdKey = identityKeyAlice.PublicKey + s.AliceBaseKey = baseKey.PublicKey + s.BobOneTimeKey = oneTimeKeyBob + return s, nil +} + +// NewInboundOlmSession creates a new inbound session from receiving the first message. +func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, receivedOTKMsg []byte, searchBobOTK SearchOTKFunc, identityKeyBob crypto.Curve25519KeyPair) (*OlmSession, error) { + decodedOTKMsg, err := goolm.Base64Decode(receivedOTKMsg) + if err != nil { + return nil, err + } + s := NewOlmSession() + + //decode OneTimeKeyMessage + oneTimeMsg := message.PreKeyMessage{} + err = oneTimeMsg.Decode(decodedOTKMsg) + if err != nil { + return nil, errors.Wrap(err, "OneTimeKeyMessage decode") + } + if !oneTimeMsg.CheckFields(identityKeyAlice) { + return nil, errors.Wrap(goolm.ErrBadMessageFormat, "OneTimeKeyMessage check fields") + } + + //Either the identityKeyAlice is set and/or the oneTimeMsg.IdentityKey is set, which is checked + // by oneTimeMsg.CheckFields + if identityKeyAlice != nil && len(oneTimeMsg.IdentityKey) != 0 { + //if both are set, compare them + if !identityKeyAlice.Equal(oneTimeMsg.IdentityKey) { + return nil, errors.Wrap(goolm.ErrBadMessageKeyID, "OneTimeKeyMessage identity keys") + } + } + if identityKeyAlice == nil { + //for downstream use set + identityKeyAlice = &oneTimeMsg.IdentityKey + } + + oneTimeKeyBob := searchBobOTK(oneTimeMsg.OneTimeKey) + if oneTimeKeyBob == nil { + return nil, errors.Wrap(goolm.ErrBadMessageKeyID, "ourOneTimeKey") + } + + //Calculate shared secret via Triple Diffie-Hellman + var secret []byte + //ECDH(E_B,I_A) + idSecret, err := oneTimeKeyBob.Key.SharedSecret(*identityKeyAlice) + if err != nil { + return nil, err + } + //ECDH(I_B,E_A) + baseIdSecret, err := identityKeyBob.SharedSecret(oneTimeMsg.BaseKey) + if err != nil { + return nil, err + } + //ECDH(E_B,E_A) + baseOneTimeSecret, err := oneTimeKeyBob.Key.SharedSecret(oneTimeMsg.BaseKey) + if err != nil { + return nil, err + } + secret = append(secret, idSecret...) + secret = append(secret, baseIdSecret...) + secret = append(secret, baseOneTimeSecret...) + //decode message + msg := message.Message{} + err = msg.Decode(oneTimeMsg.Message) + if err != nil { + return nil, errors.Wrap(err, "Message decode") + } + + if len(msg.RatchetKey) == 0 { + return nil, errors.Wrap(goolm.ErrBadMessageFormat, "Message missing ratchet key") + } + //Init Ratchet + s.Ratchet.InitialiseAsBob(secret, msg.RatchetKey) + s.AliceBaseKey = oneTimeMsg.BaseKey + s.AliceIdKey = oneTimeMsg.IdentityKey + s.BobOneTimeKey = oneTimeKeyBob.Key.PublicKey + + //https://gitlab.matrix.org/matrix-org/olm/blob/master/docs/olm.md states to remove the oneTimeKey + //this is done via the account itself + return s, nil +} + +// PickleAsJSON returns an Session as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. +func (a OlmSession) PickleAsJSON(key []byte) ([]byte, error) { + return utilities.PickleAsJSON(a, olmSessionPickleVersionJSON, key) +} + +// UnpickleAsJSON updates an Session by a base64 encrypted string with the key. The unencrypted representation has to be in JSON format. +func (a *OlmSession) UnpickleAsJSON(pickled, key []byte) error { + return utilities.UnpickleAsJSON(a, pickled, key, olmSessionPickleVersionJSON) +} + +// ID returns an identifier for this Session. Will be the same for both ends of the conversation. +// Generated by hashing the public keys used to create the session. +func (s OlmSession) ID() id.SessionID { + message := make([]byte, 3*crypto.Curve25519KeyLength) + copy(message, s.AliceIdKey) + copy(message[crypto.Curve25519KeyLength:], s.AliceBaseKey) + copy(message[2*crypto.Curve25519KeyLength:], s.BobOneTimeKey) + hash := crypto.SHA256(message) + res := id.SessionID(goolm.Base64Encode(hash)) + return res +} + +// HasReceivedMessage returns true if this session has received any message. +func (s OlmSession) HasReceivedMessage() bool { + return s.RecievedMessage +} + +// MatchesInboundSessionFrom checks if the oneTimeKeyMsg message is set for this inbound +// Session. This can happen if multiple messages are sent to this Account +// before this Account sends a message in reply. Returns true if the session +// matches. Returns false if the session does not match. +func (s OlmSession) MatchesInboundSessionFrom(theirIdentityKeyEncoded *id.Curve25519, receivedOTKMsg []byte) (bool, error) { + if len(receivedOTKMsg) == 0 { + return false, errors.Wrap(goolm.ErrEmptyInput, "inbound match") + } + decodedOTKMsg, err := goolm.Base64Decode(receivedOTKMsg) + if err != nil { + return false, err + } + + var theirIdentityKey *crypto.Curve25519PublicKey + if theirIdentityKeyEncoded != nil { + decodedKey, err := goolm.Base64Decode([]byte(*theirIdentityKeyEncoded)) + if err != nil { + return false, err + } + theirIdentityKeyByte := crypto.Curve25519PublicKey(decodedKey) + theirIdentityKey = &theirIdentityKeyByte + } + + msg := message.PreKeyMessage{} + err = msg.Decode(decodedOTKMsg) + if err != nil { + return false, err + } + if !msg.CheckFields(theirIdentityKey) { + return false, nil + } + + same := true + if msg.IdentityKey != nil { + same = same && msg.IdentityKey.Equal(s.AliceIdKey) + } + if theirIdentityKey != nil { + same = same && theirIdentityKey.Equal(s.AliceIdKey) + } + same = same && bytes.Equal(msg.BaseKey, s.AliceBaseKey) + same = same && bytes.Equal(msg.OneTimeKey, s.BobOneTimeKey) + return same, nil +} + +// EncryptMsgType returns the type of the next message that Encrypt will +// return. Returns MsgTypePreKey if the message will be a oneTimeKeyMsg. +// Returns MsgTypeMsg if the message will be a normal message. +func (s OlmSession) EncryptMsgType() id.OlmMsgType { + if s.RecievedMessage { + return id.OlmMsgTypeMsg + } + return id.OlmMsgTypePreKey +} + +// Encrypt encrypts a message using the Session. Returns the encrypted message base64 encoded. If reader is nil, crypto/rand is used for key generations. +func (s *OlmSession) Encrypt(plaintext []byte, reader io.Reader) (id.OlmMsgType, []byte, error) { + if len(plaintext) == 0 { + return 0, nil, errors.Wrap(goolm.ErrEmptyInput, "encrypt") + } + messageType := s.EncryptMsgType() + encrypted, err := s.Ratchet.Encrypt(plaintext, reader) + if err != nil { + return 0, nil, err + } + result := encrypted + if !s.RecievedMessage { + msg := message.PreKeyMessage{} + msg.Version = protocolVersion + msg.OneTimeKey = s.BobOneTimeKey + msg.IdentityKey = s.AliceIdKey + msg.BaseKey = s.AliceBaseKey + msg.Message = encrypted + + var err error + messageBody, err := msg.Encode() + if err != nil { + return 0, nil, err + } + result = messageBody + } + + return messageType, goolm.Base64Encode(result), nil +} + +// Decrypt decrypts a base64 encoded message using the Session. +func (s *OlmSession) Decrypt(crypttext []byte, msgType id.OlmMsgType) ([]byte, error) { + if len(crypttext) == 0 { + return nil, errors.Wrap(goolm.ErrEmptyInput, "decrypt") + } + decodedCrypttext, err := goolm.Base64Decode(crypttext) + if err != nil { + return nil, err + } + msgBody := decodedCrypttext + if msgType != id.OlmMsgTypeMsg { + //Pre-Key Message + msg := message.PreKeyMessage{} + err := msg.Decode(decodedCrypttext) + if err != nil { + return nil, err + } + msgBody = msg.Message + } + plaintext, err := s.Ratchet.Decrypt(msgBody) + if err != nil { + return nil, err + } + s.RecievedMessage = true + return plaintext, nil +} + +// Unpickle decodes the base64 encoded string and decrypts the result with the key. +// The decrypted value is then passed to UnpickleLibOlm. +func (o *OlmSession) Unpickle(pickled, key []byte) error { + decrypted, err := cipher.Unpickle(key, pickled) + if err != nil { + return err + } + _, err = o.UnpickleLibOlm(decrypted) + return err +} + +// UnpickleLibOlm decodes the unencryted value and populates the Session accordingly. It returns the number of bytes read. +func (o *OlmSession) UnpickleLibOlm(value []byte) (int, error) { + //First 4 bytes are the accountPickleVersion + pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value) + if err != nil { + return 0, err + } + includesChainIndex := true + switch pickledVersion { + case olmSessionPickleVersionLibOlm: + includesChainIndex = false + case uint32(0x80000001): + includesChainIndex = true + default: + return 0, errors.Wrap(goolm.ErrBadVersion, "unpickle olmSession") + } + var readBytes int + o.RecievedMessage, readBytes, err = libolmpickle.UnpickleBool(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + readBytes, err = o.AliceIdKey.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + readBytes, err = o.AliceBaseKey.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + readBytes, err = o.BobOneTimeKey.UnpickleLibOlm(value[curPos:]) + if err != nil { + return 0, err + } + curPos += readBytes + readBytes, err = o.Ratchet.UnpickleLibOlm(value[curPos:], includesChainIndex) + if err != nil { + return 0, err + } + curPos += readBytes + return curPos, nil +} + +// Pickle returns a base64 encoded and with key encrypted pickled olmSession using PickleLibOlm(). +func (o OlmSession) Pickle(key []byte) ([]byte, error) { + pickeledBytes := make([]byte, o.PickleLen()) + written, err := o.PickleLibOlm(pickeledBytes) + if err != nil { + return nil, err + } + if written != len(pickeledBytes) { + return nil, errors.New("number of written bytes not correct") + } + encrypted, err := cipher.Pickle(key, pickeledBytes) + if err != nil { + return nil, err + } + return encrypted, nil +} + +// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0. +// It returns the number of bytes written. +func (o OlmSession) PickleLibOlm(target []byte) (int, error) { + if len(target) < o.PickleLen() { + return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle MegolmOutboundSession") + } + written := libolmpickle.PickleUInt32(olmSessionPickleVersionLibOlm, target) + written += libolmpickle.PickleBool(o.RecievedMessage, target[written:]) + writtenRatchet, err := o.AliceIdKey.PickleLibOlm(target[written:]) + if err != nil { + return 0, errors.Wrap(err, "pickle MegolmOutboundSession") + } + written += writtenRatchet + writtenRatchet, err = o.AliceBaseKey.PickleLibOlm(target[written:]) + if err != nil { + return 0, errors.Wrap(err, "pickle MegolmOutboundSession") + } + written += writtenRatchet + writtenRatchet, err = o.BobOneTimeKey.PickleLibOlm(target[written:]) + if err != nil { + return 0, errors.Wrap(err, "pickle MegolmOutboundSession") + } + written += writtenRatchet + writtenRatchet, err = o.Ratchet.PickleLibOlm(target[written:]) + if err != nil { + return 0, errors.Wrap(err, "pickle MegolmOutboundSession") + } + written += writtenRatchet + return written, nil +} + +// PickleLen returns the actual number of bytes the pickled session will have. +func (o OlmSession) PickleLen() int { + length := libolmpickle.PickleUInt32Len(olmSessionPickleVersionLibOlm) + length += libolmpickle.PickleBoolLen(o.RecievedMessage) + length += o.AliceIdKey.PickleLen() + length += o.AliceBaseKey.PickleLen() + length += o.BobOneTimeKey.PickleLen() + length += o.Ratchet.PickleLen() + return length +} + +// PickleLenMin returns the minimum number of bytes the pickled session must have. +func (o OlmSession) PickleLenMin() int { + length := libolmpickle.PickleUInt32Len(olmSessionPickleVersionLibOlm) + length += libolmpickle.PickleBoolLen(o.RecievedMessage) + length += o.AliceIdKey.PickleLen() + length += o.AliceBaseKey.PickleLen() + length += o.BobOneTimeKey.PickleLen() + length += o.Ratchet.PickleLenMin() + return length +} + +// Describe returns a string describing the current state of the session for debugging. +func (o OlmSession) Describe() string { + var res string + if o.Ratchet.SenderChains.IsSet { + res += fmt.Sprintf("sender chain index: %d ", o.Ratchet.SenderChains.CKey.Index) + } else { + res += "sender chain index: " + } + res += "receiver chain indicies:" + for _, curChain := range o.Ratchet.ReceiverChains { + res += fmt.Sprintf(" %d", curChain.CKey.Index) + } + res += " skipped message keys:" + for _, curSkip := range o.Ratchet.SkippedMessageKeys { + res += fmt.Sprintf(" %d", curSkip.MKey.Index) + } + return res +} diff --git a/crypto/goolm/session/olmSession_test.go b/crypto/goolm/session/olmSession_test.go new file mode 100644 index 00000000..3fbdb569 --- /dev/null +++ b/crypto/goolm/session/olmSession_test.go @@ -0,0 +1,174 @@ +package session + +import ( + "bytes" + "errors" + "testing" + + "codeberg.org/DerLukas/goolm" + "codeberg.org/DerLukas/goolm/crypto" + "maunium.net/go/mautrix/id" +) + +func TestOlmSession(t *testing.T) { + pickleKey := []byte("secretKey") + aliceKeyPair, err := crypto.Curve25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + bobKeyPair, err := crypto.Curve25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + bobOneTimeKey, err := crypto.Curve25519GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + aliceSession, err := NewOutboundOlmSession(aliceKeyPair, bobKeyPair.PublicKey, bobOneTimeKey.PublicKey) + if err != nil { + t.Fatal(err) + } + //create a message so that there are more keys to marshal + plaintext := []byte("Test message from Alice to Bob") + msgType, message, err := aliceSession.Encrypt(plaintext, nil) + if err != nil { + t.Fatal(err) + } + if msgType != id.OlmMsgTypePreKey { + t.Fatal("Wrong message type") + } + + searchFunc := func(target crypto.Curve25519PublicKey) *crypto.OneTimeKey { + if target.Equal(bobOneTimeKey.PublicKey) { + return &crypto.OneTimeKey{ + Key: bobOneTimeKey, + Published: false, + ID: 1, + } + } + return nil + } + //bob receives message + bobSession, err := NewInboundOlmSession(nil, message, searchFunc, bobKeyPair) + if err != nil { + t.Fatal(err) + } + decryptedMsg, err := bobSession.Decrypt(message, msgType) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plaintext, decryptedMsg) { + t.Fatalf("messages are not equal:\n%v\n%v\n", plaintext, decryptedMsg) + } + + // Alice pickles session + pickled, err := aliceSession.PickleAsJSON(pickleKey) + if err != nil { + t.Fatal(err) + } + + //bob sends a message + plaintext = []byte("A message from Bob to Alice") + msgType, message, err = bobSession.Encrypt(plaintext, nil) + if err != nil { + t.Fatal(err) + } + if msgType != id.OlmMsgTypeMsg { + t.Fatal("Wrong message type") + } + + //Alice unpickles session + newAliceSession, err := OlmSessionFromJSONPickled(pickled, pickleKey) + if err != nil { + t.Fatal(err) + } + + //Alice receives message + decryptedMsg, err = newAliceSession.Decrypt(message, msgType) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plaintext, decryptedMsg) { + t.Fatalf("messages are not equal:\n%v\n%v\n", plaintext, decryptedMsg) + } + + //Alice receives message again + _, err = newAliceSession.Decrypt(message, msgType) + if err == nil { + t.Fatal("should have gotten an error") + } + + //Alice sends another message + plaintext = []byte("A second message to Bob") + msgType, message, err = newAliceSession.Encrypt(plaintext, nil) + if err != nil { + t.Fatal(err) + } + if msgType != id.OlmMsgTypeMsg { + t.Fatal("Wrong message type") + } + //bob receives message + decryptedMsg, err = bobSession.Decrypt(message, msgType) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plaintext, decryptedMsg) { + t.Fatalf("messages are not equal:\n%v\n%v\n", plaintext, decryptedMsg) + } +} + +func TestSessionPickle(t *testing.T) { + pickledDataFromLibOlm := []byte("icDKYm0b4aO23WgUuOxdpPoxC0UlEOYPVeuduNH3IkpFsmnWx5KuEOpxGiZw5IuB/sSn2RZUCTiJ90IvgC7AClkYGHep9O8lpiqQX73XVKD9okZDCAkBc83eEq0DKYC7HBkGRAU/4T6QPIBBY3UK4QZwULLE/fLsi3j4YZBehMtnlsqgHK0q1bvX4cRznZItVKR4ro0O9EAk6LLxJtSnRu5elSUk7YXT") + pickleKey := []byte("secret_key") + session, err := OlmSessionFromPickled(pickledDataFromLibOlm, pickleKey) + if err != nil { + t.Fatal(err) + } + newPickled, err := session.Pickle(pickleKey) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(pickledDataFromLibOlm, newPickled) { + t.Fatal("pickled version does not equal libolm version") + } + pickledDataFromLibOlm = append(pickledDataFromLibOlm, []byte("a")...) + _, err = OlmSessionFromPickled(pickledDataFromLibOlm, pickleKey) + if err == nil { + t.Fatal("should have gotten an error") + } +} + +func TestDecrypts(t *testing.T) { + messages := [][]byte{ + {0x41, 0x77, 0x6F}, + {0x7f, 0xff, 0x6f, 0x01, 0x01, 0x34, 0x6d, 0x67, 0x12, 0x01}, + {0xee, 0x77, 0x6f, 0x41, 0x49, 0x6f, 0x67, 0x41, 0x77, 0x80, 0x41, 0x77, 0x77, 0x80, 0x41, 0x77, 0x6f, 0x67, 0x16, 0x67, 0x0a, 0x67, 0x7d, 0x6f, 0x67, 0x0a, 0x67, 0xc2, 0x67, 0x7d}, + {0xe9, 0xe9, 0xc9, 0xc1, 0xe9, 0xe9, 0xc9, 0xe9, 0xc9, 0xc1, 0xe9, 0xe9, 0xc9, 0xc1}, + } + expectedErr := []error{ + goolm.ErrInputToSmall, + goolm.ErrBadBase64, + goolm.ErrBadBase64, + goolm.ErrBadBase64, + } + sessionPickled := []byte("E0p44KO2y2pzp9FIjv0rud2wIvWDi2dx367kP4Fz/9JCMrH+aG369HGymkFtk0+PINTLB9lQRt" + + "ohea5d7G/UXQx3r5y4IWuyh1xaRnojEZQ9a5HRZSNtvmZ9NY1f1gutYa4UtcZcbvczN8b/5Bqg" + + "e16cPUH1v62JKLlhoAJwRkH1wU6fbyOudERg5gdXA971btR+Q2V8GKbVbO5fGKL5phmEPVXyMs" + + "rfjLdzQrgjOTxN8Pf6iuP+WFPvfnR9lDmNCFxJUVAdLIMnLuAdxf1TGcS+zzCzEE8btIZ99mHF" + + "dGvPXeH8qLeNZA") + pickleKey := []byte("") + session, err := OlmSessionFromPickled(sessionPickled, pickleKey) + if err != nil { + t.Fatal(err) + } + for curIndex, curMessage := range messages { + _, err := session.Decrypt(curMessage, id.OlmMsgTypePreKey) + if err != nil { + if !errors.Is(err, expectedErr[curIndex]) { + t.Fatal(err) + } + } else { + t.Fatal("error expected") + } + } +} diff --git a/crypto/goolm/utilities/main.go b/crypto/goolm/utilities/main.go new file mode 100644 index 00000000..48ff37aa --- /dev/null +++ b/crypto/goolm/utilities/main.go @@ -0,0 +1,25 @@ +package utilities + +import ( + "codeberg.org/DerLukas/goolm" + "codeberg.org/DerLukas/goolm/crypto" + "maunium.net/go/mautrix/id" +) + +func Sha256(value []byte) []byte { + return goolm.Base64Encode(crypto.SHA256((value))) +} + +// VerifySignature verifies an ed25519 signature. +func VerifySignature(message []byte, key id.Ed25519, signature []byte) (ok bool, err error) { + keyDecoded, err := goolm.Base64Decode([]byte(key)) + if err != nil { + return false, err + } + signatureDecoded, err := goolm.Base64Decode(signature) + if err != nil { + return false, err + } + publicKey := crypto.Ed25519PublicKey(keyDecoded) + return publicKey.Verify(message, signatureDecoded), nil +} diff --git a/crypto/goolm/utilities/main_test.go b/crypto/goolm/utilities/main_test.go new file mode 100644 index 00000000..9019338b --- /dev/null +++ b/crypto/goolm/utilities/main_test.go @@ -0,0 +1,14 @@ +package utilities + +import ( + "bytes" + "testing" +) + +func TestSHA256(t *testing.T) { + plainText := []byte("Hello, World") + expected := []byte("A2daxT/5zRU1zMffzfosRYxSGDcfQY3BNvLRmsH76KU") + if !bytes.Equal(Sha256(plainText), expected) { + t.Fatal("sha256 failed") + } +} diff --git a/crypto/goolm/utilities/pickle.go b/crypto/goolm/utilities/pickle.go new file mode 100644 index 00000000..6bfbcbc2 --- /dev/null +++ b/crypto/goolm/utilities/pickle.go @@ -0,0 +1,60 @@ +package utilities + +import ( + "encoding/json" + + "codeberg.org/DerLukas/goolm" + "codeberg.org/DerLukas/goolm/cipher" + "github.com/pkg/errors" +) + +// PickleAsJSON returns an object as a base64 string encrypted using the supplied key. The unencrypted representation of the object is in JSON format. +func PickleAsJSON(object any, pickleVersion byte, key []byte) ([]byte, error) { + if len(key) == 0 { + return nil, errors.Wrap(goolm.ErrNoKeyProvided, "pickle") + } + marshaled, err := json.Marshal(object) + if err != nil { + return nil, errors.Wrap(err, "pickle marshal") + } + marshaled = append([]byte{pickleVersion}, marshaled...) + toEncrypt := make([]byte, len(marshaled)) + copy(toEncrypt, marshaled) + //pad marshaled to get block size + if len(marshaled)%cipher.PickleBlockSize() != 0 { + padding := cipher.PickleBlockSize() - len(marshaled)%cipher.PickleBlockSize() + toEncrypt = make([]byte, len(marshaled)+padding) + copy(toEncrypt, marshaled) + } + encrypted, err := cipher.Pickle(key, toEncrypt) + if err != nil { + return nil, errors.Wrap(err, "pickle encrypt") + } + return encrypted, nil +} + +// UnpickleAsJSON updates the object by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. +func UnpickleAsJSON(object any, pickled, key []byte, pickleVersion byte) error { + if len(key) == 0 { + return errors.Wrap(goolm.ErrNoKeyProvided, "unpickle") + } + decrypted, err := cipher.Unpickle(key, pickled) + if err != nil { + return errors.Wrap(err, "unpickle decrypt") + } + //unpad decrypted so unmarshal works + for i := len(decrypted) - 1; i >= 0; i-- { + if decrypted[i] != 0 { + decrypted = decrypted[:i+1] + break + } + } + if decrypted[0] != pickleVersion { + return errors.Wrap(goolm.ErrWrongPickleVersion, "unpickle") + } + err = json.Unmarshal(decrypted[1:], object) + if err != nil { + return errors.Wrap(err, "unpickle unmarshal") + } + return nil +} From ab39495bc67ee0e9a3578588a4c81cc2b6da7ce5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 15 Dec 2023 13:27:29 +0200 Subject: [PATCH 0029/1647] Clean up goolm and update import path --- crypto/goolm/.gitignore | 48 ------------- crypto/goolm/README.md | 34 ++-------- crypto/goolm/account/account.go | 51 +++++++------- ...ount_test_data.go => account_data_test.go} | 4 +- crypto/goolm/account/account_test.go | 38 ++++++----- crypto/goolm/base64.go | 33 ++------- crypto/goolm/base64_test.go | 47 ------------- .../cipher/{aesSha256.go => aes_sha256.go} | 26 +++---- .../{aesSha256_test.go => aes_sha256_test.go} | 2 +- crypto/goolm/cipher/pickle.go | 15 +++-- crypto/goolm/cipher/pickle_test.go | 8 ++- crypto/goolm/crypto/{aesCBC.go => aes_cbc.go} | 16 ++--- .../{aesCBC_test.go => aes_cbc_test.go} | 12 ++-- crypto/goolm/crypto/curve25519.go | 16 +++-- crypto/goolm/crypto/curve25519_test.go | 26 +++---- crypto/goolm/crypto/ed25519.go | 17 ++--- crypto/goolm/crypto/ed25519_test.go | 22 +++--- crypto/goolm/crypto/hmac_test.go | 25 +++---- .../crypto/{oneTimeKey.go => one_time_key.go} | 14 ++-- crypto/goolm/errors.go | 2 +- crypto/goolm/go.mod | 9 --- crypto/goolm/go.sum | 20 ------ crypto/goolm/libolmVersion.md | 3 - .../{libolmPickle => libolmpickle}/pickle.go | 0 .../pickle_test.go | 20 +++--- .../unpickle.go | 13 ++-- .../unpickle_test.go | 12 ++-- crypto/goolm/megolm/megolm.go | 22 +++--- crypto/goolm/megolm/megolm_test.go | 26 +++---- crypto/goolm/message/decoder.go | 7 +- .../{groupMessage.go => group_message.go} | 4 +- ...pMessage_test.go => group_message_test.go} | 8 ++- crypto/goolm/message/message.go | 4 +- crypto/goolm/message/message_test.go | 8 ++- .../{preKeyMessage.go => prekey_message.go} | 2 +- ...Message_test.go => prekey_message_test.go} | 9 +-- .../{sessionExport.go => session_export.go} | 10 +-- .../{sessionSharing.go => session_sharing.go} | 12 ++-- crypto/goolm/olm/chain.go | 27 ++++---- crypto/goolm/olm/olm.go | 48 ++++++------- crypto/goolm/olm/olm_test.go | 19 +++--- .../{skippedMessage.go => skipped_message.go} | 13 ++-- crypto/goolm/pk/decryption.go | 27 ++++---- crypto/goolm/pk/encryption.go | 13 ++-- crypto/goolm/pk/pk_test.go | 20 +++--- crypto/goolm/pk/signing.go | 4 +- crypto/goolm/sas/main.go | 4 +- crypto/goolm/sas/main_test.go | 11 +-- ...ndSession.go => megolm_inbound_session.go} | 43 ++++++------ ...dSession.go => megolm_outbound_session.go} | 29 ++++---- ...Session_test.go => megolm_session_test.go} | 67 ++++++++++--------- .../session/{olmSession.go => olm_session.go} | 53 ++++++++------- ...olmSession_test.go => olm_session_test.go} | 31 +++++---- crypto/goolm/utilities/main.go | 12 ++-- crypto/goolm/utilities/main_test.go | 14 ---- crypto/goolm/utilities/pickle.go | 20 +++--- 56 files changed, 482 insertions(+), 618 deletions(-) delete mode 100644 crypto/goolm/.gitignore rename crypto/goolm/account/{account_test_data.go => account_data_test.go} (99%) delete mode 100644 crypto/goolm/base64_test.go rename crypto/goolm/cipher/{aesSha256.go => aes_sha256.go} (74%) rename crypto/goolm/cipher/{aesSha256_test.go => aes_sha256_test.go} (97%) rename crypto/goolm/crypto/{aesCBC.go => aes_cbc.go} (80%) rename crypto/goolm/crypto/{aesCBC_test.go => aes_cbc_test.go} (81%) rename crypto/goolm/crypto/{oneTimeKey.go => one_time_key.go} (89%) delete mode 100644 crypto/goolm/go.mod delete mode 100644 crypto/goolm/go.sum delete mode 100644 crypto/goolm/libolmVersion.md rename crypto/goolm/{libolmPickle => libolmpickle}/pickle.go (100%) rename crypto/goolm/{libolmPickle => libolmpickle}/pickle_test.go (74%) rename crypto/goolm/{libolmPickle => libolmpickle}/unpickle.go (69%) rename crypto/goolm/{libolmPickle => libolmpickle}/unpickle_test.go (82%) rename crypto/goolm/message/{groupMessage.go => group_message.go} (98%) rename crypto/goolm/message/{groupMessage_test.go => group_message_test.go} (91%) rename crypto/goolm/message/{preKeyMessage.go => prekey_message.go} (98%) rename crypto/goolm/message/{preKeyMessage_test.go => prekey_message_test.go} (91%) rename crypto/goolm/message/{sessionExport.go => session_export.go} (82%) rename crypto/goolm/message/{sessionSharing.go => session_sharing.go} (83%) rename crypto/goolm/olm/{skippedMessage.go => skipped_message.go} (82%) rename crypto/goolm/session/{megolmInboundSession.go => megolm_inbound_session.go} (86%) rename crypto/goolm/session/{megolmOutboundSession.go => megolm_outbound_session.go} (86%) rename crypto/goolm/session/{melgomSession_test.go => megolm_session_test.go} (78%) rename crypto/goolm/session/{olmSession.go => olm_session.go} (89%) rename crypto/goolm/session/{olmSession_test.go => olm_session_test.go} (82%) delete mode 100644 crypto/goolm/utilities/main_test.go diff --git a/crypto/goolm/.gitignore b/crypto/goolm/.gitignore deleted file mode 100644 index 8cb9a443..00000000 --- a/crypto/goolm/.gitignore +++ /dev/null @@ -1,48 +0,0 @@ -# ---> Go -# If you prefer the allow list template instead of the deny list, see community template: -# https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore -# -# Binaries for programs and plugins -*.exe -*.exe~ -*.dll -*.so -*.dylib - -# Test binary, built with `go test -c` -*.test - -# Output of the go coverage tool, specifically when used with LiteIDE -*.out - -# Dependency directories (remove the comment below to include it) -# vendor/ - -# Go workspace file -go.work - -# ---> Go.AllowList -# Allowlisting gitignore template for GO projects prevents us -# from adding various unwanted local files, such as generated -# files, developer configurations or IDE-specific files etc. -# -# Recommended: Go.AllowList.gitignore - -# Ignore everything -* - -# But not these files... -!/.gitignore - -!*.go -!go.sum -!go.mod - -!README.md -!LICENSE -!libolmVersion.md - -# !Makefile - -# ...even if they are in subdirectories -!*/ diff --git a/crypto/goolm/README.md b/crypto/goolm/README.md index a34b27ca..c5eaa0af 100644 --- a/crypto/goolm/README.md +++ b/crypto/goolm/README.md @@ -1,31 +1,5 @@ -# goolm +# go-olm +This is a fork of [DerLukas's goolm](https://codeberg.org/DerLukas/goolm), +a pure Go implementation of libolm. -[![Please don't upload to GitHub](https://nogithub.codeberg.page/badge.svg)](https://nogithub.codeberg.page) -[![GoDoc](https://godoc.org/codeberg.org/DerLukas/goolm?status.svg)](https://godoc.org/codeberg.org/DerLukas/goolm) - -### A Go implementation of Olm and Megolm - -goolm is a pure Go implementation of libolm. Libolm is a cryptographic library used for end-to-end encryption in Matrix and wirtten in C++. -With goolm there is no need to use cgo when building Matrix clients in go. - -See the GoDoc for usage. - -This package is written to be a easily used in github.com/mautrix/go/crypto/olm. - -PR's are always welcome. - -# Features - -* Test files for most methods and functions adapted from libolm - -## Supported -* [Olm](https://matrix-org.github.io/vodozemac/vodozemac/olm/index.html) -* Pickle structs with encryption using JSON marshalling -* Pickle structs with encryption using the libolm format -* [Megolm](https://matrix-org.github.io/vodozemac/vodozemac/megolm/index.html) -* Inbound and outbound group sessions -* [SAS](https://matrix.org/docs/guides/implementing-more-advanced-e-2-ee-features-such-as-cross-signing) support - -# License - -MIT licensed. See the LICENSE file for details. +The original project is licensed under the MIT license. diff --git a/crypto/goolm/account/account.go b/crypto/goolm/account/account.go index e7e3beba..168818bb 100644 --- a/crypto/goolm/account/account.go +++ b/crypto/goolm/account/account.go @@ -2,17 +2,20 @@ package account import ( + "encoding/base64" "encoding/json" + "errors" + "fmt" "io" - "codeberg.org/DerLukas/goolm" - "codeberg.org/DerLukas/goolm/cipher" - "codeberg.org/DerLukas/goolm/crypto" - libolmpickle "codeberg.org/DerLukas/goolm/libolmPickle" - "codeberg.org/DerLukas/goolm/session" - "codeberg.org/DerLukas/goolm/utilities" - "github.com/pkg/errors" "maunium.net/go/mautrix/id" + + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" + "maunium.net/go/mautrix/crypto/goolm/session" + "maunium.net/go/mautrix/crypto/goolm/utilities" ) const ( @@ -39,7 +42,7 @@ type Account struct { // AccountFromJSONPickled loads the Account details from a pickled base64 string. The input is decrypted with the supplied key. func AccountFromJSONPickled(pickled, key []byte) (*Account, error) { if len(pickled) == 0 { - return nil, errors.Wrap(goolm.ErrEmptyInput, "accountFromPickled") + return nil, fmt.Errorf("accountFromPickled: %w", goolm.ErrEmptyInput) } a := &Account{} err := a.UnpickleAsJSON(pickled, key) @@ -52,7 +55,7 @@ func AccountFromJSONPickled(pickled, key []byte) (*Account, error) { // AccountFromPickled loads the Account details from a pickled base64 string. The input is decrypted with the supplied key. func AccountFromPickled(pickled, key []byte) (*Account, error) { if len(pickled) == 0 { - return nil, errors.Wrap(goolm.ErrEmptyInput, "accountFromPickled") + return nil, fmt.Errorf("accountFromPickled: %w", goolm.ErrEmptyInput) } a := &Account{} err := a.Unpickle(pickled, key) @@ -102,15 +105,15 @@ func (a Account) IdentityKeysJSON() ([]byte, error) { // IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity keys for the Account. func (a Account) IdentityKeys() (id.Ed25519, id.Curve25519) { - ed25519 := id.Ed25519(goolm.Base64Encode(a.IdKeys.Ed25519.PublicKey)) - curve25519 := id.Curve25519(goolm.Base64Encode(a.IdKeys.Curve25519.PublicKey)) + ed25519 := id.Ed25519(base64.RawStdEncoding.EncodeToString(a.IdKeys.Ed25519.PublicKey)) + curve25519 := id.Curve25519(base64.RawStdEncoding.EncodeToString(a.IdKeys.Curve25519.PublicKey)) return ed25519, curve25519 } // Sign returns the signature of a message using the Ed25519 key for this Account. func (a Account) Sign(message []byte) ([]byte, error) { if len(message) == 0 { - return nil, errors.Wrap(goolm.ErrEmptyInput, "sign") + return nil, fmt.Errorf("sign: %w", goolm.ErrEmptyInput) } return goolm.Base64Encode(a.IdKeys.Ed25519.Sign(message)), nil } @@ -184,13 +187,13 @@ func (a *Account) GenOneTimeKeys(reader io.Reader, num uint) error { // given curve25519 identity Key and one time key. func (a Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*session.OlmSession, error) { if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 { - return nil, errors.Wrap(goolm.ErrEmptyInput, "outbound session") + return nil, fmt.Errorf("outbound session: %w", goolm.ErrEmptyInput) } - theirIdentityKeyDecoded, err := goolm.Base64Decode([]byte(theirIdentityKey)) + theirIdentityKeyDecoded, err := base64.RawStdEncoding.DecodeString(string(theirIdentityKey)) if err != nil { return nil, err } - theirOneTimeKeyDecoded, err := goolm.Base64Decode([]byte(theirOneTimeKey)) + theirOneTimeKeyDecoded, err := base64.RawStdEncoding.DecodeString(string(theirOneTimeKey)) if err != nil { return nil, err } @@ -204,12 +207,12 @@ func (a Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25 // NewInboundSession creates a new inbound session from an incoming PRE_KEY message. func (a Account) NewInboundSession(theirIdentityKey *id.Curve25519, oneTimeKeyMsg []byte) (*session.OlmSession, error) { if len(oneTimeKeyMsg) == 0 { - return nil, errors.Wrap(goolm.ErrEmptyInput, "inbound session") + return nil, fmt.Errorf("inbound session: %w", goolm.ErrEmptyInput) } var theirIdentityKeyDecoded *crypto.Curve25519PublicKey var err error if theirIdentityKey != nil { - theirIdentityKeyDecodedByte, err := goolm.Base64Decode([]byte(*theirIdentityKey)) + theirIdentityKeyDecodedByte, err := base64.RawStdEncoding.DecodeString(string(*theirIdentityKey)) if err != nil { return nil, err } @@ -356,7 +359,7 @@ func (a *Account) UnpickleLibOlm(value []byte) (int, error) { switch pickledVersion { case accountPickleVersionLibOLM, 3, 2: default: - return 0, errors.Wrap(goolm.ErrBadVersion, "unpickle account") + return 0, fmt.Errorf("unpickle account: %w", goolm.ErrBadVersion) } //read ed25519 key pair readBytes, err := a.IdKeys.Ed25519.UnpickleLibOlm(value[curPos:]) @@ -464,24 +467,24 @@ func (a Account) Pickle(key []byte) ([]byte, error) { // It returns the number of bytes written. func (a Account) PickleLibOlm(target []byte) (int, error) { if len(target) < a.PickleLen() { - return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle account") + return 0, fmt.Errorf("pickle account: %w", goolm.ErrValueTooShort) } written := libolmpickle.PickleUInt32(accountPickleVersionLibOLM, target) writtenEdKey, err := a.IdKeys.Ed25519.PickleLibOlm(target[written:]) if err != nil { - return 0, errors.Wrap(err, "pickle account") + return 0, fmt.Errorf("pickle account: %w", err) } written += writtenEdKey writtenCurveKey, err := a.IdKeys.Curve25519.PickleLibOlm(target[written:]) if err != nil { - return 0, errors.Wrap(err, "pickle account") + return 0, fmt.Errorf("pickle account: %w", err) } written += writtenCurveKey written += libolmpickle.PickleUInt32(uint32(len(a.OTKeys)), target[written:]) for _, curOTKey := range a.OTKeys { writtenOT, err := curOTKey.PickleLibOlm(target[written:]) if err != nil { - return 0, errors.Wrap(err, "pickle account") + return 0, fmt.Errorf("pickle account: %w", err) } written += writtenOT } @@ -489,14 +492,14 @@ func (a Account) PickleLibOlm(target []byte) (int, error) { if a.NumFallbackKeys >= 1 { writtenOT, err := a.CurrentFallbackKey.PickleLibOlm(target[written:]) if err != nil { - return 0, errors.Wrap(err, "pickle account") + return 0, fmt.Errorf("pickle account: %w", err) } written += writtenOT if a.NumFallbackKeys >= 2 { writtenOT, err := a.PrevFallbackKey.PickleLibOlm(target[written:]) if err != nil { - return 0, errors.Wrap(err, "pickle account") + return 0, fmt.Errorf("pickle account: %w", err) } written += writtenOT } diff --git a/crypto/goolm/account/account_test_data.go b/crypto/goolm/account/account_data_test.go similarity index 99% rename from crypto/goolm/account/account_test_data.go rename to crypto/goolm/account/account_data_test.go index 742b6ac2..739421ff 100644 --- a/crypto/goolm/account/account_test_data.go +++ b/crypto/goolm/account/account_data_test.go @@ -1,6 +1,6 @@ -package account +package account_test -import "codeberg.org/DerLukas/goolm/crypto" +import "maunium.net/go/mautrix/crypto/goolm/crypto" var pickledDataFromLibOlm = []byte("b3jGWBenkTv6DJt90OX+H1ecoXQwihBjhdJHkAft49wS7ubT3Z0ta46p9PCnfKs+fOHeKhJzgfFcD5yCoatcpRzMHRri6V1dG/wMIu8nYvPPMZ8Dy5YlMBRGz0cpnOAhVoUzo/HtvyN8kgoYnZLzorVYepIqQcsLZAiG6qlztXepEflwNG619Rrk/zWYae5RBtxz9Cl0KCTj8cjY5J/SEKU+SCnj4n16wa+RfYXuLK/kBlE30uSWqQBInlLLYiSqOGjr8M0x+3A0eG0gYA+Aohwl5MbjQnDniTbQeg1gh3VwWZ6kJCgRpLnT0j6oc6V4HjP0JjseHe0rBr6W9o88sl6wGmVEr2ZjlvcD6hoCK21A98UZF0GTwHrX0zV7OQtn5cmys3A1xdgcBAo/GXte1d2HzBXSmgrnXExK3Ij+BkZoQSuEFWSUCLjCUFQohK8TfraLZ5+9sOaV/5KaUxqdBTi6HUqoYymCHxzG7olo3hlh+GJ+iOy9tnofqDirISDIIL7KJ2zNJxYHWNZrAVNxHF3rPSrBw9Zl6M2Scm9PdDqnPgGZ+MSCrCnT6UrrfmWurPahnXwdvPED9rykLtcy3aKFsB+RIezeoNdq/4d8sYVwFWd0w9HTXEYG1YY/Km2LiK/exaC98agPN4GXakCUHVZSfz59IZ3bH8jQdtZw2BPkVfNCUoTuUFzIk0LS3AtudCTiUaFdul9/Phj7TlvyvIKH8GUFRiV46fxMHJ9U0HGg6VAKtDR5qkB/nB1X8SWTmmZblR+jGOQvE6VCXSEoCdSyjYK+xtwlZsMHoFoci+NN/uxnrfoMd0D+TpNOyNFwdjn9hoS8kmqvMEyhae0Q5N4O7YwHiH/jZ9ruFTCMK10TeyFN3yxKiRKhiJkgd9bGnmHz25dm9EDkR4i1pXUFuSZCO7WPUX4aLiNMwcltW01EiTdvE7e2jgaoRguXI8gimvOZ/d8kKZ8QIgKURZZHXmud93MOXL3sAy/aBU8dBMt+E5mVeoGM2fns8o5D9Yx3gZ6CgkzmzWinfj82qyc459OcyeyV/gugEt3FI28UBMRfghIV0juOGTAjkh6G3wIyZVk2G4rG0mYONrQQhmgKf06szNQXFBHQ2Pju4pY+QEAng3D2CfXFV6S2bUVeXN0fk46afsV84WPwYg77DWTuR81Ck2arbIcKGsSpMrETMY65rtEoMAcXLzmWgPsIXdo7k+aR4mWmcxjW5a10Wxc1knOi39x0M6gnYGbhmj6IxmalzVFOjG1ZtkFL5fs59nK42aP/JZ0SdtTJjJA0PkbEFL3YOmVtRUmizVtZk63JYuyCgw36XLscTb3VWVynLONYa1RPyRLaz8L5FkTVySCFb8gP9KtDipBpPdGIeD0MGRijAPLweB4iDkM6zv9Yu6dMijZgSR0g6LmjQZPcm1YfI9AK2ht86oKJfvpj+UdYkK+wKKNzJKjKN08+mIKYbsumpbgKqx13d8sawKC4EfGAJHXsadat77Kp/ECCvhh7/i6gqWBHD0+I2LGiuQTr/Vd6OxAGmtyFOzdSGsfWm0cq78Lc6og7HTg3n7TbnEfJMaQktAI7vQYqnsvZV/KnfZ1elfPubFaFiHmCJzfkuk4X4y6r5A6FxpuEltvrHRtecQ6FHLHsBSZrUg7Dei9urMonphUfEj4rsVOMlB3ZKiQ0unmWmacmFKkHm+WhpQtzLS57/iuGCdKiL8qWBiCz80bCQXQp1iwdScZ+pZQ5pwVABH0sr9YfQEz6+oMh3Kp5LiXJAy7kNEs20h4oMP1bc/gN+F5cRubHrz3sXHWpAXF/pNw862Lj7rL0PPOZdomgHSpmKybaQyJxemlxOP9eFw2r2aym/6jc4nQoR+1Mu1ijaroJ8MgZSwTKmru+xgJXnwLx8i76iRlze2F00iNOMg/pRtFQmWh/zLsKukFtIi2PgPKo8xNRQgYvB76x0jLaX3cllpu/pL4LIM2q0p+V9+FPBPCeinkDA7jQzy397vlOSbdECuPYaj2JmkH4zKbxdDlZffgfjWCLDFkqPc0Ixz8O1k24yfkqwG792anEValGM/Hnhdh3a24y4dTV28eo1SoJ6pD1yrjP2RNvgeqs2xEbKOxPywmmOjq0zc805cXBTjDOVyeSFiiY3yJ2GN33KXv67svXw/Ky0Nl9Epbk6xbSB5b76HPAjJ1gZXEUE2zeRTVVOAWzrCERUUjcccz1ozde//rOjzEt+3fxdrq/oglOMW5Ge7ddo/a5lDRs10G3KHreT1sZz9NmA0U7VOQDwwosmt1AYB1LPg3HM2mwi8ahbZf62K1o/W450RWuG072u1unrCYBNYYArN4lnE11L5LvxctFT8qwbBPD1rHAs+Pr4GKQkTWOmUM7wS59GGWE3UEDrntj1Our5NNto7bjK05h33GRF+Vge1+8EfZ+eL2aikjeq/5dU/cN2aw5v3FCkrXzXbIX4YQMw0nu+MbbqXzmPAa7ibS5z0puTYiVs/iMC0ElMSbp9l65iosfVE+kjlF5QDVJaZoW1rTJ8ACo3zOIH6tv515OPvcxDl08nqw/eH37g/bridClsFlJu8aS8Up6Pxzx7hIjNukoDG/wm9qwN/tGgOhZaotzSQEmJTHy7rWc/tamPEJvzZ0Ev99FAlu816Q/HSSk//y8ZriufU5kvYgz0jVeotTShi5LQK30EDTZtjSxE2eJjpRwAmyD6iQwQQiQkVj5inBB6FbF9u1iVhgVsoQ52YEN5id809NSFasmjJ+szJm0E0e6WMfU3tX1j4nPiIdqV7XF8E01pTIZRBmJVUIeiq7XBb5fARjFBveFwwC/Ck3XykYz7CVQJsGOfk6VqlOzzcKhDOivDsVUPbtWJZgP3qC3sJp0dZV1c1BcDHUVbi8HC81F4zobEQGUOyTFsfeiiLUOHvBsveLt87EYpYe2rwjdnEVN5IU3NG2spMc0C7DeEiovv6GCdWgpoHHwPknS42Yv1ltjoDSrlgVF6nXyFXfVWN9CkZoOsEvoUEXGrHnLMT2RPJIWwZFGWXnHs+WKkEfFJbXJvUfTMlPM+hPOIaurwkKmJqh3cCzIGezzWFfSqM7Lv9Fth9FR1QiUcyK2eg1yc7F3lpicxf86RdcbNZD8Wep1uLA0/5zQYmiHA2ZUBFgpN0KanmykSWMHirsBXF4ixwsduRvo4YqqgIMLDrQInJMt5xvHTwHhKMgSOdpReDwM8zYYBV7ukon3cUEo8pwKuXCLs++DNU40bIyP7bpB4rL2Bm0ojNxsTsQwHO/vqXIDPPsyUEtd6J/JuL364XSP9yE5lnk2g5j3rA7uaq/DpA7m1PO+dE0nZcduFH+5OMC/tMZK1PK0Uc3z7LwIaaOzLBraqm16PdJjLCo5YcQgL5GqXN20sznKrRApqq17gdNayaZoEqG1cO0HzmDcr39Ombh/KrbtYFPkBlNbs/9tjgSOgZ5/aWb/gA6lYzspsKdJN0JAD88uYYBY2+GqdJRB86OP+DRHVRrpHOuLkcWNKILmVN5WFljoQXBvtER2IML4GHZytus+E6o2HCq+5Sh5dJlCBKRHIbzs9XPlEWbcE6Z6inuR5zbVYb97VcOhaM83ZjZZGVfdyqX9QagQ7k7eP2Ifju7vZkhz+HHa1v94HeYRGEKZvPRv2nPuzD1bOauwmhu0WKzYXXoSL40XdvSMNhGlKI65uxh6O1DoIJ8fiB3cFZjK2Li1gYMBzw7Mt8R2sSks3iKuiePX0JRTl0cBoiX7RmZuOlJet1HnUhbliS9vncRGZ/aRAL6VpQ5066/FoOq9ahWK+aSmxidcuRHqyIHhqJtUJrP/44vwDWBdZljgu4ysQvv3Fd+reJ1T4BZlY2jWJK1YCtMQMR2TeEaD7c6a7FtoexX3wuWDEe7er55nE2dW8uRjrptSdpAHFXHnmHmjzhP5Mr+MMAbDcK+3YrnsKYnmSdQP1EYPQ8dyIuc07BpypOZLWFiZpRcLwx2mAp6wFULrPptCWfPCcOgzghBa9l8eVDm/zsfZcP1FP6/J8uSbAI5YUz61gcR8VG1DKCX2e5kIk2xoPB0DlRfltx8NlT2VjLz5xKyyyaq0GpE0EfUXu0q4s2Q9tpk") var expectedEd25519KeyPairPickleLibOLM = crypto.Ed25519KeyPair{ diff --git a/crypto/goolm/account/account_test.go b/crypto/goolm/account/account_test.go index 1f2b9546..eebf4f0b 100644 --- a/crypto/goolm/account/account_test.go +++ b/crypto/goolm/account/account_test.go @@ -1,13 +1,15 @@ -package account +package account_test import ( "bytes" "errors" "testing" - "codeberg.org/DerLukas/goolm" - "codeberg.org/DerLukas/goolm/utilities" "maunium.net/go/mautrix/id" + + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/account" + "maunium.net/go/mautrix/crypto/goolm/utilities" ) type mockRandom struct { @@ -43,7 +45,7 @@ func (m *mockRandom) Read(target []byte) (int, error) { } func TestAccount(t *testing.T) { - firstAccount, err := NewAccount(nil) + firstAccount, err := account.NewAccount(nil) if err != nil { t.Fatal(err) } @@ -62,7 +64,7 @@ func TestAccount(t *testing.T) { t.Fatal(err) } //now unpickle into new Account - unpickledAccount, err := AccountFromJSONPickled(pickled, encryptionKey) + unpickledAccount, err := account.AccountFromJSONPickled(pickled, encryptionKey) if err != nil { t.Fatal(err) } @@ -129,7 +131,7 @@ func TestAccountPickleJSON(t *testing.T) { */ pickledData := []byte("fZG5DhZ0+uhVFEcdgo/dyWNy1BlSKo+W18D/QLBcZfvP0rByRzjgJM5yeDIO9N6jYFp2MbV1Y1DikFlDctwq7PhIRvbtLdrzxT94WoLrUdiNtQkw6NRNXvsFYo4NKoAgl1yQauttnGRBHCCPVV6e9d4kvnPVRkZNkbbANnadF0Tld/SMMWWoPI3L7dy+oiRh6nqNKvZz+upvgmOSm6gu2xV0yx9RJpkvLz8oHMDui1VQ1T2wTpfk5vdw0Cx4BXspf8WDnntdv0Ui4qBzUFmsB4lfqLviuhnAxu+qQrrKcZz/EyzbPwmI+P4Tn5KznxzEx2Nw/AjKKPxqVAKpx8+nV7rKKzlah71wX2CHyEsp2ptcNTJ1lr6tJxkOLdy8Rw285jpKw4MrgghnhqZ9Hh3y5P6KnRrq6zom9zfkCtCXs2h8BK+I0tkMPXO+JZoJKVOWzS+n7FIrC9XC9nAu19G5cnxv+tJdPb3p") - account, err := AccountFromJSONPickled(pickledData, key) + account, err := account.AccountFromJSONPickled(pickledData, key) if err != nil { t.Fatal(err) } @@ -144,7 +146,7 @@ func TestAccountPickleJSON(t *testing.T) { } func TestSessions(t *testing.T) { - aliceAccount, err := NewAccount(nil) + aliceAccount, err := account.NewAccount(nil) if err != nil { t.Fatal(err) } @@ -152,7 +154,7 @@ func TestSessions(t *testing.T) { if err != nil { t.Fatal(err) } - bobAccount, err := NewAccount(nil) + bobAccount, err := account.NewAccount(nil) if err != nil { t.Fatal(err) } @@ -188,7 +190,7 @@ func TestSessions(t *testing.T) { func TestAccountPickle(t *testing.T) { pickleKey := []byte("secret_key") - account, err := AccountFromPickled(pickledDataFromLibOlm, pickleKey) + account, err := account.AccountFromPickled(pickledDataFromLibOlm, pickleKey) if err != nil { t.Fatal(err) } @@ -243,7 +245,7 @@ func TestOldAccountPickle(t *testing.T) { "K/A/8TOu9iK2hDFszy6xETiousHnHgh2ZGbRUh4pQx+YMm8ZdNZeRnwFGLnrWyf9" + "O5TmXua1FcU") pickleKey := []byte("") - account, err := NewAccount(nil) + account, err := account.NewAccount(nil) if err != nil { t.Fatal(err) } @@ -266,12 +268,12 @@ func TestLoopback(t *testing.T) { tag: []byte("B")[0], current: 0x80, } - accountA, err := NewAccount(&mockA) + accountA, err := account.NewAccount(&mockA) if err != nil { t.Fatal(err) } - accountB, err := NewAccount(&mockB) + accountB, err := account.NewAccount(&mockB) if err != nil { t.Fatal(err) } @@ -370,12 +372,12 @@ func TestMoreMessages(t *testing.T) { tag: []byte("B")[0], current: 0x80, } - accountA, err := NewAccount(&mockA) + accountA, err := account.NewAccount(&mockA) if err != nil { t.Fatal(err) } - accountB, err := NewAccount(&mockB) + accountB, err := account.NewAccount(&mockB) if err != nil { t.Fatal(err) } @@ -461,12 +463,12 @@ func TestFallbackKey(t *testing.T) { tag: []byte("B")[0], current: 0x80, } - accountA, err := NewAccount(&mockA) + accountA, err := account.NewAccount(&mockA) if err != nil { t.Fatal(err) } - accountB, err := NewAccount(&mockB) + accountB, err := account.NewAccount(&mockB) if err != nil { t.Fatal(err) } @@ -631,7 +633,7 @@ func TestOldV3AccountPickle(t *testing.T) { expectedFallbackJSON := []byte("{\"curve25519\":{\"AAAAAQ\":\"dr98y6VOWt6lJaQgFVZeWY2ky76mga9MEMbdItJTdng\"}}") expectedUnpublishedFallbackJSON := []byte("{\"curve25519\":{}}") - account, err := AccountFromPickled(pickledData, pickleKey) + account, err := account.AccountFromPickled(pickledData, pickleKey) if err != nil { t.Fatal(err) } @@ -656,7 +658,7 @@ func TestAccountSign(t *testing.T) { tag: []byte("A")[0], current: 0x00, } - accountA, err := NewAccount(&mockA) + accountA, err := account.NewAccount(&mockA) if err != nil { t.Fatal(err) } diff --git a/crypto/goolm/base64.go b/crypto/goolm/base64.go index 2cc0ec2b..229008cf 100644 --- a/crypto/goolm/base64.go +++ b/crypto/goolm/base64.go @@ -2,40 +2,21 @@ package goolm import ( "encoding/base64" - - "github.com/pkg/errors" ) -// Base64Decode decodes the input. Padding characters ('=') will be added if needed. +// Deprecated: base64.RawStdEncoding should be used directly func Base64Decode(input []byte) ([]byte, error) { - //pad the input to multiple of 4 - addedPadding := 0 - for len(input)%4 != 0 { - input = append(input, []byte("=")...) - addedPadding++ - } - if addedPadding >= 3 { - return nil, errors.Wrap(ErrBase64InvalidLength, "") - } - decoded := make([]byte, base64.StdEncoding.DecodedLen(len(input))) - writtenBytes, err := base64.StdEncoding.Decode(decoded, input) + decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(input))) + writtenBytes, err := base64.RawStdEncoding.Decode(decoded, input) if err != nil { - return nil, errors.Wrap(ErrBadBase64, err.Error()) + return nil, err } - //DecodedLen returned the maximum size. However this might not be the true length. return decoded[:writtenBytes], nil } -// Base64Encode encodes the input and strips all padding characters ('=') from the end. +// Deprecated: base64.RawStdEncoding should be used directly func Base64Encode(input []byte) []byte { - encoded := make([]byte, base64.StdEncoding.EncodedLen(len(input))) - base64.StdEncoding.Encode(encoded, input) - //Remove padding = from output as libolm does so - for curIndex := len(encoded) - 1; curIndex >= 0; curIndex-- { - if string(encoded[curIndex]) != "=" { - encoded = encoded[:curIndex+1] - break - } - } + encoded := make([]byte, base64.RawStdEncoding.EncodedLen(len(input))) + base64.RawStdEncoding.Encode(encoded, input) return encoded } diff --git a/crypto/goolm/base64_test.go b/crypto/goolm/base64_test.go deleted file mode 100644 index be22a7aa..00000000 --- a/crypto/goolm/base64_test.go +++ /dev/null @@ -1,47 +0,0 @@ -package goolm - -import ( - "bytes" - "errors" - "testing" -) - -func TestBase64Encode(t *testing.T) { - input := []byte("Hello World") - expected := []byte("SGVsbG8gV29ybGQ") - result := Base64Encode(input) - if !bytes.Equal(result, expected) { - t.Fatalf("expected '%s' but got '%s'", string(expected), string(result)) - } -} - -func TestBase64Decode(t *testing.T) { - input := []byte("SGVsbG8gV29ybGQ") - expected := []byte("Hello World") - result, err := Base64Decode(input) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(result, expected) { - t.Fatalf("expected '%s' but got '%s'", string(expected), string(result)) - } - //This should fai - _, err = Base64Decode([]byte("SGVsbG8gV29ybGQab")) - if err == nil { - t.Fatal("decoded wrong input") - } - if !errors.Is(err, ErrBase64InvalidLength) { - t.Fatalf("got other error as expected: %s", err) - } -} - -func TestBase64DecodeFail(t *testing.T) { - input := []byte("SGVsbG8gV29ybGQab") - _, err := Base64Decode(input) - if err == nil { - t.Fatal("expected error") - } - if !errors.Is(err, ErrBase64InvalidLength) { - t.Fatal(err) - } -} diff --git a/crypto/goolm/cipher/aesSha256.go b/crypto/goolm/cipher/aes_sha256.go similarity index 74% rename from crypto/goolm/cipher/aesSha256.go rename to crypto/goolm/cipher/aes_sha256.go index 1234b14e..1155949b 100644 --- a/crypto/goolm/cipher/aesSha256.go +++ b/crypto/goolm/cipher/aes_sha256.go @@ -4,17 +4,17 @@ import ( "bytes" "io" - "codeberg.org/DerLukas/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/crypto" ) -// derivedAESKeys stores the derived keys for the AESSha256 cipher +// derivedAESKeys stores the derived keys for the AESSHA256 cipher type derivedAESKeys struct { key []byte hmacKey []byte iv []byte } -// deriveAESKeys derives three keys for the AESSha256 cipher +// deriveAESKeys derives three keys for the AESSHA256 cipher func deriveAESKeys(kdfInfo []byte, key []byte) (*derivedAESKeys, error) { hkdf := crypto.HKDFSHA256(key, nil, kdfInfo) keys := &derivedAESKeys{ @@ -34,25 +34,25 @@ func deriveAESKeys(kdfInfo []byte, key []byte) (*derivedAESKeys, error) { return keys, nil } -// AESSha512BlockSize resturns the blocksize of the cipher AESSha256. +// AESSha512BlockSize resturns the blocksize of the cipher AESSHA256. func AESSha512BlockSize() int { return crypto.AESCBCBlocksize() } -// AESSha256 is a valid cipher using AES with CBC and HKDFSha256. -type AESSha256 struct { +// AESSHA256 is a valid cipher using AES with CBC and HKDFSha256. +type AESSHA256 struct { kdfInfo []byte } -// NewAESSha256 returns a new AESSha256 cipher with the key derive function info (kdfInfo). -func NewAESSha256(kdfInfo []byte) *AESSha256 { - return &AESSha256{ +// NewAESSHA256 returns a new AESSHA256 cipher with the key derive function info (kdfInfo). +func NewAESSHA256(kdfInfo []byte) *AESSHA256 { + return &AESSHA256{ kdfInfo: kdfInfo, } } // Encrypt encrypts the plaintext with the key. The key is used to derive the actual encryption key (32 bytes) as well as the iv (16 bytes). -func (c AESSha256) Encrypt(key, plaintext []byte) (ciphertext []byte, err error) { +func (c AESSHA256) Encrypt(key, plaintext []byte) (ciphertext []byte, err error) { keys, err := deriveAESKeys(c.kdfInfo, key) if err != nil { return nil, err @@ -65,7 +65,7 @@ func (c AESSha256) Encrypt(key, plaintext []byte) (ciphertext []byte, err error) } // Decrypt decrypts the ciphertext with the key. The key is used to derive the actual encryption key (32 bytes) as well as the iv (16 bytes). -func (c AESSha256) Decrypt(key, ciphertext []byte) (plaintext []byte, err error) { +func (c AESSHA256) Decrypt(key, ciphertext []byte) (plaintext []byte, err error) { keys, err := deriveAESKeys(c.kdfInfo, key) if err != nil { return nil, err @@ -78,7 +78,7 @@ func (c AESSha256) Decrypt(key, ciphertext []byte) (plaintext []byte, err error) } // MAC returns the MAC for the message using the key. The key is used to derive the actual mac key (32 bytes). -func (c AESSha256) MAC(key, message []byte) ([]byte, error) { +func (c AESSHA256) MAC(key, message []byte) ([]byte, error) { keys, err := deriveAESKeys(c.kdfInfo, key) if err != nil { return nil, err @@ -87,7 +87,7 @@ func (c AESSha256) MAC(key, message []byte) ([]byte, error) { } // Verify checks the MAC of the message using the key against the givenMAC. The key is used to derive the actual mac key (32 bytes). -func (c AESSha256) Verify(key, message, givenMAC []byte) (bool, error) { +func (c AESSHA256) Verify(key, message, givenMAC []byte) (bool, error) { mac, err := c.MAC(key, message) if err != nil { return false, err diff --git a/crypto/goolm/cipher/aesSha256_test.go b/crypto/goolm/cipher/aes_sha256_test.go similarity index 97% rename from crypto/goolm/cipher/aesSha256_test.go rename to crypto/goolm/cipher/aes_sha256_test.go index e8068248..d2f49cb1 100644 --- a/crypto/goolm/cipher/aesSha256_test.go +++ b/crypto/goolm/cipher/aes_sha256_test.go @@ -51,7 +51,7 @@ func TestDeriveAESKeys(t *testing.T) { func TestCipherAESSha256(t *testing.T) { key := []byte("test key") - cipher := NewAESSha256([]byte("testKDFinfo")) + cipher := NewAESSHA256([]byte("testKDFinfo")) message := []byte("this is a random message for testing the implementation") //increase to next block size for len(message)%aes.BlockSize != 0 { diff --git a/crypto/goolm/cipher/pickle.go b/crypto/goolm/cipher/pickle.go index 9faa0285..670ff6ff 100644 --- a/crypto/goolm/cipher/pickle.go +++ b/crypto/goolm/cipher/pickle.go @@ -1,8 +1,9 @@ package cipher import ( - "codeberg.org/DerLukas/goolm" - "github.com/pkg/errors" + "fmt" + + "maunium.net/go/mautrix/crypto/goolm" ) const ( @@ -15,9 +16,9 @@ func PickleBlockSize() int { return AESSha512BlockSize() } -// Pickle encrypts the input with the key and the cipher AESSha256. The result is then encoded in base64. +// Pickle encrypts the input with the key and the cipher AESSHA256. The result is then encoded in base64. func Pickle(key, input []byte) ([]byte, error) { - pickleCipher := NewAESSha256([]byte(kdfPickle)) + pickleCipher := NewAESSHA256([]byte(kdfPickle)) ciphertext, err := pickleCipher.Encrypt(key, input) if err != nil { return nil, err @@ -31,9 +32,9 @@ func Pickle(key, input []byte) ([]byte, error) { return encoded, nil } -// Unpickle decodes the input from base64 and decrypts the decoded input with the key and the cipher AESSha256. +// Unpickle decodes the input from base64 and decrypts the decoded input with the key and the cipher AESSHA256. func Unpickle(key, input []byte) ([]byte, error) { - pickleCipher := NewAESSha256([]byte(kdfPickle)) + pickleCipher := NewAESSHA256([]byte(kdfPickle)) ciphertext, err := goolm.Base64Decode(input) if err != nil { return nil, err @@ -44,7 +45,7 @@ func Unpickle(key, input []byte) ([]byte, error) { return nil, err } if !verified { - return nil, errors.Wrap(goolm.ErrBadMAC, "decrypt pickle") + return nil, fmt.Errorf("decrypt pickle: %w", goolm.ErrBadMAC) } //Set to next block size targetCipherText := make([]byte, int(len(ciphertext)/PickleBlockSize())*PickleBlockSize()) diff --git a/crypto/goolm/cipher/pickle_test.go b/crypto/goolm/cipher/pickle_test.go index 0e06fece..b47bf3ea 100644 --- a/crypto/goolm/cipher/pickle_test.go +++ b/crypto/goolm/cipher/pickle_test.go @@ -1,9 +1,11 @@ -package cipher +package cipher_test import ( "bytes" "crypto/aes" "testing" + + "maunium.net/go/mautrix/crypto/goolm/cipher" ) func TestEncoding(t *testing.T) { @@ -16,12 +18,12 @@ func TestEncoding(t *testing.T) { toEncrypt = make([]byte, len(input)+padding) copy(toEncrypt, input) } - encoded, err := Pickle(key, toEncrypt) + encoded, err := cipher.Pickle(key, toEncrypt) if err != nil { t.Fatal(err) } - decoded, err := Unpickle(key, encoded) + decoded, err := cipher.Unpickle(key, encoded) if err != nil { t.Fatal(err) } diff --git a/crypto/goolm/crypto/aesCBC.go b/crypto/goolm/crypto/aes_cbc.go similarity index 80% rename from crypto/goolm/crypto/aesCBC.go rename to crypto/goolm/crypto/aes_cbc.go index 87a12a14..10434ab7 100644 --- a/crypto/goolm/crypto/aesCBC.go +++ b/crypto/goolm/crypto/aes_cbc.go @@ -4,9 +4,9 @@ import ( "bytes" "crypto/aes" "crypto/cipher" + "fmt" - "codeberg.org/DerLukas/goolm" - "github.com/pkg/errors" + "maunium.net/go/mautrix/crypto/goolm" ) // AESCBCBlocksize returns the blocksize of the encryption method @@ -17,15 +17,15 @@ func AESCBCBlocksize() int { // AESCBCEncrypt encrypts the plaintext with the key and iv. len(iv) must be equal to the blocksize! func AESCBCEncrypt(key, iv, plaintext []byte) ([]byte, error) { if len(key) == 0 { - return nil, errors.Wrap(goolm.ErrNoKeyProvided, "AESCBCEncrypt") + return nil, fmt.Errorf("AESCBCEncrypt: %w", goolm.ErrNoKeyProvided) } if len(iv) != AESCBCBlocksize() { - return nil, errors.Wrap(goolm.ErrNotBlocksize, "iv") + return nil, fmt.Errorf("iv: %w", goolm.ErrNotBlocksize) } var cipherText []byte plaintext = pkcs5Padding(plaintext, AESCBCBlocksize()) if len(plaintext)%AESCBCBlocksize() != 0 { - return nil, errors.Wrap(goolm.ErrNotMultipleBlocksize, "message") + return nil, fmt.Errorf("message: %w", goolm.ErrNotMultipleBlocksize) } block, err := aes.NewCipher(key) if err != nil { @@ -40,10 +40,10 @@ func AESCBCEncrypt(key, iv, plaintext []byte) ([]byte, error) { // AESCBCDecrypt decrypts the ciphertext with the key and iv. len(iv) must be equal to the blocksize! func AESCBCDecrypt(key, iv, ciphertext []byte) ([]byte, error) { if len(key) == 0 { - return nil, errors.Wrap(goolm.ErrNoKeyProvided, "AESCBCEncrypt") + return nil, fmt.Errorf("AESCBCEncrypt: %w", goolm.ErrNoKeyProvided) } if len(iv) != AESCBCBlocksize() { - return nil, errors.Wrap(goolm.ErrNotBlocksize, "iv") + return nil, fmt.Errorf("iv: %w", goolm.ErrNotBlocksize) } var block cipher.Block var err error @@ -52,7 +52,7 @@ func AESCBCDecrypt(key, iv, ciphertext []byte) ([]byte, error) { return nil, err } if len(ciphertext) < AESCBCBlocksize() { - return nil, errors.Wrap(goolm.ErrNotMultipleBlocksize, "ciphertext") + return nil, fmt.Errorf("ciphertext: %w", goolm.ErrNotMultipleBlocksize) } cbc := cipher.NewCBCDecrypter(block, iv) diff --git a/crypto/goolm/crypto/aesCBC_test.go b/crypto/goolm/crypto/aes_cbc_test.go similarity index 81% rename from crypto/goolm/crypto/aesCBC_test.go rename to crypto/goolm/crypto/aes_cbc_test.go index 87819b40..c64e4a5d 100644 --- a/crypto/goolm/crypto/aesCBC_test.go +++ b/crypto/goolm/crypto/aes_cbc_test.go @@ -1,10 +1,12 @@ -package crypto +package crypto_test import ( "bytes" "crypto/aes" "crypto/rand" "testing" + + "maunium.net/go/mautrix/crypto/goolm/crypto" ) func TestAESCBC(t *testing.T) { @@ -28,11 +30,11 @@ func TestAESCBC(t *testing.T) { plaintext = append(plaintext, []byte("-")...) } - if ciphertext, err = AESCBCEncrypt(key, iv, plaintext); err != nil { + if ciphertext, err = crypto.AESCBCEncrypt(key, iv, plaintext); err != nil { t.Fatal(err) } - resultPlainText, err := AESCBCDecrypt(key, iv, ciphertext) + resultPlainText, err := crypto.AESCBCDecrypt(key, iv, ciphertext) if err != nil { t.Fatal(err) } @@ -52,7 +54,7 @@ func TestAESCBCCase1(t *testing.T) { input := make([]byte, 16) key := make([]byte, 32) iv := make([]byte, aes.BlockSize) - encrypted, err := AESCBCEncrypt(key, iv, input) + encrypted, err := crypto.AESCBCEncrypt(key, iv, input) if err != nil { t.Fatal(err) } @@ -60,7 +62,7 @@ func TestAESCBCCase1(t *testing.T) { t.Fatalf("encrypted did not match expected:\n%v\n%v\n", encrypted, expected) } - decrypted, err := AESCBCDecrypt(key, iv, encrypted) + decrypted, err := crypto.AESCBCDecrypt(key, iv, encrypted) if err != nil { t.Fatal(err) } diff --git a/crypto/goolm/crypto/curve25519.go b/crypto/goolm/crypto/curve25519.go index 5b39ceca..125e1bfd 100644 --- a/crypto/goolm/crypto/curve25519.go +++ b/crypto/goolm/crypto/curve25519.go @@ -3,12 +3,14 @@ package crypto import ( "bytes" "crypto/rand" + "encoding/base64" + "fmt" "io" - "codeberg.org/DerLukas/goolm" - libolmpickle "codeberg.org/DerLukas/goolm/libolmPickle" - "github.com/pkg/errors" "golang.org/x/crypto/curve25519" + + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/id" ) @@ -76,11 +78,11 @@ func (c Curve25519KeyPair) SharedSecret(pubKey Curve25519PublicKey) ([]byte, err // It returns the number of bytes written. func (c Curve25519KeyPair) PickleLibOlm(target []byte) (int, error) { if len(target) < c.PickleLen() { - return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle curve25519 key pair") + return 0, fmt.Errorf("pickle curve25519 key pair: %w", goolm.ErrValueTooShort) } written, err := c.PublicKey.PickleLibOlm(target) if err != nil { - return 0, errors.Wrap(err, "pickle curve25519 key pair") + return 0, fmt.Errorf("pickle curve25519 key pair: %w", err) } if len(c.PrivateKey) != Curve25519KeyLength { written += libolmpickle.PickleBytes(make([]byte, Curve25519KeyLength), target[written:]) @@ -150,14 +152,14 @@ func (c Curve25519PublicKey) Equal(x Curve25519PublicKey) bool { // B64Encoded returns a base64 encoded string of the public key. func (c Curve25519PublicKey) B64Encoded() id.Curve25519 { - return id.Curve25519(goolm.Base64Encode(c)) + return id.Curve25519(base64.RawStdEncoding.EncodeToString(c)) } // PickleLibOlm encodes the public key into target. target has to have a size of at least PickleLen() and is written to from index 0. // It returns the number of bytes written. func (c Curve25519PublicKey) PickleLibOlm(target []byte) (int, error) { if len(target) < c.PickleLen() { - return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle curve25519 public key") + return 0, fmt.Errorf("pickle curve25519 public key: %w", goolm.ErrValueTooShort) } if len(c) != curve25519PubKeyLength { return libolmpickle.PickleBytes(make([]byte, curve25519PubKeyLength), target), nil diff --git a/crypto/goolm/crypto/curve25519_test.go b/crypto/goolm/crypto/curve25519_test.go index e8d4090b..f7df5edc 100644 --- a/crypto/goolm/crypto/curve25519_test.go +++ b/crypto/goolm/crypto/curve25519_test.go @@ -1,16 +1,18 @@ -package crypto +package crypto_test import ( "bytes" "testing" + + "maunium.net/go/mautrix/crypto/goolm/crypto" ) func TestCurve25519(t *testing.T) { - firstKeypair, err := Curve25519GenerateKey(nil) + firstKeypair, err := crypto.Curve25519GenerateKey(nil) if err != nil { t.Fatal(err) } - secondKeypair, err := Curve25519GenerateKey(nil) + secondKeypair, err := crypto.Curve25519GenerateKey(nil) if err != nil { t.Fatal(err) } @@ -25,7 +27,7 @@ func TestCurve25519(t *testing.T) { if !bytes.Equal(sharedSecretFromFirst, sharedSecretFromSecond) { t.Fatal("shared secret not equal") } - fromPrivate, err := Curve25519GenerateFromPrivate(firstKeypair.PrivateKey) + fromPrivate, err := crypto.Curve25519GenerateFromPrivate(firstKeypair.PrivateKey) if err != nil { t.Fatal(err) } @@ -65,11 +67,11 @@ func TestCurve25519Case1(t *testing.T) { 0xE0, 0x7E, 0x21, 0xC9, 0x47, 0xD1, 0x9E, 0x33, 0x76, 0xF0, 0x9B, 0x3C, 0x1E, 0x16, 0x17, 0x42, } - aliceKeyPair := Curve25519KeyPair{ + aliceKeyPair := crypto.Curve25519KeyPair{ PrivateKey: alicePrivate, PublicKey: alicePublic, } - bobKeyPair := Curve25519KeyPair{ + bobKeyPair := crypto.Curve25519KeyPair{ PrivateKey: bobPrivate, PublicKey: bobPublic, } @@ -91,7 +93,7 @@ func TestCurve25519Case1(t *testing.T) { func TestCurve25519Pickle(t *testing.T) { //create keypair - keyPair, err := Curve25519GenerateKey(nil) + keyPair, err := crypto.Curve25519GenerateKey(nil) if err != nil { t.Fatal(err) } @@ -104,7 +106,7 @@ func TestCurve25519Pickle(t *testing.T) { t.Fatal("written bytes not correct") } - unpickledKeyPair := Curve25519KeyPair{} + unpickledKeyPair := crypto.Curve25519KeyPair{} readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) if err != nil { t.Fatal(err) @@ -122,7 +124,7 @@ func TestCurve25519Pickle(t *testing.T) { func TestCurve25519PicklePubKeyOnly(t *testing.T) { //create keypair - keyPair, err := Curve25519GenerateKey(nil) + keyPair, err := crypto.Curve25519GenerateKey(nil) if err != nil { t.Fatal(err) } @@ -136,7 +138,7 @@ func TestCurve25519PicklePubKeyOnly(t *testing.T) { if writtenBytes != len(target) { t.Fatal("written bytes not correct") } - unpickledKeyPair := Curve25519KeyPair{} + unpickledKeyPair := crypto.Curve25519KeyPair{} readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) if err != nil { t.Fatal(err) @@ -154,7 +156,7 @@ func TestCurve25519PicklePubKeyOnly(t *testing.T) { func TestCurve25519PicklePrivKeyOnly(t *testing.T) { //create keypair - keyPair, err := Curve25519GenerateKey(nil) + keyPair, err := crypto.Curve25519GenerateKey(nil) if err != nil { t.Fatal(err) } @@ -168,7 +170,7 @@ func TestCurve25519PicklePrivKeyOnly(t *testing.T) { if writtenBytes != len(target) { t.Fatal("written bytes not correct") } - unpickledKeyPair := Curve25519KeyPair{} + unpickledKeyPair := crypto.Curve25519KeyPair{} readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) if err != nil { t.Fatal(err) diff --git a/crypto/goolm/crypto/ed25519.go b/crypto/goolm/crypto/ed25519.go index 483e0041..bc21300c 100644 --- a/crypto/goolm/crypto/ed25519.go +++ b/crypto/goolm/crypto/ed25519.go @@ -3,11 +3,12 @@ package crypto import ( "bytes" "crypto/ed25519" + "encoding/base64" + "fmt" "io" - "codeberg.org/DerLukas/goolm" - libolmpickle "codeberg.org/DerLukas/goolm/libolmPickle" - "github.com/pkg/errors" + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/id" ) @@ -52,7 +53,7 @@ type Ed25519KeyPair struct { // B64Encoded returns a base64 encoded string of the public key. func (c Ed25519KeyPair) B64Encoded() id.Ed25519 { - return id.Ed25519(goolm.Base64Encode(c.PublicKey)) + return id.Ed25519(base64.RawStdEncoding.EncodeToString(c.PublicKey)) } // Sign returns the signature for the message. @@ -69,11 +70,11 @@ func (c Ed25519KeyPair) Verify(message, givenSignature []byte) bool { // It returns the number of bytes written. func (c Ed25519KeyPair) PickleLibOlm(target []byte) (int, error) { if len(target) < c.PickleLen() { - return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle ed25519 key pair") + return 0, fmt.Errorf("pickle ed25519 key pair: %w", goolm.ErrValueTooShort) } written, err := c.PublicKey.PickleLibOlm(target) if err != nil { - return 0, errors.Wrap(err, "pickle ed25519 key pair") + return 0, fmt.Errorf("pickle ed25519 key pair: %w", err) } if len(c.PrivateKey) != ed25519.PrivateKeySize { @@ -141,7 +142,7 @@ func (c Ed25519PublicKey) Equal(x Ed25519PublicKey) bool { // B64Encoded returns a base64 encoded string of the public key. func (c Ed25519PublicKey) B64Encoded() id.Curve25519 { - return id.Curve25519(goolm.Base64Encode(c)) + return id.Curve25519(base64.RawStdEncoding.EncodeToString(c)) } // Verify checks the signature of the message against the givenSignature @@ -153,7 +154,7 @@ func (c Ed25519PublicKey) Verify(message, givenSignature []byte) bool { // It returns the number of bytes written. func (c Ed25519PublicKey) PickleLibOlm(target []byte) (int, error) { if len(target) < c.PickleLen() { - return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle ed25519 public key") + return 0, fmt.Errorf("pickle ed25519 public key: %w", goolm.ErrValueTooShort) } if len(c) != ed25519.PublicKeySize { return libolmpickle.PickleBytes(make([]byte, ed25519.PublicKeySize), target), nil diff --git a/crypto/goolm/crypto/ed25519_test.go b/crypto/goolm/crypto/ed25519_test.go index dbff8454..391de912 100644 --- a/crypto/goolm/crypto/ed25519_test.go +++ b/crypto/goolm/crypto/ed25519_test.go @@ -1,12 +1,14 @@ -package crypto +package crypto_test import ( "bytes" "testing" + + "maunium.net/go/mautrix/crypto/goolm/crypto" ) func TestEd25519(t *testing.T) { - keypair, err := Ed25519GenerateKey(nil) + keypair, err := crypto.Ed25519GenerateKey(nil) if err != nil { t.Fatal(err) } @@ -19,13 +21,13 @@ func TestEd25519(t *testing.T) { func TestEd25519Case1(t *testing.T) { //64 bytes for ed25519 package - keyPair, err := Ed25519GenerateKey(nil) + keyPair, err := crypto.Ed25519GenerateKey(nil) if err != nil { t.Fatal(err) } message := []byte("Hello, World") - keyPair2 := Ed25519GenerateFromPrivate(keyPair.PrivateKey) + keyPair2 := crypto.Ed25519GenerateFromPrivate(keyPair.PrivateKey) if !bytes.Equal(keyPair.PublicKey, keyPair2.PublicKey) { t.Fatal("not equal key pairs") } @@ -44,7 +46,7 @@ func TestEd25519Case1(t *testing.T) { func TestEd25519Pickle(t *testing.T) { //create keypair - keyPair, err := Ed25519GenerateKey(nil) + keyPair, err := crypto.Ed25519GenerateKey(nil) if err != nil { t.Fatal(err) } @@ -57,7 +59,7 @@ func TestEd25519Pickle(t *testing.T) { t.Fatal("written bytes not correct") } - unpickledKeyPair := Ed25519KeyPair{} + unpickledKeyPair := crypto.Ed25519KeyPair{} readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) if err != nil { t.Fatal(err) @@ -75,7 +77,7 @@ func TestEd25519Pickle(t *testing.T) { func TestEd25519PicklePubKeyOnly(t *testing.T) { //create keypair - keyPair, err := Ed25519GenerateKey(nil) + keyPair, err := crypto.Ed25519GenerateKey(nil) if err != nil { t.Fatal(err) } @@ -89,7 +91,7 @@ func TestEd25519PicklePubKeyOnly(t *testing.T) { if writtenBytes != len(target) { t.Fatal("written bytes not correct") } - unpickledKeyPair := Ed25519KeyPair{} + unpickledKeyPair := crypto.Ed25519KeyPair{} readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) if err != nil { t.Fatal(err) @@ -107,7 +109,7 @@ func TestEd25519PicklePubKeyOnly(t *testing.T) { func TestEd25519PicklePrivKeyOnly(t *testing.T) { //create keypair - keyPair, err := Ed25519GenerateKey(nil) + keyPair, err := crypto.Ed25519GenerateKey(nil) if err != nil { t.Fatal(err) } @@ -121,7 +123,7 @@ func TestEd25519PicklePrivKeyOnly(t *testing.T) { if writtenBytes != len(target) { t.Fatal("written bytes not correct") } - unpickledKeyPair := Ed25519KeyPair{} + unpickledKeyPair := crypto.Ed25519KeyPair{} readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) if err != nil { t.Fatal(err) diff --git a/crypto/goolm/crypto/hmac_test.go b/crypto/goolm/crypto/hmac_test.go index 2c7f1c71..95c0bfd5 100644 --- a/crypto/goolm/crypto/hmac_test.go +++ b/crypto/goolm/crypto/hmac_test.go @@ -1,22 +1,23 @@ -package crypto +package crypto_test import ( "bytes" + "encoding/base64" "io" "testing" - "codeberg.org/DerLukas/goolm" + "maunium.net/go/mautrix/crypto/goolm/crypto" ) func TestHMACSha256(t *testing.T) { key := []byte("test key") message := []byte("test message") - hash := HMACSHA256(key, message) - if !bytes.Equal(hash, HMACSHA256(key, message)) { + hash := crypto.HMACSHA256(key, message) + if !bytes.Equal(hash, crypto.HMACSHA256(key, message)) { t.Fail() } - str := "A4M0ovdiWHaZ5msdDFbrvtChFwZIoIaRSVGmv8bmPtc=" - result, err := goolm.Base64Decode([]byte(str)) + str := "A4M0ovdiWHaZ5msdDFbrvtChFwZIoIaRSVGmv8bmPtc" + result, err := base64.RawStdEncoding.DecodeString(str) if err != nil { t.Fatal(err) } @@ -27,8 +28,8 @@ func TestHMACSha256(t *testing.T) { func TestHKDFSha256(t *testing.T) { message := []byte("test content") - hkdf := HKDFSHA256(message, nil, nil) - hkdf2 := HKDFSHA256(message, nil, nil) + hkdf := crypto.HKDFSHA256(message, nil, nil) + hkdf2 := crypto.HKDFSHA256(message, nil, nil) result := make([]byte, 32) if _, err := io.ReadFull(hkdf, result); err != nil { t.Fatal(err) @@ -50,7 +51,7 @@ func TestSha256Case1(t *testing.T) { 0x27, 0xAE, 0x41, 0xE4, 0x64, 0x9B, 0x93, 0x4C, 0xA4, 0x95, 0x99, 0x1B, 0x78, 0x52, 0xB8, 0x55, } - result := SHA256(input) + result := crypto.SHA256(input) if !bytes.Equal(expected, result) { t.Fatalf("result not as expected:\n%v\n%v\n", result, expected) } @@ -64,7 +65,7 @@ func TestHMACCase1(t *testing.T) { 0xff, 0x16, 0x97, 0xc4, 0x93, 0x71, 0x56, 0x53, 0xc6, 0xc7, 0x12, 0x14, 0x42, 0x92, 0xc5, 0xad, } - result := HMACSHA256(input, input) + result := crypto.HMACSHA256(input, input) if !bytes.Equal(expected, result) { t.Fatalf("result not as expected:\n%v\n%v\n", result, expected) } @@ -90,7 +91,7 @@ func TestHDKFCase1(t *testing.T) { 0x90, 0xb6, 0xc7, 0x3b, 0xb5, 0x0f, 0x9c, 0x31, 0x22, 0xec, 0x84, 0x4a, 0xd7, 0xc2, 0xb3, 0xe5, } - result := HMACSHA256(salt, input) + result := crypto.HMACSHA256(salt, input) if !bytes.Equal(expectedHMAC, result) { t.Fatalf("result not as expected:\n%v\n%v\n", result, expectedHMAC) } @@ -102,7 +103,7 @@ func TestHDKFCase1(t *testing.T) { 0x34, 0x00, 0x72, 0x08, 0xd5, 0xb8, 0x87, 0x18, 0x58, 0x65, } - resultReader := HKDFSHA256(input, salt, info) + resultReader := crypto.HKDFSHA256(input, salt, info) result = make([]byte, len(expectedHDKF)) if _, err := io.ReadFull(resultReader, result); err != nil { t.Fatal(err) diff --git a/crypto/goolm/crypto/oneTimeKey.go b/crypto/goolm/crypto/one_time_key.go similarity index 89% rename from crypto/goolm/crypto/oneTimeKey.go rename to crypto/goolm/crypto/one_time_key.go index 53bc1a3b..67465563 100644 --- a/crypto/goolm/crypto/oneTimeKey.go +++ b/crypto/goolm/crypto/one_time_key.go @@ -1,11 +1,12 @@ package crypto import ( + "encoding/base64" "encoding/binary" + "fmt" - "codeberg.org/DerLukas/goolm" - libolmpickle "codeberg.org/DerLukas/goolm/libolmPickle" - "github.com/pkg/errors" + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/id" ) @@ -37,13 +38,13 @@ func (otk OneTimeKey) Equal(s OneTimeKey) bool { // It returns the number of bytes written. func (c OneTimeKey) PickleLibOlm(target []byte) (int, error) { if len(target) < c.PickleLen() { - return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle one time key") + return 0, fmt.Errorf("pickle one time key: %w", goolm.ErrValueTooShort) } written := libolmpickle.PickleUInt32(uint32(c.ID), target) written += libolmpickle.PickleBool(c.Published, target[written:]) writtenKey, err := c.Key.PickleLibOlm(target[written:]) if err != nil { - return 0, errors.Wrap(err, "pickle one time key") + return 0, fmt.Errorf("pickle one time key: %w", err) } written += writtenKey return written, nil @@ -85,8 +86,7 @@ func (c OneTimeKey) PickleLen() int { func (c OneTimeKey) KeyIDEncoded() string { resSlice := make([]byte, 4) binary.BigEndian.PutUint32(resSlice, c.ID) - encoded := goolm.Base64Encode(resSlice) - return string(encoded) + return base64.RawStdEncoding.EncodeToString(resSlice) } // PublicKeyEncoded returns the base64 encoded public key diff --git a/crypto/goolm/errors.go b/crypto/goolm/errors.go index 46db78f9..7df906d9 100644 --- a/crypto/goolm/errors.go +++ b/crypto/goolm/errors.go @@ -1,7 +1,7 @@ package goolm import ( - "github.com/pkg/errors" + "errors" ) // Those are the most common used errors diff --git a/crypto/goolm/go.mod b/crypto/goolm/go.mod deleted file mode 100644 index c4b6a461..00000000 --- a/crypto/goolm/go.mod +++ /dev/null @@ -1,9 +0,0 @@ -module codeberg.org/DerLukas/goolm - -go 1.19 - -require ( - github.com/pkg/errors v0.9.1 - golang.org/x/crypto v0.3.0 - maunium.net/go/mautrix v0.12.3 -) diff --git a/crypto/goolm/go.sum b/crypto/goolm/go.sum deleted file mode 100644 index 0b0d6700..00000000 --- a/crypto/goolm/go.sum +++ /dev/null @@ -1,20 +0,0 @@ -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= -github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.3.0 h1:a06MkbcxBrEFc0w0QIZWXrH/9cCX6KJyWbBOIwAn+7A= -golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= -golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -maunium.net/go/mautrix v0.2.0-beta.4 h1:L7Jpc+8GGc+Qo0DdamACEeU1Ci9G1mergJpsTTgDOUA= -maunium.net/go/mautrix v0.2.0-beta.4/go.mod h1:WeTUYKrM3/4LZK2bXQ9NRIXnRWKsa+6+OA1gw0nf5G8= -maunium.net/go/mautrix v0.12.3 h1:pUeO1ThhtZxE6XibGCzDhRuxwDIFNugsreVr1yYq96k= -maunium.net/go/mautrix v0.12.3/go.mod h1:uOUjkOjm2C+nQS3mr9B5ATjqemZfnPHvjdd1kZezAwg= diff --git a/crypto/goolm/libolmVersion.md b/crypto/goolm/libolmVersion.md deleted file mode 100644 index d91f241e..00000000 --- a/crypto/goolm/libolmVersion.md +++ /dev/null @@ -1,3 +0,0 @@ -### This package is based on libolm version 3.2.14 - -Changes to the libolm implementation should be reflected in this package and this file should be updated. diff --git a/crypto/goolm/libolmPickle/pickle.go b/crypto/goolm/libolmpickle/pickle.go similarity index 100% rename from crypto/goolm/libolmPickle/pickle.go rename to crypto/goolm/libolmpickle/pickle.go diff --git a/crypto/goolm/libolmPickle/pickle_test.go b/crypto/goolm/libolmpickle/pickle_test.go similarity index 74% rename from crypto/goolm/libolmPickle/pickle_test.go rename to crypto/goolm/libolmpickle/pickle_test.go index ff6062c2..ce118428 100644 --- a/crypto/goolm/libolmPickle/pickle_test.go +++ b/crypto/goolm/libolmpickle/pickle_test.go @@ -1,8 +1,10 @@ -package libolmpickle +package libolmpickle_test import ( "bytes" "testing" + + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" ) func TestPickleUInt32(t *testing.T) { @@ -20,8 +22,8 @@ func TestPickleUInt32(t *testing.T) { } for curIndex := range values { response := make([]byte, 4) - resPLen := PickleUInt32(values[curIndex], response) - if resPLen != PickleUInt32Len(values[curIndex]) { + resPLen := libolmpickle.PickleUInt32(values[curIndex], response) + if resPLen != libolmpickle.PickleUInt32Len(values[curIndex]) { t.Fatal("written bytes not correct") } if !bytes.Equal(response, expected[curIndex]) { @@ -41,8 +43,8 @@ func TestPickleBool(t *testing.T) { } for curIndex := range values { response := make([]byte, 1) - resPLen := PickleBool(values[curIndex], response) - if resPLen != PickleBoolLen(values[curIndex]) { + resPLen := libolmpickle.PickleBool(values[curIndex], response) + if resPLen != libolmpickle.PickleBoolLen(values[curIndex]) { t.Fatal("written bytes not correct") } if !bytes.Equal(response, expected[curIndex]) { @@ -62,8 +64,8 @@ func TestPickleUInt8(t *testing.T) { } for curIndex := range values { response := make([]byte, 1) - resPLen := PickleUInt8(values[curIndex], response) - if resPLen != PickleUInt8Len(values[curIndex]) { + resPLen := libolmpickle.PickleUInt8(values[curIndex], response) + if resPLen != libolmpickle.PickleUInt8Len(values[curIndex]) { t.Fatal("written bytes not correct") } if !bytes.Equal(response, expected[curIndex]) { @@ -85,8 +87,8 @@ func TestPickleBytes(t *testing.T) { } for curIndex := range values { response := make([]byte, len(values[curIndex])) - resPLen := PickleBytes(values[curIndex], response) - if resPLen != PickleBytesLen(values[curIndex]) { + resPLen := libolmpickle.PickleBytes(values[curIndex], response) + if resPLen != libolmpickle.PickleBytesLen(values[curIndex]) { t.Fatal("written bytes not correct") } if !bytes.Equal(response, expected[curIndex]) { diff --git a/crypto/goolm/libolmPickle/unpickle.go b/crypto/goolm/libolmpickle/unpickle.go similarity index 69% rename from crypto/goolm/libolmPickle/unpickle.go rename to crypto/goolm/libolmpickle/unpickle.go index 7140094f..9a6a4b62 100644 --- a/crypto/goolm/libolmPickle/unpickle.go +++ b/crypto/goolm/libolmpickle/unpickle.go @@ -1,8 +1,9 @@ package libolmpickle import ( - "codeberg.org/DerLukas/goolm" - "github.com/pkg/errors" + "fmt" + + "maunium.net/go/mautrix/crypto/goolm" ) func isZeroByteSlice(bytes []byte) bool { @@ -15,21 +16,21 @@ func isZeroByteSlice(bytes []byte) bool { func UnpickleUInt8(value []byte) (uint8, int, error) { if len(value) < 1 { - return 0, 0, errors.Wrap(goolm.ErrValueTooShort, "unpickle uint8") + return 0, 0, fmt.Errorf("unpickle uint8: %w", goolm.ErrValueTooShort) } return value[0], 1, nil } func UnpickleBool(value []byte) (bool, int, error) { if len(value) < 1 { - return false, 0, errors.Wrap(goolm.ErrValueTooShort, "unpickle bool") + return false, 0, fmt.Errorf("unpickle bool: %w", goolm.ErrValueTooShort) } return value[0] != uint8(0x00), 1, nil } func UnpickleBytes(value []byte, length int) ([]byte, int, error) { if len(value) < length { - return nil, 0, errors.Wrap(goolm.ErrValueTooShort, "unpickle bytes") + return nil, 0, fmt.Errorf("unpickle bytes: %w", goolm.ErrValueTooShort) } resp := value[:length] if isZeroByteSlice(resp) { @@ -40,7 +41,7 @@ func UnpickleBytes(value []byte, length int) ([]byte, int, error) { func UnpickleUInt32(value []byte) (uint32, int, error) { if len(value) < 4 { - return 0, 0, errors.Wrap(goolm.ErrValueTooShort, "unpickle uint32") + return 0, 0, fmt.Errorf("unpickle uint32: %w", goolm.ErrValueTooShort) } var res uint32 count := 0 diff --git a/crypto/goolm/libolmPickle/unpickle_test.go b/crypto/goolm/libolmpickle/unpickle_test.go similarity index 82% rename from crypto/goolm/libolmPickle/unpickle_test.go rename to crypto/goolm/libolmpickle/unpickle_test.go index 505f3f64..937630e5 100644 --- a/crypto/goolm/libolmPickle/unpickle_test.go +++ b/crypto/goolm/libolmpickle/unpickle_test.go @@ -1,8 +1,10 @@ -package libolmpickle +package libolmpickle_test import ( "bytes" "testing" + + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" ) func TestUnpickleUInt32(t *testing.T) { @@ -17,7 +19,7 @@ func TestUnpickleUInt32(t *testing.T) { {0xf0, 0x00, 0x00, 0x00}, } for curIndex := range values { - response, readLength, err := UnpickleUInt32(values[curIndex]) + response, readLength, err := libolmpickle.UnpickleUInt32(values[curIndex]) if err != nil { t.Fatal(err) } @@ -42,7 +44,7 @@ func TestUnpickleBool(t *testing.T) { {0x02}, } for curIndex := range values { - response, readLength, err := UnpickleBool(values[curIndex]) + response, readLength, err := libolmpickle.UnpickleBool(values[curIndex]) if err != nil { t.Fatal(err) } @@ -65,7 +67,7 @@ func TestUnpickleUInt8(t *testing.T) { {0x1a}, } for curIndex := range values { - response, readLength, err := UnpickleUInt8(values[curIndex]) + response, readLength, err := libolmpickle.UnpickleUInt8(values[curIndex]) if err != nil { t.Fatal(err) } @@ -90,7 +92,7 @@ func TestUnpickleBytes(t *testing.T) { {0xf0, 0x00, 0x00, 0x00}, } for curIndex := range values { - response, readLength, err := UnpickleBytes(values[curIndex], 4) + response, readLength, err := libolmpickle.UnpickleBytes(values[curIndex], 4) if err != nil { t.Fatal(err) } diff --git a/crypto/goolm/megolm/megolm.go b/crypto/goolm/megolm/megolm.go index 652d8b17..c3493f7b 100644 --- a/crypto/goolm/megolm/megolm.go +++ b/crypto/goolm/megolm/megolm.go @@ -3,14 +3,14 @@ package megolm import ( "crypto/rand" + "fmt" - "codeberg.org/DerLukas/goolm" - "codeberg.org/DerLukas/goolm/cipher" - "codeberg.org/DerLukas/goolm/crypto" - libolmpickle "codeberg.org/DerLukas/goolm/libolmPickle" - "codeberg.org/DerLukas/goolm/message" - "codeberg.org/DerLukas/goolm/utilities" - "github.com/pkg/errors" + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" + "maunium.net/go/mautrix/crypto/goolm/message" + "maunium.net/go/mautrix/crypto/goolm/utilities" ) const ( @@ -23,7 +23,7 @@ const ( RatchetPartLength = 256 / 8 // length of each ratchet part in bytes ) -var RatchetCipher = cipher.NewAESSha256([]byte("MEGOLM_KEYS")) +var RatchetCipher = cipher.NewAESSHA256([]byte("MEGOLM_KEYS")) // hasKeySeed are the seed for the different ratchet parts var hashKeySeeds [RatchetParts][]byte = [RatchetParts][]byte{ @@ -136,7 +136,7 @@ func (r *Ratchet) Encrypt(plaintext []byte, key *crypto.Ed25519KeyPair) ([]byte, var err error encryptedText, err := RatchetCipher.Encrypt(r.Data[:], plaintext) if err != nil { - return nil, errors.Wrap(err, "cipher encrypt") + return nil, fmt.Errorf("cipher encrypt: %w", err) } message := &message.GroupMessage{} @@ -179,7 +179,7 @@ func (r Ratchet) Decrypt(ciphertext []byte, signingkey *crypto.Ed25519PublicKey, return nil, err } if !verifiedMAC { - return nil, errors.Wrap(goolm.ErrBadMAC, "decrypt") + return nil, fmt.Errorf("decrypt: %w", goolm.ErrBadMAC) } return RatchetCipher.Decrypt(r.Data[:], msg.Ciphertext) @@ -219,7 +219,7 @@ func (r *Ratchet) UnpickleLibOlm(unpickled []byte) (int, error) { // It returns the number of bytes written. func (r Ratchet) PickleLibOlm(target []byte) (int, error) { if len(target) < r.PickleLen() { - return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle account") + return 0, fmt.Errorf("pickle account: %w", goolm.ErrValueTooShort) } written := libolmpickle.PickleBytes(r.Data[:], target) written += libolmpickle.PickleUInt32(r.Counter, target[written:]) diff --git a/crypto/goolm/megolm/megolm_test.go b/crypto/goolm/megolm/megolm_test.go index 414deb4e..40289eaf 100644 --- a/crypto/goolm/megolm/megolm_test.go +++ b/crypto/goolm/megolm/megolm_test.go @@ -1,11 +1,13 @@ -package megolm +package megolm_test import ( "bytes" "testing" + + "maunium.net/go/mautrix/crypto/goolm/megolm" ) -var startData [RatchetParts * RatchetPartLength]byte +var startData [megolm.RatchetParts * megolm.RatchetPartLength]byte func init() { startValue := []byte("0123456789ABCDEF0123456789ABCDEF") @@ -16,12 +18,12 @@ func init() { } func TestAdvance(t *testing.T) { - m, err := New(0, startData) + m, err := megolm.New(0, startData) if err != nil { t.Fatal(err) } - expectedData := [RatchetParts * RatchetPartLength]byte{ + expectedData := [megolm.RatchetParts * megolm.RatchetPartLength]byte{ 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, @@ -38,7 +40,7 @@ func TestAdvance(t *testing.T) { //repeat with complex advance m.Data = startData - expectedData = [RatchetParts * RatchetPartLength]byte{ + expectedData = [megolm.RatchetParts * megolm.RatchetPartLength]byte{ 0x54, 0x02, 0x2d, 0x7d, 0xc0, 0x29, 0x8e, 0x16, 0x37, 0xe2, 0x1c, 0x97, 0x15, 0x30, 0x92, 0xf9, 0x33, 0xc0, 0x56, 0xff, 0x74, 0xfe, 0x1b, 0x92, 0x2d, 0x97, 0x1f, 0x24, 0x82, 0xc2, 0x85, 0x9c, 0x70, 0x04, 0xc0, 0x1e, 0xe4, 0x9b, 0xd6, 0xef, 0xe0, 0x07, 0x35, 0x25, 0xaf, 0x9b, 0x16, 0x32, @@ -52,7 +54,7 @@ func TestAdvance(t *testing.T) { if !bytes.Equal(m.Data[:], expectedData[:]) { t.Fatal("result after advancing the ratchet is not as expected") } - expectedData = [RatchetParts * RatchetPartLength]byte{ + expectedData = [megolm.RatchetParts * megolm.RatchetPartLength]byte{ 0x54, 0x02, 0x2d, 0x7d, 0xc0, 0x29, 0x8e, 0x16, 0x37, 0xe2, 0x1c, 0x97, 0x15, 0x30, 0x92, 0xf9, 0x33, 0xc0, 0x56, 0xff, 0x74, 0xfe, 0x1b, 0x92, 0x2d, 0x97, 0x1f, 0x24, 0x82, 0xc2, 0x85, 0x9c, 0x55, 0x58, 0x8d, 0xf5, 0xb7, 0xa4, 0x88, 0x78, 0x42, 0x89, 0x27, 0x86, 0x81, 0x64, 0x58, 0x9f, @@ -69,7 +71,7 @@ func TestAdvance(t *testing.T) { } func TestAdvanceWraparound(t *testing.T) { - m, err := New(0xffffffff, startData) + m, err := megolm.New(0xffffffff, startData) if err != nil { t.Fatal(err) } @@ -78,7 +80,7 @@ func TestAdvanceWraparound(t *testing.T) { t.Fatal("counter not correct") } - m2, err := New(0, startData) + m2, err := megolm.New(0, startData) if err != nil { t.Fatal(err) } @@ -92,7 +94,7 @@ func TestAdvanceWraparound(t *testing.T) { } func TestAdvanceOverflowByOne(t *testing.T) { - m, err := New(0xffffffff, startData) + m, err := megolm.New(0xffffffff, startData) if err != nil { t.Fatal(err) } @@ -101,7 +103,7 @@ func TestAdvanceOverflowByOne(t *testing.T) { t.Fatal("counter not correct") } - m2, err := New(0xffffffff, startData) + m2, err := megolm.New(0xffffffff, startData) if err != nil { t.Fatal(err) } @@ -115,7 +117,7 @@ func TestAdvanceOverflowByOne(t *testing.T) { } func TestAdvanceOverflow(t *testing.T) { - m, err := New(0x1, startData) + m, err := megolm.New(0x1, startData) if err != nil { t.Fatal(err) } @@ -125,7 +127,7 @@ func TestAdvanceOverflow(t *testing.T) { t.Fatal("counter not correct") } - m2, err := New(0x1, startData) + m2, err := megolm.New(0x1, startData) if err != nil { t.Fatal(err) } diff --git a/crypto/goolm/message/decoder.go b/crypto/goolm/message/decoder.go index 37104ab3..ba49f011 100644 --- a/crypto/goolm/message/decoder.go +++ b/crypto/goolm/message/decoder.go @@ -3,18 +3,17 @@ package message import ( "encoding/binary" - "codeberg.org/DerLukas/goolm" - "github.com/pkg/errors" + "maunium.net/go/mautrix/crypto/goolm" ) // checkDecodeErr checks if there was an error during decode. func checkDecodeErr(readBytes int) error { if readBytes == 0 { //end reached - return errors.Wrap(goolm.ErrInputToSmall, "") + return goolm.ErrInputToSmall } if readBytes < 0 { - return errors.Wrap(goolm.ErrOverflow, "") + return goolm.ErrOverflow } return nil } diff --git a/crypto/goolm/message/groupMessage.go b/crypto/goolm/message/group_message.go similarity index 98% rename from crypto/goolm/message/groupMessage.go rename to crypto/goolm/message/group_message.go index 2ef39e48..176214f6 100644 --- a/crypto/goolm/message/groupMessage.go +++ b/crypto/goolm/message/group_message.go @@ -3,8 +3,8 @@ package message import ( "bytes" - "codeberg.org/DerLukas/goolm/cipher" - "codeberg.org/DerLukas/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/crypto" ) const ( diff --git a/crypto/goolm/message/groupMessage_test.go b/crypto/goolm/message/group_message_test.go similarity index 91% rename from crypto/goolm/message/groupMessage_test.go rename to crypto/goolm/message/group_message_test.go index fab2f9ea..4ae1f830 100644 --- a/crypto/goolm/message/groupMessage_test.go +++ b/crypto/goolm/message/group_message_test.go @@ -1,8 +1,10 @@ -package message +package message_test import ( "bytes" "testing" + + "maunium.net/go/mautrix/crypto/goolm/message" ) func TestGroupMessageDecode(t *testing.T) { @@ -12,7 +14,7 @@ func TestGroupMessageDecode(t *testing.T) { expectedMessageIndex := uint32(200) expectedCipherText := []byte("ciphertext") - msg := GroupMessage{} + msg := message.GroupMessage{} err := msg.Decode(messageRaw) if err != nil { t.Fatal(err) @@ -32,7 +34,7 @@ func TestGroupMessageEncode(t *testing.T) { expectedRaw := []byte("\x03\x08\xC8\x01\x12\x0aciphertexthmacsha2signature") hmacsha256 := []byte("hmacsha2") sign := []byte("signature") - msg := GroupMessage{ + msg := message.GroupMessage{ Version: 3, MessageIndex: 200, Ciphertext: []byte("ciphertext"), diff --git a/crypto/goolm/message/message.go b/crypto/goolm/message/message.go index d76c458c..d5c15b1a 100644 --- a/crypto/goolm/message/message.go +++ b/crypto/goolm/message/message.go @@ -3,8 +3,8 @@ package message import ( "bytes" - "codeberg.org/DerLukas/goolm/cipher" - "codeberg.org/DerLukas/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/crypto" ) const ( diff --git a/crypto/goolm/message/message_test.go b/crypto/goolm/message/message_test.go index 1e2921c7..4a9f29fb 100644 --- a/crypto/goolm/message/message_test.go +++ b/crypto/goolm/message/message_test.go @@ -1,8 +1,10 @@ -package message +package message_test import ( "bytes" "testing" + + "maunium.net/go/mautrix/crypto/goolm/message" ) func TestMessageDecode(t *testing.T) { @@ -10,7 +12,7 @@ func TestMessageDecode(t *testing.T) { expectedRatchetKey := []byte("ratchetkey") expectedCipherText := []byte("ciphertext") - msg := Message{} + msg := message.Message{} err := msg.Decode(messageRaw) if err != nil { t.Fatal(err) @@ -35,7 +37,7 @@ func TestMessageDecode(t *testing.T) { func TestMessageEncode(t *testing.T) { expectedRaw := []byte("\x03\n\nratchetkey\x10\x01\"\nciphertexthmacsha2") hmacsha256 := []byte("hmacsha2") - msg := Message{ + msg := message.Message{ Version: 3, Counter: 1, RatchetKey: []byte("ratchetkey"), diff --git a/crypto/goolm/message/preKeyMessage.go b/crypto/goolm/message/prekey_message.go similarity index 98% rename from crypto/goolm/message/preKeyMessage.go rename to crypto/goolm/message/prekey_message.go index 8f1826d7..9df3f9fa 100644 --- a/crypto/goolm/message/preKeyMessage.go +++ b/crypto/goolm/message/prekey_message.go @@ -1,7 +1,7 @@ package message import ( - "codeberg.org/DerLukas/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/crypto" ) const ( diff --git a/crypto/goolm/message/preKeyMessage_test.go b/crypto/goolm/message/prekey_message_test.go similarity index 91% rename from crypto/goolm/message/preKeyMessage_test.go rename to crypto/goolm/message/prekey_message_test.go index a244e95e..431d27d5 100644 --- a/crypto/goolm/message/preKeyMessage_test.go +++ b/crypto/goolm/message/prekey_message_test.go @@ -1,10 +1,11 @@ -package message +package message_test import ( "bytes" "testing" - "codeberg.org/DerLukas/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/message" ) func TestPreKeyMessageDecode(t *testing.T) { @@ -16,7 +17,7 @@ func TestPreKeyMessageDecode(t *testing.T) { expectedbaseKey := []byte("baseKey-.-.-.-.-.-.-.-.-.-.-.-.-") expectedmessage := []byte("message") - msg := PreKeyMessage{} + msg := message.PreKeyMessage{} err := msg.Decode(messageRaw) if err != nil { t.Fatal(err) @@ -45,7 +46,7 @@ func TestPreKeyMessageDecode(t *testing.T) { func TestPreKeyMessageEncode(t *testing.T) { expectedRaw := []byte("\x03\x0a\x0aonetimeKey\x1a\x05idKey\x12\x07baseKey\x22\x07message") - msg := PreKeyMessage{ + msg := message.PreKeyMessage{ Version: 3, IdentityKey: []byte("idKey"), BaseKey: []byte("baseKey"), diff --git a/crypto/goolm/message/sessionExport.go b/crypto/goolm/message/session_export.go similarity index 82% rename from crypto/goolm/message/sessionExport.go rename to crypto/goolm/message/session_export.go index 05814b3a..5c4487e3 100644 --- a/crypto/goolm/message/sessionExport.go +++ b/crypto/goolm/message/session_export.go @@ -2,10 +2,10 @@ package message import ( "encoding/binary" + "fmt" - "codeberg.org/DerLukas/goolm" - "codeberg.org/DerLukas/goolm/crypto" - "github.com/pkg/errors" + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/crypto" ) const ( @@ -32,10 +32,10 @@ func (s MegolmSessionExport) Encode() []byte { // Decode populates the struct with the data encoded in input. func (s *MegolmSessionExport) Decode(input []byte) error { if len(input) != 165 { - return errors.Wrap(goolm.ErrBadInput, "decrypt") + return fmt.Errorf("decrypt: %w", goolm.ErrBadInput) } if input[0] != sessionExportVersion { - return errors.Wrap(goolm.ErrBadVersion, "decrypt") + return fmt.Errorf("decrypt: %w", goolm.ErrBadVersion) } s.Counter = binary.BigEndian.Uint32(input[1:5]) copy(s.RatchetData[:], input[5:133]) diff --git a/crypto/goolm/message/sessionSharing.go b/crypto/goolm/message/session_sharing.go similarity index 83% rename from crypto/goolm/message/sessionSharing.go rename to crypto/goolm/message/session_sharing.go index 5c0cd773..c5393f50 100644 --- a/crypto/goolm/message/sessionSharing.go +++ b/crypto/goolm/message/session_sharing.go @@ -2,10 +2,10 @@ package message import ( "encoding/binary" + "fmt" - "codeberg.org/DerLukas/goolm" - "codeberg.org/DerLukas/goolm/crypto" - "github.com/pkg/errors" + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/crypto" ) const ( @@ -34,15 +34,15 @@ func (s MegolmSessionSharing) EncodeAndSign(key crypto.Ed25519KeyPair) []byte { // VerifyAndDecode verifies the input and populates the struct with the data encoded in input. func (s *MegolmSessionSharing) VerifyAndDecode(input []byte) error { if len(input) != 229 { - return errors.Wrap(goolm.ErrBadInput, "verify") + return fmt.Errorf("verify: %w", goolm.ErrBadInput) } publicKey := crypto.Ed25519PublicKey(input[133:165]) if !publicKey.Verify(input[:165], input[165:]) { - return errors.Wrap(goolm.ErrBadVerification, "verify") + return fmt.Errorf("verify: %w", goolm.ErrBadVerification) } s.PublicKey = publicKey if input[0] != sessionSharingVersion { - return errors.Wrap(goolm.ErrBadVersion, "verify") + return fmt.Errorf("verify: %w", goolm.ErrBadVersion) } s.Counter = binary.BigEndian.Uint32(input[1:5]) copy(s.RatchetData[:], input[5:133]) diff --git a/crypto/goolm/olm/chain.go b/crypto/goolm/olm/chain.go index 187f5f6d..76db1eaa 100644 --- a/crypto/goolm/olm/chain.go +++ b/crypto/goolm/olm/chain.go @@ -1,10 +1,11 @@ package olm import ( - "codeberg.org/DerLukas/goolm" - "codeberg.org/DerLukas/goolm/crypto" - libolmpickle "codeberg.org/DerLukas/goolm/libolmPickle" - "github.com/pkg/errors" + "fmt" + + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" ) const ( @@ -44,11 +45,11 @@ func (r *chainKey) UnpickleLibOlm(value []byte) (int, error) { // It returns the number of bytes written. func (r chainKey) PickleLibOlm(target []byte) (int, error) { if len(target) < r.PickleLen() { - return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle chain key") + return 0, fmt.Errorf("pickle chain key: %w", goolm.ErrValueTooShort) } written, err := r.Key.PickleLibOlm(target) if err != nil { - return 0, errors.Wrap(err, "pickle chain key") + return 0, fmt.Errorf("pickle chain key: %w", err) } written += libolmpickle.PickleUInt32(r.Index, target[written:]) return written, nil @@ -115,15 +116,15 @@ func (r *senderChain) UnpickleLibOlm(value []byte) (int, error) { // It returns the number of bytes written. func (r senderChain) PickleLibOlm(target []byte) (int, error) { if len(target) < r.PickleLen() { - return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle sender chain") + return 0, fmt.Errorf("pickle sender chain: %w", goolm.ErrValueTooShort) } written, err := r.RKey.PickleLibOlm(target) if err != nil { - return 0, errors.Wrap(err, "pickle sender chain") + return 0, fmt.Errorf("pickle sender chain: %w", err) } writtenChain, err := r.CKey.PickleLibOlm(target[written:]) if err != nil { - return 0, errors.Wrap(err, "pickle sender chain") + return 0, fmt.Errorf("pickle sender chain: %w", err) } written += writtenChain return written, nil @@ -188,15 +189,15 @@ func (r *receiverChain) UnpickleLibOlm(value []byte) (int, error) { // It returns the number of bytes written. func (r receiverChain) PickleLibOlm(target []byte) (int, error) { if len(target) < r.PickleLen() { - return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle sender chain") + return 0, fmt.Errorf("pickle sender chain: %w", goolm.ErrValueTooShort) } written, err := r.RKey.PickleLibOlm(target) if err != nil { - return 0, errors.Wrap(err, "pickle sender chain") + return 0, fmt.Errorf("pickle sender chain: %w", err) } writtenChain, err := r.CKey.PickleLibOlm(target) if err != nil { - return 0, errors.Wrap(err, "pickle sender chain") + return 0, fmt.Errorf("pickle sender chain: %w", err) } written += writtenChain return written, nil @@ -237,7 +238,7 @@ func (m *messageKey) UnpickleLibOlm(value []byte) (int, error) { // It returns the number of bytes written. func (m messageKey) PickleLibOlm(target []byte) (int, error) { if len(target) < m.PickleLen() { - return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle message key") + return 0, fmt.Errorf("pickle message key: %w", goolm.ErrValueTooShort) } written := 0 if len(m.Key) != messageKeyLength { diff --git a/crypto/goolm/olm/olm.go b/crypto/goolm/olm/olm.go index 7e1e3e56..2d8542fd 100644 --- a/crypto/goolm/olm/olm.go +++ b/crypto/goolm/olm/olm.go @@ -2,15 +2,15 @@ package olm import ( + "fmt" "io" - "codeberg.org/DerLukas/goolm" - "codeberg.org/DerLukas/goolm/cipher" - "codeberg.org/DerLukas/goolm/crypto" - libolmpickle "codeberg.org/DerLukas/goolm/libolmPickle" - "codeberg.org/DerLukas/goolm/message" - "codeberg.org/DerLukas/goolm/utilities" - "github.com/pkg/errors" + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" + "maunium.net/go/mautrix/crypto/goolm/message" + "maunium.net/go/mautrix/crypto/goolm/utilities" ) const ( @@ -36,7 +36,7 @@ var KdfInfo = struct { Ratchet: []byte("OLM_RATCHET"), } -var RatchetCipher = cipher.NewAESSha256([]byte("OLM_KEYS")) +var RatchetCipher = cipher.NewAESSHA256([]byte("OLM_KEYS")) // Ratchet represents the olm ratchet as described in // @@ -115,7 +115,7 @@ func (r *Ratchet) Encrypt(plaintext []byte, reader io.Reader) ([]byte, error) { encryptedText, err := RatchetCipher.Encrypt(messageKey.Key, plaintext) if err != nil { - return nil, errors.Wrap(err, "cipher encrypt") + return nil, fmt.Errorf("cipher encrypt: %w", err) } message := &message.Message{} @@ -141,10 +141,10 @@ func (r *Ratchet) Decrypt(input []byte) ([]byte, error) { return nil, err } if message.Version != protocolVersion { - return nil, errors.Wrap(goolm.ErrWrongProtocolVersion, "decrypt") + return nil, fmt.Errorf("decrypt: %w", goolm.ErrWrongProtocolVersion) } if !message.HasCounter || len(message.RatchetKey) == 0 || len(message.Ciphertext) == 0 { - return nil, errors.Wrap(goolm.ErrBadMessageFormat, "decrypt") + return nil, fmt.Errorf("decrypt: %w", goolm.ErrBadMessageFormat) } var receiverChainFromMessage *receiverChain for curChainIndex := range r.ReceiverChains { @@ -173,11 +173,11 @@ func (r *Ratchet) Decrypt(input []byte) ([]byte, error) { return nil, err } if !verified { - return nil, errors.Wrap(goolm.ErrBadMAC, "decrypt from skipped message keys") + return nil, fmt.Errorf("decrypt from skipped message keys: %w", goolm.ErrBadMAC) } result, err = RatchetCipher.Decrypt(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, message.Ciphertext) if err != nil { - return nil, errors.Wrap(err, "cipher decrypt") + return nil, fmt.Errorf("cipher decrypt: %w", err) } if len(result) != 0 { // Remove the key from the skipped keys now that we've @@ -189,7 +189,7 @@ func (r *Ratchet) Decrypt(input []byte) ([]byte, error) { } } if !foundSkippedKey { - return nil, errors.Wrap(goolm.ErrMessageKeyNotFound, "decrypt") + return nil, fmt.Errorf("decrypt: %w", goolm.ErrMessageKeyNotFound) } } else { //Advancing the chain is done in this method @@ -228,11 +228,11 @@ func (r Ratchet) createMessageKeys(chainKey chainKey) messageKey { // decryptForExistingChain returns the decrypted message by using the chain. The MAC of the rawMessage is verified. func (r *Ratchet) decryptForExistingChain(chain *receiverChain, message *message.Message, rawMessage []byte) ([]byte, error) { if message.Counter < chain.CKey.Index { - return nil, errors.Wrap(goolm.ErrChainTooHigh, "decrypt") + return nil, fmt.Errorf("decrypt: %w", goolm.ErrChainTooHigh) } // Limit the number of hashes we're prepared to compute if message.Counter-chain.CKey.Index > maxMessageGap { - return nil, errors.Wrap(goolm.ErrMsgIndexTooHigh, "decrypt from existing chain") + return nil, fmt.Errorf("decrypt from existing chain: %w", goolm.ErrMsgIndexTooHigh) } for chain.CKey.Index < message.Counter { messageKey := r.createMessageKeys(chain.chainKey()) @@ -250,7 +250,7 @@ func (r *Ratchet) decryptForExistingChain(chain *receiverChain, message *message return nil, err } if !verified { - return nil, errors.Wrap(goolm.ErrBadMAC, "decrypt from existing chain") + return nil, fmt.Errorf("decrypt from existing chain: %w", goolm.ErrBadMAC) } return RatchetCipher.Decrypt(messageKey.Key, message.Ciphertext) } @@ -260,11 +260,11 @@ func (r *Ratchet) decryptForNewChain(message *message.Message, rawMessage []byte // They shouldn't move to a new chain until we've sent them a message // acknowledging the last one if !r.SenderChains.IsSet { - return nil, errors.Wrap(goolm.ErrProtocolViolation, "decrypt for new chain") + return nil, fmt.Errorf("decrypt for new chain: %w", goolm.ErrProtocolViolation) } // Limit the number of hashes we're prepared to compute if message.Counter > maxMessageGap { - return nil, errors.Wrap(goolm.ErrMsgIndexTooHigh, "decrypt for new chain") + return nil, fmt.Errorf("decrypt for new chain: %w", goolm.ErrMsgIndexTooHigh) } newChainKey, err := r.advanceRootKey(r.SenderChains.ratchetKey(), message.RatchetKey) @@ -371,17 +371,17 @@ func (r *Ratchet) UnpickleLibOlm(value []byte, includesChainIndex bool) (int, er // It returns the number of bytes written. func (r Ratchet) PickleLibOlm(target []byte) (int, error) { if len(target) < r.PickleLen() { - return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle ratchet") + return 0, fmt.Errorf("pickle ratchet: %w", goolm.ErrValueTooShort) } written, err := r.RootKey.PickleLibOlm(target) if err != nil { - return 0, errors.Wrap(err, "pickle ratchet") + return 0, fmt.Errorf("pickle ratchet: %w", err) } if r.SenderChains.IsSet { written += libolmpickle.PickleUInt32(1, target[written:]) //Length of sender chain, always 1 writtenSender, err := r.SenderChains.PickleLibOlm(target[written:]) if err != nil { - return 0, errors.Wrap(err, "pickle ratchet") + return 0, fmt.Errorf("pickle ratchet: %w", err) } written += writtenSender } else { @@ -391,7 +391,7 @@ func (r Ratchet) PickleLibOlm(target []byte) (int, error) { for _, curChain := range r.ReceiverChains { writtenChain, err := curChain.PickleLibOlm(target[written:]) if err != nil { - return 0, errors.Wrap(err, "pickle ratchet") + return 0, fmt.Errorf("pickle ratchet: %w", err) } written += writtenChain } @@ -399,7 +399,7 @@ func (r Ratchet) PickleLibOlm(target []byte) (int, error) { for _, curChain := range r.SkippedMessageKeys { writtenChain, err := curChain.PickleLibOlm(target[written:]) if err != nil { - return 0, errors.Wrap(err, "pickle ratchet") + return 0, fmt.Errorf("pickle ratchet: %w", err) } written += writtenChain } diff --git a/crypto/goolm/olm/olm_test.go b/crypto/goolm/olm/olm_test.go index 4f70ae81..f97a0aeb 100644 --- a/crypto/goolm/olm/olm_test.go +++ b/crypto/goolm/olm/olm_test.go @@ -1,29 +1,30 @@ -package olm +package olm_test import ( "bytes" "encoding/json" "testing" - "codeberg.org/DerLukas/goolm/cipher" - "codeberg.org/DerLukas/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/olm" ) var ( sharedSecret = []byte("A secret") ) -func initializeRatchets() (*Ratchet, *Ratchet, error) { - KdfInfo = struct { +func initializeRatchets() (*olm.Ratchet, *olm.Ratchet, error) { + olm.KdfInfo = struct { Root []byte Ratchet []byte }{ Root: []byte("Olm"), Ratchet: []byte("OlmRatchet"), } - RatchetCipher = cipher.NewAESSha256([]byte("OlmMessageKeys")) - aliceRatchet := New() - bobRatchet := New() + olm.RatchetCipher = cipher.NewAESSHA256([]byte("OlmMessageKeys")) + aliceRatchet := olm.New() + bobRatchet := olm.New() aliceKey, err := crypto.Curve25519GenerateKey(nil) if err != nil { @@ -162,7 +163,7 @@ func TestJSONEncoding(t *testing.T) { t.Fatal(err) } - newRatcher := Ratchet{} + newRatcher := olm.Ratchet{} err = json.Unmarshal(marshaled, &newRatcher) if err != nil { t.Fatal(err) diff --git a/crypto/goolm/olm/skippedMessage.go b/crypto/goolm/olm/skipped_message.go similarity index 82% rename from crypto/goolm/olm/skippedMessage.go rename to crypto/goolm/olm/skipped_message.go index 893d548d..93d7c283 100644 --- a/crypto/goolm/olm/skippedMessage.go +++ b/crypto/goolm/olm/skipped_message.go @@ -1,9 +1,10 @@ package olm import ( - "codeberg.org/DerLukas/goolm" - "codeberg.org/DerLukas/goolm/crypto" - "github.com/pkg/errors" + "fmt" + + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/crypto" ) // skippedMessageKey stores a skipped message key @@ -32,15 +33,15 @@ func (r *skippedMessageKey) UnpickleLibOlm(value []byte) (int, error) { // It returns the number of bytes written. func (r skippedMessageKey) PickleLibOlm(target []byte) (int, error) { if len(target) < r.PickleLen() { - return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle sender chain") + return 0, fmt.Errorf("pickle sender chain: %w", goolm.ErrValueTooShort) } written, err := r.RKey.PickleLibOlm(target) if err != nil { - return 0, errors.Wrap(err, "pickle sender chain") + return 0, fmt.Errorf("pickle sender chain: %w", err) } writtenChain, err := r.MKey.PickleLibOlm(target) if err != nil { - return 0, errors.Wrap(err, "pickle sender chain") + return 0, fmt.Errorf("pickle sender chain: %w", err) } written += writtenChain return written, nil diff --git a/crypto/goolm/pk/decryption.go b/crypto/goolm/pk/decryption.go index e530b037..c94bfd80 100644 --- a/crypto/goolm/pk/decryption.go +++ b/crypto/goolm/pk/decryption.go @@ -1,12 +1,15 @@ package pk import ( - "codeberg.org/DerLukas/goolm" - "codeberg.org/DerLukas/goolm/cipher" - "codeberg.org/DerLukas/goolm/crypto" - libolmpickle "codeberg.org/DerLukas/goolm/libolmPickle" - "codeberg.org/DerLukas/goolm/utilities" - "github.com/pkg/errors" + "encoding/base64" + "errors" + "fmt" + + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" + "maunium.net/go/mautrix/crypto/goolm/utilities" "maunium.net/go/mautrix/id" ) @@ -54,7 +57,7 @@ func (s Decription) PrivateKey() crypto.Curve25519PrivateKey { // Decrypt decrypts the ciphertext and verifies the MAC. The base64 encoded key is used to construct the shared secret. func (s Decription) Decrypt(ciphertext, mac []byte, key id.Curve25519) ([]byte, error) { - keyDecoded, err := goolm.Base64Decode([]byte(key)) + keyDecoded, err := base64.RawStdEncoding.DecodeString(string(key)) if err != nil { return nil, err } @@ -66,13 +69,13 @@ func (s Decription) Decrypt(ciphertext, mac []byte, key id.Curve25519) ([]byte, if err != nil { return nil, err } - cipher := cipher.NewAESSha256(nil) + cipher := cipher.NewAESSHA256(nil) verified, err := cipher.Verify(sharedSecret, ciphertext, decodedMAC) if err != nil { return nil, err } if !verified { - return nil, errors.Wrap(goolm.ErrBadMAC, "decrypt") + return nil, fmt.Errorf("decrypt: %w", goolm.ErrBadMAC) } plaintext, err := cipher.Decrypt(sharedSecret, ciphertext) if err != nil { @@ -112,7 +115,7 @@ func (a *Decription) UnpickleLibOlm(value []byte) (int, error) { switch pickledVersion { case decryptionPickleVersionLibOlm: default: - return 0, errors.Wrap(goolm.ErrBadVersion, "unpickle olmSession") + return 0, fmt.Errorf("unpickle olmSession: %w", goolm.ErrBadVersion) } readBytes, err := a.KeyPair.UnpickleLibOlm(value[curPos:]) if err != nil { @@ -143,12 +146,12 @@ func (a Decription) Pickle(key []byte) ([]byte, error) { // It returns the number of bytes written. func (a Decription) PickleLibOlm(target []byte) (int, error) { if len(target) < a.PickleLen() { - return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle Decription") + return 0, fmt.Errorf("pickle Decription: %w", goolm.ErrValueTooShort) } written := libolmpickle.PickleUInt32(decryptionPickleVersionLibOlm, target) writtenKey, err := a.KeyPair.PickleLibOlm(target[written:]) if err != nil { - return 0, errors.Wrap(err, "pickle Decription") + return 0, fmt.Errorf("pickle Decription: %w", err) } written += writtenKey return written, nil diff --git a/crypto/goolm/pk/encryption.go b/crypto/goolm/pk/encryption.go index 2f64b7b2..19d9688a 100644 --- a/crypto/goolm/pk/encryption.go +++ b/crypto/goolm/pk/encryption.go @@ -1,10 +1,13 @@ package pk import ( - "codeberg.org/DerLukas/goolm" - "codeberg.org/DerLukas/goolm/cipher" - "codeberg.org/DerLukas/goolm/crypto" + "encoding/base64" + "maunium.net/go/mautrix/id" + + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/crypto" ) // Encryption is used to encrypt pk messages @@ -14,7 +17,7 @@ type Encryption struct { // NewEncryption returns a new Encryption with the base64 encoded public key of the recipient func NewEncryption(pubKey id.Curve25519) (*Encryption, error) { - pubKeyDecoded, err := goolm.Base64Decode([]byte(pubKey)) + pubKeyDecoded, err := base64.RawStdEncoding.DecodeString(string(pubKey)) if err != nil { return nil, err } @@ -33,7 +36,7 @@ func (e Encryption) Encrypt(plaintext []byte, privateKey crypto.Curve25519Privat if err != nil { return nil, nil, err } - cipher := cipher.NewAESSha256(nil) + cipher := cipher.NewAESSHA256(nil) ciphertext, err = cipher.Encrypt(sharedSecret, plaintext) if err != nil { return nil, nil, err diff --git a/crypto/goolm/pk/pk_test.go b/crypto/goolm/pk/pk_test.go index 58918bea..72a48767 100644 --- a/crypto/goolm/pk/pk_test.go +++ b/crypto/goolm/pk/pk_test.go @@ -1,11 +1,13 @@ -package pk +package pk_test import ( "bytes" + "encoding/base64" "testing" - "codeberg.org/DerLukas/goolm" - "codeberg.org/DerLukas/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/pk" "maunium.net/go/mautrix/id" ) @@ -24,7 +26,7 @@ func TestEncryptionDecryption(t *testing.T) { 0x1C, 0x2F, 0x8B, 0x27, 0xFF, 0x88, 0xE0, 0xEB, } bobPublic := []byte("3p7bfXt9wbTTW2HC7OQ1Nz+DQ8hbeGdNrfx+FG+IK08") - decryption, err := NewDecriptionFromPrivate(alicePrivate) + decryption, err := pk.NewDecriptionFromPrivate(alicePrivate) if err != nil { t.Fatal(err) } @@ -35,7 +37,7 @@ func TestEncryptionDecryption(t *testing.T) { t.Fatal("private key not correct") } - encryption, err := NewEncryption(decryption.PubKey()) + encryption, err := pk.NewEncryption(decryption.PubKey()) if err != nil { t.Fatal(err) } @@ -63,14 +65,14 @@ func TestSigning(t *testing.T) { 0xB1, 0x77, 0xFB, 0xA5, 0x1D, 0xB9, 0x2C, 0x2A, } message := []byte("We hold these truths to be self-evident, that all men are created equal, that they are endowed by their Creator with certain unalienable Rights, that among these are Life, Liberty and the pursuit of Happiness.") - signing, _ := NewSigningFromSeed(seed) + signing, _ := pk.NewSigningFromSeed(seed) signature := signing.Sign(message) signatureDecoded, err := goolm.Base64Decode(signature) if err != nil { t.Fatal(err) } pubKeyEncoded := signing.PublicKey() - pubKeyDecoded, err := goolm.Base64Decode([]byte(pubKeyEncoded)) + pubKeyDecoded, err := base64.RawStdEncoding.DecodeString(string(pubKeyEncoded)) if err != nil { t.Fatal(err) } @@ -95,7 +97,7 @@ func TestDecryptionPickling(t *testing.T) { 0xB1, 0x77, 0xFB, 0xA5, 0x1D, 0xB9, 0x2C, 0x2A, } alicePublic := []byte("hSDwCYkwp1R0i33ctD73Wg2/Og0mOBr066SpjqqbTmo") - decryption, err := NewDecriptionFromPrivate(alicePrivate) + decryption, err := pk.NewDecriptionFromPrivate(alicePrivate) if err != nil { t.Fatal(err) } @@ -115,7 +117,7 @@ func TestDecryptionPickling(t *testing.T) { t.Fatalf("pickle not as expected:\n%v\n%v\n", pickled, expectedPickle) } - newDecription, err := NewDecription() + newDecription, err := pk.NewDecription() if err != nil { t.Fatal(err) } diff --git a/crypto/goolm/pk/signing.go b/crypto/goolm/pk/signing.go index a80cc5c8..493705f6 100644 --- a/crypto/goolm/pk/signing.go +++ b/crypto/goolm/pk/signing.go @@ -3,8 +3,8 @@ package pk import ( "crypto/rand" - "codeberg.org/DerLukas/goolm" - "codeberg.org/DerLukas/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/id" ) diff --git a/crypto/goolm/sas/main.go b/crypto/goolm/sas/main.go index b1f55069..7337d5f9 100644 --- a/crypto/goolm/sas/main.go +++ b/crypto/goolm/sas/main.go @@ -4,8 +4,8 @@ package sas import ( "io" - "codeberg.org/DerLukas/goolm" - "codeberg.org/DerLukas/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/crypto" ) // SAS contains the key pair and secret for SAS. diff --git a/crypto/goolm/sas/main_test.go b/crypto/goolm/sas/main_test.go index 6e897a32..c0acec70 100644 --- a/crypto/goolm/sas/main_test.go +++ b/crypto/goolm/sas/main_test.go @@ -1,13 +1,14 @@ -package sas +package sas_test import ( "bytes" "testing" - "codeberg.org/DerLukas/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/sas" ) -func initSAS() (*SAS, *SAS, error) { +func initSAS() (*sas.SAS, *sas.SAS, error) { alicePrivate := crypto.Curve25519PrivateKey([]byte{ 0x77, 0x07, 0x6D, 0x0A, 0x73, 0x18, 0xA5, 0x7D, 0x3C, 0x16, 0xC1, 0x72, 0x51, 0xB2, 0x66, 0x45, @@ -21,7 +22,7 @@ func initSAS() (*SAS, *SAS, error) { 0x1C, 0x2F, 0x8B, 0x27, 0xFF, 0x88, 0xE0, 0xEB, }) - aliceSAS, err := New() + aliceSAS, err := sas.New() if err != nil { return nil, nil, err } @@ -31,7 +32,7 @@ func initSAS() (*SAS, *SAS, error) { return nil, nil, err } - bobSAS, err := New() + bobSAS, err := sas.New() if err != nil { return nil, nil, err } diff --git a/crypto/goolm/session/megolmInboundSession.go b/crypto/goolm/session/megolm_inbound_session.go similarity index 86% rename from crypto/goolm/session/megolmInboundSession.go rename to crypto/goolm/session/megolm_inbound_session.go index 5f1d1c8f..8214aefc 100644 --- a/crypto/goolm/session/megolmInboundSession.go +++ b/crypto/goolm/session/megolm_inbound_session.go @@ -1,14 +1,17 @@ package session import ( - "codeberg.org/DerLukas/goolm" - "codeberg.org/DerLukas/goolm/cipher" - "codeberg.org/DerLukas/goolm/crypto" - libolmpickle "codeberg.org/DerLukas/goolm/libolmPickle" - "codeberg.org/DerLukas/goolm/megolm" - "codeberg.org/DerLukas/goolm/message" - "codeberg.org/DerLukas/goolm/utilities" - "github.com/pkg/errors" + "encoding/base64" + "errors" + "fmt" + + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" + "maunium.net/go/mautrix/crypto/goolm/megolm" + "maunium.net/go/mautrix/crypto/goolm/message" + "maunium.net/go/mautrix/crypto/goolm/utilities" "maunium.net/go/mautrix/id" ) @@ -75,7 +78,7 @@ func NewMegolmInboundSessionFromExport(input []byte) (*MegolmInboundSession, err // MegolmInboundSessionFromPickled loads the MegolmInboundSession details from a pickled base64 string. The input is decrypted with the supplied key. func MegolmInboundSessionFromPickled(pickled, key []byte) (*MegolmInboundSession, error) { if len(pickled) == 0 { - return nil, errors.Wrap(goolm.ErrEmptyInput, "megolmInboundSessionFromPickled") + return nil, fmt.Errorf("megolmInboundSessionFromPickled: %w", goolm.ErrEmptyInput) } a := &MegolmInboundSession{} err := a.Unpickle(pickled, key) @@ -94,7 +97,7 @@ func (o MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, } if (messageIndex - o.InitialRatchet.Counter) >= uint32(1<<31) { // the counter is before our initial ratchet - we can't decode this - return nil, errors.Wrap(goolm.ErrRatchetNotAvailable, "decrypt") + return nil, fmt.Errorf("decrypt: %w", goolm.ErrRatchetNotAvailable) } // otherwise, start from the initial ratchet. Take a copy so that we don't overwrite the initial ratchet copiedRatchet := o.InitialRatchet @@ -106,7 +109,7 @@ func (o MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, // Decrypt decrypts a base64 encoded group message. func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint32, error) { if o.SigningKey == nil { - return nil, 0, errors.Wrap(goolm.ErrBadMessageFormat, "decrypt") + return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrBadMessageFormat) } decoded, err := goolm.Base64Decode(ciphertext) if err != nil { @@ -118,16 +121,16 @@ func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint32, error return nil, 0, err } if msg.Version != protocolVersion { - return nil, 0, errors.Wrap(goolm.ErrWrongProtocolVersion, "decrypt") + return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrWrongProtocolVersion) } if msg.Ciphertext == nil || !msg.HasMessageIndex { - return nil, 0, errors.Wrap(goolm.ErrBadMessageFormat, "decrypt") + return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrBadMessageFormat) } // verify signature verifiedSignature := msg.VerifySignatureInline(o.SigningKey, decoded) if !verifiedSignature { - return nil, 0, errors.Wrap(goolm.ErrBadSignature, "decrypt") + return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrBadSignature) } targetRatch, err := o.getRatchet(msg.MessageIndex) @@ -146,7 +149,7 @@ func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint32, error // SessionID returns the base64 endoded signing key func (o MegolmInboundSession) SessionID() id.SessionID { - return id.SessionID(goolm.Base64Encode(o.SigningKey)) + return id.SessionID(base64.RawStdEncoding.EncodeToString(o.SigningKey)) } // PickleAsJSON returns an MegolmInboundSession as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. @@ -189,7 +192,7 @@ func (o *MegolmInboundSession) UnpickleLibOlm(value []byte) (int, error) { switch pickledVersion { case megolmInboundSessionPickleVersionLibOlm, 1: default: - return 0, errors.Wrap(goolm.ErrBadVersion, "unpickle MegolmInboundSession") + return 0, fmt.Errorf("unpickle MegolmInboundSession: %w", goolm.ErrBadVersion) } readBytes, err := o.InitialRatchet.UnpickleLibOlm(value[curPos:]) if err != nil { @@ -240,22 +243,22 @@ func (o MegolmInboundSession) Pickle(key []byte) ([]byte, error) { // It returns the number of bytes written. func (o MegolmInboundSession) PickleLibOlm(target []byte) (int, error) { if len(target) < o.PickleLen() { - return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle MegolmInboundSession") + return 0, fmt.Errorf("pickle MegolmInboundSession: %w", goolm.ErrValueTooShort) } written := libolmpickle.PickleUInt32(megolmInboundSessionPickleVersionLibOlm, target) writtenInitRatchet, err := o.InitialRatchet.PickleLibOlm(target[written:]) if err != nil { - return 0, errors.Wrap(err, "pickle MegolmInboundSession") + return 0, fmt.Errorf("pickle MegolmInboundSession: %w", err) } written += writtenInitRatchet writtenRatchet, err := o.Ratchet.PickleLibOlm(target[written:]) if err != nil { - return 0, errors.Wrap(err, "pickle MegolmInboundSession") + return 0, fmt.Errorf("pickle MegolmInboundSession: %w", err) } written += writtenRatchet writtenPubKey, err := o.SigningKey.PickleLibOlm(target[written:]) if err != nil { - return 0, errors.Wrap(err, "pickle MegolmInboundSession") + return 0, fmt.Errorf("pickle MegolmInboundSession: %w", err) } written += writtenPubKey written += libolmpickle.PickleBool(o.SigningKeyVerified, target[written:]) diff --git a/crypto/goolm/session/megolmOutboundSession.go b/crypto/goolm/session/megolm_outbound_session.go similarity index 86% rename from crypto/goolm/session/megolmOutboundSession.go rename to crypto/goolm/session/megolm_outbound_session.go index 23deb43f..11aadb00 100644 --- a/crypto/goolm/session/megolmOutboundSession.go +++ b/crypto/goolm/session/megolm_outbound_session.go @@ -1,16 +1,19 @@ package session import ( + "encoding/base64" + "errors" + "fmt" "math/rand" - "codeberg.org/DerLukas/goolm" - "codeberg.org/DerLukas/goolm/cipher" - "codeberg.org/DerLukas/goolm/crypto" - libolmpickle "codeberg.org/DerLukas/goolm/libolmPickle" - "codeberg.org/DerLukas/goolm/megolm" - "codeberg.org/DerLukas/goolm/utilities" - "github.com/pkg/errors" "maunium.net/go/mautrix/id" + + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" + "maunium.net/go/mautrix/crypto/goolm/megolm" + "maunium.net/go/mautrix/crypto/goolm/utilities" ) const ( @@ -48,7 +51,7 @@ func NewMegolmOutboundSession() (*MegolmOutboundSession, error) { // MegolmOutboundSessionFromPickled loads the MegolmOutboundSession details from a pickled base64 string. The input is decrypted with the supplied key. func MegolmOutboundSessionFromPickled(pickled, key []byte) (*MegolmOutboundSession, error) { if len(pickled) == 0 { - return nil, errors.Wrap(goolm.ErrEmptyInput, "megolmOutboundSessionFromPickled") + return nil, fmt.Errorf("megolmOutboundSessionFromPickled: %w", goolm.ErrEmptyInput) } a := &MegolmOutboundSession{} err := a.Unpickle(pickled, key) @@ -69,7 +72,7 @@ func (o *MegolmOutboundSession) Encrypt(plaintext []byte) ([]byte, error) { // SessionID returns the base64 endoded public signing key func (o MegolmOutboundSession) SessionID() id.SessionID { - return id.SessionID(goolm.Base64Encode(o.SigningKey.PublicKey)) + return id.SessionID(base64.RawStdEncoding.EncodeToString(o.SigningKey.PublicKey)) } // PickleAsJSON returns an Session as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. @@ -103,7 +106,7 @@ func (o *MegolmOutboundSession) UnpickleLibOlm(value []byte) (int, error) { switch pickledVersion { case megolmOutboundSessionPickleVersionLibOlm: default: - return 0, errors.Wrap(goolm.ErrBadVersion, "unpickle MegolmInboundSession") + return 0, fmt.Errorf("unpickle MegolmInboundSession: %w", goolm.ErrBadVersion) } readBytes, err := o.Ratchet.UnpickleLibOlm(value[curPos:]) if err != nil { @@ -139,17 +142,17 @@ func (o MegolmOutboundSession) Pickle(key []byte) ([]byte, error) { // It returns the number of bytes written. func (o MegolmOutboundSession) PickleLibOlm(target []byte) (int, error) { if len(target) < o.PickleLen() { - return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle MegolmOutboundSession") + return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", goolm.ErrValueTooShort) } written := libolmpickle.PickleUInt32(megolmOutboundSessionPickleVersionLibOlm, target) writtenRatchet, err := o.Ratchet.PickleLibOlm(target[written:]) if err != nil { - return 0, errors.Wrap(err, "pickle MegolmOutboundSession") + return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err) } written += writtenRatchet writtenPubKey, err := o.SigningKey.PickleLibOlm(target[written:]) if err != nil { - return 0, errors.Wrap(err, "pickle MegolmOutboundSession") + return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err) } written += writtenPubKey return written, nil diff --git a/crypto/goolm/session/melgomSession_test.go b/crypto/goolm/session/megolm_session_test.go similarity index 78% rename from crypto/goolm/session/melgomSession_test.go rename to crypto/goolm/session/megolm_session_test.go index 55f57ad4..93eec7eb 100644 --- a/crypto/goolm/session/melgomSession_test.go +++ b/crypto/goolm/session/megolm_session_test.go @@ -1,4 +1,4 @@ -package session +package session_test import ( "bytes" @@ -6,14 +6,15 @@ import ( "math/rand" "testing" - "codeberg.org/DerLukas/goolm" - "codeberg.org/DerLukas/goolm/crypto" - "codeberg.org/DerLukas/goolm/megolm" + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/megolm" + "maunium.net/go/mautrix/crypto/goolm/session" ) func TestOutboundPickleJSON(t *testing.T) { pickleKey := []byte("secretKey") - session, err := NewMegolmOutboundSession() + sess, err := session.NewMegolmOutboundSession() if err != nil { t.Fatal(err) } @@ -21,39 +22,39 @@ func TestOutboundPickleJSON(t *testing.T) { if err != nil { t.Fatal(err) } - session.SigningKey = kp - pickled, err := session.PickleAsJSON(pickleKey) + sess.SigningKey = kp + pickled, err := sess.PickleAsJSON(pickleKey) if err != nil { t.Fatal(err) } - newSession := MegolmOutboundSession{} + newSession := session.MegolmOutboundSession{} err = newSession.UnpickleAsJSON(pickled, pickleKey) if err != nil { t.Fatal(err) } - if session.SessionID() != newSession.SessionID() { + if sess.SessionID() != newSession.SessionID() { t.Fatal("session ids not equal") } - if !bytes.Equal(session.SigningKey.PrivateKey, newSession.SigningKey.PrivateKey) { + if !bytes.Equal(sess.SigningKey.PrivateKey, newSession.SigningKey.PrivateKey) { t.Fatal("private keys not equal") } - if !bytes.Equal(session.Ratchet.Data[:], newSession.Ratchet.Data[:]) { + if !bytes.Equal(sess.Ratchet.Data[:], newSession.Ratchet.Data[:]) { t.Fatal("ratchet data not equal") } - if session.Ratchet.Counter != newSession.Ratchet.Counter { + if sess.Ratchet.Counter != newSession.Ratchet.Counter { t.Fatal("ratchet counter not equal") } } func TestInboundPickleJSON(t *testing.T) { pickleKey := []byte("secretKey") - session := MegolmInboundSession{} + sess := session.MegolmInboundSession{} kp, err := crypto.Ed25519GenerateKey(nil) if err != nil { t.Fatal(err) } - session.SigningKey = kp.PublicKey + sess.SigningKey = kp.PublicKey var randomData [megolm.RatchetParts * megolm.RatchetPartLength]byte _, err = rand.Read(randomData[:]) if err != nil { @@ -63,27 +64,27 @@ func TestInboundPickleJSON(t *testing.T) { if err != nil { t.Fatal(err) } - session.Ratchet = *ratchet - pickled, err := session.PickleAsJSON(pickleKey) + sess.Ratchet = *ratchet + pickled, err := sess.PickleAsJSON(pickleKey) if err != nil { t.Fatal(err) } - newSession := MegolmInboundSession{} + newSession := session.MegolmInboundSession{} err = newSession.UnpickleAsJSON(pickled, pickleKey) if err != nil { t.Fatal(err) } - if session.SessionID() != newSession.SessionID() { - t.Fatal("session ids not equal") + if sess.SessionID() != newSession.SessionID() { + t.Fatal("sess ids not equal") } - if !bytes.Equal(session.SigningKey, newSession.SigningKey) { + if !bytes.Equal(sess.SigningKey, newSession.SigningKey) { t.Fatal("private keys not equal") } - if !bytes.Equal(session.Ratchet.Data[:], newSession.Ratchet.Data[:]) { + if !bytes.Equal(sess.Ratchet.Data[:], newSession.Ratchet.Data[:]) { t.Fatal("ratchet data not equal") } - if session.Ratchet.Counter != newSession.Ratchet.Counter { + if sess.Ratchet.Counter != newSession.Ratchet.Counter { t.Fatal("ratchet counter not equal") } } @@ -98,7 +99,7 @@ func TestGroupSendReceive(t *testing.T) { "0123456789ABDEF0123456789ABCDEF", ) - outboundSession, err := NewMegolmOutboundSession() + outboundSession, err := session.NewMegolmOutboundSession() if err != nil { t.Fatal(err) } @@ -120,7 +121,7 @@ func TestGroupSendReceive(t *testing.T) { } //build inbound session - inboundSession, err := NewMegolmInboundSession(sessionSharing) + inboundSession, err := session.NewMegolmInboundSession(sessionSharing) if err != nil { t.Fatal(err) } @@ -156,7 +157,7 @@ func TestGroupSessionExportImport(t *testing.T) { ) //init inbound - inboundSession, err := NewMegolmInboundSession(sessionKey) + inboundSession, err := session.NewMegolmInboundSession(sessionKey) if err != nil { t.Fatal(err) } @@ -178,7 +179,7 @@ func TestGroupSessionExportImport(t *testing.T) { t.Fatal(err) } - secondInboundSession, err := NewMegolmInboundSessionFromExport(exported) + secondInboundSession, err := session.NewMegolmInboundSessionFromExport(exported) if err != nil { t.Fatal(err) } @@ -213,7 +214,7 @@ func TestBadSignatureGroupMessage(t *testing.T) { ) //init inbound - inboundSession, err := NewMegolmInboundSession(sessionKey) + inboundSession, err := session.NewMegolmInboundSession(sessionKey) if err != nil { t.Fatal(err) } @@ -243,11 +244,11 @@ func TestBadSignatureGroupMessage(t *testing.T) { func TestOutbountPickle(t *testing.T) { pickledDataFromLibOlm := []byte("icDKYm0b4aO23WgUuOxdpPoxC0UlEOYPVeuduNH3IkpFsmnWx5KuEOpxGiZw5IuB/sSn2RZUCTiJ90IvgC7AClkYGHep9O8lpiqQX73XVKD9okZDCAkBc83eEq0DKYC7HBkGRAU/4T6QPIBBY3UK4QZwULLE/fLsi3j4YZBehMtnlsqgHK0q1bvX4cRznZItUO3TiOp5I+6PnQka6n8eHTyIEh3tCetilD+BKnHvtakE0eHHvG6pjEsMNN/vs7lkB5rV6XkoUKHLTE1dAfFunYEeHEZuKQpbG385dBwaMJXt4JrC0hU5jnv6jWNqAA0Ud9GxRDvkp04") pickleKey := []byte("secret_key") - session, err := MegolmOutboundSessionFromPickled(pickledDataFromLibOlm, pickleKey) + sess, err := session.MegolmOutboundSessionFromPickled(pickledDataFromLibOlm, pickleKey) if err != nil { t.Fatal(err) } - newPickled, err := session.Pickle(pickleKey) + newPickled, err := sess.Pickle(pickleKey) if err != nil { t.Fatal(err) } @@ -255,7 +256,7 @@ func TestOutbountPickle(t *testing.T) { t.Fatal("pickled version does not equal libolm version") } pickledDataFromLibOlm = append(pickledDataFromLibOlm, []byte("a")...) - _, err = MegolmOutboundSessionFromPickled(pickledDataFromLibOlm, pickleKey) + _, err = session.MegolmOutboundSessionFromPickled(pickledDataFromLibOlm, pickleKey) if err == nil { t.Fatal("should have gotten an error") } @@ -264,11 +265,11 @@ func TestOutbountPickle(t *testing.T) { func TestInbountPickle(t *testing.T) { pickledDataFromLibOlm := []byte("1/IPCdtUoQxMba5XT7sjjUW0Hrs7no9duGFnhsEmxzFX2H3qtRc4eaFBRZYXxOBRTGZ6eMgy3IiSrgAQ1gUlSZf5Q4AVKeBkhvN4LZ6hdhQFv91mM+C2C55/4B9/gDjJEbDGiRgLoMqbWPDV+y0F4h0KaR1V1PiTCC7zCi4WdxJQ098nJLgDL4VSsDbnaLcSMO60FOYgRN4KsLaKUGkXiiUBWp4boFMCiuTTOiyH8XlH0e9uWc0vMLyGNUcO8kCbpAnx3v1JTIVan3WGsnGv4K8Qu4M8GAkZewpexrsb2BSNNeLclOV9/cR203Y5KlzXcpiWNXSs8XoB3TLEtHYMnjuakMQfyrcXKIQntg4xPD/+wvfqkcMg9i7pcplQh7X2OK5ylrMZQrZkJ1fAYBGbBz1tykWOjfrZ") pickleKey := []byte("secret_key") - session, err := MegolmInboundSessionFromPickled(pickledDataFromLibOlm, pickleKey) + sess, err := session.MegolmInboundSessionFromPickled(pickledDataFromLibOlm, pickleKey) if err != nil { t.Fatal(err) } - newPickled, err := session.Pickle(pickleKey) + newPickled, err := sess.Pickle(pickleKey) if err != nil { t.Fatal(err) } @@ -276,7 +277,7 @@ func TestInbountPickle(t *testing.T) { t.Fatal("pickled version does not equal libolm version") } pickledDataFromLibOlm = append(pickledDataFromLibOlm, []byte("a")...) - _, err = MegolmInboundSessionFromPickled(pickledDataFromLibOlm, pickleKey) + _, err = session.MegolmInboundSessionFromPickled(pickledDataFromLibOlm, pickleKey) if err == nil { t.Fatal("should have gotten an error") } diff --git a/crypto/goolm/session/olmSession.go b/crypto/goolm/session/olm_session.go similarity index 89% rename from crypto/goolm/session/olmSession.go rename to crypto/goolm/session/olm_session.go index 20cd2af5..b5189c59 100644 --- a/crypto/goolm/session/olmSession.go +++ b/crypto/goolm/session/olm_session.go @@ -2,17 +2,18 @@ package session import ( "bytes" + "encoding/base64" + "errors" "fmt" "io" - "codeberg.org/DerLukas/goolm" - "codeberg.org/DerLukas/goolm/cipher" - "codeberg.org/DerLukas/goolm/crypto" - libolmpickle "codeberg.org/DerLukas/goolm/libolmPickle" - "codeberg.org/DerLukas/goolm/message" - "codeberg.org/DerLukas/goolm/olm" - "codeberg.org/DerLukas/goolm/utilities" - "github.com/pkg/errors" + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" + "maunium.net/go/mautrix/crypto/goolm/message" + "maunium.net/go/mautrix/crypto/goolm/olm" + "maunium.net/go/mautrix/crypto/goolm/utilities" "maunium.net/go/mautrix/id" ) @@ -41,7 +42,7 @@ type SearchOTKFunc = func(crypto.Curve25519PublicKey) *crypto.OneTimeKey // the Session using the supplied key. func OlmSessionFromJSONPickled(pickled, key []byte) (*OlmSession, error) { if len(pickled) == 0 { - return nil, errors.Wrap(goolm.ErrEmptyInput, "sessionFromPickled") + return nil, fmt.Errorf("sessionFromPickled: %w", goolm.ErrEmptyInput) } a := &OlmSession{} err := a.UnpickleAsJSON(pickled, key) @@ -54,7 +55,7 @@ func OlmSessionFromJSONPickled(pickled, key []byte) (*OlmSession, error) { // OlmSessionFromPickled loads the OlmSession details from a pickled base64 string. The input is decrypted with the supplied key. func OlmSessionFromPickled(pickled, key []byte) (*OlmSession, error) { if len(pickled) == 0 { - return nil, errors.Wrap(goolm.ErrEmptyInput, "sessionFromPickled") + return nil, fmt.Errorf("sessionFromPickled: %w", goolm.ErrEmptyInput) } a := &OlmSession{} err := a.Unpickle(pickled, key) @@ -126,10 +127,10 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received oneTimeMsg := message.PreKeyMessage{} err = oneTimeMsg.Decode(decodedOTKMsg) if err != nil { - return nil, errors.Wrap(err, "OneTimeKeyMessage decode") + return nil, fmt.Errorf("OneTimeKeyMessage decode: %w", err) } if !oneTimeMsg.CheckFields(identityKeyAlice) { - return nil, errors.Wrap(goolm.ErrBadMessageFormat, "OneTimeKeyMessage check fields") + return nil, fmt.Errorf("OneTimeKeyMessage check fields: %w", goolm.ErrBadMessageFormat) } //Either the identityKeyAlice is set and/or the oneTimeMsg.IdentityKey is set, which is checked @@ -137,7 +138,7 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received if identityKeyAlice != nil && len(oneTimeMsg.IdentityKey) != 0 { //if both are set, compare them if !identityKeyAlice.Equal(oneTimeMsg.IdentityKey) { - return nil, errors.Wrap(goolm.ErrBadMessageKeyID, "OneTimeKeyMessage identity keys") + return nil, fmt.Errorf("OneTimeKeyMessage identity keys: %w", goolm.ErrBadMessageKeyID) } } if identityKeyAlice == nil { @@ -147,7 +148,7 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received oneTimeKeyBob := searchBobOTK(oneTimeMsg.OneTimeKey) if oneTimeKeyBob == nil { - return nil, errors.Wrap(goolm.ErrBadMessageKeyID, "ourOneTimeKey") + return nil, fmt.Errorf("ourOneTimeKey: %w", goolm.ErrBadMessageKeyID) } //Calculate shared secret via Triple Diffie-Hellman @@ -174,11 +175,11 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received msg := message.Message{} err = msg.Decode(oneTimeMsg.Message) if err != nil { - return nil, errors.Wrap(err, "Message decode") + return nil, fmt.Errorf("Message decode: %w", err) } if len(msg.RatchetKey) == 0 { - return nil, errors.Wrap(goolm.ErrBadMessageFormat, "Message missing ratchet key") + return nil, fmt.Errorf("Message missing ratchet key: %w", goolm.ErrBadMessageFormat) } //Init Ratchet s.Ratchet.InitialiseAsBob(secret, msg.RatchetKey) @@ -224,7 +225,7 @@ func (s OlmSession) HasReceivedMessage() bool { // matches. Returns false if the session does not match. func (s OlmSession) MatchesInboundSessionFrom(theirIdentityKeyEncoded *id.Curve25519, receivedOTKMsg []byte) (bool, error) { if len(receivedOTKMsg) == 0 { - return false, errors.Wrap(goolm.ErrEmptyInput, "inbound match") + return false, fmt.Errorf("inbound match: %w", goolm.ErrEmptyInput) } decodedOTKMsg, err := goolm.Base64Decode(receivedOTKMsg) if err != nil { @@ -233,7 +234,7 @@ func (s OlmSession) MatchesInboundSessionFrom(theirIdentityKeyEncoded *id.Curve2 var theirIdentityKey *crypto.Curve25519PublicKey if theirIdentityKeyEncoded != nil { - decodedKey, err := goolm.Base64Decode([]byte(*theirIdentityKeyEncoded)) + decodedKey, err := base64.RawStdEncoding.DecodeString(string(*theirIdentityKeyEncoded)) if err != nil { return false, err } @@ -275,7 +276,7 @@ func (s OlmSession) EncryptMsgType() id.OlmMsgType { // Encrypt encrypts a message using the Session. Returns the encrypted message base64 encoded. If reader is nil, crypto/rand is used for key generations. func (s *OlmSession) Encrypt(plaintext []byte, reader io.Reader) (id.OlmMsgType, []byte, error) { if len(plaintext) == 0 { - return 0, nil, errors.Wrap(goolm.ErrEmptyInput, "encrypt") + return 0, nil, fmt.Errorf("encrypt: %w", goolm.ErrEmptyInput) } messageType := s.EncryptMsgType() encrypted, err := s.Ratchet.Encrypt(plaintext, reader) @@ -305,7 +306,7 @@ func (s *OlmSession) Encrypt(plaintext []byte, reader io.Reader) (id.OlmMsgType, // Decrypt decrypts a base64 encoded message using the Session. func (s *OlmSession) Decrypt(crypttext []byte, msgType id.OlmMsgType) ([]byte, error) { if len(crypttext) == 0 { - return nil, errors.Wrap(goolm.ErrEmptyInput, "decrypt") + return nil, fmt.Errorf("decrypt: %w", goolm.ErrEmptyInput) } decodedCrypttext, err := goolm.Base64Decode(crypttext) if err != nil { @@ -354,7 +355,7 @@ func (o *OlmSession) UnpickleLibOlm(value []byte) (int, error) { case uint32(0x80000001): includesChainIndex = true default: - return 0, errors.Wrap(goolm.ErrBadVersion, "unpickle olmSession") + return 0, fmt.Errorf("unpickle olmSession: %w", goolm.ErrBadVersion) } var readBytes int o.RecievedMessage, readBytes, err = libolmpickle.UnpickleBool(value[curPos:]) @@ -406,28 +407,28 @@ func (o OlmSession) Pickle(key []byte) ([]byte, error) { // It returns the number of bytes written. func (o OlmSession) PickleLibOlm(target []byte) (int, error) { if len(target) < o.PickleLen() { - return 0, errors.Wrap(goolm.ErrValueTooShort, "pickle MegolmOutboundSession") + return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", goolm.ErrValueTooShort) } written := libolmpickle.PickleUInt32(olmSessionPickleVersionLibOlm, target) written += libolmpickle.PickleBool(o.RecievedMessage, target[written:]) writtenRatchet, err := o.AliceIdKey.PickleLibOlm(target[written:]) if err != nil { - return 0, errors.Wrap(err, "pickle MegolmOutboundSession") + return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err) } written += writtenRatchet writtenRatchet, err = o.AliceBaseKey.PickleLibOlm(target[written:]) if err != nil { - return 0, errors.Wrap(err, "pickle MegolmOutboundSession") + return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err) } written += writtenRatchet writtenRatchet, err = o.BobOneTimeKey.PickleLibOlm(target[written:]) if err != nil { - return 0, errors.Wrap(err, "pickle MegolmOutboundSession") + return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err) } written += writtenRatchet writtenRatchet, err = o.Ratchet.PickleLibOlm(target[written:]) if err != nil { - return 0, errors.Wrap(err, "pickle MegolmOutboundSession") + return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err) } written += writtenRatchet return written, nil diff --git a/crypto/goolm/session/olmSession_test.go b/crypto/goolm/session/olm_session_test.go similarity index 82% rename from crypto/goolm/session/olmSession_test.go rename to crypto/goolm/session/olm_session_test.go index 3fbdb569..11b13c32 100644 --- a/crypto/goolm/session/olmSession_test.go +++ b/crypto/goolm/session/olm_session_test.go @@ -1,12 +1,14 @@ -package session +package session_test import ( "bytes" + "encoding/base64" "errors" "testing" - "codeberg.org/DerLukas/goolm" - "codeberg.org/DerLukas/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/session" "maunium.net/go/mautrix/id" ) @@ -24,7 +26,7 @@ func TestOlmSession(t *testing.T) { if err != nil { t.Fatal(err) } - aliceSession, err := NewOutboundOlmSession(aliceKeyPair, bobKeyPair.PublicKey, bobOneTimeKey.PublicKey) + aliceSession, err := session.NewOutboundOlmSession(aliceKeyPair, bobKeyPair.PublicKey, bobOneTimeKey.PublicKey) if err != nil { t.Fatal(err) } @@ -49,7 +51,7 @@ func TestOlmSession(t *testing.T) { return nil } //bob receives message - bobSession, err := NewInboundOlmSession(nil, message, searchFunc, bobKeyPair) + bobSession, err := session.NewInboundOlmSession(nil, message, searchFunc, bobKeyPair) if err != nil { t.Fatal(err) } @@ -78,7 +80,7 @@ func TestOlmSession(t *testing.T) { } //Alice unpickles session - newAliceSession, err := OlmSessionFromJSONPickled(pickled, pickleKey) + newAliceSession, err := session.OlmSessionFromJSONPickled(pickled, pickleKey) if err != nil { t.Fatal(err) } @@ -120,11 +122,11 @@ func TestOlmSession(t *testing.T) { func TestSessionPickle(t *testing.T) { pickledDataFromLibOlm := []byte("icDKYm0b4aO23WgUuOxdpPoxC0UlEOYPVeuduNH3IkpFsmnWx5KuEOpxGiZw5IuB/sSn2RZUCTiJ90IvgC7AClkYGHep9O8lpiqQX73XVKD9okZDCAkBc83eEq0DKYC7HBkGRAU/4T6QPIBBY3UK4QZwULLE/fLsi3j4YZBehMtnlsqgHK0q1bvX4cRznZItVKR4ro0O9EAk6LLxJtSnRu5elSUk7YXT") pickleKey := []byte("secret_key") - session, err := OlmSessionFromPickled(pickledDataFromLibOlm, pickleKey) + sess, err := session.OlmSessionFromPickled(pickledDataFromLibOlm, pickleKey) if err != nil { t.Fatal(err) } - newPickled, err := session.Pickle(pickleKey) + newPickled, err := sess.Pickle(pickleKey) if err != nil { t.Fatal(err) } @@ -132,7 +134,7 @@ func TestSessionPickle(t *testing.T) { t.Fatal("pickled version does not equal libolm version") } pickledDataFromLibOlm = append(pickledDataFromLibOlm, []byte("a")...) - _, err = OlmSessionFromPickled(pickledDataFromLibOlm, pickleKey) + _, err = session.OlmSessionFromPickled(pickledDataFromLibOlm, pickleKey) if err == nil { t.Fatal("should have gotten an error") } @@ -147,9 +149,10 @@ func TestDecrypts(t *testing.T) { } expectedErr := []error{ goolm.ErrInputToSmall, - goolm.ErrBadBase64, - goolm.ErrBadBase64, - goolm.ErrBadBase64, + // Why are these being tested 🤔 + base64.CorruptInputError(0), + base64.CorruptInputError(0), + base64.CorruptInputError(0), } sessionPickled := []byte("E0p44KO2y2pzp9FIjv0rud2wIvWDi2dx367kP4Fz/9JCMrH+aG369HGymkFtk0+PINTLB9lQRt" + "ohea5d7G/UXQx3r5y4IWuyh1xaRnojEZQ9a5HRZSNtvmZ9NY1f1gutYa4UtcZcbvczN8b/5Bqg" + @@ -157,12 +160,12 @@ func TestDecrypts(t *testing.T) { "rfjLdzQrgjOTxN8Pf6iuP+WFPvfnR9lDmNCFxJUVAdLIMnLuAdxf1TGcS+zzCzEE8btIZ99mHF" + "dGvPXeH8qLeNZA") pickleKey := []byte("") - session, err := OlmSessionFromPickled(sessionPickled, pickleKey) + sess, err := session.OlmSessionFromPickled(sessionPickled, pickleKey) if err != nil { t.Fatal(err) } for curIndex, curMessage := range messages { - _, err := session.Decrypt(curMessage, id.OlmMsgTypePreKey) + _, err := sess.Decrypt(curMessage, id.OlmMsgTypePreKey) if err != nil { if !errors.Is(err, expectedErr[curIndex]) { t.Fatal(err) diff --git a/crypto/goolm/utilities/main.go b/crypto/goolm/utilities/main.go index 48ff37aa..c5b5c2d5 100644 --- a/crypto/goolm/utilities/main.go +++ b/crypto/goolm/utilities/main.go @@ -1,18 +1,16 @@ package utilities import ( - "codeberg.org/DerLukas/goolm" - "codeberg.org/DerLukas/goolm/crypto" + "encoding/base64" + + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/id" ) -func Sha256(value []byte) []byte { - return goolm.Base64Encode(crypto.SHA256((value))) -} - // VerifySignature verifies an ed25519 signature. func VerifySignature(message []byte, key id.Ed25519, signature []byte) (ok bool, err error) { - keyDecoded, err := goolm.Base64Decode([]byte(key)) + keyDecoded, err := base64.RawStdEncoding.DecodeString(string(key)) if err != nil { return false, err } diff --git a/crypto/goolm/utilities/main_test.go b/crypto/goolm/utilities/main_test.go deleted file mode 100644 index 9019338b..00000000 --- a/crypto/goolm/utilities/main_test.go +++ /dev/null @@ -1,14 +0,0 @@ -package utilities - -import ( - "bytes" - "testing" -) - -func TestSHA256(t *testing.T) { - plainText := []byte("Hello, World") - expected := []byte("A2daxT/5zRU1zMffzfosRYxSGDcfQY3BNvLRmsH76KU") - if !bytes.Equal(Sha256(plainText), expected) { - t.Fatal("sha256 failed") - } -} diff --git a/crypto/goolm/utilities/pickle.go b/crypto/goolm/utilities/pickle.go index 6bfbcbc2..993366c8 100644 --- a/crypto/goolm/utilities/pickle.go +++ b/crypto/goolm/utilities/pickle.go @@ -2,20 +2,20 @@ package utilities import ( "encoding/json" + "fmt" - "codeberg.org/DerLukas/goolm" - "codeberg.org/DerLukas/goolm/cipher" - "github.com/pkg/errors" + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/cipher" ) // PickleAsJSON returns an object as a base64 string encrypted using the supplied key. The unencrypted representation of the object is in JSON format. func PickleAsJSON(object any, pickleVersion byte, key []byte) ([]byte, error) { if len(key) == 0 { - return nil, errors.Wrap(goolm.ErrNoKeyProvided, "pickle") + return nil, fmt.Errorf("pickle: %w", goolm.ErrNoKeyProvided) } marshaled, err := json.Marshal(object) if err != nil { - return nil, errors.Wrap(err, "pickle marshal") + return nil, fmt.Errorf("pickle marshal: %w", err) } marshaled = append([]byte{pickleVersion}, marshaled...) toEncrypt := make([]byte, len(marshaled)) @@ -28,7 +28,7 @@ func PickleAsJSON(object any, pickleVersion byte, key []byte) ([]byte, error) { } encrypted, err := cipher.Pickle(key, toEncrypt) if err != nil { - return nil, errors.Wrap(err, "pickle encrypt") + return nil, fmt.Errorf("pickle encrypt: %w", err) } return encrypted, nil } @@ -36,11 +36,11 @@ func PickleAsJSON(object any, pickleVersion byte, key []byte) ([]byte, error) { // UnpickleAsJSON updates the object by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. func UnpickleAsJSON(object any, pickled, key []byte, pickleVersion byte) error { if len(key) == 0 { - return errors.Wrap(goolm.ErrNoKeyProvided, "unpickle") + return fmt.Errorf("unpickle: %w", goolm.ErrNoKeyProvided) } decrypted, err := cipher.Unpickle(key, pickled) if err != nil { - return errors.Wrap(err, "unpickle decrypt") + return fmt.Errorf("unpickle decrypt: %w", err) } //unpad decrypted so unmarshal works for i := len(decrypted) - 1; i >= 0; i-- { @@ -50,11 +50,11 @@ func UnpickleAsJSON(object any, pickled, key []byte, pickleVersion byte) error { } } if decrypted[0] != pickleVersion { - return errors.Wrap(goolm.ErrWrongPickleVersion, "unpickle") + return fmt.Errorf("unpickle: %w", goolm.ErrWrongPickleVersion) } err = json.Unmarshal(decrypted[1:], object) if err != nil { - return errors.Wrap(err, "unpickle unmarshal") + return fmt.Errorf("unpickle unmarshal: %w", err) } return nil } From 3e4cb751d0e495af47fdb5671deaed2bdccc3b5d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 15 Dec 2023 13:30:11 +0200 Subject: [PATCH 0030/1647] Remove special module licenses --- README.md | 5 - crypto/canonicaljson/LICENSE | 177 --------------------------------- crypto/canonicaljson/README.md | 2 + crypto/goolm/LICENSE | 9 -- crypto/olm/LICENSE | 177 --------------------------------- crypto/olm/README.md | 2 + 6 files changed, 4 insertions(+), 368 deletions(-) delete mode 100644 crypto/canonicaljson/LICENSE delete mode 100644 crypto/goolm/LICENSE delete mode 100644 crypto/olm/LICENSE diff --git a/README.md b/README.md index 04fdc0e9..d45860e7 100644 --- a/README.md +++ b/README.md @@ -17,8 +17,3 @@ In addition to the basic client API features the original project has, this fram * Structs for parsing event content * Helpers for parsing and generating Matrix HTML * Helpers for handling push rules - -This project contains modules that are licensed under Apache 2.0: - -* [maunium.net/go/mautrix/crypto/canonicaljson](crypto/canonicaljson) -* [maunium.net/go/mautrix/crypto/olm](crypto/olm) diff --git a/crypto/canonicaljson/LICENSE b/crypto/canonicaljson/LICENSE deleted file mode 100644 index f433b1a5..00000000 --- a/crypto/canonicaljson/LICENSE +++ /dev/null @@ -1,177 +0,0 @@ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS diff --git a/crypto/canonicaljson/README.md b/crypto/canonicaljson/README.md index 7f5d3da7..da9d71ff 100644 --- a/crypto/canonicaljson/README.md +++ b/crypto/canonicaljson/README.md @@ -2,3 +2,5 @@ This is a Go package to produce Matrix [Canonical JSON](https://matrix.org/docs/spec/appendices#canonical-json). It is essentially just [json.go](https://github.com/matrix-org/gomatrixserverlib/blob/master/json.go) from gomatrixserverlib without all the other files that are completely useless for non-server use cases. + +The original project is licensed under the Apache 2.0 license. diff --git a/crypto/goolm/LICENSE b/crypto/goolm/LICENSE deleted file mode 100644 index d81feaab..00000000 --- a/crypto/goolm/LICENSE +++ /dev/null @@ -1,9 +0,0 @@ -MIT License - -Copyright (c) 2022 Lukas Gallandi - -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/crypto/olm/LICENSE b/crypto/olm/LICENSE deleted file mode 100644 index f433b1a5..00000000 --- a/crypto/olm/LICENSE +++ /dev/null @@ -1,177 +0,0 @@ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS diff --git a/crypto/olm/README.md b/crypto/olm/README.md index 1b27f639..7d8086c0 100644 --- a/crypto/olm/README.md +++ b/crypto/olm/README.md @@ -1,2 +1,4 @@ # Go olm bindings Based on [Dhole/go-olm](https://github.com/Dhole/go-olm) + +The original project is licensed under the Apache 2.0 license. From 39efee17e0d6886d07f98b4060338103c4f2d21c Mon Sep 17 00:00:00 2001 From: Lukas Gallandi Date: Fri, 15 Dec 2023 13:35:25 +0200 Subject: [PATCH 0031/1647] Add build tag to use goolm in crypto/olm Merges/closes #106 --- crypto/keyexport.go | 4 +- crypto/olm/account.go | 24 +-- crypto/olm/account_goolm.go | 192 ++++++++++++++++++++++ crypto/olm/error.go | 12 +- crypto/olm/error_goolm.go | 23 +++ crypto/olm/inboundgroupsession.go | 20 +-- crypto/olm/inboundgroupsession_goolm.go | 194 +++++++++++++++++++++++ crypto/olm/olm.go | 2 + crypto/olm/olm_goolm.go | 23 +++ crypto/olm/outboundgroupsession.go | 14 +- crypto/olm/outboundgroupsession_goolm.go | 156 ++++++++++++++++++ crypto/olm/pk.go | 4 +- crypto/olm/pk_goolm.go | 71 +++++++++ crypto/olm/session.go | 22 +-- crypto/olm/session_goolm.go | 155 ++++++++++++++++++ crypto/olm/utility.go | 8 +- crypto/olm/utility_goolm.go | 92 +++++++++++ crypto/olm/verification.go | 4 +- crypto/olm/verification_goolm.go | 23 +++ 19 files changed, 995 insertions(+), 48 deletions(-) create mode 100644 crypto/olm/account_goolm.go create mode 100644 crypto/olm/error_goolm.go create mode 100644 crypto/olm/inboundgroupsession_goolm.go create mode 100644 crypto/olm/olm_goolm.go create mode 100644 crypto/olm/outboundgroupsession_goolm.go create mode 100644 crypto/olm/pk_goolm.go create mode 100644 crypto/olm/session_goolm.go create mode 100644 crypto/olm/utility_goolm.go create mode 100644 crypto/olm/verification_goolm.go diff --git a/crypto/keyexport.go b/crypto/keyexport.go index 91bfb6c6..d5a37702 100644 --- a/crypto/keyexport.go +++ b/crypto/keyexport.go @@ -69,7 +69,7 @@ func makeExportIV() []byte { iv := make([]byte, 16) _, err := rand.Read(iv) if err != nil { - panic(olm.NotEnoughGoRandom) + panic(olm.ErrNotEnoughGoRandom) } // Set bit 63 to zero iv[7] &= 0b11111110 @@ -80,7 +80,7 @@ func makeExportKeys(passphrase string) (encryptionKey, hashKey, salt, iv []byte) salt = make([]byte, 16) _, err := rand.Read(salt) if err != nil { - panic(olm.NotEnoughGoRandom) + panic(olm.ErrNotEnoughGoRandom) } encryptionKey, hashKey = computeKey(passphrase, salt, defaultPassphraseRounds) diff --git a/crypto/olm/account.go b/crypto/olm/account.go index c3d80263..d3298d6e 100644 --- a/crypto/olm/account.go +++ b/crypto/olm/account.go @@ -1,3 +1,5 @@ +//go:build !goolm + package olm // #cgo LDFLAGS: -lolm -lstdc++ @@ -30,7 +32,7 @@ type Account struct { // "INVALID_BASE64". func AccountFromPickled(pickled, key []byte) (*Account, error) { if len(pickled) == 0 { - return nil, EmptyInput + return nil, ErrEmptyInput } a := NewBlankAccount() return a, a.Unpickle(pickled, key) @@ -50,7 +52,7 @@ func NewAccount() *Account { random := make([]byte, a.createRandomLen()+1) _, err := rand.Read(random) if err != nil { - panic(NotEnoughGoRandom) + panic(ErrNotEnoughGoRandom) } r := C.olm_create_account( (*C.OlmAccount)(a.int), @@ -124,7 +126,7 @@ func (a *Account) genOneTimeKeysRandomLen(num uint) uint { // supplied key. func (a *Account) Pickle(key []byte) []byte { if len(key) == 0 { - panic(NoKeyProvided) + panic(ErrNoKeyProvided) } pickled := make([]byte, a.pickleLen()) r := C.olm_pickle_account( @@ -141,7 +143,7 @@ func (a *Account) Pickle(key []byte) []byte { func (a *Account) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return NoKeyProvided + return ErrNoKeyProvided } r := C.olm_unpickle_account( (*C.OlmAccount)(a.int), @@ -184,7 +186,7 @@ func (a *Account) MarshalJSON() ([]byte, error) { func (a *Account) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return InputNotJSONString + return ErrInputNotJSONString } if a.int == nil { *a = *NewBlankAccount() @@ -218,7 +220,7 @@ func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519) { // Account. func (a *Account) Sign(message []byte) []byte { if len(message) == 0 { - panic(EmptyInput) + panic(ErrEmptyInput) } signature := make([]byte, a.signatureLen()) r := C.olm_account_sign( @@ -296,7 +298,7 @@ func (a *Account) GenOneTimeKeys(num uint) { random := make([]byte, a.genOneTimeKeysRandomLen(num)+1) _, err := rand.Read(random) if err != nil { - panic(NotEnoughGoRandom) + panic(ErrNotEnoughGoRandom) } r := C.olm_account_generate_one_time_keys( (*C.OlmAccount)(a.int), @@ -313,13 +315,13 @@ func (a *Account) GenOneTimeKeys(num uint) { // keys couldn't be decoded as base64 then the error will be "INVALID_BASE64" func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*Session, error) { if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 { - return nil, EmptyInput + return nil, ErrEmptyInput } s := NewBlankSession() random := make([]byte, s.createOutboundRandomLen()+1) _, err := rand.Read(random) if err != nil { - panic(NotEnoughGoRandom) + panic(ErrNotEnoughGoRandom) } r := C.olm_create_outbound_session( (*C.OlmSession)(s.int), @@ -345,7 +347,7 @@ func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve2 // time key then the error will be "BAD_MESSAGE_KEY_ID". func (a *Account) NewInboundSession(oneTimeKeyMsg string) (*Session, error) { if len(oneTimeKeyMsg) == 0 { - return nil, EmptyInput + return nil, ErrEmptyInput } s := NewBlankSession() r := C.olm_create_inbound_session( @@ -368,7 +370,7 @@ func (a *Account) NewInboundSession(oneTimeKeyMsg string) (*Session, error) { // time key then the error will be "BAD_MESSAGE_KEY_ID". func (a *Account) NewInboundSessionFrom(theirIdentityKey id.Curve25519, oneTimeKeyMsg string) (*Session, error) { if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 { - return nil, EmptyInput + return nil, ErrEmptyInput } s := NewBlankSession() r := C.olm_create_inbound_session_from( diff --git a/crypto/olm/account_goolm.go b/crypto/olm/account_goolm.go new file mode 100644 index 00000000..ca448322 --- /dev/null +++ b/crypto/olm/account_goolm.go @@ -0,0 +1,192 @@ +//go:build goolm + +package olm + +import ( + "encoding/base64" + "encoding/json" + + "github.com/tidwall/sjson" + + "maunium.net/go/mautrix/crypto/canonicaljson" + "maunium.net/go/mautrix/crypto/goolm/account" + "maunium.net/go/mautrix/id" +) + +// Account stores a device account for end to end encrypted messaging. +type Account struct { + account.Account +} + +// NewAccount creates a new Account. +func NewAccount() *Account { + a, err := account.NewAccount(nil) + if err != nil { + panic(err) + } + ac := &Account{} + ac.Account = *a + return ac +} + +func NewBlankAccount() *Account { + return &Account{} +} + +// Clear clears the memory used to back this Account. +func (a *Account) Clear() error { + a.Account = account.Account{} + return nil +} + +// Pickle returns an Account as a base64 string. Encrypts the Account using the +// supplied key. +func (a *Account) Pickle(key []byte) []byte { + if len(key) == 0 { + panic(ErrNoKeyProvided) + } + pickled, err := a.Account.Pickle(key) + if err != nil { + panic(err) + } + return pickled +} + +func (a *Account) GobEncode() ([]byte, error) { + pickled, err := a.Account.Pickle(pickleKey) + if err != nil { + return nil, err + } + length := base64.RawStdEncoding.DecodedLen(len(pickled)) + rawPickled := make([]byte, length) + _, err = base64.RawStdEncoding.Decode(rawPickled, pickled) + return rawPickled, err +} + +func (a *Account) GobDecode(rawPickled []byte) error { + length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) + pickled := make([]byte, length) + base64.RawStdEncoding.Encode(pickled, rawPickled) + return a.Unpickle(pickled, pickleKey) +} + +func (a *Account) MarshalJSON() ([]byte, error) { + pickled, err := a.Account.Pickle(pickleKey) + if err != nil { + return nil, err + } + quotes := make([]byte, len(pickled)+2) + quotes[0] = '"' + quotes[len(quotes)-1] = '"' + copy(quotes[1:len(quotes)-1], pickled) + return quotes, nil +} + +func (a *Account) UnmarshalJSON(data []byte) error { + if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { + return ErrInputNotJSONString + } + return a.Unpickle(data[1:len(data)-1], pickleKey) +} + +// IdentityKeysJSON returns the public parts of the identity keys for the Account. +func (a *Account) IdentityKeysJSON() []byte { + identityKeys, err := a.Account.IdentityKeysJSON() + if err != nil { + panic(err) + } + return identityKeys +} + +// Sign returns the signature of a message using the ed25519 key for this +// Account. +func (a *Account) Sign(message []byte) []byte { + if len(message) == 0 { + panic(ErrEmptyInput) + } + signature, err := a.Account.Sign(message) + if err != nil { + panic(err) + } + return signature +} + +// SignJSON signs the given JSON object following the Matrix specification: +// https://matrix.org/docs/spec/appendices#signing-json +func (a *Account) SignJSON(obj interface{}) (string, error) { + objJSON, err := json.Marshal(obj) + if err != nil { + return "", err + } + objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned") + objJSON, _ = sjson.DeleteBytes(objJSON, "signatures") + return string(a.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))), nil +} + +// MaxNumberOfOneTimeKeys returns the largest number of one time keys this +// Account can store. +func (a *Account) MaxNumberOfOneTimeKeys() uint { + return uint(account.MaxOneTimeKeys) +} + +// GenOneTimeKeys generates a number of new one time keys. If the total number +// of keys stored by this Account exceeds MaxNumberOfOneTimeKeys then the old +// keys are discarded. +func (a *Account) GenOneTimeKeys(num uint) { + err := a.Account.GenOneTimeKeys(nil, num) + if err != nil { + panic(err) + } +} + +// NewOutboundSession creates a new out-bound session for sending messages to a +// given curve25519 identityKey and oneTimeKey. Returns error on failure. +func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*Session, error) { + if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 { + return nil, ErrEmptyInput + } + s := &Session{} + newSession, err := a.Account.NewOutboundSession(theirIdentityKey, theirOneTimeKey) + if err != nil { + return nil, err + } + s.OlmSession = *newSession + return s, nil +} + +// NewInboundSession creates a new in-bound session for sending/receiving +// messages from an incoming PRE_KEY message. Returns error on failure. +func (a *Account) NewInboundSession(oneTimeKeyMsg string) (*Session, error) { + if len(oneTimeKeyMsg) == 0 { + return nil, ErrEmptyInput + } + s := &Session{} + newSession, err := a.Account.NewInboundSession(nil, []byte(oneTimeKeyMsg)) + if err != nil { + return nil, err + } + s.OlmSession = *newSession + return s, nil +} + +// NewInboundSessionFrom creates a new in-bound session for sending/receiving +// messages from an incoming PRE_KEY message. Returns error on failure. +func (a *Account) NewInboundSessionFrom(theirIdentityKey id.Curve25519, oneTimeKeyMsg string) (*Session, error) { + if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 { + return nil, ErrEmptyInput + } + s := &Session{} + newSession, err := a.Account.NewInboundSession(&theirIdentityKey, []byte(oneTimeKeyMsg)) + if err != nil { + return nil, err + } + s.OlmSession = *newSession + return s, nil +} + +// RemoveOneTimeKeys removes the one time keys that the session used from the +// Account. Returns error on failure. +func (a *Account) RemoveOneTimeKeys(s *Session) error { + a.Account.RemoveOneTimeKeys(&s.OlmSession) + return nil +} diff --git a/crypto/olm/error.go b/crypto/olm/error.go index 70a32c7d..b15af138 100644 --- a/crypto/olm/error.go +++ b/crypto/olm/error.go @@ -1,3 +1,5 @@ +//go:build !goolm + package olm import ( @@ -7,11 +9,11 @@ import ( // Error codes from go-olm var ( - EmptyInput = errors.New("empty input") - NoKeyProvided = errors.New("no pickle key provided") - NotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand") - SignatureNotFound = errors.New("input JSON doesn't contain signature from specified device") - InputNotJSONString = errors.New("input doesn't look like a JSON string") + ErrEmptyInput = errors.New("empty input") + ErrNoKeyProvided = errors.New("no pickle key provided") + ErrNotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand") + ErrSignatureNotFound = errors.New("input JSON doesn't contain signature from specified device") + ErrInputNotJSONString = errors.New("input doesn't look like a JSON string") ) // Error codes from olm code diff --git a/crypto/olm/error_goolm.go b/crypto/olm/error_goolm.go new file mode 100644 index 00000000..e4e2cf70 --- /dev/null +++ b/crypto/olm/error_goolm.go @@ -0,0 +1,23 @@ +//go:build goolm + +package olm + +import ( + "errors" + + "maunium.net/go/mautrix/crypto/goolm" +) + +// Error codes from go-olm +var ( + ErrEmptyInput = errors.New("empty input") + ErrNoKeyProvided = errors.New("no pickle key provided") + ErrNotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand") + ErrSignatureNotFound = errors.New("input JSON doesn't contain signature from specified device") + ErrInputNotJSONString = errors.New("input doesn't look like a JSON string") +) + +// Error codes from olm code +var ( + UnknownMessageIndex = goolm.ErrRatchetNotAvailable +) diff --git a/crypto/olm/inboundgroupsession.go b/crypto/olm/inboundgroupsession.go index 93e54ff4..14d2e226 100644 --- a/crypto/olm/inboundgroupsession.go +++ b/crypto/olm/inboundgroupsession.go @@ -1,3 +1,5 @@ +//go:build !goolm + package olm // #cgo LDFLAGS: -lolm -lstdc++ @@ -25,7 +27,7 @@ type InboundGroupSession struct { // base64 couldn't be decoded then the error will be "INVALID_BASE64". func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) { if len(pickled) == 0 { - return nil, EmptyInput + return nil, ErrEmptyInput } lenKey := len(key) if lenKey == 0 { @@ -42,7 +44,7 @@ func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, // "OLM_BAD_SESSION_KEY". func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { if len(sessionKey) == 0 { - return nil, EmptyInput + return nil, ErrEmptyInput } s := NewBlankInboundGroupSession() r := C.olm_init_inbound_group_session( @@ -61,7 +63,7 @@ func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { // error will be "OLM_BAD_SESSION_KEY". func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) { if len(sessionKey) == 0 { - return nil, EmptyInput + return nil, ErrEmptyInput } s := NewBlankInboundGroupSession() r := C.olm_import_inbound_group_session( @@ -114,7 +116,7 @@ func (s *InboundGroupSession) pickleLen() uint { // InboundGroupSession using the supplied key. func (s *InboundGroupSession) Pickle(key []byte) []byte { if len(key) == 0 { - panic(NoKeyProvided) + panic(ErrNoKeyProvided) } pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_inbound_group_session( @@ -131,9 +133,9 @@ func (s *InboundGroupSession) Pickle(key []byte) []byte { func (s *InboundGroupSession) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return NoKeyProvided + return ErrNoKeyProvided } else if len(pickled) == 0 { - return EmptyInput + return ErrEmptyInput } r := C.olm_unpickle_inbound_group_session( (*C.OlmInboundGroupSession)(s.int), @@ -176,7 +178,7 @@ func (s *InboundGroupSession) MarshalJSON() ([]byte, error) { func (s *InboundGroupSession) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return InputNotJSONString + return ErrInputNotJSONString } if s == nil || s.int == nil { *s = *NewBlankInboundGroupSession() @@ -199,7 +201,7 @@ func clone(original []byte) []byte { // will be "BAD_MESSAGE_FORMAT". func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, error) { if len(message) == 0 { - return 0, EmptyInput + return 0, ErrEmptyInput } // olm_group_decrypt_max_plaintext_length destroys the input, so we have to clone it message = clone(message) @@ -224,7 +226,7 @@ func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, erro // was shared with us) the error will be "OLM_UNKNOWN_MESSAGE_INDEX". func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) { if len(message) == 0 { - return nil, 0, EmptyInput + return nil, 0, ErrEmptyInput } decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message) if err != nil { diff --git a/crypto/olm/inboundgroupsession_goolm.go b/crypto/olm/inboundgroupsession_goolm.go new file mode 100644 index 00000000..3ece360b --- /dev/null +++ b/crypto/olm/inboundgroupsession_goolm.go @@ -0,0 +1,194 @@ +//go:build goolm + +package olm + +import ( + "encoding/base64" + + "maunium.net/go/mautrix/crypto/goolm/session" + "maunium.net/go/mautrix/id" +) + +// InboundGroupSession stores an inbound encrypted messaging session for a +// group. +type InboundGroupSession struct { + session.MegolmInboundSession +} + +// InboundGroupSessionFromPickled loads an InboundGroupSession from a pickled +// base64 string. Decrypts the InboundGroupSession using the supplied key. +// Returns error on failure. +func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) { + if len(pickled) == 0 { + return nil, ErrEmptyInput + } + lenKey := len(key) + if lenKey == 0 { + key = []byte(" ") + } + megolmSession, err := session.MegolmInboundSessionFromPickled(pickled, key) + if err != nil { + return nil, err + } + return &InboundGroupSession{ + MegolmInboundSession: *megolmSession, + }, nil +} + +// NewInboundGroupSession creates a new inbound group session from a key +// exported from OutboundGroupSession.Key(). Returns error on failure. +func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { + if len(sessionKey) == 0 { + return nil, ErrEmptyInput + } + megolmSession, err := session.NewMegolmInboundSession(sessionKey) + if err != nil { + return nil, err + } + return &InboundGroupSession{ + MegolmInboundSession: *megolmSession, + }, nil +} + +// InboundGroupSessionImport imports an inbound group session from a previous +// export. Returns error on failure. +func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) { + if len(sessionKey) == 0 { + return nil, ErrEmptyInput + } + megolmSession, err := session.NewMegolmInboundSessionFromExport(sessionKey) + if err != nil { + return nil, err + } + return &InboundGroupSession{ + MegolmInboundSession: *megolmSession, + }, nil +} + +func NewBlankInboundGroupSession() *InboundGroupSession { + return &InboundGroupSession{} +} + +// Clear clears the memory used to back this InboundGroupSession. +func (s *InboundGroupSession) Clear() error { + s.MegolmInboundSession = session.MegolmInboundSession{} + return nil +} + +// Pickle returns an InboundGroupSession as a base64 string. Encrypts the +// InboundGroupSession using the supplied key. +func (s *InboundGroupSession) Pickle(key []byte) []byte { + if len(key) == 0 { + panic(ErrNoKeyProvided) + } + pickled, err := s.MegolmInboundSession.Pickle(key) + if err != nil { + panic(err) + } + return pickled +} + +func (s *InboundGroupSession) Unpickle(pickled, key []byte) error { + if len(key) == 0 { + return ErrNoKeyProvided + } else if len(pickled) == 0 { + return ErrEmptyInput + } + sOlm, err := session.MegolmInboundSessionFromPickled(pickled, key) + if err != nil { + return err + } + s.MegolmInboundSession = *sOlm + return nil +} + +func (s *InboundGroupSession) GobEncode() ([]byte, error) { + pickled, err := s.MegolmInboundSession.Pickle(pickleKey) + if err != nil { + return nil, err + } + length := base64.RawStdEncoding.DecodedLen(len(pickled)) + rawPickled := make([]byte, length) + _, err = base64.RawStdEncoding.Decode(rawPickled, pickled) + return rawPickled, err +} + +func (s *InboundGroupSession) GobDecode(rawPickled []byte) error { + if s == nil { + *s = *NewBlankInboundGroupSession() + } + length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) + pickled := make([]byte, length) + base64.RawStdEncoding.Encode(pickled, rawPickled) + return s.Unpickle(pickled, pickleKey) +} + +func (s *InboundGroupSession) MarshalJSON() ([]byte, error) { + pickled, err := s.MegolmInboundSession.Pickle(pickleKey) + if err != nil { + return nil, err + } + quotes := make([]byte, len(pickled)+2) + quotes[0] = '"' + quotes[len(quotes)-1] = '"' + copy(quotes[1:len(quotes)-1], pickled) + return quotes, nil +} + +func (s *InboundGroupSession) UnmarshalJSON(data []byte) error { + if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { + return ErrInputNotJSONString + } + if s == nil { + *s = *NewBlankInboundGroupSession() + } + return s.Unpickle(data[1:len(data)-1], pickleKey) +} + +// Decrypt decrypts a message using the InboundGroupSession. Returns the the +// plain-text and message index on success. Returns error on failure. +func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) { + if len(message) == 0 { + return nil, 0, ErrEmptyInput + } + plaintext, messageIndex, err := s.MegolmInboundSession.Decrypt(message) + if err != nil { + return nil, 0, err + } + return plaintext, uint(messageIndex), nil +} + +// ID returns a base64-encoded identifier for this session. +func (s *InboundGroupSession) ID() id.SessionID { + return s.MegolmInboundSession.SessionID() +} + +// FirstKnownIndex returns the first message index we know how to decrypt. +func (s *InboundGroupSession) FirstKnownIndex() uint32 { + return s.MegolmInboundSession.InitialRatchet.Counter +} + +// IsVerified check if the session has been verified as a valid session. (A +// session is verified either because the original session share was signed, or +// because we have subsequently successfully decrypted a message.) +func (s *InboundGroupSession) IsVerified() uint { + if s.MegolmInboundSession.SigningKeyVerified { + return 1 + } + return 0 +} + +// Export returns the base64-encoded ratchet key for this session, at the given +// index, in a format which can be used by +// InboundGroupSession.InboundGroupSessionImport(). Encrypts the +// InboundGroupSession using the supplied key. Returns error on failure. +// if we do not have a session key corresponding to the given index (ie, it was +// sent before the session key was shared with us) the error will be +// returned. +func (s *InboundGroupSession) Export(messageIndex uint32) ([]byte, error) { + res, err := s.MegolmInboundSession.SessionExportMessage(messageIndex) + if err != nil { + return nil, err + } + return res, nil +} diff --git a/crypto/olm/olm.go b/crypto/olm/olm.go index feb46e5d..685e1b6b 100644 --- a/crypto/olm/olm.go +++ b/crypto/olm/olm.go @@ -1,3 +1,5 @@ +//go:build !goolm + package olm // #cgo LDFLAGS: -lolm -lstdc++ diff --git a/crypto/olm/olm_goolm.go b/crypto/olm/olm_goolm.go new file mode 100644 index 00000000..9acd1d8d --- /dev/null +++ b/crypto/olm/olm_goolm.go @@ -0,0 +1,23 @@ +//go:build goolm + +package olm + +import ( + "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/id" +) + +// Signatures is the data structure used to sign JSON objects. +type Signatures map[id.UserID]map[id.DeviceKeyID]string + +// Version returns the version number of the olm library. +func Version() (major, minor, patch uint8) { + return goolm.GetLibaryVersion() +} + +var pickleKey = []byte("maunium.net/go/mautrix/crypto/olm") + +// SetPickleKey sets the global pickle key used when encoding structs with Gob or JSON. +func SetPickleKey(key []byte) { + pickleKey = key +} diff --git a/crypto/olm/outboundgroupsession.go b/crypto/olm/outboundgroupsession.go index 5f7da05b..98024cee 100644 --- a/crypto/olm/outboundgroupsession.go +++ b/crypto/olm/outboundgroupsession.go @@ -1,3 +1,5 @@ +//go:build !goolm + package olm // #cgo LDFLAGS: -lolm -lstdc++ @@ -26,7 +28,7 @@ type OutboundGroupSession struct { // base64 couldn't be decoded then the error will be "INVALID_BASE64". func OutboundGroupSessionFromPickled(pickled, key []byte) (*OutboundGroupSession, error) { if len(pickled) == 0 { - return nil, EmptyInput + return nil, ErrEmptyInput } s := NewBlankOutboundGroupSession() return s, s.Unpickle(pickled, key) @@ -38,7 +40,7 @@ func NewOutboundGroupSession() *OutboundGroupSession { random := make([]byte, s.createRandomLen()+1) _, err := rand.Read(random) if err != nil { - panic(NotEnoughGoRandom) + panic(ErrNotEnoughGoRandom) } r := C.olm_init_outbound_group_session( (*C.OlmOutboundGroupSession)(s.int), @@ -91,7 +93,7 @@ func (s *OutboundGroupSession) pickleLen() uint { // OutboundGroupSession using the supplied key. func (s *OutboundGroupSession) Pickle(key []byte) []byte { if len(key) == 0 { - panic(NoKeyProvided) + panic(ErrNoKeyProvided) } pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_outbound_group_session( @@ -108,7 +110,7 @@ func (s *OutboundGroupSession) Pickle(key []byte) []byte { func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return NoKeyProvided + return ErrNoKeyProvided } r := C.olm_unpickle_outbound_group_session( (*C.OlmOutboundGroupSession)(s.int), @@ -151,7 +153,7 @@ func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) { func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return InputNotJSONString + return ErrInputNotJSONString } if s == nil || s.int == nil { *s = *NewBlankOutboundGroupSession() @@ -175,7 +177,7 @@ func (s *OutboundGroupSession) encryptMsgLen(plainTextLen int) uint { // as base64. func (s *OutboundGroupSession) Encrypt(plaintext []byte) []byte { if len(plaintext) == 0 { - panic(EmptyInput) + panic(ErrEmptyInput) } message := make([]byte, s.encryptMsgLen(len(plaintext))) r := C.olm_group_encrypt( diff --git a/crypto/olm/outboundgroupsession_goolm.go b/crypto/olm/outboundgroupsession_goolm.go new file mode 100644 index 00000000..46c520b4 --- /dev/null +++ b/crypto/olm/outboundgroupsession_goolm.go @@ -0,0 +1,156 @@ +//go:build goolm + +package olm + +import ( + "encoding/base64" + + "maunium.net/go/mautrix/crypto/goolm/session" + "maunium.net/go/mautrix/id" +) + +// OutboundGroupSession stores an outbound encrypted messaging session for a +// group. +type OutboundGroupSession struct { + session.MegolmOutboundSession +} + +// OutboundGroupSessionFromPickled loads an OutboundGroupSession from a pickled +// base64 string. Decrypts the OutboundGroupSession using the supplied key. +// Returns error on failure. If the key doesn't match the one used to encrypt +// the OutboundGroupSession then the error will be "BAD_SESSION_KEY". If the +// base64 couldn't be decoded then the error will be "INVALID_BASE64". +func OutboundGroupSessionFromPickled(pickled, key []byte) (*OutboundGroupSession, error) { + if len(pickled) == 0 { + return nil, ErrEmptyInput + } + lenKey := len(key) + if lenKey == 0 { + key = []byte(" ") + } + megolmSession, err := session.MegolmOutboundSessionFromPickled(pickled, key) + if err != nil { + return nil, err + } + return &OutboundGroupSession{ + MegolmOutboundSession: *megolmSession, + }, nil +} + +// NewOutboundGroupSession creates a new outbound group session. +func NewOutboundGroupSession() *OutboundGroupSession { + megolmSession, err := session.NewMegolmOutboundSession() + if err != nil { + panic(err) + } + return &OutboundGroupSession{ + MegolmOutboundSession: *megolmSession, + } +} + +// newOutboundGroupSession initialises an empty OutboundGroupSession. +func NewBlankOutboundGroupSession() *OutboundGroupSession { + return &OutboundGroupSession{} +} + +// Clear clears the memory used to back this OutboundGroupSession. +func (s *OutboundGroupSession) Clear() error { + s.MegolmOutboundSession = session.MegolmOutboundSession{} + return nil +} + +// Pickle returns an OutboundGroupSession as a base64 string. Encrypts the +// OutboundGroupSession using the supplied key. +func (s *OutboundGroupSession) Pickle(key []byte) []byte { + if len(key) == 0 { + panic(ErrNoKeyProvided) + } + pickled, err := s.MegolmOutboundSession.Pickle(key) + if err != nil { + panic(err) + } + return pickled +} + +func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error { + if len(key) == 0 { + return ErrNoKeyProvided + } + return s.MegolmOutboundSession.Unpickle(pickled, key) +} + +func (s *OutboundGroupSession) GobEncode() ([]byte, error) { + pickled, err := s.MegolmOutboundSession.Pickle(pickleKey) + if err != nil { + return nil, err + } + length := base64.RawStdEncoding.DecodedLen(len(pickled)) + rawPickled := make([]byte, length) + _, err = base64.RawStdEncoding.Decode(rawPickled, pickled) + return rawPickled, err +} + +func (s *OutboundGroupSession) GobDecode(rawPickled []byte) error { + if s == nil { + *s = *NewBlankOutboundGroupSession() + } + length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) + pickled := make([]byte, length) + base64.RawStdEncoding.Encode(pickled, rawPickled) + return s.Unpickle(pickled, pickleKey) +} + +func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) { + pickled, err := s.MegolmOutboundSession.Pickle(pickleKey) + if err != nil { + return nil, err + } + quotes := make([]byte, len(pickled)+2) + quotes[0] = '"' + quotes[len(quotes)-1] = '"' + copy(quotes[1:len(quotes)-1], pickled) + return quotes, nil +} + +func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error { + if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { + return ErrInputNotJSONString + } + if s == nil { + *s = *NewBlankOutboundGroupSession() + } + return s.Unpickle(data[1:len(data)-1], pickleKey) +} + +// Encrypt encrypts a message using the Session. Returns the encrypted message +// as base64. +func (s *OutboundGroupSession) Encrypt(plaintext []byte) []byte { + if len(plaintext) == 0 { + panic(ErrEmptyInput) + } + message, err := s.MegolmOutboundSession.Encrypt(plaintext) + if err != nil { + panic(err) + } + return message +} + +// ID returns a base64-encoded identifier for this session. +func (s *OutboundGroupSession) ID() id.SessionID { + return s.MegolmOutboundSession.SessionID() +} + +// MessageIndex returns the message index for this session. Each message is +// sent with an increasing index; this returns the index for the next message. +func (s *OutboundGroupSession) MessageIndex() uint { + return uint(s.MegolmOutboundSession.Ratchet.Counter) +} + +// Key returns the base64-encoded current ratchet key for this session. +func (s *OutboundGroupSession) Key() string { + message, err := s.MegolmOutboundSession.SessionSharingMessage() + if err != nil { + panic(err) + } + return string(message) +} diff --git a/crypto/olm/pk.go b/crypto/olm/pk.go index e441ba14..6d6d3c16 100644 --- a/crypto/olm/pk.go +++ b/crypto/olm/pk.go @@ -1,3 +1,5 @@ +//go:build !goolm + package olm // #cgo LDFLAGS: -lolm -lstdc++ @@ -74,7 +76,7 @@ func NewPkSigning() (*PkSigning, error) { seed := make([]byte, pkSigningSeedLength()) _, err := rand.Read(seed) if err != nil { - panic(NotEnoughGoRandom) + panic(ErrNotEnoughGoRandom) } pk, err := NewPkSigningFromSeed(seed) return pk, err diff --git a/crypto/olm/pk_goolm.go b/crypto/olm/pk_goolm.go new file mode 100644 index 00000000..9659e918 --- /dev/null +++ b/crypto/olm/pk_goolm.go @@ -0,0 +1,71 @@ +//go:build goolm + +package olm + +import ( + "encoding/json" + + "github.com/tidwall/sjson" + + "maunium.net/go/mautrix/crypto/canonicaljson" + "maunium.net/go/mautrix/crypto/goolm/pk" + "maunium.net/go/mautrix/id" +) + +// PkSigning stores a key pair for signing messages. +type PkSigning struct { + pk.Signing + PublicKey id.Ed25519 + Seed []byte +} + +// Clear clears the underlying memory of a PkSigning object. +func (p *PkSigning) Clear() { + p.Signing = pk.Signing{} +} + +// NewPkSigningFromSeed creates a new PkSigning object using the given seed. +func NewPkSigningFromSeed(seed []byte) (*PkSigning, error) { + p := &PkSigning{} + signing, err := pk.NewSigningFromSeed(seed) + if err != nil { + return nil, err + } + p.Signing = *signing + p.Seed = seed + p.PublicKey = p.Signing.PublicKey() + return p, nil +} + +// NewPkSigning creates a new PkSigning object, containing a key pair for signing messages. +func NewPkSigning() (*PkSigning, error) { + p := &PkSigning{} + signing, err := pk.NewSigning() + if err != nil { + return nil, err + } + p.Signing = *signing + p.Seed = signing.Seed + p.PublicKey = p.Signing.PublicKey() + return p, err +} + +// Sign creates a signature for the given message using this key. +func (p *PkSigning) Sign(message []byte) ([]byte, error) { + return p.Signing.Sign(message), nil +} + +// SignJSON creates a signature for the given object after encoding it to canonical JSON. +func (p *PkSigning) SignJSON(obj interface{}) (string, error) { + objJSON, err := json.Marshal(obj) + if err != nil { + return "", err + } + objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned") + objJSON, _ = sjson.DeleteBytes(objJSON, "signatures") + signature, err := p.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON)) + if err != nil { + return "", err + } + return string(signature), nil +} diff --git a/crypto/olm/session.go b/crypto/olm/session.go index 625a16c3..e2920682 100644 --- a/crypto/olm/session.go +++ b/crypto/olm/session.go @@ -1,3 +1,5 @@ +//go:build !goolm + package olm // #cgo LDFLAGS: -lolm -lstdc++ @@ -40,7 +42,7 @@ func sessionSize() uint { // "INVALID_BASE64". func SessionFromPickled(pickled, key []byte) (*Session, error) { if len(pickled) == 0 { - return nil, EmptyInput + return nil, ErrEmptyInput } s := NewBlankSession() return s, s.Unpickle(pickled, key) @@ -107,7 +109,7 @@ func (s *Session) encryptMsgLen(plainTextLen int) uint { // will be "BAD_MESSAGE_FORMAT". func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) (uint, error) { if len(message) == 0 { - return 0, EmptyInput + return 0, ErrEmptyInput } r := C.olm_decrypt_max_plaintext_length( (*C.OlmSession)(s.int), @@ -124,7 +126,7 @@ func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) // supplied key. func (s *Session) Pickle(key []byte) []byte { if len(key) == 0 { - panic(NoKeyProvided) + panic(ErrNoKeyProvided) } pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_session( @@ -141,7 +143,7 @@ func (s *Session) Pickle(key []byte) []byte { func (s *Session) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return NoKeyProvided + return ErrNoKeyProvided } r := C.olm_unpickle_session( (*C.OlmSession)(s.int), @@ -184,7 +186,7 @@ func (s *Session) MarshalJSON() ([]byte, error) { func (s *Session) UnmarshalJSON(data []byte) error { if len(data) == 0 || len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return InputNotJSONString + return ErrInputNotJSONString } if s == nil || s.int == nil { *s = *NewBlankSession() @@ -226,7 +228,7 @@ func (s *Session) HasReceivedMessage() bool { // decoded then then the error will be "BAD_MESSAGE_FORMAT". func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { if len(oneTimeKeyMsg) == 0 { - return false, EmptyInput + return false, ErrEmptyInput } r := C.olm_matches_inbound_session( (*C.OlmSession)(s.int), @@ -251,7 +253,7 @@ func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { // decoded then then the error will be "BAD_MESSAGE_FORMAT". func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) { if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 { - return false, EmptyInput + return false, ErrEmptyInput } r := C.olm_matches_inbound_session_from( (*C.OlmSession)(s.int), @@ -287,13 +289,13 @@ func (s *Session) EncryptMsgType() id.OlmMsgType { // as base64. func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) { if len(plaintext) == 0 { - panic(EmptyInput) + panic(ErrEmptyInput) } // Make the slice be at least length 1 random := make([]byte, s.encryptRandomLen()+1) _, err := rand.Read(random) if err != nil { - panic(NotEnoughGoRandom) + panic(ErrNotEnoughGoRandom) } messageType := s.EncryptMsgType() message := make([]byte, s.encryptMsgLen(len(plaintext))) @@ -320,7 +322,7 @@ func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) { // "BAD_MESSAGE_MAC". func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) { if len(message) == 0 { - return nil, EmptyInput + return nil, ErrEmptyInput } decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message, msgType) if err != nil { diff --git a/crypto/olm/session_goolm.go b/crypto/olm/session_goolm.go new file mode 100644 index 00000000..5291cef4 --- /dev/null +++ b/crypto/olm/session_goolm.go @@ -0,0 +1,155 @@ +//go:build goolm + +package olm + +import ( + "encoding/base64" + + "maunium.net/go/mautrix/crypto/goolm/session" + "maunium.net/go/mautrix/id" +) + +// Session stores an end to end encrypted messaging session. +type Session struct { + session.OlmSession +} + +// SessionFromPickled loads a Session from a pickled base64 string. Decrypts +// the Session using the supplied key. Returns error on failure. +func SessionFromPickled(pickled, key []byte) (*Session, error) { + if len(pickled) == 0 { + return nil, ErrEmptyInput + } + s := NewBlankSession() + return s, s.Unpickle(pickled, key) +} + +func NewBlankSession() *Session { + return &Session{} +} + +// Clear clears the memory used to back this Session. +func (s *Session) Clear() error { + s.OlmSession = session.OlmSession{} + return nil +} + +// Pickle returns a Session as a base64 string. Encrypts the Session using the +// supplied key. +func (s *Session) Pickle(key []byte) []byte { + if len(key) == 0 { + panic(ErrNoKeyProvided) + } + pickled, err := s.OlmSession.Pickle(key) + if err != nil { + panic(err) + } + return pickled +} + +func (s *Session) Unpickle(pickled, key []byte) error { + if len(key) == 0 { + return ErrNoKeyProvided + } else if len(pickled) == 0 { + return ErrEmptyInput + } + sOlm, err := session.OlmSessionFromPickled(pickled, key) + if err != nil { + return err + } + s.OlmSession = *sOlm + return nil +} + +func (s *Session) GobEncode() ([]byte, error) { + pickled, err := s.OlmSession.Pickle(pickleKey) + if err != nil { + return nil, err + } + length := base64.RawStdEncoding.DecodedLen(len(pickled)) + rawPickled := make([]byte, length) + _, err = base64.RawStdEncoding.Decode(rawPickled, pickled) + return rawPickled, err +} + +func (s *Session) GobDecode(rawPickled []byte) error { + if s == nil { + *s = *NewBlankSession() + } + length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) + pickled := make([]byte, length) + base64.RawStdEncoding.Encode(pickled, rawPickled) + return s.Unpickle(pickled, pickleKey) +} + +func (s *Session) MarshalJSON() ([]byte, error) { + pickled, err := s.OlmSession.Pickle(pickleKey) + if err != nil { + return nil, err + } + quotes := make([]byte, len(pickled)+2) + quotes[0] = '"' + quotes[len(quotes)-1] = '"' + copy(quotes[1:len(quotes)-1], pickled) + return quotes, nil +} + +func (s *Session) UnmarshalJSON(data []byte) error { + if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { + return ErrInputNotJSONString + } + if s == nil { + *s = *NewBlankSession() + } + return s.Unpickle(data[1:len(data)-1], pickleKey) +} + +// MatchesInboundSession checks if the PRE_KEY message is for this in-bound +// Session. This can happen if multiple messages are sent to this Account +// before this Account sends a message in reply. Returns true if the session +// matches. Returns false if the session does not match. Returns error on +// failure. +func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { + return s.MatchesInboundSessionFrom("", oneTimeKeyMsg) +} + +// MatchesInboundSessionFrom checks if the PRE_KEY message is for this in-bound +// Session. This can happen if multiple messages are sent to this Account +// before this Account sends a message in reply. Returns true if the session +// matches. Returns false if the session does not match. Returns error on +// failure. +func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) { + if theirIdentityKey != "" { + theirKey := id.Curve25519(theirIdentityKey) + return s.OlmSession.MatchesInboundSessionFrom(&theirKey, []byte(oneTimeKeyMsg)) + } + return s.OlmSession.MatchesInboundSessionFrom(nil, []byte(oneTimeKeyMsg)) + +} + +// Encrypt encrypts a message using the Session. Returns the encrypted message +// as base64. +func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) { + if len(plaintext) == 0 { + panic(ErrEmptyInput) + } + messageType, message, err := s.OlmSession.Encrypt(plaintext, nil) + if err != nil { + panic(err) + } + return messageType, message +} + +// Decrypt decrypts a message using the Session. Returns the the plain-text on +// success. Returns error on failure. +func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) { + if len(message) == 0 { + return nil, ErrEmptyInput + } + return s.OlmSession.Decrypt([]byte(message), msgType) +} + +// Describe generates a string describing the internal state of an olm session for debugging and logging purposes. +func (s *Session) Describe() string { + return s.OlmSession.Describe() +} diff --git a/crypto/olm/utility.go b/crypto/olm/utility.go index 8e868deb..8d544dbb 100644 --- a/crypto/olm/utility.go +++ b/crypto/olm/utility.go @@ -1,3 +1,5 @@ +//go:build !goolm + package olm // #cgo LDFLAGS: -lolm -lstdc++ @@ -61,7 +63,7 @@ func NewUtility() *Utility { // Sha256 calculates the SHA-256 hash of the input and encodes it as base64. func (u *Utility) Sha256(input string) string { if len(input) == 0 { - panic(EmptyInput) + panic(ErrEmptyInput) } output := make([]byte, u.sha256Len()) r := C.olm_sha256( @@ -81,7 +83,7 @@ func (u *Utility) Sha256(input string) string { // small then the error will be "INVALID_BASE64". func (u *Utility) VerifySignature(message string, key id.Ed25519, signature string) (ok bool, err error) { if len(message) == 0 || len(key) == 0 || len(signature) == 0 { - return false, EmptyInput + return false, ErrEmptyInput } r := C.olm_ed25519_verify( (*C.OlmUtility)(u.int), @@ -117,7 +119,7 @@ func (u *Utility) VerifySignatureJSON(obj interface{}, userID id.UserID, keyName } sig := gjson.GetBytes(objJSON, exgjson.Path("signatures", string(userID), fmt.Sprintf("ed25519:%s", keyName))) if !sig.Exists() || sig.Type != gjson.String { - return false, SignatureNotFound + return false, ErrSignatureNotFound } objJSON, err = sjson.DeleteBytes(objJSON, "unsigned") if err != nil { diff --git a/crypto/olm/utility_goolm.go b/crypto/olm/utility_goolm.go new file mode 100644 index 00000000..31299e63 --- /dev/null +++ b/crypto/olm/utility_goolm.go @@ -0,0 +1,92 @@ +//go:build goolm + +package olm + +import ( + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.mau.fi/util/exgjson" + + "maunium.net/go/mautrix/crypto/canonicaljson" + "maunium.net/go/mautrix/crypto/goolm/utilities" + "maunium.net/go/mautrix/id" +) + +// Utility stores the necessary state to perform hash and signature +// verification operations. +type Utility struct{} + +// Clear clears the memory used to back this utility. +func (u *Utility) Clear() error { + return nil +} + +// NewUtility creates a new utility. +func NewUtility() *Utility { + return &Utility{} +} + +// Sha256 calculates the SHA-256 hash of the input and encodes it as base64. +func (u *Utility) Sha256(input string) string { + if len(input) == 0 { + panic(ErrEmptyInput) + } + hash := sha256.Sum256([]byte(input)) + return base64.RawStdEncoding.EncodeToString(hash[:]) +} + +// VerifySignature verifies an ed25519 signature. Returns true if the verification +// suceeds or false otherwise. Returns error on failure. If the key was too +// small then the error will be "INVALID_BASE64". +func (u *Utility) VerifySignature(message string, key id.Ed25519, signature string) (ok bool, err error) { + if len(message) == 0 || len(key) == 0 || len(signature) == 0 { + return false, ErrEmptyInput + } + return utilities.VerifySignature([]byte(message), key, []byte(signature)) +} + +// VerifySignatureJSON verifies the signature in the JSON object _obj following +// the Matrix specification: +// https://matrix.org/speculator/spec/drafts%2Fe2e/appendices.html#signing-json +// If the _obj is a struct, the `json` tags will be honored. +func (u *Utility) VerifySignatureJSON(obj interface{}, userID id.UserID, keyName string, key id.Ed25519) (bool, error) { + var err error + objJSON, ok := obj.(json.RawMessage) + if !ok { + objJSON, err = json.Marshal(obj) + if err != nil { + return false, err + } + } + sig := gjson.GetBytes(objJSON, exgjson.Path("signatures", string(userID), fmt.Sprintf("ed25519:%s", keyName))) + if !sig.Exists() || sig.Type != gjson.String { + return false, ErrSignatureNotFound + } + objJSON, err = sjson.DeleteBytes(objJSON, "unsigned") + if err != nil { + return false, err + } + objJSON, err = sjson.DeleteBytes(objJSON, "signatures") + if err != nil { + return false, err + } + objJSONString := string(canonicaljson.CanonicalJSONAssumeValid(objJSON)) + return u.VerifySignature(objJSONString, key, sig.Str) +} + +// VerifySignatureJSON verifies the signature in the JSON object _obj following +// the Matrix specification: +// https://matrix.org/speculator/spec/drafts%2Fe2e/appendices.html#signing-json +// This function is a wrapper over Utility.VerifySignatureJSON that creates and +// destroys the Utility object transparently. +// If the _obj is a struct, the `json` tags will be honored. +func VerifySignatureJSON(obj interface{}, userID id.UserID, keyName string, key id.Ed25519) (bool, error) { + u := NewUtility() + defer u.Clear() + return u.VerifySignatureJSON(obj, userID, keyName, key) +} diff --git a/crypto/olm/verification.go b/crypto/olm/verification.go index 4a363a0a..abdf9a4e 100644 --- a/crypto/olm/verification.go +++ b/crypto/olm/verification.go @@ -1,3 +1,5 @@ +//go:build !nosas && !goolm + package olm // #cgo LDFLAGS: -lolm -lstdc++ @@ -41,7 +43,7 @@ func NewSAS() *SAS { random := make([]byte, sas.sasRandomLength()+1) _, err := rand.Read(random) if err != nil { - panic(NotEnoughGoRandom) + panic(ErrNotEnoughGoRandom) } r := C.olm_create_sas( (*C.OlmSAS)(sas.int), diff --git a/crypto/olm/verification_goolm.go b/crypto/olm/verification_goolm.go new file mode 100644 index 00000000..fab51e5c --- /dev/null +++ b/crypto/olm/verification_goolm.go @@ -0,0 +1,23 @@ +//go:build !nosas && goolm + +package olm + +import ( + "maunium.net/go/mautrix/crypto/goolm/sas" +) + +// SAS stores an Olm Short Authentication String (SAS) object. +type SAS struct { + sas.SAS +} + +// NewSAS creates a new SAS object. +func NewSAS() *SAS { + newSAS, err := sas.New() + if err != nil { + panic(err) + } + return &SAS{ + SAS: *newSAS, + } +} From fde9c645f47bf61f24a05ac76318dc06d6b48469 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 15 Dec 2023 14:11:58 +0200 Subject: [PATCH 0032/1647] Revert changes to error names --- crypto/keyexport.go | 4 ++-- crypto/olm/account.go | 22 +++++++++++----------- crypto/olm/account_goolm.go | 12 ++++++------ crypto/olm/error.go | 10 +++++----- crypto/olm/error_goolm.go | 10 +++++----- crypto/olm/inboundgroupsession.go | 18 +++++++++--------- crypto/olm/inboundgroupsession_goolm.go | 16 ++++++++-------- crypto/olm/outboundgroupsession.go | 12 ++++++------ crypto/olm/outboundgroupsession_goolm.go | 10 +++++----- crypto/olm/pk.go | 2 +- crypto/olm/session.go | 20 ++++++++++---------- crypto/olm/session_goolm.go | 14 +++++++------- crypto/olm/utility.go | 6 +++--- crypto/olm/utility_goolm.go | 6 +++--- crypto/olm/verification.go | 2 +- 15 files changed, 82 insertions(+), 82 deletions(-) diff --git a/crypto/keyexport.go b/crypto/keyexport.go index d5a37702..91bfb6c6 100644 --- a/crypto/keyexport.go +++ b/crypto/keyexport.go @@ -69,7 +69,7 @@ func makeExportIV() []byte { iv := make([]byte, 16) _, err := rand.Read(iv) if err != nil { - panic(olm.ErrNotEnoughGoRandom) + panic(olm.NotEnoughGoRandom) } // Set bit 63 to zero iv[7] &= 0b11111110 @@ -80,7 +80,7 @@ func makeExportKeys(passphrase string) (encryptionKey, hashKey, salt, iv []byte) salt = make([]byte, 16) _, err := rand.Read(salt) if err != nil { - panic(olm.ErrNotEnoughGoRandom) + panic(olm.NotEnoughGoRandom) } encryptionKey, hashKey = computeKey(passphrase, salt, defaultPassphraseRounds) diff --git a/crypto/olm/account.go b/crypto/olm/account.go index d3298d6e..86487ac3 100644 --- a/crypto/olm/account.go +++ b/crypto/olm/account.go @@ -32,7 +32,7 @@ type Account struct { // "INVALID_BASE64". func AccountFromPickled(pickled, key []byte) (*Account, error) { if len(pickled) == 0 { - return nil, ErrEmptyInput + return nil, EmptyInput } a := NewBlankAccount() return a, a.Unpickle(pickled, key) @@ -52,7 +52,7 @@ func NewAccount() *Account { random := make([]byte, a.createRandomLen()+1) _, err := rand.Read(random) if err != nil { - panic(ErrNotEnoughGoRandom) + panic(NotEnoughGoRandom) } r := C.olm_create_account( (*C.OlmAccount)(a.int), @@ -126,7 +126,7 @@ func (a *Account) genOneTimeKeysRandomLen(num uint) uint { // supplied key. func (a *Account) Pickle(key []byte) []byte { if len(key) == 0 { - panic(ErrNoKeyProvided) + panic(NoKeyProvided) } pickled := make([]byte, a.pickleLen()) r := C.olm_pickle_account( @@ -143,7 +143,7 @@ func (a *Account) Pickle(key []byte) []byte { func (a *Account) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return ErrNoKeyProvided + return NoKeyProvided } r := C.olm_unpickle_account( (*C.OlmAccount)(a.int), @@ -186,7 +186,7 @@ func (a *Account) MarshalJSON() ([]byte, error) { func (a *Account) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return ErrInputNotJSONString + return InputNotJSONString } if a.int == nil { *a = *NewBlankAccount() @@ -220,7 +220,7 @@ func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519) { // Account. func (a *Account) Sign(message []byte) []byte { if len(message) == 0 { - panic(ErrEmptyInput) + panic(EmptyInput) } signature := make([]byte, a.signatureLen()) r := C.olm_account_sign( @@ -298,7 +298,7 @@ func (a *Account) GenOneTimeKeys(num uint) { random := make([]byte, a.genOneTimeKeysRandomLen(num)+1) _, err := rand.Read(random) if err != nil { - panic(ErrNotEnoughGoRandom) + panic(NotEnoughGoRandom) } r := C.olm_account_generate_one_time_keys( (*C.OlmAccount)(a.int), @@ -315,13 +315,13 @@ func (a *Account) GenOneTimeKeys(num uint) { // keys couldn't be decoded as base64 then the error will be "INVALID_BASE64" func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*Session, error) { if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 { - return nil, ErrEmptyInput + return nil, EmptyInput } s := NewBlankSession() random := make([]byte, s.createOutboundRandomLen()+1) _, err := rand.Read(random) if err != nil { - panic(ErrNotEnoughGoRandom) + panic(NotEnoughGoRandom) } r := C.olm_create_outbound_session( (*C.OlmSession)(s.int), @@ -347,7 +347,7 @@ func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve2 // time key then the error will be "BAD_MESSAGE_KEY_ID". func (a *Account) NewInboundSession(oneTimeKeyMsg string) (*Session, error) { if len(oneTimeKeyMsg) == 0 { - return nil, ErrEmptyInput + return nil, EmptyInput } s := NewBlankSession() r := C.olm_create_inbound_session( @@ -370,7 +370,7 @@ func (a *Account) NewInboundSession(oneTimeKeyMsg string) (*Session, error) { // time key then the error will be "BAD_MESSAGE_KEY_ID". func (a *Account) NewInboundSessionFrom(theirIdentityKey id.Curve25519, oneTimeKeyMsg string) (*Session, error) { if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 { - return nil, ErrEmptyInput + return nil, EmptyInput } s := NewBlankSession() r := C.olm_create_inbound_session_from( diff --git a/crypto/olm/account_goolm.go b/crypto/olm/account_goolm.go index ca448322..6260f779 100644 --- a/crypto/olm/account_goolm.go +++ b/crypto/olm/account_goolm.go @@ -43,7 +43,7 @@ func (a *Account) Clear() error { // supplied key. func (a *Account) Pickle(key []byte) []byte { if len(key) == 0 { - panic(ErrNoKeyProvided) + panic(NoKeyProvided) } pickled, err := a.Account.Pickle(key) if err != nil { @@ -84,7 +84,7 @@ func (a *Account) MarshalJSON() ([]byte, error) { func (a *Account) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return ErrInputNotJSONString + return InputNotJSONString } return a.Unpickle(data[1:len(data)-1], pickleKey) } @@ -102,7 +102,7 @@ func (a *Account) IdentityKeysJSON() []byte { // Account. func (a *Account) Sign(message []byte) []byte { if len(message) == 0 { - panic(ErrEmptyInput) + panic(EmptyInput) } signature, err := a.Account.Sign(message) if err != nil { @@ -143,7 +143,7 @@ func (a *Account) GenOneTimeKeys(num uint) { // given curve25519 identityKey and oneTimeKey. Returns error on failure. func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*Session, error) { if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 { - return nil, ErrEmptyInput + return nil, EmptyInput } s := &Session{} newSession, err := a.Account.NewOutboundSession(theirIdentityKey, theirOneTimeKey) @@ -158,7 +158,7 @@ func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve2 // messages from an incoming PRE_KEY message. Returns error on failure. func (a *Account) NewInboundSession(oneTimeKeyMsg string) (*Session, error) { if len(oneTimeKeyMsg) == 0 { - return nil, ErrEmptyInput + return nil, EmptyInput } s := &Session{} newSession, err := a.Account.NewInboundSession(nil, []byte(oneTimeKeyMsg)) @@ -173,7 +173,7 @@ func (a *Account) NewInboundSession(oneTimeKeyMsg string) (*Session, error) { // messages from an incoming PRE_KEY message. Returns error on failure. func (a *Account) NewInboundSessionFrom(theirIdentityKey id.Curve25519, oneTimeKeyMsg string) (*Session, error) { if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 { - return nil, ErrEmptyInput + return nil, EmptyInput } s := &Session{} newSession, err := a.Account.NewInboundSession(&theirIdentityKey, []byte(oneTimeKeyMsg)) diff --git a/crypto/olm/error.go b/crypto/olm/error.go index b15af138..63352e20 100644 --- a/crypto/olm/error.go +++ b/crypto/olm/error.go @@ -9,11 +9,11 @@ import ( // Error codes from go-olm var ( - ErrEmptyInput = errors.New("empty input") - ErrNoKeyProvided = errors.New("no pickle key provided") - ErrNotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand") - ErrSignatureNotFound = errors.New("input JSON doesn't contain signature from specified device") - ErrInputNotJSONString = errors.New("input doesn't look like a JSON string") + EmptyInput = errors.New("empty input") + NoKeyProvided = errors.New("no pickle key provided") + NotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand") + SignatureNotFound = errors.New("input JSON doesn't contain signature from specified device") + InputNotJSONString = errors.New("input doesn't look like a JSON string") ) // Error codes from olm code diff --git a/crypto/olm/error_goolm.go b/crypto/olm/error_goolm.go index e4e2cf70..0e54e566 100644 --- a/crypto/olm/error_goolm.go +++ b/crypto/olm/error_goolm.go @@ -10,11 +10,11 @@ import ( // Error codes from go-olm var ( - ErrEmptyInput = errors.New("empty input") - ErrNoKeyProvided = errors.New("no pickle key provided") - ErrNotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand") - ErrSignatureNotFound = errors.New("input JSON doesn't contain signature from specified device") - ErrInputNotJSONString = errors.New("input doesn't look like a JSON string") + EmptyInput = goolm.ErrEmptyInput + NoKeyProvided = goolm.ErrNoKeyProvided + NotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand") + SignatureNotFound = errors.New("input JSON doesn't contain signature from specified device") + InputNotJSONString = errors.New("input doesn't look like a JSON string") ) // Error codes from olm code diff --git a/crypto/olm/inboundgroupsession.go b/crypto/olm/inboundgroupsession.go index 14d2e226..0d825b1e 100644 --- a/crypto/olm/inboundgroupsession.go +++ b/crypto/olm/inboundgroupsession.go @@ -27,7 +27,7 @@ type InboundGroupSession struct { // base64 couldn't be decoded then the error will be "INVALID_BASE64". func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) { if len(pickled) == 0 { - return nil, ErrEmptyInput + return nil, EmptyInput } lenKey := len(key) if lenKey == 0 { @@ -44,7 +44,7 @@ func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, // "OLM_BAD_SESSION_KEY". func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { if len(sessionKey) == 0 { - return nil, ErrEmptyInput + return nil, EmptyInput } s := NewBlankInboundGroupSession() r := C.olm_init_inbound_group_session( @@ -63,7 +63,7 @@ func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { // error will be "OLM_BAD_SESSION_KEY". func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) { if len(sessionKey) == 0 { - return nil, ErrEmptyInput + return nil, EmptyInput } s := NewBlankInboundGroupSession() r := C.olm_import_inbound_group_session( @@ -116,7 +116,7 @@ func (s *InboundGroupSession) pickleLen() uint { // InboundGroupSession using the supplied key. func (s *InboundGroupSession) Pickle(key []byte) []byte { if len(key) == 0 { - panic(ErrNoKeyProvided) + panic(NoKeyProvided) } pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_inbound_group_session( @@ -133,9 +133,9 @@ func (s *InboundGroupSession) Pickle(key []byte) []byte { func (s *InboundGroupSession) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return ErrNoKeyProvided + return NoKeyProvided } else if len(pickled) == 0 { - return ErrEmptyInput + return EmptyInput } r := C.olm_unpickle_inbound_group_session( (*C.OlmInboundGroupSession)(s.int), @@ -178,7 +178,7 @@ func (s *InboundGroupSession) MarshalJSON() ([]byte, error) { func (s *InboundGroupSession) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return ErrInputNotJSONString + return InputNotJSONString } if s == nil || s.int == nil { *s = *NewBlankInboundGroupSession() @@ -201,7 +201,7 @@ func clone(original []byte) []byte { // will be "BAD_MESSAGE_FORMAT". func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, error) { if len(message) == 0 { - return 0, ErrEmptyInput + return 0, EmptyInput } // olm_group_decrypt_max_plaintext_length destroys the input, so we have to clone it message = clone(message) @@ -226,7 +226,7 @@ func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, erro // was shared with us) the error will be "OLM_UNKNOWN_MESSAGE_INDEX". func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) { if len(message) == 0 { - return nil, 0, ErrEmptyInput + return nil, 0, EmptyInput } decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message) if err != nil { diff --git a/crypto/olm/inboundgroupsession_goolm.go b/crypto/olm/inboundgroupsession_goolm.go index 3ece360b..56dc6418 100644 --- a/crypto/olm/inboundgroupsession_goolm.go +++ b/crypto/olm/inboundgroupsession_goolm.go @@ -20,7 +20,7 @@ type InboundGroupSession struct { // Returns error on failure. func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) { if len(pickled) == 0 { - return nil, ErrEmptyInput + return nil, EmptyInput } lenKey := len(key) if lenKey == 0 { @@ -39,7 +39,7 @@ func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, // exported from OutboundGroupSession.Key(). Returns error on failure. func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { if len(sessionKey) == 0 { - return nil, ErrEmptyInput + return nil, EmptyInput } megolmSession, err := session.NewMegolmInboundSession(sessionKey) if err != nil { @@ -54,7 +54,7 @@ func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { // export. Returns error on failure. func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) { if len(sessionKey) == 0 { - return nil, ErrEmptyInput + return nil, EmptyInput } megolmSession, err := session.NewMegolmInboundSessionFromExport(sessionKey) if err != nil { @@ -79,7 +79,7 @@ func (s *InboundGroupSession) Clear() error { // InboundGroupSession using the supplied key. func (s *InboundGroupSession) Pickle(key []byte) []byte { if len(key) == 0 { - panic(ErrNoKeyProvided) + panic(NoKeyProvided) } pickled, err := s.MegolmInboundSession.Pickle(key) if err != nil { @@ -90,9 +90,9 @@ func (s *InboundGroupSession) Pickle(key []byte) []byte { func (s *InboundGroupSession) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return ErrNoKeyProvided + return NoKeyProvided } else if len(pickled) == 0 { - return ErrEmptyInput + return EmptyInput } sOlm, err := session.MegolmInboundSessionFromPickled(pickled, key) if err != nil { @@ -137,7 +137,7 @@ func (s *InboundGroupSession) MarshalJSON() ([]byte, error) { func (s *InboundGroupSession) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return ErrInputNotJSONString + return InputNotJSONString } if s == nil { *s = *NewBlankInboundGroupSession() @@ -149,7 +149,7 @@ func (s *InboundGroupSession) UnmarshalJSON(data []byte) error { // plain-text and message index on success. Returns error on failure. func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) { if len(message) == 0 { - return nil, 0, ErrEmptyInput + return nil, 0, EmptyInput } plaintext, messageIndex, err := s.MegolmInboundSession.Decrypt(message) if err != nil { diff --git a/crypto/olm/outboundgroupsession.go b/crypto/olm/outboundgroupsession.go index 98024cee..c0866bd5 100644 --- a/crypto/olm/outboundgroupsession.go +++ b/crypto/olm/outboundgroupsession.go @@ -28,7 +28,7 @@ type OutboundGroupSession struct { // base64 couldn't be decoded then the error will be "INVALID_BASE64". func OutboundGroupSessionFromPickled(pickled, key []byte) (*OutboundGroupSession, error) { if len(pickled) == 0 { - return nil, ErrEmptyInput + return nil, EmptyInput } s := NewBlankOutboundGroupSession() return s, s.Unpickle(pickled, key) @@ -40,7 +40,7 @@ func NewOutboundGroupSession() *OutboundGroupSession { random := make([]byte, s.createRandomLen()+1) _, err := rand.Read(random) if err != nil { - panic(ErrNotEnoughGoRandom) + panic(NotEnoughGoRandom) } r := C.olm_init_outbound_group_session( (*C.OlmOutboundGroupSession)(s.int), @@ -93,7 +93,7 @@ func (s *OutboundGroupSession) pickleLen() uint { // OutboundGroupSession using the supplied key. func (s *OutboundGroupSession) Pickle(key []byte) []byte { if len(key) == 0 { - panic(ErrNoKeyProvided) + panic(NoKeyProvided) } pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_outbound_group_session( @@ -110,7 +110,7 @@ func (s *OutboundGroupSession) Pickle(key []byte) []byte { func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return ErrNoKeyProvided + return NoKeyProvided } r := C.olm_unpickle_outbound_group_session( (*C.OlmOutboundGroupSession)(s.int), @@ -153,7 +153,7 @@ func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) { func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return ErrInputNotJSONString + return InputNotJSONString } if s == nil || s.int == nil { *s = *NewBlankOutboundGroupSession() @@ -177,7 +177,7 @@ func (s *OutboundGroupSession) encryptMsgLen(plainTextLen int) uint { // as base64. func (s *OutboundGroupSession) Encrypt(plaintext []byte) []byte { if len(plaintext) == 0 { - panic(ErrEmptyInput) + panic(EmptyInput) } message := make([]byte, s.encryptMsgLen(len(plaintext))) r := C.olm_group_encrypt( diff --git a/crypto/olm/outboundgroupsession_goolm.go b/crypto/olm/outboundgroupsession_goolm.go index 46c520b4..06966f48 100644 --- a/crypto/olm/outboundgroupsession_goolm.go +++ b/crypto/olm/outboundgroupsession_goolm.go @@ -22,7 +22,7 @@ type OutboundGroupSession struct { // base64 couldn't be decoded then the error will be "INVALID_BASE64". func OutboundGroupSessionFromPickled(pickled, key []byte) (*OutboundGroupSession, error) { if len(pickled) == 0 { - return nil, ErrEmptyInput + return nil, EmptyInput } lenKey := len(key) if lenKey == 0 { @@ -63,7 +63,7 @@ func (s *OutboundGroupSession) Clear() error { // OutboundGroupSession using the supplied key. func (s *OutboundGroupSession) Pickle(key []byte) []byte { if len(key) == 0 { - panic(ErrNoKeyProvided) + panic(NoKeyProvided) } pickled, err := s.MegolmOutboundSession.Pickle(key) if err != nil { @@ -74,7 +74,7 @@ func (s *OutboundGroupSession) Pickle(key []byte) []byte { func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return ErrNoKeyProvided + return NoKeyProvided } return s.MegolmOutboundSession.Unpickle(pickled, key) } @@ -114,7 +114,7 @@ func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) { func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return ErrInputNotJSONString + return InputNotJSONString } if s == nil { *s = *NewBlankOutboundGroupSession() @@ -126,7 +126,7 @@ func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error { // as base64. func (s *OutboundGroupSession) Encrypt(plaintext []byte) []byte { if len(plaintext) == 0 { - panic(ErrEmptyInput) + panic(EmptyInput) } message, err := s.MegolmOutboundSession.Encrypt(plaintext) if err != nil { diff --git a/crypto/olm/pk.go b/crypto/olm/pk.go index 6d6d3c16..ba390afe 100644 --- a/crypto/olm/pk.go +++ b/crypto/olm/pk.go @@ -76,7 +76,7 @@ func NewPkSigning() (*PkSigning, error) { seed := make([]byte, pkSigningSeedLength()) _, err := rand.Read(seed) if err != nil { - panic(ErrNotEnoughGoRandom) + panic(NotEnoughGoRandom) } pk, err := NewPkSigningFromSeed(seed) return pk, err diff --git a/crypto/olm/session.go b/crypto/olm/session.go index e2920682..bd6d7431 100644 --- a/crypto/olm/session.go +++ b/crypto/olm/session.go @@ -42,7 +42,7 @@ func sessionSize() uint { // "INVALID_BASE64". func SessionFromPickled(pickled, key []byte) (*Session, error) { if len(pickled) == 0 { - return nil, ErrEmptyInput + return nil, EmptyInput } s := NewBlankSession() return s, s.Unpickle(pickled, key) @@ -109,7 +109,7 @@ func (s *Session) encryptMsgLen(plainTextLen int) uint { // will be "BAD_MESSAGE_FORMAT". func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) (uint, error) { if len(message) == 0 { - return 0, ErrEmptyInput + return 0, EmptyInput } r := C.olm_decrypt_max_plaintext_length( (*C.OlmSession)(s.int), @@ -126,7 +126,7 @@ func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) // supplied key. func (s *Session) Pickle(key []byte) []byte { if len(key) == 0 { - panic(ErrNoKeyProvided) + panic(NoKeyProvided) } pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_session( @@ -143,7 +143,7 @@ func (s *Session) Pickle(key []byte) []byte { func (s *Session) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return ErrNoKeyProvided + return NoKeyProvided } r := C.olm_unpickle_session( (*C.OlmSession)(s.int), @@ -186,7 +186,7 @@ func (s *Session) MarshalJSON() ([]byte, error) { func (s *Session) UnmarshalJSON(data []byte) error { if len(data) == 0 || len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return ErrInputNotJSONString + return InputNotJSONString } if s == nil || s.int == nil { *s = *NewBlankSession() @@ -228,7 +228,7 @@ func (s *Session) HasReceivedMessage() bool { // decoded then then the error will be "BAD_MESSAGE_FORMAT". func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { if len(oneTimeKeyMsg) == 0 { - return false, ErrEmptyInput + return false, EmptyInput } r := C.olm_matches_inbound_session( (*C.OlmSession)(s.int), @@ -253,7 +253,7 @@ func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { // decoded then then the error will be "BAD_MESSAGE_FORMAT". func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) { if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 { - return false, ErrEmptyInput + return false, EmptyInput } r := C.olm_matches_inbound_session_from( (*C.OlmSession)(s.int), @@ -289,13 +289,13 @@ func (s *Session) EncryptMsgType() id.OlmMsgType { // as base64. func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) { if len(plaintext) == 0 { - panic(ErrEmptyInput) + panic(EmptyInput) } // Make the slice be at least length 1 random := make([]byte, s.encryptRandomLen()+1) _, err := rand.Read(random) if err != nil { - panic(ErrNotEnoughGoRandom) + panic(NotEnoughGoRandom) } messageType := s.EncryptMsgType() message := make([]byte, s.encryptMsgLen(len(plaintext))) @@ -322,7 +322,7 @@ func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) { // "BAD_MESSAGE_MAC". func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) { if len(message) == 0 { - return nil, ErrEmptyInput + return nil, EmptyInput } decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message, msgType) if err != nil { diff --git a/crypto/olm/session_goolm.go b/crypto/olm/session_goolm.go index 5291cef4..c2684c11 100644 --- a/crypto/olm/session_goolm.go +++ b/crypto/olm/session_goolm.go @@ -18,7 +18,7 @@ type Session struct { // the Session using the supplied key. Returns error on failure. func SessionFromPickled(pickled, key []byte) (*Session, error) { if len(pickled) == 0 { - return nil, ErrEmptyInput + return nil, EmptyInput } s := NewBlankSession() return s, s.Unpickle(pickled, key) @@ -38,7 +38,7 @@ func (s *Session) Clear() error { // supplied key. func (s *Session) Pickle(key []byte) []byte { if len(key) == 0 { - panic(ErrNoKeyProvided) + panic(NoKeyProvided) } pickled, err := s.OlmSession.Pickle(key) if err != nil { @@ -49,9 +49,9 @@ func (s *Session) Pickle(key []byte) []byte { func (s *Session) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return ErrNoKeyProvided + return NoKeyProvided } else if len(pickled) == 0 { - return ErrEmptyInput + return EmptyInput } sOlm, err := session.OlmSessionFromPickled(pickled, key) if err != nil { @@ -96,7 +96,7 @@ func (s *Session) MarshalJSON() ([]byte, error) { func (s *Session) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return ErrInputNotJSONString + return InputNotJSONString } if s == nil { *s = *NewBlankSession() @@ -131,7 +131,7 @@ func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg stri // as base64. func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) { if len(plaintext) == 0 { - panic(ErrEmptyInput) + panic(EmptyInput) } messageType, message, err := s.OlmSession.Encrypt(plaintext, nil) if err != nil { @@ -144,7 +144,7 @@ func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) { // success. Returns error on failure. func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) { if len(message) == 0 { - return nil, ErrEmptyInput + return nil, EmptyInput } return s.OlmSession.Decrypt([]byte(message), msgType) } diff --git a/crypto/olm/utility.go b/crypto/olm/utility.go index 8d544dbb..87055fb3 100644 --- a/crypto/olm/utility.go +++ b/crypto/olm/utility.go @@ -63,7 +63,7 @@ func NewUtility() *Utility { // Sha256 calculates the SHA-256 hash of the input and encodes it as base64. func (u *Utility) Sha256(input string) string { if len(input) == 0 { - panic(ErrEmptyInput) + panic(EmptyInput) } output := make([]byte, u.sha256Len()) r := C.olm_sha256( @@ -83,7 +83,7 @@ func (u *Utility) Sha256(input string) string { // small then the error will be "INVALID_BASE64". func (u *Utility) VerifySignature(message string, key id.Ed25519, signature string) (ok bool, err error) { if len(message) == 0 || len(key) == 0 || len(signature) == 0 { - return false, ErrEmptyInput + return false, EmptyInput } r := C.olm_ed25519_verify( (*C.OlmUtility)(u.int), @@ -119,7 +119,7 @@ func (u *Utility) VerifySignatureJSON(obj interface{}, userID id.UserID, keyName } sig := gjson.GetBytes(objJSON, exgjson.Path("signatures", string(userID), fmt.Sprintf("ed25519:%s", keyName))) if !sig.Exists() || sig.Type != gjson.String { - return false, ErrSignatureNotFound + return false, SignatureNotFound } objJSON, err = sjson.DeleteBytes(objJSON, "unsigned") if err != nil { diff --git a/crypto/olm/utility_goolm.go b/crypto/olm/utility_goolm.go index 31299e63..926b5404 100644 --- a/crypto/olm/utility_goolm.go +++ b/crypto/olm/utility_goolm.go @@ -34,7 +34,7 @@ func NewUtility() *Utility { // Sha256 calculates the SHA-256 hash of the input and encodes it as base64. func (u *Utility) Sha256(input string) string { if len(input) == 0 { - panic(ErrEmptyInput) + panic(EmptyInput) } hash := sha256.Sum256([]byte(input)) return base64.RawStdEncoding.EncodeToString(hash[:]) @@ -45,7 +45,7 @@ func (u *Utility) Sha256(input string) string { // small then the error will be "INVALID_BASE64". func (u *Utility) VerifySignature(message string, key id.Ed25519, signature string) (ok bool, err error) { if len(message) == 0 || len(key) == 0 || len(signature) == 0 { - return false, ErrEmptyInput + return false, EmptyInput } return utilities.VerifySignature([]byte(message), key, []byte(signature)) } @@ -65,7 +65,7 @@ func (u *Utility) VerifySignatureJSON(obj interface{}, userID id.UserID, keyName } sig := gjson.GetBytes(objJSON, exgjson.Path("signatures", string(userID), fmt.Sprintf("ed25519:%s", keyName))) if !sig.Exists() || sig.Type != gjson.String { - return false, ErrSignatureNotFound + return false, SignatureNotFound } objJSON, err = sjson.DeleteBytes(objJSON, "unsigned") if err != nil { diff --git a/crypto/olm/verification.go b/crypto/olm/verification.go index abdf9a4e..bb0db7be 100644 --- a/crypto/olm/verification.go +++ b/crypto/olm/verification.go @@ -43,7 +43,7 @@ func NewSAS() *SAS { random := make([]byte, sas.sasRandomLength()+1) _, err := rand.Read(random) if err != nil { - panic(ErrNotEnoughGoRandom) + panic(NotEnoughGoRandom) } r := C.olm_create_sas( (*C.OlmSAS)(sas.int), From ca03f1df17fc2f5cc46462f32fa359e44a5f521d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 15 Dec 2023 14:17:49 +0200 Subject: [PATCH 0033/1647] Remove unused error constants --- crypto/goolm/errors.go | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/crypto/goolm/errors.go b/crypto/goolm/errors.go index 7df906d9..4dec9849 100644 --- a/crypto/goolm/errors.go +++ b/crypto/goolm/errors.go @@ -6,13 +6,11 @@ import ( // Those are the most common used errors var ( - ErrNoSigningKey = errors.New("no signing key") ErrBadSignature = errors.New("bad signature") ErrBadMAC = errors.New("bad mac") ErrBadMessageFormat = errors.New("bad message format") ErrBadVerification = errors.New("bad verification") ErrWrongProtocolVersion = errors.New("wrong protocol version") - ErrNoSessionKey = errors.New("no session key") ErrEmptyInput = errors.New("empty input") ErrNoKeyProvided = errors.New("no key") ErrBadMessageKeyID = errors.New("bad message key id") @@ -25,12 +23,8 @@ var ( ErrBadVersion = errors.New("wrong version") ErrNotBlocksize = errors.New("length != blocksize") ErrNotMultipleBlocksize = errors.New("length not a multiple of the blocksize") - ErrBase64InvalidLength = errors.New("base64 decode invalid length") - ErrWrongPickleVersion = errors.New("Wrong pickle version") - ErrSignatureNotFound = errors.New("signature not found") - ErrNotEnoughGoRandom = errors.New("Not enough random data available") + ErrWrongPickleVersion = errors.New("wrong pickle version") ErrValueTooShort = errors.New("value too short") ErrInputToSmall = errors.New("input too small (truncated?)") ErrOverflow = errors.New("overflow") - ErrBadBase64 = errors.New("Bad base64") ) From c44f7f24c232570c3ca1509c90bf95938e84a830 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 15 Dec 2023 14:36:41 +0200 Subject: [PATCH 0034/1647] Deprecate gob/json encoding olm structs --- crypto/olm/account.go | 4 +++ crypto/olm/account_goolm.go | 38 -------------------- crypto/olm/inboundgroupsession.go | 4 +++ crypto/olm/inboundgroupsession_goolm.go | 45 ------------------------ crypto/olm/olm_goolm.go | 4 +-- crypto/olm/outboundgroupsession.go | 4 +++ crypto/olm/outboundgroupsession_goolm.go | 45 ------------------------ crypto/olm/session.go | 4 +++ crypto/olm/session_goolm.go | 45 ------------------------ 9 files changed, 17 insertions(+), 176 deletions(-) diff --git a/crypto/olm/account.go b/crypto/olm/account.go index 86487ac3..37458d1b 100644 --- a/crypto/olm/account.go +++ b/crypto/olm/account.go @@ -157,6 +157,7 @@ func (a *Account) Unpickle(pickled, key []byte) error { return nil } +// Deprecated func (a *Account) GobEncode() ([]byte, error) { pickled := a.Pickle(pickleKey) length := base64.RawStdEncoding.DecodedLen(len(pickled)) @@ -165,6 +166,7 @@ func (a *Account) GobEncode() ([]byte, error) { return rawPickled, err } +// Deprecated func (a *Account) GobDecode(rawPickled []byte) error { if a.int == nil { *a = *NewBlankAccount() @@ -175,6 +177,7 @@ func (a *Account) GobDecode(rawPickled []byte) error { return a.Unpickle(pickled, pickleKey) } +// Deprecated func (a *Account) MarshalJSON() ([]byte, error) { pickled := a.Pickle(pickleKey) quotes := make([]byte, len(pickled)+2) @@ -184,6 +187,7 @@ func (a *Account) MarshalJSON() ([]byte, error) { return quotes, nil } +// Deprecated func (a *Account) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { return InputNotJSONString diff --git a/crypto/olm/account_goolm.go b/crypto/olm/account_goolm.go index 6260f779..eeff54f9 100644 --- a/crypto/olm/account_goolm.go +++ b/crypto/olm/account_goolm.go @@ -3,7 +3,6 @@ package olm import ( - "encoding/base64" "encoding/json" "github.com/tidwall/sjson" @@ -52,43 +51,6 @@ func (a *Account) Pickle(key []byte) []byte { return pickled } -func (a *Account) GobEncode() ([]byte, error) { - pickled, err := a.Account.Pickle(pickleKey) - if err != nil { - return nil, err - } - length := base64.RawStdEncoding.DecodedLen(len(pickled)) - rawPickled := make([]byte, length) - _, err = base64.RawStdEncoding.Decode(rawPickled, pickled) - return rawPickled, err -} - -func (a *Account) GobDecode(rawPickled []byte) error { - length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) - pickled := make([]byte, length) - base64.RawStdEncoding.Encode(pickled, rawPickled) - return a.Unpickle(pickled, pickleKey) -} - -func (a *Account) MarshalJSON() ([]byte, error) { - pickled, err := a.Account.Pickle(pickleKey) - if err != nil { - return nil, err - } - quotes := make([]byte, len(pickled)+2) - quotes[0] = '"' - quotes[len(quotes)-1] = '"' - copy(quotes[1:len(quotes)-1], pickled) - return quotes, nil -} - -func (a *Account) UnmarshalJSON(data []byte) error { - if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return InputNotJSONString - } - return a.Unpickle(data[1:len(data)-1], pickleKey) -} - // IdentityKeysJSON returns the public parts of the identity keys for the Account. func (a *Account) IdentityKeysJSON() []byte { identityKeys, err := a.Account.IdentityKeysJSON() diff --git a/crypto/olm/inboundgroupsession.go b/crypto/olm/inboundgroupsession.go index 0d825b1e..a3bd3b65 100644 --- a/crypto/olm/inboundgroupsession.go +++ b/crypto/olm/inboundgroupsession.go @@ -149,6 +149,7 @@ func (s *InboundGroupSession) Unpickle(pickled, key []byte) error { return nil } +// Deprecated func (s *InboundGroupSession) GobEncode() ([]byte, error) { pickled := s.Pickle(pickleKey) length := base64.RawStdEncoding.DecodedLen(len(pickled)) @@ -157,6 +158,7 @@ func (s *InboundGroupSession) GobEncode() ([]byte, error) { return rawPickled, err } +// Deprecated func (s *InboundGroupSession) GobDecode(rawPickled []byte) error { if s == nil || s.int == nil { *s = *NewBlankInboundGroupSession() @@ -167,6 +169,7 @@ func (s *InboundGroupSession) GobDecode(rawPickled []byte) error { return s.Unpickle(pickled, pickleKey) } +// Deprecated func (s *InboundGroupSession) MarshalJSON() ([]byte, error) { pickled := s.Pickle(pickleKey) quotes := make([]byte, len(pickled)+2) @@ -176,6 +179,7 @@ func (s *InboundGroupSession) MarshalJSON() ([]byte, error) { return quotes, nil } +// Deprecated func (s *InboundGroupSession) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { return InputNotJSONString diff --git a/crypto/olm/inboundgroupsession_goolm.go b/crypto/olm/inboundgroupsession_goolm.go index 56dc6418..4e561cf7 100644 --- a/crypto/olm/inboundgroupsession_goolm.go +++ b/crypto/olm/inboundgroupsession_goolm.go @@ -3,8 +3,6 @@ package olm import ( - "encoding/base64" - "maunium.net/go/mautrix/crypto/goolm/session" "maunium.net/go/mautrix/id" ) @@ -102,49 +100,6 @@ func (s *InboundGroupSession) Unpickle(pickled, key []byte) error { return nil } -func (s *InboundGroupSession) GobEncode() ([]byte, error) { - pickled, err := s.MegolmInboundSession.Pickle(pickleKey) - if err != nil { - return nil, err - } - length := base64.RawStdEncoding.DecodedLen(len(pickled)) - rawPickled := make([]byte, length) - _, err = base64.RawStdEncoding.Decode(rawPickled, pickled) - return rawPickled, err -} - -func (s *InboundGroupSession) GobDecode(rawPickled []byte) error { - if s == nil { - *s = *NewBlankInboundGroupSession() - } - length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) - pickled := make([]byte, length) - base64.RawStdEncoding.Encode(pickled, rawPickled) - return s.Unpickle(pickled, pickleKey) -} - -func (s *InboundGroupSession) MarshalJSON() ([]byte, error) { - pickled, err := s.MegolmInboundSession.Pickle(pickleKey) - if err != nil { - return nil, err - } - quotes := make([]byte, len(pickled)+2) - quotes[0] = '"' - quotes[len(quotes)-1] = '"' - copy(quotes[1:len(quotes)-1], pickled) - return quotes, nil -} - -func (s *InboundGroupSession) UnmarshalJSON(data []byte) error { - if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return InputNotJSONString - } - if s == nil { - *s = *NewBlankInboundGroupSession() - } - return s.Unpickle(data[1:len(data)-1], pickleKey) -} - // Decrypt decrypts a message using the InboundGroupSession. Returns the the // plain-text and message index on success. Returns error on failure. func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) { diff --git a/crypto/olm/olm_goolm.go b/crypto/olm/olm_goolm.go index 9acd1d8d..df18aa0e 100644 --- a/crypto/olm/olm_goolm.go +++ b/crypto/olm/olm_goolm.go @@ -15,9 +15,7 @@ func Version() (major, minor, patch uint8) { return goolm.GetLibaryVersion() } -var pickleKey = []byte("maunium.net/go/mautrix/crypto/olm") - // SetPickleKey sets the global pickle key used when encoding structs with Gob or JSON. func SetPickleKey(key []byte) { - pickleKey = key + panic("gob and json encoding is deprecated and not supported with goolm") } diff --git a/crypto/olm/outboundgroupsession.go b/crypto/olm/outboundgroupsession.go index c0866bd5..b6a33d36 100644 --- a/crypto/olm/outboundgroupsession.go +++ b/crypto/olm/outboundgroupsession.go @@ -124,6 +124,7 @@ func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error { return nil } +// Deprecated func (s *OutboundGroupSession) GobEncode() ([]byte, error) { pickled := s.Pickle(pickleKey) length := base64.RawStdEncoding.DecodedLen(len(pickled)) @@ -132,6 +133,7 @@ func (s *OutboundGroupSession) GobEncode() ([]byte, error) { return rawPickled, err } +// Deprecated func (s *OutboundGroupSession) GobDecode(rawPickled []byte) error { if s == nil || s.int == nil { *s = *NewBlankOutboundGroupSession() @@ -142,6 +144,7 @@ func (s *OutboundGroupSession) GobDecode(rawPickled []byte) error { return s.Unpickle(pickled, pickleKey) } +// Deprecated func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) { pickled := s.Pickle(pickleKey) quotes := make([]byte, len(pickled)+2) @@ -151,6 +154,7 @@ func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) { return quotes, nil } +// Deprecated func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { return InputNotJSONString diff --git a/crypto/olm/outboundgroupsession_goolm.go b/crypto/olm/outboundgroupsession_goolm.go index 06966f48..7c201213 100644 --- a/crypto/olm/outboundgroupsession_goolm.go +++ b/crypto/olm/outboundgroupsession_goolm.go @@ -3,8 +3,6 @@ package olm import ( - "encoding/base64" - "maunium.net/go/mautrix/crypto/goolm/session" "maunium.net/go/mautrix/id" ) @@ -79,49 +77,6 @@ func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error { return s.MegolmOutboundSession.Unpickle(pickled, key) } -func (s *OutboundGroupSession) GobEncode() ([]byte, error) { - pickled, err := s.MegolmOutboundSession.Pickle(pickleKey) - if err != nil { - return nil, err - } - length := base64.RawStdEncoding.DecodedLen(len(pickled)) - rawPickled := make([]byte, length) - _, err = base64.RawStdEncoding.Decode(rawPickled, pickled) - return rawPickled, err -} - -func (s *OutboundGroupSession) GobDecode(rawPickled []byte) error { - if s == nil { - *s = *NewBlankOutboundGroupSession() - } - length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) - pickled := make([]byte, length) - base64.RawStdEncoding.Encode(pickled, rawPickled) - return s.Unpickle(pickled, pickleKey) -} - -func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) { - pickled, err := s.MegolmOutboundSession.Pickle(pickleKey) - if err != nil { - return nil, err - } - quotes := make([]byte, len(pickled)+2) - quotes[0] = '"' - quotes[len(quotes)-1] = '"' - copy(quotes[1:len(quotes)-1], pickled) - return quotes, nil -} - -func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error { - if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return InputNotJSONString - } - if s == nil { - *s = *NewBlankOutboundGroupSession() - } - return s.Unpickle(data[1:len(data)-1], pickleKey) -} - // Encrypt encrypts a message using the Session. Returns the encrypted message // as base64. func (s *OutboundGroupSession) Encrypt(plaintext []byte) []byte { diff --git a/crypto/olm/session.go b/crypto/olm/session.go index bd6d7431..185e0b3d 100644 --- a/crypto/olm/session.go +++ b/crypto/olm/session.go @@ -157,6 +157,7 @@ func (s *Session) Unpickle(pickled, key []byte) error { return nil } +// Deprecated func (s *Session) GobEncode() ([]byte, error) { pickled := s.Pickle(pickleKey) length := base64.RawStdEncoding.DecodedLen(len(pickled)) @@ -165,6 +166,7 @@ func (s *Session) GobEncode() ([]byte, error) { return rawPickled, err } +// Deprecated func (s *Session) GobDecode(rawPickled []byte) error { if s == nil || s.int == nil { *s = *NewBlankSession() @@ -175,6 +177,7 @@ func (s *Session) GobDecode(rawPickled []byte) error { return s.Unpickle(pickled, pickleKey) } +// Deprecated func (s *Session) MarshalJSON() ([]byte, error) { pickled := s.Pickle(pickleKey) quotes := make([]byte, len(pickled)+2) @@ -184,6 +187,7 @@ func (s *Session) MarshalJSON() ([]byte, error) { return quotes, nil } +// Deprecated func (s *Session) UnmarshalJSON(data []byte) error { if len(data) == 0 || len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { return InputNotJSONString diff --git a/crypto/olm/session_goolm.go b/crypto/olm/session_goolm.go index c2684c11..c77efaa2 100644 --- a/crypto/olm/session_goolm.go +++ b/crypto/olm/session_goolm.go @@ -3,8 +3,6 @@ package olm import ( - "encoding/base64" - "maunium.net/go/mautrix/crypto/goolm/session" "maunium.net/go/mautrix/id" ) @@ -61,49 +59,6 @@ func (s *Session) Unpickle(pickled, key []byte) error { return nil } -func (s *Session) GobEncode() ([]byte, error) { - pickled, err := s.OlmSession.Pickle(pickleKey) - if err != nil { - return nil, err - } - length := base64.RawStdEncoding.DecodedLen(len(pickled)) - rawPickled := make([]byte, length) - _, err = base64.RawStdEncoding.Decode(rawPickled, pickled) - return rawPickled, err -} - -func (s *Session) GobDecode(rawPickled []byte) error { - if s == nil { - *s = *NewBlankSession() - } - length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) - pickled := make([]byte, length) - base64.RawStdEncoding.Encode(pickled, rawPickled) - return s.Unpickle(pickled, pickleKey) -} - -func (s *Session) MarshalJSON() ([]byte, error) { - pickled, err := s.OlmSession.Pickle(pickleKey) - if err != nil { - return nil, err - } - quotes := make([]byte, len(pickled)+2) - quotes[0] = '"' - quotes[len(quotes)-1] = '"' - copy(quotes[1:len(quotes)-1], pickled) - return quotes, nil -} - -func (s *Session) UnmarshalJSON(data []byte) error { - if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return InputNotJSONString - } - if s == nil { - *s = *NewBlankSession() - } - return s.Unpickle(data[1:len(data)-1], pickleKey) -} - // MatchesInboundSession checks if the PRE_KEY message is for this in-bound // Session. This can happen if multiple messages are sent to this Account // before this Account sends a message in reply. Returns true if the session From b892a26d6f5de146bbc7500ee3fefc63bb5a2fd4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 15 Dec 2023 14:51:22 +0200 Subject: [PATCH 0035/1647] Remove mock random --- crypto/goolm/account/account_test.go | 98 ++++++---------------------- 1 file changed, 19 insertions(+), 79 deletions(-) diff --git a/crypto/goolm/account/account_test.go b/crypto/goolm/account/account_test.go index eebf4f0b..d74b9bdb 100644 --- a/crypto/goolm/account/account_test.go +++ b/crypto/goolm/account/account_test.go @@ -12,38 +12,6 @@ import ( "maunium.net/go/mautrix/crypto/goolm/utilities" ) -type mockRandom struct { - tag byte - current byte -} - -func (m *mockRandom) get(length int) []byte { - res := make([]byte, length) - baseIndex := 0 - for length > 32 { - res[baseIndex] = m.tag - for i := 1; i < 32; i++ { - res[baseIndex+i] = m.current - } - length -= 32 - baseIndex += 32 - m.current++ - } - if length != 0 { - res[baseIndex] = m.tag - for i := 1; i < length-1; i++ { - res[baseIndex+i] = m.current - } - m.current++ - } - return res -} - -func (m *mockRandom) Read(target []byte) (int, error) { - res := m.get(len(target)) - return copy(target, res), nil -} - func TestAccount(t *testing.T) { firstAccount, err := account.NewAccount(nil) if err != nil { @@ -260,24 +228,16 @@ func TestOldAccountPickle(t *testing.T) { } func TestLoopback(t *testing.T) { - mockA := mockRandom{ - tag: []byte("A")[0], - current: 0x00, - } - mockB := mockRandom{ - tag: []byte("B")[0], - current: 0x80, - } - accountA, err := account.NewAccount(&mockA) + accountA, err := account.NewAccount(nil) if err != nil { t.Fatal(err) } - accountB, err := account.NewAccount(&mockB) + accountB, err := account.NewAccount(nil) if err != nil { t.Fatal(err) } - err = accountB.GenOneTimeKeys(&mockB, 42) + err = accountB.GenOneTimeKeys(nil, 42) if err != nil { t.Fatal(err) } @@ -288,7 +248,7 @@ func TestLoopback(t *testing.T) { } plainText := []byte("Hello, World") - msgType, message1, err := aliceSession.Encrypt(plainText, &mockA) + msgType, message1, err := aliceSession.Encrypt(plainText, nil) if err != nil { t.Fatal(err) } @@ -335,7 +295,7 @@ func TestLoopback(t *testing.T) { t.Fatal("messages are not the same") } - msgTyp2, message2, err := bobSession.Encrypt(plainText, &mockB) + msgTyp2, message2, err := bobSession.Encrypt(plainText, nil) if err != nil { t.Fatal(err) } @@ -364,24 +324,16 @@ func TestLoopback(t *testing.T) { } func TestMoreMessages(t *testing.T) { - mockA := mockRandom{ - tag: []byte("A")[0], - current: 0x00, - } - mockB := mockRandom{ - tag: []byte("B")[0], - current: 0x80, - } - accountA, err := account.NewAccount(&mockA) + accountA, err := account.NewAccount(nil) if err != nil { t.Fatal(err) } - accountB, err := account.NewAccount(&mockB) + accountB, err := account.NewAccount(nil) if err != nil { t.Fatal(err) } - err = accountB.GenOneTimeKeys(&mockB, 42) + err = accountB.GenOneTimeKeys(nil, 42) if err != nil { t.Fatal(err) } @@ -392,7 +344,7 @@ func TestMoreMessages(t *testing.T) { } plainText := []byte("Hello, World") - msgType, message1, err := aliceSession.Encrypt(plainText, &mockA) + msgType, message1, err := aliceSession.Encrypt(plainText, nil) if err != nil { t.Fatal(err) } @@ -414,7 +366,7 @@ func TestMoreMessages(t *testing.T) { for i := 0; i < 8; i++ { //alice sends, bob reveices - msgType, message, err := aliceSession.Encrypt(plainText, &mockA) + msgType, message, err := aliceSession.Encrypt(plainText, nil) if err != nil { t.Fatal(err) } @@ -437,7 +389,7 @@ func TestMoreMessages(t *testing.T) { } //now bob sends, alice receives - msgType, message, err = bobSession.Encrypt(plainText, &mockA) + msgType, message, err = bobSession.Encrypt(plainText, nil) if err != nil { t.Fatal(err) } @@ -455,24 +407,16 @@ func TestMoreMessages(t *testing.T) { } func TestFallbackKey(t *testing.T) { - mockA := mockRandom{ - tag: []byte("A")[0], - current: 0x00, - } - mockB := mockRandom{ - tag: []byte("B")[0], - current: 0x80, - } - accountA, err := account.NewAccount(&mockA) + accountA, err := account.NewAccount(nil) if err != nil { t.Fatal(err) } - accountB, err := account.NewAccount(&mockB) + accountB, err := account.NewAccount(nil) if err != nil { t.Fatal(err) } - err = accountB.GenFallbackKey(&mockB) + err = accountB.GenFallbackKey(nil) if err != nil { t.Fatal(err) } @@ -487,7 +431,7 @@ func TestFallbackKey(t *testing.T) { } plainText := []byte("Hello, World") - msgType, message1, err := aliceSession.Encrypt(plainText, &mockA) + msgType, message1, err := aliceSession.Encrypt(plainText, nil) if err != nil { t.Fatal(err) } @@ -535,7 +479,7 @@ func TestFallbackKey(t *testing.T) { } // create a new fallback key for B (the old fallback should still be usable) - err = accountB.GenFallbackKey(&mockB) + err = accountB.GenFallbackKey(nil) if err != nil { t.Fatal(err) } @@ -545,7 +489,7 @@ func TestFallbackKey(t *testing.T) { t.Fatal(err) } - msgType2, message2, err := aliceSession2.Encrypt(plainText, &mockA) + msgType2, message2, err := aliceSession2.Encrypt(plainText, nil) if err != nil { t.Fatal(err) } @@ -605,7 +549,7 @@ func TestFallbackKey(t *testing.T) { if err != nil { t.Fatal(err) } - msgType3, message3, err := aliceSession3.Encrypt(plainText, &mockA) + msgType3, message3, err := aliceSession3.Encrypt(plainText, nil) if err != nil { t.Fatal(err) } @@ -654,11 +598,7 @@ func TestOldV3AccountPickle(t *testing.T) { } func TestAccountSign(t *testing.T) { - mockA := mockRandom{ - tag: []byte("A")[0], - current: 0x00, - } - accountA, err := account.NewAccount(&mockA) + accountA, err := account.NewAccount(nil) if err != nil { t.Fatal(err) } From 893afc725981674f2e50d3be0bbab07bbaf91898 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 15 Dec 2023 15:12:15 +0200 Subject: [PATCH 0036/1647] Fix typos and JSON field names --- crypto/goolm/account/account.go | 12 ++-- crypto/goolm/account/account_test.go | 2 +- crypto/goolm/main.go | 6 +- crypto/goolm/message/group_message.go | 2 +- crypto/goolm/message/message.go | 4 +- crypto/goolm/message/prekey_message.go | 6 +- crypto/goolm/message/session_export.go | 2 +- crypto/goolm/olm/chain.go | 8 +-- crypto/goolm/olm/olm.go | 16 +++--- crypto/goolm/olm/olm_test.go | 4 +- crypto/goolm/olm/skipped_message.go | 4 +- crypto/goolm/pk/decryption.go | 54 +++++++++--------- crypto/goolm/pk/encryption.go | 2 +- crypto/goolm/pk/pk_test.go | 6 +- crypto/goolm/pk/signing.go | 2 +- .../goolm/session/megolm_inbound_session.go | 6 +- .../goolm/session/megolm_outbound_session.go | 2 +- crypto/goolm/session/olm_session.go | 56 +++++++++---------- crypto/olm/olm_goolm.go | 3 +- 19 files changed, 96 insertions(+), 101 deletions(-) diff --git a/crypto/goolm/account/account.go b/crypto/goolm/account/account.go index 168818bb..7896f849 100644 --- a/crypto/goolm/account/account.go +++ b/crypto/goolm/account/account.go @@ -31,12 +31,12 @@ type Account struct { IdKeys struct { Ed25519 crypto.Ed25519KeyPair `json:"ed25519,omitempty"` Curve25519 crypto.Curve25519KeyPair `json:"curve25519,omitempty"` - } `json:"identityKeys"` - OTKeys []crypto.OneTimeKey `json:"oneTimeKeys"` - CurrentFallbackKey crypto.OneTimeKey `json:"currentFallbackKey,omitempty"` - PrevFallbackKey crypto.OneTimeKey `json:"prevFallbackKey,omitempty"` - NextOneTimeKeyID uint32 `json:"nextOneTimeKeyID,omitempty"` - NumFallbackKeys uint8 `json:"numberFallbackKeys"` + } `json:"identity_keys"` + OTKeys []crypto.OneTimeKey `json:"one_time_keys"` + CurrentFallbackKey crypto.OneTimeKey `json:"current_fallback_key,omitempty"` + PrevFallbackKey crypto.OneTimeKey `json:"prev_fallback_key,omitempty"` + NextOneTimeKeyID uint32 `json:"next_one_time_key_id,omitempty"` + NumFallbackKeys uint8 `json:"number_fallback_keys"` } // AccountFromJSONPickled loads the Account details from a pickled base64 string. The input is decrypted with the supplied key. diff --git a/crypto/goolm/account/account_test.go b/crypto/goolm/account/account_test.go index d74b9bdb..a18840b1 100644 --- a/crypto/goolm/account/account_test.go +++ b/crypto/goolm/account/account_test.go @@ -98,7 +98,7 @@ func TestAccountPickleJSON(t *testing.T) { return */ - pickledData := []byte("fZG5DhZ0+uhVFEcdgo/dyWNy1BlSKo+W18D/QLBcZfvP0rByRzjgJM5yeDIO9N6jYFp2MbV1Y1DikFlDctwq7PhIRvbtLdrzxT94WoLrUdiNtQkw6NRNXvsFYo4NKoAgl1yQauttnGRBHCCPVV6e9d4kvnPVRkZNkbbANnadF0Tld/SMMWWoPI3L7dy+oiRh6nqNKvZz+upvgmOSm6gu2xV0yx9RJpkvLz8oHMDui1VQ1T2wTpfk5vdw0Cx4BXspf8WDnntdv0Ui4qBzUFmsB4lfqLviuhnAxu+qQrrKcZz/EyzbPwmI+P4Tn5KznxzEx2Nw/AjKKPxqVAKpx8+nV7rKKzlah71wX2CHyEsp2ptcNTJ1lr6tJxkOLdy8Rw285jpKw4MrgghnhqZ9Hh3y5P6KnRrq6zom9zfkCtCXs2h8BK+I0tkMPXO+JZoJKVOWzS+n7FIrC9XC9nAu19G5cnxv+tJdPb3p") + pickledData := []byte("6POkBWwbNl20fwvZWsOu0jgbHy4jkA5h0Ji+XCag59+ifWIRPDrqtgQi9HmkLiSF6wUhhYaV4S73WM+Hh+dlCuZRuXhTQr8yGPTifjcjq8birdAhObbEqHrYEdqaQkrgBLr/rlS5sibXeDqbkhVu4LslvootU9DkcCbd4b/0Flh7iugxqkcCs5GDndTEx9IzTVJzmK82Y0Q1Z1Z9Vuc2Iw746PtBJLtZjite6fSMp2NigPX/ZWWJ3OnwcJo0Vvjy8hgptZEWkamOHdWbUtelbHyjDIZlvxOC25D3rFif0zzPkF9qdpBPqVCWPPzGFmgnqKau6CHrnPfq7GLsM3BrprD7sHN1Js28ex14gXQPjBT7KTUo6H0e4gQMTMRp4qb8btNXDeId8xIFIElTh2SXZBTDmSq/ziVNJinEvYV8mGPvJZjDQQU+SyoS/HZ8uMc41tH0BOGDbFMHbfLMiz61E429gOrx2klu5lqyoyet7//HKi0ed5w2dQ") account, err := account.AccountFromJSONPickled(pickledData, key) if err != nil { t.Fatal(err) diff --git a/crypto/goolm/main.go b/crypto/goolm/main.go index 5e785c7b..55674305 100644 --- a/crypto/goolm/main.go +++ b/crypto/goolm/main.go @@ -1,10 +1,6 @@ -// goolm is a pure Go implementation of libolm. Libolm is a cryptographic library used for end-to-end encryption in Matrix and wirtten in C++. +// Package goolm is a pure Go implementation of libolm. Libolm is a cryptographic library used for end-to-end encryption in Matrix and written in C++. // With goolm there is no need to use cgo when building Matrix clients in go. /* This package contains the possible errors which can occur as well as some simple functions. All the 'action' happens in the subdirectories. */ package goolm - -func GetLibaryVersion() (major, minor, patch uint8) { - return 3, 2, 14 -} diff --git a/crypto/goolm/message/group_message.go b/crypto/goolm/message/group_message.go index 176214f6..ebd5b77e 100644 --- a/crypto/goolm/message/group_message.go +++ b/crypto/goolm/message/group_message.go @@ -18,7 +18,7 @@ type GroupMessage struct { Version byte `json:"version"` MessageIndex uint32 `json:"index"` Ciphertext []byte `json:"ciphertext"` - HasMessageIndex bool `json:"hasIndex"` + HasMessageIndex bool `json:"has_index"` } // Decodes decodes the input and populates the corresponding fileds. MAC and signature are ignored but have to be present. diff --git a/crypto/goolm/message/message.go b/crypto/goolm/message/message.go index d5c15b1a..8b721aeb 100644 --- a/crypto/goolm/message/message.go +++ b/crypto/goolm/message/message.go @@ -17,9 +17,9 @@ const ( // GroupMessage represents a message in the message format. type Message struct { Version byte `json:"version"` - HasCounter bool `json:"hasCounter"` + HasCounter bool `json:"has_counter"` Counter uint32 `json:"counter"` - RatchetKey crypto.Curve25519PublicKey `json:"ratchetKey"` + RatchetKey crypto.Curve25519PublicKey `json:"ratchet_key"` Ciphertext []byte `json:"ciphertext"` } diff --git a/crypto/goolm/message/prekey_message.go b/crypto/goolm/message/prekey_message.go index 9df3f9fa..6e007e06 100644 --- a/crypto/goolm/message/prekey_message.go +++ b/crypto/goolm/message/prekey_message.go @@ -13,9 +13,9 @@ const ( type PreKeyMessage struct { Version byte `json:"version"` - IdentityKey crypto.Curve25519PublicKey `json:"idKey"` - BaseKey crypto.Curve25519PublicKey `json:"baseKey"` - OneTimeKey crypto.Curve25519PublicKey `json:"otKey"` + IdentityKey crypto.Curve25519PublicKey `json:"id_key"` + BaseKey crypto.Curve25519PublicKey `json:"base_key"` + OneTimeKey crypto.Curve25519PublicKey `json:"one_time_key"` Message []byte `json:"message"` } diff --git a/crypto/goolm/message/session_export.go b/crypto/goolm/message/session_export.go index 5c4487e3..f539cce5 100644 --- a/crypto/goolm/message/session_export.go +++ b/crypto/goolm/message/session_export.go @@ -16,7 +16,7 @@ const ( type MegolmSessionExport struct { Counter uint32 `json:"counter"` RatchetData [128]byte `json:"data"` - PublicKey crypto.Ed25519PublicKey `json:"kPub"` + PublicKey crypto.Ed25519PublicKey `json:"public_key"` } // Encode returns the encoded message in the correct format. diff --git a/crypto/goolm/olm/chain.go b/crypto/goolm/olm/chain.go index 76db1eaa..403637a4 100644 --- a/crypto/goolm/olm/chain.go +++ b/crypto/goolm/olm/chain.go @@ -64,8 +64,8 @@ func (r chainKey) PickleLen() int { // senderChain is a chain for sending messages type senderChain struct { - RKey crypto.Curve25519KeyPair `json:"ratchetKey"` - CKey chainKey `json:"chainKey"` + RKey crypto.Curve25519KeyPair `json:"ratchet_key"` + CKey chainKey `json:"chain_key"` IsSet bool `json:"set"` } @@ -139,8 +139,8 @@ func (r senderChain) PickleLen() int { // senderChain is a chain for receiving messages type receiverChain struct { - RKey crypto.Curve25519PublicKey `json:"ratchetKey"` - CKey chainKey `json:"chainKey"` + RKey crypto.Curve25519PublicKey `json:"ratchet_key"` + CKey chainKey `json:"chain_key"` } // newReceiverChain returns a receiver chain initialized with chainKey and ratchet public key. diff --git a/crypto/goolm/olm/olm.go b/crypto/goolm/olm/olm.go index 2d8542fd..299ec7c4 100644 --- a/crypto/goolm/olm/olm.go +++ b/crypto/goolm/olm/olm.go @@ -44,22 +44,22 @@ var RatchetCipher = cipher.NewAESSHA256([]byte("OLM_KEYS")) type Ratchet struct { // The root key is used to generate chain keys from the ephemeral keys. // A new root_key is derived each time a new chain is started. - RootKey crypto.Curve25519PublicKey `json:"rootKey"` + RootKey crypto.Curve25519PublicKey `json:"root_key"` // The sender chain is used to send messages. Each time a new ephemeral // key is received from the remote server we generate a new sender chain // with a new ephemeral key when we next send a message. - SenderChains senderChain `json:"senderChain"` + SenderChains senderChain `json:"sender_chain"` // The receiver chain is used to decrypt received messages. We store the // last few chains so we can decrypt any out of order messages we haven't // received yet. // New chains are prepended for easier access. - ReceiverChains []receiverChain `json:"receiverChains"` + ReceiverChains []receiverChain `json:"receiver_chains"` // Storing the keys of missed messages for future use. // The order of the elements is not important. - SkippedMessageKeys []skippedMessageKey `json:"skippedMessageKeys"` + SkippedMessageKeys []skippedMessageKey `json:"skipped_message_keys"` } // New creates a new ratchet, setting the kdfInfos and cipher. @@ -68,8 +68,8 @@ func New() *Ratchet { return r } -// InitialiseAsBob initialises this ratchet from a receiving point of view (only first message). -func (r *Ratchet) InitialiseAsBob(sharedSecret []byte, theirRatchetKey crypto.Curve25519PublicKey) error { +// InitializeAsBob initializes this ratchet from a receiving point of view (only first message). +func (r *Ratchet) InitializeAsBob(sharedSecret []byte, theirRatchetKey crypto.Curve25519PublicKey) error { derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, nil, KdfInfo.Root) derivedSecrets := make([]byte, 2*sharedKeyLength) if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil { @@ -81,8 +81,8 @@ func (r *Ratchet) InitialiseAsBob(sharedSecret []byte, theirRatchetKey crypto.Cu return nil } -// InitialiseAsAlice initialises this ratchet from a sending point of view (only first message). -func (r *Ratchet) InitialiseAsAlice(sharedSecret []byte, ourRatchetKey crypto.Curve25519KeyPair) error { +// InitializeAsAlice initializes this ratchet from a sending point of view (only first message). +func (r *Ratchet) InitializeAsAlice(sharedSecret []byte, ourRatchetKey crypto.Curve25519KeyPair) error { derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, nil, KdfInfo.Root) derivedSecrets := make([]byte, 2*sharedKeyLength) if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil { diff --git a/crypto/goolm/olm/olm_test.go b/crypto/goolm/olm/olm_test.go index f97a0aeb..974ffc5e 100644 --- a/crypto/goolm/olm/olm_test.go +++ b/crypto/goolm/olm/olm_test.go @@ -31,8 +31,8 @@ func initializeRatchets() (*olm.Ratchet, *olm.Ratchet, error) { return nil, nil, err } - aliceRatchet.InitialiseAsAlice(sharedSecret, aliceKey) - bobRatchet.InitialiseAsBob(sharedSecret, aliceKey.PublicKey) + aliceRatchet.InitializeAsAlice(sharedSecret, aliceKey) + bobRatchet.InitializeAsBob(sharedSecret, aliceKey.PublicKey) return aliceRatchet, bobRatchet, nil } diff --git a/crypto/goolm/olm/skipped_message.go b/crypto/goolm/olm/skipped_message.go index 93d7c283..944337f6 100644 --- a/crypto/goolm/olm/skipped_message.go +++ b/crypto/goolm/olm/skipped_message.go @@ -9,8 +9,8 @@ import ( // skippedMessageKey stores a skipped message key type skippedMessageKey struct { - RKey crypto.Curve25519PublicKey `json:"ratchetKey"` - MKey messageKey `json:"messageKey"` + RKey crypto.Curve25519PublicKey `json:"ratchet_key"` + MKey messageKey `json:"message_key"` } // UnpickleLibOlm decodes the unencryted value and populates the chain accordingly. It returns the number of bytes read. diff --git a/crypto/goolm/pk/decryption.go b/crypto/goolm/pk/decryption.go index c94bfd80..3fb3c2a5 100644 --- a/crypto/goolm/pk/decryption.go +++ b/crypto/goolm/pk/decryption.go @@ -18,25 +18,25 @@ const ( decryptionPickleVersionLibOlm uint32 = 1 ) -// Decription is used to decrypt pk messages -type Decription struct { - KeyPair crypto.Curve25519KeyPair `json:"keyPair"` +// Decryption is used to decrypt pk messages +type Decryption struct { + KeyPair crypto.Curve25519KeyPair `json:"key_pair"` } -// NewDecription returns a new Decription with a new generated key pair. -func NewDecription() (*Decription, error) { +// NewDecryption returns a new Decryption with a new generated key pair. +func NewDecryption() (*Decryption, error) { keyPair, err := crypto.Curve25519GenerateKey(nil) if err != nil { return nil, err } - return &Decription{ + return &Decryption{ KeyPair: keyPair, }, nil } -// NewDescriptionFromPrivate resturns a new Decription with the private key fixed. -func NewDecriptionFromPrivate(privateKey crypto.Curve25519PrivateKey) (*Decription, error) { - s := &Decription{} +// NewDescriptionFromPrivate resturns a new Decryption with the private key fixed. +func NewDecryptionFromPrivate(privateKey crypto.Curve25519PrivateKey) (*Decryption, error) { + s := &Decryption{} keyPair, err := crypto.Curve25519GenerateFromPrivate(privateKey) if err != nil { return nil, err @@ -46,17 +46,17 @@ func NewDecriptionFromPrivate(privateKey crypto.Curve25519PrivateKey) (*Decripti } // PubKey returns the public key base 64 encoded. -func (s Decription) PubKey() id.Curve25519 { +func (s Decryption) PubKey() id.Curve25519 { return s.KeyPair.B64Encoded() } // PrivateKey returns the private key. -func (s Decription) PrivateKey() crypto.Curve25519PrivateKey { +func (s Decryption) PrivateKey() crypto.Curve25519PrivateKey { return s.KeyPair.PrivateKey } // Decrypt decrypts the ciphertext and verifies the MAC. The base64 encoded key is used to construct the shared secret. -func (s Decription) Decrypt(ciphertext, mac []byte, key id.Curve25519) ([]byte, error) { +func (s Decryption) Decrypt(ciphertext, mac []byte, key id.Curve25519) ([]byte, error) { keyDecoded, err := base64.RawStdEncoding.DecodeString(string(key)) if err != nil { return nil, err @@ -84,19 +84,19 @@ func (s Decription) Decrypt(ciphertext, mac []byte, key id.Curve25519) ([]byte, return plaintext, nil } -// PickleAsJSON returns an Decription as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. -func (a Decription) PickleAsJSON(key []byte) ([]byte, error) { +// PickleAsJSON returns an Decryption as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. +func (a Decryption) PickleAsJSON(key []byte) ([]byte, error) { return utilities.PickleAsJSON(a, decryptionPickleVersionJSON, key) } -// UnpickleAsJSON updates an Decription by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. -func (a *Decription) UnpickleAsJSON(pickled, key []byte) error { +// UnpickleAsJSON updates an Decryption by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. +func (a *Decryption) UnpickleAsJSON(pickled, key []byte) error { return utilities.UnpickleAsJSON(a, pickled, key, decryptionPickleVersionJSON) } // Unpickle decodes the base64 encoded string and decrypts the result with the key. // The decrypted value is then passed to UnpickleLibOlm. -func (a *Decription) Unpickle(pickled, key []byte) error { +func (a *Decryption) Unpickle(pickled, key []byte) error { decrypted, err := cipher.Unpickle(key, pickled) if err != nil { return err @@ -105,8 +105,8 @@ func (a *Decription) Unpickle(pickled, key []byte) error { return err } -// UnpickleLibOlm decodes the unencryted value and populates the Decription accordingly. It returns the number of bytes read. -func (a *Decription) UnpickleLibOlm(value []byte) (int, error) { +// UnpickleLibOlm decodes the unencryted value and populates the Decryption accordingly. It returns the number of bytes read. +func (a *Decryption) UnpickleLibOlm(value []byte) (int, error) { //First 4 bytes are the accountPickleVersion pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value) if err != nil { @@ -125,8 +125,8 @@ func (a *Decription) UnpickleLibOlm(value []byte) (int, error) { return curPos, nil } -// Pickle returns a base64 encoded and with key encrypted pickled Decription using PickleLibOlm(). -func (a Decription) Pickle(key []byte) ([]byte, error) { +// Pickle returns a base64 encoded and with key encrypted pickled Decryption using PickleLibOlm(). +func (a Decryption) Pickle(key []byte) ([]byte, error) { pickeledBytes := make([]byte, a.PickleLen()) written, err := a.PickleLibOlm(pickeledBytes) if err != nil { @@ -142,23 +142,23 @@ func (a Decription) Pickle(key []byte) ([]byte, error) { return encrypted, nil } -// PickleLibOlm encodes the Decription into target. target has to have a size of at least PickleLen() and is written to from index 0. +// PickleLibOlm encodes the Decryption into target. target has to have a size of at least PickleLen() and is written to from index 0. // It returns the number of bytes written. -func (a Decription) PickleLibOlm(target []byte) (int, error) { +func (a Decryption) PickleLibOlm(target []byte) (int, error) { if len(target) < a.PickleLen() { - return 0, fmt.Errorf("pickle Decription: %w", goolm.ErrValueTooShort) + return 0, fmt.Errorf("pickle Decryption: %w", goolm.ErrValueTooShort) } written := libolmpickle.PickleUInt32(decryptionPickleVersionLibOlm, target) writtenKey, err := a.KeyPair.PickleLibOlm(target[written:]) if err != nil { - return 0, fmt.Errorf("pickle Decription: %w", err) + return 0, fmt.Errorf("pickle Decryption: %w", err) } written += writtenKey return written, nil } -// PickleLen returns the number of bytes the pickled Decription will have. -func (a Decription) PickleLen() int { +// PickleLen returns the number of bytes the pickled Decryption will have. +func (a Decryption) PickleLen() int { length := libolmpickle.PickleUInt32Len(decryptionPickleVersionLibOlm) length += a.KeyPair.PickleLen() return length diff --git a/crypto/goolm/pk/encryption.go b/crypto/goolm/pk/encryption.go index 19d9688a..dc50a6bb 100644 --- a/crypto/goolm/pk/encryption.go +++ b/crypto/goolm/pk/encryption.go @@ -12,7 +12,7 @@ import ( // Encryption is used to encrypt pk messages type Encryption struct { - RecipientKey crypto.Curve25519PublicKey + RecipientKey crypto.Curve25519PublicKey `json:"recipient_key"` } // NewEncryption returns a new Encryption with the base64 encoded public key of the recipient diff --git a/crypto/goolm/pk/pk_test.go b/crypto/goolm/pk/pk_test.go index 72a48767..91bab5b9 100644 --- a/crypto/goolm/pk/pk_test.go +++ b/crypto/goolm/pk/pk_test.go @@ -26,7 +26,7 @@ func TestEncryptionDecryption(t *testing.T) { 0x1C, 0x2F, 0x8B, 0x27, 0xFF, 0x88, 0xE0, 0xEB, } bobPublic := []byte("3p7bfXt9wbTTW2HC7OQ1Nz+DQ8hbeGdNrfx+FG+IK08") - decryption, err := pk.NewDecriptionFromPrivate(alicePrivate) + decryption, err := pk.NewDecryptionFromPrivate(alicePrivate) if err != nil { t.Fatal(err) } @@ -97,7 +97,7 @@ func TestDecryptionPickling(t *testing.T) { 0xB1, 0x77, 0xFB, 0xA5, 0x1D, 0xB9, 0x2C, 0x2A, } alicePublic := []byte("hSDwCYkwp1R0i33ctD73Wg2/Og0mOBr066SpjqqbTmo") - decryption, err := pk.NewDecriptionFromPrivate(alicePrivate) + decryption, err := pk.NewDecryptionFromPrivate(alicePrivate) if err != nil { t.Fatal(err) } @@ -117,7 +117,7 @@ func TestDecryptionPickling(t *testing.T) { t.Fatalf("pickle not as expected:\n%v\n%v\n", pickled, expectedPickle) } - newDecription, err := pk.NewDecription() + newDecription, err := pk.NewDecryption() if err != nil { t.Fatal(err) } diff --git a/crypto/goolm/pk/signing.go b/crypto/goolm/pk/signing.go index 493705f6..046838ff 100644 --- a/crypto/goolm/pk/signing.go +++ b/crypto/goolm/pk/signing.go @@ -10,7 +10,7 @@ import ( // Signing is used for signing a pk type Signing struct { - KeyPair crypto.Ed25519KeyPair `json:"keyPair"` + KeyPair crypto.Ed25519KeyPair `json:"key_pair"` Seed []byte `json:"seed"` } diff --git a/crypto/goolm/session/megolm_inbound_session.go b/crypto/goolm/session/megolm_inbound_session.go index 8214aefc..165f7f16 100644 --- a/crypto/goolm/session/megolm_inbound_session.go +++ b/crypto/goolm/session/megolm_inbound_session.go @@ -23,9 +23,9 @@ const ( // MegolmInboundSession stores information about the sessions of receive. type MegolmInboundSession struct { Ratchet megolm.Ratchet `json:"ratchet"` - SigningKey crypto.Ed25519PublicKey `json:"signingKey"` - InitialRatchet megolm.Ratchet `json:"initalRatchet"` - SigningKeyVerified bool `json:"signingKeyVerified"` //not used for now + SigningKey crypto.Ed25519PublicKey `json:"signing_key"` + InitialRatchet megolm.Ratchet `json:"initial_ratchet"` + SigningKeyVerified bool `json:"signing_key_verified"` //not used for now } // NewMegolmInboundSession creates a new MegolmInboundSession from a base64 encoded session sharing message. diff --git a/crypto/goolm/session/megolm_outbound_session.go b/crypto/goolm/session/megolm_outbound_session.go index 11aadb00..8964a68d 100644 --- a/crypto/goolm/session/megolm_outbound_session.go +++ b/crypto/goolm/session/megolm_outbound_session.go @@ -24,7 +24,7 @@ const ( // MegolmOutboundSession stores information about the sessions to send. type MegolmOutboundSession struct { Ratchet megolm.Ratchet `json:"ratchet"` - SigningKey crypto.Ed25519KeyPair `json:"signingKey"` + SigningKey crypto.Ed25519KeyPair `json:"signing_key"` } // NewMegolmOutboundSession creates a new MegolmOutboundSession. diff --git a/crypto/goolm/session/olm_session.go b/crypto/goolm/session/olm_session.go index b5189c59..6655e0a5 100644 --- a/crypto/goolm/session/olm_session.go +++ b/crypto/goolm/session/olm_session.go @@ -28,14 +28,14 @@ const ( // OlmSession stores all information for an olm session type OlmSession struct { - RecievedMessage bool `json:"recievedMessage"` - AliceIdKey crypto.Curve25519PublicKey `json:"aliceIdKey"` - AliceBaseKey crypto.Curve25519PublicKey `json:"aliceBaseKey"` - BobOneTimeKey crypto.Curve25519PublicKey `json:"bobOnTimeKey"` - Ratchet olm.Ratchet `json:"ratchet"` + ReceivedMessage bool `json:"received_message"` + AliceIdentityKey crypto.Curve25519PublicKey `json:"alice_id_key"` + AliceBaseKey crypto.Curve25519PublicKey `json:"alice_base_key"` + BobOneTimeKey crypto.Curve25519PublicKey `json:"bob_one_time_key"` + Ratchet olm.Ratchet `json:"ratchet"` } -// used to retrieve a crypto.OneTimeKey from a public key. +// SearchOTKFunc is used to retrieve a crypto.OneTimeKey from a public key. type SearchOTKFunc = func(crypto.Curve25519PublicKey) *crypto.OneTimeKey // OlmSessionFromJSONPickled loads an OlmSession from a pickled base64 string. Decrypts @@ -65,14 +65,14 @@ func OlmSessionFromPickled(pickled, key []byte) (*OlmSession, error) { return a, nil } -// NewSession creates a new Session. +// NewOlmSession creates a new Session. func NewOlmSession() *OlmSession { s := &OlmSession{} s.Ratchet = *olm.New() return s } -// NewOutboundSession creates a new outbound session for sending the first message to a +// NewOutboundOlmSession creates a new outbound session for sending the first message to a // given curve25519 identityKey and oneTimeKey. func NewOutboundOlmSession(identityKeyAlice crypto.Curve25519KeyPair, identityKeyBob crypto.Curve25519PublicKey, oneTimeKeyBob crypto.Curve25519PublicKey) (*OlmSession, error) { s := NewOlmSession() @@ -108,8 +108,8 @@ func NewOutboundOlmSession(identityKeyAlice crypto.Curve25519KeyPair, identityKe secret = append(secret, baseIdSecret...) secret = append(secret, baseOneTimeSecret...) //Init Ratchet - s.Ratchet.InitialiseAsAlice(secret, ratchetKey) - s.AliceIdKey = identityKeyAlice.PublicKey + s.Ratchet.InitializeAsAlice(secret, ratchetKey) + s.AliceIdentityKey = identityKeyAlice.PublicKey s.AliceBaseKey = baseKey.PublicKey s.BobOneTimeKey = oneTimeKeyBob return s, nil @@ -182,9 +182,9 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received return nil, fmt.Errorf("Message missing ratchet key: %w", goolm.ErrBadMessageFormat) } //Init Ratchet - s.Ratchet.InitialiseAsBob(secret, msg.RatchetKey) + s.Ratchet.InitializeAsBob(secret, msg.RatchetKey) s.AliceBaseKey = oneTimeMsg.BaseKey - s.AliceIdKey = oneTimeMsg.IdentityKey + s.AliceIdentityKey = oneTimeMsg.IdentityKey s.BobOneTimeKey = oneTimeKeyBob.Key.PublicKey //https://gitlab.matrix.org/matrix-org/olm/blob/master/docs/olm.md states to remove the oneTimeKey @@ -206,7 +206,7 @@ func (a *OlmSession) UnpickleAsJSON(pickled, key []byte) error { // Generated by hashing the public keys used to create the session. func (s OlmSession) ID() id.SessionID { message := make([]byte, 3*crypto.Curve25519KeyLength) - copy(message, s.AliceIdKey) + copy(message, s.AliceIdentityKey) copy(message[crypto.Curve25519KeyLength:], s.AliceBaseKey) copy(message[2*crypto.Curve25519KeyLength:], s.BobOneTimeKey) hash := crypto.SHA256(message) @@ -216,7 +216,7 @@ func (s OlmSession) ID() id.SessionID { // HasReceivedMessage returns true if this session has received any message. func (s OlmSession) HasReceivedMessage() bool { - return s.RecievedMessage + return s.ReceivedMessage } // MatchesInboundSessionFrom checks if the oneTimeKeyMsg message is set for this inbound @@ -253,10 +253,10 @@ func (s OlmSession) MatchesInboundSessionFrom(theirIdentityKeyEncoded *id.Curve2 same := true if msg.IdentityKey != nil { - same = same && msg.IdentityKey.Equal(s.AliceIdKey) + same = same && msg.IdentityKey.Equal(s.AliceIdentityKey) } if theirIdentityKey != nil { - same = same && theirIdentityKey.Equal(s.AliceIdKey) + same = same && theirIdentityKey.Equal(s.AliceIdentityKey) } same = same && bytes.Equal(msg.BaseKey, s.AliceBaseKey) same = same && bytes.Equal(msg.OneTimeKey, s.BobOneTimeKey) @@ -267,7 +267,7 @@ func (s OlmSession) MatchesInboundSessionFrom(theirIdentityKeyEncoded *id.Curve2 // return. Returns MsgTypePreKey if the message will be a oneTimeKeyMsg. // Returns MsgTypeMsg if the message will be a normal message. func (s OlmSession) EncryptMsgType() id.OlmMsgType { - if s.RecievedMessage { + if s.ReceivedMessage { return id.OlmMsgTypeMsg } return id.OlmMsgTypePreKey @@ -284,11 +284,11 @@ func (s *OlmSession) Encrypt(plaintext []byte, reader io.Reader) (id.OlmMsgType, return 0, nil, err } result := encrypted - if !s.RecievedMessage { + if !s.ReceivedMessage { msg := message.PreKeyMessage{} msg.Version = protocolVersion msg.OneTimeKey = s.BobOneTimeKey - msg.IdentityKey = s.AliceIdKey + msg.IdentityKey = s.AliceIdentityKey msg.BaseKey = s.AliceBaseKey msg.Message = encrypted @@ -326,7 +326,7 @@ func (s *OlmSession) Decrypt(crypttext []byte, msgType id.OlmMsgType) ([]byte, e if err != nil { return nil, err } - s.RecievedMessage = true + s.ReceivedMessage = true return plaintext, nil } @@ -358,12 +358,12 @@ func (o *OlmSession) UnpickleLibOlm(value []byte) (int, error) { return 0, fmt.Errorf("unpickle olmSession: %w", goolm.ErrBadVersion) } var readBytes int - o.RecievedMessage, readBytes, err = libolmpickle.UnpickleBool(value[curPos:]) + o.ReceivedMessage, readBytes, err = libolmpickle.UnpickleBool(value[curPos:]) if err != nil { return 0, err } curPos += readBytes - readBytes, err = o.AliceIdKey.UnpickleLibOlm(value[curPos:]) + readBytes, err = o.AliceIdentityKey.UnpickleLibOlm(value[curPos:]) if err != nil { return 0, err } @@ -410,8 +410,8 @@ func (o OlmSession) PickleLibOlm(target []byte) (int, error) { return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", goolm.ErrValueTooShort) } written := libolmpickle.PickleUInt32(olmSessionPickleVersionLibOlm, target) - written += libolmpickle.PickleBool(o.RecievedMessage, target[written:]) - writtenRatchet, err := o.AliceIdKey.PickleLibOlm(target[written:]) + written += libolmpickle.PickleBool(o.ReceivedMessage, target[written:]) + writtenRatchet, err := o.AliceIdentityKey.PickleLibOlm(target[written:]) if err != nil { return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err) } @@ -437,8 +437,8 @@ func (o OlmSession) PickleLibOlm(target []byte) (int, error) { // PickleLen returns the actual number of bytes the pickled session will have. func (o OlmSession) PickleLen() int { length := libolmpickle.PickleUInt32Len(olmSessionPickleVersionLibOlm) - length += libolmpickle.PickleBoolLen(o.RecievedMessage) - length += o.AliceIdKey.PickleLen() + length += libolmpickle.PickleBoolLen(o.ReceivedMessage) + length += o.AliceIdentityKey.PickleLen() length += o.AliceBaseKey.PickleLen() length += o.BobOneTimeKey.PickleLen() length += o.Ratchet.PickleLen() @@ -448,8 +448,8 @@ func (o OlmSession) PickleLen() int { // PickleLenMin returns the minimum number of bytes the pickled session must have. func (o OlmSession) PickleLenMin() int { length := libolmpickle.PickleUInt32Len(olmSessionPickleVersionLibOlm) - length += libolmpickle.PickleBoolLen(o.RecievedMessage) - length += o.AliceIdKey.PickleLen() + length += libolmpickle.PickleBoolLen(o.ReceivedMessage) + length += o.AliceIdentityKey.PickleLen() length += o.AliceBaseKey.PickleLen() length += o.BobOneTimeKey.PickleLen() length += o.Ratchet.PickleLenMin() diff --git a/crypto/olm/olm_goolm.go b/crypto/olm/olm_goolm.go index df18aa0e..dbe12a76 100644 --- a/crypto/olm/olm_goolm.go +++ b/crypto/olm/olm_goolm.go @@ -3,7 +3,6 @@ package olm import ( - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/id" ) @@ -12,7 +11,7 @@ type Signatures map[id.UserID]map[id.DeviceKeyID]string // Version returns the version number of the olm library. func Version() (major, minor, patch uint8) { - return goolm.GetLibaryVersion() + return 3, 2, 15 } // SetPickleKey sets the global pickle key used when encoding structs with Gob or JSON. From 89fcc0eb3520331c1d93ac2fb113f2aebfc36d49 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 15 Dec 2023 15:57:14 +0200 Subject: [PATCH 0037/1647] Update example go.mod --- example/go.mod | 21 ++++++++++---------- example/go.sum | 52 +++++++++++++++++++++++++++----------------------- 2 files changed, 39 insertions(+), 34 deletions(-) diff --git a/example/go.mod b/example/go.mod index 7d234c5f..96cfba90 100644 --- a/example/go.mod +++ b/example/go.mod @@ -4,22 +4,23 @@ go 1.19 require ( github.com/chzyer/readline v1.5.1 - github.com/mattn/go-sqlite3 v1.14.16 - github.com/rs/zerolog v1.29.1 - maunium.net/go/mautrix v0.15.2 + github.com/mattn/go-sqlite3 v1.14.18 + github.com/rs/zerolog v1.31.0 + maunium.net/go/mautrix v0.16.3-0.20231215135638-893afc725981 ) require ( - github.com/mattn/go-colorable v0.1.12 // indirect - github.com/mattn/go-isatty v0.0.14 // indirect - github.com/tidwall/gjson v1.14.4 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/tidwall/gjson v1.17.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/tidwall/sjson v1.2.5 // indirect - golang.org/x/crypto v0.9.0 // indirect - golang.org/x/exp v0.0.0-20230510235704-dd950f8aeaea // indirect - golang.org/x/net v0.10.0 // indirect - golang.org/x/sys v0.8.0 // indirect + go.mau.fi/util v0.2.1 // indirect + golang.org/x/crypto v0.15.0 // indirect + golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa // indirect + golang.org/x/net v0.18.0 // indirect + golang.org/x/sys v0.14.0 // indirect maunium.net/go/maulogger/v2 v2.4.1 // indirect ) diff --git a/example/go.sum b/example/go.sum index 510e0ced..7ceec40a 100644 --- a/example/go.sum +++ b/example/go.sum @@ -8,40 +8,44 @@ github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38 github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= -github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= -github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= -github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= -github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= -github.com/mattn/go-sqlite3 v1.14.16/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI= +github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= -github.com/rs/zerolog v1.29.1 h1:cO+d60CHkknCbvzEWxP0S9K6KqyTjrCNUy1LdQLCGPc= -github.com/rs/zerolog v1.29.1/go.mod h1:Le6ESbR7hc+DP6Lt1THiV8CQSdkkNrd3R0XbEgp3ZBU= -github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A= +github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= -github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM= +github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= -golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= -golang.org/x/exp v0.0.0-20230510235704-dd950f8aeaea h1:vLCWI/yYrdEHyN2JzIzPO3aaQJHQdp89IZBA/+azVC4= -golang.org/x/exp v0.0.0-20230510235704-dd950f8aeaea/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= -golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +go.mau.fi/util v0.2.1 h1:eazulhFE/UmjOFtPrGg6zkF5YfAyiDzQb8ihLMbsPWw= +go.mau.fi/util v0.2.1/go.mod h1:MjlzCQEMzJ+G8RsPawHzpLB8rwTo3aPIjG5FzBvQT/c= +golang.org/x/crypto v0.15.0 h1:frVn1TEaCEaZcn3Tmd7Y2b5KKPaZ+I32Q2OA3kYp5TA= +golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g= +golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ= +golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE= +golang.org/x/net v0.18.0 h1:mIYleuAkSbHh0tCv7RvjL3F6ZVbLjq4+R7zbOn3Kokg= +golang.org/x/net v0.18.0/go.mod h1:/czyP5RqHAH4odGYxBJ1qz0+CE5WZ+2j1YgoEo8F2jQ= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= +golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8= maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho= -maunium.net/go/mautrix v0.15.2 h1:fUiVajeoOR92uJoSShHbCvh7uG6lDY4ZO4Mvt90LbjU= -maunium.net/go/mautrix v0.15.2/go.mod h1:h4NwfKqE4YxGTLSgn/gawKzXAb2sF4qx8agL6QEFtGg= +maunium.net/go/mautrix v0.16.3-0.20231215135638-893afc725981 h1:KrMb5QyuMh2ZqfRLFiKRhYRV95aPCpaAepOful5EHvg= +maunium.net/go/mautrix v0.16.3-0.20231215135638-893afc725981/go.mod h1:YL4l4rZB46/vj/ifRMEjcibbvHjgxHftOF1SgmruLu4= From 7f78c3280c91bd5068374f891a0f11f5a330ac95 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 15 Dec 2023 15:57:26 +0200 Subject: [PATCH 0038/1647] Update changelog --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1961bd79..c2fcb7f6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,11 +2,16 @@ * **Breaking change *(bridge)*** Added raw event to portal membership handling functions. +* *(crypto)* Added `goolm` build tag to use a pure Go implementation of Olm + instead of using libolm via cgo. Thanks to [@DerLukas15] in [#106]. * *(bridge)* Added context parameter for bridge command events. * *(client)* Changed default syncer to not drop unknown events. * The syncer will still drop known events if parsing the content fails. * The behavior can be changed by changing the `ParseErrorHandler` function. +[@DerLukas15]: https://github.com/DerLukas15 +[#106]: https://github.com/mautrix/go/pull/106 + ## v0.16.2 (2023-11-16) * *(event)* Added `Redacts` field to `RedactionEventContent` for room v11+. From 753cdb2e1cb0a74deaf0295a932b072cf6c01400 Mon Sep 17 00:00:00 2001 From: Joakim Recht Date: Fri, 15 Dec 2023 15:23:31 +0100 Subject: [PATCH 0039/1647] Add context parameter to all client and bridge API functions (#144) --- appservice/appservice_test.go | 3 +- appservice/intent.go | 171 ++++----- bridge/bridge.go | 34 +- bridge/commands/admin.go | 3 +- bridge/commands/doublepuppet.go | 4 +- bridge/commands/event.go | 8 +- bridge/commands/handler.go | 4 +- bridge/crypto.go | 23 +- bridge/doublepuppet.go | 30 +- bridge/matrix.go | 54 +-- client.go | 515 ++++++++++++++-------------- crypto/cross_sign_key.go | 5 +- crypto/cross_sign_pubkey.go | 9 +- crypto/cross_sign_signing.go | 23 +- crypto/cross_sign_ssss.go | 31 +- crypto/cross_sign_validation.go | 2 +- crypto/cryptohelper/cryptohelper.go | 15 +- crypto/devicelist.go | 2 +- crypto/encryptmegolm.go | 6 +- crypto/encryptolm.go | 2 +- crypto/keysharing.go | 24 +- crypto/machine.go | 16 +- crypto/sql_store.go | 20 +- crypto/ssss/client.go | 35 +- crypto/store_test.go | 4 +- crypto/verification.go | 162 ++++----- crypto/verification_in_room.go | 77 +++-- synapseadmin/register.go | 6 +- synapseadmin/userapi.go | 12 +- syncstore.go | 32 +- 30 files changed, 669 insertions(+), 663 deletions(-) diff --git a/appservice/appservice_test.go b/appservice/appservice_test.go index 50cc6fc4..eace1668 100644 --- a/appservice/appservice_test.go +++ b/appservice/appservice_test.go @@ -1,6 +1,7 @@ package appservice import ( + "context" "fmt" "net" "net/http" @@ -35,7 +36,7 @@ func TestClient_UnixSocket(t *testing.T) { err = as.SetHomeserverURL(fmt.Sprintf("unix://%s", socket)) assert.NoError(t, err) client := as.Client("user1") - resp, err := client.Whoami() + resp, err := client.Whoami(context.Background()) assert.NoError(t, err) assert.Equal(t, "@joe:example.org", string(resp.UserID)) } diff --git a/appservice/intent.go b/appservice/intent.go index 7995f44b..348eee2a 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -7,6 +7,7 @@ package appservice import ( + "context" "errors" "fmt" "strings" @@ -46,8 +47,8 @@ func (as *AppService) NewIntentAPI(localpart string) *IntentAPI { } } -func (intent *IntentAPI) Register() error { - _, _, err := intent.Client.Register(&mautrix.ReqRegister{ +func (intent *IntentAPI) Register(ctx context.Context) error { + _, _, err := intent.Client.Register(ctx, &mautrix.ReqRegister{ Username: intent.Localpart, Type: mautrix.AuthTypeAppservice, InhibitLogin: true, @@ -55,14 +56,14 @@ func (intent *IntentAPI) Register() error { return err } -func (intent *IntentAPI) EnsureRegistered() error { +func (intent *IntentAPI) EnsureRegistered(ctx context.Context) error { intent.registerLock.Lock() defer intent.registerLock.Unlock() if intent.IsCustomPuppet || intent.as.StateStore.IsRegistered(intent.UserID) { return nil } - err := intent.Register() + err := intent.Register(ctx) if err != nil && !errors.Is(err, mautrix.MUserInUse) { return fmt.Errorf("failed to ensure registered: %w", err) } @@ -75,7 +76,7 @@ type EnsureJoinedParams struct { BotOverride *mautrix.Client } -func (intent *IntentAPI) EnsureJoined(roomID id.RoomID, extra ...EnsureJoinedParams) error { +func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, extra ...EnsureJoinedParams) error { var params EnsureJoinedParams if len(extra) > 1 { panic("invalid number of extra parameters") @@ -86,11 +87,11 @@ func (intent *IntentAPI) EnsureJoined(roomID id.RoomID, extra ...EnsureJoinedPar return nil } - if err := intent.EnsureRegistered(); err != nil { + if err := intent.EnsureRegistered(ctx); err != nil { return fmt.Errorf("failed to ensure joined: %w", err) } - resp, err := intent.JoinRoomByID(roomID) + resp, err := intent.JoinRoomByID(ctx, roomID) if err != nil { bot := intent.bot if params.BotOverride != nil { @@ -99,13 +100,13 @@ func (intent *IntentAPI) EnsureJoined(roomID id.RoomID, extra ...EnsureJoinedPar if !errors.Is(err, mautrix.MForbidden) || bot == nil { return fmt.Errorf("failed to ensure joined: %w", err) } - _, inviteErr := bot.InviteUser(roomID, &mautrix.ReqInviteUser{ + _, inviteErr := bot.InviteUser(ctx, roomID, &mautrix.ReqInviteUser{ UserID: intent.UserID, }) if inviteErr != nil { return fmt.Errorf("failed to invite in ensure joined: %w", inviteErr) } - resp, err = intent.JoinRoomByID(roomID) + resp, err = intent.JoinRoomByID(ctx, roomID) if err != nil { return fmt.Errorf("failed to ensure joined after invite: %w", err) } @@ -151,55 +152,55 @@ func (intent *IntentAPI) AddDoublePuppetValue(into interface{}) interface{} { } } -func (intent *IntentAPI) SendMessageEvent(roomID id.RoomID, eventType event.Type, contentJSON interface{}) (*mautrix.RespSendEvent, error) { - if err := intent.EnsureJoined(roomID); err != nil { +func (intent *IntentAPI) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}) (*mautrix.RespSendEvent, error) { + if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } contentJSON = intent.AddDoublePuppetValue(contentJSON) - return intent.Client.SendMessageEvent(roomID, eventType, contentJSON) + return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON) } -func (intent *IntentAPI) SendMassagedMessageEvent(roomID id.RoomID, eventType event.Type, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { - if err := intent.EnsureJoined(roomID); err != nil { +func (intent *IntentAPI) SendMassagedMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { + if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } contentJSON = intent.AddDoublePuppetValue(contentJSON) - return intent.Client.SendMessageEvent(roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts}) + return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts}) } -func (intent *IntentAPI) SendStateEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (*mautrix.RespSendEvent, error) { +func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (*mautrix.RespSendEvent, error) { if eventType != event.StateMember || stateKey != string(intent.UserID) { - if err := intent.EnsureJoined(roomID); err != nil { + if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } } contentJSON = intent.AddDoublePuppetValue(contentJSON) - return intent.Client.SendStateEvent(roomID, eventType, stateKey, contentJSON) + return intent.Client.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON) } -func (intent *IntentAPI) SendMassagedStateEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { - if err := intent.EnsureJoined(roomID); err != nil { +func (intent *IntentAPI) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { + if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } contentJSON = intent.AddDoublePuppetValue(contentJSON) - return intent.Client.SendMassagedStateEvent(roomID, eventType, stateKey, contentJSON, ts) + return intent.Client.SendMassagedStateEvent(ctx, roomID, eventType, stateKey, contentJSON, ts) } -func (intent *IntentAPI) StateEvent(roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) error { - if err := intent.EnsureJoined(roomID); err != nil { +func (intent *IntentAPI) StateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) error { + if err := intent.EnsureJoined(ctx, roomID); err != nil { return err } - return intent.Client.StateEvent(roomID, eventType, stateKey, outContent) + return intent.Client.StateEvent(ctx, roomID, eventType, stateKey, outContent) } -func (intent *IntentAPI) State(roomID id.RoomID) (mautrix.RoomStateMap, error) { - if err := intent.EnsureJoined(roomID); err != nil { +func (intent *IntentAPI) State(ctx context.Context, roomID id.RoomID) (mautrix.RoomStateMap, error) { + if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } - return intent.Client.State(roomID) + return intent.Client.State(ctx, roomID) } -func (intent *IntentAPI) SendCustomMembershipEvent(roomID id.RoomID, target id.UserID, membership event.Membership, reason string, extraContent ...map[string]interface{}) (*mautrix.RespSendEvent, error) { +func (intent *IntentAPI) SendCustomMembershipEvent(ctx context.Context, roomID id.RoomID, target id.UserID, membership event.Membership, reason string, extraContent ...map[string]interface{}) (*mautrix.RespSendEvent, error) { content := &event.MemberEventContent{ Membership: membership, Reason: reason, @@ -211,7 +212,7 @@ func (intent *IntentAPI) SendCustomMembershipEvent(roomID id.RoomID, target id.U ok = memberContent != nil } if !ok { - profile, err := intent.GetProfile(target) + profile, err := intent.GetProfile(ctx, target) if err != nil { intent.Log.Debug().Err(err). Str("target_user_id", target.String()). @@ -231,21 +232,21 @@ func (intent *IntentAPI) SendCustomMembershipEvent(roomID id.RoomID, target id.U if len(extraContent) > 0 { extra = extraContent[0] } - return intent.SendStateEvent(roomID, event.StateMember, target.String(), &event.Content{ + return intent.SendStateEvent(ctx, roomID, event.StateMember, target.String(), &event.Content{ Parsed: content, Raw: extra, }) } -func (intent *IntentAPI) JoinRoomByID(roomID id.RoomID, extraContent ...map[string]interface{}) (resp *mautrix.RespJoinRoom, err error) { +func (intent *IntentAPI) JoinRoomByID(ctx context.Context, roomID id.RoomID, extraContent ...map[string]interface{}) (resp *mautrix.RespJoinRoom, err error) { if intent.IsCustomPuppet || len(extraContent) > 0 { - _, err = intent.SendCustomMembershipEvent(roomID, intent.UserID, event.MembershipJoin, "", extraContent...) + _, err = intent.SendCustomMembershipEvent(ctx, roomID, intent.UserID, event.MembershipJoin, "", extraContent...) return &mautrix.RespJoinRoom{}, err } - return intent.Client.JoinRoomByID(roomID) + return intent.Client.JoinRoomByID(ctx, roomID) } -func (intent *IntentAPI) LeaveRoom(roomID id.RoomID, extra ...interface{}) (resp *mautrix.RespLeaveRoom, err error) { +func (intent *IntentAPI) LeaveRoom(ctx context.Context, roomID id.RoomID, extra ...interface{}) (resp *mautrix.RespLeaveRoom, err error) { var extraContent map[string]interface{} leaveReq := &mautrix.ReqLeave{} for _, item := range extra { @@ -257,94 +258,94 @@ func (intent *IntentAPI) LeaveRoom(roomID id.RoomID, extra ...interface{}) (resp } } if intent.IsCustomPuppet || extraContent != nil { - _, err = intent.SendCustomMembershipEvent(roomID, intent.UserID, event.MembershipLeave, leaveReq.Reason, extraContent) + _, err = intent.SendCustomMembershipEvent(ctx, roomID, intent.UserID, event.MembershipLeave, leaveReq.Reason, extraContent) return &mautrix.RespLeaveRoom{}, err } - return intent.Client.LeaveRoom(roomID, leaveReq) + return intent.Client.LeaveRoom(ctx, roomID, leaveReq) } -func (intent *IntentAPI) InviteUser(roomID id.RoomID, req *mautrix.ReqInviteUser, extraContent ...map[string]interface{}) (resp *mautrix.RespInviteUser, err error) { +func (intent *IntentAPI) InviteUser(ctx context.Context, roomID id.RoomID, req *mautrix.ReqInviteUser, extraContent ...map[string]interface{}) (resp *mautrix.RespInviteUser, err error) { if intent.IsCustomPuppet || len(extraContent) > 0 { - _, err = intent.SendCustomMembershipEvent(roomID, req.UserID, event.MembershipInvite, req.Reason, extraContent...) + _, err = intent.SendCustomMembershipEvent(ctx, roomID, req.UserID, event.MembershipInvite, req.Reason, extraContent...) return &mautrix.RespInviteUser{}, err } - return intent.Client.InviteUser(roomID, req) + return intent.Client.InviteUser(ctx, roomID, req) } -func (intent *IntentAPI) KickUser(roomID id.RoomID, req *mautrix.ReqKickUser, extraContent ...map[string]interface{}) (resp *mautrix.RespKickUser, err error) { +func (intent *IntentAPI) KickUser(ctx context.Context, roomID id.RoomID, req *mautrix.ReqKickUser, extraContent ...map[string]interface{}) (resp *mautrix.RespKickUser, err error) { if intent.IsCustomPuppet || len(extraContent) > 0 { - _, err = intent.SendCustomMembershipEvent(roomID, req.UserID, event.MembershipLeave, req.Reason, extraContent...) + _, err = intent.SendCustomMembershipEvent(ctx, roomID, req.UserID, event.MembershipLeave, req.Reason, extraContent...) return &mautrix.RespKickUser{}, err } - return intent.Client.KickUser(roomID, req) + return intent.Client.KickUser(ctx, roomID, req) } -func (intent *IntentAPI) BanUser(roomID id.RoomID, req *mautrix.ReqBanUser, extraContent ...map[string]interface{}) (resp *mautrix.RespBanUser, err error) { +func (intent *IntentAPI) BanUser(ctx context.Context, roomID id.RoomID, req *mautrix.ReqBanUser, extraContent ...map[string]interface{}) (resp *mautrix.RespBanUser, err error) { if intent.IsCustomPuppet || len(extraContent) > 0 { - _, err = intent.SendCustomMembershipEvent(roomID, req.UserID, event.MembershipBan, req.Reason, extraContent...) + _, err = intent.SendCustomMembershipEvent(ctx, roomID, req.UserID, event.MembershipBan, req.Reason, extraContent...) return &mautrix.RespBanUser{}, err } - return intent.Client.BanUser(roomID, req) + return intent.Client.BanUser(ctx, roomID, req) } -func (intent *IntentAPI) UnbanUser(roomID id.RoomID, req *mautrix.ReqUnbanUser, extraContent ...map[string]interface{}) (resp *mautrix.RespUnbanUser, err error) { +func (intent *IntentAPI) UnbanUser(ctx context.Context, roomID id.RoomID, req *mautrix.ReqUnbanUser, extraContent ...map[string]interface{}) (resp *mautrix.RespUnbanUser, err error) { if intent.IsCustomPuppet || len(extraContent) > 0 { - _, err = intent.SendCustomMembershipEvent(roomID, req.UserID, event.MembershipLeave, req.Reason, extraContent...) + _, err = intent.SendCustomMembershipEvent(ctx, roomID, req.UserID, event.MembershipLeave, req.Reason, extraContent...) return &mautrix.RespUnbanUser{}, err } - return intent.Client.UnbanUser(roomID, req) + return intent.Client.UnbanUser(ctx, roomID, req) } -func (intent *IntentAPI) Member(roomID id.RoomID, userID id.UserID) *event.MemberEventContent { +func (intent *IntentAPI) Member(ctx context.Context, roomID id.RoomID, userID id.UserID) *event.MemberEventContent { member, ok := intent.as.StateStore.TryGetMember(roomID, userID) if !ok { - _ = intent.StateEvent(roomID, event.StateMember, string(userID), &member) + _ = intent.StateEvent(ctx, roomID, event.StateMember, string(userID), &member) } return member } -func (intent *IntentAPI) PowerLevels(roomID id.RoomID) (pl *event.PowerLevelsEventContent, err error) { +func (intent *IntentAPI) PowerLevels(ctx context.Context, roomID id.RoomID) (pl *event.PowerLevelsEventContent, err error) { pl = intent.as.StateStore.GetPowerLevels(roomID) if pl == nil { pl = &event.PowerLevelsEventContent{} - err = intent.StateEvent(roomID, event.StatePowerLevels, "", pl) + err = intent.StateEvent(ctx, roomID, event.StatePowerLevels, "", pl) } return } -func (intent *IntentAPI) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) (resp *mautrix.RespSendEvent, err error) { - return intent.SendStateEvent(roomID, event.StatePowerLevels, "", &levels) +func (intent *IntentAPI) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) (resp *mautrix.RespSendEvent, err error) { + return intent.SendStateEvent(ctx, roomID, event.StatePowerLevels, "", &levels) } -func (intent *IntentAPI) SetPowerLevel(roomID id.RoomID, userID id.UserID, level int) (*mautrix.RespSendEvent, error) { - pl, err := intent.PowerLevels(roomID) +func (intent *IntentAPI) SetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, level int) (*mautrix.RespSendEvent, error) { + pl, err := intent.PowerLevels(ctx, roomID) if err != nil { return nil, err } if pl.GetUserLevel(userID) != level { pl.SetUserLevel(userID, level) - return intent.SendStateEvent(roomID, event.StatePowerLevels, "", &pl) + return intent.SendStateEvent(ctx, roomID, event.StatePowerLevels, "", &pl) } return nil, nil } -func (intent *IntentAPI) SendText(roomID id.RoomID, text string) (*mautrix.RespSendEvent, error) { - if err := intent.EnsureJoined(roomID); err != nil { +func (intent *IntentAPI) SendText(ctx context.Context, roomID id.RoomID, text string) (*mautrix.RespSendEvent, error) { + if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } - return intent.Client.SendText(roomID, text) + return intent.Client.SendText(ctx, roomID, text) } -func (intent *IntentAPI) SendNotice(roomID id.RoomID, text string) (*mautrix.RespSendEvent, error) { - if err := intent.EnsureJoined(roomID); err != nil { +func (intent *IntentAPI) SendNotice(ctx context.Context, roomID id.RoomID, text string) (*mautrix.RespSendEvent, error) { + if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } - return intent.Client.SendNotice(roomID, text) + return intent.Client.SendNotice(ctx, roomID, text) } -func (intent *IntentAPI) RedactEvent(roomID id.RoomID, eventID id.EventID, extra ...mautrix.ReqRedact) (*mautrix.RespSendEvent, error) { - if err := intent.EnsureJoined(roomID); err != nil { +func (intent *IntentAPI) RedactEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID, extra ...mautrix.ReqRedact) (*mautrix.RespSendEvent, error) { + if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } var req mautrix.ReqRedact @@ -352,65 +353,65 @@ func (intent *IntentAPI) RedactEvent(roomID id.RoomID, eventID id.EventID, extra req = extra[0] } intent.AddDoublePuppetValue(&req.Extra) - return intent.Client.RedactEvent(roomID, eventID, req) + return intent.Client.RedactEvent(ctx, roomID, eventID, req) } -func (intent *IntentAPI) SetRoomName(roomID id.RoomID, roomName string) (*mautrix.RespSendEvent, error) { - return intent.SendStateEvent(roomID, event.StateRoomName, "", map[string]interface{}{ +func (intent *IntentAPI) SetRoomName(ctx context.Context, roomID id.RoomID, roomName string) (*mautrix.RespSendEvent, error) { + return intent.SendStateEvent(ctx, roomID, event.StateRoomName, "", map[string]interface{}{ "name": roomName, }) } -func (intent *IntentAPI) SetRoomAvatar(roomID id.RoomID, avatarURL id.ContentURI) (*mautrix.RespSendEvent, error) { - return intent.SendStateEvent(roomID, event.StateRoomAvatar, "", map[string]interface{}{ +func (intent *IntentAPI) SetRoomAvatar(ctx context.Context, roomID id.RoomID, avatarURL id.ContentURI) (*mautrix.RespSendEvent, error) { + return intent.SendStateEvent(ctx, roomID, event.StateRoomAvatar, "", map[string]interface{}{ "url": avatarURL.String(), }) } -func (intent *IntentAPI) SetRoomTopic(roomID id.RoomID, topic string) (*mautrix.RespSendEvent, error) { - return intent.SendStateEvent(roomID, event.StateTopic, "", map[string]interface{}{ +func (intent *IntentAPI) SetRoomTopic(ctx context.Context, roomID id.RoomID, topic string) (*mautrix.RespSendEvent, error) { + return intent.SendStateEvent(ctx, roomID, event.StateTopic, "", map[string]interface{}{ "topic": topic, }) } -func (intent *IntentAPI) SetDisplayName(displayName string) error { - if err := intent.EnsureRegistered(); err != nil { +func (intent *IntentAPI) SetDisplayName(ctx context.Context, displayName string) error { + if err := intent.EnsureRegistered(ctx); err != nil { return err } - resp, err := intent.Client.GetOwnDisplayName() + resp, err := intent.Client.GetOwnDisplayName(ctx) if err != nil { return fmt.Errorf("failed to check current displayname: %w", err) } else if resp.DisplayName == displayName { // No need to update return nil } - return intent.Client.SetDisplayName(displayName) + return intent.Client.SetDisplayName(ctx, displayName) } -func (intent *IntentAPI) SetAvatarURL(avatarURL id.ContentURI) error { - if err := intent.EnsureRegistered(); err != nil { +func (intent *IntentAPI) SetAvatarURL(ctx context.Context, avatarURL id.ContentURI) error { + if err := intent.EnsureRegistered(ctx); err != nil { return err } - resp, err := intent.Client.GetOwnAvatarURL() + resp, err := intent.Client.GetOwnAvatarURL(ctx) if err != nil { return fmt.Errorf("failed to check current avatar URL: %w", err) } else if resp.FileID == avatarURL.FileID && resp.Homeserver == avatarURL.Homeserver { // No need to update return nil } - return intent.Client.SetAvatarURL(avatarURL) + return intent.Client.SetAvatarURL(ctx, avatarURL) } -func (intent *IntentAPI) Whoami() (*mautrix.RespWhoami, error) { - if err := intent.EnsureRegistered(); err != nil { +func (intent *IntentAPI) Whoami(ctx context.Context) (*mautrix.RespWhoami, error) { + if err := intent.EnsureRegistered(ctx); err != nil { return nil, err } - return intent.Client.Whoami() + return intent.Client.Whoami(ctx) } -func (intent *IntentAPI) EnsureInvited(roomID id.RoomID, userID id.UserID) error { +func (intent *IntentAPI) EnsureInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) error { if !intent.as.StateStore.IsInvited(roomID, userID) { - _, err := intent.InviteUser(roomID, &mautrix.ReqInviteUser{ + _, err := intent.InviteUser(ctx, roomID, &mautrix.ReqInviteUser{ UserID: userID, }) if httpErr, ok := err.(mautrix.HTTPError); ok && diff --git a/bridge/bridge.go b/bridge/bridge.go index 291d6be9..763cb4e0 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -217,7 +217,7 @@ type Crypto interface { Decrypt(*event.Event) (*event.Event, error) Encrypt(id.RoomID, event.Type, *event.Content) error WaitForSession(id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool - RequestSession(id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) + RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) ResetSession(id.RoomID) Init() error Start() @@ -287,9 +287,9 @@ func (br *Bridge) InitVersion(tag, commit, buildTime string) { var MinSpecVersion = mautrix.SpecV11 -func (br *Bridge) ensureConnection() { +func (br *Bridge) ensureConnection(ctx context.Context) { for { - versions, err := br.Bot.Versions() + versions, err := br.Bot.Versions(ctx) if err != nil { br.ZLog.Err(err).Msg("Failed to connect to homeserver, retrying in 10 seconds...") time.Sleep(10 * time.Second) @@ -315,7 +315,7 @@ func (br *Bridge) ensureConnection() { } } - resp, err := br.Bot.Whoami() + resp, err := br.Bot.Whoami(ctx) if err != nil { if errors.Is(err, mautrix.MUnknownToken) { br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was not accepted. Is the registration file installed in your homeserver correctly?") @@ -346,7 +346,7 @@ func (br *Bridge) ensureConnection() { const maxRetries = 6 for { txnID = br.Bot.TxnID() - pingResp, err = br.Bot.AppservicePing(br.Config.AppService.ID, txnID) + pingResp, err = br.Bot.AppservicePing(ctx, br.Config.AppService.ID, txnID) if err == nil { break } @@ -385,8 +385,8 @@ func (br *Bridge) ensureConnection() { Msg("Homeserver -> bridge connection works") } -func (br *Bridge) fetchMediaConfig() { - cfg, err := br.Bot.GetMediaConfig() +func (br *Bridge) fetchMediaConfig(ctx context.Context) { + cfg, err := br.Bot.GetMediaConfig(ctx) if err != nil { br.ZLog.Warn().Err(err).Msg("Failed to fetch media config") } else { @@ -394,25 +394,25 @@ func (br *Bridge) fetchMediaConfig() { } } -func (br *Bridge) UpdateBotProfile() { +func (br *Bridge) UpdateBotProfile(ctx context.Context) { br.ZLog.Debug().Msg("Updating bot profile") botConfig := &br.Config.AppService.Bot var err error var mxc id.ContentURI if botConfig.Avatar == "remove" { - err = br.Bot.SetAvatarURL(mxc) + err = br.Bot.SetAvatarURL(ctx, mxc) } else if !botConfig.ParsedAvatar.IsEmpty() { - err = br.Bot.SetAvatarURL(botConfig.ParsedAvatar) + err = br.Bot.SetAvatarURL(ctx, botConfig.ParsedAvatar) } if err != nil { br.ZLog.Warn().Err(err).Msg("Failed to update bot avatar") } if botConfig.Displayname == "remove" { - err = br.Bot.SetDisplayName("") + err = br.Bot.SetDisplayName(ctx, "") } else if len(botConfig.Displayname) > 0 { - err = br.Bot.SetDisplayName(botConfig.Displayname) + err = br.Bot.SetDisplayName(ctx, botConfig.Displayname) } if err != nil { br.ZLog.Warn().Err(err).Msg("Failed to update bot displayname") @@ -420,7 +420,7 @@ func (br *Bridge) UpdateBotProfile() { if br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) && br.BeeperNetworkName != "" { br.ZLog.Debug().Msg("Setting contact info on the appservice bot") - br.Bot.BeeperUpdateProfile(map[string]any{ + br.Bot.BeeperUpdateProfile(ctx, map[string]any{ "com.beeper.bridge.service": br.BeeperServiceName, "com.beeper.bridge.network": br.BeeperNetworkName, "com.beeper.bridge.is_bridge_bot": true, @@ -633,8 +633,10 @@ func (br *Bridge) start() { os.Exit(23) } br.ZLog.Debug().Msg("Checking connection to homeserver") - br.ensureConnection() - go br.fetchMediaConfig() + + ctx := context.Background() + br.ensureConnection(ctx) + go br.fetchMediaConfig(ctx) if br.Crypto != nil { err = br.Crypto.Init() @@ -647,7 +649,7 @@ func (br *Bridge) start() { br.ZLog.Debug().Msg("Starting event processor") br.EventProcessor.Start() - go br.UpdateBotProfile() + go br.UpdateBotProfile(ctx) if br.Crypto != nil { go br.Crypto.Start() } diff --git a/bridge/commands/admin.go b/bridge/commands/admin.go index dde97de7..d07ada1a 100644 --- a/bridge/commands/admin.go +++ b/bridge/commands/admin.go @@ -7,6 +7,7 @@ package commands import ( + "context" "strconv" "maunium.net/go/mautrix/id" @@ -57,7 +58,7 @@ func fnSetPowerLevel(ce *Event) { ce.Reply("**Usage:** `set-pl [user] `") return } - _, err = ce.Portal.MainIntent().SetPowerLevel(ce.RoomID, userID, level) + _, err = ce.Portal.MainIntent().SetPowerLevel(context.Background(), ce.RoomID, userID, level) if err != nil { ce.Reply("Failed to set power levels: %v", err) } diff --git a/bridge/commands/doublepuppet.go b/bridge/commands/doublepuppet.go index 8c2e611e..9501d01f 100644 --- a/bridge/commands/doublepuppet.go +++ b/bridge/commands/doublepuppet.go @@ -6,6 +6,8 @@ package commands +import "context" + var CommandLoginMatrix = &FullHandler{ Func: fnLoginMatrix, Name: "login-matrix", @@ -54,7 +56,7 @@ func fnPingMatrix(ce *Event) { ce.Reply("You are not logged in with your Matrix account.") return } - resp, err := puppet.CustomIntent().Whoami() + resp, err := puppet.CustomIntent().Whoami(context.Background()) if err != nil { ce.Reply("Failed to validate Matrix login: %v", err) } else { diff --git a/bridge/commands/event.go b/bridge/commands/event.go index 0adc9237..24cf2eb9 100644 --- a/bridge/commands/event.go +++ b/bridge/commands/event.go @@ -67,7 +67,7 @@ func (ce *Event) Reply(msg string, args ...interface{}) { func (ce *Event) ReplyAdvanced(msg string, allowMarkdown, allowHTML bool) { content := format.RenderMarkdown(msg, allowMarkdown, allowHTML) content.MsgType = event.MsgNotice - _, err := ce.MainIntent().SendMessageEvent(ce.RoomID, event.EventMessage, content) + _, err := ce.MainIntent().SendMessageEvent(context.Background(), ce.RoomID, event.EventMessage, content) if err != nil { ce.ZLog.Error().Err(err).Msgf("Failed to reply to command") } @@ -75,7 +75,7 @@ func (ce *Event) ReplyAdvanced(msg string, allowMarkdown, allowHTML bool) { // React sends a reaction to the command. func (ce *Event) React(key string) { - _, err := ce.MainIntent().SendReaction(ce.RoomID, ce.EventID, key) + _, err := ce.MainIntent().SendReaction(context.Background(), ce.RoomID, ce.EventID, key) if err != nil { ce.ZLog.Error().Err(err).Msgf("Failed to react to command") } @@ -83,7 +83,7 @@ func (ce *Event) React(key string) { // Redact redacts the command. func (ce *Event) Redact(req ...mautrix.ReqRedact) { - _, err := ce.MainIntent().RedactEvent(ce.RoomID, ce.EventID, req...) + _, err := ce.MainIntent().RedactEvent(context.Background(), ce.RoomID, ce.EventID, req...) if err != nil { ce.ZLog.Error().Err(err).Msgf("Failed to redact command") } @@ -91,7 +91,7 @@ func (ce *Event) Redact(req ...mautrix.ReqRedact) { // MarkRead marks the command event as read. func (ce *Event) MarkRead() { - err := ce.MainIntent().SendReceipt(ce.RoomID, ce.EventID, event.ReceiptTypeRead, nil) + err := ce.MainIntent().SendReceipt(context.Background(), ce.RoomID, ce.EventID, event.ReceiptTypeRead, nil) if err != nil { ce.ZLog.Error().Err(err).Msgf("Failed to mark command as read") } diff --git a/bridge/commands/handler.go b/bridge/commands/handler.go index d158191a..cfed683b 100644 --- a/bridge/commands/handler.go +++ b/bridge/commands/handler.go @@ -7,6 +7,8 @@ package commands import ( + "context" + "maunium.net/go/mautrix/bridge" "maunium.net/go/mautrix/bridge/bridgeconfig" "maunium.net/go/mautrix/event" @@ -76,7 +78,7 @@ func (fh *FullHandler) ShowInHelp(ce *Event) bool { } func (fh *FullHandler) userHasRoomPermission(ce *Event) bool { - levels, err := ce.MainIntent().PowerLevels(ce.RoomID) + levels, err := ce.MainIntent().PowerLevels(context.Background(), ce.RoomID) if err != nil { ce.ZLog.Warn().Err(err).Msg("Failed to check room power levels") ce.Reply("Failed to get room power levels to see if you're allowed to use that command") diff --git a/bridge/crypto.go b/bridge/crypto.go index 065bc017..73e5dbf8 100644 --- a/bridge/crypto.go +++ b/bridge/crypto.go @@ -137,8 +137,9 @@ func (helper *CryptoHelper) Init() error { } func (helper *CryptoHelper) resyncEncryptionInfo() { + ctx := context.Background() log := helper.log.With().Str("action", "resync encryption event").Logger() - rows, err := helper.bridge.DB.Query(`SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`) + rows, err := helper.bridge.DB.QueryContext(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`) if err != nil { log.Err(err).Msg("Failed to query rooms for resync") return @@ -158,10 +159,10 @@ func (helper *CryptoHelper) resyncEncryptionInfo() { log.Debug().Interface("room_ids", roomIDs).Msg("Resyncing rooms") for _, roomID := range roomIDs { var evt event.EncryptionEventContent - err = helper.client.StateEvent(roomID, event.StateEncryption, "", &evt) + err = helper.client.StateEvent(ctx, roomID, event.StateEncryption, "", &evt) if err != nil { log.Err(err).Str("room_id", roomID.String()).Msg("Failed to get encryption event") - _, err = helper.bridge.DB.Exec(` + _, err = helper.bridge.DB.ExecContext(ctx, ` UPDATE mx_room_state SET encryption=NULL WHERE room_id=$1 AND encryption='{"resync":true}' `, roomID) if err != nil { @@ -182,7 +183,7 @@ func (helper *CryptoHelper) resyncEncryptionInfo() { Int("max_messages", maxMessages). Interface("content", &evt). Msg("Resynced encryption event") - _, err = helper.bridge.DB.Exec(` + _, err = helper.bridge.DB.ExecContext(ctx, ` UPDATE crypto_megolm_inbound_session SET max_age=$1, max_messages=$2 WHERE room_id=$3 AND max_age IS NULL AND max_messages IS NULL @@ -223,6 +224,7 @@ func (helper *CryptoHelper) allowKeyShare(ctx context.Context, device *id.Device } func (helper *CryptoHelper) loginBot() (*mautrix.Client, bool, error) { + ctx := context.Background() deviceID := helper.store.FindDeviceID() if len(deviceID) > 0 { helper.log.Debug().Str("device_id", deviceID.String()).Msg("Found existing device ID for bot in database") @@ -230,13 +232,13 @@ func (helper *CryptoHelper) loginBot() (*mautrix.Client, bool, error) { // Create a new client instance with the default AS settings (including as_token), // the Login call will then override the access token in the client. client := helper.bridge.AS.NewMautrixClient(helper.bridge.AS.BotMXID()) - flows, err := client.GetLoginFlows() + flows, err := client.GetLoginFlows(ctx) if err != nil { return nil, deviceID != "", fmt.Errorf("failed to get supported login flows: %w", err) } else if !flows.HasFlow(mautrix.AuthTypeAppservice) { return nil, deviceID != "", fmt.Errorf("homeserver does not support appservice login") } - resp, err := client.Login(&mautrix.ReqLogin{ + resp, err := client.Login(ctx, &mautrix.ReqLogin{ Type: mautrix.AuthTypeAppservice, Identifier: mautrix.UserIdentifier{ Type: mautrix.IdentifierTypeUser, @@ -255,8 +257,9 @@ func (helper *CryptoHelper) loginBot() (*mautrix.Client, bool, error) { } func (helper *CryptoHelper) verifyKeysAreOnServer() { + ctx := context.Background() helper.log.Debug().Msg("Making sure keys are still on server") - resp, err := helper.client.QueryKeys(&mautrix.ReqQueryKeys{ + resp, err := helper.client.QueryKeys(ctx, &mautrix.ReqQueryKeys{ DeviceKeys: map[id.UserID]mautrix.DeviceIDList{ helper.client.UserID: {helper.client.DeviceID}, }, @@ -333,7 +336,7 @@ func (helper *CryptoHelper) Reset(startAfterReset bool) { helper.log.Debug().Msg("Crypto syncer stopped, clearing database") helper.clearDatabase() helper.log.Debug().Msg("Crypto database cleared, logging out of all sessions") - _, err := helper.client.LogoutAll() + _, err := helper.client.LogoutAll(context.Background()) if err != nil { helper.log.Warn().Err(err).Msg("Failed to log out all devices") } @@ -395,13 +398,13 @@ func (helper *CryptoHelper) WaitForSession(roomID id.RoomID, senderKey id.Sender return helper.mach.WaitForSession(roomID, senderKey, sessionID, timeout) } -func (helper *CryptoHelper) RequestSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) { +func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) { helper.lock.RLock() defer helper.lock.RUnlock() if deviceID == "" { deviceID = "*" } - err := helper.mach.SendRoomKeyRequest(roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{userID: {deviceID}}) + err := helper.mach.SendRoomKeyRequest(ctx, roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{userID: {deviceID}}) if err != nil { helper.log.Warn().Err(err). Str("user_id", userID.String()). diff --git a/bridge/doublepuppet.go b/bridge/doublepuppet.go index 7ddc1989..35903efd 100644 --- a/bridge/doublepuppet.go +++ b/bridge/doublepuppet.go @@ -7,6 +7,7 @@ package bridge import ( + "context" "crypto/hmac" "crypto/sha512" "encoding/hex" @@ -26,7 +27,7 @@ type doublePuppetUtil struct { log zerolog.Logger } -func (dp *doublePuppetUtil) newClient(mxid id.UserID, accessToken string) (*mautrix.Client, error) { +func (dp *doublePuppetUtil) newClient(ctx context.Context, mxid id.UserID, accessToken string) (*mautrix.Client, error) { _, homeserver, err := mxid.Parse() if err != nil { return nil, err @@ -36,7 +37,7 @@ func (dp *doublePuppetUtil) newClient(mxid id.UserID, accessToken string) (*maut if homeserver == dp.br.AS.HomeserverDomain { homeserverURL = "" } else if dp.br.Config.Bridge.GetDoublePuppetConfig().AllowDiscovery { - resp, err := mautrix.DiscoverClientAPI(homeserver) + resp, err := mautrix.DiscoverClientAPI(ctx, homeserver) if err != nil { return nil, fmt.Errorf("failed to find homeserver URL for %s: %v", homeserver, err) } @@ -53,8 +54,8 @@ func (dp *doublePuppetUtil) newClient(mxid id.UserID, accessToken string) (*maut return dp.br.AS.NewExternalMautrixClient(mxid, accessToken, homeserverURL) } -func (dp *doublePuppetUtil) newIntent(mxid id.UserID, accessToken string) (*appservice.IntentAPI, error) { - client, err := dp.newClient(mxid, accessToken) +func (dp *doublePuppetUtil) newIntent(ctx context.Context, mxid id.UserID, accessToken string) (*appservice.IntentAPI, error) { + client, err := dp.newClient(ctx, mxid, accessToken) if err != nil { return nil, err } @@ -67,9 +68,9 @@ func (dp *doublePuppetUtil) newIntent(mxid id.UserID, accessToken string) (*apps return ia, nil } -func (dp *doublePuppetUtil) autoLogin(mxid id.UserID, loginSecret string) (string, error) { +func (dp *doublePuppetUtil) autoLogin(ctx context.Context, mxid id.UserID, loginSecret string) (string, error) { dp.log.Debug().Str("user_id", mxid.String()).Msg("Logging into user account with shared secret") - client, err := dp.newClient(mxid, "") + client, err := dp.newClient(ctx, mxid, "") if err != nil { return "", fmt.Errorf("failed to create mautrix client to log in: %v", err) } @@ -83,7 +84,7 @@ func (dp *doublePuppetUtil) autoLogin(mxid id.UserID, loginSecret string) (strin client.AccessToken = dp.br.AS.Registration.AppToken req.Type = mautrix.AuthTypeAppservice } else { - loginFlows, err := client.GetLoginFlows() + loginFlows, err := client.GetLoginFlows(ctx) if err != nil { return "", fmt.Errorf("failed to get supported login flows: %w", err) } @@ -101,7 +102,7 @@ func (dp *doublePuppetUtil) autoLogin(mxid id.UserID, loginSecret string) (strin return "", fmt.Errorf("no supported auth types for shared secret auth found") } } - resp, err := client.Login(&req) + resp, err := client.Login(ctx, &req) if err != nil { return "", err } @@ -122,18 +123,19 @@ func (dp *doublePuppetUtil) Setup(mxid id.UserID, savedAccessToken string, relog err = ErrNoMXID return } + ctx := context.Background() _, homeserver, _ := mxid.Parse() loginSecret, hasSecret := dp.br.Config.Bridge.GetDoublePuppetConfig().SharedSecretMap[homeserver] // Special case appservice: prefix to not login and use it as an as_token directly. if hasSecret && strings.HasPrefix(loginSecret, asTokenModePrefix) { - intent, err = dp.newIntent(mxid, strings.TrimPrefix(loginSecret, asTokenModePrefix)) + intent, err = dp.newIntent(ctx, mxid, strings.TrimPrefix(loginSecret, asTokenModePrefix)) if err != nil { return } intent.SetAppServiceUserID = true if savedAccessToken != useConfigASToken { var resp *mautrix.RespWhoami - resp, err = intent.Whoami() + resp, err = intent.Whoami(ctx) if err == nil && resp.UserID != mxid { err = ErrMismatchingMXID } @@ -142,7 +144,7 @@ func (dp *doublePuppetUtil) Setup(mxid id.UserID, savedAccessToken string, relog } if savedAccessToken == "" || savedAccessToken == useConfigASToken { if reloginOnFail && hasSecret { - savedAccessToken, err = dp.autoLogin(mxid, loginSecret) + savedAccessToken, err = dp.autoLogin(ctx, mxid, loginSecret) } else { err = ErrNoAccessToken } @@ -150,15 +152,15 @@ func (dp *doublePuppetUtil) Setup(mxid id.UserID, savedAccessToken string, relog return } } - intent, err = dp.newIntent(mxid, savedAccessToken) + intent, err = dp.newIntent(ctx, mxid, savedAccessToken) if err != nil { return } var resp *mautrix.RespWhoami - resp, err = intent.Whoami() + resp, err = intent.Whoami(ctx) if err != nil { if reloginOnFail && hasSecret && errors.Is(err, mautrix.MUnknownToken) { - intent.AccessToken, err = dp.autoLogin(mxid, loginSecret) + intent.AccessToken, err = dp.autoLogin(ctx, mxid, loginSecret) if err == nil { newAccessToken = intent.AccessToken } diff --git a/bridge/matrix.go b/bridge/matrix.go index 3196af60..f9a86d80 100644 --- a/bridge/matrix.go +++ b/bridge/matrix.go @@ -87,7 +87,7 @@ func (mx *MatrixHandler) HandleEncryption(evt *event.Event) { Msg("Encryption was enabled in room") portal.MarkEncrypted() if portal.IsPrivateChat() { - err := mx.as.BotIntent().EnsureJoined(evt.RoomID, appservice.EnsureJoinedParams{BotOverride: portal.MainIntent().Client}) + err := mx.as.BotIntent().EnsureJoined(context.Background(), evt.RoomID, appservice.EnsureJoinedParams{BotOverride: portal.MainIntent().Client}) if err != nil { mx.log.Err(err). Str("room_id", evt.RoomID.String()). @@ -99,32 +99,32 @@ func (mx *MatrixHandler) HandleEncryption(evt *event.Event) { func (mx *MatrixHandler) joinAndCheckMembers(ctx context.Context, evt *event.Event, intent *appservice.IntentAPI) *mautrix.RespJoinedMembers { log := zerolog.Ctx(ctx) - resp, err := intent.JoinRoomByID(evt.RoomID) + resp, err := intent.JoinRoomByID(ctx, evt.RoomID) if err != nil { log.Warn().Err(err).Msg("Failed to join room with invite") return nil } - members, err := intent.JoinedMembers(resp.RoomID) + members, err := intent.JoinedMembers(ctx, resp.RoomID) if err != nil { log.Warn().Err(err).Msg("Failed to get members in room after accepting invite, leaving room") - _, _ = intent.LeaveRoom(resp.RoomID) + _, _ = intent.LeaveRoom(ctx, resp.RoomID) return nil } if len(members.Joined) < 2 { log.Debug().Msg("Leaving empty room after accepting invite") - _, _ = intent.LeaveRoom(resp.RoomID) + _, _ = intent.LeaveRoom(ctx, resp.RoomID) return nil } return members } -func (mx *MatrixHandler) sendNoticeWithMarkdown(roomID id.RoomID, message string) (*mautrix.RespSendEvent, error) { +func (mx *MatrixHandler) sendNoticeWithMarkdown(ctx context.Context, roomID id.RoomID, message string) (*mautrix.RespSendEvent, error) { intent := mx.as.BotIntent() content := format.RenderMarkdown(message, true, false) content.MsgType = event.MsgNotice - return intent.SendMessageEvent(roomID, event.EventMessage, content) + return intent.SendMessageEvent(ctx, roomID, event.EventMessage, content) } func (mx *MatrixHandler) HandleBotInvite(ctx context.Context, evt *event.Event) { @@ -141,31 +141,31 @@ func (mx *MatrixHandler) HandleBotInvite(ctx context.Context, evt *event.Event) } if user.GetPermissionLevel() < bridgeconfig.PermissionLevelUser { - _, _ = intent.SendNotice(evt.RoomID, "You are not whitelisted to use this bridge.\n"+ + _, _ = intent.SendNotice(ctx, evt.RoomID, "You are not whitelisted to use this bridge.\n"+ "If you're the owner of this bridge, see the bridge.permissions section in your config file.") - _, _ = intent.LeaveRoom(evt.RoomID) + _, _ = intent.LeaveRoom(ctx, evt.RoomID) return } texts := mx.bridge.Config.Bridge.GetManagementRoomTexts() - _, _ = mx.sendNoticeWithMarkdown(evt.RoomID, texts.Welcome) + _, _ = mx.sendNoticeWithMarkdown(ctx, evt.RoomID, texts.Welcome) if len(members.Joined) == 2 && (len(user.GetManagementRoomID()) == 0 || evt.Content.AsMember().IsDirect) { user.SetManagementRoom(evt.RoomID) - _, _ = intent.SendNotice(user.GetManagementRoomID(), "This room has been registered as your bridge management/status room.") + _, _ = intent.SendNotice(ctx, user.GetManagementRoomID(), "This room has been registered as your bridge management/status room.") zerolog.Ctx(ctx).Debug().Msg("Registered room as management room with inviter") } if evt.RoomID == user.GetManagementRoomID() { if user.IsLoggedIn() { - _, _ = mx.sendNoticeWithMarkdown(evt.RoomID, texts.WelcomeConnected) + _, _ = mx.sendNoticeWithMarkdown(ctx, evt.RoomID, texts.WelcomeConnected) } else { - _, _ = mx.sendNoticeWithMarkdown(evt.RoomID, texts.WelcomeUnconnected) + _, _ = mx.sendNoticeWithMarkdown(ctx, evt.RoomID, texts.WelcomeUnconnected) } additionalHelp := texts.AdditionalHelp if len(additionalHelp) > 0 { - _, _ = mx.sendNoticeWithMarkdown(evt.RoomID, additionalHelp) + _, _ = mx.sendNoticeWithMarkdown(ctx, evt.RoomID, additionalHelp) } } } @@ -176,7 +176,7 @@ func (mx *MatrixHandler) HandleGhostInvite(ctx context.Context, evt *event.Event if inviter.GetPermissionLevel() < bridgeconfig.PermissionLevelUser { log.Debug().Msg("Rejecting invite: inviter is not whitelisted") - _, err := intent.LeaveRoom(evt.RoomID, &mautrix.ReqLeave{ + _, err := intent.LeaveRoom(ctx, evt.RoomID, &mautrix.ReqLeave{ Reason: "You're not whitelisted to use this bridge", }) if err != nil { @@ -185,7 +185,7 @@ func (mx *MatrixHandler) HandleGhostInvite(ctx context.Context, evt *event.Event return } else if !inviter.IsLoggedIn() { log.Debug().Msg("Rejecting invite: inviter is not logged in") - _, err := intent.LeaveRoom(evt.RoomID, &mautrix.ReqLeave{ + _, err := intent.LeaveRoom(ctx, evt.RoomID, &mautrix.ReqLeave{ Reason: "You're not logged into this bridge", }) if err != nil { @@ -199,11 +199,11 @@ func (mx *MatrixHandler) HandleGhostInvite(ctx context.Context, evt *event.Event return } var createEvent event.CreateEventContent - if err := intent.StateEvent(evt.RoomID, event.StateCreate, "", &createEvent); err != nil { + if err := intent.StateEvent(ctx, evt.RoomID, event.StateCreate, "", &createEvent); err != nil { log.Warn().Err(err).Msg("Failed to check m.room.create event in room") } else if createEvent.Type != "" { log.Warn().Str("room_type", string(createEvent.Type)).Msg("Non-standard room type, leaving room") - _, err = intent.LeaveRoom(evt.RoomID, &mautrix.ReqLeave{ + _, err = intent.LeaveRoom(ctx, evt.RoomID, &mautrix.ReqLeave{ Reason: "Unsupported room type", }) if err != nil { @@ -225,10 +225,10 @@ func (mx *MatrixHandler) HandleGhostInvite(ctx context.Context, evt *event.Event mx.bridge.Child.CreatePrivatePortal(evt.RoomID, inviter, ghost) } else if !hasBridgeBot { log.Debug().Msg("Leaving multi-user room after accepting invite") - _, _ = intent.SendNotice(evt.RoomID, "Please invite the bridge bot first if you want to bridge to a remote chat.") - _, _ = intent.LeaveRoom(evt.RoomID) + _, _ = intent.SendNotice(ctx, evt.RoomID, "Please invite the bridge bot first if you want to bridge to a remote chat.") + _, _ = intent.LeaveRoom(ctx, evt.RoomID) } else { - _, _ = intent.SendNotice(evt.RoomID, "This puppet will remain inactive until this room is bridged to a remote chat.") + _, _ = intent.SendNotice(ctx, evt.RoomID, "This puppet will remain inactive until this room is bridged to a remote chat.") } } @@ -237,12 +237,12 @@ func (mx *MatrixHandler) HandleMembership(evt *event.Event) { return } defer mx.TrackEventDuration(evt.Type)() + ctx := context.Background() if mx.bridge.Crypto != nil { mx.bridge.Crypto.HandleMemberEvent(evt) } - ctx := context.Background() log := mx.log.With(). Str("sender", evt.Sender.String()). Str("target", evt.GetStateKey()). @@ -358,7 +358,7 @@ func (mx *MatrixHandler) sendCryptoStatusError(ctx context.Context, evt *event.E if !isFinal { statusEvent.Status = event.MessageStatusPending } - _, sendErr := mx.bridge.Bot.SendMessageEvent(evt.RoomID, event.BeeperMessageStatus, statusEvent) + _, sendErr := mx.bridge.Bot.SendMessageEvent(ctx, evt.RoomID, event.BeeperMessageStatus, statusEvent) if sendErr != nil { zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to send message status event") } @@ -377,7 +377,7 @@ func (mx *MatrixHandler) sendCryptoStatusError(ctx context.Context, evt *event.E } else if ok && relatable.OptionalGetRelatesTo().GetThreadParent() != "" { update.GetRelatesTo().SetThread(relatable.OptionalGetRelatesTo().GetThreadParent(), evt.ID) } - resp, sendErr := mx.bridge.Bot.SendMessageEvent(evt.RoomID, event.EventMessage, &update) + resp, sendErr := mx.bridge.Bot.SendMessageEvent(ctx, evt.RoomID, event.EventMessage, &update) if sendErr != nil { zerolog.Ctx(ctx).Error().Err(sendErr).Msg("Failed to send decryption error notice") } else if resp != nil { @@ -471,7 +471,7 @@ func (mx *MatrixHandler) postDecrypt(ctx context.Context, original, decrypted *e decrypted.Mautrix.DecryptionDuration = duration mx.bridge.EventProcessor.Dispatch(decrypted) if errorEventID != "" { - _, _ = mx.bridge.Bot.RedactEvent(decrypted.RoomID, errorEventID) + _, _ = mx.bridge.Bot.RedactEvent(ctx, decrypted.RoomID, errorEventID) } } @@ -526,7 +526,7 @@ func (mx *MatrixHandler) waitLongerForSession(ctx context.Context, evt *event.Ev Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())). Msg("Couldn't find session, requesting keys and waiting longer...") - go mx.bridge.Crypto.RequestSession(evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) + go mx.bridge.Crypto.RequestSession(context.Background(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) errorEventID := mx.sendCryptoStatusError(ctx, evt, "", fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), 1, false) if !mx.bridge.Crypto.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { @@ -587,7 +587,7 @@ func (mx *MatrixHandler) HandleMessage(evt *event.Event) { }, Status: event.MessageStatusSuccess, } - _, sendErr := mx.bridge.Bot.SendMessageEvent(evt.RoomID, event.BeeperMessageStatus, statusEvent) + _, sendErr := mx.bridge.Bot.SendMessageEvent(ctx, evt.RoomID, event.BeeperMessageStatus, statusEvent) if sendErr != nil { log.Warn().Err(sendErr).Msg("Failed to send message status event for command") } diff --git a/client.go b/client.go index 17720026..0aff8734 100644 --- a/client.go +++ b/client.go @@ -30,7 +30,7 @@ type CryptoHelper interface { Encrypt(id.RoomID, event.Type, any) (*event.EncryptedEventContent, error) Decrypt(*event.Event) (*event.Event, error) WaitForSession(id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool - RequestSession(id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) + RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) Init() error } @@ -100,14 +100,14 @@ type IdentityServerInfo struct { // DiscoverClientAPI resolves the client API URL from a Matrix server name. // Use ParseUserID to extract the server name from a user ID. // https://spec.matrix.org/v1.2/client-server-api/#server-discovery -func DiscoverClientAPI(serverName string) (*ClientWellKnown, error) { +func DiscoverClientAPI(ctx context.Context, serverName string) (*ClientWellKnown, error) { wellKnownURL := url.URL{ Scheme: "https", Host: serverName, Path: "/.well-known/matrix/client", } - req, err := http.NewRequest("GET", wellKnownURL.String(), nil) + req, err := http.NewRequestWithContext(ctx, "GET", wellKnownURL.String(), nil) if err != nil { return nil, err } @@ -174,16 +174,16 @@ func (cli *Client) SyncWithContext(ctx context.Context) error { // We will keep syncing until the syncing state changes. Either because // Sync is called or StopSync is called. syncingID := cli.incrementSyncingID() - nextBatch := cli.Store.LoadNextBatch(cli.UserID) - filterID := cli.Store.LoadFilterID(cli.UserID) + nextBatch := cli.Store.LoadNextBatch(ctx, cli.UserID) + filterID := cli.Store.LoadFilterID(ctx, cli.UserID) if filterID == "" { filterJSON := cli.Syncer.GetFilterJSON(cli.UserID) - resFilter, err := cli.CreateFilter(filterJSON) + resFilter, err := cli.CreateFilter(ctx, filterJSON) if err != nil { return err } filterID = resFilter.FilterID - cli.Store.SaveFilterID(cli.UserID, filterID) + cli.Store.SaveFilterID(ctx, cli.UserID, filterID) } lastSuccessfulSync := time.Now().Add(-cli.StreamSyncMinAge - 1*time.Hour) for { @@ -192,13 +192,12 @@ func (cli *Client) SyncWithContext(ctx context.Context) error { cli.Log.Debug().Msg("Last sync is old, will stream next response") streamResp = true } - resSync, err := cli.FullSyncRequest(ReqSync{ + resSync, err := cli.FullSyncRequest(ctx, ReqSync{ Timeout: 30000, Since: nextBatch, FilterID: filterID, FullState: false, SetPresence: cli.SyncPresence, - Context: ctx, StreamResponse: streamResp, }) if err != nil { @@ -228,7 +227,7 @@ func (cli *Client) SyncWithContext(ctx context.Context) error { // Save the token now *before* processing it. This means it's possible // to not process some events, but it means that we won't get constantly stuck processing // a malformed/buggy event which keeps making us panic. - cli.Store.SaveNextBatch(cli.UserID, resSync.NextBatch) + cli.Store.SaveNextBatch(ctx, cli.UserID, resSync.NextBatch) if err = cli.Syncer.ProcessResponse(resSync, nextBatch); err != nil { return err } @@ -306,8 +305,8 @@ func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err er } } -func (cli *Client) MakeRequest(method string, httpURL string, reqBody interface{}, resBody interface{}) ([]byte, error) { - return cli.MakeFullRequest(FullRequest{Method: method, URL: httpURL, RequestJSON: reqBody, ResponseJSON: resBody}) +func (cli *Client) MakeRequest(ctx context.Context, method string, httpURL string, reqBody interface{}, resBody interface{}) ([]byte, error) { + return cli.MakeFullRequest(ctx, FullRequest{Method: method, URL: httpURL, RequestJSON: reqBody, ResponseJSON: resBody}) } type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) @@ -321,7 +320,6 @@ type FullRequest struct { RequestBody io.Reader RequestLength int64 ResponseJSON interface{} - Context context.Context MaxAttempts int SensitiveContent bool Handler ClientResponseHandler @@ -331,12 +329,9 @@ type FullRequest struct { var requestID int32 var logSensitiveContent = os.Getenv("MAUTRIX_LOG_SENSITIVE_CONTENT") == "yes" -func (params *FullRequest) compileRequest() (*http.Request, error) { +func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, error) { var logBody any reqBody := params.RequestBody - if params.Context == nil { - params.Context = context.Background() - } if params.RequestJSON != nil { jsonStr, err := json.Marshal(params.RequestJSON) if err != nil { @@ -363,7 +358,6 @@ func (params *FullRequest) compileRequest() (*http.Request, error) { reqBody = bytes.NewReader([]byte("{}")) } reqID := atomic.AddInt32(&requestID, 1) - ctx := params.Context logger := zerolog.Ctx(ctx) if logger.GetLevel() == zerolog.Disabled || logger == zerolog.DefaultContextLogger { logger = params.Logger @@ -398,14 +392,14 @@ func (params *FullRequest) compileRequest() (*http.Request, error) { // Returns the HTTP body as bytes on 2xx with a nil error. Returns an error if the response is not 2xx along // with the HTTP body bytes if it got that far. This error is an HTTPError which includes the returned // HTTP status code and possibly a RespError as the WrappedError, if the HTTP body could be decoded as a RespError. -func (cli *Client) MakeFullRequest(params FullRequest) ([]byte, error) { +func (cli *Client) MakeFullRequest(ctx context.Context, params FullRequest) ([]byte, error) { if params.MaxAttempts == 0 { params.MaxAttempts = 1 + cli.DefaultHTTPRetries } if params.Logger == nil { params.Logger = &cli.Log } - req, err := params.compileRequest() + req, err := params.compileRequest(ctx) if err != nil { return nil, err } @@ -567,39 +561,37 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof } // Whoami gets the user ID of the current user. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3accountwhoami -func (cli *Client) Whoami() (resp *RespWhoami, err error) { +func (cli *Client) Whoami(ctx context.Context) (resp *RespWhoami, err error) { + urlPath := cli.BuildClientURL("v3", "account", "whoami") - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } // CreateFilter makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3useruseridfilter -func (cli *Client) CreateFilter(filter *Filter) (resp *RespCreateFilter, err error) { +func (cli *Client) CreateFilter(ctx context.Context, filter *Filter) (resp *RespCreateFilter, err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "filter") - _, err = cli.MakeRequest("POST", urlPath, filter, &resp) + _, err = cli.MakeRequest(ctx, "POST", urlPath, filter, &resp) return } // SyncRequest makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3sync -func (cli *Client) SyncRequest(timeout int, since, filterID string, fullState bool, setPresence event.Presence, ctx context.Context) (resp *RespSync, err error) { - return cli.FullSyncRequest(ReqSync{ +func (cli *Client) SyncRequest(ctx context.Context, timeout int, since, filterID string, fullState bool, setPresence event.Presence) (resp *RespSync, err error) { + return cli.FullSyncRequest(ctx, ReqSync{ Timeout: timeout, Since: since, FilterID: filterID, FullState: fullState, SetPresence: setPresence, - Context: ctx, }) } type ReqSync struct { - Timeout int - Since string - FilterID string - FullState bool - SetPresence event.Presence - - Context context.Context + Timeout int + Since string + FilterID string + FullState bool + SetPresence event.Presence StreamResponse bool } @@ -623,13 +615,12 @@ func (req *ReqSync) BuildQuery() map[string]string { } // FullSyncRequest makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3sync -func (cli *Client) FullSyncRequest(req ReqSync) (resp *RespSync, err error) { +func (cli *Client) FullSyncRequest(ctx context.Context, req ReqSync) (resp *RespSync, err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "sync"}, req.BuildQuery()) fullReq := FullRequest{ Method: http.MethodGet, URL: urlPath, ResponseJSON: &resp, - Context: req.Context, // We don't want automatic retries for SyncRequest, the Sync() wrapper handles those. MaxAttempts: 1, } @@ -637,7 +628,7 @@ func (cli *Client) FullSyncRequest(req ReqSync) (resp *RespSync, err error) { fullReq.Handler = streamResponse } start := time.Now() - _, err = cli.MakeFullRequest(fullReq) + _, err = cli.MakeFullRequest(ctx, fullReq) duration := time.Now().Sub(start) timeout := time.Duration(req.Timeout) * time.Millisecond buffer := 10 * time.Second @@ -645,7 +636,7 @@ func (cli *Client) FullSyncRequest(req ReqSync) (resp *RespSync, err error) { buffer = 1 * time.Minute } if err == nil && duration > timeout+buffer { - cli.cliOrContextLog(fullReq.Context).Warn(). + cli.cliOrContextLog(ctx).Warn(). Str("since", req.Since). Dur("duration", duration). Dur("timeout", timeout). @@ -676,18 +667,18 @@ func (cli *Client) FullSyncRequest(req ReqSync) (resp *RespSync, err error) { // } else { // // Username is available // } -func (cli *Client) RegisterAvailable(username string) (resp *RespRegisterAvailable, err error) { +func (cli *Client) RegisterAvailable(ctx context.Context, username string) (resp *RespRegisterAvailable, err error) { u := cli.BuildURLWithQuery(ClientURLPath{"v3", "register", "available"}, map[string]string{"username": username}) - _, err = cli.MakeRequest(http.MethodGet, u, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp) if err == nil && !resp.Available { err = fmt.Errorf(`request returned OK status without "available": true`) } return } -func (cli *Client) register(url string, req *ReqRegister) (resp *RespRegister, uiaResp *RespUserInteractive, err error) { +func (cli *Client) register(ctx context.Context, url string, req *ReqRegister) (resp *RespRegister, uiaResp *RespUserInteractive, err error) { var bodyBytes []byte - bodyBytes, err = cli.MakeFullRequest(FullRequest{ + bodyBytes, err = cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, URL: url, RequestJSON: req, @@ -709,21 +700,21 @@ func (cli *Client) register(url string, req *ReqRegister) (resp *RespRegister, u // Register makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register // // Registers with kind=user. For kind=guest, see RegisterGuest. -func (cli *Client) Register(req *ReqRegister) (*RespRegister, *RespUserInteractive, error) { +func (cli *Client) Register(ctx context.Context, req *ReqRegister) (*RespRegister, *RespUserInteractive, error) { u := cli.BuildClientURL("v3", "register") - return cli.register(u, req) + return cli.register(ctx, u, req) } // RegisterGuest makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register // with kind=guest. // // For kind=user, see Register. -func (cli *Client) RegisterGuest(req *ReqRegister) (*RespRegister, *RespUserInteractive, error) { +func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister) (*RespRegister, *RespUserInteractive, error) { query := map[string]string{ "kind": "guest", } u := cli.BuildURLWithQuery(ClientURLPath{"v3", "register"}, query) - return cli.register(u, req) + return cli.register(ctx, u, req) } // RegisterDummy performs m.login.dummy registration according to https://spec.matrix.org/v1.2/client-server-api/#dummy-auth @@ -741,8 +732,8 @@ func (cli *Client) RegisterGuest(req *ReqRegister) (*RespRegister, *RespUserInte // panic(err) // } // token := res.AccessToken -func (cli *Client) RegisterDummy(req *ReqRegister) (*RespRegister, error) { - res, uia, err := cli.Register(req) +func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRegister, error) { + res, uia, err := cli.Register(ctx, req) if err != nil && uia == nil { return nil, err } else if uia == nil { @@ -751,7 +742,7 @@ func (cli *Client) RegisterDummy(req *ReqRegister) (*RespRegister, error) { return nil, errors.New("server does not support m.login.dummy") } req.Auth = BaseAuthData{Type: AuthTypeDummy, Session: uia.Session} - res, _, err = cli.Register(req) + res, _, err = cli.Register(ctx, req) if err != nil { return nil, err } @@ -759,15 +750,15 @@ func (cli *Client) RegisterDummy(req *ReqRegister) (*RespRegister, error) { } // GetLoginFlows fetches the login flows that the homeserver supports using https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3login -func (cli *Client) GetLoginFlows() (resp *RespLoginFlows, err error) { +func (cli *Client) GetLoginFlows(ctx context.Context) (resp *RespLoginFlows, err error) { urlPath := cli.BuildClientURL("v3", "login") - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } // Login a user to the homeserver according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3login -func (cli *Client) Login(req *ReqLogin) (resp *RespLogin, err error) { - _, err = cli.MakeFullRequest(FullRequest{ +func (cli *Client) Login(ctx context.Context, req *ReqLogin) (resp *RespLogin, err error) { + _, err = cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, URL: cli.BuildClientURL("v3", "login"), RequestJSON: req, @@ -803,31 +794,31 @@ func (cli *Client) Login(req *ReqLogin) (resp *RespLogin, err error) { // Logout the current user. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3logout // This does not clear the credentials from the client instance. See ClearCredentials() instead. -func (cli *Client) Logout() (resp *RespLogout, err error) { +func (cli *Client) Logout(ctx context.Context) (resp *RespLogout, err error) { urlPath := cli.BuildClientURL("v3", "logout") - _, err = cli.MakeRequest("POST", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "POST", urlPath, nil, &resp) return } // LogoutAll logs out all the devices of the current user. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3logoutall // This does not clear the credentials from the client instance. See ClearCredentials() instead. -func (cli *Client) LogoutAll() (resp *RespLogout, err error) { +func (cli *Client) LogoutAll(ctx context.Context) (resp *RespLogout, err error) { urlPath := cli.BuildClientURL("v3", "logout", "all") - _, err = cli.MakeRequest("POST", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "POST", urlPath, nil, &resp) return } // Versions returns the list of supported Matrix versions on this homeserver. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientversions -func (cli *Client) Versions() (resp *RespVersions, err error) { +func (cli *Client) Versions(ctx context.Context) (resp *RespVersions, err error) { urlPath := cli.BuildClientURL("versions") - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } // Capabilities returns capabilities on this homeserver. See https://spec.matrix.org/v1.3/client-server-api/#capabilities-negotiation -func (cli *Client) Capabilities() (resp *RespCapabilities, err error) { +func (cli *Client) Capabilities(ctx context.Context) (resp *RespCapabilities, err error) { urlPath := cli.BuildClientURL("v3", "capabilities") - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } @@ -835,7 +826,7 @@ func (cli *Client) Capabilities() (resp *RespCapabilities, err error) { // // If serverName is specified, this will be added as a query param to instruct the homeserver to join via that server. If content is specified, it will // be JSON encoded and used as the request body. -func (cli *Client) JoinRoom(roomIDorAlias, serverName string, content interface{}) (resp *RespJoinRoom, err error) { +func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias, serverName string, content interface{}) (resp *RespJoinRoom, err error) { var urlPath string if serverName != "" { urlPath = cli.BuildURLWithQuery(ClientURLPath{"v3", "join", roomIDorAlias}, map[string]string{ @@ -844,7 +835,7 @@ func (cli *Client) JoinRoom(roomIDorAlias, serverName string, content interface{ } else { urlPath = cli.BuildClientURL("v3", "join", roomIDorAlias) } - _, err = cli.MakeRequest("POST", urlPath, content, &resp) + _, err = cli.MakeRequest(ctx, "POST", urlPath, content, &resp) if err == nil && cli.StateStore != nil { cli.StateStore.SetMembership(resp.RoomID, cli.UserID, event.MembershipJoin) } @@ -855,50 +846,50 @@ func (cli *Client) JoinRoom(roomIDorAlias, serverName string, content interface{ // // Unlike JoinRoom, this method can only be used to join rooms that the server already knows about. // It's mostly intended for bridges and other things where it's already certain that the server is in the room. -func (cli *Client) JoinRoomByID(roomID id.RoomID) (resp *RespJoinRoom, err error) { - _, err = cli.MakeRequest("POST", cli.BuildClientURL("v3", "rooms", roomID, "join"), nil, &resp) +func (cli *Client) JoinRoomByID(ctx context.Context, roomID id.RoomID) (resp *RespJoinRoom, err error) { + _, err = cli.MakeRequest(ctx, "POST", cli.BuildClientURL("v3", "rooms", roomID, "join"), nil, &resp) if err == nil && cli.StateStore != nil { cli.StateStore.SetMembership(resp.RoomID, cli.UserID, event.MembershipJoin) } return } -func (cli *Client) GetProfile(mxid id.UserID) (resp *RespUserProfile, err error) { +func (cli *Client) GetProfile(ctx context.Context, mxid id.UserID) (resp *RespUserProfile, err error) { urlPath := cli.BuildClientURL("v3", "profile", mxid) - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } // GetDisplayName returns the display name of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseriddisplayname -func (cli *Client) GetDisplayName(mxid id.UserID) (resp *RespUserDisplayName, err error) { +func (cli *Client) GetDisplayName(ctx context.Context, mxid id.UserID) (resp *RespUserDisplayName, err error) { urlPath := cli.BuildClientURL("v3", "profile", mxid, "displayname") - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } // GetOwnDisplayName returns the user's display name. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseriddisplayname -func (cli *Client) GetOwnDisplayName() (resp *RespUserDisplayName, err error) { - return cli.GetDisplayName(cli.UserID) +func (cli *Client) GetOwnDisplayName(ctx context.Context) (resp *RespUserDisplayName, err error) { + return cli.GetDisplayName(ctx, cli.UserID) } // SetDisplayName sets the user's profile display name. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3profileuseriddisplayname -func (cli *Client) SetDisplayName(displayName string) (err error) { +func (cli *Client) SetDisplayName(ctx context.Context, displayName string) (err error) { urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, "displayname") s := struct { DisplayName string `json:"displayname"` }{displayName} - _, err = cli.MakeRequest("PUT", urlPath, &s, nil) + _, err = cli.MakeRequest(ctx, "PUT", urlPath, &s, nil) return } // GetAvatarURL gets the avatar URL of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseridavatar_url -func (cli *Client) GetAvatarURL(mxid id.UserID) (url id.ContentURI, err error) { +func (cli *Client) GetAvatarURL(ctx context.Context, mxid id.UserID) (url id.ContentURI, err error) { urlPath := cli.BuildClientURL("v3", "profile", mxid, "avatar_url") s := struct { AvatarURL id.ContentURI `json:"avatar_url"` }{} - _, err = cli.MakeRequest("GET", urlPath, nil, &s) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &s) if err != nil { return } @@ -907,17 +898,17 @@ func (cli *Client) GetAvatarURL(mxid id.UserID) (url id.ContentURI, err error) { } // GetOwnAvatarURL gets the user's avatar URL. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseridavatar_url -func (cli *Client) GetOwnAvatarURL() (url id.ContentURI, err error) { - return cli.GetAvatarURL(cli.UserID) +func (cli *Client) GetOwnAvatarURL(ctx context.Context) (url id.ContentURI, err error) { + return cli.GetAvatarURL(ctx, cli.UserID) } // SetAvatarURL sets the user's avatar URL. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3profileuseridavatar_url -func (cli *Client) SetAvatarURL(url id.ContentURI) (err error) { +func (cli *Client) SetAvatarURL(ctx context.Context, url id.ContentURI) (err error) { urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, "avatar_url") s := struct { AvatarURL string `json:"avatar_url"` }{url.String()} - _, err = cli.MakeRequest("PUT", urlPath, &s, nil) + _, err = cli.MakeRequest(ctx, "PUT", urlPath, &s, nil) if err != nil { return err } @@ -926,23 +917,23 @@ func (cli *Client) SetAvatarURL(url id.ContentURI) (err error) { } // BeeperUpdateProfile sets custom fields in the user's profile. -func (cli *Client) BeeperUpdateProfile(data map[string]any) (err error) { +func (cli *Client) BeeperUpdateProfile(ctx context.Context, data map[string]any) (err error) { urlPath := cli.BuildClientURL("v3", "profile", cli.UserID) - _, err = cli.MakeRequest("PATCH", urlPath, &data, nil) + _, err = cli.MakeRequest(ctx, "PATCH", urlPath, &data, nil) return } // GetAccountData gets the user's account data of this type. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3useruseridaccount_datatype -func (cli *Client) GetAccountData(name string, output interface{}) (err error) { +func (cli *Client) GetAccountData(ctx context.Context, name string, output interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "account_data", name) - _, err = cli.MakeRequest("GET", urlPath, nil, output) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, output) return } // SetAccountData sets the user's account data of this type. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3useruseridaccount_datatype -func (cli *Client) SetAccountData(name string, data interface{}) (err error) { +func (cli *Client) SetAccountData(ctx context.Context, name string, data interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "account_data", name) - _, err = cli.MakeRequest("PUT", urlPath, data, nil) + _, err = cli.MakeRequest(ctx, "PUT", urlPath, data, nil) if err != nil { return err } @@ -951,16 +942,16 @@ func (cli *Client) SetAccountData(name string, data interface{}) (err error) { } // GetRoomAccountData gets the user's account data of this type in a specific room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3useruseridaccount_datatype -func (cli *Client) GetRoomAccountData(roomID id.RoomID, name string, output interface{}) (err error) { +func (cli *Client) GetRoomAccountData(ctx context.Context, roomID id.RoomID, name string, output interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "account_data", name) - _, err = cli.MakeRequest("GET", urlPath, nil, output) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, output) return } // SetRoomAccountData sets the user's account data of this type in a specific room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3useruseridroomsroomidaccount_datatype -func (cli *Client) SetRoomAccountData(roomID id.RoomID, name string, data interface{}) (err error) { +func (cli *Client) SetRoomAccountData(ctx context.Context, roomID id.RoomID, name string, data interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "account_data", name) - _, err = cli.MakeRequest("PUT", urlPath, data, nil) + _, err = cli.MakeRequest(ctx, "PUT", urlPath, data, nil) if err != nil { return err } @@ -979,7 +970,7 @@ type ReqSendEvent struct { // SendMessageEvent sends a message event into a room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidsendeventtypetxnid // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. -func (cli *Client) SendMessageEvent(roomID id.RoomID, eventType event.Type, contentJSON interface{}, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { +func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { var req ReqSendEvent if len(extra) > 0 { req = extra[0] @@ -1011,15 +1002,15 @@ func (cli *Client) SendMessageEvent(roomID id.RoomID, eventType event.Type, cont urlData := ClientURLPath{"v3", "rooms", roomID, "send", eventType.String(), txnID} urlPath := cli.BuildURLWithQuery(urlData, queryParams) - _, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp) + _, err = cli.MakeRequest(ctx, "PUT", urlPath, contentJSON, &resp) return } // SendStateEvent sends a state event into a room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. -func (cli *Client) SendStateEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (resp *RespSendEvent, err error) { +func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (resp *RespSendEvent, err error) { urlPath := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey) - _, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp) + _, err = cli.MakeRequest(ctx, "PUT", urlPath, contentJSON, &resp) if err == nil && cli.StateStore != nil { cli.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, contentJSON) } @@ -1028,11 +1019,11 @@ func (cli *Client) SendStateEvent(roomID id.RoomID, eventType event.Type, stateK // SendMassagedStateEvent sends a state event into a room with a custom timestamp. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. -func (cli *Client) SendMassagedStateEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) { +func (cli *Client) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey}, map[string]string{ "ts": strconv.FormatInt(ts, 10), }) - _, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp) + _, err = cli.MakeRequest(ctx, "PUT", urlPath, contentJSON, &resp) if err == nil && cli.StateStore != nil { cli.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, contentJSON) } @@ -1041,8 +1032,8 @@ func (cli *Client) SendMassagedStateEvent(roomID id.RoomID, eventType event.Type // SendText sends an m.room.message event into the given room with a msgtype of m.text // See https://spec.matrix.org/v1.2/client-server-api/#mtext -func (cli *Client) SendText(roomID id.RoomID, text string) (*RespSendEvent, error) { - return cli.SendMessageEvent(roomID, event.EventMessage, &event.MessageEventContent{ +func (cli *Client) SendText(ctx context.Context, roomID id.RoomID, text string) (*RespSendEvent, error) { + return cli.SendMessageEvent(ctx, roomID, event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgText, Body: text, }) @@ -1050,15 +1041,15 @@ func (cli *Client) SendText(roomID id.RoomID, text string) (*RespSendEvent, erro // SendNotice sends an m.room.message event into the given room with a msgtype of m.notice // See https://spec.matrix.org/v1.2/client-server-api/#mnotice -func (cli *Client) SendNotice(roomID id.RoomID, text string) (*RespSendEvent, error) { - return cli.SendMessageEvent(roomID, event.EventMessage, &event.MessageEventContent{ +func (cli *Client) SendNotice(ctx context.Context, roomID id.RoomID, text string) (*RespSendEvent, error) { + return cli.SendMessageEvent(ctx, roomID, event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgNotice, Body: text, }) } -func (cli *Client) SendReaction(roomID id.RoomID, eventID id.EventID, reaction string) (*RespSendEvent, error) { - return cli.SendMessageEvent(roomID, event.EventReaction, &event.ReactionEventContent{ +func (cli *Client) SendReaction(ctx context.Context, roomID id.RoomID, eventID id.EventID, reaction string) (*RespSendEvent, error) { + return cli.SendMessageEvent(ctx, roomID, event.EventReaction, &event.ReactionEventContent{ RelatesTo: event.RelatesTo{ EventID: eventID, Type: event.RelAnnotation, @@ -1068,7 +1059,7 @@ func (cli *Client) SendReaction(roomID id.RoomID, eventID id.EventID, reaction s } // RedactEvent redacts the given event. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidredacteventidtxnid -func (cli *Client) RedactEvent(roomID id.RoomID, eventID id.EventID, extra ...ReqRedact) (resp *RespSendEvent, err error) { +func (cli *Client) RedactEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID, extra ...ReqRedact) (resp *RespSendEvent, err error) { req := ReqRedact{} if len(extra) > 0 { req = extra[0] @@ -1086,7 +1077,7 @@ func (cli *Client) RedactEvent(roomID id.RoomID, eventID id.EventID, extra ...Re txnID = cli.TxnID() } urlPath := cli.BuildClientURL("v3", "rooms", roomID, "redact", eventID, txnID) - _, err = cli.MakeRequest("PUT", urlPath, req.Extra, &resp) + _, err = cli.MakeRequest(ctx, "PUT", urlPath, req.Extra, &resp) return } @@ -1096,9 +1087,9 @@ func (cli *Client) RedactEvent(roomID id.RoomID, eventID id.EventID, extra ...Re // Preset: "public_chat", // }) // fmt.Println("Room:", resp.RoomID) -func (cli *Client) CreateRoom(req *ReqCreateRoom) (resp *RespCreateRoom, err error) { +func (cli *Client) CreateRoom(ctx context.Context, req *ReqCreateRoom) (resp *RespCreateRoom, err error) { urlPath := cli.BuildClientURL("v3", "createRoom") - _, err = cli.MakeRequest("POST", urlPath, req, &resp) + _, err = cli.MakeRequest(ctx, "POST", urlPath, req, &resp) if err == nil && cli.StateStore != nil { cli.StateStore.SetMembership(resp.RoomID, cli.UserID, event.MembershipJoin) for _, evt := range req.InitialState { @@ -1119,7 +1110,7 @@ func (cli *Client) CreateRoom(req *ReqCreateRoom) (resp *RespCreateRoom, err err } // LeaveRoom leaves the given room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidleave -func (cli *Client) LeaveRoom(roomID id.RoomID, optionalReq ...*ReqLeave) (resp *RespLeaveRoom, err error) { +func (cli *Client) LeaveRoom(ctx context.Context, roomID id.RoomID, optionalReq ...*ReqLeave) (resp *RespLeaveRoom, err error) { req := &ReqLeave{} if len(optionalReq) == 1 { req = optionalReq[0] @@ -1127,7 +1118,7 @@ func (cli *Client) LeaveRoom(roomID id.RoomID, optionalReq ...*ReqLeave) (resp * panic("invalid number of arguments to LeaveRoom") } u := cli.BuildClientURL("v3", "rooms", roomID, "leave") - _, err = cli.MakeRequest("POST", u, req, &resp) + _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) if err == nil && cli.StateStore != nil { cli.StateStore.SetMembership(roomID, cli.UserID, event.MembershipLeave) } @@ -1135,16 +1126,16 @@ func (cli *Client) LeaveRoom(roomID id.RoomID, optionalReq ...*ReqLeave) (resp * } // ForgetRoom forgets a room entirely. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidforget -func (cli *Client) ForgetRoom(roomID id.RoomID) (resp *RespForgetRoom, err error) { +func (cli *Client) ForgetRoom(ctx context.Context, roomID id.RoomID) (resp *RespForgetRoom, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "forget") - _, err = cli.MakeRequest("POST", u, struct{}{}, &resp) + _, err = cli.MakeRequest(ctx, "POST", u, struct{}{}, &resp) return } // InviteUser invites a user to a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidinvite -func (cli *Client) InviteUser(roomID id.RoomID, req *ReqInviteUser) (resp *RespInviteUser, err error) { +func (cli *Client) InviteUser(ctx context.Context, roomID id.RoomID, req *ReqInviteUser) (resp *RespInviteUser, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "invite") - _, err = cli.MakeRequest("POST", u, req, &resp) + _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) if err == nil && cli.StateStore != nil { cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipInvite) } @@ -1152,16 +1143,16 @@ func (cli *Client) InviteUser(roomID id.RoomID, req *ReqInviteUser) (resp *RespI } // InviteUserByThirdParty invites a third-party identifier to a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidinvite-1 -func (cli *Client) InviteUserByThirdParty(roomID id.RoomID, req *ReqInvite3PID) (resp *RespInviteUser, err error) { +func (cli *Client) InviteUserByThirdParty(ctx context.Context, roomID id.RoomID, req *ReqInvite3PID) (resp *RespInviteUser, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "invite") - _, err = cli.MakeRequest("POST", u, req, &resp) + _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) return } // KickUser kicks a user from a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidkick -func (cli *Client) KickUser(roomID id.RoomID, req *ReqKickUser) (resp *RespKickUser, err error) { +func (cli *Client) KickUser(ctx context.Context, roomID id.RoomID, req *ReqKickUser) (resp *RespKickUser, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "kick") - _, err = cli.MakeRequest("POST", u, req, &resp) + _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) if err == nil && cli.StateStore != nil { cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipLeave) } @@ -1169,9 +1160,9 @@ func (cli *Client) KickUser(roomID id.RoomID, req *ReqKickUser) (resp *RespKickU } // BanUser bans a user from a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidban -func (cli *Client) BanUser(roomID id.RoomID, req *ReqBanUser) (resp *RespBanUser, err error) { +func (cli *Client) BanUser(ctx context.Context, roomID id.RoomID, req *ReqBanUser) (resp *RespBanUser, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "ban") - _, err = cli.MakeRequest("POST", u, req, &resp) + _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) if err == nil && cli.StateStore != nil { cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipBan) } @@ -1179,9 +1170,9 @@ func (cli *Client) BanUser(roomID id.RoomID, req *ReqBanUser) (resp *RespBanUser } // UnbanUser unbans a user from a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidunban -func (cli *Client) UnbanUser(roomID id.RoomID, req *ReqUnbanUser) (resp *RespUnbanUser, err error) { +func (cli *Client) UnbanUser(ctx context.Context, roomID id.RoomID, req *ReqUnbanUser) (resp *RespUnbanUser, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "unban") - _, err = cli.MakeRequest("POST", u, req, &resp) + _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) if err == nil && cli.StateStore != nil { cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipLeave) } @@ -1189,30 +1180,30 @@ func (cli *Client) UnbanUser(roomID id.RoomID, req *ReqUnbanUser) (resp *RespUnb } // UserTyping sets the typing status of the user. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidtypinguserid -func (cli *Client) UserTyping(roomID id.RoomID, typing bool, timeout time.Duration) (resp *RespTyping, err error) { +func (cli *Client) UserTyping(ctx context.Context, roomID id.RoomID, typing bool, timeout time.Duration) (resp *RespTyping, err error) { req := ReqTyping{Typing: typing, Timeout: timeout.Milliseconds()} u := cli.BuildClientURL("v3", "rooms", roomID, "typing", cli.UserID) - _, err = cli.MakeRequest("PUT", u, req, &resp) + _, err = cli.MakeRequest(ctx, "PUT", u, req, &resp) return } // GetPresence gets the presence of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3presenceuseridstatus -func (cli *Client) GetPresence(userID id.UserID) (resp *RespPresence, err error) { +func (cli *Client) GetPresence(ctx context.Context, userID id.UserID) (resp *RespPresence, err error) { resp = new(RespPresence) u := cli.BuildClientURL("v3", "presence", userID, "status") - _, err = cli.MakeRequest("GET", u, nil, resp) + _, err = cli.MakeRequest(ctx, "GET", u, nil, resp) return } // GetOwnPresence gets the user's presence. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3presenceuseridstatus -func (cli *Client) GetOwnPresence() (resp *RespPresence, err error) { - return cli.GetPresence(cli.UserID) +func (cli *Client) GetOwnPresence(ctx context.Context) (resp *RespPresence, err error) { + return cli.GetPresence(ctx, cli.UserID) } -func (cli *Client) SetPresence(status event.Presence) (err error) { +func (cli *Client) SetPresence(ctx context.Context, status event.Presence) (err error) { req := ReqPresence{Presence: status} u := cli.BuildClientURL("v3", "presence", cli.UserID, "status") - _, err = cli.MakeRequest("PUT", u, req, nil) + _, err = cli.MakeRequest(ctx, "PUT", u, req, nil) return } @@ -1252,9 +1243,9 @@ func (cli *Client) updateStoreWithOutgoingEvent(roomID id.RoomID, eventType even // StateEvent gets a single state event in a room. It will attempt to JSON unmarshal into the given "outContent" struct with // the HTTP response body, or return an error. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstateeventtypestatekey -func (cli *Client) StateEvent(roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) (err error) { +func (cli *Client) StateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) (err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey) - _, err = cli.MakeRequest("GET", u, nil, outContent) + _, err = cli.MakeRequest(ctx, "GET", u, nil, outContent) if err == nil && cli.StateStore != nil { cli.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, outContent) } @@ -1302,8 +1293,8 @@ func parseRoomStateArray(_ *http.Request, res *http.Response, responseJSON inter // State gets all state in a room. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstate -func (cli *Client) State(roomID id.RoomID) (stateMap RoomStateMap, err error) { - _, err = cli.MakeFullRequest(FullRequest{ +func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomStateMap, err error) { + _, err = cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodGet, URL: cli.BuildClientURL("v3", "rooms", roomID, "state"), ResponseJSON: &stateMap, @@ -1321,34 +1312,35 @@ func (cli *Client) State(roomID id.RoomID) (stateMap RoomStateMap, err error) { } // GetMediaConfig fetches the configuration of the content repository, such as upload limitations. -func (cli *Client) GetMediaConfig() (resp *RespMediaConfig, err error) { +func (cli *Client) GetMediaConfig(ctx context.Context) (resp *RespMediaConfig, err error) { u := cli.BuildURL(MediaURLPath{"v3", "config"}) - _, err = cli.MakeRequest("GET", u, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", u, nil, &resp) return } // UploadLink uploads an HTTP URL and then returns an MXC URI. -func (cli *Client) UploadLink(link string) (*RespMediaUpload, error) { - res, err := cli.Client.Get(link) +func (cli *Client) UploadLink(ctx context.Context, link string) (*RespMediaUpload, error) { + req, err := http.NewRequestWithContext(ctx, "GET", link, nil) + if err != nil { + return nil, err + } + + res, err := cli.Client.Do(req) if res != nil { defer res.Body.Close() } if err != nil { return nil, err } - return cli.Upload(res.Body, res.Header.Get("Content-Type"), res.ContentLength) + return cli.Upload(ctx, res.Body, res.Header.Get("Content-Type"), res.ContentLength) } func (cli *Client) GetDownloadURL(mxcURL id.ContentURI) string { return cli.BuildURLWithQuery(MediaURLPath{"v3", "download", mxcURL.Homeserver, mxcURL.FileID}, map[string]string{"allow_redirect": "true"}) } -func (cli *Client) Download(mxcURL id.ContentURI) (io.ReadCloser, error) { - return cli.DownloadContext(context.Background(), mxcURL) -} - -func (cli *Client) DownloadContext(ctx context.Context, mxcURL id.ContentURI) (io.ReadCloser, error) { - resp, err := cli.downloadContext(ctx, mxcURL) +func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (io.ReadCloser, error) { + resp, err := cli.download(ctx, mxcURL) if err != nil { return nil, err } @@ -1411,7 +1403,7 @@ func (cli *Client) doMediaRequest(req *http.Request, retries int, backoff time.D return res, err } -func (cli *Client) downloadContext(ctx context.Context, mxcURL id.ContentURI) (*http.Response, error) { +func (cli *Client) download(ctx context.Context, mxcURL id.ContentURI) (*http.Response, error) { ctxLog := zerolog.Ctx(ctx) if ctxLog.GetLevel() == zerolog.Disabled || ctxLog == zerolog.DefaultContextLogger { ctx = cli.Log.WithContext(ctx) @@ -1424,12 +1416,8 @@ func (cli *Client) downloadContext(ctx context.Context, mxcURL id.ContentURI) (* return cli.doMediaRequest(req, cli.DefaultHTTPRetries, 4*time.Second) } -func (cli *Client) DownloadBytes(mxcURL id.ContentURI) ([]byte, error) { - return cli.DownloadBytesContext(context.Background(), mxcURL) -} - -func (cli *Client) DownloadBytesContext(ctx context.Context, mxcURL id.ContentURI) ([]byte, error) { - resp, err := cli.downloadContext(ctx, mxcURL) +func (cli *Client) DownloadBytes(ctx context.Context, mxcURL id.ContentURI) ([]byte, error) { + resp, err := cli.download(ctx, mxcURL) if err != nil { return nil, err } @@ -1440,10 +1428,10 @@ func (cli *Client) DownloadBytesContext(ctx context.Context, mxcURL id.ContentUR // CreateMXC creates a blank Matrix content URI to allow uploading the content asynchronously later. // // See https://spec.matrix.org/v1.7/client-server-api/#post_matrixmediav1create -func (cli *Client) CreateMXC() (*RespCreateMXC, error) { +func (cli *Client) CreateMXC(ctx context.Context) (*RespCreateMXC, error) { u, _ := url.Parse(cli.BuildURL(MediaURLPath{"v1", "create"})) var m RespCreateMXC - _, err := cli.MakeFullRequest(FullRequest{ + _, err := cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, URL: u.String(), ResponseJSON: &m, @@ -1456,15 +1444,15 @@ func (cli *Client) CreateMXC() (*RespCreateMXC, error) { // // See https://spec.matrix.org/v1.7/client-server-api/#post_matrixmediav1create // and https://spec.matrix.org/v1.7/client-server-api/#put_matrixmediav3uploadservernamemediaid -func (cli *Client) UploadAsync(req ReqUploadMedia) (*RespCreateMXC, error) { - resp, err := cli.CreateMXC() +func (cli *Client) UploadAsync(ctx context.Context, req ReqUploadMedia) (*RespCreateMXC, error) { + resp, err := cli.CreateMXC(ctx) if err != nil { return nil, err } req.MXC = resp.ContentURI req.UnstableUploadURL = resp.UnstableUploadURL go func() { - _, err = cli.UploadMedia(req) + _, err = cli.UploadMedia(ctx, req) if err != nil { cli.Log.Error().Str("mxc", req.MXC.String()).Err(err).Msg("Async upload of media failed") } @@ -1472,12 +1460,12 @@ func (cli *Client) UploadAsync(req ReqUploadMedia) (*RespCreateMXC, error) { return resp, nil } -func (cli *Client) UploadBytes(data []byte, contentType string) (*RespMediaUpload, error) { - return cli.UploadBytesWithName(data, contentType, "") +func (cli *Client) UploadBytes(ctx context.Context, data []byte, contentType string) (*RespMediaUpload, error) { + return cli.UploadBytesWithName(ctx, data, contentType, "") } -func (cli *Client) UploadBytesWithName(data []byte, contentType, fileName string) (*RespMediaUpload, error) { - return cli.UploadMedia(ReqUploadMedia{ +func (cli *Client) UploadBytesWithName(ctx context.Context, data []byte, contentType, fileName string) (*RespMediaUpload, error) { + return cli.UploadMedia(ctx, ReqUploadMedia{ ContentBytes: data, ContentType: contentType, FileName: fileName, @@ -1487,8 +1475,8 @@ func (cli *Client) UploadBytesWithName(data []byte, contentType, fileName string // Upload uploads the given data to the content repository and returns an MXC URI. // // Deprecated: UploadMedia should be used instead. -func (cli *Client) Upload(content io.Reader, contentType string, contentLength int64) (*RespMediaUpload, error) { - return cli.UploadMedia(ReqUploadMedia{ +func (cli *Client) Upload(ctx context.Context, content io.Reader, contentType string, contentLength int64) (*RespMediaUpload, error) { + return cli.UploadMedia(ctx, ReqUploadMedia{ Content: content, ContentLength: contentLength, ContentType: contentType, @@ -1511,9 +1499,9 @@ type ReqUploadMedia struct { UnstableUploadURL string } -func (cli *Client) tryUploadMediaToURL(url, contentType string, content io.Reader) (*http.Response, error) { +func (cli *Client) tryUploadMediaToURL(ctx context.Context, url, contentType string, content io.Reader) (*http.Response, error) { cli.Log.Debug().Str("url", url).Msg("Uploading media to external URL") - req, err := http.NewRequest(http.MethodPut, url, content) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, content) if err != nil { return nil, err } @@ -1523,7 +1511,7 @@ func (cli *Client) tryUploadMediaToURL(url, contentType string, content io.Reade return http.DefaultClient.Do(req) } -func (cli *Client) uploadMediaToURL(data ReqUploadMedia) (*RespMediaUpload, error) { +func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (*RespMediaUpload, error) { retries := cli.DefaultHTTPRetries if data.ContentBytes == nil { // Can't retry with a reader @@ -1536,7 +1524,7 @@ func (cli *Client) uploadMediaToURL(data ReqUploadMedia) (*RespMediaUpload, erro } else { data.Content = nil } - resp, err := cli.tryUploadMediaToURL(data.UnstableUploadURL, data.ContentType, reader) + resp, err := cli.tryUploadMediaToURL(ctx, data.UnstableUploadURL, data.ContentType, reader) if err == nil { if resp.StatusCode >= 200 && resp.StatusCode < 300 { // Everything is fine @@ -1562,7 +1550,7 @@ func (cli *Client) uploadMediaToURL(data ReqUploadMedia) (*RespMediaUpload, erro notifyURL := cli.BuildURLWithQuery(MediaURLPath{"unstable", "com.beeper.msc3870", "upload", data.MXC.Homeserver, data.MXC.FileID, "complete"}, query) var m *RespMediaUpload - _, err := cli.MakeFullRequest(FullRequest{ + _, err := cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, URL: notifyURL, ResponseJSON: m, @@ -1576,12 +1564,12 @@ func (cli *Client) uploadMediaToURL(data ReqUploadMedia) (*RespMediaUpload, erro // UploadMedia uploads the given data to the content repository and returns an MXC URI. // See https://spec.matrix.org/v1.7/client-server-api/#post_matrixmediav3upload -func (cli *Client) UploadMedia(data ReqUploadMedia) (*RespMediaUpload, error) { +func (cli *Client) UploadMedia(ctx context.Context, data ReqUploadMedia) (*RespMediaUpload, error) { if data.UnstableUploadURL != "" { if data.MXC.IsEmpty() { return nil, errors.New("MXC must also be set when uploading to external URL") } - return cli.uploadMediaToURL(data) + return cli.uploadMediaToURL(ctx, data) } u, _ := url.Parse(cli.BuildURL(MediaURLPath{"v3", "upload"})) method := http.MethodPost @@ -1601,7 +1589,7 @@ func (cli *Client) UploadMedia(data ReqUploadMedia) (*RespMediaUpload, error) { } var m RespMediaUpload - _, err := cli.MakeFullRequest(FullRequest{ + _, err := cli.MakeFullRequest(ctx, FullRequest{ Method: method, URL: u.String(), Headers: headers, @@ -1616,12 +1604,12 @@ func (cli *Client) UploadMedia(data ReqUploadMedia) (*RespMediaUpload, error) { // GetURLPreview asks the homeserver to fetch a preview for a given URL. // // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixmediav3preview_url -func (cli *Client) GetURLPreview(url string) (*RespPreviewURL, error) { +func (cli *Client) GetURLPreview(ctx context.Context, url string) (*RespPreviewURL, error) { reqURL := cli.BuildURLWithQuery(MediaURLPath{"v3", "preview_url"}, map[string]string{ "url": url, }) var output RespPreviewURL - _, err := cli.MakeRequest(http.MethodGet, reqURL, nil, &output) + _, err := cli.MakeRequest(ctx, http.MethodGet, reqURL, nil, &output) return &output, err } @@ -1629,9 +1617,9 @@ func (cli *Client) GetURLPreview(url string) (*RespPreviewURL, error) { // // In general, usage of this API is discouraged in favour of /sync, as calling this API can race with incoming membership changes. // This API is primarily designed for application services which may want to efficiently look up joined members in a room. -func (cli *Client) JoinedMembers(roomID id.RoomID) (resp *RespJoinedMembers, err error) { +func (cli *Client) JoinedMembers(ctx context.Context, roomID id.RoomID) (resp *RespJoinedMembers, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "joined_members") - _, err = cli.MakeRequest("GET", u, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", u, nil, &resp) if err == nil && cli.StateStore != nil { cli.StateStore.ClearCachedMembers(roomID, event.MembershipJoin) for userID, member := range resp.Joined { @@ -1645,7 +1633,7 @@ func (cli *Client) JoinedMembers(roomID id.RoomID) (resp *RespJoinedMembers, err return } -func (cli *Client) Members(roomID id.RoomID, req ...ReqMembers) (resp *RespMembers, err error) { +func (cli *Client) Members(ctx context.Context, roomID id.RoomID, req ...ReqMembers) (resp *RespMembers, err error) { var extra ReqMembers if len(req) > 0 { extra = req[0] @@ -1661,7 +1649,7 @@ func (cli *Client) Members(roomID id.RoomID, req ...ReqMembers) (resp *RespMembe query["not_membership"] = string(extra.NotMembership) } u := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "members"}, query) - _, err = cli.MakeRequest("GET", u, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", u, nil, &resp) if err == nil && cli.StateStore != nil { var clearMemberships []event.Membership if extra.Membership != "" { @@ -1681,9 +1669,9 @@ func (cli *Client) Members(roomID id.RoomID, req ...ReqMembers) (resp *RespMembe // // In general, usage of this API is discouraged in favour of /sync, as calling this API can race with incoming membership changes. // This API is primarily designed for application services which may want to efficiently look up joined rooms. -func (cli *Client) JoinedRooms() (resp *RespJoinedRooms, err error) { +func (cli *Client) JoinedRooms(ctx context.Context) (resp *RespJoinedRooms, err error) { u := cli.BuildClientURL("v3", "joined_rooms") - _, err = cli.MakeRequest("GET", u, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", u, nil, &resp) return } @@ -1693,16 +1681,16 @@ func (cli *Client) JoinedRooms() (resp *RespJoinedRooms, err error) { // when it encounters another space as a child it recurses into that space before returning non-space children. // // The second function parameter specifies query parameters to limit the response. No query parameters will be added if it's nil. -func (cli *Client) Hierarchy(roomID id.RoomID, req *ReqHierarchy) (resp *RespHierarchy, err error) { +func (cli *Client) Hierarchy(ctx context.Context, roomID id.RoomID, req *ReqHierarchy) (resp *RespHierarchy, err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v1", "rooms", roomID, "hierarchy"}, req.Query()) - _, err = cli.MakeRequest(http.MethodGet, urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } // Messages returns a list of message and state events for a room. It uses // pagination query parameters to paginate history in the room. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidmessages -func (cli *Client) Messages(roomID id.RoomID, from, to string, dir Direction, filter *FilterPart, limit int) (resp *RespMessages, err error) { +func (cli *Client) Messages(ctx context.Context, roomID id.RoomID, from, to string, dir Direction, filter *FilterPart, limit int) (resp *RespMessages, err error) { query := map[string]string{ "from": from, "dir": string(dir), @@ -1722,20 +1710,20 @@ func (cli *Client) Messages(roomID id.RoomID, from, to string, dir Direction, fi } urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "messages"}, query) - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } // TimestampToEvent finds the ID of the event closest to the given timestamp. // // See https://spec.matrix.org/v1.6/client-server-api/#get_matrixclientv1roomsroomidtimestamp_to_event -func (cli *Client) TimestampToEvent(roomID id.RoomID, timestamp time.Time, dir Direction) (resp *RespTimestampToEvent, err error) { +func (cli *Client) TimestampToEvent(ctx context.Context, roomID id.RoomID, timestamp time.Time, dir Direction) (resp *RespTimestampToEvent, err error) { query := map[string]string{ "ts": strconv.FormatInt(timestamp.UnixMilli(), 10), "dir": string(dir), } urlPath := cli.BuildURLWithQuery(ClientURLPath{"v1", "rooms", roomID, "timestamp_to_event"}, query) - _, err = cli.MakeRequest(http.MethodGet, urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } @@ -1743,7 +1731,7 @@ func (cli *Client) TimestampToEvent(roomID id.RoomID, timestamp time.Time, dir D // specified event. It use pagination query parameters to paginate history in // the room. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidcontexteventid -func (cli *Client) Context(roomID id.RoomID, eventID id.EventID, filter *FilterPart, limit int) (resp *RespContext, err error) { +func (cli *Client) Context(ctx context.Context, roomID id.RoomID, eventID id.EventID, filter *FilterPart, limit int) (resp *RespContext, err error) { query := map[string]string{} if filter != nil { filterJSON, err := json.Marshal(filter) @@ -1757,173 +1745,173 @@ func (cli *Client) Context(roomID id.RoomID, eventID id.EventID, filter *FilterP } urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "context", eventID}, query) - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } -func (cli *Client) GetEvent(roomID id.RoomID, eventID id.EventID) (resp *event.Event, err error) { +func (cli *Client) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (resp *event.Event, err error) { urlPath := cli.BuildClientURL("v3", "rooms", roomID, "event", eventID) - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } -func (cli *Client) MarkRead(roomID id.RoomID, eventID id.EventID) (err error) { - return cli.SendReceipt(roomID, eventID, event.ReceiptTypeRead, nil) +func (cli *Client) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.EventID) (err error) { + return cli.SendReceipt(ctx, roomID, eventID, event.ReceiptTypeRead, nil) } // MarkReadWithContent sends a read receipt including custom data. // // Deprecated: Use SendReceipt instead. -func (cli *Client) MarkReadWithContent(roomID id.RoomID, eventID id.EventID, content interface{}) (err error) { - return cli.SendReceipt(roomID, eventID, event.ReceiptTypeRead, content) +func (cli *Client) MarkReadWithContent(ctx context.Context, roomID id.RoomID, eventID id.EventID, content interface{}) (err error) { + return cli.SendReceipt(ctx, roomID, eventID, event.ReceiptTypeRead, content) } // SendReceipt sends a receipt, usually specifically a read receipt. // // Passing nil as the content is safe, the library will automatically replace it with an empty JSON object. // To mark a message in a specific thread as read, use pass a ReqSendReceipt as the content. -func (cli *Client) SendReceipt(roomID id.RoomID, eventID id.EventID, receiptType event.ReceiptType, content interface{}) (err error) { +func (cli *Client) SendReceipt(ctx context.Context, roomID id.RoomID, eventID id.EventID, receiptType event.ReceiptType, content interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "rooms", roomID, "receipt", receiptType, eventID) - _, err = cli.MakeRequest("POST", urlPath, content, nil) + _, err = cli.MakeRequest(ctx, "POST", urlPath, content, nil) return } -func (cli *Client) SetReadMarkers(roomID id.RoomID, content interface{}) (err error) { +func (cli *Client) SetReadMarkers(ctx context.Context, roomID id.RoomID, content interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "rooms", roomID, "read_markers") - _, err = cli.MakeRequest("POST", urlPath, content, nil) + _, err = cli.MakeRequest(ctx, "POST", urlPath, content, nil) return } -func (cli *Client) AddTag(roomID id.RoomID, tag string, order float64) error { +func (cli *Client) AddTag(ctx context.Context, roomID id.RoomID, tag string, order float64) error { var tagData event.Tag if order == order { tagData.Order = json.Number(strconv.FormatFloat(order, 'e', -1, 64)) } - return cli.AddTagWithCustomData(roomID, tag, tagData) + return cli.AddTagWithCustomData(ctx, roomID, tag, tagData) } -func (cli *Client) AddTagWithCustomData(roomID id.RoomID, tag string, data interface{}) (err error) { +func (cli *Client) AddTagWithCustomData(ctx context.Context, roomID id.RoomID, tag string, data interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "tags", tag) - _, err = cli.MakeRequest("PUT", urlPath, data, nil) + _, err = cli.MakeRequest(ctx, "PUT", urlPath, data, nil) return } -func (cli *Client) GetTags(roomID id.RoomID) (tags event.TagEventContent, err error) { - err = cli.GetTagsWithCustomData(roomID, &tags) +func (cli *Client) GetTags(ctx context.Context, roomID id.RoomID) (tags event.TagEventContent, err error) { + err = cli.GetTagsWithCustomData(ctx, roomID, &tags) return } -func (cli *Client) GetTagsWithCustomData(roomID id.RoomID, resp interface{}) (err error) { +func (cli *Client) GetTagsWithCustomData(ctx context.Context, roomID id.RoomID, resp interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "tags") - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } -func (cli *Client) RemoveTag(roomID id.RoomID, tag string) (err error) { +func (cli *Client) RemoveTag(ctx context.Context, roomID id.RoomID, tag string) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "tags", tag) - _, err = cli.MakeRequest("DELETE", urlPath, nil, nil) + _, err = cli.MakeRequest(ctx, "DELETE", urlPath, nil, nil) return } // Deprecated: Synapse may not handle setting m.tag directly properly, so you should use the Add/RemoveTag methods instead. -func (cli *Client) SetTags(roomID id.RoomID, tags event.Tags) (err error) { - return cli.SetRoomAccountData(roomID, "m.tag", map[string]event.Tags{ +func (cli *Client) SetTags(ctx context.Context, roomID id.RoomID, tags event.Tags) (err error) { + return cli.SetRoomAccountData(ctx, roomID, "m.tag", map[string]event.Tags{ "tags": tags, }) } // TurnServer returns turn server details and credentials for the client to use when initiating calls. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3voipturnserver -func (cli *Client) TurnServer() (resp *RespTurnServer, err error) { +func (cli *Client) TurnServer(ctx context.Context) (resp *RespTurnServer, err error) { urlPath := cli.BuildClientURL("v3", "voip", "turnServer") - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } -func (cli *Client) CreateAlias(alias id.RoomAlias, roomID id.RoomID) (resp *RespAliasCreate, err error) { +func (cli *Client) CreateAlias(ctx context.Context, alias id.RoomAlias, roomID id.RoomID) (resp *RespAliasCreate, err error) { urlPath := cli.BuildClientURL("v3", "directory", "room", alias) - _, err = cli.MakeRequest("PUT", urlPath, &ReqAliasCreate{RoomID: roomID}, &resp) + _, err = cli.MakeRequest(ctx, "PUT", urlPath, &ReqAliasCreate{RoomID: roomID}, &resp) return } -func (cli *Client) ResolveAlias(alias id.RoomAlias) (resp *RespAliasResolve, err error) { +func (cli *Client) ResolveAlias(ctx context.Context, alias id.RoomAlias) (resp *RespAliasResolve, err error) { urlPath := cli.BuildClientURL("v3", "directory", "room", alias) - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } -func (cli *Client) DeleteAlias(alias id.RoomAlias) (resp *RespAliasDelete, err error) { +func (cli *Client) DeleteAlias(ctx context.Context, alias id.RoomAlias) (resp *RespAliasDelete, err error) { urlPath := cli.BuildClientURL("v3", "directory", "room", alias) - _, err = cli.MakeRequest("DELETE", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "DELETE", urlPath, nil, &resp) return } -func (cli *Client) GetAliases(roomID id.RoomID) (resp *RespAliasList, err error) { +func (cli *Client) GetAliases(ctx context.Context, roomID id.RoomID) (resp *RespAliasList, err error) { urlPath := cli.BuildClientURL("v3", "rooms", roomID, "aliases") - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } -func (cli *Client) UploadKeys(req *ReqUploadKeys) (resp *RespUploadKeys, err error) { +func (cli *Client) UploadKeys(ctx context.Context, req *ReqUploadKeys) (resp *RespUploadKeys, err error) { urlPath := cli.BuildClientURL("v3", "keys", "upload") - _, err = cli.MakeRequest("POST", urlPath, req, &resp) + _, err = cli.MakeRequest(ctx, "POST", urlPath, req, &resp) return } -func (cli *Client) QueryKeys(req *ReqQueryKeys) (resp *RespQueryKeys, err error) { +func (cli *Client) QueryKeys(ctx context.Context, req *ReqQueryKeys) (resp *RespQueryKeys, err error) { urlPath := cli.BuildClientURL("v3", "keys", "query") - _, err = cli.MakeRequest("POST", urlPath, req, &resp) + _, err = cli.MakeRequest(ctx, "POST", urlPath, req, &resp) return } -func (cli *Client) ClaimKeys(req *ReqClaimKeys) (resp *RespClaimKeys, err error) { +func (cli *Client) ClaimKeys(ctx context.Context, req *ReqClaimKeys) (resp *RespClaimKeys, err error) { urlPath := cli.BuildClientURL("v3", "keys", "claim") - _, err = cli.MakeRequest("POST", urlPath, req, &resp) + _, err = cli.MakeRequest(ctx, "POST", urlPath, req, &resp) return } -func (cli *Client) GetKeyChanges(from, to string) (resp *RespKeyChanges, err error) { +func (cli *Client) GetKeyChanges(ctx context.Context, from, to string) (resp *RespKeyChanges, err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "keys", "changes"}, map[string]string{ "from": from, "to": to, }) - _, err = cli.MakeRequest("POST", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "POST", urlPath, nil, &resp) return } -func (cli *Client) SendToDevice(eventType event.Type, req *ReqSendToDevice) (resp *RespSendToDevice, err error) { +func (cli *Client) SendToDevice(ctx context.Context, eventType event.Type, req *ReqSendToDevice) (resp *RespSendToDevice, err error) { urlPath := cli.BuildClientURL("v3", "sendToDevice", eventType.String(), cli.TxnID()) - _, err = cli.MakeRequest("PUT", urlPath, req, &resp) + _, err = cli.MakeRequest(ctx, "PUT", urlPath, req, &resp) return } -func (cli *Client) GetDevicesInfo() (resp *RespDevicesInfo, err error) { +func (cli *Client) GetDevicesInfo(ctx context.Context) (resp *RespDevicesInfo, err error) { urlPath := cli.BuildClientURL("v3", "devices") - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } -func (cli *Client) GetDeviceInfo(deviceID id.DeviceID) (resp *RespDeviceInfo, err error) { +func (cli *Client) GetDeviceInfo(ctx context.Context, deviceID id.DeviceID) (resp *RespDeviceInfo, err error) { urlPath := cli.BuildClientURL("v3", "devices", deviceID) - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) return } -func (cli *Client) SetDeviceInfo(deviceID id.DeviceID, req *ReqDeviceInfo) error { +func (cli *Client) SetDeviceInfo(ctx context.Context, deviceID id.DeviceID, req *ReqDeviceInfo) error { urlPath := cli.BuildClientURL("v3", "devices", deviceID) - _, err := cli.MakeRequest("PUT", urlPath, req, nil) + _, err := cli.MakeRequest(ctx, "PUT", urlPath, req, nil) return err } -func (cli *Client) DeleteDevice(deviceID id.DeviceID, req *ReqDeleteDevice) error { +func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice) error { urlPath := cli.BuildClientURL("v3", "devices", deviceID) - _, err := cli.MakeRequest("DELETE", urlPath, req, nil) + _, err := cli.MakeRequest(ctx, "DELETE", urlPath, req, nil) return err } -func (cli *Client) DeleteDevices(req *ReqDeleteDevices) error { +func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices) error { urlPath := cli.BuildClientURL("v3", "delete_devices") - _, err := cli.MakeRequest("DELETE", urlPath, req, nil) + _, err := cli.MakeRequest(ctx, "DELETE", urlPath, req, nil) return err } @@ -1932,8 +1920,8 @@ type UIACallback = func(*RespUserInteractive) interface{} // UploadCrossSigningKeys uploads the given cross-signing keys to the server. // Because the endpoint requires user-interactive authentication a callback must be provided that, // given the UI auth parameters, produces the required result (or nil to end the flow). -func (cli *Client) UploadCrossSigningKeys(keys *UploadCrossSigningKeysReq, uiaCallback UIACallback) error { - content, err := cli.MakeFullRequest(FullRequest{ +func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCrossSigningKeysReq, uiaCallback UIACallback) error { + content, err := cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, URL: cli.BuildClientURL("v3", "keys", "device_signing", "upload"), RequestJSON: keys, @@ -1948,48 +1936,48 @@ func (cli *Client) UploadCrossSigningKeys(keys *UploadCrossSigningKeysReq, uiaCa auth := uiaCallback(&uiAuthResp) if auth != nil { keys.Auth = auth - return cli.UploadCrossSigningKeys(keys, uiaCallback) + return cli.UploadCrossSigningKeys(ctx, keys, uiaCallback) } } return err } -func (cli *Client) UploadSignatures(req *ReqUploadSignatures) (resp *RespUploadSignatures, err error) { +func (cli *Client) UploadSignatures(ctx context.Context, req *ReqUploadSignatures) (resp *RespUploadSignatures, err error) { urlPath := cli.BuildClientURL("v3", "keys", "signatures", "upload") - _, err = cli.MakeRequest("POST", urlPath, req, &resp) + _, err = cli.MakeRequest(ctx, "POST", urlPath, req, &resp) return } // GetPushRules returns the push notification rules for the global scope. -func (cli *Client) GetPushRules() (*pushrules.PushRuleset, error) { - return cli.GetScopedPushRules("global") +func (cli *Client) GetPushRules(ctx context.Context) (*pushrules.PushRuleset, error) { + return cli.GetScopedPushRules(ctx, "global") } // GetScopedPushRules returns the push notification rules for the given scope. -func (cli *Client) GetScopedPushRules(scope string) (resp *pushrules.PushRuleset, err error) { +func (cli *Client) GetScopedPushRules(ctx context.Context, scope string) (resp *pushrules.PushRuleset, err error) { u, _ := url.Parse(cli.BuildClientURL("v3", "pushrules", scope)) // client.BuildURL returns the URL without a trailing slash, but the pushrules endpoint requires the slash. u.Path += "/" - _, err = cli.MakeRequest("GET", u.String(), nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", u.String(), nil, &resp) return } -func (cli *Client) GetPushRule(scope string, kind pushrules.PushRuleType, ruleID string) (resp *pushrules.PushRule, err error) { +func (cli *Client) GetPushRule(ctx context.Context, scope string, kind pushrules.PushRuleType, ruleID string) (resp *pushrules.PushRule, err error) { urlPath := cli.BuildClientURL("v3", "pushrules", scope, kind, ruleID) - _, err = cli.MakeRequest("GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) if resp != nil { resp.Type = kind } return } -func (cli *Client) DeletePushRule(scope string, kind pushrules.PushRuleType, ruleID string) error { +func (cli *Client) DeletePushRule(ctx context.Context, scope string, kind pushrules.PushRuleType, ruleID string) error { urlPath := cli.BuildClientURL("v3", "pushrules", scope, kind, ruleID) - _, err := cli.MakeRequest("DELETE", urlPath, nil, nil) + _, err := cli.MakeRequest(ctx, "DELETE", urlPath, nil, nil) return err } -func (cli *Client) PutPushRule(scope string, kind pushrules.PushRuleType, ruleID string, req *ReqPutPushRule) error { +func (cli *Client) PutPushRule(ctx context.Context, scope string, kind pushrules.PushRuleType, ruleID string, req *ReqPutPushRule) error { query := make(map[string]string) if len(req.After) > 0 { query["after"] = req.After @@ -1998,14 +1986,14 @@ func (cli *Client) PutPushRule(scope string, kind pushrules.PushRuleType, ruleID query["before"] = req.Before } urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "pushrules", scope, kind, ruleID}, query) - _, err := cli.MakeRequest("PUT", urlPath, req, nil) + _, err := cli.MakeRequest(ctx, "PUT", urlPath, req, nil) return err } // BatchSend sends a batch of historical events into a room. This is only available for appservices. // // Deprecated: MSC2716 has been abandoned, so this is now Beeper-specific. BeeperBatchSend should be used instead. -func (cli *Client) BatchSend(roomID id.RoomID, req *ReqBatchSend) (resp *RespBatchSend, err error) { +func (cli *Client) BatchSend(ctx context.Context, roomID id.RoomID, req *ReqBatchSend) (resp *RespBatchSend, err error) { path := ClientURLPath{"unstable", "org.matrix.msc2716", "rooms", roomID, "batch_send"} query := map[string]string{ "prev_event_id": req.PrevEventID.String(), @@ -2019,12 +2007,12 @@ func (cli *Client) BatchSend(roomID id.RoomID, req *ReqBatchSend) (resp *RespBat if len(req.BatchID) > 0 { query["batch_id"] = req.BatchID.String() } - _, err = cli.MakeRequest("POST", cli.BuildURLWithQuery(path, query), req, &resp) + _, err = cli.MakeRequest(ctx, "POST", cli.BuildURLWithQuery(path, query), req, &resp) return } -func (cli *Client) AppservicePing(id, txnID string) (resp *RespAppservicePing, err error) { - _, err = cli.MakeFullRequest(FullRequest{ +func (cli *Client) AppservicePing(ctx context.Context, id, txnID string) (resp *RespAppservicePing, err error) { + _, err = cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, URL: cli.BuildClientURL("v1", "appservice", id, "ping"), RequestJSON: &ReqAppservicePing{TxnID: txnID}, @@ -2035,27 +2023,26 @@ func (cli *Client) AppservicePing(id, txnID string) (resp *RespAppservicePing, e return } -func (cli *Client) BeeperBatchSend(roomID id.RoomID, req *ReqBeeperBatchSend) (resp *RespBeeperBatchSend, err error) { +func (cli *Client) BeeperBatchSend(ctx context.Context, roomID id.RoomID, req *ReqBeeperBatchSend) (resp *RespBeeperBatchSend, err error) { u := cli.BuildClientURL("unstable", "com.beeper.backfill", "rooms", roomID, "batch_send") - _, err = cli.MakeRequest(http.MethodPost, u, req, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, u, req, &resp) return } -func (cli *Client) BeeperMergeRooms(req *ReqBeeperMergeRoom) (resp *RespBeeperMergeRoom, err error) { +func (cli *Client) BeeperMergeRooms(ctx context.Context, req *ReqBeeperMergeRoom) (resp *RespBeeperMergeRoom, err error) { urlPath := cli.BuildClientURL("unstable", "com.beeper.chatmerging", "merge") - _, err = cli.MakeRequest(http.MethodPost, urlPath, req, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) return } -func (cli *Client) BeeperSplitRoom(req *ReqBeeperSplitRoom) (resp *RespBeeperSplitRoom, err error) { +func (cli *Client) BeeperSplitRoom(ctx context.Context, req *ReqBeeperSplitRoom) (resp *RespBeeperSplitRoom, err error) { urlPath := cli.BuildClientURL("unstable", "com.beeper.chatmerging", "rooms", req.RoomID, "split") - _, err = cli.MakeRequest(http.MethodPost, urlPath, req, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) return } - -func (cli *Client) BeeperDeleteRoom(roomID id.RoomID) (err error) { +func (cli *Client) BeeperDeleteRoom(ctx context.Context, roomID id.RoomID) (err error) { urlPath := cli.BuildClientURL("unstable", "com.beeper.yeet", "rooms", roomID, "delete") - _, err = cli.MakeRequest(http.MethodPost, urlPath, nil, nil) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, nil, nil) return } diff --git a/crypto/cross_sign_key.go b/crypto/cross_sign_key.go index d38df8f3..4528ae02 100644 --- a/crypto/cross_sign_key.go +++ b/crypto/cross_sign_key.go @@ -8,6 +8,7 @@ package crypto import ( + "context" "fmt" "maunium.net/go/mautrix" @@ -89,7 +90,7 @@ func (mach *OlmMachine) GenerateCrossSigningKeys() (*CrossSigningKeysCache, erro } // PublishCrossSigningKeys signs and uploads the public keys of the given cross-signing keys to the server. -func (mach *OlmMachine) PublishCrossSigningKeys(keys *CrossSigningKeysCache, uiaCallback mautrix.UIACallback) error { +func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *CrossSigningKeysCache, uiaCallback mautrix.UIACallback) error { userID := mach.Client.UserID masterKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey.String()) masterKey := mautrix.CrossSigningKeys{ @@ -134,7 +135,7 @@ func (mach *OlmMachine) PublishCrossSigningKeys(keys *CrossSigningKeysCache, uia }, } - err = mach.Client.UploadCrossSigningKeys(&mautrix.UploadCrossSigningKeysReq{ + err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq{ Master: masterKey, SelfSigning: selfKey, UserSigning: userKey, diff --git a/crypto/cross_sign_pubkey.go b/crypto/cross_sign_pubkey.go index 3067753a..9f4f3583 100644 --- a/crypto/cross_sign_pubkey.go +++ b/crypto/cross_sign_pubkey.go @@ -7,6 +7,7 @@ package crypto import ( + "context" "fmt" "maunium.net/go/mautrix" @@ -19,7 +20,7 @@ type CrossSigningPublicKeysCache struct { UserSigningKey id.Ed25519 } -func (mach *OlmMachine) GetOwnCrossSigningPublicKeys() *CrossSigningPublicKeysCache { +func (mach *OlmMachine) GetOwnCrossSigningPublicKeys(ctx context.Context) *CrossSigningPublicKeysCache { if mach.crossSigningPubkeys != nil { return mach.crossSigningPubkeys } @@ -30,7 +31,7 @@ func (mach *OlmMachine) GetOwnCrossSigningPublicKeys() *CrossSigningPublicKeysCa if mach.crossSigningPubkeysFetched { return nil } - cspk, err := mach.GetCrossSigningPublicKeys(mach.Client.UserID) + cspk, err := mach.GetCrossSigningPublicKeys(ctx, mach.Client.UserID) if err != nil { mach.Log.Error().Err(err).Msg("Failed to get own cross-signing public keys") return nil @@ -40,7 +41,7 @@ func (mach *OlmMachine) GetOwnCrossSigningPublicKeys() *CrossSigningPublicKeysCa return mach.crossSigningPubkeys } -func (mach *OlmMachine) GetCrossSigningPublicKeys(userID id.UserID) (*CrossSigningPublicKeysCache, error) { +func (mach *OlmMachine) GetCrossSigningPublicKeys(ctx context.Context, userID id.UserID) (*CrossSigningPublicKeysCache, error) { dbKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID) if err != nil { return nil, fmt.Errorf("failed to get keys from database: %w", err) @@ -58,7 +59,7 @@ func (mach *OlmMachine) GetCrossSigningPublicKeys(userID id.UserID) (*CrossSigni } } - keys, err := mach.Client.QueryKeys(&mautrix.ReqQueryKeys{ + keys, err := mach.Client.QueryKeys(ctx, &mautrix.ReqQueryKeys{ DeviceKeys: mautrix.DeviceKeysRequest{ userID: mautrix.DeviceIDList{}, }, diff --git a/crypto/cross_sign_signing.go b/crypto/cross_sign_signing.go index 62c41b38..1a5a0233 100644 --- a/crypto/cross_sign_signing.go +++ b/crypto/cross_sign_signing.go @@ -8,6 +8,7 @@ package crypto import ( + "context" "errors" "fmt" @@ -59,7 +60,7 @@ func (mach *OlmMachine) fetchMasterKey(device *id.Device, content *event.Verific } // SignUser creates a cross-signing signature for a user, stores it and uploads it to the server. -func (mach *OlmMachine) SignUser(userID id.UserID, masterKey id.Ed25519) error { +func (mach *OlmMachine) SignUser(ctx context.Context, userID id.UserID, masterKey id.Ed25519) error { if userID == mach.Client.UserID { return ErrCantSignOwnMasterKey } else if mach.CrossSigningKeys == nil || mach.CrossSigningKeys.UserSigningKey == nil { @@ -74,7 +75,7 @@ func (mach *OlmMachine) SignUser(userID id.UserID, masterKey id.Ed25519) error { }, } - signature, err := mach.signAndUpload(masterKeyObj, userID, masterKey.String(), mach.CrossSigningKeys.UserSigningKey) + signature, err := mach.signAndUpload(ctx, masterKeyObj, userID, masterKey.String(), mach.CrossSigningKeys.UserSigningKey) if err != nil { return err } @@ -92,7 +93,7 @@ func (mach *OlmMachine) SignUser(userID id.UserID, masterKey id.Ed25519) error { } // SignOwnMasterKey uses the current account for signing the current user's master key and uploads the signature. -func (mach *OlmMachine) SignOwnMasterKey() error { +func (mach *OlmMachine) SignOwnMasterKey(ctx context.Context) error { if mach.CrossSigningKeys == nil { return ErrCrossSigningKeysNotCached } else if mach.account == nil { @@ -124,7 +125,7 @@ func (mach *OlmMachine) SignOwnMasterKey() error { Str("signature", signature). Msg("Signed own master key with own device key") - resp, err := mach.Client.UploadSignatures(&mautrix.ReqUploadSignatures{ + resp, err := mach.Client.UploadSignatures(ctx, &mautrix.ReqUploadSignatures{ userID: map[string]mautrix.ReqKeysSignatures{ masterKey.String(): masterKeyObj, }, @@ -144,14 +145,14 @@ func (mach *OlmMachine) SignOwnMasterKey() error { } // SignOwnDevice creates a cross-signing signature for a device belonging to the current user and uploads it to the server. -func (mach *OlmMachine) SignOwnDevice(device *id.Device) error { +func (mach *OlmMachine) SignOwnDevice(ctx context.Context, device *id.Device) error { if device.UserID != mach.Client.UserID { return ErrCantSignOtherDevice } else if mach.CrossSigningKeys == nil || mach.CrossSigningKeys.SelfSigningKey == nil { return ErrSelfSigningKeyNotCached } - deviceKeys, err := mach.getFullDeviceKeys(device) + deviceKeys, err := mach.getFullDeviceKeys(ctx, device) if err != nil { return err } @@ -166,7 +167,7 @@ func (mach *OlmMachine) SignOwnDevice(device *id.Device) error { deviceKeyObj.Keys[id.KeyID(keyID)] = key } - signature, err := mach.signAndUpload(deviceKeyObj, device.UserID, device.DeviceID.String(), mach.CrossSigningKeys.SelfSigningKey) + signature, err := mach.signAndUpload(ctx, deviceKeyObj, device.UserID, device.DeviceID.String(), mach.CrossSigningKeys.SelfSigningKey) if err != nil { return err } @@ -186,8 +187,8 @@ func (mach *OlmMachine) SignOwnDevice(device *id.Device) error { // getFullDeviceKeys gets the full device keys object for the given device. // This is used because we don't cache some of the details like list of algorithms and unsupported key types. -func (mach *OlmMachine) getFullDeviceKeys(device *id.Device) (*mautrix.DeviceKeys, error) { - devicesKeys, err := mach.Client.QueryKeys(&mautrix.ReqQueryKeys{ +func (mach *OlmMachine) getFullDeviceKeys(ctx context.Context, device *id.Device) (*mautrix.DeviceKeys, error) { + devicesKeys, err := mach.Client.QueryKeys(ctx, &mautrix.ReqQueryKeys{ DeviceKeys: mautrix.DeviceKeysRequest{ device.UserID: mautrix.DeviceIDList{device.DeviceID}, }, @@ -208,7 +209,7 @@ func (mach *OlmMachine) getFullDeviceKeys(device *id.Device) (*mautrix.DeviceKey } // signAndUpload signs the given key signatures object and uploads it to the server. -func (mach *OlmMachine) signAndUpload(req mautrix.ReqKeysSignatures, userID id.UserID, signedThing string, key *olm.PkSigning) (string, error) { +func (mach *OlmMachine) signAndUpload(ctx context.Context, req mautrix.ReqKeysSignatures, userID id.UserID, signedThing string, key *olm.PkSigning) (string, error) { signature, err := key.SignJSON(req) if err != nil { return "", fmt.Errorf("failed to sign JSON: %w", err) @@ -219,7 +220,7 @@ func (mach *OlmMachine) signAndUpload(req mautrix.ReqKeysSignatures, userID id.U }, } - resp, err := mach.Client.UploadSignatures(&mautrix.ReqUploadSignatures{ + resp, err := mach.Client.UploadSignatures(ctx, &mautrix.ReqUploadSignatures{ userID: map[string]mautrix.ReqKeysSignatures{ signedThing: req, }, diff --git a/crypto/cross_sign_ssss.go b/crypto/cross_sign_ssss.go index b8ca71cb..ef8a0ad3 100644 --- a/crypto/cross_sign_ssss.go +++ b/crypto/cross_sign_ssss.go @@ -7,6 +7,7 @@ package crypto import ( + "context" "fmt" "maunium.net/go/mautrix" @@ -16,16 +17,16 @@ import ( ) // FetchCrossSigningKeysFromSSSS fetches all the cross-signing keys from SSSS, decrypts them using the given key and stores them in the olm machine. -func (mach *OlmMachine) FetchCrossSigningKeysFromSSSS(key *ssss.Key) error { - masterKey, err := mach.retrieveDecryptXSigningKey(event.AccountDataCrossSigningMaster, key) +func (mach *OlmMachine) FetchCrossSigningKeysFromSSSS(ctx context.Context, key *ssss.Key) error { + masterKey, err := mach.retrieveDecryptXSigningKey(ctx, event.AccountDataCrossSigningMaster, key) if err != nil { return err } - selfSignKey, err := mach.retrieveDecryptXSigningKey(event.AccountDataCrossSigningSelf, key) + selfSignKey, err := mach.retrieveDecryptXSigningKey(ctx, event.AccountDataCrossSigningSelf, key) if err != nil { return err } - userSignKey, err := mach.retrieveDecryptXSigningKey(event.AccountDataCrossSigningUser, key) + userSignKey, err := mach.retrieveDecryptXSigningKey(ctx, event.AccountDataCrossSigningUser, key) if err != nil { return err } @@ -38,12 +39,12 @@ func (mach *OlmMachine) FetchCrossSigningKeysFromSSSS(key *ssss.Key) error { } // retrieveDecryptXSigningKey retrieves the requested cross-signing key from SSSS and decrypts it using the given SSSS key. -func (mach *OlmMachine) retrieveDecryptXSigningKey(keyName event.Type, key *ssss.Key) ([utils.AESCTRKeyLength]byte, error) { +func (mach *OlmMachine) retrieveDecryptXSigningKey(ctx context.Context, keyName event.Type, key *ssss.Key) ([utils.AESCTRKeyLength]byte, error) { var decryptedKey [utils.AESCTRKeyLength]byte var encData ssss.EncryptedAccountDataEventContent // retrieve and parse the account data for this key type from SSSS - err := mach.Client.GetAccountData(keyName.Type, &encData) + err := mach.Client.GetAccountData(ctx, keyName.Type, &encData) if err != nil { return decryptedKey, err } @@ -62,8 +63,8 @@ func (mach *OlmMachine) retrieveDecryptXSigningKey(keyName event.Type, key *ssss // is used. The base58-formatted recovery key is the first return parameter. // // The account password of the user is required for uploading keys to the server. -func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(userPassword, passphrase string) (string, error) { - key, err := mach.SSSS.GenerateAndUploadKey(passphrase) +func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(ctx context.Context, userPassword, passphrase string) (string, error) { + key, err := mach.SSSS.GenerateAndUploadKey(ctx, passphrase) if err != nil { return "", fmt.Errorf("failed to generate and upload SSSS key: %w", err) } @@ -77,12 +78,12 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(userPassword, passphra recoveryKey := key.RecoveryKey() // Store the private keys in SSSS - if err := mach.UploadCrossSigningKeysToSSSS(key, keysCache); err != nil { + if err := mach.UploadCrossSigningKeysToSSSS(ctx, key, keysCache); err != nil { return recoveryKey, fmt.Errorf("failed to upload cross-signing keys to SSSS: %w", err) } // Publish cross-signing keys - err = mach.PublishCrossSigningKeys(keysCache, func(uiResp *mautrix.RespUserInteractive) interface{} { + err = mach.PublishCrossSigningKeys(ctx, keysCache, func(uiResp *mautrix.RespUserInteractive) interface{} { return &mautrix.ReqUIAuthLogin{ BaseAuthData: mautrix.BaseAuthData{ Type: mautrix.AuthTypePassword, @@ -96,7 +97,7 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(userPassword, passphra return recoveryKey, fmt.Errorf("failed to publish cross-signing keys: %w", err) } - err = mach.SSSS.SetDefaultKeyID(key.ID) + err = mach.SSSS.SetDefaultKeyID(ctx, key.ID) if err != nil { return recoveryKey, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err) } @@ -105,14 +106,14 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(userPassword, passphra } // UploadCrossSigningKeysToSSSS stores the given cross-signing keys on the server encrypted with the given key. -func (mach *OlmMachine) UploadCrossSigningKeysToSSSS(key *ssss.Key, keys *CrossSigningKeysCache) error { - if err := mach.SSSS.SetEncryptedAccountData(event.AccountDataCrossSigningMaster, keys.MasterKey.Seed, key); err != nil { +func (mach *OlmMachine) UploadCrossSigningKeysToSSSS(ctx context.Context, key *ssss.Key, keys *CrossSigningKeysCache) error { + if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningMaster, keys.MasterKey.Seed, key); err != nil { return err } - if err := mach.SSSS.SetEncryptedAccountData(event.AccountDataCrossSigningSelf, keys.SelfSigningKey.Seed, key); err != nil { + if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningSelf, keys.SelfSigningKey.Seed, key); err != nil { return err } - if err := mach.SSSS.SetEncryptedAccountData(event.AccountDataCrossSigningUser, keys.UserSigningKey.Seed, key); err != nil { + if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningUser, keys.UserSigningKey.Seed, key); err != nil { return err } return nil diff --git a/crypto/cross_sign_validation.go b/crypto/cross_sign_validation.go index e8a7d79a..27afeb73 100644 --- a/crypto/cross_sign_validation.go +++ b/crypto/cross_sign_validation.go @@ -89,7 +89,7 @@ func (mach *OlmMachine) IsDeviceTrusted(device *id.Device) bool { // IsUserTrusted returns whether a user has been determined to be trusted by our user-signing key having signed their master key. // In the case the user ID is our own and we have successfully retrieved our cross-signing keys, we trust our own user. func (mach *OlmMachine) IsUserTrusted(ctx context.Context, userID id.UserID) (bool, error) { - csPubkeys := mach.GetOwnCrossSigningPublicKeys() + csPubkeys := mach.GetOwnCrossSigningPublicKeys(ctx) if csPubkeys == nil { return false, nil } diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go index 35dadac8..9d071ba9 100644 --- a/crypto/cryptohelper/cryptohelper.go +++ b/crypto/cryptohelper/cryptohelper.go @@ -124,6 +124,7 @@ func (helper *CryptoHelper) Init() error { } else { stateStore = helper.client.StateStore.(crypto.StateStore) } + ctx := context.Background() var cryptoStore crypto.Store if helper.unmanagedCryptoStore == nil { managedCryptoStore := crypto.NewSQLCryptoStore(helper.dbForManagedStores, dbutil.ZeroLogger(helper.log.With().Str("db_section", "crypto").Logger()), helper.DBAccountID, helper.client.DeviceID, helper.pickleKey) @@ -146,7 +147,7 @@ func (helper *CryptoHelper) Init() error { Str("username", helper.LoginAs.Identifier.User). Str("device_id", helper.LoginAs.DeviceID.String()). Msg("Logging in") - _, err = helper.client.Login(helper.LoginAs) + _, err = helper.client.Login(ctx, helper.LoginAs) if err != nil { return err } @@ -170,7 +171,7 @@ func (helper *CryptoHelper) Init() error { err := helper.mach.Load() if err != nil { return fmt.Errorf("failed to load olm account: %w", err) - } else if err = helper.verifyDeviceKeysOnServer(); err != nil { + } else if err = helper.verifyDeviceKeysOnServer(ctx); err != nil { return err } @@ -204,9 +205,9 @@ func (helper *CryptoHelper) Machine() *crypto.OlmMachine { return helper.mach } -func (helper *CryptoHelper) verifyDeviceKeysOnServer() error { +func (helper *CryptoHelper) verifyDeviceKeysOnServer(ctx context.Context) error { helper.log.Debug().Msg("Making sure our device has the expected keys on the server") - resp, err := helper.client.QueryKeys(&mautrix.ReqQueryKeys{ + resp, err := helper.client.QueryKeys(ctx, &mautrix.ReqQueryKeys{ DeviceKeys: map[id.UserID]mautrix.DeviceIDList{ helper.client.UserID: {helper.client.DeviceID}, }, @@ -278,7 +279,7 @@ func (helper *CryptoHelper) postDecrypt(src mautrix.EventSource, decrypted *even helper.client.Syncer.(mautrix.DispatchableSyncer).Dispatch(src|mautrix.EventSourceDecrypted, decrypted) } -func (helper *CryptoHelper) RequestSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) { +func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) { if helper == nil { return } @@ -294,7 +295,7 @@ func (helper *CryptoHelper) RequestSession(roomID id.RoomID, senderKey id.Sender Str("device_id", deviceID.String()). Str("room_id", roomID.String()). Logger() - err := helper.mach.SendRoomKeyRequest(roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{ + err := helper.mach.SendRoomKeyRequest(ctx, roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{ userID: {deviceID}, helper.client.UserID: {"*"}, }) @@ -309,7 +310,7 @@ func (helper *CryptoHelper) waitLongerForSession(log zerolog.Logger, src mautrix content := evt.Content.AsEncrypted() log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...") - go helper.RequestSession(evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) + go helper.RequestSession(context.Background(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) if !helper.mach.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { log.Debug().Msg("Didn't get session, giving up") diff --git a/crypto/devicelist.go b/crypto/devicelist.go index 7f779259..8514275c 100644 --- a/crypto/devicelist.go +++ b/crypto/devicelist.go @@ -108,7 +108,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT req.DeviceKeys[userID] = mautrix.DeviceIDList{} } log.Debug().Strs("users", strishArray(users)).Msg("Querying keys for users") - resp, err := mach.Client.QueryKeys(req) + resp, err := mach.Client.QueryKeys(ctx, req) if err != nil { log.Error().Err(err).Msg("Failed to query keys") return diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index 80aef710..078ef518 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -280,11 +280,11 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, Int("user_count", len(toDeviceWithheld.Messages)). Msg("Sending to-device messages to report withheld key") // TODO remove the next 4 lines once clients support m.room_key.withheld - _, err = mach.Client.SendToDevice(event.ToDeviceOrgMatrixRoomKeyWithheld, toDeviceWithheld) + _, err = mach.Client.SendToDevice(ctx, event.ToDeviceOrgMatrixRoomKeyWithheld, toDeviceWithheld) if err != nil { log.Warn().Err(err).Msg("Failed to report withheld keys (legacy event type)") } - _, err = mach.Client.SendToDevice(event.ToDeviceRoomKeyWithheld, toDeviceWithheld) + _, err = mach.Client.SendToDevice(ctx, event.ToDeviceRoomKeyWithheld, toDeviceWithheld) if err != nil { log.Warn().Err(err).Msg("Failed to report withheld keys") } @@ -327,7 +327,7 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session Int("device_count", deviceCount). Int("user_count", len(toDevice.Messages)). Msg("Sending to-device messages to share group session") - _, err := mach.Client.SendToDevice(event.ToDeviceEncrypted, toDevice) + _, err := mach.Client.SendToDevice(ctx, event.ToDeviceEncrypted, toDevice) return err } diff --git a/crypto/encryptolm.go b/crypto/encryptolm.go index 8ae3ba1a..f21ecd02 100644 --- a/crypto/encryptolm.go +++ b/crypto/encryptolm.go @@ -83,7 +83,7 @@ func (mach *OlmMachine) createOutboundSessions(ctx context.Context, input map[id if len(request) == 0 { return nil } - resp, err := mach.Client.ClaimKeys(&mautrix.ReqClaimKeys{ + resp, err := mach.Client.ClaimKeys(ctx, &mautrix.ReqClaimKeys{ OneTimeKeys: request, Timeout: 10 * 1000, }) diff --git a/crypto/keysharing.go b/crypto/keysharing.go index 1cbc41bd..9b8eef7e 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -48,7 +48,7 @@ func (mach *OlmMachine) RequestRoomKey(ctx context.Context, toUser id.UserID, to keyResponseReceived := make(chan struct{}) mach.roomKeyRequestFilled.Store(sessionID, keyResponseReceived) - err := mach.SendRoomKeyRequest(roomID, senderKey, sessionID, requestID, map[id.UserID][]id.DeviceID{toUser: {toDevice}}) + err := mach.SendRoomKeyRequest(ctx, roomID, senderKey, sessionID, requestID, map[id.UserID][]id.DeviceID{toUser: {toDevice}}) if err != nil { return nil, err } @@ -85,7 +85,7 @@ func (mach *OlmMachine) RequestRoomKey(ctx context.Context, toUser id.UserID, to }, } - mach.Client.SendToDevice(event.ToDeviceRoomKeyRequest, toDeviceCancel) + mach.Client.SendToDevice(ctx, event.ToDeviceRoomKeyRequest, toDeviceCancel) }() return resChan, nil } @@ -99,7 +99,7 @@ func (mach *OlmMachine) RequestRoomKey(ctx context.Context, toUser id.UserID, to // to the specific key request, but currently it only supports a single target device and is therefore deprecated. // A future function may properly support multiple targets and automatically canceling the other requests when receiving // the first response. -func (mach *OlmMachine) SendRoomKeyRequest(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, requestID string, users map[id.UserID][]id.DeviceID) error { +func (mach *OlmMachine) SendRoomKeyRequest(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, requestID string, users map[id.UserID][]id.DeviceID) error { if len(requestID) == 0 { requestID = mach.Client.TxnID() } @@ -126,7 +126,7 @@ func (mach *OlmMachine) SendRoomKeyRequest(roomID id.RoomID, senderKey id.Sender toDeviceReq.Messages[user][device] = requestEvent } } - _, err := mach.Client.SendToDevice(event.ToDeviceRoomKeyRequest, toDeviceReq) + _, err := mach.Client.SendToDevice(ctx, event.ToDeviceRoomKeyRequest, toDeviceReq) return err } @@ -188,7 +188,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt return true } -func (mach *OlmMachine) rejectKeyRequest(rejection KeyShareRejection, device *id.Device, request event.RequestedKeyInfo) { +func (mach *OlmMachine) rejectKeyRequest(ctx context.Context, rejection KeyShareRejection, device *id.Device, request event.RequestedKeyInfo) { if rejection.Code == "" { // If the rejection code is empty, it means don't share keys, but also don't tell the requester. return @@ -201,7 +201,7 @@ func (mach *OlmMachine) rejectKeyRequest(rejection KeyShareRejection, device *id Code: rejection.Code, Reason: rejection.Reason, } - err := mach.sendToOneDevice(device.UserID, device.DeviceID, event.ToDeviceRoomKeyWithheld, &content) + err := mach.sendToOneDevice(ctx, device.UserID, device.DeviceID, event.ToDeviceRoomKeyWithheld, &content) if err != nil { mach.Log.Warn().Err(err). Str("code", string(rejection.Code)). @@ -209,7 +209,7 @@ func (mach *OlmMachine) rejectKeyRequest(rejection KeyShareRejection, device *id Str("device_id", device.DeviceID.String()). Msg("Failed to send key share rejection") } - err = mach.sendToOneDevice(device.UserID, device.DeviceID, event.ToDeviceOrgMatrixRoomKeyWithheld, &content) + err = mach.sendToOneDevice(ctx, device.UserID, device.DeviceID, event.ToDeviceOrgMatrixRoomKeyWithheld, &content) if err != nil { mach.Log.Warn().Err(err). Str("code", string(rejection.Code)). @@ -270,7 +270,7 @@ func (mach *OlmMachine) handleRoomKeyRequest(ctx context.Context, sender id.User rejection := mach.AllowKeyShare(ctx, device, content.Body) if rejection != nil { - mach.rejectKeyRequest(*rejection, device, content.Body) + mach.rejectKeyRequest(ctx, *rejection, device, content.Body) return } @@ -278,15 +278,15 @@ func (mach *OlmMachine) handleRoomKeyRequest(ctx context.Context, sender id.User if err != nil { if errors.Is(err, ErrGroupSessionWithheld) { log.Debug().Err(err).Msg("Requested group session not available") - mach.rejectKeyRequest(KeyShareRejectUnavailable, device, content.Body) + mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) } else { log.Error().Err(err).Msg("Failed to get group session to forward") - mach.rejectKeyRequest(KeyShareRejectInternalError, device, content.Body) + mach.rejectKeyRequest(ctx, KeyShareRejectInternalError, device, content.Body) } return } else if igs == nil { log.Error().Msg("Didn't find group session to forward") - mach.rejectKeyRequest(KeyShareRejectUnavailable, device, content.Body) + mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) return } if internalID := igs.ID(); internalID != content.Body.SessionID { @@ -299,7 +299,7 @@ func (mach *OlmMachine) handleRoomKeyRequest(ctx context.Context, sender id.User exportedKey, err := igs.Internal.Export(firstKnownIndex) if err != nil { log.Error().Err(err).Msg("Failed to export group session to forward") - mach.rejectKeyRequest(KeyShareRejectInternalError, device, content.Body) + mach.rejectKeyRequest(ctx, KeyShareRejectInternalError, device, content.Body) return } diff --git a/crypto/machine.go b/crypto/machine.go index 2c9b63c9..37a21da3 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -381,17 +381,17 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) { mach.handleBeeperRoomKeyAck(ctx, evt.Sender, content) // verification cases case *event.VerificationStartEventContent: - mach.handleVerificationStart(evt.Sender, content, content.TransactionID, 10*time.Minute, "") + mach.handleVerificationStart(ctx, evt.Sender, content, content.TransactionID, 10*time.Minute, "") case *event.VerificationAcceptEventContent: - mach.handleVerificationAccept(evt.Sender, content, content.TransactionID) + mach.handleVerificationAccept(ctx, evt.Sender, content, content.TransactionID) case *event.VerificationKeyEventContent: - mach.handleVerificationKey(evt.Sender, content, content.TransactionID) + mach.handleVerificationKey(ctx, evt.Sender, content, content.TransactionID) case *event.VerificationMacEventContent: - mach.handleVerificationMAC(evt.Sender, content, content.TransactionID) + mach.handleVerificationMAC(ctx, evt.Sender, content, content.TransactionID) case *event.VerificationCancelEventContent: mach.handleVerificationCancel(evt.Sender, content, content.TransactionID) case *event.VerificationRequestEventContent: - mach.handleVerificationRequest(evt.Sender, content, content.TransactionID, "") + mach.handleVerificationRequest(ctx, evt.Sender, content, content.TransactionID, "") case *event.RoomKeyWithheldEventContent: mach.handleRoomKeyWithheld(ctx, content) default: @@ -473,7 +473,7 @@ func (mach *OlmMachine) SendEncryptedToDevice(ctx context.Context, device *id.De Str("to_identity_key", device.IdentityKey.String()). Str("olm_session_id", olmSess.ID().String()). Msg("Sending encrypted to-device event") - _, err = mach.Client.SendToDevice(event.ToDeviceEncrypted, + _, err = mach.Client.SendToDevice(ctx, event.ToDeviceEncrypted, &mautrix.ReqSendToDevice{ Messages: map[id.UserID]map[id.DeviceID]*event.Content{ device.UserID: { @@ -624,7 +624,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro defer mach.otkUploadLock.Unlock() if mach.lastOTKUpload.Add(1*time.Minute).After(start) || currentOTKCount < 0 { log.Debug().Msg("Checking OTK count from server due to suspiciously close share keys requests or negative OTK count") - resp, err := mach.Client.UploadKeys(&mautrix.ReqUploadKeys{}) + resp, err := mach.Client.UploadKeys(ctx, &mautrix.ReqUploadKeys{}) if err != nil { return fmt.Errorf("failed to check current OTK counts: %w", err) } @@ -649,7 +649,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro OneTimeKeys: oneTimeKeys, } log.Debug().Int("count", len(oneTimeKeys)).Msg("Uploading one-time keys") - _, err := mach.Client.UploadKeys(req) + _, err := mach.Client.UploadKeys(ctx, req) if err != nil { return err } diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 15709bdb..c73a859a 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -67,17 +67,17 @@ func (store *SQLCryptoStore) Flush() error { } // PutNextBatch stores the next sync batch token for the current account. -func (store *SQLCryptoStore) PutNextBatch(nextBatch string) error { +func (store *SQLCryptoStore) PutNextBatch(ctx context.Context, nextBatch string) error { store.SyncToken = nextBatch - _, err := store.DB.Exec(`UPDATE crypto_account SET sync_token=$1 WHERE account_id=$2`, store.SyncToken, store.AccountID) + _, err := store.DB.ExecContext(ctx, `UPDATE crypto_account SET sync_token=$1 WHERE account_id=$2`, store.SyncToken, store.AccountID) return err } // GetNextBatch retrieves the next sync batch token for the current account. -func (store *SQLCryptoStore) GetNextBatch() (string, error) { +func (store *SQLCryptoStore) GetNextBatch(ctx context.Context) (string, error) { if store.SyncToken == "" { err := store.DB. - QueryRow("SELECT sync_token FROM crypto_account WHERE account_id=$1", store.AccountID). + QueryRowContext(ctx, "SELECT sync_token FROM crypto_account WHERE account_id=$1", store.AccountID). Scan(&store.SyncToken) if !errors.Is(err, sql.ErrNoRows) { return "", err @@ -88,18 +88,18 @@ func (store *SQLCryptoStore) GetNextBatch() (string, error) { var _ mautrix.SyncStore = (*SQLCryptoStore)(nil) -func (store *SQLCryptoStore) SaveFilterID(_ id.UserID, _ string) {} -func (store *SQLCryptoStore) LoadFilterID(_ id.UserID) string { return "" } +func (store *SQLCryptoStore) SaveFilterID(ctx context.Context, _ id.UserID, _ string) {} +func (store *SQLCryptoStore) LoadFilterID(ctx context.Context, _ id.UserID) string { return "" } -func (store *SQLCryptoStore) SaveNextBatch(_ id.UserID, nextBatchToken string) { - err := store.PutNextBatch(nextBatchToken) +func (store *SQLCryptoStore) SaveNextBatch(ctx context.Context, _ id.UserID, nextBatchToken string) { + err := store.PutNextBatch(ctx, nextBatchToken) if err != nil { // TODO handle error } } -func (store *SQLCryptoStore) LoadNextBatch(_ id.UserID) string { - nb, err := store.GetNextBatch() +func (store *SQLCryptoStore) LoadNextBatch(ctx context.Context, _ id.UserID) string { + nb, err := store.GetNextBatch(ctx) if err != nil { // TODO handle error } diff --git a/crypto/ssss/client.go b/crypto/ssss/client.go index b74deca1..2dac30e1 100644 --- a/crypto/ssss/client.go +++ b/crypto/ssss/client.go @@ -7,6 +7,7 @@ package ssss import ( + "context" "fmt" "maunium.net/go/mautrix" @@ -29,9 +30,9 @@ type DefaultSecretStorageKeyContent struct { } // GetDefaultKeyID retrieves the default key ID for this account from SSSS. -func (mach *Machine) GetDefaultKeyID() (string, error) { +func (mach *Machine) GetDefaultKeyID(ctx context.Context) (string, error) { var data DefaultSecretStorageKeyContent - err := mach.Client.GetAccountData(event.AccountDataSecretStorageDefaultKey.Type, &data) + err := mach.Client.GetAccountData(ctx, event.AccountDataSecretStorageDefaultKey.Type, &data) if err != nil { if httpErr, ok := err.(mautrix.HTTPError); ok && httpErr.RespError != nil && httpErr.RespError.ErrCode == "M_NOT_FOUND" { return "", ErrNoDefaultKeyAccountDataEvent @@ -45,36 +46,36 @@ func (mach *Machine) GetDefaultKeyID() (string, error) { } // SetDefaultKeyID sets the default key ID for this account on the server. -func (mach *Machine) SetDefaultKeyID(keyID string) error { - return mach.Client.SetAccountData(event.AccountDataSecretStorageDefaultKey.Type, &DefaultSecretStorageKeyContent{keyID}) +func (mach *Machine) SetDefaultKeyID(ctx context.Context, keyID string) error { + return mach.Client.SetAccountData(ctx, event.AccountDataSecretStorageDefaultKey.Type, &DefaultSecretStorageKeyContent{keyID}) } // GetKeyData gets the details about the given key ID. -func (mach *Machine) GetKeyData(keyID string) (keyData *KeyMetadata, err error) { +func (mach *Machine) GetKeyData(ctx context.Context, keyID string) (keyData *KeyMetadata, err error) { keyData = &KeyMetadata{id: keyID} - err = mach.Client.GetAccountData(fmt.Sprintf("%s.%s", event.AccountDataSecretStorageKey.Type, keyID), keyData) + err = mach.Client.GetAccountData(ctx, fmt.Sprintf("%s.%s", event.AccountDataSecretStorageKey.Type, keyID), keyData) return } // SetKeyData stores SSSS key metadata on the server. -func (mach *Machine) SetKeyData(keyID string, keyData *KeyMetadata) error { - return mach.Client.SetAccountData(fmt.Sprintf("%s.%s", event.AccountDataSecretStorageKey.Type, keyID), keyData) +func (mach *Machine) SetKeyData(ctx context.Context, keyID string, keyData *KeyMetadata) error { + return mach.Client.SetAccountData(ctx, fmt.Sprintf("%s.%s", event.AccountDataSecretStorageKey.Type, keyID), keyData) } // GetDefaultKeyData gets the details about the default key ID (see GetDefaultKeyID). -func (mach *Machine) GetDefaultKeyData() (keyID string, keyData *KeyMetadata, err error) { - keyID, err = mach.GetDefaultKeyID() +func (mach *Machine) GetDefaultKeyData(ctx context.Context) (keyID string, keyData *KeyMetadata, err error) { + keyID, err = mach.GetDefaultKeyID(ctx) if err != nil { return } - keyData, err = mach.GetKeyData(keyID) + keyData, err = mach.GetKeyData(ctx, keyID) return } // GetDecryptedAccountData gets the account data event with the given event type and decrypts it using the given key. -func (mach *Machine) GetDecryptedAccountData(eventType event.Type, key *Key) ([]byte, error) { +func (mach *Machine) GetDecryptedAccountData(ctx context.Context, eventType event.Type, key *Key) ([]byte, error) { var encData EncryptedAccountDataEventContent - err := mach.Client.GetAccountData(eventType.Type, &encData) + err := mach.Client.GetAccountData(ctx, eventType.Type, &encData) if err != nil { return nil, err } @@ -82,7 +83,7 @@ func (mach *Machine) GetDecryptedAccountData(eventType event.Type, key *Key) ([] } // SetEncryptedAccountData encrypts the given data with the given keys and stores it on the server. -func (mach *Machine) SetEncryptedAccountData(eventType event.Type, data []byte, keys ...*Key) error { +func (mach *Machine) SetEncryptedAccountData(ctx context.Context, eventType event.Type, data []byte, keys ...*Key) error { if len(keys) == 0 { return ErrNoKeyGiven } @@ -90,17 +91,17 @@ func (mach *Machine) SetEncryptedAccountData(eventType event.Type, data []byte, for _, key := range keys { encrypted[key.ID] = key.Encrypt(eventType.Type, data) } - return mach.Client.SetAccountData(eventType.Type, &EncryptedAccountDataEventContent{Encrypted: encrypted}) + return mach.Client.SetAccountData(ctx, eventType.Type, &EncryptedAccountDataEventContent{Encrypted: encrypted}) } // GenerateAndUploadKey generates a new SSSS key and stores the metadata on the server. -func (mach *Machine) GenerateAndUploadKey(passphrase string) (key *Key, err error) { +func (mach *Machine) GenerateAndUploadKey(ctx context.Context, passphrase string) (key *Key, err error) { key, err = NewKey(passphrase) if err != nil { return nil, fmt.Errorf("failed to generate new key: %w", err) } - err = mach.SetKeyData(key.ID, key.Metadata) + err = mach.SetKeyData(ctx, key.ID, key.Metadata) if err != nil { err = fmt.Errorf("failed to upload key: %w", err) } diff --git a/crypto/store_test.go b/crypto/store_test.go index 9062d70d..ebeef393 100644 --- a/crypto/store_test.go +++ b/crypto/store_test.go @@ -54,8 +54,8 @@ func getCryptoStores(t *testing.T) map[string]Store { func TestPutNextBatch(t *testing.T) { stores := getCryptoStores(t) store := stores["sql"].(*SQLCryptoStore) - store.PutNextBatch("batch1") - if batch, _ := store.GetNextBatch(); batch != "batch1" { + store.PutNextBatch(context.Background(), "batch1") + if batch, _ := store.GetNextBatch(context.Background()); batch != "batch1" { t.Errorf("Expected batch1, got %v", batch) } } diff --git a/crypto/verification.go b/crypto/verification.go index 4925fed6..be246874 100644 --- a/crypto/verification.go +++ b/crypto/verification.go @@ -54,8 +54,8 @@ const ( ) // sendToOneDevice sends a to-device event to a single device. -func (mach *OlmMachine) sendToOneDevice(userID id.UserID, deviceID id.DeviceID, eventType event.Type, content interface{}) error { - _, err := mach.Client.SendToDevice(eventType, &mautrix.ReqSendToDevice{ +func (mach *OlmMachine) sendToOneDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID, eventType event.Type, content interface{}) error { + _, err := mach.Client.SendToDevice(ctx, eventType, &mautrix.ReqSendToDevice{ Messages: map[id.UserID]map[id.DeviceID]*event.Content{ userID: { deviceID: { @@ -118,19 +118,19 @@ type verificationState struct { } // getTransactionState retrieves the given transaction's state, or cancels the transaction if it cannot be found or there is a mismatch. -func (mach *OlmMachine) getTransactionState(transactionID string, userID id.UserID) (*verificationState, error) { +func (mach *OlmMachine) getTransactionState(ctx context.Context, transactionID string, userID id.UserID) (*verificationState, error) { verStateInterface, ok := mach.keyVerificationTransactionState.Load(userID.String() + ":" + transactionID) if !ok { - _ = mach.SendSASVerificationCancel(userID, id.DeviceID("*"), transactionID, "Unknown transaction: "+transactionID, event.VerificationCancelUnknownTransaction) + _ = mach.SendSASVerificationCancel(ctx, userID, id.DeviceID("*"), transactionID, "Unknown transaction: "+transactionID, event.VerificationCancelUnknownTransaction) return nil, ErrUnknownTransaction } verState := verStateInterface.(*verificationState) if verState.otherDevice.UserID != userID { reason := fmt.Sprintf("Unknown user for transaction %v: %v", transactionID, userID) if verState.inRoomID == "" { - _ = mach.SendSASVerificationCancel(userID, id.DeviceID("*"), transactionID, reason, event.VerificationCancelUserMismatch) + _ = mach.SendSASVerificationCancel(ctx, userID, id.DeviceID("*"), transactionID, reason, event.VerificationCancelUserMismatch) } else { - _ = mach.SendInRoomSASVerificationCancel(verState.inRoomID, userID, transactionID, reason, event.VerificationCancelUserMismatch) + _ = mach.SendInRoomSASVerificationCancel(ctx, verState.inRoomID, userID, transactionID, reason, event.VerificationCancelUserMismatch) } mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) return nil, fmt.Errorf("%w %s: %s", ErrUnknownUserForTransaction, transactionID, userID) @@ -140,9 +140,9 @@ func (mach *OlmMachine) getTransactionState(transactionID string, userID id.User // handleVerificationStart handles an incoming m.key.verification.start message. // It initializes the state for this SAS verification process and stores it. -func (mach *OlmMachine) handleVerificationStart(userID id.UserID, content *event.VerificationStartEventContent, transactionID string, timeout time.Duration, inRoomID id.RoomID) { +func (mach *OlmMachine) handleVerificationStart(ctx context.Context, userID id.UserID, content *event.VerificationStartEventContent, transactionID string, timeout time.Duration, inRoomID id.RoomID) { mach.Log.Debug().Msgf("Received verification start from %v", content.FromDevice) - otherDevice, err := mach.GetOrFetchDevice(context.TODO(), userID, content.FromDevice) + otherDevice, err := mach.GetOrFetchDevice(ctx, userID, content.FromDevice) if err != nil { mach.Log.Error().Msgf("Could not find device %v of user %v", content.FromDevice, userID) return @@ -150,9 +150,9 @@ func (mach *OlmMachine) handleVerificationStart(userID id.UserID, content *event warnAndCancel := func(logReason, cancelReason string) { mach.Log.Warn().Msgf("Canceling verification transaction %v as it %s", transactionID, logReason) if inRoomID == "" { - _ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, cancelReason, event.VerificationCancelUnknownMethod) + _ = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, cancelReason, event.VerificationCancelUnknownMethod) } else { - _ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, cancelReason, event.VerificationCancelUnknownMethod) + _ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, cancelReason, event.VerificationCancelUnknownMethod) } } switch { @@ -168,21 +168,21 @@ func (mach *OlmMachine) handleVerificationStart(userID id.UserID, content *event case !content.SupportsSASMethod(event.SASDecimal): warnAndCancel("does not support decimal SAS", "Decimal SAS method must be supported") default: - mach.actuallyStartVerification(userID, content, otherDevice, transactionID, timeout, inRoomID) + mach.actuallyStartVerification(ctx, userID, content, otherDevice, transactionID, timeout, inRoomID) } } -func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *event.VerificationStartEventContent, otherDevice *id.Device, transactionID string, timeout time.Duration, inRoomID id.RoomID) { +func (mach *OlmMachine) actuallyStartVerification(ctx context.Context, userID id.UserID, content *event.VerificationStartEventContent, otherDevice *id.Device, transactionID string, timeout time.Duration, inRoomID id.RoomID) { if inRoomID != "" && transactionID != "" { - verState, err := mach.getTransactionState(transactionID, userID) + verState, err := mach.getTransactionState(ctx, transactionID, userID) if err != nil { mach.Log.Error().Msgf("Failed to get transaction state for in-room verification %s start: %v", transactionID, err) - _ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Internal state error in gomuks :(", "net.maunium.internal_error") + _ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "Internal state error in gomuks :(", "net.maunium.internal_error") return } - mach.timeoutAfter(verState, transactionID, timeout) + mach.timeoutAfter(ctx, verState, transactionID, timeout) sasMethods := commonSASMethods(verState.hooks, content.ShortAuthenticationString) - err = mach.SendInRoomSASVerificationAccept(inRoomID, userID, content, transactionID, verState.sas.GetPubkey(), sasMethods) + err = mach.SendInRoomSASVerificationAccept(ctx, inRoomID, userID, content, transactionID, verState.sas.GetPubkey(), sasMethods) if err != nil { mach.Log.Error().Msgf("Error accepting in-room SAS verification: %v", err) } @@ -196,9 +196,9 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve if len(sasMethods) == 0 { mach.Log.Error().Msgf("No common SAS methods: %v", content.ShortAuthenticationString) if inRoomID == "" { - _ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "No common SAS methods", event.VerificationCancelUnknownMethod) + _ = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, "No common SAS methods", event.VerificationCancelUnknownMethod) } else { - _ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "No common SAS methods", event.VerificationCancelUnknownMethod) + _ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "No common SAS methods", event.VerificationCancelUnknownMethod) } return } @@ -221,20 +221,20 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve // transaction already exists mach.Log.Error().Msgf("Transaction %v already exists, canceling", transactionID) if inRoomID == "" { - _ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Transaction already exists", event.VerificationCancelUnexpectedMessage) + _ = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, "Transaction already exists", event.VerificationCancelUnexpectedMessage) } else { - _ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Transaction already exists", event.VerificationCancelUnexpectedMessage) + _ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "Transaction already exists", event.VerificationCancelUnexpectedMessage) } return } - mach.timeoutAfter(verState, transactionID, timeout) + mach.timeoutAfter(ctx, verState, transactionID, timeout) var err error if inRoomID == "" { - err = mach.SendSASVerificationAccept(userID, content, verState.sas.GetPubkey(), sasMethods) + err = mach.SendSASVerificationAccept(ctx, userID, content, verState.sas.GetPubkey(), sasMethods) } else { - err = mach.SendInRoomSASVerificationAccept(inRoomID, userID, content, transactionID, verState.sas.GetPubkey(), sasMethods) + err = mach.SendInRoomSASVerificationAccept(ctx, inRoomID, userID, content, transactionID, verState.sas.GetPubkey(), sasMethods) } if err != nil { mach.Log.Error().Msgf("Error accepting SAS verification: %v", err) @@ -243,9 +243,9 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve mach.Log.Debug().Msgf("Not accepting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID) var err error if inRoomID == "" { - err = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Not accepted by user", event.VerificationCancelByUser) + err = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, "Not accepted by user", event.VerificationCancelByUser) } else { - err = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Not accepted by user", event.VerificationCancelByUser) + err = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "Not accepted by user", event.VerificationCancelByUser) } if err != nil { mach.Log.Error().Msgf("Error canceling SAS verification: %v", err) @@ -255,8 +255,8 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve } } -func (mach *OlmMachine) timeoutAfter(verState *verificationState, transactionID string, timeout time.Duration) { - timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), timeout) +func (mach *OlmMachine) timeoutAfter(ctx context.Context, verState *verificationState, transactionID string, timeout time.Duration) { + timeoutCtx, timeoutCancel := context.WithTimeout(ctx, timeout) verState.extendTimeout = timeoutCancel go func() { mapKey := verState.otherDevice.UserID.String() + ":" + transactionID @@ -272,7 +272,7 @@ func (mach *OlmMachine) timeoutAfter(verState *verificationState, transactionID if timeoutCtx.Err() == context.DeadlineExceeded { // if deadline exceeded cancel due to timeout mach.keyVerificationTransactionState.Delete(mapKey) - _ = mach.callbackAndCancelSASVerification(verState, transactionID, "Timed out", event.VerificationCancelByTimeout) + _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Timed out", event.VerificationCancelByTimeout) mach.Log.Warn().Msgf("Verification transaction %v is canceled due to timing out", transactionID) verState.lock.Unlock() return @@ -288,9 +288,9 @@ func (mach *OlmMachine) timeoutAfter(verState *verificationState, transactionID // handleVerificationAccept handles an incoming m.key.verification.accept message. // It continues the SAS verification process by sending the SAS key message to the other device. -func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *event.VerificationAcceptEventContent, transactionID string) { +func (mach *OlmMachine) handleVerificationAccept(ctx context.Context, userID id.UserID, content *event.VerificationAcceptEventContent, transactionID string) { mach.Log.Debug().Msgf("Received verification accept for transaction %v", transactionID) - verState, err := mach.getTransactionState(transactionID, userID) + verState, err := mach.getTransactionState(ctx, transactionID, userID) if err != nil { mach.Log.Error().Msgf("Error getting transaction state: %v", err) return @@ -303,7 +303,7 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even // unexpected accept at this point mach.Log.Warn().Msgf("Unexpected verification accept message for transaction %v", transactionID) mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) - _ = mach.callbackAndCancelSASVerification(verState, transactionID, "Unexpected accept message", event.VerificationCancelUnexpectedMessage) + _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Unexpected accept message", event.VerificationCancelUnexpectedMessage) return } @@ -315,7 +315,7 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even mach.Log.Warn().Msgf("Canceling verification transaction %v due to unknown parameter", transactionID) mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) - _ = mach.callbackAndCancelSASVerification(verState, transactionID, "Verification uses unknown method", event.VerificationCancelUnknownMethod) + _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Verification uses unknown method", event.VerificationCancelUnknownMethod) return } @@ -325,9 +325,9 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even verState.verificationStarted = true if verState.inRoomID == "" { - err = mach.SendSASVerificationKey(userID, verState.otherDevice.DeviceID, transactionID, string(key)) + err = mach.SendSASVerificationKey(ctx, userID, verState.otherDevice.DeviceID, transactionID, string(key)) } else { - err = mach.SendInRoomSASVerificationKey(verState.inRoomID, userID, transactionID, string(key)) + err = mach.SendInRoomSASVerificationKey(ctx, verState.inRoomID, userID, transactionID, string(key)) } if err != nil { mach.Log.Error().Msgf("Error sending SAS key to other device: %v", err) @@ -337,9 +337,9 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even // handleVerificationKey handles an incoming m.key.verification.key message. // It stores the other device's public key in order to acquire the SAS shared secret. -func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.VerificationKeyEventContent, transactionID string) { +func (mach *OlmMachine) handleVerificationKey(ctx context.Context, userID id.UserID, content *event.VerificationKeyEventContent, transactionID string) { mach.Log.Debug().Msgf("Got verification key for transaction %v: %v", transactionID, content.Key) - verState, err := mach.getTransactionState(transactionID, userID) + verState, err := mach.getTransactionState(ctx, transactionID, userID) if err != nil { mach.Log.Error().Msgf("Error getting transaction state: %v", err) return @@ -354,7 +354,7 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V // unexpected key at this point mach.Log.Warn().Msgf("Unexpected verification key message for transaction %v", transactionID) mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) - _ = mach.callbackAndCancelSASVerification(verState, transactionID, "Unexpected key message", event.VerificationCancelUnexpectedMessage) + _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Unexpected key message", event.VerificationCancelUnexpectedMessage) return } @@ -372,7 +372,7 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V if expectedCommitment != verState.commitment { mach.Log.Warn().Msgf("Canceling verification transaction %v due to commitment mismatch", transactionID) mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) - _ = mach.callbackAndCancelSASVerification(verState, transactionID, "Commitment mismatch", event.VerificationCancelCommitmentMismatch) + _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Commitment mismatch", event.VerificationCancelCommitmentMismatch) return } } else { @@ -380,9 +380,9 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V key := verState.sas.GetPubkey() if verState.inRoomID == "" { - err = mach.SendSASVerificationKey(userID, device.DeviceID, transactionID, string(key)) + err = mach.SendSASVerificationKey(ctx, userID, device.DeviceID, transactionID, string(key)) } else { - err = mach.SendInRoomSASVerificationKey(verState.inRoomID, userID, transactionID, string(key)) + err = mach.SendInRoomSASVerificationKey(ctx, verState.inRoomID, userID, transactionID, string(key)) } if err != nil { mach.Log.Error().Msgf("Error sending SAS key to other device: %v", err) @@ -419,13 +419,13 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V mach.Log.Debug().Msgf("Generated SAS (%v): %v", sasMethod.Type(), sas) go func() { result := verState.hooks.VerifySASMatch(device, sas) - mach.sasCompared(result, transactionID, verState) + mach.sasCompared(ctx, result, transactionID, verState) }() } // sasCompared is called asynchronously. It waits for the SAS to be compared for the verification to proceed. // If the SAS match, then our MAC is sent out. Otherwise the transaction is canceled. -func (mach *OlmMachine) sasCompared(didMatch bool, transactionID string, verState *verificationState) { +func (mach *OlmMachine) sasCompared(ctx context.Context, didMatch bool, transactionID string, verState *verificationState) { verState.lock.Lock() defer verState.lock.Unlock() verState.extendTimeout() @@ -433,9 +433,9 @@ func (mach *OlmMachine) sasCompared(didMatch bool, transactionID string, verStat verState.sasMatched <- true var err error if verState.inRoomID == "" { - err = mach.SendSASVerificationMAC(verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, verState.sas) + err = mach.SendSASVerificationMAC(ctx, verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, verState.sas) } else { - err = mach.SendInRoomSASVerificationMAC(verState.inRoomID, verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, verState.sas) + err = mach.SendInRoomSASVerificationMAC(ctx, verState.inRoomID, verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, verState.sas) } if err != nil { mach.Log.Error().Msgf("Error sending verification MAC to other device: %v", err) @@ -447,9 +447,9 @@ func (mach *OlmMachine) sasCompared(didMatch bool, transactionID string, verStat // handleVerificationMAC handles an incoming m.key.verification.mac message. // It verifies the other device's MAC and if the MAC is valid it marks the device as trusted. -func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.VerificationMacEventContent, transactionID string) { +func (mach *OlmMachine) handleVerificationMAC(ctx context.Context, userID id.UserID, content *event.VerificationMacEventContent, transactionID string) { mach.Log.Debug().Msgf("Got MAC for verification %v: %v, MAC for keys: %v", transactionID, content.Mac, content.Keys) - verState, err := mach.getTransactionState(transactionID, userID) + verState, err := mach.getTransactionState(ctx, transactionID, userID) if err != nil { mach.Log.Error().Msgf("Error getting transaction state: %v", err) return @@ -466,7 +466,7 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V if !verState.verificationStarted || !verState.keyReceived { // unexpected MAC at this point mach.Log.Warn().Msgf("Unexpected MAC message for transaction %v", transactionID) - _ = mach.callbackAndCancelSASVerification(verState, transactionID, "Unexpected MAC message", event.VerificationCancelUnexpectedMessage) + _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Unexpected MAC message", event.VerificationCancelUnexpectedMessage) return } @@ -478,7 +478,7 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V if !matched { mach.Log.Warn().Msgf("SAS do not match! Canceling transaction %v", transactionID) - _ = mach.callbackAndCancelSASVerification(verState, transactionID, "SAS do not match", event.VerificationCancelSASMismatch) + _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "SAS do not match", event.VerificationCancelSASMismatch) return } @@ -494,14 +494,14 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V mach.Log.Debug().Msgf("Expected %s keys MAC, got %s", expectedKeysMAC, content.Keys) if content.Keys != expectedKeysMAC { mach.Log.Warn().Msgf("Canceling verification transaction %v due to mismatched keys MAC", transactionID) - _ = mach.callbackAndCancelSASVerification(verState, transactionID, "Mismatched keys MACs", event.VerificationCancelKeyMismatch) + _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Mismatched keys MACs", event.VerificationCancelKeyMismatch) return } mach.Log.Debug().Msgf("Expected %s PK MAC, got %s", expectedPKMAC, content.Mac[keyID]) if content.Mac[keyID] != expectedPKMAC { mach.Log.Warn().Msgf("Canceling verification transaction %v due to mismatched PK MAC", transactionID) - _ = mach.callbackAndCancelSASVerification(verState, transactionID, "Mismatched PK MACs", event.VerificationCancelKeyMismatch) + _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Mismatched PK MACs", event.VerificationCancelKeyMismatch) return } @@ -514,7 +514,7 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V if mach.CrossSigningKeys != nil { if device.UserID == mach.Client.UserID { - err := mach.SignOwnDevice(device) + err := mach.SignOwnDevice(ctx, device) if err != nil { mach.Log.Error().Msgf("Failed to cross-sign own device %s: %v", device.DeviceID, err) } else { @@ -525,7 +525,7 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V if err != nil { mach.Log.Warn().Msgf("Failed to fetch %s's master key: %v", device.UserID, err) } else { - if err := mach.SignUser(device.UserID, masterKey); err != nil { + if err := mach.SignUser(ctx, device.UserID, masterKey); err != nil { mach.Log.Error().Msgf("Failed to cross-sign master key of %s: %v", device.UserID, err) } else { mach.Log.Debug().Msgf("Cross-signed master key of %v after SAS verification", device.UserID) @@ -559,9 +559,9 @@ func (mach *OlmMachine) handleVerificationCancel(userID id.UserID, content *even } // handleVerificationRequest handles an incoming m.key.verification.request message. -func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *event.VerificationRequestEventContent, transactionID string, inRoomID id.RoomID) { +func (mach *OlmMachine) handleVerificationRequest(ctx context.Context, userID id.UserID, content *event.VerificationRequestEventContent, transactionID string, inRoomID id.RoomID) { mach.Log.Debug().Msgf("Received verification request from %v", content.FromDevice) - otherDevice, err := mach.GetOrFetchDevice(context.TODO(), userID, content.FromDevice) + otherDevice, err := mach.GetOrFetchDevice(ctx, userID, content.FromDevice) if err != nil { mach.Log.Error().Msgf("Could not find device %v of user %v", content.FromDevice, userID) return @@ -569,9 +569,9 @@ func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *eve if !content.SupportsVerificationMethod(event.VerificationMethodSAS) { mach.Log.Warn().Msgf("Canceling verification transaction %v as SAS is not supported", transactionID) if inRoomID == "" { - _ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Only SAS method is supported", event.VerificationCancelUnknownMethod) + _ = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, "Only SAS method is supported", event.VerificationCancelUnknownMethod) } else { - _ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Only SAS method is supported", event.VerificationCancelUnknownMethod) + _ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "Only SAS method is supported", event.VerificationCancelUnknownMethod) } return } @@ -579,14 +579,14 @@ func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *eve if resp == AcceptRequest { mach.Log.Debug().Msgf("Accepting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID) if inRoomID == "" { - _, err = mach.NewSASVerificationWith(otherDevice, hooks, transactionID, mach.DefaultSASTimeout) + _, err = mach.NewSASVerificationWith(ctx, otherDevice, hooks, transactionID, mach.DefaultSASTimeout) } else { - if err := mach.SendInRoomSASVerificationReady(inRoomID, transactionID); err != nil { + if err := mach.SendInRoomSASVerificationReady(ctx, inRoomID, transactionID); err != nil { mach.Log.Error().Msgf("Error sending in-room SAS verification ready: %v", err) } if mach.Client.UserID < otherDevice.UserID { // up to us to send the start message - _, err = mach.newInRoomSASVerificationWithInner(inRoomID, otherDevice, hooks, transactionID, mach.DefaultSASTimeout) + _, err = mach.newInRoomSASVerificationWithInner(ctx, inRoomID, otherDevice, hooks, transactionID, mach.DefaultSASTimeout) } } if err != nil { @@ -595,9 +595,9 @@ func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *eve } else if resp == RejectRequest { mach.Log.Debug().Msgf("Rejecting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID) if inRoomID == "" { - _ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Not accepted by user", event.VerificationCancelByUser) + _ = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, "Not accepted by user", event.VerificationCancelByUser) } else { - _ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Not accepted by user", event.VerificationCancelByUser) + _ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "Not accepted by user", event.VerificationCancelByUser) } } else { mach.Log.Debug().Msgf("Ignoring SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID) @@ -606,14 +606,14 @@ func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *eve // NewSimpleSASVerificationWith starts the SAS verification process with another device with a default timeout, // a generated transaction ID and support for both emoji and decimal SAS methods. -func (mach *OlmMachine) NewSimpleSASVerificationWith(device *id.Device, hooks VerificationHooks) (string, error) { - return mach.NewSASVerificationWith(device, hooks, "", mach.DefaultSASTimeout) +func (mach *OlmMachine) NewSimpleSASVerificationWith(ctx context.Context, device *id.Device, hooks VerificationHooks) (string, error) { + return mach.NewSASVerificationWith(ctx, device, hooks, "", mach.DefaultSASTimeout) } // NewSASVerificationWith starts the SAS verification process with another device. // If the other device accepts the verification transaction, the methods in `hooks` will be used to verify the SAS match and to complete the transaction.. // If the transaction ID is empty, a new one is generated. -func (mach *OlmMachine) NewSASVerificationWith(device *id.Device, hooks VerificationHooks, transactionID string, timeout time.Duration) (string, error) { +func (mach *OlmMachine) NewSASVerificationWith(ctx context.Context, device *id.Device, hooks VerificationHooks, transactionID string, timeout time.Duration) (string, error) { if transactionID == "" { transactionID = strconv.Itoa(rand.Int()) } @@ -631,7 +631,7 @@ func (mach *OlmMachine) NewSASVerificationWith(device *id.Device, hooks Verifica verState.lock.Lock() defer verState.lock.Unlock() - startEvent, err := mach.SendSASVerificationStart(device.UserID, device.DeviceID, transactionID, hooks.VerificationMethods()) + startEvent, err := mach.SendSASVerificationStart(ctx, device.UserID, device.DeviceID, transactionID, hooks.VerificationMethods()) if err != nil { return "", err } @@ -651,13 +651,13 @@ func (mach *OlmMachine) NewSASVerificationWith(device *id.Device, hooks Verifica return "", ErrTransactionAlreadyExists } - mach.timeoutAfter(verState, transactionID, timeout) + mach.timeoutAfter(ctx, verState, transactionID, timeout) return transactionID, nil } // CancelSASVerification is used by the user to cancel a SAS verification process with the given reason. -func (mach *OlmMachine) CancelSASVerification(userID id.UserID, transactionID, reason string) error { +func (mach *OlmMachine) CancelSASVerification(ctx context.Context, userID id.UserID, transactionID, reason string) error { mapKey := userID.String() + ":" + transactionID verStateInterface, ok := mach.keyVerificationTransactionState.Load(mapKey) if !ok { @@ -668,21 +668,21 @@ func (mach *OlmMachine) CancelSASVerification(userID id.UserID, transactionID, r defer verState.lock.Unlock() mach.Log.Trace().Msgf("User canceled verification transaction %v with reason: %v", transactionID, reason) mach.keyVerificationTransactionState.Delete(mapKey) - return mach.callbackAndCancelSASVerification(verState, transactionID, reason, event.VerificationCancelByUser) + return mach.callbackAndCancelSASVerification(ctx, verState, transactionID, reason, event.VerificationCancelByUser) } // SendSASVerificationCancel is used to manually send a SAS cancel message process with the given reason and cancellation code. -func (mach *OlmMachine) SendSASVerificationCancel(userID id.UserID, deviceID id.DeviceID, transactionID string, reason string, code event.VerificationCancelCode) error { +func (mach *OlmMachine) SendSASVerificationCancel(ctx context.Context, userID id.UserID, deviceID id.DeviceID, transactionID string, reason string, code event.VerificationCancelCode) error { content := &event.VerificationCancelEventContent{ TransactionID: transactionID, Reason: reason, Code: code, } - return mach.sendToOneDevice(userID, deviceID, event.ToDeviceVerificationCancel, content) + return mach.sendToOneDevice(ctx, userID, deviceID, event.ToDeviceVerificationCancel, content) } // SendSASVerificationStart is used to manually send the SAS verification start message to another device. -func (mach *OlmMachine) SendSASVerificationStart(toUserID id.UserID, toDeviceID id.DeviceID, transactionID string, methods []VerificationMethod) (*event.VerificationStartEventContent, error) { +func (mach *OlmMachine) SendSASVerificationStart(ctx context.Context, toUserID id.UserID, toDeviceID id.DeviceID, transactionID string, methods []VerificationMethod) (*event.VerificationStartEventContent, error) { sasMethods := make([]event.SASMethod, len(methods)) for i, method := range methods { sasMethods[i] = method.Type() @@ -696,14 +696,14 @@ func (mach *OlmMachine) SendSASVerificationStart(toUserID id.UserID, toDeviceID MessageAuthenticationCodes: []event.MACMethod{event.HKDFHMACSHA256}, ShortAuthenticationString: sasMethods, } - return content, mach.sendToOneDevice(toUserID, toDeviceID, event.ToDeviceVerificationStart, content) + return content, mach.sendToOneDevice(ctx, toUserID, toDeviceID, event.ToDeviceVerificationStart, content) } // SendSASVerificationAccept is used to manually send an accept for a SAS verification process from a received m.key.verification.start event. -func (mach *OlmMachine) SendSASVerificationAccept(fromUser id.UserID, startEvent *event.VerificationStartEventContent, publicKey []byte, methods []VerificationMethod) error { +func (mach *OlmMachine) SendSASVerificationAccept(ctx context.Context, fromUser id.UserID, startEvent *event.VerificationStartEventContent, publicKey []byte, methods []VerificationMethod) error { if startEvent.Method != event.VerificationMethodSAS { reason := "Unknown verification method: " + string(startEvent.Method) - if err := mach.SendSASVerificationCancel(fromUser, startEvent.FromDevice, startEvent.TransactionID, reason, event.VerificationCancelUnknownMethod); err != nil { + if err := mach.SendSASVerificationCancel(ctx, fromUser, startEvent.FromDevice, startEvent.TransactionID, reason, event.VerificationCancelUnknownMethod); err != nil { return err } return ErrUnknownVerificationMethod @@ -730,25 +730,25 @@ func (mach *OlmMachine) SendSASVerificationAccept(fromUser id.UserID, startEvent ShortAuthenticationString: sasMethods, Commitment: hash, } - return mach.sendToOneDevice(fromUser, startEvent.FromDevice, event.ToDeviceVerificationAccept, content) + return mach.sendToOneDevice(ctx, fromUser, startEvent.FromDevice, event.ToDeviceVerificationAccept, content) } -func (mach *OlmMachine) callbackAndCancelSASVerification(verState *verificationState, transactionID, reason string, code event.VerificationCancelCode) error { +func (mach *OlmMachine) callbackAndCancelSASVerification(ctx context.Context, verState *verificationState, transactionID, reason string, code event.VerificationCancelCode) error { go verState.hooks.OnCancel(true, reason, code) - return mach.SendSASVerificationCancel(verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, reason, code) + return mach.SendSASVerificationCancel(ctx, verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, reason, code) } // SendSASVerificationKey sends the ephemeral public key for a device to the partner device. -func (mach *OlmMachine) SendSASVerificationKey(userID id.UserID, deviceID id.DeviceID, transactionID string, key string) error { +func (mach *OlmMachine) SendSASVerificationKey(ctx context.Context, userID id.UserID, deviceID id.DeviceID, transactionID string, key string) error { content := &event.VerificationKeyEventContent{ TransactionID: transactionID, Key: key, } - return mach.sendToOneDevice(userID, deviceID, event.ToDeviceVerificationKey, content) + return mach.sendToOneDevice(ctx, userID, deviceID, event.ToDeviceVerificationKey, content) } // SendSASVerificationMAC is use the MAC of a device's key to the partner device. -func (mach *OlmMachine) SendSASVerificationMAC(userID id.UserID, deviceID id.DeviceID, transactionID string, sas *olm.SAS) error { +func (mach *OlmMachine) SendSASVerificationMAC(ctx context.Context, userID id.UserID, deviceID id.DeviceID, transactionID string, sas *olm.SAS) error { keyID := id.NewKeyID(id.KeyAlgorithmEd25519, mach.Client.DeviceID.String()) signingKey := mach.account.SigningKey() @@ -784,7 +784,7 @@ func (mach *OlmMachine) SendSASVerificationMAC(userID id.UserID, deviceID id.Dev Mac: macMap, } - return mach.sendToOneDevice(userID, deviceID, event.ToDeviceVerificationMAC, content) + return mach.sendToOneDevice(ctx, userID, deviceID, event.ToDeviceVerificationMAC, content) } func commonSASMethods(hooks VerificationHooks, otherDeviceMethods []event.SASMethod) []VerificationMethod { diff --git a/crypto/verification_in_room.go b/crypto/verification_in_room.go index cc9b9212..325b45ba 100644 --- a/crypto/verification_in_room.go +++ b/crypto/verification_in_room.go @@ -38,6 +38,7 @@ func (mach *OlmMachine) ProcessInRoomVerification(evt *event.Event) error { return ErrNoRelatesTo } + ctx := context.Background() switch content := evt.Content.Parsed.(type) { case *event.MessageEventContent: if content.MsgType == event.MsgVerificationRequest { @@ -54,18 +55,18 @@ func (mach *OlmMachine) ProcessInRoomVerification(evt *event.Event) error { Timestamp: evt.Timestamp, TransactionID: evt.ID.String(), } - mach.handleVerificationRequest(evt.Sender, newContent, evt.ID.String(), evt.RoomID) + mach.handleVerificationRequest(ctx, evt.Sender, newContent, evt.ID.String(), evt.RoomID) } case *event.VerificationStartEventContent: - mach.handleVerificationStart(evt.Sender, content, content.RelatesTo.EventID.String(), 10*time.Minute, evt.RoomID) + mach.handleVerificationStart(ctx, evt.Sender, content, content.RelatesTo.EventID.String(), 10*time.Minute, evt.RoomID) case *event.VerificationReadyEventContent: - mach.handleInRoomVerificationReady(evt.Sender, evt.RoomID, content, content.RelatesTo.EventID.String()) + mach.handleInRoomVerificationReady(ctx, evt.Sender, evt.RoomID, content, content.RelatesTo.EventID.String()) case *event.VerificationAcceptEventContent: - mach.handleVerificationAccept(evt.Sender, content, content.RelatesTo.EventID.String()) + mach.handleVerificationAccept(ctx, evt.Sender, content, content.RelatesTo.EventID.String()) case *event.VerificationKeyEventContent: - mach.handleVerificationKey(evt.Sender, content, content.RelatesTo.EventID.String()) + mach.handleVerificationKey(ctx, evt.Sender, content, content.RelatesTo.EventID.String()) case *event.VerificationMacEventContent: - mach.handleVerificationMAC(evt.Sender, content, content.RelatesTo.EventID.String()) + mach.handleVerificationMAC(ctx, evt.Sender, content, content.RelatesTo.EventID.String()) case *event.VerificationCancelEventContent: mach.handleVerificationCancel(evt.Sender, content, content.RelatesTo.EventID.String()) } @@ -73,7 +74,7 @@ func (mach *OlmMachine) ProcessInRoomVerification(evt *event.Event) error { } // SendInRoomSASVerificationCancel is used to manually send an in-room SAS cancel message process with the given reason and cancellation code. -func (mach *OlmMachine) SendInRoomSASVerificationCancel(roomID id.RoomID, userID id.UserID, transactionID string, reason string, code event.VerificationCancelCode) error { +func (mach *OlmMachine) SendInRoomSASVerificationCancel(ctx context.Context, roomID id.RoomID, userID id.UserID, transactionID string, reason string, code event.VerificationCancelCode) error { content := &event.VerificationCancelEventContent{ RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)}, Reason: reason, @@ -81,16 +82,16 @@ func (mach *OlmMachine) SendInRoomSASVerificationCancel(roomID id.RoomID, userID To: userID, } - encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationCancel, content) + encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationCancel, content) if err != nil { return err } - _, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted) + _, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted) return err } // SendInRoomSASVerificationRequest is used to manually send an in-room SAS verification request message to another user. -func (mach *OlmMachine) SendInRoomSASVerificationRequest(roomID id.RoomID, toUserID id.UserID, methods []VerificationMethod) (string, error) { +func (mach *OlmMachine) SendInRoomSASVerificationRequest(ctx context.Context, roomID id.RoomID, toUserID id.UserID, methods []VerificationMethod) (string, error) { content := &event.MessageEventContent{ MsgType: event.MsgVerificationRequest, FromDevice: mach.Client.DeviceID, @@ -98,11 +99,11 @@ func (mach *OlmMachine) SendInRoomSASVerificationRequest(roomID id.RoomID, toUse To: toUserID, } - encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.EventMessage, content) + encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.EventMessage, content) if err != nil { return "", err } - resp, err := mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted) + resp, err := mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted) if err != nil { return "", err } @@ -110,23 +111,23 @@ func (mach *OlmMachine) SendInRoomSASVerificationRequest(roomID id.RoomID, toUse } // SendInRoomSASVerificationReady is used to manually send an in-room SAS verification ready message to another user. -func (mach *OlmMachine) SendInRoomSASVerificationReady(roomID id.RoomID, transactionID string) error { +func (mach *OlmMachine) SendInRoomSASVerificationReady(ctx context.Context, roomID id.RoomID, transactionID string) error { content := &event.VerificationReadyEventContent{ FromDevice: mach.Client.DeviceID, Methods: []event.VerificationMethod{event.VerificationMethodSAS}, RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)}, } - encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationReady, content) + encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationReady, content) if err != nil { return err } - _, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted) + _, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted) return err } // SendInRoomSASVerificationStart is used to manually send the in-room SAS verification start message to another user. -func (mach *OlmMachine) SendInRoomSASVerificationStart(roomID id.RoomID, toUserID id.UserID, transactionID string, methods []VerificationMethod) (*event.VerificationStartEventContent, error) { +func (mach *OlmMachine) SendInRoomSASVerificationStart(ctx context.Context, roomID id.RoomID, toUserID id.UserID, transactionID string, methods []VerificationMethod) (*event.VerificationStartEventContent, error) { sasMethods := make([]event.SASMethod, len(methods)) for i, method := range methods { sasMethods[i] = method.Type() @@ -142,19 +143,19 @@ func (mach *OlmMachine) SendInRoomSASVerificationStart(roomID id.RoomID, toUserI To: toUserID, } - encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationStart, content) + encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationStart, content) if err != nil { return nil, err } - _, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted) + _, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted) return content, err } // SendInRoomSASVerificationAccept is used to manually send an accept for an in-room SAS verification process from a received m.key.verification.start event. -func (mach *OlmMachine) SendInRoomSASVerificationAccept(roomID id.RoomID, fromUser id.UserID, startEvent *event.VerificationStartEventContent, transactionID string, publicKey []byte, methods []VerificationMethod) error { +func (mach *OlmMachine) SendInRoomSASVerificationAccept(ctx context.Context, roomID id.RoomID, fromUser id.UserID, startEvent *event.VerificationStartEventContent, transactionID string, publicKey []byte, methods []VerificationMethod) error { if startEvent.Method != event.VerificationMethodSAS { reason := "Unknown verification method: " + string(startEvent.Method) - if err := mach.SendInRoomSASVerificationCancel(roomID, fromUser, transactionID, reason, event.VerificationCancelUnknownMethod); err != nil { + if err := mach.SendInRoomSASVerificationCancel(ctx, roomID, fromUser, transactionID, reason, event.VerificationCancelUnknownMethod); err != nil { return err } return ErrUnknownVerificationMethod @@ -183,32 +184,32 @@ func (mach *OlmMachine) SendInRoomSASVerificationAccept(roomID id.RoomID, fromUs To: fromUser, } - encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationAccept, content) + encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationAccept, content) if err != nil { return err } - _, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted) + _, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted) return err } // SendInRoomSASVerificationKey sends the ephemeral public key for a device to the partner device for an in-room verification. -func (mach *OlmMachine) SendInRoomSASVerificationKey(roomID id.RoomID, userID id.UserID, transactionID string, key string) error { +func (mach *OlmMachine) SendInRoomSASVerificationKey(ctx context.Context, roomID id.RoomID, userID id.UserID, transactionID string, key string) error { content := &event.VerificationKeyEventContent{ RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)}, Key: key, To: userID, } - encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationKey, content) + encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationKey, content) if err != nil { return err } - _, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted) + _, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted) return err } // SendInRoomSASVerificationMAC sends the MAC of a device's key to the partner device for an in-room verification. -func (mach *OlmMachine) SendInRoomSASVerificationMAC(roomID id.RoomID, userID id.UserID, deviceID id.DeviceID, transactionID string, sas *olm.SAS) error { +func (mach *OlmMachine) SendInRoomSASVerificationMAC(ctx context.Context, roomID id.RoomID, userID id.UserID, deviceID id.DeviceID, transactionID string, sas *olm.SAS) error { keyID := id.NewKeyID(id.KeyAlgorithmEd25519, mach.Client.DeviceID.String()) signingKey := mach.account.SigningKey() @@ -245,28 +246,28 @@ func (mach *OlmMachine) SendInRoomSASVerificationMAC(roomID id.RoomID, userID id To: userID, } - encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationMAC, content) + encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationMAC, content) if err != nil { return err } - _, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted) + _, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted) return err } // NewInRoomSASVerificationWith starts the in-room SAS verification process with another user in the given room. // It returns the generated transaction ID. -func (mach *OlmMachine) NewInRoomSASVerificationWith(inRoomID id.RoomID, userID id.UserID, hooks VerificationHooks, timeout time.Duration) (string, error) { - return mach.newInRoomSASVerificationWithInner(inRoomID, &id.Device{UserID: userID}, hooks, "", timeout) +func (mach *OlmMachine) NewInRoomSASVerificationWith(ctx context.Context, inRoomID id.RoomID, userID id.UserID, hooks VerificationHooks, timeout time.Duration) (string, error) { + return mach.newInRoomSASVerificationWithInner(ctx, inRoomID, &id.Device{UserID: userID}, hooks, "", timeout) } -func (mach *OlmMachine) newInRoomSASVerificationWithInner(inRoomID id.RoomID, device *id.Device, hooks VerificationHooks, transactionID string, timeout time.Duration) (string, error) { +func (mach *OlmMachine) newInRoomSASVerificationWithInner(ctx context.Context, inRoomID id.RoomID, device *id.Device, hooks VerificationHooks, transactionID string, timeout time.Duration) (string, error) { mach.Log.Debug().Msgf("Starting new in-room verification transaction user %v", device.UserID) request := transactionID == "" if request { var err error // get new transaction ID from the request message event ID - transactionID, err = mach.SendInRoomSASVerificationRequest(inRoomID, device.UserID, hooks.VerificationMethods()) + transactionID, err = mach.SendInRoomSASVerificationRequest(ctx, inRoomID, device.UserID, hooks.VerificationMethods()) if err != nil { return "", err } @@ -286,7 +287,7 @@ func (mach *OlmMachine) newInRoomSASVerificationWithInner(inRoomID id.RoomID, de if !request { // start in-room verification - startEvent, err := mach.SendInRoomSASVerificationStart(inRoomID, device.UserID, transactionID, hooks.VerificationMethods()) + startEvent, err := mach.SendInRoomSASVerificationStart(ctx, inRoomID, device.UserID, transactionID, hooks.VerificationMethods()) if err != nil { return "", err } @@ -305,19 +306,19 @@ func (mach *OlmMachine) newInRoomSASVerificationWithInner(inRoomID id.RoomID, de mach.keyVerificationTransactionState.Store(device.UserID.String()+":"+transactionID, verState) - mach.timeoutAfter(verState, transactionID, timeout) + mach.timeoutAfter(ctx, verState, transactionID, timeout) return transactionID, nil } -func (mach *OlmMachine) handleInRoomVerificationReady(userID id.UserID, roomID id.RoomID, content *event.VerificationReadyEventContent, transactionID string) { - device, err := mach.GetOrFetchDevice(context.TODO(), userID, content.FromDevice) +func (mach *OlmMachine) handleInRoomVerificationReady(ctx context.Context, userID id.UserID, roomID id.RoomID, content *event.VerificationReadyEventContent, transactionID string) { + device, err := mach.GetOrFetchDevice(ctx, userID, content.FromDevice) if err != nil { mach.Log.Error().Msgf("Error fetching device %v of user %v: %v", content.FromDevice, userID, err) return } - verState, err := mach.getTransactionState(transactionID, userID) + verState, err := mach.getTransactionState(ctx, transactionID, userID) if err != nil { mach.Log.Error().Msgf("Error getting transaction state: %v", err) return @@ -327,7 +328,7 @@ func (mach *OlmMachine) handleInRoomVerificationReady(userID id.UserID, roomID i if mach.Client.UserID < userID { // up to us to send the start message verState.lock.Lock() - mach.newInRoomSASVerificationWithInner(roomID, device, verState.hooks, transactionID, 10*time.Minute) + mach.newInRoomSASVerificationWithInner(ctx, roomID, device, verState.hooks, transactionID, 10*time.Minute) verState.lock.Unlock() } } diff --git a/synapseadmin/register.go b/synapseadmin/register.go index 36b310a9..d7a94f6f 100644 --- a/synapseadmin/register.go +++ b/synapseadmin/register.go @@ -73,11 +73,10 @@ func (req *ReqSharedSecretRegister) Sign(secret string) string { // This does not need to be called manually as SharedSecretRegister will automatically call this if no nonce is provided. func (cli *Client) GetRegisterNonce(ctx context.Context) (string, error) { var resp respGetRegisterNonce - _, err := cli.MakeFullRequest(mautrix.FullRequest{ + _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ Method: http.MethodGet, URL: cli.BuildURL(mautrix.SynapseAdminURLPath{"v1", "register"}), ResponseJSON: &resp, - Context: ctx, }) if err != nil { return "", err @@ -98,12 +97,11 @@ func (cli *Client) SharedSecretRegister(ctx context.Context, sharedSecret string } req.SHA1Checksum = req.Sign(sharedSecret) var resp mautrix.RespRegister - _, err = cli.MakeFullRequest(mautrix.FullRequest{ + _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{ Method: http.MethodPost, URL: cli.BuildURL(mautrix.SynapseAdminURLPath{"v1", "register"}), RequestJSON: req, ResponseJSON: &resp, - Context: ctx, }) if err != nil { return nil, err diff --git a/synapseadmin/userapi.go b/synapseadmin/userapi.go index ee457abc..aa1ce2a7 100644 --- a/synapseadmin/userapi.go +++ b/synapseadmin/userapi.go @@ -33,11 +33,10 @@ type ReqResetPassword struct { // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#reset-password func (cli *Client) ResetPassword(ctx context.Context, req ReqResetPassword) error { reqURL := cli.BuildAdminURL("v1", "reset_password", req.UserID) - _, err := cli.MakeFullRequest(mautrix.FullRequest{ + _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ Method: http.MethodPost, URL: reqURL, RequestJSON: &req, - Context: ctx, }) return err } @@ -50,11 +49,10 @@ func (cli *Client) ResetPassword(ctx context.Context, req ReqResetPassword) erro // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#check-username-availability func (cli *Client) UsernameAvailable(ctx context.Context, username string) (resp *mautrix.RespRegisterAvailable, err error) { u := cli.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "username_available"}, map[string]string{"username": username}) - _, err = cli.MakeFullRequest(mautrix.FullRequest{ + _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{ Method: http.MethodGet, URL: u, ResponseJSON: &resp, - Context: ctx, }) if err == nil && !resp.Available { err = fmt.Errorf(`request returned OK status without "available": true`) @@ -76,11 +74,10 @@ type RespListDevices struct { // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#list-all-devices func (cli *Client) ListDevices(ctx context.Context, userID id.UserID) (resp *RespListDevices, err error) { - _, err = cli.MakeFullRequest(mautrix.FullRequest{ + _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{ Method: http.MethodGet, URL: cli.BuildAdminURL("v2", "users", userID, "devices"), ResponseJSON: &resp, - Context: ctx, }) return } @@ -105,11 +102,10 @@ type RespUserInfo struct { // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#query-user-account func (cli *Client) GetUserInfo(ctx context.Context, userID id.UserID) (resp *RespUserInfo, err error) { - _, err = cli.MakeFullRequest(mautrix.FullRequest{ + _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{ Method: http.MethodGet, URL: cli.BuildAdminURL("v2", "users", userID), ResponseJSON: &resp, - Context: ctx, }) return } diff --git a/syncstore.go b/syncstore.go index d5fe2db4..8b5b3a55 100644 --- a/syncstore.go +++ b/syncstore.go @@ -1,21 +1,25 @@ package mautrix import ( + "context" "errors" "maunium.net/go/mautrix/id" ) +var _ SyncStore = (*MemorySyncStore)(nil) +var _ SyncStore = (*AccountDataStore)(nil) + // SyncStore is an interface which must be satisfied to store client data. // // You can either write a struct which persists this data to disk, or you can use the // provided "MemorySyncStore" which just keeps data around in-memory which is lost on // restarts. type SyncStore interface { - SaveFilterID(userID id.UserID, filterID string) - LoadFilterID(userID id.UserID) string - SaveNextBatch(userID id.UserID, nextBatchToken string) - LoadNextBatch(userID id.UserID) string + SaveFilterID(ctx context.Context, userID id.UserID, filterID string) + LoadFilterID(ctx context.Context, userID id.UserID) string + SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) + LoadNextBatch(ctx context.Context, userID id.UserID) string } // Deprecated: renamed to SyncStore @@ -32,22 +36,22 @@ type MemorySyncStore struct { } // SaveFilterID to memory. -func (s *MemorySyncStore) SaveFilterID(userID id.UserID, filterID string) { +func (s *MemorySyncStore) SaveFilterID(ctx context.Context, userID id.UserID, filterID string) { s.Filters[userID] = filterID } // LoadFilterID from memory. -func (s *MemorySyncStore) LoadFilterID(userID id.UserID) string { +func (s *MemorySyncStore) LoadFilterID(ctx context.Context, userID id.UserID) string { return s.Filters[userID] } // SaveNextBatch to memory. -func (s *MemorySyncStore) SaveNextBatch(userID id.UserID, nextBatchToken string) { +func (s *MemorySyncStore) SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) { s.NextBatch[userID] = nextBatchToken } // LoadNextBatch from memory. -func (s *MemorySyncStore) LoadNextBatch(userID id.UserID) string { +func (s *MemorySyncStore) LoadNextBatch(ctx context.Context, userID id.UserID) string { return s.NextBatch[userID] } @@ -72,21 +76,21 @@ type accountData struct { NextBatch string `json:"next_batch"` } -func (s *AccountDataStore) SaveFilterID(userID id.UserID, filterID string) { +func (s *AccountDataStore) SaveFilterID(ctx context.Context, userID id.UserID, filterID string) { if userID.String() != s.client.UserID.String() { panic("AccountDataStore must only be used with a single account") } s.FilterID = filterID } -func (s *AccountDataStore) LoadFilterID(userID id.UserID) string { +func (s *AccountDataStore) LoadFilterID(ctx context.Context, userID id.UserID) string { if userID.String() != s.client.UserID.String() { panic("AccountDataStore must only be used with a single account") } return s.FilterID } -func (s *AccountDataStore) SaveNextBatch(userID id.UserID, nextBatchToken string) { +func (s *AccountDataStore) SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) { if userID.String() != s.client.UserID.String() { panic("AccountDataStore must only be used with a single account") } else if nextBatchToken == s.nextBatch { @@ -97,7 +101,7 @@ func (s *AccountDataStore) SaveNextBatch(userID id.UserID, nextBatchToken string NextBatch: nextBatchToken, } - err := s.client.SetAccountData(s.EventType, data) + err := s.client.SetAccountData(ctx, s.EventType, data) if err != nil { s.client.Log.Warn().Err(err).Msg("Failed to save next batch token to account data") } else { @@ -109,14 +113,14 @@ func (s *AccountDataStore) SaveNextBatch(userID id.UserID, nextBatchToken string } } -func (s *AccountDataStore) LoadNextBatch(userID id.UserID) string { +func (s *AccountDataStore) LoadNextBatch(ctx context.Context, userID id.UserID) string { if userID.String() != s.client.UserID.String() { panic("AccountDataStore must only be used with a single account") } data := &accountData{} - err := s.client.GetAccountData(s.EventType, data) + err := s.client.GetAccountData(ctx, s.EventType, data) if err != nil { if errors.Is(err, MNotFound) { s.client.Log.Debug().Msg("No next batch token found in account data") From 02b780e808257044eefcbb9112144848ec4c241f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 15 Dec 2023 16:23:52 +0200 Subject: [PATCH 0040/1647] Update changelog --- CHANGELOG.md | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c2fcb7f6..ff29d6bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,15 +2,20 @@ * **Breaking change *(bridge)*** Added raw event to portal membership handling functions. -* *(crypto)* Added `goolm` build tag to use a pure Go implementation of Olm - instead of using libolm via cgo. Thanks to [@DerLukas15] in [#106]. +* **Breaking change *(client)*** Added context parameters to all functions ++ (thanks to [@recht] in [#144]). +* *(crypto)* Added experimental pure Go Olm implementation to replace libolm + Thanks to [@DerLukas15] in [#106]. + * You can use the `goolm` build tag to the new implementation. * *(bridge)* Added context parameter for bridge command events. * *(client)* Changed default syncer to not drop unknown events. * The syncer will still drop known events if parsing the content fails. * The behavior can be changed by changing the `ParseErrorHandler` function. [@DerLukas15]: https://github.com/DerLukas15 +[@recht]: https://github.com/recht [#106]: https://github.com/mautrix/go/pull/106 +[#144]: https://github.com/mautrix/go/pull/144 ## v0.16.2 (2023-11-16) From 96b40011068f3b378afe460a4839fdd8c62a0d9b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 15 Dec 2023 16:25:28 +0200 Subject: [PATCH 0041/1647] Update example again --- example/go.mod | 4 ++-- example/go.sum | 4 ++-- example/main.go | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/example/go.mod b/example/go.mod index 96cfba90..01790dc4 100644 --- a/example/go.mod +++ b/example/go.mod @@ -1,12 +1,12 @@ module maunium.net/go/mautrix/example -go 1.19 +go 1.20 require ( github.com/chzyer/readline v1.5.1 github.com/mattn/go-sqlite3 v1.14.18 github.com/rs/zerolog v1.31.0 - maunium.net/go/mautrix v0.16.3-0.20231215135638-893afc725981 + maunium.net/go/mautrix v0.16.3-0.20231215142331-753cdb2e1cb0 ) require ( diff --git a/example/go.sum b/example/go.sum index 7ceec40a..c292f262 100644 --- a/example/go.sum +++ b/example/go.sum @@ -47,5 +47,5 @@ golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8= maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho= -maunium.net/go/mautrix v0.16.3-0.20231215135638-893afc725981 h1:KrMb5QyuMh2ZqfRLFiKRhYRV95aPCpaAepOful5EHvg= -maunium.net/go/mautrix v0.16.3-0.20231215135638-893afc725981/go.mod h1:YL4l4rZB46/vj/ifRMEjcibbvHjgxHftOF1SgmruLu4= +maunium.net/go/mautrix v0.16.3-0.20231215142331-753cdb2e1cb0 h1:2ZWtBcTScQfMwpcoGeY4mLYXC6OmYN/4Qh2yhBiVNV4= +maunium.net/go/mautrix v0.16.3-0.20231215142331-753cdb2e1cb0/go.mod h1:YL4l4rZB46/vj/ifRMEjcibbvHjgxHftOF1SgmruLu4= diff --git a/example/main.go b/example/main.go index f6b30602..aaa94a3c 100644 --- a/example/main.go +++ b/example/main.go @@ -74,7 +74,7 @@ func main() { }) syncer.OnEventType(event.StateMember, func(source mautrix.EventSource, evt *event.Event) { if evt.GetStateKey() == client.UserID.String() && evt.Content.AsMember().Membership == event.MembershipInvite { - _, err := client.JoinRoomByID(evt.RoomID) + _, err := client.JoinRoomByID(context.TODO(), evt.RoomID) if err == nil { lastRoomID = evt.RoomID rl.SetPrompt(fmt.Sprintf("%s> ", lastRoomID)) @@ -137,7 +137,7 @@ func main() { log.Error().Msg("Wait for an incoming message before sending messages") continue } - resp, err := client.SendText(lastRoomID, line) + resp, err := client.SendText(context.TODO(), lastRoomID, line) if err != nil { log.Error().Err(err).Msg("Failed to send event") } else { From 728018c9d050d4ba2ac262cb0a18074e6d572c64 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 15 Dec 2023 16:26:51 +0200 Subject: [PATCH 0042/1647] Update changelog again --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ff29d6bc..e6b61ede 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,9 +3,9 @@ * **Breaking change *(bridge)*** Added raw event to portal membership handling functions. * **Breaking change *(client)*** Added context parameters to all functions -+ (thanks to [@recht] in [#144]). + (thanks to [@recht] in [#144]). * *(crypto)* Added experimental pure Go Olm implementation to replace libolm - Thanks to [@DerLukas15] in [#106]. + (thanks to [@DerLukas15] in [#106]). * You can use the `goolm` build tag to the new implementation. * *(bridge)* Added context parameter for bridge command events. * *(client)* Changed default syncer to not drop unknown events. From db66b4f5d0f0923488a2f7a3b0cfb3fbc0aff336 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 19 Dec 2023 15:40:32 +0200 Subject: [PATCH 0043/1647] Don't copy Python log configs as-is --- bridge/bridgeconfig/config.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bridge/bridgeconfig/config.go b/bridge/bridgeconfig/config.go index be42aab3..5f578008 100644 --- a/bridge/bridgeconfig/config.go +++ b/bridge/bridgeconfig/config.go @@ -264,6 +264,10 @@ func doUpgrade(helper *up.Helper) { if helper.GetNode("logging", "writers") == nil && (helper.GetNode("logging", "print_level") != nil || helper.GetNode("logging", "file_name_format") != nil) { _, _ = fmt.Fprintln(os.Stderr, "Migrating legacy log config") migrateLegacyLogConfig(helper) + } else if helper.GetNode("logging", "writers") == nil && (helper.GetNode("logging", "handlers") != nil) { + _, _ = fmt.Fprintln(os.Stderr, "Migrating Python log config is not currently supported") + // TODO implement? + //migratePythonLogConfig(helper) } else { helper.Copy(up.Map, "logging") } From 02e4140236c2868427097bff076b492a67c244ac Mon Sep 17 00:00:00 2001 From: Joakim Recht Date: Fri, 22 Dec 2023 12:56:45 +0100 Subject: [PATCH 0044/1647] Make funcs in the SyncStore interface return errors This should have been done in #144, but I forgot it. When context is being propagated, the context might be cancelled at any point, which will result in an error that needs to be handled. --- client.go | 15 +++++++++++--- crypto/sql_store.go | 27 +++++++++++++++---------- syncstore.go | 49 +++++++++++++++++++++++++-------------------- 3 files changed, 55 insertions(+), 36 deletions(-) diff --git a/client.go b/client.go index 0aff8734..1c2b9683 100644 --- a/client.go +++ b/client.go @@ -174,8 +174,15 @@ func (cli *Client) SyncWithContext(ctx context.Context) error { // We will keep syncing until the syncing state changes. Either because // Sync is called or StopSync is called. syncingID := cli.incrementSyncingID() - nextBatch := cli.Store.LoadNextBatch(ctx, cli.UserID) - filterID := cli.Store.LoadFilterID(ctx, cli.UserID) + nextBatch, err := cli.Store.LoadNextBatch(ctx, cli.UserID) + if err != nil { + return err + } + filterID, err := cli.Store.LoadFilterID(ctx, cli.UserID) + if err != nil { + return err + } + if filterID == "" { filterJSON := cli.Syncer.GetFilterJSON(cli.UserID) resFilter, err := cli.CreateFilter(ctx, filterJSON) @@ -183,7 +190,9 @@ func (cli *Client) SyncWithContext(ctx context.Context) error { return err } filterID = resFilter.FilterID - cli.Store.SaveFilterID(ctx, cli.UserID, filterID) + if err := cli.Store.SaveFilterID(ctx, cli.UserID, filterID); err != nil { + return err + } } lastSuccessfulSync := time.Now().Add(-cli.StreamSyncMinAge - 1*time.Hour) for { diff --git a/crypto/sql_store.go b/crypto/sql_store.go index c73a859a..64d62bc4 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -88,22 +88,27 @@ func (store *SQLCryptoStore) GetNextBatch(ctx context.Context) (string, error) { var _ mautrix.SyncStore = (*SQLCryptoStore)(nil) -func (store *SQLCryptoStore) SaveFilterID(ctx context.Context, _ id.UserID, _ string) {} -func (store *SQLCryptoStore) LoadFilterID(ctx context.Context, _ id.UserID) string { return "" } - -func (store *SQLCryptoStore) SaveNextBatch(ctx context.Context, _ id.UserID, nextBatchToken string) { - err := store.PutNextBatch(ctx, nextBatchToken) - if err != nil { - // TODO handle error - } +func (store *SQLCryptoStore) SaveFilterID(ctx context.Context, _ id.UserID, _ string) error { + return nil +} +func (store *SQLCryptoStore) LoadFilterID(ctx context.Context, _ id.UserID) (string, error) { + return "", nil } -func (store *SQLCryptoStore) LoadNextBatch(ctx context.Context, _ id.UserID) string { +func (store *SQLCryptoStore) SaveNextBatch(ctx context.Context, _ id.UserID, nextBatchToken string) error { + err := store.PutNextBatch(ctx, nextBatchToken) + if err != nil { + return fmt.Errorf("unable to store batch: %w", err) + } + return nil +} + +func (store *SQLCryptoStore) LoadNextBatch(ctx context.Context, _ id.UserID) (string, error) { nb, err := store.GetNextBatch(ctx) if err != nil { - // TODO handle error + return "", fmt.Errorf("unable to load batch: %w", err) } - return nb + return nb, nil } func (store *SQLCryptoStore) FindDeviceID() (deviceID id.DeviceID) { diff --git a/syncstore.go b/syncstore.go index 8b5b3a55..427941b3 100644 --- a/syncstore.go +++ b/syncstore.go @@ -3,6 +3,7 @@ package mautrix import ( "context" "errors" + "fmt" "maunium.net/go/mautrix/id" ) @@ -16,10 +17,10 @@ var _ SyncStore = (*AccountDataStore)(nil) // provided "MemorySyncStore" which just keeps data around in-memory which is lost on // restarts. type SyncStore interface { - SaveFilterID(ctx context.Context, userID id.UserID, filterID string) - LoadFilterID(ctx context.Context, userID id.UserID) string - SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) - LoadNextBatch(ctx context.Context, userID id.UserID) string + SaveFilterID(ctx context.Context, userID id.UserID, filterID string) error + LoadFilterID(ctx context.Context, userID id.UserID) (string, error) + SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) error + LoadNextBatch(ctx context.Context, userID id.UserID) (string, error) } // Deprecated: renamed to SyncStore @@ -36,23 +37,25 @@ type MemorySyncStore struct { } // SaveFilterID to memory. -func (s *MemorySyncStore) SaveFilterID(ctx context.Context, userID id.UserID, filterID string) { +func (s *MemorySyncStore) SaveFilterID(ctx context.Context, userID id.UserID, filterID string) error { s.Filters[userID] = filterID + return nil } // LoadFilterID from memory. -func (s *MemorySyncStore) LoadFilterID(ctx context.Context, userID id.UserID) string { - return s.Filters[userID] +func (s *MemorySyncStore) LoadFilterID(ctx context.Context, userID id.UserID) (string, error) { + return s.Filters[userID], nil } // SaveNextBatch to memory. -func (s *MemorySyncStore) SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) { +func (s *MemorySyncStore) SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) error { s.NextBatch[userID] = nextBatchToken + return nil } // LoadNextBatch from memory. -func (s *MemorySyncStore) LoadNextBatch(ctx context.Context, userID id.UserID) string { - return s.NextBatch[userID] +func (s *MemorySyncStore) LoadNextBatch(ctx context.Context, userID id.UserID) (string, error) { + return s.NextBatch[userID], nil } // NewMemorySyncStore constructs a new MemorySyncStore. @@ -76,25 +79,26 @@ type accountData struct { NextBatch string `json:"next_batch"` } -func (s *AccountDataStore) SaveFilterID(ctx context.Context, userID id.UserID, filterID string) { +func (s *AccountDataStore) SaveFilterID(ctx context.Context, userID id.UserID, filterID string) error { if userID.String() != s.client.UserID.String() { panic("AccountDataStore must only be used with a single account") } s.FilterID = filterID + return nil } -func (s *AccountDataStore) LoadFilterID(ctx context.Context, userID id.UserID) string { +func (s *AccountDataStore) LoadFilterID(ctx context.Context, userID id.UserID) (string, error) { if userID.String() != s.client.UserID.String() { panic("AccountDataStore must only be used with a single account") } - return s.FilterID + return s.FilterID, nil } -func (s *AccountDataStore) SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) { +func (s *AccountDataStore) SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) error { if userID.String() != s.client.UserID.String() { - panic("AccountDataStore must only be used with a single account") + return fmt.Errorf("AccountDataStore must only be used with a single account") } else if nextBatchToken == s.nextBatch { - return + return nil } data := accountData{ @@ -103,7 +107,7 @@ func (s *AccountDataStore) SaveNextBatch(ctx context.Context, userID id.UserID, err := s.client.SetAccountData(ctx, s.EventType, data) if err != nil { - s.client.Log.Warn().Err(err).Msg("Failed to save next batch token to account data") + return fmt.Errorf("failed to save next batch token to account data: %w", err) } else { s.client.Log.Debug(). Str("old_token", s.nextBatch). @@ -111,11 +115,12 @@ func (s *AccountDataStore) SaveNextBatch(ctx context.Context, userID id.UserID, Msg("Saved next batch token") s.nextBatch = nextBatchToken } + return nil } -func (s *AccountDataStore) LoadNextBatch(ctx context.Context, userID id.UserID) string { +func (s *AccountDataStore) LoadNextBatch(ctx context.Context, userID id.UserID) (string, error) { if userID.String() != s.client.UserID.String() { - panic("AccountDataStore must only be used with a single account") + return "", fmt.Errorf("AccountDataStore must only be used with a single account") } data := &accountData{} @@ -124,15 +129,15 @@ func (s *AccountDataStore) LoadNextBatch(ctx context.Context, userID id.UserID) if err != nil { if errors.Is(err, MNotFound) { s.client.Log.Debug().Msg("No next batch token found in account data") + return "", nil } else { - s.client.Log.Warn().Err(err).Msg("Failed to load next batch token from account data") + return "", fmt.Errorf("failed to load next batch token from account data: %w", err) } - return "" } s.nextBatch = data.NextBatch s.client.Log.Debug().Str("next_batch", data.NextBatch).Msg("Loaded next batch token from account data") - return s.nextBatch + return s.nextBatch, nil } // NewAccountDataStore returns a new AccountDataStore, which stores From 39844af48e42cd3e2410aa60fb1ae5ae7fbd8dca Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 28 Dec 2023 17:05:31 +0100 Subject: [PATCH 0045/1647] Log SQL line when bridge DB upgrade fails --- bridge/bridge.go | 10 +++++++--- go.mod | 4 ++-- go.sum | 8 ++++---- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/bridge/bridge.go b/bridge/bridge.go index 763cb4e0..9feea333 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -586,10 +586,14 @@ func (br *Bridge) init() { } func (br *Bridge) LogDBUpgradeErrorAndExit(name string, err error) { - br.ZLog.WithLevel(zerolog.FatalLevel). + logEvt := br.ZLog.WithLevel(zerolog.FatalLevel). Err(err). - Str("db_section", name). - Msg("Failed to initialize database") + Str("db_section", name) + var errWithLine *dbutil.PQErrorWithLine + if errors.As(err, &errWithLine) { + logEvt.Str("sql_line", errWithLine.Line) + } + logEvt.Msg("Failed to initialize database") if sqlError := (&sqlite3.Error{}); errors.As(err, sqlError) && sqlError.Code == sqlite3.ErrCorrupt { os.Exit(18) } else if errors.Is(err, dbutil.ErrForeignTables) { diff --git a/go.mod b/go.mod index 7fce5164..eb276954 100644 --- a/go.mod +++ b/go.mod @@ -12,10 +12,10 @@ require ( github.com/tidwall/gjson v1.17.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.6.0 - go.mau.fi/util v0.2.1 + go.mau.fi/util v0.2.2-0.20231228160422-22fdd4bbddeb go.mau.fi/zeroconfig v0.1.2 golang.org/x/crypto v0.15.0 - golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa + golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 golang.org/x/net v0.18.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 diff --git a/go.sum b/go.sum index ce8aac82..cefb17f4 100644 --- a/go.sum +++ b/go.sum @@ -36,14 +36,14 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.6.0 h1:boZcn2GTjpsynOsC0iJHnBWa4Bi0qzfJjthwauItG68= github.com/yuin/goldmark v1.6.0/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.mau.fi/util v0.2.1 h1:eazulhFE/UmjOFtPrGg6zkF5YfAyiDzQb8ihLMbsPWw= -go.mau.fi/util v0.2.1/go.mod h1:MjlzCQEMzJ+G8RsPawHzpLB8rwTo3aPIjG5FzBvQT/c= +go.mau.fi/util v0.2.2-0.20231228160422-22fdd4bbddeb h1:Is+6vDKgINRy9KHodvi7NElxoDaWA8sc2S3cF3+QWjs= +go.mau.fi/util v0.2.2-0.20231228160422-22fdd4bbddeb/go.mod h1:tiBX6nxVSOjU89jVQ7wBh3P8KjM26Lv1k7/I5QdSvBw= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.15.0 h1:frVn1TEaCEaZcn3Tmd7Y2b5KKPaZ+I32Q2OA3kYp5TA= golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g= -golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ= -golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE= +golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 h1:+iq7lrkxmFNBM7xx+Rae2W6uyPfhPeDWD+n+JgppptE= +golang.org/x/exp v0.0.0-20231219180239-dc181d75b848/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= golang.org/x/net v0.18.0 h1:mIYleuAkSbHh0tCv7RvjL3F6ZVbLjq4+R7zbOn3Kokg= golang.org/x/net v0.18.0/go.mod h1:/czyP5RqHAH4odGYxBJ1qz0+CE5WZ+2j1YgoEo8F2jQ= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= From 30b4ccfdb67162047c8f7f8c0a31febb0b54b4e6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 28 Dec 2023 17:06:42 +0100 Subject: [PATCH 0046/1647] Update actions --- .github/workflows/go.yml | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index c6f6e522..602d5ece 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -6,10 +6,10 @@ jobs: lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v3 + uses: actions/setup-go@v5 with: go-version: "1.21" cache: true @@ -31,14 +31,15 @@ jobs: build: runs-on: ubuntu-latest strategy: + fail-fast: false matrix: go-version: ["1.20", "1.21"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Go ${{ matrix.go-version }} - uses: actions/setup-go@v3 + uses: actions/setup-go@v5 with: go-version: ${{ matrix.go-version }} cache: true From 2a83ec587623d5f8b2c8107b60a02a4086a96954 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 28 Dec 2023 17:09:20 +0100 Subject: [PATCH 0047/1647] Update dependencies --- go.mod | 12 ++++++------ go.sum | 26 +++++++++++++------------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/go.mod b/go.mod index eb276954..07e3efdc 100644 --- a/go.mod +++ b/go.mod @@ -6,17 +6,17 @@ require ( github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.5.0 github.com/lib/pq v1.10.9 - github.com/mattn/go-sqlite3 v1.14.18 + github.com/mattn/go-sqlite3 v1.14.19 github.com/rs/zerolog v1.31.0 github.com/stretchr/testify v1.8.4 github.com/tidwall/gjson v1.17.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.6.0 - go.mau.fi/util v0.2.2-0.20231228160422-22fdd4bbddeb + go.mau.fi/util v0.2.2-0.20231228160822-a6d40c214e80 go.mau.fi/zeroconfig v0.1.2 - golang.org/x/crypto v0.15.0 - golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 - golang.org/x/net v0.18.0 + golang.org/x/crypto v0.17.0 + golang.org/x/exp v0.0.0-20231226003508-02704c960a9b + golang.org/x/net v0.19.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 maunium.net/go/maulogger/v2 v2.4.1 @@ -30,6 +30,6 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect - golang.org/x/sys v0.14.0 // indirect + golang.org/x/sys v0.15.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index cefb17f4..f52dc4cd 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,4 @@ -github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= +github.com/DATA-DOG/go-sqlmock v1.5.1 h1:FK6RCIUSfmbnI/imIICmboyQBkOckutaa6R5YYlLZyo= github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -15,8 +15,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI= -github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/mattn/go-sqlite3 v1.14.19 h1:fhGleo2h1p8tVChob4I9HpmVFIAkKGpiukdrgQbWfGI= +github.com/mattn/go-sqlite3 v1.14.19/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -36,21 +36,21 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.6.0 h1:boZcn2GTjpsynOsC0iJHnBWa4Bi0qzfJjthwauItG68= github.com/yuin/goldmark v1.6.0/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.mau.fi/util v0.2.2-0.20231228160422-22fdd4bbddeb h1:Is+6vDKgINRy9KHodvi7NElxoDaWA8sc2S3cF3+QWjs= -go.mau.fi/util v0.2.2-0.20231228160422-22fdd4bbddeb/go.mod h1:tiBX6nxVSOjU89jVQ7wBh3P8KjM26Lv1k7/I5QdSvBw= +go.mau.fi/util v0.2.2-0.20231228160822-a6d40c214e80 h1:zcfIxHgzZpgGSJv/FUVbOjO4ZWa12En4TGhxgUI/QH0= +go.mau.fi/util v0.2.2-0.20231228160822-a6d40c214e80/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.15.0 h1:frVn1TEaCEaZcn3Tmd7Y2b5KKPaZ+I32Q2OA3kYp5TA= -golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g= -golang.org/x/exp v0.0.0-20231219180239-dc181d75b848 h1:+iq7lrkxmFNBM7xx+Rae2W6uyPfhPeDWD+n+JgppptE= -golang.org/x/exp v0.0.0-20231219180239-dc181d75b848/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= -golang.org/x/net v0.18.0 h1:mIYleuAkSbHh0tCv7RvjL3F6ZVbLjq4+R7zbOn3Kokg= -golang.org/x/net v0.18.0/go.mod h1:/czyP5RqHAH4odGYxBJ1qz0+CE5WZ+2j1YgoEo8F2jQ= +golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= +golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +golang.org/x/exp v0.0.0-20231226003508-02704c960a9b h1:kLiC65FbiHWFAOu+lxwNPujcsl8VYyTYYEZnsOO1WK4= +golang.org/x/exp v0.0.0-20231226003508-02704c960a9b/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= +golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= +golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= -golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= From 370913378ad1f40ecba81be8c9d95b0043754b2a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 29 Dec 2023 21:16:57 +0100 Subject: [PATCH 0048/1647] Log full pq error details in bridge upgrades --- bridge/bridge.go | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/bridge/bridge.go b/bridge/bridge.go index 9feea333..6313b968 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -20,6 +20,7 @@ import ( "syscall" "time" + "github.com/lib/pq" "github.com/mattn/go-sqlite3" "github.com/rs/zerolog" deflog "github.com/rs/zerolog/log" @@ -585,6 +586,37 @@ func (br *Bridge) init() { br.Child.Init() } +type zerologPQError pq.Error + +func (zpe *zerologPQError) MarshalZerologObject(evt *zerolog.Event) { + maybeStr := func(field, value string) { + if value != "" { + evt.Str(field, value) + } + } + maybeStr("severity", zpe.Severity) + if name := zpe.Code.Name(); name != "" { + evt.Str("code", name) + } else if zpe.Code != "" { + evt.Str("code", string(zpe.Code)) + } + //maybeStr("message", zpe.Message) + maybeStr("detail", zpe.Detail) + maybeStr("hint", zpe.Hint) + maybeStr("position", zpe.Position) + maybeStr("internal_position", zpe.InternalPosition) + maybeStr("internal_query", zpe.InternalQuery) + maybeStr("where", zpe.Where) + maybeStr("schema", zpe.Schema) + maybeStr("table", zpe.Table) + maybeStr("column", zpe.Column) + maybeStr("data_type_name", zpe.DataTypeName) + maybeStr("constraint", zpe.Constraint) + maybeStr("file", zpe.File) + maybeStr("line", zpe.Line) + maybeStr("routine", zpe.Routine) +} + func (br *Bridge) LogDBUpgradeErrorAndExit(name string, err error) { logEvt := br.ZLog.WithLevel(zerolog.FatalLevel). Err(err). @@ -593,6 +625,10 @@ func (br *Bridge) LogDBUpgradeErrorAndExit(name string, err error) { if errors.As(err, &errWithLine) { logEvt.Str("sql_line", errWithLine.Line) } + var pqe *pq.Error + if errors.As(err, &pqe) { + logEvt.Object("pq_error", (*zerologPQError)(pqe)) + } logEvt.Msg("Failed to initialize database") if sqlError := (&sqlite3.Error{}); errors.As(err, sqlError) && sqlError.Code == sqlite3.ErrCorrupt { os.Exit(18) From d014d56e852009c332c5ad353404e5bd6fa61f5d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 3 Jan 2024 14:52:42 +0200 Subject: [PATCH 0049/1647] Add blurhash to file info struct --- event/message.go | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/event/message.go b/event/message.go index 009709af..6512f9be 100644 --- a/event/message.go +++ b/event/message.go @@ -199,10 +199,14 @@ type FileInfo struct { ThumbnailInfo *FileInfo `json:"thumbnail_info,omitempty"` ThumbnailURL id.ContentURIString `json:"thumbnail_url,omitempty"` ThumbnailFile *EncryptedFileInfo `json:"thumbnail_file,omitempty"` - Width int `json:"-"` - Height int `json:"-"` - Duration int `json:"-"` - Size int `json:"-"` + + Blurhash string `json:"blurhash,omitempty"` + AnoaBlurhash string `json:"xyz.amorgan.blurhash,omitempty"` + + Width int `json:"-"` + Height int `json:"-"` + Duration int `json:"-"` + Size int `json:"-"` } type serializableFileInfo struct { @@ -211,6 +215,9 @@ type serializableFileInfo struct { ThumbnailURL id.ContentURIString `json:"thumbnail_url,omitempty"` ThumbnailFile *EncryptedFileInfo `json:"thumbnail_file,omitempty"` + Blurhash string `json:"blurhash,omitempty"` + AnoaBlurhash string `json:"xyz.amorgan.blurhash,omitempty"` + Width json.Number `json:"w,omitempty"` Height json.Number `json:"h,omitempty"` Duration json.Number `json:"duration,omitempty"` @@ -226,6 +233,9 @@ func (sfi *serializableFileInfo) CopyFrom(fileInfo *FileInfo) *serializableFileI ThumbnailURL: fileInfo.ThumbnailURL, ThumbnailInfo: (&serializableFileInfo{}).CopyFrom(fileInfo.ThumbnailInfo), ThumbnailFile: fileInfo.ThumbnailFile, + + Blurhash: fileInfo.Blurhash, + AnoaBlurhash: fileInfo.AnoaBlurhash, } if fileInfo.Width > 0 { sfi.Width = json.Number(strconv.Itoa(fileInfo.Width)) @@ -252,6 +262,8 @@ func (sfi *serializableFileInfo) CopyTo(fileInfo *FileInfo) { MimeType: sfi.MimeType, ThumbnailURL: sfi.ThumbnailURL, ThumbnailFile: sfi.ThumbnailFile, + Blurhash: sfi.Blurhash, + AnoaBlurhash: sfi.AnoaBlurhash, } if sfi.ThumbnailInfo != nil { fileInfo.ThumbnailInfo = &FileInfo{} From 88631708a41b224bc8163ace132edcc5a8ddcea4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 4 Jan 2024 14:57:37 +0200 Subject: [PATCH 0050/1647] Add context to UpdateBridgeInfo --- bridge/bridge.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bridge/bridge.go b/bridge/bridge.go index 6313b968..960dce9a 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -61,7 +61,7 @@ type Portal interface { MainIntent() *appservice.IntentAPI ReceiveMatrixEvent(user User, evt *event.Event) - UpdateBridgeInfo() + UpdateBridgeInfo(ctx context.Context) } type MembershipHandlingPortal interface { @@ -720,7 +720,7 @@ func (br *Bridge) ResendBridgeInfo() { } br.ZLog.Info().Msg("Re-sending bridge info state event to all portals") for _, portal := range br.Child.GetAllIPortals() { - portal.UpdateBridgeInfo() + portal.UpdateBridgeInfo(context.TODO()) } br.ZLog.Info().Msg("Finished re-sending bridge info state events") } From c5f5135c966733d6e883f797e1368a62a4438c86 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 7 Jan 2024 15:35:10 +0200 Subject: [PATCH 0051/1647] Try downloading avatar before setting it for bot --- appservice/intent.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/appservice/intent.go b/appservice/intent.go index 348eee2a..f5f066d1 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -399,6 +399,13 @@ func (intent *IntentAPI) SetAvatarURL(ctx context.Context, avatarURL id.ContentU // No need to update return nil } + if !avatarURL.IsEmpty() { + // Some homeservers require the avatar to be downloaded before setting it + body, _ := intent.Client.Download(ctx, avatarURL) + if body != nil { + _ = body.Close() + } + } return intent.Client.SetAvatarURL(ctx, avatarURL) } From 48bfc596f048f373ed815e93ab511fa5e19f30cb Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 7 Jan 2024 15:35:58 +0200 Subject: [PATCH 0052/1647] Fix incorrect context.Backgrounds --- appservice/http.go | 1 + bridge/commands/admin.go | 3 +-- bridge/commands/doublepuppet.go | 4 +--- bridge/commands/event.go | 8 ++++---- bridge/commands/handler.go | 4 +--- bridge/crypto.go | 17 +++++++---------- bridge/doublepuppet.go | 3 +-- bridge/matrix.go | 10 +++++----- crypto/cryptohelper/cryptohelper.go | 4 ++-- crypto/decryptolm.go | 2 +- crypto/machine.go | 6 +++--- crypto/verification_in_room.go | 2 +- 12 files changed, 28 insertions(+), 36 deletions(-) diff --git a/appservice/http.go b/appservice/http.go index e73bf621..2219687a 100644 --- a/appservice/http.go +++ b/appservice/http.go @@ -134,6 +134,7 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) { return } log := as.Log.With().Str("transaction_id", txnID).Logger() + // Don't use request context, handling shouldn't be stopped even if the request times out ctx := context.Background() ctx = log.WithContext(ctx) if as.txnIDC.IsProcessed(txnID) { diff --git a/bridge/commands/admin.go b/bridge/commands/admin.go index d07ada1a..cf38d6c5 100644 --- a/bridge/commands/admin.go +++ b/bridge/commands/admin.go @@ -7,7 +7,6 @@ package commands import ( - "context" "strconv" "maunium.net/go/mautrix/id" @@ -58,7 +57,7 @@ func fnSetPowerLevel(ce *Event) { ce.Reply("**Usage:** `set-pl [user] `") return } - _, err = ce.Portal.MainIntent().SetPowerLevel(context.Background(), ce.RoomID, userID, level) + _, err = ce.Portal.MainIntent().SetPowerLevel(ce.Ctx, ce.RoomID, userID, level) if err != nil { ce.Reply("Failed to set power levels: %v", err) } diff --git a/bridge/commands/doublepuppet.go b/bridge/commands/doublepuppet.go index 9501d01f..3f074951 100644 --- a/bridge/commands/doublepuppet.go +++ b/bridge/commands/doublepuppet.go @@ -6,8 +6,6 @@ package commands -import "context" - var CommandLoginMatrix = &FullHandler{ Func: fnLoginMatrix, Name: "login-matrix", @@ -56,7 +54,7 @@ func fnPingMatrix(ce *Event) { ce.Reply("You are not logged in with your Matrix account.") return } - resp, err := puppet.CustomIntent().Whoami(context.Background()) + resp, err := puppet.CustomIntent().Whoami(ce.Ctx) if err != nil { ce.Reply("Failed to validate Matrix login: %v", err) } else { diff --git a/bridge/commands/event.go b/bridge/commands/event.go index 24cf2eb9..f1443d63 100644 --- a/bridge/commands/event.go +++ b/bridge/commands/event.go @@ -67,7 +67,7 @@ func (ce *Event) Reply(msg string, args ...interface{}) { func (ce *Event) ReplyAdvanced(msg string, allowMarkdown, allowHTML bool) { content := format.RenderMarkdown(msg, allowMarkdown, allowHTML) content.MsgType = event.MsgNotice - _, err := ce.MainIntent().SendMessageEvent(context.Background(), ce.RoomID, event.EventMessage, content) + _, err := ce.MainIntent().SendMessageEvent(ce.Ctx, ce.RoomID, event.EventMessage, content) if err != nil { ce.ZLog.Error().Err(err).Msgf("Failed to reply to command") } @@ -75,7 +75,7 @@ func (ce *Event) ReplyAdvanced(msg string, allowMarkdown, allowHTML bool) { // React sends a reaction to the command. func (ce *Event) React(key string) { - _, err := ce.MainIntent().SendReaction(context.Background(), ce.RoomID, ce.EventID, key) + _, err := ce.MainIntent().SendReaction(ce.Ctx, ce.RoomID, ce.EventID, key) if err != nil { ce.ZLog.Error().Err(err).Msgf("Failed to react to command") } @@ -83,7 +83,7 @@ func (ce *Event) React(key string) { // Redact redacts the command. func (ce *Event) Redact(req ...mautrix.ReqRedact) { - _, err := ce.MainIntent().RedactEvent(context.Background(), ce.RoomID, ce.EventID, req...) + _, err := ce.MainIntent().RedactEvent(ce.Ctx, ce.RoomID, ce.EventID, req...) if err != nil { ce.ZLog.Error().Err(err).Msgf("Failed to redact command") } @@ -91,7 +91,7 @@ func (ce *Event) Redact(req ...mautrix.ReqRedact) { // MarkRead marks the command event as read. func (ce *Event) MarkRead() { - err := ce.MainIntent().SendReceipt(context.Background(), ce.RoomID, ce.EventID, event.ReceiptTypeRead, nil) + err := ce.MainIntent().SendReceipt(ce.Ctx, ce.RoomID, ce.EventID, event.ReceiptTypeRead, nil) if err != nil { ce.ZLog.Error().Err(err).Msgf("Failed to mark command as read") } diff --git a/bridge/commands/handler.go b/bridge/commands/handler.go index cfed683b..ab6899c0 100644 --- a/bridge/commands/handler.go +++ b/bridge/commands/handler.go @@ -7,8 +7,6 @@ package commands import ( - "context" - "maunium.net/go/mautrix/bridge" "maunium.net/go/mautrix/bridge/bridgeconfig" "maunium.net/go/mautrix/event" @@ -78,7 +76,7 @@ func (fh *FullHandler) ShowInHelp(ce *Event) bool { } func (fh *FullHandler) userHasRoomPermission(ce *Event) bool { - levels, err := ce.MainIntent().PowerLevels(context.Background(), ce.RoomID) + levels, err := ce.MainIntent().PowerLevels(ce.Ctx, ce.RoomID) if err != nil { ce.ZLog.Warn().Err(err).Msg("Failed to check room power levels") ce.Reply("Failed to get room power levels to see if you're allowed to use that command") diff --git a/bridge/crypto.go b/bridge/crypto.go index 73e5dbf8..a1a76ebd 100644 --- a/bridge/crypto.go +++ b/bridge/crypto.go @@ -81,7 +81,7 @@ func (helper *CryptoHelper) Init() error { } var isExistingDevice bool - helper.client, isExistingDevice, err = helper.loginBot() + helper.client, isExistingDevice, err = helper.loginBot(context.TODO()) if err != nil { return err } @@ -128,16 +128,15 @@ func (helper *CryptoHelper) Init() error { return err } if isExistingDevice { - helper.verifyKeysAreOnServer() + helper.verifyKeysAreOnServer(context.TODO()) } - go helper.resyncEncryptionInfo() + go helper.resyncEncryptionInfo(context.TODO()) return nil } -func (helper *CryptoHelper) resyncEncryptionInfo() { - ctx := context.Background() +func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) { log := helper.log.With().Str("action", "resync encryption event").Logger() rows, err := helper.bridge.DB.QueryContext(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`) if err != nil { @@ -223,8 +222,7 @@ func (helper *CryptoHelper) allowKeyShare(ctx context.Context, device *id.Device } } -func (helper *CryptoHelper) loginBot() (*mautrix.Client, bool, error) { - ctx := context.Background() +func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool, error) { deviceID := helper.store.FindDeviceID() if len(deviceID) > 0 { helper.log.Debug().Str("device_id", deviceID.String()).Msg("Found existing device ID for bot in database") @@ -256,8 +254,7 @@ func (helper *CryptoHelper) loginBot() (*mautrix.Client, bool, error) { return client, deviceID != "", nil } -func (helper *CryptoHelper) verifyKeysAreOnServer() { - ctx := context.Background() +func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) { helper.log.Debug().Msg("Making sure keys are still on server") resp, err := helper.client.QueryKeys(ctx, &mautrix.ReqQueryKeys{ DeviceKeys: map[id.UserID]mautrix.DeviceIDList{ @@ -336,7 +333,7 @@ func (helper *CryptoHelper) Reset(startAfterReset bool) { helper.log.Debug().Msg("Crypto syncer stopped, clearing database") helper.clearDatabase() helper.log.Debug().Msg("Crypto database cleared, logging out of all sessions") - _, err := helper.client.LogoutAll(context.Background()) + _, err := helper.client.LogoutAll(context.TODO()) if err != nil { helper.log.Warn().Err(err).Msg("Failed to log out all devices") } diff --git a/bridge/doublepuppet.go b/bridge/doublepuppet.go index 35903efd..265d3d5c 100644 --- a/bridge/doublepuppet.go +++ b/bridge/doublepuppet.go @@ -118,12 +118,11 @@ var ( const useConfigASToken = "appservice-config" const asTokenModePrefix = "as_token:" -func (dp *doublePuppetUtil) Setup(mxid id.UserID, savedAccessToken string, reloginOnFail bool) (intent *appservice.IntentAPI, newAccessToken string, err error) { +func (dp *doublePuppetUtil) Setup(ctx context.Context, mxid id.UserID, savedAccessToken string, reloginOnFail bool) (intent *appservice.IntentAPI, newAccessToken string, err error) { if len(mxid) == 0 { err = ErrNoMXID return } - ctx := context.Background() _, homeserver, _ := mxid.Parse() loginSecret, hasSecret := dp.br.Config.Bridge.GetDoublePuppetConfig().SharedSecretMap[homeserver] // Special case appservice: prefix to not login and use it as an as_token directly. diff --git a/bridge/matrix.go b/bridge/matrix.go index f9a86d80..90453c13 100644 --- a/bridge/matrix.go +++ b/bridge/matrix.go @@ -87,7 +87,7 @@ func (mx *MatrixHandler) HandleEncryption(evt *event.Event) { Msg("Encryption was enabled in room") portal.MarkEncrypted() if portal.IsPrivateChat() { - err := mx.as.BotIntent().EnsureJoined(context.Background(), evt.RoomID, appservice.EnsureJoinedParams{BotOverride: portal.MainIntent().Client}) + err := mx.as.BotIntent().EnsureJoined(context.TODO(), evt.RoomID, appservice.EnsureJoinedParams{BotOverride: portal.MainIntent().Client}) if err != nil { mx.log.Err(err). Str("room_id", evt.RoomID.String()). @@ -237,7 +237,7 @@ func (mx *MatrixHandler) HandleMembership(evt *event.Event) { return } defer mx.TrackEventDuration(evt.Type)() - ctx := context.Background() + ctx := context.TODO() if mx.bridge.Crypto != nil { mx.bridge.Crypto.HandleMemberEvent(evt) @@ -481,7 +481,7 @@ func (mx *MatrixHandler) HandleEncrypted(evt *event.Event) { return } content := evt.Content.AsEncrypted() - ctx := context.Background() + ctx := context.TODO() log := mx.log.With(). Str("event_id", evt.ID.String()). Str("session_id", content.SessionID.String()). @@ -526,7 +526,7 @@ func (mx *MatrixHandler) waitLongerForSession(ctx context.Context, evt *event.Ev Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())). Msg("Couldn't find session, requesting keys and waiting longer...") - go mx.bridge.Crypto.RequestSession(context.Background(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) + go mx.bridge.Crypto.RequestSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) errorEventID := mx.sendCryptoStatusError(ctx, evt, "", fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), 1, false) if !mx.bridge.Crypto.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { @@ -553,7 +553,7 @@ func (mx *MatrixHandler) HandleMessage(evt *event.Event) { Str("room_id", evt.RoomID.String()). Str("sender", evt.Sender.String()). Logger() - ctx := log.WithContext(context.Background()) + ctx := log.WithContext(context.TODO()) if mx.shouldIgnoreEvent(evt) { return } else if !evt.Mautrix.WasEncrypted && mx.bridge.Config.Bridge.GetEncryptionConfig().Require { diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go index 9d071ba9..293166fe 100644 --- a/crypto/cryptohelper/cryptohelper.go +++ b/crypto/cryptohelper/cryptohelper.go @@ -124,7 +124,7 @@ func (helper *CryptoHelper) Init() error { } else { stateStore = helper.client.StateStore.(crypto.StateStore) } - ctx := context.Background() + ctx := context.TODO() var cryptoStore crypto.Store if helper.unmanagedCryptoStore == nil { managedCryptoStore := crypto.NewSQLCryptoStore(helper.dbForManagedStores, dbutil.ZeroLogger(helper.log.With().Str("db_section", "crypto").Logger()), helper.DBAccountID, helper.client.DeviceID, helper.pickleKey) @@ -310,7 +310,7 @@ func (helper *CryptoHelper) waitLongerForSession(log zerolog.Logger, src mautrix content := evt.Content.AsEncrypted() log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...") - go helper.RequestSession(context.Background(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) + go helper.RequestSession(context.TODO(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) if !helper.mach.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { log.Debug().Msg("Didn't get session, giving up") diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go index e5394e5f..57b39f0b 100644 --- a/crypto/decryptolm.go +++ b/crypto/decryptolm.go @@ -228,7 +228,7 @@ const MinUnwedgeInterval = 1 * time.Hour func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, senderKey id.SenderKey) { log = log.With().Str("action", "unwedge olm session").Logger() - ctx := log.WithContext(context.Background()) + ctx := log.WithContext(context.TODO()) mach.recentlyUnwedgedLock.Lock() prevUnwedge, ok := mach.recentlyUnwedged[senderKey] delta := time.Now().Sub(prevUnwedge) diff --git a/crypto/machine.go b/crypto/machine.go index 37a21da3..35f8c121 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -243,7 +243,7 @@ func (mach *OlmMachine) HandleOTKCounts(otkCount *mautrix.OTKCount) { if otkCount.SignedCurve25519 < int(minCount) { traceID := time.Now().Format("15:04:05.000000") log := mach.Log.With().Str("trace_id", traceID).Logger() - ctx := log.WithContext(context.Background()) + ctx := log.WithContext(context.TODO()) log.Debug(). Int("keys_left", otkCount.Curve25519). Msg("Sync response said we have less than 50 signed curve25519 keys left, sharing new ones...") @@ -334,7 +334,7 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) { Str("sender", evt.Sender.String()). Str("type", evt.Type.Type). Logger() - ctx := log.WithContext(context.Background()) + ctx := log.WithContext(context.TODO()) if evt.Type != event.ToDeviceEncrypted { log.Debug().Msg("Starting handling to-device event") } @@ -344,7 +344,7 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) { Str("sender_key", content.SenderKey.String()). Logger() log.Debug().Msg("Handling encrypted to-device event") - ctx = log.WithContext(context.Background()) + ctx = log.WithContext(ctx) decryptedEvt, err := mach.decryptOlmEvent(ctx, evt) if err != nil { log.Error().Err(err).Msg("Failed to decrypt to-device event") diff --git a/crypto/verification_in_room.go b/crypto/verification_in_room.go index 325b45ba..240c52b2 100644 --- a/crypto/verification_in_room.go +++ b/crypto/verification_in_room.go @@ -38,7 +38,7 @@ func (mach *OlmMachine) ProcessInRoomVerification(evt *event.Event) error { return ErrNoRelatesTo } - ctx := context.Background() + ctx := context.TODO() switch content := evt.Content.Parsed.(type) { case *event.MessageEventContent: if content.MsgType == event.MsgVerificationRequest { From 25bc36bc7ae79afe8b5e5f053fbd8bc8fa68acbc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 7 Jan 2024 22:44:06 +0200 Subject: [PATCH 0053/1647] Add more contexts everywhere --- CHANGELOG.md | 4 +- appservice/appservice.go | 10 +- appservice/http.go | 2 +- appservice/intent.go | 53 +- appservice/registration.go | 3 +- bridge/bridge.go | 18 +- bridge/commands/admin.go | 2 +- bridge/crypto.go | 91 ++-- bridge/cryptostore.go | 6 +- bridge/matrix.go | 10 +- client.go | 113 +++-- crypto/account.go | 2 +- crypto/cross_sign_pubkey.go | 4 +- crypto/cross_sign_signing.go | 12 +- crypto/cross_sign_store.go | 10 +- crypto/cross_sign_test.go | 32 +- crypto/cross_sign_validation.go | 14 +- crypto/cryptohelper/cryptohelper.go | 44 +- crypto/decryptmegolm.go | 12 +- crypto/decryptolm.go | 10 +- crypto/devicelist.go | 33 +- crypto/encryptmegolm.go | 22 +- crypto/encryptolm.go | 12 +- crypto/keyimport.go | 16 +- crypto/keysharing.go | 15 +- crypto/machine.go | 54 +- crypto/machine_test.go | 22 +- crypto/sql_store.go | 465 ++++++++---------- crypto/sql_store_upgrade/upgrade.go | 3 +- crypto/store.go | 138 +++--- crypto/store_test.go | 34 +- crypto/verification.go | 4 +- go.mod | 2 +- go.sum | 4 +- sqlstatestore/statestore.go | 309 ++++++------ .../v05-mark-encryption-state-resync.go | 7 +- statestore.go | 134 ++--- 37 files changed, 886 insertions(+), 840 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e6b61ede..7abbe587 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,8 +2,8 @@ * **Breaking change *(bridge)*** Added raw event to portal membership handling functions. -* **Breaking change *(client)*** Added context parameters to all functions - (thanks to [@recht] in [#144]). +* **Breaking change *(everything)*** Added context parameters to all functions + (started by [@recht] in [#144]). * *(crypto)* Added experimental pure Go Olm implementation to replace libolm (thanks to [@DerLukas15] in [#106]). * You can use the `goolm` build tag to the new implementation. diff --git a/appservice/appservice.go b/appservice/appservice.go index 98d1463f..dc5e82be 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -93,12 +93,12 @@ type WebsocketHandler func(WebsocketCommand) (ok bool, data interface{}) type StateStore interface { mautrix.StateStore - IsRegistered(userID id.UserID) bool - MarkRegistered(userID id.UserID) + IsRegistered(ctx context.Context, userID id.UserID) (bool, error) + MarkRegistered(ctx context.Context, userID id.UserID) error - GetPowerLevel(roomID id.RoomID, userID id.UserID) int - GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int - HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool + GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error) + GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error) + HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error) } // AppService is the main config for all appservices. diff --git a/appservice/http.go b/appservice/http.go index 2219687a..1d4c7f22 100644 --- a/appservice/http.go +++ b/appservice/http.go @@ -236,7 +236,7 @@ func (as *AppService) handleEvents(ctx context.Context, evts []*event.Event, def } if evt.Type.IsState() { - mautrix.UpdateStateStore(as.StateStore, evt) + mautrix.UpdateStateStore(ctx, as.StateStore, evt) } var ch chan *event.Event if evt.Type.Class == event.ToDeviceEventType { diff --git a/appservice/intent.go b/appservice/intent.go index f5f066d1..bdf0f066 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -13,6 +13,8 @@ import ( "strings" "sync" + "github.com/rs/zerolog" + "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -57,17 +59,26 @@ func (intent *IntentAPI) Register(ctx context.Context) error { } func (intent *IntentAPI) EnsureRegistered(ctx context.Context) error { + if intent.IsCustomPuppet { + return nil + } intent.registerLock.Lock() defer intent.registerLock.Unlock() - if intent.IsCustomPuppet || intent.as.StateStore.IsRegistered(intent.UserID) { + isRegistered, err := intent.as.StateStore.IsRegistered(ctx, intent.UserID) + if err != nil { + return fmt.Errorf("failed to check if user is registered: %w", err) + } else if isRegistered { return nil } - err := intent.Register(ctx) + err = intent.Register(ctx) if err != nil && !errors.Is(err, mautrix.MUserInUse) { return fmt.Errorf("failed to ensure registered: %w", err) } - intent.as.StateStore.MarkRegistered(intent.UserID) + err = intent.as.StateStore.MarkRegistered(ctx, intent.UserID) + if err != nil { + return fmt.Errorf("failed to mark user as registered in state store: %w", err) + } return nil } @@ -83,7 +94,7 @@ func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, ext } else if len(extra) == 1 { params = extra[0] } - if intent.as.StateStore.IsInRoom(roomID, intent.UserID) && !params.IgnoreCache { + if intent.as.StateStore.IsInRoom(ctx, roomID, intent.UserID) && !params.IgnoreCache { return nil } @@ -111,7 +122,10 @@ func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, ext return fmt.Errorf("failed to ensure joined after invite: %w", err) } } - intent.as.StateStore.SetMembership(resp.RoomID, intent.UserID, event.MembershipJoin) + err = intent.as.StateStore.SetMembership(ctx, resp.RoomID, intent.UserID, event.MembershipJoin) + if err != nil { + return fmt.Errorf("failed to set membership in state store: %w", err) + } return nil } @@ -205,13 +219,14 @@ func (intent *IntentAPI) SendCustomMembershipEvent(ctx context.Context, roomID i Membership: membership, Reason: reason, } - memberContent, ok := intent.as.StateStore.TryGetMember(roomID, target) - if !ok { + memberContent, err := intent.as.StateStore.TryGetMember(ctx, roomID, target) + if err != nil { + return nil, fmt.Errorf("failed to get old member content from state store: %w", err) + } else if memberContent == nil { if intent.as.GetProfile != nil { memberContent = intent.as.GetProfile(target, roomID) - ok = memberContent != nil } - if !ok { + if memberContent == nil { profile, err := intent.GetProfile(ctx, target) if err != nil { intent.Log.Debug().Err(err). @@ -224,7 +239,7 @@ func (intent *IntentAPI) SendCustomMembershipEvent(ctx context.Context, roomID i } } } - if ok && memberContent != nil { + if memberContent != nil { content.Displayname = memberContent.Displayname content.AvatarURL = memberContent.AvatarURL } @@ -297,15 +312,25 @@ func (intent *IntentAPI) UnbanUser(ctx context.Context, roomID id.RoomID, req *m } func (intent *IntentAPI) Member(ctx context.Context, roomID id.RoomID, userID id.UserID) *event.MemberEventContent { - member, ok := intent.as.StateStore.TryGetMember(roomID, userID) - if !ok { + member, err := intent.as.StateStore.TryGetMember(ctx, roomID, userID) + if err != nil { + zerolog.Ctx(ctx).Warn().Err(err). + Str("room_id", roomID.String()). + Str("user_id", userID.String()). + Msg("Failed to get member from state store") + } + if member == nil { _ = intent.StateEvent(ctx, roomID, event.StateMember, string(userID), &member) } return member } func (intent *IntentAPI) PowerLevels(ctx context.Context, roomID id.RoomID) (pl *event.PowerLevelsEventContent, err error) { - pl = intent.as.StateStore.GetPowerLevels(roomID) + pl, err = intent.as.StateStore.GetPowerLevels(ctx, roomID) + if err != nil { + err = fmt.Errorf("failed to get cached power levels: %w", err) + return + } if pl == nil { pl = &event.PowerLevelsEventContent{} err = intent.StateEvent(ctx, roomID, event.StatePowerLevels, "", pl) @@ -417,7 +442,7 @@ func (intent *IntentAPI) Whoami(ctx context.Context) (*mautrix.RespWhoami, error } func (intent *IntentAPI) EnsureInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) error { - if !intent.as.StateStore.IsInvited(roomID, userID) { + if !intent.as.StateStore.IsInvited(ctx, roomID, userID) { _, err := intent.InviteUser(ctx, roomID, &mautrix.ReqInviteUser{ UserID: userID, }) diff --git a/appservice/registration.go b/appservice/registration.go index 464ea1d6..b11bd84b 100644 --- a/appservice/registration.go +++ b/appservice/registration.go @@ -10,9 +10,8 @@ import ( "os" "regexp" - "gopkg.in/yaml.v3" - "go.mau.fi/util/random" + "gopkg.in/yaml.v3" ) // Registration contains the data in a Matrix appservice registration. diff --git a/bridge/bridge.go b/bridge/bridge.go index 960dce9a..6ad19720 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -215,15 +215,15 @@ type Bridge struct { type Crypto interface { HandleMemberEvent(*event.Event) - Decrypt(*event.Event) (*event.Event, error) - Encrypt(id.RoomID, event.Type, *event.Content) error - WaitForSession(id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool + Decrypt(context.Context, *event.Event) (*event.Event, error) + Encrypt(context.Context, id.RoomID, event.Type, *event.Content) error + WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) - ResetSession(id.RoomID) - Init() error + ResetSession(context.Context, id.RoomID) + Init(ctx context.Context) error Start() Stop() - Reset(startAfterReset bool) + Reset(ctx context.Context, startAfterReset bool) Client() *mautrix.Client ShareKeys(context.Context) error } @@ -650,10 +650,10 @@ func (br *Bridge) WaitWebsocketConnected() { func (br *Bridge) start() { br.ZLog.Debug().Msg("Running database upgrades") - err := br.DB.Upgrade() + err := br.DB.Upgrade(br.ZLog.With().Str("db_section", "main").Logger().WithContext(context.TODO())) if err != nil { br.LogDBUpgradeErrorAndExit("main", err) - } else if err = br.StateStore.Upgrade(); err != nil { + } else if err = br.StateStore.Upgrade(br.ZLog.With().Str("db_section", "matrix_state").Logger().WithContext(context.TODO())); err != nil { br.LogDBUpgradeErrorAndExit("matrix_state", err) } @@ -679,7 +679,7 @@ func (br *Bridge) start() { go br.fetchMediaConfig(ctx) if br.Crypto != nil { - err = br.Crypto.Init() + err = br.Crypto.Init(ctx) if err != nil { br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("Error initializing end-to-bridge encryption") os.Exit(19) diff --git a/bridge/commands/admin.go b/bridge/commands/admin.go index cf38d6c5..ff3340e3 100644 --- a/bridge/commands/admin.go +++ b/bridge/commands/admin.go @@ -17,7 +17,7 @@ var CommandDiscardMegolmSession = &FullHandler{ if ce.Bridge.Crypto == nil { ce.Reply("This bridge instance doesn't have end-to-bridge encryption enabled") } else { - ce.Bridge.Crypto.ResetSession(ce.RoomID) + ce.Bridge.Crypto.ResetSession(ce.Ctx, ce.RoomID) ce.Reply("Successfully reset Megolm session in this room. New decryption keys will be shared the next time a message is sent from the remote network.") } }, diff --git a/bridge/crypto.go b/bridge/crypto.go index a1a76ebd..872bf8a6 100644 --- a/bridge/crypto.go +++ b/bridge/crypto.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -61,7 +61,7 @@ func NewCryptoHelper(bridge *Bridge) Crypto { } } -func (helper *CryptoHelper) Init() error { +func (helper *CryptoHelper) Init(ctx context.Context) error { if len(helper.bridge.CryptoPickleKey) == 0 { panic("CryptoPickleKey not set") } @@ -75,13 +75,13 @@ func (helper *CryptoHelper) Init() error { helper.bridge.CryptoPickleKey, ) - err := helper.store.DB.Upgrade() + err := helper.store.DB.Upgrade(ctx) if err != nil { helper.bridge.LogDBUpgradeErrorAndExit("crypto", err) } var isExistingDevice bool - helper.client, isExistingDevice, err = helper.loginBot(context.TODO()) + helper.client, isExistingDevice, err = helper.loginBot(ctx) if err != nil { return err } @@ -111,7 +111,7 @@ func (helper *CryptoHelper) Init() error { } if encryptionConfig.DeleteKeys.DeleteOutdatedInbound { - deleted, err := helper.store.RedactOutdatedGroupSessions() + deleted, err := helper.store.RedactOutdatedGroupSessions(ctx) if err != nil { return err } @@ -123,12 +123,12 @@ func (helper *CryptoHelper) Init() error { helper.client.Syncer = &cryptoSyncer{helper.mach} helper.client.Store = helper.store - err = helper.mach.Load() + err = helper.mach.Load(ctx) if err != nil { return err } if isExistingDevice { - helper.verifyKeysAreOnServer(context.TODO()) + helper.verifyKeysAreOnServer(ctx) } go helper.resyncEncryptionInfo(context.TODO()) @@ -138,22 +138,16 @@ func (helper *CryptoHelper) Init() error { func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) { log := helper.log.With().Str("action", "resync encryption event").Logger() - rows, err := helper.bridge.DB.QueryContext(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`) + rows, err := helper.bridge.DB.Query(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`) if err != nil { log.Err(err).Msg("Failed to query rooms for resync") return } - var roomIDs []id.RoomID - for rows.Next() { - var roomID id.RoomID - err = rows.Scan(&roomID) - if err != nil { - log.Err(err).Msg("Failed to scan room ID") - continue - } - roomIDs = append(roomIDs, roomID) + roomIDs, err := dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList() + if err != nil { + log.Err(err).Msg("Failed to scan rooms for resync") + return } - _ = rows.Close() if len(roomIDs) > 0 { log.Debug().Interface("room_ids", roomIDs).Msg("Resyncing rooms") for _, roomID := range roomIDs { @@ -161,7 +155,7 @@ func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) { err = helper.client.StateEvent(ctx, roomID, event.StateEncryption, "", &evt) if err != nil { log.Err(err).Str("room_id", roomID.String()).Msg("Failed to get encryption event") - _, err = helper.bridge.DB.ExecContext(ctx, ` + _, err = helper.bridge.DB.Exec(ctx, ` UPDATE mx_room_state SET encryption=NULL WHERE room_id=$1 AND encryption='{"resync":true}' `, roomID) if err != nil { @@ -182,7 +176,7 @@ func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) { Int("max_messages", maxMessages). Interface("content", &evt). Msg("Resynced encryption event") - _, err = helper.bridge.DB.ExecContext(ctx, ` + _, err = helper.bridge.DB.Exec(ctx, ` UPDATE crypto_megolm_inbound_session SET max_age=$1, max_messages=$2 WHERE room_id=$3 AND max_age IS NULL AND max_messages IS NULL @@ -223,8 +217,10 @@ func (helper *CryptoHelper) allowKeyShare(ctx context.Context, device *id.Device } func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool, error) { - deviceID := helper.store.FindDeviceID() - if len(deviceID) > 0 { + deviceID, err := helper.store.FindDeviceID(ctx) + if err != nil { + return nil, false, fmt.Errorf("failed to find existing device ID: %w", err) + } else if len(deviceID) > 0 { helper.log.Debug().Str("device_id", deviceID.String()).Msg("Found existing device ID for bot in database") } // Create a new client instance with the default AS settings (including as_token), @@ -270,7 +266,7 @@ func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) { return } helper.log.Warn().Msg("Existing device doesn't have keys on server, resetting crypto") - helper.Reset(false) + helper.Reset(ctx, false) } func (helper *CryptoHelper) Start() { @@ -306,16 +302,16 @@ func (helper *CryptoHelper) Stop() { helper.syncDone.Wait() } -func (helper *CryptoHelper) clearDatabase() { - _, err := helper.store.DB.Exec("DELETE FROM crypto_account") +func (helper *CryptoHelper) clearDatabase(ctx context.Context) { + _, err := helper.store.DB.Exec(ctx, "DELETE FROM crypto_account") if err != nil { helper.log.Warn().Err(err).Msg("Failed to clear crypto_account table") } - _, err = helper.store.DB.Exec("DELETE FROM crypto_olm_session") + _, err = helper.store.DB.Exec(ctx, "DELETE FROM crypto_olm_session") if err != nil { helper.log.Warn().Err(err).Msg("Failed to clear crypto_olm_session table") } - _, err = helper.store.DB.Exec("DELETE FROM crypto_megolm_outbound_session") + _, err = helper.store.DB.Exec(ctx, "DELETE FROM crypto_megolm_outbound_session") if err != nil { helper.log.Warn().Err(err).Msg("Failed to clear crypto_megolm_outbound_session table") } @@ -325,22 +321,22 @@ func (helper *CryptoHelper) clearDatabase() { //_, _ = helper.store.DB.Exec("DELETE FROM crypto_cross_signing_signatures") } -func (helper *CryptoHelper) Reset(startAfterReset bool) { +func (helper *CryptoHelper) Reset(ctx context.Context, startAfterReset bool) { helper.lock.Lock() defer helper.lock.Unlock() helper.log.Info().Msg("Resetting end-to-bridge encryption device") helper.Stop() helper.log.Debug().Msg("Crypto syncer stopped, clearing database") - helper.clearDatabase() + helper.clearDatabase(ctx) helper.log.Debug().Msg("Crypto database cleared, logging out of all sessions") - _, err := helper.client.LogoutAll(context.TODO()) + _, err := helper.client.LogoutAll(ctx) if err != nil { helper.log.Warn().Err(err).Msg("Failed to log out all devices") } helper.client = nil helper.store = nil helper.mach = nil - err = helper.Init() + err = helper.Init(ctx) if err != nil { helper.log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Error reinitializing end-to-bridge encryption") os.Exit(50) @@ -355,25 +351,24 @@ func (helper *CryptoHelper) Client() *mautrix.Client { return helper.client } -func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) { - return helper.mach.DecryptMegolmEvent(context.TODO(), evt) +func (helper *CryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*event.Event, error) { + return helper.mach.DecryptMegolmEvent(ctx, evt) } -func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, content *event.Content) (err error) { +func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content *event.Content) (err error) { helper.lock.RLock() defer helper.lock.RUnlock() var encrypted *event.EncryptedEventContent - ctx := context.TODO() encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content) if err != nil { - if err != crypto.SessionExpired && err != crypto.SessionNotShared && err != crypto.NoGroupSession { + if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.SessionNotShared) && !errors.Is(err, crypto.NoGroupSession) { return } helper.log.Debug().Err(err). Str("room_id", roomID.String()). Msg("Got error while encrypting event for room, sharing group session and trying again...") var users []id.UserID - users, err = helper.store.GetRoomJoinedOrInvitedMembers(roomID) + users, err = helper.store.GetRoomJoinedOrInvitedMembers(ctx, roomID) if err != nil { err = fmt.Errorf("failed to get room member list: %w", err) } else if err = helper.mach.ShareGroupSession(ctx, roomID, users); err != nil { @@ -389,10 +384,10 @@ func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, conten return } -func (helper *CryptoHelper) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { +func (helper *CryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { helper.lock.RLock() defer helper.lock.RUnlock() - return helper.mach.WaitForSession(roomID, senderKey, sessionID, timeout) + return helper.mach.WaitForSession(ctx, roomID, senderKey, sessionID, timeout) } func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) { @@ -419,10 +414,10 @@ func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID } } -func (helper *CryptoHelper) ResetSession(roomID id.RoomID) { +func (helper *CryptoHelper) ResetSession(ctx context.Context, roomID id.RoomID) { helper.lock.RLock() defer helper.lock.RUnlock() - err := helper.mach.CryptoStore.RemoveOutboundGroupSession(roomID) + err := helper.mach.CryptoStore.RemoveOutboundGroupSession(ctx, roomID) if err != nil { helper.log.Debug().Err(err). Str("room_id", roomID.String()). @@ -499,18 +494,18 @@ type cryptoStateStore struct { var _ crypto.StateStore = (*cryptoStateStore)(nil) -func (c *cryptoStateStore) IsEncrypted(id id.RoomID) bool { +func (c *cryptoStateStore) IsEncrypted(ctx context.Context, id id.RoomID) (bool, error) { portal := c.bridge.Child.GetIPortal(id) if portal != nil { - return portal.IsEncrypted() + return portal.IsEncrypted(), nil } - return c.bridge.StateStore.IsEncrypted(id) + return c.bridge.StateStore.IsEncrypted(ctx, id) } -func (c *cryptoStateStore) FindSharedRooms(id id.UserID) []id.RoomID { - return c.bridge.StateStore.FindSharedRooms(id) +func (c *cryptoStateStore) FindSharedRooms(ctx context.Context, id id.UserID) ([]id.RoomID, error) { + return c.bridge.StateStore.FindSharedRooms(ctx, id) } -func (c *cryptoStateStore) GetEncryptionEvent(id id.RoomID) *event.EncryptionEventContent { - return c.bridge.StateStore.GetEncryptionEvent(id) +func (c *cryptoStateStore) GetEncryptionEvent(ctx context.Context, id id.RoomID) (*event.EncryptionEventContent, error) { + return c.bridge.StateStore.GetEncryptionEvent(ctx, id) } diff --git a/bridge/cryptostore.go b/bridge/cryptostore.go index e199f5a4..dde48a25 100644 --- a/bridge/cryptostore.go +++ b/bridge/cryptostore.go @@ -9,6 +9,8 @@ package bridge import ( + "context" + "github.com/lib/pq" "go.mau.fi/util/dbutil" @@ -36,9 +38,9 @@ func NewSQLCryptoStore(db *dbutil.Database, log dbutil.DatabaseLogger, userID id } } -func (store *SQLCryptoStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) (members []id.UserID, err error) { +func (store *SQLCryptoStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) (members []id.UserID, err error) { var rows dbutil.Rows - rows, err = store.DB.Query(` + rows, err = store.DB.Query(ctx, ` SELECT user_id FROM mx_user_profile WHERE room_id=$1 AND (membership='join' OR membership='invite') diff --git a/bridge/matrix.go b/bridge/matrix.go index 90453c13..00994dd2 100644 --- a/bridge/matrix.go +++ b/bridge/matrix.go @@ -494,7 +494,7 @@ func (mx *MatrixHandler) HandleEncrypted(evt *event.Event) { log.Debug().Msg("Decrypting received event") decryptionStart := time.Now() - decrypted, err := mx.bridge.Crypto.Decrypt(evt) + decrypted, err := mx.bridge.Crypto.Decrypt(ctx, evt) decryptionRetryCount := 0 if errors.Is(err, NoSessionFound) { decryptionRetryCount = 1 @@ -502,9 +502,9 @@ func (mx *MatrixHandler) HandleEncrypted(evt *event.Event) { Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())). Msg("Couldn't find session, waiting for keys to arrive...") mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepDecrypted, err, false, 0) - if mx.bridge.Crypto.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { + if mx.bridge.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { log.Debug().Msg("Got keys after waiting, trying to decrypt event again") - decrypted, err = mx.bridge.Crypto.Decrypt(evt) + decrypted, err = mx.bridge.Crypto.Decrypt(ctx, evt) } else { go mx.waitLongerForSession(ctx, evt, decryptionStart) return @@ -529,14 +529,14 @@ func (mx *MatrixHandler) waitLongerForSession(ctx context.Context, evt *event.Ev go mx.bridge.Crypto.RequestSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) errorEventID := mx.sendCryptoStatusError(ctx, evt, "", fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), 1, false) - if !mx.bridge.Crypto.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { + if !mx.bridge.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { log.Debug().Msg("Didn't get session, giving up trying to decrypt event") mx.sendCryptoStatusError(ctx, evt, errorEventID, errNoDecryptionKeys, 2, true) return } log.Debug().Msg("Got keys after waiting longer, trying to decrypt event again") - decrypted, err := mx.bridge.Crypto.Decrypt(evt) + decrypted, err := mx.bridge.Crypto.Decrypt(ctx, evt) if err != nil { log.Error().Err(err).Msg("Failed to decrypt event") mx.sendCryptoStatusError(ctx, evt, errorEventID, err, 2, true) diff --git a/client.go b/client.go index 1c2b9683..d1a6d8f0 100644 --- a/client.go +++ b/client.go @@ -27,11 +27,11 @@ import ( ) type CryptoHelper interface { - Encrypt(id.RoomID, event.Type, any) (*event.EncryptedEventContent, error) - Decrypt(*event.Event) (*event.Event, error) - WaitForSession(id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool + Encrypt(context.Context, id.RoomID, event.Type, any) (*event.EncryptedEventContent, error) + Decrypt(context.Context, *event.Event) (*event.Event, error) + WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) - Init() error + Init(context.Context) error } // Deprecated: switch to zerolog @@ -846,7 +846,10 @@ func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias, serverName strin } _, err = cli.MakeRequest(ctx, "POST", urlPath, content, &resp) if err == nil && cli.StateStore != nil { - cli.StateStore.SetMembership(resp.RoomID, cli.UserID, event.MembershipJoin) + err = cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipJoin) + if err != nil { + err = fmt.Errorf("failed to update state store: %w", err) + } } return } @@ -858,7 +861,10 @@ func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias, serverName strin func (cli *Client) JoinRoomByID(ctx context.Context, roomID id.RoomID) (resp *RespJoinRoom, err error) { _, err = cli.MakeRequest(ctx, "POST", cli.BuildClientURL("v3", "rooms", roomID, "join"), nil, &resp) if err == nil && cli.StateStore != nil { - cli.StateStore.SetMembership(resp.RoomID, cli.UserID, event.MembershipJoin) + err = cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipJoin) + if err != nil { + err = fmt.Errorf("failed to update state store: %w", err) + } } return } @@ -1000,13 +1006,20 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event queryParams["fi.mau.event_id"] = req.MeowEventID.String() } - if !req.DontEncrypt && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted && cli.StateStore.IsEncrypted(roomID) { - contentJSON, err = cli.Crypto.Encrypt(roomID, eventType, contentJSON) + if !req.DontEncrypt && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted { + var isEncrypted bool + isEncrypted, err = cli.StateStore.IsEncrypted(ctx, roomID) if err != nil { - err = fmt.Errorf("failed to encrypt event: %w", err) + err = fmt.Errorf("failed to check if room is encrypted: %w", err) return } - eventType = event.EventEncrypted + if isEncrypted { + if contentJSON, err = cli.Crypto.Encrypt(ctx, roomID, eventType, contentJSON); err != nil { + err = fmt.Errorf("failed to encrypt event: %w", err) + return + } + eventType = event.EventEncrypted + } } urlData := ClientURLPath{"v3", "rooms", roomID, "send", eventType.String(), txnID} @@ -1021,7 +1034,7 @@ func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventTy urlPath := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey) _, err = cli.MakeRequest(ctx, "PUT", urlPath, contentJSON, &resp) if err == nil && cli.StateStore != nil { - cli.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, contentJSON) + cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON) } return } @@ -1034,7 +1047,7 @@ func (cli *Client) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, }) _, err = cli.MakeRequest(ctx, "PUT", urlPath, contentJSON, &resp) if err == nil && cli.StateStore != nil { - cli.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, contentJSON) + cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON) } return } @@ -1100,19 +1113,29 @@ func (cli *Client) CreateRoom(ctx context.Context, req *ReqCreateRoom) (resp *Re urlPath := cli.BuildClientURL("v3", "createRoom") _, err = cli.MakeRequest(ctx, "POST", urlPath, req, &resp) if err == nil && cli.StateStore != nil { - cli.StateStore.SetMembership(resp.RoomID, cli.UserID, event.MembershipJoin) + storeErr := cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipJoin) + if storeErr != nil { + cli.cliOrContextLog(ctx).Warn().Err(storeErr). + Stringer("creator_user_id", cli.UserID). + Msg("Failed to update creator membership in state store after creating room") + } for _, evt := range req.InitialState { - UpdateStateStore(cli.StateStore, evt) + UpdateStateStore(ctx, cli.StateStore, evt) } inviteMembership := event.MembershipInvite if req.BeeperAutoJoinInvites { inviteMembership = event.MembershipJoin } for _, invitee := range req.Invite { - cli.StateStore.SetMembership(resp.RoomID, invitee, inviteMembership) + storeErr = cli.StateStore.SetMembership(ctx, resp.RoomID, invitee, inviteMembership) + if storeErr != nil { + cli.cliOrContextLog(ctx).Warn().Err(storeErr). + Stringer("invitee_user_id", invitee). + Msg("Failed to update membership in state store after creating room") + } } for _, evt := range req.InitialState { - cli.updateStoreWithOutgoingEvent(resp.RoomID, evt.Type, evt.GetStateKey(), &evt.Content) + cli.updateStoreWithOutgoingEvent(ctx, resp.RoomID, evt.Type, evt.GetStateKey(), &evt.Content) } } return @@ -1129,7 +1152,10 @@ func (cli *Client) LeaveRoom(ctx context.Context, roomID id.RoomID, optionalReq u := cli.BuildClientURL("v3", "rooms", roomID, "leave") _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) if err == nil && cli.StateStore != nil { - cli.StateStore.SetMembership(roomID, cli.UserID, event.MembershipLeave) + err = cli.StateStore.SetMembership(ctx, roomID, cli.UserID, event.MembershipLeave) + if err != nil { + err = fmt.Errorf("failed to update membership in state store: %w", err) + } } return } @@ -1146,7 +1172,10 @@ func (cli *Client) InviteUser(ctx context.Context, roomID id.RoomID, req *ReqInv u := cli.BuildClientURL("v3", "rooms", roomID, "invite") _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) if err == nil && cli.StateStore != nil { - cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipInvite) + err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipInvite) + if err != nil { + err = fmt.Errorf("failed to update membership in state store: %w", err) + } } return } @@ -1163,7 +1192,10 @@ func (cli *Client) KickUser(ctx context.Context, roomID id.RoomID, req *ReqKickU u := cli.BuildClientURL("v3", "rooms", roomID, "kick") _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) if err == nil && cli.StateStore != nil { - cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipLeave) + err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipLeave) + if err != nil { + err = fmt.Errorf("failed to update membership in state store: %w", err) + } } return } @@ -1173,7 +1205,10 @@ func (cli *Client) BanUser(ctx context.Context, roomID id.RoomID, req *ReqBanUse u := cli.BuildClientURL("v3", "rooms", roomID, "ban") _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) if err == nil && cli.StateStore != nil { - cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipBan) + err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipBan) + if err != nil { + err = fmt.Errorf("failed to update membership in state store: %w", err) + } } return } @@ -1183,7 +1218,10 @@ func (cli *Client) UnbanUser(ctx context.Context, roomID id.RoomID, req *ReqUnba u := cli.BuildClientURL("v3", "rooms", roomID, "unban") _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) if err == nil && cli.StateStore != nil { - cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipLeave) + err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipLeave) + if err != nil { + err = fmt.Errorf("failed to update membership in state store: %w", err) + } } return } @@ -1216,7 +1254,7 @@ func (cli *Client) SetPresence(ctx context.Context, status event.Presence) (err return } -func (cli *Client) updateStoreWithOutgoingEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) { +func (cli *Client) updateStoreWithOutgoingEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) { if cli.StateStore == nil { return } @@ -1246,7 +1284,7 @@ func (cli *Client) updateStoreWithOutgoingEvent(roomID id.RoomID, eventType even } return } - UpdateStateStore(cli.StateStore, fakeEvt) + UpdateStateStore(ctx, cli.StateStore, fakeEvt) } // StateEvent gets a single state event in a room. It will attempt to JSON unmarshal into the given "outContent" struct with @@ -1256,7 +1294,7 @@ func (cli *Client) StateEvent(ctx context.Context, roomID id.RoomID, eventType e u := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey) _, err = cli.MakeRequest(ctx, "GET", u, nil, outContent) if err == nil && cli.StateStore != nil { - cli.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, outContent) + cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, outContent) } return } @@ -1310,10 +1348,13 @@ func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomSt Handler: parseRoomStateArray, }) if err == nil && cli.StateStore != nil { - cli.StateStore.ClearCachedMembers(roomID) + clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID) + cli.cliOrContextLog(ctx).Warn().Err(clearErr). + Stringer("room_id", roomID). + Msg("Failed to clear cached member list after fetching state") for _, evts := range stateMap { for _, evt := range evts { - UpdateStateStore(cli.StateStore, evt) + UpdateStateStore(ctx, cli.StateStore, evt) } } } @@ -1630,13 +1671,22 @@ func (cli *Client) JoinedMembers(ctx context.Context, roomID id.RoomID) (resp *R u := cli.BuildClientURL("v3", "rooms", roomID, "joined_members") _, err = cli.MakeRequest(ctx, "GET", u, nil, &resp) if err == nil && cli.StateStore != nil { - cli.StateStore.ClearCachedMembers(roomID, event.MembershipJoin) + clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, event.MembershipJoin) + cli.cliOrContextLog(ctx).Warn().Err(clearErr). + Stringer("room_id", roomID). + Msg("Failed to clear cached member list after fetching joined members") for userID, member := range resp.Joined { - cli.StateStore.SetMember(roomID, userID, &event.MemberEventContent{ + updateErr := cli.StateStore.SetMember(ctx, roomID, userID, &event.MemberEventContent{ Membership: event.MembershipJoin, AvatarURL: id.ContentURIString(member.AvatarURL), Displayname: member.DisplayName, }) + if updateErr != nil { + cli.cliOrContextLog(ctx).Warn().Err(clearErr). + Stringer("room_id", roomID). + Stringer("user_id", userID). + Msg("Failed to update membership in state store after fetching joined members") + } } } return @@ -1665,10 +1715,13 @@ func (cli *Client) Members(ctx context.Context, roomID id.RoomID, req ...ReqMemb clearMemberships = append(clearMemberships, extra.Membership) } if extra.NotMembership == "" { - cli.StateStore.ClearCachedMembers(roomID, clearMemberships...) + clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, clearMemberships...) + cli.cliOrContextLog(ctx).Warn().Err(clearErr). + Stringer("room_id", roomID). + Msg("Failed to clear cached member list after fetching joined members") } for _, evt := range resp.Chunk { - UpdateStateStore(cli.StateStore, evt) + UpdateStateStore(ctx, cli.StateStore, evt) } } return diff --git a/crypto/account.go b/crypto/account.go index a667825c..0eb18a24 100644 --- a/crypto/account.go +++ b/crypto/account.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this diff --git a/crypto/cross_sign_pubkey.go b/crypto/cross_sign_pubkey.go index 9f4f3583..77efab5b 100644 --- a/crypto/cross_sign_pubkey.go +++ b/crypto/cross_sign_pubkey.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -42,7 +42,7 @@ func (mach *OlmMachine) GetOwnCrossSigningPublicKeys(ctx context.Context) *Cross } func (mach *OlmMachine) GetCrossSigningPublicKeys(ctx context.Context, userID id.UserID) (*CrossSigningPublicKeysCache, error) { - dbKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID) + dbKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID) if err != nil { return nil, fmt.Errorf("failed to get keys from database: %w", err) } diff --git a/crypto/cross_sign_signing.go b/crypto/cross_sign_signing.go index 1a5a0233..f6c37a9f 100644 --- a/crypto/cross_sign_signing.go +++ b/crypto/cross_sign_signing.go @@ -1,5 +1,5 @@ // Copyright (c) 2020 Nikos Filippakis -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -34,8 +34,8 @@ var ( ErrMismatchingMasterKeyMAC = errors.New("mismatching cross-signing master key MAC") ) -func (mach *OlmMachine) fetchMasterKey(device *id.Device, content *event.VerificationMacEventContent, verState *verificationState, transactionID string) (id.Ed25519, error) { - crossSignKeys, err := mach.CryptoStore.GetCrossSigningKeys(device.UserID) +func (mach *OlmMachine) fetchMasterKey(ctx context.Context, device *id.Device, content *event.VerificationMacEventContent, verState *verificationState, transactionID string) (id.Ed25519, error) { + crossSignKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, device.UserID) if err != nil { return "", fmt.Errorf("failed to fetch cross-signing keys: %w", err) } @@ -85,7 +85,7 @@ func (mach *OlmMachine) SignUser(ctx context.Context, userID id.UserID, masterKe Str("signature", signature). Msg("Signed master key of user with our user-signing key") - if err := mach.CryptoStore.PutSignature(userID, masterKey, mach.Client.UserID, mach.CrossSigningKeys.UserSigningKey.PublicKey, signature); err != nil { + if err := mach.CryptoStore.PutSignature(ctx, userID, masterKey, mach.Client.UserID, mach.CrossSigningKeys.UserSigningKey.PublicKey, signature); err != nil { return fmt.Errorf("error storing signature in crypto store: %w", err) } @@ -137,7 +137,7 @@ func (mach *OlmMachine) SignOwnMasterKey(ctx context.Context) error { return fmt.Errorf("%w: %+v", ErrSignatureUploadFail, resp.Failures) } - if err := mach.CryptoStore.PutSignature(userID, masterKey, userID, mach.account.SigningKey(), signature); err != nil { + if err := mach.CryptoStore.PutSignature(ctx, userID, masterKey, userID, mach.account.SigningKey(), signature); err != nil { return fmt.Errorf("error storing signature in crypto store: %w", err) } @@ -178,7 +178,7 @@ func (mach *OlmMachine) SignOwnDevice(ctx context.Context, device *id.Device) er Str("signature", signature). Msg("Signed own device key with self-signing key") - if err := mach.CryptoStore.PutSignature(device.UserID, device.SigningKey, mach.Client.UserID, mach.CrossSigningKeys.SelfSigningKey.PublicKey, signature); err != nil { + if err := mach.CryptoStore.PutSignature(ctx, device.UserID, device.SigningKey, mach.Client.UserID, mach.CrossSigningKeys.SelfSigningKey.PublicKey, signature); err != nil { return fmt.Errorf("error storing signature in crypto store: %w", err) } diff --git a/crypto/cross_sign_store.go b/crypto/cross_sign_store.go index f1008ebd..88fcd0ed 100644 --- a/crypto/cross_sign_store.go +++ b/crypto/cross_sign_store.go @@ -1,5 +1,5 @@ // Copyright (c) 2020 Nikos Filippakis -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -19,7 +19,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK log := mach.machOrContextLog(ctx) for userID, userKeys := range crossSigningKeys { log := log.With().Str("user_id", userID.String()).Logger() - currentKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID) + currentKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID) if err != nil { log.Error().Err(err). Msg("Error fetching current cross-signing keys of user") @@ -32,7 +32,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK if newKeyUsage == curKeyUsage { if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok { // old key is not in the new key map, so we drop signatures made by it - if count, err := mach.CryptoStore.DropSignaturesByKey(userID, curKey.Key); err != nil { + if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil { log.Error().Err(err).Msg("Error deleting old signatures made by user") } else { log.Debug(). @@ -50,7 +50,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK log := log.With().Str("key", key.String()).Strs("usages", strishArray(userKeys.Usage)).Logger() for _, usage := range userKeys.Usage { log.Debug().Str("usage", string(usage)).Msg("Storing cross-signing key") - if err = mach.CryptoStore.PutCrossSigningKey(userID, usage, key); err != nil { + if err = mach.CryptoStore.PutCrossSigningKey(ctx, userID, usage, key); err != nil { log.Error().Err(err).Msg("Error storing cross-signing key") } } @@ -85,7 +85,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK } else { if verified { log.Debug().Err(err).Msg("Cross-signing key signature verified") - err = mach.CryptoStore.PutSignature(userID, key, signUserID, signingKey, signature) + err = mach.CryptoStore.PutSignature(ctx, userID, key, signUserID, signingKey, signature) if err != nil { log.Error().Err(err).Msg("Error storing cross-signing key signature") } diff --git a/crypto/cross_sign_test.go b/crypto/cross_sign_test.go index 847c87f4..b53da102 100644 --- a/crypto/cross_sign_test.go +++ b/crypto/cross_sign_test.go @@ -32,7 +32,7 @@ func getOlmMachine(t *testing.T) *OlmMachine { t.Fatalf("Error opening db: %v", err) } sqlStore := NewSQLCryptoStore(db, nil, "accid", id.DeviceID("dev"), []byte("test")) - if err = sqlStore.DB.Upgrade(); err != nil { + if err = sqlStore.DB.Upgrade(context.TODO()); err != nil { t.Fatalf("Error creating tables: %v", err) } @@ -41,9 +41,9 @@ func getOlmMachine(t *testing.T) *OlmMachine { ssk, _ := olm.NewPkSigning() usk, _ := olm.NewPkSigning() - sqlStore.PutCrossSigningKey(userID, id.XSUsageMaster, mk.PublicKey) - sqlStore.PutCrossSigningKey(userID, id.XSUsageSelfSigning, ssk.PublicKey) - sqlStore.PutCrossSigningKey(userID, id.XSUsageUserSigning, usk.PublicKey) + sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageMaster, mk.PublicKey) + sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageSelfSigning, ssk.PublicKey) + sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageUserSigning, usk.PublicKey) return &OlmMachine{ CryptoStore: sqlStore, @@ -70,9 +70,9 @@ func TestTrustOwnDevice(t *testing.T) { t.Error("Own device trusted while it shouldn't be") } - m.CryptoStore.PutSignature(ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey, + m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey, ownDevice.UserID, m.CrossSigningKeys.MasterKey.PublicKey, "sig1") - m.CryptoStore.PutSignature(ownDevice.UserID, ownDevice.SigningKey, + m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, ownDevice.SigningKey, ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey, "sig2") if trusted, _ := m.IsUserTrusted(context.TODO(), ownDevice.UserID); !trusted { @@ -91,20 +91,20 @@ func TestTrustOtherUser(t *testing.T) { } theirMasterKey, _ := olm.NewPkSigning() - m.CryptoStore.PutCrossSigningKey(otherUser, id.XSUsageMaster, theirMasterKey.PublicKey) + m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey) - m.CryptoStore.PutSignature(m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, + m.CryptoStore.PutSignature(context.TODO(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, m.Client.UserID, m.CrossSigningKeys.MasterKey.PublicKey, "sig1") // sign them with self-signing instead of user-signing key - m.CryptoStore.PutSignature(otherUser, theirMasterKey.PublicKey, + m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey, m.Client.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey, "invalid_sig") if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted { t.Error("Other user trusted before their master key has been signed with our user-signing key") } - m.CryptoStore.PutSignature(otherUser, theirMasterKey.PublicKey, + m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey, m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, "sig2") if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted { @@ -128,27 +128,27 @@ func TestTrustOtherDevice(t *testing.T) { } theirMasterKey, _ := olm.NewPkSigning() - m.CryptoStore.PutCrossSigningKey(otherUser, id.XSUsageMaster, theirMasterKey.PublicKey) + m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey) theirSSK, _ := olm.NewPkSigning() - m.CryptoStore.PutCrossSigningKey(otherUser, id.XSUsageSelfSigning, theirSSK.PublicKey) + m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageSelfSigning, theirSSK.PublicKey) - m.CryptoStore.PutSignature(m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, + m.CryptoStore.PutSignature(context.TODO(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, m.Client.UserID, m.CrossSigningKeys.MasterKey.PublicKey, "sig1") - m.CryptoStore.PutSignature(otherUser, theirMasterKey.PublicKey, + m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey, m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, "sig2") if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted { t.Error("Other user not trusted while they should be") } - m.CryptoStore.PutSignature(otherUser, theirSSK.PublicKey, + m.CryptoStore.PutSignature(context.TODO(), otherUser, theirSSK.PublicKey, otherUser, theirMasterKey.PublicKey, "sig3") if m.IsDeviceTrusted(theirDevice) { t.Error("Other device trusted before it has been signed with user's SSK") } - m.CryptoStore.PutSignature(otherUser, theirDevice.SigningKey, + m.CryptoStore.PutSignature(context.TODO(), otherUser, theirDevice.SigningKey, otherUser, theirSSK.PublicKey, "sig4") if !m.IsDeviceTrusted(theirDevice) { diff --git a/crypto/cross_sign_validation.go b/crypto/cross_sign_validation.go index 27afeb73..ff2452ec 100644 --- a/crypto/cross_sign_validation.go +++ b/crypto/cross_sign_validation.go @@ -1,5 +1,5 @@ // Copyright (c) 2020 Nikos Filippakis -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -23,7 +23,7 @@ func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Devi if device.Trust == id.TrustStateVerified || device.Trust == id.TrustStateBlacklisted { return device.Trust, nil } - theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(device.UserID) + theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, device.UserID) if err != nil { mach.machOrContextLog(ctx).Error().Err(err). Str("user_id", device.UserID.String()). @@ -44,7 +44,7 @@ func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Devi Msg("Self-signing key of user not found") return id.TrustStateUnset, nil } - sskSigExists, err := mach.CryptoStore.IsKeySignedBy(device.UserID, theirSSK.Key, device.UserID, theirMSK.Key) + sskSigExists, err := mach.CryptoStore.IsKeySignedBy(ctx, device.UserID, theirSSK.Key, device.UserID, theirMSK.Key) if err != nil { mach.machOrContextLog(ctx).Error().Err(err). Str("user_id", device.UserID.String()). @@ -57,7 +57,7 @@ func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Devi Msg("Self-signing key of user is not signed by their master key") return id.TrustStateUnset, nil } - deviceSigExists, err := mach.CryptoStore.IsKeySignedBy(device.UserID, device.SigningKey, device.UserID, theirSSK.Key) + deviceSigExists, err := mach.CryptoStore.IsKeySignedBy(ctx, device.UserID, device.SigningKey, device.UserID, theirSSK.Key) if err != nil { mach.machOrContextLog(ctx).Error().Err(err). Str("user_id", device.UserID.String()). @@ -97,14 +97,14 @@ func (mach *OlmMachine) IsUserTrusted(ctx context.Context, userID id.UserID) (bo return true, nil } // first we verify our user-signing key - ourUserSigningKeyTrusted, err := mach.CryptoStore.IsKeySignedBy(mach.Client.UserID, csPubkeys.UserSigningKey, mach.Client.UserID, csPubkeys.MasterKey) + ourUserSigningKeyTrusted, err := mach.CryptoStore.IsKeySignedBy(ctx, mach.Client.UserID, csPubkeys.UserSigningKey, mach.Client.UserID, csPubkeys.MasterKey) if err != nil { mach.machOrContextLog(ctx).Error().Err(err).Msg("Error retrieving our self-signing key signatures from database") return false, err } else if !ourUserSigningKeyTrusted { return false, nil } - theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID) + theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID) if err != nil { mach.machOrContextLog(ctx).Error().Err(err). Str("user_id", userID.String()). @@ -118,7 +118,7 @@ func (mach *OlmMachine) IsUserTrusted(ctx context.Context, userID id.UserID) (bo Msg("Master key of user not found") return false, nil } - sigExists, err := mach.CryptoStore.IsKeySignedBy(userID, theirMskKey.Key, mach.Client.UserID, csPubkeys.UserSigningKey) + sigExists, err := mach.CryptoStore.IsKeySignedBy(ctx, userID, theirMskKey.Key, mach.Client.UserID, csPubkeys.UserSigningKey) if err != nil { mach.machOrContextLog(ctx).Error().Err(err). Str("user_id", userID.String()). diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go index 293166fe..eb7d7a77 100644 --- a/crypto/cryptohelper/cryptohelper.go +++ b/crypto/cryptohelper/cryptohelper.go @@ -105,7 +105,7 @@ func NewCryptoHelper(cli *mautrix.Client, pickleKey []byte, store any) (*CryptoH }, nil } -func (helper *CryptoHelper) Init() error { +func (helper *CryptoHelper) Init(ctx context.Context) error { if helper == nil { return fmt.Errorf("crypto helper is nil") } @@ -116,7 +116,7 @@ func (helper *CryptoHelper) Init() error { var stateStore crypto.StateStore if helper.managedStateStore != nil { - err := helper.managedStateStore.Upgrade() + err := helper.managedStateStore.Upgrade(ctx) if err != nil { return fmt.Errorf("failed to upgrade client state store: %w", err) } @@ -124,7 +124,6 @@ func (helper *CryptoHelper) Init() error { } else { stateStore = helper.client.StateStore.(crypto.StateStore) } - ctx := context.TODO() var cryptoStore crypto.Store if helper.unmanagedCryptoStore == nil { managedCryptoStore := crypto.NewSQLCryptoStore(helper.dbForManagedStores, dbutil.ZeroLogger(helper.log.With().Str("db_section", "crypto").Logger()), helper.DBAccountID, helper.client.DeviceID, helper.pickleKey) @@ -133,11 +132,14 @@ func (helper *CryptoHelper) Init() error { } else if _, isMemory := helper.client.Store.(*mautrix.MemorySyncStore); isMemory { helper.client.Store = managedCryptoStore } - err := managedCryptoStore.DB.Upgrade() + err := managedCryptoStore.DB.Upgrade(ctx) if err != nil { return fmt.Errorf("failed to upgrade crypto state store: %w", err) } - storedDeviceID := managedCryptoStore.FindDeviceID() + storedDeviceID, err := managedCryptoStore.FindDeviceID(ctx) + if err != nil { + return fmt.Errorf("failed to find existing device ID: %w", err) + } if helper.LoginAs != nil { if storedDeviceID != "" { helper.LoginAs.DeviceID = storedDeviceID @@ -168,7 +170,7 @@ func (helper *CryptoHelper) Init() error { return fmt.Errorf("the client must be logged in") } helper.mach = crypto.NewOlmMachine(helper.client, &helper.log, cryptoStore, stateStore) - err := helper.mach.Load() + err := helper.mach.Load(ctx) if err != nil { return fmt.Errorf("failed to load olm account: %w", err) } else if err = helper.verifyDeviceKeysOnServer(ctx); err != nil { @@ -253,17 +255,18 @@ func (helper *CryptoHelper) HandleEncrypted(src mautrix.EventSource, evt *event. Str("session_id", content.SessionID.String()). Logger() log.Debug().Msg("Decrypting received event") + ctx := log.WithContext(context.TODO()) - decrypted, err := helper.Decrypt(evt) + decrypted, err := helper.Decrypt(ctx, evt) if errors.Is(err, NoSessionFound) { log.Debug(). Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())). Msg("Couldn't find session, waiting for keys to arrive...") - if helper.mach.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { + if helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { log.Debug().Msg("Got keys after waiting, trying to decrypt event again") - decrypted, err = helper.Decrypt(evt) + decrypted, err = helper.Decrypt(ctx, evt) } else { - go helper.waitLongerForSession(log, src, evt) + go helper.waitLongerForSession(ctx, log, src, evt) return } } @@ -306,20 +309,20 @@ func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID } } -func (helper *CryptoHelper) waitLongerForSession(log zerolog.Logger, src mautrix.EventSource, evt *event.Event) { +func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, log zerolog.Logger, src mautrix.EventSource, evt *event.Event) { content := evt.Content.AsEncrypted() log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...") go helper.RequestSession(context.TODO(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) - if !helper.mach.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { + if !helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { log.Debug().Msg("Didn't get session, giving up") helper.DecryptErrorCallback(evt, NoSessionFound) return } log.Debug().Msg("Got keys after waiting longer, trying to decrypt event again") - decrypted, err := helper.Decrypt(evt) + decrypted, err := helper.Decrypt(ctx, evt) if err != nil { log.Error().Err(err).Msg("Failed to decrypt event") helper.DecryptErrorCallback(evt, err) @@ -329,32 +332,31 @@ func (helper *CryptoHelper) waitLongerForSession(log zerolog.Logger, src mautrix helper.postDecrypt(src, decrypted) } -func (helper *CryptoHelper) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { +func (helper *CryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { if helper == nil { return false } helper.lock.RLock() defer helper.lock.RUnlock() - return helper.mach.WaitForSession(roomID, senderKey, sessionID, timeout) + return helper.mach.WaitForSession(ctx, roomID, senderKey, sessionID, timeout) } -func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) { +func (helper *CryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*event.Event, error) { if helper == nil { return nil, fmt.Errorf("crypto helper is nil") } - return helper.mach.DecryptMegolmEvent(context.TODO(), evt) + return helper.mach.DecryptMegolmEvent(ctx, evt) } -func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) { +func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) { if helper == nil { return nil, fmt.Errorf("crypto helper is nil") } helper.lock.RLock() defer helper.lock.RUnlock() - ctx := context.TODO() encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content) if err != nil { - if err != crypto.SessionExpired && err != crypto.SessionNotShared && err != crypto.NoGroupSession { + if !errors.Is(err, crypto.SessionExpired) && err != crypto.NoGroupSession && !errors.Is(err, crypto.SessionNotShared) { return } helper.log.Debug(). @@ -362,7 +364,7 @@ func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, conten Str("room_id", roomID.String()). Msg("Got session error while encrypting event, sharing group session and trying again") var users []id.UserID - users, err = helper.client.StateStore.GetRoomJoinedOrInvitedMembers(roomID) + users, err = helper.client.StateStore.GetRoomJoinedOrInvitedMembers(ctx, roomID) if err != nil { err = fmt.Errorf("failed to get room member list: %w", err) } else if err = helper.mach.ShareGroupSession(ctx, roomID, users); err != nil { diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index eaff136a..540f99ca 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -91,7 +91,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event } else { forwardedKeys = true lastChainItem := sess.ForwardingChains[len(sess.ForwardingChains)-1] - device, _ = mach.CryptoStore.FindDeviceByKey(evt.Sender, id.IdentityKey(lastChainItem)) + device, _ = mach.CryptoStore.FindDeviceByKey(ctx, evt.Sender, id.IdentityKey(lastChainItem)) if device != nil { trustLevel = mach.ResolveTrust(device) } else { @@ -188,7 +188,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve mach.megolmDecryptLock.Lock() defer mach.megolmDecryptLock.Unlock() - sess, err := mach.CryptoStore.GetGroupSession(encryptionRoomID, content.SenderKey, content.SessionID) + sess, err := mach.CryptoStore.GetGroupSession(ctx, encryptionRoomID, content.SenderKey, content.SessionID) if err != nil { return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err) } else if sess == nil { @@ -250,7 +250,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve Int("max_messages", sess.MaxMessages). Logger() if sess.MaxMessages > 0 && int(ratchetTargetIndex) >= sess.MaxMessages && len(sess.RatchetSafety.MissedIndices) == 0 && mach.DeleteFullyUsedKeysOnDecrypt { - err = mach.CryptoStore.RedactGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), "maximum messages reached") + err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), "maximum messages reached") if err != nil { log.Err(err).Msg("Failed to delete fully used session") return sess, plaintext, messageIndex, RatchetError @@ -261,14 +261,14 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve if err = sess.RatchetTo(ratchetTargetIndex); err != nil { log.Err(err).Msg("Failed to ratchet session") return sess, plaintext, messageIndex, RatchetError - } else if err = mach.CryptoStore.PutGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil { + } else if err = mach.CryptoStore.PutGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil { log.Err(err).Msg("Failed to store ratcheted session") return sess, plaintext, messageIndex, RatchetError } else { log.Info().Msg("Ratcheted session forward") } } else if didModify { - if err = mach.CryptoStore.PutGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil { + if err = mach.CryptoStore.PutGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil { log.Err(err).Msg("Failed to store updated ratchet safety data") return sess, plaintext, messageIndex, RatchetError } else { diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go index 57b39f0b..f99c7dbe 100644 --- a/crypto/decryptolm.go +++ b/crypto/decryptolm.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -159,7 +159,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U } endTimeTrace = mach.timeTrace(ctx, "updating new session in database", time.Second) - err = mach.CryptoStore.UpdateSession(senderKey, session) + err = mach.CryptoStore.UpdateSession(ctx, senderKey, session) endTimeTrace() if err != nil { log.Warn().Err(err).Msg("Failed to update new olm session in crypto store after decrypting") @@ -170,7 +170,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.Context, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) { log := *zerolog.Ctx(ctx) endTimeTrace := mach.timeTrace(ctx, "getting sessions with sender key", time.Second) - sessions, err := mach.CryptoStore.GetSessions(senderKey) + sessions, err := mach.CryptoStore.GetSessions(ctx, senderKey) endTimeTrace() if err != nil { return nil, fmt.Errorf("failed to get session for %s: %w", senderKey, err) @@ -199,7 +199,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.C } } else { endTimeTrace = mach.timeTrace(ctx, "updating session in database", time.Second) - err = mach.CryptoStore.UpdateSession(senderKey, session) + err = mach.CryptoStore.UpdateSession(ctx, senderKey, session) endTimeTrace() if err != nil { log.Warn().Err(err).Msg("Failed to update olm session in crypto store after decrypting") @@ -217,7 +217,7 @@ func (mach *OlmMachine) createInboundSession(ctx context.Context, senderKey id.S return nil, err } mach.saveAccount() - err = mach.CryptoStore.AddSession(senderKey, session) + err = mach.CryptoStore.AddSession(ctx, senderKey, session) if err != nil { zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to store created inbound session") } diff --git a/crypto/devicelist.go b/crypto/devicelist.go index 8514275c..e554480d 100644 --- a/crypto/devicelist.go +++ b/crypto/devicelist.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -53,7 +53,7 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id Str("signed_device_id", deviceID.String()). Str("signature", signature). Msg("Verified self-signing signature") - err = mach.CryptoStore.PutSignature(userID, id.Ed25519(signKey), signerUserID, pubKey, signature) + err = mach.CryptoStore.PutSignature(ctx, userID, id.Ed25519(signKey), signerUserID, pubKey, signature) if err != nil { log.Warn().Err(err). Str("signer_user_id", signerUserID.String()). @@ -74,7 +74,7 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id } // save signature of device made by its own device signing key if signKey, ok := deviceKeys.Keys[id.DeviceKeyID(signerKey)]; ok { - err := mach.CryptoStore.PutSignature(userID, id.Ed25519(signKey), signerUserID, id.Ed25519(signKey), signature) + err := mach.CryptoStore.PutSignature(ctx, userID, id.Ed25519(signKey), signerUserID, id.Ed25519(signKey), signature) if err != nil { log.Warn().Err(err). Str("signer_user_id", signerUserID.String()). @@ -96,7 +96,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT log := mach.machOrContextLog(ctx) if !includeUntracked { var err error - users, err = mach.CryptoStore.FilterTrackedUsers(users) + users, err = mach.CryptoStore.FilterTrackedUsers(ctx, users) if err != nil { log.Warn().Err(err).Msg("Failed to filter tracked user list") } @@ -123,7 +123,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT delete(req.DeviceKeys, userID) newDevices := make(map[id.DeviceID]*id.Device) - existingDevices, err := mach.CryptoStore.GetDevices(userID) + existingDevices, err := mach.CryptoStore.GetDevices(ctx, userID) if err != nil { log.Warn().Err(err).Msg("Failed to get existing devices for user") existingDevices = make(map[id.DeviceID]*id.Device) @@ -151,7 +151,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT } } log.Trace().Int("new_device_count", len(newDevices)).Msg("Storing new device list") - err = mach.CryptoStore.PutDevices(userID, newDevices) + err = mach.CryptoStore.PutDevices(ctx, userID, newDevices) if err != nil { log.Warn().Err(err).Msg("Failed to update device list") } @@ -169,7 +169,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT Str("identity_key", device.IdentityKey.String()). Str("signing_key", device.SigningKey.String()). Logger() - sessionIDs, err := mach.CryptoStore.RedactGroupSessions("", device.IdentityKey, "device removed") + sessionIDs, err := mach.CryptoStore.RedactGroupSessions(ctx, "", device.IdentityKey, "device removed") if err != nil { log.Err(err).Msg("Failed to redact megolm sessions from deleted device") } else { @@ -179,7 +179,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT } } } - mach.OnDevicesChanged(userID) + mach.OnDevicesChanged(ctx, userID) } } for userID := range req.DeviceKeys { @@ -197,18 +197,25 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT // // This is called automatically whenever a device list change is noticed in ProcessSyncResponse and usually does // not need to be called manually. -func (mach *OlmMachine) OnDevicesChanged(userID id.UserID) { +func (mach *OlmMachine) OnDevicesChanged(ctx context.Context, userID id.UserID) { if mach.DisableDeviceChangeKeyRotation { return } - for _, roomID := range mach.StateStore.FindSharedRooms(userID) { - mach.Log.Debug(). + rooms, err := mach.StateStore.FindSharedRooms(ctx, userID) + if err != nil { + mach.machOrContextLog(ctx).Err(err). + Stringer("with_user_id", userID). + Msg("Failed to find shared rooms to invalidate group sessions") + return + } + for _, roomID := range rooms { + mach.machOrContextLog(ctx).Debug(). Str("user_id", userID.String()). Str("room_id", roomID.String()). Msg("Invalidating group session in room due to device change notification") - err := mach.CryptoStore.RemoveOutboundGroupSession(roomID) + err = mach.CryptoStore.RemoveOutboundGroupSession(ctx, roomID) if err != nil { - mach.Log.Warn().Err(err). + mach.machOrContextLog(ctx).Err(err). Str("user_id", userID.String()). Str("room_id", roomID.String()). Msg("Failed to invalidate outbound group session") diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index 078ef518..1eee2fec 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -84,7 +84,7 @@ func parseMessageIndex(ciphertext []byte) (uint, error) { func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID, evtType event.Type, content interface{}) (*event.EncryptedEventContent, error) { mach.megolmEncryptLock.Lock() defer mach.megolmEncryptLock.Unlock() - session, err := mach.CryptoStore.GetOutboundGroupSession(roomID) + session, err := mach.CryptoStore.GetOutboundGroupSession(ctx, roomID) if err != nil { return nil, fmt.Errorf("failed to get outbound group session: %w", err) } else if session == nil { @@ -116,7 +116,7 @@ func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID log = log.With().Uint("message_index", idx).Logger() } log.Debug().Msg("Encrypted event successfully") - err = mach.CryptoStore.UpdateOutboundGroupSession(session) + err = mach.CryptoStore.UpdateOutboundGroupSession(ctx, session) if err != nil { log.Warn().Err(err).Msg("Failed to update megolm session in crypto store after encrypting") } @@ -137,7 +137,13 @@ func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID } func (mach *OlmMachine) newOutboundGroupSession(ctx context.Context, roomID id.RoomID) *OutboundGroupSession { - session := NewOutboundGroupSession(roomID, mach.StateStore.GetEncryptionEvent(roomID)) + encryptionEvent, err := mach.StateStore.GetEncryptionEvent(ctx, roomID) + if err != nil { + mach.machOrContextLog(ctx).Err(err). + Stringer("room_id", roomID). + Msg("Failed to get encryption event in room") + } + session := NewOutboundGroupSession(roomID, encryptionEvent) if !mach.DontStoreOutboundKeys { signingKey, idKey := mach.account.Keys() mach.createGroupSession(ctx, idKey, signingKey, roomID, session.ID(), session.Internal.Key(), session.MaxAge, session.MaxMessages, false) @@ -165,7 +171,7 @@ func strishArray[T ~string](arr []T) []string { func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, users []id.UserID) error { mach.megolmEncryptLock.Lock() defer mach.megolmEncryptLock.Unlock() - session, err := mach.CryptoStore.GetOutboundGroupSession(roomID) + session, err := mach.CryptoStore.GetOutboundGroupSession(ctx, roomID) if err != nil { return fmt.Errorf("failed to get previous outbound group session: %w", err) } else if session != nil && session.Shared && !session.Expired() { @@ -192,7 +198,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, for _, userID := range users { log := log.With().Str("target_user_id", userID.String()).Logger() - devices, err := mach.CryptoStore.GetDevices(userID) + devices, err := mach.CryptoStore.GetDevices(ctx, userID) if err != nil { log.Error().Err(err).Msg("Failed to get devices of user") } else if devices == nil { @@ -292,7 +298,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, log.Debug().Msg("Group session successfully shared") session.Shared = true - return mach.CryptoStore.AddOutboundGroupSession(session) + return mach.CryptoStore.AddOutboundGroupSession(ctx, session) } func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session *OutboundGroupSession, olmSessions map[id.UserID]map[id.DeviceID]deviceSessionWrapper) error { @@ -367,7 +373,7 @@ func (mach *OlmMachine) findOlmSessionsForUser(ctx context.Context, session *Out Reason: "This device does not encrypt messages for unverified devices", }} session.Users[userKey] = OGSIgnored - } else if deviceSession, err := mach.CryptoStore.GetLatestSession(device.IdentityKey); err != nil { + } else if deviceSession, err := mach.CryptoStore.GetLatestSession(ctx, device.IdentityKey); err != nil { log.Error().Err(err).Msg("Failed to get olm session to encrypt group session") } else if deviceSession == nil { log.Warn().Err(err).Msg("Didn't find olm session to encrypt group session") diff --git a/crypto/encryptolm.go b/crypto/encryptolm.go index f21ecd02..3b1d40d3 100644 --- a/crypto/encryptolm.go +++ b/crypto/encryptolm.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -38,7 +38,7 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession Str("olm_session_description", session.Describe()). Msg("Encrypting olm message") msgType, ciphertext := session.Encrypt(plaintext) - err = mach.CryptoStore.UpdateSession(recipient.IdentityKey, session) + err = mach.CryptoStore.UpdateSession(ctx, recipient.IdentityKey, session) if err != nil { log.Error().Err(err).Msg("Failed to update olm session in crypto store after encrypting") } @@ -54,8 +54,8 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession } } -func (mach *OlmMachine) shouldCreateNewSession(identityKey id.IdentityKey) bool { - if !mach.CryptoStore.HasSession(identityKey) { +func (mach *OlmMachine) shouldCreateNewSession(ctx context.Context, identityKey id.IdentityKey) bool { + if !mach.CryptoStore.HasSession(ctx, identityKey) { return true } mach.devicesToUnwedgeLock.Lock() @@ -72,7 +72,7 @@ func (mach *OlmMachine) createOutboundSessions(ctx context.Context, input map[id for userID, devices := range input { request[userID] = make(map[id.DeviceID]id.KeyAlgorithm) for deviceID, identity := range devices { - if mach.shouldCreateNewSession(identity.IdentityKey) { + if mach.shouldCreateNewSession(ctx, identity.IdentityKey) { request[userID][deviceID] = id.KeyAlgorithmSignedCurve25519 } } @@ -117,7 +117,7 @@ func (mach *OlmMachine) createOutboundSessions(ctx context.Context, input map[id log.Error().Err(err).Msg("Failed to create outbound session with claimed one-time key") } else { wrapped := wrapSession(sess) - err = mach.CryptoStore.AddSession(identity.IdentityKey, wrapped) + err = mach.CryptoStore.AddSession(ctx, identity.IdentityKey, wrapped) if err != nil { log.Error().Err(err).Msg("Failed to store created outbound session") } else { diff --git a/crypto/keyimport.go b/crypto/keyimport.go index ed66f23b..2d9f3486 100644 --- a/crypto/keyimport.go +++ b/crypto/keyimport.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,6 +8,7 @@ package crypto import ( "bytes" + "context" "crypto/aes" "crypto/cipher" "crypto/hmac" @@ -91,7 +92,7 @@ func decryptKeyExport(passphrase string, exportData []byte) ([]ExportedSession, return sessionsJSON, nil } -func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, error) { +func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session ExportedSession) (bool, error) { if session.Algorithm != id.AlgorithmMegolmV1 { return false, ErrInvalidExportedAlgorithm } @@ -112,12 +113,12 @@ func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, er ReceivedAt: time.Now().UTC(), } - existingIGS, _ := mach.CryptoStore.GetGroupSession(igs.RoomID, igs.SenderKey, igs.ID()) + existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.SenderKey, igs.ID()) if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() { // We already have an equivalent or better session in the store, so don't override it. return false, nil } - err = mach.CryptoStore.PutGroupSession(igs.RoomID, igs.SenderKey, igs.ID(), igs) + err = mach.CryptoStore.PutGroupSession(ctx, igs.RoomID, igs.SenderKey, igs.ID(), igs) if err != nil { return false, fmt.Errorf("failed to store imported session: %w", err) } @@ -127,7 +128,7 @@ func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, er // ImportKeys imports data that was exported with the format specified in the Matrix spec. // See https://spec.matrix.org/v1.2/client-server-api/#key-exports -func (mach *OlmMachine) ImportKeys(passphrase string, data []byte) (int, int, error) { +func (mach *OlmMachine) ImportKeys(ctx context.Context, passphrase string, data []byte) (int, int, error) { exportData, err := decodeKeyExport(data) if err != nil { return 0, 0, err @@ -143,8 +144,11 @@ func (mach *OlmMachine) ImportKeys(passphrase string, data []byte) (int, int, er Str("room_id", session.RoomID.String()). Str("session_id", session.SessionID.String()). Logger() - imported, err := mach.importExportedRoomKey(session) + imported, err := mach.importExportedRoomKey(ctx, session) if err != nil { + if ctx.Err() != nil { + return count, len(sessions), ctx.Err() + } log.Error().Err(err).Msg("Failed to import Megolm session from file") } else if imported { log.Debug().Msg("Imported Megolm session from file") diff --git a/crypto/keysharing.go b/crypto/keysharing.go index 9b8eef7e..8cf15d35 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -1,5 +1,5 @@ // Copyright (c) 2020 Nikos Filippakis -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -152,7 +152,10 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt Msg("Mismatched session ID while creating inbound group session from forward") return false } - config := mach.StateStore.GetEncryptionEvent(content.RoomID) + config, err := mach.StateStore.GetEncryptionEvent(ctx, content.RoomID) + if err != nil { + log.Error().Err(err).Msg("Failed to get encryption event for room") + } var maxAge time.Duration var maxMessages int if config != nil { @@ -178,7 +181,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt MaxMessages: maxMessages, IsScheduled: content.IsScheduled, } - err = mach.CryptoStore.PutGroupSession(content.RoomID, content.SenderKey, content.SessionID, igs) + err = mach.CryptoStore.PutGroupSession(ctx, content.RoomID, content.SenderKey, content.SessionID, igs) if err != nil { log.Error().Err(err).Msg("Failed to store new inbound group session") return false @@ -274,7 +277,7 @@ func (mach *OlmMachine) handleRoomKeyRequest(ctx context.Context, sender id.User return } - igs, err := mach.CryptoStore.GetGroupSession(content.Body.RoomID, content.Body.SenderKey, content.Body.SessionID) + igs, err := mach.CryptoStore.GetGroupSession(ctx, content.Body.RoomID, content.Body.SenderKey, content.Body.SessionID) if err != nil { if errors.Is(err, ErrGroupSessionWithheld) { log.Debug().Err(err).Msg("Requested group session not available") @@ -331,7 +334,7 @@ func (mach *OlmMachine) handleBeeperRoomKeyAck(ctx context.Context, sender id.Us Int("first_message_index", content.FirstMessageIndex). Logger() - sess, err := mach.CryptoStore.GetGroupSession(content.RoomID, "", content.SessionID) + sess, err := mach.CryptoStore.GetGroupSession(ctx, content.RoomID, "", content.SessionID) if err != nil { if errors.Is(err, ErrGroupSessionWithheld) { log.Debug().Err(err).Msg("Acked group session was already redacted") @@ -351,7 +354,7 @@ func (mach *OlmMachine) handleBeeperRoomKeyAck(ctx context.Context, sender id.Us isInbound := sess.SenderKey == mach.OwnIdentity().IdentityKey if isInbound && mach.DeleteOutboundKeysOnAck && content.FirstMessageIndex == 0 { log.Debug().Msg("Redacting inbound copy of outbound group session after ack") - err = mach.CryptoStore.RedactGroupSession(content.RoomID, sess.SenderKey, content.SessionID, "outbound session acked") + err = mach.CryptoStore.RedactGroupSession(ctx, content.RoomID, sess.SenderKey, content.SessionID, "outbound session acked") if err != nil { log.Err(err).Msg("Failed to redact group session") } diff --git a/crypto/machine.go b/crypto/machine.go index 35f8c121..da78eaf7 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -80,11 +80,11 @@ type OlmMachine struct { // StateStore is used by OlmMachine to get room state information that's needed for encryption. type StateStore interface { // IsEncrypted returns whether a room is encrypted. - IsEncrypted(id.RoomID) bool + IsEncrypted(context.Context, id.RoomID) (bool, error) // GetEncryptionEvent returns the encryption event's content for an encrypted room. - GetEncryptionEvent(id.RoomID) *event.EncryptionEventContent + GetEncryptionEvent(context.Context, id.RoomID) (*event.EncryptionEventContent, error) // FindSharedRooms returns the encrypted rooms that another user is also in for a user ID. - FindSharedRooms(id.UserID) []id.RoomID + FindSharedRooms(context.Context, id.UserID) ([]id.RoomID, error) } // NewOlmMachine creates an OlmMachine with the given client, logger and stores. @@ -131,8 +131,8 @@ func (mach *OlmMachine) machOrContextLog(ctx context.Context) *zerolog.Logger { // Load loads the Olm account information from the crypto store. If there's no olm account, a new one is created. // This must be called before using the machine. -func (mach *OlmMachine) Load() (err error) { - mach.account, err = mach.CryptoStore.GetAccount() +func (mach *OlmMachine) Load(ctx context.Context) (err error) { + mach.account, err = mach.CryptoStore.GetAccount(ctx) if err != nil { return } @@ -143,15 +143,15 @@ func (mach *OlmMachine) Load() (err error) { } func (mach *OlmMachine) saveAccount() { - err := mach.CryptoStore.PutAccount(mach.account) + err := mach.CryptoStore.PutAccount(context.TODO(), mach.account) if err != nil { mach.Log.Error().Err(err).Msg("Failed to save account") } } // FlushStore calls the Flush method of the CryptoStore. -func (mach *OlmMachine) FlushStore() error { - return mach.CryptoStore.Flush() +func (mach *OlmMachine) FlushStore(ctx context.Context) error { + return mach.CryptoStore.Flush(ctx) } func (mach *OlmMachine) timeTrace(ctx context.Context, thing string, expectedDuration time.Duration) func() { @@ -284,7 +284,12 @@ func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string // // client.Syncer.(mautrix.ExtensibleSyncer).OnEventType(event.StateMember, c.crypto.HandleMemberEvent) func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Event) { - if !mach.StateStore.IsEncrypted(evt.RoomID) { + ctx := context.TODO() + if isEncrypted, err := mach.StateStore.IsEncrypted(ctx, evt.RoomID); err != nil { + mach.machOrContextLog(ctx).Err(err).Stringer("room_id", evt.RoomID). + Msg("Failed to check if room is encrypted to handle member event") + return + } else if !isEncrypted { return } content := evt.Content.AsMember() @@ -311,7 +316,7 @@ func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Even Str("prev_membership", string(prevContent.Membership)). Str("new_membership", string(content.Membership)). Msg("Got membership state change, invalidating group session in room") - err := mach.CryptoStore.RemoveOutboundGroupSession(evt.RoomID) + err := mach.CryptoStore.RemoveOutboundGroupSession(ctx, evt.RoomID) if err != nil { mach.Log.Warn().Str("room_id", evt.RoomID.String()).Msg("Failed to invalidate outbound group session") } @@ -405,7 +410,7 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) { // GetOrFetchDevice attempts to retrieve the device identity for the given device from the store // and if it's not found it asks the server for it. func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*id.Device, error) { - device, err := mach.CryptoStore.GetDevice(userID, deviceID) + device, err := mach.CryptoStore.GetDevice(ctx, userID, deviceID) if err != nil { return nil, fmt.Errorf("failed to get sender device from store: %w", err) } else if device != nil { @@ -425,7 +430,7 @@ func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID, // store and if it's not found it asks the server for it. This returns nil if the server doesn't return a device with // the given identity key. func (mach *OlmMachine) GetOrFetchDeviceByKey(ctx context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) { - deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(userID, identityKey) + deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(ctx, userID, identityKey) if err != nil || deviceIdentity != nil { return deviceIdentity, err } @@ -455,7 +460,7 @@ func (mach *OlmMachine) SendEncryptedToDevice(ctx context.Context, device *id.De mach.olmLock.Lock() defer mach.olmLock.Unlock() - olmSess, err := mach.CryptoStore.GetLatestSession(device.IdentityKey) + olmSess, err := mach.CryptoStore.GetLatestSession(ctx, device.IdentityKey) if err != nil { return err } @@ -499,7 +504,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen Msg("Mismatched session ID while creating inbound group session") return } - err = mach.CryptoStore.PutGroupSession(roomID, senderKey, sessionID, igs) + err = mach.CryptoStore.PutGroupSession(ctx, roomID, senderKey, sessionID, igs) if err != nil { log.Error().Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session") return @@ -525,7 +530,7 @@ func (mach *OlmMachine) markSessionReceived(id id.SessionID) { } // WaitForSession waits for the given Megolm session to arrive. -func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { +func (mach *OlmMachine) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { mach.keyWaitersLock.Lock() ch, ok := mach.keyWaiters[sessionID] if !ok { @@ -534,7 +539,7 @@ func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, } mach.keyWaitersLock.Unlock() // Handle race conditions where a session appears between the failed decryption and WaitForSession call. - sess, err := mach.CryptoStore.GetGroupSession(roomID, senderKey, sessionID) + sess, err := mach.CryptoStore.GetGroupSession(ctx, roomID, senderKey, sessionID) if sess != nil || errors.Is(err, ErrGroupSessionWithheld) { return true } @@ -542,10 +547,12 @@ func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, case <-ch: return true case <-time.After(timeout): - sess, err = mach.CryptoStore.GetGroupSession(roomID, senderKey, sessionID) + sess, err = mach.CryptoStore.GetGroupSession(ctx, roomID, senderKey, sessionID) // Check if the session somehow appeared in the store without telling us // We accept withheld sessions as received, as then the decryption attempt will show the error. return sess != nil || errors.Is(err, ErrGroupSessionWithheld) + case <-ctx.Done(): + return false } } @@ -568,7 +575,10 @@ func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEve return } - config := mach.StateStore.GetEncryptionEvent(content.RoomID) + config, err := mach.StateStore.GetEncryptionEvent(ctx, content.RoomID) + if err != nil { + log.Error().Err(err).Msg("Failed to get encryption event for room") + } var maxAge time.Duration var maxMessages int if config != nil { @@ -589,7 +599,7 @@ func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEve } if mach.DeletePreviousKeysOnReceive && !content.IsScheduled { log.Debug().Msg("Redacting previous megolm sessions from sender in room") - sessionIDs, err := mach.CryptoStore.RedactGroupSessions(content.RoomID, evt.SenderKey, "received new key from device") + sessionIDs, err := mach.CryptoStore.RedactGroupSessions(ctx, content.RoomID, evt.SenderKey, "received new key from device") if err != nil { log.Err(err).Msg("Failed to redact previous megolm sessions") } else { @@ -606,7 +616,7 @@ func (mach *OlmMachine) handleRoomKeyWithheld(ctx context.Context, content *even zerolog.Ctx(ctx).Debug().Interface("content", content).Msg("Non-megolm room key withheld event") return } - err := mach.CryptoStore.PutWithheldGroupSession(*content) + err := mach.CryptoStore.PutWithheldGroupSession(ctx, *content) if err != nil { zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to save room key withheld event") } @@ -662,7 +672,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro func (mach *OlmMachine) ExpiredKeyDeleteLoop(ctx context.Context) { log := mach.Log.With().Str("action", "redact expired sessions").Logger() for { - sessionIDs, err := mach.CryptoStore.RedactExpiredGroupSessions() + sessionIDs, err := mach.CryptoStore.RedactExpiredGroupSessions(ctx) if err != nil { log.Err(err).Msg("Failed to redact expired megolm sessions") } else if len(sessionIDs) > 0 { diff --git a/crypto/machine_test.go b/crypto/machine_test.go index 0271104e..f1d00ebb 100644 --- a/crypto/machine_test.go +++ b/crypto/machine_test.go @@ -20,18 +20,18 @@ import ( type mockStateStore struct{} -func (mockStateStore) IsEncrypted(id.RoomID) bool { - return true +func (mockStateStore) IsEncrypted(context.Context, id.RoomID) (bool, error) { + return true, nil } -func (mockStateStore) GetEncryptionEvent(id.RoomID) *event.EncryptionEventContent { +func (mockStateStore) GetEncryptionEvent(context.Context, id.RoomID) (*event.EncryptionEventContent, error) { return &event.EncryptionEventContent{ RotationPeriodMessages: 3, - } + }, nil } -func (mockStateStore) FindSharedRooms(id.UserID) []id.RoomID { - return []id.RoomID{"room1"} +func (mockStateStore) FindSharedRooms(context.Context, id.UserID) ([]id.RoomID, error) { + return []id.RoomID{"room1"}, nil } func newMachine(t *testing.T, userID id.UserID) *OlmMachine { @@ -47,7 +47,7 @@ func newMachine(t *testing.T, userID id.UserID) *OlmMachine { } machine := NewOlmMachine(client, nil, gobStore, mockStateStore{}) - if err := machine.Load(); err != nil { + if err := machine.Load(context.TODO()); err != nil { t.Fatalf("Error creating account: %v", err) } @@ -57,7 +57,7 @@ func newMachine(t *testing.T, userID id.UserID) *OlmMachine { func TestRatchetMegolmSession(t *testing.T) { mach := newMachine(t, "user1") outSess := mach.newOutboundGroupSession(context.TODO(), "meow") - inSess, err := mach.CryptoStore.GetGroupSession("meow", mach.OwnIdentity().IdentityKey, outSess.ID()) + inSess, err := mach.CryptoStore.GetGroupSession(context.TODO(), "meow", mach.OwnIdentity().IdentityKey, outSess.ID()) require.NoError(t, err) assert.Equal(t, uint32(0), inSess.Internal.FirstKnownIndex()) err = inSess.RatchetTo(10) @@ -85,7 +85,7 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { } // store sender device identity in receiving machine store - machineIn.CryptoStore.PutDevices("user1", map[id.DeviceID]*id.Device{ + machineIn.CryptoStore.PutDevices(context.TODO(), "user1", map[id.DeviceID]*id.Device{ "device1": { UserID: "user1", DeviceID: "device1", @@ -97,7 +97,7 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { // create & store outbound megolm session for sending the event later megolmOutSession := machineOut.newOutboundGroupSession(context.TODO(), "room1") megolmOutSession.Shared = true - machineOut.CryptoStore.AddOutboundGroupSession(megolmOutSession) + machineOut.CryptoStore.AddOutboundGroupSession(context.TODO(), megolmOutSession) // encrypt m.room_key event with olm session deviceIdentity := &id.Device{ @@ -125,7 +125,7 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { if err != nil { t.Errorf("Error creating inbound megolm session: %v", err) } - if err = machineIn.CryptoStore.PutGroupSession("room1", senderKey, igs.ID(), igs); err != nil { + if err = machineIn.CryptoStore.PutGroupSession(context.TODO(), "room1", senderKey, igs.ID(), igs); err != nil { t.Errorf("Error storing inbound megolm session: %v", err) } } diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 64d62bc4..8c85f6de 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -27,7 +27,7 @@ import ( "maunium.net/go/mautrix/id" ) -var PostgresArrayWrapper func(interface{}) interface { +var PostgresArrayWrapper func(any) interface { driver.Valuer sql.Scanner } @@ -62,21 +62,21 @@ func NewSQLCryptoStore(db *dbutil.Database, log dbutil.DatabaseLogger, accountID } // Flush does nothing for this implementation as data is already persisted in the database. -func (store *SQLCryptoStore) Flush() error { +func (store *SQLCryptoStore) Flush(_ context.Context) error { return nil } // PutNextBatch stores the next sync batch token for the current account. func (store *SQLCryptoStore) PutNextBatch(ctx context.Context, nextBatch string) error { store.SyncToken = nextBatch - _, err := store.DB.ExecContext(ctx, `UPDATE crypto_account SET sync_token=$1 WHERE account_id=$2`, store.SyncToken, store.AccountID) + _, err := store.DB.Exec(ctx, `UPDATE crypto_account SET sync_token=$1 WHERE account_id=$2`, store.SyncToken, store.AccountID) return err } // GetNextBatch retrieves the next sync batch token for the current account. func (store *SQLCryptoStore) GetNextBatch(ctx context.Context) (string, error) { if store.SyncToken == "" { - err := store.DB. + err := store.DB.Conn(ctx). QueryRowContext(ctx, "SELECT sync_token FROM crypto_account WHERE account_id=$1", store.AccountID). Scan(&store.SyncToken) if !errors.Is(err, sql.ErrNoRows) { @@ -111,20 +111,19 @@ func (store *SQLCryptoStore) LoadNextBatch(ctx context.Context, _ id.UserID) (st return nb, nil } -func (store *SQLCryptoStore) FindDeviceID() (deviceID id.DeviceID) { - err := store.DB.QueryRow("SELECT device_id FROM crypto_account WHERE account_id=$1", store.AccountID).Scan(&deviceID) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - // TODO return error - store.DB.Log.Warn("Failed to scan device ID: %v", err) +func (store *SQLCryptoStore) FindDeviceID(ctx context.Context) (deviceID id.DeviceID, err error) { + err = store.DB.QueryRow(ctx, "SELECT device_id FROM crypto_account WHERE account_id=$1", store.AccountID).Scan(&deviceID) + if errors.Is(err, sql.ErrNoRows) { + err = nil } return } // PutAccount stores an OlmAccount in the database. -func (store *SQLCryptoStore) PutAccount(account *OlmAccount) error { +func (store *SQLCryptoStore) PutAccount(ctx context.Context, account *OlmAccount) error { store.Account = account bytes := account.Internal.Pickle(store.PickleKey) - _, err := store.DB.Exec(` + _, err := store.DB.Exec(ctx, ` INSERT INTO crypto_account (device_id, shared, sync_token, account, account_id) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (account_id) DO UPDATE SET shared=excluded.shared, sync_token=excluded.sync_token, account=excluded.account, account_id=excluded.account_id @@ -133,9 +132,9 @@ func (store *SQLCryptoStore) PutAccount(account *OlmAccount) error { } // GetAccount retrieves an OlmAccount from the database. -func (store *SQLCryptoStore) GetAccount() (*OlmAccount, error) { +func (store *SQLCryptoStore) GetAccount(ctx context.Context) (*OlmAccount, error) { if store.Account == nil { - row := store.DB.QueryRow("SELECT shared, sync_token, account FROM crypto_account WHERE account_id=$1", store.AccountID) + row := store.DB.QueryRow(ctx, "SELECT shared, sync_token, account FROM crypto_account WHERE account_id=$1", store.AccountID) acc := &OlmAccount{Internal: *olm.NewBlankAccount()} var accountBytes []byte err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes) @@ -154,7 +153,7 @@ func (store *SQLCryptoStore) GetAccount() (*OlmAccount, error) { } // HasSession returns whether there is an Olm session for the given sender key. -func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool { +func (store *SQLCryptoStore) HasSession(ctx context.Context, key id.SenderKey) bool { store.olmSessionCacheLock.Lock() cache, ok := store.olmSessionCache[key] store.olmSessionCacheLock.Unlock() @@ -162,17 +161,17 @@ func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool { return true } var sessionID id.SessionID - err := store.DB.QueryRow("SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 LIMIT 1", + err := store.DB.QueryRow(ctx, "SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 LIMIT 1", key, store.AccountID).Scan(&sessionID) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return false } return len(sessionID) > 0 } // GetSessions returns all the known Olm sessions for a sender key. -func (store *SQLCryptoStore) GetSessions(key id.SenderKey) (OlmSessionList, error) { - rows, err := store.DB.Query("SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC", +func (store *SQLCryptoStore) GetSessions(ctx context.Context, key id.SenderKey) (OlmSessionList, error) { + rows, err := store.DB.Query(ctx, "SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC", key, store.AccountID) if err != nil { return nil, err @@ -212,11 +211,11 @@ func (store *SQLCryptoStore) getOlmSessionCache(key id.SenderKey) map[id.Session } // GetLatestSession retrieves the Olm session for a given sender key from the database that has the largest ID. -func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*OlmSession, error) { +func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.SenderKey) (*OlmSession, error) { store.olmSessionCacheLock.Lock() defer store.olmSessionCacheLock.Unlock() - row := store.DB.QueryRow("SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC LIMIT 1", + row := store.DB.QueryRow(ctx, "SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC LIMIT 1", key, store.AccountID) sess := OlmSession{Internal: *olm.NewBlankSession()} @@ -224,7 +223,7 @@ func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*OlmSession, er var sessionID id.SessionID err := row.Scan(&sessionID, &sessionBytes, &sess.CreationTime, &sess.LastEncryptedTime, &sess.LastDecryptedTime) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, nil } else if err != nil { return nil, err @@ -242,20 +241,20 @@ func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*OlmSession, er } // AddSession persists an Olm session for a sender in the database. -func (store *SQLCryptoStore) AddSession(key id.SenderKey, session *OlmSession) error { +func (store *SQLCryptoStore) AddSession(ctx context.Context, key id.SenderKey, session *OlmSession) error { store.olmSessionCacheLock.Lock() defer store.olmSessionCacheLock.Unlock() sessionBytes := session.Internal.Pickle(store.PickleKey) - _, err := store.DB.Exec("INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)", + _, err := store.DB.Exec(ctx, "INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)", session.ID(), key, sessionBytes, session.CreationTime, session.LastEncryptedTime, session.LastDecryptedTime, store.AccountID) store.getOlmSessionCache(key)[session.ID()] = session return err } // UpdateSession replaces the Olm session for a sender in the database. -func (store *SQLCryptoStore) UpdateSession(_ id.SenderKey, session *OlmSession) error { +func (store *SQLCryptoStore) UpdateSession(ctx context.Context, _ id.SenderKey, session *OlmSession) error { sessionBytes := session.Internal.Pickle(store.PickleKey) - _, err := store.DB.Exec("UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3 WHERE session_id=$4 AND account_id=$5", + _, err := store.DB.Exec(ctx, "UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3 WHERE session_id=$4 AND account_id=$5", sessionBytes, session.LastEncryptedTime, session.LastDecryptedTime, session.ID(), store.AccountID) return err } @@ -275,14 +274,14 @@ func datePtr(t time.Time) *time.Time { } // PutGroupSession stores an inbound Megolm group session for a room, sender and session. -func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *InboundGroupSession) error { +func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *InboundGroupSession) error { sessionBytes := session.Internal.Pickle(store.PickleKey) forwardingChains := strings.Join(session.ForwardingChains, ",") ratchetSafety, err := json.Marshal(&session.RatchetSafety) if err != nil { return fmt.Errorf("failed to marshal ratchet safety info: %w", err) } - _, err = store.DB.Exec(` + _, err = store.DB.Exec(ctx, ` INSERT INTO crypto_megolm_inbound_session ( session_id, sender_key, signing_key, room_id, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, account_id @@ -301,19 +300,19 @@ func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.Send } // GetGroupSession retrieves an inbound Megolm group session for a room, sender and session. -func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) { +func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) { var senderKeyDB, signingKey, forwardingChains, withheldCode, withheldReason sql.NullString var sessionBytes, ratchetSafetyBytes []byte var receivedAt sql.NullTime var maxAge, maxMessages sql.NullInt64 var isScheduled bool - err := store.DB.QueryRow(` + err := store.DB.QueryRow(ctx, ` SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled FROM crypto_megolm_inbound_session WHERE room_id=$1 AND (sender_key=$2 OR $2 = '') AND session_id=$3 AND account_id=$4`, roomID, senderKey, sessionID, store.AccountID, ).Scan(&senderKeyDB, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, nil } else if err != nil { return nil, err @@ -327,22 +326,7 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send Reason: withheldReason.String, } } - igs := olm.NewBlankInboundGroupSession() - err = igs.Unpickle(sessionBytes, store.PickleKey) - if err != nil { - return nil, err - } - var chains []string - if forwardingChains.String != "" { - chains = strings.Split(forwardingChains.String, ",") - } - var rs RatchetSafety - if len(ratchetSafetyBytes) > 0 { - err = json.Unmarshal(ratchetSafetyBytes, &rs) - if err != nil { - return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err) - } - } + igs, chains, rs, err := store.postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes, forwardingChains.String) if senderKey == "" { senderKey = id.Curve25519(senderKeyDB.String) } @@ -360,8 +344,8 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send }, nil } -func (store *SQLCryptoStore) RedactGroupSession(_ id.RoomID, _ id.SenderKey, sessionID id.SessionID, reason string) error { - _, err := store.DB.Exec(` +func (store *SQLCryptoStore) RedactGroupSession(ctx context.Context, _ id.RoomID, _ id.SenderKey, sessionID id.SessionID, reason string) error { + _, err := store.DB.Exec(ctx, ` UPDATE crypto_megolm_inbound_session SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL WHERE session_id=$3 AND account_id=$4 AND session IS NOT NULL @@ -369,27 +353,24 @@ func (store *SQLCryptoStore) RedactGroupSession(_ id.RoomID, _ id.SenderKey, ses return err } -func (store *SQLCryptoStore) RedactGroupSessions(roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) { +func (store *SQLCryptoStore) RedactGroupSessions(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) { if roomID == "" && senderKey == "" { return nil, fmt.Errorf("room ID or sender key must be provided for redacting sessions") } - res, err := store.DB.Query(` + res, err := store.DB.Query(ctx, ` UPDATE crypto_megolm_inbound_session SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL WHERE (room_id=$3 OR $3='') AND (sender_key=$4 OR $4='') AND account_id=$5 AND session IS NOT NULL AND is_scheduled=false AND received_at IS NOT NULL RETURNING session_id `, event.RoomKeyWithheldBeeperRedacted, "Session redacted: "+reason, roomID, senderKey, store.AccountID) - var sessionIDs []id.SessionID - for res.Next() { - var sessionID id.SessionID - _ = res.Scan(&sessionID) - sessionIDs = append(sessionIDs, sessionID) + if err != nil { + return nil, err } - return sessionIDs, err + return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList() } -func (store *SQLCryptoStore) RedactExpiredGroupSessions() ([]id.SessionID, error) { +func (store *SQLCryptoStore) RedactExpiredGroupSessions(ctx context.Context) ([]id.SessionID, error) { var query string switch store.DB.Dialect { case dbutil.Postgres: @@ -413,46 +394,40 @@ func (store *SQLCryptoStore) RedactExpiredGroupSessions() ([]id.SessionID, error default: return nil, fmt.Errorf("unsupported dialect") } - res, err := store.DB.Query(query, event.RoomKeyWithheldBeeperRedacted, "Session redacted: expired", store.AccountID) - var sessionIDs []id.SessionID - for res.Next() { - var sessionID id.SessionID - _ = res.Scan(&sessionID) - sessionIDs = append(sessionIDs, sessionID) + res, err := store.DB.Query(ctx, query, event.RoomKeyWithheldBeeperRedacted, "Session redacted: expired", store.AccountID) + if err != nil { + return nil, err } - return sessionIDs, err + return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList() } -func (store *SQLCryptoStore) RedactOutdatedGroupSessions() ([]id.SessionID, error) { - res, err := store.DB.Query(` +func (store *SQLCryptoStore) RedactOutdatedGroupSessions(ctx context.Context) ([]id.SessionID, error) { + res, err := store.DB.Query(ctx, ` UPDATE crypto_megolm_inbound_session SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL WHERE account_id=$3 AND session IS NOT NULL AND received_at IS NULL RETURNING session_id `, event.RoomKeyWithheldBeeperRedacted, "Session redacted: outdated", store.AccountID) - var sessionIDs []id.SessionID - for res.Next() { - var sessionID id.SessionID - _ = res.Scan(&sessionID) - sessionIDs = append(sessionIDs, sessionID) + if err != nil { + return nil, err } - return sessionIDs, err + return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList() } -func (store *SQLCryptoStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error { - _, err := store.DB.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, received_at, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)", +func (store *SQLCryptoStore) PutWithheldGroupSession(ctx context.Context, content event.RoomKeyWithheldEventContent) error { + _, err := store.DB.Exec(ctx, "INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, received_at, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)", content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, time.Now().UTC(), store.AccountID) return err } -func (store *SQLCryptoStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) { +func (store *SQLCryptoStore) GetWithheldGroupSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) { var code, reason sql.NullString - err := store.DB.QueryRow(` + err := store.DB.QueryRow(ctx, ` SELECT withheld_code, withheld_reason FROM crypto_megolm_inbound_session WHERE room_id=$1 AND sender_key=$2 AND session_id=$3 AND account_id=$4`, roomID, senderKey, sessionID, store.AccountID, ).Scan(&code, &reason) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, nil } else if err != nil || !code.Valid { return nil, err @@ -467,82 +442,79 @@ func (store *SQLCryptoStore) GetWithheldGroupSession(roomID id.RoomID, senderKey }, nil } -func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*InboundGroupSession, err error) { - for rows.Next() { - var roomID id.RoomID - var signingKey, senderKey, forwardingChains sql.NullString - var sessionBytes, ratchetSafetyBytes []byte - var receivedAt sql.NullTime - var maxAge, maxMessages sql.NullInt64 - var isScheduled bool - err = rows.Scan(&roomID, &signingKey, &senderKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled) +func (store *SQLCryptoStore) postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes []byte, forwardingChains string) (igs *olm.InboundGroupSession, chains []string, safety RatchetSafety, err error) { + igs = olm.NewBlankInboundGroupSession() + err = igs.Unpickle(sessionBytes, store.PickleKey) + if err != nil { + return + } + if forwardingChains != "" { + chains = strings.Split(forwardingChains, ",") + } + var rs RatchetSafety + if len(ratchetSafetyBytes) > 0 { + err = json.Unmarshal(ratchetSafetyBytes, &rs) if err != nil { - return + err = fmt.Errorf("failed to unmarshal ratchet safety info: %w", err) } - igs := olm.NewBlankInboundGroupSession() - err = igs.Unpickle(sessionBytes, store.PickleKey) - if err != nil { - return - } - var chains []string - if forwardingChains.String != "" { - chains = strings.Split(forwardingChains.String, ",") - } - var rs RatchetSafety - if len(ratchetSafetyBytes) > 0 { - err = json.Unmarshal(ratchetSafetyBytes, &rs) - if err != nil { - return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err) - } - } - result = append(result, &InboundGroupSession{ - Internal: *igs, - SigningKey: id.Ed25519(signingKey.String), - SenderKey: id.Curve25519(senderKey.String), - RoomID: roomID, - ForwardingChains: chains, - RatchetSafety: rs, - ReceivedAt: receivedAt.Time, - MaxAge: maxAge.Int64, - MaxMessages: int(maxMessages.Int64), - IsScheduled: isScheduled, - }) } return } -func (store *SQLCryptoStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) { - rows, err := store.DB.Query(` - SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled +func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*InboundGroupSession, error) { + var roomID id.RoomID + var signingKey, senderKey, forwardingChains sql.NullString + var sessionBytes, ratchetSafetyBytes []byte + var receivedAt sql.NullTime + var maxAge, maxMessages sql.NullInt64 + var isScheduled bool + err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled) + if err != nil { + return nil, err + } + igs, chains, rs, err := store.postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes, forwardingChains.String) + return &InboundGroupSession{ + Internal: *igs, + SigningKey: id.Ed25519(signingKey.String), + SenderKey: id.Curve25519(senderKey.String), + RoomID: roomID, + ForwardingChains: chains, + RatchetSafety: rs, + ReceivedAt: receivedAt.Time, + MaxAge: maxAge.Int64, + MaxMessages: int(maxMessages.Int64), + IsScheduled: isScheduled, + }, nil +} + +func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) ([]*InboundGroupSession, error) { + rows, err := store.DB.Query(ctx, ` + SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`, roomID, store.AccountID, ) - if err == sql.ErrNoRows { - return []*InboundGroupSession{}, nil - } else if err != nil { + if err != nil { return nil, err } - return store.scanGroupSessionList(rows) + return dbutil.NewRowIter(rows, store.scanInboundGroupSession).AsList() } -func (store *SQLCryptoStore) GetAllGroupSessions() ([]*InboundGroupSession, error) { - rows, err := store.DB.Query(` - SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled +func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) ([]*InboundGroupSession, error) { + rows, err := store.DB.Query(ctx, ` + SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled FROM crypto_megolm_inbound_session WHERE account_id=$2 AND session IS NOT NULL`, store.AccountID, ) - if err == sql.ErrNoRows { - return []*InboundGroupSession{}, nil - } else if err != nil { + if err != nil { return nil, err } - return store.scanGroupSessionList(rows) + return dbutil.NewRowIter(rows, store.scanInboundGroupSession).AsList() } // AddOutboundGroupSession stores an outbound Megolm session, along with the information about the room and involved devices. -func (store *SQLCryptoStore) AddOutboundGroupSession(session *OutboundGroupSession) error { +func (store *SQLCryptoStore) AddOutboundGroupSession(ctx context.Context, session *OutboundGroupSession) error { sessionBytes := session.Internal.Pickle(store.PickleKey) - _, err := store.DB.Exec(` + _, err := store.DB.Exec(ctx, ` INSERT INTO crypto_megolm_outbound_session (room_id, session_id, session, shared, max_messages, message_count, max_age, created_at, last_used, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) @@ -556,24 +528,24 @@ func (store *SQLCryptoStore) AddOutboundGroupSession(session *OutboundGroupSessi } // UpdateOutboundGroupSession replaces an outbound Megolm session with for same room and session ID. -func (store *SQLCryptoStore) UpdateOutboundGroupSession(session *OutboundGroupSession) error { +func (store *SQLCryptoStore) UpdateOutboundGroupSession(ctx context.Context, session *OutboundGroupSession) error { sessionBytes := session.Internal.Pickle(store.PickleKey) - _, err := store.DB.Exec("UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5 AND account_id=$6", + _, err := store.DB.Exec(ctx, "UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5 AND account_id=$6", sessionBytes, session.MessageCount, session.LastEncryptedTime, session.RoomID, session.ID(), store.AccountID) return err } // GetOutboundGroupSession retrieves the outbound Megolm session for the given room ID. -func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSession, error) { +func (store *SQLCryptoStore) GetOutboundGroupSession(ctx context.Context, roomID id.RoomID) (*OutboundGroupSession, error) { var ogs OutboundGroupSession var sessionBytes []byte var maxAgeMS int64 - err := store.DB.QueryRow(` + err := store.DB.QueryRow(ctx, ` SELECT session, shared, max_messages, message_count, max_age, created_at, last_used FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2`, roomID, store.AccountID, ).Scan(&sessionBytes, &ogs.Shared, &ogs.MaxMessages, &ogs.MessageCount, &maxAgeMS, &ogs.CreationTime, &ogs.LastEncryptedTime) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { return nil, nil } else if err != nil { return nil, err @@ -590,8 +562,8 @@ func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*Outboun } // RemoveOutboundGroupSession removes the outbound Megolm session for the given room ID. -func (store *SQLCryptoStore) RemoveOutboundGroupSession(roomID id.RoomID) error { - _, err := store.DB.Exec("DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2", +func (store *SQLCryptoStore) RemoveOutboundGroupSession(ctx context.Context, roomID id.RoomID) error { + _, err := store.DB.Exec(ctx, "DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2", roomID, store.AccountID) return err } @@ -608,7 +580,7 @@ func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey ` var expectedEventID id.EventID var expectedTimestamp int64 - err := store.DB.QueryRowContext(ctx, validateQuery, senderKey, sessionID, index, eventID, timestamp).Scan(&expectedEventID, &expectedTimestamp) + err := store.DB.QueryRow(ctx, validateQuery, senderKey, sessionID, index, eventID, timestamp).Scan(&expectedEventID, &expectedTimestamp) if err != nil { return false, err } else if expectedEventID != eventID || expectedTimestamp != timestamp { @@ -623,69 +595,58 @@ func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey return true, nil } +func scanDevice(rows dbutil.Scannable) (*id.Device, error) { + var device id.Device + err := rows.Scan(&device.UserID, &device.DeviceID, &device.IdentityKey, &device.SigningKey, &device.Trust, &device.Deleted, &device.Name) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } else if err != nil { + return nil, err + } + return &device, nil +} + // GetDevices returns a map of device IDs to device identities, including the identity and signing keys, for a given user ID. -func (store *SQLCryptoStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device, error) { +func (store *SQLCryptoStore) GetDevices(ctx context.Context, userID id.UserID) (map[id.DeviceID]*id.Device, error) { var ignore id.UserID - err := store.DB.QueryRow("SELECT user_id FROM crypto_tracked_user WHERE user_id=$1", userID).Scan(&ignore) - if err == sql.ErrNoRows { + err := store.DB.QueryRow(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id=$1", userID).Scan(&ignore) + if errors.Is(err, sql.ErrNoRows) { return nil, nil } else if err != nil { return nil, err } - rows, err := store.DB.Query("SELECT device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1 AND deleted=false", userID) + rows, err := store.DB.Query(ctx, "SELECT user_id, device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1 AND deleted=false", userID) if err != nil { return nil, err } data := make(map[id.DeviceID]*id.Device) - for rows.Next() { - var identity id.Device - err := rows.Scan(&identity.DeviceID, &identity.IdentityKey, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name) - if err != nil { - return nil, err - } - identity.UserID = userID - data[identity.DeviceID] = &identity + err = dbutil.NewRowIter(rows, scanDevice).Iter(func(device *id.Device) (bool, error) { + data[device.DeviceID] = device + return true, nil + }) + if err != nil { + return nil, err } return data, nil } // GetDevice returns the device dentity for a given user and device ID. -func (store *SQLCryptoStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.Device, error) { - var identity id.Device - err := store.DB.QueryRow(` - SELECT identity_key, signing_key, trust, deleted, name +func (store *SQLCryptoStore) GetDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*id.Device, error) { + return scanDevice(store.DB.QueryRow(ctx, ` + SELECT user_id, device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1 AND device_id=$2`, userID, deviceID, - ).Scan(&identity.IdentityKey, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, err - } - identity.UserID = userID - identity.DeviceID = deviceID - return &identity, nil + )) } // FindDeviceByKey finds a specific device by its sender key. -func (store *SQLCryptoStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) { - var identity id.Device - err := store.DB.QueryRow(` - SELECT device_id, signing_key, trust, deleted, name +func (store *SQLCryptoStore) FindDeviceByKey(ctx context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) { + return scanDevice(store.DB.QueryRow(ctx, ` + SELECT user_id, device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1 AND identity_key=$2`, userID, identityKey, - ).Scan(&identity.DeviceID, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } - return nil, err - } - identity.UserID = userID - identity.IdentityKey = identityKey - return &identity, nil + )) } const deviceInsertQuery = ` @@ -698,106 +659,84 @@ ON CONFLICT (user_id, device_id) DO UPDATE var deviceMassInsertTemplate = strings.ReplaceAll(deviceInsertQuery, "($1, $2, $3, $4, $5, $6, $7)", "%s") // PutDevice stores a single device for a user, replacing it if it exists already. -func (store *SQLCryptoStore) PutDevice(userID id.UserID, device *id.Device) error { - _, err := store.DB.Exec(deviceInsertQuery, +func (store *SQLCryptoStore) PutDevice(ctx context.Context, userID id.UserID, device *id.Device) error { + _, err := store.DB.Exec(ctx, deviceInsertQuery, userID, device.DeviceID, device.IdentityKey, device.SigningKey, device.Trust, device.Deleted, device.Name) return err } // PutDevices stores the device identity information for the given user ID. -func (store *SQLCryptoStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*id.Device) error { - tx, err := store.DB.Begin() - if err != nil { - return err - } - - _, err = tx.Exec("INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID) - if err != nil { - return fmt.Errorf("failed to add user to tracked users list: %w", err) - } - - _, err = tx.Exec("UPDATE crypto_device SET deleted=true WHERE user_id=$1", userID) - if err != nil { - _ = tx.Rollback() - return fmt.Errorf("failed to delete old devices: %w", err) - } - if len(devices) == 0 { - err = tx.Commit() +func (store *SQLCryptoStore) PutDevices(ctx context.Context, userID id.UserID, devices map[id.DeviceID]*id.Device) error { + return store.DB.DoTxn(ctx, nil, func(ctx context.Context) error { + _, err := store.DB.Exec(ctx, "INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID) if err != nil { - return fmt.Errorf("failed to commit changes (no devices added): %w", err) + return fmt.Errorf("failed to add user to tracked users list: %w", err) + } + + _, err = store.DB.Exec(ctx, "UPDATE crypto_device SET deleted=true WHERE user_id=$1", userID) + if err != nil { + return fmt.Errorf("failed to delete old devices: %w", err) + } + if len(devices) == 0 { + return nil + } + deviceBatchLen := 5 // how many devices will be inserted per query + deviceIDs := make([]id.DeviceID, 0, len(devices)) + for deviceID := range devices { + deviceIDs = append(deviceIDs, deviceID) + } + const valueStringFormat = "($1, $%d, $%d, $%d, $%d, $%d, $%d)" + for batchDeviceIdx := 0; batchDeviceIdx < len(deviceIDs); batchDeviceIdx += deviceBatchLen { + var batchDevices []id.DeviceID + if batchDeviceIdx+deviceBatchLen < len(deviceIDs) { + batchDevices = deviceIDs[batchDeviceIdx : batchDeviceIdx+deviceBatchLen] + } else { + batchDevices = deviceIDs[batchDeviceIdx:] + } + values := make([]interface{}, 1, len(devices)*6+1) + values[0] = userID + valueStrings := make([]string, 0, len(devices)) + i := 2 + for _, deviceID := range batchDevices { + identity := devices[deviceID] + values = append(values, deviceID, identity.IdentityKey, identity.SigningKey, identity.Trust, identity.Deleted, identity.Name) + valueStrings = append(valueStrings, fmt.Sprintf(valueStringFormat, i, i+1, i+2, i+3, i+4, i+5)) + i += 6 + } + valueString := strings.Join(valueStrings, ",") + _, err = store.DB.Exec(ctx, fmt.Sprintf(deviceMassInsertTemplate, valueString), values...) + if err != nil { + return fmt.Errorf("failed to insert new devices: %w", err) + } } return nil - } - deviceBatchLen := 5 // how many devices will be inserted per query - deviceIDs := make([]id.DeviceID, 0, len(devices)) - for deviceID := range devices { - deviceIDs = append(deviceIDs, deviceID) - } - const valueStringFormat = "($1, $%d, $%d, $%d, $%d, $%d, $%d)" - for batchDeviceIdx := 0; batchDeviceIdx < len(deviceIDs); batchDeviceIdx += deviceBatchLen { - var batchDevices []id.DeviceID - if batchDeviceIdx+deviceBatchLen < len(deviceIDs) { - batchDevices = deviceIDs[batchDeviceIdx : batchDeviceIdx+deviceBatchLen] - } else { - batchDevices = deviceIDs[batchDeviceIdx:] - } - values := make([]interface{}, 1, len(devices)*6+1) - values[0] = userID - valueStrings := make([]string, 0, len(devices)) - i := 2 - for _, deviceID := range batchDevices { - identity := devices[deviceID] - values = append(values, deviceID, identity.IdentityKey, identity.SigningKey, identity.Trust, identity.Deleted, identity.Name) - valueStrings = append(valueStrings, fmt.Sprintf(valueStringFormat, i, i+1, i+2, i+3, i+4, i+5)) - i += 6 - } - valueString := strings.Join(valueStrings, ",") - _, err = tx.Exec(fmt.Sprintf(deviceMassInsertTemplate, valueString), values...) - if err != nil { - _ = tx.Rollback() - return fmt.Errorf("failed to insert new devices: %w", err) - } - } - err = tx.Commit() - if err != nil { - return fmt.Errorf("failed to commit changes: %w", err) - } - return nil + }) } // FilterTrackedUsers finds all the user IDs out of the given ones for which the database contains identity information. -func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error) { +func (store *SQLCryptoStore) FilterTrackedUsers(ctx context.Context, users []id.UserID) ([]id.UserID, error) { var rows dbutil.Rows var err error if store.DB.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil { - rows, err = store.DB.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", PostgresArrayWrapper(users)) + rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", PostgresArrayWrapper(users)) } else { queryString := make([]string, len(users)) params := make([]interface{}, len(users)) for i, user := range users { - queryString[i] = fmt.Sprintf("$%d", i+1) + queryString[i] = fmt.Sprintf("?%d", i+1) params[i] = user } - rows, err = store.DB.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+strings.Join(queryString, ",")+")", params...) + rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+strings.Join(queryString, ",")+")", params...) } if err != nil { return users, err } - var ptr int - for rows.Next() { - err = rows.Scan(&users[ptr]) - if err != nil { - return users, err - } else { - ptr++ - } - } - return users[:ptr], nil + return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.UserID]).AsList() } // PutCrossSigningKey stores a cross-signing key of some user along with its usage. -func (store *SQLCryptoStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error { - _, err := store.DB.Exec(` +func (store *SQLCryptoStore) PutCrossSigningKey(ctx context.Context, userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error { + _, err := store.DB.Exec(ctx, ` INSERT INTO crypto_cross_signing_keys (user_id, usage, key, first_seen_key) VALUES ($1, $2, $3, $4) ON CONFLICT (user_id, usage) DO UPDATE SET key=excluded.key `, userID, usage, key, key) @@ -805,8 +744,8 @@ func (store *SQLCryptoStore) PutCrossSigningKey(userID id.UserID, usage id.Cross } // GetCrossSigningKeys retrieves a user's stored cross-signing keys. -func (store *SQLCryptoStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) { - rows, err := store.DB.Query("SELECT usage, key, first_seen_key FROM crypto_cross_signing_keys WHERE user_id=$1", userID) +func (store *SQLCryptoStore) GetCrossSigningKeys(ctx context.Context, userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) { + rows, err := store.DB.Query(ctx, "SELECT usage, key, first_seen_key FROM crypto_cross_signing_keys WHERE user_id=$1", userID) if err != nil { return nil, err } @@ -825,8 +764,8 @@ func (store *SQLCryptoStore) GetCrossSigningKeys(userID id.UserID) (map[id.Cross } // PutSignature stores a signature of a cross-signing or device key along with the signer's user ID and key. -func (store *SQLCryptoStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error { - _, err := store.DB.Exec(` +func (store *SQLCryptoStore) PutSignature(ctx context.Context, signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error { + _, err := store.DB.Exec(ctx, ` INSERT INTO crypto_cross_signing_signatures (signed_user_id, signed_key, signer_user_id, signer_key, signature) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (signed_user_id, signed_key, signer_user_id, signer_key) DO UPDATE SET signature=excluded.signature `, signedUserID, signedKey, signerUserID, signerKey, signature) @@ -834,8 +773,8 @@ func (store *SQLCryptoStore) PutSignature(signedUserID id.UserID, signedKey id.E } // GetSignaturesForKeyBy retrieves the stored signatures for a given cross-signing or device key, by the given signer. -func (store *SQLCryptoStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) { - rows, err := store.DB.Query("SELECT signer_key, signature FROM crypto_cross_signing_signatures WHERE signed_user_id=$1 AND signed_key=$2 AND signer_user_id=$3", userID, key, signerID) +func (store *SQLCryptoStore) GetSignaturesForKeyBy(ctx context.Context, userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) { + rows, err := store.DB.Query(ctx, "SELECT signer_key, signature FROM crypto_cross_signing_signatures WHERE signed_user_id=$1 AND signed_key=$2 AND signer_user_id=$3", userID, key, signerID) if err != nil { return nil, err } @@ -854,18 +793,18 @@ func (store *SQLCryptoStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25 } // IsKeySignedBy returns whether a cross-signing or device key is signed by the given signer. -func (store *SQLCryptoStore) IsKeySignedBy(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519) (isSigned bool, err error) { +func (store *SQLCryptoStore) IsKeySignedBy(ctx context.Context, signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519) (isSigned bool, err error) { q := `SELECT EXISTS( SELECT 1 FROM crypto_cross_signing_signatures WHERE signed_user_id=$1 AND signed_key=$2 AND signer_user_id=$3 AND signer_key=$4 )` - err = store.DB.QueryRow(q, signedUserID, signedKey, signerUserID, signerKey).Scan(&isSigned) + err = store.DB.QueryRow(ctx, q, signedUserID, signedKey, signerUserID, signerKey).Scan(&isSigned) return } // DropSignaturesByKey deletes the signatures made by the given user and key from the store. It returns the number of signatures deleted. -func (store *SQLCryptoStore) DropSignaturesByKey(userID id.UserID, key id.Ed25519) (int64, error) { - res, err := store.DB.Exec("DELETE FROM crypto_cross_signing_signatures WHERE signer_user_id=$1 AND signer_key=$2", userID, key) +func (store *SQLCryptoStore) DropSignaturesByKey(ctx context.Context, userID id.UserID, key id.Ed25519) (int64, error) { + res, err := store.DB.Exec(ctx, "DELETE FROM crypto_cross_signing_signatures WHERE signer_user_id=$1 AND signer_key=$2", userID, key) if err != nil { return 0, err } diff --git a/crypto/sql_store_upgrade/upgrade.go b/crypto/sql_store_upgrade/upgrade.go index c9541b91..08c995da 100644 --- a/crypto/sql_store_upgrade/upgrade.go +++ b/crypto/sql_store_upgrade/upgrade.go @@ -7,6 +7,7 @@ package sql_store_upgrade import ( + "context" "embed" "fmt" @@ -21,7 +22,7 @@ const VersionTableName = "crypto_version" var fs embed.FS func init() { - Table.Register(-1, 3, 0, "Unsupported version", false, func(tx dbutil.Execable, database *dbutil.Database) error { + Table.Register(-1, 3, 0, "Unsupported version", false, func(ctx context.Context, database *dbutil.Database) error { return fmt.Errorf("upgrading from versions 1 and 2 of the crypto store is no longer supported in mautrix-go v0.12+") }) Table.RegisterFS(fs) diff --git a/crypto/store.go b/crypto/store.go index 99e464d2..09393a51 100644 --- a/crypto/store.go +++ b/crypto/store.go @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -26,64 +26,64 @@ var ErrGroupSessionWithheld error = &event.RoomKeyWithheldEventContent{} type Store interface { // Flush ensures that everything in the store is persisted to disk. // This doesn't have to do anything, e.g. for database-backed implementations that persist everything immediately. - Flush() error + Flush(context.Context) error // PutAccount updates the OlmAccount in the store. - PutAccount(*OlmAccount) error + PutAccount(context.Context, *OlmAccount) error // GetAccount returns the OlmAccount in the store that was previously inserted with PutAccount. - GetAccount() (*OlmAccount, error) + GetAccount(ctx context.Context) (*OlmAccount, error) // AddSession inserts an Olm session into the store. - AddSession(id.SenderKey, *OlmSession) error + AddSession(context.Context, id.SenderKey, *OlmSession) error // HasSession returns whether or not the store has an Olm session with the given sender key. - HasSession(id.SenderKey) bool + HasSession(context.Context, id.SenderKey) bool // GetSessions returns all Olm sessions in the store with the given sender key. - GetSessions(id.SenderKey) (OlmSessionList, error) + GetSessions(context.Context, id.SenderKey) (OlmSessionList, error) // GetLatestSession returns the session with the highest session ID (lexiographically sorting). // It's usually safe to return the most recently added session if sorting by session ID is too difficult. - GetLatestSession(id.SenderKey) (*OlmSession, error) + GetLatestSession(context.Context, id.SenderKey) (*OlmSession, error) // UpdateSession updates a session that has previously been inserted with AddSession. - UpdateSession(id.SenderKey, *OlmSession) error + UpdateSession(context.Context, id.SenderKey, *OlmSession) error // PutGroupSession inserts an inbound Megolm session into the store. If an earlier withhold event has been inserted // with PutWithheldGroupSession, this call should replace that. However, PutWithheldGroupSession must not replace // sessions inserted with this call. - PutGroupSession(id.RoomID, id.SenderKey, id.SessionID, *InboundGroupSession) error + PutGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, *InboundGroupSession) error // GetGroupSession gets an inbound Megolm session from the store. If the group session has been withheld // (i.e. a room key withheld event has been saved with PutWithheldGroupSession), this should return the // ErrGroupSessionWithheld error. The caller may use GetWithheldGroupSession to find more details. - GetGroupSession(id.RoomID, id.SenderKey, id.SessionID) (*InboundGroupSession, error) + GetGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID) (*InboundGroupSession, error) // RedactGroupSession removes the session data for the given inbound Megolm session from the store. - RedactGroupSession(id.RoomID, id.SenderKey, id.SessionID, string) error + RedactGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, string) error // RedactGroupSessions removes the session data for all inbound Megolm sessions from a specific device and/or in a specific room. - RedactGroupSessions(id.RoomID, id.SenderKey, string) ([]id.SessionID, error) + RedactGroupSessions(context.Context, id.RoomID, id.SenderKey, string) ([]id.SessionID, error) // RedactExpiredGroupSessions removes the session data for all inbound Megolm sessions that have expired. - RedactExpiredGroupSessions() ([]id.SessionID, error) + RedactExpiredGroupSessions(context.Context) ([]id.SessionID, error) // RedactOutdatedGroupSessions removes the session data for all inbound Megolm sessions that are lacking the expiration metadata. - RedactOutdatedGroupSessions() ([]id.SessionID, error) + RedactOutdatedGroupSessions(context.Context) ([]id.SessionID, error) // PutWithheldGroupSession tells the store that a specific Megolm session was withheld. - PutWithheldGroupSession(event.RoomKeyWithheldEventContent) error + PutWithheldGroupSession(context.Context, event.RoomKeyWithheldEventContent) error // GetWithheldGroupSession gets the event content that was previously inserted with PutWithheldGroupSession. - GetWithheldGroupSession(id.RoomID, id.SenderKey, id.SessionID) (*event.RoomKeyWithheldEventContent, error) + GetWithheldGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID) (*event.RoomKeyWithheldEventContent, error) // GetGroupSessionsForRoom gets all the inbound Megolm sessions for a specific room. This is used for creating key // export files. Unlike GetGroupSession, this should not return any errors about withheld keys. - GetGroupSessionsForRoom(id.RoomID) ([]*InboundGroupSession, error) + GetGroupSessionsForRoom(context.Context, id.RoomID) ([]*InboundGroupSession, error) // GetAllGroupSessions gets all the inbound Megolm sessions in the store. This is used for creating key export // files. Unlike GetGroupSession, this should not return any errors about withheld keys. - GetAllGroupSessions() ([]*InboundGroupSession, error) + GetAllGroupSessions(context.Context) ([]*InboundGroupSession, error) // AddOutboundGroupSession inserts the given outbound Megolm session into the store. // // The store should index inserted sessions by the RoomID field to support getting and removing sessions. // There will only be one outbound session per room ID at a time. - AddOutboundGroupSession(*OutboundGroupSession) error + AddOutboundGroupSession(context.Context, *OutboundGroupSession) error // UpdateOutboundGroupSession updates the given outbound Megolm session in the store. - UpdateOutboundGroupSession(*OutboundGroupSession) error + UpdateOutboundGroupSession(context.Context, *OutboundGroupSession) error // GetOutboundGroupSession gets the stored outbound Megolm session for the given room ID from the store. - GetOutboundGroupSession(id.RoomID) (*OutboundGroupSession, error) + GetOutboundGroupSession(context.Context, id.RoomID) (*OutboundGroupSession, error) // RemoveOutboundGroupSession removes the stored outbound Megolm session for the given room ID. - RemoveOutboundGroupSession(id.RoomID) error + RemoveOutboundGroupSession(context.Context, id.RoomID) error // ValidateMessageIndex validates that the given message details aren't from a replay attack. // @@ -96,29 +96,29 @@ type Store interface { ValidateMessageIndex(ctx context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) // GetDevices returns a map from device ID to id.Device struct containing all devices of a given user. - GetDevices(id.UserID) (map[id.DeviceID]*id.Device, error) + GetDevices(context.Context, id.UserID) (map[id.DeviceID]*id.Device, error) // GetDevice returns a specific device of a given user. - GetDevice(id.UserID, id.DeviceID) (*id.Device, error) + GetDevice(context.Context, id.UserID, id.DeviceID) (*id.Device, error) // PutDevice stores a single device for a user, replacing it if it exists already. - PutDevice(id.UserID, *id.Device) error + PutDevice(context.Context, id.UserID, *id.Device) error // PutDevices overrides the stored device list for the given user with the given list. - PutDevices(id.UserID, map[id.DeviceID]*id.Device) error + PutDevices(context.Context, id.UserID, map[id.DeviceID]*id.Device) error // FindDeviceByKey finds a specific device by its identity key. - FindDeviceByKey(id.UserID, id.IdentityKey) (*id.Device, error) + FindDeviceByKey(context.Context, id.UserID, id.IdentityKey) (*id.Device, error) // FilterTrackedUsers returns a filtered version of the given list that only includes user IDs whose device lists // have been stored with PutDevices. A user is considered tracked even if the PutDevices list was empty. - FilterTrackedUsers([]id.UserID) ([]id.UserID, error) + FilterTrackedUsers(context.Context, []id.UserID) ([]id.UserID, error) // PutCrossSigningKey stores a cross-signing key of some user along with its usage. - PutCrossSigningKey(id.UserID, id.CrossSigningUsage, id.Ed25519) error + PutCrossSigningKey(context.Context, id.UserID, id.CrossSigningUsage, id.Ed25519) error // GetCrossSigningKeys retrieves a user's stored cross-signing keys. - GetCrossSigningKeys(id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) + GetCrossSigningKeys(context.Context, id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) // PutSignature stores a signature of a cross-signing or device key along with the signer's user ID and key. - PutSignature(signedUser id.UserID, signedKey id.Ed25519, signerUser id.UserID, signerKey id.Ed25519, signature string) error + PutSignature(ctx context.Context, signedUser id.UserID, signedKey id.Ed25519, signerUser id.UserID, signerKey id.Ed25519, signature string) error // IsKeySignedBy returns whether a cross-signing or device key is signed by the given signer. - IsKeySignedBy(userID id.UserID, key id.Ed25519, signedByUser id.UserID, signedByKey id.Ed25519) (bool, error) + IsKeySignedBy(ctx context.Context, userID id.UserID, key id.Ed25519, signedByUser id.UserID, signedByKey id.Ed25519) (bool, error) // DropSignaturesByKey deletes the signatures made by the given user and key from the store. It returns the number of signatures deleted. - DropSignaturesByKey(id.UserID, id.Ed25519) (int64, error) + DropSignaturesByKey(context.Context, id.UserID, id.Ed25519) (int64, error) } type messageIndexKey struct { @@ -170,18 +170,18 @@ func NewMemoryStore(saveCallback func() error) *MemoryStore { } } -func (gs *MemoryStore) Flush() error { +func (gs *MemoryStore) Flush(_ context.Context) error { gs.lock.Lock() err := gs.save() gs.lock.Unlock() return err } -func (gs *MemoryStore) GetAccount() (*OlmAccount, error) { +func (gs *MemoryStore) GetAccount(_ context.Context) (*OlmAccount, error) { return gs.Account, nil } -func (gs *MemoryStore) PutAccount(account *OlmAccount) error { +func (gs *MemoryStore) PutAccount(_ context.Context, account *OlmAccount) error { gs.lock.Lock() gs.Account = account err := gs.save() @@ -189,7 +189,7 @@ func (gs *MemoryStore) PutAccount(account *OlmAccount) error { return err } -func (gs *MemoryStore) GetSessions(senderKey id.SenderKey) (OlmSessionList, error) { +func (gs *MemoryStore) GetSessions(_ context.Context, senderKey id.SenderKey) (OlmSessionList, error) { gs.lock.Lock() sessions, ok := gs.Sessions[senderKey] if !ok { @@ -200,7 +200,7 @@ func (gs *MemoryStore) GetSessions(senderKey id.SenderKey) (OlmSessionList, erro return sessions, nil } -func (gs *MemoryStore) AddSession(senderKey id.SenderKey, session *OlmSession) error { +func (gs *MemoryStore) AddSession(_ context.Context, senderKey id.SenderKey, session *OlmSession) error { gs.lock.Lock() sessions, _ := gs.Sessions[senderKey] gs.Sessions[senderKey] = append(sessions, session) @@ -210,19 +210,19 @@ func (gs *MemoryStore) AddSession(senderKey id.SenderKey, session *OlmSession) e return err } -func (gs *MemoryStore) UpdateSession(_ id.SenderKey, _ *OlmSession) error { +func (gs *MemoryStore) UpdateSession(_ context.Context, _ id.SenderKey, _ *OlmSession) error { // we don't need to do anything here because the session is a pointer and already stored in our map return gs.save() } -func (gs *MemoryStore) HasSession(senderKey id.SenderKey) bool { +func (gs *MemoryStore) HasSession(_ context.Context, senderKey id.SenderKey) bool { gs.lock.RLock() sessions, ok := gs.Sessions[senderKey] gs.lock.RUnlock() return ok && len(sessions) > 0 && !sessions[0].Expired() } -func (gs *MemoryStore) GetLatestSession(senderKey id.SenderKey) (*OlmSession, error) { +func (gs *MemoryStore) GetLatestSession(_ context.Context, senderKey id.SenderKey) (*OlmSession, error) { gs.lock.RLock() sessions, ok := gs.Sessions[senderKey] gs.lock.RUnlock() @@ -246,7 +246,7 @@ func (gs *MemoryStore) getGroupSessions(roomID id.RoomID, senderKey id.SenderKey return sender } -func (gs *MemoryStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, igs *InboundGroupSession) error { +func (gs *MemoryStore) PutGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, igs *InboundGroupSession) error { gs.lock.Lock() gs.getGroupSessions(roomID, senderKey)[sessionID] = igs err := gs.save() @@ -254,7 +254,7 @@ func (gs *MemoryStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, return err } -func (gs *MemoryStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) { +func (gs *MemoryStore) GetGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) { gs.lock.Lock() session, ok := gs.getGroupSessions(roomID, senderKey)[sessionID] if !ok { @@ -269,7 +269,7 @@ func (gs *MemoryStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, return session, nil } -func (gs *MemoryStore) RedactGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, reason string) error { +func (gs *MemoryStore) RedactGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, reason string) error { gs.lock.Lock() delete(gs.getGroupSessions(roomID, senderKey), sessionID) err := gs.save() @@ -277,7 +277,7 @@ func (gs *MemoryStore) RedactGroupSession(roomID id.RoomID, senderKey id.SenderK return err } -func (gs *MemoryStore) RedactGroupSessions(roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) { +func (gs *MemoryStore) RedactGroupSessions(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) { gs.lock.Lock() var sessionIDs []id.SessionID if roomID != "" && senderKey != "" { @@ -315,11 +315,11 @@ func (gs *MemoryStore) RedactGroupSessions(roomID id.RoomID, senderKey id.Sender return sessionIDs, err } -func (gs *MemoryStore) RedactExpiredGroupSessions() ([]id.SessionID, error) { +func (gs *MemoryStore) RedactExpiredGroupSessions(_ context.Context) ([]id.SessionID, error) { return nil, fmt.Errorf("not implemented") } -func (gs *MemoryStore) RedactOutdatedGroupSessions() ([]id.SessionID, error) { +func (gs *MemoryStore) RedactOutdatedGroupSessions(_ context.Context) ([]id.SessionID, error) { return nil, fmt.Errorf("not implemented") } @@ -337,7 +337,7 @@ func (gs *MemoryStore) getWithheldGroupSessions(roomID id.RoomID, senderKey id.S return sender } -func (gs *MemoryStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error { +func (gs *MemoryStore) PutWithheldGroupSession(_ context.Context, content event.RoomKeyWithheldEventContent) error { gs.lock.Lock() gs.getWithheldGroupSessions(content.RoomID, content.SenderKey)[content.SessionID] = &content err := gs.save() @@ -345,7 +345,7 @@ func (gs *MemoryStore) PutWithheldGroupSession(content event.RoomKeyWithheldEven return err } -func (gs *MemoryStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) { +func (gs *MemoryStore) GetWithheldGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) { gs.lock.Lock() session, ok := gs.getWithheldGroupSessions(roomID, senderKey)[sessionID] gs.lock.Unlock() @@ -355,7 +355,7 @@ func (gs *MemoryStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.Se return session, nil } -func (gs *MemoryStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) { +func (gs *MemoryStore) GetGroupSessionsForRoom(_ context.Context, roomID id.RoomID) ([]*InboundGroupSession, error) { gs.lock.Lock() defer gs.lock.Unlock() room, ok := gs.GroupSessions[roomID] @@ -371,7 +371,7 @@ func (gs *MemoryStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGrou return result, nil } -func (gs *MemoryStore) GetAllGroupSessions() ([]*InboundGroupSession, error) { +func (gs *MemoryStore) GetAllGroupSessions(_ context.Context) ([]*InboundGroupSession, error) { gs.lock.Lock() var result []*InboundGroupSession for _, room := range gs.GroupSessions { @@ -385,7 +385,7 @@ func (gs *MemoryStore) GetAllGroupSessions() ([]*InboundGroupSession, error) { return result, nil } -func (gs *MemoryStore) AddOutboundGroupSession(session *OutboundGroupSession) error { +func (gs *MemoryStore) AddOutboundGroupSession(_ context.Context, session *OutboundGroupSession) error { gs.lock.Lock() gs.OutGroupSessions[session.RoomID] = session err := gs.save() @@ -393,12 +393,12 @@ func (gs *MemoryStore) AddOutboundGroupSession(session *OutboundGroupSession) er return err } -func (gs *MemoryStore) UpdateOutboundGroupSession(_ *OutboundGroupSession) error { +func (gs *MemoryStore) UpdateOutboundGroupSession(_ context.Context, _ *OutboundGroupSession) error { // we don't need to do anything here because the session is a pointer and already stored in our map return gs.save() } -func (gs *MemoryStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSession, error) { +func (gs *MemoryStore) GetOutboundGroupSession(_ context.Context, roomID id.RoomID) (*OutboundGroupSession, error) { gs.lock.RLock() session, ok := gs.OutGroupSessions[roomID] gs.lock.RUnlock() @@ -408,7 +408,7 @@ func (gs *MemoryStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroup return session, nil } -func (gs *MemoryStore) RemoveOutboundGroupSession(roomID id.RoomID) error { +func (gs *MemoryStore) RemoveOutboundGroupSession(_ context.Context, roomID id.RoomID) error { gs.lock.Lock() session, ok := gs.OutGroupSessions[roomID] if !ok || session == nil { @@ -443,7 +443,7 @@ func (gs *MemoryStore) ValidateMessageIndex(_ context.Context, senderKey id.Send return true, nil } -func (gs *MemoryStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device, error) { +func (gs *MemoryStore) GetDevices(_ context.Context, userID id.UserID) (map[id.DeviceID]*id.Device, error) { gs.lock.RLock() devices, ok := gs.Devices[userID] if !ok { @@ -453,7 +453,7 @@ func (gs *MemoryStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device, return devices, nil } -func (gs *MemoryStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.Device, error) { +func (gs *MemoryStore) GetDevice(_ context.Context, userID id.UserID, deviceID id.DeviceID) (*id.Device, error) { gs.lock.RLock() defer gs.lock.RUnlock() devices, ok := gs.Devices[userID] @@ -467,7 +467,7 @@ func (gs *MemoryStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.De return device, nil } -func (gs *MemoryStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) { +func (gs *MemoryStore) FindDeviceByKey(_ context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) { gs.lock.RLock() defer gs.lock.RUnlock() devices, ok := gs.Devices[userID] @@ -482,7 +482,7 @@ func (gs *MemoryStore) FindDeviceByKey(userID id.UserID, identityKey id.Identity return nil, nil } -func (gs *MemoryStore) PutDevice(userID id.UserID, device *id.Device) error { +func (gs *MemoryStore) PutDevice(_ context.Context, userID id.UserID, device *id.Device) error { gs.lock.Lock() devices, ok := gs.Devices[userID] if !ok { @@ -495,7 +495,7 @@ func (gs *MemoryStore) PutDevice(userID id.UserID, device *id.Device) error { return err } -func (gs *MemoryStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*id.Device) error { +func (gs *MemoryStore) PutDevices(_ context.Context, userID id.UserID, devices map[id.DeviceID]*id.Device) error { gs.lock.Lock() gs.Devices[userID] = devices err := gs.save() @@ -503,7 +503,7 @@ func (gs *MemoryStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*id. return err } -func (gs *MemoryStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error) { +func (gs *MemoryStore) FilterTrackedUsers(_ context.Context, users []id.UserID) ([]id.UserID, error) { gs.lock.RLock() var ptr int for _, userID := range users { @@ -517,7 +517,7 @@ func (gs *MemoryStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error return users[:ptr], nil } -func (gs *MemoryStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error { +func (gs *MemoryStore) PutCrossSigningKey(_ context.Context, userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error { gs.lock.RLock() userKeys, ok := gs.CrossSigningKeys[userID] if !ok { @@ -539,7 +539,7 @@ func (gs *MemoryStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSignin return err } -func (gs *MemoryStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) { +func (gs *MemoryStore) GetCrossSigningKeys(_ context.Context, userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) { gs.lock.RLock() defer gs.lock.RUnlock() keys, ok := gs.CrossSigningKeys[userID] @@ -549,7 +549,7 @@ func (gs *MemoryStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSignin return keys, nil } -func (gs *MemoryStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error { +func (gs *MemoryStore) PutSignature(_ context.Context, signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error { gs.lock.RLock() signedUserSigs, ok := gs.KeySignatures[signedUserID] if !ok { @@ -572,7 +572,7 @@ func (gs *MemoryStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519 return err } -func (gs *MemoryStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) { +func (gs *MemoryStore) GetSignaturesForKeyBy(_ context.Context, userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) { gs.lock.RLock() defer gs.lock.RUnlock() userKeys, ok := gs.KeySignatures[userID] @@ -590,8 +590,8 @@ func (gs *MemoryStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, s return sigsBySigner, nil } -func (gs *MemoryStore) IsKeySignedBy(userID id.UserID, key id.Ed25519, signerID id.UserID, signerKey id.Ed25519) (bool, error) { - sigs, err := gs.GetSignaturesForKeyBy(userID, key, signerID) +func (gs *MemoryStore) IsKeySignedBy(ctx context.Context, userID id.UserID, key id.Ed25519, signerID id.UserID, signerKey id.Ed25519) (bool, error) { + sigs, err := gs.GetSignaturesForKeyBy(ctx, userID, key, signerID) if err != nil { return false, err } @@ -599,7 +599,7 @@ func (gs *MemoryStore) IsKeySignedBy(userID id.UserID, key id.Ed25519, signerID return ok, nil } -func (gs *MemoryStore) DropSignaturesByKey(userID id.UserID, key id.Ed25519) (int64, error) { +func (gs *MemoryStore) DropSignaturesByKey(_ context.Context, userID id.UserID, key id.Ed25519) (int64, error) { var count int64 gs.lock.RLock() for _, userSigs := range gs.KeySignatures { diff --git a/crypto/store_test.go b/crypto/store_test.go index ebeef393..665e3ef9 100644 --- a/crypto/store_test.go +++ b/crypto/store_test.go @@ -36,7 +36,7 @@ func getCryptoStores(t *testing.T) map[string]Store { t.Fatalf("Error opening db: %v", err) } sqlStore := NewSQLCryptoStore(db, nil, "accid", id.DeviceID("dev"), []byte("test")) - if err = sqlStore.DB.Upgrade(); err != nil { + if err = sqlStore.DB.Upgrade(context.TODO()); err != nil { t.Fatalf("Error creating tables: %v", err) } @@ -65,8 +65,8 @@ func TestPutAccount(t *testing.T) { for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { acc := NewOlmAccount() - store.PutAccount(acc) - retrieved, err := store.GetAccount() + store.PutAccount(context.TODO(), acc) + retrieved, err := store.GetAccount(context.TODO()) if err != nil { t.Fatalf("Error retrieving account: %v", err) } @@ -105,7 +105,7 @@ func TestStoreOlmSession(t *testing.T) { stores := getCryptoStores(t) for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { - if store.HasSession(olmSessID) { + if store.HasSession(context.TODO(), olmSessID) { t.Error("Found Olm session before inserting it") } olmInternal, err := olm.SessionFromPickled([]byte(olmPickled), []byte("test")) @@ -117,15 +117,15 @@ func TestStoreOlmSession(t *testing.T) { id: olmSessID, Internal: *olmInternal, } - err = store.AddSession(olmSessID, &olmSess) + err = store.AddSession(context.TODO(), olmSessID, &olmSess) if err != nil { t.Errorf("Error storing Olm session: %v", err) } - if !store.HasSession(olmSessID) { + if !store.HasSession(context.TODO(), olmSessID) { t.Error("Not found Olm session after inserting it") } - retrieved, err := store.GetLatestSession(olmSessID) + retrieved, err := store.GetLatestSession(context.TODO(), olmSessID) if err != nil { t.Errorf("Failed retrieving Olm session: %v", err) } @@ -158,12 +158,12 @@ func TestStoreMegolmSession(t *testing.T) { RoomID: "room1", } - err = store.PutGroupSession("room1", acc.IdentityKey(), igs.ID(), igs) + err = store.PutGroupSession(context.TODO(), "room1", acc.IdentityKey(), igs.ID(), igs) if err != nil { t.Errorf("Error storing inbound group session: %v", err) } - retrieved, err := store.GetGroupSession("room1", acc.IdentityKey(), igs.ID()) + retrieved, err := store.GetGroupSession(context.TODO(), "room1", acc.IdentityKey(), igs.ID()) if err != nil { t.Errorf("Error retrieving inbound group session: %v", err) } @@ -179,7 +179,7 @@ func TestStoreOutboundMegolmSession(t *testing.T) { stores := getCryptoStores(t) for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { - sess, err := store.GetOutboundGroupSession("room1") + sess, err := store.GetOutboundGroupSession(context.TODO(), "room1") if sess != nil { t.Error("Got outbound session before inserting") } @@ -188,12 +188,12 @@ func TestStoreOutboundMegolmSession(t *testing.T) { } outbound := NewOutboundGroupSession("room1", nil) - err = store.AddOutboundGroupSession(outbound) + err = store.AddOutboundGroupSession(context.TODO(), outbound) if err != nil { t.Errorf("Error inserting outbound session: %v", err) } - sess, err = store.GetOutboundGroupSession("room1") + sess, err = store.GetOutboundGroupSession(context.TODO(), "room1") if sess == nil { t.Error("Did not get outbound session after inserting") } @@ -201,12 +201,12 @@ func TestStoreOutboundMegolmSession(t *testing.T) { t.Errorf("Error retrieving outbound session: %v", err) } - err = store.RemoveOutboundGroupSession("room1") + err = store.RemoveOutboundGroupSession(context.TODO(), "room1") if err != nil { t.Errorf("Error deleting outbound session: %v", err) } - sess, err = store.GetOutboundGroupSession("room1") + sess, err = store.GetOutboundGroupSession(context.TODO(), "room1") if sess != nil { t.Error("Got outbound session after deleting") } @@ -232,11 +232,11 @@ func TestStoreDevices(t *testing.T) { SigningKey: acc.SigningKey(), } } - err := store.PutDevices("user1", deviceMap) + err := store.PutDevices(context.TODO(), "user1", deviceMap) if err != nil { t.Errorf("Error string devices: %v", err) } - devs, err := store.GetDevices("user1") + devs, err := store.GetDevices(context.TODO(), "user1") if err != nil { t.Errorf("Error getting devices: %v", err) } @@ -250,7 +250,7 @@ func TestStoreDevices(t *testing.T) { t.Errorf("Last device identity key does not match") } - filtered, err := store.FilterTrackedUsers([]id.UserID{"user0", "user1", "user2"}) + filtered, err := store.FilterTrackedUsers(context.TODO(), []id.UserID{"user0", "user1", "user2"}) if err != nil { t.Errorf("Error filtering tracked users: %v", err) } else if len(filtered) != 1 || filtered[0] != "user1" { diff --git a/crypto/verification.go b/crypto/verification.go index be246874..31608bfa 100644 --- a/crypto/verification.go +++ b/crypto/verification.go @@ -507,7 +507,7 @@ func (mach *OlmMachine) handleVerificationMAC(ctx context.Context, userID id.Use // we can finally trust this device device.Trust = id.TrustStateVerified - err = mach.CryptoStore.PutDevice(device.UserID, device) + err = mach.CryptoStore.PutDevice(ctx, device.UserID, device) if err != nil { mach.Log.Warn().Msgf("Failed to put device after verifying: %v", err) } @@ -521,7 +521,7 @@ func (mach *OlmMachine) handleVerificationMAC(ctx context.Context, userID id.Use mach.Log.Debug().Msgf("Cross-signed own device %v after SAS verification", device.DeviceID) } } else { - masterKey, err := mach.fetchMasterKey(device, content, verState, transactionID) + masterKey, err := mach.fetchMasterKey(ctx, device, content, verState, transactionID) if err != nil { mach.Log.Warn().Msgf("Failed to fetch %s's master key: %v", device.UserID, err) } else { diff --git a/go.mod b/go.mod index 07e3efdc..8484acc3 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/tidwall/gjson v1.17.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.6.0 - go.mau.fi/util v0.2.2-0.20231228160822-a6d40c214e80 + go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894 go.mau.fi/zeroconfig v0.1.2 golang.org/x/crypto v0.17.0 golang.org/x/exp v0.0.0-20231226003508-02704c960a9b diff --git a/go.sum b/go.sum index f52dc4cd..d923c7b1 100644 --- a/go.sum +++ b/go.sum @@ -36,8 +36,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.6.0 h1:boZcn2GTjpsynOsC0iJHnBWa4Bi0qzfJjthwauItG68= github.com/yuin/goldmark v1.6.0/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.mau.fi/util v0.2.2-0.20231228160822-a6d40c214e80 h1:zcfIxHgzZpgGSJv/FUVbOjO4ZWa12En4TGhxgUI/QH0= -go.mau.fi/util v0.2.2-0.20231228160822-a6d40c214e80/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs= +go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894 h1:CuR5LDSxBQLETorfwJ9vRtySeLHjMvJ7//lnCMw7Dy8= +go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= diff --git a/sqlstatestore/statestore.go b/sqlstatestore/statestore.go index 531b71e4..cd94215d 100644 --- a/sqlstatestore/statestore.go +++ b/sqlstatestore/statestore.go @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -7,6 +7,7 @@ package sqlstatestore import ( + "context" "database/sql" "embed" "encoding/json" @@ -15,6 +16,7 @@ import ( "strconv" "strings" + "github.com/rs/zerolog" "go.mau.fi/util/dbutil" "maunium.net/go/mautrix/event" @@ -44,26 +46,28 @@ func NewSQLStateStore(db *dbutil.Database, log dbutil.DatabaseLogger, isBridge b } } -func (store *SQLStateStore) IsRegistered(userID id.UserID) bool { +func (store *SQLStateStore) IsRegistered(ctx context.Context, userID id.UserID) (bool, error) { var isRegistered bool err := store. - QueryRow("SELECT EXISTS(SELECT 1 FROM mx_registrations WHERE user_id=$1)", userID). + QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM mx_registrations WHERE user_id=$1)", userID). Scan(&isRegistered) - if err != nil { - store.Log.Warn("Failed to scan registration existence for %s: %v", userID, err) + if errors.Is(err, sql.ErrNoRows) { + err = nil } - return isRegistered + return isRegistered, err } -func (store *SQLStateStore) MarkRegistered(userID id.UserID) { - _, err := store.Exec("INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID) - if err != nil { - store.Log.Warn("Failed to mark %s as registered: %v", userID, err) - } +func (store *SQLStateStore) MarkRegistered(ctx context.Context, userID id.UserID) error { + _, err := store.Exec(ctx, "INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID) + return err } -func (store *SQLStateStore) GetRoomMembers(roomID id.RoomID, memberships ...event.Membership) map[id.UserID]*event.MemberEventContent { - members := make(map[id.UserID]*event.MemberEventContent) +type Member struct { + id.UserID + event.MemberEventContent +} + +func (store *SQLStateStore) GetRoomMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) (map[id.UserID]*event.MemberEventContent, error) { args := make([]any, len(memberships)+1) args[0] = roomID query := "SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1" @@ -75,25 +79,26 @@ func (store *SQLStateStore) GetRoomMembers(roomID id.RoomID, memberships ...even } query = fmt.Sprintf("%s AND membership IN (%s)", query, strings.Join(placeholders, ",")) } - rows, err := store.Query(query, args...) + rows, err := store.Query(ctx, query, args...) if err != nil { - return members + return nil, err } - var userID id.UserID - var member event.MemberEventContent - for rows.Next() { - err = rows.Scan(&userID, &member.Membership, &member.Displayname, &member.AvatarURL) - if err != nil { - store.Log.Warn("Failed to scan member in %s: %v", roomID, err) - } else { - members[userID] = &member - } - } - return members + members := make(map[id.UserID]*event.MemberEventContent) + return members, dbutil.NewRowIter(rows, func(row dbutil.Scannable) (ret Member, err error) { + err = row.Scan(&ret.UserID, &ret.Membership, &ret.Displayname, &ret.AvatarURL) + return + }).Iter(func(m Member) (bool, error) { + members[m.UserID] = &m.MemberEventContent + return true, nil + }) } -func (store *SQLStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) (members []id.UserID, err error) { - memberMap := store.GetRoomMembers(roomID, event.MembershipJoin, event.MembershipInvite) +func (store *SQLStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) (members []id.UserID, err error) { + var memberMap map[id.UserID]*event.MemberEventContent + memberMap, err = store.GetRoomMembers(ctx, roomID, event.MembershipJoin, event.MembershipInvite) + if err != nil { + return + } members = make([]id.UserID, len(memberMap)) i := 0 for userID := range memberMap { @@ -103,37 +108,39 @@ func (store *SQLStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) (mem return } -func (store *SQLStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership { - membership := event.MembershipLeave - err := store. - QueryRow("SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID). +func (store *SQLStateStore) GetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID) (membership event.Membership, err error) { + err = store. + QueryRow(ctx, "SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID). Scan(&membership) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - store.Log.Warn("Failed to scan membership of %s in %s: %v", userID, roomID, err) + if errors.Is(err, sql.ErrNoRows) { + membership = event.MembershipLeave + err = nil } - return membership + return } -func (store *SQLStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent { - member, ok := store.TryGetMember(roomID, userID) - if !ok { - member.Membership = event.MembershipLeave +func (store *SQLStateStore) GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) { + member, err := store.TryGetMember(ctx, roomID, userID) + if member == nil && err == nil { + member = &event.MemberEventContent{Membership: event.MembershipLeave} } - return member + return member, err } -func (store *SQLStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool) { +func (store *SQLStateStore) TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) { var member event.MemberEventContent err := store. - QueryRow("SELECT membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID). + QueryRow(ctx, "SELECT membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID). Scan(&member.Membership, &member.Displayname, &member.AvatarURL) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - store.Log.Warn("Failed to scan member info of %s in %s: %v", userID, roomID, err) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } else if err != nil { + return nil, err } - return &member, err == nil + return &member, nil } -func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID) { +func (store *SQLStateStore) FindSharedRooms(ctx context.Context, userID id.UserID) ([]id.RoomID, error) { query := ` SELECT room_id FROM mx_user_profile LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id @@ -141,38 +148,32 @@ func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID ` if !store.IsBridge { query = ` - SELECT mx_user_profile.room_id FROM mx_user_profile - LEFT JOIN mx_room_state ON mx_room_state.room_id=mx_user_profile.room_id - WHERE mx_user_profile.user_id=$1 AND mx_room_state.encryption IS NOT NULL - ` + SELECT mx_user_profile.room_id FROM mx_user_profile + LEFT JOIN mx_room_state ON mx_room_state.room_id=mx_user_profile.room_id + WHERE mx_user_profile.user_id=$1 AND mx_room_state.encryption IS NOT NULL + ` } - rows, err := store.Query(query, userID) + rows, err := store.Query(ctx, query, userID) if err != nil { - store.Log.Warn("Failed to query shared rooms with %s: %v", userID, err) - return + return nil, err } - for rows.Next() { - var roomID id.RoomID - err = rows.Scan(&roomID) - if err != nil { - store.Log.Warn("Failed to scan room ID: %v", err) - } else { - rooms = append(rooms, roomID) - } + return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList() +} + +func (store *SQLStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool { + return store.IsMembership(ctx, roomID, userID, "join") +} + +func (store *SQLStateStore) IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool { + return store.IsMembership(ctx, roomID, userID, "join", "invite") +} + +func (store *SQLStateStore) IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool { + membership, err := store.GetMembership(ctx, roomID, userID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get membership") + return false } - return -} - -func (store *SQLStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool { - return store.IsMembership(roomID, userID, "join") -} - -func (store *SQLStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool { - return store.IsMembership(roomID, userID, "join", "invite") -} - -func (store *SQLStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool { - membership := store.GetMembership(roomID, userID) for _, allowedMembership := range allowedMemberships { if allowedMembership == membership { return true @@ -181,27 +182,23 @@ func (store *SQLStateStore) IsMembership(roomID id.RoomID, userID id.UserID, all return false } -func (store *SQLStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) { - _, err := store.Exec(` +func (store *SQLStateStore) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error { + _, err := store.Exec(ctx, ` INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, '', '') ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership `, roomID, userID, membership) - if err != nil { - store.Log.Warn("Failed to set membership of %s in %s to %s: %v", userID, roomID, membership, err) - } + return err } -func (store *SQLStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) { - _, err := store.Exec(` +func (store *SQLStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error { + _, err := store.Exec(ctx, ` INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership, displayname=excluded.displayname, avatar_url=excluded.avatar_url `, roomID, userID, member.Membership, member.Displayname, member.AvatarURL) - if err != nil { - store.Log.Warn("Failed to set membership of %s in %s to %s: %v", userID, roomID, member, err) - } + return err } -func (store *SQLStateStore) ClearCachedMembers(roomID id.RoomID, memberships ...event.Membership) { +func (store *SQLStateStore) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error { query := "DELETE FROM mx_user_profile WHERE room_id=$1" params := make([]any, len(memberships)+1) params[0] = roomID @@ -213,109 +210,85 @@ func (store *SQLStateStore) ClearCachedMembers(roomID id.RoomID, memberships ... } query += fmt.Sprintf(" AND membership IN (%s)", strings.Join(placeholders, ",")) } - _, err := store.Exec(query, params...) - if err != nil { - store.Log.Warn("Failed to clear cached members of %s: %v", roomID, err) - } + _, err := store.Exec(ctx, query, params...) + return err } -func (store *SQLStateStore) SetEncryptionEvent(roomID id.RoomID, content *event.EncryptionEventContent) { +func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error { contentBytes, err := json.Marshal(content) if err != nil { - store.Log.Warn("Failed to marshal encryption config of %s: %v", roomID, err) - return + return fmt.Errorf("failed to marshal content JSON: %w", err) } - _, err = store.Exec(` + _, err = store.Exec(ctx, ` INSERT INTO mx_room_state (room_id, encryption) VALUES ($1, $2) ON CONFLICT (room_id) DO UPDATE SET encryption=excluded.encryption `, roomID, contentBytes) - if err != nil { - store.Log.Warn("Failed to store encryption config of %s: %v", roomID, err) - } + return err } -func (store *SQLStateStore) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventContent { +func (store *SQLStateStore) GetEncryptionEvent(ctx context.Context, roomID id.RoomID) (*event.EncryptionEventContent, error) { var data []byte err := store. - QueryRow("SELECT encryption FROM mx_room_state WHERE room_id=$1", roomID). + QueryRow(ctx, "SELECT encryption FROM mx_room_state WHERE room_id=$1", roomID). Scan(&data) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - store.Log.Warn("Failed to scan encryption config of %s: %v", roomID, err) - } - return nil + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } else if err != nil { + return nil, err } else if data == nil { - return nil + return nil, nil } - content := &event.EncryptionEventContent{} - err = json.Unmarshal(data, content) + var content event.EncryptionEventContent + err = json.Unmarshal(data, &content) if err != nil { - store.Log.Warn("Failed to parse encryption config of %s: %v", roomID, err) - return nil + return nil, fmt.Errorf("failed to parse content JSON: %w", err) } - return content + return &content, nil } -func (store *SQLStateStore) IsEncrypted(roomID id.RoomID) bool { - cfg := store.GetEncryptionEvent(roomID) - return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1 +func (store *SQLStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) { + cfg, err := store.GetEncryptionEvent(ctx, roomID) + return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1, err } -func (store *SQLStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) { - levelsBytes, err := json.Marshal(levels) - if err != nil { - store.Log.Warn("Failed to marshal power levels of %s: %v", roomID, err) - return - } - _, err = store.Exec(` +func (store *SQLStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error { + _, err := store.Exec(ctx, ` INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2) ON CONFLICT (room_id) DO UPDATE SET power_levels=excluded.power_levels - `, roomID, levelsBytes) - if err != nil { - store.Log.Warn("Failed to store power levels of %s: %v", roomID, err) - } + `, roomID, dbutil.JSON{Data: levels}) + return err } -func (store *SQLStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) { - var data []byte - err := store. - QueryRow("SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID). - Scan(&data) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - store.Log.Warn("Failed to scan power levels of %s: %v", roomID, err) - } - return - } else if data == nil { - return - } - levels = &event.PowerLevelsEventContent{} - err = json.Unmarshal(data, levels) - if err != nil { - store.Log.Warn("Failed to parse power levels of %s: %v", roomID, err) - return nil +func (store *SQLStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) { + err = store. + QueryRow(ctx, "SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID). + Scan(&dbutil.JSON{Data: &levels}) + if errors.Is(err, sql.ErrNoRows) { + err = nil } return } -func (store *SQLStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int { +func (store *SQLStateStore) GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error) { if store.Dialect == dbutil.Postgres { var powerLevel int err := store. - QueryRow(` + QueryRow(ctx, ` SELECT COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0) FROM mx_room_state WHERE room_id=$1 `, roomID, userID). Scan(&powerLevel) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - store.Log.Warn("Failed to scan power level of %s in %s: %v", userID, roomID, err) + return powerLevel, err + } else { + levels, err := store.GetPowerLevels(ctx, roomID) + if err != nil { + return 0, err } - return powerLevel + return levels.GetUserLevel(userID), nil } - return store.GetPowerLevels(roomID).GetUserLevel(userID) } -func (store *SQLStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int { +func (store *SQLStateStore) GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error) { if store.Dialect == dbutil.Postgres { defaultType := "events_default" defaultValue := 0 @@ -325,23 +298,26 @@ func (store *SQLStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType } var powerLevel int err := store. - QueryRow(` + QueryRow(ctx, ` SELECT COALESCE((power_levels->'events'->$2)::int, (power_levels->'$3')::int, $4) FROM mx_room_state WHERE room_id=$1 `, roomID, eventType.Type, defaultType, defaultValue). Scan(&powerLevel) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - store.Log.Warn("Failed to scan power level for %s in %s: %v", eventType, roomID, err) - } - return defaultValue + if errors.Is(err, sql.ErrNoRows) { + err = nil + powerLevel = defaultValue } - return powerLevel + return powerLevel, err + } else { + levels, err := store.GetPowerLevels(ctx, roomID) + if err != nil { + return 0, err + } + return levels.GetEventLevel(eventType), nil } - return store.GetPowerLevels(roomID).GetEventLevel(eventType) } -func (store *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool { +func (store *SQLStateStore) HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error) { if store.Dialect == dbutil.Postgres { defaultType := "events_default" defaultValue := 0 @@ -351,19 +327,22 @@ func (store *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, ev } var hasPower bool err := store. - QueryRow(`SELECT + QueryRow(ctx, `SELECT COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0) >= COALESCE((power_levels->'events'->$3)::int, (power_levels->'$4')::int, $5) FROM mx_room_state WHERE room_id=$1`, roomID, userID, eventType.Type, defaultType, defaultValue). Scan(&hasPower) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - store.Log.Warn("Failed to scan power level for %s in %s: %v", eventType, roomID, err) - } - return defaultValue == 0 + if errors.Is(err, sql.ErrNoRows) { + err = nil + hasPower = defaultValue == 0 } - return hasPower + return hasPower, err + } else { + levels, err := store.GetPowerLevels(ctx, roomID) + if err != nil { + return false, err + } + return levels.GetUserLevel(userID) >= levels.GetEventLevel(eventType), nil } - return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType) } diff --git a/sqlstatestore/v05-mark-encryption-state-resync.go b/sqlstatestore/v05-mark-encryption-state-resync.go index d66a9e98..bf44d308 100644 --- a/sqlstatestore/v05-mark-encryption-state-resync.go +++ b/sqlstatestore/v05-mark-encryption-state-resync.go @@ -1,19 +1,20 @@ package sqlstatestore import ( + "context" "fmt" "go.mau.fi/util/dbutil" ) func init() { - UpgradeTable.Register(-1, 5, 0, "Mark rooms that need crypto state event resynced", true, func(tx dbutil.Execable, db *dbutil.Database) error { - portalExists, err := db.TableExists(tx, "portal") + UpgradeTable.Register(-1, 5, 0, "Mark rooms that need crypto state event resynced", true, func(ctx context.Context, db *dbutil.Database) error { + portalExists, err := db.TableExists(ctx, "portal") if err != nil { return fmt.Errorf("failed to check if portal table exists") } if portalExists { - _, err = tx.Exec(` + _, err = db.Exec(ctx, ` INSERT INTO mx_room_state (room_id, encryption) SELECT portal.mxid, '{"resync":true}' FROM portal WHERE portal.encrypted=true AND portal.mxid IS NOT NULL ON CONFLICT (room_id) DO UPDATE diff --git a/statestore.go b/statestore.go index 2c0a8fd4..63a5bfb4 100644 --- a/statestore.go +++ b/statestore.go @@ -7,33 +7,37 @@ package mautrix import ( + "context" "sync" + "github.com/rs/zerolog" + "go.mau.fi/util/exerrors" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) // StateStore is an interface for storing basic room state information. type StateStore interface { - IsInRoom(roomID id.RoomID, userID id.UserID) bool - IsInvited(roomID id.RoomID, userID id.UserID) bool - IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool - GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent - TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool) - SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) - SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) - ClearCachedMembers(roomID id.RoomID, memberships ...event.Membership) + IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool + IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool + IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool + GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) + TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) + SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error + SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error + ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error - SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) - GetPowerLevels(roomID id.RoomID) *event.PowerLevelsEventContent + SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error + GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error) - SetEncryptionEvent(roomID id.RoomID, content *event.EncryptionEventContent) - IsEncrypted(roomID id.RoomID) bool + SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error + IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) - GetRoomJoinedOrInvitedMembers(roomID id.RoomID) ([]id.UserID, error) + GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) } -func UpdateStateStore(store StateStore, evt *event.Event) { +func UpdateStateStore(ctx context.Context, store StateStore, evt *event.Event) { if store == nil || evt == nil || evt.StateKey == nil { return } @@ -41,13 +45,20 @@ func UpdateStateStore(store StateStore, evt *event.Event) { if evt.Type != event.StateMember && evt.GetStateKey() != "" { return } + var err error switch content := evt.Content.Parsed.(type) { case *event.MemberEventContent: - store.SetMember(evt.RoomID, id.UserID(evt.GetStateKey()), content) + err = store.SetMember(ctx, evt.RoomID, id.UserID(evt.GetStateKey()), content) case *event.PowerLevelsEventContent: - store.SetPowerLevels(evt.RoomID, content) + err = store.SetPowerLevels(ctx, evt.RoomID, content) case *event.EncryptionEventContent: - store.SetEncryptionEvent(evt.RoomID, content) + err = store.SetEncryptionEvent(ctx, evt.RoomID, content) + } + if err != nil { + zerolog.Ctx(ctx).Warn().Err(err). + Stringer("event_id", evt.ID). + Str("event_type", evt.Type.Type). + Msg("Failed to update state store") } } @@ -57,7 +68,7 @@ func UpdateStateStore(store StateStore, evt *event.Event) { // // DefaultSyncer.ParseEventContent must also be true for this to work (which it is by default). func (cli *Client) StateStoreSyncHandler(_ EventSource, evt *event.Event) { - UpdateStateStore(cli.StateStore, evt) + UpdateStateStore(cli.Log.WithContext(context.TODO()), cli.StateStore, evt) } type MemoryStateStore struct { @@ -81,20 +92,21 @@ func NewMemoryStateStore() StateStore { } } -func (store *MemoryStateStore) IsRegistered(userID id.UserID) bool { +func (store *MemoryStateStore) IsRegistered(_ context.Context, userID id.UserID) (bool, error) { store.registrationsLock.RLock() defer store.registrationsLock.RUnlock() registered, ok := store.Registrations[userID] - return ok && registered + return ok && registered, nil } -func (store *MemoryStateStore) MarkRegistered(userID id.UserID) { +func (store *MemoryStateStore) MarkRegistered(_ context.Context, userID id.UserID) error { store.registrationsLock.Lock() defer store.registrationsLock.Unlock() store.Registrations[userID] = true + return nil } -func (store *MemoryStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*event.MemberEventContent { +func (store *MemoryStateStore) GetRoomMembers(_ context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { store.membersLock.RLock() members, ok := store.Members[roomID] store.membersLock.RUnlock() @@ -104,11 +116,14 @@ func (store *MemoryStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*e store.Members[roomID] = members store.membersLock.Unlock() } - return members + return members, nil } -func (store *MemoryStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) ([]id.UserID, error) { - members := store.GetRoomMembers(roomID) +func (store *MemoryStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) { + members, err := store.GetRoomMembers(ctx, roomID) + if err != nil { + return nil, err + } ids := make([]id.UserID, 0, len(members)) for id := range members { ids = append(ids, id) @@ -116,39 +131,39 @@ func (store *MemoryStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) ( return ids, nil } -func (store *MemoryStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership { - return store.GetMember(roomID, userID).Membership +func (store *MemoryStateStore) GetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID) (event.Membership, error) { + return exerrors.Must(store.GetMember(ctx, roomID, userID)).Membership, nil } -func (store *MemoryStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent { - member, ok := store.TryGetMember(roomID, userID) - if !ok { +func (store *MemoryStateStore) GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) { + member, err := store.TryGetMember(ctx, roomID, userID) + if member == nil && err == nil { member = &event.MemberEventContent{Membership: event.MembershipLeave} } - return member + return member, err } -func (store *MemoryStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (member *event.MemberEventContent, ok bool) { +func (store *MemoryStateStore) TryGetMember(_ context.Context, roomID id.RoomID, userID id.UserID) (member *event.MemberEventContent, err error) { store.membersLock.RLock() defer store.membersLock.RUnlock() members, membersOk := store.Members[roomID] if !membersOk { return } - member, ok = members[userID] + member = members[userID] return } -func (store *MemoryStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool { - return store.IsMembership(roomID, userID, "join") +func (store *MemoryStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool { + return store.IsMembership(ctx, roomID, userID, "join") } -func (store *MemoryStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool { - return store.IsMembership(roomID, userID, "join", "invite") +func (store *MemoryStateStore) IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool { + return store.IsMembership(ctx, roomID, userID, "join", "invite") } -func (store *MemoryStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool { - membership := store.GetMembership(roomID, userID) +func (store *MemoryStateStore) IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool { + membership := exerrors.Must(store.GetMembership(ctx, roomID, userID)) for _, allowedMembership := range allowedMemberships { if allowedMembership == membership { return true @@ -157,7 +172,7 @@ func (store *MemoryStateStore) IsMembership(roomID id.RoomID, userID id.UserID, return false } -func (store *MemoryStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) { +func (store *MemoryStateStore) SetMembership(_ context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error { store.membersLock.Lock() members, ok := store.Members[roomID] if !ok { @@ -175,9 +190,10 @@ func (store *MemoryStateStore) SetMembership(roomID id.RoomID, userID id.UserID, } store.Members[roomID] = members store.membersLock.Unlock() + return nil } -func (store *MemoryStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) { +func (store *MemoryStateStore) SetMember(_ context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error { store.membersLock.Lock() members, ok := store.Members[roomID] if !ok { @@ -189,14 +205,15 @@ func (store *MemoryStateStore) SetMember(roomID id.RoomID, userID id.UserID, mem } store.Members[roomID] = members store.membersLock.Unlock() + return nil } -func (store *MemoryStateStore) ClearCachedMembers(roomID id.RoomID, memberships ...event.Membership) { +func (store *MemoryStateStore) ClearCachedMembers(_ context.Context, roomID id.RoomID, memberships ...event.Membership) error { store.membersLock.Lock() defer store.membersLock.Unlock() members, ok := store.Members[roomID] if !ok { - return + return nil } for userID, member := range members { for _, membership := range memberships { @@ -206,46 +223,49 @@ func (store *MemoryStateStore) ClearCachedMembers(roomID id.RoomID, memberships } } } + return nil } -func (store *MemoryStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) { +func (store *MemoryStateStore) SetPowerLevels(_ context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error { store.powerLevelsLock.Lock() store.PowerLevels[roomID] = levels store.powerLevelsLock.Unlock() + return nil } -func (store *MemoryStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) { +func (store *MemoryStateStore) GetPowerLevels(_ context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) { store.powerLevelsLock.RLock() levels = store.PowerLevels[roomID] store.powerLevelsLock.RUnlock() return } -func (store *MemoryStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int { - return store.GetPowerLevels(roomID).GetUserLevel(userID) +func (store *MemoryStateStore) GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error) { + return exerrors.Must(store.GetPowerLevels(ctx, roomID)).GetUserLevel(userID), nil } -func (store *MemoryStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int { - return store.GetPowerLevels(roomID).GetEventLevel(eventType) +func (store *MemoryStateStore) GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error) { + return exerrors.Must(store.GetPowerLevels(ctx, roomID)).GetEventLevel(eventType), nil } -func (store *MemoryStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool { - return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType) +func (store *MemoryStateStore) HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error) { + return exerrors.Must(store.GetPowerLevel(ctx, roomID, userID)) >= exerrors.Must(store.GetPowerLevelRequirement(ctx, roomID, eventType)), nil } -func (store *MemoryStateStore) SetEncryptionEvent(roomID id.RoomID, content *event.EncryptionEventContent) { +func (store *MemoryStateStore) SetEncryptionEvent(_ context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error { store.encryptionLock.Lock() store.Encryption[roomID] = content store.encryptionLock.Unlock() + return nil } -func (store *MemoryStateStore) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventContent { +func (store *MemoryStateStore) GetEncryptionEvent(_ context.Context, roomID id.RoomID) (*event.EncryptionEventContent, error) { store.encryptionLock.RLock() defer store.encryptionLock.RUnlock() - return store.Encryption[roomID] + return store.Encryption[roomID], nil } -func (store *MemoryStateStore) IsEncrypted(roomID id.RoomID) bool { - cfg := store.GetEncryptionEvent(roomID) - return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1 +func (store *MemoryStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) { + cfg, err := store.GetEncryptionEvent(ctx, roomID) + return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1, err } From 8da3a1740282b06bdee4daa6e0535657529c3367 Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Wed, 10 Jan 2024 11:37:00 +0200 Subject: [PATCH 0054/1647] Add context to OLM machine LoadDevices As there's a side effect of going to the crypto store we want the context to at least exist for now. --- crypto/devicelist.go | 5 ++--- crypto/machine.go | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/crypto/devicelist.go b/crypto/devicelist.go index e554480d..bbe06aae 100644 --- a/crypto/devicelist.go +++ b/crypto/devicelist.go @@ -27,9 +27,8 @@ var ( InvalidKeySignature = errors.New("invalid signature on device keys") ) -func (mach *OlmMachine) LoadDevices(user id.UserID) map[id.DeviceID]*id.Device { - // TODO proper context? - return mach.fetchKeys(context.TODO(), []id.UserID{user}, "", true)[user] +func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) map[id.DeviceID]*id.Device { + return mach.fetchKeys(ctx, []id.UserID{user}, "", true)[user] } func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id.UserID, deviceID id.DeviceID, resp *mautrix.RespQueryKeys) { diff --git a/crypto/machine.go b/crypto/machine.go index da78eaf7..fc0f1742 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -438,7 +438,7 @@ func (mach *OlmMachine) GetOrFetchDeviceByKey(ctx context.Context, userID id.Use Str("user_id", userID.String()). Str("identity_key", identityKey.String()). Msg("Didn't find identity in crypto store, fetching from server") - devices := mach.LoadDevices(userID) + devices := mach.LoadDevices(ctx, userID) for _, device := range devices { if device.IdentityKey == identityKey { return device, nil From b3910eb6994d2aeef27046916f2fa251bd341d2f Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 11 Jan 2024 17:40:04 -0700 Subject: [PATCH 0055/1647] pre-commit: specify maunium.net/go/mautrix as local import Signed-off-by: Sumner Evans --- .pre-commit-config.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1ef386ea..a656f0a8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,4 +12,8 @@ repos: rev: v1.0.0-rc.1 hooks: - id: go-imports-repo + args: + - "-local" + - "maunium.net/go/mautrix" + - "-w" - id: go-vet-repo-mod From a3883fcf6fba331dda53f63a1c70c2a57d6d2fb6 Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Wed, 10 Jan 2024 16:03:15 +0200 Subject: [PATCH 0056/1647] Allow disabling automatic key fetching for Olm machine Many crypto operations in the Olm machine have a possible side effect of fetching keys from the server if they are missing. This may be undesired in some special cases. To tracking which users need key fetching, CryptoStore now exposes APIs to mark and query the status. --- crypto/devicelist.go | 24 ++++++----- crypto/encryptmegolm.go | 16 +++++--- crypto/machine.go | 18 +++++--- crypto/sql_store.go | 35 +++++++++++++++- .../sql_store_upgrade/00-latest-revision.sql | 5 ++- .../sql_store_upgrade/11-outdated-devices.sql | 2 + crypto/store.go | 30 ++++++++++++++ crypto/store_test.go | 41 ++++++++++++++++++- 8 files changed, 144 insertions(+), 27 deletions(-) create mode 100644 crypto/sql_store_upgrade/11-outdated-devices.sql diff --git a/crypto/devicelist.go b/crypto/devicelist.go index bbe06aae..f5c07cd3 100644 --- a/crypto/devicelist.go +++ b/crypto/devicelist.go @@ -27,8 +27,16 @@ var ( InvalidKeySignature = errors.New("invalid signature on device keys") ) -func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) map[id.DeviceID]*id.Device { - return mach.fetchKeys(ctx, []id.UserID{user}, "", true)[user] +func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) (keys map[id.DeviceID]*id.Device) { + log := zerolog.Ctx(ctx) + + if keys, err := mach.FetchKeys(ctx, []id.UserID{user}, true); err != nil { + log.Err(err).Msg("Failed to load devices") + } else if keys != nil { + return keys[user] + } + + return nil } func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id.UserID, deviceID id.DeviceID, resp *mautrix.RespQueryKeys) { @@ -85,19 +93,16 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id } } -func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceToken string, includeUntracked bool) (data map[id.UserID]map[id.DeviceID]*id.Device) { - // TODO this function should probably return errors +func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includeUntracked bool) (data map[id.UserID]map[id.DeviceID]*id.Device, err error) { req := &mautrix.ReqQueryKeys{ DeviceKeys: mautrix.DeviceKeysRequest{}, Timeout: 10 * 1000, - Token: sinceToken, } log := mach.machOrContextLog(ctx) if !includeUntracked { - var err error users, err = mach.CryptoStore.FilterTrackedUsers(ctx, users) if err != nil { - log.Warn().Err(err).Msg("Failed to filter tracked user list") + return nil, fmt.Errorf("failed to filter tracked user list: %w", err) } } if len(users) == 0 { @@ -109,8 +114,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT log.Debug().Strs("users", strishArray(users)).Msg("Querying keys for users") resp, err := mach.Client.QueryKeys(ctx, req) if err != nil { - log.Error().Err(err).Msg("Failed to query keys") - return + return nil, fmt.Errorf("failed to query keys: %w", err) } for server, err := range resp.Failures { log.Warn().Interface("query_error", err).Str("server", server).Msg("Query keys failure for server") @@ -189,7 +193,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT mach.storeCrossSigningKeys(ctx, resp.SelfSigningKeys, resp.DeviceKeys) mach.storeCrossSigningKeys(ctx, resp.UserSigningKeys, resp.DeviceKeys) - return data + return data, nil } // OnDevicesChanged finds all shared rooms with the given user and invalidates outbound sessions in those rooms. diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index 1eee2fec..dcd36dc1 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -229,12 +229,16 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, if len(fetchKeys) > 0 { log.Debug().Strs("users", strishArray(fetchKeys)).Msg("Fetching missing keys") - for userID, devices := range mach.fetchKeys(ctx, fetchKeys, "", true) { - log.Debug(). - Int("device_count", len(devices)). - Str("target_user_id", userID.String()). - Msg("Got device keys for user") - missingSessions[userID] = devices + if keys, err := mach.FetchKeys(ctx, fetchKeys, true); err != nil { + log.Err(err).Strs("users", strishArray(fetchKeys)).Msg("Failed to fetch missing keys") + } else if keys != nil { + for userID, devices := range keys { + log.Debug(). + Int("device_count", len(devices)). + Str("target_user_id", userID.String()). + Msg("Got device keys for user") + missingSessions[userID] = devices + } } } diff --git a/crypto/machine.go b/crypto/machine.go index fc0f1742..fa0c50dc 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -33,6 +33,9 @@ type OlmMachine struct { PlaintextMentions bool + // Never ask the server for keys automatically as a side effect. + DisableKeyFetching bool + SendKeysMinTrust id.TrustState ShareKeysMinTrust id.TrustState @@ -224,7 +227,11 @@ func (mach *OlmMachine) HandleDeviceLists(dl *mautrix.DeviceLists, since string) Str("trace_id", traceID). Interface("changes", dl.Changed). Msg("Device list changes in /sync") - mach.fetchKeys(context.TODO(), dl.Changed, since, false) + if mach.DisableKeyFetching { + mach.CryptoStore.MarkTrackedUsersOutdated(context.TODO(), dl.Changed) + } else { + mach.FetchKeys(context.TODO(), dl.Changed, false) + } mach.Log.Debug().Str("trace_id", traceID).Msg("Finished handling device list changes") } } @@ -413,11 +420,12 @@ func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID, device, err := mach.CryptoStore.GetDevice(ctx, userID, deviceID) if err != nil { return nil, fmt.Errorf("failed to get sender device from store: %w", err) - } else if device != nil { + } else if device != nil || mach.DisableKeyFetching { return device, nil } - usersToDevices := mach.fetchKeys(ctx, []id.UserID{userID}, "", true) - if devices, ok := usersToDevices[userID]; ok { + if usersToDevices, err := mach.FetchKeys(ctx, []id.UserID{userID}, true); err != nil { + return nil, fmt.Errorf("failed to fetch keys: %w", err) + } else if devices, ok := usersToDevices[userID]; ok { if device, ok = devices[deviceID]; ok { return device, nil } @@ -431,7 +439,7 @@ func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID, // the given identity key. func (mach *OlmMachine) GetOrFetchDeviceByKey(ctx context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) { deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(ctx, userID, identityKey) - if err != nil || deviceIdentity != nil { + if err != nil || deviceIdentity != nil || mach.DisableKeyFetching { return deviceIdentity, err } mach.machOrContextLog(ctx).Debug(). diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 8c85f6de..99a94f0e 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -665,12 +665,19 @@ func (store *SQLCryptoStore) PutDevice(ctx context.Context, userID id.UserID, de return err } +const trackedUserUpsertQuery = ` +INSERT INTO crypto_tracked_user (user_id, devices_outdated) +VALUES ($1, false) +ON CONFLICT (user_id) DO UPDATE + SET devices_outdated = EXCLUDED.devices_outdated +` + // PutDevices stores the device identity information for the given user ID. func (store *SQLCryptoStore) PutDevices(ctx context.Context, userID id.UserID, devices map[id.DeviceID]*id.Device) error { return store.DB.DoTxn(ctx, nil, func(ctx context.Context) error { - _, err := store.DB.Exec(ctx, "INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID) + _, err := store.DB.Exec(ctx, trackedUserUpsertQuery, userID) if err != nil { - return fmt.Errorf("failed to add user to tracked users list: %w", err) + return fmt.Errorf("failed to upsert user to tracked users list: %w", err) } _, err = store.DB.Exec(ctx, "UPDATE crypto_device SET deleted=true WHERE user_id=$1", userID) @@ -734,6 +741,30 @@ func (store *SQLCryptoStore) FilterTrackedUsers(ctx context.Context, users []id. return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.UserID]).AsList() } +// MarkTrackedUsersOutdated flags that the device list for given users are outdated. +func (store *SQLCryptoStore) MarkTrackedUsersOutdated(ctx context.Context, users []id.UserID) error { + return store.DB.DoTxn(ctx, nil, func(ctx context.Context) error { + // TODO refactor to use a single query + for _, userID := range users { + _, err := store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id = $1", userID) + if err != nil { + return fmt.Errorf("failed to update user in the tracked users list: %w", err) + } + } + + return nil + }) +} + +// GetOutdatedTrackerUsers gets all tracked users whose devices need to be updated. +func (store *SQLCryptoStore) GetOutdatedTrackedUsers(ctx context.Context) ([]id.UserID, error) { + rows, err := store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE devices_outdated = TRUE") + if err != nil { + return nil, err + } + return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.UserID]).AsList() +} + // PutCrossSigningKey stores a cross-signing key of some user along with its usage. func (store *SQLCryptoStore) PutCrossSigningKey(ctx context.Context, userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error { _, err := store.DB.Exec(ctx, ` diff --git a/crypto/sql_store_upgrade/00-latest-revision.sql b/crypto/sql_store_upgrade/00-latest-revision.sql index bd8f7942..90d7d31c 100644 --- a/crypto/sql_store_upgrade/00-latest-revision.sql +++ b/crypto/sql_store_upgrade/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v10: Latest revision +-- v0 -> v11: Latest revision CREATE TABLE IF NOT EXISTS crypto_account ( account_id TEXT PRIMARY KEY, device_id TEXT NOT NULL, @@ -17,7 +17,8 @@ CREATE TABLE IF NOT EXISTS crypto_message_index ( ); CREATE TABLE IF NOT EXISTS crypto_tracked_user ( - user_id TEXT PRIMARY KEY + user_id TEXT PRIMARY KEY, + devices_outdated BOOLEAN NOT NULL DEFAULT FALSE ); CREATE TABLE IF NOT EXISTS crypto_device ( diff --git a/crypto/sql_store_upgrade/11-outdated-devices.sql b/crypto/sql_store_upgrade/11-outdated-devices.sql new file mode 100644 index 00000000..f0f0ba5b --- /dev/null +++ b/crypto/sql_store_upgrade/11-outdated-devices.sql @@ -0,0 +1,2 @@ +-- v11: Add devices_outdated field to crypto_tracked_user +ALTER TABLE crypto_tracked_user ADD COLUMN devices_outdated BOOLEAN NOT NULL DEFAULT FALSE; diff --git a/crypto/store.go b/crypto/store.go index 09393a51..fb3d5b96 100644 --- a/crypto/store.go +++ b/crypto/store.go @@ -108,6 +108,10 @@ type Store interface { // FilterTrackedUsers returns a filtered version of the given list that only includes user IDs whose device lists // have been stored with PutDevices. A user is considered tracked even if the PutDevices list was empty. FilterTrackedUsers(context.Context, []id.UserID) ([]id.UserID, error) + // MarkTrackedUsersOutdated flags that the device list for given users are outdated. + MarkTrackedUsersOutdated(context.Context, []id.UserID) error + // GetOutdatedTrackerUsers gets all tracked users whose devices need to be updated. + GetOutdatedTrackedUsers(context.Context) ([]id.UserID, error) // PutCrossSigningKey stores a cross-signing key of some user along with its usage. PutCrossSigningKey(context.Context, id.UserID, id.CrossSigningUsage, id.Ed25519) error @@ -148,6 +152,7 @@ type MemoryStore struct { Devices map[id.UserID]map[id.DeviceID]*id.Device CrossSigningKeys map[id.UserID]map[id.CrossSigningUsage]id.CrossSigningKey KeySignatures map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string + OutdatedUsers map[id.UserID]struct{} } var _ Store = (*MemoryStore)(nil) @@ -167,6 +172,7 @@ func NewMemoryStore(saveCallback func() error) *MemoryStore { Devices: make(map[id.UserID]map[id.DeviceID]*id.Device), CrossSigningKeys: make(map[id.UserID]map[id.CrossSigningUsage]id.CrossSigningKey), KeySignatures: make(map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string), + OutdatedUsers: make(map[id.UserID]struct{}), } } @@ -499,6 +505,9 @@ func (gs *MemoryStore) PutDevices(_ context.Context, userID id.UserID, devices m gs.lock.Lock() gs.Devices[userID] = devices err := gs.save() + if err == nil { + delete(gs.OutdatedUsers, userID) + } gs.lock.Unlock() return err } @@ -517,6 +526,27 @@ func (gs *MemoryStore) FilterTrackedUsers(_ context.Context, users []id.UserID) return users[:ptr], nil } +func (gs *MemoryStore) MarkTrackedUsersOutdated(_ context.Context, users []id.UserID) error { + gs.lock.Lock() + for _, userID := range users { + if _, ok := gs.Devices[userID]; ok { + gs.OutdatedUsers[userID] = struct{}{} + } + } + gs.lock.Unlock() + return nil +} + +func (gs *MemoryStore) GetOutdatedTrackedUsers(_ context.Context) ([]id.UserID, error) { + gs.lock.RLock() + users := make([]id.UserID, 0, len(gs.OutdatedUsers)) + for userID := range gs.OutdatedUsers { + users = append(users, userID) + } + gs.lock.RUnlock() + return users, nil +} + func (gs *MemoryStore) PutCrossSigningKey(_ context.Context, userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error { gs.lock.RLock() userKeys, ok := gs.CrossSigningKeys[userID] diff --git a/crypto/store_test.go b/crypto/store_test.go index 665e3ef9..bbadef28 100644 --- a/crypto/store_test.go +++ b/crypto/store_test.go @@ -221,6 +221,13 @@ func TestStoreDevices(t *testing.T) { stores := getCryptoStores(t) for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { + outdated, err := store.GetOutdatedTrackedUsers(context.TODO()) + if err != nil { + t.Errorf("Error filtering tracked users: %v", err) + } + if len(outdated) > 0 { + t.Errorf("Got %d outdated tracked users when expected none", len(outdated)) + } deviceMap := make(map[id.DeviceID]*id.Device) for i := 0; i < 17; i++ { iStr := strconv.Itoa(i) @@ -232,9 +239,9 @@ func TestStoreDevices(t *testing.T) { SigningKey: acc.SigningKey(), } } - err := store.PutDevices(context.TODO(), "user1", deviceMap) + err = store.PutDevices(context.TODO(), "user1", deviceMap) if err != nil { - t.Errorf("Error string devices: %v", err) + t.Errorf("Error storing devices: %v", err) } devs, err := store.GetDevices(context.TODO(), "user1") if err != nil { @@ -256,6 +263,36 @@ func TestStoreDevices(t *testing.T) { } else if len(filtered) != 1 || filtered[0] != "user1" { t.Errorf("Expected to get 'user1' from filter, got %v", filtered) } + + outdated, err = store.GetOutdatedTrackedUsers(context.TODO()) + if err != nil { + t.Errorf("Error filtering tracked users: %v", err) + } + if len(outdated) > 0 { + t.Errorf("Got %d outdated tracked users when expected none", len(outdated)) + } + err = store.MarkTrackedUsersOutdated(context.TODO(), []id.UserID{"user0", "user1"}) + if err != nil { + t.Errorf("Error marking tracked users outdated: %v", err) + } + outdated, err = store.GetOutdatedTrackedUsers(context.TODO()) + if err != nil { + t.Errorf("Error filtering tracked users: %v", err) + } + if len(outdated) != 1 || outdated[0] != id.UserID("user1") { + t.Errorf("Got outdated tracked users %v when expected 'user1'", outdated) + } + err = store.PutDevices(context.TODO(), "user1", deviceMap) + if err != nil { + t.Errorf("Error storing devices: %v", err) + } + outdated, err = store.GetOutdatedTrackedUsers(context.TODO()) + if err != nil { + t.Errorf("Error filtering tracked users: %v", err) + } + if len(outdated) > 0 { + t.Errorf("Got outdated tracked users %v when expected none", outdated) + } }) } } From 308e3583b06f03da67da38a5ff4d711cd5fa02d1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 13 Jan 2024 18:56:12 +0200 Subject: [PATCH 0057/1647] Add contexts to event handlers --- CHANGELOG.md | 1 + appservice/eventprocessor.go | 55 +++++------ bridge/bridge.go | 8 +- bridge/crypto.go | 8 +- bridge/matrix.go | 35 ++++--- client.go | 7 +- crypto/cryptohelper/cryptohelper.go | 18 ++-- crypto/machine.go | 32 +++---- event/events.go | 2 + event/eventsource.go | 72 ++++++++++++++ statestore.go | 4 +- sync.go | 140 +++++++--------------------- 12 files changed, 196 insertions(+), 186 deletions(-) create mode 100644 event/eventsource.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 7abbe587..a04fbff4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ functions. * **Breaking change *(everything)*** Added context parameters to all functions (started by [@recht] in [#144]). +* *(client)* Moved `EventSource` to `event.Source`. * *(crypto)* Added experimental pure Go Olm implementation to replace libolm (thanks to [@DerLukas15] in [#106]). * You can use the `goolm` build tag to the new implementation. diff --git a/appservice/eventprocessor.go b/appservice/eventprocessor.go index 376a4fc4..4cd2ce4e 100644 --- a/appservice/eventprocessor.go +++ b/appservice/eventprocessor.go @@ -7,6 +7,7 @@ package appservice import ( + "context" "encoding/json" "runtime/debug" "time" @@ -25,9 +26,9 @@ const ( Sync ) -type EventHandler = func(evt *event.Event) -type OTKHandler = func(otk *mautrix.OTKCount) -type DeviceListHandler = func(lists *mautrix.DeviceLists, since string) +type EventHandler = func(ctx context.Context, evt *event.Event) +type OTKHandler = func(ctx context.Context, otk *mautrix.OTKCount) +type DeviceListHandler = func(ctx context.Context, lists *mautrix.DeviceLists, since string) type EventProcessor struct { ExecMode ExecMode @@ -97,34 +98,34 @@ func (ep *EventProcessor) recoverFunc(data interface{}) { } } -func (ep *EventProcessor) callHandler(handler EventHandler, evt *event.Event) { +func (ep *EventProcessor) callHandler(ctx context.Context, handler EventHandler, evt *event.Event) { defer ep.recoverFunc(evt) - handler(evt) + handler(ctx, evt) } -func (ep *EventProcessor) callOTKHandler(handler OTKHandler, otk *mautrix.OTKCount) { +func (ep *EventProcessor) callOTKHandler(ctx context.Context, handler OTKHandler, otk *mautrix.OTKCount) { defer ep.recoverFunc(otk) - handler(otk) + handler(ctx, otk) } -func (ep *EventProcessor) callDeviceListHandler(handler DeviceListHandler, dl *mautrix.DeviceLists) { +func (ep *EventProcessor) callDeviceListHandler(ctx context.Context, handler DeviceListHandler, dl *mautrix.DeviceLists) { defer ep.recoverFunc(dl) - handler(dl, "") + handler(ctx, dl, "") } -func (ep *EventProcessor) DispatchOTK(otk *mautrix.OTKCount) { +func (ep *EventProcessor) DispatchOTK(ctx context.Context, otk *mautrix.OTKCount) { for _, handler := range ep.otkHandlers { - go ep.callOTKHandler(handler, otk) + go ep.callOTKHandler(ctx, handler, otk) } } -func (ep *EventProcessor) DispatchDeviceList(dl *mautrix.DeviceLists) { +func (ep *EventProcessor) DispatchDeviceList(ctx context.Context, dl *mautrix.DeviceLists) { for _, handler := range ep.deviceListHandlers { - go ep.callDeviceListHandler(handler, dl) + go ep.callDeviceListHandler(ctx, handler, dl) } } -func (ep *EventProcessor) Dispatch(evt *event.Event) { +func (ep *EventProcessor) Dispatch(ctx context.Context, evt *event.Event) { handlers, ok := ep.handlers[evt.Type] if !ok { return @@ -132,25 +133,25 @@ func (ep *EventProcessor) Dispatch(evt *event.Event) { switch ep.ExecMode { case AsyncHandlers: for _, handler := range handlers { - go ep.callHandler(handler, evt) + go ep.callHandler(ctx, handler, evt) } case AsyncLoop: go func() { for _, handler := range handlers { - ep.callHandler(handler, evt) + ep.callHandler(ctx, handler, evt) } }() case Sync: if ep.ExecSyncWarnTime == 0 && ep.ExecSyncTimeout == 0 { for _, handler := range handlers { - ep.callHandler(handler, evt) + ep.callHandler(ctx, handler, evt) } return } doneChan := make(chan struct{}) go func() { for _, handler := range handlers { - ep.callHandler(handler, evt) + ep.callHandler(ctx, handler, evt) } close(doneChan) }() @@ -172,35 +173,35 @@ func (ep *EventProcessor) Dispatch(evt *event.Event) { } } } -func (ep *EventProcessor) startEvents() { +func (ep *EventProcessor) startEvents(ctx context.Context) { for { select { case evt := <-ep.as.Events: - ep.Dispatch(evt) + ep.Dispatch(ctx, evt) case <-ep.stop: return } } } -func (ep *EventProcessor) startEncryption() { +func (ep *EventProcessor) startEncryption(ctx context.Context) { for { select { case evt := <-ep.as.ToDeviceEvents: - ep.Dispatch(evt) + ep.Dispatch(ctx, evt) case otk := <-ep.as.OTKCounts: - ep.DispatchOTK(otk) + ep.DispatchOTK(ctx, otk) case dl := <-ep.as.DeviceLists: - ep.DispatchDeviceList(dl) + ep.DispatchDeviceList(ctx, dl) case <-ep.stop: return } } } -func (ep *EventProcessor) Start() { - go ep.startEvents() - go ep.startEncryption() +func (ep *EventProcessor) Start(ctx context.Context) { + go ep.startEvents(ctx) + go ep.startEncryption(ctx) } func (ep *EventProcessor) Stop() { diff --git a/bridge/bridge.go b/bridge/bridge.go index 6ad19720..7d5333ce 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -214,7 +214,7 @@ type Bridge struct { } type Crypto interface { - HandleMemberEvent(*event.Event) + HandleMemberEvent(context.Context, *event.Event) Decrypt(context.Context, *event.Event) (*event.Event, error) Encrypt(context.Context, id.RoomID, event.Type, *event.Content) error WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool @@ -321,7 +321,7 @@ func (br *Bridge) ensureConnection(ctx context.Context) { if errors.Is(err, mautrix.MUnknownToken) { br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was not accepted. Is the registration file installed in your homeserver correctly?") } else if errors.Is(err, mautrix.MExclusive) { - br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was accepted, but the /register request was not. Are the homeserver domain and username template in the config correct, and do they match the values in the registration?") + br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was accepted, but the /register request was not. Are the homeserver domain, bot username and username template in the config correct, and do they match the values in the registration?") } else { br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("/whoami request failed with unknown error") } @@ -674,7 +674,7 @@ func (br *Bridge) start() { } br.ZLog.Debug().Msg("Checking connection to homeserver") - ctx := context.Background() + ctx := br.ZLog.WithContext(context.Background()) br.ensureConnection(ctx) go br.fetchMediaConfig(ctx) @@ -687,7 +687,7 @@ func (br *Bridge) start() { } br.ZLog.Debug().Msg("Starting event processor") - br.EventProcessor.Start() + br.EventProcessor.Start(ctx) go br.UpdateBotProfile(ctx) if br.Crypto != nil { diff --git a/bridge/crypto.go b/bridge/crypto.go index 872bf8a6..f0b90056 100644 --- a/bridge/crypto.go +++ b/bridge/crypto.go @@ -425,10 +425,10 @@ func (helper *CryptoHelper) ResetSession(ctx context.Context, roomID id.RoomID) } } -func (helper *CryptoHelper) HandleMemberEvent(evt *event.Event) { +func (helper *CryptoHelper) HandleMemberEvent(ctx context.Context, evt *event.Event) { helper.lock.RLock() defer helper.lock.RUnlock() - helper.mach.HandleMemberEvent(0, evt) + helper.mach.HandleMemberEvent(ctx, evt) } // ShareKeys uploads the given number of one-time-keys to the server. @@ -440,7 +440,7 @@ type cryptoSyncer struct { *crypto.OlmMachine } -func (syncer *cryptoSyncer) ProcessResponse(resp *mautrix.RespSync, since string) error { +func (syncer *cryptoSyncer) ProcessResponse(ctx context.Context, resp *mautrix.RespSync, since string) error { done := make(chan struct{}) go func() { defer func() { @@ -454,7 +454,7 @@ func (syncer *cryptoSyncer) ProcessResponse(resp *mautrix.RespSync, since string done <- struct{}{} }() syncer.Log.Trace().Str("since", since).Msg("Starting sync response handling") - syncer.ProcessSyncResponse(resp, since) + syncer.ProcessSyncResponse(ctx, resp, since) syncer.Log.Trace().Str("since", since).Msg("Successfully handled sync response") }() select { diff --git a/bridge/matrix.go b/bridge/matrix.go index 00994dd2..5aa457fa 100644 --- a/bridge/matrix.go +++ b/bridge/matrix.go @@ -68,13 +68,13 @@ func NewMatrixHandler(br *Bridge) *MatrixHandler { return handler } -func (mx *MatrixHandler) sendBridgeCheckpoint(evt *event.Event) { +func (mx *MatrixHandler) sendBridgeCheckpoint(_ context.Context, evt *event.Event) { if !evt.Mautrix.CheckpointSent { go mx.bridge.SendMessageSuccessCheckpoint(evt, status.MsgStepBridge, 0) } } -func (mx *MatrixHandler) HandleEncryption(evt *event.Event) { +func (mx *MatrixHandler) HandleEncryption(ctx context.Context, evt *event.Event) { defer mx.TrackEventDuration(evt.Type)() if evt.Content.AsEncryption().Algorithm != id.AlgorithmMegolmV1 { return @@ -87,7 +87,7 @@ func (mx *MatrixHandler) HandleEncryption(evt *event.Event) { Msg("Encryption was enabled in room") portal.MarkEncrypted() if portal.IsPrivateChat() { - err := mx.as.BotIntent().EnsureJoined(context.TODO(), evt.RoomID, appservice.EnsureJoinedParams{BotOverride: portal.MainIntent().Client}) + err := mx.as.BotIntent().EnsureJoined(ctx, evt.RoomID, appservice.EnsureJoinedParams{BotOverride: portal.MainIntent().Client}) if err != nil { mx.log.Err(err). Str("room_id", evt.RoomID.String()). @@ -232,15 +232,14 @@ func (mx *MatrixHandler) HandleGhostInvite(ctx context.Context, evt *event.Event } } -func (mx *MatrixHandler) HandleMembership(evt *event.Event) { +func (mx *MatrixHandler) HandleMembership(ctx context.Context, evt *event.Event) { if evt.Sender == mx.bridge.Bot.UserID || mx.bridge.Child.IsGhost(evt.Sender) { return } defer mx.TrackEventDuration(evt.Type)() - ctx := context.TODO() if mx.bridge.Crypto != nil { - mx.bridge.Crypto.HandleMemberEvent(evt) + mx.bridge.Crypto.HandleMemberEvent(ctx, evt) } log := mx.log.With(). @@ -300,7 +299,7 @@ func (mx *MatrixHandler) HandleMembership(evt *event.Event) { // TODO kicking/inviting non-ghost users users } -func (mx *MatrixHandler) HandleRoomMetadata(evt *event.Event) { +func (mx *MatrixHandler) HandleRoomMetadata(ctx context.Context, evt *event.Event) { defer mx.TrackEventDuration(evt.Type)() if mx.shouldIgnoreEvent(evt) { return @@ -469,20 +468,20 @@ func (mx *MatrixHandler) postDecrypt(ctx context.Context, original, decrypted *e mx.bridge.SendMessageSuccessCheckpoint(decrypted, status.MsgStepDecrypted, retryCount) decrypted.Mautrix.CheckpointSent = true decrypted.Mautrix.DecryptionDuration = duration - mx.bridge.EventProcessor.Dispatch(decrypted) + decrypted.Mautrix.EventSource |= event.SourceDecrypted + mx.bridge.EventProcessor.Dispatch(ctx, decrypted) if errorEventID != "" { _, _ = mx.bridge.Bot.RedactEvent(ctx, decrypted.RoomID, errorEventID) } } -func (mx *MatrixHandler) HandleEncrypted(evt *event.Event) { +func (mx *MatrixHandler) HandleEncrypted(ctx context.Context, evt *event.Event) { defer mx.TrackEventDuration(evt.Type)() if mx.shouldIgnoreEvent(evt) { return } content := evt.Content.AsEncrypted() - ctx := context.TODO() - log := mx.log.With(). + log := zerolog.Ctx(ctx).With(). Str("event_id", evt.ID.String()). Str("session_id", content.SessionID.String()). Logger() @@ -546,14 +545,14 @@ func (mx *MatrixHandler) waitLongerForSession(ctx context.Context, evt *event.Ev mx.postDecrypt(ctx, evt, decrypted, 2, errorEventID, time.Since(decryptionStart)) } -func (mx *MatrixHandler) HandleMessage(evt *event.Event) { +func (mx *MatrixHandler) HandleMessage(ctx context.Context, evt *event.Event) { defer mx.TrackEventDuration(evt.Type)() - log := mx.log.With(). + log := zerolog.Ctx(ctx).With(). Str("event_id", evt.ID.String()). Str("room_id", evt.RoomID.String()). Str("sender", evt.Sender.String()). Logger() - ctx := log.WithContext(context.TODO()) + ctx = log.WithContext(ctx) if mx.shouldIgnoreEvent(evt) { return } else if !evt.Mautrix.WasEncrypted && mx.bridge.Config.Bridge.GetEncryptionConfig().Require { @@ -604,7 +603,7 @@ func (mx *MatrixHandler) HandleMessage(evt *event.Event) { } } -func (mx *MatrixHandler) HandleReaction(evt *event.Event) { +func (mx *MatrixHandler) HandleReaction(_ context.Context, evt *event.Event) { defer mx.TrackEventDuration(evt.Type)() if mx.shouldIgnoreEvent(evt) { return @@ -623,7 +622,7 @@ func (mx *MatrixHandler) HandleReaction(evt *event.Event) { } } -func (mx *MatrixHandler) HandleRedaction(evt *event.Event) { +func (mx *MatrixHandler) HandleRedaction(_ context.Context, evt *event.Event) { defer mx.TrackEventDuration(evt.Type)() if mx.shouldIgnoreEvent(evt) { return @@ -642,7 +641,7 @@ func (mx *MatrixHandler) HandleRedaction(evt *event.Event) { } } -func (mx *MatrixHandler) HandleReceipt(evt *event.Event) { +func (mx *MatrixHandler) HandleReceipt(_ context.Context, evt *event.Event) { portal := mx.bridge.Child.GetIPortal(evt.RoomID) if portal == nil { return @@ -676,7 +675,7 @@ func (mx *MatrixHandler) HandleReceipt(evt *event.Event) { } } -func (mx *MatrixHandler) HandleTyping(evt *event.Event) { +func (mx *MatrixHandler) HandleTyping(_ context.Context, evt *event.Event) { portal := mx.bridge.Child.GetIPortal(evt.RoomID) if portal == nil { return diff --git a/client.go b/client.go index d1a6d8f0..dfef7231 100644 --- a/client.go +++ b/client.go @@ -236,8 +236,11 @@ func (cli *Client) SyncWithContext(ctx context.Context) error { // Save the token now *before* processing it. This means it's possible // to not process some events, but it means that we won't get constantly stuck processing // a malformed/buggy event which keeps making us panic. - cli.Store.SaveNextBatch(ctx, cli.UserID, resSync.NextBatch) - if err = cli.Syncer.ProcessResponse(resSync, nextBatch); err != nil { + err = cli.Store.SaveNextBatch(ctx, cli.UserID, resSync.NextBatch) + if err != nil { + return err + } + if err = cli.Syncer.ProcessResponse(ctx, resSync, nextBatch); err != nil { return err } diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go index eb7d7a77..a0065012 100644 --- a/crypto/cryptohelper/cryptohelper.go +++ b/crypto/cryptohelper/cryptohelper.go @@ -245,17 +245,18 @@ var NoSessionFound = crypto.NoSessionFound const initialSessionWaitTimeout = 3 * time.Second const extendedSessionWaitTimeout = 22 * time.Second -func (helper *CryptoHelper) HandleEncrypted(src mautrix.EventSource, evt *event.Event) { +func (helper *CryptoHelper) HandleEncrypted(ctx context.Context, evt *event.Event) { if helper == nil { return } content := evt.Content.AsEncrypted() + // TODO use context log instead of helper? log := helper.log.With(). Str("event_id", evt.ID.String()). Str("session_id", content.SessionID.String()). Logger() log.Debug().Msg("Decrypting received event") - ctx := log.WithContext(context.TODO()) + ctx = log.WithContext(ctx) decrypted, err := helper.Decrypt(ctx, evt) if errors.Is(err, NoSessionFound) { @@ -266,7 +267,7 @@ func (helper *CryptoHelper) HandleEncrypted(src mautrix.EventSource, evt *event. log.Debug().Msg("Got keys after waiting, trying to decrypt event again") decrypted, err = helper.Decrypt(ctx, evt) } else { - go helper.waitLongerForSession(ctx, log, src, evt) + go helper.waitLongerForSession(ctx, log, evt) return } } @@ -275,11 +276,12 @@ func (helper *CryptoHelper) HandleEncrypted(src mautrix.EventSource, evt *event. helper.DecryptErrorCallback(evt, err) return } - helper.postDecrypt(src, decrypted) + helper.postDecrypt(ctx, decrypted) } -func (helper *CryptoHelper) postDecrypt(src mautrix.EventSource, decrypted *event.Event) { - helper.client.Syncer.(mautrix.DispatchableSyncer).Dispatch(src|mautrix.EventSourceDecrypted, decrypted) +func (helper *CryptoHelper) postDecrypt(ctx context.Context, decrypted *event.Event) { + decrypted.Mautrix.EventSource |= event.SourceDecrypted + helper.client.Syncer.(mautrix.DispatchableSyncer).Dispatch(ctx, decrypted) } func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) { @@ -309,7 +311,7 @@ func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID } } -func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, log zerolog.Logger, src mautrix.EventSource, evt *event.Event) { +func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, log zerolog.Logger, evt *event.Event) { content := evt.Content.AsEncrypted() log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...") @@ -329,7 +331,7 @@ func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, log zerolo return } - helper.postDecrypt(src, decrypted) + helper.postDecrypt(ctx, decrypted) } func (helper *CryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { diff --git a/crypto/machine.go b/crypto/machine.go index fa0c50dc..9892536a 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -197,9 +197,9 @@ func (mach *OlmMachine) OwnIdentity() *id.Device { } type asEventProcessor interface { - On(evtType event.Type, handler func(evt *event.Event)) - OnOTK(func(otk *mautrix.OTKCount)) - OnDeviceList(func(lists *mautrix.DeviceLists, since string)) + On(evtType event.Type, handler func(ctx context.Context, evt *event.Event)) + OnOTK(func(ctx context.Context, otk *mautrix.OTKCount)) + OnDeviceList(func(ctx context.Context, lists *mautrix.DeviceLists, since string)) } func (mach *OlmMachine) AddAppserviceListener(ep asEventProcessor) { @@ -220,7 +220,7 @@ func (mach *OlmMachine) AddAppserviceListener(ep asEventProcessor) { mach.Log.Debug().Msg("Added listeners for encryption data coming from appservice transactions") } -func (mach *OlmMachine) HandleDeviceLists(dl *mautrix.DeviceLists, since string) { +func (mach *OlmMachine) HandleDeviceLists(ctx context.Context, dl *mautrix.DeviceLists, since string) { if len(dl.Changed) > 0 { traceID := time.Now().Format("15:04:05.000000") mach.Log.Debug(). @@ -228,15 +228,15 @@ func (mach *OlmMachine) HandleDeviceLists(dl *mautrix.DeviceLists, since string) Interface("changes", dl.Changed). Msg("Device list changes in /sync") if mach.DisableKeyFetching { - mach.CryptoStore.MarkTrackedUsersOutdated(context.TODO(), dl.Changed) + mach.CryptoStore.MarkTrackedUsersOutdated(ctx, dl.Changed) } else { - mach.FetchKeys(context.TODO(), dl.Changed, false) + mach.FetchKeys(ctx, dl.Changed, false) } mach.Log.Debug().Str("trace_id", traceID).Msg("Finished handling device list changes") } } -func (mach *OlmMachine) HandleOTKCounts(otkCount *mautrix.OTKCount) { +func (mach *OlmMachine) HandleOTKCounts(ctx context.Context, otkCount *mautrix.OTKCount) { if (len(otkCount.UserID) > 0 && otkCount.UserID != mach.Client.UserID) || (len(otkCount.DeviceID) > 0 && otkCount.DeviceID != mach.Client.DeviceID) { // TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions mach.Log.Warn(). @@ -250,7 +250,7 @@ func (mach *OlmMachine) HandleOTKCounts(otkCount *mautrix.OTKCount) { if otkCount.SignedCurve25519 < int(minCount) { traceID := time.Now().Format("15:04:05.000000") log := mach.Log.With().Str("trace_id", traceID).Logger() - ctx := log.WithContext(context.TODO()) + ctx = log.WithContext(ctx) log.Debug(). Int("keys_left", otkCount.Curve25519). Msg("Sync response said we have less than 50 signed curve25519 keys left, sharing new ones...") @@ -268,8 +268,8 @@ func (mach *OlmMachine) HandleOTKCounts(otkCount *mautrix.OTKCount) { // This can be easily registered into a mautrix client using .OnSync(): // // client.Syncer.(mautrix.ExtensibleSyncer).OnSync(c.crypto.ProcessSyncResponse) -func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string) bool { - mach.HandleDeviceLists(&resp.DeviceLists, since) +func (mach *OlmMachine) ProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) bool { + mach.HandleDeviceLists(ctx, &resp.DeviceLists, since) for _, evt := range resp.ToDevice.Events { evt.Type.Class = event.ToDeviceEventType @@ -278,10 +278,10 @@ func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string mach.Log.Warn().Str("event_type", evt.Type.Type).Err(err).Msg("Failed to parse to-device event") continue } - mach.HandleToDeviceEvent(evt) + mach.HandleToDeviceEvent(ctx, evt) } - mach.HandleOTKCounts(&resp.DeviceOTKCount) + mach.HandleOTKCounts(ctx, &resp.DeviceOTKCount) return true } @@ -290,8 +290,7 @@ func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string // Currently this is not automatically called, so you must add a listener yourself: // // client.Syncer.(mautrix.ExtensibleSyncer).OnEventType(event.StateMember, c.crypto.HandleMemberEvent) -func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Event) { - ctx := context.TODO() +func (mach *OlmMachine) HandleMemberEvent(ctx context.Context, evt *event.Event) { if isEncrypted, err := mach.StateStore.IsEncrypted(ctx, evt.RoomID); err != nil { mach.machOrContextLog(ctx).Err(err).Stringer("room_id", evt.RoomID). Msg("Failed to check if room is encrypted to handle member event") @@ -331,7 +330,7 @@ func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Even // HandleToDeviceEvent handles a single to-device event. This is automatically called by ProcessSyncResponse, so you // don't need to add any custom handlers if you use that method. -func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) { +func (mach *OlmMachine) HandleToDeviceEvent(ctx context.Context, evt *event.Event) { if len(evt.ToUserID) > 0 && (evt.ToUserID != mach.Client.UserID || evt.ToDeviceID != mach.Client.DeviceID) { // TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions mach.Log.Debug(). @@ -341,12 +340,13 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) { return } traceID := time.Now().Format("15:04:05.000000") + // TODO use context log? log := mach.Log.With(). Str("trace_id", traceID). Str("sender", evt.Sender.String()). Str("type", evt.Type.Type). Logger() - ctx := log.WithContext(context.TODO()) + ctx = log.WithContext(ctx) if evt.Type != event.ToDeviceEncrypted { log.Debug().Msg("Starting handling to-device event") } diff --git a/event/events.go b/event/events.go index 57611221..f7b4d4d6 100644 --- a/event/events.go +++ b/event/events.go @@ -105,6 +105,8 @@ func (evt *Event) MarshalJSON() ([]byte, error) { } type MautrixInfo struct { + EventSource Source + TrustState id.TrustState ForwardedKeys bool WasEncrypted bool diff --git a/event/eventsource.go b/event/eventsource.go new file mode 100644 index 00000000..86c1cebe --- /dev/null +++ b/event/eventsource.go @@ -0,0 +1,72 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package event + +import ( + "fmt" +) + +// Source represents the part of the sync response that an event came from. +type Source int + +const ( + SourcePresence Source = 1 << iota + SourceJoin + SourceInvite + SourceLeave + SourceAccountData + SourceTimeline + SourceState + SourceEphemeral + SourceToDevice + SourceDecrypted +) + +const primaryTypes = SourcePresence | SourceAccountData | SourceToDevice | SourceTimeline | SourceState +const roomSections = SourceJoin | SourceInvite | SourceLeave +const roomableTypes = SourceAccountData | SourceTimeline | SourceState +const encryptableTypes = roomableTypes | SourceToDevice + +func (es Source) String() string { + var typeName string + switch es & primaryTypes { + case SourcePresence: + typeName = "presence" + case SourceAccountData: + typeName = "account data" + case SourceToDevice: + typeName = "to-device" + case SourceTimeline: + typeName = "timeline" + case SourceState: + typeName = "state" + default: + return fmt.Sprintf("unknown (%d)", es) + } + if es&roomableTypes != 0 { + switch es & roomSections { + case SourceJoin: + typeName = "joined room " + typeName + case SourceInvite: + typeName = "invited room " + typeName + case SourceLeave: + typeName = "left room " + typeName + default: + return fmt.Sprintf("unknown (%s+%d)", typeName, es) + } + es &^= roomSections + } + if es&encryptableTypes != 0 && es&SourceDecrypted != 0 { + typeName += " (decrypted)" + es &^= SourceDecrypted + } + es &^= primaryTypes + if es != 0 { + return fmt.Sprintf("unknown (%s+%d)", typeName, es) + } + return typeName +} diff --git a/statestore.go b/statestore.go index 63a5bfb4..8fe5f8b3 100644 --- a/statestore.go +++ b/statestore.go @@ -67,8 +67,8 @@ func UpdateStateStore(ctx context.Context, store StateStore, evt *event.Event) { // client.Syncer.(mautrix.ExtensibleSyncer).OnEvent(client.StateStoreSyncHandler) // // DefaultSyncer.ParseEventContent must also be true for this to work (which it is by default). -func (cli *Client) StateStoreSyncHandler(_ EventSource, evt *event.Event) { - UpdateStateStore(cli.Log.WithContext(context.TODO()), cli.StateStore, evt) +func (cli *Client) StateStoreSyncHandler(ctx context.Context, evt *event.Event) { + UpdateStateStore(ctx, cli.StateStore, evt) } type MemoryStateStore struct { diff --git a/sync.go b/sync.go index f05e9b5f..d4208404 100644 --- a/sync.go +++ b/sync.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -7,6 +7,7 @@ package mautrix import ( + "context" "errors" "fmt" "runtime/debug" @@ -16,78 +17,17 @@ import ( "maunium.net/go/mautrix/id" ) -// EventSource represents the part of the sync response that an event came from. -type EventSource int - -const ( - EventSourcePresence EventSource = 1 << iota - EventSourceJoin - EventSourceInvite - EventSourceLeave - EventSourceAccountData - EventSourceTimeline - EventSourceState - EventSourceEphemeral - EventSourceToDevice - EventSourceDecrypted -) - -const primaryTypes = EventSourcePresence | EventSourceAccountData | EventSourceToDevice | EventSourceTimeline | EventSourceState -const roomSections = EventSourceJoin | EventSourceInvite | EventSourceLeave -const roomableTypes = EventSourceAccountData | EventSourceTimeline | EventSourceState -const encryptableTypes = roomableTypes | EventSourceToDevice - -func (es EventSource) String() string { - var typeName string - switch es & primaryTypes { - case EventSourcePresence: - typeName = "presence" - case EventSourceAccountData: - typeName = "account data" - case EventSourceToDevice: - typeName = "to-device" - case EventSourceTimeline: - typeName = "timeline" - case EventSourceState: - typeName = "state" - default: - return fmt.Sprintf("unknown (%d)", es) - } - if es&roomableTypes != 0 { - switch es & roomSections { - case EventSourceJoin: - typeName = "joined room " + typeName - case EventSourceInvite: - typeName = "invited room " + typeName - case EventSourceLeave: - typeName = "left room " + typeName - default: - return fmt.Sprintf("unknown (%s+%d)", typeName, es) - } - es &^= roomSections - } - if es&encryptableTypes != 0 && es&EventSourceDecrypted != 0 { - typeName += " (decrypted)" - es &^= EventSourceDecrypted - } - es &^= primaryTypes - if es != 0 { - return fmt.Sprintf("unknown (%s+%d)", typeName, es) - } - return typeName -} - // EventHandler handles a single event from a sync response. -type EventHandler func(source EventSource, evt *event.Event) +type EventHandler func(ctx context.Context, evt *event.Event) // SyncHandler handles a whole sync response. If the return value is false, handling will be stopped completely. -type SyncHandler func(resp *RespSync, since string) bool +type SyncHandler func(ctx context.Context, resp *RespSync, since string) bool // Syncer is an interface that must be satisfied in order to do /sync requests on a client. type Syncer interface { // ProcessResponse processes the /sync response. The since parameter is the since= value that was used to produce the response. // This is useful for detecting the very first sync (since=""). If an error is return, Syncing will be stopped permanently. - ProcessResponse(resp *RespSync, since string) error + ProcessResponse(ctx context.Context, resp *RespSync, since string) error // OnFailedSync returns either the time to wait before retrying or an error to stop syncing permanently. OnFailedSync(res *RespSync, err error) (time.Duration, error) // GetFilterJSON for the given user ID. NOT the filter ID. @@ -101,7 +41,7 @@ type ExtensibleSyncer interface { } type DispatchableSyncer interface { - Dispatch(source EventSource, evt *event.Event) + Dispatch(ctx context.Context, evt *event.Event) } // DefaultSyncer is the default syncing implementation. You can either write your own syncer, or selectively @@ -144,7 +84,7 @@ func NewDefaultSyncer() *DefaultSyncer { // ProcessResponse processes the /sync response in a way suitable for bots. "Suitable for bots" means a stream of // unrepeating events. Returns a fatal error if a listener panics. -func (s *DefaultSyncer) ProcessResponse(res *RespSync, since string) (err error) { +func (s *DefaultSyncer) ProcessResponse(ctx context.Context, res *RespSync, since string) (err error) { defer func() { if r := recover(); r != nil { err = fmt.Errorf("ProcessResponse panicked! since=%s panic=%s\n%s", since, r, debug.Stack()) @@ -152,38 +92,38 @@ func (s *DefaultSyncer) ProcessResponse(res *RespSync, since string) (err error) }() for _, listener := range s.syncListeners { - if !listener(res, since) { + if !listener(ctx, res, since) { return } } - s.processSyncEvents("", res.ToDevice.Events, EventSourceToDevice) - s.processSyncEvents("", res.Presence.Events, EventSourcePresence) - s.processSyncEvents("", res.AccountData.Events, EventSourceAccountData) + s.processSyncEvents(ctx, "", res.ToDevice.Events, event.SourceToDevice) + s.processSyncEvents(ctx, "", res.Presence.Events, event.SourcePresence) + s.processSyncEvents(ctx, "", res.AccountData.Events, event.SourceAccountData) for roomID, roomData := range res.Rooms.Join { - s.processSyncEvents(roomID, roomData.State.Events, EventSourceJoin|EventSourceState) - s.processSyncEvents(roomID, roomData.Timeline.Events, EventSourceJoin|EventSourceTimeline) - s.processSyncEvents(roomID, roomData.Ephemeral.Events, EventSourceJoin|EventSourceEphemeral) - s.processSyncEvents(roomID, roomData.AccountData.Events, EventSourceJoin|EventSourceAccountData) + s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceJoin|event.SourceState) + s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceJoin|event.SourceTimeline) + s.processSyncEvents(ctx, roomID, roomData.Ephemeral.Events, event.SourceJoin|event.SourceEphemeral) + s.processSyncEvents(ctx, roomID, roomData.AccountData.Events, event.SourceJoin|event.SourceAccountData) } for roomID, roomData := range res.Rooms.Invite { - s.processSyncEvents(roomID, roomData.State.Events, EventSourceInvite|EventSourceState) + s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceInvite|event.SourceState) } for roomID, roomData := range res.Rooms.Leave { - s.processSyncEvents(roomID, roomData.State.Events, EventSourceLeave|EventSourceState) - s.processSyncEvents(roomID, roomData.Timeline.Events, EventSourceLeave|EventSourceTimeline) + s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceLeave|event.SourceState) + s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceLeave|event.SourceTimeline) } return } -func (s *DefaultSyncer) processSyncEvents(roomID id.RoomID, events []*event.Event, source EventSource) { +func (s *DefaultSyncer) processSyncEvents(ctx context.Context, roomID id.RoomID, events []*event.Event, source event.Source) { for _, evt := range events { - s.processSyncEvent(roomID, evt, source) + s.processSyncEvent(ctx, roomID, evt, source) } } -func (s *DefaultSyncer) processSyncEvent(roomID id.RoomID, evt *event.Event, source EventSource) { +func (s *DefaultSyncer) processSyncEvent(ctx context.Context, roomID id.RoomID, evt *event.Event, source event.Source) { evt.RoomID = roomID // Ensure the type class is correct. It's safe to mutate the class since the event type is not a pointer. @@ -191,11 +131,11 @@ func (s *DefaultSyncer) processSyncEvent(roomID id.RoomID, evt *event.Event, sou switch { case evt.StateKey != nil: evt.Type.Class = event.StateEventType - case source == EventSourcePresence, source&EventSourceEphemeral != 0: + case source == event.SourcePresence, source&event.SourceEphemeral != 0: evt.Type.Class = event.EphemeralEventType - case source&EventSourceAccountData != 0: + case source&event.SourceAccountData != 0: evt.Type.Class = event.AccountDataEventType - case source == EventSourceToDevice: + case source == event.SourceToDevice: evt.Type.Class = event.ToDeviceEventType default: evt.Type.Class = event.MessageEventType @@ -208,17 +148,18 @@ func (s *DefaultSyncer) processSyncEvent(roomID id.RoomID, evt *event.Event, sou } } - s.Dispatch(source, evt) + evt.Mautrix.EventSource = source + s.Dispatch(ctx, evt) } -func (s *DefaultSyncer) Dispatch(source EventSource, evt *event.Event) { +func (s *DefaultSyncer) Dispatch(ctx context.Context, evt *event.Event) { for _, fn := range s.globalListeners { - fn(source, evt) + fn(ctx, evt) } listeners, exists := s.listeners[evt.Type] if exists { for _, fn := range listeners { - fn(source, evt) + fn(ctx, evt) } } } @@ -266,31 +207,18 @@ func (s *DefaultSyncer) GetFilterJSON(userID id.UserID) *Filter { return s.FilterJSON } -// OldEventIgnorer is a utility struct for bots to ignore events from before the bot joined the room. -// -// Deprecated: Use Client.DontProcessOldEvents instead. -type OldEventIgnorer struct { - UserID id.UserID -} - -func (oei *OldEventIgnorer) Register(syncer ExtensibleSyncer) { - syncer.OnSync(oei.DontProcessOldEvents) -} - -func (oei *OldEventIgnorer) DontProcessOldEvents(resp *RespSync, since string) bool { - return dontProcessOldEvents(oei.UserID, resp, since) -} - // DontProcessOldEvents is a sync handler that removes rooms that the user just joined. // It's meant for bots to ignore events from before the bot joined the room. // // To use it, register it with your Syncer, e.g.: // // cli.Syncer.(mautrix.ExtensibleSyncer).OnSync(cli.DontProcessOldEvents) -func (cli *Client) DontProcessOldEvents(resp *RespSync, since string) bool { +func (cli *Client) DontProcessOldEvents(_ context.Context, resp *RespSync, since string) bool { return dontProcessOldEvents(cli.UserID, resp, since) } +var _ SyncHandler = (*Client)(nil).DontProcessOldEvents + func dontProcessOldEvents(userID id.UserID, resp *RespSync, since string) bool { if since == "" { return false @@ -327,7 +255,7 @@ func dontProcessOldEvents(userID id.UserID, resp *RespSync, since string) bool { // To use it, register it with your Syncer, e.g.: // // cli.Syncer.(mautrix.ExtensibleSyncer).OnSync(cli.MoveInviteState) -func (cli *Client) MoveInviteState(resp *RespSync, _ string) bool { +func (cli *Client) MoveInviteState(ctx context.Context, resp *RespSync, _ string) bool { for _, meta := range resp.Rooms.Invite { var inviteState []event.StrippedState var inviteEvt *event.Event @@ -352,3 +280,5 @@ func (cli *Client) MoveInviteState(resp *RespSync, _ string) bool { } return true } + +var _ SyncHandler = (*Client)(nil).MoveInviteState From 0d04e346fe45b78da515978bd6d2fbb20506854b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 13 Jan 2024 18:57:52 +0200 Subject: [PATCH 0058/1647] Update example --- example/go.mod | 14 +++++++------- example/go.sum | 30 +++++++++++++++--------------- example/main.go | 8 ++++---- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/example/go.mod b/example/go.mod index 01790dc4..f78b3fa0 100644 --- a/example/go.mod +++ b/example/go.mod @@ -4,9 +4,9 @@ go 1.20 require ( github.com/chzyer/readline v1.5.1 - github.com/mattn/go-sqlite3 v1.14.18 + github.com/mattn/go-sqlite3 v1.14.19 github.com/rs/zerolog v1.31.0 - maunium.net/go/mautrix v0.16.3-0.20231215142331-753cdb2e1cb0 + maunium.net/go/mautrix v0.16.3-0.20240113165612-308e3583b06f ) require ( @@ -16,11 +16,11 @@ require ( github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/tidwall/sjson v1.2.5 // indirect - go.mau.fi/util v0.2.1 // indirect - golang.org/x/crypto v0.15.0 // indirect - golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa // indirect - golang.org/x/net v0.18.0 // indirect - golang.org/x/sys v0.14.0 // indirect + go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894 // indirect + golang.org/x/crypto v0.17.0 // indirect + golang.org/x/exp v0.0.0-20231226003508-02704c960a9b // indirect + golang.org/x/net v0.19.0 // indirect + golang.org/x/sys v0.15.0 // indirect maunium.net/go/maulogger/v2 v2.4.1 // indirect ) diff --git a/example/go.sum b/example/go.sum index c292f262..0a3092ed 100644 --- a/example/go.sum +++ b/example/go.sum @@ -1,4 +1,4 @@ -github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= +github.com/DATA-DOG/go-sqlmock v1.5.1 h1:FK6RCIUSfmbnI/imIICmboyQBkOckutaa6R5YYlLZyo= github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM= github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI= @@ -13,8 +13,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI= -github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/mattn/go-sqlite3 v1.14.19 h1:fhGleo2h1p8tVChob4I9HpmVFIAkKGpiukdrgQbWfGI= +github.com/mattn/go-sqlite3 v1.14.19/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= @@ -30,22 +30,22 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -go.mau.fi/util v0.2.1 h1:eazulhFE/UmjOFtPrGg6zkF5YfAyiDzQb8ihLMbsPWw= -go.mau.fi/util v0.2.1/go.mod h1:MjlzCQEMzJ+G8RsPawHzpLB8rwTo3aPIjG5FzBvQT/c= -golang.org/x/crypto v0.15.0 h1:frVn1TEaCEaZcn3Tmd7Y2b5KKPaZ+I32Q2OA3kYp5TA= -golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g= -golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ= -golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE= -golang.org/x/net v0.18.0 h1:mIYleuAkSbHh0tCv7RvjL3F6ZVbLjq4+R7zbOn3Kokg= -golang.org/x/net v0.18.0/go.mod h1:/czyP5RqHAH4odGYxBJ1qz0+CE5WZ+2j1YgoEo8F2jQ= +go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894 h1:CuR5LDSxBQLETorfwJ9vRtySeLHjMvJ7//lnCMw7Dy8= +go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs= +golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= +golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +golang.org/x/exp v0.0.0-20231226003508-02704c960a9b h1:kLiC65FbiHWFAOu+lxwNPujcsl8VYyTYYEZnsOO1WK4= +golang.org/x/exp v0.0.0-20231226003508-02704c960a9b/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= +golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= +golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= -golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8= maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho= -maunium.net/go/mautrix v0.16.3-0.20231215142331-753cdb2e1cb0 h1:2ZWtBcTScQfMwpcoGeY4mLYXC6OmYN/4Qh2yhBiVNV4= -maunium.net/go/mautrix v0.16.3-0.20231215142331-753cdb2e1cb0/go.mod h1:YL4l4rZB46/vj/ifRMEjcibbvHjgxHftOF1SgmruLu4= +maunium.net/go/mautrix v0.16.3-0.20240113165612-308e3583b06f h1:6uzyAxrjqGv2SbTAnIK3LI6mo1fILWOga6uNyId+6yM= +maunium.net/go/mautrix v0.16.3-0.20240113165612-308e3583b06f/go.mod h1:eRQu5ED1ODsP+xq1K9l1AOD+O9FMkAhodd/RVc3Bkqg= diff --git a/example/main.go b/example/main.go index aaa94a3c..f799409c 100644 --- a/example/main.go +++ b/example/main.go @@ -62,7 +62,7 @@ func main() { var lastRoomID id.RoomID syncer := client.Syncer.(*mautrix.DefaultSyncer) - syncer.OnEventType(event.EventMessage, func(source mautrix.EventSource, evt *event.Event) { + syncer.OnEventType(event.EventMessage, func(ctx context.Context, evt *event.Event) { lastRoomID = evt.RoomID rl.SetPrompt(fmt.Sprintf("%s> ", lastRoomID)) log.Info(). @@ -72,9 +72,9 @@ func main() { Str("body", evt.Content.AsMessage().Body). Msg("Received message") }) - syncer.OnEventType(event.StateMember, func(source mautrix.EventSource, evt *event.Event) { + syncer.OnEventType(event.StateMember, func(ctx context.Context, evt *event.Event) { if evt.GetStateKey() == client.UserID.String() && evt.Content.AsMember().Membership == event.MembershipInvite { - _, err := client.JoinRoomByID(context.TODO(), evt.RoomID) + _, err := client.JoinRoomByID(ctx, evt.RoomID) if err == nil { lastRoomID = evt.RoomID rl.SetPrompt(fmt.Sprintf("%s> ", lastRoomID)) @@ -108,7 +108,7 @@ func main() { } // If you want to use multiple clients with the same DB, you should set a distinct database account ID for each one. //cryptoHelper.DBAccountID = "" - err = cryptoHelper.Init() + err = cryptoHelper.Init(context.TODO()) if err != nil { panic(err) } From d7c1cf6b64bf890db9ebc8211da6411c1308f29f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 13 Jan 2024 19:01:52 +0200 Subject: [PATCH 0059/1647] Update dependencies --- go.mod | 10 +++++----- go.sum | 20 ++++++++++---------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/go.mod b/go.mod index 8484acc3..64f5d49e 100644 --- a/go.mod +++ b/go.mod @@ -12,11 +12,11 @@ require ( github.com/tidwall/gjson v1.17.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.6.0 - go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894 + go.mau.fi/util v0.2.2-0.20240112154312-b89d6e13ae53 go.mau.fi/zeroconfig v0.1.2 - golang.org/x/crypto v0.17.0 - golang.org/x/exp v0.0.0-20231226003508-02704c960a9b - golang.org/x/net v0.19.0 + golang.org/x/crypto v0.18.0 + golang.org/x/exp v0.0.0-20240112132812-db7319d0e0e3 + golang.org/x/net v0.20.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 maunium.net/go/maulogger/v2 v2.4.1 @@ -30,6 +30,6 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect - golang.org/x/sys v0.15.0 // indirect + golang.org/x/sys v0.16.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index d923c7b1..765b5490 100644 --- a/go.sum +++ b/go.sum @@ -36,21 +36,21 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.6.0 h1:boZcn2GTjpsynOsC0iJHnBWa4Bi0qzfJjthwauItG68= github.com/yuin/goldmark v1.6.0/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894 h1:CuR5LDSxBQLETorfwJ9vRtySeLHjMvJ7//lnCMw7Dy8= -go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs= +go.mau.fi/util v0.2.2-0.20240112154312-b89d6e13ae53 h1:1RbC2484wnz5paT/254/Hj+2HOKb+2cqpxaUbsV08jc= +go.mau.fi/util v0.2.2-0.20240112154312-b89d6e13ae53/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= -golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= -golang.org/x/exp v0.0.0-20231226003508-02704c960a9b h1:kLiC65FbiHWFAOu+lxwNPujcsl8VYyTYYEZnsOO1WK4= -golang.org/x/exp v0.0.0-20231226003508-02704c960a9b/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= -golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= -golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= +golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc= +golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= +golang.org/x/exp v0.0.0-20240112132812-db7319d0e0e3 h1:hNQpMuAJe5CtcUqCXaWga3FHu+kQvCqcsoVaQgSV60o= +golang.org/x/exp v0.0.0-20240112132812-db7319d0e0e3/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= +golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo= +golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= From fe88d047680516ec918946e278ac925e58b63a97 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 13 Jan 2024 19:42:56 +0200 Subject: [PATCH 0060/1647] Remove Token field in ReqQueryKeys It was removed in v1.7 of the spec --- requests.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/requests.go b/requests.go index 985c8338..6e00346a 100644 --- a/requests.go +++ b/requests.go @@ -287,9 +287,7 @@ type Signatures map[id.UserID]map[id.KeyID]string type ReqQueryKeys struct { DeviceKeys DeviceKeysRequest `json:"device_keys"` - - Timeout int64 `json:"timeout,omitempty"` - Token string `json:"token,omitempty"` + Timeout int64 `json:"timeout,omitempty"` } type DeviceKeysRequest map[id.UserID]DeviceIDList From 5b0d4ba08608c4182c357e7ea7e3d3e231dba0c1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 13 Jan 2024 19:43:43 +0200 Subject: [PATCH 0061/1647] Update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a04fbff4..851790d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ * **Breaking change *(everything)*** Added context parameters to all functions (started by [@recht] in [#144]). * *(client)* Moved `EventSource` to `event.Source`. +* *(client)* Removed deprecated `OldEventIgnorer`. The non-deprecated version + (`Client.DontProcessOldEvents`) is still available. * *(crypto)* Added experimental pure Go Olm implementation to replace libolm (thanks to [@DerLukas15] in [#106]). * You can use the `goolm` build tag to the new implementation. From 970ba1a907f5da4fbbbc2cf5c06ce275570305da Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Mon, 15 Jan 2024 09:41:55 +0200 Subject: [PATCH 0062/1647] Store own device keys on init --- crypto/decryptolm.go | 2 +- crypto/machine.go | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go index f99c7dbe..68eaa875 100644 --- a/crypto/decryptolm.go +++ b/crypto/decryptolm.go @@ -216,7 +216,7 @@ func (mach *OlmMachine) createInboundSession(ctx context.Context, senderKey id.S if err != nil { return nil, err } - mach.saveAccount() + mach.saveAccount(ctx) err = mach.CryptoStore.AddSession(ctx, senderKey, session) if err != nil { zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to store created inbound session") diff --git a/crypto/machine.go b/crypto/machine.go index 9892536a..b7c41ab0 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -145,8 +145,8 @@ func (mach *OlmMachine) Load(ctx context.Context) (err error) { return nil } -func (mach *OlmMachine) saveAccount() { - err := mach.CryptoStore.PutAccount(context.TODO(), mach.account) +func (mach *OlmMachine) saveAccount(ctx context.Context) { + err := mach.CryptoStore.PutAccount(ctx, mach.account) if err != nil { mach.Log.Error().Err(err).Msg("Failed to save account") } @@ -655,6 +655,15 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro var deviceKeys *mautrix.DeviceKeys if !mach.account.Shared { deviceKeys = mach.account.getInitialKeys(mach.Client.UserID, mach.Client.DeviceID) + err := mach.CryptoStore.PutDevice(ctx, mach.Client.UserID, &id.Device{ + UserID: mach.Client.UserID, + DeviceID: mach.Client.DeviceID, + IdentityKey: deviceKeys.Keys.GetCurve25519(mach.Client.DeviceID), + SigningKey: deviceKeys.Keys.GetEd25519(mach.Client.DeviceID), + }) + if err != nil { + return fmt.Errorf("failed to save initial keys: %w", err) + } log.Debug().Msg("Going to upload initial account keys") } oneTimeKeys := mach.account.getOneTimeKeys(mach.Client.UserID, mach.Client.DeviceID, currentOTKCount) @@ -673,7 +682,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro } mach.lastOTKUpload = time.Now() mach.account.Shared = true - mach.saveAccount() + mach.saveAccount(ctx) return nil } From ac69c357b9854a80fe54c5a9bd7851c27c077044 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 15 Jan 2024 10:38:05 -0700 Subject: [PATCH 0063/1647] crypto/utils: use crypto/rand instead of math/rand Signed-off-by: Sumner Evans --- crypto/utils/utils.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crypto/utils/utils.go b/crypto/utils/utils.go index 414d83bd..382db02f 100644 --- a/crypto/utils/utils.go +++ b/crypto/utils/utils.go @@ -10,10 +10,10 @@ import ( "crypto/aes" "crypto/cipher" "crypto/hmac" + "crypto/rand" "crypto/sha256" "crypto/sha512" "encoding/base64" - "math/rand" "strings" "go.mau.fi/util/base58" From a0b92fd1851e9710be93a0effba7611adf98e23f Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 15 Jan 2024 15:14:33 -0700 Subject: [PATCH 0064/1647] crypto/goolm/session: use crypto/rand instead of math/rand Signed-off-by: Sumner Evans --- crypto/goolm/session/megolm_outbound_session.go | 2 +- crypto/goolm/session/megolm_session_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crypto/goolm/session/megolm_outbound_session.go b/crypto/goolm/session/megolm_outbound_session.go index 8964a68d..e594258d 100644 --- a/crypto/goolm/session/megolm_outbound_session.go +++ b/crypto/goolm/session/megolm_outbound_session.go @@ -1,10 +1,10 @@ package session import ( + "crypto/rand" "encoding/base64" "errors" "fmt" - "math/rand" "maunium.net/go/mautrix/id" diff --git a/crypto/goolm/session/megolm_session_test.go b/crypto/goolm/session/megolm_session_test.go index 93eec7eb..9b3f56b5 100644 --- a/crypto/goolm/session/megolm_session_test.go +++ b/crypto/goolm/session/megolm_session_test.go @@ -2,8 +2,8 @@ package session_test import ( "bytes" + "crypto/rand" "errors" - "math/rand" "testing" "maunium.net/go/mautrix/crypto/goolm" From ff6bd01335a97eccef51c61e424631305ff6da25 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 16 Jan 2024 16:02:12 +0200 Subject: [PATCH 0065/1647] Add support for custom validators for entire bridge config --- CHANGELOG.md | 2 ++ bridge/bridge.go | 15 ++++++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 851790d3..c09ea319 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,9 +11,11 @@ (thanks to [@DerLukas15] in [#106]). * You can use the `goolm` build tag to the new implementation. * *(bridge)* Added context parameter for bridge command events. +* *(bridge)* Added method to allow custom validation for the entire config. * *(client)* Changed default syncer to not drop unknown events. * The syncer will still drop known events if parsing the content fails. * The behavior can be changed by changing the `ParseErrorHandler` function. +* *(crypto)* Fixed some places using math/rand instead of crypto/rand. [@DerLukas15]: https://github.com/DerLukas15 [@recht]: https://github.com/recht diff --git a/bridge/bridge.go b/bridge/bridge.go index 7d5333ce..c41fba27 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -135,6 +135,11 @@ type ChildOverride interface { CreatePrivatePortal(id.RoomID, User, Ghost) } +type ConfigValidatingBridge interface { + ChildOverride + ValidateConfig() error +} + type FlagHandlingBridge interface { ChildOverride HandleFlags() bool @@ -469,7 +474,15 @@ func (br *Bridge) validateConfig() error { case br.Config.AppService.Database.URI == "postgres://user:password@host/database?sslmode=disable": return errors.New("appservice.database not configured") default: - return br.Config.Bridge.Validate() + err := br.Config.Bridge.Validate() + if err != nil { + return err + } + validator, ok := br.Child.(ConfigValidatingBridge) + if ok { + return validator.ValidateConfig() + } + return nil } } From f37c2d8d740ed59527320a66c96e70eac99cf6ec Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 16 Jan 2024 16:06:54 +0200 Subject: [PATCH 0066/1647] Bump version to v0.17.0 --- CHANGELOG.md | 4 ++-- go.mod | 2 +- go.sum | 4 ++-- version.go | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c09ea319..d840bcbe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,10 +1,10 @@ -## v0.17.0 (unreleased) +## v0.17.0 (2024-01-16) * **Breaking change *(bridge)*** Added raw event to portal membership handling functions. * **Breaking change *(everything)*** Added context parameters to all functions (started by [@recht] in [#144]). -* *(client)* Moved `EventSource` to `event.Source`. +* **Breaking change *(client)*** Moved `EventSource` to `event.Source`. * *(client)* Removed deprecated `OldEventIgnorer`. The non-deprecated version (`Client.DontProcessOldEvents`) is still available. * *(crypto)* Added experimental pure Go Olm implementation to replace libolm diff --git a/go.mod b/go.mod index 64f5d49e..48ff59e0 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/tidwall/gjson v1.17.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.6.0 - go.mau.fi/util v0.2.2-0.20240112154312-b89d6e13ae53 + go.mau.fi/util v0.3.0 go.mau.fi/zeroconfig v0.1.2 golang.org/x/crypto v0.18.0 golang.org/x/exp v0.0.0-20240112132812-db7319d0e0e3 diff --git a/go.sum b/go.sum index 765b5490..9061a651 100644 --- a/go.sum +++ b/go.sum @@ -36,8 +36,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.6.0 h1:boZcn2GTjpsynOsC0iJHnBWa4Bi0qzfJjthwauItG68= github.com/yuin/goldmark v1.6.0/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.mau.fi/util v0.2.2-0.20240112154312-b89d6e13ae53 h1:1RbC2484wnz5paT/254/Hj+2HOKb+2cqpxaUbsV08jc= -go.mau.fi/util v0.2.2-0.20240112154312-b89d6e13ae53/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs= +go.mau.fi/util v0.3.0 h1:Lt3lbRXP6ZBqTINK0EieRWor3zEwwwrDT14Z5N8RUCs= +go.mau.fi/util v0.3.0/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc= diff --git a/version.go b/version.go index 7b0c3dbe..d92a7977 100644 --- a/version.go +++ b/version.go @@ -7,7 +7,7 @@ import ( "strings" ) -const Version = "v0.16.2" +const Version = "v0.17.0" var GoModVersion = "" var Commit = "" From d151ec47114a8a8b8b682ad58ecfc04f1040e226 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 15 Jan 2024 17:02:02 +0200 Subject: [PATCH 0067/1647] Drop support for unprefixed appservice paths --- CHANGELOG.md | 5 +++++ appservice/appservice.go | 4 ---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d840bcbe..4c89c495 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +## unreleased + +* *(appservice)* Dropped support for legacy non-prefixed appservice paths + (e.g. `/transactions` instead of `/_matrix/app/v1/transactions`). + ## v0.17.0 (2024-01-16) * **Breaking change *(bridge)*** Added raw event to portal membership handling diff --git a/appservice/appservice.go b/appservice/appservice.go index dc5e82be..76d2f786 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -58,14 +58,10 @@ func Create() *AppService { QueryHandler: &QueryHandlerStub{}, } - as.Router.HandleFunc("/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut) - as.Router.HandleFunc("/rooms/{roomAlias}", as.GetRoom).Methods(http.MethodGet) - as.Router.HandleFunc("/users/{userID}", as.GetUser).Methods(http.MethodGet) as.Router.HandleFunc("/_matrix/app/v1/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut) as.Router.HandleFunc("/_matrix/app/v1/rooms/{roomAlias}", as.GetRoom).Methods(http.MethodGet) as.Router.HandleFunc("/_matrix/app/v1/users/{userID}", as.GetUser).Methods(http.MethodGet) as.Router.HandleFunc("/_matrix/app/v1/ping", as.PostPing).Methods(http.MethodPost) - as.Router.HandleFunc("/_matrix/app/unstable/fi.mau.msc2659/ping", as.PostPing).Methods(http.MethodPost) as.Router.HandleFunc("/_matrix/mau/live", as.GetLive).Methods(http.MethodGet) as.Router.HandleFunc("/_matrix/mau/ready", as.GetReady).Methods(http.MethodGet) From c8e9998e7f73bfbb633ff689cc08689c43b0bfd8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 15 Jan 2024 17:09:13 +0200 Subject: [PATCH 0068/1647] Drop support for legacy query param auth for appservices --- CHANGELOG.md | 3 +++ appservice/http.go | 25 +++++++++---------------- bridge/bridge.go | 2 +- 3 files changed, 13 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4c89c495..c17e5b16 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ * *(appservice)* Dropped support for legacy non-prefixed appservice paths (e.g. `/transactions` instead of `/_matrix/app/v1/transactions`). +* *(appservice)* Dropped support for legacy `access_token` authorization in + appservice endpoints. +* *(bridge)* Bumped minimum Matrix spec version to v1.4. ## v0.17.0 (2024-01-16) diff --git a/appservice/http.go b/appservice/http.go index 1d4c7f22..38bcecf8 100644 --- a/appservice/http.go +++ b/appservice/http.go @@ -82,27 +82,20 @@ func (as *AppService) Stop() { // CheckServerToken checks if the given request originated from the Matrix homeserver. func (as *AppService) CheckServerToken(w http.ResponseWriter, r *http.Request) (isValid bool) { authHeader := r.Header.Get("Authorization") - if len(authHeader) > 0 && strings.HasPrefix(authHeader, "Bearer ") { - isValid = authHeader[len("Bearer "):] == as.Registration.ServerToken - } else { - queryToken := r.URL.Query().Get("access_token") - if len(queryToken) > 0 { - isValid = queryToken == as.Registration.ServerToken - } else { - Error{ - ErrorCode: ErrUnknownToken, - HTTPStatus: http.StatusForbidden, - Message: "Missing access token", - }.Write(w) - return - } - } - if !isValid { + if !strings.HasPrefix(authHeader, "Bearer ") { + Error{ + ErrorCode: ErrUnknownToken, + HTTPStatus: http.StatusForbidden, + Message: "Missing access token", + }.Write(w) + } else if authHeader[len("Bearer "):] != as.Registration.ServerToken { Error{ ErrorCode: ErrUnknownToken, HTTPStatus: http.StatusForbidden, Message: "Incorrect access token", }.Write(w) + } else { + isValid = true } return } diff --git a/bridge/bridge.go b/bridge/bridge.go index c41fba27..134582a2 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -291,7 +291,7 @@ func (br *Bridge) InitVersion(tag, commit, buildTime string) { br.BuildTime = buildTime } -var MinSpecVersion = mautrix.SpecV11 +var MinSpecVersion = mautrix.SpecV14 func (br *Bridge) ensureConnection(ctx context.Context) { for { From 66cfa6389e72838b99a6e6545b99f34030a488b8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 16 Jan 2024 21:26:04 +0200 Subject: [PATCH 0069/1647] Fix RawArgs when using command state function --- CHANGELOG.md | 1 + bridge/commands/processor.go | 1 + 2 files changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c17e5b16..c0d3bc8e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ * *(appservice)* Dropped support for legacy `access_token` authorization in appservice endpoints. * *(bridge)* Bumped minimum Matrix spec version to v1.4. +* *(bridge)* Fixed `RawArgs` field in command events of command state callbacks. ## v0.17.0 (2024-01-16) diff --git a/bridge/commands/processor.go b/bridge/commands/processor.go index 904f5c40..70dd16e9 100644 --- a/bridge/commands/processor.go +++ b/bridge/commands/processor.go @@ -110,6 +110,7 @@ func (proc *Processor) Handle(ctx context.Context, roomID id.RoomID, eventID id. } if state != nil && state.Next != nil { ce.Command = "" + ce.RawArgs = message ce.Args = args ce.Handler = state.Next state.Next.Run(ce) From 6ac759c8ff4c1041b6f4edbdc323cc1e12324deb Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Wed, 17 Jan 2024 09:26:13 +0200 Subject: [PATCH 0070/1647] Only skip fetching keys during Megolm decryption if disabled Blanket disabling caused a lot of side effects which were hard to deal with without major refactoring. This should probably be an argument to DecryptMegolm instead of a flag. --- crypto/decryptmegolm.go | 6 +++++- crypto/machine.go | 14 +++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index 540f99ca..abe01871 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -72,7 +72,11 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event if sess.SigningKey == ownSigningKey && sess.SenderKey == ownIdentityKey && len(sess.ForwardingChains) == 0 { trustLevel = id.TrustStateVerified } else { - device, err = mach.GetOrFetchDeviceByKey(ctx, evt.Sender, sess.SenderKey) + if mach.DisableDecryptKeyFetching { + device, err = mach.CryptoStore.FindDeviceByKey(ctx, evt.Sender, sess.SenderKey) + } else { + device, err = mach.GetOrFetchDeviceByKey(ctx, evt.Sender, sess.SenderKey) + } if err != nil { // We don't want to throw these errors as the message can still be decrypted. log.Debug().Err(err).Msg("Failed to get device to verify session") diff --git a/crypto/machine.go b/crypto/machine.go index b7c41ab0..77d99a8f 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -33,8 +33,8 @@ type OlmMachine struct { PlaintextMentions bool - // Never ask the server for keys automatically as a side effect. - DisableKeyFetching bool + // Never ask the server for keys automatically as a side effect during Megolm decryption. + DisableDecryptKeyFetching bool SendKeysMinTrust id.TrustState ShareKeysMinTrust id.TrustState @@ -227,11 +227,7 @@ func (mach *OlmMachine) HandleDeviceLists(ctx context.Context, dl *mautrix.Devic Str("trace_id", traceID). Interface("changes", dl.Changed). Msg("Device list changes in /sync") - if mach.DisableKeyFetching { - mach.CryptoStore.MarkTrackedUsersOutdated(ctx, dl.Changed) - } else { - mach.FetchKeys(ctx, dl.Changed, false) - } + mach.FetchKeys(ctx, dl.Changed, false) mach.Log.Debug().Str("trace_id", traceID).Msg("Finished handling device list changes") } } @@ -420,7 +416,7 @@ func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID, device, err := mach.CryptoStore.GetDevice(ctx, userID, deviceID) if err != nil { return nil, fmt.Errorf("failed to get sender device from store: %w", err) - } else if device != nil || mach.DisableKeyFetching { + } else if device != nil { return device, nil } if usersToDevices, err := mach.FetchKeys(ctx, []id.UserID{userID}, true); err != nil { @@ -439,7 +435,7 @@ func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID, // the given identity key. func (mach *OlmMachine) GetOrFetchDeviceByKey(ctx context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) { deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(ctx, userID, identityKey) - if err != nil || deviceIdentity != nil || mach.DisableKeyFetching { + if err != nil || deviceIdentity != nil { return deviceIdentity, err } mach.machOrContextLog(ctx).Debug(). From 9f12b80726b7ce24e6aeafcaa6e6a061cba40a95 Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Wed, 17 Jan 2024 11:29:32 +0200 Subject: [PATCH 0071/1647] Open up OlmMachine event handlers --- crypto/keysharing.go | 4 +-- crypto/machine.go | 81 +++++++++++++++++++++++--------------------- 2 files changed, 45 insertions(+), 40 deletions(-) diff --git a/crypto/keysharing.go b/crypto/keysharing.go index 8cf15d35..09da1d1a 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -248,7 +248,7 @@ func (mach *OlmMachine) defaultAllowKeyShare(ctx context.Context, device *id.Dev } } -func (mach *OlmMachine) handleRoomKeyRequest(ctx context.Context, sender id.UserID, content *event.RoomKeyRequestEventContent) { +func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.UserID, content *event.RoomKeyRequestEventContent) { log := zerolog.Ctx(ctx).With(). Str("request_id", content.RequestID). Str("device_id", content.RequestingDeviceID.String()). @@ -327,7 +327,7 @@ func (mach *OlmMachine) handleRoomKeyRequest(ctx context.Context, sender id.User } } -func (mach *OlmMachine) handleBeeperRoomKeyAck(ctx context.Context, sender id.UserID, content *event.BeeperRoomKeyAckEventContent) { +func (mach *OlmMachine) HandleBeeperRoomKeyAck(ctx context.Context, sender id.UserID, content *event.BeeperRoomKeyAckEventContent) { log := mach.machOrContextLog(ctx).With(). Str("room_id", content.RoomID.String()). Str("session_id", content.SessionID.String()). diff --git a/crypto/machine.go b/crypto/machine.go index 77d99a8f..b1ecd754 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -324,6 +324,44 @@ func (mach *OlmMachine) HandleMemberEvent(ctx context.Context, evt *event.Event) } } +func (mach *OlmMachine) HandleEncryptedEvent(ctx context.Context, evt *event.Event) { + if _, ok := evt.Content.Parsed.(*event.EncryptedEventContent); !ok { + mach.machOrContextLog(ctx).Warn().Msg("Passed invalid event to encrypted handler") + return + } + + decryptedEvt, err := mach.decryptOlmEvent(ctx, evt) + if err != nil { + mach.machOrContextLog(ctx).Error().Err(err).Msg("Failed to decrypt to-device event") + return + } + + log := mach.machOrContextLog(ctx).With(). + Str("decrypted_type", decryptedEvt.Type.Type). + Str("sender_device", decryptedEvt.SenderDevice.String()). + Str("sender_signing_key", decryptedEvt.Keys.Ed25519.String()). + Logger() + log.Trace().Msg("Successfully decrypted to-device event") + + switch decryptedContent := decryptedEvt.Content.Parsed.(type) { + case *event.RoomKeyEventContent: + mach.receiveRoomKey(ctx, decryptedEvt, decryptedContent) + log.Trace().Msg("Handled room key event") + case *event.ForwardedRoomKeyEventContent: + if mach.importForwardedRoomKey(ctx, decryptedEvt, decryptedContent) { + if ch, ok := mach.roomKeyRequestFilled.Load(decryptedContent.SessionID); ok { + // close channel to notify listener that the key was received + close(ch.(chan struct{})) + } + } + log.Trace().Msg("Handled forwarded room key event") + case *event.DummyEventContent: + log.Debug().Msg("Received encrypted dummy event") + default: + log.Debug().Msg("Unhandled encrypted to-device event") + } +} + // HandleToDeviceEvent handles a single to-device event. This is automatically called by ProcessSyncResponse, so you // don't need to add any custom handlers if you use that method. func (mach *OlmMachine) HandleToDeviceEvent(ctx context.Context, evt *event.Event) { @@ -348,45 +386,12 @@ func (mach *OlmMachine) HandleToDeviceEvent(ctx context.Context, evt *event.Even } switch content := evt.Content.Parsed.(type) { case *event.EncryptedEventContent: - log = log.With(). - Str("sender_key", content.SenderKey.String()). - Logger() - log.Debug().Msg("Handling encrypted to-device event") - ctx = log.WithContext(ctx) - decryptedEvt, err := mach.decryptOlmEvent(ctx, evt) - if err != nil { - log.Error().Err(err).Msg("Failed to decrypt to-device event") - return - } - log = log.With(). - Str("decrypted_type", decryptedEvt.Type.Type). - Str("sender_device", decryptedEvt.SenderDevice.String()). - Str("sender_signing_key", decryptedEvt.Keys.Ed25519.String()). - Logger() - log.Trace().Msg("Successfully decrypted to-device event") - - switch decryptedContent := decryptedEvt.Content.Parsed.(type) { - case *event.RoomKeyEventContent: - mach.receiveRoomKey(ctx, decryptedEvt, decryptedContent) - log.Trace().Msg("Handled room key event") - case *event.ForwardedRoomKeyEventContent: - if mach.importForwardedRoomKey(ctx, decryptedEvt, decryptedContent) { - if ch, ok := mach.roomKeyRequestFilled.Load(decryptedContent.SessionID); ok { - // close channel to notify listener that the key was received - close(ch.(chan struct{})) - } - } - log.Trace().Msg("Handled forwarded room key event") - case *event.DummyEventContent: - log.Debug().Msg("Received encrypted dummy event") - default: - log.Debug().Msg("Unhandled encrypted to-device event") - } + mach.HandleEncryptedEvent(ctx, evt) return case *event.RoomKeyRequestEventContent: - go mach.handleRoomKeyRequest(ctx, evt.Sender, content) + go mach.HandleRoomKeyRequest(ctx, evt.Sender, content) case *event.BeeperRoomKeyAckEventContent: - mach.handleBeeperRoomKeyAck(ctx, evt.Sender, content) + mach.HandleBeeperRoomKeyAck(ctx, evt.Sender, content) // verification cases case *event.VerificationStartEventContent: mach.handleVerificationStart(ctx, evt.Sender, content, content.TransactionID, 10*time.Minute, "") @@ -401,7 +406,7 @@ func (mach *OlmMachine) HandleToDeviceEvent(ctx context.Context, evt *event.Even case *event.VerificationRequestEventContent: mach.handleVerificationRequest(ctx, evt.Sender, content, content.TransactionID, "") case *event.RoomKeyWithheldEventContent: - mach.handleRoomKeyWithheld(ctx, content) + mach.HandleRoomKeyWithheld(ctx, content) default: deviceID, _ := evt.Content.Raw["device_id"].(string) log.Debug().Str("maybe_device_id", deviceID).Msg("Unhandled to-device event") @@ -615,7 +620,7 @@ func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEve mach.createGroupSession(ctx, evt.SenderKey, evt.Keys.Ed25519, content.RoomID, content.SessionID, content.SessionKey, maxAge, maxMessages, content.IsScheduled) } -func (mach *OlmMachine) handleRoomKeyWithheld(ctx context.Context, content *event.RoomKeyWithheldEventContent) { +func (mach *OlmMachine) HandleRoomKeyWithheld(ctx context.Context, content *event.RoomKeyWithheldEventContent) { if content.Algorithm != id.AlgorithmMegolmV1 { zerolog.Ctx(ctx).Debug().Interface("content", content).Msg("Non-megolm room key withheld event") return From 4020e9c2ea685dbbab9458827c35f248a6f1a69c Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Wed, 10 Jan 2024 23:19:51 -0700 Subject: [PATCH 0072/1647] client: add key backup functions Signed-off-by: Sumner Evans --- client.go | 148 +++++++++++++++++++++++++++++++++++++++++++++++++++ requests.go | 16 ++++-- responses.go | 8 +-- 3 files changed, 163 insertions(+), 9 deletions(-) diff --git a/client.go b/client.go index dfef7231..3fabb7af 100644 --- a/client.go +++ b/client.go @@ -1944,6 +1944,154 @@ func (cli *Client) GetKeyChanges(ctx context.Context, from, to string) (resp *Re return } +// GetKeyBackup retrieves the keys from the backup. +// +// See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keyskeys +func (cli *Client) GetKeyBackup(ctx context.Context, version string) (resp *RespRoomKeys, err error) { + urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys"}, map[string]string{ + "version": version, + }) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + return +} + +// PutKeysInBackup stores several keys in the backup. +// +// See: https://spec.matrix.org/v1.9/client-server-api/#put_matrixclientv3room_keyskeys +func (cli *Client) PutKeysInBackup(ctx context.Context, version string, req *ReqKeyBackup) (resp *RespRoomKeysUpdate, err error) { + urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys"}, map[string]string{ + "version": version, + }) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, req, &resp) + return +} + +// DeleteKeyBackup deletes all keys from the backup. +// +// See: https://spec.matrix.org/v1.9/client-server-api/#delete_matrixclientv3room_keyskeys +func (cli *Client) DeleteKeyBackup(ctx context.Context, version string) (resp *RespRoomKeysUpdate, err error) { + urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys"}, map[string]string{ + "version": version, + }) + _, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, &resp) + return +} + +// GetKeyBackupForRoom retrieves the keys from the backup for the given room. +// +// See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keyskeysroomid +func (cli *Client) GetKeyBackupForRoom(ctx context.Context, version string, roomID id.RoomID) (resp *RespRoomKeyBackup, err error) { + urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String()}, map[string]string{ + "version": version, + }) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + return +} + +// PutKeysInBackupForRoom stores several keys in the backup for the given room. +// +// See: https://spec.matrix.org/v1.9/client-server-api/#put_matrixclientv3room_keyskeysroomid +func (cli *Client) PutKeysInBackupForRoom(ctx context.Context, version string, roomID id.RoomID, req *ReqRoomKeyBackup) (resp *RespRoomKeysUpdate, err error) { + urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String()}, map[string]string{ + "version": version, + }) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, req, &resp) + return +} + +// DeleteKeysFromBackupForRoom deletes all the keys in the backup for the given +// room. +// +// See: https://spec.matrix.org/v1.9/client-server-api/#delete_matrixclientv3room_keyskeysroomid +func (cli *Client) DeleteKeysFromBackupForRoom(ctx context.Context, version string, roomID id.RoomID) (resp *RespRoomKeysUpdate, err error) { + urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String()}, map[string]string{ + "version": version, + }) + _, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, &resp) + return +} + +// GetKeyBackupForRoomAndSession retrieves a key from the backup. +// +// See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keyskeysroomidsessionid +func (cli *Client) GetKeyBackupForRoomAndSession(ctx context.Context, version string, roomID id.RoomID, sessionID id.SessionID) (resp *RespKeyBackupData, err error) { + urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String(), sessionID.String()}, map[string]string{ + "version": version, + }) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + return +} + +// PutKeysInBackupForRoomAndSession stores a key in the backup. +// +// See: https://spec.matrix.org/v1.9/client-server-api/#put_matrixclientv3room_keyskeysroomidsessionid +func (cli *Client) PutKeysInBackupForRoomAndSession(ctx context.Context, version string, roomID id.RoomID, sessionID id.SessionID, req *ReqKeyBackupData) (resp *RespRoomKeysUpdate, err error) { + urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String(), sessionID.String()}, map[string]string{ + "version": version, + }) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, req, &resp) + return +} + +// DeleteKeysInBackupForRoomAndSession deletes a key from the backup. +// +// See: https://spec.matrix.org/v1.9/client-server-api/#delete_matrixclientv3room_keyskeysroomidsessionid +func (cli *Client) DeleteKeysInBackupForRoomAndSession(ctx context.Context, version string, roomID id.RoomID, sessionID id.SessionID) (resp *RespRoomKeysUpdate, err error) { + urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String(), sessionID.String()}, map[string]string{ + "version": version, + }) + _, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, &resp) + return +} + +// GetKeyBackupLatestVersion returns information about the latest backup version. +// +// See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keysversion +func (cli *Client) GetKeyBackupLatestVersion(ctx context.Context) (resp *RespRoomKeysVersion, err error) { + urlPath := cli.BuildClientURL("v3", "room_keys", "version") + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + return +} + +// CreateKeyBackupVersion creates a new key backup. +// +// See: https://spec.matrix.org/v1.9/client-server-api/#post_matrixclientv3room_keysversion +func (cli *Client) CreateKeyBackupVersion(ctx context.Context, req *ReqRoomKeysVersionCreate) (resp *RespRoomKeysVersionCreate, err error) { + urlPath := cli.BuildClientURL("v3", "room_keys", "version") + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) + return +} + +// GetKeyBackupVersion returns information about an existing key backup. +// +// See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keysversionversion +func (cli *Client) GetKeyBackupVersion(ctx context.Context, version string) (resp *RespRoomKeysVersion, err error) { + urlPath := cli.BuildClientURL("v3", "room_keys", "version", version) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + return +} + +// UpdateKeyBackupVersion updates information about an existing key backup. Only +// the auth_data can be modified. +// +// See: https://spec.matrix.org/v1.9/client-server-api/#put_matrixclientv3room_keysversionversion +func (cli *Client) UpdateKeyBackupVersion(ctx context.Context, version string, req *ReqRoomKeysVersionUpdate) error { + urlPath := cli.BuildClientURL("v3", "room_keys", "version", version) + _, err := cli.MakeRequest(ctx, http.MethodPut, urlPath, nil, nil) + return err +} + +// DeleteKeyBackupVersion deletes an existing key backup. Both the information +// about the backup, as well as all key data related to the backup will be +// deleted. +// +// See: https://spec.matrix.org/v1.1/client-server-api/#delete_matrixclientv3room_keysversionversion +func (cli *Client) DeleteKeyBackupVersion(ctx context.Context, version string) error { + urlPath := cli.BuildClientURL("v3", "room_keys", "version", version) + _, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, nil) + return err +} + func (cli *Client) SendToDevice(ctx context.Context, eventType event.Type, req *ReqSendToDevice) (resp *RespSendToDevice, err error) { urlPath := cli.BuildClientURL("v3", "sendToDevice", eventType.String(), cli.TxnID()) _, err = cli.MakeRequest(ctx, "PUT", urlPath, req, &resp) diff --git a/requests.go b/requests.go index 6e00346a..82f98c65 100644 --- a/requests.go +++ b/requests.go @@ -434,15 +434,21 @@ type ReqRoomKeysVersionCreate struct { AuthData json.RawMessage `json:"auth_data"` } -type ReqRoomKeysUpdate struct { - Rooms map[id.RoomID]ReqRoomKeysRoomUpdate `json:"rooms"` +type ReqRoomKeysVersionUpdate struct { + Algorithm string `json:"algorithm"` + AuthData json.RawMessage `json:"auth_data"` + Version string `json:"version,omitempty"` } -type ReqRoomKeysRoomUpdate struct { - Sessions map[id.SessionID]ReqRoomKeysSessionUpdate `json:"sessions"` +type ReqKeyBackup struct { + Rooms map[id.RoomID]ReqRoomKeyBackup `json:"rooms"` } -type ReqRoomKeysSessionUpdate struct { +type ReqRoomKeyBackup struct { + Sessions map[id.SessionID]ReqKeyBackupData `json:"sessions"` +} + +type ReqKeyBackupData struct { FirstMessageIndex int `json:"first_message_index"` ForwardedCount int `json:"forwarded_count"` IsVerified bool `json:"is_verified"` diff --git a/responses.go b/responses.go index 69eb4b8f..f5060d60 100644 --- a/responses.go +++ b/responses.go @@ -605,14 +605,14 @@ type RespRoomKeysVersion struct { } type RespRoomKeys struct { - Rooms map[id.RoomID]RespRoomKeysRoom `json:"rooms"` + Rooms map[id.RoomID]RespRoomKeyBackup `json:"rooms"` } -type RespRoomKeysRoom struct { - Sessions map[id.SessionID]RespRoomKeysSession `json:"sessions"` +type RespRoomKeyBackup struct { + Sessions map[id.SessionID]RespKeyBackupData `json:"sessions"` } -type RespRoomKeysSession struct { +type RespKeyBackupData struct { FirstMessageIndex int `json:"first_message_index"` ForwardedCount int `json:"forwarded_count"` IsVerified bool `json:"is_verified"` From 96d1d162268857df169a7f24e0d0e5f447ba87e4 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 11 Jan 2024 11:02:55 -0700 Subject: [PATCH 0073/1647] id/crypto: add KeyBackupAlgorithm Signed-off-by: Sumner Evans --- id/crypto.go | 6 ++++++ responses.go | 10 +++++----- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/id/crypto.go b/id/crypto.go index 84fcd67f..e920a301 100644 --- a/id/crypto.go +++ b/id/crypto.go @@ -44,6 +44,12 @@ const ( XSUsageUserSigning CrossSigningUsage = "user_signing" ) +type KeyBackupAlgorithm string + +const ( + KeyBackupAlgorithmMegolmBackupV1 KeyBackupAlgorithm = "m.megolm_backup.v1.curve25519-aes-sha2" +) + // A SessionID is an arbitrary string that identifies an Olm or Megolm session. type SessionID string diff --git a/responses.go b/responses.go index f5060d60..4599fb19 100644 --- a/responses.go +++ b/responses.go @@ -597,11 +597,11 @@ type RespRoomKeysVersionCreate struct { } type RespRoomKeysVersion struct { - Algorithm string `json:"algorithm"` - AuthData json.RawMessage `json:"auth_data"` - Count int `json:"count"` - ETag string `json:"etag"` - Version string `json:"version"` + Algorithm id.KeyBackupAlgorithm `json:"algorithm"` + AuthData json.RawMessage `json:"auth_data"` + Count int `json:"count"` + ETag string `json:"etag"` + Version string `json:"version"` } type RespRoomKeys struct { From 18bb31e1c7f03b753c4165c2fc524980442be663 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 11 Jan 2024 11:55:44 -0700 Subject: [PATCH 0074/1647] crypto/signatures: move Signatures type alias to separate package and add helper method Signed-off-by: Sumner Evans --- crypto/account.go | 13 +++---------- crypto/cross_sign_key.go | 13 +++---------- crypto/cross_sign_signing.go | 13 +++---------- crypto/signatures/signatures.go | 17 +++++++++++++++++ requests.go | 25 ++++++++++++------------- 5 files changed, 38 insertions(+), 43 deletions(-) create mode 100644 crypto/signatures/signatures.go diff --git a/crypto/account.go b/crypto/account.go index 0eb18a24..78fbfa5f 100644 --- a/crypto/account.go +++ b/crypto/account.go @@ -9,6 +9,7 @@ package crypto import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/crypto/signatures" "maunium.net/go/mautrix/id" ) @@ -62,11 +63,7 @@ func (account *OlmAccount) getInitialKeys(userID id.UserID, deviceID id.DeviceID panic(err) } - deviceKeys.Signatures = mautrix.Signatures{ - userID: { - id.NewKeyID(id.KeyAlgorithmEd25519, deviceID.String()): signature, - }, - } + deviceKeys.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, deviceID.String(), signature) return deviceKeys } @@ -79,11 +76,7 @@ func (account *OlmAccount) getOneTimeKeys(userID id.UserID, deviceID id.DeviceID for keyID, key := range account.Internal.OneTimeKeys() { key := mautrix.OneTimeKey{Key: key} signature, _ := account.Internal.SignJSON(key) - key.Signatures = mautrix.Signatures{ - userID: { - id.NewKeyID(id.KeyAlgorithmEd25519, deviceID.String()): signature, - }, - } + key.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, deviceID.String(), signature) key.IsSigned = true oneTimeKeys[id.NewKeyID(id.KeyAlgorithmSignedCurve25519, keyID)] = key } diff --git a/crypto/cross_sign_key.go b/crypto/cross_sign_key.go index 4528ae02..005a05fb 100644 --- a/crypto/cross_sign_key.go +++ b/crypto/cross_sign_key.go @@ -13,6 +13,7 @@ import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/crypto/signatures" "maunium.net/go/mautrix/id" ) @@ -112,11 +113,7 @@ func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *Cross if err != nil { return fmt.Errorf("failed to sign self-signing key: %w", err) } - selfKey.Signatures = map[id.UserID]map[id.KeyID]string{ - userID: { - masterKeyID: selfSig, - }, - } + selfKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey.String(), selfSig) userKey := mautrix.CrossSigningKeys{ UserID: userID, @@ -129,11 +126,7 @@ func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *Cross if err != nil { return fmt.Errorf("failed to sign user-signing key: %w", err) } - userKey.Signatures = map[id.UserID]map[id.KeyID]string{ - userID: { - masterKeyID: userSig, - }, - } + userKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey.String(), userSig) err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq{ Master: masterKey, diff --git a/crypto/cross_sign_signing.go b/crypto/cross_sign_signing.go index f6c37a9f..616fef4a 100644 --- a/crypto/cross_sign_signing.go +++ b/crypto/cross_sign_signing.go @@ -14,6 +14,7 @@ import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/crypto/signatures" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -115,11 +116,7 @@ func (mach *OlmMachine) SignOwnMasterKey(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to sign JSON: %w", err) } - masterKeyObj.Signatures = mautrix.Signatures{ - userID: map[id.KeyID]string{ - id.NewKeyID(id.KeyAlgorithmEd25519, deviceID.String()): signature, - }, - } + masterKeyObj.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, deviceID.String(), signature) mach.Log.Debug(). Str("device_id", deviceID.String()). Str("signature", signature). @@ -214,11 +211,7 @@ func (mach *OlmMachine) signAndUpload(ctx context.Context, req mautrix.ReqKeysSi if err != nil { return "", fmt.Errorf("failed to sign JSON: %w", err) } - req.Signatures = mautrix.Signatures{ - mach.Client.UserID: map[id.KeyID]string{ - id.NewKeyID(id.KeyAlgorithmEd25519, key.PublicKey.String()): signature, - }, - } + req.Signatures = signatures.NewSingleSignature(mach.Client.UserID, id.KeyAlgorithmEd25519, key.PublicKey.String(), signature) resp, err := mach.Client.UploadSignatures(ctx, &mautrix.ReqUploadSignatures{ userID: map[string]mautrix.ReqKeysSignatures{ diff --git a/crypto/signatures/signatures.go b/crypto/signatures/signatures.go new file mode 100644 index 00000000..dae34ef3 --- /dev/null +++ b/crypto/signatures/signatures.go @@ -0,0 +1,17 @@ +package signatures + +import "maunium.net/go/mautrix/id" + +// Signatures represents a set of signatures for some data from multiple users +// and keys. +type Signatures map[id.UserID]map[id.KeyID]string + +// NewSingleSignature creates a new [Signatures] object with a single +// signature. +func NewSingleSignature(userID id.UserID, algorithm id.KeyAlgorithm, keyID string, signature string) Signatures { + return Signatures{ + userID: { + id.NewKeyID(algorithm, keyID): signature, + }, + } +} diff --git a/requests.go b/requests.go index 82f98c65..1551e63b 100644 --- a/requests.go +++ b/requests.go @@ -4,6 +4,7 @@ import ( "encoding/json" "strconv" + "maunium.net/go/mautrix/crypto/signatures" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" "maunium.net/go/mautrix/pushrules" @@ -184,11 +185,11 @@ type ReqAliasCreate struct { } type OneTimeKey struct { - Key id.Curve25519 `json:"key"` - Fallback bool `json:"fallback,omitempty"` - Signatures Signatures `json:"signatures,omitempty"` - Unsigned map[string]any `json:"unsigned,omitempty"` - IsSigned bool `json:"-"` + Key id.Curve25519 `json:"key"` + Fallback bool `json:"fallback,omitempty"` + Signatures signatures.Signatures `json:"signatures,omitempty"` + Unsigned map[string]any `json:"unsigned,omitempty"` + IsSigned bool `json:"-"` // Raw data in the one-time key. This must be used for signature verification to ensure unrecognized fields // aren't thrown away (because that would invalidate the signature). @@ -230,7 +231,7 @@ type ReqKeysSignatures struct { Algorithms []id.Algorithm `json:"algorithms,omitempty"` Usage []id.CrossSigningUsage `json:"usage,omitempty"` Keys map[id.KeyID]string `json:"keys"` - Signatures Signatures `json:"signatures"` + Signatures signatures.Signatures `json:"signatures"` } type ReqUploadSignatures map[id.UserID]map[string]ReqKeysSignatures @@ -240,15 +241,15 @@ type DeviceKeys struct { DeviceID id.DeviceID `json:"device_id"` Algorithms []id.Algorithm `json:"algorithms"` Keys KeyMap `json:"keys"` - Signatures Signatures `json:"signatures"` + Signatures signatures.Signatures `json:"signatures"` Unsigned map[string]interface{} `json:"unsigned,omitempty"` } type CrossSigningKeys struct { - UserID id.UserID `json:"user_id"` - Usage []id.CrossSigningUsage `json:"usage"` - Keys map[id.KeyID]id.Ed25519 `json:"keys"` - Signatures map[id.UserID]map[id.KeyID]string `json:"signatures,omitempty"` + UserID id.UserID `json:"user_id"` + Usage []id.CrossSigningUsage `json:"usage"` + Keys map[id.KeyID]id.Ed25519 `json:"keys"` + Signatures signatures.Signatures `json:"signatures,omitempty"` } func (csk *CrossSigningKeys) FirstKey() id.Ed25519 { @@ -283,8 +284,6 @@ func (km KeyMap) GetCurve25519(deviceID id.DeviceID) id.Curve25519 { return id.Curve25519(val) } -type Signatures map[id.UserID]map[id.KeyID]string - type ReqQueryKeys struct { DeviceKeys DeviceKeysRequest `json:"device_keys"` Timeout int64 `json:"timeout,omitempty"` From d9c3b42564e0f48f1fd406665a784caae604c6ef Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 11 Jan 2024 12:09:03 -0700 Subject: [PATCH 0075/1647] crypto/backup: add structs for m.megolm_backup.v1.curve25519-aes-sha2 backup algorithm Signed-off-by: Sumner Evans --- crypto/backup/megolmbackup.go | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 crypto/backup/megolmbackup.go diff --git a/crypto/backup/megolmbackup.go b/crypto/backup/megolmbackup.go new file mode 100644 index 00000000..2ec425c0 --- /dev/null +++ b/crypto/backup/megolmbackup.go @@ -0,0 +1,29 @@ +package backup + +import ( + "maunium.net/go/mautrix/crypto/signatures" + "maunium.net/go/mautrix/id" +) + +// MegolmAuthData is the auth_data when the key backup is created with +// the [id.KeyBackupAlgorithmMegolmBackupV1] algorithm as defined in +// [Section 11.12.3.2.2 of the Spec]. +// +// [Section 11.12.3.2.2 of the Spec]: https://spec.matrix.org/v1.9/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 +type MegolmAuthData struct { + PublicKey id.Ed25519 `json:"public_key"` + Signatures signatures.Signatures `json:"signatures"` +} + +// MegolmSessionData is the decrypted session_data when the key backup is created +// with the [id.KeyBackupAlgorithmMegolmBackupV1] algorithm as defined in +// [Section 11.12.3.2.2 of the Spec]. +// +// [Section 11.12.3.2.2 of the Spec]: https://spec.matrix.org/v1.9/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 +type MegolmSessionData struct { + Algorithm id.Algorithm `json:"algorithm"` + ForwardingKeyChain []string `json:"forwarding_curve25519_key_chain"` + SenderClaimedKeys map[id.KeyAlgorithm]string `json:"sender_claimed_keys"` + SenderKey id.SenderKey `json:"sender_key"` + SessionKey []byte `json:"session_key"` +} From 066db534f683d5cc1290ec94635bb118c64c208f Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 11 Jan 2024 18:52:07 -0700 Subject: [PATCH 0076/1647] crypto/pkcs7: add package for (un)padding data using PKCS#7 Signed-off-by: Sumner Evans --- crypto/pkcs7/pkcs7.go | 24 ++++++++++++++++++++ crypto/pkcs7/pkcs7_test.go | 45 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) create mode 100644 crypto/pkcs7/pkcs7.go create mode 100644 crypto/pkcs7/pkcs7_test.go diff --git a/crypto/pkcs7/pkcs7.go b/crypto/pkcs7/pkcs7.go new file mode 100644 index 00000000..1018e52b --- /dev/null +++ b/crypto/pkcs7/pkcs7.go @@ -0,0 +1,24 @@ +package pkcs7 + +import "bytes" + +// Pad implements PKCS#7 padding as defined in [RFC2315]. It pads the plaintext +// to the given blockSize in the range [1, 255]. This is normally used in +// AES-CBC encryption. +// +// [RFC2315]: https://www.ietf.org/rfc/rfc2315.txt +func Pad(plaintext []byte, blockSize int) []byte { + padding := blockSize - len(plaintext)%blockSize + return append(plaintext, bytes.Repeat([]byte{byte(padding)}, padding)...) +} + +// Unpad implements PKCS#7 unpadding as defined in [RFC2315]. It unpads the +// plaintext by reading the padding amount from the last byte of the plaintext. +// This is normally used in AES-CBC decryption. +// +// [RFC2315]: https://www.ietf.org/rfc/rfc2315.txt +func Unpad(plaintext []byte) []byte { + length := len(plaintext) + unpadding := int(plaintext[length-1]) + return plaintext[:length-unpadding] +} diff --git a/crypto/pkcs7/pkcs7_test.go b/crypto/pkcs7/pkcs7_test.go new file mode 100644 index 00000000..6ef835c0 --- /dev/null +++ b/crypto/pkcs7/pkcs7_test.go @@ -0,0 +1,45 @@ +package pkcs7_test + +import ( + "bytes" + "crypto/aes" + "testing" + + "github.com/stretchr/testify/assert" + + "maunium.net/go/mautrix/crypto/pkcs7" +) + +func TestPKCS7(t *testing.T) { + testCases := []struct { + input []byte + blockLen int + expected []byte + }{ + {[]byte("test"), 4, []byte("test\x04\x04\x04\x04")}, + {[]byte("test"), 8, []byte("test\x04\x04\x04\x04")}, + {[]byte("test1"), 8, []byte("test1\x03\x03\x03")}, + {bytes.Repeat([]byte("test1"), 6), aes.BlockSize, append(bytes.Repeat([]byte("test1"), 6), 0x02, 0x02)}, + } + for _, tc := range testCases { + t.Run(string(tc.input), func(t *testing.T) { + // Test pad + padded := pkcs7.Pad(tc.input, tc.blockLen) + assert.Equal(t, tc.expected, padded) + assert.Zero(t, len(padded)%tc.blockLen, "padded length is not a multiple of block size") + + // Test unpad + assert.Equal(t, tc.input, pkcs7.Unpad(tc.expected)) + }) + } +} + +func TestPKCS7_RoundtripWithAESBlockSize(t *testing.T) { + for i := 0; i < 1024; i++ { + input := bytes.Repeat([]byte{byte(i)}, i) + padded := pkcs7.Pad(input, aes.BlockSize) + assert.Zero(t, len(padded)%aes.BlockSize, "padded length is not a multiple of the AES block size") + unpadded := pkcs7.Unpad(padded) + assert.Equal(t, bytes.Repeat([]byte{byte(i)}, i), unpadded) + } +} From e74304d0220e0fcb16b1ef151285505bc68dfa32 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 11 Jan 2024 19:13:44 -0700 Subject: [PATCH 0077/1647] crypto/aescbc: move to its own module Signed-off-by: Sumner Evans --- crypto/aescbc/aes_cbc.go | 54 +++++++++++++ .../{goolm/crypto => aescbc}/aes_cbc_test.go | 12 +-- crypto/aescbc/errors.go | 9 +++ crypto/goolm/cipher/aes_sha256.go | 8 +- crypto/goolm/crypto/aes_cbc.go | 75 ------------------- crypto/goolm/errors.go | 2 - 6 files changed, 74 insertions(+), 86 deletions(-) create mode 100644 crypto/aescbc/aes_cbc.go rename crypto/{goolm/crypto => aescbc}/aes_cbc_test.go (81%) create mode 100644 crypto/aescbc/errors.go delete mode 100644 crypto/goolm/crypto/aes_cbc.go diff --git a/crypto/aescbc/aes_cbc.go b/crypto/aescbc/aes_cbc.go new file mode 100644 index 00000000..f1fdc84d --- /dev/null +++ b/crypto/aescbc/aes_cbc.go @@ -0,0 +1,54 @@ +package aescbc + +import ( + "crypto/aes" + "crypto/cipher" + + "maunium.net/go/mautrix/crypto/pkcs7" +) + +// Encrypt encrypts the plaintext with the key and IV. The IV length must be +// equal to the AES block size. +// +// This function might mutate the plaintext. +func Encrypt(key, iv, plaintext []byte) ([]byte, error) { + if len(key) == 0 { + return nil, ErrNoKeyProvided + } + if len(iv) != aes.BlockSize { + return nil, ErrIVNotBlockSize + } + plaintext = pkcs7.Pad(plaintext, aes.BlockSize) + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + cipher.NewCBCEncrypter(block, iv).CryptBlocks(plaintext, plaintext) + return plaintext, nil +} + +// Decrypt decrypts the ciphertext with the key and IV. The IV length must be +// equal to the block size. +// +// This function mutates the ciphertext. +func Decrypt(key, iv, ciphertext []byte) ([]byte, error) { + if len(key) == 0 { + return nil, ErrNoKeyProvided + } + if len(iv) != aes.BlockSize { + return nil, ErrIVNotBlockSize + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + if len(ciphertext) < aes.BlockSize { + return nil, ErrNotMultipleBlockSize + } + + cipher.NewCBCDecrypter(block, iv).CryptBlocks(ciphertext, ciphertext) + return pkcs7.Unpad(ciphertext), nil +} diff --git a/crypto/goolm/crypto/aes_cbc_test.go b/crypto/aescbc/aes_cbc_test.go similarity index 81% rename from crypto/goolm/crypto/aes_cbc_test.go rename to crypto/aescbc/aes_cbc_test.go index c64e4a5d..06dcee0d 100644 --- a/crypto/goolm/crypto/aes_cbc_test.go +++ b/crypto/aescbc/aes_cbc_test.go @@ -1,4 +1,4 @@ -package crypto_test +package aescbc_test import ( "bytes" @@ -6,7 +6,7 @@ import ( "crypto/rand" "testing" - "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/aescbc" ) func TestAESCBC(t *testing.T) { @@ -30,11 +30,11 @@ func TestAESCBC(t *testing.T) { plaintext = append(plaintext, []byte("-")...) } - if ciphertext, err = crypto.AESCBCEncrypt(key, iv, plaintext); err != nil { + if ciphertext, err = aescbc.Encrypt(key, iv, plaintext); err != nil { t.Fatal(err) } - resultPlainText, err := crypto.AESCBCDecrypt(key, iv, ciphertext) + resultPlainText, err := aescbc.Decrypt(key, iv, ciphertext) if err != nil { t.Fatal(err) } @@ -54,7 +54,7 @@ func TestAESCBCCase1(t *testing.T) { input := make([]byte, 16) key := make([]byte, 32) iv := make([]byte, aes.BlockSize) - encrypted, err := crypto.AESCBCEncrypt(key, iv, input) + encrypted, err := aescbc.Encrypt(key, iv, input) if err != nil { t.Fatal(err) } @@ -62,7 +62,7 @@ func TestAESCBCCase1(t *testing.T) { t.Fatalf("encrypted did not match expected:\n%v\n%v\n", encrypted, expected) } - decrypted, err := crypto.AESCBCDecrypt(key, iv, encrypted) + decrypted, err := aescbc.Decrypt(key, iv, encrypted) if err != nil { t.Fatal(err) } diff --git a/crypto/aescbc/errors.go b/crypto/aescbc/errors.go new file mode 100644 index 00000000..542c3450 --- /dev/null +++ b/crypto/aescbc/errors.go @@ -0,0 +1,9 @@ +package aescbc + +import "errors" + +var ( + ErrNoKeyProvided = errors.New("no key") + ErrIVNotBlockSize = errors.New("IV length does not match AES block size") + ErrNotMultipleBlockSize = errors.New("ciphertext length is not a multiple of the AES block size") +) diff --git a/crypto/goolm/cipher/aes_sha256.go b/crypto/goolm/cipher/aes_sha256.go index 1155949b..2d2d58d5 100644 --- a/crypto/goolm/cipher/aes_sha256.go +++ b/crypto/goolm/cipher/aes_sha256.go @@ -2,8 +2,10 @@ package cipher import ( "bytes" + "crypto/aes" "io" + "maunium.net/go/mautrix/crypto/aescbc" "maunium.net/go/mautrix/crypto/goolm/crypto" ) @@ -36,7 +38,7 @@ func deriveAESKeys(kdfInfo []byte, key []byte) (*derivedAESKeys, error) { // AESSha512BlockSize resturns the blocksize of the cipher AESSHA256. func AESSha512BlockSize() int { - return crypto.AESCBCBlocksize() + return aes.BlockSize } // AESSHA256 is a valid cipher using AES with CBC and HKDFSha256. @@ -57,7 +59,7 @@ func (c AESSHA256) Encrypt(key, plaintext []byte) (ciphertext []byte, err error) if err != nil { return nil, err } - ciphertext, err = crypto.AESCBCEncrypt(keys.key, keys.iv, plaintext) + ciphertext, err = aescbc.Encrypt(keys.key, keys.iv, plaintext) if err != nil { return nil, err } @@ -70,7 +72,7 @@ func (c AESSHA256) Decrypt(key, ciphertext []byte) (plaintext []byte, err error) if err != nil { return nil, err } - plaintext, err = crypto.AESCBCDecrypt(keys.key, keys.iv, ciphertext) + plaintext, err = aescbc.Decrypt(keys.key, keys.iv, ciphertext) if err != nil { return nil, err } diff --git a/crypto/goolm/crypto/aes_cbc.go b/crypto/goolm/crypto/aes_cbc.go deleted file mode 100644 index 10434ab7..00000000 --- a/crypto/goolm/crypto/aes_cbc.go +++ /dev/null @@ -1,75 +0,0 @@ -package crypto - -import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "fmt" - - "maunium.net/go/mautrix/crypto/goolm" -) - -// AESCBCBlocksize returns the blocksize of the encryption method -func AESCBCBlocksize() int { - return aes.BlockSize -} - -// AESCBCEncrypt encrypts the plaintext with the key and iv. len(iv) must be equal to the blocksize! -func AESCBCEncrypt(key, iv, plaintext []byte) ([]byte, error) { - if len(key) == 0 { - return nil, fmt.Errorf("AESCBCEncrypt: %w", goolm.ErrNoKeyProvided) - } - if len(iv) != AESCBCBlocksize() { - return nil, fmt.Errorf("iv: %w", goolm.ErrNotBlocksize) - } - var cipherText []byte - plaintext = pkcs5Padding(plaintext, AESCBCBlocksize()) - if len(plaintext)%AESCBCBlocksize() != 0 { - return nil, fmt.Errorf("message: %w", goolm.ErrNotMultipleBlocksize) - } - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - cipherText = make([]byte, len(plaintext)) - cbc := cipher.NewCBCEncrypter(block, iv) - cbc.CryptBlocks(cipherText, plaintext) - return cipherText, nil -} - -// AESCBCDecrypt decrypts the ciphertext with the key and iv. len(iv) must be equal to the blocksize! -func AESCBCDecrypt(key, iv, ciphertext []byte) ([]byte, error) { - if len(key) == 0 { - return nil, fmt.Errorf("AESCBCEncrypt: %w", goolm.ErrNoKeyProvided) - } - if len(iv) != AESCBCBlocksize() { - return nil, fmt.Errorf("iv: %w", goolm.ErrNotBlocksize) - } - var block cipher.Block - var err error - block, err = aes.NewCipher(key) - if err != nil { - return nil, err - } - if len(ciphertext) < AESCBCBlocksize() { - return nil, fmt.Errorf("ciphertext: %w", goolm.ErrNotMultipleBlocksize) - } - - cbc := cipher.NewCBCDecrypter(block, iv) - cbc.CryptBlocks(ciphertext, ciphertext) - return pkcs5Unpadding(ciphertext), nil -} - -// pkcs5Padding paddes the plaintext to be used in the AESCBC encryption. -func pkcs5Padding(plaintext []byte, blockSize int) []byte { - padding := (blockSize - len(plaintext)%blockSize) - padtext := bytes.Repeat([]byte{byte(padding)}, padding) - return append(plaintext, padtext...) -} - -// pkcs5Unpadding undoes the padding to the plaintext after AESCBC decryption. -func pkcs5Unpadding(plaintext []byte) []byte { - length := len(plaintext) - unpadding := int(plaintext[length-1]) - return plaintext[:(length - unpadding)] -} diff --git a/crypto/goolm/errors.go b/crypto/goolm/errors.go index 4dec9849..6539b0f1 100644 --- a/crypto/goolm/errors.go +++ b/crypto/goolm/errors.go @@ -21,8 +21,6 @@ var ( ErrChainTooHigh = errors.New("chain index too high") ErrBadInput = errors.New("bad input") ErrBadVersion = errors.New("wrong version") - ErrNotBlocksize = errors.New("length != blocksize") - ErrNotMultipleBlocksize = errors.New("length not a multiple of the blocksize") ErrWrongPickleVersion = errors.New("wrong pickle version") ErrValueTooShort = errors.New("value too short") ErrInputToSmall = errors.New("input too small (truncated?)") From e0d1c1de3329d5baa6367d74869a5cbeb550d482 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 12 Jan 2024 10:06:16 -0700 Subject: [PATCH 0078/1647] crypto/backup: add EphemeralKey struct with JSON (de)serialization Signed-off-by: Sumner Evans --- crypto/backup/ephemeralkey.go | 35 ++++++++++++++++++++ crypto/backup/ephemeralkey_test.go | 51 ++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+) create mode 100644 crypto/backup/ephemeralkey.go create mode 100644 crypto/backup/ephemeralkey_test.go diff --git a/crypto/backup/ephemeralkey.go b/crypto/backup/ephemeralkey.go new file mode 100644 index 00000000..d0ee03a6 --- /dev/null +++ b/crypto/backup/ephemeralkey.go @@ -0,0 +1,35 @@ +package backup + +import ( + "crypto/ecdh" + "encoding/base64" + "encoding/json" +) + +// EphemeralKey is a wrapper around an ECDH X25519 public key that implements +// JSON marshalling and unmarshalling. +type EphemeralKey struct { + *ecdh.PublicKey +} + +func (k *EphemeralKey) MarshalJSON() ([]byte, error) { + if k == nil || k.PublicKey == nil { + return json.Marshal(nil) + } + return json.Marshal(base64.RawStdEncoding.EncodeToString(k.Bytes())) +} + +func (k *EphemeralKey) UnmarshalJSON(data []byte) error { + var keyStr string + err := json.Unmarshal(data, &keyStr) + if err != nil { + return err + } + + keyBytes, err := base64.RawStdEncoding.DecodeString(keyStr) + if err != nil { + return err + } + k.PublicKey, err = ecdh.X25519().NewPublicKey(keyBytes) + return err +} diff --git a/crypto/backup/ephemeralkey_test.go b/crypto/backup/ephemeralkey_test.go new file mode 100644 index 00000000..93d24563 --- /dev/null +++ b/crypto/backup/ephemeralkey_test.go @@ -0,0 +1,51 @@ +package backup_test + +import ( + "crypto/ecdh" + "crypto/rand" + "encoding/base64" + "encoding/json" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/crypto/backup" +) + +type testStruct struct { + EphemeralKey *backup.EphemeralKey `json:"ephemeral"` +} + +func TestEphemeralKey_UnmarshalJSON(t *testing.T) { + testCases := []string{ + "o43y/Mck1DExWdHr0+qbPJbjzO97+RH1mw6phLhYQj0", + } + + testJSONTemplate := `{"ephemeral": "%s"}` + + for _, tc := range testCases { + t.Run(tc, func(t *testing.T) { + var test testStruct + jsonInput := fmt.Sprintf(testJSONTemplate, tc) + err := json.Unmarshal([]byte(jsonInput), &test) + require.NoError(t, err) + expected, err := base64.RawStdEncoding.DecodeString(tc) + require.NoError(t, err) + assert.Equal(t, expected, test.EphemeralKey.Bytes()) + }) + } +} + +func TestEphemeralKey_MarshallJSON(t *testing.T) { + key, err := ecdh.X25519().GenerateKey(rand.Reader) + require.NoError(t, err) + + test := &backup.EphemeralKey{key.PublicKey()} + marshalled, err := json.Marshal(test) + require.NoError(t, err) + assert.EqualValues(t, '"', marshalled[0]) + assert.Len(t, marshalled, 45) + assert.EqualValues(t, '"', marshalled[44]) +} From 6681e40debab6b6053a06f62a1adeed7f310d201 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 12 Jan 2024 10:06:44 -0700 Subject: [PATCH 0079/1647] crypto/backup: add EncryptedSessionData struct and encrypt/decrypt methods Signed-off-by: Sumner Evans --- crypto/backup/encryptedsessiondata.go | 149 +++++++++++++++++++++ crypto/backup/encryptedsessiondata_test.go | 102 ++++++++++++++ crypto/backup/megolmbackupkey.go | 28 ++++ 3 files changed, 279 insertions(+) create mode 100644 crypto/backup/encryptedsessiondata.go create mode 100644 crypto/backup/encryptedsessiondata_test.go create mode 100644 crypto/backup/megolmbackupkey.go diff --git a/crypto/backup/encryptedsessiondata.go b/crypto/backup/encryptedsessiondata.go new file mode 100644 index 00000000..ccaea0c4 --- /dev/null +++ b/crypto/backup/encryptedsessiondata.go @@ -0,0 +1,149 @@ +package backup + +import ( + "bytes" + "crypto/ecdh" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + + "golang.org/x/crypto/hkdf" + + "maunium.net/go/mautrix/crypto/aescbc" +) + +var ErrInvalidMAC = errors.New("invalid MAC") + +// UnpaddedBytes is a byte slice that is encoded and decoded using +// [base64.RawStdEncoding]. +type UnpaddedBytes []byte + +func (b UnpaddedBytes) MarshalJSON() ([]byte, error) { + return json.Marshal(base64.RawStdEncoding.EncodeToString(b)) +} + +func (b *UnpaddedBytes) UnmarshalJSON(data []byte) error { + var b64str string + err := json.Unmarshal(data, &b64str) + if err != nil { + return err + } + *b, err = base64.RawStdEncoding.DecodeString(b64str) + return err +} + +// EncryptedSessionData is the encrypted session_data field of a key backup as +// defined in [Section 11.12.3.2.2 of the Spec]. +// +// The type parameter T represents the format of the session data contained in +// the encrypted payload. +// +// [Section 11.12.3.2.2 of the Spec]: https://spec.matrix.org/v1.9/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 +type EncryptedSessionData[T any] struct { + Ciphertext UnpaddedBytes `json:"ciphertext"` + Ephemeral EphemeralKey `json:"ephemeral"` + MAC UnpaddedBytes `json:"mac"` +} + +func calculateEncryptionParameters(sharedSecret []byte) (key, macKey, iv []byte, err error) { + hkdfReader := hkdf.New(sha256.New, sharedSecret, nil, nil) + encryptionParams := make([]byte, 80) + _, err = hkdfReader.Read(encryptionParams) + if err != nil { + return nil, nil, nil, err + } + + return encryptionParams[:32], encryptionParams[32:64], encryptionParams[64:], nil +} + +// calculateCompatMAC calculates the MAC for compatibility with Olm and +// Vodozemac which do not actually write the ciphertext when computing the MAC. +// +// Deprecated: Use [calculateMAC] instead. +func calculateCompatMAC(macKey []byte) []byte { + hash := hmac.New(sha256.New, macKey) + return hash.Sum(nil)[:8] +} + +// calculateMAC calculates the MAC as described in step 5 of according to +// [Section 11.12.3.2.2] of the Spec. +// +// [Section 11.12.3.2.2]: https://spec.matrix.org/v1.9/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 +func calculateMAC(macKey, ciphertext []byte) []byte { + hash := hmac.New(sha256.New, macKey) + _, err := hash.Write(ciphertext) + if err != nil { + panic(err) + } + return hash.Sum(nil)[:8] +} + +// EncryptSessionData encrypts the given session data with the given recovery +// key as defined in [Section 11.12.3.2.2 of the Spec]. +// +// [Section 11.12.3.2.2 of the Spec]: https://spec.matrix.org/v1.9/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 +func EncryptSessionData[T any](backupKey *MegolmBackupKey, sessionData T) (*EncryptedSessionData[T], error) { + sessionJSON, err := json.Marshal(sessionData) + if err != nil { + return nil, err + } + + ephemeralKey, err := ecdh.X25519().GenerateKey(rand.Reader) + if err != nil { + return nil, err + } + + sharedSecret, err := ephemeralKey.ECDH(backupKey.PublicKey()) + if err != nil { + return nil, err + } + + key, macKey, iv, err := calculateEncryptionParameters(sharedSecret) + if err != nil { + return nil, err + } + + ciphertext, err := aescbc.Encrypt(key, iv, sessionJSON) + if err != nil { + return nil, err + } + + return &EncryptedSessionData[T]{ + Ciphertext: ciphertext, + Ephemeral: EphemeralKey{ephemeralKey.PublicKey()}, + MAC: calculateCompatMAC(macKey), + }, nil +} + +// Decrypt decrypts the [EncryptedSessionData] into a *T using the recovery key +// by reversing the process described in [Section 11.12.3.2.2 of the Spec]. +// +// [Section 11.12.3.2.2 of the Spec]: https://spec.matrix.org/v1.9/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 +func (esd *EncryptedSessionData[T]) Decrypt(backupKey *MegolmBackupKey) (*T, error) { + sharedSecret, err := backupKey.ECDH(esd.Ephemeral.PublicKey) + if err != nil { + return nil, err + } + + key, macKey, iv, err := calculateEncryptionParameters(sharedSecret) + if err != nil { + return nil, err + } + + // Verify the MAC before decrypting. + if !bytes.Equal(calculateCompatMAC(macKey), esd.MAC) { + return nil, ErrInvalidMAC + } + + plaintext, err := aescbc.Decrypt(key, iv, esd.Ciphertext) + if err != nil { + return nil, err + } + + var sessionData T + err = json.Unmarshal(plaintext, &sessionData) + return &sessionData, err +} diff --git a/crypto/backup/encryptedsessiondata_test.go b/crypto/backup/encryptedsessiondata_test.go new file mode 100644 index 00000000..8aab1390 --- /dev/null +++ b/crypto/backup/encryptedsessiondata_test.go @@ -0,0 +1,102 @@ +package backup_test + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/crypto/backup" + "maunium.net/go/mautrix/id" +) + +func TestEncryptedSessionData_Decrypt(t *testing.T) { + testCases := []struct { + encryptedJSON []byte + expectedJSON string + }{ + { + []byte(` + { + "ciphertext": "hDCjEbyi2uMXt3RBWe9mRdeqhcoraPR84/cq5ll16LIIIICJ8ZLmiWG5IwmGqDFmd3Jw20cNo49b38LH3oBJUl5DG44VdjoI4nlgAzaMSLwMZ7JFGt0Enu1Csfgpvgt1qksTP6QB7YDwITD33iL7ucco1iOl7ABGzhyjCi2iZ3A6Xmx3RsAmHhmU5gJWE6/lIoI6/lh7dZFSfp4RTGfxQ8ToCCIsrgdx1weViv4I4ArXfcrdnaprPzP4cH77Ej1Wg1/bUHtB4C8nOiX+cYnOG29NbTHbtQF14zJpA+2XM2JngiLkss+NQj96PQzgPNhAMEFOLLy5ckY1WvS4sMMeCVzAyt5dwEGDcyxLTC4oJ/RrvLcHCHW0aOygPSlNoMRyDgC0f92+mPQGAmFv4GhfDFXfaauBxBdRAPjXj7Onn2B4UdfwQXGLT3RAihba8i9usOX5hLxqQqvtA3SUuV8hPrzHhpPEeRvx+PgZsXwV+gM7Aw3Mza6hwmILdngJh7NNQTINsCRqff9Ck3Kh7aSOoHsHvz7Ot+T514ObDwWYYCBMmS/6EG4XjSya6R98ggRWGrO9l21YYUvzBTv7OLtMck0Za3151Zqi/5LRKP95QIU", + "ephemeral": "o43y/Mck1DExWdHr0+qbPJbjzO97+RH1mw6phLhYQj0", + "mac": "Mnt8eXwFfjw" + } + `), + ` + { + "algorithm": "m.megolm.v1.aes-sha2", + "sender_key": "JUUfV6vErSATm3rIOU9DML+IX1SlYxnAAS824xhbhC4", + "session_key": "AQAAAABc1O9JP2/HXS22iLN1uScFv2UyL33/L3L0sysPKcovQFI0lwKTuutrVeww2SNOU9b2J62kV/QXEw7+N2I9klrvqqr9kdo1ywqFtZOnp8DlgR2+OhOnUYmj5YmJhmApPle9xnVVwZv57Q0REsmSAovHBLH4Kf3GEHPJ9WXEEnLINT9Gzit9qjIZ1fKKacLtvsZ+hbnTPvP5Df3ENalB+03E", + "sender_claimed_keys": {"ed25519":"R2UJWSfgGr64iPENthl/98WGqBtnNlYuP12d6TEuGo4"}, + "forwarding_curve25519_key_chain": [] + } + `, + }, + { + []byte(` + { + "ciphertext": "vdLkqNTijkM1L7HmbxdZs1EHygC7GFG0wPTAaLqpOCoir3K6tNYbjIJs36vzrwawdmfPxZvA9p/k3bZIhZDP7IivGYe69+4pWiIzrwYkHCidigKXkYD8KxKWvakBquO9vWUssXC05xdkQjHMNJK3zSJgtkbMhoY28i1VUdmIjts4xU0cIT40F52Uyx3iu1UrqywUREEE5vhoSbeWxW3Vo5lqPi6rnyvMGZhVzAOv6re2O7wPWnSp0YJUsPaEj6Q9QpLr8BB9vJ++3kwmP5vxfjJLUsXuNEHWIKP5QyhpmGCgwjNpjnU6VhCqBzqs2M/KKX8zxZMGTIRidc3gx2i8KtDwRHRzh3FsSJEaC0sfCfGijpH5g9Pa+2P6b1GxvGQ4TF5X6ayLiV6FyNilpZ4z3kYsy63fP06uinHkX0TUClMQgLLmn0BAiOxKWtLNSLxgFdSYFm5oU/rpOBXWQKbzQ3cvlJZxBtxnaAhJnt3+t/3pJahlKAOxrQbKZAPL/KbO4nF9dsHpMkfMs25pVLDoHLKEXSBhagEFDbPKL5Uv55kca1C1XGrx+8fYUDBRQtYSLBSbAtF3UMv+hIMdRnmyQntwOy2hKRRs2UxnIlExk0Q", + "ephemeral": "24PxRUfQDyYNZcTq0HT8pS3Gq+zkfsAcXHFJ3nZ56W4", + "mac": "T7xq9qHm4Js" + } + `), + ` + { + "algorithm": "m.megolm.v1.aes-sha2", + "sender_key": "JUUfV6vErSATm3rIOU9DML+IX1SlYxnAAS824xhbhC4", + "session_key": "AQAAAAB6cP1PrdPeIG/B0ZRHNUc65ujvIzOxKhW1HN25efyZFaq9xsLvCngm4WO56gEuUhS16E4m0pAa9B/KyRz3AnSOVcHYh1bYxm9qf6zU5PFm255n6FR2lGN0vrgUM7Xu2GNUDCWoNI4m4QsiBor9eCj2ZJRay75dZ4nkhNf3GxBKOkhzPCreKabLxVsseGGIkq8rf01b0CWIcp5ISQISLdza", + "sender_claimed_keys": {"ed25519":"R2UJWSfgGr64iPENthl/98WGqBtnNlYuP12d6TEuGo4"}, + "forwarding_curve25519_key_chain": [] + } + `, + }, + } + + keyBytes, err := base64.RawStdEncoding.DecodeString("ReSMMZeRtDSdrwXzu2OvN0B73KUXkYPt3kaYfFIkw10") + require.NoError(t, err) + backupKey, err := backup.MegolmBackupKeyFromBytes(keyBytes) + require.NoError(t, err) + + for i, tc := range testCases { + t.Run(fmt.Sprintf("test case %d", i+1), func(t *testing.T) { + var esd backup.EncryptedSessionData[backup.MegolmSessionData] + err := json.Unmarshal([]byte(tc.encryptedJSON), &esd) + assert.NoError(t, err) + + sessionData, err := esd.Decrypt(backupKey) + require.NoError(t, err) + + sessionDataJSON, err := json.Marshal(sessionData) + require.NoError(t, err) + assert.JSONEq(t, string(tc.expectedJSON), string(sessionDataJSON)) + }) + } +} + +func TestEncryptedSessionData_Roundtrip(t *testing.T) { + backupKey, err := backup.NewMegolmBackupKey() + require.NoError(t, err) + + sessionData := backup.MegolmSessionData{ + Algorithm: id.AlgorithmMegolmV1, + } + + encrypted, err := backup.EncryptSessionData(backupKey, sessionData) + require.NoError(t, err) + + encryptedJSON, err := json.Marshal(encrypted) + require.NoError(t, err) + + var roundTrippedEncryptedSessionData backup.EncryptedSessionData[backup.MegolmSessionData] + err = json.Unmarshal(encryptedJSON, &roundTrippedEncryptedSessionData) + require.NoError(t, err) + + decrypted, err := roundTrippedEncryptedSessionData.Decrypt(backupKey) + require.NoError(t, err) + + assert.Equal(t, id.AlgorithmMegolmV1, decrypted.Algorithm) +} diff --git a/crypto/backup/megolmbackupkey.go b/crypto/backup/megolmbackupkey.go new file mode 100644 index 00000000..8a57b4cf --- /dev/null +++ b/crypto/backup/megolmbackupkey.go @@ -0,0 +1,28 @@ +package backup + +import ( + "crypto/ecdh" + "crypto/rand" +) + +// MegolmBackupKey is a wrapper around an ECDH X25519 private key that is used +// to decrypt a megolm key backup. +type MegolmBackupKey struct { + *ecdh.PrivateKey +} + +func NewMegolmBackupKey() (*MegolmBackupKey, error) { + key, err := ecdh.X25519().GenerateKey(rand.Reader) + if err != nil { + return nil, err + } + return &MegolmBackupKey{key}, nil +} + +func MegolmBackupKeyFromBytes(bytes []byte) (*MegolmBackupKey, error) { + key, err := ecdh.X25519().NewPrivateKey(bytes) + if err != nil { + return nil, err + } + return &MegolmBackupKey{key}, nil +} From 385449c9cc062e6d8b5525ea6b72e4b7c5313900 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 11 Jan 2024 20:27:01 -0700 Subject: [PATCH 0080/1647] responses: make key backup structs generic over {session,auth}_data Also update the Client to specify the types Signed-off-by: Sumner Evans --- client.go | 15 ++++++++++----- responses.go | 22 +++++++++++----------- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/client.go b/client.go index 3fabb7af..18f2c019 100644 --- a/client.go +++ b/client.go @@ -21,6 +21,7 @@ import ( "go.mau.fi/util/retryafter" "maunium.net/go/maulogger/v2/maulogadapt" + "maunium.net/go/mautrix/crypto/backup" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" "maunium.net/go/mautrix/pushrules" @@ -1947,7 +1948,7 @@ func (cli *Client) GetKeyChanges(ctx context.Context, from, to string) (resp *Re // GetKeyBackup retrieves the keys from the backup. // // See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keyskeys -func (cli *Client) GetKeyBackup(ctx context.Context, version string) (resp *RespRoomKeys, err error) { +func (cli *Client) GetKeyBackup(ctx context.Context, version string) (resp *RespRoomKeys[backup.EncryptedSessionData[backup.MegolmSessionData]], err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys"}, map[string]string{ "version": version, }) @@ -1980,7 +1981,9 @@ func (cli *Client) DeleteKeyBackup(ctx context.Context, version string) (resp *R // GetKeyBackupForRoom retrieves the keys from the backup for the given room. // // See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keyskeysroomid -func (cli *Client) GetKeyBackupForRoom(ctx context.Context, version string, roomID id.RoomID) (resp *RespRoomKeyBackup, err error) { +func (cli *Client) GetKeyBackupForRoom( + ctx context.Context, version string, roomID id.RoomID, +) (resp *RespRoomKeyBackup[backup.EncryptedSessionData[backup.MegolmSessionData]], err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String()}, map[string]string{ "version": version, }) @@ -2014,7 +2017,9 @@ func (cli *Client) DeleteKeysFromBackupForRoom(ctx context.Context, version stri // GetKeyBackupForRoomAndSession retrieves a key from the backup. // // See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keyskeysroomidsessionid -func (cli *Client) GetKeyBackupForRoomAndSession(ctx context.Context, version string, roomID id.RoomID, sessionID id.SessionID) (resp *RespKeyBackupData, err error) { +func (cli *Client) GetKeyBackupForRoomAndSession( + ctx context.Context, version string, roomID id.RoomID, sessionID id.SessionID, +) (resp *RespKeyBackupData[backup.EncryptedSessionData[backup.MegolmSessionData]], err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String(), sessionID.String()}, map[string]string{ "version": version, }) @@ -2047,7 +2052,7 @@ func (cli *Client) DeleteKeysInBackupForRoomAndSession(ctx context.Context, vers // GetKeyBackupLatestVersion returns information about the latest backup version. // // See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keysversion -func (cli *Client) GetKeyBackupLatestVersion(ctx context.Context) (resp *RespRoomKeysVersion, err error) { +func (cli *Client) GetKeyBackupLatestVersion(ctx context.Context) (resp *RespRoomKeysVersion[backup.MegolmAuthData], err error) { urlPath := cli.BuildClientURL("v3", "room_keys", "version") _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return @@ -2065,7 +2070,7 @@ func (cli *Client) CreateKeyBackupVersion(ctx context.Context, req *ReqRoomKeysV // GetKeyBackupVersion returns information about an existing key backup. // // See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keysversionversion -func (cli *Client) GetKeyBackupVersion(ctx context.Context, version string) (resp *RespRoomKeysVersion, err error) { +func (cli *Client) GetKeyBackupVersion(ctx context.Context, version string) (resp *RespRoomKeysVersion[backup.MegolmAuthData], err error) { urlPath := cli.BuildClientURL("v3", "room_keys", "version", version) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return diff --git a/responses.go b/responses.go index 4599fb19..b8552b58 100644 --- a/responses.go +++ b/responses.go @@ -596,27 +596,27 @@ type RespRoomKeysVersionCreate struct { Version string `json:"version"` } -type RespRoomKeysVersion struct { +type RespRoomKeysVersion[A any] struct { Algorithm id.KeyBackupAlgorithm `json:"algorithm"` - AuthData json.RawMessage `json:"auth_data"` + AuthData A `json:"auth_data"` Count int `json:"count"` ETag string `json:"etag"` Version string `json:"version"` } -type RespRoomKeys struct { - Rooms map[id.RoomID]RespRoomKeyBackup `json:"rooms"` +type RespRoomKeys[S any] struct { + Rooms map[id.RoomID]RespRoomKeyBackup[S] `json:"rooms"` } -type RespRoomKeyBackup struct { - Sessions map[id.SessionID]RespKeyBackupData `json:"sessions"` +type RespRoomKeyBackup[S any] struct { + Sessions map[id.SessionID]RespKeyBackupData[S] `json:"sessions"` } -type RespKeyBackupData struct { - FirstMessageIndex int `json:"first_message_index"` - ForwardedCount int `json:"forwarded_count"` - IsVerified bool `json:"is_verified"` - SessionData json.RawMessage `json:"session_data"` +type RespKeyBackupData[S any] struct { + FirstMessageIndex int `json:"first_message_index"` + ForwardedCount int `json:"forwarded_count"` + IsVerified bool `json:"is_verified"` + SessionData S `json:"session_data"` } type RespRoomKeysUpdate struct { From 950ad6bc7e2403aae73d24b14461cf97e025fc63 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 15 Jan 2024 19:16:26 -0700 Subject: [PATCH 0081/1647] crypto/ssss: use errors.Is Signed-off-by: Sumner Evans --- crypto/ssss/client.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crypto/ssss/client.go b/crypto/ssss/client.go index 2dac30e1..0cfdd24f 100644 --- a/crypto/ssss/client.go +++ b/crypto/ssss/client.go @@ -8,6 +8,7 @@ package ssss import ( "context" + "errors" "fmt" "maunium.net/go/mautrix" @@ -34,7 +35,7 @@ func (mach *Machine) GetDefaultKeyID(ctx context.Context) (string, error) { var data DefaultSecretStorageKeyContent err := mach.Client.GetAccountData(ctx, event.AccountDataSecretStorageDefaultKey.Type, &data) if err != nil { - if httpErr, ok := err.(mautrix.HTTPError); ok && httpErr.RespError != nil && httpErr.RespError.ErrCode == "M_NOT_FOUND" { + if httpErr, ok := err.(mautrix.HTTPError); ok && errors.Is(httpErr.RespError, mautrix.MNotFound) { return "", ErrNoDefaultKeyAccountDataEvent } return "", fmt.Errorf("failed to get default key account data from server: %w", err) From c77d6fe1f9554284b166fa55f77c14d171dfe068 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 15 Jan 2024 22:28:11 -0700 Subject: [PATCH 0082/1647] event/type: add megolm backup key account data type Signed-off-by: Sumner Evans --- event/type.go | 1 + 1 file changed, 1 insertion(+) diff --git a/event/type.go b/event/type.go index 2f4f4f94..c741e8ed 100644 --- a/event/type.go +++ b/event/type.go @@ -238,6 +238,7 @@ var ( AccountDataCrossSigningMaster = Type{"m.cross_signing.master", AccountDataEventType} AccountDataCrossSigningUser = Type{"m.cross_signing.user_signing", AccountDataEventType} AccountDataCrossSigningSelf = Type{"m.cross_signing.self_signing", AccountDataEventType} + AccountDataMegolmBackupKey = Type{"m.megolm_backup.v1", AccountDataEventType} ) // Device-to-device events From 65be59bfed6e83d04611d3cb73c8968bd95fba33 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 16 Jan 2024 11:40:46 -0700 Subject: [PATCH 0083/1647] crypto: refactor to remove need for Utility struct This also removes all dependence on libolm for the functions that were provided by the Utility struct. The crypto/signatures package should be used for all signature verification operations, and for the occasional situation where a base64-encoded SHA-256 hash is required, the olm.SHA256B64 function should be used. Signed-off-by: Sumner Evans --- crypto/cross_sign_store.go | 4 +- crypto/devicelist.go | 6 +- crypto/encryptolm.go | 4 +- crypto/goolm/account/account.go | 5 +- crypto/goolm/account/account_test.go | 29 +++--- crypto/goolm/utilities/main.go | 23 ----- crypto/olm/sha256b64.go | 15 +++ crypto/olm/utility.go | 146 --------------------------- crypto/olm/utility_goolm.go | 92 ----------------- crypto/signatures/signatures.go | 73 +++++++++++++- crypto/verification.go | 4 +- crypto/verification_in_room.go | 2 +- 12 files changed, 114 insertions(+), 289 deletions(-) delete mode 100644 crypto/goolm/utilities/main.go create mode 100644 crypto/olm/sha256b64.go delete mode 100644 crypto/olm/utility.go delete mode 100644 crypto/olm/utility_goolm.go diff --git a/crypto/cross_sign_store.go b/crypto/cross_sign_store.go index 88fcd0ed..28d0bad0 100644 --- a/crypto/cross_sign_store.go +++ b/crypto/cross_sign_store.go @@ -11,7 +11,7 @@ import ( "context" "maunium.net/go/mautrix" - "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/crypto/signatures" "maunium.net/go/mautrix/id" ) @@ -80,7 +80,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK } log.Debug().Msg("Verifying cross-signing key signature") - if verified, err := olm.VerifySignatureJSON(userKeys, signUserID, signKeyName, signingKey); err != nil { + if verified, err := signatures.VerifySignatureJSON(userKeys, signUserID, signKeyName, signingKey); err != nil { log.Warn().Err(err).Msg("Error verifying cross-signing key signature") } else { if verified { diff --git a/crypto/devicelist.go b/crypto/devicelist.go index f5c07cd3..16c4164e 100644 --- a/crypto/devicelist.go +++ b/crypto/devicelist.go @@ -14,7 +14,7 @@ import ( "github.com/rs/zerolog" "maunium.net/go/mautrix" - "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/crypto/signatures" "maunium.net/go/mautrix/id" ) @@ -52,7 +52,7 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id } else if _, ok := selfSigs[id.NewKeyID(id.KeyAlgorithmEd25519, pubKey.String())]; !ok { continue } - if verified, err := olm.VerifySignatureJSON(deviceKeys, signerUserID, pubKey.String(), pubKey); verified { + if verified, err := signatures.VerifySignatureJSON(deviceKeys, signerUserID, pubKey.String(), pubKey); verified { if signKey, ok := deviceKeys.Keys[id.DeviceKeyID(signerKey)]; ok { signature := deviceKeys.Signatures[signerUserID][id.NewKeyID(id.KeyAlgorithmEd25519, pubKey.String())] log.Trace().Err(err). @@ -245,7 +245,7 @@ func (mach *OlmMachine) validateDevice(userID id.UserID, deviceID id.DeviceID, d return existing, fmt.Errorf("%w (expected %s, got %s)", MismatchingSigningKey, existing.SigningKey, signingKey) } - ok, err := olm.VerifySignatureJSON(deviceKeys, userID, deviceID.String(), signingKey) + ok, err := signatures.VerifySignatureJSON(deviceKeys, userID, deviceID.String(), signingKey) if err != nil { return existing, fmt.Errorf("failed to verify signature: %w", err) } else if !ok { diff --git a/crypto/encryptolm.go b/crypto/encryptolm.go index 3b1d40d3..15e9df29 100644 --- a/crypto/encryptolm.go +++ b/crypto/encryptolm.go @@ -12,7 +12,7 @@ import ( "fmt" "maunium.net/go/mautrix" - "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/crypto/signatures" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -109,7 +109,7 @@ func (mach *OlmMachine) createOutboundSessions(ctx context.Context, input map[id continue } identity := input[userID][deviceID] - if ok, err := olm.VerifySignatureJSON(oneTimeKey.RawData, userID, deviceID.String(), identity.SigningKey); err != nil { + if ok, err := signatures.VerifySignatureJSON(oneTimeKey.RawData, userID, deviceID.String(), identity.SigningKey); err != nil { log.Error().Err(err).Msg("Failed to verify signature of one-time key") } else if !ok { log.Warn().Msg("One-time key has invalid signature from device") diff --git a/crypto/goolm/account/account.go b/crypto/goolm/account/account.go index 7896f849..4057543a 100644 --- a/crypto/goolm/account/account.go +++ b/crypto/goolm/account/account.go @@ -110,12 +110,13 @@ func (a Account) IdentityKeys() (id.Ed25519, id.Curve25519) { return ed25519, curve25519 } -// Sign returns the signature of a message using the Ed25519 key for this Account. +// Sign returns the base64-encoded signature of a message using the Ed25519 key +// for this Account. func (a Account) Sign(message []byte) ([]byte, error) { if len(message) == 0 { return nil, fmt.Errorf("sign: %w", goolm.ErrEmptyInput) } - return goolm.Base64Encode(a.IdKeys.Ed25519.Sign(message)), nil + return []byte(base64.RawStdEncoding.EncodeToString(a.IdKeys.Ed25519.Sign(message))), nil } // OneTimeKeys returns the public parts of the unpublished one time keys of the Account. diff --git a/crypto/goolm/account/account_test.go b/crypto/goolm/account/account_test.go index a18840b1..943d8570 100644 --- a/crypto/goolm/account/account_test.go +++ b/crypto/goolm/account/account_test.go @@ -2,14 +2,18 @@ package account_test import ( "bytes" + "encoding/base64" "errors" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "maunium.net/go/mautrix/id" "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/account" - "maunium.net/go/mautrix/crypto/goolm/utilities" + "maunium.net/go/mautrix/crypto/signatures" ) func TestAccount(t *testing.T) { @@ -599,19 +603,14 @@ func TestOldV3AccountPickle(t *testing.T) { func TestAccountSign(t *testing.T) { accountA, err := account.NewAccount(nil) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) plainText := []byte("Hello, World") - signature, err := accountA.Sign(plainText) - if err != nil { - t.Fatal(err) - } - verified, err := utilities.VerifySignature(plainText, accountA.IdKeys.Ed25519.B64Encoded(), signature) - if err != nil { - t.Fatal(err) - } - if !verified { - t.Fatal("signature did not verify") - } + signatureB64, err := accountA.Sign(plainText) + require.NoError(t, err) + signature, err := base64.RawStdEncoding.DecodeString(string(signatureB64)) + require.NoError(t, err) + + verified, err := signatures.VerifySignature(plainText, accountA.IdKeys.Ed25519.B64Encoded(), signature) + assert.NoError(t, err) + assert.True(t, verified) } diff --git a/crypto/goolm/utilities/main.go b/crypto/goolm/utilities/main.go deleted file mode 100644 index c5b5c2d5..00000000 --- a/crypto/goolm/utilities/main.go +++ /dev/null @@ -1,23 +0,0 @@ -package utilities - -import ( - "encoding/base64" - - "maunium.net/go/mautrix/crypto/goolm" - "maunium.net/go/mautrix/crypto/goolm/crypto" - "maunium.net/go/mautrix/id" -) - -// VerifySignature verifies an ed25519 signature. -func VerifySignature(message []byte, key id.Ed25519, signature []byte) (ok bool, err error) { - keyDecoded, err := base64.RawStdEncoding.DecodeString(string(key)) - if err != nil { - return false, err - } - signatureDecoded, err := goolm.Base64Decode(signature) - if err != nil { - return false, err - } - publicKey := crypto.Ed25519PublicKey(keyDecoded) - return publicKey.Verify(message, signatureDecoded), nil -} diff --git a/crypto/olm/sha256b64.go b/crypto/olm/sha256b64.go new file mode 100644 index 00000000..711c9454 --- /dev/null +++ b/crypto/olm/sha256b64.go @@ -0,0 +1,15 @@ +package olm + +import ( + "crypto/sha256" + "encoding/base64" +) + +// SHA256B64 calculates the SHA-256 hash of the input and encodes it as base64. +func SHA256B64(input []byte) string { + if len(input) == 0 { + panic(EmptyInput) + } + hash := sha256.Sum256([]byte(input)) + return base64.RawStdEncoding.EncodeToString(hash[:]) +} diff --git a/crypto/olm/utility.go b/crypto/olm/utility.go deleted file mode 100644 index 87055fb3..00000000 --- a/crypto/olm/utility.go +++ /dev/null @@ -1,146 +0,0 @@ -//go:build !goolm - -package olm - -// #cgo LDFLAGS: -lolm -lstdc++ -// #include -import "C" - -import ( - "encoding/json" - "fmt" - "unsafe" - - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" - "go.mau.fi/util/exgjson" - - "maunium.net/go/mautrix/crypto/canonicaljson" - "maunium.net/go/mautrix/id" -) - -// Utility stores the necessary state to perform hash and signature -// verification operations. -type Utility struct { - int *C.OlmUtility - mem []byte -} - -// utilitySize returns the size of a utility object in bytes. -func utilitySize() uint { - return uint(C.olm_utility_size()) -} - -// sha256Len returns the length of the buffer needed to hold the SHA-256 hash. -func (u *Utility) sha256Len() uint { - return uint(C.olm_sha256_length((*C.OlmUtility)(u.int))) -} - -// lastError returns an error describing the most recent error to happen to a -// utility. -func (u *Utility) lastError() error { - return convertError(C.GoString(C.olm_utility_last_error((*C.OlmUtility)(u.int)))) -} - -// Clear clears the memory used to back this utility. -func (u *Utility) Clear() error { - r := C.olm_clear_utility((*C.OlmUtility)(u.int)) - if r == errorVal() { - return u.lastError() - } - return nil -} - -// NewUtility creates a new utility. -func NewUtility() *Utility { - memory := make([]byte, utilitySize()) - return &Utility{ - int: C.olm_utility(unsafe.Pointer(&memory[0])), - mem: memory, - } -} - -// Sha256 calculates the SHA-256 hash of the input and encodes it as base64. -func (u *Utility) Sha256(input string) string { - if len(input) == 0 { - panic(EmptyInput) - } - output := make([]byte, u.sha256Len()) - r := C.olm_sha256( - (*C.OlmUtility)(u.int), - unsafe.Pointer(&([]byte(input)[0])), - C.size_t(len(input)), - unsafe.Pointer(&(output[0])), - C.size_t(len(output))) - if r == errorVal() { - panic(u.lastError()) - } - return string(output) -} - -// VerifySignature verifies an ed25519 signature. Returns true if the verification -// suceeds or false otherwise. Returns error on failure. If the key was too -// small then the error will be "INVALID_BASE64". -func (u *Utility) VerifySignature(message string, key id.Ed25519, signature string) (ok bool, err error) { - if len(message) == 0 || len(key) == 0 || len(signature) == 0 { - return false, EmptyInput - } - r := C.olm_ed25519_verify( - (*C.OlmUtility)(u.int), - unsafe.Pointer(&([]byte(key)[0])), - C.size_t(len(key)), - unsafe.Pointer(&([]byte(message)[0])), - C.size_t(len(message)), - unsafe.Pointer(&([]byte(signature)[0])), - C.size_t(len(signature))) - if r == errorVal() { - err = u.lastError() - if err == BadMessageMAC { - err = nil - } - } else { - ok = true - } - return ok, err -} - -// VerifySignatureJSON verifies the signature in the JSON object _obj following -// the Matrix specification: -// https://matrix.org/speculator/spec/drafts%2Fe2e/appendices.html#signing-json -// If the _obj is a struct, the `json` tags will be honored. -func (u *Utility) VerifySignatureJSON(obj interface{}, userID id.UserID, keyName string, key id.Ed25519) (bool, error) { - var err error - objJSON, ok := obj.(json.RawMessage) - if !ok { - objJSON, err = json.Marshal(obj) - if err != nil { - return false, err - } - } - sig := gjson.GetBytes(objJSON, exgjson.Path("signatures", string(userID), fmt.Sprintf("ed25519:%s", keyName))) - if !sig.Exists() || sig.Type != gjson.String { - return false, SignatureNotFound - } - objJSON, err = sjson.DeleteBytes(objJSON, "unsigned") - if err != nil { - return false, err - } - objJSON, err = sjson.DeleteBytes(objJSON, "signatures") - if err != nil { - return false, err - } - objJSONString := string(canonicaljson.CanonicalJSONAssumeValid(objJSON)) - return u.VerifySignature(objJSONString, key, sig.Str) -} - -// VerifySignatureJSON verifies the signature in the JSON object _obj following -// the Matrix specification: -// https://matrix.org/speculator/spec/drafts%2Fe2e/appendices.html#signing-json -// This function is a wrapper over Utility.VerifySignatureJSON that creates and -// destroys the Utility object transparently. -// If the _obj is a struct, the `json` tags will be honored. -func VerifySignatureJSON(obj interface{}, userID id.UserID, keyName string, key id.Ed25519) (bool, error) { - u := NewUtility() - defer u.Clear() - return u.VerifySignatureJSON(obj, userID, keyName, key) -} diff --git a/crypto/olm/utility_goolm.go b/crypto/olm/utility_goolm.go deleted file mode 100644 index 926b5404..00000000 --- a/crypto/olm/utility_goolm.go +++ /dev/null @@ -1,92 +0,0 @@ -//go:build goolm - -package olm - -import ( - "crypto/sha256" - "encoding/base64" - "encoding/json" - "fmt" - - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" - "go.mau.fi/util/exgjson" - - "maunium.net/go/mautrix/crypto/canonicaljson" - "maunium.net/go/mautrix/crypto/goolm/utilities" - "maunium.net/go/mautrix/id" -) - -// Utility stores the necessary state to perform hash and signature -// verification operations. -type Utility struct{} - -// Clear clears the memory used to back this utility. -func (u *Utility) Clear() error { - return nil -} - -// NewUtility creates a new utility. -func NewUtility() *Utility { - return &Utility{} -} - -// Sha256 calculates the SHA-256 hash of the input and encodes it as base64. -func (u *Utility) Sha256(input string) string { - if len(input) == 0 { - panic(EmptyInput) - } - hash := sha256.Sum256([]byte(input)) - return base64.RawStdEncoding.EncodeToString(hash[:]) -} - -// VerifySignature verifies an ed25519 signature. Returns true if the verification -// suceeds or false otherwise. Returns error on failure. If the key was too -// small then the error will be "INVALID_BASE64". -func (u *Utility) VerifySignature(message string, key id.Ed25519, signature string) (ok bool, err error) { - if len(message) == 0 || len(key) == 0 || len(signature) == 0 { - return false, EmptyInput - } - return utilities.VerifySignature([]byte(message), key, []byte(signature)) -} - -// VerifySignatureJSON verifies the signature in the JSON object _obj following -// the Matrix specification: -// https://matrix.org/speculator/spec/drafts%2Fe2e/appendices.html#signing-json -// If the _obj is a struct, the `json` tags will be honored. -func (u *Utility) VerifySignatureJSON(obj interface{}, userID id.UserID, keyName string, key id.Ed25519) (bool, error) { - var err error - objJSON, ok := obj.(json.RawMessage) - if !ok { - objJSON, err = json.Marshal(obj) - if err != nil { - return false, err - } - } - sig := gjson.GetBytes(objJSON, exgjson.Path("signatures", string(userID), fmt.Sprintf("ed25519:%s", keyName))) - if !sig.Exists() || sig.Type != gjson.String { - return false, SignatureNotFound - } - objJSON, err = sjson.DeleteBytes(objJSON, "unsigned") - if err != nil { - return false, err - } - objJSON, err = sjson.DeleteBytes(objJSON, "signatures") - if err != nil { - return false, err - } - objJSONString := string(canonicaljson.CanonicalJSONAssumeValid(objJSON)) - return u.VerifySignature(objJSONString, key, sig.Str) -} - -// VerifySignatureJSON verifies the signature in the JSON object _obj following -// the Matrix specification: -// https://matrix.org/speculator/spec/drafts%2Fe2e/appendices.html#signing-json -// This function is a wrapper over Utility.VerifySignatureJSON that creates and -// destroys the Utility object transparently. -// If the _obj is a struct, the `json` tags will be honored. -func VerifySignatureJSON(obj interface{}, userID id.UserID, keyName string, key id.Ed25519) (bool, error) { - u := NewUtility() - defer u.Clear() - return u.VerifySignatureJSON(obj, userID, keyName, key) -} diff --git a/crypto/signatures/signatures.go b/crypto/signatures/signatures.go index dae34ef3..7ad19316 100644 --- a/crypto/signatures/signatures.go +++ b/crypto/signatures/signatures.go @@ -1,6 +1,24 @@ package signatures -import "maunium.net/go/mautrix/id" +import ( + "encoding/base64" + "encoding/json" + "errors" + "fmt" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.mau.fi/util/exgjson" + + "maunium.net/go/mautrix/crypto/canonicaljson" + "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/id" +) + +var ( + ErrEmptyInput = errors.New("empty input") + ErrSignatureNotFound = errors.New("input JSON doesn't contain signature from specified device") +) // Signatures represents a set of signatures for some data from multiple users // and keys. @@ -15,3 +33,56 @@ func NewSingleSignature(userID id.UserID, algorithm id.KeyAlgorithm, keyID strin }, } } + +// VerifySignature verifies an Ed25519 signature. +func VerifySignature(message []byte, key id.Ed25519, signature []byte) (ok bool, err error) { + if len(message) == 0 || len(key) == 0 || len(signature) == 0 { + return false, ErrEmptyInput + } + keyDecoded, err := base64.RawStdEncoding.DecodeString(key.String()) + if err != nil { + return false, err + } + publicKey := crypto.Ed25519PublicKey(keyDecoded) + return publicKey.Verify(message, signature), nil +} + +// VerifySignatureJSON verifies the signature in the given JSON object "obj" +// as described in [Appendix 3] of the Matrix Spec. +// +// This function is a wrapper over [Utility.VerifySignatureJSON] that creates +// and destroys the [Utility] object transparently. +// +// If the "obj" is not already a [json.RawMessage], it will re-encoded as JSON +// for the verification, so "json" tags will be honored. +// +// [Appendix 3]: https://spec.matrix.org/v1.9/appendices/#signing-json +func VerifySignatureJSON(obj any, userID id.UserID, keyName string, key id.Ed25519) (bool, error) { + var err error + objJSON, ok := obj.(json.RawMessage) + if !ok { + objJSON, err = json.Marshal(obj) + if err != nil { + return false, err + } + } + + sig := gjson.GetBytes(objJSON, exgjson.Path("signatures", string(userID), fmt.Sprintf("ed25519:%s", keyName))) + if !sig.Exists() || sig.Type != gjson.String { + return false, ErrSignatureNotFound + } + objJSON, err = sjson.DeleteBytes(objJSON, "unsigned") + if err != nil { + return false, err + } + objJSON, err = sjson.DeleteBytes(objJSON, "signatures") + if err != nil { + return false, err + } + objJSONString := canonicaljson.CanonicalJSONAssumeValid(objJSON) + sigBytes, err := base64.RawStdEncoding.DecodeString(sig.Str) + if err != nil { + return false, err + } + return VerifySignature(objJSONString, key, sigBytes) +} diff --git a/crypto/verification.go b/crypto/verification.go index 31608bfa..78906f89 100644 --- a/crypto/verification.go +++ b/crypto/verification.go @@ -367,7 +367,7 @@ func (mach *OlmMachine) handleVerificationKey(ctx context.Context, userID id.Use if verState.initiatedByUs { // verify commitment string from accept message now - expectedCommitment := olm.NewUtility().Sha256(content.Key + verState.startEventCanonical) + expectedCommitment := olm.SHA256B64([]byte(content.Key + verState.startEventCanonical)) mach.Log.Debug().Msgf("Received commitment: %v Expected: %v", verState.commitment, expectedCommitment) if expectedCommitment != verState.commitment { mach.Log.Warn().Msgf("Canceling verification transaction %v due to commitment mismatch", transactionID) @@ -716,7 +716,7 @@ func (mach *OlmMachine) SendSASVerificationAccept(ctx context.Context, fromUser if err != nil { return err } - hash := olm.NewUtility().Sha256(string(publicKey) + string(canonical)) + hash := olm.SHA256B64(append(publicKey, canonical...)) sasMethods := make([]event.SASMethod, len(methods)) for i, method := range methods { sasMethods[i] = method.Type() diff --git a/crypto/verification_in_room.go b/crypto/verification_in_room.go index 240c52b2..a01f0216 100644 --- a/crypto/verification_in_room.go +++ b/crypto/verification_in_room.go @@ -168,7 +168,7 @@ func (mach *OlmMachine) SendInRoomSASVerificationAccept(ctx context.Context, roo if err != nil { return err } - hash := olm.NewUtility().Sha256(string(publicKey) + string(canonical)) + hash := olm.SHA256B64(append(publicKey, canonical...)) sasMethods := make([]event.SASMethod, len(methods)) for i, method := range methods { sasMethods[i] = method.Type() From d8ae0dda1470b173ea5fb4c720c426024ccec7d3 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 15 Jan 2024 22:28:32 -0700 Subject: [PATCH 0084/1647] crypto/keybackup: add function to download and store latest key backup Signed-off-by: Sumner Evans --- crypto/keybackup.go | 149 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 crypto/keybackup.go diff --git a/crypto/keybackup.go b/crypto/keybackup.go new file mode 100644 index 00000000..d52e7003 --- /dev/null +++ b/crypto/keybackup.go @@ -0,0 +1,149 @@ +package crypto + +import ( + "context" + "fmt" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/crypto/backup" + "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/crypto/signatures" + "maunium.net/go/mautrix/id" +) + +func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, megolmBackupKey *backup.MegolmBackupKey) error { + log := mach.machOrContextLog(ctx).With(). + Str("action", "download and store latest key backup"). + Logger() + versionInfo, err := mach.Client.GetKeyBackupLatestVersion(ctx) + if err != nil { + return err + } + + if versionInfo.Algorithm != id.KeyBackupAlgorithmMegolmBackupV1 { + return fmt.Errorf("unsupported key backup algorithm: %s", versionInfo.Algorithm) + } + + log = log.With(). + Int("count", versionInfo.Count). + Str("etag", versionInfo.ETag). + Str("key_backup_version", versionInfo.Version). + Logger() + + if versionInfo.Count == 0 { + log.Debug().Msg("No keys found in key backup") + return nil + } + + userSignatures, ok := versionInfo.AuthData.Signatures[mach.Client.UserID] + if !ok { + return fmt.Errorf("no signature from user %s found in key backup", mach.Client.UserID) + } + + crossSigningPubkeys := mach.GetOwnCrossSigningPublicKeys(ctx) + + signatureVerified := false + for keyID := range userSignatures { + keyAlg, keyName := keyID.Parse() + if keyAlg != id.KeyAlgorithmEd25519 { + continue + } + log := log.With().Str("key_name", keyName).Logger() + + var key id.Ed25519 + if keyName == crossSigningPubkeys.MasterKey.String() { + key = crossSigningPubkeys.MasterKey + } else if device, err := mach.GetOrFetchDevice(ctx, mach.Client.UserID, id.DeviceID(keyName)); err != nil { + log.Warn().Err(err).Msg("Failed to fetch device") + continue + } else if !mach.IsDeviceTrusted(device) { + log.Warn().Err(err).Msg("Device is not trusted") + continue + } else { + key = device.SigningKey + } + + ok, err = signatures.VerifySignatureJSON(versionInfo.AuthData, mach.Client.UserID, keyName, key) + if err != nil || !ok { + log.Warn().Err(err).Stringer("key_id", keyID).Msg("Signature verification failed") + continue + } else { + // One of the signatures is valid, break from the loop. + signatureVerified = true + break + } + } + if !signatureVerified { + return fmt.Errorf("no valid signature from user %s found in key backup", mach.Client.UserID) + } + + keys, err := mach.Client.GetKeyBackup(ctx, versionInfo.Version) + if err != nil { + return err + } + + for roomID, backup := range keys.Rooms { + for sessionID, keyBackupData := range backup.Sessions { + sessionData, err := keyBackupData.SessionData.Decrypt(megolmBackupKey) + if err != nil { + return err + } + + err = mach.importRoomKeyFromBackup(ctx, roomID, sessionID, sessionData) + if err != nil { + return err + } + } + } + + return nil +} + +func (mach *OlmMachine) importRoomKeyFromBackup(ctx context.Context, roomID id.RoomID, sessionID id.SessionID, keyBackupData *backup.MegolmSessionData) error { + log := zerolog.Ctx(ctx).With(). + Str("room_id", roomID.String()). + Str("session_id", sessionID.String()). + Logger() + if keyBackupData.Algorithm != id.AlgorithmMegolmV1 { + return fmt.Errorf("ignoring room key in backup with weird algorithm %s", keyBackupData.Algorithm) + } + + igsInternal, err := olm.InboundGroupSessionImport([]byte(keyBackupData.SessionKey)) + if err != nil { + return fmt.Errorf("failed to import inbound group session: %w", err) + } else if igsInternal.ID() != sessionID { + log.Warn(). + Stringer("actual_session_id", igsInternal.ID()). + Msg("Mismatched session ID while creating inbound group session from key backup") + return fmt.Errorf("mismatched session ID while creating inbound group session from key backup") + } + + var maxAge time.Duration + var maxMessages int + if config, err := mach.StateStore.GetEncryptionEvent(ctx, roomID); err != nil { + log.Error().Err(err).Msg("Failed to get encryption event for room") + } else if config != nil { + maxAge = time.Duration(config.RotationPeriodMillis) * time.Millisecond + maxMessages = config.RotationPeriodMessages + } + + igs := &InboundGroupSession{ + Internal: *igsInternal, + SenderKey: keyBackupData.SenderKey, + RoomID: roomID, + ForwardingChains: append(keyBackupData.ForwardingKeyChain, keyBackupData.SenderKey.String()), + id: sessionID, + + ReceivedAt: time.Now().UTC(), + MaxAge: maxAge.Milliseconds(), + MaxMessages: maxMessages, + } + err = mach.CryptoStore.PutGroupSession(ctx, roomID, keyBackupData.SenderKey, sessionID, igs) + if err != nil { + return fmt.Errorf("failed to store new inbound group session: %w", err) + } + mach.markSessionReceived(sessionID) + return nil +} From 4ef20ba9cc9e4b010269bdd09c16d01caee31379 Mon Sep 17 00:00:00 2001 From: Adam Van Ymeren Date: Wed, 17 Jan 2024 12:48:19 -0800 Subject: [PATCH 0085/1647] base64 encode the session before passing to libolm/goolm also log the number of imported sessions for fun Signed-off-by: Sumner Evans --- crypto/keybackup.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/crypto/keybackup.go b/crypto/keybackup.go index d52e7003..4b2be50a 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -2,6 +2,7 @@ package crypto import ( "context" + "encoding/base64" "fmt" "time" @@ -84,6 +85,8 @@ func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, meg return err } + var count int + for roomID, backup := range keys.Rooms { for sessionID, keyBackupData := range backup.Sessions { sessionData, err := keyBackupData.SessionData.Decrypt(megolmBackupKey) @@ -95,9 +98,12 @@ func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, meg if err != nil { return err } + count++ } } + log.Info().Int("count", count).Msg("successfully imported sessions from backup") + return nil } @@ -110,7 +116,7 @@ func (mach *OlmMachine) importRoomKeyFromBackup(ctx context.Context, roomID id.R return fmt.Errorf("ignoring room key in backup with weird algorithm %s", keyBackupData.Algorithm) } - igsInternal, err := olm.InboundGroupSessionImport([]byte(keyBackupData.SessionKey)) + igsInternal, err := olm.InboundGroupSessionImport([]byte(base64.RawStdEncoding.EncodeToString(keyBackupData.SessionKey))) if err != nil { return fmt.Errorf("failed to import inbound group session: %w", err) } else if igsInternal.ID() != sessionID { From 591e3a002a4990d38e5b6e165c1e1147c3626f38 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Jan 2024 12:28:10 +0200 Subject: [PATCH 0086/1647] Add function for creating a ready-to-use AppService instance --- CHANGELOG.md | 2 ++ appservice/appservice.go | 43 +++++++++++++++++++++++++++++++++++ bridge/bridge.go | 19 +++++++++++++--- bridge/bridgeconfig/config.go | 1 - 4 files changed, 61 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c0d3bc8e..57518395 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ appservice endpoints. * *(bridge)* Bumped minimum Matrix spec version to v1.4. * *(bridge)* Fixed `RawArgs` field in command events of command state callbacks. +* *(appservice)* Added `CreateFull` helper function for creating an `AppService` + instance with all the mandatory fields set. ## v0.17.0 (2024-01-16) diff --git a/appservice/appservice.go b/appservice/appservice.go index 76d2f786..f6ca540c 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -56,6 +56,8 @@ func Create() *AppService { OTKCounts: make(chan *mautrix.OTKCount, OTKChannelSize), DeviceLists: make(chan *mautrix.DeviceLists, EventChannelSize), QueryHandler: &QueryHandlerStub{}, + + DefaultHTTPRetries: 4, } as.Router.HandleFunc("/_matrix/app/v1/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut) @@ -68,6 +70,47 @@ func Create() *AppService { return as } +// CreateOpts contains the options for initializing a new [AppService] instance. +type CreateOpts struct { + // Required, the registration file data for this appservice. + Registration *Registration + // Required, the homeserver's server_name. + HomeserverDomain string + // Required, the homeserver URL to connect to. May be an unix:/path/to/socket URL. + HomeserverURL string + // Required if you want to use the standard HTTP server, optional for websockets (non-standard) + HostConfig HostConfig + // Optional, defaults to a memory state store + StateStore StateStore +} + +// CreateFull creates a fully configured appservice instance that can be [Start]ed and used directly. +func CreateFull(opts CreateOpts) (*AppService, error) { + if opts.HomeserverDomain == "" { + return nil, fmt.Errorf("missing homeserver domain") + } else if opts.HomeserverURL == "" { + return nil, fmt.Errorf("missing homeserver URL") + } else if opts.Registration == nil { + return nil, fmt.Errorf("missing registration") + } + as := Create() + as.HomeserverDomain = opts.HomeserverDomain + as.Host = opts.HostConfig + as.Registration = opts.Registration + err := as.SetHomeserverURL(opts.HomeserverURL) + if err != nil { + return nil, err + } + if opts.StateStore != nil { + as.StateStore = opts.StateStore + } else { + as.StateStore = mautrix.NewMemoryStateStore().(StateStore) + } + return as, nil +} + +var _ StateStore = (*mautrix.MemoryStateStore)(nil) + // QueryHandler handles room alias and user ID queries from the homeserver. type QueryHandler interface { QueryAlias(alias string) bool diff --git a/bridge/bridge.go b/bridge/bridge.go index 134582a2..767c28cf 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -567,11 +567,24 @@ func (br *Bridge) init() { br.ZLog.Debug().Msg("Initializing state store") br.StateStore = sqlstatestore.NewSQLStateStore(br.DB, dbutil.ZeroLogger(br.ZLog.With().Str("db_section", "matrix_state").Logger()), true) - br.AS = br.Config.MakeAppService() + br.AS, err = appservice.CreateFull(appservice.CreateOpts{ + Registration: br.Config.AppService.GetRegistration(), + HomeserverDomain: br.Config.Homeserver.Domain, + HomeserverURL: br.Config.Homeserver.Address, + HostConfig: appservice.HostConfig{ + Hostname: br.Config.AppService.Hostname, + Port: br.Config.AppService.Port, + }, + StateStore: br.StateStore, + }) + if err != nil { + br.ZLog.WithLevel(zerolog.FatalLevel).Err(err). + Msg("Failed to initialize appservice") + os.Exit(15) + } + br.AS.Log = *br.ZLog br.AS.DoublePuppetValue = br.Name br.AS.GetProfile = br.getProfile - br.AS.Log = *br.ZLog - br.AS.StateStore = br.StateStore br.Bot = br.AS.BotIntent() br.ZLog.Debug().Msg("Initializing Matrix event processor") diff --git a/bridge/bridgeconfig/config.go b/bridge/bridgeconfig/config.go index 5f578008..2e8548b5 100644 --- a/bridge/bridgeconfig/config.go +++ b/bridge/bridgeconfig/config.go @@ -106,7 +106,6 @@ func (config *BaseConfig) MakeAppService() *appservice.AppService { _ = as.SetHomeserverURL(config.Homeserver.Address) as.Host.Hostname = config.AppService.Hostname as.Host.Port = config.AppService.Port - as.DefaultHTTPRetries = 4 as.Registration = config.AppService.GetRegistration() return as } From 1045d29e074c3b24014d598e29d9147612f1004f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Jan 2024 12:31:14 +0200 Subject: [PATCH 0087/1647] Add some more godocs for appservices --- appservice/appservice.go | 42 +++++++++++++++++++++++++++++++++++----- 1 file changed, 37 insertions(+), 5 deletions(-) diff --git a/appservice/appservice.go b/appservice/appservice.go index f6ca540c..b64a84a1 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -35,7 +35,7 @@ import ( var EventChannelSize = 64 var OTKChannelSize = 4 -// Create a blank appservice instance. +// Create creates a blank appservice instance. func Create() *AppService { jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) as := &AppService{ @@ -76,7 +76,7 @@ type CreateOpts struct { Registration *Registration // Required, the homeserver's server_name. HomeserverDomain string - // Required, the homeserver URL to connect to. May be an unix:/path/to/socket URL. + // Required, the homeserver URL to connect to. Should be either https://address or unix:path HomeserverURL string // Required if you want to use the standard HTTP server, optional for websockets (non-standard) HostConfig HostConfig @@ -217,10 +217,13 @@ func (as *AppService) PrepareWebsocket() { // HostConfig contains info about how to host the appservice. type HostConfig struct { + // Hostname can be an IP address or an absolute path for a unix socket. Hostname string `yaml:"hostname"` - Port uint16 `yaml:"port"` - TLSKey string `yaml:"tls_key,omitempty"` - TLSCert string `yaml:"tls_cert,omitempty"` + // Port is required when Hostname is an IP address, optional for unix sockets + Port uint16 `yaml:"port"` + + TLSKey string `yaml:"tls_key,omitempty"` + TLSCert string `yaml:"tls_cert,omitempty"` } // Address gets the whole address of the Appservice. @@ -254,6 +257,7 @@ func (as *AppService) YAML() (string, error) { return string(data), nil } +// BotMXID returns the user ID corresponding to the appservice's sender_localpart func (as *AppService) BotMXID() id.UserID { return id.NewUserID(as.Registration.SenderLocalpart, as.HomeserverDomain) } @@ -290,6 +294,12 @@ func (as *AppService) makeIntent(userID id.UserID) *IntentAPI { return intent } +// Intent returns an [IntentAPI] object for the given user ID. +// +// This will return nil if the given user ID has an empty localpart, +// or if the server name is not the same as the appservice's HomeserverDomain. +// It does not currently validate that the given user ID is actually in the +// appservice's namespace. Validation may be added later. func (as *AppService) Intent(userID id.UserID) *IntentAPI { as.intentsLock.RLock() intent, ok := as.intents[userID] @@ -300,6 +310,7 @@ func (as *AppService) Intent(userID id.UserID) *IntentAPI { return intent } +// BotIntent returns an [IntentAPI] object for the appservice's sender_localpart user. func (as *AppService) BotIntent() *IntentAPI { if as.botIntent == nil { as.botIntent = as.makeIntent(as.BotMXID()) @@ -307,6 +318,10 @@ func (as *AppService) BotIntent() *IntentAPI { return as.botIntent } +// SetHomeserverURL updates the appservice's homeserver URL. +// +// Note that this does not affect already-created [IntentAPI] or [mautrix.Client] objects, +// so it should not be called after Intent or Client are called. func (as *AppService) SetHomeserverURL(homeserverURL string) error { parsedURL, err := url.Parse(homeserverURL) if err != nil { @@ -335,6 +350,10 @@ func (as *AppService) SetHomeserverURL(homeserverURL string) error { return nil } +// NewMautrixClient creates a new [mautrix.Client] instance for the given user ID. +// +// This does not do any validation, and it does not cache the client. +// Usually you should prefer [AppService.Client] or [AppService.Intent] over this method. func (as *AppService) NewMautrixClient(userID id.UserID) *mautrix.Client { client := &mautrix.Client{ HomeserverURL: as.hsURLForClient, @@ -351,6 +370,11 @@ func (as *AppService) NewMautrixClient(userID id.UserID) *mautrix.Client { return client } +// NewExternalMautrixClient creates a new [mautrix.Client] instance for an external user, +// with a token and homeserver URL separate from the main appservice. +// +// This is primarily meant to facilitate double puppeting in bridges, and is used by [bridge.doublePuppetUtil]. +// Non-bridge appservices will likely not need this. func (as *AppService) NewExternalMautrixClient(userID id.UserID, token string, homeserverURL string) (*mautrix.Client, error) { client := as.NewMautrixClient(userID) client.AccessToken = token @@ -378,6 +402,11 @@ func (as *AppService) makeClient(userID id.UserID) *mautrix.Client { return client } +// Client returns the [mautrix.Client] instance for the given user ID. +// +// Unlike [AppService.Intent], this does not do any validation, and will always return a value. +// Usually you should prefer creating intents and using intent methods over direct client methods. +// You can always access the client inside an intent with [IntentAPI.Client]. func (as *AppService) Client(userID id.UserID) *mautrix.Client { as.clientsLock.RLock() client, ok := as.clients[userID] @@ -388,6 +417,9 @@ func (as *AppService) Client(userID id.UserID) *mautrix.Client { return client } +// BotClient returns the [mautrix.Client] instance for the appservice's sender_localpart user. +// +// Like with the generic Client method, [AppService.BotIntent] should be preferred over this. func (as *AppService) BotClient() *mautrix.Client { if as.botClient == nil { as.botClient = as.makeClient(as.BotMXID()) From 3756213bae8f75d8ed334d007752d1b7e1559368 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Jan 2024 13:49:51 +0200 Subject: [PATCH 0088/1647] Set type class for decrypted olm events correctly --- crypto/decryptolm.go | 9 +++++---- crypto/machine_test.go | 6 ++++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go index 68eaa875..55614b76 100644 --- a/crypto/decryptolm.go +++ b/crypto/decryptolm.go @@ -57,7 +57,7 @@ func (mach *OlmMachine) decryptOlmEvent(ctx context.Context, evt *event.Event) ( if !ok { return nil, NotEncryptedForMe } - decrypted, err := mach.decryptAndParseOlmCiphertext(ctx, evt.Sender, content.SenderKey, ownContent.Type, ownContent.Body) + decrypted, err := mach.decryptAndParseOlmCiphertext(ctx, evt, content.SenderKey, ownContent.Type, ownContent.Body) if err != nil { return nil, err } @@ -69,13 +69,13 @@ type OlmEventKeys struct { Ed25519 id.Ed25519 `json:"ed25519"` } -func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, sender id.UserID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) (*DecryptedOlmEvent, error) { +func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *event.Event, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) (*DecryptedOlmEvent, error) { if olmType != id.OlmMsgTypePreKey && olmType != id.OlmMsgTypeMsg { return nil, UnsupportedOlmMessageType } endTimeTrace := mach.timeTrace(ctx, "decrypting olm ciphertext", 5*time.Second) - plaintext, err := mach.tryDecryptOlmCiphertext(ctx, sender, senderKey, olmType, ciphertext) + plaintext, err := mach.tryDecryptOlmCiphertext(ctx, evt.Sender, senderKey, olmType, ciphertext) endTimeTrace() if err != nil { return nil, err @@ -88,7 +88,8 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, sender if err != nil { return nil, fmt.Errorf("failed to parse olm payload: %w", err) } - if sender != olmEvt.Sender { + olmEvt.Type.Class = evt.Type.Class + if evt.Sender != olmEvt.Sender { return nil, SenderMismatch } else if mach.Client.UserID != olmEvt.Recipient { return nil, RecipientMismatch diff --git a/crypto/machine_test.go b/crypto/machine_test.go index f1d00ebb..d3750d34 100644 --- a/crypto/machine_test.go +++ b/crypto/machine_test.go @@ -114,12 +114,14 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { for _, content := range content.OlmCiphertext { // decrypt olm ciphertext - decrypted, err := machineIn.decryptAndParseOlmCiphertext(context.TODO(), "user1", senderKey, content.Type, content.Body) + decrypted, err := machineIn.decryptAndParseOlmCiphertext(context.TODO(), &event.Event{ + Type: event.ToDeviceEncrypted, + Sender: "user1", + }, senderKey, content.Type, content.Body) if err != nil { t.Errorf("Error decrypting olm content: %v", err) } // store room key in new inbound group session - decrypted.Content.ParseRaw(event.ToDeviceRoomKey) roomKeyEvt := decrypted.Content.AsRoomKey() igs, err := NewInboundGroupSession(senderKey, signingKey, "room1", roomKeyEvt.SessionKey, 0, 0, false) if err != nil { From 38b67b622dc07479c0ffbed3e123cb39a7106a23 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Jan 2024 18:09:49 +0200 Subject: [PATCH 0089/1647] Fix base64 in SSSS keys (#159) --- crypto/ssss/key.go | 25 +++++++++++++++---------- crypto/ssss/meta.go | 15 ++++++++++----- crypto/ssss/meta_test.go | 4 ++-- crypto/ssss/types.go | 2 ++ crypto/utils/utils.go | 4 ++-- crypto/utils/utils_test.go | 2 +- 6 files changed, 32 insertions(+), 20 deletions(-) diff --git a/crypto/ssss/key.go b/crypto/ssss/key.go index 37a53aa2..3c38d3cd 100644 --- a/crypto/ssss/key.go +++ b/crypto/ssss/key.go @@ -67,12 +67,12 @@ func NewKey(passphrase string) (*Key, error) { if _, err := rand.Read(ivBytes[:]); err != nil { return nil, fmt.Errorf("failed to get random bytes for IV: %w", err) } - keyData.IV = base64.StdEncoding.EncodeToString(ivBytes[:]) + keyData.IV = base64.RawStdEncoding.EncodeToString(ivBytes[:]) keyData.MAC = keyData.calculateHash(ssssKey) return &Key{ Key: ssssKey, - ID: base64.StdEncoding.EncodeToString(keyIDBytes), + ID: base64.RawStdEncoding.EncodeToString(keyIDBytes), Metadata: &keyData, }, nil } @@ -87,13 +87,18 @@ func (key *Key) Encrypt(eventType string, data []byte) EncryptedKeyData { aesKey, hmacKey := utils.DeriveKeysSHA256(key.Key, eventType) iv := utils.GenA256CTRIV() - payload := make([]byte, base64.StdEncoding.EncodedLen(len(data))) - base64.StdEncoding.Encode(payload, data) + // For some reason, keys in secret storage are base64 encoded before encrypting. + // Even more confusingly, it's a part of each key type's spec rather than the secrets spec. + // Key backup (`m.megolm_backup.v1`): https://spec.matrix.org/v1.9/client-server-api/#recovery-key + // Cross-signing (master, etc): https://spec.matrix.org/v1.9/client-server-api/#cross-signing (the very last paragraph) + // It's also not clear whether unpadded base64 is the right option, but assuming it is because everything else is unpadded. + payload := make([]byte, base64.RawStdEncoding.EncodedLen(len(data))) + base64.RawStdEncoding.Encode(payload, data) utils.XorA256CTR(payload, aesKey, iv) return EncryptedKeyData{ - Ciphertext: base64.StdEncoding.EncodeToString(payload), - IV: base64.StdEncoding.EncodeToString(iv[:]), + Ciphertext: base64.RawStdEncoding.EncodeToString(payload), + IV: base64.RawStdEncoding.EncodeToString(iv[:]), MAC: utils.HMACSHA256B64(payload, hmacKey), } } @@ -101,10 +106,10 @@ func (key *Key) Encrypt(eventType string, data []byte) EncryptedKeyData { // Decrypt decrypts the given encrypted data with this key. func (key *Key) Decrypt(eventType string, data EncryptedKeyData) ([]byte, error) { var ivBytes [utils.AESCTRIVLength]byte - decodedIV, _ := base64.StdEncoding.DecodeString(data.IV) + decodedIV, _ := base64.RawStdEncoding.DecodeString(strings.TrimRight(data.IV, "=")) copy(ivBytes[:], decodedIV) - payload, err := base64.StdEncoding.DecodeString(data.Ciphertext) + payload, err := base64.RawStdEncoding.DecodeString(strings.TrimRight(data.Ciphertext, "=")) if err != nil { return nil, err } @@ -114,11 +119,11 @@ func (key *Key) Decrypt(eventType string, data EncryptedKeyData) ([]byte, error) // compare the stored MAC with the one we calculated from the ciphertext calcMac := utils.HMACSHA256B64(payload, hmacKey) - if strings.ReplaceAll(data.MAC, "=", "") != strings.ReplaceAll(calcMac, "=", "") { + if strings.TrimRight(data.MAC, "=") != calcMac { return nil, ErrKeyDataMACMismatch } utils.XorA256CTR(payload, aesKey, ivBytes) - decryptedDecoded, err := base64.StdEncoding.DecodeString(string(payload)) + decryptedDecoded, err := base64.RawStdEncoding.DecodeString(strings.TrimRight(string(payload), "=")) return decryptedDecoded, err } diff --git a/crypto/ssss/meta.go b/crypto/ssss/meta.go index 345b1c77..e752cf0c 100644 --- a/crypto/ssss/meta.go +++ b/crypto/ssss/meta.go @@ -19,9 +19,14 @@ import ( type KeyMetadata struct { id string - Algorithm Algorithm `json:"algorithm"` - IV string `json:"iv"` - MAC string `json:"mac"` + Name string `json:"name"` + Algorithm Algorithm `json:"algorithm"` + + // Note: as per https://spec.matrix.org/v1.9/client-server-api/#msecret_storagev1aes-hmac-sha2, + // these fields are "maybe padded" base64, so both unpadded and padded values must be supported. + IV string `json:"iv"` + MAC string `json:"mac"` + Passphrase *PassphraseMetadata `json:"passphrase,omitempty"` } @@ -59,7 +64,7 @@ func (kd *KeyMetadata) VerifyRecoveryKey(recoverKey string) (*Key, error) { // VerifyKey verifies the SSSS key is valid by calculating and comparing its MAC. func (kd *KeyMetadata) VerifyKey(key []byte) bool { - return strings.ReplaceAll(kd.MAC, "=", "") == strings.ReplaceAll(kd.calculateHash(key), "=", "") + return strings.TrimRight(kd.MAC, "=") == kd.calculateHash(key) } // calculateHash calculates the hash used for checking if the key is entered correctly as described @@ -68,7 +73,7 @@ func (kd *KeyMetadata) calculateHash(key []byte) string { aesKey, hmacKey := utils.DeriveKeysSHA256(key, "") var ivBytes [utils.AESCTRIVLength]byte - _, _ = base64.StdEncoding.Decode(ivBytes[:], []byte(kd.IV)) + _, _ = base64.RawStdEncoding.Decode(ivBytes[:], []byte(strings.TrimRight(kd.IV, "="))) cipher := utils.XorA256CTR(make([]byte, utils.AESCTRKeyLength), aesKey, ivBytes) diff --git a/crypto/ssss/meta_test.go b/crypto/ssss/meta_test.go index e8b41180..2ad8f62a 100644 --- a/crypto/ssss/meta_test.go +++ b/crypto/ssss/meta_test.go @@ -24,8 +24,8 @@ const key1Meta = ` "iterations": 500000, "salt": "y863BOoqOadgDp8S3FtHXikDJEalsQ7d" }, - "iv": "xxkTK0L4UzxgAFkQ6XPwsw==", - "mac": "MEhooO0ZhFJNxUhvRMSxBnJfL20wkLgle3ocY0ee/eA=" + "iv": "xxkTK0L4UzxgAFkQ6XPwsw", + "mac": "MEhooO0ZhFJNxUhvRMSxBnJfL20wkLgle3ocY0ee/eA" } ` const key1ID = "gEJqbfSEMnP5JXXcukpXEX1l0aI3MDs0" diff --git a/crypto/ssss/types.go b/crypto/ssss/types.go index ce5b2df6..ef175928 100644 --- a/crypto/ssss/types.go +++ b/crypto/ssss/types.go @@ -47,6 +47,8 @@ const ( ) type EncryptedKeyData struct { + // Note: as per https://spec.matrix.org/v1.9/client-server-api/#msecret_storagev1aes-hmac-sha2-1, + // these fields are "maybe padded" base64, so both unpadded and padded values must be supported. Ciphertext string `json:"ciphertext"` IV string `json:"iv"` MAC string `json:"mac"` diff --git a/crypto/utils/utils.go b/crypto/utils/utils.go index 382db02f..e2f8a19c 100644 --- a/crypto/utils/utils.go +++ b/crypto/utils/utils.go @@ -124,9 +124,9 @@ func EncodeBase58RecoveryKey(key []byte) string { return spacedKey } -// HMACSHA256B64 calculates the base64 of the SHA256 hmac of the input with the given key. +// HMACSHA256B64 calculates the unpadded base64 of the SHA256 hmac of the input with the given key. func HMACSHA256B64(input []byte, hmacKey [HMACKeyLength]byte) string { h := hmac.New(sha256.New, hmacKey[:]) h.Write(input) - return base64.StdEncoding.EncodeToString(h.Sum(nil)) + return base64.RawStdEncoding.EncodeToString(h.Sum(nil)) } diff --git a/crypto/utils/utils_test.go b/crypto/utils/utils_test.go index eb700b2a..c4f01a68 100644 --- a/crypto/utils/utils_test.go +++ b/crypto/utils/utils_test.go @@ -74,7 +74,7 @@ func TestKeyDerivationAndHMAC(t *testing.T) { } calcMac := HMACSHA256B64(ciphertextBytes, hmacKey) - expectedMac := "0DABPNIZsP9iTOh1o6EM0s7BfHHXb96dN7Eca88jq2E=" + expectedMac := "0DABPNIZsP9iTOh1o6EM0s7BfHHXb96dN7Eca88jq2E" if calcMac != expectedMac { t.Errorf("Expected MAC `%v`, got `%v`", expectedMac, calcMac) } From 97d19484a396eedf5b88d4209ce39ce9418f39e6 Mon Sep 17 00:00:00 2001 From: saces Date: Fri, 2 Dec 2022 23:50:44 +0100 Subject: [PATCH 0090/1647] add missing declarations for m.key.verification.ready Signed-off-by: saces --- event/content.go | 1 + event/type.go | 1 + event/verification.go | 2 ++ 3 files changed, 4 insertions(+) diff --git a/event/content.go b/event/content.go index 24c1c193..86e8da58 100644 --- a/event/content.go +++ b/event/content.go @@ -77,6 +77,7 @@ var TypeMap = map[Type]reflect.Type{ ToDeviceVerificationMAC: reflect.TypeOf(VerificationMacEventContent{}), ToDeviceVerificationCancel: reflect.TypeOf(VerificationCancelEventContent{}), ToDeviceVerificationRequest: reflect.TypeOf(VerificationRequestEventContent{}), + ToDeviceVerificationReady: reflect.TypeOf(VerificationReadyEventContent{}), ToDeviceOrgMatrixRoomKeyWithheld: reflect.TypeOf(RoomKeyWithheldEventContent{}), diff --git a/event/type.go b/event/type.go index c741e8ed..18560588 100644 --- a/event/type.go +++ b/event/type.go @@ -255,6 +255,7 @@ var ( ToDeviceVerificationKey = Type{"m.key.verification.key", ToDeviceEventType} ToDeviceVerificationMAC = Type{"m.key.verification.mac", ToDeviceEventType} ToDeviceVerificationCancel = Type{"m.key.verification.cancel", ToDeviceEventType} + ToDeviceVerificationReady = Type{"m.key.verification.ready", ToDeviceEventType} ToDeviceOrgMatrixRoomKeyWithheld = Type{"org.matrix.room_key.withheld", ToDeviceEventType} diff --git a/event/verification.go b/event/verification.go index 8410904d..66d5abec 100644 --- a/event/verification.go +++ b/event/verification.go @@ -141,6 +141,8 @@ func (vsec *VerificationStartEventContent) SetRelatesTo(rel *RelatesTo) { type VerificationReadyEventContent struct { // The device ID which accepted the process. FromDevice id.DeviceID `json:"from_device"` + // An opaque identifier for the verification process. Must be unique with respect to the devices involved. + TransactionID string `json:"transaction_id,omitempty"` // The verification methods supported by the sender. Methods []VerificationMethod `json:"methods"` // Original event ID for in-room verification. From 94664f1c8a733e05a9d7a046142976193a16049e Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Fri, 19 Jan 2024 20:32:28 +0200 Subject: [PATCH 0091/1647] Secret sharing implementation --- crypto/machine.go | 12 ++ crypto/sharing.go | 173 ++++++++++++++++++ crypto/sql_store.go | 28 +++ .../sql_store_upgrade/00-latest-revision.sql | 7 +- crypto/sql_store_upgrade/12-secrets.sql | 5 + crypto/store.go | 30 +++ event/content.go | 2 + event/encryption.go | 23 +++ event/type.go | 10 +- id/crypto.go | 14 ++ 10 files changed, 300 insertions(+), 4 deletions(-) create mode 100644 crypto/sharing.go create mode 100644 crypto/sql_store_upgrade/12-secrets.sql diff --git a/crypto/machine.go b/crypto/machine.go index b1ecd754..7621ca2f 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -78,6 +78,9 @@ type OlmMachine struct { DeleteKeysOnDeviceDelete bool DisableDeviceChangeKeyRotation bool + + secretLock sync.Mutex + secretListeners map[string]chan<- string } // StateStore is used by OlmMachine to get room state information that's needed for encryption. @@ -119,6 +122,7 @@ func NewOlmMachine(client *mautrix.Client, log *zerolog.Logger, cryptoStore Stor devicesToUnwedge: make(map[id.IdentityKey]bool), recentlyUnwedged: make(map[id.IdentityKey]time.Time), + secretListeners: make(map[string]chan<- string), } mach.AllowKeyShare = mach.defaultAllowKeyShare return mach @@ -357,6 +361,9 @@ func (mach *OlmMachine) HandleEncryptedEvent(ctx context.Context, evt *event.Eve log.Trace().Msg("Handled forwarded room key event") case *event.DummyEventContent: log.Debug().Msg("Received encrypted dummy event") + case *event.SecretSendEventContent: + mach.receiveSecret(ctx, decryptedEvt, decryptedContent) + log.Trace().Msg("Handled secret send event") default: log.Debug().Msg("Unhandled encrypted to-device event") } @@ -407,6 +414,11 @@ func (mach *OlmMachine) HandleToDeviceEvent(ctx context.Context, evt *event.Even mach.handleVerificationRequest(ctx, evt.Sender, content, content.TransactionID, "") case *event.RoomKeyWithheldEventContent: mach.HandleRoomKeyWithheld(ctx, content) + case *event.SecretRequestEventContent: + if content.Action == event.SecretRequestRequest { + mach.HandleSecretRequest(ctx, evt.Sender, content) + log.Trace().Msg("Handled secret request event") + } default: deviceID, _ := evt.Content.Raw["device_id"].(string) log.Debug().Str("maybe_device_id", deviceID).Msg("Unhandled to-device event") diff --git a/crypto/sharing.go b/crypto/sharing.go new file mode 100644 index 00000000..cf14499f --- /dev/null +++ b/crypto/sharing.go @@ -0,0 +1,173 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package crypto + +import ( + "context" + "time" + + "go.mau.fi/util/random" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func (mach *OlmMachine) GetOrRequestSecret(ctx context.Context, name id.Secret, timeout time.Duration) (secret string, err error) { + secret, err = mach.CryptoStore.GetSecret(ctx, name) + if err != nil || secret != "" { + return + } + + requestID, secretChan := random.String(64), make(chan string, 1) + mach.secretLock.Lock() + mach.secretListeners[requestID] = secretChan + mach.secretLock.Unlock() + defer func() { + mach.secretLock.Lock() + delete(mach.secretListeners, requestID) + mach.secretLock.Unlock() + }() + + // request secret from any device + err = mach.sendToOneDevice(ctx, mach.Client.UserID, id.DeviceID("*"), event.ToDeviceSecretRequest, &event.SecretRequestEventContent{ + Action: event.SecretRequestRequest, + RequestID: requestID, + Name: name, + RequestingDeviceID: mach.Client.DeviceID, + }) + if err != nil { + return + } + + select { + case <-ctx.Done(): + err = ctx.Err() + case <-time.After(timeout): + case secret = <-secretChan: + } + + if secret != "" { + err = mach.CryptoStore.PutSecret(ctx, name, secret) + } + return +} + +func (mach *OlmMachine) HandleSecretRequest(ctx context.Context, userID id.UserID, content *event.SecretRequestEventContent) { + log := mach.machOrContextLog(ctx).With(). + Stringer("user_id", userID). + Stringer("requesting_device_id", content.RequestingDeviceID). + Stringer("action", content.Action). + Str("request_id", content.RequestID). + Stringer("secret", content.Name). + Logger() + + log.Trace().Msg("Handling secret request") + + if content.Action == event.SecretRequestCancellation { + log.Trace().Msg("Secret request cancellation is unimplemented, ignoring") + return + } else if content.Action != event.SecretRequestRequest { + log.Warn().Msg("Ignoring unknown secret request action") + return + } + + // immediately ignore requests from other users + if userID != mach.Client.UserID || content.RequestingDeviceID == "" { + log.Debug().Msg("Secret request was not from our own device, ignoring") + return + } + + if content.RequestingDeviceID == mach.Client.DeviceID { + log.Debug().Msg("Secret request was from this device, ignoring") + return + } + + keys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, mach.Client.UserID) + if err != nil { + log.Err(err).Msg("Failed to get cross signing keys from crypto store") + return + } + + crossSigningKey, ok := keys[id.XSUsageSelfSigning] + if !ok { + log.Warn().Msg("Couldn't find self signing key to verify requesting device") + return + } + + device, err := mach.GetOrFetchDevice(ctx, mach.Client.UserID, content.RequestingDeviceID) + if err != nil { + log.Err(err).Msg("Failed to get or fetch requesting device") + return + } + + verified, err := mach.CryptoStore.IsKeySignedBy(ctx, mach.Client.UserID, device.SigningKey, mach.Client.UserID, crossSigningKey.Key) + if err != nil { + log.Err(err).Msg("Failed to check if requesting device is verified") + return + } + + if !verified { + log.Warn().Msg("Requesting device is not verified, ignoring request") + return + } + + secret, err := mach.CryptoStore.GetSecret(ctx, content.Name) + if err != nil { + log.Err(err).Msg("Failed to get secret from store") + return + } else if secret != "" { + log.Debug().Msg("Responding to secret request") + mach.sendToOneDevice(ctx, mach.Client.UserID, content.RequestingDeviceID, event.ToDeviceSecretRequest, &event.SecretSendEventContent{ + RequestID: content.RequestID, + Secret: secret, + }) + } else { + log.Debug().Msg("No stored secret found, secret request ignored") + } +} + +func (mach *OlmMachine) receiveSecret(ctx context.Context, evt *DecryptedOlmEvent, content *event.SecretSendEventContent) { + log := mach.machOrContextLog(ctx).With(). + Stringer("sender", evt.Sender). + Stringer("sender_device", evt.SenderDevice). + Str("request_id", content.RequestID). + Logger() + + log.Trace().Msg("Handling secret send request") + + // immediately ignore secrets from other users + if evt.Sender != mach.Client.UserID { + log.Warn().Msg("Secret send was not from our own device") + return + } else if content.Secret == "" { + log.Warn().Msg("We were sent an empty secret") + return + } + + mach.secretLock.Lock() + secretChan := mach.secretListeners[content.RequestID] + mach.secretLock.Unlock() + + if secretChan == nil { + log.Warn().Msg("We were sent a secret we didn't request") + return + } + + select { + case secretChan <- content.Secret: + default: + } + + // best effort cancel this for all other targets + go func() { + mach.sendToOneDevice(ctx, mach.Client.UserID, id.DeviceID("*"), event.ToDeviceSecretRequest, &event.SecretRequestEventContent{ + Action: event.SecretRequestCancellation, + RequestID: content.RequestID, + RequestingDeviceID: mach.Client.DeviceID, + }) + }() +} diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 99a94f0e..cf46c606 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -21,6 +21,7 @@ import ( "go.mau.fi/util/dbutil" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/sql_store_upgrade" "maunium.net/go/mautrix/event" @@ -845,3 +846,30 @@ func (store *SQLCryptoStore) DropSignaturesByKey(ctx context.Context, userID id. } return count, nil } + +func (store *SQLCryptoStore) PutSecret(ctx context.Context, name id.Secret, value string) error { + bytes, err := cipher.Pickle(store.PickleKey, []byte(value)) + if err != nil { + return err + } + _, err = store.DB.Exec(ctx, ` + INSERT INTO crypto_secrets (name, secret) VALUES ($1, $2) + ON CONFLICT (name) DO UPDATE SET secret=excluded.secret + `, name, bytes) + return err +} + +func (store *SQLCryptoStore) GetSecret(ctx context.Context, name id.Secret) (value string, err error) { + var bytes []byte + err = store.DB.QueryRow(ctx, `SELECT secret FROM crypto_secrets WHERE name=$1`, name).Scan(&bytes) + if errors.Is(err, sql.ErrNoRows) { + return "", nil + } + bytes, err = cipher.Unpickle(store.PickleKey, bytes) + return string(bytes), err +} + +func (store *SQLCryptoStore) DeleteSecret(ctx context.Context, name id.Secret) (err error) { + _, err = store.DB.Exec(ctx, "DELETE FROM crypto_secrets WHERE name=$1", name) + return +} diff --git a/crypto/sql_store_upgrade/00-latest-revision.sql b/crypto/sql_store_upgrade/00-latest-revision.sql index 90d7d31c..4e72b3ae 100644 --- a/crypto/sql_store_upgrade/00-latest-revision.sql +++ b/crypto/sql_store_upgrade/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v11: Latest revision +-- v0 -> v12 (compatible with v9+): Latest revision CREATE TABLE IF NOT EXISTS crypto_account ( account_id TEXT PRIMARY KEY, device_id TEXT NOT NULL, @@ -93,3 +93,8 @@ CREATE TABLE IF NOT EXISTS crypto_cross_signing_signatures ( signature CHAR(88) NOT NULL, PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key) ); + +CREATE TABLE IF NOT EXISTS crypto_secrets ( + name TEXT PRIMARY KEY NOT NULL, + secret bytea NOT NULL +); diff --git a/crypto/sql_store_upgrade/12-secrets.sql b/crypto/sql_store_upgrade/12-secrets.sql new file mode 100644 index 00000000..d9f30ee7 --- /dev/null +++ b/crypto/sql_store_upgrade/12-secrets.sql @@ -0,0 +1,5 @@ +-- v12 (compatible with v9+): Add crypto_secrets table +CREATE TABLE IF NOT EXISTS crypto_secrets ( + name TEXT PRIMARY KEY NOT NULL, + secret bytea NOT NULL +); diff --git a/crypto/store.go b/crypto/store.go index fb3d5b96..5b0f8638 100644 --- a/crypto/store.go +++ b/crypto/store.go @@ -123,6 +123,13 @@ type Store interface { IsKeySignedBy(ctx context.Context, userID id.UserID, key id.Ed25519, signedByUser id.UserID, signedByKey id.Ed25519) (bool, error) // DropSignaturesByKey deletes the signatures made by the given user and key from the store. It returns the number of signatures deleted. DropSignaturesByKey(context.Context, id.UserID, id.Ed25519) (int64, error) + + // PutSecret stores a named secret, replacing it if it exists already. + PutSecret(context.Context, id.Secret, string) error + // GetSecret returns a named secret. + GetSecret(context.Context, id.Secret) (string, error) + // DeleteSecret removes a named secret. + DeleteSecret(context.Context, id.Secret) error } type messageIndexKey struct { @@ -153,6 +160,7 @@ type MemoryStore struct { CrossSigningKeys map[id.UserID]map[id.CrossSigningUsage]id.CrossSigningKey KeySignatures map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string OutdatedUsers map[id.UserID]struct{} + Secrets map[id.Secret]string } var _ Store = (*MemoryStore)(nil) @@ -173,6 +181,7 @@ func NewMemoryStore(saveCallback func() error) *MemoryStore { CrossSigningKeys: make(map[id.UserID]map[id.CrossSigningUsage]id.CrossSigningKey), KeySignatures: make(map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string), OutdatedUsers: make(map[id.UserID]struct{}), + Secrets: make(map[id.Secret]string), } } @@ -645,3 +654,24 @@ func (gs *MemoryStore) DropSignaturesByKey(_ context.Context, userID id.UserID, gs.lock.RUnlock() return count, nil } + +func (gs *MemoryStore) PutSecret(_ context.Context, name id.Secret, value string) error { + gs.lock.Lock() + gs.Secrets[name] = value + gs.lock.Unlock() + return nil +} + +func (gs *MemoryStore) GetSecret(_ context.Context, name id.Secret) (value string, _ error) { + gs.lock.RLock() + value = gs.Secrets[name] + gs.lock.RUnlock() + return +} + +func (gs *MemoryStore) DeleteSecret(_ context.Context, name id.Secret) error { + gs.lock.Lock() + delete(gs.Secrets, name) + gs.lock.Unlock() + return nil +} diff --git a/event/content.go b/event/content.go index 86e8da58..6462fce2 100644 --- a/event/content.go +++ b/event/content.go @@ -69,6 +69,8 @@ var TypeMap = map[Type]reflect.Type{ ToDeviceRoomKeyRequest: reflect.TypeOf(RoomKeyRequestEventContent{}), ToDeviceEncrypted: reflect.TypeOf(EncryptedEventContent{}), ToDeviceRoomKeyWithheld: reflect.TypeOf(RoomKeyWithheldEventContent{}), + ToDeviceSecretRequest: reflect.TypeOf(SecretRequestEventContent{}), + ToDeviceSecretSend: reflect.TypeOf(SecretSendEventContent{}), ToDeviceDummy: reflect.TypeOf(DummyEventContent{}), ToDeviceVerificationStart: reflect.TypeOf(VerificationStartEventContent{}), diff --git a/event/encryption.go b/event/encryption.go index fa1ac2dd..cf9c2814 100644 --- a/event/encryption.go +++ b/event/encryption.go @@ -176,4 +176,27 @@ func (withheld *RoomKeyWithheldEventContent) Is(other error) bool { return withheld.Code == "" || otherWithheld.Code == "" || withheld.Code == otherWithheld.Code } +type SecretRequestAction string + +func (a SecretRequestAction) String() string { + return string(a) +} + +const ( + SecretRequestRequest = "request" + SecretRequestCancellation = "request_cancellation" +) + +type SecretRequestEventContent struct { + Name id.Secret `json:"name,omitempty"` + Action SecretRequestAction `json:"action"` + RequestingDeviceID id.DeviceID `json:"requesting_device_id"` + RequestID string `json:"request_id"` +} + +type SecretSendEventContent struct { + RequestID string `json:"request_id"` + Secret string `json:"secret"` +} + type DummyEventContent struct{} diff --git a/event/type.go b/event/type.go index 18560588..b60d3f08 100644 --- a/event/type.go +++ b/event/type.go @@ -10,6 +10,8 @@ import ( "encoding/json" "fmt" "strings" + + "maunium.net/go/mautrix/id" ) type RoomType string @@ -235,9 +237,9 @@ var ( AccountDataSecretStorageDefaultKey = Type{"m.secret_storage.default_key", AccountDataEventType} AccountDataSecretStorageKey = Type{"m.secret_storage.key", AccountDataEventType} - AccountDataCrossSigningMaster = Type{"m.cross_signing.master", AccountDataEventType} - AccountDataCrossSigningUser = Type{"m.cross_signing.user_signing", AccountDataEventType} - AccountDataCrossSigningSelf = Type{"m.cross_signing.self_signing", AccountDataEventType} + AccountDataCrossSigningMaster = Type{string(id.SecretXSMaster), AccountDataEventType} + AccountDataCrossSigningUser = Type{string(id.SecretXSUserSigning), AccountDataEventType} + AccountDataCrossSigningSelf = Type{string(id.SecretXSSelfSigning), AccountDataEventType} AccountDataMegolmBackupKey = Type{"m.megolm_backup.v1", AccountDataEventType} ) @@ -248,6 +250,8 @@ var ( ToDeviceForwardedRoomKey = Type{"m.forwarded_room_key", ToDeviceEventType} ToDeviceEncrypted = Type{"m.room.encrypted", ToDeviceEventType} ToDeviceRoomKeyWithheld = Type{"m.room_key.withheld", ToDeviceEventType} + ToDeviceSecretRequest = Type{"m.secret.request", ToDeviceEventType} + ToDeviceSecretSend = Type{"m.secret.send", ToDeviceEventType} ToDeviceDummy = Type{"m.dummy", ToDeviceEventType} ToDeviceVerificationRequest = Type{"m.key.verification.request", ToDeviceEventType} ToDeviceVerificationStart = Type{"m.key.verification.start", ToDeviceEventType} diff --git a/id/crypto.go b/id/crypto.go index e920a301..f28e3d88 100644 --- a/id/crypto.go +++ b/id/crypto.go @@ -153,3 +153,17 @@ type CrossSigningKey struct { Key Ed25519 First Ed25519 } + +// Secret storage keys +type Secret string + +func (s Secret) String() string { + return string(s) +} + +const ( + SecretXSMaster Secret = "m.cross_signing.master" + SecretXSSelfSigning Secret = "m.cross_signing.self_signing" + SecretXSUserSigning Secret = "m.cross_signing.user_signing" + SecretMegolmBackupV1 Secret = "m.megolm_backup.v1" +) From 8dc80b3178b1a4e829aa0e6869fb142616c18aed Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Wed, 24 Jan 2024 12:08:03 +0200 Subject: [PATCH 0092/1647] Share room keys to known devices on request If we have shared a session with a device once, allow asking for it again. --- crypto/encryptmegolm.go | 11 ++++ crypto/keysharing.go | 14 ++++- crypto/machine.go | 3 + crypto/sql_store.go | 14 +++++ .../sql_store_upgrade/00-latest-revision.sql | 10 +++- .../13-megolm-session-sharing.sql | 9 +++ crypto/store.go | 41 ++++++++++++++ crypto/store_test.go | 56 +++++++++++++++++++ 8 files changed, 154 insertions(+), 4 deletions(-) create mode 100644 crypto/sql_store_upgrade/13-megolm-session-sharing.sql diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index dcd36dc1..62bcc044 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -330,6 +330,17 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session Str("target_user_id", userID.String()). Str("target_device_id", deviceID.String()). Msg("Encrypted group session for device") + if !mach.DisableSharedGroupSessionTracking { + err := mach.CryptoStore.MarkOutboundGroupSessionShared(ctx, userID, device.identity.IdentityKey, session.id) + if err != nil { + log.Warn(). + Err(err). + Str("target_user_id", userID.String()). + Str("target_device_id", deviceID.String()). + Stringer("target_session_id", session.id). + Msg("Failed to mark outbound group session shared") + } + } } } diff --git a/crypto/keysharing.go b/crypto/keysharing.go index 09da1d1a..bc9bc61a 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -222,11 +222,19 @@ func (mach *OlmMachine) rejectKeyRequest(ctx context.Context, rejection KeyShare } } -func (mach *OlmMachine) defaultAllowKeyShare(ctx context.Context, device *id.Device, _ event.RequestedKeyInfo) *KeyShareRejection { +func (mach *OlmMachine) defaultAllowKeyShare(ctx context.Context, device *id.Device, evt event.RequestedKeyInfo) *KeyShareRejection { log := mach.machOrContextLog(ctx) if mach.Client.UserID != device.UserID { - log.Debug().Msg("Rejecting key request from a different user") - return &KeyShareRejectOtherUser + isShared, err := mach.CryptoStore.IsOutboundGroupSessionShared(ctx, device.UserID, device.IdentityKey, evt.SessionID) + if err != nil { + log.Err(err).Msg("Rejecting key request due to internal error when checking session sharing") + return &KeyShareRejectNoResponse + } else if !isShared { + log.Debug().Msg("Rejecting key request for unshared session") + return &KeyShareRejectOtherUser + } + log.Debug().Msg("Accepting key request for shared session") + return nil } else if mach.Client.DeviceID == device.DeviceID { log.Debug().Msg("Ignoring key request from ourselves") return &KeyShareRejectNoResponse diff --git a/crypto/machine.go b/crypto/machine.go index 7621ca2f..180e05f0 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -36,6 +36,9 @@ type OlmMachine struct { // Never ask the server for keys automatically as a side effect during Megolm decryption. DisableDecryptKeyFetching bool + // Don't mark outbound Olm sessions as shared for devices they were initially sent to. + DisableSharedGroupSessionTracking bool + SendKeysMinTrust id.TrustState ShareKeysMinTrust id.TrustState diff --git a/crypto/sql_store.go b/crypto/sql_store.go index cf46c606..cb9d621a 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -569,6 +569,20 @@ func (store *SQLCryptoStore) RemoveOutboundGroupSession(ctx context.Context, roo return err } +func (store *SQLCryptoStore) MarkOutboundGroupSessionShared(ctx context.Context, userID id.UserID, identityKey id.IdentityKey, sessionID id.SessionID) error { + _, err := store.DB.Exec(ctx, "INSERT INTO crypto_megolm_outbound_session_shared (user_id, identity_key, session_id) VALUES ($1, $2, $3)", userID, identityKey, sessionID) + return err +} + +func (store *SQLCryptoStore) IsOutboundGroupSessionShared(ctx context.Context, userID id.UserID, identityKey id.IdentityKey, sessionID id.SessionID) (shared bool, err error) { + err = store.DB.QueryRow(ctx, `SELECT TRUE FROM crypto_megolm_outbound_session_shared WHERE user_id=$1 AND identity_key=$2 AND session_id=$3`, + userID, identityKey, sessionID).Scan(&shared) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return +} + // ValidateMessageIndex returns whether the given event information match the ones stored in the database // for the given sender key, session ID and index. If the index hasn't been stored, this will store it. func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) { diff --git a/crypto/sql_store_upgrade/00-latest-revision.sql b/crypto/sql_store_upgrade/00-latest-revision.sql index 4e72b3ae..a8c31153 100644 --- a/crypto/sql_store_upgrade/00-latest-revision.sql +++ b/crypto/sql_store_upgrade/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v12 (compatible with v9+): Latest revision +-- v0 -> v13 (compatible with v9+): Latest revision CREATE TABLE IF NOT EXISTS crypto_account ( account_id TEXT PRIMARY KEY, device_id TEXT NOT NULL, @@ -75,6 +75,14 @@ CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session ( PRIMARY KEY (account_id, room_id) ); +CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session_shared ( + user_id TEXT NOT NULL, + identity_key CHAR(43) NOT NULL, + session_id CHAR(43) NOT NULL, + + PRIMARY KEY (user_id, identity_key, session_id) +); + CREATE TABLE IF NOT EXISTS crypto_cross_signing_keys ( user_id TEXT, usage TEXT, diff --git a/crypto/sql_store_upgrade/13-megolm-session-sharing.sql b/crypto/sql_store_upgrade/13-megolm-session-sharing.sql new file mode 100644 index 00000000..ea69f3cf --- /dev/null +++ b/crypto/sql_store_upgrade/13-megolm-session-sharing.sql @@ -0,0 +1,9 @@ +-- v13 (compatible with v9+): Add crypto_megolm_outbound_session_shared table + +CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session_shared ( + user_id TEXT NOT NULL, + identity_key CHAR(43) NOT NULL, + session_id CHAR(43) NOT NULL, + + PRIMARY KEY (user_id, identity_key, session_id) +); diff --git a/crypto/store.go b/crypto/store.go index 5b0f8638..f900a3fa 100644 --- a/crypto/store.go +++ b/crypto/store.go @@ -84,6 +84,10 @@ type Store interface { GetOutboundGroupSession(context.Context, id.RoomID) (*OutboundGroupSession, error) // RemoveOutboundGroupSession removes the stored outbound Megolm session for the given room ID. RemoveOutboundGroupSession(context.Context, id.RoomID) error + // MarkOutboutGroupSessionShared flags that the currently known device has been shared the keys for the specified session. + MarkOutboundGroupSessionShared(context.Context, id.UserID, id.IdentityKey, id.SessionID) error + // IsOutboutGroupSessionShared checks if the specified session has been shared with the device. + IsOutboundGroupSessionShared(context.Context, id.UserID, id.IdentityKey, id.SessionID) (bool, error) // ValidateMessageIndex validates that the given message details aren't from a replay attack. // @@ -155,6 +159,7 @@ type MemoryStore struct { GroupSessions map[id.RoomID]map[id.SenderKey]map[id.SessionID]*InboundGroupSession WithheldGroupSessions map[id.RoomID]map[id.SenderKey]map[id.SessionID]*event.RoomKeyWithheldEventContent OutGroupSessions map[id.RoomID]*OutboundGroupSession + SharedGroupSessions map[id.UserID]map[id.IdentityKey]map[id.SessionID]struct{} MessageIndices map[messageIndexKey]messageIndexValue Devices map[id.UserID]map[id.DeviceID]*id.Device CrossSigningKeys map[id.UserID]map[id.CrossSigningUsage]id.CrossSigningKey @@ -176,6 +181,7 @@ func NewMemoryStore(saveCallback func() error) *MemoryStore { GroupSessions: make(map[id.RoomID]map[id.SenderKey]map[id.SessionID]*InboundGroupSession), WithheldGroupSessions: make(map[id.RoomID]map[id.SenderKey]map[id.SessionID]*event.RoomKeyWithheldEventContent), OutGroupSessions: make(map[id.RoomID]*OutboundGroupSession), + SharedGroupSessions: make(map[id.UserID]map[id.IdentityKey]map[id.SessionID]struct{}), MessageIndices: make(map[messageIndexKey]messageIndexValue), Devices: make(map[id.UserID]map[id.DeviceID]*id.Device), CrossSigningKeys: make(map[id.UserID]map[id.CrossSigningUsage]id.CrossSigningKey), @@ -435,6 +441,41 @@ func (gs *MemoryStore) RemoveOutboundGroupSession(_ context.Context, roomID id.R return nil } +func (gs *MemoryStore) MarkOutboundGroupSessionShared(_ context.Context, userID id.UserID, identityKey id.IdentityKey, sessionID id.SessionID) error { + gs.lock.Lock() + + if _, ok := gs.SharedGroupSessions[userID]; !ok { + gs.SharedGroupSessions[userID] = make(map[id.IdentityKey]map[id.SessionID]struct{}) + } + identities := gs.SharedGroupSessions[userID] + + if _, ok := identities[identityKey]; !ok { + identities[identityKey] = make(map[id.SessionID]struct{}) + } + + identities[identityKey][sessionID] = struct{}{} + + gs.lock.Unlock() + return nil +} + +func (gs *MemoryStore) IsOutboundGroupSessionShared(_ context.Context, userID id.UserID, identityKey id.IdentityKey, sessionID id.SessionID) (isShared bool, err error) { + gs.lock.Lock() + defer gs.lock.Unlock() + + if _, ok := gs.SharedGroupSessions[userID]; !ok { + return + } + identities := gs.SharedGroupSessions[userID] + + if _, ok := identities[identityKey]; !ok { + return + } + + _, isShared = identities[identityKey][sessionID] + return +} + func (gs *MemoryStore) ValidateMessageIndex(_ context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) { gs.lock.Lock() defer gs.lock.Unlock() diff --git a/crypto/store_test.go b/crypto/store_test.go index bbadef28..b3a3b2b7 100644 --- a/crypto/store_test.go +++ b/crypto/store_test.go @@ -217,6 +217,62 @@ func TestStoreOutboundMegolmSession(t *testing.T) { } } +func TestStoreOutboundMegolmSessionSharing(t *testing.T) { + stores := getCryptoStores(t) + + resetDevice := func() *id.Device { + acc := NewOlmAccount() + return &id.Device{ + UserID: "user1", + DeviceID: id.DeviceID("dev1"), + IdentityKey: acc.IdentityKey(), + SigningKey: acc.SigningKey(), + } + } + + for storeName, store := range stores { + t.Run(storeName, func(t *testing.T) { + device := resetDevice() + err := store.PutDevice(context.TODO(), "user1", device) + if err != nil { + t.Errorf("Error storing devices: %v", err) + } + + shared, err := store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1") + if err != nil { + t.Errorf("Error checking if outbound group session is shared: %v", err) + } else if shared { + t.Errorf("Outbound group session shared when it shouldn't") + } + + err = store.MarkOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1") + if err != nil { + t.Errorf("Error marking outbound group session as shared: %v", err) + } + + shared, err = store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1") + if err != nil { + t.Errorf("Error checking if outbound group session is shared: %v", err) + } else if !shared { + t.Errorf("Outbound group session not shared when it should") + } + + device = resetDevice() + err = store.PutDevice(context.TODO(), "user1", device) + if err != nil { + t.Errorf("Error storing devices: %v", err) + } + + shared, err = store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1") + if err != nil { + t.Errorf("Error checking if outbound group session is shared: %v", err) + } else if shared { + t.Errorf("Outbound group session shared when it shouldn't") + } + }) + } +} + func TestStoreDevices(t *testing.T) { stores := getCryptoStores(t) for storeName, store := range stores { From 1f2e75ee94927d52f86b675388293796160d0c49 Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Fri, 26 Jan 2024 14:12:59 +0200 Subject: [PATCH 0093/1647] Import group session signing key from backup --- crypto/backup/megolmbackup.go | 14 +++++++++----- crypto/keybackup.go | 1 + 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/crypto/backup/megolmbackup.go b/crypto/backup/megolmbackup.go index 2ec425c0..dea3e704 100644 --- a/crypto/backup/megolmbackup.go +++ b/crypto/backup/megolmbackup.go @@ -15,15 +15,19 @@ type MegolmAuthData struct { Signatures signatures.Signatures `json:"signatures"` } +type SenderClaimedKeys struct { + Ed25519 id.Ed25519 `json:"ed25519"` +} + // MegolmSessionData is the decrypted session_data when the key backup is created // with the [id.KeyBackupAlgorithmMegolmBackupV1] algorithm as defined in // [Section 11.12.3.2.2 of the Spec]. // // [Section 11.12.3.2.2 of the Spec]: https://spec.matrix.org/v1.9/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 type MegolmSessionData struct { - Algorithm id.Algorithm `json:"algorithm"` - ForwardingKeyChain []string `json:"forwarding_curve25519_key_chain"` - SenderClaimedKeys map[id.KeyAlgorithm]string `json:"sender_claimed_keys"` - SenderKey id.SenderKey `json:"sender_key"` - SessionKey []byte `json:"session_key"` + Algorithm id.Algorithm `json:"algorithm"` + ForwardingKeyChain []string `json:"forwarding_curve25519_key_chain"` + SenderClaimedKeys SenderClaimedKeys `json:"sender_claimed_keys"` + SenderKey id.SenderKey `json:"sender_key"` + SessionKey []byte `json:"session_key"` } diff --git a/crypto/keybackup.go b/crypto/keybackup.go index 4b2be50a..8e7f1de2 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -137,6 +137,7 @@ func (mach *OlmMachine) importRoomKeyFromBackup(ctx context.Context, roomID id.R igs := &InboundGroupSession{ Internal: *igsInternal, + SigningKey: keyBackupData.SenderClaimedKeys.Ed25519, SenderKey: keyBackupData.SenderKey, RoomID: roomID, ForwardingChains: append(keyBackupData.ForwardingKeyChain, keyBackupData.SenderKey.String()), From 1027a3acde90a63df40fcffe546b8ae3f52a37df Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 26 Jan 2024 10:35:12 -0700 Subject: [PATCH 0094/1647] keybackup: skip bad keys in backup instead of erroring Signed-off-by: Sumner Evans --- crypto/keybackup.go | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/crypto/keybackup.go b/crypto/keybackup.go index 8e7f1de2..36ee7930 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -85,24 +85,31 @@ func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, meg return err } - var count int + var count, failedCount int for roomID, backup := range keys.Rooms { for sessionID, keyBackupData := range backup.Sessions { sessionData, err := keyBackupData.SessionData.Decrypt(megolmBackupKey) if err != nil { - return err + log.Warn().Err(err).Msg("Failed to decrypt session data") + failedCount++ + continue } err = mach.importRoomKeyFromBackup(ctx, roomID, sessionID, sessionData) if err != nil { - return err + log.Warn().Err(err).Msg("Failed to import room key from backup") + failedCount++ + continue } count++ } } - log.Info().Int("count", count).Msg("successfully imported sessions from backup") + log.Info(). + Int("count", count). + Int("failed_count", failedCount). + Msg("successfully imported sessions from backup") return nil } From 0ecb22e6f0e8764f8bf84d06100666e0c16a2e13 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 26 Jan 2024 19:55:58 +0200 Subject: [PATCH 0095/1647] Add m.fully_read to event type class guesser --- event/type.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/event/type.go b/event/type.go index b60d3f08..e39a3695 100644 --- a/event/type.go +++ b/event/type.go @@ -118,7 +118,8 @@ func (et *Type) GuessClass() TypeClass { return EphemeralEventType case AccountDataDirectChats.Type, AccountDataPushRules.Type, AccountDataRoomTags.Type, AccountDataSecretStorageKey.Type, AccountDataSecretStorageDefaultKey.Type, - AccountDataCrossSigningMaster.Type, AccountDataCrossSigningSelf.Type, AccountDataCrossSigningUser.Type: + AccountDataCrossSigningMaster.Type, AccountDataCrossSigningSelf.Type, AccountDataCrossSigningUser.Type, + AccountDataFullyRead.Type: return AccountDataEventType case EventRedaction.Type, EventMessage.Type, EventEncrypted.Type, EventReaction.Type, EventSticker.Type, InRoomVerificationStart.Type, InRoomVerificationReady.Type, InRoomVerificationAccept.Type, From cfe3ce5f7bf8301ed39c53fa083f099c99a288ae Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Fri, 26 Jan 2024 14:28:16 +0200 Subject: [PATCH 0096/1647] Add test for secret store --- crypto/store_test.go | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/crypto/store_test.go b/crypto/store_test.go index b3a3b2b7..e6969e3e 100644 --- a/crypto/store_test.go +++ b/crypto/store_test.go @@ -352,3 +352,23 @@ func TestStoreDevices(t *testing.T) { }) } } + +func TestStoreSecrets(t *testing.T) { + stores := getCryptoStores(t) + for storeName, store := range stores { + t.Run(storeName, func(t *testing.T) { + storedSecret := "trustno1" + err := store.PutSecret(context.TODO(), id.SecretMegolmBackupV1, storedSecret) + if err != nil { + t.Errorf("Error storing secret: %v", err) + } + + secret, err := store.GetSecret(context.TODO(), id.SecretMegolmBackupV1) + if err != nil { + t.Errorf("Error storing secret: %v", err) + } else if secret != storedSecret { + t.Errorf("Stored secret did not match: '%s' != '%s'", secret, storedSecret) + } + }) + } +} From dd925e228bc6a031854064776b9fed1c9fc4a7d6 Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Fri, 26 Jan 2024 14:27:57 +0200 Subject: [PATCH 0097/1647] Allow restoring single sessions from backups --- crypto/keybackup.go | 39 ++++++++++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/crypto/keybackup.go b/crypto/keybackup.go index 36ee7930..680a348b 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -8,6 +8,7 @@ import ( "github.com/rs/zerolog" + "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/backup" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/signatures" @@ -18,16 +19,30 @@ func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, meg log := mach.machOrContextLog(ctx).With(). Str("action", "download and store latest key backup"). Logger() - versionInfo, err := mach.Client.GetKeyBackupLatestVersion(ctx) + + ctx = log.WithContext(ctx) + + versionInfo, err := mach.GetAndVerifyLatestKeyBackupVersion(ctx) if err != nil { return err + } else if versionInfo == nil { + return nil + } + + return mach.GetAndStoreKeyBackup(ctx, versionInfo.Version, megolmBackupKey) +} + +func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context) (*mautrix.RespRoomKeysVersion[backup.MegolmAuthData], error) { + versionInfo, err := mach.Client.GetKeyBackupLatestVersion(ctx) + if err != nil { + return nil, err } if versionInfo.Algorithm != id.KeyBackupAlgorithmMegolmBackupV1 { - return fmt.Errorf("unsupported key backup algorithm: %s", versionInfo.Algorithm) + return nil, fmt.Errorf("unsupported key backup algorithm: %s", versionInfo.Algorithm) } - log = log.With(). + log := mach.machOrContextLog(ctx).With(). Int("count", versionInfo.Count). Str("etag", versionInfo.ETag). Str("key_backup_version", versionInfo.Version). @@ -35,12 +50,12 @@ func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, meg if versionInfo.Count == 0 { log.Debug().Msg("No keys found in key backup") - return nil + return nil, nil } userSignatures, ok := versionInfo.AuthData.Signatures[mach.Client.UserID] if !ok { - return fmt.Errorf("no signature from user %s found in key backup", mach.Client.UserID) + return nil, fmt.Errorf("no signature from user %s found in key backup", mach.Client.UserID) } crossSigningPubkeys := mach.GetOwnCrossSigningPublicKeys(ctx) @@ -77,14 +92,20 @@ func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, meg } } if !signatureVerified { - return fmt.Errorf("no valid signature from user %s found in key backup", mach.Client.UserID) + return nil, fmt.Errorf("no valid signature from user %s found in key backup", mach.Client.UserID) } - keys, err := mach.Client.GetKeyBackup(ctx, versionInfo.Version) + return versionInfo, nil +} + +func (mach *OlmMachine) GetAndStoreKeyBackup(ctx context.Context, version string, megolmBackupKey *backup.MegolmBackupKey) error { + keys, err := mach.Client.GetKeyBackup(ctx, version) if err != nil { return err } + log := zerolog.Ctx(ctx) + var count, failedCount int for roomID, backup := range keys.Rooms { @@ -96,7 +117,7 @@ func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, meg continue } - err = mach.importRoomKeyFromBackup(ctx, roomID, sessionID, sessionData) + err = mach.ImportRoomKeyFromBackup(ctx, roomID, sessionID, sessionData) if err != nil { log.Warn().Err(err).Msg("Failed to import room key from backup") failedCount++ @@ -114,7 +135,7 @@ func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, meg return nil } -func (mach *OlmMachine) importRoomKeyFromBackup(ctx context.Context, roomID id.RoomID, sessionID id.SessionID, keyBackupData *backup.MegolmSessionData) error { +func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, roomID id.RoomID, sessionID id.SessionID, keyBackupData *backup.MegolmSessionData) error { log := zerolog.Ctx(ctx).With(). Str("room_id", roomID.String()). Str("session_id", sessionID.String()). From a36dc59187ce6562594442192c7da73fcae5afb4 Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Thu, 1 Feb 2024 10:04:50 +0200 Subject: [PATCH 0098/1647] Plumb event.AccountDataMegolmBackupKey properly --- crypto/ssss/types.go | 1 + event/type.go | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/crypto/ssss/types.go b/crypto/ssss/types.go index ef175928..60852c55 100644 --- a/crypto/ssss/types.go +++ b/crypto/ssss/types.go @@ -74,4 +74,5 @@ func init() { event.TypeMap[event.AccountDataCrossSigningUser] = encryptedContent event.TypeMap[event.AccountDataSecretStorageDefaultKey] = reflect.TypeOf(&DefaultSecretStorageKeyContent{}) event.TypeMap[event.AccountDataSecretStorageKey] = reflect.TypeOf(&KeyMetadata{}) + event.TypeMap[event.AccountDataMegolmBackupKey] = reflect.TypeOf(&EncryptedAccountDataEventContent{}) } diff --git a/event/type.go b/event/type.go index e39a3695..e7d47818 100644 --- a/event/type.go +++ b/event/type.go @@ -119,7 +119,7 @@ func (et *Type) GuessClass() TypeClass { case AccountDataDirectChats.Type, AccountDataPushRules.Type, AccountDataRoomTags.Type, AccountDataSecretStorageKey.Type, AccountDataSecretStorageDefaultKey.Type, AccountDataCrossSigningMaster.Type, AccountDataCrossSigningSelf.Type, AccountDataCrossSigningUser.Type, - AccountDataFullyRead.Type: + AccountDataFullyRead.Type, AccountDataMegolmBackupKey.Type: return AccountDataEventType case EventRedaction.Type, EventMessage.Type, EventEncrypted.Type, EventReaction.Type, EventSticker.Type, InRoomVerificationStart.Type, InRoomVerificationReady.Type, InRoomVerificationAccept.Type, From 75bb8452fa60a91b370490afb1f7d35f020670c2 Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Thu, 1 Feb 2024 12:41:52 +0200 Subject: [PATCH 0099/1647] Always verify and return key backup version Otherwise we can't save the known backup version if the remote one is empty. --- crypto/keybackup.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/crypto/keybackup.go b/crypto/keybackup.go index 680a348b..328309dc 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -48,11 +48,6 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context) Str("key_backup_version", versionInfo.Version). Logger() - if versionInfo.Count == 0 { - log.Debug().Msg("No keys found in key backup") - return nil, nil - } - userSignatures, ok := versionInfo.AuthData.Signatures[mach.Client.UserID] if !ok { return nil, fmt.Errorf("no signature from user %s found in key backup", mach.Client.UserID) From e08ed2384508f7a330e66cf34801da3abd6e7826 Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Thu, 1 Feb 2024 12:42:32 +0200 Subject: [PATCH 0100/1647] Fix SQL CryptoStore GetSecret error handling --- crypto/sql_store.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/crypto/sql_store.go b/crypto/sql_store.go index cb9d621a..1d86fec9 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -878,6 +878,8 @@ func (store *SQLCryptoStore) GetSecret(ctx context.Context, name id.Secret) (val err = store.DB.QueryRow(ctx, `SELECT secret FROM crypto_secrets WHERE name=$1`, name).Scan(&bytes) if errors.Is(err, sql.ErrNoRows) { return "", nil + } else if err != nil { + return "", err } bytes, err = cipher.Unpickle(store.PickleKey, bytes) return string(bytes), err From 11c2907f2e5ae8600714ec1dec3b01a5087cfc3a Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Tue, 30 Jan 2024 14:48:52 +0200 Subject: [PATCH 0101/1647] Database level support for key backup versioning This doesn't plumb anything in yet but adds the columns and types for an external implementation. Key backup version is now typed. --- client.go | 42 ++++++------- crypto/account.go | 9 +-- crypto/keybackup.go | 24 ++++---- crypto/machine.go | 15 ++++- crypto/sessions.go | 9 +-- crypto/sql_store.go | 59 +++++++++++-------- .../sql_store_upgrade/00-latest-revision.sql | 42 ++++++------- .../14-account-key-backup-version.sql | 4 ++ crypto/store.go | 34 ++++++++--- go.mod | 2 +- go.sum | 4 +- id/crypto.go | 7 +++ requests.go | 10 ++-- responses.go | 4 +- 14 files changed, 160 insertions(+), 105 deletions(-) create mode 100644 crypto/sql_store_upgrade/14-account-key-backup-version.sql diff --git a/client.go b/client.go index 18f2c019..2712789b 100644 --- a/client.go +++ b/client.go @@ -1948,9 +1948,9 @@ func (cli *Client) GetKeyChanges(ctx context.Context, from, to string) (resp *Re // GetKeyBackup retrieves the keys from the backup. // // See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keyskeys -func (cli *Client) GetKeyBackup(ctx context.Context, version string) (resp *RespRoomKeys[backup.EncryptedSessionData[backup.MegolmSessionData]], err error) { +func (cli *Client) GetKeyBackup(ctx context.Context, version id.KeyBackupVersion) (resp *RespRoomKeys[backup.EncryptedSessionData[backup.MegolmSessionData]], err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys"}, map[string]string{ - "version": version, + "version": string(version), }) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return @@ -1959,9 +1959,9 @@ func (cli *Client) GetKeyBackup(ctx context.Context, version string) (resp *Resp // PutKeysInBackup stores several keys in the backup. // // See: https://spec.matrix.org/v1.9/client-server-api/#put_matrixclientv3room_keyskeys -func (cli *Client) PutKeysInBackup(ctx context.Context, version string, req *ReqKeyBackup) (resp *RespRoomKeysUpdate, err error) { +func (cli *Client) PutKeysInBackup(ctx context.Context, version id.KeyBackupVersion, req *ReqKeyBackup) (resp *RespRoomKeysUpdate, err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys"}, map[string]string{ - "version": version, + "version": string(version), }) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, req, &resp) return @@ -1970,9 +1970,9 @@ func (cli *Client) PutKeysInBackup(ctx context.Context, version string, req *Req // DeleteKeyBackup deletes all keys from the backup. // // See: https://spec.matrix.org/v1.9/client-server-api/#delete_matrixclientv3room_keyskeys -func (cli *Client) DeleteKeyBackup(ctx context.Context, version string) (resp *RespRoomKeysUpdate, err error) { +func (cli *Client) DeleteKeyBackup(ctx context.Context, version id.KeyBackupVersion) (resp *RespRoomKeysUpdate, err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys"}, map[string]string{ - "version": version, + "version": string(version), }) _, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, &resp) return @@ -1982,10 +1982,10 @@ func (cli *Client) DeleteKeyBackup(ctx context.Context, version string) (resp *R // // See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keyskeysroomid func (cli *Client) GetKeyBackupForRoom( - ctx context.Context, version string, roomID id.RoomID, + ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, ) (resp *RespRoomKeyBackup[backup.EncryptedSessionData[backup.MegolmSessionData]], err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String()}, map[string]string{ - "version": version, + "version": string(version), }) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return @@ -1994,9 +1994,9 @@ func (cli *Client) GetKeyBackupForRoom( // PutKeysInBackupForRoom stores several keys in the backup for the given room. // // See: https://spec.matrix.org/v1.9/client-server-api/#put_matrixclientv3room_keyskeysroomid -func (cli *Client) PutKeysInBackupForRoom(ctx context.Context, version string, roomID id.RoomID, req *ReqRoomKeyBackup) (resp *RespRoomKeysUpdate, err error) { +func (cli *Client) PutKeysInBackupForRoom(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, req *ReqRoomKeyBackup) (resp *RespRoomKeysUpdate, err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String()}, map[string]string{ - "version": version, + "version": string(version), }) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, req, &resp) return @@ -2006,9 +2006,9 @@ func (cli *Client) PutKeysInBackupForRoom(ctx context.Context, version string, r // room. // // See: https://spec.matrix.org/v1.9/client-server-api/#delete_matrixclientv3room_keyskeysroomid -func (cli *Client) DeleteKeysFromBackupForRoom(ctx context.Context, version string, roomID id.RoomID) (resp *RespRoomKeysUpdate, err error) { +func (cli *Client) DeleteKeysFromBackupForRoom(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID) (resp *RespRoomKeysUpdate, err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String()}, map[string]string{ - "version": version, + "version": string(version), }) _, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, &resp) return @@ -2018,10 +2018,10 @@ func (cli *Client) DeleteKeysFromBackupForRoom(ctx context.Context, version stri // // See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keyskeysroomidsessionid func (cli *Client) GetKeyBackupForRoomAndSession( - ctx context.Context, version string, roomID id.RoomID, sessionID id.SessionID, + ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID, ) (resp *RespKeyBackupData[backup.EncryptedSessionData[backup.MegolmSessionData]], err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String(), sessionID.String()}, map[string]string{ - "version": version, + "version": string(version), }) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return @@ -2030,9 +2030,9 @@ func (cli *Client) GetKeyBackupForRoomAndSession( // PutKeysInBackupForRoomAndSession stores a key in the backup. // // See: https://spec.matrix.org/v1.9/client-server-api/#put_matrixclientv3room_keyskeysroomidsessionid -func (cli *Client) PutKeysInBackupForRoomAndSession(ctx context.Context, version string, roomID id.RoomID, sessionID id.SessionID, req *ReqKeyBackupData) (resp *RespRoomKeysUpdate, err error) { +func (cli *Client) PutKeysInBackupForRoomAndSession(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID, req *ReqKeyBackupData) (resp *RespRoomKeysUpdate, err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String(), sessionID.String()}, map[string]string{ - "version": version, + "version": string(version), }) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, req, &resp) return @@ -2041,9 +2041,9 @@ func (cli *Client) PutKeysInBackupForRoomAndSession(ctx context.Context, version // DeleteKeysInBackupForRoomAndSession deletes a key from the backup. // // See: https://spec.matrix.org/v1.9/client-server-api/#delete_matrixclientv3room_keyskeysroomidsessionid -func (cli *Client) DeleteKeysInBackupForRoomAndSession(ctx context.Context, version string, roomID id.RoomID, sessionID id.SessionID) (resp *RespRoomKeysUpdate, err error) { +func (cli *Client) DeleteKeysInBackupForRoomAndSession(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID) (resp *RespRoomKeysUpdate, err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String(), sessionID.String()}, map[string]string{ - "version": version, + "version": string(version), }) _, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, &resp) return @@ -2070,7 +2070,7 @@ func (cli *Client) CreateKeyBackupVersion(ctx context.Context, req *ReqRoomKeysV // GetKeyBackupVersion returns information about an existing key backup. // // See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keysversionversion -func (cli *Client) GetKeyBackupVersion(ctx context.Context, version string) (resp *RespRoomKeysVersion[backup.MegolmAuthData], err error) { +func (cli *Client) GetKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) (resp *RespRoomKeysVersion[backup.MegolmAuthData], err error) { urlPath := cli.BuildClientURL("v3", "room_keys", "version", version) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return @@ -2080,7 +2080,7 @@ func (cli *Client) GetKeyBackupVersion(ctx context.Context, version string) (res // the auth_data can be modified. // // See: https://spec.matrix.org/v1.9/client-server-api/#put_matrixclientv3room_keysversionversion -func (cli *Client) UpdateKeyBackupVersion(ctx context.Context, version string, req *ReqRoomKeysVersionUpdate) error { +func (cli *Client) UpdateKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion, req *ReqRoomKeysVersionUpdate) error { urlPath := cli.BuildClientURL("v3", "room_keys", "version", version) _, err := cli.MakeRequest(ctx, http.MethodPut, urlPath, nil, nil) return err @@ -2091,7 +2091,7 @@ func (cli *Client) UpdateKeyBackupVersion(ctx context.Context, version string, r // deleted. // // See: https://spec.matrix.org/v1.1/client-server-api/#delete_matrixclientv3room_keysversionversion -func (cli *Client) DeleteKeyBackupVersion(ctx context.Context, version string) error { +func (cli *Client) DeleteKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) error { urlPath := cli.BuildClientURL("v3", "room_keys", "version", version) _, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, nil) return err diff --git a/crypto/account.go b/crypto/account.go index 78fbfa5f..d242df6f 100644 --- a/crypto/account.go +++ b/crypto/account.go @@ -14,10 +14,11 @@ import ( ) type OlmAccount struct { - Internal olm.Account - signingKey id.SigningKey - identityKey id.IdentityKey - Shared bool + Internal olm.Account + signingKey id.SigningKey + identityKey id.IdentityKey + Shared bool + KeyBackupVersion id.KeyBackupVersion } func NewOlmAccount() *OlmAccount { diff --git a/crypto/keybackup.go b/crypto/keybackup.go index 328309dc..9090e76c 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -15,7 +15,7 @@ import ( "maunium.net/go/mautrix/id" ) -func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, megolmBackupKey *backup.MegolmBackupKey) error { +func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, megolmBackupKey *backup.MegolmBackupKey) (id.KeyBackupVersion, error) { log := mach.machOrContextLog(ctx).With(). Str("action", "download and store latest key backup"). Logger() @@ -24,12 +24,13 @@ func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, meg versionInfo, err := mach.GetAndVerifyLatestKeyBackupVersion(ctx) if err != nil { - return err + return "", err } else if versionInfo == nil { - return nil + return "", nil } - return mach.GetAndStoreKeyBackup(ctx, versionInfo.Version, megolmBackupKey) + err = mach.GetAndStoreKeyBackup(ctx, versionInfo.Version, megolmBackupKey) + return versionInfo.Version, err } func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context) (*mautrix.RespRoomKeysVersion[backup.MegolmAuthData], error) { @@ -45,7 +46,7 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context) log := mach.machOrContextLog(ctx).With(). Int("count", versionInfo.Count). Str("etag", versionInfo.ETag). - Str("key_backup_version", versionInfo.Version). + Stringer("key_backup_version", versionInfo.Version). Logger() userSignatures, ok := versionInfo.AuthData.Signatures[mach.Client.UserID] @@ -93,7 +94,7 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context) return versionInfo, nil } -func (mach *OlmMachine) GetAndStoreKeyBackup(ctx context.Context, version string, megolmBackupKey *backup.MegolmBackupKey) error { +func (mach *OlmMachine) GetAndStoreKeyBackup(ctx context.Context, version id.KeyBackupVersion, megolmBackupKey *backup.MegolmBackupKey) error { keys, err := mach.Client.GetKeyBackup(ctx, version) if err != nil { return err @@ -112,7 +113,7 @@ func (mach *OlmMachine) GetAndStoreKeyBackup(ctx context.Context, version string continue } - err = mach.ImportRoomKeyFromBackup(ctx, roomID, sessionID, sessionData) + err = mach.ImportRoomKeyFromBackup(ctx, version, roomID, sessionID, sessionData) if err != nil { log.Warn().Err(err).Msg("Failed to import room key from backup") failedCount++ @@ -130,7 +131,7 @@ func (mach *OlmMachine) GetAndStoreKeyBackup(ctx context.Context, version string return nil } -func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, roomID id.RoomID, sessionID id.SessionID, keyBackupData *backup.MegolmSessionData) error { +func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID, keyBackupData *backup.MegolmSessionData) error { log := zerolog.Ctx(ctx).With(). Str("room_id", roomID.String()). Str("session_id", sessionID.String()). @@ -166,9 +167,10 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, roomID id.R ForwardingChains: append(keyBackupData.ForwardingKeyChain, keyBackupData.SenderKey.String()), id: sessionID, - ReceivedAt: time.Now().UTC(), - MaxAge: maxAge.Milliseconds(), - MaxMessages: maxMessages, + ReceivedAt: time.Now().UTC(), + MaxAge: maxAge.Milliseconds(), + MaxMessages: maxMessages, + KeyBackupVersion: version, } err = mach.CryptoStore.PutGroupSession(ctx, roomID, keyBackupData.SenderKey, sessionID, igs) if err != nil { diff --git a/crypto/machine.go b/crypto/machine.go index 180e05f0..e5058ed8 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -152,11 +152,21 @@ func (mach *OlmMachine) Load(ctx context.Context) (err error) { return nil } -func (mach *OlmMachine) saveAccount(ctx context.Context) { +func (mach *OlmMachine) saveAccount(ctx context.Context) error { err := mach.CryptoStore.PutAccount(ctx, mach.account) if err != nil { mach.Log.Error().Err(err).Msg("Failed to save account") } + return err +} + +func (mach *OlmMachine) KeyBackupVersion() id.KeyBackupVersion { + return mach.account.KeyBackupVersion +} + +func (mach *OlmMachine) SetKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) error { + mach.account.KeyBackupVersion = version + return mach.saveAccount(ctx) } // FlushStore calls the Flush method of the CryptoStore. @@ -698,8 +708,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro } mach.lastOTKUpload = time.Now() mach.account.Shared = true - mach.saveAccount(ctx) - return nil + return mach.saveAccount(ctx) } func (mach *OlmMachine) ExpiredKeyDeleteLoop(ctx context.Context) { diff --git a/crypto/sessions.go b/crypto/sessions.go index ad8c2ae8..045af933 100644 --- a/crypto/sessions.go +++ b/crypto/sessions.go @@ -105,10 +105,11 @@ type InboundGroupSession struct { ForwardingChains []string RatchetSafety RatchetSafety - ReceivedAt time.Time - MaxAge int64 - MaxMessages int - IsScheduled bool + ReceivedAt time.Time + MaxAge int64 + MaxMessages int + IsScheduled bool + KeyBackupVersion id.KeyBackupVersion id id.SessionID } diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 1d86fec9..ef1be25b 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -125,20 +125,21 @@ func (store *SQLCryptoStore) PutAccount(ctx context.Context, account *OlmAccount store.Account = account bytes := account.Internal.Pickle(store.PickleKey) _, err := store.DB.Exec(ctx, ` - INSERT INTO crypto_account (device_id, shared, sync_token, account, account_id) VALUES ($1, $2, $3, $4, $5) + INSERT INTO crypto_account (device_id, shared, sync_token, account, account_id, key_backup_version) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (account_id) DO UPDATE SET shared=excluded.shared, sync_token=excluded.sync_token, - account=excluded.account, account_id=excluded.account_id - `, store.DeviceID, account.Shared, store.SyncToken, bytes, store.AccountID) + account=excluded.account, account_id=excluded.account_id, + key_backup_version=excluded.key_backup_version + `, store.DeviceID, account.Shared, store.SyncToken, bytes, store.AccountID, account.KeyBackupVersion) return err } // GetAccount retrieves an OlmAccount from the database. func (store *SQLCryptoStore) GetAccount(ctx context.Context) (*OlmAccount, error) { if store.Account == nil { - row := store.DB.QueryRow(ctx, "SELECT shared, sync_token, account FROM crypto_account WHERE account_id=$1", store.AccountID) + row := store.DB.QueryRow(ctx, "SELECT shared, sync_token, account, key_backup_version FROM crypto_account WHERE account_id=$1", store.AccountID) acc := &OlmAccount{Internal: *olm.NewBlankAccount()} var accountBytes []byte - err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes) + err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes, &acc.KeyBackupVersion) if err == sql.ErrNoRows { return nil, nil } else if err != nil { @@ -285,17 +286,18 @@ func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, roomID id.Room _, err = store.DB.Exec(ctx, ` INSERT INTO crypto_megolm_inbound_session ( session_id, sender_key, signing_key, room_id, session, forwarding_chains, - ratchet_safety, received_at, max_age, max_messages, is_scheduled, account_id - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, account_id + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) ON CONFLICT (session_id, account_id) DO UPDATE SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key, room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains, ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at, - max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled + max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled, + key_backup_version=excluded.key_backup_version `, sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains, ratchetSafety, datePtr(session.ReceivedAt), intishPtr(session.MaxAge), intishPtr(session.MaxMessages), - session.IsScheduled, store.AccountID, + session.IsScheduled, session.KeyBackupVersion, store.AccountID, ) return err } @@ -307,12 +309,13 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room var receivedAt sql.NullTime var maxAge, maxMessages sql.NullInt64 var isScheduled bool + var version id.KeyBackupVersion err := store.DB.QueryRow(ctx, ` - SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled + SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version FROM crypto_megolm_inbound_session WHERE room_id=$1 AND (sender_key=$2 OR $2 = '') AND session_id=$3 AND account_id=$4`, roomID, senderKey, sessionID, store.AccountID, - ).Scan(&senderKeyDB, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled) + ).Scan(&senderKeyDB, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version) if errors.Is(err, sql.ErrNoRows) { return nil, nil } else if err != nil { @@ -342,6 +345,7 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room MaxAge: maxAge.Int64, MaxMessages: int(maxMessages.Int64), IsScheduled: isScheduled, + KeyBackupVersion: version, }, nil } @@ -469,7 +473,8 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In var receivedAt sql.NullTime var maxAge, maxMessages sql.NullInt64 var isScheduled bool - err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled) + var version id.KeyBackupVersion + err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version) if err != nil { return nil, err } @@ -485,31 +490,35 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In MaxAge: maxAge.Int64, MaxMessages: int(maxMessages.Int64), IsScheduled: isScheduled, + KeyBackupVersion: version, }, nil } -func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) ([]*InboundGroupSession, error) { +func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) dbutil.RowIter[*InboundGroupSession] { rows, err := store.DB.Query(ctx, ` - SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled + SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`, roomID, store.AccountID, ) - if err != nil { - return nil, err - } - return dbutil.NewRowIter(rows, store.scanInboundGroupSession).AsList() + return dbutil.NewRowIterWithError(rows, store.scanInboundGroupSession, err) } -func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) ([]*InboundGroupSession, error) { +func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.RowIter[*InboundGroupSession] { rows, err := store.DB.Query(ctx, ` - SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled - FROM crypto_megolm_inbound_session WHERE account_id=$2 AND session IS NOT NULL`, + SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version + FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL`, store.AccountID, ) - if err != nil { - return nil, err - } - return dbutil.NewRowIter(rows, store.scanInboundGroupSession).AsList() + return dbutil.NewRowIterWithError(rows, store.scanInboundGroupSession, err) +} + +func (store *SQLCryptoStore) GetGroupSessionsWithoutKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession] { + rows, err := store.DB.Query(ctx, ` + SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version + FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL AND key_backup_version != $2`, + store.AccountID, version, + ) + return dbutil.NewRowIterWithError(rows, store.scanInboundGroupSession, err) } // AddOutboundGroupSession stores an outbound Megolm session, along with the information about the room and involved devices. diff --git a/crypto/sql_store_upgrade/00-latest-revision.sql b/crypto/sql_store_upgrade/00-latest-revision.sql index a8c31153..06aea750 100644 --- a/crypto/sql_store_upgrade/00-latest-revision.sql +++ b/crypto/sql_store_upgrade/00-latest-revision.sql @@ -1,10 +1,11 @@ --- v0 -> v13 (compatible with v9+): Latest revision +-- v0 -> v14 (compatible with v9+): Latest revision CREATE TABLE IF NOT EXISTS crypto_account ( - account_id TEXT PRIMARY KEY, - device_id TEXT NOT NULL, - shared BOOLEAN NOT NULL, - sync_token TEXT NOT NULL, - account bytea NOT NULL + account_id TEXT PRIMARY KEY, + device_id TEXT NOT NULL, + shared BOOLEAN NOT NULL, + sync_token TEXT NOT NULL, + account bytea NOT NULL, + key_backup_version TEXT NOT NULL DEFAULT '' ); CREATE TABLE IF NOT EXISTS crypto_message_index ( @@ -44,20 +45,21 @@ CREATE TABLE IF NOT EXISTS crypto_olm_session ( ); CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session ( - account_id TEXT, - session_id CHAR(43), - sender_key CHAR(43) NOT NULL, - signing_key CHAR(43), - room_id TEXT NOT NULL, - session bytea, - forwarding_chains bytea, - withheld_code TEXT, - withheld_reason TEXT, - ratchet_safety jsonb, - received_at timestamp, - max_age BIGINT, - max_messages INTEGER, - is_scheduled BOOLEAN NOT NULL DEFAULT false, + account_id TEXT, + session_id CHAR(43), + sender_key CHAR(43) NOT NULL, + signing_key CHAR(43), + room_id TEXT NOT NULL, + session bytea, + forwarding_chains bytea, + withheld_code TEXT, + withheld_reason TEXT, + ratchet_safety jsonb, + received_at timestamp, + max_age BIGINT, + max_messages INTEGER, + is_scheduled BOOLEAN NOT NULL DEFAULT false, + key_backup_version TEXT NOT NULL DEFAULT '', PRIMARY KEY (account_id, session_id) ); diff --git a/crypto/sql_store_upgrade/14-account-key-backup-version.sql b/crypto/sql_store_upgrade/14-account-key-backup-version.sql new file mode 100644 index 00000000..e5236b62 --- /dev/null +++ b/crypto/sql_store_upgrade/14-account-key-backup-version.sql @@ -0,0 +1,4 @@ +-- v14 (compatible with v9+): Add key_backup_version column to account and igs + +ALTER TABLE crypto_account ADD COLUMN key_backup_version TEXT NOT NULL DEFAULT ''; +ALTER TABLE crypto_megolm_inbound_session ADD COLUMN key_backup_version TEXT NOT NULL DEFAULT ''; diff --git a/crypto/store.go b/crypto/store.go index f900a3fa..3b6e6564 100644 --- a/crypto/store.go +++ b/crypto/store.go @@ -12,6 +12,8 @@ import ( "sort" "sync" + "go.mau.fi/util/dbutil" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -68,10 +70,12 @@ type Store interface { // GetGroupSessionsForRoom gets all the inbound Megolm sessions for a specific room. This is used for creating key // export files. Unlike GetGroupSession, this should not return any errors about withheld keys. - GetGroupSessionsForRoom(context.Context, id.RoomID) ([]*InboundGroupSession, error) + GetGroupSessionsForRoom(context.Context, id.RoomID) dbutil.RowIter[*InboundGroupSession] // GetAllGroupSessions gets all the inbound Megolm sessions in the store. This is used for creating key export // files. Unlike GetGroupSession, this should not return any errors about withheld keys. - GetAllGroupSessions(context.Context) ([]*InboundGroupSession, error) + GetAllGroupSessions(context.Context) dbutil.RowIter[*InboundGroupSession] + // GetGroupSessionsWithoutKeyBackupVersion gets all the inbound Megolm sessions in the store that do not match given key backup version. + GetGroupSessionsWithoutKeyBackupVersion(context.Context, id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession] // AddOutboundGroupSession inserts the given outbound Megolm session into the store. // @@ -376,12 +380,12 @@ func (gs *MemoryStore) GetWithheldGroupSession(_ context.Context, roomID id.Room return session, nil } -func (gs *MemoryStore) GetGroupSessionsForRoom(_ context.Context, roomID id.RoomID) ([]*InboundGroupSession, error) { +func (gs *MemoryStore) GetGroupSessionsForRoom(_ context.Context, roomID id.RoomID) dbutil.RowIter[*InboundGroupSession] { gs.lock.Lock() defer gs.lock.Unlock() room, ok := gs.GroupSessions[roomID] if !ok { - return []*InboundGroupSession{}, nil + return nil } var result []*InboundGroupSession for _, sessions := range room { @@ -389,10 +393,10 @@ func (gs *MemoryStore) GetGroupSessionsForRoom(_ context.Context, roomID id.Room result = append(result, session) } } - return result, nil + return dbutil.NewSliceIter[*InboundGroupSession](result) } -func (gs *MemoryStore) GetAllGroupSessions(_ context.Context) ([]*InboundGroupSession, error) { +func (gs *MemoryStore) GetAllGroupSessions(_ context.Context) dbutil.RowIter[*InboundGroupSession] { gs.lock.Lock() var result []*InboundGroupSession for _, room := range gs.GroupSessions { @@ -403,7 +407,23 @@ func (gs *MemoryStore) GetAllGroupSessions(_ context.Context) ([]*InboundGroupSe } } gs.lock.Unlock() - return result, nil + return dbutil.NewSliceIter[*InboundGroupSession](result) +} + +func (gs *MemoryStore) GetGroupSessionsWithoutKeyBackupVersion(_ context.Context, version id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession] { + gs.lock.Lock() + var result []*InboundGroupSession + for _, room := range gs.GroupSessions { + for _, sessions := range room { + for _, session := range sessions { + if session.KeyBackupVersion != version { + result = append(result, session) + } + } + } + } + gs.lock.Unlock() + return dbutil.NewSliceIter[*InboundGroupSession](result) } func (gs *MemoryStore) AddOutboundGroupSession(_ context.Context, session *OutboundGroupSession) error { diff --git a/go.mod b/go.mod index 48ff59e0..08eb341b 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/tidwall/gjson v1.17.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.6.0 - go.mau.fi/util v0.3.0 + go.mau.fi/util v0.3.1-0.20240131162106-bcac615a2941 go.mau.fi/zeroconfig v0.1.2 golang.org/x/crypto v0.18.0 golang.org/x/exp v0.0.0-20240112132812-db7319d0e0e3 diff --git a/go.sum b/go.sum index 9061a651..64186d85 100644 --- a/go.sum +++ b/go.sum @@ -36,8 +36,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.6.0 h1:boZcn2GTjpsynOsC0iJHnBWa4Bi0qzfJjthwauItG68= github.com/yuin/goldmark v1.6.0/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.mau.fi/util v0.3.0 h1:Lt3lbRXP6ZBqTINK0EieRWor3zEwwwrDT14Z5N8RUCs= -go.mau.fi/util v0.3.0/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs= +go.mau.fi/util v0.3.1-0.20240131162106-bcac615a2941 h1:F9ySn0OM0uFqcGDQM2WUqlFJh4UCBYNfeSxzJd0kknM= +go.mau.fi/util v0.3.1-0.20240131162106-bcac615a2941/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc= diff --git a/id/crypto.go b/id/crypto.go index f28e3d88..9334198e 100644 --- a/id/crypto.go +++ b/id/crypto.go @@ -50,6 +50,13 @@ const ( KeyBackupAlgorithmMegolmBackupV1 KeyBackupAlgorithm = "m.megolm_backup.v1.curve25519-aes-sha2" ) +// BackupVersion is an arbitrary string that identifies a server side key backup. +type KeyBackupVersion string + +func (version KeyBackupVersion) String() string { + return string(version) +} + // A SessionID is an arbitrary string that identifies an Olm or Megolm session. type SessionID string diff --git a/requests.go b/requests.go index 1551e63b..61fe8a55 100644 --- a/requests.go +++ b/requests.go @@ -429,14 +429,14 @@ type ReqBeeperSplitRoom struct { } type ReqRoomKeysVersionCreate struct { - Algorithm string `json:"algorithm"` - AuthData json.RawMessage `json:"auth_data"` + Algorithm id.KeyBackupAlgorithm `json:"algorithm"` + AuthData json.RawMessage `json:"auth_data"` } type ReqRoomKeysVersionUpdate struct { - Algorithm string `json:"algorithm"` - AuthData json.RawMessage `json:"auth_data"` - Version string `json:"version,omitempty"` + Algorithm id.KeyBackupAlgorithm `json:"algorithm"` + AuthData json.RawMessage `json:"auth_data"` + Version id.KeyBackupVersion `json:"version,omitempty"` } type ReqKeyBackup struct { diff --git a/responses.go b/responses.go index b8552b58..e182a722 100644 --- a/responses.go +++ b/responses.go @@ -593,7 +593,7 @@ type RespTimestampToEvent struct { } type RespRoomKeysVersionCreate struct { - Version string `json:"version"` + Version id.KeyBackupVersion `json:"version"` } type RespRoomKeysVersion[A any] struct { @@ -601,7 +601,7 @@ type RespRoomKeysVersion[A any] struct { AuthData A `json:"auth_data"` Count int `json:"count"` ETag string `json:"etag"` - Version string `json:"version"` + Version id.KeyBackupVersion `json:"version"` } type RespRoomKeys[S any] struct { From b131dab9de88ed96adc75b5eeb5d32a60f008b6b Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Fri, 2 Feb 2024 14:44:22 +0200 Subject: [PATCH 0102/1647] Allow any UI auth for uploading cross signing keys Fix endless loop with UI auth causing 401 when uploading keys. Use any type for key backup setup request auth data so that unmarshaled objects can be used that have signatures embedded. Generating keys will now also return them if we also want to setup key backup without storing the keys in-between. --- client.go | 8 ++-- crypto/cross_sign_ssss.go | 77 +++++++++++++++++++++++---------------- requests.go | 13 ++++--- 3 files changed, 57 insertions(+), 41 deletions(-) diff --git a/client.go b/client.go index 2712789b..6562559f 100644 --- a/client.go +++ b/client.go @@ -2061,7 +2061,7 @@ func (cli *Client) GetKeyBackupLatestVersion(ctx context.Context) (resp *RespRoo // CreateKeyBackupVersion creates a new key backup. // // See: https://spec.matrix.org/v1.9/client-server-api/#post_matrixclientv3room_keysversion -func (cli *Client) CreateKeyBackupVersion(ctx context.Context, req *ReqRoomKeysVersionCreate) (resp *RespRoomKeysVersionCreate, err error) { +func (cli *Client) CreateKeyBackupVersion(ctx context.Context, req *ReqRoomKeysVersionCreate[backup.MegolmAuthData]) (resp *RespRoomKeysVersionCreate, err error) { urlPath := cli.BuildClientURL("v3", "room_keys", "version") _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) return @@ -2080,7 +2080,7 @@ func (cli *Client) GetKeyBackupVersion(ctx context.Context, version id.KeyBackup // the auth_data can be modified. // // See: https://spec.matrix.org/v1.9/client-server-api/#put_matrixclientv3room_keysversionversion -func (cli *Client) UpdateKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion, req *ReqRoomKeysVersionUpdate) error { +func (cli *Client) UpdateKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion, req *ReqRoomKeysVersionUpdate[backup.MegolmAuthData]) error { urlPath := cli.BuildClientURL("v3", "room_keys", "version", version) _, err := cli.MakeRequest(ctx, http.MethodPut, urlPath, nil, nil) return err @@ -2145,7 +2145,7 @@ func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCross RequestJSON: keys, SensitiveContent: keys.Auth != nil, }) - if respErr, ok := err.(HTTPError); ok && respErr.IsStatus(http.StatusUnauthorized) { + if respErr, ok := err.(HTTPError); ok && respErr.IsStatus(http.StatusUnauthorized) && uiaCallback != nil { // try again with UI auth var uiAuthResp RespUserInteractive if err := json.Unmarshal(content, &uiAuthResp); err != nil { @@ -2154,7 +2154,7 @@ func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCross auth := uiaCallback(&uiAuthResp) if auth != nil { keys.Auth = auth - return cli.UploadCrossSigningKeys(ctx, keys, uiaCallback) + return cli.UploadCrossSigningKeys(ctx, keys, nil) } } return err diff --git a/crypto/cross_sign_ssss.go b/crypto/cross_sign_ssss.go index ef8a0ad3..a87f87de 100644 --- a/crypto/cross_sign_ssss.go +++ b/crypto/cross_sign_ssss.go @@ -14,6 +14,7 @@ import ( "maunium.net/go/mautrix/crypto/ssss" "maunium.net/go/mautrix/crypto/utils" "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" ) // FetchCrossSigningKeysFromSSSS fetches all the cross-signing keys from SSSS, decrypts them using the given key and stores them in the olm machine. @@ -57,33 +58,8 @@ func (mach *OlmMachine) retrieveDecryptXSigningKey(ctx context.Context, keyName return decryptedKey, nil } -// GenerateAndUploadCrossSigningKeys generates a new key with all corresponding cross-signing keys. -// -// A passphrase can be provided to generate the SSSS key. If the passphrase is empty, a random key -// is used. The base58-formatted recovery key is the first return parameter. -// -// The account password of the user is required for uploading keys to the server. -func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(ctx context.Context, userPassword, passphrase string) (string, error) { - key, err := mach.SSSS.GenerateAndUploadKey(ctx, passphrase) - if err != nil { - return "", fmt.Errorf("failed to generate and upload SSSS key: %w", err) - } - - // generate the three cross-signing keys - keysCache, err := mach.GenerateCrossSigningKeys() - if err != nil { - return "", err - } - - recoveryKey := key.RecoveryKey() - - // Store the private keys in SSSS - if err := mach.UploadCrossSigningKeysToSSSS(ctx, key, keysCache); err != nil { - return recoveryKey, fmt.Errorf("failed to upload cross-signing keys to SSSS: %w", err) - } - - // Publish cross-signing keys - err = mach.PublishCrossSigningKeys(ctx, keysCache, func(uiResp *mautrix.RespUserInteractive) interface{} { +func (mach *OlmMachine) GenerateAndUploadCrossSigningKeysWithPassword(ctx context.Context, userPassword, passphrase string) (string, *CrossSigningKeysCache, error) { + return mach.GenerateAndUploadCrossSigningKeys(ctx, func(uiResp *mautrix.RespUserInteractive) interface{} { return &mautrix.ReqUIAuthLogin{ BaseAuthData: mautrix.BaseAuthData{ Type: mautrix.AuthTypePassword, @@ -92,17 +68,44 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(ctx context.Context, u User: mach.Client.UserID.String(), Password: userPassword, } - }) + }, passphrase) +} + +// GenerateAndUploadCrossSigningKeys generates a new key with all corresponding cross-signing keys. +// +// A passphrase can be provided to generate the SSSS key. If the passphrase is empty, a random key +// is used. The base58-formatted recovery key is the first return parameter. +// +// The account password of the user is required for uploading keys to the server. +func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(ctx context.Context, uiaCallback mautrix.UIACallback, passphrase string) (string, *CrossSigningKeysCache, error) { + key, err := mach.SSSS.GenerateAndUploadKey(ctx, passphrase) if err != nil { - return recoveryKey, fmt.Errorf("failed to publish cross-signing keys: %w", err) + return "", nil, fmt.Errorf("failed to generate and upload SSSS key: %w", err) + } + + // generate the three cross-signing keys + keysCache, err := mach.GenerateCrossSigningKeys() + if err != nil { + return "", nil, err + } + + // Store the private keys in SSSS + if err := mach.UploadCrossSigningKeysToSSSS(ctx, key, keysCache); err != nil { + return "", nil, fmt.Errorf("failed to upload cross-signing keys to SSSS: %w", err) + } + + // Publish cross-signing keys + err = mach.PublishCrossSigningKeys(ctx, keysCache, uiaCallback) + if err != nil { + return "", nil, fmt.Errorf("failed to publish cross-signing keys: %w", err) } err = mach.SSSS.SetDefaultKeyID(ctx, key.ID) if err != nil { - return recoveryKey, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err) + return "", nil, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err) } - return recoveryKey, nil + return key.RecoveryKey(), keysCache, nil } // UploadCrossSigningKeysToSSSS stores the given cross-signing keys on the server encrypted with the given key. @@ -116,5 +119,17 @@ func (mach *OlmMachine) UploadCrossSigningKeysToSSSS(ctx context.Context, key *s if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningUser, keys.UserSigningKey.Seed, key); err != nil { return err } + + // Also store these locally + if err := mach.CryptoStore.PutCrossSigningKey(ctx, mach.Client.UserID, id.XSUsageMaster, keys.MasterKey.PublicKey); err != nil { + return err + } + if err := mach.CryptoStore.PutCrossSigningKey(ctx, mach.Client.UserID, id.XSUsageSelfSigning, keys.SelfSigningKey.PublicKey); err != nil { + return err + } + if err := mach.CryptoStore.PutCrossSigningKey(ctx, mach.Client.UserID, id.XSUsageUserSigning, keys.UserSigningKey.PublicKey); err != nil { + return err + } + return nil } diff --git a/requests.go b/requests.go index 61fe8a55..cdf020a0 100644 --- a/requests.go +++ b/requests.go @@ -97,8 +97,9 @@ type ReqUIAuthFallback struct { type ReqUIAuthLogin struct { BaseAuthData - User string `json:"user"` - Password string `json:"password"` + User string `json:"user,omitempty"` + Password string `json:"password,omitempty"` + Token string `json:"token,omitempty"` } // ReqCreateRoom is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3createroom @@ -428,14 +429,14 @@ type ReqBeeperSplitRoom struct { Parts []BeeperSplitRoomPart `json:"parts"` } -type ReqRoomKeysVersionCreate struct { +type ReqRoomKeysVersionCreate[A any] struct { Algorithm id.KeyBackupAlgorithm `json:"algorithm"` - AuthData json.RawMessage `json:"auth_data"` + AuthData A `json:"auth_data"` } -type ReqRoomKeysVersionUpdate struct { +type ReqRoomKeysVersionUpdate[A any] struct { Algorithm id.KeyBackupAlgorithm `json:"algorithm"` - AuthData json.RawMessage `json:"auth_data"` + AuthData A `json:"auth_data"` Version id.KeyBackupVersion `json:"version,omitempty"` } From a94162cde540625a53bdd6f4c6e4c9cd92eea35e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 8 Feb 2024 10:22:14 +0200 Subject: [PATCH 0103/1647] Update dependencies and bump minimum Go version to 1.21 --- .github/workflows/go.yml | 4 ++-- go.mod | 18 +++++++++--------- go.sum | 35 ++++++++++++++++++----------------- 3 files changed, 29 insertions(+), 28 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 602d5ece..5de7470c 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -11,7 +11,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: "1.21" + go-version: "1.22" cache: true - name: Install libolm @@ -33,7 +33,7 @@ jobs: strategy: fail-fast: false matrix: - go-version: ["1.20", "1.21"] + go-version: ["1.21", "1.22"] steps: - uses: actions/checkout@v4 diff --git a/go.mod b/go.mod index 08eb341b..ce2a08a1 100644 --- a/go.mod +++ b/go.mod @@ -1,22 +1,22 @@ module maunium.net/go/mautrix -go 1.20 +go 1.21 require ( github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.5.0 github.com/lib/pq v1.10.9 - github.com/mattn/go-sqlite3 v1.14.19 - github.com/rs/zerolog v1.31.0 + github.com/mattn/go-sqlite3 v1.14.22 + github.com/rs/zerolog v1.32.0 github.com/stretchr/testify v1.8.4 github.com/tidwall/gjson v1.17.0 github.com/tidwall/sjson v1.2.5 - github.com/yuin/goldmark v1.6.0 - go.mau.fi/util v0.3.1-0.20240131162106-bcac615a2941 + github.com/yuin/goldmark v1.7.0 + go.mau.fi/util v0.3.1-0.20240208081958-5f386f84d7e2 go.mau.fi/zeroconfig v0.1.2 - golang.org/x/crypto v0.18.0 - golang.org/x/exp v0.0.0-20240112132812-db7319d0e0e3 - golang.org/x/net v0.20.0 + golang.org/x/crypto v0.19.0 + golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3 + golang.org/x/net v0.21.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 maunium.net/go/maulogger/v2 v2.4.1 @@ -30,6 +30,6 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect - golang.org/x/sys v0.16.0 // indirect + golang.org/x/sys v0.17.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index 64186d85..8ae63c70 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,5 @@ -github.com/DATA-DOG/go-sqlmock v1.5.1 h1:FK6RCIUSfmbnI/imIICmboyQBkOckutaa6R5YYlLZyo= +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -15,14 +16,14 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.19 h1:fhGleo2h1p8tVChob4I9HpmVFIAkKGpiukdrgQbWfGI= -github.com/mattn/go-sqlite3 v1.14.19/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= -github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A= -github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0= +github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= @@ -34,23 +35,23 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/yuin/goldmark v1.6.0 h1:boZcn2GTjpsynOsC0iJHnBWa4Bi0qzfJjthwauItG68= -github.com/yuin/goldmark v1.6.0/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.mau.fi/util v0.3.1-0.20240131162106-bcac615a2941 h1:F9ySn0OM0uFqcGDQM2WUqlFJh4UCBYNfeSxzJd0kknM= -go.mau.fi/util v0.3.1-0.20240131162106-bcac615a2941/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs= +github.com/yuin/goldmark v1.7.0 h1:EfOIvIMZIzHdB/R/zVrikYLPPwJlfMcNczJFMs1m6sA= +github.com/yuin/goldmark v1.7.0/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= +go.mau.fi/util v0.3.1-0.20240208081958-5f386f84d7e2 h1:jh3ji7lyshRySDa/uiv+wFH35VvBREpeO39WJ0gR+Zk= +go.mau.fi/util v0.3.1-0.20240208081958-5f386f84d7e2/go.mod h1:rRypwgXVEPILomtFPyQcnbOeuRqf+nRN84vh/CICq4w= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc= -golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/exp v0.0.0-20240112132812-db7319d0e0e3 h1:hNQpMuAJe5CtcUqCXaWga3FHu+kQvCqcsoVaQgSV60o= -golang.org/x/exp v0.0.0-20240112132812-db7319d0e0e3/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= -golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo= -golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= +golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3 h1:/RIbNt/Zr7rVhIkQhooTxCxFcdWLGIKnZA4IXNFSrvo= +golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= +golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= -golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= From 2d1786ced444250bc6495537ddd4a88b8a33cc20 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 8 Feb 2024 10:56:32 +0200 Subject: [PATCH 0104/1647] Configure zerolog globals with exzerolog --- bridge/bridge.go | 7 +------ go.mod | 2 +- go.sum | 4 ++-- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/bridge/bridge.go b/bridge/bridge.go index 767c28cf..8841cc37 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -23,7 +23,6 @@ import ( "github.com/lib/pq" "github.com/mattn/go-sqlite3" "github.com/rs/zerolog" - deflog "github.com/rs/zerolog/log" "go.mau.fi/util/configupgrade" "go.mau.fi/util/dbutil" _ "go.mau.fi/util/dbutil/litestream" @@ -516,11 +515,7 @@ func (br *Bridge) init() { _, _ = fmt.Fprintln(os.Stderr, "Failed to initialize logger:", err) os.Exit(12) } - defaultCtxLog := br.ZLog.With().Bool("default_context_log", true).Caller().Logger() - zerolog.TimeFieldFormat = time.RFC3339Nano - zerolog.CallerMarshalFunc = exzerolog.CallerWithFunctionName - zerolog.DefaultContextLogger = &defaultCtxLog - deflog.Logger = br.ZLog.With().Bool("global_log", true).Caller().Logger() + exzerolog.SetupDefaults(br.ZLog) br.Log = maulogadapt.ZeroAsMau(br.ZLog) br.DoublePuppet = &doublePuppetUtil{br: br, log: br.ZLog.With().Str("component", "double puppet").Logger()} diff --git a/go.mod b/go.mod index ce2a08a1..b36e8d3c 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/tidwall/gjson v1.17.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.0 - go.mau.fi/util v0.3.1-0.20240208081958-5f386f84d7e2 + go.mau.fi/util v0.3.1-0.20240208085450-32294da153ab go.mau.fi/zeroconfig v0.1.2 golang.org/x/crypto v0.19.0 golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3 diff --git a/go.sum b/go.sum index 8ae63c70..abe1986c 100644 --- a/go.sum +++ b/go.sum @@ -37,8 +37,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.0 h1:EfOIvIMZIzHdB/R/zVrikYLPPwJlfMcNczJFMs1m6sA= github.com/yuin/goldmark v1.7.0/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.3.1-0.20240208081958-5f386f84d7e2 h1:jh3ji7lyshRySDa/uiv+wFH35VvBREpeO39WJ0gR+Zk= -go.mau.fi/util v0.3.1-0.20240208081958-5f386f84d7e2/go.mod h1:rRypwgXVEPILomtFPyQcnbOeuRqf+nRN84vh/CICq4w= +go.mau.fi/util v0.3.1-0.20240208085450-32294da153ab h1:XZ8W5vHWlXSGmHn1U+Fvbh+xZr9wuHTvbY+qV7aybDY= +go.mau.fi/util v0.3.1-0.20240208085450-32294da153ab/go.mod h1:rRypwgXVEPILomtFPyQcnbOeuRqf+nRN84vh/CICq4w= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= From 719df9ba820324e6463f772cb081c383ba86f1f9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 8 Feb 2024 10:58:15 +0200 Subject: [PATCH 0105/1647] Use exzerolog defaults in example --- example/go.mod | 20 +++++++++++--------- example/go.sum | 39 ++++++++++++++++++++++----------------- example/main.go | 2 ++ 3 files changed, 35 insertions(+), 26 deletions(-) diff --git a/example/go.mod b/example/go.mod index f78b3fa0..60583640 100644 --- a/example/go.mod +++ b/example/go.mod @@ -1,12 +1,15 @@ module maunium.net/go/mautrix/example -go 1.20 +go 1.21 + +toolchain go1.22.0 require ( github.com/chzyer/readline v1.5.1 - github.com/mattn/go-sqlite3 v1.14.19 - github.com/rs/zerolog v1.31.0 - maunium.net/go/mautrix v0.16.3-0.20240113165612-308e3583b06f + github.com/mattn/go-sqlite3 v1.14.22 + github.com/rs/zerolog v1.32.0 + go.mau.fi/util v0.3.1-0.20240208085450-32294da153ab + maunium.net/go/mautrix v0.17.1-0.20240208085632-2d1786ced444 ) require ( @@ -16,11 +19,10 @@ require ( github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/tidwall/sjson v1.2.5 // indirect - go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894 // indirect - golang.org/x/crypto v0.17.0 // indirect - golang.org/x/exp v0.0.0-20231226003508-02704c960a9b // indirect - golang.org/x/net v0.19.0 // indirect - golang.org/x/sys v0.15.0 // indirect + golang.org/x/crypto v0.19.0 // indirect + golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3 // indirect + golang.org/x/net v0.21.0 // indirect + golang.org/x/sys v0.17.0 // indirect maunium.net/go/maulogger/v2 v2.4.1 // indirect ) diff --git a/example/go.sum b/example/go.sum index 0a3092ed..f81f31c2 100644 --- a/example/go.sum +++ b/example/go.sum @@ -1,4 +1,5 @@ -github.com/DATA-DOG/go-sqlmock v1.5.1 h1:FK6RCIUSfmbnI/imIICmboyQBkOckutaa6R5YYlLZyo= +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM= github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI= @@ -7,20 +8,23 @@ github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.19 h1:fhGleo2h1p8tVChob4I9HpmVFIAkKGpiukdrgQbWfGI= -github.com/mattn/go-sqlite3 v1.14.19/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= -github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A= -github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0= +github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM= github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= @@ -30,22 +34,23 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894 h1:CuR5LDSxBQLETorfwJ9vRtySeLHjMvJ7//lnCMw7Dy8= -go.mau.fi/util v0.2.2-0.20240107143939-48dfc4dc3894/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs= -golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= -golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= -golang.org/x/exp v0.0.0-20231226003508-02704c960a9b h1:kLiC65FbiHWFAOu+lxwNPujcsl8VYyTYYEZnsOO1WK4= -golang.org/x/exp v0.0.0-20231226003508-02704c960a9b/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= -golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= -golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= +go.mau.fi/util v0.3.1-0.20240208085450-32294da153ab h1:XZ8W5vHWlXSGmHn1U+Fvbh+xZr9wuHTvbY+qV7aybDY= +go.mau.fi/util v0.3.1-0.20240208085450-32294da153ab/go.mod h1:rRypwgXVEPILomtFPyQcnbOeuRqf+nRN84vh/CICq4w= +golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3 h1:/RIbNt/Zr7rVhIkQhooTxCxFcdWLGIKnZA4IXNFSrvo= +golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= +golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8= maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho= -maunium.net/go/mautrix v0.16.3-0.20240113165612-308e3583b06f h1:6uzyAxrjqGv2SbTAnIK3LI6mo1fILWOga6uNyId+6yM= -maunium.net/go/mautrix v0.16.3-0.20240113165612-308e3583b06f/go.mod h1:eRQu5ED1ODsP+xq1K9l1AOD+O9FMkAhodd/RVc3Bkqg= +maunium.net/go/mautrix v0.17.1-0.20240208085632-2d1786ced444 h1:PkpCzQotFakHkGKAatiQdb+XjP/HLQM40xuiy2JtHes= +maunium.net/go/mautrix v0.17.1-0.20240208085632-2d1786ced444/go.mod h1:tMIBWuMXrtjXAqMtaD1VHiT0B3TCxraYlqtncLIyKF0= diff --git a/example/main.go b/example/main.go index f799409c..d8006d46 100644 --- a/example/main.go +++ b/example/main.go @@ -20,6 +20,7 @@ import ( "github.com/chzyer/readline" _ "github.com/mattn/go-sqlite3" "github.com/rs/zerolog" + "go.mau.fi/util/exzerolog" "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/cryptohelper" @@ -57,6 +58,7 @@ func main() { if !*debug { log = log.Level(zerolog.InfoLevel) } + exzerolog.SetupDefaults(&log) client.Log = log var lastRoomID id.RoomID From f0f3e84acd66d24930de07e875970ed99ef580ea Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Fri, 9 Feb 2024 13:00:42 +0200 Subject: [PATCH 0106/1647] Sign cross-signing master with device key --- crypto/cross_sign_key.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/crypto/cross_sign_key.go b/crypto/cross_sign_key.go index 005a05fb..45e56b4b 100644 --- a/crypto/cross_sign_key.go +++ b/crypto/cross_sign_key.go @@ -101,6 +101,11 @@ func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *Cross masterKeyID: keys.MasterKey.PublicKey, }, } + masterSig, err := mach.account.Internal.SignJSON(masterKey) + if err != nil { + return fmt.Errorf("failed to sign master key: %w", err) + } + masterKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, mach.Client.DeviceID.String(), masterSig) selfKey := mautrix.CrossSigningKeys{ UserID: userID, From 8bfa59b5d31bf8e84cb85162424366b193a31d87 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 9 Feb 2024 13:44:25 +0200 Subject: [PATCH 0107/1647] Add Go-version-independent names for actions --- .github/workflows/go.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 5de7470c..488e4dd5 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -5,6 +5,7 @@ on: [push, pull_request] jobs: lint: runs-on: ubuntu-latest + name: Lint (latest) steps: - uses: actions/checkout@v4 @@ -34,6 +35,7 @@ jobs: fail-fast: false matrix: go-version: ["1.21", "1.22"] + name: Build ${{ matrix.go-version == '1.22' && '(latest)' || '(old)' }} steps: - uses: actions/checkout@v4 From b369efbc06b2777140e1faec5d2fdf82fdee432f Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 26 Jan 2024 09:19:56 -0700 Subject: [PATCH 0108/1647] goolm: rename a couple files Signed-off-by: Sumner Evans --- crypto/goolm/cipher/{main.go => cipher.go} | 3 ++- crypto/goolm/crypto/doc.go | 2 ++ crypto/goolm/crypto/main.go | 2 -- crypto/goolm/sas/{main.go => sas.go} | 2 +- crypto/goolm/sas/{main_test.go => sas_test.go} | 0 crypto/goolm/session/doc.go | 3 +++ crypto/goolm/session/main.go | 2 -- 7 files changed, 8 insertions(+), 6 deletions(-) rename crypto/goolm/cipher/{main.go => cipher.go} (85%) create mode 100644 crypto/goolm/crypto/doc.go delete mode 100644 crypto/goolm/crypto/main.go rename crypto/goolm/sas/{main.go => sas.go} (97%) rename crypto/goolm/sas/{main_test.go => sas_test.go} (100%) create mode 100644 crypto/goolm/session/doc.go delete mode 100644 crypto/goolm/session/main.go diff --git a/crypto/goolm/cipher/main.go b/crypto/goolm/cipher/cipher.go similarity index 85% rename from crypto/goolm/cipher/main.go rename to crypto/goolm/cipher/cipher.go index a8664702..43580b0b 100644 --- a/crypto/goolm/cipher/main.go +++ b/crypto/goolm/cipher/cipher.go @@ -1,4 +1,5 @@ -// cipher provides the methods and structs to do encryptions for olm/megolm. +// Package cipher provides the methods and structs to do encryptions for +// olm/megolm. package cipher // Cipher defines a valid cipher. diff --git a/crypto/goolm/crypto/doc.go b/crypto/goolm/crypto/doc.go new file mode 100644 index 00000000..5bdb01d8 --- /dev/null +++ b/crypto/goolm/crypto/doc.go @@ -0,0 +1,2 @@ +// Package crpyto provides the nessesary encryption methods for olm/megolm +package crypto diff --git a/crypto/goolm/crypto/main.go b/crypto/goolm/crypto/main.go deleted file mode 100644 index 509d44a5..00000000 --- a/crypto/goolm/crypto/main.go +++ /dev/null @@ -1,2 +0,0 @@ -// crpyto provides the nessesary encryption methods for olm/megolm -package crypto diff --git a/crypto/goolm/sas/main.go b/crypto/goolm/sas/sas.go similarity index 97% rename from crypto/goolm/sas/main.go rename to crypto/goolm/sas/sas.go index 7337d5f9..e34ba41c 100644 --- a/crypto/goolm/sas/main.go +++ b/crypto/goolm/sas/sas.go @@ -1,4 +1,4 @@ -// sas provides the means to do SAS between keys +// Package sas provides the means to do SAS between keys package sas import ( diff --git a/crypto/goolm/sas/main_test.go b/crypto/goolm/sas/sas_test.go similarity index 100% rename from crypto/goolm/sas/main_test.go rename to crypto/goolm/sas/sas_test.go diff --git a/crypto/goolm/session/doc.go b/crypto/goolm/session/doc.go new file mode 100644 index 00000000..bc2e8f46 --- /dev/null +++ b/crypto/goolm/session/doc.go @@ -0,0 +1,3 @@ +// Package session provides the different types of sessions for en/decrypting +// of messages +package session diff --git a/crypto/goolm/session/main.go b/crypto/goolm/session/main.go deleted file mode 100644 index 0caf8045..00000000 --- a/crypto/goolm/session/main.go +++ /dev/null @@ -1,2 +0,0 @@ -// session provides the different types of sessions for en/decrypting of messages -package session From 6bfa468ee7cf68e8792587da1070a227d61d4ea8 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 19 Jan 2024 13:53:24 -0700 Subject: [PATCH 0109/1647] crypto: remove old verification code Signed-off-by: Sumner Evans --- crypto/cross_sign_signing.go | 26 - crypto/keysharing.go | 15 + crypto/machine.go | 23 - crypto/verification.go | 801 ----------------------------- crypto/verification_in_room.go | 334 ------------ crypto/verification_sas_methods.go | 201 -------- 6 files changed, 15 insertions(+), 1385 deletions(-) delete mode 100644 crypto/verification.go delete mode 100644 crypto/verification_in_room.go delete mode 100644 crypto/verification_sas_methods.go diff --git a/crypto/cross_sign_signing.go b/crypto/cross_sign_signing.go index 616fef4a..c0efd54e 100644 --- a/crypto/cross_sign_signing.go +++ b/crypto/cross_sign_signing.go @@ -15,7 +15,6 @@ import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/signatures" - "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -35,31 +34,6 @@ var ( ErrMismatchingMasterKeyMAC = errors.New("mismatching cross-signing master key MAC") ) -func (mach *OlmMachine) fetchMasterKey(ctx context.Context, device *id.Device, content *event.VerificationMacEventContent, verState *verificationState, transactionID string) (id.Ed25519, error) { - crossSignKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, device.UserID) - if err != nil { - return "", fmt.Errorf("failed to fetch cross-signing keys: %w", err) - } - masterKey, ok := crossSignKeys[id.XSUsageMaster] - if !ok { - return "", ErrCrossSigningMasterKeyNotFound - } - masterKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, masterKey.Key.String()) - masterKeyMAC, ok := content.Mac[masterKeyID] - if !ok { - return masterKey.Key, ErrMasterKeyMACNotFound - } - expectedMasterKeyMAC, _, err := mach.getPKAndKeysMAC(verState.sas, device.UserID, device.DeviceID, - mach.Client.UserID, mach.Client.DeviceID, transactionID, masterKey.Key, masterKeyID, content.Mac) - if err != nil { - return masterKey.Key, fmt.Errorf("failed to calculate expected MAC for master key: %w", err) - } - if masterKeyMAC != expectedMasterKeyMAC { - err = fmt.Errorf("%w: expected %s, got %s", ErrMismatchingMasterKeyMAC, expectedMasterKeyMAC, masterKeyMAC) - } - return masterKey.Key, err -} - // SignUser creates a cross-signing signature for a user, stores it and uploads it to the server. func (mach *OlmMachine) SignUser(ctx context.Context, userID id.UserID, masterKey id.Ed25519) error { if userID == mach.Client.UserID { diff --git a/crypto/keysharing.go b/crypto/keysharing.go index bc9bc61a..2e8947f6 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -222,6 +222,21 @@ func (mach *OlmMachine) rejectKeyRequest(ctx context.Context, rejection KeyShare } } +// sendToOneDevice sends a to-device event to a single device. +func (mach *OlmMachine) sendToOneDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID, eventType event.Type, content interface{}) error { + _, err := mach.Client.SendToDevice(ctx, eventType, &mautrix.ReqSendToDevice{ + Messages: map[id.UserID]map[id.DeviceID]*event.Content{ + userID: { + deviceID: { + Parsed: content, + }, + }, + }, + }) + + return err +} + func (mach *OlmMachine) defaultAllowKeyShare(ctx context.Context, device *id.Device, evt event.RequestedKeyInfo) *KeyShareRejection { log := mach.machOrContextLog(ctx) if mach.Client.UserID != device.UserID { diff --git a/crypto/machine.go b/crypto/machine.go index e5058ed8..4a691166 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -44,10 +44,6 @@ type OlmMachine struct { AllowKeyShare func(context.Context, *id.Device, event.RequestedKeyInfo) *KeyShareRejection - DefaultSASTimeout time.Duration - // AcceptVerificationFrom determines whether the machine will accept verification requests from this device. - AcceptVerificationFrom func(string, *id.Device, id.RoomID) (VerificationRequestResponse, VerificationHooks) - account *OlmAccount roomKeyRequestFilled *sync.Map @@ -112,12 +108,6 @@ func NewOlmMachine(client *mautrix.Client, log *zerolog.Logger, cryptoStore Stor SendKeysMinTrust: id.TrustStateUnset, ShareKeysMinTrust: id.TrustStateCrossSignedTOFU, - DefaultSASTimeout: 10 * time.Minute, - AcceptVerificationFrom: func(string, *id.Device, id.RoomID) (VerificationRequestResponse, VerificationHooks) { - // Reject requests by default. Users need to override this to return appropriate verification hooks. - return RejectRequest, nil - }, - roomKeyRequestFilled: &sync.Map{}, keyVerificationTransactionState: &sync.Map{}, @@ -412,19 +402,6 @@ func (mach *OlmMachine) HandleToDeviceEvent(ctx context.Context, evt *event.Even go mach.HandleRoomKeyRequest(ctx, evt.Sender, content) case *event.BeeperRoomKeyAckEventContent: mach.HandleBeeperRoomKeyAck(ctx, evt.Sender, content) - // verification cases - case *event.VerificationStartEventContent: - mach.handleVerificationStart(ctx, evt.Sender, content, content.TransactionID, 10*time.Minute, "") - case *event.VerificationAcceptEventContent: - mach.handleVerificationAccept(ctx, evt.Sender, content, content.TransactionID) - case *event.VerificationKeyEventContent: - mach.handleVerificationKey(ctx, evt.Sender, content, content.TransactionID) - case *event.VerificationMacEventContent: - mach.handleVerificationMAC(ctx, evt.Sender, content, content.TransactionID) - case *event.VerificationCancelEventContent: - mach.handleVerificationCancel(evt.Sender, content, content.TransactionID) - case *event.VerificationRequestEventContent: - mach.handleVerificationRequest(ctx, evt.Sender, content, content.TransactionID, "") case *event.RoomKeyWithheldEventContent: mach.HandleRoomKeyWithheld(ctx, content) case *event.SecretRequestEventContent: diff --git a/crypto/verification.go b/crypto/verification.go deleted file mode 100644 index 78906f89..00000000 --- a/crypto/verification.go +++ /dev/null @@ -1,801 +0,0 @@ -// Copyright (c) 2020 Nikos Filippakis -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package crypto - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "math/rand" - "sort" - "strconv" - "strings" - "sync" - "time" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/crypto/canonicaljson" - "maunium.net/go/mautrix/crypto/olm" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -var ( - ErrUnknownUserForTransaction = errors.New("unknown user for transaction") - ErrTransactionAlreadyExists = errors.New("transaction already exists") - // ErrUnknownTransaction is returned when a key verification message is received with an unknown transaction ID. - ErrUnknownTransaction = errors.New("unknown transaction") - // ErrUnknownVerificationMethod is returned when the verification method in a received m.key.verification.start is unknown. - ErrUnknownVerificationMethod = errors.New("unknown verification method") -) - -type VerificationHooks interface { - // VerifySASMatch receives the generated SAS and its method, as well as the device that is being verified. - // It returns whether the given SAS match with the SAS displayed on other device. - VerifySASMatch(otherDevice *id.Device, sas SASData) bool - // VerificationMethods returns the list of supported verification methods in order of preference. - // It must contain at least the decimal method. - VerificationMethods() []VerificationMethod - OnCancel(cancelledByUs bool, reason string, reasonCode event.VerificationCancelCode) - OnSuccess() -} - -type VerificationRequestResponse int - -const ( - AcceptRequest VerificationRequestResponse = iota - RejectRequest - IgnoreRequest -) - -// sendToOneDevice sends a to-device event to a single device. -func (mach *OlmMachine) sendToOneDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID, eventType event.Type, content interface{}) error { - _, err := mach.Client.SendToDevice(ctx, eventType, &mautrix.ReqSendToDevice{ - Messages: map[id.UserID]map[id.DeviceID]*event.Content{ - userID: { - deviceID: { - Parsed: content, - }, - }, - }, - }) - - return err -} - -func (mach *OlmMachine) getPKAndKeysMAC(sas *olm.SAS, sendingUser id.UserID, sendingDevice id.DeviceID, receivingUser id.UserID, receivingDevice id.DeviceID, - transactionID string, signingKey id.SigningKey, mainKeyID id.KeyID, keys map[id.KeyID]string) (string, string, error) { - sasInfo := "MATRIX_KEY_VERIFICATION_MAC" + - sendingUser.String() + sendingDevice.String() + - receivingUser.String() + receivingDevice.String() + - transactionID - - // get key IDs from key map - keyIDStrings := make([]string, len(keys)) - i := 0 - for keyID := range keys { - keyIDStrings[i] = keyID.String() - i++ - } - sort.Sort(sort.StringSlice(keyIDStrings)) - keyIDString := strings.Join(keyIDStrings, ",") - - pubKeyMac, err := sas.CalculateMAC([]byte(signingKey), []byte(sasInfo+mainKeyID.String())) - if err != nil { - return "", "", err - } - mach.Log.Trace().Msgf("sas.CalculateMAC(\"%s\", \"%s\") -> \"%s\"", signingKey, sasInfo+mainKeyID.String(), string(pubKeyMac)) - - keysMac, err := sas.CalculateMAC([]byte(keyIDString), []byte(sasInfo+"KEY_IDS")) - if err != nil { - return "", "", err - } - mach.Log.Trace().Msgf("sas.CalculateMAC(\"%s\", \"%s\") -> \"%s\"", keyIDString, sasInfo+"KEY_IDS", string(keysMac)) - - return string(pubKeyMac), string(keysMac), nil -} - -// verificationState holds all the information needed for the state of a SAS verification with another device. -type verificationState struct { - sas *olm.SAS - otherDevice *id.Device - initiatedByUs bool - verificationStarted bool - keyReceived bool - sasMatched chan bool - commitment string - startEventCanonical string - chosenSASMethod VerificationMethod - hooks VerificationHooks - extendTimeout context.CancelFunc - inRoomID id.RoomID - lock sync.Mutex -} - -// getTransactionState retrieves the given transaction's state, or cancels the transaction if it cannot be found or there is a mismatch. -func (mach *OlmMachine) getTransactionState(ctx context.Context, transactionID string, userID id.UserID) (*verificationState, error) { - verStateInterface, ok := mach.keyVerificationTransactionState.Load(userID.String() + ":" + transactionID) - if !ok { - _ = mach.SendSASVerificationCancel(ctx, userID, id.DeviceID("*"), transactionID, "Unknown transaction: "+transactionID, event.VerificationCancelUnknownTransaction) - return nil, ErrUnknownTransaction - } - verState := verStateInterface.(*verificationState) - if verState.otherDevice.UserID != userID { - reason := fmt.Sprintf("Unknown user for transaction %v: %v", transactionID, userID) - if verState.inRoomID == "" { - _ = mach.SendSASVerificationCancel(ctx, userID, id.DeviceID("*"), transactionID, reason, event.VerificationCancelUserMismatch) - } else { - _ = mach.SendInRoomSASVerificationCancel(ctx, verState.inRoomID, userID, transactionID, reason, event.VerificationCancelUserMismatch) - } - mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) - return nil, fmt.Errorf("%w %s: %s", ErrUnknownUserForTransaction, transactionID, userID) - } - return verState, nil -} - -// handleVerificationStart handles an incoming m.key.verification.start message. -// It initializes the state for this SAS verification process and stores it. -func (mach *OlmMachine) handleVerificationStart(ctx context.Context, userID id.UserID, content *event.VerificationStartEventContent, transactionID string, timeout time.Duration, inRoomID id.RoomID) { - mach.Log.Debug().Msgf("Received verification start from %v", content.FromDevice) - otherDevice, err := mach.GetOrFetchDevice(ctx, userID, content.FromDevice) - if err != nil { - mach.Log.Error().Msgf("Could not find device %v of user %v", content.FromDevice, userID) - return - } - warnAndCancel := func(logReason, cancelReason string) { - mach.Log.Warn().Msgf("Canceling verification transaction %v as it %s", transactionID, logReason) - if inRoomID == "" { - _ = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, cancelReason, event.VerificationCancelUnknownMethod) - } else { - _ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, cancelReason, event.VerificationCancelUnknownMethod) - } - } - switch { - case content.Method != event.VerificationMethodSAS: - warnAndCancel("is not SAS", "Only SAS method is supported") - case !content.SupportsKeyAgreementProtocol(event.KeyAgreementCurve25519HKDFSHA256): - warnAndCancel("does not support key agreement protocol curve25519-hkdf-sha256", - "Only curve25519-hkdf-sha256 key agreement protocol is supported") - case !content.SupportsHashMethod(event.VerificationHashSHA256): - warnAndCancel("does not support SHA256 hashing", "Only SHA256 hashing is supported") - case !content.SupportsMACMethod(event.HKDFHMACSHA256): - warnAndCancel("does not support MAC method hkdf-hmac-sha256", "Only hkdf-hmac-sha256 MAC method is supported") - case !content.SupportsSASMethod(event.SASDecimal): - warnAndCancel("does not support decimal SAS", "Decimal SAS method must be supported") - default: - mach.actuallyStartVerification(ctx, userID, content, otherDevice, transactionID, timeout, inRoomID) - } -} - -func (mach *OlmMachine) actuallyStartVerification(ctx context.Context, userID id.UserID, content *event.VerificationStartEventContent, otherDevice *id.Device, transactionID string, timeout time.Duration, inRoomID id.RoomID) { - if inRoomID != "" && transactionID != "" { - verState, err := mach.getTransactionState(ctx, transactionID, userID) - if err != nil { - mach.Log.Error().Msgf("Failed to get transaction state for in-room verification %s start: %v", transactionID, err) - _ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "Internal state error in gomuks :(", "net.maunium.internal_error") - return - } - mach.timeoutAfter(ctx, verState, transactionID, timeout) - sasMethods := commonSASMethods(verState.hooks, content.ShortAuthenticationString) - err = mach.SendInRoomSASVerificationAccept(ctx, inRoomID, userID, content, transactionID, verState.sas.GetPubkey(), sasMethods) - if err != nil { - mach.Log.Error().Msgf("Error accepting in-room SAS verification: %v", err) - } - verState.chosenSASMethod = sasMethods[0] - verState.verificationStarted = true - return - } - resp, hooks := mach.AcceptVerificationFrom(transactionID, otherDevice, inRoomID) - if resp == AcceptRequest { - sasMethods := commonSASMethods(hooks, content.ShortAuthenticationString) - if len(sasMethods) == 0 { - mach.Log.Error().Msgf("No common SAS methods: %v", content.ShortAuthenticationString) - if inRoomID == "" { - _ = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, "No common SAS methods", event.VerificationCancelUnknownMethod) - } else { - _ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "No common SAS methods", event.VerificationCancelUnknownMethod) - } - return - } - verState := &verificationState{ - sas: olm.NewSAS(), - otherDevice: otherDevice, - initiatedByUs: false, - verificationStarted: true, - keyReceived: false, - sasMatched: make(chan bool, 1), - hooks: hooks, - chosenSASMethod: sasMethods[0], - inRoomID: inRoomID, - } - verState.lock.Lock() - defer verState.lock.Unlock() - - _, loaded := mach.keyVerificationTransactionState.LoadOrStore(userID.String()+":"+transactionID, verState) - if loaded { - // transaction already exists - mach.Log.Error().Msgf("Transaction %v already exists, canceling", transactionID) - if inRoomID == "" { - _ = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, "Transaction already exists", event.VerificationCancelUnexpectedMessage) - } else { - _ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "Transaction already exists", event.VerificationCancelUnexpectedMessage) - } - return - } - - mach.timeoutAfter(ctx, verState, transactionID, timeout) - - var err error - if inRoomID == "" { - err = mach.SendSASVerificationAccept(ctx, userID, content, verState.sas.GetPubkey(), sasMethods) - } else { - err = mach.SendInRoomSASVerificationAccept(ctx, inRoomID, userID, content, transactionID, verState.sas.GetPubkey(), sasMethods) - } - if err != nil { - mach.Log.Error().Msgf("Error accepting SAS verification: %v", err) - } - } else if resp == RejectRequest { - mach.Log.Debug().Msgf("Not accepting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID) - var err error - if inRoomID == "" { - err = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, "Not accepted by user", event.VerificationCancelByUser) - } else { - err = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "Not accepted by user", event.VerificationCancelByUser) - } - if err != nil { - mach.Log.Error().Msgf("Error canceling SAS verification: %v", err) - } - } else { - mach.Log.Debug().Msgf("Ignoring SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID) - } -} - -func (mach *OlmMachine) timeoutAfter(ctx context.Context, verState *verificationState, transactionID string, timeout time.Duration) { - timeoutCtx, timeoutCancel := context.WithTimeout(ctx, timeout) - verState.extendTimeout = timeoutCancel - go func() { - mapKey := verState.otherDevice.UserID.String() + ":" + transactionID - for { - <-timeoutCtx.Done() - // when timeout context is done - verState.lock.Lock() - // if transaction not active anymore, return - if _, ok := mach.keyVerificationTransactionState.Load(mapKey); !ok { - verState.lock.Unlock() - return - } - if timeoutCtx.Err() == context.DeadlineExceeded { - // if deadline exceeded cancel due to timeout - mach.keyVerificationTransactionState.Delete(mapKey) - _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Timed out", event.VerificationCancelByTimeout) - mach.Log.Warn().Msgf("Verification transaction %v is canceled due to timing out", transactionID) - verState.lock.Unlock() - return - } - // otherwise the cancel func was called, so the timeout is reset - mach.Log.Debug().Msgf("Extending timeout for transaction %v", transactionID) - timeoutCtx, timeoutCancel = context.WithTimeout(context.Background(), timeout) - verState.extendTimeout = timeoutCancel - verState.lock.Unlock() - } - }() -} - -// handleVerificationAccept handles an incoming m.key.verification.accept message. -// It continues the SAS verification process by sending the SAS key message to the other device. -func (mach *OlmMachine) handleVerificationAccept(ctx context.Context, userID id.UserID, content *event.VerificationAcceptEventContent, transactionID string) { - mach.Log.Debug().Msgf("Received verification accept for transaction %v", transactionID) - verState, err := mach.getTransactionState(ctx, transactionID, userID) - if err != nil { - mach.Log.Error().Msgf("Error getting transaction state: %v", err) - return - } - verState.lock.Lock() - defer verState.lock.Unlock() - verState.extendTimeout() - - if !verState.initiatedByUs || verState.verificationStarted { - // unexpected accept at this point - mach.Log.Warn().Msgf("Unexpected verification accept message for transaction %v", transactionID) - mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) - _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Unexpected accept message", event.VerificationCancelUnexpectedMessage) - return - } - - sasMethods := commonSASMethods(verState.hooks, content.ShortAuthenticationString) - if content.KeyAgreementProtocol != event.KeyAgreementCurve25519HKDFSHA256 || - content.Hash != event.VerificationHashSHA256 || - content.MessageAuthenticationCode != event.HKDFHMACSHA256 || - len(sasMethods) == 0 { - - mach.Log.Warn().Msgf("Canceling verification transaction %v due to unknown parameter", transactionID) - mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) - _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Verification uses unknown method", event.VerificationCancelUnknownMethod) - return - } - - key := verState.sas.GetPubkey() - verState.commitment = content.Commitment - verState.chosenSASMethod = sasMethods[0] - verState.verificationStarted = true - - if verState.inRoomID == "" { - err = mach.SendSASVerificationKey(ctx, userID, verState.otherDevice.DeviceID, transactionID, string(key)) - } else { - err = mach.SendInRoomSASVerificationKey(ctx, verState.inRoomID, userID, transactionID, string(key)) - } - if err != nil { - mach.Log.Error().Msgf("Error sending SAS key to other device: %v", err) - return - } -} - -// handleVerificationKey handles an incoming m.key.verification.key message. -// It stores the other device's public key in order to acquire the SAS shared secret. -func (mach *OlmMachine) handleVerificationKey(ctx context.Context, userID id.UserID, content *event.VerificationKeyEventContent, transactionID string) { - mach.Log.Debug().Msgf("Got verification key for transaction %v: %v", transactionID, content.Key) - verState, err := mach.getTransactionState(ctx, transactionID, userID) - if err != nil { - mach.Log.Error().Msgf("Error getting transaction state: %v", err) - return - } - verState.lock.Lock() - defer verState.lock.Unlock() - verState.extendTimeout() - - device := verState.otherDevice - - if !verState.verificationStarted || verState.keyReceived { - // unexpected key at this point - mach.Log.Warn().Msgf("Unexpected verification key message for transaction %v", transactionID) - mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) - _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Unexpected key message", event.VerificationCancelUnexpectedMessage) - return - } - - if err := verState.sas.SetTheirKey([]byte(content.Key)); err != nil { - mach.Log.Error().Msgf("Error setting other device's key: %v", err) - return - } - - verState.keyReceived = true - - if verState.initiatedByUs { - // verify commitment string from accept message now - expectedCommitment := olm.SHA256B64([]byte(content.Key + verState.startEventCanonical)) - mach.Log.Debug().Msgf("Received commitment: %v Expected: %v", verState.commitment, expectedCommitment) - if expectedCommitment != verState.commitment { - mach.Log.Warn().Msgf("Canceling verification transaction %v due to commitment mismatch", transactionID) - mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) - _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Commitment mismatch", event.VerificationCancelCommitmentMismatch) - return - } - } else { - // if verification was initiated by other device, send out our key now - key := verState.sas.GetPubkey() - - if verState.inRoomID == "" { - err = mach.SendSASVerificationKey(ctx, userID, device.DeviceID, transactionID, string(key)) - } else { - err = mach.SendInRoomSASVerificationKey(ctx, verState.inRoomID, userID, transactionID, string(key)) - } - if err != nil { - mach.Log.Error().Msgf("Error sending SAS key to other device: %v", err) - return - } - } - - // compare the SAS keys in a new goroutine and, when the verification is complete, send out the MAC - var initUserID, acceptUserID id.UserID - var initDeviceID, acceptDeviceID id.DeviceID - var initKey, acceptKey string - if verState.initiatedByUs { - initUserID = mach.Client.UserID - initDeviceID = mach.Client.DeviceID - initKey = string(verState.sas.GetPubkey()) - acceptUserID = device.UserID - acceptDeviceID = device.DeviceID - acceptKey = content.Key - } else { - initUserID = device.UserID - initDeviceID = device.DeviceID - initKey = content.Key - acceptUserID = mach.Client.UserID - acceptDeviceID = mach.Client.DeviceID - acceptKey = string(verState.sas.GetPubkey()) - } - // use the prefered SAS method to generate a SAS - sasMethod := verState.chosenSASMethod - sas, err := sasMethod.GetVerificationSAS(initUserID, initDeviceID, initKey, acceptUserID, acceptDeviceID, acceptKey, transactionID, verState.sas) - if err != nil { - mach.Log.Error().Msgf("Error generating SAS (method %v): %v", sasMethod.Type(), err) - return - } - mach.Log.Debug().Msgf("Generated SAS (%v): %v", sasMethod.Type(), sas) - go func() { - result := verState.hooks.VerifySASMatch(device, sas) - mach.sasCompared(ctx, result, transactionID, verState) - }() -} - -// sasCompared is called asynchronously. It waits for the SAS to be compared for the verification to proceed. -// If the SAS match, then our MAC is sent out. Otherwise the transaction is canceled. -func (mach *OlmMachine) sasCompared(ctx context.Context, didMatch bool, transactionID string, verState *verificationState) { - verState.lock.Lock() - defer verState.lock.Unlock() - verState.extendTimeout() - if didMatch { - verState.sasMatched <- true - var err error - if verState.inRoomID == "" { - err = mach.SendSASVerificationMAC(ctx, verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, verState.sas) - } else { - err = mach.SendInRoomSASVerificationMAC(ctx, verState.inRoomID, verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, verState.sas) - } - if err != nil { - mach.Log.Error().Msgf("Error sending verification MAC to other device: %v", err) - } - } else { - verState.sasMatched <- false - } -} - -// handleVerificationMAC handles an incoming m.key.verification.mac message. -// It verifies the other device's MAC and if the MAC is valid it marks the device as trusted. -func (mach *OlmMachine) handleVerificationMAC(ctx context.Context, userID id.UserID, content *event.VerificationMacEventContent, transactionID string) { - mach.Log.Debug().Msgf("Got MAC for verification %v: %v, MAC for keys: %v", transactionID, content.Mac, content.Keys) - verState, err := mach.getTransactionState(ctx, transactionID, userID) - if err != nil { - mach.Log.Error().Msgf("Error getting transaction state: %v", err) - return - } - verState.lock.Lock() - defer verState.lock.Unlock() - verState.extendTimeout() - - device := verState.otherDevice - - // we are done with this SAS verification in all cases so we forget about it - mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) - - if !verState.verificationStarted || !verState.keyReceived { - // unexpected MAC at this point - mach.Log.Warn().Msgf("Unexpected MAC message for transaction %v", transactionID) - _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Unexpected MAC message", event.VerificationCancelUnexpectedMessage) - return - } - - // do this in another goroutine as the match result might take a long time to arrive - go func() { - matched := <-verState.sasMatched - verState.lock.Lock() - defer verState.lock.Unlock() - - if !matched { - mach.Log.Warn().Msgf("SAS do not match! Canceling transaction %v", transactionID) - _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "SAS do not match", event.VerificationCancelSASMismatch) - return - } - - keyID := id.NewKeyID(id.KeyAlgorithmEd25519, device.DeviceID.String()) - - expectedPKMAC, expectedKeysMAC, err := mach.getPKAndKeysMAC(verState.sas, device.UserID, device.DeviceID, - mach.Client.UserID, mach.Client.DeviceID, transactionID, device.SigningKey, keyID, content.Mac) - if err != nil { - mach.Log.Error().Msgf("Error generating MAC to match with received MAC: %v", err) - return - } - - mach.Log.Debug().Msgf("Expected %s keys MAC, got %s", expectedKeysMAC, content.Keys) - if content.Keys != expectedKeysMAC { - mach.Log.Warn().Msgf("Canceling verification transaction %v due to mismatched keys MAC", transactionID) - _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Mismatched keys MACs", event.VerificationCancelKeyMismatch) - return - } - - mach.Log.Debug().Msgf("Expected %s PK MAC, got %s", expectedPKMAC, content.Mac[keyID]) - if content.Mac[keyID] != expectedPKMAC { - mach.Log.Warn().Msgf("Canceling verification transaction %v due to mismatched PK MAC", transactionID) - _ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Mismatched PK MACs", event.VerificationCancelKeyMismatch) - return - } - - // we can finally trust this device - device.Trust = id.TrustStateVerified - err = mach.CryptoStore.PutDevice(ctx, device.UserID, device) - if err != nil { - mach.Log.Warn().Msgf("Failed to put device after verifying: %v", err) - } - - if mach.CrossSigningKeys != nil { - if device.UserID == mach.Client.UserID { - err := mach.SignOwnDevice(ctx, device) - if err != nil { - mach.Log.Error().Msgf("Failed to cross-sign own device %s: %v", device.DeviceID, err) - } else { - mach.Log.Debug().Msgf("Cross-signed own device %v after SAS verification", device.DeviceID) - } - } else { - masterKey, err := mach.fetchMasterKey(ctx, device, content, verState, transactionID) - if err != nil { - mach.Log.Warn().Msgf("Failed to fetch %s's master key: %v", device.UserID, err) - } else { - if err := mach.SignUser(ctx, device.UserID, masterKey); err != nil { - mach.Log.Error().Msgf("Failed to cross-sign master key of %s: %v", device.UserID, err) - } else { - mach.Log.Debug().Msgf("Cross-signed master key of %v after SAS verification", device.UserID) - } - } - } - } else { - // TODO ask user to unlock cross-signing keys? - mach.Log.Debug().Msgf("Cross-signing keys not cached, not signing %s/%s", device.UserID, device.DeviceID) - } - - mach.Log.Debug().Msgf("Device %v of user %v verified successfully!", device.DeviceID, device.UserID) - - verState.hooks.OnSuccess() - }() -} - -// handleVerificationCancel handles an incoming m.key.verification.cancel message. -// It cancels the verification process for the given reason. -func (mach *OlmMachine) handleVerificationCancel(userID id.UserID, content *event.VerificationCancelEventContent, transactionID string) { - // make sure to not reply with a cancel to not cause a loop of cancel messages - // this verification will get canceled even if the senders do not match - verStateInterface, ok := mach.keyVerificationTransactionState.Load(userID.String() + ":" + transactionID) - if ok { - go verStateInterface.(*verificationState).hooks.OnCancel(false, content.Reason, content.Code) - } - - mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) - mach.Log.Warn().Msgf("SAS verification %v was canceled by %v with reason: %v (%v)", - transactionID, userID, content.Reason, content.Code) -} - -// handleVerificationRequest handles an incoming m.key.verification.request message. -func (mach *OlmMachine) handleVerificationRequest(ctx context.Context, userID id.UserID, content *event.VerificationRequestEventContent, transactionID string, inRoomID id.RoomID) { - mach.Log.Debug().Msgf("Received verification request from %v", content.FromDevice) - otherDevice, err := mach.GetOrFetchDevice(ctx, userID, content.FromDevice) - if err != nil { - mach.Log.Error().Msgf("Could not find device %v of user %v", content.FromDevice, userID) - return - } - if !content.SupportsVerificationMethod(event.VerificationMethodSAS) { - mach.Log.Warn().Msgf("Canceling verification transaction %v as SAS is not supported", transactionID) - if inRoomID == "" { - _ = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, "Only SAS method is supported", event.VerificationCancelUnknownMethod) - } else { - _ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "Only SAS method is supported", event.VerificationCancelUnknownMethod) - } - return - } - resp, hooks := mach.AcceptVerificationFrom(transactionID, otherDevice, inRoomID) - if resp == AcceptRequest { - mach.Log.Debug().Msgf("Accepting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID) - if inRoomID == "" { - _, err = mach.NewSASVerificationWith(ctx, otherDevice, hooks, transactionID, mach.DefaultSASTimeout) - } else { - if err := mach.SendInRoomSASVerificationReady(ctx, inRoomID, transactionID); err != nil { - mach.Log.Error().Msgf("Error sending in-room SAS verification ready: %v", err) - } - if mach.Client.UserID < otherDevice.UserID { - // up to us to send the start message - _, err = mach.newInRoomSASVerificationWithInner(ctx, inRoomID, otherDevice, hooks, transactionID, mach.DefaultSASTimeout) - } - } - if err != nil { - mach.Log.Error().Msgf("Error accepting SAS verification request: %v", err) - } - } else if resp == RejectRequest { - mach.Log.Debug().Msgf("Rejecting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID) - if inRoomID == "" { - _ = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, "Not accepted by user", event.VerificationCancelByUser) - } else { - _ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "Not accepted by user", event.VerificationCancelByUser) - } - } else { - mach.Log.Debug().Msgf("Ignoring SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID) - } -} - -// NewSimpleSASVerificationWith starts the SAS verification process with another device with a default timeout, -// a generated transaction ID and support for both emoji and decimal SAS methods. -func (mach *OlmMachine) NewSimpleSASVerificationWith(ctx context.Context, device *id.Device, hooks VerificationHooks) (string, error) { - return mach.NewSASVerificationWith(ctx, device, hooks, "", mach.DefaultSASTimeout) -} - -// NewSASVerificationWith starts the SAS verification process with another device. -// If the other device accepts the verification transaction, the methods in `hooks` will be used to verify the SAS match and to complete the transaction.. -// If the transaction ID is empty, a new one is generated. -func (mach *OlmMachine) NewSASVerificationWith(ctx context.Context, device *id.Device, hooks VerificationHooks, transactionID string, timeout time.Duration) (string, error) { - if transactionID == "" { - transactionID = strconv.Itoa(rand.Int()) - } - mach.Log.Debug().Msgf("Starting new verification transaction %v with device %v of user %v", transactionID, device.DeviceID, device.UserID) - - verState := &verificationState{ - sas: olm.NewSAS(), - otherDevice: device, - initiatedByUs: true, - verificationStarted: false, - keyReceived: false, - sasMatched: make(chan bool, 1), - hooks: hooks, - } - verState.lock.Lock() - defer verState.lock.Unlock() - - startEvent, err := mach.SendSASVerificationStart(ctx, device.UserID, device.DeviceID, transactionID, hooks.VerificationMethods()) - if err != nil { - return "", err - } - - payload, err := json.Marshal(startEvent) - if err != nil { - return "", err - } - canonical, err := canonicaljson.CanonicalJSON(payload) - if err != nil { - return "", err - } - - verState.startEventCanonical = string(canonical) - _, loaded := mach.keyVerificationTransactionState.LoadOrStore(device.UserID.String()+":"+transactionID, verState) - if loaded { - return "", ErrTransactionAlreadyExists - } - - mach.timeoutAfter(ctx, verState, transactionID, timeout) - - return transactionID, nil -} - -// CancelSASVerification is used by the user to cancel a SAS verification process with the given reason. -func (mach *OlmMachine) CancelSASVerification(ctx context.Context, userID id.UserID, transactionID, reason string) error { - mapKey := userID.String() + ":" + transactionID - verStateInterface, ok := mach.keyVerificationTransactionState.Load(mapKey) - if !ok { - return ErrUnknownTransaction - } - verState := verStateInterface.(*verificationState) - verState.lock.Lock() - defer verState.lock.Unlock() - mach.Log.Trace().Msgf("User canceled verification transaction %v with reason: %v", transactionID, reason) - mach.keyVerificationTransactionState.Delete(mapKey) - return mach.callbackAndCancelSASVerification(ctx, verState, transactionID, reason, event.VerificationCancelByUser) -} - -// SendSASVerificationCancel is used to manually send a SAS cancel message process with the given reason and cancellation code. -func (mach *OlmMachine) SendSASVerificationCancel(ctx context.Context, userID id.UserID, deviceID id.DeviceID, transactionID string, reason string, code event.VerificationCancelCode) error { - content := &event.VerificationCancelEventContent{ - TransactionID: transactionID, - Reason: reason, - Code: code, - } - return mach.sendToOneDevice(ctx, userID, deviceID, event.ToDeviceVerificationCancel, content) -} - -// SendSASVerificationStart is used to manually send the SAS verification start message to another device. -func (mach *OlmMachine) SendSASVerificationStart(ctx context.Context, toUserID id.UserID, toDeviceID id.DeviceID, transactionID string, methods []VerificationMethod) (*event.VerificationStartEventContent, error) { - sasMethods := make([]event.SASMethod, len(methods)) - for i, method := range methods { - sasMethods[i] = method.Type() - } - content := &event.VerificationStartEventContent{ - FromDevice: mach.Client.DeviceID, - TransactionID: transactionID, - Method: event.VerificationMethodSAS, - KeyAgreementProtocols: []event.KeyAgreementProtocol{event.KeyAgreementCurve25519HKDFSHA256}, - Hashes: []event.VerificationHashMethod{event.VerificationHashSHA256}, - MessageAuthenticationCodes: []event.MACMethod{event.HKDFHMACSHA256}, - ShortAuthenticationString: sasMethods, - } - return content, mach.sendToOneDevice(ctx, toUserID, toDeviceID, event.ToDeviceVerificationStart, content) -} - -// SendSASVerificationAccept is used to manually send an accept for a SAS verification process from a received m.key.verification.start event. -func (mach *OlmMachine) SendSASVerificationAccept(ctx context.Context, fromUser id.UserID, startEvent *event.VerificationStartEventContent, publicKey []byte, methods []VerificationMethod) error { - if startEvent.Method != event.VerificationMethodSAS { - reason := "Unknown verification method: " + string(startEvent.Method) - if err := mach.SendSASVerificationCancel(ctx, fromUser, startEvent.FromDevice, startEvent.TransactionID, reason, event.VerificationCancelUnknownMethod); err != nil { - return err - } - return ErrUnknownVerificationMethod - } - payload, err := json.Marshal(startEvent) - if err != nil { - return err - } - canonical, err := canonicaljson.CanonicalJSON(payload) - if err != nil { - return err - } - hash := olm.SHA256B64(append(publicKey, canonical...)) - sasMethods := make([]event.SASMethod, len(methods)) - for i, method := range methods { - sasMethods[i] = method.Type() - } - content := &event.VerificationAcceptEventContent{ - TransactionID: startEvent.TransactionID, - Method: event.VerificationMethodSAS, - KeyAgreementProtocol: event.KeyAgreementCurve25519HKDFSHA256, - Hash: event.VerificationHashSHA256, - MessageAuthenticationCode: event.HKDFHMACSHA256, - ShortAuthenticationString: sasMethods, - Commitment: hash, - } - return mach.sendToOneDevice(ctx, fromUser, startEvent.FromDevice, event.ToDeviceVerificationAccept, content) -} - -func (mach *OlmMachine) callbackAndCancelSASVerification(ctx context.Context, verState *verificationState, transactionID, reason string, code event.VerificationCancelCode) error { - go verState.hooks.OnCancel(true, reason, code) - return mach.SendSASVerificationCancel(ctx, verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, reason, code) -} - -// SendSASVerificationKey sends the ephemeral public key for a device to the partner device. -func (mach *OlmMachine) SendSASVerificationKey(ctx context.Context, userID id.UserID, deviceID id.DeviceID, transactionID string, key string) error { - content := &event.VerificationKeyEventContent{ - TransactionID: transactionID, - Key: key, - } - return mach.sendToOneDevice(ctx, userID, deviceID, event.ToDeviceVerificationKey, content) -} - -// SendSASVerificationMAC is use the MAC of a device's key to the partner device. -func (mach *OlmMachine) SendSASVerificationMAC(ctx context.Context, userID id.UserID, deviceID id.DeviceID, transactionID string, sas *olm.SAS) error { - keyID := id.NewKeyID(id.KeyAlgorithmEd25519, mach.Client.DeviceID.String()) - - signingKey := mach.account.SigningKey() - keyIDsMap := map[id.KeyID]string{keyID: ""} - macMap := make(map[id.KeyID]string) - - if mach.CrossSigningKeys != nil { - masterKey := mach.CrossSigningKeys.MasterKey.PublicKey - masterKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, masterKey.String()) - // add master key ID to key map - keyIDsMap[masterKeyID] = "" - masterKeyMAC, _, err := mach.getPKAndKeysMAC(sas, mach.Client.UserID, mach.Client.DeviceID, - userID, deviceID, transactionID, masterKey, masterKeyID, keyIDsMap) - if err != nil { - mach.Log.Error().Msgf("Error generating master key MAC: %v", err) - } else { - mach.Log.Debug().Msgf("Generated master key `%v` MAC: %v", masterKey, masterKeyMAC) - macMap[masterKeyID] = masterKeyMAC - } - } - - pubKeyMac, keysMac, err := mach.getPKAndKeysMAC(sas, mach.Client.UserID, mach.Client.DeviceID, userID, deviceID, transactionID, signingKey, keyID, keyIDsMap) - if err != nil { - return err - } - mach.Log.Debug().Msgf("MAC of key %s is: %s", signingKey, pubKeyMac) - mach.Log.Debug().Msgf("MAC of key ID(s) %s is: %s", keyID, keysMac) - macMap[keyID] = pubKeyMac - - content := &event.VerificationMacEventContent{ - TransactionID: transactionID, - Keys: keysMac, - Mac: macMap, - } - - return mach.sendToOneDevice(ctx, userID, deviceID, event.ToDeviceVerificationMAC, content) -} - -func commonSASMethods(hooks VerificationHooks, otherDeviceMethods []event.SASMethod) []VerificationMethod { - methods := make([]VerificationMethod, 0) - for _, hookMethod := range hooks.VerificationMethods() { - for _, otherMethod := range otherDeviceMethods { - if hookMethod.Type() == otherMethod { - methods = append(methods, hookMethod) - break - } - } - } - return methods -} diff --git a/crypto/verification_in_room.go b/crypto/verification_in_room.go deleted file mode 100644 index a01f0216..00000000 --- a/crypto/verification_in_room.go +++ /dev/null @@ -1,334 +0,0 @@ -// Copyright (c) 2020 Nikos Filippakis -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package crypto - -import ( - "context" - "encoding/json" - "errors" - "time" - - "maunium.net/go/mautrix/crypto/canonicaljson" - "maunium.net/go/mautrix/crypto/olm" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -var ( - ErrNoVerificationFromDevice = errors.New("from_device field is empty") - ErrNoVerificationMethods = errors.New("verification method list is empty") - ErrNoRelatesTo = errors.New("missing m.relates_to info") -) - -// ProcessInRoomVerification is a callback that is to be called when a client receives a message -// related to in-room verification. -// -// Currently this is not automatically called, so you must add the listener yourself. -// Note that in-room verification events are wrapped in m.room.encrypted, but this expects the decrypted events. -func (mach *OlmMachine) ProcessInRoomVerification(evt *event.Event) error { - if evt.Sender == mach.Client.UserID { - // nothing to do if the message is our own - return nil - } - if relatable, ok := evt.Content.Parsed.(event.Relatable); !ok || relatable.OptionalGetRelatesTo() == nil { - return ErrNoRelatesTo - } - - ctx := context.TODO() - switch content := evt.Content.Parsed.(type) { - case *event.MessageEventContent: - if content.MsgType == event.MsgVerificationRequest { - if content.FromDevice == "" { - return ErrNoVerificationFromDevice - } - if content.Methods == nil { - return ErrNoVerificationMethods - } - - newContent := &event.VerificationRequestEventContent{ - FromDevice: content.FromDevice, - Methods: content.Methods, - Timestamp: evt.Timestamp, - TransactionID: evt.ID.String(), - } - mach.handleVerificationRequest(ctx, evt.Sender, newContent, evt.ID.String(), evt.RoomID) - } - case *event.VerificationStartEventContent: - mach.handleVerificationStart(ctx, evt.Sender, content, content.RelatesTo.EventID.String(), 10*time.Minute, evt.RoomID) - case *event.VerificationReadyEventContent: - mach.handleInRoomVerificationReady(ctx, evt.Sender, evt.RoomID, content, content.RelatesTo.EventID.String()) - case *event.VerificationAcceptEventContent: - mach.handleVerificationAccept(ctx, evt.Sender, content, content.RelatesTo.EventID.String()) - case *event.VerificationKeyEventContent: - mach.handleVerificationKey(ctx, evt.Sender, content, content.RelatesTo.EventID.String()) - case *event.VerificationMacEventContent: - mach.handleVerificationMAC(ctx, evt.Sender, content, content.RelatesTo.EventID.String()) - case *event.VerificationCancelEventContent: - mach.handleVerificationCancel(evt.Sender, content, content.RelatesTo.EventID.String()) - } - return nil -} - -// SendInRoomSASVerificationCancel is used to manually send an in-room SAS cancel message process with the given reason and cancellation code. -func (mach *OlmMachine) SendInRoomSASVerificationCancel(ctx context.Context, roomID id.RoomID, userID id.UserID, transactionID string, reason string, code event.VerificationCancelCode) error { - content := &event.VerificationCancelEventContent{ - RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)}, - Reason: reason, - Code: code, - To: userID, - } - - encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationCancel, content) - if err != nil { - return err - } - _, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted) - return err -} - -// SendInRoomSASVerificationRequest is used to manually send an in-room SAS verification request message to another user. -func (mach *OlmMachine) SendInRoomSASVerificationRequest(ctx context.Context, roomID id.RoomID, toUserID id.UserID, methods []VerificationMethod) (string, error) { - content := &event.MessageEventContent{ - MsgType: event.MsgVerificationRequest, - FromDevice: mach.Client.DeviceID, - Methods: []event.VerificationMethod{event.VerificationMethodSAS}, - To: toUserID, - } - - encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.EventMessage, content) - if err != nil { - return "", err - } - resp, err := mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted) - if err != nil { - return "", err - } - return resp.EventID.String(), nil -} - -// SendInRoomSASVerificationReady is used to manually send an in-room SAS verification ready message to another user. -func (mach *OlmMachine) SendInRoomSASVerificationReady(ctx context.Context, roomID id.RoomID, transactionID string) error { - content := &event.VerificationReadyEventContent{ - FromDevice: mach.Client.DeviceID, - Methods: []event.VerificationMethod{event.VerificationMethodSAS}, - RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)}, - } - - encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationReady, content) - if err != nil { - return err - } - _, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted) - return err -} - -// SendInRoomSASVerificationStart is used to manually send the in-room SAS verification start message to another user. -func (mach *OlmMachine) SendInRoomSASVerificationStart(ctx context.Context, roomID id.RoomID, toUserID id.UserID, transactionID string, methods []VerificationMethod) (*event.VerificationStartEventContent, error) { - sasMethods := make([]event.SASMethod, len(methods)) - for i, method := range methods { - sasMethods[i] = method.Type() - } - content := &event.VerificationStartEventContent{ - FromDevice: mach.Client.DeviceID, - RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)}, - Method: event.VerificationMethodSAS, - KeyAgreementProtocols: []event.KeyAgreementProtocol{event.KeyAgreementCurve25519HKDFSHA256}, - Hashes: []event.VerificationHashMethod{event.VerificationHashSHA256}, - MessageAuthenticationCodes: []event.MACMethod{event.HKDFHMACSHA256}, - ShortAuthenticationString: sasMethods, - To: toUserID, - } - - encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationStart, content) - if err != nil { - return nil, err - } - _, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted) - return content, err -} - -// SendInRoomSASVerificationAccept is used to manually send an accept for an in-room SAS verification process from a received m.key.verification.start event. -func (mach *OlmMachine) SendInRoomSASVerificationAccept(ctx context.Context, roomID id.RoomID, fromUser id.UserID, startEvent *event.VerificationStartEventContent, transactionID string, publicKey []byte, methods []VerificationMethod) error { - if startEvent.Method != event.VerificationMethodSAS { - reason := "Unknown verification method: " + string(startEvent.Method) - if err := mach.SendInRoomSASVerificationCancel(ctx, roomID, fromUser, transactionID, reason, event.VerificationCancelUnknownMethod); err != nil { - return err - } - return ErrUnknownVerificationMethod - } - payload, err := json.Marshal(startEvent) - if err != nil { - return err - } - canonical, err := canonicaljson.CanonicalJSON(payload) - if err != nil { - return err - } - hash := olm.SHA256B64(append(publicKey, canonical...)) - sasMethods := make([]event.SASMethod, len(methods)) - for i, method := range methods { - sasMethods[i] = method.Type() - } - content := &event.VerificationAcceptEventContent{ - RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)}, - Method: event.VerificationMethodSAS, - KeyAgreementProtocol: event.KeyAgreementCurve25519HKDFSHA256, - Hash: event.VerificationHashSHA256, - MessageAuthenticationCode: event.HKDFHMACSHA256, - ShortAuthenticationString: sasMethods, - Commitment: hash, - To: fromUser, - } - - encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationAccept, content) - if err != nil { - return err - } - _, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted) - return err -} - -// SendInRoomSASVerificationKey sends the ephemeral public key for a device to the partner device for an in-room verification. -func (mach *OlmMachine) SendInRoomSASVerificationKey(ctx context.Context, roomID id.RoomID, userID id.UserID, transactionID string, key string) error { - content := &event.VerificationKeyEventContent{ - RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)}, - Key: key, - To: userID, - } - - encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationKey, content) - if err != nil { - return err - } - _, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted) - return err -} - -// SendInRoomSASVerificationMAC sends the MAC of a device's key to the partner device for an in-room verification. -func (mach *OlmMachine) SendInRoomSASVerificationMAC(ctx context.Context, roomID id.RoomID, userID id.UserID, deviceID id.DeviceID, transactionID string, sas *olm.SAS) error { - keyID := id.NewKeyID(id.KeyAlgorithmEd25519, mach.Client.DeviceID.String()) - - signingKey := mach.account.SigningKey() - keyIDsMap := map[id.KeyID]string{keyID: ""} - macMap := make(map[id.KeyID]string) - - if mach.CrossSigningKeys != nil { - masterKey := mach.CrossSigningKeys.MasterKey.PublicKey - masterKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, masterKey.String()) - // add master key ID to key map - keyIDsMap[masterKeyID] = "" - masterKeyMAC, _, err := mach.getPKAndKeysMAC(sas, mach.Client.UserID, mach.Client.DeviceID, - userID, deviceID, transactionID, masterKey, masterKeyID, keyIDsMap) - if err != nil { - mach.Log.Error().Msgf("Error generating master key MAC: %v", err) - } else { - mach.Log.Debug().Msgf("Generated master key `%v` MAC: %v", masterKey, masterKeyMAC) - macMap[masterKeyID] = masterKeyMAC - } - } - - pubKeyMac, keysMac, err := mach.getPKAndKeysMAC(sas, mach.Client.UserID, mach.Client.DeviceID, userID, deviceID, transactionID, signingKey, keyID, keyIDsMap) - if err != nil { - return err - } - mach.Log.Debug().Msgf("MAC of key %s is: %s", signingKey, pubKeyMac) - mach.Log.Debug().Msgf("MAC of key ID(s) %s is: %s", keyID, keysMac) - macMap[keyID] = pubKeyMac - - content := &event.VerificationMacEventContent{ - RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)}, - Keys: keysMac, - Mac: macMap, - To: userID, - } - - encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationMAC, content) - if err != nil { - return err - } - _, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted) - return err -} - -// NewInRoomSASVerificationWith starts the in-room SAS verification process with another user in the given room. -// It returns the generated transaction ID. -func (mach *OlmMachine) NewInRoomSASVerificationWith(ctx context.Context, inRoomID id.RoomID, userID id.UserID, hooks VerificationHooks, timeout time.Duration) (string, error) { - return mach.newInRoomSASVerificationWithInner(ctx, inRoomID, &id.Device{UserID: userID}, hooks, "", timeout) -} - -func (mach *OlmMachine) newInRoomSASVerificationWithInner(ctx context.Context, inRoomID id.RoomID, device *id.Device, hooks VerificationHooks, transactionID string, timeout time.Duration) (string, error) { - mach.Log.Debug().Msgf("Starting new in-room verification transaction user %v", device.UserID) - - request := transactionID == "" - if request { - var err error - // get new transaction ID from the request message event ID - transactionID, err = mach.SendInRoomSASVerificationRequest(ctx, inRoomID, device.UserID, hooks.VerificationMethods()) - if err != nil { - return "", err - } - } - verState := &verificationState{ - sas: olm.NewSAS(), - otherDevice: device, - initiatedByUs: true, - verificationStarted: false, - keyReceived: false, - sasMatched: make(chan bool, 1), - hooks: hooks, - inRoomID: inRoomID, - } - verState.lock.Lock() - defer verState.lock.Unlock() - - if !request { - // start in-room verification - startEvent, err := mach.SendInRoomSASVerificationStart(ctx, inRoomID, device.UserID, transactionID, hooks.VerificationMethods()) - if err != nil { - return "", err - } - - payload, err := json.Marshal(startEvent) - if err != nil { - return "", err - } - canonical, err := canonicaljson.CanonicalJSON(payload) - if err != nil { - return "", err - } - - verState.startEventCanonical = string(canonical) - } - - mach.keyVerificationTransactionState.Store(device.UserID.String()+":"+transactionID, verState) - - mach.timeoutAfter(ctx, verState, transactionID, timeout) - - return transactionID, nil -} - -func (mach *OlmMachine) handleInRoomVerificationReady(ctx context.Context, userID id.UserID, roomID id.RoomID, content *event.VerificationReadyEventContent, transactionID string) { - device, err := mach.GetOrFetchDevice(ctx, userID, content.FromDevice) - if err != nil { - mach.Log.Error().Msgf("Error fetching device %v of user %v: %v", content.FromDevice, userID, err) - return - } - - verState, err := mach.getTransactionState(ctx, transactionID, userID) - if err != nil { - mach.Log.Error().Msgf("Error getting transaction state: %v", err) - return - } - //mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) - - if mach.Client.UserID < userID { - // up to us to send the start message - verState.lock.Lock() - mach.newInRoomSASVerificationWithInner(ctx, roomID, device, verState.hooks, transactionID, 10*time.Minute) - verState.lock.Unlock() - } -} diff --git a/crypto/verification_sas_methods.go b/crypto/verification_sas_methods.go deleted file mode 100644 index 2d847303..00000000 --- a/crypto/verification_sas_methods.go +++ /dev/null @@ -1,201 +0,0 @@ -// Copyright (c) 2020 Nikos Filippakis -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package crypto - -import ( - "fmt" - - "maunium.net/go/mautrix/crypto/olm" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -// SASData contains the data that users need to verify. -type SASData interface { - Type() event.SASMethod -} - -// VerificationMethod describes a method for generating a SAS. -type VerificationMethod interface { - // GetVerificationSAS uses the user, device ID and key of the user who initiated the verification transaction, - // the user, device ID and key of the user who accepted, the transaction ID and the SAS object to generate a SAS. - // The SAS can be any type, such as an array of numbers or emojis. - GetVerificationSAS(initUserID id.UserID, initDeviceID id.DeviceID, initKey string, - acceptUserID id.UserID, acceptDeviceID id.DeviceID, acceptKey string, - transactionID string, sas *olm.SAS) (SASData, error) - // Type returns the type of this SAS method - Type() event.SASMethod -} - -const sasInfoFormat = "MATRIX_KEY_VERIFICATION_SAS|%s|%s|%s|%s|%s|%s|%s" - -// VerificationMethodDecimal describes the decimal SAS method. -type VerificationMethodDecimal struct{} - -// DecimalSASData contains the verification numbers for the decimal SAS method. -type DecimalSASData [3]uint - -// Type returns the decimal SAS method type. -func (DecimalSASData) Type() event.SASMethod { - return event.SASDecimal -} - -// GetVerificationSAS generates the three numbers that need to match with the other device for a verification to be valid. -func (VerificationMethodDecimal) GetVerificationSAS(initUserID id.UserID, initDeviceID id.DeviceID, initKey string, - acceptUserID id.UserID, acceptDeviceID id.DeviceID, acceptKey string, - transactionID string, sas *olm.SAS) (SASData, error) { - - sasInfo := fmt.Sprintf(sasInfoFormat, - initUserID, initDeviceID, initKey, - acceptUserID, acceptDeviceID, acceptKey, - transactionID) - - sasBytes, err := sas.GenerateBytes([]byte(sasInfo), 5) - if err != nil { - return DecimalSASData{0, 0, 0}, err - } - - numbers := DecimalSASData{ - (uint(sasBytes[0])<<5 | uint(sasBytes[1])>>3) + 1000, - (uint(sasBytes[1]&0x7)<<10 | uint(sasBytes[2])<<2 | uint(sasBytes[3]>>6)) + 1000, - (uint(sasBytes[3]&0x3F)<<7 | uint(sasBytes[4])>>1) + 1000, - } - - return numbers, nil -} - -// Type returns the decimal SAS method type. -func (VerificationMethodDecimal) Type() event.SASMethod { - return event.SASDecimal -} - -var allEmojis = [...]VerificationEmoji{ - {'🐶', "Dog"}, - {'🐱', "Cat"}, - {'🦁', "Lion"}, - {'🐎', "Horse"}, - {'🦄', "Unicorn"}, - {'🐷', "Pig"}, - {'🐘', "Elephant"}, - {'🐰', "Rabbit"}, - {'🐼', "Panda"}, - {'🐓', "Rooster"}, - {'🐧', "Penguin"}, - {'🐢', "Turtle"}, - {'🐟', "Fish"}, - {'🐙', "Octopus"}, - {'🦋', "Butterfly"}, - {'🌷', "Flower"}, - {'🌳', "Tree"}, - {'🌵', "Cactus"}, - {'🍄', "Mushroom"}, - {'🌏', "Globe"}, - {'🌙', "Moon"}, - {'☁', "Cloud"}, - {'🔥', "Fire"}, - {'🍌', "Banana"}, - {'🍎', "Apple"}, - {'🍓', "Strawberry"}, - {'🌽', "Corn"}, - {'🍕', "Pizza"}, - {'🎂', "Cake"}, - {'❤', "Heart"}, - {'😀', "Smiley"}, - {'🤖', "Robot"}, - {'🎩', "Hat"}, - {'👓', "Glasses"}, - {'🔧', "Spanner"}, - {'🎅', "Santa"}, - {'👍', "Thumbs Up"}, - {'☂', "Umbrella"}, - {'⌛', "Hourglass"}, - {'⏰', "Clock"}, - {'🎁', "Gift"}, - {'💡', "Light Bulb"}, - {'📕', "Book"}, - {'✏', "Pencil"}, - {'📎', "Paperclip"}, - {'✂', "Scissors"}, - {'🔒', "Lock"}, - {'🔑', "Key"}, - {'🔨', "Hammer"}, - {'☎', "Telephone"}, - {'🏁', "Flag"}, - {'🚂', "Train"}, - {'🚲', "Bicycle"}, - {'✈', "Aeroplane"}, - {'🚀', "Rocket"}, - {'🏆', "Trophy"}, - {'⚽', "Ball"}, - {'🎸', "Guitar"}, - {'🎺', "Trumpet"}, - {'🔔', "Bell"}, - {'⚓', "Anchor"}, - {'🎧', "Headphones"}, - {'📁', "Folder"}, - {'📌', "Pin"}, -} - -// VerificationEmoji describes an emoji that might be sent for verifying devices. -type VerificationEmoji struct { - Emoji rune - Description string -} - -func (vm VerificationEmoji) GetEmoji() rune { - return vm.Emoji -} - -func (vm VerificationEmoji) GetDescription() string { - return vm.Description -} - -// EmojiSASData contains the verification emojis for the emoji SAS method. -type EmojiSASData [7]VerificationEmoji - -// Type returns the emoji SAS method type. -func (EmojiSASData) Type() event.SASMethod { - return event.SASEmoji -} - -// VerificationMethodEmoji describes the emoji SAS method. -type VerificationMethodEmoji struct{} - -// GetVerificationSAS generates the three numbers that need to match with the other device for a verification to be valid. -func (VerificationMethodEmoji) GetVerificationSAS(initUserID id.UserID, initDeviceID id.DeviceID, initKey string, - acceptUserID id.UserID, acceptDeviceID id.DeviceID, acceptKey string, - transactionID string, sas *olm.SAS) (SASData, error) { - - sasInfo := fmt.Sprintf(sasInfoFormat, - initUserID, initDeviceID, initKey, - acceptUserID, acceptDeviceID, acceptKey, - transactionID) - - var emojis EmojiSASData - sasBytes, err := sas.GenerateBytes([]byte(sasInfo), 6) - - if err != nil { - return emojis, err - } - - sasNum := uint64(sasBytes[0])<<40 | uint64(sasBytes[1])<<32 | uint64(sasBytes[2])<<24 | - uint64(sasBytes[3])<<16 | uint64(sasBytes[4])<<8 | uint64(sasBytes[5]) - - for i := 0; i < len(emojis); i++ { - // take nth group of 6 bits - emojiIdx := (sasNum >> uint(48-(i+1)*6)) & 0x3F - emoji := allEmojis[emojiIdx] - emojis[i] = emoji - } - - return emojis, nil -} - -// Type returns the emoji SAS method type. -func (VerificationMethodEmoji) Type() event.SASMethod { - return event.SASEmoji -} From 6aa214ad1a8701d799c3cb608d8606580a13ed6d Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 19 Jan 2024 14:37:05 -0700 Subject: [PATCH 0110/1647] events: reorder and add done event This commit reorders the events to match what's in the spec, and adds the missing m.key.verification.done event. Signed-off-by: Sumner Evans --- event/content.go | 18 +++++--- event/type.go | 37 +++++++++------ event/verification.go | 104 +++++++++++++++++++++++------------------- 3 files changed, 91 insertions(+), 68 deletions(-) diff --git a/event/content.go b/event/content.go index 6462fce2..0439d9a2 100644 --- a/event/content.go +++ b/event/content.go @@ -57,12 +57,14 @@ var TypeMap = map[Type]reflect.Type{ EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}), EphemeralEventPresence: reflect.TypeOf(PresenceEventContent{}), - InRoomVerificationStart: reflect.TypeOf(VerificationStartEventContent{}), InRoomVerificationReady: reflect.TypeOf(VerificationReadyEventContent{}), + InRoomVerificationStart: reflect.TypeOf(VerificationStartEventContent{}), + InRoomVerificationDone: reflect.TypeOf(VerificationDoneEventContent{}), + InRoomVerificationCancel: reflect.TypeOf(VerificationCancelEventContent{}), + InRoomVerificationAccept: reflect.TypeOf(VerificationAcceptEventContent{}), InRoomVerificationKey: reflect.TypeOf(VerificationKeyEventContent{}), InRoomVerificationMAC: reflect.TypeOf(VerificationMacEventContent{}), - InRoomVerificationCancel: reflect.TypeOf(VerificationCancelEventContent{}), ToDeviceRoomKey: reflect.TypeOf(RoomKeyEventContent{}), ToDeviceForwardedRoomKey: reflect.TypeOf(ForwardedRoomKeyEventContent{}), @@ -73,13 +75,15 @@ var TypeMap = map[Type]reflect.Type{ ToDeviceSecretSend: reflect.TypeOf(SecretSendEventContent{}), ToDeviceDummy: reflect.TypeOf(DummyEventContent{}), - ToDeviceVerificationStart: reflect.TypeOf(VerificationStartEventContent{}), - ToDeviceVerificationAccept: reflect.TypeOf(VerificationAcceptEventContent{}), - ToDeviceVerificationKey: reflect.TypeOf(VerificationKeyEventContent{}), - ToDeviceVerificationMAC: reflect.TypeOf(VerificationMacEventContent{}), - ToDeviceVerificationCancel: reflect.TypeOf(VerificationCancelEventContent{}), ToDeviceVerificationRequest: reflect.TypeOf(VerificationRequestEventContent{}), ToDeviceVerificationReady: reflect.TypeOf(VerificationReadyEventContent{}), + ToDeviceVerificationStart: reflect.TypeOf(VerificationStartEventContent{}), + ToDeviceVerificationDone: reflect.TypeOf(VerificationDoneEventContent{}), + ToDeviceVerificationCancel: reflect.TypeOf(VerificationCancelEventContent{}), + + ToDeviceVerificationAccept: reflect.TypeOf(VerificationAcceptEventContent{}), + ToDeviceVerificationKey: reflect.TypeOf(VerificationKeyEventContent{}), + ToDeviceVerificationMAC: reflect.TypeOf(VerificationMacEventContent{}), ToDeviceOrgMatrixRoomKeyWithheld: reflect.TypeOf(RoomKeyWithheldEventContent{}), diff --git a/event/type.go b/event/type.go index e7d47818..a4b36392 100644 --- a/event/type.go +++ b/event/type.go @@ -203,12 +203,15 @@ var ( EventReaction = Type{"m.reaction", MessageEventType} EventSticker = Type{"m.sticker", MessageEventType} - InRoomVerificationStart = Type{"m.key.verification.start", MessageEventType} InRoomVerificationReady = Type{"m.key.verification.ready", MessageEventType} + InRoomVerificationStart = Type{"m.key.verification.start", MessageEventType} + InRoomVerificationDone = Type{"m.key.verification.done", MessageEventType} + InRoomVerificationCancel = Type{"m.key.verification.cancel", MessageEventType} + + // SAS Verification Events InRoomVerificationAccept = Type{"m.key.verification.accept", MessageEventType} InRoomVerificationKey = Type{"m.key.verification.key", MessageEventType} InRoomVerificationMAC = Type{"m.key.verification.mac", MessageEventType} - InRoomVerificationCancel = Type{"m.key.verification.cancel", MessageEventType} CallInvite = Type{"m.call.invite", MessageEventType} CallCandidates = Type{"m.call.candidates", MessageEventType} @@ -246,21 +249,25 @@ var ( // Device-to-device events var ( - ToDeviceRoomKey = Type{"m.room_key", ToDeviceEventType} - ToDeviceRoomKeyRequest = Type{"m.room_key_request", ToDeviceEventType} - ToDeviceForwardedRoomKey = Type{"m.forwarded_room_key", ToDeviceEventType} - ToDeviceEncrypted = Type{"m.room.encrypted", ToDeviceEventType} - ToDeviceRoomKeyWithheld = Type{"m.room_key.withheld", ToDeviceEventType} - ToDeviceSecretRequest = Type{"m.secret.request", ToDeviceEventType} - ToDeviceSecretSend = Type{"m.secret.send", ToDeviceEventType} - ToDeviceDummy = Type{"m.dummy", ToDeviceEventType} + ToDeviceRoomKey = Type{"m.room_key", ToDeviceEventType} + ToDeviceRoomKeyRequest = Type{"m.room_key_request", ToDeviceEventType} + ToDeviceForwardedRoomKey = Type{"m.forwarded_room_key", ToDeviceEventType} + ToDeviceEncrypted = Type{"m.room.encrypted", ToDeviceEventType} + ToDeviceRoomKeyWithheld = Type{"m.room_key.withheld", ToDeviceEventType} + ToDeviceSecretRequest = Type{"m.secret.request", ToDeviceEventType} + ToDeviceSecretSend = Type{"m.secret.send", ToDeviceEventType} + ToDeviceDummy = Type{"m.dummy", ToDeviceEventType} + ToDeviceVerificationRequest = Type{"m.key.verification.request", ToDeviceEventType} - ToDeviceVerificationStart = Type{"m.key.verification.start", ToDeviceEventType} - ToDeviceVerificationAccept = Type{"m.key.verification.accept", ToDeviceEventType} - ToDeviceVerificationKey = Type{"m.key.verification.key", ToDeviceEventType} - ToDeviceVerificationMAC = Type{"m.key.verification.mac", ToDeviceEventType} - ToDeviceVerificationCancel = Type{"m.key.verification.cancel", ToDeviceEventType} ToDeviceVerificationReady = Type{"m.key.verification.ready", ToDeviceEventType} + ToDeviceVerificationStart = Type{"m.key.verification.start", ToDeviceEventType} + ToDeviceVerificationDone = Type{"m.key.verification.done", ToDeviceEventType} + ToDeviceVerificationCancel = Type{"m.key.verification.cancel", ToDeviceEventType} + + // SAS Verification Events + ToDeviceVerificationAccept = Type{"m.key.verification.accept", ToDeviceEventType} + ToDeviceVerificationKey = Type{"m.key.verification.key", ToDeviceEventType} + ToDeviceVerificationMAC = Type{"m.key.verification.mac", ToDeviceEventType} ToDeviceOrgMatrixRoomKeyWithheld = Type{"org.matrix.room_key.withheld", ToDeviceEventType} diff --git a/event/verification.go b/event/verification.go index 66d5abec..9fa8a592 100644 --- a/event/verification.go +++ b/event/verification.go @@ -166,6 +166,64 @@ func (vrec *VerificationReadyEventContent) SetRelatesTo(rel *RelatesTo) { vrec.RelatesTo = rel } +// VerificationDoneEventContent represents the content of a +// m.key.verification.done event as described in [Section 11.12.2.1] of the +// Matrix Spec. +// +// [Section 11.12.2.1]: https://spec.matrix.org/v1.9/client-server-api/#mkeyverificationdone +type VerificationDoneEventContent struct { + // The opaque identifier for the verification process/request. + TransactionID string `json:"transaction_id,omitempty"` + // Original event ID for in-room verification. + RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` +} + +type VerificationCancelCode string + +const ( + VerificationCancelByUser VerificationCancelCode = "m.user" + VerificationCancelByTimeout VerificationCancelCode = "m.timeout" + VerificationCancelUnknownTransaction VerificationCancelCode = "m.unknown_transaction" + VerificationCancelUnknownMethod VerificationCancelCode = "m.unknown_method" + VerificationCancelUnexpectedMessage VerificationCancelCode = "m.unexpected_message" + VerificationCancelKeyMismatch VerificationCancelCode = "m.key_mismatch" + VerificationCancelUserMismatch VerificationCancelCode = "m.user_mismatch" + VerificationCancelInvalidMessage VerificationCancelCode = "m.invalid_message" + VerificationCancelAccepted VerificationCancelCode = "m.accepted" + VerificationCancelSASMismatch VerificationCancelCode = "m.mismatched_sas" + VerificationCancelCommitmentMismatch VerificationCancelCode = "m.mismatched_commitment" +) + +// VerificationCancelEventContent represents the content of a m.key.verification.cancel to_device event. +// https://spec.matrix.org/v1.2/client-server-api/#mkeyverificationcancel +type VerificationCancelEventContent struct { + // The opaque identifier for the verification process/request. + TransactionID string `json:"transaction_id,omitempty"` + // A human readable description of the code. The client should only rely on this string if it does not understand the code. + Reason string `json:"reason"` + // The error code for why the process/request was cancelled by the user. + Code VerificationCancelCode `json:"code"` + // The user that the event is sent to for in-room verification. + To id.UserID `json:"to,omitempty"` + // Original event ID for in-room verification. + RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` +} + +func (vcec *VerificationCancelEventContent) GetRelatesTo() *RelatesTo { + if vcec.RelatesTo == nil { + vcec.RelatesTo = &RelatesTo{} + } + return vcec.RelatesTo +} + +func (vcec *VerificationCancelEventContent) OptionalGetRelatesTo() *RelatesTo { + return vcec.RelatesTo +} + +func (vcec *VerificationCancelEventContent) SetRelatesTo(rel *RelatesTo) { + vcec.RelatesTo = rel +} + // VerificationAcceptEventContent represents the content of a m.key.verification.accept to_device event. // https://spec.matrix.org/v1.2/client-server-api/#mkeyverificationaccept type VerificationAcceptEventContent struct { @@ -261,49 +319,3 @@ func (vmec *VerificationMacEventContent) OptionalGetRelatesTo() *RelatesTo { func (vmec *VerificationMacEventContent) SetRelatesTo(rel *RelatesTo) { vmec.RelatesTo = rel } - -type VerificationCancelCode string - -const ( - VerificationCancelByUser VerificationCancelCode = "m.user" - VerificationCancelByTimeout VerificationCancelCode = "m.timeout" - VerificationCancelUnknownTransaction VerificationCancelCode = "m.unknown_transaction" - VerificationCancelUnknownMethod VerificationCancelCode = "m.unknown_method" - VerificationCancelUnexpectedMessage VerificationCancelCode = "m.unexpected_message" - VerificationCancelKeyMismatch VerificationCancelCode = "m.key_mismatch" - VerificationCancelUserMismatch VerificationCancelCode = "m.user_mismatch" - VerificationCancelInvalidMessage VerificationCancelCode = "m.invalid_message" - VerificationCancelAccepted VerificationCancelCode = "m.accepted" - VerificationCancelSASMismatch VerificationCancelCode = "m.mismatched_sas" - VerificationCancelCommitmentMismatch VerificationCancelCode = "m.mismatched_commitment" -) - -// VerificationCancelEventContent represents the content of a m.key.verification.cancel to_device event. -// https://spec.matrix.org/v1.2/client-server-api/#mkeyverificationcancel -type VerificationCancelEventContent struct { - // The opaque identifier for the verification process/request. - TransactionID string `json:"transaction_id,omitempty"` - // A human readable description of the code. The client should only rely on this string if it does not understand the code. - Reason string `json:"reason"` - // The error code for why the process/request was cancelled by the user. - Code VerificationCancelCode `json:"code"` - // The user that the event is sent to for in-room verification. - To id.UserID `json:"to,omitempty"` - // Original event ID for in-room verification. - RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` -} - -func (vcec *VerificationCancelEventContent) GetRelatesTo() *RelatesTo { - if vcec.RelatesTo == nil { - vcec.RelatesTo = &RelatesTo{} - } - return vcec.RelatesTo -} - -func (vcec *VerificationCancelEventContent) OptionalGetRelatesTo() *RelatesTo { - return vcec.RelatesTo -} - -func (vcec *VerificationCancelEventContent) SetRelatesTo(rel *RelatesTo) { - vcec.RelatesTo = rel -} From ef5eb04ff81b559efd1f94c080470277d9f71fc1 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 19 Jan 2024 14:49:59 -0700 Subject: [PATCH 0111/1647] event/verification: add QR code verification method constants Signed-off-by: Sumner Evans --- event/verification.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/event/verification.go b/event/verification.go index 9fa8a592..8540c737 100644 --- a/event/verification.go +++ b/event/verification.go @@ -12,7 +12,13 @@ import ( type VerificationMethod string -const VerificationMethodSAS VerificationMethod = "m.sas.v1" +const ( + VerificationMethodSAS VerificationMethod = "m.sas.v1" + + VerificationMethodReciprocate VerificationMethod = "m.reciprocate.v1" + VerificationMethodQRCodeShow VerificationMethod = "m.qr_code.show.v1" + VerificationMethodQRCodeScan VerificationMethod = "m.qr_code.scan.v1" +) // VerificationRequestEventContent represents the content of a m.key.verification.request to_device event. // https://spec.matrix.org/v1.2/client-server-api/#mkeyverificationrequest From 3aec0a3a6ffca8225df3023bbb1545eace6d5bd3 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 19 Jan 2024 16:03:50 -0700 Subject: [PATCH 0112/1647] event/verification: improve documentation and other cleanup Signed-off-by: Sumner Evans --- crypto/backup/encryptedsessiondata.go | 26 +- event/verification.go | 476 +++++++++++++------------- 2 files changed, 242 insertions(+), 260 deletions(-) diff --git a/crypto/backup/encryptedsessiondata.go b/crypto/backup/encryptedsessiondata.go index ccaea0c4..8ac74151 100644 --- a/crypto/backup/encryptedsessiondata.go +++ b/crypto/backup/encryptedsessiondata.go @@ -6,10 +6,10 @@ import ( "crypto/hmac" "crypto/rand" "crypto/sha256" - "encoding/base64" "encoding/json" "errors" + "go.mau.fi/util/jsonbytes" "golang.org/x/crypto/hkdf" "maunium.net/go/mautrix/crypto/aescbc" @@ -17,24 +17,6 @@ import ( var ErrInvalidMAC = errors.New("invalid MAC") -// UnpaddedBytes is a byte slice that is encoded and decoded using -// [base64.RawStdEncoding]. -type UnpaddedBytes []byte - -func (b UnpaddedBytes) MarshalJSON() ([]byte, error) { - return json.Marshal(base64.RawStdEncoding.EncodeToString(b)) -} - -func (b *UnpaddedBytes) UnmarshalJSON(data []byte) error { - var b64str string - err := json.Unmarshal(data, &b64str) - if err != nil { - return err - } - *b, err = base64.RawStdEncoding.DecodeString(b64str) - return err -} - // EncryptedSessionData is the encrypted session_data field of a key backup as // defined in [Section 11.12.3.2.2 of the Spec]. // @@ -43,9 +25,9 @@ func (b *UnpaddedBytes) UnmarshalJSON(data []byte) error { // // [Section 11.12.3.2.2 of the Spec]: https://spec.matrix.org/v1.9/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 type EncryptedSessionData[T any] struct { - Ciphertext UnpaddedBytes `json:"ciphertext"` - Ephemeral EphemeralKey `json:"ephemeral"` - MAC UnpaddedBytes `json:"mac"` + Ciphertext jsonbytes.UnpaddedBytes `json:"ciphertext"` + Ephemeral EphemeralKey `json:"ephemeral"` + MAC jsonbytes.UnpaddedBytes `json:"mac"` } func calculateEncryptionParameters(sharedSecret []byte) (key, macKey, iv []byte, err error) { diff --git a/event/verification.go b/event/verification.go index 8540c737..60fcb9d4 100644 --- a/event/verification.go +++ b/event/verification.go @@ -7,6 +7,10 @@ package event import ( + "go.mau.fi/util/jsonbytes" + "go.mau.fi/util/jsontime" + "golang.org/x/exp/slices" + "maunium.net/go/mautrix/id" ) @@ -20,308 +24,304 @@ const ( VerificationMethodQRCodeScan VerificationMethod = "m.qr_code.scan.v1" ) -// VerificationRequestEventContent represents the content of a m.key.verification.request to_device event. -// https://spec.matrix.org/v1.2/client-server-api/#mkeyverificationrequest -type VerificationRequestEventContent struct { - // The device ID which is initiating the request. - FromDevice id.DeviceID `json:"from_device"` - // An opaque identifier for the verification request. Must be unique with respect to the devices involved. - TransactionID string `json:"transaction_id,omitempty"` - // The verification methods supported by the sender. - Methods []VerificationMethod `json:"methods"` - // The POSIX timestamp in milliseconds for when the request was made. - Timestamp int64 `json:"timestamp,omitempty"` - // The user that the event is sent to for in-room verification. - To id.UserID `json:"to,omitempty"` - // Original event ID for in-room verification. +type VerificationTransactionable interface { + GetTransactionID() id.VerificationTransactionID + SetTransactionID(id.VerificationTransactionID) +} + +// ToDeviceVerificationEvent contains the fields common to all to-device +// verification events. +type ToDeviceVerificationEvent struct { + // TransactionID is an opaque identifier for the verification request. Must + // be unique with respect to the devices involved. + TransactionID id.VerificationTransactionID `json:"transaction_id,omitempty"` +} + +var _ VerificationTransactionable = (*ToDeviceVerificationEvent)(nil) + +func (ve *ToDeviceVerificationEvent) GetTransactionID() id.VerificationTransactionID { + return ve.TransactionID +} + +func (ve *ToDeviceVerificationEvent) SetTransactionID(id id.VerificationTransactionID) { + ve.TransactionID = id +} + +// InRoomVerificationEvent contains the fields common to all in-room +// verification events. +type InRoomVerificationEvent struct { + // RelatesTo indicates the m.key.verification.request that this message is + // related to. Note that for encrypted messages, this property should be in + // the unencrypted portion of the event. RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` } -func (vrec *VerificationRequestEventContent) SupportsVerificationMethod(meth VerificationMethod) bool { - for _, supportedMeth := range vrec.Methods { - if supportedMeth == meth { - return true - } +var _ Relatable = (*InRoomVerificationEvent)(nil) + +func (ve *InRoomVerificationEvent) GetRelatesTo() *RelatesTo { + if ve.RelatesTo == nil { + ve.RelatesTo = &RelatesTo{} } - return false + return ve.RelatesTo +} + +func (ve *InRoomVerificationEvent) OptionalGetRelatesTo() *RelatesTo { + return ve.RelatesTo +} + +func (ve *InRoomVerificationEvent) SetRelatesTo(rel *RelatesTo) { + ve.RelatesTo = rel +} + +// VerificationRequestEventContent represents the content of an +// [m.key.verification.request] to-device event as described in [Section +// 11.12.2.1] of the Spec. +// +// For the in-room version, use a standard [MessageEventContent] struct. +// +// [m.key.verification.request]: https://spec.matrix.org/v1.9/client-server-api/#mkeyverificationrequest +// [Section 11.12.2.1]: https://spec.matrix.org/v1.9/client-server-api/#key-verification-framework +type VerificationRequestEventContent struct { + ToDeviceVerificationEvent + // FromDevice is the device ID which is initiating the request. + FromDevice id.DeviceID `json:"from_device"` + // Methods is a list of the verification methods supported by the sender. + Methods []VerificationMethod `json:"methods"` + // Timestamp is the time at which the request was made. + Timestamp jsontime.UnixMilli `json:"timestamp,omitempty"` +} + +// VerificationRequestEventContentFromMessage converts an in-room verification +// request message event to a [VerificationRequestEventContent]. +func VerificationRequestEventContentFromMessage(evt *Event) *VerificationRequestEventContent { + content := evt.Content.AsMessage() + return &VerificationRequestEventContent{ + ToDeviceVerificationEvent: ToDeviceVerificationEvent{ + TransactionID: id.VerificationTransactionID(evt.ID), + }, + Timestamp: jsontime.UMInt(evt.Timestamp), + FromDevice: content.FromDevice, + Methods: content.Methods, + } +} + +// SupportsVerificationMethod returns whether the given verification method is +// supported by the sender. +func (vrec *VerificationRequestEventContent) SupportsVerificationMethod(method VerificationMethod) bool { + return slices.Contains(vrec.Methods, method) +} + +// VerificationReadyEventContent represents the content of an +// [m.key.verification.ready] event (both the to-device and the in-room +// version) as described in [Section 11.12.2.1] of the Spec. +// +// [m.key.verification.ready]: https://spec.matrix.org/v1.9/client-server-api/#mkeyverificationready +// [Section 11.12.2.1]: https://spec.matrix.org/v1.9/client-server-api/#key-verification-framework +type VerificationReadyEventContent struct { + ToDeviceVerificationEvent + InRoomVerificationEvent + + // FromDevice is the device ID which is initiating the request. + FromDevice id.DeviceID `json:"from_device"` + // Methods is a list of the verification methods supported by the sender. + Methods []VerificationMethod `json:"methods"` } type KeyAgreementProtocol string const ( - KeyAgreementCurve25519 KeyAgreementProtocol = "curve25519" - KeyAgreementCurve25519HKDFSHA256 KeyAgreementProtocol = "curve25519-hkdf-sha256" + KeyAgreementProtocolCurve25519 KeyAgreementProtocol = "curve25519" + KeyAgreementProtocolCurve25519HKDFSHA256 KeyAgreementProtocol = "curve25519-hkdf-sha256" ) type VerificationHashMethod string -const VerificationHashSHA256 VerificationHashMethod = "sha256" +const VerificationHashMethodSHA256 VerificationHashMethod = "sha256" type MACMethod string -const HKDFHMACSHA256 MACMethod = "hkdf-hmac-sha256" +const ( + MACMethodHKDFHMACSHA256 MACMethod = "hkdf-hmac-sha256" + MACMethodHKDFHMACSHA256V2 MACMethod = "hkdf-hmac-sha256.v2" +) type SASMethod string const ( - SASDecimal SASMethod = "decimal" - SASEmoji SASMethod = "emoji" + SASMethodDecimal SASMethod = "decimal" + SASMethodEmoji SASMethod = "emoji" ) -// VerificationStartEventContent represents the content of a m.key.verification.start to_device event. -// https://spec.matrix.org/v1.2/client-server-api/#mkeyverificationstartmsasv1 +// VerificationStartEventContent represents the content of an +// [m.key.verification.start] event (both the to-device and the in-room +// version) as described in [Section 11.12.2.1] of the Spec. +// +// This struct also contains the fields for an [m.key.verification.start] event +// using the [VerificationMethodSAS] method as described in [Section +// 11.12.2.2.2] and an [m.key.verification.start] using +// [VerificationMethodReciprocate] as described in [Section 11.12.2.4.2]. +// +// [m.key.verification.start]: https://spec.matrix.org/v1.9/client-server-api/#mkeyverificationstart +// [Section 11.12.2.1]: https://spec.matrix.org/v1.9/client-server-api/#key-verification-framework +// [Section 11.12.2.2.2]: https://spec.matrix.org/v1.9/client-server-api/#verification-messages-specific-to-sas +// [Section 11.12.2.4.2]: https://spec.matrix.org/v1.9/client-server-api/#verification-messages-specific-to-qr-codes type VerificationStartEventContent struct { - // The device ID which is initiating the process. + ToDeviceVerificationEvent + InRoomVerificationEvent + + // FromDevice is the device ID which is initiating the request. FromDevice id.DeviceID `json:"from_device"` - // An opaque identifier for the verification process. Must be unique with respect to the devices involved. - TransactionID string `json:"transaction_id,omitempty"` - // The verification method to use. + // Method is the verification method to use. Method VerificationMethod `json:"method"` - // The key agreement protocols the sending device understands. - KeyAgreementProtocols []KeyAgreementProtocol `json:"key_agreement_protocols"` - // The hash methods the sending device understands. - Hashes []VerificationHashMethod `json:"hashes"` - // The message authentication codes that the sending device understands. + // NextMethod is an optional method to use to verify the other user's key. + // Applicable when the method chosen only verifies one user’s key. This + // field will never be present if the method verifies keys both ways. + NextMethod VerificationMethod `json:"next_method,omitempty"` + + // Hashes are the hash methods the sending device understands. This field + // is only applicable when the method is m.sas.v1. + Hashes []VerificationHashMethod `json:"hashes,omitempty"` + // KeyAgreementProtocols is the list of key agreement protocols the sending + // device understands. This field is only applicable when the method is + // m.sas.v1. + KeyAgreementProtocols []KeyAgreementProtocol `json:"key_agreement_protocols,omitempty"` + // MessageAuthenticationCodes is a list of the MAC methods that the sending + // device understands. This field is only applicable when the method is + // m.sas.v1. MessageAuthenticationCodes []MACMethod `json:"message_authentication_codes"` - // The SAS methods the sending device (and the sending device's user) understands. + // ShortAuthenticationString is a list of SAS methods the sending device + // (and the sending device's user) understands. This field is only + // applicable when the method is m.sas.v1. ShortAuthenticationString []SASMethod `json:"short_authentication_string"` - // The user that the event is sent to for in-room verification. - To id.UserID `json:"to,omitempty"` - // Original event ID for in-room verification. - RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` + + // Secret is the shared secret from the QR code. This field is only + // applicable when the method is m.reciprocate.v1. + Secret jsonbytes.UnpaddedBytes `json:"secret,omitempty"` } func (vsec *VerificationStartEventContent) SupportsKeyAgreementProtocol(proto KeyAgreementProtocol) bool { - for _, supportedProto := range vsec.KeyAgreementProtocols { - if supportedProto == proto { - return true - } - } - return false + return slices.Contains(vsec.KeyAgreementProtocols, proto) } func (vsec *VerificationStartEventContent) SupportsHashMethod(alg VerificationHashMethod) bool { - for _, supportedAlg := range vsec.Hashes { - if supportedAlg == alg { - return true - } - } - return false + return slices.Contains(vsec.Hashes, alg) } func (vsec *VerificationStartEventContent) SupportsMACMethod(meth MACMethod) bool { - for _, supportedMeth := range vsec.MessageAuthenticationCodes { - if supportedMeth == meth { - return true - } - } - return false + return slices.Contains(vsec.MessageAuthenticationCodes, meth) } func (vsec *VerificationStartEventContent) SupportsSASMethod(meth SASMethod) bool { - for _, supportedMeth := range vsec.ShortAuthenticationString { - if supportedMeth == meth { - return true - } - } - return false + return slices.Contains(vsec.ShortAuthenticationString, meth) } -func (vsec *VerificationStartEventContent) GetRelatesTo() *RelatesTo { - if vsec.RelatesTo == nil { - vsec.RelatesTo = &RelatesTo{} - } - return vsec.RelatesTo -} - -func (vsec *VerificationStartEventContent) OptionalGetRelatesTo() *RelatesTo { - return vsec.RelatesTo -} - -func (vsec *VerificationStartEventContent) SetRelatesTo(rel *RelatesTo) { - vsec.RelatesTo = rel -} - -// VerificationReadyEventContent represents the content of a m.key.verification.ready event. -// https://spec.matrix.org/v1.2/client-server-api/#mkeyverificationready -type VerificationReadyEventContent struct { - // The device ID which accepted the process. - FromDevice id.DeviceID `json:"from_device"` - // An opaque identifier for the verification process. Must be unique with respect to the devices involved. - TransactionID string `json:"transaction_id,omitempty"` - // The verification methods supported by the sender. - Methods []VerificationMethod `json:"methods"` - // Original event ID for in-room verification. - RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` -} - -var _ Relatable = (*VerificationReadyEventContent)(nil) - -func (vrec *VerificationReadyEventContent) GetRelatesTo() *RelatesTo { - if vrec.RelatesTo == nil { - vrec.RelatesTo = &RelatesTo{} - } - return vrec.RelatesTo -} - -func (vrec *VerificationReadyEventContent) OptionalGetRelatesTo() *RelatesTo { - return vrec.RelatesTo -} - -func (vrec *VerificationReadyEventContent) SetRelatesTo(rel *RelatesTo) { - vrec.RelatesTo = rel -} - -// VerificationDoneEventContent represents the content of a -// m.key.verification.done event as described in [Section 11.12.2.1] of the -// Matrix Spec. +// VerificationDoneEventContent represents the content of an +// [m.key.verification.done] event (both the to-device and the in-room version) +// as described in [Section 11.12.2.1] of the Spec. // +// This type is an alias for [VerificationRelatable] since there are no +// additional fields defined by the spec. +// +// [m.key.verification.done]: https://spec.matrix.org/v1.9/client-server-api/#mkeyverificationdone // [Section 11.12.2.1]: https://spec.matrix.org/v1.9/client-server-api/#mkeyverificationdone type VerificationDoneEventContent struct { - // The opaque identifier for the verification process/request. - TransactionID string `json:"transaction_id,omitempty"` - // Original event ID for in-room verification. - RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` + ToDeviceVerificationEvent + InRoomVerificationEvent } type VerificationCancelCode string const ( - VerificationCancelByUser VerificationCancelCode = "m.user" - VerificationCancelByTimeout VerificationCancelCode = "m.timeout" - VerificationCancelUnknownTransaction VerificationCancelCode = "m.unknown_transaction" - VerificationCancelUnknownMethod VerificationCancelCode = "m.unknown_method" - VerificationCancelUnexpectedMessage VerificationCancelCode = "m.unexpected_message" - VerificationCancelKeyMismatch VerificationCancelCode = "m.key_mismatch" - VerificationCancelUserMismatch VerificationCancelCode = "m.user_mismatch" - VerificationCancelInvalidMessage VerificationCancelCode = "m.invalid_message" - VerificationCancelAccepted VerificationCancelCode = "m.accepted" - VerificationCancelSASMismatch VerificationCancelCode = "m.mismatched_sas" - VerificationCancelCommitmentMismatch VerificationCancelCode = "m.mismatched_commitment" + VerificationCancelCodeUser VerificationCancelCode = "m.user" + VerificationCancelCodeTimeout VerificationCancelCode = "m.timeout" + VerificationCancelCodeUnknownTransaction VerificationCancelCode = "m.unknown_transaction" + VerificationCancelCodeUnknownMethod VerificationCancelCode = "m.unknown_method" + VerificationCancelCodeUnexpectedMessage VerificationCancelCode = "m.unexpected_message" + VerificationCancelCodeKeyMismatch VerificationCancelCode = "m.key_mismatch" + VerificationCancelCodeUserMismatch VerificationCancelCode = "m.user_mismatch" + VerificationCancelCodeInvalidMessage VerificationCancelCode = "m.invalid_message" + VerificationCancelCodeAccepted VerificationCancelCode = "m.accepted" + VerificationCancelCodeSASMismatch VerificationCancelCode = "m.mismatched_sas" + VerificationCancelCodeCommitmentMismatch VerificationCancelCode = "m.mismatched_commitment" ) -// VerificationCancelEventContent represents the content of a m.key.verification.cancel to_device event. -// https://spec.matrix.org/v1.2/client-server-api/#mkeyverificationcancel +// VerificationCancelEventContent represents the content of an +// [m.key.verification.cancel] event (both the to-device and the in-room +// version) as described in [Section 11.12.2.1] of the Spec. +// +// [m.key.verification.cancel]: https://spec.matrix.org/v1.9/client-server-api/#mkeyverificationcancel +// [Section 11.12.2.1]: https://spec.matrix.org/v1.9/client-server-api/#mkeyverificationdone type VerificationCancelEventContent struct { - // The opaque identifier for the verification process/request. - TransactionID string `json:"transaction_id,omitempty"` - // A human readable description of the code. The client should only rely on this string if it does not understand the code. - Reason string `json:"reason"` - // The error code for why the process/request was cancelled by the user. + ToDeviceVerificationEvent + InRoomVerificationEvent + + // Code is the error code for why the process/request was cancelled by the + // user. Code VerificationCancelCode `json:"code"` - // The user that the event is sent to for in-room verification. - To id.UserID `json:"to,omitempty"` - // Original event ID for in-room verification. - RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` + // Reason is a human readable description of the code. The client should + // only rely on this string if it does not understand the code. + Reason string `json:"reason"` } -func (vcec *VerificationCancelEventContent) GetRelatesTo() *RelatesTo { - if vcec.RelatesTo == nil { - vcec.RelatesTo = &RelatesTo{} - } - return vcec.RelatesTo -} - -func (vcec *VerificationCancelEventContent) OptionalGetRelatesTo() *RelatesTo { - return vcec.RelatesTo -} - -func (vcec *VerificationCancelEventContent) SetRelatesTo(rel *RelatesTo) { - vcec.RelatesTo = rel -} - -// VerificationAcceptEventContent represents the content of a m.key.verification.accept to_device event. -// https://spec.matrix.org/v1.2/client-server-api/#mkeyverificationaccept +// VerificationAcceptEventContent represents the content of an +// [m.key.verification.accept] event (both the to-device and the in-room +// version) as described in [Section 11.12.2.2.2] of the Spec. +// +// [m.key.verification.accept]: https://spec.matrix.org/v1.9/client-server-api/#mkeyverificationaccept +// [Section 11.12.2.2.2]: https://spec.matrix.org/v1.9/client-server-api/#verification-messages-specific-to-sas type VerificationAcceptEventContent struct { - // An opaque identifier for the verification process. Must be the same as the one used for the m.key.verification.start message. - TransactionID string `json:"transaction_id,omitempty"` - // The verification method to use. - Method VerificationMethod `json:"method"` - // The key agreement protocol the device is choosing to use, out of the options in the m.key.verification.start message. - KeyAgreementProtocol KeyAgreementProtocol `json:"key_agreement_protocol"` - // The hash method the device is choosing to use, out of the options in the m.key.verification.start message. + ToDeviceVerificationEvent + InRoomVerificationEvent + + // Commitment is the hash of the concatenation of the device's ephemeral + // public key (encoded as unpadded base64) and the canonical JSON + // representation of the m.key.verification.start message. + Commitment jsonbytes.UnpaddedBytes `json:"commitment"` + // Hash is the hash method the device is choosing to use, out of the + // options in the m.key.verification.start message. Hash VerificationHashMethod `json:"hash"` - // The message authentication code the device is choosing to use, out of the options in the m.key.verification.start message. + // KeyAgreementProtocol is the key agreement protocol the device is + // choosing to use, out of the options in the m.key.verification.start + // message. + KeyAgreementProtocol KeyAgreementProtocol `json:"key_agreement_protocol"` + // MessageAuthenticationCode is the message authentication code the device + // is choosing to use, out of the options in the m.key.verification.start + // message. MessageAuthenticationCode MACMethod `json:"message_authentication_code"` - // The SAS methods both devices involved in the verification process understand. Must be a subset of the options in the m.key.verification.start message. + // ShortAuthenticationString is a list of SAS methods both devices involved + // in the verification process understand. Must be a subset of the options + // in the m.key.verification.start message. ShortAuthenticationString []SASMethod `json:"short_authentication_string"` - // The hash (encoded as unpadded base64) of the concatenation of the device's ephemeral public key (encoded as unpadded base64) and the canonical JSON representation of the m.key.verification.start message. - Commitment string `json:"commitment"` - // The user that the event is sent to for in-room verification. - To id.UserID `json:"to,omitempty"` - // Original event ID for in-room verification. - RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` } -func (vaec *VerificationAcceptEventContent) GetRelatesTo() *RelatesTo { - if vaec.RelatesTo == nil { - vaec.RelatesTo = &RelatesTo{} - } - return vaec.RelatesTo -} - -func (vaec *VerificationAcceptEventContent) OptionalGetRelatesTo() *RelatesTo { - return vaec.RelatesTo -} - -func (vaec *VerificationAcceptEventContent) SetRelatesTo(rel *RelatesTo) { - vaec.RelatesTo = rel -} - -// VerificationKeyEventContent represents the content of a m.key.verification.key to_device event. -// https://spec.matrix.org/v1.2/client-server-api/#mkeyverificationkey +// VerificationKeyEventContent represents the content of an +// [m.key.verification.key] event (both the to-device and the in-room version) +// as described in [Section 11.12.2.2.2] of the Spec. +// +// [m.key.verification.key]: https://spec.matrix.org/v1.9/client-server-api/#mkeyverificationkey +// [Section 11.12.2.2.2]: https://spec.matrix.org/v1.9/client-server-api/#verification-messages-specific-to-sas type VerificationKeyEventContent struct { - // An opaque identifier for the verification process. Must be the same as the one used for the m.key.verification.start message. - TransactionID string `json:"transaction_id,omitempty"` - // The device's ephemeral public key, encoded as unpadded base64. - Key string `json:"key"` - // The user that the event is sent to for in-room verification. - To id.UserID `json:"to,omitempty"` - // Original event ID for in-room verification. - RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` + ToDeviceVerificationEvent + InRoomVerificationEvent + + // Key is the device’s ephemeral public key. + Key jsonbytes.UnpaddedBytes `json:"key"` } -func (vkec *VerificationKeyEventContent) GetRelatesTo() *RelatesTo { - if vkec.RelatesTo == nil { - vkec.RelatesTo = &RelatesTo{} - } - return vkec.RelatesTo -} - -func (vkec *VerificationKeyEventContent) OptionalGetRelatesTo() *RelatesTo { - return vkec.RelatesTo -} - -func (vkec *VerificationKeyEventContent) SetRelatesTo(rel *RelatesTo) { - vkec.RelatesTo = rel -} - -// VerificationMacEventContent represents the content of a m.key.verification.mac to_device event. -// https://spec.matrix.org/v1.2/client-server-api/#mkeyverificationmac +// VerificationMacEventContent represents the content of an +// [m.key.verification.mac] event (both the to-device and the in-room version) +// as described in [Section 11.12.2.2.2] of the Spec. +// +// [m.key.verification.mac]: https://spec.matrix.org/v1.9/client-server-api/#mkeyverificationmac +// [Section 11.12.2.2.2]: https://spec.matrix.org/v1.9/client-server-api/#verification-messages-specific-to-sas type VerificationMacEventContent struct { - // An opaque identifier for the verification process. Must be the same as the one used for the m.key.verification.start message. - TransactionID string `json:"transaction_id,omitempty"` - // A map of the key ID to the MAC of the key, using the algorithm in the verification process. The MAC is encoded as unpadded base64. - Mac map[id.KeyID]string `json:"mac"` - // The MAC of the comma-separated, sorted, list of key IDs given in the mac property, encoded as unpadded base64. - Keys string `json:"keys"` - // The user that the event is sent to for in-room verification. - To id.UserID `json:"to,omitempty"` - // Original event ID for in-room verification. - RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` -} + ToDeviceVerificationEvent + InRoomVerificationEvent -func (vmec *VerificationMacEventContent) GetRelatesTo() *RelatesTo { - if vmec.RelatesTo == nil { - vmec.RelatesTo = &RelatesTo{} - } - return vmec.RelatesTo -} - -func (vmec *VerificationMacEventContent) OptionalGetRelatesTo() *RelatesTo { - return vmec.RelatesTo -} - -func (vmec *VerificationMacEventContent) SetRelatesTo(rel *RelatesTo) { - vmec.RelatesTo = rel + // Keys is the MAC of the comma-separated, sorted, list of key IDs given in + // the MAC property. + Keys jsonbytes.UnpaddedBytes `json:"keys"` + // MAC is a map of the key ID to the MAC of the key, using the algorithm in + // the verification process. + MAC map[id.KeyID]jsonbytes.UnpaddedBytes `json:"mac"` } From 7469dcf9190efbc49123dbf075a84170c8f0a272 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 19 Jan 2024 16:26:50 -0700 Subject: [PATCH 0113/1647] id/crypto: add VerificationTransactionID Signed-off-by: Sumner Evans --- id/crypto.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/id/crypto.go b/id/crypto.go index 9334198e..48a63e78 100644 --- a/id/crypto.go +++ b/id/crypto.go @@ -9,6 +9,8 @@ package id import ( "fmt" "strings" + + "go.mau.fi/util/random" ) // OlmMsgType is an Olm message type @@ -174,3 +176,15 @@ const ( SecretXSUserSigning Secret = "m.cross_signing.user_signing" SecretMegolmBackupV1 Secret = "m.megolm_backup.v1" ) + +// VerificationTransactionID is a unique identifier for a verification +// transaction. +type VerificationTransactionID string + +func NewVerificationTransactionID() VerificationTransactionID { + return VerificationTransactionID(random.String(32)) +} + +func (t VerificationTransactionID) String() string { + return string(t) +} From f46d2d349a80d31218f305c9094e26e27fbe1b04 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 23 Jan 2024 14:49:00 -0700 Subject: [PATCH 0114/1647] verificationhelper/qrcode: add (en|de)coder Signed-off-by: Sumner Evans --- crypto/verificationhelper/qrcode.go | 98 ++++++++++++++++++++++++ crypto/verificationhelper/qrcode_test.go | 58 ++++++++++++++ 2 files changed, 156 insertions(+) create mode 100644 crypto/verificationhelper/qrcode.go create mode 100644 crypto/verificationhelper/qrcode_test.go diff --git a/crypto/verificationhelper/qrcode.go b/crypto/verificationhelper/qrcode.go new file mode 100644 index 00000000..a28d8fc3 --- /dev/null +++ b/crypto/verificationhelper/qrcode.go @@ -0,0 +1,98 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package verificationhelper + +import ( + "bytes" + "encoding/binary" + "errors" + + "go.mau.fi/util/random" + + "maunium.net/go/mautrix/id" +) + +var ( + ErrInvalidQRCodeHeader = errors.New("invalid QR code header") + ErrUnknownQRCodeVersion = errors.New("invalid QR code version") + ErrInvalidQRCodeMode = errors.New("invalid QR code mode") +) + +type QRCodeMode byte + +const ( + QRCodeModeCrossSigning QRCodeMode = 0x00 + QRCodeModeSelfVerifyingMasterKeyTrusted QRCodeMode = 0x01 + QRCodeModeSelfVerifyingMasterKeyUntrusted QRCodeMode = 0x02 +) + +type QRCode struct { + Mode QRCodeMode + TransactionID id.VerificationTransactionID + Key1, Key2 [32]byte + SharedSecret []byte +} + +func NewQRCode(mode QRCodeMode, txnID id.VerificationTransactionID, key1, key2 [32]byte) *QRCode { + return &QRCode{ + Mode: mode, + TransactionID: txnID, + Key1: key1, + Key2: key2, + SharedSecret: random.Bytes(16), + } +} + +// NewQRCodeFromBytes parses the bytes from a QR code scan as defined in +// [Section 11.12.2.4.1] of the Spec. +// +// [Section 11.12.2.4.1]: https://spec.matrix.org/v1.9/client-server-api/#qr-code-format +func NewQRCodeFromBytes(data []byte) (*QRCode, error) { + if !bytes.HasPrefix(data, []byte("MATRIX")) { + return nil, ErrInvalidQRCodeHeader + } + if data[6] != 0x02 { + return nil, ErrUnknownQRCodeVersion + } + if data[7] != 0x00 && data[7] != 0x01 && data[7] != 0x02 { + return nil, ErrInvalidQRCodeMode + } + transactionIDLength := binary.BigEndian.Uint16(data[8:10]) + transactionID := data[10 : 10+transactionIDLength] + + var key1, key2 [32]byte + copy(key1[:], data[10+transactionIDLength:10+transactionIDLength+32]) + copy(key2[:], data[10+transactionIDLength+32:10+transactionIDLength+64]) + + return &QRCode{ + Mode: QRCodeMode(data[7]), + TransactionID: id.VerificationTransactionID(transactionID), + Key1: key1, + Key2: key2, + SharedSecret: data[10+transactionIDLength+64:], + }, nil +} + +// Bytes returns the bytes that need to be encoded in the QR code as defined in +// [Section 11.12.2.4.1] of the Spec. +// +// [Section 11.12.2.4.1]: https://spec.matrix.org/v1.9/client-server-api/#qr-code-format +func (q *QRCode) Bytes() []byte { + var buf bytes.Buffer + buf.WriteString("MATRIX") // Header + buf.WriteByte(0x02) // Version + buf.WriteByte(byte(q.Mode)) // Mode + + // Transaction ID length + Transaction ID + buf.Write(binary.BigEndian.AppendUint16(nil, uint16(len(q.TransactionID.String())))) + buf.WriteString(q.TransactionID.String()) + + buf.Write(q.Key1[:]) // Key 1 + buf.Write(q.Key2[:]) // Key 2 + buf.Write(q.SharedSecret) // Shared secret + return buf.Bytes() +} diff --git a/crypto/verificationhelper/qrcode_test.go b/crypto/verificationhelper/qrcode_test.go new file mode 100644 index 00000000..d2767734 --- /dev/null +++ b/crypto/verificationhelper/qrcode_test.go @@ -0,0 +1,58 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package verificationhelper_test + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/crypto/verificationhelper" +) + +func TestQRCode_Roundtrip(t *testing.T) { + var key1, key2 [32]byte + copy(key1[:], bytes.Repeat([]byte{0x01}, 32)) + copy(key2[:], bytes.Repeat([]byte{0x02}, 32)) + qrCode := verificationhelper.NewQRCode(verificationhelper.QRCodeModeCrossSigning, "test", key1, key2) + + encoded := qrCode.Bytes() + decoded, err := verificationhelper.NewQRCodeFromBytes(encoded) + require.NoError(t, err) + + assert.Equal(t, verificationhelper.QRCodeModeCrossSigning, decoded.Mode) + assert.EqualValues(t, "test", decoded.TransactionID) + assert.Equal(t, key1, decoded.Key1) + assert.Equal(t, key2, decoded.Key2) +} + +func TestQRCodeDecode(t *testing.T) { + qrcodeData := []byte{ + 0x4d, 0x41, 0x54, 0x52, 0x49, 0x58, 0x02, 0x01, 0x00, 0x20, 0x47, 0x6e, 0x41, 0x65, 0x43, 0x76, + 0x74, 0x57, 0x6a, 0x7a, 0x4d, 0x4f, 0x56, 0x57, 0x51, 0x54, 0x6b, 0x74, 0x33, 0x35, 0x59, 0x52, + 0x55, 0x72, 0x75, 0x6a, 0x6d, 0x52, 0x50, 0x63, 0x38, 0x61, 0x18, 0x32, 0x7c, 0xc3, 0x8c, 0xc2, + 0xa6, 0xc2, 0xb5, 0xc2, 0xa7, 0x50, 0x57, 0x67, 0x19, 0x5e, 0xc3, 0xaf, 0xc2, 0xa0, 0xc2, 0x98, + 0xc2, 0x9d, 0x36, 0xc3, 0xad, 0x7a, 0x10, 0x2e, 0x18, 0x3e, 0x4e, 0xc3, 0x84, 0xc3, 0x81, 0x45, + 0x0c, 0xc2, 0xae, 0x19, 0x78, 0xc2, 0x99, 0x06, 0xc2, 0x92, 0xc2, 0x94, 0xc2, 0x8e, 0xc2, 0xb7, + 0x59, 0xc2, 0x96, 0xc2, 0xad, 0xc3, 0xbd, 0x70, 0x6a, 0x11, 0xc2, 0xba, 0xc2, 0xa9, 0x29, 0xc3, + 0x8f, 0x0d, 0xc2, 0xb8, 0xc2, 0x88, 0x67, 0x5b, 0xc3, 0xb3, 0x01, 0xc2, 0xb0, 0x63, 0x2e, 0xc2, + 0xa5, 0xc3, 0xb3, 0x60, 0xc3, 0x82, 0x04, 0xc3, 0xa3, 0x72, 0x7d, 0x7c, 0x1d, 0xc2, 0xb6, 0xc2, + 0xba, 0xc2, 0x81, 0x1e, 0xc2, 0x99, 0xc2, 0xb8, 0x7f, 0x0a, + } + decoded, err := verificationhelper.NewQRCodeFromBytes(qrcodeData) + require.NoError(t, err) + assert.Equal(t, verificationhelper.QRCodeModeSelfVerifyingMasterKeyTrusted, decoded.Mode) + assert.EqualValues(t, "GnAeCvtWjzMOVWQTkt35YRUrujmRPc8a", decoded.TransactionID) + assert.Equal(t, + [32]byte{0x18, 0x32, 0x7c, 0xc3, 0x8c, 0xc2, 0xa6, 0xc2, 0xb5, 0xc2, 0xa7, 0x50, 0x57, 0x67, 0x19, 0x5e, 0xc3, 0xaf, 0xc2, 0xa0, 0xc2, 0x98, 0xc2, 0x9d, 0x36, 0xc3, 0xad, 0x7a, 0x10, 0x2e, 0x18, 0x3e}, + decoded.Key1) + assert.Equal(t, + [32]byte{0x4e, 0xc3, 0x84, 0xc3, 0x81, 0x45, 0xc, 0xc2, 0xae, 0x19, 0x78, 0xc2, 0x99, 0x6, 0xc2, 0x92, 0xc2, 0x94, 0xc2, 0x8e, 0xc2, 0xb7, 0x59, 0xc2, 0x96, 0xc2, 0xad, 0xc3, 0xbd, 0x70, 0x6a, 0x11}, + decoded.Key2) +} From 582ce5de49872ca46ed0bd4ae7f57ecd5202ab84 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 22 Jan 2024 15:22:02 -0700 Subject: [PATCH 0115/1647] verificationhelper/qrcode: begin implementing flow Signed-off-by: Sumner Evans --- client.go | 14 + crypto/verificationhelper/reciprocate.go | 257 +++++++ .../verificationhelper/verificationhelper.go | 655 ++++++++++++++++++ event/content.go | 35 + id/crypto.go | 13 + 5 files changed, 974 insertions(+) create mode 100644 crypto/verificationhelper/reciprocate.go create mode 100644 crypto/verificationhelper/verificationhelper.go diff --git a/client.go b/client.go index 6562559f..c9523190 100644 --- a/client.go +++ b/client.go @@ -35,6 +35,19 @@ type CryptoHelper interface { Init(context.Context) error } +type VerificationHelper interface { + Init(context.Context) error + StartVerification(ctx context.Context, to id.UserID) (id.VerificationTransactionID, error) + StartInRoomVerification(ctx context.Context, roomID id.RoomID, to id.UserID) (id.VerificationTransactionID, error) + AcceptVerification(ctx context.Context, txnID id.VerificationTransactionID) error + + HandleScannedQRData(ctx context.Context, data []byte) error + ConfirmQRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) error + + StartSAS(ctx context.Context, txnID id.VerificationTransactionID) error + ConfirmSAS(ctx context.Context, txnID id.VerificationTransactionID) error +} + // Deprecated: switch to zerolog type Logger interface { Debugfln(message string, args ...interface{}) @@ -58,6 +71,7 @@ type Client struct { Store SyncStore // The thing which can store tokens/ids StateStore StateStore Crypto CryptoHelper + Verification VerificationHelper Log zerolog.Logger // Deprecated: switch to the zerolog instance in Log diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go new file mode 100644 index 00000000..e1c5d403 --- /dev/null +++ b/crypto/verificationhelper/reciprocate.go @@ -0,0 +1,257 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package verificationhelper + +import ( + "bytes" + "context" + "encoding/base64" + "fmt" + + "golang.org/x/exp/slices" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +// HandleScannedQRData verifies the keys from a scanned QR code and if +// successful, sends the m.key.verification.start event and +// m.key.verification.done event. +func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []byte) error { + qrCode, err := NewQRCodeFromBytes(data) + if err != nil { + return err + } + log := vh.getLog(ctx).With(). + Str("verification_action", "handle scanned QR data"). + Stringer("transaction_id", qrCode.TransactionID). + Int("mode", int(qrCode.Mode)). + Logger() + + txn, ok := vh.activeTransactions[qrCode.TransactionID] + if !ok { + log.Warn().Msg("Ignoring QR code scan for an unknown transaction") + return nil + } else if txn.VerificationStep != verificationStepReady { + log.Warn().Msg("Ignoring QR code scan for a transaction that is not in the ready state") + return nil + } + + // Verify the keys + log.Info().Msg("Verifying keys from QR code") + + switch qrCode.Mode { + case QRCodeModeCrossSigning: + // TODO + panic("unimplemented") + // TODO sign their master key + case QRCodeModeSelfVerifyingMasterKeyTrusted: + // The QR was created by a device that trusts the master key, which + // means that we don't trust the key. Key1 is the master key public + // key, and Key2 is what the other device thinks our device key is. + + if vh.client.UserID != txn.TheirUser { + return fmt.Errorf("mode %d is only allowed when the other user is the same as the current user", qrCode.Mode) + } + + // Verify the master key is correct + crossSigningPubkeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) + crossSigningMasterKeyBytes, err := base64.RawStdEncoding.DecodeString(crossSigningPubkeys.MasterKey.String()) + if err != nil { + return err + } + if bytes.Equal(crossSigningMasterKeyBytes, qrCode.Key1[:]) { + log.Info().Msg("Verified that the other device has the same master key") + } else { + return fmt.Errorf("the master key does not match") + } + + // Verify that the device key that the other device things we have is + // correct. + myDevice := vh.mach.OwnIdentity() + myDeviceKeyBytes, err := base64.RawStdEncoding.DecodeString(myDevice.IdentityKey.String()) + if err != nil { + return err + } + if bytes.Equal(myDeviceKeyBytes, qrCode.Key2[:]) { + log.Info().Msg("Verified that the other device has the correct key for this device") + } else { + return fmt.Errorf("the other device has the wrong key for this device") + } + + case QRCodeModeSelfVerifyingMasterKeyUntrusted: + // The QR was created by a device that does not trust the master key, + // which means that we do trust the master key. Key1 is the other + // device's device key, and Key2 is what the other device thinks the + // master key is. + + if vh.client.UserID != txn.TheirUser { + return fmt.Errorf("mode %d is only allowed when the other user is the same as the current user", qrCode.Mode) + } + + // Get their device + theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) + if err != nil { + return err + } + + // Verify that the other device's key is what we expect. + myDeviceKeyBytes, err := base64.RawStdEncoding.DecodeString(theirDevice.IdentityKey.String()) + if err != nil { + return err + } + if bytes.Equal(myDeviceKeyBytes, qrCode.Key1[:]) { + log.Info().Msg("Verified that the other device key is what we expected") + } else { + return fmt.Errorf("the other device's key is not what we expected") + } + + // Verify that what they think the master key is is correct. + crossSigningPubkeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) + crossSigningMasterKeyBytes, err := base64.RawStdEncoding.DecodeString(crossSigningPubkeys.MasterKey.String()) + if err != nil { + return err + } + if bytes.Equal(crossSigningMasterKeyBytes, qrCode.Key2[:]) { + log.Info().Msg("Verified that the other device has the correct master key") + } else { + return fmt.Errorf("the master key does not match") + } + + // Trust their device + theirDevice.Trust = id.TrustStateVerified + err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUser, theirDevice) + if err != nil { + return fmt.Errorf("failed to update device trust state after verifying: %w", err) + } + + // TODO Cross-sign their device with the cross-signing key + default: + return fmt.Errorf("unknown QR code mode %d", qrCode.Mode) + } + + // Send a m.key.verification.start event with the secret + startEvt := &event.VerificationStartEventContent{ + FromDevice: vh.client.DeviceID, + Method: event.VerificationMethodReciprocate, + Secret: qrCode.SharedSecret, + } + err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationStart, startEvt) + if err != nil { + return err + } + + // Immediately send the m.key.verification.done event, as our side of the + // transaction is done. + err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) + if err != nil { + return err + } + + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + delete(vh.activeTransactions, txn.TransactionID) + + // Broadcast that the verification is complete. + vh.verificationDone(ctx, txn.TransactionID) + // TODO do we need to also somehow broadcast that we are now a trusted + // device? + + return nil +} + +func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) error { + log := vh.getLog(ctx).With(). + Str("verification_action", "confirm QR code scanned"). + Stringer("transaction_id", txnID). + Logger() + + txn, ok := vh.activeTransactions[txnID] + if !ok { + log.Warn().Msg("Ignoring QR code scan confirmation for an unknown transaction") + return nil + } else if txn.VerificationStep != verificationStepStarted { + log.Warn().Msg("Ignoring QR code scan confirmation for a transaction that is not in the started state") + return nil + } + log.Info().Msg("Confirming QR code scanned") + + // TODO trust the keys somehow + + err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) + if err != nil { + return err + } + + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + delete(vh.activeTransactions, txn.TransactionID) + + // Broadcast that the verification is complete. + vh.verificationDone(ctx, txn.TransactionID) + // TODO do we need to also somehow broadcast that we are now a trusted + // device? + + return nil +} + +func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *verificationTransaction) error { + log := vh.getLog(ctx).With(). + Str("verification_action", "generate and show QR code"). + Stringer("transaction_id", txn.TransactionID). + Logger() + if vh.showQRCode == nil || !slices.Contains(txn.SupportedMethods, event.VerificationMethodQRCodeShow) { + log.Warn().Msg("Ignoring QR code generation request as showing a QR code is not enabled") + return nil + } + + ownCrossSigningPublicKeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) + + mode := QRCodeModeCrossSigning + if vh.client.UserID == txn.TheirUser { + // This is a self-signing situation. + // TODO determine if it's trusted or not. + mode = QRCodeModeSelfVerifyingMasterKeyUntrusted + } + + var key1, key2 []byte + switch mode { + case QRCodeModeCrossSigning: + // Key 1 is the current user's master signing key. + key1 = ownCrossSigningPublicKeys.MasterKey.Bytes() + + // Key 2 is the other user's master signing key. + theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUser) + if err != nil { + return err + } + key2 = theirSigningKeys.MasterKey.Bytes() + case QRCodeModeSelfVerifyingMasterKeyTrusted: + // Key 1 is the current user's master signing key. + key1 = ownCrossSigningPublicKeys.MasterKey.Bytes() + + // Key 2 is the other device's key. + theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) + if err != nil { + return err + } + key2 = theirDevice.IdentityKey.Bytes() + case QRCodeModeSelfVerifyingMasterKeyUntrusted: + // Key 1 is the current device's key + key1 = vh.mach.OwnIdentity().IdentityKey.Bytes() + + // Key 2 is the master signing key. + key2 = ownCrossSigningPublicKeys.MasterKey.Bytes() + default: + log.Fatal().Str("mode", string(mode)).Msg("Unknown QR code mode") + } + + qrCode := NewQRCode(mode, txn.TransactionID, [32]byte(key1), [32]byte(key2)) + txn.QRCodeSharedSecret = qrCode.SharedSecret + vh.showQRCode(ctx, txn.TransactionID, qrCode) + return nil +} diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go new file mode 100644 index 00000000..6dee8b7b --- /dev/null +++ b/crypto/verificationhelper/verificationhelper.go @@ -0,0 +1,655 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package verificationhelper + +import ( + "bytes" + "context" + "crypto/ecdh" + "errors" + "fmt" + + "github.com/rs/zerolog" + "go.mau.fi/util/jsontime" + "golang.org/x/exp/maps" + "golang.org/x/exp/slices" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type verificationStep int + +const ( + verificationStepRequested verificationStep = iota + verificationStepReady + verificationStepStarted +) + +func (step verificationStep) String() string { + switch step { + case verificationStepRequested: + return "requested" + case verificationStepReady: + return "ready" + case verificationStepStarted: + return "started" + default: + return fmt.Sprintf("verificationStep(%d)", step) + } +} + +type verificationTransaction struct { + // RoomID is the room ID if the verification is happening in a room or + // empty if it is a to-device verification. + RoomID id.RoomID + + // VerificationStep is the current step of the verification flow. + VerificationStep verificationStep + // TransactionID is the ID of the verification transaction. + TransactionID id.VerificationTransactionID + + // TheirDevice is the device ID of the device that either made the initial + // request or accepted our request. + TheirDevice id.DeviceID + + // TheirUser is the user ID of the other user. + TheirUser id.UserID + + // SentToDeviceIDs is a list of devices which the initial request was sent + // to. This is only used for to-device verification requests, and is meant + // to be used to send cancellation requests to all other devices when a + // verification request is accepted via a m.key.verification.ready event. + SentToDeviceIDs []id.DeviceID + + // SupportedMethods is a list of verification methods that the other device + // supports. + SupportedMethods []event.VerificationMethod + + // QRCodeSharedSecret is the shared secret that was encoded in the QR code + // that we showed. + QRCodeSharedSecret []byte + + StartedByUs bool // Whether the verification was started by us + StartEventContent *event.VerificationStartEventContent // The m.key.verification.start event content + Commitment []byte // The commitment from the m.key.verification.accept event + EphemeralKey *ecdh.PrivateKey // The ephemeral key + EphemeralPublicKeyShared bool // Whether this device's ephemeral public key has been shared + OtherPublicKey *ecdh.PublicKey // The other device's ephemeral public key +} + +// RequiredCallbacks is an interface representing the callbacks required for +// the [VerificationHelper]. +type RequiredCallbacks interface { + // VerificationRequested is called when a verification request is received + // from another device. + VerificationRequested(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID) + + // VerificationError is called when an error occurs during the verification + // process. + VerificationError(ctx context.Context, txnID id.VerificationTransactionID, err error) + + // VerificationCancelled is called when the verification is cancelled. + VerificationCancelled(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) + + // VerificationDone is called when the verification is done. + VerificationDone(ctx context.Context, txnID id.VerificationTransactionID) +} + +type showSASCallbacks interface { + // ShowSAS is called when the SAS verification has generated a short + // authentication string to show. It is guaranteed that either the emojis + // list, or the decimals list, or both will be present. + ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, decimals []int) +} + +type showQRCodeCallbacks interface { + // ShowQRCode is called when the verification has been accepted and a QR + // code should be shown to the user. + ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *QRCode) + + // QRCodeScanned is called when the other user has scanned the QR code and + // sent the m.key.verification.start event. + QRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) +} + +type VerificationHelper struct { + client *mautrix.Client + mach *crypto.OlmMachine + + activeTransactions map[id.VerificationTransactionID]*verificationTransaction + activeTransactionsLock sync.Mutex + + supportedMethods []event.VerificationMethod + verificationRequested func(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID) + verificationError func(ctx context.Context, txnID id.VerificationTransactionID, err error) + verificationCancelled func(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) + verificationDone func(ctx context.Context, txnID id.VerificationTransactionID) + + showSAS func(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, decimals []int) + + showQRCode func(ctx context.Context, txnID id.VerificationTransactionID, qrCode *QRCode) + qrCodeScaned func(ctx context.Context, txnID id.VerificationTransactionID) +} + +var _ mautrix.VerificationHelper = (*VerificationHelper)(nil) + +func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, callbacks any, supportsScan bool) *VerificationHelper { + if client.Crypto == nil { + panic("client.Crypto is nil") + } + + helper := VerificationHelper{ + client: client, + mach: mach, + activeTransactions: map[id.VerificationTransactionID]*verificationTransaction{}, + } + + if c, ok := callbacks.(RequiredCallbacks); !ok { + panic("callbacks must implement VerificationRequested") + } else { + helper.verificationRequested = c.VerificationRequested + helper.verificationError = func(ctx context.Context, txnID id.VerificationTransactionID, err error) { + zerolog.Ctx(ctx).Err(err).Msg("Verification error") + c.VerificationError(ctx, txnID, err) + } + helper.verificationCancelled = c.VerificationCancelled + helper.verificationDone = c.VerificationDone + } + + if c, ok := callbacks.(showEmojiCallbacks); ok { + helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodSAS) + helper.showEmojis = c.ShowEmojis + } + if c, ok := callbacks.(showDecimalCallbacks); ok { + helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodSAS) + helper.showDecimal = c.ShowDecimal + } + if c, ok := callbacks.(showQRCodeCallbacks); ok { + helper.supportedMethods = append(helper.supportedMethods, + event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate) + helper.showQRCode = c.ShowQRCode + helper.qrCodeScaned = c.QRCodeScanned + } + if supportsScan { + helper.supportedMethods = append(helper.supportedMethods, + event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate) + } + + slices.Sort(helper.supportedMethods) + helper.supportedMethods = slices.Compact(helper.supportedMethods) + return &helper +} + +func (vh *VerificationHelper) getLog(ctx context.Context) *zerolog.Logger { + logger := vh.client.Log.With(). + Any("supported_methods", vh.supportedMethods). + Str("component", "verification"). + Logger() + return &logger +} + +// Init initializes the verification helper by adding the necessary event +// handlers to the syncer. +func (vh *VerificationHelper) Init(ctx context.Context) error { + if vh == nil { + return fmt.Errorf("verification helper is nil") + } + syncer, ok := vh.client.Syncer.(mautrix.ExtensibleSyncer) + if !ok { + return fmt.Errorf("the client syncer must implement ExtensibleSyncer") + } + + // Event handlers for verification requests. These are special since we do + // not need to check that the transaction ID is known. + syncer.OnEventType(event.ToDeviceVerificationRequest, vh.onVerificationRequest) + syncer.OnEventType(event.EventMessage, func(ctx context.Context, evt *event.Event) { + if evt.Content.AsMessage().MsgType == event.MsgVerificationRequest { + vh.onVerificationRequest(ctx, evt) + } + }) + + // Wrapper for the event handlers to check that the transaction ID is known + // and ignore the event if it isn't. + wrapHandler := func(callback func(context.Context, *verificationTransaction, *event.Event)) func(context.Context, *event.Event) { + return func(ctx context.Context, evt *event.Event) { + log := vh.getLog(ctx).With(). + Str("verification_action", "check transaction ID"). + Stringer("sender", evt.Sender). + Stringer("room_id", evt.RoomID). + Stringer("event_id", evt.ID). + Logger() + + var transactionID id.VerificationTransactionID + if evt.ID != "" { + transactionID = id.VerificationTransactionID(evt.ID) + } else { + txnID, ok := evt.Content.Raw["transaction_id"].(string) + if !ok { + log.Warn().Msg("Ignoring verification event without a transaction ID") + return + } + transactionID = id.VerificationTransactionID(txnID) + } + vh.activeTransactionsLock.Lock() + txn, ok := vh.activeTransactions[transactionID] + vh.activeTransactionsLock.Unlock() + if !ok { + log.Warn(). + Stringer("transaction_id", transactionID). + Msg("Ignoring verification event for an unknown transaction") + + txn = &verificationTransaction{ + RoomID: evt.RoomID, + TheirUser: evt.Sender, + } + txn.TransactionID = evt.Content.Parsed.(event.VerificationTransactionable).GetTransactionID() + if txn.TransactionID == "" { + txn.TransactionID = id.VerificationTransactionID(evt.ID) + } + if fromDevice, ok := evt.Content.Raw["from_device"]; ok { + txn.TheirDevice = id.DeviceID(fromDevice.(string)) + } + cancelEvt := event.VerificationCancelEventContent{ + Code: event.VerificationCancelCodeUnknownTransaction, + Reason: "The transaction ID was not recognized.", + } + err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationCancel, &cancelEvt) + if err != nil { + log.Err(err).Msg("Failed to send cancellation event") + } + vh.verificationCancelled(ctx, txn.TransactionID, cancelEvt.Code, cancelEvt.Reason) + return + } + + logCtx := vh.getLog(ctx).With(). + Stringer("transaction_id", transactionID). + Stringer("transaction_step", txn.VerificationStep). + Stringer("sender", evt.Sender) + if evt.RoomID != "" { + logCtx = logCtx. + Stringer("room_id", evt.RoomID). + Stringer("event_id", evt.ID) + } + callback(logCtx.Logger().WithContext(ctx), txn, evt) + } + } + + // Event handlers for the to-device verification events. + syncer.OnEventType(event.ToDeviceVerificationReady, wrapHandler(vh.onVerificationReady)) + syncer.OnEventType(event.ToDeviceVerificationStart, wrapHandler(vh.onVerificationStart)) + syncer.OnEventType(event.ToDeviceVerificationDone, wrapHandler(vh.onVerificationDone)) + syncer.OnEventType(event.ToDeviceVerificationCancel, wrapHandler(vh.onVerificationCancel)) + syncer.OnEventType(event.ToDeviceVerificationAccept, wrapHandler(vh.onVerificationAccept)) // SAS + syncer.OnEventType(event.ToDeviceVerificationKey, wrapHandler(vh.onVerificationKey)) // SAS + syncer.OnEventType(event.ToDeviceVerificationMAC, wrapHandler(vh.onVerificationMAC)) // SAS + + // Event handlers for the in-room verification events. + syncer.OnEventType(event.InRoomVerificationReady, wrapHandler(vh.onVerificationReady)) + syncer.OnEventType(event.InRoomVerificationStart, wrapHandler(vh.onVerificationStart)) + syncer.OnEventType(event.InRoomVerificationDone, wrapHandler(vh.onVerificationDone)) + syncer.OnEventType(event.InRoomVerificationCancel, wrapHandler(vh.onVerificationCancel)) + syncer.OnEventType(event.InRoomVerificationAccept, wrapHandler(vh.onVerificationAccept)) // SAS + syncer.OnEventType(event.InRoomVerificationKey, wrapHandler(vh.onVerificationKey)) // SAS + syncer.OnEventType(event.InRoomVerificationMAC, wrapHandler(vh.onVerificationMAC)) // SAS + + return nil +} + +// StartVerification starts an interactive verification flow with the given +// user via a to-device event. +func (vh *VerificationHelper) StartVerification(ctx context.Context, to id.UserID) (id.VerificationTransactionID, error) { + txnID := id.NewVerificationTransactionID() + + devices, err := vh.mach.CryptoStore.GetDevices(ctx, to) + if err != nil { + return "", fmt.Errorf("failed to get devices for user: %w", err) + } + + vh.getLog(ctx).Info(). + Str("verification_action", "start verification"). + Stringer("transaction_id", txnID). + Stringer("to", to). + Any("device_ids", maps.Keys(devices)). + Msg("Sending verification request") + + content := &event.Content{ + Parsed: &event.VerificationRequestEventContent{ + ToDeviceVerificationEvent: event.ToDeviceVerificationEvent{TransactionID: txnID}, + FromDevice: vh.client.DeviceID, + Methods: vh.supportedMethods, + Timestamp: jsontime.UnixMilliNow(), + }, + } + + req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{to: {}}} + for deviceID := range devices { + if deviceID == vh.client.DeviceID { + // Don't ever send the event to the current device. We are likely + // trying to send a verification request to our other devices. + continue + } + + req.Messages[to][deviceID] = content + } + _, err = vh.client.SendToDevice(ctx, event.ToDeviceVerificationRequest, &req) + if err != nil { + return "", fmt.Errorf("failed to send verification request: %w", err) + } + + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + vh.activeTransactions[txnID] = &verificationTransaction{ + VerificationStep: verificationStepRequested, + TransactionID: txnID, + TheirUser: to, + SentToDeviceIDs: maps.Keys(devices), + } + return txnID, nil +} + +// StartVerification starts an interactive verification flow with the given +// user in the given room. +func (vh *VerificationHelper) StartInRoomVerification(ctx context.Context, roomID id.RoomID, to id.UserID) (id.VerificationTransactionID, error) { + log := vh.getLog(ctx).With(). + Str("verification_action", "start in-room verification"). + Stringer("room_id", roomID). + Stringer("to", to). + Logger() + + log.Info().Msg("Sending verification request") + content := event.MessageEventContent{ + MsgType: event.MsgVerificationRequest, + Body: "Alice is requesting to verify your device, but your client does not support verification, so you may need to use a different verification method.", + FromDevice: vh.client.DeviceID, + Methods: vh.supportedMethods, + To: to, + } + encryptedContent, err := vh.client.Crypto.Encrypt(ctx, roomID, event.EventMessage, &content) + if err != nil { + return "", fmt.Errorf("failed to encrypt verification request: %w", err) + } + resp, err := vh.client.SendMessageEvent(ctx, roomID, event.EventMessage, encryptedContent) + if err != nil { + return "", fmt.Errorf("failed to send verification request: %w", err) + } + + txnID := id.VerificationTransactionID(resp.EventID) + log.Info().Stringer("transaction_id", txnID).Msg("Got a transaction ID for the verification request") + + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + vh.activeTransactions[txnID] = &verificationTransaction{ + RoomID: roomID, + VerificationStep: verificationStepRequested, + TransactionID: txnID, + TheirUser: to, + } + return txnID, nil +} + +// AcceptVerification accepts a verification request. The transaction ID should +// be the transaction ID of a verification request that was received via the +// [RequiredCallbacks.VerificationRequested] callback. +func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.VerificationTransactionID) error { + log := vh.getLog(ctx).With(). + Str("verification_action", "accept verification"). + Stringer("transaction_id", txnID). + Logger() + + txn, ok := vh.activeTransactions[txnID] + if !ok { + return fmt.Errorf("unknown transaction ID") + } + + log.Info().Msg("Sending ready event") + readyEvt := &event.VerificationReadyEventContent{ + FromDevice: vh.client.DeviceID, + Methods: vh.supportedMethods, + } + err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationReady, readyEvt) + if err != nil { + return err + } + txn.VerificationStep = verificationStepReady + + return vh.generateAndShowQRCode(ctx, txn) +} + +// sendVerificationEvent sends a verification event to the other user's device +// setting the m.relates_to or transaction ID as necessary. +// +// Notes: +// +// - "content" must implement [event.Relatable] and +// [event.VerificationTransactionable]. +// - evtType can be either the to-device or in-room version of the event type +// as it is always stringified. +func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn *verificationTransaction, evtType event.Type, content any) error { + if txn.RoomID != "" { + content.(event.Relatable).SetRelatesTo(&event.RelatesTo{Type: event.RelReference, EventID: id.EventID(txn.TransactionID)}) + _, err := vh.client.SendMessageEvent(ctx, txn.RoomID, evtType, &event.Content{ + Parsed: content, + }) + if err != nil { + return fmt.Errorf("failed to send start event: %w", err) + } + } else { + content.(event.VerificationTransactionable).SetTransactionID(txn.TransactionID) + req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{ + txn.TheirUser: { + txn.TheirDevice: &event.Content{Parsed: content}, + }, + }} + _, err := vh.client.SendToDevice(ctx, evtType, &req) + if err != nil { + return fmt.Errorf("failed to send start event: %w", err) + } + } + return nil +} + +func (vh *VerificationHelper) onVerificationRequest(ctx context.Context, evt *event.Event) { + logCtx := vh.getLog(ctx).With(). + Str("verification_action", "verification request"). + Stringer("sender", evt.Sender) + if evt.RoomID != "" { + logCtx = logCtx. + Stringer("room_id", evt.RoomID). + Stringer("event_id", evt.ID) + } + log := logCtx.Logger() + + var verificationRequest *event.VerificationRequestEventContent + switch evt.Type { + case event.EventMessage: + to := evt.Content.AsMessage().To + if to != vh.client.UserID { + log.Info().Stringer("to", to).Msg("Ignoring verification request for another user") + return + } + + verificationRequest = event.VerificationRequestEventContentFromMessage(evt) + case event.ToDeviceVerificationRequest: + verificationRequest = evt.Content.AsVerificationRequest() + default: + log.Warn().Str("type", evt.Type.Type).Msg("Ignoring verification request of unknown type") + return + } + + if verificationRequest.FromDevice == vh.client.DeviceID { + log.Warn().Msg("Ignoring verification request from our own device. Why did it even get sent to us?") + return + } + + if verificationRequest.TransactionID == "" { + log.Warn().Msg("Ignoring verification request without a transaction ID") + return + } + + log = log.With().Any("requested_methods", verificationRequest.Methods).Logger() + ctx = log.WithContext(ctx) + log.Info().Msg("Received verification request") + + vh.activeTransactionsLock.Lock() + _, ok := vh.activeTransactions[verificationRequest.TransactionID] + if ok { + vh.activeTransactionsLock.Unlock() + log.Info().Msg("Ignoring verification request for an already active transaction") + return + } + vh.activeTransactions[verificationRequest.TransactionID] = &verificationTransaction{ + RoomID: evt.RoomID, + VerificationStep: verificationStepRequested, + TransactionID: verificationRequest.TransactionID, + TheirDevice: verificationRequest.FromDevice, + TheirUser: evt.Sender, + SupportedMethods: verificationRequest.Methods, + } + vh.activeTransactionsLock.Unlock() + + vh.verificationRequested(ctx, verificationRequest.TransactionID, evt.Sender) +} + +func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *verificationTransaction, evt *event.Event) { + log := vh.getLog(ctx).With(). + Str("verification_action", "verification ready"). + Logger() + + if txn.VerificationStep != verificationStepRequested { + log.Warn().Msg("Ignoring verification ready event for a transaction that is not in the requested state") + return + } + + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + + readyEvt := evt.Content.AsVerificationReady() + + // Update the transaction state. + txn.VerificationStep = verificationStepReady + txn.TheirDevice = readyEvt.FromDevice + txn.SupportedMethods = readyEvt.Methods + + // If we sent this verification request, send cancellations to all of the + // other devices. + if len(txn.SentToDeviceIDs) > 0 { + content := &event.Content{ + Parsed: &event.VerificationCancelEventContent{ + ToDeviceVerificationEvent: event.ToDeviceVerificationEvent{TransactionID: txn.TransactionID}, + Code: event.VerificationCancelCodeAccepted, + Reason: "The verification was accepted on another device.", + }, + } + devices, err := vh.mach.CryptoStore.GetDevices(ctx, txn.TheirUser) + if err != nil { + vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("failed to get devices for %s: %w", txn.TheirUser, err)) + return + } + req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUser: {}}} + for deviceID := range devices { + if deviceID == txn.TheirDevice { + // Don't ever send a cancellation to the device that accepted + // the request. + continue + } + + req.Messages[txn.TheirUser][deviceID] = content + } + _, err = vh.client.SendToDevice(ctx, event.ToDeviceVerificationRequest, &req) + if err != nil { + log.Warn().Err(err).Msg("Failed to send cancellation requests") + } + } + err := vh.generateAndShowQRCode(ctx, txn) + if err != nil { + vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("failed to generate and show QR code: %w", err)) + } +} + +func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *verificationTransaction, evt *event.Event) { + startEvt := evt.Content.AsVerificationStart() + log := vh.getLog(ctx).With(). + Str("verification_action", "verification start"). + Str("method", string(startEvt.Method)). + Logger() + + if txn.VerificationStep != verificationStepReady { + log.Warn().Msg("Ignoring verification start event for a transaction that is not in the ready state") + return + } + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + txn.VerificationStep = verificationStepStarted + + switch startEvt.Method { + case event.VerificationMethodSAS: + // TODO + log.Info().Msg("Received SAS verification start event") + err := vh.onVerificationStartSAS(ctx, txn, evt) + if err != nil { + // TODO should we cancel on all errors? + vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("failed to handle SAS verification start: %w", err)) + } + case event.VerificationMethodReciprocate: + log.Info().Msg("Received reciprocate start event") + + if !bytes.Equal(txn.QRCodeSharedSecret, startEvt.Secret) { + vh.verificationError(ctx, txn.TransactionID, errors.New("reciprocated shared secret does not match")) + return + } + + vh.qrCodeScaned(ctx, txn.TransactionID) + default: + // Note that we should never get m.qr_code.show.v1 or m.qr_code.scan.v1 + // here, since the start command for scanning and showing QR codes + // should be of type m.reciprocate.v1. + log.Error().Str("method", string(startEvt.Method)).Msg("Unsupported verification method in start event") + } +} + +func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn *verificationTransaction, evt *event.Event) { + vh.getLog(ctx).Info(). + Str("verification_action", "done"). + Stringer("transaction_id", txn.TransactionID). + Msg("Verification done") + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + delete(vh.activeTransactions, txn.TransactionID) +} + +func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn *verificationTransaction, evt *event.Event) { + cancelEvt := evt.Content.AsVerificationCancel() + vh.getLog(ctx).Info(). + Str("verification_action", "cancel"). + Stringer("transaction_id", txn.TransactionID). + Str("cancel_code", string(cancelEvt.Code)). + Str("reason", cancelEvt.Reason). + Msg("Verification was cancelled") + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + delete(vh.activeTransactions, txn.TransactionID) + vh.verificationCancelled(ctx, txn.TransactionID, cancelEvt.Code, cancelEvt.Reason) +} + +// SAS verification events +func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn *verificationTransaction, evt *event.Event) { + // TODO + vh.getLog(ctx).Error().Any("evt", evt).Msg("ACCEPT UNIMPLEMENTED") +} + +func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn *verificationTransaction, evt *event.Event) { + // TODO + vh.getLog(ctx).Error().Any("evt", evt).Msg("KEY UNIMPLEMENTED") +} + +func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verificationTransaction, evt *event.Event) { + // TODO + vh.getLog(ctx).Error().Any("evt", evt).Msg("MAC UNIMPLEMENTED") +} diff --git a/event/content.go b/event/content.go index 0439d9a2..68b8fa03 100644 --- a/event/content.go +++ b/event/content.go @@ -513,3 +513,38 @@ func (content *Content) AsModPolicy() *ModPolicyContent { } return casted } +func (content *Content) AsVerificationRequest() *VerificationRequestEventContent { + casted, ok := content.Parsed.(*VerificationRequestEventContent) + if !ok { + return &VerificationRequestEventContent{} + } + return casted +} +func (content *Content) AsVerificationReady() *VerificationReadyEventContent { + casted, ok := content.Parsed.(*VerificationReadyEventContent) + if !ok { + return &VerificationReadyEventContent{} + } + return casted +} +func (content *Content) AsVerificationStart() *VerificationStartEventContent { + casted, ok := content.Parsed.(*VerificationStartEventContent) + if !ok { + return &VerificationStartEventContent{} + } + return casted +} +func (content *Content) AsVerificationDone() *VerificationDoneEventContent { + casted, ok := content.Parsed.(*VerificationDoneEventContent) + if !ok { + return &VerificationDoneEventContent{} + } + return casted +} +func (content *Content) AsVerificationCancel() *VerificationCancelEventContent { + casted, ok := content.Parsed.(*VerificationCancelEventContent) + if !ok { + return &VerificationCancelEventContent{} + } + return casted +} diff --git a/id/crypto.go b/id/crypto.go index 48a63e78..355a84a8 100644 --- a/id/crypto.go +++ b/id/crypto.go @@ -7,6 +7,7 @@ package id import ( + "encoding/base64" "fmt" "strings" @@ -74,6 +75,12 @@ func (ed25519 Ed25519) String() string { return string(ed25519) } +func (ed25519 Ed25519) Bytes() []byte { + val, _ := base64.RawStdEncoding.DecodeString(string(ed25519)) + // TODO handle errors + return val +} + func (ed25519 Ed25519) Fingerprint() string { spacedSigningKey := make([]byte, len(ed25519)+(len(ed25519)-1)/4) var ptr = 0 @@ -97,6 +104,12 @@ func (curve25519 Curve25519) String() string { return string(curve25519) } +func (curve25519 Curve25519) Bytes() []byte { + val, _ := base64.RawStdEncoding.DecodeString(string(curve25519)) + // TODO handle errors + return val +} + // A DeviceID is an arbitrary string that references a specific device. type DeviceID string From 2f279590facc71cc4538dc955b99d6c84f390fe2 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 26 Jan 2024 10:31:12 -0700 Subject: [PATCH 0116/1647] verificationhelper/sas: begin implementing flow Signed-off-by: Sumner Evans --- crypto/verificationhelper/sas.go | 562 ++++++++++++++++++ crypto/verificationhelper/sas_test.go | 28 + .../verificationhelper/verificationhelper.go | 40 +- event/content.go | 21 + 4 files changed, 621 insertions(+), 30 deletions(-) create mode 100644 crypto/verificationhelper/sas.go create mode 100644 crypto/verificationhelper/sas_test.go diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go new file mode 100644 index 00000000..a02cc34e --- /dev/null +++ b/crypto/verificationhelper/sas.go @@ -0,0 +1,562 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package verificationhelper + +import ( + "bytes" + "context" + "crypto/ecdh" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "strings" + + "go.mau.fi/util/jsonbytes" + "golang.org/x/crypto/hkdf" + "golang.org/x/exp/slices" + + "maunium.net/go/mautrix/crypto/canonicaljson" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +// StartSAS starts a SAS verification flow. The transaction ID should be the +// transaction ID of a verification request that was received via the +// [RequiredCallbacks.VerificationRequested] callback. +func (vh *VerificationHelper) StartSAS(ctx context.Context, txnID id.VerificationTransactionID) error { + log := vh.getLog(ctx).With(). + Str("verification_action", "accept verification"). + Stringer("transaction_id", txnID). + Logger() + + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + txn, ok := vh.activeTransactions[txnID] + if !ok { + return fmt.Errorf("unknown transaction ID") + } + txn.StartedByUs = true + if !slices.Contains(txn.SupportedMethods, event.VerificationMethodSAS) { + return fmt.Errorf("the other device does not support SAS verification") + } + + // Ensure that we have their device key. + _, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) + if err != nil { + log.Err(err).Msg("Failed to fetch device") + return err + } + + log.Info().Msg("Sending start event") + txn.StartEventContent = &event.VerificationStartEventContent{ + FromDevice: vh.client.DeviceID, + Method: event.VerificationMethodSAS, + + Hashes: []event.VerificationHashMethod{event.VerificationHashMethodSHA256}, + KeyAgreementProtocols: []event.KeyAgreementProtocol{event.KeyAgreementProtocolCurve25519HKDFSHA256}, + MessageAuthenticationCodes: []event.MACMethod{ + event.MACMethodHKDFHMACSHA256, + event.MACMethodHKDFHMACSHA256V2, + }, + ShortAuthenticationString: []event.SASMethod{ + event.SASMethodDecimal, + event.SASMethodEmoji, + }, + } + return vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationStart, txn.StartEventContent) +} + +// ConfirmSAS indicates that the user has confirmed that the SAS matches SAS +// shown on the other user's device. +func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.VerificationTransactionID) error { + log := vh.getLog(ctx).With(). + Str("verification_action", "confirm SAS"). + Stringer("transaction_id", txnID). + Logger() + + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + txn, ok := vh.activeTransactions[txnID] + if !ok { + return fmt.Errorf("unknown transaction ID") + } + + var err error + keys := map[id.KeyID]jsonbytes.UnpaddedBytes{} + + log.Info().Msg("Signing keys") + + // TODO actually sign some keys + // My device key + myDevice := vh.mach.OwnIdentity() + myDeviceKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, myDevice.DeviceID.String()) + keys[myDeviceKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUser, txn.TheirDevice, myDeviceKeyID.String(), myDevice.IdentityKey.String()) + if err != nil { + return err + } + + // Master signing key + // TODO how to detect whether or not we trust the master key? + + var keyIDs []string + for keyID := range keys { + keyIDs = append(keyIDs, keyID.String()) + } + slices.Sort(keyIDs) + keysMAC, err := vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUser, txn.TheirDevice, "KEY_IDS", strings.Join(keyIDs, ",")) + if err != nil { + return err + } + + macEventContent := &event.VerificationMacEventContent{ + Keys: keysMAC, + MAC: keys, + } + return vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationMAC, macEventContent) +} + +// onVerificationStartSAS handles the m.key.verification.start events with +// method of m.sas.v1 by implementing steps 4-7 of [Section 11.12.2.2] of the +// Spec. +// +// [Section 11.12.2.2]: https://spec.matrix.org/v1.9/client-server-api/#short-authentication-string-sas-verification +func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn *verificationTransaction, evt *event.Event) error { + startEvt := evt.Content.AsVerificationStart() + log := vh.getLog(ctx) + log.Info().Msg("Received SAS verification start event") + + _, err := vh.mach.GetOrFetchDevice(ctx, evt.Sender, startEvt.FromDevice) + if err != nil { + log.Err(err).Msg("Failed to fetch device") + return err + } + + keyAggreementProtocol := event.KeyAgreementProtocolCurve25519HKDFSHA256 + if !startEvt.SupportsKeyAgreementProtocol(keyAggreementProtocol) { + return fmt.Errorf("the other device does not support any key agreement protocols that we support") + } + + hashAlgorithm := event.VerificationHashMethodSHA256 + if !startEvt.SupportsHashMethod(hashAlgorithm) { + return fmt.Errorf("the other device does not support any hash algorithms that we support") + } + + macMethod := event.MACMethodHKDFHMACSHA256V2 + if !startEvt.SupportsMACMethod(macMethod) { + if startEvt.SupportsMACMethod(event.MACMethodHKDFHMACSHA256) { + macMethod = event.MACMethodHKDFHMACSHA256 + } else { + return fmt.Errorf("the other device does not support any message authentication codes that we support") + } + } + + var sasMethods []event.SASMethod + for _, sasMethod := range startEvt.ShortAuthenticationString { + if sasMethod == event.SASMethodDecimal || sasMethod == event.SASMethodEmoji { + sasMethods = append(sasMethods, sasMethod) + } + } + if len(sasMethods) == 0 { + return fmt.Errorf("the other device does not support any short authentication string methods that we support") + } + + ephemeralKey, err := ecdh.X25519().GenerateKey(rand.Reader) + if err != nil { + return fmt.Errorf("failed to generate ephemeral key: %w", err) + } + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + txn.MACMethod = macMethod + txn.EphemeralKey = ephemeralKey + txn.StartEventContent = startEvt + + commitment, err := calculateCommitment(ephemeralKey.PublicKey(), startEvt) + if err != nil { + return fmt.Errorf("failed to calculate commitment: %w", err) + } + + err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationAccept, &event.VerificationAcceptEventContent{ + Commitment: commitment, + Hash: hashAlgorithm, + KeyAgreementProtocol: keyAggreementProtocol, + MessageAuthenticationCode: macMethod, + ShortAuthenticationString: sasMethods, + }) + if err != nil { + return fmt.Errorf("failed to send accept event: %w", err) + } + return nil +} + +func calculateCommitment(ephemeralPubKey *ecdh.PublicKey, startEvt *event.VerificationStartEventContent) ([]byte, error) { + // The commitmentHashInput is the hash (encoded as unpadded base64) of the + // concatenation of the device's ephemeral public key (encoded as + // unpadded base64) and the canonical JSON representation of the + // m.key.verification.start message. + // + // I have no idea why they chose to base64-encode the public key before + // hashing it, but we are just stuck on that. + commitmentHashInput := sha256.New() + commitmentHashInput.Write([]byte(base64.RawStdEncoding.EncodeToString(ephemeralPubKey.Bytes()))) + encodedStartEvt, err := json.Marshal(startEvt) + if err != nil { + return nil, err + } + commitmentHashInput.Write(canonicaljson.CanonicalJSONAssumeValid(encodedStartEvt)) + return commitmentHashInput.Sum(nil), nil +} + +// onVerificationAccept handles the m.key.verification.accept SAS verification +// event. This follows Step 4 of [Section 11.12.2.2] of the Spec. +// +// [Section 11.12.2.2]: https://spec.matrix.org/v1.9/client-server-api/#short-authentication-string-sas-verification +func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn *verificationTransaction, evt *event.Event) { + acceptEvt := evt.Content.AsVerificationAccept() + log := vh.getLog(ctx).With(). + Str("verification_action", "accept"). + Stringer("transaction_id", txn.TransactionID). + Str("commitment", base64.RawStdEncoding.EncodeToString(acceptEvt.Commitment)). + Str("hash", string(acceptEvt.Hash)). + Str("key_agreement_protocol", string(acceptEvt.KeyAgreementProtocol)). + Str("message_authentication_code", string(acceptEvt.MessageAuthenticationCode)). + Any("short_authentication_string", acceptEvt.ShortAuthenticationString). + Logger() + log.Info().Msg("Received SAS verification accept event") + + ephemeralKey, err := ecdh.X25519().GenerateKey(rand.Reader) + if err != nil { + log.Err(err).Msg("Failed to generate ephemeral key") + return + } + + err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationKey, &event.VerificationKeyEventContent{ + Key: ephemeralKey.PublicKey().Bytes(), + }) + if err != nil { + log.Err(err).Msg("Failed to send key event") + return + } + + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + txn.MACMethod = acceptEvt.MessageAuthenticationCode + txn.Commitment = acceptEvt.Commitment + txn.EphemeralKey = ephemeralKey + txn.EphemeralPublicKeyShared = true +} + +func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn *verificationTransaction, evt *event.Event) { + log := vh.getLog(ctx).With(). + Str("verification_action", "key"). + Logger() + keyEvt := evt.Content.AsVerificationKey() + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + + var err error + txn.OtherPublicKey, err = ecdh.X25519().NewPublicKey(keyEvt.Key) + if err != nil { + log.Err(err).Msg("Failed to generate other public key") + return + } + + if txn.EphemeralPublicKeyShared { + // Verify that the commitment hash is correct + commitment, err := calculateCommitment(txn.OtherPublicKey, txn.StartEventContent) + if err != nil { + log.Err(err).Msg("Failed to calculate commitment") + return + } + if !bytes.Equal(commitment, txn.Commitment) { + err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationCancel, &event.VerificationCancelEventContent{ + Code: event.VerificationCancelCodeKeyMismatch, + Reason: "The key was not the one we expected.", + }) + if err != nil { + log.Err(err).Msg("Failed to send cancellation event") + } + return + } + } else { + err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationKey, &event.VerificationKeyEventContent{ + Key: txn.EphemeralKey.PublicKey().Bytes(), + }) + if err != nil { + log.Err(err).Msg("Failed to send key event") + return + } + txn.EphemeralPublicKeyShared = true + } + + sasBytes, err := vh.verificationSASHKDF(txn) + if err != nil { + log.Err(err).Msg("Failed to compute HKDF for SAS") + return + } + + var decimals []int + var emojis []rune + if txn.StartEventContent.SupportsSASMethod(event.SASMethodDecimal) { + decimals = []int{ + (int(sasBytes[0])<<5 | int(sasBytes[1])>>3) + 1000, + ((int(sasBytes[1])&0x07)<<10 | int(sasBytes[2])<<2 | int(sasBytes[3])>>6) + 1000, + ((int(sasBytes[3])&0x3f)<<7 | int(sasBytes[4])>>1) + 1000, + } + } + if txn.StartEventContent.SupportsSASMethod(event.SASMethodEmoji) { + sasNum := uint64(sasBytes[0])<<40 | uint64(sasBytes[1])<<32 | uint64(sasBytes[2])<<24 | + uint64(sasBytes[3])<<16 | uint64(sasBytes[4])<<8 | uint64(sasBytes[5]) + + for i := 0; i < 7; i++ { + // Right shift the number and then mask the lowest 6 bits. + emojiIdx := (sasNum >> uint(48-(i+1)*6)) & 0b111111 + emojis = append(emojis, allEmojis[emojiIdx]) + } + } + vh.showSAS(ctx, txn.TransactionID, emojis, decimals) +} + +func (vh *VerificationHelper) verificationSASHKDF(txn *verificationTransaction) ([]byte, error) { + sharedSecret, err := txn.EphemeralKey.ECDH(txn.OtherPublicKey) + if err != nil { + return nil, err + } + + // Perform the SAS HKDF calculation according to Section 11.12.2.2.4 of the + // Spec: + // https://spec.matrix.org/v1.9/client-server-api/#sas-hkdf-calculation + myInfo := strings.Join([]string{ + vh.client.UserID.String(), + vh.client.DeviceID.String(), + base64.RawStdEncoding.EncodeToString(txn.EphemeralKey.PublicKey().Bytes()), + }, "|") + + theirInfo := strings.Join([]string{ + txn.TheirUser.String(), + txn.TheirDevice.String(), + base64.RawStdEncoding.EncodeToString(txn.OtherPublicKey.Bytes()), + }, "|") + + var infoBuf bytes.Buffer + infoBuf.WriteString("MATRIX_KEY_VERIFICATION_SAS|") + if txn.StartedByUs { + infoBuf.WriteString(myInfo + "|" + theirInfo) + } else { + infoBuf.WriteString(theirInfo + "|" + myInfo) + } + infoBuf.WriteRune('|') + infoBuf.WriteString(txn.TransactionID.String()) + + reader := hkdf.New(sha256.New, sharedSecret, nil, infoBuf.Bytes()) + output := make([]byte, 6) + _, err = reader.Read(output) + return output, err +} + +// BrokenB64Encode implements the incorrect base64 serialization in libolm for +// the hkdf-hmac-sha256 MAC method. The bug is caused by the input and output +// buffers being equal to one another during the base64 encoding. +// +// This function is narrowly scoped to this specific bug, and does not work +// generally (it only supports if the input is 32-bytes). +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/3783 and +// https://gitlab.matrix.org/matrix-org/olm/-/merge_requests/16 for details. +// +// Deprecated: never use this. It is only here for compatibility with the +// broken libolm implementation. +func BrokenB64Encode(input []byte) string { + encodeBase64 := []byte{ + 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, + 0x49, 0x4A, 0x4B, 0x4C, 0x4D, 0x4E, 0x4F, 0x50, + 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, + 0x59, 0x5A, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, + 0x67, 0x68, 0x69, 0x6A, 0x6B, 0x6C, 0x6D, 0x6E, + 0x6F, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, + 0x77, 0x78, 0x79, 0x7A, 0x30, 0x31, 0x32, 0x33, + 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x2B, 0x2F, + } + + output := make([]byte, 43) + copy(output, input) + + pos := 0 + outputPos := 0 + for pos != 30 { + value := int32(output[pos]) + value <<= 8 + value |= int32(output[pos+1]) + value <<= 8 + value |= int32(output[pos+2]) + pos += 3 + output[outputPos] = encodeBase64[(value>>18)&0x3F] + output[outputPos+1] = encodeBase64[(value>>12)&0x3F] + output[outputPos+2] = encodeBase64[(value>>6)&0x3F] + output[outputPos+3] = encodeBase64[value&0x3F] + outputPos += 4 + } + // This is the mangling that libolm does to the base64 encoding. + value := int32(output[pos]) + value <<= 8 + value |= int32(output[pos+1]) + value <<= 2 + output[outputPos] = encodeBase64[(value>>12)&0x3F] + output[outputPos+1] = encodeBase64[(value>>6)&0x3F] + output[outputPos+2] = encodeBase64[value&0x3F] + return string(output) +} + +func (vh *VerificationHelper) verificationMACHKDF(txn *verificationTransaction, senderUser id.UserID, senderDevice id.DeviceID, receivingUser id.UserID, receivingDevice id.DeviceID, keyID, key string) ([]byte, error) { + sharedSecret, err := txn.EphemeralKey.ECDH(txn.OtherPublicKey) + if err != nil { + return nil, err + } + fmt.Printf("KEYID %s\n", keyID) + fmt.Printf("KEY %s\n", key) + + var infoBuf bytes.Buffer + infoBuf.WriteString("MATRIX_KEY_VERIFICATION_MAC") + infoBuf.WriteString(senderUser.String()) + infoBuf.WriteString(senderDevice.String()) + infoBuf.WriteString(receivingUser.String()) + infoBuf.WriteString(receivingDevice.String()) + infoBuf.WriteString(txn.TransactionID.String()) + infoBuf.WriteString(keyID) + + reader := hkdf.New(sha256.New, sharedSecret, nil, infoBuf.Bytes()) + macKey := make([]byte, 32) + _, err = reader.Read(macKey) + if err != nil { + return nil, err + } + + hash := hmac.New(sha256.New, macKey) + hash.Write([]byte(key)) + sum := hash.Sum(nil) + if txn.MACMethod == event.MACMethodHKDFHMACSHA256 { + fmt.Printf("MANGLING %v\n", sum) + fmt.Printf("%s\n", BrokenB64Encode(sum)) + sum, err = base64.RawStdEncoding.DecodeString(BrokenB64Encode(sum)) + if err != nil { + panic(err) + } + fmt.Printf("MANGLING %v\n", sum) + } + return sum, nil +} + +var allEmojis = []rune{ + '🐶', + '🐱', + '🦁', + '🐎', + '🦄', + '🐷', + '🐘', + '🐰', + '🐼', + '🐓', + '🐧', + '🐢', + '🐟', + '🐙', + '🦋', + '🌷', + '🌳', + '🌵', + '🍄', + '🌏', + '🌙', + '☁', + '🔥', + '🍌', + '🍎', + '🍓', + '🌽', + '🍕', + '🎂', + '❤', + '😀', + '🤖', + '🎩', + '👓', + '🔧', + '🎅', + '👍', + '☂', + '⌛', + '⏰', + '🎁', + '💡', + '📕', + '✏', + '📎', + '✂', + '🔒', + '🔑', + '🔨', + '☎', + '🏁', + '🚂', + '🚲', + '✈', + '🚀', + '🏆', + '⚽', + '🎸', + '🎺', + '🔔', + '⚓', + '🎧', + '📁', + '📌', +} + +func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verificationTransaction, evt *event.Event) { + log := vh.getLog(ctx).With(). + Str("verification_action", "mac"). + Logger() + log.Info().Msg("Received SAS verification MAC event") + macEvt := evt.Content.AsVerificationMAC() + jsonBytes, _ := json.Marshal(macEvt) + fmt.Printf("%s\n", jsonBytes) + var keyIDs []string + // for keyID, mac := range macEvt.MAC { + // log.Info().Str("key_id", keyID.String()).Msg("Received MAC for key") + // keyIDs = append(keyIDs, keyID.String()) + + // var key string + + // expectedMAC, err := vh.verificationMACHKDF(txn, txn.TheirUser, txn.TheirDevice, vh.client.UserID, vh.client.DeviceID, keyID.String(), key) + // if err != nil { + // vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("failed to calculate key MAC: %w", err)) + // return + // } + // if !bytes.Equal(expectedMAC, mac) { + // vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("MAC mismatch for key %s", keyID)) + // return + // } + // } + + log.Info().Msg("Verifying MAC for all sent keys") + expectedKeyMAC, err := vh.verificationMACHKDF(txn, txn.TheirUser, txn.TheirDevice, vh.client.UserID, vh.client.DeviceID, "KEY_IDS", strings.Join(keyIDs, ",")) + if err != nil { + vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("failed to calculate key list MAC: %w", err)) + return + } + fmt.Printf("%d %v\n", len(expectedKeyMAC), expectedKeyMAC) + fmt.Printf("%d %v\n", len(macEvt.Keys), macEvt.Keys) + if !bytes.Equal(expectedKeyMAC, macEvt.Keys) { + vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("key list MAC mismatch")) + return + } + + // TODO actually do a trust thing +} diff --git a/crypto/verificationhelper/sas_test.go b/crypto/verificationhelper/sas_test.go new file mode 100644 index 00000000..78e88b80 --- /dev/null +++ b/crypto/verificationhelper/sas_test.go @@ -0,0 +1,28 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package verificationhelper_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "maunium.net/go/mautrix/crypto/verificationhelper" +) + +func TestBrokenB64Encode(t *testing.T) { + // See example from the PR that fixed the issue: + // https://gitlab.matrix.org/matrix-org/olm/-/merge_requests/16 + input := []byte{ + 121, 105, 187, 19, 37, 94, 119, 248, 224, 34, 94, 29, 157, 5, + 15, 230, 246, 115, 236, 217, 80, 78, 56, 200, 80, 200, 82, 158, + 168, 179, 10, 230, + } + + b64 := verificationhelper.BrokenB64Encode(input) + assert.Equal(t, "eWm7NyVeVmXgbVhnYlZobllsWm9ibGxzV205aWJHeHo", b64) +} diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 6dee8b7b..3c34caa4 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -12,6 +12,7 @@ import ( "crypto/ecdh" "errors" "fmt" + "sync" "github.com/rs/zerolog" "go.mau.fi/util/jsontime" @@ -79,6 +80,7 @@ type verificationTransaction struct { StartedByUs bool // Whether the verification was started by us StartEventContent *event.VerificationStartEventContent // The m.key.verification.start event content Commitment []byte // The commitment from the m.key.verification.accept event + MACMethod event.MACMethod // The method used to calculate the MAC EphemeralKey *ecdh.PrivateKey // The ephemeral key EphemeralPublicKeyShared bool // Whether this device's ephemeral public key has been shared OtherPublicKey *ecdh.PublicKey // The other device's ephemeral public key @@ -103,9 +105,9 @@ type RequiredCallbacks interface { } type showSASCallbacks interface { - // ShowSAS is called when the SAS verification has generated a short - // authentication string to show. It is guaranteed that either the emojis - // list, or the decimals list, or both will be present. + // ShowSAS is a callback that is called when the SAS verification has + // generated a short authentication string to show. It is guaranteed that + // either the emojis list, or the decimals list, or both will be present. ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, decimals []int) } @@ -163,13 +165,9 @@ func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, call helper.verificationDone = c.VerificationDone } - if c, ok := callbacks.(showEmojiCallbacks); ok { + if c, ok := callbacks.(showSASCallbacks); ok { helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodSAS) - helper.showEmojis = c.ShowEmojis - } - if c, ok := callbacks.(showDecimalCallbacks); ok { - helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodSAS) - helper.showDecimal = c.ShowDecimal + helper.showSAS = c.ShowSAS } if c, ok := callbacks.(showQRCodeCallbacks); ok { helper.supportedMethods = append(helper.supportedMethods, @@ -579,6 +577,7 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *veri Str("verification_action", "verification start"). Str("method", string(startEvt.Method)). Logger() + ctx = log.WithContext(ctx) if txn.VerificationStep != verificationStepReady { log.Warn().Msg("Ignoring verification start event for a transaction that is not in the ready state") @@ -590,11 +589,7 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *veri switch startEvt.Method { case event.VerificationMethodSAS: - // TODO - log.Info().Msg("Received SAS verification start event") - err := vh.onVerificationStartSAS(ctx, txn, evt) - if err != nil { - // TODO should we cancel on all errors? + if err := vh.onVerificationStartSAS(ctx, txn, evt); err != nil { vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("failed to handle SAS verification start: %w", err)) } case event.VerificationMethodReciprocate: @@ -622,6 +617,7 @@ func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn *verif vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() delete(vh.activeTransactions, txn.TransactionID) + vh.verificationDone(ctx, txn.TransactionID) } func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn *verificationTransaction, evt *event.Event) { @@ -637,19 +633,3 @@ func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn *ver delete(vh.activeTransactions, txn.TransactionID) vh.verificationCancelled(ctx, txn.TransactionID, cancelEvt.Code, cancelEvt.Reason) } - -// SAS verification events -func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn *verificationTransaction, evt *event.Event) { - // TODO - vh.getLog(ctx).Error().Any("evt", evt).Msg("ACCEPT UNIMPLEMENTED") -} - -func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn *verificationTransaction, evt *event.Event) { - // TODO - vh.getLog(ctx).Error().Any("evt", evt).Msg("KEY UNIMPLEMENTED") -} - -func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verificationTransaction, evt *event.Event) { - // TODO - vh.getLog(ctx).Error().Any("evt", evt).Msg("MAC UNIMPLEMENTED") -} diff --git a/event/content.go b/event/content.go index 68b8fa03..2a2833d3 100644 --- a/event/content.go +++ b/event/content.go @@ -548,3 +548,24 @@ func (content *Content) AsVerificationCancel() *VerificationCancelEventContent { } return casted } +func (content *Content) AsVerificationAccept() *VerificationAcceptEventContent { + casted, ok := content.Parsed.(*VerificationAcceptEventContent) + if !ok { + return &VerificationAcceptEventContent{} + } + return casted +} +func (content *Content) AsVerificationKey() *VerificationKeyEventContent { + casted, ok := content.Parsed.(*VerificationKeyEventContent) + if !ok { + return &VerificationKeyEventContent{} + } + return casted +} +func (content *Content) AsVerificationMAC() *VerificationMacEventContent { + casted, ok := content.Parsed.(*VerificationMacEventContent) + if !ok { + return &VerificationMacEventContent{} + } + return casted +} From 340ab4239a331fa7adcd83af677ff6ba5741b22f Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Mon, 12 Feb 2024 14:00:33 +0200 Subject: [PATCH 0117/1647] Use device signing key to verify interactive verification Remove unnecessary base64 as well. --- crypto/verificationhelper/reciprocate.go | 27 +++++------------------- 1 file changed, 5 insertions(+), 22 deletions(-) diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index e1c5d403..cc1c33e7 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -9,7 +9,6 @@ package verificationhelper import ( "bytes" "context" - "encoding/base64" "fmt" "golang.org/x/exp/slices" @@ -60,11 +59,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by // Verify the master key is correct crossSigningPubkeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) - crossSigningMasterKeyBytes, err := base64.RawStdEncoding.DecodeString(crossSigningPubkeys.MasterKey.String()) - if err != nil { - return err - } - if bytes.Equal(crossSigningMasterKeyBytes, qrCode.Key1[:]) { + if bytes.Equal(crossSigningPubkeys.MasterKey.Bytes(), qrCode.Key1[:]) { log.Info().Msg("Verified that the other device has the same master key") } else { return fmt.Errorf("the master key does not match") @@ -72,12 +67,8 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by // Verify that the device key that the other device things we have is // correct. - myDevice := vh.mach.OwnIdentity() - myDeviceKeyBytes, err := base64.RawStdEncoding.DecodeString(myDevice.IdentityKey.String()) - if err != nil { - return err - } - if bytes.Equal(myDeviceKeyBytes, qrCode.Key2[:]) { + myKeys := vh.mach.OwnIdentity() + if bytes.Equal(myKeys.SigningKey.Bytes(), qrCode.Key2[:]) { log.Info().Msg("Verified that the other device has the correct key for this device") } else { return fmt.Errorf("the other device has the wrong key for this device") @@ -100,11 +91,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by } // Verify that the other device's key is what we expect. - myDeviceKeyBytes, err := base64.RawStdEncoding.DecodeString(theirDevice.IdentityKey.String()) - if err != nil { - return err - } - if bytes.Equal(myDeviceKeyBytes, qrCode.Key1[:]) { + if bytes.Equal(theirDevice.SigningKey.Bytes(), qrCode.Key1[:]) { log.Info().Msg("Verified that the other device key is what we expected") } else { return fmt.Errorf("the other device's key is not what we expected") @@ -112,11 +99,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by // Verify that what they think the master key is is correct. crossSigningPubkeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) - crossSigningMasterKeyBytes, err := base64.RawStdEncoding.DecodeString(crossSigningPubkeys.MasterKey.String()) - if err != nil { - return err - } - if bytes.Equal(crossSigningMasterKeyBytes, qrCode.Key2[:]) { + if bytes.Equal(crossSigningPubkeys.MasterKey.Bytes(), qrCode.Key2[:]) { log.Info().Msg("Verified that the other device has the correct master key") } else { return fmt.Errorf("the master key does not match") From 64d2b520054e94fe3ded038eae9dce072b99b129 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 12 Feb 2024 07:31:49 -0700 Subject: [PATCH 0118/1647] verificationhelper: only show QR code if other device can scan Signed-off-by: Sumner Evans --- crypto/verificationhelper/reciprocate.go | 2 +- crypto/verificationhelper/verificationhelper.go | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index cc1c33e7..3923a512 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -187,7 +187,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *ve Str("verification_action", "generate and show QR code"). Stringer("transaction_id", txn.TransactionID). Logger() - if vh.showQRCode == nil || !slices.Contains(txn.SupportedMethods, event.VerificationMethodQRCodeShow) { + if vh.showQRCode == nil || !slices.Contains(txn.SupportedMethods, event.VerificationMethodQRCodeScan) { log.Warn().Msg("Ignoring QR code generation request as showing a QR code is not enabled") return nil } diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 3c34caa4..e0e23df7 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -128,6 +128,7 @@ type VerificationHelper struct { activeTransactions map[id.VerificationTransactionID]*verificationTransaction activeTransactionsLock sync.Mutex + // supportedMethods are the methods that *we* support supportedMethods []event.VerificationMethod verificationRequested func(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID) verificationError func(ctx context.Context, txnID id.VerificationTransactionID, err error) From 6274cb650f66bc69c61306e968294f884da1b242 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 12 Feb 2024 07:34:08 -0700 Subject: [PATCH 0119/1647] verificationhelper: update warnings Signed-off-by: Sumner Evans --- crypto/verificationhelper/reciprocate.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index 3923a512..b6f5663b 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -187,8 +187,12 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *ve Str("verification_action", "generate and show QR code"). Stringer("transaction_id", txn.TransactionID). Logger() - if vh.showQRCode == nil || !slices.Contains(txn.SupportedMethods, event.VerificationMethodQRCodeScan) { - log.Warn().Msg("Ignoring QR code generation request as showing a QR code is not enabled") + if vh.showQRCode == nil { + log.Warn().Msg("Ignoring QR code generation request as showing a QR code is not enabled on this device") + return nil + } + if !slices.Contains(txn.SupportedMethods, event.VerificationMethodQRCodeScan) { + log.Warn().Msg("Ignoring QR code generation request as other device cannot scan QR codes") return nil } From 92362d93d28cc971a3becb703b97203b05c1dda7 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 12 Feb 2024 09:58:20 -0700 Subject: [PATCH 0120/1647] verificationhelper: update a few docstrings Signed-off-by: Sumner Evans --- crypto/verificationhelper/sas.go | 2 +- crypto/verificationhelper/verificationhelper.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index a02cc34e..3b237ecc 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -29,7 +29,7 @@ import ( // StartSAS starts a SAS verification flow. The transaction ID should be the // transaction ID of a verification request that was received via the -// [RequiredCallbacks.VerificationRequested] callback. +// VerificationRequested callback in [RequiredCallbacks]. func (vh *VerificationHelper) StartSAS(ctx context.Context, txnID id.VerificationTransactionID) error { log := vh.getLog(ctx).With(). Str("verification_action", "accept verification"). diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index e0e23df7..9a44fcb2 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -395,7 +395,7 @@ func (vh *VerificationHelper) StartInRoomVerification(ctx context.Context, roomI // AcceptVerification accepts a verification request. The transaction ID should // be the transaction ID of a verification request that was received via the -// [RequiredCallbacks.VerificationRequested] callback. +// VerificationRequested callback in [RequiredCallbacks]. func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.VerificationTransactionID) error { log := vh.getLog(ctx).With(). Str("verification_action", "accept verification"). From f1c0d68dccc9de83ee3ccf78085d6fdc0da96945 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 12 Feb 2024 17:41:11 -0700 Subject: [PATCH 0121/1647] verificationhelper/recipricate: properly check if the master key is trusted Signed-off-by: Sumner Evans --- crypto/verificationhelper/doc.go | 11 +++++++++++ crypto/verificationhelper/reciprocate.go | 18 ++++++++---------- crypto/verificationhelper/sas.go | 1 + .../verificationhelper/verificationhelper.go | 9 ++++++--- 4 files changed, 26 insertions(+), 13 deletions(-) create mode 100644 crypto/verificationhelper/doc.go diff --git a/crypto/verificationhelper/doc.go b/crypto/verificationhelper/doc.go new file mode 100644 index 00000000..29931654 --- /dev/null +++ b/crypto/verificationhelper/doc.go @@ -0,0 +1,11 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Package verificationhelper provides a helper for the interactive +// verification process according to [Section 11.12.2] of the Spec. +// +// [Section 11.12.2]: https://spec.matrix.org/v1.9/client-server-api/#device-verification +package verificationhelper diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index b6f5663b..7974f5c6 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -98,8 +98,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by } // Verify that what they think the master key is is correct. - crossSigningPubkeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) - if bytes.Equal(crossSigningPubkeys.MasterKey.Bytes(), qrCode.Key2[:]) { + if bytes.Equal(vh.mach.GetOwnCrossSigningPublicKeys(ctx).MasterKey.Bytes(), qrCode.Key2[:]) { log.Info().Msg("Verified that the other device has the correct master key") } else { return fmt.Errorf("the master key does not match") @@ -141,9 +140,6 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by // Broadcast that the verification is complete. vh.verificationDone(ctx, txn.TransactionID) - // TODO do we need to also somehow broadcast that we are now a trusted - // device? - return nil } @@ -176,9 +172,6 @@ func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id // Broadcast that the verification is complete. vh.verificationDone(ctx, txn.TransactionID) - // TODO do we need to also somehow broadcast that we are now a trusted - // device? - return nil } @@ -201,8 +194,13 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *ve mode := QRCodeModeCrossSigning if vh.client.UserID == txn.TheirUser { // This is a self-signing situation. - // TODO determine if it's trusted or not. - mode = QRCodeModeSelfVerifyingMasterKeyUntrusted + if trusted, err := vh.mach.IsUserTrusted(ctx, vh.client.UserID); err != nil { + return err + } else if trusted { + mode = QRCodeModeSelfVerifyingMasterKeyTrusted + } else { + mode = QRCodeModeSelfVerifyingMasterKeyUntrusted + } } var key1, key2 []byte diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index 3b237ecc..771863b7 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -42,6 +42,7 @@ func (vh *VerificationHelper) StartSAS(ctx context.Context, txnID id.Verificatio if !ok { return fmt.Errorf("unknown transaction ID") } + txn.VerificationStep = verificationStepStarted txn.StartedByUs = true if !slices.Contains(txn.SupportedMethods, event.VerificationMethodSAS) { return fmt.Errorf("the other device does not support SAS verification") diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 9a44fcb2..5044f459 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -580,12 +580,15 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *veri Logger() ctx = log.WithContext(ctx) - if txn.VerificationStep != verificationStepReady { + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + + if txn.VerificationStep == verificationStepStarted { + log.Info().Msg("Got a verification start request from the other device, but the verification is already in progress") + } else if txn.VerificationStep != verificationStepReady { log.Warn().Msg("Ignoring verification start event for a transaction that is not in the ready state") return } - vh.activeTransactionsLock.Lock() - defer vh.activeTransactionsLock.Unlock() txn.VerificationStep = verificationStepStarted switch startEvt.Method { From 08b1a6a00bce6eb6e0863ee93fef7c5a720d8444 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 13 Feb 2024 09:23:28 -0700 Subject: [PATCH 0122/1647] verificationhelper: sign other device with self-signing key if they don't trust the key Signed-off-by: Sumner Evans --- crypto/verificationhelper/reciprocate.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index 7974f5c6..6d426fcd 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -45,9 +45,8 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by switch qrCode.Mode { case QRCodeModeCrossSigning: - // TODO panic("unimplemented") - // TODO sign their master key + // TODO verify and sign their master key case QRCodeModeSelfVerifyingMasterKeyTrusted: // The QR was created by a device that trusts the master key, which // means that we don't trust the key. Key1 is the master key public @@ -111,7 +110,11 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by return fmt.Errorf("failed to update device trust state after verifying: %w", err) } - // TODO Cross-sign their device with the cross-signing key + // Cross-sign their device with the self-signing key + err = vh.mach.SignOwnDevice(ctx, theirDevice) + if err != nil { + return fmt.Errorf("failed to sign their device: %w", err) + } default: return fmt.Errorf("unknown QR code mode %d", qrCode.Mode) } From 2882267761aa446365d2b0f640667f73565f4b7c Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 13 Feb 2024 10:50:37 -0700 Subject: [PATCH 0123/1647] verificationhelper: trust and cross-sign on QR code confirmation Signed-off-by: Sumner Evans --- crypto/verificationhelper/reciprocate.go | 26 +++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index 6d426fcd..0e784598 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -162,7 +162,31 @@ func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id } log.Info().Msg("Confirming QR code scanned") - // TODO trust the keys somehow + if txn.TheirUser == vh.client.UserID { + // Self-signing situation. Trust their device. + + // Get their device + theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) + if err != nil { + return err + } + + // Trust their device + theirDevice.Trust = id.TrustStateVerified + err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUser, theirDevice) + if err != nil { + return fmt.Errorf("failed to update device trust state after verifying: %w", err) + } + + // Cross-sign their device with the self-signing key + if vh.mach.CrossSigningKeys != nil { + err = vh.mach.SignOwnDevice(ctx, theirDevice) + if err != nil { + return fmt.Errorf("failed to sign their device: %w", err) + } + } + } + // TODO: handle QR codes that are not self-signing situations err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { From bfca57156590cbe3391ca86bc067095a7cfb82ec Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 13 Feb 2024 18:11:28 -0700 Subject: [PATCH 0124/1647] verificationhelper: more thoroughly check verification states for QR process Signed-off-by: Sumner Evans --- crypto/verificationhelper/reciprocate.go | 33 ++- crypto/verificationhelper/sas.go | 34 ++- .../verificationhelper/verificationhelper.go | 194 ++++++++++++------ 3 files changed, 170 insertions(+), 91 deletions(-) diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index 0e784598..ddb8f62c 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -35,10 +35,11 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by if !ok { log.Warn().Msg("Ignoring QR code scan for an unknown transaction") return nil - } else if txn.VerificationStep != verificationStepReady { + } else if txn.VerificationState != verificationStateReady { log.Warn().Msg("Ignoring QR code scan for a transaction that is not in the ready state") return nil } + txn.VerificationState = verificationStateTheirQRScanned // Verify the keys log.Info().Msg("Verifying keys from QR code") @@ -120,46 +121,40 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by } // Send a m.key.verification.start event with the secret - startEvt := &event.VerificationStartEventContent{ + txn.StartEventContent = &event.VerificationStartEventContent{ FromDevice: vh.client.DeviceID, Method: event.VerificationMethodReciprocate, Secret: qrCode.SharedSecret, } - err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationStart, startEvt) + err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationStart, txn.StartEventContent) if err != nil { return err } // Immediately send the m.key.verification.done event, as our side of the // transaction is done. - err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) - if err != nil { - return err - } - - vh.activeTransactionsLock.Lock() - defer vh.activeTransactionsLock.Unlock() - delete(vh.activeTransactions, txn.TransactionID) - - // Broadcast that the verification is complete. - vh.verificationDone(ctx, txn.TransactionID) - return nil + return vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) } +// ConfirmQRCodeScanned confirms that our QR code has been scanned and sends the +// m.key.verification.done event to the other device. func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) error { log := vh.getLog(ctx).With(). Str("verification_action", "confirm QR code scanned"). Stringer("transaction_id", txnID). Logger() + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() txn, ok := vh.activeTransactions[txnID] if !ok { log.Warn().Msg("Ignoring QR code scan confirmation for an unknown transaction") return nil - } else if txn.VerificationStep != verificationStepStarted { + } else if txn.VerificationState != verificationStateOurQRScanned { log.Warn().Msg("Ignoring QR code scan confirmation for a transaction that is not in the started state") return nil } + log.Info().Msg("Confirming QR code scanned") if txn.TheirUser == vh.client.UserID { @@ -193,9 +188,7 @@ func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id return err } - vh.activeTransactionsLock.Lock() - defer vh.activeTransactionsLock.Unlock() - delete(vh.activeTransactions, txn.TransactionID) + txn.VerificationState = verificationStateDone // Broadcast that the verification is complete. vh.verificationDone(ctx, txn.TransactionID) @@ -211,7 +204,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *ve log.Warn().Msg("Ignoring QR code generation request as showing a QR code is not enabled on this device") return nil } - if !slices.Contains(txn.SupportedMethods, event.VerificationMethodQRCodeScan) { + if !slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeScan) { log.Warn().Msg("Ignoring QR code generation request as other device cannot scan QR codes") return nil } diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index 771863b7..f20d2b60 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -15,6 +15,7 @@ import ( "crypto/sha256" "encoding/base64" "encoding/json" + "errors" "fmt" "strings" @@ -41,10 +42,13 @@ func (vh *VerificationHelper) StartSAS(ctx context.Context, txnID id.Verificatio txn, ok := vh.activeTransactions[txnID] if !ok { return fmt.Errorf("unknown transaction ID") + } else if txn.VerificationState != verificationStateReady { + return errors.New("transaction is not in ready state") } - txn.VerificationStep = verificationStepStarted + + txn.VerificationState = verificationStateSASStarted txn.StartedByUs = true - if !slices.Contains(txn.SupportedMethods, event.VerificationMethodSAS) { + if !slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodSAS) { return fmt.Errorf("the other device does not support SAS verification") } @@ -55,6 +59,8 @@ func (vh *VerificationHelper) StartSAS(ctx context.Context, txnID id.Verificatio return err } + // TODO check if the other device already has sent a start event + log.Info().Msg("Sending start event") txn.StartEventContent = &event.VerificationStartEventContent{ FromDevice: vh.client.DeviceID, @@ -129,6 +135,11 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat // // [Section 11.12.2.2]: https://spec.matrix.org/v1.9/client-server-api/#short-authentication-string-sas-verification func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn *verificationTransaction, evt *event.Event) error { + if txn.VerificationState != verificationStateReady { + vh.unexpectedEvent(ctx, txn) + return nil // return nil since we already sent a cancellation event in vh.unexpectedEvent + } + startEvt := evt.Content.AsVerificationStart() log := vh.getLog(ctx) log.Info().Msg("Received SAS verification start event") @@ -172,8 +183,6 @@ func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn *v if err != nil { return fmt.Errorf("failed to generate ephemeral key: %w", err) } - vh.activeTransactionsLock.Lock() - defer vh.activeTransactionsLock.Unlock() txn.MACMethod = macMethod txn.EphemeralKey = ephemeralKey txn.StartEventContent = startEvt @@ -193,6 +202,7 @@ func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn *v if err != nil { return fmt.Errorf("failed to send accept event: %w", err) } + txn.VerificationState = verificationStateSASAccepted return nil } @@ -231,6 +241,13 @@ func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn *ver Logger() log.Info().Msg("Received SAS verification accept event") + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + if txn.VerificationState != verificationStateSASStarted { + vh.unexpectedEvent(ctx, txn) + return + } + ephemeralKey, err := ecdh.X25519().GenerateKey(rand.Reader) if err != nil { log.Err(err).Msg("Failed to generate ephemeral key") @@ -245,8 +262,7 @@ func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn *ver return } - vh.activeTransactionsLock.Lock() - defer vh.activeTransactionsLock.Unlock() + txn.VerificationState = verificationStateSASAccepted txn.MACMethod = acceptEvt.MessageAuthenticationCode txn.Commitment = acceptEvt.Commitment txn.EphemeralKey = ephemeralKey @@ -261,6 +277,11 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn *verifi vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() + if txn.VerificationState != verificationStateSASAccepted { + vh.unexpectedEvent(ctx, txn) + return + } + var err error txn.OtherPublicKey, err = ecdh.X25519().NewPublicKey(keyEvt.Key) if err != nil { @@ -295,6 +316,7 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn *verifi } txn.EphemeralPublicKeyShared = true } + txn.VerificationState = verificationStateSASKeysExchanged sasBytes, err := vh.verificationSASHKDF(txn) if err != nil { diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 5044f459..1d6ba912 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -15,6 +15,7 @@ import ( "sync" "github.com/rs/zerolog" + "github.com/rs/zerolog/log" "go.mau.fi/util/jsontime" "golang.org/x/exp/maps" "golang.org/x/exp/slices" @@ -25,22 +26,43 @@ import ( "maunium.net/go/mautrix/id" ) -type verificationStep int +type verificationState int const ( - verificationStepRequested verificationStep = iota - verificationStepReady - verificationStepStarted + verificationStateRequested verificationState = iota + verificationStateReady + verificationStateCancelled + verificationStateDone + + verificationStateTheirQRScanned // We scanned their QR code + verificationStateOurQRScanned // They scanned our QR code + + verificationStateSASStarted // An SAS verification has been started + verificationStateSASAccepted // An SAS verification has been accepted + verificationStateSASKeysExchanged // An SAS verification has exchanged keys + verificationStateSASMAC // An SAS verification has exchanged MACs ) -func (step verificationStep) String() string { +func (step verificationState) String() string { switch step { - case verificationStepRequested: + case verificationStateRequested: return "requested" - case verificationStepReady: + case verificationStateReady: return "ready" - case verificationStepStarted: - return "started" + case verificationStateCancelled: + return "cancelled" + case verificationStateTheirQRScanned: + return "their_qr_scanned" + case verificationStateOurQRScanned: + return "our_qr_scanned" + case verificationStateSASStarted: + return "sas_started" + case verificationStateSASAccepted: + return "sas_accepted" + case verificationStateSASKeysExchanged: + return "sas_keys_exchanged" + case verificationStateSASMAC: + return "sas_mac" default: return fmt.Sprintf("verificationStep(%d)", step) } @@ -51,17 +73,19 @@ type verificationTransaction struct { // empty if it is a to-device verification. RoomID id.RoomID - // VerificationStep is the current step of the verification flow. - VerificationStep verificationStep + // VerificationState is the current step of the verification flow. + VerificationState verificationState // TransactionID is the ID of the verification transaction. TransactionID id.VerificationTransactionID // TheirDevice is the device ID of the device that either made the initial // request or accepted our request. TheirDevice id.DeviceID - // TheirUser is the user ID of the other user. TheirUser id.UserID + // TheirSupportedMethods is a list of verification methods that the other + // device supports. + TheirSupportedMethods []event.VerificationMethod // SentToDeviceIDs is a list of devices which the initial request was sent // to. This is only used for to-device verification requests, and is meant @@ -69,10 +93,6 @@ type verificationTransaction struct { // verification request is accepted via a m.key.verification.ready event. SentToDeviceIDs []id.DeviceID - // SupportedMethods is a list of verification methods that the other device - // supports. - SupportedMethods []event.VerificationMethod - // QRCodeSharedSecret is the shared secret that was encoded in the QR code // that we showed. QRCodeSharedSecret []byte @@ -236,40 +256,55 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { } transactionID = id.VerificationTransactionID(txnID) } + log = log.With().Stringer("transaction_id", transactionID).Logger() + vh.activeTransactionsLock.Lock() txn, ok := vh.activeTransactions[transactionID] vh.activeTransactionsLock.Unlock() - if !ok { - log.Warn(). - Stringer("transaction_id", transactionID). - Msg("Ignoring verification event for an unknown transaction") + if !ok || txn.VerificationState == verificationStateCancelled || txn.VerificationState == verificationStateDone { + var code event.VerificationCancelCode + var reason string + if !ok { + log.Warn().Msg("Ignoring verification event for an unknown transaction and sending cancellation") - txn = &verificationTransaction{ - RoomID: evt.RoomID, - TheirUser: evt.Sender, + // We have to create a fake transaction so that the call to + // verificationCancelled works. + txn = &verificationTransaction{ + RoomID: evt.RoomID, + TheirUser: evt.Sender, + } + txn.TransactionID = evt.Content.Parsed.(event.VerificationTransactionable).GetTransactionID() + if txn.TransactionID == "" { + txn.TransactionID = id.VerificationTransactionID(evt.ID) + } + if fromDevice, ok := evt.Content.Raw["from_device"]; ok { + txn.TheirDevice = id.DeviceID(fromDevice.(string)) + } + code = event.VerificationCancelCodeUnknownTransaction + reason = "The transaction ID was not recognized." + } else if txn.VerificationState == verificationStateCancelled { + log.Warn().Msg("Ignoring verification event for a cancelled transaction") + code = event.VerificationCancelCodeUnexpectedMessage + reason = "The transaction is cancelled." + } else if txn.VerificationState == verificationStateDone { + code = event.VerificationCancelCodeUnexpectedMessage + reason = "The transaction is done." } - txn.TransactionID = evt.Content.Parsed.(event.VerificationTransactionable).GetTransactionID() - if txn.TransactionID == "" { - txn.TransactionID = id.VerificationTransactionID(evt.ID) - } - if fromDevice, ok := evt.Content.Raw["from_device"]; ok { - txn.TheirDevice = id.DeviceID(fromDevice.(string)) - } - cancelEvt := event.VerificationCancelEventContent{ - Code: event.VerificationCancelCodeUnknownTransaction, - Reason: "The transaction ID was not recognized.", - } - err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationCancel, &cancelEvt) + + // Send the actual cancellation event. + err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationCancel, &event.VerificationCancelEventContent{ + Code: code, + Reason: reason, + }) if err != nil { log.Err(err).Msg("Failed to send cancellation event") } - vh.verificationCancelled(ctx, txn.TransactionID, cancelEvt.Code, cancelEvt.Reason) + vh.verificationCancelled(ctx, txn.TransactionID, code, reason) return } logCtx := vh.getLog(ctx).With(). - Stringer("transaction_id", transactionID). - Stringer("transaction_step", txn.VerificationStep). + Stringer("transaction_step", txn.VerificationState). Stringer("sender", evt.Sender) if evt.RoomID != "" { logCtx = logCtx. @@ -345,10 +380,10 @@ func (vh *VerificationHelper) StartVerification(ctx context.Context, to id.UserI vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() vh.activeTransactions[txnID] = &verificationTransaction{ - VerificationStep: verificationStepRequested, - TransactionID: txnID, - TheirUser: to, - SentToDeviceIDs: maps.Keys(devices), + VerificationState: verificationStateRequested, + TransactionID: txnID, + TheirUser: to, + SentToDeviceIDs: maps.Keys(devices), } return txnID, nil } @@ -385,10 +420,10 @@ func (vh *VerificationHelper) StartInRoomVerification(ctx context.Context, roomI vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() vh.activeTransactions[txnID] = &verificationTransaction{ - RoomID: roomID, - VerificationStep: verificationStepRequested, - TransactionID: txnID, - TheirUser: to, + RoomID: roomID, + VerificationState: verificationStateRequested, + TransactionID: txnID, + TheirUser: to, } return txnID, nil } @@ -406,6 +441,9 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V if !ok { return fmt.Errorf("unknown transaction ID") } + if txn.VerificationState != verificationStateRequested { + return fmt.Errorf("transaction is not in the requested state") + } log.Info().Msg("Sending ready event") readyEvt := &event.VerificationReadyEventContent{ @@ -416,7 +454,7 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V if err != nil { return err } - txn.VerificationStep = verificationStepReady + txn.VerificationState = verificationStateReady return vh.generateAndShowQRCode(ctx, txn) } @@ -504,12 +542,12 @@ func (vh *VerificationHelper) onVerificationRequest(ctx context.Context, evt *ev return } vh.activeTransactions[verificationRequest.TransactionID] = &verificationTransaction{ - RoomID: evt.RoomID, - VerificationStep: verificationStepRequested, - TransactionID: verificationRequest.TransactionID, - TheirDevice: verificationRequest.FromDevice, - TheirUser: evt.Sender, - SupportedMethods: verificationRequest.Methods, + RoomID: evt.RoomID, + VerificationState: verificationStateRequested, + TransactionID: verificationRequest.TransactionID, + TheirDevice: verificationRequest.FromDevice, + TheirUser: evt.Sender, + TheirSupportedMethods: verificationRequest.Methods, } vh.activeTransactionsLock.Unlock() @@ -521,7 +559,7 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *veri Str("verification_action", "verification ready"). Logger() - if txn.VerificationStep != verificationStepRequested { + if txn.VerificationState != verificationStateRequested { log.Warn().Msg("Ignoring verification ready event for a transaction that is not in the requested state") return } @@ -532,9 +570,9 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *veri readyEvt := evt.Content.AsVerificationReady() // Update the transaction state. - txn.VerificationStep = verificationStepReady + txn.VerificationState = verificationStateReady txn.TheirDevice = readyEvt.FromDevice - txn.SupportedMethods = readyEvt.Methods + txn.TheirSupportedMethods = readyEvt.Methods // If we sent this verification request, send cancellations to all of the // other devices. @@ -583,36 +621,57 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *veri vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationStep == verificationStepStarted { - log.Info().Msg("Got a verification start request from the other device, but the verification is already in progress") - } else if txn.VerificationStep != verificationStepReady { + if txn.VerificationState != verificationStateReady { log.Warn().Msg("Ignoring verification start event for a transaction that is not in the ready state") return } - txn.VerificationStep = verificationStepStarted switch startEvt.Method { case event.VerificationMethodSAS: + txn.VerificationState = verificationStateSASStarted if err := vh.onVerificationStartSAS(ctx, txn, evt); err != nil { vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("failed to handle SAS verification start: %w", err)) + // TODO cancel? } case event.VerificationMethodReciprocate: log.Info().Msg("Received reciprocate start event") - if !bytes.Equal(txn.QRCodeSharedSecret, startEvt.Secret) { vh.verificationError(ctx, txn.TransactionID, errors.New("reciprocated shared secret does not match")) return } - + txn.VerificationState = verificationStateOurQRScanned vh.qrCodeScaned(ctx, txn.TransactionID) default: // Note that we should never get m.qr_code.show.v1 or m.qr_code.scan.v1 // here, since the start command for scanning and showing QR codes // should be of type m.reciprocate.v1. log.Error().Str("method", string(startEvt.Method)).Msg("Unsupported verification method in start event") + + cancelEvt := event.VerificationCancelEventContent{ + Code: event.VerificationCancelCodeUnknownMethod, + Reason: fmt.Sprintf("unknown method %s", startEvt.Method), + } + err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationCancel, &cancelEvt) + if err != nil { + log.Err(err).Msg("Failed to send cancellation event") + } + vh.verificationCancelled(ctx, txn.TransactionID, cancelEvt.Code, cancelEvt.Reason) } } +func (vh *VerificationHelper) unexpectedEvent(ctx context.Context, txn *verificationTransaction) { + cancelEvt := event.VerificationCancelEventContent{ + Code: event.VerificationCancelCodeUnexpectedMessage, + Reason: "Got event for a transaction that is not in the correct state", + } + err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationCancel, &cancelEvt) + if err != nil { + log.Err(err).Msg("Failed to send cancellation event") + } + txn.VerificationState = verificationStateCancelled + vh.verificationCancelled(ctx, txn.TransactionID, cancelEvt.Code, cancelEvt.Reason) +} + func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn *verificationTransaction, evt *event.Event) { vh.getLog(ctx).Info(). Str("verification_action", "done"). @@ -620,8 +679,13 @@ func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn *verif Msg("Verification done") vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - delete(vh.activeTransactions, txn.TransactionID) - vh.verificationDone(ctx, txn.TransactionID) + + if txn.VerificationState == verificationStateTheirQRScanned || txn.VerificationState == verificationStateSASMAC { + txn.VerificationState = verificationStateDone + vh.verificationDone(ctx, txn.TransactionID) + } else { + vh.unexpectedEvent(ctx, txn) + } } func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn *verificationTransaction, evt *event.Event) { @@ -634,6 +698,6 @@ func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn *ver Msg("Verification was cancelled") vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - delete(vh.activeTransactions, txn.TransactionID) + txn.VerificationState = verificationStateCancelled vh.verificationCancelled(ctx, txn.TransactionID, cancelEvt.Code, cancelEvt.Reason) } From 990b623244b5ed73d7bf09a19a3905a9a576b0a1 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 15 Feb 2024 21:26:29 -0700 Subject: [PATCH 0125/1647] pre-commit: prevent literal HTTP methods Signed-off-by: Sumner Evans --- .pre-commit-config.yaml | 5 ++ client.go | 140 ++++++++++++++++++++-------------------- 2 files changed, 75 insertions(+), 70 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a656f0a8..5fffa9fb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,3 +17,8 @@ repos: - "maunium.net/go/mautrix" - "-w" - id: go-vet-repo-mod + + - repo: https://github.com/beeper/pre-commit-go + rev: v0.3.1 + hooks: + - id: prevent-literal-http-methods diff --git a/client.go b/client.go index c9523190..4f94d986 100644 --- a/client.go +++ b/client.go @@ -122,7 +122,7 @@ func DiscoverClientAPI(ctx context.Context, serverName string) (*ClientWellKnown Path: "/.well-known/matrix/client", } - req, err := http.NewRequestWithContext(ctx, "GET", wellKnownURL.String(), nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnownURL.String(), nil) if err != nil { return nil, err } @@ -591,14 +591,14 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof func (cli *Client) Whoami(ctx context.Context) (resp *RespWhoami, err error) { urlPath := cli.BuildClientURL("v3", "account", "whoami") - _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } // CreateFilter makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3useruseridfilter func (cli *Client) CreateFilter(ctx context.Context, filter *Filter) (resp *RespCreateFilter, err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "filter") - _, err = cli.MakeRequest(ctx, "POST", urlPath, filter, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, filter, &resp) return } @@ -779,7 +779,7 @@ func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRe // GetLoginFlows fetches the login flows that the homeserver supports using https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3login func (cli *Client) GetLoginFlows(ctx context.Context) (resp *RespLoginFlows, err error) { urlPath := cli.BuildClientURL("v3", "login") - _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } @@ -823,7 +823,7 @@ func (cli *Client) Login(ctx context.Context, req *ReqLogin) (resp *RespLogin, e // This does not clear the credentials from the client instance. See ClearCredentials() instead. func (cli *Client) Logout(ctx context.Context) (resp *RespLogout, err error) { urlPath := cli.BuildClientURL("v3", "logout") - _, err = cli.MakeRequest(ctx, "POST", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, nil, &resp) return } @@ -831,21 +831,21 @@ func (cli *Client) Logout(ctx context.Context) (resp *RespLogout, err error) { // This does not clear the credentials from the client instance. See ClearCredentials() instead. func (cli *Client) LogoutAll(ctx context.Context) (resp *RespLogout, err error) { urlPath := cli.BuildClientURL("v3", "logout", "all") - _, err = cli.MakeRequest(ctx, "POST", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, nil, &resp) return } // Versions returns the list of supported Matrix versions on this homeserver. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientversions func (cli *Client) Versions(ctx context.Context) (resp *RespVersions, err error) { urlPath := cli.BuildClientURL("versions") - _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } // Capabilities returns capabilities on this homeserver. See https://spec.matrix.org/v1.3/client-server-api/#capabilities-negotiation func (cli *Client) Capabilities(ctx context.Context) (resp *RespCapabilities, err error) { urlPath := cli.BuildClientURL("v3", "capabilities") - _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } @@ -862,7 +862,7 @@ func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias, serverName strin } else { urlPath = cli.BuildClientURL("v3", "join", roomIDorAlias) } - _, err = cli.MakeRequest(ctx, "POST", urlPath, content, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, content, &resp) if err == nil && cli.StateStore != nil { err = cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipJoin) if err != nil { @@ -877,7 +877,7 @@ func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias, serverName strin // Unlike JoinRoom, this method can only be used to join rooms that the server already knows about. // It's mostly intended for bridges and other things where it's already certain that the server is in the room. func (cli *Client) JoinRoomByID(ctx context.Context, roomID id.RoomID) (resp *RespJoinRoom, err error) { - _, err = cli.MakeRequest(ctx, "POST", cli.BuildClientURL("v3", "rooms", roomID, "join"), nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, cli.BuildClientURL("v3", "rooms", roomID, "join"), nil, &resp) if err == nil && cli.StateStore != nil { err = cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipJoin) if err != nil { @@ -889,14 +889,14 @@ func (cli *Client) JoinRoomByID(ctx context.Context, roomID id.RoomID) (resp *Re func (cli *Client) GetProfile(ctx context.Context, mxid id.UserID) (resp *RespUserProfile, err error) { urlPath := cli.BuildClientURL("v3", "profile", mxid) - _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } // GetDisplayName returns the display name of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseriddisplayname func (cli *Client) GetDisplayName(ctx context.Context, mxid id.UserID) (resp *RespUserDisplayName, err error) { urlPath := cli.BuildClientURL("v3", "profile", mxid, "displayname") - _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } @@ -911,7 +911,7 @@ func (cli *Client) SetDisplayName(ctx context.Context, displayName string) (err s := struct { DisplayName string `json:"displayname"` }{displayName} - _, err = cli.MakeRequest(ctx, "PUT", urlPath, &s, nil) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &s, nil) return } @@ -922,7 +922,7 @@ func (cli *Client) GetAvatarURL(ctx context.Context, mxid id.UserID) (url id.Con AvatarURL id.ContentURI `json:"avatar_url"` }{} - _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &s) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &s) if err != nil { return } @@ -941,7 +941,7 @@ func (cli *Client) SetAvatarURL(ctx context.Context, url id.ContentURI) (err err s := struct { AvatarURL string `json:"avatar_url"` }{url.String()} - _, err = cli.MakeRequest(ctx, "PUT", urlPath, &s, nil) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &s, nil) if err != nil { return err } @@ -952,21 +952,21 @@ func (cli *Client) SetAvatarURL(ctx context.Context, url id.ContentURI) (err err // BeeperUpdateProfile sets custom fields in the user's profile. func (cli *Client) BeeperUpdateProfile(ctx context.Context, data map[string]any) (err error) { urlPath := cli.BuildClientURL("v3", "profile", cli.UserID) - _, err = cli.MakeRequest(ctx, "PATCH", urlPath, &data, nil) + _, err = cli.MakeRequest(ctx, http.MethodPatch, urlPath, &data, nil) return } // GetAccountData gets the user's account data of this type. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3useruseridaccount_datatype func (cli *Client) GetAccountData(ctx context.Context, name string, output interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "account_data", name) - _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, output) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, output) return } // SetAccountData sets the user's account data of this type. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3useruseridaccount_datatype func (cli *Client) SetAccountData(ctx context.Context, name string, data interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "account_data", name) - _, err = cli.MakeRequest(ctx, "PUT", urlPath, data, nil) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, data, nil) if err != nil { return err } @@ -977,14 +977,14 @@ func (cli *Client) SetAccountData(ctx context.Context, name string, data interfa // GetRoomAccountData gets the user's account data of this type in a specific room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3useruseridaccount_datatype func (cli *Client) GetRoomAccountData(ctx context.Context, roomID id.RoomID, name string, output interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "account_data", name) - _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, output) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, output) return } // SetRoomAccountData sets the user's account data of this type in a specific room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3useruseridroomsroomidaccount_datatype func (cli *Client) SetRoomAccountData(ctx context.Context, roomID id.RoomID, name string, data interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "account_data", name) - _, err = cli.MakeRequest(ctx, "PUT", urlPath, data, nil) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, data, nil) if err != nil { return err } @@ -1042,7 +1042,7 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event urlData := ClientURLPath{"v3", "rooms", roomID, "send", eventType.String(), txnID} urlPath := cli.BuildURLWithQuery(urlData, queryParams) - _, err = cli.MakeRequest(ctx, "PUT", urlPath, contentJSON, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp) return } @@ -1050,7 +1050,7 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (resp *RespSendEvent, err error) { urlPath := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey) - _, err = cli.MakeRequest(ctx, "PUT", urlPath, contentJSON, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp) if err == nil && cli.StateStore != nil { cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON) } @@ -1063,7 +1063,7 @@ func (cli *Client) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey}, map[string]string{ "ts": strconv.FormatInt(ts, 10), }) - _, err = cli.MakeRequest(ctx, "PUT", urlPath, contentJSON, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp) if err == nil && cli.StateStore != nil { cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON) } @@ -1117,7 +1117,7 @@ func (cli *Client) RedactEvent(ctx context.Context, roomID id.RoomID, eventID id txnID = cli.TxnID() } urlPath := cli.BuildClientURL("v3", "rooms", roomID, "redact", eventID, txnID) - _, err = cli.MakeRequest(ctx, "PUT", urlPath, req.Extra, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, req.Extra, &resp) return } @@ -1129,7 +1129,7 @@ func (cli *Client) RedactEvent(ctx context.Context, roomID id.RoomID, eventID id // fmt.Println("Room:", resp.RoomID) func (cli *Client) CreateRoom(ctx context.Context, req *ReqCreateRoom) (resp *RespCreateRoom, err error) { urlPath := cli.BuildClientURL("v3", "createRoom") - _, err = cli.MakeRequest(ctx, "POST", urlPath, req, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) if err == nil && cli.StateStore != nil { storeErr := cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipJoin) if storeErr != nil { @@ -1168,7 +1168,7 @@ func (cli *Client) LeaveRoom(ctx context.Context, roomID id.RoomID, optionalReq panic("invalid number of arguments to LeaveRoom") } u := cli.BuildClientURL("v3", "rooms", roomID, "leave") - _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, u, req, &resp) if err == nil && cli.StateStore != nil { err = cli.StateStore.SetMembership(ctx, roomID, cli.UserID, event.MembershipLeave) if err != nil { @@ -1181,14 +1181,14 @@ func (cli *Client) LeaveRoom(ctx context.Context, roomID id.RoomID, optionalReq // ForgetRoom forgets a room entirely. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidforget func (cli *Client) ForgetRoom(ctx context.Context, roomID id.RoomID) (resp *RespForgetRoom, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "forget") - _, err = cli.MakeRequest(ctx, "POST", u, struct{}{}, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, u, struct{}{}, &resp) return } // InviteUser invites a user to a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidinvite func (cli *Client) InviteUser(ctx context.Context, roomID id.RoomID, req *ReqInviteUser) (resp *RespInviteUser, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "invite") - _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, u, req, &resp) if err == nil && cli.StateStore != nil { err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipInvite) if err != nil { @@ -1201,14 +1201,14 @@ func (cli *Client) InviteUser(ctx context.Context, roomID id.RoomID, req *ReqInv // InviteUserByThirdParty invites a third-party identifier to a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidinvite-1 func (cli *Client) InviteUserByThirdParty(ctx context.Context, roomID id.RoomID, req *ReqInvite3PID) (resp *RespInviteUser, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "invite") - _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, u, req, &resp) return } // KickUser kicks a user from a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidkick func (cli *Client) KickUser(ctx context.Context, roomID id.RoomID, req *ReqKickUser) (resp *RespKickUser, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "kick") - _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, u, req, &resp) if err == nil && cli.StateStore != nil { err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipLeave) if err != nil { @@ -1221,7 +1221,7 @@ func (cli *Client) KickUser(ctx context.Context, roomID id.RoomID, req *ReqKickU // BanUser bans a user from a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidban func (cli *Client) BanUser(ctx context.Context, roomID id.RoomID, req *ReqBanUser) (resp *RespBanUser, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "ban") - _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, u, req, &resp) if err == nil && cli.StateStore != nil { err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipBan) if err != nil { @@ -1234,7 +1234,7 @@ func (cli *Client) BanUser(ctx context.Context, roomID id.RoomID, req *ReqBanUse // UnbanUser unbans a user from a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidunban func (cli *Client) UnbanUser(ctx context.Context, roomID id.RoomID, req *ReqUnbanUser) (resp *RespUnbanUser, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "unban") - _, err = cli.MakeRequest(ctx, "POST", u, req, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, u, req, &resp) if err == nil && cli.StateStore != nil { err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipLeave) if err != nil { @@ -1248,7 +1248,7 @@ func (cli *Client) UnbanUser(ctx context.Context, roomID id.RoomID, req *ReqUnba func (cli *Client) UserTyping(ctx context.Context, roomID id.RoomID, typing bool, timeout time.Duration) (resp *RespTyping, err error) { req := ReqTyping{Typing: typing, Timeout: timeout.Milliseconds()} u := cli.BuildClientURL("v3", "rooms", roomID, "typing", cli.UserID) - _, err = cli.MakeRequest(ctx, "PUT", u, req, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPut, u, req, &resp) return } @@ -1256,7 +1256,7 @@ func (cli *Client) UserTyping(ctx context.Context, roomID id.RoomID, typing bool func (cli *Client) GetPresence(ctx context.Context, userID id.UserID) (resp *RespPresence, err error) { resp = new(RespPresence) u := cli.BuildClientURL("v3", "presence", userID, "status") - _, err = cli.MakeRequest(ctx, "GET", u, nil, resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, resp) return } @@ -1268,7 +1268,7 @@ func (cli *Client) GetOwnPresence(ctx context.Context) (resp *RespPresence, err func (cli *Client) SetPresence(ctx context.Context, status event.Presence) (err error) { req := ReqPresence{Presence: status} u := cli.BuildClientURL("v3", "presence", cli.UserID, "status") - _, err = cli.MakeRequest(ctx, "PUT", u, req, nil) + _, err = cli.MakeRequest(ctx, http.MethodPut, u, req, nil) return } @@ -1310,7 +1310,7 @@ func (cli *Client) updateStoreWithOutgoingEvent(ctx context.Context, roomID id.R // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstateeventtypestatekey func (cli *Client) StateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) (err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey) - _, err = cli.MakeRequest(ctx, "GET", u, nil, outContent) + _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, outContent) if err == nil && cli.StateStore != nil { cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, outContent) } @@ -1382,13 +1382,13 @@ func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomSt // GetMediaConfig fetches the configuration of the content repository, such as upload limitations. func (cli *Client) GetMediaConfig(ctx context.Context) (resp *RespMediaConfig, err error) { u := cli.BuildURL(MediaURLPath{"v3", "config"}) - _, err = cli.MakeRequest(ctx, "GET", u, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp) return } // UploadLink uploads an HTTP URL and then returns an MXC URI. func (cli *Client) UploadLink(ctx context.Context, link string) (*RespMediaUpload, error) { - req, err := http.NewRequestWithContext(ctx, "GET", link, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, link, nil) if err != nil { return nil, err } @@ -1687,7 +1687,7 @@ func (cli *Client) GetURLPreview(ctx context.Context, url string) (*RespPreviewU // This API is primarily designed for application services which may want to efficiently look up joined members in a room. func (cli *Client) JoinedMembers(ctx context.Context, roomID id.RoomID) (resp *RespJoinedMembers, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "joined_members") - _, err = cli.MakeRequest(ctx, "GET", u, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp) if err == nil && cli.StateStore != nil { clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, event.MembershipJoin) cli.cliOrContextLog(ctx).Warn().Err(clearErr). @@ -1726,7 +1726,7 @@ func (cli *Client) Members(ctx context.Context, roomID id.RoomID, req ...ReqMemb query["not_membership"] = string(extra.NotMembership) } u := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "members"}, query) - _, err = cli.MakeRequest(ctx, "GET", u, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp) if err == nil && cli.StateStore != nil { var clearMemberships []event.Membership if extra.Membership != "" { @@ -1751,7 +1751,7 @@ func (cli *Client) Members(ctx context.Context, roomID id.RoomID, req ...ReqMemb // This API is primarily designed for application services which may want to efficiently look up joined rooms. func (cli *Client) JoinedRooms(ctx context.Context) (resp *RespJoinedRooms, err error) { u := cli.BuildClientURL("v3", "joined_rooms") - _, err = cli.MakeRequest(ctx, "GET", u, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp) return } @@ -1790,7 +1790,7 @@ func (cli *Client) Messages(ctx context.Context, roomID id.RoomID, from, to stri } urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "messages"}, query) - _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } @@ -1825,13 +1825,13 @@ func (cli *Client) Context(ctx context.Context, roomID id.RoomID, eventID id.Eve } urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "context", eventID}, query) - _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } func (cli *Client) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (resp *event.Event, err error) { urlPath := cli.BuildClientURL("v3", "rooms", roomID, "event", eventID) - _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } @@ -1852,13 +1852,13 @@ func (cli *Client) MarkReadWithContent(ctx context.Context, roomID id.RoomID, ev // To mark a message in a specific thread as read, use pass a ReqSendReceipt as the content. func (cli *Client) SendReceipt(ctx context.Context, roomID id.RoomID, eventID id.EventID, receiptType event.ReceiptType, content interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "rooms", roomID, "receipt", receiptType, eventID) - _, err = cli.MakeRequest(ctx, "POST", urlPath, content, nil) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, content, nil) return } func (cli *Client) SetReadMarkers(ctx context.Context, roomID id.RoomID, content interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "rooms", roomID, "read_markers") - _, err = cli.MakeRequest(ctx, "POST", urlPath, content, nil) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, content, nil) return } @@ -1872,7 +1872,7 @@ func (cli *Client) AddTag(ctx context.Context, roomID id.RoomID, tag string, ord func (cli *Client) AddTagWithCustomData(ctx context.Context, roomID id.RoomID, tag string, data interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "tags", tag) - _, err = cli.MakeRequest(ctx, "PUT", urlPath, data, nil) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, data, nil) return } @@ -1883,13 +1883,13 @@ func (cli *Client) GetTags(ctx context.Context, roomID id.RoomID) (tags event.Ta func (cli *Client) GetTagsWithCustomData(ctx context.Context, roomID id.RoomID, resp interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "tags") - _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } func (cli *Client) RemoveTag(ctx context.Context, roomID id.RoomID, tag string) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "tags", tag) - _, err = cli.MakeRequest(ctx, "DELETE", urlPath, nil, nil) + _, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, nil) return } @@ -1904,49 +1904,49 @@ func (cli *Client) SetTags(ctx context.Context, roomID id.RoomID, tags event.Tag // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3voipturnserver func (cli *Client) TurnServer(ctx context.Context) (resp *RespTurnServer, err error) { urlPath := cli.BuildClientURL("v3", "voip", "turnServer") - _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } func (cli *Client) CreateAlias(ctx context.Context, alias id.RoomAlias, roomID id.RoomID) (resp *RespAliasCreate, err error) { urlPath := cli.BuildClientURL("v3", "directory", "room", alias) - _, err = cli.MakeRequest(ctx, "PUT", urlPath, &ReqAliasCreate{RoomID: roomID}, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &ReqAliasCreate{RoomID: roomID}, &resp) return } func (cli *Client) ResolveAlias(ctx context.Context, alias id.RoomAlias) (resp *RespAliasResolve, err error) { urlPath := cli.BuildClientURL("v3", "directory", "room", alias) - _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } func (cli *Client) DeleteAlias(ctx context.Context, alias id.RoomAlias) (resp *RespAliasDelete, err error) { urlPath := cli.BuildClientURL("v3", "directory", "room", alias) - _, err = cli.MakeRequest(ctx, "DELETE", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, &resp) return } func (cli *Client) GetAliases(ctx context.Context, roomID id.RoomID) (resp *RespAliasList, err error) { urlPath := cli.BuildClientURL("v3", "rooms", roomID, "aliases") - _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } func (cli *Client) UploadKeys(ctx context.Context, req *ReqUploadKeys) (resp *RespUploadKeys, err error) { urlPath := cli.BuildClientURL("v3", "keys", "upload") - _, err = cli.MakeRequest(ctx, "POST", urlPath, req, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) return } func (cli *Client) QueryKeys(ctx context.Context, req *ReqQueryKeys) (resp *RespQueryKeys, err error) { urlPath := cli.BuildClientURL("v3", "keys", "query") - _, err = cli.MakeRequest(ctx, "POST", urlPath, req, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) return } func (cli *Client) ClaimKeys(ctx context.Context, req *ReqClaimKeys) (resp *RespClaimKeys, err error) { urlPath := cli.BuildClientURL("v3", "keys", "claim") - _, err = cli.MakeRequest(ctx, "POST", urlPath, req, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) return } @@ -1955,7 +1955,7 @@ func (cli *Client) GetKeyChanges(ctx context.Context, from, to string) (resp *Re "from": from, "to": to, }) - _, err = cli.MakeRequest(ctx, "POST", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, nil, &resp) return } @@ -2113,37 +2113,37 @@ func (cli *Client) DeleteKeyBackupVersion(ctx context.Context, version id.KeyBac func (cli *Client) SendToDevice(ctx context.Context, eventType event.Type, req *ReqSendToDevice) (resp *RespSendToDevice, err error) { urlPath := cli.BuildClientURL("v3", "sendToDevice", eventType.String(), cli.TxnID()) - _, err = cli.MakeRequest(ctx, "PUT", urlPath, req, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, req, &resp) return } func (cli *Client) GetDevicesInfo(ctx context.Context) (resp *RespDevicesInfo, err error) { urlPath := cli.BuildClientURL("v3", "devices") - _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } func (cli *Client) GetDeviceInfo(ctx context.Context, deviceID id.DeviceID) (resp *RespDeviceInfo, err error) { urlPath := cli.BuildClientURL("v3", "devices", deviceID) - _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } func (cli *Client) SetDeviceInfo(ctx context.Context, deviceID id.DeviceID, req *ReqDeviceInfo) error { urlPath := cli.BuildClientURL("v3", "devices", deviceID) - _, err := cli.MakeRequest(ctx, "PUT", urlPath, req, nil) + _, err := cli.MakeRequest(ctx, http.MethodPut, urlPath, req, nil) return err } func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice) error { urlPath := cli.BuildClientURL("v3", "devices", deviceID) - _, err := cli.MakeRequest(ctx, "DELETE", urlPath, req, nil) + _, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, req, nil) return err } func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices) error { urlPath := cli.BuildClientURL("v3", "delete_devices") - _, err := cli.MakeRequest(ctx, "DELETE", urlPath, req, nil) + _, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, req, nil) return err } @@ -2176,7 +2176,7 @@ func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCross func (cli *Client) UploadSignatures(ctx context.Context, req *ReqUploadSignatures) (resp *RespUploadSignatures, err error) { urlPath := cli.BuildClientURL("v3", "keys", "signatures", "upload") - _, err = cli.MakeRequest(ctx, "POST", urlPath, req, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) return } @@ -2190,13 +2190,13 @@ func (cli *Client) GetScopedPushRules(ctx context.Context, scope string) (resp * u, _ := url.Parse(cli.BuildClientURL("v3", "pushrules", scope)) // client.BuildURL returns the URL without a trailing slash, but the pushrules endpoint requires the slash. u.Path += "/" - _, err = cli.MakeRequest(ctx, "GET", u.String(), nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, u.String(), nil, &resp) return } func (cli *Client) GetPushRule(ctx context.Context, scope string, kind pushrules.PushRuleType, ruleID string) (resp *pushrules.PushRule, err error) { urlPath := cli.BuildClientURL("v3", "pushrules", scope, kind, ruleID) - _, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) if resp != nil { resp.Type = kind } @@ -2205,7 +2205,7 @@ func (cli *Client) GetPushRule(ctx context.Context, scope string, kind pushrules func (cli *Client) DeletePushRule(ctx context.Context, scope string, kind pushrules.PushRuleType, ruleID string) error { urlPath := cli.BuildClientURL("v3", "pushrules", scope, kind, ruleID) - _, err := cli.MakeRequest(ctx, "DELETE", urlPath, nil, nil) + _, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, nil) return err } @@ -2218,7 +2218,7 @@ func (cli *Client) PutPushRule(ctx context.Context, scope string, kind pushrules query["before"] = req.Before } urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "pushrules", scope, kind, ruleID}, query) - _, err := cli.MakeRequest(ctx, "PUT", urlPath, req, nil) + _, err := cli.MakeRequest(ctx, http.MethodPut, urlPath, req, nil) return err } @@ -2239,7 +2239,7 @@ func (cli *Client) BatchSend(ctx context.Context, roomID id.RoomID, req *ReqBatc if len(req.BatchID) > 0 { query["batch_id"] = req.BatchID.String() } - _, err = cli.MakeRequest(ctx, "POST", cli.BuildURLWithQuery(path, query), req, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, cli.BuildURLWithQuery(path, query), req, &resp) return } From 7ffbe34f0c43247d14c5bdd1540d6a4a572ca23d Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Thu, 15 Feb 2024 14:27:40 +0200 Subject: [PATCH 0126/1647] Log if importing partial megolm session --- crypto/keybackup.go | 4 ++++ crypto/keysharing.go | 3 +++ 2 files changed, 7 insertions(+) diff --git a/crypto/keybackup.go b/crypto/keybackup.go index 9090e76c..93440a58 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -159,6 +159,10 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id. maxMessages = config.RotationPeriodMessages } + if firstKnownIndex := igsInternal.FirstKnownIndex(); firstKnownIndex > 0 { + log.Warn().Uint32("first_known_index", firstKnownIndex).Msg("Importing partial session") + } + igs := &InboundGroupSession{ Internal: *igsInternal, SigningKey: keyBackupData.SenderClaimedKeys.Ed25519, diff --git a/crypto/keysharing.go b/crypto/keysharing.go index 2e8947f6..fa422ca5 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -168,6 +168,9 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt if content.MaxMessages != 0 { maxMessages = content.MaxMessages } + if firstKnownIndex := igsInternal.FirstKnownIndex(); firstKnownIndex > 0 { + log.Warn().Uint32("first_known_index", firstKnownIndex).Msg("Importing partial session") + } igs := &InboundGroupSession{ Internal: *igsInternal, SigningKey: evt.Keys.Ed25519, From 66ba71153e742d32190ee2d880a2406af494e200 Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Fri, 16 Feb 2024 09:56:19 +0200 Subject: [PATCH 0127/1647] Remove withheld keys when scanning all IGS rows --- crypto/sql_store.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crypto/sql_store.go b/crypto/sql_store.go index ef1be25b..a3b3b74a 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -496,7 +496,7 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) dbutil.RowIter[*InboundGroupSession] { rows, err := store.DB.Query(ctx, ` - SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version + SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`, roomID, store.AccountID, ) @@ -505,7 +505,7 @@ func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.RowIter[*InboundGroupSession] { rows, err := store.DB.Query(ctx, ` - SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version + SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL`, store.AccountID, ) @@ -514,7 +514,7 @@ func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.Row func (store *SQLCryptoStore) GetGroupSessionsWithoutKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession] { rows, err := store.DB.Query(ctx, ` - SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version + SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL AND key_backup_version != $2`, store.AccountID, version, ) From a1b18a005a4011fad9f03339e04d345077ed9cb1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 16 Feb 2024 16:52:21 +0200 Subject: [PATCH 0128/1647] Add flag to run bridge even if homeserver is outdated --- bridge/bridge.go | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/bridge/bridge.go b/bridge/bridge.go index 8841cc37..abebdd62 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -49,6 +49,7 @@ var version = flag.MakeFull("v", "version", "View bridge version and quit.", "fa var versionJSON = flag.Make().LongKey("version-json").Usage("Print a JSON object representing the bridge version and quit.").Default("false").Bool() var ignoreUnsupportedDatabase = flag.Make().LongKey("ignore-unsupported-database").Usage("Run even if the database schema is too new").Default("false").Bool() var ignoreForeignTables = flag.Make().LongKey("ignore-foreign-tables").Usage("Run even if the database contains tables from other programs (like Synapse)").Default("false").Bool() +var ignoreUnsupportedServer = flag.Make().LongKey("ignore-unsupported-server").Usage("Run even if the Matrix homeserver is outdated").Default("false").Bool() var wantHelp, _ = flag.MakeHelpFlag() var _ appservice.StateStore = (*sqlstatestore.SQLStateStore)(nil) @@ -304,19 +305,27 @@ func (br *Bridge) ensureConnection(ctx context.Context) { } } + unsupportedServerLogLevel := zerolog.FatalLevel + if *ignoreUnsupportedServer { + unsupportedServerLogLevel = zerolog.ErrorLevel + } if br.Config.Homeserver.Software == bridgeconfig.SoftwareHungry && !br.SpecVersions.Supports(mautrix.BeeperFeatureHungry) { br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The config claims the homeserver is hungryserv, but the /versions response didn't confirm it") os.Exit(18) } else if !br.SpecVersions.ContainsGreaterOrEqual(MinSpecVersion) { - br.ZLog.WithLevel(zerolog.FatalLevel). + br.ZLog.WithLevel(unsupportedServerLogLevel). Stringer("server_supports", br.SpecVersions.GetLatest()). Stringer("bridge_requires", MinSpecVersion). Msg("The homeserver is outdated (supported spec versions are below minimum required by bridge)") - os.Exit(18) + if !*ignoreUnsupportedServer { + os.Exit(18) + } } else if fr, ok := br.Child.(CSFeatureRequirer); ok { if msg, hasFeatures := fr.CheckFeatures(&br.SpecVersions); !hasFeatures { - br.ZLog.WithLevel(zerolog.FatalLevel).Msg(msg) - os.Exit(18) + br.ZLog.WithLevel(unsupportedServerLogLevel).Msg(msg) + if !*ignoreUnsupportedServer { + os.Exit(18) + } } } From 4cae27c4454bf7a5356a3a655e32d8685dc97ef6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 16 Feb 2024 16:54:18 +0200 Subject: [PATCH 0129/1647] Bump version to v0.18.0-beta.1 --- CHANGELOG.md | 18 ++++++++++++++++-- go.mod | 6 +++--- go.sum | 12 ++++++------ version.go | 2 +- 4 files changed, 26 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 57518395..3298c379 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,10 +1,24 @@ -## unreleased +## v0.18.0 (unreleased) +### beta.1 (2024-02-16) + +* Bumped minimum Go version to 1.21. +* *(bridge)* Bumped minimum Matrix spec version to v1.4. +* **Breaking change *(crypto)*** Deleted old half-broken interactive + verification code and replaced it with a new `verificationhelper`. + * The new verification helper is still experimental. + * Both QR and emoji verification are supported (in theory). +* *(crypto)* Added support for server-side key backup. +* *(crypto)* Added support for receiving and sending secrets like cross-signing + private keys via secret sharing. +* *(crypto)* Added support for tracking which devices megolm sessions were + initially shared to, and allowing re-sharing the keys to those sessions. +* *(client)* Changed cross-signing key upload method to accept a callback for + user-interactive auth instead of only hardcoding password support. * *(appservice)* Dropped support for legacy non-prefixed appservice paths (e.g. `/transactions` instead of `/_matrix/app/v1/transactions`). * *(appservice)* Dropped support for legacy `access_token` authorization in appservice endpoints. -* *(bridge)* Bumped minimum Matrix spec version to v1.4. * *(bridge)* Fixed `RawArgs` field in command events of command state callbacks. * *(appservice)* Added `CreateFull` helper function for creating an `AppService` instance with all the mandatory fields set. diff --git a/go.mod b/go.mod index b36e8d3c..226ceb44 100644 --- a/go.mod +++ b/go.mod @@ -9,13 +9,13 @@ require ( github.com/mattn/go-sqlite3 v1.14.22 github.com/rs/zerolog v1.32.0 github.com/stretchr/testify v1.8.4 - github.com/tidwall/gjson v1.17.0 + github.com/tidwall/gjson v1.17.1 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.0 - go.mau.fi/util v0.3.1-0.20240208085450-32294da153ab + go.mau.fi/util v0.4.0 go.mau.fi/zeroconfig v0.1.2 golang.org/x/crypto v0.19.0 - golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3 + golang.org/x/exp v0.0.0-20240213143201-ec583247a57a golang.org/x/net v0.21.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 diff --git a/go.sum b/go.sum index abe1986c..dbcaf985 100644 --- a/go.sum +++ b/go.sum @@ -27,8 +27,8 @@ github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWR github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM= -github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U= +github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= @@ -37,14 +37,14 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.0 h1:EfOIvIMZIzHdB/R/zVrikYLPPwJlfMcNczJFMs1m6sA= github.com/yuin/goldmark v1.7.0/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.3.1-0.20240208085450-32294da153ab h1:XZ8W5vHWlXSGmHn1U+Fvbh+xZr9wuHTvbY+qV7aybDY= -go.mau.fi/util v0.3.1-0.20240208085450-32294da153ab/go.mod h1:rRypwgXVEPILomtFPyQcnbOeuRqf+nRN84vh/CICq4w= +go.mau.fi/util v0.4.0 h1:S2X3qU4pUcb/vxBRfAuZjbrR9xVMAXSjQojNBLPBbhs= +go.mau.fi/util v0.4.0/go.mod h1:leeiHtgVBuN+W9aDii3deAXnfC563iN3WK6BF8/AjNw= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= -golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3 h1:/RIbNt/Zr7rVhIkQhooTxCxFcdWLGIKnZA4IXNFSrvo= -golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= +golang.org/x/exp v0.0.0-20240213143201-ec583247a57a h1:HinSgX1tJRX3KsL//Gxynpw5CTOAIPhgL4W8PNiIpVE= +golang.org/x/exp v0.0.0-20240213143201-ec583247a57a/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc= golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/version.go b/version.go index d92a7977..ed53f6ae 100644 --- a/version.go +++ b/version.go @@ -7,7 +7,7 @@ import ( "strings" ) -const Version = "v0.17.0" +const Version = "v0.18.0-beta.1" var GoModVersion = "" var Commit = "" From fd986fc43a4eb99d0f2a8a4d950d57cc57085ad3 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 16 Feb 2024 09:36:35 -0700 Subject: [PATCH 0130/1647] crypto: add some license headers Signed-off-by: Sumner Evans --- crypto/aescbc/aes_cbc.go | 6 ++++++ crypto/aescbc/aes_cbc_test.go | 6 ++++++ crypto/aescbc/errors.go | 6 ++++++ crypto/backup/encryptedsessiondata.go | 6 ++++++ crypto/backup/encryptedsessiondata_test.go | 6 ++++++ crypto/backup/ephemeralkey.go | 6 ++++++ crypto/backup/ephemeralkey_test.go | 6 ++++++ crypto/backup/megolmbackup.go | 6 ++++++ crypto/backup/megolmbackupkey.go | 6 ++++++ crypto/pkcs7/pkcs7.go | 6 ++++++ crypto/pkcs7/pkcs7_test.go | 6 ++++++ crypto/signatures/signatures.go | 6 ++++++ 12 files changed, 72 insertions(+) diff --git a/crypto/aescbc/aes_cbc.go b/crypto/aescbc/aes_cbc.go index f1fdc84d..d69a5f49 100644 --- a/crypto/aescbc/aes_cbc.go +++ b/crypto/aescbc/aes_cbc.go @@ -1,3 +1,9 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + package aescbc import ( diff --git a/crypto/aescbc/aes_cbc_test.go b/crypto/aescbc/aes_cbc_test.go index 06dcee0d..bb03f706 100644 --- a/crypto/aescbc/aes_cbc_test.go +++ b/crypto/aescbc/aes_cbc_test.go @@ -1,3 +1,9 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + package aescbc_test import ( diff --git a/crypto/aescbc/errors.go b/crypto/aescbc/errors.go index 542c3450..f3d2d7ce 100644 --- a/crypto/aescbc/errors.go +++ b/crypto/aescbc/errors.go @@ -1,3 +1,9 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + package aescbc import "errors" diff --git a/crypto/backup/encryptedsessiondata.go b/crypto/backup/encryptedsessiondata.go index 8ac74151..37b0a6c8 100644 --- a/crypto/backup/encryptedsessiondata.go +++ b/crypto/backup/encryptedsessiondata.go @@ -1,3 +1,9 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + package backup import ( diff --git a/crypto/backup/encryptedsessiondata_test.go b/crypto/backup/encryptedsessiondata_test.go index 8aab1390..761c4328 100644 --- a/crypto/backup/encryptedsessiondata_test.go +++ b/crypto/backup/encryptedsessiondata_test.go @@ -1,3 +1,9 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + package backup_test import ( diff --git a/crypto/backup/ephemeralkey.go b/crypto/backup/ephemeralkey.go index d0ee03a6..e481e7a9 100644 --- a/crypto/backup/ephemeralkey.go +++ b/crypto/backup/ephemeralkey.go @@ -1,3 +1,9 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + package backup import ( diff --git a/crypto/backup/ephemeralkey_test.go b/crypto/backup/ephemeralkey_test.go index 93d24563..0842f54f 100644 --- a/crypto/backup/ephemeralkey_test.go +++ b/crypto/backup/ephemeralkey_test.go @@ -1,3 +1,9 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + package backup_test import ( diff --git a/crypto/backup/megolmbackup.go b/crypto/backup/megolmbackup.go index dea3e704..0a2a42a2 100644 --- a/crypto/backup/megolmbackup.go +++ b/crypto/backup/megolmbackup.go @@ -1,3 +1,9 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + package backup import ( diff --git a/crypto/backup/megolmbackupkey.go b/crypto/backup/megolmbackupkey.go index 8a57b4cf..8f23d104 100644 --- a/crypto/backup/megolmbackupkey.go +++ b/crypto/backup/megolmbackupkey.go @@ -1,3 +1,9 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + package backup import ( diff --git a/crypto/pkcs7/pkcs7.go b/crypto/pkcs7/pkcs7.go index 1018e52b..c83c5afd 100644 --- a/crypto/pkcs7/pkcs7.go +++ b/crypto/pkcs7/pkcs7.go @@ -1,3 +1,9 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + package pkcs7 import "bytes" diff --git a/crypto/pkcs7/pkcs7_test.go b/crypto/pkcs7/pkcs7_test.go index 6ef835c0..5edc9a92 100644 --- a/crypto/pkcs7/pkcs7_test.go +++ b/crypto/pkcs7/pkcs7_test.go @@ -1,3 +1,9 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + package pkcs7_test import ( diff --git a/crypto/signatures/signatures.go b/crypto/signatures/signatures.go index 7ad19316..0c4422f9 100644 --- a/crypto/signatures/signatures.go +++ b/crypto/signatures/signatures.go @@ -1,3 +1,9 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + package signatures import ( From 0b8e46e84dd8ef3de9c0b1fb8599c227ffed690d Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 16 Feb 2024 12:01:42 -0700 Subject: [PATCH 0131/1647] event/verification: Mac -> MAC Signed-off-by: Sumner Evans --- crypto/verificationhelper/sas.go | 2 +- event/content.go | 10 +++++----- event/verification.go | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index f20d2b60..98641506 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -122,7 +122,7 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat return err } - macEventContent := &event.VerificationMacEventContent{ + macEventContent := &event.VerificationMACEventContent{ Keys: keysMAC, MAC: keys, } diff --git a/event/content.go b/event/content.go index 2a2833d3..bdb3eeb8 100644 --- a/event/content.go +++ b/event/content.go @@ -64,7 +64,7 @@ var TypeMap = map[Type]reflect.Type{ InRoomVerificationAccept: reflect.TypeOf(VerificationAcceptEventContent{}), InRoomVerificationKey: reflect.TypeOf(VerificationKeyEventContent{}), - InRoomVerificationMAC: reflect.TypeOf(VerificationMacEventContent{}), + InRoomVerificationMAC: reflect.TypeOf(VerificationMACEventContent{}), ToDeviceRoomKey: reflect.TypeOf(RoomKeyEventContent{}), ToDeviceForwardedRoomKey: reflect.TypeOf(ForwardedRoomKeyEventContent{}), @@ -83,7 +83,7 @@ var TypeMap = map[Type]reflect.Type{ ToDeviceVerificationAccept: reflect.TypeOf(VerificationAcceptEventContent{}), ToDeviceVerificationKey: reflect.TypeOf(VerificationKeyEventContent{}), - ToDeviceVerificationMAC: reflect.TypeOf(VerificationMacEventContent{}), + ToDeviceVerificationMAC: reflect.TypeOf(VerificationMACEventContent{}), ToDeviceOrgMatrixRoomKeyWithheld: reflect.TypeOf(RoomKeyWithheldEventContent{}), @@ -562,10 +562,10 @@ func (content *Content) AsVerificationKey() *VerificationKeyEventContent { } return casted } -func (content *Content) AsVerificationMAC() *VerificationMacEventContent { - casted, ok := content.Parsed.(*VerificationMacEventContent) +func (content *Content) AsVerificationMAC() *VerificationMACEventContent { + casted, ok := content.Parsed.(*VerificationMACEventContent) if !ok { - return &VerificationMacEventContent{} + return &VerificationMACEventContent{} } return casted } diff --git a/event/verification.go b/event/verification.go index 60fcb9d4..5ab903c0 100644 --- a/event/verification.go +++ b/event/verification.go @@ -308,13 +308,13 @@ type VerificationKeyEventContent struct { Key jsonbytes.UnpaddedBytes `json:"key"` } -// VerificationMacEventContent represents the content of an +// VerificationMACEventContent represents the content of an // [m.key.verification.mac] event (both the to-device and the in-room version) // as described in [Section 11.12.2.2.2] of the Spec. // // [m.key.verification.mac]: https://spec.matrix.org/v1.9/client-server-api/#mkeyverificationmac // [Section 11.12.2.2.2]: https://spec.matrix.org/v1.9/client-server-api/#verification-messages-specific-to-sas -type VerificationMacEventContent struct { +type VerificationMACEventContent struct { ToDeviceVerificationEvent InRoomVerificationEvent From 492c42f8b289a36801b43b1a1a0586b8ffd4ec1e Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 16 Feb 2024 12:26:42 -0700 Subject: [PATCH 0132/1647] verificationhelper/sas: actually exchange and verify MAC keys Signed-off-by: Sumner Evans --- crypto/verificationhelper/reciprocate.go | 21 ++- crypto/verificationhelper/sas.go | 143 ++++++++++++++---- .../verificationhelper/verificationhelper.go | 25 ++- 3 files changed, 143 insertions(+), 46 deletions(-) diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index ddb8f62c..e49070ec 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -133,7 +133,16 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by // Immediately send the m.key.verification.done event, as our side of the // transaction is done. - return vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) + err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) + if err != nil { + return err + } + txn.SentOurDone = true + if txn.ReceivedTheirDone { + txn.VerificationState = verificationStateDone + vh.verificationDone(ctx, txn.TransactionID) + } + return nil } // ConfirmQRCodeScanned confirms that our QR code has been scanned and sends the @@ -187,11 +196,11 @@ func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id if err != nil { return err } - - txn.VerificationState = verificationStateDone - - // Broadcast that the verification is complete. - vh.verificationDone(ctx, txn.TransactionID) + txn.SentOurDone = true + if txn.ReceivedTheirDone { + txn.VerificationState = verificationStateDone + vh.verificationDone(ctx, txn.TransactionID) + } return nil } diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index 98641506..40b86f3f 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -93,6 +93,8 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat txn, ok := vh.activeTransactions[txnID] if !ok { return fmt.Errorf("unknown transaction ID") + } else if txn.VerificationState != verificationStateSASKeysExchanged { + return errors.New("transaction is not in keys exchanged state") } var err error @@ -100,17 +102,23 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat log.Info().Msg("Signing keys") - // TODO actually sign some keys // My device key myDevice := vh.mach.OwnIdentity() myDeviceKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, myDevice.DeviceID.String()) - keys[myDeviceKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUser, txn.TheirDevice, myDeviceKeyID.String(), myDevice.IdentityKey.String()) + keys[myDeviceKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUser, txn.TheirDevice, myDeviceKeyID.String(), myDevice.SigningKey.String()) if err != nil { return err } // Master signing key - // TODO how to detect whether or not we trust the master key? + crossSigningKeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) + if crossSigningKeys != nil { + crossSigningKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, crossSigningKeys.MasterKey.String()) + keys[crossSigningKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUser, txn.TheirDevice, crossSigningKeyID.String(), crossSigningKeys.MasterKey.String()) + if err != nil { + return err + } + } var keyIDs []string for keyID := range keys { @@ -126,7 +134,21 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat Keys: keysMAC, MAC: keys, } - return vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationMAC, macEventContent) + err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationMAC, macEventContent) + if err != nil { + return err + } + + txn.SentOurMAC = true + if txn.ReceivedTheirMAC { + txn.VerificationState = verificationStateSASMACExchanged + err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) + if err != nil { + return err + } + txn.SentOurDone = true + } + return nil } // onVerificationStartSAS handles the m.key.verification.start events with @@ -141,7 +163,10 @@ func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn *v } startEvt := evt.Content.AsVerificationStart() - log := vh.getLog(ctx) + log := vh.getLog(ctx).With(). + Str("verification_action", "start_sas"). + Stringer("transaction_id", txn.TransactionID). + Logger() log.Info().Msg("Received SAS verification start event") _, err := vh.mach.GetOrFetchDevice(ctx, evt.Sender, startEvt.FromDevice) @@ -441,8 +466,6 @@ func (vh *VerificationHelper) verificationMACHKDF(txn *verificationTransaction, if err != nil { return nil, err } - fmt.Printf("KEYID %s\n", keyID) - fmt.Printf("KEY %s\n", key) var infoBuf bytes.Buffer infoBuf.WriteString("MATRIX_KEY_VERIFICATION_MAC") @@ -464,13 +487,10 @@ func (vh *VerificationHelper) verificationMACHKDF(txn *verificationTransaction, hash.Write([]byte(key)) sum := hash.Sum(nil) if txn.MACMethod == event.MACMethodHKDFHMACSHA256 { - fmt.Printf("MANGLING %v\n", sum) - fmt.Printf("%s\n", BrokenB64Encode(sum)) sum, err = base64.RawStdEncoding.DecodeString(BrokenB64Encode(sum)) if err != nil { panic(err) } - fmt.Printf("MANGLING %v\n", sum) } return sum, nil } @@ -548,38 +568,97 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi Logger() log.Info().Msg("Received SAS verification MAC event") macEvt := evt.Content.AsVerificationMAC() - jsonBytes, _ := json.Marshal(macEvt) - fmt.Printf("%s\n", jsonBytes) - var keyIDs []string - // for keyID, mac := range macEvt.MAC { - // log.Info().Str("key_id", keyID.String()).Msg("Received MAC for key") - // keyIDs = append(keyIDs, keyID.String()) - - // var key string - - // expectedMAC, err := vh.verificationMACHKDF(txn, txn.TheirUser, txn.TheirDevice, vh.client.UserID, vh.client.DeviceID, keyID.String(), key) - // if err != nil { - // vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("failed to calculate key MAC: %w", err)) - // return - // } - // if !bytes.Equal(expectedMAC, mac) { - // vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("MAC mismatch for key %s", keyID)) - // return - // } - // } + // Verifying Keys MAC log.Info().Msg("Verifying MAC for all sent keys") + var hasTheirDeviceKey bool + var keyIDs []string + for keyID := range macEvt.MAC { + keyIDs = append(keyIDs, keyID.String()) + _, kID := keyID.Parse() + if kID == txn.TheirDevice.String() { + hasTheirDeviceKey = true + } + } + slices.Sort(keyIDs) expectedKeyMAC, err := vh.verificationMACHKDF(txn, txn.TheirUser, txn.TheirDevice, vh.client.UserID, vh.client.DeviceID, "KEY_IDS", strings.Join(keyIDs, ",")) if err != nil { vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("failed to calculate key list MAC: %w", err)) return } - fmt.Printf("%d %v\n", len(expectedKeyMAC), expectedKeyMAC) - fmt.Printf("%d %v\n", len(macEvt.Keys), macEvt.Keys) if !bytes.Equal(expectedKeyMAC, macEvt.Keys) { vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("key list MAC mismatch")) return } + if !hasTheirDeviceKey { + vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("their device key not found")) + return + } - // TODO actually do a trust thing + // Verify the MAC for each key + for keyID, mac := range macEvt.MAC { + log.Info().Str("key_id", keyID.String()).Msg("Received MAC for key") + + alg, kID := keyID.Parse() + if alg != id.KeyAlgorithmEd25519 { + vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("unsupported key algorithm %s", alg)) + return + } + + var key string + var theirDevice *id.Device + if kID == txn.TheirDevice.String() { + theirDevice, err = vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) + if err != nil { + vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("failed to fetch their device: %w", err)) + return + } + key = theirDevice.SigningKey.String() + } else { // This is the master key + crossSigningKeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) + if crossSigningKeys == nil { + vh.verificationError(ctx, txn.TransactionID, errors.New("cross-signing keys not found")) + return + } + if kID != crossSigningKeys.MasterKey.String() { + vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("unknown key ID %s", keyID)) + return + } + key = crossSigningKeys.MasterKey.String() + } + + expectedMAC, err := vh.verificationMACHKDF(txn, txn.TheirUser, txn.TheirDevice, vh.client.UserID, vh.client.DeviceID, keyID.String(), key) + if err != nil { + vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("failed to calculate key MAC: %w", err)) + return + } + if !bytes.Equal(expectedMAC, mac) { + vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("MAC mismatch for key %s", keyID)) + return + } + + // Trust their device + if kID == txn.TheirDevice.String() { + theirDevice.Trust = id.TrustStateVerified + err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUser, theirDevice) + if err != nil { + vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("failed to update device trust state after verifying: %w", err)) + return + } + } + } + log.Info().Msg("All MACs verified") + + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + txn.ReceivedTheirMAC = true + if txn.SentOurMAC { + txn.VerificationState = verificationStateSASMACExchanged + err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) + if err != nil { + vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("failed to send verification done event: %w", err)) + return + } + txn.SentOurDone = true + } } diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 1d6ba912..939b0d8c 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -40,7 +40,7 @@ const ( verificationStateSASStarted // An SAS verification has been started verificationStateSASAccepted // An SAS verification has been accepted verificationStateSASKeysExchanged // An SAS verification has exchanged keys - verificationStateSASMAC // An SAS verification has exchanged MACs + verificationStateSASMACExchanged // An SAS verification has exchanged MACs ) func (step verificationState) String() string { @@ -61,7 +61,7 @@ func (step verificationState) String() string { return "sas_accepted" case verificationStateSASKeysExchanged: return "sas_keys_exchanged" - case verificationStateSASMAC: + case verificationStateSASMACExchanged: return "sas_mac" default: return fmt.Sprintf("verificationStep(%d)", step) @@ -104,6 +104,10 @@ type verificationTransaction struct { EphemeralKey *ecdh.PrivateKey // The ephemeral key EphemeralPublicKeyShared bool // Whether this device's ephemeral public key has been shared OtherPublicKey *ecdh.PublicKey // The other device's ephemeral public key + ReceivedTheirMAC bool // Whether we have received their MAC + SentOurMAC bool // Whether we have sent our MAC + ReceivedTheirDone bool // Whether we have received their done event + SentOurDone bool // Whether we have sent our done event } // RequiredCallbacks is an interface representing the callbacks required for @@ -151,6 +155,7 @@ type VerificationHelper struct { // supportedMethods are the methods that *we* support supportedMethods []event.VerificationMethod verificationRequested func(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID) + // TODO maybe just only have verificationCancelled instead of verifictionError? verificationError func(ctx context.Context, txnID id.VerificationTransactionID, err error) verificationCancelled func(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) verificationDone func(ctx context.Context, txnID id.VerificationTransactionID) @@ -207,9 +212,9 @@ func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, call } func (vh *VerificationHelper) getLog(ctx context.Context) *zerolog.Logger { - logger := vh.client.Log.With(). - Any("supported_methods", vh.supportedMethods). + logger := zerolog.Ctx(ctx).With(). Str("component", "verification"). + Any("supported_methods", vh.supportedMethods). Logger() return &logger } @@ -680,11 +685,15 @@ func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn *verif vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState == verificationStateTheirQRScanned || txn.VerificationState == verificationStateSASMAC { - txn.VerificationState = verificationStateDone - vh.verificationDone(ctx, txn.TransactionID) - } else { + if txn.VerificationState != verificationStateTheirQRScanned && txn.VerificationState != verificationStateSASMACExchanged { vh.unexpectedEvent(ctx, txn) + return + } + + txn.VerificationState = verificationStateDone + txn.ReceivedTheirDone = true + if txn.SentOurDone { + vh.verificationDone(ctx, txn.TransactionID) } } From 169ed443c8bc3cbbce9eb27b2b29b0055801a645 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 16 Feb 2024 13:23:02 -0700 Subject: [PATCH 0133/1647] verificationhelper: handle both devices sending start event Signed-off-by: Sumner Evans --- crypto/verificationhelper/reciprocate.go | 3 ++ crypto/verificationhelper/sas.go | 4 +- .../verificationhelper/verificationhelper.go | 51 ++++++++++++++++++- 3 files changed, 54 insertions(+), 4 deletions(-) diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index e49070ec..ab177eb9 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -30,6 +30,8 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by Stringer("transaction_id", qrCode.TransactionID). Int("mode", int(qrCode.Mode)). Logger() + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() txn, ok := vh.activeTransactions[qrCode.TransactionID] if !ok { @@ -121,6 +123,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by } // Send a m.key.verification.start event with the secret + txn.StartedByUs = true txn.StartEventContent = &event.VerificationStartEventContent{ FromDevice: vh.client.DeviceID, Method: event.VerificationMethodReciprocate, diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index 40b86f3f..f1aac7e9 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -44,6 +44,8 @@ func (vh *VerificationHelper) StartSAS(ctx context.Context, txnID id.Verificatio return fmt.Errorf("unknown transaction ID") } else if txn.VerificationState != verificationStateReady { return errors.New("transaction is not in ready state") + } else if txn.StartEventContent != nil { + return errors.New("start event already sent or received") } txn.VerificationState = verificationStateSASStarted @@ -59,8 +61,6 @@ func (vh *VerificationHelper) StartSAS(ctx context.Context, txnID id.Verificatio return err } - // TODO check if the other device already has sent a start event - log.Info().Msg("Sending start event") txn.StartEventContent = &event.VerificationStartEventContent{ FromDevice: vh.client.DeviceID, diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 939b0d8c..2e5ba5cd 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -626,8 +626,55 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *veri vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState != verificationStateReady { - log.Warn().Msg("Ignoring verification start event for a transaction that is not in the ready state") + if txn.VerificationState == verificationStateSASStarted || txn.VerificationState == verificationStateOurQRScanned || txn.VerificationState == verificationStateTheirQRScanned { + // We might have sent the event, and they also sent an event. + if txn.StartEventContent == nil || !txn.StartedByUs { + // We didn't sent a start event yet, so we have gotten ourselves + // into a bad state. They've either sent two start events, or we + // have gone on to a new state. + vh.unexpectedEvent(ctx, txn) + return + } + + // Otherwise, we need to implement the following algorithm from Section + // 11.12.2.1 of the Spec: + // https://spec.matrix.org/v1.9/client-server-api/#key-verification-framework + // + // If Alice's and Bob's clients both send an m.key.verification.start + // message, and both specify the same verification method, then the + // m.key.verification.start message sent by the user whose ID is the + // lexicographically largest user ID should be ignored, and the + // situation should be treated the same as if only the user with the + // lexicographically smallest user ID had sent the + // m.key.verification.start message. In the case where the user IDs are + // the same (that is, when a user is verifying their own device), then + // the device IDs should be compared instead. If the two + // m.key.verification.start messages do not specify the same + // verification method, then the verification should be cancelled with + // a code of m.unexpected_message. + + if txn.StartEventContent.Method != startEvt.Method { + cancelEvt := event.VerificationCancelEventContent{ + Code: event.VerificationCancelCodeUnexpectedMessage, + Reason: "The start events have different verification methods.", + } + err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationCancel, &cancelEvt) + if err != nil { + log.Err(err).Msg("Failed to send cancellation event") + } + txn.VerificationState = verificationStateCancelled + vh.verificationCancelled(ctx, txn.TransactionID, cancelEvt.Code, cancelEvt.Reason) + return + } + + if txn.TheirUser < vh.client.UserID || (txn.TheirUser == vh.client.UserID && txn.TheirDevice < vh.client.DeviceID) { + // Use their start event instead of ours + txn.StartedByUs = false + txn.StartEventContent = startEvt + } + + } else if txn.VerificationState != verificationStateReady { + vh.unexpectedEvent(ctx, txn) return } From b7f434cd767ee5f71a3a896e565550ae52dcf99b Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 16 Feb 2024 16:02:11 -0700 Subject: [PATCH 0134/1647] verificationhelper: streamline error/cancellation logic Signed-off-by: Sumner Evans --- crypto/verificationhelper/sas.go | 33 +++-- .../verificationhelper/verificationhelper.go | 126 ++++++++---------- 2 files changed, 73 insertions(+), 86 deletions(-) diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index f1aac7e9..256ff934 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -157,11 +157,6 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat // // [Section 11.12.2.2]: https://spec.matrix.org/v1.9/client-server-api/#short-authentication-string-sas-verification func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn *verificationTransaction, evt *event.Event) error { - if txn.VerificationState != verificationStateReady { - vh.unexpectedEvent(ctx, txn) - return nil // return nil since we already sent a cancellation event in vh.unexpectedEvent - } - startEvt := evt.Content.AsVerificationStart() log := vh.getLog(ctx).With(). Str("verification_action", "start_sas"). @@ -269,7 +264,8 @@ func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn *ver vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() if txn.VerificationState != verificationStateSASStarted { - vh.unexpectedEvent(ctx, txn) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, + "received accept event for a transaction that is not in the started state") return } @@ -303,7 +299,8 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn *verifi defer vh.activeTransactionsLock.Unlock() if txn.VerificationState != verificationStateSASAccepted { - vh.unexpectedEvent(ctx, txn) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, + "received key event for a transaction that is not in the accepted state") return } @@ -583,15 +580,15 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi slices.Sort(keyIDs) expectedKeyMAC, err := vh.verificationMACHKDF(txn, txn.TheirUser, txn.TheirDevice, vh.client.UserID, vh.client.DeviceID, "KEY_IDS", strings.Join(keyIDs, ",")) if err != nil { - vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("failed to calculate key list MAC: %w", err)) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeSASMismatch, "failed to calculate key list MAC: %v", err) return } if !bytes.Equal(expectedKeyMAC, macEvt.Keys) { - vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("key list MAC mismatch")) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeSASMismatch, "key list MAC mismatch") return } if !hasTheirDeviceKey { - vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("their device key not found")) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeSASMismatch, "their device key not found in list of keys") return } @@ -601,7 +598,7 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi alg, kID := keyID.Parse() if alg != id.KeyAlgorithmEd25519 { - vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("unsupported key algorithm %s", alg)) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnknownMethod, "unsupported key algorithm %s", alg) return } @@ -610,18 +607,18 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi if kID == txn.TheirDevice.String() { theirDevice, err = vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) if err != nil { - vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("failed to fetch their device: %w", err)) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to fetch their device: %v", err) return } key = theirDevice.SigningKey.String() } else { // This is the master key crossSigningKeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) if crossSigningKeys == nil { - vh.verificationError(ctx, txn.TransactionID, errors.New("cross-signing keys not found")) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "cross-signing keys not found") return } if kID != crossSigningKeys.MasterKey.String() { - vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("unknown key ID %s", keyID)) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "unknown key ID %s", keyID) return } key = crossSigningKeys.MasterKey.String() @@ -629,11 +626,11 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi expectedMAC, err := vh.verificationMACHKDF(txn, txn.TheirUser, txn.TheirDevice, vh.client.UserID, vh.client.DeviceID, keyID.String(), key) if err != nil { - vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("failed to calculate key MAC: %w", err)) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to calculate key MAC: %v", err) return } if !bytes.Equal(expectedMAC, mac) { - vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("MAC mismatch for key %s", keyID)) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeSASMismatch, "MAC mismatch for key %s", keyID) return } @@ -642,7 +639,7 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi theirDevice.Trust = id.TrustStateVerified err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUser, theirDevice) if err != nil { - vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("failed to update device trust state after verifying: %w", err)) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to update device trust state after verifying: %v", err) return } } @@ -656,7 +653,7 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi txn.VerificationState = verificationStateSASMACExchanged err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { - vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("failed to send verification done event: %w", err)) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to send verification done event: %v", err) return } txn.SentOurDone = true diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 2e5ba5cd..f8fd5c87 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -10,12 +10,10 @@ import ( "bytes" "context" "crypto/ecdh" - "errors" "fmt" "sync" "github.com/rs/zerolog" - "github.com/rs/zerolog/log" "go.mau.fi/util/jsontime" "golang.org/x/exp/maps" "golang.org/x/exp/slices" @@ -117,10 +115,6 @@ type RequiredCallbacks interface { // from another device. VerificationRequested(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID) - // VerificationError is called when an error occurs during the verification - // process. - VerificationError(ctx context.Context, txnID id.VerificationTransactionID, err error) - // VerificationCancelled is called when the verification is cancelled. VerificationCancelled(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) @@ -153,12 +147,10 @@ type VerificationHelper struct { activeTransactionsLock sync.Mutex // supportedMethods are the methods that *we* support - supportedMethods []event.VerificationMethod - verificationRequested func(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID) - // TODO maybe just only have verificationCancelled instead of verifictionError? - verificationError func(ctx context.Context, txnID id.VerificationTransactionID, err error) - verificationCancelled func(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) - verificationDone func(ctx context.Context, txnID id.VerificationTransactionID) + supportedMethods []event.VerificationMethod + verificationRequested func(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID) + verificationCancelledCallback func(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) + verificationDone func(ctx context.Context, txnID id.VerificationTransactionID) showSAS func(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, decimals []int) @@ -183,11 +175,7 @@ func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, call panic("callbacks must implement VerificationRequested") } else { helper.verificationRequested = c.VerificationRequested - helper.verificationError = func(ctx context.Context, txnID id.VerificationTransactionID, err error) { - zerolog.Ctx(ctx).Err(err).Msg("Verification error") - c.VerificationError(ctx, txnID, err) - } - helper.verificationCancelled = c.VerificationCancelled + helper.verificationCancelledCallback = c.VerificationCancelled helper.verificationDone = c.VerificationDone } @@ -297,14 +285,7 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { } // Send the actual cancellation event. - err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationCancel, &event.VerificationCancelEventContent{ - Code: code, - Reason: reason, - }) - if err != nil { - log.Err(err).Msg("Failed to send cancellation event") - } - vh.verificationCancelled(ctx, txn.TransactionID, code, reason) + vh.cancelVerificationTxn(ctx, txn, code, reason) return } @@ -464,6 +445,23 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V return vh.generateAndShowQRCode(ctx, txn) } +// CancelVerification cancels a verification request. The transaction ID should +// be the transaction ID of a verification request that was received via the +// VerificationRequested callback in [RequiredCallbacks]. +func (vh *VerificationHelper) CancelVerification(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) error { + log := vh.getLog(ctx).With(). + Str("verification_action", "cancel verification"). + Stringer("transaction_id", txnID). + Logger() + ctx = log.WithContext(ctx) + + txn, ok := vh.activeTransactions[txnID] + if !ok { + return fmt.Errorf("unknown transaction ID") + } + return vh.cancelVerificationTxn(ctx, txn, code, reason) +} + // sendVerificationEvent sends a verification event to the other user's device // setting the m.relates_to or transaction ID as necessary. // @@ -497,6 +495,27 @@ func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn *ve return nil } +func (vh *VerificationHelper) cancelVerificationTxn(ctx context.Context, txn *verificationTransaction, code event.VerificationCancelCode, reasonFmtStr string, fmtArgs ...any) error { + log := vh.getLog(ctx) + reason := fmt.Sprintf(reasonFmtStr, fmtArgs...) + log.Info(). + Stringer("transaction_id", txn.TransactionID). + Str("code", string(code)). + Str("reason", reason). + Msg("Sending cancellation event") + cancelEvt := &event.VerificationCancelEventContent{ + Code: code, + Reason: reason, + } + err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationCancel, cancelEvt) + if err != nil { + return err + } + txn.VerificationState = verificationStateCancelled + vh.verificationCancelledCallback(ctx, txn.TransactionID, code, reason) + return nil +} + func (vh *VerificationHelper) onVerificationRequest(ctx context.Context, evt *event.Event) { logCtx := vh.getLog(ctx).With(). Str("verification_action", "verification request"). @@ -591,7 +610,7 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *veri } devices, err := vh.mach.CryptoStore.GetDevices(ctx, txn.TheirUser) if err != nil { - vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("failed to get devices for %s: %w", txn.TheirUser, err)) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to get devices for %s: %v", txn.TheirUser, err) return } req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUser: {}}} @@ -609,9 +628,10 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *veri log.Warn().Err(err).Msg("Failed to send cancellation requests") } } + err := vh.generateAndShowQRCode(ctx, txn) if err != nil { - vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("failed to generate and show QR code: %w", err)) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to generate and show QR code: %v", err) } } @@ -632,7 +652,8 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *veri // We didn't sent a start event yet, so we have gotten ourselves // into a bad state. They've either sent two start events, or we // have gone on to a new state. - vh.unexpectedEvent(ctx, txn) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, + "got repeat start event from other user") return } @@ -654,16 +675,7 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *veri // a code of m.unexpected_message. if txn.StartEventContent.Method != startEvt.Method { - cancelEvt := event.VerificationCancelEventContent{ - Code: event.VerificationCancelCodeUnexpectedMessage, - Reason: "The start events have different verification methods.", - } - err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationCancel, &cancelEvt) - if err != nil { - log.Err(err).Msg("Failed to send cancellation event") - } - txn.VerificationState = verificationStateCancelled - vh.verificationCancelled(ctx, txn.TransactionID, cancelEvt.Code, cancelEvt.Reason) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "the start events have different verification methods") return } @@ -672,9 +684,9 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *veri txn.StartedByUs = false txn.StartEventContent = startEvt } - } else if txn.VerificationState != verificationStateReady { - vh.unexpectedEvent(ctx, txn) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, + "got start event for transaction that is not in ready state") return } @@ -682,13 +694,12 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *veri case event.VerificationMethodSAS: txn.VerificationState = verificationStateSASStarted if err := vh.onVerificationStartSAS(ctx, txn, evt); err != nil { - vh.verificationError(ctx, txn.TransactionID, fmt.Errorf("failed to handle SAS verification start: %w", err)) - // TODO cancel? + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to handle SAS verification start: %v", err) } case event.VerificationMethodReciprocate: log.Info().Msg("Received reciprocate start event") if !bytes.Equal(txn.QRCodeSharedSecret, startEvt.Secret) { - vh.verificationError(ctx, txn.TransactionID, errors.New("reciprocated shared secret does not match")) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "reciprocated shared secret does not match") return } txn.VerificationState = verificationStateOurQRScanned @@ -698,32 +709,10 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *veri // here, since the start command for scanning and showing QR codes // should be of type m.reciprocate.v1. log.Error().Str("method", string(startEvt.Method)).Msg("Unsupported verification method in start event") - - cancelEvt := event.VerificationCancelEventContent{ - Code: event.VerificationCancelCodeUnknownMethod, - Reason: fmt.Sprintf("unknown method %s", startEvt.Method), - } - err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationCancel, &cancelEvt) - if err != nil { - log.Err(err).Msg("Failed to send cancellation event") - } - vh.verificationCancelled(ctx, txn.TransactionID, cancelEvt.Code, cancelEvt.Reason) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnknownMethod, fmt.Sprintf("unknown method %s", startEvt.Method)) } } -func (vh *VerificationHelper) unexpectedEvent(ctx context.Context, txn *verificationTransaction) { - cancelEvt := event.VerificationCancelEventContent{ - Code: event.VerificationCancelCodeUnexpectedMessage, - Reason: "Got event for a transaction that is not in the correct state", - } - err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationCancel, &cancelEvt) - if err != nil { - log.Err(err).Msg("Failed to send cancellation event") - } - txn.VerificationState = verificationStateCancelled - vh.verificationCancelled(ctx, txn.TransactionID, cancelEvt.Code, cancelEvt.Reason) -} - func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn *verificationTransaction, evt *event.Event) { vh.getLog(ctx).Info(). Str("verification_action", "done"). @@ -733,7 +722,8 @@ func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn *verif defer vh.activeTransactionsLock.Unlock() if txn.VerificationState != verificationStateTheirQRScanned && txn.VerificationState != verificationStateSASMACExchanged { - vh.unexpectedEvent(ctx, txn) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, + "got done event for transaction that is not in QR-scanned or MAC-exchanged state") return } @@ -755,5 +745,5 @@ func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn *ver vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() txn.VerificationState = verificationStateCancelled - vh.verificationCancelled(ctx, txn.TransactionID, cancelEvt.Code, cancelEvt.Reason) + vh.verificationCancelledCallback(ctx, txn.TransactionID, cancelEvt.Code, cancelEvt.Reason) } From 740c588b963998343a700f0a7e5a313522e3c412 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 17 Feb 2024 14:45:20 +0200 Subject: [PATCH 0135/1647] Add MSC4095 types --- event/beeper.go | 24 ++++++++++++++++++++++++ event/message.go | 2 ++ responses.go | 14 +------------- 3 files changed, 27 insertions(+), 13 deletions(-) diff --git a/event/beeper.go b/event/beeper.go index e37b06c2..77019fcd 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -61,3 +61,27 @@ type BeeperRoomKeyAckEventContent struct { SessionID id.SessionID `json:"session_id"` FirstMessageIndex int `json:"first_message_index"` } + +type LinkPreview struct { + CanonicalURL string `json:"og:url,omitempty"` + Title string `json:"og:title,omitempty"` + Type string `json:"og:type,omitempty"` + Description string `json:"og:description,omitempty"` + + ImageURL id.ContentURIString `json:"og:image,omitempty"` + + ImageSize int `json:"matrix:image:size,omitempty"` + ImageWidth int `json:"og:image:width,omitempty"` + ImageHeight int `json:"og:image:height,omitempty"` + ImageType string `json:"og:image:type,omitempty"` +} + +// BeeperLinkPreview contains the data for a bundled URL preview as specified in MSC4095 +// +// https://github.com/matrix-org/matrix-spec-proposals/pull/4095 +type BeeperLinkPreview struct { + LinkPreview + + MatchedURL string `json:"matched_url"` + ImageEncryption *EncryptedFileInfo `json:"beeper:image:encryption,omitempty"` +} diff --git a/event/message.go b/event/message.go index 6512f9be..d8b27c3d 100644 --- a/event/message.go +++ b/event/message.go @@ -116,6 +116,8 @@ type MessageEventContent struct { BeeperGalleryImages []*MessageEventContent `json:"com.beeper.gallery.images,omitempty"` BeeperGalleryCaption string `json:"com.beeper.gallery.caption,omitempty"` BeeperGalleryCaptionHTML string `json:"com.beeper.gallery.caption_html,omitempty"` + + BeeperLinkPreviews []*BeeperLinkPreview `json:"com.beeper.linkpreviews,omitempty"` } func (content *MessageEventContent) GetRelatesTo() *RelatesTo { diff --git a/responses.go b/responses.go index e182a722..96e30a1f 100644 --- a/responses.go +++ b/responses.go @@ -118,19 +118,7 @@ type RespCreateMXC struct { } // RespPreviewURL is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#get_matrixmediav3preview_url -type RespPreviewURL struct { - CanonicalURL string `json:"og:url,omitempty"` - Title string `json:"og:title,omitempty"` - Type string `json:"og:type,omitempty"` - Description string `json:"og:description,omitempty"` - - ImageURL id.ContentURIString `json:"og:image,omitempty"` - - ImageSize int `json:"matrix:image:size,omitempty"` - ImageWidth int `json:"og:image:width,omitempty"` - ImageHeight int `json:"og:image:height,omitempty"` - ImageType string `json:"og:image:type,omitempty"` -} +type RespPreviewURL = event.LinkPreview // RespUserInteractive is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#user-interactive-authentication-api type RespUserInteractive struct { From b1b1c97a115cf3e8d72c81cdbb090a546df5dbb0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 17 Feb 2024 15:34:17 +0200 Subject: [PATCH 0136/1647] Make matched_url optional --- event/beeper.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/event/beeper.go b/event/beeper.go index 77019fcd..51ddd77f 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -82,6 +82,6 @@ type LinkPreview struct { type BeeperLinkPreview struct { LinkPreview - MatchedURL string `json:"matched_url"` + MatchedURL string `json:"matched_url,omitempty"` ImageEncryption *EncryptedFileInfo `json:"beeper:image:encryption,omitempty"` } From 5e73f1674a5d37bb2d2bb5c98ff3eb7dbb56ff9b Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Sun, 18 Feb 2024 19:26:12 -0700 Subject: [PATCH 0137/1647] verification: add CancelVerification to interface Signed-off-by: Sumner Evans --- client.go | 1 + 1 file changed, 1 insertion(+) diff --git a/client.go b/client.go index 4f94d986..a6f76cbb 100644 --- a/client.go +++ b/client.go @@ -40,6 +40,7 @@ type VerificationHelper interface { StartVerification(ctx context.Context, to id.UserID) (id.VerificationTransactionID, error) StartInRoomVerification(ctx context.Context, roomID id.RoomID, to id.UserID) (id.VerificationTransactionID, error) AcceptVerification(ctx context.Context, txnID id.VerificationTransactionID) error + CancelVerification(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) error HandleScannedQRData(ctx context.Context, data []byte) error ConfirmQRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) error From ccbf0ee9884cbafe5c7e141f7eac91ad1f06aedd Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 20 Feb 2024 00:17:13 +0200 Subject: [PATCH 0138/1647] Add note about event source being moved to 0.17 changelog Fixes #184 --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3298c379..956db43c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,8 @@ functions. * **Breaking change *(everything)*** Added context parameters to all functions (started by [@recht] in [#144]). +* **Breaking change *(client)*** Moved event source from sync event handler + function parameters to the `Mautrix.Source` field inside the event struct. * **Breaking change *(client)*** Moved `EventSource` to `event.Source`. * *(client)* Removed deprecated `OldEventIgnorer`. The non-deprecated version (`Client.DontProcessOldEvents`) is still available. From 0f7c7169642d19e19d4d30ef7d36a97630b4d8f3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 20 Feb 2024 00:18:29 +0200 Subject: [PATCH 0139/1647] Fix field name --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 956db43c..f1815d89 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,7 +30,8 @@ * **Breaking change *(everything)*** Added context parameters to all functions (started by [@recht] in [#144]). * **Breaking change *(client)*** Moved event source from sync event handler - function parameters to the `Mautrix.Source` field inside the event struct. + function parameters to the `Mautrix.EventSource` field inside the event + struct. * **Breaking change *(client)*** Moved `EventSource` to `event.Source`. * *(client)* Removed deprecated `OldEventIgnorer`. The non-deprecated version (`Client.DontProcessOldEvents`) is still available. From 128fc8cd89d84c18a36d907de058dac25af811f8 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 19 Feb 2024 16:33:56 -0700 Subject: [PATCH 0140/1647] event/verification: remove Supports* functions Use slices.Contains instead Signed-off-by: Sumner Evans --- crypto/verificationhelper/sas.go | 12 ++++++------ event/verification.go | 23 ----------------------- 2 files changed, 6 insertions(+), 29 deletions(-) diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index 256ff934..8160a4e1 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -171,18 +171,18 @@ func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn *v } keyAggreementProtocol := event.KeyAgreementProtocolCurve25519HKDFSHA256 - if !startEvt.SupportsKeyAgreementProtocol(keyAggreementProtocol) { + if !slices.Contains(startEvt.KeyAgreementProtocols, keyAggreementProtocol) { return fmt.Errorf("the other device does not support any key agreement protocols that we support") } hashAlgorithm := event.VerificationHashMethodSHA256 - if !startEvt.SupportsHashMethod(hashAlgorithm) { + if !slices.Contains(startEvt.Hashes, hashAlgorithm) { return fmt.Errorf("the other device does not support any hash algorithms that we support") } macMethod := event.MACMethodHKDFHMACSHA256V2 - if !startEvt.SupportsMACMethod(macMethod) { - if startEvt.SupportsMACMethod(event.MACMethodHKDFHMACSHA256) { + if !slices.Contains(startEvt.MessageAuthenticationCodes, macMethod) { + if slices.Contains(startEvt.MessageAuthenticationCodes, event.MACMethodHKDFHMACSHA256) { macMethod = event.MACMethodHKDFHMACSHA256 } else { return fmt.Errorf("the other device does not support any message authentication codes that we support") @@ -348,14 +348,14 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn *verifi var decimals []int var emojis []rune - if txn.StartEventContent.SupportsSASMethod(event.SASMethodDecimal) { + if slices.Contains(txn.StartEventContent.ShortAuthenticationString, event.SASMethodDecimal) { decimals = []int{ (int(sasBytes[0])<<5 | int(sasBytes[1])>>3) + 1000, ((int(sasBytes[1])&0x07)<<10 | int(sasBytes[2])<<2 | int(sasBytes[3])>>6) + 1000, ((int(sasBytes[3])&0x3f)<<7 | int(sasBytes[4])>>1) + 1000, } } - if txn.StartEventContent.SupportsSASMethod(event.SASMethodEmoji) { + if slices.Contains(txn.StartEventContent.ShortAuthenticationString, event.SASMethodEmoji) { sasNum := uint64(sasBytes[0])<<40 | uint64(sasBytes[1])<<32 | uint64(sasBytes[2])<<24 | uint64(sasBytes[3])<<16 | uint64(sasBytes[4])<<8 | uint64(sasBytes[5]) diff --git a/event/verification.go b/event/verification.go index 5ab903c0..b1851de3 100644 --- a/event/verification.go +++ b/event/verification.go @@ -9,7 +9,6 @@ package event import ( "go.mau.fi/util/jsonbytes" "go.mau.fi/util/jsontime" - "golang.org/x/exp/slices" "maunium.net/go/mautrix/id" ) @@ -105,12 +104,6 @@ func VerificationRequestEventContentFromMessage(evt *Event) *VerificationRequest } } -// SupportsVerificationMethod returns whether the given verification method is -// supported by the sender. -func (vrec *VerificationRequestEventContent) SupportsVerificationMethod(method VerificationMethod) bool { - return slices.Contains(vrec.Methods, method) -} - // VerificationReadyEventContent represents the content of an // [m.key.verification.ready] event (both the to-device and the in-room // version) as described in [Section 11.12.2.1] of the Spec. @@ -199,22 +192,6 @@ type VerificationStartEventContent struct { Secret jsonbytes.UnpaddedBytes `json:"secret,omitempty"` } -func (vsec *VerificationStartEventContent) SupportsKeyAgreementProtocol(proto KeyAgreementProtocol) bool { - return slices.Contains(vsec.KeyAgreementProtocols, proto) -} - -func (vsec *VerificationStartEventContent) SupportsHashMethod(alg VerificationHashMethod) bool { - return slices.Contains(vsec.Hashes, alg) -} - -func (vsec *VerificationStartEventContent) SupportsMACMethod(meth MACMethod) bool { - return slices.Contains(vsec.MessageAuthenticationCodes, meth) -} - -func (vsec *VerificationStartEventContent) SupportsSASMethod(meth SASMethod) bool { - return slices.Contains(vsec.ShortAuthenticationString, meth) -} - // VerificationDoneEventContent represents the content of an // [m.key.verification.done] event (both the to-device and the in-room version) // as described in [Section 11.12.2.1] of the Spec. From a6644eb03027827713bb1fef99f41b2c00a569e4 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 19 Feb 2024 16:40:32 -0700 Subject: [PATCH 0141/1647] verificationhelper: add callback for scan QR code This callback indicates that the other device is showing a QR code and is ready for our device to scan it. Signed-off-by: Sumner Evans --- crypto/verificationhelper/verificationhelper.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index f8fd5c87..025af25e 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -130,6 +130,11 @@ type showSASCallbacks interface { } type showQRCodeCallbacks interface { + // ScanQRCode is called when another device has sent a + // m.key.verification.ready event and indicated that they are capable of + // showing a QR code. + ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID) + // ShowQRCode is called when the verification has been accepted and a QR // code should be shown to the user. ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *QRCode) @@ -154,6 +159,7 @@ type VerificationHelper struct { showSAS func(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, decimals []int) + scanQRCode func(ctx context.Context, txnID id.VerificationTransactionID) showQRCode func(ctx context.Context, txnID id.VerificationTransactionID, qrCode *QRCode) qrCodeScaned func(ctx context.Context, txnID id.VerificationTransactionID) } @@ -186,6 +192,7 @@ func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, call if c, ok := callbacks.(showQRCodeCallbacks); ok { helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate) + helper.scanQRCode = c.ScanQRCode helper.showQRCode = c.ShowQRCode helper.qrCodeScaned = c.QRCodeScanned } @@ -442,6 +449,10 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V } txn.VerificationState = verificationStateReady + if slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) { + vh.scanQRCode(ctx, txn.TransactionID) + } + return vh.generateAndShowQRCode(ctx, txn) } @@ -629,6 +640,10 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *veri } } + if slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) { + vh.scanQRCode(ctx, txn.TransactionID) + } + err := vh.generateAndShowQRCode(ctx, txn) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to generate and show QR code: %v", err) From 7f0d53ac91ba14da30cd99e6f149cff158d30277 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 21 Feb 2024 19:10:49 +0200 Subject: [PATCH 0142/1647] Treat missing upload size limit as 50mb --- bridge/bridge.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bridge/bridge.go b/bridge/bridge.go index abebdd62..56f27a8b 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -404,6 +404,9 @@ func (br *Bridge) fetchMediaConfig(ctx context.Context) { if err != nil { br.ZLog.Warn().Err(err).Msg("Failed to fetch media config") } else { + if cfg.UploadSize == 0 { + cfg.UploadSize = 50 * 1024 * 1024 + } br.MediaConfig = *cfg } } From 6abf3c4adc14abaea4fbda9c9ab22196e75af792 Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Wed, 21 Feb 2024 13:27:06 +0200 Subject: [PATCH 0143/1647] Use the encoded form of megolm session key in backup session data We're using the encoded presentation elsewhere as a string and this inconsistency is a footgun. --- crypto/backup/megolmbackup.go | 2 +- crypto/keybackup.go | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/crypto/backup/megolmbackup.go b/crypto/backup/megolmbackup.go index 0a2a42a2..71b8e88b 100644 --- a/crypto/backup/megolmbackup.go +++ b/crypto/backup/megolmbackup.go @@ -35,5 +35,5 @@ type MegolmSessionData struct { ForwardingKeyChain []string `json:"forwarding_curve25519_key_chain"` SenderClaimedKeys SenderClaimedKeys `json:"sender_claimed_keys"` SenderKey id.SenderKey `json:"sender_key"` - SessionKey []byte `json:"session_key"` + SessionKey string `json:"session_key"` } diff --git a/crypto/keybackup.go b/crypto/keybackup.go index 93440a58..cf5e747f 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -2,7 +2,6 @@ package crypto import ( "context" - "encoding/base64" "fmt" "time" @@ -140,7 +139,7 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id. return fmt.Errorf("ignoring room key in backup with weird algorithm %s", keyBackupData.Algorithm) } - igsInternal, err := olm.InboundGroupSessionImport([]byte(base64.RawStdEncoding.EncodeToString(keyBackupData.SessionKey))) + igsInternal, err := olm.InboundGroupSessionImport([]byte(keyBackupData.SessionKey)) if err != nil { return fmt.Errorf("failed to import inbound group session: %w", err) } else if igsInternal.ID() != sessionID { From 581aa8015501209bdfd98f7d923144ffd94165fd Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 23 Feb 2024 21:11:51 +0200 Subject: [PATCH 0144/1647] Fix some error logs --- client.go | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/client.go b/client.go index a6f76cbb..0015aede 100644 --- a/client.go +++ b/client.go @@ -1691,9 +1691,11 @@ func (cli *Client) JoinedMembers(ctx context.Context, roomID id.RoomID) (resp *R _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp) if err == nil && cli.StateStore != nil { clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, event.MembershipJoin) - cli.cliOrContextLog(ctx).Warn().Err(clearErr). - Stringer("room_id", roomID). - Msg("Failed to clear cached member list after fetching joined members") + if clearErr != nil { + cli.cliOrContextLog(ctx).Warn().Err(clearErr). + Stringer("room_id", roomID). + Msg("Failed to clear cached member list after fetching joined members") + } for userID, member := range resp.Joined { updateErr := cli.StateStore.SetMember(ctx, roomID, userID, &event.MemberEventContent{ Membership: event.MembershipJoin, @@ -1701,7 +1703,7 @@ func (cli *Client) JoinedMembers(ctx context.Context, roomID id.RoomID) (resp *R Displayname: member.DisplayName, }) if updateErr != nil { - cli.cliOrContextLog(ctx).Warn().Err(clearErr). + cli.cliOrContextLog(ctx).Warn().Err(updateErr). Stringer("room_id", roomID). Stringer("user_id", userID). Msg("Failed to update membership in state store after fetching joined members") @@ -1735,9 +1737,11 @@ func (cli *Client) Members(ctx context.Context, roomID id.RoomID, req ...ReqMemb } if extra.NotMembership == "" { clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, clearMemberships...) - cli.cliOrContextLog(ctx).Warn().Err(clearErr). - Stringer("room_id", roomID). - Msg("Failed to clear cached member list after fetching joined members") + if clearErr != nil { + cli.cliOrContextLog(ctx).Warn().Err(clearErr). + Stringer("room_id", roomID). + Msg("Failed to clear cached member list after fetching joined members") + } } for _, evt := range resp.Chunk { UpdateStateStore(ctx, cli.StateStore, evt) From cbd13347246359e3819ac12ee2576788bce36705 Mon Sep 17 00:00:00 2001 From: G-ht <44305945+grvn-ht@users.noreply.github.com> Date: Sat, 24 Feb 2024 14:06:27 +0100 Subject: [PATCH 0145/1647] Add more Synapse admin API wrappers (#181) Co-authored-by: Tulir Asokan --- synapseadmin/roomapi.go | 237 ++++++++++++++++++++++++++++++++++++++++ synapseadmin/userapi.go | 91 ++++++++++++++- 2 files changed, 326 insertions(+), 2 deletions(-) create mode 100644 synapseadmin/roomapi.go diff --git a/synapseadmin/roomapi.go b/synapseadmin/roomapi.go new file mode 100644 index 00000000..0953377e --- /dev/null +++ b/synapseadmin/roomapi.go @@ -0,0 +1,237 @@ +// Copyright (c) 2023 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package synapseadmin + +import ( + "context" + "encoding/json" + "net/http" + "strconv" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type ReqListRoom struct { + SearchTerm string + OrderBy string + Direction mautrix.Direction + From int + Limit int +} + +func (req *ReqListRoom) BuildQuery() map[string]string { + query := map[string]string{ + "from": strconv.Itoa(req.From), + } + if req.SearchTerm != "" { + query["search_term"] = req.SearchTerm + } + if req.OrderBy != "" { + query["order_by"] = req.OrderBy + } + if req.Direction != 0 { + query["dir"] = string(req.Direction) + } + if req.Limit != 0 { + query["limit"] = strconv.Itoa(req.Limit) + } + return query +} + +type RoomInfo struct { + RoomID id.RoomID `json:"room_id"` + Name string `json:"name"` + CanonicalAlias id.RoomAlias `json:"canonical_alias"` + JoinedMembers int `json:"joined_members"` + JoinedLocalMembers int `json:"joined_local_members"` + Version string `json:"version"` + Creator id.UserID `json:"creator"` + Encryption id.Algorithm `json:"encryption"` + Federatable bool `json:"federatable"` + Public bool `json:"public"` + JoinRules event.JoinRule `json:"join_rules"` + GuestAccess event.GuestAccess `json:"guest_access"` + HistoryVisibility event.HistoryVisibility `json:"history_visibility"` + StateEvents int `json:"state_events"` + RoomType event.RoomType `json:"room_type"` +} + +type RespListRooms struct { + Rooms []RoomInfo `json:"rooms"` + Offset int `json:"offset"` + TotalRooms int `json:"total_rooms"` + NextBatch int `json:"next_batch"` + PrevBatch int `json:"prev_batch"` +} + +// ListRooms returns a list of rooms on the server. +// +// https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#list-room-api +func (cli *Client) ListRooms(ctx context.Context, req ReqListRoom) (RespListRooms, error) { + var resp RespListRooms + var reqURL string + reqURL = cli.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery()) + _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ + Method: http.MethodGet, + URL: reqURL, + ResponseJSON: &resp, + }) + return resp, err +} + +type RespRoomMessages = mautrix.RespMessages + +// RoomMessages returns a list of messages in a room. +// +// https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#room-messages-api +func (cli *Client) RoomMessages(ctx context.Context, roomID id.RoomID, from, to string, dir mautrix.Direction, filter *mautrix.FilterPart, limit int) (resp *RespRoomMessages, err error) { + query := map[string]string{ + "from": from, + "dir": string(dir), + } + if filter != nil { + filterJSON, err := json.Marshal(filter) + if err != nil { + return nil, err + } + query["filter"] = string(filterJSON) + } + if to != "" { + query["to"] = to + } + if limit != 0 { + query["limit"] = strconv.Itoa(limit) + } + urlPath := cli.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms", roomID, "messages"}, query) + _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{ + Method: http.MethodGet, + URL: urlPath, + ResponseJSON: &resp, + }) + return resp, err +} + +type ReqDeleteRoom struct { + Purge bool `json:"purge,omitempty"` + Block bool `json:"block,omitempty"` + Message string `json:"message,omitempty"` + RoomName string `json:"room_name,omitempty"` + NewRoomUserID id.UserID `json:"new_room_user_id,omitempty"` +} + +type RespDeleteRoom struct { + DeleteID string `json:"delete_id"` +} + +// DeleteRoom deletes a room from the server, optionally blocking it and/or purging all data from the database. +// +// This calls the async version of the endpoint, which will return immediately and delete the room in the background. +// +// https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#version-2-new-version +func (cli *Client) DeleteRoom(ctx context.Context, roomID id.RoomID, req ReqDeleteRoom) (RespDeleteRoom, error) { + reqURL := cli.BuildAdminURL("v2", "rooms", roomID) + var resp RespDeleteRoom + _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ + Method: http.MethodDelete, + URL: reqURL, + ResponseJSON: &resp, + RequestJSON: &req, + }) + return resp, err +} + +type RespRoomsMembers struct { + Members []id.UserID `json:"members"` + Total int `json:"total"` +} + +// RoomMembers gets the full list of members in a room. +// +// https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#room-members-api +func (cli *Client) RoomMembers(ctx context.Context, roomID id.RoomID) (RespRoomsMembers, error) { + reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "members") + var resp RespRoomsMembers + _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ + Method: http.MethodGet, + URL: reqURL, + ResponseJSON: &resp, + }) + return resp, err +} + +type ReqMakeRoomAdmin struct { + UserID id.UserID `json:"user_id"` +} + +// MakeRoomAdmin promotes a user to admin in a room. This requires that a local user has permission to promote users in the room. +// +// https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#make-room-admin-api +func (cli *Client) MakeRoomAdmin(ctx context.Context, roomIDOrAlias string, req ReqMakeRoomAdmin) error { + reqURL := cli.BuildAdminURL("v1", "rooms", roomIDOrAlias, "make_room_admin") + _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ + Method: http.MethodPost, + URL: reqURL, + RequestJSON: &req, + }) + return err +} + +type ReqJoinUserToRoom struct { + UserID id.UserID `json:"user_id"` +} + +// JoinUserToRoom makes a local user join the given room. +// +// https://matrix-org.github.io/synapse/latest/admin_api/room_membership.html +func (cli *Client) JoinUserToRoom(ctx context.Context, roomID id.RoomID, req ReqJoinUserToRoom) error { + reqURL := cli.BuildAdminURL("v1", "join", roomID) + _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ + Method: http.MethodPost, + URL: reqURL, + RequestJSON: &req, + }) + return err +} + +type ReqBlockRoom struct { + Block bool `json:"block"` +} + +// BlockRoom blocks or unblocks a room. +// +// https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#block-room-api +func (cli *Client) BlockRoom(ctx context.Context, roomID id.RoomID, req ReqBlockRoom) error { + reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "block") + _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ + Method: http.MethodPut, + URL: reqURL, + RequestJSON: &req, + }) + return err +} + +// RoomsBlockResponse represents the response containing wether a room is blocked or not +type RoomsBlockResponse struct { + Block bool `json:"block"` + UserID id.UserID `json:"user_id"` +} + +// GetRoomBlockStatus gets whether a room is currently blocked. +// +// https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#get-block-status +func (cli *Client) GetRoomBlockStatus(ctx context.Context, roomID id.RoomID) (RoomsBlockResponse, error) { + var resp RoomsBlockResponse + reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "block") + _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ + Method: http.MethodGet, + URL: reqURL, + ResponseJSON: &resp, + }) + return resp, err +} diff --git a/synapseadmin/userapi.go b/synapseadmin/userapi.go index aa1ce2a7..31d0a6dc 100644 --- a/synapseadmin/userapi.go +++ b/synapseadmin/userapi.go @@ -21,7 +21,6 @@ import ( type ReqResetPassword struct { // The user whose password to reset. UserID id.UserID `json:"-"` - // The new password for the user. Required. NewPassword string `json:"new_password"` // Whether all the user's existing devices should be logged out after the password change. @@ -86,7 +85,7 @@ type RespUserInfo struct { UserID id.UserID `json:"name"` DisplayName string `json:"displayname"` AvatarURL id.ContentURIString `json:"avatar_url"` - Guest int `json:"is_guest"` + Guest bool `json:"is_guest"` Admin bool `json:"admin"` Deactivated bool `json:"deactivated"` Erased bool `json:"erased"` @@ -109,3 +108,91 @@ func (cli *Client) GetUserInfo(ctx context.Context, userID id.UserID) (resp *Res }) return } + +type ReqDeleteUser struct { + Erase bool `json:"erase"` +} + +// DeactivateAccount deactivates a specific local user account. +// +// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#deactivate-account +func (cli *Client) DeactivateAccount(ctx context.Context, userID id.UserID, req ReqDeleteUser) error { + reqURL := cli.BuildAdminURL("v1", "deactivate", userID) + _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ + Method: http.MethodPost, + URL: reqURL, + RequestJSON: &req, + }) + return err +} + +type ReqCreateOrModifyAccount struct { + Password string `json:"password,omitempty"` + LogoutDevices *bool `json:"logout_devices,omitempty"` + + Deactivated *bool `json:"deactivated,omitempty"` + Admin *bool `json:"admin,omitempty"` + Locked *bool `json:"locked,omitempty"` + + Displayname string `json:"displayname,omitempty"` + AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` + UserType string `json:"user_type,omitempty"` +} + +// CreateOrModifyAccount creates or modifies an account on the server. +// +// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#create-or-modify-account +func (cli *Client) CreateOrModifyAccount(ctx context.Context, userID id.UserID, req ReqCreateOrModifyAccount) error { + reqURL := cli.BuildAdminURL("v2", "users", userID) + _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ + Method: http.MethodPut, + URL: reqURL, + RequestJSON: &req, + }) + return err +} + +type RatelimitOverride struct { + MessagesPerSecond int `json:"messages_per_second"` + BurstCount int `json:"burst_count"` +} + +type ReqSetRatelimit = RatelimitOverride + +// SetUserRatelimit overrides the message sending ratelimit for a specific user. +// +// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#set-ratelimit +func (cli *Client) SetUserRatelimit(ctx context.Context, userID id.UserID, req ReqSetRatelimit) error { + reqURL := cli.BuildAdminURL("v1", "users", userID, "override_ratelimit") + _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ + Method: http.MethodPost, + URL: reqURL, + RequestJSON: &req, + }) + return err +} + +type RespUserRatelimit = RatelimitOverride + +// GetUserRatelimit gets the ratelimit override for the given user. +// +// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#get-status-of-ratelimit +func (cli *Client) GetUserRatelimit(ctx context.Context, userID id.UserID) (resp RespUserRatelimit, err error) { + _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{ + Method: http.MethodGet, + URL: cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), + ResponseJSON: &resp, + }) + return +} + +// DeleteUserRatelimit deletes the ratelimit override for the given user, returning them to the default ratelimits. +// +// https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#delete-ratelimit +func (cli *Client) DeleteUserRatelimit(ctx context.Context, userID id.UserID) (err error) { + _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{ + Method: http.MethodDelete, + URL: cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), + }) + return +} From a8e1ae1936730f6aae6beb3ad660a5fbb73b4d07 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 3 Mar 2024 12:47:29 +0200 Subject: [PATCH 0146/1647] Link to FAQ in some error cases --- bridge/bridge.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/bridge/bridge.go b/bridge/bridge.go index 56f27a8b..c50ede30 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -333,8 +333,10 @@ func (br *Bridge) ensureConnection(ctx context.Context) { if err != nil { if errors.Is(err, mautrix.MUnknownToken) { br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was not accepted. Is the registration file installed in your homeserver correctly?") + br.ZLog.Info().Msg("See https://docs.mau.fi/faq/as-token for more info") } else if errors.Is(err, mautrix.MExclusive) { br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was accepted, but the /register request was not. Are the homeserver domain, bot username and username template in the config correct, and do they match the values in the registration?") + br.ZLog.Info().Msg("See https://docs.mau.fi/faq/as-register for more info") } else { br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("/whoami request failed with unknown error") } @@ -387,6 +389,7 @@ func (br *Bridge) ensureConnection(ctx context.Context) { } if outOfRetries { evt.Msg("Homeserver -> bridge connection is not working") + br.ZLog.Info().Msg("See https://docs.mau.fi/faq/as-ping for more info") os.Exit(13) } evt.Msg("Homeserver -> bridge connection is not working, retrying in 5 seconds...") @@ -535,6 +538,7 @@ func (br *Bridge) init() { err = br.validateConfig() if err != nil { br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("Configuration error") + br.ZLog.Info().Msg("See https://docs.mau.fi/faq/field-unconfigured for more info") os.Exit(11) } @@ -667,6 +671,7 @@ func (br *Bridge) LogDBUpgradeErrorAndExit(name string, err error) { os.Exit(18) } else if errors.Is(err, dbutil.ErrForeignTables) { br.ZLog.Info().Msg("You can use --ignore-foreign-tables to ignore this error") + br.ZLog.Info().Msg("See https://docs.mau.fi/faq/foreign-tables for more info") } else if errors.Is(err, dbutil.ErrNotOwned) { br.ZLog.Info().Msg("Sharing the same database with different programs is not supported") } else if errors.Is(err, dbutil.ErrUnsupportedDatabaseVersion) { From bb6c88faf3cea0c65c6f3671d13dc2e72256298f Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Wed, 6 Mar 2024 08:00:06 +0200 Subject: [PATCH 0147/1647] Add callback on megolm session receive --- crypto/keybackup.go | 2 +- crypto/keyimport.go | 2 +- crypto/keysharing.go | 2 +- crypto/machine.go | 11 +++++++++-- 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/crypto/keybackup.go b/crypto/keybackup.go index cf5e747f..d3701e93 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -179,6 +179,6 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id. if err != nil { return fmt.Errorf("failed to store new inbound group session: %w", err) } - mach.markSessionReceived(sessionID) + mach.markSessionReceived(ctx, sessionID) return nil } diff --git a/crypto/keyimport.go b/crypto/keyimport.go index 2d9f3486..da51774f 100644 --- a/crypto/keyimport.go +++ b/crypto/keyimport.go @@ -122,7 +122,7 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor if err != nil { return false, fmt.Errorf("failed to store imported session: %w", err) } - mach.markSessionReceived(igs.ID()) + mach.markSessionReceived(ctx, igs.ID()) return true, nil } diff --git a/crypto/keysharing.go b/crypto/keysharing.go index fa422ca5..05e7f894 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -189,7 +189,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt log.Error().Err(err).Msg("Failed to store new inbound group session") return false } - mach.markSessionReceived(content.SessionID) + mach.markSessionReceived(ctx, content.SessionID) log.Debug().Msg("Received forwarded inbound group session") return true } diff --git a/crypto/machine.go b/crypto/machine.go index 4a691166..4417faf3 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -52,6 +52,9 @@ type OlmMachine struct { keyWaiters map[id.SessionID]chan struct{} keyWaitersLock sync.Mutex + // Optional callback which is called when we save a session to store + SessionReceived func(context.Context, id.SessionID) + devicesToUnwedge map[id.IdentityKey]bool devicesToUnwedgeLock sync.Mutex recentlyUnwedged map[id.IdentityKey]time.Time @@ -520,7 +523,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen log.Error().Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session") return } - mach.markSessionReceived(sessionID) + mach.markSessionReceived(ctx, sessionID) log.Debug(). Str("session_id", sessionID.String()). Str("sender_key", senderKey.String()). @@ -530,7 +533,11 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen Msg("Received inbound group session") } -func (mach *OlmMachine) markSessionReceived(id id.SessionID) { +func (mach *OlmMachine) markSessionReceived(ctx context.Context, id id.SessionID) { + if mach.SessionReceived != nil { + mach.SessionReceived(ctx, id) + } + mach.keyWaitersLock.Lock() ch, ok := mach.keyWaiters[id] if ok { From a10c1142030a017ca158cd51de546e86bd852127 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 20 Feb 2024 08:53:37 -0700 Subject: [PATCH 0148/1647] verification: remove (go)olm SAS code Signed-off-by: Sumner Evans --- crypto/goolm/sas/sas.go | 76 ----------------- crypto/goolm/sas/sas_test.go | 112 ------------------------ crypto/olm/verification.go | 142 ------------------------------- crypto/olm/verification_goolm.go | 23 ----- 4 files changed, 353 deletions(-) delete mode 100644 crypto/goolm/sas/sas.go delete mode 100644 crypto/goolm/sas/sas_test.go delete mode 100644 crypto/olm/verification.go delete mode 100644 crypto/olm/verification_goolm.go diff --git a/crypto/goolm/sas/sas.go b/crypto/goolm/sas/sas.go deleted file mode 100644 index e34ba41c..00000000 --- a/crypto/goolm/sas/sas.go +++ /dev/null @@ -1,76 +0,0 @@ -// Package sas provides the means to do SAS between keys -package sas - -import ( - "io" - - "maunium.net/go/mautrix/crypto/goolm" - "maunium.net/go/mautrix/crypto/goolm/crypto" -) - -// SAS contains the key pair and secret for SAS. -type SAS struct { - KeyPair crypto.Curve25519KeyPair - Secret []byte -} - -// New creates a new SAS with a new key pair. -func New() (*SAS, error) { - kp, err := crypto.Curve25519GenerateKey(nil) - if err != nil { - return nil, err - } - s := &SAS{ - KeyPair: kp, - } - return s, nil -} - -// GetPubkey returns the public key of the key pair base64 encoded -func (s SAS) GetPubkey() []byte { - return goolm.Base64Encode(s.KeyPair.PublicKey) -} - -// SetTheirKey sets the key of the other party and computes the shared secret. -func (s *SAS) SetTheirKey(key []byte) error { - keyDecoded, err := goolm.Base64Decode(key) - if err != nil { - return err - } - sharedSecret, err := s.KeyPair.SharedSecret(keyDecoded) - if err != nil { - return err - } - s.Secret = sharedSecret - return nil -} - -// GenerateBytes creates length bytes from the shared secret and info. -func (s SAS) GenerateBytes(info []byte, length uint) ([]byte, error) { - byteReader := crypto.HKDFSHA256(s.Secret, nil, info) - output := make([]byte, length) - if _, err := io.ReadFull(byteReader, output); err != nil { - return nil, err - } - return output, nil -} - -// calculateMAC returns a base64 encoded MAC of input. -func (s *SAS) calculateMAC(input, info []byte, length uint) ([]byte, error) { - key, err := s.GenerateBytes(info, length) - if err != nil { - return nil, err - } - mac := crypto.HMACSHA256(key, input) - return goolm.Base64Encode(mac), nil -} - -// CalculateMACFixes returns a base64 encoded, 32 byte long MAC of input. -func (s SAS) CalculateMAC(input, info []byte) ([]byte, error) { - return s.calculateMAC(input, info, 32) -} - -// CalculateMACLongKDF returns a base64 encoded, 256 byte long MAC of input. -func (s SAS) CalculateMACLongKDF(input, info []byte) ([]byte, error) { - return s.calculateMAC(input, info, 256) -} diff --git a/crypto/goolm/sas/sas_test.go b/crypto/goolm/sas/sas_test.go deleted file mode 100644 index c0acec70..00000000 --- a/crypto/goolm/sas/sas_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package sas_test - -import ( - "bytes" - "testing" - - "maunium.net/go/mautrix/crypto/goolm/crypto" - "maunium.net/go/mautrix/crypto/goolm/sas" -) - -func initSAS() (*sas.SAS, *sas.SAS, error) { - alicePrivate := crypto.Curve25519PrivateKey([]byte{ - 0x77, 0x07, 0x6D, 0x0A, 0x73, 0x18, 0xA5, 0x7D, - 0x3C, 0x16, 0xC1, 0x72, 0x51, 0xB2, 0x66, 0x45, - 0xDF, 0x4C, 0x2F, 0x87, 0xEB, 0xC0, 0x99, 0x2A, - 0xB1, 0x77, 0xFB, 0xA5, 0x1D, 0xB9, 0x2C, 0x2A, - }) - bobPrivate := crypto.Curve25519PrivateKey([]byte{ - 0x5D, 0xAB, 0x08, 0x7E, 0x62, 0x4A, 0x8A, 0x4B, - 0x79, 0xE1, 0x7F, 0x8B, 0x83, 0x80, 0x0E, 0xE6, - 0x6F, 0x3B, 0xB1, 0x29, 0x26, 0x18, 0xB6, 0xFD, - 0x1C, 0x2F, 0x8B, 0x27, 0xFF, 0x88, 0xE0, 0xEB, - }) - - aliceSAS, err := sas.New() - if err != nil { - return nil, nil, err - } - aliceSAS.KeyPair.PrivateKey = alicePrivate - aliceSAS.KeyPair.PublicKey, err = alicePrivate.PubKey() - if err != nil { - return nil, nil, err - } - - bobSAS, err := sas.New() - if err != nil { - return nil, nil, err - } - bobSAS.KeyPair.PrivateKey = bobPrivate - bobSAS.KeyPair.PublicKey, err = bobPrivate.PubKey() - if err != nil { - return nil, nil, err - } - return aliceSAS, bobSAS, nil -} - -func TestGenerateBytes(t *testing.T) { - aliceSAS, bobSAS, err := initSAS() - if err != nil { - t.Fatal(err) - } - alicePublicEncoded := []byte("hSDwCYkwp1R0i33ctD73Wg2/Og0mOBr066SpjqqbTmo") - bobPublicEncoded := []byte("3p7bfXt9wbTTW2HC7OQ1Nz+DQ8hbeGdNrfx+FG+IK08") - - if !bytes.Equal(aliceSAS.GetPubkey(), alicePublicEncoded) { - t.Fatal("public keys not equal") - } - if !bytes.Equal(bobSAS.GetPubkey(), bobPublicEncoded) { - t.Fatal("public keys not equal") - } - - err = aliceSAS.SetTheirKey(bobSAS.GetPubkey()) - if err != nil { - t.Fatal(err) - } - err = bobSAS.SetTheirKey(aliceSAS.GetPubkey()) - if err != nil { - t.Fatal(err) - } - - aliceBytes, err := aliceSAS.GenerateBytes([]byte("SAS"), 6) - if err != nil { - t.Fatal(err) - } - bobBytes, err := bobSAS.GenerateBytes([]byte("SAS"), 6) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(aliceBytes, bobBytes) { - t.Fatal("results are not equal") - } -} - -func TestSASMac(t *testing.T) { - aliceSAS, bobSAS, err := initSAS() - if err != nil { - t.Fatal(err) - } - err = aliceSAS.SetTheirKey(bobSAS.GetPubkey()) - if err != nil { - t.Fatal(err) - } - err = bobSAS.SetTheirKey(aliceSAS.GetPubkey()) - if err != nil { - t.Fatal(err) - } - - plainText := []byte("Hello world!") - info := []byte("MAC") - - aliceMac, err := aliceSAS.CalculateMAC(plainText, info) - if err != nil { - t.Fatal(err) - } - bobMac, err := bobSAS.CalculateMAC(plainText, info) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(aliceMac, bobMac) { - t.Fatal("results are not equal") - } -} diff --git a/crypto/olm/verification.go b/crypto/olm/verification.go deleted file mode 100644 index bb0db7be..00000000 --- a/crypto/olm/verification.go +++ /dev/null @@ -1,142 +0,0 @@ -//go:build !nosas && !goolm - -package olm - -// #cgo LDFLAGS: -lolm -lstdc++ -// #include -// #include -import "C" - -import ( - "crypto/rand" - "unsafe" -) - -// SAS stores an Olm Short Authentication String (SAS) object. -type SAS struct { - int *C.OlmSAS - mem []byte -} - -// NewBlankSAS initializes an empty SAS object. -func NewBlankSAS() *SAS { - memory := make([]byte, sasSize()) - return &SAS{ - int: C.olm_sas(unsafe.Pointer(&memory[0])), - mem: memory, - } -} - -// sasSize is the size of a SAS object in bytes. -func sasSize() uint { - return uint(C.olm_sas_size()) -} - -// sasRandomLength is the number of random bytes needed to create an SAS object. -func (sas *SAS) sasRandomLength() uint { - return uint(C.olm_create_sas_random_length(sas.int)) -} - -// NewSAS creates a new SAS object. -func NewSAS() *SAS { - sas := NewBlankSAS() - random := make([]byte, sas.sasRandomLength()+1) - _, err := rand.Read(random) - if err != nil { - panic(NotEnoughGoRandom) - } - r := C.olm_create_sas( - (*C.OlmSAS)(sas.int), - unsafe.Pointer(&random[0]), - C.size_t(len(random))) - if r == errorVal() { - panic(sas.lastError()) - } else { - return sas - } -} - -// clear clears the memory used to back an SAS object. -func (sas *SAS) clear() uint { - return uint(C.olm_clear_sas(sas.int)) -} - -// lastError returns the most recent error to happen to an SAS object. -func (sas *SAS) lastError() error { - return convertError(C.GoString(C.olm_sas_last_error(sas.int))) -} - -// pubkeyLength is the size of a public key in bytes. -func (sas *SAS) pubkeyLength() uint { - return uint(C.olm_sas_pubkey_length((*C.OlmSAS)(sas.int))) -} - -// GetPubkey gets the public key for the SAS object. -func (sas *SAS) GetPubkey() []byte { - pubkey := make([]byte, sas.pubkeyLength()) - r := C.olm_sas_get_pubkey( - (*C.OlmSAS)(sas.int), - unsafe.Pointer(&pubkey[0]), - C.size_t(len(pubkey))) - if r == errorVal() { - panic(sas.lastError()) - } - return pubkey -} - -// SetTheirKey sets the public key of the other user. -func (sas *SAS) SetTheirKey(theirKey []byte) error { - theirKeyCopy := make([]byte, len(theirKey)) - copy(theirKeyCopy, theirKey) - r := C.olm_sas_set_their_key( - (*C.OlmSAS)(sas.int), - unsafe.Pointer(&theirKeyCopy[0]), - C.size_t(len(theirKeyCopy))) - if r == errorVal() { - return sas.lastError() - } - return nil -} - -// GenerateBytes generates bytes to use for the short authentication string. -func (sas *SAS) GenerateBytes(info []byte, count uint) ([]byte, error) { - infoCopy := make([]byte, len(info)) - copy(infoCopy, info) - output := make([]byte, count) - r := C.olm_sas_generate_bytes( - (*C.OlmSAS)(sas.int), - unsafe.Pointer(&infoCopy[0]), - C.size_t(len(infoCopy)), - unsafe.Pointer(&output[0]), - C.size_t(len(output))) - if r == errorVal() { - return nil, sas.lastError() - } - return output, nil -} - -// macLength is the size of a message authentication code generated by olm_sas_calculate_mac. -func (sas *SAS) macLength() uint { - return uint(C.olm_sas_mac_length((*C.OlmSAS)(sas.int))) -} - -// CalculateMAC generates a message authentication code (MAC) based on the shared secret. -func (sas *SAS) CalculateMAC(input []byte, info []byte) ([]byte, error) { - inputCopy := make([]byte, len(input)) - copy(inputCopy, input) - infoCopy := make([]byte, len(info)) - copy(infoCopy, info) - mac := make([]byte, sas.macLength()) - r := C.olm_sas_calculate_mac( - (*C.OlmSAS)(sas.int), - unsafe.Pointer(&inputCopy[0]), - C.size_t(len(inputCopy)), - unsafe.Pointer(&infoCopy[0]), - C.size_t(len(infoCopy)), - unsafe.Pointer(&mac[0]), - C.size_t(len(mac))) - if r == errorVal() { - return nil, sas.lastError() - } - return mac, nil -} diff --git a/crypto/olm/verification_goolm.go b/crypto/olm/verification_goolm.go deleted file mode 100644 index fab51e5c..00000000 --- a/crypto/olm/verification_goolm.go +++ /dev/null @@ -1,23 +0,0 @@ -//go:build !nosas && goolm - -package olm - -import ( - "maunium.net/go/mautrix/crypto/goolm/sas" -) - -// SAS stores an Olm Short Authentication String (SAS) object. -type SAS struct { - sas.SAS -} - -// NewSAS creates a new SAS object. -func NewSAS() *SAS { - newSAS, err := sas.New() - if err != nil { - panic(err) - } - return &SAS{ - SAS: *newSAS, - } -} From 284ab0d62c5770ce3c7a66a52b10d9fca99db99e Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 8 Mar 2024 14:04:19 -0700 Subject: [PATCH 0149/1647] olm: remove SHA256 base64 utility Signed-off-by: Sumner Evans --- crypto/olm/sha256b64.go | 15 --------------- 1 file changed, 15 deletions(-) delete mode 100644 crypto/olm/sha256b64.go diff --git a/crypto/olm/sha256b64.go b/crypto/olm/sha256b64.go deleted file mode 100644 index 711c9454..00000000 --- a/crypto/olm/sha256b64.go +++ /dev/null @@ -1,15 +0,0 @@ -package olm - -import ( - "crypto/sha256" - "encoding/base64" -) - -// SHA256B64 calculates the SHA-256 hash of the input and encodes it as base64. -func SHA256B64(input []byte) string { - if len(input) == 0 { - panic(EmptyInput) - } - hash := sha256.Sum256([]byte(input)) - return base64.RawStdEncoding.EncodeToString(hash[:]) -} From a6b4b3bf347903ba43946236f9d5b661671824f9 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 8 Mar 2024 15:27:31 -0700 Subject: [PATCH 0150/1647] ci: run tests with goolm as well Co-authored-by: Tulir Asokan Signed-off-by: Sumner Evans --- .github/workflows/go.yml | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 488e4dd5..66f6aee1 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -35,7 +35,7 @@ jobs: fail-fast: false matrix: go-version: ["1.21", "1.22"] - name: Build ${{ matrix.go-version == '1.22' && '(latest)' || '(old)' }} + name: Build (${{ matrix.go-version == '1.22' && 'latest' || 'old' }}, libolm) steps: - uses: actions/checkout@v4 @@ -59,3 +59,31 @@ jobs: - name: Test run: go test -json -v ./... 2>&1 | gotestfmt + + build-goolm: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + go-version: ["1.21", "1.22"] + name: Build (${{ matrix.go-version == '1.22' && 'latest' || 'old' }}, goolm) + + steps: + - uses: actions/checkout@v4 + + - name: Set up Go ${{ matrix.go-version }} + uses: actions/setup-go@v5 + with: + go-version: ${{ matrix.go-version }} + cache: true + + - name: Set up gotestfmt + uses: GoTestTools/gotestfmt-action@v2 + with: + token: ${{ secrets.GITHUB_TOKEN }} + + - name: Build + run: go build -tags=goolm -v ./... + + - name: Test + run: go test -tags=goolm -json -v ./... 2>&1 | gotestfmt From b8e4202c0faba38576065f142bb4fbec75aeb6a4 Mon Sep 17 00:00:00 2001 From: Malte E <97891689+maltee1@users.noreply.github.com> Date: Sat, 9 Mar 2024 15:33:09 +0100 Subject: [PATCH 0151/1647] Add handler for power levels in bridges (#189) --- bridge/bridge.go | 5 +++++ bridge/matrix.go | 16 ++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/bridge/bridge.go b/bridge/bridge.go index c50ede30..4a7ba465 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -91,6 +91,11 @@ type DisappearingPortal interface { ScheduleDisappearing() } +type PowerLevelHandlingPortal interface { + Portal + HandleMatrixPowerLevels(sender User, evt *event.Event) +} + type User interface { GetPermissionLevel() bridgeconfig.PermissionLevel IsLoggedIn() bool diff --git a/bridge/matrix.go b/bridge/matrix.go index 5aa457fa..b49279aa 100644 --- a/bridge/matrix.go +++ b/bridge/matrix.go @@ -65,6 +65,7 @@ func NewMatrixHandler(br *Bridge) *MatrixHandler { br.EventProcessor.On(event.StateEncryption, handler.HandleEncryption) br.EventProcessor.On(event.EphemeralEventReceipt, handler.HandleReceipt) br.EventProcessor.On(event.EphemeralEventTyping, handler.HandleTyping) + br.EventProcessor.On(event.StatePowerLevels, handler.HandlePowerLevels) return handler } @@ -686,3 +687,18 @@ func (mx *MatrixHandler) HandleTyping(_ context.Context, evt *event.Event) { } typingPortal.HandleMatrixTyping(evt.Content.AsTyping().UserIDs) } + +func (mx *MatrixHandler) HandlePowerLevels(_ context.Context, evt *event.Event) { + if mx.shouldIgnoreEvent(evt) { + return + } + portal := mx.bridge.Child.GetIPortal(evt.RoomID) + if portal == nil { + return + } + powerLevelPortal, ok := portal.(PowerLevelHandlingPortal) + if ok { + user := mx.bridge.Child.GetIUser(evt.Sender, true) + powerLevelPortal.HandleMatrixPowerLevels(user, evt) + } +} From 6b1a039bebb8b092c4ef9523356d26111734984d Mon Sep 17 00:00:00 2001 From: Malte E Date: Tue, 5 Mar 2024 22:08:48 +0100 Subject: [PATCH 0152/1647] add join rule handler --- bridge/bridge.go | 5 +++++ bridge/matrix.go | 16 ++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/bridge/bridge.go b/bridge/bridge.go index 4a7ba465..4f559bc8 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -96,6 +96,11 @@ type PowerLevelHandlingPortal interface { HandleMatrixPowerLevels(sender User, evt *event.Event) } +type JoinRuleHandlingPortal interface { + Portal + HandleMatrixJoinRule(sender User, evt *event.Event) +} + type User interface { GetPermissionLevel() bridgeconfig.PermissionLevel IsLoggedIn() bool diff --git a/bridge/matrix.go b/bridge/matrix.go index b49279aa..d27377a5 100644 --- a/bridge/matrix.go +++ b/bridge/matrix.go @@ -66,6 +66,7 @@ func NewMatrixHandler(br *Bridge) *MatrixHandler { br.EventProcessor.On(event.EphemeralEventReceipt, handler.HandleReceipt) br.EventProcessor.On(event.EphemeralEventTyping, handler.HandleTyping) br.EventProcessor.On(event.StatePowerLevels, handler.HandlePowerLevels) + br.EventProcessor.On(event.StateJoinRules, handler.HandleJoinRule) return handler } @@ -702,3 +703,18 @@ func (mx *MatrixHandler) HandlePowerLevels(_ context.Context, evt *event.Event) powerLevelPortal.HandleMatrixPowerLevels(user, evt) } } + +func (mx *MatrixHandler) HandleJoinRule(_ context.Context, evt *event.Event) { + if mx.shouldIgnoreEvent(evt) { + return + } + portal := mx.bridge.Child.GetIPortal(evt.RoomID) + if portal == nil { + return + } + joinRulePortal, ok := portal.(JoinRuleHandlingPortal) + if ok { + user := mx.bridge.Child.GetIUser(evt.Sender, true) + joinRulePortal.HandleMatrixJoinRule(user, evt) + } +} From 41dfb400647ec5b990e015115e8ce68552b7a7e7 Mon Sep 17 00:00:00 2001 From: Malte E Date: Sat, 9 Mar 2024 13:37:59 +0100 Subject: [PATCH 0153/1647] add ban/unban handling --- bridge/bridge.go | 6 ++++++ bridge/matrix.go | 44 +++++++++++++++++++++++++++----------------- 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/bridge/bridge.go b/bridge/bridge.go index 4f559bc8..ea66c1a4 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -101,6 +101,12 @@ type JoinRuleHandlingPortal interface { HandleMatrixJoinRule(sender User, evt *event.Event) } +type BanHandlingPortal interface { + Portal + HandleMatrixBan(sender User, ghost Ghost, evt *event.Event) + HandleMatrixUnban(sender User, ghost Ghost, evt *event.Event) +} + type User interface { GetPermissionLevel() bridgeconfig.PermissionLevel IsLoggedIn() bool diff --git a/bridge/matrix.go b/bridge/matrix.go index d27377a5..99843329 100644 --- a/bridge/matrix.go +++ b/bridge/matrix.go @@ -276,27 +276,37 @@ func (mx *MatrixHandler) HandleMembership(ctx context.Context, evt *event.Event) } else if user.GetPermissionLevel() < bridgeconfig.PermissionLevelUser || !user.IsLoggedIn() { return } - - mhp, ok := portal.(MembershipHandlingPortal) - if !ok { + bhp, bhpOk := portal.(BanHandlingPortal) + mhp, mhpOk := portal.(MembershipHandlingPortal) + if !(mhpOk || bhpOk) { return } - - if content.Membership == event.MembershipLeave { - if evt.Unsigned.PrevContent != nil { - _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) - prevContent, ok := evt.Unsigned.PrevContent.Parsed.(*event.MemberEventContent) - if ok && prevContent.Membership != "join" { - return + var prevContent *event.MemberEventContent + if evt.Unsigned.PrevContent != nil { + _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) + var ok bool + prevContent, ok = evt.Unsigned.PrevContent.Parsed.(*event.MemberEventContent) + if !ok { + prevContent = &event.MemberEventContent{Membership: event.MembershipLeave} + } + } + if bhpOk { + if content.Membership == event.MembershipBan { + bhp.HandleMatrixBan(user, ghost, evt) + } else if content.Membership == event.MembershipLeave && prevContent.Membership == event.MembershipBan { + bhp.HandleMatrixUnban(user, ghost, evt) + } + } + if mhpOk { + if content.Membership == event.MembershipLeave && prevContent.Membership == event.MembershipJoin { + if isSelf { + mhp.HandleMatrixLeave(user, evt) + } else if ghost != nil { + mhp.HandleMatrixKick(user, ghost, evt) } + } else if content.Membership == event.MembershipInvite && !isSelf && ghost != nil { + mhp.HandleMatrixInvite(user, ghost, evt) } - if isSelf { - mhp.HandleMatrixLeave(user, evt) - } else if ghost != nil { - mhp.HandleMatrixKick(user, ghost, evt) - } - } else if content.Membership == event.MembershipInvite && !isSelf && ghost != nil { - mhp.HandleMatrixInvite(user, ghost, evt) } // TODO kicking/inviting non-ghost users users } From db41583fddb7eb64a103a5fc13fd40dc78ebeefb Mon Sep 17 00:00:00 2001 From: Malte E Date: Sun, 10 Mar 2024 13:47:09 +0100 Subject: [PATCH 0154/1647] add knock handling --- bridge/bridge.go | 8 ++++++++ bridge/matrix.go | 20 ++++++++++++++++++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/bridge/bridge.go b/bridge/bridge.go index ea66c1a4..4a6cf68e 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -107,6 +107,14 @@ type BanHandlingPortal interface { HandleMatrixUnban(sender User, ghost Ghost, evt *event.Event) } +type KnockHandlingPortal interface { + Portal + HandleMatrixKnock(sender User, evt *event.Event) + HandleMatrixRetractKnock(sender User, evt *event.Event) + HandleMatrixAcceptKnock(sender User, ghost Ghost, evt *event.Event) + HandleMatrixRejectKnock(sender User, ghost Ghost, evt *event.Event) +} + type User interface { GetPermissionLevel() bridgeconfig.PermissionLevel IsLoggedIn() bool diff --git a/bridge/matrix.go b/bridge/matrix.go index 99843329..5b646b34 100644 --- a/bridge/matrix.go +++ b/bridge/matrix.go @@ -278,7 +278,8 @@ func (mx *MatrixHandler) HandleMembership(ctx context.Context, evt *event.Event) } bhp, bhpOk := portal.(BanHandlingPortal) mhp, mhpOk := portal.(MembershipHandlingPortal) - if !(mhpOk || bhpOk) { + khp, khpOk := portal.(KnockHandlingPortal) + if !(mhpOk || bhpOk || khpOk) { return } var prevContent *event.MemberEventContent @@ -290,13 +291,28 @@ func (mx *MatrixHandler) HandleMembership(ctx context.Context, evt *event.Event) prevContent = &event.MemberEventContent{Membership: event.MembershipLeave} } } - if bhpOk { + if bhpOk && ghost != nil { if content.Membership == event.MembershipBan { bhp.HandleMatrixBan(user, ghost, evt) } else if content.Membership == event.MembershipLeave && prevContent.Membership == event.MembershipBan { bhp.HandleMatrixUnban(user, ghost, evt) } } + if khpOk { + if content.Membership == event.MembershipKnock { + khp.HandleMatrixKnock(user, evt) + } else if prevContent.Membership == event.MembershipKnock { + if content.Membership == event.MembershipInvite && ghost != nil { + khp.HandleMatrixAcceptKnock(user, ghost, evt) + } else if content.Membership == event.MembershipLeave { + if isSelf { + khp.HandleMatrixRetractKnock(user, evt) + } else if ghost != nil { + khp.HandleMatrixRejectKnock(user, ghost, evt) + } + } + } + } if mhpOk { if content.Membership == event.MembershipLeave && prevContent.Membership == event.MembershipJoin { if isSelf { From 1423650a2908296ee385040e2be90c93332791a7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 8 Mar 2024 15:09:35 +0200 Subject: [PATCH 0155/1647] Don't use UIA wrapper for appservice user registrations --- appservice/intent.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/appservice/intent.go b/appservice/intent.go index bdf0f066..e091582a 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -10,6 +10,7 @@ import ( "context" "errors" "fmt" + "net/http" "strings" "sync" @@ -50,11 +51,11 @@ func (as *AppService) NewIntentAPI(localpart string) *IntentAPI { } func (intent *IntentAPI) Register(ctx context.Context) error { - _, _, err := intent.Client.Register(ctx, &mautrix.ReqRegister{ + _, err := intent.Client.MakeRequest(ctx, http.MethodPost, intent.BuildClientURL("v3", "register"), &mautrix.ReqRegister{ Username: intent.Localpart, Type: mautrix.AuthTypeAppservice, InhibitLogin: true, - }) + }, nil) return err } From 311a20cea9806c86136d1f5b25e1ad8b4e590d4f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 10 Mar 2024 20:34:59 +0200 Subject: [PATCH 0156/1647] Update CHANGELOG.md --- CHANGELOG.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f1815d89..9b97af05 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,15 @@ ## v0.18.0 (unreleased) +* *(bridge)* Fixed upload size limit not having a default if the server + returned no value. +* *(synapseadmin)* Added wrappers for some room and user admin APIs. + (thanks to [@grvn-ht] in [#181]). +* *(crypto/verificationhelper)* Fixed bugs. +* *(crypto)* Fixed key backup uploading doing too much base64. + +[@grvn-ht]: https://github.com/grvn-ht +[#181]: https://github.com/mautrix/go/pull/181 + ### beta.1 (2024-02-16) * Bumped minimum Go version to 1.21. From a36f60a4f3431f02e6be664f5c73f3f5cdf61d5a Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Mon, 11 Mar 2024 10:33:02 +0200 Subject: [PATCH 0157/1647] Parse Beeper inbox preview event in sync --- responses.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/responses.go b/responses.go index 96e30a1f..9e5fd0aa 100644 --- a/responses.go +++ b/responses.go @@ -309,6 +309,12 @@ func (slr SyncLeftRoom) MarshalJSON() ([]byte, error) { return marshalAndDeleteEmpty((marshalableSyncLeftRoom)(slr), syncLeftRoomPathsToDelete) } +type BeeperInboxPreviewEvent struct { + EventID id.EventID `json:"event_id"` + Timestamp jsontime.UnixMilli `json:"origin_server_ts"` + Event *event.Event `json:"event,omitempty"` +} + type SyncJoinedRoom struct { Summary LazyLoadSummary `json:"summary"` State SyncEventsList `json:"state"` @@ -319,6 +325,8 @@ type SyncJoinedRoom struct { UnreadNotifications *UnreadNotificationCounts `json:"unread_notifications,omitempty"` // https://github.com/matrix-org/matrix-spec-proposals/pull/2654 MSC2654UnreadCount *int `json:"org.matrix.msc2654.unread_count,omitempty"` + // Beeper extension + BeeperInboxPreview *BeeperInboxPreviewEvent `json:"com.beeper.inbox.preview,omitempty"` } type UnreadNotificationCounts struct { From d18dcfc7eb602490b50761ddabaf5cc8eedb9b72 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 11 Mar 2024 15:37:57 +0200 Subject: [PATCH 0158/1647] Update dependencies --- go.mod | 12 ++++++------ go.sum | 24 ++++++++++++------------ 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/go.mod b/go.mod index 226ceb44..8cbc974c 100644 --- a/go.mod +++ b/go.mod @@ -8,15 +8,15 @@ require ( github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.22 github.com/rs/zerolog v1.32.0 - github.com/stretchr/testify v1.8.4 + github.com/stretchr/testify v1.9.0 github.com/tidwall/gjson v1.17.1 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.0 - go.mau.fi/util v0.4.0 + go.mau.fi/util v0.4.1-0.20240311133655-ff64e137ce44 go.mau.fi/zeroconfig v0.1.2 - golang.org/x/crypto v0.19.0 - golang.org/x/exp v0.0.0-20240213143201-ec583247a57a - golang.org/x/net v0.21.0 + golang.org/x/crypto v0.21.0 + golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 + golang.org/x/net v0.22.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 maunium.net/go/maulogger/v2 v2.4.1 @@ -30,6 +30,6 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect - golang.org/x/sys v0.17.0 // indirect + golang.org/x/sys v0.18.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index dbcaf985..26e8e0a5 100644 --- a/go.sum +++ b/go.sum @@ -24,8 +24,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0= github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U= github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= @@ -37,21 +37,21 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.0 h1:EfOIvIMZIzHdB/R/zVrikYLPPwJlfMcNczJFMs1m6sA= github.com/yuin/goldmark v1.7.0/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.4.0 h1:S2X3qU4pUcb/vxBRfAuZjbrR9xVMAXSjQojNBLPBbhs= -go.mau.fi/util v0.4.0/go.mod h1:leeiHtgVBuN+W9aDii3deAXnfC563iN3WK6BF8/AjNw= +go.mau.fi/util v0.4.1-0.20240311133655-ff64e137ce44 h1:d5nG84/nftM2sBibpoT8X4aCTOoueoe26DBBLGHi41k= +go.mau.fi/util v0.4.1-0.20240311133655-ff64e137ce44/go.mod h1:jOAREC/go8T6rGic01cu6WRa90xi9U4z3QmDjRf8xpo= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= -golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= -golang.org/x/exp v0.0.0-20240213143201-ec583247a57a h1:HinSgX1tJRX3KsL//Gxynpw5CTOAIPhgL4W8PNiIpVE= -golang.org/x/exp v0.0.0-20240213143201-ec583247a57a/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc= -golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= -golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= +golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 h1:LfspQV/FYTatPTr/3HzIcmiUFH7PGP+OQ6mgDYo3yuQ= +golang.org/x/exp v0.0.0-20240222234643-814bf88cf225/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc= +golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= +golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= -golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= From 3b65d98c0cb80cfc2c4f1b41c33e8a3598d8de63 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 8 Mar 2024 15:01:44 -0700 Subject: [PATCH 0159/1647] olm/pk: make an interface Signed-off-by: Sumner Evans --- crypto/cross_sign_key.go | 54 +++++++++---------- crypto/cross_sign_signing.go | 10 ++-- crypto/cross_sign_ssss.go | 12 ++--- crypto/cross_sign_test.go | 56 ++++++++++---------- crypto/goolm/pk/decryption.go | 4 +- crypto/goolm/pk/pk_test.go | 13 +++-- crypto/goolm/pk/signing.go | 43 +++++++++++---- crypto/olm/pk_goolm.go | 80 +++++++--------------------- crypto/olm/pk_interface.go | 41 +++++++++++++++ crypto/olm/{pk.go => pk_libolm.go} | 84 ++++++++++++++++++------------ 10 files changed, 222 insertions(+), 175 deletions(-) create mode 100644 crypto/olm/pk_interface.go rename crypto/olm/{pk.go => pk_libolm.go} (67%) diff --git a/crypto/cross_sign_key.go b/crypto/cross_sign_key.go index 45e56b4b..f7dc08cb 100644 --- a/crypto/cross_sign_key.go +++ b/crypto/cross_sign_key.go @@ -19,16 +19,16 @@ import ( // CrossSigningKeysCache holds the three cross-signing keys for the current user. type CrossSigningKeysCache struct { - MasterKey *olm.PkSigning - SelfSigningKey *olm.PkSigning - UserSigningKey *olm.PkSigning + MasterKey olm.PKSigning + SelfSigningKey olm.PKSigning + UserSigningKey olm.PKSigning } func (cskc *CrossSigningKeysCache) PublicKeys() *CrossSigningPublicKeysCache { return &CrossSigningPublicKeysCache{ - MasterKey: cskc.MasterKey.PublicKey, - SelfSigningKey: cskc.SelfSigningKey.PublicKey, - UserSigningKey: cskc.UserSigningKey.PublicKey, + MasterKey: cskc.MasterKey.PublicKey(), + SelfSigningKey: cskc.SelfSigningKey.PublicKey(), + UserSigningKey: cskc.UserSigningKey.PublicKey(), } } @@ -40,28 +40,28 @@ type CrossSigningSeeds struct { func (mach *OlmMachine) ExportCrossSigningKeys() CrossSigningSeeds { return CrossSigningSeeds{ - MasterKey: mach.CrossSigningKeys.MasterKey.Seed, - SelfSigningKey: mach.CrossSigningKeys.SelfSigningKey.Seed, - UserSigningKey: mach.CrossSigningKeys.UserSigningKey.Seed, + MasterKey: mach.CrossSigningKeys.MasterKey.Seed(), + SelfSigningKey: mach.CrossSigningKeys.SelfSigningKey.Seed(), + UserSigningKey: mach.CrossSigningKeys.UserSigningKey.Seed(), } } func (mach *OlmMachine) ImportCrossSigningKeys(keys CrossSigningSeeds) (err error) { var keysCache CrossSigningKeysCache - if keysCache.MasterKey, err = olm.NewPkSigningFromSeed(keys.MasterKey); err != nil { + if keysCache.MasterKey, err = olm.NewPKSigningFromSeed(keys.MasterKey); err != nil { return } - if keysCache.SelfSigningKey, err = olm.NewPkSigningFromSeed(keys.SelfSigningKey); err != nil { + if keysCache.SelfSigningKey, err = olm.NewPKSigningFromSeed(keys.SelfSigningKey); err != nil { return } - if keysCache.UserSigningKey, err = olm.NewPkSigningFromSeed(keys.UserSigningKey); err != nil { + if keysCache.UserSigningKey, err = olm.NewPKSigningFromSeed(keys.UserSigningKey); err != nil { return } mach.Log.Debug(). - Str("master", keysCache.MasterKey.PublicKey.String()). - Str("self_signing", keysCache.SelfSigningKey.PublicKey.String()). - Str("user_signing", keysCache.UserSigningKey.PublicKey.String()). + Str("master", keysCache.MasterKey.PublicKey().String()). + Str("self_signing", keysCache.SelfSigningKey.PublicKey().String()). + Str("user_signing", keysCache.UserSigningKey.PublicKey().String()). Msg("Imported own cross-signing keys") mach.CrossSigningKeys = &keysCache @@ -73,19 +73,19 @@ func (mach *OlmMachine) ImportCrossSigningKeys(keys CrossSigningSeeds) (err erro func (mach *OlmMachine) GenerateCrossSigningKeys() (*CrossSigningKeysCache, error) { var keysCache CrossSigningKeysCache var err error - if keysCache.MasterKey, err = olm.NewPkSigning(); err != nil { + if keysCache.MasterKey, err = olm.NewPKSigning(); err != nil { return nil, fmt.Errorf("failed to generate master key: %w", err) } - if keysCache.SelfSigningKey, err = olm.NewPkSigning(); err != nil { + if keysCache.SelfSigningKey, err = olm.NewPKSigning(); err != nil { return nil, fmt.Errorf("failed to generate self-signing key: %w", err) } - if keysCache.UserSigningKey, err = olm.NewPkSigning(); err != nil { + if keysCache.UserSigningKey, err = olm.NewPKSigning(); err != nil { return nil, fmt.Errorf("failed to generate user-signing key: %w", err) } mach.Log.Debug(). - Str("master", keysCache.MasterKey.PublicKey.String()). - Str("self_signing", keysCache.SelfSigningKey.PublicKey.String()). - Str("user_signing", keysCache.UserSigningKey.PublicKey.String()). + Str("master", keysCache.MasterKey.PublicKey().String()). + Str("self_signing", keysCache.SelfSigningKey.PublicKey().String()). + Str("user_signing", keysCache.UserSigningKey.PublicKey().String()). Msg("Generated cross-signing keys") return &keysCache, nil } @@ -93,12 +93,12 @@ func (mach *OlmMachine) GenerateCrossSigningKeys() (*CrossSigningKeysCache, erro // PublishCrossSigningKeys signs and uploads the public keys of the given cross-signing keys to the server. func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *CrossSigningKeysCache, uiaCallback mautrix.UIACallback) error { userID := mach.Client.UserID - masterKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey.String()) + masterKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey().String()) masterKey := mautrix.CrossSigningKeys{ UserID: userID, Usage: []id.CrossSigningUsage{id.XSUsageMaster}, Keys: map[id.KeyID]id.Ed25519{ - masterKeyID: keys.MasterKey.PublicKey, + masterKeyID: keys.MasterKey.PublicKey(), }, } masterSig, err := mach.account.Internal.SignJSON(masterKey) @@ -111,27 +111,27 @@ func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *Cross UserID: userID, Usage: []id.CrossSigningUsage{id.XSUsageSelfSigning}, Keys: map[id.KeyID]id.Ed25519{ - id.NewKeyID(id.KeyAlgorithmEd25519, keys.SelfSigningKey.PublicKey.String()): keys.SelfSigningKey.PublicKey, + id.NewKeyID(id.KeyAlgorithmEd25519, keys.SelfSigningKey.PublicKey().String()): keys.SelfSigningKey.PublicKey(), }, } selfSig, err := keys.MasterKey.SignJSON(selfKey) if err != nil { return fmt.Errorf("failed to sign self-signing key: %w", err) } - selfKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey.String(), selfSig) + selfKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey().String(), selfSig) userKey := mautrix.CrossSigningKeys{ UserID: userID, Usage: []id.CrossSigningUsage{id.XSUsageUserSigning}, Keys: map[id.KeyID]id.Ed25519{ - id.NewKeyID(id.KeyAlgorithmEd25519, keys.UserSigningKey.PublicKey.String()): keys.UserSigningKey.PublicKey, + id.NewKeyID(id.KeyAlgorithmEd25519, keys.UserSigningKey.PublicKey().String()): keys.UserSigningKey.PublicKey(), }, } userSig, err := keys.MasterKey.SignJSON(userKey) if err != nil { return fmt.Errorf("failed to sign user-signing key: %w", err) } - userKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey.String(), userSig) + userKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey().String(), userSig) err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq{ Master: masterKey, diff --git a/crypto/cross_sign_signing.go b/crypto/cross_sign_signing.go index c0efd54e..86920728 100644 --- a/crypto/cross_sign_signing.go +++ b/crypto/cross_sign_signing.go @@ -60,7 +60,7 @@ func (mach *OlmMachine) SignUser(ctx context.Context, userID id.UserID, masterKe Str("signature", signature). Msg("Signed master key of user with our user-signing key") - if err := mach.CryptoStore.PutSignature(ctx, userID, masterKey, mach.Client.UserID, mach.CrossSigningKeys.UserSigningKey.PublicKey, signature); err != nil { + if err := mach.CryptoStore.PutSignature(ctx, userID, masterKey, mach.Client.UserID, mach.CrossSigningKeys.UserSigningKey.PublicKey(), signature); err != nil { return fmt.Errorf("error storing signature in crypto store: %w", err) } @@ -77,7 +77,7 @@ func (mach *OlmMachine) SignOwnMasterKey(ctx context.Context) error { userID := mach.Client.UserID deviceID := mach.Client.DeviceID - masterKey := mach.CrossSigningKeys.MasterKey.PublicKey + masterKey := mach.CrossSigningKeys.MasterKey.PublicKey() masterKeyObj := mautrix.ReqKeysSignatures{ UserID: userID, @@ -149,7 +149,7 @@ func (mach *OlmMachine) SignOwnDevice(ctx context.Context, device *id.Device) er Str("signature", signature). Msg("Signed own device key with self-signing key") - if err := mach.CryptoStore.PutSignature(ctx, device.UserID, device.SigningKey, mach.Client.UserID, mach.CrossSigningKeys.SelfSigningKey.PublicKey, signature); err != nil { + if err := mach.CryptoStore.PutSignature(ctx, device.UserID, device.SigningKey, mach.Client.UserID, mach.CrossSigningKeys.SelfSigningKey.PublicKey(), signature); err != nil { return fmt.Errorf("error storing signature in crypto store: %w", err) } @@ -180,12 +180,12 @@ func (mach *OlmMachine) getFullDeviceKeys(ctx context.Context, device *id.Device } // signAndUpload signs the given key signatures object and uploads it to the server. -func (mach *OlmMachine) signAndUpload(ctx context.Context, req mautrix.ReqKeysSignatures, userID id.UserID, signedThing string, key *olm.PkSigning) (string, error) { +func (mach *OlmMachine) signAndUpload(ctx context.Context, req mautrix.ReqKeysSignatures, userID id.UserID, signedThing string, key olm.PKSigning) (string, error) { signature, err := key.SignJSON(req) if err != nil { return "", fmt.Errorf("failed to sign JSON: %w", err) } - req.Signatures = signatures.NewSingleSignature(mach.Client.UserID, id.KeyAlgorithmEd25519, key.PublicKey.String(), signature) + req.Signatures = signatures.NewSingleSignature(mach.Client.UserID, id.KeyAlgorithmEd25519, key.PublicKey().String(), signature) resp, err := mach.Client.UploadSignatures(ctx, &mautrix.ReqUploadSignatures{ userID: map[string]mautrix.ReqKeysSignatures{ diff --git a/crypto/cross_sign_ssss.go b/crypto/cross_sign_ssss.go index a87f87de..389a9fd2 100644 --- a/crypto/cross_sign_ssss.go +++ b/crypto/cross_sign_ssss.go @@ -110,24 +110,24 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(ctx context.Context, u // UploadCrossSigningKeysToSSSS stores the given cross-signing keys on the server encrypted with the given key. func (mach *OlmMachine) UploadCrossSigningKeysToSSSS(ctx context.Context, key *ssss.Key, keys *CrossSigningKeysCache) error { - if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningMaster, keys.MasterKey.Seed, key); err != nil { + if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningMaster, keys.MasterKey.Seed(), key); err != nil { return err } - if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningSelf, keys.SelfSigningKey.Seed, key); err != nil { + if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningSelf, keys.SelfSigningKey.Seed(), key); err != nil { return err } - if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningUser, keys.UserSigningKey.Seed, key); err != nil { + if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningUser, keys.UserSigningKey.Seed(), key); err != nil { return err } // Also store these locally - if err := mach.CryptoStore.PutCrossSigningKey(ctx, mach.Client.UserID, id.XSUsageMaster, keys.MasterKey.PublicKey); err != nil { + if err := mach.CryptoStore.PutCrossSigningKey(ctx, mach.Client.UserID, id.XSUsageMaster, keys.MasterKey.PublicKey()); err != nil { return err } - if err := mach.CryptoStore.PutCrossSigningKey(ctx, mach.Client.UserID, id.XSUsageSelfSigning, keys.SelfSigningKey.PublicKey); err != nil { + if err := mach.CryptoStore.PutCrossSigningKey(ctx, mach.Client.UserID, id.XSUsageSelfSigning, keys.SelfSigningKey.PublicKey()); err != nil { return err } - if err := mach.CryptoStore.PutCrossSigningKey(ctx, mach.Client.UserID, id.XSUsageUserSigning, keys.UserSigningKey.PublicKey); err != nil { + if err := mach.CryptoStore.PutCrossSigningKey(ctx, mach.Client.UserID, id.XSUsageUserSigning, keys.UserSigningKey.PublicKey()); err != nil { return err } diff --git a/crypto/cross_sign_test.go b/crypto/cross_sign_test.go index b53da102..e11fb018 100644 --- a/crypto/cross_sign_test.go +++ b/crypto/cross_sign_test.go @@ -37,13 +37,13 @@ func getOlmMachine(t *testing.T) *OlmMachine { } userID := id.UserID("@mautrix") - mk, _ := olm.NewPkSigning() - ssk, _ := olm.NewPkSigning() - usk, _ := olm.NewPkSigning() + mk, _ := olm.NewPKSigning() + ssk, _ := olm.NewPKSigning() + usk, _ := olm.NewPKSigning() - sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageMaster, mk.PublicKey) - sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageSelfSigning, ssk.PublicKey) - sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageUserSigning, usk.PublicKey) + sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageMaster, mk.PublicKey()) + sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageSelfSigning, ssk.PublicKey()) + sqlStore.PutCrossSigningKey(context.TODO(), userID, id.XSUsageUserSigning, usk.PublicKey()) return &OlmMachine{ CryptoStore: sqlStore, @@ -70,10 +70,10 @@ func TestTrustOwnDevice(t *testing.T) { t.Error("Own device trusted while it shouldn't be") } - m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey, - ownDevice.UserID, m.CrossSigningKeys.MasterKey.PublicKey, "sig1") + m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), + ownDevice.UserID, m.CrossSigningKeys.MasterKey.PublicKey(), "sig1") m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, ownDevice.SigningKey, - ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey, "sig2") + ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), "sig2") if trusted, _ := m.IsUserTrusted(context.TODO(), ownDevice.UserID); !trusted { t.Error("Own user not trusted while they should be") @@ -90,22 +90,22 @@ func TestTrustOtherUser(t *testing.T) { t.Error("Other user trusted while they shouldn't be") } - theirMasterKey, _ := olm.NewPkSigning() - m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey) + theirMasterKey, _ := olm.NewPKSigning() + m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey()) - m.CryptoStore.PutSignature(context.TODO(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, - m.Client.UserID, m.CrossSigningKeys.MasterKey.PublicKey, "sig1") + m.CryptoStore.PutSignature(context.TODO(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(), + m.Client.UserID, m.CrossSigningKeys.MasterKey.PublicKey(), "sig1") // sign them with self-signing instead of user-signing key - m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey, - m.Client.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey, "invalid_sig") + m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(), + m.Client.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), "invalid_sig") if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted { t.Error("Other user trusted before their master key has been signed with our user-signing key") } - m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey, - m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, "sig2") + m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(), + m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(), "sig2") if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted { t.Error("Other user not trusted while they should be") @@ -127,29 +127,29 @@ func TestTrustOtherDevice(t *testing.T) { t.Error("Other device trusted while it shouldn't be") } - theirMasterKey, _ := olm.NewPkSigning() - m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey) - theirSSK, _ := olm.NewPkSigning() - m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageSelfSigning, theirSSK.PublicKey) + theirMasterKey, _ := olm.NewPKSigning() + m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey()) + theirSSK, _ := olm.NewPKSigning() + m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageSelfSigning, theirSSK.PublicKey()) - m.CryptoStore.PutSignature(context.TODO(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, - m.Client.UserID, m.CrossSigningKeys.MasterKey.PublicKey, "sig1") - m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey, - m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, "sig2") + m.CryptoStore.PutSignature(context.TODO(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(), + m.Client.UserID, m.CrossSigningKeys.MasterKey.PublicKey(), "sig1") + m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(), + m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(), "sig2") if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted { t.Error("Other user not trusted while they should be") } - m.CryptoStore.PutSignature(context.TODO(), otherUser, theirSSK.PublicKey, - otherUser, theirMasterKey.PublicKey, "sig3") + m.CryptoStore.PutSignature(context.TODO(), otherUser, theirSSK.PublicKey(), + otherUser, theirMasterKey.PublicKey(), "sig3") if m.IsDeviceTrusted(theirDevice) { t.Error("Other device trusted before it has been signed with user's SSK") } m.CryptoStore.PutSignature(context.TODO(), otherUser, theirDevice.SigningKey, - otherUser, theirSSK.PublicKey, "sig4") + otherUser, theirSSK.PublicKey(), "sig4") if !m.IsDeviceTrusted(theirDevice) { t.Error("Other device not trusted while it should be") diff --git a/crypto/goolm/pk/decryption.go b/crypto/goolm/pk/decryption.go index 3fb3c2a5..d08e09f4 100644 --- a/crypto/goolm/pk/decryption.go +++ b/crypto/goolm/pk/decryption.go @@ -45,8 +45,8 @@ func NewDecryptionFromPrivate(privateKey crypto.Curve25519PrivateKey) (*Decrypti return s, nil } -// PubKey returns the public key base 64 encoded. -func (s Decryption) PubKey() id.Curve25519 { +// PublicKey returns the public key base 64 encoded. +func (s Decryption) PublicKey() id.Curve25519 { return s.KeyPair.B64Encoded() } diff --git a/crypto/goolm/pk/pk_test.go b/crypto/goolm/pk/pk_test.go index 91bab5b9..7ac524be 100644 --- a/crypto/goolm/pk/pk_test.go +++ b/crypto/goolm/pk/pk_test.go @@ -30,14 +30,14 @@ func TestEncryptionDecryption(t *testing.T) { if err != nil { t.Fatal(err) } - if !bytes.Equal([]byte(decryption.PubKey()), alicePublic) { + if !bytes.Equal([]byte(decryption.PublicKey()), alicePublic) { t.Fatal("public key not correct") } if !bytes.Equal(decryption.PrivateKey(), alicePrivate) { t.Fatal("private key not correct") } - encryption, err := pk.NewEncryption(decryption.PubKey()) + encryption, err := pk.NewEncryption(decryption.PublicKey()) if err != nil { t.Fatal(err) } @@ -66,7 +66,10 @@ func TestSigning(t *testing.T) { } message := []byte("We hold these truths to be self-evident, that all men are created equal, that they are endowed by their Creator with certain unalienable Rights, that among these are Life, Liberty and the pursuit of Happiness.") signing, _ := pk.NewSigningFromSeed(seed) - signature := signing.Sign(message) + signature, err := signing.Sign(message) + if err != nil { + t.Fatal(err) + } signatureDecoded, err := goolm.Base64Decode(signature) if err != nil { t.Fatal(err) @@ -101,7 +104,7 @@ func TestDecryptionPickling(t *testing.T) { if err != nil { t.Fatal(err) } - if !bytes.Equal([]byte(decryption.PubKey()), alicePublic) { + if !bytes.Equal([]byte(decryption.PublicKey()), alicePublic) { t.Fatal("public key not correct") } if !bytes.Equal(decryption.PrivateKey(), alicePrivate) { @@ -125,7 +128,7 @@ func TestDecryptionPickling(t *testing.T) { if err != nil { t.Fatal(err) } - if !bytes.Equal([]byte(newDecription.PubKey()), alicePublic) { + if !bytes.Equal([]byte(newDecription.PublicKey()), alicePublic) { t.Fatal("public key not correct") } if !bytes.Equal(newDecription.PrivateKey(), alicePrivate) { diff --git a/crypto/goolm/pk/signing.go b/crypto/goolm/pk/signing.go index 046838ff..a98330d5 100644 --- a/crypto/goolm/pk/signing.go +++ b/crypto/goolm/pk/signing.go @@ -2,7 +2,11 @@ package pk import ( "crypto/rand" + "encoding/json" + "github.com/tidwall/sjson" + + "maunium.net/go/mautrix/crypto/canonicaljson" "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/id" @@ -10,15 +14,15 @@ import ( // Signing is used for signing a pk type Signing struct { - KeyPair crypto.Ed25519KeyPair `json:"key_pair"` - Seed []byte `json:"seed"` + keyPair crypto.Ed25519KeyPair + seed []byte } // NewSigningFromSeed constructs a new Signing based on a seed. func NewSigningFromSeed(seed []byte) (*Signing, error) { s := &Signing{} - s.Seed = seed - s.KeyPair = crypto.Ed25519GenerateFromSeed(seed) + s.seed = seed + s.keyPair = crypto.Ed25519GenerateFromSeed(seed) return s, nil } @@ -32,13 +36,34 @@ func NewSigning() (*Signing, error) { return NewSigningFromSeed(seed) } -// Sign returns the signature of the message base64 encoded. -func (s Signing) Sign(message []byte) []byte { - signature := s.KeyPair.Sign(message) - return goolm.Base64Encode(signature) +// Seed returns the seed of the key pair. +func (s Signing) Seed() []byte { + return s.seed } // PublicKey returns the public key of the key pair base 64 encoded. func (s Signing) PublicKey() id.Ed25519 { - return s.KeyPair.B64Encoded() + return s.keyPair.B64Encoded() +} + +// Sign returns the signature of the message base64 encoded. +func (s Signing) Sign(message []byte) ([]byte, error) { + signature := s.keyPair.Sign(message) + return goolm.Base64Encode(signature), nil +} + +// SignJSON creates a signature for the given object after encoding it to +// canonical JSON. +func (s Signing) SignJSON(obj any) (string, error) { + objJSON, err := json.Marshal(obj) + if err != nil { + return "", err + } + objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned") + objJSON, _ = sjson.DeleteBytes(objJSON, "signatures") + signature, err := s.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON)) + if err != nil { + return "", err + } + return string(signature), nil } diff --git a/crypto/olm/pk_goolm.go b/crypto/olm/pk_goolm.go index 9659e918..372c94fa 100644 --- a/crypto/olm/pk_goolm.go +++ b/crypto/olm/pk_goolm.go @@ -1,71 +1,29 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// When the goolm build flag is enabled, this file will make [PKSigning] +// constructors use the goolm constuctors. + //go:build goolm package olm -import ( - "encoding/json" +import "maunium.net/go/mautrix/crypto/goolm/pk" - "github.com/tidwall/sjson" - - "maunium.net/go/mautrix/crypto/canonicaljson" - "maunium.net/go/mautrix/crypto/goolm/pk" - "maunium.net/go/mautrix/id" -) - -// PkSigning stores a key pair for signing messages. -type PkSigning struct { - pk.Signing - PublicKey id.Ed25519 - Seed []byte +// NewPKSigningFromSeed creates a new PKSigning object using the given seed. +func NewPKSigningFromSeed(seed []byte) (PKSigning, error) { + return pk.NewSigningFromSeed(seed) } -// Clear clears the underlying memory of a PkSigning object. -func (p *PkSigning) Clear() { - p.Signing = pk.Signing{} +// NewPKSigning creates a new [PKSigning] object, containing a key pair for +// signing messages. +func NewPKSigning() (PKSigning, error) { + return pk.NewSigning() } -// NewPkSigningFromSeed creates a new PkSigning object using the given seed. -func NewPkSigningFromSeed(seed []byte) (*PkSigning, error) { - p := &PkSigning{} - signing, err := pk.NewSigningFromSeed(seed) - if err != nil { - return nil, err - } - p.Signing = *signing - p.Seed = seed - p.PublicKey = p.Signing.PublicKey() - return p, nil -} - -// NewPkSigning creates a new PkSigning object, containing a key pair for signing messages. -func NewPkSigning() (*PkSigning, error) { - p := &PkSigning{} - signing, err := pk.NewSigning() - if err != nil { - return nil, err - } - p.Signing = *signing - p.Seed = signing.Seed - p.PublicKey = p.Signing.PublicKey() - return p, err -} - -// Sign creates a signature for the given message using this key. -func (p *PkSigning) Sign(message []byte) ([]byte, error) { - return p.Signing.Sign(message), nil -} - -// SignJSON creates a signature for the given object after encoding it to canonical JSON. -func (p *PkSigning) SignJSON(obj interface{}) (string, error) { - objJSON, err := json.Marshal(obj) - if err != nil { - return "", err - } - objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned") - objJSON, _ = sjson.DeleteBytes(objJSON, "signatures") - signature, err := p.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON)) - if err != nil { - return "", err - } - return string(signature), nil +func NewPKDecryption(privateKey []byte) (PKDecryption, error) { + return pk.NewDecryption() } diff --git a/crypto/olm/pk_interface.go b/crypto/olm/pk_interface.go new file mode 100644 index 00000000..11c41431 --- /dev/null +++ b/crypto/olm/pk_interface.go @@ -0,0 +1,41 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package olm + +import ( + "maunium.net/go/mautrix/crypto/goolm/pk" + "maunium.net/go/mautrix/id" +) + +// PKSigning is an interface for signing messages. +type PKSigning interface { + // Seed returns the seed of the key. + Seed() []byte + + // PublicKey returns the public key. + PublicKey() id.Ed25519 + + // Sign creates a signature for the given message using this key. + Sign(message []byte) ([]byte, error) + + // SignJSON creates a signature for the given object after encoding it to + // canonical JSON. + SignJSON(obj any) (string, error) +} + +var _ PKSigning = (*pk.Signing)(nil) + +// PKDecryption is an interface for decrypting messages. +type PKDecryption interface { + // PublicKey returns the public key. + PublicKey() id.Curve25519 + + // Decrypt verifies and decrypts the given message. + Decrypt(ciphertext, mac []byte, key id.Curve25519) ([]byte, error) +} + +var _ PKDecryption = (*pk.Decryption)(nil) diff --git a/crypto/olm/pk.go b/crypto/olm/pk_libolm.go similarity index 67% rename from crypto/olm/pk.go rename to crypto/olm/pk_libolm.go index ba390afe..0854b4d1 100644 --- a/crypto/olm/pk.go +++ b/crypto/olm/pk_libolm.go @@ -1,3 +1,9 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + //go:build !goolm package olm @@ -18,14 +24,17 @@ import ( "maunium.net/go/mautrix/id" ) -// PkSigning stores a key pair for signing messages. -type PkSigning struct { +// LibOlmPKSigning stores a key pair for signing messages. +type LibOlmPKSigning struct { int *C.OlmPkSigning mem []byte - PublicKey id.Ed25519 - Seed []byte + publicKey id.Ed25519 + seed []byte } +// Ensure that LibOlmPKSigning implements PKSigning. +var _ PKSigning = (*LibOlmPKSigning)(nil) + func pkSigningSize() uint { return uint(C.olm_pk_signing_size()) } @@ -42,48 +51,57 @@ func pkSigningSignatureLength() uint { return uint(C.olm_pk_signature_length()) } -func NewBlankPkSigning() *PkSigning { +func newBlankPKSigning() *LibOlmPKSigning { memory := make([]byte, pkSigningSize()) - return &PkSigning{ + return &LibOlmPKSigning{ int: C.olm_pk_signing(unsafe.Pointer(&memory[0])), mem: memory, } } -// Clear clears the underlying memory of a PkSigning object. -func (p *PkSigning) Clear() { - C.olm_clear_pk_signing((*C.OlmPkSigning)(p.int)) -} - -// NewPkSigningFromSeed creates a new PkSigning object using the given seed. -func NewPkSigningFromSeed(seed []byte) (*PkSigning, error) { - p := NewBlankPkSigning() - p.Clear() +// NewPKSigningFromSeed creates a new [PKSigning] object using the given seed. +func NewPKSigningFromSeed(seed []byte) (PKSigning, error) { + p := newBlankPKSigning() + p.clear() pubKey := make([]byte, pkSigningPublicKeyLength()) if C.olm_pk_signing_key_from_seed((*C.OlmPkSigning)(p.int), unsafe.Pointer(&pubKey[0]), C.size_t(len(pubKey)), unsafe.Pointer(&seed[0]), C.size_t(len(seed))) == errorVal() { return nil, p.lastError() } - p.PublicKey = id.Ed25519(pubKey) - p.Seed = seed + p.publicKey = id.Ed25519(pubKey) + p.seed = seed return p, nil } -// NewPkSigning creates a new PkSigning object, containing a key pair for signing messages. -func NewPkSigning() (*PkSigning, error) { +// NewPKSigning creates a new LibOlmPKSigning object, containing a key pair for +// signing messages. +func NewPKSigning() (PKSigning, error) { // Generate the seed seed := make([]byte, pkSigningSeedLength()) _, err := rand.Read(seed) if err != nil { panic(NotEnoughGoRandom) } - pk, err := NewPkSigningFromSeed(seed) + pk, err := NewPKSigningFromSeed(seed) return pk, err } +func (p *LibOlmPKSigning) PublicKey() id.Ed25519 { + return p.publicKey +} + +func (p *LibOlmPKSigning) Seed() []byte { + return p.seed +} + +// clear clears the underlying memory of a LibOlmPKSigning object. +func (p *LibOlmPKSigning) clear() { + C.olm_clear_pk_signing((*C.OlmPkSigning)(p.int)) +} + // Sign creates a signature for the given message using this key. -func (p *PkSigning) Sign(message []byte) ([]byte, error) { +func (p *LibOlmPKSigning) Sign(message []byte) ([]byte, error) { signature := make([]byte, pkSigningSignatureLength()) if C.olm_pk_sign((*C.OlmPkSigning)(p.int), (*C.uint8_t)(unsafe.Pointer(&message[0])), C.size_t(len(message)), (*C.uint8_t)(unsafe.Pointer(&signature[0])), C.size_t(len(signature))) == errorVal() { @@ -93,7 +111,7 @@ func (p *PkSigning) Sign(message []byte) ([]byte, error) { } // SignJSON creates a signature for the given object after encoding it to canonical JSON. -func (p *PkSigning) SignJSON(obj interface{}) (string, error) { +func (p *LibOlmPKSigning) SignJSON(obj interface{}) (string, error) { objJSON, err := json.Marshal(obj) if err != nil { return "", err @@ -107,12 +125,13 @@ func (p *PkSigning) SignJSON(obj interface{}) (string, error) { return string(signature), nil } -// lastError returns the last error that happened in relation to this PkSigning object. -func (p *PkSigning) lastError() error { +// lastError returns the last error that happened in relation to this +// LibOlmPKSigning object. +func (p *LibOlmPKSigning) lastError() error { return convertError(C.GoString(C.olm_pk_signing_last_error((*C.OlmPkSigning)(p.int)))) } -type PkDecryption struct { +type LibOlmPKDecryption struct { int *C.OlmPkDecryption mem []byte PublicKey []byte @@ -126,13 +145,13 @@ func pkDecryptionPublicKeySize() uint { return uint(C.olm_pk_key_length()) } -func NewPkDecryption(privateKey []byte) (*PkDecryption, error) { +func NewPkDecryption(privateKey []byte) (*LibOlmPKDecryption, error) { memory := make([]byte, pkDecryptionSize()) - p := &PkDecryption{ + p := &LibOlmPKDecryption{ int: C.olm_pk_decryption(unsafe.Pointer(&memory[0])), mem: memory, } - p.Clear() + p.clear() pubKey := make([]byte, pkDecryptionPublicKeySize()) if C.olm_pk_key_from_private((*C.OlmPkDecryption)(p.int), @@ -145,7 +164,7 @@ func NewPkDecryption(privateKey []byte) (*PkDecryption, error) { return p, nil } -func (p *PkDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byte) ([]byte, error) { +func (p *LibOlmPKDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byte) ([]byte, error) { maxPlaintextLength := uint(C.olm_pk_max_plaintext_length((*C.OlmPkDecryption)(p.int), C.size_t(len(ciphertext)))) plaintext := make([]byte, maxPlaintextLength) @@ -162,11 +181,12 @@ func (p *PkDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byt } // Clear clears the underlying memory of a PkDecryption object. -func (p *PkDecryption) Clear() { +func (p *LibOlmPKDecryption) clear() { C.olm_clear_pk_decryption((*C.OlmPkDecryption)(p.int)) } -// lastError returns the last error that happened in relation to this PkDecryption object. -func (p *PkDecryption) lastError() error { +// lastError returns the last error that happened in relation to this +// LibOlmPKDecryption object. +func (p *LibOlmPKDecryption) lastError() error { return convertError(C.GoString(C.olm_pk_decryption_last_error((*C.OlmPkDecryption)(p.int)))) } From 2728a8f8aa5dbbe69ddf87105a279337110df0e2 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 8 Mar 2024 15:16:41 -0700 Subject: [PATCH 0160/1647] olm/pk: add fuzz test for the Sign function Signed-off-by: Sumner Evans --- crypto/olm/pk_test.go | 45 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 crypto/olm/pk_test.go diff --git a/crypto/olm/pk_test.go b/crypto/olm/pk_test.go new file mode 100644 index 00000000..b57e6571 --- /dev/null +++ b/crypto/olm/pk_test.go @@ -0,0 +1,45 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Only run this test if goo is disabled (that is, libolm is used). +//go:build !goolm + +package olm_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/crypto/goolm/pk" + "maunium.net/go/mautrix/crypto/olm" +) + +func FuzzSign(f *testing.F) { + seed := []byte("Quohboh3ka3ooghequier9lee8Bahwoh") + goolmPkSigning, err := pk.NewSigningFromSeed(seed) + require.NoError(f, err) + + libolmPkSigning, err := olm.NewPKSigningFromSeed(seed) + require.NoError(f, err) + + f.Add([]byte("message")) + + f.Fuzz(func(t *testing.T, message []byte) { + // libolm breaks with empty messages, so don't perform differential + // fuzzing on that. + if len(message) == 0 { + return + } + + libolmResult, libolmErr := libolmPkSigning.Sign(message) + goolmResult, goolmErr := goolmPkSigning.Sign(message) + + assert.Equal(t, goolmErr, libolmErr) + assert.Equal(t, goolmResult, libolmResult) + }) +} From 94246ffc85aa16507d5b26824c41f59fba474678 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 11 Mar 2024 20:35:50 +0200 Subject: [PATCH 0161/1647] Drop maulogger support --- CHANGELOG.md | 2 ++ appservice/appservice.go | 5 +---- bridge/bridge.go | 5 ----- bridge/commands/event.go | 3 --- bridge/commands/processor.go | 2 -- bridge/websocket.go | 2 +- client.go | 20 ++------------------ go.mod | 1 - go.sum | 2 -- 9 files changed, 6 insertions(+), 36 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b97af05..821c8500 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ ## v0.18.0 (unreleased) +* **Breaking change *(client, bridge, appservice)*** Dropped support for + maulogger. Only zerolog loggers are now provided by default. * *(bridge)* Fixed upload size limit not having a default if the server returned no value. * *(synapseadmin)* Added wrappers for some room and user admin APIs. diff --git a/appservice/appservice.go b/appservice/appservice.go index b64a84a1..ef9c6236 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -24,7 +24,6 @@ import ( "github.com/rs/zerolog" "golang.org/x/net/publicsuffix" "gopkg.in/yaml.v3" - "maunium.net/go/maulogger/v2/maulogadapt" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" @@ -355,7 +354,7 @@ func (as *AppService) SetHomeserverURL(homeserverURL string) error { // This does not do any validation, and it does not cache the client. // Usually you should prefer [AppService.Client] or [AppService.Intent] over this method. func (as *AppService) NewMautrixClient(userID id.UserID) *mautrix.Client { - client := &mautrix.Client{ + return &mautrix.Client{ HomeserverURL: as.hsURLForClient, UserID: userID, SetAppServiceUserID: true, @@ -366,8 +365,6 @@ func (as *AppService) NewMautrixClient(userID id.UserID) *mautrix.Client { Client: as.HTTPClient, DefaultHTTPRetries: as.DefaultHTTPRetries, } - client.Logger = maulogadapt.ZeroAsMau(&client.Log) - return client } // NewExternalMautrixClient creates a new [mautrix.Client] instance for an external user, diff --git a/bridge/bridge.go b/bridge/bridge.go index 4a7ba465..cfc31044 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -29,8 +29,6 @@ import ( "go.mau.fi/util/exzerolog" "gopkg.in/yaml.v3" flag "maunium.net/go/mauflag" - "maunium.net/go/maulogger/v2" - "maunium.net/go/maulogger/v2/maulogadapt" "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" @@ -201,8 +199,6 @@ type Bridge struct { Crypto Crypto CryptoPickleKey string - // Deprecated: Switch to ZLog - Log maulogger.Logger ZLog *zerolog.Logger MediaConfig mautrix.RespMediaConfig @@ -536,7 +532,6 @@ func (br *Bridge) init() { os.Exit(12) } exzerolog.SetupDefaults(br.ZLog) - br.Log = maulogadapt.ZeroAsMau(br.ZLog) br.DoublePuppet = &doublePuppetUtil{br: br, log: br.ZLog.With().Str("component", "double puppet").Logger()} diff --git a/bridge/commands/event.go b/bridge/commands/event.go index f1443d63..42b49b68 100644 --- a/bridge/commands/event.go +++ b/bridge/commands/event.go @@ -12,7 +12,6 @@ import ( "strings" "github.com/rs/zerolog" - "maunium.net/go/maulogger/v2" "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" @@ -38,8 +37,6 @@ type Event struct { ReplyTo id.EventID Ctx context.Context ZLog *zerolog.Logger - // Deprecated: switch to ZLog - Log maulogger.Logger } // MainIntent returns the intent to use when replying to the command. diff --git a/bridge/commands/processor.go b/bridge/commands/processor.go index 70dd16e9..6158a7cd 100644 --- a/bridge/commands/processor.go +++ b/bridge/commands/processor.go @@ -12,7 +12,6 @@ import ( "strings" "github.com/rs/zerolog" - "maunium.net/go/maulogger/v2/maulogadapt" "maunium.net/go/mautrix/bridge" "maunium.net/go/mautrix/id" @@ -91,7 +90,6 @@ func (proc *Processor) Handle(ctx context.Context, roomID id.RoomID, eventID id. ReplyTo: replyTo, Ctx: ctx, ZLog: &log, - Log: maulogadapt.ZeroAsMau(&log), } log.Debug().Msg("Received command") diff --git a/bridge/websocket.go b/bridge/websocket.go index 6b3d9abb..44a3d8d8 100644 --- a/bridge/websocket.go +++ b/bridge/websocket.go @@ -119,7 +119,7 @@ func (br *Bridge) PingServer() (start, serverTs, end time.Time) { } start = time.Now() var resp wsPingData - br.Log.Debugln("Pinging appservice websocket") + br.ZLog.Debug().Msg("Pinging appservice websocket") ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() err := br.AS.RequestWebsocket(ctx, &appservice.WebsocketRequest{ diff --git a/client.go b/client.go index 0015aede..ad68e793 100644 --- a/client.go +++ b/client.go @@ -19,7 +19,6 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/retryafter" - "maunium.net/go/maulogger/v2/maulogadapt" "maunium.net/go/mautrix/crypto/backup" "maunium.net/go/mautrix/event" @@ -49,17 +48,6 @@ type VerificationHelper interface { ConfirmSAS(ctx context.Context, txnID id.VerificationTransactionID) error } -// Deprecated: switch to zerolog -type Logger interface { - Debugfln(message string, args ...interface{}) -} - -// Deprecated: switch to zerolog -type WarnLogger interface { - Logger - Warnfln(message string, args ...interface{}) -} - // Client represents a Matrix client. type Client struct { HomeserverURL *url.URL // The base homeserver URL @@ -75,8 +63,6 @@ type Client struct { Verification VerificationHelper Log zerolog.Logger - // Deprecated: switch to the zerolog instance in Log - Logger Logger RequestHook func(req *http.Request) ResponseHook func(req *http.Request, resp *http.Response, duration time.Duration) @@ -2295,7 +2281,7 @@ func NewClient(homeserverURL string, userID id.UserID, accessToken string) (*Cli if err != nil { return nil, err } - cli := &Client{ + return &Client{ AccessToken: accessToken, UserAgent: DefaultUserAgent, HomeserverURL: hsURL, @@ -2307,7 +2293,5 @@ func NewClient(homeserverURL string, userID id.UserID, accessToken string) (*Cli // The client will work with this storer: it just won't remember across restarts. // In practice, a database backend should be used. Store: NewMemorySyncStore(), - } - cli.Logger = maulogadapt.ZeroAsMau(&cli.Log) - return cli, nil + }, nil } diff --git a/go.mod b/go.mod index 8cbc974c..425fcae5 100644 --- a/go.mod +++ b/go.mod @@ -19,7 +19,6 @@ require ( golang.org/x/net v0.22.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 - maunium.net/go/maulogger/v2 v2.4.1 ) require ( diff --git a/go.sum b/go.sum index 26e8e0a5..42f5e3e3 100644 --- a/go.sum +++ b/go.sum @@ -60,5 +60,3 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= -maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8= -maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho= From 08397c8b9ad29f31fa4c773ebc108c354e6b1170 Mon Sep 17 00:00:00 2001 From: Brad Murray Date: Mon, 11 Mar 2024 18:50:06 -0400 Subject: [PATCH 0162/1647] Fix responding to m.secret.request messages (#195) --- crypto/sharing.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/crypto/sharing.go b/crypto/sharing.go index cf14499f..18088b8e 100644 --- a/crypto/sharing.go +++ b/crypto/sharing.go @@ -121,9 +121,11 @@ func (mach *OlmMachine) HandleSecretRequest(ctx context.Context, userID id.UserI return } else if secret != "" { log.Debug().Msg("Responding to secret request") - mach.sendToOneDevice(ctx, mach.Client.UserID, content.RequestingDeviceID, event.ToDeviceSecretRequest, &event.SecretSendEventContent{ - RequestID: content.RequestID, - Secret: secret, + mach.SendEncryptedToDevice(ctx, device, event.ToDeviceSecretSend, event.Content{ + Parsed: event.SecretSendEventContent{ + RequestID: content.RequestID, + Secret: secret, + }, }) } else { log.Debug().Msg("No stored secret found, secret request ignored") From 8128b00e0082526b0365ca088f2adc2bb72da82a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 12 Mar 2024 21:15:39 +0200 Subject: [PATCH 0163/1647] Add key server that passes the federation tester (#197) --- federation/keyserver.go | 203 +++++++++++++++++++++++++++++++++++++++ federation/signingkey.go | 123 ++++++++++++++++++++++++ 2 files changed, 326 insertions(+) create mode 100644 federation/keyserver.go create mode 100644 federation/signingkey.go diff --git a/federation/keyserver.go b/federation/keyserver.go new file mode 100644 index 00000000..3e74bfdf --- /dev/null +++ b/federation/keyserver.go @@ -0,0 +1,203 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation + +import ( + "encoding/json" + "fmt" + "net/http" + "strconv" + "time" + + "github.com/gorilla/mux" + "go.mau.fi/util/jsontime" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/id" +) + +type ServerVersion struct { + Name string `json:"name"` + Version string `json:"version"` +} + +// ServerKeyProvider is an interface that returns private server keys for server key requests. +type ServerKeyProvider interface { + Get(r *http.Request) (serverName string, key *SigningKey) +} + +// StaticServerKey is an implementation of [ServerKeyProvider] that always returns the same server name and key. +type StaticServerKey struct { + ServerName string + Key *SigningKey +} + +func (ssk *StaticServerKey) Get(r *http.Request) (serverName string, key *SigningKey) { + return ssk.ServerName, ssk.Key +} + +// KeyServer implements a basic Matrix key server that can serve its own keys, plus the federation version endpoint. +// +// It does not implement querying keys of other servers, nor any other federation endpoints. +type KeyServer struct { + KeyProvider ServerKeyProvider + Version ServerVersion + WellKnownTarget string +} + +// Register registers the key server endpoints to the given router. +func (ks *KeyServer) Register(r *mux.Router) { + r.HandleFunc("/.well-known/matrix/server", ks.GetWellKnown).Methods(http.MethodGet) + r.HandleFunc("/_matrix/federation/v1/version", ks.GetServerVersion).Methods(http.MethodGet) + keyRouter := r.PathPrefix("/_matrix/key").Subrouter() + keyRouter.HandleFunc("/v2/server", ks.GetServerKey).Methods(http.MethodGet) + keyRouter.HandleFunc("/v2/query/{serverName}", ks.GetQueryKeys).Methods(http.MethodGet) + keyRouter.HandleFunc("/v2/query", ks.PostQueryKeys).Methods(http.MethodPost) + keyRouter.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + ErrCode: mautrix.MUnrecognized.ErrCode, + Err: "Unrecognized endpoint", + }) + }) + keyRouter.MethodNotAllowedHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + jsonResponse(w, http.StatusMethodNotAllowed, &mautrix.RespError{ + ErrCode: mautrix.MUnrecognized.ErrCode, + Err: "Invalid method for endpoint", + }) + }) +} + +func jsonResponse(w http.ResponseWriter, code int, data any) { + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(code) + _ = json.NewEncoder(w).Encode(data) +} + +// RespWellKnown is the response body for the `GET /.well-known/matrix/server` endpoint. +type RespWellKnown struct { + Server string `json:"m.server"` +} + +// GetWellKnown implements the `GET /.well-known/matrix/server` endpoint +// +// https://spec.matrix.org/v1.9/server-server-api/#get_well-knownmatrixserver +func (ks *KeyServer) GetWellKnown(w http.ResponseWriter, r *http.Request) { + if ks.WellKnownTarget == "" { + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + ErrCode: mautrix.MNotFound.ErrCode, + Err: "No well-known target set", + }) + } else { + jsonResponse(w, http.StatusOK, &RespWellKnown{Server: ks.WellKnownTarget}) + } +} + +// RespServerVersion is the response body for the `GET /_matrix/federation/v1/version` endpoint +type RespServerVersion struct { + Server ServerVersion `json:"server"` +} + +// GetServerVersion implements the `GET /_matrix/federation/v1/version` endpoint +// +// https://spec.matrix.org/v1.9/server-server-api/#get_matrixfederationv1version +func (ks *KeyServer) GetServerVersion(w http.ResponseWriter, r *http.Request) { + jsonResponse(w, http.StatusOK, &RespServerVersion{Server: ks.Version}) +} + +// GetServerKey implements the `GET /_matrix/key/v2/server` endpoint. +// +// https://spec.matrix.org/v1.9/server-server-api/#get_matrixkeyv2server +func (ks *KeyServer) GetServerKey(w http.ResponseWriter, r *http.Request) { + domain, key := ks.KeyProvider.Get(r) + if key == nil { + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + ErrCode: mautrix.MNotFound.ErrCode, + Err: fmt.Sprintf("No signing key found for %q", r.Host), + }) + } else { + jsonResponse(w, http.StatusOK, key.GenerateKeyResponse(domain, nil)) + } +} + +// ReqQueryKeys is the request body for the `POST /_matrix/key/v2/query` endpoint +type ReqQueryKeys struct { + ServerKeys map[string]map[id.KeyID]QueryKeysCriteria `json:"server_keys"` +} + +type QueryKeysCriteria struct { + MinimumValidUntilTS jsontime.UnixMilli `json:"minimum_valid_until_ts"` +} + +// PostQueryKeysResponse is the response body for the `POST /_matrix/key/v2/query` endpoint +type PostQueryKeysResponse struct { + ServerKeys map[string]*ServerKeyResponse `json:"server_keys"` +} + +// PostQueryKeys implements the `POST /_matrix/key/v2/query` endpoint +// +// https://spec.matrix.org/v1.9/server-server-api/#post_matrixkeyv2query +func (ks *KeyServer) PostQueryKeys(w http.ResponseWriter, r *http.Request) { + var req ReqQueryKeys + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ + ErrCode: mautrix.MBadJSON.ErrCode, + Err: fmt.Sprintf("failed to parse request: %v", err), + }) + return + } + + resp := &PostQueryKeysResponse{ + ServerKeys: make(map[string]*ServerKeyResponse), + } + for serverName, keys := range req.ServerKeys { + domain, key := ks.KeyProvider.Get(r) + if domain != serverName { + continue + } + for keyID, criteria := range keys { + if key.ID == keyID && criteria.MinimumValidUntilTS.Before(time.Now().Add(24*time.Hour)) { + resp.ServerKeys[serverName] = key.GenerateKeyResponse(serverName, nil) + } + } + } + jsonResponse(w, http.StatusOK, resp) +} + +// GetQueryKeysResponse is the response body for the `GET /_matrix/key/v2/query/{serverName}` endpoint +type GetQueryKeysResponse struct { + ServerKeys []*ServerKeyResponse `json:"server_keys"` +} + +// GetQueryKeys implements the `GET /_matrix/key/v2/query/{serverName}` endpoint +// +// https://spec.matrix.org/v1.9/server-server-api/#get_matrixkeyv2queryservername +func (ks *KeyServer) GetQueryKeys(w http.ResponseWriter, r *http.Request) { + serverName := mux.Vars(r)["serverName"] + minimumValidUntilTSString := r.URL.Query().Get("minimum_valid_until_ts") + minimumValidUntilTS, err := strconv.ParseInt(minimumValidUntilTSString, 10, 64) + if err != nil && minimumValidUntilTSString != "" { + jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ + ErrCode: mautrix.MInvalidParam.ErrCode, + Err: fmt.Sprintf("failed to parse ?minimum_valid_until_ts: %v", err), + }) + return + } else if time.UnixMilli(minimumValidUntilTS).After(time.Now().Add(24 * time.Hour)) { + jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ + ErrCode: mautrix.MInvalidParam.ErrCode, + Err: "minimum_valid_until_ts may not be more than 24 hours in the future", + }) + return + } + resp := &GetQueryKeysResponse{ + ServerKeys: []*ServerKeyResponse{}, + } + if domain, key := ks.KeyProvider.Get(r); key != nil && domain == serverName { + resp.ServerKeys = append(resp.ServerKeys, key.GenerateKeyResponse(serverName, nil)) + } + jsonResponse(w, http.StatusOK, resp) +} diff --git a/federation/signingkey.go b/federation/signingkey.go new file mode 100644 index 00000000..3d118233 --- /dev/null +++ b/federation/signingkey.go @@ -0,0 +1,123 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation + +import ( + "crypto/ed25519" + "encoding/base64" + "encoding/json" + "fmt" + "strings" + "time" + + "go.mau.fi/util/jsontime" + + "maunium.net/go/mautrix/crypto/canonicaljson" + "maunium.net/go/mautrix/id" +) + +// SigningKey is a Matrix federation signing key pair. +type SigningKey struct { + ID id.KeyID + Pub id.SigningKey + Priv ed25519.PrivateKey +} + +// SynapseString returns a string representation of the private key compatible with Synapse's .signing.key file format. +// +// The output of this function can be parsed back into a [SigningKey] using the [ParseSynapseKey] function. +func (sk *SigningKey) SynapseString() string { + alg, id := sk.ID.Parse() + return fmt.Sprintf("%s %s %s", alg, id, base64.RawStdEncoding.EncodeToString(sk.Priv.Seed())) +} + +// ParseSynapseKey parses a Synapse-compatible private key string into a SigningKey. +func ParseSynapseKey(key string) (*SigningKey, error) { + parts := strings.Split(key, " ") + if len(parts) != 3 { + return nil, fmt.Errorf("invalid key format (expected 3 space-separated parts, got %d)", len(parts)) + } else if parts[0] != string(id.KeyAlgorithmEd25519) { + return nil, fmt.Errorf("unsupported key algorithm %s (only ed25519 is supported)", parts[0]) + } + seed, err := base64.RawStdEncoding.DecodeString(parts[2]) + if err != nil { + return nil, fmt.Errorf("invalid private key: %w", err) + } + priv := ed25519.NewKeyFromSeed(seed) + pub := base64.RawStdEncoding.EncodeToString(priv.Public().(ed25519.PublicKey)) + return &SigningKey{ + ID: id.NewKeyID(id.KeyAlgorithmEd25519, parts[1]), + Pub: id.SigningKey(pub), + Priv: priv, + }, nil +} + +// GenerateSigningKey generates a new random signing key. +func GenerateSigningKey() *SigningKey { + pub, priv, err := ed25519.GenerateKey(nil) + if err != nil { + panic(err) + } + return &SigningKey{ + ID: id.NewKeyID(id.KeyAlgorithmEd25519, base64.RawURLEncoding.EncodeToString(pub[:4])), + Pub: id.SigningKey(base64.RawStdEncoding.EncodeToString(pub)), + Priv: priv, + } +} + +// ServerKeyResponse is the response body for the `GET /_matrix/key/v2/server` endpoint. +// It's also used inside the query endpoint response structs. +type ServerKeyResponse struct { + ServerName string `json:"server_name"` + VerifyKeys map[id.KeyID]ServerVerifyKey `json:"verify_keys"` + OldVerifyKeys map[id.KeyID]OldVerifyKey `json:"old_verify_keys,omitempty"` + Signatures map[string]map[id.KeyID]string `json:"signatures,omitempty"` + ValidUntilTS jsontime.UnixMilli `json:"valid_until_ts"` +} + +type ServerVerifyKey struct { + Key id.SigningKey `json:"key"` +} + +type OldVerifyKey struct { + Key id.SigningKey `json:"key"` + ExpiredTS jsontime.UnixMilli `json:"expired_ts"` +} + +func (sk *SigningKey) SignJSON(data any) ([]byte, error) { + marshaled, err := json.Marshal(data) + if err != nil { + return nil, err + } + return sk.SignRawJSON(marshaled), nil +} + +func (sk *SigningKey) SignRawJSON(data json.RawMessage) []byte { + return ed25519.Sign(sk.Priv, canonicaljson.CanonicalJSONAssumeValid(data)) +} + +// GenerateKeyResponse generates a key response signed by this key with the given server name and optionally some old verify keys. +func (sk *SigningKey) GenerateKeyResponse(serverName string, oldVerifyKeys map[id.KeyID]OldVerifyKey) *ServerKeyResponse { + skr := &ServerKeyResponse{ + ServerName: serverName, + OldVerifyKeys: oldVerifyKeys, + ValidUntilTS: jsontime.UM(time.Now().Add(24 * time.Hour)), + VerifyKeys: map[id.KeyID]ServerVerifyKey{ + sk.ID: {Key: sk.Pub}, + }, + } + signature, err := sk.SignJSON(skr) + if err != nil { + panic(err) + } + skr.Signatures = map[string]map[id.KeyID]string{ + serverName: { + sk.ID: base64.RawURLEncoding.EncodeToString(signature), + }, + } + return skr +} From f0b728f502867952e61081933d7aa70d09210c7e Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Wed, 13 Mar 2024 11:06:52 +0200 Subject: [PATCH 0164/1647] Require OGS update to succeed during EncryptMegolmEvent Otherwise we could end up reusing the same ratchet multiple times. --- crypto/encryptmegolm.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index 62bcc044..d592bd1c 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -118,7 +118,7 @@ func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID log.Debug().Msg("Encrypted event successfully") err = mach.CryptoStore.UpdateOutboundGroupSession(ctx, session) if err != nil { - log.Warn().Err(err).Msg("Failed to update megolm session in crypto store after encrypting") + return nil, fmt.Errorf("failed to update outbound group session after encrypting: %w", err) } encrypted := &event.EncryptedEventContent{ Algorithm: id.AlgorithmMegolmV1, From 5224780563b27a98d2d62947f8f8062006193d9f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 13 Mar 2024 16:57:01 +0200 Subject: [PATCH 0165/1647] Split UserID.Parse into generic ParseCommonIdentifier --- id/userid.go | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/id/userid.go b/id/userid.go index 3aae3b21..f1b6e27f 100644 --- a/id/userid.go +++ b/id/userid.go @@ -34,21 +34,37 @@ var ( ErrNoncompliantLocalpart = errors.New("contains characters that are not allowed") ErrUserIDTooLong = errors.New("the given user ID is longer than 255 characters") ErrEmptyLocalpart = errors.New("empty localparts are not allowed") + ErrEmptyIdentifier = errors.New("identifier is empty") ) +// ParseCommonIdentifier parses a common identifier according to https://spec.matrix.org/v1.9/appendices/#common-identifier-format +func ParseCommonIdentifier[Stringish ~string](identifier Stringish) (sigil byte, localpart, homeserver string, err error) { + if len(identifier) == 0 { + return 0, "", "", ErrEmptyIdentifier + } + sigil = identifier[0] + strIdentifier := string(identifier) + if strings.ContainsRune(strIdentifier, ':') { + parts := strings.SplitN(strIdentifier, ":", 2) + localpart = parts[0][1:] + homeserver = parts[1] + } else { + localpart = strIdentifier[1:] + } + return +} + // Parse parses the user ID into the localpart and server name. // // Note that this only enforces very basic user ID formatting requirements: user IDs start with // a @, and contain a : after the @. If you want to enforce localpart validity, see the // ParseAndValidate and ValidateUserLocalpart functions. func (userID UserID) Parse() (localpart, homeserver string, err error) { - if len(userID) == 0 || userID[0] != '@' || !strings.ContainsRune(string(userID), ':') { - // This error wrapping lets you use errors.Is() nicely even though the message contains the user ID + var sigil byte + sigil, localpart, homeserver, err = ParseCommonIdentifier(userID) + if err != nil || sigil != '@' || homeserver == "" { err = fmt.Errorf("'%s' %w", userID, ErrInvalidUserID) - return } - parts := strings.SplitN(string(userID), ":", 2) - localpart, homeserver = strings.TrimPrefix(parts[0], "@"), parts[1] return } From 20fde3d163133d31d70df91bc094ce9a7ccce546 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 13 Mar 2024 17:01:07 +0200 Subject: [PATCH 0166/1647] Remove error in ParseCommonIdentifier --- id/userid.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/id/userid.go b/id/userid.go index f1b6e27f..53b68b96 100644 --- a/id/userid.go +++ b/id/userid.go @@ -34,13 +34,12 @@ var ( ErrNoncompliantLocalpart = errors.New("contains characters that are not allowed") ErrUserIDTooLong = errors.New("the given user ID is longer than 255 characters") ErrEmptyLocalpart = errors.New("empty localparts are not allowed") - ErrEmptyIdentifier = errors.New("identifier is empty") ) // ParseCommonIdentifier parses a common identifier according to https://spec.matrix.org/v1.9/appendices/#common-identifier-format -func ParseCommonIdentifier[Stringish ~string](identifier Stringish) (sigil byte, localpart, homeserver string, err error) { +func ParseCommonIdentifier[Stringish ~string](identifier Stringish) (sigil byte, localpart, homeserver string) { if len(identifier) == 0 { - return 0, "", "", ErrEmptyIdentifier + return } sigil = identifier[0] strIdentifier := string(identifier) @@ -61,8 +60,8 @@ func ParseCommonIdentifier[Stringish ~string](identifier Stringish) (sigil byte, // ParseAndValidate and ValidateUserLocalpart functions. func (userID UserID) Parse() (localpart, homeserver string, err error) { var sigil byte - sigil, localpart, homeserver, err = ParseCommonIdentifier(userID) - if err != nil || sigil != '@' || homeserver == "" { + sigil, localpart, homeserver = ParseCommonIdentifier(userID) + if sigil != '@' || homeserver == "" { err = fmt.Errorf("'%s' %w", userID, ErrInvalidUserID) } return From a7bf4858932b6296b69533fa94330c06402b123d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 13 Mar 2024 21:23:04 +0200 Subject: [PATCH 0167/1647] Update changelog --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 821c8500..3614632e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,9 +8,16 @@ (thanks to [@grvn-ht] in [#181]). * *(crypto/verificationhelper)* Fixed bugs. * *(crypto)* Fixed key backup uploading doing too much base64. +* *(crypto)* Changed `EncryptMegolmEvent` to return an error if persisting the + megolm session fails. This ensures that database errors won't cause messages + to be sent with duplicate indexes. +* *(id)* Added `ParseCommonIdentifier` function to parse any Matrix identifier + in the [Common Identifier Format]. +* *(federation)* Added simple key server that passes the federation tester. [@grvn-ht]: https://github.com/grvn-ht [#181]: https://github.com/mautrix/go/pull/181 +[Common Identifier Format]: https://spec.matrix.org/v1.9/appendices/#common-identifier-format ### beta.1 (2024-02-16) From fad4448ab7ef4abdb002be01c9e5ee4fa5e16362 Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Tue, 12 Mar 2024 13:00:55 +0200 Subject: [PATCH 0168/1647] Use a callback to receive secret response To properly receive and store a requested secret, we usually need to validate it against something like a public key to ensure we got the correct one. This changes the API so that we instead use a callback to receive any incoming secret matching our request but we'll fail when we hit the specified timeout if we never receive anything that is accepted. --- crypto/sharing.go | 62 +++++++++++++++++++++++++++++------------------ 1 file changed, 39 insertions(+), 23 deletions(-) diff --git a/crypto/sharing.go b/crypto/sharing.go index 18088b8e..c0f3e209 100644 --- a/crypto/sharing.go +++ b/crypto/sharing.go @@ -16,13 +16,26 @@ import ( "maunium.net/go/mautrix/id" ) -func (mach *OlmMachine) GetOrRequestSecret(ctx context.Context, name id.Secret, timeout time.Duration) (secret string, err error) { - secret, err = mach.CryptoStore.GetSecret(ctx, name) - if err != nil || secret != "" { - return +// Callback function to process a received secret. +// +// Returning true or an error will immediately return from the wait loop, returning false will continue waiting for new responses. +type SecretReceiverFunc func(string) (bool, error) + +func (mach *OlmMachine) GetOrRequestSecret(ctx context.Context, name id.Secret, receiver SecretReceiverFunc, timeout time.Duration) (err error) { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + // always offer our stored secret first, if any + secret, err := mach.CryptoStore.GetSecret(ctx, name) + if err != nil { + return err + } else if secret != "" { + if ok, err := receiver(secret); ok || err != nil { + return err + } } - requestID, secretChan := random.String(64), make(chan string, 1) + requestID, secretChan := random.String(64), make(chan string, 5) mach.secretLock.Lock() mach.secretListeners[requestID] = secretChan mach.secretLock.Unlock() @@ -43,17 +56,27 @@ func (mach *OlmMachine) GetOrRequestSecret(ctx context.Context, name id.Secret, return } - select { - case <-ctx.Done(): - err = ctx.Err() - case <-time.After(timeout): - case secret = <-secretChan: - } + // best effort cancel request from all devices when returning + defer func() { + go mach.sendToOneDevice(context.Background(), mach.Client.UserID, id.DeviceID("*"), event.ToDeviceSecretRequest, &event.SecretRequestEventContent{ + Action: event.SecretRequestCancellation, + RequestID: requestID, + RequestingDeviceID: mach.Client.DeviceID, + }) + }() - if secret != "" { - err = mach.CryptoStore.PutSecret(ctx, name, secret) + for { + select { + case <-ctx.Done(): + return ctx.Err() + case secret = <-secretChan: + if ok, err := receiver(secret); err != nil { + return err + } else if ok { + return mach.CryptoStore.PutSecret(ctx, name, secret) + } + } } - return } func (mach *OlmMachine) HandleSecretRequest(ctx context.Context, userID id.UserID, content *event.SecretRequestEventContent) { @@ -159,17 +182,10 @@ func (mach *OlmMachine) receiveSecret(ctx context.Context, evt *DecryptedOlmEven return } + // secret channel is buffered and we don't want to block + // at worst we drop _some_ of the responses select { case secretChan <- content.Secret: default: } - - // best effort cancel this for all other targets - go func() { - mach.sendToOneDevice(ctx, mach.Client.UserID, id.DeviceID("*"), event.ToDeviceSecretRequest, &event.SecretRequestEventContent{ - Action: event.SecretRequestCancellation, - RequestID: content.RequestID, - RequestingDeviceID: mach.Client.DeviceID, - }) - }() } From b556d65da936837a3eb32259cf2ff9d9bc41d479 Mon Sep 17 00:00:00 2001 From: Malte E Date: Fri, 15 Mar 2024 22:29:33 +0100 Subject: [PATCH 0169/1647] add handler for accepting/rejecting/retracting invites --- bridge/bridge.go | 7 +++++++ bridge/matrix.go | 13 +++++++++++++ 2 files changed, 20 insertions(+) diff --git a/bridge/bridge.go b/bridge/bridge.go index 4a6cf68e..b5d8a365 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -115,6 +115,13 @@ type KnockHandlingPortal interface { HandleMatrixRejectKnock(sender User, ghost Ghost, evt *event.Event) } +type InviteHandlingPortal interface { + Portal + HandleMatrixAcceptInvite(sender User, evt *event.Event) + HandleMatrixRejectInvite(sender User, evt *event.Event) + HandleMatrixRetractInvite(sender User, ghost Ghost, evt *event.Event) +} + type User interface { GetPermissionLevel() bridgeconfig.PermissionLevel IsLoggedIn() bool diff --git a/bridge/matrix.go b/bridge/matrix.go index 5b646b34..1f16eb97 100644 --- a/bridge/matrix.go +++ b/bridge/matrix.go @@ -279,6 +279,7 @@ func (mx *MatrixHandler) HandleMembership(ctx context.Context, evt *event.Event) bhp, bhpOk := portal.(BanHandlingPortal) mhp, mhpOk := portal.(MembershipHandlingPortal) khp, khpOk := portal.(KnockHandlingPortal) + ihp, ihpOk := portal.(InviteHandlingPortal) if !(mhpOk || bhpOk || khpOk) { return } @@ -291,6 +292,18 @@ func (mx *MatrixHandler) HandleMembership(ctx context.Context, evt *event.Event) prevContent = &event.MemberEventContent{Membership: event.MembershipLeave} } } + if ihpOk && ghost != nil && prevContent.Membership == event.MembershipInvite && content.Membership != event.MembershipBan { + if content.Membership == event.MembershipJoin { + ihp.HandleMatrixAcceptInvite(user, evt) + } + if content.Membership == event.MembershipLeave { + if isSelf { + ihp.HandleMatrixRejectInvite(user, evt) + } else { + ihp.HandleMatrixRetractInvite(user, ghost, evt) + } + } + } if bhpOk && ghost != nil { if content.Membership == event.MembershipBan { bhp.HandleMatrixBan(user, ghost, evt) From 5dedc9806a6c9dcabbd88f88c8e24b63899efcc0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 16 Mar 2024 12:51:38 +0200 Subject: [PATCH 0170/1647] Bump version to v0.18.0 --- CHANGELOG.md | 5 ++++- go.mod | 4 ++-- go.sum | 8 ++++---- version.go | 2 +- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3614632e..cece9947 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -## v0.18.0 (unreleased) +## v0.18.0 (2024-03-16) * **Breaking change *(client, bridge, appservice)*** Dropped support for maulogger. Only zerolog loggers are now provided by default. @@ -11,6 +11,9 @@ * *(crypto)* Changed `EncryptMegolmEvent` to return an error if persisting the megolm session fails. This ensures that database errors won't cause messages to be sent with duplicate indexes. +* *(crypto)* Changed `GetOrRequestSecret` to use a callback instead of returning + the value directly. This allows validating the value in order to ignore + invalid secrets. * *(id)* Added `ParseCommonIdentifier` function to parse any Matrix identifier in the [Common Identifier Format]. * *(federation)* Added simple key server that passes the federation tester. diff --git a/go.mod b/go.mod index 425fcae5..3e349953 100644 --- a/go.mod +++ b/go.mod @@ -12,10 +12,10 @@ require ( github.com/tidwall/gjson v1.17.1 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.0 - go.mau.fi/util v0.4.1-0.20240311133655-ff64e137ce44 + go.mau.fi/util v0.4.1 go.mau.fi/zeroconfig v0.1.2 golang.org/x/crypto v0.21.0 - golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 + golang.org/x/exp v0.0.0-20240314144324-c7f7c6466f7f golang.org/x/net v0.22.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 diff --git a/go.sum b/go.sum index 42f5e3e3..705282d4 100644 --- a/go.sum +++ b/go.sum @@ -37,14 +37,14 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.0 h1:EfOIvIMZIzHdB/R/zVrikYLPPwJlfMcNczJFMs1m6sA= github.com/yuin/goldmark v1.7.0/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.4.1-0.20240311133655-ff64e137ce44 h1:d5nG84/nftM2sBibpoT8X4aCTOoueoe26DBBLGHi41k= -go.mau.fi/util v0.4.1-0.20240311133655-ff64e137ce44/go.mod h1:jOAREC/go8T6rGic01cu6WRa90xi9U4z3QmDjRf8xpo= +go.mau.fi/util v0.4.1 h1:3EC9KxIXo5+h869zDGf5OOZklRd/FjeVnimTwtm3owg= +go.mau.fi/util v0.4.1/go.mod h1:GjkTEBsehYZbSh2LlE6cWEn+6ZIZTGrTMM/5DMNlmFY= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= -golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 h1:LfspQV/FYTatPTr/3HzIcmiUFH7PGP+OQ6mgDYo3yuQ= -golang.org/x/exp v0.0.0-20240222234643-814bf88cf225/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc= +golang.org/x/exp v0.0.0-20240314144324-c7f7c6466f7f h1:3CW0unweImhOzd5FmYuRsD4Y4oQFKZIjAnKbjV4WIrw= +golang.org/x/exp v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc= golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/version.go b/version.go index ed53f6ae..82817bca 100644 --- a/version.go +++ b/version.go @@ -7,7 +7,7 @@ import ( "strings" ) -const Version = "v0.18.0-beta.1" +const Version = "v0.18.0" var GoModVersion = "" var Commit = "" From 8ba307b28d5c89352a03ca1cb102acd09cbeb53e Mon Sep 17 00:00:00 2001 From: Adam Van Ymeren Date: Sat, 16 Mar 2024 11:36:58 -0700 Subject: [PATCH 0171/1647] Fix Unsigned.IsEmpty() when all we have is HSOrder --- event/events.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/event/events.go b/event/events.go index f7b4d4d6..4653a531 100644 --- a/event/events.go +++ b/event/events.go @@ -149,5 +149,6 @@ type Unsigned struct { func (us *Unsigned) IsEmpty() bool { return us.PrevContent == nil && us.PrevSender == "" && us.ReplacesState == "" && us.Age == 0 && - us.TransactionID == "" && us.RedactedBecause == nil && us.InviteRoomState == nil && us.Relations == nil + us.TransactionID == "" && us.RedactedBecause == nil && us.InviteRoomState == nil && us.Relations == nil && + us.BeeperHSOrder == 0 } From 9fe66581e53832d3e4a4e568bcfaf6d89c5fadd5 Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Mon, 18 Mar 2024 13:05:33 +0200 Subject: [PATCH 0172/1647] Check that shared IGS has higher index than stored Copies the logic from key import. --- crypto/keysharing.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/crypto/keysharing.go b/crypto/keysharing.go index 05e7f894..d1b2e92c 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -184,6 +184,11 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt MaxMessages: maxMessages, IsScheduled: content.IsScheduled, } + existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.SenderKey, igs.ID()) + if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() { + // We already have an equivalent or better session in the store, so don't override it. + return false + } err = mach.CryptoStore.PutGroupSession(ctx, content.RoomID, content.SenderKey, content.SessionID, igs) if err != nil { log.Error().Err(err).Msg("Failed to store new inbound group session") From 0095e1fb78d9468eed522687e4d41a384b6df2d4 Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Thu, 28 Mar 2024 10:28:17 +0200 Subject: [PATCH 0173/1647] Assume the device list is up-to-date on key backup restore Fetching devices in a loop can cause request storming if there's a lot of unknown signatures for a key backup. A client implementation should always ensure that the devices are updated from device list changed updates from sync. --- crypto/keybackup.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/crypto/keybackup.go b/crypto/keybackup.go index d3701e93..afdb84fd 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -66,8 +66,10 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context) var key id.Ed25519 if keyName == crossSigningPubkeys.MasterKey.String() { key = crossSigningPubkeys.MasterKey - } else if device, err := mach.GetOrFetchDevice(ctx, mach.Client.UserID, id.DeviceID(keyName)); err != nil { - log.Warn().Err(err).Msg("Failed to fetch device") + } else if device, err := mach.CryptoStore.GetDevice(ctx, mach.Client.UserID, id.DeviceID(keyName)); err != nil { + return nil, fmt.Errorf("failed to get device %s/%s from store: %w", mach.Client.UserID, keyName, err) + } else if device == nil { + log.Warn().Err(err).Msg("Device does not exist, ignoring signature") continue } else if !mach.IsDeviceTrusted(device) { log.Warn().Err(err).Msg("Device is not trusted") From 64cc843952f05a61ce314387fc021f1db9294e76 Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Tue, 2 Apr 2024 13:38:55 +0300 Subject: [PATCH 0174/1647] Invalidate memory cache when storing own cross-signing keys When another device does cross-signing reset we would incorrectly cache the old keys indefinitely. --- crypto/cross_sign_store.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/crypto/cross_sign_store.go b/crypto/cross_sign_store.go index 28d0bad0..456ab6ed 100644 --- a/crypto/cross_sign_store.go +++ b/crypto/cross_sign_store.go @@ -96,5 +96,12 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK } } } + + // Clear internal cache so that it refreshes from crypto store + if userID == mach.Client.UserID && mach.crossSigningPubkeys != nil { + log.Debug().Msg("Resetting internal cross-signing key cache") + mach.crossSigningPubkeys = nil + mach.crossSigningPubkeysFetched = false + } } } From 898b235a840b3e796cc35b094ca55847c2963b78 Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Wed, 13 Mar 2024 08:49:13 +0200 Subject: [PATCH 0175/1647] Allow overriding http.Client with FullRequest --- client.go | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/client.go b/client.go index ad68e793..686c0e19 100644 --- a/client.go +++ b/client.go @@ -338,6 +338,7 @@ type FullRequest struct { SensitiveContent bool Handler ClientResponseHandler Logger *zerolog.Logger + Client *http.Client } var requestID int32 @@ -424,7 +425,10 @@ func (cli *Client) MakeFullRequest(ctx context.Context, params FullRequest) ([]b if len(cli.AccessToken) > 0 { req.Header.Set("Authorization", "Bearer "+cli.AccessToken) } - return cli.executeCompiledRequest(req, params.MaxAttempts-1, 4*time.Second, params.ResponseJSON, params.Handler) + if params.Client == nil { + params.Client = cli.Client + } + return cli.executeCompiledRequest(req, params.MaxAttempts-1, 4*time.Second, params.ResponseJSON, params.Handler, params.Client) } func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger { @@ -435,7 +439,7 @@ func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger { return log } -func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff time.Duration, responseJSON interface{}, handler ClientResponseHandler) ([]byte, error) { +func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff time.Duration, responseJSON interface{}, handler ClientResponseHandler, client *http.Client) ([]byte, error) { log := zerolog.Ctx(req.Context()) if req.Body != nil { if req.GetBody == nil { @@ -453,7 +457,7 @@ func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff Int("retry_in_seconds", int(backoff.Seconds())). Msg("Request failed, retrying") time.Sleep(backoff) - return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler) + return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, client) } func readRequestBody(req *http.Request, res *http.Response) ([]byte, error) { @@ -535,17 +539,17 @@ func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) { } } -func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backoff time.Duration, responseJSON interface{}, handler ClientResponseHandler) ([]byte, error) { +func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backoff time.Duration, responseJSON interface{}, handler ClientResponseHandler, client *http.Client) ([]byte, error) { cli.RequestStart(req) startTime := time.Now() - res, err := cli.Client.Do(req) + res, err := client.Do(req) duration := time.Now().Sub(startTime) if res != nil { defer res.Body.Close() } if err != nil { if retries > 0 { - return cli.doRetry(req, err, retries, backoff, responseJSON, handler) + return cli.doRetry(req, err, retries, backoff, responseJSON, handler, client) } err = HTTPError{ Request: req, @@ -560,7 +564,7 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof if retries > 0 && retryafter.Should(res.StatusCode, !cli.IgnoreRateLimit) { backoff = retryafter.Parse(res.Header.Get("Retry-After"), backoff) - return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler) + return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, client) } var body []byte From 640086dbf98ada3fba17edb4704231d42c33948c Mon Sep 17 00:00:00 2001 From: Malte E <97891689+maltee1@users.noreply.github.com> Date: Fri, 5 Apr 2024 01:27:36 +0200 Subject: [PATCH 0176/1647] Fix default prevContent in bridge membership event handler (#204) --- bridge/matrix.go | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/bridge/matrix.go b/bridge/matrix.go index 1f16eb97..7c1a5e25 100644 --- a/bridge/matrix.go +++ b/bridge/matrix.go @@ -283,14 +283,10 @@ func (mx *MatrixHandler) HandleMembership(ctx context.Context, evt *event.Event) if !(mhpOk || bhpOk || khpOk) { return } - var prevContent *event.MemberEventContent + prevContent := &event.MemberEventContent{Membership: event.MembershipLeave} if evt.Unsigned.PrevContent != nil { _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) - var ok bool - prevContent, ok = evt.Unsigned.PrevContent.Parsed.(*event.MemberEventContent) - if !ok { - prevContent = &event.MemberEventContent{Membership: event.MembershipLeave} - } + prevContent, _ = evt.Unsigned.PrevContent.Parsed.(*event.MemberEventContent) } if ihpOk && ghost != nil && prevContent.Membership == event.MembershipInvite && content.Membership != event.MembershipBan { if content.Membership == event.MembershipJoin { From 423d32ddf6d615d1d3b44ec443156f1836bec628 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 13 Apr 2024 13:56:26 +0300 Subject: [PATCH 0177/1647] Add real context to HTML parser context struct --- format/htmlparser.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/format/htmlparser.go b/format/htmlparser.go index eb2a662b..8ddd8818 100644 --- a/format/htmlparser.go +++ b/format/htmlparser.go @@ -7,6 +7,7 @@ package format import ( + "context" "fmt" "math" "strconv" @@ -33,14 +34,16 @@ func (ts TagStack) Has(tag string) bool { } type Context struct { + Ctx context.Context ReturnData map[string]any TagStack TagStack PreserveWhitespace bool } -func NewContext() Context { +func NewContext(ctx context.Context) Context { return Context{ + Ctx: ctx, ReturnData: map[string]any{}, TagStack: make(TagStack, 0, 4), } @@ -411,7 +414,7 @@ func HTMLToText(html string) string { Newline: "\n", HorizontalLine: "\n---\n", PillConverter: DefaultPillConverter, - }).Parse(html, NewContext()) + }).Parse(html, NewContext(context.TODO())) } // HTMLToMarkdown converts Matrix HTML into markdown with the default settings. @@ -429,5 +432,5 @@ func HTMLToMarkdown(html string) string { } return fmt.Sprintf("[%s](%s)", text, href) }, - }).Parse(html, NewContext()) + }).Parse(html, NewContext(context.TODO())) } From a19dab189714649037720d2907e2d0405528643a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 16 Apr 2024 13:57:50 +0300 Subject: [PATCH 0178/1647] Bump version to v0.18.1 --- CHANGELOG.md | 17 +++++++++++++++++ crypto/cross_sign_validation.go | 2 +- go.mod | 12 ++++++------ go.sum | 24 ++++++++++++------------ version.go | 2 +- 5 files changed, 37 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cece9947..d7d17bc1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,20 @@ +## v0.18.1 (2024-04-16) + +* *(format)* Added a `context.Context` field to HTMLParser's Context struct. +* *(bridge)* Added support for handling join rules, knocks, invites and bans + (thanks to [@maltee1] in [#193] and [#204]). +* *(crypto)* Changed forwarded room key handling to only accept keys with a + lower first known index than the existing session if there is one. +* *(crypto)* Changed key backup restore to assume own device list is up to date + to avoid re-requesting device list for every deleted device that has signed + key backup. +* *(crypto)* Fixed memory cache not being invalidated when storing own + cross-signing keys + +[@maltee1]: https://github.com/maltee1 +[#193]: https://github.com/mautrix/go/pull/193 +[#204]: https://github.com/mautrix/go/pull/204 + ## v0.18.0 (2024-03-16) * **Breaking change *(client, bridge, appservice)*** Dropped support for diff --git a/crypto/cross_sign_validation.go b/crypto/cross_sign_validation.go index ff2452ec..04a179df 100644 --- a/crypto/cross_sign_validation.go +++ b/crypto/cross_sign_validation.go @@ -32,7 +32,7 @@ func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Devi } theirMSK, ok := theirKeys[id.XSUsageMaster] if !ok { - mach.machOrContextLog(ctx).Error(). + mach.machOrContextLog(ctx).Debug(). Str("user_id", device.UserID.String()). Msg("Master key of user not found") return id.TrustStateUnset, nil diff --git a/go.mod b/go.mod index 3e349953..5afaff8d 100644 --- a/go.mod +++ b/go.mod @@ -11,12 +11,12 @@ require ( github.com/stretchr/testify v1.9.0 github.com/tidwall/gjson v1.17.1 github.com/tidwall/sjson v1.2.5 - github.com/yuin/goldmark v1.7.0 - go.mau.fi/util v0.4.1 + github.com/yuin/goldmark v1.7.1 + go.mau.fi/util v0.4.2 go.mau.fi/zeroconfig v0.1.2 - golang.org/x/crypto v0.21.0 - golang.org/x/exp v0.0.0-20240314144324-c7f7c6466f7f - golang.org/x/net v0.22.0 + golang.org/x/crypto v0.22.0 + golang.org/x/exp v0.0.0-20240409090435-93d18d7e34b8 + golang.org/x/net v0.24.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 ) @@ -29,6 +29,6 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect - golang.org/x/sys v0.18.0 // indirect + golang.org/x/sys v0.19.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index 705282d4..3704b9f2 100644 --- a/go.sum +++ b/go.sum @@ -35,23 +35,23 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/yuin/goldmark v1.7.0 h1:EfOIvIMZIzHdB/R/zVrikYLPPwJlfMcNczJFMs1m6sA= -github.com/yuin/goldmark v1.7.0/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.4.1 h1:3EC9KxIXo5+h869zDGf5OOZklRd/FjeVnimTwtm3owg= -go.mau.fi/util v0.4.1/go.mod h1:GjkTEBsehYZbSh2LlE6cWEn+6ZIZTGrTMM/5DMNlmFY= +github.com/yuin/goldmark v1.7.1 h1:3bajkSilaCbjdKVsKdZjZCLBNPL9pYzrCakKaf4U49U= +github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= +go.mau.fi/util v0.4.2 h1:RR3TOcRHmCF9Bx/3YG4S65MYfa+nV6/rn8qBWW4Mi30= +go.mau.fi/util v0.4.2/go.mod h1:PlAVfUUcPyHPrwnvjkJM9UFcPE7qGPDJqk+Oufa1Gtw= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= -golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= -golang.org/x/exp v0.0.0-20240314144324-c7f7c6466f7f h1:3CW0unweImhOzd5FmYuRsD4Y4oQFKZIjAnKbjV4WIrw= -golang.org/x/exp v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc= -golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= -golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= +golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= +golang.org/x/exp v0.0.0-20240409090435-93d18d7e34b8 h1:ESSUROHIBHg7USnszlcdmjBEwdMj9VUvU+OPk4yl2mc= +golang.org/x/exp v0.0.0-20240409090435-93d18d7e34b8/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI= +golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= +golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= +golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= diff --git a/version.go b/version.go index 82817bca..e00141ae 100644 --- a/version.go +++ b/version.go @@ -7,7 +7,7 @@ import ( "strings" ) -const Version = "v0.18.0" +const Version = "v0.18.1" var GoModVersion = "" var Commit = "" From ff9e2e0f1d3700990f9875b469f84a879f61603b Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 18 Apr 2024 15:22:12 -0600 Subject: [PATCH 0179/1647] machine/ShareKeys: save keys before sending server request in case it fails Signed-off-by: Sumner Evans --- crypto/account.go | 1 - crypto/machine.go | 6 ++++++ crypto/machine_test.go | 1 + 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/crypto/account.go b/crypto/account.go index d242df6f..2f012e59 100644 --- a/crypto/account.go +++ b/crypto/account.go @@ -81,6 +81,5 @@ func (account *OlmAccount) getOneTimeKeys(userID id.UserID, deviceID id.DeviceID key.IsSigned = true oneTimeKeys[id.NewKeyID(id.KeyAlgorithmSignedCurve25519, keyID)] = key } - account.Internal.MarkKeysAsPublished() return oneTimeKeys } diff --git a/crypto/machine.go b/crypto/machine.go index 4417faf3..41149f01 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -681,6 +681,11 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro log.Debug().Msg("No one-time keys nor device keys got when trying to share keys") return nil } + // Save the keys before sending the upload request in case there is a + // network failure. + if err := mach.saveAccount(ctx); err != nil { + return err + } req := &mautrix.ReqUploadKeys{ DeviceKeys: deviceKeys, OneTimeKeys: oneTimeKeys, @@ -691,6 +696,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro return err } mach.lastOTKUpload = time.Now() + mach.account.Internal.MarkKeysAsPublished() mach.account.Shared = true return mach.saveAccount(ctx) } diff --git a/crypto/machine_test.go b/crypto/machine_test.go index d3750d34..057c2ae1 100644 --- a/crypto/machine_test.go +++ b/crypto/machine_test.go @@ -77,6 +77,7 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { otk = otkTmp break } + machineIn.account.Internal.MarkKeysAsPublished() // create outbound olm session for sending machine using OTK olmSession, err := machineOut.account.Internal.NewOutboundSession(machineIn.account.IdentityKey(), otk.Key) From 6cc490d9ab8ca902893c6c841281e1a178f7fbe5 Mon Sep 17 00:00:00 2001 From: Malte E <97891689+maltee1@users.noreply.github.com> Date: Sun, 21 Apr 2024 15:22:26 +0200 Subject: [PATCH 0180/1647] check ghost != nil in correct line (#208) --- bridge/matrix.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bridge/matrix.go b/bridge/matrix.go index 7c1a5e25..446a0b0a 100644 --- a/bridge/matrix.go +++ b/bridge/matrix.go @@ -288,14 +288,14 @@ func (mx *MatrixHandler) HandleMembership(ctx context.Context, evt *event.Event) _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) prevContent, _ = evt.Unsigned.PrevContent.Parsed.(*event.MemberEventContent) } - if ihpOk && ghost != nil && prevContent.Membership == event.MembershipInvite && content.Membership != event.MembershipBan { + if ihpOk && prevContent.Membership == event.MembershipInvite && content.Membership != event.MembershipBan { if content.Membership == event.MembershipJoin { ihp.HandleMatrixAcceptInvite(user, evt) } if content.Membership == event.MembershipLeave { if isSelf { ihp.HandleMatrixRejectInvite(user, evt) - } else { + } else if ghost != nil { ihp.HandleMatrixRetractInvite(user, ghost, evt) } } From 2810465ef29466b531d3424bd4175fa8f0af4c26 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 25 Apr 2024 09:40:57 -0600 Subject: [PATCH 0181/1647] verificationhelper: ensure that the keys are fetched before starting Signed-off-by: Sumner Evans --- crypto/devicelist.go | 4 ++++ crypto/verificationhelper/verificationhelper.go | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/crypto/devicelist.go b/crypto/devicelist.go index 16c4164e..e98ba45a 100644 --- a/crypto/devicelist.go +++ b/crypto/devicelist.go @@ -93,6 +93,10 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id } } +// FetchKeys fetches the devices of a list of other users. If includeUntracked +// is set to false, then the users are filtered to to only include user IDs +// whose device lists have been stored with the PutDevices function on the +// [Store]. See the FilterTrackedUsers function on [Store] for details. func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includeUntracked bool) (data map[id.UserID]map[id.DeviceID]*id.Device, err error) { req := &mautrix.ReqQueryKeys{ DeviceKeys: mautrix.DeviceKeysRequest{}, diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 025af25e..c8c00f16 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -337,6 +337,12 @@ func (vh *VerificationHelper) StartVerification(ctx context.Context, to id.UserI devices, err := vh.mach.CryptoStore.GetDevices(ctx, to) if err != nil { return "", fmt.Errorf("failed to get devices for user: %w", err) + } else if len(devices) == 0 { + // HACK: we are doing this because the client doesn't wait until it has + // the devices before starting verification. + if _, err := vh.mach.FetchKeys(ctx, []id.UserID{to}, true); err != nil { + return "", err + } } vh.getLog(ctx).Info(). From c0e030fc85fcf869c2d95d148a676c3570ab5ffb Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Sun, 12 May 2024 18:10:48 -0600 Subject: [PATCH 0182/1647] crypto/olm: remove Signatures definition Signed-off-by: Sumner Evans --- crypto/olm/olm.go | 6 ------ crypto/olm/olm_goolm.go | 7 ------- 2 files changed, 13 deletions(-) diff --git a/crypto/olm/olm.go b/crypto/olm/olm.go index 685e1b6b..fa1ae856 100644 --- a/crypto/olm/olm.go +++ b/crypto/olm/olm.go @@ -5,12 +5,6 @@ package olm // #cgo LDFLAGS: -lolm -lstdc++ // #include import "C" -import ( - "maunium.net/go/mautrix/id" -) - -// Signatures is the data structure used to sign JSON objects. -type Signatures map[id.UserID]map[id.DeviceKeyID]string // Version returns the version number of the olm library. func Version() (major, minor, patch uint8) { diff --git a/crypto/olm/olm_goolm.go b/crypto/olm/olm_goolm.go index dbe12a76..a1489ded 100644 --- a/crypto/olm/olm_goolm.go +++ b/crypto/olm/olm_goolm.go @@ -2,13 +2,6 @@ package olm -import ( - "maunium.net/go/mautrix/id" -) - -// Signatures is the data structure used to sign JSON objects. -type Signatures map[id.UserID]map[id.DeviceKeyID]string - // Version returns the version number of the olm library. func Version() (major, minor, patch uint8) { return 3, 2, 15 From 01fde7d9a8245f8dbf2870306a604e6d57295465 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 13 May 2024 15:42:34 -0600 Subject: [PATCH 0183/1647] verificationhelper/StartVerification: actually set devices after FetchKeys Signed-off-by: Sumner Evans --- crypto/verificationhelper/verificationhelper.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index c8c00f16..31f237bd 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -340,8 +340,10 @@ func (vh *VerificationHelper) StartVerification(ctx context.Context, to id.UserI } else if len(devices) == 0 { // HACK: we are doing this because the client doesn't wait until it has // the devices before starting verification. - if _, err := vh.mach.FetchKeys(ctx, []id.UserID{to}, true); err != nil { + if keys, err := vh.mach.FetchKeys(ctx, []id.UserID{to}, true); err != nil { return "", err + } else { + devices = keys[to] } } From d10103dcf588589ee0630bce429562ce52903552 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 13 May 2024 21:36:59 -0600 Subject: [PATCH 0184/1647] crypto/encryptmegolm: return error if sharing outbound session fails This allows us to catch and throw "database is locked" errors. This will ensure that if saving the key fails, then we won't share the key out to anyone. Signed-off-by: Sumner Evans --- crypto/encryptmegolm.go | 14 ++++++++++---- crypto/machine.go | 17 ++++++++++------- crypto/machine_test.go | 6 ++++-- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index d592bd1c..19ff5c8c 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -136,19 +136,23 @@ func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID return encrypted, nil } -func (mach *OlmMachine) newOutboundGroupSession(ctx context.Context, roomID id.RoomID) *OutboundGroupSession { +func (mach *OlmMachine) newOutboundGroupSession(ctx context.Context, roomID id.RoomID) (*OutboundGroupSession, error) { encryptionEvent, err := mach.StateStore.GetEncryptionEvent(ctx, roomID) if err != nil { mach.machOrContextLog(ctx).Err(err). Stringer("room_id", roomID). Msg("Failed to get encryption event in room") + return nil, fmt.Errorf("failed to get encryption event in room %s: %w", roomID, err) } session := NewOutboundGroupSession(roomID, encryptionEvent) if !mach.DontStoreOutboundKeys { signingKey, idKey := mach.account.Keys() - mach.createGroupSession(ctx, idKey, signingKey, roomID, session.ID(), session.Internal.Key(), session.MaxAge, session.MaxMessages, false) + err := mach.createGroupSession(ctx, idKey, signingKey, roomID, session.ID(), session.Internal.Key(), session.MaxAge, session.MaxMessages, false) + if err != nil { + return nil, err + } } - return session + return session, err } type deviceSessionWrapper struct { @@ -183,7 +187,9 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, Logger() ctx = log.WithContext(ctx) if session == nil || session.Expired() { - session = mach.newOutboundGroupSession(ctx, roomID) + if session, err = mach.newOutboundGroupSession(ctx, roomID); err != nil { + return err + } } log = log.With().Str("session_id", session.ID().String()).Logger() ctx = log.WithContext(ctx) diff --git a/crypto/machine.go b/crypto/machine.go index 41149f01..0cdeebea 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -505,23 +505,22 @@ func (mach *OlmMachine) SendEncryptedToDevice(ctx context.Context, device *id.De return err } -func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionID id.SessionID, sessionKey string, maxAge time.Duration, maxMessages int, isScheduled bool) { +func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionID id.SessionID, sessionKey string, maxAge time.Duration, maxMessages int, isScheduled bool) error { log := zerolog.Ctx(ctx) igs, err := NewInboundGroupSession(senderKey, signingKey, roomID, sessionKey, maxAge, maxMessages, isScheduled) if err != nil { - log.Error().Err(err).Msg("Failed to create inbound group session") - return + return fmt.Errorf("failed to create inbound group session: %w", err) } else if igs.ID() != sessionID { log.Warn(). Str("expected_session_id", sessionID.String()). Str("actual_session_id", igs.ID().String()). Msg("Mismatched session ID while creating inbound group session") - return + return fmt.Errorf("mismatched session ID while creating inbound group session") } err = mach.CryptoStore.PutGroupSession(ctx, roomID, senderKey, sessionID, igs) if err != nil { - log.Error().Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session") - return + log.Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session") + return fmt.Errorf("failed to store new inbound group session: %w", err) } mach.markSessionReceived(ctx, sessionID) log.Debug(). @@ -531,6 +530,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen Int("max_messages", maxMessages). Bool("is_scheduled", isScheduled). Msg("Received inbound group session") + return nil } func (mach *OlmMachine) markSessionReceived(ctx context.Context, id id.SessionID) { @@ -626,7 +626,10 @@ func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEve Msg("Redacted previous megolm sessions") } } - mach.createGroupSession(ctx, evt.SenderKey, evt.Keys.Ed25519, content.RoomID, content.SessionID, content.SessionKey, maxAge, maxMessages, content.IsScheduled) + err = mach.createGroupSession(ctx, evt.SenderKey, evt.Keys.Ed25519, content.RoomID, content.SessionID, content.SessionKey, maxAge, maxMessages, content.IsScheduled) + if err != nil { + log.Err(err).Msg("Failed to create inbound group session") + } } func (mach *OlmMachine) HandleRoomKeyWithheld(ctx context.Context, content *event.RoomKeyWithheldEventContent) { diff --git a/crypto/machine_test.go b/crypto/machine_test.go index 057c2ae1..807b65b2 100644 --- a/crypto/machine_test.go +++ b/crypto/machine_test.go @@ -56,7 +56,8 @@ func newMachine(t *testing.T, userID id.UserID) *OlmMachine { func TestRatchetMegolmSession(t *testing.T) { mach := newMachine(t, "user1") - outSess := mach.newOutboundGroupSession(context.TODO(), "meow") + outSess, err := mach.newOutboundGroupSession(context.TODO(), "meow") + assert.NoError(t, err) inSess, err := mach.CryptoStore.GetGroupSession(context.TODO(), "meow", mach.OwnIdentity().IdentityKey, outSess.ID()) require.NoError(t, err) assert.Equal(t, uint32(0), inSess.Internal.FirstKnownIndex()) @@ -96,7 +97,8 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { }) // create & store outbound megolm session for sending the event later - megolmOutSession := machineOut.newOutboundGroupSession(context.TODO(), "room1") + megolmOutSession, err := machineOut.newOutboundGroupSession(context.TODO(), "room1") + assert.NoError(t, err) megolmOutSession.Shared = true machineOut.CryptoStore.AddOutboundGroupSession(context.TODO(), megolmOutSession) From 043918073770fbdc789dcdbc3634b73aa991f748 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 14 May 2024 10:32:11 -0600 Subject: [PATCH 0185/1647] crypto/sql_store: fix a couple places where the error value is unused Signed-off-by: Sumner Evans --- crypto/sql_store.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/crypto/sql_store.go b/crypto/sql_store.go index a3b3b74a..9b1d28c2 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -331,6 +331,9 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room } } igs, chains, rs, err := store.postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes, forwardingChains.String) + if err != nil { + return nil, err + } if senderKey == "" { senderKey = id.Curve25519(senderKeyDB.String) } @@ -479,6 +482,9 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In return nil, err } igs, chains, rs, err := store.postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes, forwardingChains.String) + if err != nil { + return nil, err + } return &InboundGroupSession{ Internal: *igs, SigningKey: id.Ed25519(signingKey.String), From 78f5e4373b0810f0cbf503e4a5be848eeb5642ef Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Tue, 14 May 2024 12:43:15 +0100 Subject: [PATCH 0186/1647] Pass error to `Client.ResponseHook` --- client.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index 686c0e19..cc133f95 100644 --- a/client.go +++ b/client.go @@ -65,7 +65,7 @@ type Client struct { Log zerolog.Logger RequestHook func(req *http.Request) - ResponseHook func(req *http.Request, resp *http.Response, duration time.Duration) + ResponseHook func(req *http.Request, resp *http.Response, err error, duration time.Duration) SyncPresence event.Presence @@ -291,10 +291,10 @@ func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err er Str("method", req.Method). Str("url", req.URL.String()). Dur("duration", duration) + if cli.ResponseHook != nil { + cli.ResponseHook(req, resp, err, duration) + } if resp != nil { - if cli.ResponseHook != nil { - cli.ResponseHook(req, resp, duration) - } mime := resp.Header.Get("Content-Type") length := resp.ContentLength if length == -1 && contentLength > 0 { From 5490cc6aee199ff40b51dc1c68fd455101463ca3 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 14 May 2024 11:08:01 -0600 Subject: [PATCH 0187/1647] crypto/sql_store: add logging on PutGroupSession Signed-off-by: Sumner Evans --- crypto/sql_store.go | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 9b1d28c2..87f50426 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -283,6 +283,18 @@ func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, roomID id.Room if err != nil { return fmt.Errorf("failed to marshal ratchet safety info: %w", err) } + zerolog.Ctx(ctx).Debug(). + Stringer("session_id", sessionID). + Str("account_id", store.AccountID). + Stringer("sender_key", senderKey). + Stringer("signing_key", session.SigningKey). + Stringer("room_id", roomID). + Time("received_at", session.ReceivedAt). + Int64("max_age", session.MaxAge). + Int("max_messages", session.MaxMessages). + Bool("is_scheduled", session.IsScheduled). + Stringer("key_backup_version", session.KeyBackupVersion). + Msg("Upserting megolm inbound group session") _, err = store.DB.Exec(ctx, ` INSERT INTO crypto_megolm_inbound_session ( session_id, sender_key, signing_key, room_id, session, forwarding_chains, @@ -293,7 +305,7 @@ func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, roomID id.Room room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains, ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at, max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled, - key_backup_version=excluded.key_backup_version + key_backup_version=excluded.key_backup_version `, sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains, ratchetSafety, datePtr(session.ReceivedAt), intishPtr(session.MaxAge), intishPtr(session.MaxMessages), From 34ef1b3705dd62a3ead935191c79312a20bdce1e Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 14 May 2024 11:24:55 -0600 Subject: [PATCH 0188/1647] crypto/sql_store: don't check sender_key in GetGroupSession Signed-off-by: Sumner Evans --- crypto/sql_store.go | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 87f50426..1e56f74d 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -315,7 +315,7 @@ func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, roomID id.Room } // GetGroupSession retrieves an inbound Megolm group session for a room, sender and session. -func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) { +func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.RoomID, _ id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) { var senderKeyDB, signingKey, forwardingChains, withheldCode, withheldReason sql.NullString var sessionBytes, ratchetSafetyBytes []byte var receivedAt sql.NullTime @@ -325,8 +325,8 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room err := store.DB.QueryRow(ctx, ` SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version FROM crypto_megolm_inbound_session - WHERE room_id=$1 AND (sender_key=$2 OR $2 = '') AND session_id=$3 AND account_id=$4`, - roomID, senderKey, sessionID, store.AccountID, + WHERE room_id=$1 AND session_id=$2 AND account_id=$3`, + roomID, sessionID, store.AccountID, ).Scan(&senderKeyDB, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version) if errors.Is(err, sql.ErrNoRows) { return nil, nil @@ -337,7 +337,7 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room RoomID: roomID, Algorithm: id.AlgorithmMegolmV1, SessionID: sessionID, - SenderKey: senderKey, + SenderKey: id.Curve25519(senderKeyDB.String), Code: event.RoomKeyWithheldCode(withheldCode.String), Reason: withheldReason.String, } @@ -346,13 +346,10 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room if err != nil { return nil, err } - if senderKey == "" { - senderKey = id.Curve25519(senderKeyDB.String) - } return &InboundGroupSession{ Internal: *igs, SigningKey: id.Ed25519(signingKey.String), - SenderKey: senderKey, + SenderKey: id.Curve25519(senderKeyDB.String), RoomID: roomID, ForwardingChains: chains, RatchetSafety: rs, From b31dbb0bd0a3d059bc9cd4a48978db939b0e435e Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 14 May 2024 11:37:22 -0600 Subject: [PATCH 0189/1647] store: update interface to not take sender key According to https://spec.matrix.org/latest/client-server-api/#mmegolmv1aes-sha2, clients MUST NOT store or lookup sessions using the sender key. This commit removes the sender key from most of the functions related to putting and getting group sessions from the Store interface. Notably, RedactGroupSessions still accepts a sender key because it's meant for batch deletion of group sessions. Signed-off-by: Sumner Evans Signed-off-by: Sumner Evans --- crypto/store.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/crypto/store.go b/crypto/store.go index 3b6e6564..5918e7e4 100644 --- a/crypto/store.go +++ b/crypto/store.go @@ -50,13 +50,13 @@ type Store interface { // PutGroupSession inserts an inbound Megolm session into the store. If an earlier withhold event has been inserted // with PutWithheldGroupSession, this call should replace that. However, PutWithheldGroupSession must not replace // sessions inserted with this call. - PutGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, *InboundGroupSession) error + PutGroupSession(context.Context, *InboundGroupSession) error // GetGroupSession gets an inbound Megolm session from the store. If the group session has been withheld // (i.e. a room key withheld event has been saved with PutWithheldGroupSession), this should return the // ErrGroupSessionWithheld error. The caller may use GetWithheldGroupSession to find more details. - GetGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID) (*InboundGroupSession, error) + GetGroupSession(context.Context, id.RoomID, id.SessionID) (*InboundGroupSession, error) // RedactGroupSession removes the session data for the given inbound Megolm session from the store. - RedactGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, string) error + RedactGroupSession(context.Context, id.RoomID, id.SessionID, string) error // RedactGroupSessions removes the session data for all inbound Megolm sessions from a specific device and/or in a specific room. RedactGroupSessions(context.Context, id.RoomID, id.SenderKey, string) ([]id.SessionID, error) // RedactExpiredGroupSessions removes the session data for all inbound Megolm sessions that have expired. @@ -66,7 +66,7 @@ type Store interface { // PutWithheldGroupSession tells the store that a specific Megolm session was withheld. PutWithheldGroupSession(context.Context, event.RoomKeyWithheldEventContent) error // GetWithheldGroupSession gets the event content that was previously inserted with PutWithheldGroupSession. - GetWithheldGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID) (*event.RoomKeyWithheldEventContent, error) + GetWithheldGroupSession(context.Context, id.RoomID, id.SessionID) (*event.RoomKeyWithheldEventContent, error) // GetGroupSessionsForRoom gets all the inbound Megolm sessions for a specific room. This is used for creating key // export files. Unlike GetGroupSession, this should not return any errors about withheld keys. From d0de43f3952044f7ae29fe42ee8f3487c3f19453 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 14 May 2024 12:27:37 -0600 Subject: [PATCH 0190/1647] crypto/sql_store: don't take sender key on group session methods Fixes compatibility with the Store interface. Signed-off-by: Sumner Evans --- crypto/sql_store.go | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 1e56f74d..689c25f0 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -276,7 +276,7 @@ func datePtr(t time.Time) *time.Time { } // PutGroupSession stores an inbound Megolm group session for a room, sender and session. -func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *InboundGroupSession) error { +func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, session *InboundGroupSession) error { sessionBytes := session.Internal.Pickle(store.PickleKey) forwardingChains := strings.Join(session.ForwardingChains, ",") ratchetSafety, err := json.Marshal(&session.RatchetSafety) @@ -284,11 +284,11 @@ func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, roomID id.Room return fmt.Errorf("failed to marshal ratchet safety info: %w", err) } zerolog.Ctx(ctx).Debug(). - Stringer("session_id", sessionID). + Stringer("session_id", session.ID()). Str("account_id", store.AccountID). - Stringer("sender_key", senderKey). + Stringer("sender_key", session.SenderKey). Stringer("signing_key", session.SigningKey). - Stringer("room_id", roomID). + Stringer("room_id", session.RoomID). Time("received_at", session.ReceivedAt). Int64("max_age", session.MaxAge). Int("max_messages", session.MaxMessages). @@ -307,7 +307,7 @@ func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, roomID id.Room max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled, key_backup_version=excluded.key_backup_version `, - sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains, + session.ID(), session.SenderKey, session.SigningKey, session.RoomID, sessionBytes, forwardingChains, ratchetSafety, datePtr(session.ReceivedAt), intishPtr(session.MaxAge), intishPtr(session.MaxMessages), session.IsScheduled, session.KeyBackupVersion, store.AccountID, ) @@ -315,8 +315,8 @@ func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, roomID id.Room } // GetGroupSession retrieves an inbound Megolm group session for a room, sender and session. -func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.RoomID, _ id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) { - var senderKeyDB, signingKey, forwardingChains, withheldCode, withheldReason sql.NullString +func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.RoomID, sessionID id.SessionID) (*InboundGroupSession, error) { + var senderKey, signingKey, forwardingChains, withheldCode, withheldReason sql.NullString var sessionBytes, ratchetSafetyBytes []byte var receivedAt sql.NullTime var maxAge, maxMessages sql.NullInt64 @@ -327,7 +327,7 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room FROM crypto_megolm_inbound_session WHERE room_id=$1 AND session_id=$2 AND account_id=$3`, roomID, sessionID, store.AccountID, - ).Scan(&senderKeyDB, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version) + ).Scan(&senderKey, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version) if errors.Is(err, sql.ErrNoRows) { return nil, nil } else if err != nil { @@ -337,7 +337,7 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room RoomID: roomID, Algorithm: id.AlgorithmMegolmV1, SessionID: sessionID, - SenderKey: id.Curve25519(senderKeyDB.String), + SenderKey: id.Curve25519(senderKey.String), Code: event.RoomKeyWithheldCode(withheldCode.String), Reason: withheldReason.String, } @@ -349,7 +349,7 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room return &InboundGroupSession{ Internal: *igs, SigningKey: id.Ed25519(signingKey.String), - SenderKey: id.Curve25519(senderKeyDB.String), + SenderKey: id.Curve25519(senderKey.String), RoomID: roomID, ForwardingChains: chains, RatchetSafety: rs, @@ -361,7 +361,7 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room }, nil } -func (store *SQLCryptoStore) RedactGroupSession(ctx context.Context, _ id.RoomID, _ id.SenderKey, sessionID id.SessionID, reason string) error { +func (store *SQLCryptoStore) RedactGroupSession(ctx context.Context, _ id.RoomID, sessionID id.SessionID, reason string) error { _, err := store.DB.Exec(ctx, ` UPDATE crypto_megolm_inbound_session SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL @@ -437,13 +437,13 @@ func (store *SQLCryptoStore) PutWithheldGroupSession(ctx context.Context, conten return err } -func (store *SQLCryptoStore) GetWithheldGroupSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) { - var code, reason sql.NullString +func (store *SQLCryptoStore) GetWithheldGroupSession(ctx context.Context, roomID id.RoomID, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) { + var senderKey, code, reason sql.NullString err := store.DB.QueryRow(ctx, ` - SELECT withheld_code, withheld_reason FROM crypto_megolm_inbound_session - WHERE room_id=$1 AND sender_key=$2 AND session_id=$3 AND account_id=$4`, - roomID, senderKey, sessionID, store.AccountID, - ).Scan(&code, &reason) + SELECT withheld_code, withheld_reason, sender_key FROM crypto_megolm_inbound_session + WHERE room_id=$1 AND session_id=$2 AND account_id=$3`, + roomID, sessionID, store.AccountID, + ).Scan(&code, &reason, &senderKey) if errors.Is(err, sql.ErrNoRows) { return nil, nil } else if err != nil || !code.Valid { @@ -453,7 +453,7 @@ func (store *SQLCryptoStore) GetWithheldGroupSession(ctx context.Context, roomID RoomID: roomID, Algorithm: id.AlgorithmMegolmV1, SessionID: sessionID, - SenderKey: senderKey, + SenderKey: id.Curve25519(senderKey.String), Code: event.RoomKeyWithheldCode(code.String), Reason: reason.String, }, nil From a87716a3583b3daa329734cb383b0e507e13fe67 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 14 May 2024 12:29:30 -0600 Subject: [PATCH 0191/1647] crypto/store: don't rely on sender key for storing and lookups * Fixes compatibility with the Store interface * Increases the usage of "defer"s for "gs.lock.Unlock" and "gs.lock.RUnlock" * Increases the usage of "golang.org/x/exp/maps" Signed-off-by: Sumner Evans Signed-off-by: Sumner Evans --- crypto/store.go | 207 +++++++++++++++++++----------------------------- 1 file changed, 83 insertions(+), 124 deletions(-) diff --git a/crypto/store.go b/crypto/store.go index 5918e7e4..a84d4f13 100644 --- a/crypto/store.go +++ b/crypto/store.go @@ -13,6 +13,7 @@ import ( "sync" "go.mau.fi/util/dbutil" + "golang.org/x/exp/maps" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -160,8 +161,8 @@ type MemoryStore struct { Account *OlmAccount Sessions map[id.SenderKey]OlmSessionList - GroupSessions map[id.RoomID]map[id.SenderKey]map[id.SessionID]*InboundGroupSession - WithheldGroupSessions map[id.RoomID]map[id.SenderKey]map[id.SessionID]*event.RoomKeyWithheldEventContent + GroupSessions map[id.RoomID]map[id.SessionID]*InboundGroupSession + WithheldGroupSessions map[id.RoomID]map[id.SessionID]*event.RoomKeyWithheldEventContent OutGroupSessions map[id.RoomID]*OutboundGroupSession SharedGroupSessions map[id.UserID]map[id.IdentityKey]map[id.SessionID]struct{} MessageIndices map[messageIndexKey]messageIndexValue @@ -182,8 +183,8 @@ func NewMemoryStore(saveCallback func() error) *MemoryStore { save: saveCallback, Sessions: make(map[id.SenderKey]OlmSessionList), - GroupSessions: make(map[id.RoomID]map[id.SenderKey]map[id.SessionID]*InboundGroupSession), - WithheldGroupSessions: make(map[id.RoomID]map[id.SenderKey]map[id.SessionID]*event.RoomKeyWithheldEventContent), + GroupSessions: make(map[id.RoomID]map[id.SessionID]*InboundGroupSession), + WithheldGroupSessions: make(map[id.RoomID]map[id.SessionID]*event.RoomKeyWithheldEventContent), OutGroupSessions: make(map[id.RoomID]*OutboundGroupSession), SharedGroupSessions: make(map[id.UserID]map[id.IdentityKey]map[id.SessionID]struct{}), MessageIndices: make(map[messageIndexKey]messageIndexValue), @@ -197,9 +198,8 @@ func NewMemoryStore(saveCallback func() error) *MemoryStore { func (gs *MemoryStore) Flush(_ context.Context) error { gs.lock.Lock() - err := gs.save() - gs.lock.Unlock() - return err + defer gs.lock.Unlock() + return gs.save() } func (gs *MemoryStore) GetAccount(_ context.Context) (*OlmAccount, error) { @@ -208,31 +208,29 @@ func (gs *MemoryStore) GetAccount(_ context.Context) (*OlmAccount, error) { func (gs *MemoryStore) PutAccount(_ context.Context, account *OlmAccount) error { gs.lock.Lock() + defer gs.lock.Unlock() gs.Account = account - err := gs.save() - gs.lock.Unlock() - return err + return gs.save() } func (gs *MemoryStore) GetSessions(_ context.Context, senderKey id.SenderKey) (OlmSessionList, error) { gs.lock.Lock() + defer gs.lock.Unlock() sessions, ok := gs.Sessions[senderKey] if !ok { sessions = []*OlmSession{} gs.Sessions[senderKey] = sessions } - gs.lock.Unlock() return sessions, nil } func (gs *MemoryStore) AddSession(_ context.Context, senderKey id.SenderKey, session *OlmSession) error { gs.lock.Lock() - sessions, _ := gs.Sessions[senderKey] + defer gs.lock.Unlock() + sessions := gs.Sessions[senderKey] gs.Sessions[senderKey] = append(sessions, session) sort.Sort(gs.Sessions[senderKey]) - err := gs.save() - gs.lock.Unlock() - return err + return gs.save() } func (gs *MemoryStore) UpdateSession(_ context.Context, _ id.SenderKey, _ *OlmSession) error { @@ -242,102 +240,86 @@ func (gs *MemoryStore) UpdateSession(_ context.Context, _ id.SenderKey, _ *OlmSe func (gs *MemoryStore) HasSession(_ context.Context, senderKey id.SenderKey) bool { gs.lock.RLock() + defer gs.lock.RUnlock() sessions, ok := gs.Sessions[senderKey] - gs.lock.RUnlock() return ok && len(sessions) > 0 && !sessions[0].Expired() } func (gs *MemoryStore) GetLatestSession(_ context.Context, senderKey id.SenderKey) (*OlmSession, error) { gs.lock.RLock() + defer gs.lock.RUnlock() sessions, ok := gs.Sessions[senderKey] - gs.lock.RUnlock() if !ok || len(sessions) == 0 { return nil, nil } return sessions[0], nil } -func (gs *MemoryStore) getGroupSessions(roomID id.RoomID, senderKey id.SenderKey) map[id.SessionID]*InboundGroupSession { +func (gs *MemoryStore) getGroupSessions(roomID id.RoomID) map[id.SessionID]*InboundGroupSession { room, ok := gs.GroupSessions[roomID] if !ok { - room = make(map[id.SenderKey]map[id.SessionID]*InboundGroupSession) + room = make(map[id.SessionID]*InboundGroupSession) gs.GroupSessions[roomID] = room } - sender, ok := room[senderKey] - if !ok { - sender = make(map[id.SessionID]*InboundGroupSession) - room[senderKey] = sender - } - return sender + return room } -func (gs *MemoryStore) PutGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, igs *InboundGroupSession) error { +func (gs *MemoryStore) PutGroupSession(_ context.Context, igs *InboundGroupSession) error { gs.lock.Lock() - gs.getGroupSessions(roomID, senderKey)[sessionID] = igs - err := gs.save() - gs.lock.Unlock() - return err + defer gs.lock.Unlock() + gs.getGroupSessions(igs.RoomID)[igs.ID()] = igs + return gs.save() } -func (gs *MemoryStore) GetGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) { +func (gs *MemoryStore) GetGroupSession(_ context.Context, roomID id.RoomID, sessionID id.SessionID) (*InboundGroupSession, error) { gs.lock.Lock() - session, ok := gs.getGroupSessions(roomID, senderKey)[sessionID] + defer gs.lock.Unlock() + session, ok := gs.getGroupSessions(roomID)[sessionID] if !ok { - withheld, ok := gs.getWithheldGroupSessions(roomID, senderKey)[sessionID] - gs.lock.Unlock() + withheld, ok := gs.getWithheldGroupSessions(roomID)[sessionID] if ok { return nil, fmt.Errorf("%w (%s)", ErrGroupSessionWithheld, withheld.Code) } return nil, nil } - gs.lock.Unlock() return session, nil } -func (gs *MemoryStore) RedactGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, reason string) error { +func (gs *MemoryStore) RedactGroupSession(_ context.Context, roomID id.RoomID, sessionID id.SessionID, reason string) error { gs.lock.Lock() - delete(gs.getGroupSessions(roomID, senderKey), sessionID) - err := gs.save() - gs.lock.Unlock() - return err + defer gs.lock.Unlock() + delete(gs.getGroupSessions(roomID), sessionID) + return gs.save() } func (gs *MemoryStore) RedactGroupSessions(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) { gs.lock.Lock() + defer gs.lock.Unlock() var sessionIDs []id.SessionID if roomID != "" && senderKey != "" { - sessions := gs.getGroupSessions(roomID, senderKey) - for sessionID := range sessions { - sessionIDs = append(sessionIDs, sessionID) - delete(sessions, sessionID) + sessions := gs.getGroupSessions(roomID) + for sessionID, session := range sessions { + if session.SenderKey == senderKey { + sessionIDs = append(sessionIDs, sessionID) + delete(sessions, sessionID) + } } } else if senderKey != "" { for _, room := range gs.GroupSessions { - sessions, ok := room[senderKey] - if ok { - for sessionID := range sessions { + for sessionID, session := range room { + if session.SenderKey == senderKey { sessionIDs = append(sessionIDs, sessionID) + delete(room, sessionID) } - delete(room, senderKey) } } } else if roomID != "" { - room, ok := gs.GroupSessions[roomID] - if ok { - for senderKey := range room { - sessions := room[senderKey] - for sessionID := range sessions { - sessionIDs = append(sessionIDs, sessionID) - } - } - delete(gs.GroupSessions, roomID) - } + sessionIDs = maps.Keys(gs.GroupSessions[roomID]) + delete(gs.GroupSessions, roomID) } else { return nil, fmt.Errorf("room ID or sender key must be provided for redacting sessions") } - err := gs.save() - gs.lock.Unlock() - return sessionIDs, err + return sessionIDs, gs.save() } func (gs *MemoryStore) RedactExpiredGroupSessions(_ context.Context) ([]id.SessionID, error) { @@ -348,32 +330,26 @@ func (gs *MemoryStore) RedactOutdatedGroupSessions(_ context.Context) ([]id.Sess return nil, fmt.Errorf("not implemented") } -func (gs *MemoryStore) getWithheldGroupSessions(roomID id.RoomID, senderKey id.SenderKey) map[id.SessionID]*event.RoomKeyWithheldEventContent { +func (gs *MemoryStore) getWithheldGroupSessions(roomID id.RoomID) map[id.SessionID]*event.RoomKeyWithheldEventContent { room, ok := gs.WithheldGroupSessions[roomID] if !ok { - room = make(map[id.SenderKey]map[id.SessionID]*event.RoomKeyWithheldEventContent) + room = make(map[id.SessionID]*event.RoomKeyWithheldEventContent) gs.WithheldGroupSessions[roomID] = room } - sender, ok := room[senderKey] - if !ok { - sender = make(map[id.SessionID]*event.RoomKeyWithheldEventContent) - room[senderKey] = sender - } - return sender + return room } func (gs *MemoryStore) PutWithheldGroupSession(_ context.Context, content event.RoomKeyWithheldEventContent) error { gs.lock.Lock() - gs.getWithheldGroupSessions(content.RoomID, content.SenderKey)[content.SessionID] = &content - err := gs.save() - gs.lock.Unlock() - return err + defer gs.lock.Unlock() + gs.getWithheldGroupSessions(content.RoomID)[content.SessionID] = &content + return gs.save() } -func (gs *MemoryStore) GetWithheldGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) { +func (gs *MemoryStore) GetWithheldGroupSession(_ context.Context, roomID id.RoomID, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) { gs.lock.Lock() - session, ok := gs.getWithheldGroupSessions(roomID, senderKey)[sessionID] - gs.lock.Unlock() + defer gs.lock.Unlock() + session, ok := gs.getWithheldGroupSessions(roomID)[sessionID] if !ok { return nil, nil } @@ -387,51 +363,38 @@ func (gs *MemoryStore) GetGroupSessionsForRoom(_ context.Context, roomID id.Room if !ok { return nil } - var result []*InboundGroupSession - for _, sessions := range room { - for _, session := range sessions { - result = append(result, session) - } - } - return dbutil.NewSliceIter[*InboundGroupSession](result) + return dbutil.NewSliceIter(maps.Values(room)) } func (gs *MemoryStore) GetAllGroupSessions(_ context.Context) dbutil.RowIter[*InboundGroupSession] { gs.lock.Lock() + defer gs.lock.Unlock() var result []*InboundGroupSession for _, room := range gs.GroupSessions { - for _, sessions := range room { - for _, session := range sessions { - result = append(result, session) - } - } + result = append(result, maps.Values(room)...) } - gs.lock.Unlock() - return dbutil.NewSliceIter[*InboundGroupSession](result) + return dbutil.NewSliceIter(result) } func (gs *MemoryStore) GetGroupSessionsWithoutKeyBackupVersion(_ context.Context, version id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession] { gs.lock.Lock() + defer gs.lock.Unlock() var result []*InboundGroupSession for _, room := range gs.GroupSessions { - for _, sessions := range room { - for _, session := range sessions { - if session.KeyBackupVersion != version { - result = append(result, session) - } + for _, session := range room { + if session.KeyBackupVersion != version { + result = append(result, session) } } } - gs.lock.Unlock() - return dbutil.NewSliceIter[*InboundGroupSession](result) + return dbutil.NewSliceIter(result) } func (gs *MemoryStore) AddOutboundGroupSession(_ context.Context, session *OutboundGroupSession) error { gs.lock.Lock() + defer gs.lock.Unlock() gs.OutGroupSessions[session.RoomID] = session - err := gs.save() - gs.lock.Unlock() - return err + return gs.save() } func (gs *MemoryStore) UpdateOutboundGroupSession(_ context.Context, _ *OutboundGroupSession) error { @@ -441,8 +404,8 @@ func (gs *MemoryStore) UpdateOutboundGroupSession(_ context.Context, _ *Outbound func (gs *MemoryStore) GetOutboundGroupSession(_ context.Context, roomID id.RoomID) (*OutboundGroupSession, error) { gs.lock.RLock() + defer gs.lock.RUnlock() session, ok := gs.OutGroupSessions[roomID] - gs.lock.RUnlock() if !ok { return nil, nil } @@ -451,18 +414,18 @@ func (gs *MemoryStore) GetOutboundGroupSession(_ context.Context, roomID id.Room func (gs *MemoryStore) RemoveOutboundGroupSession(_ context.Context, roomID id.RoomID) error { gs.lock.Lock() + defer gs.lock.Unlock() session, ok := gs.OutGroupSessions[roomID] if !ok || session == nil { - gs.lock.Unlock() return nil } delete(gs.OutGroupSessions, roomID) - gs.lock.Unlock() return nil } func (gs *MemoryStore) MarkOutboundGroupSessionShared(_ context.Context, userID id.UserID, identityKey id.IdentityKey, sessionID id.SessionID) error { gs.lock.Lock() + defer gs.lock.Unlock() if _, ok := gs.SharedGroupSessions[userID]; !ok { gs.SharedGroupSessions[userID] = make(map[id.IdentityKey]map[id.SessionID]struct{}) @@ -475,7 +438,6 @@ func (gs *MemoryStore) MarkOutboundGroupSessionShared(_ context.Context, userID identities[identityKey][sessionID] = struct{}{} - gs.lock.Unlock() return nil } @@ -521,11 +483,11 @@ func (gs *MemoryStore) ValidateMessageIndex(_ context.Context, senderKey id.Send func (gs *MemoryStore) GetDevices(_ context.Context, userID id.UserID) (map[id.DeviceID]*id.Device, error) { gs.lock.RLock() + defer gs.lock.RUnlock() devices, ok := gs.Devices[userID] if !ok { devices = nil } - gs.lock.RUnlock() return devices, nil } @@ -560,30 +522,30 @@ func (gs *MemoryStore) FindDeviceByKey(_ context.Context, userID id.UserID, iden func (gs *MemoryStore) PutDevice(_ context.Context, userID id.UserID, device *id.Device) error { gs.lock.Lock() + defer gs.lock.Unlock() devices, ok := gs.Devices[userID] if !ok { devices = make(map[id.DeviceID]*id.Device) gs.Devices[userID] = devices } devices[device.DeviceID] = device - err := gs.save() - gs.lock.Unlock() - return err + return gs.save() } func (gs *MemoryStore) PutDevices(_ context.Context, userID id.UserID, devices map[id.DeviceID]*id.Device) error { gs.lock.Lock() + defer gs.lock.Unlock() gs.Devices[userID] = devices err := gs.save() if err == nil { delete(gs.OutdatedUsers, userID) } - gs.lock.Unlock() return err } func (gs *MemoryStore) FilterTrackedUsers(_ context.Context, users []id.UserID) ([]id.UserID, error) { gs.lock.RLock() + defer gs.lock.RUnlock() var ptr int for _, userID := range users { _, ok := gs.Devices[userID] @@ -592,33 +554,33 @@ func (gs *MemoryStore) FilterTrackedUsers(_ context.Context, users []id.UserID) ptr++ } } - gs.lock.RUnlock() return users[:ptr], nil } func (gs *MemoryStore) MarkTrackedUsersOutdated(_ context.Context, users []id.UserID) error { gs.lock.Lock() + defer gs.lock.Unlock() for _, userID := range users { if _, ok := gs.Devices[userID]; ok { gs.OutdatedUsers[userID] = struct{}{} } } - gs.lock.Unlock() return nil } func (gs *MemoryStore) GetOutdatedTrackedUsers(_ context.Context) ([]id.UserID, error) { gs.lock.RLock() + defer gs.lock.RUnlock() users := make([]id.UserID, 0, len(gs.OutdatedUsers)) for userID := range gs.OutdatedUsers { users = append(users, userID) } - gs.lock.RUnlock() return users, nil } func (gs *MemoryStore) PutCrossSigningKey(_ context.Context, userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error { gs.lock.RLock() + defer gs.lock.RUnlock() userKeys, ok := gs.CrossSigningKeys[userID] if !ok { userKeys = make(map[id.CrossSigningUsage]id.CrossSigningKey) @@ -635,7 +597,6 @@ func (gs *MemoryStore) PutCrossSigningKey(_ context.Context, userID id.UserID, u } } err := gs.save() - gs.lock.RUnlock() return err } @@ -651,6 +612,7 @@ func (gs *MemoryStore) GetCrossSigningKeys(_ context.Context, userID id.UserID) func (gs *MemoryStore) PutSignature(_ context.Context, signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error { gs.lock.RLock() + defer gs.lock.RUnlock() signedUserSigs, ok := gs.KeySignatures[signedUserID] if !ok { signedUserSigs = make(map[id.Ed25519]map[id.UserID]map[id.Ed25519]string) @@ -667,9 +629,7 @@ func (gs *MemoryStore) PutSignature(_ context.Context, signedUserID id.UserID, s signaturesForKey[signerUserID] = signedByUser } signedByUser[signerKey] = signature - err := gs.save() - gs.lock.RUnlock() - return err + return gs.save() } func (gs *MemoryStore) GetSignaturesForKeyBy(_ context.Context, userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) { @@ -700,8 +660,9 @@ func (gs *MemoryStore) IsKeySignedBy(ctx context.Context, userID id.UserID, key } func (gs *MemoryStore) DropSignaturesByKey(_ context.Context, userID id.UserID, key id.Ed25519) (int64, error) { - var count int64 gs.lock.RLock() + defer gs.lock.RUnlock() + var count int64 for _, userSigs := range gs.KeySignatures { for _, keySigs := range userSigs { if signedBySigner, ok := keySigs[userID]; ok { @@ -712,27 +673,25 @@ func (gs *MemoryStore) DropSignaturesByKey(_ context.Context, userID id.UserID, } } } - gs.lock.RUnlock() return count, nil } func (gs *MemoryStore) PutSecret(_ context.Context, name id.Secret, value string) error { gs.lock.Lock() + defer gs.lock.Unlock() gs.Secrets[name] = value - gs.lock.Unlock() return nil } -func (gs *MemoryStore) GetSecret(_ context.Context, name id.Secret) (value string, _ error) { +func (gs *MemoryStore) GetSecret(_ context.Context, name id.Secret) (string, error) { gs.lock.RLock() - value = gs.Secrets[name] - gs.lock.RUnlock() - return + defer gs.lock.RUnlock() + return gs.Secrets[name], nil } func (gs *MemoryStore) DeleteSecret(_ context.Context, name id.Secret) error { gs.lock.Lock() + defer gs.lock.Unlock() delete(gs.Secrets, name) - gs.lock.Unlock() return nil } From de0347db00ccd2212e1a7ff136b2c705fbf09643 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 14 May 2024 12:31:46 -0600 Subject: [PATCH 0192/1647] crypto: fix usages of Store interface Signed-off-by: Sumner Evans --- crypto/decryptmegolm.go | 8 ++++---- crypto/keybackup.go | 2 +- crypto/keyimport.go | 4 ++-- crypto/keysharing.go | 10 +++++----- crypto/machine.go | 6 +++--- crypto/machine_test.go | 4 ++-- crypto/store_test.go | 4 ++-- 7 files changed, 19 insertions(+), 19 deletions(-) diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index abe01871..faabdbd6 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -192,7 +192,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve mach.megolmDecryptLock.Lock() defer mach.megolmDecryptLock.Unlock() - sess, err := mach.CryptoStore.GetGroupSession(ctx, encryptionRoomID, content.SenderKey, content.SessionID) + sess, err := mach.CryptoStore.GetGroupSession(ctx, encryptionRoomID, content.SessionID) if err != nil { return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err) } else if sess == nil { @@ -254,7 +254,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve Int("max_messages", sess.MaxMessages). Logger() if sess.MaxMessages > 0 && int(ratchetTargetIndex) >= sess.MaxMessages && len(sess.RatchetSafety.MissedIndices) == 0 && mach.DeleteFullyUsedKeysOnDecrypt { - err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), "maximum messages reached") + err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.ID(), "maximum messages reached") if err != nil { log.Err(err).Msg("Failed to delete fully used session") return sess, plaintext, messageIndex, RatchetError @@ -265,14 +265,14 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve if err = sess.RatchetTo(ratchetTargetIndex); err != nil { log.Err(err).Msg("Failed to ratchet session") return sess, plaintext, messageIndex, RatchetError - } else if err = mach.CryptoStore.PutGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil { + } else if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil { log.Err(err).Msg("Failed to store ratcheted session") return sess, plaintext, messageIndex, RatchetError } else { log.Info().Msg("Ratcheted session forward") } } else if didModify { - if err = mach.CryptoStore.PutGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil { + if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil { log.Err(err).Msg("Failed to store updated ratchet safety data") return sess, plaintext, messageIndex, RatchetError } else { diff --git a/crypto/keybackup.go b/crypto/keybackup.go index afdb84fd..820f3114 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -177,7 +177,7 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id. MaxMessages: maxMessages, KeyBackupVersion: version, } - err = mach.CryptoStore.PutGroupSession(ctx, roomID, keyBackupData.SenderKey, sessionID, igs) + err = mach.CryptoStore.PutGroupSession(ctx, igs) if err != nil { return fmt.Errorf("failed to store new inbound group session: %w", err) } diff --git a/crypto/keyimport.go b/crypto/keyimport.go index da51774f..6c320f43 100644 --- a/crypto/keyimport.go +++ b/crypto/keyimport.go @@ -113,12 +113,12 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor ReceivedAt: time.Now().UTC(), } - existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.SenderKey, igs.ID()) + existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID()) if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() { // We already have an equivalent or better session in the store, so don't override it. return false, nil } - err = mach.CryptoStore.PutGroupSession(ctx, igs.RoomID, igs.SenderKey, igs.ID(), igs) + err = mach.CryptoStore.PutGroupSession(ctx, igs) if err != nil { return false, fmt.Errorf("failed to store imported session: %w", err) } diff --git a/crypto/keysharing.go b/crypto/keysharing.go index d1b2e92c..ad0011e5 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -184,12 +184,12 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt MaxMessages: maxMessages, IsScheduled: content.IsScheduled, } - existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.SenderKey, igs.ID()) + existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID()) if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() { // We already have an equivalent or better session in the store, so don't override it. return false } - err = mach.CryptoStore.PutGroupSession(ctx, content.RoomID, content.SenderKey, content.SessionID, igs) + err = mach.CryptoStore.PutGroupSession(ctx, igs) if err != nil { log.Error().Err(err).Msg("Failed to store new inbound group session") return false @@ -308,7 +308,7 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User return } - igs, err := mach.CryptoStore.GetGroupSession(ctx, content.Body.RoomID, content.Body.SenderKey, content.Body.SessionID) + igs, err := mach.CryptoStore.GetGroupSession(ctx, content.Body.RoomID, content.Body.SessionID) if err != nil { if errors.Is(err, ErrGroupSessionWithheld) { log.Debug().Err(err).Msg("Requested group session not available") @@ -365,7 +365,7 @@ func (mach *OlmMachine) HandleBeeperRoomKeyAck(ctx context.Context, sender id.Us Int("first_message_index", content.FirstMessageIndex). Logger() - sess, err := mach.CryptoStore.GetGroupSession(ctx, content.RoomID, "", content.SessionID) + sess, err := mach.CryptoStore.GetGroupSession(ctx, content.RoomID, content.SessionID) if err != nil { if errors.Is(err, ErrGroupSessionWithheld) { log.Debug().Err(err).Msg("Acked group session was already redacted") @@ -385,7 +385,7 @@ func (mach *OlmMachine) HandleBeeperRoomKeyAck(ctx context.Context, sender id.Us isInbound := sess.SenderKey == mach.OwnIdentity().IdentityKey if isInbound && mach.DeleteOutboundKeysOnAck && content.FirstMessageIndex == 0 { log.Debug().Msg("Redacting inbound copy of outbound group session after ack") - err = mach.CryptoStore.RedactGroupSession(ctx, content.RoomID, sess.SenderKey, content.SessionID, "outbound session acked") + err = mach.CryptoStore.RedactGroupSession(ctx, content.RoomID, content.SessionID, "outbound session acked") if err != nil { log.Err(err).Msg("Failed to redact group session") } diff --git a/crypto/machine.go b/crypto/machine.go index 0cdeebea..efd65799 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -517,7 +517,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen Msg("Mismatched session ID while creating inbound group session") return fmt.Errorf("mismatched session ID while creating inbound group session") } - err = mach.CryptoStore.PutGroupSession(ctx, roomID, senderKey, sessionID, igs) + err = mach.CryptoStore.PutGroupSession(ctx, igs) if err != nil { log.Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session") return fmt.Errorf("failed to store new inbound group session: %w", err) @@ -557,7 +557,7 @@ func (mach *OlmMachine) WaitForSession(ctx context.Context, roomID id.RoomID, se } mach.keyWaitersLock.Unlock() // Handle race conditions where a session appears between the failed decryption and WaitForSession call. - sess, err := mach.CryptoStore.GetGroupSession(ctx, roomID, senderKey, sessionID) + sess, err := mach.CryptoStore.GetGroupSession(ctx, roomID, sessionID) if sess != nil || errors.Is(err, ErrGroupSessionWithheld) { return true } @@ -565,7 +565,7 @@ func (mach *OlmMachine) WaitForSession(ctx context.Context, roomID id.RoomID, se case <-ch: return true case <-time.After(timeout): - sess, err = mach.CryptoStore.GetGroupSession(ctx, roomID, senderKey, sessionID) + sess, err = mach.CryptoStore.GetGroupSession(ctx, roomID, sessionID) // Check if the session somehow appeared in the store without telling us // We accept withheld sessions as received, as then the decryption attempt will show the error. return sess != nil || errors.Is(err, ErrGroupSessionWithheld) diff --git a/crypto/machine_test.go b/crypto/machine_test.go index 807b65b2..59c86236 100644 --- a/crypto/machine_test.go +++ b/crypto/machine_test.go @@ -58,7 +58,7 @@ func TestRatchetMegolmSession(t *testing.T) { mach := newMachine(t, "user1") outSess, err := mach.newOutboundGroupSession(context.TODO(), "meow") assert.NoError(t, err) - inSess, err := mach.CryptoStore.GetGroupSession(context.TODO(), "meow", mach.OwnIdentity().IdentityKey, outSess.ID()) + inSess, err := mach.CryptoStore.GetGroupSession(context.TODO(), "meow", outSess.ID()) require.NoError(t, err) assert.Equal(t, uint32(0), inSess.Internal.FirstKnownIndex()) err = inSess.RatchetTo(10) @@ -130,7 +130,7 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { if err != nil { t.Errorf("Error creating inbound megolm session: %v", err) } - if err = machineIn.CryptoStore.PutGroupSession(context.TODO(), "room1", senderKey, igs.ID(), igs); err != nil { + if err = machineIn.CryptoStore.PutGroupSession(context.TODO(), igs); err != nil { t.Errorf("Error storing inbound megolm session: %v", err) } } diff --git a/crypto/store_test.go b/crypto/store_test.go index e6969e3e..740273dd 100644 --- a/crypto/store_test.go +++ b/crypto/store_test.go @@ -158,12 +158,12 @@ func TestStoreMegolmSession(t *testing.T) { RoomID: "room1", } - err = store.PutGroupSession(context.TODO(), "room1", acc.IdentityKey(), igs.ID(), igs) + err = store.PutGroupSession(context.TODO(), igs) if err != nil { t.Errorf("Error storing inbound group session: %v", err) } - retrieved, err := store.GetGroupSession(context.TODO(), "room1", acc.IdentityKey(), igs.ID()) + retrieved, err := store.GetGroupSession(context.TODO(), "room1", igs.ID()) if err != nil { t.Errorf("Error retrieving inbound group session: %v", err) } From 3651e46c1ebaf3cf06737aeb7a97bfcd880b303a Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 14 May 2024 21:55:32 -0600 Subject: [PATCH 0193/1647] ShareGroupSession: return error in more cases * If getting the devices from the database fails * If FetchKeys fails * If createOutboundSessions fails Signed-off-by: Sumner Evans --- crypto/encryptmegolm.go | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index 19ff5c8c..634a685f 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -200,16 +200,17 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, olmSessions := make(map[id.UserID]map[id.DeviceID]deviceSessionWrapper) missingSessions := make(map[id.UserID]map[id.DeviceID]*id.Device) missingUserSessions := make(map[id.DeviceID]*id.Device) - var fetchKeys []id.UserID + var fetchKeysForUsers []id.UserID for _, userID := range users { log := log.With().Str("target_user_id", userID.String()).Logger() devices, err := mach.CryptoStore.GetDevices(ctx, userID) if err != nil { - log.Error().Err(err).Msg("Failed to get devices of user") + log.Err(err).Msg("Failed to get devices of user") + return fmt.Errorf("failed to get devices of user %s: %w", userID, err) } else if devices == nil { log.Debug().Msg("GetDevices returned nil, will fetch keys and retry") - fetchKeys = append(fetchKeys, userID) + fetchKeysForUsers = append(fetchKeysForUsers, userID) } else if len(devices) == 0 { log.Trace().Msg("User has no devices, skipping") } else { @@ -233,18 +234,19 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, } } - if len(fetchKeys) > 0 { - log.Debug().Strs("users", strishArray(fetchKeys)).Msg("Fetching missing keys") - if keys, err := mach.FetchKeys(ctx, fetchKeys, true); err != nil { - log.Err(err).Strs("users", strishArray(fetchKeys)).Msg("Failed to fetch missing keys") - } else if keys != nil { - for userID, devices := range keys { - log.Debug(). - Int("device_count", len(devices)). - Str("target_user_id", userID.String()). - Msg("Got device keys for user") - missingSessions[userID] = devices - } + if len(fetchKeysForUsers) > 0 { + log.Debug().Strs("users", strishArray(fetchKeysForUsers)).Msg("Fetching missing keys") + keys, err := mach.FetchKeys(ctx, fetchKeysForUsers, true) + if err != nil { + log.Err(err).Strs("users", strishArray(fetchKeysForUsers)).Msg("Failed to fetch missing keys") + return fmt.Errorf("failed to fetch missing keys: %w", err) + } + for userID, devices := range keys { + log.Debug(). + Int("device_count", len(devices)). + Str("target_user_id", userID.String()). + Msg("Got device keys for user") + missingSessions[userID] = devices } } @@ -252,7 +254,8 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, log.Debug().Msg("Creating missing olm sessions") err = mach.createOutboundSessions(ctx, missingSessions) if err != nil { - log.Error().Err(err).Msg("Failed to create missing olm sessions") + log.Err(err).Msg("Failed to create missing olm sessions") + return fmt.Errorf("failed to create missing olm sessions: %w", err) } } From 654b82ec73cc942fd3e61f1f367dd17980255a27 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 16 May 2024 16:04:51 +0300 Subject: [PATCH 0194/1647] Update dependencies --- go.mod | 10 +++++----- go.sum | 20 ++++++++++---------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/go.mod b/go.mod index 5afaff8d..4c975e1c 100644 --- a/go.mod +++ b/go.mod @@ -12,11 +12,11 @@ require ( github.com/tidwall/gjson v1.17.1 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.1 - go.mau.fi/util v0.4.2 + go.mau.fi/util v0.4.3-0.20240516130206-1051a7dd4dd2 go.mau.fi/zeroconfig v0.1.2 - golang.org/x/crypto v0.22.0 - golang.org/x/exp v0.0.0-20240409090435-93d18d7e34b8 - golang.org/x/net v0.24.0 + golang.org/x/crypto v0.23.0 + golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 + golang.org/x/net v0.25.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 ) @@ -29,6 +29,6 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect - golang.org/x/sys v0.19.0 // indirect + golang.org/x/sys v0.20.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index 3704b9f2..b3e72206 100644 --- a/go.sum +++ b/go.sum @@ -37,21 +37,21 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.1 h1:3bajkSilaCbjdKVsKdZjZCLBNPL9pYzrCakKaf4U49U= github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.4.2 h1:RR3TOcRHmCF9Bx/3YG4S65MYfa+nV6/rn8qBWW4Mi30= -go.mau.fi/util v0.4.2/go.mod h1:PlAVfUUcPyHPrwnvjkJM9UFcPE7qGPDJqk+Oufa1Gtw= +go.mau.fi/util v0.4.3-0.20240516130206-1051a7dd4dd2 h1:yL6ZH3O5AzH+5rKDrHGEdvmkmdePwjWfWvChD6nmY/M= +go.mau.fi/util v0.4.3-0.20240516130206-1051a7dd4dd2/go.mod h1:m+PJpPMadAW6cj3ldyuO5bLhFreWdwcu+3QTwYNGlGk= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= -golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= -golang.org/x/exp v0.0.0-20240409090435-93d18d7e34b8 h1:ESSUROHIBHg7USnszlcdmjBEwdMj9VUvU+OPk4yl2mc= -golang.org/x/exp v0.0.0-20240409090435-93d18d7e34b8/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI= -golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= -golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= +golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= +golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= +golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= +golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= -golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= From 3bd42f5a8236b907e196da3e1482b7a2a4fc8668 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 16 May 2024 17:14:08 +0300 Subject: [PATCH 0195/1647] Add option to disable tracking megolm session ratchet state The tracking is meant for bridges/bots that want to delete old ratchet states after they're not needed, but for normal clients it's just unnecessary overhead --- crypto/decryptmegolm.go | 5 +++++ crypto/machine.go | 1 + go.mod | 2 +- go.sum | 4 ++-- 4 files changed, 9 insertions(+), 3 deletions(-) diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index faabdbd6..14beb96b 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -213,6 +213,11 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve return sess, nil, messageIndex, fmt.Errorf("%w %d", DuplicateMessageIndex, messageIndex) } + // Normal clients don't care about tracking the ratchet state, so let them bypass the rest of the function + if mach.DisableRatchetTracking { + return sess, plaintext, messageIndex, nil + } + expectedMessageIndex := sess.RatchetSafety.NextIndex didModify := false switch { diff --git a/crypto/machine.go b/crypto/machine.go index efd65799..abb8d540 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -78,6 +78,7 @@ type OlmMachine struct { RatchetKeysOnDecrypt bool DeleteFullyUsedKeysOnDecrypt bool DeleteKeysOnDeviceDelete bool + DisableRatchetTracking bool DisableDeviceChangeKeyRotation bool diff --git a/go.mod b/go.mod index 4c975e1c..f1c025b4 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/tidwall/gjson v1.17.1 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.1 - go.mau.fi/util v0.4.3-0.20240516130206-1051a7dd4dd2 + go.mau.fi/util v0.4.3-0.20240516141139-2ebe792cd8f7 go.mau.fi/zeroconfig v0.1.2 golang.org/x/crypto v0.23.0 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 diff --git a/go.sum b/go.sum index b3e72206..4e4a26c9 100644 --- a/go.sum +++ b/go.sum @@ -37,8 +37,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.1 h1:3bajkSilaCbjdKVsKdZjZCLBNPL9pYzrCakKaf4U49U= github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.4.3-0.20240516130206-1051a7dd4dd2 h1:yL6ZH3O5AzH+5rKDrHGEdvmkmdePwjWfWvChD6nmY/M= -go.mau.fi/util v0.4.3-0.20240516130206-1051a7dd4dd2/go.mod h1:m+PJpPMadAW6cj3ldyuO5bLhFreWdwcu+3QTwYNGlGk= +go.mau.fi/util v0.4.3-0.20240516141139-2ebe792cd8f7 h1:2hnc2iS7usHT3aqIQ8HVtKtPgic+13EVSdZ1m8UBL/E= +go.mau.fi/util v0.4.3-0.20240516141139-2ebe792cd8f7/go.mod h1:m+PJpPMadAW6cj3ldyuO5bLhFreWdwcu+3QTwYNGlGk= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= From dd1dfb9bab27a403e312bde65deb272a23585ca8 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 20 May 2024 10:57:51 -0600 Subject: [PATCH 0196/1647] pkcs7: update parameter names and documentation Signed-off-by: Sumner Evans --- crypto/pkcs7/pkcs7.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/crypto/pkcs7/pkcs7.go b/crypto/pkcs7/pkcs7.go index c83c5afd..dc28ed6a 100644 --- a/crypto/pkcs7/pkcs7.go +++ b/crypto/pkcs7/pkcs7.go @@ -8,23 +8,23 @@ package pkcs7 import "bytes" -// Pad implements PKCS#7 padding as defined in [RFC2315]. It pads the plaintext -// to the given blockSize in the range [1, 255]. This is normally used in -// AES-CBC encryption. +// Pad implements PKCS#7 padding as defined in [RFC2315]. It pads the data to +// the given blockSize in the range [1, 255]. This is normally used in AES-CBC +// encryption. // // [RFC2315]: https://www.ietf.org/rfc/rfc2315.txt -func Pad(plaintext []byte, blockSize int) []byte { - padding := blockSize - len(plaintext)%blockSize - return append(plaintext, bytes.Repeat([]byte{byte(padding)}, padding)...) +func Pad(data []byte, blockSize int) []byte { + padding := blockSize - len(data)%blockSize + return append(data, bytes.Repeat([]byte{byte(padding)}, padding)...) } // Unpad implements PKCS#7 unpadding as defined in [RFC2315]. It unpads the -// plaintext by reading the padding amount from the last byte of the plaintext. -// This is normally used in AES-CBC decryption. +// data by reading the padding amount from the last byte of the data. This is +// normally used in AES-CBC decryption. // // [RFC2315]: https://www.ietf.org/rfc/rfc2315.txt -func Unpad(plaintext []byte) []byte { - length := len(plaintext) - unpadding := int(plaintext[length-1]) - return plaintext[:length-unpadding] +func Unpad(data []byte) []byte { + length := len(data) + unpadding := int(data[length-1]) + return data[:length-unpadding] } From 1c054a4f5c103e053506d82a99f9d2e71da7fbbf Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 20 May 2024 10:59:34 -0600 Subject: [PATCH 0197/1647] verificationhelper: actually sign master key Signed-off-by: Sumner Evans --- crypto/verificationhelper/reciprocate.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index ab177eb9..5eb654e3 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -76,6 +76,9 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by return fmt.Errorf("the other device has the wrong key for this device") } + if err := vh.mach.SignOwnMasterKey(ctx); err != nil { + return fmt.Errorf("failed to sign own master key: %w", err) + } case QRCodeModeSelfVerifyingMasterKeyUntrusted: // The QR was created by a device that does not trust the master key, // which means that we do trust the master key. Key1 is the other @@ -192,8 +195,10 @@ func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id return fmt.Errorf("failed to sign their device: %w", err) } } + } else { + // TODO: handle QR codes that are not self-signing situations + panic("unimplemented") } - // TODO: handle QR codes that are not self-signing situations err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { From 800d061426383a8a743dc362903e1c5dd0915069 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 20 May 2024 10:59:52 -0600 Subject: [PATCH 0198/1647] verificationhelper: fix check for whether we trust the master key Signed-off-by: Sumner Evans --- crypto/verificationhelper/reciprocate.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index 5eb654e3..d543dd9f 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -231,7 +231,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *ve mode := QRCodeModeCrossSigning if vh.client.UserID == txn.TheirUser { // This is a self-signing situation. - if trusted, err := vh.mach.IsUserTrusted(ctx, vh.client.UserID); err != nil { + if trusted, err := vh.mach.CryptoStore.IsKeySignedBy(ctx, vh.client.UserID, ownCrossSigningPublicKeys.MasterKey, vh.client.UserID, vh.mach.OwnIdentity().SigningKey); err != nil { return err } else if trusted { mode = QRCodeModeSelfVerifyingMasterKeyTrusted From 816d94077ded55bf4d7f9b9d6193af291006d4f4 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 20 May 2024 11:13:15 -0600 Subject: [PATCH 0199/1647] verificationhelper: verify we trust master key when scanning a device that doesn't Signed-off-by: Sumner Evans --- crypto/verificationhelper/reciprocate.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index d543dd9f..cb53833f 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -46,6 +46,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by // Verify the keys log.Info().Msg("Verifying keys from QR code") + ownCrossSigningPublicKeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) switch qrCode.Mode { case QRCodeModeCrossSigning: panic("unimplemented") @@ -60,8 +61,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by } // Verify the master key is correct - crossSigningPubkeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) - if bytes.Equal(crossSigningPubkeys.MasterKey.Bytes(), qrCode.Key1[:]) { + if bytes.Equal(ownCrossSigningPublicKeys.MasterKey.Bytes(), qrCode.Key1[:]) { log.Info().Msg("Verified that the other device has the same master key") } else { return fmt.Errorf("the master key does not match") @@ -85,6 +85,13 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by // device's device key, and Key2 is what the other device thinks the // master key is. + // Check that we actually trust the master key. + if trusted, err := vh.mach.CryptoStore.IsKeySignedBy(ctx, vh.client.UserID, ownCrossSigningPublicKeys.MasterKey, vh.client.UserID, vh.mach.OwnIdentity().SigningKey); err != nil { + return err + } else if !trusted { + return fmt.Errorf("the master key is not trusted by this device") + } + if vh.client.UserID != txn.TheirUser { return fmt.Errorf("mode %d is only allowed when the other user is the same as the current user", qrCode.Mode) } From ef65138cf9ec244601ae9cdc312866476ad90bda Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 21 May 2024 07:12:53 -0600 Subject: [PATCH 0200/1647] verification: check IdentityKey instead of SigningKey in QR mode 2 Signed-off-by: Sumner Evans --- crypto/verificationhelper/reciprocate.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index cb53833f..2052d522 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -103,7 +103,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by } // Verify that the other device's key is what we expect. - if bytes.Equal(theirDevice.SigningKey.Bytes(), qrCode.Key1[:]) { + if bytes.Equal(theirDevice.IdentityKey.Bytes(), qrCode.Key1[:]) { log.Info().Msg("Verified that the other device key is what we expected") } else { return fmt.Errorf("the other device's key is not what we expected") From 4c8b63da5b62eaab32791b292a47c079394eef8b Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 21 May 2024 09:54:06 -0600 Subject: [PATCH 0201/1647] verification: log transaction ID and from_device on verification request Signed-off-by: Sumner Evans --- crypto/verificationhelper/verificationhelper.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 31f237bd..cede8156 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -573,7 +573,11 @@ func (vh *VerificationHelper) onVerificationRequest(ctx context.Context, evt *ev return } - log = log.With().Any("requested_methods", verificationRequest.Methods).Logger() + log = log.With(). + Any("requested_methods", verificationRequest.Methods). + Stringer("transaction_id", verificationRequest.TransactionID). + Stringer("from_device", verificationRequest.FromDevice). + Logger() ctx = log.WithContext(ctx) log.Info().Msg("Received verification request") From 9917b3ad3c44106dcaacc06fed3da6d987ee1640 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Tue, 21 May 2024 19:33:19 +0100 Subject: [PATCH 0202/1647] Add `UpdateRequestOnRetry` client hook Enables modifying the request object between retries, eg. to switch contexts after cancel. --- client.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/client.go b/client.go index cc133f95..10d1b2b9 100644 --- a/client.go +++ b/client.go @@ -67,6 +67,8 @@ type Client struct { RequestHook func(req *http.Request) ResponseHook func(req *http.Request, resp *http.Response, err error, duration time.Duration) + UpdateRequestOnRetry func(req *http.Request, cause error) *http.Request + SyncPresence event.Presence StreamSyncMinAge time.Duration @@ -457,6 +459,9 @@ func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff Int("retry_in_seconds", int(backoff.Seconds())). Msg("Request failed, retrying") time.Sleep(backoff) + if cli.UpdateRequestOnRetry != nil { + req = cli.UpdateRequestOnRetry(req, cause) + } return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, client) } From 55f47fbb16e90819a9d33fdfd2a1c0ec1c70f296 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Wed, 22 May 2024 17:20:40 -0600 Subject: [PATCH 0203/1647] verificationhelper: fix sending cancellation to other devices Signed-off-by: Sumner Evans --- crypto/verificationhelper/verificationhelper.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index cede8156..46112b5b 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -646,7 +646,7 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *veri req.Messages[txn.TheirUser][deviceID] = content } - _, err = vh.client.SendToDevice(ctx, event.ToDeviceVerificationRequest, &req) + _, err = vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) if err != nil { log.Warn().Err(err).Msg("Failed to send cancellation requests") } From 843ba24d0ab7b8eaa735207a59f0ce79a21f0591 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 23 May 2024 10:02:10 -0600 Subject: [PATCH 0204/1647] cross signing: don't require master private key to sign master public key Signed-off-by: Sumner Evans --- crypto/cross_sign_signing.go | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/crypto/cross_sign_signing.go b/crypto/cross_sign_signing.go index 86920728..1d80cc91 100644 --- a/crypto/cross_sign_signing.go +++ b/crypto/cross_sign_signing.go @@ -19,15 +19,15 @@ import ( ) var ( - ErrCrossSigningKeysNotCached = errors.New("cross-signing private keys not in cache") - ErrUserSigningKeyNotCached = errors.New("user-signing private key not in cache") - ErrSelfSigningKeyNotCached = errors.New("self-signing private key not in cache") - ErrSignatureUploadFail = errors.New("server-side failure uploading signatures") - ErrCantSignOwnMasterKey = errors.New("signing your own master key is not allowed") - ErrCantSignOtherDevice = errors.New("signing other users' devices is not allowed") - ErrUserNotInQueryResponse = errors.New("could not find user in query keys response") - ErrDeviceNotInQueryResponse = errors.New("could not find device in query keys response") - ErrOlmAccountNotLoaded = errors.New("olm account has not been loaded") + ErrCrossSigningPubkeysNotCached = errors.New("cross-signing public keys not in cache") + ErrUserSigningKeyNotCached = errors.New("user-signing private key not in cache") + ErrSelfSigningKeyNotCached = errors.New("self-signing private key not in cache") + ErrSignatureUploadFail = errors.New("server-side failure uploading signatures") + ErrCantSignOwnMasterKey = errors.New("signing your own master key is not allowed") + ErrCantSignOtherDevice = errors.New("signing other users' devices is not allowed") + ErrUserNotInQueryResponse = errors.New("could not find user in query keys response") + ErrDeviceNotInQueryResponse = errors.New("could not find device in query keys response") + ErrOlmAccountNotLoaded = errors.New("olm account has not been loaded") ErrCrossSigningMasterKeyNotFound = errors.New("cross-signing master key not found") ErrMasterKeyMACNotFound = errors.New("found cross-signing master key, but didn't find corresponding MAC in verification request") @@ -69,15 +69,16 @@ func (mach *OlmMachine) SignUser(ctx context.Context, userID id.UserID, masterKe // SignOwnMasterKey uses the current account for signing the current user's master key and uploads the signature. func (mach *OlmMachine) SignOwnMasterKey(ctx context.Context) error { - if mach.CrossSigningKeys == nil { - return ErrCrossSigningKeysNotCached + crossSigningPubkeys := mach.GetOwnCrossSigningPublicKeys(ctx) + if crossSigningPubkeys == nil { + return ErrCrossSigningPubkeysNotCached } else if mach.account == nil { return ErrOlmAccountNotLoaded } userID := mach.Client.UserID deviceID := mach.Client.DeviceID - masterKey := mach.CrossSigningKeys.MasterKey.PublicKey() + masterKey := crossSigningPubkeys.MasterKey masterKeyObj := mautrix.ReqKeysSignatures{ UserID: userID, From 3e8221b17d3e3714d758d8fd453c6904207bd5e9 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 23 May 2024 12:55:09 -0600 Subject: [PATCH 0205/1647] verificationhelper: don't send cancellation to self Signed-off-by: Sumner Evans --- crypto/verificationhelper/verificationhelper.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 46112b5b..2acdd47f 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -638,9 +638,10 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *veri } req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUser: {}}} for deviceID := range devices { - if deviceID == txn.TheirDevice { + if deviceID == txn.TheirDevice || deviceID == vh.client.DeviceID { // Don't ever send a cancellation to the device that accepted - // the request. + // the request or to our own device (which can happen if this + // is a self-verification). continue } From 842852a6c148b49c3008f6a324dedf19e9b5640b Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 23 May 2024 16:45:49 -0600 Subject: [PATCH 0206/1647] crypto/cross_sign_ssss: trust master key during generation and upload Signed-off-by: Sumner Evans --- crypto/cross_sign_ssss.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/crypto/cross_sign_ssss.go b/crypto/cross_sign_ssss.go index 389a9fd2..540d625d 100644 --- a/crypto/cross_sign_ssss.go +++ b/crypto/cross_sign_ssss.go @@ -100,6 +100,12 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(ctx context.Context, u return "", nil, fmt.Errorf("failed to publish cross-signing keys: %w", err) } + // Trust the master key + err = mach.SignOwnMasterKey(ctx) + if err != nil { + return "", nil, fmt.Errorf("failed to sign own master key: %w", err) + } + err = mach.SSSS.SetDefaultKeyID(ctx, key.ID) if err != nil { return "", nil, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err) From 30132a2c85f1a998af043adbb1b02253ea61b35c Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 21 May 2024 12:34:09 -0600 Subject: [PATCH 0207/1647] statestore: implement FindSharedRooms on MemoryStateStore Signed-off-by: Sumner Evans --- statestore.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/statestore.go b/statestore.go index 8fe5f8b3..c1193009 100644 --- a/statestore.go +++ b/statestore.go @@ -269,3 +269,14 @@ func (store *MemoryStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID cfg, err := store.GetEncryptionEvent(ctx, roomID) return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1, err } + +func (store *MemoryStateStore) FindSharedRooms(ctx context.Context, userID id.UserID) (rooms []id.RoomID, err error) { + store.membersLock.RLock() + defer store.membersLock.RUnlock() + for roomID, members := range store.Members { + if _, ok := members[userID]; ok { + rooms = append(rooms, roomID) + } + } + return rooms, nil +} From 3bb4648c01eb82c3100c1f01ccd07d55c0c2d00c Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 24 May 2024 11:10:50 -0600 Subject: [PATCH 0208/1647] verification/qr: use SigningKey instead of IdentityKey It turns out that it's supposed to be the signing key. See discussion about it in the #e2e:matrix.org room: https://matrix.to/#/!vlnjqGLpLJlFmBSkfQ:matrix.org/$J6UbQwsakEsHMbv5yH7RUpM-OlklZ4U3Ti3VqWp9p8E?via=matrix.org&via=privacytools.io&via=envs.net This commit reverts commit ef65138cf9ec244601ae9cdc312866476ad90bda: verification: check IdentityKey instead of SigningKey in QR mode 2 It also fixes generation to use the signing key instead of the identity key. Signed-off-by: Sumner Evans --- crypto/verificationhelper/reciprocate.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index 2052d522..c86f7d19 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -103,7 +103,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by } // Verify that the other device's key is what we expect. - if bytes.Equal(theirDevice.IdentityKey.Bytes(), qrCode.Key1[:]) { + if bytes.Equal(theirDevice.SigningKey.Bytes(), qrCode.Key1[:]) { log.Info().Msg("Verified that the other device key is what we expected") } else { return fmt.Errorf("the other device's key is not what we expected") @@ -268,10 +268,10 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *ve if err != nil { return err } - key2 = theirDevice.IdentityKey.Bytes() + key2 = theirDevice.SigningKey.Bytes() case QRCodeModeSelfVerifyingMasterKeyUntrusted: // Key 1 is the current device's key - key1 = vh.mach.OwnIdentity().IdentityKey.Bytes() + key1 = vh.mach.OwnIdentity().SigningKey.Bytes() // Key 2 is the master signing key. key2 = ownCrossSigningPublicKeys.MasterKey.Bytes() From 2e50f99e52f236e47cfafb3d9ddd428448881abe Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 24 May 2024 12:46:56 -0600 Subject: [PATCH 0209/1647] verificationhelper: don't move state to done until both devices have sent the done event Signed-off-by: Sumner Evans --- crypto/verificationhelper/verificationhelper.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 2acdd47f..2c7c05f8 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -755,9 +755,9 @@ func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn *verif return } - txn.VerificationState = verificationStateDone txn.ReceivedTheirDone = true if txn.SentOurDone { + txn.VerificationState = verificationStateDone vh.verificationDone(ctx, txn.TransactionID) } } From 3dbf8ef2f0192bc903cb62b2a0f3b863bbfc8087 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 21 May 2024 15:59:48 -0600 Subject: [PATCH 0210/1647] verificationhelper: better errors/logs and more aggressive cancellations Signed-off-by: Sumner Evans --- crypto/verificationhelper/reciprocate.go | 51 ++++---- crypto/verificationhelper/sas.go | 10 +- .../verificationhelper/verificationhelper.go | 116 +++++++++++++----- event/verification.go | 4 + 4 files changed, 119 insertions(+), 62 deletions(-) diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index c86f7d19..9676f295 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -35,11 +35,9 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by txn, ok := vh.activeTransactions[qrCode.TransactionID] if !ok { - log.Warn().Msg("Ignoring QR code scan for an unknown transaction") - return nil + return fmt.Errorf("unknown transaction ID found in QR code") } else if txn.VerificationState != verificationStateReady { - log.Warn().Msg("Ignoring QR code scan for a transaction that is not in the ready state") - return nil + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "transaction found in the QR code is not in the ready state") } txn.VerificationState = verificationStateTheirQRScanned @@ -57,14 +55,14 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by // key, and Key2 is what the other device thinks our device key is. if vh.client.UserID != txn.TheirUser { - return fmt.Errorf("mode %d is only allowed when the other user is the same as the current user", qrCode.Mode) + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "mode %d is only allowed when the other user is the same as the current user", qrCode.Mode) } // Verify the master key is correct if bytes.Equal(ownCrossSigningPublicKeys.MasterKey.Bytes(), qrCode.Key1[:]) { log.Info().Msg("Verified that the other device has the same master key") } else { - return fmt.Errorf("the master key does not match") + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "the master key does not match") } // Verify that the device key that the other device things we have is @@ -73,11 +71,11 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by if bytes.Equal(myKeys.SigningKey.Bytes(), qrCode.Key2[:]) { log.Info().Msg("Verified that the other device has the correct key for this device") } else { - return fmt.Errorf("the other device has the wrong key for this device") + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "the other device has the wrong key for this device") } if err := vh.mach.SignOwnMasterKey(ctx); err != nil { - return fmt.Errorf("failed to sign own master key: %w", err) + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "failed to sign own master key: %w", err) } case QRCodeModeSelfVerifyingMasterKeyUntrusted: // The QR was created by a device that does not trust the master key, @@ -89,47 +87,47 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by if trusted, err := vh.mach.CryptoStore.IsKeySignedBy(ctx, vh.client.UserID, ownCrossSigningPublicKeys.MasterKey, vh.client.UserID, vh.mach.OwnIdentity().SigningKey); err != nil { return err } else if !trusted { - return fmt.Errorf("the master key is not trusted by this device") + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeMasterKeyNotTrusted, "the master key is not trusted by this device, cannot verify device that does not trust the master key") } if vh.client.UserID != txn.TheirUser { - return fmt.Errorf("mode %d is only allowed when the other user is the same as the current user", qrCode.Mode) + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "mode %d is only allowed when the other user is the same as the current user", qrCode.Mode) } // Get their device theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) if err != nil { - return err + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to get their device: %w", err) } // Verify that the other device's key is what we expect. if bytes.Equal(theirDevice.SigningKey.Bytes(), qrCode.Key1[:]) { log.Info().Msg("Verified that the other device key is what we expected") } else { - return fmt.Errorf("the other device's key is not what we expected") + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "the other device's key is not what we expected") } // Verify that what they think the master key is is correct. if bytes.Equal(vh.mach.GetOwnCrossSigningPublicKeys(ctx).MasterKey.Bytes(), qrCode.Key2[:]) { log.Info().Msg("Verified that the other device has the correct master key") } else { - return fmt.Errorf("the master key does not match") + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "the master key does not match") } // Trust their device theirDevice.Trust = id.TrustStateVerified err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUser, theirDevice) if err != nil { - return fmt.Errorf("failed to update device trust state after verifying: %w", err) + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to update device trust state after verifying: %+v", err) } // Cross-sign their device with the self-signing key err = vh.mach.SignOwnDevice(ctx, theirDevice) if err != nil { - return fmt.Errorf("failed to sign their device: %w", err) + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to sign their device: %+v", err) } default: - return fmt.Errorf("unknown QR code mode %d", qrCode.Mode) + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "unknown QR code mode %d", qrCode.Mode) } // Send a m.key.verification.start event with the secret @@ -141,17 +139,20 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by } err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationStart, txn.StartEventContent) if err != nil { - return err + return fmt.Errorf("failed to send m.key.verification.start event: %w", err) } + log.Debug().Msg("Successfully sent the m.key.verification.start event") // Immediately send the m.key.verification.done event, as our side of the // transaction is done. err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { - return err + return fmt.Errorf("failed to send m.key.verification.done event: %w", err) } + log.Debug().Msg("Successfully sent the m.key.verification.done event") txn.SentOurDone = true if txn.ReceivedTheirDone { + log.Debug().Msg("We already received their done event. Setting verification state to done.") txn.VerificationState = verificationStateDone vh.verificationDone(ctx, txn.TransactionID) } @@ -225,15 +226,17 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *ve Stringer("transaction_id", txn.TransactionID). Logger() if vh.showQRCode == nil { - log.Warn().Msg("Ignoring QR code generation request as showing a QR code is not enabled on this device") + log.Info().Msg("Ignoring QR code generation request as showing a QR code is not enabled on this device") return nil - } - if !slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeScan) { - log.Warn().Msg("Ignoring QR code generation request as other device cannot scan QR codes") + } else if !slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeScan) { + log.Info().Msg("Ignoring QR code generation request as other device cannot scan QR codes") return nil } ownCrossSigningPublicKeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) + if ownCrossSigningPublicKeys == nil || len(ownCrossSigningPublicKeys.MasterKey) == 0 { + return fmt.Errorf("failed to get own cross-signing master public key") + } mode := QRCodeModeCrossSigning if vh.client.UserID == txn.TheirUser { @@ -245,6 +248,8 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *ve } else { mode = QRCodeModeSelfVerifyingMasterKeyUntrusted } + } else { + panic("unimplemented") } var key1, key2 []byte @@ -276,7 +281,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *ve // Key 2 is the master signing key. key2 = ownCrossSigningPublicKeys.MasterKey.Bytes() default: - log.Fatal().Str("mode", string(mode)).Msg("Unknown QR code mode") + log.Fatal().Int("mode", int(mode)).Msg("Unknown QR code mode") } qrCode := NewQRCode(mode, txn.TransactionID, [32]byte(key1), [32]byte(key2)) diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index 8160a4e1..6e523ba5 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -580,7 +580,7 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi slices.Sort(keyIDs) expectedKeyMAC, err := vh.verificationMACHKDF(txn, txn.TheirUser, txn.TheirDevice, vh.client.UserID, vh.client.DeviceID, "KEY_IDS", strings.Join(keyIDs, ",")) if err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeSASMismatch, "failed to calculate key list MAC: %v", err) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeSASMismatch, "failed to calculate key list MAC: %w", err) return } if !bytes.Equal(expectedKeyMAC, macEvt.Keys) { @@ -607,7 +607,7 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi if kID == txn.TheirDevice.String() { theirDevice, err = vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) if err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to fetch their device: %v", err) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to fetch their device: %w", err) return } key = theirDevice.SigningKey.String() @@ -626,7 +626,7 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi expectedMAC, err := vh.verificationMACHKDF(txn, txn.TheirUser, txn.TheirDevice, vh.client.UserID, vh.client.DeviceID, keyID.String(), key) if err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to calculate key MAC: %v", err) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to calculate key MAC: %w", err) return } if !bytes.Equal(expectedMAC, mac) { @@ -639,7 +639,7 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi theirDevice.Trust = id.TrustStateVerified err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUser, theirDevice) if err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to update device trust state after verifying: %v", err) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to update device trust state after verifying: %w", err) return } } @@ -653,7 +653,7 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi txn.VerificationState = verificationStateSASMACExchanged err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to send verification done event: %v", err) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to send verification done event: %w", err) return } txn.SentOurDone = true diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 2c7c05f8..e2166a15 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -178,31 +178,31 @@ func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, call } if c, ok := callbacks.(RequiredCallbacks); !ok { - panic("callbacks must implement VerificationRequested") + panic("callbacks must implement RequiredCallbacks") } else { helper.verificationRequested = c.VerificationRequested helper.verificationCancelledCallback = c.VerificationCancelled helper.verificationDone = c.VerificationDone } + supportedMethods := map[event.VerificationMethod]struct{}{} if c, ok := callbacks.(showSASCallbacks); ok { - helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodSAS) + supportedMethods[event.VerificationMethodSAS] = struct{}{} helper.showSAS = c.ShowSAS } if c, ok := callbacks.(showQRCodeCallbacks); ok { - helper.supportedMethods = append(helper.supportedMethods, - event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate) + supportedMethods[event.VerificationMethodQRCodeShow] = struct{}{} + supportedMethods[event.VerificationMethodReciprocate] = struct{}{} helper.scanQRCode = c.ScanQRCode helper.showQRCode = c.ShowQRCode helper.qrCodeScaned = c.QRCodeScanned } if supportsScan { - helper.supportedMethods = append(helper.supportedMethods, - event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate) + supportedMethods[event.VerificationMethodQRCodeScan] = struct{}{} + supportedMethods[event.VerificationMethodReciprocate] = struct{}{} } - slices.Sort(helper.supportedMethods) - helper.supportedMethods = slices.Compact(helper.supportedMethods) + helper.supportedMethods = maps.Keys(supportedMethods) return &helper } @@ -332,6 +332,10 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { // StartVerification starts an interactive verification flow with the given // user via a to-device event. func (vh *VerificationHelper) StartVerification(ctx context.Context, to id.UserID) (id.VerificationTransactionID, error) { + if len(vh.supportedMethods) == 0 { + return "", fmt.Errorf("no supported verification methods") + } + txnID := id.NewVerificationTransactionID() devices, err := vh.mach.CryptoStore.GetDevices(ctx, to) @@ -389,8 +393,8 @@ func (vh *VerificationHelper) StartVerification(ctx context.Context, to id.UserI return txnID, nil } -// StartVerification starts an interactive verification flow with the given -// user in the given room. +// StartInRoomVerification starts an interactive verification flow with the +// given user in the given room. func (vh *VerificationHelper) StartInRoomVerification(ctx context.Context, roomID id.RoomID, to id.UserID) (id.VerificationTransactionID, error) { log := vh.getLog(ctx).With(). Str("verification_action", "start in-room verification"). @@ -441,15 +445,34 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V txn, ok := vh.activeTransactions[txnID] if !ok { return fmt.Errorf("unknown transaction ID") - } - if txn.VerificationState != verificationStateRequested { + } else if txn.VerificationState != verificationStateRequested { return fmt.Errorf("transaction is not in the requested state") } + supportedMethods := map[event.VerificationMethod]struct{}{} + for _, method := range txn.TheirSupportedMethods { + switch method { + case event.VerificationMethodSAS: + if slices.Contains(vh.supportedMethods, event.VerificationMethodSAS) { + supportedMethods[event.VerificationMethodSAS] = struct{}{} + } + case event.VerificationMethodQRCodeShow: + if slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) { + supportedMethods[event.VerificationMethodQRCodeScan] = struct{}{} + supportedMethods[event.VerificationMethodReciprocate] = struct{}{} + } + case event.VerificationMethodQRCodeScan: + if slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeShow) { + supportedMethods[event.VerificationMethodQRCodeShow] = struct{}{} + supportedMethods[event.VerificationMethodReciprocate] = struct{}{} + } + } + } + log.Info().Msg("Sending ready event") readyEvt := &event.VerificationReadyEventContent{ FromDevice: vh.client.DeviceID, - Methods: vh.supportedMethods, + Methods: maps.Keys(supportedMethods), } err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationReady, readyEvt) if err != nil { @@ -457,7 +480,7 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V } txn.VerificationState = verificationStateReady - if slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) { + if vh.scanQRCode != nil && slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) { vh.scanQRCode(ctx, txn.TransactionID) } @@ -497,7 +520,7 @@ func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn *ve Parsed: content, }) if err != nil { - return fmt.Errorf("failed to send start event: %w", err) + return fmt.Errorf("failed to send %s event to %s: %w", evtType.String(), txn.RoomID, err) } } else { content.(event.VerificationTransactionable).SetTransactionID(txn.TransactionID) @@ -508,15 +531,19 @@ func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn *ve }} _, err := vh.client.SendToDevice(ctx, evtType, &req) if err != nil { - return fmt.Errorf("failed to send start event: %w", err) + return fmt.Errorf("failed to send %s event to %s: %w", evtType.String(), txn.TheirDevice, err) } } return nil } +// cancelVerificationTxn cancels a verification transaction with the given code +// and reason. It always returns an error, which is the formatted error message +// (this is allows the caller to return the result of this function call +// directly to expose the error to its caller). func (vh *VerificationHelper) cancelVerificationTxn(ctx context.Context, txn *verificationTransaction, code event.VerificationCancelCode, reasonFmtStr string, fmtArgs ...any) error { log := vh.getLog(ctx) - reason := fmt.Sprintf(reasonFmtStr, fmtArgs...) + reason := fmt.Errorf(reasonFmtStr, fmtArgs...).Error() log.Info(). Stringer("transaction_id", txn.TransactionID). Str("code", string(code)). @@ -528,11 +555,11 @@ func (vh *VerificationHelper) cancelVerificationTxn(ctx context.Context, txn *ve } err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationCancel, cancelEvt) if err != nil { - return err + return fmt.Errorf("failed to send cancel verification event (code: %s, reason: %s): %w", code, reason, err) } txn.VerificationState = verificationStateCancelled vh.verificationCancelledCallback(ctx, txn.TransactionID, code, reason) - return nil + return fmt.Errorf("verification cancelled (code: %s): %s", code, reason) } func (vh *VerificationHelper) onVerificationRequest(ctx context.Context, evt *event.Event) { @@ -568,7 +595,7 @@ func (vh *VerificationHelper) onVerificationRequest(ctx context.Context, evt *ev return } - if verificationRequest.TransactionID == "" { + if len(verificationRequest.TransactionID) == 0 { log.Warn().Msg("Ignoring verification request without a transaction ID") return } @@ -581,11 +608,33 @@ func (vh *VerificationHelper) onVerificationRequest(ctx context.Context, evt *ev ctx = log.WithContext(ctx) log.Info().Msg("Received verification request") + // Check if we support any of the methods listed + var supportsAnyMethod bool + for _, method := range verificationRequest.Methods { + switch method { + case event.VerificationMethodSAS: + supportsAnyMethod = slices.Contains(vh.supportedMethods, event.VerificationMethodSAS) + case event.VerificationMethodQRCodeScan: + supportsAnyMethod = slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeShow) && + slices.Contains(verificationRequest.Methods, event.VerificationMethodReciprocate) + case event.VerificationMethodQRCodeShow: + supportsAnyMethod = slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && + slices.Contains(verificationRequest.Methods, event.VerificationMethodReciprocate) + } + if supportsAnyMethod { + break + } + } + if !supportsAnyMethod { + log.Warn().Msg("Ignoring verification request that doesn't have any methods we support") + return + } + vh.activeTransactionsLock.Lock() - _, ok := vh.activeTransactions[verificationRequest.TransactionID] + existing, ok := vh.activeTransactions[verificationRequest.TransactionID] if ok { vh.activeTransactionsLock.Unlock() - log.Info().Msg("Ignoring verification request for an already active transaction") + vh.cancelVerificationTxn(ctx, existing, event.VerificationCancelCodeUnexpectedMessage, "received a new verification request for the same transaction ID") return } vh.activeTransactions[verificationRequest.TransactionID] = &verificationTransaction{ @@ -607,7 +656,7 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *veri Logger() if txn.VerificationState != verificationStateRequested { - log.Warn().Msg("Ignoring verification ready event for a transaction that is not in the requested state") + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "verification ready event received for a transaction that is not in the requested state") return } @@ -633,7 +682,7 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *veri } devices, err := vh.mach.CryptoStore.GetDevices(ctx, txn.TheirUser) if err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to get devices for %s: %v", txn.TheirUser, err) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to get devices for %s: %w", txn.TheirUser, err) return } req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUser: {}}} @@ -653,13 +702,13 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *veri } } - if slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) { + if vh.scanQRCode != nil && slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) { vh.scanQRCode(ctx, txn.TransactionID) } err := vh.generateAndShowQRCode(ctx, txn) if err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to generate and show QR code: %v", err) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to generate and show QR code: %w", err) } } @@ -680,8 +729,7 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *veri // We didn't sent a start event yet, so we have gotten ourselves // into a bad state. They've either sent two start events, or we // have gone on to a new state. - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, - "got repeat start event from other user") + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "got repeat start event from other user") return } @@ -713,8 +761,7 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *veri txn.StartEventContent = startEvt } } else if txn.VerificationState != verificationStateReady { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, - "got start event for transaction that is not in ready state") + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "got start event for transaction that is not in ready state") return } @@ -722,7 +769,7 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *veri case event.VerificationMethodSAS: txn.VerificationState = verificationStateSASStarted if err := vh.onVerificationStartSAS(ctx, txn, evt); err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to handle SAS verification start: %v", err) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to handle SAS verification start: %w", err) } case event.VerificationMethodReciprocate: log.Info().Msg("Received reciprocate start event") @@ -749,9 +796,10 @@ func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn *verif vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState != verificationStateTheirQRScanned && txn.VerificationState != verificationStateSASMACExchanged { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, - "got done event for transaction that is not in QR-scanned or MAC-exchanged state") + if !slices.Contains([]verificationState{ + verificationStateTheirQRScanned, verificationStateOurQRScanned, verificationStateSASMACExchanged, + }, txn.VerificationState) { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "got done event for transaction that is not in QR-scanned or MAC-exchanged state") return } diff --git a/event/verification.go b/event/verification.go index b1851de3..6101896f 100644 --- a/event/verification.go +++ b/event/verification.go @@ -220,6 +220,10 @@ const ( VerificationCancelCodeAccepted VerificationCancelCode = "m.accepted" VerificationCancelCodeSASMismatch VerificationCancelCode = "m.mismatched_sas" VerificationCancelCodeCommitmentMismatch VerificationCancelCode = "m.mismatched_commitment" + + // Non-spec codes + VerificationCancelCodeInternalError VerificationCancelCode = "com.beeper.internal_error" + VerificationCancelCodeMasterKeyNotTrusted VerificationCancelCode = "com.beeper.master_key_not_trusted" // the master key is not trusted by this device, but the QR code that was scanned was from a device that doesn't trust the master key ) // VerificationCancelEventContent represents the content of an From 2195043eba98e7123c90006996a81ec3f94e83fe Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Wed, 22 May 2024 11:56:35 -0600 Subject: [PATCH 0211/1647] verificationhelper: add E2E tests Signed-off-by: Sumner Evans --- crypto/verificationhelper/callbacks_test.go | 136 ++++ crypto/verificationhelper/mockserver_test.go | 248 +++++++ crypto/verificationhelper/qrcode_test.go | 71 +- crypto/verificationhelper/reciprocate.go | 3 +- .../verificationhelper_self_test.go | 687 ++++++++++++++++++ 5 files changed, 1120 insertions(+), 25 deletions(-) create mode 100644 crypto/verificationhelper/callbacks_test.go create mode 100644 crypto/verificationhelper/mockserver_test.go create mode 100644 crypto/verificationhelper/verificationhelper_self_test.go diff --git a/crypto/verificationhelper/callbacks_test.go b/crypto/verificationhelper/callbacks_test.go new file mode 100644 index 00000000..7fc35129 --- /dev/null +++ b/crypto/verificationhelper/callbacks_test.go @@ -0,0 +1,136 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package verificationhelper_test + +import ( + "context" + + "maunium.net/go/mautrix/crypto/verificationhelper" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type MockVerificationCallbacks interface { + GetRequestedVerifications() map[id.UserID][]id.VerificationTransactionID + GetScanQRCodeTransactions() []id.VerificationTransactionID + GetQRCodeShown(id.VerificationTransactionID) *verificationhelper.QRCode +} + +type baseVerificationCallbacks struct { + scanQRCodeTransactions []id.VerificationTransactionID + verificationsRequested map[id.UserID][]id.VerificationTransactionID + qrCodesShown map[id.VerificationTransactionID]*verificationhelper.QRCode + qrCodesScanned map[id.VerificationTransactionID]struct{} + doneTransactions map[id.VerificationTransactionID]struct{} + verificationCancellation map[id.VerificationTransactionID]*event.VerificationCancelEventContent +} + +func newBaseVerificationCallbacks() *baseVerificationCallbacks { + return &baseVerificationCallbacks{ + verificationsRequested: map[id.UserID][]id.VerificationTransactionID{}, + qrCodesShown: map[id.VerificationTransactionID]*verificationhelper.QRCode{}, + qrCodesScanned: map[id.VerificationTransactionID]struct{}{}, + doneTransactions: map[id.VerificationTransactionID]struct{}{}, + verificationCancellation: map[id.VerificationTransactionID]*event.VerificationCancelEventContent{}, + } +} + +func (c *baseVerificationCallbacks) GetRequestedVerifications() map[id.UserID][]id.VerificationTransactionID { + return c.verificationsRequested +} + +func (c *baseVerificationCallbacks) GetScanQRCodeTransactions() []id.VerificationTransactionID { + return c.scanQRCodeTransactions +} + +func (c *baseVerificationCallbacks) GetQRCodeShown(txnID id.VerificationTransactionID) *verificationhelper.QRCode { + return c.qrCodesShown[txnID] +} + +func (c *baseVerificationCallbacks) WasOurQRCodeScanned(txnID id.VerificationTransactionID) bool { + _, ok := c.qrCodesScanned[txnID] + return ok +} + +func (c *baseVerificationCallbacks) IsVerificationDone(txnID id.VerificationTransactionID) bool { + _, ok := c.doneTransactions[txnID] + return ok +} + +func (c *baseVerificationCallbacks) GetVerificationCancellation(txnID id.VerificationTransactionID) *event.VerificationCancelEventContent { + return c.verificationCancellation[txnID] +} + +func (c *baseVerificationCallbacks) VerificationRequested(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID) { + c.verificationsRequested[from] = append(c.verificationsRequested[from], txnID) +} + +func (c *baseVerificationCallbacks) VerificationCancelled(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) { + c.verificationCancellation[txnID] = &event.VerificationCancelEventContent{ + Code: code, + Reason: reason, + } +} + +func (c *baseVerificationCallbacks) VerificationDone(ctx context.Context, txnID id.VerificationTransactionID) { + c.doneTransactions[txnID] = struct{}{} +} + +type sasVerificationCallbacks struct { + *baseVerificationCallbacks +} + +func newSASVerificationCallbacks() *sasVerificationCallbacks { + return &sasVerificationCallbacks{newBaseVerificationCallbacks()} +} + +func newSASVerificationCallbacksWithBase(base *baseVerificationCallbacks) *sasVerificationCallbacks { + return &sasVerificationCallbacks{base} +} + +func (*sasVerificationCallbacks) ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, decimals []int) { + panic("show sas") +} + +type qrCodeVerificationCallbacks struct { + *baseVerificationCallbacks +} + +func newQRCodeVerificationCallbacks() *qrCodeVerificationCallbacks { + return &qrCodeVerificationCallbacks{newBaseVerificationCallbacks()} +} + +func newQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *qrCodeVerificationCallbacks { + return &qrCodeVerificationCallbacks{base} +} + +func (c *qrCodeVerificationCallbacks) ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID) { + c.scanQRCodeTransactions = append(c.scanQRCodeTransactions, txnID) +} + +func (c *qrCodeVerificationCallbacks) ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *verificationhelper.QRCode) { + c.qrCodesShown[txnID] = qrCode +} + +func (c *qrCodeVerificationCallbacks) QRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) { + c.qrCodesScanned[txnID] = struct{}{} +} + +type allVerificationCallbacks struct { + *baseVerificationCallbacks + *sasVerificationCallbacks + *qrCodeVerificationCallbacks +} + +func newAllVerificationCallbacks() *allVerificationCallbacks { + base := newBaseVerificationCallbacks() + return &allVerificationCallbacks{ + base, + newSASVerificationCallbacksWithBase(base), + newQRCodeVerificationCallbacksWithBase(base), + } +} diff --git a/crypto/verificationhelper/mockserver_test.go b/crypto/verificationhelper/mockserver_test.go new file mode 100644 index 00000000..3c640267 --- /dev/null +++ b/crypto/verificationhelper/mockserver_test.go @@ -0,0 +1,248 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package verificationhelper_test + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/gorilla/mux" + "github.com/stretchr/testify/require" + "go.mau.fi/util/random" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/cryptohelper" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +// mockServer is a mock Matrix server that wraps an [httptest.Server] to allow +// testing of the interactive verification process. +type mockServer struct { + *httptest.Server + + AccessTokenToUserID map[string]id.UserID + DeviceInbox map[id.UserID]map[id.DeviceID][]event.Event + AccountData map[id.UserID]map[event.Type]json.RawMessage + DeviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys + MasterKeys map[id.UserID]mautrix.CrossSigningKeys + SelfSigningKeys map[id.UserID]mautrix.CrossSigningKeys + UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys +} + +func DecodeVarsMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + var err error + for k, v := range vars { + vars[k], err = url.PathUnescape(v) + if err != nil { + panic(err) + } + } + next.ServeHTTP(w, r) + }) +} + +func createMockServer(t *testing.T) *mockServer { + t.Helper() + + server := mockServer{ + AccessTokenToUserID: map[string]id.UserID{}, + DeviceInbox: map[id.UserID]map[id.DeviceID][]event.Event{}, + AccountData: map[id.UserID]map[event.Type]json.RawMessage{}, + DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{}, + MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + } + + router := mux.NewRouter().SkipClean(true).StrictSlash(false).UseEncodedPath() + router.Use(DecodeVarsMiddleware) + router.HandleFunc("/_matrix/client/v3/login", server.postLogin).Methods(http.MethodPost) + router.HandleFunc("/_matrix/client/v3/keys/query", server.postKeysQuery).Methods(http.MethodPost) + router.HandleFunc("/_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice).Methods(http.MethodPut) + router.HandleFunc("/_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData).Methods(http.MethodPut) + router.HandleFunc("/_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload).Methods(http.MethodPost) + router.HandleFunc("/_matrix/client/v3/keys/signatures/upload", server.emptyResp).Methods(http.MethodPost) + router.HandleFunc("/_matrix/client/v3/keys/upload", server.postKeysUpload).Methods(http.MethodPost) + + server.Server = httptest.NewServer(router) + return &server +} + +func (ms *mockServer) getUserID(r *http.Request) id.UserID { + authHeader := r.Header.Get("Authorization") + authHeader = strings.TrimPrefix(authHeader, "Bearer ") + userID, ok := ms.AccessTokenToUserID[authHeader] + if !ok { + panic("no user ID found for access token " + authHeader) + } + return userID +} + +func (s *mockServer) emptyResp(w http.ResponseWriter, _ *http.Request) { + w.Write([]byte("{}")) +} + +func (s *mockServer) postLogin(w http.ResponseWriter, r *http.Request) { + var loginReq mautrix.ReqLogin + json.NewDecoder(r.Body).Decode(&loginReq) + + deviceID := loginReq.DeviceID + if deviceID == "" { + deviceID = id.DeviceID(random.String(10)) + } + + accessToken := random.String(30) + userID := id.UserID(loginReq.Identifier.User) + s.AccessTokenToUserID[accessToken] = userID + + json.NewEncoder(w).Encode(&mautrix.RespLogin{ + AccessToken: accessToken, + DeviceID: deviceID, + UserID: userID, + }) +} + +func (s *mockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + var req mautrix.ReqSendToDevice + json.NewDecoder(r.Body).Decode(&req) + evtType := event.Type{Type: vars["type"], Class: event.ToDeviceEventType} + + for user, devices := range req.Messages { + for device, content := range devices { + if _, ok := s.DeviceInbox[user]; !ok { + s.DeviceInbox[user] = map[id.DeviceID][]event.Event{} + } + content.ParseRaw(evtType) + s.DeviceInbox[user][device] = append(s.DeviceInbox[user][device], event.Event{ + Sender: s.getUserID(r), + Type: evtType, + Content: *content, + }) + } + } + s.emptyResp(w, r) +} + +func (s *mockServer) putAccountData(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + userID := id.UserID(vars["userID"]) + eventType := event.Type{Type: vars["type"], Class: event.AccountDataEventType} + + jsonData, _ := io.ReadAll(r.Body) + if _, ok := s.AccountData[userID]; !ok { + s.AccountData[userID] = map[event.Type]json.RawMessage{} + } + s.AccountData[userID][eventType] = json.RawMessage(jsonData) + s.emptyResp(w, r) +} + +func (s *mockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) { + var req mautrix.ReqQueryKeys + json.NewDecoder(r.Body).Decode(&req) + resp := mautrix.RespQueryKeys{ + MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{}, + } + for user := range req.DeviceKeys { + resp.MasterKeys[user] = s.MasterKeys[user] + resp.UserSigningKeys[user] = s.UserSigningKeys[user] + resp.SelfSigningKeys[user] = s.SelfSigningKeys[user] + resp.DeviceKeys[user] = s.DeviceKeys[user] + } + json.NewEncoder(w).Encode(&resp) +} + +func (s *mockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) { + var req mautrix.ReqUploadKeys + json.NewDecoder(r.Body).Decode(&req) + + userID := s.getUserID(r) + if _, ok := s.DeviceKeys[userID]; !ok { + s.DeviceKeys[userID] = map[id.DeviceID]mautrix.DeviceKeys{} + } + s.DeviceKeys[userID][req.DeviceKeys.DeviceID] = *req.DeviceKeys + + json.NewEncoder(w).Encode(&mautrix.RespUploadKeys{ + OneTimeKeyCounts: mautrix.OTKCount{SignedCurve25519: 50}, + }) +} + +func (s *mockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) { + var req mautrix.UploadCrossSigningKeysReq + json.NewDecoder(r.Body).Decode(&req) + + userID := s.getUserID(r) + s.MasterKeys[userID] = req.Master + s.SelfSigningKeys[userID] = req.SelfSigning + s.UserSigningKeys[userID] = req.UserSigning + + s.emptyResp(w, r) +} + +func (ms *mockServer) Login(t *testing.T, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) { + t.Helper() + client, err := mautrix.NewClient(ms.URL, "", "") + require.NoError(t, err) + client.StateStore = mautrix.NewMemoryStateStore() + + _, err = client.Login(ctx, &mautrix.ReqLogin{ + Type: mautrix.AuthTypePassword, + Identifier: mautrix.UserIdentifier{ + Type: mautrix.IdentifierTypeUser, + User: userID.String(), + }, + DeviceID: deviceID, + Password: "password", + StoreCredentials: true, + }) + require.NoError(t, err) + + cryptoStore := crypto.NewMemoryStore(nil) + cryptoHelper, err := cryptohelper.NewCryptoHelper(client, []byte("test"), cryptoStore) + require.NoError(t, err) + client.Crypto = cryptoHelper + + err = cryptoHelper.Init(ctx) + require.NoError(t, err) + + err = cryptoHelper.Machine().ShareKeys(ctx, 50) + require.NoError(t, err) + + return client, cryptoStore +} + +func (ms *mockServer) dispatchToDevice(t *testing.T, ctx context.Context, client *mautrix.Client) { + t.Helper() + + for _, evt := range ms.DeviceInbox[client.UserID][client.DeviceID] { + client.Syncer.(*mautrix.DefaultSyncer).Dispatch(ctx, &evt) + ms.DeviceInbox[client.UserID][client.DeviceID] = ms.DeviceInbox[client.UserID][client.DeviceID][1:] + } +} + +func addDeviceID(ctx context.Context, cryptoStore crypto.Store, userID id.UserID, deviceID id.DeviceID) { + err := cryptoStore.PutDevice(ctx, userID, &id.Device{ + UserID: userID, + DeviceID: deviceID, + }) + if err != nil { + panic(err) + } +} diff --git a/crypto/verificationhelper/qrcode_test.go b/crypto/verificationhelper/qrcode_test.go index d2767734..fda3de2c 100644 --- a/crypto/verificationhelper/qrcode_test.go +++ b/crypto/verificationhelper/qrcode_test.go @@ -8,51 +8,76 @@ package verificationhelper_test import ( "bytes" + "encoding/base64" + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "maunium.net/go/mautrix/crypto/verificationhelper" + "maunium.net/go/mautrix/id" ) func TestQRCode_Roundtrip(t *testing.T) { var key1, key2 [32]byte copy(key1[:], bytes.Repeat([]byte{0x01}, 32)) copy(key2[:], bytes.Repeat([]byte{0x02}, 32)) - qrCode := verificationhelper.NewQRCode(verificationhelper.QRCodeModeCrossSigning, "test", key1, key2) + txnID := id.VerificationTransactionID(strings.Repeat("a", 20)) + qrCode := verificationhelper.NewQRCode(verificationhelper.QRCodeModeCrossSigning, txnID, key1, key2) encoded := qrCode.Bytes() decoded, err := verificationhelper.NewQRCodeFromBytes(encoded) require.NoError(t, err) assert.Equal(t, verificationhelper.QRCodeModeCrossSigning, decoded.Mode) - assert.EqualValues(t, "test", decoded.TransactionID) + assert.EqualValues(t, txnID, decoded.TransactionID) assert.Equal(t, key1, decoded.Key1) assert.Equal(t, key2, decoded.Key2) } func TestQRCodeDecode(t *testing.T) { - qrcodeData := []byte{ - 0x4d, 0x41, 0x54, 0x52, 0x49, 0x58, 0x02, 0x01, 0x00, 0x20, 0x47, 0x6e, 0x41, 0x65, 0x43, 0x76, - 0x74, 0x57, 0x6a, 0x7a, 0x4d, 0x4f, 0x56, 0x57, 0x51, 0x54, 0x6b, 0x74, 0x33, 0x35, 0x59, 0x52, - 0x55, 0x72, 0x75, 0x6a, 0x6d, 0x52, 0x50, 0x63, 0x38, 0x61, 0x18, 0x32, 0x7c, 0xc3, 0x8c, 0xc2, - 0xa6, 0xc2, 0xb5, 0xc2, 0xa7, 0x50, 0x57, 0x67, 0x19, 0x5e, 0xc3, 0xaf, 0xc2, 0xa0, 0xc2, 0x98, - 0xc2, 0x9d, 0x36, 0xc3, 0xad, 0x7a, 0x10, 0x2e, 0x18, 0x3e, 0x4e, 0xc3, 0x84, 0xc3, 0x81, 0x45, - 0x0c, 0xc2, 0xae, 0x19, 0x78, 0xc2, 0x99, 0x06, 0xc2, 0x92, 0xc2, 0x94, 0xc2, 0x8e, 0xc2, 0xb7, - 0x59, 0xc2, 0x96, 0xc2, 0xad, 0xc3, 0xbd, 0x70, 0x6a, 0x11, 0xc2, 0xba, 0xc2, 0xa9, 0x29, 0xc3, - 0x8f, 0x0d, 0xc2, 0xb8, 0xc2, 0x88, 0x67, 0x5b, 0xc3, 0xb3, 0x01, 0xc2, 0xb0, 0x63, 0x2e, 0xc2, - 0xa5, 0xc3, 0xb3, 0x60, 0xc3, 0x82, 0x04, 0xc3, 0xa3, 0x72, 0x7d, 0x7c, 0x1d, 0xc2, 0xb6, 0xc2, - 0xba, 0xc2, 0x81, 0x1e, 0xc2, 0x99, 0xc2, 0xb8, 0x7f, 0x0a, + testCases := []struct { + b64 string + txnID string + key1 string + key2 string + sharedSecret string + }{ + { + "TUFUUklYAgEAIEduQWVDdnRXanpNT1ZXUVRrdDM1WVJVcnVqbVJQYzhhGDJ8w4zCpsK1wqdQV2cZXsOvwqDCmMKdNsOtehAuGD5Ow4TDgUUMwq4ZeMKZBsKSwpTCjsK3WcKWwq3DvXBqEcK6wqkpw48NwrjCiGdbw7MBwrBjLsKlw7Ngw4IEw6NyfXwdwrbCusKBHsKZwrh/Cg==", + "GnAeCvtWjzMOVWQTkt35YRUrujmRPc8a", + "GDJ8w4zCpsK1wqdQV2cZXsOvwqDCmMKdNsOtehAuGD4=", + "TsOEw4FFDMKuGXjCmQbCksKUwo7Ct1nClsKtw71wahE=", + "wrrCqSnDjw3CuMKIZ1vDswHCsGMuwqXDs2DDggTDo3J9fB3CtsK6woEewpnCuH8K", + }, + { + "TUFUUklYAgEAIGM1YjljNzE3ZWIzYjRmYzBiZDhhZjA0MDQ4NDY5MDdle4oLkpUdO1cTu5M3K3B4BlnpxtAbVgXCuQKOIqMmt+xAjVvaEXF39X0z5waRY9UE0b5PKiWvOBSJHEGkxX28Y2OEDLIWP/kCVUlyXXENlj0=", + "c5b9c717eb3b4fc0bd8af0404846907e", + "e4oLkpUdO1cTu5M3K3B4BlnpxtAbVgXCuQKOIqMmt+w=", + "QI1b2hFxd/V9M+cGkWPVBNG+TyolrzgUiRxBpMV9vGM=", + "Y4QMshY/+QJVSXJdcQ2WPQ==", + }, + } + + for _, tc := range testCases { + t.Run(tc.b64, func(t *testing.T) { + qrcodeData, err := base64.StdEncoding.DecodeString(tc.b64) + require.NoError(t, err) + expectedKey1, err := base64.StdEncoding.DecodeString(tc.key1) + require.NoError(t, err) + expectedKey2, err := base64.StdEncoding.DecodeString(tc.key2) + require.NoError(t, err) + expectedSharedSecret, err := base64.StdEncoding.DecodeString(tc.sharedSecret) + require.NoError(t, err) + + decoded, err := verificationhelper.NewQRCodeFromBytes(qrcodeData) + require.NoError(t, err) + assert.Equal(t, verificationhelper.QRCodeModeSelfVerifyingMasterKeyTrusted, decoded.Mode) + assert.EqualValues(t, tc.txnID, decoded.TransactionID) + assert.EqualValues(t, expectedKey1, decoded.Key1) + assert.EqualValues(t, expectedKey2, decoded.Key2) + assert.EqualValues(t, expectedSharedSecret, decoded.SharedSecret) + }) } - decoded, err := verificationhelper.NewQRCodeFromBytes(qrcodeData) - require.NoError(t, err) - assert.Equal(t, verificationhelper.QRCodeModeSelfVerifyingMasterKeyTrusted, decoded.Mode) - assert.EqualValues(t, "GnAeCvtWjzMOVWQTkt35YRUrujmRPc8a", decoded.TransactionID) - assert.Equal(t, - [32]byte{0x18, 0x32, 0x7c, 0xc3, 0x8c, 0xc2, 0xa6, 0xc2, 0xb5, 0xc2, 0xa7, 0x50, 0x57, 0x67, 0x19, 0x5e, 0xc3, 0xaf, 0xc2, 0xa0, 0xc2, 0x98, 0xc2, 0x9d, 0x36, 0xc3, 0xad, 0x7a, 0x10, 0x2e, 0x18, 0x3e}, - decoded.Key1) - assert.Equal(t, - [32]byte{0x4e, 0xc3, 0x84, 0xc3, 0x81, 0x45, 0xc, 0xc2, 0xae, 0x19, 0x78, 0xc2, 0x99, 0x6, 0xc2, 0x92, 0xc2, 0x94, 0xc2, 0x8e, 0xc2, 0xb7, 0x59, 0xc2, 0x96, 0xc2, 0xad, 0xc3, 0xbd, 0x70, 0x6a, 0x11}, - decoded.Key2) } diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index 9676f295..e99e55fd 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -174,8 +174,7 @@ func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id log.Warn().Msg("Ignoring QR code scan confirmation for an unknown transaction") return nil } else if txn.VerificationState != verificationStateOurQRScanned { - log.Warn().Msg("Ignoring QR code scan confirmation for a transaction that is not in the started state") - return nil + return fmt.Errorf("transaction is not in the scanned state") } log.Info().Msg("Confirming QR code scanned") diff --git a/crypto/verificationhelper/verificationhelper_self_test.go b/crypto/verificationhelper/verificationhelper_self_test.go new file mode 100644 index 00000000..08e2c6e4 --- /dev/null +++ b/crypto/verificationhelper/verificationhelper_self_test.go @@ -0,0 +1,687 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package verificationhelper_test + +import ( + "context" + "fmt" + "os" + "testing" + "time" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/cryptohelper" + "maunium.net/go/mautrix/crypto/verificationhelper" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +var userID = id.UserID("@alice:example.org") +var sendingDeviceID = id.DeviceID("sending") +var receivingDeviceID = id.DeviceID("receiving") + +func init() { + log.Logger = zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.RFC3339}).With().Timestamp().Logger().Level(zerolog.TraceLevel) + zerolog.DefaultContextLogger = &log.Logger +} + +func initServerAndLogin(t *testing.T, ctx context.Context) (ts *mockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) { + t.Helper() + ts = createMockServer(t) + + sendingClient, sendingCryptoStore = ts.Login(t, ctx, userID, sendingDeviceID) + sendingMachine = sendingClient.Crypto.(*cryptohelper.CryptoHelper).Machine() + receivingClient, receivingCryptoStore = ts.Login(t, ctx, userID, receivingDeviceID) + receivingMachine = receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine() + + err := sendingCryptoStore.PutDevice(ctx, userID, sendingMachine.OwnIdentity()) + require.NoError(t, err) + err = sendingCryptoStore.PutDevice(ctx, userID, receivingMachine.OwnIdentity()) + require.NoError(t, err) + err = receivingCryptoStore.PutDevice(ctx, userID, sendingMachine.OwnIdentity()) + require.NoError(t, err) + err = receivingCryptoStore.PutDevice(ctx, userID, receivingMachine.OwnIdentity()) + require.NoError(t, err) + return +} + +func initDefaultCallbacks(t *testing.T, ctx context.Context, sendingClient, receivingClient *mautrix.Client, sendingMachine, receivingMachine *crypto.OlmMachine) (sendingCallbacks, receivingCallbacks *allVerificationCallbacks, sendingHelper, receivingHelper *verificationhelper.VerificationHelper) { + t.Helper() + sendingCallbacks = newAllVerificationCallbacks() + sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, sendingCallbacks, true) + require.NoError(t, sendingHelper.Init(ctx)) + + receivingCallbacks = newAllVerificationCallbacks() + receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, receivingCallbacks, true) + require.NoError(t, receivingHelper.Init(ctx)) + return +} + +func TestSelfVerification_Start(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + receivingDeviceID2 := id.DeviceID("receiving2") + + testCases := []struct { + supportsScan bool + callbacks MockVerificationCallbacks + startVerificationErrMsg string + expectedVerificationMethods []event.VerificationMethod + }{ + {false, newBaseVerificationCallbacks(), "no supported verification methods", nil}, + {true, newBaseVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {false, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}}, + {true, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, + {true, newQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {false, newQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {false, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, + {true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, + } + + for i, tc := range testCases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + ts := createMockServer(t) + defer ts.Close() + + client, cryptoStore := ts.Login(t, ctx, userID, sendingDeviceID) + addDeviceID(ctx, cryptoStore, userID, sendingDeviceID) + addDeviceID(ctx, cryptoStore, userID, receivingDeviceID) + addDeviceID(ctx, cryptoStore, userID, receivingDeviceID2) + + senderHelper := verificationhelper.NewVerificationHelper(client, client.Crypto.(*cryptohelper.CryptoHelper).Machine(), tc.callbacks, tc.supportsScan) + err := senderHelper.Init(ctx) + require.NoError(t, err) + + txnID, err := senderHelper.StartVerification(ctx, userID) + if tc.startVerificationErrMsg != "" { + assert.ErrorContains(t, err, tc.startVerificationErrMsg) + return + } + + assert.NoError(t, err) + assert.NotEmpty(t, txnID) + + toDeviceInbox := ts.DeviceInbox[userID] + + // Ensure that we didn't send a verification request to the + // sending device. + assert.Empty(t, toDeviceInbox[sendingDeviceID]) + + // Ensure that the verification request was sent to both of + // the other devices. + assert.NotEmpty(t, toDeviceInbox[receivingDeviceID]) + assert.NotEmpty(t, toDeviceInbox[receivingDeviceID2]) + assert.Equal(t, toDeviceInbox[receivingDeviceID], toDeviceInbox[receivingDeviceID2]) + assert.Len(t, toDeviceInbox[receivingDeviceID], 1) + + // Ensure that the verification request is correct. + verificationRequest := toDeviceInbox[receivingDeviceID][0].Content.AsVerificationRequest() + assert.Equal(t, sendingDeviceID, verificationRequest.FromDevice) + assert.Equal(t, txnID, verificationRequest.TransactionID) + assert.ElementsMatch(t, tc.expectedVerificationMethods, verificationRequest.Methods) + }) + } +} + +func TestSelfVerification_Accept_NoSupportedMethods(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + + ts := createMockServer(t) + defer ts.Close() + + sendingClient, sendingCryptoStore := ts.Login(t, ctx, userID, sendingDeviceID) + receivingClient, _ := ts.Login(t, ctx, userID, receivingDeviceID) + addDeviceID(ctx, sendingCryptoStore, userID, sendingDeviceID) + addDeviceID(ctx, sendingCryptoStore, userID, receivingDeviceID) + + sendingMachine := sendingClient.Crypto.(*cryptohelper.CryptoHelper).Machine() + recoveryKey, cache, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + assert.NoError(t, err) + assert.NotEmpty(t, recoveryKey) + assert.NotNil(t, cache) + + sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, newAllVerificationCallbacks(), true) + err = sendingHelper.Init(ctx) + require.NoError(t, err) + + receivingCallbacks := newBaseVerificationCallbacks() + receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine(), receivingCallbacks, false) + err = receivingHelper.Init(ctx) + require.NoError(t, err) + + txnID, err := sendingHelper.StartVerification(ctx, userID) + require.NoError(t, err) + require.NotEmpty(t, txnID) + + ts.dispatchToDevice(t, ctx, receivingClient) + + // Ensure that the receiver ignored the request because it + // doesn't support any of the verification methods in the + // request. + assert.Empty(t, receivingCallbacks.GetRequestedVerifications()) +} + +func TestSelfVerification_Accept_CorrectMethodsPresented(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + + testCases := []struct { + sendingSupportsScan bool + receivingSupportsScan bool + sendingCallbacks MockVerificationCallbacks + receivingCallbacks MockVerificationCallbacks + expectedVerificationMethods []event.VerificationMethod + }{ + {false, false, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}}, + {true, false, newQRCodeVerificationCallbacks(), newQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {false, true, newQRCodeVerificationCallbacks(), newQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {true, false, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, + {true, true, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, + } + + for i, tc := range testCases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLogin(t, ctx) + defer ts.Close() + + recoveryKey, sendingCrossSigningKeysCache, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + assert.NoError(t, err) + assert.NotEmpty(t, recoveryKey) + assert.NotNil(t, sendingCrossSigningKeysCache) + + sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, tc.sendingCallbacks, tc.sendingSupportsScan) + err = sendingHelper.Init(ctx) + require.NoError(t, err) + + receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, tc.receivingCallbacks, tc.receivingSupportsScan) + err = receivingHelper.Init(ctx) + require.NoError(t, err) + + txnID, err := sendingHelper.StartVerification(ctx, userID) + require.NoError(t, err) + + // Process the verification request on the receiving device. + ts.dispatchToDevice(t, ctx, receivingClient) + + // Ensure that the receiving device received a verification + // request with the correct transaction ID. + assert.ElementsMatch(t, []id.VerificationTransactionID{txnID}, tc.receivingCallbacks.GetRequestedVerifications()[userID]) + + // Have the receiving device accept the verification request. + err = receivingHelper.AcceptVerification(ctx, txnID) + require.NoError(t, err) + + _, sendingIsQRCallbacks := tc.sendingCallbacks.(*qrCodeVerificationCallbacks) + _, sendingIsAllCallbacks := tc.sendingCallbacks.(*allVerificationCallbacks) + sendingCanShowQR := sendingIsQRCallbacks || sendingIsAllCallbacks + _, receivingIsQRCallbacks := tc.receivingCallbacks.(*qrCodeVerificationCallbacks) + _, receivingIsAllCallbacks := tc.receivingCallbacks.(*allVerificationCallbacks) + receivingCanShowQR := receivingIsQRCallbacks || receivingIsAllCallbacks + + // Ensure that if the receiving device should show a QR code that + // it has the correct content. + if tc.sendingSupportsScan && receivingCanShowQR { + receivingShownQRCode := tc.receivingCallbacks.GetQRCodeShown(txnID) + require.NotNil(t, receivingShownQRCode) + assert.Equal(t, txnID, receivingShownQRCode.TransactionID) + assert.NotEmpty(t, receivingShownQRCode.SharedSecret) + } + + // Check for whether the receiving device should be scanning a QR + // code. + if tc.receivingSupportsScan && sendingCanShowQR { + assert.Contains(t, tc.receivingCallbacks.GetScanQRCodeTransactions(), txnID) + } + + // Check that the m.key.verification.ready event has the correct + // content. + sendingInbox := ts.DeviceInbox[userID][sendingDeviceID] + assert.Len(t, sendingInbox, 1) + readyEvt := sendingInbox[0].Content.AsVerificationReady() + assert.Equal(t, txnID, readyEvt.TransactionID) + assert.Equal(t, receivingDeviceID, readyEvt.FromDevice) + assert.ElementsMatch(t, tc.expectedVerificationMethods, readyEvt.Methods) + + // Receive the m.key.verification.ready event on the sending + // device. + ts.dispatchToDevice(t, ctx, sendingClient) + + // Ensure that if the sending device should show a QR code that it + // has the correct content. + if tc.receivingSupportsScan && sendingCanShowQR { + sendingShownQRCode := tc.sendingCallbacks.GetQRCodeShown(txnID) + require.NotNil(t, sendingShownQRCode) + assert.Equal(t, txnID, sendingShownQRCode.TransactionID) + assert.NotEmpty(t, sendingShownQRCode.SharedSecret) + } + + // Check for whether the sending device should be scanning a QR + // code. + if tc.sendingSupportsScan && receivingCanShowQR { + assert.Contains(t, tc.sendingCallbacks.GetScanQRCodeTransactions(), txnID) + } + }) + } +} + +func TestSelfVerification_Accept_QRContents(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + + testCases := []struct { + sendingGeneratedCrossSigningKeys bool + receivingGeneratedCrossSigningKeys bool + expectedAcceptError string + }{ + {true, false, ""}, + {false, true, ""}, + {false, false, "failed to get own cross-signing master public key"}, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("sendingGenerated=%t receivingGenerated=%t err=%s", tc.sendingGeneratedCrossSigningKeys, tc.receivingGeneratedCrossSigningKeys, tc.expectedAcceptError), func(t *testing.T) { + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLogin(t, ctx) + defer ts.Close() + sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + var err error + + var sendingRecoveryKey, receivingRecoveryKey string + var sendingCrossSigningKeysCache, receivingCrossSigningKeysCache *crypto.CrossSigningKeysCache + + if tc.sendingGeneratedCrossSigningKeys { + sendingRecoveryKey, sendingCrossSigningKeysCache, err = sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + assert.NotEmpty(t, sendingRecoveryKey) + assert.NotNil(t, sendingCrossSigningKeysCache) + } + + if tc.receivingGeneratedCrossSigningKeys { + receivingRecoveryKey, receivingCrossSigningKeysCache, err = receivingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + assert.NotEmpty(t, receivingRecoveryKey) + assert.NotNil(t, receivingCrossSigningKeysCache) + } + + // Send the verification request from the sender device and accept + // it on the receiving device and receive the verification ready + // event on the sending device. + txnID, err := sendingHelper.StartVerification(ctx, userID) + require.NoError(t, err) + ts.dispatchToDevice(t, ctx, receivingClient) + + err = receivingHelper.AcceptVerification(ctx, txnID) + if tc.expectedAcceptError != "" { + assert.ErrorContains(t, err, tc.expectedAcceptError) + return + } else { + require.NoError(t, err) + } + + ts.dispatchToDevice(t, ctx, sendingClient) + + receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID) + require.NotNil(t, receivingShownQRCode) + assert.NotEmpty(t, receivingShownQRCode.SharedSecret) + assert.Equal(t, txnID, receivingShownQRCode.TransactionID) + + sendingShownQRCode := sendingCallbacks.GetQRCodeShown(txnID) + require.NotNil(t, sendingShownQRCode) + assert.NotEmpty(t, sendingShownQRCode.SharedSecret) + assert.Equal(t, txnID, sendingShownQRCode.TransactionID) + + // See the spec for the QR Code format: + // https://spec.matrix.org/v1.10/client-server-api/#qr-code-format + if tc.receivingGeneratedCrossSigningKeys { + masterKeyBytes := receivingMachine.GetOwnCrossSigningPublicKeys(ctx).MasterKey.Bytes() + + // The receiving device should have shown a QR Code with + // trusted mode + assert.Equal(t, verificationhelper.QRCodeModeSelfVerifyingMasterKeyTrusted, receivingShownQRCode.Mode) + assert.EqualValues(t, masterKeyBytes, receivingShownQRCode.Key1) // master key + assert.EqualValues(t, sendingMachine.OwnIdentity().SigningKey.Bytes(), receivingShownQRCode.Key2) // other device key + + // The sending device should have shown a QR code with + // untrusted mode. + assert.Equal(t, verificationhelper.QRCodeModeSelfVerifyingMasterKeyUntrusted, sendingShownQRCode.Mode) + assert.EqualValues(t, sendingMachine.OwnIdentity().SigningKey.Bytes(), sendingShownQRCode.Key1) // own device key + assert.EqualValues(t, masterKeyBytes, sendingShownQRCode.Key2) // master key + } else if tc.sendingGeneratedCrossSigningKeys { + masterKeyBytes := sendingMachine.GetOwnCrossSigningPublicKeys(ctx).MasterKey.Bytes() + + // The receiving device should have shown a QR code with + // untrusted mode + assert.Equal(t, verificationhelper.QRCodeModeSelfVerifyingMasterKeyUntrusted, receivingShownQRCode.Mode) + assert.EqualValues(t, receivingMachine.OwnIdentity().SigningKey.Bytes(), receivingShownQRCode.Key1) // own device key + assert.EqualValues(t, masterKeyBytes, receivingShownQRCode.Key2) // master key + + // The sending device should have shown a QR code with trusted + // mode. + assert.Equal(t, verificationhelper.QRCodeModeSelfVerifyingMasterKeyTrusted, sendingShownQRCode.Mode) + assert.EqualValues(t, masterKeyBytes, sendingShownQRCode.Key1) // master key + assert.EqualValues(t, receivingMachine.OwnIdentity().SigningKey.Bytes(), sendingShownQRCode.Key2) // other device key + } + }) + } +} + +// TestAcceptSelfVerificationCancelOnNonParticipatingDevices ensures that we do +// not regress https://github.com/mautrix/go/pull/230. +func TestSelfVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + ts, sendingClient, receivingClient, sendingCryptoStore, receivingCryptoStore, sendingMachine, receivingMachine := initServerAndLogin(t, ctx) + defer ts.Close() + _, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + + nonParticipatingDeviceID1 := id.DeviceID("non-participating1") + nonParticipatingDeviceID2 := id.DeviceID("non-participating2") + addDeviceID(ctx, sendingCryptoStore, userID, nonParticipatingDeviceID1) + addDeviceID(ctx, sendingCryptoStore, userID, nonParticipatingDeviceID2) + addDeviceID(ctx, receivingCryptoStore, userID, nonParticipatingDeviceID1) + addDeviceID(ctx, receivingCryptoStore, userID, nonParticipatingDeviceID2) + + _, _, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + assert.NoError(t, err) + + // Send the verification request from the sender device and accept it on + // the receiving device. + txnID, err := sendingHelper.StartVerification(ctx, userID) + require.NoError(t, err) + ts.dispatchToDevice(t, ctx, receivingClient) + err = receivingHelper.AcceptVerification(ctx, txnID) + require.NoError(t, err) + + // Receive the m.key.verification.ready event on the sending device. + ts.dispatchToDevice(t, ctx, sendingClient) + + // The sending and receiving devices should not have any cancellation + // events in their inboxes. + assert.Empty(t, ts.DeviceInbox[userID][sendingDeviceID]) + assert.Empty(t, ts.DeviceInbox[userID][receivingDeviceID]) + + // There should now be cancellation events in the non-participating devices + // inboxes (in addition to the request event). + assert.Len(t, ts.DeviceInbox[userID][nonParticipatingDeviceID1], 2) + assert.Len(t, ts.DeviceInbox[userID][nonParticipatingDeviceID2], 2) + assert.Equal(t, ts.DeviceInbox[userID][nonParticipatingDeviceID1][1], ts.DeviceInbox[userID][nonParticipatingDeviceID2][1]) + cancellationEvent := ts.DeviceInbox[userID][nonParticipatingDeviceID1][1].Content.AsVerificationCancel() + assert.Equal(t, txnID, cancellationEvent.TransactionID) + assert.Equal(t, event.VerificationCancelCodeAccepted, cancellationEvent.Code) +} + +func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + + testCases := []struct { + sendingGeneratedCrossSigningKeys bool + sendingScansQR bool // false indicates that receiving device should emulate a scan + }{ + {false, false}, + {false, true}, + {true, false}, + {true, true}, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("sendingGeneratedCrossSigningKeys=%t sendingScansQR=%t", tc.sendingGeneratedCrossSigningKeys, tc.sendingScansQR), func(t *testing.T) { + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLogin(t, ctx) + defer ts.Close() + sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + var err error + + if tc.sendingGeneratedCrossSigningKeys { + _, _, err = sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + } else { + _, _, err = receivingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + } + + // Send the verification request from the sender device and accept + // it on the receiving device and receive the verification ready + // event on the sending device. + txnID, err := sendingHelper.StartVerification(ctx, userID) + require.NoError(t, err) + ts.dispatchToDevice(t, ctx, receivingClient) + err = receivingHelper.AcceptVerification(ctx, txnID) + require.NoError(t, err) + ts.dispatchToDevice(t, ctx, sendingClient) + + receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID) + require.NotNil(t, receivingShownQRCode) + sendingShownQRCode := sendingCallbacks.GetQRCodeShown(txnID) + require.NotNil(t, sendingShownQRCode) + + if tc.sendingScansQR { + // Emulate scanning the QR code shown by the receiving device + // on the sending device. + err := sendingHelper.HandleScannedQRData(ctx, receivingShownQRCode.Bytes()) + require.NoError(t, err) + + // Ensure that the receiving device received a verification + // start event and a verification done event. + receivingInbox := ts.DeviceInbox[userID][receivingDeviceID] + assert.Len(t, receivingInbox, 2) + + startEvt := receivingInbox[0].Content.AsVerificationStart() + assert.Equal(t, txnID, startEvt.TransactionID) + assert.Equal(t, sendingDeviceID, startEvt.FromDevice) + assert.Equal(t, event.VerificationMethodReciprocate, startEvt.Method) + assert.EqualValues(t, receivingShownQRCode.SharedSecret, startEvt.Secret) + + doneEvt := receivingInbox[1].Content.AsVerificationDone() + assert.Equal(t, txnID, doneEvt.TransactionID) + + // Handle the start and done events on the receiving client and + // confirm the scan. + ts.dispatchToDevice(t, ctx, receivingClient) + + // Ensure that the receiving device detected that its QR code + // was scanned. + assert.True(t, receivingCallbacks.WasOurQRCodeScanned(txnID)) + err = receivingHelper.ConfirmQRCodeScanned(ctx, txnID) + require.NoError(t, err) + + // Ensure that the sending device received a verification done + // event. + sendingInbox := ts.DeviceInbox[userID][sendingDeviceID] + require.Len(t, sendingInbox, 1) + doneEvt = sendingInbox[0].Content.AsVerificationDone() + assert.Equal(t, txnID, doneEvt.TransactionID) + + ts.dispatchToDevice(t, ctx, sendingClient) + } else { // receiving scans QR + // Emulate scanning the QR code shown by the sending device on + // the receiving device. + err := receivingHelper.HandleScannedQRData(ctx, sendingShownQRCode.Bytes()) + require.NoError(t, err) + + // Ensure that the sending device received a verification + // start event and a verification done event. + sendingInbox := ts.DeviceInbox[userID][sendingDeviceID] + assert.Len(t, sendingInbox, 2) + + startEvt := sendingInbox[0].Content.AsVerificationStart() + assert.Equal(t, txnID, startEvt.TransactionID) + assert.Equal(t, receivingDeviceID, startEvt.FromDevice) + assert.Equal(t, event.VerificationMethodReciprocate, startEvt.Method) + assert.EqualValues(t, sendingShownQRCode.SharedSecret, startEvt.Secret) + + doneEvt := sendingInbox[1].Content.AsVerificationDone() + assert.Equal(t, txnID, doneEvt.TransactionID) + + // Handle the start and done events on the receiving client and + // confirm the scan. + ts.dispatchToDevice(t, ctx, sendingClient) + + // Ensure that the sending device detected that its QR code was + // scanned. + assert.True(t, sendingCallbacks.WasOurQRCodeScanned(txnID)) + err = sendingHelper.ConfirmQRCodeScanned(ctx, txnID) + require.NoError(t, err) + + // Ensure that the receiving device received a verification + // done event. + receivingInbox := ts.DeviceInbox[userID][receivingDeviceID] + require.Len(t, receivingInbox, 1) + doneEvt = receivingInbox[0].Content.AsVerificationDone() + assert.Equal(t, txnID, doneEvt.TransactionID) + + ts.dispatchToDevice(t, ctx, receivingClient) + } + + // Ensure that both devices have marked the verification as done. + assert.True(t, sendingCallbacks.IsVerificationDone(txnID)) + assert.True(t, receivingCallbacks.IsVerificationDone(txnID)) + }) + } +} + +func TestSelfVerification_ErrorOnDoubleAccept(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLogin(t, ctx) + defer ts.Close() + _, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + + _, _, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + + txnID, err := sendingHelper.StartVerification(ctx, userID) + require.NoError(t, err) + ts.dispatchToDevice(t, ctx, receivingClient) + err = receivingHelper.AcceptVerification(ctx, txnID) + require.NoError(t, err) + err = receivingHelper.AcceptVerification(ctx, txnID) + require.ErrorContains(t, err, "transaction is not in the requested state") +} + +func TestSelfVerification_ScanQRTransactionIDCorrupted(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLogin(t, ctx) + defer ts.Close() + sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + var err error + + _, _, err = sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + + // Send the verification request from the sender device and accept + // it on the receiving device and receive the verification ready + // event on the sending device. + txnID, err := sendingHelper.StartVerification(ctx, userID) + require.NoError(t, err) + ts.dispatchToDevice(t, ctx, receivingClient) + err = receivingHelper.AcceptVerification(ctx, txnID) + require.NoError(t, err) + ts.dispatchToDevice(t, ctx, sendingClient) + + receivingShownQRCodeBytes := receivingCallbacks.GetQRCodeShown(txnID).Bytes() + sendingShownQRCodeBytes := sendingCallbacks.GetQRCodeShown(txnID).Bytes() + + // Corrupt the QR codes (the 20th byte should be in the transaction ID) + receivingShownQRCodeBytes[20]++ + sendingShownQRCodeBytes[20]++ + + // Emulate scanning the QR code shown by the receiving device + // on the sending device. + err = sendingHelper.HandleScannedQRData(ctx, receivingShownQRCodeBytes) + assert.ErrorContains(t, err, "unknown transaction ID found in QR code") + + // Emulate scanning the QR code shown by the sending device on + // the receiving device. + err = receivingHelper.HandleScannedQRData(ctx, sendingShownQRCodeBytes) + assert.ErrorContains(t, err, "unknown transaction ID found in QR code") +} + +func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + + testCases := []struct { + sendingGeneratedCrossSigningKeys bool + sendingScansQR bool // false indicates that receiving device should emulate a scan + corruptByte int + expectedError string + }{ + // The 50th byte should be in the first key + {false, false, 50, "the other device's key is not what we expected"}, // receiver scans sender QR code, sender doesn't trust the master key => mode 0x02 => key1 == sender device key + {false, true, 50, "the master key does not match"}, // sender scans receiver QR code, receiver trusts the master key => mode 0x01 => key1 == master key + {true, false, 50, "the master key does not match"}, // receiver scans sender QR code, sender trusts the master key => mode 0x01 => key1 == master key + {true, true, 50, "the other device's key is not what we expected"}, // sender scans receiver QR Code, receiver doesn't trust the master key => mode 0x02 => key1 == receiver device key + // The 100th byte should be in the second key + {false, false, 100, "the master key does not match"}, // receiver scans sender QR code, sender doesn't trust the master key => mode 0x02 => key2 == master key + {false, true, 100, "the other device has the wrong key for this device"}, // sender scans receiver QR code, receiver trusts the master key => mode 0x01 => key2 == sender device key + {true, false, 100, "the other device has the wrong key for this device"}, // receiver scans sender QR code, sender trusts the master key => mode 0x01 => key2 == receiver device key + {true, true, 100, "the master key does not match"}, // sender scans receiver QR Code, receiver doesn't trust the master key => mode 0x02 => key2 == master key + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("sendingGeneratedCrossSigningKeys=%t sendingScansQR=%t corrupt=%d", tc.sendingGeneratedCrossSigningKeys, tc.sendingScansQR, tc.corruptByte), func(t *testing.T) { + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLogin(t, ctx) + defer ts.Close() + sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + var err error + + if tc.sendingGeneratedCrossSigningKeys { + _, _, err = sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + } else { + _, _, err = receivingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + } + + // Send the verification request from the sender device and accept + // it on the receiving device and receive the verification ready + // event on the sending device. + txnID, err := sendingHelper.StartVerification(ctx, userID) + require.NoError(t, err) + ts.dispatchToDevice(t, ctx, receivingClient) + err = receivingHelper.AcceptVerification(ctx, txnID) + require.NoError(t, err) + ts.dispatchToDevice(t, ctx, sendingClient) + + receivingShownQRCodeBytes := receivingCallbacks.GetQRCodeShown(txnID).Bytes() + sendingShownQRCodeBytes := sendingCallbacks.GetQRCodeShown(txnID).Bytes() + + // Corrupt the QR codes + receivingShownQRCodeBytes[tc.corruptByte]++ + sendingShownQRCodeBytes[tc.corruptByte]++ + + if tc.sendingScansQR { + // Emulate scanning the QR code shown by the receiving device + // on the sending device. + err := sendingHelper.HandleScannedQRData(ctx, receivingShownQRCodeBytes) + assert.ErrorContains(t, err, tc.expectedError) + + // Ensure that the receiving device received a cancellation. + receivingInbox := ts.DeviceInbox[userID][receivingDeviceID] + assert.Len(t, receivingInbox, 1) + ts.dispatchToDevice(t, ctx, receivingClient) + cancellation := receivingCallbacks.GetVerificationCancellation(txnID) + require.NotNil(t, cancellation) + assert.Equal(t, event.VerificationCancelCodeKeyMismatch, cancellation.Code) + assert.Equal(t, tc.expectedError, cancellation.Reason) + } else { // receiving scans QR + // Emulate scanning the QR code shown by the sending device on + // the receiving device. + err := receivingHelper.HandleScannedQRData(ctx, sendingShownQRCodeBytes) + assert.ErrorContains(t, err, tc.expectedError) + + // Ensure that the sending device received a cancellation. + sendingInbox := ts.DeviceInbox[userID][sendingDeviceID] + assert.Len(t, sendingInbox, 1) + ts.dispatchToDevice(t, ctx, sendingClient) + cancellation := sendingCallbacks.GetVerificationCancellation(txnID) + require.NotNil(t, cancellation) + assert.Equal(t, event.VerificationCancelCodeKeyMismatch, cancellation.Code) + assert.Equal(t, tc.expectedError, cancellation.Reason) + } + }) + } +} From b196541e9865959193dd069f5cc9174d9fb3b757 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 23 May 2024 00:38:57 +0300 Subject: [PATCH 0212/1647] Fix crypto_secrets table schema --- crypto/sql_store.go | 10 +++++----- crypto/sql_store_upgrade/00-latest-revision.sql | 9 ++++++--- crypto/sql_store_upgrade/15-fix-secrets.sql | 16 ++++++++++++++++ 3 files changed, 27 insertions(+), 8 deletions(-) create mode 100644 crypto/sql_store_upgrade/15-fix-secrets.sql diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 689c25f0..15731aca 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -891,15 +891,15 @@ func (store *SQLCryptoStore) PutSecret(ctx context.Context, name id.Secret, valu return err } _, err = store.DB.Exec(ctx, ` - INSERT INTO crypto_secrets (name, secret) VALUES ($1, $2) - ON CONFLICT (name) DO UPDATE SET secret=excluded.secret - `, name, bytes) + INSERT INTO crypto_secrets (account_id, name, secret) VALUES ($1, $2, $3) + ON CONFLICT (account_id, name) DO UPDATE SET secret=excluded.secret + `, store.AccountID, name, bytes) return err } func (store *SQLCryptoStore) GetSecret(ctx context.Context, name id.Secret) (value string, err error) { var bytes []byte - err = store.DB.QueryRow(ctx, `SELECT secret FROM crypto_secrets WHERE name=$1`, name).Scan(&bytes) + err = store.DB.QueryRow(ctx, `SELECT secret FROM crypto_secrets WHERE account_id=$1 AND name=$2`, store.AccountID, name).Scan(&bytes) if errors.Is(err, sql.ErrNoRows) { return "", nil } else if err != nil { @@ -910,6 +910,6 @@ func (store *SQLCryptoStore) GetSecret(ctx context.Context, name id.Secret) (val } func (store *SQLCryptoStore) DeleteSecret(ctx context.Context, name id.Secret) (err error) { - _, err = store.DB.Exec(ctx, "DELETE FROM crypto_secrets WHERE name=$1", name) + _, err = store.DB.Exec(ctx, "DELETE FROM crypto_secrets WHERE account_id=$1 AND name=$2", store.AccountID, name) return } diff --git a/crypto/sql_store_upgrade/00-latest-revision.sql b/crypto/sql_store_upgrade/00-latest-revision.sql index 06aea750..7e039af5 100644 --- a/crypto/sql_store_upgrade/00-latest-revision.sql +++ b/crypto/sql_store_upgrade/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v14 (compatible with v9+): Latest revision +-- v0 -> v15: Latest revision CREATE TABLE IF NOT EXISTS crypto_account ( account_id TEXT PRIMARY KEY, device_id TEXT NOT NULL, @@ -105,6 +105,9 @@ CREATE TABLE IF NOT EXISTS crypto_cross_signing_signatures ( ); CREATE TABLE IF NOT EXISTS crypto_secrets ( - name TEXT PRIMARY KEY NOT NULL, - secret bytea NOT NULL + account_id TEXT NOT NULL, + name TEXT NOT NULL, + secret bytea NOT NULL, + + PRIMARY KEY (account_id, name) ); diff --git a/crypto/sql_store_upgrade/15-fix-secrets.sql b/crypto/sql_store_upgrade/15-fix-secrets.sql new file mode 100644 index 00000000..47235397 --- /dev/null +++ b/crypto/sql_store_upgrade/15-fix-secrets.sql @@ -0,0 +1,16 @@ +-- v15: Fix crypto_secrets table +CREATE TABLE crypto_secrets_new ( + account_id TEXT NOT NULL, + name TEXT NOT NULL, + secret bytea NOT NULL, + + PRIMARY KEY (account_id, name) +); + +INSERT INTO crypto_secrets_new (account_id, name, secret) +SELECT '', name, secret +FROM crypto_secrets; + +DROP TABLE crypto_secrets; + +ALTER TABLE crypto_secrets_new RENAME TO crypto_secrets; From 826c8cf28e41ce34c030054e094d8af9887230d6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 25 May 2024 23:01:07 +0300 Subject: [PATCH 0213/1647] Update m.relates_to in raw decrypted payload --- crypto/decryptmegolm.go | 38 +++++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index 14beb96b..49de2e44 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -14,6 +14,9 @@ import ( "strings" "github.com/rs/zerolog" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.mau.fi/util/exgjson" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/event" @@ -36,6 +39,8 @@ type megolmEvent struct { Content event.Content `json:"content"` } +var relatesToPath = exgjson.Path("m.relates_to") + // DecryptMegolmEvent decrypts an m.room.encrypted event where the algorithm is m.megolm.v1.aes-sha2 func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event) (*event.Event, error) { content, ok := evt.Content.Parsed.(*event.EncryptedEventContent) @@ -107,6 +112,26 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event } } + if content.RelatesTo != nil { + relation := gjson.GetBytes(evt.Content.VeryRaw, relatesToPath) + if relation.Exists() { + var raw []byte + if relation.Index > 0 { + raw = evt.Content.VeryRaw[relation.Index : relation.Index+len(relation.Raw)] + } else { + raw = []byte(relation.Raw) + } + updatedPlaintext, err := sjson.SetRawBytes(plaintext, relatesToPath, raw) + if err != nil { + log.Warn().Msg("Failed to copy m.relates_to to decrypted payload") + } else if updatedPlaintext != nil { + plaintext = updatedPlaintext + } + } else { + log.Warn().Msg("Failed to find m.relates_to in raw encrypted event even though it was present in parsed content") + } + } + megolmEvt := &megolmEvent{} err = json.Unmarshal(plaintext, &megolmEvt) if err != nil { @@ -124,19 +149,6 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event return nil, fmt.Errorf("failed to parse content of megolm payload event: %w", err) } } - if content.RelatesTo != nil { - relatable, ok := megolmEvt.Content.Parsed.(event.Relatable) - if ok { - if relatable.OptionalGetRelatesTo() == nil { - relatable.SetRelatesTo(content.RelatesTo) - } else { - log.Trace().Msg("Not overriding relation data as encrypted payload already has it") - } - } - if _, hasRelation := megolmEvt.Content.Raw["m.relates_to"]; !hasRelation { - megolmEvt.Content.Raw["m.relates_to"] = evt.Content.Raw["m.relates_to"] - } - } log.Debug().Msg("Event decrypted successfully") megolmEvt.Type.Class = evt.Type.Class return &event.Event{ From 881879ea0a30d66541d78da3413ee0d05dfe8dc6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 25 May 2024 23:02:13 +0300 Subject: [PATCH 0214/1647] Do first sync with timeout 0 --- client.go | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index 10d1b2b9..6d059230 100644 --- a/client.go +++ b/client.go @@ -199,14 +199,20 @@ func (cli *Client) SyncWithContext(ctx context.Context) error { } } lastSuccessfulSync := time.Now().Add(-cli.StreamSyncMinAge - 1*time.Hour) + // Always do first sync with 0 timeout + isFailing := true for { streamResp := false if cli.StreamSyncMinAge > 0 && time.Since(lastSuccessfulSync) > cli.StreamSyncMinAge { cli.Log.Debug().Msg("Last sync is old, will stream next response") streamResp = true } + timeout := 30000 + if isFailing { + timeout = 0 + } resSync, err := cli.FullSyncRequest(ctx, ReqSync{ - Timeout: 30000, + Timeout: timeout, Since: nextBatch, FilterID: filterID, FullState: false, @@ -214,6 +220,7 @@ func (cli *Client) SyncWithContext(ctx context.Context) error { StreamResponse: streamResp, }) if err != nil { + isFailing = true if ctx.Err() != nil { return ctx.Err() } @@ -221,6 +228,9 @@ func (cli *Client) SyncWithContext(ctx context.Context) error { if err2 != nil { return err2 } + if duration <= 0 { + continue + } select { case <-ctx.Done(): return ctx.Err() @@ -228,6 +238,7 @@ func (cli *Client) SyncWithContext(ctx context.Context) error { continue } } + isFailing = false lastSuccessfulSync = time.Now() // Check that the syncing state hasn't changed From d64447c3f7d57884c469350a172ff5578cef8f1f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 25 May 2024 23:02:58 +0300 Subject: [PATCH 0215/1647] Clamp megolm session rotation periods to sensible limits --- crypto/keyexport.go | 7 ------- crypto/sessions.go | 5 ++++- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/crypto/keyexport.go b/crypto/keyexport.go index 91bfb6c6..bb373f4d 100644 --- a/crypto/keyexport.go +++ b/crypto/keyexport.go @@ -117,13 +117,6 @@ func exportSessionsJSON(sessions []*InboundGroupSession) ([]byte, error) { return json.Marshal(exportedSessions) } -func min(a, b int) int { - if a > b { - return b - } - return a -} - func formatKeyExportData(data []byte) []byte { base64Data := make([]byte, base64.StdEncoding.EncodedLen(len(data))) base64.StdEncoding.Encode(base64Data, data) diff --git a/crypto/sessions.go b/crypto/sessions.go index 045af933..6075a644 100644 --- a/crypto/sessions.go +++ b/crypto/sessions.go @@ -196,11 +196,14 @@ func NewOutboundGroupSession(roomID id.RoomID, encryptionContent *event.Encrypti RoomID: roomID, } if encryptionContent != nil { + // Clamp rotation period to prevent unreasonable values + // Similar to https://github.com/matrix-org/matrix-rust-sdk/blob/matrix-sdk-crypto-0.7.1/crates/matrix-sdk-crypto/src/olm/group_sessions/outbound.rs#L415-L441 if encryptionContent.RotationPeriodMillis != 0 { ogs.MaxAge = time.Duration(encryptionContent.RotationPeriodMillis) * time.Millisecond + ogs.MaxAge = min(max(ogs.MaxAge, 1*time.Hour), 365*24*time.Hour) } if encryptionContent.RotationPeriodMessages != 0 { - ogs.MaxMessages = encryptionContent.RotationPeriodMessages + ogs.MaxMessages = min(max(encryptionContent.RotationPeriodMessages, 1), 10000) } } return ogs From a2169274da2999d08239532b4ff7fb8136ec1fc5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 25 May 2024 23:03:26 +0300 Subject: [PATCH 0216/1647] Include room ID and first known index in SessionReceived callback --- crypto/keybackup.go | 5 +++-- crypto/keyimport.go | 5 +++-- crypto/keysharing.go | 5 +++-- crypto/machine.go | 8 ++++---- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/crypto/keybackup.go b/crypto/keybackup.go index 820f3114..7d8148f6 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -160,7 +160,8 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id. maxMessages = config.RotationPeriodMessages } - if firstKnownIndex := igsInternal.FirstKnownIndex(); firstKnownIndex > 0 { + firstKnownIndex := igsInternal.FirstKnownIndex() + if firstKnownIndex > 0 { log.Warn().Uint32("first_known_index", firstKnownIndex).Msg("Importing partial session") } @@ -181,6 +182,6 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id. if err != nil { return fmt.Errorf("failed to store new inbound group session: %w", err) } - mach.markSessionReceived(ctx, sessionID) + mach.markSessionReceived(ctx, roomID, sessionID, firstKnownIndex) return nil } diff --git a/crypto/keyimport.go b/crypto/keyimport.go index 6c320f43..693ff6b8 100644 --- a/crypto/keyimport.go +++ b/crypto/keyimport.go @@ -114,7 +114,8 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor ReceivedAt: time.Now().UTC(), } existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID()) - if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() { + firstKnownIndex := igs.Internal.FirstKnownIndex() + if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= firstKnownIndex { // We already have an equivalent or better session in the store, so don't override it. return false, nil } @@ -122,7 +123,7 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor if err != nil { return false, fmt.Errorf("failed to store imported session: %w", err) } - mach.markSessionReceived(ctx, igs.ID()) + mach.markSessionReceived(ctx, session.RoomID, igs.ID(), firstKnownIndex) return true, nil } diff --git a/crypto/keysharing.go b/crypto/keysharing.go index ad0011e5..362dee81 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -168,7 +168,8 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt if content.MaxMessages != 0 { maxMessages = content.MaxMessages } - if firstKnownIndex := igsInternal.FirstKnownIndex(); firstKnownIndex > 0 { + firstKnownIndex := igsInternal.FirstKnownIndex() + if firstKnownIndex > 0 { log.Warn().Uint32("first_known_index", firstKnownIndex).Msg("Importing partial session") } igs := &InboundGroupSession{ @@ -194,7 +195,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt log.Error().Err(err).Msg("Failed to store new inbound group session") return false } - mach.markSessionReceived(ctx, content.SessionID) + mach.markSessionReceived(ctx, content.RoomID, content.SessionID, firstKnownIndex) log.Debug().Msg("Received forwarded inbound group session") return true } diff --git a/crypto/machine.go b/crypto/machine.go index abb8d540..c9c06c3b 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -53,7 +53,7 @@ type OlmMachine struct { keyWaitersLock sync.Mutex // Optional callback which is called when we save a session to store - SessionReceived func(context.Context, id.SessionID) + SessionReceived func(context.Context, id.RoomID, id.SessionID, uint32) devicesToUnwedge map[id.IdentityKey]bool devicesToUnwedgeLock sync.Mutex @@ -523,7 +523,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen log.Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session") return fmt.Errorf("failed to store new inbound group session: %w", err) } - mach.markSessionReceived(ctx, sessionID) + mach.markSessionReceived(ctx, roomID, sessionID, igs.Internal.FirstKnownIndex()) log.Debug(). Str("session_id", sessionID.String()). Str("sender_key", senderKey.String()). @@ -534,9 +534,9 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen return nil } -func (mach *OlmMachine) markSessionReceived(ctx context.Context, id id.SessionID) { +func (mach *OlmMachine) markSessionReceived(ctx context.Context, roomID id.RoomID, id id.SessionID, firstKnownIndex uint32) { if mach.SessionReceived != nil { - mach.SessionReceived(ctx, id) + mach.SessionReceived(ctx, roomID, id, firstKnownIndex) } mach.keyWaitersLock.Lock() From 98c491e069f3db74a71906b46899830d80fd90e4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 25 May 2024 23:03:39 +0300 Subject: [PATCH 0217/1647] Add constants for room versions --- event/state.go | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/event/state.go b/event/state.go index d6b6cf70..e03e6a85 100644 --- a/event/state.go +++ b/event/state.go @@ -56,13 +56,29 @@ type Predecessor struct { EventID id.EventID `json:"event_id"` } +type RoomVersion string + +const ( + RoomV1 RoomVersion = "1" + RoomV2 RoomVersion = "2" + RoomV3 RoomVersion = "3" + RoomV4 RoomVersion = "4" + RoomV5 RoomVersion = "5" + RoomV6 RoomVersion = "6" + RoomV7 RoomVersion = "7" + RoomV8 RoomVersion = "8" + RoomV9 RoomVersion = "9" + RoomV10 RoomVersion = "10" + RoomV11 RoomVersion = "11" +) + // CreateEventContent represents the content of a m.room.create state event. // https://spec.matrix.org/v1.2/client-server-api/#mroomcreate type CreateEventContent struct { Type RoomType `json:"type,omitempty"` Creator id.UserID `json:"creator,omitempty"` Federate bool `json:"m.federate,omitempty"` - RoomVersion string `json:"room_version,omitempty"` + RoomVersion RoomVersion `json:"room_version,omitempty"` Predecessor *Predecessor `json:"predecessor,omitempty"` } From 797aed1e83bcf4b2a2e15005a337249080967b3d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 25 May 2024 23:03:54 +0300 Subject: [PATCH 0218/1647] Update `m.megolm_backup.v1` event type to reference secret ID constant --- event/type.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/event/type.go b/event/type.go index a4b36392..2c801d5e 100644 --- a/event/type.go +++ b/event/type.go @@ -244,7 +244,7 @@ var ( AccountDataCrossSigningMaster = Type{string(id.SecretXSMaster), AccountDataEventType} AccountDataCrossSigningUser = Type{string(id.SecretXSUserSigning), AccountDataEventType} AccountDataCrossSigningSelf = Type{string(id.SecretXSSelfSigning), AccountDataEventType} - AccountDataMegolmBackupKey = Type{"m.megolm_backup.v1", AccountDataEventType} + AccountDataMegolmBackupKey = Type{string(id.SecretMegolmBackupV1), AccountDataEventType} ) // Device-to-device events From 2497fe4397a65c8a7bc9faed0aa815826c43221a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 25 May 2024 23:55:28 +0300 Subject: [PATCH 0219/1647] Export function to parse megolm message index --- crypto/decryptmegolm.go | 2 +- crypto/encryptmegolm.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index 49de2e44..1b714d09 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -182,7 +182,7 @@ const missedIndexCutoff = 10 func (mach *OlmMachine) checkUndecryptableMessageIndexDuplication(ctx context.Context, sess *InboundGroupSession, evt *event.Event, content *event.EncryptedEventContent) (uint, error) { log := *zerolog.Ctx(ctx) - messageIndex, decodeErr := parseMessageIndex(content.MegolmCiphertext) + messageIndex, decodeErr := ParseMegolmMessageIndex(content.MegolmCiphertext) if decodeErr != nil { log.Warn().Err(decodeErr).Msg("Failed to parse message index to check if it's a duplicate for message that failed to decrypt") return 0, fmt.Errorf("%w (also failed to parse message index)", olm.UnknownMessageIndex) diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index 634a685f..fd7b8ea2 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -61,7 +61,7 @@ func IsShareError(err error) bool { return err == SessionExpired || err == SessionNotShared || err == NoGroupSession } -func parseMessageIndex(ciphertext []byte) (uint, error) { +func ParseMegolmMessageIndex(ciphertext []byte) (uint, error) { decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(ciphertext))) var err error _, err = base64.RawStdEncoding.Decode(decoded, ciphertext) @@ -109,7 +109,7 @@ func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID if err != nil { return nil, err } - idx, err := parseMessageIndex(ciphertext) + idx, err := ParseMegolmMessageIndex(ciphertext) if err != nil { log.Warn().Err(err).Msg("Failed to get megolm message index of encrypted event") } else { From d7011a7f8b1b78f833372d11dd89657d143a2aa8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 25 May 2024 23:55:53 +0300 Subject: [PATCH 0220/1647] Return imported session in ImportRoomKeyFromBackup --- crypto/keybackup.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/crypto/keybackup.go b/crypto/keybackup.go index 7d8148f6..8709fe92 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -114,7 +114,7 @@ func (mach *OlmMachine) GetAndStoreKeyBackup(ctx context.Context, version id.Key continue } - err = mach.ImportRoomKeyFromBackup(ctx, version, roomID, sessionID, sessionData) + _, err = mach.ImportRoomKeyFromBackup(ctx, version, roomID, sessionID, sessionData) if err != nil { log.Warn().Err(err).Msg("Failed to import room key from backup") failedCount++ @@ -132,23 +132,23 @@ func (mach *OlmMachine) GetAndStoreKeyBackup(ctx context.Context, version id.Key return nil } -func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID, keyBackupData *backup.MegolmSessionData) error { +func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID, keyBackupData *backup.MegolmSessionData) (*InboundGroupSession, error) { log := zerolog.Ctx(ctx).With(). Str("room_id", roomID.String()). Str("session_id", sessionID.String()). Logger() if keyBackupData.Algorithm != id.AlgorithmMegolmV1 { - return fmt.Errorf("ignoring room key in backup with weird algorithm %s", keyBackupData.Algorithm) + return nil, fmt.Errorf("ignoring room key in backup with weird algorithm %s", keyBackupData.Algorithm) } igsInternal, err := olm.InboundGroupSessionImport([]byte(keyBackupData.SessionKey)) if err != nil { - return fmt.Errorf("failed to import inbound group session: %w", err) + return nil, fmt.Errorf("failed to import inbound group session: %w", err) } else if igsInternal.ID() != sessionID { log.Warn(). Stringer("actual_session_id", igsInternal.ID()). Msg("Mismatched session ID while creating inbound group session from key backup") - return fmt.Errorf("mismatched session ID while creating inbound group session from key backup") + return nil, fmt.Errorf("mismatched session ID while creating inbound group session from key backup") } var maxAge time.Duration @@ -180,8 +180,8 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id. } err = mach.CryptoStore.PutGroupSession(ctx, igs) if err != nil { - return fmt.Errorf("failed to store new inbound group session: %w", err) + return nil, fmt.Errorf("failed to store new inbound group session: %w", err) } mach.markSessionReceived(ctx, roomID, sessionID, firstKnownIndex) - return nil + return igs, nil } From 5afa391317622cf22bdc262ab529efdd00dc4be3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 26 May 2024 17:30:10 +0300 Subject: [PATCH 0221/1647] Refactor MarkTrackedUsersOutdated to use single query --- crypto/sql_store.go | 40 +++++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 15731aca..26b4ddbe 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -759,6 +759,17 @@ func (store *SQLCryptoStore) PutDevices(ctx context.Context, userID id.UserID, d }) } +func userIDsToParams(users []id.UserID) (placeholders string, params []any) { + queryString := make([]string, len(users)) + params = make([]any, len(users)) + for i, user := range users { + queryString[i] = fmt.Sprintf("$%d", i+1) + params[i] = user + } + placeholders = strings.Join(queryString, ",") + return +} + // FilterTrackedUsers finds all the user IDs out of the given ones for which the database contains identity information. func (store *SQLCryptoStore) FilterTrackedUsers(ctx context.Context, users []id.UserID) ([]id.UserID, error) { var rows dbutil.Rows @@ -766,13 +777,8 @@ func (store *SQLCryptoStore) FilterTrackedUsers(ctx context.Context, users []id. if store.DB.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil { rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", PostgresArrayWrapper(users)) } else { - queryString := make([]string, len(users)) - params := make([]interface{}, len(users)) - for i, user := range users { - queryString[i] = fmt.Sprintf("?%d", i+1) - params[i] = user - } - rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+strings.Join(queryString, ",")+")", params...) + placeholders, params := userIDsToParams(users) + rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+placeholders+")", params...) } if err != nil { return users, err @@ -781,18 +787,14 @@ func (store *SQLCryptoStore) FilterTrackedUsers(ctx context.Context, users []id. } // MarkTrackedUsersOutdated flags that the device list for given users are outdated. -func (store *SQLCryptoStore) MarkTrackedUsersOutdated(ctx context.Context, users []id.UserID) error { - return store.DB.DoTxn(ctx, nil, func(ctx context.Context) error { - // TODO refactor to use a single query - for _, userID := range users { - _, err := store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id = $1", userID) - if err != nil { - return fmt.Errorf("failed to update user in the tracked users list: %w", err) - } - } - - return nil - }) +func (store *SQLCryptoStore) MarkTrackedUsersOutdated(ctx context.Context, users []id.UserID) (err error) { + if store.DB.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil { + _, err = store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id = ANY($1)", PostgresArrayWrapper(users)) + } else { + placeholders, params := userIDsToParams(users) + _, err = store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id IN ("+placeholders+")", params...) + } + return } // GetOutdatedTrackerUsers gets all tracked users whose devices need to be updated. From 96676c13594886edbe27e9549e1602b3449ced8a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 26 May 2024 18:13:10 +0300 Subject: [PATCH 0222/1647] Remove separate go.mod for example --- example/go.mod | 29 -------------------------- example/go.sum | 56 -------------------------------------------------- go.mod | 1 + go.sum | 7 +++++++ 4 files changed, 8 insertions(+), 85 deletions(-) delete mode 100644 example/go.mod delete mode 100644 example/go.sum diff --git a/example/go.mod b/example/go.mod deleted file mode 100644 index 60583640..00000000 --- a/example/go.mod +++ /dev/null @@ -1,29 +0,0 @@ -module maunium.net/go/mautrix/example - -go 1.21 - -toolchain go1.22.0 - -require ( - github.com/chzyer/readline v1.5.1 - github.com/mattn/go-sqlite3 v1.14.22 - github.com/rs/zerolog v1.32.0 - go.mau.fi/util v0.3.1-0.20240208085450-32294da153ab - maunium.net/go/mautrix v0.17.1-0.20240208085632-2d1786ced444 -) - -require ( - github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-isatty v0.0.19 // indirect - github.com/tidwall/gjson v1.17.0 // indirect - github.com/tidwall/match v1.1.1 // indirect - github.com/tidwall/pretty v1.2.0 // indirect - github.com/tidwall/sjson v1.2.5 // indirect - golang.org/x/crypto v0.19.0 // indirect - golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3 // indirect - golang.org/x/net v0.21.0 // indirect - golang.org/x/sys v0.17.0 // indirect - maunium.net/go/maulogger/v2 v2.4.1 // indirect -) - -//replace maunium.net/go/mautrix => ../ diff --git a/example/go.sum b/example/go.sum deleted file mode 100644 index f81f31c2..00000000 --- a/example/go.sum +++ /dev/null @@ -1,56 +0,0 @@ -github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= -github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= -github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM= -github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= -github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI= -github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= -github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= -github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= -github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= -github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= -github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= -github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= -github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= -github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0= -github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.17.0 h1:/Jocvlh98kcTfpN2+JzGQWQcqrPQwDrVEMApx/M5ZwM= -github.com/tidwall/gjson v1.17.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= -github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= -github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= -github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= -github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= -github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -go.mau.fi/util v0.3.1-0.20240208085450-32294da153ab h1:XZ8W5vHWlXSGmHn1U+Fvbh+xZr9wuHTvbY+qV7aybDY= -go.mau.fi/util v0.3.1-0.20240208085450-32294da153ab/go.mod h1:rRypwgXVEPILomtFPyQcnbOeuRqf+nRN84vh/CICq4w= -golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= -golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= -golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3 h1:/RIbNt/Zr7rVhIkQhooTxCxFcdWLGIKnZA4IXNFSrvo= -golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= -golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= -golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= -golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= -golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8= -maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho= -maunium.net/go/mautrix v0.17.1-0.20240208085632-2d1786ced444 h1:PkpCzQotFakHkGKAatiQdb+XjP/HLQM40xuiy2JtHes= -maunium.net/go/mautrix v0.17.1-0.20240208085632-2d1786ced444/go.mod h1:tMIBWuMXrtjXAqMtaD1VHiT0B3TCxraYlqtncLIyKF0= diff --git a/go.mod b/go.mod index f1c025b4..6213dbf5 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module maunium.net/go/mautrix go 1.21 require ( + github.com/chzyer/readline v1.5.1 github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.5.0 github.com/lib/pq v1.10.9 diff --git a/go.sum b/go.sum index 4e4a26c9..307aa876 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,11 @@ github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= +github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM= +github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= +github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI= +github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= +github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= +github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -47,6 +53,7 @@ golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJ golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= +golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= From ec471738fc1393ae102adffea4cafe983fe50e6a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 26 May 2024 18:13:22 +0300 Subject: [PATCH 0223/1647] Add interface to override UpdateStateStore behavior --- statestore.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/statestore.go b/statestore.go index c1193009..7b4e2d2d 100644 --- a/statestore.go +++ b/statestore.go @@ -37,10 +37,18 @@ type StateStore interface { GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) } +type StateStoreUpdater interface { + UpdateState(ctx context.Context, evt *event.Event) +} + func UpdateStateStore(ctx context.Context, store StateStore, evt *event.Event) { if store == nil || evt == nil || evt.StateKey == nil { return } + if directUpdater, ok := store.(StateStoreUpdater); ok { + directUpdater.UpdateState(ctx, evt) + return + } // We only care about events without a state key (power levels, encryption) or member events with state key if evt.Type != event.StateMember && evt.GetStateKey() != "" { return From 0b07ae99420e48777edf9652c2ed943fde5cf90d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 26 May 2024 18:27:48 +0300 Subject: [PATCH 0224/1647] Ignore conflicts when inserting withheld group sessions --- crypto/machine.go | 1 + crypto/sql_store.go | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/crypto/machine.go b/crypto/machine.go index c9c06c3b..8e9a6c66 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -638,6 +638,7 @@ func (mach *OlmMachine) HandleRoomKeyWithheld(ctx context.Context, content *even zerolog.Ctx(ctx).Debug().Interface("content", content).Msg("Non-megolm room key withheld event") return } + // TODO log if there's a conflict? (currently ignored) err := mach.CryptoStore.PutWithheldGroupSession(ctx, *content) if err != nil { zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to save room key withheld event") diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 26b4ddbe..0d824364 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -432,8 +432,11 @@ func (store *SQLCryptoStore) RedactOutdatedGroupSessions(ctx context.Context) ([ } func (store *SQLCryptoStore) PutWithheldGroupSession(ctx context.Context, content event.RoomKeyWithheldEventContent) error { - _, err := store.DB.Exec(ctx, "INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, received_at, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)", - content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, time.Now().UTC(), store.AccountID) + _, err := store.DB.Exec(ctx, ` + INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, received_at, account_id) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (session_id, account_id) DO NOTHING + `, content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, time.Now().UTC(), store.AccountID) return err } From c1eb217b9e02531e2fdf630e0df2bed1b4390d51 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 26 May 2024 18:29:22 +0300 Subject: [PATCH 0225/1647] Add draft of high-level client framework --- .gitignore | 2 +- hicli/cryptohelper.go | 82 ++++ hicli/database/account.go | 65 ++++ hicli/database/accountdata.go | 71 ++++ hicli/database/database.go | 60 +++ hicli/database/event.go | 193 ++++++++++ hicli/database/room.go | 115 ++++++ hicli/database/sessionrequest.go | 69 ++++ hicli/database/state.go | 32 ++ hicli/database/statestore.go | 149 +++++++ hicli/database/timeline.go | 47 +++ .../database/upgrades/00-latest-revision.sql | 107 ++++++ hicli/database/upgrades/upgrades.go | 22 ++ hicli/decryptionqueue.go | 194 ++++++++++ hicli/hicli.go | 159 ++++++++ hicli/hitest/hitest.go | 69 ++++ hicli/login.go | 77 ++++ hicli/sync.go | 362 ++++++++++++++++++ hicli/syncwrap.go | 100 +++++ hicli/verify.go | 158 ++++++++ 20 files changed, 2132 insertions(+), 1 deletion(-) create mode 100644 hicli/cryptohelper.go create mode 100644 hicli/database/account.go create mode 100644 hicli/database/accountdata.go create mode 100644 hicli/database/database.go create mode 100644 hicli/database/event.go create mode 100644 hicli/database/room.go create mode 100644 hicli/database/sessionrequest.go create mode 100644 hicli/database/state.go create mode 100644 hicli/database/statestore.go create mode 100644 hicli/database/timeline.go create mode 100644 hicli/database/upgrades/00-latest-revision.sql create mode 100644 hicli/database/upgrades/upgrades.go create mode 100644 hicli/decryptionqueue.go create mode 100644 hicli/hicli.go create mode 100644 hicli/hitest/hitest.go create mode 100644 hicli/login.go create mode 100644 hicli/sync.go create mode 100644 hicli/syncwrap.go create mode 100644 hicli/verify.go diff --git a/.gitignore b/.gitignore index f37a7d0c..c01f2f30 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ .idea/ .vscode/ -*.db +*.db* *.log diff --git a/hicli/cryptohelper.go b/hicli/cryptohelper.go new file mode 100644 index 00000000..eb054af9 --- /dev/null +++ b/hicli/cryptohelper.go @@ -0,0 +1,82 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type hiCryptoHelper HiClient + +var _ mautrix.CryptoHelper = (*hiCryptoHelper)(nil) + +func (h *hiCryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) { + h.encryptLock.Lock() + defer h.encryptLock.Unlock() + encrypted, err = h.Crypto.EncryptMegolmEvent(ctx, roomID, evtType, content) + if err != nil { + if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.NoGroupSession) && !errors.Is(err, crypto.SessionNotShared) { + return + } + h.Log.Debug(). + Err(err). + Str("room_id", roomID.String()). + Msg("Got session error while encrypting event, sharing group session and trying again") + var users []id.UserID + users, err = h.ClientStore.GetRoomJoinedOrInvitedMembers(ctx, roomID) + if err != nil { + err = fmt.Errorf("failed to get room member list: %w", err) + } else if err = h.Crypto.ShareGroupSession(ctx, roomID, users); err != nil { + err = fmt.Errorf("failed to share group session: %w", err) + } else if encrypted, err = h.Crypto.EncryptMegolmEvent(ctx, roomID, evtType, content); err != nil { + err = fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err) + } + } + return +} + +func (h *hiCryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*event.Event, error) { + return h.Crypto.DecryptMegolmEvent(ctx, evt) +} + +func (h *hiCryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { + return h.Crypto.WaitForSession(ctx, roomID, senderKey, sessionID, timeout) +} + +func (h *hiCryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) { + err := h.Crypto.SendRoomKeyRequest(ctx, roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{ + userID: {deviceID}, + h.Account.UserID: {"*"}, + }) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("room_id", roomID). + Stringer("session_id", sessionID). + Stringer("user_id", userID). + Msg("Failed to send room key request") + } else { + zerolog.Ctx(ctx).Debug(). + Stringer("room_id", roomID). + Stringer("session_id", sessionID). + Stringer("user_id", userID). + Msg("Sent room key request") + } +} + +func (h *hiCryptoHelper) Init(ctx context.Context) error { + return nil +} diff --git a/hicli/database/account.go b/hicli/database/account.go new file mode 100644 index 00000000..49b50771 --- /dev/null +++ b/hicli/database/account.go @@ -0,0 +1,65 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/id" +) + +const ( + getAccountQuery = `SELECT user_id, device_id, access_token, homeserver_url, next_batch FROM account WHERE user_id = $1` + putNextBatchQuery = `UPDATE account SET next_batch = $1 WHERE user_id = $2` + upsertAccountQuery = ` + INSERT INTO account (user_id, device_id, access_token, homeserver_url, next_batch) + VALUES ($1, $2, $3, $4, $5) ON CONFLICT (user_id) + DO UPDATE SET device_id = excluded.device_id, + access_token = excluded.access_token, + homeserver_url = excluded.homeserver_url, + next_batch = excluded.next_batch + ` +) + +type AccountQuery struct { + *dbutil.QueryHelper[*Account] +} + +func (aq *AccountQuery) GetFirstUserID(ctx context.Context) (userID id.UserID, err error) { + err = aq.GetDB().QueryRow(ctx, `SELECT user_id FROM account LIMIT 1`).Scan(&userID) + return +} + +func (aq *AccountQuery) Get(ctx context.Context, userID id.UserID) (*Account, error) { + return aq.QueryOne(ctx, getAccountQuery, userID) +} + +func (aq *AccountQuery) PutNextBatch(ctx context.Context, userID id.UserID, nextBatch string) error { + return aq.Exec(ctx, putNextBatchQuery, nextBatch, userID) +} + +func (aq *AccountQuery) Put(ctx context.Context, account *Account) error { + return aq.Exec(ctx, upsertAccountQuery, account.sqlVariables()...) +} + +type Account struct { + UserID id.UserID + DeviceID id.DeviceID + AccessToken string + HomeserverURL string + NextBatch string +} + +func (a *Account) Scan(row dbutil.Scannable) (*Account, error) { + return dbutil.ValueOrErr(a, row.Scan(&a.UserID, &a.DeviceID, &a.AccessToken, &a.HomeserverURL, &a.NextBatch)) +} + +func (a *Account) sqlVariables() []any { + return []any{a.UserID, a.DeviceID, a.AccessToken, a.HomeserverURL, a.NextBatch} +} diff --git a/hicli/database/accountdata.go b/hicli/database/accountdata.go new file mode 100644 index 00000000..963886c3 --- /dev/null +++ b/hicli/database/accountdata.go @@ -0,0 +1,71 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "encoding/json" + "unsafe" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const ( + upsertAccountDataQuery = ` + INSERT INTO account_data (user_id, type, content) VALUES ($1, $2, $3) + ON CONFLICT (user_id, type) DO UPDATE SET content = excluded.content + ` + upsertRoomAccountDataQuery = ` + INSERT INTO room_account_data (user_id, room_id, type, content) VALUES ($1, $2, $3, $4) + ON CONFLICT (user_id, room_id, type) DO UPDATE SET content = excluded.content + ` +) + +type AccountDataQuery struct { + *dbutil.QueryHelper[*AccountData] +} + +func unsafeJSONString(content json.RawMessage) *string { + if content == nil { + return nil + } + str := unsafe.String(unsafe.SliceData(content), len(content)) + return &str +} + +func (adq *AccountDataQuery) Put(ctx context.Context, userID id.UserID, eventType event.Type, content json.RawMessage) error { + return adq.Exec(ctx, upsertAccountDataQuery, userID, eventType.Type, unsafeJSONString(content)) +} + +func (adq *AccountDataQuery) PutRoom(ctx context.Context, userID id.UserID, roomID id.RoomID, eventType event.Type, content json.RawMessage) error { + return adq.Exec(ctx, upsertRoomAccountDataQuery, userID, roomID, eventType.Type, unsafeJSONString(content)) +} + +type AccountData struct { + UserID id.UserID + RoomID id.RoomID + Type string + Content json.RawMessage +} + +func (a *AccountData) Scan(row dbutil.Scannable) (*AccountData, error) { + var roomID sql.NullString + err := row.Scan(&a.UserID, &roomID, &a.Type, (*[]byte)(&a.Content)) + if err != nil { + return nil, err + } + a.RoomID = id.RoomID(roomID.String) + return a, nil +} + +func (a *AccountData) sqlVariables() []any { + return []any{a.UserID, dbutil.StrPtr(a.RoomID), a.Type, unsafeJSONString(a.Content)} +} diff --git a/hicli/database/database.go b/hicli/database/database.go new file mode 100644 index 00000000..c1273ab7 --- /dev/null +++ b/hicli/database/database.go @@ -0,0 +1,60 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/hicli/database/upgrades" +) + +type Database struct { + *dbutil.Database + + Account AccountQuery + AccountData AccountDataQuery + Room RoomQuery + Event EventQuery + CurrentState CurrentStateQuery + Timeline TimelineQuery + SessionRequest SessionRequestQuery +} + +func New(rawDB *dbutil.Database) *Database { + rawDB.UpgradeTable = upgrades.Table + return &Database{ + Database: rawDB, + + Account: AccountQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newAccount)}, + AccountData: AccountDataQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newAccountData)}, + Room: RoomQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newRoom)}, + Event: EventQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newEvent)}, + CurrentState: CurrentStateQuery{Database: rawDB}, + Timeline: TimelineQuery{Database: rawDB}, + SessionRequest: SessionRequestQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newSessionRequest)}, + } +} + +func newSessionRequest(_ *dbutil.QueryHelper[*SessionRequest]) *SessionRequest { + return &SessionRequest{} +} + +func newEvent(_ *dbutil.QueryHelper[*Event]) *Event { + return &Event{} +} + +func newRoom(_ *dbutil.QueryHelper[*Room]) *Room { + return &Room{} +} + +func newAccountData(_ *dbutil.QueryHelper[*AccountData]) *AccountData { + return &AccountData{} +} + +func newAccount(_ *dbutil.QueryHelper[*Account]) *Account { + return &Account{} +} diff --git a/hicli/database/event.go b/hicli/database/event.go new file mode 100644 index 00000000..b7b15eea --- /dev/null +++ b/hicli/database/event.go @@ -0,0 +1,193 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "database/sql" + "encoding/json" + "time" + + "github.com/tidwall/gjson" + "go.mau.fi/util/dbutil" + "go.mau.fi/util/exgjson" + "golang.org/x/net/context" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const ( + getEventBaseQuery = ` + SELECT rowid, room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, + redacted_by, relates_to, megolm_session_id, decryption_error + FROM event + ` + getFailedEventsByMegolmSessionID = getEventBaseQuery + `WHERE room_id = $1 AND megolm_session_id = $2 AND decryption_error IS NOT NULL` + upsertEventQuery = ` + INSERT INTO event (room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, redacted_by, relates_to, megolm_session_id, decryption_error) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) + ON CONFLICT (event_id) DO UPDATE + SET decrypted=COALESCE(event.decrypted, excluded.decrypted), + decrypted_type=COALESCE(event.decrypted_type, excluded.decrypted_type), + redacted_by=COALESCE(event.redacted_by, excluded.redacted_by), + decryption_error=CASE WHEN COALESCE(event.decrypted, excluded.decrypted) IS NULL THEN COALESCE(excluded.decryption_error, event.decryption_error) END + RETURNING rowid + ` + updateEventDecryptedQuery = `UPDATE event SET decrypted = $1, decrypted_type = $2, decryption_error = NULL WHERE rowid = $3` +) + +type EventQuery struct { + *dbutil.QueryHelper[*Event] +} + +func (eq *EventQuery) GetFailedByMegolmSessionID(ctx context.Context, roomID id.RoomID, sessionID id.SessionID) ([]*Event, error) { + return eq.QueryMany(ctx, getFailedEventsByMegolmSessionID, roomID, sessionID) +} + +func (eq *EventQuery) Upsert(ctx context.Context, evt *Event) (rowID int64, err error) { + err = eq.GetDB().QueryRow(ctx, upsertEventQuery, evt.sqlVariables()...).Scan(&rowID) + return +} + +func (eq *EventQuery) UpdateDecrypted(ctx context.Context, rowID int64, decrypted json.RawMessage, decryptedType string) error { + return eq.Exec(ctx, updateEventDecryptedQuery, unsafeJSONString(decrypted), decryptedType, rowID) +} + +type Event struct { + RowID int64 + + RoomID id.RoomID + ID id.EventID + Sender id.UserID + Type string + StateKey *string + Timestamp time.Time + + Content json.RawMessage + Decrypted json.RawMessage + DecryptedType string + Unsigned json.RawMessage + + RedactedBy id.EventID + RelatesTo id.EventID + + MegolmSessionID id.SessionID + DecryptionError string +} + +func MautrixToEvent(evt *event.Event) *Event { + dbEvt := &Event{ + RoomID: evt.RoomID, + ID: evt.ID, + Sender: evt.Sender, + Type: evt.Type.Type, + StateKey: evt.StateKey, + Timestamp: time.UnixMilli(evt.Timestamp), + Content: evt.Content.VeryRaw, + RelatesTo: getRelatesTo(evt), + MegolmSessionID: getMegolmSessionID(evt), + } + dbEvt.Unsigned, _ = json.Marshal(&evt.Unsigned) + if evt.Unsigned.RedactedBecause != nil { + dbEvt.RedactedBy = evt.Unsigned.RedactedBecause.ID + } + return dbEvt +} + +func (e *Event) AsRawMautrix() *event.Event { + evt := &event.Event{ + RoomID: e.RoomID, + ID: e.ID, + Sender: e.Sender, + Type: event.Type{Type: e.Type, Class: event.MessageEventType}, + StateKey: e.StateKey, + Timestamp: e.Timestamp.UnixMilli(), + Content: event.Content{VeryRaw: e.Content}, + } + if e.Decrypted != nil { + evt.Content.VeryRaw = e.Decrypted + evt.Type.Type = e.DecryptedType + evt.Mautrix.WasEncrypted = true + } + if e.StateKey != nil { + evt.Type.Class = event.StateEventType + } + _ = json.Unmarshal(e.Unsigned, &evt.Unsigned) + return evt +} + +func (e *Event) Scan(row dbutil.Scannable) (*Event, error) { + var timestamp int64 + var redactedBy, relatesTo, megolmSessionID, decryptionError, decryptedType sql.NullString + err := row.Scan( + &e.RowID, + &e.RoomID, + &e.ID, + &e.Sender, + &e.Type, + &e.StateKey, + ×tamp, + (*[]byte)(&e.Content), + (*[]byte)(&e.Decrypted), + &decryptedType, + (*[]byte)(&e.Unsigned), + &redactedBy, + &relatesTo, + &megolmSessionID, + &decryptionError, + ) + if err != nil { + return nil, err + } + e.Timestamp = time.UnixMilli(timestamp) + e.RedactedBy = id.EventID(redactedBy.String) + e.RelatesTo = id.EventID(relatesTo.String) + e.MegolmSessionID = id.SessionID(megolmSessionID.String) + e.DecryptedType = decryptedType.String + e.DecryptionError = decryptionError.String + return e, nil +} + +var relatesToPath = exgjson.Path("m.relates_to", "event_id") + +func getRelatesTo(evt *event.Event) id.EventID { + res := gjson.GetBytes(evt.Content.VeryRaw, relatesToPath) + if res.Exists() && res.Type == gjson.String { + return id.EventID(res.Str) + } + return "" +} + +func getMegolmSessionID(evt *event.Event) id.SessionID { + if evt.Type != event.EventEncrypted { + return "" + } + res := gjson.GetBytes(evt.Content.VeryRaw, "session_id") + if res.Exists() && res.Type == gjson.String { + return id.SessionID(res.Str) + } + return "" +} + +func (e *Event) sqlVariables() []any { + return []any{ + e.RoomID, + e.ID, + e.Sender, + e.Type, + e.StateKey, + e.Timestamp.UnixMilli(), + unsafeJSONString(e.Content), + unsafeJSONString(e.Decrypted), + dbutil.StrPtr(e.DecryptedType), + unsafeJSONString(e.Unsigned), + dbutil.StrPtr(e.RedactedBy), + dbutil.StrPtr(e.RelatesTo), + dbutil.StrPtr(e.MegolmSessionID), + dbutil.StrPtr(e.DecryptionError), + } +} diff --git a/hicli/database/room.go b/hicli/database/room.go new file mode 100644 index 00000000..c7d13fca --- /dev/null +++ b/hicli/database/room.go @@ -0,0 +1,115 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const ( + getRoomByIDQuery = ` + SELECT room_id, creation_content, name, avatar, topic, lazy_load_summary, encryption_event, has_member_list, prev_batch + FROM room WHERE room_id = $1 + ` + ensureRoomExistsQuery = ` + INSERT INTO room (room_id) VALUES ($1) + ON CONFLICT (room_id) DO NOTHING + ` + upsertRoomFromSyncQuery = ` + UPDATE room + SET creation_content = COALESCE(room.creation_content, $2), + name = COALESCE($3, room.name), + avatar = COALESCE($4, room.avatar), + topic = COALESCE($5, room.topic), + lazy_load_summary = COALESCE($6, room.lazy_load_summary), + encryption_event = COALESCE($7, room.encryption_event), + has_member_list = room.has_member_list OR $8, + prev_batch = COALESCE(room.prev_batch, $9) + WHERE room_id = $1 + ` + setRoomPrevBatchQuery = ` + INSERT INTO room (room_id, prev_batch) VALUES ($1, $2) + ON CONFLICT (room_id) DO UPDATE SET prev_batch = excluded.prev_batch + ` +) + +type RoomQuery struct { + *dbutil.QueryHelper[*Room] +} + +func (rq *RoomQuery) Get(ctx context.Context, roomID id.RoomID) (*Room, error) { + return rq.QueryOne(ctx, getRoomByIDQuery, roomID) +} + +func (rq *RoomQuery) Upsert(ctx context.Context, room *Room) error { + return rq.Exec(ctx, upsertRoomFromSyncQuery, room.sqlVariables()...) +} + +func (rq *RoomQuery) CreateRow(ctx context.Context, roomID id.RoomID) error { + return rq.Exec(ctx, ensureRoomExistsQuery, roomID) +} + +func (rq *RoomQuery) SetPrevBatch(ctx context.Context, roomID id.RoomID, prevBatch string) error { + return rq.Exec(ctx, setRoomPrevBatchQuery, roomID, prevBatch) +} + +type Room struct { + ID id.RoomID + CreationContent *event.CreateEventContent + + Name *string + Avatar *id.ContentURI + Topic *string + + LazyLoadSummary *mautrix.LazyLoadSummary + + EncryptionEvent *event.EncryptionEventContent + HasMemberList bool + + PrevBatch string +} + +func (r *Room) Scan(row dbutil.Scannable) (*Room, error) { + var prevBatch sql.NullString + err := row.Scan( + &r.ID, + dbutil.JSON{Data: &r.CreationContent}, + &r.Name, + &r.Avatar, + &r.Topic, + dbutil.JSON{Data: &r.LazyLoadSummary}, + dbutil.JSON{Data: &r.EncryptionEvent}, + &r.HasMemberList, + &prevBatch, + ) + if err != nil { + return nil, err + } + r.PrevBatch = prevBatch.String + return r, nil +} + +func (r *Room) sqlVariables() []any { + return []any{ + r.ID, + dbutil.JSONPtr(r.CreationContent), + r.Name, + r.Avatar, + r.Topic, + dbutil.JSONPtr(r.LazyLoadSummary), + dbutil.JSONPtr(r.EncryptionEvent), + r.HasMemberList, + dbutil.StrPtr(r.PrevBatch), + } +} diff --git a/hicli/database/sessionrequest.go b/hicli/database/sessionrequest.go new file mode 100644 index 00000000..6690c13f --- /dev/null +++ b/hicli/database/sessionrequest.go @@ -0,0 +1,69 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/id" +) + +const ( + putSessionRequestQueueEntry = ` + INSERT INTO session_request (room_id, session_id, sender, min_index, backup_checked, request_sent) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (session_id) DO UPDATE + SET min_index = MIN(excluded.min_index, session_request.min_index), + backup_checked = excluded.backup_checked OR session_request.backup_checked, + request_sent = excluded.request_sent OR session_request.request_sent + ` + removeSessionRequestQuery = ` + DELETE FROM session_request WHERE session_id = $1 AND min_index >= $2 + ` + getNextSessionsToRequestQuery = ` + SELECT room_id, session_id, sender, min_index, backup_checked, request_sent + FROM session_request + WHERE request_sent = false OR backup_checked = false + ORDER BY backup_checked, rowid + LIMIT $1 + ` +) + +type SessionRequestQuery struct { + *dbutil.QueryHelper[*SessionRequest] +} + +func (srq *SessionRequestQuery) Next(ctx context.Context, count int) ([]*SessionRequest, error) { + return srq.QueryMany(ctx, getNextSessionsToRequestQuery, count) +} + +func (srq *SessionRequestQuery) Remove(ctx context.Context, sessionID id.SessionID, minIndex uint32) error { + return srq.Exec(ctx, removeSessionRequestQuery, sessionID, minIndex) +} + +func (srq *SessionRequestQuery) Put(ctx context.Context, sr *SessionRequest) error { + return srq.Exec(ctx, putSessionRequestQueueEntry, sr.sqlVariables()...) +} + +type SessionRequest struct { + RoomID id.RoomID + SessionID id.SessionID + Sender id.UserID + MinIndex uint32 + BackupChecked bool + RequestSent bool +} + +func (s *SessionRequest) Scan(row dbutil.Scannable) (*SessionRequest, error) { + return dbutil.ValueOrErr(s, row.Scan(&s.RoomID, &s.SessionID, &s.Sender, &s.MinIndex, &s.BackupChecked, &s.RequestSent)) +} + +func (s *SessionRequest) sqlVariables() []any { + return []any{s.RoomID, s.SessionID, s.Sender, s.MinIndex, s.BackupChecked, s.RequestSent} +} diff --git a/hicli/database/state.go b/hicli/database/state.go new file mode 100644 index 00000000..47c91dcf --- /dev/null +++ b/hicli/database/state.go @@ -0,0 +1,32 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const ( + setCurrentStateQuery = ` + INSERT INTO current_state (room_id, event_type, state_key, event_rowid, membership) VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (room_id, event_type, state_key) DO UPDATE SET event_rowid = excluded.event_rowid, membership = excluded.membership + ` +) + +type CurrentStateQuery struct { + *dbutil.Database +} + +func (csq *CurrentStateQuery) Set(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, eventRowID int64, membership event.Membership) error { + _, err := csq.Exec(ctx, setCurrentStateQuery, roomID, eventType.Type, stateKey, eventRowID, dbutil.StrPtr(membership)) + return err +} diff --git a/hicli/database/statestore.go b/hicli/database/statestore.go new file mode 100644 index 00000000..e0471ef2 --- /dev/null +++ b/hicli/database/statestore.go @@ -0,0 +1,149 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "errors" + + "go.mau.fi/util/dbutil" + "golang.org/x/exp/slices" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const ( + getMembershipQuery = ` + SELECT membership FROM current_state + WHERE room_id = $1 AND event_type = 'm.room.member' AND state_key = $2 + ` + getStateEventContentQuery = ` + SELECT event.content FROM current_state cs + LEFT JOIN event ON event.rowid = cs.event_rowid + WHERE cs.room_id = $1 AND cs.event_type = $2 AND cs.state_key = $3 + ` + getRoomJoinedOrInvitedMembersQuery = ` + SELECT state_key FROM current_state + WHERE room_id = $1 AND event_type = 'm.room.member' AND membership IN ('join', 'invite') + ` + isRoomEncryptedQuery = ` + SELECT room.encryption_event IS NOT NULL FROM room WHERE room_id = $1 + ` + getRoomEncryptionEventQuery = ` + SELECT room.encryption_event FROM room WHERE room_id = $1 + ` + findSharedRoomsQuery = ` + SELECT room_id FROM current_state + WHERE event_type = 'm.room.member' AND state_key = $1 AND membership = 'join' + ` +) + +type ClientStateStore struct { + *Database +} + +var _ mautrix.StateStore = (*ClientStateStore)(nil) +var _ mautrix.StateStoreUpdater = (*ClientStateStore)(nil) +var _ crypto.StateStore = (*ClientStateStore)(nil) + +func (c *ClientStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool { + return c.IsMembership(ctx, roomID, userID, event.MembershipJoin) +} + +func (c *ClientStateStore) IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool { + return c.IsMembership(ctx, roomID, userID, event.MembershipInvite, event.MembershipJoin) +} + +func (c *ClientStateStore) IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool { + var membership event.Membership + err := c.QueryRow(ctx, getMembershipQuery, roomID, userID).Scan(&membership) + if errors.Is(err, sql.ErrNoRows) { + err = nil + membership = event.MembershipLeave + } + return slices.Contains(allowedMemberships, membership) +} + +func (c *ClientStateStore) GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) { + content, err := c.TryGetMember(ctx, roomID, userID) + if content == nil { + content = &event.MemberEventContent{Membership: event.MembershipLeave} + } + return content, err +} + +func (c *ClientStateStore) TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (content *event.MemberEventContent, err error) { + err = c.QueryRow(ctx, getStateEventContentQuery, roomID, event.StateMember.Type, userID).Scan(&dbutil.JSON{Data: &content}) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return +} + +func (c *ClientStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (content *event.PowerLevelsEventContent, err error) { + err = c.QueryRow(ctx, getStateEventContentQuery, roomID, event.StatePowerLevels.Type, "").Scan(&dbutil.JSON{Data: &content}) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return +} + +func (c *ClientStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) { + rows, err := c.Query(ctx, getRoomJoinedOrInvitedMembersQuery, roomID) + return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() +} + +func (c *ClientStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (isEncrypted bool, err error) { + err = c.QueryRow(ctx, isRoomEncryptedQuery, roomID).Scan(&isEncrypted) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return +} + +func (c *ClientStateStore) GetEncryptionEvent(ctx context.Context, roomID id.RoomID) (content *event.EncryptionEventContent, err error) { + err = c.QueryRow(ctx, getRoomEncryptionEventQuery, roomID). + Scan(&dbutil.JSON{Data: &content}) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return +} + +func (c *ClientStateStore) FindSharedRooms(ctx context.Context, userID id.UserID) ([]id.RoomID, error) { + // TODO for multiuser support, this might need to filter by the local user's membership + rows, err := c.Query(ctx, findSharedRoomsQuery, userID) + return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.RoomID], err).AsList() +} + +// Update methods are all intentionally no-ops as the state store wants to have the full event + +func (c *ClientStateStore) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error { + return nil +} + +func (c *ClientStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error { + return nil +} + +func (c *ClientStateStore) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error { + return nil +} + +func (c *ClientStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error { + return nil +} + +func (c *ClientStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error { + return nil +} + +func (c *ClientStateStore) UpdateState(ctx context.Context, evt *event.Event) {} diff --git a/hicli/database/timeline.go b/hicli/database/timeline.go new file mode 100644 index 00000000..585e55bb --- /dev/null +++ b/hicli/database/timeline.go @@ -0,0 +1,47 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/id" +) + +const ( + clearTimelineQuery = ` + DELETE FROM timeline WHERE room_id = $1 + ` + setTimelineQuery = ` + INSERT INTO timeline (room_id, event_rowid) VALUES ($1, $2) + ` +) + +type MassInsertableRowID int64 + +func (m MassInsertableRowID) GetMassInsertValues() [1]any { + return [1]any{m} +} + +var setTimelineQueryBuilder = dbutil.NewMassInsertBuilder[MassInsertableRowID, [1]any](setTimelineQuery, "($1, $%d)") + +type TimelineQuery struct { + *dbutil.Database +} + +func (tq *TimelineQuery) Clear(ctx context.Context, roomID id.RoomID) error { + _, err := tq.Exec(ctx, clearTimelineQuery, roomID) + return err +} + +func (tq *TimelineQuery) Append(ctx context.Context, roomID id.RoomID, rowIDs []MassInsertableRowID) error { + query, params := setTimelineQueryBuilder.Build([1]any{roomID}, rowIDs) + _, err := tq.Exec(ctx, query, params...) + return err +} diff --git a/hicli/database/upgrades/00-latest-revision.sql b/hicli/database/upgrades/00-latest-revision.sql new file mode 100644 index 00000000..cc85f25a --- /dev/null +++ b/hicli/database/upgrades/00-latest-revision.sql @@ -0,0 +1,107 @@ +-- v0 -> v1: Latest revision +CREATE TABLE account ( + user_id TEXT NOT NULL PRIMARY KEY, + device_id TEXT NOT NULL, + access_token TEXT NOT NULL, + homeserver_url TEXT NOT NULL, + + next_batch TEXT NOT NULL +) STRICT; + +CREATE TABLE room ( + room_id TEXT NOT NULL PRIMARY KEY, + creation_content TEXT, + + name TEXT, + avatar TEXT, + topic TEXT, + lazy_load_summary TEXT, + + encryption_event TEXT, + has_member_list INTEGER NOT NULL DEFAULT false, + + prev_batch TEXT +) STRICT; +CREATE INDEX room_type_idx ON room (creation_content ->> 'type'); + +CREATE TABLE account_data ( + user_id TEXT NOT NULL, + type TEXT NOT NULL, + content TEXT NOT NULL, + + PRIMARY KEY (user_id, type) +) STRICT; + +CREATE TABLE room_account_data ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + content TEXT NOT NULL, + + PRIMARY KEY (user_id, room_id, type), + CONSTRAINT room_account_data_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE +) STRICT; + +CREATE TABLE event ( + rowid INTEGER PRIMARY KEY, + + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + sender TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT, + timestamp INTEGER NOT NULL, + + content TEXT NOT NULL, + decrypted TEXT, + decrypted_type TEXT, + unsigned TEXT NOT NULL, + + redacted_by TEXT, + relates_to TEXT, + + megolm_session_id TEXT, + decryption_error TEXT, + + CONSTRAINT event_id_unique_key UNIQUE (event_id), + CONSTRAINT event_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE +) STRICT; +CREATE INDEX event_room_id_idx ON event (room_id); +CREATE INDEX event_redacted_by_idx ON event (room_id, redacted_by); +CREATE INDEX event_relates_to_idx ON event (room_id, relates_to); +CREATE INDEX event_megolm_session_id_idx ON event (room_id, megolm_session_id); + +CREATE TABLE session_request ( + room_id TEXT NOT NULL, + session_id TEXT NOT NULL, + sender TEXT NOT NULL, + min_index INTEGER NOT NULL, + backup_checked INTEGER NOT NULL DEFAULT false, + request_sent INTEGER NOT NULL DEFAULT false, + + PRIMARY KEY (session_id), + CONSTRAINT session_request_queue_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE +) STRICT; + +CREATE TABLE timeline ( + rowid INTEGER PRIMARY KEY, + room_id TEXT NOT NULL, + event_rowid INTEGER NOT NULL, + + CONSTRAINT timeline_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE, + CONSTRAINT timeline_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid) ON DELETE CASCADE +) STRICT; +CREATE INDEX timeline_room_id_idx ON timeline (room_id); + +CREATE TABLE current_state ( + room_id TEXT NOT NULL, + event_type TEXT NOT NULL, + state_key TEXT NOT NULL, + event_rowid INTEGER NOT NULL, + + membership TEXT, + + PRIMARY KEY (room_id, event_type, state_key), + CONSTRAINT current_state_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE, + CONSTRAINT current_state_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid) +) STRICT, WITHOUT ROWID; diff --git a/hicli/database/upgrades/upgrades.go b/hicli/database/upgrades/upgrades.go new file mode 100644 index 00000000..9d0bd1a0 --- /dev/null +++ b/hicli/database/upgrades/upgrades.go @@ -0,0 +1,22 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package upgrades + +import ( + "embed" + + "go.mau.fi/util/dbutil" +) + +var Table dbutil.UpgradeTable + +//go:embed *.sql +var upgrades embed.FS + +func init() { + Table.RegisterFS(upgrades) +} diff --git a/hicli/decryptionqueue.go b/hicli/decryptionqueue.go new file mode 100644 index 00000000..551713a8 --- /dev/null +++ b/hicli/decryptionqueue.go @@ -0,0 +1,194 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "fmt" + "sync" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/hicli/database" + "maunium.net/go/mautrix/id" +) + +func (h *HiClient) fetchFromKeyBackup(ctx context.Context, roomID id.RoomID, sessionID id.SessionID) (*crypto.InboundGroupSession, error) { + data, err := h.Client.GetKeyBackupForRoomAndSession(ctx, h.KeyBackupVersion, roomID, sessionID) + if err != nil { + return nil, err + } else if data == nil { + return nil, nil + } + decrypted, err := data.SessionData.Decrypt(h.KeyBackupKey) + if err != nil { + return nil, err + } + return h.Crypto.ImportRoomKeyFromBackup(ctx, h.KeyBackupVersion, roomID, sessionID, decrypted) +} + +func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.RoomID, sessionID id.SessionID, firstKnownIndex uint32) { + log := zerolog.Ctx(ctx) + err := h.DB.SessionRequest.Remove(ctx, sessionID, firstKnownIndex) + if err != nil { + log.Warn().Err(err).Msg("Failed to remove session request after receiving megolm session") + } + events, err := h.DB.Event.GetFailedByMegolmSessionID(ctx, roomID, sessionID) + if err != nil { + log.Err(err).Msg("Failed to get events that failed to decrypt to retry decryption") + return + } else if len(events) == 0 { + log.Trace().Msg("No events to retry decryption for") + return + } + decrypted := events[:0] + for _, evt := range events { + if evt.Decrypted != nil { + continue + } + + evt.Decrypted, evt.DecryptedType, err = h.decryptEvent(ctx, evt.AsRawMautrix()) + if err != nil { + log.Warn().Err(err).Stringer("event_id", evt.ID).Msg("Failed to decrypt event even after receiving megolm session") + } else { + decrypted = append(decrypted, evt) + } + } + if len(decrypted) > 0 { + err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { + for _, evt := range decrypted { + err = h.DB.Event.UpdateDecrypted(ctx, evt.RowID, evt.Decrypted, evt.DecryptedType) + if err != nil { + return fmt.Errorf("failed to save decrypted content for %s: %w", evt.ID, err) + } + } + return nil + }) + if err != nil { + log.Err(err).Msg("Failed to save decrypted events") + } + } +} + +func (h *HiClient) WakeupRequestQueue() { + select { + case h.requestQueueWakeup <- struct{}{}: + default: + } +} + +func (h *HiClient) RunRequestQueue(ctx context.Context) { + log := zerolog.Ctx(ctx).With().Str("action", "request queue").Logger() + ctx = log.WithContext(ctx) + log.Info().Msg("Starting key request queue") + defer func() { + log.Info().Msg("Stopping key request queue") + }() + for { + err := h.FetchKeysForOutdatedUsers(ctx) + if err != nil { + log.Err(err).Msg("Failed to fetch outdated device lists for tracked users") + } + madeRequests, err := h.RequestQueuedSessions(ctx) + if err != nil { + log.Err(err).Msg("Failed to handle session request queue") + } else if madeRequests { + continue + } + select { + case <-ctx.Done(): + return + case <-h.requestQueueWakeup: + } + } +} + +func (h *HiClient) requestQueuedSession(ctx context.Context, req *database.SessionRequest, doneFunc func()) { + defer doneFunc() + log := zerolog.Ctx(ctx) + if !req.BackupChecked { + sess, err := h.fetchFromKeyBackup(ctx, req.RoomID, req.SessionID) + if err != nil { + log.Err(err). + Stringer("session_id", req.SessionID). + Msg("Failed to fetch session from key backup") + + // TODO should this have retries instead of just storing it's checked? + req.BackupChecked = true + err = h.DB.SessionRequest.Put(ctx, req) + if err != nil { + log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after trying to check backup") + } + } else if sess == nil || sess.Internal.FirstKnownIndex() > req.MinIndex { + req.BackupChecked = true + err = h.DB.SessionRequest.Put(ctx, req) + if err != nil { + log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after checking backup") + } + } else { + log.Debug().Stringer("session_id", req.SessionID). + Msg("Found session with sufficiently low first known index, removing from queue") + err = h.DB.SessionRequest.Remove(ctx, req.SessionID, sess.Internal.FirstKnownIndex()) + if err != nil { + log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to remove session from request queue") + } + } + } else { + err := h.Crypto.SendRoomKeyRequest(ctx, req.RoomID, "", req.SessionID, "", map[id.UserID][]id.DeviceID{ + h.Account.UserID: {"*"}, + req.Sender: {"*"}, + }) + //var err error + if err != nil { + log.Err(err). + Stringer("session_id", req.SessionID). + Msg("Failed to send key request") + } else { + log.Debug().Stringer("session_id", req.SessionID).Msg("Sent key request") + req.RequestSent = true + err = h.DB.SessionRequest.Put(ctx, req) + if err != nil { + log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after sending request") + } + } + } +} + +const MaxParallelRequests = 5 + +func (h *HiClient) RequestQueuedSessions(ctx context.Context) (bool, error) { + sessions, err := h.DB.SessionRequest.Next(ctx, MaxParallelRequests) + if err != nil { + return false, fmt.Errorf("failed to get next events to decrypt: %w", err) + } else if len(sessions) == 0 { + return false, nil + } + var wg sync.WaitGroup + wg.Add(len(sessions)) + for _, req := range sessions { + go h.requestQueuedSession(ctx, req, wg.Done) + } + wg.Wait() + + return true, err +} + +func (h *HiClient) FetchKeysForOutdatedUsers(ctx context.Context) error { + outdatedUsers, err := h.Crypto.CryptoStore.GetOutdatedTrackedUsers(ctx) + if err != nil { + return err + } else if len(outdatedUsers) == 0 { + return nil + } + _, err = h.Crypto.FetchKeys(ctx, outdatedUsers, false) + if err != nil { + return err + } + // TODO backoff for users that fail to be fetched? + return nil +} diff --git a/hicli/hicli.go b/hicli/hicli.go new file mode 100644 index 00000000..9b889d3c --- /dev/null +++ b/hicli/hicli.go @@ -0,0 +1,159 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Package hicli contains a highly opinionated high-level framework for developing instant messaging clients on Matrix. +package hicli + +import ( + "context" + "fmt" + "net" + "net/http" + "net/url" + "sync" + "time" + + "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/backup" + "maunium.net/go/mautrix/hicli/database" + "maunium.net/go/mautrix/id" +) + +type HiClient struct { + DB *database.Database + Account *database.Account + Client *mautrix.Client + Crypto *crypto.OlmMachine + CryptoStore *crypto.SQLCryptoStore + ClientStore *database.ClientStateStore + Log zerolog.Logger + + Verified bool + + KeyBackupVersion id.KeyBackupVersion + KeyBackupKey *backup.MegolmBackupKey + + firstSyncReceived bool + syncingID int + syncLock sync.Mutex + encryptLock sync.Mutex + + requestQueueWakeup chan struct{} +} + +func New(rawDB *dbutil.Database, log zerolog.Logger, pickleKey []byte) *HiClient { + rawDB.Owner = "hicli" + rawDB.IgnoreForeignTables = true + db := database.New(rawDB) + db.Log = dbutil.ZeroLogger(log.With().Str("db_section", "hicli").Logger()) + c := &HiClient{ + DB: db, + Log: log, + + requestQueueWakeup: make(chan struct{}, 1), + } + c.ClientStore = &database.ClientStateStore{Database: db} + c.Client = &mautrix.Client{ + UserAgent: mautrix.DefaultUserAgent, + Client: &http.Client{ + Transport: &http.Transport{ + DialContext: (&net.Dialer{Timeout: 10 * time.Second}).DialContext, + TLSHandshakeTimeout: 10 * time.Second, + // This needs to be relatively high to allow initial syncs + ResponseHeaderTimeout: 180 * time.Second, + ForceAttemptHTTP2: true, + }, + Timeout: 180 * time.Second, + }, + Syncer: (*hiSyncer)(c), + Store: (*hiStore)(c), + StateStore: c.ClientStore, + Log: log.With().Str("component", "mautrix client").Logger(), + } + c.CryptoStore = crypto.NewSQLCryptoStore(rawDB, dbutil.ZeroLogger(log.With().Str("db_section", "crypto").Logger()), "", "", pickleKey) + cryptoLog := log.With().Str("component", "crypto").Logger() + c.Crypto = crypto.NewOlmMachine(c.Client, &cryptoLog, c.CryptoStore, c.ClientStore) + c.Crypto.SessionReceived = c.handleReceivedMegolmSession + c.Crypto.DisableRatchetTracking = true + c.Crypto.DisableDecryptKeyFetching = true + c.Client.Crypto = (*hiCryptoHelper)(c) + return c +} + +func (h *HiClient) IsLoggedIn() bool { + return h.Account != nil +} + +func (h *HiClient) Start(ctx context.Context, userID id.UserID) error { + err := h.DB.Upgrade(ctx) + if err != nil { + return fmt.Errorf("failed to upgrade hicli db: %w", err) + } + err = h.CryptoStore.DB.Upgrade(ctx) + if err != nil { + return fmt.Errorf("failed to upgrade crypto db: %w", err) + } + account, err := h.DB.Account.Get(ctx, userID) + if err != nil { + return err + } + if account != nil { + zerolog.Ctx(ctx).Debug().Stringer("user_id", account.UserID).Msg("Preparing client with existing credentials") + h.Account = account + h.CryptoStore.AccountID = account.UserID.String() + h.CryptoStore.DeviceID = account.DeviceID + h.Client.UserID = account.UserID + h.Client.DeviceID = account.DeviceID + h.Client.AccessToken = account.AccessToken + h.Client.HomeserverURL, err = url.Parse(account.HomeserverURL) + if err != nil { + return err + } + err = h.Crypto.Load(ctx) + if err != nil { + return fmt.Errorf("failed to load olm machine: %w", err) + } + + h.Verified, err = h.checkIsCurrentDeviceVerified(ctx) + if err != nil { + return err + } + zerolog.Ctx(ctx).Debug().Bool("verified", h.Verified).Msg("Checked current device verification status") + if h.Verified { + err = h.loadPrivateKeys(ctx) + if err != nil { + return err + } + go h.Sync() + go h.RunRequestQueue(ctx) + } + } + return nil +} + +func (h *HiClient) Sync() { + h.Client.StopSync() + h.syncLock.Lock() + defer h.syncLock.Unlock() + h.syncingID++ + syncingID := h.syncingID + log := h.Log.With(). + Str("action", "sync"). + Int("sync_id", syncingID). + Logger() + ctx := log.WithContext(context.Background()) + log.Info().Msg("Starting syncing") + err := h.Client.SyncWithContext(ctx) + if err != nil { + log.Err(err).Msg("Fatal error in syncer") + } else { + log.Info().Msg("Syncing stopped") + } +} diff --git a/hicli/hitest/hitest.go b/hicli/hitest/hitest.go new file mode 100644 index 00000000..ec94a328 --- /dev/null +++ b/hicli/hitest/hitest.go @@ -0,0 +1,69 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package main + +import ( + "context" + "io" + "os" + "os/signal" + "syscall" + + "github.com/chzyer/readline" + _ "github.com/mattn/go-sqlite3" + "go.mau.fi/util/dbutil" + _ "go.mau.fi/util/dbutil/litestream" + "go.mau.fi/util/exerrors" + "go.mau.fi/util/exzerolog" + "go.mau.fi/zeroconfig" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/hicli" + "maunium.net/go/mautrix/id" +) + +var writerTypeReadline zeroconfig.WriterType = "hitest_readline" + +func main() { + rl := exerrors.Must(readline.New("> ")) + defer func() { + _ = rl.Close() + }() + zeroconfig.RegisterWriter(writerTypeReadline, func(config *zeroconfig.WriterConfig) (io.Writer, error) { + return rl.Stdout(), nil + }) + log := exerrors.Must((&zeroconfig.Config{ + Writers: []zeroconfig.WriterConfig{{ + Type: writerTypeReadline, + Format: zeroconfig.LogFormatPrettyColored, + }}, + }).Compile()) + exzerolog.SetupDefaults(log) + + rawDB := exerrors.Must(dbutil.NewWithDialect("hicli.db", "sqlite3-fk-wal")) + ctx := log.WithContext(context.Background()) + cli := hicli.New(rawDB, *log, []byte("meow")) + userID, _ := cli.DB.Account.GetFirstUserID(ctx) + exerrors.PanicIfNotNil(cli.Start(ctx, userID)) + if !cli.IsLoggedIn() { + rl.SetPrompt("User ID: ") + userID := id.UserID(exerrors.Must(rl.Readline())) + _, serverName := exerrors.Must2(userID.Parse()) + discovery, err := mautrix.DiscoverClientAPI(ctx, serverName) + if discovery == nil { + log.Fatal().Err(err).Msg("Failed to discover homeserver") + } + password := exerrors.Must(rl.ReadPassword("Password: ")) + recoveryCode := exerrors.Must(rl.ReadPassword("Recovery code: ")) + exerrors.PanicIfNotNil(cli.LoginAndVerify(ctx, discovery.Homeserver.BaseURL, userID.String(), string(password), string(recoveryCode))) + } + rl.SetPrompt("> ") + + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + <-c +} diff --git a/hicli/login.go b/hicli/login.go new file mode 100644 index 00000000..47ea5a4d --- /dev/null +++ b/hicli/login.go @@ -0,0 +1,77 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "fmt" + "net/url" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/hicli/database" +) + +func (h *HiClient) LoginPassword(ctx context.Context, homeserverURL, username, password string) error { + var err error + h.Client.HomeserverURL, err = url.Parse(homeserverURL) + if err != nil { + return err + } + return h.Login(ctx, &mautrix.ReqLogin{ + Type: mautrix.AuthTypePassword, + Identifier: mautrix.UserIdentifier{ + Type: mautrix.IdentifierTypeUser, + User: username, + }, + Password: password, + InitialDeviceDisplayName: "mautrix client", + }) +} + +func (h *HiClient) Login(ctx context.Context, req *mautrix.ReqLogin) error { + req.StoreCredentials = true + req.StoreHomeserverURL = true + resp, err := h.Client.Login(ctx, req) + if err != nil { + return err + } + h.Account = &database.Account{ + UserID: resp.UserID, + DeviceID: resp.DeviceID, + AccessToken: resp.AccessToken, + HomeserverURL: h.Client.HomeserverURL.String(), + } + h.CryptoStore.AccountID = resp.UserID.String() + h.CryptoStore.DeviceID = resp.DeviceID + err = h.DB.Account.Put(ctx, h.Account) + if err != nil { + return err + } + err = h.Crypto.Load(ctx) + if err != nil { + return fmt.Errorf("failed to load olm machine: %w", err) + } + err = h.Crypto.ShareKeys(ctx, 0) + if err != nil { + return err + } + return nil +} + +func (h *HiClient) LoginAndVerify(ctx context.Context, homeserverURL, username, password, recoveryCode string) error { + err := h.LoginPassword(ctx, homeserverURL, username, password) + if err != nil { + return err + } + err = h.VerifyWithRecoveryCode(ctx, recoveryCode) + if err != nil { + return err + } + go h.Sync() + go h.RunRequestQueue(ctx) + return nil +} diff --git a/hicli/sync.go b/hicli/sync.go new file mode 100644 index 00000000..d0064015 --- /dev/null +++ b/hicli/sync.go @@ -0,0 +1,362 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "errors" + "fmt" + + "github.com/rs/zerolog" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.mau.fi/util/exzerolog" + "golang.org/x/exp/slices" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/hicli/database" + "maunium.net/go/mautrix/id" +) + +type syncContext struct { + shouldWakeupRequestQueue bool +} + +func (h *HiClient) preProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error { + log := zerolog.Ctx(ctx) + postponedToDevices := resp.ToDevice.Events[:0] + for _, evt := range resp.ToDevice.Events { + evt.Type.Class = event.ToDeviceEventType + err := evt.Content.ParseRaw(evt.Type) + if err != nil { + log.Warn().Err(err). + Stringer("event_type", &evt.Type). + Stringer("sender", evt.Sender). + Msg("Failed to parse to-device event, skipping") + continue + } + + switch content := evt.Content.Parsed.(type) { + case *event.EncryptedEventContent: + h.Crypto.HandleEncryptedEvent(ctx, evt) + case *event.RoomKeyWithheldEventContent: + h.Crypto.HandleRoomKeyWithheld(ctx, content) + default: + postponedToDevices = append(postponedToDevices, evt) + } + } + resp.ToDevice.Events = postponedToDevices + + return nil +} + +func (h *HiClient) postProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error { + h.Crypto.HandleOTKCounts(ctx, &resp.DeviceOTKCount) + go h.asyncPostProcessSyncResponse(ctx, resp, since) + if ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue { + h.WakeupRequestQueue() + } + return nil +} + +func (h *HiClient) asyncPostProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) { + for _, evt := range resp.ToDevice.Events { + switch content := evt.Content.Parsed.(type) { + case *event.SecretRequestEventContent: + h.Crypto.HandleSecretRequest(ctx, evt.Sender, content) + case *event.RoomKeyRequestEventContent: + h.Crypto.HandleRoomKeyRequest(ctx, evt.Sender, content) + } + } +} + +func (h *HiClient) processSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error { + if len(resp.DeviceLists.Changed) > 0 { + zerolog.Ctx(ctx).Debug(). + Array("users", exzerolog.ArrayOfStringers(resp.DeviceLists.Changed)). + Msg("Marking changed device lists for tracked users as outdated") + err := h.Crypto.CryptoStore.MarkTrackedUsersOutdated(ctx, resp.DeviceLists.Changed) + if err != nil { + return fmt.Errorf("failed to mark changed device lists as outdated: %w", err) + } + ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true + } + + for _, evt := range resp.AccountData.Events { + evt.Type.Class = event.AccountDataEventType + err := h.DB.AccountData.Put(ctx, h.Account.UserID, evt.Type, evt.Content.VeryRaw) + if err != nil { + return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err) + } + } + for roomID, room := range resp.Rooms.Join { + err := h.processSyncJoinedRoom(ctx, roomID, room) + if err != nil { + return fmt.Errorf("failed to process joined room %s: %w", roomID, err) + } + } + for roomID, room := range resp.Rooms.Leave { + err := h.processSyncLeftRoom(ctx, roomID, room) + if err != nil { + return fmt.Errorf("failed to process left room %s: %w", roomID, err) + } + } + h.Account.NextBatch = resp.NextBatch + err := h.DB.Account.PutNextBatch(ctx, h.Account.UserID, resp.NextBatch) + if err != nil { + return fmt.Errorf("failed to save next_batch: %w", err) + } + return nil +} + +func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncJoinedRoom) error { + existingRoomData, err := h.DB.Room.Get(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to get room data: %w", err) + } else if existingRoomData == nil { + err = h.DB.Room.CreateRow(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to ensure room row exists: %w", err) + } + existingRoomData = &database.Room{ID: roomID} + } + + for _, evt := range room.AccountData.Events { + evt.Type.Class = event.AccountDataEventType + evt.RoomID = roomID + err = h.DB.AccountData.PutRoom(ctx, h.Account.UserID, roomID, evt.Type, evt.Content.VeryRaw) + if err != nil { + return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err) + } + } + err = h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary) + if err != nil { + return err + } + return nil +} + +func (h *HiClient) processSyncLeftRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncLeftRoom) error { + existingRoomData, err := h.DB.Room.Get(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to get room data: %w", err) + } else if existingRoomData == nil { + return nil + } + return h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary) +} + +func isDecryptionErrorRetryable(err error) bool { + return errors.Is(err, crypto.NoSessionFound) || errors.Is(err, olm.UnknownMessageIndex) || errors.Is(err, crypto.ErrGroupSessionWithheld) +} + +func removeReplyFallback(evt *event.Event) []byte { + content, ok := evt.Content.Parsed.(*event.MessageEventContent) + if ok && content.RelatesTo.GetReplyTo() != "" { + prevFormattedBody := content.FormattedBody + content.RemoveReplyFallback() + if content.FormattedBody != prevFormattedBody { + bytes, err := sjson.SetBytes(evt.Content.VeryRaw, "formatted_body", content.FormattedBody) + if err == nil { + return bytes + } + bytes, err = sjson.SetBytes(evt.Content.VeryRaw, "body", content.Body) + if err == nil { + return bytes + } + } + } + return nil +} + +func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) ([]byte, string, error) { + err := evt.Content.ParseRaw(evt.Type) + if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) { + return nil, "", err + } + decrypted, err := h.Crypto.DecryptMegolmEvent(ctx, evt) + if err != nil { + return nil, "", err + } + withoutFallback := removeReplyFallback(decrypted) + if withoutFallback != nil { + return withoutFallback, decrypted.Type.Type, nil + } + return decrypted.Content.VeryRaw, decrypted.Type.Type, nil +} + +func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.Room, state *mautrix.SyncEventsList, timeline *mautrix.SyncTimeline, summary *mautrix.LazyLoadSummary) error { + decryptionQueue := make(map[id.SessionID]*database.SessionRequest) + roomDataChanged := false + processEvent := func(evt *event.Event) (database.MassInsertableRowID, error) { + evt.RoomID = room.ID + dbEvt := database.MautrixToEvent(evt) + contentWithoutFallback := removeReplyFallback(evt) + if contentWithoutFallback != nil { + dbEvt.Content = contentWithoutFallback + } + var decryptionErr error + if evt.Type == event.EventEncrypted { + dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt) + if decryptionErr != nil { + dbEvt.DecryptionError = decryptionErr.Error() + } + } + rowID, err := h.DB.Event.Upsert(ctx, dbEvt) + if err != nil { + return -1, fmt.Errorf("failed to save event %s: %w", evt.ID, err) + } + if decryptionErr != nil && isDecryptionErrorRetryable(decryptionErr) { + req, ok := decryptionQueue[dbEvt.MegolmSessionID] + if !ok { + req = &database.SessionRequest{ + RoomID: room.ID, + SessionID: dbEvt.MegolmSessionID, + Sender: evt.Sender, + } + } + minIndex, _ := crypto.ParseMegolmMessageIndex(evt.Content.AsEncrypted().MegolmCiphertext) + req.MinIndex = min(uint32(minIndex), req.MinIndex) + decryptionQueue[dbEvt.MegolmSessionID] = req + } + if evt.StateKey != nil { + var membership event.Membership + if evt.Type == event.StateMember { + membership = event.Membership(gjson.GetBytes(evt.Content.VeryRaw, "membership").Str) + } + err = h.DB.CurrentState.Set(ctx, room.ID, evt.Type, *evt.StateKey, rowID, membership) + if err != nil { + return -1, fmt.Errorf("failed to save current state event ID %s for %s/%s: %w", evt.ID, evt.Type.Type, *evt.StateKey, err) + } + roomDataChanged = processImportantEvent(ctx, evt, room) || roomDataChanged + } + return database.MassInsertableRowID(rowID), nil + } + var err error + for _, evt := range state.Events { + evt.Type.Class = event.StateEventType + _, err = processEvent(evt) + if err != nil { + return err + } + } + if len(timeline.Events) > 0 { + timelineIDs := make([]database.MassInsertableRowID, len(timeline.Events)) + for i, evt := range timeline.Events { + if evt.StateKey != nil { + evt.Type.Class = event.StateEventType + } else { + evt.Type.Class = event.MessageEventType + } + timelineIDs[i], err = processEvent(evt) + if err != nil { + return err + } + } + for _, entry := range decryptionQueue { + err = h.DB.SessionRequest.Put(ctx, entry) + if err != nil { + return fmt.Errorf("failed to save session request for %s: %w", entry.SessionID, err) + } + } + if len(decryptionQueue) > 0 { + ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true + } + if timeline.Limited { + err = h.DB.Timeline.Clear(ctx, room.ID) + if err != nil { + return fmt.Errorf("failed to clear old timeline: %w", err) + } + } + err = h.DB.Timeline.Append(ctx, room.ID, timelineIDs) + if err != nil { + return fmt.Errorf("failed to append timeline: %w", err) + } + } + if timeline.PrevBatch != "" && room.PrevBatch == "" { + room.PrevBatch = timeline.PrevBatch + roomDataChanged = true + } + if summary.Heroes != nil { + roomDataChanged = roomDataChanged || room.LazyLoadSummary == nil || + !slices.Equal(summary.Heroes, room.LazyLoadSummary.Heroes) || + !intPtrEqual(summary.JoinedMemberCount, room.LazyLoadSummary.JoinedMemberCount) || + !intPtrEqual(summary.InvitedMemberCount, room.LazyLoadSummary.InvitedMemberCount) + room.LazyLoadSummary = summary + } + if roomDataChanged { + err = h.DB.Room.Upsert(ctx, room) + if err != nil { + return fmt.Errorf("failed to save room data: %w", err) + } + } + return nil +} + +func intPtrEqual(a, b *int) bool { + if a == nil || b == nil { + return a == b + } + return *a == *b +} + +func processImportantEvent(ctx context.Context, evt *event.Event, existingRoomData *database.Room) (roomDataChanged bool) { + if evt.StateKey == nil { + return + } + switch evt.Type { + case event.StateCreate, event.StateRoomName, event.StateRoomAvatar, event.StateTopic, event.StateEncryption: + if *evt.StateKey != "" { + return + } + default: + return + } + err := evt.Content.ParseRaw(evt.Type) + if err != nil { + zerolog.Ctx(ctx).Warn().Err(err). + Stringer("event_type", &evt.Type). + Stringer("event_id", evt.ID). + Msg("Failed to parse state event, skipping") + return + } + switch evt.Type { + case event.StateCreate: + if existingRoomData.CreationContent == nil { + roomDataChanged = true + } + existingRoomData.CreationContent, _ = evt.Content.Parsed.(*event.CreateEventContent) + case event.StateEncryption: + newEncryption, _ := evt.Content.Parsed.(*event.EncryptionEventContent) + if existingRoomData.EncryptionEvent == nil || existingRoomData.EncryptionEvent.Algorithm == newEncryption.Algorithm { + roomDataChanged = true + existingRoomData.EncryptionEvent = newEncryption + } + case event.StateRoomName: + content, ok := evt.Content.Parsed.(*event.RoomNameEventContent) + if ok { + roomDataChanged = existingRoomData.Name == nil || *existingRoomData.Name != content.Name + existingRoomData.Name = &content.Name + } + case event.StateRoomAvatar: + content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent) + if ok { + roomDataChanged = existingRoomData.Avatar == nil || *existingRoomData.Avatar != content.URL + existingRoomData.Avatar = &content.URL + } + case event.StateTopic: + content, ok := evt.Content.Parsed.(*event.TopicEventContent) + if ok { + roomDataChanged = existingRoomData.Topic == nil || *existingRoomData.Topic != content.Topic + existingRoomData.Topic = &content.Topic + } + } + return +} diff --git a/hicli/syncwrap.go b/hicli/syncwrap.go new file mode 100644 index 00000000..eccdb7b1 --- /dev/null +++ b/hicli/syncwrap.go @@ -0,0 +1,100 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "fmt" + "time" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/id" +) + +type hiSyncer HiClient + +var _ mautrix.Syncer = (*hiSyncer)(nil) + +type contextKey int + +const ( + syncContextKey contextKey = iota +) + +func (h *hiSyncer) ProcessResponse(ctx context.Context, resp *mautrix.RespSync, since string) error { + c := (*HiClient)(h) + ctx = context.WithValue(ctx, syncContextKey, &syncContext{}) + err := c.preProcessSyncResponse(ctx, resp, since) + if err != nil { + return err + } + err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { + return c.processSyncResponse(ctx, resp, since) + }) + if err != nil { + return err + } + err = c.postProcessSyncResponse(ctx, resp, since) + if err != nil { + return err + } + c.firstSyncReceived = true + return nil +} + +func (h *hiSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.Duration, error) { + (*HiClient)(h).Log.Err(err).Msg("Sync failed, retrying in 1 second") + return 1 * time.Second, nil +} + +func (h *hiSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter { + if !h.Verified { + return &mautrix.Filter{ + Presence: mautrix.FilterPart{ + NotRooms: []id.RoomID{"*"}, + }, + Room: mautrix.RoomFilter{ + NotRooms: []id.RoomID{"*"}, + }, + } + } + return &mautrix.Filter{ + Presence: mautrix.FilterPart{ + NotRooms: []id.RoomID{"*"}, + }, + Room: mautrix.RoomFilter{ + State: mautrix.FilterPart{ + LazyLoadMembers: true, + }, + Timeline: mautrix.FilterPart{ + Limit: 100, + LazyLoadMembers: true, + }, + }, + } +} + +type hiStore HiClient + +var _ mautrix.SyncStore = (*hiStore)(nil) + +// Filter ID save and load are intentionally no-ops: we want to recreate filters when restarting syncing + +func (h *hiStore) SaveFilterID(_ context.Context, _ id.UserID, _ string) error { return nil } +func (h *hiStore) LoadFilterID(_ context.Context, _ id.UserID) (string, error) { return "", nil } + +func (h *hiStore) SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) error { + // This is intentionally a no-op: we don't want to save the next batch before processing the sync + return nil +} + +func (h *hiStore) LoadNextBatch(_ context.Context, userID id.UserID) (string, error) { + if h.Account.UserID != userID { + return "", fmt.Errorf("mismatching user ID") + } + return h.Account.NextBatch, nil +} diff --git a/hicli/verify.go b/hicli/verify.go new file mode 100644 index 00000000..2062519a --- /dev/null +++ b/hicli/verify.go @@ -0,0 +1,158 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "encoding/base64" + "fmt" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/backup" + "maunium.net/go/mautrix/crypto/ssss" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func (h *HiClient) checkIsCurrentDeviceVerified(ctx context.Context) (bool, error) { + keys := h.Crypto.GetOwnCrossSigningPublicKeys(ctx) + if keys == nil { + return false, fmt.Errorf("own cross-signing keys not found") + } + isVerified, err := h.Crypto.CryptoStore.IsKeySignedBy(ctx, h.Account.UserID, h.Crypto.GetAccount().SigningKey(), h.Account.UserID, keys.SelfSigningKey) + if err != nil { + return false, fmt.Errorf("failed to check if current device is signed by own self-signing key: %w", err) + } + return isVerified, nil +} + +func (h *HiClient) fetchKeyBackupKey(ctx context.Context, ssssKey *ssss.Key) error { + latestVersion, err := h.Client.GetKeyBackupLatestVersion(ctx) + if err != nil { + return fmt.Errorf("failed to get key backup latest version: %w", err) + } + h.KeyBackupVersion = latestVersion.Version + data, err := h.Crypto.SSSS.GetDecryptedAccountData(ctx, event.AccountDataMegolmBackupKey, ssssKey) + if err != nil { + return fmt.Errorf("failed to get megolm backup key from SSSS: %w", err) + } + key, err := backup.MegolmBackupKeyFromBytes(data) + if err != nil { + return fmt.Errorf("failed to parse megolm backup key: %w", err) + } + err = h.CryptoStore.PutSecret(ctx, id.SecretMegolmBackupV1, base64.StdEncoding.EncodeToString(key.Bytes())) + if err != nil { + return fmt.Errorf("failed to store megolm backup key: %w", err) + } + h.KeyBackupKey = key + return nil +} + +func (h *HiClient) getAndDecodeSecret(ctx context.Context, secret id.Secret) ([]byte, error) { + secretData, err := h.CryptoStore.GetSecret(ctx, secret) + if err != nil { + return nil, fmt.Errorf("failed to get secret %s: %w", secret, err) + } + data, err := base64.StdEncoding.DecodeString(secretData) + if err != nil { + return nil, fmt.Errorf("failed to decode secret %s: %w", secret, err) + } + return data, nil +} + +func (h *HiClient) loadPrivateKeys(ctx context.Context) error { + zerolog.Ctx(ctx).Debug().Msg("Loading cross-signing private keys") + masterKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSMaster) + if err != nil { + return fmt.Errorf("failed to get master key: %w", err) + } + selfSigningKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSSelfSigning) + if err != nil { + return fmt.Errorf("failed to get self-signing key: %w", err) + } + userSigningKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSUserSigning) + if err != nil { + return fmt.Errorf("failed to get user signing key: %w", err) + } + err = h.Crypto.ImportCrossSigningKeys(crypto.CrossSigningSeeds{ + MasterKey: masterKeySeed, + SelfSigningKey: selfSigningKeySeed, + UserSigningKey: userSigningKeySeed, + }) + if err != nil { + return fmt.Errorf("failed to import cross-signing private keys: %w", err) + } + zerolog.Ctx(ctx).Debug().Msg("Loading key backup key") + keyBackupKey, err := h.getAndDecodeSecret(ctx, id.SecretMegolmBackupV1) + if err != nil { + return fmt.Errorf("failed to get megolm backup key: %w", err) + } + h.KeyBackupKey, err = backup.MegolmBackupKeyFromBytes(keyBackupKey) + if err != nil { + return fmt.Errorf("failed to parse megolm backup key: %w", err) + } + zerolog.Ctx(ctx).Debug().Msg("Fetching key backup version") + latestVersion, err := h.Client.GetKeyBackupLatestVersion(ctx) + if err != nil { + return fmt.Errorf("failed to get key backup latest version: %w", err) + } + h.KeyBackupVersion = latestVersion.Version + zerolog.Ctx(ctx).Debug().Msg("Secrets loaded") + return nil +} + +func (h *HiClient) storeCrossSigningPrivateKeys(ctx context.Context) error { + keys := h.Crypto.CrossSigningKeys + err := h.CryptoStore.PutSecret(ctx, id.SecretXSMaster, base64.StdEncoding.EncodeToString(keys.MasterKey.Seed())) + if err != nil { + return err + } + err = h.CryptoStore.PutSecret(ctx, id.SecretXSSelfSigning, base64.StdEncoding.EncodeToString(keys.SelfSigningKey.Seed())) + if err != nil { + return err + } + err = h.CryptoStore.PutSecret(ctx, id.SecretXSUserSigning, base64.StdEncoding.EncodeToString(keys.UserSigningKey.Seed())) + if err != nil { + return err + } + return nil +} + +func (h *HiClient) VerifyWithRecoveryCode(ctx context.Context, code string) error { + _, keyData, err := h.Crypto.SSSS.GetDefaultKeyData(ctx) + if err != nil { + return fmt.Errorf("failed to get default SSSS key data: %w", err) + } + key, err := keyData.VerifyRecoveryKey(code) + if err != nil { + return err + } + err = h.Crypto.FetchCrossSigningKeysFromSSSS(ctx, key) + if err != nil { + return fmt.Errorf("failed to fetch cross-signing keys from SSSS: %w", err) + } + err = h.Crypto.SignOwnDevice(ctx, h.Crypto.OwnIdentity()) + if err != nil { + return fmt.Errorf("failed to sign own device: %w", err) + } + err = h.Crypto.SignOwnMasterKey(ctx) + if err != nil { + return fmt.Errorf("failed to sign own master key: %w", err) + } + err = h.storeCrossSigningPrivateKeys(ctx) + if err != nil { + return fmt.Errorf("failed to store cross-signing private keys: %w", err) + } + err = h.fetchKeyBackupKey(ctx, key) + if err != nil { + return fmt.Errorf("failed to fetch key backup key: %w", err) + } + h.Verified = true + return nil +} From 9254a5d6c1d97cefd94e208e2f4a921012eb1a5c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 3 May 2024 11:44:34 +0200 Subject: [PATCH 0226/1647] Add base for v2 bridge architecture --- appservice/http.go | 7 + bridgev2/bridge.go | 102 ++++ bridgev2/bridgeconfig/appservice.go | 131 ++++ bridgev2/bridgeconfig/config.go | 52 ++ bridgev2/bridgeconfig/encryption.go | 48 ++ bridgev2/bridgeconfig/homeserver.go | 38 ++ bridgev2/bridgeconfig/permissions.go | 71 +++ bridgev2/bridgeconfig/upgrade.go | 132 ++++ bridgev2/cmdevent.go | 96 +++ bridgev2/cmdhandler.go | 94 +++ bridgev2/cmdhelp.go | 129 ++++ bridgev2/cmdmeta.go | 50 ++ bridgev2/cmdprocessor.go | 118 ++++ bridgev2/database/database.go | 49 ++ bridgev2/database/ghost.go | 92 +++ bridgev2/database/message.go | 146 +++++ bridgev2/database/portal.go | 128 ++++ bridgev2/database/reaction.go | 110 ++++ bridgev2/database/upgrades/00-latest.sql | 133 ++++ bridgev2/database/upgrades/upgrades.go | 22 + bridgev2/database/user.go | 89 +++ bridgev2/database/userlogin.go | 111 ++++ bridgev2/ghost.go | 87 +++ bridgev2/matrix/cmdadmin.go | 79 +++ bridgev2/matrix/cmddoublepuppet.go | 74 +++ bridgev2/matrix/connector.go | 183 ++++++ bridgev2/matrix/crypto.go | 499 +++++++++++++++ bridgev2/matrix/cryptostore.go | 63 ++ bridgev2/matrix/doublepuppet.go | 175 ++++++ bridgev2/matrix/intent.go | 178 ++++++ bridgev2/matrix/matrix.go | 175 ++++++ bridgev2/matrix/no-crypto.go | 28 + bridgev2/matrix/websocket.go.dis | 163 +++++ bridgev2/matrixinterface.go | 52 ++ bridgev2/messagestatus.go | 86 +++ bridgev2/networkid/bridgeid.go | 59 ++ bridgev2/networkinterface.go | 217 +++++++ bridgev2/portal.go | 733 +++++++++++++++++++++++ bridgev2/queue.go | 92 +++ bridgev2/user.go | 85 +++ bridgev2/userlogin.go | 111 ++++ event/beeper.go | 7 +- requests.go | 6 +- versions.go | 1 + 44 files changed, 5097 insertions(+), 4 deletions(-) create mode 100644 bridgev2/bridge.go create mode 100644 bridgev2/bridgeconfig/appservice.go create mode 100644 bridgev2/bridgeconfig/config.go create mode 100644 bridgev2/bridgeconfig/encryption.go create mode 100644 bridgev2/bridgeconfig/homeserver.go create mode 100644 bridgev2/bridgeconfig/permissions.go create mode 100644 bridgev2/bridgeconfig/upgrade.go create mode 100644 bridgev2/cmdevent.go create mode 100644 bridgev2/cmdhandler.go create mode 100644 bridgev2/cmdhelp.go create mode 100644 bridgev2/cmdmeta.go create mode 100644 bridgev2/cmdprocessor.go create mode 100644 bridgev2/database/database.go create mode 100644 bridgev2/database/ghost.go create mode 100644 bridgev2/database/message.go create mode 100644 bridgev2/database/portal.go create mode 100644 bridgev2/database/reaction.go create mode 100644 bridgev2/database/upgrades/00-latest.sql create mode 100644 bridgev2/database/upgrades/upgrades.go create mode 100644 bridgev2/database/user.go create mode 100644 bridgev2/database/userlogin.go create mode 100644 bridgev2/ghost.go create mode 100644 bridgev2/matrix/cmdadmin.go create mode 100644 bridgev2/matrix/cmddoublepuppet.go create mode 100644 bridgev2/matrix/connector.go create mode 100644 bridgev2/matrix/crypto.go create mode 100644 bridgev2/matrix/cryptostore.go create mode 100644 bridgev2/matrix/doublepuppet.go create mode 100644 bridgev2/matrix/intent.go create mode 100644 bridgev2/matrix/matrix.go create mode 100644 bridgev2/matrix/no-crypto.go create mode 100644 bridgev2/matrix/websocket.go.dis create mode 100644 bridgev2/matrixinterface.go create mode 100644 bridgev2/messagestatus.go create mode 100644 bridgev2/networkid/bridgeid.go create mode 100644 bridgev2/networkinterface.go create mode 100644 bridgev2/portal.go create mode 100644 bridgev2/queue.go create mode 100644 bridgev2/user.go create mode 100644 bridgev2/userlogin.go diff --git a/appservice/http.go b/appservice/http.go index 38bcecf8..47f6a282 100644 --- a/appservice/http.go +++ b/appservice/http.go @@ -211,10 +211,17 @@ func (as *AppService) handleEvents(ctx context.Context, evts []*event.Event, def for _, evt := range evts { evt.Mautrix.ReceivedAt = time.Now() if defaultTypeClass != event.UnknownEventType { + if defaultTypeClass == event.EphemeralEventType { + evt.Mautrix.EventSource = event.SourceEphemeral + } else if defaultTypeClass == event.ToDeviceEventType { + evt.Mautrix.EventSource = event.SourceToDevice + } evt.Type.Class = defaultTypeClass } else if evt.StateKey != nil { + evt.Mautrix.EventSource = event.SourceTimeline & event.SourceJoin evt.Type.Class = event.StateEventType } else { + evt.Mautrix.EventSource = event.SourceTimeline evt.Type.Class = event.MessageEventType } err := evt.Content.ParseRaw(evt.Type) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go new file mode 100644 index 00000000..1550c72a --- /dev/null +++ b/bridgev2/bridge.go @@ -0,0 +1,102 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "errors" + "os" + "os/signal" + "sync" + "syscall" + + "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" + "go.mau.fi/util/exerrors" + + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" + + "maunium.net/go/mautrix/id" +) + +var ErrNotLoggedIn = errors.New("not logged in") + +type Bridge struct { + ID networkid.BridgeID + DB *database.Database + Log zerolog.Logger + + Matrix MatrixConnector + Bot MatrixAPI + Network NetworkConnector + Commands *CommandProcessor + + // TODO move to config + CommandPrefix string + + usersByMXID map[id.UserID]*User + userLoginsByID map[networkid.UserLoginID]*UserLogin + portalsByID map[networkid.PortalID]*Portal + portalsByMXID map[id.RoomID]*Portal + ghostsByID map[networkid.UserID]*Ghost + cacheLock sync.Mutex +} + +func NewBridge(bridgeID networkid.BridgeID, db *dbutil.Database, log zerolog.Logger, matrix MatrixConnector, network NetworkConnector) *Bridge { + br := &Bridge{ + ID: bridgeID, + DB: database.New(bridgeID, db), + Log: log, + + Matrix: matrix, + Network: network, + + usersByMXID: make(map[id.UserID]*User), + userLoginsByID: make(map[networkid.UserLoginID]*UserLogin), + portalsByID: make(map[networkid.PortalID]*Portal), + portalsByMXID: make(map[id.RoomID]*Portal), + ghostsByID: make(map[networkid.UserID]*Ghost), + } + br.Commands = NewProcessor(br) + br.Matrix.Init(br) + br.Bot = br.Matrix.BotIntent() + br.Network.Init(br) + return br +} + +func (br *Bridge) Start() { + br.Log.Info().Msg("Starting bridge") + ctx := br.Log.WithContext(context.Background()) + + exerrors.PanicIfNotNil(br.DB.Upgrade(ctx)) + br.Log.Info().Msg("Starting Matrix connector") + exerrors.PanicIfNotNil(br.Matrix.Start(ctx)) + br.Log.Info().Msg("Starting network connector") + exerrors.PanicIfNotNil(br.Network.Start(ctx)) + + logins, err := br.GetAllUserLogins(ctx) + if err != nil { + br.Log.Fatal().Err(err).Msg("Failed to get user logins") + } + for _, login := range logins { + br.Log.Info().Str("id", string(login.ID)).Msg("Starting user login") + err = login.Client.Connect(login.Log.WithContext(ctx)) + if err != nil { + br.Log.Err(err).Msg("Failed to connect existing client") + } + } + if len(logins) == 0 { + br.Log.Info().Msg("No user logins found") + } + + br.Log.Info().Msg("Bridge started") + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + <-c + br.Log.Info().Msg("Shutting down bridge") +} diff --git a/bridgev2/bridgeconfig/appservice.go b/bridgev2/bridgeconfig/appservice.go new file mode 100644 index 00000000..37ed9306 --- /dev/null +++ b/bridgev2/bridgeconfig/appservice.go @@ -0,0 +1,131 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgeconfig + +import ( + "fmt" + "html/template" + "regexp" + "strings" + + "go.mau.fi/util/exerrors" + "go.mau.fi/util/random" + + "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/id" +) + +type AppserviceConfig struct { + Address string `yaml:"address"` + Hostname string `yaml:"hostname"` + Port uint16 `yaml:"port"` + + ID string `yaml:"id"` + Bot BotUserConfig `yaml:"bot"` + + ASToken string `yaml:"as_token"` + HSToken string `yaml:"hs_token"` + + EphemeralEvents bool `yaml:"ephemeral_events"` + AsyncTransactions bool `yaml:"async_transactions"` + + UsernameTemplate string `yaml:"username_template"` + usernameTemplate *template.Template `yaml:"-"` +} + +func (asc *AppserviceConfig) FormatUsername(username string) string { + if asc.usernameTemplate == nil { + asc.usernameTemplate = exerrors.Must(template.New("username").Parse(asc.UsernameTemplate)) + } + var buf strings.Builder + _ = asc.usernameTemplate.Execute(&buf, username) + return buf.String() +} + +func (config *Config) MakeUserIDRegex(matcher string) *regexp.Regexp { + usernamePlaceholder := strings.ToLower(random.String(16)) + usernameTemplate := fmt.Sprintf("@%s:%s", + config.AppService.FormatUsername(usernamePlaceholder), + config.Homeserver.Domain) + usernameTemplate = regexp.QuoteMeta(usernameTemplate) + usernameTemplate = strings.Replace(usernameTemplate, usernamePlaceholder, matcher, 1) + usernameTemplate = fmt.Sprintf("^%s$", usernameTemplate) + return regexp.MustCompile(usernameTemplate) +} + +// GetRegistration copies the data from the bridge config into an *appservice.Registration struct. +// This can't be used with the homeserver, see GenerateRegistration for generating files for the homeserver. +func (asc *AppserviceConfig) GetRegistration() *appservice.Registration { + reg := &appservice.Registration{} + asc.copyToRegistration(reg) + reg.SenderLocalpart = asc.Bot.Username + reg.ServerToken = asc.HSToken + reg.AppToken = asc.ASToken + return reg +} + +func (asc *AppserviceConfig) copyToRegistration(registration *appservice.Registration) { + registration.ID = asc.ID + registration.URL = asc.Address + falseVal := false + registration.RateLimited = &falseVal + registration.EphemeralEvents = asc.EphemeralEvents + registration.SoruEphemeralEvents = asc.EphemeralEvents +} + +// GenerateRegistration generates a registration file for the homeserver. +func (config *Config) GenerateRegistration() *appservice.Registration { + registration := appservice.CreateRegistration() + config.AppService.HSToken = registration.ServerToken + config.AppService.ASToken = registration.AppToken + config.AppService.copyToRegistration(registration) + + registration.SenderLocalpart = random.String(32) + botRegex := regexp.MustCompile(fmt.Sprintf("^@%s:%s$", + regexp.QuoteMeta(config.AppService.Bot.Username), + regexp.QuoteMeta(config.Homeserver.Domain))) + registration.Namespaces.UserIDs.Register(botRegex, true) + registration.Namespaces.UserIDs.Register(config.MakeUserIDRegex(".*"), true) + + return registration +} + +func (config *Config) MakeAppService() *appservice.AppService { + as := appservice.Create() + as.HomeserverDomain = config.Homeserver.Domain + _ = as.SetHomeserverURL(config.Homeserver.Address) + as.Host.Hostname = config.AppService.Hostname + as.Host.Port = config.AppService.Port + as.Registration = config.AppService.GetRegistration() + return as +} + +type BotUserConfig struct { + Username string `yaml:"username"` + Displayname string `yaml:"displayname"` + Avatar string `yaml:"avatar"` + + ParsedAvatar id.ContentURI `yaml:"-"` +} + +type serializableBUC BotUserConfig + +func (buc *BotUserConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + var sbuc serializableBUC + err := unmarshal(&sbuc) + if err != nil { + return err + } + *buc = (BotUserConfig)(sbuc) + if buc.Avatar != "" && buc.Avatar != "remove" { + buc.ParsedAvatar, err = id.ParseContentURI(buc.Avatar) + if err != nil { + return fmt.Errorf("%w in bot avatar", err) + } + } + return nil +} diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go new file mode 100644 index 00000000..eeaa6d48 --- /dev/null +++ b/bridgev2/bridgeconfig/config.go @@ -0,0 +1,52 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgeconfig + +import ( + "go.mau.fi/util/dbutil" + "go.mau.fi/zeroconfig" +) + +type Config struct { + Homeserver HomeserverConfig `yaml:"homeserver"` + AppService AppserviceConfig `yaml:"appservice"` + Database dbutil.Config `yaml:"database"` + Bridge BridgeConfig `yaml:"bridge"` // TODO this is more like matrix than bridge + Provisioning ProvisioningConfig `yaml:"provisioning"` + DoublePuppet DoublePuppetConfig `yaml:"double_puppet"` + Encryption EncryptionConfig `yaml:"encryption"` + Permissions PermissionConfig `yaml:"permissions"` + ManagementRoomTexts ManagementRoomTexts `yaml:"management_room_texts"` + Logging zeroconfig.Config `yaml:"logging"` +} + +type BridgeConfig struct { + MessageStatusEvents bool `yaml:"message_status_events"` + DeliveryReceipts bool `yaml:"delivery_receipts"` + MessageErrorNotices bool `yaml:"message_error_notices"` + SyncDirectChatList bool `yaml:"sync_direct_chat_list"` + FederateRooms bool `yaml:"federate_rooms"` +} + +type ProvisioningConfig struct { + Prefix string `yaml:"prefix"` + SharedSecret string `yaml:"shared_secret"` + DebugEndpoints bool `yaml:"debug_endpoints"` +} + +type DoublePuppetConfig struct { + Servers map[string]string `yaml:"servers"` + AllowDiscovery bool `yaml:"allow_discovery"` + Secrets map[string]string `yaml:"secrets"` +} + +type ManagementRoomTexts struct { + Welcome string `yaml:"welcome"` + WelcomeConnected string `yaml:"welcome_connected"` + WelcomeUnconnected string `yaml:"welcome_unconnected"` + AdditionalHelp string `yaml:"additional_help"` +} diff --git a/bridgev2/bridgeconfig/encryption.go b/bridgev2/bridgeconfig/encryption.go new file mode 100644 index 00000000..93a427d3 --- /dev/null +++ b/bridgev2/bridgeconfig/encryption.go @@ -0,0 +1,48 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgeconfig + +import ( + "maunium.net/go/mautrix/id" +) + +type EncryptionConfig struct { + Allow bool `yaml:"allow"` + Default bool `yaml:"default"` + Require bool `yaml:"require"` + Appservice bool `yaml:"appservice"` + + PlaintextMentions bool `yaml:"plaintext_mentions"` + + PickleKey string `yaml:"pickle_key"` + + DeleteKeys struct { + DeleteOutboundOnAck bool `yaml:"delete_outbound_on_ack"` + DontStoreOutbound bool `yaml:"dont_store_outbound"` + RatchetOnDecrypt bool `yaml:"ratchet_on_decrypt"` + DeleteFullyUsedOnDecrypt bool `yaml:"delete_fully_used_on_decrypt"` + DeletePrevOnNewSession bool `yaml:"delete_prev_on_new_session"` + DeleteOnDeviceDelete bool `yaml:"delete_on_device_delete"` + PeriodicallyDeleteExpired bool `yaml:"periodically_delete_expired"` + DeleteOutdatedInbound bool `yaml:"delete_outdated_inbound"` + } `yaml:"delete_keys"` + + VerificationLevels struct { + Receive id.TrustState `yaml:"receive"` + Send id.TrustState `yaml:"send"` + Share id.TrustState `yaml:"share"` + } `yaml:"verification_levels"` + AllowKeySharing bool `yaml:"allow_key_sharing"` + + Rotation struct { + EnableCustom bool `yaml:"enable_custom"` + Milliseconds int64 `yaml:"milliseconds"` + Messages int `yaml:"messages"` + + DisableDeviceChangeKeyRotation bool `yaml:"disable_device_change_key_rotation"` + } `yaml:"rotation"` +} diff --git a/bridgev2/bridgeconfig/homeserver.go b/bridgev2/bridgeconfig/homeserver.go new file mode 100644 index 00000000..8d888d4f --- /dev/null +++ b/bridgev2/bridgeconfig/homeserver.go @@ -0,0 +1,38 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgeconfig + +type HomeserverSoftware string + +const ( + SoftwareStandard HomeserverSoftware = "standard" + SoftwareAsmux HomeserverSoftware = "asmux" + SoftwareHungry HomeserverSoftware = "hungry" +) + +var AllowedHomeserverSoftware = map[HomeserverSoftware]bool{ + SoftwareStandard: true, + SoftwareAsmux: true, + SoftwareHungry: true, +} + +type HomeserverConfig struct { + Address string `yaml:"address"` + Domain string `yaml:"domain"` + AsyncMedia bool `yaml:"async_media"` + + PublicAddress string `yaml:"public_address,omitempty"` + + Software HomeserverSoftware `yaml:"software"` + + StatusEndpoint string `yaml:"status_endpoint"` + MessageSendCheckpointEndpoint string `yaml:"message_send_checkpoint_endpoint"` + + Websocket bool `yaml:"websocket"` + WSProxy string `yaml:"websocket_proxy"` + WSPingInterval int `yaml:"ping_interval_seconds"` +} diff --git a/bridgev2/bridgeconfig/permissions.go b/bridgev2/bridgeconfig/permissions.go new file mode 100644 index 00000000..198e140e --- /dev/null +++ b/bridgev2/bridgeconfig/permissions.go @@ -0,0 +1,71 @@ +// Copyright (c) 2023 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgeconfig + +import ( + "strconv" + "strings" + + "maunium.net/go/mautrix/id" +) + +type PermissionConfig map[string]PermissionLevel + +type PermissionLevel int + +const ( + PermissionLevelBlock PermissionLevel = 0 + PermissionLevelRelay PermissionLevel = 5 + PermissionLevelUser PermissionLevel = 10 + PermissionLevelAdmin PermissionLevel = 100 +) + +var namesToLevels = map[string]PermissionLevel{ + "block": PermissionLevelBlock, + "relay": PermissionLevelRelay, + "user": PermissionLevelUser, + "admin": PermissionLevelAdmin, +} + +func RegisterPermissionLevel(name string, level PermissionLevel) { + namesToLevels[name] = level +} + +func (pc *PermissionConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + rawPC := make(map[string]string) + err := unmarshal(&rawPC) + if err != nil { + return err + } + + if *pc == nil { + *pc = make(map[string]PermissionLevel) + } + for key, value := range rawPC { + level, ok := namesToLevels[strings.ToLower(value)] + if ok { + (*pc)[key] = level + } else if val, err := strconv.Atoi(value); err == nil { + (*pc)[key] = PermissionLevel(val) + } else { + (*pc)[key] = PermissionLevelBlock + } + } + return nil +} + +func (pc PermissionConfig) Get(userID id.UserID) PermissionLevel { + if level, ok := pc[string(userID)]; ok { + return level + } else if level, ok = pc[userID.Homeserver()]; len(userID.Homeserver()) > 0 && ok { + return level + } else if level, ok = pc["*"]; ok { + return level + } else { + return PermissionLevelBlock + } +} diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go new file mode 100644 index 00000000..d4006f43 --- /dev/null +++ b/bridgev2/bridgeconfig/upgrade.go @@ -0,0 +1,132 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgeconfig + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/rs/zerolog" + up "go.mau.fi/util/configupgrade" + "go.mau.fi/zeroconfig" + "gopkg.in/yaml.v3" +) + +func doUpgrade(helper *up.Helper) { + helper.Copy(up.Str, "homeserver", "address") + helper.Copy(up.Str, "homeserver", "domain") + if legacyAsmuxFlag, ok := helper.Get(up.Bool, "homeserver", "asmux"); ok && legacyAsmuxFlag == "true" { + helper.Set(up.Str, string(SoftwareAsmux), "homeserver", "software") + } else { + helper.Copy(up.Str, "homeserver", "software") + } + helper.Copy(up.Str|up.Null, "homeserver", "status_endpoint") + helper.Copy(up.Str|up.Null, "homeserver", "message_send_checkpoint_endpoint") + helper.Copy(up.Bool, "homeserver", "async_media") + helper.Copy(up.Str|up.Null, "homeserver", "websocket_proxy") + helper.Copy(up.Bool, "homeserver", "websocket") + helper.Copy(up.Int, "homeserver", "ping_interval_seconds") + + helper.Copy(up.Str|up.Null, "appservice", "address") + helper.Copy(up.Str|up.Null, "appservice", "hostname") + helper.Copy(up.Int|up.Null, "appservice", "port") + if dbType, ok := helper.Get(up.Str, "appservice", "database", "type"); ok && dbType == "sqlite3" { + helper.Set(up.Str, "sqlite3-fk-wal", "appservice", "database", "type") + } else { + helper.Copy(up.Str, "appservice", "database", "type") + } + helper.Copy(up.Str, "appservice", "database", "uri") + helper.Copy(up.Int, "appservice", "database", "max_open_conns") + helper.Copy(up.Int, "appservice", "database", "max_idle_conns") + helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_idle_time") + helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_lifetime") + helper.Copy(up.Str, "appservice", "id") + helper.Copy(up.Str, "appservice", "bot", "username") + helper.Copy(up.Str, "appservice", "bot", "displayname") + helper.Copy(up.Str, "appservice", "bot", "avatar") + helper.Copy(up.Bool, "appservice", "ephemeral_events") + helper.Copy(up.Bool, "appservice", "async_transactions") + helper.Copy(up.Str, "appservice", "as_token") + helper.Copy(up.Str, "appservice", "hs_token") + + if helper.GetNode("logging", "writers") == nil && (helper.GetNode("logging", "print_level") != nil || helper.GetNode("logging", "file_name_format") != nil) { + _, _ = fmt.Fprintln(os.Stderr, "Migrating legacy log config") + migrateLegacyLogConfig(helper) + } else if helper.GetNode("logging", "writers") == nil && (helper.GetNode("logging", "handlers") != nil) { + _, _ = fmt.Fprintln(os.Stderr, "Migrating Python log config is not currently supported") + // TODO implement? + //migratePythonLogConfig(helper) + } else { + helper.Copy(up.Map, "logging") + } +} + +type legacyLogConfig struct { + Directory string `yaml:"directory"` + FileNameFormat string `yaml:"file_name_format"` + FileDateFormat string `yaml:"file_date_format"` + FileMode uint32 `yaml:"file_mode"` + TimestampFormat string `yaml:"timestamp_format"` + RawPrintLevel string `yaml:"print_level"` + JSONStdout bool `yaml:"print_json"` + JSONFile bool `yaml:"file_json"` +} + +func migrateLegacyLogConfig(helper *up.Helper) { + var llc legacyLogConfig + var newConfig zeroconfig.Config + err := helper.GetBaseNode("logging").Decode(&newConfig) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Base config is corrupted: failed to decode example log config:", err) + return + } else if len(newConfig.Writers) != 2 || newConfig.Writers[0].Type != "stdout" || newConfig.Writers[1].Type != "file" { + _, _ = fmt.Fprintln(os.Stderr, "Base log config is not in expected format") + return + } + err = helper.GetNode("logging").Decode(&llc) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Failed to decode legacy log config:", err) + return + } + if llc.RawPrintLevel != "" { + level, err := zerolog.ParseLevel(llc.RawPrintLevel) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Failed to parse minimum stdout log level:", err) + } else { + newConfig.Writers[0].MinLevel = &level + } + } + if llc.Directory != "" && llc.FileNameFormat != "" { + if llc.FileNameFormat == "{{.Date}}-{{.Index}}.log" { + llc.FileNameFormat = "bridge.log" + } else { + llc.FileNameFormat = strings.ReplaceAll(llc.FileNameFormat, "{{.Date}}", "") + llc.FileNameFormat = strings.ReplaceAll(llc.FileNameFormat, "{{.Index}}", "") + } + newConfig.Writers[1].Filename = filepath.Join(llc.Directory, llc.FileNameFormat) + } else if llc.FileNameFormat == "" { + newConfig.Writers = newConfig.Writers[0:1] + } + if llc.JSONStdout { + newConfig.Writers[0].TimeFormat = "" + newConfig.Writers[0].Format = "json" + } else if llc.TimestampFormat != "" { + newConfig.Writers[0].TimeFormat = llc.TimestampFormat + } + var updatedConfig yaml.Node + err = updatedConfig.Encode(&newConfig) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Failed to encode migrated log config:", err) + return + } + *helper.GetBaseNode("logging").Node = updatedConfig +} + +// Upgrader is a config upgrader that copies the default fields in the homeserver, appservice and logging blocks. +var Upgrader = up.SimpleUpgrader(doUpgrade) diff --git a/bridgev2/cmdevent.go b/bridgev2/cmdevent.go new file mode 100644 index 00000000..de43ccca --- /dev/null +++ b/bridgev2/cmdevent.go @@ -0,0 +1,96 @@ +// Copyright (c) 2023 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" + "maunium.net/go/mautrix/id" +) + +// CommandEvent stores all data which might be used to handle commands +type CommandEvent struct { + Bot MatrixAPI + Bridge *Bridge + Portal *Portal + Processor *CommandProcessor + Handler MinimalCommandHandler + RoomID id.RoomID + EventID id.EventID + User *User + Command string + Args []string + RawArgs string + ReplyTo id.EventID + Ctx context.Context + Log *zerolog.Logger +} + +// Reply sends a reply to command as notice, with optional string formatting and automatic $cmdprefix replacement. +func (ce *CommandEvent) Reply(msg string, args ...any) { + msg = strings.ReplaceAll(msg, "$cmdprefix ", ce.Bridge.CommandPrefix+" ") + if len(args) > 0 { + msg = fmt.Sprintf(msg, args...) + } + ce.ReplyAdvanced(msg, true, false) +} + +// ReplyAdvanced sends a reply to command as notice. It allows using HTML and disabling markdown, +// but doesn't have built-in string formatting. +func (ce *CommandEvent) ReplyAdvanced(msg string, allowMarkdown, allowHTML bool) { + content := format.RenderMarkdown(msg, allowMarkdown, allowHTML) + content.MsgType = event.MsgNotice + _, err := ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: content}, time.Now()) + if err != nil { + ce.Log.Err(err).Msgf("Failed to reply to command") + } +} + +// React sends a reaction to the command. +func (ce *CommandEvent) React(key string) { + _, err := ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventReaction, &event.Content{ + Parsed: &event.ReactionEventContent{ + RelatesTo: event.RelatesTo{ + Type: event.RelAnnotation, + EventID: ce.EventID, + Key: key, + }, + }, + }, time.Now()) + if err != nil { + ce.Log.Err(err).Msgf("Failed to react to command") + } +} + +// Redact redacts the command. +func (ce *CommandEvent) Redact(req ...mautrix.ReqRedact) { + _, err := ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventRedaction, &event.Content{ + Parsed: &event.RedactionEventContent{ + Redacts: ce.EventID, + }, + }, time.Now()) + if err != nil { + ce.Log.Err(err).Msgf("Failed to redact command") + } +} + +// MarkRead marks the command event as read. +func (ce *CommandEvent) MarkRead() { + // TODO + //err := ce.Bot.SendReceipt(ce.Ctx, ce.RoomID, ce.EventID, event.ReceiptTypeRead, nil) + //if err != nil { + // ce.Log.Err(err).Msgf("Failed to mark command as read") + //} +} diff --git a/bridgev2/cmdhandler.go b/bridgev2/cmdhandler.go new file mode 100644 index 00000000..9f9c69ec --- /dev/null +++ b/bridgev2/cmdhandler.go @@ -0,0 +1,94 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "maunium.net/go/mautrix/event" +) + +type MinimalCommandHandler interface { + Run(*CommandEvent) +} + +type MinimalCommandHandlerFunc func(*CommandEvent) + +func (mhf MinimalCommandHandlerFunc) Run(ce *CommandEvent) { + mhf(ce) +} + +type CommandState struct { + Next MinimalCommandHandler + Action string + Meta any +} + +type CommandHandler interface { + MinimalCommandHandler + GetName() string +} + +type AliasedCommandHandler interface { + CommandHandler + GetAliases() []string +} + +type FullHandler struct { + Func func(*CommandEvent) + + Name string + Aliases []string + Help HelpMeta + + RequiresAdmin bool + RequiresPortal bool + RequiresLogin bool + + RequiresEventLevel event.Type +} + +func (fh *FullHandler) GetHelp() HelpMeta { + fh.Help.Command = fh.Name + return fh.Help +} + +func (fh *FullHandler) GetName() string { + return fh.Name +} + +func (fh *FullHandler) GetAliases() []string { + return fh.Aliases +} + +func (fh *FullHandler) ShowInHelp(ce *CommandEvent) bool { + return true + //return !fh.RequiresAdmin || ce.User.GetPermissionLevel() >= bridgeconfig.PermissionLevelAdmin +} + +func (fh *FullHandler) userHasRoomPermission(ce *CommandEvent) bool { + return true + //levels, err := ce.MainIntent().PowerLevels(ce.Ctx, ce.RoomID) + //if err != nil { + // ce.ZLog.Warn().Err(err).Msg("Failed to check room power levels") + // ce.Reply("Failed to get room power levels to see if you're allowed to use that command") + // return false + //} + //return levels.GetUserLevel(ce.User.GetMXID()) >= levels.GetEventLevel(fh.RequiresEventLevel) +} + +func (fh *FullHandler) Run(ce *CommandEvent) { + //if fh.RequiresAdmin && ce.User.GetPermissionLevel() < bridgeconfig.PermissionLevelAdmin { + // ce.Reply("That command is limited to bridge administrators.") + //} else if fh.RequiresEventLevel.Type != "" && ce.User.GetPermissionLevel() < bridgeconfig.PermissionLevelAdmin && !fh.userHasRoomPermission(ce) { + // ce.Reply("That command requires room admin rights.") + //} else if fh.RequiresPortal && ce.Portal == nil { + // ce.Reply("That command can only be ran in portal rooms.") + //} else if fh.RequiresLogin && !ce.User.IsLoggedIn() { + // ce.Reply("That command requires you to be logged in.") + //} else { + fh.Func(ce) + //} +} diff --git a/bridgev2/cmdhelp.go b/bridgev2/cmdhelp.go new file mode 100644 index 00000000..53d5076e --- /dev/null +++ b/bridgev2/cmdhelp.go @@ -0,0 +1,129 @@ +// Copyright (c) 2022 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "fmt" + "sort" + "strings" +) + +type HelpfulHandler interface { + CommandHandler + GetHelp() HelpMeta + ShowInHelp(*CommandEvent) bool +} + +type HelpSection struct { + Name string + Order int +} + +var ( + // Deprecated: this should be used as a placeholder that needs to be fixed + HelpSectionUnclassified = HelpSection{"Unclassified", -1} + + HelpSectionGeneral = HelpSection{"General", 0} + HelpSectionAuth = HelpSection{"Authentication", 10} + HelpSectionAdmin = HelpSection{"Administration", 50} +) + +type HelpMeta struct { + Command string + Section HelpSection + Description string + Args string +} + +func (hm *HelpMeta) String() string { + if len(hm.Args) == 0 { + return fmt.Sprintf("**%s** - %s", hm.Command, hm.Description) + } + return fmt.Sprintf("**%s** %s - %s", hm.Command, hm.Args, hm.Description) +} + +type helpSectionList []HelpSection + +func (h helpSectionList) Len() int { + return len(h) +} + +func (h helpSectionList) Less(i, j int) bool { + return h[i].Order < h[j].Order +} + +func (h helpSectionList) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +type helpMetaList []HelpMeta + +func (h helpMetaList) Len() int { + return len(h) +} + +func (h helpMetaList) Less(i, j int) bool { + return h[i].Command < h[j].Command +} + +func (h helpMetaList) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +var _ sort.Interface = (helpSectionList)(nil) +var _ sort.Interface = (helpMetaList)(nil) + +func FormatHelp(ce *CommandEvent) string { + sections := make(map[HelpSection]helpMetaList) + for _, handler := range ce.Processor.handlers { + helpfulHandler, ok := handler.(HelpfulHandler) + if !ok || !helpfulHandler.ShowInHelp(ce) { + continue + } + help := helpfulHandler.GetHelp() + if help.Description == "" { + continue + } + sections[help.Section] = append(sections[help.Section], help) + } + + sortedSections := make(helpSectionList, 0, len(sections)) + for section := range sections { + sortedSections = append(sortedSections, section) + } + sort.Sort(sortedSections) + + var output strings.Builder + output.Grow(10240) + + var prefixMsg string + if ce.RoomID == ce.User.ManagementRoom { + prefixMsg = "This is your management room: prefixing commands with `%s` is not required." + } else if ce.Portal != nil { + prefixMsg = "**This is a portal room**: you must always prefix commands with `%s`. Management commands will not be bridged." + } else { + prefixMsg = "This is not your management room: prefixing commands with `%s` is required." + } + _, _ = fmt.Fprintf(&output, prefixMsg, ce.Bridge.CommandPrefix) + output.WriteByte('\n') + output.WriteString("Parameters in [square brackets] are optional, while parameters in are required.") + output.WriteByte('\n') + output.WriteByte('\n') + + for _, section := range sortedSections { + output.WriteString("#### ") + output.WriteString(section.Name) + output.WriteByte('\n') + sort.Sort(sections[section]) + for _, command := range sections[section] { + output.WriteString(command.String()) + output.WriteByte('\n') + } + output.WriteByte('\n') + } + return output.String() +} diff --git a/bridgev2/cmdmeta.go b/bridgev2/cmdmeta.go new file mode 100644 index 00000000..4020f569 --- /dev/null +++ b/bridgev2/cmdmeta.go @@ -0,0 +1,50 @@ +// Copyright (c) 2022 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +var CommandHelp = &FullHandler{ + Func: func(ce *CommandEvent) { + ce.Reply(FormatHelp(ce)) + }, + Name: "help", + Help: HelpMeta{ + Section: HelpSectionGeneral, + Description: "Show this help message.", + }, +} + +var CommandVersion = &FullHandler{ + Func: func(ce *CommandEvent) { + ce.Reply("Bridge versions are not yet implemented") + //ce.Reply("[%s](%s) %s (%s)", ce.Bridge.Name, ce.Bridge.URL, ce.Bridge.LinkifiedVersion, ce.Bridge.BuildTime) + }, + Name: "version", + Help: HelpMeta{ + Section: HelpSectionGeneral, + Description: "Get the bridge version.", + }, +} + +var CommandCancel = &FullHandler{ + Func: func(ce *CommandEvent) { + state := ce.User.CommandState.Swap(nil) + if state != nil { + action := state.Action + if action == "" { + action = "Unknown action" + } + ce.Reply("%s cancelled.", action) + } else { + ce.Reply("No ongoing command.") + } + }, + Name: "cancel", + Help: HelpMeta{ + Section: HelpSectionGeneral, + Description: "Cancel an ongoing action.", + }, +} diff --git a/bridgev2/cmdprocessor.go b/bridgev2/cmdprocessor.go new file mode 100644 index 00000000..e412bf8e --- /dev/null +++ b/bridgev2/cmdprocessor.go @@ -0,0 +1,118 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "runtime/debug" + "strings" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/id" +) + +type CommandProcessor struct { + bridge *Bridge + log *zerolog.Logger + + handlers map[string]CommandHandler + aliases map[string]string +} + +// NewProcessor creates a CommandProcessor +func NewProcessor(bridge *Bridge) *CommandProcessor { + proc := &CommandProcessor{ + bridge: bridge, + log: &bridge.Log, + + handlers: make(map[string]CommandHandler), + aliases: make(map[string]string), + } + proc.AddHandlers(CommandHelp, CommandVersion, CommandCancel) + return proc +} + +func (proc *CommandProcessor) AddHandlers(handlers ...CommandHandler) { + for _, handler := range handlers { + proc.AddHandler(handler) + } +} + +func (proc *CommandProcessor) AddHandler(handler CommandHandler) { + proc.handlers[handler.GetName()] = handler + aliased, ok := handler.(AliasedCommandHandler) + if ok { + for _, alias := range aliased.GetAliases() { + proc.aliases[alias] = handler.GetName() + } + } +} + +// Handle handles messages to the bridge +func (proc *CommandProcessor) Handle(ctx context.Context, roomID id.RoomID, eventID id.EventID, user *User, message string, replyTo id.EventID) { + defer func() { + err := recover() + if err != nil { + zerolog.Ctx(ctx).Error(). + Str(zerolog.ErrorStackFieldName, string(debug.Stack())). + Interface(zerolog.ErrorFieldName, err). + Msg("Panic in Matrix command handler") + } + }() + args := strings.Fields(message) + if len(args) == 0 { + args = []string{"unknown-command"} + } + command := strings.ToLower(args[0]) + rawArgs := strings.TrimLeft(strings.TrimPrefix(message, command), " ") + log := zerolog.Ctx(ctx).With().Str("mx_command", command).Logger() + ctx = log.WithContext(ctx) + portal, err := proc.bridge.GetPortalByMXID(ctx, roomID) + if err != nil { + // :( + } + ce := &CommandEvent{ + Bot: proc.bridge.Bot, + Bridge: proc.bridge, + Portal: portal, + Processor: proc, + RoomID: roomID, + EventID: eventID, + User: user, + Command: command, + Args: args[1:], + RawArgs: rawArgs, + ReplyTo: replyTo, + Ctx: ctx, + Log: &log, + } + log.Debug().Msg("Received command") + + realCommand, ok := proc.aliases[ce.Command] + if !ok { + realCommand = ce.Command + } + + var handler MinimalCommandHandler + handler, ok = proc.handlers[realCommand] + if !ok { + state := ce.User.CommandState.Load() + if state != nil && state.Next != nil { + ce.Command = "" + ce.RawArgs = message + ce.Args = args + ce.Handler = state.Next + state.Next.Run(ce) + } else { + ce.Reply("Unknown command, use the `help` command for help.") + } + } else { + ce.Handler = handler + handler.Run(ce) + } +} diff --git a/bridgev2/database/database.go b/bridgev2/database/database.go new file mode 100644 index 00000000..688a40da --- /dev/null +++ b/bridgev2/database/database.go @@ -0,0 +1,49 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/bridgev2/networkid" + + "maunium.net/go/mautrix/bridgev2/database/upgrades" +) + +type Database struct { + *dbutil.Database + + BridgeID networkid.BridgeID + Portal *PortalQuery + Ghost *GhostQuery + Message *MessageQuery + Reaction *ReactionQuery + User *UserQuery + UserLogin *UserLoginQuery +} + +func New(bridgeID networkid.BridgeID, db *dbutil.Database) *Database { + db.UpgradeTable = upgrades.Table + return &Database{ + Database: db, + BridgeID: bridgeID, + Portal: &PortalQuery{bridgeID, dbutil.MakeQueryHelper(db, newPortal)}, + Ghost: &GhostQuery{bridgeID, dbutil.MakeQueryHelper(db, newGhost)}, + Message: &MessageQuery{bridgeID, dbutil.MakeQueryHelper(db, newMessage)}, + Reaction: &ReactionQuery{bridgeID, dbutil.MakeQueryHelper(db, newReaction)}, + User: &UserQuery{bridgeID, dbutil.MakeQueryHelper(db, newUser)}, + UserLogin: &UserLoginQuery{bridgeID, dbutil.MakeQueryHelper(db, newUserLogin)}, + } +} + +func ensureBridgeIDMatches(ptr *networkid.BridgeID, expected networkid.BridgeID) { + if *ptr == "" { + *ptr = expected + } else if *ptr != expected { + panic("bridge ID mismatch") + } +} diff --git a/bridgev2/database/ghost.go b/bridgev2/database/ghost.go new file mode 100644 index 00000000..e56eb13a --- /dev/null +++ b/bridgev2/database/ghost.go @@ -0,0 +1,92 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/id" +) + +type GhostQuery struct { + BridgeID networkid.BridgeID + *dbutil.QueryHelper[*Ghost] +} + +type Ghost struct { + BridgeID networkid.BridgeID + ID networkid.UserID + + Name string + AvatarID networkid.AvatarID + AvatarMXC id.ContentURIString + NameSet bool + AvatarSet bool + Metadata map[string]any +} + +func newGhost(_ *dbutil.QueryHelper[*Ghost]) *Ghost { + return &Ghost{} +} + +const ( + getGhostBaseQuery = ` + SELECT bridge_id, id, name, avatar_id, avatar_mxc, name_set, avatar_set, metadata FROM ghost + ` + getGhostByIDQuery = getGhostBaseQuery + `WHERE bridge_id=$1 AND id=$2` + insertGhostQuery = ` + INSERT INTO ghost (bridge_id, id, name, avatar_id, avatar_mxc, name_set, avatar_set, metadata) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ` + updateGhostQuery = ` + UPDATE ghost SET name=$3, avatar_id=$4, avatar_mxc=$5, name_set=$6, avatar_set=$7, metadata=$8 + WHERE bridge_id=$1 AND id=$2 + ` +) + +func (gq *GhostQuery) GetByID(ctx context.Context, id networkid.UserID) (*Ghost, error) { + return gq.QueryOne(ctx, getGhostByIDQuery, gq.BridgeID, id) +} + +func (gq *GhostQuery) Insert(ctx context.Context, ghost *Ghost) error { + ensureBridgeIDMatches(&ghost.BridgeID, gq.BridgeID) + return gq.Exec(ctx, insertGhostQuery, ghost.sqlVariables()...) +} + +func (gq *GhostQuery) Update(ctx context.Context, ghost *Ghost) error { + ensureBridgeIDMatches(&ghost.BridgeID, gq.BridgeID) + return gq.Exec(ctx, updateGhostQuery, ghost.sqlVariables()...) +} + +func (g *Ghost) Scan(row dbutil.Scannable) (*Ghost, error) { + err := row.Scan( + &g.BridgeID, &g.ID, + &g.Name, &g.AvatarID, &g.AvatarMXC, + &g.NameSet, &g.AvatarSet, dbutil.JSON{Data: &g.Metadata}, + ) + if err != nil { + return nil, err + } + if g.Metadata == nil { + g.Metadata = make(map[string]any) + } + return g, nil +} + +func (g *Ghost) sqlVariables() []any { + if g.Metadata == nil { + g.Metadata = make(map[string]any) + } + return []any{ + g.BridgeID, g.ID, + g.Name, g.AvatarID, g.AvatarMXC, + g.NameSet, g.AvatarSet, dbutil.JSON{Data: g.Metadata}, + } +} diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go new file mode 100644 index 00000000..539aa61a --- /dev/null +++ b/bridgev2/database/message.go @@ -0,0 +1,146 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "time" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/id" +) + +type MessageQuery struct { + BridgeID networkid.BridgeID + *dbutil.QueryHelper[*Message] +} + +type Message struct { + RowID int64 + BridgeID networkid.BridgeID + ID networkid.MessageID + PartID networkid.PartID + MXID id.EventID + + RoomID networkid.PortalID + SenderID networkid.UserID + Timestamp time.Time + + RelatesToRowID int64 + + Metadata map[string]any +} + +func newMessage(_ *dbutil.QueryHelper[*Message]) *Message { + return &Message{} +} + +const ( + getMessageBaseQuery = ` + SELECT rowid, bridge_id, id, part_id, mxid, room_id, sender_id, timestamp, relates_to, metadata FROM message + ` + getAllMessagePartsByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND id=$2` + getMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND part_id=$3` + getMessageByMXIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` + getLastMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND id=$2 ORDER BY part_id DESC LIMIT 1` + getFirstMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND id=$2 ORDER BY part_id ASC LIMIT 1` + getMessagesBetweenTimeQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND timestamp>$3 AND timestamp<=$4` + insertMessageQuery = ` + INSERT INTO message (bridge_id, id, part_id, mxid, room_id, sender_id, timestamp, relates_to, metadata) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + RETURNING rowid + ` + updateMessageQuery = ` + UPDATE message SET id=$2, part_id=$3, mxid=$4, room_id=$5, sender_id=$6, timestamp=$7, relates_to=$8, metadata=$9 + WHERE bridge_id=$1 AND rowid=$10 + ` + deleteAllMessagePartsByIDQuery = ` + DELETE FROM message WHERE bridge_id=$1 AND id=$2 + ` + deleteMessagePartByRowIDQuery = ` + DELETE FROM message WHERE bridge_id=$1 AND rowid=$2 + ` +) + +func (mq *MessageQuery) GetAllPartsByID(ctx context.Context, id networkid.MessageID) ([]*Message, error) { + return mq.QueryMany(ctx, getAllMessagePartsByIDQuery, mq.BridgeID, id) +} + +func (mq *MessageQuery) GetPartByID(ctx context.Context, id networkid.MessageID, partID networkid.PartID) (*Message, error) { + return mq.QueryOne(ctx, getMessagePartByIDQuery, mq.BridgeID, id, partID) +} + +func (mq *MessageQuery) GetPartByMXID(ctx context.Context, mxid id.EventID) (*Message, error) { + return mq.QueryOne(ctx, getMessageByMXIDQuery, mq.BridgeID, mxid) +} + +func (mq *MessageQuery) GetLastPartByID(ctx context.Context, id networkid.MessageID) (*Message, error) { + return mq.QueryOne(ctx, getLastMessagePartByIDQuery, mq.BridgeID, id) +} + +func (mq *MessageQuery) GetFirstPartByID(ctx context.Context, id networkid.MessageID) (*Message, error) { + return mq.QueryOne(ctx, getFirstMessagePartByIDQuery, mq.BridgeID, id) +} + +func (mq *MessageQuery) GetFirstOrSpecificPartByID(ctx context.Context, id networkid.MessageOptionalPartID) (*Message, error) { + if id.PartID == nil { + return mq.GetFirstPartByID(ctx, id.MessageID) + } else { + return mq.GetPartByID(ctx, id.MessageID, *id.PartID) + } +} + +func (mq *MessageQuery) GetMessagesBetweenTimeQuery(ctx context.Context, start, end time.Time) ([]*Message, error) { + return mq.QueryMany(ctx, getMessagesBetweenTimeQuery, mq.BridgeID, start.UnixNano(), end.UnixNano()) +} + +func (mq *MessageQuery) Insert(ctx context.Context, msg *Message) error { + ensureBridgeIDMatches(&msg.BridgeID, mq.BridgeID) + return mq.GetDB().QueryRow(ctx, insertMessageQuery, msg.sqlVariables()...).Scan(&msg.RowID) +} + +func (mq *MessageQuery) Update(ctx context.Context, msg *Message) error { + ensureBridgeIDMatches(&msg.BridgeID, mq.BridgeID) + return mq.Exec(ctx, updateMessageQuery, msg.updateSQLVariables()...) +} + +func (mq *MessageQuery) DeleteAllParts(ctx context.Context, id networkid.MessageID) error { + return mq.Exec(ctx, deleteAllMessagePartsByIDQuery, mq.BridgeID, id) +} + +func (mq *MessageQuery) Delete(ctx context.Context, rowID int64) error { + return mq.Exec(ctx, deleteMessagePartByRowIDQuery, mq.BridgeID, rowID) +} + +func (m *Message) Scan(row dbutil.Scannable) (*Message, error) { + var timestamp int64 + var relatesTo sql.NullInt64 + err := row.Scan( + &m.RowID, &m.BridgeID, &m.ID, &m.PartID, &m.MXID, &m.RoomID, &m.SenderID, + ×tamp, &relatesTo, dbutil.JSON{Data: &m.Metadata}, + ) + if err != nil { + return nil, err + } + m.Timestamp = time.Unix(0, timestamp) + m.RelatesToRowID = relatesTo.Int64 + return m, nil +} + +func (m *Message) sqlVariables() []any { + return []any{ + m.BridgeID, m.ID, m.PartID, m.MXID, m.RoomID, m.SenderID, + m.Timestamp.UnixNano(), dbutil.NumPtr(m.RelatesToRowID), dbutil.JSON{Data: m.Metadata}, + } +} + +func (m *Message) updateSQLVariables() []any { + return append(m.sqlVariables(), m.RowID) +} diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go new file mode 100644 index 00000000..c0f581f6 --- /dev/null +++ b/bridgev2/database/portal.go @@ -0,0 +1,128 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/id" +) + +type PortalQuery struct { + BridgeID networkid.BridgeID + *dbutil.QueryHelper[*Portal] +} + +type Portal struct { + BridgeID networkid.BridgeID + ID networkid.PortalID + MXID id.RoomID + + ParentID networkid.PortalID + Name string + Topic string + AvatarID networkid.AvatarID + AvatarMXC id.ContentURIString + NameSet bool + TopicSet bool + AvatarSet bool + InSpace bool + Metadata map[string]any +} + +func newPortal(_ *dbutil.QueryHelper[*Portal]) *Portal { + return &Portal{} +} + +const ( + getPortalBaseQuery = ` + SELECT bridge_id, id, mxid, parent_id, name, topic, avatar_id, avatar_mxc, + name_set, topic_set, avatar_set, in_space, + metadata + FROM portal + ` + getPortalByIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2` + getPortalByMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` + getChildPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND parent_id=$2` + + insertPortalQuery = ` + INSERT INTO portal ( + bridge_id, id, mxid, + parent_id, name, topic, avatar_id, avatar_mxc, + name_set, avatar_set, topic_set, in_space, + metadata + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) + ` + updatePortalQuery = ` + UPDATE portal + SET mxid=$3, parent_id=$4, name=$5, topic=$6, avatar_id=$7, avatar_mxc=$8, + name_set=$9, avatar_set=$10, topic_set=$11, in_space=$12, metadata=$13 + WHERE bridge_id=$1 AND id=$2 + ` + reIDPortalQuery = `UPDATE portal SET id=$3 WHERE bridge_id=$1 AND id=$2` +) + +func (pq *PortalQuery) GetByID(ctx context.Context, id networkid.PortalID) (*Portal, error) { + return pq.QueryOne(ctx, getPortalByIDQuery, pq.BridgeID, id) +} + +func (pq *PortalQuery) GetByMXID(ctx context.Context, mxid id.RoomID) (*Portal, error) { + return pq.QueryOne(ctx, getPortalByMXIDQuery, pq.BridgeID, mxid) +} + +func (pq *PortalQuery) GetChildren(ctx context.Context, parentID networkid.PortalID) ([]*Portal, error) { + return pq.QueryMany(ctx, getChildPortalsQuery, pq.BridgeID, parentID) +} + +func (pq *PortalQuery) ReID(ctx context.Context, oldID, newID networkid.PortalID) error { + return pq.Exec(ctx, reIDPortalQuery, pq.BridgeID, oldID, newID) +} + +func (pq *PortalQuery) Insert(ctx context.Context, p *Portal) error { + ensureBridgeIDMatches(&p.BridgeID, pq.BridgeID) + return pq.Exec(ctx, insertPortalQuery, p.sqlVariables()...) +} + +func (pq *PortalQuery) Update(ctx context.Context, p *Portal) error { + ensureBridgeIDMatches(&p.BridgeID, pq.BridgeID) + return pq.Exec(ctx, updatePortalQuery, p.sqlVariables()...) +} + +func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { + var mxid, parentID sql.NullString + err := row.Scan( + &p.BridgeID, &p.ID, &mxid, + &parentID, &p.Name, &p.Topic, &p.AvatarID, &p.AvatarMXC, + &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.InSpace, + dbutil.JSON{Data: &p.Metadata}, + ) + if err != nil { + return nil, err + } + if p.Metadata == nil { + p.Metadata = make(map[string]any) + } + p.MXID = id.RoomID(mxid.String) + p.ParentID = networkid.PortalID(parentID.String) + return p, nil +} + +func (p *Portal) sqlVariables() []any { + if p.Metadata == nil { + p.Metadata = make(map[string]any) + } + return []any{ + p.BridgeID, p.ID, dbutil.StrPtr(p.MXID), + dbutil.StrPtr(p.ParentID), p.Name, p.Topic, p.AvatarID, p.AvatarMXC, + p.NameSet, p.TopicSet, p.AvatarSet, p.InSpace, + dbutil.JSON{Data: p.Metadata}, + } +} diff --git a/bridgev2/database/reaction.go b/bridgev2/database/reaction.go new file mode 100644 index 00000000..f5d5a469 --- /dev/null +++ b/bridgev2/database/reaction.go @@ -0,0 +1,110 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "time" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/id" +) + +type ReactionQuery struct { + BridgeID networkid.BridgeID + *dbutil.QueryHelper[*Reaction] +} + +type Reaction struct { + BridgeID networkid.BridgeID + RoomID networkid.PortalID + MessageID networkid.MessageID + MessagePartID networkid.PartID + SenderID networkid.UserID + EmojiID networkid.EmojiID + MXID id.EventID + + Timestamp time.Time + Metadata map[string]any +} + +func newReaction(_ *dbutil.QueryHelper[*Reaction]) *Reaction { + return &Reaction{} +} + +const ( + getReactionBaseQuery = ` + SELECT bridge_id, message_id, message_part_id, sender_id, emoji_id, room_id, mxid, timestamp, metadata FROM reaction + ` + getReactionByIDQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND message_part_id=$3 AND sender_id=$4 AND emoji_id=$5` + getAllReactionsToMessageBySenderQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND sender_id=$3` + getAllReactionsToMessageQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2` + getReactionByMXIDQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` + upsertReactionQuery = ` + INSERT INTO reaction (bridge_id, message_id, message_part_id, sender_id, emoji_id, room_id, mxid, timestamp, metadata) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + ON CONFLICT (bridge_id, message_id, message_part_id, sender_id, emoji_id) + DO UPDATE SET mxid=excluded.mxid, timestamp=excluded.timestamp, metadata=excluded.metadata + ` + deleteReactionQuery = ` + DELETE FROM reaction WHERE bridge_id=$1 AND message_id=$2 AND message_part_id=$3 AND sender_id=$4 AND emoji_id=$5 + ` +) + +func (rq *ReactionQuery) GetByID(ctx context.Context, messageID networkid.MessageID, messagePartID networkid.PartID, senderID networkid.UserID, emojiID networkid.EmojiID) (*Reaction, error) { + return rq.QueryOne(ctx, getReactionByIDQuery, rq.BridgeID, messageID, messagePartID, senderID, emojiID) +} + +func (rq *ReactionQuery) GetAllToMessageBySender(ctx context.Context, messageID networkid.MessageID, senderID networkid.UserID) ([]*Reaction, error) { + return rq.QueryMany(ctx, getAllReactionsToMessageBySenderQuery, rq.BridgeID, messageID, senderID) +} + +func (rq *ReactionQuery) GetAllToMessage(ctx context.Context, messageID networkid.MessageID) ([]*Reaction, error) { + return rq.QueryMany(ctx, getAllReactionsToMessageQuery, rq.BridgeID, messageID) +} + +func (rq *ReactionQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Reaction, error) { + return rq.QueryOne(ctx, getReactionByMXIDQuery, rq.BridgeID, mxid) +} + +func (rq *ReactionQuery) Upsert(ctx context.Context, reaction *Reaction) error { + ensureBridgeIDMatches(&reaction.BridgeID, rq.BridgeID) + return rq.Exec(ctx, upsertReactionQuery, reaction.sqlVariables()...) +} + +func (rq *ReactionQuery) Delete(ctx context.Context, reaction *Reaction) error { + ensureBridgeIDMatches(&reaction.BridgeID, rq.BridgeID) + return rq.Exec(ctx, deleteReactionQuery, reaction.BridgeID, reaction.MessageID, reaction.MessagePartID, reaction.SenderID, reaction.EmojiID) +} + +func (r *Reaction) Scan(row dbutil.Scannable) (*Reaction, error) { + var timestamp int64 + err := row.Scan( + &r.BridgeID, &r.MessageID, &r.MessagePartID, &r.SenderID, &r.EmojiID, + &r.RoomID, &r.MXID, ×tamp, dbutil.JSON{Data: &r.Metadata}, + ) + if err != nil { + return nil, err + } + if r.Metadata == nil { + r.Metadata = make(map[string]any) + } + r.Timestamp = time.Unix(0, timestamp) + return r, nil +} + +func (r *Reaction) sqlVariables() []any { + if r.Metadata == nil { + r.Metadata = make(map[string]any) + } + return []any{ + r.BridgeID, r.MessageID, r.MessagePartID, r.SenderID, r.EmojiID, + r.RoomID, r.MXID, r.Timestamp.UnixNano(), dbutil.JSON{Data: r.Metadata}, + } +} diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql new file mode 100644 index 00000000..2f107689 --- /dev/null +++ b/bridgev2/database/upgrades/00-latest.sql @@ -0,0 +1,133 @@ +-- v0 -> v1: Latest revision +CREATE TABLE portal ( + bridge_id TEXT NOT NULL, + id TEXT NOT NULL, + mxid TEXT, + + parent_id TEXT, + name TEXT NOT NULL, + topic TEXT NOT NULL, + avatar_id TEXT NOT NULL, + avatar_mxc TEXT NOT NULL, + name_set BOOLEAN NOT NULL, + avatar_set BOOLEAN NOT NULL, + topic_set BOOLEAN NOT NULL, + in_space BOOLEAN NOT NULL, + metadata jsonb NOT NULL, + + PRIMARY KEY (bridge_id, id), + CONSTRAINT portal_parent_fkey FOREIGN KEY (bridge_id, parent_id) + -- Deletes aren't allowed to cascade here: + -- children should be re-parented or cleaned up manually + REFERENCES portal (bridge_id, id) ON UPDATE CASCADE +); + +CREATE TABLE ghost ( + bridge_id TEXT NOT NULL, + id TEXT NOT NULL, + + name TEXT NOT NULL, + avatar_id TEXT NOT NULL, + avatar_mxc TEXT NOT NULL, + name_set BOOLEAN NOT NULL, + avatar_set BOOLEAN NOT NULL, + metadata jsonb NOT NULL, + + PRIMARY KEY (bridge_id, id) +); + +CREATE TABLE message ( + -- Messages have an extra rowid to allow a single relates_to column with ON DELETE SET NULL + -- If the foreign key used (bridge_id, relates_to), then deleting the target column + -- would try to set bridge_id to null as well. + + -- only: sqlite (line commented) +-- rowid INTEGER PRIMARY KEY, + -- only: postgres + rowid BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + + bridge_id TEXT NOT NULL, + id TEXT NOT NULL, + part_id TEXT NOT NULL, + mxid TEXT NOT NULL, + + room_id TEXT NOT NULL, + sender_id TEXT NOT NULL, + timestamp BIGINT NOT NULL, + relates_to BIGINT, + metadata jsonb NOT NULL, + + CONSTRAINT message_relation_fkey FOREIGN KEY (relates_to) + REFERENCES message (rowid) ON DELETE SET NULL, + CONSTRAINT message_room_fkey FOREIGN KEY (bridge_id, room_id) + REFERENCES portal (bridge_id, id) + ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT message_sender_fkey FOREIGN KEY (bridge_id, sender_id) + REFERENCES ghost (bridge_id, id) + ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT message_real_pkey UNIQUE (bridge_id, id, part_id) +); + +CREATE TABLE reaction ( + bridge_id TEXT NOT NULL, + message_id TEXT NOT NULL, + message_part_id TEXT NOT NULL, + sender_id TEXT NOT NULL, + emoji_id TEXT NOT NULL, + room_id TEXT NOT NULL, + mxid TEXT NOT NULL, + + timestamp BIGINT NOT NULL, + metadata jsonb NOT NULL, + + PRIMARY KEY (bridge_id, message_id, message_part_id, sender_id, emoji_id), + CONSTRAINT reaction_room_fkey FOREIGN KEY (bridge_id, room_id) + REFERENCES portal (bridge_id, id) + ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT reaction_message_fkey FOREIGN KEY (bridge_id, message_id, message_part_id) + REFERENCES message (bridge_id, id, part_id) + ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT reaction_sender_fkey FOREIGN KEY (bridge_id, sender_id) + REFERENCES ghost (bridge_id, id) + ON DELETE CASCADE ON UPDATE CASCADE +); + +CREATE TABLE "user" ( + bridge_id TEXT NOT NULL, + mxid TEXT NOT NULL, + + management_room TEXT, + access_token TEXT, + + PRIMARY KEY (bridge_id, mxid) +); + +CREATE TABLE user_login ( + bridge_id TEXT NOT NULL, + user_mxid TEXT NOT NULL, + id TEXT NOT NULL, + space_room TEXT, + metadata jsonb NOT NULL, + + PRIMARY KEY (bridge_id, user_mxid, id), + CONSTRAINT user_login_user_fkey FOREIGN KEY (bridge_id, user_mxid) + REFERENCES "user" (bridge_id, mxid) + ON DELETE CASCADE ON UPDATE CASCADE +); + +CREATE TABLE user_portal ( + bridge_id TEXT NOT NULL, + user_mxid TEXT NOT NULL, + login_id TEXT NOT NULL, + portal_id TEXT NOT NULL, + in_space BOOLEAN NOT NULL, + preferred BOOLEAN NOT NULL, + + PRIMARY KEY (bridge_id, user_mxid, login_id, portal_id), + CONSTRAINT user_portal_user_login_fkey FOREIGN KEY (bridge_id, user_mxid, login_id) + REFERENCES user_login (bridge_id, user_mxid, id) + ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT user_portal_portal_fkey FOREIGN KEY (bridge_id, portal_id) + REFERENCES portal (bridge_id, id) + ON DELETE CASCADE ON UPDATE CASCADE +); diff --git a/bridgev2/database/upgrades/upgrades.go b/bridgev2/database/upgrades/upgrades.go new file mode 100644 index 00000000..4fef472e --- /dev/null +++ b/bridgev2/database/upgrades/upgrades.go @@ -0,0 +1,22 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package upgrades + +import ( + "embed" + + "go.mau.fi/util/dbutil" +) + +var Table dbutil.UpgradeTable + +//go:embed *.sql +var rawUpgrades embed.FS + +func init() { + Table.RegisterFS(rawUpgrades) +} diff --git a/bridgev2/database/user.go b/bridgev2/database/user.go new file mode 100644 index 00000000..c5d2d0aa --- /dev/null +++ b/bridgev2/database/user.go @@ -0,0 +1,89 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/id" +) + +type UserQuery struct { + BridgeID networkid.BridgeID + *dbutil.QueryHelper[*User] +} + +type User struct { + BridgeID networkid.BridgeID + MXID id.UserID + + ManagementRoom id.RoomID + AccessToken string +} + +func newUser(_ *dbutil.QueryHelper[*User]) *User { + return &User{} +} + +const ( + getUserBaseQuery = ` + SELECT bridge_id, mxid, management_room, access_token FROM "user" + ` + getUserByMXIDQuery = getUserBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` + insertUserQuery = ` + INSERT INTO "user" (bridge_id, mxid, management_room, access_token) + VALUES ($1, $2, $3, $4) + ` + updateUserQuery = ` + UPDATE "user" SET management_room=$3, access_token=$4 + WHERE bridge_id=$1 AND mxid=$2 + ` + findUserLoginsByPortalIDQuery = ` + SELECT login_id + FROM user_portal + WHERE bridge_id=$1 AND user_mxid=$2 AND portal_id=$3 + ORDER BY CASE WHEN preferred THEN 0 ELSE 1 END, login_id + ` +) + +func (uq *UserQuery) GetByMXID(ctx context.Context, userID id.UserID) (*User, error) { + return uq.QueryOne(ctx, getUserByMXIDQuery, uq.BridgeID, userID) +} + +func (uq *UserQuery) Insert(ctx context.Context, user *User) error { + ensureBridgeIDMatches(&user.BridgeID, uq.BridgeID) + return uq.Exec(ctx, insertUserQuery, user.sqlVariables()...) +} + +func (uq *UserQuery) Update(ctx context.Context, user *User) error { + ensureBridgeIDMatches(&user.BridgeID, uq.BridgeID) + return uq.Exec(ctx, updateUserQuery, user.sqlVariables()...) +} + +func (uq *UserQuery) FindLoginsByPortalID(ctx context.Context, userID id.UserID, portalID networkid.PortalID) ([]networkid.UserLoginID, error) { + rows, err := uq.GetDB().Query(ctx, findUserLoginsByPortalIDQuery, uq.BridgeID, userID, portalID) + return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[networkid.UserLoginID], err).AsList() +} + +func (u *User) Scan(row dbutil.Scannable) (*User, error) { + var managementRoom, accessToken sql.NullString + err := row.Scan(&u.BridgeID, &u.MXID, &managementRoom, &accessToken) + if err != nil { + return nil, err + } + u.ManagementRoom = id.RoomID(managementRoom.String) + u.AccessToken = accessToken.String + return u, nil +} + +func (u *User) sqlVariables() []any { + return []any{u.BridgeID, u.MXID, dbutil.StrPtr(u.ManagementRoom), dbutil.StrPtr(u.AccessToken)} +} diff --git a/bridgev2/database/userlogin.go b/bridgev2/database/userlogin.go new file mode 100644 index 00000000..8dfd6cbc --- /dev/null +++ b/bridgev2/database/userlogin.go @@ -0,0 +1,111 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/id" +) + +type UserLoginQuery struct { + BridgeID networkid.BridgeID + *dbutil.QueryHelper[*UserLogin] +} + +type UserLogin struct { + BridgeID networkid.BridgeID + UserMXID id.UserID + ID networkid.UserLoginID + SpaceRoom id.RoomID + Metadata map[string]any +} + +func newUserLogin(_ *dbutil.QueryHelper[*UserLogin]) *UserLogin { + return &UserLogin{} +} + +const ( + getUserLoginBaseQuery = ` + SELECT bridge_id, user_mxid, id, space_room, metadata FROM user_login + ` + getAllLoginsQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1` + getAllLoginsForUserQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND user_mxid=$2` + insertUserLoginQuery = ` + INSERT INTO user_login (bridge_id, user_mxid, id, space_room, metadata) + VALUES ($1, $2, $3, $4, $5) + ` + updateUserLoginQuery = ` + UPDATE user_login SET space_room=$4, metadata=$5 + WHERE bridge_id=$1 AND user_mxid=$2 AND id=$3 + ` + insertUserPortalQuery = ` + INSERT INTO user_portal (bridge_id, user_mxid, login_id, portal_id, in_space, preferred) + VALUES ($1, $2, $3, $4, false, false) + ON CONFLICT (bridge_id, user_mxid, login_id, portal_id) DO NOTHING + ` + upsertUserPortalQuery = ` + INSERT INTO user_portal (bridge_id, user_mxid, login_id, portal_id, in_space, preferred) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (bridge_id, user_mxid, login_id, portal_id) DO UPDATE SET in_space=excluded.in_space, preferred=excluded.preferred + ` + markLoginAsPreferredQuery = ` + UPDATE user_portal SET preferred=(login_id=$3) WHERE bridge_id=$1 AND user_mxid=$2 AND portal_id=$4 + ` +) + +func (uq *UserLoginQuery) GetAll(ctx context.Context) ([]*UserLogin, error) { + return uq.QueryMany(ctx, getAllLoginsQuery, uq.BridgeID) +} + +func (uq *UserLoginQuery) GetAllForUser(ctx context.Context, userID id.UserID) ([]*UserLogin, error) { + return uq.QueryMany(ctx, getAllLoginsForUserQuery, uq.BridgeID, userID) +} + +func (uq *UserLoginQuery) Insert(ctx context.Context, login *UserLogin) error { + ensureBridgeIDMatches(&login.BridgeID, uq.BridgeID) + return uq.Exec(ctx, insertUserLoginQuery, login.sqlVariables()...) +} + +func (uq *UserLoginQuery) Update(ctx context.Context, login *UserLogin) error { + ensureBridgeIDMatches(&login.BridgeID, uq.BridgeID) + return uq.Exec(ctx, updateUserLoginQuery, login.sqlVariables()...) +} + +func (uq *UserLoginQuery) EnsureUserPortalExists(ctx context.Context, login *UserLogin, portalID networkid.PortalID) error { + ensureBridgeIDMatches(&login.BridgeID, uq.BridgeID) + return uq.Exec(ctx, insertUserPortalQuery, login.BridgeID, login.UserMXID, login.ID, portalID) +} + +func (uq *UserLoginQuery) MarkLoginAsPreferredInPortal(ctx context.Context, login *UserLogin, portalID networkid.PortalID) error { + ensureBridgeIDMatches(&login.BridgeID, uq.BridgeID) + return uq.Exec(ctx, markLoginAsPreferredQuery, login.BridgeID, login.UserMXID, login.ID, portalID) +} + +func (u *UserLogin) Scan(row dbutil.Scannable) (*UserLogin, error) { + var spaceRoom sql.NullString + err := row.Scan(&u.BridgeID, &u.UserMXID, &u.ID, &spaceRoom, dbutil.JSON{Data: &u.Metadata}) + if err != nil { + return nil, err + } + if u.Metadata == nil { + u.Metadata = make(map[string]any) + } + u.SpaceRoom = id.RoomID(spaceRoom.String) + return u, nil +} + +func (u *UserLogin) sqlVariables() []any { + if u.Metadata == nil { + u.Metadata = make(map[string]any) + } + return []any{u.BridgeID, u.UserMXID, u.ID, dbutil.StrPtr(u.SpaceRoom), dbutil.JSON{Data: u.Metadata}} +} diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go new file mode 100644 index 00000000..467cafed --- /dev/null +++ b/bridgev2/ghost.go @@ -0,0 +1,87 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "fmt" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/id" +) + +type Ghost struct { + *database.Ghost + Bridge *Bridge + Log zerolog.Logger + Intent MatrixAPI + MXID id.UserID +} + +func (br *Bridge) loadGhost(ctx context.Context, dbGhost *database.Ghost, queryErr error, id *networkid.UserID) (*Ghost, error) { + if queryErr != nil { + return nil, fmt.Errorf("failed to query db: %w", queryErr) + } + if dbGhost == nil { + if id == nil { + return nil, nil + } + dbGhost = &database.Ghost{ + BridgeID: br.ID, + ID: *id, + } + err := br.DB.Ghost.Insert(ctx, dbGhost) + if err != nil { + return nil, fmt.Errorf("failed to insert new ghost: %w", err) + } + } + mxid := br.Matrix.FormatGhostMXID(dbGhost.ID) + ghost := &Ghost{ + Ghost: dbGhost, + Bridge: br, + Log: br.Log.With().Str("ghost_id", string(dbGhost.ID)).Logger(), + Intent: br.Matrix.GhostIntent(mxid), + MXID: mxid, + } + br.ghostsByID[ghost.ID] = ghost + return ghost, nil +} + +func (br *Bridge) unlockedGetGhostByID(ctx context.Context, id networkid.UserID, onlyIfExists bool) (*Ghost, error) { + cached, ok := br.ghostsByID[id] + if ok { + return cached, nil + } + idPtr := &id + if onlyIfExists { + idPtr = nil + } + db, err := br.DB.Ghost.GetByID(ctx, id) + return br.loadGhost(ctx, db, err, idPtr) +} + +func (br *Bridge) GetGhostByMXID(ctx context.Context, mxid id.UserID) (*Ghost, error) { + ghostID, ok := br.Matrix.ParseGhostMXID(mxid) + if !ok { + return nil, nil + } + return br.GetGhostByID(ctx, ghostID) +} + +func (br *Bridge) GetGhostByID(ctx context.Context, id networkid.UserID) (*Ghost, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + return br.unlockedGetGhostByID(ctx, id, false) +} + +func (ghost *Ghost) IntentFor(portal *Portal) MatrixAPI { + // TODO use user double puppet intent if appropriate + return ghost.Intent +} diff --git a/bridgev2/matrix/cmdadmin.go b/bridgev2/matrix/cmdadmin.go new file mode 100644 index 00000000..45a83b4f --- /dev/null +++ b/bridgev2/matrix/cmdadmin.go @@ -0,0 +1,79 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package matrix + +import ( + "strconv" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/id" +) + +var CommandDiscardMegolmSession = &bridgev2.FullHandler{ + Func: func(ce *bridgev2.CommandEvent) { + matrix := ce.Bridge.Matrix.(*Connector) + if matrix.Crypto == nil { + ce.Reply("This bridge instance doesn't have end-to-bridge encryption enabled") + } else { + matrix.Crypto.ResetSession(ce.Ctx, ce.RoomID) + ce.Reply("Successfully reset Megolm session in this room. New decryption keys will be shared the next time a message is sent from the remote network.") + } + }, + Name: "discard-megolm-session", + Aliases: []string{"discard-session"}, + Help: bridgev2.HelpMeta{ + Section: bridgev2.HelpSectionAdmin, + Description: "Discard the Megolm session in the room", + }, + RequiresAdmin: true, +} + +func fnSetPowerLevel(ce *bridgev2.CommandEvent) { + var level int + var userID id.UserID + var err error + if len(ce.Args) == 1 { + level, err = strconv.Atoi(ce.Args[0]) + if err != nil { + ce.Reply("Invalid power level \"%s\"", ce.Args[0]) + return + } + userID = ce.User.MXID + } else if len(ce.Args) == 2 { + userID = id.UserID(ce.Args[0]) + _, _, err := userID.Parse() + if err != nil { + ce.Reply("Invalid user ID \"%s\"", ce.Args[0]) + return + } + level, err = strconv.Atoi(ce.Args[1]) + if err != nil { + ce.Reply("Invalid power level \"%s\"", ce.Args[1]) + return + } + } else { + ce.Reply("**Usage:** `set-pl [user] `") + return + } + _, err = ce.Bot.(*ASIntent).Matrix.SetPowerLevel(ce.Ctx, ce.RoomID, userID, level) + if err != nil { + ce.Reply("Failed to set power levels: %v", err) + } +} + +var CommandSetPowerLevel = &bridgev2.FullHandler{ + Func: fnSetPowerLevel, + Name: "set-pl", + Aliases: []string{"set-power-level"}, + Help: bridgev2.HelpMeta{ + Section: bridgev2.HelpSectionAdmin, + Description: "Change the power level in a portal room.", + Args: "[_user ID_] <_power level_>", + }, + RequiresAdmin: true, + RequiresPortal: true, +} diff --git a/bridgev2/matrix/cmddoublepuppet.go b/bridgev2/matrix/cmddoublepuppet.go new file mode 100644 index 00000000..1b755f36 --- /dev/null +++ b/bridgev2/matrix/cmddoublepuppet.go @@ -0,0 +1,74 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package matrix + +import ( + "maunium.net/go/mautrix/bridgev2" +) + +var CommandLoginMatrix = &bridgev2.FullHandler{ + Func: fnLoginMatrix, + Name: "login-matrix", + Help: bridgev2.HelpMeta{ + Section: bridgev2.HelpSectionAuth, + Description: "Enable double puppeting.", + Args: "<_access token_>", + }, + RequiresLogin: true, +} + +func fnLoginMatrix(ce *bridgev2.CommandEvent) { + if len(ce.Args) == 0 { + ce.Reply("**Usage:** `login-matrix `") + return + } + //err := ce.User.SwitchCustomMXID(ce.Args[0], ce.User.GetMXID()) + //if err != nil { + // ce.Reply("Failed to enable double puppeting: %v", err) + //} else { + // ce.Reply("Successfully switched puppet") + //} +} + +var CommandPingMatrix = &bridgev2.FullHandler{ + Func: fnPingMatrix, + Name: "ping-matrix", + Help: bridgev2.HelpMeta{ + Section: bridgev2.HelpSectionAuth, + Description: "Ping the Matrix server with the double puppet.", + }, + RequiresLogin: true, +} + +func fnPingMatrix(ce *bridgev2.CommandEvent) { + //resp, err := puppet.CustomIntent().Whoami(ce.Ctx) + //if err != nil { + // ce.Reply("Failed to validate Matrix login: %v", err) + //} else { + // ce.Reply("Confirmed valid access token for %s / %s", resp.UserID, resp.DeviceID) + //} +} + +var CommandLogoutMatrix = &bridgev2.FullHandler{ + Func: fnLogoutMatrix, + Name: "logout-matrix", + Help: bridgev2.HelpMeta{ + Section: bridgev2.HelpSectionAuth, + Description: "Disable double puppeting.", + }, + RequiresLogin: true, +} + +func fnLogoutMatrix(ce *bridgev2.CommandEvent) { + //puppet := ce.User.GetIDoublePuppet() + //if puppet == nil || puppet.CustomIntent() == nil { + // ce.Reply("You don't have double puppeting enabled.") + // return + //} + //puppet.ClearCustomMXID() + //ce.Reply("Successfully disabled double puppeting.") +} diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go new file mode 100644 index 00000000..22dc50b3 --- /dev/null +++ b/bridgev2/matrix/connector.go @@ -0,0 +1,183 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package matrix + +import ( + "context" + "regexp" + "sync" + "time" + + "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/bridge/status" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/bridgeconfig" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/sqlstatestore" +) + +type Crypto interface { + HandleMemberEvent(context.Context, *event.Event) + Decrypt(context.Context, *event.Event) (*event.Event, error) + Encrypt(context.Context, id.RoomID, event.Type, *event.Content) error + WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool + RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) + ResetSession(context.Context, id.RoomID) + Init(ctx context.Context) error + Start() + Stop() + Reset(ctx context.Context, startAfterReset bool) + Client() *mautrix.Client + ShareKeys(context.Context) error +} + +type Connector struct { + //DB *dbutil.Database + AS *appservice.AppService + Bot *appservice.IntentAPI + StateStore *sqlstatestore.SQLStateStore + Crypto Crypto + Log *zerolog.Logger + Config *bridgeconfig.Config + Bridge *bridgev2.Bridge + + EventProcessor *appservice.EventProcessor + + userIDRegex *regexp.Regexp + + // TODO move to config + AsyncUploads bool + + Websocket bool + wsStopPinger chan struct{} + wsStarted chan struct{} + wsStopped chan struct{} + wsShortCircuitReconnectBackoff chan struct{} + wsStartupWait *sync.WaitGroup +} + +func NewConnector(cfg *bridgeconfig.Config) *Connector { + c := &Connector{} + c.Config = cfg + c.userIDRegex = cfg.MakeUserIDRegex("(.+)") + return c +} + +func (br *Connector) Init(bridge *bridgev2.Bridge) { + br.Bridge = bridge + br.Log = &bridge.Log + br.StateStore = sqlstatestore.NewSQLStateStore(bridge.DB.Database, dbutil.ZeroLogger(br.Log.With().Str("db_section", "matrix").Logger()), false) + br.AS = br.Config.MakeAppService() + br.AS.Log = bridge.Log + br.AS.StateStore = br.StateStore + br.EventProcessor = appservice.NewEventProcessor(br.AS) + for evtType := range status.CheckpointTypes { + br.EventProcessor.On(evtType, br.sendBridgeCheckpoint) + } + br.EventProcessor.On(event.EventMessage, br.handleRoomEvent) + br.EventProcessor.On(event.EventSticker, br.handleRoomEvent) + br.EventProcessor.On(event.EventReaction, br.handleRoomEvent) + br.EventProcessor.On(event.EventRedaction, br.handleRoomEvent) + br.EventProcessor.On(event.EventEncrypted, br.handleEncryptedEvent) + br.EventProcessor.On(event.StateMember, br.handleRoomEvent) + br.Bot = br.AS.BotIntent() + br.Crypto = NewCryptoHelper(br) + br.Bridge.Commands.AddHandlers(CommandDiscardMegolmSession, CommandSetPowerLevel) +} + +func (br *Connector) Start(ctx context.Context) error { + br.EventProcessor.Start(ctx) + err := br.StateStore.Upgrade(ctx) + if err != nil { + return err + } + go br.AS.Start() + if br.Crypto != nil { + err = br.Crypto.Init(ctx) + if err != nil { + return err + } + br.Crypto.Start() + } + return nil +} + +var _ bridgev2.MatrixConnector = (*Connector)(nil) + +func (br *Connector) GhostIntent(userID id.UserID) bridgev2.MatrixAPI { + return &ASIntent{ + Matrix: br.AS.Intent(userID), + Connector: br, + } +} + +func (br *Connector) SendMessageStatus(ctx context.Context, evt bridgev2.MessageStatus) { + log := zerolog.Ctx(ctx) + err := br.SendMessageCheckpoints([]*status.MessageCheckpoint{evt.ToCheckpoint()}) + if err != nil { + log.Err(err).Msg("Failed to send message checkpoint") + } +} + +func (br *Connector) SendMessageCheckpoints(checkpoints []*status.MessageCheckpoint) error { + checkpointsJSON := status.CheckpointsJSON{Checkpoints: checkpoints} + + if br.Websocket { + return br.AS.SendWebsocket(&appservice.WebsocketRequest{ + Command: "message_checkpoint", + Data: checkpointsJSON, + }) + } + + endpoint := br.Config.Homeserver.MessageSendCheckpointEndpoint + if endpoint == "" { + return nil + } + + return checkpointsJSON.SendHTTP(endpoint, br.AS.Registration.AppToken) +} + +func (br *Connector) ParseGhostMXID(userID id.UserID) (networkid.UserID, bool) { + match := br.userIDRegex.FindStringSubmatch(string(userID)) + if match == nil || userID == br.Bot.UserID { + return "", false + } + decoded, err := id.DecodeUserLocalpart(match[1]) + if err != nil { + return "", false + } + return networkid.UserID(decoded), true +} + +func (br *Connector) FormatGhostMXID(userID networkid.UserID) id.UserID { + localpart := br.Config.AppService.FormatUsername(id.EncodeUserLocalpart(string(userID))) + return id.NewUserID(localpart, br.Config.Homeserver.Domain) +} + +func (br *Connector) UserIntent(user *bridgev2.User) bridgev2.MatrixAPI { + // TODO implement double puppeting + return nil +} + +func (br *Connector) BotIntent() bridgev2.MatrixAPI { + return &ASIntent{Connector: br, Matrix: br.Bot} +} + +func (br *Connector) GetMemberInfo(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) { + // TODO fetch from network sometimes? + return br.AS.StateStore.GetMember(ctx, roomID, userID) +} + +func (br *Connector) ServerName() string { + return br.Config.Homeserver.Domain +} diff --git a/bridgev2/matrix/crypto.go b/bridgev2/matrix/crypto.go new file mode 100644 index 00000000..008c2f3b --- /dev/null +++ b/bridgev2/matrix/crypto.go @@ -0,0 +1,499 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build cgo && !nocrypto + +package matrix + +import ( + "context" + "errors" + "fmt" + "os" + "runtime/debug" + "sync" + "time" + + "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/sqlstatestore" +) + +var _ crypto.StateStore = (*sqlstatestore.SQLStateStore)(nil) + +var NoSessionFound = crypto.NoSessionFound +var DuplicateMessageIndex = crypto.DuplicateMessageIndex +var UnknownMessageIndex = olm.UnknownMessageIndex + +type CryptoHelper struct { + bridge *Connector + client *mautrix.Client + mach *crypto.OlmMachine + store *SQLCryptoStore + log *zerolog.Logger + + lock sync.RWMutex + syncDone sync.WaitGroup + cancelSync func() + + cancelPeriodicDeleteLoop func() +} + +func NewCryptoHelper(c *Connector) Crypto { + if !c.Config.Encryption.Allow { + c.Log.Debug().Msg("Bridge built with end-to-bridge encryption, but disabled in config") + return nil + } + log := c.Log.With().Str("component", "crypto").Logger() + return &CryptoHelper{ + bridge: c, + log: &log, + } +} + +func (helper *CryptoHelper) Init(ctx context.Context) error { + if len(helper.bridge.Config.Encryption.PickleKey) == 0 { + panic("CryptoPickleKey not set") + } + helper.log.Debug().Msg("Initializing end-to-bridge encryption...") + + helper.store = NewSQLCryptoStore( + helper.bridge.Bridge.DB.Database, + dbutil.ZeroLogger(helper.bridge.Log.With().Str("db_section", "crypto").Logger()), + string(helper.bridge.Bridge.ID), + helper.bridge.AS.BotMXID(), + fmt.Sprintf("@%s:%s", helper.bridge.Config.AppService.FormatUsername("%"), helper.bridge.AS.HomeserverDomain), + helper.bridge.Config.Encryption.PickleKey, + ) + + err := helper.store.DB.Upgrade(ctx) + if err != nil { + // TODO copy this function back + //helper.bridge.LogDBUpgradeErrorAndExit("crypto", err) + panic(err) + } + + var isExistingDevice bool + helper.client, isExistingDevice, err = helper.loginBot(ctx) + if err != nil { + return err + } + + helper.log.Debug(). + Str("device_id", helper.client.DeviceID.String()). + Msg("Logged in as bridge bot") + helper.mach = crypto.NewOlmMachine(helper.client, helper.log, helper.store, helper.bridge.StateStore) + helper.mach.AllowKeyShare = helper.allowKeyShare + + encryptionConfig := helper.bridge.Config.Encryption + helper.mach.SendKeysMinTrust = encryptionConfig.VerificationLevels.Receive + helper.mach.PlaintextMentions = encryptionConfig.PlaintextMentions + + helper.mach.DeleteOutboundKeysOnAck = encryptionConfig.DeleteKeys.DeleteOutboundOnAck + helper.mach.DontStoreOutboundKeys = encryptionConfig.DeleteKeys.DontStoreOutbound + helper.mach.RatchetKeysOnDecrypt = encryptionConfig.DeleteKeys.RatchetOnDecrypt + helper.mach.DeleteFullyUsedKeysOnDecrypt = encryptionConfig.DeleteKeys.DeleteFullyUsedOnDecrypt + helper.mach.DeletePreviousKeysOnReceive = encryptionConfig.DeleteKeys.DeletePrevOnNewSession + helper.mach.DeleteKeysOnDeviceDelete = encryptionConfig.DeleteKeys.DeleteOnDeviceDelete + helper.mach.DisableDeviceChangeKeyRotation = encryptionConfig.Rotation.DisableDeviceChangeKeyRotation + if encryptionConfig.DeleteKeys.PeriodicallyDeleteExpired { + ctx, cancel := context.WithCancel(context.Background()) + helper.cancelPeriodicDeleteLoop = cancel + go helper.mach.ExpiredKeyDeleteLoop(ctx) + } + + if encryptionConfig.DeleteKeys.DeleteOutdatedInbound { + deleted, err := helper.store.RedactOutdatedGroupSessions(ctx) + if err != nil { + return err + } + if len(deleted) > 0 { + helper.log.Debug().Int("deleted", len(deleted)).Msg("Deleted inbound keys which lacked expiration metadata") + } + } + + helper.client.Syncer = &cryptoSyncer{helper.mach} + helper.client.Store = helper.store + + err = helper.mach.Load(ctx) + if err != nil { + return err + } + if isExistingDevice { + helper.verifyKeysAreOnServer(ctx) + } + + go helper.resyncEncryptionInfo(context.TODO()) + + return nil +} + +func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) { + log := helper.log.With().Str("action", "resync encryption event").Logger() + rows, err := helper.store.DB.Query(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`) + if err != nil { + log.Err(err).Msg("Failed to query rooms for resync") + return + } + roomIDs, err := dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList() + if err != nil { + log.Err(err).Msg("Failed to scan rooms for resync") + return + } + if len(roomIDs) > 0 { + log.Debug().Interface("room_ids", roomIDs).Msg("Resyncing rooms") + for _, roomID := range roomIDs { + var evt event.EncryptionEventContent + err = helper.client.StateEvent(ctx, roomID, event.StateEncryption, "", &evt) + if err != nil { + log.Err(err).Str("room_id", roomID.String()).Msg("Failed to get encryption event") + _, err = helper.store.DB.Exec(ctx, ` + UPDATE mx_room_state SET encryption=NULL WHERE room_id=$1 AND encryption='{"resync":true}' + `, roomID) + if err != nil { + log.Err(err).Str("room_id", roomID.String()).Msg("Failed to unmark room for resync after failed sync") + } + } else { + maxAge := evt.RotationPeriodMillis + if maxAge <= 0 { + maxAge = (7 * 24 * time.Hour).Milliseconds() + } + maxMessages := evt.RotationPeriodMessages + if maxMessages <= 0 { + maxMessages = 100 + } + log.Debug(). + Str("room_id", roomID.String()). + Int64("max_age_ms", maxAge). + Int("max_messages", maxMessages). + Interface("content", &evt). + Msg("Resynced encryption event") + _, err = helper.store.DB.Exec(ctx, ` + UPDATE crypto_megolm_inbound_session + SET max_age=$1, max_messages=$2 + WHERE room_id=$3 AND max_age IS NULL AND max_messages IS NULL + `, maxAge, maxMessages, roomID) + if err != nil { + log.Err(err).Str("room_id", roomID.String()).Msg("Failed to update megolm session table") + } else { + log.Debug().Str("room_id", roomID.String()).Msg("Updated megolm session table") + } + } + } + } +} + +func (helper *CryptoHelper) allowKeyShare(ctx context.Context, device *id.Device, info event.RequestedKeyInfo) *crypto.KeyShareRejection { + cfg := helper.bridge.Config.Encryption + if !cfg.AllowKeySharing { + return &crypto.KeyShareRejectNoResponse + } else if device.Trust == id.TrustStateBlacklisted { + return &crypto.KeyShareRejectBlacklisted + } else if trustState := helper.mach.ResolveTrust(device); trustState >= cfg.VerificationLevels.Share { + portal, err := helper.bridge.Bridge.GetPortalByMXID(ctx, info.RoomID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal to handle key request") + return &crypto.KeyShareRejectNoResponse + } else if portal == nil { + zerolog.Ctx(ctx).Debug().Msg("Rejecting key request: room is not a portal") + return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnavailable, Reason: "Requested room is not a portal room"} + } + user, err := helper.bridge.Bridge.GetExistingUserByMXID(ctx, device.UserID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get user to handle key request") + return &crypto.KeyShareRejectNoResponse + } else if user == nil { + // TODO + return &crypto.KeyShareRejectNoResponse + } else if true { + // TODO admin check and is in room check + return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnauthorized, Reason: "Key sharing is not yet implemented in bridgev2"} + } + zerolog.Ctx(ctx).Debug().Msg("Accepting key request") + return nil + } else { + return &crypto.KeyShareRejectUnverified + } +} + +func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool, error) { + deviceID, err := helper.store.FindDeviceID(ctx) + if err != nil { + return nil, false, fmt.Errorf("failed to find existing device ID: %w", err) + } else if len(deviceID) > 0 { + helper.log.Debug().Str("device_id", deviceID.String()).Msg("Found existing device ID for bot in database") + } + // Create a new client instance with the default AS settings (including as_token), + // the Login call will then override the access token in the client. + client := helper.bridge.AS.NewMautrixClient(helper.bridge.AS.BotMXID()) + flows, err := client.GetLoginFlows(ctx) + if err != nil { + return nil, deviceID != "", fmt.Errorf("failed to get supported login flows: %w", err) + } else if !flows.HasFlow(mautrix.AuthTypeAppservice) { + return nil, deviceID != "", fmt.Errorf("homeserver does not support appservice login") + } + resp, err := client.Login(ctx, &mautrix.ReqLogin{ + Type: mautrix.AuthTypeAppservice, + Identifier: mautrix.UserIdentifier{ + Type: mautrix.IdentifierTypeUser, + User: string(helper.bridge.AS.BotMXID()), + }, + DeviceID: deviceID, + StoreCredentials: true, + + // TODO find proper bridge name + InitialDeviceDisplayName: "Megabridge", // fmt.Sprintf("%s bridge", helper.bridge.ProtocolName), + }) + if err != nil { + return nil, deviceID != "", fmt.Errorf("failed to log in as bridge bot: %w", err) + } + helper.store.DeviceID = resp.DeviceID + return client, deviceID != "", nil +} + +func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) { + helper.log.Debug().Msg("Making sure keys are still on server") + resp, err := helper.client.QueryKeys(ctx, &mautrix.ReqQueryKeys{ + DeviceKeys: map[id.UserID]mautrix.DeviceIDList{ + helper.client.UserID: {helper.client.DeviceID}, + }, + }) + if err != nil { + helper.log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to query own keys to make sure device still exists") + os.Exit(33) + } + device, ok := resp.DeviceKeys[helper.client.UserID][helper.client.DeviceID] + if ok && len(device.Keys) > 0 { + return + } + helper.log.Warn().Msg("Existing device doesn't have keys on server, resetting crypto") + helper.Reset(ctx, false) +} + +func (helper *CryptoHelper) Start() { + if helper.bridge.Config.Encryption.Appservice { + helper.log.Debug().Msg("End-to-bridge encryption is in appservice mode, registering event listeners and not starting syncer") + helper.bridge.AS.Registration.EphemeralEvents = true + helper.mach.AddAppserviceListener(helper.bridge.EventProcessor) + return + } + helper.syncDone.Add(1) + defer helper.syncDone.Done() + helper.log.Debug().Msg("Starting syncer for receiving to-device messages") + var ctx context.Context + ctx, helper.cancelSync = context.WithCancel(context.Background()) + err := helper.client.SyncWithContext(ctx) + if err != nil && !errors.Is(err, context.Canceled) { + helper.log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Fatal error syncing") + os.Exit(51) + } else { + helper.log.Info().Msg("Bridge bot to-device syncer stopped without error") + } +} + +func (helper *CryptoHelper) Stop() { + helper.log.Debug().Msg("CryptoHelper.Stop() called, stopping bridge bot sync") + helper.client.StopSync() + if helper.cancelSync != nil { + helper.cancelSync() + } + if helper.cancelPeriodicDeleteLoop != nil { + helper.cancelPeriodicDeleteLoop() + } + helper.syncDone.Wait() +} + +func (helper *CryptoHelper) clearDatabase(ctx context.Context) { + _, err := helper.store.DB.Exec(ctx, "DELETE FROM crypto_account") + if err != nil { + helper.log.Warn().Err(err).Msg("Failed to clear crypto_account table") + } + _, err = helper.store.DB.Exec(ctx, "DELETE FROM crypto_olm_session") + if err != nil { + helper.log.Warn().Err(err).Msg("Failed to clear crypto_olm_session table") + } + _, err = helper.store.DB.Exec(ctx, "DELETE FROM crypto_megolm_outbound_session") + if err != nil { + helper.log.Warn().Err(err).Msg("Failed to clear crypto_megolm_outbound_session table") + } + //_, _ = helper.store.DB.Exec("DELETE FROM crypto_device") + //_, _ = helper.store.DB.Exec("DELETE FROM crypto_tracked_user") + //_, _ = helper.store.DB.Exec("DELETE FROM crypto_cross_signing_keys") + //_, _ = helper.store.DB.Exec("DELETE FROM crypto_cross_signing_signatures") +} + +func (helper *CryptoHelper) Reset(ctx context.Context, startAfterReset bool) { + helper.lock.Lock() + defer helper.lock.Unlock() + helper.log.Info().Msg("Resetting end-to-bridge encryption device") + helper.Stop() + helper.log.Debug().Msg("Crypto syncer stopped, clearing database") + helper.clearDatabase(ctx) + helper.log.Debug().Msg("Crypto database cleared, logging out of all sessions") + _, err := helper.client.LogoutAll(ctx) + if err != nil { + helper.log.Warn().Err(err).Msg("Failed to log out all devices") + } + helper.client = nil + helper.store = nil + helper.mach = nil + err = helper.Init(ctx) + if err != nil { + helper.log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Error reinitializing end-to-bridge encryption") + os.Exit(50) + } + helper.log.Info().Msg("End-to-bridge encryption successfully reset") + if startAfterReset { + go helper.Start() + } +} + +func (helper *CryptoHelper) Client() *mautrix.Client { + return helper.client +} + +func (helper *CryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*event.Event, error) { + return helper.mach.DecryptMegolmEvent(ctx, evt) +} + +func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content *event.Content) (err error) { + helper.lock.RLock() + defer helper.lock.RUnlock() + var encrypted *event.EncryptedEventContent + encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content) + if err != nil { + if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.SessionNotShared) && !errors.Is(err, crypto.NoGroupSession) { + return + } + helper.log.Debug().Err(err). + Str("room_id", roomID.String()). + Msg("Got error while encrypting event for room, sharing group session and trying again...") + var users []id.UserID + users, err = helper.store.GetRoomJoinedOrInvitedMembers(ctx, roomID) + if err != nil { + err = fmt.Errorf("failed to get room member list: %w", err) + } else if err = helper.mach.ShareGroupSession(ctx, roomID, users); err != nil { + err = fmt.Errorf("failed to share group session: %w", err) + } else if encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content); err != nil { + err = fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err) + } + } + if encrypted != nil { + content.Parsed = encrypted + content.Raw = nil + } + return +} + +func (helper *CryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { + helper.lock.RLock() + defer helper.lock.RUnlock() + return helper.mach.WaitForSession(ctx, roomID, senderKey, sessionID, timeout) +} + +func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) { + helper.lock.RLock() + defer helper.lock.RUnlock() + if deviceID == "" { + deviceID = "*" + } + err := helper.mach.SendRoomKeyRequest(ctx, roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{userID: {deviceID}}) + if err != nil { + helper.log.Warn().Err(err). + Str("user_id", userID.String()). + Str("device_id", deviceID.String()). + Str("session_id", sessionID.String()). + Str("room_id", roomID.String()). + Msg("Failed to send key request") + } else { + helper.log.Debug(). + Str("user_id", userID.String()). + Str("device_id", deviceID.String()). + Str("session_id", sessionID.String()). + Str("room_id", roomID.String()). + Msg("Sent key request") + } +} + +func (helper *CryptoHelper) ResetSession(ctx context.Context, roomID id.RoomID) { + helper.lock.RLock() + defer helper.lock.RUnlock() + err := helper.mach.CryptoStore.RemoveOutboundGroupSession(ctx, roomID) + if err != nil { + helper.log.Debug().Err(err). + Str("room_id", roomID.String()). + Msg("Error manually removing outbound group session in room") + } +} + +func (helper *CryptoHelper) HandleMemberEvent(ctx context.Context, evt *event.Event) { + helper.lock.RLock() + defer helper.lock.RUnlock() + helper.mach.HandleMemberEvent(ctx, evt) +} + +// ShareKeys uploads the given number of one-time-keys to the server. +func (helper *CryptoHelper) ShareKeys(ctx context.Context) error { + return helper.mach.ShareKeys(ctx, -1) +} + +type cryptoSyncer struct { + *crypto.OlmMachine +} + +func (syncer *cryptoSyncer) ProcessResponse(ctx context.Context, resp *mautrix.RespSync, since string) error { + done := make(chan struct{}) + go func() { + defer func() { + if err := recover(); err != nil { + syncer.Log.Error(). + Str("since", since). + Interface("error", err). + Str("stack", string(debug.Stack())). + Msg("Processing sync response panicked") + } + done <- struct{}{} + }() + syncer.Log.Trace().Str("since", since).Msg("Starting sync response handling") + syncer.ProcessSyncResponse(ctx, resp, since) + syncer.Log.Trace().Str("since", since).Msg("Successfully handled sync response") + }() + select { + case <-done: + case <-time.After(30 * time.Second): + syncer.Log.Warn().Str("since", since).Msg("Handling sync response is taking unusually long") + } + return nil +} + +func (syncer *cryptoSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.Duration, error) { + if errors.Is(err, mautrix.MUnknownToken) { + return 0, err + } + syncer.Log.Error().Err(err).Msg("Error /syncing, waiting 10 seconds") + return 10 * time.Second, nil +} + +func (syncer *cryptoSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter { + everything := []event.Type{{Type: "*"}} + return &mautrix.Filter{ + Presence: mautrix.FilterPart{NotTypes: everything}, + AccountData: mautrix.FilterPart{NotTypes: everything}, + Room: mautrix.RoomFilter{ + IncludeLeave: false, + Ephemeral: mautrix.FilterPart{NotTypes: everything}, + AccountData: mautrix.FilterPart{NotTypes: everything}, + State: mautrix.FilterPart{NotTypes: everything}, + Timeline: mautrix.FilterPart{NotTypes: everything}, + }, + } +} diff --git a/bridgev2/matrix/cryptostore.go b/bridgev2/matrix/cryptostore.go new file mode 100644 index 00000000..234797a6 --- /dev/null +++ b/bridgev2/matrix/cryptostore.go @@ -0,0 +1,63 @@ +// Copyright (c) 2022 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build cgo && !nocrypto + +package matrix + +import ( + "context" + + "github.com/lib/pq" + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/id" +) + +func init() { + crypto.PostgresArrayWrapper = pq.Array +} + +type SQLCryptoStore struct { + *crypto.SQLCryptoStore + UserID id.UserID + GhostIDFormat string +} + +var _ crypto.Store = (*SQLCryptoStore)(nil) + +func NewSQLCryptoStore(db *dbutil.Database, log dbutil.DatabaseLogger, accountID string, userID id.UserID, ghostIDFormat, pickleKey string) *SQLCryptoStore { + return &SQLCryptoStore{ + SQLCryptoStore: crypto.NewSQLCryptoStore(db, log, accountID, "", []byte(pickleKey)), + UserID: userID, + GhostIDFormat: ghostIDFormat, + } +} + +func (store *SQLCryptoStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) (members []id.UserID, err error) { + var rows dbutil.Rows + rows, err = store.DB.Query(ctx, ` + SELECT user_id FROM mx_user_profile + WHERE room_id=$1 + AND (membership='join' OR membership='invite') + AND user_id<>$2 + AND user_id NOT LIKE $3 + `, roomID, store.UserID, store.GhostIDFormat) + if err != nil { + return + } + for rows.Next() { + var userID id.UserID + err = rows.Scan(&userID) + if err != nil { + return members, err + } else { + members = append(members, userID) + } + } + return +} diff --git a/bridgev2/matrix/doublepuppet.go b/bridgev2/matrix/doublepuppet.go new file mode 100644 index 00000000..4c1aca6a --- /dev/null +++ b/bridgev2/matrix/doublepuppet.go @@ -0,0 +1,175 @@ +// Copyright (c) 2023 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package matrix + +import ( + "context" + "crypto/hmac" + "crypto/sha512" + "encoding/hex" + "errors" + "fmt" + "strings" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/id" +) + +type doublePuppetUtil struct { + br *Connector + log zerolog.Logger +} + +func (dp *doublePuppetUtil) newClient(ctx context.Context, mxid id.UserID, accessToken string) (*mautrix.Client, error) { + _, homeserver, err := mxid.Parse() + if err != nil { + return nil, err + } + homeserverURL, found := dp.br.Config.DoublePuppet.Servers[homeserver] + if !found { + if homeserver == dp.br.AS.HomeserverDomain { + homeserverURL = "" + } else if dp.br.Config.DoublePuppet.AllowDiscovery { + resp, err := mautrix.DiscoverClientAPI(ctx, homeserver) + if err != nil { + return nil, fmt.Errorf("failed to find homeserver URL for %s: %v", homeserver, err) + } + homeserverURL = resp.Homeserver.BaseURL + dp.log.Debug(). + Str("homeserver", homeserver). + Str("url", homeserverURL). + Str("user_id", mxid.String()). + Msg("Discovered URL to enable double puppeting for user") + } else { + return nil, fmt.Errorf("double puppeting from %s is not allowed", homeserver) + } + } + return dp.br.AS.NewExternalMautrixClient(mxid, accessToken, homeserverURL) +} + +func (dp *doublePuppetUtil) newIntent(ctx context.Context, mxid id.UserID, accessToken string) (*appservice.IntentAPI, error) { + client, err := dp.newClient(ctx, mxid, accessToken) + if err != nil { + return nil, err + } + + ia := dp.br.AS.NewIntentAPI("custom") + ia.Client = client + ia.Localpart, _, _ = mxid.Parse() + ia.UserID = mxid + ia.IsCustomPuppet = true + return ia, nil +} + +func (dp *doublePuppetUtil) autoLogin(ctx context.Context, mxid id.UserID, loginSecret string) (string, error) { + dp.log.Debug().Str("user_id", mxid.String()).Msg("Logging into user account with shared secret") + client, err := dp.newClient(ctx, mxid, "") + if err != nil { + return "", fmt.Errorf("failed to create mautrix client to log in: %v", err) + } + // TODO proper bridge name + //bridgeName := fmt.Sprintf("%s Bridge", dp.br.ProtocolName) + bridgeName := "Megabridge" + req := mautrix.ReqLogin{ + Identifier: mautrix.UserIdentifier{Type: mautrix.IdentifierTypeUser, User: string(mxid)}, + DeviceID: id.DeviceID(bridgeName), + InitialDeviceDisplayName: bridgeName, + } + if loginSecret == "appservice" { + client.AccessToken = dp.br.AS.Registration.AppToken + req.Type = mautrix.AuthTypeAppservice + } else { + loginFlows, err := client.GetLoginFlows(ctx) + if err != nil { + return "", fmt.Errorf("failed to get supported login flows: %w", err) + } + mac := hmac.New(sha512.New, []byte(loginSecret)) + mac.Write([]byte(mxid)) + token := hex.EncodeToString(mac.Sum(nil)) + switch { + case loginFlows.HasFlow(mautrix.AuthTypeDevtureSharedSecret): + req.Type = mautrix.AuthTypeDevtureSharedSecret + req.Token = token + case loginFlows.HasFlow(mautrix.AuthTypePassword): + req.Type = mautrix.AuthTypePassword + req.Password = token + default: + return "", fmt.Errorf("no supported auth types for shared secret auth found") + } + } + resp, err := client.Login(ctx, &req) + if err != nil { + return "", err + } + return resp.AccessToken, nil +} + +var ( + ErrMismatchingMXID = errors.New("whoami result does not match custom mxid") + ErrNoAccessToken = errors.New("no access token provided") + ErrNoMXID = errors.New("no mxid provided") +) + +const useConfigASToken = "appservice-config" +const asTokenModePrefix = "as_token:" + +func (dp *doublePuppetUtil) Setup(ctx context.Context, mxid id.UserID, savedAccessToken string, reloginOnFail bool) (intent *appservice.IntentAPI, newAccessToken string, err error) { + if len(mxid) == 0 { + err = ErrNoMXID + return + } + _, homeserver, _ := mxid.Parse() + loginSecret, hasSecret := dp.br.Config.DoublePuppet.Secrets[homeserver] + // Special case appservice: prefix to not login and use it as an as_token directly. + if hasSecret && strings.HasPrefix(loginSecret, asTokenModePrefix) { + intent, err = dp.newIntent(ctx, mxid, strings.TrimPrefix(loginSecret, asTokenModePrefix)) + if err != nil { + return + } + intent.SetAppServiceUserID = true + if savedAccessToken != useConfigASToken { + var resp *mautrix.RespWhoami + resp, err = intent.Whoami(ctx) + if err == nil && resp.UserID != mxid { + err = ErrMismatchingMXID + } + } + return intent, useConfigASToken, err + } + if savedAccessToken == "" || savedAccessToken == useConfigASToken { + if reloginOnFail && hasSecret { + savedAccessToken, err = dp.autoLogin(ctx, mxid, loginSecret) + } else { + err = ErrNoAccessToken + } + if err != nil { + return + } + } + intent, err = dp.newIntent(ctx, mxid, savedAccessToken) + if err != nil { + return + } + var resp *mautrix.RespWhoami + resp, err = intent.Whoami(ctx) + if err != nil { + if reloginOnFail && hasSecret && errors.Is(err, mautrix.MUnknownToken) { + intent.AccessToken, err = dp.autoLogin(ctx, mxid, loginSecret) + if err == nil { + newAccessToken = intent.AccessToken + } + } + } else if resp.UserID != mxid { + err = ErrMismatchingMXID + } else { + newAccessToken = savedAccessToken + } + return +} diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go new file mode 100644 index 00000000..ef4913b9 --- /dev/null +++ b/bridgev2/matrix/intent.go @@ -0,0 +1,178 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package matrix + +import ( + "context" + "fmt" + "time" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/crypto/attachment" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +// ASIntent implements the bridge ghost API interface using a real Matrix homeserver as the backend. +type ASIntent struct { + Matrix *appservice.IntentAPI + Connector *Connector +} + +var _ bridgev2.MatrixAPI = (*ASIntent)(nil) + +func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, ts time.Time) (*mautrix.RespSendEvent, error) { + // TODO remove this once hungryserv and synapse support sending m.room.redactions directly in all room versions + if eventType == event.EventRedaction { + parsedContent := content.Parsed.(*event.RedactionEventContent) + return as.Matrix.RedactEvent(ctx, roomID, parsedContent.Redacts, mautrix.ReqRedact{ + Reason: parsedContent.Reason, + Extra: content.Raw, + }) + } + if eventType != event.EventReaction && eventType != event.EventRedaction { + if encrypted, err := as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil { + return nil, fmt.Errorf("failed to check if room is encrypted: %w", err) + } else if encrypted { + err = as.Connector.Crypto.Encrypt(ctx, roomID, eventType, content) + if err != nil { + return nil, err + } + eventType = event.EventEncrypted + } + } + if ts.IsZero() { + return as.Matrix.SendMessageEvent(ctx, roomID, eventType, content) + } else { + return as.Matrix.SendMassagedMessageEvent(ctx, roomID, eventType, content, ts.UnixMilli()) + } +} + +func (as *ASIntent) SendState(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, content *event.Content, ts time.Time) (*mautrix.RespSendEvent, error) { + if ts.IsZero() { + return as.Matrix.SendStateEvent(ctx, roomID, eventType, stateKey, content) + } else { + return as.Matrix.SendMassagedStateEvent(ctx, roomID, eventType, stateKey, content, ts.UnixMilli()) + } +} + +func (as *ASIntent) DownloadMedia(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo) ([]byte, error) { + if file != nil { + uri = file.URL + } + parsedURI, err := uri.Parse() + if err != nil { + return nil, err + } + data, err := as.Matrix.DownloadBytes(ctx, parsedURI) + if err != nil { + return nil, err + } + if file != nil { + err = file.DecryptInPlace(data) + if err != nil { + return nil, err + } + } + return data, nil +} + +func (as *ASIntent) UploadMedia(ctx context.Context, roomID id.RoomID, data []byte, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) { + if roomID != "" { + var encrypted bool + if encrypted, err = as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil { + err = fmt.Errorf("failed to check if room is encrypted: %w", err) + return + } else if encrypted { + file = &event.EncryptedFileInfo{ + EncryptedFile: *attachment.NewEncryptedFile(), + } + file.EncryptInPlace(data) + mimeType = "application/octet-stream" + fileName = "" + } + } + req := mautrix.ReqUploadMedia{ + ContentBytes: data, + ContentType: mimeType, + FileName: fileName, + } + if as.Connector.AsyncUploads { + var resp *mautrix.RespCreateMXC + resp, err = as.Matrix.UploadAsync(ctx, req) + if resp != nil { + url = resp.ContentURI.CUString() + } + } else { + var resp *mautrix.RespMediaUpload + resp, err = as.Matrix.UploadMedia(ctx, req) + if resp != nil { + url = resp.ContentURI.CUString() + } + } + if file != nil { + file.URL = url + url = "" + } + return +} + +func (as *ASIntent) SetDisplayName(ctx context.Context, name string) error { + return as.Matrix.SetDisplayName(ctx, name) +} + +func (as *ASIntent) SetAvatarURL(ctx context.Context, avatarURL id.ContentURIString) error { + parsedAvatarURL, err := avatarURL.Parse() + if err != nil { + return err + } + return as.Matrix.SetAvatarURL(ctx, parsedAvatarURL) +} + +func (as *ASIntent) GetMXID() id.UserID { + return as.Matrix.UserID +} + +func (as *ASIntent) InviteUser(ctx context.Context, roomID id.RoomID, userID id.UserID) error { + _, err := as.Matrix.InviteUser(ctx, roomID, &mautrix.ReqInviteUser{ + Reason: "", + UserID: userID, + }) + return err +} + +func (as *ASIntent) EnsureJoined(ctx context.Context, roomID id.RoomID) error { + return as.Matrix.EnsureJoined(ctx, roomID) +} + +func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) (id.RoomID, error) { + if as.Connector.Config.Encryption.Default { + content := &event.EncryptionEventContent{Algorithm: id.AlgorithmMegolmV1} + if rot := as.Connector.Config.Encryption.Rotation; rot.EnableCustom { + content.RotationPeriodMillis = rot.Milliseconds + content.RotationPeriodMessages = rot.Messages + } + req.InitialState = append(req.InitialState, &event.Event{ + Type: event.StateEncryption, + Content: event.Content{ + Parsed: content, + }, + }) + } + resp, err := as.Matrix.CreateRoom(ctx, req) + if err != nil { + return "", err + } + return resp.RoomID, nil +} + +func (as *ASIntent) DeleteRoom(ctx context.Context, roomID id.RoomID) error { + // TODO implement non-beeper delete + return as.Matrix.BeeperDeleteRoom(ctx, roomID) +} diff --git a/bridgev2/matrix/matrix.go b/bridgev2/matrix/matrix.go new file mode 100644 index 00000000..7dfeacf0 --- /dev/null +++ b/bridgev2/matrix/matrix.go @@ -0,0 +1,175 @@ +// Copyright (c) 2023 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package matrix + +import ( + "context" + "errors" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func (br *Connector) handleRoomEvent(ctx context.Context, evt *event.Event) { + if br.shouldIgnoreEvent(evt) { + return + } + if (evt.Type == event.EventMessage || evt.Type == event.EventSticker) && !evt.Mautrix.WasEncrypted && br.Config.Encryption.Require { + zerolog.Ctx(ctx).Warn().Msg("Dropping unencrypted event as encryption is configured to be required") + // TODO send metrics + return + } + br.Bridge.QueueMatrixEvent(ctx, evt) +} + +func (br *Connector) handleEphemeralEvent(ctx context.Context, evt *event.Event) { + if br.shouldIgnoreEvent(evt) { + return + } + br.Bridge.QueueMatrixEvent(ctx, evt) +} + +func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event) { + if br.shouldIgnoreEvent(evt) { + return + } + content := evt.Content.AsEncrypted() + log := zerolog.Ctx(ctx).With(). + Str("event_id", evt.ID.String()). + Str("session_id", content.SessionID.String()). + Logger() + ctx = log.WithContext(ctx) + if br.Crypto == nil { + // TODO send metrics + log.Error().Msg("Can't decrypt message: no crypto") + return + } + log.Debug().Msg("Decrypting received event") + + decryptionStart := time.Now() + decrypted, err := br.Crypto.Decrypt(ctx, evt) + decryptionRetryCount := 0 + if errors.Is(err, NoSessionFound) { + decryptionRetryCount = 1 + log.Debug(). + Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())). + Msg("Couldn't find session, waiting for keys to arrive...") + // TODO send metrics + if br.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { + log.Debug().Msg("Got keys after waiting, trying to decrypt event again") + decrypted, err = br.Crypto.Decrypt(ctx, evt) + } else { + go br.waitLongerForSession(ctx, evt, decryptionStart) + return + } + } + if err != nil { + log.Warn().Err(err).Msg("Failed to decrypt event") + // TODO send metrics + return + } + br.postDecrypt(ctx, evt, decrypted, decryptionRetryCount, "", time.Since(decryptionStart)) +} + +func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event, decryptionStart time.Time) { + log := zerolog.Ctx(ctx) + content := evt.Content.AsEncrypted() + log.Debug(). + Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())). + Msg("Couldn't find session, requesting keys and waiting longer...") + + go br.Crypto.RequestSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) + var errorEventID id.EventID + // TODO send metrics + + if !br.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { + log.Debug().Msg("Didn't get session, giving up trying to decrypt event") + // TODO send metrics + return + } + + log.Debug().Msg("Got keys after waiting longer, trying to decrypt event again") + decrypted, err := br.Crypto.Decrypt(ctx, evt) + if err != nil { + log.Error().Err(err).Msg("Failed to decrypt event") + // TODO send metrics + return + } + + br.postDecrypt(ctx, evt, decrypted, 2, errorEventID, time.Since(decryptionStart)) +} + +type CommandProcessor interface { + Handle(ctx context.Context, roomID id.RoomID, eventID id.EventID, user bridgev2.User, message string, replyTo id.EventID) +} + +func (br *Connector) sendBridgeCheckpoint(_ context.Context, evt *event.Event) { + if !evt.Mautrix.CheckpointSent { + //go br.SendMessageSuccessCheckpoint(evt, status.MsgStepBridge, 0) + } +} + +func (br *Connector) shouldIgnoreEvent(evt *event.Event) bool { + if evt.Sender == br.Bot.UserID { + return true + } + _, isGhost := br.ParseGhostMXID(evt.Sender) + if isGhost { + return true + } + // TODO exclude double puppeted events + return false +} + +const initialSessionWaitTimeout = 3 * time.Second +const extendedSessionWaitTimeout = 22 * time.Second + +func copySomeKeys(original, decrypted *event.Event) { + isScheduled, _ := original.Content.Raw["com.beeper.scheduled"].(bool) + _, alreadyExists := decrypted.Content.Raw["com.beeper.scheduled"] + if isScheduled && !alreadyExists { + decrypted.Content.Raw["com.beeper.scheduled"] = true + } +} + +func (br *Connector) postDecrypt(ctx context.Context, original, decrypted *event.Event, retryCount int, errorEventID id.EventID, duration time.Duration) { + log := zerolog.Ctx(ctx) + minLevel := br.Config.Encryption.VerificationLevels.Send + if decrypted.Mautrix.TrustState < minLevel { + logEvt := log.Warn(). + Str("user_id", decrypted.Sender.String()). + Bool("forwarded_keys", decrypted.Mautrix.ForwardedKeys). + Stringer("device_trust", decrypted.Mautrix.TrustState). + Stringer("min_trust", minLevel) + if decrypted.Mautrix.TrustSource != nil { + dev := decrypted.Mautrix.TrustSource + logEvt. + Str("device_id", dev.DeviceID.String()). + Str("device_signing_key", dev.SigningKey.String()) + } else { + logEvt.Str("device_id", "unknown") + } + logEvt.Msg("Dropping event due to insufficient verification level") + //err := deviceUnverifiedErrorWithExplanation(decrypted.Mautrix.TrustState) + //go mx.sendCryptoStatusError(ctx, decrypted, errorEventID, err, retryCount, true) + return + } + copySomeKeys(original, decrypted) + + // TODO checkpoint + decrypted.Mautrix.CheckpointSent = true + decrypted.Mautrix.DecryptionDuration = duration + decrypted.Mautrix.EventSource |= event.SourceDecrypted + br.EventProcessor.Dispatch(ctx, decrypted) + if errorEventID != "" { + _, _ = br.Bot.RedactEvent(ctx, decrypted.RoomID, errorEventID) + } +} diff --git a/bridgev2/matrix/no-crypto.go b/bridgev2/matrix/no-crypto.go new file mode 100644 index 00000000..5b05272c --- /dev/null +++ b/bridgev2/matrix/no-crypto.go @@ -0,0 +1,28 @@ +// Copyright (c) 2023 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build !cgo || nocrypto + +package matrix + +import ( + "errors" + + "maunium.net/go/mautrix/bridge" +) + +func NewCryptoHelper(bridge *bridge.Bridge) bridge.Crypto { + if bridge.Config.Bridge.GetEncryptionConfig().Allow { + bridge.ZLog.Warn().Msg("Bridge built without end-to-bridge encryption, but encryption is enabled in config") + } else { + bridge.ZLog.Debug().Msg("Bridge built without end-to-bridge encryption") + } + return nil +} + +var NoSessionFound = errors.New("nil") +var UnknownMessageIndex = NoSessionFound +var DuplicateMessageIndex = NoSessionFound diff --git a/bridgev2/matrix/websocket.go.dis b/bridgev2/matrix/websocket.go.dis new file mode 100644 index 00000000..cf4b0517 --- /dev/null +++ b/bridgev2/matrix/websocket.go.dis @@ -0,0 +1,163 @@ +package matrix + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "go.mau.fi/util/jsontime" + + "maunium.net/go/mautrix/appservice" +) + +const defaultReconnectBackoff = 2 * time.Second +const maxReconnectBackoff = 2 * time.Minute +const reconnectBackoffReset = 5 * time.Minute + +func (br *Connector) startWebsocket(wg *sync.WaitGroup) { + log := br.Log.With().Str("action", "appservice websocket").Logger() + var wgOnce sync.Once + onConnect := func() { + wssBr, ok := br.Child.(WebsocketStartingBridge) + if ok { + wssBr.OnWebsocketConnect() + } + if br.latestState != nil { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + br.latestState.Timestamp = jsontime.UnixNow() + err := br.SendBridgeState(ctx, br.latestState) + if err != nil { + log.Err(err).Msg("Failed to resend latest bridge state after websocket reconnect") + } else { + log.Debug().Any("bridge_state", br.latestState).Msg("Resent bridge state after websocket reconnect") + } + }() + } + wgOnce.Do(wg.Done) + select { + case br.wsStarted <- struct{}{}: + default: + } + } + reconnectBackoff := defaultReconnectBackoff + lastDisconnect := time.Now().UnixNano() + br.wsStopped = make(chan struct{}) + defer func() { + log.Debug().Msg("Appservice websocket loop finished") + close(br.wsStopped) + }() + addr := br.Config.Homeserver.WSProxy + if addr == "" { + addr = br.Config.Homeserver.Address + } + for { + err := br.AS.StartWebsocket(addr, onConnect) + if errors.Is(err, appservice.ErrWebsocketManualStop) { + return + } else if closeCommand := (&appservice.CloseCommand{}); errors.As(err, &closeCommand) && closeCommand.Status == appservice.MeowConnectionReplaced { + log.Info().Msg("Appservice websocket closed by another instance of the bridge, shutting down...") + br.ManualStop(0) + return + } else if err != nil { + log.Err(err).Msg("Error in appservice websocket") + } + if br.Stopping { + return + } + now := time.Now().UnixNano() + if lastDisconnect+reconnectBackoffReset.Nanoseconds() < now { + reconnectBackoff = defaultReconnectBackoff + } else { + reconnectBackoff *= 2 + if reconnectBackoff > maxReconnectBackoff { + reconnectBackoff = maxReconnectBackoff + } + } + lastDisconnect = now + log.Info(). + Int("backoff_seconds", int(reconnectBackoff.Seconds())). + Msg("Websocket disconnected, reconnecting...") + select { + case <-br.wsShortCircuitReconnectBackoff: + log.Debug().Msg("Reconnect backoff was short-circuited") + case <-time.After(reconnectBackoff): + } + if br.Stopping { + return + } + } +} + +type wsPingData struct { + Timestamp int64 `json:"timestamp"` +} + +func (br *Connector) PingServer() (start, serverTs, end time.Time) { + if !br.Websocket { + panic(fmt.Errorf("PingServer called without websocket enabled")) + } + if !br.AS.HasWebsocket() { + br.Log.Debug().Msg("Received server ping request, but no websocket connected. Trying to short-circuit backoff sleep") + select { + case br.wsShortCircuitReconnectBackoff <- struct{}{}: + default: + br.Log.Warn().Msg("Failed to ping websocket: not connected and no backoff?") + return + } + select { + case <-br.wsStarted: + case <-time.After(15 * time.Second): + if !br.AS.HasWebsocket() { + br.Log.Warn().Msg("Failed to ping websocket: didn't connect after 15 seconds of waiting") + return + } + } + } + start = time.Now() + var resp wsPingData + br.Log.Debug().Msg("Pinging appservice websocket") + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + err := br.AS.RequestWebsocket(ctx, &appservice.WebsocketRequest{ + Command: "ping", + Data: &wsPingData{Timestamp: start.UnixMilli()}, + }, &resp) + end = time.Now() + if err != nil { + br.Log.Warn().Err(err).Dur("duration", end.Sub(start)).Msg("Websocket ping returned error") + br.AS.StopWebsocket(fmt.Errorf("websocket ping returned error in %s: %w", end.Sub(start), err)) + } else { + serverTs = time.Unix(0, resp.Timestamp*int64(time.Millisecond)) + br.Log.Debug(). + Dur("duration", end.Sub(start)). + Dur("req_duration", serverTs.Sub(start)). + Dur("resp_duration", end.Sub(serverTs)). + Msg("Websocket ping returned success") + } + return +} + +func (br *Connector) websocketServerPinger() { + interval := time.Duration(br.Config.Homeserver.WSPingInterval) * time.Second + clock := time.NewTicker(interval) + defer func() { + br.Log.Info().Msg("Stopping websocket pinger") + clock.Stop() + }() + br.Log.Info().Dur("interval_duration", interval).Msg("Starting websocket pinger") + for { + select { + case <-clock.C: + br.PingServer() + case <-br.wsStopPinger: + return + } + if br.Stopping { + return + } + } +} diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go new file mode 100644 index 00000000..e8e23628 --- /dev/null +++ b/bridgev2/matrixinterface.go @@ -0,0 +1,52 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "time" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type MatrixConnector interface { + Init(*Bridge) + Start(ctx context.Context) error + + ParseGhostMXID(userID id.UserID) (networkid.UserID, bool) + FormatGhostMXID(userID networkid.UserID) id.UserID + + GhostIntent(userID id.UserID) MatrixAPI + UserIntent(user *User) MatrixAPI + BotIntent() MatrixAPI + + SendMessageStatus(ctx context.Context, status MessageStatus) + + GetMemberInfo(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) + + ServerName() string +} + +type MatrixAPI interface { + GetMXID() id.UserID + + SendMessage(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, ts time.Time) (*mautrix.RespSendEvent, error) + SendState(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, content *event.Content, ts time.Time) (*mautrix.RespSendEvent, error) + DownloadMedia(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo) ([]byte, error) + UploadMedia(ctx context.Context, roomID id.RoomID, data []byte, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) + + SetDisplayName(ctx context.Context, name string) error + SetAvatarURL(ctx context.Context, avatarURL id.ContentURIString) error + + CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) (id.RoomID, error) + DeleteRoom(ctx context.Context, roomID id.RoomID) error + InviteUser(ctx context.Context, roomID id.RoomID, userID id.UserID) error + EnsureJoined(ctx context.Context, roomID id.RoomID) error +} diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go new file mode 100644 index 00000000..c7d8d0b4 --- /dev/null +++ b/bridgev2/messagestatus.go @@ -0,0 +1,86 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "maunium.net/go/mautrix/bridge/status" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type MessageStatus struct { + RoomID id.RoomID + EventID id.EventID + Status event.MessageStatus + ErrorReason event.MessageStatusReason + DeliveredTo []id.UserID + Error error // Internal error to be tracked in message checkpoints + Message string // Human-readable message shown to users +} + +func (ms *MessageStatus) CheckpointStatus() status.MessageCheckpointStatus { + switch ms.Status { + case event.MessageStatusSuccess: + if ms.DeliveredTo != nil { + return status.MsgStatusDelivered + } + return status.MsgStatusSuccess + case event.MessageStatusPending: + return status.MsgStatusWillRetry + case event.MessageStatusRetriable, event.MessageStatusFail: + switch ms.ErrorReason { + case event.MessageStatusTooOld: + return status.MsgStatusTimeout + case event.MessageStatusUnsupported: + return status.MsgStatusUnsupported + default: + return status.MsgStatusPermFailure + } + default: + return "UNKNOWN" + } +} + +func (ms *MessageStatus) ToCheckpoint() *status.MessageCheckpoint { + checkpoint := &status.MessageCheckpoint{ + RoomID: ms.RoomID, + EventID: ms.EventID, + Step: status.MsgStepRemote, + Status: ms.CheckpointStatus(), + ReportedBy: status.MsgReportedByBridge, + } + if ms.Error != nil { + checkpoint.Info = ms.Error.Error() + } else if ms.Message != "" { + checkpoint.Info = ms.Message + } + return checkpoint +} + +func (ms *MessageStatus) ToEvent() *event.BeeperMessageStatusEventContent { + content := &event.BeeperMessageStatusEventContent{ + RelatesTo: event.RelatesTo{ + Type: event.RelAnnotation, + EventID: ms.EventID, + }, + Status: ms.Status, + Reason: ms.ErrorReason, + Message: ms.Message, + } + if ms.Error != nil { + content.InternalError = ms.Error.Error() + } + if ms.DeliveredTo != nil { + content.DeliveredToUsers = &ms.DeliveredTo + } + return content +} + +func (ms *MessageStatus) ErrorAsMessage() *MessageStatus { + ms.Message = ms.Error.Error() + return ms +} diff --git a/bridgev2/networkid/bridgeid.go b/bridgev2/networkid/bridgeid.go new file mode 100644 index 00000000..75406df3 --- /dev/null +++ b/bridgev2/networkid/bridgeid.go @@ -0,0 +1,59 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package networkid + +// BridgeID is an opaque identifier for a bridge +type BridgeID string + +// PortalID is the ID of a room on the remote network. +// +// Portal IDs must be globally unique and refer to a single chat. +// This means that user IDs can't be used directly as DM chat IDs, instead the ID must contain both user IDs (e.g. "user1-user2"). +// If generating such IDs manually, sorting the users is recommended to ensure they're consistent. +type PortalID string + +// UserID is the ID of a user on the remote network. +type UserID string + +// UserLoginID is the ID of the user being controlled on the remote network. It may be the same shape as [UserID]. +type UserLoginID string + +// MessageID is the ID of a message on the remote network. +// +// Message IDs must be unique across rooms and consistent across users. +type MessageID string + +// PartID is the ID of a message part on the remote network (e.g. index of image in album). +// +// Part IDs are only unique within a message, not globally. +// To refer to a specific message part globally, use the MessagePartID tuple struct. +type PartID string + +// MessagePartID refers to a specific part of a message by combining a message ID and a part ID. +type MessagePartID struct { + MessageID MessageID + PartID PartID +} + +// MessageOptionalPartID refers to a specific part of a message by combining a message ID and an optional part ID. +// If the part ID is not set, this should refer to the first part ID sorted alphabetically. +type MessageOptionalPartID struct { + MessageID MessageID + PartID *PartID +} + +// AvatarID is the ID of a user or room avatar on the remote network. +// +// It may be a real URL, an opaque identifier, or anything in between. +type AvatarID string + +// EmojiID is the ID of a reaction emoji on the remote network. +// +// On networks that only allow one reaction per message, an empty string should be used +// to apply the unique constraints in the database appropriately. +// On networks that allow multiple emojis, this is the unicode emoji or a network-specific shortcode. +type EmojiID string diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go new file mode 100644 index 00000000..b70fd89c --- /dev/null +++ b/bridgev2/networkinterface.go @@ -0,0 +1,217 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" +) + +type ConvertedMessagePart struct { + ID networkid.PartID + Type event.Type + Content *event.MessageEventContent + Extra map[string]any + DBMetadata map[string]any +} + +type EventSender struct { + IsFromMe bool + SenderLogin networkid.UserLoginID + Sender networkid.UserID +} + +type ConvertedMessage struct { + ID networkid.MessageID + EventSender + Timestamp time.Time + ReplyTo *networkid.MessageOptionalPartID + ThreadRoot *networkid.MessageOptionalPartID + Parts []*ConvertedMessagePart + // For edits, set this field to skip editing the event + Unchanged bool +} + +type NetworkConnector interface { + Init(*Bridge) + Start(context.Context) error + PrepareLogin(ctx context.Context, login *UserLogin) error +} + +type NetworkAPI interface { + Connect(ctx context.Context) error + IsLoggedIn() bool + GetChatInfo(ctx context.Context, portal *Portal) (*PortalInfo, error) + + HandleMatrixMessage(ctx context.Context, msg *MatrixMessage) (message *database.Message, err error) + HandleMatrixEdit(ctx context.Context, msg *MatrixEdit) error + HandleMatrixReaction(ctx context.Context, msg *MatrixReaction) (emojiID networkid.EmojiID, err error) + HandleMatrixReactionRemove(ctx context.Context, msg *MatrixReactionRemove) error + HandleMatrixMessageRemove(ctx context.Context, msg *MatrixMessageRemove) error +} + +type RemoteEventType int + +const ( + RemoteEventMessage RemoteEventType = iota + RemoteEventEdit + RemoteEventReaction + RemoteEventReactionRemove + RemoteEventMessageRemove +) + +type RemoteEvent interface { + GetType() RemoteEventType + GetPortalID() networkid.PortalID + ShouldCreatePortal() bool + AddLogContext(c zerolog.Context) zerolog.Context +} + +type RemoteMessage interface { + RemoteEvent + ConvertMessage(ctx context.Context, portal *Portal) (*ConvertedMessage, error) +} + +type RemoteEdit interface { + RemoteEvent + GetTargetMessage() networkid.MessageID + ConvertEdit(ctx context.Context, portal *Portal, existing []*database.Message) (*ConvertedMessage, error) +} + +type RemoteReaction interface { + RemoteEvent + GetSender() EventSender + GetTargetMessage() networkid.MessageID + GetReactionEmoji() (string, networkid.EmojiID) +} + +type RemoteReactionRemove interface { + RemoteEvent + GetSender() EventSender + GetTargetMessage() networkid.MessageID + GetRemovedEmojiID() networkid.EmojiID +} + +type RemoteMessageRemove interface { + RemoteEvent + GetSender() EventSender + GetTargetMessage() networkid.MessageID +} + +// SimpleRemoteEvent is a simple implementation of RemoteEvent that can be used with struct fields and some callbacks. +type SimpleRemoteEvent[T any] struct { + Type RemoteEventType + LogContext func(c zerolog.Context) zerolog.Context + PortalID networkid.PortalID + Data T + CreatePortal bool + + Sender EventSender + TargetMessage networkid.MessageID + EmojiID networkid.EmojiID + Emoji string + + ConvertMessageFunc func(ctx context.Context, portal *Portal, data T) (*ConvertedMessage, error) + ConvertEditFunc func(ctx context.Context, portal *Portal, existing []*database.Message, data T) (*ConvertedMessage, error) +} + +var ( + _ RemoteMessage = (*SimpleRemoteEvent[any])(nil) + _ RemoteEdit = (*SimpleRemoteEvent[any])(nil) + _ RemoteReaction = (*SimpleRemoteEvent[any])(nil) + _ RemoteReactionRemove = (*SimpleRemoteEvent[any])(nil) + _ RemoteMessageRemove = (*SimpleRemoteEvent[any])(nil) +) + +func (sre *SimpleRemoteEvent[T]) AddLogContext(c zerolog.Context) zerolog.Context { + return sre.LogContext(c) +} + +func (sre *SimpleRemoteEvent[T]) GetPortalID() networkid.PortalID { + return sre.PortalID +} + +func (sre *SimpleRemoteEvent[T]) ConvertMessage(ctx context.Context, portal *Portal) (*ConvertedMessage, error) { + return sre.ConvertMessageFunc(ctx, portal, sre.Data) +} + +func (sre *SimpleRemoteEvent[T]) ConvertEdit(ctx context.Context, portal *Portal, existing []*database.Message) (*ConvertedMessage, error) { + return sre.ConvertEditFunc(ctx, portal, existing, sre.Data) +} + +func (sre *SimpleRemoteEvent[T]) GetSender() EventSender { + return sre.Sender +} + +func (sre *SimpleRemoteEvent[T]) GetTargetMessage() networkid.MessageID { + return sre.TargetMessage +} + +func (sre *SimpleRemoteEvent[T]) GetReactionEmoji() (string, networkid.EmojiID) { + return sre.Emoji, sre.EmojiID +} + +func (sre *SimpleRemoteEvent[T]) GetRemovedEmojiID() networkid.EmojiID { + return sre.EmojiID +} + +func (sre *SimpleRemoteEvent[T]) GetType() RemoteEventType { + return sre.Type +} + +func (sre *SimpleRemoteEvent[T]) ShouldCreatePortal() bool { + return sre.CreatePortal +} + +type OrigSender struct { + User *User + event.MemberEventContent +} + +type MatrixEventBase[ContentType any] struct { + // The raw event being bridged. + Event *event.Event + // The parsed content struct of the event. Custom fields can be found in Event.Content.Raw. + Content ContentType + // The room where the event happened. + Portal *Portal + + // The original sender user ID. Only present in case the event is being relayed (and Sender is not the same user). + OrigSender *OrigSender +} + +type MatrixMessage struct { + MatrixEventBase[*event.MessageEventContent] + ThreadRoot *database.Message + ReplyTo *database.Message +} + +type MatrixEdit struct { + MatrixEventBase[*event.MessageEventContent] + EditTarget *database.Message +} + +type MatrixReaction struct { + MatrixEventBase[*event.ReactionEventContent] + TargetMessage *database.Message +} + +type MatrixReactionRemove struct { + MatrixEventBase[*event.RedactionEventContent] + TargetReaction *database.Reaction +} + +type MatrixMessageRemove struct { + MatrixEventBase[*event.RedactionEventContent] + TargetMessage *database.Message +} diff --git a/bridgev2/portal.go b/bridgev2/portal.go new file mode 100644 index 00000000..072d3971 --- /dev/null +++ b/bridgev2/portal.go @@ -0,0 +1,733 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + "github.com/rs/zerolog" + "go.mau.fi/util/exslices" + "golang.org/x/exp/slices" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type portalMatrixEvent struct { + evt *event.Event + sender *User +} + +type portalRemoteEvent struct { + evt RemoteEvent + source *UserLogin +} + +func (pme *portalMatrixEvent) isPortalEvent() {} +func (pre *portalRemoteEvent) isPortalEvent() {} + +type portalEvent interface { + isPortalEvent() +} + +type Portal struct { + *database.Portal + Bridge *Bridge + Log zerolog.Logger + Parent *Portal + Relay *UserLogin + + currentlyTyping []id.UserID + currentlyTypingLock sync.Mutex + + roomCreateLock sync.Mutex + + events chan portalEvent +} + +const PortalEventBuffer = 64 + +func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, queryErr error, id *networkid.PortalID) (*Portal, error) { + if queryErr != nil { + return nil, fmt.Errorf("failed to query db: %w", queryErr) + } + if dbPortal == nil { + if id == nil { + return nil, nil + } + dbPortal = &database.Portal{ + BridgeID: br.ID, + ID: *id, + } + err := br.DB.Portal.Insert(ctx, dbPortal) + if err != nil { + return nil, fmt.Errorf("failed to insert new portal: %w", err) + } + } + portal := &Portal{ + Portal: dbPortal, + Bridge: br, + + events: make(chan portalEvent, PortalEventBuffer), + } + br.portalsByID[portal.ID] = portal + if portal.MXID != "" { + br.portalsByMXID[portal.MXID] = portal + } + if portal.ParentID != "" { + var err error + portal.Parent, err = br.unlockedGetPortalByID(ctx, portal.ParentID, false) + if err != nil { + return nil, fmt.Errorf("failed to load parent portal (%s): %w", portal.ParentID, err) + } + } + portal.updateLogger() + go portal.eventLoop() + return portal, nil +} + +func (portal *Portal) updateLogger() { + logWith := portal.Bridge.Log.With().Str("portal_id", string(portal.ID)) + if portal.MXID != "" { + logWith = logWith.Stringer("portal_mxid", portal.MXID) + } + portal.Log = logWith.Logger() +} + +func (br *Bridge) unlockedGetPortalByID(ctx context.Context, id networkid.PortalID, onlyIfExists bool) (*Portal, error) { + cached, ok := br.portalsByID[id] + if ok { + return cached, nil + } + idPtr := &id + if onlyIfExists { + idPtr = nil + } + db, err := br.DB.Portal.GetByID(ctx, id) + return br.loadPortal(ctx, db, err, idPtr) +} + +func (br *Bridge) GetPortalByMXID(ctx context.Context, mxid id.RoomID) (*Portal, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + cached, ok := br.portalsByMXID[mxid] + if ok { + return cached, nil + } + db, err := br.DB.Portal.GetByMXID(ctx, mxid) + return br.loadPortal(ctx, db, err, nil) +} + +func (br *Bridge) GetPortalByID(ctx context.Context, id networkid.PortalID) (*Portal, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + return br.unlockedGetPortalByID(ctx, id, false) +} + +func (br *Bridge) GetExistingPortalByID(ctx context.Context, id networkid.PortalID) (*Portal, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + return br.unlockedGetPortalByID(ctx, id, true) +} + +func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) { + select { + case portal.events <- evt: + default: + zerolog.Ctx(ctx).Error(). + Str("portal_id", string(portal.ID)). + Msg("Portal event channel is full") + } +} + +func (portal *Portal) eventLoop() { + for rawEvt := range portal.events { + switch evt := rawEvt.(type) { + case *portalMatrixEvent: + portal.handleMatrixEvent(evt.sender, evt.evt) + case *portalRemoteEvent: + portal.handleRemoteEvent(evt.source, evt.evt) + default: + panic(fmt.Errorf("illegal type %T in eventLoop", evt)) + } + } +} + +func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User) (*UserLogin, error) { + logins, err := portal.Bridge.DB.User.FindLoginsByPortalID(ctx, user.MXID, portal.ID) + if err != nil { + return nil, err + } + portal.Bridge.cacheLock.Lock() + defer portal.Bridge.cacheLock.Unlock() + for _, loginID := range logins { + login, ok := user.logins[loginID] + if ok && login.Client != nil { + return login, nil + } + } + // Portal has relay, use it + if portal.Relay != nil { + return nil, nil + } + var firstLogin *UserLogin + for _, login := range user.logins { + firstLogin = login + break + } + if firstLogin != nil { + zerolog.Ctx(ctx).Warn(). + Str("chosen_login_id", string(firstLogin.ID)). + Msg("No usable user portal rows found, returning random login") + return firstLogin, nil + } else { + return nil, ErrNotLoggedIn + } +} + +func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) { + if evt.Mautrix.EventSource&event.SourceEphemeral != 0 { + switch evt.Type { + case event.EphemeralEventReceipt: + portal.handleMatrixReceipts(evt) + case event.EphemeralEventTyping: + portal.handleMatrixTyping(evt) + } + return + } + log := portal.Log.With(). + Str("action", "handle matrix event"). + Stringer("event_id", evt.ID). + Stringer("sender", sender.MXID). + Logger() + ctx := log.WithContext(context.TODO()) + login, err := portal.FindPreferredLogin(ctx, sender) + if err != nil { + log.Err(err).Msg("Failed to get user login to handle Matrix event") + // TODO send metrics + return + } + var origSender *OrigSender + if login == nil { + login = portal.Relay + origSender = &OrigSender{ + User: sender, + } + + memberInfo, err := portal.Bridge.Matrix.GetMemberInfo(ctx, portal.MXID, sender.MXID) + if err != nil { + log.Warn().Err(err).Msg("Failed to get member info for user being relayed") + } else if memberInfo != nil { + origSender.MemberEventContent = *memberInfo + } + } + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("login_id", string(login.ID)) + }) + switch evt.Type { + case event.EventMessage, event.EventSticker: + portal.handleMatrixMessage(ctx, login, origSender, evt) + case event.EventReaction: + if origSender != nil { + log.Debug().Msg("Ignoring reaction event from relayed user") + // TODO send metrics + return + } + portal.handleMatrixReaction(ctx, login, evt) + case event.EventRedaction: + portal.handleMatrixRedaction(ctx, login, origSender, evt) + case event.StateRoomName: + case event.StateTopic: + case event.StateRoomAvatar: + case event.StateEncryption: + } +} + +func (portal *Portal) handleMatrixReceipts(evt *event.Event) { + content, ok := evt.Content.Parsed.(event.ReceiptEventContent) + if !ok { + return + } + ctx := context.TODO() + for evtID, receipts := range content { + readReceipts, ok := receipts[event.ReceiptTypeRead] + if !ok { + continue + } + for userID, receipt := range readReceipts { + sender, err := portal.Bridge.GetUserByMXID(ctx, userID) + if err != nil { + // TODO log + return + } + portal.handleMatrixReadReceipt(ctx, sender, evtID, receipt) + } + } +} + +func (portal *Portal) handleMatrixReadReceipt(ctx context.Context, user *User, eventID id.EventID, receipt event.ReadReceipt) { + // TODO send read receipt(s) to network +} + +func (portal *Portal) handleMatrixTyping(evt *event.Event) { + content, ok := evt.Content.Parsed.(*event.TypingEventContent) + if !ok { + return + } + portal.currentlyTypingLock.Lock() + defer portal.currentlyTypingLock.Unlock() + slices.Sort(content.UserIDs) + stoppedTyping, startedTyping := exslices.SortedDiff(portal.currentlyTyping, content.UserIDs, func(a, b id.UserID) int { + return strings.Compare(string(a), string(b)) + }) + for range stoppedTyping { + // TODO send typing stop events + } + for range startedTyping { + // TODO send typing start events + } + portal.currentlyTyping = content.UserIDs +} + +func (portal *Portal) periodicTypingUpdater() { + for { + // TODO make delay configurable by network connector + time.Sleep(5 * time.Second) + portal.currentlyTypingLock.Lock() + if len(portal.currentlyTyping) == 0 { + portal.currentlyTypingLock.Unlock() + continue + } + // TODO send typing events + portal.currentlyTypingLock.Unlock() + } +} + +func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { + log := zerolog.Ctx(ctx) + content, ok := evt.Content.Parsed.(*event.MessageEventContent) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + // TODO send metrics + return + } + if content.RelatesTo.GetReplaceID() != "" { + portal.handleMatrixEdit(ctx, sender, origSender, evt, content) + return + } + + // TODO get capabilities from network connector + threadsSupported := true + repliesSupported := true + var threadRoot, replyTo *database.Message + var err error + if threadsSupported { + threadRootID := content.RelatesTo.GetThreadParent() + if threadRootID != "" { + threadRoot, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, threadRootID) + if err != nil { + log.Err(err).Msg("Failed to get thread root message from database") + } + } + } + if repliesSupported { + var replyToID id.EventID + if threadsSupported { + replyToID = content.RelatesTo.GetNonFallbackReplyTo() + } else { + replyToID = content.RelatesTo.GetReplyTo() + } + if replyToID != "" { + replyTo, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, replyToID) + if err != nil { + log.Err(err).Msg("Failed to get reply target message from database") + } + } + } + + message, err := sender.Client.HandleMatrixMessage(ctx, &MatrixMessage{ + MatrixEventBase: MatrixEventBase[*event.MessageEventContent]{ + Event: evt, + Content: content, + OrigSender: origSender, + Portal: portal, + }, + ThreadRoot: threadRoot, + ReplyTo: replyTo, + }) + if err != nil { + log.Err(err).Msg("Failed to handle Matrix message") + // TODO send metrics here or inside HandleMatrixMessage? + return + } + if message.Metadata == nil { + message.Metadata = make(map[string]any) + } + message.Metadata["sender_mxid"] = evt.Sender + // Hack to ensure the ghost row exists + // TODO move to better place (like login) + portal.Bridge.GetGhostByID(ctx, message.SenderID) + err = portal.Bridge.DB.Message.Insert(ctx, message) + if err != nil { + log.Err(err).Msg("Failed to save message to database") + } + // TODO send success metrics +} + +func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent) { + editTargetID := content.RelatesTo.GetReplaceID() + log := zerolog.Ctx(ctx) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Stringer("edit_target_mxid", editTargetID) + }) + if content.NewContent != nil { + content = content.NewContent + } + editTarget, err := portal.Bridge.DB.Message.GetPartByMXID(ctx, editTargetID) + if err != nil { + log.Err(err).Msg("Failed to get edit target message from database") + // TODO send metrics + return + } + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("edit_target_remote_id", string(editTarget.ID)) + }) + err = sender.Client.HandleMatrixEdit(ctx, &MatrixEdit{ + MatrixEventBase: MatrixEventBase[*event.MessageEventContent]{ + Event: evt, + Content: content, + OrigSender: origSender, + Portal: portal, + }, + EditTarget: editTarget, + }) + if err != nil { + log.Err(err).Msg("Failed to handle Matrix edit") + // TODO send metrics here or inside HandleMatrixEdit? + return + } + err = portal.Bridge.DB.Message.Update(ctx, editTarget) + if err != nil { + log.Err(err).Msg("Failed to save message to database after editing") + } + // TODO send success metrics +} + +func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) { + +} + +func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { + +} + +func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { + log := portal.Log.With(). + Str("source_id", string(source.ID)). + Str("action", "handle remote event"). + Logger() + log.UpdateContext(evt.AddLogContext) + ctx := log.WithContext(context.TODO()) + if portal.MXID == "" { + if !evt.ShouldCreatePortal() { + return + } + err := portal.CreateMatrixRoom(ctx, source) + if err != nil { + log.Err(err).Msg("Failed to create portal to handle event") + // TODO error + return + } + } + switch evt.GetType() { + case RemoteEventMessage: + portal.handleRemoteMessage(ctx, source, evt.(RemoteMessage)) + case RemoteEventEdit: + portal.handleRemoteEdit(ctx, source, evt.(RemoteEdit)) + case RemoteEventReaction: + portal.handleRemoteReaction(ctx, source, evt.(RemoteReaction)) + case RemoteEventReactionRemove: + portal.handleRemoteReactionRemove(ctx, source, evt.(RemoteReactionRemove)) + case RemoteEventMessageRemove: + portal.handleRemoteMessageRemove(ctx, source, evt.(RemoteMessageRemove)) + } +} + +func (portal *Portal) getIntentFor(ctx context.Context, sender EventSender, source *UserLogin) MatrixAPI { + var intent MatrixAPI + if sender.IsFromMe { + intent = portal.Bridge.Matrix.UserIntent(source.User) + } + if intent == nil && sender.SenderLogin != "" { + senderLogin := portal.Bridge.GetCachedUserLoginByID(sender.SenderLogin) + if senderLogin != nil { + intent = portal.Bridge.Matrix.UserIntent(senderLogin.User) + } + } + if intent == nil { + ghost, err := portal.Bridge.GetGhostByID(ctx, sender.Sender) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost for message sender") + return nil + } + // TODO update ghost info + intent = ghost.Intent + } + return intent +} + +func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin, evt RemoteMessage) { + log := zerolog.Ctx(ctx) + converted, err := evt.ConvertMessage(ctx, portal) + if err != nil { + // TODO log and notify room? + return + } + var relatesToRowID int64 + var replyTo, threadRoot, prevThreadEvent *database.Message + if converted.ReplyTo != nil { + replyTo, err = portal.Bridge.DB.Message.GetFirstOrSpecificPartByID(ctx, *converted.ReplyTo) + if err != nil { + log.Err(err).Msg("Failed to get reply target message from database") + } else { + relatesToRowID = replyTo.RowID + } + } + if converted.ThreadRoot != nil { + threadRoot, err = portal.Bridge.DB.Message.GetFirstOrSpecificPartByID(ctx, *converted.ThreadRoot) + if err != nil { + log.Err(err).Msg("Failed to get thread root message from database") + } else { + relatesToRowID = threadRoot.RowID + } + // TODO thread roots need to be saved in the database in a way that allows fetching + // the first bridged thread message even if the original one isn't bridged + + // TODO 2 fetch last event in thread properly + prevThreadEvent = threadRoot + } + intent := portal.getIntentFor(ctx, converted.EventSender, source) + if intent == nil { + return + } + for _, part := range converted.Parts { + if threadRoot != nil && prevThreadEvent != nil { + part.Content.GetRelatesTo().SetThread(threadRoot.MXID, prevThreadEvent.MXID) + } + if replyTo != nil { + part.Content.GetRelatesTo().SetReplyTo(replyTo.MXID) + if part.Content.Mentions == nil { + part.Content.Mentions = &event.Mentions{} + } + replyTargetSenderMXID, ok := replyTo.Metadata["sender_mxid"].(string) + if ok && !slices.Contains(part.Content.Mentions.UserIDs, id.UserID(replyTargetSenderMXID)) { + part.Content.Mentions.UserIDs = append(part.Content.Mentions.UserIDs, id.UserID(replyTargetSenderMXID)) + } + } + resp, err := intent.SendMessage(ctx, portal.MXID, part.Type, &event.Content{ + Parsed: part.Content, + Raw: part.Extra, + }, converted.Timestamp) + if err != nil { + log.Err(err).Str("part_id", string(part.ID)).Msg("Failed to send message part to Matrix") + continue + } + if part.DBMetadata == nil { + part.DBMetadata = make(map[string]any) + } + // TODO make metadata fields less hacky + part.DBMetadata["sender_mxid"] = intent.GetMXID() + dbMessage := &database.Message{ + ID: converted.ID, + PartID: part.ID, + MXID: resp.EventID, + RoomID: portal.ID, + SenderID: converted.Sender, + Timestamp: converted.Timestamp, + RelatesToRowID: relatesToRowID, + Metadata: part.DBMetadata, + } + err = portal.Bridge.DB.Message.Insert(ctx, dbMessage) + if err != nil { + log.Err(err).Str("part_id", string(part.ID)).Msg("Failed to save message part to database") + } + if prevThreadEvent != nil { + prevThreadEvent = dbMessage + } + } +} + +func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, evt RemoteEdit) { + +} + +func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogin, evt RemoteReaction) { + +} + +func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *UserLogin, evt RemoteReactionRemove) { + +} + +func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *UserLogin, evt RemoteMessageRemove) { + +} + +var stateElementFunctionalMembers = event.Type{Class: event.StateEventType, Type: "io.element.functional_members"} + +type PortalInfo struct { + Name string + Topic string + AvatarID networkid.AvatarID + AvatarMXC id.ContentURIString + + Members []networkid.UserID + + IsDirectChat bool + IsSpace bool +} + +func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin) error { + portal.roomCreateLock.Lock() + defer portal.roomCreateLock.Unlock() + if portal.MXID != "" { + return nil + } + log := zerolog.Ctx(ctx).With(). + Str("action", "create matrix room"). + Logger() + ctx = log.WithContext(ctx) + log.Info().Msg("Creating Matrix room") + + info, err := source.Client.GetChatInfo(ctx, portal) + if err != nil { + log.Err(err).Msg("Failed to update portal info for creation") + return err + } + portal.Name = info.Name + portal.Topic = info.Topic + portal.AvatarID = info.AvatarID + portal.AvatarMXC = info.AvatarMXC + invite := make([]id.UserID, 0, len(info.Members)+1) + inviteIntents := make([]MatrixAPI, 0, len(info.Members)+1) + for _, memberID := range info.Members { + ghost, err := portal.Bridge.GetGhostByID(ctx, memberID) + if err != nil { + log.Err(err).Str("memebr_id", string(memberID)).Msg("Failed to get portal member ghost") + } else { + invite = append(invite, ghost.MXID) + inviteIntents = append(inviteIntents, ghost.Intent) + } + } + // TODO should the source user mxid come from members? + invite = append(invite, source.UserMXID) + inviteIntents = append(inviteIntents, portal.Bridge.Matrix.UserIntent(source.User)) + + req := mautrix.ReqCreateRoom{ + Visibility: "private", + Name: portal.Name, + Topic: portal.Topic, + CreationContent: make(map[string]any), + InitialState: make([]*event.Event, 0, 4), + Preset: "private_chat", + IsDirect: info.IsDirectChat, + PowerLevelOverride: &event.PowerLevelsEventContent{ + Users: map[id.UserID]int{ + portal.Bridge.Bot.GetMXID(): 9001, + }, + }, + BeeperLocalRoomID: id.RoomID(fmt.Sprintf("!%s:%s", portal.ID, portal.Bridge.Matrix.ServerName())), + BeeperInitialMembers: invite, + } + // TODO find this properly from the matrix connector + isBeeper := true + // TODO remove this after initial_members is supported in hungryserv + if isBeeper { + req.BeeperAutoJoinInvites = true + req.Invite = invite + } + if info.IsSpace { + req.CreationContent["type"] = event.RoomTypeSpace + } + emptyString := "" + req.InitialState = append(req.InitialState, &event.Event{ + StateKey: &emptyString, + Type: stateElementFunctionalMembers, + Content: event.Content{Raw: map[string]any{ + "service_members": []id.UserID{portal.Bridge.Bot.GetMXID()}, + }}, + }) + if req.Topic == "" { + // Add explicit topic event if topic is empty to ensure the event is set. + // This ensures that there won't be an extra event later if PUT /state/... is called. + req.InitialState = append(req.InitialState, &event.Event{ + StateKey: &emptyString, + Type: event.StateTopic, + Content: event.Content{Parsed: &event.TopicEventContent{Topic: ""}}, + }) + } + if portal.AvatarMXC != "" { + req.InitialState = append(req.InitialState, &event.Event{ + StateKey: &emptyString, + Type: event.StateRoomAvatar, + // TODO change RoomAvatarEventContent to have id.ContentURIString instead of id.ContentURI? + Content: event.Content{Raw: map[string]any{"url": portal.AvatarMXC}}, + }) + } + if portal.Parent != nil { + // TODO create parent portal if it doesn't exist? + req.InitialState = append(req.InitialState, &event.Event{ + StateKey: (*string)(&portal.Parent.MXID), + Type: event.StateSpaceParent, + Content: event.Content{Parsed: &event.SpaceParentEventContent{ + Via: []string{portal.Bridge.Matrix.ServerName()}, + Canonical: true, + }}, + }) + } + roomID, err := portal.Bridge.Bot.CreateRoom(ctx, &req) + if err != nil { + log.Err(err).Msg("Failed to create Matrix room") + return err + } + log.Info().Stringer("room_id", roomID).Msg("Matrix room created") + portal.AvatarSet = true + portal.TopicSet = true + portal.NameSet = true + portal.MXID = roomID + portal.Bridge.cacheLock.Lock() + portal.Bridge.portalsByMXID[roomID] = portal + portal.Bridge.cacheLock.Unlock() + portal.updateLogger() + err = portal.Bridge.DB.Portal.Update(ctx, portal.Portal) + if err != nil { + log.Err(err).Msg("Failed to save portal to database after creating Matrix room") + return err + } + if portal.Parent != nil { + // TODO add m.space.child event + } + if !isBeeper { + for i, mxid := range invite { + intent := inviteIntents[i] + // TODO handle errors + if intent != nil { + intent.EnsureJoined(ctx, portal.MXID) + } else { + portal.Bridge.Bot.InviteUser(ctx, portal.MXID, mxid) + } + } + } + return nil +} diff --git a/bridgev2/queue.go b/bridgev2/queue.go new file mode 100644 index 00000000..f27094df --- /dev/null +++ b/bridgev2/queue.go @@ -0,0 +1,92 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "strings" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/event" +) + +func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) { + // TODO maybe HandleMatrixEvent would be more appropriate as this also handles bot invites and commands + + log := zerolog.Ctx(ctx) + var sender *User + if evt.Sender != "" { + var err error + sender, err = br.GetUserByMXID(ctx, evt.Sender) + if err != nil { + log.Err(err).Msg("Failed to get sender user for incoming Matrix event") + // TODO send metrics + return + } + } + if sender == nil && evt.Type.Class != event.EphemeralEventType { + log.Error().Msg("Missing sender for incoming non-ephemeral Matrix event") + // TODO send metrics + return + } + if evt.Type == event.EventMessage { + msg := evt.Content.AsMessage() + if msg != nil { + msg.RemoveReplyFallback() + + if strings.HasPrefix(msg.Body, br.CommandPrefix) || evt.RoomID == sender.ManagementRoom { + br.Commands.Handle( + ctx, + evt.RoomID, + evt.ID, + sender, + strings.TrimPrefix(msg.Body, br.CommandPrefix+" "), + msg.RelatesTo.GetReplyTo(), + ) + return + } + } + } + if evt.Type == event.StateMember && evt.GetStateKey() == br.Bot.GetMXID().String() && evt.Content.AsMember().Membership == event.MembershipInvite { + br.Bot.EnsureJoined(ctx, evt.RoomID) + // TODO handle errors + if sender.ManagementRoom == "" { + sender.ManagementRoom = evt.RoomID + br.DB.User.Update(ctx, sender.User) + } + return + } + portal, err := br.GetPortalByMXID(ctx, evt.RoomID) + if err != nil { + log.Err(err).Msg("Failed to get portal for incoming Matrix event") + // TODO send metrics + return + } else if portal != nil { + portal.queueEvent(ctx, &portalMatrixEvent{ + evt: evt, + sender: sender, + }) + } +} + +func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) { + log := login.Log + ctx := log.WithContext(context.TODO()) + portal, err := br.GetPortalByID(ctx, evt.GetPortalID()) + if err != nil { + log.Err(err).Str("portal_id", string(portal.ID)). + Msg("Failed to get portal to handle remote event") + return + } + // TODO put this in a better place, and maybe cache to avoid constant db queries + br.DB.UserLogin.EnsureUserPortalExists(ctx, login.UserLogin, portal.ID) + portal.queueEvent(ctx, &portalRemoteEvent{ + evt: evt, + source: login, + }) +} diff --git a/bridgev2/user.go b/bridgev2/user.go new file mode 100644 index 00000000..82d841c2 --- /dev/null +++ b/bridgev2/user.go @@ -0,0 +1,85 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "fmt" + "sync/atomic" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/id" +) + +type User struct { + *database.User + Bridge *Bridge + Log zerolog.Logger + + CommandState atomic.Pointer[CommandState] + + logins map[networkid.UserLoginID]*UserLogin +} + +func (br *Bridge) loadUser(ctx context.Context, dbUser *database.User, queryErr error, userID *id.UserID) (*User, error) { + if queryErr != nil { + return nil, fmt.Errorf("failed to query db: %w", queryErr) + } + if dbUser == nil { + if userID == nil { + return nil, nil + } + dbUser = &database.User{ + BridgeID: br.ID, + MXID: *userID, + } + err := br.DB.User.Insert(ctx, dbUser) + if err != nil { + return nil, fmt.Errorf("failed to insert new user: %w", err) + } + } + user := &User{ + User: dbUser, + Bridge: br, + Log: br.Log.With().Stringer("user_mxid", dbUser.MXID).Logger(), + logins: make(map[networkid.UserLoginID]*UserLogin), + } + br.usersByMXID[user.MXID] = user + err := br.unlockedLoadUserLoginsByMXID(ctx, user) + if err != nil { + return nil, fmt.Errorf("failed to load user logins: %w", err) + } + return user, nil +} + +func (br *Bridge) unlockedGetUserByMXID(ctx context.Context, userID id.UserID, onlyIfExists bool) (*User, error) { + cached, ok := br.usersByMXID[userID] + if ok { + return cached, nil + } + idPtr := &userID + if onlyIfExists { + idPtr = nil + } + db, err := br.DB.User.GetByMXID(ctx, userID) + return br.loadUser(ctx, db, err, idPtr) +} + +func (br *Bridge) GetUserByMXID(ctx context.Context, userID id.UserID) (*User, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + return br.unlockedGetUserByMXID(ctx, userID, false) +} + +func (br *Bridge) GetExistingUserByMXID(ctx context.Context, userID id.UserID) (*User, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + return br.unlockedGetUserByMXID(ctx, userID, true) +} diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go new file mode 100644 index 00000000..0a04662d --- /dev/null +++ b/bridgev2/userlogin.go @@ -0,0 +1,111 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "fmt" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" +) + +type UserLogin struct { + *database.UserLogin + Bridge *Bridge + User *User + Log zerolog.Logger + + Client NetworkAPI +} + +func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *database.UserLogin) (*UserLogin, error) { + if user == nil { + var err error + user, err = br.unlockedGetUserByMXID(ctx, dbUserLogin.UserMXID, true) + if err != nil { + return nil, fmt.Errorf("failed to get user: %w", err) + } + } + userLogin := &UserLogin{ + UserLogin: dbUserLogin, + Bridge: br, + User: user, + Log: user.Log.With().Str("login_id", string(dbUserLogin.ID)).Logger(), + } + err := br.Network.PrepareLogin(ctx, userLogin) + if err != nil { + return nil, fmt.Errorf("failed to prepare: %w", err) + } + user.logins[userLogin.ID] = userLogin + br.userLoginsByID[userLogin.ID] = userLogin + return userLogin, nil +} + +func (br *Bridge) loadManyUserLogins(ctx context.Context, user *User, logins []*database.UserLogin) ([]*UserLogin, error) { + output := make([]*UserLogin, len(logins)) + for i, dbLogin := range logins { + if cached, ok := br.userLoginsByID[dbLogin.ID]; ok { + output[i] = cached + } else { + loaded, err := br.loadUserLogin(ctx, user, dbLogin) + if err != nil { + return nil, fmt.Errorf("failed to load user login: %w", err) + } + output[i] = loaded + } + } + return output, nil +} + +func (br *Bridge) unlockedLoadUserLoginsByMXID(ctx context.Context, user *User) error { + logins, err := br.DB.UserLogin.GetAllForUser(ctx, user.MXID) + if err != nil { + return err + } + _, err = br.loadManyUserLogins(ctx, user, logins) + return err +} + +func (br *Bridge) GetAllUserLogins(ctx context.Context) ([]*UserLogin, error) { + logins, err := br.DB.UserLogin.GetAll(ctx) + if err != nil { + return nil, err + } + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + return br.loadManyUserLogins(ctx, nil, logins) +} + +func (br *Bridge) GetCachedUserLoginByID(id networkid.UserLoginID) *UserLogin { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + return br.userLoginsByID[id] +} + +func (user *User) NewLogin(ctx context.Context, data *database.UserLogin, client NetworkAPI) (*UserLogin, error) { + data.BridgeID = user.BridgeID + data.UserMXID = user.MXID + ul := &UserLogin{ + UserLogin: data, + Bridge: user.Bridge, + User: user, + Log: user.Log.With().Str("login_id", string(data.ID)).Logger(), + Client: client, + } + err := user.Bridge.DB.UserLogin.Insert(ctx, ul.UserLogin) + if err != nil { + return nil, err + } + user.Bridge.cacheLock.Lock() + defer user.Bridge.cacheLock.Unlock() + user.Bridge.userLoginsByID[ul.ID] = ul + user.logins[ul.ID] = ul + return ul, nil +} diff --git a/event/beeper.go b/event/beeper.go index 51ddd77f..5e412504 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -36,8 +36,11 @@ type BeeperMessageStatusEventContent struct { RelatesTo RelatesTo `json:"m.relates_to"` Status MessageStatus `json:"status"` Reason MessageStatusReason `json:"reason,omitempty"` - Error string `json:"error,omitempty"` - Message string `json:"message,omitempty"` + // Deprecated: clients were showing this to users even though they aren't supposed to. + // Use InternalError for error messages that should be included in bug reports, but not shown in the UI. + Error string `json:"error,omitempty"` + InternalError string `json:"internal_error,omitempty"` + Message string `json:"message,omitempty"` LastRetry id.EventID `json:"last_retry,omitempty"` diff --git a/requests.go b/requests.go index cdf020a0..b6c2f895 100644 --- a/requests.go +++ b/requests.go @@ -118,8 +118,10 @@ type ReqCreateRoom struct { PowerLevelOverride *event.PowerLevelsEventContent `json:"power_level_content_override,omitempty"` - MeowRoomID id.RoomID `json:"fi.mau.room_id,omitempty"` - BeeperAutoJoinInvites bool `json:"com.beeper.auto_join_invites,omitempty"` + MeowRoomID id.RoomID `json:"fi.mau.room_id,omitempty"` + BeeperInitialMembers []id.UserID `json:"com.beeper.initial_members,omitempty"` + BeeperAutoJoinInvites bool `json:"com.beeper.auto_join_invites,omitempty"` + BeeperLocalRoomID id.RoomID `json:"com.beeper.local_room_id,omitempty"` } // ReqRedact is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidredacteventidtxnid diff --git a/versions.go b/versions.go index d3dd3c67..bdddf729 100644 --- a/versions.go +++ b/versions.go @@ -54,6 +54,7 @@ type UnstableFeature struct { } var ( + FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17} FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17} BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"} From 84f77cbafe192c2f8d333922c8172460a3a6c779 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 27 May 2024 07:23:01 -0600 Subject: [PATCH 0227/1647] crypto/cross signing: actually save signatures in store on publish Signed-off-by: Sumner Evans --- crypto/cross_sign_key.go | 10 ++++++++++ crypto/cross_sign_ssss.go | 6 ------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/crypto/cross_sign_key.go b/crypto/cross_sign_key.go index f7dc08cb..3d01fb99 100644 --- a/crypto/cross_sign_key.go +++ b/crypto/cross_sign_key.go @@ -142,6 +142,16 @@ func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *Cross return err } + if err := mach.CryptoStore.PutSignature(ctx, userID, keys.MasterKey.PublicKey(), userID, mach.account.SigningKey(), masterSig); err != nil { + return fmt.Errorf("error storing signature of master key by device signing key in crypto store: %w", err) + } + if err := mach.CryptoStore.PutSignature(ctx, userID, keys.SelfSigningKey.PublicKey(), userID, keys.MasterKey.PublicKey(), selfSig); err != nil { + return fmt.Errorf("error storing signature of self-signing key by master key in crypto store: %w", err) + } + if err := mach.CryptoStore.PutSignature(ctx, userID, keys.UserSigningKey.PublicKey(), userID, keys.MasterKey.PublicKey(), userSig); err != nil { + return fmt.Errorf("error storing signature of user-signing key by master key in crypto store: %w", err) + } + mach.CrossSigningKeys = keys mach.crossSigningPubkeys = keys.PublicKeys() diff --git a/crypto/cross_sign_ssss.go b/crypto/cross_sign_ssss.go index 540d625d..389a9fd2 100644 --- a/crypto/cross_sign_ssss.go +++ b/crypto/cross_sign_ssss.go @@ -100,12 +100,6 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(ctx context.Context, u return "", nil, fmt.Errorf("failed to publish cross-signing keys: %w", err) } - // Trust the master key - err = mach.SignOwnMasterKey(ctx) - if err != nil { - return "", nil, fmt.Errorf("failed to sign own master key: %w", err) - } - err = mach.SSSS.SetDefaultKeyID(ctx, key.ID) if err != nil { return "", nil, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err) From cd7f343cfdc5d9deafc12e2c38e7923204feac74 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Sun, 26 May 2024 17:46:33 -0600 Subject: [PATCH 0228/1647] verificationhelper: split QR code tests into separate file Signed-off-by: Sumner Evans --- ..._test.go => verificationhelper_qr_test.go} | 334 +----------------- .../verificationhelper_test.go | 329 +++++++++++++++++ 2 files changed, 339 insertions(+), 324 deletions(-) rename crypto/verificationhelper/{verificationhelper_self_test.go => verificationhelper_qr_test.go} (50%) create mode 100644 crypto/verificationhelper/verificationhelper_test.go diff --git a/crypto/verificationhelper/verificationhelper_self_test.go b/crypto/verificationhelper/verificationhelper_qr_test.go similarity index 50% rename from crypto/verificationhelper/verificationhelper_self_test.go rename to crypto/verificationhelper/verificationhelper_qr_test.go index 08e2c6e4..7391f151 100644 --- a/crypto/verificationhelper/verificationhelper_self_test.go +++ b/crypto/verificationhelper/verificationhelper_qr_test.go @@ -9,269 +9,17 @@ package verificationhelper_test import ( "context" "fmt" - "os" "testing" - "time" - "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto" - "maunium.net/go/mautrix/crypto/cryptohelper" "maunium.net/go/mautrix/crypto/verificationhelper" "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" ) -var userID = id.UserID("@alice:example.org") -var sendingDeviceID = id.DeviceID("sending") -var receivingDeviceID = id.DeviceID("receiving") - -func init() { - log.Logger = zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.RFC3339}).With().Timestamp().Logger().Level(zerolog.TraceLevel) - zerolog.DefaultContextLogger = &log.Logger -} - -func initServerAndLogin(t *testing.T, ctx context.Context) (ts *mockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) { - t.Helper() - ts = createMockServer(t) - - sendingClient, sendingCryptoStore = ts.Login(t, ctx, userID, sendingDeviceID) - sendingMachine = sendingClient.Crypto.(*cryptohelper.CryptoHelper).Machine() - receivingClient, receivingCryptoStore = ts.Login(t, ctx, userID, receivingDeviceID) - receivingMachine = receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine() - - err := sendingCryptoStore.PutDevice(ctx, userID, sendingMachine.OwnIdentity()) - require.NoError(t, err) - err = sendingCryptoStore.PutDevice(ctx, userID, receivingMachine.OwnIdentity()) - require.NoError(t, err) - err = receivingCryptoStore.PutDevice(ctx, userID, sendingMachine.OwnIdentity()) - require.NoError(t, err) - err = receivingCryptoStore.PutDevice(ctx, userID, receivingMachine.OwnIdentity()) - require.NoError(t, err) - return -} - -func initDefaultCallbacks(t *testing.T, ctx context.Context, sendingClient, receivingClient *mautrix.Client, sendingMachine, receivingMachine *crypto.OlmMachine) (sendingCallbacks, receivingCallbacks *allVerificationCallbacks, sendingHelper, receivingHelper *verificationhelper.VerificationHelper) { - t.Helper() - sendingCallbacks = newAllVerificationCallbacks() - sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, sendingCallbacks, true) - require.NoError(t, sendingHelper.Init(ctx)) - - receivingCallbacks = newAllVerificationCallbacks() - receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, receivingCallbacks, true) - require.NoError(t, receivingHelper.Init(ctx)) - return -} - -func TestSelfVerification_Start(t *testing.T) { - ctx := log.Logger.WithContext(context.TODO()) - receivingDeviceID2 := id.DeviceID("receiving2") - - testCases := []struct { - supportsScan bool - callbacks MockVerificationCallbacks - startVerificationErrMsg string - expectedVerificationMethods []event.VerificationMethod - }{ - {false, newBaseVerificationCallbacks(), "no supported verification methods", nil}, - {true, newBaseVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, - {false, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}}, - {true, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, - {true, newQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, - {false, newQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, - {false, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, - {true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, - } - - for i, tc := range testCases { - t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - ts := createMockServer(t) - defer ts.Close() - - client, cryptoStore := ts.Login(t, ctx, userID, sendingDeviceID) - addDeviceID(ctx, cryptoStore, userID, sendingDeviceID) - addDeviceID(ctx, cryptoStore, userID, receivingDeviceID) - addDeviceID(ctx, cryptoStore, userID, receivingDeviceID2) - - senderHelper := verificationhelper.NewVerificationHelper(client, client.Crypto.(*cryptohelper.CryptoHelper).Machine(), tc.callbacks, tc.supportsScan) - err := senderHelper.Init(ctx) - require.NoError(t, err) - - txnID, err := senderHelper.StartVerification(ctx, userID) - if tc.startVerificationErrMsg != "" { - assert.ErrorContains(t, err, tc.startVerificationErrMsg) - return - } - - assert.NoError(t, err) - assert.NotEmpty(t, txnID) - - toDeviceInbox := ts.DeviceInbox[userID] - - // Ensure that we didn't send a verification request to the - // sending device. - assert.Empty(t, toDeviceInbox[sendingDeviceID]) - - // Ensure that the verification request was sent to both of - // the other devices. - assert.NotEmpty(t, toDeviceInbox[receivingDeviceID]) - assert.NotEmpty(t, toDeviceInbox[receivingDeviceID2]) - assert.Equal(t, toDeviceInbox[receivingDeviceID], toDeviceInbox[receivingDeviceID2]) - assert.Len(t, toDeviceInbox[receivingDeviceID], 1) - - // Ensure that the verification request is correct. - verificationRequest := toDeviceInbox[receivingDeviceID][0].Content.AsVerificationRequest() - assert.Equal(t, sendingDeviceID, verificationRequest.FromDevice) - assert.Equal(t, txnID, verificationRequest.TransactionID) - assert.ElementsMatch(t, tc.expectedVerificationMethods, verificationRequest.Methods) - }) - } -} - -func TestSelfVerification_Accept_NoSupportedMethods(t *testing.T) { - ctx := log.Logger.WithContext(context.TODO()) - - ts := createMockServer(t) - defer ts.Close() - - sendingClient, sendingCryptoStore := ts.Login(t, ctx, userID, sendingDeviceID) - receivingClient, _ := ts.Login(t, ctx, userID, receivingDeviceID) - addDeviceID(ctx, sendingCryptoStore, userID, sendingDeviceID) - addDeviceID(ctx, sendingCryptoStore, userID, receivingDeviceID) - - sendingMachine := sendingClient.Crypto.(*cryptohelper.CryptoHelper).Machine() - recoveryKey, cache, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") - assert.NoError(t, err) - assert.NotEmpty(t, recoveryKey) - assert.NotNil(t, cache) - - sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, newAllVerificationCallbacks(), true) - err = sendingHelper.Init(ctx) - require.NoError(t, err) - - receivingCallbacks := newBaseVerificationCallbacks() - receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine(), receivingCallbacks, false) - err = receivingHelper.Init(ctx) - require.NoError(t, err) - - txnID, err := sendingHelper.StartVerification(ctx, userID) - require.NoError(t, err) - require.NotEmpty(t, txnID) - - ts.dispatchToDevice(t, ctx, receivingClient) - - // Ensure that the receiver ignored the request because it - // doesn't support any of the verification methods in the - // request. - assert.Empty(t, receivingCallbacks.GetRequestedVerifications()) -} - -func TestSelfVerification_Accept_CorrectMethodsPresented(t *testing.T) { - ctx := log.Logger.WithContext(context.TODO()) - - testCases := []struct { - sendingSupportsScan bool - receivingSupportsScan bool - sendingCallbacks MockVerificationCallbacks - receivingCallbacks MockVerificationCallbacks - expectedVerificationMethods []event.VerificationMethod - }{ - {false, false, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}}, - {true, false, newQRCodeVerificationCallbacks(), newQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, - {false, true, newQRCodeVerificationCallbacks(), newQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, - {true, false, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, - {true, true, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, - } - - for i, tc := range testCases { - t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLogin(t, ctx) - defer ts.Close() - - recoveryKey, sendingCrossSigningKeysCache, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") - assert.NoError(t, err) - assert.NotEmpty(t, recoveryKey) - assert.NotNil(t, sendingCrossSigningKeysCache) - - sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, tc.sendingCallbacks, tc.sendingSupportsScan) - err = sendingHelper.Init(ctx) - require.NoError(t, err) - - receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, tc.receivingCallbacks, tc.receivingSupportsScan) - err = receivingHelper.Init(ctx) - require.NoError(t, err) - - txnID, err := sendingHelper.StartVerification(ctx, userID) - require.NoError(t, err) - - // Process the verification request on the receiving device. - ts.dispatchToDevice(t, ctx, receivingClient) - - // Ensure that the receiving device received a verification - // request with the correct transaction ID. - assert.ElementsMatch(t, []id.VerificationTransactionID{txnID}, tc.receivingCallbacks.GetRequestedVerifications()[userID]) - - // Have the receiving device accept the verification request. - err = receivingHelper.AcceptVerification(ctx, txnID) - require.NoError(t, err) - - _, sendingIsQRCallbacks := tc.sendingCallbacks.(*qrCodeVerificationCallbacks) - _, sendingIsAllCallbacks := tc.sendingCallbacks.(*allVerificationCallbacks) - sendingCanShowQR := sendingIsQRCallbacks || sendingIsAllCallbacks - _, receivingIsQRCallbacks := tc.receivingCallbacks.(*qrCodeVerificationCallbacks) - _, receivingIsAllCallbacks := tc.receivingCallbacks.(*allVerificationCallbacks) - receivingCanShowQR := receivingIsQRCallbacks || receivingIsAllCallbacks - - // Ensure that if the receiving device should show a QR code that - // it has the correct content. - if tc.sendingSupportsScan && receivingCanShowQR { - receivingShownQRCode := tc.receivingCallbacks.GetQRCodeShown(txnID) - require.NotNil(t, receivingShownQRCode) - assert.Equal(t, txnID, receivingShownQRCode.TransactionID) - assert.NotEmpty(t, receivingShownQRCode.SharedSecret) - } - - // Check for whether the receiving device should be scanning a QR - // code. - if tc.receivingSupportsScan && sendingCanShowQR { - assert.Contains(t, tc.receivingCallbacks.GetScanQRCodeTransactions(), txnID) - } - - // Check that the m.key.verification.ready event has the correct - // content. - sendingInbox := ts.DeviceInbox[userID][sendingDeviceID] - assert.Len(t, sendingInbox, 1) - readyEvt := sendingInbox[0].Content.AsVerificationReady() - assert.Equal(t, txnID, readyEvt.TransactionID) - assert.Equal(t, receivingDeviceID, readyEvt.FromDevice) - assert.ElementsMatch(t, tc.expectedVerificationMethods, readyEvt.Methods) - - // Receive the m.key.verification.ready event on the sending - // device. - ts.dispatchToDevice(t, ctx, sendingClient) - - // Ensure that if the sending device should show a QR code that it - // has the correct content. - if tc.receivingSupportsScan && sendingCanShowQR { - sendingShownQRCode := tc.sendingCallbacks.GetQRCodeShown(txnID) - require.NotNil(t, sendingShownQRCode) - assert.Equal(t, txnID, sendingShownQRCode.TransactionID) - assert.NotEmpty(t, sendingShownQRCode.SharedSecret) - } - - // Check for whether the sending device should be scanning a QR - // code. - if tc.sendingSupportsScan && receivingCanShowQR { - assert.Contains(t, tc.sendingCallbacks.GetScanQRCodeTransactions(), txnID) - } - }) - } -} - func TestSelfVerification_Accept_QRContents(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) @@ -312,7 +60,7 @@ func TestSelfVerification_Accept_QRContents(t *testing.T) { // Send the verification request from the sender device and accept // it on the receiving device and receive the verification ready // event on the sending device. - txnID, err := sendingHelper.StartVerification(ctx, userID) + txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) ts.dispatchToDevice(t, ctx, receivingClient) @@ -371,50 +119,6 @@ func TestSelfVerification_Accept_QRContents(t *testing.T) { } } -// TestAcceptSelfVerificationCancelOnNonParticipatingDevices ensures that we do -// not regress https://github.com/mautrix/go/pull/230. -func TestSelfVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) { - ctx := log.Logger.WithContext(context.TODO()) - ts, sendingClient, receivingClient, sendingCryptoStore, receivingCryptoStore, sendingMachine, receivingMachine := initServerAndLogin(t, ctx) - defer ts.Close() - _, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) - - nonParticipatingDeviceID1 := id.DeviceID("non-participating1") - nonParticipatingDeviceID2 := id.DeviceID("non-participating2") - addDeviceID(ctx, sendingCryptoStore, userID, nonParticipatingDeviceID1) - addDeviceID(ctx, sendingCryptoStore, userID, nonParticipatingDeviceID2) - addDeviceID(ctx, receivingCryptoStore, userID, nonParticipatingDeviceID1) - addDeviceID(ctx, receivingCryptoStore, userID, nonParticipatingDeviceID2) - - _, _, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") - assert.NoError(t, err) - - // Send the verification request from the sender device and accept it on - // the receiving device. - txnID, err := sendingHelper.StartVerification(ctx, userID) - require.NoError(t, err) - ts.dispatchToDevice(t, ctx, receivingClient) - err = receivingHelper.AcceptVerification(ctx, txnID) - require.NoError(t, err) - - // Receive the m.key.verification.ready event on the sending device. - ts.dispatchToDevice(t, ctx, sendingClient) - - // The sending and receiving devices should not have any cancellation - // events in their inboxes. - assert.Empty(t, ts.DeviceInbox[userID][sendingDeviceID]) - assert.Empty(t, ts.DeviceInbox[userID][receivingDeviceID]) - - // There should now be cancellation events in the non-participating devices - // inboxes (in addition to the request event). - assert.Len(t, ts.DeviceInbox[userID][nonParticipatingDeviceID1], 2) - assert.Len(t, ts.DeviceInbox[userID][nonParticipatingDeviceID2], 2) - assert.Equal(t, ts.DeviceInbox[userID][nonParticipatingDeviceID1][1], ts.DeviceInbox[userID][nonParticipatingDeviceID2][1]) - cancellationEvent := ts.DeviceInbox[userID][nonParticipatingDeviceID1][1].Content.AsVerificationCancel() - assert.Equal(t, txnID, cancellationEvent.TransactionID) - assert.Equal(t, event.VerificationCancelCodeAccepted, cancellationEvent.Code) -} - func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) @@ -446,7 +150,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { // Send the verification request from the sender device and accept // it on the receiving device and receive the verification ready // event on the sending device. - txnID, err := sendingHelper.StartVerification(ctx, userID) + txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) ts.dispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) @@ -466,7 +170,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { // Ensure that the receiving device received a verification // start event and a verification done event. - receivingInbox := ts.DeviceInbox[userID][receivingDeviceID] + receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] assert.Len(t, receivingInbox, 2) startEvt := receivingInbox[0].Content.AsVerificationStart() @@ -490,7 +194,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { // Ensure that the sending device received a verification done // event. - sendingInbox := ts.DeviceInbox[userID][sendingDeviceID] + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] require.Len(t, sendingInbox, 1) doneEvt = sendingInbox[0].Content.AsVerificationDone() assert.Equal(t, txnID, doneEvt.TransactionID) @@ -504,7 +208,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { // Ensure that the sending device received a verification // start event and a verification done event. - sendingInbox := ts.DeviceInbox[userID][sendingDeviceID] + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] assert.Len(t, sendingInbox, 2) startEvt := sendingInbox[0].Content.AsVerificationStart() @@ -528,7 +232,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { // Ensure that the receiving device received a verification // done event. - receivingInbox := ts.DeviceInbox[userID][receivingDeviceID] + receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] require.Len(t, receivingInbox, 1) doneEvt = receivingInbox[0].Content.AsVerificationDone() assert.Equal(t, txnID, doneEvt.TransactionID) @@ -543,24 +247,6 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { } } -func TestSelfVerification_ErrorOnDoubleAccept(t *testing.T) { - ctx := log.Logger.WithContext(context.TODO()) - ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLogin(t, ctx) - defer ts.Close() - _, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) - - _, _, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") - require.NoError(t, err) - - txnID, err := sendingHelper.StartVerification(ctx, userID) - require.NoError(t, err) - ts.dispatchToDevice(t, ctx, receivingClient) - err = receivingHelper.AcceptVerification(ctx, txnID) - require.NoError(t, err) - err = receivingHelper.AcceptVerification(ctx, txnID) - require.ErrorContains(t, err, "transaction is not in the requested state") -} - func TestSelfVerification_ScanQRTransactionIDCorrupted(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) @@ -575,7 +261,7 @@ func TestSelfVerification_ScanQRTransactionIDCorrupted(t *testing.T) { // Send the verification request from the sender device and accept // it on the receiving device and receive the verification ready // event on the sending device. - txnID, err := sendingHelper.StartVerification(ctx, userID) + txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) ts.dispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) @@ -639,7 +325,7 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { // Send the verification request from the sender device and accept // it on the receiving device and receive the verification ready // event on the sending device. - txnID, err := sendingHelper.StartVerification(ctx, userID) + txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) ts.dispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) @@ -660,7 +346,7 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { assert.ErrorContains(t, err, tc.expectedError) // Ensure that the receiving device received a cancellation. - receivingInbox := ts.DeviceInbox[userID][receivingDeviceID] + receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] assert.Len(t, receivingInbox, 1) ts.dispatchToDevice(t, ctx, receivingClient) cancellation := receivingCallbacks.GetVerificationCancellation(txnID) @@ -674,7 +360,7 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { assert.ErrorContains(t, err, tc.expectedError) // Ensure that the sending device received a cancellation. - sendingInbox := ts.DeviceInbox[userID][sendingDeviceID] + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] assert.Len(t, sendingInbox, 1) ts.dispatchToDevice(t, ctx, sendingClient) cancellation := sendingCallbacks.GetVerificationCancellation(txnID) diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go new file mode 100644 index 00000000..d9e53b91 --- /dev/null +++ b/crypto/verificationhelper/verificationhelper_test.go @@ -0,0 +1,329 @@ +package verificationhelper_test + +import ( + "context" + "fmt" + "os" + "testing" + "time" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/cryptohelper" + "maunium.net/go/mautrix/crypto/verificationhelper" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +var aliceUserID = id.UserID("@alice:example.org") +var sendingDeviceID = id.DeviceID("sending") +var receivingDeviceID = id.DeviceID("receiving") + +func init() { + log.Logger = zerolog.New(zerolog.ConsoleWriter{Out: os.Stdout, TimeFormat: time.RFC3339}).With().Timestamp().Logger().Level(zerolog.TraceLevel) + zerolog.DefaultContextLogger = &log.Logger +} + +func initServerAndLogin(t *testing.T, ctx context.Context) (ts *mockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) { + t.Helper() + ts = createMockServer(t) + + sendingClient, sendingCryptoStore = ts.Login(t, ctx, aliceUserID, sendingDeviceID) + sendingMachine = sendingClient.Crypto.(*cryptohelper.CryptoHelper).Machine() + receivingClient, receivingCryptoStore = ts.Login(t, ctx, aliceUserID, receivingDeviceID) + receivingMachine = receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine() + + err := sendingCryptoStore.PutDevice(ctx, aliceUserID, sendingMachine.OwnIdentity()) + require.NoError(t, err) + err = sendingCryptoStore.PutDevice(ctx, aliceUserID, receivingMachine.OwnIdentity()) + require.NoError(t, err) + err = receivingCryptoStore.PutDevice(ctx, aliceUserID, sendingMachine.OwnIdentity()) + require.NoError(t, err) + err = receivingCryptoStore.PutDevice(ctx, aliceUserID, receivingMachine.OwnIdentity()) + require.NoError(t, err) + return +} + +func initDefaultCallbacks(t *testing.T, ctx context.Context, sendingClient, receivingClient *mautrix.Client, sendingMachine, receivingMachine *crypto.OlmMachine) (sendingCallbacks, receivingCallbacks *allVerificationCallbacks, sendingHelper, receivingHelper *verificationhelper.VerificationHelper) { + t.Helper() + sendingCallbacks = newAllVerificationCallbacks() + sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, sendingCallbacks, true) + require.NoError(t, sendingHelper.Init(ctx)) + + receivingCallbacks = newAllVerificationCallbacks() + receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, receivingCallbacks, true) + require.NoError(t, receivingHelper.Init(ctx)) + return +} + +func TestVerification_Start(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + receivingDeviceID2 := id.DeviceID("receiving2") + + testCases := []struct { + supportsScan bool + callbacks MockVerificationCallbacks + startVerificationErrMsg string + expectedVerificationMethods []event.VerificationMethod + }{ + {false, newBaseVerificationCallbacks(), "no supported verification methods", nil}, + {true, newBaseVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {false, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}}, + {true, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, + {true, newQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {false, newQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {false, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, + {true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, + } + + for i, tc := range testCases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + ts := createMockServer(t) + defer ts.Close() + + client, cryptoStore := ts.Login(t, ctx, aliceUserID, sendingDeviceID) + addDeviceID(ctx, cryptoStore, aliceUserID, sendingDeviceID) + addDeviceID(ctx, cryptoStore, aliceUserID, receivingDeviceID) + addDeviceID(ctx, cryptoStore, aliceUserID, receivingDeviceID2) + + senderHelper := verificationhelper.NewVerificationHelper(client, client.Crypto.(*cryptohelper.CryptoHelper).Machine(), tc.callbacks, tc.supportsScan) + err := senderHelper.Init(ctx) + require.NoError(t, err) + + txnID, err := senderHelper.StartVerification(ctx, aliceUserID) + if tc.startVerificationErrMsg != "" { + assert.ErrorContains(t, err, tc.startVerificationErrMsg) + return + } + + assert.NoError(t, err) + assert.NotEmpty(t, txnID) + + toDeviceInbox := ts.DeviceInbox[aliceUserID] + + // Ensure that we didn't send a verification request to the + // sending device. + assert.Empty(t, toDeviceInbox[sendingDeviceID]) + + // Ensure that the verification request was sent to both of + // the other devices. + assert.NotEmpty(t, toDeviceInbox[receivingDeviceID]) + assert.NotEmpty(t, toDeviceInbox[receivingDeviceID2]) + assert.Equal(t, toDeviceInbox[receivingDeviceID], toDeviceInbox[receivingDeviceID2]) + assert.Len(t, toDeviceInbox[receivingDeviceID], 1) + + // Ensure that the verification request is correct. + verificationRequest := toDeviceInbox[receivingDeviceID][0].Content.AsVerificationRequest() + assert.Equal(t, sendingDeviceID, verificationRequest.FromDevice) + assert.Equal(t, txnID, verificationRequest.TransactionID) + assert.ElementsMatch(t, tc.expectedVerificationMethods, verificationRequest.Methods) + }) + } +} + +func TestVerification_Accept_NoSupportedMethods(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + + ts := createMockServer(t) + defer ts.Close() + + sendingClient, sendingCryptoStore := ts.Login(t, ctx, aliceUserID, sendingDeviceID) + receivingClient, _ := ts.Login(t, ctx, aliceUserID, receivingDeviceID) + addDeviceID(ctx, sendingCryptoStore, aliceUserID, sendingDeviceID) + addDeviceID(ctx, sendingCryptoStore, aliceUserID, receivingDeviceID) + + sendingMachine := sendingClient.Crypto.(*cryptohelper.CryptoHelper).Machine() + recoveryKey, cache, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + assert.NoError(t, err) + assert.NotEmpty(t, recoveryKey) + assert.NotNil(t, cache) + + sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, newAllVerificationCallbacks(), true) + err = sendingHelper.Init(ctx) + require.NoError(t, err) + + receivingCallbacks := newBaseVerificationCallbacks() + receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine(), receivingCallbacks, false) + err = receivingHelper.Init(ctx) + require.NoError(t, err) + + txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) + require.NoError(t, err) + require.NotEmpty(t, txnID) + + ts.dispatchToDevice(t, ctx, receivingClient) + + // Ensure that the receiver ignored the request because it + // doesn't support any of the verification methods in the + // request. + assert.Empty(t, receivingCallbacks.GetRequestedVerifications()) +} + +func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + + testCases := []struct { + sendingSupportsScan bool + receivingSupportsScan bool + sendingCallbacks MockVerificationCallbacks + receivingCallbacks MockVerificationCallbacks + expectedVerificationMethods []event.VerificationMethod + }{ + {false, false, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}}, + {true, false, newQRCodeVerificationCallbacks(), newQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {false, true, newQRCodeVerificationCallbacks(), newQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {true, false, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, + {true, true, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, + } + + for i, tc := range testCases { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLogin(t, ctx) + defer ts.Close() + + recoveryKey, sendingCrossSigningKeysCache, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + assert.NoError(t, err) + assert.NotEmpty(t, recoveryKey) + assert.NotNil(t, sendingCrossSigningKeysCache) + + sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, tc.sendingCallbacks, tc.sendingSupportsScan) + err = sendingHelper.Init(ctx) + require.NoError(t, err) + + receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, tc.receivingCallbacks, tc.receivingSupportsScan) + err = receivingHelper.Init(ctx) + require.NoError(t, err) + + txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) + require.NoError(t, err) + + // Process the verification request on the receiving device. + ts.dispatchToDevice(t, ctx, receivingClient) + + // Ensure that the receiving device received a verification + // request with the correct transaction ID. + assert.ElementsMatch(t, []id.VerificationTransactionID{txnID}, tc.receivingCallbacks.GetRequestedVerifications()[aliceUserID]) + + // Have the receiving device accept the verification request. + err = receivingHelper.AcceptVerification(ctx, txnID) + require.NoError(t, err) + + _, sendingIsQRCallbacks := tc.sendingCallbacks.(*qrCodeVerificationCallbacks) + _, sendingIsAllCallbacks := tc.sendingCallbacks.(*allVerificationCallbacks) + sendingCanShowQR := sendingIsQRCallbacks || sendingIsAllCallbacks + _, receivingIsQRCallbacks := tc.receivingCallbacks.(*qrCodeVerificationCallbacks) + _, receivingIsAllCallbacks := tc.receivingCallbacks.(*allVerificationCallbacks) + receivingCanShowQR := receivingIsQRCallbacks || receivingIsAllCallbacks + + // Ensure that if the receiving device should show a QR code that + // it has the correct content. + if tc.sendingSupportsScan && receivingCanShowQR { + receivingShownQRCode := tc.receivingCallbacks.GetQRCodeShown(txnID) + require.NotNil(t, receivingShownQRCode) + assert.Equal(t, txnID, receivingShownQRCode.TransactionID) + assert.NotEmpty(t, receivingShownQRCode.SharedSecret) + } + + // Check for whether the receiving device should be scanning a QR + // code. + if tc.receivingSupportsScan && sendingCanShowQR { + assert.Contains(t, tc.receivingCallbacks.GetScanQRCodeTransactions(), txnID) + } + + // Check that the m.key.verification.ready event has the correct + // content. + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] + assert.Len(t, sendingInbox, 1) + readyEvt := sendingInbox[0].Content.AsVerificationReady() + assert.Equal(t, txnID, readyEvt.TransactionID) + assert.Equal(t, receivingDeviceID, readyEvt.FromDevice) + assert.ElementsMatch(t, tc.expectedVerificationMethods, readyEvt.Methods) + + // Receive the m.key.verification.ready event on the sending + // device. + ts.dispatchToDevice(t, ctx, sendingClient) + + // Ensure that if the sending device should show a QR code that it + // has the correct content. + if tc.receivingSupportsScan && sendingCanShowQR { + sendingShownQRCode := tc.sendingCallbacks.GetQRCodeShown(txnID) + require.NotNil(t, sendingShownQRCode) + assert.Equal(t, txnID, sendingShownQRCode.TransactionID) + assert.NotEmpty(t, sendingShownQRCode.SharedSecret) + } + + // Check for whether the sending device should be scanning a QR + // code. + if tc.sendingSupportsScan && receivingCanShowQR { + assert.Contains(t, tc.sendingCallbacks.GetScanQRCodeTransactions(), txnID) + } + }) + } +} + +// TestAcceptSelfVerificationCancelOnNonParticipatingDevices ensures that we do +// not regress https://github.com/mautrix/go/pull/230. +func TestVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + ts, sendingClient, receivingClient, sendingCryptoStore, receivingCryptoStore, sendingMachine, receivingMachine := initServerAndLogin(t, ctx) + defer ts.Close() + _, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + + nonParticipatingDeviceID1 := id.DeviceID("non-participating1") + nonParticipatingDeviceID2 := id.DeviceID("non-participating2") + addDeviceID(ctx, sendingCryptoStore, aliceUserID, nonParticipatingDeviceID1) + addDeviceID(ctx, sendingCryptoStore, aliceUserID, nonParticipatingDeviceID2) + addDeviceID(ctx, receivingCryptoStore, aliceUserID, nonParticipatingDeviceID1) + addDeviceID(ctx, receivingCryptoStore, aliceUserID, nonParticipatingDeviceID2) + + _, _, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + assert.NoError(t, err) + + // Send the verification request from the sender device and accept it on + // the receiving device. + txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) + require.NoError(t, err) + ts.dispatchToDevice(t, ctx, receivingClient) + err = receivingHelper.AcceptVerification(ctx, txnID) + require.NoError(t, err) + + // Receive the m.key.verification.ready event on the sending device. + ts.dispatchToDevice(t, ctx, sendingClient) + + // The sending and receiving devices should not have any cancellation + // events in their inboxes. + assert.Empty(t, ts.DeviceInbox[aliceUserID][sendingDeviceID]) + assert.Empty(t, ts.DeviceInbox[aliceUserID][receivingDeviceID]) + + // There should now be cancellation events in the non-participating devices + // inboxes (in addition to the request event). + assert.Len(t, ts.DeviceInbox[aliceUserID][nonParticipatingDeviceID1], 2) + assert.Len(t, ts.DeviceInbox[aliceUserID][nonParticipatingDeviceID2], 2) + assert.Equal(t, ts.DeviceInbox[aliceUserID][nonParticipatingDeviceID1][1], ts.DeviceInbox[aliceUserID][nonParticipatingDeviceID2][1]) + cancellationEvent := ts.DeviceInbox[aliceUserID][nonParticipatingDeviceID1][1].Content.AsVerificationCancel() + assert.Equal(t, txnID, cancellationEvent.TransactionID) + assert.Equal(t, event.VerificationCancelCodeAccepted, cancellationEvent.Code) +} + +func TestVerification_ErrorOnDoubleAccept(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLogin(t, ctx) + defer ts.Close() + _, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + + _, _, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + + txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) + require.NoError(t, err) + ts.dispatchToDevice(t, ctx, receivingClient) + err = receivingHelper.AcceptVerification(ctx, txnID) + require.NoError(t, err) + err = receivingHelper.AcceptVerification(ctx, txnID) + require.ErrorContains(t, err, "transaction is not in the requested state") +} From a2abce8215a9a20e4ae9a8765330117b1f9fd090 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Sun, 26 May 2024 20:50:47 -0600 Subject: [PATCH 0229/1647] verificationhelper: add tests for SAS flow Signed-off-by: Sumner Evans --- crypto/verificationhelper/callbacks_test.go | 17 +- .../verificationhelper_sas_test.go | 278 ++++++++++++++++++ 2 files changed, 293 insertions(+), 2 deletions(-) create mode 100644 crypto/verificationhelper/verificationhelper_sas_test.go diff --git a/crypto/verificationhelper/callbacks_test.go b/crypto/verificationhelper/callbacks_test.go index 7fc35129..7b1055d1 100644 --- a/crypto/verificationhelper/callbacks_test.go +++ b/crypto/verificationhelper/callbacks_test.go @@ -27,6 +27,8 @@ type baseVerificationCallbacks struct { qrCodesScanned map[id.VerificationTransactionID]struct{} doneTransactions map[id.VerificationTransactionID]struct{} verificationCancellation map[id.VerificationTransactionID]*event.VerificationCancelEventContent + emojisShown map[id.VerificationTransactionID][]rune + decimalsShown map[id.VerificationTransactionID][]int } func newBaseVerificationCallbacks() *baseVerificationCallbacks { @@ -36,6 +38,8 @@ func newBaseVerificationCallbacks() *baseVerificationCallbacks { qrCodesScanned: map[id.VerificationTransactionID]struct{}{}, doneTransactions: map[id.VerificationTransactionID]struct{}{}, verificationCancellation: map[id.VerificationTransactionID]*event.VerificationCancelEventContent{}, + emojisShown: map[id.VerificationTransactionID][]rune{}, + decimalsShown: map[id.VerificationTransactionID][]int{}, } } @@ -65,6 +69,14 @@ func (c *baseVerificationCallbacks) GetVerificationCancellation(txnID id.Verific return c.verificationCancellation[txnID] } +func (c *baseVerificationCallbacks) GetEmojisShown(txnID id.VerificationTransactionID) []rune { + return c.emojisShown[txnID] +} + +func (c *baseVerificationCallbacks) GetDecimalsShown(txnID id.VerificationTransactionID) []int { + return c.decimalsShown[txnID] +} + func (c *baseVerificationCallbacks) VerificationRequested(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID) { c.verificationsRequested[from] = append(c.verificationsRequested[from], txnID) } @@ -92,8 +104,9 @@ func newSASVerificationCallbacksWithBase(base *baseVerificationCallbacks) *sasVe return &sasVerificationCallbacks{base} } -func (*sasVerificationCallbacks) ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, decimals []int) { - panic("show sas") +func (c *sasVerificationCallbacks) ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, decimals []int) { + c.emojisShown[txnID] = emojis + c.decimalsShown[txnID] = decimals } type qrCodeVerificationCallbacks struct { diff --git a/crypto/verificationhelper/verificationhelper_sas_test.go b/crypto/verificationhelper/verificationhelper_sas_test.go new file mode 100644 index 00000000..04cddbf4 --- /dev/null +++ b/crypto/verificationhelper/verificationhelper_sas_test.go @@ -0,0 +1,278 @@ +package verificationhelper_test + +import ( + "context" + "fmt" + "testing" + + "github.com/rs/zerolog/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" + + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func TestVerification_SAS(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + + testCases := []struct { + sendingGeneratedCrossSigningKeys bool + sendingStartsSAS bool + sendingConfirmsFirst bool + }{ + {true, true, true}, + {true, true, false}, + {true, false, true}, + {true, false, false}, + {false, true, true}, + {false, true, false}, + {false, false, true}, + {false, false, false}, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("sendingGenerated=%t sendingStartsSAS=%t sendingConfirmsFirst=%t", tc.sendingGeneratedCrossSigningKeys, tc.sendingStartsSAS, tc.sendingConfirmsFirst), func(t *testing.T) { + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLogin(t, ctx) + defer ts.Close() + sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + var err error + + var sendingRecoveryKey, receivingRecoveryKey string + var sendingCrossSigningKeysCache, receivingCrossSigningKeysCache *crypto.CrossSigningKeysCache + + if tc.sendingGeneratedCrossSigningKeys { + sendingRecoveryKey, sendingCrossSigningKeysCache, err = sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + assert.NotEmpty(t, sendingRecoveryKey) + assert.NotNil(t, sendingCrossSigningKeysCache) + } else { + receivingRecoveryKey, receivingCrossSigningKeysCache, err = receivingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + assert.NotEmpty(t, receivingRecoveryKey) + assert.NotNil(t, receivingCrossSigningKeysCache) + } + + // Send the verification request from the sender device and accept + // it on the receiving device and receive the verification ready + // event on the sending device. + txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) + require.NoError(t, err) + ts.dispatchToDevice(t, ctx, receivingClient) + err = receivingHelper.AcceptVerification(ctx, txnID) + require.NoError(t, err) + ts.dispatchToDevice(t, ctx, sendingClient) + + // Test that the start event is correct + var startEvt *event.VerificationStartEventContent + if tc.sendingStartsSAS { + err = sendingHelper.StartSAS(ctx, txnID) + require.NoError(t, err) + + // Ensure that the receiving device received a verification + // start event. + receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] + assert.Len(t, receivingInbox, 1) + startEvt = receivingInbox[0].Content.AsVerificationStart() + assert.Equal(t, sendingDeviceID, startEvt.FromDevice) + } else { + err = receivingHelper.StartSAS(ctx, txnID) + require.NoError(t, err) + + // Ensure that the receiving device received a verification + // start event. + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] + assert.Len(t, sendingInbox, 1) + startEvt = sendingInbox[0].Content.AsVerificationStart() + assert.Equal(t, receivingDeviceID, startEvt.FromDevice) + } + assert.Equal(t, txnID, startEvt.TransactionID) + assert.Equal(t, event.VerificationMethodSAS, startEvt.Method) + assert.Contains(t, startEvt.Hashes, event.VerificationHashMethodSHA256) + assert.Contains(t, startEvt.KeyAgreementProtocols, event.KeyAgreementProtocolCurve25519HKDFSHA256) + assert.Contains(t, startEvt.MessageAuthenticationCodes, event.MACMethodHKDFHMACSHA256) + assert.Contains(t, startEvt.MessageAuthenticationCodes, event.MACMethodHKDFHMACSHA256V2) + assert.Contains(t, startEvt.ShortAuthenticationString, event.SASMethodDecimal) + assert.Contains(t, startEvt.ShortAuthenticationString, event.SASMethodEmoji) + + // Test that the accept event is correct + var acceptEvt *event.VerificationAcceptEventContent + if tc.sendingStartsSAS { + // Process the verification start event on the receiving + // device. + ts.dispatchToDevice(t, ctx, receivingClient) + + // Receiving device sent the accept event to the sending device + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] + assert.Len(t, sendingInbox, 1) + acceptEvt = sendingInbox[0].Content.AsVerificationAccept() + } else { + // Process the verification start event on the sending device. + ts.dispatchToDevice(t, ctx, sendingClient) + + // Sending device sent the accept event to the receiving device + receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] + assert.Len(t, receivingInbox, 1) + acceptEvt = receivingInbox[0].Content.AsVerificationAccept() + } + assert.Equal(t, txnID, acceptEvt.TransactionID) + assert.Equal(t, acceptEvt.Hash, event.VerificationHashMethodSHA256) + assert.Equal(t, acceptEvt.KeyAgreementProtocol, event.KeyAgreementProtocolCurve25519HKDFSHA256) + assert.Equal(t, acceptEvt.MessageAuthenticationCode, event.MACMethodHKDFHMACSHA256V2) + assert.Contains(t, acceptEvt.ShortAuthenticationString, event.SASMethodDecimal) + assert.Contains(t, acceptEvt.ShortAuthenticationString, event.SASMethodEmoji) + assert.NotEmpty(t, acceptEvt.Commitment) + + // Test that the first key event is correct + var firstKeyEvt *event.VerificationKeyEventContent + if tc.sendingStartsSAS { + // Process the verification accept event on the sending device. + ts.dispatchToDevice(t, ctx, sendingClient) + + // Sending device sends first key event to the receiving + // device. + receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] + assert.Len(t, receivingInbox, 1) + firstKeyEvt = receivingInbox[0].Content.AsVerificationKey() + } else { + // Process the verification accept event on the receiving + // device. + ts.dispatchToDevice(t, ctx, receivingClient) + + // Receiving device sends first key event to the sending + // device. + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] + assert.Len(t, sendingInbox, 1) + firstKeyEvt = sendingInbox[0].Content.AsVerificationKey() + } + assert.Equal(t, txnID, firstKeyEvt.TransactionID) + assert.NotEmpty(t, firstKeyEvt.Key) + assert.Len(t, firstKeyEvt.Key, 32) + + // Test that the second key event is correct + var secondKeyEvt *event.VerificationKeyEventContent + if tc.sendingStartsSAS { + // Process the first key event on the receiving device. + ts.dispatchToDevice(t, ctx, receivingClient) + + // Receiving device sends second key event to the sending + // device. + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] + assert.Len(t, sendingInbox, 1) + secondKeyEvt = sendingInbox[0].Content.AsVerificationKey() + + // Ensure that the receiving device showed emojis and SAS numbers. + assert.Len(t, receivingCallbacks.GetDecimalsShown(txnID), 3) + assert.Len(t, receivingCallbacks.GetEmojisShown(txnID), 7) + } else { + // Process the first key event on the sending device. + ts.dispatchToDevice(t, ctx, sendingClient) + + // Sending device sends second key event to the receiving + // device. + receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] + assert.Len(t, receivingInbox, 1) + secondKeyEvt = receivingInbox[0].Content.AsVerificationKey() + + // Ensure that the sending device showed emojis and SAS numbers. + assert.Len(t, sendingCallbacks.GetDecimalsShown(txnID), 3) + assert.Len(t, sendingCallbacks.GetEmojisShown(txnID), 7) + } + assert.Equal(t, txnID, secondKeyEvt.TransactionID) + assert.NotEmpty(t, secondKeyEvt.Key) + assert.Len(t, secondKeyEvt.Key, 32) + + // Ensure that the SAS codes are the same. + if tc.sendingStartsSAS { + // Process the second key event on the sending device. + ts.dispatchToDevice(t, ctx, sendingClient) + } else { + // Process the second key event on the receiving device. + ts.dispatchToDevice(t, ctx, receivingClient) + } + assert.Equal(t, sendingCallbacks.GetDecimalsShown(txnID), receivingCallbacks.GetDecimalsShown(txnID)) + assert.Equal(t, sendingCallbacks.GetEmojisShown(txnID), receivingCallbacks.GetEmojisShown(txnID)) + + // Test that the first MAC event is correct + var firstMACEvt *event.VerificationMACEventContent + if tc.sendingConfirmsFirst { + err = sendingHelper.ConfirmSAS(ctx, txnID) + require.NoError(t, err) + + // The receiving device should have received the MAC event. + receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] + assert.Len(t, receivingInbox, 1) + firstMACEvt = receivingInbox[0].Content.AsVerificationMAC() + + // The MAC event should have a MAC for the sending device ID. + assert.Contains(t, maps.Keys(firstMACEvt.MAC), id.NewKeyID(id.KeyAlgorithmEd25519, sendingDeviceID.String())) + } else { + err = receivingHelper.ConfirmSAS(ctx, txnID) + require.NoError(t, err) + + // The sending device should have received the MAC event. + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] + assert.Len(t, sendingInbox, 1) + firstMACEvt = sendingInbox[0].Content.AsVerificationMAC() + + // The MAC event should have a MAC for the receiving device ID. + assert.Contains(t, maps.Keys(firstMACEvt.MAC), id.NewKeyID(id.KeyAlgorithmEd25519, receivingDeviceID.String())) + } + assert.Equal(t, txnID, firstMACEvt.TransactionID) + + // The master key and the sending device ID should be in the + // MAC event's mac keys. + if tc.sendingGeneratedCrossSigningKeys { + assert.Contains(t, maps.Keys(firstMACEvt.MAC), id.NewKeyID(id.KeyAlgorithmEd25519, sendingCrossSigningKeysCache.MasterKey.PublicKey().String())) + } else { + assert.Contains(t, maps.Keys(firstMACEvt.MAC), id.NewKeyID(id.KeyAlgorithmEd25519, receivingCrossSigningKeysCache.MasterKey.PublicKey().String())) + } + + // Test that the second MAC event is correct + var secondMACEvt *event.VerificationMACEventContent + if tc.sendingConfirmsFirst { + err = receivingHelper.ConfirmSAS(ctx, txnID) + require.NoError(t, err) + + // The sending device should have received the MAC event. + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] + assert.Len(t, sendingInbox, 1) + secondMACEvt = sendingInbox[0].Content.AsVerificationMAC() + + // The MAC event should have a MAC for the receiving device ID. + assert.Contains(t, maps.Keys(secondMACEvt.MAC), id.NewKeyID(id.KeyAlgorithmEd25519, receivingDeviceID.String())) + } else { + err = sendingHelper.ConfirmSAS(ctx, txnID) + require.NoError(t, err) + + // The receiving device should have received the MAC event. + receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] + assert.Len(t, receivingInbox, 1) + secondMACEvt = receivingInbox[0].Content.AsVerificationMAC() + + // The MAC event should have a MAC for the sending device ID. + assert.Contains(t, maps.Keys(secondMACEvt.MAC), id.NewKeyID(id.KeyAlgorithmEd25519, sendingDeviceID.String())) + } + assert.Equal(t, txnID, secondMACEvt.TransactionID) + + // The master key and the sending device ID should be in the + // MAC event's mac keys. + if tc.sendingGeneratedCrossSigningKeys { + assert.Contains(t, maps.Keys(firstMACEvt.MAC), id.NewKeyID(id.KeyAlgorithmEd25519, sendingCrossSigningKeysCache.MasterKey.PublicKey().String())) + } else { + assert.Contains(t, maps.Keys(firstMACEvt.MAC), id.NewKeyID(id.KeyAlgorithmEd25519, receivingCrossSigningKeysCache.MasterKey.PublicKey().String())) + } + + // Test the transaction is done on both sides. We have to dispatch + // twice to process and drain all of the events. + ts.dispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, sendingClient) + ts.dispatchToDevice(t, ctx, receivingClient) + assert.True(t, sendingCallbacks.IsVerificationDone(txnID)) + assert.True(t, receivingCallbacks.IsVerificationDone(txnID)) + }) + } +} From 5bdc3fdca0a11e2d1dc59e84ed43bd59bb78a04e Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Sun, 26 May 2024 22:18:12 -0600 Subject: [PATCH 0230/1647] verificationhelper: implement cross-signing Signed-off-by: Sumner Evans --- crypto/verificationhelper/reciprocate.go | 52 +++++++++++++++++++----- 1 file changed, 42 insertions(+), 10 deletions(-) diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index e99e55fd..851878aa 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -9,6 +9,7 @@ package verificationhelper import ( "bytes" "context" + "errors" "fmt" "golang.org/x/exp/slices" @@ -47,8 +48,26 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by ownCrossSigningPublicKeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) switch qrCode.Mode { case QRCodeModeCrossSigning: - panic("unimplemented") - // TODO verify and sign their master key + theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUser) + if err != nil { + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "couldn't get %s's cross-signing keys: %w", txn.TheirUser, err) + } + if bytes.Equal(theirSigningKeys.MasterKey.Bytes(), qrCode.Key1[:]) { + log.Info().Msg("Verified that the other device has the master key we expected") + } else { + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "the other device does not have the master key we expected") + } + + // Verify the master key is correct + if bytes.Equal(ownCrossSigningPublicKeys.MasterKey.Bytes(), qrCode.Key2[:]) { + log.Info().Msg("Verified that the other device has the same master key") + } else { + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "the master key does not match") + } + + if err := vh.mach.SignUser(ctx, txn.TheirUser, theirSigningKeys.MasterKey); err != nil { + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to sign their master key: %w", err) + } case QRCodeModeSelfVerifyingMasterKeyTrusted: // The QR was created by a device that trusts the master key, which // means that we don't trust the key. Key1 is the master key public @@ -75,7 +94,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by } if err := vh.mach.SignOwnMasterKey(ctx); err != nil { - return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "failed to sign own master key: %w", err) + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to sign own master key: %w", err) } case QRCodeModeSelfVerifyingMasterKeyUntrusted: // The QR was created by a device that does not trust the master key, @@ -203,8 +222,15 @@ func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id } } } else { - // TODO: handle QR codes that are not self-signing situations - panic("unimplemented") + // Cross-signing situation. Sign their master key. + theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUser) + if err != nil { + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "couldn't get %s's cross-signing keys: %w", txn.TheirUser, err) + } + + if err := vh.mach.SignUser(ctx, txn.TheirUser, theirSigningKeys.MasterKey); err != nil { + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to sign their master key: %w", err) + } } err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) @@ -234,21 +260,27 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *ve ownCrossSigningPublicKeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) if ownCrossSigningPublicKeys == nil || len(ownCrossSigningPublicKeys.MasterKey) == 0 { - return fmt.Errorf("failed to get own cross-signing master public key") + return errors.New("failed to get own cross-signing master public key") } + ownMasterKeyTrusted, err := vh.mach.CryptoStore.IsKeySignedBy(ctx, vh.client.UserID, ownCrossSigningPublicKeys.MasterKey, vh.client.UserID, vh.mach.OwnIdentity().SigningKey) + if err != nil { + return err + } mode := QRCodeModeCrossSigning if vh.client.UserID == txn.TheirUser { // This is a self-signing situation. - if trusted, err := vh.mach.CryptoStore.IsKeySignedBy(ctx, vh.client.UserID, ownCrossSigningPublicKeys.MasterKey, vh.client.UserID, vh.mach.OwnIdentity().SigningKey); err != nil { - return err - } else if trusted { + if ownMasterKeyTrusted { mode = QRCodeModeSelfVerifyingMasterKeyTrusted } else { mode = QRCodeModeSelfVerifyingMasterKeyUntrusted } } else { - panic("unimplemented") + // This is a cross-signing situation. + if !ownMasterKeyTrusted { + return errors.New("cannot cross-sign other device when own master key is not trusted") + } + mode = QRCodeModeCrossSigning } var key1, key2 []byte From c1e7cc53004289cbc99455b151a6d451524298ec Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Sun, 26 May 2024 22:18:34 -0600 Subject: [PATCH 0231/1647] verificationhelper: add test for QR code cross-signing Signed-off-by: Sumner Evans --- crypto/verificationhelper/mockserver_test.go | 7 + .../verificationhelper_qr_crosssign_test.go | 154 ++++++++++++++++++ ....go => verificationhelper_qr_self_test.go} | 8 +- .../verificationhelper_sas_test.go | 2 +- .../verificationhelper_test.go | 37 +++-- 5 files changed, 191 insertions(+), 17 deletions(-) create mode 100644 crypto/verificationhelper/verificationhelper_qr_crosssign_test.go rename crypto/verificationhelper/{verificationhelper_qr_test.go => verificationhelper_qr_self_test.go} (98%) diff --git a/crypto/verificationhelper/mockserver_test.go b/crypto/verificationhelper/mockserver_test.go index 3c640267..e35f51b2 100644 --- a/crypto/verificationhelper/mockserver_test.go +++ b/crypto/verificationhelper/mockserver_test.go @@ -17,6 +17,7 @@ import ( "testing" "github.com/gorilla/mux" + "github.com/rs/zerolog/log" "github.com/stretchr/testify/require" "go.mau.fi/util/random" @@ -222,6 +223,12 @@ func (ms *mockServer) Login(t *testing.T, ctx context.Context, userID id.UserID, err = cryptoHelper.Init(ctx) require.NoError(t, err) + machineLog := log.Logger.With(). + Stringer("my_user_id", userID). + Stringer("my_device_id", deviceID). + Logger() + cryptoHelper.Machine().Log = &machineLog + err = cryptoHelper.Machine().ShareKeys(ctx, 50) require.NoError(t, err) diff --git a/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go b/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go new file mode 100644 index 00000000..2bbed25e --- /dev/null +++ b/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go @@ -0,0 +1,154 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package verificationhelper_test + +import ( + "context" + "fmt" + "testing" + + "github.com/rs/zerolog/log" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + + testCases := []struct { + sendingScansQR bool // false indicates that receiving device should emulate a scan + }{ + {false}, + {true}, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("sendingScansQR=%t", tc.sendingScansQR), func(t *testing.T) { + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginAliceBob(t, ctx) + defer ts.Close() + sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + var err error + + // Generate cross-signing keys for both users + _, _, err = sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + _, _, err = receivingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + + // Fetch each other's keys + sendingMachine.FetchKeys(ctx, []id.UserID{bobUserID}, true) + receivingMachine.FetchKeys(ctx, []id.UserID{aliceUserID}, true) + + // Send the verification request from the sender device and accept + // it on the receiving device and receive the verification ready + // event on the sending device. + txnID, err := sendingHelper.StartVerification(ctx, bobUserID) + require.NoError(t, err) + ts.dispatchToDevice(t, ctx, receivingClient) + err = receivingHelper.AcceptVerification(ctx, txnID) + require.NoError(t, err) + ts.dispatchToDevice(t, ctx, sendingClient) + + receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID) + require.NotNil(t, receivingShownQRCode) + sendingShownQRCode := sendingCallbacks.GetQRCodeShown(txnID) + require.NotNil(t, sendingShownQRCode) + + if tc.sendingScansQR { + // Emulate scanning the QR code shown by the receiving device + // on the sending device. + err := sendingHelper.HandleScannedQRData(ctx, receivingShownQRCode.Bytes()) + require.NoError(t, err) + + // Ensure that the receiving device received a verification + // start event and a verification done event. + receivingInbox := ts.DeviceInbox[bobUserID][receivingDeviceID] + assert.Len(t, receivingInbox, 2) + + startEvt := receivingInbox[0].Content.AsVerificationStart() + assert.Equal(t, txnID, startEvt.TransactionID) + assert.Equal(t, sendingDeviceID, startEvt.FromDevice) + assert.Equal(t, event.VerificationMethodReciprocate, startEvt.Method) + assert.EqualValues(t, receivingShownQRCode.SharedSecret, startEvt.Secret) + + doneEvt := receivingInbox[1].Content.AsVerificationDone() + assert.Equal(t, txnID, doneEvt.TransactionID) + + // Handle the start and done events on the receiving client and + // confirm the scan. + ts.dispatchToDevice(t, ctx, receivingClient) + + // Ensure that the receiving device detected that its QR code + // was scanned. + assert.True(t, receivingCallbacks.WasOurQRCodeScanned(txnID)) + err = receivingHelper.ConfirmQRCodeScanned(ctx, txnID) + require.NoError(t, err) + + // Ensure that the sending device received a verification done + // event. + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] + require.Len(t, sendingInbox, 1) + doneEvt = sendingInbox[0].Content.AsVerificationDone() + assert.Equal(t, txnID, doneEvt.TransactionID) + + ts.dispatchToDevice(t, ctx, sendingClient) + } else { // receiving scans QR + // Emulate scanning the QR code shown by the sending device on + // the receiving device. + err := receivingHelper.HandleScannedQRData(ctx, sendingShownQRCode.Bytes()) + require.NoError(t, err) + + // Ensure that the sending device received a verification + // start event and a verification done event. + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] + assert.Len(t, sendingInbox, 2) + + startEvt := sendingInbox[0].Content.AsVerificationStart() + assert.Equal(t, txnID, startEvt.TransactionID) + assert.Equal(t, receivingDeviceID, startEvt.FromDevice) + assert.Equal(t, event.VerificationMethodReciprocate, startEvt.Method) + assert.EqualValues(t, sendingShownQRCode.SharedSecret, startEvt.Secret) + + doneEvt := sendingInbox[1].Content.AsVerificationDone() + assert.Equal(t, txnID, doneEvt.TransactionID) + + // Handle the start and done events on the receiving client and + // confirm the scan. + ts.dispatchToDevice(t, ctx, sendingClient) + + // Ensure that the sending device detected that its QR code was + // scanned. + assert.True(t, sendingCallbacks.WasOurQRCodeScanned(txnID)) + err = sendingHelper.ConfirmQRCodeScanned(ctx, txnID) + require.NoError(t, err) + + // Ensure that the receiving device received a verification + // done event. + receivingInbox := ts.DeviceInbox[bobUserID][receivingDeviceID] + require.Len(t, receivingInbox, 1) + doneEvt = receivingInbox[0].Content.AsVerificationDone() + assert.Equal(t, txnID, doneEvt.TransactionID) + + ts.dispatchToDevice(t, ctx, receivingClient) + } + + // Ensure that both devices have marked the verification as done. + assert.True(t, sendingCallbacks.IsVerificationDone(txnID)) + assert.True(t, receivingCallbacks.IsVerificationDone(txnID)) + + bobTrustsAlice, err := receivingMachine.IsUserTrusted(ctx, aliceUserID) + assert.NoError(t, err) + assert.True(t, bobTrustsAlice) + aliceTrustsBob, err := sendingMachine.IsUserTrusted(ctx, bobUserID) + assert.NoError(t, err) + assert.True(t, aliceTrustsBob) + }) + } +} diff --git a/crypto/verificationhelper/verificationhelper_qr_test.go b/crypto/verificationhelper/verificationhelper_qr_self_test.go similarity index 98% rename from crypto/verificationhelper/verificationhelper_qr_test.go rename to crypto/verificationhelper/verificationhelper_qr_self_test.go index 7391f151..443157b7 100644 --- a/crypto/verificationhelper/verificationhelper_qr_test.go +++ b/crypto/verificationhelper/verificationhelper_qr_self_test.go @@ -35,7 +35,7 @@ func TestSelfVerification_Accept_QRContents(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("sendingGenerated=%t receivingGenerated=%t err=%s", tc.sendingGeneratedCrossSigningKeys, tc.receivingGeneratedCrossSigningKeys, tc.expectedAcceptError), func(t *testing.T) { - ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLogin(t, ctx) + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -134,7 +134,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("sendingGeneratedCrossSigningKeys=%t sendingScansQR=%t", tc.sendingGeneratedCrossSigningKeys, tc.sendingScansQR), func(t *testing.T) { - ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLogin(t, ctx) + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -250,7 +250,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { func TestSelfVerification_ScanQRTransactionIDCorrupted(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) - ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLogin(t, ctx) + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -309,7 +309,7 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("sendingGeneratedCrossSigningKeys=%t sendingScansQR=%t corrupt=%d", tc.sendingGeneratedCrossSigningKeys, tc.sendingScansQR, tc.corruptByte), func(t *testing.T) { - ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLogin(t, ctx) + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error diff --git a/crypto/verificationhelper/verificationhelper_sas_test.go b/crypto/verificationhelper/verificationhelper_sas_test.go index 04cddbf4..e986cf85 100644 --- a/crypto/verificationhelper/verificationhelper_sas_test.go +++ b/crypto/verificationhelper/verificationhelper_sas_test.go @@ -35,7 +35,7 @@ func TestVerification_SAS(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("sendingGenerated=%t sendingStartsSAS=%t sendingConfirmsFirst=%t", tc.sendingGeneratedCrossSigningKeys, tc.sendingStartsSAS, tc.sendingConfirmsFirst), func(t *testing.T) { - ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLogin(t, ctx) + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go index d9e53b91..0f68e261 100644 --- a/crypto/verificationhelper/verificationhelper_test.go +++ b/crypto/verificationhelper/verificationhelper_test.go @@ -21,6 +21,7 @@ import ( ) var aliceUserID = id.UserID("@alice:example.org") +var bobUserID = id.UserID("@bob:example.org") var sendingDeviceID = id.DeviceID("sending") var receivingDeviceID = id.DeviceID("receiving") @@ -29,7 +30,7 @@ func init() { zerolog.DefaultContextLogger = &log.Logger } -func initServerAndLogin(t *testing.T, ctx context.Context) (ts *mockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) { +func initServerAndLoginTwoAlice(t *testing.T, ctx context.Context) (ts *mockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) { t.Helper() ts = createMockServer(t) @@ -38,14 +39,26 @@ func initServerAndLogin(t *testing.T, ctx context.Context) (ts *mockServer, send receivingClient, receivingCryptoStore = ts.Login(t, ctx, aliceUserID, receivingDeviceID) receivingMachine = receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine() - err := sendingCryptoStore.PutDevice(ctx, aliceUserID, sendingMachine.OwnIdentity()) - require.NoError(t, err) - err = sendingCryptoStore.PutDevice(ctx, aliceUserID, receivingMachine.OwnIdentity()) - require.NoError(t, err) - err = receivingCryptoStore.PutDevice(ctx, aliceUserID, sendingMachine.OwnIdentity()) - require.NoError(t, err) - err = receivingCryptoStore.PutDevice(ctx, aliceUserID, receivingMachine.OwnIdentity()) - require.NoError(t, err) + require.NoError(t, sendingCryptoStore.PutDevice(ctx, aliceUserID, sendingMachine.OwnIdentity())) + require.NoError(t, sendingCryptoStore.PutDevice(ctx, aliceUserID, receivingMachine.OwnIdentity())) + require.NoError(t, receivingCryptoStore.PutDevice(ctx, aliceUserID, sendingMachine.OwnIdentity())) + require.NoError(t, receivingCryptoStore.PutDevice(ctx, aliceUserID, receivingMachine.OwnIdentity())) + return +} + +func initServerAndLoginAliceBob(t *testing.T, ctx context.Context) (ts *mockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) { + t.Helper() + ts = createMockServer(t) + + sendingClient, sendingCryptoStore = ts.Login(t, ctx, aliceUserID, sendingDeviceID) + sendingMachine = sendingClient.Crypto.(*cryptohelper.CryptoHelper).Machine() + receivingClient, receivingCryptoStore = ts.Login(t, ctx, bobUserID, receivingDeviceID) + receivingMachine = receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine() + + require.NoError(t, sendingCryptoStore.PutDevice(ctx, aliceUserID, sendingMachine.OwnIdentity())) + require.NoError(t, sendingCryptoStore.PutDevice(ctx, bobUserID, receivingMachine.OwnIdentity())) + require.NoError(t, receivingCryptoStore.PutDevice(ctx, aliceUserID, sendingMachine.OwnIdentity())) + require.NoError(t, receivingCryptoStore.PutDevice(ctx, bobUserID, receivingMachine.OwnIdentity())) return } @@ -183,7 +196,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { for i, tc := range testCases { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLogin(t, ctx) + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) defer ts.Close() recoveryKey, sendingCrossSigningKeysCache, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") @@ -270,7 +283,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { // not regress https://github.com/mautrix/go/pull/230. func TestVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) - ts, sendingClient, receivingClient, sendingCryptoStore, receivingCryptoStore, sendingMachine, receivingMachine := initServerAndLogin(t, ctx) + ts, sendingClient, receivingClient, sendingCryptoStore, receivingCryptoStore, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) defer ts.Close() _, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) @@ -312,7 +325,7 @@ func TestVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) { func TestVerification_ErrorOnDoubleAccept(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) - ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLogin(t, ctx) + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) defer ts.Close() _, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) From a6a3876403ab0f7355e79f107963c5846d006577 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 27 May 2024 09:02:04 -0600 Subject: [PATCH 0232/1647] keybackup: don't NPE if we couldn't get cross signing pubkeys Signed-off-by: Sumner Evans --- crypto/keybackup.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/crypto/keybackup.go b/crypto/keybackup.go index 820f3114..e0aff254 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -54,6 +54,9 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context) } crossSigningPubkeys := mach.GetOwnCrossSigningPublicKeys(ctx) + if crossSigningPubkeys == nil { + return nil, ErrCrossSigningPubkeysNotCached + } signatureVerified := false for keyID := range userSignatures { From 289ef6f5dbfad593fbe1a30a5a683371d57d19f2 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 27 May 2024 09:03:15 -0600 Subject: [PATCH 0233/1647] verificationhelper: ensure cross-signing public keys are cached when handling QR data Signed-off-by: Sumner Evans --- crypto/verificationhelper/reciprocate.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index 851878aa..9c6de067 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -14,6 +14,7 @@ import ( "golang.org/x/exp/slices" + "maunium.net/go/mautrix/crypto" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -46,6 +47,10 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by log.Info().Msg("Verifying keys from QR code") ownCrossSigningPublicKeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) + if ownCrossSigningPublicKeys == nil { + return crypto.ErrCrossSigningPubkeysNotCached + } + switch qrCode.Mode { case QRCodeModeCrossSigning: theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUser) @@ -127,7 +132,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by } // Verify that what they think the master key is is correct. - if bytes.Equal(vh.mach.GetOwnCrossSigningPublicKeys(ctx).MasterKey.Bytes(), qrCode.Key2[:]) { + if bytes.Equal(ownCrossSigningPublicKeys.MasterKey.Bytes(), qrCode.Key2[:]) { log.Info().Msg("Verified that the other device has the correct master key") } else { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "the master key does not match") From 1c750ffd0d0fcbdde04ccda52a31b906a1b192fd Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 27 May 2024 11:36:42 -0600 Subject: [PATCH 0234/1647] verificationhelper: fix CancelVerification * Calling `CancelVerification` no longer echoes an error back representing the reason for the cancellation. * Calling `CancelVerification` right after starting verification (but before another device has accepted the verification) now sends out the cancellation events to all devices that the request was initially sent out to. * Adds a test to ensure that the above statements are actually true. Signed-off-by: Sumner Evans verificationhelper: add test for cancellating right after starting verification Signed-off-by: Sumner Evans --- .../verificationhelper/verificationhelper.go | 54 +++++++++++++++---- .../verificationhelper_test.go | 50 +++++++++++++++++ 2 files changed, 93 insertions(+), 11 deletions(-) diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index e2166a15..1fc1fd22 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -491,17 +491,52 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V // be the transaction ID of a verification request that was received via the // VerificationRequested callback in [RequiredCallbacks]. func (vh *VerificationHelper) CancelVerification(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) error { - log := vh.getLog(ctx).With(). - Str("verification_action", "cancel verification"). - Stringer("transaction_id", txnID). - Logger() - ctx = log.WithContext(ctx) - txn, ok := vh.activeTransactions[txnID] if !ok { return fmt.Errorf("unknown transaction ID") } - return vh.cancelVerificationTxn(ctx, txn, code, reason) + log := vh.getLog(ctx).With(). + Str("verification_action", "cancel verification"). + Stringer("transaction_id", txnID). + Str("code", string(code)). + Str("reason", reason). + Logger() + ctx = log.WithContext(ctx) + + log.Info().Msg("Sending cancellation event") + cancelEvt := &event.VerificationCancelEventContent{Code: code, Reason: reason} + if len(txn.RoomID) > 0 { + // Sending the cancellation event to the room. + err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationCancel, cancelEvt) + if err != nil { + return fmt.Errorf("failed to send cancel verification event (code: %s, reason: %s): %w", code, reason, err) + } + } else { + cancelEvt.SetTransactionID(txn.TransactionID) + req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{ + txn.TheirUser: {}, + }} + if len(txn.TheirDevice) > 0 { + // Send the cancellation event to only the device that accepted the + // verification request. All of the other devices already received a + // cancellation event with code "m.acceped". + req.Messages[txn.TheirUser][txn.TheirDevice] = &event.Content{Parsed: cancelEvt} + } else { + // Send the cancellation event to all of the devices that we sent the + // request to. + for _, deviceID := range txn.SentToDeviceIDs { + if deviceID != vh.client.DeviceID { + req.Messages[txn.TheirUser][deviceID] = &event.Content{Parsed: cancelEvt} + } + } + } + _, err := vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) + if err != nil { + return fmt.Errorf("failed to send m.key.verification.cancel event to %v: %w", maps.Keys(req.Messages[txn.TheirUser]), err) + } + } + txn.VerificationState = verificationStateCancelled + return nil } // sendVerificationEvent sends a verification event to the other user's device @@ -549,10 +584,7 @@ func (vh *VerificationHelper) cancelVerificationTxn(ctx context.Context, txn *ve Str("code", string(code)). Str("reason", reason). Msg("Sending cancellation event") - cancelEvt := &event.VerificationCancelEventContent{ - Code: code, - Reason: reason, - } + cancelEvt := &event.VerificationCancelEventContent{Code: code, Reason: reason} err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationCancel, cancelEvt) if err != nil { return fmt.Errorf("failed to send cancel verification event (code: %s, reason: %s): %w", code, reason, err) diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go index 0f68e261..d6ed9b09 100644 --- a/crypto/verificationhelper/verificationhelper_test.go +++ b/crypto/verificationhelper/verificationhelper_test.go @@ -139,6 +139,56 @@ func TestVerification_Start(t *testing.T) { } } +func TestVerification_StartThenCancel(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + + for _, sendingCancels := range []bool{true, false} { + t.Run(fmt.Sprintf("sendingCancels=%t", sendingCancels), func(t *testing.T) { + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + defer ts.Close() + _, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + + txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) + require.NoError(t, err) + + assert.Empty(t, ts.DeviceInbox[aliceUserID][sendingDeviceID]) + + // Process the request event on the receiving device. + receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] + assert.Len(t, receivingInbox, 1) + assert.Equal(t, txnID, receivingInbox[0].Content.AsVerificationRequest().TransactionID) + ts.dispatchToDevice(t, ctx, receivingClient) + + // Cancel the verification request on the sending device. + var cancelEvt *event.VerificationCancelEventContent + if sendingCancels { + err = sendingHelper.CancelVerification(ctx, txnID, event.VerificationCancelCodeUser, "Recovery code preferred") + assert.NoError(t, err) + + // The sending device should not have a cancellation event. + assert.Empty(t, ts.DeviceInbox[aliceUserID][sendingDeviceID]) + + // Ensure that the cancellation event was sent to the receiving device. + assert.Len(t, ts.DeviceInbox[aliceUserID][receivingDeviceID], 1) + cancelEvt = ts.DeviceInbox[aliceUserID][receivingDeviceID][0].Content.AsVerificationCancel() + } else { + err = receivingHelper.CancelVerification(ctx, txnID, event.VerificationCancelCodeUser, "Recovery code preferred") + assert.NoError(t, err) + + // The receiving device should not have a cancellation event. + assert.Empty(t, ts.DeviceInbox[aliceUserID][receivingDeviceID]) + + // Ensure that the cancellation event was sent to the sending device. + assert.Len(t, ts.DeviceInbox[aliceUserID][sendingDeviceID], 1) + cancelEvt = ts.DeviceInbox[aliceUserID][sendingDeviceID][0].Content.AsVerificationCancel() + } + assert.Equal(t, txnID, cancelEvt.TransactionID) + assert.Equal(t, event.VerificationCancelCodeUser, cancelEvt.Code) + assert.Equal(t, "Recovery code preferred", cancelEvt.Reason) + }) + } +} + func TestVerification_Accept_NoSupportedMethods(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) From 3885a6378ea5ce5ffc8800a8bdc4bb2e4cf60d3d Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 27 May 2024 15:24:42 -0600 Subject: [PATCH 0235/1647] verificationhelper: cancel if multiple requests received from same device Signed-off-by: Sumner Evans --- .../verificationhelper/verificationhelper.go | 25 ++++++--- .../verificationhelper_test.go | 55 ++++++++++++++++++- 2 files changed, 72 insertions(+), 8 deletions(-) diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 1fc1fd22..aa38692a 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -663,13 +663,7 @@ func (vh *VerificationHelper) onVerificationRequest(ctx context.Context, evt *ev } vh.activeTransactionsLock.Lock() - existing, ok := vh.activeTransactions[verificationRequest.TransactionID] - if ok { - vh.activeTransactionsLock.Unlock() - vh.cancelVerificationTxn(ctx, existing, event.VerificationCancelCodeUnexpectedMessage, "received a new verification request for the same transaction ID") - return - } - vh.activeTransactions[verificationRequest.TransactionID] = &verificationTransaction{ + newTxn := &verificationTransaction{ RoomID: evt.RoomID, VerificationState: verificationStateRequested, TransactionID: verificationRequest.TransactionID, @@ -677,6 +671,23 @@ func (vh *VerificationHelper) onVerificationRequest(ctx context.Context, evt *ev TheirUser: evt.Sender, TheirSupportedMethods: verificationRequest.Methods, } + for existingTxnID, existingTxn := range vh.activeTransactions { + if existingTxn.TheirUser == evt.Sender && existingTxn.TheirDevice == verificationRequest.FromDevice { + vh.cancelVerificationTxn(ctx, existingTxn, event.VerificationCancelCodeUnexpectedMessage, "received multiple verification requests from the same device") + vh.cancelVerificationTxn(ctx, newTxn, event.VerificationCancelCodeUnexpectedMessage, "received multiple verification requests from the same device") + delete(vh.activeTransactions, existingTxnID) + vh.activeTransactionsLock.Unlock() + return + } + + if existingTxnID == verificationRequest.TransactionID { + vh.cancelVerificationTxn(ctx, existingTxn, event.VerificationCancelCodeUnexpectedMessage, "received a new verification request for the same transaction ID") + delete(vh.activeTransactions, existingTxnID) + vh.activeTransactionsLock.Unlock() + return + } + } + vh.activeTransactions[verificationRequest.TransactionID] = newTxn vh.activeTransactionsLock.Unlock() vh.verificationRequested(ctx, verificationRequest.TransactionID, evt.Sender) diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go index d6ed9b09..e8be5771 100644 --- a/crypto/verificationhelper/verificationhelper_test.go +++ b/crypto/verificationhelper/verificationhelper_test.go @@ -388,5 +388,58 @@ func TestVerification_ErrorOnDoubleAccept(t *testing.T) { err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) err = receivingHelper.AcceptVerification(ctx, txnID) - require.ErrorContains(t, err, "transaction is not in the requested state") + assert.ErrorContains(t, err, "transaction is not in the requested state") +} + +// TestVerification_CancelOnDoubleStart ensures that the receiving device +// cancels both transactions if the sending device starts two verifications. +// +// This test ensures that the following bullet point from [Section 10.12.2.2.1 +// of the Spec] is followed: +// +// - When the same device attempts to initiate multiple verification attempts, +// the recipient should cancel all attempts with that device. +// +// [Section 10.12.2.2.1 of the Spec]: https://spec.matrix.org/v1.10/client-server-api/#error-and-exception-handling +func TestVerification_CancelOnDoubleStart(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + defer ts.Close() + sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + + _, _, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + + // Send and accept the first verification request. + txnID1, err := sendingHelper.StartVerification(ctx, aliceUserID) + require.NoError(t, err) + ts.dispatchToDevice(t, ctx, receivingClient) + err = receivingHelper.AcceptVerification(ctx, txnID1) + require.NoError(t, err) + ts.dispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.ready event + + // Send a second verification request + txnID2, err := sendingHelper.StartVerification(ctx, aliceUserID) + require.NoError(t, err) + ts.dispatchToDevice(t, ctx, receivingClient) + + // Ensure that the sending device received a cancellation event for both of + // the ongoing transactions. + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] + require.Len(t, sendingInbox, 2) + cancelEvt1 := sendingInbox[0].Content.AsVerificationCancel() + cancelEvt2 := sendingInbox[1].Content.AsVerificationCancel() + cancelledTxnIDs := []id.VerificationTransactionID{cancelEvt1.TransactionID, cancelEvt2.TransactionID} + assert.Contains(t, cancelledTxnIDs, txnID1) + assert.Contains(t, cancelledTxnIDs, txnID2) + assert.Equal(t, event.VerificationCancelCodeUnexpectedMessage, cancelEvt1.Code) + assert.Equal(t, event.VerificationCancelCodeUnexpectedMessage, cancelEvt2.Code) + assert.Equal(t, "received multiple verification requests from the same device", cancelEvt1.Reason) + assert.Equal(t, "received multiple verification requests from the same device", cancelEvt2.Reason) + + assert.NotNil(t, receivingCallbacks.GetVerificationCancellation(txnID1)) + assert.NotNil(t, receivingCallbacks.GetVerificationCancellation(txnID2)) + ts.dispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.cancel events + assert.NotNil(t, sendingCallbacks.GetVerificationCancellation(txnID1)) + assert.NotNil(t, sendingCallbacks.GetVerificationCancellation(txnID2)) } From cd4146f72810b0112d8a7226c01286c198ec845f Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 27 May 2024 16:18:03 -0600 Subject: [PATCH 0236/1647] verificationhelper: make auto-cancellations more spec-compliant * Prevents sending cancellation events in response to cancellation events that we don't know about. * Streamlines sending cancellations for all other unknown-transaction cases. * Ensures that the activeTransactionsLock is locked when calling cancelVerificationTxn. Signed-off-by: Sumner Evans --- crypto/verificationhelper/reciprocate.go | 4 +- crypto/verificationhelper/sas.go | 4 +- .../verificationhelper/verificationhelper.go | 88 +++++++++---------- 3 files changed, 46 insertions(+), 50 deletions(-) diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index 9c6de067..2ea0a0ed 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -177,7 +177,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by txn.SentOurDone = true if txn.ReceivedTheirDone { log.Debug().Msg("We already received their done event. Setting verification state to done.") - txn.VerificationState = verificationStateDone + delete(vh.activeTransactions, txn.TransactionID) vh.verificationDone(ctx, txn.TransactionID) } return nil @@ -244,7 +244,7 @@ func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id } txn.SentOurDone = true if txn.ReceivedTheirDone { - txn.VerificationState = verificationStateDone + delete(vh.activeTransactions, txn.TransactionID) vh.verificationDone(ctx, txn.TransactionID) } return nil diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index 6e523ba5..bf8c6050 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -564,6 +564,8 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi Str("verification_action", "mac"). Logger() log.Info().Msg("Received SAS verification MAC event") + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() macEvt := evt.Content.AsVerificationMAC() // Verifying Keys MAC @@ -646,8 +648,6 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi } log.Info().Msg("All MACs verified") - vh.activeTransactionsLock.Lock() - defer vh.activeTransactionsLock.Unlock() txn.ReceivedTheirMAC = true if txn.SentOurMAC { txn.VerificationState = verificationStateSASMACExchanged diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index aa38692a..c7f77734 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -29,8 +29,6 @@ type verificationState int const ( verificationStateRequested verificationState = iota verificationStateReady - verificationStateCancelled - verificationStateDone verificationStateTheirQRScanned // We scanned their QR code verificationStateOurQRScanned // They scanned our QR code @@ -47,8 +45,6 @@ func (step verificationState) String() string { return "requested" case verificationStateReady: return "ready" - case verificationStateCancelled: - return "cancelled" case verificationStateTheirQRScanned: return "their_qr_scanned" case verificationStateOurQRScanned: @@ -249,54 +245,49 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { if evt.ID != "" { transactionID = id.VerificationTransactionID(evt.ID) } else { - txnID, ok := evt.Content.Raw["transaction_id"].(string) - if !ok { + if txnID, ok := evt.Content.Parsed.(event.VerificationTransactionable); !ok { log.Warn().Msg("Ignoring verification event without a transaction ID") return + } else { + transactionID = txnID.GetTransactionID() } - transactionID = id.VerificationTransactionID(txnID) } log = log.With().Stringer("transaction_id", transactionID).Logger() vh.activeTransactionsLock.Lock() txn, ok := vh.activeTransactions[transactionID] - vh.activeTransactionsLock.Unlock() - if !ok || txn.VerificationState == verificationStateCancelled || txn.VerificationState == verificationStateDone { - var code event.VerificationCancelCode - var reason string - if !ok { - log.Warn().Msg("Ignoring verification event for an unknown transaction and sending cancellation") - - // We have to create a fake transaction so that the call to - // verificationCancelled works. - txn = &verificationTransaction{ - RoomID: evt.RoomID, - TheirUser: evt.Sender, - } - txn.TransactionID = evt.Content.Parsed.(event.VerificationTransactionable).GetTransactionID() - if txn.TransactionID == "" { - txn.TransactionID = id.VerificationTransactionID(evt.ID) - } - if fromDevice, ok := evt.Content.Raw["from_device"]; ok { - txn.TheirDevice = id.DeviceID(fromDevice.(string)) - } - code = event.VerificationCancelCodeUnknownTransaction - reason = "The transaction ID was not recognized." - } else if txn.VerificationState == verificationStateCancelled { - log.Warn().Msg("Ignoring verification event for a cancelled transaction") - code = event.VerificationCancelCodeUnexpectedMessage - reason = "The transaction is cancelled." - } else if txn.VerificationState == verificationStateDone { - code = event.VerificationCancelCodeUnexpectedMessage - reason = "The transaction is done." + if !ok { + // If it's a cancellation event for an unknown transaction, we + // can just ignore it. + if evt.Type == event.ToDeviceVerificationCancel || evt.Type == event.InRoomVerificationCancel { + log.Info().Msg("Ignoring verification cancellation event for an unknown transaction") + return } - // Send the actual cancellation event. - vh.cancelVerificationTxn(ctx, txn, code, reason) + // We have to create a fake transaction so that the call to + // verificationCancelled works. + txn = &verificationTransaction{ + RoomID: evt.RoomID, + TheirUser: evt.Sender, + } + if transactionable, ok := evt.Content.Parsed.(event.VerificationTransactionable); ok { + txn.TransactionID = transactionable.GetTransactionID() + } else { + txn.TransactionID = id.VerificationTransactionID(evt.ID) + } + if fromDevice, ok := evt.Content.Raw["from_device"]; ok { + txn.TheirDevice = id.DeviceID(fromDevice.(string)) + } + + // Send a cancellation event. + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnknownTransaction, "The transaction ID was not recognized.") + vh.activeTransactionsLock.Unlock() return + } else { + vh.activeTransactionsLock.Unlock() } - logCtx := vh.getLog(ctx).With(). + logCtx := log.With(). Stringer("transaction_step", txn.VerificationState). Stringer("sender", evt.Sender) if evt.RoomID != "" { @@ -491,6 +482,9 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V // be the transaction ID of a verification request that was received via the // VerificationRequested callback in [RequiredCallbacks]. func (vh *VerificationHelper) CancelVerification(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) error { + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + txn, ok := vh.activeTransactions[txnID] if !ok { return fmt.Errorf("unknown transaction ID") @@ -535,7 +529,7 @@ func (vh *VerificationHelper) CancelVerification(ctx context.Context, txnID id.V return fmt.Errorf("failed to send m.key.verification.cancel event to %v: %w", maps.Keys(req.Messages[txn.TheirUser]), err) } } - txn.VerificationState = verificationStateCancelled + delete(vh.activeTransactions, txn.TransactionID) return nil } @@ -576,6 +570,8 @@ func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn *ve // and reason. It always returns an error, which is the formatted error message // (this is allows the caller to return the result of this function call // directly to expose the error to its caller). +// +// Must always be called with the activeTransactionsLock held. func (vh *VerificationHelper) cancelVerificationTxn(ctx context.Context, txn *verificationTransaction, code event.VerificationCancelCode, reasonFmtStr string, fmtArgs ...any) error { log := vh.getLog(ctx) reason := fmt.Errorf(reasonFmtStr, fmtArgs...).Error() @@ -589,7 +585,7 @@ func (vh *VerificationHelper) cancelVerificationTxn(ctx context.Context, txn *ve if err != nil { return fmt.Errorf("failed to send cancel verification event (code: %s, reason: %s): %w", code, reason, err) } - txn.VerificationState = verificationStateCancelled + delete(vh.activeTransactions, txn.TransactionID) vh.verificationCancelledCallback(ctx, txn.TransactionID, code, reason) return fmt.Errorf("verification cancelled (code: %s): %s", code, reason) } @@ -698,14 +694,14 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *veri Str("verification_action", "verification ready"). Logger() + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + if txn.VerificationState != verificationStateRequested { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "verification ready event received for a transaction that is not in the requested state") return } - vh.activeTransactionsLock.Lock() - defer vh.activeTransactionsLock.Unlock() - readyEvt := evt.Content.AsVerificationReady() // Update the transaction state. @@ -848,7 +844,7 @@ func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn *verif txn.ReceivedTheirDone = true if txn.SentOurDone { - txn.VerificationState = verificationStateDone + delete(vh.activeTransactions, txn.TransactionID) vh.verificationDone(ctx, txn.TransactionID) } } @@ -863,6 +859,6 @@ func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn *ver Msg("Verification was cancelled") vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn.VerificationState = verificationStateCancelled + delete(vh.activeTransactions, txn.TransactionID) vh.verificationCancelledCallback(ctx, txn.TransactionID, cancelEvt.Code, cancelEvt.Reason) } From 0b10e7346dd1fe55b57770b8cbc2c7901294b6e9 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 27 May 2024 18:23:15 -0600 Subject: [PATCH 0237/1647] verificationhelper: implement timeout logic Added 10-minute timeout for verification requests as per https://spec.matrix.org/v1.10/client-server-api/#error-and-exception-handling Signed-off-by: Sumner Evans --- .../verificationhelper/verificationhelper.go | 28 ++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index c7f77734..fb6b1b40 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -12,6 +12,7 @@ import ( "crypto/ecdh" "fmt" "sync" + "time" "github.com/rs/zerolog" "go.mau.fi/util/jsontime" @@ -349,14 +350,16 @@ func (vh *VerificationHelper) StartVerification(ctx context.Context, to id.UserI Any("device_ids", maps.Keys(devices)). Msg("Sending verification request") + now := time.Now() content := &event.Content{ Parsed: &event.VerificationRequestEventContent{ ToDeviceVerificationEvent: event.ToDeviceVerificationEvent{TransactionID: txnID}, FromDevice: vh.client.DeviceID, Methods: vh.supportedMethods, - Timestamp: jsontime.UnixMilliNow(), + Timestamp: jsontime.UM(now), }, } + vh.expireTransactionAt(txnID, now.Add(time.Minute*10)) req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{to: {}}} for deviceID := range devices { @@ -583,6 +586,7 @@ func (vh *VerificationHelper) cancelVerificationTxn(ctx context.Context, txn *ve cancelEvt := &event.VerificationCancelEventContent{Code: code, Reason: reason} err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationCancel, cancelEvt) if err != nil { + log.Err(err).Msg("failed to send cancellation event") return fmt.Errorf("failed to send cancel verification event (code: %s, reason: %s): %w", code, reason, err) } delete(vh.activeTransactions, txn.TransactionID) @@ -623,6 +627,11 @@ func (vh *VerificationHelper) onVerificationRequest(ctx context.Context, evt *ev return } + if verificationRequest.Timestamp.Add(10 * time.Minute).Before(time.Now()) { + log.Warn().Msg("Ignoring verification request that is over ten minutes old") + return + } + if len(verificationRequest.TransactionID) == 0 { log.Warn().Msg("Ignoring verification request without a transaction ID") return @@ -686,9 +695,26 @@ func (vh *VerificationHelper) onVerificationRequest(ctx context.Context, evt *ev vh.activeTransactions[verificationRequest.TransactionID] = newTxn vh.activeTransactionsLock.Unlock() + vh.expireTransactionAt(verificationRequest.TransactionID, verificationRequest.Timestamp.Add(time.Minute*10)) vh.verificationRequested(ctx, verificationRequest.TransactionID, evt.Sender) } +func (vh *VerificationHelper) expireTransactionAt(txnID id.VerificationTransactionID, expireAt time.Time) { + go func() { + time.Sleep(time.Until(expireAt)) + + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + + txn, ok := vh.activeTransactions[txnID] + if !ok { + return + } + + vh.cancelVerificationTxn(context.Background(), txn, event.VerificationCancelCodeTimeout, "verification timed out") + }() +} + func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *verificationTransaction, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "verification ready"). From 3c7b3e13efe494ba223ff553348f1e046416d4f3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 28 May 2024 20:49:23 +0300 Subject: [PATCH 0238/1647] Add initial user and room metadata support --- bridgev2/database/ghost.go | 55 +++-- bridgev2/database/portal.go | 47 ++-- bridgev2/database/upgrades/00-latest.sql | 44 ++-- bridgev2/database/userlogin.go | 15 +- bridgev2/ghost.go | 139 +++++++++++ bridgev2/matrix/connector.go | 20 ++ bridgev2/matrix/intent.go | 7 + bridgev2/matrixinterface.go | 2 + bridgev2/networkinterface.go | 3 + bridgev2/portal.go | 290 ++++++++++++++++++++--- bridgev2/queue.go | 2 +- bridgev2/userlogin.go | 10 + client.go | 4 +- event/beeper.go | 9 + event/state.go | 6 +- 15 files changed, 552 insertions(+), 101 deletions(-) diff --git a/bridgev2/database/ghost.go b/bridgev2/database/ghost.go index e56eb13a..a814e7c4 100644 --- a/bridgev2/database/ghost.go +++ b/bridgev2/database/ghost.go @@ -8,6 +8,7 @@ package database import ( "context" + "encoding/hex" "go.mau.fi/util/dbutil" @@ -20,16 +21,25 @@ type GhostQuery struct { *dbutil.QueryHelper[*Ghost] } +type GhostMetadata struct { + IsBot bool `json:"is_bot,omitempty"` + Identifiers []string `json:"identifiers,omitempty"` + ContactInfoSet bool `json:"contact_info_set,omitempty"` + + Extra map[string]any `json:"extra"` +} + type Ghost struct { BridgeID networkid.BridgeID ID networkid.UserID - Name string - AvatarID networkid.AvatarID - AvatarMXC id.ContentURIString - NameSet bool - AvatarSet bool - Metadata map[string]any + Name string + AvatarID networkid.AvatarID + AvatarHash [32]byte + AvatarMXC id.ContentURIString + NameSet bool + AvatarSet bool + Metadata GhostMetadata } func newGhost(_ *dbutil.QueryHelper[*Ghost]) *Ghost { @@ -38,15 +48,15 @@ func newGhost(_ *dbutil.QueryHelper[*Ghost]) *Ghost { const ( getGhostBaseQuery = ` - SELECT bridge_id, id, name, avatar_id, avatar_mxc, name_set, avatar_set, metadata FROM ghost + SELECT bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc, name_set, avatar_set, metadata FROM ghost ` getGhostByIDQuery = getGhostBaseQuery + `WHERE bridge_id=$1 AND id=$2` insertGhostQuery = ` - INSERT INTO ghost (bridge_id, id, name, avatar_id, avatar_mxc, name_set, avatar_set, metadata) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + INSERT INTO ghost (bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc, name_set, avatar_set, metadata) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) ` updateGhostQuery = ` - UPDATE ghost SET name=$3, avatar_id=$4, avatar_mxc=$5, name_set=$6, avatar_set=$7, metadata=$8 + UPDATE ghost SET name=$3, avatar_id=$4, avatar_hash=$5, avatar_mxc=$6, name_set=$7, avatar_set=$8, metadata=$9 WHERE bridge_id=$1 AND id=$2 ` ) @@ -66,27 +76,38 @@ func (gq *GhostQuery) Update(ctx context.Context, ghost *Ghost) error { } func (g *Ghost) Scan(row dbutil.Scannable) (*Ghost, error) { + var avatarHash string err := row.Scan( &g.BridgeID, &g.ID, - &g.Name, &g.AvatarID, &g.AvatarMXC, + &g.Name, &g.AvatarID, &avatarHash, &g.AvatarMXC, &g.NameSet, &g.AvatarSet, dbutil.JSON{Data: &g.Metadata}, ) if err != nil { return nil, err } - if g.Metadata == nil { - g.Metadata = make(map[string]any) + if g.Metadata.Extra == nil { + g.Metadata.Extra = make(map[string]any) + } + if avatarHash != "" { + data, _ := hex.DecodeString(avatarHash) + if len(data) == 32 { + g.AvatarHash = *(*[32]byte)(data) + } } return g, nil } func (g *Ghost) sqlVariables() []any { - if g.Metadata == nil { - g.Metadata = make(map[string]any) + if g.Metadata.Extra == nil { + g.Metadata.Extra = make(map[string]any) + } + var avatarHash string + if g.AvatarHash != [32]byte{} { + avatarHash = hex.EncodeToString(g.AvatarHash[:]) } return []any{ g.BridgeID, g.ID, - g.Name, g.AvatarID, g.AvatarMXC, - g.NameSet, g.AvatarSet, dbutil.JSON{Data: g.Metadata}, + g.Name, g.AvatarID, avatarHash, g.AvatarMXC, + g.NameSet, g.AvatarSet, dbutil.JSON{Data: &g.Metadata}, } } diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index c0f581f6..36a38dfe 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -9,6 +9,7 @@ package database import ( "context" "database/sql" + "encoding/hex" "go.mau.fi/util/dbutil" @@ -26,16 +27,17 @@ type Portal struct { ID networkid.PortalID MXID id.RoomID - ParentID networkid.PortalID - Name string - Topic string - AvatarID networkid.AvatarID - AvatarMXC id.ContentURIString - NameSet bool - TopicSet bool - AvatarSet bool - InSpace bool - Metadata map[string]any + ParentID networkid.PortalID + Name string + Topic string + AvatarID networkid.AvatarID + AvatarHash [32]byte + AvatarMXC id.ContentURIString + NameSet bool + TopicSet bool + AvatarSet bool + InSpace bool + Metadata map[string]any } func newPortal(_ *dbutil.QueryHelper[*Portal]) *Portal { @@ -44,7 +46,7 @@ func newPortal(_ *dbutil.QueryHelper[*Portal]) *Portal { const ( getPortalBaseQuery = ` - SELECT bridge_id, id, mxid, parent_id, name, topic, avatar_id, avatar_mxc, + SELECT bridge_id, id, mxid, parent_id, name, topic, avatar_id, avatar_hash, avatar_mxc, name_set, topic_set, avatar_set, in_space, metadata FROM portal @@ -56,15 +58,15 @@ const ( insertPortalQuery = ` INSERT INTO portal ( bridge_id, id, mxid, - parent_id, name, topic, avatar_id, avatar_mxc, + parent_id, name, topic, avatar_id, avatar_hash, avatar_mxc, name_set, avatar_set, topic_set, in_space, metadata - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) ` updatePortalQuery = ` UPDATE portal - SET mxid=$3, parent_id=$4, name=$5, topic=$6, avatar_id=$7, avatar_mxc=$8, - name_set=$9, avatar_set=$10, topic_set=$11, in_space=$12, metadata=$13 + SET mxid=$3, parent_id=$4, name=$5, topic=$6, avatar_id=$7, avatar_hash=$8, avatar_mxc=$9, + name_set=$10, avatar_set=$11, topic_set=$12, in_space=$13, metadata=$14 WHERE bridge_id=$1 AND id=$2 ` reIDPortalQuery = `UPDATE portal SET id=$3 WHERE bridge_id=$1 AND id=$2` @@ -98,9 +100,10 @@ func (pq *PortalQuery) Update(ctx context.Context, p *Portal) error { func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { var mxid, parentID sql.NullString + var avatarHash string err := row.Scan( &p.BridgeID, &p.ID, &mxid, - &parentID, &p.Name, &p.Topic, &p.AvatarID, &p.AvatarMXC, + &parentID, &p.Name, &p.Topic, &p.AvatarID, &avatarHash, &p.AvatarMXC, &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.InSpace, dbutil.JSON{Data: &p.Metadata}, ) @@ -110,6 +113,12 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { if p.Metadata == nil { p.Metadata = make(map[string]any) } + if avatarHash != "" { + data, _ := hex.DecodeString(avatarHash) + if len(data) == 32 { + p.AvatarHash = *(*[32]byte)(data) + } + } p.MXID = id.RoomID(mxid.String) p.ParentID = networkid.PortalID(parentID.String) return p, nil @@ -119,9 +128,13 @@ func (p *Portal) sqlVariables() []any { if p.Metadata == nil { p.Metadata = make(map[string]any) } + var avatarHash string + if p.AvatarHash != [32]byte{} { + avatarHash = hex.EncodeToString(p.AvatarHash[:]) + } return []any{ p.BridgeID, p.ID, dbutil.StrPtr(p.MXID), - dbutil.StrPtr(p.ParentID), p.Name, p.Topic, p.AvatarID, p.AvatarMXC, + dbutil.StrPtr(p.ParentID), p.Name, p.Topic, p.AvatarID, avatarHash, p.AvatarMXC, p.NameSet, p.TopicSet, p.AvatarSet, p.InSpace, dbutil.JSON{Data: p.Metadata}, } diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index 2f107689..b64a5507 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,19 +1,20 @@ -- v0 -> v1: Latest revision CREATE TABLE portal ( - bridge_id TEXT NOT NULL, - id TEXT NOT NULL, - mxid TEXT, + bridge_id TEXT NOT NULL, + id TEXT NOT NULL, + mxid TEXT, - parent_id TEXT, - name TEXT NOT NULL, - topic TEXT NOT NULL, - avatar_id TEXT NOT NULL, - avatar_mxc TEXT NOT NULL, - name_set BOOLEAN NOT NULL, - avatar_set BOOLEAN NOT NULL, - topic_set BOOLEAN NOT NULL, - in_space BOOLEAN NOT NULL, - metadata jsonb NOT NULL, + parent_id TEXT, + name TEXT NOT NULL, + topic TEXT NOT NULL, + avatar_id TEXT NOT NULL, + avatar_hash TEXT NOT NULL, + avatar_mxc TEXT NOT NULL, + name_set BOOLEAN NOT NULL, + avatar_set BOOLEAN NOT NULL, + topic_set BOOLEAN NOT NULL, + in_space BOOLEAN NOT NULL, + metadata jsonb NOT NULL, PRIMARY KEY (bridge_id, id), CONSTRAINT portal_parent_fkey FOREIGN KEY (bridge_id, parent_id) @@ -23,15 +24,16 @@ CREATE TABLE portal ( ); CREATE TABLE ghost ( - bridge_id TEXT NOT NULL, - id TEXT NOT NULL, + bridge_id TEXT NOT NULL, + id TEXT NOT NULL, - name TEXT NOT NULL, - avatar_id TEXT NOT NULL, - avatar_mxc TEXT NOT NULL, - name_set BOOLEAN NOT NULL, - avatar_set BOOLEAN NOT NULL, - metadata jsonb NOT NULL, + name TEXT NOT NULL, + avatar_id TEXT NOT NULL, + avatar_hash TEXT NOT NULL, + avatar_mxc TEXT NOT NULL, + name_set BOOLEAN NOT NULL, + avatar_set BOOLEAN NOT NULL, + metadata jsonb NOT NULL, PRIMARY KEY (bridge_id, id) ); diff --git a/bridgev2/database/userlogin.go b/bridgev2/database/userlogin.go index 8dfd6cbc..bec14841 100644 --- a/bridgev2/database/userlogin.go +++ b/bridgev2/database/userlogin.go @@ -37,9 +37,14 @@ const ( getUserLoginBaseQuery = ` SELECT bridge_id, user_mxid, id, space_room, metadata FROM user_login ` - getAllLoginsQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1` - getAllLoginsForUserQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND user_mxid=$2` - insertUserLoginQuery = ` + getAllLoginsQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1` + getAllLoginsForUserQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND user_mxid=$2` + getAllLoginsInPortalQuery = ` + SELECT ul.bridge_id, ul.user_mxid, ul.id, ul.space_room, ul.metadata FROM user_portal + LEFT JOIN user_login ul ON user_portal.bridge_id=ul.bridge_id AND user_portal.user_mxid=ul.user_mxid AND user_portal.login_id=ul.id + WHERE user_portal.bridge_id=$1 AND user_portal.portal_id=$2 + ` + insertUserLoginQuery = ` INSERT INTO user_login (bridge_id, user_mxid, id, space_room, metadata) VALUES ($1, $2, $3, $4, $5) ` @@ -66,6 +71,10 @@ func (uq *UserLoginQuery) GetAll(ctx context.Context) ([]*UserLogin, error) { return uq.QueryMany(ctx, getAllLoginsQuery, uq.BridgeID) } +func (uq *UserLoginQuery) GetAllInPortal(ctx context.Context, portalID networkid.PortalID) ([]*UserLogin, error) { + return uq.QueryMany(ctx, getAllLoginsInPortalQuery, uq.BridgeID, portalID) +} + func (uq *UserLoginQuery) GetAllForUser(ctx context.Context, userID id.UserID) ([]*UserLogin, error) { return uq.QueryMany(ctx, getAllLoginsForUserQuery, uq.BridgeID, userID) } diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index 467cafed..79eff612 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -8,12 +8,17 @@ package bridgev2 import ( "context" + "crypto/sha256" "fmt" + "net/http" "github.com/rs/zerolog" + "go.mau.fi/util/exmime" + "golang.org/x/exp/slices" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -85,3 +90,137 @@ func (ghost *Ghost) IntentFor(portal *Portal) MatrixAPI { // TODO use user double puppet intent if appropriate return ghost.Intent } + +type Avatar struct { + ID networkid.AvatarID + Get func(ctx context.Context) ([]byte, error) + Remove bool +} + +func (a *Avatar) Reupload(ctx context.Context, intent MatrixAPI, currentHash [32]byte) (id.ContentURIString, [32]byte, error) { + data, err := a.Get(ctx) + if err != nil { + return "", [32]byte{}, err + } + hash := sha256.Sum256(data) + if hash == currentHash { + return "", hash, nil + } + mime := http.DetectContentType(data) + fileName := "avatar" + exmime.ExtensionFromMimetype(mime) + uri, _, err := intent.UploadMedia(ctx, "", data, fileName, mime) + if err != nil { + return "", hash, err + } + return uri, hash, nil +} + +type UserInfo struct { + Identifiers []string + Name *string + Avatar *Avatar + IsBot *bool +} + +func (ghost *Ghost) UpdateName(ctx context.Context, name string) bool { + if ghost.Name == name && ghost.NameSet { + return false + } + ghost.Name = name + ghost.NameSet = false + err := ghost.Intent.SetDisplayName(ctx, name) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to set display name") + } else { + ghost.NameSet = true + } + return true +} + +func (ghost *Ghost) UpdateAvatar(ctx context.Context, avatar *Avatar) bool { + if ghost.AvatarID == avatar.ID && ghost.AvatarSet { + return false + } + ghost.AvatarID = avatar.ID + if !avatar.Remove { + newMXC, newHash, err := avatar.Reupload(ctx, ghost.Intent, ghost.AvatarHash) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to reupload avatar") + return true + } else if newHash == ghost.AvatarHash { + return true + } + ghost.AvatarMXC = newMXC + } else { + ghost.AvatarMXC = "" + } + ghost.AvatarSet = false + if err := ghost.Intent.SetAvatarURL(ctx, ghost.AvatarMXC); err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to set avatar URL") + } else { + ghost.AvatarSet = true + } + return true +} + +func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, isBot *bool) bool { + if identifiers != nil { + slices.Sort(identifiers) + } + if ghost.Metadata.ContactInfoSet && + (identifiers == nil || slices.Equal(identifiers, ghost.Metadata.Identifiers)) && + (isBot == nil || *isBot == ghost.Metadata.IsBot) { + return false + } + if identifiers != nil { + ghost.Metadata.Identifiers = identifiers + } + if isBot != nil { + ghost.Metadata.IsBot = *isBot + } + meta := &event.BeeperProfileExtra{ + RemoteID: string(ghost.ID), + Identifiers: ghost.Metadata.Identifiers, + Service: "", // TODO set + Network: "", // TODO set + IsBridgeBot: false, + IsNetworkBot: ghost.Metadata.IsBot, + } + err := ghost.Intent.SetExtraProfileMeta(ctx, meta) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to set extra profile metadata") + } else { + ghost.Metadata.ContactInfoSet = true + } + return true +} + +func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin) { + if ghost.Name != "" && ghost.NameSet { + return + } + info, err := source.Client.GetUserInfo(ctx, ghost) + if err != nil { + zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to get info to update ghost") + } + ghost.UpdateInfo(ctx, info) +} + +func (ghost *Ghost) UpdateInfo(ctx context.Context, info *UserInfo) { + update := false + if info.Name != nil { + update = ghost.UpdateName(ctx, *info.Name) || update + } + if info.Avatar != nil { + update = ghost.UpdateAvatar(ctx, info.Avatar) || update + } + if info.Identifiers != nil || info.IsBot != nil { + update = ghost.UpdateContactInfo(ctx, info.Identifiers, info.IsBot) || update + } + if update { + err := ghost.Bridge.DB.Ghost.Update(ctx, ghost.Ghost) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to update ghost in database after updating info") + } + } +} diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 22dc50b3..d4af2506 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -51,6 +51,8 @@ type Connector struct { Config *bridgeconfig.Config Bridge *bridgev2.Bridge + SpecVersions *mautrix.RespVersions + EventProcessor *appservice.EventProcessor userIDRegex *regexp.Regexp @@ -102,6 +104,10 @@ func (br *Connector) Start(ctx context.Context) error { return err } go br.AS.Start() + br.SpecVersions, err = br.Bot.Versions(ctx) + if err != nil { + return err + } if br.Crypto != nil { err = br.Crypto.Init(ctx) if err != nil { @@ -173,6 +179,20 @@ func (br *Connector) BotIntent() bridgev2.MatrixAPI { return &ASIntent{Connector: br, Matrix: br.Bot} } +func (br *Connector) GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { + // TODO use cache? + members, err := br.Bot.Members(ctx, roomID) + if err != nil { + return nil, err + } + output := make(map[id.UserID]*event.MemberEventContent, len(members.Chunk)) + for _, evt := range members.Chunk { + _ = evt.Content.ParseRaw(evt.Type) + output[id.UserID(evt.GetStateKey())] = evt.Content.AsMember() + } + return output, nil +} + func (br *Connector) GetMemberInfo(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) { // TODO fetch from network sometimes? return br.AS.StateStore.GetMember(ctx, roomID, userID) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index ef4913b9..b9158bf3 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -135,6 +135,13 @@ func (as *ASIntent) SetAvatarURL(ctx context.Context, avatarURL id.ContentURIStr return as.Matrix.SetAvatarURL(ctx, parsedAvatarURL) } +func (as *ASIntent) SetExtraProfileMeta(ctx context.Context, data any) error { + if !as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) { + return nil + } + return as.Matrix.BeeperUpdateProfile(ctx, data) +} + func (as *ASIntent) GetMXID() id.UserID { return as.Matrix.UserID } diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index e8e23628..4dd6dd46 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -29,6 +29,7 @@ type MatrixConnector interface { SendMessageStatus(ctx context.Context, status MessageStatus) + GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) GetMemberInfo(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) ServerName() string @@ -44,6 +45,7 @@ type MatrixAPI interface { SetDisplayName(ctx context.Context, name string) error SetAvatarURL(ctx context.Context, avatarURL id.ContentURIString) error + SetExtraProfileMeta(ctx context.Context, data any) error CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) (id.RoomID, error) DeleteRoom(ctx context.Context, roomID id.RoomID) error diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index b70fd89c..e8be6765 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -51,7 +51,10 @@ type NetworkConnector interface { type NetworkAPI interface { Connect(ctx context.Context) error IsLoggedIn() bool + + IsThisUser(ctx context.Context, userID networkid.UserID) bool GetChatInfo(ctx context.Context, portal *Portal) (*PortalInfo, error) + GetUserInfo(ctx context.Context, ghost *Ghost) (*UserInfo, error) HandleMatrixMessage(ctx context.Context, msg *MatrixMessage) (message *database.Message, err error) HandleMatrixEdit(ctx context.Context, msg *MatrixEdit) error diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 072d3971..046a56b5 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -8,6 +8,7 @@ package bridgev2 import ( "context" + "errors" "fmt" "strings" "sync" @@ -481,7 +482,7 @@ func (portal *Portal) getIntentFor(ctx context.Context, sender EventSender, sour zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost for message sender") return nil } - // TODO update ghost info + ghost.UpdateInfoIfNecessary(ctx, source) intent = ghost.Intent } return intent @@ -587,15 +588,246 @@ func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *Use var stateElementFunctionalMembers = event.Type{Class: event.StateEventType, Type: "io.element.functional_members"} type PortalInfo struct { - Name string - Topic string - AvatarID networkid.AvatarID - AvatarMXC id.ContentURIString + Name *string + Topic *string + Avatar *Avatar Members []networkid.UserID - IsDirectChat bool - IsSpace bool + IsDirectChat *bool + IsSpace *bool +} + +func (portal *Portal) UpdateName(ctx context.Context, name string, sender *Ghost, ts time.Time) bool { + if portal.Name == name && (portal.NameSet || portal.MXID == "") { + return false + } + portal.Name = name + portal.NameSet = portal.sendRoomMeta(ctx, sender, ts, event.StateRoomName, "", &event.RoomNameEventContent{Name: name}) + return true +} + +func (portal *Portal) UpdateTopic(ctx context.Context, topic string, sender *Ghost, ts time.Time) bool { + if portal.Topic == topic && (portal.TopicSet || portal.MXID == "") { + return false + } + portal.Topic = topic + portal.TopicSet = portal.sendRoomMeta(ctx, sender, ts, event.StateTopic, "", &event.TopicEventContent{Topic: topic}) + return true +} + +func (portal *Portal) UpdateAvatar(ctx context.Context, avatar *Avatar, sender *Ghost, ts time.Time) bool { + if portal.AvatarID == avatar.ID && (portal.AvatarSet || portal.MXID == "") { + return false + } + portal.AvatarID = avatar.ID + intent := portal.Bridge.Bot + if sender != nil { + intent = sender.IntentFor(portal) + } + if avatar.Remove { + portal.AvatarMXC = "" + portal.AvatarHash = [32]byte{} + } else { + newMXC, newHash, err := avatar.Reupload(ctx, intent, portal.AvatarHash) + if err != nil { + portal.AvatarSet = false + zerolog.Ctx(ctx).Err(err).Msg("Failed to reupload room avatar") + return true + } else if newHash == portal.AvatarHash { + return true + } + portal.AvatarMXC = newMXC + portal.AvatarHash = newHash + } + portal.AvatarSet = portal.sendRoomMeta(ctx, sender, ts, event.StateRoomAvatar, "", &event.RoomAvatarEventContent{URL: portal.AvatarMXC}) + return true +} + +func (portal *Portal) GetTopLevelParent() *Portal { + // TODO ensure there's no infinite recursion? + if portal.Parent == nil { + // TODO only return self if this is a space portal + return portal + } + return portal.Parent.GetTopLevelParent() +} + +func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) { + bridgeInfo := event.BridgeEventContent{ + BridgeBot: portal.Bridge.Bot.GetMXID(), + Creator: portal.Bridge.Bot.GetMXID(), + Protocol: event.BridgeInfoSection{ + ID: "signal", // TODO fill properly + DisplayName: "Signal", // TODO fill properly + AvatarURL: "", // TODO fill properly + ExternalURL: "https://signal.org/", // TODO fill properly + }, + Channel: event.BridgeInfoSection{ + ID: string(portal.ID), + DisplayName: portal.Name, + AvatarURL: portal.AvatarMXC, + // TODO external URL? + }, + // TODO room type + } + parent := portal.GetTopLevelParent() + if parent != nil { + bridgeInfo.Network = &event.BridgeInfoSection{ + ID: string(parent.ID), + DisplayName: parent.Name, + AvatarURL: parent.AvatarMXC, + // TODO external URL? + } + } + // TODO use something globally unique instead of bridge ID? + // maybe ask the matrix connector to use serverName+appserviceID+bridgeID + stateKey := string(portal.BridgeID) + return stateKey, bridgeInfo +} + +func (portal *Portal) UpdateBridgeInfo(ctx context.Context) { + if portal.MXID == "" { + return + } + stateKey, bridgeInfo := portal.getBridgeInfo() + portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBridge, stateKey, &bridgeInfo) + portal.sendRoomMeta(ctx, nil, time.Now(), event.StateHalfShotBridge, stateKey, &bridgeInfo) +} + +func (portal *Portal) sendRoomMeta(ctx context.Context, sender *Ghost, ts time.Time, eventType event.Type, stateKey string, content any) bool { + if portal.MXID == "" { + return false + } + + intent := portal.Bridge.Bot + if sender != nil { + intent = sender.IntentFor(portal) + } + wrappedContent := &event.Content{Parsed: content} + _, err := intent.SendState(ctx, portal.MXID, eventType, stateKey, wrappedContent, ts) + if errors.Is(err, mautrix.MForbidden) && intent != portal.Bridge.Bot { + wrappedContent.Raw = map[string]any{ + "fi.mau.bridge.set_by": intent.GetMXID(), + } + _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateRoomName, "", wrappedContent, ts) + } + if err != nil { + zerolog.Ctx(ctx).Err(err). + Str("event_type", eventType.Type). + Msg("Failed to set room metadata") + return false + } + return true +} + +func (portal *Portal) SyncParticipants(ctx context.Context, members []networkid.UserID, source *UserLogin) ([]id.UserID, error) { + loginsInPortal, err := portal.Bridge.GetUserLoginsInPortal(ctx, portal.ID) + if err != nil { + return nil, fmt.Errorf("failed to get user logins in portal: %w", err) + } + expectedUserIDs := make([]id.UserID, 0, len(members)) + expectedExtraUsers := make([]id.UserID, 0) + expectedIntents := make([]MatrixAPI, len(members)) + for i, member := range members { + for _, login := range loginsInPortal { + if login.Client.IsThisUser(ctx, member) { + userIntent := portal.Bridge.Matrix.UserIntent(login.User) + if userIntent != nil { + expectedIntents[i] = userIntent + } else { + expectedExtraUsers = append(expectedExtraUsers, login.UserMXID) + expectedUserIDs = append(expectedUserIDs, login.UserMXID) + } + break + } + } + ghost, err := portal.Bridge.GetGhostByID(ctx, member) + if err != nil { + return nil, fmt.Errorf("failed to get ghost for %s: %w", member, err) + } + ghost.UpdateInfoIfNecessary(ctx, source) + if expectedIntents[i] == nil { + expectedIntents[i] = ghost.Intent + } + expectedUserIDs = append(expectedUserIDs, expectedIntents[i].GetMXID()) + } + if portal.MXID == "" { + return expectedUserIDs, nil + } + currentMembers, err := portal.Bridge.Matrix.GetMembers(ctx, portal.MXID) + for _, intent := range expectedIntents { + mxid := intent.GetMXID() + memberEvt, ok := currentMembers[mxid] + delete(currentMembers, mxid) + if !ok || memberEvt.Membership != event.MembershipJoin { + err = intent.EnsureJoined(ctx, portal.MXID) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("user_id", mxid). + Msg("Failed to ensure user is joined to room") + } + } + } + for _, mxid := range expectedExtraUsers { + memberEvt, ok := currentMembers[mxid] + delete(currentMembers, mxid) + if !ok || (memberEvt.Membership != event.MembershipJoin && memberEvt.Membership != event.MembershipInvite) { + err = portal.Bridge.Bot.InviteUser(ctx, portal.MXID, mxid) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("user_id", mxid). + Msg("Failed to invite user to room") + } + } + } + if portal.Relay == nil { + for extraMember, memberEvt := range currentMembers { + if memberEvt.Membership == event.MembershipLeave || memberEvt.Membership == event.MembershipBan { + continue + } + _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateMember, extraMember.String(), &event.Content{ + Parsed: &event.MemberEventContent{ + Membership: event.MembershipLeave, + AvatarURL: memberEvt.AvatarURL, + Displayname: memberEvt.Displayname, + Reason: "User is not in remote chat", + }, + }, time.Now()) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("user_id", extraMember). + Msg("Failed to remove user from room") + } + } + } + return expectedUserIDs, nil +} + +func (portal *Portal) UpdateInfo(ctx context.Context, info *PortalInfo, sender *Ghost, ts time.Time) { + changed := false + if info.Name != nil { + changed = portal.UpdateName(ctx, *info.Name, sender, ts) || changed + } + if info.Topic != nil { + changed = portal.UpdateTopic(ctx, *info.Topic, sender, ts) || changed + } + if info.Avatar != nil { + changed = portal.UpdateAvatar(ctx, info.Avatar, sender, ts) || changed + } + //if info.Members != nil && portal.MXID != "" { + // _, err := portal.SyncParticipants(ctx, info.Members, source) + // if err != nil { + // zerolog.Ctx(ctx).Err(err).Msg("Failed to sync room members") + // } + //} + if changed { + portal.UpdateBridgeInfo(ctx) + err := portal.Bridge.DB.Portal.Update(ctx, portal.Portal) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal to database after updating info") + } + } } func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin) error { @@ -615,24 +847,12 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin) e log.Err(err).Msg("Failed to update portal info for creation") return err } - portal.Name = info.Name - portal.Topic = info.Topic - portal.AvatarID = info.AvatarID - portal.AvatarMXC = info.AvatarMXC - invite := make([]id.UserID, 0, len(info.Members)+1) - inviteIntents := make([]MatrixAPI, 0, len(info.Members)+1) - for _, memberID := range info.Members { - ghost, err := portal.Bridge.GetGhostByID(ctx, memberID) - if err != nil { - log.Err(err).Str("memebr_id", string(memberID)).Msg("Failed to get portal member ghost") - } else { - invite = append(invite, ghost.MXID) - inviteIntents = append(inviteIntents, ghost.Intent) - } + portal.UpdateInfo(ctx, info, nil, time.Time{}) + initialMembers, err := portal.SyncParticipants(ctx, info.Members, source) + if err != nil { + log.Err(err).Msg("Failed to process participant list for portal creation") + return err } - // TODO should the source user mxid come from members? - invite = append(invite, source.UserMXID) - inviteIntents = append(inviteIntents, portal.Bridge.Matrix.UserIntent(source.User)) req := mautrix.ReqCreateRoom{ Visibility: "private", @@ -641,23 +861,23 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin) e CreationContent: make(map[string]any), InitialState: make([]*event.Event, 0, 4), Preset: "private_chat", - IsDirect: info.IsDirectChat, + IsDirect: *info.IsDirectChat, PowerLevelOverride: &event.PowerLevelsEventContent{ Users: map[id.UserID]int{ portal.Bridge.Bot.GetMXID(): 9001, }, }, BeeperLocalRoomID: id.RoomID(fmt.Sprintf("!%s:%s", portal.ID, portal.Bridge.Matrix.ServerName())), - BeeperInitialMembers: invite, + BeeperInitialMembers: initialMembers, } // TODO find this properly from the matrix connector isBeeper := true // TODO remove this after initial_members is supported in hungryserv if isBeeper { req.BeeperAutoJoinInvites = true - req.Invite = invite + req.Invite = initialMembers } - if info.IsSpace { + if *info.IsSpace { req.CreationContent["type"] = event.RoomTypeSpace } emptyString := "" @@ -681,8 +901,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin) e req.InitialState = append(req.InitialState, &event.Event{ StateKey: &emptyString, Type: event.StateRoomAvatar, - // TODO change RoomAvatarEventContent to have id.ContentURIString instead of id.ContentURI? - Content: event.Content{Raw: map[string]any{"url": portal.AvatarMXC}}, + Content: event.Content{Parsed: &event.RoomAvatarEventContent{URL: portal.AvatarMXC}}, }) } if portal.Parent != nil { @@ -719,14 +938,9 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin) e // TODO add m.space.child event } if !isBeeper { - for i, mxid := range invite { - intent := inviteIntents[i] - // TODO handle errors - if intent != nil { - intent.EnsureJoined(ctx, portal.MXID) - } else { - portal.Bridge.Bot.InviteUser(ctx, portal.MXID, mxid) - } + _, err = portal.SyncParticipants(ctx, info.Members, source) + if err != nil { + log.Err(err).Msg("Failed to sync participants after room creation") } } return nil diff --git a/bridgev2/queue.go b/bridgev2/queue.go index f27094df..4e3a6549 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -79,7 +79,7 @@ func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) { ctx := log.WithContext(context.TODO()) portal, err := br.GetPortalByID(ctx, evt.GetPortalID()) if err != nil { - log.Err(err).Str("portal_id", string(portal.ID)). + log.Err(err).Str("portal_id", string(evt.GetPortalID())). Msg("Failed to get portal to handle remote event") return } diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 0a04662d..8d082ac1 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -83,6 +83,16 @@ func (br *Bridge) GetAllUserLogins(ctx context.Context) ([]*UserLogin, error) { return br.loadManyUserLogins(ctx, nil, logins) } +func (br *Bridge) GetUserLoginsInPortal(ctx context.Context, portalID networkid.PortalID) ([]*UserLogin, error) { + logins, err := br.DB.UserLogin.GetAllInPortal(ctx, portalID) + if err != nil { + return nil, err + } + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + return br.loadManyUserLogins(ctx, nil, logins) +} + func (br *Bridge) GetCachedUserLoginByID(id networkid.UserLoginID) *UserLogin { br.cacheLock.Lock() defer br.cacheLock.Unlock() diff --git a/client.go b/client.go index 10d1b2b9..ab99245c 100644 --- a/client.go +++ b/client.go @@ -946,9 +946,9 @@ func (cli *Client) SetAvatarURL(ctx context.Context, url id.ContentURI) (err err } // BeeperUpdateProfile sets custom fields in the user's profile. -func (cli *Client) BeeperUpdateProfile(ctx context.Context, data map[string]any) (err error) { +func (cli *Client) BeeperUpdateProfile(ctx context.Context, data any) (err error) { urlPath := cli.BuildClientURL("v3", "profile", cli.UserID) - _, err = cli.MakeRequest(ctx, http.MethodPatch, urlPath, &data, nil) + _, err = cli.MakeRequest(ctx, http.MethodPatch, urlPath, data, nil) return } diff --git a/event/beeper.go b/event/beeper.go index 5e412504..3287e494 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -88,3 +88,12 @@ type BeeperLinkPreview struct { MatchedURL string `json:"matched_url,omitempty"` ImageEncryption *EncryptedFileInfo `json:"beeper:image:encryption,omitempty"` } + +type BeeperProfileExtra struct { + RemoteID string `json:"com.beeper.bridge.remote_id,omitempty"` + Identifiers []string `json:"com.beeper.bridge.identifiers,omitempty"` + Service string `json:"com.beeper.bridge.service,omitempty"` + Network string `json:"com.beeper.bridge.network,omitempty"` + IsBridgeBot bool `json:"com.beeper.bridge.is_bridge_bot,omitempty"` + IsNetworkBot bool `json:"com.beeper.bridge.is_network_bot,omitempty"` +} diff --git a/event/state.go b/event/state.go index d6b6cf70..83f007a1 100644 --- a/event/state.go +++ b/event/state.go @@ -26,8 +26,8 @@ type RoomNameEventContent struct { // RoomAvatarEventContent represents the content of a m.room.avatar state event. // https://spec.matrix.org/v1.2/client-server-api/#mroomavatar type RoomAvatarEventContent struct { - URL id.ContentURI `json:"url"` - Info *FileInfo `json:"info,omitempty"` + URL id.ContentURIString `json:"url"` + Info *FileInfo `json:"info,omitempty"` } // ServerACLEventContent represents the content of a m.room.server_acl state event. @@ -149,6 +149,8 @@ type BridgeEventContent struct { Protocol BridgeInfoSection `json:"protocol"` Network *BridgeInfoSection `json:"network,omitempty"` Channel BridgeInfoSection `json:"channel"` + + BeeperRoomType string `json:"com.beeper.room_type,omitempty"` } type SpaceChildEventContent struct { From 248de0e6adb2ca42ca55bcb1755b7337dda0c260 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 29 May 2024 16:55:54 +0300 Subject: [PATCH 0239/1647] Add config field for network connector --- bridgev2/bridgeconfig/config.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index eeaa6d48..cfe1cffe 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -9,6 +9,7 @@ package bridgeconfig import ( "go.mau.fi/util/dbutil" "go.mau.fi/zeroconfig" + "gopkg.in/yaml.v3" ) type Config struct { @@ -22,6 +23,7 @@ type Config struct { Permissions PermissionConfig `yaml:"permissions"` ManagementRoomTexts ManagementRoomTexts `yaml:"management_room_texts"` Logging zeroconfig.Config `yaml:"logging"` + Network yaml.Node `yaml:"network"` } type BridgeConfig struct { From 57f6cd89e3bf148e7ab497ef076e054fd506e2ab Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 28 May 2024 17:24:37 -0600 Subject: [PATCH 0240/1647] (*Client).Download: return entire response instead of just body Signed-off-by: Sumner Evans --- appservice/intent.go | 6 +++--- client.go | 12 ++---------- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/appservice/intent.go b/appservice/intent.go index e091582a..39c22d7f 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -427,9 +427,9 @@ func (intent *IntentAPI) SetAvatarURL(ctx context.Context, avatarURL id.ContentU } if !avatarURL.IsEmpty() { // Some homeservers require the avatar to be downloaded before setting it - body, _ := intent.Client.Download(ctx, avatarURL) - if body != nil { - _ = body.Close() + resp, _ := intent.Download(ctx, avatarURL) + if resp != nil { + _ = resp.Body.Close() } } return intent.Client.SetAvatarURL(ctx, avatarURL) diff --git a/client.go b/client.go index 10d1b2b9..2ae05c44 100644 --- a/client.go +++ b/client.go @@ -1403,14 +1403,6 @@ func (cli *Client) GetDownloadURL(mxcURL id.ContentURI) string { return cli.BuildURLWithQuery(MediaURLPath{"v3", "download", mxcURL.Homeserver, mxcURL.FileID}, map[string]string{"allow_redirect": "true"}) } -func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (io.ReadCloser, error) { - resp, err := cli.download(ctx, mxcURL) - if err != nil { - return nil, err - } - return resp.Body, nil -} - func (cli *Client) doMediaRetry(req *http.Request, cause error, retries int, backoff time.Duration) (*http.Response, error) { log := zerolog.Ctx(req.Context()) if req.Body != nil { @@ -1467,7 +1459,7 @@ func (cli *Client) doMediaRequest(req *http.Request, retries int, backoff time.D return res, err } -func (cli *Client) download(ctx context.Context, mxcURL id.ContentURI) (*http.Response, error) { +func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (*http.Response, error) { ctxLog := zerolog.Ctx(ctx) if ctxLog.GetLevel() == zerolog.Disabled || ctxLog == zerolog.DefaultContextLogger { ctx = cli.Log.WithContext(ctx) @@ -1481,7 +1473,7 @@ func (cli *Client) download(ctx context.Context, mxcURL id.ContentURI) (*http.Re } func (cli *Client) DownloadBytes(ctx context.Context, mxcURL id.ContentURI) ([]byte, error) { - resp, err := cli.download(ctx, mxcURL) + resp, err := cli.Download(ctx, mxcURL) if err != nil { return nil, err } From 2ec680ba4e72502f808554ee09d3e7a81c32bdb7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 29 May 2024 23:45:08 +0300 Subject: [PATCH 0241/1647] Add receipt storage, relation caches, pagination and other stuff --- hicli/database/database.go | 6 + hicli/database/event.go | 255 ++++++++++++++++-- hicli/database/receipt.go | 69 +++++ hicli/database/room.go | 49 +++- hicli/database/state.go | 2 +- hicli/database/timeline.go | 73 ++++- .../database/upgrades/00-latest-revision.sql | 129 ++++++++- hicli/decryptionqueue.go | 8 + hicli/events.go | 25 ++ hicli/hicli.go | 10 + hicli/hitest/hitest.go | 1 + hicli/login.go | 9 +- hicli/paginate.go | 91 +++++++ hicli/sync.go | 208 +++++++++----- hicli/syncwrap.go | 8 +- 15 files changed, 816 insertions(+), 127 deletions(-) create mode 100644 hicli/database/receipt.go create mode 100644 hicli/events.go create mode 100644 hicli/paginate.go diff --git a/hicli/database/database.go b/hicli/database/database.go index c1273ab7..8ec2b42a 100644 --- a/hicli/database/database.go +++ b/hicli/database/database.go @@ -22,6 +22,7 @@ type Database struct { CurrentState CurrentStateQuery Timeline TimelineQuery SessionRequest SessionRequestQuery + Receipt ReceiptQuery } func New(rawDB *dbutil.Database) *Database { @@ -36,6 +37,7 @@ func New(rawDB *dbutil.Database) *Database { CurrentState: CurrentStateQuery{Database: rawDB}, Timeline: TimelineQuery{Database: rawDB}, SessionRequest: SessionRequestQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newSessionRequest)}, + Receipt: ReceiptQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newReceipt)}, } } @@ -51,6 +53,10 @@ func newRoom(_ *dbutil.QueryHelper[*Room]) *Room { return &Room{} } +func newReceipt(_ *dbutil.QueryHelper[*Receipt]) *Receipt { + return &Receipt{} +} + func newAccountData(_ *dbutil.QueryHelper[*AccountData]) *AccountData { return &AccountData{} } diff --git a/hicli/database/event.go b/hicli/database/event.go index b7b15eea..ab68583f 100644 --- a/hicli/database/event.go +++ b/hicli/database/event.go @@ -9,6 +9,8 @@ package database import ( "database/sql" "encoding/json" + "fmt" + "strings" "time" "github.com/tidwall/gjson" @@ -22,14 +24,15 @@ import ( const ( getEventBaseQuery = ` - SELECT rowid, room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, - redacted_by, relates_to, megolm_session_id, decryption_error + SELECT rowid, -1, room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, + redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, reactions, last_edit_rowid FROM event ` + getEventByID = getEventBaseQuery + `WHERE event_id = $1` getFailedEventsByMegolmSessionID = getEventBaseQuery + `WHERE room_id = $1 AND megolm_session_id = $2 AND decryption_error IS NOT NULL` upsertEventQuery = ` - INSERT INTO event (room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, redacted_by, relates_to, megolm_session_id, decryption_error) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) + INSERT INTO event (room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) ON CONFLICT (event_id) DO UPDATE SET decrypted=COALESCE(event.decrypted, excluded.decrypted), decrypted_type=COALESCE(event.decrypted_type, excluded.decrypted_type), @@ -38,6 +41,38 @@ const ( RETURNING rowid ` updateEventDecryptedQuery = `UPDATE event SET decrypted = $1, decrypted_type = $2, decryption_error = NULL WHERE rowid = $3` + getTimelineQuery = ` + SELECT event.rowid, timeline.rowid, event.room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, + redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, reactions, last_edit_rowid + FROM timeline + JOIN event ON event.rowid = timeline.event_rowid + WHERE timeline.room_id = $1 AND timeline.rowid < $2 + ORDER BY timeline.rowid DESC + LIMIT $3 + ` + getEventReactionsQuery = getEventBaseQuery + ` + WHERE room_id = ? + AND type = 'm.reaction' + AND relation_type = 'm.annotation' + AND redacted_by IS NULL + AND relates_to IN (%s) + ` + getEventEditRowIDsQuery = ` + SELECT main.event_id, edit.rowid + FROM event main + JOIN event edit ON edit.room_id = main.room_id + AND edit.relates_to = main.event_id + AND edit.relation_type = 'm.replace' + AND edit.type = main.type + AND edit.sender = main.sender + AND edit.redacted_by IS NULL + WHERE main.event_id IN (%s) + ORDER BY main.event_id, edit.timestamp + ` + setLastEditRowIDQuery = ` + UPDATE event SET last_edit_rowid = $2 WHERE event_id = $1 + ` + updateReactionCountsQuery = `UPDATE event SET reactions = $2 WHERE event_id = $1` ) type EventQuery struct { @@ -48,35 +83,171 @@ func (eq *EventQuery) GetFailedByMegolmSessionID(ctx context.Context, roomID id. return eq.QueryMany(ctx, getFailedEventsByMegolmSessionID, roomID, sessionID) } -func (eq *EventQuery) Upsert(ctx context.Context, evt *Event) (rowID int64, err error) { +func (eq *EventQuery) GetByID(ctx context.Context, eventID id.EventID) (*Event, error) { + return eq.QueryOne(ctx, getEventByID, eventID) +} + +func (eq *EventQuery) Upsert(ctx context.Context, evt *Event) (rowID EventRowID, err error) { err = eq.GetDB().QueryRow(ctx, upsertEventQuery, evt.sqlVariables()...).Scan(&rowID) + if err == nil { + evt.RowID = rowID + } return } -func (eq *EventQuery) UpdateDecrypted(ctx context.Context, rowID int64, decrypted json.RawMessage, decryptedType string) error { +func (eq *EventQuery) UpdateDecrypted(ctx context.Context, rowID EventRowID, decrypted json.RawMessage, decryptedType string) error { return eq.Exec(ctx, updateEventDecryptedQuery, unsafeJSONString(decrypted), decryptedType, rowID) } +func (eq *EventQuery) GetTimeline(ctx context.Context, roomID id.RoomID, limit int, before TimelineRowID) ([]*Event, error) { + return eq.QueryMany(ctx, getTimelineQuery, roomID, before, limit) +} + +func (eq *EventQuery) FillReactionCounts(ctx context.Context, roomID id.RoomID, events []*Event) error { + eventIDs := make([]id.EventID, 0) + eventMap := make(map[id.EventID]*Event) + for i, evt := range events { + if evt.Reactions == nil { + eventIDs[i] = evt.ID + eventMap[evt.ID] = evt + } + } + result, err := eq.GetReactions(ctx, roomID, eventIDs...) + if err != nil { + return err + } + for evtID, res := range result { + eventMap[evtID].Reactions = res.Counts + } + return nil +} + +func (eq *EventQuery) FillLastEditRowIDs(ctx context.Context, roomID id.RoomID, events []*Event) error { + eventIDs := make([]id.EventID, 0) + eventMap := make(map[id.EventID]*Event) + for i, evt := range events { + if evt.LastEditRowID == 0 { + eventIDs[i] = evt.ID + eventMap[evt.ID] = evt + } + } + return eq.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error { + result, err := eq.GetEditRowIDs(ctx, roomID, eventIDs...) + if err != nil { + return err + } + for evtID, res := range result { + lastEditRowID := res[len(res)-1] + eventMap[evtID].LastEditRowID = lastEditRowID + err = eq.Exec(ctx, setLastEditRowIDQuery, evtID, lastEditRowID) + if err != nil { + return err + } + } + return nil + }) +} + +var reactionKeyPath = exgjson.Path("m.relates_to", "key") + +type GetReactionsResult struct { + Events []*Event + Counts map[string]int +} + +func buildMultiEventGetFunction(roomID id.RoomID, eventIDs []id.EventID, query string) (string, []any) { + params := make([]any, len(eventIDs)+1) + params[0] = roomID + for i, evtID := range eventIDs { + params[i+1] = evtID + } + placeholders := strings.Repeat("?,", len(eventIDs)) + placeholders = placeholders[:len(placeholders)-1] + return fmt.Sprintf(query, placeholders), params +} + +type editRowIDTuple struct { + eventID id.EventID + editRowID EventRowID +} + +func (eq *EventQuery) GetEditRowIDs(ctx context.Context, roomID id.RoomID, eventIDs ...id.EventID) (map[id.EventID][]EventRowID, error) { + query, params := buildMultiEventGetFunction(roomID, eventIDs, getEventEditRowIDsQuery) + rows, err := eq.GetDB().Query(ctx, query, params...) + output := make(map[id.EventID][]EventRowID) + return output, dbutil.NewRowIterWithError(rows, func(row dbutil.Scannable) (tuple editRowIDTuple, err error) { + err = row.Scan(&tuple.eventID, &tuple.editRowID) + return + }, err).Iter(func(tuple editRowIDTuple) (bool, error) { + output[tuple.eventID] = append(output[tuple.eventID], tuple.editRowID) + return true, nil + }) +} + +func (eq *EventQuery) GetReactions(ctx context.Context, roomID id.RoomID, eventIDs ...id.EventID) (map[id.EventID]*GetReactionsResult, error) { + result := make(map[id.EventID]*GetReactionsResult, len(eventIDs)) + for _, evtID := range eventIDs { + result[evtID] = &GetReactionsResult{Counts: make(map[string]int)} + } + return result, eq.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error { + query, params := buildMultiEventGetFunction(roomID, eventIDs, getEventReactionsQuery) + events, err := eq.QueryMany(ctx, query, params...) + if err != nil { + return err + } else if len(events) == 0 { + return nil + } + for _, evt := range events { + dest := result[evt.RelatesTo] + dest.Events = append(dest.Events, evt) + keyRes := gjson.GetBytes(evt.Content, reactionKeyPath) + if keyRes.Type == gjson.String { + dest.Counts[keyRes.Str]++ + } + } + for evtID, res := range result { + if len(res.Counts) > 0 { + err = eq.Exec(ctx, updateReactionCountsQuery, evtID, dbutil.JSON{Data: &res.Counts}) + if err != nil { + return err + } + } + } + return nil + }) +} + +type EventRowID int64 + +func (m EventRowID) GetMassInsertValues() [1]any { + return [1]any{m} +} + type Event struct { - RowID int64 + RowID EventRowID `json:"fi.mau.hicli.rowid"` + TimelineRowID TimelineRowID `json:"fi.mau.hicli.timeline_rowid"` - RoomID id.RoomID - ID id.EventID - Sender id.UserID - Type string - StateKey *string - Timestamp time.Time + RoomID id.RoomID `json:"room_id"` + ID id.EventID `json:"event_id"` + Sender id.UserID `json:"sender"` + Type string `json:"type"` + StateKey *string `json:"state_key,omitempty"` + Timestamp time.Time `json:"timestamp"` - Content json.RawMessage - Decrypted json.RawMessage - DecryptedType string - Unsigned json.RawMessage + Content json.RawMessage `json:"content"` + Decrypted json.RawMessage `json:"decrypted,omitempty"` + DecryptedType string `json:"decrypted_type,omitempty"` + Unsigned json.RawMessage `json:"unsigned,omitempty"` - RedactedBy id.EventID - RelatesTo id.EventID + RedactedBy id.EventID `json:"redacted_by,omitempty"` + RelatesTo id.EventID `json:"relates_to,omitempty"` + RelationType event.RelationType `json:"relation_type,omitempty"` - MegolmSessionID id.SessionID + MegolmSessionID id.SessionID `json:"-,omitempty"` DecryptionError string + + Reactions map[string]int + LastEditRowID EventRowID } func MautrixToEvent(evt *event.Event) *Event { @@ -88,9 +259,9 @@ func MautrixToEvent(evt *event.Event) *Event { StateKey: evt.StateKey, Timestamp: time.UnixMilli(evt.Timestamp), Content: evt.Content.VeryRaw, - RelatesTo: getRelatesTo(evt), MegolmSessionID: getMegolmSessionID(evt), } + dbEvt.RelatesTo, dbEvt.RelationType = getRelatesTo(evt) dbEvt.Unsigned, _ = json.Marshal(&evt.Unsigned) if evt.Unsigned.RedactedBecause != nil { dbEvt.RedactedBy = evt.Unsigned.RedactedBecause.ID @@ -122,9 +293,11 @@ func (e *Event) AsRawMautrix() *event.Event { func (e *Event) Scan(row dbutil.Scannable) (*Event, error) { var timestamp int64 - var redactedBy, relatesTo, megolmSessionID, decryptionError, decryptedType sql.NullString + var redactedBy, relatesTo, relationType, megolmSessionID, decryptionError, decryptedType sql.NullString + var lastEditRowID sql.NullInt64 err := row.Scan( &e.RowID, + &e.TimelineRowID, &e.RoomID, &e.ID, &e.Sender, @@ -137,8 +310,11 @@ func (e *Event) Scan(row dbutil.Scannable) (*Event, error) { (*[]byte)(&e.Unsigned), &redactedBy, &relatesTo, + &relationType, &megolmSessionID, &decryptionError, + dbutil.JSON{Data: &e.Reactions}, + &lastEditRowID, ) if err != nil { return nil, err @@ -146,20 +322,26 @@ func (e *Event) Scan(row dbutil.Scannable) (*Event, error) { e.Timestamp = time.UnixMilli(timestamp) e.RedactedBy = id.EventID(redactedBy.String) e.RelatesTo = id.EventID(relatesTo.String) + e.RelationType = event.RelationType(relatesTo.String) e.MegolmSessionID = id.SessionID(megolmSessionID.String) e.DecryptedType = decryptedType.String e.DecryptionError = decryptionError.String + e.LastEditRowID = EventRowID(lastEditRowID.Int64) return e, nil } var relatesToPath = exgjson.Path("m.relates_to", "event_id") +var relationTypePath = exgjson.Path("m.relates_to", "rel_type") -func getRelatesTo(evt *event.Event) id.EventID { - res := gjson.GetBytes(evt.Content.VeryRaw, relatesToPath) - if res.Exists() && res.Type == gjson.String { - return id.EventID(res.Str) +func getRelatesTo(evt *event.Event) (id.EventID, event.RelationType) { + if evt.StateKey != nil { + return "", "" } - return "" + results := gjson.GetManyBytes(evt.Content.VeryRaw, relatesToPath, relationTypePath) + if len(results) == 2 && results[0].Exists() && results[1].Exists() && results[0].Type == gjson.String && results[1].Type == gjson.String { + return id.EventID(results[0].Str), event.RelationType(results[1].Str) + } + return "", "" } func getMegolmSessionID(evt *event.Event) id.SessionID { @@ -174,6 +356,10 @@ func getMegolmSessionID(evt *event.Event) id.SessionID { } func (e *Event) sqlVariables() []any { + var reactions any + if e.Reactions != nil { + reactions = e.Reactions + } return []any{ e.RoomID, e.ID, @@ -187,7 +373,22 @@ func (e *Event) sqlVariables() []any { unsafeJSONString(e.Unsigned), dbutil.StrPtr(e.RedactedBy), dbutil.StrPtr(e.RelatesTo), + dbutil.StrPtr(e.RelationType), dbutil.StrPtr(e.MegolmSessionID), dbutil.StrPtr(e.DecryptionError), + dbutil.JSON{Data: reactions}, + dbutil.NumPtr(e.LastEditRowID), } } + +func (e *Event) CanUseForPreview() bool { + return (e.Type == event.EventMessage.Type || e.Type == event.EventSticker.Type || + (e.Type == event.EventEncrypted.Type && + (e.DecryptedType == event.EventMessage.Type || e.DecryptedType == event.EventSticker.Type))) && + e.RelationType != event.RelReplace +} + +func (e *Event) BumpsSortingTimestamp() bool { + return (e.Type == event.EventMessage.Type || e.Type == event.EventSticker.Type || e.Type == event.EventEncrypted.Type) && + e.RelationType != event.RelReplace +} diff --git a/hicli/database/receipt.go b/hicli/database/receipt.go new file mode 100644 index 00000000..8757f4d6 --- /dev/null +++ b/hicli/database/receipt.go @@ -0,0 +1,69 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "time" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const ( + upsertReceiptQuery = ` + INSERT INTO receipt (room_id, user_id, receipt_type, thread_id, event_id, timestamp) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (room_id, user_id, receipt_type, thread_id) DO UPDATE + SET event_id = excluded.event_id, + timestamp = excluded.timestamp + ` +) + +var receiptMassInserter = dbutil.NewMassInsertBuilder[*Receipt, [1]any](upsertReceiptQuery, "($1, $%d, $%d, $%d, $%d, $%d)") + +type ReceiptQuery struct { + *dbutil.QueryHelper[*Receipt] +} + +func (rq *ReceiptQuery) Put(ctx context.Context, receipt *Receipt) error { + return rq.Exec(ctx, upsertReceiptQuery, receipt.sqlVariables()...) +} + +func (rq *ReceiptQuery) PutMany(ctx context.Context, roomID id.RoomID, receipts ...*Receipt) error { + query, params := receiptMassInserter.Build([1]any{roomID}, receipts) + return rq.Exec(ctx, query, params...) +} + +type Receipt struct { + RoomID id.RoomID + UserID id.UserID + ReceiptType event.ReceiptType + ThreadID event.ThreadID + EventID id.EventID + Timestamp time.Time +} + +func (r *Receipt) Scan(row dbutil.Scannable) (*Receipt, error) { + var ts int64 + err := row.Scan(&r.RoomID, &r.UserID, &r.ReceiptType, &r.ThreadID, &r.EventID, &ts) + if err != nil { + return nil, err + } + r.Timestamp = time.UnixMilli(ts) + return r, nil +} + +func (r *Receipt) sqlVariables() []any { + return []any{r.RoomID, r.UserID, r.ReceiptType, r.ThreadID, r.EventID, r.Timestamp.UnixMilli()} +} + +func (r *Receipt) GetMassInsertValues() [5]any { + return [5]any{r.UserID, r.ReceiptType, r.ThreadID, r.EventID, r.Timestamp.UnixMilli()} +} diff --git a/hicli/database/room.go b/hicli/database/room.go index c7d13fca..0f6f5a74 100644 --- a/hicli/database/room.go +++ b/hicli/database/room.go @@ -9,6 +9,7 @@ package database import ( "context" "database/sql" + "time" "go.mau.fi/util/dbutil" @@ -18,10 +19,12 @@ import ( ) const ( - getRoomByIDQuery = ` - SELECT room_id, creation_content, name, avatar, topic, lazy_load_summary, encryption_event, has_member_list, prev_batch - FROM room WHERE room_id = $1 + getRoomBaseQuery = ` + SELECT room_id, creation_content, name, avatar, topic, lazy_load_summary, encryption_event, has_member_list, + preview_event_rowid, sorting_timestamp, prev_batch + FROM room ` + getRoomByIDQuery = getRoomBaseQuery + `WHERE room_id = $1` ensureRoomExistsQuery = ` INSERT INTO room (room_id) VALUES ($1) ON CONFLICT (room_id) DO NOTHING @@ -35,12 +38,20 @@ const ( lazy_load_summary = COALESCE($6, room.lazy_load_summary), encryption_event = COALESCE($7, room.encryption_event), has_member_list = room.has_member_list OR $8, - prev_batch = COALESCE(room.prev_batch, $9) + preview_event_rowid = COALESCE($9, room.preview_event_rowid), + sorting_timestamp = COALESCE($10, room.sorting_timestamp), + prev_batch = COALESCE($11, room.prev_batch) WHERE room_id = $1 ` setRoomPrevBatchQuery = ` - INSERT INTO room (room_id, prev_batch) VALUES ($1, $2) - ON CONFLICT (room_id) DO UPDATE SET prev_batch = excluded.prev_batch + UPDATE room SET prev_batch = $2 WHERE room_id = $1 + ` + updateRoomPreviewIfLaterOnTimelineQuery = ` + UPDATE room + SET preview_event_rowid = $2 + WHERE room_id = $1 + AND COALESCE((SELECT rowid FROM timeline WHERE event_rowid = $2), -1) + > COALESCE((SELECT rowid FROM timeline WHERE event_rowid = preview_event_rowid), 0) ` ) @@ -64,6 +75,10 @@ func (rq *RoomQuery) SetPrevBatch(ctx context.Context, roomID id.RoomID, prevBat return rq.Exec(ctx, setRoomPrevBatchQuery, roomID, prevBatch) } +func (rq *RoomQuery) UpdatePreviewIfLaterOnTimeline(ctx context.Context, roomID id.RoomID, rowID EventRowID) error { + return rq.Exec(ctx, updateRoomPreviewIfLaterOnTimelineQuery, roomID, rowID) +} + type Room struct { ID id.RoomID CreationContent *event.CreateEventContent @@ -77,11 +92,15 @@ type Room struct { EncryptionEvent *event.EncryptionEventContent HasMemberList bool + PreviewEventRowID EventRowID + SortingTimestamp time.Time + PrevBatch string } func (r *Room) Scan(row dbutil.Scannable) (*Room, error) { var prevBatch sql.NullString + var previewEventRowID, sortingTimestamp sql.NullInt64 err := row.Scan( &r.ID, dbutil.JSON{Data: &r.CreationContent}, @@ -91,12 +110,16 @@ func (r *Room) Scan(row dbutil.Scannable) (*Room, error) { dbutil.JSON{Data: &r.LazyLoadSummary}, dbutil.JSON{Data: &r.EncryptionEvent}, &r.HasMemberList, + &previewEventRowID, + &sortingTimestamp, &prevBatch, ) if err != nil { return nil, err } r.PrevBatch = prevBatch.String + r.PreviewEventRowID = EventRowID(previewEventRowID.Int64) + r.SortingTimestamp = time.UnixMilli(sortingTimestamp.Int64) return r, nil } @@ -110,6 +133,20 @@ func (r *Room) sqlVariables() []any { dbutil.JSONPtr(r.LazyLoadSummary), dbutil.JSONPtr(r.EncryptionEvent), r.HasMemberList, + dbutil.NumPtr(r.PreviewEventRowID), + dbutil.UnixMilliPtr(r.SortingTimestamp), dbutil.StrPtr(r.PrevBatch), } } + +func (r *Room) BumpSortingTimestamp(evt *Event) bool { + if !evt.BumpsSortingTimestamp() || evt.Timestamp.Before(r.SortingTimestamp) { + return false + } + r.SortingTimestamp = evt.Timestamp + now := time.Now() + if r.SortingTimestamp.After(now) { + r.SortingTimestamp = now + } + return true +} diff --git a/hicli/database/state.go b/hicli/database/state.go index 47c91dcf..31c5adda 100644 --- a/hicli/database/state.go +++ b/hicli/database/state.go @@ -26,7 +26,7 @@ type CurrentStateQuery struct { *dbutil.Database } -func (csq *CurrentStateQuery) Set(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, eventRowID int64, membership event.Membership) error { +func (csq *CurrentStateQuery) Set(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, eventRowID EventRowID, membership event.Membership) error { _, err := csq.Exec(ctx, setCurrentStateQuery, roomID, eventType.Type, stateKey, eventRowID, dbutil.StrPtr(membership)) return err } diff --git a/hicli/database/timeline.go b/hicli/database/timeline.go index 585e55bb..107c7ddb 100644 --- a/hicli/database/timeline.go +++ b/hicli/database/timeline.go @@ -8,6 +8,9 @@ package database import ( "context" + "database/sql" + "errors" + "sync" "go.mau.fi/util/dbutil" @@ -18,30 +21,86 @@ const ( clearTimelineQuery = ` DELETE FROM timeline WHERE room_id = $1 ` - setTimelineQuery = ` + appendTimelineQuery = ` INSERT INTO timeline (room_id, event_rowid) VALUES ($1, $2) ` + prependTimelineQuery = ` + INSERT INTO timeline (room_id, rowid, event_rowid) VALUES ($1, $2, $3) + ` + findMinRowIDQuery = `SELECT MIN(rowid) FROM timeline` ) -type MassInsertableRowID int64 +type TimelineRowID int64 -func (m MassInsertableRowID) GetMassInsertValues() [1]any { - return [1]any{m} +type TimelinePrepend struct { + Timeline TimelineRowID + Event EventRowID } -var setTimelineQueryBuilder = dbutil.NewMassInsertBuilder[MassInsertableRowID, [1]any](setTimelineQuery, "($1, $%d)") +func (tp TimelinePrepend) GetMassInsertValues() [2]any { + return [2]any{tp.Timeline, tp.Event} +} + +var appendTimelineQueryBuilder = dbutil.NewMassInsertBuilder[EventRowID, [1]any](appendTimelineQuery, "($1, $%d)") +var prependTimelineQueryBuilder = dbutil.NewMassInsertBuilder[TimelinePrepend, [1]any](prependTimelineQuery, "($1, $%d, $%d)") type TimelineQuery struct { *dbutil.Database + + minRowID TimelineRowID + minRowIDFound bool + prependLock sync.Mutex } +// Clear clears the timeline of a given room. func (tq *TimelineQuery) Clear(ctx context.Context, roomID id.RoomID) error { _, err := tq.Exec(ctx, clearTimelineQuery, roomID) return err } -func (tq *TimelineQuery) Append(ctx context.Context, roomID id.RoomID, rowIDs []MassInsertableRowID) error { - query, params := setTimelineQueryBuilder.Build([1]any{roomID}, rowIDs) +func (tq *TimelineQuery) reserveRowIDs(ctx context.Context, count int) (startFrom TimelineRowID, err error) { + tq.prependLock.Lock() + defer tq.prependLock.Unlock() + if !tq.minRowIDFound { + err = tq.QueryRow(ctx, findMinRowIDQuery).Scan(&tq.minRowID) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return + } + if tq.minRowID >= 0 { + // No negative row IDs exist, start at -1 + tq.minRowID = -1 + } else { + // We fetched the lowest row ID, but we want the next available one, so decrement one + tq.minRowID-- + } + } + startFrom = tq.minRowID + tq.minRowID -= TimelineRowID(count) + return +} + +// Prepend adds the given event row IDs to the beginning of the timeline. +// The events must be sorted in reverse chronological order (newest event first). +func (tq *TimelineQuery) Prepend(ctx context.Context, roomID id.RoomID, rowIDs []EventRowID) error { + startFrom, err := tq.reserveRowIDs(ctx, len(rowIDs)) + if err != nil { + return err + } + prependEntries := make([]TimelinePrepend, len(rowIDs)) + for i, rowID := range rowIDs { + prependEntries[i] = TimelinePrepend{ + Timeline: startFrom - TimelineRowID(i), + Event: rowID, + } + } + query, params := prependTimelineQueryBuilder.Build([1]any{roomID}, prependEntries) + _, err = tq.Exec(ctx, query, params...) + return err +} + +// Append adds the given event row IDs to the end of the timeline. +func (tq *TimelineQuery) Append(ctx context.Context, roomID id.RoomID, rowIDs []EventRowID) error { + query, params := appendTimelineQueryBuilder.Build([1]any{roomID}, rowIDs) _, err := tq.Exec(ctx, query, params...) return err } diff --git a/hicli/database/upgrades/00-latest-revision.sql b/hicli/database/upgrades/00-latest-revision.sql index cc85f25a..53aa7df2 100644 --- a/hicli/database/upgrades/00-latest-revision.sql +++ b/hicli/database/upgrades/00-latest-revision.sql @@ -9,20 +9,26 @@ CREATE TABLE account ( ) STRICT; CREATE TABLE room ( - room_id TEXT NOT NULL PRIMARY KEY, - creation_content TEXT, + room_id TEXT NOT NULL PRIMARY KEY, + creation_content TEXT, - name TEXT, - avatar TEXT, - topic TEXT, - lazy_load_summary TEXT, + name TEXT, + avatar TEXT, + topic TEXT, + lazy_load_summary TEXT, - encryption_event TEXT, - has_member_list INTEGER NOT NULL DEFAULT false, + encryption_event TEXT, + has_member_list INTEGER NOT NULL DEFAULT false, - prev_batch TEXT + preview_event_rowid INTEGER, + sorting_timestamp INTEGER, + + prev_batch TEXT, + + CONSTRAINT room_preview_event_fkey FOREIGN KEY (preview_event_rowid) REFERENCES event (rowid) ON DELETE SET NULL ) STRICT; CREATE INDEX room_type_idx ON room (creation_content ->> 'type'); +CREATE INDEX room_sorting_timestamp_idx ON room (sorting_timestamp DESC); CREATE TABLE account_data ( user_id TEXT NOT NULL, @@ -59,10 +65,14 @@ CREATE TABLE event ( redacted_by TEXT, relates_to TEXT, + relation_type TEXT, megolm_session_id TEXT, decryption_error TEXT, + reactions TEXT, + last_edit_rowid INTEGER, + CONSTRAINT event_id_unique_key UNIQUE (event_id), CONSTRAINT event_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE ) STRICT; @@ -71,6 +81,92 @@ CREATE INDEX event_redacted_by_idx ON event (room_id, redacted_by); CREATE INDEX event_relates_to_idx ON event (room_id, relates_to); CREATE INDEX event_megolm_session_id_idx ON event (room_id, megolm_session_id); +CREATE TRIGGER event_update_redacted_by + AFTER INSERT + ON event + WHEN NEW.type = 'm.room.redaction' +BEGIN + UPDATE event SET redacted_by = NEW.event_id WHERE room_id = NEW.room_id AND event_id = NEW.content ->> 'redacts'; +END; + +CREATE TRIGGER event_update_last_edit_when_redacted + AFTER UPDATE + ON event + WHEN OLD.redacted_by IS NULL + AND NEW.redacted_by IS NOT NULL + AND NEW.relation_type = 'm.replace' +BEGIN + UPDATE event + SET last_edit_rowid = (SELECT rowid + FROM event edit + WHERE edit.room_id = event.room_id + AND edit.relates_to = event.event_id + AND edit.relation_type = 'm.replace' + AND edit.type = event.type + AND edit.sender = event.sender + AND edit.redacted_by IS NULL + ORDER BY edit.timestamp DESC + LIMIT 1) + WHERE event_id = NEW.relates_to + AND last_edit_rowid = NEW.rowid; +END; + +CREATE TRIGGER event_insert_update_last_edit + AFTER INSERT + ON event + WHEN NEW.relation_type = 'm.replace' + AND NEW.redacted_by IS NULL +BEGIN + UPDATE event + SET last_edit_rowid = NEW.rowid + WHERE event_id = NEW.relates_to + AND type = NEW.type + AND sender = NEW.sender + AND state_key IS NULL + AND NEW.timestamp > COALESCE((SELECT prev_edit.timestamp FROM event prev_edit WHERE prev_edit.rowid = event.last_edit_rowid), 0); +END; + +CREATE TRIGGER event_insert_fill_reactions + AFTER INSERT + ON event + WHEN NEW.type = 'm.reaction' + AND NEW.relation_type = 'm.annotation' + AND NEW.redacted_by IS NULL + AND typeof(NEW.content ->> '$."m.relates_to".key') = 'text' +BEGIN + UPDATE event + SET reactions=json_set( + reactions, + '$.' || json_quote(NEW.content ->> '$."m.relates_to".key'), + coalesce( + reactions ->> ('$.' || json_quote(NEW.content ->> '$."m.relates_to".key')), + 0 + ) + 1) + WHERE event_id = NEW.relates_to + AND reactions IS NOT NULL; +END; + +CREATE TRIGGER event_redact_fill_reactions + AFTER UPDATE + ON event + WHEN NEW.type = 'm.reaction' + AND NEW.relation_type = 'm.annotation' + AND NEW.redacted_by IS NOT NULL + AND OLD.redacted_by IS NULL + AND typeof(NEW.content ->> '$."m.relates_to".key') = 'text' +BEGIN + UPDATE event + SET reactions=json_set( + reactions, + '$.' || json_quote(NEW.content ->> '$."m.relates_to".key'), + coalesce( + reactions ->> ('$.' || json_quote(NEW.content ->> '$."m.relates_to".key')), + 0 + ) - 1) + WHERE event_id = NEW.relates_to + AND reactions IS NOT NULL; +END; + CREATE TABLE session_request ( room_id TEXT NOT NULL, session_id TEXT NOT NULL, @@ -99,9 +195,22 @@ CREATE TABLE current_state ( state_key TEXT NOT NULL, event_rowid INTEGER NOT NULL, - membership TEXT, + membership TEXT, PRIMARY KEY (room_id, event_type, state_key), CONSTRAINT current_state_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE, CONSTRAINT current_state_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid) ) STRICT, WITHOUT ROWID; + +CREATE TABLE receipt ( + room_id TEXT NOT NULL, + user_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + thread_id TEXT NOT NULL, + event_id TEXT NOT NULL, + timestamp INTEGER NOT NULL, + + PRIMARY KEY (room_id, user_id, receipt_type, thread_id), + CONSTRAINT receipt_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE + -- note: there's no foreign key on event ID because receipts could point at events that are too far in history. +) STRICT; diff --git a/hicli/decryptionqueue.go b/hicli/decryptionqueue.go index 551713a8..4358297c 100644 --- a/hicli/decryptionqueue.go +++ b/hicli/decryptionqueue.go @@ -66,11 +66,19 @@ func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.Ro if err != nil { return fmt.Errorf("failed to save decrypted content for %s: %w", evt.ID, err) } + if evt.CanUseForPreview() { + err = h.DB.Room.UpdatePreviewIfLaterOnTimeline(ctx, evt.RoomID, evt.RowID) + if err != nil { + return fmt.Errorf("failed to update room %s preview to %d: %w", evt.RoomID, evt.RowID, err) + } + } } return nil }) if err != nil { log.Err(err).Msg("Failed to save decrypted events") + } else { + h.DispatchEvent(&EventsDecrypted{Events: decrypted}) } } } diff --git a/hicli/events.go b/hicli/events.go new file mode 100644 index 00000000..9e9d43de --- /dev/null +++ b/hicli/events.go @@ -0,0 +1,25 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/hicli/database" + "maunium.net/go/mautrix/id" +) + +type SyncComplete struct { +} + +type EventsDecrypted struct { + Events []*database.Event +} + +type Typing struct { + RoomID id.RoomID `json:"room_id"` + event.TypingEventContent +} diff --git a/hicli/hicli.go b/hicli/hicli.go index 9b889d3c..8816afaf 100644 --- a/hicli/hicli.go +++ b/hicli/hicli.go @@ -9,6 +9,7 @@ package hicli import ( "context" + "errors" "fmt" "net" "net/http" @@ -46,6 +47,15 @@ type HiClient struct { encryptLock sync.Mutex requestQueueWakeup chan struct{} + + paginationInterrupterLock sync.Mutex + paginationInterrupter map[id.RoomID]context.CancelCauseFunc +} + +var ErrTimelineReset = errors.New("got limited timeline sync response") + +func (h *HiClient) DispatchEvent(evt any) { + // TODO } func New(rawDB *dbutil.Database, log zerolog.Logger, pickleKey []byte) *HiClient { diff --git a/hicli/hitest/hitest.go b/hicli/hitest/hitest.go index ec94a328..b097dad1 100644 --- a/hicli/hitest/hitest.go +++ b/hicli/hitest/hitest.go @@ -29,6 +29,7 @@ import ( var writerTypeReadline zeroconfig.WriterType = "hitest_readline" func main() { + hicli.InitialDeviceDisplayName = "mautrix hitest" rl := exerrors.Must(readline.New("> ")) defer func() { _ = rl.Close() diff --git a/hicli/login.go b/hicli/login.go index 47ea5a4d..2f9efb2d 100644 --- a/hicli/login.go +++ b/hicli/login.go @@ -13,8 +13,11 @@ import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/hicli/database" + "maunium.net/go/mautrix/id" ) +var InitialDeviceDisplayName = "mautrix hiclient" + func (h *HiClient) LoginPassword(ctx context.Context, homeserverURL, username, password string) error { var err error h.Client.HomeserverURL, err = url.Parse(homeserverURL) @@ -28,7 +31,7 @@ func (h *HiClient) LoginPassword(ctx context.Context, homeserverURL, username, p User: username, }, Password: password, - InitialDeviceDisplayName: "mautrix client", + InitialDeviceDisplayName: InitialDeviceDisplayName, }) } @@ -59,6 +62,10 @@ func (h *HiClient) Login(ctx context.Context, req *mautrix.ReqLogin) error { if err != nil { return err } + _, err = h.Crypto.FetchKeys(ctx, []id.UserID{h.Account.UserID}, true) + if err != nil { + return fmt.Errorf("failed to fetch own devices: %w", err) + } return nil } diff --git a/hicli/paginate.go b/hicli/paginate.go new file mode 100644 index 00000000..6fce03da --- /dev/null +++ b/hicli/paginate.go @@ -0,0 +1,91 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "errors" + "fmt" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/hicli/database" + "maunium.net/go/mautrix/id" +) + +var ErrPaginationAlreadyInProgress = errors.New("pagination is already in progress") + +func (h *HiClient) Paginate(ctx context.Context, roomID id.RoomID, minTimelineID database.TimelineRowID, limit int) ([]*database.Event, error) { + evts, err := h.DB.Event.GetTimeline(ctx, roomID, limit, minTimelineID) + if err != nil { + return nil, err + } else if len(evts) > 0 { + return evts, nil + } else { + return h.PaginateServer(ctx, roomID, limit) + } +} + +func (h *HiClient) PaginateServer(ctx context.Context, roomID id.RoomID, limit int) ([]*database.Event, error) { + ctx, cancel := context.WithCancelCause(ctx) + h.paginationInterrupterLock.Lock() + if _, alreadyPaginating := h.paginationInterrupter[roomID]; alreadyPaginating { + h.paginationInterrupterLock.Unlock() + return nil, ErrPaginationAlreadyInProgress + } + h.paginationInterrupter[roomID] = cancel + h.paginationInterrupterLock.Unlock() + defer func() { + h.paginationInterrupterLock.Lock() + delete(h.paginationInterrupter, roomID) + h.paginationInterrupterLock.Unlock() + }() + + room, err := h.DB.Room.Get(ctx, roomID) + if err != nil { + return nil, fmt.Errorf("failed to get room from database: %w", err) + } + resp, err := h.Client.Messages(ctx, roomID, room.PrevBatch, "", mautrix.DirectionBackward, nil, limit) + if err != nil { + return nil, fmt.Errorf("failed to get messages from server: %w", err) + } + events := make([]*database.Event, len(resp.Chunk)) + wakeupSessionRequests := false + err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { + if err = ctx.Err(); err != nil { + return err + } + eventRowIDs := make([]database.EventRowID, len(resp.Chunk)) + decryptionQueue := make(map[id.SessionID]*database.SessionRequest) + for i, evt := range resp.Chunk { + events[i], err = h.processEvent(ctx, evt, decryptionQueue, true) + if err != nil { + return err + } + eventRowIDs[i] = events[i].RowID + } + wakeupSessionRequests = len(decryptionQueue) > 0 + for _, entry := range decryptionQueue { + err = h.DB.SessionRequest.Put(ctx, entry) + if err != nil { + return fmt.Errorf("failed to save session request for %s: %w", entry.SessionID, err) + } + } + err = h.DB.Room.SetPrevBatch(ctx, room.ID, resp.End) + if err != nil { + return fmt.Errorf("failed to set prev_batch: %w", err) + } + err = h.DB.Timeline.Prepend(ctx, room.ID, eventRowIDs) + if err != nil { + return fmt.Errorf("failed to prepend events to timeline: %w", err) + } + return nil + }) + if err == nil && wakeupSessionRequests { + h.WakeupRequestQueue() + } + return events, err +} diff --git a/hicli/sync.go b/hicli/sync.go index d0064015..1f569b3c 100644 --- a/hicli/sync.go +++ b/hicli/sync.go @@ -57,13 +57,13 @@ func (h *HiClient) preProcessSyncResponse(ctx context.Context, resp *mautrix.Res return nil } -func (h *HiClient) postProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error { +func (h *HiClient) postProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) { h.Crypto.HandleOTKCounts(ctx, &resp.DeviceOTKCount) go h.asyncPostProcessSyncResponse(ctx, resp, since) if ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue { h.WakeupRequestQueue() } - return nil + h.firstSyncReceived = true } func (h *HiClient) asyncPostProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) { @@ -116,6 +116,24 @@ func (h *HiClient) processSyncResponse(ctx context.Context, resp *mautrix.RespSy return nil } +func receiptsToList(content *event.ReceiptEventContent) []*database.Receipt { + receiptList := make([]*database.Receipt, 0) + for eventID, receipts := range *content { + for receiptType, users := range receipts { + for userID, receiptInfo := range users { + receiptList = append(receiptList, &database.Receipt{ + UserID: userID, + ReceiptType: receiptType, + ThreadID: receiptInfo.ThreadID, + EventID: eventID, + Timestamp: receiptInfo.Timestamp, + }) + } + } + } + return receiptList +} + func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncJoinedRoom) error { existingRoomData, err := h.DB.Room.Get(ctx, roomID) if err != nil { @@ -140,6 +158,29 @@ func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, if err != nil { return err } + for _, evt := range room.Ephemeral.Events { + evt.Type.Class = event.EphemeralEventType + err = evt.Content.ParseRaw(evt.Type) + if err != nil { + zerolog.Ctx(ctx).Debug().Err(err).Msg("Failed to parse ephemeral event content") + continue + } + switch evt.Type { + case event.EphemeralEventReceipt: + err = h.DB.Receipt.PutMany(ctx, roomID, receiptsToList(evt.Content.AsReceipt())...) + if err != nil { + return fmt.Errorf("failed to save receipts: %w", err) + } + case event.EphemeralEventTyping: + go h.DispatchEvent(&Typing{ + RoomID: roomID, + TypingEventContent: *evt.Content.AsTyping(), + }) + } + if evt.Type != event.EphemeralEventReceipt { + continue + } + } return nil } @@ -164,11 +205,8 @@ func removeReplyFallback(evt *event.Event) []byte { content.RemoveReplyFallback() if content.FormattedBody != prevFormattedBody { bytes, err := sjson.SetBytes(evt.Content.VeryRaw, "formatted_body", content.FormattedBody) - if err == nil { - return bytes - } - bytes, err = sjson.SetBytes(evt.Content.VeryRaw, "body", content.Body) - if err == nil { + bytes, err2 := sjson.SetBytes(bytes, "body", content.Body) + if err == nil && err2 == nil { return bytes } } @@ -192,70 +230,104 @@ func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) ([]byte, return decrypted.Content.VeryRaw, decrypted.Type.Type, nil } -func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.Room, state *mautrix.SyncEventsList, timeline *mautrix.SyncTimeline, summary *mautrix.LazyLoadSummary) error { - decryptionQueue := make(map[id.SessionID]*database.SessionRequest) - roomDataChanged := false - processEvent := func(evt *event.Event) (database.MassInsertableRowID, error) { - evt.RoomID = room.ID - dbEvt := database.MautrixToEvent(evt) - contentWithoutFallback := removeReplyFallback(evt) - if contentWithoutFallback != nil { - dbEvt.Content = contentWithoutFallback - } - var decryptionErr error - if evt.Type == event.EventEncrypted { - dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt) - if decryptionErr != nil { - dbEvt.DecryptionError = decryptionErr.Error() - } - } - rowID, err := h.DB.Event.Upsert(ctx, dbEvt) +func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptionQueue map[id.SessionID]*database.SessionRequest, checkDB bool) (*database.Event, error) { + if checkDB { + dbEvt, err := h.DB.Event.GetByID(ctx, evt.ID) if err != nil { - return -1, fmt.Errorf("failed to save event %s: %w", evt.ID, err) + return nil, fmt.Errorf("failed to check if event %s exists: %w", evt.ID, err) + } else if dbEvt != nil { + return dbEvt, nil } - if decryptionErr != nil && isDecryptionErrorRetryable(decryptionErr) { - req, ok := decryptionQueue[dbEvt.MegolmSessionID] - if !ok { - req = &database.SessionRequest{ - RoomID: room.ID, - SessionID: dbEvt.MegolmSessionID, - Sender: evt.Sender, - } + } + dbEvt := database.MautrixToEvent(evt) + contentWithoutFallback := removeReplyFallback(evt) + if contentWithoutFallback != nil { + dbEvt.Content = contentWithoutFallback + } + var decryptionErr error + if evt.Type == event.EventEncrypted { + dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt) + if decryptionErr != nil { + dbEvt.DecryptionError = decryptionErr.Error() + } + } else if evt.Type == event.EventRedaction { + if evt.Redacts != "" && gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str != evt.Redacts.String() { + var err error + evt.Content.VeryRaw, err = sjson.SetBytes(evt.Content.VeryRaw, "redacts", evt.Redacts) + if err != nil { + return dbEvt, fmt.Errorf("failed to set redacts field: %w", err) } - minIndex, _ := crypto.ParseMegolmMessageIndex(evt.Content.AsEncrypted().MegolmCiphertext) - req.MinIndex = min(uint32(minIndex), req.MinIndex) - decryptionQueue[dbEvt.MegolmSessionID] = req + } + } + _, err := h.DB.Event.Upsert(ctx, dbEvt) + if err != nil { + return dbEvt, fmt.Errorf("failed to save event %s: %w", evt.ID, err) + } + if decryptionErr != nil && isDecryptionErrorRetryable(decryptionErr) { + req, ok := decryptionQueue[dbEvt.MegolmSessionID] + if !ok { + req = &database.SessionRequest{ + RoomID: evt.RoomID, + SessionID: dbEvt.MegolmSessionID, + Sender: evt.Sender, + } + } + minIndex, _ := crypto.ParseMegolmMessageIndex(evt.Content.AsEncrypted().MegolmCiphertext) + req.MinIndex = min(uint32(minIndex), req.MinIndex) + decryptionQueue[dbEvt.MegolmSessionID] = req + } + return dbEvt, err +} + +func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.Room, state *mautrix.SyncEventsList, timeline *mautrix.SyncTimeline, summary *mautrix.LazyLoadSummary) error { + updatedRoom := &database.Room{ + ID: room.ID, + + SortingTimestamp: room.SortingTimestamp, + } + decryptionQueue := make(map[id.SessionID]*database.SessionRequest) + processNewEvent := func(evt *event.Event, isTimeline bool) (database.EventRowID, error) { + evt.RoomID = room.ID + dbEvt, err := h.processEvent(ctx, evt, decryptionQueue, false) + if err != nil { + return -1, err + } + if isTimeline { + if dbEvt.CanUseForPreview() { + updatedRoom.PreviewEventRowID = dbEvt.RowID + } + updatedRoom.BumpSortingTimestamp(dbEvt) } if evt.StateKey != nil { var membership event.Membership if evt.Type == event.StateMember { membership = event.Membership(gjson.GetBytes(evt.Content.VeryRaw, "membership").Str) } - err = h.DB.CurrentState.Set(ctx, room.ID, evt.Type, *evt.StateKey, rowID, membership) + err = h.DB.CurrentState.Set(ctx, room.ID, evt.Type, *evt.StateKey, dbEvt.RowID, membership) if err != nil { return -1, fmt.Errorf("failed to save current state event ID %s for %s/%s: %w", evt.ID, evt.Type.Type, *evt.StateKey, err) } - roomDataChanged = processImportantEvent(ctx, evt, room) || roomDataChanged + processImportantEvent(ctx, evt, room, updatedRoom) } - return database.MassInsertableRowID(rowID), nil + return dbEvt.RowID, nil } var err error for _, evt := range state.Events { evt.Type.Class = event.StateEventType - _, err = processEvent(evt) + _, err = processNewEvent(evt, false) if err != nil { return err } } if len(timeline.Events) > 0 { - timelineIDs := make([]database.MassInsertableRowID, len(timeline.Events)) + timelineIDs := make([]database.EventRowID, len(timeline.Events)) for i, evt := range timeline.Events { if evt.StateKey != nil { evt.Type.Class = event.StateEventType } else { evt.Type.Class = event.MessageEventType } - timelineIDs[i], err = processEvent(evt) + timelineIDs[i], err = processNewEvent(evt, true) if err != nil { return err } @@ -274,6 +346,12 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R if err != nil { return fmt.Errorf("failed to clear old timeline: %w", err) } + updatedRoom.PrevBatch = timeline.PrevBatch + h.paginationInterrupterLock.Lock() + if interrupt, ok := h.paginationInterrupter[room.ID]; ok { + interrupt(ErrTimelineReset) + } + h.paginationInterrupterLock.Unlock() } err = h.DB.Timeline.Append(ctx, room.ID, timelineIDs) if err != nil { @@ -281,18 +359,17 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R } } if timeline.PrevBatch != "" && room.PrevBatch == "" { - room.PrevBatch = timeline.PrevBatch - roomDataChanged = true + updatedRoom.PrevBatch = timeline.PrevBatch } - if summary.Heroes != nil { - roomDataChanged = roomDataChanged || room.LazyLoadSummary == nil || - !slices.Equal(summary.Heroes, room.LazyLoadSummary.Heroes) || - !intPtrEqual(summary.JoinedMemberCount, room.LazyLoadSummary.JoinedMemberCount) || - !intPtrEqual(summary.InvitedMemberCount, room.LazyLoadSummary.InvitedMemberCount) - room.LazyLoadSummary = summary + if summary.Heroes != nil && (room.LazyLoadSummary == nil || + !slices.Equal(summary.Heroes, room.LazyLoadSummary.Heroes) || + !intPtrEqual(summary.JoinedMemberCount, room.LazyLoadSummary.JoinedMemberCount) || + !intPtrEqual(summary.InvitedMemberCount, room.LazyLoadSummary.InvitedMemberCount)) { + updatedRoom.LazyLoadSummary = summary } - if roomDataChanged { - err = h.DB.Room.Upsert(ctx, room) + // TODO check if updatedRoom contains anything + if true { + err = h.DB.Room.Upsert(ctx, updatedRoom) if err != nil { return fmt.Errorf("failed to save room data: %w", err) } @@ -307,7 +384,7 @@ func intPtrEqual(a, b *int) bool { return *a == *b } -func processImportantEvent(ctx context.Context, evt *event.Event, existingRoomData *database.Room) (roomDataChanged bool) { +func processImportantEvent(ctx context.Context, evt *event.Event, existingRoomData, updatedRoom *database.Room) (roomDataChanged bool) { if evt.StateKey == nil { return } @@ -329,33 +406,26 @@ func processImportantEvent(ctx context.Context, evt *event.Event, existingRoomDa } switch evt.Type { case event.StateCreate: - if existingRoomData.CreationContent == nil { - roomDataChanged = true - } - existingRoomData.CreationContent, _ = evt.Content.Parsed.(*event.CreateEventContent) + updatedRoom.CreationContent, _ = evt.Content.Parsed.(*event.CreateEventContent) case event.StateEncryption: newEncryption, _ := evt.Content.Parsed.(*event.EncryptionEventContent) if existingRoomData.EncryptionEvent == nil || existingRoomData.EncryptionEvent.Algorithm == newEncryption.Algorithm { - roomDataChanged = true - existingRoomData.EncryptionEvent = newEncryption + updatedRoom.EncryptionEvent = newEncryption } case event.StateRoomName: content, ok := evt.Content.Parsed.(*event.RoomNameEventContent) - if ok { - roomDataChanged = existingRoomData.Name == nil || *existingRoomData.Name != content.Name - existingRoomData.Name = &content.Name + if ok && (existingRoomData.Name == nil || *existingRoomData.Name != content.Name) { + updatedRoom.Name = &content.Name } case event.StateRoomAvatar: content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent) - if ok { - roomDataChanged = existingRoomData.Avatar == nil || *existingRoomData.Avatar != content.URL - existingRoomData.Avatar = &content.URL + if ok && (existingRoomData.Avatar == nil || *existingRoomData.Avatar != content.URL) { + updatedRoom.Avatar = &content.URL } case event.StateTopic: content, ok := evt.Content.Parsed.(*event.TopicEventContent) - if ok { - roomDataChanged = existingRoomData.Topic == nil || *existingRoomData.Topic != content.Topic - existingRoomData.Topic = &content.Topic + if ok && (existingRoomData.Topic == nil || *existingRoomData.Topic != content.Topic) { + updatedRoom.Topic = &content.Topic } } return diff --git a/hicli/syncwrap.go b/hicli/syncwrap.go index eccdb7b1..46d31e57 100644 --- a/hicli/syncwrap.go +++ b/hicli/syncwrap.go @@ -32,17 +32,13 @@ func (h *hiSyncer) ProcessResponse(ctx context.Context, resp *mautrix.RespSync, if err != nil { return err } - err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { + err = c.DB.DoTxn(ctx, nil, func(ctx context.Context) error { return c.processSyncResponse(ctx, resp, since) }) if err != nil { return err } - err = c.postProcessSyncResponse(ctx, resp, since) - if err != nil { - return err - } - c.firstSyncReceived = true + c.postProcessSyncResponse(ctx, resp, since) return nil } From 2b78b75885c785828798bc8da46b78caf3cb3b8b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 31 May 2024 14:54:01 +0300 Subject: [PATCH 0242/1647] Add beeper streaming flag to sync requests --- client.go | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/client.go b/client.go index 2ae05c44..4fe7ec50 100644 --- a/client.go +++ b/client.go @@ -610,12 +610,13 @@ func (cli *Client) SyncRequest(ctx context.Context, timeout int, since, filterID } type ReqSync struct { - Timeout int - Since string - FilterID string - FullState bool - SetPresence event.Presence - StreamResponse bool + Timeout int + Since string + FilterID string + FullState bool + SetPresence event.Presence + StreamResponse bool + BeeperStreaming bool } func (req *ReqSync) BuildQuery() map[string]string { @@ -634,6 +635,11 @@ func (req *ReqSync) BuildQuery() map[string]string { if req.FullState { query["full_state"] = "true" } + if req.BeeperStreaming { + // TODO remove this + query["streaming"] = "" + query["com.beeper.streaming"] = "true" + } return query } From b10a140a5c146e1db9a0c7b57ce958377738adf6 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 31 May 2024 12:12:44 -0600 Subject: [PATCH 0243/1647] goolm/crypto: use crypto/ed25519 Equal functions Previously, the code was using raw byte comparisons, which is not correct, as it makes timing attacks possible. Signed-off-by: Sumner Evans --- crypto/goolm/crypto/ed25519.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/crypto/goolm/crypto/ed25519.go b/crypto/goolm/crypto/ed25519.go index bc21300c..f0c56297 100644 --- a/crypto/goolm/crypto/ed25519.go +++ b/crypto/goolm/crypto/ed25519.go @@ -1,7 +1,6 @@ package crypto import ( - "bytes" "crypto/ed25519" "encoding/base64" "fmt" @@ -118,7 +117,7 @@ type Ed25519PrivateKey ed25519.PrivateKey // Equal compares the private key to the given private key. func (c Ed25519PrivateKey) Equal(x Ed25519PrivateKey) bool { - return bytes.Equal(c, x) + return ed25519.PrivateKey(c).Equal(ed25519.PrivateKey(x)) } // PubKey returns the public key derived from the private key. @@ -137,7 +136,7 @@ type Ed25519PublicKey ed25519.PublicKey // Equal compares the public key to the given public key. func (c Ed25519PublicKey) Equal(x Ed25519PublicKey) bool { - return bytes.Equal(c, x) + return ed25519.PublicKey(c).Equal(ed25519.PublicKey(x)) } // B64Encoded returns a base64 encoded string of the public key. From 2ed8d0d0b3f24d72ac5fd8c0e5109d3320cb4faa Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 1 Jun 2024 03:08:41 +0300 Subject: [PATCH 0244/1647] Add proper room name calculation and fix some bugs --- event/content.go | 10 ++ event/state.go | 4 + event/type.go | 4 +- hicli/database/database.go | 7 +- hicli/database/event.go | 22 +-- hicli/database/receipt.go | 12 ++ hicli/database/room.go | 84 +++++++++-- hicli/database/state.go | 21 ++- hicli/database/timeline.go | 36 +++-- .../database/upgrades/00-latest-revision.sql | 7 +- hicli/hicli.go | 22 ++- hicli/hitest/hitest.go | 9 +- hicli/paginate.go | 8 +- hicli/sync.go | 131 ++++++++++++++++-- 14 files changed, 307 insertions(+), 70 deletions(-) diff --git a/event/content.go b/event/content.go index bdb3eeb8..e22b6435 100644 --- a/event/content.go +++ b/event/content.go @@ -40,6 +40,8 @@ var TypeMap = map[Type]reflect.Type{ StateSpaceChild: reflect.TypeOf(SpaceChildEventContent{}), StateInsertionMarker: reflect.TypeOf(InsertionMarkerContent{}), + StateElementFunctionalMembers: reflect.TypeOf(ElementFunctionalMembersContent{}), + EventMessage: reflect.TypeOf(MessageEventContent{}), EventSticker: reflect.TypeOf(MessageEventContent{}), EventEncrypted: reflect.TypeOf(EncryptedEventContent{}), @@ -211,6 +213,7 @@ func init() { gob.Register(&BridgeEventContent{}) gob.Register(&SpaceChildEventContent{}) gob.Register(&SpaceParentEventContent{}) + gob.Register(&ElementFunctionalMembersContent{}) gob.Register(&RoomNameEventContent{}) gob.Register(&RoomAvatarEventContent{}) gob.Register(&TopicEventContent{}) @@ -352,6 +355,13 @@ func (content *Content) AsSpaceParent() *SpaceParentEventContent { } return casted } +func (content *Content) AsElementFunctionalMembers() *ElementFunctionalMembersContent { + casted, ok := content.Parsed.(*ElementFunctionalMembersContent) + if !ok { + return &ElementFunctionalMembersContent{} + } + return casted +} func (content *Content) AsMessage() *MessageEventContent { casted, ok := content.Parsed.(*MessageEventContent) if !ok { diff --git a/event/state.go b/event/state.go index e03e6a85..16b8eead 100644 --- a/event/state.go +++ b/event/state.go @@ -191,3 +191,7 @@ type InsertionMarkerContent struct { InsertionID id.EventID `json:"org.matrix.msc2716.marker.insertion"` Timestamp int64 `json:"com.beeper.timestamp,omitempty"` } + +type ElementFunctionalMembersContent struct { + FunctionalMembers []id.UserID `json:"functional_members"` +} diff --git a/event/type.go b/event/type.go index 2c801d5e..56752bc3 100644 --- a/event/type.go +++ b/event/type.go @@ -112,7 +112,7 @@ func (et *Type) GuessClass() TypeClass { StatePowerLevels.Type, StateRoomName.Type, StateRoomAvatar.Type, StateServerACL.Type, StateTopic.Type, StatePinnedEvents.Type, StateTombstone.Type, StateEncryption.Type, StateBridge.Type, StateHalfShotBridge.Type, StateSpaceParent.Type, StateSpaceChild.Type, StatePolicyRoom.Type, StatePolicyServer.Type, StatePolicyUser.Type, - StateInsertionMarker.Type: + StateInsertionMarker.Type, StateElementFunctionalMembers.Type: return StateEventType case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type: return EphemeralEventType @@ -193,6 +193,8 @@ var ( // Deprecated: MSC2716 has been abandoned StateInsertionMarker = Type{"org.matrix.msc2716.marker", StateEventType} + + StateElementFunctionalMembers = Type{"io.element.functional_members", StateEventType} ) // Message events diff --git a/hicli/database/database.go b/hicli/database/database.go index 8ec2b42a..601ca64b 100644 --- a/hicli/database/database.go +++ b/hicli/database/database.go @@ -27,15 +27,16 @@ type Database struct { func New(rawDB *dbutil.Database) *Database { rawDB.UpgradeTable = upgrades.Table + eventQH := dbutil.MakeQueryHelper(rawDB, newEvent) return &Database{ Database: rawDB, Account: AccountQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newAccount)}, AccountData: AccountDataQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newAccountData)}, Room: RoomQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newRoom)}, - Event: EventQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newEvent)}, - CurrentState: CurrentStateQuery{Database: rawDB}, - Timeline: TimelineQuery{Database: rawDB}, + Event: EventQuery{QueryHelper: eventQH}, + CurrentState: CurrentStateQuery{QueryHelper: eventQH}, + Timeline: TimelineQuery{QueryHelper: eventQH}, SessionRequest: SessionRequestQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newSessionRequest)}, Receipt: ReceiptQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newReceipt)}, } diff --git a/hicli/database/event.go b/hicli/database/event.go index ab68583f..6f681373 100644 --- a/hicli/database/event.go +++ b/hicli/database/event.go @@ -41,16 +41,7 @@ const ( RETURNING rowid ` updateEventDecryptedQuery = `UPDATE event SET decrypted = $1, decrypted_type = $2, decryption_error = NULL WHERE rowid = $3` - getTimelineQuery = ` - SELECT event.rowid, timeline.rowid, event.room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, - redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, reactions, last_edit_rowid - FROM timeline - JOIN event ON event.rowid = timeline.event_rowid - WHERE timeline.room_id = $1 AND timeline.rowid < $2 - ORDER BY timeline.rowid DESC - LIMIT $3 - ` - getEventReactionsQuery = getEventBaseQuery + ` + getEventReactionsQuery = getEventBaseQuery + ` WHERE room_id = ? AND type = 'm.reaction' AND relation_type = 'm.annotation' @@ -60,9 +51,10 @@ const ( getEventEditRowIDsQuery = ` SELECT main.event_id, edit.rowid FROM event main - JOIN event edit ON edit.room_id = main.room_id - AND edit.relates_to = main.event_id - AND edit.relation_type = 'm.replace' + JOIN event edit ON + edit.room_id = main.room_id + AND edit.relates_to = main.event_id + AND edit.relation_type = 'm.replace' AND edit.type = main.type AND edit.sender = main.sender AND edit.redacted_by IS NULL @@ -99,10 +91,6 @@ func (eq *EventQuery) UpdateDecrypted(ctx context.Context, rowID EventRowID, dec return eq.Exec(ctx, updateEventDecryptedQuery, unsafeJSONString(decrypted), decryptedType, rowID) } -func (eq *EventQuery) GetTimeline(ctx context.Context, roomID id.RoomID, limit int, before TimelineRowID) ([]*Event, error) { - return eq.QueryMany(ctx, getTimelineQuery, roomID, before, limit) -} - func (eq *EventQuery) FillReactionCounts(ctx context.Context, roomID id.RoomID, events []*Event) error { eventIDs := make([]id.EventID, 0) eventMap := make(map[id.EventID]*Event) diff --git a/hicli/database/receipt.go b/hicli/database/receipt.go index 8757f4d6..a3370fba 100644 --- a/hicli/database/receipt.go +++ b/hicli/database/receipt.go @@ -11,6 +11,7 @@ import ( "time" "go.mau.fi/util/dbutil" + "go.mau.fi/util/exslices" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -37,6 +38,17 @@ func (rq *ReceiptQuery) Put(ctx context.Context, receipt *Receipt) error { } func (rq *ReceiptQuery) PutMany(ctx context.Context, roomID id.RoomID, receipts ...*Receipt) error { + if len(receipts) > 1000 { + return rq.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error { + for _, receiptChunk := range exslices.Chunk(receipts, 200) { + err := rq.PutMany(ctx, roomID, receiptChunk...) + if err != nil { + return err + } + } + return nil + }) + } query, params := receiptMassInserter.Build([1]any{roomID}, receipts) return rq.Exec(ctx, query, params...) } diff --git a/hicli/database/room.go b/hicli/database/room.go index 0f6f5a74..d8a162b6 100644 --- a/hicli/database/room.go +++ b/hicli/database/room.go @@ -20,7 +20,8 @@ import ( const ( getRoomBaseQuery = ` - SELECT room_id, creation_content, name, avatar, topic, lazy_load_summary, encryption_event, has_member_list, + SELECT room_id, creation_content, name, name_quality, avatar, topic, canonical_alias, + lazy_load_summary, encryption_event, has_member_list, preview_event_rowid, sorting_timestamp, prev_batch FROM room ` @@ -33,14 +34,16 @@ const ( UPDATE room SET creation_content = COALESCE(room.creation_content, $2), name = COALESCE($3, room.name), - avatar = COALESCE($4, room.avatar), - topic = COALESCE($5, room.topic), - lazy_load_summary = COALESCE($6, room.lazy_load_summary), - encryption_event = COALESCE($7, room.encryption_event), - has_member_list = room.has_member_list OR $8, - preview_event_rowid = COALESCE($9, room.preview_event_rowid), - sorting_timestamp = COALESCE($10, room.sorting_timestamp), - prev_batch = COALESCE($11, room.prev_batch) + name_quality = CASE WHEN $3 IS NOT NULL THEN $4 ELSE room.name_quality END, + avatar = COALESCE($5, room.avatar), + topic = COALESCE($6, room.topic), + canonical_alias = COALESCE($7, room.canonical_alias), + lazy_load_summary = COALESCE($8, room.lazy_load_summary), + encryption_event = COALESCE($9, room.encryption_event), + has_member_list = room.has_member_list OR $10, + preview_event_rowid = COALESCE($11, room.preview_event_rowid), + sorting_timestamp = COALESCE($12, room.sorting_timestamp), + prev_batch = COALESCE($13, room.prev_batch) WHERE room_id = $1 ` setRoomPrevBatchQuery = ` @@ -79,13 +82,24 @@ func (rq *RoomQuery) UpdatePreviewIfLaterOnTimeline(ctx context.Context, roomID return rq.Exec(ctx, updateRoomPreviewIfLaterOnTimelineQuery, roomID, rowID) } +type NameQuality int + +const ( + NameQualityNil NameQuality = iota + NameQualityParticipants + NameQualityCanonicalAlias + NameQualityExplicit +) + type Room struct { ID id.RoomID CreationContent *event.CreateEventContent - Name *string - Avatar *id.ContentURI - Topic *string + Name *string + NameQuality NameQuality + Avatar *id.ContentURI + Topic *string + CanonicalAlias *id.RoomAlias LazyLoadSummary *mautrix.LazyLoadSummary @@ -98,6 +112,48 @@ type Room struct { PrevBatch string } +func (r *Room) CheckChangesAndCopyInto(other *Room) (hasChanges bool) { + if r.Name != nil && r.NameQuality >= other.NameQuality { + other.Name = r.Name + other.NameQuality = r.NameQuality + hasChanges = true + } + if r.Avatar != nil { + other.Avatar = r.Avatar + hasChanges = true + } + if r.Topic != nil { + other.Topic = r.Topic + hasChanges = true + } + if r.CanonicalAlias != nil { + other.CanonicalAlias = r.CanonicalAlias + hasChanges = true + } + if r.LazyLoadSummary != nil { + other.LazyLoadSummary = r.LazyLoadSummary + hasChanges = true + } + if r.EncryptionEvent != nil && other.EncryptionEvent == nil { + other.EncryptionEvent = r.EncryptionEvent + hasChanges = true + } + other.HasMemberList = other.HasMemberList || r.HasMemberList + if r.PreviewEventRowID > other.PreviewEventRowID { + other.PreviewEventRowID = r.PreviewEventRowID + hasChanges = true + } + if r.SortingTimestamp.After(other.SortingTimestamp) { + other.SortingTimestamp = r.SortingTimestamp + hasChanges = true + } + if r.PrevBatch != "" && other.PrevBatch == "" { + other.PrevBatch = r.PrevBatch + hasChanges = true + } + return +} + func (r *Room) Scan(row dbutil.Scannable) (*Room, error) { var prevBatch sql.NullString var previewEventRowID, sortingTimestamp sql.NullInt64 @@ -105,8 +161,10 @@ func (r *Room) Scan(row dbutil.Scannable) (*Room, error) { &r.ID, dbutil.JSON{Data: &r.CreationContent}, &r.Name, + &r.NameQuality, &r.Avatar, &r.Topic, + &r.CanonicalAlias, dbutil.JSON{Data: &r.LazyLoadSummary}, dbutil.JSON{Data: &r.EncryptionEvent}, &r.HasMemberList, @@ -128,8 +186,10 @@ func (r *Room) sqlVariables() []any { r.ID, dbutil.JSONPtr(r.CreationContent), r.Name, + r.NameQuality, r.Avatar, r.Topic, + r.CanonicalAlias, dbutil.JSONPtr(r.LazyLoadSummary), dbutil.JSONPtr(r.EncryptionEvent), r.HasMemberList, diff --git a/hicli/database/state.go b/hicli/database/state.go index 31c5adda..1b542f9f 100644 --- a/hicli/database/state.go +++ b/hicli/database/state.go @@ -20,13 +20,28 @@ const ( INSERT INTO current_state (room_id, event_type, state_key, event_rowid, membership) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (room_id, event_type, state_key) DO UPDATE SET event_rowid = excluded.event_rowid, membership = excluded.membership ` + getCurrentRoomStateQuery = ` + SELECT event.rowid, -1, event.room_id, event.event_id, sender, event.type, event.state_key, timestamp, content, decrypted, decrypted_type, unsigned, + redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, reactions, last_edit_rowid + FROM current_state cs + JOIN event ON cs.event_rowid = event.rowid + WHERE cs.room_id = $1 + ` + getCurrentStateEventQuery = getCurrentRoomStateQuery + `AND cs.event_type = $2 AND cs.state_key = $3` ) type CurrentStateQuery struct { - *dbutil.Database + *dbutil.QueryHelper[*Event] } func (csq *CurrentStateQuery) Set(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, eventRowID EventRowID, membership event.Membership) error { - _, err := csq.Exec(ctx, setCurrentStateQuery, roomID, eventType.Type, stateKey, eventRowID, dbutil.StrPtr(membership)) - return err + return csq.Exec(ctx, setCurrentStateQuery, roomID, eventType.Type, stateKey, eventRowID, dbutil.StrPtr(membership)) +} + +func (csq *CurrentStateQuery) Get(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*Event, error) { + return csq.QueryOne(ctx, getCurrentStateEventQuery, roomID, eventType.Type, stateKey) +} + +func (csq *CurrentStateQuery) GetAll(ctx context.Context, roomID id.RoomID) ([]*Event, error) { + return csq.QueryMany(ctx, getCurrentRoomStateQuery, roomID) } diff --git a/hicli/database/timeline.go b/hicli/database/timeline.go index 107c7ddb..3a24603a 100644 --- a/hicli/database/timeline.go +++ b/hicli/database/timeline.go @@ -28,24 +28,33 @@ const ( INSERT INTO timeline (room_id, rowid, event_rowid) VALUES ($1, $2, $3) ` findMinRowIDQuery = `SELECT MIN(rowid) FROM timeline` + getTimelineQuery = ` + SELECT event.rowid, timeline.rowid, event.room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, + redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, reactions, last_edit_rowid + FROM timeline + JOIN event ON event.rowid = timeline.event_rowid + WHERE timeline.room_id = $1 AND timeline.rowid < $2 + ORDER BY timeline.rowid DESC + LIMIT $3 + ` ) type TimelineRowID int64 -type TimelinePrepend struct { +type TimelineRowTuple struct { Timeline TimelineRowID Event EventRowID } -func (tp TimelinePrepend) GetMassInsertValues() [2]any { +func (tp TimelineRowTuple) GetMassInsertValues() [2]any { return [2]any{tp.Timeline, tp.Event} } var appendTimelineQueryBuilder = dbutil.NewMassInsertBuilder[EventRowID, [1]any](appendTimelineQuery, "($1, $%d)") -var prependTimelineQueryBuilder = dbutil.NewMassInsertBuilder[TimelinePrepend, [1]any](prependTimelineQuery, "($1, $%d, $%d)") +var prependTimelineQueryBuilder = dbutil.NewMassInsertBuilder[TimelineRowTuple, [1]any](prependTimelineQuery, "($1, $%d, $%d)") type TimelineQuery struct { - *dbutil.Database + *dbutil.QueryHelper[*Event] minRowID TimelineRowID minRowIDFound bool @@ -54,15 +63,14 @@ type TimelineQuery struct { // Clear clears the timeline of a given room. func (tq *TimelineQuery) Clear(ctx context.Context, roomID id.RoomID) error { - _, err := tq.Exec(ctx, clearTimelineQuery, roomID) - return err + return tq.Exec(ctx, clearTimelineQuery, roomID) } func (tq *TimelineQuery) reserveRowIDs(ctx context.Context, count int) (startFrom TimelineRowID, err error) { tq.prependLock.Lock() defer tq.prependLock.Unlock() if !tq.minRowIDFound { - err = tq.QueryRow(ctx, findMinRowIDQuery).Scan(&tq.minRowID) + err = tq.GetDB().QueryRow(ctx, findMinRowIDQuery).Scan(&tq.minRowID) if err != nil && !errors.Is(err, sql.ErrNoRows) { return } @@ -86,21 +94,23 @@ func (tq *TimelineQuery) Prepend(ctx context.Context, roomID id.RoomID, rowIDs [ if err != nil { return err } - prependEntries := make([]TimelinePrepend, len(rowIDs)) + prependEntries := make([]TimelineRowTuple, len(rowIDs)) for i, rowID := range rowIDs { - prependEntries[i] = TimelinePrepend{ + prependEntries[i] = TimelineRowTuple{ Timeline: startFrom - TimelineRowID(i), Event: rowID, } } query, params := prependTimelineQueryBuilder.Build([1]any{roomID}, prependEntries) - _, err = tq.Exec(ctx, query, params...) - return err + return tq.Exec(ctx, query, params...) } // Append adds the given event row IDs to the end of the timeline. func (tq *TimelineQuery) Append(ctx context.Context, roomID id.RoomID, rowIDs []EventRowID) error { query, params := appendTimelineQueryBuilder.Build([1]any{roomID}, rowIDs) - _, err := tq.Exec(ctx, query, params...) - return err + return tq.Exec(ctx, query, params...) +} + +func (tq *TimelineQuery) Get(ctx context.Context, roomID id.RoomID, limit int, before TimelineRowID) ([]*Event, error) { + return tq.QueryMany(ctx, getTimelineQuery, roomID, before, limit) } diff --git a/hicli/database/upgrades/00-latest-revision.sql b/hicli/database/upgrades/00-latest-revision.sql index 53aa7df2..33e9fc97 100644 --- a/hicli/database/upgrades/00-latest-revision.sql +++ b/hicli/database/upgrades/00-latest-revision.sql @@ -13,8 +13,10 @@ CREATE TABLE room ( creation_content TEXT, name TEXT, + name_quality INTEGER NOT NULL DEFAULT 0, avatar TEXT, topic TEXT, + canonical_alias TEXT, lazy_load_summary TEXT, encryption_event TEXT, @@ -47,6 +49,7 @@ CREATE TABLE room_account_data ( PRIMARY KEY (user_id, room_id, type), CONSTRAINT room_account_data_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE ) STRICT; +CREATE INDEX room_account_data_room_id_idx ON room_account_data (room_id); CREATE TABLE event ( rowid INTEGER PRIMARY KEY, @@ -123,7 +126,8 @@ BEGIN AND type = NEW.type AND sender = NEW.sender AND state_key IS NULL - AND NEW.timestamp > COALESCE((SELECT prev_edit.timestamp FROM event prev_edit WHERE prev_edit.rowid = event.last_edit_rowid), 0); + AND NEW.timestamp > + COALESCE((SELECT prev_edit.timestamp FROM event prev_edit WHERE prev_edit.rowid = event.last_edit_rowid), 0); END; CREATE TRIGGER event_insert_fill_reactions @@ -178,6 +182,7 @@ CREATE TABLE session_request ( PRIMARY KEY (session_id), CONSTRAINT session_request_queue_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE ) STRICT; +CREATE INDEX session_request_room_idx ON session_request (room_id); CREATE TABLE timeline ( rowid INTEGER PRIMARY KEY, diff --git a/hicli/hicli.go b/hicli/hicli.go index 8816afaf..5dcbb2f2 100644 --- a/hicli/hicli.go +++ b/hicli/hicli.go @@ -44,6 +44,7 @@ type HiClient struct { firstSyncReceived bool syncingID int syncLock sync.Mutex + stopSync context.CancelFunc encryptLock sync.Mutex requestQueueWakeup chan struct{} @@ -150,6 +151,9 @@ func (h *HiClient) Start(ctx context.Context, userID id.UserID) error { func (h *HiClient) Sync() { h.Client.StopSync() + if fn := h.stopSync; fn != nil { + fn() + } h.syncLock.Lock() defer h.syncLock.Unlock() h.syncingID++ @@ -158,12 +162,26 @@ func (h *HiClient) Sync() { Str("action", "sync"). Int("sync_id", syncingID). Logger() - ctx := log.WithContext(context.Background()) + ctx, cancel := context.WithCancel(log.WithContext(context.Background())) + h.stopSync = cancel log.Info().Msg("Starting syncing") err := h.Client.SyncWithContext(ctx) - if err != nil { + if err != nil && ctx.Err() == nil { log.Err(err).Msg("Fatal error in syncer") } else { log.Info().Msg("Syncing stopped") } } + +func (h *HiClient) Stop() { + h.Client.StopSync() + if fn := h.stopSync; fn != nil { + fn() + } + h.syncLock.Lock() + h.syncLock.Unlock() + err := h.DB.Close() + if err != nil { + h.Log.Err(err).Msg("Failed to close database cleanly") + } +} diff --git a/hicli/hitest/hitest.go b/hicli/hitest/hitest.go index b097dad1..4779c85c 100644 --- a/hicli/hitest/hitest.go +++ b/hicli/hitest/hitest.go @@ -15,6 +15,7 @@ import ( "github.com/chzyer/readline" _ "github.com/mattn/go-sqlite3" + "github.com/rs/zerolog" "go.mau.fi/util/dbutil" _ "go.mau.fi/util/dbutil/litestream" "go.mau.fi/util/exerrors" @@ -37,7 +38,9 @@ func main() { zeroconfig.RegisterWriter(writerTypeReadline, func(config *zeroconfig.WriterConfig) (io.Writer, error) { return rl.Stdout(), nil }) + debug := zerolog.DebugLevel log := exerrors.Must((&zeroconfig.Config{ + MinLevel: &debug, Writers: []zeroconfig.WriterConfig{{ Type: writerTypeReadline, Format: zeroconfig.LogFormatPrettyColored, @@ -54,10 +57,7 @@ func main() { rl.SetPrompt("User ID: ") userID := id.UserID(exerrors.Must(rl.Readline())) _, serverName := exerrors.Must2(userID.Parse()) - discovery, err := mautrix.DiscoverClientAPI(ctx, serverName) - if discovery == nil { - log.Fatal().Err(err).Msg("Failed to discover homeserver") - } + discovery := exerrors.Must(mautrix.DiscoverClientAPI(ctx, serverName)) password := exerrors.Must(rl.ReadPassword("Password: ")) recoveryCode := exerrors.Must(rl.ReadPassword("Recovery code: ")) exerrors.PanicIfNotNil(cli.LoginAndVerify(ctx, discovery.Homeserver.BaseURL, userID.String(), string(password), string(recoveryCode))) @@ -67,4 +67,5 @@ func main() { c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt, syscall.SIGTERM) <-c + cli.Stop() } diff --git a/hicli/paginate.go b/hicli/paginate.go index 6fce03da..0155a626 100644 --- a/hicli/paginate.go +++ b/hicli/paginate.go @@ -18,8 +18,8 @@ import ( var ErrPaginationAlreadyInProgress = errors.New("pagination is already in progress") -func (h *HiClient) Paginate(ctx context.Context, roomID id.RoomID, minTimelineID database.TimelineRowID, limit int) ([]*database.Event, error) { - evts, err := h.DB.Event.GetTimeline(ctx, roomID, limit, minTimelineID) +func (h *HiClient) Paginate(ctx context.Context, roomID id.RoomID, maxTimelineID database.TimelineRowID, limit int) ([]*database.Event, error) { + evts, err := h.DB.Timeline.Get(ctx, roomID, limit, maxTimelineID) if err != nil { return nil, err } else if len(evts) > 0 { @@ -74,6 +74,10 @@ func (h *HiClient) PaginateServer(ctx context.Context, roomID id.RoomID, limit i return fmt.Errorf("failed to save session request for %s: %w", entry.SessionID, err) } } + err = h.DB.Event.FillLastEditRowIDs(ctx, roomID, events) + if err != nil { + return fmt.Errorf("failed to fill last edit row IDs: %w", err) + } err = h.DB.Room.SetPrevBatch(ctx, room.ID, resp.End) if err != nil { return fmt.Errorf("failed to set prev_batch: %w", err) diff --git a/hicli/sync.go b/hicli/sync.go index 1f569b3c..785717c2 100644 --- a/hicli/sync.go +++ b/hicli/sync.go @@ -10,6 +10,7 @@ import ( "context" "errors" "fmt" + "strings" "github.com/rs/zerolog" "github.com/tidwall/gjson" @@ -284,6 +285,17 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R ID: room.ID, SortingTimestamp: room.SortingTimestamp, + NameQuality: room.NameQuality, + } + heroesChanged := false + if summary.Heroes == nil && summary.JoinedMemberCount == nil && summary.InvitedMemberCount == nil { + summary = room.LazyLoadSummary + } else if room.LazyLoadSummary == nil || + !slices.Equal(summary.Heroes, room.LazyLoadSummary.Heroes) || + !intPtrEqual(summary.JoinedMemberCount, room.LazyLoadSummary.JoinedMemberCount) || + !intPtrEqual(summary.InvitedMemberCount, room.LazyLoadSummary.InvitedMemberCount) { + updatedRoom.LazyLoadSummary = summary + heroesChanged = true } decryptionQueue := make(map[id.SessionID]*database.SessionRequest) processNewEvent := func(evt *event.Event, isTimeline bool) (database.EventRowID, error) { @@ -302,6 +314,11 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R var membership event.Membership if evt.Type == event.StateMember { membership = event.Membership(gjson.GetBytes(evt.Content.VeryRaw, "membership").Str) + if summary != nil && slices.Contains(summary.Heroes, id.UserID(*evt.StateKey)) { + heroesChanged = true + } + } else if evt.Type == event.StateElementFunctionalMembers { + heroesChanged = true } err = h.DB.CurrentState.Set(ctx, room.ID, evt.Type, *evt.StateKey, dbEvt.RowID, membership) if err != nil { @@ -358,17 +375,19 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R return fmt.Errorf("failed to append timeline: %w", err) } } + // Calculate name from participants if participants changed and current name was generated from participants, or if the room name was unset + if (heroesChanged && updatedRoom.NameQuality <= database.NameQualityParticipants) || updatedRoom.NameQuality == database.NameQualityNil { + name, err := h.calculateRoomParticipantName(ctx, room.ID, summary) + if err != nil { + return fmt.Errorf("failed to calculate room name: %w", err) + } + updatedRoom.Name = &name + updatedRoom.NameQuality = database.NameQualityParticipants + } if timeline.PrevBatch != "" && room.PrevBatch == "" { updatedRoom.PrevBatch = timeline.PrevBatch } - if summary.Heroes != nil && (room.LazyLoadSummary == nil || - !slices.Equal(summary.Heroes, room.LazyLoadSummary.Heroes) || - !intPtrEqual(summary.JoinedMemberCount, room.LazyLoadSummary.JoinedMemberCount) || - !intPtrEqual(summary.InvitedMemberCount, room.LazyLoadSummary.InvitedMemberCount)) { - updatedRoom.LazyLoadSummary = summary - } - // TODO check if updatedRoom contains anything - if true { + if updatedRoom.CheckChangesAndCopyInto(room) { err = h.DB.Room.Upsert(ctx, updatedRoom) if err != nil { return fmt.Errorf("failed to save room data: %w", err) @@ -377,6 +396,70 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R return nil } +func joinMemberNames(names []string, totalCount int) string { + if len(names) == 1 { + return names[0] + } else if len(names) < 5 || (len(names) == 5 && totalCount <= 6) { + return strings.Join(names[:len(names)-1], ", ") + " and " + names[len(names)-1] + } else { + return fmt.Sprintf("%s and %d others", strings.Join(names[:4], ", "), totalCount-5) + } +} + +func (h *HiClient) calculateRoomParticipantName(ctx context.Context, roomID id.RoomID, summary *mautrix.LazyLoadSummary) (string, error) { + if summary == nil || len(summary.Heroes) == 0 { + return "Empty room", nil + } + var functionalMembers []id.UserID + functionalMembersEvt, err := h.DB.CurrentState.Get(ctx, roomID, event.StateElementFunctionalMembers, "") + if err != nil { + return "", fmt.Errorf("failed to get %s event: %w", event.StateElementFunctionalMembers.Type, err) + } else if functionalMembersEvt != nil { + mautrixEvt := functionalMembersEvt.AsRawMautrix() + _ = mautrixEvt.Content.ParseRaw(mautrixEvt.Type) + content, ok := mautrixEvt.Content.Parsed.(*event.ElementFunctionalMembersContent) + if ok { + functionalMembers = content.FunctionalMembers + } + } + var members, leftMembers []string + var memberCount int + if summary.JoinedMemberCount != nil && *summary.JoinedMemberCount > 0 { + memberCount = *summary.JoinedMemberCount + } else if summary.InvitedMemberCount != nil { + memberCount = *summary.InvitedMemberCount + } + for _, hero := range summary.Heroes { + if slices.Contains(functionalMembers, hero) { + memberCount-- + continue + } else if len(members) >= 5 { + break + } + heroEvt, err := h.DB.CurrentState.Get(ctx, roomID, event.StateMember, hero.String()) + if err != nil { + return "", fmt.Errorf("failed to get %s's member event: %w", hero, err) + } + results := gjson.GetManyBytes(heroEvt.Content, "membership", "displayname") + name := results[1].Str + if name == "" { + name = hero.String() + } + if results[0].Str == "join" || results[0].Str == "invite" { + members = append(members, name) + } else { + leftMembers = append(leftMembers, name) + } + } + if len(members) > 0 { + return joinMemberNames(members, memberCount), nil + } else if len(leftMembers) > 0 { + return fmt.Sprintf("Empty room (was %s)", joinMemberNames(leftMembers, memberCount)), nil + } else { + return "Empty room", nil + } +} + func intPtrEqual(a, b *int) bool { if a == nil || b == nil { return a == b @@ -389,7 +472,7 @@ func processImportantEvent(ctx context.Context, evt *event.Event, existingRoomDa return } switch evt.Type { - case event.StateCreate, event.StateRoomName, event.StateRoomAvatar, event.StateTopic, event.StateEncryption: + case event.StateCreate, event.StateRoomName, event.StateCanonicalAlias, event.StateRoomAvatar, event.StateTopic, event.StateEncryption: if *evt.StateKey != "" { return } @@ -414,17 +497,41 @@ func processImportantEvent(ctx context.Context, evt *event.Event, existingRoomDa } case event.StateRoomName: content, ok := evt.Content.Parsed.(*event.RoomNameEventContent) - if ok && (existingRoomData.Name == nil || *existingRoomData.Name != content.Name) { + if ok { updatedRoom.Name = &content.Name + updatedRoom.NameQuality = database.NameQualityExplicit + if content.Name == "" { + if updatedRoom.CanonicalAlias != nil && *updatedRoom.CanonicalAlias != "" { + updatedRoom.Name = (*string)(updatedRoom.CanonicalAlias) + updatedRoom.NameQuality = database.NameQualityCanonicalAlias + } else if existingRoomData.CanonicalAlias != nil && *existingRoomData.CanonicalAlias != "" { + updatedRoom.Name = (*string)(existingRoomData.CanonicalAlias) + updatedRoom.NameQuality = database.NameQualityCanonicalAlias + } else { + updatedRoom.NameQuality = database.NameQualityNil + } + } + } + case event.StateCanonicalAlias: + content, ok := evt.Content.Parsed.(*event.CanonicalAliasEventContent) + if ok { + updatedRoom.CanonicalAlias = &content.Alias + if updatedRoom.NameQuality <= database.NameQualityCanonicalAlias { + updatedRoom.Name = (*string)(&content.Alias) + updatedRoom.NameQuality = database.NameQualityCanonicalAlias + if content.Alias == "" { + updatedRoom.NameQuality = database.NameQualityNil + } + } } case event.StateRoomAvatar: content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent) - if ok && (existingRoomData.Avatar == nil || *existingRoomData.Avatar != content.URL) { + if ok { updatedRoom.Avatar = &content.URL } case event.StateTopic: content, ok := evt.Content.Parsed.(*event.TopicEventContent) - if ok && (existingRoomData.Topic == nil || *existingRoomData.Topic != content.Topic) { + if ok { updatedRoom.Topic = &content.Topic } } From 320c99ce66412917e5bcf474d15536bad75eecb2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 1 Jun 2024 22:50:21 +0300 Subject: [PATCH 0245/1647] Add initial sync event dispatching --- hicli/database/event.go | 43 ++++++++++++------- hicli/database/room.go | 8 +++- hicli/database/timeline.go | 31 +++++++------ .../database/upgrades/00-latest-revision.sql | 25 ++++++----- hicli/decryptionqueue.go | 8 +++- hicli/events.go | 10 ++++- hicli/hicli.go | 10 ++--- hicli/hitest/hitest.go | 26 ++++++++++- hicli/paginate.go | 36 +++++++++++++++- hicli/sync.go | 23 +++++++--- hicli/syncwrap.go | 2 +- 11 files changed, 166 insertions(+), 56 deletions(-) diff --git a/hicli/database/event.go b/hicli/database/event.go index 6f681373..4ea50dd5 100644 --- a/hicli/database/event.go +++ b/hicli/database/event.go @@ -28,6 +28,7 @@ const ( redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, reactions, last_edit_rowid FROM event ` + getManyEventsByRowID = getEventBaseQuery + `WHERE rowid IN (%s)` getEventByID = getEventBaseQuery + `WHERE event_id = $1` getFailedEventsByMegolmSessionID = getEventBaseQuery + `WHERE room_id = $1 AND megolm_session_id = $2 AND decryption_error IS NOT NULL` upsertEventQuery = ` @@ -79,6 +80,11 @@ func (eq *EventQuery) GetByID(ctx context.Context, eventID id.EventID) (*Event, return eq.QueryOne(ctx, getEventByID, eventID) } +func (eq *EventQuery) GetByRowIDs(ctx context.Context, rowIDs ...EventRowID) ([]*Event, error) { + query, params := buildMultiEventGetFunction(nil, rowIDs, getManyEventsByRowID) + return eq.QueryMany(ctx, query, params...) +} + func (eq *EventQuery) Upsert(ctx context.Context, evt *Event) (rowID EventRowID, err error) { err = eq.GetDB().QueryRow(ctx, upsertEventQuery, evt.sqlVariables()...).Scan(&rowID) if err == nil { @@ -114,7 +120,7 @@ func (eq *EventQuery) FillLastEditRowIDs(ctx context.Context, roomID id.RoomID, eventIDs := make([]id.EventID, 0) eventMap := make(map[id.EventID]*Event) for i, evt := range events { - if evt.LastEditRowID == 0 { + if evt.LastEditRowID == nil { eventIDs[i] = evt.ID eventMap[evt.ID] = evt } @@ -126,12 +132,21 @@ func (eq *EventQuery) FillLastEditRowIDs(ctx context.Context, roomID id.RoomID, } for evtID, res := range result { lastEditRowID := res[len(res)-1] - eventMap[evtID].LastEditRowID = lastEditRowID + eventMap[evtID].LastEditRowID = &lastEditRowID + delete(eventMap, evtID) err = eq.Exec(ctx, setLastEditRowIDQuery, evtID, lastEditRowID) if err != nil { return err } } + var zero EventRowID + for evtID, evt := range eventMap { + evt.LastEditRowID = &zero + err = eq.Exec(ctx, setLastEditRowIDQuery, evtID, zero) + if err != nil { + return err + } + } return nil }) } @@ -143,11 +158,11 @@ type GetReactionsResult struct { Counts map[string]int } -func buildMultiEventGetFunction(roomID id.RoomID, eventIDs []id.EventID, query string) (string, []any) { - params := make([]any, len(eventIDs)+1) - params[0] = roomID +func buildMultiEventGetFunction[T any](preParams []any, eventIDs []T, query string) (string, []any) { + params := make([]any, len(preParams)+len(eventIDs)) + copy(params, preParams) for i, evtID := range eventIDs { - params[i+1] = evtID + params[i+len(preParams)] = evtID } placeholders := strings.Repeat("?,", len(eventIDs)) placeholders = placeholders[:len(placeholders)-1] @@ -160,7 +175,7 @@ type editRowIDTuple struct { } func (eq *EventQuery) GetEditRowIDs(ctx context.Context, roomID id.RoomID, eventIDs ...id.EventID) (map[id.EventID][]EventRowID, error) { - query, params := buildMultiEventGetFunction(roomID, eventIDs, getEventEditRowIDsQuery) + query, params := buildMultiEventGetFunction([]any{roomID}, eventIDs, getEventEditRowIDsQuery) rows, err := eq.GetDB().Query(ctx, query, params...) output := make(map[id.EventID][]EventRowID) return output, dbutil.NewRowIterWithError(rows, func(row dbutil.Scannable) (tuple editRowIDTuple, err error) { @@ -178,7 +193,7 @@ func (eq *EventQuery) GetReactions(ctx context.Context, roomID id.RoomID, eventI result[evtID] = &GetReactionsResult{Counts: make(map[string]int)} } return result, eq.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error { - query, params := buildMultiEventGetFunction(roomID, eventIDs, getEventReactionsQuery) + query, params := buildMultiEventGetFunction([]any{roomID}, eventIDs, getEventReactionsQuery) events, err := eq.QueryMany(ctx, query, params...) if err != nil { return err @@ -212,8 +227,8 @@ func (m EventRowID) GetMassInsertValues() [1]any { } type Event struct { - RowID EventRowID `json:"fi.mau.hicli.rowid"` - TimelineRowID TimelineRowID `json:"fi.mau.hicli.timeline_rowid"` + RowID EventRowID `json:"rowid"` + TimelineRowID TimelineRowID `json:"timeline_rowid"` RoomID id.RoomID `json:"room_id"` ID id.EventID `json:"event_id"` @@ -235,7 +250,7 @@ type Event struct { DecryptionError string Reactions map[string]int - LastEditRowID EventRowID + LastEditRowID *EventRowID } func MautrixToEvent(evt *event.Event) *Event { @@ -282,7 +297,6 @@ func (e *Event) AsRawMautrix() *event.Event { func (e *Event) Scan(row dbutil.Scannable) (*Event, error) { var timestamp int64 var redactedBy, relatesTo, relationType, megolmSessionID, decryptionError, decryptedType sql.NullString - var lastEditRowID sql.NullInt64 err := row.Scan( &e.RowID, &e.TimelineRowID, @@ -302,7 +316,7 @@ func (e *Event) Scan(row dbutil.Scannable) (*Event, error) { &megolmSessionID, &decryptionError, dbutil.JSON{Data: &e.Reactions}, - &lastEditRowID, + &e.LastEditRowID, ) if err != nil { return nil, err @@ -314,7 +328,6 @@ func (e *Event) Scan(row dbutil.Scannable) (*Event, error) { e.MegolmSessionID = id.SessionID(megolmSessionID.String) e.DecryptedType = decryptedType.String e.DecryptionError = decryptionError.String - e.LastEditRowID = EventRowID(lastEditRowID.Int64) return e, nil } @@ -365,7 +378,7 @@ func (e *Event) sqlVariables() []any { dbutil.StrPtr(e.MegolmSessionID), dbutil.StrPtr(e.DecryptionError), dbutil.JSON{Data: reactions}, - dbutil.NumPtr(e.LastEditRowID), + e.LastEditRowID, } } diff --git a/hicli/database/room.go b/hicli/database/room.go index d8a162b6..e3b140d0 100644 --- a/hicli/database/room.go +++ b/hicli/database/room.go @@ -55,6 +55,7 @@ const ( WHERE room_id = $1 AND COALESCE((SELECT rowid FROM timeline WHERE event_rowid = $2), -1) > COALESCE((SELECT rowid FROM timeline WHERE event_rowid = preview_event_rowid), 0) + RETURNING preview_event_rowid ` ) @@ -78,8 +79,11 @@ func (rq *RoomQuery) SetPrevBatch(ctx context.Context, roomID id.RoomID, prevBat return rq.Exec(ctx, setRoomPrevBatchQuery, roomID, prevBatch) } -func (rq *RoomQuery) UpdatePreviewIfLaterOnTimeline(ctx context.Context, roomID id.RoomID, rowID EventRowID) error { - return rq.Exec(ctx, updateRoomPreviewIfLaterOnTimelineQuery, roomID, rowID) +func (rq *RoomQuery) UpdatePreviewIfLaterOnTimeline(ctx context.Context, roomID id.RoomID, rowID EventRowID) (previewChanged bool, err error) { + var newPreviewRowID EventRowID + err = rq.GetDB().QueryRow(ctx, updateRoomPreviewIfLaterOnTimelineQuery, roomID, rowID).Scan(&newPreviewRowID) + previewChanged = newPreviewRowID == rowID + return } type NameQuality int diff --git a/hicli/database/timeline.go b/hicli/database/timeline.go index 3a24603a..891f6acb 100644 --- a/hicli/database/timeline.go +++ b/hicli/database/timeline.go @@ -22,7 +22,7 @@ const ( DELETE FROM timeline WHERE room_id = $1 ` appendTimelineQuery = ` - INSERT INTO timeline (room_id, event_rowid) VALUES ($1, $2) + INSERT INTO timeline (room_id, event_rowid) VALUES ($1, $2) RETURNING rowid, event_rowid ` prependTimelineQuery = ` INSERT INTO timeline (room_id, rowid, event_rowid) VALUES ($1, $2, $3) @@ -42,12 +42,17 @@ const ( type TimelineRowID int64 type TimelineRowTuple struct { - Timeline TimelineRowID - Event EventRowID + Timeline TimelineRowID `json:"timeline_rowid"` + Event EventRowID `json:"event_rowid"` } -func (tp TimelineRowTuple) GetMassInsertValues() [2]any { - return [2]any{tp.Timeline, tp.Event} +var timelineRowTupleScanner = dbutil.ConvertRowFn[TimelineRowTuple](func(row dbutil.Scannable) (trt TimelineRowTuple, err error) { + err = row.Scan(&trt.Timeline, &trt.Event) + return +}) + +func (trt TimelineRowTuple) GetMassInsertValues() [2]any { + return [2]any{trt.Timeline, trt.Event} } var appendTimelineQueryBuilder = dbutil.NewMassInsertBuilder[EventRowID, [1]any](appendTimelineQuery, "($1, $%d)") @@ -89,12 +94,13 @@ func (tq *TimelineQuery) reserveRowIDs(ctx context.Context, count int) (startFro // Prepend adds the given event row IDs to the beginning of the timeline. // The events must be sorted in reverse chronological order (newest event first). -func (tq *TimelineQuery) Prepend(ctx context.Context, roomID id.RoomID, rowIDs []EventRowID) error { - startFrom, err := tq.reserveRowIDs(ctx, len(rowIDs)) +func (tq *TimelineQuery) Prepend(ctx context.Context, roomID id.RoomID, rowIDs []EventRowID) (prependEntries []TimelineRowTuple, err error) { + var startFrom TimelineRowID + startFrom, err = tq.reserveRowIDs(ctx, len(rowIDs)) if err != nil { - return err + return } - prependEntries := make([]TimelineRowTuple, len(rowIDs)) + prependEntries = make([]TimelineRowTuple, len(rowIDs)) for i, rowID := range rowIDs { prependEntries[i] = TimelineRowTuple{ Timeline: startFrom - TimelineRowID(i), @@ -102,13 +108,14 @@ func (tq *TimelineQuery) Prepend(ctx context.Context, roomID id.RoomID, rowIDs [ } } query, params := prependTimelineQueryBuilder.Build([1]any{roomID}, prependEntries) - return tq.Exec(ctx, query, params...) + err = tq.Exec(ctx, query, params...) + return } // Append adds the given event row IDs to the end of the timeline. -func (tq *TimelineQuery) Append(ctx context.Context, roomID id.RoomID, rowIDs []EventRowID) error { +func (tq *TimelineQuery) Append(ctx context.Context, roomID id.RoomID, rowIDs []EventRowID) ([]TimelineRowTuple, error) { query, params := appendTimelineQueryBuilder.Build([1]any{roomID}, rowIDs) - return tq.Exec(ctx, query, params...) + return timelineRowTupleScanner.NewRowIter(tq.GetDB().Query(ctx, query, params...)).AsList() } func (tq *TimelineQuery) Get(ctx context.Context, roomID id.RoomID, limit int, before TimelineRowID) ([]*Event, error) { diff --git a/hicli/database/upgrades/00-latest-revision.sql b/hicli/database/upgrades/00-latest-revision.sql index 33e9fc97..8d99a315 100644 --- a/hicli/database/upgrades/00-latest-revision.sql +++ b/hicli/database/upgrades/00-latest-revision.sql @@ -100,16 +100,18 @@ CREATE TRIGGER event_update_last_edit_when_redacted AND NEW.relation_type = 'm.replace' BEGIN UPDATE event - SET last_edit_rowid = (SELECT rowid - FROM event edit - WHERE edit.room_id = event.room_id - AND edit.relates_to = event.event_id - AND edit.relation_type = 'm.replace' - AND edit.type = event.type - AND edit.sender = event.sender - AND edit.redacted_by IS NULL - ORDER BY edit.timestamp DESC - LIMIT 1) + SET last_edit_rowid = COALESCE( + (SELECT rowid + FROM event edit + WHERE edit.room_id = event.room_id + AND edit.relates_to = event.event_id + AND edit.relation_type = 'm.replace' + AND edit.type = event.type + AND edit.sender = event.sender + AND edit.redacted_by IS NULL + ORDER BY edit.timestamp DESC + LIMIT 1), + 0) WHERE event_id = NEW.relates_to AND last_edit_rowid = NEW.rowid; END; @@ -190,7 +192,8 @@ CREATE TABLE timeline ( event_rowid INTEGER NOT NULL, CONSTRAINT timeline_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE, - CONSTRAINT timeline_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid) ON DELETE CASCADE + CONSTRAINT timeline_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid) ON DELETE CASCADE, + CONSTRAINT timeline_event_unique_key UNIQUE (event_rowid) ) STRICT; CREATE INDEX timeline_room_id_idx ON timeline (room_id); diff --git a/hicli/decryptionqueue.go b/hicli/decryptionqueue.go index 4358297c..02466b69 100644 --- a/hicli/decryptionqueue.go +++ b/hicli/decryptionqueue.go @@ -60,6 +60,7 @@ func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.Ro } } if len(decrypted) > 0 { + previewRowIDChanges := make(map[id.RoomID]database.EventRowID) err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { for _, evt := range decrypted { err = h.DB.Event.UpdateDecrypted(ctx, evt.RowID, evt.Decrypted, evt.DecryptedType) @@ -67,9 +68,12 @@ func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.Ro return fmt.Errorf("failed to save decrypted content for %s: %w", evt.ID, err) } if evt.CanUseForPreview() { - err = h.DB.Room.UpdatePreviewIfLaterOnTimeline(ctx, evt.RoomID, evt.RowID) + var previewChanged bool + previewChanged, err = h.DB.Room.UpdatePreviewIfLaterOnTimeline(ctx, evt.RoomID, evt.RowID) if err != nil { return fmt.Errorf("failed to update room %s preview to %d: %w", evt.RoomID, evt.RowID, err) + } else if previewChanged { + previewRowIDChanges[evt.RoomID] = evt.RowID } } } @@ -78,7 +82,7 @@ func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.Ro if err != nil { log.Err(err).Msg("Failed to save decrypted events") } else { - h.DispatchEvent(&EventsDecrypted{Events: decrypted}) + h.EventHandler(&EventsDecrypted{Events: decrypted, PreviewRowIDs: previewRowIDChanges}) } } } diff --git a/hicli/events.go b/hicli/events.go index 9e9d43de..6cdfc98a 100644 --- a/hicli/events.go +++ b/hicli/events.go @@ -12,11 +12,19 @@ import ( "maunium.net/go/mautrix/id" ) +type SyncRoom struct { + Meta *database.Room `json:"meta"` + Timeline []database.TimelineRowTuple `json:"timeline"` + Reset bool `json:"reset"` +} + type SyncComplete struct { + Rooms map[id.RoomID]*SyncRoom `json:"rooms"` } type EventsDecrypted struct { - Events []*database.Event + PreviewRowIDs map[id.RoomID]database.EventRowID `json:"room_preview_rowids"` + Events []*database.Event `json:"events"` } type Typing struct { diff --git a/hicli/hicli.go b/hicli/hicli.go index 5dcbb2f2..5e2957c0 100644 --- a/hicli/hicli.go +++ b/hicli/hicli.go @@ -41,6 +41,8 @@ type HiClient struct { KeyBackupVersion id.KeyBackupVersion KeyBackupKey *backup.MegolmBackupKey + EventHandler func(evt any) + firstSyncReceived bool syncingID int syncLock sync.Mutex @@ -55,11 +57,7 @@ type HiClient struct { var ErrTimelineReset = errors.New("got limited timeline sync response") -func (h *HiClient) DispatchEvent(evt any) { - // TODO -} - -func New(rawDB *dbutil.Database, log zerolog.Logger, pickleKey []byte) *HiClient { +func New(rawDB *dbutil.Database, log zerolog.Logger, pickleKey []byte, evtHandler func(any)) *HiClient { rawDB.Owner = "hicli" rawDB.IgnoreForeignTables = true db := database.New(rawDB) @@ -69,6 +67,8 @@ func New(rawDB *dbutil.Database, log zerolog.Logger, pickleKey []byte) *HiClient Log: log, requestQueueWakeup: make(chan struct{}, 1), + + EventHandler: evtHandler, } c.ClientStore = &database.ClientStateStore{Database: db} c.Client = &mautrix.Client{ diff --git a/hicli/hitest/hitest.go b/hicli/hitest/hitest.go index 4779c85c..c55c27dc 100644 --- a/hicli/hitest/hitest.go +++ b/hicli/hitest/hitest.go @@ -8,6 +8,7 @@ package main import ( "context" + "fmt" "io" "os" "os/signal" @@ -50,7 +51,30 @@ func main() { rawDB := exerrors.Must(dbutil.NewWithDialect("hicli.db", "sqlite3-fk-wal")) ctx := log.WithContext(context.Background()) - cli := hicli.New(rawDB, *log, []byte("meow")) + cli := hicli.New(rawDB, *log, []byte("meow"), func(a any) { + _, _ = fmt.Fprintf(rl, "Received event of type %T\n", a) + switch evt := a.(type) { + case *hicli.SyncComplete: + for _, room := range evt.Rooms { + name := "name unset" + if room.Meta.Name != nil { + name = *room.Meta.Name + } + _, _ = fmt.Fprintf(rl, "Room %s (%s) in sync:\n", name, room.Meta.ID) + _, _ = fmt.Fprintf(rl, " Preview: %d, sort: %v\n", room.Meta.PreviewEventRowID, room.Meta.SortingTimestamp) + _, _ = fmt.Fprintf(rl, " Timeline: +%d %v, reset: %t\n", len(room.Timeline), room.Timeline, room.Reset) + } + case *hicli.EventsDecrypted: + for _, decrypted := range evt.Events { + _, _ = fmt.Fprintf(rl, "Delayed decryption of %s completed: %s / %s\n", decrypted.ID, decrypted.DecryptedType, decrypted.Decrypted) + } + if len(evt.PreviewRowIDs) > 0 { + _, _ = fmt.Fprintf(rl, "Room previews updated: %+v\n", evt.PreviewRowIDs) + } + case *hicli.Typing: + _, _ = fmt.Fprintf(rl, "Typing list in %s: %+v\n", evt.RoomID, evt.UserIDs) + } + }) userID, _ := cli.DB.Account.GetFirstUserID(ctx) exerrors.PanicIfNotNil(cli.Start(ctx, userID)) if !cli.IsLoggedIn() { diff --git a/hicli/paginate.go b/hicli/paginate.go index 0155a626..957ac3e3 100644 --- a/hicli/paginate.go +++ b/hicli/paginate.go @@ -18,6 +18,36 @@ import ( var ErrPaginationAlreadyInProgress = errors.New("pagination is already in progress") +func (h *HiClient) GetEventsByRowIDs(ctx context.Context, rowIDs []database.EventRowID) ([]*database.Event, error) { + events, err := h.DB.Event.GetByRowIDs(ctx, rowIDs...) + if err != nil { + return nil, err + } else if len(events) == 0 { + return events, nil + } + firstRoomID := events[0].RoomID + allInSameRoom := true + for _, evt := range events { + if evt.RoomID != firstRoomID { + allInSameRoom = false + break + } + } + if allInSameRoom { + err = h.DB.Event.FillLastEditRowIDs(ctx, firstRoomID, events) + if err != nil { + return events, fmt.Errorf("failed to fill last edit row IDs: %w", err) + } + err = h.DB.Event.FillReactionCounts(ctx, firstRoomID, events) + if err != nil { + return events, fmt.Errorf("failed to fill reaction counts: %w", err) + } + } else { + // TODO slow path where events are collected and filling is done one room at a time? + } + return events, nil +} + func (h *HiClient) Paginate(ctx context.Context, roomID id.RoomID, maxTimelineID database.TimelineRowID, limit int) ([]*database.Event, error) { evts, err := h.DB.Timeline.Get(ctx, roomID, limit, maxTimelineID) if err != nil { @@ -82,10 +112,14 @@ func (h *HiClient) PaginateServer(ctx context.Context, roomID id.RoomID, limit i if err != nil { return fmt.Errorf("failed to set prev_batch: %w", err) } - err = h.DB.Timeline.Prepend(ctx, room.ID, eventRowIDs) + var tuples []database.TimelineRowTuple + tuples, err = h.DB.Timeline.Prepend(ctx, room.ID, eventRowIDs) if err != nil { return fmt.Errorf("failed to prepend events to timeline: %w", err) } + for i, evt := range events { + evt.TimelineRowID = tuples[i].Timeline + } return nil }) if err == nil && wakeupSessionRequests { diff --git a/hicli/sync.go b/hicli/sync.go index 785717c2..6502b43e 100644 --- a/hicli/sync.go +++ b/hicli/sync.go @@ -28,6 +28,8 @@ import ( type syncContext struct { shouldWakeupRequestQueue bool + + evt *SyncComplete } func (h *HiClient) preProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error { @@ -61,10 +63,12 @@ func (h *HiClient) preProcessSyncResponse(ctx context.Context, resp *mautrix.Res func (h *HiClient) postProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) { h.Crypto.HandleOTKCounts(ctx, &resp.DeviceOTKCount) go h.asyncPostProcessSyncResponse(ctx, resp, since) - if ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue { + syncCtx := ctx.Value(syncContextKey).(*syncContext) + if syncCtx.shouldWakeupRequestQueue { h.WakeupRequestQueue() } h.firstSyncReceived = true + h.EventHandler(syncCtx.evt) } func (h *HiClient) asyncPostProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) { @@ -173,7 +177,7 @@ func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, return fmt.Errorf("failed to save receipts: %w", err) } case event.EphemeralEventTyping: - go h.DispatchEvent(&Typing{ + go h.EventHandler(&Typing{ RoomID: roomID, TypingEventContent: *evt.Content.AsTyping(), }) @@ -246,7 +250,7 @@ func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptio dbEvt.Content = contentWithoutFallback } var decryptionErr error - if evt.Type == event.EventEncrypted { + if evt.Type == event.EventEncrypted && dbEvt.RedactedBy == "" { dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt) if decryptionErr != nil { dbEvt.DecryptionError = decryptionErr.Error() @@ -336,6 +340,7 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R return err } } + var timelineRowTuples []database.TimelineRowTuple if len(timeline.Events) > 0 { timelineIDs := make([]database.EventRowID, len(timeline.Events)) for i, evt := range timeline.Events { @@ -370,7 +375,7 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R } h.paginationInterrupterLock.Unlock() } - err = h.DB.Timeline.Append(ctx, room.ID, timelineIDs) + timelineRowTuples, err = h.DB.Timeline.Append(ctx, room.ID, timelineIDs) if err != nil { return fmt.Errorf("failed to append timeline: %w", err) } @@ -387,12 +392,20 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R if timeline.PrevBatch != "" && room.PrevBatch == "" { updatedRoom.PrevBatch = timeline.PrevBatch } - if updatedRoom.CheckChangesAndCopyInto(room) { + roomChanged := updatedRoom.CheckChangesAndCopyInto(room) + if roomChanged { err = h.DB.Room.Upsert(ctx, updatedRoom) if err != nil { return fmt.Errorf("failed to save room data: %w", err) } } + if roomChanged || len(timelineRowTuples) > 0 { + ctx.Value(syncContextKey).(*syncContext).evt.Rooms[room.ID] = &SyncRoom{ + Meta: room, + Timeline: timelineRowTuples, + Reset: timeline.Limited, + } + } return nil } diff --git a/hicli/syncwrap.go b/hicli/syncwrap.go index 46d31e57..13837202 100644 --- a/hicli/syncwrap.go +++ b/hicli/syncwrap.go @@ -27,7 +27,7 @@ const ( func (h *hiSyncer) ProcessResponse(ctx context.Context, resp *mautrix.RespSync, since string) error { c := (*HiClient)(h) - ctx = context.WithValue(ctx, syncContextKey, &syncContext{}) + ctx = context.WithValue(ctx, syncContextKey, &syncContext{evt: &SyncComplete{Rooms: make(map[id.RoomID]*SyncRoom, len(resp.Rooms.Join))}}) err := c.preProcessSyncResponse(ctx, resp, since) if err != nil { return err From 409a7a81660404f290a687cd2b55a0a5be48bf7b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 1 Jun 2024 23:51:23 +0300 Subject: [PATCH 0246/1647] Add initial message sending support --- hicli/cryptohelper.go | 29 +-- hicli/database/event.go | 56 ++++- hicli/database/statestore.go | 9 + .../database/upgrades/00-latest-revision.sql | 3 + hicli/events.go | 5 + hicli/hitest/hitest.go | 24 ++- hicli/send.go | 196 ++++++++++++++++++ 7 files changed, 284 insertions(+), 38 deletions(-) create mode 100644 hicli/send.go diff --git a/hicli/cryptohelper.go b/hicli/cryptohelper.go index eb054af9..2a2e9626 100644 --- a/hicli/cryptohelper.go +++ b/hicli/cryptohelper.go @@ -8,14 +8,12 @@ package hicli import ( "context" - "errors" "fmt" "time" "github.com/rs/zerolog" "maunium.net/go/mautrix" - "maunium.net/go/mautrix/crypto" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -24,29 +22,14 @@ type hiCryptoHelper HiClient var _ mautrix.CryptoHelper = (*hiCryptoHelper)(nil) -func (h *hiCryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) { - h.encryptLock.Lock() - defer h.encryptLock.Unlock() - encrypted, err = h.Crypto.EncryptMegolmEvent(ctx, roomID, evtType, content) +func (h *hiCryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (*event.EncryptedEventContent, error) { + roomMeta, err := h.DB.Room.Get(ctx, roomID) if err != nil { - if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.NoGroupSession) && !errors.Is(err, crypto.SessionNotShared) { - return - } - h.Log.Debug(). - Err(err). - Str("room_id", roomID.String()). - Msg("Got session error while encrypting event, sharing group session and trying again") - var users []id.UserID - users, err = h.ClientStore.GetRoomJoinedOrInvitedMembers(ctx, roomID) - if err != nil { - err = fmt.Errorf("failed to get room member list: %w", err) - } else if err = h.Crypto.ShareGroupSession(ctx, roomID, users); err != nil { - err = fmt.Errorf("failed to share group session: %w", err) - } else if encrypted, err = h.Crypto.EncryptMegolmEvent(ctx, roomID, evtType, content); err != nil { - err = fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err) - } + return nil, fmt.Errorf("failed to get room metadata: %w", err) + } else if roomMeta == nil { + return nil, fmt.Errorf("unknown room") } - return + return (*HiClient)(h).Encrypt(ctx, roomMeta, evtType, content) } func (h *hiCryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*event.Event, error) { diff --git a/hicli/database/event.go b/hicli/database/event.go index 4ea50dd5..f4df8868 100644 --- a/hicli/database/event.go +++ b/hicli/database/event.go @@ -25,22 +25,35 @@ import ( const ( getEventBaseQuery = ` SELECT rowid, -1, room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, - redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, reactions, last_edit_rowid + transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, reactions, last_edit_rowid FROM event ` getManyEventsByRowID = getEventBaseQuery + `WHERE rowid IN (%s)` getEventByID = getEventBaseQuery + `WHERE event_id = $1` getFailedEventsByMegolmSessionID = getEventBaseQuery + `WHERE room_id = $1 AND megolm_session_id = $2 AND decryption_error IS NOT NULL` - upsertEventQuery = ` - INSERT INTO event (room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) + insertEventBaseQuery = ` + INSERT INTO event ( + room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, + transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error + ) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) + ` + insertEventQuery = insertEventBaseQuery + `RETURNING rowid` + upsertEventQuery = insertEventBaseQuery + ` ON CONFLICT (event_id) DO UPDATE SET decrypted=COALESCE(event.decrypted, excluded.decrypted), decrypted_type=COALESCE(event.decrypted_type, excluded.decrypted_type), redacted_by=COALESCE(event.redacted_by, excluded.redacted_by), - decryption_error=CASE WHEN COALESCE(event.decrypted, excluded.decrypted) IS NULL THEN COALESCE(excluded.decryption_error, event.decryption_error) END + decryption_error=CASE WHEN COALESCE(event.decrypted, excluded.decrypted) IS NULL THEN COALESCE(excluded.decryption_error, event.decryption_error) END, + timestamp=excluded.timestamp, + unsigned=COALESCE(excluded.unsigned, event.unsigned) + ON CONFLICT (transaction_id) DO UPDATE + SET event_id=excluded.event_id, + timestamp=excluded.timestamp, + unsigned=excluded.unsigned RETURNING rowid ` + updateEventIDQuery = `UPDATE event SET event_id=$2 WHERE rowid=$1` updateEventDecryptedQuery = `UPDATE event SET decrypted = $1, decrypted_type = $2, decryption_error = NULL WHERE rowid = $3` getEventReactionsQuery = getEventBaseQuery + ` WHERE room_id = ? @@ -93,6 +106,18 @@ func (eq *EventQuery) Upsert(ctx context.Context, evt *Event) (rowID EventRowID, return } +func (eq *EventQuery) Insert(ctx context.Context, evt *Event) (rowID EventRowID, err error) { + err = eq.GetDB().QueryRow(ctx, insertEventQuery, evt.sqlVariables()...).Scan(&rowID) + if err == nil { + evt.RowID = rowID + } + return +} + +func (eq *EventQuery) UpdateID(ctx context.Context, rowID EventRowID, newID id.EventID) error { + return eq.Exec(ctx, updateEventIDQuery, rowID, newID) +} + func (eq *EventQuery) UpdateDecrypted(ctx context.Context, rowID EventRowID, decrypted json.RawMessage, decryptedType string) error { return eq.Exec(ctx, updateEventDecryptedQuery, unsafeJSONString(decrypted), decryptedType, rowID) } @@ -242,6 +267,8 @@ type Event struct { DecryptedType string `json:"decrypted_type,omitempty"` Unsigned json.RawMessage `json:"unsigned,omitempty"` + TransactionID string `json:"transaction_id,omitempty"` + RedactedBy id.EventID `json:"redacted_by,omitempty"` RelatesTo id.EventID `json:"relates_to,omitempty"` RelationType event.RelationType `json:"relation_type,omitempty"` @@ -263,8 +290,12 @@ func MautrixToEvent(evt *event.Event) *Event { Timestamp: time.UnixMilli(evt.Timestamp), Content: evt.Content.VeryRaw, MegolmSessionID: getMegolmSessionID(evt), + TransactionID: evt.Unsigned.TransactionID, } - dbEvt.RelatesTo, dbEvt.RelationType = getRelatesTo(evt) + if !strings.HasPrefix(dbEvt.TransactionID, "hicli-mautrix-go_") { + dbEvt.TransactionID = "" + } + dbEvt.RelatesTo, dbEvt.RelationType = getRelatesToFromEvent(evt) dbEvt.Unsigned, _ = json.Marshal(&evt.Unsigned) if evt.Unsigned.RedactedBecause != nil { dbEvt.RedactedBy = evt.Unsigned.RedactedBecause.ID @@ -296,7 +327,7 @@ func (e *Event) AsRawMautrix() *event.Event { func (e *Event) Scan(row dbutil.Scannable) (*Event, error) { var timestamp int64 - var redactedBy, relatesTo, relationType, megolmSessionID, decryptionError, decryptedType sql.NullString + var transactionID, redactedBy, relatesTo, relationType, megolmSessionID, decryptionError, decryptedType sql.NullString err := row.Scan( &e.RowID, &e.TimelineRowID, @@ -310,6 +341,7 @@ func (e *Event) Scan(row dbutil.Scannable) (*Event, error) { (*[]byte)(&e.Decrypted), &decryptedType, (*[]byte)(&e.Unsigned), + &transactionID, &redactedBy, &relatesTo, &relationType, @@ -322,6 +354,7 @@ func (e *Event) Scan(row dbutil.Scannable) (*Event, error) { return nil, err } e.Timestamp = time.UnixMilli(timestamp) + e.TransactionID = transactionID.String e.RedactedBy = id.EventID(redactedBy.String) e.RelatesTo = id.EventID(relatesTo.String) e.RelationType = event.RelationType(relatesTo.String) @@ -334,11 +367,15 @@ func (e *Event) Scan(row dbutil.Scannable) (*Event, error) { var relatesToPath = exgjson.Path("m.relates_to", "event_id") var relationTypePath = exgjson.Path("m.relates_to", "rel_type") -func getRelatesTo(evt *event.Event) (id.EventID, event.RelationType) { +func getRelatesToFromEvent(evt *event.Event) (id.EventID, event.RelationType) { if evt.StateKey != nil { return "", "" } - results := gjson.GetManyBytes(evt.Content.VeryRaw, relatesToPath, relationTypePath) + return GetRelatesToFromBytes(evt.Content.VeryRaw) +} + +func GetRelatesToFromBytes(content []byte) (id.EventID, event.RelationType) { + results := gjson.GetManyBytes(content, relatesToPath, relationTypePath) if len(results) == 2 && results[0].Exists() && results[1].Exists() && results[0].Type == gjson.String && results[1].Type == gjson.String { return id.EventID(results[0].Str), event.RelationType(results[1].Str) } @@ -372,6 +409,7 @@ func (e *Event) sqlVariables() []any { unsafeJSONString(e.Decrypted), dbutil.StrPtr(e.DecryptedType), unsafeJSONString(e.Unsigned), + dbutil.StrPtr(e.TransactionID), dbutil.StrPtr(e.RedactedBy), dbutil.StrPtr(e.RelatesTo), dbutil.StrPtr(e.RelationType), diff --git a/hicli/database/statestore.go b/hicli/database/statestore.go index e0471ef2..baf84df1 100644 --- a/hicli/database/statestore.go +++ b/hicli/database/statestore.go @@ -30,6 +30,10 @@ const ( LEFT JOIN event ON event.rowid = cs.event_rowid WHERE cs.room_id = $1 AND cs.event_type = $2 AND cs.state_key = $3 ` + getRoomJoinedMembersQuery = ` + SELECT state_key FROM current_state + WHERE room_id = $1 AND event_type = 'm.room.member' AND membership = 'join' + ` getRoomJoinedOrInvitedMembersQuery = ` SELECT state_key FROM current_state WHERE room_id = $1 AND event_type = 'm.room.member' AND membership IN ('join', 'invite') @@ -96,6 +100,11 @@ func (c *ClientStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) return } +func (c *ClientStateStore) GetRoomJoinedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) { + rows, err := c.Query(ctx, getRoomJoinedMembersQuery, roomID) + return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() +} + func (c *ClientStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) { rows, err := c.Query(ctx, getRoomJoinedOrInvitedMembersQuery, roomID) return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() diff --git a/hicli/database/upgrades/00-latest-revision.sql b/hicli/database/upgrades/00-latest-revision.sql index 8d99a315..df6499a1 100644 --- a/hicli/database/upgrades/00-latest-revision.sql +++ b/hicli/database/upgrades/00-latest-revision.sql @@ -66,6 +66,8 @@ CREATE TABLE event ( decrypted_type TEXT, unsigned TEXT NOT NULL, + transaction_id TEXT, + redacted_by TEXT, relates_to TEXT, relation_type TEXT, @@ -77,6 +79,7 @@ CREATE TABLE event ( last_edit_rowid INTEGER, CONSTRAINT event_id_unique_key UNIQUE (event_id), + CONSTRAINT transaction_id_unique_key UNIQUE (transaction_id), CONSTRAINT event_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE ) STRICT; CREATE INDEX event_room_id_idx ON event (room_id); diff --git a/hicli/events.go b/hicli/events.go index 6cdfc98a..75894111 100644 --- a/hicli/events.go +++ b/hicli/events.go @@ -31,3 +31,8 @@ type Typing struct { RoomID id.RoomID `json:"room_id"` event.TypingEventContent } + +type SendComplete struct { + Event *database.Event `json:"event"` + Error error `json:"error"` +} diff --git a/hicli/hitest/hitest.go b/hicli/hitest/hitest.go index c55c27dc..88ac287a 100644 --- a/hicli/hitest/hitest.go +++ b/hicli/hitest/hitest.go @@ -10,9 +10,7 @@ import ( "context" "fmt" "io" - "os" - "os/signal" - "syscall" + "strings" "github.com/chzyer/readline" _ "github.com/mattn/go-sqlite3" @@ -24,6 +22,7 @@ import ( "go.mau.fi/zeroconfig" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/hicli" "maunium.net/go/mautrix/id" ) @@ -88,8 +87,21 @@ func main() { } rl.SetPrompt("> ") - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - <-c + for { + line, err := rl.Readline() + if err != nil { + break + } + fields := strings.Fields(line) + switch strings.ToLower(fields[0]) { + case "send": + resp, err := cli.Send(ctx, id.RoomID(fields[1]), event.EventMessage, &event.MessageEventContent{ + Body: strings.Join(fields[2:], " "), + MsgType: event.MsgText, + }) + _, _ = fmt.Fprintln(rl, err) + _, _ = fmt.Fprintf(rl, "%+v\n", resp) + } + } cli.Stop() } diff --git a/hicli/send.go b/hicli/send.go new file mode 100644 index 00000000..db26eca8 --- /dev/null +++ b/hicli/send.go @@ -0,0 +1,196 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/hicli/database" + "maunium.net/go/mautrix/id" +) + +func (h *HiClient) Send(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (*database.Event, error) { + roomMeta, err := h.DB.Room.Get(ctx, roomID) + if err != nil { + return nil, fmt.Errorf("failed to get room metadata: %w", err) + } else if roomMeta == nil { + return nil, fmt.Errorf("unknown room") + } + var decryptedType event.Type + var decryptedContent json.RawMessage + var megolmSessionID id.SessionID + if roomMeta.EncryptionEvent != nil && evtType != event.EventReaction { + decryptedType = evtType + decryptedContent, err = json.Marshal(content) + if err != nil { + return nil, fmt.Errorf("failed to marshal event content: %w", err) + } + encryptedContent, err := h.Encrypt(ctx, roomMeta, evtType, content) + if err != nil { + return nil, fmt.Errorf("failed to encrypt event: %w", err) + } + megolmSessionID = encryptedContent.SessionID + content = encryptedContent + evtType = event.EventEncrypted + } + mainContent, err := json.Marshal(content) + if err != nil { + return nil, fmt.Errorf("failed to marshal event content: %w", err) + } + var zero database.EventRowID + txnID := "hicli-" + h.Client.TxnID() + relatesTo, relationType := database.GetRelatesToFromBytes(mainContent) + dbEvt := &database.Event{ + RoomID: roomID, + ID: id.EventID(fmt.Sprintf("~%s", txnID)), + Sender: h.Account.UserID, + Type: evtType.Type, + Timestamp: time.Now(), + Content: mainContent, + Decrypted: decryptedContent, + DecryptedType: decryptedType.Type, + Unsigned: []byte("{}"), + TransactionID: txnID, + RelatesTo: relatesTo, + RelationType: relationType, + MegolmSessionID: megolmSessionID, + DecryptionError: "", + Reactions: map[string]int{}, + LastEditRowID: &zero, + } + _, err = h.DB.Event.Insert(ctx, dbEvt) + if err != nil { + return nil, fmt.Errorf("failed to insert event into database: %w", err) + } + go func() { + var err error + defer func() { + h.EventHandler(&SendComplete{ + Event: dbEvt, + Error: err, + }) + }() + var resp *mautrix.RespSendEvent + resp, err = h.Client.SendMessageEvent(ctx, roomID, evtType, content, mautrix.ReqSendEvent{ + Timestamp: dbEvt.Timestamp.UnixMilli(), + TransactionID: txnID, + DontEncrypt: true, + }) + if err != nil { + // TODO save send error to db? + err = fmt.Errorf("failed to send event: %w", err) + return + } + dbEvt.ID = resp.EventID + err = h.DB.Event.UpdateID(ctx, dbEvt.RowID, dbEvt.ID) + if err != nil { + err = fmt.Errorf("failed to update event ID in database: %w", err) + } + }() + return dbEvt, nil +} + +func (h *HiClient) Encrypt(ctx context.Context, room *database.Room, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) { + h.encryptLock.Lock() + defer h.encryptLock.Unlock() + encrypted, err = h.Crypto.EncryptMegolmEvent(ctx, room.ID, evtType, content) + if errors.Is(err, crypto.SessionExpired) || errors.Is(err, crypto.NoGroupSession) || errors.Is(err, crypto.SessionNotShared) { + if err = h.shareGroupSession(ctx, room); err != nil { + err = fmt.Errorf("failed to share group session: %w", err) + } else if encrypted, err = h.Crypto.EncryptMegolmEvent(ctx, room.ID, evtType, content); err != nil { + err = fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err) + } + } + return +} + +func (h *HiClient) EnsureGroupSessionShared(ctx context.Context, roomID id.RoomID) error { + h.encryptLock.Lock() + defer h.encryptLock.Unlock() + if session, err := h.CryptoStore.GetOutboundGroupSession(ctx, roomID); err != nil { + return fmt.Errorf("failed to get previous outbound group session: %w", err) + } else if session != nil && session.Shared && !session.Expired() { + return nil + } else if roomMeta, err := h.DB.Room.Get(ctx, roomID); err != nil { + return fmt.Errorf("failed to get room metadata: %w", err) + } else if roomMeta == nil { + return fmt.Errorf("unknown room") + } else { + return h.shareGroupSession(ctx, roomMeta) + } +} + +func (h *HiClient) shareGroupSession(ctx context.Context, room *database.Room) error { + if !room.HasMemberList { + resp, err := h.Client.Members(ctx, room.ID) + if err != nil { + return fmt.Errorf("failed to get room member list: %w", err) + } + err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { + for _, evt := range resp.Chunk { + dbEvt, err := h.processEvent(ctx, evt, nil, true) + if err != nil { + return err + } + membership := event.Membership(evt.Content.Raw["membership"].(string)) + err = h.DB.CurrentState.Set(ctx, room.ID, evt.Type, *evt.StateKey, dbEvt.RowID, membership) + if err != nil { + return err + } + } + return nil + }) + if err != nil { + return fmt.Errorf("failed to process room member list: %w", err) + } + } + shareToInvited := h.shouldShareKeysToInvitedUsers(ctx, room.ID) + var users []id.UserID + var err error + if shareToInvited { + users, err = h.ClientStore.GetRoomJoinedOrInvitedMembers(ctx, room.ID) + } else { + users, err = h.ClientStore.GetRoomJoinedMembers(ctx, room.ID) + } + if err != nil { + return fmt.Errorf("failed to get room member list: %w", err) + } else if err = h.Crypto.ShareGroupSession(ctx, room.ID, users); err != nil { + return fmt.Errorf("failed to share group session: %w", err) + } + return nil +} + +func (h *HiClient) shouldShareKeysToInvitedUsers(ctx context.Context, roomID id.RoomID) bool { + historyVisibility, err := h.DB.CurrentState.Get(ctx, roomID, event.StateHistoryVisibility, "") + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get history visibility event") + return false + } + mautrixEvt := historyVisibility.AsRawMautrix() + err = mautrixEvt.Content.ParseRaw(mautrixEvt.Type) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to parse history visibility event") + return false + } + hv, ok := mautrixEvt.Content.Parsed.(*event.HistoryVisibilityEventContent) + if !ok { + zerolog.Ctx(ctx).Warn().Msg("Unexpected parsed content type for history visibility event") + return false + } + return hv.HistoryVisibility == event.HistoryVisibilityInvited || + hv.HistoryVisibility == event.HistoryVisibilityShared || + hv.HistoryVisibility == event.HistoryVisibilityWorldReadable +} From 557fb94669499796a76c8eecdc700d545217cb68 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 2 Jun 2024 21:30:42 +0300 Subject: [PATCH 0247/1647] Allow using separate crypto db and adjust other things --- go.mod | 2 +- go.sum | 4 +-- hicli/hicli.go | 30 ++++++++++++++++++----- hicli/hitest/hitest.go | 4 +-- hicli/send.go | 55 +++++++++++++++++++++++++----------------- 5 files changed, 62 insertions(+), 33 deletions(-) diff --git a/go.mod b/go.mod index 6213dbf5..6562fcc8 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/tidwall/gjson v1.17.1 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.1 - go.mau.fi/util v0.4.3-0.20240516141139-2ebe792cd8f7 + go.mau.fi/util v0.4.3-0.20240602182959-603e3d7117c1 go.mau.fi/zeroconfig v0.1.2 golang.org/x/crypto v0.23.0 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 diff --git a/go.sum b/go.sum index 307aa876..1de567b2 100644 --- a/go.sum +++ b/go.sum @@ -43,8 +43,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.1 h1:3bajkSilaCbjdKVsKdZjZCLBNPL9pYzrCakKaf4U49U= github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.4.3-0.20240516141139-2ebe792cd8f7 h1:2hnc2iS7usHT3aqIQ8HVtKtPgic+13EVSdZ1m8UBL/E= -go.mau.fi/util v0.4.3-0.20240516141139-2ebe792cd8f7/go.mod h1:m+PJpPMadAW6cj3ldyuO5bLhFreWdwcu+3QTwYNGlGk= +go.mau.fi/util v0.4.3-0.20240602182959-603e3d7117c1 h1:HFo2MX0/echGXJhlVzWlNHH1I/4dtkq9UkIBsVhX/tU= +go.mau.fi/util v0.4.3-0.20240602182959-603e3d7117c1/go.mod h1:4etkIWotzgsWICu/1I34Y2LFFekINhFsyWYHXEsxXdY= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= diff --git a/hicli/hicli.go b/hicli/hicli.go index 5e2957c0..7caee4a8 100644 --- a/hicli/hicli.go +++ b/hicli/hicli.go @@ -57,11 +57,18 @@ type HiClient struct { var ErrTimelineReset = errors.New("got limited timeline sync response") -func New(rawDB *dbutil.Database, log zerolog.Logger, pickleKey []byte, evtHandler func(any)) *HiClient { - rawDB.Owner = "hicli" - rawDB.IgnoreForeignTables = true +func New(rawDB, cryptoDB *dbutil.Database, log zerolog.Logger, pickleKey []byte, evtHandler func(any)) *HiClient { + if cryptoDB == nil { + cryptoDB = rawDB + } + if rawDB.Owner == "" { + rawDB.Owner = "hicli" + rawDB.IgnoreForeignTables = true + } + if rawDB.Log == nil { + rawDB.Log = dbutil.ZeroLogger(log.With().Str("db_section", "hicli").Logger()) + } db := database.New(rawDB) - db.Log = dbutil.ZeroLogger(log.With().Str("db_section", "hicli").Logger()) c := &HiClient{ DB: db, Log: log, @@ -88,7 +95,7 @@ func New(rawDB *dbutil.Database, log zerolog.Logger, pickleKey []byte, evtHandle StateStore: c.ClientStore, Log: log.With().Str("component", "mautrix client").Logger(), } - c.CryptoStore = crypto.NewSQLCryptoStore(rawDB, dbutil.ZeroLogger(log.With().Str("db_section", "crypto").Logger()), "", "", pickleKey) + c.CryptoStore = crypto.NewSQLCryptoStore(cryptoDB, dbutil.ZeroLogger(log.With().Str("db_section", "crypto").Logger()), "", "", pickleKey) cryptoLog := log.With().Str("component", "crypto").Logger() c.Crypto = crypto.NewOlmMachine(c.Client, &cryptoLog, c.CryptoStore, c.ClientStore) c.Crypto.SessionReceived = c.handleReceivedMegolmSession @@ -102,7 +109,10 @@ func (h *HiClient) IsLoggedIn() bool { return h.Account != nil } -func (h *HiClient) Start(ctx context.Context, userID id.UserID) error { +func (h *HiClient) Start(ctx context.Context, userID id.UserID, expectedAccount *database.Account) error { + if expectedAccount != nil && userID != expectedAccount.UserID { + panic(fmt.Errorf("invalid parameters: different user ID in expected account and user ID")) + } err := h.DB.Upgrade(ctx) if err != nil { return fmt.Errorf("failed to upgrade hicli db: %w", err) @@ -114,6 +124,14 @@ func (h *HiClient) Start(ctx context.Context, userID id.UserID) error { account, err := h.DB.Account.Get(ctx, userID) if err != nil { return err + } else if account == nil && expectedAccount != nil { + err = h.DB.Account.Put(ctx, expectedAccount) + if err != nil { + return err + } + account = expectedAccount + } else if expectedAccount != nil && expectedAccount.DeviceID != account.DeviceID { + return fmt.Errorf("device ID mismatch: expected %s, got %s", expectedAccount.DeviceID, account.DeviceID) } if account != nil { zerolog.Ctx(ctx).Debug().Stringer("user_id", account.UserID).Msg("Preparing client with existing credentials") diff --git a/hicli/hitest/hitest.go b/hicli/hitest/hitest.go index 88ac287a..97705304 100644 --- a/hicli/hitest/hitest.go +++ b/hicli/hitest/hitest.go @@ -50,7 +50,7 @@ func main() { rawDB := exerrors.Must(dbutil.NewWithDialect("hicli.db", "sqlite3-fk-wal")) ctx := log.WithContext(context.Background()) - cli := hicli.New(rawDB, *log, []byte("meow"), func(a any) { + cli := hicli.New(rawDB, nil, *log, []byte("meow"), func(a any) { _, _ = fmt.Fprintf(rl, "Received event of type %T\n", a) switch evt := a.(type) { case *hicli.SyncComplete: @@ -75,7 +75,7 @@ func main() { } }) userID, _ := cli.DB.Account.GetFirstUserID(ctx) - exerrors.PanicIfNotNil(cli.Start(ctx, userID)) + exerrors.PanicIfNotNil(cli.Start(ctx, userID, nil)) if !cli.IsLoggedIn() { rl.SetPrompt("User ID: ") userID := id.UserID(exerrors.Must(rl.Readline())) diff --git a/hicli/send.go b/hicli/send.go index db26eca8..66175e75 100644 --- a/hicli/send.go +++ b/hicli/send.go @@ -133,33 +133,44 @@ func (h *HiClient) EnsureGroupSessionShared(ctx context.Context, roomID id.RoomI } } -func (h *HiClient) shareGroupSession(ctx context.Context, room *database.Room) error { - if !room.HasMemberList { - resp, err := h.Client.Members(ctx, room.ID) - if err != nil { - return fmt.Errorf("failed to get room member list: %w", err) - } - err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { - for _, evt := range resp.Chunk { - dbEvt, err := h.processEvent(ctx, evt, nil, true) - if err != nil { - return err - } - membership := event.Membership(evt.Content.Raw["membership"].(string)) - err = h.DB.CurrentState.Set(ctx, room.ID, evt.Type, *evt.StateKey, dbEvt.RowID, membership) - if err != nil { - return err - } +func (h *HiClient) loadMembers(ctx context.Context, room *database.Room) error { + if room.HasMemberList { + return nil + } + resp, err := h.Client.Members(ctx, room.ID) + if err != nil { + return fmt.Errorf("failed to get room member list: %w", err) + } + err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { + for _, evt := range resp.Chunk { + dbEvt, err := h.processEvent(ctx, evt, nil, true) + if err != nil { + return err + } + membership := event.Membership(evt.Content.Raw["membership"].(string)) + err = h.DB.CurrentState.Set(ctx, room.ID, evt.Type, *evt.StateKey, dbEvt.RowID, membership) + if err != nil { + return err } - return nil - }) - if err != nil { - return fmt.Errorf("failed to process room member list: %w", err) } + return h.DB.Room.Upsert(ctx, &database.Room{ + ID: room.ID, + HasMemberList: true, + }) + }) + if err != nil { + return fmt.Errorf("failed to process room member list: %w", err) + } + return nil +} + +func (h *HiClient) shareGroupSession(ctx context.Context, room *database.Room) error { + err := h.loadMembers(ctx, room) + if err != nil { + return err } shareToInvited := h.shouldShareKeysToInvitedUsers(ctx, room.ID) var users []id.UserID - var err error if shareToInvited { users, err = h.ClientStore.GetRoomJoinedOrInvitedMembers(ctx, room.ID) } else { From a599b15466ae09b69d7b4fbedce2a930002e1da4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 3 Jun 2024 22:33:36 +0300 Subject: [PATCH 0248/1647] Add bridge login interface --- bridgev2/cmdhandler.go | 1 + bridgev2/cmdlogin.go | 288 ++++++++++++++++++++++++++++++++ bridgev2/cmdmeta.go | 3 + bridgev2/login-step.schema.json | 138 +++++++++++++++ bridgev2/login-steps.uml | 43 +++++ bridgev2/login.go | 124 ++++++++++++++ bridgev2/matrix/connector.go | 17 +- bridgev2/matrix/provisioning.go | 241 ++++++++++++++++++++++++++ bridgev2/networkinterface.go | 5 +- bridgev2/userlogin.go | 2 +- go.mod | 2 + go.sum | 3 + 12 files changed, 858 insertions(+), 9 deletions(-) create mode 100644 bridgev2/cmdlogin.go create mode 100644 bridgev2/login-step.schema.json create mode 100644 bridgev2/login-steps.uml create mode 100644 bridgev2/login.go create mode 100644 bridgev2/matrix/provisioning.go diff --git a/bridgev2/cmdhandler.go b/bridgev2/cmdhandler.go index 9f9c69ec..55db056f 100644 --- a/bridgev2/cmdhandler.go +++ b/bridgev2/cmdhandler.go @@ -24,6 +24,7 @@ type CommandState struct { Next MinimalCommandHandler Action string Meta any + Cancel func() } type CommandHandler interface { diff --git a/bridgev2/cmdlogin.go b/bridgev2/cmdlogin.go new file mode 100644 index 00000000..b9f6cd98 --- /dev/null +++ b/bridgev2/cmdlogin.go @@ -0,0 +1,288 @@ +// Copyright (c) 2022 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "regexp" + "strings" + "time" + + "github.com/skip2/go-qrcode" + "golang.org/x/net/html" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +var CommandLogin = &FullHandler{ + Func: fnLogin, + Name: "login", + Help: HelpMeta{ + Section: HelpSectionAuth, + Description: "Log into the bridge", + }, +} + +func formatFlowsReply(flows []LoginFlow) string { + var buf strings.Builder + for _, flow := range flows { + _, _ = fmt.Fprintf(&buf, "* `%s` - %s\n", flow.ID, flow.Description) + } + return buf.String() +} + +func fnLogin(ce *CommandEvent) { + flows := ce.Bridge.Network.GetLoginFlows() + var chosenFlowID string + if len(ce.Args) > 0 { + inputFlowID := strings.ToLower(ce.Args[0]) + for _, flow := range flows { + if flow.ID == inputFlowID { + chosenFlowID = flow.ID + break + } + } + if chosenFlowID == "" { + ce.Reply("Invalid login flow `%s`. Available options:\n\n%s", ce.Args[0], formatFlowsReply(flows)) + return + } + } else if len(flows) == 1 { + chosenFlowID = flows[0].ID + } else { + ce.Reply("Please specify a login flow, e.g. `login %s`.\n\n%s", flows[0].ID, formatFlowsReply(flows)) + return + } + + login, err := ce.Bridge.Network.CreateLogin(ce.Ctx, ce.User, chosenFlowID) + if err != nil { + ce.Reply("Failed to prepare login process: %v", err) + return + } + nextStep, err := login.Start(ce.Ctx) + if err != nil { + ce.Reply("Failed to start login: %v", err) + return + } + doLoginStep(ce, login, nextStep) +} + +type userInputLoginCommandState struct { + Login LoginProcessUserInput + Data map[string]string + RemainingFields []LoginInputDataField +} + +func (uilcs *userInputLoginCommandState) promptNext(ce *CommandEvent) { + // TODO reply prompting field + ce.User.CommandState.Store(&CommandState{ + Next: MinimalCommandHandlerFunc(uilcs.submitNext), + Action: "Login", + Meta: uilcs, + Cancel: uilcs.Login.Cancel, + }) +} + +func (uilcs *userInputLoginCommandState) submitNext(ce *CommandEvent) { + field := uilcs.RemainingFields[0] + var err error + uilcs.Data[field.ID], err = field.Validate(ce.RawArgs) + if err != nil { + ce.Reply("Invalid value: %v", err) + return + } else if len(uilcs.RemainingFields) > 1 { + uilcs.RemainingFields = uilcs.RemainingFields[1:] + uilcs.promptNext(ce) + return + } + ce.User.CommandState.Store(nil) + if nextStep, err := uilcs.Login.SubmitUserInput(ce.Ctx, uilcs.Data); err != nil { + ce.Reply("Failed to submit input: %v", err) + } else { + doLoginStep(ce, uilcs.Login, nextStep) + } +} + +const qrSizePx = 512 + +func sendQR(ce *CommandEvent, qr string, prevEventID *id.EventID) error { + qrData, err := qrcode.Encode(qr, qrcode.Low, qrSizePx) + if err != nil { + return fmt.Errorf("failed to encode QR code: %w", err) + } + qrMXC, qrFile, err := ce.Bot.UploadMedia(ce.Ctx, ce.RoomID, qrData, "qr.png", "image/png") + if err != nil { + return fmt.Errorf("failed to upload image: %w", err) + } + content := &event.MessageEventContent{ + MsgType: event.MsgImage, + FileName: "qr.png", + URL: qrMXC, + File: qrFile, + + Body: qr, + Format: event.FormatHTML, + FormattedBody: fmt.Sprintf("
%s
", html.EscapeString(qr)), + } + if *prevEventID != "" { + content.SetEdit(*prevEventID) + } + newEventID, err := ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: content}, time.Now()) + if err != nil { + return err + } + if *prevEventID == "" { + *prevEventID = newEventID.EventID + } + return nil +} + +type contextKey int + +const ( + contextKeyPrevEventID contextKey = iota +) + +func doLoginDisplayAndWait(ce *CommandEvent, login LoginProcessDisplayAndWait, step *LoginStep) { + prevEvent, ok := ce.Ctx.Value(contextKeyPrevEventID).(*id.EventID) + if !ok { + prevEvent = new(id.EventID) + ce.Ctx = context.WithValue(ce.Ctx, contextKeyPrevEventID, prevEvent) + } + switch step.DisplayAndWaitParams.Type { + case LoginDisplayTypeQR: + err := sendQR(ce, step.DisplayAndWaitParams.Data, prevEvent) + if err != nil { + ce.Reply("Failed to send QR code: %v", err) + login.Cancel() + return + } + case LoginDisplayTypeEmoji: + ce.ReplyAdvanced(step.DisplayAndWaitParams.Data, false, false) + case LoginDisplayTypeCode: + ce.ReplyAdvanced(fmt.Sprintf("%s", html.EscapeString(step.DisplayAndWaitParams.Data)), false, true) + default: + ce.Reply("Unsupported display type %q", step.DisplayAndWaitParams.Type) + login.Cancel() + return + } + nextStep, err := login.Wait(ce.Ctx) + // Redact the QR code, unless the next step is refreshing the code (in which case the event is just edited) + if *prevEvent != "" && (nextStep == nil || nextStep.StepID != step.StepID) { + _, _ = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventRedaction, &event.Content{ + Parsed: &event.RedactionEventContent{ + Redacts: *prevEvent, + }, + }, time.Now()) + *prevEvent = "" + } + if err != nil { + ce.Reply("Login failed: %v", err) + return + } + doLoginStep(ce, login, nextStep) +} + +type cookieLoginCommandState struct { + Login LoginProcessCookies + Data *LoginCookiesParams +} + +func (clcs *cookieLoginCommandState) prompt(ce *CommandEvent) { + ce.User.CommandState.Store(&CommandState{ + Next: MinimalCommandHandlerFunc(clcs.submit), + Action: "Login", + Meta: clcs, + Cancel: clcs.Login.Cancel, + }) +} + +var curlCookieRegex = regexp.MustCompile(`-H '[cC]ookie: ([^']*)'`) + +func missingKeys(required []string, data map[string]string) (missing []string) { + for _, requiredKey := range required { + if _, ok := data[requiredKey]; !ok { + missing = append(missing, requiredKey) + } + } + return +} + +func (clcs *cookieLoginCommandState) submit(ce *CommandEvent) { + ce.Redact() + + cookies := make(map[string]string) + if strings.HasPrefix(strings.TrimSpace(ce.RawArgs), "curl") { + if len(clcs.Data.LocalStorageKeys) > 0 || len(clcs.Data.SpecialKeys) > 0 { + ce.Reply("Special keys and localStorage keys can't be extracted from curl commands - please provide the data as JSON instead") + return + } + cookieHeader := curlCookieRegex.FindStringSubmatch(ce.RawArgs) + if len(cookieHeader) != 2 { + ce.Reply("Couldn't find `-H 'Cookie: ...'` in curl command") + return + } + parsed := (&http.Request{Header: http.Header{"Cookie": {cookieHeader[1]}}}).Cookies() + for _, cookie := range parsed { + cookies[cookie.Name] = cookie.Value + } + } else { + err := json.Unmarshal([]byte(ce.RawArgs), &cookies) + if err != nil { + ce.Reply("Failed to parse input as JSON: %v", err) + return + } + } + missingCookies := missingKeys(clcs.Data.CookieKeys, cookies) + if len(missingCookies) > 0 { + ce.Reply("Missing required cookies: %+v", missingCookies) + return + } + missingLocalStorage := missingKeys(clcs.Data.LocalStorageKeys, cookies) + if len(missingLocalStorage) > 0 { + ce.Reply("Missing required localStorage keys: %+v", missingLocalStorage) + return + } + missingSpecial := missingKeys(clcs.Data.SpecialKeys, cookies) + if len(missingSpecial) > 0 { + ce.Reply("Missing required special keys: %+v", missingSpecial) + return + } + ce.User.CommandState.Store(nil) + nextStep, err := clcs.Login.SubmitCookies(ce.Ctx, cookies) + if err != nil { + ce.Reply("Login failed: %v", err) + } + doLoginStep(ce, clcs.Login, nextStep) +} + +func doLoginStep(ce *CommandEvent, login LoginProcess, step *LoginStep) { + ce.Reply(step.Instructions) + + switch step.Type { + case LoginStepTypeDisplayAndWait: + doLoginDisplayAndWait(ce, login.(LoginProcessDisplayAndWait), step) + case LoginStepTypeCookies: + (&cookieLoginCommandState{ + Login: login.(LoginProcessCookies), + Data: step.CookiesParams, + }).prompt(ce) + case LoginStepTypeUserInput: + (&userInputLoginCommandState{ + Login: login.(LoginProcessUserInput), + RemainingFields: step.UserInputParams.Fields, + Data: make(map[string]string), + }).promptNext(ce) + case LoginStepTypeComplete: + // Nothing to do other than instructions + default: + panic(fmt.Errorf("unknown login step type %q", step.Type)) + } +} diff --git a/bridgev2/cmdmeta.go b/bridgev2/cmdmeta.go index 4020f569..1ea611f9 100644 --- a/bridgev2/cmdmeta.go +++ b/bridgev2/cmdmeta.go @@ -37,6 +37,9 @@ var CommandCancel = &FullHandler{ if action == "" { action = "Unknown action" } + if state.Cancel != nil { + state.Cancel() + } ce.Reply("%s cancelled.", action) } else { ce.Reply("No ongoing command.") diff --git a/bridgev2/login-step.schema.json b/bridgev2/login-step.schema.json new file mode 100644 index 00000000..fa120778 --- /dev/null +++ b/bridgev2/login-step.schema.json @@ -0,0 +1,138 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://go.mau.fi/mautrix/bridgev2/login-step.json", + "title": "Login step data", + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": ["user_input", "cookies", "display_and_wait", "complete"] + }, + "step_id": { + "type": "string", + "description": "An unique ID identifying this step. This can be used to implement special behavior in clients." + }, + "instructions": { + "type": "string", + "description": "Human-readable instructions for completing this login step." + }, + "user_input": { + "type": "object", + "title": "User input params", + "description": "Parameters for the `user_input` login type", + "properties": { + "fields": { + "type": "array", + "description": "The list of fields that the user must fill", + "items": { + "title": "Field", + "description": "A field that the user must fill", + "type": "object", + "properties": { + "type": { + "type": "string", + "enum": ["username", "phone_number", "email", "password", "2fa_code"] + }, + "id": { + "type": "string", + "description": "The ID of the field. This should be used when submitting the form.", + "examples": ["uid", "email", "2fa_password", "meow"] + }, + "name": { + "type": "string", + "description": "The name of the field", + "examples": ["Username", "Password", "Phone number", "2FA code", "Meow"] + }, + "pattern": { + "type": "string", + "description": "A regular expression that the field value must match" + } + }, + "required": ["type", "id", "name"] + } + } + }, + "required": ["fields"] + }, + "cookies": { + "type": "object", + "title": "Cookie params", + "description": "Parameters for the `cookies` login type", + "properties": { + "url": { + "type": "string", + "description": "The URL to open when using a webview to extract cookies" + }, + "cookie_domain": { + "type": "string", + "description": "The domain of the cookies to extract" + }, + "cookie_keys": { + "type": "array", + "description": "The cookie names to extract", + "items": { + "type": "string" + } + }, + "local_storage_keys": { + "type": "array", + "description": "The local storage keys to extract", + "items": { + "type": "string" + } + }, + "special_keys": { + "type": "array", + "description": "Special-cased extraction types that clients must support individually", + "items": { + "type": "string" + } + } + }, + "required": ["url"] + }, + "display_and_wait": { + "type": "object", + "title": "Display and wait params", + "description": "Parameters for the `display_and_wait` login type", + "properties": { + "type": { + "type": "string", + "description": "The type of thing to display", + "enum": ["qr", "emoji", "code"] + }, + "data": { + "type": "string", + "description": "The thing to display (raw data for QR, unicode emoji for emoji, plain string for code)" + }, + "image_url": { + "type": "string", + "description": "An image containing the thing to display. If present, this is recommended over using data directly. For emojis, the URL to the canonical image representation of the emoji" + } + }, + "required": ["type", "data"] + }, + "complete": { + "type": "object", + "title": "Login complete information", + "description": "Information about a successful login", + "properties": { + "user_login_id": { + "type": "string", + "description": "The ID of the user login entry" + } + } + } + }, + "required": [ + "type", + "step_id", + "instructions" + ], + "oneOf": [ + {"title":"User input type","properties":{"type": {"type":"string","const": "user_input"}}, "required": ["user_input"]}, + {"title":"Cookies type","properties":{"type": {"type":"string","const": "cookies"}}, "required": ["cookies"]}, + {"title":"Display and wait type","properties":{"type": {"type":"string","const": "display_and_wait"}}, "required": ["display_and_wait"]}, + {"title":"Login complete","properties":{"type": {"type":"string","const": "complete"}}} + ] +} diff --git a/bridgev2/login-steps.uml b/bridgev2/login-steps.uml new file mode 100644 index 00000000..5af9c88e --- /dev/null +++ b/bridgev2/login-steps.uml @@ -0,0 +1,43 @@ +title Login flows + +participant User +participant Client +participant Bridge +participant User's device + +alt Username+Password/Phone number/2FA code + Client->+Bridge: /login + Bridge->-Client: step=user_input, fields=[...] + Client->User: input box(es) + User->Client: submit input + Client->+Bridge: /login/user_input + Bridge->-Client: success=true, step=next step +end + +alt Cookies + Client->+Bridge: /login + Bridge->-Client: step=cookies, url=..., cookies=[...] + Client->User: webview + User->Client: login in webview + Client->Bridge: /login/cookies + Bridge->-Client: success=true, step=next step +end + +alt QR/Emoji/Code + Client->+Bridge: /login + Bridge->-Client: step=display_and_wait, data=... + Client->+Bridge: /login/wait + Client->User: display QR/emoji/code + loop Refresh QR + Bridge->-Client: step=display_and_wait, data=new QR + Client->User: display new QR + Client->+Bridge: /login/wait + end +else Successful case + User->User's device: Scan QR/tap emoji/enter code + User's device->Bridge: Login successful + Bridge->-Client: success=true, step=next step +else Error + Bridge->Client: error=timeout + Client->User: error +end diff --git a/bridgev2/login.go b/bridgev2/login.go new file mode 100644 index 00000000..31500680 --- /dev/null +++ b/bridgev2/login.go @@ -0,0 +1,124 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + + "maunium.net/go/mautrix/bridgev2/networkid" +) + +type LoginProcess interface { + Start(ctx context.Context) (*LoginStep, error) + Cancel() +} + +type LoginProcessDisplayAndWait interface { + LoginProcess + Wait(ctx context.Context) (*LoginStep, error) +} + +type LoginProcessUserInput interface { + LoginProcess + SubmitUserInput(ctx context.Context, input map[string]string) (*LoginStep, error) +} + +type LoginProcessCookies interface { + LoginProcess + SubmitCookies(ctx context.Context, cookies map[string]string) (*LoginStep, error) +} + +type LoginFlow struct { + Name string `json:"name"` + Description string `json:"description"` + ID string `json:"id"` +} + +type LoginStepType string + +const ( + LoginStepTypeUserInput LoginStepType = "user_input" + LoginStepTypeCookies LoginStepType = "cookies" + LoginStepTypeDisplayAndWait LoginStepType = "display_and_wait" + LoginStepTypeComplete LoginStepType = "complete" +) + +type LoginDisplayType string + +const ( + LoginDisplayTypeQR LoginDisplayType = "qr" + LoginDisplayTypeEmoji LoginDisplayType = "emoji" + LoginDisplayTypeCode LoginDisplayType = "code" +) + +type LoginStep struct { + // The type of login step + Type LoginStepType `json:"type"` + // A unique ID for this step. The ID should be same for every login using the same flow, + // but it should be different for different bridges and step types. + // + // For example, Telegram's QR scan followed by a 2-factor password + // might use the IDs `fi.mau.telegram.qr` and `fi.mau.telegram.2fa_password`. + StepID string `json:"step_id"` + // Instructions contains human-readable instructions for completing the login step. + Instructions string `json:"instructions"` + + // Exactly one of the following structs must be filled depending on the step type. + + DisplayAndWaitParams *LoginDisplayAndWaitParams `json:"display_and_wait"` + CookiesParams *LoginCookiesParams `json:"cookies"` + UserInputParams *LoginUserInputParams `json:"user_input"` + CompleteParams *LoginCompleteParams `json:"complete"` +} + +type LoginDisplayAndWaitParams struct { + // The type of thing to display (QR, emoji or text code) + Type LoginDisplayType `json:"type"` + // The thing to display (raw data for QR, unicode emoji for emoji, plain string for code) + Data string `json:"data"` + // An image containing the thing to display. If present, this is recommended over using data directly. + // For emojis, the URL to the canonical image representation of the emoji + ImageURL string `json:"image_url,omitempty"` +} + +type LoginCookiesParams struct { + URL string `json:"url"` + + CookieDomain string `json:"cookie_domain,omitempty"` + CookieKeys []string `json:"cookie_keys,omitempty"` + LocalStorageKeys []string `json:"local_storage_keys,omitempty"` + SpecialKeys []string `json:"special_keys,omitempty"` +} + +type LoginInputFieldType string + +const ( + LoginInputFieldTypeUsername LoginInputFieldType = "username" + LoginInputFieldTypePassword LoginInputFieldType = "password" + LoginInputFieldTypePhoneNumber LoginInputFieldType = "phone_number" + LoginInputFieldTypeEmail LoginInputFieldType = "email" + LoginInputFieldType2FACode LoginInputFieldType = "2fa_code" +) + +type LoginInputDataField struct { + Type LoginInputFieldType `json:"type"` + ID string `json:"id"` + Name string `json:"name"` + Pattern string `json:"pattern,omitempty"` + Validate func(string) (string, error) `json:"-"` +} + +type LoginUserInputParams struct { + Fields []LoginInputDataField `json:"fields"` +} + +type LoginCompleteParams struct { + UserLoginID networkid.UserLoginID `json:"user_login_id"` +} + +type LoginSubmit struct { +} diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index d4af2506..58919f47 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -43,13 +43,14 @@ type Crypto interface { type Connector struct { //DB *dbutil.Database - AS *appservice.AppService - Bot *appservice.IntentAPI - StateStore *sqlstatestore.SQLStateStore - Crypto Crypto - Log *zerolog.Logger - Config *bridgeconfig.Config - Bridge *bridgev2.Bridge + AS *appservice.AppService + Bot *appservice.IntentAPI + StateStore *sqlstatestore.SQLStateStore + Crypto Crypto + Log *zerolog.Logger + Config *bridgeconfig.Config + Bridge *bridgev2.Bridge + Provisioning *ProvisioningAPI SpecVersions *mautrix.RespVersions @@ -95,9 +96,11 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) { br.Bot = br.AS.BotIntent() br.Crypto = NewCryptoHelper(br) br.Bridge.Commands.AddHandlers(CommandDiscardMegolmSession, CommandSetPowerLevel) + br.Provisioning = &ProvisioningAPI{br: br} } func (br *Connector) Start(ctx context.Context) error { + br.Provisioning.Init() br.EventProcessor.Start(ctx) err := br.StateStore.Upgrade(ctx) if err != nil { diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go new file mode 100644 index 00000000..7164d961 --- /dev/null +++ b/bridgev2/matrix/provisioning.go @@ -0,0 +1,241 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package matrix + +import ( + "context" + "encoding/json" + "net/http" + "strings" + "sync" + + "github.com/gorilla/mux" + "github.com/rs/xid" + "github.com/rs/zerolog" + "github.com/rs/zerolog/hlog" + "github.com/rs/zerolog/log" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/id" +) + +type ProvisioningAPI struct { + br *Connector + log zerolog.Logger + net bridgev2.NetworkConnector + + logins map[string]*ProvLogin + loginsLock sync.RWMutex +} + +type ProvLogin struct { + ID string + Process bridgev2.LoginProcess + NextStep *bridgev2.LoginStep + Lock sync.Mutex +} + +type provisioningContextKey int + +const ( + provisioningUserKey provisioningContextKey = iota + provisioningLoginKey +) + +func (prov *ProvisioningAPI) Init() { + prov.net = prov.br.Bridge.Network + prov.log = prov.br.Log.With().Str("component", "provisioning").Logger() + router := prov.br.AS.Router.PathPrefix(prov.br.Config.Provisioning.Prefix).Subrouter() + router.Use(hlog.NewHandler(prov.log)) + // TODO add access logger + //router.Use(requestlog.AccessLogger(true)) + router.Use(prov.AuthMiddleware) + router.Path("/v3/login/flows").Methods(http.MethodGet).HandlerFunc(prov.GetLoginFlows) + router.Path("/v3/login/start/{flowID}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginStart) + router.Path("/v3/login/step/{loginID}/{stepID}/{stepType:user_input|cookies}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginSubmitInput) + router.Path("/v3/login/step/{loginID}/{stepID}/{stepType:wait}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginWait) + + if prov.br.Config.Provisioning.DebugEndpoints { + log.Debug().Msg("Enabling debug API at /debug") + r := prov.br.AS.Router.PathPrefix("/debug").Subrouter() + r.Use(prov.AuthMiddleware) + r.PathPrefix("/pprof").Handler(http.DefaultServeMux) + } +} + +func jsonResponse(w http.ResponseWriter, status int, response any) { + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(response) +} + +func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") + if auth != prov.br.Config.Provisioning.SharedSecret { + zerolog.Ctx(r.Context()).Warn().Msg("Authentication token does not match shared secret") + jsonResponse(w, http.StatusForbidden, &mautrix.RespError{ + Err: "Authentication token does not match shared secret", + ErrCode: mautrix.MForbidden.ErrCode, + }) + return + } + userID := r.URL.Query().Get("user_id") + user, err := prov.br.Bridge.GetUserByMXID(r.Context(), id.UserID(userID)) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get user") + jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ + Err: "Failed to get user", + ErrCode: "M_UNKNOWN", + }) + return + } + // TODO handle user being nil? + + ctx := context.WithValue(r.Context(), provisioningUserKey, user) + if loginID, ok := mux.Vars(r)["loginID"]; ok { + prov.loginsLock.RLock() + login, ok := prov.logins[loginID] + prov.loginsLock.RUnlock() + if !ok { + zerolog.Ctx(r.Context()).Warn().Str("login_id", loginID).Msg("Login not found") + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + Err: "Login not found", + ErrCode: mautrix.MNotFound.ErrCode, + }) + return + } + login.Lock.Lock() + // This will only unlock after the handler runs + defer login.Lock.Unlock() + stepID := mux.Vars(r)["stepID"] + if login.NextStep.StepID != stepID { + zerolog.Ctx(r.Context()).Warn(). + Str("request_step_id", stepID). + Str("expected_step_id", login.NextStep.StepID). + Msg("Step ID does not match") + jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ + Err: "Step ID does not match", + ErrCode: mautrix.MBadState.ErrCode, + }) + return + } + stepType := mux.Vars(r)["stepType"] + if login.NextStep.Type != bridgev2.LoginStepType(stepType) { + zerolog.Ctx(r.Context()).Warn(). + Str("request_step_type", stepType). + Str("expected_step_type", string(login.NextStep.Type)). + Msg("Step type does not match") + jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ + Err: "Step type does not match", + ErrCode: mautrix.MBadState.ErrCode, + }) + return + } + ctx = context.WithValue(r.Context(), provisioningLoginKey, login) + } + h.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +type RespLoginFlows struct { + Flows []bridgev2.LoginFlow `json:"flows"` +} + +type RespSubmitLogin struct { + LoginID string `json:"login_id"` + *bridgev2.LoginStep +} + +func (prov *ProvisioningAPI) GetLoginFlows(w http.ResponseWriter, r *http.Request) { + jsonResponse(w, http.StatusOK, &RespLoginFlows{ + Flows: prov.net.GetLoginFlows(), + }) +} + +func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Request) { + login, err := prov.net.CreateLogin( + r.Context(), + r.Context().Value(provisioningUserKey).(*bridgev2.User), + mux.Vars(r)["flowID"], + ) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create login process") + jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ + Err: "Failed to create login process", + ErrCode: "M_UNKNOWN", + }) + return + } + firstStep, err := login.Start(r.Context()) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to start login") + jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ + Err: "Failed to start login", + ErrCode: "M_UNKNOWN", + }) + return + } + loginID := xid.New().String() + prov.loginsLock.Lock() + prov.logins[loginID] = &ProvLogin{ + ID: loginID, + Process: login, + NextStep: firstStep, + } + prov.loginsLock.Unlock() + jsonResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: loginID, LoginStep: firstStep}) +} + +func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http.Request) { + var params map[string]string + err := json.NewDecoder(r.Body).Decode(¶ms) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") + jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ + Err: "Failed to decode request body", + ErrCode: mautrix.MNotJSON.ErrCode, + }) + return + } + login := r.Context().Value(provisioningLoginKey).(*ProvLogin) + var nextStep *bridgev2.LoginStep + switch login.NextStep.Type { + case bridgev2.LoginStepTypeUserInput: + nextStep, err = login.Process.(bridgev2.LoginProcessUserInput).SubmitUserInput(r.Context(), params) + case bridgev2.LoginStepTypeCookies: + nextStep, err = login.Process.(bridgev2.LoginProcessCookies).SubmitCookies(r.Context(), params) + default: + panic("Impossible state") + } + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to submit input") + jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ + Err: "Failed to submit input", + ErrCode: "M_UNKNOWN", + }) + return + } + login.NextStep = nextStep + jsonResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) +} + +func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Request) { + login := r.Context().Value(provisioningLoginKey).(*ProvLogin) + nextStep, err := login.Process.(bridgev2.LoginProcessDisplayAndWait).Wait(r.Context()) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to submit input") + jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ + Err: "Failed to submit input", + ErrCode: "M_UNKNOWN", + }) + return + } + login.NextStep = nextStep + jsonResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) +} diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index e8be6765..2c2125ac 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -45,7 +45,10 @@ type ConvertedMessage struct { type NetworkConnector interface { Init(*Bridge) Start(context.Context) error - PrepareLogin(ctx context.Context, login *UserLogin) error + LoadUserLogin(ctx context.Context, login *UserLogin) error + + GetLoginFlows() []LoginFlow + CreateLogin(ctx context.Context, user *User, flowID string) (LoginProcess, error) } type NetworkAPI interface { diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 8d082ac1..b970bda6 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -39,7 +39,7 @@ func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *da User: user, Log: user.Log.With().Str("login_id", string(dbUserLogin.ID)).Logger(), } - err := br.Network.PrepareLogin(ctx, userLogin) + err := br.Network.LoadUserLogin(ctx, userLogin) if err != nil { return nil, fmt.Errorf("failed to prepare: %w", err) } diff --git a/go.mod b/go.mod index f1c025b4..ef64ece3 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,9 @@ require ( github.com/gorilla/websocket v1.5.0 github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.22 + github.com/rs/xid v1.5.0 github.com/rs/zerolog v1.32.0 + github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/stretchr/testify v1.9.0 github.com/tidwall/gjson v1.17.1 github.com/tidwall/sjson v1.2.5 diff --git a/go.sum b/go.sum index 4e4a26c9..c7295275 100644 --- a/go.sum +++ b/go.sum @@ -21,9 +21,12 @@ github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxU github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0= github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= +github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= From 0707c89fbadc572ee4d8ab4fe1ca17d1c34916e5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 5 Jun 2024 13:18:40 +0300 Subject: [PATCH 0249/1647] Add logout support and adjust other things --- bridgev2/cmdlogin.go | 66 ++++++++++++++++++++++++++++++++++ bridgev2/cmdprocessor.go | 5 ++- bridgev2/database/userlogin.go | 7 ++++ bridgev2/login.go | 14 +++++--- bridgev2/matrix/connector.go | 3 -- bridgev2/matrix/intent.go | 2 +- bridgev2/networkinterface.go | 1 + bridgev2/userlogin.go | 21 +++++++++++ 8 files changed, 110 insertions(+), 9 deletions(-) diff --git a/bridgev2/cmdlogin.go b/bridgev2/cmdlogin.go index b9f6cd98..d15fac72 100644 --- a/bridgev2/cmdlogin.go +++ b/bridgev2/cmdlogin.go @@ -18,6 +18,7 @@ import ( "github.com/skip2/go-qrcode" "golang.org/x/net/html" + "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -28,6 +29,7 @@ var CommandLogin = &FullHandler{ Help: HelpMeta{ Section: HelpSectionAuth, Description: "Log into the bridge", + Args: "[_flow ID_]", }, } @@ -286,3 +288,67 @@ func doLoginStep(ce *CommandEvent, login LoginProcess, step *LoginStep) { panic(fmt.Errorf("unknown login step type %q", step.Type)) } } + +var CommandLogout = &FullHandler{ + Func: fnLogout, + Name: "logout", + Help: HelpMeta{ + Section: HelpSectionAuth, + Description: "Log out of the bridge", + Args: "<_login ID_>", + }, +} + +func getUserLogins(user *User) string { + user.Bridge.cacheLock.Lock() + logins := make([]string, len(user.logins)) + for key := range user.logins { + logins = append(logins, fmt.Sprintf("* `%s`", key)) + } + user.Bridge.cacheLock.Unlock() + return strings.Join(logins, "\n") +} + +func fnLogout(ce *CommandEvent) { + if len(ce.Args) == 0 { + ce.Reply("Usage: `$cmdprefix logout `\n\nYour logins:\n\n%s", getUserLogins(ce.User)) + return + } + login := ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) + if login == nil || login.UserMXID != ce.User.MXID { + ce.Reply("Login `%s` not found", ce.Args[0]) + return + } + login.Logout(ce.Ctx) + ce.Reply("Logged out") +} + +var CommandSetPreferredLogin = &FullHandler{ + Func: fnSetPreferredLogin, + Name: "set-preferred-login", + Aliases: []string{"prefer"}, + Help: HelpMeta{ + Section: HelpSectionAuth, + Description: "Set the preferred login ID for sending messages to this portal (only relevant when logged into multiple accounts via the bridge)", + Args: "<_login ID_>", + }, + RequiresPortal: true, +} + +func fnSetPreferredLogin(ce *CommandEvent) { + if len(ce.Args) == 0 { + ce.Reply("Usage: `$cmdprefix set-preferred-login `\n\nYour logins:\n\n%s", getUserLogins(ce.User)) + return + } + login := ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) + if login == nil || login.UserMXID != ce.User.MXID { + ce.Reply("Login `%s` not found", ce.Args[0]) + return + } + err := login.MarkAsPreferredIn(ce.Ctx, ce.Portal) + if err != nil { + ce.Reply("Failed to set preferred login: %v", err) + } else { + ce.Reply("Preferred login set") + } +} diff --git a/bridgev2/cmdprocessor.go b/bridgev2/cmdprocessor.go index e412bf8e..b3289cad 100644 --- a/bridgev2/cmdprocessor.go +++ b/bridgev2/cmdprocessor.go @@ -33,7 +33,10 @@ func NewProcessor(bridge *Bridge) *CommandProcessor { handlers: make(map[string]CommandHandler), aliases: make(map[string]string), } - proc.AddHandlers(CommandHelp, CommandVersion, CommandCancel) + proc.AddHandlers( + CommandHelp, CommandVersion, CommandCancel, + CommandLogin, CommandLogout, CommandSetPreferredLogin, + ) return proc } diff --git a/bridgev2/database/userlogin.go b/bridgev2/database/userlogin.go index bec14841..8169980c 100644 --- a/bridgev2/database/userlogin.go +++ b/bridgev2/database/userlogin.go @@ -52,6 +52,9 @@ const ( UPDATE user_login SET space_room=$4, metadata=$5 WHERE bridge_id=$1 AND user_mxid=$2 AND id=$3 ` + deleteUserLoginQuery = ` + DELETE FROM user_login WHERE bridge_id=$1 AND id=$2 + ` insertUserPortalQuery = ` INSERT INTO user_portal (bridge_id, user_mxid, login_id, portal_id, in_space, preferred) VALUES ($1, $2, $3, $4, false, false) @@ -89,6 +92,10 @@ func (uq *UserLoginQuery) Update(ctx context.Context, login *UserLogin) error { return uq.Exec(ctx, updateUserLoginQuery, login.sqlVariables()...) } +func (uq *UserLoginQuery) Delete(ctx context.Context, loginID networkid.UserLoginID) error { + return uq.Exec(ctx, deleteUserLoginQuery, uq.BridgeID, loginID) +} + func (uq *UserLoginQuery) EnsureUserPortalExists(ctx context.Context, login *UserLogin, portalID networkid.PortalID) error { ensureBridgeIDMatches(&login.BridgeID, uq.BridgeID) return uq.Exec(ctx, insertUserPortalQuery, login.BridgeID, login.UserMXID, login.ID, portalID) diff --git a/bridgev2/login.go b/bridgev2/login.go index 31500680..78c4e7de 100644 --- a/bridgev2/login.go +++ b/bridgev2/login.go @@ -105,14 +105,20 @@ const ( ) type LoginInputDataField struct { - Type LoginInputFieldType `json:"type"` - ID string `json:"id"` - Name string `json:"name"` - Pattern string `json:"pattern,omitempty"` + // The type of input field as a hint for the client. + Type LoginInputFieldType `json:"type"` + // The ID of the field to be used as the key in the map that is submitted to the connector. + ID string `json:"id"` + // The name of the field shown to the user. + Name string `json:"name"` + // A regex pattern that the client can use to validate input client-side. + Pattern string `json:"pattern,omitempty"` + // A function that validates the input and optionally cleans it up before it's submitted to the connector. Validate func(string) (string, error) `json:"-"` } type LoginUserInputParams struct { + // The fields that the user needs to fill in. Fields []LoginInputDataField `json:"fields"` } diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 58919f47..4d28d3ee 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -58,9 +58,6 @@ type Connector struct { userIDRegex *regexp.Regexp - // TODO move to config - AsyncUploads bool - Websocket bool wsStopPinger chan struct{} wsStarted chan struct{} diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index b9158bf3..800d8b75 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -103,7 +103,7 @@ func (as *ASIntent) UploadMedia(ctx context.Context, roomID id.RoomID, data []by ContentType: mimeType, FileName: fileName, } - if as.Connector.AsyncUploads { + if as.Connector.Config.Homeserver.AsyncMedia { var resp *mautrix.RespCreateMXC resp, err = as.Matrix.UploadAsync(ctx, req) if resp != nil { diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 2c2125ac..6dd094a1 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -54,6 +54,7 @@ type NetworkConnector interface { type NetworkAPI interface { Connect(ctx context.Context) error IsLoggedIn() bool + LogoutRemote(ctx context.Context) IsThisUser(ctx context.Context, userID networkid.UserID) bool GetChatInfo(ctx context.Context, portal *Portal) (*PortalInfo, error) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index b970bda6..74956e26 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -119,3 +119,24 @@ func (user *User) NewLogin(ctx context.Context, data *database.UserLogin, client user.logins[ul.ID] = ul return ul, nil } + +func (ul *UserLogin) Save(ctx context.Context) error { + return ul.Bridge.DB.UserLogin.Update(ctx, ul.UserLogin) +} + +func (ul *UserLogin) Logout(ctx context.Context) { + ul.Client.LogoutRemote(ctx) + err := ul.Bridge.DB.UserLogin.Delete(ctx, ul.ID) + if err != nil { + ul.Log.Err(err).Msg("Failed to delete user login") + } + ul.Bridge.cacheLock.Lock() + defer ul.Bridge.cacheLock.Unlock() + delete(ul.User.logins, ul.ID) + delete(ul.Bridge.userLoginsByID, ul.ID) + // TODO kick user out of rooms? +} + +func (ul *UserLogin) MarkAsPreferredIn(ctx context.Context, portal *Portal) error { + return ul.Bridge.DB.UserLogin.MarkLoginAsPreferredInPortal(ctx, ul.UserLogin, portal.ID) +} From 218ed06e73f650ddc33d6a5a001e676265374569 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 5 Jun 2024 13:50:31 +0300 Subject: [PATCH 0250/1647] Move ID and sender from ConvertedMessage to RemoteMessage --- bridgev2/networkinterface.go | 31 ++++++++++++++++++------------- bridgev2/portal.go | 21 ++++++++++++++------- 2 files changed, 32 insertions(+), 20 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 6dd094a1..3936e9c8 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -32,8 +32,9 @@ type EventSender struct { } type ConvertedMessage struct { - ID networkid.MessageID - EventSender + // TODO are these ever ambiguous at the time of forming the RemoteMessage? + //ID networkid.MessageID + //EventSender Timestamp time.Time ReplyTo *networkid.MessageOptionalPartID ThreadRoot *networkid.MessageOptionalPartID @@ -82,36 +83,35 @@ type RemoteEvent interface { GetPortalID() networkid.PortalID ShouldCreatePortal() bool AddLogContext(c zerolog.Context) zerolog.Context + GetSender() EventSender } type RemoteMessage interface { RemoteEvent - ConvertMessage(ctx context.Context, portal *Portal) (*ConvertedMessage, error) + GetID() networkid.MessageID + ConvertMessage(ctx context.Context, portal *Portal, intent MatrixAPI) (*ConvertedMessage, error) } type RemoteEdit interface { RemoteEvent GetTargetMessage() networkid.MessageID - ConvertEdit(ctx context.Context, portal *Portal, existing []*database.Message) (*ConvertedMessage, error) + ConvertEdit(ctx context.Context, portal *Portal, intent MatrixAPI, existing []*database.Message) (*ConvertedMessage, error) } type RemoteReaction interface { RemoteEvent - GetSender() EventSender GetTargetMessage() networkid.MessageID GetReactionEmoji() (string, networkid.EmojiID) } type RemoteReactionRemove interface { RemoteEvent - GetSender() EventSender GetTargetMessage() networkid.MessageID GetRemovedEmojiID() networkid.EmojiID } type RemoteMessageRemove interface { RemoteEvent - GetSender() EventSender GetTargetMessage() networkid.MessageID } @@ -123,13 +123,14 @@ type SimpleRemoteEvent[T any] struct { Data T CreatePortal bool + ID networkid.MessageID Sender EventSender TargetMessage networkid.MessageID EmojiID networkid.EmojiID Emoji string - ConvertMessageFunc func(ctx context.Context, portal *Portal, data T) (*ConvertedMessage, error) - ConvertEditFunc func(ctx context.Context, portal *Portal, existing []*database.Message, data T) (*ConvertedMessage, error) + ConvertMessageFunc func(ctx context.Context, portal *Portal, intent MatrixAPI, data T) (*ConvertedMessage, error) + ConvertEditFunc func(ctx context.Context, portal *Portal, intent MatrixAPI, existing []*database.Message, data T) (*ConvertedMessage, error) } var ( @@ -148,12 +149,16 @@ func (sre *SimpleRemoteEvent[T]) GetPortalID() networkid.PortalID { return sre.PortalID } -func (sre *SimpleRemoteEvent[T]) ConvertMessage(ctx context.Context, portal *Portal) (*ConvertedMessage, error) { - return sre.ConvertMessageFunc(ctx, portal, sre.Data) +func (sre *SimpleRemoteEvent[T]) ConvertMessage(ctx context.Context, portal *Portal, intent MatrixAPI) (*ConvertedMessage, error) { + return sre.ConvertMessageFunc(ctx, portal, intent, sre.Data) } -func (sre *SimpleRemoteEvent[T]) ConvertEdit(ctx context.Context, portal *Portal, existing []*database.Message) (*ConvertedMessage, error) { - return sre.ConvertEditFunc(ctx, portal, existing, sre.Data) +func (sre *SimpleRemoteEvent[T]) ConvertEdit(ctx context.Context, portal *Portal, intent MatrixAPI, existing []*database.Message) (*ConvertedMessage, error) { + return sre.ConvertEditFunc(ctx, portal, intent, existing, sre.Data) +} + +func (sre *SimpleRemoteEvent[T]) GetID() networkid.MessageID { + return sre.ID } func (sre *SimpleRemoteEvent[T]) GetSender() EventSender { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 046a56b5..a950bee4 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -490,7 +490,18 @@ func (portal *Portal) getIntentFor(ctx context.Context, sender EventSender, sour func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin, evt RemoteMessage) { log := zerolog.Ctx(ctx) - converted, err := evt.ConvertMessage(ctx, portal) + existing, err := portal.Bridge.DB.Message.GetFirstPartByID(ctx, evt.GetID()) + if err != nil { + log.Err(err).Msg("Failed to check if message is a duplicate") + } else if existing != nil { + log.Debug().Stringer("existing_mxid", existing.MXID).Msg("Ignoring duplicate message") + return + } + intent := portal.getIntentFor(ctx, evt.GetSender(), source) + if intent == nil { + return + } + converted, err := evt.ConvertMessage(ctx, portal, intent) if err != nil { // TODO log and notify room? return @@ -518,10 +529,6 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin // TODO 2 fetch last event in thread properly prevThreadEvent = threadRoot } - intent := portal.getIntentFor(ctx, converted.EventSender, source) - if intent == nil { - return - } for _, part := range converted.Parts { if threadRoot != nil && prevThreadEvent != nil { part.Content.GetRelatesTo().SetThread(threadRoot.MXID, prevThreadEvent.MXID) @@ -550,11 +557,11 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin // TODO make metadata fields less hacky part.DBMetadata["sender_mxid"] = intent.GetMXID() dbMessage := &database.Message{ - ID: converted.ID, + ID: evt.GetID(), PartID: part.ID, MXID: resp.EventID, RoomID: portal.ID, - SenderID: converted.Sender, + SenderID: evt.GetSender().Sender, Timestamp: converted.Timestamp, RelatesToRowID: relatesToRowID, Metadata: part.DBMetadata, From feebb5813f45574a9380cd8532908d8d7c8b85ea Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Wed, 5 Jun 2024 11:02:07 -0600 Subject: [PATCH 0251/1647] bridgev2/login: add description Signed-off-by: Sumner Evans --- bridgev2/login-step.schema.json | 7 ++++++- bridgev2/login.go | 2 ++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/bridgev2/login-step.schema.json b/bridgev2/login-step.schema.json index fa120778..518698fa 100644 --- a/bridgev2/login-step.schema.json +++ b/bridgev2/login-step.schema.json @@ -40,9 +40,14 @@ }, "name": { "type": "string", - "description": "The name of the field", + "description": "The name of the field shown to the user", "examples": ["Username", "Password", "Phone number", "2FA code", "Meow"] }, + "description": { + "type": "string", + "description": "The description of the field shown to the user", + "examples": ["Include the country code with a +"] + }, "pattern": { "type": "string", "description": "A regular expression that the field value must match" diff --git a/bridgev2/login.go b/bridgev2/login.go index 78c4e7de..7d6018f7 100644 --- a/bridgev2/login.go +++ b/bridgev2/login.go @@ -111,6 +111,8 @@ type LoginInputDataField struct { ID string `json:"id"` // The name of the field shown to the user. Name string `json:"name"` + // The description of the field shown to the user. + Description string `json:"description"` // A regex pattern that the client can use to validate input client-side. Pattern string `json:"pattern,omitempty"` // A function that validates the input and optionally cleans it up before it's submitted to the connector. From d7ffa71838248e42aafd495325e62a9c7307afdf Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Wed, 5 Jun 2024 11:14:21 -0600 Subject: [PATCH 0252/1647] login: add default validate function depending on field type Signed-off-by: Sumner Evans --- bridgev2/cmdlogin.go | 1 + bridgev2/login.go | 31 +++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/bridgev2/cmdlogin.go b/bridgev2/cmdlogin.go index d15fac72..740915a1 100644 --- a/bridgev2/cmdlogin.go +++ b/bridgev2/cmdlogin.go @@ -94,6 +94,7 @@ func (uilcs *userInputLoginCommandState) promptNext(ce *CommandEvent) { func (uilcs *userInputLoginCommandState) submitNext(ce *CommandEvent) { field := uilcs.RemainingFields[0] + field.FillDefaultValidate() var err error uilcs.Data[field.ID], err = field.Validate(ce.RawArgs) if err != nil { diff --git a/bridgev2/login.go b/bridgev2/login.go index 7d6018f7..5b25cb94 100644 --- a/bridgev2/login.go +++ b/bridgev2/login.go @@ -8,6 +8,9 @@ package bridgev2 import ( "context" + "fmt" + "regexp" + "strings" "maunium.net/go/mautrix/bridgev2/networkid" ) @@ -119,6 +122,34 @@ type LoginInputDataField struct { Validate func(string) (string, error) `json:"-"` } +var phoneNumberRe = regexp.MustCompile(`\+\d+`) + +func (f *LoginInputDataField) FillDefaultValidate() { + noopValidate := func(input string) (string, error) { return input, nil } + if f.Validate != nil { + return + } + switch f.Type { + case LoginInputFieldTypePhoneNumber: + f.Validate = func(phone string) (string, error) { + phone = strings.ReplaceAll(phone, " ", "") + if !phoneNumberRe.MatchString(phone) { + return "", fmt.Errorf("invalid phone number") + } + return phone, nil + } + case LoginInputFieldTypeEmail: + f.Validate = func(email string) (string, error) { + if !strings.ContainsRune(email, '@') { + return "", fmt.Errorf("invalid email") + } + return email, nil + } + default: + f.Validate = noopValidate + } +} + type LoginUserInputParams struct { // The fields that the user needs to fill in. Fields []LoginInputDataField `json:"fields"` From f7c4ff64559580465ac2d189f795b80a4ca3f7ad Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 6 Jun 2024 13:41:53 +0300 Subject: [PATCH 0253/1647] Add more accurate logs in cmdprocessor --- bridgev2/cmdprocessor.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/bridgev2/cmdprocessor.go b/bridgev2/cmdprocessor.go index b3289cad..55ad0c9a 100644 --- a/bridgev2/cmdprocessor.go +++ b/bridgev2/cmdprocessor.go @@ -73,8 +73,6 @@ func (proc *CommandProcessor) Handle(ctx context.Context, roomID id.RoomID, even } command := strings.ToLower(args[0]) rawArgs := strings.TrimLeft(strings.TrimPrefix(message, command), " ") - log := zerolog.Ctx(ctx).With().Str("mx_command", command).Logger() - ctx = log.WithContext(ctx) portal, err := proc.bridge.GetPortalByMXID(ctx, roomID) if err != nil { // :( @@ -92,9 +90,7 @@ func (proc *CommandProcessor) Handle(ctx context.Context, roomID id.RoomID, even RawArgs: rawArgs, ReplyTo: replyTo, Ctx: ctx, - Log: &log, } - log.Debug().Msg("Received command") realCommand, ok := proc.aliases[ce.Command] if !ok { @@ -110,11 +106,21 @@ func (proc *CommandProcessor) Handle(ctx context.Context, roomID id.RoomID, even ce.RawArgs = message ce.Args = args ce.Handler = state.Next + log := zerolog.Ctx(ctx).With().Str("action", state.Action).Logger() + ce.Log = &log + ce.Ctx = log.WithContext(ctx) + log.Debug().Msg("Received reply to command state") state.Next.Run(ce) } else { + zerolog.Ctx(ctx).Debug().Str("mx_command", command).Msg("Received unknown command") ce.Reply("Unknown command, use the `help` command for help.") } } else { + log := zerolog.Ctx(ctx).With().Str("mx_command", command).Logger() + ctx = log.WithContext(ctx) + ce.Log = &log + ce.Ctx = ctx + log.Debug().Msg("Received command") ce.Handler = handler handler.Run(ce) } From 55fa856e768ca8c669dd0dbaaeb8ba0a29c1e49e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 6 Jun 2024 13:42:08 +0300 Subject: [PATCH 0254/1647] Allow overriding message rowid --- bridgev2/database/upgrades/00-latest.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index b64a5507..df8c0a4a 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -46,7 +46,7 @@ CREATE TABLE message ( -- only: sqlite (line commented) -- rowid INTEGER PRIMARY KEY, -- only: postgres - rowid BIGINT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + rowid BIGINT PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY, bridge_id TEXT NOT NULL, id TEXT NOT NULL, From be4a3e17ea96530975412e89280a5d9c916940b6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 6 Jun 2024 13:42:18 +0300 Subject: [PATCH 0255/1647] Add more fields to cookie login --- bridgev2/login-step.schema.json | 8 ++++++++ bridgev2/login.go | 4 +++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/bridgev2/login-step.schema.json b/bridgev2/login-step.schema.json index 518698fa..42edb56f 100644 --- a/bridgev2/login-step.schema.json +++ b/bridgev2/login-step.schema.json @@ -68,6 +68,10 @@ "type": "string", "description": "The URL to open when using a webview to extract cookies" }, + "user_agent": { + "type": "string", + "description": "The user agent to use when opening the URL" + }, "cookie_domain": { "type": "string", "description": "The domain of the cookies to extract" @@ -92,6 +96,10 @@ "items": { "type": "string" } + }, + "special_extract_js": { + "type": "string", + "description": "JavaScript code that can be evaluated inside the webview to extract the special keys" } }, "required": ["url"] diff --git a/bridgev2/login.go b/bridgev2/login.go index 5b25cb94..23647042 100644 --- a/bridgev2/login.go +++ b/bridgev2/login.go @@ -89,12 +89,14 @@ type LoginDisplayAndWaitParams struct { } type LoginCookiesParams struct { - URL string `json:"url"` + URL string `json:"url"` + UserAgent string `json:"user_agent,omitempty"` CookieDomain string `json:"cookie_domain,omitempty"` CookieKeys []string `json:"cookie_keys,omitempty"` LocalStorageKeys []string `json:"local_storage_keys,omitempty"` SpecialKeys []string `json:"special_keys,omitempty"` + SpecialExtractJS string `json:"special_extract_js,omitempty"` } type LoginInputFieldType string From a0e309fa55ab2938b45817c6d6166a101efb9e59 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 6 Jun 2024 16:00:41 +0300 Subject: [PATCH 0256/1647] Mostly implement Matrix reactions and redactions --- bridgev2/database/message.go | 5 ++ bridgev2/networkinterface.go | 3 +- bridgev2/portal.go | 112 ++++++++++++++++++++++++++++++++++- 3 files changed, 117 insertions(+), 3 deletions(-) diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index 539aa61a..b38d193e 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -48,6 +48,7 @@ const ( ` getAllMessagePartsByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND id=$2` getMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND part_id=$3` + getMessagePartByRowIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND rowid=$2` getMessageByMXIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` getLastMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND id=$2 ORDER BY part_id DESC LIMIT 1` getFirstMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND id=$2 ORDER BY part_id ASC LIMIT 1` @@ -89,6 +90,10 @@ func (mq *MessageQuery) GetFirstPartByID(ctx context.Context, id networkid.Messa return mq.QueryOne(ctx, getFirstMessagePartByIDQuery, mq.BridgeID, id) } +func (mq *MessageQuery) GetByRowID(ctx context.Context, rowID int64) (*Message, error) { + return mq.QueryOne(ctx, getMessagePartByRowIDQuery, mq.BridgeID, rowID) +} + func (mq *MessageQuery) GetFirstOrSpecificPartByID(ctx context.Context, id networkid.MessageOptionalPartID) (*Message, error) { if id.PartID == nil { return mq.GetFirstPartByID(ctx, id.MessageID) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 3936e9c8..1aa02263 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -63,7 +63,7 @@ type NetworkAPI interface { HandleMatrixMessage(ctx context.Context, msg *MatrixMessage) (message *database.Message, err error) HandleMatrixEdit(ctx context.Context, msg *MatrixEdit) error - HandleMatrixReaction(ctx context.Context, msg *MatrixReaction) (emojiID networkid.EmojiID, err error) + HandleMatrixReaction(ctx context.Context, msg *MatrixReaction) (reaction *database.Reaction, err error) HandleMatrixReactionRemove(ctx context.Context, msg *MatrixReactionRemove) error HandleMatrixMessageRemove(ctx context.Context, msg *MatrixMessageRemove) error } @@ -216,6 +216,7 @@ type MatrixEdit struct { type MatrixReaction struct { MatrixEventBase[*event.ReactionEventContent] TargetMessage *database.Message + GetExisting func(ctx context.Context, senderID networkid.UserID, emojiID networkid.EmojiID) (*database.Reaction, error) } type MatrixReactionRemove struct { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index a950bee4..002a3908 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -400,6 +400,10 @@ func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, o log.Err(err).Msg("Failed to get edit target message from database") // TODO send metrics return + } else if editTarget == nil { + log.Warn().Msg("Edit target message not found in database") + // TODO send metrics + return } log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Str("edit_target_remote_id", string(editTarget.ID)) @@ -426,11 +430,115 @@ func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, o } func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) { - + log := zerolog.Ctx(ctx) + content, ok := evt.Content.Parsed.(*event.ReactionEventContent) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + // TODO send metrics + return + } + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Stringer("reaction_target_mxid", content.RelatesTo.EventID) + }) + reactionTarget, err := portal.Bridge.DB.Message.GetPartByMXID(ctx, content.RelatesTo.EventID) + if err != nil { + log.Err(err).Msg("Failed to get reaction target message from database") + // TODO send metrics + return + } else if reactionTarget == nil { + log.Warn().Msg("Reaction target message not found in database") + // TODO send metrics + return + } + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("reaction_target_remote_id", string(reactionTarget.ID)) + }) + dbReaction, err := sender.Client.HandleMatrixReaction(ctx, &MatrixReaction{ + MatrixEventBase: MatrixEventBase[*event.ReactionEventContent]{ + Event: evt, + Content: content, + Portal: portal, + }, + TargetMessage: reactionTarget, + GetExisting: func(ctx context.Context, senderID networkid.UserID, emojiID networkid.EmojiID) (*database.Reaction, error) { + return portal.Bridge.DB.Reaction.GetByID(ctx, reactionTarget.ID, reactionTarget.PartID, senderID, emojiID) + }, + }) + if err != nil { + log.Err(err).Msg("Failed to handle Matrix reaction") + // TODO send metrics here or inside HandleMatrixReaction? + return + } + // TODO figure out how to delete outdated reactions if appropriate + if dbReaction != nil { + err = portal.Bridge.DB.Reaction.Upsert(ctx, dbReaction) + if err != nil { + log.Err(err).Msg("Failed to save reaction to database") + } + } else { + log.Debug().Msg("Reaction was ignored") + } + // TODO send success metrics } func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { - + log := zerolog.Ctx(ctx) + content, ok := evt.Content.Parsed.(*event.RedactionEventContent) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + // TODO send metrics + return + } + if evt.Redacts != "" && content.Redacts != evt.Redacts { + content.Redacts = evt.Redacts + } + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Stringer("redaction_target_mxid", content.Redacts) + }) + redactionTargetMsg, err := portal.Bridge.DB.Message.GetPartByMXID(ctx, content.Redacts) + if err != nil { + log.Err(err).Msg("Failed to get redaction target message from database") + // TODO send metrics + return + } + redactionTargetReaction, err := portal.Bridge.DB.Reaction.GetByMXID(ctx, content.Redacts) + if err != nil { + log.Err(err).Msg("Failed to get redaction target reaction from database") + // TODO send metrics + return + } + if redactionTargetMsg != nil { + err = sender.Client.HandleMatrixMessageRemove(ctx, &MatrixMessageRemove{ + MatrixEventBase: MatrixEventBase[*event.RedactionEventContent]{ + Event: evt, + Content: content, + Portal: portal, + OrigSender: origSender, + }, + TargetMessage: redactionTargetMsg, + }) + } else if redactionTargetReaction != nil { + err = sender.Client.HandleMatrixReactionRemove(ctx, &MatrixReactionRemove{ + MatrixEventBase: MatrixEventBase[*event.RedactionEventContent]{ + Event: evt, + Content: content, + Portal: portal, + OrigSender: origSender, + }, + TargetReaction: redactionTargetReaction, + }) + } else { + log.Debug().Msg("Redaction target message not found in database") + // TODO send metrics + return + } + if err != nil { + log.Err(err).Msg("Failed to handle Matrix redaction") + // TODO send metrics here or inside HandleMatrixMessageRemove and HandleMatrixReactionRemove? + return + } + // TODO delete msg/reaction db row + // TODO send success metrics } func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { From 3204ffed2ddfb318a098301ab1cf596293b31775 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 6 Jun 2024 18:04:47 +0300 Subject: [PATCH 0257/1647] Add new type for converted edits --- bridgev2/networkinterface.go | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 1aa02263..0af72071 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -39,8 +39,20 @@ type ConvertedMessage struct { ReplyTo *networkid.MessageOptionalPartID ThreadRoot *networkid.MessageOptionalPartID Parts []*ConvertedMessagePart - // For edits, set this field to skip editing the event - Unchanged bool +} + +type ConvertedEditPart struct { + Part *database.Message + + Type event.Type + Content *event.MessageEventContent + Extra map[string]any +} + +type ConvertedEdit struct { + Timestamp time.Time + ModifiedParts []*ConvertedEditPart + DeletedParts []*database.Message } type NetworkConnector interface { @@ -95,7 +107,7 @@ type RemoteMessage interface { type RemoteEdit interface { RemoteEvent GetTargetMessage() networkid.MessageID - ConvertEdit(ctx context.Context, portal *Portal, intent MatrixAPI, existing []*database.Message) (*ConvertedMessage, error) + ConvertEdit(ctx context.Context, portal *Portal, intent MatrixAPI, existing []*database.Message) (*ConvertedEdit, error) } type RemoteReaction interface { @@ -130,7 +142,7 @@ type SimpleRemoteEvent[T any] struct { Emoji string ConvertMessageFunc func(ctx context.Context, portal *Portal, intent MatrixAPI, data T) (*ConvertedMessage, error) - ConvertEditFunc func(ctx context.Context, portal *Portal, intent MatrixAPI, existing []*database.Message, data T) (*ConvertedMessage, error) + ConvertEditFunc func(ctx context.Context, portal *Portal, intent MatrixAPI, existing []*database.Message, data T) (*ConvertedEdit, error) } var ( @@ -153,7 +165,7 @@ func (sre *SimpleRemoteEvent[T]) ConvertMessage(ctx context.Context, portal *Por return sre.ConvertMessageFunc(ctx, portal, intent, sre.Data) } -func (sre *SimpleRemoteEvent[T]) ConvertEdit(ctx context.Context, portal *Portal, intent MatrixAPI, existing []*database.Message) (*ConvertedMessage, error) { +func (sre *SimpleRemoteEvent[T]) ConvertEdit(ctx context.Context, portal *Portal, intent MatrixAPI, existing []*database.Message) (*ConvertedEdit, error) { return sre.ConvertEditFunc(ctx, portal, intent, existing, sre.Data) } From 97d803723b00dd87ead9abb9f0c657b43f187621 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 6 Jun 2024 19:44:17 +0300 Subject: [PATCH 0258/1647] Add support for remote edits --- bridgev2/networkinterface.go | 6 +++- bridgev2/portal.go | 55 +++++++++++++++++++++++++++++++++++- 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 0af72071..89ce81aa 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -44,9 +44,13 @@ type ConvertedMessage struct { type ConvertedEditPart struct { Part *database.Message - Type event.Type + Type event.Type + // The Content and Extra fields will be put inside `m.new_content` automatically. + // SetEdit must NOT be called by the network connector. Content *event.MessageEventContent Extra map[string]any + // TopLevelExtra can be used to specify custom fields at the top level of the content rather than inside `m.new_content`. + TopLevelExtra map[string]any } type ConvertedEdit struct { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 002a3908..f7fdb3cd 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -685,7 +685,60 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin } func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, evt RemoteEdit) { - + log := zerolog.Ctx(ctx) + existing, err := portal.Bridge.DB.Message.GetAllPartsByID(ctx, evt.GetTargetMessage()) + if err != nil { + log.Err(err).Msg("Failed to get edit target message") + return + } else if existing == nil { + log.Warn().Msg("Edit target message not found") + return + } + intent := portal.getIntentFor(ctx, evt.GetSender(), source) + if intent == nil { + return + } + converted, err := evt.ConvertEdit(ctx, portal, intent, existing) + if err != nil { + // TODO log and notify room? + return + } + for _, part := range converted.ModifiedParts { + part.Content.SetEdit(part.Part.MXID) + if part.TopLevelExtra == nil { + part.TopLevelExtra = make(map[string]any) + } + if part.Extra != nil { + part.TopLevelExtra["m.new_content"] = part.Extra + } + wrappedContent := &event.Content{ + Parsed: part.Content, + Raw: part.TopLevelExtra, + } + _, err = intent.SendMessage(ctx, portal.MXID, part.Type, wrappedContent, converted.Timestamp) + if err != nil { + log.Err(err).Stringer("part_mxid", part.Part.MXID).Msg("Failed to edit message part") + } + err = portal.Bridge.DB.Message.Update(ctx, part.Part) + if err != nil { + log.Err(err).Int64("part_rowid", part.Part.RowID).Msg("Failed to update message part in database") + } + } + for _, part := range converted.DeletedParts { + redactContent := &event.Content{ + Parsed: &event.RedactionEventContent{ + Redacts: part.MXID, + }, + } + _, err = intent.SendMessage(ctx, portal.MXID, event.EventRedaction, redactContent, converted.Timestamp) + if err != nil { + log.Err(err).Stringer("part_mxid", part.MXID).Msg("Failed to redact message part deleted in edit") + } + err = portal.Bridge.DB.Message.Delete(ctx, part.RowID) + if err != nil { + log.Err(err).Int64("part_rowid", part.RowID).Msg("Failed to delete message part from database") + } + } } func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogin, evt RemoteReaction) { From a150a476041936112772d31b4c6aa8cea5ddc9c8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 6 Jun 2024 20:58:22 +0300 Subject: [PATCH 0259/1647] Adjust remote event interfaces and add support for reactions --- bridgev2/networkinterface.go | 75 ++++++++++++++++++++++++------------ bridgev2/portal.go | 65 ++++++++++++++++++++++++++++--- 2 files changed, 111 insertions(+), 29 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 89ce81aa..365c2436 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -32,10 +32,6 @@ type EventSender struct { } type ConvertedMessage struct { - // TODO are these ever ambiguous at the time of forming the RemoteMessage? - //ID networkid.MessageID - //EventSender - Timestamp time.Time ReplyTo *networkid.MessageOptionalPartID ThreadRoot *networkid.MessageOptionalPartID Parts []*ConvertedMessagePart @@ -54,7 +50,6 @@ type ConvertedEditPart struct { } type ConvertedEdit struct { - Timestamp time.Time ModifiedParts []*ConvertedEditPart DeletedParts []*database.Message } @@ -87,7 +82,8 @@ type NetworkAPI interface { type RemoteEventType int const ( - RemoteEventMessage RemoteEventType = iota + RemoteEventUnknown RemoteEventType = iota + RemoteEventMessage RemoteEventEdit RemoteEventReaction RemoteEventReactionRemove @@ -102,6 +98,21 @@ type RemoteEvent interface { GetSender() EventSender } +type RemoteEventWithTargetMessage interface { + RemoteEvent + GetTargetMessage() networkid.MessageID +} + +type RemoteEventWithTargetPart interface { + RemoteEventWithTargetMessage + GetTargetMessagePart() networkid.PartID +} + +type RemoteEventWithTimestamp interface { + RemoteEvent + GetTimestamp() time.Time +} + type RemoteMessage interface { RemoteEvent GetID() networkid.MessageID @@ -109,26 +120,27 @@ type RemoteMessage interface { } type RemoteEdit interface { - RemoteEvent - GetTargetMessage() networkid.MessageID + RemoteEventWithTargetMessage ConvertEdit(ctx context.Context, portal *Portal, intent MatrixAPI, existing []*database.Message) (*ConvertedEdit, error) } type RemoteReaction interface { - RemoteEvent - GetTargetMessage() networkid.MessageID + RemoteEventWithTargetMessage GetReactionEmoji() (string, networkid.EmojiID) } +type RemoteReactionWithMeta interface { + RemoteReaction + GetReactionDBMetadata() map[string]any +} + type RemoteReactionRemove interface { - RemoteEvent - GetTargetMessage() networkid.MessageID + RemoteEventWithTargetMessage GetRemovedEmojiID() networkid.EmojiID } type RemoteMessageRemove interface { - RemoteEvent - GetTargetMessage() networkid.MessageID + RemoteEventWithTargetMessage } // SimpleRemoteEvent is a simple implementation of RemoteEvent that can be used with struct fields and some callbacks. @@ -139,22 +151,26 @@ type SimpleRemoteEvent[T any] struct { Data T CreatePortal bool - ID networkid.MessageID - Sender EventSender - TargetMessage networkid.MessageID - EmojiID networkid.EmojiID - Emoji string + ID networkid.MessageID + Sender EventSender + TargetMessage networkid.MessageID + EmojiID networkid.EmojiID + Emoji string + ReactionDBMeta map[string]any + Timestamp time.Time ConvertMessageFunc func(ctx context.Context, portal *Portal, intent MatrixAPI, data T) (*ConvertedMessage, error) ConvertEditFunc func(ctx context.Context, portal *Portal, intent MatrixAPI, existing []*database.Message, data T) (*ConvertedEdit, error) } var ( - _ RemoteMessage = (*SimpleRemoteEvent[any])(nil) - _ RemoteEdit = (*SimpleRemoteEvent[any])(nil) - _ RemoteReaction = (*SimpleRemoteEvent[any])(nil) - _ RemoteReactionRemove = (*SimpleRemoteEvent[any])(nil) - _ RemoteMessageRemove = (*SimpleRemoteEvent[any])(nil) + _ RemoteMessage = (*SimpleRemoteEvent[any])(nil) + _ RemoteEdit = (*SimpleRemoteEvent[any])(nil) + _ RemoteEventWithTimestamp = (*SimpleRemoteEvent[any])(nil) + _ RemoteReaction = (*SimpleRemoteEvent[any])(nil) + _ RemoteReactionWithMeta = (*SimpleRemoteEvent[any])(nil) + _ RemoteReactionRemove = (*SimpleRemoteEvent[any])(nil) + _ RemoteMessageRemove = (*SimpleRemoteEvent[any])(nil) ) func (sre *SimpleRemoteEvent[T]) AddLogContext(c zerolog.Context) zerolog.Context { @@ -165,6 +181,13 @@ func (sre *SimpleRemoteEvent[T]) GetPortalID() networkid.PortalID { return sre.PortalID } +func (sre *SimpleRemoteEvent[T]) GetTimestamp() time.Time { + if sre.Timestamp.IsZero() { + return time.Now() + } + return sre.Timestamp +} + func (sre *SimpleRemoteEvent[T]) ConvertMessage(ctx context.Context, portal *Portal, intent MatrixAPI) (*ConvertedMessage, error) { return sre.ConvertMessageFunc(ctx, portal, intent, sre.Data) } @@ -193,6 +216,10 @@ func (sre *SimpleRemoteEvent[T]) GetRemovedEmojiID() networkid.EmojiID { return sre.EmojiID } +func (sre *SimpleRemoteEvent[T]) GetReactionDBMetadata() map[string]any { + return sre.ReactionDBMeta +} + func (sre *SimpleRemoteEvent[T]) GetType() RemoteEventType { return sre.Type } diff --git a/bridgev2/portal.go b/bridgev2/portal.go index f7fdb3cd..68b49743 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -16,6 +16,7 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/exslices" + "go.mau.fi/util/variationselector" "golang.org/x/exp/slices" "maunium.net/go/mautrix" @@ -637,6 +638,7 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin // TODO 2 fetch last event in thread properly prevThreadEvent = threadRoot } + ts := getEventTS(evt) for _, part := range converted.Parts { if threadRoot != nil && prevThreadEvent != nil { part.Content.GetRelatesTo().SetThread(threadRoot.MXID, prevThreadEvent.MXID) @@ -654,7 +656,7 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin resp, err := intent.SendMessage(ctx, portal.MXID, part.Type, &event.Content{ Parsed: part.Content, Raw: part.Extra, - }, converted.Timestamp) + }, ts) if err != nil { log.Err(err).Str("part_id", string(part.ID)).Msg("Failed to send message part to Matrix") continue @@ -670,7 +672,7 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin MXID: resp.EventID, RoomID: portal.ID, SenderID: evt.GetSender().Sender, - Timestamp: converted.Timestamp, + Timestamp: ts, RelatesToRowID: relatesToRowID, Metadata: part.DBMetadata, } @@ -703,6 +705,7 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e // TODO log and notify room? return } + ts := getEventTS(evt) for _, part := range converted.ModifiedParts { part.Content.SetEdit(part.Part.MXID) if part.TopLevelExtra == nil { @@ -715,7 +718,7 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e Parsed: part.Content, Raw: part.TopLevelExtra, } - _, err = intent.SendMessage(ctx, portal.MXID, part.Type, wrappedContent, converted.Timestamp) + _, err = intent.SendMessage(ctx, portal.MXID, part.Type, wrappedContent, ts) if err != nil { log.Err(err).Stringer("part_mxid", part.Part.MXID).Msg("Failed to edit message part") } @@ -730,7 +733,7 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e Redacts: part.MXID, }, } - _, err = intent.SendMessage(ctx, portal.MXID, event.EventRedaction, redactContent, converted.Timestamp) + _, err = intent.SendMessage(ctx, portal.MXID, event.EventRedaction, redactContent, ts) if err != nil { log.Err(err).Stringer("part_mxid", part.MXID).Msg("Failed to redact message part deleted in edit") } @@ -741,8 +744,60 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e } } -func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogin, evt RemoteReaction) { +func (portal *Portal) getTargetMessagePart(ctx context.Context, evt RemoteEventWithTargetMessage) (*database.Message, error) { + if partTargeter, ok := evt.(RemoteEventWithTargetPart); ok { + return portal.Bridge.DB.Message.GetPartByID(ctx, evt.GetTargetMessage(), partTargeter.GetTargetMessagePart()) + } else { + return portal.Bridge.DB.Message.GetFirstPartByID(ctx, evt.GetTargetMessage()) + } +} +func getEventTS(evt RemoteEvent) time.Time { + if tsProvider, ok := evt.(RemoteEventWithTimestamp); ok { + return tsProvider.GetTimestamp() + } + return time.Now() +} + +func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogin, evt RemoteReaction) { + log := zerolog.Ctx(ctx) + targetMessage, err := portal.getTargetMessagePart(ctx, evt) + if err != nil { + log.Err(err).Msg("Failed to get target message for reaction") + return + } + ts := getEventTS(evt) + intent := portal.getIntentFor(ctx, evt.GetSender(), source) + emoji, emojiID := evt.GetReactionEmoji() + resp, err := intent.SendMessage(ctx, portal.MXID, event.EventReaction, &event.Content{ + Parsed: &event.ReactionEventContent{ + RelatesTo: event.RelatesTo{ + Type: event.RelAnnotation, + EventID: targetMessage.MXID, + Key: variationselector.Add(emoji), + }, + }, + }, ts) + if err != nil { + log.Err(err).Msg("Failed to send reaction to Matrix") + return + } + dbReaction := &database.Reaction{ + RoomID: portal.ID, + MessageID: targetMessage.ID, + MessagePartID: targetMessage.PartID, + SenderID: evt.GetSender().Sender, + EmojiID: emojiID, + MXID: resp.EventID, + Timestamp: ts, + } + if metaProvider, ok := evt.(RemoteReactionWithMeta); ok { + dbReaction.Metadata = metaProvider.GetReactionDBMetadata() + } + err = portal.Bridge.DB.Reaction.Upsert(ctx, dbReaction) + if err != nil { + log.Err(err).Msg("Failed to save reaction to database") + } } func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *UserLogin, evt RemoteReactionRemove) { From 8670a1cbb3b46d97b46abbefeb744608727aaee8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 6 Jun 2024 21:23:42 +0300 Subject: [PATCH 0260/1647] Fix reaction deduplication --- bridgev2/networkinterface.go | 9 +++- bridgev2/portal.go | 95 ++++++++++++++++++++++++++++++------ 2 files changed, 89 insertions(+), 15 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 365c2436..0fd9d641 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -74,6 +74,7 @@ type NetworkAPI interface { HandleMatrixMessage(ctx context.Context, msg *MatrixMessage) (message *database.Message, err error) HandleMatrixEdit(ctx context.Context, msg *MatrixEdit) error + PreHandleMatrixReaction(ctx context.Context, msg *MatrixReaction) (MatrixReactionPreResponse, error) HandleMatrixReaction(ctx context.Context, msg *MatrixReaction) (reaction *database.Reaction, err error) HandleMatrixReactionRemove(ctx context.Context, msg *MatrixReactionRemove) error HandleMatrixMessageRemove(ctx context.Context, msg *MatrixMessageRemove) error @@ -259,7 +260,13 @@ type MatrixEdit struct { type MatrixReaction struct { MatrixEventBase[*event.ReactionEventContent] TargetMessage *database.Message - GetExisting func(ctx context.Context, senderID networkid.UserID, emojiID networkid.EmojiID) (*database.Reaction, error) +} + +type MatrixReactionPreResponse struct { + SenderID networkid.UserID + EmojiID networkid.EmojiID + Emoji string + MaxReactions int } type MatrixReactionRemove struct { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 68b49743..c18b1f80 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -373,6 +373,15 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin // TODO send metrics here or inside HandleMatrixMessage? return } + if message.MXID == "" { + message.MXID = evt.ID + } + if message.RoomID == "" { + message.RoomID = portal.ID + } + if message.Timestamp.IsZero() { + message.Timestamp = time.UnixMilli(evt.Timestamp) + } if message.Metadata == nil { message.Metadata = make(map[string]any) } @@ -454,30 +463,66 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Str("reaction_target_remote_id", string(reactionTarget.ID)) }) - dbReaction, err := sender.Client.HandleMatrixReaction(ctx, &MatrixReaction{ + react := &MatrixReaction{ MatrixEventBase: MatrixEventBase[*event.ReactionEventContent]{ Event: evt, Content: content, Portal: portal, }, TargetMessage: reactionTarget, - GetExisting: func(ctx context.Context, senderID networkid.UserID, emojiID networkid.EmojiID) (*database.Reaction, error) { - return portal.Bridge.DB.Reaction.GetByID(ctx, reactionTarget.ID, reactionTarget.PartID, senderID, emojiID) - }, - }) + } + preResp, err := sender.Client.PreHandleMatrixReaction(ctx, react) + if err != nil { + log.Err(err).Msg("Failed to pre-handle Matrix reaction") + // TODO send metrics + return + } + existing, err := portal.Bridge.DB.Reaction.GetByID(ctx, reactionTarget.ID, reactionTarget.PartID, preResp.SenderID, preResp.EmojiID) + if err != nil { + log.Err(err).Msg("Failed to check if reaction is a duplicate") + return + } else if existing != nil { + if existing.EmojiID != "" || existing.Metadata["emoji"] == preResp.Emoji { + log.Debug().Msg("Ignoring duplicate reaction") + // TODO send metrics + return + } + _, err = portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ + Parsed: &event.RedactionEventContent{ + Redacts: existing.MXID, + }, + }, time.Now()) + if err != nil { + log.Err(err).Msg("Failed to remove old reaction") + } + } + if preResp.MaxReactions > 0 { + // TODO get all reactions to message by sender in order to remove oldest ones + // (this is necessary for telegram where reaction limit is 1 or 3 based on premium status) + } + dbReaction, err := sender.Client.HandleMatrixReaction(ctx, react) if err != nil { log.Err(err).Msg("Failed to handle Matrix reaction") // TODO send metrics here or inside HandleMatrixReaction? return } - // TODO figure out how to delete outdated reactions if appropriate - if dbReaction != nil { - err = portal.Bridge.DB.Reaction.Upsert(ctx, dbReaction) - if err != nil { - log.Err(err).Msg("Failed to save reaction to database") - } - } else { - log.Debug().Msg("Reaction was ignored") + // Fill all fields that are known to allow omitting them in connector code + if dbReaction.RoomID == "" { + dbReaction.RoomID = portal.ID + } + if dbReaction.MessageID == "" { + dbReaction.MessageID = reactionTarget.ID + dbReaction.MessagePartID = reactionTarget.PartID + } + if dbReaction.MXID == "" { + dbReaction.MXID = evt.ID + } + if dbReaction.Timestamp.IsZero() { + dbReaction.Timestamp = time.UnixMilli(evt.Timestamp) + } + err = portal.Bridge.DB.Reaction.Upsert(ctx, dbReaction) + if err != nil { + log.Err(err).Msg("Failed to save reaction to database") } // TODO send success metrics } @@ -766,9 +811,17 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi log.Err(err).Msg("Failed to get target message for reaction") return } + emoji, emojiID := evt.GetReactionEmoji() + existingReaction, err := portal.Bridge.DB.Reaction.GetByID(ctx, targetMessage.ID, targetMessage.PartID, evt.GetSender().Sender, emojiID) + if err != nil { + log.Err(err).Msg("Failed to check if reaction is a duplicate") + return + } else if existingReaction != nil && (emojiID != "" || existingReaction.Metadata["emoji"] == emoji) { + log.Debug().Msg("Ignoring duplicate reaction") + return + } ts := getEventTS(evt) intent := portal.getIntentFor(ctx, evt.GetSender(), source) - emoji, emojiID := evt.GetReactionEmoji() resp, err := intent.SendMessage(ctx, portal.MXID, event.EventReaction, &event.Content{ Parsed: &event.ReactionEventContent{ RelatesTo: event.RelatesTo{ @@ -793,11 +846,25 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi } if metaProvider, ok := evt.(RemoteReactionWithMeta); ok { dbReaction.Metadata = metaProvider.GetReactionDBMetadata() + } else if emojiID == "" { + dbReaction.Metadata = map[string]any{ + "emoji": emoji, + } } err = portal.Bridge.DB.Reaction.Upsert(ctx, dbReaction) if err != nil { log.Err(err).Msg("Failed to save reaction to database") } + if existingReaction != nil { + _, err = intent.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ + Parsed: &event.RedactionEventContent{ + Redacts: existingReaction.MXID, + }, + }, ts) + if err != nil { + log.Err(err).Msg("Failed to redact old reaction") + } + } } func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *UserLogin, evt RemoteReactionRemove) { From 4836aec6cfa20b6921794302fc0ec57184ff2b10 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 6 Jun 2024 21:41:52 +0300 Subject: [PATCH 0261/1647] Auto-fill more fields when handling Matrix reactions --- bridgev2/networkinterface.go | 1 + bridgev2/portal.go | 14 ++++++++++++++ 2 files changed, 15 insertions(+) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 0fd9d641..fd9a7f73 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -260,6 +260,7 @@ type MatrixEdit struct { type MatrixReaction struct { MatrixEventBase[*event.ReactionEventContent] TargetMessage *database.Message + PreHandleResp *MatrixReactionPreResponse } type MatrixReactionPreResponse struct { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index c18b1f80..a19c0322 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -496,6 +496,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi log.Err(err).Msg("Failed to remove old reaction") } } + react.PreHandleResp = &preResp if preResp.MaxReactions > 0 { // TODO get all reactions to message by sender in order to remove oldest ones // (this is necessary for telegram where reaction limit is 1 or 3 based on premium status) @@ -520,6 +521,19 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi if dbReaction.Timestamp.IsZero() { dbReaction.Timestamp = time.UnixMilli(evt.Timestamp) } + if dbReaction.Metadata == nil { + dbReaction.Metadata = make(map[string]any) + } + if preResp.EmojiID == "" && dbReaction.EmojiID == "" { + if _, alreadySet := dbReaction.Metadata["emoji"]; !alreadySet { + dbReaction.Metadata["emoji"] = preResp.Emoji + } + } else if dbReaction.EmojiID == "" { + dbReaction.EmojiID = preResp.EmojiID + } + if dbReaction.SenderID == "" { + dbReaction.SenderID = preResp.SenderID + } err = portal.Bridge.DB.Reaction.Upsert(ctx, dbReaction) if err != nil { log.Err(err).Msg("Failed to save reaction to database") From 476f6fbc2595ff4be10f45c6bc84bad4f186d3d9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 7 Jun 2024 12:54:39 +0300 Subject: [PATCH 0262/1647] Add support for remote message/reaction removals --- bridgev2/database/reaction.go | 15 ++++++---- bridgev2/portal.go | 55 +++++++++++++++++++++++++++++++++-- 2 files changed, 63 insertions(+), 7 deletions(-) diff --git a/bridgev2/database/reaction.go b/bridgev2/database/reaction.go index f5d5a469..4abec731 100644 --- a/bridgev2/database/reaction.go +++ b/bridgev2/database/reaction.go @@ -42,11 +42,12 @@ const ( getReactionBaseQuery = ` SELECT bridge_id, message_id, message_part_id, sender_id, emoji_id, room_id, mxid, timestamp, metadata FROM reaction ` - getReactionByIDQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND message_part_id=$3 AND sender_id=$4 AND emoji_id=$5` - getAllReactionsToMessageBySenderQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND sender_id=$3` - getAllReactionsToMessageQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2` - getReactionByMXIDQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` - upsertReactionQuery = ` + getReactionByIDQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND message_part_id=$3 AND sender_id=$4 AND emoji_id=$5` + getReactionByIDWithoutMessagePartQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND sender_id=$3 AND emoji_id=$4 ORDER BY message_part_id ASC LIMIT 1` + getAllReactionsToMessageBySenderQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND sender_id=$3 ORDER BY timestamp DESC` + getAllReactionsToMessageQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2` + getReactionByMXIDQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` + upsertReactionQuery = ` INSERT INTO reaction (bridge_id, message_id, message_part_id, sender_id, emoji_id, room_id, mxid, timestamp, metadata) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) ON CONFLICT (bridge_id, message_id, message_part_id, sender_id, emoji_id) @@ -61,6 +62,10 @@ func (rq *ReactionQuery) GetByID(ctx context.Context, messageID networkid.Messag return rq.QueryOne(ctx, getReactionByIDQuery, rq.BridgeID, messageID, messagePartID, senderID, emojiID) } +func (rq *ReactionQuery) GetByIDWithoutMessagePart(ctx context.Context, messageID networkid.MessageID, senderID networkid.UserID, emojiID networkid.EmojiID) (*Reaction, error) { + return rq.QueryOne(ctx, getReactionByIDWithoutMessagePartQuery, rq.BridgeID, messageID, senderID, emojiID) +} + func (rq *ReactionQuery) GetAllToMessageBySender(ctx context.Context, messageID networkid.MessageID, senderID networkid.UserID) ([]*Reaction, error) { return rq.QueryMany(ctx, getAllReactionsToMessageBySenderQuery, rq.BridgeID, messageID, senderID) } diff --git a/bridgev2/portal.go b/bridgev2/portal.go index a19c0322..b3960fbb 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -811,6 +811,14 @@ func (portal *Portal) getTargetMessagePart(ctx context.Context, evt RemoteEventW } } +func (portal *Portal) getTargetReaction(ctx context.Context, evt RemoteReactionRemove) (*database.Reaction, error) { + if partTargeter, ok := evt.(RemoteEventWithTargetPart); ok { + return portal.Bridge.DB.Reaction.GetByID(ctx, evt.GetTargetMessage(), partTargeter.GetTargetMessagePart(), evt.GetSender().Sender, evt.GetRemovedEmojiID()) + } else { + return portal.Bridge.DB.Reaction.GetByIDWithoutMessagePart(ctx, evt.GetTargetMessage(), evt.GetSender().Sender, evt.GetRemovedEmojiID()) + } +} + func getEventTS(evt RemoteEvent) time.Time { if tsProvider, ok := evt.(RemoteEventWithTimestamp); ok { return tsProvider.GetTimestamp() @@ -882,11 +890,54 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi } func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *UserLogin, evt RemoteReactionRemove) { - + log := zerolog.Ctx(ctx) + targetReaction, err := portal.getTargetReaction(ctx, evt) + if err != nil { + log.Err(err).Msg("Failed to get target reaction for removal") + return + } else if targetReaction == nil { + log.Warn().Msg("Target reaction not found") + return + } + intent := portal.getIntentFor(ctx, evt.GetSender(), source) + ts := getEventTS(evt) + _, err = intent.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ + Parsed: &event.RedactionEventContent{ + Redacts: targetReaction.MXID, + }, + }, ts) + if err != nil { + log.Err(err).Stringer("reaction_mxid", targetReaction.MXID).Msg("Failed to redact reaction") + } + err = portal.Bridge.DB.Reaction.Delete(ctx, targetReaction) + if err != nil { + log.Err(err).Msg("Failed to delete target reaction from database") + } } func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *UserLogin, evt RemoteMessageRemove) { - + log := zerolog.Ctx(ctx) + targetParts, err := portal.Bridge.DB.Message.GetAllPartsByID(ctx, evt.GetTargetMessage()) + if err != nil { + log.Err(err).Msg("Failed to get target message for removal") + return + } + intent := portal.getIntentFor(ctx, evt.GetSender(), source) + ts := getEventTS(evt) + for _, part := range targetParts { + _, err = intent.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ + Parsed: &event.RedactionEventContent{ + Redacts: part.MXID, + }, + }, ts) + if err != nil { + log.Err(err).Stringer("part_mxid", part.MXID).Msg("Failed to redact message part") + } + } + err = portal.Bridge.DB.Message.DeleteAllParts(ctx, evt.GetTargetMessage()) + if err != nil { + log.Err(err).Msg("Failed to delete target message from database") + } } var stateElementFunctionalMembers = event.Type{Class: event.StateEventType, Type: "io.element.functional_members"} From 6466bf9452600d36358aa17ab0240306df2d9d41 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 7 Jun 2024 12:55:01 +0300 Subject: [PATCH 0263/1647] Add support for max reaction count when handling Matrix reactions --- bridgev2/networkinterface.go | 3 +++ bridgev2/portal.go | 26 ++++++++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index fd9a7f73..7d58eae8 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -261,6 +261,9 @@ type MatrixReaction struct { MatrixEventBase[*event.ReactionEventContent] TargetMessage *database.Message PreHandleResp *MatrixReactionPreResponse + + // When MaxReactions is >0 in the pre-response, this is the list of previous reactions that should be preserved. + ExistingReactionsToKeep []*database.Reaction } type MatrixReactionPreResponse struct { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index b3960fbb..4234bb8c 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -498,6 +498,32 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi } react.PreHandleResp = &preResp if preResp.MaxReactions > 0 { + allReactions, err := portal.Bridge.DB.Reaction.GetAllToMessageBySender(ctx, reactionTarget.ID, preResp.SenderID) + if err != nil { + log.Err(err).Msg("Failed to get all reactions to message by sender") + // TODO send metrics + return + } + if len(allReactions) < preResp.MaxReactions { + react.ExistingReactionsToKeep = allReactions + } else { + // Keep n-1 previous reactions and remove the rest + react.ExistingReactionsToKeep = allReactions[:preResp.MaxReactions-1] + for _, oldReaction := range allReactions[preResp.MaxReactions-1:] { + _, err = portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ + Parsed: &event.RedactionEventContent{ + Redacts: oldReaction.MXID, + }, + }, time.Now()) + if err != nil { + log.Err(err).Msg("Failed to remove previous reaction after limit was exceeded") + } + err = portal.Bridge.DB.Reaction.Delete(ctx, oldReaction) + if err != nil { + log.Err(err).Msg("Failed to delete previous reaction from database after limit was exceeded") + } + } + } // TODO get all reactions to message by sender in order to remove oldest ones // (this is necessary for telegram where reaction limit is 1 or 3 based on premium status) } From 2580ef78d77d26811e8c9401bad774a09c58bf5b Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Mon, 10 Jun 2024 09:46:38 +0100 Subject: [PATCH 0264/1647] Add `FullRequest.BackoffDuration` field --- client.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index 4fe7ec50..7fb9548b 100644 --- a/client.go +++ b/client.go @@ -337,6 +337,7 @@ type FullRequest struct { RequestLength int64 ResponseJSON interface{} MaxAttempts int + BackoffDuration time.Duration SensitiveContent bool Handler ClientResponseHandler Logger *zerolog.Logger @@ -413,6 +414,9 @@ func (cli *Client) MakeFullRequest(ctx context.Context, params FullRequest) ([]b if params.MaxAttempts == 0 { params.MaxAttempts = 1 + cli.DefaultHTTPRetries } + if params.BackoffDuration == 0 { + params.BackoffDuration = 4 * time.Second + } if params.Logger == nil { params.Logger = &cli.Log } @@ -430,7 +434,7 @@ func (cli *Client) MakeFullRequest(ctx context.Context, params FullRequest) ([]b if params.Client == nil { params.Client = cli.Client } - return cli.executeCompiledRequest(req, params.MaxAttempts-1, 4*time.Second, params.ResponseJSON, params.Handler, params.Client) + return cli.executeCompiledRequest(req, params.MaxAttempts-1, params.BackoffDuration, params.ResponseJSON, params.Handler, params.Client) } func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger { From 92de4a8a51c6625e54fc4230e6508f197ee86153 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Mon, 10 Jun 2024 09:46:48 +0100 Subject: [PATCH 0265/1647] Add `ReqSync.Client` field --- client.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/client.go b/client.go index 7fb9548b..19a49267 100644 --- a/client.go +++ b/client.go @@ -621,6 +621,7 @@ type ReqSync struct { SetPresence event.Presence StreamResponse bool BeeperStreaming bool + Client *http.Client } func (req *ReqSync) BuildQuery() map[string]string { @@ -654,6 +655,7 @@ func (cli *Client) FullSyncRequest(ctx context.Context, req ReqSync) (resp *Resp Method: http.MethodGet, URL: urlPath, ResponseJSON: &resp, + Client: req.Client, // We don't want automatic retries for SyncRequest, the Sync() wrapper handles those. MaxAttempts: 1, } From 9c77bffa4306b4d295e6ec51e3c1d6344cc7c4ee Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Mon, 10 Jun 2024 11:49:30 +0100 Subject: [PATCH 0266/1647] Add `Client.DefaultHTTPBackoff` --- client.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index 19a49267..c3ac8d91 100644 --- a/client.go +++ b/client.go @@ -76,6 +76,8 @@ type Client struct { // Number of times that mautrix will retry any HTTP request // if the request fails entirely or returns a HTTP gateway error (502-504) DefaultHTTPRetries int + // Amount of time to wait between HTTP retries, defaults to 4 seconds + DefaultHTTPBackoff time.Duration // Set to true to disable automatically sleeping on 429 errors. IgnoreRateLimit bool @@ -415,7 +417,11 @@ func (cli *Client) MakeFullRequest(ctx context.Context, params FullRequest) ([]b params.MaxAttempts = 1 + cli.DefaultHTTPRetries } if params.BackoffDuration == 0 { - params.BackoffDuration = 4 * time.Second + if cli.DefaultHTTPBackoff == 0 { + params.BackoffDuration = 4 * time.Second + } else { + params.BackoffDuration = cli.DefaultHTTPBackoff + } } if params.Logger == nil { params.Logger = &cli.Log From 9ba40c5d17f303d69ba785ecb645c664b53ccf46 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 10 Jun 2024 15:11:59 +0300 Subject: [PATCH 0267/1647] Add message checkpoints, status events and error notices --- bridgev2/cmdprocessor.go | 26 ++++- bridgev2/matrix/connector.go | 27 ++++- bridgev2/matrix/cryptoerror.go | 93 ++++++++++++++++ bridgev2/matrix/matrix.go | 52 ++++++--- bridgev2/matrixinterface.go | 2 +- bridgev2/messagestatus.go | 188 +++++++++++++++++++++++++++++---- bridgev2/portal.go | 139 +++++++++++++++++------- bridgev2/queue.go | 14 ++- 8 files changed, 454 insertions(+), 87 deletions(-) create mode 100644 bridgev2/matrix/cryptoerror.go diff --git a/bridgev2/cmdprocessor.go b/bridgev2/cmdprocessor.go index 55ad0c9a..1be36b47 100644 --- a/bridgev2/cmdprocessor.go +++ b/bridgev2/cmdprocessor.go @@ -8,11 +8,14 @@ package bridgev2 import ( "context" + "fmt" "runtime/debug" "strings" "github.com/rs/zerolog" + "maunium.net/go/mautrix/bridge/status" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -59,13 +62,32 @@ func (proc *CommandProcessor) AddHandler(handler CommandHandler) { // Handle handles messages to the bridge func (proc *CommandProcessor) Handle(ctx context.Context, roomID id.RoomID, eventID id.EventID, user *User, message string, replyTo id.EventID) { defer func() { + statusInfo := &MessageStatusEventInfo{ + RoomID: roomID, + EventID: eventID, + EventType: event.EventMessage, + Sender: user.MXID, + } + ms := MessageStatus{ + Step: status.MsgStepCommand, + Status: event.MessageStatusSuccess, + } err := recover() if err != nil { zerolog.Ctx(ctx).Error(). - Str(zerolog.ErrorStackFieldName, string(debug.Stack())). - Interface(zerolog.ErrorFieldName, err). + Bytes(zerolog.ErrorStackFieldName, debug.Stack()). + Any(zerolog.ErrorFieldName, err). Msg("Panic in Matrix command handler") + ms.Status = event.MessageStatusFail + ms.IsCertain = true + if realErr, ok := err.(error); ok { + ms.InternalError = realErr + } else { + ms.InternalError = fmt.Errorf("%v", err) + } + ms.ErrorAsMessage = true } + proc.bridge.Matrix.SendMessageStatus(ctx, &ms, statusInfo) }() args := strings.Fields(message) if len(args) == 0 { diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 4d28d3ee..7e9cc892 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -127,12 +127,35 @@ func (br *Connector) GhostIntent(userID id.UserID) bridgev2.MatrixAPI { } } -func (br *Connector) SendMessageStatus(ctx context.Context, evt bridgev2.MessageStatus) { +func (br *Connector) SendMessageStatus(ctx context.Context, ms *bridgev2.MessageStatus, evt *bridgev2.MessageStatusEventInfo) { + br.internalSendMessageStatus(ctx, ms, evt, "") +} + +func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2.MessageStatus, evt *bridgev2.MessageStatusEventInfo, editEvent id.EventID) id.EventID { log := zerolog.Ctx(ctx) - err := br.SendMessageCheckpoints([]*status.MessageCheckpoint{evt.ToCheckpoint()}) + err := br.SendMessageCheckpoints([]*status.MessageCheckpoint{ms.ToCheckpoint(evt)}) if err != nil { log.Err(err).Msg("Failed to send message checkpoint") } + if br.Config.Bridge.MessageStatusEvents { + _, err = br.Bot.SendMessageEvent(ctx, evt.RoomID, event.BeeperMessageStatus, ms.ToMSSEvent(evt)) + if err != nil { + log.Err(err).Msg("Failed to send MSS event") + } + } + if ms.SendNotice && br.Config.Bridge.MessageErrorNotices && (ms.Status == event.MessageStatusFail || ms.Status == event.MessageStatusRetriable || ms.Step == status.MsgStepDecrypted) { + content := ms.ToNoticeEvent(evt) + if editEvent != "" { + content.SetEdit(editEvent) + } + resp, err := br.Bot.SendMessageEvent(ctx, evt.RoomID, event.EventMessage, content) + if err != nil { + log.Err(err).Msg("Failed to send notice event") + } else { + return resp.EventID + } + } + return "" } func (br *Connector) SendMessageCheckpoints(checkpoints []*status.MessageCheckpoint) error { diff --git a/bridgev2/matrix/cryptoerror.go b/bridgev2/matrix/cryptoerror.go new file mode 100644 index 00000000..93cf0e75 --- /dev/null +++ b/bridgev2/matrix/cryptoerror.go @@ -0,0 +1,93 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package matrix + +import ( + "context" + "errors" + "fmt" + + "maunium.net/go/mautrix/bridge/status" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +var ( + errDeviceNotTrusted = errors.New("your device is not trusted") + errMessageNotEncrypted = errors.New("unencrypted message") + errNoDecryptionKeys = errors.New("the bridge hasn't received the decryption keys") + errNoCrypto = errors.New("this bridge has not been configured to support encryption") +) + +func errorToHumanMessage(err error) string { + var withheld *event.RoomKeyWithheldEventContent + switch { + case errors.Is(err, errDeviceNotTrusted), errors.Is(err, errNoDecryptionKeys), errors.Is(err, errNoCrypto): + return err.Error() + case errors.Is(err, UnknownMessageIndex): + return "the keys received by the bridge can't decrypt the message" + case errors.Is(err, DuplicateMessageIndex): + return "your client encrypted multiple messages with the same key" + case errors.As(err, &withheld): + if withheld.Code == event.RoomKeyWithheldBeeperRedacted { + return "your client used an outdated encryption session" + } + return "your client refused to share decryption keys with the bridge" + case errors.Is(err, errMessageNotEncrypted): + return "the message is not encrypted" + default: + return "the bridge failed to decrypt the message" + } +} + +func deviceUnverifiedErrorWithExplanation(trust id.TrustState) error { + var explanation string + switch trust { + case id.TrustStateBlacklisted: + explanation = "device is blacklisted" + case id.TrustStateUnset: + explanation = "unverified" + case id.TrustStateUnknownDevice: + explanation = "device info not found" + case id.TrustStateForwarded: + explanation = "keys were forwarded from an unknown device" + case id.TrustStateCrossSignedUntrusted: + explanation = "cross-signing keys changed after setting up the bridge" + default: + return errDeviceNotTrusted + } + return fmt.Errorf("%w (%s)", errDeviceNotTrusted, explanation) +} + +func (br *Connector) sendCryptoStatusError(ctx context.Context, evt *event.Event, err error, errorEventID *id.EventID, retryNum int, isFinal bool) { + ms := &bridgev2.MessageStatus{ + Step: status.MsgStepDecrypted, + Status: event.MessageStatusRetriable, + ErrorReason: event.MessageStatusUndecryptable, + InternalError: err, + Message: errorToHumanMessage(err), + IsCertain: true, + SendNotice: true, + RetryNum: retryNum, + } + if !isFinal { + ms.Status = event.MessageStatusPending + // Don't send notice for first error + if retryNum == 0 { + ms.SendNotice = false + } + } + var editEventID id.EventID + if errorEventID != nil { + editEventID = *errorEventID + } + respEventID := br.internalSendMessageStatus(ctx, ms, bridgev2.StatusEventInfoFromEvent(evt), editEventID) + if errorEventID != nil && *errorEventID == "" { + *errorEventID = respEventID + } +} diff --git a/bridgev2/matrix/matrix.go b/bridgev2/matrix/matrix.go index 7dfeacf0..bd2a1bb9 100644 --- a/bridgev2/matrix/matrix.go +++ b/bridgev2/matrix/matrix.go @@ -9,10 +9,12 @@ package matrix import ( "context" "errors" + "fmt" "time" "github.com/rs/zerolog" + "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -24,7 +26,7 @@ func (br *Connector) handleRoomEvent(ctx context.Context, evt *event.Event) { } if (evt.Type == event.EventMessage || evt.Type == event.EventSticker) && !evt.Mautrix.WasEncrypted && br.Config.Encryption.Require { zerolog.Ctx(ctx).Warn().Msg("Dropping unencrypted event as encryption is configured to be required") - // TODO send metrics + br.sendCryptoStatusError(ctx, evt, errMessageNotEncrypted, nil, 0, true) return } br.Bridge.QueueMatrixEvent(ctx, evt) @@ -48,7 +50,7 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event) Logger() ctx = log.WithContext(ctx) if br.Crypto == nil { - // TODO send metrics + br.sendCryptoStatusError(ctx, evt, errNoCrypto, nil, 0, true) log.Error().Msg("Can't decrypt message: no crypto") return } @@ -62,7 +64,7 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event) log.Debug(). Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())). Msg("Couldn't find session, waiting for keys to arrive...") - // TODO send metrics + go br.sendCryptoStatusError(ctx, evt, err, nil, 0, false) if br.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { log.Debug().Msg("Got keys after waiting, trying to decrypt event again") decrypted, err = br.Crypto.Decrypt(ctx, evt) @@ -73,10 +75,10 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event) } if err != nil { log.Warn().Err(err).Msg("Failed to decrypt event") - // TODO send metrics + go br.sendCryptoStatusError(ctx, evt, err, nil, decryptionRetryCount, true) return } - br.postDecrypt(ctx, evt, decrypted, decryptionRetryCount, "", time.Since(decryptionStart)) + br.postDecrypt(ctx, evt, decrypted, decryptionRetryCount, nil, time.Since(decryptionStart)) } func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event, decryptionStart time.Time) { @@ -87,12 +89,12 @@ func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event, Msg("Couldn't find session, requesting keys and waiting longer...") go br.Crypto.RequestSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) - var errorEventID id.EventID - // TODO send metrics + var errorEventID *id.EventID + go br.sendCryptoStatusError(ctx, evt, fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), errorEventID, 1, false) if !br.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { log.Debug().Msg("Didn't get session, giving up trying to decrypt event") - // TODO send metrics + go br.sendCryptoStatusError(ctx, evt, errNoDecryptionKeys, errorEventID, 2, true) return } @@ -100,7 +102,7 @@ func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event, decrypted, err := br.Crypto.Decrypt(ctx, evt) if err != nil { log.Error().Err(err).Msg("Failed to decrypt event") - // TODO send metrics + go br.sendCryptoStatusError(ctx, evt, err, errorEventID, 2, true) return } @@ -111,9 +113,25 @@ type CommandProcessor interface { Handle(ctx context.Context, roomID id.RoomID, eventID id.EventID, user bridgev2.User, message string, replyTo id.EventID) } -func (br *Connector) sendBridgeCheckpoint(_ context.Context, evt *event.Event) { +func (br *Connector) sendSuccessCheckpoint(ctx context.Context, evt *event.Event, step status.MessageCheckpointStep, retryNum int) { + err := br.SendMessageCheckpoints([]*status.MessageCheckpoint{{ + RoomID: evt.RoomID, + EventID: evt.ID, + EventType: evt.Type, + MessageType: evt.Content.AsMessage().MsgType, + Step: step, + Status: status.MsgStatusSuccess, + ReportedBy: status.MsgReportedByBridge, + RetryNum: retryNum, + }}) + if err != nil { + zerolog.Ctx(ctx).Err(err).Str("checkpoint_step", string(step)).Msg("Failed to send checkpoint") + } +} + +func (br *Connector) sendBridgeCheckpoint(ctx context.Context, evt *event.Event) { if !evt.Mautrix.CheckpointSent { - //go br.SendMessageSuccessCheckpoint(evt, status.MsgStepBridge, 0) + go br.sendSuccessCheckpoint(ctx, evt, status.MsgStepBridge, 0) } } @@ -140,7 +158,7 @@ func copySomeKeys(original, decrypted *event.Event) { } } -func (br *Connector) postDecrypt(ctx context.Context, original, decrypted *event.Event, retryCount int, errorEventID id.EventID, duration time.Duration) { +func (br *Connector) postDecrypt(ctx context.Context, original, decrypted *event.Event, retryCount int, errorEventID *id.EventID, duration time.Duration) { log := zerolog.Ctx(ctx) minLevel := br.Config.Encryption.VerificationLevels.Send if decrypted.Mautrix.TrustState < minLevel { @@ -158,18 +176,18 @@ func (br *Connector) postDecrypt(ctx context.Context, original, decrypted *event logEvt.Str("device_id", "unknown") } logEvt.Msg("Dropping event due to insufficient verification level") - //err := deviceUnverifiedErrorWithExplanation(decrypted.Mautrix.TrustState) - //go mx.sendCryptoStatusError(ctx, decrypted, errorEventID, err, retryCount, true) + err := deviceUnverifiedErrorWithExplanation(decrypted.Mautrix.TrustState) + go br.sendCryptoStatusError(ctx, decrypted, err, errorEventID, retryCount, true) return } copySomeKeys(original, decrypted) - // TODO checkpoint + go br.sendSuccessCheckpoint(ctx, decrypted, status.MsgStepDecrypted, retryCount) decrypted.Mautrix.CheckpointSent = true decrypted.Mautrix.DecryptionDuration = duration decrypted.Mautrix.EventSource |= event.SourceDecrypted br.EventProcessor.Dispatch(ctx, decrypted) - if errorEventID != "" { - _, _ = br.Bot.RedactEvent(ctx, decrypted.RoomID, errorEventID) + if errorEventID != nil && *errorEventID != "" { + _, _ = br.Bot.RedactEvent(ctx, decrypted.RoomID, *errorEventID) } } diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 4dd6dd46..b791a028 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -27,7 +27,7 @@ type MatrixConnector interface { UserIntent(user *User) MatrixAPI BotIntent() MatrixAPI - SendMessageStatus(ctx context.Context, status MessageStatus) + SendMessageStatus(ctx context.Context, status *MessageStatus, evt *MessageStatusEventInfo) GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) GetMemberInfo(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index c7d8d0b4..43fe5da0 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -7,22 +7,122 @@ package bridgev2 import ( + "errors" + "fmt" + "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) -type MessageStatus struct { +var ( + ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false) + ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false) + ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) + ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true) + ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) +) + +type MessageStatusEventInfo struct { RoomID id.RoomID EventID id.EventID - Status event.MessageStatus - ErrorReason event.MessageStatusReason - DeliveredTo []id.UserID - Error error // Internal error to be tracked in message checkpoints - Message string // Human-readable message shown to users + EventType event.Type + MessageType event.MessageType + Sender id.UserID + ThreadRoot id.EventID } -func (ms *MessageStatus) CheckpointStatus() status.MessageCheckpointStatus { +func StatusEventInfoFromEvent(evt *event.Event) *MessageStatusEventInfo { + var threadRoot id.EventID + if relatable, ok := evt.Content.Parsed.(event.Relatable); ok { + threadRoot = relatable.OptionalGetRelatesTo().GetThreadParent() + } + return &MessageStatusEventInfo{ + RoomID: evt.RoomID, + EventID: evt.ID, + EventType: evt.Type, + MessageType: evt.Content.AsMessage().MsgType, + Sender: evt.Sender, + ThreadRoot: threadRoot, + } +} + +type MessageStatus struct { + Step status.MessageCheckpointStep + RetryNum int + + Status event.MessageStatus + ErrorReason event.MessageStatusReason + DeliveredTo []id.UserID + InternalError error // Internal error to be tracked in message checkpoints + Message string // Human-readable message shown to users + + ErrorAsMessage bool + IsCertain bool + SendNotice bool +} + +func WrapErrorInStatus(err error) MessageStatus { + var alreadyWrapped MessageStatus + var ok bool + if alreadyWrapped, ok = err.(MessageStatus); ok { + return alreadyWrapped + } else if errors.As(err, &alreadyWrapped) { + alreadyWrapped.InternalError = err + return alreadyWrapped + } + return MessageStatus{ + Status: event.MessageStatusRetriable, + ErrorReason: event.MessageStatusGenericError, + InternalError: err, + } +} + +func (ms MessageStatus) WithSendNotice(send bool) MessageStatus { + ms.SendNotice = send + return ms +} + +func (ms MessageStatus) WithIsCertain(certain bool) MessageStatus { + ms.IsCertain = certain + return ms +} + +func (ms MessageStatus) WithMessage(msg string) MessageStatus { + ms.Message = msg + ms.ErrorAsMessage = false + return ms +} + +func (ms MessageStatus) WithStep(step status.MessageCheckpointStep) MessageStatus { + ms.Step = step + return ms +} + +func (ms MessageStatus) WithStatus(status event.MessageStatus) MessageStatus { + ms.Status = status + return ms +} + +func (ms MessageStatus) WithErrorReason(reason event.MessageStatusReason) MessageStatus { + ms.ErrorReason = reason + return ms +} + +func (ms MessageStatus) WithErrorAsMessage() MessageStatus { + ms.ErrorAsMessage = true + return ms +} + +func (ms MessageStatus) Error() string { + return ms.InternalError.Error() +} + +func (ms MessageStatus) Unwrap() error { + return ms.InternalError +} + +func (ms *MessageStatus) checkpointStatus() status.MessageCheckpointStatus { switch ms.Status { case event.MessageStatusSuccess: if ms.DeliveredTo != nil { @@ -45,34 +145,44 @@ func (ms *MessageStatus) CheckpointStatus() status.MessageCheckpointStatus { } } -func (ms *MessageStatus) ToCheckpoint() *status.MessageCheckpoint { - checkpoint := &status.MessageCheckpoint{ - RoomID: ms.RoomID, - EventID: ms.EventID, - Step: status.MsgStepRemote, - Status: ms.CheckpointStatus(), - ReportedBy: status.MsgReportedByBridge, +func (ms *MessageStatus) ToCheckpoint(evt *MessageStatusEventInfo) *status.MessageCheckpoint { + step := status.MsgStepRemote + if ms.Step != "" { + step = ms.Step } - if ms.Error != nil { - checkpoint.Info = ms.Error.Error() + checkpoint := &status.MessageCheckpoint{ + RoomID: evt.RoomID, + EventID: evt.EventID, + Step: step, + Status: ms.checkpointStatus(), + RetryNum: ms.RetryNum, + ReportedBy: status.MsgReportedByBridge, + EventType: evt.EventType, + MessageType: evt.MessageType, + } + if ms.InternalError != nil { + checkpoint.Info = ms.InternalError.Error() } else if ms.Message != "" { checkpoint.Info = ms.Message } return checkpoint } -func (ms *MessageStatus) ToEvent() *event.BeeperMessageStatusEventContent { +func (ms *MessageStatus) ToMSSEvent(evt *MessageStatusEventInfo) *event.BeeperMessageStatusEventContent { content := &event.BeeperMessageStatusEventContent{ RelatesTo: event.RelatesTo{ Type: event.RelAnnotation, - EventID: ms.EventID, + EventID: evt.EventID, }, Status: ms.Status, Reason: ms.ErrorReason, Message: ms.Message, } - if ms.Error != nil { - content.InternalError = ms.Error.Error() + if ms.InternalError != nil { + content.InternalError = ms.InternalError.Error() + if ms.ErrorAsMessage { + content.Message = content.InternalError + } } if ms.DeliveredTo != nil { content.DeliveredToUsers = &ms.DeliveredTo @@ -80,7 +190,39 @@ func (ms *MessageStatus) ToEvent() *event.BeeperMessageStatusEventContent { return content } -func (ms *MessageStatus) ErrorAsMessage() *MessageStatus { - ms.Message = ms.Error.Error() - return ms +func (ms *MessageStatus) ToNoticeEvent(evt *MessageStatusEventInfo) *event.MessageEventContent { + certainty := "may not have been" + if ms.IsCertain { + certainty = "was not" + } + evtType := "message" + switch evt.EventType { + case event.EventReaction: + evtType = "reaction" + case event.EventRedaction: + evtType = "redaction" + } + msg := ms.Message + if ms.ErrorAsMessage { + msg = ms.InternalError.Error() + } + messagePrefix := fmt.Sprintf("Your %s %s bridged", evtType, certainty) + if ms.Step == status.MsgStepCommand { + messagePrefix = "Handling your command panicked" + } + content := &event.MessageEventContent{ + MsgType: event.MsgText, + Body: fmt.Sprintf("\u26a0\ufe0f %s: %s", messagePrefix, msg), + RelatesTo: &event.RelatesTo{}, + Mentions: &event.Mentions{}, + } + if evt.ThreadRoot != "" { + content.RelatesTo.SetThread(evt.ThreadRoot, evt.EventID) + } else { + content.RelatesTo.SetReplyTo(evt.EventID) + } + if evt.Sender != "" { + content.Mentions.UserIDs = []id.UserID{evt.Sender} + } + return content } diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 4234bb8c..ccb4aeda 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -198,6 +198,24 @@ func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User) (*User } } +func (portal *Portal) sendSuccessStatus(ctx context.Context, evt *event.Event) { + portal.Bridge.Matrix.SendMessageStatus(ctx, &MessageStatus{Status: event.MessageStatusSuccess}, StatusEventInfoFromEvent(evt)) +} + +func (portal *Portal) sendErrorStatus(ctx context.Context, evt *event.Event, err error) { + status := WrapErrorInStatus(err) + if status.Status == "" { + status.Status = event.MessageStatusRetriable + } + if status.ErrorReason == "" { + status.ErrorReason = event.MessageStatusGenericError + } + if status.InternalError == nil { + status.InternalError = err + } + portal.Bridge.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) +} + func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) { if evt.Mautrix.EventSource&event.SourceEphemeral != 0 { switch evt.Type { @@ -217,7 +235,7 @@ func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) { login, err := portal.FindPreferredLogin(ctx, sender) if err != nil { log.Err(err).Msg("Failed to get user login to handle Matrix event") - // TODO send metrics + portal.sendErrorStatus(ctx, evt, WrapErrorInStatus(err).WithMessage("Failed to get login to handle event").WithIsCertain(true).WithSendNotice(true)) return } var origSender *OrigSender @@ -243,7 +261,7 @@ func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) { case event.EventReaction: if origSender != nil { log.Debug().Msg("Ignoring reaction event from relayed user") - // TODO send metrics + portal.sendErrorStatus(ctx, evt, ErrIgnoringReactionFromRelayedUser) return } portal.handleMatrixReaction(ctx, login, evt) @@ -321,7 +339,7 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin content, ok := evt.Content.Parsed.(*event.MessageEventContent) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - // TODO send metrics + portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) return } if content.RelatesTo.GetReplaceID() != "" { @@ -370,7 +388,7 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin }) if err != nil { log.Err(err).Msg("Failed to handle Matrix message") - // TODO send metrics here or inside HandleMatrixMessage? + portal.sendErrorStatus(ctx, evt, err) return } if message.MXID == "" { @@ -393,7 +411,7 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin if err != nil { log.Err(err).Msg("Failed to save message to database") } - // TODO send success metrics + portal.sendSuccessStatus(ctx, evt) } func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent) { @@ -408,11 +426,11 @@ func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, o editTarget, err := portal.Bridge.DB.Message.GetPartByMXID(ctx, editTargetID) if err != nil { log.Err(err).Msg("Failed to get edit target message from database") - // TODO send metrics + portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get edit target: %w", ErrDatabaseError, err)) return } else if editTarget == nil { log.Warn().Msg("Edit target message not found in database") - // TODO send metrics + portal.sendErrorStatus(ctx, evt, fmt.Errorf("edit %w", ErrTargetMessageNotFound)) return } log.UpdateContext(func(c zerolog.Context) zerolog.Context { @@ -429,14 +447,14 @@ func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, o }) if err != nil { log.Err(err).Msg("Failed to handle Matrix edit") - // TODO send metrics here or inside HandleMatrixEdit? + portal.sendErrorStatus(ctx, evt, err) return } err = portal.Bridge.DB.Message.Update(ctx, editTarget) if err != nil { log.Err(err).Msg("Failed to save message to database after editing") } - // TODO send success metrics + portal.sendSuccessStatus(ctx, evt) } func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) { @@ -444,7 +462,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi content, ok := evt.Content.Parsed.(*event.ReactionEventContent) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - // TODO send metrics + portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) return } log.UpdateContext(func(c zerolog.Context) zerolog.Context { @@ -453,11 +471,11 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi reactionTarget, err := portal.Bridge.DB.Message.GetPartByMXID(ctx, content.RelatesTo.EventID) if err != nil { log.Err(err).Msg("Failed to get reaction target message from database") - // TODO send metrics + portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get reaction target: %w", ErrDatabaseError, err)) return } else if reactionTarget == nil { log.Warn().Msg("Reaction target message not found in database") - // TODO send metrics + portal.sendErrorStatus(ctx, evt, fmt.Errorf("reaction %w", ErrTargetMessageNotFound)) return } log.UpdateContext(func(c zerolog.Context) zerolog.Context { @@ -474,7 +492,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi preResp, err := sender.Client.PreHandleMatrixReaction(ctx, react) if err != nil { log.Err(err).Msg("Failed to pre-handle Matrix reaction") - // TODO send metrics + portal.sendErrorStatus(ctx, evt, err) return } existing, err := portal.Bridge.DB.Reaction.GetByID(ctx, reactionTarget.ID, reactionTarget.PartID, preResp.SenderID, preResp.EmojiID) @@ -484,7 +502,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi } else if existing != nil { if existing.EmojiID != "" || existing.Metadata["emoji"] == preResp.Emoji { log.Debug().Msg("Ignoring duplicate reaction") - // TODO send metrics + portal.sendSuccessStatus(ctx, evt) return } _, err = portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ @@ -501,7 +519,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi allReactions, err := portal.Bridge.DB.Reaction.GetAllToMessageBySender(ctx, reactionTarget.ID, preResp.SenderID) if err != nil { log.Err(err).Msg("Failed to get all reactions to message by sender") - // TODO send metrics + portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get previous reactions: %w", ErrDatabaseError, err)) return } if len(allReactions) < preResp.MaxReactions { @@ -524,13 +542,11 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi } } } - // TODO get all reactions to message by sender in order to remove oldest ones - // (this is necessary for telegram where reaction limit is 1 or 3 based on premium status) } dbReaction, err := sender.Client.HandleMatrixReaction(ctx, react) if err != nil { log.Err(err).Msg("Failed to handle Matrix reaction") - // TODO send metrics here or inside HandleMatrixReaction? + portal.sendErrorStatus(ctx, evt, err) return } // Fill all fields that are known to allow omitting them in connector code @@ -564,7 +580,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi if err != nil { log.Err(err).Msg("Failed to save reaction to database") } - // TODO send success metrics + portal.sendSuccessStatus(ctx, evt) } func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { @@ -572,7 +588,7 @@ func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLog content, ok := evt.Content.Parsed.(*event.RedactionEventContent) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - // TODO send metrics + portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) return } if evt.Redacts != "" && content.Redacts != evt.Redacts { @@ -584,16 +600,9 @@ func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLog redactionTargetMsg, err := portal.Bridge.DB.Message.GetPartByMXID(ctx, content.Redacts) if err != nil { log.Err(err).Msg("Failed to get redaction target message from database") - // TODO send metrics + portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get redaction target message: %w", ErrDatabaseError, err)) return - } - redactionTargetReaction, err := portal.Bridge.DB.Reaction.GetByMXID(ctx, content.Redacts) - if err != nil { - log.Err(err).Msg("Failed to get redaction target reaction from database") - // TODO send metrics - return - } - if redactionTargetMsg != nil { + } else if redactionTargetMsg != nil { err = sender.Client.HandleMatrixMessageRemove(ctx, &MatrixMessageRemove{ MatrixEventBase: MatrixEventBase[*event.RedactionEventContent]{ Event: evt, @@ -603,6 +612,10 @@ func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLog }, TargetMessage: redactionTargetMsg, }) + } else if redactionTargetReaction, err := portal.Bridge.DB.Reaction.GetByMXID(ctx, content.Redacts); err != nil { + log.Err(err).Msg("Failed to get redaction target reaction from database") + portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get redaction target message reaction: %w", ErrDatabaseError, err)) + return } else if redactionTargetReaction != nil { err = sender.Client.HandleMatrixReactionRemove(ctx, &MatrixReactionRemove{ MatrixEventBase: MatrixEventBase[*event.RedactionEventContent]{ @@ -615,16 +628,16 @@ func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLog }) } else { log.Debug().Msg("Redaction target message not found in database") - // TODO send metrics + portal.sendErrorStatus(ctx, evt, fmt.Errorf("redaction %w", ErrTargetMessageNotFound)) return } if err != nil { log.Err(err).Msg("Failed to handle Matrix redaction") - // TODO send metrics here or inside HandleMatrixMessageRemove and HandleMatrixReactionRemove? + portal.sendErrorStatus(ctx, evt, err) return } // TODO delete msg/reaction db row - // TODO send success metrics + portal.sendSuccessStatus(ctx, evt) } func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { @@ -646,6 +659,8 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { } } switch evt.GetType() { + case RemoteEventUnknown: + log.Debug().Msg("Ignoring remote event with type unknown") case RemoteEventMessage: portal.handleRemoteMessage(ctx, source, evt.(RemoteMessage)) case RemoteEventEdit: @@ -656,6 +671,8 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { portal.handleRemoteReactionRemove(ctx, source, evt.(RemoteReactionRemove)) case RemoteEventMessageRemove: portal.handleRemoteMessageRemove(ctx, source, evt.(RemoteMessageRemove)) + default: + log.Warn().Int("type", int(evt.GetType())).Msg("Got remote event with unknown type") } } @@ -695,9 +712,11 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin if intent == nil { return } + ts := getEventTS(evt) converted, err := evt.ConvertMessage(ctx, portal, intent) if err != nil { - // TODO log and notify room? + log.Err(err).Msg("Failed to convert remote message") + portal.sendRemoteErrorNotice(ctx, intent, err, ts, "message") return } var relatesToRowID int64 @@ -723,7 +742,6 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin // TODO 2 fetch last event in thread properly prevThreadEvent = threadRoot } - ts := getEventTS(evt) for _, part := range converted.Parts { if threadRoot != nil && prevThreadEvent != nil { part.Content.GetRelatesTo().SetThread(threadRoot.MXID, prevThreadEvent.MXID) @@ -746,6 +764,10 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin log.Err(err).Str("part_id", string(part.ID)).Msg("Failed to send message part to Matrix") continue } + log.Debug(). + Stringer("event_id", resp.EventID). + Str("part_id", string(part.ID)). + Msg("Sent message part to Matrix") if part.DBMetadata == nil { part.DBMetadata = make(map[string]any) } @@ -771,6 +793,24 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin } } +func (portal *Portal) sendRemoteErrorNotice(ctx context.Context, intent MatrixAPI, err error, ts time.Time, evtTypeName string) { + resp, sendErr := intent.SendMessage(ctx, portal.MXID, event.EventMessage, &event.Content{ + Parsed: &event.MessageEventContent{ + MsgType: event.MsgNotice, + Body: fmt.Sprintf("An error occurred while processing an incoming %s", evtTypeName), + Mentions: &event.Mentions{}, + }, + Raw: map[string]any{ + "fi.mau.bridge.internal_error": err.Error(), + }, + }, ts) + if sendErr != nil { + zerolog.Ctx(ctx).Err(sendErr).Msg("Failed to send error notice after remote event handling failed") + } else { + zerolog.Ctx(ctx).Debug().Stringer("event_id", resp.EventID).Msg("Sent error notice after remote event handling failed") + } +} + func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, evt RemoteEdit) { log := zerolog.Ctx(ctx) existing, err := portal.Bridge.DB.Message.GetAllPartsByID(ctx, evt.GetTargetMessage()) @@ -785,12 +825,13 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e if intent == nil { return } + ts := getEventTS(evt) converted, err := evt.ConvertEdit(ctx, portal, intent, existing) if err != nil { - // TODO log and notify room? + log.Err(err).Msg("Failed to convert remote edit") + portal.sendRemoteErrorNotice(ctx, intent, err, ts, "edit") return } - ts := getEventTS(evt) for _, part := range converted.ModifiedParts { part.Content.SetEdit(part.Part.MXID) if part.TopLevelExtra == nil { @@ -803,9 +844,14 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e Parsed: part.Content, Raw: part.TopLevelExtra, } - _, err = intent.SendMessage(ctx, portal.MXID, part.Type, wrappedContent, ts) + resp, err := intent.SendMessage(ctx, portal.MXID, part.Type, wrappedContent, ts) if err != nil { log.Err(err).Stringer("part_mxid", part.Part.MXID).Msg("Failed to edit message part") + } else { + log.Debug(). + Stringer("event_id", resp.EventID). + Str("part_id", string(part.Part.ID)). + Msg("Sent message part edit to Matrix") } err = portal.Bridge.DB.Message.Update(ctx, part.Part) if err != nil { @@ -818,9 +864,15 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e Redacts: part.MXID, }, } - _, err = intent.SendMessage(ctx, portal.MXID, event.EventRedaction, redactContent, ts) + resp, err := intent.SendMessage(ctx, portal.MXID, event.EventRedaction, redactContent, ts) if err != nil { log.Err(err).Stringer("part_mxid", part.MXID).Msg("Failed to redact message part deleted in edit") + } else { + log.Debug(). + Stringer("redaction_event_id", resp.EventID). + Stringer("redacted_event_id", part.MXID). + Str("part_id", string(part.ID)). + Msg("Sent redaction of message part to Matrix") } err = portal.Bridge.DB.Message.Delete(ctx, part.RowID) if err != nil { @@ -883,6 +935,9 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi log.Err(err).Msg("Failed to send reaction to Matrix") return } + log.Debug(). + Stringer("event_id", resp.EventID). + Msg("Sent reaction to Matrix") dbReaction := &database.Reaction{ RoomID: portal.ID, MessageID: targetMessage.ID, @@ -951,13 +1006,19 @@ func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *Use intent := portal.getIntentFor(ctx, evt.GetSender(), source) ts := getEventTS(evt) for _, part := range targetParts { - _, err = intent.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ + resp, err := intent.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ Parsed: &event.RedactionEventContent{ Redacts: part.MXID, }, }, ts) if err != nil { log.Err(err).Stringer("part_mxid", part.MXID).Msg("Failed to redact message part") + } else { + log.Debug(). + Stringer("redaction_event_id", resp.EventID). + Stringer("redacted_event_id", part.MXID). + Str("part_id", string(part.ID)). + Msg("Sent redaction of message part to Matrix") } } err = portal.Bridge.DB.Message.DeleteAllParts(ctx, evt.GetTargetMessage()) diff --git a/bridgev2/queue.go b/bridgev2/queue.go index 4e3a6549..aa246267 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -8,6 +8,8 @@ package bridgev2 import ( "context" + "errors" + "fmt" "strings" "github.com/rs/zerolog" @@ -25,13 +27,15 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) { sender, err = br.GetUserByMXID(ctx, evt.Sender) if err != nil { log.Err(err).Msg("Failed to get sender user for incoming Matrix event") - // TODO send metrics + status := WrapErrorInStatus(fmt.Errorf("%w: failed to get sender user: %w", ErrDatabaseError, err)) + br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) return } } if sender == nil && evt.Type.Class != event.EphemeralEventType { log.Error().Msg("Missing sender for incoming non-ephemeral Matrix event") - // TODO send metrics + status := WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage() + br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) return } if evt.Type == event.EventMessage { @@ -64,13 +68,17 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) { portal, err := br.GetPortalByMXID(ctx, evt.RoomID) if err != nil { log.Err(err).Msg("Failed to get portal for incoming Matrix event") - // TODO send metrics + status := WrapErrorInStatus(fmt.Errorf("%w: failed to get portal: %w", ErrDatabaseError, err)) + br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) return } else if portal != nil { portal.queueEvent(ctx, &portalMatrixEvent{ evt: evt, sender: sender, }) + } else { + status := WrapErrorInStatus(ErrNoPortal) + br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) } } From 39ce0103d40bdc68871534559ff90f7d1a0fe0d5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 10 Jun 2024 21:41:14 +0300 Subject: [PATCH 0268/1647] Add DM receivers to portals --- bridgev2/bridge.go | 4 +- bridgev2/database/message.go | 22 +++--- bridgev2/database/portal.go | 35 ++++++---- bridgev2/database/reaction.go | 12 ++-- bridgev2/database/upgrades/00-latest.sql | 88 +++++++++++++----------- bridgev2/database/user.go | 6 +- bridgev2/database/userlogin.go | 28 ++++---- bridgev2/networkid/bridgeid.go | 36 ++++++++-- bridgev2/networkinterface.go | 8 +-- bridgev2/portal.go | 70 ++++++++++++------- bridgev2/queue.go | 9 ++- bridgev2/userlogin.go | 6 +- 12 files changed, 193 insertions(+), 131 deletions(-) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 1550c72a..fe0a725a 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -41,7 +41,7 @@ type Bridge struct { usersByMXID map[id.UserID]*User userLoginsByID map[networkid.UserLoginID]*UserLogin - portalsByID map[networkid.PortalID]*Portal + portalsByKey map[networkid.PortalKey]*Portal portalsByMXID map[id.RoomID]*Portal ghostsByID map[networkid.UserID]*Ghost cacheLock sync.Mutex @@ -58,7 +58,7 @@ func NewBridge(bridgeID networkid.BridgeID, db *dbutil.Database, log zerolog.Log usersByMXID: make(map[id.UserID]*User), userLoginsByID: make(map[networkid.UserLoginID]*UserLogin), - portalsByID: make(map[networkid.PortalID]*Portal), + portalsByKey: make(map[networkid.PortalKey]*Portal), portalsByMXID: make(map[id.RoomID]*Portal), ghostsByID: make(map[networkid.UserID]*Ghost), } diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index b38d193e..c1261be7 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -29,7 +29,7 @@ type Message struct { PartID networkid.PartID MXID id.EventID - RoomID networkid.PortalID + Room networkid.PortalKey SenderID networkid.UserID Timestamp time.Time @@ -44,7 +44,7 @@ func newMessage(_ *dbutil.QueryHelper[*Message]) *Message { const ( getMessageBaseQuery = ` - SELECT rowid, bridge_id, id, part_id, mxid, room_id, sender_id, timestamp, relates_to, metadata FROM message + SELECT rowid, bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, timestamp, relates_to, metadata FROM message ` getAllMessagePartsByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND id=$2` getMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND part_id=$3` @@ -52,15 +52,15 @@ const ( getMessageByMXIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` getLastMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND id=$2 ORDER BY part_id DESC LIMIT 1` getFirstMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND id=$2 ORDER BY part_id ASC LIMIT 1` - getMessagesBetweenTimeQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND timestamp>$3 AND timestamp<=$4` + getMessagesBetweenTimeQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND timestamp>$4 AND timestamp<=$5` insertMessageQuery = ` - INSERT INTO message (bridge_id, id, part_id, mxid, room_id, sender_id, timestamp, relates_to, metadata) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + INSERT INTO message (bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, timestamp, relates_to, metadata) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING rowid ` updateMessageQuery = ` - UPDATE message SET id=$2, part_id=$3, mxid=$4, room_id=$5, sender_id=$6, timestamp=$7, relates_to=$8, metadata=$9 - WHERE bridge_id=$1 AND rowid=$10 + UPDATE message SET id=$2, part_id=$3, mxid=$4, room_id=$5, room_receiver=$6, sender_id=$7, timestamp=$8, relates_to=$9, metadata=$10 + WHERE bridge_id=$1 AND rowid=$11 ` deleteAllMessagePartsByIDQuery = ` DELETE FROM message WHERE bridge_id=$1 AND id=$2 @@ -102,8 +102,8 @@ func (mq *MessageQuery) GetFirstOrSpecificPartByID(ctx context.Context, id netwo } } -func (mq *MessageQuery) GetMessagesBetweenTimeQuery(ctx context.Context, start, end time.Time) ([]*Message, error) { - return mq.QueryMany(ctx, getMessagesBetweenTimeQuery, mq.BridgeID, start.UnixNano(), end.UnixNano()) +func (mq *MessageQuery) GetMessagesBetweenTimeQuery(ctx context.Context, portal networkid.PortalKey, start, end time.Time) ([]*Message, error) { + return mq.QueryMany(ctx, getMessagesBetweenTimeQuery, mq.BridgeID, portal.ID, portal.Receiver, start.UnixNano(), end.UnixNano()) } func (mq *MessageQuery) Insert(ctx context.Context, msg *Message) error { @@ -128,7 +128,7 @@ func (m *Message) Scan(row dbutil.Scannable) (*Message, error) { var timestamp int64 var relatesTo sql.NullInt64 err := row.Scan( - &m.RowID, &m.BridgeID, &m.ID, &m.PartID, &m.MXID, &m.RoomID, &m.SenderID, + &m.RowID, &m.BridgeID, &m.ID, &m.PartID, &m.MXID, &m.Room.ID, &m.Room.Receiver, &m.SenderID, ×tamp, &relatesTo, dbutil.JSON{Data: &m.Metadata}, ) if err != nil { @@ -141,7 +141,7 @@ func (m *Message) Scan(row dbutil.Scannable) (*Message, error) { func (m *Message) sqlVariables() []any { return []any{ - m.BridgeID, m.ID, m.PartID, m.MXID, m.RoomID, m.SenderID, + m.BridgeID, m.ID, m.PartID, m.MXID, m.Room.ID, m.Room.Receiver, m.SenderID, m.Timestamp.UnixNano(), dbutil.NumPtr(m.RelatesToRowID), dbutil.JSON{Data: m.Metadata}, } } diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index 36a38dfe..2456b9ca 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -24,8 +24,8 @@ type PortalQuery struct { type Portal struct { BridgeID networkid.BridgeID - ID networkid.PortalID - MXID id.RoomID + networkid.PortalKey + MXID id.RoomID ParentID networkid.PortalID Name string @@ -46,34 +46,39 @@ func newPortal(_ *dbutil.QueryHelper[*Portal]) *Portal { const ( getPortalBaseQuery = ` - SELECT bridge_id, id, mxid, parent_id, name, topic, avatar_id, avatar_hash, avatar_mxc, + SELECT bridge_id, id, receiver, mxid, parent_id, name, topic, avatar_id, avatar_hash, avatar_mxc, name_set, topic_set, avatar_set, in_space, metadata FROM portal ` - getPortalByIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2` - getPortalByMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` - getChildPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND parent_id=$2` + getPortalByIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND receiver=$3` + getPortalByIDWithUncertainReceiverQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='')` + getPortalByMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` + getChildPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND parent_id=$2` insertPortalQuery = ` INSERT INTO portal ( - bridge_id, id, mxid, + bridge_id, id, receiver, mxid, parent_id, name, topic, avatar_id, avatar_hash, avatar_mxc, name_set, avatar_set, topic_set, in_space, metadata - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) ` updatePortalQuery = ` UPDATE portal - SET mxid=$3, parent_id=$4, name=$5, topic=$6, avatar_id=$7, avatar_hash=$8, avatar_mxc=$9, - name_set=$10, avatar_set=$11, topic_set=$12, in_space=$13, metadata=$14 - WHERE bridge_id=$1 AND id=$2 + SET mxid=$4, parent_id=$5, name=$6, topic=$7, avatar_id=$8, avatar_hash=$9, avatar_mxc=$10, + name_set=$11, avatar_set=$12, topic_set=$13, in_space=$14, metadata=$15 + WHERE bridge_id=$1 AND id=$2 AND receiver=$3 ` reIDPortalQuery = `UPDATE portal SET id=$3 WHERE bridge_id=$1 AND id=$2` ) -func (pq *PortalQuery) GetByID(ctx context.Context, id networkid.PortalID) (*Portal, error) { - return pq.QueryOne(ctx, getPortalByIDQuery, pq.BridgeID, id) +func (pq *PortalQuery) GetByID(ctx context.Context, key networkid.PortalKey) (*Portal, error) { + return pq.QueryOne(ctx, getPortalByIDQuery, pq.BridgeID, key.ID, key.Receiver) +} + +func (pq *PortalQuery) GetByIDWithUncertainReceiver(ctx context.Context, key networkid.PortalKey) (*Portal, error) { + return pq.QueryOne(ctx, getPortalByIDWithUncertainReceiverQuery, pq.BridgeID, key.ID, key.Receiver) } func (pq *PortalQuery) GetByMXID(ctx context.Context, mxid id.RoomID) (*Portal, error) { @@ -102,7 +107,7 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { var mxid, parentID sql.NullString var avatarHash string err := row.Scan( - &p.BridgeID, &p.ID, &mxid, + &p.BridgeID, &p.ID, &p.Receiver, &mxid, &parentID, &p.Name, &p.Topic, &p.AvatarID, &avatarHash, &p.AvatarMXC, &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.InSpace, dbutil.JSON{Data: &p.Metadata}, @@ -133,7 +138,7 @@ func (p *Portal) sqlVariables() []any { avatarHash = hex.EncodeToString(p.AvatarHash[:]) } return []any{ - p.BridgeID, p.ID, dbutil.StrPtr(p.MXID), + p.BridgeID, p.ID, p.Receiver, dbutil.StrPtr(p.MXID), dbutil.StrPtr(p.ParentID), p.Name, p.Topic, p.AvatarID, avatarHash, p.AvatarMXC, p.NameSet, p.TopicSet, p.AvatarSet, p.InSpace, dbutil.JSON{Data: p.Metadata}, diff --git a/bridgev2/database/reaction.go b/bridgev2/database/reaction.go index 4abec731..5b01459b 100644 --- a/bridgev2/database/reaction.go +++ b/bridgev2/database/reaction.go @@ -23,7 +23,7 @@ type ReactionQuery struct { type Reaction struct { BridgeID networkid.BridgeID - RoomID networkid.PortalID + Room networkid.PortalKey MessageID networkid.MessageID MessagePartID networkid.PartID SenderID networkid.UserID @@ -40,7 +40,7 @@ func newReaction(_ *dbutil.QueryHelper[*Reaction]) *Reaction { const ( getReactionBaseQuery = ` - SELECT bridge_id, message_id, message_part_id, sender_id, emoji_id, room_id, mxid, timestamp, metadata FROM reaction + SELECT bridge_id, message_id, message_part_id, sender_id, emoji_id, room_id, room_receiver, mxid, timestamp, metadata FROM reaction ` getReactionByIDQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND message_part_id=$3 AND sender_id=$4 AND emoji_id=$5` getReactionByIDWithoutMessagePartQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND sender_id=$3 AND emoji_id=$4 ORDER BY message_part_id ASC LIMIT 1` @@ -48,8 +48,8 @@ const ( getAllReactionsToMessageQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2` getReactionByMXIDQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` upsertReactionQuery = ` - INSERT INTO reaction (bridge_id, message_id, message_part_id, sender_id, emoji_id, room_id, mxid, timestamp, metadata) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + INSERT INTO reaction (bridge_id, message_id, message_part_id, sender_id, emoji_id, room_id, room_receiver, mxid, timestamp, metadata) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) ON CONFLICT (bridge_id, message_id, message_part_id, sender_id, emoji_id) DO UPDATE SET mxid=excluded.mxid, timestamp=excluded.timestamp, metadata=excluded.metadata ` @@ -92,7 +92,7 @@ func (r *Reaction) Scan(row dbutil.Scannable) (*Reaction, error) { var timestamp int64 err := row.Scan( &r.BridgeID, &r.MessageID, &r.MessagePartID, &r.SenderID, &r.EmojiID, - &r.RoomID, &r.MXID, ×tamp, dbutil.JSON{Data: &r.Metadata}, + &r.Room.ID, &r.Room.Receiver, &r.MXID, ×tamp, dbutil.JSON{Data: &r.Metadata}, ) if err != nil { return nil, err @@ -110,6 +110,6 @@ func (r *Reaction) sqlVariables() []any { } return []any{ r.BridgeID, r.MessageID, r.MessagePartID, r.SenderID, r.EmojiID, - r.RoomID, r.MXID, r.Timestamp.UnixNano(), dbutil.JSON{Data: r.Metadata}, + r.Room.ID, r.Room.Receiver, r.MXID, r.Timestamp.UnixNano(), dbutil.JSON{Data: r.Metadata}, } } diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index df8c0a4a..dd03a0eb 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,26 +1,31 @@ -- v0 -> v1: Latest revision CREATE TABLE portal ( - bridge_id TEXT NOT NULL, - id TEXT NOT NULL, - mxid TEXT, + bridge_id TEXT NOT NULL, + id TEXT NOT NULL, + receiver TEXT NOT NULL, + mxid TEXT, - parent_id TEXT, - name TEXT NOT NULL, - topic TEXT NOT NULL, - avatar_id TEXT NOT NULL, - avatar_hash TEXT NOT NULL, - avatar_mxc TEXT NOT NULL, - name_set BOOLEAN NOT NULL, - avatar_set BOOLEAN NOT NULL, - topic_set BOOLEAN NOT NULL, - in_space BOOLEAN NOT NULL, - metadata jsonb NOT NULL, + parent_id TEXT, + -- This is not accessed by the bridge, it's only used for the portal parent foreign key. + -- Parent groups are probably never DMs, so they don't need a receiver. + parent_receiver TEXT NOT NULL DEFAULT '', - PRIMARY KEY (bridge_id, id), - CONSTRAINT portal_parent_fkey FOREIGN KEY (bridge_id, parent_id) + name TEXT NOT NULL, + topic TEXT NOT NULL, + avatar_id TEXT NOT NULL, + avatar_hash TEXT NOT NULL, + avatar_mxc TEXT NOT NULL, + name_set BOOLEAN NOT NULL, + avatar_set BOOLEAN NOT NULL, + topic_set BOOLEAN NOT NULL, + in_space BOOLEAN NOT NULL, + metadata jsonb NOT NULL, + + PRIMARY KEY (bridge_id, id, receiver), + CONSTRAINT portal_parent_fkey FOREIGN KEY (bridge_id, parent_id, parent_receiver) -- Deletes aren't allowed to cascade here: -- children should be re-parented or cleaned up manually - REFERENCES portal (bridge_id, id) ON UPDATE CASCADE + REFERENCES portal (bridge_id, id, receiver) ON UPDATE CASCADE ); CREATE TABLE ghost ( @@ -46,23 +51,24 @@ CREATE TABLE message ( -- only: sqlite (line commented) -- rowid INTEGER PRIMARY KEY, -- only: postgres - rowid BIGINT PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY, + rowid BIGINT PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY, - bridge_id TEXT NOT NULL, - id TEXT NOT NULL, - part_id TEXT NOT NULL, - mxid TEXT NOT NULL, + bridge_id TEXT NOT NULL, + id TEXT NOT NULL, + part_id TEXT NOT NULL, + mxid TEXT NOT NULL, - room_id TEXT NOT NULL, - sender_id TEXT NOT NULL, - timestamp BIGINT NOT NULL, - relates_to BIGINT, - metadata jsonb NOT NULL, + room_id TEXT NOT NULL, + room_receiver TEXT NOT NULL, + sender_id TEXT NOT NULL, + timestamp BIGINT NOT NULL, + relates_to BIGINT, + metadata jsonb NOT NULL, CONSTRAINT message_relation_fkey FOREIGN KEY (relates_to) REFERENCES message (rowid) ON DELETE SET NULL, - CONSTRAINT message_room_fkey FOREIGN KEY (bridge_id, room_id) - REFERENCES portal (bridge_id, id) + CONSTRAINT message_room_fkey FOREIGN KEY (bridge_id, room_id, room_receiver) + REFERENCES portal (bridge_id, id, receiver) ON DELETE CASCADE ON UPDATE CASCADE, CONSTRAINT message_sender_fkey FOREIGN KEY (bridge_id, sender_id) REFERENCES ghost (bridge_id, id) @@ -77,14 +83,15 @@ CREATE TABLE reaction ( sender_id TEXT NOT NULL, emoji_id TEXT NOT NULL, room_id TEXT NOT NULL, + room_receiver TEXT NOT NULL, mxid TEXT NOT NULL, timestamp BIGINT NOT NULL, metadata jsonb NOT NULL, PRIMARY KEY (bridge_id, message_id, message_part_id, sender_id, emoji_id), - CONSTRAINT reaction_room_fkey FOREIGN KEY (bridge_id, room_id) - REFERENCES portal (bridge_id, id) + CONSTRAINT reaction_room_fkey FOREIGN KEY (bridge_id, room_id, room_receiver) + REFERENCES portal (bridge_id, id, receiver) ON DELETE CASCADE ON UPDATE CASCADE, CONSTRAINT reaction_message_fkey FOREIGN KEY (bridge_id, message_id, message_part_id) REFERENCES message (bridge_id, id, part_id) @@ -118,18 +125,19 @@ CREATE TABLE user_login ( ); CREATE TABLE user_portal ( - bridge_id TEXT NOT NULL, - user_mxid TEXT NOT NULL, - login_id TEXT NOT NULL, - portal_id TEXT NOT NULL, - in_space BOOLEAN NOT NULL, - preferred BOOLEAN NOT NULL, + bridge_id TEXT NOT NULL, + user_mxid TEXT NOT NULL, + login_id TEXT NOT NULL, + portal_id TEXT NOT NULL, + portal_receiver TEXT NOT NULL, + in_space BOOLEAN NOT NULL, + preferred BOOLEAN NOT NULL, - PRIMARY KEY (bridge_id, user_mxid, login_id, portal_id), + PRIMARY KEY (bridge_id, user_mxid, login_id, portal_id, portal_receiver), CONSTRAINT user_portal_user_login_fkey FOREIGN KEY (bridge_id, user_mxid, login_id) REFERENCES user_login (bridge_id, user_mxid, id) ON DELETE CASCADE ON UPDATE CASCADE, - CONSTRAINT user_portal_portal_fkey FOREIGN KEY (bridge_id, portal_id) - REFERENCES portal (bridge_id, id) + CONSTRAINT user_portal_portal_fkey FOREIGN KEY (bridge_id, portal_id, portal_receiver) + REFERENCES portal (bridge_id, id, receiver) ON DELETE CASCADE ON UPDATE CASCADE ); diff --git a/bridgev2/database/user.go b/bridgev2/database/user.go index c5d2d0aa..c549e234 100644 --- a/bridgev2/database/user.go +++ b/bridgev2/database/user.go @@ -49,7 +49,7 @@ const ( findUserLoginsByPortalIDQuery = ` SELECT login_id FROM user_portal - WHERE bridge_id=$1 AND user_mxid=$2 AND portal_id=$3 + WHERE bridge_id=$1 AND user_mxid=$2 AND portal_id=$3 AND portal_receiver=$4 ORDER BY CASE WHEN preferred THEN 0 ELSE 1 END, login_id ` ) @@ -68,8 +68,8 @@ func (uq *UserQuery) Update(ctx context.Context, user *User) error { return uq.Exec(ctx, updateUserQuery, user.sqlVariables()...) } -func (uq *UserQuery) FindLoginsByPortalID(ctx context.Context, userID id.UserID, portalID networkid.PortalID) ([]networkid.UserLoginID, error) { - rows, err := uq.GetDB().Query(ctx, findUserLoginsByPortalIDQuery, uq.BridgeID, userID, portalID) +func (uq *UserQuery) FindLoginsByPortalID(ctx context.Context, userID id.UserID, portal networkid.PortalKey) ([]networkid.UserLoginID, error) { + rows, err := uq.GetDB().Query(ctx, findUserLoginsByPortalIDQuery, uq.BridgeID, userID, portal.ID, portal.Receiver) return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[networkid.UserLoginID], err).AsList() } diff --git a/bridgev2/database/userlogin.go b/bridgev2/database/userlogin.go index 8169980c..36b2032b 100644 --- a/bridgev2/database/userlogin.go +++ b/bridgev2/database/userlogin.go @@ -42,7 +42,7 @@ const ( getAllLoginsInPortalQuery = ` SELECT ul.bridge_id, ul.user_mxid, ul.id, ul.space_room, ul.metadata FROM user_portal LEFT JOIN user_login ul ON user_portal.bridge_id=ul.bridge_id AND user_portal.user_mxid=ul.user_mxid AND user_portal.login_id=ul.id - WHERE user_portal.bridge_id=$1 AND user_portal.portal_id=$2 + WHERE user_portal.bridge_id=$1 AND user_portal.portal_id=$2 AND user_portal.portal_receiver=$3 ` insertUserLoginQuery = ` INSERT INTO user_login (bridge_id, user_mxid, id, space_room, metadata) @@ -56,17 +56,17 @@ const ( DELETE FROM user_login WHERE bridge_id=$1 AND id=$2 ` insertUserPortalQuery = ` - INSERT INTO user_portal (bridge_id, user_mxid, login_id, portal_id, in_space, preferred) - VALUES ($1, $2, $3, $4, false, false) - ON CONFLICT (bridge_id, user_mxid, login_id, portal_id) DO NOTHING + INSERT INTO user_portal (bridge_id, user_mxid, login_id, portal_id, portal_receiver, in_space, preferred) + VALUES ($1, $2, $3, $4, $5, false, false) + ON CONFLICT (bridge_id, user_mxid, login_id, portal_id, portal_receiver) DO NOTHING ` upsertUserPortalQuery = ` - INSERT INTO user_portal (bridge_id, user_mxid, login_id, portal_id, in_space, preferred) - VALUES ($1, $2, $3, $4, $5, $6) - ON CONFLICT (bridge_id, user_mxid, login_id, portal_id) DO UPDATE SET in_space=excluded.in_space, preferred=excluded.preferred + INSERT INTO user_portal (bridge_id, user_mxid, login_id, portal_id, portal_receiver, in_space, preferred) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (bridge_id, user_mxid, login_id, portal_id, portal_receiver) DO UPDATE SET in_space=excluded.in_space, preferred=excluded.preferred ` markLoginAsPreferredQuery = ` - UPDATE user_portal SET preferred=(login_id=$3) WHERE bridge_id=$1 AND user_mxid=$2 AND portal_id=$4 + UPDATE user_portal SET preferred=(login_id=$3) WHERE bridge_id=$1 AND user_mxid=$2 AND portal_id=$4 AND portal_receiver=$5 ` ) @@ -74,8 +74,8 @@ func (uq *UserLoginQuery) GetAll(ctx context.Context) ([]*UserLogin, error) { return uq.QueryMany(ctx, getAllLoginsQuery, uq.BridgeID) } -func (uq *UserLoginQuery) GetAllInPortal(ctx context.Context, portalID networkid.PortalID) ([]*UserLogin, error) { - return uq.QueryMany(ctx, getAllLoginsInPortalQuery, uq.BridgeID, portalID) +func (uq *UserLoginQuery) GetAllInPortal(ctx context.Context, portal networkid.PortalKey) ([]*UserLogin, error) { + return uq.QueryMany(ctx, getAllLoginsInPortalQuery, uq.BridgeID, portal.ID, portal.Receiver) } func (uq *UserLoginQuery) GetAllForUser(ctx context.Context, userID id.UserID) ([]*UserLogin, error) { @@ -96,14 +96,14 @@ func (uq *UserLoginQuery) Delete(ctx context.Context, loginID networkid.UserLogi return uq.Exec(ctx, deleteUserLoginQuery, uq.BridgeID, loginID) } -func (uq *UserLoginQuery) EnsureUserPortalExists(ctx context.Context, login *UserLogin, portalID networkid.PortalID) error { +func (uq *UserLoginQuery) EnsureUserPortalExists(ctx context.Context, login *UserLogin, portal networkid.PortalKey) error { ensureBridgeIDMatches(&login.BridgeID, uq.BridgeID) - return uq.Exec(ctx, insertUserPortalQuery, login.BridgeID, login.UserMXID, login.ID, portalID) + return uq.Exec(ctx, insertUserPortalQuery, login.BridgeID, login.UserMXID, login.ID, portal.ID, portal.Receiver) } -func (uq *UserLoginQuery) MarkLoginAsPreferredInPortal(ctx context.Context, login *UserLogin, portalID networkid.PortalID) error { +func (uq *UserLoginQuery) MarkLoginAsPreferredInPortal(ctx context.Context, login *UserLogin, portal networkid.PortalKey) error { ensureBridgeIDMatches(&login.BridgeID, uq.BridgeID) - return uq.Exec(ctx, markLoginAsPreferredQuery, login.BridgeID, login.UserMXID, login.ID, portalID) + return uq.Exec(ctx, markLoginAsPreferredQuery, login.BridgeID, login.UserMXID, login.ID, portal.ID, portal.Receiver) } func (u *UserLogin) Scan(row dbutil.Scannable) (*UserLogin, error) { diff --git a/bridgev2/networkid/bridgeid.go b/bridgev2/networkid/bridgeid.go index 75406df3..57900cb7 100644 --- a/bridgev2/networkid/bridgeid.go +++ b/bridgev2/networkid/bridgeid.go @@ -6,16 +6,44 @@ package networkid +import ( + "fmt" + + "github.com/rs/zerolog" +) + // BridgeID is an opaque identifier for a bridge type BridgeID string // PortalID is the ID of a room on the remote network. -// -// Portal IDs must be globally unique and refer to a single chat. -// This means that user IDs can't be used directly as DM chat IDs, instead the ID must contain both user IDs (e.g. "user1-user2"). -// If generating such IDs manually, sorting the users is recommended to ensure they're consistent. type PortalID string +// PortalKey is the unique key of a room on the remote network. It combines a portal ID and a receiver ID. +// +// The Receiver field is generally only used for DMs, and should be empty for group chats. +// The purpose is to segregate DMs by receiver, so that the same DM has separate rooms even +// if both sides are logged into the bridge. Also, for networks that use user IDs as DM chat IDs, +// the receiver is necessary to have separate rooms for separate users who have a DM with the same +// remote user. +type PortalKey struct { + ID PortalID + Receiver UserLoginID +} + +func (pk PortalKey) String() string { + if pk.Receiver == "" { + return string(pk.ID) + } + return fmt.Sprintf("%s/%s", pk.ID, pk.Receiver) +} + +func (pk PortalKey) MarshalZerologObject(evt *zerolog.Event) { + evt.Str("portal_id", string(pk.ID)) + if pk.Receiver != "" { + evt.Str("portal_receiver", string(pk.Receiver)) + } +} + // UserID is the ID of a user on the remote network. type UserID string diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 7d58eae8..b30a764c 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -93,7 +93,7 @@ const ( type RemoteEvent interface { GetType() RemoteEventType - GetPortalID() networkid.PortalID + GetPortalKey() networkid.PortalKey ShouldCreatePortal() bool AddLogContext(c zerolog.Context) zerolog.Context GetSender() EventSender @@ -148,7 +148,7 @@ type RemoteMessageRemove interface { type SimpleRemoteEvent[T any] struct { Type RemoteEventType LogContext func(c zerolog.Context) zerolog.Context - PortalID networkid.PortalID + PortalKey networkid.PortalKey Data T CreatePortal bool @@ -178,8 +178,8 @@ func (sre *SimpleRemoteEvent[T]) AddLogContext(c zerolog.Context) zerolog.Contex return sre.LogContext(c) } -func (sre *SimpleRemoteEvent[T]) GetPortalID() networkid.PortalID { - return sre.PortalID +func (sre *SimpleRemoteEvent[T]) GetPortalKey() networkid.PortalKey { + return sre.PortalKey } func (sre *SimpleRemoteEvent[T]) GetTimestamp() time.Time { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index ccb4aeda..dab5129c 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -60,7 +60,7 @@ type Portal struct { const PortalEventBuffer = 64 -func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, queryErr error, id *networkid.PortalID) (*Portal, error) { +func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, queryErr error, id *networkid.PortalKey) (*Portal, error) { if queryErr != nil { return nil, fmt.Errorf("failed to query db: %w", queryErr) } @@ -69,8 +69,8 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que return nil, nil } dbPortal = &database.Portal{ - BridgeID: br.ID, - ID: *id, + BridgeID: br.ID, + PortalKey: *id, } err := br.DB.Portal.Insert(ctx, dbPortal) if err != nil { @@ -83,13 +83,13 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que events: make(chan portalEvent, PortalEventBuffer), } - br.portalsByID[portal.ID] = portal + br.portalsByKey[portal.PortalKey] = portal if portal.MXID != "" { br.portalsByMXID[portal.MXID] = portal } if portal.ParentID != "" { var err error - portal.Parent, err = br.unlockedGetPortalByID(ctx, portal.ParentID, false) + portal.Parent, err = br.unlockedGetPortalByID(ctx, networkid.PortalKey{ID: portal.ParentID}, false) if err != nil { return nil, fmt.Errorf("failed to load parent portal (%s): %w", portal.ParentID, err) } @@ -107,8 +107,8 @@ func (portal *Portal) updateLogger() { portal.Log = logWith.Logger() } -func (br *Bridge) unlockedGetPortalByID(ctx context.Context, id networkid.PortalID, onlyIfExists bool) (*Portal, error) { - cached, ok := br.portalsByID[id] +func (br *Bridge) unlockedGetPortalByID(ctx context.Context, id networkid.PortalKey, onlyIfExists bool) (*Portal, error) { + cached, ok := br.portalsByKey[id] if ok { return cached, nil } @@ -131,16 +131,28 @@ func (br *Bridge) GetPortalByMXID(ctx context.Context, mxid id.RoomID) (*Portal, return br.loadPortal(ctx, db, err, nil) } -func (br *Bridge) GetPortalByID(ctx context.Context, id networkid.PortalID) (*Portal, error) { +func (br *Bridge) GetPortalByID(ctx context.Context, id networkid.PortalKey) (*Portal, error) { br.cacheLock.Lock() defer br.cacheLock.Unlock() return br.unlockedGetPortalByID(ctx, id, false) } -func (br *Bridge) GetExistingPortalByID(ctx context.Context, id networkid.PortalID) (*Portal, error) { +func (br *Bridge) GetExistingPortalByID(ctx context.Context, id networkid.PortalKey) (*Portal, error) { br.cacheLock.Lock() defer br.cacheLock.Unlock() - return br.unlockedGetPortalByID(ctx, id, true) + if id.Receiver == "" { + return br.unlockedGetPortalByID(ctx, id, true) + } + cached, ok := br.portalsByKey[id] + if ok { + return cached, nil + } + cached, ok = br.portalsByKey[networkid.PortalKey{ID: id.ID}] + if ok { + return cached, nil + } + db, err := br.DB.Portal.GetByIDWithUncertainReceiver(ctx, id) + return br.loadPortal(ctx, db, err, nil) } func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) { @@ -167,7 +179,7 @@ func (portal *Portal) eventLoop() { } func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User) (*UserLogin, error) { - logins, err := portal.Bridge.DB.User.FindLoginsByPortalID(ctx, user.MXID, portal.ID) + logins, err := portal.Bridge.DB.User.FindLoginsByPortalID(ctx, user.MXID, portal.PortalKey) if err != nil { return nil, err } @@ -394,8 +406,8 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin if message.MXID == "" { message.MXID = evt.ID } - if message.RoomID == "" { - message.RoomID = portal.ID + if message.Room.ID == "" { + message.Room = portal.PortalKey } if message.Timestamp.IsZero() { message.Timestamp = time.UnixMilli(evt.Timestamp) @@ -550,8 +562,8 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi return } // Fill all fields that are known to allow omitting them in connector code - if dbReaction.RoomID == "" { - dbReaction.RoomID = portal.ID + if dbReaction.Room.ID == "" { + dbReaction.Room = portal.PortalKey } if dbReaction.MessageID == "" { dbReaction.MessageID = reactionTarget.ID @@ -777,7 +789,7 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin ID: evt.GetID(), PartID: part.ID, MXID: resp.EventID, - RoomID: portal.ID, + Room: portal.PortalKey, SenderID: evt.GetSender().Sender, Timestamp: ts, RelatesToRowID: relatesToRowID, @@ -939,7 +951,7 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi Stringer("event_id", resp.EventID). Msg("Sent reaction to Matrix") dbReaction := &database.Reaction{ - RoomID: portal.ID, + Room: portal.PortalKey, MessageID: targetMessage.ID, MessagePartID: targetMessage.PartID, SenderID: evt.GetSender().Sender, @@ -1163,17 +1175,20 @@ func (portal *Portal) sendRoomMeta(ctx context.Context, sender *Ghost, ts time.T return true } -func (portal *Portal) SyncParticipants(ctx context.Context, members []networkid.UserID, source *UserLogin) ([]id.UserID, error) { - loginsInPortal, err := portal.Bridge.GetUserLoginsInPortal(ctx, portal.ID) +func (portal *Portal) SyncParticipants(ctx context.Context, members []networkid.UserID, source *UserLogin) ([]id.UserID, []id.UserID, error) { + loginsInPortal, err := portal.Bridge.GetUserLoginsInPortal(ctx, portal.PortalKey) if err != nil { - return nil, fmt.Errorf("failed to get user logins in portal: %w", err) + return nil, nil, fmt.Errorf("failed to get user logins in portal: %w", err) } expectedUserIDs := make([]id.UserID, 0, len(members)) expectedExtraUsers := make([]id.UserID, 0) expectedIntents := make([]MatrixAPI, len(members)) + extraFunctionalMembers := make([]id.UserID, 0) for i, member := range members { + isLoggedInUser := false for _, login := range loginsInPortal { if login.Client.IsThisUser(ctx, member) { + isLoggedInUser = true userIntent := portal.Bridge.Matrix.UserIntent(login.User) if userIntent != nil { expectedIntents[i] = userIntent @@ -1186,16 +1201,19 @@ func (portal *Portal) SyncParticipants(ctx context.Context, members []networkid. } ghost, err := portal.Bridge.GetGhostByID(ctx, member) if err != nil { - return nil, fmt.Errorf("failed to get ghost for %s: %w", member, err) + return nil, nil, fmt.Errorf("failed to get ghost for %s: %w", member, err) } ghost.UpdateInfoIfNecessary(ctx, source) if expectedIntents[i] == nil { expectedIntents[i] = ghost.Intent + if isLoggedInUser { + extraFunctionalMembers = append(extraFunctionalMembers, ghost.Intent.GetMXID()) + } } expectedUserIDs = append(expectedUserIDs, expectedIntents[i].GetMXID()) } if portal.MXID == "" { - return expectedUserIDs, nil + return expectedUserIDs, extraFunctionalMembers, nil } currentMembers, err := portal.Bridge.Matrix.GetMembers(ctx, portal.MXID) for _, intent := range expectedIntents { @@ -1243,7 +1261,7 @@ func (portal *Portal) SyncParticipants(ctx context.Context, members []networkid. } } } - return expectedUserIDs, nil + return expectedUserIDs, extraFunctionalMembers, nil } func (portal *Portal) UpdateInfo(ctx context.Context, info *PortalInfo, sender *Ghost, ts time.Time) { @@ -1290,7 +1308,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin) e return err } portal.UpdateInfo(ctx, info, nil, time.Time{}) - initialMembers, err := portal.SyncParticipants(ctx, info.Members, source) + initialMembers, extraFunctionalMembers, err := portal.SyncParticipants(ctx, info.Members, source) if err != nil { log.Err(err).Msg("Failed to process participant list for portal creation") return err @@ -1327,7 +1345,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin) e StateKey: &emptyString, Type: stateElementFunctionalMembers, Content: event.Content{Raw: map[string]any{ - "service_members": []id.UserID{portal.Bridge.Bot.GetMXID()}, + "service_members": append(extraFunctionalMembers, portal.Bridge.Bot.GetMXID()), }}, }) if req.Topic == "" { @@ -1380,7 +1398,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin) e // TODO add m.space.child event } if !isBeeper { - _, err = portal.SyncParticipants(ctx, info.Members, source) + _, _, err = portal.SyncParticipants(ctx, info.Members, source) if err != nil { log.Err(err).Msg("Failed to sync participants after room creation") } diff --git a/bridgev2/queue.go b/bridgev2/queue.go index aa246267..02fc3895 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -85,14 +85,17 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) { func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) { log := login.Log ctx := log.WithContext(context.TODO()) - portal, err := br.GetPortalByID(ctx, evt.GetPortalID()) + portal, err := br.GetPortalByID(ctx, evt.GetPortalKey()) if err != nil { - log.Err(err).Str("portal_id", string(evt.GetPortalID())). + log.Err(err).Object("portal_id", evt.GetPortalKey()). Msg("Failed to get portal to handle remote event") return } // TODO put this in a better place, and maybe cache to avoid constant db queries - br.DB.UserLogin.EnsureUserPortalExists(ctx, login.UserLogin, portal.ID) + err = br.DB.UserLogin.EnsureUserPortalExists(ctx, login.UserLogin, portal.PortalKey) + if err != nil { + log.Warn().Err(err).Msg("Failed to ensure user portal row exists") + } portal.queueEvent(ctx, &portalRemoteEvent{ evt: evt, source: login, diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 74956e26..429c79cd 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -83,8 +83,8 @@ func (br *Bridge) GetAllUserLogins(ctx context.Context) ([]*UserLogin, error) { return br.loadManyUserLogins(ctx, nil, logins) } -func (br *Bridge) GetUserLoginsInPortal(ctx context.Context, portalID networkid.PortalID) ([]*UserLogin, error) { - logins, err := br.DB.UserLogin.GetAllInPortal(ctx, portalID) +func (br *Bridge) GetUserLoginsInPortal(ctx context.Context, portal networkid.PortalKey) ([]*UserLogin, error) { + logins, err := br.DB.UserLogin.GetAllInPortal(ctx, portal) if err != nil { return nil, err } @@ -138,5 +138,5 @@ func (ul *UserLogin) Logout(ctx context.Context) { } func (ul *UserLogin) MarkAsPreferredIn(ctx context.Context, portal *Portal) error { - return ul.Bridge.DB.UserLogin.MarkLoginAsPreferredInPortal(ctx, ul.UserLogin, portal.ID) + return ul.Bridge.DB.UserLogin.MarkLoginAsPreferredInPortal(ctx, ul.UserLogin, portal.PortalKey) } From f97d365ea9eb095ec9ccbe493d6d89c4d2dc2224 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 11 Jun 2024 15:04:02 +0300 Subject: [PATCH 0269/1647] Add proper main for v2 matrix bridges --- bridge/bridge.go | 4 +- bridge/bridgeconfig/config.go | 4 +- bridgev2/bridge.go | 55 ++-- bridgev2/bridgeconfig/config.go | 24 +- bridgev2/bridgeconfig/upgrade.go | 171 +++++----- bridgev2/cmdevent.go | 2 +- bridgev2/cmdhelp.go | 2 +- bridgev2/ghost.go | 5 +- bridgev2/matrix/connector.go | 194 +++++++++++- bridgev2/matrix/crypto.go | 5 +- bridgev2/matrix/doublepuppet.go | 4 +- bridgev2/matrix/intent.go | 6 + bridgev2/matrix/mxmain/config.go | 36 +++ bridgev2/matrix/mxmain/dberror.go | 75 +++++ bridgev2/matrix/mxmain/example-config.yaml | 217 +++++++++++++ bridgev2/matrix/mxmain/main.go | 348 +++++++++++++++++++++ bridgev2/networkinterface.go | 36 +++ bridgev2/portal.go | 7 +- bridgev2/queue.go | 4 +- go.mod | 12 +- go.sum | 24 +- 21 files changed, 1071 insertions(+), 164 deletions(-) create mode 100644 bridgev2/matrix/mxmain/config.go create mode 100644 bridgev2/matrix/mxmain/dberror.go create mode 100644 bridgev2/matrix/mxmain/example-config.yaml create mode 100644 bridgev2/matrix/mxmain/main.go diff --git a/bridge/bridge.go b/bridge/bridge.go index 1fb04eb8..053c9021 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -276,7 +276,7 @@ func (br *Bridge) GenerateRegistration() { os.Exit(21) } - updateTokens := func(helper *configupgrade.Helper) { + updateTokens := func(helper configupgrade.Helper) { helper.Set(configupgrade.Str, reg.AppToken, "appservice", "as_token") helper.Set(configupgrade.Str, reg.ServerToken, "appservice", "hs_token") } @@ -775,7 +775,7 @@ func (br *Bridge) ResendBridgeInfo() { if !br.SaveConfig { br.ZLog.Warn().Msg("Not setting resend_bridge_info to false in config due to --no-update flag") } else { - _, _, err := configupgrade.Do(br.ConfigPath, true, br.ConfigUpgrader, configupgrade.SimpleUpgrader(func(helper *configupgrade.Helper) { + _, _, err := configupgrade.Do(br.ConfigPath, true, br.ConfigUpgrader, configupgrade.SimpleUpgrader(func(helper configupgrade.Helper) { helper.Set(configupgrade.Bool, "false", "bridge", "resend_bridge_info") })) if err != nil { diff --git a/bridge/bridgeconfig/config.go b/bridge/bridgeconfig/config.go index 2e8548b5..dfb6b7e5 100644 --- a/bridge/bridgeconfig/config.go +++ b/bridge/bridgeconfig/config.go @@ -223,7 +223,7 @@ type BaseConfig struct { Logging zeroconfig.Config `yaml:"logging"` } -func doUpgrade(helper *up.Helper) { +func doUpgrade(helper up.Helper) { helper.Copy(up.Str, "homeserver", "address") helper.Copy(up.Str, "homeserver", "domain") if legacyAsmuxFlag, ok := helper.Get(up.Bool, "homeserver", "asmux"); ok && legacyAsmuxFlag == "true" { @@ -283,7 +283,7 @@ type legacyLogConfig struct { JSONFile bool `yaml:"file_json"` } -func migrateLegacyLogConfig(helper *up.Helper) { +func migrateLegacyLogConfig(helper up.Helper) { var llc legacyLogConfig var newConfig zeroconfig.Config err := helper.GetBaseNode("logging").Decode(&newConfig) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index fe0a725a..0f77cea5 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -9,18 +9,15 @@ package bridgev2 import ( "context" "errors" - "os" - "os/signal" + "fmt" "sync" - "syscall" "github.com/rs/zerolog" "go.mau.fi/util/dbutil" - "go.mau.fi/util/exerrors" + "maunium.net/go/mautrix/bridgev2/bridgeconfig" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/id" ) @@ -35,9 +32,7 @@ type Bridge struct { Bot MatrixAPI Network NetworkConnector Commands *CommandProcessor - - // TODO move to config - CommandPrefix string + Config *bridgeconfig.BridgeConfig usersByMXID map[id.UserID]*User userLoginsByID map[networkid.UserLoginID]*UserLogin @@ -47,7 +42,7 @@ type Bridge struct { cacheLock sync.Mutex } -func NewBridge(bridgeID networkid.BridgeID, db *dbutil.Database, log zerolog.Logger, matrix MatrixConnector, network NetworkConnector) *Bridge { +func NewBridge(bridgeID networkid.BridgeID, db *dbutil.Database, log zerolog.Logger, cfg *bridgeconfig.BridgeConfig, matrix MatrixConnector, network NetworkConnector) *Bridge { br := &Bridge{ ID: bridgeID, DB: database.New(bridgeID, db), @@ -55,6 +50,7 @@ func NewBridge(bridgeID networkid.BridgeID, db *dbutil.Database, log zerolog.Log Matrix: matrix, Network: network, + Config: cfg, usersByMXID: make(map[id.UserID]*User), userLoginsByID: make(map[networkid.UserLoginID]*UserLogin), @@ -62,6 +58,9 @@ func NewBridge(bridgeID networkid.BridgeID, db *dbutil.Database, log zerolog.Log portalsByMXID: make(map[id.RoomID]*Portal), ghostsByID: make(map[networkid.UserID]*Ghost), } + if br.Config == nil { + br.Config = &bridgeconfig.BridgeConfig{CommandPrefix: "!bridge"} + } br.Commands = NewProcessor(br) br.Matrix.Init(br) br.Bot = br.Matrix.BotIntent() @@ -69,19 +68,41 @@ func NewBridge(bridgeID networkid.BridgeID, db *dbutil.Database, log zerolog.Log return br } -func (br *Bridge) Start() { +type DBUpgradeError struct { + Err error + Section string +} + +func (e DBUpgradeError) Error() string { + return e.Err.Error() +} + +func (e DBUpgradeError) Unwrap() error { + return e.Err +} + +func (br *Bridge) Start() error { br.Log.Info().Msg("Starting bridge") ctx := br.Log.WithContext(context.Background()) - exerrors.PanicIfNotNil(br.DB.Upgrade(ctx)) + err := br.DB.Upgrade(ctx) + if err != nil { + return DBUpgradeError{Err: err, Section: "main"} + } br.Log.Info().Msg("Starting Matrix connector") - exerrors.PanicIfNotNil(br.Matrix.Start(ctx)) + err = br.Matrix.Start(ctx) + if err != nil { + return fmt.Errorf("failed to start Matrix connector: %w", err) + } br.Log.Info().Msg("Starting network connector") - exerrors.PanicIfNotNil(br.Network.Start(ctx)) + err = br.Network.Start(ctx) + if err != nil { + return fmt.Errorf("failed to start network connector: %w", err) + } logins, err := br.GetAllUserLogins(ctx) if err != nil { - br.Log.Fatal().Err(err).Msg("Failed to get user logins") + return fmt.Errorf("failed to get user logins: %w", err) } for _, login := range logins { br.Log.Info().Str("id", string(login.ID)).Msg("Starting user login") @@ -92,11 +113,9 @@ func (br *Bridge) Start() { } if len(logins) == 0 { br.Log.Info().Msg("No user logins found") + // TODO send UNCONFIGURED bridge state } br.Log.Info().Msg("Bridge started") - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - <-c - br.Log.Info().Msg("Shutting down bridge") + return nil } diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index cfe1cffe..1153407a 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -13,20 +13,26 @@ import ( ) type Config struct { - Homeserver HomeserverConfig `yaml:"homeserver"` - AppService AppserviceConfig `yaml:"appservice"` - Database dbutil.Config `yaml:"database"` - Bridge BridgeConfig `yaml:"bridge"` // TODO this is more like matrix than bridge - Provisioning ProvisioningConfig `yaml:"provisioning"` - DoublePuppet DoublePuppetConfig `yaml:"double_puppet"` - Encryption EncryptionConfig `yaml:"encryption"` + Network yaml.Node `yaml:"network"` + Bridge BridgeConfig `yaml:"bridge"` + Database dbutil.Config `yaml:"database"` + Homeserver HomeserverConfig `yaml:"homeserver"` + AppService AppserviceConfig `yaml:"appservice"` + Matrix MatrixConfig `yaml:"matrix"` + Provisioning ProvisioningConfig `yaml:"provisioning"` + DoublePuppet DoublePuppetConfig `yaml:"double_puppet"` + Encryption EncryptionConfig `yaml:"encryption"` + Logging zeroconfig.Config `yaml:"logging"` + Permissions PermissionConfig `yaml:"permissions"` ManagementRoomTexts ManagementRoomTexts `yaml:"management_room_texts"` - Logging zeroconfig.Config `yaml:"logging"` - Network yaml.Node `yaml:"network"` } type BridgeConfig struct { + CommandPrefix string `yaml:"command_prefix"` +} + +type MatrixConfig struct { MessageStatusEvents bool `yaml:"message_status_events"` DeliveryReceipts bool `yaml:"delivery_receipts"` MessageErrorNotices bool `yaml:"message_error_notices"` diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index d4006f43..570f2c95 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -9,23 +9,28 @@ package bridgeconfig import ( "fmt" "os" - "path/filepath" - "strings" - "github.com/rs/zerolog" up "go.mau.fi/util/configupgrade" - "go.mau.fi/zeroconfig" - "gopkg.in/yaml.v3" + "go.mau.fi/util/random" ) -func doUpgrade(helper *up.Helper) { +func doUpgrade(helper up.Helper) { + helper.Copy(up.Str, "bridge", "command_prefix") + + if dbType, ok := helper.Get(up.Str, "database", "type"); ok && dbType == "sqlite3" { + helper.Set(up.Str, "sqlite3-fk-wal", "database", "type") + } else { + helper.Copy(up.Str, "database", "type") + } + helper.Copy(up.Str, "database", "uri") + helper.Copy(up.Int, "database", "max_open_conns") + helper.Copy(up.Int, "database", "max_idle_conns") + helper.Copy(up.Str|up.Null, "database", "max_conn_idle_time") + helper.Copy(up.Str|up.Null, "database", "max_conn_lifetime") + helper.Copy(up.Str, "homeserver", "address") helper.Copy(up.Str, "homeserver", "domain") - if legacyAsmuxFlag, ok := helper.Get(up.Bool, "homeserver", "asmux"); ok && legacyAsmuxFlag == "true" { - helper.Set(up.Str, string(SoftwareAsmux), "homeserver", "software") - } else { - helper.Copy(up.Str, "homeserver", "software") - } + helper.Copy(up.Str, "homeserver", "software") helper.Copy(up.Str|up.Null, "homeserver", "status_endpoint") helper.Copy(up.Str|up.Null, "homeserver", "message_send_checkpoint_endpoint") helper.Copy(up.Bool, "homeserver", "async_media") @@ -36,16 +41,6 @@ func doUpgrade(helper *up.Helper) { helper.Copy(up.Str|up.Null, "appservice", "address") helper.Copy(up.Str|up.Null, "appservice", "hostname") helper.Copy(up.Int|up.Null, "appservice", "port") - if dbType, ok := helper.Get(up.Str, "appservice", "database", "type"); ok && dbType == "sqlite3" { - helper.Set(up.Str, "sqlite3-fk-wal", "appservice", "database", "type") - } else { - helper.Copy(up.Str, "appservice", "database", "type") - } - helper.Copy(up.Str, "appservice", "database", "uri") - helper.Copy(up.Int, "appservice", "database", "max_open_conns") - helper.Copy(up.Int, "appservice", "database", "max_idle_conns") - helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_idle_time") - helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_lifetime") helper.Copy(up.Str, "appservice", "id") helper.Copy(up.Str, "appservice", "bot", "username") helper.Copy(up.Str, "appservice", "bot", "displayname") @@ -54,79 +49,83 @@ func doUpgrade(helper *up.Helper) { helper.Copy(up.Bool, "appservice", "async_transactions") helper.Copy(up.Str, "appservice", "as_token") helper.Copy(up.Str, "appservice", "hs_token") + helper.Copy(up.Str, "appservice", "username_template") + + helper.Copy(up.Bool, "matrix", "message_status_events") + helper.Copy(up.Bool, "matrix", "delivery_receipts") + helper.Copy(up.Bool, "matrix", "message_error_notices") + helper.Copy(up.Bool, "matrix", "sync_direct_chat_list") + helper.Copy(up.Bool, "matrix", "federate_rooms") + + helper.Copy(up.Str, "provisioning", "prefix") + if secret, ok := helper.Get(up.Str, "provisioning", "shared_secret"); !ok || secret == "generate" { + sharedSecret := random.String(64) + helper.Set(up.Str, sharedSecret, "provisioning", "shared_secret") + } else { + helper.Copy(up.Str, "provisioning", "shared_secret") + } + helper.Copy(up.Bool, "provisioning", "debug_endpoints") + + helper.Copy(up.Map, "double_puppet", "servers") + helper.Copy(up.Bool, "double_puppet", "allow_discovery") + helper.Copy(up.Map, "double_puppet", "secrets") + + helper.Copy(up.Bool, "encryption", "allow") + helper.Copy(up.Bool, "encryption", "default") + helper.Copy(up.Bool, "encryption", "require") + helper.Copy(up.Bool, "encryption", "appservice") + helper.Copy(up.Bool, "encryption", "allow_key_sharing") + if secret, ok := helper.Get(up.Str, "encryption", "pickle_key"); !ok || secret == "generate" { + helper.Set(up.Str, random.String(64), "encryption", "pickle_key") + } else { + helper.Copy(up.Str, "encryption", "pickle_key") + } + helper.Copy(up.Bool, "encryption", "delete_keys", "delete_outbound_on_ack") + helper.Copy(up.Bool, "encryption", "delete_keys", "dont_store_outbound") + helper.Copy(up.Bool, "encryption", "delete_keys", "ratchet_on_decrypt") + helper.Copy(up.Bool, "encryption", "delete_keys", "delete_fully_used_on_decrypt") + helper.Copy(up.Bool, "encryption", "delete_keys", "delete_prev_on_new_session") + helper.Copy(up.Bool, "encryption", "delete_keys", "delete_on_device_delete") + helper.Copy(up.Bool, "encryption", "delete_keys", "periodically_delete_expired") + helper.Copy(up.Bool, "encryption", "delete_keys", "delete_outdated_inbound") + helper.Copy(up.Str, "encryption", "verification_levels", "receive") + helper.Copy(up.Str, "encryption", "verification_levels", "send") + helper.Copy(up.Str, "encryption", "verification_levels", "share") + helper.Copy(up.Bool, "encryption", "rotation", "enable_custom") + helper.Copy(up.Int, "encryption", "rotation", "milliseconds") + helper.Copy(up.Int, "encryption", "rotation", "messages") + helper.Copy(up.Bool, "encryption", "rotation", "disable_device_change_key_rotation") if helper.GetNode("logging", "writers") == nil && (helper.GetNode("logging", "print_level") != nil || helper.GetNode("logging", "file_name_format") != nil) { - _, _ = fmt.Fprintln(os.Stderr, "Migrating legacy log config") - migrateLegacyLogConfig(helper) + _, _ = fmt.Fprintln(os.Stderr, "Migrating legacy log configs is not supported") } else if helper.GetNode("logging", "writers") == nil && (helper.GetNode("logging", "handlers") != nil) { - _, _ = fmt.Fprintln(os.Stderr, "Migrating Python log config is not currently supported") - // TODO implement? - //migratePythonLogConfig(helper) + _, _ = fmt.Fprintln(os.Stderr, "Migrating Python log configs is not supported") } else { helper.Copy(up.Map, "logging") } } -type legacyLogConfig struct { - Directory string `yaml:"directory"` - FileNameFormat string `yaml:"file_name_format"` - FileDateFormat string `yaml:"file_date_format"` - FileMode uint32 `yaml:"file_mode"` - TimestampFormat string `yaml:"timestamp_format"` - RawPrintLevel string `yaml:"print_level"` - JSONStdout bool `yaml:"print_json"` - JSONFile bool `yaml:"file_json"` -} - -func migrateLegacyLogConfig(helper *up.Helper) { - var llc legacyLogConfig - var newConfig zeroconfig.Config - err := helper.GetBaseNode("logging").Decode(&newConfig) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Base config is corrupted: failed to decode example log config:", err) - return - } else if len(newConfig.Writers) != 2 || newConfig.Writers[0].Type != "stdout" || newConfig.Writers[1].Type != "file" { - _, _ = fmt.Fprintln(os.Stderr, "Base log config is not in expected format") - return - } - err = helper.GetNode("logging").Decode(&llc) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to decode legacy log config:", err) - return - } - if llc.RawPrintLevel != "" { - level, err := zerolog.ParseLevel(llc.RawPrintLevel) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to parse minimum stdout log level:", err) - } else { - newConfig.Writers[0].MinLevel = &level - } - } - if llc.Directory != "" && llc.FileNameFormat != "" { - if llc.FileNameFormat == "{{.Date}}-{{.Index}}.log" { - llc.FileNameFormat = "bridge.log" - } else { - llc.FileNameFormat = strings.ReplaceAll(llc.FileNameFormat, "{{.Date}}", "") - llc.FileNameFormat = strings.ReplaceAll(llc.FileNameFormat, "{{.Index}}", "") - } - newConfig.Writers[1].Filename = filepath.Join(llc.Directory, llc.FileNameFormat) - } else if llc.FileNameFormat == "" { - newConfig.Writers = newConfig.Writers[0:1] - } - if llc.JSONStdout { - newConfig.Writers[0].TimeFormat = "" - newConfig.Writers[0].Format = "json" - } else if llc.TimestampFormat != "" { - newConfig.Writers[0].TimeFormat = llc.TimestampFormat - } - var updatedConfig yaml.Node - err = updatedConfig.Encode(&newConfig) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to encode migrated log config:", err) - return - } - *helper.GetBaseNode("logging").Node = updatedConfig +var SpacedBlocks = [][]string{ + {"bridge"}, + {"database"}, + {"homeserver"}, + {"homeserver", "software"}, + {"homeserver", "websocket"}, + {"appservice"}, + {"appservice", "hostname"}, + {"appservice", "id"}, + {"appservice", "ephemeral_events"}, + {"appservice", "as_token"}, + {"appservice", "username_template"}, + {"matrix"}, + {"provisioning"}, + {"double_puppet"}, + {"encryption"}, + {"logging"}, } // Upgrader is a config upgrader that copies the default fields in the homeserver, appservice and logging blocks. -var Upgrader = up.SimpleUpgrader(doUpgrade) +var Upgrader up.SpacedUpgrader = &up.StructUpgrader{ + SimpleUpgrader: up.SimpleUpgrader(doUpgrade), + Blocks: SpacedBlocks, +} diff --git a/bridgev2/cmdevent.go b/bridgev2/cmdevent.go index de43ccca..0c80330d 100644 --- a/bridgev2/cmdevent.go +++ b/bridgev2/cmdevent.go @@ -40,7 +40,7 @@ type CommandEvent struct { // Reply sends a reply to command as notice, with optional string formatting and automatic $cmdprefix replacement. func (ce *CommandEvent) Reply(msg string, args ...any) { - msg = strings.ReplaceAll(msg, "$cmdprefix ", ce.Bridge.CommandPrefix+" ") + msg = strings.ReplaceAll(msg, "$cmdprefix ", ce.Bridge.Config.CommandPrefix+" ") if len(args) > 0 { msg = fmt.Sprintf(msg, args...) } diff --git a/bridgev2/cmdhelp.go b/bridgev2/cmdhelp.go index 53d5076e..043d487c 100644 --- a/bridgev2/cmdhelp.go +++ b/bridgev2/cmdhelp.go @@ -108,7 +108,7 @@ func FormatHelp(ce *CommandEvent) string { } else { prefixMsg = "This is not your management room: prefixing commands with `%s` is required." } - _, _ = fmt.Fprintf(&output, prefixMsg, ce.Bridge.CommandPrefix) + _, _ = fmt.Fprintf(&output, prefixMsg, ce.Bridge.Config.CommandPrefix) output.WriteByte('\n') output.WriteString("Parameters in [square brackets] are optional, while parameters in are required.") output.WriteByte('\n') diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index 79eff612..762856d2 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -178,11 +178,12 @@ func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, if isBot != nil { ghost.Metadata.IsBot = *isBot } + bridgeName := ghost.Bridge.Network.GetName() meta := &event.BeeperProfileExtra{ RemoteID: string(ghost.ID), Identifiers: ghost.Metadata.Identifiers, - Service: "", // TODO set - Network: "", // TODO set + Service: bridgeName.BeeperBridgeType, + Network: bridgeName.NetworkID, IsBridgeBot: false, IsNetworkBot: ghost.Metadata.IsBot, } diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 7e9cc892..3eb9176d 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -8,7 +8,11 @@ package matrix import ( "context" + "encoding/json" + "errors" + "os" "regexp" + "strings" "sync" "time" @@ -52,7 +56,9 @@ type Connector struct { Bridge *bridgev2.Bridge Provisioning *ProvisioningAPI - SpecVersions *mautrix.RespVersions + MediaConfig mautrix.RespMediaConfig + SpecVersions *mautrix.RespVersions + IgnoreUnsupportedServer bool EventProcessor *appservice.EventProcessor @@ -66,10 +72,13 @@ type Connector struct { wsStartupWait *sync.WaitGroup } +var _ bridgev2.MatrixConnector = (*Connector)(nil) + func NewConnector(cfg *bridgeconfig.Config) *Connector { c := &Connector{} c.Config = cfg c.userIDRegex = cfg.MakeUserIDRegex("(.+)") + c.MediaConfig.UploadSize = 50 * 1024 * 1024 return c } @@ -81,6 +90,9 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) { br.AS.Log = bridge.Log br.AS.StateStore = br.StateStore br.EventProcessor = appservice.NewEventProcessor(br.AS) + if !br.Config.AppService.AsyncTransactions { + br.EventProcessor.ExecMode = appservice.Sync + } for evtType := range status.CheckpointTypes { br.EventProcessor.On(evtType, br.sendBridgeCheckpoint) } @@ -98,27 +110,181 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) { func (br *Connector) Start(ctx context.Context) error { br.Provisioning.Init() - br.EventProcessor.Start(ctx) err := br.StateStore.Upgrade(ctx) if err != nil { - return err + return bridgev2.DBUpgradeError{Section: "matrix_state", Err: err} } go br.AS.Start() - br.SpecVersions, err = br.Bot.Versions(ctx) - if err != nil { - return err - } + br.ensureConnection(ctx) + go br.fetchMediaConfig(ctx) if br.Crypto != nil { err = br.Crypto.Init(ctx) if err != nil { return err } - br.Crypto.Start() } + br.EventProcessor.Start(ctx) + go br.UpdateBotProfile(ctx) + if br.Crypto != nil { + go br.Crypto.Start() + } + br.AS.Ready = true return nil } -var _ bridgev2.MatrixConnector = (*Connector)(nil) +var MinSpecVersion = mautrix.SpecV14 + +func (br *Connector) ensureConnection(ctx context.Context) { + for { + versions, err := br.Bot.Versions(ctx) + if err != nil { + br.Log.Err(err).Msg("Failed to connect to homeserver, retrying in 10 seconds...") + time.Sleep(10 * time.Second) + } else { + br.SpecVersions = versions + break + } + } + + unsupportedServerLogLevel := zerolog.FatalLevel + if br.IgnoreUnsupportedServer { + unsupportedServerLogLevel = zerolog.ErrorLevel + } + if br.Config.Homeserver.Software == bridgeconfig.SoftwareHungry && !br.SpecVersions.Supports(mautrix.BeeperFeatureHungry) { + br.Log.WithLevel(zerolog.FatalLevel).Msg("The config claims the homeserver is hungryserv, but the /versions response didn't confirm it") + os.Exit(18) + } else if !br.SpecVersions.ContainsGreaterOrEqual(MinSpecVersion) { + br.Log.WithLevel(unsupportedServerLogLevel). + Stringer("server_supports", br.SpecVersions.GetLatest()). + Stringer("bridge_requires", MinSpecVersion). + Msg("The homeserver is outdated (supported spec versions are below minimum required by bridge)") + if !br.IgnoreUnsupportedServer { + os.Exit(18) + } + } + + resp, err := br.Bot.Whoami(ctx) + if err != nil { + if errors.Is(err, mautrix.MUnknownToken) { + br.Log.WithLevel(zerolog.FatalLevel).Msg("The as_token was not accepted. Is the registration file installed in your homeserver correctly?") + br.Log.Info().Msg("See https://docs.mau.fi/faq/as-token for more info") + } else if errors.Is(err, mautrix.MExclusive) { + br.Log.WithLevel(zerolog.FatalLevel).Msg("The as_token was accepted, but the /register request was not. Are the homeserver domain, bot username and username template in the config correct, and do they match the values in the registration?") + br.Log.Info().Msg("See https://docs.mau.fi/faq/as-register for more info") + } else { + br.Log.WithLevel(zerolog.FatalLevel).Err(err).Msg("/whoami request failed with unknown error") + } + os.Exit(16) + } else if resp.UserID != br.Bot.UserID { + br.Log.WithLevel(zerolog.FatalLevel). + Stringer("got_user_id", resp.UserID). + Stringer("expected_user_id", br.Bot.UserID). + Msg("Unexpected user ID in whoami call") + os.Exit(17) + } + + if br.Websocket { + br.Log.Debug().Msg("Websocket mode: no need to check status of homeserver -> bridge connection") + return + } else if !br.SpecVersions.Supports(mautrix.FeatureAppservicePing) { + br.Log.Debug().Msg("Homeserver does not support checking status of homeserver -> bridge connection") + return + } + var pingResp *mautrix.RespAppservicePing + var txnID string + var retryCount int + const maxRetries = 6 + for { + txnID = br.Bot.TxnID() + pingResp, err = br.Bot.AppservicePing(ctx, br.Config.AppService.ID, txnID) + if err == nil { + break + } + var httpErr mautrix.HTTPError + var pingErrBody string + if errors.As(err, &httpErr) && httpErr.RespError != nil { + if val, ok := httpErr.RespError.ExtraData["body"].(string); ok { + pingErrBody = strings.TrimSpace(val) + } + } + outOfRetries := retryCount >= maxRetries + level := zerolog.ErrorLevel + if outOfRetries { + level = zerolog.FatalLevel + } + evt := br.Log.WithLevel(level).Err(err).Str("txn_id", txnID) + if pingErrBody != "" { + bodyBytes := []byte(pingErrBody) + if json.Valid(bodyBytes) { + evt.RawJSON("body", bodyBytes) + } else { + evt.Str("body", pingErrBody) + } + } + if outOfRetries { + evt.Msg("Homeserver -> bridge connection is not working") + br.Log.Info().Msg("See https://docs.mau.fi/faq/as-ping for more info") + os.Exit(13) + } + evt.Msg("Homeserver -> bridge connection is not working, retrying in 5 seconds...") + time.Sleep(5 * time.Second) + retryCount++ + } + br.Log.Debug(). + Str("txn_id", txnID). + Int64("duration_ms", pingResp.DurationMS). + Msg("Homeserver -> bridge connection works") +} + +func (br *Connector) fetchMediaConfig(ctx context.Context) { + cfg, err := br.Bot.GetMediaConfig(ctx) + if err != nil { + br.Log.Warn().Err(err).Msg("Failed to fetch media config") + } else { + if cfg.UploadSize == 0 { + cfg.UploadSize = 50 * 1024 * 1024 + } + br.MediaConfig = *cfg + } +} + +func (br *Connector) UpdateBotProfile(ctx context.Context) { + br.Log.Debug().Msg("Updating bot profile") + botConfig := &br.Config.AppService.Bot + + var err error + var mxc id.ContentURI + if botConfig.Avatar == "remove" { + err = br.Bot.SetAvatarURL(ctx, mxc) + } else if !botConfig.ParsedAvatar.IsEmpty() { + err = br.Bot.SetAvatarURL(ctx, botConfig.ParsedAvatar) + } + if err != nil { + br.Log.Warn().Err(err).Msg("Failed to update bot avatar") + } + + if botConfig.Displayname == "remove" { + err = br.Bot.SetDisplayName(ctx, "") + } else if len(botConfig.Displayname) > 0 { + err = br.Bot.SetDisplayName(ctx, botConfig.Displayname) + } + if err != nil { + br.Log.Warn().Err(err).Msg("Failed to update bot displayname") + } + + if br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) { + br.Log.Debug().Msg("Setting contact info on the appservice bot") + netName := br.Bridge.Network.GetName() + err = br.Bot.BeeperUpdateProfile(ctx, event.BeeperProfileExtra{ + Service: netName.BeeperBridgeType, + Network: netName.NetworkID, + IsBridgeBot: true, + }) + if err != nil { + br.Log.Warn().Err(err).Msg("Failed to update bot contact info") + } + } +} func (br *Connector) GhostIntent(userID id.UserID) bridgev2.MatrixAPI { return &ASIntent{ @@ -137,13 +303,13 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2 if err != nil { log.Err(err).Msg("Failed to send message checkpoint") } - if br.Config.Bridge.MessageStatusEvents { + if br.Config.Matrix.MessageStatusEvents { _, err = br.Bot.SendMessageEvent(ctx, evt.RoomID, event.BeeperMessageStatus, ms.ToMSSEvent(evt)) if err != nil { log.Err(err).Msg("Failed to send MSS event") } } - if ms.SendNotice && br.Config.Bridge.MessageErrorNotices && (ms.Status == event.MessageStatusFail || ms.Status == event.MessageStatusRetriable || ms.Step == status.MsgStepDecrypted) { + if ms.SendNotice && br.Config.Matrix.MessageErrorNotices && (ms.Status == event.MessageStatusFail || ms.Status == event.MessageStatusRetriable || ms.Step == status.MsgStepDecrypted) { content := ms.ToNoticeEvent(evt) if editEvent != "" { content.SetEdit(editEvent) @@ -155,6 +321,12 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2 return resp.EventID } } + if ms.Status == event.MessageStatusSuccess && br.Config.Matrix.DeliveryReceipts { + err = br.Bot.SendReceipt(ctx, evt.RoomID, evt.EventID, event.ReceiptTypeRead, nil) + if err != nil { + log.Err(err).Msg("Failed to send Matrix delivery receipt") + } + } return "" } diff --git a/bridgev2/matrix/crypto.go b/bridgev2/matrix/crypto.go index 008c2f3b..56b0e179 100644 --- a/bridgev2/matrix/crypto.go +++ b/bridgev2/matrix/crypto.go @@ -21,6 +21,7 @@ import ( "go.mau.fi/util/dbutil" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/crypto" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/event" @@ -77,9 +78,7 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { err := helper.store.DB.Upgrade(ctx) if err != nil { - // TODO copy this function back - //helper.bridge.LogDBUpgradeErrorAndExit("crypto", err) - panic(err) + return bridgev2.DBUpgradeError{Section: "crypto", Err: err} } var isExistingDevice bool diff --git a/bridgev2/matrix/doublepuppet.go b/bridgev2/matrix/doublepuppet.go index 4c1aca6a..3c0f65e2 100644 --- a/bridgev2/matrix/doublepuppet.go +++ b/bridgev2/matrix/doublepuppet.go @@ -74,9 +74,7 @@ func (dp *doublePuppetUtil) autoLogin(ctx context.Context, mxid id.UserID, login if err != nil { return "", fmt.Errorf("failed to create mautrix client to log in: %v", err) } - // TODO proper bridge name - //bridgeName := fmt.Sprintf("%s Bridge", dp.br.ProtocolName) - bridgeName := "Megabridge" + bridgeName := fmt.Sprintf("%s Bridge", dp.br.Bridge.Network.GetName().DisplayName) req := mautrix.ReqLogin{ Identifier: mautrix.UserIdentifier{Type: mautrix.IdentifierTypeUser, User: string(mxid)}, DeviceID: id.DeviceID(bridgeName), diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 800d8b75..25856cb0 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -172,6 +172,12 @@ func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) }, }) } + if !as.Connector.Config.Matrix.FederateRooms { + if req.CreationContent == nil { + req.CreationContent = make(map[string]any) + } + req.CreationContent["m.federate"] = false + } resp, err := as.Matrix.CreateRoom(ctx, req) if err != nil { return "", err diff --git a/bridgev2/matrix/mxmain/config.go b/bridgev2/matrix/mxmain/config.go new file mode 100644 index 00000000..a684d8a2 --- /dev/null +++ b/bridgev2/matrix/mxmain/config.go @@ -0,0 +1,36 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mxmain + +import ( + _ "embed" + "strings" + "text/template" + + "go.mau.fi/util/exerrors" +) + +//go:embed example-config.yaml +var MatrixExampleConfigBase string + +var matrixExampleConfigBaseTemplate = exerrors.Must(template.New("example-config.yaml"). + Delims("$<<", ">>"). + Parse(MatrixExampleConfigBase)) + +func (br *BridgeMain) makeFullExampleConfig(networkExample string) string { + var buf strings.Builder + buf.WriteString("# Network-specific config options\n") + buf.WriteString("network:\n") + for _, line := range strings.Split(networkExample, "\n") { + buf.WriteString(" ") + buf.WriteString(line) + buf.WriteRune('\n') + } + buf.WriteRune('\n') + exerrors.PanicIfNotNil(matrixExampleConfigBaseTemplate.Execute(&buf, br.Connector.GetName())) + return buf.String() +} diff --git a/bridgev2/matrix/mxmain/dberror.go b/bridgev2/matrix/mxmain/dberror.go new file mode 100644 index 00000000..eb34ccfa --- /dev/null +++ b/bridgev2/matrix/mxmain/dberror.go @@ -0,0 +1,75 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mxmain + +import ( + "errors" + "os" + + "github.com/lib/pq" + "github.com/mattn/go-sqlite3" + "github.com/rs/zerolog" + + "go.mau.fi/util/dbutil" +) + +type zerologPQError pq.Error + +func (zpe *zerologPQError) MarshalZerologObject(evt *zerolog.Event) { + maybeStr := func(field, value string) { + if value != "" { + evt.Str(field, value) + } + } + maybeStr("severity", zpe.Severity) + if name := zpe.Code.Name(); name != "" { + evt.Str("code", name) + } else if zpe.Code != "" { + evt.Str("code", string(zpe.Code)) + } + //maybeStr("message", zpe.Message) + maybeStr("detail", zpe.Detail) + maybeStr("hint", zpe.Hint) + maybeStr("position", zpe.Position) + maybeStr("internal_position", zpe.InternalPosition) + maybeStr("internal_query", zpe.InternalQuery) + maybeStr("where", zpe.Where) + maybeStr("schema", zpe.Schema) + maybeStr("table", zpe.Table) + maybeStr("column", zpe.Column) + maybeStr("data_type_name", zpe.DataTypeName) + maybeStr("constraint", zpe.Constraint) + maybeStr("file", zpe.File) + maybeStr("line", zpe.Line) + maybeStr("routine", zpe.Routine) +} + +func (br *BridgeMain) LogDBUpgradeErrorAndExit(name string, err error) { + logEvt := br.Log.WithLevel(zerolog.FatalLevel). + Err(err). + Str("db_section", name) + var errWithLine *dbutil.PQErrorWithLine + if errors.As(err, &errWithLine) { + logEvt.Str("sql_line", errWithLine.Line) + } + var pqe *pq.Error + if errors.As(err, &pqe) { + logEvt.Object("pq_error", (*zerologPQError)(pqe)) + } + logEvt.Msg("Failed to initialize database") + if sqlError := (&sqlite3.Error{}); errors.As(err, sqlError) && sqlError.Code == sqlite3.ErrCorrupt { + os.Exit(18) + } else if errors.Is(err, dbutil.ErrForeignTables) { + br.Log.Info().Msg("You can use --ignore-foreign-tables to ignore this error") + br.Log.Info().Msg("See https://docs.mau.fi/faq/foreign-tables for more info") + } else if errors.Is(err, dbutil.ErrNotOwned) { + br.Log.Info().Msg("Sharing the same database with different programs is not supported") + } else if errors.Is(err, dbutil.ErrUnsupportedDatabaseVersion) { + br.Log.Info().Msg("Downgrading the bridge is not supported") + } + os.Exit(15) +} diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml new file mode 100644 index 00000000..06d9f8ca --- /dev/null +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -0,0 +1,217 @@ +# Config options that affect the central bridge module. +bridge: + # The prefix for commands. Only required in non-management rooms. + command_prefix: '$<>' + +# Config for the bridge's database. +database: + # The database type. "sqlite3-fk-wal" and "postgres" are supported. + type: postgres + # The database URI. + # SQLite: A raw file path is supported, but `file:?_txlock=immediate` is recommended. + # https://github.com/mattn/go-sqlite3#connection-string + # Postgres: Connection string. For example, postgres://user:password@host/database?sslmode=disable + # To connect via Unix socket, use something like postgres:///dbname?host=/var/run/postgresql + uri: postgres://user:password@host/database?sslmode=disable + # Maximum number of connections. + max_open_conns: 5 + max_idle_conns: 1 + # Maximum connection idle time and lifetime before they're closed. Disabled if null. + # Parsed with https://pkg.go.dev/time#ParseDuration + max_conn_idle_time: null + max_conn_lifetime: null + +# Homeserver details. +homeserver: + # The address that this appservice can use to connect to the homeserver. + # Local addresses without HTTPS are generally recommended when the bridge is running on the same machine, + # but https also works if they run on different machines. + address: http://example.localhost:8008 + # The domain of the homeserver (also known as server_name, used for MXIDs, etc). + domain: example.com + + # What software is the homeserver running? + # Standard Matrix homeservers like Synapse, Dendrite and Conduit should just use "standard" here. + software: standard + # The URL to push real-time bridge status to. + # If set, the bridge will make POST requests to this URL whenever a user's remote network connection state changes. + # The bridge will use the appservice as_token to authorize requests. + status_endpoint: http://localhost:4001 + # Endpoint for reporting per-message status. + # If set, the bridge will make POST requests to this URL when processing a message from Matrix. + # It will make one request when receiving the message (step BRIDGE), one after decrypting if applicable + # (step DECRYPTED) and one after sending to the remote network (step REMOTE). Errors will also be reported. + # The bridge will use the appservice as_token to authorize requests. + message_send_checkpoint_endpoint: http://localhost:4001 + # Does the homeserver support https://github.com/matrix-org/matrix-spec-proposals/pull/2246? + async_media: false + + # Should the bridge use a websocket for connecting to the homeserver? + # The server side is currently not documented anywhere and is only implemented by mautrix-wsproxy, + # mautrix-asmux (deprecated), and hungryserv (proprietary). + websocket: false + # How often should the websocket be pinged? Pinging will be disabled if this is zero. + ping_interval_seconds: 0 + +# Application service host/registration related details. +# Changing these values requires regeneration of the registration. +appservice: + # The address that the homeserver can use to connect to this appservice. + address: http://localhost:$<> + + # The hostname and port where this appservice should listen. + # For Docker, you generally have to change the hostname to 0.0.0.0. + hostname: 127.0.0.1 + port: $<> + + # The unique ID of this appservice. + id: $<<.NetworkID>> + # Appservice bot details. + bot: + # Username of the appservice bot. + username: $<<.NetworkID>>bot + # Display name and avatar for bot. Set to "remove" to remove display name/avatar, leave empty + # to leave display name/avatar as-is. + displayname: $<<.DisplayName>> bridge bot + avatar: $<<.NetworkIcon>> + + # Whether to receive ephemeral events via appservice transactions. + ephemeral_events: true + # Should incoming events be handled asynchronously? + # This may be necessary for large public instances with lots of messages going through. + # However, messages will not be guaranteed to be bridged in the same order they were sent in. + async_transactions: false + + # Authentication tokens for AS <-> HS communication. Autogenerated; do not modify. + as_token: "This value is generated when generating the registration" + hs_token: "This value is generated when generating the registration" + + # Localpart template of MXIDs for remote users. + # {{.}} is replaced with the internal ID of the user. + username_template: $<<.NetworkID>>_{{.}} + +# Config options that affect the Matrix connector of the bridge. +matrix: + # Whether the bridge should send the message status as a custom com.beeper.message_send_status event. + message_status_events: false + # Whether the bridge should send a read receipt after successfully bridging a message. + delivery_receipts: false + # Whether the bridge should send error notices via m.notice events when a message fails to bridge. + message_error_notices: true + sync_direct_chat_list: false + # Whether created rooms should have federation enabled. If false, created portal rooms + # will never be federated. Changing this option requires recreating rooms. + federate_rooms: false + +# Settings for provisioning API +provisioning: + # Prefix for the provisioning API paths. + prefix: /_matrix/provision + # Shared secret for authentication. If set to "generate" or null, a random secret will be generated, + # or if set to "disable", the provisioning API will be disabled. + shared_secret: generate + # Enable debug API at /debug with provisioning authentication. + debug_endpoints: true + +# Settings for enabling double puppeting +double_puppet: + # Servers to always allow double puppeting from. + # This is only for other servers and should NOT contain the server the bridge is on. + servers: + matrix.org: https://matrix-client.matrix.org + # Whether to allow client API URL discovery for other servers. When using this option, + # users on other servers can use double puppeting even if their server URLs aren't + # explicitly added to the servers map above. + allow_discovery: false + # Shared secrets for automatic double puppeting. + # See https://docs.mau.fi/bridges/general/double-puppeting.html for instructions. + secrets: + example.com: as_token:foobar + +# End-to-bridge encryption support options. +# +# See https://docs.mau.fi/bridges/general/end-to-bridge-encryption.html for more info. +encryption: + # Whether to enable encryption at all. If false, the bridge will not function in encrypted rooms. + allow: false + # Whether to force-enable encryption in all bridged rooms. + default: false + # Whether to require all messages to be encrypted and drop any unencrypted messages. + require: false + # Whether to use MSC2409/MSC3202 instead of /sync long polling for receiving encryption-related data. + # This option is not yet compatible with standard Matrix servers like Synapse and should not be used. + appservice: false + # Enable key sharing? If enabled, key requests for rooms where users are in will be fulfilled. + # You must use a client that supports requesting keys from other users to use this feature. + allow_key_sharing: true + # Pickle key for encrypting encryption keys in the bridge database. + # If set to generate, a random key will be generated. + pickle_key: generate + # Options for deleting megolm sessions from the bridge. + delete_keys: + # Beeper-specific: delete outbound sessions when hungryserv confirms + # that the user has uploaded the key to key backup. + delete_outbound_on_ack: false + # Don't store outbound sessions in the inbound table. + dont_store_outbound: false + # Ratchet megolm sessions forward after decrypting messages. + ratchet_on_decrypt: false + # Delete fully used keys (index >= max_messages) after decrypting messages. + delete_fully_used_on_decrypt: false + # Delete previous megolm sessions from same device when receiving a new one. + delete_prev_on_new_session: false + # Delete megolm sessions received from a device when the device is deleted. + delete_on_device_delete: false + # Periodically delete megolm sessions when 2x max_age has passed since receiving the session. + periodically_delete_expired: false + # Delete inbound megolm sessions that don't have the received_at field used for + # automatic ratcheting and expired session deletion. This is meant as a migration + # to delete old keys prior to the bridge update. + delete_outdated_inbound: false + # What level of device verification should be required from users? + # + # Valid levels: + # unverified - Send keys to all device in the room. + # cross-signed-untrusted - Require valid cross-signing, but trust all cross-signing keys. + # cross-signed-tofu - Require valid cross-signing, trust cross-signing keys on first use (and reject changes). + # cross-signed-verified - Require valid cross-signing, plus a valid user signature from the bridge bot. + # Note that creating user signatures from the bridge bot is not currently possible. + # verified - Require manual per-device verification + # (currently only possible by modifying the `trust` column in the `crypto_device` database table). + verification_levels: + # Minimum level for which the bridge should send keys to when bridging messages from the remote network to Matrix. + receive: unverified + # Minimum level that the bridge should accept for incoming Matrix messages. + send: unverified + # Minimum level that the bridge should require for accepting key requests. + share: cross-signed-tofu + # Options for Megolm room key rotation. These options allow you to configure the m.room.encryption event content. + # See https://spec.matrix.org/v1.10/client-server-api/#mroomencryption for more information about that event. + rotation: + # Enable custom Megolm room key rotation settings. Note that these + # settings will only apply to rooms created after this option is set. + enable_custom: false + # The maximum number of milliseconds a session should be used + # before changing it. The Matrix spec recommends 604800000 (a week) + # as the default. + milliseconds: 604800000 + # The maximum number of messages that should be sent with a given a + # session before changing it. The Matrix spec recommends 100 as the + # default. + messages: 100 + # Disable rotating keys when a user's devices change? + # You should not enable this option unless you understand all the implications. + disable_device_change_key_rotation: false + +# Logging config. See https://github.com/tulir/zeroconfig for details. +logging: + min_level: debug + writers: + - type: stdout + format: pretty-colored + - type: file + format: json + filename: ./logs/bridge.log + max_size: 100 + max_backups: 10 + compress: false diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go new file mode 100644 index 00000000..85269fce --- /dev/null +++ b/bridgev2/matrix/mxmain/main.go @@ -0,0 +1,348 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mxmain + +import ( + _ "embed" + "encoding/json" + "errors" + "fmt" + "os" + "os/signal" + "runtime" + "strings" + "syscall" + "time" + + "github.com/mattn/go-sqlite3" + "github.com/rs/zerolog" + "go.mau.fi/util/configupgrade" + "go.mau.fi/util/dbutil" + "go.mau.fi/util/exerrors" + "go.mau.fi/util/exzerolog" + "gopkg.in/yaml.v3" + flag "maunium.net/go/mauflag" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/bridgeconfig" + "maunium.net/go/mautrix/bridgev2/matrix" +) + +var configPath = flag.MakeFull("c", "config", "The path to your config file.", "config.yaml").String() +var writeExampleConfig = flag.MakeFull("e", "generate-example-config", "Save the example config to the config path and quit.", "false").Bool() +var dontSaveConfig = flag.MakeFull("n", "no-update", "Don't save updated config to disk.", "false").Bool() +var registrationPath = flag.MakeFull("r", "registration", "The path where to save the appservice registration.", "registration.yaml").String() +var generateRegistration = flag.MakeFull("g", "generate-registration", "Generate registration and quit.", "false").Bool() +var version = flag.MakeFull("v", "version", "View bridge version and quit.", "false").Bool() +var versionJSON = flag.Make().LongKey("version-json").Usage("Print a JSON object representing the bridge version and quit.").Default("false").Bool() +var ignoreUnsupportedDatabase = flag.Make().LongKey("ignore-unsupported-database").Usage("Run even if the database schema is too new").Default("false").Bool() +var ignoreForeignTables = flag.Make().LongKey("ignore-foreign-tables").Usage("Run even if the database contains tables from other programs (like Synapse)").Default("false").Bool() +var ignoreUnsupportedServer = flag.Make().LongKey("ignore-unsupported-server").Usage("Run even if the Matrix homeserver is outdated").Default("false").Bool() +var wantHelp, _ = flag.MakeHelpFlag() + +type BridgeMain struct { + Name string + Description string + URL string + Version string + + PostInit func() + + Connector bridgev2.NetworkConnector + Log *zerolog.Logger + DB *dbutil.Database + Config *bridgeconfig.Config + Matrix *matrix.Connector + Bridge *bridgev2.Bridge + + ConfigPath string + RegistrationPath string + SaveConfig bool + + baseVersion string + commit string + LinkifiedVersion string + VersionDesc string + BuildTime time.Time + + AdditionalShortFlags string + AdditionalLongFlags string +} + +type VersionJSONOutput struct { + Name string + URL string + + Version string + IsRelease bool + Commit string + FormattedVersion string + BuildTime time.Time + + OS string + Arch string + + Mautrix struct { + Version string + Commit string + } +} + +func (br *BridgeMain) Run() { + flag.SetHelpTitles( + fmt.Sprintf("%s - %s", br.Name, br.Description), + fmt.Sprintf("%s [-hgvn%s] [-c ] [-r ]%s", br.Name, br.AdditionalShortFlags, br.AdditionalLongFlags)) + err := flag.Parse() + br.ConfigPath = *configPath + br.RegistrationPath = *registrationPath + br.SaveConfig = !*dontSaveConfig + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, err) + flag.PrintHelp() + os.Exit(1) + } else if *wantHelp { + flag.PrintHelp() + os.Exit(0) + } else if *version { + fmt.Println(br.VersionDesc) + return + } else if *versionJSON { + output := VersionJSONOutput{ + URL: br.URL, + Name: br.Name, + + Version: br.baseVersion, + IsRelease: br.Version == br.baseVersion, + Commit: br.commit, + FormattedVersion: br.Version, + BuildTime: br.BuildTime, + + OS: runtime.GOOS, + Arch: runtime.GOARCH, + } + output.Mautrix.Commit = mautrix.Commit + output.Mautrix.Version = mautrix.Version + _ = json.NewEncoder(os.Stdout).Encode(output) + return + } else if *writeExampleConfig { + networkExample, _, _ := br.Connector.GetConfig() + exerrors.PanicIfNotNil(os.WriteFile(*configPath, []byte(br.makeFullExampleConfig(networkExample)), 0600)) + return + } + + br.loadConfig() + if *generateRegistration { + br.GenerateRegistration() + return + } + + br.Init() + err = br.Bridge.Start() + if err != nil { + var dbUpgradeErr bridgev2.DBUpgradeError + if errors.As(err, &dbUpgradeErr) { + br.LogDBUpgradeErrorAndExit(dbUpgradeErr.Section, dbUpgradeErr.Err) + } else { + br.Log.Fatal().Err(err).Msg("Failed to start bridge") + } + } + + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + <-c + br.Log.Info().Msg("Shutting down bridge") +} + +func (br *BridgeMain) GenerateRegistration() { + if !br.SaveConfig { + // We need to save the generated as_token and hs_token in the config + _, _ = fmt.Fprintln(os.Stderr, "--no-update is not compatible with --generate-registration") + os.Exit(5) + } else if br.Config.Homeserver.Domain == "example.com" { + _, _ = fmt.Fprintln(os.Stderr, "Homeserver domain is not set") + os.Exit(20) + } + reg := br.Config.GenerateRegistration() + err := reg.Save(br.RegistrationPath) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Failed to save registration:", err) + os.Exit(21) + } + + updateTokens := func(helper configupgrade.Helper) { + helper.Set(configupgrade.Str, reg.AppToken, "appservice", "as_token") + helper.Set(configupgrade.Str, reg.ServerToken, "appservice", "hs_token") + } + upgrader, _ := br.getConfigUpgrader() + _, _, err = configupgrade.Do(br.ConfigPath, true, upgrader, configupgrade.SimpleUpgrader(updateTokens)) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Failed to save config:", err) + os.Exit(22) + } + fmt.Println("Registration generated. See https://docs.mau.fi/bridges/general/registering-appservices.html for instructions on installing the registration.") + os.Exit(0) +} + +func (br *BridgeMain) Init() { + var err error + br.Log, err = br.Config.Logging.Compile() + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Failed to initialize logger:", err) + os.Exit(12) + } + err = br.validateConfig() + if err != nil { + br.Log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Configuration error") + br.Log.Info().Msg("See https://docs.mau.fi/faq/field-unconfigured for more info") + os.Exit(11) + } + + br.Log.Info(). + Str("name", br.Name). + Str("version", br.Version). + Time("built_at", br.BuildTime). + Str("go_version", runtime.Version()). + Msg("Initializing bridge") + + exzerolog.SetupDefaults(br.Log) + br.initDB() + br.Matrix = matrix.NewConnector(br.Config) + br.Matrix.IgnoreUnsupportedServer = *ignoreUnsupportedServer + br.Bridge = bridgev2.NewBridge("", br.DB, *br.Log, &br.Config.Bridge, br.Matrix, br.Connector) + if br.PostInit != nil { + br.PostInit() + } +} + +func (br *BridgeMain) initDB() { + br.Log.Debug().Msg("Initializing database connection") + dbConfig := br.Config.Database + if (dbConfig.Type == "sqlite3-fk-wal" || dbConfig.Type == "litestream") && dbConfig.MaxOpenConns != 1 && !strings.Contains(dbConfig.URI, "_txlock=immediate") { + var fixedExampleURI string + if !strings.HasPrefix(dbConfig.URI, "file:") { + fixedExampleURI = fmt.Sprintf("file:%s?_txlock=immediate", dbConfig.URI) + } else if !strings.ContainsRune(dbConfig.URI, '?') { + fixedExampleURI = fmt.Sprintf("%s?_txlock=immediate", dbConfig.URI) + } else { + fixedExampleURI = fmt.Sprintf("%s&_txlock=immediate", dbConfig.URI) + } + br.Log.Warn(). + Str("fixed_uri_example", fixedExampleURI). + Msg("Using SQLite without _txlock=immediate is not recommended") + } + var err error + br.DB, err = dbutil.NewFromConfig(br.Name, dbConfig, dbutil.ZeroLogger(br.Log.With().Str("db_section", "main").Logger())) + if err != nil { + br.Log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to initialize database connection") + if sqlError := (&sqlite3.Error{}); errors.As(err, sqlError) && sqlError.Code == sqlite3.ErrCorrupt { + os.Exit(18) + } + os.Exit(14) + } + br.DB.IgnoreUnsupportedDatabase = *ignoreUnsupportedDatabase + br.DB.IgnoreForeignTables = *ignoreForeignTables +} + +func (br *BridgeMain) validateConfig() error { + switch { + case br.Config.Homeserver.Address == "http://example.localhost:8008": + return errors.New("homeserver.address not configured") + case br.Config.Homeserver.Domain == "example.com": + return errors.New("homeserver.domain not configured") + case !bridgeconfig.AllowedHomeserverSoftware[br.Config.Homeserver.Software]: + return errors.New("invalid value for homeserver.software (use `standard` if you don't know what the field is for)") + case br.Config.AppService.ASToken == "This value is generated when generating the registration": + return errors.New("appservice.as_token not configured. Did you forget to generate the registration? ") + case br.Config.AppService.HSToken == "This value is generated when generating the registration": + return errors.New("appservice.hs_token not configured. Did you forget to generate the registration? ") + case br.Config.Database.URI == "postgres://user:password@host/database?sslmode=disable": + return errors.New("appservice.database not configured") + default: + cfgValidator, ok := br.Connector.(bridgev2.ConfigValidatingNetwork) + if ok { + err := cfgValidator.ValidateConfig() + if err != nil { + return err + } + } + return nil + } +} + +func (br *BridgeMain) getConfigUpgrader() (configupgrade.BaseUpgrader, any) { + networkExample, networkData, networkUpgrader := br.Connector.GetConfig() + baseConfig := br.makeFullExampleConfig(networkExample) + networkUpgraderProxied := &configupgrade.ProxyUpgrader{Target: networkUpgrader, Prefix: []string{"network"}} + upgrader := configupgrade.MergeUpgraders(baseConfig, networkUpgraderProxied, bridgeconfig.Upgrader) + return upgrader, networkData +} + +func (br *BridgeMain) loadConfig() { + upgrader, networkData := br.getConfigUpgrader() + configData, upgraded, err := configupgrade.Do(br.ConfigPath, br.SaveConfig, upgrader) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Error updating config:", err) + if !upgraded { + os.Exit(10) + } + } + + var cfg bridgeconfig.Config + err = yaml.Unmarshal(configData, &cfg) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Failed to parse config:", err) + os.Exit(10) + } + err = cfg.Network.Decode(networkData) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Failed to parse network config:", err) + os.Exit(10) + } + br.Config = &cfg +} + +func (br *BridgeMain) InitVersion(tag, commit, rawBuildTime string) { + br.baseVersion = br.Version + if len(tag) > 0 && tag[0] == 'v' { + tag = tag[1:] + } + if tag != br.Version { + suffix := "" + if !strings.HasSuffix(br.Version, "+dev") { + suffix = "+dev" + } + if len(commit) > 8 { + br.Version = fmt.Sprintf("%s%s.%s", br.Version, suffix, commit[:8]) + } else { + br.Version = fmt.Sprintf("%s%s.unknown", br.Version, suffix) + } + } + + br.LinkifiedVersion = fmt.Sprintf("v%s", br.Version) + if tag == br.Version { + br.LinkifiedVersion = fmt.Sprintf("[v%s](%s/releases/v%s)", br.Version, br.URL, tag) + } else if len(commit) > 8 { + br.LinkifiedVersion = strings.Replace(br.LinkifiedVersion, commit[:8], fmt.Sprintf("[%s](%s/commit/%s)", commit[:8], br.URL, commit), 1) + } + var buildTime time.Time + if rawBuildTime != "unknown" { + buildTime, _ = time.Parse(time.RFC3339, rawBuildTime) + } + var builtWith string + if buildTime.IsZero() { + rawBuildTime = "unknown" + builtWith = runtime.Version() + } else { + rawBuildTime = buildTime.Format(time.RFC1123) + builtWith = fmt.Sprintf("built at %s with %s", rawBuildTime, runtime.Version()) + } + mautrix.DefaultUserAgent = fmt.Sprintf("%s/%s %s", br.Name, br.Version, mautrix.DefaultUserAgent) + br.VersionDesc = fmt.Sprintf("%s %s (%s)", br.Name, br.Version, builtWith) + br.commit = commit + br.BuildTime = buildTime +} diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index b30a764c..5b07ab63 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -11,10 +11,12 @@ import ( "time" "github.com/rs/zerolog" + "go.mau.fi/util/configupgrade" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" ) type ConvertedMessagePart struct { @@ -54,15 +56,49 @@ type ConvertedEdit struct { DeletedParts []*database.Message } +type BridgeName struct { + // The displayname of the network, e.g. `Discord` + DisplayName string + // The URL to the website of the network, e.g. `https://discord.com` + NetworkURL string + // The icon of the network as a mxc:// URI + NetworkIcon id.ContentURIString + // An identifier uniquely identifying the network, e.g. `discord` + NetworkID string + // An identifier uniquely identifying the bridge software, e.g. `discordgo` + BeeperBridgeType string + // The default appservice port to use in the example config, defaults to 8080 if unset + DefaultPort uint16 + // The default command prefix to use in the example config, defaults to NetworkID if unset. Must include the ! prefix. + DefaultCommandPrefix string +} + +func (bn BridgeName) AsBridgeInfoSection() event.BridgeInfoSection { + return event.BridgeInfoSection{ + ID: bn.BeeperBridgeType, + DisplayName: bn.DisplayName, + AvatarURL: bn.NetworkIcon, + ExternalURL: bn.NetworkURL, + } +} + type NetworkConnector interface { Init(*Bridge) Start(context.Context) error LoadUserLogin(ctx context.Context, login *UserLogin) error + GetName() BridgeName + GetConfig() (example string, data any, upgrader configupgrade.Upgrader) + GetLoginFlows() []LoginFlow CreateLogin(ctx context.Context, user *User, flowID string) (LoginProcess, error) } +type ConfigValidatingNetwork interface { + NetworkConnector + ValidateConfig() error +} + type NetworkAPI interface { Connect(ctx context.Context) error IsLoggedIn() bool diff --git a/bridgev2/portal.go b/bridgev2/portal.go index dab5129c..504d3359 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1111,12 +1111,7 @@ func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) { bridgeInfo := event.BridgeEventContent{ BridgeBot: portal.Bridge.Bot.GetMXID(), Creator: portal.Bridge.Bot.GetMXID(), - Protocol: event.BridgeInfoSection{ - ID: "signal", // TODO fill properly - DisplayName: "Signal", // TODO fill properly - AvatarURL: "", // TODO fill properly - ExternalURL: "https://signal.org/", // TODO fill properly - }, + Protocol: portal.Bridge.Network.GetName().AsBridgeInfoSection(), Channel: event.BridgeInfoSection{ ID: string(portal.ID), DisplayName: portal.Name, diff --git a/bridgev2/queue.go b/bridgev2/queue.go index 02fc3895..cf4d27e8 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -43,13 +43,13 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) { if msg != nil { msg.RemoveReplyFallback() - if strings.HasPrefix(msg.Body, br.CommandPrefix) || evt.RoomID == sender.ManagementRoom { + if strings.HasPrefix(msg.Body, br.Config.CommandPrefix) || evt.RoomID == sender.ManagementRoom { br.Commands.Handle( ctx, evt.RoomID, evt.ID, sender, - strings.TrimPrefix(msg.Body, br.CommandPrefix+" "), + strings.TrimPrefix(msg.Body, br.Config.CommandPrefix+" "), msg.RelatesTo.GetReplyTo(), ) return diff --git a/go.mod b/go.mod index ef64ece3..82129a54 100644 --- a/go.mod +++ b/go.mod @@ -8,17 +8,17 @@ require ( github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.22 github.com/rs/xid v1.5.0 - github.com/rs/zerolog v1.32.0 + github.com/rs/zerolog v1.33.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/stretchr/testify v1.9.0 github.com/tidwall/gjson v1.17.1 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.1 - go.mau.fi/util v0.4.3-0.20240516141139-2ebe792cd8f7 + go.mau.fi/util v0.4.3-0.20240611114927-6ef09885dd97 go.mau.fi/zeroconfig v0.1.2 - golang.org/x/crypto v0.23.0 - golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 - golang.org/x/net v0.25.0 + golang.org/x/crypto v0.24.0 + golang.org/x/exp v0.0.0-20240604190554-fc45aab8b7f8 + golang.org/x/net v0.26.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 ) @@ -31,6 +31,6 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect - golang.org/x/sys v0.20.0 // indirect + golang.org/x/sys v0.21.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index c7295275..dce463ef 100644 --- a/go.sum +++ b/go.sum @@ -23,8 +23,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= -github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0= -github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= +github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= @@ -40,21 +40,21 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.1 h1:3bajkSilaCbjdKVsKdZjZCLBNPL9pYzrCakKaf4U49U= github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.4.3-0.20240516141139-2ebe792cd8f7 h1:2hnc2iS7usHT3aqIQ8HVtKtPgic+13EVSdZ1m8UBL/E= -go.mau.fi/util v0.4.3-0.20240516141139-2ebe792cd8f7/go.mod h1:m+PJpPMadAW6cj3ldyuO5bLhFreWdwcu+3QTwYNGlGk= +go.mau.fi/util v0.4.3-0.20240611114927-6ef09885dd97 h1:btYXIv4Iqnboc9FQS99dh8XwMF2QftOhfTeh02K2b4o= +go.mau.fi/util v0.4.3-0.20240611114927-6ef09885dd97/go.mod h1:4etkIWotzgsWICu/1I34Y2LFFekINhFsyWYHXEsxXdY= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= -golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= -golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= -golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= -golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= +golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= +golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= +golang.org/x/exp v0.0.0-20240604190554-fc45aab8b7f8 h1:LoYXNGAShUG3m/ehNk4iFctuhGX/+R1ZpfJ4/ia80JM= +golang.org/x/exp v0.0.0-20240604190554-fc45aab8b7f8/go.mod h1:jj3sYF3dwk5D+ghuXyeI3r5MFf+NT2An6/9dOA95KSI= +golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= +golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= -golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= +golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= From 464b7bc44bdde7a4da6aa968e819ac88f15bffcb Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 11 Jun 2024 16:23:05 +0300 Subject: [PATCH 0270/1647] Allow Matrix auth for provisioning API --- bridgev2/matrix/connector.go | 2 + bridgev2/matrix/doublepuppet.go | 35 +++++++++---- bridgev2/matrix/mxmain/example-config.yaml | 6 ++- bridgev2/matrix/provisioning.go | 59 +++++++++++++++++----- go.mod | 2 +- go.sum | 4 +- 6 files changed, 83 insertions(+), 25 deletions(-) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 3eb9176d..bc1c9b8b 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -55,6 +55,7 @@ type Connector struct { Config *bridgeconfig.Config Bridge *bridgev2.Bridge Provisioning *ProvisioningAPI + DoublePuppet *doublePuppetUtil MediaConfig mautrix.RespMediaConfig SpecVersions *mautrix.RespVersions @@ -106,6 +107,7 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) { br.Crypto = NewCryptoHelper(br) br.Bridge.Commands.AddHandlers(CommandDiscardMegolmSession, CommandSetPowerLevel) br.Provisioning = &ProvisioningAPI{br: br} + br.DoublePuppet = newDoublePuppetUtil(br) } func (br *Connector) Start(ctx context.Context) error { diff --git a/bridgev2/matrix/doublepuppet.go b/bridgev2/matrix/doublepuppet.go index 3c0f65e2..410b2652 100644 --- a/bridgev2/matrix/doublepuppet.go +++ b/bridgev2/matrix/doublepuppet.go @@ -14,6 +14,7 @@ import ( "errors" "fmt" "strings" + "sync" "github.com/rs/zerolog" @@ -25,6 +26,17 @@ import ( type doublePuppetUtil struct { br *Connector log zerolog.Logger + + discoveryCache map[string]string + discoveryCacheLock sync.Mutex +} + +func newDoublePuppetUtil(br *Connector) *doublePuppetUtil { + return &doublePuppetUtil{ + br: br, + log: br.Log.With().Str("component", "double puppet").Logger(), + discoveryCache: make(map[string]string), + } } func (dp *doublePuppetUtil) newClient(ctx context.Context, mxid id.UserID, accessToken string) (*mautrix.Client, error) { @@ -37,16 +49,21 @@ func (dp *doublePuppetUtil) newClient(ctx context.Context, mxid id.UserID, acces if homeserver == dp.br.AS.HomeserverDomain { homeserverURL = "" } else if dp.br.Config.DoublePuppet.AllowDiscovery { - resp, err := mautrix.DiscoverClientAPI(ctx, homeserver) - if err != nil { - return nil, fmt.Errorf("failed to find homeserver URL for %s: %v", homeserver, err) + dp.discoveryCacheLock.Lock() + defer dp.discoveryCacheLock.Unlock() + if homeserverURL, found = dp.discoveryCache[homeserver]; !found { + resp, err := mautrix.DiscoverClientAPI(ctx, homeserver) + if err != nil { + return nil, fmt.Errorf("failed to find homeserver URL for %s: %v", homeserver, err) + } + homeserverURL = resp.Homeserver.BaseURL + dp.discoveryCache[homeserver] = homeserverURL + dp.log.Debug(). + Str("homeserver", homeserver). + Str("url", homeserverURL). + Str("user_id", mxid.String()). + Msg("Discovered URL to enable double puppeting for user") } - homeserverURL = resp.Homeserver.BaseURL - dp.log.Debug(). - Str("homeserver", homeserver). - Str("url", homeserverURL). - Str("user_id", mxid.String()). - Msg("Discovered URL to enable double puppeting for user") } else { return nil, fmt.Errorf("double puppeting from %s is not allowed", homeserver) } diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 06d9f8ca..2a4b805a 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -110,6 +110,10 @@ provisioning: # Shared secret for authentication. If set to "generate" or null, a random secret will be generated, # or if set to "disable", the provisioning API will be disabled. shared_secret: generate + # Whether to allow provisioning API requests to be authed using Matrix access tokens. + # This follows the same rules as double puppeting to determine which server to contact to check the token, + # which means that by default, it only works for users on the same server as the bridge. + allow_matrix_auth: true # Enable debug API at /debug with provisioning authentication. debug_endpoints: true @@ -118,7 +122,7 @@ double_puppet: # Servers to always allow double puppeting from. # This is only for other servers and should NOT contain the server the bridge is on. servers: - matrix.org: https://matrix-client.matrix.org + anotherserver.example.org: https://matrix.anotherserver.example.org # Whether to allow client API URL discovery for other servers. When using this option, # users on other servers can use double puppeting even if their server URLs aren't # explicitly added to the servers map above. diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 7164d961..864edbfe 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -9,21 +9,28 @@ package matrix import ( "context" "encoding/json" + "fmt" "net/http" "strings" "sync" + "time" "github.com/gorilla/mux" "github.com/rs/xid" "github.com/rs/zerolog" "github.com/rs/zerolog/hlog" - "github.com/rs/zerolog/log" + "go.mau.fi/util/requestlog" "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/id" ) +type matrixAuthCacheEntry struct { + Expires time.Time + UserID id.UserID +} + type ProvisioningAPI struct { br *Connector log zerolog.Logger @@ -31,6 +38,9 @@ type ProvisioningAPI struct { logins map[string]*ProvLogin loginsLock sync.RWMutex + + matrixAuthCache map[string]matrixAuthCacheEntry + matrixAuthCacheLock sync.Mutex } type ProvLogin struct { @@ -48,12 +58,13 @@ const ( ) func (prov *ProvisioningAPI) Init() { + prov.matrixAuthCache = make(map[string]matrixAuthCacheEntry) + prov.logins = make(map[string]*ProvLogin) prov.net = prov.br.Bridge.Network prov.log = prov.br.Log.With().Str("component", "provisioning").Logger() router := prov.br.AS.Router.PathPrefix(prov.br.Config.Provisioning.Prefix).Subrouter() router.Use(hlog.NewHandler(prov.log)) - // TODO add access logger - //router.Use(requestlog.AccessLogger(true)) + router.Use(requestlog.AccessLogger(false)) router.Use(prov.AuthMiddleware) router.Path("/v3/login/flows").Methods(http.MethodGet).HandlerFunc(prov.GetLoginFlows) router.Path("/v3/login/start/{flowID}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginStart) @@ -61,7 +72,7 @@ func (prov *ProvisioningAPI) Init() { router.Path("/v3/login/step/{loginID}/{stepID}/{stepType:wait}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginWait) if prov.br.Config.Provisioning.DebugEndpoints { - log.Debug().Msg("Enabling debug API at /debug") + prov.log.Debug().Msg("Enabling debug API at /debug") r := prov.br.AS.Router.PathPrefix("/debug").Subrouter() r.Use(prov.AuthMiddleware) r.PathPrefix("/pprof").Handler(http.DefaultServeMux) @@ -74,19 +85,43 @@ func jsonResponse(w http.ResponseWriter, status int, response any) { _ = json.NewEncoder(w).Encode(response) } +func (prov *ProvisioningAPI) checkMatrixAuth(ctx context.Context, userID id.UserID, token string) error { + prov.matrixAuthCacheLock.Lock() + defer prov.matrixAuthCacheLock.Unlock() + if cached, ok := prov.matrixAuthCache[token]; ok && cached.Expires.After(time.Now()) && cached.UserID == userID { + return nil + } else if client, err := prov.br.DoublePuppet.newClient(ctx, userID, token); err != nil { + return err + } else if whoami, err := client.Whoami(ctx); err != nil { + return err + } else if whoami.UserID != userID { + return fmt.Errorf("mismatching user ID (%q != %q)", whoami.UserID, userID) + } else { + prov.matrixAuthCache[token] = matrixAuthCacheEntry{ + Expires: time.Now().Add(5 * time.Minute), + UserID: whoami.UserID, + } + return nil + } +} + func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") + userID := id.UserID(r.URL.Query().Get("user_id")) if auth != prov.br.Config.Provisioning.SharedSecret { - zerolog.Ctx(r.Context()).Warn().Msg("Authentication token does not match shared secret") - jsonResponse(w, http.StatusForbidden, &mautrix.RespError{ - Err: "Authentication token does not match shared secret", - ErrCode: mautrix.MForbidden.ErrCode, - }) - return + err := prov.checkMatrixAuth(r.Context(), userID, auth) + if err != nil { + zerolog.Ctx(r.Context()).Warn().Err(err). + Msg("Provisioning API request contained invalid auth") + jsonResponse(w, http.StatusForbidden, &mautrix.RespError{ + Err: "Invalid auth token", + ErrCode: mautrix.MForbidden.ErrCode, + }) + return + } } - userID := r.URL.Query().Get("user_id") - user, err := prov.br.Bridge.GetUserByMXID(r.Context(), id.UserID(userID)) + user, err := prov.br.Bridge.GetUserByMXID(r.Context(), userID) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get user") jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ diff --git a/go.mod b/go.mod index 82129a54..4247ace2 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/tidwall/gjson v1.17.1 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.1 - go.mau.fi/util v0.4.3-0.20240611114927-6ef09885dd97 + go.mau.fi/util v0.4.3-0.20240611132549-e72a5f4745e7 go.mau.fi/zeroconfig v0.1.2 golang.org/x/crypto v0.24.0 golang.org/x/exp v0.0.0-20240604190554-fc45aab8b7f8 diff --git a/go.sum b/go.sum index dce463ef..d918cc61 100644 --- a/go.sum +++ b/go.sum @@ -40,8 +40,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.1 h1:3bajkSilaCbjdKVsKdZjZCLBNPL9pYzrCakKaf4U49U= github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.4.3-0.20240611114927-6ef09885dd97 h1:btYXIv4Iqnboc9FQS99dh8XwMF2QftOhfTeh02K2b4o= -go.mau.fi/util v0.4.3-0.20240611114927-6ef09885dd97/go.mod h1:4etkIWotzgsWICu/1I34Y2LFFekINhFsyWYHXEsxXdY= +go.mau.fi/util v0.4.3-0.20240611132549-e72a5f4745e7 h1:DviEWXBpeOlFrqIf5s/iBDp1ewZx8fe6imMJ78kq3tA= +go.mau.fi/util v0.4.3-0.20240611132549-e72a5f4745e7/go.mod h1:Eaj7jl37ehkA7S6vE/vfPs5PsY8e91FKZ2BqA3OM/NU= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= From 1ff72aeffb74f4341a1df6ff0a71b3852e223f7f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 11 Jun 2024 17:04:06 +0300 Subject: [PATCH 0271/1647] Mostly implement double puppeting --- bridgev2/matrix/cmddoublepuppet.go | 48 ++++++++++-------- bridgev2/matrix/connector.go | 17 +++++-- bridgev2/matrix/doublepuppet.go | 81 ++++-------------------------- bridgev2/matrixinterface.go | 2 +- bridgev2/portal.go | 6 +-- bridgev2/user.go | 69 +++++++++++++++++++++++++ 6 files changed, 125 insertions(+), 98 deletions(-) diff --git a/bridgev2/matrix/cmddoublepuppet.go b/bridgev2/matrix/cmddoublepuppet.go index 1b755f36..13d24f54 100644 --- a/bridgev2/matrix/cmddoublepuppet.go +++ b/bridgev2/matrix/cmddoublepuppet.go @@ -26,12 +26,12 @@ func fnLoginMatrix(ce *bridgev2.CommandEvent) { ce.Reply("**Usage:** `login-matrix `") return } - //err := ce.User.SwitchCustomMXID(ce.Args[0], ce.User.GetMXID()) - //if err != nil { - // ce.Reply("Failed to enable double puppeting: %v", err) - //} else { - // ce.Reply("Successfully switched puppet") - //} + err := ce.User.LoginDoublePuppet(ce.Ctx, ce.Args[0]) + if err != nil { + ce.Reply("Failed to enable double puppeting: %v", err) + } else { + ce.Reply("Successfully switched puppets") + } } var CommandPingMatrix = &bridgev2.FullHandler{ @@ -41,16 +41,25 @@ var CommandPingMatrix = &bridgev2.FullHandler{ Section: bridgev2.HelpSectionAuth, Description: "Ping the Matrix server with the double puppet.", }, - RequiresLogin: true, } func fnPingMatrix(ce *bridgev2.CommandEvent) { - //resp, err := puppet.CustomIntent().Whoami(ce.Ctx) - //if err != nil { - // ce.Reply("Failed to validate Matrix login: %v", err) - //} else { - // ce.Reply("Confirmed valid access token for %s / %s", resp.UserID, resp.DeviceID) - //} + intent := ce.User.DoublePuppet(ce.Ctx) + if intent == nil { + ce.Reply("You don't have double puppeting enabled.") + return + } + asIntent := intent.(*ASIntent) + resp, err := asIntent.Matrix.Whoami(ce.Ctx) + if err != nil { + ce.Reply("Failed to validate Matrix login: %v", err) + } else { + if asIntent.Matrix.SetAppServiceUserID && resp.DeviceID == "" { + ce.Reply("Confirmed valid access token for %s (appservice double puppeting)", resp.UserID) + } else { + ce.Reply("Confirmed valid access token for %s / %s", resp.UserID, resp.DeviceID) + } + } } var CommandLogoutMatrix = &bridgev2.FullHandler{ @@ -64,11 +73,10 @@ var CommandLogoutMatrix = &bridgev2.FullHandler{ } func fnLogoutMatrix(ce *bridgev2.CommandEvent) { - //puppet := ce.User.GetIDoublePuppet() - //if puppet == nil || puppet.CustomIntent() == nil { - // ce.Reply("You don't have double puppeting enabled.") - // return - //} - //puppet.ClearCustomMXID() - //ce.Reply("Successfully disabled double puppeting.") + if ce.User.AccessToken == "" { + ce.Reply("You don't have double puppeting enabled.") + return + } + ce.User.LogoutDoublePuppet(ce.Ctx) + ce.Reply("Successfully disabled double puppeting.") } diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index bc1c9b8b..6511b6ac 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -105,7 +105,10 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) { br.EventProcessor.On(event.StateMember, br.handleRoomEvent) br.Bot = br.AS.BotIntent() br.Crypto = NewCryptoHelper(br) - br.Bridge.Commands.AddHandlers(CommandDiscardMegolmSession, CommandSetPowerLevel) + br.Bridge.Commands.AddHandlers( + CommandDiscardMegolmSession, CommandSetPowerLevel, + CommandLoginMatrix, CommandPingMatrix, CommandLogoutMatrix, + ) br.Provisioning = &ProvisioningAPI{br: br} br.DoublePuppet = newDoublePuppetUtil(br) } @@ -367,9 +370,15 @@ func (br *Connector) FormatGhostMXID(userID networkid.UserID) id.UserID { return id.NewUserID(localpart, br.Config.Homeserver.Domain) } -func (br *Connector) UserIntent(user *bridgev2.User) bridgev2.MatrixAPI { - // TODO implement double puppeting - return nil +func (br *Connector) NewUserIntent(ctx context.Context, userID id.UserID, accessToken string) (bridgev2.MatrixAPI, string, error) { + intent, newToken, err := br.DoublePuppet.Setup(ctx, userID, accessToken) + if err != nil { + if errors.Is(err, ErrNoAccessToken) { + err = nil + } + return nil, accessToken, err + } + return &ASIntent{Connector: br, Matrix: intent}, newToken, nil } func (br *Connector) BotIntent() bridgev2.MatrixAPI { diff --git a/bridgev2/matrix/doublepuppet.go b/bridgev2/matrix/doublepuppet.go index 410b2652..ace33f30 100644 --- a/bridgev2/matrix/doublepuppet.go +++ b/bridgev2/matrix/doublepuppet.go @@ -8,9 +8,6 @@ package matrix import ( "context" - "crypto/hmac" - "crypto/sha512" - "encoding/hex" "errors" "fmt" "strings" @@ -24,8 +21,7 @@ import ( ) type doublePuppetUtil struct { - br *Connector - log zerolog.Logger + br *Connector discoveryCache map[string]string discoveryCacheLock sync.Mutex @@ -34,7 +30,6 @@ type doublePuppetUtil struct { func newDoublePuppetUtil(br *Connector) *doublePuppetUtil { return &doublePuppetUtil{ br: br, - log: br.Log.With().Str("component", "double puppet").Logger(), discoveryCache: make(map[string]string), } } @@ -58,7 +53,7 @@ func (dp *doublePuppetUtil) newClient(ctx context.Context, mxid id.UserID, acces } homeserverURL = resp.Homeserver.BaseURL dp.discoveryCache[homeserver] = homeserverURL - dp.log.Debug(). + zerolog.Ctx(ctx).Debug(). Str("homeserver", homeserver). Str("url", homeserverURL). Str("user_id", mxid.String()). @@ -85,47 +80,6 @@ func (dp *doublePuppetUtil) newIntent(ctx context.Context, mxid id.UserID, acces return ia, nil } -func (dp *doublePuppetUtil) autoLogin(ctx context.Context, mxid id.UserID, loginSecret string) (string, error) { - dp.log.Debug().Str("user_id", mxid.String()).Msg("Logging into user account with shared secret") - client, err := dp.newClient(ctx, mxid, "") - if err != nil { - return "", fmt.Errorf("failed to create mautrix client to log in: %v", err) - } - bridgeName := fmt.Sprintf("%s Bridge", dp.br.Bridge.Network.GetName().DisplayName) - req := mautrix.ReqLogin{ - Identifier: mautrix.UserIdentifier{Type: mautrix.IdentifierTypeUser, User: string(mxid)}, - DeviceID: id.DeviceID(bridgeName), - InitialDeviceDisplayName: bridgeName, - } - if loginSecret == "appservice" { - client.AccessToken = dp.br.AS.Registration.AppToken - req.Type = mautrix.AuthTypeAppservice - } else { - loginFlows, err := client.GetLoginFlows(ctx) - if err != nil { - return "", fmt.Errorf("failed to get supported login flows: %w", err) - } - mac := hmac.New(sha512.New, []byte(loginSecret)) - mac.Write([]byte(mxid)) - token := hex.EncodeToString(mac.Sum(nil)) - switch { - case loginFlows.HasFlow(mautrix.AuthTypeDevtureSharedSecret): - req.Type = mautrix.AuthTypeDevtureSharedSecret - req.Token = token - case loginFlows.HasFlow(mautrix.AuthTypePassword): - req.Type = mautrix.AuthTypePassword - req.Password = token - default: - return "", fmt.Errorf("no supported auth types for shared secret auth found") - } - } - resp, err := client.Login(ctx, &req) - if err != nil { - return "", err - } - return resp.AccessToken, nil -} - var ( ErrMismatchingMXID = errors.New("whoami result does not match custom mxid") ErrNoAccessToken = errors.New("no access token provided") @@ -135,14 +89,13 @@ var ( const useConfigASToken = "appservice-config" const asTokenModePrefix = "as_token:" -func (dp *doublePuppetUtil) Setup(ctx context.Context, mxid id.UserID, savedAccessToken string, reloginOnFail bool) (intent *appservice.IntentAPI, newAccessToken string, err error) { +func (dp *doublePuppetUtil) Setup(ctx context.Context, mxid id.UserID, savedAccessToken string) (intent *appservice.IntentAPI, newAccessToken string, err error) { if len(mxid) == 0 { err = ErrNoMXID return } _, homeserver, _ := mxid.Parse() loginSecret, hasSecret := dp.br.Config.DoublePuppet.Secrets[homeserver] - // Special case appservice: prefix to not login and use it as an as_token directly. if hasSecret && strings.HasPrefix(loginSecret, asTokenModePrefix) { intent, err = dp.newIntent(ctx, mxid, strings.TrimPrefix(loginSecret, asTokenModePrefix)) if err != nil { @@ -157,16 +110,9 @@ func (dp *doublePuppetUtil) Setup(ctx context.Context, mxid id.UserID, savedAcce } } return intent, useConfigASToken, err - } - if savedAccessToken == "" || savedAccessToken == useConfigASToken { - if reloginOnFail && hasSecret { - savedAccessToken, err = dp.autoLogin(ctx, mxid, loginSecret) - } else { - err = ErrNoAccessToken - } - if err != nil { - return - } + } else if savedAccessToken == "" || savedAccessToken == useConfigASToken { + err = ErrNoAccessToken + return } intent, err = dp.newIntent(ctx, mxid, savedAccessToken) if err != nil { @@ -174,17 +120,12 @@ func (dp *doublePuppetUtil) Setup(ctx context.Context, mxid id.UserID, savedAcce } var resp *mautrix.RespWhoami resp, err = intent.Whoami(ctx) - if err != nil { - if reloginOnFail && hasSecret && errors.Is(err, mautrix.MUnknownToken) { - intent.AccessToken, err = dp.autoLogin(ctx, mxid, loginSecret) - if err == nil { - newAccessToken = intent.AccessToken - } + if err == nil { + if resp.UserID != mxid { + err = ErrMismatchingMXID + } else { + newAccessToken = savedAccessToken } - } else if resp.UserID != mxid { - err = ErrMismatchingMXID - } else { - newAccessToken = savedAccessToken } return } diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index b791a028..c017eb92 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -24,7 +24,7 @@ type MatrixConnector interface { FormatGhostMXID(userID networkid.UserID) id.UserID GhostIntent(userID id.UserID) MatrixAPI - UserIntent(user *User) MatrixAPI + NewUserIntent(ctx context.Context, userID id.UserID, accessToken string) (MatrixAPI, string, error) BotIntent() MatrixAPI SendMessageStatus(ctx context.Context, status *MessageStatus, evt *MessageStatusEventInfo) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 504d3359..49ed3f96 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -691,12 +691,12 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { func (portal *Portal) getIntentFor(ctx context.Context, sender EventSender, source *UserLogin) MatrixAPI { var intent MatrixAPI if sender.IsFromMe { - intent = portal.Bridge.Matrix.UserIntent(source.User) + intent = source.User.DoublePuppet(ctx) } if intent == nil && sender.SenderLogin != "" { senderLogin := portal.Bridge.GetCachedUserLoginByID(sender.SenderLogin) if senderLogin != nil { - intent = portal.Bridge.Matrix.UserIntent(senderLogin.User) + intent = senderLogin.User.DoublePuppet(ctx) } } if intent == nil { @@ -1184,7 +1184,7 @@ func (portal *Portal) SyncParticipants(ctx context.Context, members []networkid. for _, login := range loginsInPortal { if login.Client.IsThisUser(ctx, member) { isLoggedInUser = true - userIntent := portal.Bridge.Matrix.UserIntent(login.User) + userIntent := login.User.DoublePuppet(ctx) if userIntent != nil { expectedIntents[i] = userIntent } else { diff --git a/bridgev2/user.go b/bridgev2/user.go index 82d841c2..bf8eaf13 100644 --- a/bridgev2/user.go +++ b/bridgev2/user.go @@ -9,6 +9,7 @@ package bridgev2 import ( "context" "fmt" + "sync" "sync/atomic" "github.com/rs/zerolog" @@ -25,6 +26,10 @@ type User struct { CommandState atomic.Pointer[CommandState] + doublePuppetIntent MatrixAPI + doublePuppetInitialized bool + doublePuppetLock sync.Mutex + logins map[networkid.UserLoginID]*UserLogin } @@ -83,3 +88,67 @@ func (br *Bridge) GetExistingUserByMXID(ctx context.Context, userID id.UserID) ( defer br.cacheLock.Unlock() return br.unlockedGetUserByMXID(ctx, userID, true) } + +func (user *User) LogoutDoublePuppet(ctx context.Context) { + user.doublePuppetLock.Lock() + defer user.doublePuppetLock.Unlock() + user.AccessToken = "" + err := user.Save(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save removed access token") + } + user.doublePuppetIntent = nil + user.doublePuppetInitialized = false +} + +func (user *User) LoginDoublePuppet(ctx context.Context, token string) error { + if token == "" { + return fmt.Errorf("no token provided") + } + user.doublePuppetLock.Lock() + defer user.doublePuppetLock.Unlock() + intent, newToken, err := user.Bridge.Matrix.NewUserIntent(ctx, user.MXID, token) + if err != nil { + return err + } + user.AccessToken = newToken + user.doublePuppetIntent = intent + user.doublePuppetInitialized = true + err = user.Save(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save new access token") + } + if newToken != token { + return fmt.Errorf("logging in manually is not supported when automatic double puppeting is enabled") + } + return nil +} + +func (user *User) DoublePuppet(ctx context.Context) MatrixAPI { + user.doublePuppetLock.Lock() + defer user.doublePuppetLock.Unlock() + if user.doublePuppetInitialized { + return user.doublePuppetIntent + } + user.doublePuppetInitialized = true + log := user.Log.With().Str("action", "setup double puppet").Logger() + ctx = log.WithContext(ctx) + intent, newToken, err := user.Bridge.Matrix.NewUserIntent(ctx, user.MXID, user.AccessToken) + if err != nil { + log.Err(err).Msg("Failed to create new user intent") + return nil + } + user.doublePuppetIntent = intent + if newToken != user.AccessToken { + user.AccessToken = newToken + err = user.Save(ctx) + if err != nil { + log.Warn().Err(err).Msg("Failed to save new access token") + } + } + return intent +} + +func (user *User) Save(ctx context.Context) error { + return user.Bridge.DB.User.Update(ctx, user.User) +} From f14c5aafb912ca321767ddad0d1f94ed4836938d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 11 Jun 2024 17:06:47 +0300 Subject: [PATCH 0272/1647] Don't send MSS event for first decryption error --- bridgev2/matrix/connector.go | 2 +- bridgev2/matrix/cryptoerror.go | 1 + bridgev2/messagestatus.go | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 6511b6ac..3fbda800 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -308,7 +308,7 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2 if err != nil { log.Err(err).Msg("Failed to send message checkpoint") } - if br.Config.Matrix.MessageStatusEvents { + if !ms.DisableMSS && br.Config.Matrix.MessageStatusEvents { _, err = br.Bot.SendMessageEvent(ctx, evt.RoomID, event.BeeperMessageStatus, ms.ToMSSEvent(evt)) if err != nil { log.Err(err).Msg("Failed to send MSS event") diff --git a/bridgev2/matrix/cryptoerror.go b/bridgev2/matrix/cryptoerror.go index 93cf0e75..55110429 100644 --- a/bridgev2/matrix/cryptoerror.go +++ b/bridgev2/matrix/cryptoerror.go @@ -80,6 +80,7 @@ func (br *Connector) sendCryptoStatusError(ctx context.Context, evt *event.Event // Don't send notice for first error if retryNum == 0 { ms.SendNotice = false + ms.DisableMSS = true } } var editEventID id.EventID diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index 43fe5da0..16107422 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -60,6 +60,7 @@ type MessageStatus struct { ErrorAsMessage bool IsCertain bool SendNotice bool + DisableMSS bool } func WrapErrorInStatus(err error) MessageStatus { From f9e2159842f21bf238bbed16d95271c6a39e0b07 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 11 Jun 2024 17:17:20 +0300 Subject: [PATCH 0273/1647] Ignore incoming double puppeted events --- appservice/appservice.go | 1 + appservice/intent.go | 44 ++++++++++++++++++++++++++-------- bridgev2/matrix/connector.go | 2 +- bridgev2/matrix/intent.go | 3 +++ bridgev2/matrix/matrix.go | 9 ++++++- bridgev2/matrix/mxmain/main.go | 1 + 6 files changed, 48 insertions(+), 12 deletions(-) diff --git a/appservice/appservice.go b/appservice/appservice.go index ef9c6236..814cd5c4 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -193,6 +193,7 @@ type AppService struct { } const DoublePuppetKey = "fi.mau.double_puppet_source" +const DoublePuppetTSKey = "fi.mau.double_puppet_ts" func getDefaultProcessID() string { pid := syscall.Getpid() diff --git a/appservice/intent.go b/appservice/intent.go index 39c22d7f..31ba4732 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -130,40 +130,64 @@ func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, ext return nil } -func (intent *IntentAPI) AddDoublePuppetValue(into interface{}) interface{} { +func (intent *IntentAPI) AddDoublePuppetValue(into any) any { + return intent.AddDoublePuppetValueWithTS(into, 0) +} + +func (intent *IntentAPI) AddDoublePuppetValueWithTS(into any, ts int64) any { if !intent.IsCustomPuppet || intent.as.DoublePuppetValue == "" { return into } + // Only use ts deduplication feature with appservice double puppeting + if !intent.SetAppServiceUserID { + ts = 0 + } switch val := into.(type) { - case *map[string]interface{}: + case *map[string]any: if *val == nil { - valNonPtr := make(map[string]interface{}) + valNonPtr := make(map[string]any) *val = valNonPtr } (*val)[DoublePuppetKey] = intent.as.DoublePuppetValue + if ts != 0 { + (*val)[DoublePuppetTSKey] = ts + } return val - case map[string]interface{}: + case map[string]any: val[DoublePuppetKey] = intent.as.DoublePuppetValue + if ts != 0 { + val[DoublePuppetTSKey] = ts + } return val case *event.Content: if val.Raw == nil { - val.Raw = make(map[string]interface{}) + val.Raw = make(map[string]any) } val.Raw[DoublePuppetKey] = intent.as.DoublePuppetValue + if ts != 0 { + val.Raw[DoublePuppetTSKey] = ts + } return val case event.Content: if val.Raw == nil { - val.Raw = make(map[string]interface{}) + val.Raw = make(map[string]any) } val.Raw[DoublePuppetKey] = intent.as.DoublePuppetValue + if ts != 0 { + val.Raw[DoublePuppetTSKey] = ts + } return val default: - return &event.Content{ - Raw: map[string]interface{}{ + content := &event.Content{ + Raw: map[string]any{ DoublePuppetKey: intent.as.DoublePuppetValue, }, Parsed: val, } + if ts != 0 { + content.Raw[DoublePuppetTSKey] = ts + } + return content } } @@ -179,7 +203,7 @@ func (intent *IntentAPI) SendMassagedMessageEvent(ctx context.Context, roomID id if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } - contentJSON = intent.AddDoublePuppetValue(contentJSON) + contentJSON = intent.AddDoublePuppetValueWithTS(contentJSON, ts) return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts}) } @@ -197,7 +221,7 @@ func (intent *IntentAPI) SendMassagedStateEvent(ctx context.Context, roomID id.R if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } - contentJSON = intent.AddDoublePuppetValue(contentJSON) + contentJSON = intent.AddDoublePuppetValueWithTS(contentJSON, ts) return intent.Client.SendMassagedStateEvent(ctx, roomID, eventType, stateKey, contentJSON, ts) } diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 3fbda800..04905bfd 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -86,7 +86,7 @@ func NewConnector(cfg *bridgeconfig.Config) *Connector { func (br *Connector) Init(bridge *bridgev2.Bridge) { br.Bridge = bridge br.Log = &bridge.Log - br.StateStore = sqlstatestore.NewSQLStateStore(bridge.DB.Database, dbutil.ZeroLogger(br.Log.With().Str("db_section", "matrix").Logger()), false) + br.StateStore = sqlstatestore.NewSQLStateStore(bridge.DB.Database, dbutil.ZeroLogger(br.Log.With().Str("db_section", "matrix_state").Logger()), false) br.AS = br.Config.MakeAppService() br.AS.Log = bridge.Log br.AS.StateStore = br.StateStore diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 25856cb0..6c81349b 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -40,6 +40,9 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType if encrypted, err := as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil { return nil, fmt.Errorf("failed to check if room is encrypted: %w", err) } else if encrypted { + if as.Matrix.IsCustomPuppet { + as.Matrix.AddDoublePuppetValueWithTS(content, ts.UnixMilli()) + } err = as.Connector.Crypto.Encrypt(ctx, roomID, eventType, content) if err != nil { return nil, err diff --git a/bridgev2/matrix/matrix.go b/bridgev2/matrix/matrix.go index bd2a1bb9..8a96f039 100644 --- a/bridgev2/matrix/matrix.go +++ b/bridgev2/matrix/matrix.go @@ -14,6 +14,7 @@ import ( "github.com/rs/zerolog" + "maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/event" @@ -143,7 +144,13 @@ func (br *Connector) shouldIgnoreEvent(evt *event.Event) bool { if isGhost { return true } - // TODO exclude double puppeted events + dpVal, ok := evt.Content.Raw[appservice.DoublePuppetKey] + if ok && dpVal == br.AS.DoublePuppetValue { + dpTS, ok := evt.Content.Raw[appservice.DoublePuppetTSKey].(float64) + if !ok || int64(dpTS) == evt.Timestamp { + return true + } + } return false } diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index 85269fce..284a15e1 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -214,6 +214,7 @@ func (br *BridgeMain) Init() { br.Matrix = matrix.NewConnector(br.Config) br.Matrix.IgnoreUnsupportedServer = *ignoreUnsupportedServer br.Bridge = bridgev2.NewBridge("", br.DB, *br.Log, &br.Config.Bridge, br.Matrix, br.Connector) + br.Matrix.AS.DoublePuppetValue = br.Name if br.PostInit != nil { br.PostInit() } From d58e8f88173b790c6287aa3576ed3051e725ac38 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 11 Jun 2024 20:26:04 +0300 Subject: [PATCH 0274/1647] Add bridge state queue for user logins --- bridgev2/bridge.go | 3 +- bridgev2/bridgestate.go | 132 +++++++++++++++++++++++++++++++++ bridgev2/matrix/connector.go | 18 ++++- bridgev2/matrix/mxmain/main.go | 2 +- bridgev2/matrixinterface.go | 2 + bridgev2/userlogin.go | 22 +++++- 6 files changed, 175 insertions(+), 4 deletions(-) create mode 100644 bridgev2/bridgestate.go diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 0f77cea5..a12da24b 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -15,6 +15,7 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/dbutil" + "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2/bridgeconfig" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" @@ -113,7 +114,7 @@ func (br *Bridge) Start() error { } if len(logins) == 0 { br.Log.Info().Msg("No user logins found") - // TODO send UNCONFIGURED bridge state + br.SendGlobalBridgeState(status.BridgeState{StateEvent: status.StateUnconfigured}) } br.Log.Info().Msg("Bridge started") diff --git a/bridgev2/bridgestate.go b/bridgev2/bridgestate.go new file mode 100644 index 00000000..961d9e31 --- /dev/null +++ b/bridgev2/bridgestate.go @@ -0,0 +1,132 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "runtime/debug" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/bridge/status" +) + +type BridgeStateQueue struct { + prev *status.BridgeState + ch chan status.BridgeState + bridge *Bridge + user status.BridgeStateFiller +} + +func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) { + for { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + if err := br.Matrix.SendBridgeStatus(ctx, &state); err != nil { + br.Log.Warn().Err(err).Msg("Failed to update global bridge state") + cancel() + time.Sleep(5 * time.Second) + continue + } else { + br.Log.Debug().Any("bridge_state", state).Msg("Sent new global bridge state") + cancel() + break + } + } +} + +func (br *Bridge) NewBridgeStateQueue(user status.BridgeStateFiller) *BridgeStateQueue { + bsq := &BridgeStateQueue{ + ch: make(chan status.BridgeState, 10), + bridge: br, + user: user, + } + go bsq.loop() + return bsq +} + +func (bsq *BridgeStateQueue) loop() { + defer func() { + err := recover() + if err != nil { + bsq.bridge.Log.Error(). + Bytes(zerolog.ErrorStackFieldName, debug.Stack()). + Any(zerolog.ErrorFieldName, err). + Msg("Panic in bridge state loop") + } + }() + for state := range bsq.ch { + bsq.immediateSendBridgeState(state) + } +} + +func (bsq *BridgeStateQueue) immediateSendBridgeState(state status.BridgeState) { + retryIn := 2 + for { + if bsq.prev != nil && bsq.prev.ShouldDeduplicate(&state) { + bsq.bridge.Log.Debug(). + Str("state_event", string(state.StateEvent)). + Msg("Not sending bridge state as it's a duplicate") + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + err := bsq.bridge.Matrix.SendBridgeStatus(ctx, &state) + cancel() + + if err != nil { + bsq.bridge.Log.Warn().Err(err). + Int("retry_in_seconds", retryIn). + Msg("Failed to update bridge state") + time.Sleep(time.Duration(retryIn) * time.Second) + retryIn *= 2 + if retryIn > 64 { + retryIn = 64 + } + } else { + bsq.prev = &state + bsq.bridge.Log.Debug(). + Any("bridge_state", state). + Msg("Sent new bridge state") + return + } + } +} + +func (bsq *BridgeStateQueue) Send(state status.BridgeState) { + if bsq == nil { + return + } + + state = state.Fill(bsq.user) + + if len(bsq.ch) >= 8 { + bsq.bridge.Log.Warn().Msg("Bridge state queue is nearly full, discarding an item") + select { + case <-bsq.ch: + default: + } + } + select { + case bsq.ch <- state: + default: + bsq.bridge.Log.Error().Msg("Bridge state queue is full, dropped new state") + } +} + +func (bsq *BridgeStateQueue) GetPrev() status.BridgeState { + if bsq != nil && bsq.prev != nil { + return *bsq.prev + } + return status.BridgeState{} +} + +func (bsq *BridgeStateQueue) SetPrev(prev status.BridgeState) { + if bsq != nil { + bsq.prev = &prev + } +} diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 04905bfd..76c78bfd 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -46,7 +46,6 @@ type Crypto interface { } type Connector struct { - //DB *dbutil.Database AS *appservice.AppService Bot *appservice.IntentAPI StateStore *sqlstatestore.SQLStateStore @@ -71,6 +70,7 @@ type Connector struct { wsStopped chan struct{} wsShortCircuitReconnectBackoff chan struct{} wsStartupWait *sync.WaitGroup + latestState *status.BridgeState } var _ bridgev2.MatrixConnector = (*Connector)(nil) @@ -298,6 +298,22 @@ func (br *Connector) GhostIntent(userID id.UserID) bridgev2.MatrixAPI { } } +func (br *Connector) SendBridgeStatus(ctx context.Context, state *status.BridgeState) error { + if br.Websocket { + // FIXME this doesn't account for multiple users + br.latestState = state + + return br.AS.SendWebsocket(&appservice.WebsocketRequest{ + Command: "bridge_status", + Data: state, + }) + } else if br.Config.Homeserver.StatusEndpoint != "" { + return state.SendHTTP(ctx, br.Config.Homeserver.StatusEndpoint, br.Config.AppService.ASToken) + } else { + return nil + } +} + func (br *Connector) SendMessageStatus(ctx context.Context, ms *bridgev2.MessageStatus, evt *bridgev2.MessageStatusEventInfo) { br.internalSendMessageStatus(ctx, ms, evt, "") } diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index 284a15e1..2f881439 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -195,6 +195,7 @@ func (br *BridgeMain) Init() { _, _ = fmt.Fprintln(os.Stderr, "Failed to initialize logger:", err) os.Exit(12) } + exzerolog.SetupDefaults(br.Log) err = br.validateConfig() if err != nil { br.Log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Configuration error") @@ -209,7 +210,6 @@ func (br *BridgeMain) Init() { Str("go_version", runtime.Version()). Msg("Initializing bridge") - exzerolog.SetupDefaults(br.Log) br.initDB() br.Matrix = matrix.NewConnector(br.Config) br.Matrix.IgnoreUnsupportedServer = *ignoreUnsupportedServer diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index c017eb92..350bf005 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -11,6 +11,7 @@ import ( "time" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -27,6 +28,7 @@ type MatrixConnector interface { NewUserIntent(ctx context.Context, userID id.UserID, accessToken string) (MatrixAPI, string, error) BotIntent() MatrixAPI + SendBridgeStatus(ctx context.Context, state *status.BridgeState) error SendMessageStatus(ctx context.Context, status *MessageStatus, evt *MessageStatusEventInfo) GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 429c79cd..b531bdfb 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -12,8 +12,10 @@ import ( "github.com/rs/zerolog" + "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/id" ) type UserLogin struct { @@ -22,7 +24,8 @@ type UserLogin struct { User *User Log zerolog.Logger - Client NetworkAPI + Client NetworkAPI + BridgeState *BridgeStateQueue } func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *database.UserLogin) (*UserLogin, error) { @@ -45,6 +48,7 @@ func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *da } user.logins[userLogin.ID] = userLogin br.userLoginsByID[userLogin.ID] = userLogin + userLogin.BridgeState = br.NewBridgeStateQueue(userLogin) return userLogin, nil } @@ -135,8 +139,24 @@ func (ul *UserLogin) Logout(ctx context.Context) { delete(ul.User.logins, ul.ID) delete(ul.Bridge.userLoginsByID, ul.ID) // TODO kick user out of rooms? + ul.BridgeState.Send(status.BridgeState{StateEvent: status.StateLoggedOut}) } func (ul *UserLogin) MarkAsPreferredIn(ctx context.Context, portal *Portal) error { return ul.Bridge.DB.UserLogin.MarkLoginAsPreferredInPortal(ctx, ul.UserLogin, portal.PortalKey) } + +var _ status.BridgeStateFiller = (*UserLogin)(nil) + +func (ul *UserLogin) GetMXID() id.UserID { + return ul.UserMXID +} + +func (ul *UserLogin) GetRemoteID() string { + return string(ul.ID) +} + +func (ul *UserLogin) GetRemoteName() string { + name, _ := ul.Metadata["remote_name"].(string) + return name +} From fd5603f92209be2d07032bdc8ade51541872b590 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 11 Jun 2024 21:08:24 +0300 Subject: [PATCH 0275/1647] Update readme --- README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index d45860e7..ac41ca78 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ A Golang Matrix framework. Used by [gomuks](https://matrix.org/docs/projects/cli [go-neb](https://github.com/matrix-org/go-neb), [mautrix-whatsapp](https://github.com/mautrix/whatsapp) and others. -Matrix room: [`#maunium:maunium.net`](https://matrix.to/#/#maunium:maunium.net) +Matrix room: [`#go:maunium.net`](https://matrix.to/#/#go:maunium.net) This project is based on [matrix-org/gomatrix](https://github.com/matrix-org/gomatrix). The original project is licensed under [Apache 2.0](https://github.com/matrix-org/gomatrix/blob/master/LICENSE). @@ -14,6 +14,9 @@ In addition to the basic client API features the original project has, this fram * Appservice support (Intent API like mautrix-python, room state storage, etc) * End-to-end encryption support (incl. interactive SAS verification) +* High-level module for building puppeting bridges +* High-level module for building chat clients +* Wrapper functions for the Synapse admin API * Structs for parsing event content * Helpers for parsing and generating Matrix HTML * Helpers for handling push rules From 2d30fad138f6752140d5c542a39ade27392dd99f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 12 Jun 2024 00:00:39 +0300 Subject: [PATCH 0276/1647] Add some godocs and an example for mxmain --- bridgev2/matrix/mxmain/main.go | 99 +++++++++++++++++++++-------- bridgev2/matrix/mxmain/main_test.go | 49 ++++++++++++++ bridgev2/networkid/bridgeid.go | 40 ++++++++++-- 3 files changed, 157 insertions(+), 31 deletions(-) create mode 100644 bridgev2/matrix/mxmain/main_test.go diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index 2f881439..977b424a 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -4,6 +4,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. +// Package mxmain contains initialization code for a single-network Matrix bridge using the bridgev2 package. package mxmain import ( @@ -45,20 +46,34 @@ var ignoreForeignTables = flag.Make().LongKey("ignore-foreign-tables").Usage("Ru var ignoreUnsupportedServer = flag.Make().LongKey("ignore-unsupported-server").Usage("Run even if the Matrix homeserver is outdated").Default("false").Bool() var wantHelp, _ = flag.MakeHelpFlag() +// BridgeMain contains the main function for a Matrix bridge. type BridgeMain struct { - Name string + // Name is the name of the bridge project, e.g. mautrix-signal. + // Note that when making your own bridges that isn't under github.com/mautrix, + // you should invent your own name and not use the mautrix-* naming scheme. + Name string + // Description is a brief description of the bridge, usually of the form "A Matrix-OtherPlatform puppeting bridge." Description string - URL string - Version string + // URL is the Git repository address for the bridge. + URL string + // Version is the latest release of the bridge. InitVersion will compare this to the provided + // git tag to see if the built version is the release or a dev build. + // You can either bump this right after a release or right before, as long as it matches on the release commit. + Version string + // PostInit is a function that will be called after the bridge has been initialized but before it is started. PostInit func() + // Connector is the network connector for the bridge. Connector bridgev2.NetworkConnector - Log *zerolog.Logger - DB *dbutil.Database - Config *bridgeconfig.Config - Matrix *matrix.Connector - Bridge *bridgev2.Bridge + + // All fields below are set automatically in Run or InitVersion should not be set manually. + + Log *zerolog.Logger + DB *dbutil.Database + Config *bridgeconfig.Config + Matrix *matrix.Connector + Bridge *bridgev2.Bridge ConfigPath string RegistrationPath string @@ -93,7 +108,19 @@ type VersionJSONOutput struct { } } +// Run runs the bridge and waits for SIGTERM before stopping. func (br *BridgeMain) Run() { + br.PreInit() + br.Init() + br.Start() + br.WaitForInterrupt() + br.Stop() +} + +// PreInit parses CLI flags and loads the config file. This is called by [Run] and does not need to be called manually. +// +// This also handles all flags that cause the bridge to exit immediately (e.g. `--version` and `--generate-registration`). +func (br *BridgeMain) PreInit() { flag.SetHelpTitles( fmt.Sprintf("%s - %s", br.Name, br.Description), fmt.Sprintf("%s [-hgvn%s] [-c ] [-r ]%s", br.Name, br.AdditionalShortFlags, br.AdditionalLongFlags)) @@ -134,28 +161,11 @@ func (br *BridgeMain) Run() { exerrors.PanicIfNotNil(os.WriteFile(*configPath, []byte(br.makeFullExampleConfig(networkExample)), 0600)) return } - - br.loadConfig() + br.LoadConfig() if *generateRegistration { br.GenerateRegistration() return } - - br.Init() - err = br.Bridge.Start() - if err != nil { - var dbUpgradeErr bridgev2.DBUpgradeError - if errors.As(err, &dbUpgradeErr) { - br.LogDBUpgradeErrorAndExit(dbUpgradeErr.Section, dbUpgradeErr.Err) - } else { - br.Log.Fatal().Err(err).Msg("Failed to start bridge") - } - } - - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - <-c - br.Log.Info().Msg("Shutting down bridge") } func (br *BridgeMain) GenerateRegistration() { @@ -188,6 +198,8 @@ func (br *BridgeMain) GenerateRegistration() { os.Exit(0) } +// Init sets up logging, database connection and creates the Matrix connector and central Bridge struct. +// This is called by [Run] and does not need to be called manually. func (br *BridgeMain) Init() { var err error br.Log, err = br.Config.Logging.Compile() @@ -283,7 +295,9 @@ func (br *BridgeMain) getConfigUpgrader() (configupgrade.BaseUpgrader, any) { return upgrader, networkData } -func (br *BridgeMain) loadConfig() { +// LoadConfig upgrades and loads the config file. +// This is called by [Run] and does not need to be called manually. +func (br *BridgeMain) LoadConfig() { upgrader, networkData := br.getConfigUpgrader() configData, upgraded, err := configupgrade.Do(br.ConfigPath, br.SaveConfig, upgrader) if err != nil { @@ -307,6 +321,37 @@ func (br *BridgeMain) loadConfig() { br.Config = &cfg } +// Start starts the bridge after everything has been initialized. +// This is called by [Run] and does not need to be called manually. +func (br *BridgeMain) Start() { + err := br.Bridge.Start() + if err != nil { + var dbUpgradeErr bridgev2.DBUpgradeError + if errors.As(err, &dbUpgradeErr) { + br.LogDBUpgradeErrorAndExit(dbUpgradeErr.Section, dbUpgradeErr.Err) + } else { + br.Log.Fatal().Err(err).Msg("Failed to start bridge") + } + } +} + +// WaitForInterrupt waits for a SIGINT or SIGTERM signal. +func (br *BridgeMain) WaitForInterrupt() { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + <-c +} + +// Stop cleanly stops the bridge. This is called by [Run] and does not need to be called manually. +func (br *BridgeMain) Stop() { + br.Log.Info().Msg("Shutting down bridge") + // TODO actually stop cleanly +} + +// InitVersion formats the bridge version and build time nicely for things like +// the `version` bridge command on Matrix and the `--version` CLI flag. +// +// The values should generally be set by the build system. See the [BridgeMain] example for usage. func (br *BridgeMain) InitVersion(tag, commit, rawBuildTime string) { br.baseVersion = br.Version if len(tag) > 0 && tag[0] == 'v' { diff --git a/bridgev2/matrix/mxmain/main_test.go b/bridgev2/matrix/mxmain/main_test.go new file mode 100644 index 00000000..e7a9880b --- /dev/null +++ b/bridgev2/matrix/mxmain/main_test.go @@ -0,0 +1,49 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mxmain_test + +import ( + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/matrix/mxmain" +) + +// Information to find out exactly which commit the bridge was built from. +// These are filled at build time with the -X linker flag. +// +// For example: +// +// go build -ldflags "-X main.Tag=$(git describe --exact-match --tags 2>/dev/null) -X main.Commit=$(git rev-parse HEAD) -X 'main.BuildTime=`date -Iseconds`'" +// +// You may additionally want to fill the mautrix-go version using another ldflag: +// +// export MAUTRIX_VERSION=$(cat go.mod | grep 'maunium.net/go/mautrix ' | head -n1 | awk '{ print $2 }') +// go build -ldflags "-X 'maunium.net/go/mautrix.GoModVersion=$MAUTRIX_VERSION'" +// +// (to use both at the same time, merge the ldflags into one, `-ldflags "-X ... -X ..."`) +var ( + Tag = "unknown" + Commit = "unknown" + BuildTime = "unknown" +) + +func ExampleBridgeMain() { + var yourConnector bridgev2.NetworkConnector + m := mxmain.BridgeMain{ + Name: "example-matrix-bridge", + URL: "https://github.com/octocat/matrix-bridge", + Description: "An example Matrix bridge.", + Version: "1.0.0", + + Connector: yourConnector, + } + m.PostInit = func() { + // If you want some code to run after all the setup is done, but before the bridge is started, + // you can set a function in PostInit. This is not required if you don't need to do anything special. + } + m.InitVersion(Tag, Commit, BuildTime) + m.Run() +} diff --git a/bridgev2/networkid/bridgeid.go b/bridgev2/networkid/bridgeid.go index 57900cb7..f57e74bf 100644 --- a/bridgev2/networkid/bridgeid.go +++ b/bridgev2/networkid/bridgeid.go @@ -4,6 +4,19 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. +// Package networkid contains string types used to represent different kinds of identifiers on remote networks. +// +// Except for [BridgeID], all types in this package are only generated by network connectors. +// Network connectors may generate and parse these types any way they want, all other components +// will treat them as opaque identifiers and will not parse them nor assume anything about them. +// However, identifiers are stored in the bridge database, so backwards-compatibility must be +// considered when changing the format. +// +// All IDs are scoped to a bridge, i.e. they don't need to be unique across different bridges. +// However, most IDs need to be globally unique within the bridge, i.e. the same ID must refer +// to the same entity even from another user's point of view. If the remote network does not +// directly provide such globally unique identifiers, the network connector should prefix them +// with a user ID or other identifier to make them unique. package networkid import ( @@ -15,7 +28,8 @@ import ( // BridgeID is an opaque identifier for a bridge type BridgeID string -// PortalID is the ID of a room on the remote network. +// PortalID is the ID of a room on the remote network. A portal ID alone should identify group chats +// uniquely, and also DMs when scoped to a user login ID (see [PortalKey]). type PortalID string // PortalKey is the unique key of a room on the remote network. It combines a portal ID and a receiver ID. @@ -25,6 +39,10 @@ type PortalID string // if both sides are logged into the bridge. Also, for networks that use user IDs as DM chat IDs, // the receiver is necessary to have separate rooms for separate users who have a DM with the same // remote user. +// +// It is also permitted to use a non-empty receiver for group chats if there is a good reason to +// segregate them. For example, Telegram's non-supergroups have user-scoped message IDs instead +// of chat-scoped IDs, which is easier to manage with segregated rooms. type PortalKey struct { ID PortalID Receiver UserLoginID @@ -45,14 +63,23 @@ func (pk PortalKey) MarshalZerologObject(evt *zerolog.Event) { } // UserID is the ID of a user on the remote network. +// +// User IDs must be globally unique within the bridge for identifying a specific remote user. type UserID string -// UserLoginID is the ID of the user being controlled on the remote network. It may be the same shape as [UserID]. +// UserLoginID is the ID of the user being controlled on the remote network. +// +// It may be the same shape as [UserID]. However, being the same shape is not required, and the +// central bridge module and Matrix connectors will never assume it is. Instead, the bridge will +// use methods like [maunium.net/go/mautrix/bridgev2.NetworkAPI.IsThisUser] to check if a user ID +// is associated with a given UserLogin. +// The network connector is of course allowed to assume a UserLoginID is equivalent to a UserID, +// because it is the one defining both types. type UserLoginID string // MessageID is the ID of a message on the remote network. // -// Message IDs must be unique across rooms and consistent across users. +// Message IDs must be unique across rooms and consistent across users (i.e. globally unique within the bridge). type MessageID string // PartID is the ID of a message part on the remote network (e.g. index of image in album). @@ -76,7 +103,12 @@ type MessageOptionalPartID struct { // AvatarID is the ID of a user or room avatar on the remote network. // -// It may be a real URL, an opaque identifier, or anything in between. +// It may be a real URL, an opaque identifier, or anything in between. It should be an identifier that +// can be acquired from the remote network without downloading the entire avatar. +// +// In general, it is preferred to use a stable identifier which only changes when the avatar changes. +// However, the bridge will also hash the avatar data to check for changes before sending an avatar +// update to Matrix, so the avatar ID being slightly unstable won't be the end of the world. type AvatarID string // EmojiID is the ID of a reaction emoji on the remote network. From d33172c5d25e2335f1672d1b0e896a483e0bf6a7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 12 Jun 2024 00:05:50 +0300 Subject: [PATCH 0277/1647] Move ldflag docs to InitVersion --- bridgev2/matrix/mxmain/main.go | 19 ++++++++++++++++++- bridgev2/matrix/mxmain/main_test.go | 13 ++----------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index 977b424a..061f681e 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -351,7 +351,24 @@ func (br *BridgeMain) Stop() { // InitVersion formats the bridge version and build time nicely for things like // the `version` bridge command on Matrix and the `--version` CLI flag. // -// The values should generally be set by the build system. See the [BridgeMain] example for usage. +// The values should generally be set by the build system. For example, assuming you have +// +// var ( +// Tag = "unknown" +// Commit = "unknown" +// BuildTime = "unknown" +// ) +// +// in your main package, then you'd use the following ldflags to fill them appropriately: +// +// go build -ldflags "-X main.Tag=$(git describe --exact-match --tags 2>/dev/null) -X main.Commit=$(git rev-parse HEAD) -X 'main.BuildTime=`date -Iseconds`'" +// +// You may additionally want to fill the mautrix-go version using another ldflag: +// +// export MAUTRIX_VERSION=$(cat go.mod | grep 'maunium.net/go/mautrix ' | head -n1 | awk '{ print $2 }') +// go build -ldflags "-X 'maunium.net/go/mautrix.GoModVersion=$MAUTRIX_VERSION'" +// +// (to use both at the same time, simply merge the ldflags into one, `-ldflags "-X '...' -X ..."`) func (br *BridgeMain) InitVersion(tag, commit, rawBuildTime string) { br.baseVersion = br.Version if len(tag) > 0 && tag[0] == 'v' { diff --git a/bridgev2/matrix/mxmain/main_test.go b/bridgev2/matrix/mxmain/main_test.go index e7a9880b..9a71344d 100644 --- a/bridgev2/matrix/mxmain/main_test.go +++ b/bridgev2/matrix/mxmain/main_test.go @@ -13,17 +13,6 @@ import ( // Information to find out exactly which commit the bridge was built from. // These are filled at build time with the -X linker flag. -// -// For example: -// -// go build -ldflags "-X main.Tag=$(git describe --exact-match --tags 2>/dev/null) -X main.Commit=$(git rev-parse HEAD) -X 'main.BuildTime=`date -Iseconds`'" -// -// You may additionally want to fill the mautrix-go version using another ldflag: -// -// export MAUTRIX_VERSION=$(cat go.mod | grep 'maunium.net/go/mautrix ' | head -n1 | awk '{ print $2 }') -// go build -ldflags "-X 'maunium.net/go/mautrix.GoModVersion=$MAUTRIX_VERSION'" -// -// (to use both at the same time, merge the ldflags into one, `-ldflags "-X ... -X ..."`) var ( Tag = "unknown" Commit = "unknown" @@ -31,7 +20,9 @@ var ( ) func ExampleBridgeMain() { + // Set this yourself var yourConnector bridgev2.NetworkConnector + m := mxmain.BridgeMain{ Name: "example-matrix-bridge", URL: "https://github.com/octocat/matrix-bridge", From 7262fa71df24793df90c779b3b9d8727d7cf48ce Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 12 Jun 2024 00:53:18 +0300 Subject: [PATCH 0278/1647] bridgev2/login: improve default phone number validator --- bridgev2/login.go | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/bridgev2/login.go b/bridgev2/login.go index 23647042..b592e64b 100644 --- a/bridgev2/login.go +++ b/bridgev2/login.go @@ -9,7 +9,6 @@ package bridgev2 import ( "context" "fmt" - "regexp" "strings" "maunium.net/go/mautrix/bridgev2/networkid" @@ -124,7 +123,16 @@ type LoginInputDataField struct { Validate func(string) (string, error) `json:"-"` } -var phoneNumberRe = regexp.MustCompile(`\+\d+`) +var numberCleaner = strings.NewReplacer("-", "", " ", "", "(", "", ")", "") + +func isOnlyNumbers(input string) bool { + for _, r := range input { + if r < '0' || r > '9' { + return false + } + } + return true +} func (f *LoginInputDataField) FillDefaultValidate() { noopValidate := func(input string) (string, error) { return input, nil } @@ -134,9 +142,13 @@ func (f *LoginInputDataField) FillDefaultValidate() { switch f.Type { case LoginInputFieldTypePhoneNumber: f.Validate = func(phone string) (string, error) { - phone = strings.ReplaceAll(phone, " ", "") - if !phoneNumberRe.MatchString(phone) { - return "", fmt.Errorf("invalid phone number") + phone = numberCleaner.Replace(phone) + if len(phone) < 2 { + return "", fmt.Errorf("phone number must start with + and contain numbers") + } else if phone[0] != '+' { + return "", fmt.Errorf("phone number must start with +") + } else if !isOnlyNumbers(phone[1:]) { + return "", fmt.Errorf("phone number must only contain numbers") } return phone, nil } From 3a98f57d193bf985cfe829043f49f5a77778e963 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 12 Jun 2024 01:05:37 +0300 Subject: [PATCH 0279/1647] bridgev2: add more unorganized documentation --- bridgev2/login.go | 12 ++++ bridgev2/networkinterface.go | 32 +++++++++ bridgev2/unorganized-docs/README.md | 66 +++++++++++++++++++ .../incoming-matrix-message.uml | 23 +++++++ .../incoming-remote-message.uml | 22 +++++++ .../login-step.schema.json | 0 .../{ => unorganized-docs}/login-steps.uml | 0 7 files changed, 155 insertions(+) create mode 100644 bridgev2/unorganized-docs/README.md create mode 100644 bridgev2/unorganized-docs/incoming-matrix-message.uml create mode 100644 bridgev2/unorganized-docs/incoming-remote-message.uml rename bridgev2/{ => unorganized-docs}/login-step.schema.json (100%) rename bridgev2/{ => unorganized-docs}/login-steps.uml (100%) diff --git a/bridgev2/login.go b/bridgev2/login.go index b592e64b..636619ce 100644 --- a/bridgev2/login.go +++ b/bridgev2/login.go @@ -14,8 +14,20 @@ import ( "maunium.net/go/mautrix/bridgev2/networkid" ) +// LoginProcess represents a single occurrence of a user logging into the remote network. type LoginProcess interface { + // Start starts the process and returns the first step. + // + // For example, a network using QR login may connect to the network, fetch a QR code, + // and return a DisplayAndWait-type step. + // + // This will only ever be called once. Start(ctx context.Context) (*LoginStep, error) + // Cancel stops the login process and cleans up any resources. + // No other methods will be called after cancel. + // + // Cancel will not be called if any other method returned an error: + // errors are always treated as fatal and the process is assumed to be automatically cancelled. Cancel() } diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 5b07ab63..df23e8c4 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -56,6 +56,7 @@ type ConvertedEdit struct { DeletedParts []*database.Message } +// BridgeName contains information about the network that a connector bridges to. type BridgeName struct { // The displayname of the network, e.g. `Discord` DisplayName string @@ -82,23 +83,50 @@ func (bn BridgeName) AsBridgeInfoSection() event.BridgeInfoSection { } } +// NetworkConnector is the main interface that a network connector must implement. type NetworkConnector interface { + // Init is called when the bridge is initialized. The connector should store the bridge instance for later use. + // This should not do any network calls or other blocking operations. Init(*Bridge) + // Start is called when the bridge is starting. + // The connector should do any non-user-specific startup actions necessary. + // User logins will be loaded separately, so the connector should not load them here. Start(context.Context) error + // LoadUserLogin is called when a UserLogin is loaded from the database in order to fill the [UserLogin.Client] field. + // + // This is called within the bridge's global cache lock, so it must not do any slow operations, + // such as connecting to the network. Instead, connecting should happen when [NetworkAPI.Connect] is called later. LoadUserLogin(ctx context.Context, login *UserLogin) error GetName() BridgeName + // GetConfig returns all the parts of the network connector's config file. Specifically: + // - example: a string containing an example config file + // - data: an interface to unmarshal the actual config into + // - upgrader: a config upgrader to ensure all fields are present and to do any migrations from old configs GetConfig() (example string, data any, upgrader configupgrade.Upgrader) + // GetLoginFlows returns a list of login flows that the network supports. GetLoginFlows() []LoginFlow + // CreateLogin is called when a user wants to log in to the network. + // + // This should generally not do any work, it should just return a LoginProcess that remembers + // the user and will execute the requested flow. The actual work should start when [LoginProcess.Start] is called. CreateLogin(ctx context.Context, user *User, flowID string) (LoginProcess, error) } +// ConfigValidatingNetwork is an optional interface that network connectors can implement to validate config fields +// before the bridge is started. +// +// When the ValidateConfig method is called, the config data will already be unmarshaled into the +// object returned by [NetworkConnector.GetConfig]. +// +// This mechanism is usually used to refuse bridge startup if a mandatory field has an invalid value. type ConfigValidatingNetwork interface { NetworkConnector ValidateConfig() error } +// NetworkAPI is an interface representing a remote network client for a single user login. type NetworkAPI interface { Connect(ctx context.Context) error IsLoggedIn() bool @@ -127,6 +155,10 @@ const ( RemoteEventMessageRemove ) +// RemoteEvent represents a single event from the remote network, such as a message or a reaction. +// +// When a [NetworkAPI] receives an event from the remote network, it should convert it into a [RemoteEvent] +// and pass it to the bridge for processing using [Bridge.QueueRemoteEvent]. type RemoteEvent interface { GetType() RemoteEventType GetPortalKey() networkid.PortalKey diff --git a/bridgev2/unorganized-docs/README.md b/bridgev2/unorganized-docs/README.md new file mode 100644 index 00000000..62cc731a --- /dev/null +++ b/bridgev2/unorganized-docs/README.md @@ -0,0 +1,66 @@ +# Megabridge +Megabridge, also known as bridgev2 (final naming is subject to change), is a +new high-level framework for writing puppeting Matrix bridges with hopefully +minimal boilerplate code. + +## General architecture +Megabridge is split into three components: network connectors, the central +bridge module, and Matrix connectors. + +* Network connectors are responsible for connecting to the remote (non-Matrix) + network and handling all the protocol-specific details. +* The central bridge module has most of the generic bridge logic, such as + keeping track of portal mappings and handling messages. +* Matrix connectors are responsible for connecting to Matrix. Initially there + will be two Matrix connectors: one for the standard setup that connects to + a Matrix homeserver as an application service, and another for Beeper's local + bridge system. However, in the future there could be a third connector which + uses a single bot account and [MSC4144] instead of an appservice with ghost + users. + + [MSC4144]: https://github.com/matrix-org/matrix-spec-proposals/pull/4144 + +The central bridge module defines interfaces that it uses to interact with the +connectors on both sides. Additionally, the connectors are allowed to directly +call interface methods on other side. + +## Getting started with a new network connector +To create a new network connector, you need to implement the +`NetworkConnector`, `LoginProcess`, `NetworkAPI` and `RemoteEvent` interfaces. + +* `NetworkConnector` is the main entry point to the remote network. It is + responsible for general non-user-specific things, as well as creating + `NetworkAPI`s and starting login flows. +* `LoginProcess` is a state machine for logging into the remote network. +* `NetworkAPI` is the remote network client for a single login. It is + responsible for maintaining the connection to the remote network, receiving + incoming events, sending outgoing events, and fetching information like + chat/user metadata. +* `RemoteEvent` represents a single event from the remote network, such as a + message or a reaction. When the NetworkAPI receives an event, it should create + a `RemoteEvent` object and pass it to the bridge using `Bridge.QueueRemoteEvent`. + +### Login +Logins are implemented by combining three types of steps: + +* `user_input` asks the user to enter some information, such as a phone number, + username, email, password, or 2FA code. +* `cookies` either asks the user to extract cookies from their browser, or opens + a webview to do it automatically (depending on whether the login is being done + via bridge commands or a more advanced client). +* `display_and_wait` displays a QR code or other data to the user and waits until + the remote network accepts the login. + +The general flow is: + +1. Login handler (bridge command or client) calls `NetworkConnector.GetLoginFlows` + to get available login flows, and asks the user to pick one (or alternatively + automatically picks the first one if there's only one option). +2. Login handler calls `NetworkConnector.CreateLogin` with the chosen flow ID and + the network connector returns a `LoginProcess` object that remembers the user + and flow. +3. Login handler calls `LoginProcess.Start` to get the first step. +4. Login handler calls the appropriate functions (`Wait`, `SubmitUserInput` or + `SubmitCookies`) based on the step data as many times as needed. +5. When the login is done, the login process creates the `UserLogin` object and + returns a `complete` step. diff --git a/bridgev2/unorganized-docs/incoming-matrix-message.uml b/bridgev2/unorganized-docs/incoming-matrix-message.uml new file mode 100644 index 00000000..ae13ee74 --- /dev/null +++ b/bridgev2/unorganized-docs/incoming-matrix-message.uml @@ -0,0 +1,23 @@ +title Bridge v2 incoming Matrix message + +participant Network Library +participant Network Connector +participant Bridge +participant Portal +participant Database +participant Matrix + +Matrix->Bridge: QueueMatrixEvent(evt) +note over Bridge: GetPortalByID(evt.GetPortalID()) +Bridge->Portal: portal.events <- evt +loop event queue consumer + Portal->+Portal: \n evt := <-portal.events + note over Portal: Check for edit, reply/thread, etc + Portal->+Network Connector: HandleMatrixMessage(evt, replyTo) + Network Connector->Network Connector: msg := ConvertMatrixMessage(evt) + Network Connector->+Network Library: SendMessage(msg) + Network Library->-Network Connector: OK + Network Connector->-Portal: *database.Message{msg.ID} + Portal->-Database: Message.Insert() + Portal->Matrix: Success checkpoint +end diff --git a/bridgev2/unorganized-docs/incoming-remote-message.uml b/bridgev2/unorganized-docs/incoming-remote-message.uml new file mode 100644 index 00000000..f86d6e65 --- /dev/null +++ b/bridgev2/unorganized-docs/incoming-remote-message.uml @@ -0,0 +1,22 @@ +title Bridge v2 incoming remote message + +participant Network Library +participant Network Connector +participant Bridge +participant Portal +participant Database +participant Matrix + +Network Library->Network Connector: New event +Network Connector->Bridge: QueueRemoteEvent(evt) +note over Bridge: GetPortalByID(evt.GetPortalID()) +Bridge->Portal: portal.events <- evt +loop event queue consumer + Portal->+Portal: \n evt := <-portal.events + note over Portal: CreateMatrixRoom() if applicable + Portal->+Network Connector: ConvertRemoteMessage(evt) + Network Connector->-Portal: *ConvertedMessage + Portal->+Matrix: SendMessage(convertedMsg) + Matrix->-Portal: event ID + Portal->-Database: Message.Insert() +end diff --git a/bridgev2/login-step.schema.json b/bridgev2/unorganized-docs/login-step.schema.json similarity index 100% rename from bridgev2/login-step.schema.json rename to bridgev2/unorganized-docs/login-step.schema.json diff --git a/bridgev2/login-steps.uml b/bridgev2/unorganized-docs/login-steps.uml similarity index 100% rename from bridgev2/login-steps.uml rename to bridgev2/unorganized-docs/login-steps.uml From f0690182c7ecf5255e5ce8d86d8528c6b21a5665 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 12 Jun 2024 15:04:43 +0300 Subject: [PATCH 0280/1647] bridgev2: add interface for network connectors to find out the max file size --- bridgev2/matrix/connector.go | 4 ++++ bridgev2/networkinterface.go | 10 ++++++++++ 2 files changed, 14 insertions(+) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 76c78bfd..bd86fea6 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -250,6 +250,10 @@ func (br *Connector) fetchMediaConfig(ctx context.Context) { cfg.UploadSize = 50 * 1024 * 1024 } br.MediaConfig = *cfg + mfsn, ok := br.Bridge.Network.(bridgev2.MaxFileSizeingNetwork) + if ok { + mfsn.SetMaxFileSize(br.MediaConfig.UploadSize) + } } } diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index df23e8c4..9dc6407a 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -126,6 +126,16 @@ type ConfigValidatingNetwork interface { ValidateConfig() error } +// MaxFileSizeingNetwork is an optional interface that network connectors can implement +// to find out the maximum file size that can be uploaded to Matrix. +// +// The SetMaxFileSize will be called asynchronously soon after startup. +// Before the function is called, the connector may assume a default limit of 50 MiB. +type MaxFileSizeingNetwork interface { + NetworkConnector + SetMaxFileSize(maxSize int64) +} + // NetworkAPI is an interface representing a remote network client for a single user login. type NetworkAPI interface { Connect(ctx context.Context) error From cf6b0e71f0a008d41c4ab6b60b71fb52054b8d83 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 12 Jun 2024 19:24:07 +0300 Subject: [PATCH 0281/1647] bridgev2: read receipt support --- bridgev2/database/database.go | 32 +++--- bridgev2/database/upgrades/00-latest.sql | 1 + bridgev2/database/user.go | 11 -- bridgev2/database/userlogin.go | 23 ---- bridgev2/database/userportal.go | 127 +++++++++++++++++++++ bridgev2/matrix/connector.go | 2 + bridgev2/matrix/intent.go | 17 +++ bridgev2/matrix/matrix.go | 22 +++- bridgev2/matrixinterface.go | 1 + bridgev2/networkinterface.go | 38 ++++++- bridgev2/portal.go | 136 ++++++++++++++++++++--- bridgev2/queue.go | 2 +- bridgev2/userlogin.go | 2 +- 13 files changed, 342 insertions(+), 72 deletions(-) create mode 100644 bridgev2/database/userportal.go diff --git a/bridgev2/database/database.go b/bridgev2/database/database.go index 688a40da..c4e2598f 100644 --- a/bridgev2/database/database.go +++ b/bridgev2/database/database.go @@ -17,26 +17,28 @@ import ( type Database struct { *dbutil.Database - BridgeID networkid.BridgeID - Portal *PortalQuery - Ghost *GhostQuery - Message *MessageQuery - Reaction *ReactionQuery - User *UserQuery - UserLogin *UserLoginQuery + BridgeID networkid.BridgeID + Portal *PortalQuery + Ghost *GhostQuery + Message *MessageQuery + Reaction *ReactionQuery + User *UserQuery + UserLogin *UserLoginQuery + UserPortal *UserPortalQuery } func New(bridgeID networkid.BridgeID, db *dbutil.Database) *Database { db.UpgradeTable = upgrades.Table return &Database{ - Database: db, - BridgeID: bridgeID, - Portal: &PortalQuery{bridgeID, dbutil.MakeQueryHelper(db, newPortal)}, - Ghost: &GhostQuery{bridgeID, dbutil.MakeQueryHelper(db, newGhost)}, - Message: &MessageQuery{bridgeID, dbutil.MakeQueryHelper(db, newMessage)}, - Reaction: &ReactionQuery{bridgeID, dbutil.MakeQueryHelper(db, newReaction)}, - User: &UserQuery{bridgeID, dbutil.MakeQueryHelper(db, newUser)}, - UserLogin: &UserLoginQuery{bridgeID, dbutil.MakeQueryHelper(db, newUserLogin)}, + Database: db, + BridgeID: bridgeID, + Portal: &PortalQuery{bridgeID, dbutil.MakeQueryHelper(db, newPortal)}, + Ghost: &GhostQuery{bridgeID, dbutil.MakeQueryHelper(db, newGhost)}, + Message: &MessageQuery{bridgeID, dbutil.MakeQueryHelper(db, newMessage)}, + Reaction: &ReactionQuery{bridgeID, dbutil.MakeQueryHelper(db, newReaction)}, + User: &UserQuery{bridgeID, dbutil.MakeQueryHelper(db, newUser)}, + UserLogin: &UserLoginQuery{bridgeID, dbutil.MakeQueryHelper(db, newUserLogin)}, + UserPortal: &UserPortalQuery{bridgeID, dbutil.MakeQueryHelper(db, newUserPortal)}, } } diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index dd03a0eb..8d9d150e 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -132,6 +132,7 @@ CREATE TABLE user_portal ( portal_receiver TEXT NOT NULL, in_space BOOLEAN NOT NULL, preferred BOOLEAN NOT NULL, + last_read BIGINT, PRIMARY KEY (bridge_id, user_mxid, login_id, portal_id, portal_receiver), CONSTRAINT user_portal_user_login_fkey FOREIGN KEY (bridge_id, user_mxid, login_id) diff --git a/bridgev2/database/user.go b/bridgev2/database/user.go index c549e234..de3b316a 100644 --- a/bridgev2/database/user.go +++ b/bridgev2/database/user.go @@ -46,12 +46,6 @@ const ( UPDATE "user" SET management_room=$3, access_token=$4 WHERE bridge_id=$1 AND mxid=$2 ` - findUserLoginsByPortalIDQuery = ` - SELECT login_id - FROM user_portal - WHERE bridge_id=$1 AND user_mxid=$2 AND portal_id=$3 AND portal_receiver=$4 - ORDER BY CASE WHEN preferred THEN 0 ELSE 1 END, login_id - ` ) func (uq *UserQuery) GetByMXID(ctx context.Context, userID id.UserID) (*User, error) { @@ -68,11 +62,6 @@ func (uq *UserQuery) Update(ctx context.Context, user *User) error { return uq.Exec(ctx, updateUserQuery, user.sqlVariables()...) } -func (uq *UserQuery) FindLoginsByPortalID(ctx context.Context, userID id.UserID, portal networkid.PortalKey) ([]networkid.UserLoginID, error) { - rows, err := uq.GetDB().Query(ctx, findUserLoginsByPortalIDQuery, uq.BridgeID, userID, portal.ID, portal.Receiver) - return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[networkid.UserLoginID], err).AsList() -} - func (u *User) Scan(row dbutil.Scannable) (*User, error) { var managementRoom, accessToken sql.NullString err := row.Scan(&u.BridgeID, &u.MXID, &managementRoom, &accessToken) diff --git a/bridgev2/database/userlogin.go b/bridgev2/database/userlogin.go index 36b2032b..65ed4bf8 100644 --- a/bridgev2/database/userlogin.go +++ b/bridgev2/database/userlogin.go @@ -55,19 +55,6 @@ const ( deleteUserLoginQuery = ` DELETE FROM user_login WHERE bridge_id=$1 AND id=$2 ` - insertUserPortalQuery = ` - INSERT INTO user_portal (bridge_id, user_mxid, login_id, portal_id, portal_receiver, in_space, preferred) - VALUES ($1, $2, $3, $4, $5, false, false) - ON CONFLICT (bridge_id, user_mxid, login_id, portal_id, portal_receiver) DO NOTHING - ` - upsertUserPortalQuery = ` - INSERT INTO user_portal (bridge_id, user_mxid, login_id, portal_id, portal_receiver, in_space, preferred) - VALUES ($1, $2, $3, $4, $5, $6, $7) - ON CONFLICT (bridge_id, user_mxid, login_id, portal_id, portal_receiver) DO UPDATE SET in_space=excluded.in_space, preferred=excluded.preferred - ` - markLoginAsPreferredQuery = ` - UPDATE user_portal SET preferred=(login_id=$3) WHERE bridge_id=$1 AND user_mxid=$2 AND portal_id=$4 AND portal_receiver=$5 - ` ) func (uq *UserLoginQuery) GetAll(ctx context.Context) ([]*UserLogin, error) { @@ -96,16 +83,6 @@ func (uq *UserLoginQuery) Delete(ctx context.Context, loginID networkid.UserLogi return uq.Exec(ctx, deleteUserLoginQuery, uq.BridgeID, loginID) } -func (uq *UserLoginQuery) EnsureUserPortalExists(ctx context.Context, login *UserLogin, portal networkid.PortalKey) error { - ensureBridgeIDMatches(&login.BridgeID, uq.BridgeID) - return uq.Exec(ctx, insertUserPortalQuery, login.BridgeID, login.UserMXID, login.ID, portal.ID, portal.Receiver) -} - -func (uq *UserLoginQuery) MarkLoginAsPreferredInPortal(ctx context.Context, login *UserLogin, portal networkid.PortalKey) error { - ensureBridgeIDMatches(&login.BridgeID, uq.BridgeID) - return uq.Exec(ctx, markLoginAsPreferredQuery, login.BridgeID, login.UserMXID, login.ID, portal.ID, portal.Receiver) -} - func (u *UserLogin) Scan(row dbutil.Scannable) (*UserLogin, error) { var spaceRoom sql.NullString err := row.Scan(&u.BridgeID, &u.UserMXID, &u.ID, &spaceRoom, dbutil.JSON{Data: &u.Metadata}) diff --git a/bridgev2/database/userportal.go b/bridgev2/database/userportal.go new file mode 100644 index 00000000..2a41fc91 --- /dev/null +++ b/bridgev2/database/userportal.go @@ -0,0 +1,127 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "time" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/id" +) + +type UserPortalQuery struct { + BridgeID networkid.BridgeID + *dbutil.QueryHelper[*UserPortal] +} + +type UserPortal struct { + BridgeID networkid.BridgeID + UserMXID id.UserID + LoginID networkid.UserLoginID + Portal networkid.PortalKey + InSpace *bool + Preferred *bool + LastRead time.Time +} + +func newUserPortal(_ *dbutil.QueryHelper[*UserPortal]) *UserPortal { + return &UserPortal{} +} + +const ( + getUserPortalBaseQuery = ` + SELECT bridge_id, user_mxid, login_id, portal_id, portal_receiver, in_space, preferred, last_read + FROM user_portal + ` + getUserPortalQuery = getUserPortalBaseQuery + ` + WHERE bridge_id=$1 AND user_mxid=$2 AND login_id=$3 AND portal_id=$4 AND portal_receiver=$5 + ` + findUserLoginsByPortalIDQuery = getUserPortalBaseQuery + ` + WHERE bridge_id=$1 AND user_mxid=$2 AND portal_id=$3 AND portal_receiver=$4 + ORDER BY CASE WHEN preferred THEN 0 ELSE 1 END, login_id + ` + insertUserPortalQuery = ` + INSERT INTO user_portal (bridge_id, user_mxid, login_id, portal_id, portal_receiver, in_space, preferred) + VALUES ($1, $2, $3, $4, $5, false, false) + ON CONFLICT (bridge_id, user_mxid, login_id, portal_id, portal_receiver) DO NOTHING + ` + upsertUserPortalQuery = ` + INSERT INTO user_portal (bridge_id, user_mxid, login_id, portal_id, portal_receiver, in_space, preferred, last_read) + VALUES ($1, $2, $3, $4, $5, COALESCE($6, false), COALESCE($7, false), $8) + ON CONFLICT (bridge_id, user_mxid, login_id, portal_id, portal_receiver) DO UPDATE + SET in_space=COALESCE($6, user_portal.in_space), + preferred=COALESCE($7, user_portal.preferred), + last_read=COALESCE($8, user_portal.last_read) + ` + markLoginAsPreferredQuery = ` + UPDATE user_portal SET preferred=(login_id=$3) WHERE bridge_id=$1 AND user_mxid=$2 AND portal_id=$4 AND portal_receiver=$5 + ` +) + +func UserPortalFor(ul *UserLogin, portal networkid.PortalKey) *UserPortal { + return &UserPortal{ + BridgeID: ul.BridgeID, + UserMXID: ul.UserMXID, + LoginID: ul.ID, + Portal: portal, + } +} + +func (upq *UserPortalQuery) GetAllByUser(ctx context.Context, userID id.UserID, portal networkid.PortalKey) ([]*UserPortal, error) { + return upq.QueryMany(ctx, findUserLoginsByPortalIDQuery, upq.BridgeID, userID, portal.ID, portal.Receiver) +} + +func (upq *UserPortalQuery) Get(ctx context.Context, login *UserLogin, portal networkid.PortalKey) (*UserPortal, error) { + return upq.QueryOne(ctx, getUserPortalQuery, upq.BridgeID, login.UserMXID, login.ID, portal.ID, portal.Receiver) +} + +func (upq *UserPortalQuery) Put(ctx context.Context, up *UserPortal) error { + ensureBridgeIDMatches(&up.BridgeID, upq.BridgeID) + return upq.Exec(ctx, upsertUserPortalQuery, up.sqlVariables()...) +} + +func (upq *UserPortalQuery) EnsureExists(ctx context.Context, login *UserLogin, portal networkid.PortalKey) error { + return upq.Exec(ctx, insertUserPortalQuery, upq.BridgeID, login.UserMXID, login.ID, portal.ID, portal.Receiver) +} + +func (upq *UserPortalQuery) MarkAsPreferred(ctx context.Context, login *UserLogin, portal networkid.PortalKey) error { + return upq.Exec(ctx, markLoginAsPreferredQuery, upq.BridgeID, login.UserMXID, login.ID, portal.ID, portal.Receiver) +} + +func (up *UserPortal) Scan(row dbutil.Scannable) (*UserPortal, error) { + var lastRead sql.NullInt64 + err := row.Scan( + &up.BridgeID, &up.UserMXID, &up.LoginID, &up.Portal.ID, &up.Portal.Receiver, + &up.InSpace, &up.Preferred, &lastRead, + ) + if err != nil { + return nil, err + } + if lastRead.Valid { + up.LastRead = time.Unix(0, lastRead.Int64) + } + return up, nil +} + +func (up *UserPortal) sqlVariables() []any { + return []any{ + up.BridgeID, up.UserMXID, up.LoginID, up.Portal.ID, up.Portal.Receiver, + up.InSpace, + up.Preferred, + dbutil.ConvertedPtr(up.LastRead, time.Time.UnixNano), + } +} + +func (up *UserPortal) ResetValues() { + up.InSpace = nil + up.Preferred = nil + up.LastRead = time.Time{} +} diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index bd86fea6..c3bf99de 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -103,6 +103,8 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) { br.EventProcessor.On(event.EventRedaction, br.handleRoomEvent) br.EventProcessor.On(event.EventEncrypted, br.handleEncryptedEvent) br.EventProcessor.On(event.StateMember, br.handleRoomEvent) + br.EventProcessor.On(event.EphemeralEventReceipt, br.handleEphemeralEvent) + br.EventProcessor.On(event.EphemeralEventTyping, br.handleEphemeralEvent) br.Bot = br.AS.BotIntent() br.Crypto = NewCryptoHelper(br) br.Bridge.Commands.AddHandlers( diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 6c81349b..fbe1e261 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -65,6 +65,23 @@ func (as *ASIntent) SendState(ctx context.Context, roomID id.RoomID, eventType e } } +func (as *ASIntent) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.EventID, ts time.Time) error { + extraData := map[string]any{} + if !ts.IsZero() { + extraData["ts"] = ts.UnixMilli() + } + as.Matrix.AddDoublePuppetValue(extraData) + req := mautrix.ReqSetReadMarkers{ + Read: eventID, + BeeperReadExtra: extraData, + } + if as.Matrix.IsCustomPuppet { + req.FullyRead = eventID + req.BeeperFullyReadExtra = extraData + } + return as.Matrix.SetReadMarkers(ctx, roomID, &req) +} + func (as *ASIntent) DownloadMedia(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo) ([]byte, error) { if file != nil { uri = file.URL diff --git a/bridgev2/matrix/matrix.go b/bridgev2/matrix/matrix.go index 8a96f039..4fde9941 100644 --- a/bridgev2/matrix/matrix.go +++ b/bridgev2/matrix/matrix.go @@ -34,8 +34,26 @@ func (br *Connector) handleRoomEvent(ctx context.Context, evt *event.Event) { } func (br *Connector) handleEphemeralEvent(ctx context.Context, evt *event.Event) { - if br.shouldIgnoreEvent(evt) { - return + if evt.Type == event.EphemeralEventReceipt { + receiptContent := *evt.Content.AsReceipt() + for eventID, receipts := range receiptContent { + for receiptType, userReceipts := range receipts { + for userID, receipt := range userReceipts { + if br.AS.DoublePuppetValue != "" && receipt.Extra[appservice.DoublePuppetKey] == br.AS.DoublePuppetValue { + delete(userReceipts, userID) + } + } + if len(userReceipts) == 0 { + delete(receipts, receiptType) + } + } + if len(receipts) == 0 { + delete(receiptContent, eventID) + } + } + if len(receiptContent) == 0 { + return + } } br.Bridge.QueueMatrixEvent(ctx, evt) } diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 350bf005..da8c92a8 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -42,6 +42,7 @@ type MatrixAPI interface { SendMessage(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, ts time.Time) (*mautrix.RespSendEvent, error) SendState(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, content *event.Content, ts time.Time) (*mautrix.RespSendEvent, error) + MarkRead(ctx context.Context, roomID id.RoomID, eventID id.EventID, ts time.Time) error DownloadMedia(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo) ([]byte, error) UploadMedia(ctx context.Context, roomID id.RoomID, data []byte, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 9dc6407a..def970ed 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -152,6 +152,7 @@ type NetworkAPI interface { HandleMatrixReaction(ctx context.Context, msg *MatrixReaction) (reaction *database.Reaction, err error) HandleMatrixReactionRemove(ctx context.Context, msg *MatrixReactionRemove) error HandleMatrixMessageRemove(ctx context.Context, msg *MatrixMessageRemove) error + HandleMatrixReadReceipt(ctx context.Context, msg *MatrixReadReceipt) error } type RemoteEventType int @@ -163,6 +164,9 @@ const ( RemoteEventReaction RemoteEventReactionRemove RemoteEventMessageRemove + RemoteEventReadReceipt + RemoteEventDeliveryReceipt + RemoteEventTyping ) // RemoteEvent represents a single event from the remote network, such as a message or a reaction. @@ -172,11 +176,15 @@ const ( type RemoteEvent interface { GetType() RemoteEventType GetPortalKey() networkid.PortalKey - ShouldCreatePortal() bool AddLogContext(c zerolog.Context) zerolog.Context GetSender() EventSender } +type RemoteEventThatMayCreatePortal interface { + RemoteEvent + ShouldCreatePortal() bool +} + type RemoteEventWithTargetMessage interface { RemoteEvent GetTargetMessage() networkid.MessageID @@ -222,6 +230,17 @@ type RemoteMessageRemove interface { RemoteEventWithTargetMessage } +type RemoteReceipt interface { + RemoteEvent + GetLastReceiptTarget() networkid.MessageID + GetReceiptTargets() []networkid.MessageID +} + +type RemoteTyping interface { + RemoteEvent + GetTimeout() time.Duration +} + // SimpleRemoteEvent is a simple implementation of RemoteEvent that can be used with struct fields and some callbacks. type SimpleRemoteEvent[T any] struct { Type RemoteEventType @@ -360,3 +379,20 @@ type MatrixMessageRemove struct { MatrixEventBase[*event.RedactionEventContent] TargetMessage *database.Message } + +type MatrixReadReceipt struct { + Portal *Portal + // The event ID that the receipt is targeting + EventID id.EventID + // The exact message that was read. This may be nil if the event ID isn't a message. + ExactMessage *database.Message + // The timestamp that the user has read up to. This is either the timestamp of the message + // (if one is present) or the timestamp of the receipt. + ReadUpTo time.Time + // The ReadUpTo timestamp of the previous message + LastRead time.Time + // The receipt metadata. + Receipt event.ReadReceipt +} + +type MatrixTyping struct{} diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 49ed3f96..d1b35664 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -178,22 +178,25 @@ func (portal *Portal) eventLoop() { } } -func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User) (*UserLogin, error) { - logins, err := portal.Bridge.DB.User.FindLoginsByPortalID(ctx, user.MXID, portal.PortalKey) +func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowRelay bool) (*UserLogin, *database.UserPortal, error) { + logins, err := portal.Bridge.DB.UserPortal.GetAllByUser(ctx, user.MXID, portal.PortalKey) if err != nil { - return nil, err + return nil, nil, err } portal.Bridge.cacheLock.Lock() defer portal.Bridge.cacheLock.Unlock() - for _, loginID := range logins { - login, ok := user.logins[loginID] + for _, up := range logins { + login, ok := user.logins[up.LoginID] if ok && login.Client != nil { - return login, nil + return login, up, nil } } + if !allowRelay { + return nil, nil, ErrNotLoggedIn + } // Portal has relay, use it if portal.Relay != nil { - return nil, nil + return nil, nil, nil } var firstLogin *UserLogin for _, login := range user.logins { @@ -204,9 +207,9 @@ func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User) (*User zerolog.Ctx(ctx).Warn(). Str("chosen_login_id", string(firstLogin.ID)). Msg("No usable user portal rows found, returning random login") - return firstLogin, nil + return firstLogin, nil, nil } else { - return nil, ErrNotLoggedIn + return nil, nil, ErrNotLoggedIn } } @@ -244,7 +247,7 @@ func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) { Stringer("sender", sender.MXID). Logger() ctx := log.WithContext(context.TODO()) - login, err := portal.FindPreferredLogin(ctx, sender) + login, _, err := portal.FindPreferredLogin(ctx, sender, true) if err != nil { log.Err(err).Msg("Failed to get user login to handle Matrix event") portal.sendErrorStatus(ctx, evt, WrapErrorInStatus(err).WithMessage("Failed to get login to handle event").WithIsCertain(true).WithSendNotice(true)) @@ -287,29 +290,71 @@ func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) { } func (portal *Portal) handleMatrixReceipts(evt *event.Event) { - content, ok := evt.Content.Parsed.(event.ReceiptEventContent) + content, ok := evt.Content.Parsed.(*event.ReceiptEventContent) if !ok { return } - ctx := context.TODO() - for evtID, receipts := range content { + for evtID, receipts := range *content { readReceipts, ok := receipts[event.ReceiptTypeRead] if !ok { continue } for userID, receipt := range readReceipts { - sender, err := portal.Bridge.GetUserByMXID(ctx, userID) + sender, err := portal.Bridge.GetUserByMXID(context.TODO(), userID) if err != nil { // TODO log return } - portal.handleMatrixReadReceipt(ctx, sender, evtID, receipt) + portal.handleMatrixReadReceipt(sender, evtID, receipt) } } } -func (portal *Portal) handleMatrixReadReceipt(ctx context.Context, user *User, eventID id.EventID, receipt event.ReadReceipt) { - // TODO send read receipt(s) to network +func (portal *Portal) handleMatrixReadReceipt(user *User, eventID id.EventID, receipt event.ReadReceipt) { + log := portal.Log.With(). + Str("action", "handle matrix read receipt"). + Stringer("event_id", eventID). + Stringer("user_id", user.MXID). + Logger() + ctx := log.WithContext(context.TODO()) + login, userPortal, err := portal.FindPreferredLogin(ctx, user, false) + if err != nil { + log.Err(err).Msg("Failed to get preferred login for user") + return + } + evt := &MatrixReadReceipt{ + Portal: portal, + EventID: eventID, + Receipt: receipt, + } + if userPortal == nil { + userPortal = database.UserPortalFor(login.UserLogin, portal.PortalKey) + } else { + evt.LastRead = userPortal.LastRead + } + evt.ExactMessage, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, eventID) + if err != nil { + log.Err(err).Msg("Failed to get exact message from database") + } else if evt.ExactMessage != nil { + evt.ReadUpTo = evt.ExactMessage.Timestamp + } else { + evt.ReadUpTo = receipt.Timestamp + } + err = login.Client.HandleMatrixReadReceipt(ctx, evt) + if err != nil { + log.Err(err).Msg("Failed to handle read receipt") + return + } + userPortal.ResetValues() + if evt.ExactMessage != nil { + userPortal.LastRead = evt.ExactMessage.Timestamp + } else { + userPortal.LastRead = receipt.Timestamp + } + err = portal.Bridge.DB.UserPortal.Put(ctx, userPortal) + if err != nil { + log.Err(err).Msg("Failed to save user portal metadata") + } } func (portal *Portal) handleMatrixTyping(evt *event.Event) { @@ -660,7 +705,8 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { log.UpdateContext(evt.AddLogContext) ctx := log.WithContext(context.TODO()) if portal.MXID == "" { - if !evt.ShouldCreatePortal() { + mcp, ok := evt.(RemoteEventThatMayCreatePortal) + if !ok || !mcp.ShouldCreatePortal() { return } err := portal.CreateMatrixRoom(ctx, source) @@ -683,6 +729,12 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { portal.handleRemoteReactionRemove(ctx, source, evt.(RemoteReactionRemove)) case RemoteEventMessageRemove: portal.handleRemoteMessageRemove(ctx, source, evt.(RemoteMessageRemove)) + case RemoteEventReadReceipt: + portal.handleRemoteReadReceipt(ctx, source, evt.(RemoteReceipt)) + case RemoteEventDeliveryReceipt: + portal.handleRemoteDeliveryReceipt(ctx, source, evt.(RemoteReceipt)) + case RemoteEventTyping: + portal.handleRemoteTyping(ctx, source, evt.(RemoteTyping)) default: log.Warn().Int("type", int(evt.GetType())).Msg("Got remote event with unknown type") } @@ -1039,6 +1091,54 @@ func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *Use } } +func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserLogin, evt RemoteReceipt) { + log := zerolog.Ctx(ctx) + var err error + var lastTarget *database.Message + if lastTargetID := evt.GetLastReceiptTarget(); lastTargetID != "" { + lastTarget, err = portal.Bridge.DB.Message.GetLastPartByID(ctx, lastTargetID) + if err != nil { + log.Err(err).Str("last_target_id", string(lastTargetID)). + Msg("Failed to get last target message for read receipt") + return + } else if lastTarget == nil { + log.Debug().Str("last_target_id", string(lastTargetID)). + Msg("Last target message not found") + } + } + if lastTarget == nil { + for _, targetID := range evt.GetReceiptTargets() { + target, err := portal.Bridge.DB.Message.GetLastPartByID(ctx, targetID) + if err != nil { + log.Err(err).Str("target_id", string(targetID)). + Msg("Failed to get target message for read receipt") + return + } else if target != nil && (lastTarget == nil || target.Timestamp.After(lastTarget.Timestamp)) { + lastTarget = target + } + } + } + if lastTarget == nil { + log.Warn().Msg("No target message found for read receipt") + return + } + intent := portal.getIntentFor(ctx, evt.GetSender(), source) + err = intent.MarkRead(ctx, portal.MXID, lastTarget.MXID, getEventTS(evt)) + if err != nil { + log.Err(err).Stringer("target_mxid", lastTarget.MXID).Msg("Failed to bridge read receipt") + } else { + log.Debug().Stringer("target_mxid", lastTarget.MXID).Msg("Bridged read receipt") + } +} + +func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteReceipt) { + +} + +func (portal *Portal) handleRemoteTyping(ctx context.Context, source *UserLogin, evt RemoteTyping) { + +} + var stateElementFunctionalMembers = event.Type{Class: event.StateEventType, Type: "io.element.functional_members"} type PortalInfo struct { diff --git a/bridgev2/queue.go b/bridgev2/queue.go index cf4d27e8..4ec9fb30 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -92,7 +92,7 @@ func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) { return } // TODO put this in a better place, and maybe cache to avoid constant db queries - err = br.DB.UserLogin.EnsureUserPortalExists(ctx, login.UserLogin, portal.PortalKey) + err = br.DB.UserPortal.EnsureExists(ctx, login.UserLogin, portal.PortalKey) if err != nil { log.Warn().Err(err).Msg("Failed to ensure user portal row exists") } diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index b531bdfb..aa809736 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -143,7 +143,7 @@ func (ul *UserLogin) Logout(ctx context.Context) { } func (ul *UserLogin) MarkAsPreferredIn(ctx context.Context, portal *Portal) error { - return ul.Bridge.DB.UserLogin.MarkLoginAsPreferredInPortal(ctx, ul.UserLogin, portal.PortalKey) + return ul.Bridge.DB.UserPortal.MarkAsPreferred(ctx, ul.UserLogin, portal.PortalKey) } var _ status.BridgeStateFiller = (*UserLogin)(nil) From baa700e123f7431b2fdab28e0a6bbe95a5e54102 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 12 Jun 2024 19:53:28 +0300 Subject: [PATCH 0282/1647] bridgev2: add feature list --- bridgev2/unorganized-docs/FEATURES.md | 47 +++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 bridgev2/unorganized-docs/FEATURES.md diff --git a/bridgev2/unorganized-docs/FEATURES.md b/bridgev2/unorganized-docs/FEATURES.md new file mode 100644 index 00000000..dc3bf6e5 --- /dev/null +++ b/bridgev2/unorganized-docs/FEATURES.md @@ -0,0 +1,47 @@ +# Megabridge features + +* [ ] Messages + * [x] Text (incl. formatting and mentions) + * [x] Attachments + * [ ] Polls + * [x] Replies + * [ ] Threads + * [x] Edits + * [x] Reactions + * [ ] Reaction mass-syncing + * [x] Deletions + * [x] Message status events and error notices + * [ ] Backfilling history +* [x] Login +* [x] Logout +* [ ] Re-login after credential expiry +* [ ] Disappearing messages +* [x] Read receipts +* [ ] Presence +* [ ] Typing notifications +* [ ] Chat metadata + * [ ] Archive/low priority + * [ ] Pin/favorite + * [ ] Mark unread + * [ ] Mute status + * [ ] Temporary mutes ("snooze") +* [x] User metadata (name/avatar) +* [ ] Group metadata + * [ ] Initial meta and full resyncs + * [x] Name, avatar, topic + * [x] Members + * [ ] Permissions + * [ ] Change events + * [ ] Name, avatar, topic + * [ ] Members (join, leave, invite, kick, ban, knock) + * [ ] Permissions (promote, demote) +* [ ] Misc actions + * [ ] Invites / accepting message requests + * [ ] Create group + * [ ] Create DM + * [ ] Get contact list + * [ ] Check if identifier is on remote network + * [ ] Search users on remote network + * [ ] Delete chat + * [ ] Report spam +* [ ] Custom emojis From 5af319fe3f43f7f76152601c115b695dde58e35f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 12 Jun 2024 21:13:51 +0300 Subject: [PATCH 0283/1647] client: add support for retrying requests with seekable readers --- client.go | 66 +++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 42 insertions(+), 24 deletions(-) diff --git a/client.go b/client.go index ab68a426..7733b349 100644 --- a/client.go +++ b/client.go @@ -454,14 +454,21 @@ func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger { func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff time.Duration, responseJSON interface{}, handler ClientResponseHandler, client *http.Client) ([]byte, error) { log := zerolog.Ctx(req.Context()) if req.Body != nil { - if req.GetBody == nil { - log.Warn().Msg("Failed to get new body to retry request: GetBody is nil") - return nil, cause - } var err error - req.Body, err = req.GetBody() - if err != nil { - log.Warn().Err(err).Msg("Failed to get new body to retry request") + if req.GetBody != nil { + req.Body, err = req.GetBody() + if err != nil { + log.Warn().Err(err).Msg("Failed to get new body to retry request") + return nil, cause + } + } else if bodySeeker, ok := req.Body.(io.ReadSeeker); ok { + _, err = bodySeeker.Seek(0, io.SeekStart) + if err != nil { + log.Warn().Err(err).Msg("Failed to seek to beginning of request body") + return nil, cause + } + } else { + log.Warn().Msg("Failed to get new body to retry request: GetBody is nil and Body is not an io.ReadSeeker") return nil, cause } } @@ -1424,14 +1431,21 @@ func (cli *Client) GetDownloadURL(mxcURL id.ContentURI) string { func (cli *Client) doMediaRetry(req *http.Request, cause error, retries int, backoff time.Duration) (*http.Response, error) { log := zerolog.Ctx(req.Context()) if req.Body != nil { - if req.GetBody == nil { - log.Warn().Msg("Failed to get new body to retry request: GetBody is nil") - return nil, cause - } var err error - req.Body, err = req.GetBody() - if err != nil { - log.Warn().Err(err).Msg("Failed to get new body to retry request") + if req.GetBody != nil { + req.Body, err = req.GetBody() + if err != nil { + log.Warn().Err(err).Msg("Failed to get new body to retry request") + return nil, cause + } + } else if bodySeeker, ok := req.Body.(io.ReadSeeker); ok { + _, err = bodySeeker.Seek(0, io.SeekStart) + if err != nil { + log.Warn().Err(err).Msg("Failed to seek to beginning of request body") + return nil, cause + } + } else { + log.Warn().Msg("Failed to get new body to retry request: GetBody is nil and Body is not an io.ReadSeeker") return nil, cause } } @@ -1573,12 +1587,13 @@ type ReqUploadMedia struct { UnstableUploadURL string } -func (cli *Client) tryUploadMediaToURL(ctx context.Context, url, contentType string, content io.Reader) (*http.Response, error) { +func (cli *Client) tryUploadMediaToURL(ctx context.Context, url, contentType string, content io.Reader, contentLength int64) (*http.Response, error) { cli.Log.Debug().Str("url", url).Msg("Uploading media to external URL") req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, content) if err != nil { return nil, err } + req.ContentLength = contentLength req.Header.Set("Content-Type", contentType) req.Header.Set("User-Agent", cli.UserAgent+" (external media uploader)") @@ -1587,18 +1602,17 @@ func (cli *Client) tryUploadMediaToURL(ctx context.Context, url, contentType str func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (*RespMediaUpload, error) { retries := cli.DefaultHTTPRetries - if data.ContentBytes == nil { - // Can't retry with a reader + reader := data.Content + if data.ContentBytes != nil { + data.ContentLength = int64(len(data.ContentBytes)) + reader = bytes.NewReader(data.ContentBytes) + } + readerSeeker, canSeek := reader.(io.ReadSeeker) + if !canSeek { retries = 0 } for { - reader := data.Content - if reader == nil { - reader = bytes.NewReader(data.ContentBytes) - } else { - data.Content = nil - } - resp, err := cli.tryUploadMediaToURL(ctx, data.UnstableUploadURL, data.ContentType, reader) + resp, err := cli.tryUploadMediaToURL(ctx, data.UnstableUploadURL, data.ContentType, reader, data.ContentLength) if err == nil { if resp.StatusCode >= 200 && resp.StatusCode < 300 { // Everything is fine @@ -1614,6 +1628,10 @@ func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (* cli.Log.Warn().Str("url", data.UnstableUploadURL).Err(err). Msg("Error uploading media to external URL, retrying") retries-- + _, err = readerSeeker.Seek(0, io.SeekStart) + if err != nil { + return nil, fmt.Errorf("failed to seek back to start of reader: %w", err) + } } query := map[string]string{} From bbba811760a195aa77bd36609167d2c8062d145b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 13 Jun 2024 16:36:57 +0300 Subject: [PATCH 0284/1647] bridgev2/mxmain: actually exit if flags tell to exit --- bridgev2/matrix/mxmain/main.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index 061f681e..0a235e7d 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -137,7 +137,7 @@ func (br *BridgeMain) PreInit() { os.Exit(0) } else if *version { fmt.Println(br.VersionDesc) - return + os.Exit(0) } else if *versionJSON { output := VersionJSONOutput{ URL: br.URL, @@ -155,16 +155,16 @@ func (br *BridgeMain) PreInit() { output.Mautrix.Commit = mautrix.Commit output.Mautrix.Version = mautrix.Version _ = json.NewEncoder(os.Stdout).Encode(output) - return + os.Exit(0) } else if *writeExampleConfig { networkExample, _, _ := br.Connector.GetConfig() exerrors.PanicIfNotNil(os.WriteFile(*configPath, []byte(br.makeFullExampleConfig(networkExample)), 0600)) - return + os.Exit(0) } br.LoadConfig() if *generateRegistration { br.GenerateRegistration() - return + os.Exit(0) } } From ecd47aa2315423cc23e98eff02cf9597a3a00b89 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 13 Jun 2024 21:39:59 +0300 Subject: [PATCH 0285/1647] bridgev2/mxmain: fix some defaults in example config --- bridgev2/matrix/mxmain/example-config.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 2a4b805a..fdfd3e4a 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -36,13 +36,13 @@ homeserver: # The URL to push real-time bridge status to. # If set, the bridge will make POST requests to this URL whenever a user's remote network connection state changes. # The bridge will use the appservice as_token to authorize requests. - status_endpoint: http://localhost:4001 + status_endpoint: # Endpoint for reporting per-message status. # If set, the bridge will make POST requests to this URL when processing a message from Matrix. # It will make one request when receiving the message (step BRIDGE), one after decrypting if applicable # (step DECRYPTED) and one after sending to the remote network (step REMOTE). Errors will also be reported. # The bridge will use the appservice as_token to authorize requests. - message_send_checkpoint_endpoint: http://localhost:4001 + message_send_checkpoint_endpoint: # Does the homeserver support https://github.com/matrix-org/matrix-spec-proposals/pull/2246? async_media: false @@ -101,7 +101,7 @@ matrix: sync_direct_chat_list: false # Whether created rooms should have federation enabled. If false, created portal rooms # will never be federated. Changing this option requires recreating rooms. - federate_rooms: false + federate_rooms: true # Settings for provisioning API provisioning: @@ -115,7 +115,7 @@ provisioning: # which means that by default, it only works for users on the same server as the bridge. allow_matrix_auth: true # Enable debug API at /debug with provisioning authentication. - debug_endpoints: true + debug_endpoints: false # Settings for enabling double puppeting double_puppet: From 2863a1323b603a151610699878be2d584555674e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 13 Jun 2024 21:40:53 +0300 Subject: [PATCH 0286/1647] bridgev2/matrix: always import postgres and litestream --- bridgev2/matrix/connector.go | 2 ++ bridgev2/matrix/crypto.go | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index c3bf99de..09209bba 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -16,8 +16,10 @@ import ( "sync" "time" + _ "github.com/lib/pq" "github.com/rs/zerolog" "go.mau.fi/util/dbutil" + _ "go.mau.fi/util/dbutil/litestream" "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" diff --git a/bridgev2/matrix/crypto.go b/bridgev2/matrix/crypto.go index 56b0e179..427b369d 100644 --- a/bridgev2/matrix/crypto.go +++ b/bridgev2/matrix/crypto.go @@ -17,6 +17,7 @@ import ( "sync" "time" + "github.com/lib/pq" "github.com/rs/zerolog" "go.mau.fi/util/dbutil" @@ -29,6 +30,10 @@ import ( "maunium.net/go/mautrix/sqlstatestore" ) +func init() { + crypto.PostgresArrayWrapper = pq.Array +} + var _ crypto.StateStore = (*sqlstatestore.SQLStateStore)(nil) var NoSessionFound = crypto.NoSessionFound From 5272547ae7dc65884168540682c0b0d0e8acf5c3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 14 Jun 2024 12:42:50 +0300 Subject: [PATCH 0287/1647] bridgev2/mxmain: implement version command --- bridgev2/cmdmeta.go | 12 ------------ bridgev2/cmdprocessor.go | 2 +- bridgev2/matrix/mxmain/main.go | 10 ++++++++++ 3 files changed, 11 insertions(+), 13 deletions(-) diff --git a/bridgev2/cmdmeta.go b/bridgev2/cmdmeta.go index 1ea611f9..d866998f 100644 --- a/bridgev2/cmdmeta.go +++ b/bridgev2/cmdmeta.go @@ -17,18 +17,6 @@ var CommandHelp = &FullHandler{ }, } -var CommandVersion = &FullHandler{ - Func: func(ce *CommandEvent) { - ce.Reply("Bridge versions are not yet implemented") - //ce.Reply("[%s](%s) %s (%s)", ce.Bridge.Name, ce.Bridge.URL, ce.Bridge.LinkifiedVersion, ce.Bridge.BuildTime) - }, - Name: "version", - Help: HelpMeta{ - Section: HelpSectionGeneral, - Description: "Get the bridge version.", - }, -} - var CommandCancel = &FullHandler{ Func: func(ce *CommandEvent) { state := ce.User.CommandState.Swap(nil) diff --git a/bridgev2/cmdprocessor.go b/bridgev2/cmdprocessor.go index 1be36b47..efeb11a2 100644 --- a/bridgev2/cmdprocessor.go +++ b/bridgev2/cmdprocessor.go @@ -37,7 +37,7 @@ func NewProcessor(bridge *Bridge) *CommandProcessor { aliases: make(map[string]string), } proc.AddHandlers( - CommandHelp, CommandVersion, CommandCancel, + CommandHelp, CommandCancel, CommandLogin, CommandLogout, CommandSetPreferredLogin, ) return proc diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index 0a235e7d..dd4c0328 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -227,6 +227,16 @@ func (br *BridgeMain) Init() { br.Matrix.IgnoreUnsupportedServer = *ignoreUnsupportedServer br.Bridge = bridgev2.NewBridge("", br.DB, *br.Log, &br.Config.Bridge, br.Matrix, br.Connector) br.Matrix.AS.DoublePuppetValue = br.Name + br.Bridge.Commands.AddHandler(&bridgev2.FullHandler{ + Func: func(ce *bridgev2.CommandEvent) { + ce.Reply("[%s](%s) %s (%s)", br.Name, br.URL, br.LinkifiedVersion, br.BuildTime.Format(time.RFC1123)) + }, + Name: "version", + Help: bridgev2.HelpMeta{ + Section: bridgev2.HelpSectionGeneral, + Description: "Get the bridge version.", + }, + }) if br.PostInit != nil { br.PostInit() } From b456fb6e0a6b24c4ed0fa06a0d09da6959ccc672 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 14 Jun 2024 12:47:08 +0300 Subject: [PATCH 0288/1647] bridgev2: initialize bridge state queue when creating UserLogin via User.NewLogin --- bridgev2/userlogin.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index aa809736..3a980c01 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -117,6 +117,7 @@ func (user *User) NewLogin(ctx context.Context, data *database.UserLogin, client if err != nil { return nil, err } + ul.BridgeState = user.Bridge.NewBridgeStateQueue(ul) user.Bridge.cacheLock.Lock() defer user.Bridge.cacheLock.Unlock() user.Bridge.userLoginsByID[ul.ID] = ul From 8a727c001dc48492a5196ebcfa10677e2d8257c3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 14 Jun 2024 12:58:42 +0300 Subject: [PATCH 0289/1647] bridgev2: include remote name in user login list --- bridgev2/cmdlogin.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bridgev2/cmdlogin.go b/bridgev2/cmdlogin.go index 740915a1..32e59cde 100644 --- a/bridgev2/cmdlogin.go +++ b/bridgev2/cmdlogin.go @@ -303,8 +303,9 @@ var CommandLogout = &FullHandler{ func getUserLogins(user *User) string { user.Bridge.cacheLock.Lock() logins := make([]string, len(user.logins)) - for key := range user.logins { - logins = append(logins, fmt.Sprintf("* `%s`", key)) + for key, val := range user.logins { + remoteName, _ := val.Metadata["remote_name"].(string) + logins = append(logins, fmt.Sprintf("* `%s` (%s)", key, remoteName)) } user.Bridge.cacheLock.Unlock() return strings.Join(logins, "\n") From 869a04dadbe4ef1bcc0337dfc42f4145aba70319 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 14 Jun 2024 12:59:01 +0300 Subject: [PATCH 0290/1647] bridgev2/matrix: don't send status events for ephemeral events --- bridgev2/matrix/connector.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 09209bba..555e408b 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -327,6 +327,9 @@ func (br *Connector) SendMessageStatus(ctx context.Context, ms *bridgev2.Message } func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2.MessageStatus, evt *bridgev2.MessageStatusEventInfo, editEvent id.EventID) id.EventID { + if evt.EventType.IsEphemeral() { + return "" + } log := zerolog.Ctx(ctx) err := br.SendMessageCheckpoints([]*status.MessageCheckpoint{ms.ToCheckpoint(evt)}) if err != nil { From 9304e2b9d711c36870df7dde7ace6485fa9640ff Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 15 Jun 2024 14:24:12 +0300 Subject: [PATCH 0291/1647] bridgev2: update FEATURES.md --- bridgev2/unorganized-docs/FEATURES.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bridgev2/unorganized-docs/FEATURES.md b/bridgev2/unorganized-docs/FEATURES.md index dc3bf6e5..908ca975 100644 --- a/bridgev2/unorganized-docs/FEATURES.md +++ b/bridgev2/unorganized-docs/FEATURES.md @@ -19,6 +19,8 @@ * [x] Read receipts * [ ] Presence * [ ] Typing notifications +* [ ] Spaces +* [ ] Relay mode * [ ] Chat metadata * [ ] Archive/low priority * [ ] Pin/favorite From bad4de70f7098f25abe9cdfb66ea6f36341dcdbf Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 15 Jun 2024 14:33:30 +0300 Subject: [PATCH 0292/1647] hicli: fix some bugs --- hicli/database/room.go | 7 ++++++- hicli/database/state.go | 2 +- hicli/hitest/hitest.go | 3 +++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/hicli/database/room.go b/hicli/database/room.go index e3b140d0..1538ef12 100644 --- a/hicli/database/room.go +++ b/hicli/database/room.go @@ -9,6 +9,7 @@ package database import ( "context" "database/sql" + "errors" "time" "go.mau.fi/util/dbutil" @@ -82,7 +83,11 @@ func (rq *RoomQuery) SetPrevBatch(ctx context.Context, roomID id.RoomID, prevBat func (rq *RoomQuery) UpdatePreviewIfLaterOnTimeline(ctx context.Context, roomID id.RoomID, rowID EventRowID) (previewChanged bool, err error) { var newPreviewRowID EventRowID err = rq.GetDB().QueryRow(ctx, updateRoomPreviewIfLaterOnTimelineQuery, roomID, rowID).Scan(&newPreviewRowID) - previewChanged = newPreviewRowID == rowID + if errors.Is(err, sql.ErrNoRows) { + err = nil + } else if err == nil { + previewChanged = newPreviewRowID == rowID + } return } diff --git a/hicli/database/state.go b/hicli/database/state.go index 1b542f9f..845de6ed 100644 --- a/hicli/database/state.go +++ b/hicli/database/state.go @@ -22,7 +22,7 @@ const ( ` getCurrentRoomStateQuery = ` SELECT event.rowid, -1, event.room_id, event.event_id, sender, event.type, event.state_key, timestamp, content, decrypted, decrypted_type, unsigned, - redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, reactions, last_edit_rowid + transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, reactions, last_edit_rowid FROM current_state cs JOIN event ON cs.event_rowid = event.rowid WHERE cs.room_id = $1 diff --git a/hicli/hitest/hitest.go b/hicli/hitest/hitest.go index 97705304..c6873bac 100644 --- a/hicli/hitest/hitest.go +++ b/hicli/hitest/hitest.go @@ -93,6 +93,9 @@ func main() { break } fields := strings.Fields(line) + if len(fields) == 0 { + continue + } switch strings.ToLower(fields[0]) { case "send": resp, err := cli.Send(ctx, id.RoomID(fields[1]), event.EventMessage, &event.MessageEventContent{ From 1d800734acfb026c272b61d916ce2518e3b47995 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 15 Jun 2024 19:44:04 +0300 Subject: [PATCH 0293/1647] hicli: better get event functions --- hicli/database/event.go | 11 ++++++++--- hicli/database/room.go | 26 +++++++++++++------------- hicli/paginate.go | 12 ++++++++++++ 3 files changed, 33 insertions(+), 16 deletions(-) diff --git a/hicli/database/event.go b/hicli/database/event.go index f4df8868..cb9568f6 100644 --- a/hicli/database/event.go +++ b/hicli/database/event.go @@ -28,6 +28,7 @@ const ( transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, reactions, last_edit_rowid FROM event ` + getEventByRowID = getEventBaseQuery + `WHERE rowid = $1` getManyEventsByRowID = getEventBaseQuery + `WHERE rowid IN (%s)` getEventByID = getEventBaseQuery + `WHERE event_id = $1` getFailedEventsByMegolmSessionID = getEventBaseQuery + `WHERE room_id = $1 AND megolm_session_id = $2 AND decryption_error IS NOT NULL` @@ -93,6 +94,10 @@ func (eq *EventQuery) GetByID(ctx context.Context, eventID id.EventID) (*Event, return eq.QueryOne(ctx, getEventByID, eventID) } +func (eq *EventQuery) GetByRowID(ctx context.Context, rowID EventRowID) (*Event, error) { + return eq.QueryOne(ctx, getEventByRowID, rowID) +} + func (eq *EventQuery) GetByRowIDs(ctx context.Context, rowIDs ...EventRowID) ([]*Event, error) { query, params := buildMultiEventGetFunction(nil, rowIDs, getManyEventsByRowID) return eq.QueryMany(ctx, query, params...) @@ -274,10 +279,10 @@ type Event struct { RelationType event.RelationType `json:"relation_type,omitempty"` MegolmSessionID id.SessionID `json:"-,omitempty"` - DecryptionError string + DecryptionError string `json:"decryption_error,omitempty"` - Reactions map[string]int - LastEditRowID *EventRowID + Reactions map[string]int `json:"reactions,omitempty"` + LastEditRowID *EventRowID `json:"last_edit_rowid,omitempty"` } func MautrixToEvent(evt *event.Event) *Event { diff --git a/hicli/database/room.go b/hicli/database/room.go index 1538ef12..92adc279 100644 --- a/hicli/database/room.go +++ b/hicli/database/room.go @@ -101,24 +101,24 @@ const ( ) type Room struct { - ID id.RoomID - CreationContent *event.CreateEventContent + ID id.RoomID `json:"room_id"` + CreationContent *event.CreateEventContent `json:"creation_content,omitempty"` - Name *string - NameQuality NameQuality - Avatar *id.ContentURI - Topic *string - CanonicalAlias *id.RoomAlias + Name *string `json:"name,omitempty"` + NameQuality NameQuality `json:"name_quality"` + Avatar *id.ContentURI `json:"avatar,omitempty"` + Topic *string `json:"topic,omitempty"` + CanonicalAlias *id.RoomAlias `json:"canonical_alias,omitempty"` - LazyLoadSummary *mautrix.LazyLoadSummary + LazyLoadSummary *mautrix.LazyLoadSummary `json:"lazy_load_summary,omitempty"` - EncryptionEvent *event.EncryptionEventContent - HasMemberList bool + EncryptionEvent *event.EncryptionEventContent `json:"encryption_event,omitempty"` + HasMemberList bool `json:"has_member_list"` - PreviewEventRowID EventRowID - SortingTimestamp time.Time + PreviewEventRowID EventRowID `json:"preview_event_rowid"` + SortingTimestamp time.Time `json:"sorting_timestamp"` - PrevBatch string + PrevBatch string `json:"prev_batch"` } func (r *Room) CheckChangesAndCopyInto(other *Room) (hasChanges bool) { diff --git a/hicli/paginate.go b/hicli/paginate.go index 957ac3e3..9992b36e 100644 --- a/hicli/paginate.go +++ b/hicli/paginate.go @@ -48,6 +48,18 @@ func (h *HiClient) GetEventsByRowIDs(ctx context.Context, rowIDs []database.Even return events, nil } +func (h *HiClient) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*database.Event, error) { + if evt, err := h.DB.Event.GetByID(ctx, eventID); err != nil { + return nil, fmt.Errorf("failed to get event from database: %w", err) + } else if evt != nil { + return evt, nil + } else if serverEvt, err := h.Client.GetEvent(ctx, roomID, eventID); err != nil { + return nil, fmt.Errorf("failed to get event from server: %w", err) + } else { + return h.processEvent(ctx, serverEvt, nil, false) + } +} + func (h *HiClient) Paginate(ctx context.Context, roomID id.RoomID, maxTimelineID database.TimelineRowID, limit int) ([]*database.Event, error) { evts, err := h.DB.Timeline.Get(ctx, roomID, limit, maxTimelineID) if err != nil { From b571d922e041f9b423d51b0d0ecf022489c711c5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 15 Jun 2024 19:44:35 +0300 Subject: [PATCH 0294/1647] hicli: include all new events in sync completed event --- hicli/events.go | 1 + hicli/sync.go | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/hicli/events.go b/hicli/events.go index 75894111..a30dda8d 100644 --- a/hicli/events.go +++ b/hicli/events.go @@ -15,6 +15,7 @@ import ( type SyncRoom struct { Meta *database.Room `json:"meta"` Timeline []database.TimelineRowTuple `json:"timeline"` + Events []*database.Event `json:"events"` Reset bool `json:"reset"` } diff --git a/hicli/sync.go b/hicli/sync.go index 1764f09e..24f0cfed 100644 --- a/hicli/sync.go +++ b/hicli/sync.go @@ -302,6 +302,7 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R heroesChanged = true } decryptionQueue := make(map[id.SessionID]*database.SessionRequest) + allNewEvents := make([]*database.Event, 0, len(state.Events)+len(timeline.Events)) processNewEvent := func(evt *event.Event, isTimeline bool) (database.EventRowID, error) { evt.RoomID = room.ID dbEvt, err := h.processEvent(ctx, evt, decryptionQueue, false) @@ -330,6 +331,7 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R } processImportantEvent(ctx, evt, room, updatedRoom) } + allNewEvents = append(allNewEvents, dbEvt) return dbEvt.RowID, nil } var err error @@ -399,11 +401,12 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R return fmt.Errorf("failed to save room data: %w", err) } } - if roomChanged || len(timelineRowTuples) > 0 { + if roomChanged || len(timelineRowTuples) > 0 || len(allNewEvents) > 0 { ctx.Value(syncContextKey).(*syncContext).evt.Rooms[room.ID] = &SyncRoom{ Meta: room, Timeline: timelineRowTuples, Reset: timeline.Limited, + Events: allNewEvents, } } return nil From 9ed8ca3d37c66b8fed702bfdc991b66d92cb78c6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 15 Jun 2024 19:44:46 +0300 Subject: [PATCH 0295/1647] hicli: check spec versions on startup --- hicli/hicli.go | 19 +++++++++++++++++++ hicli/login.go | 4 ++++ 2 files changed, 23 insertions(+) diff --git a/hicli/hicli.go b/hicli/hicli.go index 7caee4a8..7524b6bc 100644 --- a/hicli/hicli.go +++ b/hicli/hicli.go @@ -19,6 +19,7 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/dbutil" + "go.mau.fi/util/exerrors" "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto" @@ -145,6 +146,10 @@ func (h *HiClient) Start(ctx context.Context, userID id.UserID, expectedAccount if err != nil { return err } + err = h.CheckServerVersions(ctx) + if err != nil { + return err + } err = h.Crypto.Load(ctx) if err != nil { return fmt.Errorf("failed to load olm machine: %w", err) @@ -167,6 +172,20 @@ func (h *HiClient) Start(ctx context.Context, userID id.UserID, expectedAccount return nil } +var ErrFailedToCheckServerVersions = errors.New("failed to check server versions") +var ErrOutdatedServer = errors.New("homeserver is outdated") +var MinimumSpecVersion = mautrix.SpecV11 + +func (h *HiClient) CheckServerVersions(ctx context.Context) error { + versions, err := h.Client.Versions(ctx) + if err != nil { + return exerrors.NewDualError(ErrFailedToCheckServerVersions, err) + } else if !versions.Contains(MinimumSpecVersion) { + return fmt.Errorf("%w (minimum: %s, highest supported: %s)", ErrOutdatedServer, MinimumSpecVersion, versions.GetLatest()) + } + return nil +} + func (h *HiClient) Sync() { h.Client.StopSync() if fn := h.stopSync; fn != nil { diff --git a/hicli/login.go b/hicli/login.go index 2f9efb2d..d33ea422 100644 --- a/hicli/login.go +++ b/hicli/login.go @@ -36,6 +36,10 @@ func (h *HiClient) LoginPassword(ctx context.Context, homeserverURL, username, p } func (h *HiClient) Login(ctx context.Context, req *mautrix.ReqLogin) error { + err := h.CheckServerVersions(ctx) + if err != nil { + return err + } req.StoreCredentials = true req.StoreHomeserverURL = true resp, err := h.Client.Login(ctx, req) From a44362dc71d36abef1558cbdd215c260085c9b72 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 15 Jun 2024 20:15:06 +0300 Subject: [PATCH 0296/1647] client: stop using MakeFullRequest unnecessarily --- client.go | 15 +++--------- synapseadmin/register.go | 13 ++-------- synapseadmin/roomapi.go | 49 ++++++------------------------------- synapseadmin/userapi.go | 53 +++++++--------------------------------- 4 files changed, 22 insertions(+), 108 deletions(-) diff --git a/client.go b/client.go index ffb7b7a1..2b958688 100644 --- a/client.go +++ b/client.go @@ -334,7 +334,7 @@ func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err er } } -func (cli *Client) MakeRequest(ctx context.Context, method string, httpURL string, reqBody interface{}, resBody interface{}) ([]byte, error) { +func (cli *Client) MakeRequest(ctx context.Context, method string, httpURL string, reqBody any, resBody any) ([]byte, error) { return cli.MakeFullRequest(ctx, FullRequest{Method: method, URL: httpURL, RequestJSON: reqBody, ResponseJSON: resBody}) } @@ -1528,13 +1528,8 @@ func (cli *Client) DownloadBytes(ctx context.Context, mxcURL id.ContentURI) ([]b // // See https://spec.matrix.org/v1.7/client-server-api/#post_matrixmediav1create func (cli *Client) CreateMXC(ctx context.Context) (*RespCreateMXC, error) { - u, _ := url.Parse(cli.BuildURL(MediaURLPath{"v1", "create"})) var m RespCreateMXC - _, err := cli.MakeFullRequest(ctx, FullRequest{ - Method: http.MethodPost, - URL: u.String(), - ResponseJSON: &m, - }) + _, err := cli.MakeRequest(ctx, http.MethodPost, cli.BuildURL(MediaURLPath{"v1", "create"}), nil, &m) return &m, err } @@ -1653,11 +1648,7 @@ func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (* notifyURL := cli.BuildURLWithQuery(MediaURLPath{"unstable", "com.beeper.msc3870", "upload", data.MXC.Homeserver, data.MXC.FileID, "complete"}, query) var m *RespMediaUpload - _, err := cli.MakeFullRequest(ctx, FullRequest{ - Method: http.MethodPost, - URL: notifyURL, - ResponseJSON: m, - }) + _, err := cli.MakeRequest(ctx, http.MethodPost, notifyURL, nil, &m) if err != nil { return nil, err } diff --git a/synapseadmin/register.go b/synapseadmin/register.go index d7a94f6f..641f9b56 100644 --- a/synapseadmin/register.go +++ b/synapseadmin/register.go @@ -73,11 +73,7 @@ func (req *ReqSharedSecretRegister) Sign(secret string) string { // This does not need to be called manually as SharedSecretRegister will automatically call this if no nonce is provided. func (cli *Client) GetRegisterNonce(ctx context.Context) (string, error) { var resp respGetRegisterNonce - _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodGet, - URL: cli.BuildURL(mautrix.SynapseAdminURLPath{"v1", "register"}), - ResponseJSON: &resp, - }) + _, err := cli.MakeRequest(ctx, http.MethodGet, cli.BuildURL(mautrix.SynapseAdminURLPath{"v1", "register"}), nil, &resp) if err != nil { return "", err } @@ -97,12 +93,7 @@ func (cli *Client) SharedSecretRegister(ctx context.Context, sharedSecret string } req.SHA1Checksum = req.Sign(sharedSecret) var resp mautrix.RespRegister - _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodPost, - URL: cli.BuildURL(mautrix.SynapseAdminURLPath{"v1", "register"}), - RequestJSON: req, - ResponseJSON: &resp, - }) + _, err = cli.MakeRequest(ctx, http.MethodPost, cli.BuildURL(mautrix.SynapseAdminURLPath{"v1", "register"}), &req, &resp) if err != nil { return nil, err } diff --git a/synapseadmin/roomapi.go b/synapseadmin/roomapi.go index 0953377e..6c072e23 100644 --- a/synapseadmin/roomapi.go +++ b/synapseadmin/roomapi.go @@ -77,11 +77,7 @@ func (cli *Client) ListRooms(ctx context.Context, req ReqListRoom) (RespListRoom var resp RespListRooms var reqURL string reqURL = cli.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery()) - _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodGet, - URL: reqURL, - ResponseJSON: &resp, - }) + _, err := cli.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) return resp, err } @@ -109,11 +105,7 @@ func (cli *Client) RoomMessages(ctx context.Context, roomID id.RoomID, from, to query["limit"] = strconv.Itoa(limit) } urlPath := cli.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms", roomID, "messages"}, query) - _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodGet, - URL: urlPath, - ResponseJSON: &resp, - }) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return resp, err } @@ -137,12 +129,7 @@ type RespDeleteRoom struct { func (cli *Client) DeleteRoom(ctx context.Context, roomID id.RoomID, req ReqDeleteRoom) (RespDeleteRoom, error) { reqURL := cli.BuildAdminURL("v2", "rooms", roomID) var resp RespDeleteRoom - _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodDelete, - URL: reqURL, - ResponseJSON: &resp, - RequestJSON: &req, - }) + _, err := cli.MakeRequest(ctx, http.MethodDelete, reqURL, &req, &resp) return resp, err } @@ -157,11 +144,7 @@ type RespRoomsMembers struct { func (cli *Client) RoomMembers(ctx context.Context, roomID id.RoomID) (RespRoomsMembers, error) { reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "members") var resp RespRoomsMembers - _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodGet, - URL: reqURL, - ResponseJSON: &resp, - }) + _, err := cli.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) return resp, err } @@ -174,11 +157,7 @@ type ReqMakeRoomAdmin struct { // https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#make-room-admin-api func (cli *Client) MakeRoomAdmin(ctx context.Context, roomIDOrAlias string, req ReqMakeRoomAdmin) error { reqURL := cli.BuildAdminURL("v1", "rooms", roomIDOrAlias, "make_room_admin") - _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodPost, - URL: reqURL, - RequestJSON: &req, - }) + _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) return err } @@ -191,11 +170,7 @@ type ReqJoinUserToRoom struct { // https://matrix-org.github.io/synapse/latest/admin_api/room_membership.html func (cli *Client) JoinUserToRoom(ctx context.Context, roomID id.RoomID, req ReqJoinUserToRoom) error { reqURL := cli.BuildAdminURL("v1", "join", roomID) - _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodPost, - URL: reqURL, - RequestJSON: &req, - }) + _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) return err } @@ -208,11 +183,7 @@ type ReqBlockRoom struct { // https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#block-room-api func (cli *Client) BlockRoom(ctx context.Context, roomID id.RoomID, req ReqBlockRoom) error { reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "block") - _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodPut, - URL: reqURL, - RequestJSON: &req, - }) + _, err := cli.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil) return err } @@ -228,10 +199,6 @@ type RoomsBlockResponse struct { func (cli *Client) GetRoomBlockStatus(ctx context.Context, roomID id.RoomID) (RoomsBlockResponse, error) { var resp RoomsBlockResponse reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "block") - _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodGet, - URL: reqURL, - ResponseJSON: &resp, - }) + _, err := cli.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) return resp, err } diff --git a/synapseadmin/userapi.go b/synapseadmin/userapi.go index 31d0a6dc..9cbb17e4 100644 --- a/synapseadmin/userapi.go +++ b/synapseadmin/userapi.go @@ -32,11 +32,7 @@ type ReqResetPassword struct { // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#reset-password func (cli *Client) ResetPassword(ctx context.Context, req ReqResetPassword) error { reqURL := cli.BuildAdminURL("v1", "reset_password", req.UserID) - _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodPost, - URL: reqURL, - RequestJSON: &req, - }) + _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) return err } @@ -48,11 +44,7 @@ func (cli *Client) ResetPassword(ctx context.Context, req ReqResetPassword) erro // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#check-username-availability func (cli *Client) UsernameAvailable(ctx context.Context, username string) (resp *mautrix.RespRegisterAvailable, err error) { u := cli.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "username_available"}, map[string]string{"username": username}) - _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodGet, - URL: u, - ResponseJSON: &resp, - }) + _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp) if err == nil && !resp.Available { err = fmt.Errorf(`request returned OK status without "available": true`) } @@ -73,11 +65,7 @@ type RespListDevices struct { // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#list-all-devices func (cli *Client) ListDevices(ctx context.Context, userID id.UserID) (resp *RespListDevices, err error) { - _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodGet, - URL: cli.BuildAdminURL("v2", "users", userID, "devices"), - ResponseJSON: &resp, - }) + _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v2", "users", userID, "devices"), nil, &resp) return } @@ -101,11 +89,7 @@ type RespUserInfo struct { // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#query-user-account func (cli *Client) GetUserInfo(ctx context.Context, userID id.UserID) (resp *RespUserInfo, err error) { - _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodGet, - URL: cli.BuildAdminURL("v2", "users", userID), - ResponseJSON: &resp, - }) + _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v2", "users", userID), nil, &resp) return } @@ -118,11 +102,7 @@ type ReqDeleteUser struct { // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#deactivate-account func (cli *Client) DeactivateAccount(ctx context.Context, userID id.UserID, req ReqDeleteUser) error { reqURL := cli.BuildAdminURL("v1", "deactivate", userID) - _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodPost, - URL: reqURL, - RequestJSON: &req, - }) + _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) return err } @@ -144,11 +124,7 @@ type ReqCreateOrModifyAccount struct { // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#create-or-modify-account func (cli *Client) CreateOrModifyAccount(ctx context.Context, userID id.UserID, req ReqCreateOrModifyAccount) error { reqURL := cli.BuildAdminURL("v2", "users", userID) - _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodPut, - URL: reqURL, - RequestJSON: &req, - }) + _, err := cli.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil) return err } @@ -164,11 +140,7 @@ type ReqSetRatelimit = RatelimitOverride // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#set-ratelimit func (cli *Client) SetUserRatelimit(ctx context.Context, userID id.UserID, req ReqSetRatelimit) error { reqURL := cli.BuildAdminURL("v1", "users", userID, "override_ratelimit") - _, err := cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodPost, - URL: reqURL, - RequestJSON: &req, - }) + _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) return err } @@ -178,11 +150,7 @@ type RespUserRatelimit = RatelimitOverride // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#get-status-of-ratelimit func (cli *Client) GetUserRatelimit(ctx context.Context, userID id.UserID) (resp RespUserRatelimit, err error) { - _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodGet, - URL: cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), - ResponseJSON: &resp, - }) + _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), nil, &resp) return } @@ -190,9 +158,6 @@ func (cli *Client) GetUserRatelimit(ctx context.Context, userID id.UserID) (resp // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#delete-ratelimit func (cli *Client) DeleteUserRatelimit(ctx context.Context, userID id.UserID) (err error) { - _, err = cli.MakeFullRequest(ctx, mautrix.FullRequest{ - Method: http.MethodDelete, - URL: cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), - }) + _, err = cli.MakeRequest(ctx, http.MethodDelete, cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), nil, nil) return } From 3b55fedc173e41e6c2e6be4b9d65691617262452 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 15 Jun 2024 20:16:19 +0300 Subject: [PATCH 0297/1647] client: cache spec versions supported by server --- appservice/appservice.go | 16 ++++++++++------ bridge/bridge.go | 1 + bridgev2/matrix/connector.go | 1 + client.go | 4 ++++ versions.go | 11 +++++++++++ 5 files changed, 27 insertions(+), 6 deletions(-) diff --git a/appservice/appservice.go b/appservice/appservice.go index 814cd5c4..90ace5d9 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -56,6 +56,8 @@ func Create() *AppService { DeviceLists: make(chan *mautrix.DeviceLists, EventChannelSize), QueryHandler: &QueryHandlerStub{}, + SpecVersions: &mautrix.RespVersions{}, + DefaultHTTPRetries: 4, } @@ -158,12 +160,13 @@ type AppService struct { QueryHandler QueryHandler StateStore StateStore - Router *mux.Router - UserAgent string - server *http.Server - HTTPClient *http.Client - botClient *mautrix.Client - botIntent *IntentAPI + Router *mux.Router + UserAgent string + server *http.Server + HTTPClient *http.Client + botClient *mautrix.Client + botIntent *IntentAPI + SpecVersions *mautrix.RespVersions DefaultHTTPRetries int @@ -365,6 +368,7 @@ func (as *AppService) NewMautrixClient(userID id.UserID) *mautrix.Client { Log: as.Log.With().Str("as_user_id", userID.String()).Logger(), Client: as.HTTPClient, DefaultHTTPRetries: as.DefaultHTTPRetries, + SpecVersions: as.SpecVersions, } } diff --git a/bridge/bridge.go b/bridge/bridge.go index 053c9021..40d4c615 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -328,6 +328,7 @@ func (br *Bridge) ensureConnection(ctx context.Context) { time.Sleep(10 * time.Second) } else { br.SpecVersions = *versions + *br.AS.SpecVersions = *versions break } } diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 555e408b..3493e7a3 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -151,6 +151,7 @@ func (br *Connector) ensureConnection(ctx context.Context) { time.Sleep(10 * time.Second) } else { br.SpecVersions = versions + *br.AS.SpecVersions = *versions break } } diff --git a/client.go b/client.go index 2b958688..6f1ef6e1 100644 --- a/client.go +++ b/client.go @@ -61,6 +61,7 @@ type Client struct { StateStore StateStore Crypto CryptoHelper Verification VerificationHelper + SpecVersions *RespVersions Log zerolog.Logger @@ -871,6 +872,9 @@ func (cli *Client) LogoutAll(ctx context.Context) (resp *RespLogout, err error) func (cli *Client) Versions(ctx context.Context) (resp *RespVersions, err error) { urlPath := cli.BuildClientURL("versions") _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + if resp != nil { + cli.SpecVersions = resp + } return } diff --git a/versions.go b/versions.go index bdddf729..246889c0 100644 --- a/versions.go +++ b/versions.go @@ -19,6 +19,9 @@ type RespVersions struct { } func (versions *RespVersions) ContainsFunc(match func(found SpecVersion) bool) bool { + if versions == nil { + return false + } for _, found := range versions.Versions { if match(found) { return true @@ -40,6 +43,9 @@ func (versions *RespVersions) ContainsGreaterOrEqual(version SpecVersion) bool { } func (versions *RespVersions) GetLatest() (latest SpecVersion) { + if versions == nil { + return + } for _, ver := range versions.Versions { if ver.GreaterThan(latest) { latest = ver @@ -65,6 +71,9 @@ var ( ) func (versions *RespVersions) Supports(feature UnstableFeature) bool { + if versions == nil { + return false + } return versions.UnstableFeatures[feature.UnstableFlag] || (!feature.SpecVersion.IsEmpty() && versions.ContainsGreaterOrEqual(feature.SpecVersion)) } @@ -96,6 +105,8 @@ var ( SpecV17 = MustParseSpecVersion("v1.7") SpecV18 = MustParseSpecVersion("v1.8") SpecV19 = MustParseSpecVersion("v1.9") + SpecV110 = MustParseSpecVersion("v1.10") + SpecV111 = MustParseSpecVersion("v1.11") ) func (svf SpecVersionFormat) String() string { From b959fbf737216300e08d40434d920bd1876421c6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 15 Jun 2024 20:18:40 +0300 Subject: [PATCH 0298/1647] client: return `*http.Response` in `MakeFullRequest` --- client.go | 68 +++++++++++++++++++++++++++++-------------------------- 1 file changed, 36 insertions(+), 32 deletions(-) diff --git a/client.go b/client.go index 6f1ef6e1..553b7e04 100644 --- a/client.go +++ b/client.go @@ -336,7 +336,8 @@ func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err er } func (cli *Client) MakeRequest(ctx context.Context, method string, httpURL string, reqBody any, resBody any) ([]byte, error) { - return cli.MakeFullRequest(ctx, FullRequest{Method: method, URL: httpURL, RequestJSON: reqBody, ResponseJSON: resBody}) + data, _, err := cli.MakeFullRequest(ctx, FullRequest{Method: method, URL: httpURL, RequestJSON: reqBody, ResponseJSON: resBody}) + return data, err } type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) @@ -354,6 +355,7 @@ type FullRequest struct { BackoffDuration time.Duration SensitiveContent bool Handler ClientResponseHandler + DontReadResponse bool Logger *zerolog.Logger Client *http.Client } @@ -418,13 +420,7 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e return req, nil } -// MakeFullRequest makes a JSON HTTP request to the given URL. -// If "resBody" is not nil, the response body will be json.Unmarshalled into it. -// -// Returns the HTTP body as bytes on 2xx with a nil error. Returns an error if the response is not 2xx along -// with the HTTP body bytes if it got that far. This error is an HTTPError which includes the returned -// HTTP status code and possibly a RespError as the WrappedError, if the HTTP body could be decoded as a RespError. -func (cli *Client) MakeFullRequest(ctx context.Context, params FullRequest) ([]byte, error) { +func (cli *Client) MakeFullRequest(ctx context.Context, params FullRequest) ([]byte, *http.Response, error) { if params.MaxAttempts == 0 { params.MaxAttempts = 1 + cli.DefaultHTTPRetries } @@ -440,10 +436,14 @@ func (cli *Client) MakeFullRequest(ctx context.Context, params FullRequest) ([]b } req, err := params.compileRequest(ctx) if err != nil { - return nil, err + return nil, nil, err } if params.Handler == nil { - params.Handler = handleNormalResponse + if params.DontReadResponse { + params.Handler = noopHandleResponse + } else { + params.Handler = handleNormalResponse + } } req.Header.Set("User-Agent", cli.UserAgent) if len(cli.AccessToken) > 0 { @@ -452,7 +452,7 @@ func (cli *Client) MakeFullRequest(ctx context.Context, params FullRequest) ([]b if params.Client == nil { params.Client = cli.Client } - return cli.executeCompiledRequest(req, params.MaxAttempts-1, params.BackoffDuration, params.ResponseJSON, params.Handler, params.Client) + return cli.executeCompiledRequest(req, params.MaxAttempts-1, params.BackoffDuration, params.ResponseJSON, params.Handler, params.DontReadResponse, params.Client) } func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger { @@ -463,7 +463,7 @@ func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger { return log } -func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff time.Duration, responseJSON interface{}, handler ClientResponseHandler, client *http.Client) ([]byte, error) { +func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff time.Duration, responseJSON any, handler ClientResponseHandler, dontReadResponse bool, client *http.Client) ([]byte, *http.Response, error) { log := zerolog.Ctx(req.Context()) if req.Body != nil { var err error @@ -471,17 +471,17 @@ func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff req.Body, err = req.GetBody() if err != nil { log.Warn().Err(err).Msg("Failed to get new body to retry request") - return nil, cause + return nil, nil, cause } } else if bodySeeker, ok := req.Body.(io.ReadSeeker); ok { _, err = bodySeeker.Seek(0, io.SeekStart) if err != nil { log.Warn().Err(err).Msg("Failed to seek to beginning of request body") - return nil, cause + return nil, nil, cause } } else { log.Warn().Msg("Failed to get new body to retry request: GetBody is nil and Body is not an io.ReadSeeker") - return nil, cause + return nil, nil, cause } } log.Warn().Err(cause). @@ -491,10 +491,10 @@ func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff if cli.UpdateRequestOnRetry != nil { req = cli.UpdateRequestOnRetry(req, cause) } - return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, client) + return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, dontReadResponse, client) } -func readRequestBody(req *http.Request, res *http.Response) ([]byte, error) { +func readResponseBody(req *http.Request, res *http.Response) ([]byte, error) { contents, err := io.ReadAll(res.Body) if err != nil { return nil, HTTPError{ @@ -536,8 +536,12 @@ func streamResponse(req *http.Request, res *http.Response, responseJSON interfac } } +func noopHandleResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { + return nil, nil +} + func handleNormalResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { - if contents, err := readRequestBody(req, res); err != nil { + if contents, err := readResponseBody(req, res); err != nil { return nil, err } else if responseJSON == nil { return contents, nil @@ -556,7 +560,7 @@ func handleNormalResponse(req *http.Request, res *http.Response, responseJSON in } func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) { - contents, err := readRequestBody(req, res) + contents, err := readResponseBody(req, res) if err != nil { return contents, err } @@ -573,17 +577,17 @@ func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) { } } -func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backoff time.Duration, responseJSON interface{}, handler ClientResponseHandler, client *http.Client) ([]byte, error) { +func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backoff time.Duration, responseJSON any, handler ClientResponseHandler, dontReadResponse bool, client *http.Client) ([]byte, *http.Response, error) { cli.RequestStart(req) startTime := time.Now() res, err := client.Do(req) duration := time.Now().Sub(startTime) - if res != nil { + if res != nil && !dontReadResponse { defer res.Body.Close() } if err != nil { if retries > 0 { - return cli.doRetry(req, err, retries, backoff, responseJSON, handler, client) + return cli.doRetry(req, err, retries, backoff, responseJSON, handler, dontReadResponse, client) } err = HTTPError{ Request: req, @@ -593,12 +597,12 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof WrappedError: err, } cli.LogRequestDone(req, res, err, nil, 0, duration) - return nil, err + return nil, res, err } if retries > 0 && retryafter.Should(res.StatusCode, !cli.IgnoreRateLimit) { backoff = retryafter.Parse(res.Header.Get("Retry-After"), backoff) - return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, client) + return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, dontReadResponse, client) } var body []byte @@ -609,7 +613,7 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof body, err = handler(req, res, responseJSON) cli.LogRequestDone(req, res, nil, err, len(body), duration) } - return body, err + return body, res, err } // Whoami gets the user ID of the current user. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3accountwhoami @@ -688,7 +692,7 @@ func (cli *Client) FullSyncRequest(ctx context.Context, req ReqSync) (resp *Resp fullReq.Handler = streamResponse } start := time.Now() - _, err = cli.MakeFullRequest(ctx, fullReq) + _, _, err = cli.MakeFullRequest(ctx, fullReq) duration := time.Now().Sub(start) timeout := time.Duration(req.Timeout) * time.Millisecond buffer := 10 * time.Second @@ -738,7 +742,7 @@ func (cli *Client) RegisterAvailable(ctx context.Context, username string) (resp func (cli *Client) register(ctx context.Context, url string, req *ReqRegister) (resp *RespRegister, uiaResp *RespUserInteractive, err error) { var bodyBytes []byte - bodyBytes, err = cli.MakeFullRequest(ctx, FullRequest{ + bodyBytes, _, err = cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, URL: url, RequestJSON: req, @@ -818,7 +822,7 @@ func (cli *Client) GetLoginFlows(ctx context.Context) (resp *RespLoginFlows, err // Login a user to the homeserver according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3login func (cli *Client) Login(ctx context.Context, req *ReqLogin) (resp *RespLogin, err error) { - _, err = cli.MakeFullRequest(ctx, FullRequest{ + _, _, err = cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, URL: cli.BuildClientURL("v3", "login"), RequestJSON: req, @@ -1395,7 +1399,7 @@ func parseRoomStateArray(_ *http.Request, res *http.Response, responseJSON inter // State gets all state in a room. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstate func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomStateMap, err error) { - _, err = cli.MakeFullRequest(ctx, FullRequest{ + _, _, err = cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodGet, URL: cli.BuildClientURL("v3", "rooms", roomID, "state"), ResponseJSON: &stateMap, @@ -1687,7 +1691,7 @@ func (cli *Client) UploadMedia(ctx context.Context, data ReqUploadMedia) (*RespM } var m RespMediaUpload - _, err := cli.MakeFullRequest(ctx, FullRequest{ + _, _, err := cli.MakeFullRequest(ctx, FullRequest{ Method: method, URL: u.String(), Headers: headers, @@ -2187,7 +2191,7 @@ type UIACallback = func(*RespUserInteractive) interface{} // Because the endpoint requires user-interactive authentication a callback must be provided that, // given the UI auth parameters, produces the required result (or nil to end the flow). func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCrossSigningKeysReq, uiaCallback UIACallback) error { - content, err := cli.MakeFullRequest(ctx, FullRequest{ + content, _, err := cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, URL: cli.BuildClientURL("v3", "keys", "device_signing", "upload"), RequestJSON: keys, @@ -2278,7 +2282,7 @@ func (cli *Client) BatchSend(ctx context.Context, roomID id.RoomID, req *ReqBatc } func (cli *Client) AppservicePing(ctx context.Context, id, txnID string) (resp *RespAppservicePing, err error) { - _, err = cli.MakeFullRequest(ctx, FullRequest{ + _, _, err = cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, URL: cli.BuildClientURL("v1", "appservice", id, "ping"), RequestJSON: &ReqAppservicePing{TxnID: txnID}, From ef97d96754612d50a751cbfa96af5639ecacb8d1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 15 Jun 2024 20:19:48 +0300 Subject: [PATCH 0299/1647] client: use authenticated media download endpoint if cached spec versions say it's supported --- client.go | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index 553b7e04..401da029 100644 --- a/client.go +++ b/client.go @@ -1443,6 +1443,7 @@ func (cli *Client) UploadLink(ctx context.Context, link string) (*RespMediaUploa return cli.Upload(ctx, res.Body, res.Header.Get("Content-Type"), res.ContentLength) } +// Deprecated: unauthenticated media is deprecated as of Matrix v1.11. Use [Download] or [DownloadBytes] instead. func (cli *Client) GetDownloadURL(mxcURL id.ContentURI) string { return cli.BuildURLWithQuery(MediaURLPath{"v3", "download", mxcURL.Homeserver, mxcURL.FileID}, map[string]string{"allow_redirect": "true"}) } @@ -1515,12 +1516,21 @@ func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (*http.Re if ctxLog.GetLevel() == zerolog.Disabled || ctxLog == zerolog.DefaultContextLogger { ctx = cli.Log.WithContext(ctx) } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, cli.GetDownloadURL(mxcURL), nil) - if err != nil { - return nil, err + if cli.SpecVersions.ContainsGreaterOrEqual(SpecV111) { + _, resp, err := cli.MakeFullRequest(ctx, FullRequest{ + Method: http.MethodGet, + URL: cli.BuildClientURL("v1", "media", "download", mxcURL.Homeserver, mxcURL.FileID), + DontReadResponse: true, + }) + return resp, err + } else { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, cli.GetDownloadURL(mxcURL), nil) + if err != nil { + return nil, err + } + req.Header.Set("User-Agent", cli.UserAgent+" (media downloader)") + return cli.doMediaRequest(req, cli.DefaultHTTPRetries, 4*time.Second) } - req.Header.Set("User-Agent", cli.UserAgent+" (media downloader)") - return cli.doMediaRequest(req, cli.DefaultHTTPRetries, 4*time.Second) } func (cli *Client) DownloadBytes(ctx context.Context, mxcURL id.ContentURI) ([]byte, error) { From f7de98ba777d62474749c4bb90a4aa00b772d251 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 16 Jun 2024 13:11:21 +0300 Subject: [PATCH 0300/1647] client: add backwards-compatibility for MakeFullRequest --- client.go | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/client.go b/client.go index 401da029..a57131bf 100644 --- a/client.go +++ b/client.go @@ -336,8 +336,7 @@ func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err er } func (cli *Client) MakeRequest(ctx context.Context, method string, httpURL string, reqBody any, resBody any) ([]byte, error) { - data, _, err := cli.MakeFullRequest(ctx, FullRequest{Method: method, URL: httpURL, RequestJSON: reqBody, ResponseJSON: resBody}) - return data, err + return cli.MakeFullRequest(ctx, FullRequest{Method: method, URL: httpURL, RequestJSON: reqBody, ResponseJSON: resBody}) } type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) @@ -420,7 +419,12 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e return req, nil } -func (cli *Client) MakeFullRequest(ctx context.Context, params FullRequest) ([]byte, *http.Response, error) { +func (cli *Client) MakeFullRequest(ctx context.Context, params FullRequest) ([]byte, error) { + data, _, err := cli.MakeFullRequestWithResp(ctx, params) + return data, err +} + +func (cli *Client) MakeFullRequestWithResp(ctx context.Context, params FullRequest) ([]byte, *http.Response, error) { if params.MaxAttempts == 0 { params.MaxAttempts = 1 + cli.DefaultHTTPRetries } @@ -692,7 +696,7 @@ func (cli *Client) FullSyncRequest(ctx context.Context, req ReqSync) (resp *Resp fullReq.Handler = streamResponse } start := time.Now() - _, _, err = cli.MakeFullRequest(ctx, fullReq) + _, err = cli.MakeFullRequest(ctx, fullReq) duration := time.Now().Sub(start) timeout := time.Duration(req.Timeout) * time.Millisecond buffer := 10 * time.Second @@ -742,7 +746,7 @@ func (cli *Client) RegisterAvailable(ctx context.Context, username string) (resp func (cli *Client) register(ctx context.Context, url string, req *ReqRegister) (resp *RespRegister, uiaResp *RespUserInteractive, err error) { var bodyBytes []byte - bodyBytes, _, err = cli.MakeFullRequest(ctx, FullRequest{ + bodyBytes, err = cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, URL: url, RequestJSON: req, @@ -822,7 +826,7 @@ func (cli *Client) GetLoginFlows(ctx context.Context) (resp *RespLoginFlows, err // Login a user to the homeserver according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3login func (cli *Client) Login(ctx context.Context, req *ReqLogin) (resp *RespLogin, err error) { - _, _, err = cli.MakeFullRequest(ctx, FullRequest{ + _, err = cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, URL: cli.BuildClientURL("v3", "login"), RequestJSON: req, @@ -1399,7 +1403,7 @@ func parseRoomStateArray(_ *http.Request, res *http.Response, responseJSON inter // State gets all state in a room. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstate func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomStateMap, err error) { - _, _, err = cli.MakeFullRequest(ctx, FullRequest{ + _, err = cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodGet, URL: cli.BuildClientURL("v3", "rooms", roomID, "state"), ResponseJSON: &stateMap, @@ -1517,7 +1521,7 @@ func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (*http.Re ctx = cli.Log.WithContext(ctx) } if cli.SpecVersions.ContainsGreaterOrEqual(SpecV111) { - _, resp, err := cli.MakeFullRequest(ctx, FullRequest{ + _, resp, err := cli.MakeFullRequestWithResp(ctx, FullRequest{ Method: http.MethodGet, URL: cli.BuildClientURL("v1", "media", "download", mxcURL.Homeserver, mxcURL.FileID), DontReadResponse: true, @@ -1701,7 +1705,7 @@ func (cli *Client) UploadMedia(ctx context.Context, data ReqUploadMedia) (*RespM } var m RespMediaUpload - _, _, err := cli.MakeFullRequest(ctx, FullRequest{ + _, err := cli.MakeFullRequest(ctx, FullRequest{ Method: method, URL: u.String(), Headers: headers, @@ -2201,7 +2205,7 @@ type UIACallback = func(*RespUserInteractive) interface{} // Because the endpoint requires user-interactive authentication a callback must be provided that, // given the UI auth parameters, produces the required result (or nil to end the flow). func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCrossSigningKeysReq, uiaCallback UIACallback) error { - content, _, err := cli.MakeFullRequest(ctx, FullRequest{ + content, err := cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, URL: cli.BuildClientURL("v3", "keys", "device_signing", "upload"), RequestJSON: keys, @@ -2292,7 +2296,7 @@ func (cli *Client) BatchSend(ctx context.Context, roomID id.RoomID, req *ReqBatc } func (cli *Client) AppservicePing(ctx context.Context, id, txnID string) (resp *RespAppservicePing, err error) { - _, _, err = cli.MakeFullRequest(ctx, FullRequest{ + _, err = cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, URL: cli.BuildClientURL("v1", "appservice", id, "ping"), RequestJSON: &ReqAppservicePing{TxnID: txnID}, From 8b7a3ea230ee5ee0c2d2c49add0084ad2e906863 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 16 Jun 2024 23:29:14 +0300 Subject: [PATCH 0301/1647] client: use MSC3916 endpoints for config and preview_url too --- client.go | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index a57131bf..3fb9919e 100644 --- a/client.go +++ b/client.go @@ -1425,7 +1425,12 @@ func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomSt // GetMediaConfig fetches the configuration of the content repository, such as upload limitations. func (cli *Client) GetMediaConfig(ctx context.Context) (resp *RespMediaConfig, err error) { - u := cli.BuildURL(MediaURLPath{"v3", "config"}) + var u string + if cli.SpecVersions.ContainsGreaterOrEqual(SpecV111) { + u = cli.BuildClientURL("v1", "media", "config") + } else { + u = cli.BuildURL(MediaURLPath{"v3", "config"}) + } _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp) return } @@ -1721,7 +1726,13 @@ func (cli *Client) UploadMedia(ctx context.Context, data ReqUploadMedia) (*RespM // // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixmediav3preview_url func (cli *Client) GetURLPreview(ctx context.Context, url string) (*RespPreviewURL, error) { - reqURL := cli.BuildURLWithQuery(MediaURLPath{"v3", "preview_url"}, map[string]string{ + var urlPath PrefixableURLPath + if cli.SpecVersions.ContainsGreaterOrEqual(SpecV111) { + urlPath = ClientURLPath{"v1", "media", "preview_url"} + } else { + urlPath = MediaURLPath{"v3", "preview_url"} + } + reqURL := cli.BuildURLWithQuery(urlPath, map[string]string{ "url": url, }) var output RespPreviewURL From ace2f37f01f4f6f903f2a3959329693bc4f17bb2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 16 Jun 2024 23:31:11 +0300 Subject: [PATCH 0302/1647] Bump version to v0.19.0-beta.1 --- CHANGELOG.md | 40 ++++++++++++++++++++++++++++++++++++++++ go.mod | 6 +++--- go.sum | 12 ++++++------ 3 files changed, 49 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d7d17bc1..d3b968fe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,43 @@ +## v0.19.0 (unreleased) + +### beta.1 (2024-06-16) + +* *(bridgev2)* Added experimental high-level bridge framework. +* *(hicli)* Added experimental high-level client framework. +* **Slightly breaking changes** + * *(crypto)* Added room ID and first known index parameters to + `SessionReceived` callback. + * *(crypto)* Changed `ImportRoomKeyFromBackup` to return the imported + session. + * *(client)* Added `error` parameter to `ResponseHook`. + * *(client)* Changed `Download` to return entire response instead of just an + `io.Reader`. +* *(crypto)* Changed initial olm device sharing to save keys before sharing to + ensure keys aren't accidentally regenerated in case the request fails. +* *(crypto)* Changed `EncryptMegolmEvent` and `ShareGroupSession` to return + more errors instead of only logging and ignoring them. +* *(crypto)* Added option to completely disable megolm ratchet tracking. + * The tracking is meant for bots and bridges which may want to delete old + keys, but for normal clients it's just unnecessary overhead. +* *(crypto)* Changed Megolm session storage methods in `Store` to not take + sender key as parameter. + * This causes a breaking change to the layout of the `MemoryStore` struct. + Using MemoryStore in production is not recommended. +* *(crypto)* Changed `DecryptMegolmEvent` to copy `m.relates_to` in the raw + content too instead of only in the parsed struct. +* *(crypto)* Exported function to parse megolm message index from raw + ciphertext bytes. +* *(crypto/sqlstore)* Fixed schema of `crypto_secrets` table to include + account ID. +* *(crypto/verificationhelper)* Fixed more bugs. +* *(client)* Added `UpdateRequestOnRetry` hook which is called immediately + before retrying a normal HTTP request. +* *(client)* Added support for MSC3916 media download endpoint. + * Support is automatically detected from spec versions. The `SpecVersions` + property can either be filled manually, or `Versions` can be called to + automatically populate the field with the response. +* *(event)* Added constants for known room versions. + ## v0.18.1 (2024-04-16) * *(format)* Added a `context.Context` field to HTMLParser's Context struct. diff --git a/go.mod b/go.mod index 80230c85..5c264e03 100644 --- a/go.mod +++ b/go.mod @@ -14,11 +14,11 @@ require ( github.com/stretchr/testify v1.9.0 github.com/tidwall/gjson v1.17.1 github.com/tidwall/sjson v1.2.5 - github.com/yuin/goldmark v1.7.1 - go.mau.fi/util v0.4.3-0.20240611132549-e72a5f4745e7 + github.com/yuin/goldmark v1.7.2 + go.mau.fi/util v0.5.0 go.mau.fi/zeroconfig v0.1.2 golang.org/x/crypto v0.24.0 - golang.org/x/exp v0.0.0-20240604190554-fc45aab8b7f8 + golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 golang.org/x/net v0.26.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 diff --git a/go.sum b/go.sum index 7a592e41..88e9ec63 100644 --- a/go.sum +++ b/go.sum @@ -44,16 +44,16 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/yuin/goldmark v1.7.1 h1:3bajkSilaCbjdKVsKdZjZCLBNPL9pYzrCakKaf4U49U= -github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.4.3-0.20240611132549-e72a5f4745e7 h1:DviEWXBpeOlFrqIf5s/iBDp1ewZx8fe6imMJ78kq3tA= -go.mau.fi/util v0.4.3-0.20240611132549-e72a5f4745e7/go.mod h1:Eaj7jl37ehkA7S6vE/vfPs5PsY8e91FKZ2BqA3OM/NU= +github.com/yuin/goldmark v1.7.2 h1:NjGd7lO7zrUn/A7eKwn5PEOt4ONYGqpxSEeZuduvgxc= +github.com/yuin/goldmark v1.7.2/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= +go.mau.fi/util v0.5.0 h1:8yELAl+1CDRrwGe9NUmREgVclSs26Z68pTWePHVxuDo= +go.mau.fi/util v0.5.0/go.mod h1:DsJzUrJAG53lCZnnYvq9/mOyLuPScWwYhvETiTrpdP4= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= -golang.org/x/exp v0.0.0-20240604190554-fc45aab8b7f8 h1:LoYXNGAShUG3m/ehNk4iFctuhGX/+R1ZpfJ4/ia80JM= -golang.org/x/exp v0.0.0-20240604190554-fc45aab8b7f8/go.mod h1:jj3sYF3dwk5D+ghuXyeI3r5MFf+NT2An6/9dOA95KSI= +golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 h1:yixxcjnhBmY0nkL253HFVIm0JsFHwrHdT3Yh6szTnfY= +golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8/go.mod h1:jj3sYF3dwk5D+ghuXyeI3r5MFf+NT2An6/9dOA95KSI= golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= From 2759e45688d79d6eea70904881abbe59ebd255ed Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 17 Jun 2024 14:11:39 +0300 Subject: [PATCH 0303/1647] bridgev2: add network interface for registering push notifications --- bridgev2/networkinterface.go | 52 ++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index def970ed..efb61657 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -8,6 +8,7 @@ package bridgev2 import ( "context" + "fmt" "time" "github.com/rs/zerolog" @@ -155,6 +156,57 @@ type NetworkAPI interface { HandleMatrixReadReceipt(ctx context.Context, msg *MatrixReadReceipt) error } +type PushType int + +func (pt PushType) String() string { + return pt.GoString() +} + +func (pt PushType) GoString() string { + switch pt { + case PushTypeUnknown: + return "PushTypeUnknown" + case PushTypeWeb: + return "PushTypeWeb" + case PushTypeAPNs: + return "PushTypeAPNs" + case PushTypeFCM: + return "PushTypeFCM" + default: + return fmt.Sprintf("PushType(%d)", int(pt)) + } +} + +const ( + PushTypeUnknown PushType = iota + PushTypeWeb + PushTypeAPNs + PushTypeFCM +) + +type WebPushConfig struct { + VapidKey string `json:"vapid_key"` +} + +type FCMPushConfig struct { + SenderID string `json:"sender_id"` +} + +type APNsPushConfig struct { + BundleID string `json:"bundle_id"` +} + +type PushConfig struct { + Web *WebPushConfig `json:"web,omitempty"` + FCM *FCMPushConfig `json:"fcm,omitempty"` + APNs *APNsPushConfig `json:"apns,omitempty"` +} + +type PushableNetworkAPI interface { + RegisterPushNotifications(ctx context.Context, pushType PushType, token string) error + GetPushConfigs() *PushConfig +} + type RemoteEventType int const ( From 3828c08f27fa559c730938aab620bf89910298cb Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 17 Jun 2024 14:56:34 +0300 Subject: [PATCH 0304/1647] bridgev2: add debug command for registering pusher --- bridgev2/cmddebug.go | 59 ++++++++++++++++++++++++++++++++++++ bridgev2/cmdlogin.go | 6 ++-- bridgev2/cmdprocessor.go | 1 + bridgev2/networkinterface.go | 14 +++++++++ 4 files changed, 77 insertions(+), 3 deletions(-) create mode 100644 bridgev2/cmddebug.go diff --git a/bridgev2/cmddebug.go b/bridgev2/cmddebug.go new file mode 100644 index 00000000..400470ed --- /dev/null +++ b/bridgev2/cmddebug.go @@ -0,0 +1,59 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "strings" + + "maunium.net/go/mautrix/bridgev2/networkid" +) + +var CommandRegisterPush = &FullHandler{ + Func: func(ce *CommandEvent) { + if len(ce.Args) < 3 { + ce.Reply("Usage: `$cmdprefix debug-register-push `\n\nYour logins:\n\n%s", ce.User.GetFormattedUserLogins()) + return + } + pushType := PushTypeFromString(ce.Args[1]) + if pushType == PushTypeUnknown { + ce.Reply("Unknown push type `%s`. Allowed types: `web`, `apns`, `fcm`", ce.Args[1]) + return + } + login := ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) + if login == nil || login.UserMXID != ce.User.MXID { + ce.Reply("Login `%s` not found", ce.Args[0]) + return + } + pushable, ok := login.Client.(PushableNetworkAPI) + if !ok { + ce.Reply("This network connector does not support push registration") + return + } + pushToken := strings.Join(ce.Args[2:], " ") + if pushToken == "null" { + pushToken = "" + } + err := pushable.RegisterPushNotifications(ce.Ctx, pushType, pushToken) + if err != nil { + ce.Reply("Failed to register pusher: %v", err) + return + } + if pushToken == "" { + ce.Reply("Pusher de-registered successfully") + } else { + ce.Reply("Pusher registered successfully") + } + }, + Name: "debug-register-push", + Help: HelpMeta{ + Section: HelpSectionAdmin, + Description: "Register a pusher", + Args: "<_login ID_> <_push type_> <_push token_>", + }, + RequiresAdmin: true, + RequiresLogin: true, +} diff --git a/bridgev2/cmdlogin.go b/bridgev2/cmdlogin.go index 32e59cde..b09e9dfb 100644 --- a/bridgev2/cmdlogin.go +++ b/bridgev2/cmdlogin.go @@ -300,7 +300,7 @@ var CommandLogout = &FullHandler{ }, } -func getUserLogins(user *User) string { +func (user *User) GetFormattedUserLogins() string { user.Bridge.cacheLock.Lock() logins := make([]string, len(user.logins)) for key, val := range user.logins { @@ -313,7 +313,7 @@ func getUserLogins(user *User) string { func fnLogout(ce *CommandEvent) { if len(ce.Args) == 0 { - ce.Reply("Usage: `$cmdprefix logout `\n\nYour logins:\n\n%s", getUserLogins(ce.User)) + ce.Reply("Usage: `$cmdprefix logout `\n\nYour logins:\n\n%s", ce.User.GetFormattedUserLogins()) return } login := ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) @@ -339,7 +339,7 @@ var CommandSetPreferredLogin = &FullHandler{ func fnSetPreferredLogin(ce *CommandEvent) { if len(ce.Args) == 0 { - ce.Reply("Usage: `$cmdprefix set-preferred-login `\n\nYour logins:\n\n%s", getUserLogins(ce.User)) + ce.Reply("Usage: `$cmdprefix set-preferred-login `\n\nYour logins:\n\n%s", ce.User.GetFormattedUserLogins()) return } login := ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) diff --git a/bridgev2/cmdprocessor.go b/bridgev2/cmdprocessor.go index efeb11a2..7b064fda 100644 --- a/bridgev2/cmdprocessor.go +++ b/bridgev2/cmdprocessor.go @@ -38,6 +38,7 @@ func NewProcessor(bridge *Bridge) *CommandProcessor { } proc.AddHandlers( CommandHelp, CommandCancel, + CommandRegisterPush, CommandLogin, CommandLogout, CommandSetPreferredLogin, ) return proc diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index efb61657..48781b89 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -9,6 +9,7 @@ package bridgev2 import ( "context" "fmt" + "strings" "time" "github.com/rs/zerolog" @@ -162,6 +163,19 @@ func (pt PushType) String() string { return pt.GoString() } +func PushTypeFromString(str string) PushType { + switch strings.TrimPrefix(strings.ToLower(str), "pushtype") { + case "web": + return PushTypeWeb + case "apns": + return PushTypeAPNs + case "fcm": + return PushTypeFCM + default: + return PushTypeUnknown + } +} + func (pt PushType) GoString() string { switch pt { case PushTypeUnknown: From 833995832be37a0c604cb99a1fbf66a0dfd9daf3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 17 Jun 2024 16:00:07 +0300 Subject: [PATCH 0305/1647] bridgev2: add clean shutdown --- bridgev2/bridge.go | 18 ++++++++++++++++++ bridgev2/matrix/connector.go | 8 ++++++++ bridgev2/matrix/mxmain/main.go | 3 +-- bridgev2/matrixinterface.go | 1 + bridgev2/networkinterface.go | 1 + bridgev2/userlogin.go | 18 ++++++++++++++++++ 6 files changed, 47 insertions(+), 2 deletions(-) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index a12da24b..8eaf6a1e 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -120,3 +120,21 @@ func (br *Bridge) Start() error { br.Log.Info().Msg("Bridge started") return nil } + +func (br *Bridge) Stop() { + br.Log.Info().Msg("Shutting down bridge") + br.Matrix.Stop() + br.cacheLock.Lock() + var wg sync.WaitGroup + wg.Add(len(br.userLoginsByID)) + for _, login := range br.userLoginsByID { + go login.Disconnect(wg.Done) + } + wg.Wait() + br.cacheLock.Unlock() + err := br.DB.Close() + if err != nil { + br.Log.Warn().Err(err).Msg("Failed to close database") + } + br.Log.Info().Msg("Shutdown complete") +} diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 3493e7a3..5773a1bf 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -141,6 +141,14 @@ func (br *Connector) Start(ctx context.Context) error { return nil } +func (br *Connector) Stop() { + br.AS.Stop() + br.EventProcessor.Stop() + if br.Crypto != nil { + br.Crypto.Stop() + } +} + var MinSpecVersion = mautrix.SpecV14 func (br *Connector) ensureConnection(ctx context.Context) { diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index dd4c0328..3b2bc460 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -354,8 +354,7 @@ func (br *BridgeMain) WaitForInterrupt() { // Stop cleanly stops the bridge. This is called by [Run] and does not need to be called manually. func (br *BridgeMain) Stop() { - br.Log.Info().Msg("Shutting down bridge") - // TODO actually stop cleanly + br.Bridge.Stop() } // InitVersion formats the bridge version and build time nicely for things like diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index da8c92a8..56fbde38 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -20,6 +20,7 @@ import ( type MatrixConnector interface { Init(*Bridge) Start(ctx context.Context) error + Stop() ParseGhostMXID(userID id.UserID) (networkid.UserID, bool) FormatGhostMXID(userID networkid.UserID) id.UserID diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 48781b89..d82dd7ff 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -141,6 +141,7 @@ type MaxFileSizeingNetwork interface { // NetworkAPI is an interface representing a remote network client for a single user login. type NetworkAPI interface { Connect(ctx context.Context) error + Disconnect() IsLoggedIn() bool LogoutRemote(ctx context.Context) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 3a980c01..beaa03c7 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -9,6 +9,7 @@ package bridgev2 import ( "context" "fmt" + "time" "github.com/rs/zerolog" @@ -161,3 +162,20 @@ func (ul *UserLogin) GetRemoteName() string { name, _ := ul.Metadata["remote_name"].(string) return name } + +func (ul *UserLogin) Disconnect(done func()) { + defer done() + if ul.Client != nil { + disconnected := make(chan struct{}) + go func() { + ul.Client.Disconnect() + ul.Client = nil + close(disconnected) + }() + select { + case <-disconnected: + case <-time.After(5 * time.Second): + ul.Log.Warn().Msg("Client disconnection timed out") + } + } +} From afeadfb15feeab88aca99776ef3c0c9828ea6515 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 17 Jun 2024 18:15:34 +0300 Subject: [PATCH 0306/1647] crypto: fix m.relates_to copying --- crypto/decryptmegolm.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index 1b714d09..99b584f5 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -39,7 +39,10 @@ type megolmEvent struct { Content event.Content `json:"content"` } -var relatesToPath = exgjson.Path("m.relates_to") +var ( + relatesToContentPath = exgjson.Path("m.relates_to") + relatesToTopLevelPath = exgjson.Path("content", "m.relates_to") +) // DecryptMegolmEvent decrypts an m.room.encrypted event where the algorithm is m.megolm.v1.aes-sha2 func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event) (*event.Event, error) { @@ -113,15 +116,15 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event } if content.RelatesTo != nil { - relation := gjson.GetBytes(evt.Content.VeryRaw, relatesToPath) - if relation.Exists() { + relation := gjson.GetBytes(evt.Content.VeryRaw, relatesToContentPath) + if relation.Exists() && !gjson.GetBytes(plaintext, relatesToTopLevelPath).IsObject() { var raw []byte if relation.Index > 0 { raw = evt.Content.VeryRaw[relation.Index : relation.Index+len(relation.Raw)] } else { raw = []byte(relation.Raw) } - updatedPlaintext, err := sjson.SetRawBytes(plaintext, relatesToPath, raw) + updatedPlaintext, err := sjson.SetRawBytes(plaintext, relatesToTopLevelPath, raw) if err != nil { log.Warn().Msg("Failed to copy m.relates_to to decrypted payload") } else if updatedPlaintext != nil { From 7d68995c85673fc6c3f7be49abc93688b1ff3616 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 18 Jun 2024 16:55:32 +0300 Subject: [PATCH 0307/1647] bridgev2: add structs for metadata of all db tables --- bridgev2/cmdlogin.go | 3 +- bridgev2/database/database.go | 58 ++++++++++++++++++++++++++++++++++ bridgev2/database/ghost.go | 15 +++++++-- bridgev2/database/message.go | 25 ++++++++++++++- bridgev2/database/portal.go | 26 ++++++++++++--- bridgev2/database/reaction.go | 27 +++++++++++++--- bridgev2/database/userlogin.go | 27 +++++++++++++--- bridgev2/networkinterface.go | 6 +++- bridgev2/portal.go | 41 +++++++++--------------- bridgev2/userlogin.go | 3 +- 10 files changed, 182 insertions(+), 49 deletions(-) diff --git a/bridgev2/cmdlogin.go b/bridgev2/cmdlogin.go index b09e9dfb..76727906 100644 --- a/bridgev2/cmdlogin.go +++ b/bridgev2/cmdlogin.go @@ -304,8 +304,7 @@ func (user *User) GetFormattedUserLogins() string { user.Bridge.cacheLock.Lock() logins := make([]string, len(user.logins)) for key, val := range user.logins { - remoteName, _ := val.Metadata["remote_name"].(string) - logins = append(logins, fmt.Sprintf("* `%s` (%s)", key, remoteName)) + logins = append(logins, fmt.Sprintf("* `%s` (%s)", key, val.Metadata.RemoteName)) } user.Bridge.cacheLock.Unlock() return strings.Join(logins, "\n") diff --git a/bridgev2/database/database.go b/bridgev2/database/database.go index c4e2598f..c910498a 100644 --- a/bridgev2/database/database.go +++ b/bridgev2/database/database.go @@ -7,7 +7,13 @@ package database import ( + "encoding/json" + "reflect" + "strings" + "go.mau.fi/util/dbutil" + "golang.org/x/exp/constraints" + "golang.org/x/exp/maps" "maunium.net/go/mautrix/bridgev2/networkid" @@ -49,3 +55,55 @@ func ensureBridgeIDMatches(ptr *networkid.BridgeID, expected networkid.BridgeID) panic("bridge ID mismatch") } } + +func GetNumberFromMap[T constraints.Integer | constraints.Float](m map[string]any, key string) (T, bool) { + if val, found := m[key]; found { + floatVal, ok := val.(float64) + if ok { + return T(floatVal), true + } + tVal, ok := val.(T) + if ok { + return tVal, true + } + } + return 0, false +} + +func unmarshalMerge(input []byte, data any, extra *map[string]any) error { + err := json.Unmarshal(input, data) + if err != nil { + return err + } + err = json.Unmarshal(input, extra) + if err != nil { + return err + } + if *extra == nil { + *extra = make(map[string]any) + } + return nil +} + +func marshalMerge(data any, extra map[string]any) ([]byte, error) { + if extra == nil { + return json.Marshal(data) + } + merged := make(map[string]any) + maps.Copy(merged, extra) + dataRef := reflect.ValueOf(data).Elem() + dataType := dataRef.Type() + for _, field := range reflect.VisibleFields(dataType) { + parts := strings.Split(field.Tag.Get("json"), ",") + if len(parts) == 0 || len(parts[0]) == 0 || parts[0] == "-" { + continue + } + fieldVal := dataRef.FieldByIndex(field.Index) + if fieldVal.IsZero() { + delete(merged, parts[0]) + } else { + merged[parts[0]] = fieldVal.Interface() + } + } + return json.Marshal(merged) +} diff --git a/bridgev2/database/ghost.go b/bridgev2/database/ghost.go index a814e7c4..e6383c0e 100644 --- a/bridgev2/database/ghost.go +++ b/bridgev2/database/ghost.go @@ -21,12 +21,23 @@ type GhostQuery struct { *dbutil.QueryHelper[*Ghost] } -type GhostMetadata struct { +type StandardGhostMetadata struct { IsBot bool `json:"is_bot,omitempty"` Identifiers []string `json:"identifiers,omitempty"` ContactInfoSet bool `json:"contact_info_set,omitempty"` +} - Extra map[string]any `json:"extra"` +type GhostMetadata struct { + StandardGhostMetadata + Extra map[string]any +} + +func (gm *GhostMetadata) UnmarshalJSON(data []byte) error { + return unmarshalMerge(data, &gm.StandardGhostMetadata, &gm.Extra) +} + +func (gm *GhostMetadata) MarshalJSON() ([]byte, error) { + return marshalMerge(&gm.StandardGhostMetadata, gm.Extra) } type Ghost struct { diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index c1261be7..363f8e74 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -22,6 +22,23 @@ type MessageQuery struct { *dbutil.QueryHelper[*Message] } +type StandardMessageMetadata struct { + SenderMXID id.UserID `json:"sender_mxid,omitempty"` +} + +type MessageMetadata struct { + StandardMessageMetadata + Extra map[string]any +} + +func (mm *MessageMetadata) UnmarshalJSON(data []byte) error { + return unmarshalMerge(data, &mm.StandardMessageMetadata, &mm.Extra) +} + +func (mm *MessageMetadata) MarshalJSON() ([]byte, error) { + return marshalMerge(&mm.StandardMessageMetadata, mm.Extra) +} + type Message struct { RowID int64 BridgeID networkid.BridgeID @@ -35,7 +52,7 @@ type Message struct { RelatesToRowID int64 - Metadata map[string]any + Metadata MessageMetadata } func newMessage(_ *dbutil.QueryHelper[*Message]) *Message { @@ -134,12 +151,18 @@ func (m *Message) Scan(row dbutil.Scannable) (*Message, error) { if err != nil { return nil, err } + if m.Metadata.Extra == nil { + m.Metadata.Extra = make(map[string]any) + } m.Timestamp = time.Unix(0, timestamp) m.RelatesToRowID = relatesTo.Int64 return m, nil } func (m *Message) sqlVariables() []any { + if m.Metadata.Extra == nil { + m.Metadata.Extra = make(map[string]any) + } return []any{ m.BridgeID, m.ID, m.PartID, m.MXID, m.Room.ID, m.Room.Receiver, m.SenderID, m.Timestamp.UnixNano(), dbutil.NumPtr(m.RelatesToRowID), dbutil.JSON{Data: m.Metadata}, diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index 2456b9ca..db11a78f 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -22,6 +22,22 @@ type PortalQuery struct { *dbutil.QueryHelper[*Portal] } +type StandardPortalMetadata struct { +} + +type PortalMetadata struct { + StandardPortalMetadata + Extra map[string]any +} + +func (pm *PortalMetadata) UnmarshalJSON(data []byte) error { + return unmarshalMerge(data, &pm.StandardPortalMetadata, &pm.Extra) +} + +func (pm *PortalMetadata) MarshalJSON() ([]byte, error) { + return marshalMerge(&pm.StandardPortalMetadata, pm.Extra) +} + type Portal struct { BridgeID networkid.BridgeID networkid.PortalKey @@ -37,7 +53,7 @@ type Portal struct { TopicSet bool AvatarSet bool InSpace bool - Metadata map[string]any + Metadata PortalMetadata } func newPortal(_ *dbutil.QueryHelper[*Portal]) *Portal { @@ -115,8 +131,8 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { if err != nil { return nil, err } - if p.Metadata == nil { - p.Metadata = make(map[string]any) + if p.Metadata.Extra == nil { + p.Metadata.Extra = make(map[string]any) } if avatarHash != "" { data, _ := hex.DecodeString(avatarHash) @@ -130,8 +146,8 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { } func (p *Portal) sqlVariables() []any { - if p.Metadata == nil { - p.Metadata = make(map[string]any) + if p.Metadata.Extra == nil { + p.Metadata.Extra = make(map[string]any) } var avatarHash string if p.AvatarHash != [32]byte{} { diff --git a/bridgev2/database/reaction.go b/bridgev2/database/reaction.go index 5b01459b..d5cb798f 100644 --- a/bridgev2/database/reaction.go +++ b/bridgev2/database/reaction.go @@ -21,6 +21,23 @@ type ReactionQuery struct { *dbutil.QueryHelper[*Reaction] } +type StandardReactionMetadata struct { + Emoji string `json:"emoji,omitempty"` +} + +type ReactionMetadata struct { + StandardReactionMetadata + Extra map[string]any +} + +func (rm *ReactionMetadata) UnmarshalJSON(data []byte) error { + return unmarshalMerge(data, &rm.StandardReactionMetadata, &rm.Extra) +} + +func (rm *ReactionMetadata) MarshalJSON() ([]byte, error) { + return marshalMerge(&rm.StandardReactionMetadata, rm.Extra) +} + type Reaction struct { BridgeID networkid.BridgeID Room networkid.PortalKey @@ -31,7 +48,7 @@ type Reaction struct { MXID id.EventID Timestamp time.Time - Metadata map[string]any + Metadata ReactionMetadata } func newReaction(_ *dbutil.QueryHelper[*Reaction]) *Reaction { @@ -97,16 +114,16 @@ func (r *Reaction) Scan(row dbutil.Scannable) (*Reaction, error) { if err != nil { return nil, err } - if r.Metadata == nil { - r.Metadata = make(map[string]any) + if r.Metadata.Extra == nil { + r.Metadata.Extra = make(map[string]any) } r.Timestamp = time.Unix(0, timestamp) return r, nil } func (r *Reaction) sqlVariables() []any { - if r.Metadata == nil { - r.Metadata = make(map[string]any) + if r.Metadata.Extra == nil { + r.Metadata.Extra = make(map[string]any) } return []any{ r.BridgeID, r.MessageID, r.MessagePartID, r.SenderID, r.EmojiID, diff --git a/bridgev2/database/userlogin.go b/bridgev2/database/userlogin.go index 65ed4bf8..ea8e5838 100644 --- a/bridgev2/database/userlogin.go +++ b/bridgev2/database/userlogin.go @@ -21,12 +21,29 @@ type UserLoginQuery struct { *dbutil.QueryHelper[*UserLogin] } +type StandardUserLoginMetadata struct { + RemoteName string `json:"remote_name,omitempty"` +} + +type UserLoginMetadata struct { + StandardUserLoginMetadata + Extra map[string]any +} + +func (ulm *UserLoginMetadata) UnmarshalJSON(data []byte) error { + return unmarshalMerge(data, &ulm.StandardUserLoginMetadata, &ulm.Extra) +} + +func (ulm *UserLoginMetadata) MarshalJSON() ([]byte, error) { + return marshalMerge(&ulm.StandardUserLoginMetadata, ulm.Extra) +} + type UserLogin struct { BridgeID networkid.BridgeID UserMXID id.UserID ID networkid.UserLoginID SpaceRoom id.RoomID - Metadata map[string]any + Metadata UserLoginMetadata } func newUserLogin(_ *dbutil.QueryHelper[*UserLogin]) *UserLogin { @@ -89,16 +106,16 @@ func (u *UserLogin) Scan(row dbutil.Scannable) (*UserLogin, error) { if err != nil { return nil, err } - if u.Metadata == nil { - u.Metadata = make(map[string]any) + if u.Metadata.Extra == nil { + u.Metadata.Extra = make(map[string]any) } u.SpaceRoom = id.RoomID(spaceRoom.String) return u, nil } func (u *UserLogin) sqlVariables() []any { - if u.Metadata == nil { - u.Metadata = make(map[string]any) + if u.Metadata.Extra == nil { + u.Metadata.Extra = make(map[string]any) } return []any{u.BridgeID, u.UserMXID, u.ID, dbutil.StrPtr(u.SpaceRoom), dbutil.JSON{Data: u.Metadata}} } diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index d82dd7ff..66328004 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -138,6 +138,10 @@ type MaxFileSizeingNetwork interface { SetMaxFileSize(maxSize int64) } +type MatrixMessageResponse struct { + DB *database.Message +} + // NetworkAPI is an interface representing a remote network client for a single user login. type NetworkAPI interface { Connect(ctx context.Context) error @@ -149,7 +153,7 @@ type NetworkAPI interface { GetChatInfo(ctx context.Context, portal *Portal) (*PortalInfo, error) GetUserInfo(ctx context.Context, ghost *Ghost) (*UserInfo, error) - HandleMatrixMessage(ctx context.Context, msg *MatrixMessage) (message *database.Message, err error) + HandleMatrixMessage(ctx context.Context, msg *MatrixMessage) (message *MatrixMessageResponse, err error) HandleMatrixEdit(ctx context.Context, msg *MatrixEdit) error PreHandleMatrixReaction(ctx context.Context, msg *MatrixReaction) (MatrixReactionPreResponse, error) HandleMatrixReaction(ctx context.Context, msg *MatrixReaction) (reaction *database.Reaction, err error) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index d1b35664..86667d9b 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -433,7 +433,7 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin } } - message, err := sender.Client.HandleMatrixMessage(ctx, &MatrixMessage{ + resp, err := sender.Client.HandleMatrixMessage(ctx, &MatrixMessage{ MatrixEventBase: MatrixEventBase[*event.MessageEventContent]{ Event: evt, Content: content, @@ -448,6 +448,7 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin portal.sendErrorStatus(ctx, evt, err) return } + message := resp.DB if message.MXID == "" { message.MXID = evt.ID } @@ -457,10 +458,7 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin if message.Timestamp.IsZero() { message.Timestamp = time.UnixMilli(evt.Timestamp) } - if message.Metadata == nil { - message.Metadata = make(map[string]any) - } - message.Metadata["sender_mxid"] = evt.Sender + message.Metadata.SenderMXID = evt.Sender // Hack to ensure the ghost row exists // TODO move to better place (like login) portal.Bridge.GetGhostByID(ctx, message.SenderID) @@ -557,7 +555,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi log.Err(err).Msg("Failed to check if reaction is a duplicate") return } else if existing != nil { - if existing.EmojiID != "" || existing.Metadata["emoji"] == preResp.Emoji { + if existing.EmojiID != "" || existing.Metadata.Emoji == preResp.Emoji { log.Debug().Msg("Ignoring duplicate reaction") portal.sendSuccessStatus(ctx, evt) return @@ -620,12 +618,9 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi if dbReaction.Timestamp.IsZero() { dbReaction.Timestamp = time.UnixMilli(evt.Timestamp) } - if dbReaction.Metadata == nil { - dbReaction.Metadata = make(map[string]any) - } if preResp.EmojiID == "" && dbReaction.EmojiID == "" { - if _, alreadySet := dbReaction.Metadata["emoji"]; !alreadySet { - dbReaction.Metadata["emoji"] = preResp.Emoji + if dbReaction.Metadata.Emoji == "" { + dbReaction.Metadata.Emoji = preResp.Emoji } } else if dbReaction.EmojiID == "" { dbReaction.EmojiID = preResp.EmojiID @@ -815,9 +810,8 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin if part.Content.Mentions == nil { part.Content.Mentions = &event.Mentions{} } - replyTargetSenderMXID, ok := replyTo.Metadata["sender_mxid"].(string) - if ok && !slices.Contains(part.Content.Mentions.UserIDs, id.UserID(replyTargetSenderMXID)) { - part.Content.Mentions.UserIDs = append(part.Content.Mentions.UserIDs, id.UserID(replyTargetSenderMXID)) + if !slices.Contains(part.Content.Mentions.UserIDs, replyTo.Metadata.SenderMXID) { + part.Content.Mentions.UserIDs = append(part.Content.Mentions.UserIDs, replyTo.Metadata.SenderMXID) } } resp, err := intent.SendMessage(ctx, portal.MXID, part.Type, &event.Content{ @@ -832,11 +826,6 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin Stringer("event_id", resp.EventID). Str("part_id", string(part.ID)). Msg("Sent message part to Matrix") - if part.DBMetadata == nil { - part.DBMetadata = make(map[string]any) - } - // TODO make metadata fields less hacky - part.DBMetadata["sender_mxid"] = intent.GetMXID() dbMessage := &database.Message{ ID: evt.GetID(), PartID: part.ID, @@ -845,8 +834,9 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin SenderID: evt.GetSender().Sender, Timestamp: ts, RelatesToRowID: relatesToRowID, - Metadata: part.DBMetadata, } + dbMessage.Metadata.SenderMXID = intent.GetMXID() + dbMessage.Metadata.Extra = part.DBMetadata err = portal.Bridge.DB.Message.Insert(ctx, dbMessage) if err != nil { log.Err(err).Str("part_id", string(part.ID)).Msg("Failed to save message part to database") @@ -980,7 +970,7 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi if err != nil { log.Err(err).Msg("Failed to check if reaction is a duplicate") return - } else if existingReaction != nil && (emojiID != "" || existingReaction.Metadata["emoji"] == emoji) { + } else if existingReaction != nil && (emojiID != "" || existingReaction.Metadata.Emoji == emoji) { log.Debug().Msg("Ignoring duplicate reaction") return } @@ -1012,11 +1002,10 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi Timestamp: ts, } if metaProvider, ok := evt.(RemoteReactionWithMeta); ok { - dbReaction.Metadata = metaProvider.GetReactionDBMetadata() - } else if emojiID == "" { - dbReaction.Metadata = map[string]any{ - "emoji": emoji, - } + dbReaction.Metadata.Extra = metaProvider.GetReactionDBMetadata() + } + if emojiID == "" { + dbReaction.Metadata.Emoji = emoji } err = portal.Bridge.DB.Reaction.Upsert(ctx, dbReaction) if err != nil { diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index beaa03c7..5329d464 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -159,8 +159,7 @@ func (ul *UserLogin) GetRemoteID() string { } func (ul *UserLogin) GetRemoteName() string { - name, _ := ul.Metadata["remote_name"].(string) - return name + return ul.Metadata.RemoteName } func (ul *UserLogin) Disconnect(done func()) { From 2c6ca02eeb56eeacfafc8dbfa442eeac052f472e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 18 Jun 2024 17:31:48 +0300 Subject: [PATCH 0308/1647] bridgev2: add support for disappearing messages --- bridgev2/bridge.go | 5 + bridgev2/database/database.go | 36 +++--- bridgev2/database/disappear.go | 106 +++++++++++++++++ bridgev2/database/portal.go | 3 + bridgev2/database/upgrades/00-latest.sql | 13 ++- .../upgrades/02-disappearing-messages.sql | 11 ++ bridgev2/disappear.go | 110 ++++++++++++++++++ bridgev2/networkinterface.go | 1 + bridgev2/portal.go | 33 +++++- 9 files changed, 297 insertions(+), 21 deletions(-) create mode 100644 bridgev2/database/disappear.go create mode 100644 bridgev2/database/upgrades/02-disappearing-messages.sql create mode 100644 bridgev2/disappear.go diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 8eaf6a1e..9a73bc2a 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -35,6 +35,8 @@ type Bridge struct { Commands *CommandProcessor Config *bridgeconfig.BridgeConfig + DisappearLoop *DisappearLoop + usersByMXID map[id.UserID]*User userLoginsByID map[networkid.UserLoginID]*UserLogin portalsByKey map[networkid.PortalKey]*Portal @@ -66,6 +68,7 @@ func NewBridge(bridgeID networkid.BridgeID, db *dbutil.Database, log zerolog.Log br.Matrix.Init(br) br.Bot = br.Matrix.BotIntent() br.Network.Init(br) + br.DisappearLoop = &DisappearLoop{br: br} return br } @@ -100,6 +103,8 @@ func (br *Bridge) Start() error { if err != nil { return fmt.Errorf("failed to start network connector: %w", err) } + // TODO only start if the network supports disappearing messages? + go br.DisappearLoop.Start() logins, err := br.GetAllUserLogins(ctx) if err != nil { diff --git a/bridgev2/database/database.go b/bridgev2/database/database.go index c910498a..c6d1e4eb 100644 --- a/bridgev2/database/database.go +++ b/bridgev2/database/database.go @@ -23,28 +23,30 @@ import ( type Database struct { *dbutil.Database - BridgeID networkid.BridgeID - Portal *PortalQuery - Ghost *GhostQuery - Message *MessageQuery - Reaction *ReactionQuery - User *UserQuery - UserLogin *UserLoginQuery - UserPortal *UserPortalQuery + BridgeID networkid.BridgeID + Portal *PortalQuery + Ghost *GhostQuery + Message *MessageQuery + DisappearingMessage *DisappearingMessageQuery + Reaction *ReactionQuery + User *UserQuery + UserLogin *UserLoginQuery + UserPortal *UserPortalQuery } func New(bridgeID networkid.BridgeID, db *dbutil.Database) *Database { db.UpgradeTable = upgrades.Table return &Database{ - Database: db, - BridgeID: bridgeID, - Portal: &PortalQuery{bridgeID, dbutil.MakeQueryHelper(db, newPortal)}, - Ghost: &GhostQuery{bridgeID, dbutil.MakeQueryHelper(db, newGhost)}, - Message: &MessageQuery{bridgeID, dbutil.MakeQueryHelper(db, newMessage)}, - Reaction: &ReactionQuery{bridgeID, dbutil.MakeQueryHelper(db, newReaction)}, - User: &UserQuery{bridgeID, dbutil.MakeQueryHelper(db, newUser)}, - UserLogin: &UserLoginQuery{bridgeID, dbutil.MakeQueryHelper(db, newUserLogin)}, - UserPortal: &UserPortalQuery{bridgeID, dbutil.MakeQueryHelper(db, newUserPortal)}, + Database: db, + BridgeID: bridgeID, + Portal: &PortalQuery{bridgeID, dbutil.MakeQueryHelper(db, newPortal)}, + Ghost: &GhostQuery{bridgeID, dbutil.MakeQueryHelper(db, newGhost)}, + Message: &MessageQuery{bridgeID, dbutil.MakeQueryHelper(db, newMessage)}, + DisappearingMessage: &DisappearingMessageQuery{bridgeID, dbutil.MakeQueryHelper(db, newDisappearingMessage)}, + Reaction: &ReactionQuery{bridgeID, dbutil.MakeQueryHelper(db, newReaction)}, + User: &UserQuery{bridgeID, dbutil.MakeQueryHelper(db, newUser)}, + UserLogin: &UserLoginQuery{bridgeID, dbutil.MakeQueryHelper(db, newUserLogin)}, + UserPortal: &UserPortalQuery{bridgeID, dbutil.MakeQueryHelper(db, newUserPortal)}, } } diff --git a/bridgev2/database/disappear.go b/bridgev2/database/disappear.go new file mode 100644 index 00000000..22a5be5c --- /dev/null +++ b/bridgev2/database/disappear.go @@ -0,0 +1,106 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "time" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/id" +) + +// DisappearingType represents the type of a disappearing message timer. +type DisappearingType string + +const ( + DisappearingTypeNone DisappearingType = "" + DisappearingTypeAfterRead DisappearingType = "after_read" + DisappearingTypeAfterSend DisappearingType = "after_send" +) + +// DisappearingSetting represents a disappearing message timer setting +// by combining a type with a timer and an optional start timestamp. +type DisappearingSetting struct { + Type DisappearingType + Timer time.Duration + DisappearAt time.Time +} + +type DisappearingMessageQuery struct { + BridgeID networkid.BridgeID + *dbutil.QueryHelper[*DisappearingMessage] +} + +type DisappearingMessage struct { + BridgeID networkid.BridgeID + RoomID id.RoomID + EventID id.EventID + DisappearingSetting +} + +func newDisappearingMessage(_ *dbutil.QueryHelper[*DisappearingMessage]) *DisappearingMessage { + return &DisappearingMessage{} +} + +const ( + upsertDisappearingMessageQuery = ` + INSERT INTO disappearing_message (bridge_id, mx_room, mxid, type, timer, disappear_at) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (bridge_id, mxid) DO UPDATE SET timer=excluded.timer, disappear_at=excluded.disappear_at + ` + startDisappearingMessagesQuery = ` + UPDATE disappearing_message + SET disappear_at=$1 + timer + WHERE bridge_id=$2 AND mx_room=$3 AND disappear_at IS NULL AND type='after_read' + RETURNING bridge_id, mx_room, mxid, type, timer, disappear_at + ` + getUpcomingDisappearingMessagesQuery = ` + SELECT bridge_id, mx_room, mxid, type, timer, disappear_at + FROM disappearing_message WHERE bridge_id = $1 AND disappear_at IS NOT NULL AND disappear_at < $2 + ORDER BY disappear_at + ` + deleteDisappearingMessageQuery = ` + DELETE FROM disappearing_message WHERE bridge_id=$1 AND mxid=$2 + ` +) + +func (dmq *DisappearingMessageQuery) Put(ctx context.Context, dm *DisappearingMessage) error { + ensureBridgeIDMatches(&dm.BridgeID, dmq.BridgeID) + return dmq.Exec(ctx, upsertDisappearingMessageQuery, dm.sqlVariables()...) +} + +func (dmq *DisappearingMessageQuery) StartAll(ctx context.Context, roomID id.RoomID) ([]*DisappearingMessage, error) { + return dmq.QueryMany(ctx, startDisappearingMessagesQuery, time.Now().UnixNano(), dmq.BridgeID, roomID) +} + +func (dmq *DisappearingMessageQuery) GetUpcoming(ctx context.Context, duration time.Duration) ([]*DisappearingMessage, error) { + return dmq.QueryMany(ctx, getUpcomingDisappearingMessagesQuery, dmq.BridgeID, time.Now().Add(duration).UnixNano()) +} + +func (dmq *DisappearingMessageQuery) Delete(ctx context.Context, eventID id.EventID) error { + return dmq.Exec(ctx, deleteDisappearingMessageQuery, dmq.BridgeID, eventID) +} + +func (d *DisappearingMessage) Scan(row dbutil.Scannable) (*DisappearingMessage, error) { + var disappearAt sql.NullInt64 + err := row.Scan(&d.BridgeID, &d.RoomID, &d.EventID, &d.Type, &d.Timer, &disappearAt) + if err != nil { + return nil, err + } + if disappearAt.Valid { + d.DisappearAt = time.Unix(0, disappearAt.Int64) + } + return d, nil +} + +func (d *DisappearingMessage) sqlVariables() []any { + return []any{d.BridgeID, d.RoomID, d.EventID, d.Type, d.Timer, dbutil.ConvertedPtr(d.DisappearAt, time.Time.UnixNano)} +} diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index db11a78f..ded53177 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -10,6 +10,7 @@ import ( "context" "database/sql" "encoding/hex" + "time" "go.mau.fi/util/dbutil" @@ -23,6 +24,8 @@ type PortalQuery struct { } type StandardPortalMetadata struct { + DisappearType DisappearingType `json:"disappear_type,omitempty"` + DisappearTimer time.Duration `json:"disappear_timer,omitempty"` } type PortalMetadata struct { diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index 8d9d150e..adceab3c 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v1: Latest revision +-- v0 -> v2 (compatible with v1+): Latest revision CREATE TABLE portal ( bridge_id TEXT NOT NULL, id TEXT NOT NULL, @@ -76,6 +76,17 @@ CREATE TABLE message ( CONSTRAINT message_real_pkey UNIQUE (bridge_id, id, part_id) ); +CREATE TABLE disappearing_message ( + bridge_id TEXT NOT NULL, + mx_room TEXT NOT NULL, + mxid TEXT NOT NULL, + type TEXT NOT NULL, + timer BIGINT NOT NULL, + disappear_at BIGINT, + + PRIMARY KEY (bridge_id, mxid) +); + CREATE TABLE reaction ( bridge_id TEXT NOT NULL, message_id TEXT NOT NULL, diff --git a/bridgev2/database/upgrades/02-disappearing-messages.sql b/bridgev2/database/upgrades/02-disappearing-messages.sql new file mode 100644 index 00000000..e1425e75 --- /dev/null +++ b/bridgev2/database/upgrades/02-disappearing-messages.sql @@ -0,0 +1,11 @@ +-- v2 (compatible with v1+): Add disappearing messages table +CREATE TABLE disappearing_message ( + bridge_id TEXT NOT NULL, + mx_room TEXT NOT NULL, + mxid TEXT NOT NULL, + type TEXT NOT NULL, + timer BIGINT NOT NULL, + disappear_at BIGINT, + + PRIMARY KEY (bridge_id, mxid) +); diff --git a/bridgev2/disappear.go b/bridgev2/disappear.go new file mode 100644 index 00000000..971a6c39 --- /dev/null +++ b/bridgev2/disappear.go @@ -0,0 +1,110 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "time" + + "github.com/rs/zerolog" + "golang.org/x/exp/slices" + + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type DisappearLoop struct { + br *Bridge + NextCheck time.Time + stop context.CancelFunc +} + +const DisappearCheckInterval = 1 * time.Hour + +func (dl *DisappearLoop) Start() { + log := dl.br.Log.With().Str("component", "disappear loop").Logger() + ctx := log.WithContext(context.Background()) + ctx, dl.stop = context.WithCancel(ctx) + log.Debug().Msg("Disappearing message loop starting") + for { + dl.NextCheck = time.Now().Add(DisappearCheckInterval) + messages, err := dl.br.DB.DisappearingMessage.GetUpcoming(ctx, DisappearCheckInterval) + if err != nil { + log.Err(err).Msg("Failed to get upcoming disappearing messages") + } else if len(messages) > 0 { + go dl.sleepAndDisappear(ctx, messages...) + } + select { + case <-time.After(time.Until(dl.NextCheck)): + case <-ctx.Done(): + log.Debug().Msg("Disappearing message loop stopping") + return + } + } +} + +func (dl *DisappearLoop) Stop() { + if dl.stop != nil { + dl.stop() + } +} + +func (dl *DisappearLoop) StartAll(ctx context.Context, roomID id.RoomID) { + startedMessages, err := dl.br.DB.DisappearingMessage.StartAll(ctx, roomID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to start disappearing messages") + return + } + slices.SortFunc(startedMessages, func(a, b *database.DisappearingMessage) int { + return a.DisappearAt.Compare(b.DisappearAt) + }) + slices.DeleteFunc(startedMessages, func(dm *database.DisappearingMessage) bool { + return dm.DisappearAt.After(dl.NextCheck) + }) + if len(startedMessages) > 0 { + go dl.sleepAndDisappear(ctx, startedMessages...) + } +} + +func (dl *DisappearLoop) Add(ctx context.Context, dm *database.DisappearingMessage) { + err := dl.br.DB.DisappearingMessage.Put(ctx, dm) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("event_id", dm.EventID). + Msg("Failed to save disappearing message") + } + if !dm.DisappearAt.IsZero() && dm.DisappearAt.Before(dl.NextCheck) { + go dl.sleepAndDisappear(context.WithoutCancel(ctx), dm) + } +} + +func (dl *DisappearLoop) sleepAndDisappear(ctx context.Context, dms ...*database.DisappearingMessage) { + for _, msg := range dms { + time.Sleep(time.Until(msg.DisappearAt)) + resp, err := dl.br.Bot.SendMessage(ctx, msg.RoomID, event.EventRedaction, &event.Content{ + Parsed: &event.RedactionEventContent{ + Redacts: msg.EventID, + Reason: "Message disappeared", + }, + }, time.Now()) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("target_event_id", msg.EventID).Msg("Failed to disappear message") + } else { + zerolog.Ctx(ctx).Debug(). + Stringer("target_event_id", msg.EventID). + Stringer("redaction_event_id", resp.EventID). + Msg("Disappeared message") + } + err = dl.br.DB.DisappearingMessage.Delete(ctx, msg.EventID) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("event_id", msg.EventID). + Msg("Failed to delete disappearing message entry from database") + } + } +} diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 66328004..65406cbe 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -39,6 +39,7 @@ type ConvertedMessage struct { ReplyTo *networkid.MessageOptionalPartID ThreadRoot *networkid.MessageOptionalPartID Parts []*ConvertedMessagePart + Disappear database.DisappearingSetting } type ConvertedEditPart struct { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 86667d9b..213993f2 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -355,6 +355,7 @@ func (portal *Portal) handleMatrixReadReceipt(user *User, eventID id.EventID, re if err != nil { log.Err(err).Msg("Failed to save user portal metadata") } + portal.Bridge.DisappearLoop.StartAll(ctx, portal.MXID) } func (portal *Portal) handleMatrixTyping(evt *event.Event) { @@ -466,6 +467,17 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin if err != nil { log.Err(err).Msg("Failed to save message to database") } + if portal.Metadata.DisappearType != database.DisappearingTypeNone { + go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ + RoomID: portal.MXID, + EventID: message.MXID, + DisappearingSetting: database.DisappearingSetting{ + Type: portal.Metadata.DisappearType, + Timer: portal.Metadata.DisappearTimer, + DisappearAt: message.Timestamp.Add(portal.Metadata.DisappearTimer), + }, + }) + } portal.sendSuccessStatus(ctx, evt) } @@ -841,6 +853,13 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin if err != nil { log.Err(err).Str("part_id", string(part.ID)).Msg("Failed to save message part to database") } + if converted.Disappear.Type != database.DisappearingTypeNone { + go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ + RoomID: portal.MXID, + EventID: dbMessage.MXID, + DisappearingSetting: converted.Disappear, + }) + } if prevThreadEvent != nil { prevThreadEvent = dbMessage } @@ -1111,13 +1130,17 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL log.Warn().Msg("No target message found for read receipt") return } - intent := portal.getIntentFor(ctx, evt.GetSender(), source) + sender := evt.GetSender() + intent := portal.getIntentFor(ctx, sender, source) err = intent.MarkRead(ctx, portal.MXID, lastTarget.MXID, getEventTS(evt)) if err != nil { log.Err(err).Stringer("target_mxid", lastTarget.MXID).Msg("Failed to bridge read receipt") } else { log.Debug().Stringer("target_mxid", lastTarget.MXID).Msg("Bridged read receipt") } + if sender.IsFromMe { + portal.Bridge.DisappearLoop.StartAll(ctx, portal.MXID) + } } func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteReceipt) { @@ -1367,7 +1390,7 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *PortalInfo, sender * //} if changed { portal.UpdateBridgeInfo(ctx) - err := portal.Bridge.DB.Portal.Update(ctx, portal.Portal) + err := portal.Save(ctx) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal to database after updating info") } @@ -1473,7 +1496,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin) e portal.Bridge.portalsByMXID[roomID] = portal portal.Bridge.cacheLock.Unlock() portal.updateLogger() - err = portal.Bridge.DB.Portal.Update(ctx, portal.Portal) + err = portal.Save(ctx) if err != nil { log.Err(err).Msg("Failed to save portal to database after creating Matrix room") return err @@ -1489,3 +1512,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin) e } return nil } + +func (portal *Portal) Save(ctx context.Context) error { + return portal.Bridge.DB.Portal.Update(ctx, portal.Portal) +} From 7b845c947f3e390392aec3acf867a2c76f3e0f24 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 19 Jun 2024 11:28:49 +0300 Subject: [PATCH 0309/1647] bridgev2: add direct media interface (#244) Co-authored-by: Sumner Evans --- bridgev2/bridgeconfig/config.go | 9 + bridgev2/bridgeconfig/upgrade.go | 13 + bridgev2/matrix/connector.go | 9 +- bridgev2/matrix/directmedia.go | 98 ++++++ bridgev2/matrix/mxmain/example-config.yaml | 20 ++ bridgev2/matrixinterface.go | 2 + bridgev2/networkid/bridgeid.go | 7 + bridgev2/networkinterface.go | 16 + mediaproxy/mediaproxy.go | 388 +++++++++++++++++++++ 9 files changed, 561 insertions(+), 1 deletion(-) create mode 100644 bridgev2/matrix/directmedia.go create mode 100644 mediaproxy/mediaproxy.go diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index 1153407a..448c2f89 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -20,6 +20,7 @@ type Config struct { AppService AppserviceConfig `yaml:"appservice"` Matrix MatrixConfig `yaml:"matrix"` Provisioning ProvisioningConfig `yaml:"provisioning"` + DirectMedia DirectMediaConfig `yaml:"direct_media"` DoublePuppet DoublePuppetConfig `yaml:"double_puppet"` Encryption EncryptionConfig `yaml:"encryption"` Logging zeroconfig.Config `yaml:"logging"` @@ -46,6 +47,14 @@ type ProvisioningConfig struct { DebugEndpoints bool `yaml:"debug_endpoints"` } +type DirectMediaConfig struct { + Enabled bool `yaml:"enabled"` + AllowProxy bool `yaml:"allow_proxy"` + ServerName string `yaml:"server_name"` + WellKnownResponse string `yaml:"well_known_response"` + ServerKey string `yaml:"server_key"` +} + type DoublePuppetConfig struct { Servers map[string]string `yaml:"servers"` AllowDiscovery bool `yaml:"allow_discovery"` diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 570f2c95..f554973f 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -12,6 +12,8 @@ import ( up "go.mau.fi/util/configupgrade" "go.mau.fi/util/random" + + "maunium.net/go/mautrix/federation" ) func doUpgrade(helper up.Helper) { @@ -66,6 +68,17 @@ func doUpgrade(helper up.Helper) { } helper.Copy(up.Bool, "provisioning", "debug_endpoints") + helper.Copy(up.Bool, "direct_media", "enabled") + helper.Copy(up.Str, "direct_media", "server_name") + helper.Copy(up.Str|up.Null, "direct_media", "well_known_response") + helper.Copy(up.Bool, "direct_media", "allow_proxy") + if serverKey, ok := helper.Get(up.Str, "direct_media", "server_key"); !ok || serverKey == "generate" { + serverKey = federation.GenerateSigningKey().SynapseString() + helper.Set(up.Str, serverKey, "direct_media", "server_key") + } else { + helper.Copy(up.Str, "direct_media", "server_key") + } + helper.Copy(up.Map, "double_puppet", "servers") helper.Copy(up.Bool, "double_puppet", "allow_discovery") helper.Copy(up.Map, "double_puppet", "secrets") diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 5773a1bf..5da605dc 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -29,6 +29,7 @@ import ( "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/mediaproxy" "maunium.net/go/mautrix/sqlstatestore" ) @@ -57,6 +58,8 @@ type Connector struct { Bridge *bridgev2.Bridge Provisioning *ProvisioningAPI DoublePuppet *doublePuppetUtil + MediaProxy *mediaproxy.MediaProxy + dmaSigKey [32]byte MediaConfig mautrix.RespMediaConfig SpecVersions *mautrix.RespVersions @@ -119,7 +122,11 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) { func (br *Connector) Start(ctx context.Context) error { br.Provisioning.Init() - err := br.StateStore.Upgrade(ctx) + err := br.initDirectMedia() + if err != nil { + return err + } + err = br.StateStore.Upgrade(ctx) if err != nil { return bridgev2.DBUpgradeError{Section: "matrix_state", Err: err} } diff --git a/bridgev2/matrix/directmedia.go b/bridgev2/matrix/directmedia.go new file mode 100644 index 00000000..e4090ea0 --- /dev/null +++ b/bridgev2/matrix/directmedia.go @@ -0,0 +1,98 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package matrix + +import ( + "bytes" + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "fmt" + "net/http" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/mediaproxy" +) + +const MediaIDPrefix = "\U0001F408" +const MediaIDTruncatedHashLength = 16 +const ContentURIMaxLength = 255 + +func (br *Connector) initDirectMedia() error { + if !br.Config.DirectMedia.Enabled { + return nil + } + dmn, ok := br.Bridge.Network.(bridgev2.DirectMediableNetwork) + if !ok { + return fmt.Errorf("direct media is enabled in config, but the network connector does not support it") + } + var err error + br.MediaProxy, err = mediaproxy.New(br.Config.DirectMedia.ServerName, br.Config.DirectMedia.ServerKey, br.getDirectMedia) + if err != nil { + return fmt.Errorf("failed to initialize media proxy: %w", err) + } + if br.Config.DirectMedia.WellKnownResponse != "" { + br.MediaProxy.KeyServer.WellKnownTarget = br.Config.DirectMedia.WellKnownResponse + } + if !br.Config.DirectMedia.AllowProxy { + br.MediaProxy.DisallowProxying() + } + br.MediaProxy.RegisterRoutes(br.AS.Router) + br.dmaSigKey = sha256.Sum256(br.MediaProxy.GetServerKey().Priv.Seed()) + dmn.SetUseDirectMedia() + br.Log.Debug().Str("server_name", br.MediaProxy.GetServerName()).Msg("Enabled direct media access") + return nil +} + +func (br *Connector) hashMediaID(data []byte) []byte { + hasher := hmac.New(sha256.New, br.dmaSigKey[:]) + hasher.Write(data) + return hasher.Sum(nil)[:MediaIDTruncatedHashLength] +} + +func (br *Connector) GenerateContentURI(ctx context.Context, mediaID networkid.MediaID) (id.ContentURIString, error) { + if br.MediaProxy == nil { + return "", bridgev2.ErrDirectMediaNotEnabled + } + buf := make([]byte, len(MediaIDPrefix)+len(mediaID)+MediaIDTruncatedHashLength) + copy(buf, MediaIDPrefix) + copy(buf[len(MediaIDPrefix):], mediaID) + truncatedHash := br.hashMediaID(buf[:len(MediaIDPrefix)+len(mediaID)]) + copy(buf[len(MediaIDPrefix)+len(mediaID):], truncatedHash) + mxc := id.ContentURI{ + Homeserver: br.MediaProxy.GetServerName(), + FileID: base64.RawURLEncoding.EncodeToString(buf), + }.CUString() + if len(mxc) > ContentURIMaxLength { + return "", fmt.Errorf("content URI too long (%d > %d)", len(mxc), ContentURIMaxLength) + } + return mxc, nil +} + +func (br *Connector) getDirectMedia(ctx context.Context, mediaIDStr string) (response mediaproxy.GetMediaResponse, err error) { + mediaID, err := base64.RawURLEncoding.DecodeString(mediaIDStr) + if err != nil || !bytes.HasPrefix(mediaID, []byte(MediaIDPrefix)) || len(mediaID) < len(MediaIDPrefix)+MediaIDTruncatedHashLength+1 { + return nil, mediaproxy.ErrInvalidMediaIDSyntax + } + receivedHash := mediaID[len(mediaID)-MediaIDTruncatedHashLength:] + expectedHash := br.hashMediaID(mediaID[:len(mediaID)-MediaIDTruncatedHashLength]) + if !hmac.Equal(receivedHash, expectedHash) { + return nil, &mediaproxy.ResponseError{ + Status: http.StatusNotFound, + Data: &mautrix.RespError{ + ErrCode: mautrix.MNotFound.ErrCode, + Err: "Invalid checksum in media ID part", + }, + } + } + remoteMediaID := networkid.MediaID(mediaID[len(MediaIDPrefix) : len(mediaID)-MediaIDTruncatedHashLength]) + return br.Bridge.Network.(bridgev2.DirectMediableNetwork).Download(ctx, remoteMediaID) +} diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index fdfd3e4a..425eba17 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -117,6 +117,26 @@ provisioning: # Enable debug API at /debug with provisioning authentication. debug_endpoints: false +# Settings for converting remote media to custom mxc:// URIs instead of reuploading. +# More details can be found at https://docs.mau.fi/bridges/go/discord/direct-media.html +direct_media: + # Should custom mxc:// URIs be used instead of reuploading media? + enabled: false + # The server name to use for the custom mxc:// URIs. + # This server name will effectively be a real Matrix server, it just won't implement anything other than media. + # You must either set up .well-known delegation from this domain to the bridge, or proxy the domain directly to the bridge. + server_name: discord-media.example.com + # Optionally a custom .well-known response. This defaults to `server_name:443` + well_known_response: + # If the remote network supports media downloads over HTTP, then the bridge will use MSC3860/MSC3916 + # media download redirects if the requester supports it. Optionally, you can force redirects + # and not allow proxying at all by setting this to false. + # This option does nothing if the remote network does not support media downloads over HTTP. + allow_proxy: true + # Matrix server signing key to make the federation tester pass, same format as synapse's .signing.key file. + # This key is also used to sign the mxc:// URIs to ensure only the bridge can generate them. + server_key: generate + # Settings for enabling double puppeting double_puppet: # Servers to always allow double puppeting from. diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 56fbde38..abbaa382 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -25,6 +25,8 @@ type MatrixConnector interface { ParseGhostMXID(userID id.UserID) (networkid.UserID, bool) FormatGhostMXID(userID networkid.UserID) id.UserID + GenerateContentURI(ctx context.Context, mediaID networkid.MediaID) (id.ContentURIString, error) + GhostIntent(userID id.UserID) MatrixAPI NewUserIntent(ctx context.Context, userID id.UserID, accessToken string) (MatrixAPI, string, error) BotIntent() MatrixAPI diff --git a/bridgev2/networkid/bridgeid.go b/bridgev2/networkid/bridgeid.go index f57e74bf..3b55b67b 100644 --- a/bridgev2/networkid/bridgeid.go +++ b/bridgev2/networkid/bridgeid.go @@ -117,3 +117,10 @@ type AvatarID string // to apply the unique constraints in the database appropriately. // On networks that allow multiple emojis, this is the unicode emoji or a network-specific shortcode. type EmojiID string + +// MediaID represents a media identifier that can be downloaded from the remote network at any point in the future. +// +// This is used to implement on-demand media downloads. The network connector can ask the Matrix connector +// to generate a content URI from a media ID. Then, when the Matrix connector wants to download the media, +// it will parse the content URI and ask the network connector for the data using the media ID. +type MediaID []byte diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 65406cbe..ca02d12a 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -8,6 +8,7 @@ package bridgev2 import ( "context" + "errors" "fmt" "strings" "time" @@ -19,6 +20,7 @@ import ( "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/mediaproxy" ) type ConvertedMessagePart struct { @@ -117,6 +119,20 @@ type NetworkConnector interface { CreateLogin(ctx context.Context, user *User, flowID string) (LoginProcess, error) } +var ErrDirectMediaNotEnabled = errors.New("direct media is not enabled") + +// DirectMediableNetwork is an optional interface that network connectors can implement to support direct media access. +// +// If the Matrix connector has direct media enabled, SetUseDirectMedia will be called +// before the Start method of the network connector. Download will then be called +// whenever someone wants to download a direct media `mxc://` URI which was generated +// by calling GenerateContentURI on the Matrix connector. +type DirectMediableNetwork interface { + NetworkConnector + SetUseDirectMedia() + Download(ctx context.Context, mediaID networkid.MediaID) (mediaproxy.GetMediaResponse, error) +} + // ConfigValidatingNetwork is an optional interface that network connectors can implement to validate config fields // before the bridge is started. // diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go new file mode 100644 index 00000000..702910fa --- /dev/null +++ b/mediaproxy/mediaproxy.go @@ -0,0 +1,388 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mediaproxy + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "mime" + "mime/multipart" + "net" + "net/http" + "net/textproto" + "strconv" + "strings" + "time" + + "github.com/gorilla/mux" + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/federation" +) + +type GetMediaResponse interface { + isGetMediaResponse() +} + +func (*GetMediaResponseURL) isGetMediaResponse() {} +func (*GetMediaResponseData) isGetMediaResponse() {} + +type GetMediaResponseURL struct { + URL string + ExpiresAt time.Time +} + +type GetMediaResponseData struct { + Reader io.ReadCloser + ContentType string + ContentLength int64 +} + +type GetMediaFunc = func(ctx context.Context, mediaID string) (response GetMediaResponse, err error) + +type MediaProxy struct { + KeyServer *federation.KeyServer + ProxyClient *http.Client + + GetMedia GetMediaFunc + PrepareProxyRequest func(*http.Request) + + serverName string + serverKey *federation.SigningKey + + FederationRouter *mux.Router + LegacyMediaRouter *mux.Router + ClientMediaRouter *mux.Router +} + +func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProxy, error) { + parsed, err := federation.ParseSynapseKey(serverKey) + if err != nil { + return nil, err + } + return &MediaProxy{ + serverName: serverName, + serverKey: parsed, + GetMedia: getMedia, + ProxyClient: &http.Client{ + Transport: &http.Transport{ + DialContext: (&net.Dialer{Timeout: 10 * time.Second}).DialContext, + TLSHandshakeTimeout: 10 * time.Second, + ForceAttemptHTTP2: false, + }, + Timeout: 60 * time.Second, + }, + KeyServer: &federation.KeyServer{ + KeyProvider: &federation.StaticServerKey{ + ServerName: serverName, + Key: parsed, + }, + WellKnownTarget: fmt.Sprintf("%s:443", serverName), + Version: federation.ServerVersion{ + Name: "mautrix-go media proxy", + Version: mautrix.Version, + }, + }, + }, nil +} + +func (mp *MediaProxy) GetServerName() string { + return mp.serverName +} + +func (mp *MediaProxy) GetServerKey() *federation.SigningKey { + return mp.serverKey +} + +func (mp *MediaProxy) DisallowProxying() { + mp.ProxyClient = nil +} + +func (mp *MediaProxy) RegisterRoutes(router *mux.Router) { + if mp.FederationRouter == nil { + mp.FederationRouter = router.PathPrefix("/_matrix/federation").Subrouter() + } + if mp.LegacyMediaRouter == nil { + mp.LegacyMediaRouter = router.PathPrefix("/_matrix/media").Subrouter() + } + if mp.ClientMediaRouter == nil { + mp.ClientMediaRouter = router.PathPrefix("/_matrix/client/v1/media").Subrouter() + } + + mp.FederationRouter.HandleFunc("/v1/media/download/{mediaID}", mp.DownloadMediaFederation).Methods(http.MethodGet) + mp.FederationRouter.HandleFunc("/v1/version", mp.KeyServer.GetServerVersion).Methods(http.MethodGet) + addClientRoutes := func(router *mux.Router, prefix string) { + router.HandleFunc(prefix+"/download/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet) + router.HandleFunc(prefix+"/download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia).Methods(http.MethodGet) + router.HandleFunc(prefix+"/thumbnail/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet) + router.HandleFunc(prefix+"/upload/{serverName}/{mediaID}", mp.UploadNotSupported).Methods(http.MethodPut) + router.HandleFunc(prefix+"/upload", mp.UploadNotSupported).Methods(http.MethodPost) + router.HandleFunc(prefix+"/create", mp.UploadNotSupported).Methods(http.MethodPost) + router.HandleFunc(prefix+"/config", mp.UploadNotSupported).Methods(http.MethodGet) + router.HandleFunc(prefix+"/preview_url", mp.PreviewURLNotSupported).Methods(http.MethodGet) + } + addClientRoutes(mp.LegacyMediaRouter, "/v3") + addClientRoutes(mp.LegacyMediaRouter, "/r0") + addClientRoutes(mp.LegacyMediaRouter, "/v1") + addClientRoutes(mp.ClientMediaRouter, "") + mp.LegacyMediaRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint) + mp.LegacyMediaRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod) + mp.FederationRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint) + mp.FederationRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod) + mp.ClientMediaRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint) + mp.ClientMediaRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod) + corsMiddleware := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization") + w.Header().Set("Content-Security-Policy", "sandbox; default-src 'none'; script-src 'none'; plugin-types application/pdf; style-src 'unsafe-inline'; object-src 'self';") + next.ServeHTTP(w, r) + }) + } + mp.LegacyMediaRouter.Use(corsMiddleware) + mp.ClientMediaRouter.Use(corsMiddleware) + mp.KeyServer.Register(router) +} + +func (mp *MediaProxy) proxyDownload(ctx context.Context, w http.ResponseWriter, url, fileName string) { + log := zerolog.Ctx(ctx) + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + log.Err(err).Str("url", url).Msg("Failed to create proxy request") + jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ + ErrCode: "M_UNKNOWN", + Err: "Failed to create proxy request", + }) + return + } + if mp.PrepareProxyRequest != nil { + mp.PrepareProxyRequest(req) + } + resp, err := mp.ProxyClient.Do(req) + defer func() { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + }() + if err != nil { + log.Err(err).Str("url", url).Msg("Failed to proxy download") + jsonResponse(w, http.StatusServiceUnavailable, &mautrix.RespError{ + ErrCode: "M_UNKNOWN", + Err: "Failed to proxy download", + }) + return + } else if resp.StatusCode != http.StatusOK { + log.Warn().Str("url", url).Int("status", resp.StatusCode).Msg("Unexpected status code proxying download") + jsonResponse(w, resp.StatusCode, &mautrix.RespError{ + ErrCode: "M_UNKNOWN", + Err: "Unexpected status code proxying download", + }) + return + } + w.Header()["Content-Type"] = resp.Header["Content-Type"] + w.Header()["Content-Length"] = resp.Header["Content-Length"] + w.Header()["Last-Modified"] = resp.Header["Last-Modified"] + w.Header()["Cache-Control"] = resp.Header["Cache-Control"] + contentDisposition := "attachment" + switch resp.Header.Get("Content-Type") { + case "text/css", "text/plain", "text/csv", "application/json", "application/ld+json", "image/jpeg", "image/gif", + "image/png", "image/apng", "image/webp", "image/avif", "video/mp4", "video/webm", "video/ogg", "video/quicktime", + "audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave", "audio/wav", "audio/x-wav", + "audio/x-pn-wav", "audio/flac", "audio/x-flac", "application/pdf": + contentDisposition = "inline" + } + if fileName != "" { + contentDisposition = mime.FormatMediaType(contentDisposition, map[string]string{ + "filename": fileName, + }) + } + w.Header().Set("Content-Disposition", contentDisposition) + w.WriteHeader(http.StatusOK) + _, err = io.Copy(w, resp.Body) + if err != nil { + log.Debug().Err(err).Msg("Failed to write proxy response") + } +} + +type ResponseError struct { + Status int + Data any +} + +func (err *ResponseError) Error() string { + return fmt.Sprintf("HTTP %d: %v", err.Status, err.Data) +} + +var ErrInvalidMediaIDSyntax = errors.New("invalid media ID syntax") + +func (mp *MediaProxy) getMedia(w http.ResponseWriter, r *http.Request) GetMediaResponse { + mediaID := mux.Vars(r)["mediaID"] + resp, err := mp.GetMedia(r.Context(), mediaID) + if err != nil { + var respError *ResponseError + if errors.Is(err, ErrInvalidMediaIDSyntax) { + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + ErrCode: mautrix.MNotFound.ErrCode, + Err: fmt.Sprintf("This is a media proxy at %q, other media downloads are not available here", mp.serverName), + }) + } else if errors.As(err, &respError) { + jsonResponse(w, respError.Status, respError.Data) + } else { + zerolog.Ctx(r.Context()).Err(err).Str("media_id", mediaID).Msg("Failed to get media URL") + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + ErrCode: mautrix.MNotFound.ErrCode, + Err: "Media not found", + }) + } + return nil + } + return resp +} + +func (mp *MediaProxy) DownloadMediaFederation(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + log := zerolog.Ctx(ctx) + // TODO check destination header in X-Matrix auth + + resp := mp.getMedia(w, r) + if resp == nil { + return + } + + mpw := multipart.NewWriter(w) + w.Header().Set("Content-Type", strings.Replace(mpw.FormDataContentType(), "form-data", "mixed", 1)) + w.WriteHeader(http.StatusOK) + metaPart, err := mpw.CreatePart(textproto.MIMEHeader{ + "Content-Type": {"application/json"}, + }) + if err != nil { + log.Err(err).Msg("Failed to create multipart metadata field") + return + } + _, err = metaPart.Write([]byte(`{}`)) + if err != nil { + log.Err(err).Msg("Failed to write multipart metadata field") + return + } + if urlResp, ok := resp.(*GetMediaResponseURL); ok { + _, err = mpw.CreatePart(textproto.MIMEHeader{ + "Location": {urlResp.URL}, + }) + if err != nil { + log.Err(err).Msg("Failed to create multipart redirect field") + return + } + } else if dataResp, ok := resp.(*GetMediaResponseData); ok { + dataPart, err := mpw.CreatePart(textproto.MIMEHeader{ + "Content-Type": {dataResp.ContentType}, + }) + if err != nil { + log.Err(err).Msg("Failed to create multipart data field") + return + } + _, err = io.Copy(dataPart, dataResp.Reader) + if err != nil { + log.Err(err).Msg("Failed to write multipart data field") + return + } + } else { + panic("unknown GetMediaResponse type") + } + err = mpw.Close() + if err != nil { + log.Err(err).Msg("Failed to close multipart writer") + return + } +} + +func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + log := zerolog.Ctx(ctx) + vars := mux.Vars(r) + if vars["serverName"] != mp.serverName { + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + ErrCode: mautrix.MNotFound.ErrCode, + Err: fmt.Sprintf("This is a media proxy at %q, other media downloads are not available here", mp.serverName), + }) + return + } + resp := mp.getMedia(w, r) + if resp == nil { + return + } + + if urlResp, ok := resp.(*GetMediaResponseURL); ok { + // Proxy if the config allows proxying and the request doesn't allow redirects. + // In any other case, redirect to the URL. + if mp.ProxyClient != nil && r.URL.Query().Get("allow_redirect") != "true" { + mp.proxyDownload(ctx, w, urlResp.URL, vars["fileName"]) + return + } + w.Header().Set("Location", urlResp.URL) + expirySeconds := (time.Until(urlResp.ExpiresAt) - 5*time.Minute).Seconds() + if urlResp.ExpiresAt.IsZero() { + w.Header().Set("Cache-Control", "public, max-age=31536000, immutable") + } else if expirySeconds > 0 { + cacheControl := fmt.Sprintf("public, max-age=%d, immutable", int(expirySeconds)) + w.Header().Set("Cache-Control", cacheControl) + } else { + w.Header().Set("Cache-Control", "no-store") + } + w.WriteHeader(http.StatusTemporaryRedirect) + } else if dataResp, ok := resp.(*GetMediaResponseData); ok { + w.Header().Set("Content-Type", dataResp.ContentType) + if dataResp.ContentLength != 0 { + w.Header().Set("Content-Length", strconv.FormatInt(dataResp.ContentLength, 10)) + } + w.WriteHeader(http.StatusOK) + _, err := io.Copy(w, dataResp.Reader) + if err != nil { + log.Err(err).Msg("Failed to write media data") + } + } else { + panic("unknown GetMediaResponse type") + } +} + +func jsonResponse(w http.ResponseWriter, status int, response interface{}) { + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(response) +} + +func (mp *MediaProxy) UploadNotSupported(w http.ResponseWriter, r *http.Request) { + jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ + ErrCode: mautrix.MUnrecognized.ErrCode, + Err: "This is a media proxy and does not support media uploads.", + }) +} + +func (mp *MediaProxy) PreviewURLNotSupported(w http.ResponseWriter, r *http.Request) { + jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ + ErrCode: mautrix.MUnrecognized.ErrCode, + Err: "This is a media proxy and does not support URL previews.", + }) +} + +func (mp *MediaProxy) UnknownEndpoint(w http.ResponseWriter, r *http.Request) { + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + ErrCode: mautrix.MUnrecognized.ErrCode, + Err: "Unrecognized endpoint", + }) +} + +func (mp *MediaProxy) UnsupportedMethod(w http.ResponseWriter, r *http.Request) { + jsonResponse(w, http.StatusMethodNotAllowed, &mautrix.RespError{ + ErrCode: mautrix.MUnrecognized.ErrCode, + Err: "Invalid method for endpoint", + }) +} From 3e302fb46fdbfd86fded1fb040dcf813088e34c2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 19 Jun 2024 11:46:03 +0300 Subject: [PATCH 0310/1647] mediaproxy: add basic config structs --- bridgev2/bridgeconfig/config.go | 9 ++++----- bridgev2/matrix/directmedia.go | 8 +------- mediaproxy/mediaproxy.go | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 12 deletions(-) diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index 448c2f89..d7ef575b 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -10,6 +10,8 @@ import ( "go.mau.fi/util/dbutil" "go.mau.fi/zeroconfig" "gopkg.in/yaml.v3" + + "maunium.net/go/mautrix/mediaproxy" ) type Config struct { @@ -48,11 +50,8 @@ type ProvisioningConfig struct { } type DirectMediaConfig struct { - Enabled bool `yaml:"enabled"` - AllowProxy bool `yaml:"allow_proxy"` - ServerName string `yaml:"server_name"` - WellKnownResponse string `yaml:"well_known_response"` - ServerKey string `yaml:"server_key"` + Enabled bool `yaml:"enabled"` + mediaproxy.BasicConfig `yaml:",inline"` } type DoublePuppetConfig struct { diff --git a/bridgev2/matrix/directmedia.go b/bridgev2/matrix/directmedia.go index e4090ea0..58b461bb 100644 --- a/bridgev2/matrix/directmedia.go +++ b/bridgev2/matrix/directmedia.go @@ -35,16 +35,10 @@ func (br *Connector) initDirectMedia() error { return fmt.Errorf("direct media is enabled in config, but the network connector does not support it") } var err error - br.MediaProxy, err = mediaproxy.New(br.Config.DirectMedia.ServerName, br.Config.DirectMedia.ServerKey, br.getDirectMedia) + br.MediaProxy, err = mediaproxy.NewFromConfig(br.Config.DirectMedia.BasicConfig, br.getDirectMedia) if err != nil { return fmt.Errorf("failed to initialize media proxy: %w", err) } - if br.Config.DirectMedia.WellKnownResponse != "" { - br.MediaProxy.KeyServer.WellKnownTarget = br.Config.DirectMedia.WellKnownResponse - } - if !br.Config.DirectMedia.AllowProxy { - br.MediaProxy.DisallowProxying() - } br.MediaProxy.RegisterRoutes(br.AS.Router) br.dmaSigKey = sha256.Sum256(br.MediaProxy.GetServerKey().Priv.Seed()) dmn.SetUseDirectMedia() diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go index 702910fa..f56bad5b 100644 --- a/mediaproxy/mediaproxy.go +++ b/mediaproxy/mediaproxy.go @@ -94,6 +94,38 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx }, nil } +type BasicConfig struct { + ServerName string `yaml:"server_name" json:"server_name"` + ServerKey string `yaml:"server_key" json:"server_key"` + AllowProxy bool `yaml:"allow_proxy" json:"allow_proxy"` + WellKnownResponse string `yaml:"well_known_response" json:"well_known_response"` +} + +func NewFromConfig(cfg BasicConfig, getMedia GetMediaFunc) (*MediaProxy, error) { + mp, err := New(cfg.ServerName, cfg.ServerKey, getMedia) + if err != nil { + return nil, err + } + if !cfg.AllowProxy { + mp.DisallowProxying() + } + if cfg.WellKnownResponse != "" { + mp.KeyServer.WellKnownTarget = cfg.WellKnownResponse + } + return mp, nil +} + +type ServerConfig struct { + Hostname string `yaml:"hostname" json:"hostname"` + Port uint16 `yaml:"port" json:"port"` +} + +func (mp *MediaProxy) Listen(cfg ServerConfig) error { + router := mux.NewRouter() + mp.RegisterRoutes(router) + return http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), router) +} + func (mp *MediaProxy) GetServerName() string { return mp.serverName } From bd2c40e815bdbc757814a5e1483db24b9c53ecf7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 19 Jun 2024 12:19:41 +0300 Subject: [PATCH 0311/1647] mediaproxy: adjust default /version response --- mediaproxy/mediaproxy.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go index f56bad5b..c17be6ed 100644 --- a/mediaproxy/mediaproxy.go +++ b/mediaproxy/mediaproxy.go @@ -88,7 +88,7 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx WellKnownTarget: fmt.Sprintf("%s:443", serverName), Version: federation.ServerVersion{ Name: "mautrix-go media proxy", - Version: mautrix.Version, + Version: strings.TrimPrefix(mautrix.VersionWithCommit, "v"), }, }, }, nil From 4516583742808ea467ee3306262590698a25ae4a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 19 Jun 2024 12:28:12 +0300 Subject: [PATCH 0312/1647] mediaproxy: add option to force proxying legacy federated downloads --- mediaproxy/mediaproxy.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go index c17be6ed..f2591428 100644 --- a/mediaproxy/mediaproxy.go +++ b/mediaproxy/mediaproxy.go @@ -52,6 +52,8 @@ type MediaProxy struct { KeyServer *federation.KeyServer ProxyClient *http.Client + ForceProxyLegacyFederation bool + GetMedia GetMediaFunc PrepareProxyRequest func(*http.Request) @@ -196,6 +198,7 @@ func (mp *MediaProxy) proxyDownload(ctx context.Context, w http.ResponseWriter, }) return } + req.Header.Set("User-Agent", mautrix.DefaultUserAgent+" (media proxy)") if mp.PrepareProxyRequest != nil { mp.PrepareProxyRequest(req) } @@ -355,7 +358,8 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { if urlResp, ok := resp.(*GetMediaResponseURL); ok { // Proxy if the config allows proxying and the request doesn't allow redirects. // In any other case, redirect to the URL. - if mp.ProxyClient != nil && r.URL.Query().Get("allow_redirect") != "true" { + isFederated := strings.HasPrefix(r.Header.Get("Authorization"), "X-Matrix") + if mp.ProxyClient != nil && (r.URL.Query().Get("allow_redirect") != "true" || (mp.ForceProxyLegacyFederation && isFederated)) { mp.proxyDownload(ctx, w, urlResp.URL, vars["fileName"]) return } From 2e23f6372f661a0d6ad9f56e2a91dad28d9f115d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 19 Jun 2024 15:35:44 +0300 Subject: [PATCH 0313/1647] bridgev2: add support for typing notifications --- bridgev2/matrix/intent.go | 11 +++++ bridgev2/matrixinterface.go | 1 + bridgev2/networkinterface.go | 20 +++++++- bridgev2/portal.go | 96 +++++++++++++++++++++++++++++++----- 4 files changed, 114 insertions(+), 14 deletions(-) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index fbe1e261..877d5ee8 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -82,6 +82,17 @@ func (as *ASIntent) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.E return as.Matrix.SetReadMarkers(ctx, roomID, &req) } +func (as *ASIntent) MarkTyping(ctx context.Context, roomID id.RoomID, typingType bridgev2.TypingType, timeout time.Duration) error { + if typingType != bridgev2.TypingTypeText { + return nil + } else if as.Matrix.IsCustomPuppet { + // Don't send double puppeted typing notifications, there's no good way to prevent echoing them + return nil + } + _, err := as.Matrix.UserTyping(ctx, roomID, timeout > 0, timeout) + return err +} + func (as *ASIntent) DownloadMedia(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo) ([]byte, error) { if file != nil { uri = file.URL diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index abbaa382..41799f3a 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -46,6 +46,7 @@ type MatrixAPI interface { SendMessage(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, ts time.Time) (*mautrix.RespSendEvent, error) SendState(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, content *event.Content, ts time.Time) (*mautrix.RespSendEvent, error) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.EventID, ts time.Time) error + MarkTyping(ctx context.Context, roomID id.RoomID, typingType TypingType, timeout time.Duration) error DownloadMedia(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo) ([]byte, error) UploadMedia(ctx context.Context, roomID id.RoomID, data []byte, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index ca02d12a..267be890 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -177,6 +177,7 @@ type NetworkAPI interface { HandleMatrixReactionRemove(ctx context.Context, msg *MatrixReactionRemove) error HandleMatrixMessageRemove(ctx context.Context, msg *MatrixMessageRemove) error HandleMatrixReadReceipt(ctx context.Context, msg *MatrixReadReceipt) error + HandleMatrixTyping(ctx context.Context, msg *MatrixTyping) error } type PushType int @@ -329,6 +330,19 @@ type RemoteTyping interface { GetTimeout() time.Duration } +type TypingType int + +const ( + TypingTypeText TypingType = iota + TypingTypeUploadingMedia + TypingTypeRecordingMedia +) + +type RemoteTypingWithType interface { + RemoteTyping + GetTypingType() TypingType +} + // SimpleRemoteEvent is a simple implementation of RemoteEvent that can be used with struct fields and some callbacks. type SimpleRemoteEvent[T any] struct { Type RemoteEventType @@ -483,4 +497,8 @@ type MatrixReadReceipt struct { Receipt event.ReadReceipt } -type MatrixTyping struct{} +type MatrixTyping struct { + Portal *Portal + IsTyping bool + Type TypingType +} diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 213993f2..15cd7018 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -50,8 +50,9 @@ type Portal struct { Parent *Portal Relay *UserLogin - currentlyTyping []id.UserID - currentlyTypingLock sync.Mutex + currentlyTyping []id.UserID + currentlyTypingLogins map[id.UserID]*UserLogin + currentlyTypingLock sync.Mutex roomCreateLock sync.Mutex @@ -60,17 +61,17 @@ type Portal struct { const PortalEventBuffer = 64 -func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, queryErr error, id *networkid.PortalKey) (*Portal, error) { +func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, queryErr error, key *networkid.PortalKey) (*Portal, error) { if queryErr != nil { return nil, fmt.Errorf("failed to query db: %w", queryErr) } if dbPortal == nil { - if id == nil { + if key == nil { return nil, nil } dbPortal = &database.Portal{ BridgeID: br.ID, - PortalKey: *id, + PortalKey: *key, } err := br.DB.Portal.Insert(ctx, dbPortal) if err != nil { @@ -82,6 +83,8 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que Bridge: br, events: make(chan portalEvent, PortalEventBuffer), + + currentlyTypingLogins: make(map[id.UserID]*UserLogin), } br.portalsByKey[portal.PortalKey] = portal if portal.MXID != "" { @@ -369,16 +372,57 @@ func (portal *Portal) handleMatrixTyping(evt *event.Event) { stoppedTyping, startedTyping := exslices.SortedDiff(portal.currentlyTyping, content.UserIDs, func(a, b id.UserID) int { return strings.Compare(string(a), string(b)) }) - for range stoppedTyping { - // TODO send typing stop events - } - for range startedTyping { - // TODO send typing start events - } + ctx := portal.Log.WithContext(context.TODO()) + portal.sendTypings(ctx, stoppedTyping, false) + portal.sendTypings(ctx, startedTyping, true) portal.currentlyTyping = content.UserIDs } +func (portal *Portal) sendTypings(ctx context.Context, userIDs []id.UserID, typing bool) { + for _, userID := range userIDs { + login, ok := portal.currentlyTypingLogins[userID] + if !ok && !typing { + continue + } else if !ok { + user, err := portal.Bridge.GetUserByMXID(ctx, userID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("user_id", userID).Msg("Failed to get user to send typing event") + continue + } else if user == nil { + continue + } + login, _, err = portal.FindPreferredLogin(ctx, user, false) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("user_id", userID).Msg("Failed to get user login to send typing event") + continue + } else if login == nil { + continue + } + portal.currentlyTypingLogins[userID] = login + } + if !typing { + delete(portal.currentlyTypingLogins, userID) + } + err := login.Client.HandleMatrixTyping(ctx, &MatrixTyping{ + Portal: portal, + IsTyping: typing, + Type: TypingTypeText, + }) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("user_id", userID).Msg("Failed to bridge Matrix typing event") + } else { + zerolog.Ctx(ctx).Debug(). + Stringer("user_id", userID). + Bool("typing", typing). + Msg("Sent typing event") + } + } +} + func (portal *Portal) periodicTypingUpdater() { + // TODO actually call this function + log := portal.Log.With().Str("component", "typing updater").Logger() + ctx := log.WithContext(context.Background()) for { // TODO make delay configurable by network connector time.Sleep(5 * time.Second) @@ -387,7 +431,25 @@ func (portal *Portal) periodicTypingUpdater() { portal.currentlyTypingLock.Unlock() continue } - // TODO send typing events + for _, userID := range portal.currentlyTyping { + login, ok := portal.currentlyTypingLogins[userID] + if !ok { + continue + } + err := login.Client.HandleMatrixTyping(ctx, &MatrixTyping{ + Portal: portal, + IsTyping: true, + Type: TypingTypeText, + }) + if err != nil { + log.Err(err).Stringer("user_id", userID).Msg("Failed to repeat Matrix typing event") + } else { + log.Debug(). + Stringer("user_id", userID). + Bool("typing", true). + Msg("Sent repeatedtyping event") + } + } portal.currentlyTypingLock.Unlock() } } @@ -1148,7 +1210,15 @@ func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *U } func (portal *Portal) handleRemoteTyping(ctx context.Context, source *UserLogin, evt RemoteTyping) { - + var typingType TypingType + if typedEvt, ok := evt.(RemoteTypingWithType); ok { + typingType = typedEvt.GetTypingType() + } + intent := portal.getIntentFor(ctx, evt.GetSender(), source) + err := intent.MarkTyping(ctx, portal.MXID, typingType, evt.GetTimeout()) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to bridge typing event") + } } var stateElementFunctionalMembers = event.Type{Class: event.StateEventType, Type: "io.element.functional_members"} From 69e2b42d857adda6cecd4a7e96d8a1a17775cd93 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Wed, 19 Jun 2024 08:03:17 -0600 Subject: [PATCH 0314/1647] bridgev2/directmedia: add configurable media ID prefix Signed-off-by: Sumner Evans --- bridgev2/bridgeconfig/config.go | 3 ++- bridgev2/bridgeconfig/upgrade.go | 1 + bridgev2/matrix/directmedia.go | 5 +++-- bridgev2/matrix/mxmain/example-config.yaml | 2 ++ 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index d7ef575b..e5ebbc01 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -50,7 +50,8 @@ type ProvisioningConfig struct { } type DirectMediaConfig struct { - Enabled bool `yaml:"enabled"` + Enabled bool `yaml:"enabled"` + MediaIDPrefix string `yaml:"media_id_prefix"` mediaproxy.BasicConfig `yaml:",inline"` } diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index f554973f..560d5381 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -69,6 +69,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "provisioning", "debug_endpoints") helper.Copy(up.Bool, "direct_media", "enabled") + helper.Copy(up.Str|up.Null, "direct_media", "media_id_prefix") helper.Copy(up.Str, "direct_media", "server_name") helper.Copy(up.Str|up.Null, "direct_media", "well_known_response") helper.Copy(up.Bool, "direct_media", "allow_proxy") diff --git a/bridgev2/matrix/directmedia.go b/bridgev2/matrix/directmedia.go index 58b461bb..15af0263 100644 --- a/bridgev2/matrix/directmedia.go +++ b/bridgev2/matrix/directmedia.go @@ -14,6 +14,7 @@ import ( "encoding/base64" "fmt" "net/http" + "strings" "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2" @@ -63,7 +64,7 @@ func (br *Connector) GenerateContentURI(ctx context.Context, mediaID networkid.M copy(buf[len(MediaIDPrefix)+len(mediaID):], truncatedHash) mxc := id.ContentURI{ Homeserver: br.MediaProxy.GetServerName(), - FileID: base64.RawURLEncoding.EncodeToString(buf), + FileID: br.Config.DirectMedia.MediaIDPrefix + base64.RawURLEncoding.EncodeToString(buf), }.CUString() if len(mxc) > ContentURIMaxLength { return "", fmt.Errorf("content URI too long (%d > %d)", len(mxc), ContentURIMaxLength) @@ -72,7 +73,7 @@ func (br *Connector) GenerateContentURI(ctx context.Context, mediaID networkid.M } func (br *Connector) getDirectMedia(ctx context.Context, mediaIDStr string) (response mediaproxy.GetMediaResponse, err error) { - mediaID, err := base64.RawURLEncoding.DecodeString(mediaIDStr) + mediaID, err := base64.RawURLEncoding.DecodeString(strings.TrimPrefix(mediaIDStr, br.Config.DirectMedia.MediaIDPrefix)) if err != nil || !bytes.HasPrefix(mediaID, []byte(MediaIDPrefix)) || len(mediaID) < len(MediaIDPrefix)+MediaIDTruncatedHashLength+1 { return nil, mediaproxy.ErrInvalidMediaIDSyntax } diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 425eba17..8c18934c 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -128,6 +128,8 @@ direct_media: server_name: discord-media.example.com # Optionally a custom .well-known response. This defaults to `server_name:443` well_known_response: + # Optionally specify a custom prefix for the media ID part of the MXC URI. + media_id_prefix: # If the remote network supports media downloads over HTTP, then the bridge will use MSC3860/MSC3916 # media download redirects if the requester supports it. Optionally, you can force redirects # and not allow proxying at all by setting this to false. From c53c0c1860df560505fa8662352ac456c1662034 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 19 Jun 2024 15:46:09 +0300 Subject: [PATCH 0315/1647] bridgev2: fix typo --- bridgev2/portal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 15cd7018..4e32b0bc 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -447,7 +447,7 @@ func (portal *Portal) periodicTypingUpdater() { log.Debug(). Stringer("user_id", userID). Bool("typing", true). - Msg("Sent repeatedtyping event") + Msg("Sent repeated typing event") } } portal.currentlyTypingLock.Unlock() From 28d81a2b60bb3416d7598ee9df9fbc4a696c90a2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 19 Jun 2024 20:55:38 +0300 Subject: [PATCH 0316/1647] bridgev2: implement re-ID'ing portals properly --- bridgev2/database/portal.go | 14 ++++- bridgev2/matrix/intent.go | 35 ++++++++++- bridgev2/matrixinterface.go | 2 +- bridgev2/portal.go | 60 +++++++++++++----- bridgev2/portalreid.go | 121 ++++++++++++++++++++++++++++++++++++ 5 files changed, 211 insertions(+), 21 deletions(-) create mode 100644 bridgev2/portalreid.go diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index ded53177..05a22593 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -89,7 +89,11 @@ const ( name_set=$11, avatar_set=$12, topic_set=$13, in_space=$14, metadata=$15 WHERE bridge_id=$1 AND id=$2 AND receiver=$3 ` - reIDPortalQuery = `UPDATE portal SET id=$3 WHERE bridge_id=$1 AND id=$2` + deletePortalQuery = ` + DELETE FROM portal + WHERE bridge_id=$1 AND id=$2 AND receiver=$3 + ` + reIDPortalQuery = `UPDATE portal SET id=$4, receiver=$5 WHERE bridge_id=$1 AND id=$2 AND receiver=$3` ) func (pq *PortalQuery) GetByID(ctx context.Context, key networkid.PortalKey) (*Portal, error) { @@ -108,8 +112,8 @@ func (pq *PortalQuery) GetChildren(ctx context.Context, parentID networkid.Porta return pq.QueryMany(ctx, getChildPortalsQuery, pq.BridgeID, parentID) } -func (pq *PortalQuery) ReID(ctx context.Context, oldID, newID networkid.PortalID) error { - return pq.Exec(ctx, reIDPortalQuery, pq.BridgeID, oldID, newID) +func (pq *PortalQuery) ReID(ctx context.Context, oldID, newID networkid.PortalKey) error { + return pq.Exec(ctx, reIDPortalQuery, pq.BridgeID, oldID.ID, oldID.Receiver, newID.ID, newID.Receiver) } func (pq *PortalQuery) Insert(ctx context.Context, p *Portal) error { @@ -122,6 +126,10 @@ func (pq *PortalQuery) Update(ctx context.Context, p *Portal) error { return pq.Exec(ctx, updatePortalQuery, p.sqlVariables()...) } +func (pq *PortalQuery) Delete(ctx context.Context, key networkid.PortalKey) error { + return pq.Exec(ctx, deletePortalQuery, pq.BridgeID, key.ID, key.Receiver) +} + func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { var mxid, parentID sql.NullString var avatarHash string diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 877d5ee8..73c11658 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -11,6 +11,8 @@ import ( "fmt" "time" + "github.com/rs/zerolog" + "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/bridgev2" @@ -216,7 +218,34 @@ func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) return resp.RoomID, nil } -func (as *ASIntent) DeleteRoom(ctx context.Context, roomID id.RoomID) error { - // TODO implement non-beeper delete - return as.Matrix.BeeperDeleteRoom(ctx, roomID) +func (as *ASIntent) DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnly bool) error { + if as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureRoomYeeting) { + return as.Matrix.BeeperDeleteRoom(ctx, roomID) + } + members, err := as.Matrix.JoinedMembers(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to get portal members for cleanup: %w", err) + } + for member := range members.Joined { + if member == as.Matrix.UserID { + continue + } + _, isGhost := as.Connector.ParseGhostMXID(member) + if isGhost { + _, err = as.Connector.AS.Intent(member).LeaveRoom(ctx, roomID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("user_id", member).Msg("Failed to leave room while cleaning up portal") + } + } else if !puppetsOnly { + _, err = as.Matrix.KickUser(ctx, roomID, &mautrix.ReqKickUser{UserID: member, Reason: "Deleting portal"}) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("user_id", member).Msg("Failed to kick user while cleaning up portal") + } + } + } + _, err = as.Matrix.LeaveRoom(ctx, roomID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to leave room while cleaning up portal") + } + return nil } diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 41799f3a..c98404cf 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -55,7 +55,7 @@ type MatrixAPI interface { SetExtraProfileMeta(ctx context.Context, data any) error CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) (id.RoomID, error) - DeleteRoom(ctx context.Context, roomID id.RoomID) error + DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnly bool) error InviteUser(ctx context.Context, roomID id.RoomID, userID id.UserID) error EnsureJoined(ctx context.Context, roomID id.RoomID) error } diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 4e32b0bc..ffcd13c5 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -778,7 +778,7 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { if !ok || !mcp.ShouldCreatePortal() { return } - err := portal.CreateMatrixRoom(ctx, source) + err := portal.CreateMatrixRoom(ctx, source, nil) if err != nil { log.Err(err).Msg("Failed to create portal to handle event") // TODO error @@ -1441,7 +1441,7 @@ func (portal *Portal) SyncParticipants(ctx context.Context, members []networkid. return expectedUserIDs, extraFunctionalMembers, nil } -func (portal *Portal) UpdateInfo(ctx context.Context, info *PortalInfo, sender *Ghost, ts time.Time) { +func (portal *Portal) UpdateInfo(ctx context.Context, info *PortalInfo, source *UserLogin, sender *Ghost, ts time.Time) { changed := false if info.Name != nil { changed = portal.UpdateName(ctx, *info.Name, sender, ts) || changed @@ -1452,12 +1452,13 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *PortalInfo, sender * if info.Avatar != nil { changed = portal.UpdateAvatar(ctx, info.Avatar, sender, ts) || changed } - //if info.Members != nil && portal.MXID != "" { - // _, err := portal.SyncParticipants(ctx, info.Members, source) - // if err != nil { - // zerolog.Ctx(ctx).Err(err).Msg("Failed to sync room members") - // } - //} + if info.Members != nil && portal.MXID != "" && source != nil { + _, _, err := portal.SyncParticipants(ctx, info.Members, source) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to sync room members") + } + // TODO detect changes to functional members list? + } if changed { portal.UpdateBridgeInfo(ctx) err := portal.Save(ctx) @@ -1467,7 +1468,7 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *PortalInfo, sender * } } -func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin) error { +func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, info *PortalInfo) error { portal.roomCreateLock.Lock() defer portal.roomCreateLock.Unlock() if portal.MXID != "" { @@ -1479,12 +1480,15 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin) e ctx = log.WithContext(ctx) log.Info().Msg("Creating Matrix room") - info, err := source.Client.GetChatInfo(ctx, portal) - if err != nil { - log.Err(err).Msg("Failed to update portal info for creation") - return err + var err error + if info == nil { + info, err = source.Client.GetChatInfo(ctx, portal) + if err != nil { + log.Err(err).Msg("Failed to update portal info for creation") + return err + } } - portal.UpdateInfo(ctx, info, nil, time.Time{}) + portal.UpdateInfo(ctx, info, source, nil, time.Time{}) initialMembers, extraFunctionalMembers, err := portal.SyncParticipants(ctx, info.Members, source) if err != nil { log.Err(err).Msg("Failed to process participant list for portal creation") @@ -1583,6 +1587,34 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin) e return nil } +func (portal *Portal) Delete(ctx context.Context) error { + err := portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey) + if err != nil { + return err + } + portal.Bridge.cacheLock.Lock() + defer portal.Bridge.cacheLock.Unlock() + portal.unlockedDeleteCache() + return nil +} + +func (portal *Portal) unlockedDelete(ctx context.Context) error { + // TODO delete child portals? + err := portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey) + if err != nil { + return err + } + portal.unlockedDeleteCache() + return nil +} + +func (portal *Portal) unlockedDeleteCache() { + delete(portal.Bridge.portalsByKey, portal.PortalKey) + if portal.MXID != "" { + delete(portal.Bridge.portalsByMXID, portal.MXID) + } +} + func (portal *Portal) Save(ctx context.Context) error { return portal.Bridge.DB.Portal.Update(ctx, portal.Portal) } diff --git a/bridgev2/portalreid.go b/bridgev2/portalreid.go new file mode 100644 index 00000000..fea818dd --- /dev/null +++ b/bridgev2/portalreid.go @@ -0,0 +1,121 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "fmt" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" +) + +type ReIDResult int + +const ( + ReIDResultError ReIDResult = iota + ReIDResultNoOp + ReIDResultSourceDeleted + ReIDResultSourceReIDd + ReIDResultTargetDeletedAndSourceReIDd + ReIDResultSourceTombstonedIntoTarget +) + +func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.PortalKey) (ReIDResult, *Portal, error) { + log := zerolog.Ctx(ctx) + log.Debug().Msg("Re-ID'ing portal") + defer func() { + log.Debug().Msg("Finished handling portal re-ID") + }() + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + sourcePortal, err := br.unlockedGetPortalByID(ctx, source, true) + if err != nil { + return ReIDResultError, nil, fmt.Errorf("failed to get source portal: %w", err) + } else if sourcePortal == nil { + log.Debug().Msg("Source portal not found, re-ID is no-op") + return ReIDResultNoOp, nil, nil + } + sourcePortal.roomCreateLock.Lock() + defer sourcePortal.roomCreateLock.Unlock() + if sourcePortal.MXID == "" { + log.Info().Msg("Source portal doesn't have Matrix room, deleting row") + err = sourcePortal.unlockedDelete(ctx) + if err != nil { + return ReIDResultError, nil, fmt.Errorf("failed to delete source portal: %w", err) + } + return ReIDResultSourceDeleted, nil, nil + } + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Stringer("source_portal_mxid", sourcePortal.MXID) + }) + targetPortal, err := br.unlockedGetPortalByID(ctx, target, true) + if err != nil { + return ReIDResultError, nil, fmt.Errorf("failed to get target portal: %w", err) + } + if targetPortal == nil { + log.Info().Msg("Target portal doesn't exist, re-ID'ing source portal") + err = sourcePortal.unlockedReID(ctx, target) + if err != nil { + return ReIDResultError, nil, fmt.Errorf("failed to re-ID source portal: %w", err) + } + return ReIDResultSourceReIDd, sourcePortal, nil + } + targetPortal.roomCreateLock.Lock() + defer targetPortal.roomCreateLock.Unlock() + if targetPortal.MXID == "" { + log.Info().Msg("Target portal row exists, but doesn't have a Matrix room. Deleting target portal row and re-ID'ing source portal") + err = targetPortal.unlockedDelete(ctx) + if err != nil { + return ReIDResultError, nil, fmt.Errorf("failed to delete target portal: %w", err) + } + err = sourcePortal.unlockedReID(ctx, target) + if err != nil { + return ReIDResultError, nil, fmt.Errorf("failed to re-ID source portal after deleting target: %w", err) + } + return ReIDResultTargetDeletedAndSourceReIDd, sourcePortal, nil + } else { + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Stringer("target_portal_mxid", targetPortal.MXID) + }) + log.Info().Msg("Both target and source portals have Matrix rooms, tombstoning source portal") + err = sourcePortal.unlockedDelete(ctx) + if err != nil { + return ReIDResultError, nil, fmt.Errorf("failed to delete source portal row: %w", err) + } + go func() { + _, err := br.Bot.SendState(ctx, sourcePortal.MXID, event.StateTombstone, "", &event.Content{ + Parsed: &event.TombstoneEventContent{ + Body: fmt.Sprintf("This room has been merged"), + ReplacementRoom: targetPortal.MXID, + }, + }, time.Now()) + if err != nil { + log.Err(err).Msg("Failed to send tombstone to source portal room") + } + err = br.Bot.DeleteRoom(ctx, sourcePortal.MXID, err == nil) + if err != nil { + log.Err(err).Msg("Failed to delete source portal room") + } + }() + return ReIDResultSourceTombstonedIntoTarget, targetPortal, nil + } +} + +func (portal *Portal) unlockedReID(ctx context.Context, target networkid.PortalKey) error { + err := portal.Bridge.DB.Portal.ReID(ctx, portal.PortalKey, target) + if err != nil { + return err + } + delete(portal.Bridge.portalsByKey, portal.PortalKey) + portal.Bridge.portalsByKey[target] = portal + portal.PortalKey = target + return nil +} From e182928df710e990a251112a02f14f0df0e26773 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 19 Jun 2024 21:32:20 +0300 Subject: [PATCH 0317/1647] bridgev2: fix panic when starting disappearing messages --- bridgev2/disappear.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bridgev2/disappear.go b/bridgev2/disappear.go index 971a6c39..089c8aef 100644 --- a/bridgev2/disappear.go +++ b/bridgev2/disappear.go @@ -60,12 +60,12 @@ func (dl *DisappearLoop) StartAll(ctx context.Context, roomID id.RoomID) { zerolog.Ctx(ctx).Err(err).Msg("Failed to start disappearing messages") return } + startedMessages = slices.DeleteFunc(startedMessages, func(dm *database.DisappearingMessage) bool { + return dm.DisappearAt.After(dl.NextCheck) + }) slices.SortFunc(startedMessages, func(a, b *database.DisappearingMessage) int { return a.DisappearAt.Compare(b.DisappearAt) }) - slices.DeleteFunc(startedMessages, func(dm *database.DisappearingMessage) bool { - return dm.DisappearAt.After(dl.NextCheck) - }) if len(startedMessages) > 0 { go dl.sleepAndDisappear(ctx, startedMessages...) } From eefa21918360606c8a7ab605e5eb822684316adf Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 19 Jun 2024 21:33:02 +0300 Subject: [PATCH 0318/1647] bridgev2: add support for starting DMs --- bridgev2/cmdhelp.go | 1 + bridgev2/cmdlogin.go | 10 --- bridgev2/cmdprocessor.go | 1 + bridgev2/cmdstartchat.go | 128 +++++++++++++++++++++++++++++++++++ bridgev2/login.go | 24 ++++--- bridgev2/networkinterface.go | 31 +++++++++ bridgev2/portal.go | 10 +++ bridgev2/portalreid.go | 3 + bridgev2/user.go | 24 +++++++ 9 files changed, 211 insertions(+), 21 deletions(-) create mode 100644 bridgev2/cmdstartchat.go diff --git a/bridgev2/cmdhelp.go b/bridgev2/cmdhelp.go index 043d487c..80b8e972 100644 --- a/bridgev2/cmdhelp.go +++ b/bridgev2/cmdhelp.go @@ -29,6 +29,7 @@ var ( HelpSectionGeneral = HelpSection{"General", 0} HelpSectionAuth = HelpSection{"Authentication", 10} + HelpSectionChats = HelpSection{"Starting and managing chats", 20} HelpSectionAdmin = HelpSection{"Administration", 50} ) diff --git a/bridgev2/cmdlogin.go b/bridgev2/cmdlogin.go index 76727906..dfa66319 100644 --- a/bridgev2/cmdlogin.go +++ b/bridgev2/cmdlogin.go @@ -300,16 +300,6 @@ var CommandLogout = &FullHandler{ }, } -func (user *User) GetFormattedUserLogins() string { - user.Bridge.cacheLock.Lock() - logins := make([]string, len(user.logins)) - for key, val := range user.logins { - logins = append(logins, fmt.Sprintf("* `%s` (%s)", key, val.Metadata.RemoteName)) - } - user.Bridge.cacheLock.Unlock() - return strings.Join(logins, "\n") -} - func fnLogout(ce *CommandEvent) { if len(ce.Args) == 0 { ce.Reply("Usage: `$cmdprefix logout `\n\nYour logins:\n\n%s", ce.User.GetFormattedUserLogins()) diff --git a/bridgev2/cmdprocessor.go b/bridgev2/cmdprocessor.go index 7b064fda..15604af6 100644 --- a/bridgev2/cmdprocessor.go +++ b/bridgev2/cmdprocessor.go @@ -40,6 +40,7 @@ func NewProcessor(bridge *Bridge) *CommandProcessor { CommandHelp, CommandCancel, CommandRegisterPush, CommandLogin, CommandLogout, CommandSetPreferredLogin, + CommandResolveIdentifier, CommandStartChat, ) return proc } diff --git a/bridgev2/cmdstartchat.go b/bridgev2/cmdstartchat.go new file mode 100644 index 00000000..d8746d25 --- /dev/null +++ b/bridgev2/cmdstartchat.go @@ -0,0 +1,128 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "fmt" + "strings" + "time" + + "golang.org/x/net/html" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/id" +) + +var CommandResolveIdentifier = &FullHandler{ + Func: fnResolveIdentifier, + Name: "resolve-identifier", + Help: HelpMeta{ + Section: HelpSectionChats, + Description: "Check if a given identifier is on the remote network", + Args: "[_login ID_] <_identifier_>", + }, + RequiresLogin: true, +} + +var CommandStartChat = &FullHandler{ + Func: fnResolveIdentifier, + Name: "start-chat", + Help: HelpMeta{ + Section: HelpSectionChats, + Description: "Start a direct chat with the given user", + Args: "[_login ID_] <_identifier_>", + }, + RequiresLogin: true, +} + +func getClientForStartingChat[T IdentifierResolvingNetworkAPI](ce *CommandEvent, thing string) (*UserLogin, T, []string) { + remainingArgs := ce.Args[1:] + login := ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) + if login == nil || login.UserMXID != ce.User.MXID { + remainingArgs = ce.Args + login = ce.User.GetDefaultLogin() + } + api, ok := login.Client.(T) + if !ok { + ce.Reply("This bridge does not support %s", thing) + } + return login, api, remainingArgs +} + +func fnResolveIdentifier(ce *CommandEvent) { + login, api, identifierParts := getClientForStartingChat[IdentifierResolvingNetworkAPI](ce, "resolving identifiers") + if api == nil { + return + } + createChat := ce.Command == "start-chat" + identifier := strings.Join(identifierParts, " ") + resp, err := api.ResolveIdentifier(ce.Ctx, identifier, createChat) + if err != nil { + ce.Reply("Failed to resolve identifier: %v", err) + return + } else if resp == nil { + ce.ReplyAdvanced(fmt.Sprintf("Identifier %s not found", html.EscapeString(identifier)), false, true) + return + } + var targetName string + var targetMXID id.UserID + if resp.Ghost != nil { + if resp.UserInfo != nil { + resp.Ghost.UpdateInfo(ce.Ctx, resp.UserInfo) + } + targetName = resp.Ghost.Name + targetMXID = resp.Ghost.MXID + if !createChat { + ce.Reply("Found `%s` / [%s](%s)", resp.Ghost.ID, resp.Ghost.Name, resp.Ghost.MXID.URI().MatrixToURL()) + } + } else if resp.UserInfo != nil && resp.UserInfo.Name != nil { + targetName = *resp.UserInfo.Name + } + var formattedName string + if targetMXID != "" { + formattedName = fmt.Sprintf("`%s` / [%s](%s)", resp.UserID, targetName, targetMXID.URI().MatrixToURL()) + } else if targetName != "" { + formattedName = fmt.Sprintf("`%s` / %s", resp.UserID, targetName) + } else { + formattedName = fmt.Sprintf("`%s`", resp.UserID) + } + if createChat { + if resp.Chat == nil { + ce.Reply("Interface error: network connector did not return chat for create chat request") + return + } + portal := resp.Chat.Portal + if portal == nil { + portal, err = ce.Bridge.GetPortalByID(ce.Ctx, resp.Chat.PortalID) + if err != nil { + ce.Reply("Failed to get portal: %v", err) + return + } + } + if portal.MXID != "" { + name := portal.Name + if name == "" { + name = portal.MXID.String() + } + portal.UpdateInfo(ce.Ctx, resp.Chat.PortalInfo, login, nil, time.Time{}) + ce.Reply("You already have a direct chat with %s at [%s](%s)", formattedName, name, portal.MXID.URI().MatrixToURL()) + } else { + err = portal.CreateMatrixRoom(ce.Ctx, login, resp.Chat.PortalInfo) + if err != nil { + ce.Reply("Failed to create room: %v", err) + return + } + name := portal.Name + if name == "" { + name = portal.MXID.String() + } + ce.Reply("Created chat with %s: [%s](%s)", formattedName, name, portal.MXID.URI().MatrixToURL()) + } + } else { + ce.Reply("Found %s", formattedName) + } +} diff --git a/bridgev2/login.go b/bridgev2/login.go index 636619ce..a1dc9e91 100644 --- a/bridgev2/login.go +++ b/bridgev2/login.go @@ -146,6 +146,18 @@ func isOnlyNumbers(input string) bool { return true } +func CleanPhoneNumber(phone string) (string, error) { + phone = numberCleaner.Replace(phone) + if len(phone) < 2 { + return "", fmt.Errorf("phone number must start with + and contain numbers") + } else if phone[0] != '+' { + return "", fmt.Errorf("phone number must start with +") + } else if !isOnlyNumbers(phone[1:]) { + return "", fmt.Errorf("phone number must only contain numbers") + } + return phone, nil +} + func (f *LoginInputDataField) FillDefaultValidate() { noopValidate := func(input string) (string, error) { return input, nil } if f.Validate != nil { @@ -153,17 +165,7 @@ func (f *LoginInputDataField) FillDefaultValidate() { } switch f.Type { case LoginInputFieldTypePhoneNumber: - f.Validate = func(phone string) (string, error) { - phone = numberCleaner.Replace(phone) - if len(phone) < 2 { - return "", fmt.Errorf("phone number must start with + and contain numbers") - } else if phone[0] != '+' { - return "", fmt.Errorf("phone number must start with +") - } else if !isOnlyNumbers(phone[1:]) { - return "", fmt.Errorf("phone number must only contain numbers") - } - return phone, nil - } + f.Validate = CleanPhoneNumber case LoginInputFieldTypeEmail: f.Validate = func(email string) (string, error) { if !strings.ContainsRune(email, '@') { diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 267be890..06c8196e 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -180,6 +180,37 @@ type NetworkAPI interface { HandleMatrixTyping(ctx context.Context, msg *MatrixTyping) error } +type ResolveIdentifierResponse struct { + Ghost *Ghost + + UserID networkid.UserID + UserInfo *UserInfo + + Chat *CreateChatResponse +} + +type CreateChatResponse struct { + Portal *Portal + + PortalID networkid.PortalKey + PortalInfo *PortalInfo +} + +type IdentifierResolvingNetworkAPI interface { + NetworkAPI + ResolveIdentifier(ctx context.Context, identifier string, createChat bool) (*ResolveIdentifierResponse, error) +} + +type UserSearchingNetworkAPI interface { + IdentifierResolvingNetworkAPI + SearchUsers(ctx context.Context, query string) ([]*ResolveIdentifierResponse, error) +} + +type GroupCreatingNetworkAPI interface { + IdentifierResolvingNetworkAPI + CreateGroup(ctx context.Context, name string, users ...networkid.UserID) (*CreateChatResponse, error) +} + type PushType int func (pt PushType) String() string { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index ffcd13c5..7c2c27dd 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1357,6 +1357,9 @@ func (portal *Portal) SyncParticipants(ctx context.Context, members []networkid. if err != nil { return nil, nil, fmt.Errorf("failed to get user logins in portal: %w", err) } + if !slices.Contains(loginsInPortal, source) { + loginsInPortal = append(loginsInPortal, source) + } expectedUserIDs := make([]id.UserID, 0, len(members)) expectedExtraUsers := make([]id.UserID, 0) expectedIntents := make([]MatrixAPI, len(members)) @@ -1459,6 +1462,13 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *PortalInfo, source * } // TODO detect changes to functional members list? } + if source != nil { + // TODO is this a good place for this call? there's another one in QueueRemoteEvent + err := portal.Bridge.DB.UserPortal.EnsureExists(ctx, source.UserLogin, portal.PortalKey) + if err != nil { + zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to ensure user portal row exists") + } + } if changed { portal.UpdateBridgeInfo(ctx) err := portal.Save(ctx) diff --git a/bridgev2/portalreid.go b/bridgev2/portalreid.go index fea818dd..0622aef6 100644 --- a/bridgev2/portalreid.go +++ b/bridgev2/portalreid.go @@ -29,6 +29,9 @@ const ( ) func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.PortalKey) (ReIDResult, *Portal, error) { + if source == target { + return ReIDResultError, nil, fmt.Errorf("illegal re-ID call: source and target are the same") + } log := zerolog.Ctx(ctx) log.Debug().Msg("Re-ID'ing portal") defer func() { diff --git a/bridgev2/user.go b/bridgev2/user.go index bf8eaf13..86268ec1 100644 --- a/bridgev2/user.go +++ b/bridgev2/user.go @@ -9,10 +9,13 @@ package bridgev2 import ( "context" "fmt" + "strings" "sync" "sync/atomic" "github.com/rs/zerolog" + "golang.org/x/exp/maps" + "golang.org/x/exp/slices" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" @@ -149,6 +152,27 @@ func (user *User) DoublePuppet(ctx context.Context) MatrixAPI { return intent } +func (user *User) GetFormattedUserLogins() string { + user.Bridge.cacheLock.Lock() + logins := make([]string, len(user.logins)) + for key, val := range user.logins { + logins = append(logins, fmt.Sprintf("* `%s` (%s)", key, val.Metadata.RemoteName)) + } + user.Bridge.cacheLock.Unlock() + return strings.Join(logins, "\n") +} + +func (user *User) GetDefaultLogin() *UserLogin { + user.Bridge.cacheLock.Lock() + defer user.Bridge.cacheLock.Unlock() + if len(user.logins) == 0 { + return nil + } + loginKeys := maps.Keys(user.logins) + slices.Sort(loginKeys) + return user.logins[loginKeys[0]] +} + func (user *User) Save(ctx context.Context) error { return user.Bridge.DB.User.Update(ctx, user.User) } From 8134d17a028d62457e3dd66e0af96a9545d0e359 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 19 Jun 2024 22:35:50 +0300 Subject: [PATCH 0319/1647] bridgev2/disappear: fill disappear at timer for after_send type if it's unset --- bridgev2/portal.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 7c2c27dd..26cd30e8 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -916,6 +916,9 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin log.Err(err).Str("part_id", string(part.ID)).Msg("Failed to save message part to database") } if converted.Disappear.Type != database.DisappearingTypeNone { + if converted.Disappear.Type == database.DisappearingTypeAfterSend && converted.Disappear.DisappearAt.IsZero() { + converted.Disappear.DisappearAt = dbMessage.Timestamp.Add(converted.Disappear.Timer) + } go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ RoomID: portal.MXID, EventID: dbMessage.MXID, From 2a7a5070fb32cf6440cd801e9e012431fae7077a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 19 Jun 2024 22:58:36 +0300 Subject: [PATCH 0320/1647] bridgev2/matrix: add provisioning API for starting DMs --- bridgev2/matrix/provisioning.go | 124 ++++++++++++++++++++++++++++++-- 1 file changed, 117 insertions(+), 7 deletions(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 864edbfe..a0b06dfe 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -23,6 +23,7 @@ import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/id" ) @@ -54,7 +55,8 @@ type provisioningContextKey int const ( provisioningUserKey provisioningContextKey = iota - provisioningLoginKey + provisioningUserLoginKey + provisioningLoginProcessKey ) func (prov *ProvisioningAPI) Init() { @@ -68,8 +70,11 @@ func (prov *ProvisioningAPI) Init() { router.Use(prov.AuthMiddleware) router.Path("/v3/login/flows").Methods(http.MethodGet).HandlerFunc(prov.GetLoginFlows) router.Path("/v3/login/start/{flowID}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginStart) - router.Path("/v3/login/step/{loginID}/{stepID}/{stepType:user_input|cookies}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginSubmitInput) - router.Path("/v3/login/step/{loginID}/{stepID}/{stepType:wait}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginWait) + router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginSubmitInput) + router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:wait}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginWait) + router.Path("/v3/resolve_identifier/{identifier}").Methods(http.MethodGet).HandlerFunc(prov.GetResolveIdentifier) + router.Path("/v3/create_dm").Methods(http.MethodPost).HandlerFunc(prov.PostCreateDM) + router.Path("/v3/create_group").Methods(http.MethodPost).HandlerFunc(prov.PostCreateGroup) if prov.br.Config.Provisioning.DebugEndpoints { prov.log.Debug().Msg("Enabling debug API at /debug") @@ -133,7 +138,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { // TODO handle user being nil? ctx := context.WithValue(r.Context(), provisioningUserKey, user) - if loginID, ok := mux.Vars(r)["loginID"]; ok { + if loginID, ok := mux.Vars(r)["loginProcessID"]; ok { prov.loginsLock.RLock() login, ok := prov.logins[loginID] prov.loginsLock.RUnlock() @@ -172,7 +177,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { }) return } - ctx = context.WithValue(r.Context(), provisioningLoginKey, login) + ctx = context.WithValue(r.Context(), provisioningLoginProcessKey, login) } h.ServeHTTP(w, r.WithContext(ctx)) }) @@ -238,7 +243,7 @@ func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http }) return } - login := r.Context().Value(provisioningLoginKey).(*ProvLogin) + login := r.Context().Value(provisioningLoginProcessKey).(*ProvLogin) var nextStep *bridgev2.LoginStep switch login.NextStep.Type { case bridgev2.LoginStepTypeUserInput: @@ -261,7 +266,7 @@ func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http } func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Request) { - login := r.Context().Value(provisioningLoginKey).(*ProvLogin) + login := r.Context().Value(provisioningLoginProcessKey).(*ProvLogin) nextStep, err := login.Process.(bridgev2.LoginProcessDisplayAndWait).Wait(r.Context()) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to submit input") @@ -274,3 +279,108 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques login.NextStep = nextStep jsonResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) } + +func (prov *ProvisioningAPI) getLoginForCall(w http.ResponseWriter, r *http.Request) *bridgev2.UserLogin { + user := r.Context().Value(provisioningUserKey).(*bridgev2.User) + userLogin := prov.br.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(r.URL.Query().Get("login_id"))) + if userLogin == nil || userLogin.UserMXID != user.MXID { + userLogin = user.GetDefaultLogin() + } + if userLogin == nil { + jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ + Err: "Not logged in", + ErrCode: "FI.MAU.NOT_LOGGED_IN", + }) + return nil + } + return userLogin +} + +type RespResolveIdentifier struct { + ID networkid.UserID `json:"id,omitempty"` + Name string `json:"name,omitempty"` + AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` + MXID id.UserID `json:"mxid,omitempty"` + DMRoomID id.RoomID `json:"dm_room_mxid,omitempty"` +} + +func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http.Request, createChat bool) { + login := prov.getLoginForCall(w, r) + if login == nil { + return + } + api, ok := login.Client.(bridgev2.IdentifierResolvingNetworkAPI) + if !ok { + jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ + Err: "This bridge does not support resolving identifiers", + ErrCode: mautrix.MUnrecognized.ErrCode, + }) + return + } + resp, err := api.ResolveIdentifier(r.Context(), mux.Vars(r)["identifier"], createChat) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to resolve identifier") + jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ + Err: fmt.Sprintf("Failed to resolve identifier: %v", err), + ErrCode: "M_UNKNOWN", + }) + } + apiResp := &RespResolveIdentifier{} + status := http.StatusOK + if resp.Ghost != nil { + if resp.UserInfo != nil { + resp.Ghost.UpdateInfo(r.Context(), resp.UserInfo) + } + apiResp.Name = resp.Ghost.Name + apiResp.AvatarURL = resp.Ghost.AvatarMXC + apiResp.MXID = resp.Ghost.MXID + } else if resp.UserInfo != nil && resp.UserInfo.Name != nil { + apiResp.Name = *resp.UserInfo.Name + } + if resp.Chat != nil { + if resp.Chat.Portal == nil { + resp.Chat.Portal, err = prov.br.Bridge.GetPortalByID(r.Context(), resp.Chat.PortalID) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get portal") + jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ + Err: "Failed to get portal", + ErrCode: "M_UNKNOWN", + }) + return + } + } + if createChat && resp.Chat.Portal.MXID == "" { + status = http.StatusCreated + err = resp.Chat.Portal.CreateMatrixRoom(r.Context(), login, resp.Chat.PortalInfo) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create portal room") + jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ + Err: "Failed to create portal room", + ErrCode: "M_UNKNOWN", + }) + return + } + } + apiResp.DMRoomID = resp.Chat.Portal.MXID + } + jsonResponse(w, status, resp) +} + +func (prov *ProvisioningAPI) GetResolveIdentifier(w http.ResponseWriter, r *http.Request) { + prov.doResolveIdentifier(w, r, false) +} + +func (prov *ProvisioningAPI) PostCreateDM(w http.ResponseWriter, r *http.Request) { + prov.doResolveIdentifier(w, r, true) +} + +func (prov *ProvisioningAPI) PostCreateGroup(w http.ResponseWriter, r *http.Request) { + login := prov.getLoginForCall(w, r) + if login == nil { + return + } + jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ + Err: "Creating groups is not yet implemented", + ErrCode: mautrix.MUnrecognized.ErrCode, + }) +} From fb4aff7608cde6a6dc36e9b7d274ea38c0efa7e7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 19 Jun 2024 23:29:56 +0300 Subject: [PATCH 0321/1647] bridgev2: allow pre-uploaded avatars --- bridgev2/ghost.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index 762856d2..53d457c8 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -95,9 +95,16 @@ type Avatar struct { ID networkid.AvatarID Get func(ctx context.Context) ([]byte, error) Remove bool + + // For pre-uploaded avatars, the MXC URI and hash can be provided directly + MXC id.ContentURIString + Hash [32]byte } func (a *Avatar) Reupload(ctx context.Context, intent MatrixAPI, currentHash [32]byte) (id.ContentURIString, [32]byte, error) { + if a.MXC != "" { + return a.MXC, a.Hash, nil + } data, err := a.Get(ctx) if err != nil { return "", [32]byte{}, err From 557e53b5cd93e9280be2e38276f16a2131e57861 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 19 Jun 2024 23:37:14 +0300 Subject: [PATCH 0322/1647] bridgev2: add bridge info in room create requests --- bridgev2/portal.go | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 26cd30e8..b259095a 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1224,8 +1224,6 @@ func (portal *Portal) handleRemoteTyping(ctx context.Context, source *UserLogin, } } -var stateElementFunctionalMembers = event.Type{Class: event.StateEventType, Type: "io.element.functional_members"} - type PortalInfo struct { Name *string Topic *string @@ -1286,8 +1284,8 @@ func (portal *Portal) UpdateAvatar(ctx context.Context, avatar *Avatar, sender * func (portal *Portal) GetTopLevelParent() *Portal { // TODO ensure there's no infinite recursion? if portal.Parent == nil { - // TODO only return self if this is a space portal - return portal + // TODO return self if this is a space portal? + return nil } return portal.Parent.GetTopLevelParent() } @@ -1513,7 +1511,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i Name: portal.Name, Topic: portal.Topic, CreationContent: make(map[string]any), - InitialState: make([]*event.Event, 0, 4), + InitialState: make([]*event.Event, 0, 6), Preset: "private_chat", IsDirect: *info.IsDirectChat, PowerLevelOverride: &event.PowerLevelsEventContent{ @@ -1534,13 +1532,23 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i if *info.IsSpace { req.CreationContent["type"] = event.RoomTypeSpace } + bridgeInfoStateKey, bridgeInfo := portal.getBridgeInfo() emptyString := "" + req.InitialState = append(req.InitialState, &event.Event{ StateKey: &emptyString, - Type: stateElementFunctionalMembers, - Content: event.Content{Raw: map[string]any{ - "service_members": append(extraFunctionalMembers, portal.Bridge.Bot.GetMXID()), + Type: event.StateElementFunctionalMembers, + Content: event.Content{Parsed: &event.ElementFunctionalMembersContent{ + FunctionalMembers: append(extraFunctionalMembers, portal.Bridge.Bot.GetMXID()), }}, + }, &event.Event{ + StateKey: &bridgeInfoStateKey, + Type: event.StateHalfShotBridge, + Content: event.Content{Parsed: &bridgeInfo}, + }, &event.Event{ + StateKey: &bridgeInfoStateKey, + Type: event.StateBridge, + Content: event.Content{Parsed: &bridgeInfo}, }) if req.Topic == "" { // Add explicit topic event if topic is empty to ensure the event is set. From 68d8ab6896fde4a7430bdbca6fb37b6155581cb0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 19 Jun 2024 23:41:09 +0300 Subject: [PATCH 0323/1647] changelog: update --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d3b968fe..5b8213ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ ## v0.19.0 (unreleased) +* *(crypto)* Fixed bug with copying `m.relates_to` from wire content to + decrypted content. +* *(bridgev2)* Added more features. +* *(mediaproxy)* Added module for implementing simple media repos that proxy + requests elsewhere. + ### beta.1 (2024-06-16) * *(bridgev2)* Added experimental high-level bridge framework. From 2eb51d35e25f45dff73284b64941e5f8c4df445b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 20 Jun 2024 13:03:55 +0300 Subject: [PATCH 0324/1647] bridgev2: remove duplicate response in resolve-identifier command --- bridgev2/cmdstartchat.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/bridgev2/cmdstartchat.go b/bridgev2/cmdstartchat.go index d8746d25..3d0a3ced 100644 --- a/bridgev2/cmdstartchat.go +++ b/bridgev2/cmdstartchat.go @@ -76,9 +76,6 @@ func fnResolveIdentifier(ce *CommandEvent) { } targetName = resp.Ghost.Name targetMXID = resp.Ghost.MXID - if !createChat { - ce.Reply("Found `%s` / [%s](%s)", resp.Ghost.ID, resp.Ghost.Name, resp.Ghost.MXID.URI().MatrixToURL()) - } } else if resp.UserInfo != nil && resp.UserInfo.Name != nil { targetName = *resp.UserInfo.Name } From 59b99dee7099e5087419955111fcd4e31a4b7ffa Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 20 Jun 2024 13:26:11 +0300 Subject: [PATCH 0325/1647] bridgev2: add remote->matrix room tagging and muting interfaces --- bridgev2/matrix/intent.go | 30 ++++++++++++++++ bridgev2/matrixinterface.go | 3 ++ bridgev2/networkinterface.go | 14 ++++++++ bridgev2/portal.go | 66 ++++++++++++++++++++++++++++++++++++ client.go | 16 ++++----- event/accountdata.go | 42 +++++++++++++++++++++-- 6 files changed, 160 insertions(+), 11 deletions(-) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 73c11658..42d33fdb 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -8,6 +8,7 @@ package matrix import ( "context" + "errors" "fmt" "time" @@ -19,6 +20,7 @@ import ( "maunium.net/go/mautrix/crypto/attachment" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/pushrules" ) // ASIntent implements the bridge ghost API interface using a real Matrix homeserver as the backend. @@ -249,3 +251,31 @@ func (as *ASIntent) DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnl } return nil } + +func (as *ASIntent) TagRoom(ctx context.Context, roomID id.RoomID, tag event.RoomTag, isTagged bool) error { + if isTagged { + return as.Matrix.AddTagWithCustomData(ctx, roomID, tag, &event.TagMetadata{ + MauDoublePuppetSource: as.Connector.AS.DoublePuppetValue, + }) + } else { + if tag == "" { + // TODO clear all tags? + } + return as.Matrix.RemoveTag(ctx, roomID, tag) + } +} + +func (as *ASIntent) MuteRoom(ctx context.Context, roomID id.RoomID, until time.Time) error { + if !until.IsZero() && until.Before(time.Now()) { + err := as.Matrix.DeletePushRule(ctx, "global", pushrules.RoomRule, string(roomID)) + // If the push rule doesn't exist, everything is fine + if errors.Is(err, mautrix.MNotFound) { + err = nil + } + return err + } else { + return as.Matrix.PutPushRule(ctx, "global", pushrules.RoomRule, string(roomID), &mautrix.ReqPutPushRule{ + Actions: []pushrules.PushActionType{pushrules.ActionDontNotify}, + }) + } +} diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index c98404cf..5e78cc67 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -58,4 +58,7 @@ type MatrixAPI interface { DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnly bool) error InviteUser(ctx context.Context, roomID id.RoomID, userID id.UserID) error EnsureJoined(ctx context.Context, roomID id.RoomID) error + + TagRoom(ctx context.Context, roomID id.RoomID, tag event.RoomTag, isTagged bool) error + MuteRoom(ctx context.Context, roomID id.RoomID, until time.Time) error } diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 06c8196e..c9aae04b 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -287,6 +287,8 @@ const ( RemoteEventReadReceipt RemoteEventDeliveryReceipt RemoteEventTyping + RemoteEventChatTag + RemoteEventChatMute ) // RemoteEvent represents a single event from the remote network, such as a message or a reaction. @@ -374,6 +376,18 @@ type RemoteTypingWithType interface { GetTypingType() TypingType } +type RemoteChatTag interface { + RemoteEvent + GetTag() (tag event.RoomTag, remove bool) +} + +var Unmuted = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + +type RemoteChatMute interface { + RemoteEvent + GetMutedUntil() time.Time +} + // SimpleRemoteEvent is a simple implementation of RemoteEvent that can be used with struct fields and some callbacks. type SimpleRemoteEvent[T any] struct { Type RemoteEventType diff --git a/bridgev2/portal.go b/bridgev2/portal.go index b259095a..b771a801 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -804,6 +804,10 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { portal.handleRemoteDeliveryReceipt(ctx, source, evt.(RemoteReceipt)) case RemoteEventTyping: portal.handleRemoteTyping(ctx, source, evt.(RemoteTyping)) + case RemoteEventChatTag: + portal.handleRemoteChatTag(ctx, source, evt.(RemoteChatTag)) + case RemoteEventChatMute: + portal.handleRemoteChatMute(ctx, source, evt.(RemoteChatMute)) default: log.Warn().Int("type", int(evt.GetType())).Msg("Got remote event with unknown type") } @@ -1224,6 +1228,37 @@ func (portal *Portal) handleRemoteTyping(ctx context.Context, source *UserLogin, } } +func (portal *Portal) handleRemoteChatTag(ctx context.Context, source *UserLogin, evt RemoteChatTag) { + if !evt.GetSender().IsFromMe { + zerolog.Ctx(ctx).Warn().Msg("Ignoring chat tag event from non-self user") + return + } + dp := source.User.DoublePuppet(ctx) + if dp == nil { + return + } + tag, isTagged := evt.GetTag() + err := dp.TagRoom(ctx, portal.MXID, tag, isTagged) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to bridge chat tag event") + } +} + +func (portal *Portal) handleRemoteChatMute(ctx context.Context, source *UserLogin, evt RemoteChatMute) { + if !evt.GetSender().IsFromMe { + zerolog.Ctx(ctx).Warn().Msg("Ignoring chat mute event from non-self user") + return + } + dp := source.User.DoublePuppet(ctx) + if dp == nil { + return + } + err := dp.MuteRoom(ctx, portal.MXID, evt.GetMutedUntil()) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to bridge chat mute event") + } +} + type PortalInfo struct { Name *string Topic *string @@ -1233,6 +1268,13 @@ type PortalInfo struct { IsDirectChat *bool IsSpace *bool + + UserLocal *UserLocalPortalInfo +} + +type UserLocalPortalInfo struct { + MutedUntil *time.Time + Tag *event.RoomTag } func (portal *Portal) UpdateName(ctx context.Context, name string, sender *Ghost, ts time.Time) bool { @@ -1445,6 +1487,28 @@ func (portal *Portal) SyncParticipants(ctx context.Context, members []networkid. return expectedUserIDs, extraFunctionalMembers, nil } +func (portal *Portal) updateUserLocalInfo(ctx context.Context, info *UserLocalPortalInfo, source *UserLogin) { + if portal.MXID == "" || info == nil { + return + } + dp := source.User.DoublePuppet(ctx) + if dp == nil { + return + } + if info.MutedUntil != nil { + err := dp.MuteRoom(ctx, portal.MXID, *info.MutedUntil) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to mute room") + } + } + if info.Tag != nil { + err := dp.TagRoom(ctx, portal.MXID, *info.Tag, *info.Tag != "") + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to tag room") + } + } +} + func (portal *Portal) UpdateInfo(ctx context.Context, info *PortalInfo, source *UserLogin, sender *Ghost, ts time.Time) { changed := false if info.Name != nil { @@ -1469,6 +1533,7 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *PortalInfo, source * if err != nil { zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to ensure user portal row exists") } + portal.updateUserLocalInfo(ctx, info.UserLocal, source) } if changed { portal.UpdateBridgeInfo(ctx) @@ -1599,6 +1664,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i if portal.Parent != nil { // TODO add m.space.child event } + portal.updateUserLocalInfo(ctx, info.UserLocal, source) if !isBeeper { _, _, err = portal.SyncParticipants(ctx, info.Members, source) if err != nil { diff --git a/client.go b/client.go index 3fb9919e..5b4236c2 100644 --- a/client.go +++ b/client.go @@ -1925,15 +1925,13 @@ func (cli *Client) SetReadMarkers(ctx context.Context, roomID id.RoomID, content return } -func (cli *Client) AddTag(ctx context.Context, roomID id.RoomID, tag string, order float64) error { - var tagData event.Tag - if order == order { - tagData.Order = json.Number(strconv.FormatFloat(order, 'e', -1, 64)) - } - return cli.AddTagWithCustomData(ctx, roomID, tag, tagData) +func (cli *Client) AddTag(ctx context.Context, roomID id.RoomID, tag event.RoomTag, order float64) error { + return cli.AddTagWithCustomData(ctx, roomID, tag, &event.TagMetadata{ + Order: json.Number(strconv.FormatFloat(order, 'e', -1, 64)), + }) } -func (cli *Client) AddTagWithCustomData(ctx context.Context, roomID id.RoomID, tag string, data interface{}) (err error) { +func (cli *Client) AddTagWithCustomData(ctx context.Context, roomID id.RoomID, tag event.RoomTag, data any) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "tags", tag) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, data, nil) return @@ -1944,13 +1942,13 @@ func (cli *Client) GetTags(ctx context.Context, roomID id.RoomID) (tags event.Ta return } -func (cli *Client) GetTagsWithCustomData(ctx context.Context, roomID id.RoomID, resp interface{}) (err error) { +func (cli *Client) GetTagsWithCustomData(ctx context.Context, roomID id.RoomID, resp any) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "tags") _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } -func (cli *Client) RemoveTag(ctx context.Context, roomID id.RoomID, tag string) (err error) { +func (cli *Client) RemoveTag(ctx context.Context, roomID id.RoomID, tag event.RoomTag) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "tags", tag) _, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, nil) return diff --git a/event/accountdata.go b/event/accountdata.go index 6637fcfe..f4b05802 100644 --- a/event/accountdata.go +++ b/event/accountdata.go @@ -8,6 +8,7 @@ package event import ( "encoding/json" + "strings" "maunium.net/go/mautrix/id" ) @@ -18,10 +19,47 @@ type TagEventContent struct { Tags Tags `json:"tags"` } -type Tags map[string]Tag +type Tags map[RoomTag]TagMetadata -type Tag struct { +type RoomTag string + +const ( + RoomTagFavourite RoomTag = "m.favourite" + RoomTagLowPriority RoomTag = "m.lowpriority" + RoomTagServerNotice RoomTag = "m.server_notice" +) + +func (rt RoomTag) IsUserDefined() bool { + return strings.HasPrefix(string(rt), "u.") +} + +func (rt RoomTag) String() string { + return string(rt) +} + +func (rt RoomTag) Name() string { + if rt.IsUserDefined() { + return string(rt[2:]) + } + switch rt { + case RoomTagFavourite: + return "Favourite" + case RoomTagLowPriority: + return "Low priority" + case RoomTagServerNotice: + return "Server notice" + default: + return "" + } +} + +// Deprecated: type alias +type Tag = TagMetadata + +type TagMetadata struct { Order json.Number `json:"order,omitempty"` + + MauDoublePuppetSource string `json:"fi.mau.double_puppet_source,omitempty"` } // DirectChatsEventContent represents the content of a m.direct account data event. From 7c2fdc703d2b9efff7e836e37766a2d9b43cd953 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 20 Jun 2024 15:21:17 +0300 Subject: [PATCH 0326/1647] bridgev2: add remote->matrix mark unread interfaces --- bridgev2/matrix/intent.go | 20 +++++++++++++++++++- bridgev2/matrixinterface.go | 1 + bridgev2/networkinterface.go | 6 ++++++ bridgev2/portal.go | 17 +++++++++++++++++ event/accountdata.go | 4 ++++ event/content.go | 8 ++++++++ event/type.go | 2 ++ 7 files changed, 57 insertions(+), 1 deletion(-) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 42d33fdb..325287db 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -83,7 +83,25 @@ func (as *ASIntent) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.E req.FullyRead = eventID req.BeeperFullyReadExtra = extraData } - return as.Matrix.SetReadMarkers(ctx, roomID, &req) + err := as.Matrix.SetReadMarkers(ctx, roomID, &req) + if err != nil { + return err + } + if as.Matrix.IsCustomPuppet { + err = as.Matrix.SetRoomAccountData(ctx, roomID, event.AccountDataMarkedUnread.Type, &event.MarkedUnreadEventContent{ + Unread: false, + }) + if err != nil { + return err + } + } + return nil +} + +func (as *ASIntent) MarkUnread(ctx context.Context, roomID id.RoomID, unread bool) error { + return as.Matrix.SetRoomAccountData(ctx, roomID, event.AccountDataMarkedUnread.Type, &event.MarkedUnreadEventContent{ + Unread: unread, + }) } func (as *ASIntent) MarkTyping(ctx context.Context, roomID id.RoomID, typingType bridgev2.TypingType, timeout time.Duration) error { diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 5e78cc67..7ac6002c 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -46,6 +46,7 @@ type MatrixAPI interface { SendMessage(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, ts time.Time) (*mautrix.RespSendEvent, error) SendState(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, content *event.Content, ts time.Time) (*mautrix.RespSendEvent, error) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.EventID, ts time.Time) error + MarkUnread(ctx context.Context, roomID id.RoomID, unread bool) error MarkTyping(ctx context.Context, roomID id.RoomID, typingType TypingType, timeout time.Duration) error DownloadMedia(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo) ([]byte, error) UploadMedia(ctx context.Context, roomID id.RoomID, data []byte, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index c9aae04b..5c9ac40e 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -286,6 +286,7 @@ const ( RemoteEventMessageRemove RemoteEventReadReceipt RemoteEventDeliveryReceipt + RemoteEventMarkUnread RemoteEventTyping RemoteEventChatTag RemoteEventChatMute @@ -358,6 +359,11 @@ type RemoteReceipt interface { GetReceiptTargets() []networkid.MessageID } +type RemoteMarkUnread interface { + RemoteEvent + GetUnread() bool +} + type RemoteTyping interface { RemoteEvent GetTimeout() time.Duration diff --git a/bridgev2/portal.go b/bridgev2/portal.go index b771a801..b5d19f5a 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -800,6 +800,8 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { portal.handleRemoteMessageRemove(ctx, source, evt.(RemoteMessageRemove)) case RemoteEventReadReceipt: portal.handleRemoteReadReceipt(ctx, source, evt.(RemoteReceipt)) + case RemoteEventMarkUnread: + portal.handleRemoteMarkUnread(ctx, source, evt.(RemoteMarkUnread)) case RemoteEventDeliveryReceipt: portal.handleRemoteDeliveryReceipt(ctx, source, evt.(RemoteReceipt)) case RemoteEventTyping: @@ -1212,6 +1214,21 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL } } +func (portal *Portal) handleRemoteMarkUnread(ctx context.Context, source *UserLogin, evt RemoteMarkUnread) { + if !evt.GetSender().IsFromMe { + zerolog.Ctx(ctx).Warn().Msg("Ignoring mark unread event from non-self user") + return + } + dp := source.User.DoublePuppet(ctx) + if dp == nil { + return + } + err := dp.MarkUnread(ctx, portal.MXID, evt.GetUnread()) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to bridge mark unread event") + } +} + func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteReceipt) { } diff --git a/event/accountdata.go b/event/accountdata.go index f4b05802..2d37e0bd 100644 --- a/event/accountdata.go +++ b/event/accountdata.go @@ -81,3 +81,7 @@ type IgnoredUserListEventContent struct { type IgnoredUser struct { // This is an empty object } + +type MarkedUnreadEventContent struct { + Unread bool `json:"unread"` +} diff --git a/event/content.go b/event/content.go index e22b6435..c24de56b 100644 --- a/event/content.go +++ b/event/content.go @@ -54,6 +54,7 @@ var TypeMap = map[Type]reflect.Type{ AccountDataDirectChats: reflect.TypeOf(DirectChatsEventContent{}), AccountDataFullyRead: reflect.TypeOf(FullyReadEventContent{}), AccountDataIgnoredUserList: reflect.TypeOf(IgnoredUserListEventContent{}), + AccountDataMarkedUnread: reflect.TypeOf(MarkedUnreadEventContent{}), EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}), EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}), @@ -418,6 +419,13 @@ func (content *Content) AsIgnoredUserList() *IgnoredUserListEventContent { } return casted } +func (content *Content) AsMarkedUnread() *MarkedUnreadEventContent { + casted, ok := content.Parsed.(*MarkedUnreadEventContent) + if !ok { + return &MarkedUnreadEventContent{} + } + return casted +} func (content *Content) AsTyping() *TypingEventContent { casted, ok := content.Parsed.(*TypingEventContent) if !ok { diff --git a/event/type.go b/event/type.go index 56752bc3..e5c6498a 100644 --- a/event/type.go +++ b/event/type.go @@ -117,6 +117,7 @@ func (et *Type) GuessClass() TypeClass { case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type: return EphemeralEventType case AccountDataDirectChats.Type, AccountDataPushRules.Type, AccountDataRoomTags.Type, + AccountDataFullyRead.Type, AccountDataIgnoredUserList.Type, AccountDataMarkedUnread.Type, AccountDataSecretStorageKey.Type, AccountDataSecretStorageDefaultKey.Type, AccountDataCrossSigningMaster.Type, AccountDataCrossSigningSelf.Type, AccountDataCrossSigningUser.Type, AccountDataFullyRead.Type, AccountDataMegolmBackupKey.Type: @@ -240,6 +241,7 @@ var ( AccountDataRoomTags = Type{"m.tag", AccountDataEventType} AccountDataFullyRead = Type{"m.fully_read", AccountDataEventType} AccountDataIgnoredUserList = Type{"m.ignored_user_list", AccountDataEventType} + AccountDataMarkedUnread = Type{"m.marked_unread", AccountDataEventType} AccountDataSecretStorageDefaultKey = Type{"m.secret_storage.default_key", AccountDataEventType} AccountDataSecretStorageKey = Type{"m.secret_storage.key", AccountDataEventType} From 44610ce65ea547eee0e748874074963534738764 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 20 Jun 2024 15:49:41 +0300 Subject: [PATCH 0327/1647] bridgev2/matrix: add logout provisioning API --- bridgev2/database/userlogin.go | 5 +++ bridgev2/matrix/provisioning.go | 58 ++++++++++++++++++++++++++------- bridgev2/userlogin.go | 10 ++++++ 3 files changed, 62 insertions(+), 11 deletions(-) diff --git a/bridgev2/database/userlogin.go b/bridgev2/database/userlogin.go index ea8e5838..77de3122 100644 --- a/bridgev2/database/userlogin.go +++ b/bridgev2/database/userlogin.go @@ -54,6 +54,7 @@ const ( getUserLoginBaseQuery = ` SELECT bridge_id, user_mxid, id, space_room, metadata FROM user_login ` + getLoginByIDQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND id=$2` getAllLoginsQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1` getAllLoginsForUserQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND user_mxid=$2` getAllLoginsInPortalQuery = ` @@ -74,6 +75,10 @@ const ( ` ) +func (uq *UserLoginQuery) GetByID(ctx context.Context, id networkid.UserLoginID) (*UserLogin, error) { + return uq.QueryOne(ctx, getLoginByIDQuery, uq.BridgeID, id) +} + func (uq *UserLoginQuery) GetAll(ctx context.Context) ([]*UserLogin, error) { return uq.QueryMany(ctx, getAllLoginsQuery, uq.BridgeID) } diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index a0b06dfe..1ff9dd06 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -72,6 +72,7 @@ func (prov *ProvisioningAPI) Init() { router.Path("/v3/login/start/{flowID}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginStart) router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginSubmitInput) router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:wait}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginWait) + router.Path("/v3/logout/{loginID}").Methods(http.MethodPost).HandlerFunc(prov.PostLogout) router.Path("/v3/resolve_identifier/{identifier}").Methods(http.MethodGet).HandlerFunc(prov.GetResolveIdentifier) router.Path("/v3/create_dm").Methods(http.MethodPost).HandlerFunc(prov.PostCreateDM) router.Path("/v3/create_group").Methods(http.MethodPost).HandlerFunc(prov.PostCreateGroup) @@ -280,20 +281,55 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques jsonResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) } +func (prov *ProvisioningAPI) PostLogout(w http.ResponseWriter, r *http.Request) { + user := r.Context().Value(provisioningUserKey).(*bridgev2.User) + userLoginID := networkid.UserLoginID(mux.Vars(r)["loginID"]) + if userLoginID == "all" { + for { + login := user.GetDefaultLogin() + if login == nil { + break + } + login.Logout(r.Context()) + } + } else { + userLogin := prov.br.Bridge.GetCachedUserLoginByID(userLoginID) + if userLogin == nil || userLogin.UserMXID != user.MXID { + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + Err: "Login not found", + ErrCode: mautrix.MNotFound.ErrCode, + }) + return + } + userLogin.Logout(r.Context()) + } + jsonResponse(w, http.StatusOK, json.RawMessage("{}")) +} + func (prov *ProvisioningAPI) getLoginForCall(w http.ResponseWriter, r *http.Request) *bridgev2.UserLogin { user := r.Context().Value(provisioningUserKey).(*bridgev2.User) - userLogin := prov.br.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(r.URL.Query().Get("login_id"))) - if userLogin == nil || userLogin.UserMXID != user.MXID { - userLogin = user.GetDefaultLogin() + userLoginID := networkid.UserLoginID(r.URL.Query().Get("login_id")) + if userLoginID != "" { + userLogin := prov.br.Bridge.GetCachedUserLoginByID(userLoginID) + if userLogin == nil || userLogin.UserMXID != user.MXID { + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + Err: "Login not found", + ErrCode: mautrix.MNotFound.ErrCode, + }) + return nil + } + return userLogin + } else { + userLogin := user.GetDefaultLogin() + if userLogin == nil { + jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ + Err: "Not logged in", + ErrCode: "FI.MAU.NOT_LOGGED_IN", + }) + return nil + } + return userLogin } - if userLogin == nil { - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - Err: "Not logged in", - ErrCode: "FI.MAU.NOT_LOGGED_IN", - }) - return nil - } - return userLogin } type RespResolveIdentifier struct { diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 5329d464..5eb10ac3 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -98,6 +98,16 @@ func (br *Bridge) GetUserLoginsInPortal(ctx context.Context, portal networkid.Po return br.loadManyUserLogins(ctx, nil, logins) } +func (br *Bridge) GetUserLoginByID(ctx context.Context, id networkid.UserLoginID) (*UserLogin, error) { + login, err := br.DB.UserLogin.GetByID(ctx, id) + if err != nil { + return nil, err + } + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + return br.loadUserLogin(ctx, nil, login) +} + func (br *Bridge) GetCachedUserLoginByID(id networkid.UserLoginID) *UserLogin { br.cacheLock.Lock() defer br.cacheLock.Unlock() From 05320aee0c7a7b8e07ce6d6d740299878e51122f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 20 Jun 2024 16:11:22 +0300 Subject: [PATCH 0328/1647] bridgev2/matrix: add provisioning API to get all login IDs --- bridgev2/matrix/provisioning.go | 10 ++++++++++ bridgev2/user.go | 6 ++++++ 2 files changed, 16 insertions(+) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 1ff9dd06..474df818 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -73,6 +73,7 @@ func (prov *ProvisioningAPI) Init() { router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginSubmitInput) router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:wait}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginWait) router.Path("/v3/logout/{loginID}").Methods(http.MethodPost).HandlerFunc(prov.PostLogout) + router.Path("/v3/logins").Methods(http.MethodGet).HandlerFunc(prov.GetLogins) router.Path("/v3/resolve_identifier/{identifier}").Methods(http.MethodGet).HandlerFunc(prov.GetResolveIdentifier) router.Path("/v3/create_dm").Methods(http.MethodPost).HandlerFunc(prov.PostCreateDM) router.Path("/v3/create_group").Methods(http.MethodPost).HandlerFunc(prov.PostCreateGroup) @@ -306,6 +307,15 @@ func (prov *ProvisioningAPI) PostLogout(w http.ResponseWriter, r *http.Request) jsonResponse(w, http.StatusOK, json.RawMessage("{}")) } +type RespGetLogins struct { + LoginIDs []networkid.UserLoginID `json:"login_ids"` +} + +func (prov *ProvisioningAPI) GetLogins(w http.ResponseWriter, r *http.Request) { + user := r.Context().Value(provisioningUserKey).(*bridgev2.User) + jsonResponse(w, http.StatusOK, &RespGetLogins{LoginIDs: user.GetUserLoginIDs()}) +} + func (prov *ProvisioningAPI) getLoginForCall(w http.ResponseWriter, r *http.Request) *bridgev2.UserLogin { user := r.Context().Value(provisioningUserKey).(*bridgev2.User) userLoginID := networkid.UserLoginID(r.URL.Query().Get("login_id")) diff --git a/bridgev2/user.go b/bridgev2/user.go index 86268ec1..500a51c0 100644 --- a/bridgev2/user.go +++ b/bridgev2/user.go @@ -152,6 +152,12 @@ func (user *User) DoublePuppet(ctx context.Context) MatrixAPI { return intent } +func (user *User) GetUserLoginIDs() []networkid.UserLoginID { + user.Bridge.cacheLock.Lock() + defer user.Bridge.cacheLock.Unlock() + return maps.Keys(user.logins) +} + func (user *User) GetFormattedUserLogins() string { user.Bridge.cacheLock.Lock() logins := make([]string, len(user.logins)) From 7f901a26880a998ce827ddd719730e1dd75c26a7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 20 Jun 2024 16:51:02 +0300 Subject: [PATCH 0329/1647] bridgev2/matrix: fix bugs in provisioning API login --- bridgev2/login.go | 8 ++++---- bridgev2/matrix/provisioning.go | 3 +++ 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/bridgev2/login.go b/bridgev2/login.go index a1dc9e91..d438b183 100644 --- a/bridgev2/login.go +++ b/bridgev2/login.go @@ -83,10 +83,10 @@ type LoginStep struct { // Exactly one of the following structs must be filled depending on the step type. - DisplayAndWaitParams *LoginDisplayAndWaitParams `json:"display_and_wait"` - CookiesParams *LoginCookiesParams `json:"cookies"` - UserInputParams *LoginUserInputParams `json:"user_input"` - CompleteParams *LoginCompleteParams `json:"complete"` + DisplayAndWaitParams *LoginDisplayAndWaitParams `json:"display_and_wait,omitempty"` + CookiesParams *LoginCookiesParams `json:"cookies,omitempty"` + UserInputParams *LoginUserInputParams `json:"user_input,omitempty"` + CompleteParams *LoginCompleteParams `json:"complete,omitempty"` } type LoginDisplayAndWaitParams struct { diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 474df818..97bd7f65 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -168,6 +168,9 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { return } stepType := mux.Vars(r)["stepType"] + if stepType == "wait" { + stepType = "display_and_wait" + } if login.NextStep.Type != bridgev2.LoginStepType(stepType) { zerolog.Ctx(r.Context()).Warn(). Str("request_step_type", stepType). From 0418273bdbb13281a33e25d114155438436fb17b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 20 Jun 2024 16:51:16 +0300 Subject: [PATCH 0330/1647] bridgev2/database: maybe fix saving custom metadata fields --- bridgev2/database/message.go | 2 +- bridgev2/database/portal.go | 2 +- bridgev2/database/reaction.go | 2 +- bridgev2/database/userlogin.go | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index 363f8e74..4ad2b488 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -165,7 +165,7 @@ func (m *Message) sqlVariables() []any { } return []any{ m.BridgeID, m.ID, m.PartID, m.MXID, m.Room.ID, m.Room.Receiver, m.SenderID, - m.Timestamp.UnixNano(), dbutil.NumPtr(m.RelatesToRowID), dbutil.JSON{Data: m.Metadata}, + m.Timestamp.UnixNano(), dbutil.NumPtr(m.RelatesToRowID), dbutil.JSON{Data: &m.Metadata}, } } diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index 05a22593..0d578ea4 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -168,6 +168,6 @@ func (p *Portal) sqlVariables() []any { p.BridgeID, p.ID, p.Receiver, dbutil.StrPtr(p.MXID), dbutil.StrPtr(p.ParentID), p.Name, p.Topic, p.AvatarID, avatarHash, p.AvatarMXC, p.NameSet, p.TopicSet, p.AvatarSet, p.InSpace, - dbutil.JSON{Data: p.Metadata}, + dbutil.JSON{Data: &p.Metadata}, } } diff --git a/bridgev2/database/reaction.go b/bridgev2/database/reaction.go index d5cb798f..b7d675b7 100644 --- a/bridgev2/database/reaction.go +++ b/bridgev2/database/reaction.go @@ -127,6 +127,6 @@ func (r *Reaction) sqlVariables() []any { } return []any{ r.BridgeID, r.MessageID, r.MessagePartID, r.SenderID, r.EmojiID, - r.Room.ID, r.Room.Receiver, r.MXID, r.Timestamp.UnixNano(), dbutil.JSON{Data: r.Metadata}, + r.Room.ID, r.Room.Receiver, r.MXID, r.Timestamp.UnixNano(), dbutil.JSON{Data: &r.Metadata}, } } diff --git a/bridgev2/database/userlogin.go b/bridgev2/database/userlogin.go index 77de3122..b371483c 100644 --- a/bridgev2/database/userlogin.go +++ b/bridgev2/database/userlogin.go @@ -122,5 +122,5 @@ func (u *UserLogin) sqlVariables() []any { if u.Metadata.Extra == nil { u.Metadata.Extra = make(map[string]any) } - return []any{u.BridgeID, u.UserMXID, u.ID, dbutil.StrPtr(u.SpaceRoom), dbutil.JSON{Data: u.Metadata}} + return []any{u.BridgeID, u.UserMXID, u.ID, dbutil.StrPtr(u.SpaceRoom), dbutil.JSON{Data: &u.Metadata}} } From 7b6f3ba0541d2e2d4f70305e7b85a30e98e8f111 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 20 Jun 2024 17:28:53 +0300 Subject: [PATCH 0331/1647] bridgev2/matrix: add provisioning API for listing contacts --- bridgev2/matrix/provisioning.go | 94 +++++++++++++++++++++++++++++---- bridgev2/networkinterface.go | 5 ++ 2 files changed, 90 insertions(+), 9 deletions(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 97bd7f65..aa339b67 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -74,6 +74,7 @@ func (prov *ProvisioningAPI) Init() { router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:wait}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginWait) router.Path("/v3/logout/{loginID}").Methods(http.MethodPost).HandlerFunc(prov.PostLogout) router.Path("/v3/logins").Methods(http.MethodGet).HandlerFunc(prov.GetLogins) + router.Path("/v3/contacts").Methods(http.MethodGet).HandlerFunc(prov.GetContactList) router.Path("/v3/resolve_identifier/{identifier}").Methods(http.MethodGet).HandlerFunc(prov.GetResolveIdentifier) router.Path("/v3/create_dm").Methods(http.MethodPost).HandlerFunc(prov.PostCreateDM) router.Path("/v3/create_group").Methods(http.MethodPost).HandlerFunc(prov.PostCreateGroup) @@ -346,11 +347,12 @@ func (prov *ProvisioningAPI) getLoginForCall(w http.ResponseWriter, r *http.Requ } type RespResolveIdentifier struct { - ID networkid.UserID `json:"id,omitempty"` - Name string `json:"name,omitempty"` - AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` - MXID id.UserID `json:"mxid,omitempty"` - DMRoomID id.RoomID `json:"dm_room_mxid,omitempty"` + ID networkid.UserID `json:"id,omitempty"` + Name string `json:"name,omitempty"` + AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` + Identifiers []string `json:"identifiers,omitempty"` + MXID id.UserID `json:"mxid,omitempty"` + DMRoomID id.RoomID `json:"dm_room_mxid,omitempty"` } func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http.Request, createChat bool) { @@ -369,12 +371,14 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http. resp, err := api.ResolveIdentifier(r.Context(), mux.Vars(r)["identifier"], createChat) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to resolve identifier") - jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ + jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ Err: fmt.Sprintf("Failed to resolve identifier: %v", err), ErrCode: "M_UNKNOWN", }) } - apiResp := &RespResolveIdentifier{} + apiResp := &RespResolveIdentifier{ + ID: resp.UserID, + } status := http.StatusOK if resp.Ghost != nil { if resp.UserInfo != nil { @@ -382,6 +386,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http. } apiResp.Name = resp.Ghost.Name apiResp.AvatarURL = resp.Ghost.AvatarMXC + apiResp.Identifiers = resp.Ghost.Metadata.Identifiers apiResp.MXID = resp.Ghost.MXID } else if resp.UserInfo != nil && resp.UserInfo.Name != nil { apiResp.Name = *resp.UserInfo.Name @@ -391,7 +396,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http. resp.Chat.Portal, err = prov.br.Bridge.GetPortalByID(r.Context(), resp.Chat.PortalID) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get portal") - jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ + jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ Err: "Failed to get portal", ErrCode: "M_UNKNOWN", }) @@ -403,7 +408,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http. err = resp.Chat.Portal.CreateMatrixRoom(r.Context(), login, resp.Chat.PortalInfo) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create portal room") - jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ + jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ Err: "Failed to create portal room", ErrCode: "M_UNKNOWN", }) @@ -415,6 +420,77 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http. jsonResponse(w, status, resp) } +type RespGetContactList struct { + Contacts []*RespResolveIdentifier `json:"contacts"` +} + +func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Request) { + login := prov.getLoginForCall(w, r) + if login == nil { + return + } + api, ok := login.Client.(bridgev2.ContactListingNetworkAPI) + if !ok { + jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ + Err: "This bridge does not support listing contacts", + ErrCode: mautrix.MUnrecognized.ErrCode, + }) + return + } + resp, err := api.GetContactList(r.Context()) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get contact list") + jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ + Err: fmt.Sprintf("Failed to get contact list: %v", err), + ErrCode: "M_UNKNOWN", + }) + return + } + apiResp := &RespGetContactList{ + Contacts: make([]*RespResolveIdentifier, len(resp)), + } + for i, contact := range resp { + apiContact := &RespResolveIdentifier{ + ID: contact.UserID, + } + fmt.Println(contact.UserInfo.Identifiers) + apiResp.Contacts[i] = apiContact + if contact.UserInfo != nil { + if contact.UserInfo.Name != nil { + apiContact.Name = *contact.UserInfo.Name + } + if contact.UserInfo.Identifiers != nil { + apiContact.Identifiers = contact.UserInfo.Identifiers + } + } + if contact.Ghost != nil { + if contact.Ghost.Name != "" { + apiContact.Name = contact.Ghost.Name + } + if len(contact.Ghost.Metadata.Identifiers) >= len(apiContact.Identifiers) { + apiContact.Identifiers = contact.Ghost.Metadata.Identifiers + } + apiContact.AvatarURL = contact.Ghost.AvatarMXC + apiContact.MXID = contact.Ghost.MXID + } + if contact.Chat != nil { + if contact.Chat.Portal == nil { + contact.Chat.Portal, err = prov.br.Bridge.GetPortalByID(r.Context(), contact.Chat.PortalID) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get portal") + jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ + Err: "Failed to get portal", + ErrCode: "M_UNKNOWN", + }) + return + } + } + apiContact.DMRoomID = contact.Chat.Portal.MXID + } + } + jsonResponse(w, http.StatusOK, apiResp) +} + func (prov *ProvisioningAPI) GetResolveIdentifier(w http.ResponseWriter, r *http.Request) { prov.doResolveIdentifier(w, r, false) } diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 5c9ac40e..4ec98119 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -201,6 +201,11 @@ type IdentifierResolvingNetworkAPI interface { ResolveIdentifier(ctx context.Context, identifier string, createChat bool) (*ResolveIdentifierResponse, error) } +type ContactListingNetworkAPI interface { + NetworkAPI + GetContactList(ctx context.Context) ([]*ResolveIdentifierResponse, error) +} + type UserSearchingNetworkAPI interface { IdentifierResolvingNetworkAPI SearchUsers(ctx context.Context, query string) ([]*ResolveIdentifierResponse, error) From 8e1fdfda2c1eb822db44e6d2766ebf9633a1515a Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 20 Jun 2024 09:51:01 -0600 Subject: [PATCH 0332/1647] event: add unstable audio and voice fields Signed-off-by: Sumner Evans --- event/audio.go | 8 ++++++++ event/message.go | 3 +++ 2 files changed, 11 insertions(+) create mode 100644 event/audio.go diff --git a/event/audio.go b/event/audio.go new file mode 100644 index 00000000..0fc0818b --- /dev/null +++ b/event/audio.go @@ -0,0 +1,8 @@ +package event + +type MSC1767Audio struct { + Duration int `json:"duration,omitempty"` + Waveform []int `json:"waveform,omitempty"` +} + +type MSC3245Voice struct{} diff --git a/event/message.go b/event/message.go index d8b27c3d..21f5240b 100644 --- a/event/message.go +++ b/event/message.go @@ -118,6 +118,9 @@ type MessageEventContent struct { BeeperGalleryCaptionHTML string `json:"com.beeper.gallery.caption_html,omitempty"` BeeperLinkPreviews []*BeeperLinkPreview `json:"com.beeper.linkpreviews,omitempty"` + + MSC1767Audio *MSC1767Audio `json:"org.matrix.msc1767.audio,omitempty"` + MSC3245Voice *MSC3245Voice `json:"org.matrix.msc3245.voice,omitempty"` } func (content *MessageEventContent) GetRelatesTo() *RelatesTo { From 7aea403e00e628a028def4a26b8ebfd27e36183b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 21 Jun 2024 12:33:48 +0300 Subject: [PATCH 0333/1647] bridgev2/matrix: expose functions to allow custom provisioning API endpoints --- bridgev2/matrix/mxmain/main.go | 6 +++- bridgev2/matrix/provisioning.go | 55 ++++++++++++++++++++++----------- 2 files changed, 42 insertions(+), 19 deletions(-) diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index 3b2bc460..1fde50b8 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -62,7 +62,8 @@ type BridgeMain struct { Version string // PostInit is a function that will be called after the bridge has been initialized but before it is started. - PostInit func() + PostInit func() + PostStart func() // Connector is the network connector for the bridge. Connector bridgev2.NetworkConnector @@ -343,6 +344,9 @@ func (br *BridgeMain) Start() { br.Log.Fatal().Err(err).Msg("Failed to start bridge") } } + if br.PostStart != nil { + br.PostStart() + } } // WaitForInterrupt waits for a SIGINT or SIGTERM signal. diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index aa339b67..60a87469 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -33,6 +33,8 @@ type matrixAuthCacheEntry struct { } type ProvisioningAPI struct { + Router *mux.Router + br *Connector log zerolog.Logger net bridgev2.NetworkConnector @@ -59,25 +61,42 @@ const ( provisioningLoginProcessKey ) +func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User { + return r.Context().Value(provisioningUserKey).(*bridgev2.User) +} + +func (prov *ProvisioningAPI) GetRouter() *mux.Router { + return prov.Router +} + +type IProvisioningAPI interface { + GetRouter() *mux.Router + GetUser(r *http.Request) *bridgev2.User +} + +func (br *Connector) GetProvisioning() IProvisioningAPI { + return br.Provisioning +} + func (prov *ProvisioningAPI) Init() { prov.matrixAuthCache = make(map[string]matrixAuthCacheEntry) prov.logins = make(map[string]*ProvLogin) prov.net = prov.br.Bridge.Network prov.log = prov.br.Log.With().Str("component", "provisioning").Logger() - router := prov.br.AS.Router.PathPrefix(prov.br.Config.Provisioning.Prefix).Subrouter() - router.Use(hlog.NewHandler(prov.log)) - router.Use(requestlog.AccessLogger(false)) - router.Use(prov.AuthMiddleware) - router.Path("/v3/login/flows").Methods(http.MethodGet).HandlerFunc(prov.GetLoginFlows) - router.Path("/v3/login/start/{flowID}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginStart) - router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginSubmitInput) - router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:wait}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginWait) - router.Path("/v3/logout/{loginID}").Methods(http.MethodPost).HandlerFunc(prov.PostLogout) - router.Path("/v3/logins").Methods(http.MethodGet).HandlerFunc(prov.GetLogins) - router.Path("/v3/contacts").Methods(http.MethodGet).HandlerFunc(prov.GetContactList) - router.Path("/v3/resolve_identifier/{identifier}").Methods(http.MethodGet).HandlerFunc(prov.GetResolveIdentifier) - router.Path("/v3/create_dm").Methods(http.MethodPost).HandlerFunc(prov.PostCreateDM) - router.Path("/v3/create_group").Methods(http.MethodPost).HandlerFunc(prov.PostCreateGroup) + prov.Router = prov.br.AS.Router.PathPrefix(prov.br.Config.Provisioning.Prefix).Subrouter() + prov.Router.Use(hlog.NewHandler(prov.log)) + prov.Router.Use(requestlog.AccessLogger(false)) + prov.Router.Use(prov.AuthMiddleware) + prov.Router.Path("/v3/login/flows").Methods(http.MethodGet).HandlerFunc(prov.GetLoginFlows) + prov.Router.Path("/v3/login/start/{flowID}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginStart) + prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginSubmitInput) + prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:wait}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginWait) + prov.Router.Path("/v3/logout/{loginID}").Methods(http.MethodPost).HandlerFunc(prov.PostLogout) + prov.Router.Path("/v3/logins").Methods(http.MethodGet).HandlerFunc(prov.GetLogins) + prov.Router.Path("/v3/contacts").Methods(http.MethodGet).HandlerFunc(prov.GetContactList) + prov.Router.Path("/v3/resolve_identifier/{identifier}").Methods(http.MethodGet).HandlerFunc(prov.GetResolveIdentifier) + prov.Router.Path("/v3/create_dm").Methods(http.MethodPost).HandlerFunc(prov.PostCreateDM) + prov.Router.Path("/v3/create_group").Methods(http.MethodPost).HandlerFunc(prov.PostCreateGroup) if prov.br.Config.Provisioning.DebugEndpoints { prov.log.Debug().Msg("Enabling debug API at /debug") @@ -207,7 +226,7 @@ func (prov *ProvisioningAPI) GetLoginFlows(w http.ResponseWriter, r *http.Reques func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Request) { login, err := prov.net.CreateLogin( r.Context(), - r.Context().Value(provisioningUserKey).(*bridgev2.User), + prov.GetUser(r), mux.Vars(r)["flowID"], ) if err != nil { @@ -287,7 +306,7 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques } func (prov *ProvisioningAPI) PostLogout(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value(provisioningUserKey).(*bridgev2.User) + user := prov.GetUser(r) userLoginID := networkid.UserLoginID(mux.Vars(r)["loginID"]) if userLoginID == "all" { for { @@ -316,12 +335,12 @@ type RespGetLogins struct { } func (prov *ProvisioningAPI) GetLogins(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value(provisioningUserKey).(*bridgev2.User) + user := prov.GetUser(r) jsonResponse(w, http.StatusOK, &RespGetLogins{LoginIDs: user.GetUserLoginIDs()}) } func (prov *ProvisioningAPI) getLoginForCall(w http.ResponseWriter, r *http.Request) *bridgev2.UserLogin { - user := r.Context().Value(provisioningUserKey).(*bridgev2.User) + user := prov.GetUser(r) userLoginID := networkid.UserLoginID(r.URL.Query().Get("login_id")) if userLoginID != "" { userLogin := prov.br.Bridge.GetCachedUserLoginByID(userLoginID) From c746b86741b3caa8a0c76854d814e2aaf36be0f9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 21 Jun 2024 14:07:50 +0300 Subject: [PATCH 0334/1647] bridgev2: add delete method for UserLogin --- bridgev2/bridgestate.go | 4 +++ bridgev2/database/userportal.go | 7 +++++ bridgev2/userlogin.go | 51 +++++++++++++++++++++++++++------ 3 files changed, 53 insertions(+), 9 deletions(-) diff --git a/bridgev2/bridgestate.go b/bridgev2/bridgestate.go index 961d9e31..578dfee3 100644 --- a/bridgev2/bridgestate.go +++ b/bridgev2/bridgestate.go @@ -49,6 +49,10 @@ func (br *Bridge) NewBridgeStateQueue(user status.BridgeStateFiller) *BridgeStat return bsq } +func (bsq *BridgeStateQueue) Destroy() { + close(bsq.ch) +} + func (bsq *BridgeStateQueue) loop() { defer func() { err := recover() diff --git a/bridgev2/database/userportal.go b/bridgev2/database/userportal.go index 2a41fc91..49d0bf5a 100644 --- a/bridgev2/database/userportal.go +++ b/bridgev2/database/userportal.go @@ -48,6 +48,9 @@ const ( WHERE bridge_id=$1 AND user_mxid=$2 AND portal_id=$3 AND portal_receiver=$4 ORDER BY CASE WHEN preferred THEN 0 ELSE 1 END, login_id ` + getAllPortalsForLoginQuery = getUserPortalBaseQuery + ` + WHERE bridge_id=$1 AND user_mxid=$2 AND login_id=$3 + ` insertUserPortalQuery = ` INSERT INTO user_portal (bridge_id, user_mxid, login_id, portal_id, portal_receiver, in_space, preferred) VALUES ($1, $2, $3, $4, $5, false, false) @@ -79,6 +82,10 @@ func (upq *UserPortalQuery) GetAllByUser(ctx context.Context, userID id.UserID, return upq.QueryMany(ctx, findUserLoginsByPortalIDQuery, upq.BridgeID, userID, portal.ID, portal.Receiver) } +func (upq *UserPortalQuery) GetAllForLogin(ctx context.Context, login *UserLogin) ([]*UserPortal, error) { + return upq.QueryMany(ctx, getUserPortalQuery, upq.BridgeID, login.UserMXID, login.ID) +} + func (upq *UserPortalQuery) Get(ctx context.Context, login *UserLogin, portal networkid.PortalKey) (*UserPortal, error) { return upq.QueryOne(ctx, getUserPortalQuery, upq.BridgeID, login.UserMXID, login.ID, portal.ID, portal.Receiver) } diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 5eb10ac3..3e4b35d9 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -141,17 +141,47 @@ func (ul *UserLogin) Save(ctx context.Context) error { } func (ul *UserLogin) Logout(ctx context.Context) { - ul.Client.LogoutRemote(ctx) - err := ul.Bridge.DB.UserLogin.Delete(ctx, ul.ID) + ul.Delete(ctx, status.BridgeState{StateEvent: status.StateLoggedOut}, true) +} + +func (ul *UserLogin) Delete(ctx context.Context, state status.BridgeState, logoutRemote bool) { + if logoutRemote { + ul.Client.LogoutRemote(ctx) + } else { + ul.Disconnect(nil) + } + portals, err := ul.Bridge.DB.UserPortal.GetAllForLogin(ctx, ul.UserLogin) + if err != nil { + ul.Log.Err(err).Msg("Failed to get user portals") + } + err = ul.Bridge.DB.UserLogin.Delete(ctx, ul.ID) if err != nil { ul.Log.Err(err).Msg("Failed to delete user login") } ul.Bridge.cacheLock.Lock() - defer ul.Bridge.cacheLock.Unlock() delete(ul.User.logins, ul.ID) delete(ul.Bridge.userLoginsByID, ul.ID) - // TODO kick user out of rooms? - ul.BridgeState.Send(status.BridgeState{StateEvent: status.StateLoggedOut}) + ul.Bridge.cacheLock.Unlock() + go ul.deleteSpace(ctx) + go ul.kickUserFromPortals(ctx, portals) + if state.StateEvent != "" { + ul.BridgeState.Send(state) + } + ul.BridgeState.Destroy() +} + +func (ul *UserLogin) deleteSpace(ctx context.Context) { + if ul.SpaceRoom == "" { + return + } + err := ul.Bridge.Bot.DeleteRoom(ctx, ul.SpaceRoom, false) + if err != nil { + ul.Log.Err(err).Msg("Failed to delete space room") + } +} + +func (ul *UserLogin) kickUserFromPortals(ctx context.Context, portals []*database.UserPortal) { + // TODO kick user out of rooms } func (ul *UserLogin) MarkAsPreferredIn(ctx context.Context, portal *Portal) error { @@ -173,12 +203,15 @@ func (ul *UserLogin) GetRemoteName() string { } func (ul *UserLogin) Disconnect(done func()) { - defer done() - if ul.Client != nil { + if done != nil { + defer done() + } + client := ul.Client + if client != nil { + ul.Client = nil disconnected := make(chan struct{}) go func() { - ul.Client.Disconnect() - ul.Client = nil + client.Disconnect() close(disconnected) }() select { From c12ac71a39ba6e21d12d1809377b95d3ca610c17 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 21 Jun 2024 16:56:55 +0300 Subject: [PATCH 0335/1647] bridgev2: split NetworkAPI interface and add capabilities --- bridgev2/database/message.go | 1 + bridgev2/matrix/connector.go | 7 +++ bridgev2/matrixinterface.go | 6 ++ bridgev2/messagestatus.go | 8 +++ bridgev2/networkinterface.go | 54 ++++++++++++++++ bridgev2/portal.go | 119 +++++++++++++++++++++++++++++------ 6 files changed, 175 insertions(+), 20 deletions(-) diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index 4ad2b488..d083eb98 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -24,6 +24,7 @@ type MessageQuery struct { type StandardMessageMetadata struct { SenderMXID id.UserID `json:"sender_mxid,omitempty"` + EditCount int `json:"edit_count,omitempty"` } type MessageMetadata struct { diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 5da605dc..a06aeadb 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -63,6 +63,7 @@ type Connector struct { MediaConfig mautrix.RespMediaConfig SpecVersions *mautrix.RespVersions + Capabilities *bridgev2.MatrixCapabilities IgnoreUnsupportedServer bool EventProcessor *appservice.EventProcessor @@ -85,6 +86,7 @@ func NewConnector(cfg *bridgeconfig.Config) *Connector { c.Config = cfg c.userIDRegex = cfg.MakeUserIDRegex("(.+)") c.MediaConfig.UploadSize = 50 * 1024 * 1024 + c.Capabilities = &bridgev2.MatrixCapabilities{} return c } @@ -148,6 +150,10 @@ func (br *Connector) Start(ctx context.Context) error { return nil } +func (br *Connector) GetCapabilities() *bridgev2.MatrixCapabilities { + return br.Capabilities +} + func (br *Connector) Stop() { br.AS.Stop() br.EventProcessor.Stop() @@ -167,6 +173,7 @@ func (br *Connector) ensureConnection(ctx context.Context) { } else { br.SpecVersions = versions *br.AS.SpecVersions = *versions + br.Capabilities.AutoJoinInvites = br.SpecVersions.Supports(mautrix.BeeperFeatureAutojoinInvites) break } } diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 7ac6002c..ead306a6 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -17,11 +17,17 @@ import ( "maunium.net/go/mautrix/id" ) +type MatrixCapabilities struct { + AutoJoinInvites bool +} + type MatrixConnector interface { Init(*Bridge) Start(ctx context.Context) error Stop() + GetCapabilities() *MatrixCapabilities + ParseGhostMXID(userID id.UserID) (networkid.UserID, bool) FormatGhostMXID(userID networkid.UserID) id.UserID diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index 16107422..f642f73b 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -18,6 +18,14 @@ import ( var ( ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false) ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false) + ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true) + ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true) + ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true) + ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true) + ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true) + ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true) + ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true) + ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true) ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true) ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 4ec98119..314ec20d 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -159,6 +159,39 @@ type MatrixMessageResponse struct { DB *database.Message } +type FileRestriction struct { + MaxSize int64 + MimeTypes []string +} + +type NetworkRoomCapabilities struct { + FormattedText bool + UserMentions bool + RoomMentions bool + + LocationMessages bool + Captions bool + MaxTextLength int + MaxCaptionLength int + + Threads bool + Replies bool + Edits bool + EditMaxCount int + EditMaxAge time.Duration + Deletes bool + DeleteMaxAge time.Duration + + DefaultFileRestriction *FileRestriction + Files map[event.MessageType]FileRestriction + + ReadReceipts bool + + Reactions bool + ReactionCount int + AllowedReactions []string +} + // NetworkAPI is an interface representing a remote network client for a single user login. type NetworkAPI interface { Connect(ctx context.Context) error @@ -169,14 +202,35 @@ type NetworkAPI interface { IsThisUser(ctx context.Context, userID networkid.UserID) bool GetChatInfo(ctx context.Context, portal *Portal) (*PortalInfo, error) GetUserInfo(ctx context.Context, ghost *Ghost) (*UserInfo, error) + GetCapabilities(ctx context.Context, portal *Portal) *NetworkRoomCapabilities HandleMatrixMessage(ctx context.Context, msg *MatrixMessage) (message *MatrixMessageResponse, err error) +} + +type EditHandlingNetworkAPI interface { + NetworkAPI HandleMatrixEdit(ctx context.Context, msg *MatrixEdit) error +} + +type ReactionHandlingNetworkAPI interface { + NetworkAPI PreHandleMatrixReaction(ctx context.Context, msg *MatrixReaction) (MatrixReactionPreResponse, error) HandleMatrixReaction(ctx context.Context, msg *MatrixReaction) (reaction *database.Reaction, err error) HandleMatrixReactionRemove(ctx context.Context, msg *MatrixReactionRemove) error +} + +type RedactionHandlingNetworkAPI interface { + NetworkAPI HandleMatrixMessageRemove(ctx context.Context, msg *MatrixMessageRemove) error +} + +type ReadReceiptHandlingNetworkAPI interface { + NetworkAPI HandleMatrixReadReceipt(ctx context.Context, msg *MatrixReadReceipt) error +} + +type TypingHandlingNetworkAPI interface { + NetworkAPI HandleMatrixTyping(ctx context.Context, msg *MatrixTyping) error } diff --git a/bridgev2/portal.go b/bridgev2/portal.go index b5d19f5a..382917c3 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -325,6 +325,10 @@ func (portal *Portal) handleMatrixReadReceipt(user *User, eventID id.EventID, re log.Err(err).Msg("Failed to get preferred login for user") return } + rrClient, ok := login.Client.(ReadReceiptHandlingNetworkAPI) + if !ok { + return + } evt := &MatrixReadReceipt{ Portal: portal, EventID: eventID, @@ -343,7 +347,7 @@ func (portal *Portal) handleMatrixReadReceipt(user *User, eventID id.EventID, re } else { evt.ReadUpTo = receipt.Timestamp } - err = login.Client.HandleMatrixReadReceipt(ctx, evt) + err = rrClient.HandleMatrixReadReceipt(ctx, evt) if err != nil { log.Err(err).Msg("Failed to handle read receipt") return @@ -397,13 +401,19 @@ func (portal *Portal) sendTypings(ctx context.Context, userIDs []id.UserID, typi continue } else if login == nil { continue + } else if _, ok = login.Client.(TypingHandlingNetworkAPI); !ok { + continue } portal.currentlyTypingLogins[userID] = login } if !typing { delete(portal.currentlyTypingLogins, userID) } - err := login.Client.HandleMatrixTyping(ctx, &MatrixTyping{ + typingAPI, ok := login.Client.(TypingHandlingNetworkAPI) + if !ok { + continue + } + err := typingAPI.HandleMatrixTyping(ctx, &MatrixTyping{ Portal: portal, IsTyping: typing, Type: TypingTypeText, @@ -436,7 +446,11 @@ func (portal *Portal) periodicTypingUpdater() { if !ok { continue } - err := login.Client.HandleMatrixTyping(ctx, &MatrixTyping{ + typingAPI, ok := login.Client.(TypingHandlingNetworkAPI) + if !ok { + continue + } + err := typingAPI.HandleMatrixTyping(ctx, &MatrixTyping{ Portal: portal, IsTyping: true, Type: TypingTypeText, @@ -454,6 +468,27 @@ func (portal *Portal) periodicTypingUpdater() { } } +func (portal *Portal) checkMessageContentCaps(ctx context.Context, caps *NetworkRoomCapabilities, content *event.MessageEventContent, evt *event.Event) bool { + switch content.MsgType { + case event.MsgText, event.MsgNotice, event.MsgEmote: + // No checks for now, message length is safer to check after conversion inside connector + case event.MsgLocation: + if !caps.LocationMessages { + portal.sendErrorStatus(ctx, evt, ErrLocationMessagesNotAllowed) + return false + } + case event.MsgImage, event.MsgAudio, event.MsgVideo, event.MsgFile: + if content.FileName != "" && content.Body != content.FileName { + if !caps.Captions { + portal.sendErrorStatus(ctx, evt, ErrCaptionsNotAllowed) + return false + } + } + default: + } + return true +} + func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(*event.MessageEventContent) @@ -462,17 +497,19 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) return } + caps := sender.Client.GetCapabilities(ctx, portal) + if content.RelatesTo.GetReplaceID() != "" { - portal.handleMatrixEdit(ctx, sender, origSender, evt, content) + portal.handleMatrixEdit(ctx, sender, origSender, evt, content, caps) + return + } + if !portal.checkMessageContentCaps(ctx, caps, content, evt) { return } - // TODO get capabilities from network connector - threadsSupported := true - repliesSupported := true var threadRoot, replyTo *database.Message var err error - if threadsSupported { + if caps.Threads { threadRootID := content.RelatesTo.GetThreadParent() if threadRootID != "" { threadRoot, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, threadRootID) @@ -481,9 +518,9 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin } } } - if repliesSupported { + if caps.Replies { var replyToID id.EventID - if threadsSupported { + if caps.Threads { replyToID = content.RelatesTo.GetNonFallbackReplyTo() } else { replyToID = content.RelatesTo.GetReplyTo() @@ -543,15 +580,28 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin portal.sendSuccessStatus(ctx, evt) } -func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent) { - editTargetID := content.RelatesTo.GetReplaceID() +func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent, caps *NetworkRoomCapabilities) { log := zerolog.Ctx(ctx) + editTargetID := content.RelatesTo.GetReplaceID() log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Stringer("edit_target_mxid", editTargetID) }) if content.NewContent != nil { content = content.NewContent } + + editingAPI, ok := sender.Client.(EditHandlingNetworkAPI) + if !ok { + log.Debug().Msg("Ignoring edit as network connector doesn't implement EditHandlingNetworkAPI") + portal.sendErrorStatus(ctx, evt, ErrEditsNotSupported) + return + } else if !caps.Edits { + log.Debug().Msg("Ignoring edit as room doesn't support edits") + portal.sendErrorStatus(ctx, evt, ErrEditsNotSupportedInPortal) + return + } else if !portal.checkMessageContentCaps(ctx, caps, content, evt) { + return + } editTarget, err := portal.Bridge.DB.Message.GetPartByMXID(ctx, editTargetID) if err != nil { log.Err(err).Msg("Failed to get edit target message from database") @@ -561,11 +611,17 @@ func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, o log.Warn().Msg("Edit target message not found in database") portal.sendErrorStatus(ctx, evt, fmt.Errorf("edit %w", ErrTargetMessageNotFound)) return + } else if caps.EditMaxAge > 0 && time.Since(editTarget.Timestamp) > caps.EditMaxAge { + portal.sendErrorStatus(ctx, evt, ErrEditTargetTooOld) + return + } else if caps.EditMaxCount > 0 && editTarget.Metadata.EditCount >= caps.EditMaxCount { + portal.sendErrorStatus(ctx, evt, ErrEditTargetTooManyEdits) + return } log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Str("edit_target_remote_id", string(editTarget.ID)) }) - err = sender.Client.HandleMatrixEdit(ctx, &MatrixEdit{ + err = editingAPI.HandleMatrixEdit(ctx, &MatrixEdit{ MatrixEventBase: MatrixEventBase[*event.MessageEventContent]{ Event: evt, Content: content, @@ -588,6 +644,12 @@ func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, o func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) { log := zerolog.Ctx(ctx) + reactingAPI, ok := sender.Client.(ReactionHandlingNetworkAPI) + if !ok { + log.Debug().Msg("Ignoring reaction as network connector doesn't implement ReactionHandlingNetworkAPI") + portal.sendErrorStatus(ctx, evt, ErrReactionsNotSupported) + return + } content, ok := evt.Content.Parsed.(*event.ReactionEventContent) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") @@ -618,7 +680,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi }, TargetMessage: reactionTarget, } - preResp, err := sender.Client.PreHandleMatrixReaction(ctx, react) + preResp, err := reactingAPI.PreHandleMatrixReaction(ctx, react) if err != nil { log.Err(err).Msg("Failed to pre-handle Matrix reaction") portal.sendErrorStatus(ctx, evt, err) @@ -672,7 +734,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi } } } - dbReaction, err := sender.Client.HandleMatrixReaction(ctx, react) + dbReaction, err := reactingAPI.HandleMatrixReaction(ctx, react) if err != nil { log.Err(err).Msg("Failed to handle Matrix reaction") portal.sendErrorStatus(ctx, evt, err) @@ -723,13 +785,25 @@ func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLog log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Stringer("redaction_target_mxid", content.Redacts) }) + deletingAPI, deleteOK := sender.Client.(RedactionHandlingNetworkAPI) + reactingAPI, reactOK := sender.Client.(ReactionHandlingNetworkAPI) + if !deleteOK && !reactOK { + log.Debug().Msg("Ignoring redaction without checking target as network connector doesn't implement RedactionHandlingNetworkAPI nor ReactionHandlingNetworkAPI") + portal.sendErrorStatus(ctx, evt, ErrRedactionsNotSupported) + return + } redactionTargetMsg, err := portal.Bridge.DB.Message.GetPartByMXID(ctx, content.Redacts) if err != nil { log.Err(err).Msg("Failed to get redaction target message from database") portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get redaction target message: %w", ErrDatabaseError, err)) return } else if redactionTargetMsg != nil { - err = sender.Client.HandleMatrixMessageRemove(ctx, &MatrixMessageRemove{ + if !deleteOK { + log.Debug().Msg("Ignoring message redaction event as network connector doesn't implement RedactionHandlingNetworkAPI") + portal.sendErrorStatus(ctx, evt, ErrRedactionsNotSupported) + return + } + err = deletingAPI.HandleMatrixMessageRemove(ctx, &MatrixMessageRemove{ MatrixEventBase: MatrixEventBase[*event.RedactionEventContent]{ Event: evt, Content: content, @@ -743,7 +817,12 @@ func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLog portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get redaction target message reaction: %w", ErrDatabaseError, err)) return } else if redactionTargetReaction != nil { - err = sender.Client.HandleMatrixReactionRemove(ctx, &MatrixReactionRemove{ + if !reactOK { + log.Debug().Msg("Ignoring reaction redaction event as network connector doesn't implement ReactionHandlingNetworkAPI") + portal.sendErrorStatus(ctx, evt, ErrReactionsNotSupported) + return + } + err = reactingAPI.HandleMatrixReactionRemove(ctx, &MatrixReactionRemove{ MatrixEventBase: MatrixEventBase[*event.RedactionEventContent]{ Event: evt, Content: content, @@ -1605,9 +1684,9 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i BeeperInitialMembers: initialMembers, } // TODO find this properly from the matrix connector - isBeeper := true + autoJoinInvites := portal.Bridge.Matrix.GetCapabilities().AutoJoinInvites // TODO remove this after initial_members is supported in hungryserv - if isBeeper { + if autoJoinInvites { req.BeeperAutoJoinInvites = true req.Invite = initialMembers } @@ -1682,7 +1761,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i // TODO add m.space.child event } portal.updateUserLocalInfo(ctx, info.UserLocal, source) - if !isBeeper { + if !autoJoinInvites { _, _, err = portal.SyncParticipants(ctx, info.Members, source) if err != nil { log.Err(err).Msg("Failed to sync participants after room creation") From b23c553580353d25c5e57e394f3557e966033c97 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 21 Jun 2024 17:23:45 +0300 Subject: [PATCH 0336/1647] bridgev2: merge redundant FormatGhostMXID into GhostIntent --- bridgev2/cmdstartchat.go | 2 +- bridgev2/ghost.go | 5 +---- bridgev2/matrix/connector.go | 4 ++-- bridgev2/matrix/provisioning.go | 4 ++-- bridgev2/matrixinterface.go | 8 +++----- 5 files changed, 9 insertions(+), 14 deletions(-) diff --git a/bridgev2/cmdstartchat.go b/bridgev2/cmdstartchat.go index 3d0a3ced..0a553cae 100644 --- a/bridgev2/cmdstartchat.go +++ b/bridgev2/cmdstartchat.go @@ -75,7 +75,7 @@ func fnResolveIdentifier(ce *CommandEvent) { resp.Ghost.UpdateInfo(ce.Ctx, resp.UserInfo) } targetName = resp.Ghost.Name - targetMXID = resp.Ghost.MXID + targetMXID = resp.Ghost.Intent.GetMXID() } else if resp.UserInfo != nil && resp.UserInfo.Name != nil { targetName = *resp.UserInfo.Name } diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index 53d457c8..88cb2f85 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -27,7 +27,6 @@ type Ghost struct { Bridge *Bridge Log zerolog.Logger Intent MatrixAPI - MXID id.UserID } func (br *Bridge) loadGhost(ctx context.Context, dbGhost *database.Ghost, queryErr error, id *networkid.UserID) (*Ghost, error) { @@ -47,13 +46,11 @@ func (br *Bridge) loadGhost(ctx context.Context, dbGhost *database.Ghost, queryE return nil, fmt.Errorf("failed to insert new ghost: %w", err) } } - mxid := br.Matrix.FormatGhostMXID(dbGhost.ID) ghost := &Ghost{ Ghost: dbGhost, Bridge: br, Log: br.Log.With().Str("ghost_id", string(dbGhost.ID)).Logger(), - Intent: br.Matrix.GhostIntent(mxid), - MXID: mxid, + Intent: br.Matrix.GhostIntent(dbGhost.ID), } br.ghostsByID[ghost.ID] = ghost return ghost, nil diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index a06aeadb..9134b6b1 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -322,9 +322,9 @@ func (br *Connector) UpdateBotProfile(ctx context.Context) { } } -func (br *Connector) GhostIntent(userID id.UserID) bridgev2.MatrixAPI { +func (br *Connector) GhostIntent(userID networkid.UserID) bridgev2.MatrixAPI { return &ASIntent{ - Matrix: br.AS.Intent(userID), + Matrix: br.AS.Intent(br.FormatGhostMXID(userID)), Connector: br, } } diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 60a87469..b9074dec 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -406,7 +406,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http. apiResp.Name = resp.Ghost.Name apiResp.AvatarURL = resp.Ghost.AvatarMXC apiResp.Identifiers = resp.Ghost.Metadata.Identifiers - apiResp.MXID = resp.Ghost.MXID + apiResp.MXID = resp.Ghost.Intent.GetMXID() } else if resp.UserInfo != nil && resp.UserInfo.Name != nil { apiResp.Name = *resp.UserInfo.Name } @@ -490,7 +490,7 @@ func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Reque apiContact.Identifiers = contact.Ghost.Metadata.Identifiers } apiContact.AvatarURL = contact.Ghost.AvatarMXC - apiContact.MXID = contact.Ghost.MXID + apiContact.MXID = contact.Ghost.Intent.GetMXID() } if contact.Chat != nil { if contact.Chat.Portal == nil { diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index ead306a6..f69cda5e 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -29,17 +29,15 @@ type MatrixConnector interface { GetCapabilities() *MatrixCapabilities ParseGhostMXID(userID id.UserID) (networkid.UserID, bool) - FormatGhostMXID(userID networkid.UserID) id.UserID - - GenerateContentURI(ctx context.Context, mediaID networkid.MediaID) (id.ContentURIString, error) - - GhostIntent(userID id.UserID) MatrixAPI + GhostIntent(userID networkid.UserID) MatrixAPI NewUserIntent(ctx context.Context, userID id.UserID, accessToken string) (MatrixAPI, string, error) BotIntent() MatrixAPI SendBridgeStatus(ctx context.Context, state *status.BridgeState) error SendMessageStatus(ctx context.Context, status *MessageStatus, evt *MessageStatusEventInfo) + GenerateContentURI(ctx context.Context, mediaID networkid.MediaID) (id.ContentURIString, error) + GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) GetMemberInfo(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) From 78f2fda2d454aadc0c5c67c93c0a2806177f570d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 21 Jun 2024 20:08:31 +0300 Subject: [PATCH 0337/1647] bridgev2: rename and fix GetUserLoginByID --- bridgev2/userlogin.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 3e4b35d9..04eb2f38 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -30,6 +30,9 @@ type UserLogin struct { } func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *database.UserLogin) (*UserLogin, error) { + if dbUserLogin == nil { + return nil, nil + } if user == nil { var err error user, err = br.unlockedGetUserByMXID(ctx, dbUserLogin.UserMXID, true) @@ -98,7 +101,7 @@ func (br *Bridge) GetUserLoginsInPortal(ctx context.Context, portal networkid.Po return br.loadManyUserLogins(ctx, nil, logins) } -func (br *Bridge) GetUserLoginByID(ctx context.Context, id networkid.UserLoginID) (*UserLogin, error) { +func (br *Bridge) GetExistingUserLoginByID(ctx context.Context, id networkid.UserLoginID) (*UserLogin, error) { login, err := br.DB.UserLogin.GetByID(ctx, id) if err != nil { return nil, err From 921240d99bf7d2adeb1aadf9206202016ca9cb6b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 21 Jun 2024 20:09:53 +0300 Subject: [PATCH 0338/1647] bridgev2: add method to create space for user login --- bridgev2/space.go | 58 +++++++++++++++++++++++++++++++++++++++++++ bridgev2/userlogin.go | 3 +++ 2 files changed, 61 insertions(+) create mode 100644 bridgev2/space.go diff --git a/bridgev2/space.go b/bridgev2/space.go new file mode 100644 index 00000000..c9018802 --- /dev/null +++ b/bridgev2/space.go @@ -0,0 +1,58 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "fmt" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func (ul *UserLogin) GetSpaceRoom(ctx context.Context) (id.RoomID, error) { + ul.spaceCreateLock.Lock() + defer ul.spaceCreateLock.Unlock() + if ul.SpaceRoom != "" { + return ul.SpaceRoom, nil + } + netName := ul.Bridge.Network.GetName() + var err error + req := &mautrix.ReqCreateRoom{ + Visibility: "private", + Name: fmt.Sprintf("%s (%s)", netName.DisplayName, ul.Metadata.RemoteName), + Topic: fmt.Sprintf("Your %s bridged chats - %s", netName.DisplayName, ul.Metadata.RemoteName), + InitialState: []*event.Event{{ + Type: event.StateRoomAvatar, + Content: event.Content{ + Parsed: &event.RoomAvatarEventContent{ + URL: netName.NetworkIcon, + }, + }, + }}, + CreationContent: map[string]any{ + "type": event.RoomTypeSpace, + }, + PowerLevelOverride: &event.PowerLevelsEventContent{ + Users: map[id.UserID]int{ + ul.Bridge.Bot.GetMXID(): 9001, + ul.UserMXID: 50, + }, + }, + } + ul.SpaceRoom, err = ul.Bridge.Bot.CreateRoom(ctx, req) + if err != nil { + return "", fmt.Errorf("failed to create space room: %w", err) + } + ul.User.DoublePuppet(ctx).EnsureJoined(ctx, ul.SpaceRoom) + err = ul.Save(ctx) + if err != nil { + return "", fmt.Errorf("failed to save space room ID: %w", err) + } + return ul.SpaceRoom, nil +} diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 04eb2f38..5578578b 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -9,6 +9,7 @@ package bridgev2 import ( "context" "fmt" + "sync" "time" "github.com/rs/zerolog" @@ -27,6 +28,8 @@ type UserLogin struct { Client NetworkAPI BridgeState *BridgeStateQueue + + spaceCreateLock sync.Mutex } func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *database.UserLogin) (*UserLogin, error) { From e13b62807ff7ab59d8e16c591cfae3e87145533c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 24 Jun 2024 15:33:28 +0300 Subject: [PATCH 0339/1647] bridgev2: add capability flag for disappearing messages --- bridgev2/bridge.go | 5 +++-- bridgev2/networkinterface.go | 6 ++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 9a73bc2a..765041d2 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -103,8 +103,9 @@ func (br *Bridge) Start() error { if err != nil { return fmt.Errorf("failed to start network connector: %w", err) } - // TODO only start if the network supports disappearing messages? - go br.DisappearLoop.Start() + if br.Network.GetCapabilities().DisappearingMessages { + go br.DisappearLoop.Start() + } logins, err := br.GetAllUserLogins(ctx) if err != nil { diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 314ec20d..21479562 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -103,6 +103,8 @@ type NetworkConnector interface { // such as connecting to the network. Instead, connecting should happen when [NetworkAPI.Connect] is called later. LoadUserLogin(ctx context.Context, login *UserLogin) error + GetCapabilities() *NetworkGeneralCapabilities + GetName() BridgeName // GetConfig returns all the parts of the network connector's config file. Specifically: // - example: a string containing an example config file @@ -164,6 +166,10 @@ type FileRestriction struct { MimeTypes []string } +type NetworkGeneralCapabilities struct { + DisappearingMessages bool +} + type NetworkRoomCapabilities struct { FormattedText bool UserMentions bool From 855715bbed92577df2e9a4c44340c921170e1ef5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 24 Jun 2024 15:56:13 +0300 Subject: [PATCH 0340/1647] bridgev2: add capability flag for refetching ghost info more often --- bridgev2/ghost.go | 21 +++++++++++++++++---- bridgev2/networkinterface.go | 5 +++++ bridgev2/portal.go | 20 ++++++++++---------- 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index 88cb2f85..a43190f1 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -200,15 +200,28 @@ func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, return true } -func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin) { - if ghost.Name != "" && ghost.NameSet { +func (br *Bridge) allowAggressiveUpdateForType(evtType RemoteEventType) bool { + if !br.Network.GetCapabilities().AggressiveUpdateInfo { + return false + } + switch evtType { + case RemoteEventUnknown, RemoteEventMessage, RemoteEventEdit, RemoteEventReaction: + return true + default: + return false + } +} + +func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin, evtType RemoteEventType) { + if ghost.Name != "" && ghost.NameSet && !ghost.Bridge.allowAggressiveUpdateForType(evtType) { return } info, err := source.Client.GetUserInfo(ctx, ghost) if err != nil { - zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to get info to update ghost") + zerolog.Ctx(ctx).Err(err).Msg("Failed to get info to update ghost") + } else if info != nil { + ghost.UpdateInfo(ctx, info) } - ghost.UpdateInfo(ctx, info) } func (ghost *Ghost) UpdateInfo(ctx context.Context, info *UserInfo) { diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 21479562..8c72002c 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -167,7 +167,12 @@ type FileRestriction struct { } type NetworkGeneralCapabilities struct { + // Does the network connector support disappearing messages? + // This flag enables the message disappearing loop in the bridge. DisappearingMessages bool + // Should the bridge re-request user info on incoming messages even if the ghost already has info? + // By default, info is only requested for ghosts with no name, and other updating is left to events. + AggressiveUpdateInfo bool } type NetworkRoomCapabilities struct { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 382917c3..c574c46c 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -894,7 +894,7 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { } } -func (portal *Portal) getIntentFor(ctx context.Context, sender EventSender, source *UserLogin) MatrixAPI { +func (portal *Portal) getIntentFor(ctx context.Context, sender EventSender, source *UserLogin, evtType RemoteEventType) MatrixAPI { var intent MatrixAPI if sender.IsFromMe { intent = source.User.DoublePuppet(ctx) @@ -911,7 +911,7 @@ func (portal *Portal) getIntentFor(ctx context.Context, sender EventSender, sour zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost for message sender") return nil } - ghost.UpdateInfoIfNecessary(ctx, source) + ghost.UpdateInfoIfNecessary(ctx, source, evtType) intent = ghost.Intent } return intent @@ -926,7 +926,7 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin log.Debug().Stringer("existing_mxid", existing.MXID).Msg("Ignoring duplicate message") return } - intent := portal.getIntentFor(ctx, evt.GetSender(), source) + intent := portal.getIntentFor(ctx, evt.GetSender(), source, RemoteEventMessage) if intent == nil { return } @@ -1044,7 +1044,7 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e log.Warn().Msg("Edit target message not found") return } - intent := portal.getIntentFor(ctx, evt.GetSender(), source) + intent := portal.getIntentFor(ctx, evt.GetSender(), source, RemoteEventEdit) if intent == nil { return } @@ -1144,7 +1144,7 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi return } ts := getEventTS(evt) - intent := portal.getIntentFor(ctx, evt.GetSender(), source) + intent := portal.getIntentFor(ctx, evt.GetSender(), source, RemoteEventReaction) resp, err := intent.SendMessage(ctx, portal.MXID, event.EventReaction, &event.Content{ Parsed: &event.ReactionEventContent{ RelatesTo: event.RelatesTo{ @@ -1202,7 +1202,7 @@ func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *Us log.Warn().Msg("Target reaction not found") return } - intent := portal.getIntentFor(ctx, evt.GetSender(), source) + intent := portal.getIntentFor(ctx, evt.GetSender(), source, RemoteEventReactionRemove) ts := getEventTS(evt) _, err = intent.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ Parsed: &event.RedactionEventContent{ @@ -1225,7 +1225,7 @@ func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *Use log.Err(err).Msg("Failed to get target message for removal") return } - intent := portal.getIntentFor(ctx, evt.GetSender(), source) + intent := portal.getIntentFor(ctx, evt.GetSender(), source, RemoteEventMessageRemove) ts := getEventTS(evt) for _, part := range targetParts { resp, err := intent.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ @@ -1281,7 +1281,7 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL return } sender := evt.GetSender() - intent := portal.getIntentFor(ctx, sender, source) + intent := portal.getIntentFor(ctx, sender, source, RemoteEventReadReceipt) err = intent.MarkRead(ctx, portal.MXID, lastTarget.MXID, getEventTS(evt)) if err != nil { log.Err(err).Stringer("target_mxid", lastTarget.MXID).Msg("Failed to bridge read receipt") @@ -1317,7 +1317,7 @@ func (portal *Portal) handleRemoteTyping(ctx context.Context, source *UserLogin, if typedEvt, ok := evt.(RemoteTypingWithType); ok { typingType = typedEvt.GetTypingType() } - intent := portal.getIntentFor(ctx, evt.GetSender(), source) + intent := portal.getIntentFor(ctx, evt.GetSender(), source, RemoteEventTyping) err := intent.MarkTyping(ctx, portal.MXID, typingType, evt.GetTimeout()) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to bridge typing event") @@ -1522,7 +1522,7 @@ func (portal *Portal) SyncParticipants(ctx context.Context, members []networkid. if err != nil { return nil, nil, fmt.Errorf("failed to get ghost for %s: %w", member, err) } - ghost.UpdateInfoIfNecessary(ctx, source) + ghost.UpdateInfoIfNecessary(ctx, source, 0) if expectedIntents[i] == nil { expectedIntents[i] = ghost.Intent if isLoggedInUser { From 287547297eab538ad3f0988896db16d64e8842c3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 24 Jun 2024 19:58:03 +0300 Subject: [PATCH 0341/1647] bridgev2: fill room type in `m.bridge` --- bridgev2/database/portal.go | 2 ++ bridgev2/portal.go | 12 +++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index 0d578ea4..85e6f9c1 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -26,6 +26,8 @@ type PortalQuery struct { type StandardPortalMetadata struct { DisappearType DisappearingType `json:"disappear_type,omitempty"` DisappearTimer time.Duration `json:"disappear_timer,omitempty"` + IsDirect bool `json:"is_direct,omitempty"` + IsSpace bool `json:"is_space,omitempty"` } type PortalMetadata struct { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index c574c46c..f1b8f5f8 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1441,6 +1441,11 @@ func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) { }, // TODO room type } + if portal.Metadata.IsDirect { + bridgeInfo.BeeperRoomType = "dm" + } else if portal.Metadata.IsSpace { + bridgeInfo.BeeperRoomType = "space" + } parent := portal.GetTopLevelParent() if parent != nil { bridgeInfo.Network = &event.BridgeInfoSection{ @@ -1623,6 +1628,10 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *PortalInfo, source * } // TODO detect changes to functional members list? } + if info.IsDirectChat != nil && portal.Metadata.IsDirect != *info.IsDirectChat { + changed = true + portal.Metadata.IsDirect = *info.IsDirectChat + } if source != nil { // TODO is this a good place for this call? there's another one in QueueRemoteEvent err := portal.Bridge.DB.UserPortal.EnsureExists(ctx, source.UserLogin, portal.PortalKey) @@ -1674,7 +1683,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i CreationContent: make(map[string]any), InitialState: make([]*event.Event, 0, 6), Preset: "private_chat", - IsDirect: *info.IsDirectChat, + IsDirect: portal.Metadata.IsDirect, PowerLevelOverride: &event.PowerLevelsEventContent{ Users: map[id.UserID]int{ portal.Bridge.Bot.GetMXID(): 9001, @@ -1692,6 +1701,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i } if *info.IsSpace { req.CreationContent["type"] = event.RoomTypeSpace + portal.Metadata.IsSpace = true } bridgeInfoStateKey, bridgeInfo := portal.getBridgeInfo() emptyString := "" From 09a8a5104a6c93107b1501c40ada435f4f4202d0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 24 Jun 2024 20:10:09 +0300 Subject: [PATCH 0342/1647] bridgev2/mxmain: adjust database owner --- bridgev2/matrix/mxmain/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index 1fde50b8..6edb0604 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -260,7 +260,7 @@ func (br *BridgeMain) initDB() { Msg("Using SQLite without _txlock=immediate is not recommended") } var err error - br.DB, err = dbutil.NewFromConfig(br.Name, dbConfig, dbutil.ZeroLogger(br.Log.With().Str("db_section", "main").Logger())) + br.DB, err = dbutil.NewFromConfig("megabridge/"+br.Name, dbConfig, dbutil.ZeroLogger(br.Log.With().Str("db_section", "main").Logger())) if err != nil { br.Log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to initialize database connection") if sqlError := (&sqlite3.Error{}); errors.As(err, sqlError) && sqlError.Code == sqlite3.ErrCorrupt { From a1ec390dc0a05a85b9eecdab9d6d53b6a4ab6d84 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 25 Jun 2024 11:54:16 +0300 Subject: [PATCH 0343/1647] bridgev2/mxmain: add utility for legacy db migrations --- bridgev2/matrix/mxmain/dberror.go | 4 +- bridgev2/matrix/mxmain/legacymigrate.go | 89 +++++++++++++++++++++++++ bridgev2/matrix/mxmain/main.go | 2 +- go.mod | 2 +- go.sum | 4 +- 5 files changed, 95 insertions(+), 6 deletions(-) create mode 100644 bridgev2/matrix/mxmain/legacymigrate.go diff --git a/bridgev2/matrix/mxmain/dberror.go b/bridgev2/matrix/mxmain/dberror.go index eb34ccfa..1c0f6381 100644 --- a/bridgev2/matrix/mxmain/dberror.go +++ b/bridgev2/matrix/mxmain/dberror.go @@ -48,7 +48,7 @@ func (zpe *zerologPQError) MarshalZerologObject(evt *zerolog.Event) { maybeStr("routine", zpe.Routine) } -func (br *BridgeMain) LogDBUpgradeErrorAndExit(name string, err error) { +func (br *BridgeMain) LogDBUpgradeErrorAndExit(name string, err error, message string) { logEvt := br.Log.WithLevel(zerolog.FatalLevel). Err(err). Str("db_section", name) @@ -60,7 +60,7 @@ func (br *BridgeMain) LogDBUpgradeErrorAndExit(name string, err error) { if errors.As(err, &pqe) { logEvt.Object("pq_error", (*zerologPQError)(pqe)) } - logEvt.Msg("Failed to initialize database") + logEvt.Msg(message) if sqlError := (&sqlite3.Error{}); errors.As(err, sqlError) && sqlError.Code == sqlite3.ErrCorrupt { os.Exit(18) } else if errors.Is(err, dbutil.ErrForeignTables) { diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go new file mode 100644 index 00000000..fc32a6a9 --- /dev/null +++ b/bridgev2/matrix/mxmain/legacymigrate.go @@ -0,0 +1,89 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mxmain + +import ( + "context" + "database/sql" + "errors" +) + +func (br *BridgeMain) LegacyMigrateSimple(renameTablesQuery, copyDataQuery string) func(ctx context.Context) error { + return func(ctx context.Context) error { + _, err := br.DB.Exec(ctx, renameTablesQuery) + if err != nil { + return err + } + upgradesTo, compat, err := br.DB.UpgradeTable[0].DangerouslyRun(ctx, br.DB) + if err != nil { + return err + } + _, err = br.DB.Exec(ctx, copyDataQuery) + if err != nil { + return err + } + _, err = br.DB.Exec(ctx, "UPDATE database_owner SET owner = $1 WHERE key = 0", br.DB.Owner) + if err != nil { + return err + } + _, err = br.DB.Exec(ctx, "UPDATE version SET version = $1, compat = $2", upgradesTo, compat) + if err != nil { + return err + } + + return nil + } +} + +func (br *BridgeMain) CheckLegacyDB(expectedVersion int, minBridgeVersion, firstMegaVersion string, migrator func(context.Context) error, transaction bool) { + log := br.Log.With().Str("action", "migrate legacy db").Logger() + ctx := log.WithContext(context.Background()) + exists, err := br.DB.TableExists(ctx, "database_owner") + if err != nil { + log.Err(err).Msg("Failed to check if database_owner table exists") + return + } else if !exists { + return + } + var owner string + err = br.DB.QueryRow(ctx, "SELECT owner FROM database_owner WHERE key=0").Scan(&owner) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + log.Err(err).Msg("Failed to get database owner") + return + } else if owner != br.Name { + if owner != "megabridge/"+br.Name && owner != "" { + log.Warn().Str("db_owner", owner).Msg("Unexpected database owner, not migrating database") + } + return + } + var dbVersion int + err = br.DB.QueryRow(ctx, "SELECT version FROM version").Scan(&dbVersion) + if dbVersion < expectedVersion { + log.Fatal(). + Int("expected_version", expectedVersion). + Int("version", dbVersion). + Msgf("Unsupported database version. Please upgrade to %s %s or higher before upgrading to %s.", br.Name, minBridgeVersion, firstMegaVersion) // zerolog-allow-msgf + return + } else if dbVersion > expectedVersion { + log.Fatal(). + Int("expected_version", expectedVersion). + Int("version", dbVersion). + Msg("Unsupported database version (higher than expected)") + return + } + log.Info().Msg("Detected legacy database, migrating...") + if transaction { + err = br.DB.DoTxn(ctx, nil, migrator) + } else { + err = migrator(ctx) + } + if err != nil { + br.LogDBUpgradeErrorAndExit("main", err, "Failed to migrate legacy database") + } else { + log.Info().Msg("Successfully migrated legacy database") + } +} diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index 6edb0604..9a5326a6 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -339,7 +339,7 @@ func (br *BridgeMain) Start() { if err != nil { var dbUpgradeErr bridgev2.DBUpgradeError if errors.As(err, &dbUpgradeErr) { - br.LogDBUpgradeErrorAndExit(dbUpgradeErr.Section, dbUpgradeErr.Err) + br.LogDBUpgradeErrorAndExit(dbUpgradeErr.Section, dbUpgradeErr.Err, "Failed to initialize database") } else { br.Log.Fatal().Err(err).Msg("Failed to start bridge") } diff --git a/go.mod b/go.mod index 5c264e03..ab83f0a9 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/tidwall/gjson v1.17.1 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.2 - go.mau.fi/util v0.5.0 + go.mau.fi/util v0.5.1-0.20240625085258-678695edd51c go.mau.fi/zeroconfig v0.1.2 golang.org/x/crypto v0.24.0 golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 diff --git a/go.sum b/go.sum index 88e9ec63..ec612fe3 100644 --- a/go.sum +++ b/go.sum @@ -46,8 +46,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.2 h1:NjGd7lO7zrUn/A7eKwn5PEOt4ONYGqpxSEeZuduvgxc= github.com/yuin/goldmark v1.7.2/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.5.0 h1:8yELAl+1CDRrwGe9NUmREgVclSs26Z68pTWePHVxuDo= -go.mau.fi/util v0.5.0/go.mod h1:DsJzUrJAG53lCZnnYvq9/mOyLuPScWwYhvETiTrpdP4= +go.mau.fi/util v0.5.1-0.20240625085258-678695edd51c h1:LAXHOnupWCFvTyx4ZAu5t+6n7zADldeRHIk1s+2luow= +go.mau.fi/util v0.5.1-0.20240625085258-678695edd51c/go.mod h1:DsJzUrJAG53lCZnnYvq9/mOyLuPScWwYhvETiTrpdP4= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= From 54ff874fac728e79ac0fd880886badc080e50a3e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 25 Jun 2024 13:19:43 +0300 Subject: [PATCH 0344/1647] bridgev2/config: add legacy bridge config migration --- bridgev2/bridgeconfig/upgrade.go | 106 ++++++++++++++++++++++++++++++- 1 file changed, 105 insertions(+), 1 deletion(-) diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 560d5381..5fccf938 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -17,6 +17,11 @@ import ( ) func doUpgrade(helper up.Helper) { + if _, isLegacyConfig := helper.Get(up.Str, "appservice", "database", "uri"); isLegacyConfig { + doMigrateLegacy(helper) + return + } + helper.Copy(up.Str, "bridge", "command_prefix") if dbType, ok := helper.Get(up.Str, "database", "type"); ok && dbType == "sqlite3" { @@ -110,13 +115,111 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Int, "encryption", "rotation", "messages") helper.Copy(up.Bool, "encryption", "rotation", "disable_device_change_key_rotation") + helper.Copy(up.Map, "logging") +} + +func CopyToOtherLocation(helper up.Helper, fieldType up.YAMLType, source, dest []string) { + val, ok := helper.Get(fieldType, source...) + if ok { + helper.Set(fieldType, val, dest...) + } +} + +func CopyMapToOtherLocation(helper up.Helper, source, dest []string) { + val := helper.GetNode(source...) + if val != nil && val.Map != nil { + helper.SetMap(val.Map, dest...) + } +} + +var HackyMigrateLegacyNetworkConfig func(up.Helper) + +func doMigrateLegacy(helper up.Helper) { + if HackyMigrateLegacyNetworkConfig == nil { + _, _ = fmt.Fprintln(os.Stderr, "Legacy bridge config detected, but hacky network config migrator is not set") + os.Exit(1) + } + _, _ = fmt.Fprintln(os.Stderr, "Migrating legacy bridge config") + + helper.Copy(up.Str, "homeserver", "address") + helper.Copy(up.Str, "homeserver", "domain") + helper.Copy(up.Str, "homeserver", "software") + helper.Copy(up.Str|up.Null, "homeserver", "status_endpoint") + helper.Copy(up.Str|up.Null, "homeserver", "message_send_checkpoint_endpoint") + helper.Copy(up.Bool, "homeserver", "async_media") + helper.Copy(up.Str|up.Null, "homeserver", "websocket_proxy") + helper.Copy(up.Bool, "homeserver", "websocket") + helper.Copy(up.Int, "homeserver", "ping_interval_seconds") + + helper.Copy(up.Str|up.Null, "appservice", "address") + helper.Copy(up.Str|up.Null, "appservice", "hostname") + helper.Copy(up.Int|up.Null, "appservice", "port") + helper.Copy(up.Str, "appservice", "id") + helper.Copy(up.Str, "appservice", "bot", "username") + helper.Copy(up.Str, "appservice", "bot", "displayname") + helper.Copy(up.Str, "appservice", "bot", "avatar") + helper.Copy(up.Bool, "appservice", "ephemeral_events") + helper.Copy(up.Bool, "appservice", "async_transactions") + helper.Copy(up.Str, "appservice", "as_token") + helper.Copy(up.Str, "appservice", "hs_token") + + helper.Copy(up.Str, "bridge", "command_prefix") + + CopyToOtherLocation(helper, up.Str, []string{"appservice", "database", "type"}, []string{"database", "type"}) + CopyToOtherLocation(helper, up.Str, []string{"appservice", "database", "uri"}, []string{"database", "uri"}) + CopyToOtherLocation(helper, up.Int, []string{"appservice", "database", "max_open_conns"}, []string{"database", "max_open_conns"}) + CopyToOtherLocation(helper, up.Int, []string{"appservice", "database", "max_idle_conns"}, []string{"database", "max_idle_conns"}) + CopyToOtherLocation(helper, up.Int, []string{"appservice", "database", "max_conn_idle_time"}, []string{"database", "max_conn_idle_time"}) + CopyToOtherLocation(helper, up.Int, []string{"appservice", "database", "max_conn_lifetime"}, []string{"database", "max_conn_lifetime"}) + + CopyToOtherLocation(helper, up.Str, []string{"bridge", "username_template"}, []string{"appservice", "username_template"}) + + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "message_status_events"}, []string{"matrix", "message_status_events"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "delivery_receipts"}, []string{"matrix", "delivery_receipts"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "message_error_notices"}, []string{"matrix", "message_error_notices"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "sync_direct_chat_list"}, []string{"matrix", "sync_direct_chat_list"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "federate_rooms"}, []string{"matrix", "federate_rooms"}) + + CopyToOtherLocation(helper, up.Str, []string{"bridge", "provisioning", "prefix"}, []string{"provisioning", "prefix"}) + CopyToOtherLocation(helper, up.Str, []string{"bridge", "provisioning", "shared_secret"}, []string{"provisioning", "shared_secret"}) + CopyToOtherLocation(helper, up.Str, []string{"appservice", "provisioning", "prefix"}, []string{"provisioning", "prefix"}) + CopyToOtherLocation(helper, up.Str, []string{"appservice", "provisioning", "shared_secret"}, []string{"provisioning", "shared_secret"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "provisioning", "debug_endpoints"}, []string{"provisioning", "debug_endpoints"}) + + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "double_puppet_allow_discovery"}, []string{"double_puppet", "allow_discovery"}) + CopyMapToOtherLocation(helper, []string{"bridge", "double_puppet_server_map"}, []string{"double_puppet", "servers"}) + CopyMapToOtherLocation(helper, []string{"bridge", "login_shared_secret_map"}, []string{"double_puppet", "secrets"}) + + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "allow"}, []string{"encryption", "allow"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "default"}, []string{"encryption", "default"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "require"}, []string{"encryption", "require"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "appservice"}, []string{"encryption", "appservice"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "allow_key_sharing"}, []string{"encryption", "allow_key_sharing"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_outbound_on_ack"}, []string{"encryption", "delete_keys", "delete_outbound_on_ack"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "dont_store_outbound"}, []string{"encryption", "delete_keys", "dont_store_outbound"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "ratchet_on_decrypt"}, []string{"encryption", "delete_keys", "ratchet_on_decrypt"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_fully_used_on_decrypt"}, []string{"encryption", "delete_keys", "delete_fully_used_on_decrypt"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_prev_on_new_session"}, []string{"encryption", "delete_keys", "delete_prev_on_new_session"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_on_device_delete"}, []string{"encryption", "delete_keys", "delete_on_device_delete"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "periodically_delete_expired"}, []string{"encryption", "delete_keys", "periodically_delete_expired"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_outdated_inbound"}, []string{"encryption", "delete_keys", "delete_outdated_inbound"}) + CopyToOtherLocation(helper, up.Str, []string{"bridge", "encryption", "verification_levels", "receive"}, []string{"encryption", "verification_levels", "receive"}) + CopyToOtherLocation(helper, up.Str, []string{"bridge", "encryption", "verification_levels", "send"}, []string{"encryption", "verification_levels", "send"}) + CopyToOtherLocation(helper, up.Str, []string{"bridge", "encryption", "verification_levels", "share"}, []string{"encryption", "verification_levels", "share"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "rotation", "enable_custom"}, []string{"encryption", "rotation", "enable_custom"}) + CopyToOtherLocation(helper, up.Int, []string{"bridge", "encryption", "rotation", "milliseconds"}, []string{"encryption", "rotation", "milliseconds"}) + CopyToOtherLocation(helper, up.Int, []string{"bridge", "encryption", "rotation", "messages"}, []string{"encryption", "rotation", "messages"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "rotation", "disable_device_change_key_rotation"}, []string{"encryption", "rotation", "disable_device_change_key_rotation"}) + if helper.GetNode("logging", "writers") == nil && (helper.GetNode("logging", "print_level") != nil || helper.GetNode("logging", "file_name_format") != nil) { - _, _ = fmt.Fprintln(os.Stderr, "Migrating legacy log configs is not supported") + _, _ = fmt.Fprintln(os.Stderr, "Migrating maulogger configs is not supported") } else if helper.GetNode("logging", "writers") == nil && (helper.GetNode("logging", "handlers") != nil) { _, _ = fmt.Fprintln(os.Stderr, "Migrating Python log configs is not supported") } else { helper.Copy(up.Map, "logging") } + + HackyMigrateLegacyNetworkConfig(helper) } var SpacedBlocks = [][]string{ @@ -133,6 +236,7 @@ var SpacedBlocks = [][]string{ {"appservice", "username_template"}, {"matrix"}, {"provisioning"}, + {"direct_media"}, {"double_puppet"}, {"encryption"}, {"logging"}, From 13b2d62753029927fe853f9867fd4879564a93ee Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 25 Jun 2024 15:04:14 +0300 Subject: [PATCH 0345/1647] bridgev2/matrix: implement resolve identifier not found responses --- bridgev2/matrix/provisioning.go | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index b9074dec..88420bde 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -339,7 +339,7 @@ func (prov *ProvisioningAPI) GetLogins(w http.ResponseWriter, r *http.Request) { jsonResponse(w, http.StatusOK, &RespGetLogins{LoginIDs: user.GetUserLoginIDs()}) } -func (prov *ProvisioningAPI) getLoginForCall(w http.ResponseWriter, r *http.Request) *bridgev2.UserLogin { +func (prov *ProvisioningAPI) GetLoginForRequest(w http.ResponseWriter, r *http.Request) *bridgev2.UserLogin { user := prov.GetUser(r) userLoginID := networkid.UserLoginID(r.URL.Query().Get("login_id")) if userLoginID != "" { @@ -375,7 +375,7 @@ type RespResolveIdentifier struct { } func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http.Request, createChat bool) { - login := prov.getLoginForCall(w, r) + login := prov.GetLoginForRequest(w, r) if login == nil { return } @@ -394,6 +394,13 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http. Err: fmt.Sprintf("Failed to resolve identifier: %v", err), ErrCode: "M_UNKNOWN", }) + return + } else if resp == nil { + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + ErrCode: mautrix.MNotFound.ErrCode, + Err: "Identifier not found", + }) + return } apiResp := &RespResolveIdentifier{ ID: resp.UserID, @@ -444,7 +451,7 @@ type RespGetContactList struct { } func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Request) { - login := prov.getLoginForCall(w, r) + login := prov.GetLoginForRequest(w, r) if login == nil { return } @@ -519,7 +526,7 @@ func (prov *ProvisioningAPI) PostCreateDM(w http.ResponseWriter, r *http.Request } func (prov *ProvisioningAPI) PostCreateGroup(w http.ResponseWriter, r *http.Request) { - login := prov.getLoginForCall(w, r) + login := prov.GetLoginForRequest(w, r) if login == nil { return } From c51be40fbe1fa3ef46756b42440ab9c657ee5522 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 25 Jun 2024 15:55:27 +0300 Subject: [PATCH 0346/1647] bridgev2: don't kick bridge bot when syncing members --- bridgev2/portal.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index f1b8f5f8..0d8d779e 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1540,6 +1540,7 @@ func (portal *Portal) SyncParticipants(ctx context.Context, members []networkid. return expectedUserIDs, extraFunctionalMembers, nil } currentMembers, err := portal.Bridge.Matrix.GetMembers(ctx, portal.MXID) + delete(currentMembers, portal.Bridge.Bot.GetMXID()) for _, intent := range expectedIntents { mxid := intent.GetMXID() memberEvt, ok := currentMembers[mxid] From 34cbfc2601128facc947c20d33bdfe0495c7615b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 25 Jun 2024 16:05:59 +0300 Subject: [PATCH 0347/1647] event: fix content of `io.element.functional_members` --- bridgev2/portal.go | 2 +- event/state.go | 2 +- hicli/sync.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 0d8d779e..e0243453 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1711,7 +1711,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i StateKey: &emptyString, Type: event.StateElementFunctionalMembers, Content: event.Content{Parsed: &event.ElementFunctionalMembersContent{ - FunctionalMembers: append(extraFunctionalMembers, portal.Bridge.Bot.GetMXID()), + ServiceMembers: append(extraFunctionalMembers, portal.Bridge.Bot.GetMXID()), }}, }, &event.Event{ StateKey: &bridgeInfoStateKey, diff --git a/event/state.go b/event/state.go index 809951cf..20a383d5 100644 --- a/event/state.go +++ b/event/state.go @@ -195,5 +195,5 @@ type InsertionMarkerContent struct { } type ElementFunctionalMembersContent struct { - FunctionalMembers []id.UserID `json:"functional_members"` + ServiceMembers []id.UserID `json:"service_members"` } diff --git a/hicli/sync.go b/hicli/sync.go index 24f0cfed..c3f30a72 100644 --- a/hicli/sync.go +++ b/hicli/sync.go @@ -435,7 +435,7 @@ func (h *HiClient) calculateRoomParticipantName(ctx context.Context, roomID id.R _ = mautrixEvt.Content.ParseRaw(mautrixEvt.Type) content, ok := mautrixEvt.Content.Parsed.(*event.ElementFunctionalMembersContent) if ok { - functionalMembers = content.FunctionalMembers + functionalMembers = content.ServiceMembers } } var members, leftMembers []string From bcb0eb5874bbab87cea70cd2d1736e66a7b67aaf Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 25 Jun 2024 16:07:50 +0300 Subject: [PATCH 0348/1647] bridgev2: add delete-portal command --- bridgev2/cmdprocessor.go | 2 +- bridgev2/cmdstartchat.go | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/bridgev2/cmdprocessor.go b/bridgev2/cmdprocessor.go index 15604af6..32428fc5 100644 --- a/bridgev2/cmdprocessor.go +++ b/bridgev2/cmdprocessor.go @@ -38,7 +38,7 @@ func NewProcessor(bridge *Bridge) *CommandProcessor { } proc.AddHandlers( CommandHelp, CommandCancel, - CommandRegisterPush, + CommandRegisterPush, CommandDeletePortal, CommandLogin, CommandLogout, CommandSetPreferredLogin, CommandResolveIdentifier, CommandStartChat, ) diff --git a/bridgev2/cmdstartchat.go b/bridgev2/cmdstartchat.go index 0a553cae..a0530cdb 100644 --- a/bridgev2/cmdstartchat.go +++ b/bridgev2/cmdstartchat.go @@ -123,3 +123,23 @@ func fnResolveIdentifier(ce *CommandEvent) { ce.Reply("Found %s", formattedName) } } + +var CommandDeletePortal = &FullHandler{ + Func: func(ce *CommandEvent) { + err := ce.Portal.Delete(ce.Ctx) + if err != nil { + ce.Reply("Failed to delete portal: %v", err) + } + err = ce.Bot.DeleteRoom(ce.Ctx, ce.Portal.MXID, false) + if err != nil { + ce.Reply("Failed to clean up room: %v", err) + } + }, + Name: "delete-portal", + Help: HelpMeta{ + Section: HelpSectionAdmin, + Description: "Delete the current portal room", + }, + RequiresAdmin: true, + RequiresPortal: true, +} From 5daf7a74a312d74d2672a66e0f73e27f9912e7b5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 25 Jun 2024 21:29:48 +0300 Subject: [PATCH 0349/1647] client: parse event content in Members --- client.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/client.go b/client.go index 5b4236c2..1d6433ac 100644 --- a/client.go +++ b/client.go @@ -1788,6 +1788,11 @@ func (cli *Client) Members(ctx context.Context, roomID id.RoomID, req ...ReqMemb } u := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "members"}, query) _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp) + if err == nil { + for _, evt := range resp.Chunk { + _ = evt.Content.ParseRaw(evt.Type) + } + } if err == nil && cli.StateStore != nil { var clearMemberships []event.Membership if extra.Membership != "" { From ae054177794b5577103157a323aeedf0d54a17f5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 25 Jun 2024 21:31:46 +0300 Subject: [PATCH 0350/1647] bridgev2: remove outdated comment --- bridgev2/portal.go | 1 - 1 file changed, 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index e0243453..cf5e6cae 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1693,7 +1693,6 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i BeeperLocalRoomID: id.RoomID(fmt.Sprintf("!%s:%s", portal.ID, portal.Bridge.Matrix.ServerName())), BeeperInitialMembers: initialMembers, } - // TODO find this properly from the matrix connector autoJoinInvites := portal.Bridge.Matrix.GetCapabilities().AutoJoinInvites // TODO remove this after initial_members is supported in hungryserv if autoJoinInvites { From 3f8bb2fd54b16fe69d2926915143aadaee5b6cea Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 25 Jun 2024 21:45:43 +0300 Subject: [PATCH 0351/1647] bridgev2: don't panic if reply/thread target can't be found --- bridgev2/portal.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index cf5e6cae..147815f6 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -943,6 +943,8 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin replyTo, err = portal.Bridge.DB.Message.GetFirstOrSpecificPartByID(ctx, *converted.ReplyTo) if err != nil { log.Err(err).Msg("Failed to get reply target message from database") + } else if replyTo == nil { + log.Warn().Any("reply_to", converted.ReplyTo).Msg("Reply target message not found in database") } else { relatesToRowID = replyTo.RowID } @@ -951,6 +953,8 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin threadRoot, err = portal.Bridge.DB.Message.GetFirstOrSpecificPartByID(ctx, *converted.ThreadRoot) if err != nil { log.Err(err).Msg("Failed to get thread root message from database") + } else if threadRoot == nil { + log.Warn().Any("thread_root", converted.ThreadRoot).Msg("Thread root message not found in database") } else { relatesToRowID = threadRoot.RowID } From 6996bd1087ebda31b85312367473923050388d57 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 25 Jun 2024 15:06:19 -0600 Subject: [PATCH 0352/1647] bridgev2/portal: allow sender to be empty If the sender is empty, then default to the portal's bridge bot. Signed-off-by: Sumner Evans --- bridgev2/portal.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 147815f6..b7b11406 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -905,7 +905,7 @@ func (portal *Portal) getIntentFor(ctx context.Context, sender EventSender, sour intent = senderLogin.User.DoublePuppet(ctx) } } - if intent == nil { + if intent == nil && sender.Sender != "" { ghost, err := portal.Bridge.GetGhostByID(ctx, sender.Sender) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost for message sender") @@ -914,6 +914,9 @@ func (portal *Portal) getIntentFor(ctx context.Context, sender EventSender, sour ghost.UpdateInfoIfNecessary(ctx, source, evtType) intent = ghost.Intent } + if intent == nil { + intent = portal.Bridge.Bot + } return intent } From f246e7041420abff4b0cb442fb5e8bdb51e04691 Mon Sep 17 00:00:00 2001 From: Simon Ruderich Date: Tue, 25 Jun 2024 08:10:51 +0200 Subject: [PATCH 0353/1647] verificationhelper: fix deadlock when ignoring an unknown cancellation vh.activeTransactionsLock must be unlocked before leaving the function. The return when ignoring an unknown cancellation was the only one missing the unlock. --- crypto/verificationhelper/verificationhelper.go | 1 + 1 file changed, 1 insertion(+) diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index fb6b1b40..c9fd7407 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -262,6 +262,7 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { // can just ignore it. if evt.Type == event.ToDeviceVerificationCancel || evt.Type == event.InRoomVerificationCancel { log.Info().Msg("Ignoring verification cancellation event for an unknown transaction") + vh.activeTransactionsLock.Unlock() return } From d9a8b7ddbc34e5c8a3139e74ad61f3c5a899a878 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 26 Jun 2024 13:20:06 +0300 Subject: [PATCH 0354/1647] bridgev2/mxmain: specify expected megabridge db version in legacy db migrator --- bridgev2/matrix/mxmain/legacymigrate.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go index fc32a6a9..0aa8a597 100644 --- a/bridgev2/matrix/mxmain/legacymigrate.go +++ b/bridgev2/matrix/mxmain/legacymigrate.go @@ -10,9 +10,10 @@ import ( "context" "database/sql" "errors" + "fmt" ) -func (br *BridgeMain) LegacyMigrateSimple(renameTablesQuery, copyDataQuery string) func(ctx context.Context) error { +func (br *BridgeMain) LegacyMigrateSimple(renameTablesQuery, copyDataQuery string, newDBVersion int) func(ctx context.Context) error { return func(ctx context.Context) error { _, err := br.DB.Exec(ctx, renameTablesQuery) if err != nil { @@ -22,6 +23,9 @@ func (br *BridgeMain) LegacyMigrateSimple(renameTablesQuery, copyDataQuery strin if err != nil { return err } + if upgradesTo < newDBVersion || compat > newDBVersion { + return fmt.Errorf("unexpected new database version (%d/c:%d, expected %d)", upgradesTo, compat, newDBVersion) + } _, err = br.DB.Exec(ctx, copyDataQuery) if err != nil { return err From 5dbfd7093e58879adb43dc36e77b7d0baf8bc395 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 26 Jun 2024 14:42:48 +0300 Subject: [PATCH 0355/1647] bridgev2/login: allow returning whole UserLogin in LoginCompleteParams --- bridgev2/login.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bridgev2/login.go b/bridgev2/login.go index d438b183..1f5e3267 100644 --- a/bridgev2/login.go +++ b/bridgev2/login.go @@ -185,6 +185,7 @@ type LoginUserInputParams struct { type LoginCompleteParams struct { UserLoginID networkid.UserLoginID `json:"user_login_id"` + UserLogin *UserLogin `json:"-"` } type LoginSubmit struct { From 1a18d9ee55f198901eb8985a0474149e479b4085 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 26 Jun 2024 19:50:06 +0300 Subject: [PATCH 0356/1647] bridgev2/login: add display nothing type for wait step --- bridgev2/cmdlogin.go | 2 ++ bridgev2/login.go | 11 ++++++----- bridgev2/matrix/provisioning.go | 9 +++------ bridgev2/unorganized-docs/login-step.schema.json | 4 ++-- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/bridgev2/cmdlogin.go b/bridgev2/cmdlogin.go index dfa66319..160e6a96 100644 --- a/bridgev2/cmdlogin.go +++ b/bridgev2/cmdlogin.go @@ -171,6 +171,8 @@ func doLoginDisplayAndWait(ce *CommandEvent, login LoginProcessDisplayAndWait, s ce.ReplyAdvanced(step.DisplayAndWaitParams.Data, false, false) case LoginDisplayTypeCode: ce.ReplyAdvanced(fmt.Sprintf("%s", html.EscapeString(step.DisplayAndWaitParams.Data)), false, true) + case LoginDisplayTypeNothing: + // Do nothing default: ce.Reply("Unsupported display type %q", step.DisplayAndWaitParams.Type) login.Cancel() diff --git a/bridgev2/login.go b/bridgev2/login.go index 1f5e3267..5ec9c12a 100644 --- a/bridgev2/login.go +++ b/bridgev2/login.go @@ -64,9 +64,10 @@ const ( type LoginDisplayType string const ( - LoginDisplayTypeQR LoginDisplayType = "qr" - LoginDisplayTypeEmoji LoginDisplayType = "emoji" - LoginDisplayTypeCode LoginDisplayType = "code" + LoginDisplayTypeQR LoginDisplayType = "qr" + LoginDisplayTypeEmoji LoginDisplayType = "emoji" + LoginDisplayTypeCode LoginDisplayType = "code" + LoginDisplayTypeNothing LoginDisplayType = "nothing" ) type LoginStep struct { @@ -92,8 +93,8 @@ type LoginStep struct { type LoginDisplayAndWaitParams struct { // The type of thing to display (QR, emoji or text code) Type LoginDisplayType `json:"type"` - // The thing to display (raw data for QR, unicode emoji for emoji, plain string for code) - Data string `json:"data"` + // The thing to display (raw data for QR, unicode emoji for emoji, plain string for code, omitted for nothing) + Data string `json:"data,omitempty"` // An image containing the thing to display. If present, this is recommended over using data directly. // For emojis, the URL to the canonical image representation of the emoji ImageURL string `json:"image_url,omitempty"` diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 88420bde..5e0b2ca7 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -90,7 +90,7 @@ func (prov *ProvisioningAPI) Init() { prov.Router.Path("/v3/login/flows").Methods(http.MethodGet).HandlerFunc(prov.GetLoginFlows) prov.Router.Path("/v3/login/start/{flowID}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginStart) prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginSubmitInput) - prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:wait}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginWait) + prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:display_and_wait}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginWait) prov.Router.Path("/v3/logout/{loginID}").Methods(http.MethodPost).HandlerFunc(prov.PostLogout) prov.Router.Path("/v3/logins").Methods(http.MethodGet).HandlerFunc(prov.GetLogins) prov.Router.Path("/v3/contacts").Methods(http.MethodGet).HandlerFunc(prov.GetContactList) @@ -188,9 +188,6 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { return } stepType := mux.Vars(r)["stepType"] - if stepType == "wait" { - stepType = "display_and_wait" - } if login.NextStep.Type != bridgev2.LoginStepType(stepType) { zerolog.Ctx(r.Context()).Warn(). Str("request_step_type", stepType). @@ -294,9 +291,9 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques login := r.Context().Value(provisioningLoginProcessKey).(*ProvLogin) nextStep, err := login.Process.(bridgev2.LoginProcessDisplayAndWait).Wait(r.Context()) if err != nil { - zerolog.Ctx(r.Context()).Err(err).Msg("Failed to submit input") + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to wait") jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ - Err: "Failed to submit input", + Err: "Failed to wait", ErrCode: "M_UNKNOWN", }) return diff --git a/bridgev2/unorganized-docs/login-step.schema.json b/bridgev2/unorganized-docs/login-step.schema.json index 42edb56f..38f85e6b 100644 --- a/bridgev2/unorganized-docs/login-step.schema.json +++ b/bridgev2/unorganized-docs/login-step.schema.json @@ -112,7 +112,7 @@ "type": { "type": "string", "description": "The type of thing to display", - "enum": ["qr", "emoji", "code"] + "enum": ["qr", "emoji", "code", "nothing"] }, "data": { "type": "string", @@ -123,7 +123,7 @@ "description": "An image containing the thing to display. If present, this is recommended over using data directly. For emojis, the URL to the canonical image representation of the emoji" } }, - "required": ["type", "data"] + "required": ["type"] }, "complete": { "type": "object", From c4cb5dad047bb7ee86c430d63a2606e49d8a1400 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 26 Jun 2024 21:44:39 +0300 Subject: [PATCH 0357/1647] bridgev2/mxmain/legacymigrate: apply SQL upgrade filters to copy data query --- bridgev2/matrix/mxmain/legacymigrate.go | 5 +++++ go.mod | 2 +- go.sum | 4 ++-- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go index 0aa8a597..8b62708d 100644 --- a/bridgev2/matrix/mxmain/legacymigrate.go +++ b/bridgev2/matrix/mxmain/legacymigrate.go @@ -7,6 +7,7 @@ package mxmain import ( + "bytes" "context" "database/sql" "errors" @@ -26,6 +27,10 @@ func (br *BridgeMain) LegacyMigrateSimple(renameTablesQuery, copyDataQuery strin if upgradesTo < newDBVersion || compat > newDBVersion { return fmt.Errorf("unexpected new database version (%d/c:%d, expected %d)", upgradesTo, compat, newDBVersion) } + copyDataQuery, err = br.DB.Internals().FilterSQLUpgrade(bytes.Split([]byte(copyDataQuery), []byte("\n"))) + if err != nil { + return err + } _, err = br.DB.Exec(ctx, copyDataQuery) if err != nil { return err diff --git a/go.mod b/go.mod index ab83f0a9..d381a97d 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/tidwall/gjson v1.17.1 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.2 - go.mau.fi/util v0.5.1-0.20240625085258-678695edd51c + go.mau.fi/util v0.5.1-0.20240626184357-b3f4d78c25cf go.mau.fi/zeroconfig v0.1.2 golang.org/x/crypto v0.24.0 golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 diff --git a/go.sum b/go.sum index ec612fe3..483ae912 100644 --- a/go.sum +++ b/go.sum @@ -46,8 +46,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.2 h1:NjGd7lO7zrUn/A7eKwn5PEOt4ONYGqpxSEeZuduvgxc= github.com/yuin/goldmark v1.7.2/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.5.1-0.20240625085258-678695edd51c h1:LAXHOnupWCFvTyx4ZAu5t+6n7zADldeRHIk1s+2luow= -go.mau.fi/util v0.5.1-0.20240625085258-678695edd51c/go.mod h1:DsJzUrJAG53lCZnnYvq9/mOyLuPScWwYhvETiTrpdP4= +go.mau.fi/util v0.5.1-0.20240626184357-b3f4d78c25cf h1:ceXQTB6IqjqGBGhzOTEBGbxQu7xDyuT9YR06gxr9Ncw= +go.mau.fi/util v0.5.1-0.20240626184357-b3f4d78c25cf/go.mod h1:DsJzUrJAG53lCZnnYvq9/mOyLuPScWwYhvETiTrpdP4= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= From c9314c6a63f8648f7b67e5b497e1b003c9486dc3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 26 Jun 2024 21:44:59 +0300 Subject: [PATCH 0358/1647] bridgev2/database: fix UserPortal.GetAllForLogin --- bridgev2/database/userportal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/database/userportal.go b/bridgev2/database/userportal.go index 49d0bf5a..cfa4887a 100644 --- a/bridgev2/database/userportal.go +++ b/bridgev2/database/userportal.go @@ -83,7 +83,7 @@ func (upq *UserPortalQuery) GetAllByUser(ctx context.Context, userID id.UserID, } func (upq *UserPortalQuery) GetAllForLogin(ctx context.Context, login *UserLogin) ([]*UserPortal, error) { - return upq.QueryMany(ctx, getUserPortalQuery, upq.BridgeID, login.UserMXID, login.ID) + return upq.QueryMany(ctx, getAllPortalsForLoginQuery, upq.BridgeID, login.UserMXID, login.ID) } func (upq *UserPortalQuery) Get(ctx context.Context, login *UserLogin, portal networkid.PortalKey) (*UserPortal, error) { From dbefc6e49c635491020aa5d057abd34af07075b2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 26 Jun 2024 23:41:35 +0300 Subject: [PATCH 0359/1647] appservice/intent: set will_auto_accept for double puppets in EnsureJoined --- appservice/intent.go | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/appservice/intent.go b/appservice/intent.go index 31ba4732..cddac965 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -112,9 +112,21 @@ func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, ext if !errors.Is(err, mautrix.MForbidden) || bot == nil { return fmt.Errorf("failed to ensure joined: %w", err) } - _, inviteErr := bot.InviteUser(ctx, roomID, &mautrix.ReqInviteUser{ - UserID: intent.UserID, - }) + var inviteErr error + if intent.IsCustomPuppet { + _, err = bot.SendStateEvent(ctx, roomID, event.StateMember, intent.UserID.String(), &event.Content{ + Raw: map[string]any{ + "fi.mau.will_auto_accept": true, + }, + Parsed: &event.MemberEventContent{ + Membership: event.MembershipInvite, + }, + }) + } else { + _, inviteErr = bot.InviteUser(ctx, roomID, &mautrix.ReqInviteUser{ + UserID: intent.UserID, + }) + } if inviteErr != nil { return fmt.Errorf("failed to invite in ensure joined: %w", inviteErr) } From c8b03b087ea6acc3f13881c81ae6c3eb3fa92f14 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 27 Jun 2024 00:16:03 +0300 Subject: [PATCH 0360/1647] bridgev2: add portals to per-userlogin space --- bridgev2/bridgeconfig/config.go | 3 +- bridgev2/bridgeconfig/upgrade.go | 1 + bridgev2/database/userportal.go | 24 ++++---- bridgev2/matrix/mxmain/example-config.yaml | 2 + bridgev2/portal.go | 8 +-- bridgev2/queue.go | 5 +- bridgev2/space.go | 69 +++++++++++++++++++++- bridgev2/userlogin.go | 5 ++ 8 files changed, 95 insertions(+), 22 deletions(-) diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index e5ebbc01..fc426320 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -32,7 +32,8 @@ type Config struct { } type BridgeConfig struct { - CommandPrefix string `yaml:"command_prefix"` + CommandPrefix string `yaml:"command_prefix"` + PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"` } type MatrixConfig struct { diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 5fccf938..6f0610a9 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -164,6 +164,7 @@ func doMigrateLegacy(helper up.Helper) { helper.Copy(up.Str, "appservice", "hs_token") helper.Copy(up.Str, "bridge", "command_prefix") + helper.Copy(up.Bool, "bridge", "personal_filtering_spaces") CopyToOtherLocation(helper, up.Str, []string{"appservice", "database", "type"}, []string{"database", "type"}) CopyToOtherLocation(helper, up.Str, []string{"appservice", "database", "uri"}, []string{"database", "uri"}) diff --git a/bridgev2/database/userportal.go b/bridgev2/database/userportal.go index cfa4887a..5c34ad51 100644 --- a/bridgev2/database/userportal.go +++ b/bridgev2/database/userportal.go @@ -51,10 +51,11 @@ const ( getAllPortalsForLoginQuery = getUserPortalBaseQuery + ` WHERE bridge_id=$1 AND user_mxid=$2 AND login_id=$3 ` - insertUserPortalQuery = ` + getOrCreateUserPortalQuery = ` INSERT INTO user_portal (bridge_id, user_mxid, login_id, portal_id, portal_receiver, in_space, preferred) VALUES ($1, $2, $3, $4, $5, false, false) - ON CONFLICT (bridge_id, user_mxid, login_id, portal_id, portal_receiver) DO NOTHING + ON CONFLICT (bridge_id, user_mxid, login_id, portal_id, portal_receiver) DO UPDATE SET portal_id=user_portal.portal_id + RETURNING bridge_id, user_mxid, login_id, portal_id, portal_receiver, in_space, preferred, last_read ` upsertUserPortalQuery = ` INSERT INTO user_portal (bridge_id, user_mxid, login_id, portal_id, portal_receiver, in_space, preferred, last_read) @@ -90,15 +91,15 @@ func (upq *UserPortalQuery) Get(ctx context.Context, login *UserLogin, portal ne return upq.QueryOne(ctx, getUserPortalQuery, upq.BridgeID, login.UserMXID, login.ID, portal.ID, portal.Receiver) } +func (upq *UserPortalQuery) GetOrCreate(ctx context.Context, login *UserLogin, portal networkid.PortalKey) (*UserPortal, error) { + return upq.QueryOne(ctx, getOrCreateUserPortalQuery, upq.BridgeID, login.UserMXID, login.ID, portal.ID, portal.Receiver) +} + func (upq *UserPortalQuery) Put(ctx context.Context, up *UserPortal) error { ensureBridgeIDMatches(&up.BridgeID, upq.BridgeID) return upq.Exec(ctx, upsertUserPortalQuery, up.sqlVariables()...) } -func (upq *UserPortalQuery) EnsureExists(ctx context.Context, login *UserLogin, portal networkid.PortalKey) error { - return upq.Exec(ctx, insertUserPortalQuery, upq.BridgeID, login.UserMXID, login.ID, portal.ID, portal.Receiver) -} - func (upq *UserPortalQuery) MarkAsPreferred(ctx context.Context, login *UserLogin, portal networkid.PortalKey) error { return upq.Exec(ctx, markLoginAsPreferredQuery, upq.BridgeID, login.UserMXID, login.ID, portal.ID, portal.Receiver) } @@ -127,8 +128,11 @@ func (up *UserPortal) sqlVariables() []any { } } -func (up *UserPortal) ResetValues() { - up.InSpace = nil - up.Preferred = nil - up.LastRead = time.Time{} +func (up *UserPortal) CopyWithoutValues() *UserPortal { + return &UserPortal{ + BridgeID: up.BridgeID, + UserMXID: up.UserMXID, + LoginID: up.LoginID, + Portal: up.Portal, + } } diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 8c18934c..1f691d20 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -2,6 +2,8 @@ bridge: # The prefix for commands. Only required in non-management rooms. command_prefix: '$<>' + # Should the bridge create a space for each login containing the rooms that account is in? + personal_filtering_spaces: true # Config for the bridge's database. database: diff --git a/bridgev2/portal.go b/bridgev2/portal.go index b7b11406..850c2613 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -337,6 +337,7 @@ func (portal *Portal) handleMatrixReadReceipt(user *User, eventID id.EventID, re if userPortal == nil { userPortal = database.UserPortalFor(login.UserLogin, portal.PortalKey) } else { + userPortal = userPortal.CopyWithoutValues() evt.LastRead = userPortal.LastRead } evt.ExactMessage, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, eventID) @@ -352,7 +353,6 @@ func (portal *Portal) handleMatrixReadReceipt(user *User, eventID id.EventID, re log.Err(err).Msg("Failed to handle read receipt") return } - userPortal.ResetValues() if evt.ExactMessage != nil { userPortal.LastRead = evt.ExactMessage.Timestamp } else { @@ -1641,11 +1641,7 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *PortalInfo, source * portal.Metadata.IsDirect = *info.IsDirectChat } if source != nil { - // TODO is this a good place for this call? there's another one in QueueRemoteEvent - err := portal.Bridge.DB.UserPortal.EnsureExists(ctx, source.UserLogin, portal.PortalKey) - if err != nil { - zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to ensure user portal row exists") - } + source.MarkInPortal(ctx, portal) portal.updateUserLocalInfo(ctx, info.UserLocal, source) } if changed { diff --git a/bridgev2/queue.go b/bridgev2/queue.go index 4ec9fb30..10c47960 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -92,10 +92,7 @@ func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) { return } // TODO put this in a better place, and maybe cache to avoid constant db queries - err = br.DB.UserPortal.EnsureExists(ctx, login.UserLogin, portal.PortalKey) - if err != nil { - log.Warn().Err(err).Msg("Failed to ensure user portal row exists") - } + login.MarkInPortal(ctx, portal) portal.queueEvent(ctx, &portalRemoteEvent{ evt: evt, source: login, diff --git a/bridgev2/space.go b/bridgev2/space.go index c9018802..55dcaf32 100644 --- a/bridgev2/space.go +++ b/bridgev2/space.go @@ -9,13 +9,67 @@ package bridgev2 import ( "context" "fmt" + "time" + + "github.com/rs/zerolog" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) +func (ul *UserLogin) MarkInPortal(ctx context.Context, portal *Portal) { + if ul.inPortalCache.Has(portal.PortalKey) { + return + } + userPortal, err := ul.Bridge.DB.UserPortal.GetOrCreate(ctx, ul.UserLogin, portal.PortalKey) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to ensure user portal row exists") + return + } + ul.inPortalCache.Add(portal.PortalKey) + if ul.Bridge.Config.PersonalFilteringSpaces && (userPortal.InSpace == nil || !*userPortal.InSpace) { + go ul.tryAddPortalToSpace(ctx, portal, userPortal.CopyWithoutValues()) + } +} + +func (ul *UserLogin) tryAddPortalToSpace(ctx context.Context, portal *Portal, userPortal *database.UserPortal) { + err := ul.AddPortalToSpace(ctx, portal, userPortal) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to add portal to space") + } +} + +func (ul *UserLogin) AddPortalToSpace(ctx context.Context, portal *Portal, userPortal *database.UserPortal) error { + spaceRoom, err := ul.GetSpaceRoom(ctx) + if err != nil { + return fmt.Errorf("failed to get space room: %w", err) + } else if spaceRoom == "" { + return nil + } + _, err = ul.Bridge.Bot.SendState(ctx, spaceRoom, event.StateSpaceChild, portal.MXID.String(), &event.Content{ + Parsed: &event.SpaceChildEventContent{ + Via: []string{ul.Bridge.Matrix.ServerName()}, + }, + }, time.Now()) + if err != nil { + return fmt.Errorf("failed to add portal to space: %w", err) + } + inSpace := true + userPortal.InSpace = &inSpace + err = ul.Bridge.DB.UserPortal.Put(ctx, userPortal) + if err != nil { + return fmt.Errorf("failed to save user portal row: %w", err) + } + zerolog.Ctx(ctx).Debug().Stringer("space_room_id", spaceRoom).Msg("Added portal to space") + return nil +} + func (ul *UserLogin) GetSpaceRoom(ctx context.Context) (id.RoomID, error) { + if !ul.Bridge.Config.PersonalFilteringSpaces { + return ul.SpaceRoom, nil + } ul.spaceCreateLock.Lock() defer ul.spaceCreateLock.Unlock() if ul.SpaceRoom != "" { @@ -23,6 +77,8 @@ func (ul *UserLogin) GetSpaceRoom(ctx context.Context) (id.RoomID, error) { } netName := ul.Bridge.Network.GetName() var err error + autoJoin := ul.Bridge.Matrix.GetCapabilities().AutoJoinInvites + doublePuppet := ul.User.DoublePuppet(ctx) req := &mautrix.ReqCreateRoom{ Visibility: "private", Name: fmt.Sprintf("%s (%s)", netName.DisplayName, ul.Metadata.RemoteName), @@ -44,12 +100,23 @@ func (ul *UserLogin) GetSpaceRoom(ctx context.Context) (id.RoomID, error) { ul.UserMXID: 50, }, }, + Invite: []id.UserID{ul.UserMXID}, + } + if autoJoin { + req.BeeperInitialMembers = []id.UserID{ul.UserMXID} + // TODO remove this after initial_members is supported in hungryserv + req.BeeperAutoJoinInvites = true } ul.SpaceRoom, err = ul.Bridge.Bot.CreateRoom(ctx, req) if err != nil { return "", fmt.Errorf("failed to create space room: %w", err) } - ul.User.DoublePuppet(ctx).EnsureJoined(ctx, ul.SpaceRoom) + if !autoJoin && doublePuppet != nil { + err = doublePuppet.EnsureJoined(ctx, ul.SpaceRoom) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to auto-join created space room with double puppet") + } + } err = ul.Save(ctx) if err != nil { return "", fmt.Errorf("failed to save space room ID: %w", err) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 5578578b..e18c2b25 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -13,6 +13,7 @@ import ( "time" "github.com/rs/zerolog" + "go.mau.fi/util/exsync" "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2/database" @@ -29,6 +30,8 @@ type UserLogin struct { Client NetworkAPI BridgeState *BridgeStateQueue + inPortalCache *exsync.Set[networkid.PortalKey] + spaceCreateLock sync.Mutex } @@ -48,6 +51,8 @@ func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *da Bridge: br, User: user, Log: user.Log.With().Str("login_id", string(dbUserLogin.ID)).Logger(), + + inPortalCache: exsync.NewSet[networkid.PortalKey](), } err := br.Network.LoadUserLogin(ctx, userLogin) if err != nil { From e25578d435a2a5bf972e4c59fde99ba78c2ce112 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 27 Jun 2024 11:32:50 +0300 Subject: [PATCH 0361/1647] bridgev2: improve handling of user logins in bad credentials --- bridgev2/bridge.go | 25 ++++++++++++++++++------- bridgev2/database/userlogin.go | 13 +++++++------ bridgev2/user.go | 6 ++++++ bridgev2/userlogin.go | 24 ++++++++---------------- 4 files changed, 39 insertions(+), 29 deletions(-) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 765041d2..36f7aa06 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -107,18 +107,29 @@ func (br *Bridge) Start() error { go br.DisappearLoop.Start() } - logins, err := br.GetAllUserLogins(ctx) + userIDs, err := br.DB.UserLogin.GetAllUserIDsWithLogins(ctx) if err != nil { - return fmt.Errorf("failed to get user logins: %w", err) + return fmt.Errorf("failed to get users with logins: %w", err) } - for _, login := range logins { - br.Log.Info().Str("id", string(login.ID)).Msg("Starting user login") - err = login.Client.Connect(login.Log.WithContext(ctx)) + startedAny := false + for _, userID := range userIDs { + br.Log.Info().Stringer("user_id", userID).Msg("Loading user") + var user *User + user, err = br.GetUserByMXID(ctx, userID) if err != nil { - br.Log.Err(err).Msg("Failed to connect existing client") + br.Log.Err(err).Stringer("user_id", userID).Msg("Failed to load user") + } else { + for _, login := range user.GetCachedUserLogins() { + startedAny = true + br.Log.Info().Str("id", string(login.ID)).Msg("Starting user login") + err = login.Client.Connect(login.Log.WithContext(ctx)) + if err != nil { + br.Log.Err(err).Msg("Failed to connect existing client") + } + } } } - if len(logins) == 0 { + if !startedAny { br.Log.Info().Msg("No user logins found") br.SendGlobalBridgeState(status.BridgeState{StateEvent: status.StateUnconfigured}) } diff --git a/bridgev2/database/userlogin.go b/bridgev2/database/userlogin.go index b371483c..cc92e7d4 100644 --- a/bridgev2/database/userlogin.go +++ b/bridgev2/database/userlogin.go @@ -54,10 +54,10 @@ const ( getUserLoginBaseQuery = ` SELECT bridge_id, user_mxid, id, space_room, metadata FROM user_login ` - getLoginByIDQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND id=$2` - getAllLoginsQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1` - getAllLoginsForUserQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND user_mxid=$2` - getAllLoginsInPortalQuery = ` + getLoginByIDQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND id=$2` + getAllUsersWithLoginsQuery = `SELECT DISTINCT user_mxid FROM user_login WHERE bridge_id=$1` + getAllLoginsForUserQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND user_mxid=$2` + getAllLoginsInPortalQuery = ` SELECT ul.bridge_id, ul.user_mxid, ul.id, ul.space_room, ul.metadata FROM user_portal LEFT JOIN user_login ul ON user_portal.bridge_id=ul.bridge_id AND user_portal.user_mxid=ul.user_mxid AND user_portal.login_id=ul.id WHERE user_portal.bridge_id=$1 AND user_portal.portal_id=$2 AND user_portal.portal_receiver=$3 @@ -79,8 +79,9 @@ func (uq *UserLoginQuery) GetByID(ctx context.Context, id networkid.UserLoginID) return uq.QueryOne(ctx, getLoginByIDQuery, uq.BridgeID, id) } -func (uq *UserLoginQuery) GetAll(ctx context.Context) ([]*UserLogin, error) { - return uq.QueryMany(ctx, getAllLoginsQuery, uq.BridgeID) +func (uq *UserLoginQuery) GetAllUserIDsWithLogins(ctx context.Context) ([]id.UserID, error) { + rows, err := uq.GetDB().Query(ctx, getAllUsersWithLoginsQuery, uq.BridgeID) + return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() } func (uq *UserLoginQuery) GetAllInPortal(ctx context.Context, portal networkid.PortalKey) ([]*UserLogin, error) { diff --git a/bridgev2/user.go b/bridgev2/user.go index 500a51c0..9fca8de3 100644 --- a/bridgev2/user.go +++ b/bridgev2/user.go @@ -158,6 +158,12 @@ func (user *User) GetUserLoginIDs() []networkid.UserLoginID { return maps.Keys(user.logins) } +func (user *User) GetCachedUserLogins() []*UserLogin { + user.Bridge.cacheLock.Lock() + defer user.Bridge.cacheLock.Unlock() + return maps.Values(user.logins) +} + func (user *User) GetFormattedUserLogins() string { user.Bridge.cacheLock.Lock() logins := make([]string, len(user.logins)) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index e18c2b25..1bc81190 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -56,7 +56,8 @@ func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *da } err := br.Network.LoadUserLogin(ctx, userLogin) if err != nil { - return nil, fmt.Errorf("failed to prepare: %w", err) + userLogin.Log.Err(err).Msg("Failed to load user login") + return nil, nil } user.logins[userLogin.ID] = userLogin br.userLoginsByID[userLogin.ID] = userLogin @@ -65,16 +66,17 @@ func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *da } func (br *Bridge) loadManyUserLogins(ctx context.Context, user *User, logins []*database.UserLogin) ([]*UserLogin, error) { - output := make([]*UserLogin, len(logins)) - for i, dbLogin := range logins { + output := make([]*UserLogin, 0, len(logins)) + for _, dbLogin := range logins { if cached, ok := br.userLoginsByID[dbLogin.ID]; ok { - output[i] = cached + output = append(output, cached) } else { loaded, err := br.loadUserLogin(ctx, user, dbLogin) if err != nil { - return nil, fmt.Errorf("failed to load user login: %w", err) + return nil, err + } else if loaded != nil { + output = append(output, loaded) } - output[i] = loaded } } return output, nil @@ -89,16 +91,6 @@ func (br *Bridge) unlockedLoadUserLoginsByMXID(ctx context.Context, user *User) return err } -func (br *Bridge) GetAllUserLogins(ctx context.Context) ([]*UserLogin, error) { - logins, err := br.DB.UserLogin.GetAll(ctx) - if err != nil { - return nil, err - } - br.cacheLock.Lock() - defer br.cacheLock.Unlock() - return br.loadManyUserLogins(ctx, nil, logins) -} - func (br *Bridge) GetUserLoginsInPortal(ctx context.Context, portal networkid.PortalKey) ([]*UserLogin, error) { logins, err := br.DB.UserLogin.GetAllInPortal(ctx, portal) if err != nil { From b8837f1d8daec846f71f3337682b788b012c8543 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 27 Jun 2024 17:39:30 +0300 Subject: [PATCH 0362/1647] bridgev2: allow updating disappearing timer via PortalInfo --- bridgev2/portal.go | 47 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 850c2613..ee7861d7 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -15,6 +15,7 @@ import ( "time" "github.com/rs/zerolog" + "go.mau.fi/util/exfmt" "go.mau.fi/util/exslices" "go.mau.fi/util/variationselector" "golang.org/x/exp/slices" @@ -1371,6 +1372,7 @@ type PortalInfo struct { IsDirectChat *bool IsSpace *bool + Disappear *database.DisappearingSetting UserLocal *UserLocalPortalInfo } @@ -1618,6 +1620,48 @@ func (portal *Portal) updateUserLocalInfo(ctx context.Context, info *UserLocalPo } } +func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting database.DisappearingSetting, sender *Ghost, ts time.Time, implicit, save bool) bool { + if portal.Metadata.DisappearTimer == setting.Timer { + return false + } + portal.Metadata.DisappearType = setting.Type + portal.Metadata.DisappearTimer = setting.Timer + if setting.Timer == 0 { + portal.Metadata.DisappearType = "" + } + if save { + err := portal.Save(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal to database after updating disappearing setting") + } + } + content := &event.MessageEventContent{ + MsgType: event.MsgNotice, + Body: fmt.Sprintf("Disappearing messages set to %s", exfmt.Duration(setting.Timer)), + } + if implicit { + content.Body = fmt.Sprintf("Automatically enabled disappearing message timer (%s) because incoming message is disappearing", exfmt.Duration(setting.Timer)) + } else if setting.Timer == 0 { + content.Body = "Disappearing messages disabled" + } + intent := portal.Bridge.Bot + if sender != nil { + intent = sender.IntentFor(portal) + } + _, err := intent.SendMessage(ctx, portal.MXID, event.EventMessage, &event.Content{ + Parsed: content, + }, ts) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to send disappearing messages notice") + } else { + zerolog.Ctx(ctx).Debug(). + Dur("new_timer", portal.Metadata.DisappearTimer). + Bool("implicit", implicit). + Msg("Sent disappearing messages notice") + } + return true +} + func (portal *Portal) UpdateInfo(ctx context.Context, info *PortalInfo, source *UserLogin, sender *Ghost, ts time.Time) { changed := false if info.Name != nil { @@ -1629,6 +1673,9 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *PortalInfo, source * if info.Avatar != nil { changed = portal.UpdateAvatar(ctx, info.Avatar, sender, ts) || changed } + if info.Disappear != nil { + changed = portal.UpdateDisappearingSetting(ctx, *info.Disappear, sender, ts, false, false) || changed + } if info.Members != nil && portal.MXID != "" && source != nil { _, _, err := portal.SyncParticipants(ctx, info.Members, source) if err != nil { From 35f8d837b5f5ff4f6675f18e3b8467e141c114a4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 27 Jun 2024 18:51:10 +0300 Subject: [PATCH 0363/1647] bridgev2: add remote chat info change events --- bridgev2/networkinterface.go | 11 +++++ bridgev2/portal.go | 93 +++++++++++++++++++++++------------- 2 files changed, 72 insertions(+), 32 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 8c72002c..413c4f24 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -360,6 +360,7 @@ const ( RemoteEventTyping RemoteEventChatTag RemoteEventChatMute + RemoteEventChatInfoChange ) // RemoteEvent represents a single event from the remote network, such as a message or a reaction. @@ -373,6 +374,16 @@ type RemoteEvent interface { GetSender() EventSender } +type RemotePreHandler interface { + RemoteEvent + PreHandle(ctx context.Context, portal *Portal) +} + +type RemoteChatInfoChange interface { + RemoteEvent + GetChatInfoChange(ctx context.Context) (*ChatInfoChange, error) +} + type RemoteEventThatMayCreatePortal interface { RemoteEvent ShouldCreatePortal() bool diff --git a/bridgev2/portal.go b/bridgev2/portal.go index ee7861d7..1fa3ccdc 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -865,6 +865,10 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { return } } + preHandler, ok := evt.(RemotePreHandler) + if ok { + preHandler.PreHandle(ctx, portal) + } switch evt.GetType() { case RemoteEventUnknown: log.Debug().Msg("Ignoring remote event with type unknown") @@ -890,12 +894,14 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { portal.handleRemoteChatTag(ctx, source, evt.(RemoteChatTag)) case RemoteEventChatMute: portal.handleRemoteChatMute(ctx, source, evt.(RemoteChatMute)) + case RemoteEventChatInfoChange: + portal.handleRemoteChatInfoChange(ctx, source, evt.(RemoteChatInfoChange)) default: log.Warn().Int("type", int(evt.GetType())).Msg("Got remote event with unknown type") } } -func (portal *Portal) getIntentFor(ctx context.Context, sender EventSender, source *UserLogin, evtType RemoteEventType) MatrixAPI { +func (portal *Portal) GetIntentFor(ctx context.Context, sender EventSender, source *UserLogin, evtType RemoteEventType) MatrixAPI { var intent MatrixAPI if sender.IsFromMe { intent = source.User.DoublePuppet(ctx) @@ -930,7 +936,7 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin log.Debug().Stringer("existing_mxid", existing.MXID).Msg("Ignoring duplicate message") return } - intent := portal.getIntentFor(ctx, evt.GetSender(), source, RemoteEventMessage) + intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessage) if intent == nil { return } @@ -1052,7 +1058,7 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e log.Warn().Msg("Edit target message not found") return } - intent := portal.getIntentFor(ctx, evt.GetSender(), source, RemoteEventEdit) + intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventEdit) if intent == nil { return } @@ -1152,7 +1158,7 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi return } ts := getEventTS(evt) - intent := portal.getIntentFor(ctx, evt.GetSender(), source, RemoteEventReaction) + intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventReaction) resp, err := intent.SendMessage(ctx, portal.MXID, event.EventReaction, &event.Content{ Parsed: &event.ReactionEventContent{ RelatesTo: event.RelatesTo{ @@ -1210,7 +1216,7 @@ func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *Us log.Warn().Msg("Target reaction not found") return } - intent := portal.getIntentFor(ctx, evt.GetSender(), source, RemoteEventReactionRemove) + intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventReactionRemove) ts := getEventTS(evt) _, err = intent.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ Parsed: &event.RedactionEventContent{ @@ -1233,7 +1239,7 @@ func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *Use log.Err(err).Msg("Failed to get target message for removal") return } - intent := portal.getIntentFor(ctx, evt.GetSender(), source, RemoteEventMessageRemove) + intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessageRemove) ts := getEventTS(evt) for _, part := range targetParts { resp, err := intent.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ @@ -1289,7 +1295,7 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL return } sender := evt.GetSender() - intent := portal.getIntentFor(ctx, sender, source, RemoteEventReadReceipt) + intent := portal.GetIntentFor(ctx, sender, source, RemoteEventReadReceipt) err = intent.MarkRead(ctx, portal.MXID, lastTarget.MXID, getEventTS(evt)) if err != nil { log.Err(err).Stringer("target_mxid", lastTarget.MXID).Msg("Failed to bridge read receipt") @@ -1325,7 +1331,7 @@ func (portal *Portal) handleRemoteTyping(ctx context.Context, source *UserLogin, if typedEvt, ok := evt.(RemoteTypingWithType); ok { typingType = typedEvt.GetTypingType() } - intent := portal.getIntentFor(ctx, evt.GetSender(), source, RemoteEventTyping) + intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventTyping) err := intent.MarkTyping(ctx, portal.MXID, typingType, evt.GetTimeout()) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to bridge typing event") @@ -1363,6 +1369,27 @@ func (portal *Portal) handleRemoteChatMute(ctx context.Context, source *UserLogi } } +func (portal *Portal) handleRemoteChatInfoChange(ctx context.Context, source *UserLogin, evt RemoteChatInfoChange) { + info, err := evt.GetChatInfoChange(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get chat info change") + return + } + portal.ProcessChatInfoChange(ctx, evt.GetSender(), source, info, getEventTS(evt)) +} + +type ChatInfoChange struct { + PortalInfo *PortalInfo + // TODO member event changes +} + +func (portal *Portal) ProcessChatInfoChange(ctx context.Context, sender EventSender, source *UserLogin, change *ChatInfoChange, ts time.Time) { + intent := portal.GetIntentFor(ctx, sender, source, RemoteEventChatInfoChange) + if change.PortalInfo != nil { + portal.UpdateInfo(ctx, change.PortalInfo, source, intent, ts) + } +} + type PortalInfo struct { Name *string Topic *string @@ -1375,6 +1402,8 @@ type PortalInfo struct { Disappear *database.DisappearingSetting UserLocal *UserLocalPortalInfo + + ExtraUpdates func(context.Context, *Portal) bool } type UserLocalPortalInfo struct { @@ -1382,7 +1411,7 @@ type UserLocalPortalInfo struct { Tag *event.RoomTag } -func (portal *Portal) UpdateName(ctx context.Context, name string, sender *Ghost, ts time.Time) bool { +func (portal *Portal) UpdateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time) bool { if portal.Name == name && (portal.NameSet || portal.MXID == "") { return false } @@ -1391,7 +1420,7 @@ func (portal *Portal) UpdateName(ctx context.Context, name string, sender *Ghost return true } -func (portal *Portal) UpdateTopic(ctx context.Context, topic string, sender *Ghost, ts time.Time) bool { +func (portal *Portal) UpdateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time) bool { if portal.Topic == topic && (portal.TopicSet || portal.MXID == "") { return false } @@ -1400,20 +1429,19 @@ func (portal *Portal) UpdateTopic(ctx context.Context, topic string, sender *Gho return true } -func (portal *Portal) UpdateAvatar(ctx context.Context, avatar *Avatar, sender *Ghost, ts time.Time) bool { +func (portal *Portal) UpdateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time) bool { if portal.AvatarID == avatar.ID && (portal.AvatarSet || portal.MXID == "") { return false } portal.AvatarID = avatar.ID - intent := portal.Bridge.Bot - if sender != nil { - intent = sender.IntentFor(portal) + if sender == nil { + sender = portal.Bridge.Bot } if avatar.Remove { portal.AvatarMXC = "" portal.AvatarHash = [32]byte{} } else { - newMXC, newHash, err := avatar.Reupload(ctx, intent, portal.AvatarHash) + newMXC, newHash, err := avatar.Reupload(ctx, sender, portal.AvatarHash) if err != nil { portal.AvatarSet = false zerolog.Ctx(ctx).Err(err).Msg("Failed to reupload room avatar") @@ -1479,20 +1507,19 @@ func (portal *Portal) UpdateBridgeInfo(ctx context.Context) { portal.sendRoomMeta(ctx, nil, time.Now(), event.StateHalfShotBridge, stateKey, &bridgeInfo) } -func (portal *Portal) sendRoomMeta(ctx context.Context, sender *Ghost, ts time.Time, eventType event.Type, stateKey string, content any) bool { +func (portal *Portal) sendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any) bool { if portal.MXID == "" { return false } - intent := portal.Bridge.Bot - if sender != nil { - intent = sender.IntentFor(portal) + if sender == nil { + sender = portal.Bridge.Bot } wrappedContent := &event.Content{Parsed: content} - _, err := intent.SendState(ctx, portal.MXID, eventType, stateKey, wrappedContent, ts) - if errors.Is(err, mautrix.MForbidden) && intent != portal.Bridge.Bot { + _, err := sender.SendState(ctx, portal.MXID, eventType, stateKey, wrappedContent, ts) + if errors.Is(err, mautrix.MForbidden) && sender != portal.Bridge.Bot { wrappedContent.Raw = map[string]any{ - "fi.mau.bridge.set_by": intent.GetMXID(), + "fi.mau.bridge.set_by": sender.GetMXID(), } _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateRoomName, "", wrappedContent, ts) } @@ -1620,15 +1647,15 @@ func (portal *Portal) updateUserLocalInfo(ctx context.Context, info *UserLocalPo } } -func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting database.DisappearingSetting, sender *Ghost, ts time.Time, implicit, save bool) bool { - if portal.Metadata.DisappearTimer == setting.Timer { +func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting database.DisappearingSetting, sender MatrixAPI, ts time.Time, implicit, save bool) bool { + if setting.Timer == 0 { + setting.Type = "" + } + if portal.Metadata.DisappearTimer == setting.Timer && portal.Metadata.DisappearType == setting.Type { return false } portal.Metadata.DisappearType = setting.Type portal.Metadata.DisappearTimer = setting.Timer - if setting.Timer == 0 { - portal.Metadata.DisappearType = "" - } if save { err := portal.Save(ctx) if err != nil { @@ -1644,11 +1671,10 @@ func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting dat } else if setting.Timer == 0 { content.Body = "Disappearing messages disabled" } - intent := portal.Bridge.Bot - if sender != nil { - intent = sender.IntentFor(portal) + if sender == nil { + sender = portal.Bridge.Bot } - _, err := intent.SendMessage(ctx, portal.MXID, event.EventMessage, &event.Content{ + _, err := sender.SendMessage(ctx, portal.MXID, event.EventMessage, &event.Content{ Parsed: content, }, ts) if err != nil { @@ -1662,7 +1688,7 @@ func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting dat return true } -func (portal *Portal) UpdateInfo(ctx context.Context, info *PortalInfo, source *UserLogin, sender *Ghost, ts time.Time) { +func (portal *Portal) UpdateInfo(ctx context.Context, info *PortalInfo, source *UserLogin, sender MatrixAPI, ts time.Time) { changed := false if info.Name != nil { changed = portal.UpdateName(ctx, *info.Name, sender, ts) || changed @@ -1691,6 +1717,9 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *PortalInfo, source * source.MarkInPortal(ctx, portal) portal.updateUserLocalInfo(ctx, info.UserLocal, source) } + if info.ExtraUpdates != nil { + changed = info.ExtraUpdates(ctx, portal) || changed + } if changed { portal.UpdateBridgeInfo(ctx) err := portal.Save(ctx) From 943c33d4ab91a4f12a78be48d0cb472e3970b553 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 27 Jun 2024 19:57:38 +0300 Subject: [PATCH 0364/1647] bridgev2/ghost: add ExtraUpdates for UserInfo too --- bridgev2/ghost.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index a43190f1..efa257e7 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -124,6 +124,8 @@ type UserInfo struct { Name *string Avatar *Avatar IsBot *bool + + ExtraUpdates func(context.Context, *Ghost) bool } func (ghost *Ghost) UpdateName(ctx context.Context, name string) bool { @@ -235,6 +237,9 @@ func (ghost *Ghost) UpdateInfo(ctx context.Context, info *UserInfo) { if info.Identifiers != nil || info.IsBot != nil { update = ghost.UpdateContactInfo(ctx, info.Identifiers, info.IsBot) || update } + if info.ExtraUpdates != nil { + update = info.ExtraUpdates(ctx, ghost) || update + } if update { err := ghost.Bridge.DB.Ghost.Update(ctx, ghost.Ghost) if err != nil { From 6ef736b5202cd6ae69371693b7adbc028c12e725 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 27 Jun 2024 20:10:24 +0300 Subject: [PATCH 0365/1647] bridgev2: add generic common ExtraUpdates functions --- bridgev2/ghost.go | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index efa257e7..1e89b660 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -14,6 +14,7 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/exmime" + "golang.org/x/exp/constraints" "golang.org/x/exp/slices" "maunium.net/go/mautrix/bridgev2/database" @@ -226,6 +227,45 @@ func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin } } +func MergeExtraUpdaters[T any](funcs ...func(context.Context, T) bool) func(context.Context, T) bool { + return func(ctx context.Context, obj T) bool { + update := false + for _, f := range funcs { + update = f(ctx, obj) || update + } + return update + } +} + +func NumberMetadataUpdater[T *Ghost | *Portal, MetaType constraints.Integer | constraints.Float](key string, value MetaType) func(context.Context, T) bool { + return simpleMetadataUpdater[T, MetaType](key, value, database.GetNumberFromMap[MetaType]) +} + +func SimpleMetadataUpdater[T *Ghost | *Portal, MetaType comparable](key string, value MetaType) func(context.Context, T) bool { + return simpleMetadataUpdater[T, MetaType](key, value, func(m map[string]any, key string) (MetaType, bool) { + val, ok := m[key].(MetaType) + return val, ok + }) +} + +func simpleMetadataUpdater[T *Ghost | *Portal, MetaType comparable](key string, value MetaType, getter func(map[string]any, string) (MetaType, bool)) func(context.Context, T) bool { + return func(ctx context.Context, obj T) bool { + var meta map[string]any + switch typedObj := any(obj).(type) { + case *Ghost: + meta = typedObj.Metadata.Extra + case *Portal: + meta = typedObj.Metadata.Extra + } + currentVal, ok := getter(meta, key) + if ok && currentVal == value { + return false + } + meta[key] = value + return true + } +} + func (ghost *Ghost) UpdateInfo(ctx context.Context, info *UserInfo) { update := false if info.Name != nil { From 3d621312d0bb9ee6c2ba3be6715e47b25d9f498b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 28 Jun 2024 00:37:56 +0300 Subject: [PATCH 0366/1647] bridgev2/matrix/intent: remove other tags in TagRoom --- bridgev2/matrix/intent.go | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 325287db..87ad509c 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -271,16 +271,31 @@ func (as *ASIntent) DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnl } func (as *ASIntent) TagRoom(ctx context.Context, roomID id.RoomID, tag event.RoomTag, isTagged bool) error { + tags, err := as.Matrix.GetTags(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to get room tags: %w", err) + } if isTagged { - return as.Matrix.AddTagWithCustomData(ctx, roomID, tag, &event.TagMetadata{ + _, alreadyTagged := tags.Tags[tag] + if alreadyTagged { + return nil + } + err = as.Matrix.AddTagWithCustomData(ctx, roomID, tag, &event.TagMetadata{ MauDoublePuppetSource: as.Connector.AS.DoublePuppetValue, }) - } else { - if tag == "" { - // TODO clear all tags? + if err != nil { + return err } - return as.Matrix.RemoveTag(ctx, roomID, tag) } + for extraTag := range tags.Tags { + if extraTag == event.RoomTagFavourite || extraTag == event.RoomTagLowPriority { + err = as.Matrix.RemoveTag(ctx, roomID, extraTag) + if err != nil { + return fmt.Errorf("failed to remove extra tag %s: %w", extraTag, err) + } + } + } + return nil } func (as *ASIntent) MuteRoom(ctx context.Context, roomID id.RoomID, until time.Time) error { From fce5d19205ad3f78fc699e004d9852be93988cea Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 28 Jun 2024 00:42:32 +0300 Subject: [PATCH 0367/1647] bridgev2: remove chat mute and tag events The chat info change event can be used to do both actions --- bridgev2/networkinterface.go | 14 -------------- bridgev2/portal.go | 37 ++---------------------------------- 2 files changed, 2 insertions(+), 49 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 413c4f24..3dcaee0b 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -358,8 +358,6 @@ const ( RemoteEventDeliveryReceipt RemoteEventMarkUnread RemoteEventTyping - RemoteEventChatTag - RemoteEventChatMute RemoteEventChatInfoChange ) @@ -463,18 +461,6 @@ type RemoteTypingWithType interface { GetTypingType() TypingType } -type RemoteChatTag interface { - RemoteEvent - GetTag() (tag event.RoomTag, remove bool) -} - -var Unmuted = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) - -type RemoteChatMute interface { - RemoteEvent - GetMutedUntil() time.Time -} - // SimpleRemoteEvent is a simple implementation of RemoteEvent that can be used with struct fields and some callbacks. type SimpleRemoteEvent[T any] struct { Type RemoteEventType diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 1fa3ccdc..a33f08c4 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -890,10 +890,6 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { portal.handleRemoteDeliveryReceipt(ctx, source, evt.(RemoteReceipt)) case RemoteEventTyping: portal.handleRemoteTyping(ctx, source, evt.(RemoteTyping)) - case RemoteEventChatTag: - portal.handleRemoteChatTag(ctx, source, evt.(RemoteChatTag)) - case RemoteEventChatMute: - portal.handleRemoteChatMute(ctx, source, evt.(RemoteChatMute)) case RemoteEventChatInfoChange: portal.handleRemoteChatInfoChange(ctx, source, evt.(RemoteChatInfoChange)) default: @@ -1338,37 +1334,6 @@ func (portal *Portal) handleRemoteTyping(ctx context.Context, source *UserLogin, } } -func (portal *Portal) handleRemoteChatTag(ctx context.Context, source *UserLogin, evt RemoteChatTag) { - if !evt.GetSender().IsFromMe { - zerolog.Ctx(ctx).Warn().Msg("Ignoring chat tag event from non-self user") - return - } - dp := source.User.DoublePuppet(ctx) - if dp == nil { - return - } - tag, isTagged := evt.GetTag() - err := dp.TagRoom(ctx, portal.MXID, tag, isTagged) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to bridge chat tag event") - } -} - -func (portal *Portal) handleRemoteChatMute(ctx context.Context, source *UserLogin, evt RemoteChatMute) { - if !evt.GetSender().IsFromMe { - zerolog.Ctx(ctx).Warn().Msg("Ignoring chat mute event from non-self user") - return - } - dp := source.User.DoublePuppet(ctx) - if dp == nil { - return - } - err := dp.MuteRoom(ctx, portal.MXID, evt.GetMutedUntil()) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to bridge chat mute event") - } -} - func (portal *Portal) handleRemoteChatInfoChange(ctx context.Context, source *UserLogin, evt RemoteChatInfoChange) { info, err := evt.GetChatInfoChange(ctx) if err != nil { @@ -1406,6 +1371,8 @@ type PortalInfo struct { ExtraUpdates func(context.Context, *Portal) bool } +var Unmuted = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + type UserLocalPortalInfo struct { MutedUntil *time.Time Tag *event.RoomTag From a2353aaef704ebe52325dfa5d8deadd0f5248818 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 28 Jun 2024 00:44:49 +0300 Subject: [PATCH 0368/1647] bridgev2: rename PortalInfo to ChatInfo --- bridgev2/networkinterface.go | 4 ++-- bridgev2/portal.go | 15 +++++++++------ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 3dcaee0b..91ce9f54 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -211,7 +211,7 @@ type NetworkAPI interface { LogoutRemote(ctx context.Context) IsThisUser(ctx context.Context, userID networkid.UserID) bool - GetChatInfo(ctx context.Context, portal *Portal) (*PortalInfo, error) + GetChatInfo(ctx context.Context, portal *Portal) (*ChatInfo, error) GetUserInfo(ctx context.Context, ghost *Ghost) (*UserInfo, error) GetCapabilities(ctx context.Context, portal *Portal) *NetworkRoomCapabilities @@ -258,7 +258,7 @@ type CreateChatResponse struct { Portal *Portal PortalID networkid.PortalKey - PortalInfo *PortalInfo + PortalInfo *ChatInfo } type IdentifierResolvingNetworkAPI interface { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index a33f08c4..a1d9c541 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1344,18 +1344,21 @@ func (portal *Portal) handleRemoteChatInfoChange(ctx context.Context, source *Us } type ChatInfoChange struct { - PortalInfo *PortalInfo + ChatInfo *ChatInfo // TODO member event changes } func (portal *Portal) ProcessChatInfoChange(ctx context.Context, sender EventSender, source *UserLogin, change *ChatInfoChange, ts time.Time) { intent := portal.GetIntentFor(ctx, sender, source, RemoteEventChatInfoChange) - if change.PortalInfo != nil { - portal.UpdateInfo(ctx, change.PortalInfo, source, intent, ts) + if change.ChatInfo != nil { + portal.UpdateInfo(ctx, change.ChatInfo, source, intent, ts) } } -type PortalInfo struct { +// Deprecated: Renamed to ChatInfo +type PortalInfo = ChatInfo + +type ChatInfo struct { Name *string Topic *string Avatar *Avatar @@ -1655,7 +1658,7 @@ func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting dat return true } -func (portal *Portal) UpdateInfo(ctx context.Context, info *PortalInfo, source *UserLogin, sender MatrixAPI, ts time.Time) { +func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *UserLogin, sender MatrixAPI, ts time.Time) { changed := false if info.Name != nil { changed = portal.UpdateName(ctx, *info.Name, sender, ts) || changed @@ -1696,7 +1699,7 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *PortalInfo, source * } } -func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, info *PortalInfo) error { +func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, info *ChatInfo) error { portal.roomCreateLock.Lock() defer portal.roomCreateLock.Unlock() if portal.MXID != "" { From d7251a4c695056e535ad8c9ded79840bc33ed36d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 28 Jun 2024 02:55:07 +0300 Subject: [PATCH 0369/1647] bridgev2: add support for all member actions, power levels and join rules --- bridgev2/ghost.go | 5 - bridgev2/matrix/connector.go | 5 +- bridgev2/matrix/intent.go | 36 +++ bridgev2/matrixinterface.go | 1 + bridgev2/portal.go | 415 +++++++++++++++++++++++++++-------- event/state.go | 11 +- 6 files changed, 368 insertions(+), 105 deletions(-) diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index 1e89b660..c41b79c9 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -84,11 +84,6 @@ func (br *Bridge) GetGhostByID(ctx context.Context, id networkid.UserID) (*Ghost return br.unlockedGetGhostByID(ctx, id, false) } -func (ghost *Ghost) IntentFor(portal *Portal) MatrixAPI { - // TODO use user double puppet intent if appropriate - return ghost.Intent -} - type Avatar struct { ID networkid.AvatarID Get func(ctx context.Context) ([]byte, error) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 9134b6b1..972e5100 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -435,6 +435,10 @@ func (br *Connector) BotIntent() bridgev2.MatrixAPI { return &ASIntent{Connector: br, Matrix: br.Bot} } +func (br *Connector) GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error) { + return br.Bot.PowerLevels(ctx, roomID) +} + func (br *Connector) GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { // TODO use cache? members, err := br.Bot.Members(ctx, roomID) @@ -443,7 +447,6 @@ func (br *Connector) GetMembers(ctx context.Context, roomID id.RoomID) (map[id.U } output := make(map[id.UserID]*event.MemberEventContent, len(members.Chunk)) for _, evt := range members.Chunk { - _ = evt.Content.ParseRaw(evt.Type) output[id.UserID(evt.GetStateKey())] = evt.Content.AsMember() } return output, nil diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 87ad509c..45927ef0 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -61,7 +61,43 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType } } +func (as *ASIntent) fillMemberEvent(ctx context.Context, roomID id.RoomID, userID id.UserID, content *event.Content) { + targetContent := content.Parsed.(*event.MemberEventContent) + if targetContent.Displayname != "" || targetContent.AvatarURL != "" { + return + } + memberContent, err := as.Matrix.StateStore.TryGetMember(ctx, roomID, userID) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("target_user_id", userID). + Str("membership", string(targetContent.Membership)). + Msg("Failed to get old member content from state store to fill new membership event") + } else if memberContent != nil { + targetContent.Displayname = memberContent.Displayname + targetContent.AvatarURL = memberContent.AvatarURL + } else if ghost, err := as.Connector.Bridge.GetGhostByMXID(ctx, userID); err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("target_user_id", userID). + Str("membership", string(targetContent.Membership)). + Msg("Failed to get ghost to fill new membership event") + } else if ghost != nil { + targetContent.Displayname = ghost.Name + targetContent.AvatarURL = ghost.AvatarMXC + } else if profile, err := as.Matrix.GetProfile(ctx, userID); err != nil { + zerolog.Ctx(ctx).Debug().Err(err). + Stringer("target_user_id", userID). + Str("membership", string(targetContent.Membership)). + Msg("Failed to get profile to fill new membership event") + } else if profile != nil { + targetContent.Displayname = profile.DisplayName + targetContent.AvatarURL = profile.AvatarURL.CUString() + } +} + func (as *ASIntent) SendState(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, content *event.Content, ts time.Time) (*mautrix.RespSendEvent, error) { + if eventType == event.StateMember { + as.fillMemberEvent(ctx, roomID, id.UserID(stateKey), content) + } if ts.IsZero() { return as.Matrix.SendStateEvent(ctx, roomID, eventType, stateKey, content) } else { diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index f69cda5e..7689011d 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -38,6 +38,7 @@ type MatrixConnector interface { GenerateContentURI(ctx context.Context, mediaID networkid.MediaID) (id.ContentURIString, error) + GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error) GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) GetMemberInfo(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index a1d9c541..c299e048 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -897,26 +897,47 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { } } -func (portal *Portal) GetIntentFor(ctx context.Context, sender EventSender, source *UserLogin, evtType RemoteEventType) MatrixAPI { - var intent MatrixAPI +func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID) { if sender.IsFromMe { intent = source.User.DoublePuppet(ctx) - } - if intent == nil && sender.SenderLogin != "" { + if intent != nil { + return + } + extraUserID = source.UserMXID + } else if sender.SenderLogin != "" { senderLogin := portal.Bridge.GetCachedUserLoginByID(sender.SenderLogin) if senderLogin != nil { intent = senderLogin.User.DoublePuppet(ctx) + if intent != nil { + return + } + extraUserID = senderLogin.UserMXID } } - if intent == nil && sender.Sender != "" { + if sender.Sender != "" { + for _, login := range otherLogins { + if login.Client.IsThisUser(ctx, sender.Sender) { + intent = login.User.DoublePuppet(ctx) + if intent != nil { + return + } + extraUserID = login.UserMXID + } + } ghost, err := portal.Bridge.GetGhostByID(ctx, sender.Sender) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost for message sender") - return nil + return + } else { + ghost.UpdateInfoIfNecessary(ctx, source, evtType) + intent = ghost.Intent } - ghost.UpdateInfoIfNecessary(ctx, source, evtType) - intent = ghost.Intent } + return +} + +func (portal *Portal) GetIntentFor(ctx context.Context, sender EventSender, source *UserLogin, evtType RemoteEventType) MatrixAPI { + intent, _ := portal.getIntentAndUserMXIDFor(ctx, sender, source, nil, evtType) if intent == nil { intent = portal.Bridge.Bot } @@ -1344,8 +1365,12 @@ func (portal *Portal) handleRemoteChatInfoChange(ctx context.Context, source *Us } type ChatInfoChange struct { + // The chat info that changed. Any fields that did not change can be left as nil. ChatInfo *ChatInfo - // TODO member event changes + // A list of member changes. + // This list should only include changes, not the whole member list. + // To resync the whole list, use the field inside ChatInfo. + MemberChanges *ChatMemberList } func (portal *Portal) ProcessChatInfoChange(ctx context.Context, sender EventSender, source *UserLogin, change *ChatInfoChange, ts time.Time) { @@ -1353,17 +1378,99 @@ func (portal *Portal) ProcessChatInfoChange(ctx context.Context, sender EventSen if change.ChatInfo != nil { portal.UpdateInfo(ctx, change.ChatInfo, source, intent, ts) } + if change.MemberChanges != nil { + err := portal.SyncParticipants(ctx, change.MemberChanges, source, intent, ts) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to sync room members") + } + } } // Deprecated: Renamed to ChatInfo type PortalInfo = ChatInfo +type ChatMember struct { + EventSender + Membership event.Membership + Nickname string + PowerLevel int + + PrevMembership event.Membership +} + +type ChatMemberList struct { + // Whether this is the full member list. + // If true, any extra members not listed here will be removed from the portal. + IsFull bool + // Should the bridge call IsThisUser for every member in the list? + // This should be used when SenderLogin can't be filled accurately. + CheckAllLogins bool + + Members []ChatMember + PowerLevels *PowerLevelChanges +} + +type PowerLevelChanges struct { + Events map[event.Type]int + UsersDefault *int + EventsDefault *int + StateDefault *int + Invite *int + Kick *int + Ban *int + Redact *int + + Custom func(*event.PowerLevelsEventContent) bool +} + +func (plc *PowerLevelChanges) Apply(content *event.PowerLevelsEventContent) (changed bool) { + if plc == nil || content == nil { + return + } + for evtType, level := range plc.Events { + changed = content.EnsureEventLevel(evtType, level) || changed + } + if plc.UsersDefault != nil { + changed = content.UsersDefault != *plc.UsersDefault + content.UsersDefault = *plc.UsersDefault + } + if plc.EventsDefault != nil { + changed = content.EventsDefault != *plc.EventsDefault + content.EventsDefault = *plc.EventsDefault + } + if plc.StateDefault != nil { + changed = content.StateDefault() != *plc.StateDefault + content.StateDefaultPtr = plc.StateDefault + } + if plc.Invite != nil { + changed = content.Invite() != *plc.Invite + content.InvitePtr = plc.Invite + } + if plc.Kick != nil { + changed = content.Kick() != *plc.Kick + content.KickPtr = plc.Kick + } + if plc.Ban != nil { + changed = content.Ban() != *plc.Ban + content.BanPtr = plc.Ban + } + if plc.Redact != nil { + changed = content.Redact() != *plc.Redact + content.RedactPtr = plc.Redact + } + if plc.Custom != nil { + changed = plc.Custom(content) || changed + } + return changed +} + type ChatInfo struct { Name *string Topic *string Avatar *Avatar - Members []networkid.UserID + Members *ChatMemberList + JoinRule *event.JoinRulesEventContent IsDirectChat *bool IsSpace *bool @@ -1477,22 +1584,26 @@ func (portal *Portal) UpdateBridgeInfo(ctx context.Context) { portal.sendRoomMeta(ctx, nil, time.Now(), event.StateHalfShotBridge, stateKey, &bridgeInfo) } +func (portal *Portal) sendStateWithIntentOrBot(ctx context.Context, sender MatrixAPI, eventType event.Type, stateKey string, content *event.Content, ts time.Time) (resp *mautrix.RespSendEvent, err error) { + if sender == nil { + sender = portal.Bridge.Bot + } + resp, err = sender.SendState(ctx, portal.MXID, eventType, stateKey, content, ts) + if errors.Is(err, mautrix.MForbidden) && sender != portal.Bridge.Bot { + if content.Raw == nil { + content.Raw = make(map[string]any) + } + content.Raw["fi.mau.bridge.set_by"] = sender.GetMXID() + resp, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateRoomName, "", content, ts) + } + return +} + func (portal *Portal) sendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any) bool { if portal.MXID == "" { return false } - - if sender == nil { - sender = portal.Bridge.Bot - } - wrappedContent := &event.Content{Parsed: content} - _, err := sender.SendState(ctx, portal.MXID, eventType, stateKey, wrappedContent, ts) - if errors.Is(err, mautrix.MForbidden) && sender != portal.Bridge.Bot { - wrappedContent.Raw = map[string]any{ - "fi.mau.bridge.set_by": sender.GetMXID(), - } - _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateRoomName, "", wrappedContent, ts) - } + _, err := portal.sendStateWithIntentOrBot(ctx, sender, eventType, stateKey, &event.Content{Parsed: content}, ts) if err != nil { zerolog.Ctx(ctx).Err(err). Str("event_type", eventType.Type). @@ -1502,81 +1613,173 @@ func (portal *Portal) sendRoomMeta(ctx context.Context, sender MatrixAPI, ts tim return true } -func (portal *Portal) SyncParticipants(ctx context.Context, members []networkid.UserID, source *UserLogin) ([]id.UserID, []id.UserID, error) { - loginsInPortal, err := portal.Bridge.GetUserLoginsInPortal(ctx, portal.PortalKey) - if err != nil { - return nil, nil, fmt.Errorf("failed to get user logins in portal: %w", err) +func (portal *Portal) GetInitialMemberList(ctx context.Context, members *ChatMemberList, source *UserLogin, pl *event.PowerLevelsEventContent) (invite, functional []id.UserID, err error) { + if members == nil { + invite = []id.UserID{source.UserMXID} + return } - if !slices.Contains(loginsInPortal, source) { - loginsInPortal = append(loginsInPortal, source) - } - expectedUserIDs := make([]id.UserID, 0, len(members)) - expectedExtraUsers := make([]id.UserID, 0) - expectedIntents := make([]MatrixAPI, len(members)) - extraFunctionalMembers := make([]id.UserID, 0) - for i, member := range members { - isLoggedInUser := false - for _, login := range loginsInPortal { - if login.Client.IsThisUser(ctx, member) { - isLoggedInUser = true - userIntent := login.User.DoublePuppet(ctx) - if userIntent != nil { - expectedIntents[i] = userIntent - } else { - expectedExtraUsers = append(expectedExtraUsers, login.UserMXID) - expectedUserIDs = append(expectedUserIDs, login.UserMXID) - } - break - } - } - ghost, err := portal.Bridge.GetGhostByID(ctx, member) + var loginsInPortal []*UserLogin + if members.CheckAllLogins { + loginsInPortal, err = portal.Bridge.GetUserLoginsInPortal(ctx, portal.PortalKey) if err != nil { - return nil, nil, fmt.Errorf("failed to get ghost for %s: %w", member, err) + err = fmt.Errorf("failed to get user logins in portal: %w", err) + return } - ghost.UpdateInfoIfNecessary(ctx, source, 0) - if expectedIntents[i] == nil { - expectedIntents[i] = ghost.Intent - if isLoggedInUser { - extraFunctionalMembers = append(extraFunctionalMembers, ghost.Intent.GetMXID()) + } + members.PowerLevels.Apply(pl) + for _, member := range members.Members { + intent, extraUserID := portal.getIntentAndUserMXIDFor(ctx, member.EventSender, source, loginsInPortal, 0) + if extraUserID != "" { + invite = append(invite, extraUserID) + pl.EnsureUserLevel(extraUserID, member.PowerLevel) + if intent != nil { + // If intent is present along with a user ID, it's the ghost of a logged-in user, + // so add it to the functional members list + functional = append(functional, intent.GetMXID()) } } - expectedUserIDs = append(expectedUserIDs, expectedIntents[i].GetMXID()) + if intent != nil { + invite = append(invite, intent.GetMXID()) + pl.EnsureUserLevel(intent.GetMXID(), member.PowerLevel) + } } - if portal.MXID == "" { - return expectedUserIDs, extraFunctionalMembers, nil + return +} + +func (portal *Portal) SyncParticipants(ctx context.Context, members *ChatMemberList, source *UserLogin, sender MatrixAPI, ts time.Time) error { + var loginsInPortal []*UserLogin + var err error + if members.CheckAllLogins { + loginsInPortal, err = portal.Bridge.GetUserLoginsInPortal(ctx, portal.PortalKey) + if err != nil { + return fmt.Errorf("failed to get user logins in portal: %w", err) + } + } + if sender == nil { + sender = portal.Bridge.Bot + } + log := zerolog.Ctx(ctx) + currentPower, err := portal.Bridge.Matrix.GetPowerLevels(ctx, portal.MXID) + if err != nil { + return fmt.Errorf("failed to get current power levels: %w", err) } currentMembers, err := portal.Bridge.Matrix.GetMembers(ctx, portal.MXID) + if err != nil { + return fmt.Errorf("failed to get current members: %w", err) + } delete(currentMembers, portal.Bridge.Bot.GetMXID()) - for _, intent := range expectedIntents { - mxid := intent.GetMXID() - memberEvt, ok := currentMembers[mxid] - delete(currentMembers, mxid) - if !ok || memberEvt.Membership != event.MembershipJoin { + powerChanged := members.PowerLevels.Apply(currentPower) + syncUser := func(extraUserID id.UserID, member ChatMember, hasIntent bool) bool { + powerChanged = currentPower.EnsureUserLevel(extraUserID, member.PowerLevel) || powerChanged + currentMember, ok := currentMembers[extraUserID] + delete(currentMembers, extraUserID) + if ok && currentMember.Membership == member.Membership { + return false + } + if currentMember == nil { + currentMember = &event.MemberEventContent{Membership: event.MembershipLeave} + } + if member.PrevMembership != "" && member.PrevMembership != currentMember.Membership { + log.Trace(). + Stringer("user_id", extraUserID). + Str("expected_prev_membership", string(member.PrevMembership)). + Str("actual_prev_membership", string(currentMember.Membership)). + Str("target_membership", string(member.Membership)). + Msg("Not updating membership: prev membership mismatch") + return false + } + content := &event.MemberEventContent{ + Membership: member.Membership, + Displayname: currentMember.Displayname, + AvatarURL: currentMember.AvatarURL, + } + wrappedContent := &event.Content{Parsed: content, Raw: make(map[string]any)} + thisEvtSender := sender + if member.Membership == event.MembershipJoin { + content.Membership = event.MembershipInvite + if hasIntent { + wrappedContent.Raw["fi.mau.will_auto_accept"] = true + } + if thisEvtSender.GetMXID() == extraUserID { + thisEvtSender = portal.Bridge.Bot + } + } + if currentMember != nil && currentMember.Membership == event.MembershipBan && member.Membership != event.MembershipLeave { + unbanContent := *content + unbanContent.Membership = event.MembershipLeave + wrappedUnbanContent := &event.Content{Parsed: &unbanContent} + _, err = portal.sendStateWithIntentOrBot(ctx, thisEvtSender, event.StateMember, extraUserID.String(), wrappedUnbanContent, ts) + if err != nil { + log.Err(err). + Stringer("target_user_id", extraUserID). + Stringer("sender_user_id", thisEvtSender.GetMXID()). + Str("prev_membership", string(currentMember.Membership)). + Str("membership", string(member.Membership)). + Msg("Failed to unban user to update membership") + } else { + log.Trace(). + Stringer("target_user_id", extraUserID). + Stringer("sender_user_id", thisEvtSender.GetMXID()). + Str("prev_membership", string(currentMember.Membership)). + Str("membership", string(member.Membership)). + Msg("Unbanned user to update membership") + } + } + _, err = portal.sendStateWithIntentOrBot(ctx, thisEvtSender, event.StateMember, extraUserID.String(), wrappedContent, ts) + if err != nil { + log.Err(err). + Stringer("target_user_id", extraUserID). + Stringer("sender_user_id", thisEvtSender.GetMXID()). + Str("prev_membership", string(currentMember.Membership)). + Str("membership", string(member.Membership)). + Msg("Failed to update user membership") + } else { + log.Trace(). + Stringer("target_user_id", extraUserID). + Stringer("sender_user_id", thisEvtSender.GetMXID()). + Str("prev_membership", string(currentMember.Membership)). + Str("membership", string(member.Membership)). + Msg("Updating membership in room") + } + return true + } + syncIntent := func(intent MatrixAPI, member ChatMember) { + if !syncUser(intent.GetMXID(), member, true) { + return + } + if member.Membership == event.MembershipJoin { err = intent.EnsureJoined(ctx, portal.MXID) if err != nil { - zerolog.Ctx(ctx).Err(err). - Stringer("user_id", mxid). + log.Err(err). + Stringer("user_id", intent.GetMXID()). Msg("Failed to ensure user is joined to room") } } } - for _, mxid := range expectedExtraUsers { - memberEvt, ok := currentMembers[mxid] - delete(currentMembers, mxid) - if !ok || (memberEvt.Membership != event.MembershipJoin && memberEvt.Membership != event.MembershipInvite) { - err = portal.Bridge.Bot.InviteUser(ctx, portal.MXID, mxid) - if err != nil { - zerolog.Ctx(ctx).Err(err). - Stringer("user_id", mxid). - Msg("Failed to invite user to room") - } + for _, member := range members.Members { + intent, extraUserID := portal.getIntentAndUserMXIDFor(ctx, member.EventSender, source, loginsInPortal, 0) + if intent != nil { + syncIntent(intent, member) + } + if extraUserID != "" { + syncUser(extraUserID, member, false) } } - if portal.Relay == nil { + if powerChanged { + _, err = portal.sendStateWithIntentOrBot(ctx, sender, event.StatePowerLevels, "", &event.Content{Parsed: currentPower}, ts) + if err != nil { + log.Err(err).Msg("Failed to update power levels") + } + } + if members.IsFull { for extraMember, memberEvt := range currentMembers { if memberEvt.Membership == event.MembershipLeave || memberEvt.Membership == event.MembershipBan { continue } + _, isGhost := portal.Bridge.Matrix.ParseGhostMXID(extraMember) + if !isGhost && portal.Relay != nil { + continue + } _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateMember, extraMember.String(), &event.Content{ Parsed: &event.MemberEventContent{ Membership: event.MembershipLeave, @@ -1592,7 +1795,7 @@ func (portal *Portal) SyncParticipants(ctx context.Context, members []networkid. } } } - return expectedUserIDs, extraFunctionalMembers, nil + return nil } func (portal *Portal) updateUserLocalInfo(ctx context.Context, info *UserLocalPortalInfo, source *UserLogin) { @@ -1672,8 +1875,12 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us if info.Disappear != nil { changed = portal.UpdateDisappearingSetting(ctx, *info.Disappear, sender, ts, false, false) || changed } + if info.JoinRule != nil { + // TODO change detection instead of spamming this every time? + portal.sendRoomMeta(ctx, sender, ts, event.StateJoinRules, "", info.JoinRule) + } if info.Members != nil && portal.MXID != "" && source != nil { - _, _, err := portal.SyncParticipants(ctx, info.Members, source) + err := portal.SyncParticipants(ctx, info.Members, source, nil, time.Time{}) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to sync room members") } @@ -1720,25 +1927,29 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i } } portal.UpdateInfo(ctx, info, source, nil, time.Time{}) - initialMembers, extraFunctionalMembers, err := portal.SyncParticipants(ctx, info.Members, source) + powerLevels := &event.PowerLevelsEventContent{ + Events: map[string]int{ + event.StateTombstone.Type: 100, + event.StateServerACL.Type: 100, + event.StateEncryption.Type: 100, + }, + } + initialMembers, extraFunctionalMembers, err := portal.GetInitialMemberList(ctx, info.Members, source, powerLevels) if err != nil { log.Err(err).Msg("Failed to process participant list for portal creation") return err } + powerLevels.EnsureUserLevel(portal.Bridge.Bot.GetMXID(), 9001) req := mautrix.ReqCreateRoom{ - Visibility: "private", - Name: portal.Name, - Topic: portal.Topic, - CreationContent: make(map[string]any), - InitialState: make([]*event.Event, 0, 6), - Preset: "private_chat", - IsDirect: portal.Metadata.IsDirect, - PowerLevelOverride: &event.PowerLevelsEventContent{ - Users: map[id.UserID]int{ - portal.Bridge.Bot.GetMXID(): 9001, - }, - }, + Visibility: "private", + Name: portal.Name, + Topic: portal.Topic, + CreationContent: make(map[string]any), + InitialState: make([]*event.Event, 0, 6), + Preset: "private_chat", + IsDirect: portal.Metadata.IsDirect, + PowerLevelOverride: powerLevels, BeeperLocalRoomID: id.RoomID(fmt.Sprintf("!%s:%s", portal.ID, portal.Bridge.Matrix.ServerName())), BeeperInitialMembers: initialMembers, } @@ -1797,6 +2008,12 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i }}, }) } + if info.JoinRule != nil { + req.InitialState = append(req.InitialState, &event.Event{ + Type: event.StateJoinRules, + Content: event.Content{Parsed: info.JoinRule}, + }) + } roomID, err := portal.Bridge.Bot.CreateRoom(ctx, &req) if err != nil { log.Err(err).Msg("Failed to create Matrix room") @@ -1821,9 +2038,19 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i } portal.updateUserLocalInfo(ctx, info.UserLocal, source) if !autoJoinInvites { - _, _, err = portal.SyncParticipants(ctx, info.Members, source) - if err != nil { - log.Err(err).Msg("Failed to sync participants after room creation") + if info.Members == nil { + dp := source.User.DoublePuppet(ctx) + if dp != nil { + err = dp.EnsureJoined(ctx, portal.MXID) + if err != nil { + log.Err(err).Msg("Failed to ensure user is joined to room after creation") + } + } + } else { + err = portal.SyncParticipants(ctx, info.Members, source, nil, time.Time{}) + if err != nil { + log.Err(err).Msg("Failed to sync participants after room creation") + } } } return nil diff --git a/event/state.go b/event/state.go index 20a383d5..c47a91ca 100644 --- a/event/state.go +++ b/event/state.go @@ -87,11 +87,12 @@ type CreateEventContent struct { type JoinRule string const ( - JoinRulePublic JoinRule = "public" - JoinRuleKnock JoinRule = "knock" - JoinRuleInvite JoinRule = "invite" - JoinRuleRestricted JoinRule = "restricted" - JoinRulePrivate JoinRule = "private" + JoinRulePublic JoinRule = "public" + JoinRuleKnock JoinRule = "knock" + JoinRuleInvite JoinRule = "invite" + JoinRuleRestricted JoinRule = "restricted" + JoinRuleKnockRestricted JoinRule = "knock_restricted" + JoinRulePrivate JoinRule = "private" ) // JoinRulesEventContent represents the content of a m.room.join_rules state event. From b0bc6165e7659eb1d07b2653e4cf68ea4dd19b22 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 28 Jun 2024 22:55:11 +0300 Subject: [PATCH 0370/1647] bridgev2: add support for room name/topic/avatar changes --- bridgev2/matrix/connector.go | 3 ++ bridgev2/messagestatus.go | 1 + bridgev2/networkinterface.go | 28 ++++++++++++++ bridgev2/portal.go | 72 ++++++++++++++++++++++++++++++++++++ 4 files changed, 104 insertions(+) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 972e5100..4b913c21 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -110,6 +110,9 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) { br.EventProcessor.On(event.EventRedaction, br.handleRoomEvent) br.EventProcessor.On(event.EventEncrypted, br.handleEncryptedEvent) br.EventProcessor.On(event.StateMember, br.handleRoomEvent) + br.EventProcessor.On(event.StateRoomName, br.handleRoomEvent) + br.EventProcessor.On(event.StateRoomAvatar, br.handleRoomEvent) + br.EventProcessor.On(event.StateTopic, br.handleRoomEvent) br.EventProcessor.On(event.EphemeralEventReceipt, br.handleEphemeralEvent) br.EventProcessor.On(event.EphemeralEventTyping, br.handleEphemeralEvent) br.Bot = br.AS.BotIntent() diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index f642f73b..201b4080 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -25,6 +25,7 @@ var ( ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true) ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true) ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true) + ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithSendNotice(false) ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true) ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 91ce9f54..0d14e309 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -245,6 +245,21 @@ type TypingHandlingNetworkAPI interface { HandleMatrixTyping(ctx context.Context, msg *MatrixTyping) error } +type RoomNameHandlingNetworkAPI interface { + NetworkAPI + HandleMatrixRoomName(ctx context.Context, msg *MatrixRoomName) (bool, error) +} + +type RoomAvatarHandlingNetworkAPI interface { + NetworkAPI + HandleMatrixRoomAvatar(ctx context.Context, msg *MatrixRoomAvatar) (bool, error) +} + +type RoomTopicHandlingNetworkAPI interface { + NetworkAPI + HandleMatrixRoomTopic(ctx context.Context, msg *MatrixRoomTopic) (bool, error) +} + type ResolveIdentifierResponse struct { Ghost *Ghost @@ -600,6 +615,19 @@ type MatrixMessageRemove struct { TargetMessage *database.Message } +type RoomMetaEventContent interface { + *event.RoomNameEventContent | *event.RoomAvatarEventContent | *event.TopicEventContent +} + +type MatrixRoomMeta[ContentType RoomMetaEventContent] struct { + MatrixEventBase[ContentType] + PrevContent ContentType +} + +type MatrixRoomName = MatrixRoomMeta[*event.RoomNameEventContent] +type MatrixRoomAvatar = MatrixRoomMeta[*event.RoomAvatarEventContent] +type MatrixRoomTopic = MatrixRoomMeta[*event.TopicEventContent] + type MatrixReadReceipt struct { Portal *Portal // The event ID that the receipt is targeting diff --git a/bridgev2/portal.go b/bridgev2/portal.go index c299e048..0b18a535 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -247,6 +247,7 @@ func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) { } log := portal.Log.With(). Str("action", "handle matrix event"). + Str("event_type", evt.Type.Type). Stringer("event_id", evt.ID). Stringer("sender", sender.MXID). Logger() @@ -287,8 +288,11 @@ func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) { case event.EventRedaction: portal.handleMatrixRedaction(ctx, login, origSender, evt) case event.StateRoomName: + handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomNameHandlingNetworkAPI.HandleMatrixRoomName) case event.StateTopic: + handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomTopicHandlingNetworkAPI.HandleMatrixRoomTopic) case event.StateRoomAvatar: + handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomAvatarHandlingNetworkAPI.HandleMatrixRoomAvatar) case event.StateEncryption: } } @@ -772,6 +776,73 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi portal.sendSuccessStatus(ctx, evt) } +func handleMatrixRoomMeta[APIType any, ContentType RoomMetaEventContent]( + portal *Portal, + ctx context.Context, + sender *UserLogin, + origSender *OrigSender, + evt *event.Event, + fn func(APIType, context.Context, *MatrixRoomMeta[ContentType]) (bool, error), +) { + api, ok := sender.Client.(APIType) + if !ok { + portal.sendErrorStatus(ctx, evt, ErrRoomMetadataNotSupported) + return + } + log := zerolog.Ctx(ctx) + content, ok := evt.Content.Parsed.(ContentType) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + return + } + switch typedContent := evt.Content.Parsed.(type) { + case *event.RoomNameEventContent: + if typedContent.Name == portal.Name { + portal.sendSuccessStatus(ctx, evt) + return + } + case *event.TopicEventContent: + if typedContent.Topic == portal.Topic { + portal.sendSuccessStatus(ctx, evt) + return + } + case *event.RoomAvatarEventContent: + if typedContent.URL == portal.AvatarMXC { + portal.sendSuccessStatus(ctx, evt) + return + } + } + var prevContent ContentType + if evt.Unsigned.PrevContent != nil { + _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) + prevContent, _ = evt.Unsigned.PrevContent.Parsed.(ContentType) + } + + changed, err := fn(api, ctx, &MatrixRoomMeta[ContentType]{ + MatrixEventBase: MatrixEventBase[ContentType]{ + Event: evt, + Content: content, + Portal: portal, + OrigSender: origSender, + }, + PrevContent: prevContent, + }) + if err != nil { + log.Err(err).Msg("Failed to handle Matrix room metadata") + portal.sendErrorStatus(ctx, evt, err) + return + } + if changed { + portal.UpdateBridgeInfo(ctx) + err = portal.Save(ctx) + if err != nil { + log.Err(err).Msg("Failed to save portal after updating room metadata") + } + } + portal.sendSuccessStatus(ctx, evt) +} + func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(*event.RedactionEventContent) @@ -823,6 +894,7 @@ func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLog portal.sendErrorStatus(ctx, evt, ErrReactionsNotSupported) return } + // TODO ignore if sender doesn't match? err = reactingAPI.HandleMatrixReactionRemove(ctx, &MatrixReactionRemove{ MatrixEventBase: MatrixEventBase[*event.RedactionEventContent]{ Event: evt, From 86ac5c340b353dfcf10444987656205e6866fc26 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 28 Jun 2024 23:01:58 +0300 Subject: [PATCH 0371/1647] event/beeper: omit empty network value in MSS events --- event/beeper.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/event/beeper.go b/event/beeper.go index 3287e494..1394a6ce 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -32,7 +32,7 @@ const ( ) type BeeperMessageStatusEventContent struct { - Network string `json:"network"` + Network string `json:"network,omitempty"` RelatesTo RelatesTo `json:"m.relates_to"` Status MessageStatus `json:"status"` Reason MessageStatusReason `json:"reason,omitempty"` From 46b4ab4c9d40144a6b78c8597a10aea1ba0e607c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 29 Jun 2024 11:23:14 +0300 Subject: [PATCH 0372/1647] bridgev2: start adding relay support --- bridgev2/bridgeconfig/config.go | 5 +- bridgev2/bridgeconfig/relay.go | 94 ++++++++++++++++ bridgev2/bridgeconfig/upgrade.go | 7 ++ bridgev2/database/portal.go | 49 +++++---- bridgev2/database/upgrades/00-latest.sql | 16 ++- .../upgrades/03-portal-relay-postgres.sql | 13 +++ .../upgrades/04-portal-relay-sqlite.sql | 100 ++++++++++++++++++ bridgev2/matrix/mxmain/example-config.yaml | 19 ++++ bridgev2/portal.go | 27 ++++- bridgev2/userlogin.go | 12 ++- 10 files changed, 311 insertions(+), 31 deletions(-) create mode 100644 bridgev2/bridgeconfig/relay.go create mode 100644 bridgev2/database/upgrades/03-portal-relay-postgres.sql create mode 100644 bridgev2/database/upgrades/04-portal-relay-sqlite.sql diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index fc426320..90cf7b29 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -32,8 +32,9 @@ type Config struct { } type BridgeConfig struct { - CommandPrefix string `yaml:"command_prefix"` - PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"` + CommandPrefix string `yaml:"command_prefix"` + PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"` + Relay RelayConfig `yaml:"relay"` } type MatrixConfig struct { diff --git a/bridgev2/bridgeconfig/relay.go b/bridgev2/bridgeconfig/relay.go new file mode 100644 index 00000000..e71a7d1d --- /dev/null +++ b/bridgev2/bridgeconfig/relay.go @@ -0,0 +1,94 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgeconfig + +import ( + "fmt" + "strings" + "text/template" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" +) + +type RelayConfig struct { + Enabled bool `yaml:"enabled"` + AdminOnly bool `yaml:"admin_only"` + DefaultRelays []networkid.UserLoginID `yaml:"default_relays"` + MessageFormats map[event.MessageType]string `yaml:"message_formats"` + messageTemplates *template.Template `yaml:"-"` +} + +type umRelayConfig RelayConfig + +func (rc *RelayConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + err := unmarshal((*umRelayConfig)(rc)) + if err != nil { + return err + } + + rc.messageTemplates = template.New("messageTemplates") + for key, format := range rc.MessageFormats { + _, err = rc.messageTemplates.New(string(key)).Parse(format) + if err != nil { + return err + } + } + + return nil +} + +type formatData struct { + Sender any + Content *event.MessageEventContent + Caption string + Message string + FileName string +} + +func isMedia(msgType event.MessageType) bool { + switch msgType { + case event.MsgImage, event.MsgVideo, event.MsgAudio, event.MsgFile: + return true + default: + return false + } +} + +func (rc *RelayConfig) FormatMessage(content *event.MessageEventContent, sender any) (*event.MessageEventContent, error) { + _, isSupported := rc.MessageFormats[content.MsgType] + if !isSupported { + return nil, fmt.Errorf("unsupported msgtype for relaying") + } + contentCopy := *content + content = &contentCopy + content.EnsureHasHTML() + fd := &formatData{ + Sender: sender, + Content: content, + Message: content.FormattedBody, + } + fd.Message = content.FormattedBody + if content.FileName != "" { + fd.FileName = content.FileName + if content.FileName != content.Body { + fd.Caption = fd.Message + } + } else if isMedia(content.MsgType) { + content.FileName = content.Body + fd.FileName = content.Body + } + var output strings.Builder + err := rc.messageTemplates.ExecuteTemplate(&output, string(content.MsgType), fd) + if err != nil { + return nil, err + } + content.FormattedBody = output.String() + content.Body = format.HTMLToText(content.FormattedBody) + return content, nil +} diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 6f0610a9..766d6fdf 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -23,6 +23,11 @@ func doUpgrade(helper up.Helper) { } helper.Copy(up.Str, "bridge", "command_prefix") + helper.Copy(up.Bool, "bridge", "personal_filtering_spaces") + helper.Copy(up.Bool, "bridge", "relay", "enabled") + helper.Copy(up.Bool, "bridge", "relay", "admin_only") + helper.Copy(up.List, "bridge", "relay", "default_relays") + helper.Copy(up.Map, "bridge", "relay", "message_formats") if dbType, ok := helper.Get(up.Str, "database", "type"); ok && dbType == "sqlite3" { helper.Set(up.Str, "sqlite3-fk-wal", "database", "type") @@ -165,6 +170,8 @@ func doMigrateLegacy(helper up.Helper) { helper.Copy(up.Str, "bridge", "command_prefix") helper.Copy(up.Bool, "bridge", "personal_filtering_spaces") + helper.Copy(up.Bool, "bridge", "relay", "enabled") + helper.Copy(up.Bool, "bridge", "relay", "admin_only") CopyToOtherLocation(helper, up.Str, []string{"appservice", "database", "type"}, []string{"database", "type"}) CopyToOtherLocation(helper, up.Str, []string{"appservice", "database", "uri"}, []string{"database", "uri"}) diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index 85e6f9c1..0e8c3665 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -48,17 +48,18 @@ type Portal struct { networkid.PortalKey MXID id.RoomID - ParentID networkid.PortalID - Name string - Topic string - AvatarID networkid.AvatarID - AvatarHash [32]byte - AvatarMXC id.ContentURIString - NameSet bool - TopicSet bool - AvatarSet bool - InSpace bool - Metadata PortalMetadata + ParentID networkid.PortalID + RelayLoginID networkid.UserLoginID + Name string + Topic string + AvatarID networkid.AvatarID + AvatarHash [32]byte + AvatarMXC id.ContentURIString + NameSet bool + TopicSet bool + AvatarSet bool + InSpace bool + Metadata PortalMetadata } func newPortal(_ *dbutil.QueryHelper[*Portal]) *Portal { @@ -67,7 +68,8 @@ func newPortal(_ *dbutil.QueryHelper[*Portal]) *Portal { const ( getPortalBaseQuery = ` - SELECT bridge_id, id, receiver, mxid, parent_id, name, topic, avatar_id, avatar_hash, avatar_mxc, + SELECT bridge_id, id, receiver, mxid, parent_id, relay_login_id, + name, topic, avatar_id, avatar_hash, avatar_mxc, name_set, topic_set, avatar_set, in_space, metadata FROM portal @@ -80,15 +82,20 @@ const ( insertPortalQuery = ` INSERT INTO portal ( bridge_id, id, receiver, mxid, - parent_id, name, topic, avatar_id, avatar_hash, avatar_mxc, + parent_id, relay_login_id, + name, topic, avatar_id, avatar_hash, avatar_mxc, name_set, avatar_set, topic_set, in_space, - metadata - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) + metadata, relay_bridge_id + ) VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, + CASE WHEN $6 IS NULL THEN NULL ELSE $1 END + ) ` updatePortalQuery = ` UPDATE portal - SET mxid=$4, parent_id=$5, name=$6, topic=$7, avatar_id=$8, avatar_hash=$9, avatar_mxc=$10, - name_set=$11, avatar_set=$12, topic_set=$13, in_space=$14, metadata=$15 + SET mxid=$4, parent_id=$5, relay_bridge_id=CASE WHEN $6 IS NULL THEN NULL ELSE bridge_id END, relay_login_id=$6, + name=$7, topic=$8, avatar_id=$9, avatar_hash=$10, avatar_mxc=$11, + name_set=$12, avatar_set=$13, topic_set=$14, in_space=$15, metadata=$16 WHERE bridge_id=$1 AND id=$2 AND receiver=$3 ` deletePortalQuery = ` @@ -133,11 +140,11 @@ func (pq *PortalQuery) Delete(ctx context.Context, key networkid.PortalKey) erro } func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { - var mxid, parentID sql.NullString + var mxid, parentID, relayLoginID sql.NullString var avatarHash string err := row.Scan( &p.BridgeID, &p.ID, &p.Receiver, &mxid, - &parentID, &p.Name, &p.Topic, &p.AvatarID, &avatarHash, &p.AvatarMXC, + &parentID, &relayLoginID, &p.Name, &p.Topic, &p.AvatarID, &avatarHash, &p.AvatarMXC, &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.InSpace, dbutil.JSON{Data: &p.Metadata}, ) @@ -155,6 +162,7 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { } p.MXID = id.RoomID(mxid.String) p.ParentID = networkid.PortalID(parentID.String) + p.RelayLoginID = networkid.UserLoginID(relayLoginID.String) return p, nil } @@ -168,7 +176,8 @@ func (p *Portal) sqlVariables() []any { } return []any{ p.BridgeID, p.ID, p.Receiver, dbutil.StrPtr(p.MXID), - dbutil.StrPtr(p.ParentID), p.Name, p.Topic, p.AvatarID, avatarHash, p.AvatarMXC, + dbutil.StrPtr(p.ParentID), dbutil.StrPtr(p.RelayLoginID), + p.Name, p.Topic, p.AvatarID, avatarHash, p.AvatarMXC, p.NameSet, p.TopicSet, p.AvatarSet, p.InSpace, dbutil.JSON{Data: &p.Metadata}, } diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index adceab3c..f60e5ea8 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v2 (compatible with v1+): Latest revision +-- v0 -> v4 (compatible with v1+): Latest revision CREATE TABLE portal ( bridge_id TEXT NOT NULL, id TEXT NOT NULL, @@ -10,6 +10,9 @@ CREATE TABLE portal ( -- Parent groups are probably never DMs, so they don't need a receiver. parent_receiver TEXT NOT NULL DEFAULT '', + relay_bridge_id TEXT, + relay_login_id TEXT, + name TEXT NOT NULL, topic TEXT NOT NULL, avatar_id TEXT NOT NULL, @@ -25,7 +28,10 @@ CREATE TABLE portal ( CONSTRAINT portal_parent_fkey FOREIGN KEY (bridge_id, parent_id, parent_receiver) -- Deletes aren't allowed to cascade here: -- children should be re-parented or cleaned up manually - REFERENCES portal (bridge_id, id, receiver) ON UPDATE CASCADE + REFERENCES portal (bridge_id, id, receiver) ON UPDATE CASCADE, + CONSTRAINT portal_relay_fkey FOREIGN KEY (relay_bridge_id, relay_login_id) + REFERENCES user_login (bridge_id, id) + ON DELETE SET NULL ON UPDATE CASCADE ); CREATE TABLE ghost ( @@ -129,7 +135,7 @@ CREATE TABLE user_login ( space_room TEXT, metadata jsonb NOT NULL, - PRIMARY KEY (bridge_id, user_mxid, id), + PRIMARY KEY (bridge_id, id), CONSTRAINT user_login_user_fkey FOREIGN KEY (bridge_id, user_mxid) REFERENCES "user" (bridge_id, mxid) ON DELETE CASCADE ON UPDATE CASCADE @@ -146,8 +152,8 @@ CREATE TABLE user_portal ( last_read BIGINT, PRIMARY KEY (bridge_id, user_mxid, login_id, portal_id, portal_receiver), - CONSTRAINT user_portal_user_login_fkey FOREIGN KEY (bridge_id, user_mxid, login_id) - REFERENCES user_login (bridge_id, user_mxid, id) + CONSTRAINT user_portal_user_login_fkey FOREIGN KEY (bridge_id, login_id) + REFERENCES user_login (bridge_id, id) ON DELETE CASCADE ON UPDATE CASCADE, CONSTRAINT user_portal_portal_fkey FOREIGN KEY (bridge_id, portal_id, portal_receiver) REFERENCES portal (bridge_id, id, receiver) diff --git a/bridgev2/database/upgrades/03-portal-relay-postgres.sql b/bridgev2/database/upgrades/03-portal-relay-postgres.sql new file mode 100644 index 00000000..4ea52ac6 --- /dev/null +++ b/bridgev2/database/upgrades/03-portal-relay-postgres.sql @@ -0,0 +1,13 @@ +-- v3 (compatible with v1+): Add relay column for portals (Postgres) +-- only: postgres +ALTER TABLE portal ADD COLUMN relay_bridge_id TEXT; +ALTER TABLE portal ADD COLUMN relay_login_id TEXT; +ALTER TABLE user_portal DROP CONSTRAINT user_portal_user_login_fkey; +ALTER TABLE user_login DROP CONSTRAINT user_login_pkey; +ALTER TABLE user_login ADD CONSTRAINT user_login_pkey PRIMARY KEY (bridge_id, id); +ALTER TABLE user_portal ADD CONSTRAINT user_portal_user_login_fkey FOREIGN KEY (bridge_id, login_id) + REFERENCES user_login (bridge_id, id) + ON DELETE CASCADE ON UPDATE CASCADE; +ALTER TABLE portal ADD CONSTRAINT portal_relay_fkey FOREIGN KEY (relay_bridge_id, relay_login_id) + REFERENCES user_login (bridge_id, id) + ON DELETE CASCADE ON UPDATE CASCADE; diff --git a/bridgev2/database/upgrades/04-portal-relay-sqlite.sql b/bridgev2/database/upgrades/04-portal-relay-sqlite.sql new file mode 100644 index 00000000..04385958 --- /dev/null +++ b/bridgev2/database/upgrades/04-portal-relay-sqlite.sql @@ -0,0 +1,100 @@ +-- v4 (compatible with v1+): Add relay column for portals (SQLite) +-- transaction: off +-- only: sqlite + +PRAGMA foreign_keys = OFF; +BEGIN; + +CREATE TABLE user_login_new ( + bridge_id TEXT NOT NULL, + user_mxid TEXT NOT NULL, + id TEXT NOT NULL, + space_room TEXT, + metadata jsonb NOT NULL, + + PRIMARY KEY (bridge_id, id), + CONSTRAINT user_login_user_fkey FOREIGN KEY (bridge_id, user_mxid) + REFERENCES "user" (bridge_id, mxid) + ON DELETE CASCADE ON UPDATE CASCADE +); + +INSERT INTO user_login_new +SELECT bridge_id, user_mxid, id, space_room, metadata +FROM user_login; + +DROP TABLE user_login; +ALTER TABLE user_login_new RENAME TO user_login; + + +CREATE TABLE user_portal_new ( + bridge_id TEXT NOT NULL, + user_mxid TEXT NOT NULL, + login_id TEXT NOT NULL, + portal_id TEXT NOT NULL, + portal_receiver TEXT NOT NULL, + in_space BOOLEAN NOT NULL, + preferred BOOLEAN NOT NULL, + last_read BIGINT, + + PRIMARY KEY (bridge_id, user_mxid, login_id, portal_id, portal_receiver), + CONSTRAINT user_portal_user_login_fkey FOREIGN KEY (bridge_id, login_id) + REFERENCES user_login (bridge_id, id) + ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT user_portal_portal_fkey FOREIGN KEY (bridge_id, portal_id, portal_receiver) + REFERENCES portal (bridge_id, id, receiver) + ON DELETE CASCADE ON UPDATE CASCADE +); + +INSERT INTO user_portal_new +SELECT bridge_id, user_mxid, login_id, portal_id, portal_receiver, in_space, preferred, last_read +FROM user_portal; + +DROP TABLE user_portal; +ALTER TABLE user_portal_new RENAME TO user_portal; + +CREATE TABLE portal_new ( + bridge_id TEXT NOT NULL, + id TEXT NOT NULL, + receiver TEXT NOT NULL, + mxid TEXT, + + parent_id TEXT, + -- This is not accessed by the bridge, it's only used for the portal parent foreign key. + -- Parent groups are probably never DMs, so they don't need a receiver. + parent_receiver TEXT NOT NULL DEFAULT '', + + relay_bridge_id TEXT, + relay_login_id TEXT, + + name TEXT NOT NULL, + topic TEXT NOT NULL, + avatar_id TEXT NOT NULL, + avatar_hash TEXT NOT NULL, + avatar_mxc TEXT NOT NULL, + name_set BOOLEAN NOT NULL, + avatar_set BOOLEAN NOT NULL, + topic_set BOOLEAN NOT NULL, + in_space BOOLEAN NOT NULL, + metadata jsonb NOT NULL, + + PRIMARY KEY (bridge_id, id, receiver), + CONSTRAINT portal_parent_fkey FOREIGN KEY (bridge_id, parent_id, parent_receiver) + -- Deletes aren't allowed to cascade here: + -- children should be re-parented or cleaned up manually + REFERENCES portal (bridge_id, id, receiver) ON UPDATE CASCADE, + CONSTRAINT portal_relay_fkey FOREIGN KEY (relay_bridge_id, relay_login_id) + REFERENCES user_login (bridge_id, id) + ON DELETE SET NULL ON UPDATE CASCADE +); + +INSERT INTO portal_new +SELECT bridge_id, id, receiver, mxid, parent_id, parent_receiver, NULL, NULL, + name, topic, avatar_id, avatar_hash, avatar_mxc, name_set, avatar_set, topic_set, in_space, metadata +FROM portal; + +DROP TABLE portal; +ALTER TABLE portal_new RENAME TO portal; + +PRAGMA foreign_key_check; +COMMIT; +PRAGMA foreign_keys = ON; diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 1f691d20..b6559657 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -4,6 +4,25 @@ bridge: command_prefix: '$<>' # Should the bridge create a space for each login containing the rooms that account is in? personal_filtering_spaces: true + # Settings for relay mode + relay: + # Whether relay mode should be allowed. If allowed, the set-relay command can be used to turn any + # authenticated user into a relaybot for that chat. + enabled: false + # Should only admins be allowed to set themselves as relay users? + admin_only: true + # List of user login IDs which anyone can set as a relay using set-default-relay as long as they're in the room. + default_relays: [] + # The formats to use when sending messages via the relaybot. + message_formats: + m.text: "{{ .Sender.Displayname }}: {{ .Message }}" + m.notice: "{{ .Sender.Displayname }}: {{ .Message }}" + m.emote: "* {{ .Sender.Displayname }} {{ .Message }}" + m.file: "{{ .Sender.Displayname }} sent a file{{ if .Caption }}: {{ .Caption }}{{ end }}" + m.image: "{{ .Sender.Displayname }} sent an image{{ if .Caption }}: {{ .Caption }}{{ end }}" + m.audio: "{{ .Sender.Displayname }} sent an audio file{{ if .Caption }}: {{ .Caption }}{{ end }}" + m.video: "{{ .Sender.Displayname }} sent a video{{ if .Caption }}: {{ .Caption }}{{ end }}" + m.location: "{{ .Sender.Displayname }} sent a location{{ if .Caption }}: {{ .Caption }}{{ end }}" # Config for the bridge's database. database: diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 0b18a535..7a618897 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -91,13 +91,19 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que if portal.MXID != "" { br.portalsByMXID[portal.MXID] = portal } + var err error if portal.ParentID != "" { - var err error portal.Parent, err = br.unlockedGetPortalByID(ctx, networkid.PortalKey{ID: portal.ParentID}, false) if err != nil { return nil, fmt.Errorf("failed to load parent portal (%s): %w", portal.ParentID, err) } } + if portal.RelayLoginID != "" { + portal.Relay, err = br.unlockedGetExistingUserLoginByID(ctx, portal.RelayLoginID) + if err != nil { + return nil, fmt.Errorf("failed to load relay login (%s): %w", portal.RelayLoginID, err) + } + } portal.updateLogger() go portal.eventLoop() return portal, nil @@ -508,12 +514,20 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin portal.handleMatrixEdit(ctx, sender, origSender, evt, content, caps) return } + var err error + if origSender != nil { + content, err = portal.Bridge.Config.Relay.FormatMessage(content, origSender) + if err != nil { + log.Err(err).Msg("Failed to format message for relaying") + portal.sendErrorStatus(ctx, evt, err) + return + } + } if !portal.checkMessageContentCaps(ctx, caps, content, evt) { return } var threadRoot, replyTo *database.Message - var err error if caps.Threads { threadRootID := content.RelatesTo.GetThreadParent() if threadRootID != "" { @@ -594,6 +608,15 @@ func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, o if content.NewContent != nil { content = content.NewContent } + if origSender != nil { + var err error + content, err = portal.Bridge.Config.Relay.FormatMessage(content, origSender) + if err != nil { + log.Err(err).Msg("Failed to format message for relaying") + portal.sendErrorStatus(ctx, evt, err) + return + } + } editingAPI, ok := sender.Client.(EditHandlingNetworkAPI) if !ok { diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 1bc81190..22f7504c 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -102,12 +102,20 @@ func (br *Bridge) GetUserLoginsInPortal(ctx context.Context, portal networkid.Po } func (br *Bridge) GetExistingUserLoginByID(ctx context.Context, id networkid.UserLoginID) (*UserLogin, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + return br.unlockedGetExistingUserLoginByID(ctx, id) +} + +func (br *Bridge) unlockedGetExistingUserLoginByID(ctx context.Context, id networkid.UserLoginID) (*UserLogin, error) { + cached, ok := br.userLoginsByID[id] + if ok { + return cached, nil + } login, err := br.DB.UserLogin.GetByID(ctx, id) if err != nil { return nil, err } - br.cacheLock.Lock() - defer br.cacheLock.Unlock() return br.loadUserLogin(ctx, nil, login) } From 5782506e9ead609ab964f2e2c8af00b2345da7e8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 29 Jun 2024 11:50:43 +0300 Subject: [PATCH 0373/1647] bridgev2: move commands to subpackage --- bridgev2/bridge.go | 18 ++++- bridgev2/cmdmeta.go | 41 ------------ bridgev2/{cmddebug.go => commands/debug.go} | 11 ++-- bridgev2/{cmdevent.go => commands/event.go} | 28 ++++---- .../{cmdhandler.go => commands/handler.go} | 16 ++--- bridgev2/{cmdhelp.go => commands/help.go} | 17 ++++- bridgev2/{cmdlogin.go => commands/login.go} | 64 +++++++++--------- .../processor.go} | 65 +++++++++++++++---- .../startchat.go} | 12 ++-- bridgev2/matrix/cmdadmin.go | 18 ++--- bridgev2/matrix/cmddoublepuppet.go | 26 ++++---- bridgev2/matrix/connector.go | 3 +- bridgev2/matrix/mxmain/main.go | 11 ++-- bridgev2/user.go | 4 +- 14 files changed, 182 insertions(+), 152 deletions(-) delete mode 100644 bridgev2/cmdmeta.go rename bridgev2/{cmddebug.go => commands/debug.go} (86%) rename bridgev2/{cmdevent.go => commands/event.go} (81%) rename bridgev2/{cmdhandler.go => commands/handler.go} (85%) rename bridgev2/{cmdhelp.go => commands/help.go} (92%) rename bridgev2/{cmdlogin.go => commands/login.go} (84%) rename bridgev2/{cmdprocessor.go => commands/processor.go} (68%) rename bridgev2/{cmdstartchat.go => commands/startchat.go} (91%) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 36f7aa06..65d39c80 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -24,6 +24,10 @@ import ( var ErrNotLoggedIn = errors.New("not logged in") +type CommandProcessor interface { + Handle(ctx context.Context, roomID id.RoomID, eventID id.EventID, user *User, message string, replyTo id.EventID) +} + type Bridge struct { ID networkid.BridgeID DB *database.Database @@ -32,7 +36,7 @@ type Bridge struct { Matrix MatrixConnector Bot MatrixAPI Network NetworkConnector - Commands *CommandProcessor + Commands CommandProcessor Config *bridgeconfig.BridgeConfig DisappearLoop *DisappearLoop @@ -45,7 +49,15 @@ type Bridge struct { cacheLock sync.Mutex } -func NewBridge(bridgeID networkid.BridgeID, db *dbutil.Database, log zerolog.Logger, cfg *bridgeconfig.BridgeConfig, matrix MatrixConnector, network NetworkConnector) *Bridge { +func NewBridge( + bridgeID networkid.BridgeID, + db *dbutil.Database, + log zerolog.Logger, + cfg *bridgeconfig.BridgeConfig, + matrix MatrixConnector, + network NetworkConnector, + newCommandProcessor func(*Bridge) CommandProcessor, +) *Bridge { br := &Bridge{ ID: bridgeID, DB: database.New(bridgeID, db), @@ -64,7 +76,7 @@ func NewBridge(bridgeID networkid.BridgeID, db *dbutil.Database, log zerolog.Log if br.Config == nil { br.Config = &bridgeconfig.BridgeConfig{CommandPrefix: "!bridge"} } - br.Commands = NewProcessor(br) + br.Commands = newCommandProcessor(br) br.Matrix.Init(br) br.Bot = br.Matrix.BotIntent() br.Network.Init(br) diff --git a/bridgev2/cmdmeta.go b/bridgev2/cmdmeta.go deleted file mode 100644 index d866998f..00000000 --- a/bridgev2/cmdmeta.go +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) 2022 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package bridgev2 - -var CommandHelp = &FullHandler{ - Func: func(ce *CommandEvent) { - ce.Reply(FormatHelp(ce)) - }, - Name: "help", - Help: HelpMeta{ - Section: HelpSectionGeneral, - Description: "Show this help message.", - }, -} - -var CommandCancel = &FullHandler{ - Func: func(ce *CommandEvent) { - state := ce.User.CommandState.Swap(nil) - if state != nil { - action := state.Action - if action == "" { - action = "Unknown action" - } - if state.Cancel != nil { - state.Cancel() - } - ce.Reply("%s cancelled.", action) - } else { - ce.Reply("No ongoing command.") - } - }, - Name: "cancel", - Help: HelpMeta{ - Section: HelpSectionGeneral, - Description: "Cancel an ongoing action.", - }, -} diff --git a/bridgev2/cmddebug.go b/bridgev2/commands/debug.go similarity index 86% rename from bridgev2/cmddebug.go rename to bridgev2/commands/debug.go index 400470ed..d00697ee 100644 --- a/bridgev2/cmddebug.go +++ b/bridgev2/commands/debug.go @@ -4,22 +4,23 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -package bridgev2 +package commands import ( "strings" + "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" ) var CommandRegisterPush = &FullHandler{ - Func: func(ce *CommandEvent) { + Func: func(ce *Event) { if len(ce.Args) < 3 { ce.Reply("Usage: `$cmdprefix debug-register-push `\n\nYour logins:\n\n%s", ce.User.GetFormattedUserLogins()) return } - pushType := PushTypeFromString(ce.Args[1]) - if pushType == PushTypeUnknown { + pushType := bridgev2.PushTypeFromString(ce.Args[1]) + if pushType == bridgev2.PushTypeUnknown { ce.Reply("Unknown push type `%s`. Allowed types: `web`, `apns`, `fcm`", ce.Args[1]) return } @@ -28,7 +29,7 @@ var CommandRegisterPush = &FullHandler{ ce.Reply("Login `%s` not found", ce.Args[0]) return } - pushable, ok := login.Client.(PushableNetworkAPI) + pushable, ok := login.Client.(bridgev2.PushableNetworkAPI) if !ok { ce.Reply("This network connector does not support push registration") return diff --git a/bridgev2/cmdevent.go b/bridgev2/commands/event.go similarity index 81% rename from bridgev2/cmdevent.go rename to bridgev2/commands/event.go index 0c80330d..d03c7e68 100644 --- a/bridgev2/cmdevent.go +++ b/bridgev2/commands/event.go @@ -4,7 +4,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -package bridgev2 +package commands import ( "context" @@ -14,22 +14,24 @@ import ( "github.com/rs/zerolog" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/format" "maunium.net/go/mautrix/id" ) -// CommandEvent stores all data which might be used to handle commands -type CommandEvent struct { - Bot MatrixAPI - Bridge *Bridge - Portal *Portal - Processor *CommandProcessor +// Event stores all data which might be used to handle commands +type Event struct { + Bot bridgev2.MatrixAPI + Bridge *bridgev2.Bridge + Portal *bridgev2.Portal + Processor *Processor Handler MinimalCommandHandler RoomID id.RoomID EventID id.EventID - User *User + User *bridgev2.User Command string Args []string RawArgs string @@ -39,7 +41,7 @@ type CommandEvent struct { } // Reply sends a reply to command as notice, with optional string formatting and automatic $cmdprefix replacement. -func (ce *CommandEvent) Reply(msg string, args ...any) { +func (ce *Event) Reply(msg string, args ...any) { msg = strings.ReplaceAll(msg, "$cmdprefix ", ce.Bridge.Config.CommandPrefix+" ") if len(args) > 0 { msg = fmt.Sprintf(msg, args...) @@ -49,7 +51,7 @@ func (ce *CommandEvent) Reply(msg string, args ...any) { // ReplyAdvanced sends a reply to command as notice. It allows using HTML and disabling markdown, // but doesn't have built-in string formatting. -func (ce *CommandEvent) ReplyAdvanced(msg string, allowMarkdown, allowHTML bool) { +func (ce *Event) ReplyAdvanced(msg string, allowMarkdown, allowHTML bool) { content := format.RenderMarkdown(msg, allowMarkdown, allowHTML) content.MsgType = event.MsgNotice _, err := ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: content}, time.Now()) @@ -59,7 +61,7 @@ func (ce *CommandEvent) ReplyAdvanced(msg string, allowMarkdown, allowHTML bool) } // React sends a reaction to the command. -func (ce *CommandEvent) React(key string) { +func (ce *Event) React(key string) { _, err := ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventReaction, &event.Content{ Parsed: &event.ReactionEventContent{ RelatesTo: event.RelatesTo{ @@ -75,7 +77,7 @@ func (ce *CommandEvent) React(key string) { } // Redact redacts the command. -func (ce *CommandEvent) Redact(req ...mautrix.ReqRedact) { +func (ce *Event) Redact(req ...mautrix.ReqRedact) { _, err := ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventRedaction, &event.Content{ Parsed: &event.RedactionEventContent{ Redacts: ce.EventID, @@ -87,7 +89,7 @@ func (ce *CommandEvent) Redact(req ...mautrix.ReqRedact) { } // MarkRead marks the command event as read. -func (ce *CommandEvent) MarkRead() { +func (ce *Event) MarkRead() { // TODO //err := ce.Bot.SendReceipt(ce.Ctx, ce.RoomID, ce.EventID, event.ReceiptTypeRead, nil) //if err != nil { diff --git a/bridgev2/cmdhandler.go b/bridgev2/commands/handler.go similarity index 85% rename from bridgev2/cmdhandler.go rename to bridgev2/commands/handler.go index 55db056f..b8ff7019 100644 --- a/bridgev2/cmdhandler.go +++ b/bridgev2/commands/handler.go @@ -4,19 +4,19 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -package bridgev2 +package commands import ( "maunium.net/go/mautrix/event" ) type MinimalCommandHandler interface { - Run(*CommandEvent) + Run(*Event) } -type MinimalCommandHandlerFunc func(*CommandEvent) +type MinimalCommandHandlerFunc func(*Event) -func (mhf MinimalCommandHandlerFunc) Run(ce *CommandEvent) { +func (mhf MinimalCommandHandlerFunc) Run(ce *Event) { mhf(ce) } @@ -38,7 +38,7 @@ type AliasedCommandHandler interface { } type FullHandler struct { - Func func(*CommandEvent) + Func func(*Event) Name string Aliases []string @@ -64,12 +64,12 @@ func (fh *FullHandler) GetAliases() []string { return fh.Aliases } -func (fh *FullHandler) ShowInHelp(ce *CommandEvent) bool { +func (fh *FullHandler) ShowInHelp(ce *Event) bool { return true //return !fh.RequiresAdmin || ce.User.GetPermissionLevel() >= bridgeconfig.PermissionLevelAdmin } -func (fh *FullHandler) userHasRoomPermission(ce *CommandEvent) bool { +func (fh *FullHandler) userHasRoomPermission(ce *Event) bool { return true //levels, err := ce.MainIntent().PowerLevels(ce.Ctx, ce.RoomID) //if err != nil { @@ -80,7 +80,7 @@ func (fh *FullHandler) userHasRoomPermission(ce *CommandEvent) bool { //return levels.GetUserLevel(ce.User.GetMXID()) >= levels.GetEventLevel(fh.RequiresEventLevel) } -func (fh *FullHandler) Run(ce *CommandEvent) { +func (fh *FullHandler) Run(ce *Event) { //if fh.RequiresAdmin && ce.User.GetPermissionLevel() < bridgeconfig.PermissionLevelAdmin { // ce.Reply("That command is limited to bridge administrators.") //} else if fh.RequiresEventLevel.Type != "" && ce.User.GetPermissionLevel() < bridgeconfig.PermissionLevelAdmin && !fh.userHasRoomPermission(ce) { diff --git a/bridgev2/cmdhelp.go b/bridgev2/commands/help.go similarity index 92% rename from bridgev2/cmdhelp.go rename to bridgev2/commands/help.go index 80b8e972..5c91a4d1 100644 --- a/bridgev2/cmdhelp.go +++ b/bridgev2/commands/help.go @@ -4,7 +4,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -package bridgev2 +package commands import ( "fmt" @@ -15,7 +15,7 @@ import ( type HelpfulHandler interface { CommandHandler GetHelp() HelpMeta - ShowInHelp(*CommandEvent) bool + ShowInHelp(*Event) bool } type HelpSection struct { @@ -78,7 +78,7 @@ func (h helpMetaList) Swap(i, j int) { var _ sort.Interface = (helpSectionList)(nil) var _ sort.Interface = (helpMetaList)(nil) -func FormatHelp(ce *CommandEvent) string { +func FormatHelp(ce *Event) string { sections := make(map[HelpSection]helpMetaList) for _, handler := range ce.Processor.handlers { helpfulHandler, ok := handler.(HelpfulHandler) @@ -128,3 +128,14 @@ func FormatHelp(ce *CommandEvent) string { } return output.String() } + +var CommandHelp = &FullHandler{ + Func: func(ce *Event) { + ce.Reply(FormatHelp(ce)) + }, + Name: "help", + Help: HelpMeta{ + Section: HelpSectionGeneral, + Description: "Show this help message.", + }, +} diff --git a/bridgev2/cmdlogin.go b/bridgev2/commands/login.go similarity index 84% rename from bridgev2/cmdlogin.go rename to bridgev2/commands/login.go index 160e6a96..b38bb21b 100644 --- a/bridgev2/cmdlogin.go +++ b/bridgev2/commands/login.go @@ -4,7 +4,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -package bridgev2 +package commands import ( "context" @@ -18,6 +18,8 @@ import ( "github.com/skip2/go-qrcode" "golang.org/x/net/html" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -33,7 +35,7 @@ var CommandLogin = &FullHandler{ }, } -func formatFlowsReply(flows []LoginFlow) string { +func formatFlowsReply(flows []bridgev2.LoginFlow) string { var buf strings.Builder for _, flow := range flows { _, _ = fmt.Fprintf(&buf, "* `%s` - %s\n", flow.ID, flow.Description) @@ -41,7 +43,7 @@ func formatFlowsReply(flows []LoginFlow) string { return buf.String() } -func fnLogin(ce *CommandEvent) { +func fnLogin(ce *Event) { flows := ce.Bridge.Network.GetLoginFlows() var chosenFlowID string if len(ce.Args) > 0 { @@ -77,14 +79,14 @@ func fnLogin(ce *CommandEvent) { } type userInputLoginCommandState struct { - Login LoginProcessUserInput + Login bridgev2.LoginProcessUserInput Data map[string]string - RemainingFields []LoginInputDataField + RemainingFields []bridgev2.LoginInputDataField } -func (uilcs *userInputLoginCommandState) promptNext(ce *CommandEvent) { +func (uilcs *userInputLoginCommandState) promptNext(ce *Event) { // TODO reply prompting field - ce.User.CommandState.Store(&CommandState{ + StoreCommandState(ce.User, &CommandState{ Next: MinimalCommandHandlerFunc(uilcs.submitNext), Action: "Login", Meta: uilcs, @@ -92,7 +94,7 @@ func (uilcs *userInputLoginCommandState) promptNext(ce *CommandEvent) { }) } -func (uilcs *userInputLoginCommandState) submitNext(ce *CommandEvent) { +func (uilcs *userInputLoginCommandState) submitNext(ce *Event) { field := uilcs.RemainingFields[0] field.FillDefaultValidate() var err error @@ -105,7 +107,7 @@ func (uilcs *userInputLoginCommandState) submitNext(ce *CommandEvent) { uilcs.promptNext(ce) return } - ce.User.CommandState.Store(nil) + StoreCommandState(ce.User, nil) if nextStep, err := uilcs.Login.SubmitUserInput(ce.Ctx, uilcs.Data); err != nil { ce.Reply("Failed to submit input: %v", err) } else { @@ -115,7 +117,7 @@ func (uilcs *userInputLoginCommandState) submitNext(ce *CommandEvent) { const qrSizePx = 512 -func sendQR(ce *CommandEvent, qr string, prevEventID *id.EventID) error { +func sendQR(ce *Event, qr string, prevEventID *id.EventID) error { qrData, err := qrcode.Encode(qr, qrcode.Low, qrSizePx) if err != nil { return fmt.Errorf("failed to encode QR code: %w", err) @@ -153,25 +155,25 @@ const ( contextKeyPrevEventID contextKey = iota ) -func doLoginDisplayAndWait(ce *CommandEvent, login LoginProcessDisplayAndWait, step *LoginStep) { +func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait, step *bridgev2.LoginStep) { prevEvent, ok := ce.Ctx.Value(contextKeyPrevEventID).(*id.EventID) if !ok { prevEvent = new(id.EventID) ce.Ctx = context.WithValue(ce.Ctx, contextKeyPrevEventID, prevEvent) } switch step.DisplayAndWaitParams.Type { - case LoginDisplayTypeQR: + case bridgev2.LoginDisplayTypeQR: err := sendQR(ce, step.DisplayAndWaitParams.Data, prevEvent) if err != nil { ce.Reply("Failed to send QR code: %v", err) login.Cancel() return } - case LoginDisplayTypeEmoji: + case bridgev2.LoginDisplayTypeEmoji: ce.ReplyAdvanced(step.DisplayAndWaitParams.Data, false, false) - case LoginDisplayTypeCode: + case bridgev2.LoginDisplayTypeCode: ce.ReplyAdvanced(fmt.Sprintf("%s", html.EscapeString(step.DisplayAndWaitParams.Data)), false, true) - case LoginDisplayTypeNothing: + case bridgev2.LoginDisplayTypeNothing: // Do nothing default: ce.Reply("Unsupported display type %q", step.DisplayAndWaitParams.Type) @@ -196,12 +198,12 @@ func doLoginDisplayAndWait(ce *CommandEvent, login LoginProcessDisplayAndWait, s } type cookieLoginCommandState struct { - Login LoginProcessCookies - Data *LoginCookiesParams + Login bridgev2.LoginProcessCookies + Data *bridgev2.LoginCookiesParams } -func (clcs *cookieLoginCommandState) prompt(ce *CommandEvent) { - ce.User.CommandState.Store(&CommandState{ +func (clcs *cookieLoginCommandState) prompt(ce *Event) { + StoreCommandState(ce.User, &CommandState{ Next: MinimalCommandHandlerFunc(clcs.submit), Action: "Login", Meta: clcs, @@ -220,7 +222,7 @@ func missingKeys(required []string, data map[string]string) (missing []string) { return } -func (clcs *cookieLoginCommandState) submit(ce *CommandEvent) { +func (clcs *cookieLoginCommandState) submit(ce *Event) { ce.Redact() cookies := make(map[string]string) @@ -260,7 +262,7 @@ func (clcs *cookieLoginCommandState) submit(ce *CommandEvent) { ce.Reply("Missing required special keys: %+v", missingSpecial) return } - ce.User.CommandState.Store(nil) + StoreCommandState(ce.User, nil) nextStep, err := clcs.Login.SubmitCookies(ce.Ctx, cookies) if err != nil { ce.Reply("Login failed: %v", err) @@ -268,24 +270,24 @@ func (clcs *cookieLoginCommandState) submit(ce *CommandEvent) { doLoginStep(ce, clcs.Login, nextStep) } -func doLoginStep(ce *CommandEvent, login LoginProcess, step *LoginStep) { +func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginStep) { ce.Reply(step.Instructions) switch step.Type { - case LoginStepTypeDisplayAndWait: - doLoginDisplayAndWait(ce, login.(LoginProcessDisplayAndWait), step) - case LoginStepTypeCookies: + case bridgev2.LoginStepTypeDisplayAndWait: + doLoginDisplayAndWait(ce, login.(bridgev2.LoginProcessDisplayAndWait), step) + case bridgev2.LoginStepTypeCookies: (&cookieLoginCommandState{ - Login: login.(LoginProcessCookies), + Login: login.(bridgev2.LoginProcessCookies), Data: step.CookiesParams, }).prompt(ce) - case LoginStepTypeUserInput: + case bridgev2.LoginStepTypeUserInput: (&userInputLoginCommandState{ - Login: login.(LoginProcessUserInput), + Login: login.(bridgev2.LoginProcessUserInput), RemainingFields: step.UserInputParams.Fields, Data: make(map[string]string), }).promptNext(ce) - case LoginStepTypeComplete: + case bridgev2.LoginStepTypeComplete: // Nothing to do other than instructions default: panic(fmt.Errorf("unknown login step type %q", step.Type)) @@ -302,7 +304,7 @@ var CommandLogout = &FullHandler{ }, } -func fnLogout(ce *CommandEvent) { +func fnLogout(ce *Event) { if len(ce.Args) == 0 { ce.Reply("Usage: `$cmdprefix logout `\n\nYour logins:\n\n%s", ce.User.GetFormattedUserLogins()) return @@ -328,7 +330,7 @@ var CommandSetPreferredLogin = &FullHandler{ RequiresPortal: true, } -func fnSetPreferredLogin(ce *CommandEvent) { +func fnSetPreferredLogin(ce *Event) { if len(ce.Args) == 0 { ce.Reply("Usage: `$cmdprefix set-preferred-login `\n\nYour logins:\n\n%s", ce.User.GetFormattedUserLogins()) return diff --git a/bridgev2/cmdprocessor.go b/bridgev2/commands/processor.go similarity index 68% rename from bridgev2/cmdprocessor.go rename to bridgev2/commands/processor.go index 32428fc5..eb08cae2 100644 --- a/bridgev2/cmdprocessor.go +++ b/bridgev2/commands/processor.go @@ -4,32 +4,36 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -package bridgev2 +package commands import ( "context" "fmt" "runtime/debug" "strings" + "sync/atomic" + "unsafe" "github.com/rs/zerolog" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) -type CommandProcessor struct { - bridge *Bridge +type Processor struct { + bridge *bridgev2.Bridge log *zerolog.Logger handlers map[string]CommandHandler aliases map[string]string } -// NewProcessor creates a CommandProcessor -func NewProcessor(bridge *Bridge) *CommandProcessor { - proc := &CommandProcessor{ +// NewProcessor creates a Processor +func NewProcessor(bridge *bridgev2.Bridge) bridgev2.CommandProcessor { + proc := &Processor{ bridge: bridge, log: &bridge.Log, @@ -45,13 +49,13 @@ func NewProcessor(bridge *Bridge) *CommandProcessor { return proc } -func (proc *CommandProcessor) AddHandlers(handlers ...CommandHandler) { +func (proc *Processor) AddHandlers(handlers ...CommandHandler) { for _, handler := range handlers { proc.AddHandler(handler) } } -func (proc *CommandProcessor) AddHandler(handler CommandHandler) { +func (proc *Processor) AddHandler(handler CommandHandler) { proc.handlers[handler.GetName()] = handler aliased, ok := handler.(AliasedCommandHandler) if ok { @@ -62,15 +66,15 @@ func (proc *CommandProcessor) AddHandler(handler CommandHandler) { } // Handle handles messages to the bridge -func (proc *CommandProcessor) Handle(ctx context.Context, roomID id.RoomID, eventID id.EventID, user *User, message string, replyTo id.EventID) { +func (proc *Processor) Handle(ctx context.Context, roomID id.RoomID, eventID id.EventID, user *bridgev2.User, message string, replyTo id.EventID) { defer func() { - statusInfo := &MessageStatusEventInfo{ + statusInfo := &bridgev2.MessageStatusEventInfo{ RoomID: roomID, EventID: eventID, EventType: event.EventMessage, Sender: user.MXID, } - ms := MessageStatus{ + ms := bridgev2.MessageStatus{ Step: status.MsgStepCommand, Status: event.MessageStatusSuccess, } @@ -101,7 +105,7 @@ func (proc *CommandProcessor) Handle(ctx context.Context, roomID id.RoomID, even if err != nil { // :( } - ce := &CommandEvent{ + ce := &Event{ Bot: proc.bridge.Bot, Bridge: proc.bridge, Portal: portal, @@ -124,7 +128,7 @@ func (proc *CommandProcessor) Handle(ctx context.Context, roomID id.RoomID, even var handler MinimalCommandHandler handler, ok = proc.handlers[realCommand] if !ok { - state := ce.User.CommandState.Load() + state := LoadCommandState(ce.User) if state != nil && state.Next != nil { ce.Command = "" ce.RawArgs = message @@ -149,3 +153,38 @@ func (proc *CommandProcessor) Handle(ctx context.Context, roomID id.RoomID, even handler.Run(ce) } } + +func LoadCommandState(user *bridgev2.User) *CommandState { + return (*CommandState)(atomic.LoadPointer(&user.CommandState)) +} + +func StoreCommandState(user *bridgev2.User, cs *CommandState) { + atomic.StorePointer(&user.CommandState, unsafe.Pointer(cs)) +} + +func SwapCommandState(user *bridgev2.User, cs *CommandState) *CommandState { + return (*CommandState)(atomic.SwapPointer(&user.CommandState, unsafe.Pointer(cs))) +} + +var CommandCancel = &FullHandler{ + Func: func(ce *Event) { + state := SwapCommandState(ce.User, nil) + if state != nil { + action := state.Action + if action == "" { + action = "Unknown action" + } + if state.Cancel != nil { + state.Cancel() + } + ce.Reply("%s cancelled.", action) + } else { + ce.Reply("No ongoing command.") + } + }, + Name: "cancel", + Help: HelpMeta{ + Section: HelpSectionGeneral, + Description: "Cancel an ongoing action.", + }, +} diff --git a/bridgev2/cmdstartchat.go b/bridgev2/commands/startchat.go similarity index 91% rename from bridgev2/cmdstartchat.go rename to bridgev2/commands/startchat.go index a0530cdb..903fc17f 100644 --- a/bridgev2/cmdstartchat.go +++ b/bridgev2/commands/startchat.go @@ -4,7 +4,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -package bridgev2 +package commands import ( "fmt" @@ -13,6 +13,8 @@ import ( "golang.org/x/net/html" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/id" ) @@ -39,7 +41,7 @@ var CommandStartChat = &FullHandler{ RequiresLogin: true, } -func getClientForStartingChat[T IdentifierResolvingNetworkAPI](ce *CommandEvent, thing string) (*UserLogin, T, []string) { +func getClientForStartingChat[T bridgev2.IdentifierResolvingNetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) { remainingArgs := ce.Args[1:] login := ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) if login == nil || login.UserMXID != ce.User.MXID { @@ -53,8 +55,8 @@ func getClientForStartingChat[T IdentifierResolvingNetworkAPI](ce *CommandEvent, return login, api, remainingArgs } -func fnResolveIdentifier(ce *CommandEvent) { - login, api, identifierParts := getClientForStartingChat[IdentifierResolvingNetworkAPI](ce, "resolving identifiers") +func fnResolveIdentifier(ce *Event) { + login, api, identifierParts := getClientForStartingChat[bridgev2.IdentifierResolvingNetworkAPI](ce, "resolving identifiers") if api == nil { return } @@ -125,7 +127,7 @@ func fnResolveIdentifier(ce *CommandEvent) { } var CommandDeletePortal = &FullHandler{ - Func: func(ce *CommandEvent) { + Func: func(ce *Event) { err := ce.Portal.Delete(ce.Ctx) if err != nil { ce.Reply("Failed to delete portal: %v", err) diff --git a/bridgev2/matrix/cmdadmin.go b/bridgev2/matrix/cmdadmin.go index 45a83b4f..0bd3eb82 100644 --- a/bridgev2/matrix/cmdadmin.go +++ b/bridgev2/matrix/cmdadmin.go @@ -9,12 +9,12 @@ package matrix import ( "strconv" - "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/commands" "maunium.net/go/mautrix/id" ) -var CommandDiscardMegolmSession = &bridgev2.FullHandler{ - Func: func(ce *bridgev2.CommandEvent) { +var CommandDiscardMegolmSession = &commands.FullHandler{ + Func: func(ce *commands.Event) { matrix := ce.Bridge.Matrix.(*Connector) if matrix.Crypto == nil { ce.Reply("This bridge instance doesn't have end-to-bridge encryption enabled") @@ -25,14 +25,14 @@ var CommandDiscardMegolmSession = &bridgev2.FullHandler{ }, Name: "discard-megolm-session", Aliases: []string{"discard-session"}, - Help: bridgev2.HelpMeta{ - Section: bridgev2.HelpSectionAdmin, + Help: commands.HelpMeta{ + Section: commands.HelpSectionAdmin, Description: "Discard the Megolm session in the room", }, RequiresAdmin: true, } -func fnSetPowerLevel(ce *bridgev2.CommandEvent) { +func fnSetPowerLevel(ce *commands.Event) { var level int var userID id.UserID var err error @@ -65,12 +65,12 @@ func fnSetPowerLevel(ce *bridgev2.CommandEvent) { } } -var CommandSetPowerLevel = &bridgev2.FullHandler{ +var CommandSetPowerLevel = &commands.FullHandler{ Func: fnSetPowerLevel, Name: "set-pl", Aliases: []string{"set-power-level"}, - Help: bridgev2.HelpMeta{ - Section: bridgev2.HelpSectionAdmin, + Help: commands.HelpMeta{ + Section: commands.HelpSectionAdmin, Description: "Change the power level in a portal room.", Args: "[_user ID_] <_power level_>", }, diff --git a/bridgev2/matrix/cmddoublepuppet.go b/bridgev2/matrix/cmddoublepuppet.go index 13d24f54..29175138 100644 --- a/bridgev2/matrix/cmddoublepuppet.go +++ b/bridgev2/matrix/cmddoublepuppet.go @@ -7,21 +7,21 @@ package matrix import ( - "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/commands" ) -var CommandLoginMatrix = &bridgev2.FullHandler{ +var CommandLoginMatrix = &commands.FullHandler{ Func: fnLoginMatrix, Name: "login-matrix", - Help: bridgev2.HelpMeta{ - Section: bridgev2.HelpSectionAuth, + Help: commands.HelpMeta{ + Section: commands.HelpSectionAuth, Description: "Enable double puppeting.", Args: "<_access token_>", }, RequiresLogin: true, } -func fnLoginMatrix(ce *bridgev2.CommandEvent) { +func fnLoginMatrix(ce *commands.Event) { if len(ce.Args) == 0 { ce.Reply("**Usage:** `login-matrix `") return @@ -34,16 +34,16 @@ func fnLoginMatrix(ce *bridgev2.CommandEvent) { } } -var CommandPingMatrix = &bridgev2.FullHandler{ +var CommandPingMatrix = &commands.FullHandler{ Func: fnPingMatrix, Name: "ping-matrix", - Help: bridgev2.HelpMeta{ - Section: bridgev2.HelpSectionAuth, + Help: commands.HelpMeta{ + Section: commands.HelpSectionAuth, Description: "Ping the Matrix server with the double puppet.", }, } -func fnPingMatrix(ce *bridgev2.CommandEvent) { +func fnPingMatrix(ce *commands.Event) { intent := ce.User.DoublePuppet(ce.Ctx) if intent == nil { ce.Reply("You don't have double puppeting enabled.") @@ -62,17 +62,17 @@ func fnPingMatrix(ce *bridgev2.CommandEvent) { } } -var CommandLogoutMatrix = &bridgev2.FullHandler{ +var CommandLogoutMatrix = &commands.FullHandler{ Func: fnLogoutMatrix, Name: "logout-matrix", - Help: bridgev2.HelpMeta{ - Section: bridgev2.HelpSectionAuth, + Help: commands.HelpMeta{ + Section: commands.HelpSectionAuth, Description: "Disable double puppeting.", }, RequiresLogin: true, } -func fnLogoutMatrix(ce *bridgev2.CommandEvent) { +func fnLogoutMatrix(ce *commands.Event) { if ce.User.AccessToken == "" { ce.Reply("You don't have double puppeting enabled.") return diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 4b913c21..1910c48d 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -26,6 +26,7 @@ import ( "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/bridgeconfig" + "maunium.net/go/mautrix/bridgev2/commands" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -117,7 +118,7 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) { br.EventProcessor.On(event.EphemeralEventTyping, br.handleEphemeralEvent) br.Bot = br.AS.BotIntent() br.Crypto = NewCryptoHelper(br) - br.Bridge.Commands.AddHandlers( + br.Bridge.Commands.(*commands.Processor).AddHandlers( CommandDiscardMegolmSession, CommandSetPowerLevel, CommandLoginMatrix, CommandPingMatrix, CommandLogoutMatrix, ) diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index 9a5326a6..8434e9c7 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -31,6 +31,7 @@ import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/bridgeconfig" + "maunium.net/go/mautrix/bridgev2/commands" "maunium.net/go/mautrix/bridgev2/matrix" ) @@ -226,15 +227,15 @@ func (br *BridgeMain) Init() { br.initDB() br.Matrix = matrix.NewConnector(br.Config) br.Matrix.IgnoreUnsupportedServer = *ignoreUnsupportedServer - br.Bridge = bridgev2.NewBridge("", br.DB, *br.Log, &br.Config.Bridge, br.Matrix, br.Connector) + br.Bridge = bridgev2.NewBridge("", br.DB, *br.Log, &br.Config.Bridge, br.Matrix, br.Connector, commands.NewProcessor) br.Matrix.AS.DoublePuppetValue = br.Name - br.Bridge.Commands.AddHandler(&bridgev2.FullHandler{ - Func: func(ce *bridgev2.CommandEvent) { + br.Bridge.Commands.(*commands.Processor).AddHandler(&commands.FullHandler{ + Func: func(ce *commands.Event) { ce.Reply("[%s](%s) %s (%s)", br.Name, br.URL, br.LinkifiedVersion, br.BuildTime.Format(time.RFC1123)) }, Name: "version", - Help: bridgev2.HelpMeta{ - Section: bridgev2.HelpSectionGeneral, + Help: commands.HelpMeta{ + Section: commands.HelpSectionGeneral, Description: "Get the bridge version.", }, }) diff --git a/bridgev2/user.go b/bridgev2/user.go index 9fca8de3..50054a77 100644 --- a/bridgev2/user.go +++ b/bridgev2/user.go @@ -11,7 +11,7 @@ import ( "fmt" "strings" "sync" - "sync/atomic" + "unsafe" "github.com/rs/zerolog" "golang.org/x/exp/maps" @@ -27,7 +27,7 @@ type User struct { Bridge *Bridge Log zerolog.Logger - CommandState atomic.Pointer[CommandState] + CommandState unsafe.Pointer doublePuppetIntent MatrixAPI doublePuppetInitialized bool From 07f7849c287d9cbc6b0df70015a6c37c6070d0f7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 29 Jun 2024 13:23:22 +0300 Subject: [PATCH 0374/1647] bridgev2: add permissions --- bridgev2/bridgeconfig/appservice.go | 5 +- bridgev2/bridgeconfig/config.go | 8 +- bridgev2/bridgeconfig/permissions.go | 121 +++++++++++++-------- bridgev2/bridgeconfig/relay.go | 10 +- bridgev2/bridgeconfig/upgrade.go | 4 + bridgev2/commands/handler.go | 49 +++++---- bridgev2/commands/login.go | 4 +- bridgev2/matrix/cmddoublepuppet.go | 8 ++ bridgev2/matrix/mxmain/example-config.yaml | 16 +++ bridgev2/matrix/mxmain/main.go | 4 + bridgev2/queue.go | 71 ++++++++---- bridgev2/user.go | 11 +- 12 files changed, 208 insertions(+), 103 deletions(-) diff --git a/bridgev2/bridgeconfig/appservice.go b/bridgev2/bridgeconfig/appservice.go index 37ed9306..7466636b 100644 --- a/bridgev2/bridgeconfig/appservice.go +++ b/bridgev2/bridgeconfig/appservice.go @@ -14,6 +14,7 @@ import ( "go.mau.fi/util/exerrors" "go.mau.fi/util/random" + "gopkg.in/yaml.v3" "maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/id" @@ -114,9 +115,9 @@ type BotUserConfig struct { type serializableBUC BotUserConfig -func (buc *BotUserConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { +func (buc *BotUserConfig) UnmarshalYAML(node *yaml.Node) error { var sbuc serializableBUC - err := unmarshal(&sbuc) + err := node.Decode(&sbuc) if err != nil { return err } diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index 90cf7b29..623ee446 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -27,14 +27,14 @@ type Config struct { Encryption EncryptionConfig `yaml:"encryption"` Logging zeroconfig.Config `yaml:"logging"` - Permissions PermissionConfig `yaml:"permissions"` ManagementRoomTexts ManagementRoomTexts `yaml:"management_room_texts"` } type BridgeConfig struct { - CommandPrefix string `yaml:"command_prefix"` - PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"` - Relay RelayConfig `yaml:"relay"` + CommandPrefix string `yaml:"command_prefix"` + PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"` + Relay RelayConfig `yaml:"relay"` + Permissions PermissionConfig `yaml:"permissions"` } type MatrixConfig struct { diff --git a/bridgev2/bridgeconfig/permissions.go b/bridgev2/bridgeconfig/permissions.go index 198e140e..15b4561d 100644 --- a/bridgev2/bridgeconfig/permissions.go +++ b/bridgev2/bridgeconfig/permissions.go @@ -7,65 +7,100 @@ package bridgeconfig import ( - "strconv" + "fmt" "strings" + "gopkg.in/yaml.v3" + "maunium.net/go/mautrix/id" ) -type PermissionConfig map[string]PermissionLevel - -type PermissionLevel int - -const ( - PermissionLevelBlock PermissionLevel = 0 - PermissionLevelRelay PermissionLevel = 5 - PermissionLevelUser PermissionLevel = 10 - PermissionLevelAdmin PermissionLevel = 100 -) - -var namesToLevels = map[string]PermissionLevel{ - "block": PermissionLevelBlock, - "relay": PermissionLevelRelay, - "user": PermissionLevelUser, - "admin": PermissionLevelAdmin, +type Permissions struct { + SendEvents bool `yaml:"send_events"` + Commands bool `yaml:"commands"` + Login bool `yaml:"login"` + DoublePuppet bool `yaml:"double_puppet"` + Admin bool `yaml:"admin"` } -func RegisterPermissionLevel(name string, level PermissionLevel) { - namesToLevels[name] = level +type PermissionConfig map[string]*Permissions + +func boolToInt(val bool) int { + if val { + return 1 + } + return 0 } -func (pc *PermissionConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { - rawPC := make(map[string]string) - err := unmarshal(&rawPC) - if err != nil { - return err +func (pc PermissionConfig) IsConfigured() bool { + _, hasWildcard := pc["*"] + _, hasExampleDomain := pc["example.com"] + _, hasExampleUser := pc["@admin:example.com"] + exampleLen := boolToInt(hasWildcard) + boolToInt(hasExampleUser) + boolToInt(hasExampleDomain) + if len(pc) <= exampleLen { + return false } - - if *pc == nil { - *pc = make(map[string]PermissionLevel) - } - for key, value := range rawPC { - level, ok := namesToLevels[strings.ToLower(value)] - if ok { - (*pc)[key] = level - } else if val, err := strconv.Atoi(value); err == nil { - (*pc)[key] = PermissionLevel(val) - } else { - (*pc)[key] = PermissionLevelBlock - } - } - return nil + return true } -func (pc PermissionConfig) Get(userID id.UserID) PermissionLevel { +func (pc PermissionConfig) Get(userID id.UserID) Permissions { if level, ok := pc[string(userID)]; ok { - return level + return *level } else if level, ok = pc[userID.Homeserver()]; len(userID.Homeserver()) > 0 && ok { - return level + return *level } else if level, ok = pc["*"]; ok { - return level + return *level } else { return PermissionLevelBlock } } + +var ( + PermissionLevelBlock = Permissions{} + PermissionLevelRelay = Permissions{SendEvents: true} + PermissionLevelCommands = Permissions{SendEvents: true, Commands: true} + PermissionLevelUser = Permissions{SendEvents: true, Commands: true, Login: true, DoublePuppet: true} + PermissionLevelAdmin = Permissions{SendEvents: true, Commands: true, Login: true, DoublePuppet: true, Admin: true} +) + +var namesToLevels = map[string]Permissions{ + "block": PermissionLevelBlock, + "relay": PermissionLevelRelay, + "commands": PermissionLevelCommands, + "user": PermissionLevelUser, + "admin": PermissionLevelAdmin, +} + +var levelsToNames = map[Permissions]string{ + PermissionLevelBlock: "block", + PermissionLevelRelay: "relay", + PermissionLevelCommands: "commands", + PermissionLevelUser: "user", + PermissionLevelAdmin: "admin", +} + +type umPerm Permissions + +func (p *Permissions) UnmarshalYAML(perm *yaml.Node) error { + switch perm.Tag { + case "!!str": + var ok bool + *p, ok = namesToLevels[strings.ToLower(perm.Value)] + if !ok { + return fmt.Errorf("invalid permissions level %s", perm.Value) + } + return nil + case "!!map": + err := perm.Decode((*umPerm)(p)) + return err + default: + return fmt.Errorf("invalid permissions type %s", perm.Tag) + } +} + +func (p *Permissions) MarshalYAML() (any, error) { + if level, ok := levelsToNames[*p]; ok { + return level, nil + } + return umPerm(*p), nil +} diff --git a/bridgev2/bridgeconfig/relay.go b/bridgev2/bridgeconfig/relay.go index e71a7d1d..7daf8e38 100644 --- a/bridgev2/bridgeconfig/relay.go +++ b/bridgev2/bridgeconfig/relay.go @@ -11,6 +11,8 @@ import ( "strings" "text/template" + "gopkg.in/yaml.v3" + "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/format" @@ -26,15 +28,15 @@ type RelayConfig struct { type umRelayConfig RelayConfig -func (rc *RelayConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { - err := unmarshal((*umRelayConfig)(rc)) +func (rc *RelayConfig) UnmarshalYAML(node *yaml.Node) error { + err := node.Decode((*umRelayConfig)(rc)) if err != nil { return err } rc.messageTemplates = template.New("messageTemplates") - for key, format := range rc.MessageFormats { - _, err = rc.messageTemplates.New(string(key)).Parse(format) + for key, template := range rc.MessageFormats { + _, err = rc.messageTemplates.New(string(key)).Parse(template) if err != nil { return err } diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 766d6fdf..7908c268 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -28,6 +28,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "bridge", "relay", "admin_only") helper.Copy(up.List, "bridge", "relay", "default_relays") helper.Copy(up.Map, "bridge", "relay", "message_formats") + helper.Copy(up.Map, "bridge", "permissions") if dbType, ok := helper.Get(up.Str, "database", "type"); ok && dbType == "sqlite3" { helper.Set(up.Str, "sqlite3-fk-wal", "database", "type") @@ -172,6 +173,7 @@ func doMigrateLegacy(helper up.Helper) { helper.Copy(up.Bool, "bridge", "personal_filtering_spaces") helper.Copy(up.Bool, "bridge", "relay", "enabled") helper.Copy(up.Bool, "bridge", "relay", "admin_only") + helper.Copy(up.Map, "bridge", "permissions") CopyToOtherLocation(helper, up.Str, []string{"appservice", "database", "type"}, []string{"database", "type"}) CopyToOtherLocation(helper, up.Str, []string{"appservice", "database", "uri"}, []string{"database", "uri"}) @@ -232,6 +234,8 @@ func doMigrateLegacy(helper up.Helper) { var SpacedBlocks = [][]string{ {"bridge"}, + {"bridge", "relay"}, + {"bridge", "permissions"}, {"database"}, {"homeserver"}, {"homeserver", "software"}, diff --git a/bridgev2/commands/handler.go b/bridgev2/commands/handler.go index b8ff7019..c1daf1af 100644 --- a/bridgev2/commands/handler.go +++ b/bridgev2/commands/handler.go @@ -44,11 +44,11 @@ type FullHandler struct { Aliases []string Help HelpMeta - RequiresAdmin bool - RequiresPortal bool - RequiresLogin bool - - RequiresEventLevel event.Type + RequiresAdmin bool + RequiresPortal bool + RequiresLogin bool + RequiresEventLevel event.Type + RequiresLoginPermission bool } func (fh *FullHandler) GetHelp() HelpMeta { @@ -70,26 +70,27 @@ func (fh *FullHandler) ShowInHelp(ce *Event) bool { } func (fh *FullHandler) userHasRoomPermission(ce *Event) bool { - return true - //levels, err := ce.MainIntent().PowerLevels(ce.Ctx, ce.RoomID) - //if err != nil { - // ce.ZLog.Warn().Err(err).Msg("Failed to check room power levels") - // ce.Reply("Failed to get room power levels to see if you're allowed to use that command") - // return false - //} - //return levels.GetUserLevel(ce.User.GetMXID()) >= levels.GetEventLevel(fh.RequiresEventLevel) + levels, err := ce.Bridge.Matrix.GetPowerLevels(ce.Ctx, ce.RoomID) + if err != nil { + ce.Log.Warn().Err(err).Msg("Failed to check room power levels") + ce.Reply("Failed to get room power levels to see if you're allowed to use that command") + return false + } + return levels.GetUserLevel(ce.User.MXID) >= levels.GetEventLevel(fh.RequiresEventLevel) } func (fh *FullHandler) Run(ce *Event) { - //if fh.RequiresAdmin && ce.User.GetPermissionLevel() < bridgeconfig.PermissionLevelAdmin { - // ce.Reply("That command is limited to bridge administrators.") - //} else if fh.RequiresEventLevel.Type != "" && ce.User.GetPermissionLevel() < bridgeconfig.PermissionLevelAdmin && !fh.userHasRoomPermission(ce) { - // ce.Reply("That command requires room admin rights.") - //} else if fh.RequiresPortal && ce.Portal == nil { - // ce.Reply("That command can only be ran in portal rooms.") - //} else if fh.RequiresLogin && !ce.User.IsLoggedIn() { - // ce.Reply("That command requires you to be logged in.") - //} else { - fh.Func(ce) - //} + if fh.RequiresAdmin && !ce.User.Permissions.Admin { + ce.Reply("That command is limited to bridge administrators.") + } else if fh.RequiresLoginPermission && !ce.User.Permissions.Login { + ce.Reply("You do not have permissions to log into this bridge.") + } else if fh.RequiresEventLevel.Type != "" && !ce.User.Permissions.Admin && !fh.userHasRoomPermission(ce) { + ce.Reply("That command requires room admin rights.") + } else if fh.RequiresPortal && ce.Portal == nil { + ce.Reply("That command can only be ran in portal rooms.") + } else if fh.RequiresLogin && ce.User.GetDefaultLogin() == nil { + ce.Reply("That command requires you to be logged in.") + } else { + fh.Func(ce) + } } diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go index b38bb21b..233e3d80 100644 --- a/bridgev2/commands/login.go +++ b/bridgev2/commands/login.go @@ -33,6 +33,7 @@ var CommandLogin = &FullHandler{ Description: "Log into the bridge", Args: "[_flow ID_]", }, + RequiresLoginPermission: true, } func formatFlowsReply(flows []bridgev2.LoginFlow) string { @@ -327,7 +328,8 @@ var CommandSetPreferredLogin = &FullHandler{ Description: "Set the preferred login ID for sending messages to this portal (only relevant when logged into multiple accounts via the bridge)", Args: "<_login ID_>", }, - RequiresPortal: true, + RequiresPortal: true, + RequiresLoginPermission: true, } func fnSetPreferredLogin(ce *Event) { diff --git a/bridgev2/matrix/cmddoublepuppet.go b/bridgev2/matrix/cmddoublepuppet.go index 29175138..2f3a3dc2 100644 --- a/bridgev2/matrix/cmddoublepuppet.go +++ b/bridgev2/matrix/cmddoublepuppet.go @@ -22,6 +22,10 @@ var CommandLoginMatrix = &commands.FullHandler{ } func fnLoginMatrix(ce *commands.Event) { + if !ce.User.Permissions.DoublePuppet { + ce.Reply("You don't have permission to manage double puppeting.") + return + } if len(ce.Args) == 0 { ce.Reply("**Usage:** `login-matrix `") return @@ -73,6 +77,10 @@ var CommandLogoutMatrix = &commands.FullHandler{ } func fnLogoutMatrix(ce *commands.Event) { + if !ce.User.Permissions.DoublePuppet { + ce.Reply("You don't have permission to manage double puppeting.") + return + } if ce.User.AccessToken == "" { ce.Reply("You don't have double puppeting enabled.") return diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index b6559657..e040f7df 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -4,6 +4,7 @@ bridge: command_prefix: '$<>' # Should the bridge create a space for each login containing the rooms that account is in? personal_filtering_spaces: true + # Settings for relay mode relay: # Whether relay mode should be allowed. If allowed, the set-relay command can be used to turn any @@ -24,6 +25,21 @@ bridge: m.video: "{{ .Sender.Displayname }} sent a video{{ if .Caption }}: {{ .Caption }}{{ end }}" m.location: "{{ .Sender.Displayname }} sent a location{{ if .Caption }}: {{ .Caption }}{{ end }}" + # Permissions for using the bridge. + # Permitted values: + # relay - Talk through the relaybot (if enabled), no access otherwise + # commands - Access to use commands in the bridge, but not login. + # user - Access to use the bridge with puppeting. + # admin - Full access, user level with some additional administration tools. + # Permitted keys: + # * - All Matrix users + # domain - All users on that homeserver + # mxid - Specific user + permissions: + "*": relay + "example.com": user + "@admin:example.com": admin + # Config for the bridge's database. database: # The database type. "sqlite3-fk-wal" and "postgres" are supported. diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index 8434e9c7..d786e215 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -287,6 +287,10 @@ func (br *BridgeMain) validateConfig() error { return errors.New("appservice.hs_token not configured. Did you forget to generate the registration? ") case br.Config.Database.URI == "postgres://user:password@host/database?sslmode=disable": return errors.New("appservice.database not configured") + case !br.Config.Bridge.Permissions.IsConfigured(): + return errors.New("bridge.permissions not configured") + case !strings.Contains(br.Config.AppService.FormatUsername("1234567890"), "1234567890"): + return errors.New("username template is missing user ID placeholder") default: cfgValidator, ok := br.Connector.(bridgev2.ConfigValidatingNetwork) if ok { diff --git a/bridgev2/queue.go b/bridgev2/queue.go index 10c47960..1c874d10 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "strings" + "time" "github.com/rs/zerolog" @@ -30,38 +31,66 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) { status := WrapErrorInStatus(fmt.Errorf("%w: failed to get sender user: %w", ErrDatabaseError, err)) br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) return + } else if sender == nil { + log.Error().Msg("Couldn't get sender for incoming non-ephemeral Matrix event") + status := WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage() + br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) + return + } else if !sender.Permissions.SendEvents { + status := WrapErrorInStatus(errors.New("you don't have permission to send messages")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage() + br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) + return } - } - if sender == nil && evt.Type.Class != event.EphemeralEventType { + } else if evt.Type.Class != event.EphemeralEventType { log.Error().Msg("Missing sender for incoming non-ephemeral Matrix event") status := WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage() br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) return } - if evt.Type == event.EventMessage { + if evt.Type == event.EventMessage && sender != nil { msg := evt.Content.AsMessage() - if msg != nil { - msg.RemoveReplyFallback() - - if strings.HasPrefix(msg.Body, br.Config.CommandPrefix) || evt.RoomID == sender.ManagementRoom { - br.Commands.Handle( - ctx, - evt.RoomID, - evt.ID, - sender, - strings.TrimPrefix(msg.Body, br.Config.CommandPrefix+" "), - msg.RelatesTo.GetReplyTo(), - ) + msg.RemoveReplyFallback() + if strings.HasPrefix(msg.Body, br.Config.CommandPrefix) || evt.RoomID == sender.ManagementRoom { + if !sender.Permissions.Commands { + status := WrapErrorInStatus(errors.New("you don't have permission to use commands")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage() + br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) return } + br.Commands.Handle( + ctx, + evt.RoomID, + evt.ID, + sender, + strings.TrimPrefix(msg.Body, br.Config.CommandPrefix+" "), + msg.RelatesTo.GetReplyTo(), + ) + return } } - if evt.Type == event.StateMember && evt.GetStateKey() == br.Bot.GetMXID().String() && evt.Content.AsMember().Membership == event.MembershipInvite { - br.Bot.EnsureJoined(ctx, evt.RoomID) - // TODO handle errors - if sender.ManagementRoom == "" { - sender.ManagementRoom = evt.RoomID - br.DB.User.Update(ctx, sender.User) + if evt.Type == event.StateMember && evt.GetStateKey() == br.Bot.GetMXID().String() && evt.Content.AsMember().Membership == event.MembershipInvite && sender != nil { + if !sender.Permissions.Commands { + _, err := br.Bot.SendState(ctx, evt.RoomID, event.StateMember, br.Bot.GetMXID().String(), &event.Content{ + Parsed: &event.MemberEventContent{ + Membership: event.MembershipLeave, + Reason: "You don't have permission to send commands to this bridge", + }, + }, time.Time{}) + if err != nil { + log.Err(err).Msg("Failed to reject invite from user with no permission") + } else { + log.Debug().Msg("Rejected invite from user with no permission") + } + } else if err := br.Bot.EnsureJoined(ctx, evt.RoomID); err != nil { + log.Err(err).Msg("Failed to accept invite to room") + } else { + log.Debug().Msg("Accepted invite to room as bot") + if sender.ManagementRoom == "" { + sender.ManagementRoom = evt.RoomID + err = br.DB.User.Update(ctx, sender.User) + if err != nil { + log.Err(err).Msg("Failed to update user's management room in database") + } + } } return } diff --git a/bridgev2/user.go b/bridgev2/user.go index 50054a77..b9ea462b 100644 --- a/bridgev2/user.go +++ b/bridgev2/user.go @@ -17,6 +17,7 @@ import ( "golang.org/x/exp/maps" "golang.org/x/exp/slices" + "maunium.net/go/mautrix/bridgev2/bridgeconfig" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/id" @@ -28,6 +29,7 @@ type User struct { Log zerolog.Logger CommandState unsafe.Pointer + Permissions bridgeconfig.Permissions doublePuppetIntent MatrixAPI doublePuppetInitialized bool @@ -54,10 +56,11 @@ func (br *Bridge) loadUser(ctx context.Context, dbUser *database.User, queryErr } } user := &User{ - User: dbUser, - Bridge: br, - Log: br.Log.With().Stringer("user_mxid", dbUser.MXID).Logger(), - logins: make(map[networkid.UserLoginID]*UserLogin), + User: dbUser, + Bridge: br, + Log: br.Log.With().Stringer("user_mxid", dbUser.MXID).Logger(), + logins: make(map[networkid.UserLoginID]*UserLogin), + Permissions: br.Config.Permissions.Get(dbUser.MXID), } br.usersByMXID[user.MXID] = user err := br.unlockedLoadUserLoginsByMXID(ctx, user) From 2b668652aba5b7a5f4e12bb21a992029ca1420d0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 29 Jun 2024 15:52:31 +0300 Subject: [PATCH 0375/1647] bridgev2: implement un/set-relay commands --- bridgev2/commands/login.go | 21 +++- bridgev2/commands/processor.go | 3 +- bridgev2/commands/relay.go | 134 +++++++++++++++++++++ bridgev2/database/portal.go | 6 +- bridgev2/matrix/mxmain/example-config.yaml | 2 +- bridgev2/portal.go | 14 +++ 6 files changed, 174 insertions(+), 6 deletions(-) create mode 100644 bridgev2/commands/relay.go diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go index 233e3d80..c0e92c27 100644 --- a/bridgev2/commands/login.go +++ b/bridgev2/commands/login.go @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -295,6 +295,25 @@ func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginSte } } +var CommandListLogins = &FullHandler{ + Func: fnListLogins, + Name: "list-logins", + Help: HelpMeta{ + Section: HelpSectionAuth, + Description: "List your logins", + }, + RequiresLoginPermission: true, +} + +func fnListLogins(ce *Event) { + logins := ce.User.GetFormattedUserLogins() + if len(logins) == 0 { + ce.Reply("You're not logged in") + } else { + ce.Reply("%s", logins) + } +} + var CommandLogout = &FullHandler{ Func: fnLogout, Name: "logout", diff --git a/bridgev2/commands/processor.go b/bridgev2/commands/processor.go index eb08cae2..a1c09d7e 100644 --- a/bridgev2/commands/processor.go +++ b/bridgev2/commands/processor.go @@ -43,7 +43,8 @@ func NewProcessor(bridge *bridgev2.Bridge) bridgev2.CommandProcessor { proc.AddHandlers( CommandHelp, CommandCancel, CommandRegisterPush, CommandDeletePortal, - CommandLogin, CommandLogout, CommandSetPreferredLogin, + CommandLogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin, + CommandSetRelay, CommandUnsetRelay, CommandResolveIdentifier, CommandStartChat, ) return proc diff --git a/bridgev2/commands/relay.go b/bridgev2/commands/relay.go new file mode 100644 index 00000000..7d2e10ba --- /dev/null +++ b/bridgev2/commands/relay.go @@ -0,0 +1,134 @@ +// Copyright (c) 2022 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "golang.org/x/exp/slices" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" +) + +var fakeEvtSetRelay = event.Type{Type: "fi.mau.bridge.set_relay", Class: event.StateEventType} + +var CommandSetRelay = &FullHandler{ + Func: fnSetRelay, + Name: "set-relay", + Help: HelpMeta{ + Section: HelpSectionAuth, + Description: "Use your account to relay messages sent by users who haven't logged in", + Args: "[_login ID_]", + }, + RequiresPortal: true, +} + +func fnSetRelay(ce *Event) { + if !ce.Bridge.Config.Relay.Enabled { + ce.Reply("This bridge does not allow relay mode") + return + } else if !canManageRelay(ce) { + ce.Reply("You don't have permission to manage the relay in this room") + return + } + var relay *bridgev2.UserLogin + if len(ce.Args) == 0 { + relay = ce.User.GetDefaultLogin() + if relay == nil { + if len(ce.Bridge.Config.Relay.DefaultRelays) == 0 { + ce.Reply("You're not logged in and there are no default relay users configured") + return + } + logins, err := ce.Bridge.GetUserLoginsInPortal(ce.Ctx, ce.Portal.PortalKey) + if err != nil { + ce.Log.Err(err).Msg("Failed to get user logins in portal") + ce.Reply("Failed to get logins in portal to find default relay") + return + } + Outer: + for _, loginID := range ce.Bridge.Config.Relay.DefaultRelays { + for _, login := range logins { + if login.ID == loginID { + relay = login + break Outer + } + } + } + if relay == nil { + ce.Reply("You're not logged in and none of the default relay users are in the chat") + return + } + } + } else { + relay = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) + if relay == nil { + ce.Reply("User login with ID %q not found", ce.Args[0]) + return + } else if !slices.Contains(ce.Bridge.Config.Relay.DefaultRelays, relay.ID) && relay.UserMXID != ce.User.MXID && !ce.User.Permissions.Admin { + ce.Reply("Only bridge admins can set another user's login as the relay") + return + } + } + err := ce.Portal.SetRelay(ce.Ctx, relay) + if err != nil { + ce.Log.Err(err).Msg("Failed to unset relay") + ce.Reply("Failed to save relay settings") + } else { + ce.Reply( + "Messages sent by users who haven't logged in will now be relayed through %s ([%s](%s)'s login)", + relay.Metadata.RemoteName, + relay.UserMXID, + // TODO this will need to stop linkifying if we ever allow UserLogins that aren't bound to a real user. + relay.UserMXID.URI().MatrixToURL(), + ) + } +} + +var CommandUnsetRelay = &FullHandler{ + Func: fnUnsetRelay, + Name: "unset-relay", + Help: HelpMeta{ + Section: HelpSectionAuth, + Description: "Stop relaying messages sent by users who haven't logged in", + }, + RequiresPortal: true, +} + +func fnUnsetRelay(ce *Event) { + if ce.Portal.Relay == nil { + ce.Reply("This portal doesn't have a relay set.") + return + } else if !canManageRelay(ce) { + ce.Reply("You don't have permission to manage the relay in this room") + return + } + err := ce.Portal.SetRelay(ce.Ctx, nil) + if err != nil { + ce.Log.Err(err).Msg("Failed to unset relay") + ce.Reply("Failed to save relay settings") + } else { + ce.Reply("Stopped relaying messages for users who haven't logged in") + } +} + +func canManageRelay(ce *Event) bool { + if ce.Bridge.Config.Relay.AdminOnly { + return ce.User.Permissions.Admin + } + return ce.User.Permissions.Admin || + (ce.Portal.Relay != nil && ce.Portal.Relay.UserMXID == ce.User.MXID) || + hasRelayRoomPermissions(ce) +} + +func hasRelayRoomPermissions(ce *Event) bool { + levels, err := ce.Bridge.Matrix.GetPowerLevels(ce.Ctx, ce.RoomID) + if err != nil { + ce.Log.Err(err).Msg("Failed to check room power levels") + return false + } + return levels.GetUserLevel(ce.User.MXID) >= levels.GetEventLevel(fakeEvtSetRelay) +} diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index 0e8c3665..47c39a0b 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -87,13 +87,13 @@ const ( name_set, avatar_set, topic_set, in_space, metadata, relay_bridge_id ) VALUES ( - $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, - CASE WHEN $6 IS NULL THEN NULL ELSE $1 END + $1, $2, $3, $4, $5, cast($6 AS TEXT), $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, + CASE WHEN cast($6 AS TEXT) IS NULL THEN NULL ELSE $1 END ) ` updatePortalQuery = ` UPDATE portal - SET mxid=$4, parent_id=$5, relay_bridge_id=CASE WHEN $6 IS NULL THEN NULL ELSE bridge_id END, relay_login_id=$6, + SET mxid=$4, parent_id=$5, relay_login_id=cast($6 AS TEXT), relay_bridge_id=CASE WHEN cast($6 AS TEXT) IS NULL THEN NULL ELSE bridge_id END, name=$7, topic=$8, avatar_id=$9, avatar_hash=$10, avatar_mxc=$11, name_set=$12, avatar_set=$13, topic_set=$14, in_space=$15, metadata=$16 WHERE bridge_id=$1 AND id=$2 AND receiver=$3 diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index e040f7df..9b84056f 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -12,7 +12,7 @@ bridge: enabled: false # Should only admins be allowed to set themselves as relay users? admin_only: true - # List of user login IDs which anyone can set as a relay using set-default-relay as long as they're in the room. + # List of user login IDs which anyone can set as a relay, as long as the relay user is in the room. default_relays: [] # The formats to use when sending messages via the relaybot. message_formats: diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 7a618897..8780c4ec 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2182,3 +2182,17 @@ func (portal *Portal) unlockedDeleteCache() { func (portal *Portal) Save(ctx context.Context) error { return portal.Bridge.DB.Portal.Update(ctx, portal.Portal) } + +func (portal *Portal) SetRelay(ctx context.Context, relay *UserLogin) error { + portal.Relay = relay + if relay == nil { + portal.RelayLoginID = "" + } else { + portal.RelayLoginID = relay.ID + } + err := portal.Save(ctx) + if err != nil { + return err + } + return nil +} From cf5284b9b634df7b8903ffbcbe56b2b722e3ca66 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 29 Jun 2024 22:30:12 +0300 Subject: [PATCH 0376/1647] bridgev2: fix relation type of MSS events --- bridgev2/messagestatus.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index 201b4080..2d9a6ddb 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -181,7 +181,7 @@ func (ms *MessageStatus) ToCheckpoint(evt *MessageStatusEventInfo) *status.Messa func (ms *MessageStatus) ToMSSEvent(evt *MessageStatusEventInfo) *event.BeeperMessageStatusEventContent { content := &event.BeeperMessageStatusEventContent{ RelatesTo: event.RelatesTo{ - Type: event.RelAnnotation, + Type: event.RelReference, EventID: evt.EventID, }, Status: ms.Status, From 9dcaacae0771dd1dd96cd416273a639a583410d8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 29 Jun 2024 22:35:39 +0300 Subject: [PATCH 0377/1647] event: initialize map in Set(User|Event)Level if it's nil --- bridgev2/portal.go | 1 + event/powerlevels.go | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 8780c4ec..a05db54c 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2028,6 +2028,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i event.StateServerACL.Type: 100, event.StateEncryption.Type: 100, }, + Users: map[id.UserID]int{}, } initialMembers, extraFunctionalMembers, err := portal.GetInitialMemberList(ctx, info.Members, source, powerLevels) if err != nil { diff --git a/event/powerlevels.go b/event/powerlevels.go index 91d56611..d291eacd 100644 --- a/event/powerlevels.go +++ b/event/powerlevels.go @@ -143,6 +143,9 @@ func (pl *PowerLevelsEventContent) SetUserLevel(userID id.UserID, level int) { if level == pl.UsersDefault { delete(pl.Users, userID) } else { + if pl.Users == nil { + pl.Users = make(map[id.UserID]int) + } pl.Users[userID] = level } } @@ -175,6 +178,9 @@ func (pl *PowerLevelsEventContent) SetEventLevel(eventType Type, level int) { if (eventType.IsState() && level == pl.StateDefault()) || (!eventType.IsState() && level == pl.EventsDefault) { delete(pl.Events, eventType.String()) } else { + if pl.Events == nil { + pl.Events = make(map[string]int) + } pl.Events[eventType.String()] = level } } From 302bbb739ba485627b458c57e2f4306064ad239b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 29 Jun 2024 22:45:40 +0300 Subject: [PATCH 0378/1647] bridgev2: fix adding portals to spaces when creating --- bridgev2/database/userportal.go | 13 +++++++++--- bridgev2/portal.go | 35 ++++++++++++++++++++++----------- bridgev2/space.go | 5 ++++- 3 files changed, 37 insertions(+), 16 deletions(-) diff --git a/bridgev2/database/userportal.go b/bridgev2/database/userportal.go index 5c34ad51..71235d2a 100644 --- a/bridgev2/database/userportal.go +++ b/bridgev2/database/userportal.go @@ -44,10 +44,13 @@ const ( getUserPortalQuery = getUserPortalBaseQuery + ` WHERE bridge_id=$1 AND user_mxid=$2 AND login_id=$3 AND portal_id=$4 AND portal_receiver=$5 ` - findUserLoginsByPortalIDQuery = getUserPortalBaseQuery + ` + findUserLoginsOfUserByPortalIDQuery = getUserPortalBaseQuery + ` WHERE bridge_id=$1 AND user_mxid=$2 AND portal_id=$3 AND portal_receiver=$4 ORDER BY CASE WHEN preferred THEN 0 ELSE 1 END, login_id ` + getAllUserLoginsInPortalQuery = getUserPortalBaseQuery + ` + WHERE bridge_id=$1 AND portal_id=$2 AND portal_receiver=$3 + ` getAllPortalsForLoginQuery = getUserPortalBaseQuery + ` WHERE bridge_id=$1 AND user_mxid=$2 AND login_id=$3 ` @@ -79,14 +82,18 @@ func UserPortalFor(ul *UserLogin, portal networkid.PortalKey) *UserPortal { } } -func (upq *UserPortalQuery) GetAllByUser(ctx context.Context, userID id.UserID, portal networkid.PortalKey) ([]*UserPortal, error) { - return upq.QueryMany(ctx, findUserLoginsByPortalIDQuery, upq.BridgeID, userID, portal.ID, portal.Receiver) +func (upq *UserPortalQuery) GetAllForUserInPortal(ctx context.Context, userID id.UserID, portal networkid.PortalKey) ([]*UserPortal, error) { + return upq.QueryMany(ctx, findUserLoginsOfUserByPortalIDQuery, upq.BridgeID, userID, portal.ID, portal.Receiver) } func (upq *UserPortalQuery) GetAllForLogin(ctx context.Context, login *UserLogin) ([]*UserPortal, error) { return upq.QueryMany(ctx, getAllPortalsForLoginQuery, upq.BridgeID, login.UserMXID, login.ID) } +func (upq *UserPortalQuery) GetAllInPortal(ctx context.Context, portal networkid.PortalKey) ([]*UserPortal, error) { + return upq.QueryMany(ctx, getAllUserLoginsInPortalQuery, upq.BridgeID, portal.ID, portal.Receiver) +} + func (upq *UserPortalQuery) Get(ctx context.Context, login *UserLogin, portal networkid.PortalKey) (*UserPortal, error) { return upq.QueryOne(ctx, getUserPortalQuery, upq.BridgeID, login.UserMXID, login.ID, portal.ID, portal.Receiver) } diff --git a/bridgev2/portal.go b/bridgev2/portal.go index a05db54c..23bd5870 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -189,7 +189,7 @@ func (portal *Portal) eventLoop() { } func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowRelay bool) (*UserLogin, *database.UserPortal, error) { - logins, err := portal.Bridge.DB.UserPortal.GetAllByUser(ctx, user.MXID, portal.PortalKey) + logins, err := portal.Bridge.DB.UserPortal.GetAllForUserInPortal(ctx, user.MXID, portal.PortalKey) if err != nil { return nil, nil, err } @@ -2038,20 +2038,20 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i powerLevels.EnsureUserLevel(portal.Bridge.Bot.GetMXID(), 9001) req := mautrix.ReqCreateRoom{ - Visibility: "private", - Name: portal.Name, - Topic: portal.Topic, - CreationContent: make(map[string]any), - InitialState: make([]*event.Event, 0, 6), - Preset: "private_chat", - IsDirect: portal.Metadata.IsDirect, - PowerLevelOverride: powerLevels, - BeeperLocalRoomID: id.RoomID(fmt.Sprintf("!%s:%s", portal.ID, portal.Bridge.Matrix.ServerName())), - BeeperInitialMembers: initialMembers, + Visibility: "private", + Name: portal.Name, + Topic: portal.Topic, + CreationContent: make(map[string]any), + InitialState: make([]*event.Event, 0, 6), + Preset: "private_chat", + IsDirect: portal.Metadata.IsDirect, + PowerLevelOverride: powerLevels, + BeeperLocalRoomID: id.RoomID(fmt.Sprintf("!%s:%s", portal.ID, portal.Bridge.Matrix.ServerName())), } autoJoinInvites := portal.Bridge.Matrix.GetCapabilities().AutoJoinInvites - // TODO remove this after initial_members is supported in hungryserv if autoJoinInvites { + req.BeeperInitialMembers = initialMembers + // TODO remove this after initial_members is supported in hungryserv req.BeeperAutoJoinInvites = true req.Invite = initialMembers } @@ -2149,6 +2149,17 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i } } } + userPortals, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to get user logins in portal to add portal to spaces") + } else { + for _, up := range userPortals { + login := portal.Bridge.GetCachedUserLoginByID(up.LoginID) + if login != nil { + go login.tryAddPortalToSpace(ctx, portal, up.CopyWithoutValues()) + } + } + } return nil } diff --git a/bridgev2/space.go b/bridgev2/space.go index 55dcaf32..a778fc9b 100644 --- a/bridgev2/space.go +++ b/bridgev2/space.go @@ -29,7 +29,7 @@ func (ul *UserLogin) MarkInPortal(ctx context.Context, portal *Portal) { return } ul.inPortalCache.Add(portal.PortalKey) - if ul.Bridge.Config.PersonalFilteringSpaces && (userPortal.InSpace == nil || !*userPortal.InSpace) { + if ul.Bridge.Config.PersonalFilteringSpaces && (userPortal.InSpace == nil || !*userPortal.InSpace) && portal.MXID != "" { go ul.tryAddPortalToSpace(ctx, portal, userPortal.CopyWithoutValues()) } } @@ -42,6 +42,9 @@ func (ul *UserLogin) tryAddPortalToSpace(ctx context.Context, portal *Portal, us } func (ul *UserLogin) AddPortalToSpace(ctx context.Context, portal *Portal, userPortal *database.UserPortal) error { + if portal.MXID == "" { + return nil + } spaceRoom, err := ul.GetSpaceRoom(ctx) if err != nil { return fmt.Errorf("failed to get space room: %w", err) From 012d542a07b2a6d3a01f9ad325ce1b2cf5768890 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 29 Jun 2024 23:18:17 +0300 Subject: [PATCH 0379/1647] bridgev2: improve handling bot invites --- bridgev2/matrix/intent.go | 12 ++++++- bridgev2/queue.go | 73 ++++++++++++++++++++++++++------------- 2 files changed, 60 insertions(+), 25 deletions(-) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 45927ef0..26f2a4bb 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -244,7 +244,17 @@ func (as *ASIntent) InviteUser(ctx context.Context, roomID id.RoomID, userID id. } func (as *ASIntent) EnsureJoined(ctx context.Context, roomID id.RoomID) error { - return as.Matrix.EnsureJoined(ctx, roomID) + err := as.Matrix.EnsureJoined(ctx, roomID) + if err != nil { + return err + } + if as.Connector.Bot.UserID == as.Matrix.UserID { + _, err = as.Matrix.State(ctx, roomID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get state after joining room with bot") + } + } + return nil } func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) (id.RoomID, error) { diff --git a/bridgev2/queue.go b/bridgev2/queue.go index 1c874d10..ec60cbb8 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -16,6 +16,7 @@ import ( "github.com/rs/zerolog" "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" ) func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) { @@ -68,30 +69,7 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) { } } if evt.Type == event.StateMember && evt.GetStateKey() == br.Bot.GetMXID().String() && evt.Content.AsMember().Membership == event.MembershipInvite && sender != nil { - if !sender.Permissions.Commands { - _, err := br.Bot.SendState(ctx, evt.RoomID, event.StateMember, br.Bot.GetMXID().String(), &event.Content{ - Parsed: &event.MemberEventContent{ - Membership: event.MembershipLeave, - Reason: "You don't have permission to send commands to this bridge", - }, - }, time.Time{}) - if err != nil { - log.Err(err).Msg("Failed to reject invite from user with no permission") - } else { - log.Debug().Msg("Rejected invite from user with no permission") - } - } else if err := br.Bot.EnsureJoined(ctx, evt.RoomID); err != nil { - log.Err(err).Msg("Failed to accept invite to room") - } else { - log.Debug().Msg("Accepted invite to room as bot") - if sender.ManagementRoom == "" { - sender.ManagementRoom = evt.RoomID - err = br.DB.User.Update(ctx, sender.User) - if err != nil { - log.Err(err).Msg("Failed to update user's management room in database") - } - } - } + br.handleBotInvite(ctx, evt, sender) return } portal, err := br.GetPortalByMXID(ctx, evt.RoomID) @@ -111,6 +89,53 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) { } } +func (br *Bridge) handleBotInvite(ctx context.Context, evt *event.Event, sender *User) { + log := zerolog.Ctx(ctx) + if !sender.Permissions.Commands { + _, err := br.Bot.SendState(ctx, evt.RoomID, event.StateMember, br.Bot.GetMXID().String(), &event.Content{ + Parsed: &event.MemberEventContent{ + Membership: event.MembershipLeave, + Reason: "You don't have permission to send commands to this bridge", + }, + }, time.Time{}) + if err != nil { + log.Err(err).Msg("Failed to reject invite from user with no permission") + } else { + log.Debug().Msg("Rejected invite from user with no permission") + } + return + } + err := br.Bot.EnsureJoined(ctx, evt.RoomID) + if err != nil { + log.Err(err).Msg("Failed to accept invite to room") + return + } + log.Debug().Msg("Accepted invite to room as bot") + members, err := br.Matrix.GetMembers(ctx, evt.RoomID) + if err != nil { + log.Err(err).Msg("Failed to get members of room after accepting invite") + } + if len(members) == 2 { + var message string + if sender.ManagementRoom == "" { + message = fmt.Sprintf("Hello, I'm a %s bridge bot.\n\nUse `help` for help or `login` to log in.\n\nThis room has been marked as your management room.", br.Network.GetName().DisplayName) + sender.ManagementRoom = evt.RoomID + err = br.DB.User.Update(ctx, sender.User) + if err != nil { + log.Err(err).Msg("Failed to update user's management room in database") + } + } else { + message = fmt.Sprintf("Hello, I'm a %s bridge bot.\n\nUse `%s help` for help.", br.Network.GetName().DisplayName, br.Config.CommandPrefix) + } + _, err = br.Bot.SendMessage(ctx, evt.RoomID, event.EventMessage, &event.Content{ + Parsed: format.RenderMarkdown(message, true, false), + }, time.Time{}) + if err != nil { + log.Err(err).Msg("Failed to send welcome message to room") + } + } +} + func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) { log := login.Log ctx := log.WithContext(context.TODO()) From de5a5607ad01732cf7964b65c35c0906980435e9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 29 Jun 2024 23:24:04 +0300 Subject: [PATCH 0380/1647] bridgev2/matrix: ignore ephemeral events from bridge bot and ghosts --- bridgev2/matrix/matrix.go | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/bridgev2/matrix/matrix.go b/bridgev2/matrix/matrix.go index 4fde9941..9d1c8f9d 100644 --- a/bridgev2/matrix/matrix.go +++ b/bridgev2/matrix/matrix.go @@ -10,6 +10,7 @@ import ( "context" "errors" "fmt" + "slices" "time" "github.com/rs/zerolog" @@ -34,12 +35,13 @@ func (br *Connector) handleRoomEvent(ctx context.Context, evt *event.Event) { } func (br *Connector) handleEphemeralEvent(ctx context.Context, evt *event.Event) { - if evt.Type == event.EphemeralEventReceipt { + switch evt.Type { + case event.EphemeralEventReceipt: receiptContent := *evt.Content.AsReceipt() for eventID, receipts := range receiptContent { for receiptType, userReceipts := range receipts { for userID, receipt := range userReceipts { - if br.AS.DoublePuppetValue != "" && receipt.Extra[appservice.DoublePuppetKey] == br.AS.DoublePuppetValue { + if br.shouldIgnoreEventFromUser(userID) || (br.AS.DoublePuppetValue != "" && receipt.Extra[appservice.DoublePuppetKey] == br.AS.DoublePuppetValue) { delete(userReceipts, userID) } } @@ -54,6 +56,9 @@ func (br *Connector) handleEphemeralEvent(ctx context.Context, evt *event.Event) if len(receiptContent) == 0 { return } + case event.EphemeralEventTyping: + typingContent := evt.Content.AsTyping() + typingContent.UserIDs = slices.DeleteFunc(typingContent.UserIDs, br.shouldIgnoreEventFromUser) } br.Bridge.QueueMatrixEvent(ctx, evt) } @@ -154,14 +159,21 @@ func (br *Connector) sendBridgeCheckpoint(ctx context.Context, evt *event.Event) } } -func (br *Connector) shouldIgnoreEvent(evt *event.Event) bool { - if evt.Sender == br.Bot.UserID { +func (br *Connector) shouldIgnoreEventFromUser(userID id.UserID) bool { + if userID == br.Bot.UserID { return true } - _, isGhost := br.ParseGhostMXID(evt.Sender) + _, isGhost := br.ParseGhostMXID(userID) if isGhost { return true } + return false +} + +func (br *Connector) shouldIgnoreEvent(evt *event.Event) bool { + if br.shouldIgnoreEventFromUser(evt.Sender) { + return true + } dpVal, ok := evt.Content.Raw[appservice.DoublePuppetKey] if ok && dpVal == br.AS.DoublePuppetValue { dpTS, ok := evt.Content.Raw[appservice.DoublePuppetTSKey].(float64) From 9d082e1e2b23eeff03f78cc2d642f2127db7c836 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 29 Jun 2024 23:29:05 +0300 Subject: [PATCH 0381/1647] bridgev2/commands: don't send MSS event for delete-portal command --- bridgev2/commands/event.go | 2 ++ bridgev2/commands/processor.go | 12 +++++++----- bridgev2/commands/startchat.go | 1 + 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/bridgev2/commands/event.go b/bridgev2/commands/event.go index d03c7e68..2a4b26a5 100644 --- a/bridgev2/commands/event.go +++ b/bridgev2/commands/event.go @@ -38,6 +38,8 @@ type Event struct { ReplyTo id.EventID Ctx context.Context Log *zerolog.Logger + + MessageStatus *bridgev2.MessageStatus } // Reply sends a reply to command as notice, with optional string formatting and automatic $cmdprefix replacement. diff --git a/bridgev2/commands/processor.go b/bridgev2/commands/processor.go index a1c09d7e..fce13d09 100644 --- a/bridgev2/commands/processor.go +++ b/bridgev2/commands/processor.go @@ -68,6 +68,10 @@ func (proc *Processor) AddHandler(handler CommandHandler) { // Handle handles messages to the bridge func (proc *Processor) Handle(ctx context.Context, roomID id.RoomID, eventID id.EventID, user *bridgev2.User, message string, replyTo id.EventID) { + ms := &bridgev2.MessageStatus{ + Step: status.MsgStepCommand, + Status: event.MessageStatusSuccess, + } defer func() { statusInfo := &bridgev2.MessageStatusEventInfo{ RoomID: roomID, @@ -75,10 +79,6 @@ func (proc *Processor) Handle(ctx context.Context, roomID id.RoomID, eventID id. EventType: event.EventMessage, Sender: user.MXID, } - ms := bridgev2.MessageStatus{ - Step: status.MsgStepCommand, - Status: event.MessageStatusSuccess, - } err := recover() if err != nil { zerolog.Ctx(ctx).Error(). @@ -94,7 +94,7 @@ func (proc *Processor) Handle(ctx context.Context, roomID id.RoomID, eventID id. } ms.ErrorAsMessage = true } - proc.bridge.Matrix.SendMessageStatus(ctx, &ms, statusInfo) + proc.bridge.Matrix.SendMessageStatus(ctx, ms, statusInfo) }() args := strings.Fields(message) if len(args) == 0 { @@ -119,6 +119,8 @@ func (proc *Processor) Handle(ctx context.Context, roomID id.RoomID, eventID id. RawArgs: rawArgs, ReplyTo: replyTo, Ctx: ctx, + + MessageStatus: ms, } realCommand, ok := proc.aliases[ce.Command] diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go index 903fc17f..3054a1a1 100644 --- a/bridgev2/commands/startchat.go +++ b/bridgev2/commands/startchat.go @@ -136,6 +136,7 @@ var CommandDeletePortal = &FullHandler{ if err != nil { ce.Reply("Failed to clean up room: %v", err) } + ce.MessageStatus.DisableMSS = true }, Name: "delete-portal", Help: HelpMeta{ From f2585f7bcce171d187b14bfbc0dfd3012b489377 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 29 Jun 2024 23:29:51 +0300 Subject: [PATCH 0382/1647] bridgev2: don't use double puppet of other user in DMs --- bridgev2/portal.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 23bd5870..fc745388 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -999,7 +999,7 @@ func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventS return } extraUserID = source.UserMXID - } else if sender.SenderLogin != "" { + } else if sender.SenderLogin != "" && portal.Receiver == "" { senderLogin := portal.Bridge.GetCachedUserLoginByID(sender.SenderLogin) if senderLogin != nil { intent = senderLogin.User.DoublePuppet(ctx) @@ -1010,13 +1010,15 @@ func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventS } } if sender.Sender != "" { - for _, login := range otherLogins { - if login.Client.IsThisUser(ctx, sender.Sender) { - intent = login.User.DoublePuppet(ctx) - if intent != nil { - return + if portal.Receiver == "" { + for _, login := range otherLogins { + if login.Client.IsThisUser(ctx, sender.Sender) { + intent = login.User.DoublePuppet(ctx) + if intent != nil { + return + } + extraUserID = login.UserMXID } - extraUserID = login.UserMXID } } ghost, err := portal.Bridge.GetGhostByID(ctx, sender.Sender) From 4257c5edd3d435a14d296053bf36ee75c21153b0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 29 Jun 2024 23:56:06 +0300 Subject: [PATCH 0383/1647] bridgev2: add room receiver to message primary key --- bridgev2/database/message.go | 36 ++++----- bridgev2/database/upgrades/00-latest.sql | 10 +-- .../05-message-receiver-pkey-postgres.sql | 10 +++ .../06-message-receiver-pkey-sqlite.sql | 75 +++++++++++++++++++ bridgev2/portal.go | 20 ++--- 5 files changed, 118 insertions(+), 33 deletions(-) create mode 100644 bridgev2/database/upgrades/05-message-receiver-pkey-postgres.sql create mode 100644 bridgev2/database/upgrades/06-message-receiver-pkey-sqlite.sql diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index d083eb98..9129c0b7 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -64,12 +64,12 @@ const ( getMessageBaseQuery = ` SELECT rowid, bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, timestamp, relates_to, metadata FROM message ` - getAllMessagePartsByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND id=$2` - getMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND part_id=$3` + getAllMessagePartsByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3` + getMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 AND part_id=$4` getMessagePartByRowIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND rowid=$2` getMessageByMXIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` - getLastMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND id=$2 ORDER BY part_id DESC LIMIT 1` - getFirstMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND id=$2 ORDER BY part_id ASC LIMIT 1` + getLastMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id DESC LIMIT 1` + getFirstMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id ASC LIMIT 1` getMessagesBetweenTimeQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND timestamp>$4 AND timestamp<=$5` insertMessageQuery = ` INSERT INTO message (bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, timestamp, relates_to, metadata) @@ -81,42 +81,42 @@ const ( WHERE bridge_id=$1 AND rowid=$11 ` deleteAllMessagePartsByIDQuery = ` - DELETE FROM message WHERE bridge_id=$1 AND id=$2 + DELETE FROM message WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ` deleteMessagePartByRowIDQuery = ` DELETE FROM message WHERE bridge_id=$1 AND rowid=$2 ` ) -func (mq *MessageQuery) GetAllPartsByID(ctx context.Context, id networkid.MessageID) ([]*Message, error) { - return mq.QueryMany(ctx, getAllMessagePartsByIDQuery, mq.BridgeID, id) +func (mq *MessageQuery) GetAllPartsByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID) ([]*Message, error) { + return mq.QueryMany(ctx, getAllMessagePartsByIDQuery, mq.BridgeID, receiver, id) } -func (mq *MessageQuery) GetPartByID(ctx context.Context, id networkid.MessageID, partID networkid.PartID) (*Message, error) { - return mq.QueryOne(ctx, getMessagePartByIDQuery, mq.BridgeID, id, partID) +func (mq *MessageQuery) GetPartByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID, partID networkid.PartID) (*Message, error) { + return mq.QueryOne(ctx, getMessagePartByIDQuery, mq.BridgeID, receiver, id, partID) } func (mq *MessageQuery) GetPartByMXID(ctx context.Context, mxid id.EventID) (*Message, error) { return mq.QueryOne(ctx, getMessageByMXIDQuery, mq.BridgeID, mxid) } -func (mq *MessageQuery) GetLastPartByID(ctx context.Context, id networkid.MessageID) (*Message, error) { - return mq.QueryOne(ctx, getLastMessagePartByIDQuery, mq.BridgeID, id) +func (mq *MessageQuery) GetLastPartByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID) (*Message, error) { + return mq.QueryOne(ctx, getLastMessagePartByIDQuery, mq.BridgeID, receiver, id) } -func (mq *MessageQuery) GetFirstPartByID(ctx context.Context, id networkid.MessageID) (*Message, error) { - return mq.QueryOne(ctx, getFirstMessagePartByIDQuery, mq.BridgeID, id) +func (mq *MessageQuery) GetFirstPartByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID) (*Message, error) { + return mq.QueryOne(ctx, getFirstMessagePartByIDQuery, mq.BridgeID, receiver, id) } func (mq *MessageQuery) GetByRowID(ctx context.Context, rowID int64) (*Message, error) { return mq.QueryOne(ctx, getMessagePartByRowIDQuery, mq.BridgeID, rowID) } -func (mq *MessageQuery) GetFirstOrSpecificPartByID(ctx context.Context, id networkid.MessageOptionalPartID) (*Message, error) { +func (mq *MessageQuery) GetFirstOrSpecificPartByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageOptionalPartID) (*Message, error) { if id.PartID == nil { - return mq.GetFirstPartByID(ctx, id.MessageID) + return mq.GetFirstPartByID(ctx, receiver, id.MessageID) } else { - return mq.GetPartByID(ctx, id.MessageID, *id.PartID) + return mq.GetPartByID(ctx, receiver, id.MessageID, *id.PartID) } } @@ -134,8 +134,8 @@ func (mq *MessageQuery) Update(ctx context.Context, msg *Message) error { return mq.Exec(ctx, updateMessageQuery, msg.updateSQLVariables()...) } -func (mq *MessageQuery) DeleteAllParts(ctx context.Context, id networkid.MessageID) error { - return mq.Exec(ctx, deleteAllMessagePartsByIDQuery, mq.BridgeID, id) +func (mq *MessageQuery) DeleteAllParts(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID) error { + return mq.Exec(ctx, deleteAllMessagePartsByIDQuery, mq.BridgeID, receiver, id) } func (mq *MessageQuery) Delete(ctx context.Context, rowID int64) error { diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index f60e5ea8..b84f0549 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v4 (compatible with v1+): Latest revision +-- v0 -> v6 (compatible with v1+): Latest revision CREATE TABLE portal ( bridge_id TEXT NOT NULL, id TEXT NOT NULL, @@ -79,7 +79,7 @@ CREATE TABLE message ( CONSTRAINT message_sender_fkey FOREIGN KEY (bridge_id, sender_id) REFERENCES ghost (bridge_id, id) ON DELETE CASCADE ON UPDATE CASCADE, - CONSTRAINT message_real_pkey UNIQUE (bridge_id, id, part_id) + CONSTRAINT message_real_pkey UNIQUE (bridge_id, room_receiver, id, part_id) ); CREATE TABLE disappearing_message ( @@ -106,12 +106,12 @@ CREATE TABLE reaction ( timestamp BIGINT NOT NULL, metadata jsonb NOT NULL, - PRIMARY KEY (bridge_id, message_id, message_part_id, sender_id, emoji_id), + PRIMARY KEY (bridge_id, room_receiver, message_id, message_part_id, sender_id, emoji_id), CONSTRAINT reaction_room_fkey FOREIGN KEY (bridge_id, room_id, room_receiver) REFERENCES portal (bridge_id, id, receiver) ON DELETE CASCADE ON UPDATE CASCADE, - CONSTRAINT reaction_message_fkey FOREIGN KEY (bridge_id, message_id, message_part_id) - REFERENCES message (bridge_id, id, part_id) + CONSTRAINT reaction_message_fkey FOREIGN KEY (bridge_id, room_receiver, message_id, message_part_id) + REFERENCES message (bridge_id, room_receiver, id, part_id) ON DELETE CASCADE ON UPDATE CASCADE, CONSTRAINT reaction_sender_fkey FOREIGN KEY (bridge_id, sender_id) REFERENCES ghost (bridge_id, id) diff --git a/bridgev2/database/upgrades/05-message-receiver-pkey-postgres.sql b/bridgev2/database/upgrades/05-message-receiver-pkey-postgres.sql new file mode 100644 index 00000000..a19d9d57 --- /dev/null +++ b/bridgev2/database/upgrades/05-message-receiver-pkey-postgres.sql @@ -0,0 +1,10 @@ +-- v5 (compatible with v1+): Add room_receiver to message unique key +-- only: postgres +ALTER TABLE reaction DROP CONSTRAINT reaction_message_fkey; +ALTER TABLE reaction DROP CONSTRAINT reaction_pkey1; +ALTER TABLE reaction ADD PRIMARY KEY (bridge_id, room_receiver, message_id, message_part_id, sender_id, emoji_id); +ALTER TABLE message DROP CONSTRAINT message_real_pkey; +ALTER TABLE message ADD CONSTRAINT message_real_pkey UNIQUE (bridge_id, room_receiver, id, part_id); +ALTER TABLE reaction ADD CONSTRAINT reaction_message_fkey FOREIGN KEY (bridge_id, room_receiver, message_id, message_part_id) + REFERENCES message (bridge_id, room_receiver, id, part_id) + ON DELETE CASCADE ON UPDATE CASCADE; diff --git a/bridgev2/database/upgrades/06-message-receiver-pkey-sqlite.sql b/bridgev2/database/upgrades/06-message-receiver-pkey-sqlite.sql new file mode 100644 index 00000000..d1be1030 --- /dev/null +++ b/bridgev2/database/upgrades/06-message-receiver-pkey-sqlite.sql @@ -0,0 +1,75 @@ +-- v6 (compatible with v1+): Add room_receiver to message unique key +-- transaction: off +-- only: sqlite + +PRAGMA foreign_keys = OFF; +BEGIN; + +CREATE TABLE message_new ( + rowid INTEGER PRIMARY KEY, + + bridge_id TEXT NOT NULL, + id TEXT NOT NULL, + part_id TEXT NOT NULL, + mxid TEXT NOT NULL, + + room_id TEXT NOT NULL, + room_receiver TEXT NOT NULL, + sender_id TEXT NOT NULL, + timestamp BIGINT NOT NULL, + relates_to BIGINT, + metadata jsonb NOT NULL, + + CONSTRAINT message_relation_fkey FOREIGN KEY (relates_to) + REFERENCES message (rowid) ON DELETE SET NULL, + CONSTRAINT message_room_fkey FOREIGN KEY (bridge_id, room_id, room_receiver) + REFERENCES portal (bridge_id, id, receiver) + ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT message_sender_fkey FOREIGN KEY (bridge_id, sender_id) + REFERENCES ghost (bridge_id, id) + ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT message_real_pkey UNIQUE (bridge_id, room_receiver, id, part_id) +); + +INSERT INTO message_new (rowid, bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, timestamp, relates_to, metadata) +SELECT rowid, bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, timestamp, relates_to, metadata +FROM message; + +DROP TABLE message; +ALTER TABLE message_new RENAME TO message; + +CREATE TABLE reaction_new ( + bridge_id TEXT NOT NULL, + message_id TEXT NOT NULL, + message_part_id TEXT NOT NULL, + sender_id TEXT NOT NULL, + emoji_id TEXT NOT NULL, + room_id TEXT NOT NULL, + room_receiver TEXT NOT NULL, + mxid TEXT NOT NULL, + + timestamp BIGINT NOT NULL, + metadata jsonb NOT NULL, + + PRIMARY KEY (bridge_id, room_receiver, message_id, message_part_id, sender_id, emoji_id), + CONSTRAINT reaction_room_fkey FOREIGN KEY (bridge_id, room_id, room_receiver) + REFERENCES portal (bridge_id, id, receiver) + ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT reaction_message_fkey FOREIGN KEY (bridge_id, room_receiver, message_id, message_part_id) + REFERENCES message (bridge_id, room_receiver, id, part_id) + ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT reaction_sender_fkey FOREIGN KEY (bridge_id, sender_id) + REFERENCES ghost (bridge_id, id) + ON DELETE CASCADE ON UPDATE CASCADE +); + +INSERT INTO reaction_new +SELECT bridge_id, message_id, message_part_id, sender_id, emoji_id, room_id, room_receiver, mxid, timestamp, metadata +FROM reaction; + +DROP TABLE reaction; +ALTER TABLE reaction_new RENAME TO reaction; + +PRAGMA foreign_key_check; +COMMIT; +PRAGMA foreign_keys = ON; diff --git a/bridgev2/portal.go b/bridgev2/portal.go index fc745388..a12987ff 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1043,7 +1043,7 @@ func (portal *Portal) GetIntentFor(ctx context.Context, sender EventSender, sour func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin, evt RemoteMessage) { log := zerolog.Ctx(ctx) - existing, err := portal.Bridge.DB.Message.GetFirstPartByID(ctx, evt.GetID()) + existing, err := portal.Bridge.DB.Message.GetFirstPartByID(ctx, portal.Receiver, evt.GetID()) if err != nil { log.Err(err).Msg("Failed to check if message is a duplicate") } else if existing != nil { @@ -1064,7 +1064,7 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin var relatesToRowID int64 var replyTo, threadRoot, prevThreadEvent *database.Message if converted.ReplyTo != nil { - replyTo, err = portal.Bridge.DB.Message.GetFirstOrSpecificPartByID(ctx, *converted.ReplyTo) + replyTo, err = portal.Bridge.DB.Message.GetFirstOrSpecificPartByID(ctx, portal.Receiver, *converted.ReplyTo) if err != nil { log.Err(err).Msg("Failed to get reply target message from database") } else if replyTo == nil { @@ -1074,7 +1074,7 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin } } if converted.ThreadRoot != nil { - threadRoot, err = portal.Bridge.DB.Message.GetFirstOrSpecificPartByID(ctx, *converted.ThreadRoot) + threadRoot, err = portal.Bridge.DB.Message.GetFirstOrSpecificPartByID(ctx, portal.Receiver, *converted.ThreadRoot) if err != nil { log.Err(err).Msg("Failed to get thread root message from database") } else if threadRoot == nil { @@ -1164,7 +1164,7 @@ func (portal *Portal) sendRemoteErrorNotice(ctx context.Context, intent MatrixAP func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, evt RemoteEdit) { log := zerolog.Ctx(ctx) - existing, err := portal.Bridge.DB.Message.GetAllPartsByID(ctx, evt.GetTargetMessage()) + existing, err := portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, evt.GetTargetMessage()) if err != nil { log.Err(err).Msg("Failed to get edit target message") return @@ -1234,9 +1234,9 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e func (portal *Portal) getTargetMessagePart(ctx context.Context, evt RemoteEventWithTargetMessage) (*database.Message, error) { if partTargeter, ok := evt.(RemoteEventWithTargetPart); ok { - return portal.Bridge.DB.Message.GetPartByID(ctx, evt.GetTargetMessage(), partTargeter.GetTargetMessagePart()) + return portal.Bridge.DB.Message.GetPartByID(ctx, portal.Receiver, evt.GetTargetMessage(), partTargeter.GetTargetMessagePart()) } else { - return portal.Bridge.DB.Message.GetFirstPartByID(ctx, evt.GetTargetMessage()) + return portal.Bridge.DB.Message.GetFirstPartByID(ctx, portal.Receiver, evt.GetTargetMessage()) } } @@ -1348,7 +1348,7 @@ func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *Us func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *UserLogin, evt RemoteMessageRemove) { log := zerolog.Ctx(ctx) - targetParts, err := portal.Bridge.DB.Message.GetAllPartsByID(ctx, evt.GetTargetMessage()) + targetParts, err := portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, evt.GetTargetMessage()) if err != nil { log.Err(err).Msg("Failed to get target message for removal") return @@ -1371,7 +1371,7 @@ func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *Use Msg("Sent redaction of message part to Matrix") } } - err = portal.Bridge.DB.Message.DeleteAllParts(ctx, evt.GetTargetMessage()) + err = portal.Bridge.DB.Message.DeleteAllParts(ctx, portal.Receiver, evt.GetTargetMessage()) if err != nil { log.Err(err).Msg("Failed to delete target message from database") } @@ -1382,7 +1382,7 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL var err error var lastTarget *database.Message if lastTargetID := evt.GetLastReceiptTarget(); lastTargetID != "" { - lastTarget, err = portal.Bridge.DB.Message.GetLastPartByID(ctx, lastTargetID) + lastTarget, err = portal.Bridge.DB.Message.GetLastPartByID(ctx, portal.Receiver, lastTargetID) if err != nil { log.Err(err).Str("last_target_id", string(lastTargetID)). Msg("Failed to get last target message for read receipt") @@ -1394,7 +1394,7 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL } if lastTarget == nil { for _, targetID := range evt.GetReceiptTargets() { - target, err := portal.Bridge.DB.Message.GetLastPartByID(ctx, targetID) + target, err := portal.Bridge.DB.Message.GetLastPartByID(ctx, portal.Receiver, targetID) if err != nil { log.Err(err).Str("target_id", string(targetID)). Msg("Failed to get target message for read receipt") From db0ec1ebf9d0ba19db83be1ae6f5d7ec98cc93de Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 30 Jun 2024 00:06:31 +0300 Subject: [PATCH 0384/1647] bridgev2: ensure user is joined when marking in portal --- bridgev2/matrix/intent.go | 4 ++++ bridgev2/matrixinterface.go | 1 + bridgev2/portal.go | 1 + bridgev2/space.go | 18 ++++++++++++++++-- 4 files changed, 22 insertions(+), 2 deletions(-) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 26f2a4bb..da5e63ce 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -257,6 +257,10 @@ func (as *ASIntent) EnsureJoined(ctx context.Context, roomID id.RoomID) error { return nil } +func (as *ASIntent) EnsureInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) error { + return as.Matrix.EnsureInvited(ctx, roomID, userID) +} + func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) (id.RoomID, error) { if as.Connector.Config.Encryption.Default { content := &event.EncryptionEventContent{Algorithm: id.AlgorithmMegolmV1} diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 7689011d..cc6b4a98 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -64,6 +64,7 @@ type MatrixAPI interface { DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnly bool) error InviteUser(ctx context.Context, roomID id.RoomID, userID id.UserID) error EnsureJoined(ctx context.Context, roomID id.RoomID) error + EnsureInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) error TagRoom(ctx context.Context, roomID id.RoomID, tag event.RoomTag, isTagged bool) error MuteRoom(ctx context.Context, roomID id.RoomID, until time.Time) error diff --git a/bridgev2/portal.go b/bridgev2/portal.go index a12987ff..397d6140 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2158,6 +2158,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i for _, up := range userPortals { login := portal.Bridge.GetCachedUserLoginByID(up.LoginID) if login != nil { + login.inPortalCache.Remove(portal.PortalKey) go login.tryAddPortalToSpace(ctx, portal, up.CopyWithoutValues()) } } diff --git a/bridgev2/space.go b/bridgev2/space.go index a778fc9b..f5066e07 100644 --- a/bridgev2/space.go +++ b/bridgev2/space.go @@ -29,8 +29,22 @@ func (ul *UserLogin) MarkInPortal(ctx context.Context, portal *Portal) { return } ul.inPortalCache.Add(portal.PortalKey) - if ul.Bridge.Config.PersonalFilteringSpaces && (userPortal.InSpace == nil || !*userPortal.InSpace) && portal.MXID != "" { - go ul.tryAddPortalToSpace(ctx, portal, userPortal.CopyWithoutValues()) + if portal.MXID != "" { + dp := ul.User.DoublePuppet(ctx) + if dp != nil { + err = dp.EnsureJoined(ctx, portal.MXID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to ensure double puppet is joined to portal") + } + } else { + err = ul.Bridge.Bot.EnsureInvited(ctx, portal.MXID, ul.UserMXID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to ensure user is invited to portal") + } + } + if ul.Bridge.Config.PersonalFilteringSpaces && (userPortal.InSpace == nil || !*userPortal.InSpace) { + go ul.tryAddPortalToSpace(ctx, portal, userPortal.CopyWithoutValues()) + } } } From 9af9101c27dfe7f535fb9e4feb3f83051f7f2038 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 30 Jun 2024 00:13:58 +0300 Subject: [PATCH 0385/1647] bridgev2: don't log warnings about not being logged in in ephemeral event handlers --- bridgev2/commands/relay.go | 2 +- bridgev2/portal.go | 16 +++++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/bridgev2/commands/relay.go b/bridgev2/commands/relay.go index 7d2e10ba..9093a52c 100644 --- a/bridgev2/commands/relay.go +++ b/bridgev2/commands/relay.go @@ -66,7 +66,7 @@ func fnSetRelay(ce *Event) { } else { relay = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) if relay == nil { - ce.Reply("User login with ID %q not found", ce.Args[0]) + ce.Reply("User login with ID `%s` not found", ce.Args[0]) return } else if !slices.Contains(ce.Bridge.Config.Relay.DefaultRelays, relay.ID) && relay.UserMXID != ce.User.MXID && !ce.User.Permissions.Admin { ce.Reply("Only bridge admins can set another user's login as the relay") diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 397d6140..786c6640 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -261,7 +261,11 @@ func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) { login, _, err := portal.FindPreferredLogin(ctx, sender, true) if err != nil { log.Err(err).Msg("Failed to get user login to handle Matrix event") - portal.sendErrorStatus(ctx, evt, WrapErrorInStatus(err).WithMessage("Failed to get login to handle event").WithIsCertain(true).WithSendNotice(true)) + if errors.Is(err, ErrNotLoggedIn) { + portal.sendErrorStatus(ctx, evt, WrapErrorInStatus(err).WithMessage("You're not logged in").WithIsCertain(true).WithSendNotice(true)) + } else { + portal.sendErrorStatus(ctx, evt, WrapErrorInStatus(err).WithMessage("Failed to get login to handle event").WithIsCertain(true).WithSendNotice(true)) + } return } var origSender *OrigSender @@ -333,7 +337,11 @@ func (portal *Portal) handleMatrixReadReceipt(user *User, eventID id.EventID, re ctx := log.WithContext(context.TODO()) login, userPortal, err := portal.FindPreferredLogin(ctx, user, false) if err != nil { - log.Err(err).Msg("Failed to get preferred login for user") + if !errors.Is(err, ErrNotLoggedIn) { + log.Err(err).Msg("Failed to get preferred login for user") + } + return + } else if login == nil { return } rrClient, ok := login.Client.(ReadReceiptHandlingNetworkAPI) @@ -408,7 +416,9 @@ func (portal *Portal) sendTypings(ctx context.Context, userIDs []id.UserID, typi } login, _, err = portal.FindPreferredLogin(ctx, user, false) if err != nil { - zerolog.Ctx(ctx).Err(err).Stringer("user_id", userID).Msg("Failed to get user login to send typing event") + if !errors.Is(err, ErrNotLoggedIn) { + zerolog.Ctx(ctx).Err(err).Stringer("user_id", userID).Msg("Failed to get user login to send typing event") + } continue } else if login == nil { continue From 47debaae166a1433aaed5fcebd434a37a6624e37 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 30 Jun 2024 00:18:02 +0300 Subject: [PATCH 0386/1647] bridgev2/database: fix inserting reactions --- bridgev2/database/reaction.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/database/reaction.go b/bridgev2/database/reaction.go index b7d675b7..03e9f521 100644 --- a/bridgev2/database/reaction.go +++ b/bridgev2/database/reaction.go @@ -67,7 +67,7 @@ const ( upsertReactionQuery = ` INSERT INTO reaction (bridge_id, message_id, message_part_id, sender_id, emoji_id, room_id, room_receiver, mxid, timestamp, metadata) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) - ON CONFLICT (bridge_id, message_id, message_part_id, sender_id, emoji_id) + ON CONFLICT (bridge_id, room_receiver, message_id, message_part_id, sender_id, emoji_id) DO UPDATE SET mxid=excluded.mxid, timestamp=excluded.timestamp, metadata=excluded.metadata ` deleteReactionQuery = ` From 63c49bf8400dc8a8b87e26a5f07a5fbe5f784766 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 30 Jun 2024 00:25:45 +0300 Subject: [PATCH 0387/1647] bridgev2: send member events to olm machine --- bridgev2/matrix/matrix.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bridgev2/matrix/matrix.go b/bridgev2/matrix/matrix.go index 9d1c8f9d..d5a76ceb 100644 --- a/bridgev2/matrix/matrix.go +++ b/bridgev2/matrix/matrix.go @@ -31,6 +31,9 @@ func (br *Connector) handleRoomEvent(ctx context.Context, evt *event.Event) { br.sendCryptoStatusError(ctx, evt, errMessageNotEncrypted, nil, 0, true) return } + if evt.Type == event.StateMember && br.Crypto != nil { + br.Crypto.HandleMemberEvent(ctx, evt) + } br.Bridge.QueueMatrixEvent(ctx, evt) } From b09e31d0b499e8e75f4c815164cf07d7dfb058e9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 30 Jun 2024 00:35:36 +0300 Subject: [PATCH 0388/1647] bridgev2: fix v5/v6 db migration name --- .../database/upgrades/05-message-receiver-pkey-postgres.sql | 2 +- bridgev2/database/upgrades/06-message-receiver-pkey-sqlite.sql | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bridgev2/database/upgrades/05-message-receiver-pkey-postgres.sql b/bridgev2/database/upgrades/05-message-receiver-pkey-postgres.sql index a19d9d57..1cdbcccf 100644 --- a/bridgev2/database/upgrades/05-message-receiver-pkey-postgres.sql +++ b/bridgev2/database/upgrades/05-message-receiver-pkey-postgres.sql @@ -1,4 +1,4 @@ --- v5 (compatible with v1+): Add room_receiver to message unique key +-- v5 (compatible with v1+): Add room_receiver to message unique key (Postgres) -- only: postgres ALTER TABLE reaction DROP CONSTRAINT reaction_message_fkey; ALTER TABLE reaction DROP CONSTRAINT reaction_pkey1; diff --git a/bridgev2/database/upgrades/06-message-receiver-pkey-sqlite.sql b/bridgev2/database/upgrades/06-message-receiver-pkey-sqlite.sql index d1be1030..b88c5052 100644 --- a/bridgev2/database/upgrades/06-message-receiver-pkey-sqlite.sql +++ b/bridgev2/database/upgrades/06-message-receiver-pkey-sqlite.sql @@ -1,4 +1,4 @@ --- v6 (compatible with v1+): Add room_receiver to message unique key +-- v6 (compatible with v1+): Add room_receiver to message unique key (SQLite) -- transaction: off -- only: sqlite From 0443daef0e6ae5dc7b3a5410857e6442a4048ec6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 30 Jun 2024 01:14:27 +0300 Subject: [PATCH 0389/1647] crypto: use exzerolog.ArrayOfStrs instead of custom function --- crypto/cross_sign_store.go | 4 +++- crypto/devicelist.go | 5 +++-- crypto/encryptmegolm.go | 15 ++++----------- crypto/machine.go | 17 +++++------------ 4 files changed, 15 insertions(+), 26 deletions(-) diff --git a/crypto/cross_sign_store.go b/crypto/cross_sign_store.go index 456ab6ed..968a52a1 100644 --- a/crypto/cross_sign_store.go +++ b/crypto/cross_sign_store.go @@ -10,6 +10,8 @@ package crypto import ( "context" + "go.mau.fi/util/exzerolog" + "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/signatures" "maunium.net/go/mautrix/id" @@ -47,7 +49,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK } for _, key := range userKeys.Keys { - log := log.With().Str("key", key.String()).Strs("usages", strishArray(userKeys.Usage)).Logger() + log := log.With().Str("key", key.String()).Array("usages", exzerolog.ArrayOfStrs(userKeys.Usage)).Logger() for _, usage := range userKeys.Usage { log.Debug().Str("usage", string(usage)).Msg("Storing cross-signing key") if err = mach.CryptoStore.PutCrossSigningKey(ctx, userID, usage, key); err != nil { diff --git a/crypto/devicelist.go b/crypto/devicelist.go index e98ba45a..de6c21f3 100644 --- a/crypto/devicelist.go +++ b/crypto/devicelist.go @@ -12,6 +12,7 @@ import ( "fmt" "github.com/rs/zerolog" + "go.mau.fi/util/exzerolog" "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/signatures" @@ -115,7 +116,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ for _, userID := range users { req.DeviceKeys[userID] = mautrix.DeviceIDList{} } - log.Debug().Strs("users", strishArray(users)).Msg("Querying keys for users") + log.Debug().Array("users", exzerolog.ArrayOfStrs(users)).Msg("Querying keys for users") resp, err := mach.Client.QueryKeys(ctx, req) if err != nil { return nil, fmt.Errorf("failed to query keys: %w", err) @@ -181,7 +182,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ log.Err(err).Msg("Failed to redact megolm sessions from deleted device") } else { log.Info(). - Strs("session_ids", stringifyArray(sessionIDs)). + Array("session_ids", exzerolog.ArrayOfStrs(sessionIDs)). Msg("Redacted megolm sessions from deleted device") } } diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index fd7b8ea2..d8d5c7c9 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -15,6 +15,7 @@ import ( "fmt" "github.com/rs/zerolog" + "go.mau.fi/util/exzerolog" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" @@ -160,14 +161,6 @@ type deviceSessionWrapper struct { identity *id.Device } -func strishArray[T ~string](arr []T) []string { - out := make([]string, len(arr)) - for i, item := range arr { - out[i] = string(item) - } - return out -} - // ShareGroupSession shares a group session for a specific room with all the devices of the given user list. // // For devices with TrustStateBlacklisted, a m.room_key.withheld event with code=m.blacklisted is sent. @@ -193,7 +186,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, } log = log.With().Str("session_id", session.ID().String()).Logger() ctx = log.WithContext(ctx) - log.Debug().Strs("users", strishArray(users)).Msg("Sharing group session for room") + log.Debug().Array("users", exzerolog.ArrayOfStrs(users)).Msg("Sharing group session for room") withheldCount := 0 toDeviceWithheld := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)} @@ -235,10 +228,10 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, } if len(fetchKeysForUsers) > 0 { - log.Debug().Strs("users", strishArray(fetchKeysForUsers)).Msg("Fetching missing keys") + log.Debug().Array("users", exzerolog.ArrayOfStrs(fetchKeysForUsers)).Msg("Fetching missing keys") keys, err := mach.FetchKeys(ctx, fetchKeysForUsers, true) if err != nil { - log.Err(err).Strs("users", strishArray(fetchKeysForUsers)).Msg("Failed to fetch missing keys") + log.Err(err).Array("users", exzerolog.ArrayOfStrs(fetchKeysForUsers)).Msg("Failed to fetch missing keys") return fmt.Errorf("failed to fetch missing keys: %w", err) } for userID, devices := range keys { diff --git a/crypto/machine.go b/crypto/machine.go index 8e9a6c66..2477b9e1 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -15,11 +15,12 @@ import ( "github.com/rs/zerolog" - "maunium.net/go/mautrix/crypto/ssss" - "maunium.net/go/mautrix/id" + "go.mau.fi/util/exzerolog" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto/ssss" "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" ) // OlmMachine is the main struct for handling Matrix end-to-end encryption. @@ -575,14 +576,6 @@ func (mach *OlmMachine) WaitForSession(ctx context.Context, roomID id.RoomID, se } } -func stringifyArray[T ~string](arr []T) []string { - strs := make([]string, len(arr)) - for i, v := range arr { - strs[i] = string(v) - } - return strs -} - func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEvent, content *event.RoomKeyEventContent) { log := zerolog.Ctx(ctx).With(). Str("algorithm", string(content.Algorithm)). @@ -623,7 +616,7 @@ func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEve log.Err(err).Msg("Failed to redact previous megolm sessions") } else { log.Info(). - Strs("session_ids", stringifyArray(sessionIDs)). + Array("session_ids", exzerolog.ArrayOfStrs(sessionIDs)). Msg("Redacted previous megolm sessions") } } @@ -713,7 +706,7 @@ func (mach *OlmMachine) ExpiredKeyDeleteLoop(ctx context.Context) { if err != nil { log.Err(err).Msg("Failed to redact expired megolm sessions") } else if len(sessionIDs) > 0 { - log.Info().Strs("session_ids", stringifyArray(sessionIDs)).Msg("Redacted expired megolm sessions") + log.Info().Array("session_ids", exzerolog.ArrayOfStrs(sessionIDs)).Msg("Redacted expired megolm sessions") } else { log.Debug().Msg("Didn't find any expired megolm sessions") } From 2b96826aa0d574d22fc5143b7104dcbd7c8d9da7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 30 Jun 2024 01:15:31 +0300 Subject: [PATCH 0390/1647] main: update dependencies --- go.mod | 4 ++-- go.sum | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index d381a97d..8a90192a 100644 --- a/go.mod +++ b/go.mod @@ -14,8 +14,8 @@ require ( github.com/stretchr/testify v1.9.0 github.com/tidwall/gjson v1.17.1 github.com/tidwall/sjson v1.2.5 - github.com/yuin/goldmark v1.7.2 - go.mau.fi/util v0.5.1-0.20240626184357-b3f4d78c25cf + github.com/yuin/goldmark v1.7.4 + go.mau.fi/util v0.5.1-0.20240629220711-4fa40bf64652 go.mau.fi/zeroconfig v0.1.2 golang.org/x/crypto v0.24.0 golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 diff --git a/go.sum b/go.sum index 483ae912..e005cbf1 100644 --- a/go.sum +++ b/go.sum @@ -44,10 +44,10 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/yuin/goldmark v1.7.2 h1:NjGd7lO7zrUn/A7eKwn5PEOt4ONYGqpxSEeZuduvgxc= -github.com/yuin/goldmark v1.7.2/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.5.1-0.20240626184357-b3f4d78c25cf h1:ceXQTB6IqjqGBGhzOTEBGbxQu7xDyuT9YR06gxr9Ncw= -go.mau.fi/util v0.5.1-0.20240626184357-b3f4d78c25cf/go.mod h1:DsJzUrJAG53lCZnnYvq9/mOyLuPScWwYhvETiTrpdP4= +github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= +github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= +go.mau.fi/util v0.5.1-0.20240629220711-4fa40bf64652 h1:/wY7vpwOE6he5Qlf6lICHuOUs+nAQdsC7qwRsDbsh14= +go.mau.fi/util v0.5.1-0.20240629220711-4fa40bf64652/go.mod h1:DsJzUrJAG53lCZnnYvq9/mOyLuPScWwYhvETiTrpdP4= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= From 407695e37a69f3b0fd3f81ad6b7f05dee7dd1cac Mon Sep 17 00:00:00 2001 From: rudis Date: Sun, 30 Jun 2024 09:30:44 +0000 Subject: [PATCH 0391/1647] client: don't log warning in State() when StateStore is set (#249) State logs a warning that ClearCachedMembers() fails even when nil is returned as error. Looks like this was forgotten in 581aa80. --- client.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index 1d6433ac..997d7363 100644 --- a/client.go +++ b/client.go @@ -1411,9 +1411,11 @@ func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomSt }) if err == nil && cli.StateStore != nil { clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID) - cli.cliOrContextLog(ctx).Warn().Err(clearErr). - Stringer("room_id", roomID). - Msg("Failed to clear cached member list after fetching state") + if clearErr != nil { + cli.cliOrContextLog(ctx).Warn().Err(clearErr). + Stringer("room_id", roomID). + Msg("Failed to clear cached member list after fetching state") + } for _, evts := range stateMap { for _, evt := range evts { UpdateStateStore(ctx, cli.StateStore, evt) From c6e87a260c952ae587f7e55dec89d3ab98f2e465 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 30 Jun 2024 19:25:37 +0300 Subject: [PATCH 0392/1647] bridgev2/matrix: expose appservice HTTP server to network connectors --- bridgev2/bridgeconfig/appservice.go | 7 ++++--- bridgev2/bridgeconfig/upgrade.go | 1 + bridgev2/matrix/connector.go | 20 +++++++++++++++++++- bridgev2/matrix/mxmain/example-config.yaml | 6 +++++- bridgev2/matrixinterface.go | 7 +++++++ 5 files changed, 36 insertions(+), 5 deletions(-) diff --git a/bridgev2/bridgeconfig/appservice.go b/bridgev2/bridgeconfig/appservice.go index 7466636b..9ff333e9 100644 --- a/bridgev2/bridgeconfig/appservice.go +++ b/bridgev2/bridgeconfig/appservice.go @@ -21,9 +21,10 @@ import ( ) type AppserviceConfig struct { - Address string `yaml:"address"` - Hostname string `yaml:"hostname"` - Port uint16 `yaml:"port"` + Address string `yaml:"address"` + PublicAddress string `yaml:"public_address"` + Hostname string `yaml:"hostname"` + Port uint16 `yaml:"port"` ID string `yaml:"id"` Bot BotUserConfig `yaml:"bot"` diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 7908c268..e74e2368 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -52,6 +52,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Int, "homeserver", "ping_interval_seconds") helper.Copy(up.Str|up.Null, "appservice", "address") + helper.Copy(up.Str|up.Null, "appservice", "public_address") helper.Copy(up.Str|up.Null, "appservice", "hostname") helper.Copy(up.Int|up.Null, "appservice", "port") helper.Copy(up.Str, "appservice", "id") diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 1910c48d..e5b9777f 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -16,6 +16,7 @@ import ( "sync" "time" + "github.com/gorilla/mux" _ "github.com/lib/pq" "github.com/rs/zerolog" "go.mau.fi/util/dbutil" @@ -80,7 +81,10 @@ type Connector struct { latestState *status.BridgeState } -var _ bridgev2.MatrixConnector = (*Connector)(nil) +var ( + _ bridgev2.MatrixConnector = (*Connector)(nil) + _ bridgev2.MatrixConnectorWithServer = (*Connector)(nil) +) func NewConnector(cfg *bridgeconfig.Config) *Connector { c := &Connector{} @@ -154,6 +158,20 @@ func (br *Connector) Start(ctx context.Context) error { return nil } +func (br *Connector) GetPublicAddress() string { + if br.Config.AppService.PublicAddress == "https://bridge.example.com" { + return "" + } + return br.Config.AppService.PublicAddress +} + +func (br *Connector) GetRouter() *mux.Router { + if br.GetPublicAddress() != "" { + return br.AS.Router + } + return nil +} + func (br *Connector) GetCapabilities() *bridgev2.MatrixCapabilities { return br.Capabilities } diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 9b84056f..86de8916 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -91,10 +91,13 @@ homeserver: ping_interval_seconds: 0 # Application service host/registration related details. -# Changing these values requires regeneration of the registration. +# Changing these values requires regeneration of the registration (except when noted otherwise) appservice: # The address that the homeserver can use to connect to this appservice. address: http://localhost:$<> + # A public address that external services can use to reach this appservice. + # This value doesn't affect the registration file. + public_address: https://bridge.example.com # The hostname and port where this appservice should listen. # For Docker, you generally have to change the hostname to 0.0.0.0. @@ -117,6 +120,7 @@ appservice: # Should incoming events be handled asynchronously? # This may be necessary for large public instances with lots of messages going through. # However, messages will not be guaranteed to be bridged in the same order they were sent in. + # This value doesn't affect the registration file. async_transactions: false # Authentication tokens for AS <-> HS communication. Autogenerated; do not modify. diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index cc6b4a98..0ba5f212 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -10,6 +10,8 @@ import ( "context" "time" + "github.com/gorilla/mux" + "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2/networkid" @@ -45,6 +47,11 @@ type MatrixConnector interface { ServerName() string } +type MatrixConnectorWithServer interface { + GetPublicAddress() string + GetRouter() *mux.Router +} + type MatrixAPI interface { GetMXID() id.UserID From d86f710aea3a28bf4511ffd3be01d782628b4171 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 1 Jul 2024 09:39:08 +0300 Subject: [PATCH 0393/1647] bridgev2/mxmain: handle GetConfig returning a nil upgrader --- bridgev2/matrix/mxmain/main.go | 3 +++ go.mod | 2 +- go.sum | 4 ++-- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index d786e215..15d48b32 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -306,6 +306,9 @@ func (br *BridgeMain) validateConfig() error { func (br *BridgeMain) getConfigUpgrader() (configupgrade.BaseUpgrader, any) { networkExample, networkData, networkUpgrader := br.Connector.GetConfig() baseConfig := br.makeFullExampleConfig(networkExample) + if networkUpgrader == nil { + networkUpgrader = configupgrade.NoopUpgrader + } networkUpgraderProxied := &configupgrade.ProxyUpgrader{Target: networkUpgrader, Prefix: []string{"network"}} upgrader := configupgrade.MergeUpgraders(baseConfig, networkUpgraderProxied, bridgeconfig.Upgrader) return upgrader, networkData diff --git a/go.mod b/go.mod index 8a90192a..7111c199 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/tidwall/gjson v1.17.1 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.4 - go.mau.fi/util v0.5.1-0.20240629220711-4fa40bf64652 + go.mau.fi/util v0.5.1-0.20240701063757-6126777abba3 go.mau.fi/zeroconfig v0.1.2 golang.org/x/crypto v0.24.0 golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 diff --git a/go.sum b/go.sum index e005cbf1..6f3e8f5d 100644 --- a/go.sum +++ b/go.sum @@ -46,8 +46,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.5.1-0.20240629220711-4fa40bf64652 h1:/wY7vpwOE6he5Qlf6lICHuOUs+nAQdsC7qwRsDbsh14= -go.mau.fi/util v0.5.1-0.20240629220711-4fa40bf64652/go.mod h1:DsJzUrJAG53lCZnnYvq9/mOyLuPScWwYhvETiTrpdP4= +go.mau.fi/util v0.5.1-0.20240701063757-6126777abba3 h1:EeFfeO2CheFO3HU6SCVQiP6dY8Wwv6dUyJ2SKtPyE70= +go.mau.fi/util v0.5.1-0.20240701063757-6126777abba3/go.mod h1:DsJzUrJAG53lCZnnYvq9/mOyLuPScWwYhvETiTrpdP4= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= From 69d7d7902256b0ef07bb82f8e38a70409422d892 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 1 Jul 2024 09:45:53 +0300 Subject: [PATCH 0394/1647] bridgev2: add godocs for most NetworkAPI functions --- bridgev2/networkinterface.go | 109 ++++++++++++++++++++++++++++++++--- 1 file changed, 100 insertions(+), 9 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 0d14e309..d566517d 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -71,7 +71,8 @@ type BridgeName struct { NetworkIcon id.ContentURIString // An identifier uniquely identifying the network, e.g. `discord` NetworkID string - // An identifier uniquely identifying the bridge software, e.g. `discordgo` + // An identifier uniquely identifying the bridge software. + // The Go import path is a good choice here (e.g. github.com/octocat/discordbridge) BeeperBridgeType string // The default appservice port to use in the example config, defaults to 8080 if unset DefaultPort uint16 @@ -97,21 +98,30 @@ type NetworkConnector interface { // The connector should do any non-user-specific startup actions necessary. // User logins will be loaded separately, so the connector should not load them here. Start(context.Context) error - // LoadUserLogin is called when a UserLogin is loaded from the database in order to fill the [UserLogin.Client] field. + + // GetName returns the name of the bridge and some additional metadata, + // which is used to fill `m.bridge` events among other things. // - // This is called within the bridge's global cache lock, so it must not do any slow operations, - // such as connecting to the network. Instead, connecting should happen when [NetworkAPI.Connect] is called later. - LoadUserLogin(ctx context.Context, login *UserLogin) error - - GetCapabilities() *NetworkGeneralCapabilities - + // The first call happens *before* the config is loaded, because the data here is also used to + // fill parts of the example config (like the default username template and bot localpart). + // The output can still be adjusted based on config variables, but the function must have + // default values when called without a config. GetName() BridgeName + // GetCapabilities returns the general capabilities of the network connector. + // Note that most capabilities are scoped to rooms and are returned by [NetworkAPI.GetCapabilities] instead. + GetCapabilities() *NetworkGeneralCapabilities // GetConfig returns all the parts of the network connector's config file. Specifically: // - example: a string containing an example config file // - data: an interface to unmarshal the actual config into // - upgrader: a config upgrader to ensure all fields are present and to do any migrations from old configs GetConfig() (example string, data any, upgrader configupgrade.Upgrader) + // LoadUserLogin is called when a UserLogin is loaded from the database in order to fill the [UserLogin.Client] field. + // + // This is called within the bridge's global cache lock, so it must not do any slow operations, + // such as connecting to the network. Instead, connecting should happen when [NetworkAPI.Connect] is called later. + LoadUserLogin(ctx context.Context, login *UserLogin) error + // GetLoginFlows returns a list of login flows that the network supports. GetLoginFlows() []LoginFlow // CreateLogin is called when a user wants to log in to the network. @@ -204,68 +214,144 @@ type NetworkRoomCapabilities struct { } // NetworkAPI is an interface representing a remote network client for a single user login. +// +// Implementations of this interface are stored in [UserLogin.Client]. +// The [NetworkConnector.LoadUserLogin] method is responsible for filling the Client field with a NetworkAPI. type NetworkAPI interface { + // Connect is called to actually connect to the remote network. + // If there's no persistent connection, this may just check access token validity, or even do nothing at all. Connect(ctx context.Context) error + // Disconnect should disconnect from the remote network. + // A clean disconnection is preferred, but it should not take too long. Disconnect() + // IsLoggedIn should return whether the access tokens in this NetworkAPI are valid. + // This should not do any IO operations, it should only return cached data which is updated elsewhere. IsLoggedIn() bool + // LogoutRemote should invalidate the access tokens in this NetworkAPI if possible + // and disconnect from the remote network. LogoutRemote(ctx context.Context) + // IsThisUser should return whether the given remote network user ID is the same as this login. + // This is used when the bridge wants to convert a user login ID to a user ID. IsThisUser(ctx context.Context, userID networkid.UserID) bool + // GetChatInfo returns info for a given chat. Any fields that are nil will be ignored and not processed at all, + // while empty strings will change the relevant value in the room to be an empty string. + // For example, a nil name will mean the room name is not changed, while an empty string name will remove the name. GetChatInfo(ctx context.Context, portal *Portal) (*ChatInfo, error) + // GetUserInfo returns info for a given user. Like chat info, fields can be nil to skip them. GetUserInfo(ctx context.Context, ghost *Ghost) (*UserInfo, error) + // GetCapabilities returns the bridging capabilities in a given room. + // This can simply return a static list if the remote network has no per-chat capability differences, + // but all calls will include the portal, because some networks do have per-chat differences. GetCapabilities(ctx context.Context, portal *Portal) *NetworkRoomCapabilities + // HandleMatrixMessage is called when a message is sent from Matrix in an existing portal room. + // This function should convert the message as appropriate, send it over to the remote network, + // and return the info so the central bridge can store it in the database. + // + // This is only called for normal non-edit messages. For other types of events, see the optional extra interfaces (`XHandlingNetworkAPI`). HandleMatrixMessage(ctx context.Context, msg *MatrixMessage) (message *MatrixMessageResponse, err error) } +// EditHandlingNetworkAPI is an optional interface that network connectors can implement to handle message edits. type EditHandlingNetworkAPI interface { NetworkAPI + // HandleMatrixEdit is called when a previously bridged message is edited in a portal room. + // The central bridge module will save the [*database.Message] after this function returns, + // so the network connector is allowed to mutate the provided object. HandleMatrixEdit(ctx context.Context, msg *MatrixEdit) error } +// ReactionHandlingNetworkAPI is an optional interface that network connectors can implement to handle message reactions. type ReactionHandlingNetworkAPI interface { NetworkAPI + // PreHandleMatrixReaction is called as the first step of handling a reaction. It returns the emoji ID, + // sender user ID and max reaction count to allow the central bridge module to de-duplicate the reaction + // if appropriate. PreHandleMatrixReaction(ctx context.Context, msg *MatrixReaction) (MatrixReactionPreResponse, error) + // HandleMatrixReaction is called after confirming that the reaction is not a duplicate. + // This is the method that should actually send the reaction to the remote network. + // The returned [database.Reaction] object may be empty: the central bridge module already has + // all the required fields and will fill them automatically if they're empty. However, network + // connectors are allowed to set fields themselves if any extra fields are necessary. HandleMatrixReaction(ctx context.Context, msg *MatrixReaction) (reaction *database.Reaction, err error) + // HandleMatrixReactionRemove is called when a redaction event is received pointing at a previously + // bridged reaction. The network connector should remove the reaction from the remote network. HandleMatrixReactionRemove(ctx context.Context, msg *MatrixReactionRemove) error } +// RedactionHandlingNetworkAPI is an optional interface that network connectors can implement to handle message deletions. type RedactionHandlingNetworkAPI interface { NetworkAPI + // HandleMatrixMessageRemove is called when a previously bridged message is deleted in a portal room. HandleMatrixMessageRemove(ctx context.Context, msg *MatrixMessageRemove) error } +// ReadReceiptHandlingNetworkAPI is an optional interface that network connectors can implement to handle read receipts. type ReadReceiptHandlingNetworkAPI interface { NetworkAPI + // HandleMatrixReadReceipt is called when a read receipt is sent in a portal room. + // This will be called even if the target message is not a bridged message. + // Network connectors must gracefully handle [MatrixReadReceipt.ExactMessage] being nil. + // The exact handling is up to the network connector. HandleMatrixReadReceipt(ctx context.Context, msg *MatrixReadReceipt) error } +// TypingHandlingNetworkAPI is an optional interface that network connectors can implement to handle typing events. type TypingHandlingNetworkAPI interface { NetworkAPI + // HandleMatrixTyping is called when a user starts typing in a portal room. + // In the future, the central bridge module will likely get a loop to automatically repeat + // calls to this function until the user stops typing. HandleMatrixTyping(ctx context.Context, msg *MatrixTyping) error } +// RoomNameHandlingNetworkAPI is an optional interface that network connectors can implement to handle room name changes. type RoomNameHandlingNetworkAPI interface { NetworkAPI + // HandleMatrixRoomName is called when the name of a portal room is changed. + // This method should update the Name and NameSet fields of the Portal with + // the new name and return true if the change was successful. + // If the change is not successful, then the fields should not be updated. HandleMatrixRoomName(ctx context.Context, msg *MatrixRoomName) (bool, error) } +// RoomAvatarHandlingNetworkAPI is an optional interface that network connectors can implement to handle room avatar changes. type RoomAvatarHandlingNetworkAPI interface { NetworkAPI + // HandleMatrixRoomAvatar is called when the avatar of a portal room is changed. + // This method should update the AvatarID, AvatarHash and AvatarMXC fields + // with the new avatar details and return true if the change was successful. + // If the change is not successful, then the fields should not be updated. HandleMatrixRoomAvatar(ctx context.Context, msg *MatrixRoomAvatar) (bool, error) } +// RoomTopicHandlingNetworkAPI is an optional interface that network connectors can implement to handle room topic changes. type RoomTopicHandlingNetworkAPI interface { NetworkAPI + // HandleMatrixRoomTopic is called when the topic of a portal room is changed. + // This method should update the Topic and TopicSet fields of the Portal with + // the new topic and return true if the change was successful. + // If the change is not successful, then the fields should not be updated. HandleMatrixRoomTopic(ctx context.Context, msg *MatrixRoomTopic) (bool, error) } type ResolveIdentifierResponse struct { + // Ghost is the ghost of the user that the identifier resolves to. + // This field should be set whenever possible. However, it is not required, + // and the central bridge module will not try to create a ghost if it is not set. Ghost *Ghost - UserID networkid.UserID + // UserID is the user ID of the user that the identifier resolves to. + UserID networkid.UserID + // UserInfo contains the info of the user that the identifier resolves to. + // If both this and the Ghost field are set, the central bridge module will + // automatically update the ghost's info with the data here. UserInfo *UserInfo + // Chat contains info about the direct chat with the resolved user. + // This field is required when createChat is true in the ResolveIdentifier call, + // and optional otherwise. Chat *CreateChatResponse } @@ -276,11 +362,16 @@ type CreateChatResponse struct { PortalInfo *ChatInfo } +// IdentifierResolvingNetworkAPI is an optional interface that network connectors can implement to support starting new direct chats. type IdentifierResolvingNetworkAPI interface { NetworkAPI + // ResolveIdentifier is called when the user wants to start a new chat. + // This can happen via the `resolve-identifier` or `start-chat` bridge bot commands, + // or the corresponding provisioning API endpoints. ResolveIdentifier(ctx context.Context, identifier string, createChat bool) (*ResolveIdentifierResponse, error) } +// ContactListingNetworkAPI is an optional interface that network connectors can implement to provide the user's contact list. type ContactListingNetworkAPI interface { NetworkAPI GetContactList(ctx context.Context) ([]*ResolveIdentifierResponse, error) From bec59868d5360c4bc1413818618dbbf740889cbf Mon Sep 17 00:00:00 2001 From: Simon Ruderich Date: Tue, 2 Jul 2024 06:14:21 +0000 Subject: [PATCH 0395/1647] event/powerlevels: use 0 as default required level for invite (#250) --- event/powerlevels.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/event/powerlevels.go b/event/powerlevels.go index d291eacd..c60910d9 100644 --- a/event/powerlevels.go +++ b/event/powerlevels.go @@ -96,7 +96,7 @@ func (pl *PowerLevelsEventContent) Invite() int { if pl.InvitePtr != nil { return *pl.InvitePtr } - return 50 + return 0 } func (pl *PowerLevelsEventContent) Kick() int { From 9b647fe945599a5cc2c1beb739e0f9b17b048645 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 1 Jul 2024 10:52:16 +0300 Subject: [PATCH 0396/1647] bridgev2: make User.NewLogin smarter The function now knows how to reuse existing logins instead of only inserting new ones --- bridgev2/userlogin.go | 69 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 56 insertions(+), 13 deletions(-) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 22f7504c..289c9fc8 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -9,6 +9,7 @@ package bridgev2 import ( "context" "fmt" + "maps" "sync" "time" @@ -125,25 +126,67 @@ func (br *Bridge) GetCachedUserLoginByID(id networkid.UserLoginID) *UserLogin { return br.userLoginsByID[id] } -func (user *User) NewLogin(ctx context.Context, data *database.UserLogin, client NetworkAPI) (*UserLogin, error) { +type NewLoginParams struct { + LoadUserLogin func(context.Context, *UserLogin) error + DeleteOnConflict bool + DontReuseExisting bool +} + +func (user *User) NewLogin(ctx context.Context, data *database.UserLogin, params NewLoginParams) (*UserLogin, error) { data.BridgeID = user.BridgeID data.UserMXID = user.MXID - ul := &UserLogin{ - UserLogin: data, - Bridge: user.Bridge, - User: user, - Log: user.Log.With().Str("login_id", string(data.ID)).Logger(), - Client: client, + if params.LoadUserLogin == nil { + params.LoadUserLogin = user.Bridge.Network.LoadUserLogin } - err := user.Bridge.DB.UserLogin.Insert(ctx, ul.UserLogin) + ul, err := user.Bridge.GetExistingUserLoginByID(ctx, data.ID) + if err != nil { + return nil, fmt.Errorf("failed to check if login already exists: %w", err) + } + var doInsert bool + if ul != nil && ul.UserMXID != user.MXID { + if params.DeleteOnConflict { + ul.Delete(ctx, status.BridgeState{StateEvent: status.StateLoggedOut, Error: "overridden-by-another-user"}, false) + ul = nil + } else { + return nil, fmt.Errorf("%s is already logged in with that account", ul.UserMXID) + } + } + if ul != nil { + if params.DontReuseExisting { + return nil, fmt.Errorf("login already exists") + } + doInsert = false + ul.Metadata.RemoteName = data.Metadata.RemoteName + maps.Copy(ul.Metadata.Extra, data.Metadata.Extra) + } else { + doInsert = true + ul = &UserLogin{ + UserLogin: data, + Bridge: user.Bridge, + User: user, + Log: user.Log.With().Str("login_id", string(data.ID)).Logger(), + } + ul.BridgeState = user.Bridge.NewBridgeStateQueue(ul) + } + err = params.LoadUserLogin(ul.Log.WithContext(context.Background()), ul) if err != nil { return nil, err } - ul.BridgeState = user.Bridge.NewBridgeStateQueue(ul) - user.Bridge.cacheLock.Lock() - defer user.Bridge.cacheLock.Unlock() - user.Bridge.userLoginsByID[ul.ID] = ul - user.logins[ul.ID] = ul + if doInsert { + err = user.Bridge.DB.UserLogin.Insert(ctx, ul.UserLogin) + if err != nil { + return nil, err + } + user.Bridge.cacheLock.Lock() + user.Bridge.userLoginsByID[ul.ID] = ul + user.logins[ul.ID] = ul + user.Bridge.cacheLock.Unlock() + } else { + err = ul.Save(ctx) + if err != nil { + return nil, err + } + } return ul, nil } From 82a579f10e59ef0d4a4e7189d81f1d65bfa9c664 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 2 Jul 2024 10:29:03 +0300 Subject: [PATCH 0397/1647] bridgev2: add godocs for NewLogin --- bridgev2/userlogin.go | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 289c9fc8..f20f4657 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -132,9 +132,22 @@ type NewLoginParams struct { DontReuseExisting bool } -func (user *User) NewLogin(ctx context.Context, data *database.UserLogin, params NewLoginParams) (*UserLogin, error) { +// NewLogin creates a UserLogin object for this user with the given parameters. +// +// If a login already exists with the same ID, it is reused after updating the remote name +// and metadata from the provided data, unless DontReuseExisting is set in params. +// +// If the existing login belongs to another user, this returns an error, +// unless DeleteOnConflict is set in the params, in which case the existing login is deleted. +// +// This will automatically call LoadUserLogin after creating the UserLogin object. +// The load method defaults to the network connector's LoadUserLogin method, but it can be overridden in params. +func (user *User) NewLogin(ctx context.Context, data *database.UserLogin, params *NewLoginParams) (*UserLogin, error) { data.BridgeID = user.BridgeID data.UserMXID = user.MXID + if params == nil { + params = &NewLoginParams{} + } if params.LoadUserLogin == nil { params.LoadUserLogin = user.Bridge.Network.LoadUserLogin } From 74c0110ee0e4cc47706e5708c966645a3c07610e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 2 Jul 2024 11:12:38 +0300 Subject: [PATCH 0398/1647] misc: remove some local functions in favor of generic ones --- crypto/keyexport.go | 17 +++----------- crypto/olm/inboundgroupsession.go | 9 ++------ crypto/sql_store.go | 9 +------- crypto/ssss/key.go | 25 ++++++-------------- event/powerlevels.go | 38 +++++++++---------------------- go.mod | 2 +- go.sum | 4 ++-- 7 files changed, 27 insertions(+), 77 deletions(-) diff --git a/crypto/keyexport.go b/crypto/keyexport.go index bb373f4d..3d126db4 100644 --- a/crypto/keyexport.go +++ b/crypto/keyexport.go @@ -11,7 +11,6 @@ import ( "crypto/aes" "crypto/cipher" "crypto/hmac" - "crypto/rand" "crypto/sha256" "crypto/sha512" "encoding/base64" @@ -20,9 +19,9 @@ import ( "fmt" "math" + "go.mau.fi/util/random" "golang.org/x/crypto/pbkdf2" - "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) @@ -66,25 +65,15 @@ func computeKey(passphrase string, salt []byte, rounds int) (encryptionKey, hash } func makeExportIV() []byte { - iv := make([]byte, 16) - _, err := rand.Read(iv) - if err != nil { - panic(olm.NotEnoughGoRandom) - } + iv := random.Bytes(16) // Set bit 63 to zero iv[7] &= 0b11111110 return iv } func makeExportKeys(passphrase string) (encryptionKey, hashKey, salt, iv []byte) { - salt = make([]byte, 16) - _, err := rand.Read(salt) - if err != nil { - panic(olm.NotEnoughGoRandom) - } - + salt = random.Bytes(16) encryptionKey, hashKey = computeKey(passphrase, salt, defaultPassphraseRounds) - iv = makeExportIV() return } diff --git a/crypto/olm/inboundgroupsession.go b/crypto/olm/inboundgroupsession.go index a3bd3b65..cac49d18 100644 --- a/crypto/olm/inboundgroupsession.go +++ b/crypto/olm/inboundgroupsession.go @@ -7,6 +7,7 @@ package olm import "C" import ( + "bytes" "encoding/base64" "unsafe" @@ -190,12 +191,6 @@ func (s *InboundGroupSession) UnmarshalJSON(data []byte) error { return s.Unpickle(data[1:len(data)-1], pickleKey) } -func clone(original []byte) []byte { - clone := make([]byte, len(original)) - copy(clone, original) - return clone -} - // decryptMaxPlaintextLen returns the maximum number of bytes of plain-text a // given message could decode to. The actual size could be different due to // padding. Returns error on failure. If the message base64 couldn't be @@ -208,7 +203,7 @@ func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, erro return 0, EmptyInput } // olm_group_decrypt_max_plaintext_length destroys the input, so we have to clone it - message = clone(message) + message = bytes.Clone(message) r := C.olm_group_decrypt_max_plaintext_length( (*C.OlmInboundGroupSession)(s.int), (*C.uint8_t)(&message[0]), diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 0d824364..0b71e36d 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -261,13 +261,6 @@ func (store *SQLCryptoStore) UpdateSession(ctx context.Context, _ id.SenderKey, return err } -func intishPtr[T int | int64](i T) *T { - if i == 0 { - return nil - } - return &i -} - func datePtr(t time.Time) *time.Time { if t.IsZero() { return nil @@ -308,7 +301,7 @@ func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, session *Inbou key_backup_version=excluded.key_backup_version `, session.ID(), session.SenderKey, session.SigningKey, session.RoomID, sessionBytes, forwardingChains, - ratchetSafety, datePtr(session.ReceivedAt), intishPtr(session.MaxAge), intishPtr(session.MaxMessages), + ratchetSafety, datePtr(session.ReceivedAt), dbutil.NumPtr(session.MaxAge), dbutil.NumPtr(session.MaxMessages), session.IsScheduled, session.KeyBackupVersion, store.AccountID, ) return err diff --git a/crypto/ssss/key.go b/crypto/ssss/key.go index 3c38d3cd..c973c1fe 100644 --- a/crypto/ssss/key.go +++ b/crypto/ssss/key.go @@ -7,11 +7,12 @@ package ssss import ( - "crypto/rand" "encoding/base64" "fmt" "strings" + "go.mau.fi/util/random" + "maunium.net/go/mautrix/crypto/utils" ) @@ -33,10 +34,7 @@ func NewKey(passphrase string) (*Key, error) { if len(passphrase) > 0 { // There's a passphrase. We need to generate a salt for it, set the metadata // and then compute the key using the passphrase and the metadata. - saltBytes := make([]byte, 24) - if _, err := rand.Read(saltBytes); err != nil { - return nil, fmt.Errorf("failed to get random bytes for salt: %w", err) - } + saltBytes := random.Bytes(24) keyData.Passphrase = &PassphraseMetadata{ Algorithm: PassphraseAlgorithmPBKDF2, Iterations: 500000, @@ -50,24 +48,15 @@ func NewKey(passphrase string) (*Key, error) { } } else { // No passphrase, just generate a random key - ssssKey = make([]byte, 32) - if _, err := rand.Read(ssssKey); err != nil { - return nil, fmt.Errorf("failed to get random bytes for key: %w", err) - } + ssssKey = random.Bytes(32) } // Generate a random ID for the key. It's what identifies the key in account data. - keyIDBytes := make([]byte, 24) - if _, err := rand.Read(keyIDBytes); err != nil { - return nil, fmt.Errorf("failed to get random bytes for key ID: %w", err) - } + keyIDBytes := random.Bytes(24) // We store a certain hash in the key metadata so that clients can check if the user entered the correct key. - var ivBytes [utils.AESCTRIVLength]byte - if _, err := rand.Read(ivBytes[:]); err != nil { - return nil, fmt.Errorf("failed to get random bytes for IV: %w", err) - } - keyData.IV = base64.RawStdEncoding.EncodeToString(ivBytes[:]) + ivBytes := random.Bytes(utils.AESCTRIVLength) + keyData.IV = base64.RawStdEncoding.EncodeToString(ivBytes) keyData.MAC = keyData.calculateHash(ssssKey) return &Key{ diff --git a/event/powerlevels.go b/event/powerlevels.go index c60910d9..1882f1e9 100644 --- a/event/powerlevels.go +++ b/event/powerlevels.go @@ -9,6 +9,9 @@ package event import ( "sync" + "go.mau.fi/util/ptr" + "golang.org/x/exp/maps" + "maunium.net/go/mautrix/id" ) @@ -33,42 +36,23 @@ type PowerLevelsEventContent struct { RedactPtr *int `json:"redact,omitempty"` } -func copyPtr(ptr *int) *int { - if ptr == nil { - return nil - } - val := *ptr - return &val -} - -func copyMap[Key comparable](m map[Key]int) map[Key]int { - if m == nil { - return nil - } - copied := make(map[Key]int, len(m)) - for k, v := range m { - copied[k] = v - } - return copied -} - func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent { if pl == nil { return nil } return &PowerLevelsEventContent{ - Users: copyMap(pl.Users), + Users: maps.Clone(pl.Users), UsersDefault: pl.UsersDefault, - Events: copyMap(pl.Events), + Events: maps.Clone(pl.Events), EventsDefault: pl.EventsDefault, - StateDefaultPtr: copyPtr(pl.StateDefaultPtr), + StateDefaultPtr: ptr.Clone(pl.StateDefaultPtr), Notifications: pl.Notifications.Clone(), - InvitePtr: copyPtr(pl.InvitePtr), - KickPtr: copyPtr(pl.KickPtr), - BanPtr: copyPtr(pl.BanPtr), - RedactPtr: copyPtr(pl.RedactPtr), + InvitePtr: ptr.Clone(pl.InvitePtr), + KickPtr: ptr.Clone(pl.KickPtr), + BanPtr: ptr.Clone(pl.BanPtr), + RedactPtr: ptr.Clone(pl.RedactPtr), } } @@ -81,7 +65,7 @@ func (npl *NotificationPowerLevels) Clone() *NotificationPowerLevels { return nil } return &NotificationPowerLevels{ - RoomPtr: copyPtr(npl.RoomPtr), + RoomPtr: ptr.Clone(npl.RoomPtr), } } diff --git a/go.mod b/go.mod index 7111c199..2610717d 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/tidwall/gjson v1.17.1 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.4 - go.mau.fi/util v0.5.1-0.20240701063757-6126777abba3 + go.mau.fi/util v0.5.1-0.20240702075351-577617730cb7 go.mau.fi/zeroconfig v0.1.2 golang.org/x/crypto v0.24.0 golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 diff --git a/go.sum b/go.sum index 6f3e8f5d..b3fd55a7 100644 --- a/go.sum +++ b/go.sum @@ -46,8 +46,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.5.1-0.20240701063757-6126777abba3 h1:EeFfeO2CheFO3HU6SCVQiP6dY8Wwv6dUyJ2SKtPyE70= -go.mau.fi/util v0.5.1-0.20240701063757-6126777abba3/go.mod h1:DsJzUrJAG53lCZnnYvq9/mOyLuPScWwYhvETiTrpdP4= +go.mau.fi/util v0.5.1-0.20240702075351-577617730cb7 h1:1avw60QZMpzzMMisf6Jqm+WSycZ59OHJA5IlSXHCCPE= +go.mau.fi/util v0.5.1-0.20240702075351-577617730cb7/go.mod h1:DsJzUrJAG53lCZnnYvq9/mOyLuPScWwYhvETiTrpdP4= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= From b5324dffde45e154a7df38db507bfdb357858c88 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 6 Jul 2024 10:11:44 +0300 Subject: [PATCH 0399/1647] crypto/attachment: implement io.Seeker in EncryptStream (#243) --- crypto/attachment/attachments.go | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/crypto/attachment/attachments.go b/crypto/attachment/attachments.go index e516fded..5f1e3be9 100644 --- a/crypto/attachment/attachments.go +++ b/crypto/attachment/attachments.go @@ -12,6 +12,7 @@ import ( "crypto/sha256" "encoding/base64" "errors" + "fmt" "hash" "io" @@ -136,6 +137,27 @@ type encryptingReader struct { isDecrypting bool } +func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) { + if r.closed { + return 0, ReaderClosed + } + if offset != 0 || whence != io.SeekStart { + return 0, fmt.Errorf("attachments.EncryptStream: only seeking to the beginning is supported") + } + seeker, ok := r.source.(io.ReadSeeker) + if !ok { + return 0, fmt.Errorf("attachments.EncryptStream: source reader (%T) is not an io.ReadSeeker", r.source) + } + n, err := seeker.Seek(offset, whence) + if err != nil { + return 0, err + } + block, _ := aes.NewCipher(r.file.decoded.key[:]) + r.stream = cipher.NewCTR(block, r.file.decoded.iv[:]) + r.hash.Reset() + return n, nil +} + func (r *encryptingReader) Read(dst []byte) (n int, err error) { if r.closed { return 0, ReaderClosed From 65ae2fce425cdb6d2d51e2d346b4265f47feb6b7 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Sat, 6 Jul 2024 01:11:55 -0600 Subject: [PATCH 0400/1647] appservice/intent: fix error handling on double-puppet invites (#252) Signed-off-by: Sumner Evans --- appservice/intent.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/appservice/intent.go b/appservice/intent.go index cddac965..9d6b55e5 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -114,7 +114,7 @@ func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, ext } var inviteErr error if intent.IsCustomPuppet { - _, err = bot.SendStateEvent(ctx, roomID, event.StateMember, intent.UserID.String(), &event.Content{ + _, inviteErr = bot.SendStateEvent(ctx, roomID, event.StateMember, intent.UserID.String(), &event.Content{ Raw: map[string]any{ "fi.mau.will_auto_accept": true, }, From 017ca6222317f4ccf9c7612d7152669ab9b19bc2 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Sat, 6 Jul 2024 01:12:02 -0600 Subject: [PATCH 0401/1647] bridgev2/portal: don't shadow redaction remove error (#251) Signed-off-by: Sumner Evans --- bridgev2/portal.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 786c6640..73ee81c9 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -897,6 +897,7 @@ func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLog portal.sendErrorStatus(ctx, evt, ErrRedactionsNotSupported) return } + var redactionTargetReaction *database.Reaction redactionTargetMsg, err := portal.Bridge.DB.Message.GetPartByMXID(ctx, content.Redacts) if err != nil { log.Err(err).Msg("Failed to get redaction target message from database") @@ -917,7 +918,7 @@ func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLog }, TargetMessage: redactionTargetMsg, }) - } else if redactionTargetReaction, err := portal.Bridge.DB.Reaction.GetByMXID(ctx, content.Redacts); err != nil { + } else if redactionTargetReaction, err = portal.Bridge.DB.Reaction.GetByMXID(ctx, content.Redacts); err != nil { log.Err(err).Msg("Failed to get redaction target reaction from database") portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get redaction target message reaction: %w", ErrDatabaseError, err)) return From e9545f96678d62e5dab18b524852e6dd3bdbc065 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 6 Jul 2024 14:32:37 +0300 Subject: [PATCH 0402/1647] bridgev2/portal: catch panics in event handlers --- bridgev2/messagestatus.go | 1 + bridgev2/portal.go | 78 +++++++++++++++++++++++++++------------ 2 files changed, 56 insertions(+), 23 deletions(-) diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index 2d9a6ddb..90b3d023 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -16,6 +16,7 @@ import ( ) var ( + ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true) ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false) ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false) ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 73ee81c9..1a27a8cb 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -10,6 +10,7 @@ import ( "context" "errors" "fmt" + "runtime/debug" "strings" "sync" "time" @@ -242,22 +243,41 @@ func (portal *Portal) sendErrorStatus(ctx context.Context, evt *event.Event, err } func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) { - if evt.Mautrix.EventSource&event.SourceEphemeral != 0 { - switch evt.Type { - case event.EphemeralEventReceipt: - portal.handleMatrixReceipts(evt) - case event.EphemeralEventTyping: - portal.handleMatrixTyping(evt) - } - return - } log := portal.Log.With(). Str("action", "handle matrix event"). Str("event_type", evt.Type.Type). - Stringer("event_id", evt.ID). - Stringer("sender", sender.MXID). Logger() ctx := log.WithContext(context.TODO()) + defer func() { + if err := recover(); err != nil { + logEvt := log.Error() + if realErr, ok := err.(error); ok { + logEvt = logEvt.Err(realErr) + } else { + logEvt = logEvt.Any(zerolog.ErrorFieldName, err) + } + logEvt. + Bytes("stack", debug.Stack()). + Msg("Matrix event handler panicked") + if evt.ID != "" { + go portal.sendErrorStatus(ctx, evt, ErrPanicInEventHandler) + } + } + }() + if evt.Mautrix.EventSource&event.SourceEphemeral != 0 { + switch evt.Type { + case event.EphemeralEventReceipt: + portal.handleMatrixReceipts(ctx, evt) + case event.EphemeralEventTyping: + portal.handleMatrixTyping(ctx, evt) + } + return + } + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c. + Stringer("event_id", evt.ID). + Stringer("sender", sender.MXID) + }) login, _, err := portal.FindPreferredLogin(ctx, sender, true) if err != nil { log.Err(err).Msg("Failed to get user login to handle Matrix event") @@ -307,7 +327,7 @@ func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) { } } -func (portal *Portal) handleMatrixReceipts(evt *event.Event) { +func (portal *Portal) handleMatrixReceipts(ctx context.Context, evt *event.Event) { content, ok := evt.Content.Parsed.(*event.ReceiptEventContent) if !ok { return @@ -318,23 +338,23 @@ func (portal *Portal) handleMatrixReceipts(evt *event.Event) { continue } for userID, receipt := range readReceipts { - sender, err := portal.Bridge.GetUserByMXID(context.TODO(), userID) + sender, err := portal.Bridge.GetUserByMXID(ctx, userID) if err != nil { // TODO log return } - portal.handleMatrixReadReceipt(sender, evtID, receipt) + portal.handleMatrixReadReceipt(ctx, sender, evtID, receipt) } } } -func (portal *Portal) handleMatrixReadReceipt(user *User, eventID id.EventID, receipt event.ReadReceipt) { - log := portal.Log.With(). - Str("action", "handle matrix read receipt"). - Stringer("event_id", eventID). - Stringer("user_id", user.MXID). - Logger() - ctx := log.WithContext(context.TODO()) +func (portal *Portal) handleMatrixReadReceipt(ctx context.Context, user *User, eventID id.EventID, receipt event.ReadReceipt) { + log := zerolog.Ctx(ctx) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c. + Stringer("event_id", eventID). + Stringer("user_id", user.MXID) + }) login, userPortal, err := portal.FindPreferredLogin(ctx, user, false) if err != nil { if !errors.Is(err, ErrNotLoggedIn) { @@ -384,7 +404,7 @@ func (portal *Portal) handleMatrixReadReceipt(user *User, eventID id.EventID, re portal.Bridge.DisappearLoop.StartAll(ctx, portal.MXID) } -func (portal *Portal) handleMatrixTyping(evt *event.Event) { +func (portal *Portal) handleMatrixTyping(ctx context.Context, evt *event.Event) { content, ok := evt.Content.Parsed.(*event.TypingEventContent) if !ok { return @@ -395,7 +415,6 @@ func (portal *Portal) handleMatrixTyping(evt *event.Event) { stoppedTyping, startedTyping := exslices.SortedDiff(portal.currentlyTyping, content.UserIDs, func(a, b id.UserID) int { return strings.Compare(string(a), string(b)) }) - ctx := portal.Log.WithContext(context.TODO()) portal.sendTypings(ctx, stoppedTyping, false) portal.sendTypings(ctx, startedTyping, true) portal.currentlyTyping = content.UserIDs @@ -957,6 +976,19 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { Str("source_id", string(source.ID)). Str("action", "handle remote event"). Logger() + defer func() { + if err := recover(); err != nil { + logEvt := log.Error() + if realErr, ok := err.(error); ok { + logEvt = logEvt.Err(realErr) + } else { + logEvt = logEvt.Any(zerolog.ErrorFieldName, err) + } + logEvt. + Bytes("stack", debug.Stack()). + Msg("Remote event handler panicked") + } + }() log.UpdateContext(evt.AddLogContext) ctx := log.WithContext(context.TODO()) if portal.MXID == "" { From 11421ec6430236464f73adce5c88f6ccb2dbbf0a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 6 Jul 2024 14:32:52 +0300 Subject: [PATCH 0403/1647] bridgev2/participants: treat empty membership as join --- bridgev2/portal.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 1a27a8cb..1d147f68 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1768,6 +1768,9 @@ func (portal *Portal) GetInitialMemberList(ctx context.Context, members *ChatMem } members.PowerLevels.Apply(pl) for _, member := range members.Members { + if member.Membership != event.MembershipJoin && member.Membership != "" { + continue + } intent, extraUserID := portal.getIntentAndUserMXIDFor(ctx, member.EventSender, source, loginsInPortal, 0) if extraUserID != "" { invite = append(invite, extraUserID) @@ -1810,6 +1813,9 @@ func (portal *Portal) SyncParticipants(ctx context.Context, members *ChatMemberL delete(currentMembers, portal.Bridge.Bot.GetMXID()) powerChanged := members.PowerLevels.Apply(currentPower) syncUser := func(extraUserID id.UserID, member ChatMember, hasIntent bool) bool { + if member.Membership == "" { + member.Membership = event.MembershipJoin + } powerChanged = currentPower.EnsureUserLevel(extraUserID, member.PowerLevel) || powerChanged currentMember, ok := currentMembers[extraUserID] delete(currentMembers, extraUserID) From e9034dc9f1b80c0c85924c5b95c14bca74cd244d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 6 Jul 2024 14:51:45 +0300 Subject: [PATCH 0404/1647] bridgev2/mxmain: refuse to write example config to existing file --- bridgev2/matrix/mxmain/main.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index 15d48b32..1e5cab4e 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -159,8 +159,13 @@ func (br *BridgeMain) PreInit() { _ = json.NewEncoder(os.Stdout).Encode(output) os.Exit(0) } else if *writeExampleConfig { + if _, err = os.Stat(*configPath); !errors.Is(err, os.ErrNotExist) { + _, _ = fmt.Fprintln(os.Stderr, *configPath, "already exists, please remove it if you want to generate a new example") + os.Exit(1) + } networkExample, _, _ := br.Connector.GetConfig() exerrors.PanicIfNotNil(os.WriteFile(*configPath, []byte(br.makeFullExampleConfig(networkExample)), 0600)) + fmt.Println("Wrote example config to", *configPath) os.Exit(0) } br.LoadConfig() From e716b1ca0837e87f57e736579eadd9a6696f9aef Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 6 Jul 2024 14:59:05 +0300 Subject: [PATCH 0405/1647] bridgev2/mxmain: allow network connector to have no config --- bridgev2/matrix/mxmain/main.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index 1e5cab4e..4c25fb33 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -337,10 +337,12 @@ func (br *BridgeMain) LoadConfig() { _, _ = fmt.Fprintln(os.Stderr, "Failed to parse config:", err) os.Exit(10) } - err = cfg.Network.Decode(networkData) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to parse network config:", err) - os.Exit(10) + if networkData != nil { + err = cfg.Network.Decode(networkData) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Failed to parse network config:", err) + os.Exit(10) + } } br.Config = &cfg } From be24616d9f03bd0f3b819cc5f7e9e218903594e3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 6 Jul 2024 15:00:32 +0300 Subject: [PATCH 0406/1647] bridgev2/database: fix order of tables in initial schema --- bridgev2/database/upgrades/00-latest.sql | 46 ++++++++++++------------ 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index b84f0549..51201aba 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,27 @@ -- v0 -> v6 (compatible with v1+): Latest revision +CREATE TABLE "user" ( + bridge_id TEXT NOT NULL, + mxid TEXT NOT NULL, + + management_room TEXT, + access_token TEXT, + + PRIMARY KEY (bridge_id, mxid) +); + +CREATE TABLE user_login ( + bridge_id TEXT NOT NULL, + user_mxid TEXT NOT NULL, + id TEXT NOT NULL, + space_room TEXT, + metadata jsonb NOT NULL, + + PRIMARY KEY (bridge_id, id), + CONSTRAINT user_login_user_fkey FOREIGN KEY (bridge_id, user_mxid) + REFERENCES "user" (bridge_id, mxid) + ON DELETE CASCADE ON UPDATE CASCADE +); + CREATE TABLE portal ( bridge_id TEXT NOT NULL, id TEXT NOT NULL, @@ -118,29 +141,6 @@ CREATE TABLE reaction ( ON DELETE CASCADE ON UPDATE CASCADE ); -CREATE TABLE "user" ( - bridge_id TEXT NOT NULL, - mxid TEXT NOT NULL, - - management_room TEXT, - access_token TEXT, - - PRIMARY KEY (bridge_id, mxid) -); - -CREATE TABLE user_login ( - bridge_id TEXT NOT NULL, - user_mxid TEXT NOT NULL, - id TEXT NOT NULL, - space_room TEXT, - metadata jsonb NOT NULL, - - PRIMARY KEY (bridge_id, id), - CONSTRAINT user_login_user_fkey FOREIGN KEY (bridge_id, user_mxid) - REFERENCES "user" (bridge_id, mxid) - ON DELETE CASCADE ON UPDATE CASCADE -); - CREATE TABLE user_portal ( bridge_id TEXT NOT NULL, user_mxid TEXT NOT NULL, From 5f510014f0f97abe964fbd636af2000f5da6dce3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 6 Jul 2024 15:07:30 +0300 Subject: [PATCH 0407/1647] bridgev2/login: improve user input handling in command login --- bridgev2/commands/login.go | 7 ++++++- bridgev2/login.go | 18 +++++++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go index c0e92c27..77244505 100644 --- a/bridgev2/commands/login.go +++ b/bridgev2/commands/login.go @@ -86,7 +86,12 @@ type userInputLoginCommandState struct { } func (uilcs *userInputLoginCommandState) promptNext(ce *Event) { - // TODO reply prompting field + field := uilcs.RemainingFields[0] + if field.Description != "" { + ce.Reply("Please enter your %s\n%s", field.Name, field.Description) + } else { + ce.Reply("Please enter your %s", field.Name) + } StoreCommandState(ce.User, &CommandState{ Next: MinimalCommandHandlerFunc(uilcs.submitNext), Action: "Login", diff --git a/bridgev2/login.go b/bridgev2/login.go index 5ec9c12a..a2593d3d 100644 --- a/bridgev2/login.go +++ b/bridgev2/login.go @@ -9,6 +9,7 @@ package bridgev2 import ( "context" "fmt" + "regexp" "strings" "maunium.net/go/mautrix/bridgev2/networkid" @@ -159,8 +160,11 @@ func CleanPhoneNumber(phone string) (string, error) { return phone, nil } +func noopValidate(input string) (string, error) { + return input, nil +} + func (f *LoginInputDataField) FillDefaultValidate() { - noopValidate := func(input string) (string, error) { return input, nil } if f.Validate != nil { return } @@ -175,6 +179,18 @@ func (f *LoginInputDataField) FillDefaultValidate() { return email, nil } default: + if f.Pattern != "" { + f.Validate = func(s string) (string, error) { + match, err := regexp.MatchString(f.Pattern, s) + if err != nil { + return "", err + } else if !match { + return "", fmt.Errorf("invalid input") + } else { + return s, nil + } + } + } f.Validate = noopValidate } } From 977eb24233517ff289d92f18139d666916c338a4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 6 Jul 2024 15:08:11 +0300 Subject: [PATCH 0408/1647] bridgev2/login: don't reply with empty instructions --- bridgev2/commands/login.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go index 77244505..8137ccaf 100644 --- a/bridgev2/commands/login.go +++ b/bridgev2/commands/login.go @@ -277,7 +277,9 @@ func (clcs *cookieLoginCommandState) submit(ce *Event) { } func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginStep) { - ce.Reply(step.Instructions) + if step.Instructions != "" { + ce.Reply(step.Instructions) + } switch step.Type { case bridgev2.LoginStepTypeDisplayAndWait: From 7eb0e962ce29a4603c93efe3627b17746c7613e7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 6 Jul 2024 15:10:15 +0300 Subject: [PATCH 0409/1647] bridgev2/login: further improve user input command handling --- bridgev2/commands/login.go | 3 +++ bridgev2/login.go | 5 +++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go index 8137ccaf..f8c0e402 100644 --- a/bridgev2/commands/login.go +++ b/bridgev2/commands/login.go @@ -103,6 +103,9 @@ func (uilcs *userInputLoginCommandState) promptNext(ce *Event) { func (uilcs *userInputLoginCommandState) submitNext(ce *Event) { field := uilcs.RemainingFields[0] field.FillDefaultValidate() + if field.Type == bridgev2.LoginInputFieldTypePassword { + ce.Redact() + } var err error uilcs.Data[field.ID], err = field.Validate(ce.RawArgs) if err != nil { diff --git a/bridgev2/login.go b/bridgev2/login.go index a2593d3d..775f018b 100644 --- a/bridgev2/login.go +++ b/bridgev2/login.go @@ -185,13 +185,14 @@ func (f *LoginInputDataField) FillDefaultValidate() { if err != nil { return "", err } else if !match { - return "", fmt.Errorf("invalid input") + return "", fmt.Errorf("doesn't match regex `%s`", f.Pattern) } else { return s, nil } } + } else { + f.Validate = noopValidate } - f.Validate = noopValidate } } From f32320883172b12f8e128722256ecd15e4779eb5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 6 Jul 2024 15:46:46 +0300 Subject: [PATCH 0410/1647] bridgev2: link to mau.fi/ports for DefaultPort doc --- bridgev2/networkinterface.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index d566517d..cd9ff655 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -75,6 +75,7 @@ type BridgeName struct { // The Go import path is a good choice here (e.g. github.com/octocat/discordbridge) BeeperBridgeType string // The default appservice port to use in the example config, defaults to 8080 if unset + // Official mautrix bridges will use ports defined in https://mau.fi/ports DefaultPort uint16 // The default command prefix to use in the example config, defaults to NetworkID if unset. Must include the ! prefix. DefaultCommandPrefix string From b4057a26c3edc904bfeac220f72904cfed57f65a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 6 Jul 2024 15:46:59 +0300 Subject: [PATCH 0411/1647] bridgev2/portal: don't panic if IsSpace is unset in CreateMatrixRoom --- bridgev2/portal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 1d147f68..e4bba7d6 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2106,7 +2106,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i req.BeeperAutoJoinInvites = true req.Invite = initialMembers } - if *info.IsSpace { + if info.IsSpace != nil && *info.IsSpace { req.CreationContent["type"] = event.RoomTypeSpace portal.Metadata.IsSpace = true } From 0aa773b97305ca4bd76e6cea8cb3e482a951a9e4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 9 Jul 2024 19:08:41 +0300 Subject: [PATCH 0412/1647] bridgev2: add more features to reaction and receipt events --- bridgev2/database/message.go | 9 ++++++++- bridgev2/networkinterface.go | 12 ++++++++++++ bridgev2/portal.go | 12 ++++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index 9129c0b7..0a337d93 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -71,7 +71,10 @@ const ( getLastMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id DESC LIMIT 1` getFirstMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id ASC LIMIT 1` getMessagesBetweenTimeQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND timestamp>$4 AND timestamp<=$5` - insertMessageQuery = ` + + getLastMessagePartAtOrBeforeTimeQuery = getMessageBaseQuery + `WHERE bridge_id = $1 AND room_id=$2 AND room_receiver=$3 AND timestamp<=$4 ORDER BY timestamp DESC, part_id DESC LIMIT 1` + + insertMessageQuery = ` INSERT INTO message (bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, timestamp, relates_to, metadata) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING rowid @@ -120,6 +123,10 @@ func (mq *MessageQuery) GetFirstOrSpecificPartByID(ctx context.Context, receiver } } +func (mq *MessageQuery) GetLastPartAtOrBeforeTime(ctx context.Context, portal networkid.PortalKey, maxTS time.Time) (*Message, error) { + return mq.QueryOne(ctx, getLastMessagePartAtOrBeforeTimeQuery, mq.BridgeID, portal.ID, portal.Receiver, maxTS.UnixNano()) +} + func (mq *MessageQuery) GetMessagesBetweenTimeQuery(ctx context.Context, portal networkid.PortalKey, start, end time.Time) ([]*Message, error) { return mq.QueryMany(ctx, getMessagesBetweenTimeQuery, mq.BridgeID, portal.ID, portal.Receiver, start.UnixNano(), end.UnixNano()) } diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index cd9ff655..65886ba5 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -525,6 +525,11 @@ type RemoteReaction interface { GetReactionEmoji() (string, networkid.EmojiID) } +type RemoteReactionWithExtraContent interface { + RemoteReaction + GetReactionExtraContent() map[string]any +} + type RemoteReactionWithMeta interface { RemoteReaction GetReactionDBMetadata() map[string]any @@ -543,6 +548,7 @@ type RemoteReceipt interface { RemoteEvent GetLastReceiptTarget() networkid.MessageID GetReceiptTargets() []networkid.MessageID + GetReadUpTo() time.Time } type RemoteMarkUnread interface { @@ -583,6 +589,7 @@ type SimpleRemoteEvent[T any] struct { Emoji string ReactionDBMeta map[string]any Timestamp time.Time + ChatInfoChange *ChatInfoChange ConvertMessageFunc func(ctx context.Context, portal *Portal, intent MatrixAPI, data T) (*ConvertedMessage, error) ConvertEditFunc func(ctx context.Context, portal *Portal, intent MatrixAPI, existing []*database.Message, data T) (*ConvertedEdit, error) @@ -596,6 +603,7 @@ var ( _ RemoteReactionWithMeta = (*SimpleRemoteEvent[any])(nil) _ RemoteReactionRemove = (*SimpleRemoteEvent[any])(nil) _ RemoteMessageRemove = (*SimpleRemoteEvent[any])(nil) + _ RemoteChatInfoChange = (*SimpleRemoteEvent[any])(nil) ) func (sre *SimpleRemoteEvent[T]) AddLogContext(c zerolog.Context) zerolog.Context { @@ -645,6 +653,10 @@ func (sre *SimpleRemoteEvent[T]) GetReactionDBMetadata() map[string]any { return sre.ReactionDBMeta } +func (sre *SimpleRemoteEvent[T]) GetChatInfoChange(ctx context.Context) (*ChatInfoChange, error) { + return sre.ChatInfoChange, nil +} + func (sre *SimpleRemoteEvent[T]) GetType() RemoteEventType { return sre.Type } diff --git a/bridgev2/portal.go b/bridgev2/portal.go index e4bba7d6..a973eef8 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1316,6 +1316,10 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi } ts := getEventTS(evt) intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventReaction) + var extra map[string]any + if extraContentProvider, ok := evt.(RemoteReactionWithExtraContent); ok { + extra = extraContentProvider.GetReactionExtraContent() + } resp, err := intent.SendMessage(ctx, portal.MXID, event.EventReaction, &event.Content{ Parsed: &event.ReactionEventContent{ RelatesTo: event.RelatesTo{ @@ -1324,6 +1328,7 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi Key: variationselector.Add(emoji), }, }, + Raw: extra, }, ts) if err != nil { log.Err(err).Msg("Failed to send reaction to Matrix") @@ -1447,6 +1452,13 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL } } } + readUpTo := evt.GetReadUpTo() + if lastTarget == nil && !readUpTo.IsZero() { + lastTarget, err = portal.Bridge.DB.Message.GetLastPartAtOrBeforeTime(ctx, portal.PortalKey, readUpTo) + if err != nil { + log.Err(err).Time("read_up_to", readUpTo).Msg("Failed to get target message for read receipt") + } + } if lastTarget == nil { log.Warn().Msg("No target message found for read receipt") return From fc7ed77e26300334db3752e753e4c1a7f9b6293e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 9 Jul 2024 19:09:00 +0300 Subject: [PATCH 0413/1647] bridgev2: add helper for finding existing portal receiver --- bridgev2/database/portal.go | 11 +++++++++++ bridgev2/networkid/bridgeid.go | 4 ++++ bridgev2/portal.go | 29 +++++++++++++++++++++++++++++ 3 files changed, 44 insertions(+) diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index 47c39a0b..ca277455 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -10,6 +10,7 @@ import ( "context" "database/sql" "encoding/hex" + "errors" "time" "go.mau.fi/util/dbutil" @@ -79,6 +80,8 @@ const ( getPortalByMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` getChildPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND parent_id=$2` + findPortalReceiverQuery = `SELECT id, receiver FROM portal WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='') LIMIT 1` + insertPortalQuery = ` INSERT INTO portal ( bridge_id, id, receiver, mxid, @@ -109,6 +112,14 @@ func (pq *PortalQuery) GetByID(ctx context.Context, key networkid.PortalKey) (*P return pq.QueryOne(ctx, getPortalByIDQuery, pq.BridgeID, key.ID, key.Receiver) } +func (pq *PortalQuery) FindReceiver(ctx context.Context, id networkid.PortalID, maybeReceiver networkid.UserLoginID) (key networkid.PortalKey, err error) { + err = pq.GetDB().QueryRow(ctx, findPortalReceiverQuery, pq.BridgeID, id, maybeReceiver).Scan(&key.ID, &key.Receiver) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return +} + func (pq *PortalQuery) GetByIDWithUncertainReceiver(ctx context.Context, key networkid.PortalKey) (*Portal, error) { return pq.QueryOne(ctx, getPortalByIDWithUncertainReceiverQuery, pq.BridgeID, key.ID, key.Receiver) } diff --git a/bridgev2/networkid/bridgeid.go b/bridgev2/networkid/bridgeid.go index 3b55b67b..08e49f29 100644 --- a/bridgev2/networkid/bridgeid.go +++ b/bridgev2/networkid/bridgeid.go @@ -48,6 +48,10 @@ type PortalKey struct { Receiver UserLoginID } +func (pk PortalKey) IsEmpty() bool { + return pk.ID == "" && pk.Receiver == "" +} + func (pk PortalKey) String() string { if pk.Receiver == "" { return string(pk.ID) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index a973eef8..bbcea8ab 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -131,6 +131,35 @@ func (br *Bridge) unlockedGetPortalByID(ctx context.Context, id networkid.Portal return br.loadPortal(ctx, db, err, idPtr) } +func (br *Bridge) FindPortalReceiver(ctx context.Context, id networkid.PortalID, maybeReceiver networkid.UserLoginID) (networkid.PortalKey, error) { + key := br.FindCachedPortalReceiver(id, maybeReceiver) + if !key.IsEmpty() { + return key, nil + } + key, err := br.DB.Portal.FindReceiver(ctx, id, maybeReceiver) + if err != nil { + return networkid.PortalKey{}, err + } + return key, nil +} + +func (br *Bridge) FindCachedPortalReceiver(id networkid.PortalID, maybeReceiver networkid.UserLoginID) networkid.PortalKey { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + portal, ok := br.portalsByKey[networkid.PortalKey{ + ID: id, + Receiver: maybeReceiver, + }] + if ok { + return portal.PortalKey + } + portal, ok = br.portalsByKey[networkid.PortalKey{ID: id}] + if ok { + return portal.PortalKey + } + return networkid.PortalKey{} +} + func (br *Bridge) GetPortalByMXID(ctx context.Context, mxid id.RoomID) (*Portal, error) { br.cacheLock.Lock() defer br.cacheLock.Unlock() From 06f9e82d7cc079a0eed7f787fac84f1f671cba42 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 10 Jul 2024 11:40:40 +0300 Subject: [PATCH 0414/1647] bridgev2: add total member count field to member list --- bridgev2/portal.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index bbcea8ab..b500f305 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1587,6 +1587,9 @@ type ChatMemberList struct { // This should be used when SenderLogin can't be filled accurately. CheckAllLogins bool + // The total number of members in the chat, regardless of how many of those members are included in Members. + TotalMemberCount int + Members []ChatMember PowerLevels *PowerLevelChanges } From 989edc61a812d907d2988f6bdf3af092abff13ac Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 10 Jul 2024 11:41:46 +0300 Subject: [PATCH 0415/1647] bridgev2: refetch info on room create if it's missing member list --- bridgev2/portal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index b500f305..dd52b49c 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2109,7 +2109,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i log.Info().Msg("Creating Matrix room") var err error - if info == nil { + if info == nil || info.Members == nil { info, err = source.Client.GetChatInfo(ctx, portal) if err != nil { log.Err(err).Msg("Failed to update portal info for creation") From cd334a3815896cc7d546748fa27961cda0e0a334 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 10 Jul 2024 13:17:32 +0300 Subject: [PATCH 0416/1647] bridgev2: expose unlocked get functions and lock entire NewLogin --- bridgev2/portal.go | 8 ++++---- bridgev2/portalreid.go | 4 ++-- bridgev2/userlogin.go | 20 ++++++++++++++------ 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index dd52b49c..a6b0494b 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -94,7 +94,7 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que } var err error if portal.ParentID != "" { - portal.Parent, err = br.unlockedGetPortalByID(ctx, networkid.PortalKey{ID: portal.ParentID}, false) + portal.Parent, err = br.UnlockedGetPortalByID(ctx, networkid.PortalKey{ID: portal.ParentID}, false) if err != nil { return nil, fmt.Errorf("failed to load parent portal (%s): %w", portal.ParentID, err) } @@ -118,7 +118,7 @@ func (portal *Portal) updateLogger() { portal.Log = logWith.Logger() } -func (br *Bridge) unlockedGetPortalByID(ctx context.Context, id networkid.PortalKey, onlyIfExists bool) (*Portal, error) { +func (br *Bridge) UnlockedGetPortalByID(ctx context.Context, id networkid.PortalKey, onlyIfExists bool) (*Portal, error) { cached, ok := br.portalsByKey[id] if ok { return cached, nil @@ -174,14 +174,14 @@ func (br *Bridge) GetPortalByMXID(ctx context.Context, mxid id.RoomID) (*Portal, func (br *Bridge) GetPortalByID(ctx context.Context, id networkid.PortalKey) (*Portal, error) { br.cacheLock.Lock() defer br.cacheLock.Unlock() - return br.unlockedGetPortalByID(ctx, id, false) + return br.UnlockedGetPortalByID(ctx, id, false) } func (br *Bridge) GetExistingPortalByID(ctx context.Context, id networkid.PortalKey) (*Portal, error) { br.cacheLock.Lock() defer br.cacheLock.Unlock() if id.Receiver == "" { - return br.unlockedGetPortalByID(ctx, id, true) + return br.UnlockedGetPortalByID(ctx, id, true) } cached, ok := br.portalsByKey[id] if ok { diff --git a/bridgev2/portalreid.go b/bridgev2/portalreid.go index 0622aef6..c4f7a69b 100644 --- a/bridgev2/portalreid.go +++ b/bridgev2/portalreid.go @@ -39,7 +39,7 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta }() br.cacheLock.Lock() defer br.cacheLock.Unlock() - sourcePortal, err := br.unlockedGetPortalByID(ctx, source, true) + sourcePortal, err := br.UnlockedGetPortalByID(ctx, source, true) if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to get source portal: %w", err) } else if sourcePortal == nil { @@ -59,7 +59,7 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Stringer("source_portal_mxid", sourcePortal.MXID) }) - targetPortal, err := br.unlockedGetPortalByID(ctx, target, true) + targetPortal, err := br.UnlockedGetPortalByID(ctx, target, true) if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to get target portal: %w", err) } diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index f20f4657..9e9f605b 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -143,6 +143,8 @@ type NewLoginParams struct { // This will automatically call LoadUserLogin after creating the UserLogin object. // The load method defaults to the network connector's LoadUserLogin method, but it can be overridden in params. func (user *User) NewLogin(ctx context.Context, data *database.UserLogin, params *NewLoginParams) (*UserLogin, error) { + user.Bridge.cacheLock.Lock() + defer user.Bridge.cacheLock.Unlock() data.BridgeID = user.BridgeID data.UserMXID = user.MXID if params == nil { @@ -151,14 +153,14 @@ func (user *User) NewLogin(ctx context.Context, data *database.UserLogin, params if params.LoadUserLogin == nil { params.LoadUserLogin = user.Bridge.Network.LoadUserLogin } - ul, err := user.Bridge.GetExistingUserLoginByID(ctx, data.ID) + ul, err := user.Bridge.unlockedGetExistingUserLoginByID(ctx, data.ID) if err != nil { return nil, fmt.Errorf("failed to check if login already exists: %w", err) } var doInsert bool if ul != nil && ul.UserMXID != user.MXID { if params.DeleteOnConflict { - ul.Delete(ctx, status.BridgeState{StateEvent: status.StateLoggedOut, Error: "overridden-by-another-user"}, false) + ul.delete(ctx, status.BridgeState{StateEvent: status.StateLoggedOut, Error: "overridden-by-another-user"}, false, true) ul = nil } else { return nil, fmt.Errorf("%s is already logged in with that account", ul.UserMXID) @@ -190,10 +192,8 @@ func (user *User) NewLogin(ctx context.Context, data *database.UserLogin, params if err != nil { return nil, err } - user.Bridge.cacheLock.Lock() user.Bridge.userLoginsByID[ul.ID] = ul user.logins[ul.ID] = ul - user.Bridge.cacheLock.Unlock() } else { err = ul.Save(ctx) if err != nil { @@ -212,6 +212,10 @@ func (ul *UserLogin) Logout(ctx context.Context) { } func (ul *UserLogin) Delete(ctx context.Context, state status.BridgeState, logoutRemote bool) { + ul.delete(ctx, state, logoutRemote, false) +} + +func (ul *UserLogin) delete(ctx context.Context, state status.BridgeState, logoutRemote, unlocked bool) { if logoutRemote { ul.Client.LogoutRemote(ctx) } else { @@ -225,10 +229,14 @@ func (ul *UserLogin) Delete(ctx context.Context, state status.BridgeState, logou if err != nil { ul.Log.Err(err).Msg("Failed to delete user login") } - ul.Bridge.cacheLock.Lock() + if !unlocked { + ul.Bridge.cacheLock.Lock() + } delete(ul.User.logins, ul.ID) delete(ul.Bridge.userLoginsByID, ul.ID) - ul.Bridge.cacheLock.Unlock() + if !unlocked { + ul.Bridge.cacheLock.Unlock() + } go ul.deleteSpace(ctx) go ul.kickUserFromPortals(ctx, portals) if state.StateEvent != "" { From e9097ad3a2c9714e60f85da784d4c7796ad5766b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 10 Jul 2024 13:17:52 +0300 Subject: [PATCH 0417/1647] bridgev2: implement portal parent spaces --- bridgev2/portal.go | 68 +++++++++++++++++++++++++++++++++++++--------- bridgev2/space.go | 57 ++++++++++++++++++++++++++++++++++---- 2 files changed, 107 insertions(+), 18 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index a6b0494b..f5c9a1d0 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -18,6 +18,7 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/exfmt" "go.mau.fi/util/exslices" + "go.mau.fi/util/ptr" "go.mau.fi/util/variationselector" "golang.org/x/exp/slices" @@ -1659,6 +1660,7 @@ type ChatInfo struct { IsDirectChat *bool IsSpace *bool Disappear *database.DisappearingSetting + ParentID *networkid.PortalID UserLocal *UserLocalPortalInfo @@ -2051,6 +2053,40 @@ func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting dat return true } +func (portal *Portal) UpdateParent(ctx context.Context, newParent networkid.PortalID, source *UserLogin) bool { + if portal.ParentID == newParent { + return false + } + var err error + if portal.MXID != "" && portal.InSpace && portal.Parent != nil && portal.Parent.MXID != "" { + err = portal.toggleSpace(ctx, portal.Parent.MXID, false, true) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("old_space_mxid", portal.Parent.MXID).Msg("Failed to remove portal from old space") + } + } + portal.ParentID = newParent + portal.InSpace = false + if newParent != "" { + portal.Parent, err = portal.Bridge.GetPortalByID(ctx, networkid.PortalKey{ID: newParent}) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get new parent portal") + } + } + if portal.MXID != "" && portal.Parent != nil && (source != nil || portal.Parent.MXID != "") { + if portal.Parent.MXID == "" { + zerolog.Ctx(ctx).Info().Msg("Parent portal doesn't exist, creating") + err = portal.Parent.CreateMatrixRoom(ctx, source, nil) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to create parent portal") + } + } + if portal.Parent.MXID != "" { + portal.addToParentSpaceAndSave(ctx, false) + } + } + return true +} + func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *UserLogin, sender MatrixAPI, ts time.Time) { changed := false if info.Name != nil { @@ -2065,6 +2101,9 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us if info.Disappear != nil { changed = portal.UpdateDisappearingSetting(ctx, *info.Disappear, sender, ts, false, false) || changed } + if info.ParentID != nil { + changed = portal.UpdateParent(ctx, *info.ParentID, source) || changed + } if info.JoinRule != nil { // TODO change detection instead of spamming this every time? portal.sendRoomMeta(ctx, sender, ts, event.StateJoinRules, "", info.JoinRule) @@ -2100,6 +2139,9 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i portal.roomCreateLock.Lock() defer portal.roomCreateLock.Unlock() if portal.MXID != "" { + if source != nil { + source.MarkInPortal(ctx, portal) + } return nil } log := zerolog.Ctx(ctx).With(). @@ -2155,11 +2197,9 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i portal.Metadata.IsSpace = true } bridgeInfoStateKey, bridgeInfo := portal.getBridgeInfo() - emptyString := "" req.InitialState = append(req.InitialState, &event.Event{ - StateKey: &emptyString, - Type: event.StateElementFunctionalMembers, + Type: event.StateElementFunctionalMembers, Content: event.Content{Parsed: &event.ElementFunctionalMembersContent{ ServiceMembers: append(extraFunctionalMembers, portal.Bridge.Bot.GetMXID()), }}, @@ -2176,22 +2216,19 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i // Add explicit topic event if topic is empty to ensure the event is set. // This ensures that there won't be an extra event later if PUT /state/... is called. req.InitialState = append(req.InitialState, &event.Event{ - StateKey: &emptyString, - Type: event.StateTopic, - Content: event.Content{Parsed: &event.TopicEventContent{Topic: ""}}, + Type: event.StateTopic, + Content: event.Content{Parsed: &event.TopicEventContent{Topic: ""}}, }) } if portal.AvatarMXC != "" { req.InitialState = append(req.InitialState, &event.Event{ - StateKey: &emptyString, - Type: event.StateRoomAvatar, - Content: event.Content{Parsed: &event.RoomAvatarEventContent{URL: portal.AvatarMXC}}, + Type: event.StateRoomAvatar, + Content: event.Content{Parsed: &event.RoomAvatarEventContent{URL: portal.AvatarMXC}}, }) } - if portal.Parent != nil { - // TODO create parent portal if it doesn't exist? + if portal.Parent != nil && portal.Parent.MXID != "" { req.InitialState = append(req.InitialState, &event.Event{ - StateKey: (*string)(&portal.Parent.MXID), + StateKey: ptr.Ptr(portal.Parent.MXID.String()), Type: event.StateSpaceParent, Content: event.Content{Parsed: &event.SpaceParentEventContent{ Via: []string{portal.Bridge.Matrix.ServerName()}, @@ -2225,7 +2262,12 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i return err } if portal.Parent != nil { - // TODO add m.space.child event + if portal.Parent.MXID != "" { + portal.addToParentSpaceAndSave(ctx, true) + } else { + log.Info().Msg("Parent portal doesn't exist, creating in background") + go portal.createParentAndAddToSpace(ctx, source) + } } portal.updateUserLocalInfo(ctx, info.UserLocal, source) if !autoJoinInvites { diff --git a/bridgev2/space.go b/bridgev2/space.go index f5066e07..7cf570ec 100644 --- a/bridgev2/space.go +++ b/bridgev2/space.go @@ -65,11 +65,7 @@ func (ul *UserLogin) AddPortalToSpace(ctx context.Context, portal *Portal, userP } else if spaceRoom == "" { return nil } - _, err = ul.Bridge.Bot.SendState(ctx, spaceRoom, event.StateSpaceChild, portal.MXID.String(), &event.Content{ - Parsed: &event.SpaceChildEventContent{ - Via: []string{ul.Bridge.Matrix.ServerName()}, - }, - }, time.Now()) + err = portal.toggleSpace(ctx, spaceRoom, false, false) if err != nil { return fmt.Errorf("failed to add portal to space: %w", err) } @@ -83,6 +79,57 @@ func (ul *UserLogin) AddPortalToSpace(ctx context.Context, portal *Portal, userP return nil } +func (portal *Portal) createParentAndAddToSpace(ctx context.Context, source *UserLogin) { + err := portal.Parent.CreateMatrixRoom(ctx, source, nil) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to create parent portal") + } else { + portal.addToParentSpaceAndSave(ctx, true) + } +} + +func (portal *Portal) addToParentSpaceAndSave(ctx context.Context, save bool) { + err := portal.toggleSpace(ctx, portal.Parent.MXID, true, false) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("space_mxid", portal.Parent.MXID).Msg("Failed to add portal to space") + } else { + portal.InSpace = true + if save { + err = portal.Save(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal to database after adding to space") + } + } + } +} + +func (portal *Portal) toggleSpace(ctx context.Context, spaceID id.RoomID, canonical, remove bool) error { + via := []string{portal.Bridge.Matrix.ServerName()} + if remove { + via = nil + } + _, err := portal.Bridge.Bot.SendState(ctx, spaceID, event.StateSpaceChild, portal.MXID.String(), &event.Content{ + Parsed: &event.SpaceChildEventContent{ + Via: via, + }, + }, time.Now()) + if err != nil { + return err + } + if canonical { + _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateSpaceParent, spaceID.String(), &event.Content{ + Parsed: &event.SpaceParentEventContent{ + Via: via, + Canonical: !remove, + }, + }, time.Now()) + if err != nil { + return err + } + } + return nil +} + func (ul *UserLogin) GetSpaceRoom(ctx context.Context) (id.RoomID, error) { if !ul.Bridge.Config.PersonalFilteringSpaces { return ul.SpaceRoom, nil From 890b23a332da039b2c3461c248b39a6c109a87f1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 10 Jul 2024 14:13:21 +0300 Subject: [PATCH 0418/1647] bridgev2: don't add portals with parent into user login space --- bridgev2/portal.go | 20 +++++++++++--------- bridgev2/space.go | 2 +- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index f5c9a1d0..57258bb6 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2286,15 +2286,17 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i } } } - userPortals, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) - if err != nil { - log.Err(err).Msg("Failed to get user logins in portal to add portal to spaces") - } else { - for _, up := range userPortals { - login := portal.Bridge.GetCachedUserLoginByID(up.LoginID) - if login != nil { - login.inPortalCache.Remove(portal.PortalKey) - go login.tryAddPortalToSpace(ctx, portal, up.CopyWithoutValues()) + if portal.Parent == nil { + userPortals, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to get user logins in portal to add portal to spaces") + } else { + for _, up := range userPortals { + login := portal.Bridge.GetCachedUserLoginByID(up.LoginID) + if login != nil { + login.inPortalCache.Remove(portal.PortalKey) + go login.tryAddPortalToSpace(ctx, portal, up.CopyWithoutValues()) + } } } } diff --git a/bridgev2/space.go b/bridgev2/space.go index 7cf570ec..41ef3c2b 100644 --- a/bridgev2/space.go +++ b/bridgev2/space.go @@ -56,7 +56,7 @@ func (ul *UserLogin) tryAddPortalToSpace(ctx context.Context, portal *Portal, us } func (ul *UserLogin) AddPortalToSpace(ctx context.Context, portal *Portal, userPortal *database.UserPortal) error { - if portal.MXID == "" { + if portal.MXID == "" || portal.Parent != nil { return nil } spaceRoom, err := ul.GetSpaceRoom(ctx) From 2675993fc23309d0b325a8efb60ca3c3900202ad Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 10 Jul 2024 14:15:41 +0300 Subject: [PATCH 0419/1647] bridgev2: fill global bridge state before sending --- bridgev2/bridgestate.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bridgev2/bridgestate.go b/bridgev2/bridgestate.go index 578dfee3..a9f57d37 100644 --- a/bridgev2/bridgestate.go +++ b/bridgev2/bridgestate.go @@ -24,6 +24,7 @@ type BridgeStateQueue struct { } func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) { + state.Fill(nil) for { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) if err := br.Matrix.SendBridgeStatus(ctx, &state); err != nil { From b6bc70e10170317e563a2c88ac80974b2c696dc2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 10 Jul 2024 14:16:41 +0300 Subject: [PATCH 0420/1647] bridgev2: fix filling global bridge state --- bridgev2/bridgestate.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/bridgestate.go b/bridgev2/bridgestate.go index a9f57d37..e7d18d5e 100644 --- a/bridgev2/bridgestate.go +++ b/bridgev2/bridgestate.go @@ -24,7 +24,7 @@ type BridgeStateQueue struct { } func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) { - state.Fill(nil) + state = state.Fill(nil) for { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) if err := br.Matrix.SendBridgeStatus(ctx, &state); err != nil { From b5f24a8b500bd45f99c04f23d43d661c00441fdb Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 10 Jul 2024 15:37:01 +0300 Subject: [PATCH 0421/1647] bridgev2/commands: fix logging error message in command panic handler --- bridgev2/commands/processor.go | 12 ++++++++---- bridgev2/ghost.go | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/bridgev2/commands/processor.go b/bridgev2/commands/processor.go index fce13d09..d14e9781 100644 --- a/bridgev2/commands/processor.go +++ b/bridgev2/commands/processor.go @@ -81,10 +81,14 @@ func (proc *Processor) Handle(ctx context.Context, roomID id.RoomID, eventID id. } err := recover() if err != nil { - zerolog.Ctx(ctx).Error(). - Bytes(zerolog.ErrorStackFieldName, debug.Stack()). - Any(zerolog.ErrorFieldName, err). - Msg("Panic in Matrix command handler") + logEvt := zerolog.Ctx(ctx).Error(). + Bytes(zerolog.ErrorStackFieldName, debug.Stack()) + if realErr, ok := err.(error); ok { + logEvt = logEvt.Err(realErr) + } else { + logEvt = logEvt.Any(zerolog.ErrorFieldName, err) + } + logEvt.Msg("Panic in Matrix command handler") ms.Status = event.MessageStatusFail ms.IsCertain = true if realErr, ok := err.(error); ok { diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index c41b79c9..cf1a68d9 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -216,7 +216,7 @@ func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin } info, err := source.Client.GetUserInfo(ctx, ghost) if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get info to update ghost") + zerolog.Ctx(ctx).Err(err).Str("ghost_id", string(ghost.ID)).Msg("Failed to get info to update ghost") } else if info != nil { ghost.UpdateInfo(ctx, info) } From 9ff2c21fa7a505c62e08d7423e3169f7387190b6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 10 Jul 2024 15:37:28 +0300 Subject: [PATCH 0422/1647] bridgev2: fix auto-joining if membership is blank --- bridgev2/portal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 57258bb6..cdd17f07 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1939,7 +1939,7 @@ func (portal *Portal) SyncParticipants(ctx context.Context, members *ChatMemberL if !syncUser(intent.GetMXID(), member, true) { return } - if member.Membership == event.MembershipJoin { + if member.Membership == event.MembershipJoin || member.Membership == "" { err = intent.EnsureJoined(ctx, portal.MXID) if err != nil { log.Err(err). From 6ce34d7819bf87a4cbe3671e7aff1ceb6711153d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 10 Jul 2024 17:14:30 +0300 Subject: [PATCH 0423/1647] bridgev2: set `network` section in `m.bridge` events properly --- bridgev2/portal.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index cdd17f07..9c564c03 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1720,10 +1720,11 @@ func (portal *Portal) UpdateAvatar(ctx context.Context, avatar *Avatar, sender M } func (portal *Portal) GetTopLevelParent() *Portal { - // TODO ensure there's no infinite recursion? if portal.Parent == nil { - // TODO return self if this is a space portal? - return nil + if !portal.Metadata.IsSpace { + return nil + } + return portal } return portal.Parent.GetTopLevelParent() } @@ -1739,9 +1740,9 @@ func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) { AvatarURL: portal.AvatarMXC, // TODO external URL? }, - // TODO room type } if portal.Metadata.IsDirect { + // TODO group dm type? bridgeInfo.BeeperRoomType = "dm" } else if portal.Metadata.IsSpace { bridgeInfo.BeeperRoomType = "space" From 0cbe236550d79e07f12a12fec02243c8300977da Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 27 May 2024 14:30:35 +0300 Subject: [PATCH 0424/1647] crypto/sqlstore: fill account_id when updating crypto_secrets schema --- crypto/sql_store_upgrade/15-fix-secrets.sql | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/crypto/sql_store_upgrade/15-fix-secrets.sql b/crypto/sql_store_upgrade/15-fix-secrets.sql index 47235397..d49cffae 100644 --- a/crypto/sql_store_upgrade/15-fix-secrets.sql +++ b/crypto/sql_store_upgrade/15-fix-secrets.sql @@ -14,3 +14,8 @@ FROM crypto_secrets; DROP TABLE crypto_secrets; ALTER TABLE crypto_secrets_new RENAME TO crypto_secrets; + +-- only: sqlite +UPDATE crypto_secrets SET account_id=(SELECT account_id FROM crypto_account ORDER BY rowid DESC LIMIT 1); +-- only: postgres +UPDATE crypto_secrets SET account_id=(SELECT account_id FROM crypto_account LIMIT 1); From 8ebeb5e3abcf6bf5b512fd1c73cd485f1ec56cf4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 10 Jul 2024 18:28:29 +0300 Subject: [PATCH 0425/1647] bridgev2: add some more logs to handling remote events --- bridgev2/networkinterface.go | 29 +++++++++++++++++++++++++++++ bridgev2/portal.go | 7 ++++++- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 65886ba5..53ca7d1e 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -454,6 +454,35 @@ type PushableNetworkAPI interface { type RemoteEventType int +func (ret RemoteEventType) String() string { + switch ret { + case RemoteEventUnknown: + return "RemoteEventUnknown" + case RemoteEventMessage: + return "RemoteEventMessage" + case RemoteEventEdit: + return "RemoteEventEdit" + case RemoteEventReaction: + return "RemoteEventReaction" + case RemoteEventReactionRemove: + return "RemoteEventReactionRemove" + case RemoteEventMessageRemove: + return "RemoteEventMessageRemove" + case RemoteEventReadReceipt: + return "RemoteEventReadReceipt" + case RemoteEventDeliveryReceipt: + return "RemoteEventDeliveryReceipt" + case RemoteEventMarkUnread: + return "RemoteEventMarkUnread" + case RemoteEventTyping: + return "RemoteEventTyping" + case RemoteEventChatInfoChange: + return "RemoteEventChatInfoChange" + default: + return fmt.Sprintf("RemoteEventType(%d)", int(ret)) + } +} + const ( RemoteEventUnknown RemoteEventType = iota RemoteEventMessage diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 9c564c03..a717af7b 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1037,7 +1037,9 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { if ok { preHandler.PreHandle(ctx, portal) } - switch evt.GetType() { + evtType := evt.GetType() + log.Debug().Stringer("bridge_evt_type", evtType).Msg("Handling remote event") + switch evtType { case RemoteEventUnknown: log.Debug().Msg("Ignoring remote event with type unknown") case RemoteEventMessage: @@ -1430,6 +1432,9 @@ func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *Use if err != nil { log.Err(err).Msg("Failed to get target message for removal") return + } else if len(targetParts) == 0 { + log.Debug().Msg("Target message not found") + return } intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessageRemove) ts := getEventTS(evt) From 5e50b6a87b05cf436803e7189607b34f0545bad4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 10 Jul 2024 19:07:21 +0300 Subject: [PATCH 0426/1647] crypto: remove incorrect warning log when `m.relates_to` is in both contents --- crypto/decryptmegolm.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index 99b584f5..ff5b82f3 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -130,7 +130,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event } else if updatedPlaintext != nil { plaintext = updatedPlaintext } - } else { + } else if !relation.Exists() { log.Warn().Msg("Failed to find m.relates_to in raw encrypted event even though it was present in parsed content") } } From 3ebe0e18ce7d03fbbd89afc220c8894172e73570 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 10 Jul 2024 19:07:38 +0300 Subject: [PATCH 0427/1647] bridgev2: allow returning nil in HandleMatrixReaction --- bridgev2/messagestatus.go | 2 +- bridgev2/portal.go | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index 90b3d023..1a04e5af 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -16,7 +16,7 @@ import ( ) var ( - ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true) + ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage() ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false) ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false) ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index a717af7b..543307ea 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -827,6 +827,9 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi portal.sendErrorStatus(ctx, evt, err) return } + if dbReaction == nil { + dbReaction = &database.Reaction{} + } // Fill all fields that are known to allow omitting them in connector code if dbReaction.Room.ID == "" { dbReaction.Room = portal.PortalKey From 672ded60f999073dae7da91c6724931ba531f6e9 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Wed, 10 Jul 2024 12:01:17 -0600 Subject: [PATCH 0428/1647] pre-commit: update, enforce go mod tidy Signed-off-by: Sumner Evans --- .pre-commit-config.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5fffa9fb..1ef1b112 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.6.0 hooks: - id: trailing-whitespace exclude_types: [markdown] @@ -17,6 +17,7 @@ repos: - "maunium.net/go/mautrix" - "-w" - id: go-vet-repo-mod + - id: go-mod-tidy - repo: https://github.com/beeper/pre-commit-go rev: v0.3.1 From dd16a8d1d90b9ea1c35e1cfabd56ddfe8394535b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 10 Jul 2024 23:41:10 +0300 Subject: [PATCH 0429/1647] bridgev2: replace relates_to with thread_root and reply_to columns --- bridgev2/database/message.go | 39 +++++++++++---- bridgev2/database/upgrades/00-latest.sql | 27 ++++++----- .../07-message-relation-without-fkey.sql | 4 ++ bridgev2/networkinterface.go | 2 +- bridgev2/portal.go | 47 ++++++++++++------- 5 files changed, 78 insertions(+), 41 deletions(-) create mode 100644 bridgev2/database/upgrades/07-message-relation-without-fkey.sql diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index 0a337d93..d9f78b1e 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -51,7 +51,8 @@ type Message struct { SenderID networkid.UserID Timestamp time.Time - RelatesToRowID int64 + ThreadRoot networkid.MessageID + ReplyTo networkid.MessageOptionalPartID Metadata MessageMetadata } @@ -62,7 +63,7 @@ func newMessage(_ *dbutil.QueryHelper[*Message]) *Message { const ( getMessageBaseQuery = ` - SELECT rowid, bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, timestamp, relates_to, metadata FROM message + SELECT rowid, bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, timestamp, thread_root_id, reply_to_id, reply_to_part_id, metadata FROM message ` getAllMessagePartsByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3` getMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 AND part_id=$4` @@ -71,17 +72,20 @@ const ( getLastMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id DESC LIMIT 1` getFirstMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id ASC LIMIT 1` getMessagesBetweenTimeQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND timestamp>$4 AND timestamp<=$5` + getFirstMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp ASC, part_id ASC LIMIT 1` + getLastMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp DESC, part_id DESC LIMIT 1` getLastMessagePartAtOrBeforeTimeQuery = getMessageBaseQuery + `WHERE bridge_id = $1 AND room_id=$2 AND room_receiver=$3 AND timestamp<=$4 ORDER BY timestamp DESC, part_id DESC LIMIT 1` insertMessageQuery = ` - INSERT INTO message (bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, timestamp, relates_to, metadata) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + INSERT INTO message (bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, timestamp, thread_root_id, reply_to_id, reply_to_part_id, metadata) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) RETURNING rowid ` updateMessageQuery = ` - UPDATE message SET id=$2, part_id=$3, mxid=$4, room_id=$5, room_receiver=$6, sender_id=$7, timestamp=$8, relates_to=$9, metadata=$10 - WHERE bridge_id=$1 AND rowid=$11 + UPDATE message SET id=$2, part_id=$3, mxid=$4, room_id=$5, room_receiver=$6, sender_id=$7, timestamp=$8, + thread_root_id=$9, reply_to_id=$10, reply_to_part_id=$11, metadata=$12 + WHERE bridge_id=$1 AND rowid=$13 ` deleteAllMessagePartsByIDQuery = ` DELETE FROM message WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 @@ -131,6 +135,14 @@ func (mq *MessageQuery) GetMessagesBetweenTimeQuery(ctx context.Context, portal return mq.QueryMany(ctx, getMessagesBetweenTimeQuery, mq.BridgeID, portal.ID, portal.Receiver, start.UnixNano(), end.UnixNano()) } +func (mq *MessageQuery) GetFirstThreadMessage(ctx context.Context, portal networkid.PortalKey, threadRoot networkid.MessageID) (*Message, error) { + return mq.QueryOne(ctx, getFirstMessageInThread, mq.BridgeID, portal.ID, portal.Receiver, threadRoot) +} + +func (mq *MessageQuery) GetLastThreadMessage(ctx context.Context, portal networkid.PortalKey, threadRoot networkid.MessageID) (*Message, error) { + return mq.QueryOne(ctx, getLastMessageInThread, mq.BridgeID, portal.ID, portal.Receiver, threadRoot) +} + func (mq *MessageQuery) Insert(ctx context.Context, msg *Message) error { ensureBridgeIDMatches(&msg.BridgeID, mq.BridgeID) return mq.GetDB().QueryRow(ctx, insertMessageQuery, msg.sqlVariables()...).Scan(&msg.RowID) @@ -151,10 +163,10 @@ func (mq *MessageQuery) Delete(ctx context.Context, rowID int64) error { func (m *Message) Scan(row dbutil.Scannable) (*Message, error) { var timestamp int64 - var relatesTo sql.NullInt64 + var threadRootID, replyToID, replyToPartID sql.NullString err := row.Scan( &m.RowID, &m.BridgeID, &m.ID, &m.PartID, &m.MXID, &m.Room.ID, &m.Room.Receiver, &m.SenderID, - ×tamp, &relatesTo, dbutil.JSON{Data: &m.Metadata}, + ×tamp, &threadRootID, &replyToID, &replyToPartID, dbutil.JSON{Data: &m.Metadata}, ) if err != nil { return nil, err @@ -163,7 +175,13 @@ func (m *Message) Scan(row dbutil.Scannable) (*Message, error) { m.Metadata.Extra = make(map[string]any) } m.Timestamp = time.Unix(0, timestamp) - m.RelatesToRowID = relatesTo.Int64 + m.ThreadRoot = networkid.MessageID(threadRootID.String) + if replyToID.Valid { + m.ReplyTo.MessageID = networkid.MessageID(replyToID.String) + if replyToPartID.Valid { + m.ReplyTo.PartID = (*networkid.PartID)(&replyToPartID.String) + } + } return m, nil } @@ -173,7 +191,8 @@ func (m *Message) sqlVariables() []any { } return []any{ m.BridgeID, m.ID, m.PartID, m.MXID, m.Room.ID, m.Room.Receiver, m.SenderID, - m.Timestamp.UnixNano(), dbutil.NumPtr(m.RelatesToRowID), dbutil.JSON{Data: &m.Metadata}, + m.Timestamp.UnixNano(), dbutil.StrPtr(m.ThreadRoot), dbutil.StrPtr(m.ReplyTo.MessageID), m.ReplyTo.PartID, + dbutil.JSON{Data: &m.Metadata}, } } diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index 51201aba..303dcf8d 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v6 (compatible with v1+): Latest revision +-- v0 -> v7 (compatible with v1+): Latest revision CREATE TABLE "user" ( bridge_id TEXT NOT NULL, mxid TEXT NOT NULL, @@ -80,19 +80,22 @@ CREATE TABLE message ( -- only: sqlite (line commented) -- rowid INTEGER PRIMARY KEY, -- only: postgres - rowid BIGINT PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY, + rowid BIGINT PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY, - bridge_id TEXT NOT NULL, - id TEXT NOT NULL, - part_id TEXT NOT NULL, - mxid TEXT NOT NULL, + bridge_id TEXT NOT NULL, + id TEXT NOT NULL, + part_id TEXT NOT NULL, + mxid TEXT NOT NULL, - room_id TEXT NOT NULL, - room_receiver TEXT NOT NULL, - sender_id TEXT NOT NULL, - timestamp BIGINT NOT NULL, - relates_to BIGINT, - metadata jsonb NOT NULL, + room_id TEXT NOT NULL, + room_receiver TEXT NOT NULL, + sender_id TEXT NOT NULL, + timestamp BIGINT NOT NULL, + thread_root_id TEXT, + reply_to_id TEXT, + reply_to_part_id TEXT, + relates_to BIGINT, -- unused column, TODO: remove + metadata jsonb NOT NULL, CONSTRAINT message_relation_fkey FOREIGN KEY (relates_to) REFERENCES message (rowid) ON DELETE SET NULL, diff --git a/bridgev2/database/upgrades/07-message-relation-without-fkey.sql b/bridgev2/database/upgrades/07-message-relation-without-fkey.sql new file mode 100644 index 00000000..9c4c9fd5 --- /dev/null +++ b/bridgev2/database/upgrades/07-message-relation-without-fkey.sql @@ -0,0 +1,4 @@ +-- v7: Add new relation columns to messages +ALTER TABLE message ADD COLUMN thread_root_id TEXT; +ALTER TABLE message ADD COLUMN reply_to_id TEXT; +ALTER TABLE message ADD COLUMN reply_to_part_id TEXT; diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 53ca7d1e..4c193585 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -39,7 +39,7 @@ type EventSender struct { type ConvertedMessage struct { ReplyTo *networkid.MessageOptionalPartID - ThreadRoot *networkid.MessageOptionalPartID + ThreadRoot *networkid.MessageID Parts []*ConvertedMessagePart Disappear database.DisappearingSetting } diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 543307ea..3c1f6b5c 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -636,6 +636,16 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin if message.Timestamp.IsZero() { message.Timestamp = time.UnixMilli(evt.Timestamp) } + if message.ReplyTo.MessageID == "" && replyTo != nil { + message.ReplyTo.MessageID = replyTo.ID + message.ReplyTo.PartID = &replyTo.PartID + } + if message.ThreadRoot == "" && threadRoot != nil { + message.ThreadRoot = threadRoot.ID + if threadRoot.ThreadRoot != "" { + message.ThreadRoot = threadRoot.ThreadRoot + } + } message.Metadata.SenderMXID = evt.Sender // Hack to ensure the ghost row exists // TODO move to better place (like login) @@ -1139,32 +1149,32 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin portal.sendRemoteErrorNotice(ctx, intent, err, ts, "message") return } - var relatesToRowID int64 + var threadRootID networkid.MessageID + var replyToID networkid.MessageOptionalPartID var replyTo, threadRoot, prevThreadEvent *database.Message if converted.ReplyTo != nil { + replyToID = *converted.ReplyTo replyTo, err = portal.Bridge.DB.Message.GetFirstOrSpecificPartByID(ctx, portal.Receiver, *converted.ReplyTo) if err != nil { log.Err(err).Msg("Failed to get reply target message from database") } else if replyTo == nil { log.Warn().Any("reply_to", converted.ReplyTo).Msg("Reply target message not found in database") - } else { - relatesToRowID = replyTo.RowID } } if converted.ThreadRoot != nil { - threadRoot, err = portal.Bridge.DB.Message.GetFirstOrSpecificPartByID(ctx, portal.Receiver, *converted.ThreadRoot) + threadRootID = *converted.ThreadRoot + threadRoot, err = portal.Bridge.DB.Message.GetFirstThreadMessage(ctx, portal.PortalKey, threadRootID) if err != nil { log.Err(err).Msg("Failed to get thread root message from database") } else if threadRoot == nil { log.Warn().Any("thread_root", converted.ThreadRoot).Msg("Thread root message not found in database") - } else { - relatesToRowID = threadRoot.RowID } - // TODO thread roots need to be saved in the database in a way that allows fetching - // the first bridged thread message even if the original one isn't bridged - - // TODO 2 fetch last event in thread properly - prevThreadEvent = threadRoot + prevThreadEvent, err = portal.Bridge.DB.Message.GetLastThreadMessage(ctx, portal.PortalKey, threadRootID) + if err != nil { + log.Err(err).Msg("Failed to get last thread message from database") + } else if prevThreadEvent == nil { + prevThreadEvent = threadRoot + } } for _, part := range converted.Parts { if threadRoot != nil && prevThreadEvent != nil { @@ -1192,13 +1202,14 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin Str("part_id", string(part.ID)). Msg("Sent message part to Matrix") dbMessage := &database.Message{ - ID: evt.GetID(), - PartID: part.ID, - MXID: resp.EventID, - Room: portal.PortalKey, - SenderID: evt.GetSender().Sender, - Timestamp: ts, - RelatesToRowID: relatesToRowID, + ID: evt.GetID(), + PartID: part.ID, + MXID: resp.EventID, + Room: portal.PortalKey, + SenderID: evt.GetSender().Sender, + Timestamp: ts, + ThreadRoot: threadRootID, + ReplyTo: replyToID, } dbMessage.Metadata.SenderMXID = intent.GetMXID() dbMessage.Metadata.Extra = part.DBMetadata From 8893695f8404e3b7414a2f9968481e58a4d03c5a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 11 Jul 2024 10:49:28 +0300 Subject: [PATCH 0430/1647] bridgev2/commands: fix panic and improve logs in start-chat --- bridgev2/commands/startchat.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go index 3054a1a1..275b2051 100644 --- a/bridgev2/commands/startchat.go +++ b/bridgev2/commands/startchat.go @@ -56,6 +56,10 @@ func getClientForStartingChat[T bridgev2.IdentifierResolvingNetworkAPI](ce *Even } func fnResolveIdentifier(ce *Event) { + if len(ce.Args) == 0 { + ce.Reply("Usage: `$cmdprefix %s `", ce.Command) + return + } login, api, identifierParts := getClientForStartingChat[bridgev2.IdentifierResolvingNetworkAPI](ce, "resolving identifiers") if api == nil { return @@ -64,6 +68,7 @@ func fnResolveIdentifier(ce *Event) { identifier := strings.Join(identifierParts, " ") resp, err := api.ResolveIdentifier(ce.Ctx, identifier, createChat) if err != nil { + ce.Log.Err(err).Msg("Failed to resolve identifier") ce.Reply("Failed to resolve identifier: %v", err) return } else if resp == nil { @@ -98,10 +103,19 @@ func fnResolveIdentifier(ce *Event) { if portal == nil { portal, err = ce.Bridge.GetPortalByID(ce.Ctx, resp.Chat.PortalID) if err != nil { + ce.Log.Err(err).Msg("Failed to get portal") ce.Reply("Failed to get portal: %v", err) return } } + if resp.Chat.PortalInfo == nil { + resp.Chat.PortalInfo, err = api.GetChatInfo(ce.Ctx, portal) + if err != nil { + ce.Log.Err(err).Msg("Failed to get portal info") + ce.Reply("Failed to get portal info: %v", err) + return + } + } if portal.MXID != "" { name := portal.Name if name == "" { @@ -112,6 +126,7 @@ func fnResolveIdentifier(ce *Event) { } else { err = portal.CreateMatrixRoom(ce.Ctx, login, resp.Chat.PortalInfo) if err != nil { + ce.Log.Err(err).Msg("Failed to create room") ce.Reply("Failed to create room: %v", err) return } From 7f18d6b7358e8a3aa06ac0fb21b7d68fc2a43885 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 11 Jul 2024 11:38:23 +0300 Subject: [PATCH 0431/1647] bridgev2: add caption merging utilities --- bridgev2/networkinterface.go | 65 ++++++++++++++++++++++++++++++++++-- event/message.go | 18 ++++++++++ 2 files changed, 80 insertions(+), 3 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 4c193585..dfde8bd0 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -15,6 +15,8 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/configupgrade" + "go.mau.fi/util/ptr" + "golang.org/x/exp/maps" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" @@ -31,6 +33,21 @@ type ConvertedMessagePart struct { DBMetadata map[string]any } +func (cmp *ConvertedMessagePart) ToEditPart(part *database.Message) *ConvertedEditPart { + if cmp == nil { + return nil + } + if cmp.DBMetadata != nil { + maps.Copy(part.Metadata.Extra, cmp.DBMetadata) + } + return &ConvertedEditPart{ + Part: part, + Type: cmp.Type, + Content: cmp.Content, + Extra: cmp.Extra, + } +} + type EventSender struct { IsFromMe bool SenderLogin networkid.UserLoginID @@ -44,6 +61,48 @@ type ConvertedMessage struct { Disappear database.DisappearingSetting } +func MergeCaption(textPart, mediaPart *ConvertedMessagePart) *ConvertedMessagePart { + if textPart == nil { + return mediaPart + } else if mediaPart == nil { + return textPart + } + mediaPart = ptr.Clone(mediaPart) + if mediaPart.Content.Body != "" && mediaPart.Content.FileName != "" && mediaPart.Content.Body != mediaPart.Content.FileName { + textPart = ptr.Clone(textPart) + textPart.Content.EnsureHasHTML() + mediaPart.Content.EnsureHasHTML() + mediaPart.Content.Body += "\n\n" + textPart.Content.Body + mediaPart.Content.FormattedBody += "

" + textPart.Content.FormattedBody + } else { + mediaPart.Content.FileName = mediaPart.Content.Body + mediaPart.Content.Body = textPart.Content.Body + mediaPart.Content.Format = textPart.Content.Format + mediaPart.Content.FormattedBody = textPart.Content.FormattedBody + } + mediaPart.ID = textPart.ID + return mediaPart +} + +func (cm *ConvertedMessage) MergeCaption() bool { + if len(cm.Parts) != 2 { + return false + } + textPart, mediaPart := cm.Parts[0], cm.Parts[1] + if textPart.Content.MsgType.IsMedia() { + textPart, mediaPart = mediaPart, textPart + } + if !mediaPart.Content.MsgType.IsMedia() || !textPart.Content.MsgType.IsText() { + return false + } + merged := MergeCaption(textPart, mediaPart) + if merged != nil { + cm.Parts = []*ConvertedMessagePart{merged} + return true + } + return false +} + type ConvertedEditPart struct { Part *database.Message @@ -357,9 +416,9 @@ type ResolveIdentifierResponse struct { } type CreateChatResponse struct { - Portal *Portal - - PortalID networkid.PortalKey + PortalID networkid.PortalKey + // Portal and PortalInfo are not required, the caller will fetch them automatically based on PortalID if necessary. + Portal *Portal PortalInfo *ChatInfo } diff --git a/event/message.go b/event/message.go index 21f5240b..5330d006 100644 --- a/event/message.go +++ b/event/message.go @@ -21,6 +21,24 @@ import ( // https://spec.matrix.org/v1.2/client-server-api/#mroommessage-msgtypes type MessageType string +func (mt MessageType) IsText() bool { + switch mt { + case MsgText, MsgNotice, MsgEmote: + return true + default: + return false + } +} + +func (mt MessageType) IsMedia() bool { + switch mt { + case MsgImage, MsgVideo, MsgAudio, MsgFile, MessageType(EventSticker.Type): + return true + default: + return false + } +} + // Msgtypes const ( MsgText MessageType = "m.text" From 32e6f25c34b91cd07c7d8bc59282e7422541c762 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 11 Jul 2024 14:01:35 +0300 Subject: [PATCH 0432/1647] event: add helper to append user ID to Mentions --- event/message.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/event/message.go b/event/message.go index 5330d006..6ce405d6 100644 --- a/event/message.go +++ b/event/message.go @@ -8,6 +8,7 @@ package event import ( "encoding/json" + "slices" "strconv" "strings" @@ -212,6 +213,12 @@ type Mentions struct { Room bool `json:"room,omitempty"` } +func (m *Mentions) Add(userID id.UserID) { + if !slices.Contains(m.UserIDs, userID) { + m.UserIDs = append(m.UserIDs, userID) + } +} + type EncryptedFileInfo struct { attachment.EncryptedFile URL id.ContentURIString `json:"url"` From ebcdde0c97634b97b875c145ac520005e1942c7d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 11 Jul 2024 17:28:31 +0300 Subject: [PATCH 0433/1647] bridgev2: add support for legacy replies to thread messages --- bridgev2/portal.go | 46 ++++++++++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 3c1f6b5c..ebe7a22a 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -587,26 +587,40 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin } var threadRoot, replyTo *database.Message + var replyToID id.EventID if caps.Threads { - threadRootID := content.RelatesTo.GetThreadParent() - if threadRootID != "" { - threadRoot, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, threadRootID) - if err != nil { - log.Err(err).Msg("Failed to get thread root message from database") - } + replyToID = content.RelatesTo.GetNonFallbackReplyTo() + } else { + replyToID = content.RelatesTo.GetReplyTo() + } + threadRootID := content.RelatesTo.GetThreadParent() + if caps.Threads && threadRootID != "" { + threadRoot, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, threadRootID) + if err != nil { + log.Err(err).Msg("Failed to get thread root message from database") } } - if caps.Replies { - var replyToID id.EventID - if caps.Threads { - replyToID = content.RelatesTo.GetNonFallbackReplyTo() + if replyToID != "" && (caps.Replies || caps.Threads) { + replyTo, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, replyToID) + if err != nil { + log.Err(err).Msg("Failed to get reply target message from database") } else { - replyToID = content.RelatesTo.GetReplyTo() - } - if replyToID != "" { - replyTo, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, replyToID) - if err != nil { - log.Err(err).Msg("Failed to get reply target message from database") + // Support replying to threads from non-thread-capable clients. + // The fallback happens if the message is not a Matrix thread and either + // * the replied-to message is in a thread, or + // * the network only supports threads (assume the user wants to start a new thread) + if caps.Threads && threadRoot == nil && (replyTo.ThreadRoot != "" || !caps.Replies) { + threadRootRemoteID := replyTo.ThreadRoot + if threadRootRemoteID == "" { + threadRootRemoteID = replyTo.ID + } + threadRoot, err = portal.Bridge.DB.Message.GetFirstThreadMessage(ctx, portal.PortalKey, threadRootRemoteID) + if err != nil { + log.Err(err).Msg("Failed to get thread root message from database (via reply fallback)") + } + } + if !caps.Replies { + replyTo = nil } } } From 9e4bce17e70b0be997fd9ae1408c6edb2e3c5200 Mon Sep 17 00:00:00 2001 From: Adam Van Ymeren Date: Thu, 11 Jul 2024 13:17:44 -0700 Subject: [PATCH 0434/1647] decryptmegolm: Use ResolveTrustContext to ensure any DB transactions are carried forward (#254) - also make verificationhelper interfaces public so client code can assert conformance --- crypto/decryptmegolm.go | 10 ++++++++-- crypto/verificationhelper/verificationhelper.go | 8 ++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index ff5b82f3..ba2811ab 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -98,14 +98,20 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event } else if device.SigningKey != sess.SigningKey || device.IdentityKey != sess.SenderKey { return nil, DeviceKeyMismatch } else { - trustLevel = mach.ResolveTrust(device) + trustLevel, err = mach.ResolveTrustContext(ctx, device) + if err != nil { + return nil, err + } } } else { forwardedKeys = true lastChainItem := sess.ForwardingChains[len(sess.ForwardingChains)-1] device, _ = mach.CryptoStore.FindDeviceByKey(ctx, evt.Sender, id.IdentityKey(lastChainItem)) if device != nil { - trustLevel = mach.ResolveTrust(device) + trustLevel, err = mach.ResolveTrustContext(ctx, device) + if err != nil { + return nil, err + } } else { log.Debug(). Str("forward_last_sender_key", lastChainItem). diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index c9fd7407..e7ea53c5 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -119,14 +119,14 @@ type RequiredCallbacks interface { VerificationDone(ctx context.Context, txnID id.VerificationTransactionID) } -type showSASCallbacks interface { +type ShowSASCallbacks interface { // ShowSAS is a callback that is called when the SAS verification has // generated a short authentication string to show. It is guaranteed that // either the emojis list, or the decimals list, or both will be present. ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, decimals []int) } -type showQRCodeCallbacks interface { +type ShowQRCodeCallbacks interface { // ScanQRCode is called when another device has sent a // m.key.verification.ready event and indicated that they are capable of // showing a QR code. @@ -183,11 +183,11 @@ func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, call } supportedMethods := map[event.VerificationMethod]struct{}{} - if c, ok := callbacks.(showSASCallbacks); ok { + if c, ok := callbacks.(ShowSASCallbacks); ok { supportedMethods[event.VerificationMethodSAS] = struct{}{} helper.showSAS = c.ShowSAS } - if c, ok := callbacks.(showQRCodeCallbacks); ok { + if c, ok := callbacks.(ShowQRCodeCallbacks); ok { supportedMethods[event.VerificationMethodQRCodeShow] = struct{}{} supportedMethods[event.VerificationMethodReciprocate] = struct{}{} helper.scanQRCode = c.ScanQRCode From 98918d7ab71b3f6477d89dc474c9f540fcb802ab Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 12 Jul 2024 14:59:46 +0300 Subject: [PATCH 0435/1647] bridgev2: add timestamp to message checkpoints --- bridgev2/messagestatus.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index 1a04e5af..7be0c188 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -10,6 +10,8 @@ import ( "errors" "fmt" + "go.mau.fi/util/jsontime" + "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -165,6 +167,7 @@ func (ms *MessageStatus) ToCheckpoint(evt *MessageStatusEventInfo) *status.Messa RoomID: evt.RoomID, EventID: evt.EventID, Step: step, + Timestamp: jsontime.UnixMilliNow(), Status: ms.checkpointStatus(), RetryNum: ms.RetryNum, ReportedBy: status.MsgReportedByBridge, From 0f9f923378c5cb77ed03e63538cb4604cb0ad8c6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 12 Jul 2024 15:00:58 +0300 Subject: [PATCH 0436/1647] bridgev2/matrix: also add timestamp to BRIDGE checkpoints --- bridgev2/matrix/matrix.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bridgev2/matrix/matrix.go b/bridgev2/matrix/matrix.go index d5a76ceb..a7261a61 100644 --- a/bridgev2/matrix/matrix.go +++ b/bridgev2/matrix/matrix.go @@ -14,6 +14,7 @@ import ( "time" "github.com/rs/zerolog" + "go.mau.fi/util/jsontime" "maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/bridge/status" @@ -147,6 +148,7 @@ func (br *Connector) sendSuccessCheckpoint(ctx context.Context, evt *event.Event EventType: evt.Type, MessageType: evt.Content.AsMessage().MsgType, Step: step, + Timestamp: jsontime.UnixMilliNow(), Status: status.MsgStatusSuccess, ReportedBy: status.MsgReportedByBridge, RetryNum: retryNum, From 85cead80342f27845892096f28f4a1d96f34c2b2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 12 Jul 2024 15:52:32 +0300 Subject: [PATCH 0437/1647] bridgev2: add support for forward backfilling --- bridgev2/matrix/connector.go | 55 +++++++++ bridgev2/matrixinterface.go | 4 + bridgev2/networkid/bridgeid.go | 3 + bridgev2/networkinterface.go | 105 ++++++++++++++++ bridgev2/portal.go | 187 ++++++++++++++++++++--------- bridgev2/portalbackfill.go | 211 +++++++++++++++++++++++++++++++++ 6 files changed, 512 insertions(+), 53 deletions(-) create mode 100644 bridgev2/portalbackfill.go diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index e5b9777f..04a7ea1f 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -8,19 +8,25 @@ package matrix import ( "context" + "crypto/sha256" + "encoding/base64" "encoding/json" "errors" + "fmt" + "net/url" "os" "regexp" "strings" "sync" "time" + "unsafe" "github.com/gorilla/mux" _ "github.com/lib/pq" "github.com/rs/zerolog" "go.mau.fi/util/dbutil" _ "go.mau.fi/util/dbutil/litestream" + "go.mau.fi/util/exsync" "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" @@ -63,6 +69,10 @@ type Connector struct { MediaProxy *mediaproxy.MediaProxy dmaSigKey [32]byte + doublePuppetIntents *exsync.Map[id.UserID, *appservice.IntentAPI] + + deterministicEventIDServer string + MediaConfig mautrix.RespMediaConfig SpecVersions *mautrix.RespVersions Capabilities *bridgev2.MatrixCapabilities @@ -92,6 +102,7 @@ func NewConnector(cfg *bridgeconfig.Config) *Connector { c.userIDRegex = cfg.MakeUserIDRegex("(.+)") c.MediaConfig.UploadSize = 50 * 1024 * 1024 c.Capabilities = &bridgev2.MatrixCapabilities{} + c.doublePuppetIntents = exsync.NewMap[id.UserID, *appservice.IntentAPI]() return c } @@ -128,6 +139,7 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) { ) br.Provisioning = &ProvisioningAPI{br: br} br.DoublePuppet = newDoublePuppetUtil(br) + br.deterministicEventIDServer = "backfill." + br.Config.Homeserver.Domain } func (br *Connector) Start(ctx context.Context) error { @@ -154,6 +166,10 @@ func (br *Connector) Start(ctx context.Context) error { if br.Crypto != nil { go br.Crypto.Start() } + parsed, _ := url.Parse(br.Bridge.Network.GetName().NetworkURL) + if parsed != nil { + br.deterministicEventIDServer = parsed.Hostname() + } br.AS.Ready = true return nil } @@ -196,6 +212,7 @@ func (br *Connector) ensureConnection(ctx context.Context) { br.SpecVersions = versions *br.AS.SpecVersions = *versions br.Capabilities.AutoJoinInvites = br.SpecVersions.Supports(mautrix.BeeperFeatureAutojoinInvites) + br.Capabilities.BatchSending = br.SpecVersions.Supports(mautrix.BeeperFeatureBatchSending) break } } @@ -450,6 +467,7 @@ func (br *Connector) NewUserIntent(ctx context.Context, userID id.UserID, access } return nil, accessToken, err } + br.doublePuppetIntents.Set(userID, intent) return &ASIntent{Connector: br, Matrix: intent}, newToken, nil } @@ -479,6 +497,43 @@ func (br *Connector) GetMemberInfo(ctx context.Context, roomID id.RoomID, userID return br.AS.StateStore.GetMember(ctx, roomID, userID) } +func (br *Connector) BatchSend(ctx context.Context, roomID id.RoomID, req *mautrix.ReqBeeperBatchSend) (*mautrix.RespBeeperBatchSend, error) { + if encrypted, err := br.StateStore.IsEncrypted(ctx, roomID); err != nil { + return nil, fmt.Errorf("failed to check if room is encrypted: %w", err) + } else if encrypted { + for _, evt := range req.Events { + intent, _ := br.doublePuppetIntents.Get(evt.Sender) + if intent != nil { + intent.AddDoublePuppetValueWithTS(evt.ID, evt.Timestamp) + } + err = br.Crypto.Encrypt(ctx, roomID, evt.Type, &evt.Content) + if err != nil { + return nil, err + } + evt.Type = event.EventEncrypted + } + } + return br.Bot.BeeperBatchSend(ctx, roomID, req) +} + +func (br *Connector) GenerateDeterministicEventID(roomID id.RoomID, _ networkid.PortalKey, messageID networkid.MessageID, partID networkid.PartID) id.EventID { + data := make([]byte, 0, len(roomID)+len(messageID)+len(partID)) + data = append(data, roomID...) + data = append(data, messageID...) + data = append(data, partID...) + + hash := sha256.Sum256(data) + hashB64Len := base64.RawURLEncoding.EncodedLen(len(hash)) + + eventID := make([]byte, 1+hashB64Len+1+len(br.deterministicEventIDServer)) + eventID[0] = '$' + base64.RawURLEncoding.Encode(eventID[1:1+hashB64Len], hash[:]) + eventID[1+hashB64Len] = ':' + copy(eventID[1+hashB64Len+1:], br.deterministicEventIDServer) + + return id.EventID(unsafe.String(unsafe.SliceData(eventID), len(eventID))) +} + func (br *Connector) ServerName() string { return br.Config.Homeserver.Domain } diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 0ba5f212..49f03c07 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -21,6 +21,7 @@ import ( type MatrixCapabilities struct { AutoJoinInvites bool + BatchSending bool } type MatrixConnector interface { @@ -44,6 +45,9 @@ type MatrixConnector interface { GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) GetMemberInfo(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) + BatchSend(ctx context.Context, roomID id.RoomID, req *mautrix.ReqBeeperBatchSend) (*mautrix.RespBeeperBatchSend, error) + GenerateDeterministicEventID(roomID id.RoomID, portalKey networkid.PortalKey, messageID networkid.MessageID, partID networkid.PartID) id.EventID + ServerName() string } diff --git a/bridgev2/networkid/bridgeid.go b/bridgev2/networkid/bridgeid.go index 08e49f29..65d34609 100644 --- a/bridgev2/networkid/bridgeid.go +++ b/bridgev2/networkid/bridgeid.go @@ -105,6 +105,9 @@ type MessageOptionalPartID struct { PartID *PartID } +// PaginationCursor is a cursor used for paginating message history. +type PaginationCursor string + // AvatarID is the ID of a user or room avatar on the remote network. // // It may be a real URL, an opaque identifier, or anything in between. It should be an identifier that diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index dfde8bd0..f28f2250 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -313,6 +313,86 @@ type NetworkAPI interface { HandleMatrixMessage(ctx context.Context, msg *MatrixMessage) (message *MatrixMessageResponse, err error) } +// FetchMessagesParams contains the parameters for a message history pagination request. +type FetchMessagesParams struct { + // The portal to fetch messages in. Always present. + Portal *Portal + // When fetching messages inside a thread, the ID of the thread. + ThreadRoot networkid.MessageID + // Whether to fetch new messages instead of old ones. + Forward bool + // The oldest known message in the thread or the portal. If Forward is true, this is the newest known message instead. + // If the portal doesn't have any bridged messages, this will be nil. + AnchorMessage *database.Message + // The cursor returned by the previous call to FetchMessages with the same portal and thread root. + // This will not be present in Forward calls. + Cursor networkid.PaginationCursor + // The preferred number of messages to return. The returned batch can be bigger or smaller + // without any side effects, but the network connector should aim for this number. + Count int +} + +// BackfillReaction is an individual reaction to a message in a history pagination request. +// +// The target message is always the BackfillMessage that contains this item. +// Optionally, the reaction can target a specific part by specifying TargetPart. +// If not specified, the first part (sorted lexicographically) is targeted. +type BackfillReaction struct { + // Optional part of the message that the reaction targets. + // If nil, the reaction targets the first part of the message. + TargetPart *networkid.PartID + // Optional timestamp for the reaction. + // If unset, the reaction will have a fake timestamp that is slightly after the message timestamp. + Timestamp time.Time + + Sender EventSender + EmojiID networkid.EmojiID + Emoji string + ExtraContent map[string]any + DBMetadata map[string]any +} + +// BackfillMessage is an individual message in a history pagination request. +type BackfillMessage struct { + *ConvertedMessage + Sender EventSender + ID networkid.MessageID + Timestamp time.Time + Reactions []*BackfillReaction +} + +// FetchMessagesResponse contains the response for a message history pagination request. +type FetchMessagesResponse struct { + // The messages to backfill. Messages should always be sorted in chronological order (oldest to newest). + Messages []*BackfillMessage + // The next cursor to use for fetching more messages. + Cursor networkid.PaginationCursor + // Whether there are more messages that can be backfilled. + // This field is required. If it is false, FetchMessages will not be called again. + HasMore bool + // Whether the batch contains new messages rather than old ones. + // Cursor, HasMore and the progress fields will be ignored when this is present. + Forward bool + // When sending forward backfill (or the first batch in a room), this field can be set + // to mark the messages as read immediately after backfilling. + MarkRead bool + + // When HasMore is true, one of the following fields can be set to report backfill progress: + + // Approximate backfill progress as a number between 0 and 1. + ApproxProgress float64 + // Approximate number of messages remaining that can be backfilled. + ApproxRemainingCount int + // Approximate total number of messages in the chat. + ApproxTotalCount int +} + +// BackfillingNetworkAPI is an optional interface that network connectors can implement to support backfilling message history. +type BackfillingNetworkAPI interface { + NetworkAPI + FetchMessages(ctx context.Context, fetchParams FetchMessagesParams) (*FetchMessagesResponse, error) +} + // EditHandlingNetworkAPI is an optional interface that network connectors can implement to handle message edits. type EditHandlingNetworkAPI interface { NetworkAPI @@ -537,6 +617,10 @@ func (ret RemoteEventType) String() string { return "RemoteEventTyping" case RemoteEventChatInfoChange: return "RemoteEventChatInfoChange" + case RemoteEventChatResync: + return "RemoteEventChatResync" + case RemoteEventBackfill: + return "RemoteEventBackfill" default: return fmt.Sprintf("RemoteEventType(%d)", int(ret)) } @@ -554,6 +638,8 @@ const ( RemoteEventMarkUnread RemoteEventTyping RemoteEventChatInfoChange + RemoteEventChatResync + RemoteEventBackfill ) // RemoteEvent represents a single event from the remote network, such as a message or a reaction. @@ -577,6 +663,20 @@ type RemoteChatInfoChange interface { GetChatInfoChange(ctx context.Context) (*ChatInfoChange, error) } +type RemoteChatResync interface { + RemoteEvent +} + +type RemoteChatResyncWithInfo interface { + RemoteChatResync + GetChatInfo(ctx context.Context, portal *Portal) (*ChatInfo, error) +} + +type RemoteChatResyncBackfill interface { + RemoteChatResync + CheckNeedsBackfill(ctx context.Context, latestMessage *database.Message) (bool, error) +} + type RemoteEventThatMayCreatePortal interface { RemoteEvent ShouldCreatePortal() bool @@ -649,6 +749,11 @@ type RemoteTyping interface { GetTimeout() time.Duration } +type RemoteBackfill interface { + RemoteEvent + GetBackfillData(ctx context.Context, portal *Portal) (*FetchMessagesResponse, error) +} + type TypingType int const ( diff --git a/bridgev2/portal.go b/bridgev2/portal.go index ebe7a22a..ddd8862b 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -383,7 +383,8 @@ func (portal *Portal) handleMatrixReadReceipt(ctx context.Context, user *User, e log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c. Stringer("event_id", eventID). - Stringer("user_id", user.MXID) + Stringer("user_id", user.MXID). + Stringer("receipt_ts", receipt.Timestamp) }) login, userPortal, err := portal.FindPreferredLogin(ctx, user, false) if err != nil { @@ -398,6 +399,9 @@ func (portal *Portal) handleMatrixReadReceipt(ctx context.Context, user *User, e if !ok { return } + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("user_login_id", string(login.ID)) + }) evt := &MatrixReadReceipt{ Portal: portal, EventID: eventID, @@ -413,6 +417,9 @@ func (portal *Portal) handleMatrixReadReceipt(ctx context.Context, user *User, e if err != nil { log.Err(err).Msg("Failed to get exact message from database") } else if evt.ExactMessage != nil { + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("exact_message_id", string(evt.ExactMessage.ID)).Time("exact_message_ts", evt.ExactMessage.Timestamp) + }) evt.ReadUpTo = evt.ExactMessage.Timestamp } else { evt.ReadUpTo = receipt.Timestamp @@ -1053,7 +1060,16 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { if !ok || !mcp.ShouldCreatePortal() { return } - err := portal.CreateMatrixRoom(ctx, source, nil) + infoProvider, ok := mcp.(RemoteChatResyncWithInfo) + var info *ChatInfo + var err error + if ok { + info, err = infoProvider.GetChatInfo(ctx, portal) + if err != nil { + log.Err(err).Msg("Failed to get chat info for portal creation from chat resync event") + } + } + err = portal.CreateMatrixRoom(ctx, source, info) if err != nil { log.Err(err).Msg("Failed to create portal to handle event") // TODO error @@ -1089,6 +1105,10 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { portal.handleRemoteTyping(ctx, source, evt.(RemoteTyping)) case RemoteEventChatInfoChange: portal.handleRemoteChatInfoChange(ctx, source, evt.(RemoteChatInfoChange)) + case RemoteEventChatResync: + portal.handleRemoteChatResync(ctx, source, evt.(RemoteChatResync)) + case RemoteEventBackfill: + portal.handleRemoteBackfill(ctx, source, evt.(RemoteBackfill)) default: log.Warn().Int("type", int(evt.GetType())).Msg("Got remote event with unknown type") } @@ -1143,66 +1163,70 @@ func (portal *Portal) GetIntentFor(ctx context.Context, sender EventSender, sour return intent } -func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin, evt RemoteMessage) { +func (portal *Portal) getRelationMeta(ctx context.Context, replyToPtr *networkid.MessageOptionalPartID, threadRootPtr *networkid.MessageID, isBatchSend bool) (replyTo, threadRoot, prevThreadEvent *database.Message) { log := zerolog.Ctx(ctx) - existing, err := portal.Bridge.DB.Message.GetFirstPartByID(ctx, portal.Receiver, evt.GetID()) - if err != nil { - log.Err(err).Msg("Failed to check if message is a duplicate") - } else if existing != nil { - log.Debug().Stringer("existing_mxid", existing.MXID).Msg("Ignoring duplicate message") - return - } - intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessage) - if intent == nil { - return - } - ts := getEventTS(evt) - converted, err := evt.ConvertMessage(ctx, portal, intent) - if err != nil { - log.Err(err).Msg("Failed to convert remote message") - portal.sendRemoteErrorNotice(ctx, intent, err, ts, "message") - return - } - var threadRootID networkid.MessageID - var replyToID networkid.MessageOptionalPartID - var replyTo, threadRoot, prevThreadEvent *database.Message - if converted.ReplyTo != nil { - replyToID = *converted.ReplyTo - replyTo, err = portal.Bridge.DB.Message.GetFirstOrSpecificPartByID(ctx, portal.Receiver, *converted.ReplyTo) + var err error + if replyToPtr != nil { + replyTo, err = portal.Bridge.DB.Message.GetFirstOrSpecificPartByID(ctx, portal.Receiver, *replyToPtr) if err != nil { log.Err(err).Msg("Failed to get reply target message from database") } else if replyTo == nil { - log.Warn().Any("reply_to", converted.ReplyTo).Msg("Reply target message not found in database") + if isBatchSend { + // This is somewhat evil + replyTo = &database.Message{ + MXID: portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, replyToPtr.MessageID, ptr.Val(replyToPtr.PartID)), + } + } else { + log.Warn().Any("reply_to", *replyToPtr).Msg("Reply target message not found in database") + } } } - if converted.ThreadRoot != nil { - threadRootID = *converted.ThreadRoot - threadRoot, err = portal.Bridge.DB.Message.GetFirstThreadMessage(ctx, portal.PortalKey, threadRootID) + if threadRootPtr != nil { + threadRoot, err = portal.Bridge.DB.Message.GetFirstThreadMessage(ctx, portal.PortalKey, *threadRootPtr) if err != nil { log.Err(err).Msg("Failed to get thread root message from database") } else if threadRoot == nil { - log.Warn().Any("thread_root", converted.ThreadRoot).Msg("Thread root message not found in database") - } - prevThreadEvent, err = portal.Bridge.DB.Message.GetLastThreadMessage(ctx, portal.PortalKey, threadRootID) - if err != nil { + if isBatchSend { + threadRoot = &database.Message{ + MXID: portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, *threadRootPtr, ""), + } + } else { + log.Warn().Str("thread_root", string(*threadRootPtr)).Msg("Thread root message not found in database") + } + } else if prevThreadEvent, err = portal.Bridge.DB.Message.GetLastThreadMessage(ctx, portal.PortalKey, *threadRootPtr); err != nil { log.Err(err).Msg("Failed to get last thread message from database") - } else if prevThreadEvent == nil { + } + if prevThreadEvent == nil { prevThreadEvent = threadRoot } } + return +} + +func (portal *Portal) applyRelationMeta(content *event.MessageEventContent, replyTo, threadRoot, prevThreadEvent *database.Message) { + if threadRoot != nil && prevThreadEvent != nil { + content.GetRelatesTo().SetThread(threadRoot.MXID, prevThreadEvent.MXID) + } + if replyTo != nil { + content.GetRelatesTo().SetReplyTo(replyTo.MXID) + if content.Mentions == nil { + content.Mentions = &event.Mentions{} + } + content.Mentions.Add(replyTo.Metadata.SenderMXID) + } +} + +func (portal *Portal) sendConvertedMessage(ctx context.Context, id networkid.MessageID, intent MatrixAPI, sender EventSender, converted *ConvertedMessage, ts time.Time, logContext func(*zerolog.Event) *zerolog.Event) []*database.Message { + if logContext == nil { + logContext = func(e *zerolog.Event) *zerolog.Event { + return e + } + } + log := zerolog.Ctx(ctx) + replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta(ctx, converted.ReplyTo, converted.ThreadRoot, false) + output := make([]*database.Message, 0, len(converted.Parts)) for _, part := range converted.Parts { - if threadRoot != nil && prevThreadEvent != nil { - part.Content.GetRelatesTo().SetThread(threadRoot.MXID, prevThreadEvent.MXID) - } - if replyTo != nil { - part.Content.GetRelatesTo().SetReplyTo(replyTo.MXID) - if part.Content.Mentions == nil { - part.Content.Mentions = &event.Mentions{} - } - if !slices.Contains(part.Content.Mentions.UserIDs, replyTo.Metadata.SenderMXID) { - part.Content.Mentions.UserIDs = append(part.Content.Mentions.UserIDs, replyTo.Metadata.SenderMXID) - } - } + portal.applyRelationMeta(part.Content, replyTo, threadRoot, prevThreadEvent) resp, err := intent.SendMessage(ctx, portal.MXID, part.Type, &event.Content{ Parsed: part.Content, Raw: part.Extra, @@ -1216,14 +1240,14 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin Str("part_id", string(part.ID)). Msg("Sent message part to Matrix") dbMessage := &database.Message{ - ID: evt.GetID(), + ID: id, PartID: part.ID, MXID: resp.EventID, Room: portal.PortalKey, - SenderID: evt.GetSender().Sender, + SenderID: sender.Sender, Timestamp: ts, - ThreadRoot: threadRootID, - ReplyTo: replyToID, + ThreadRoot: ptr.Val(converted.ThreadRoot), + ReplyTo: ptr.Val(converted.ReplyTo), } dbMessage.Metadata.SenderMXID = intent.GetMXID() dbMessage.Metadata.Extra = part.DBMetadata @@ -1244,7 +1268,32 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin if prevThreadEvent != nil { prevThreadEvent = dbMessage } + output = append(output, dbMessage) } + return output +} + +func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin, evt RemoteMessage) { + log := zerolog.Ctx(ctx) + existing, err := portal.Bridge.DB.Message.GetFirstPartByID(ctx, portal.Receiver, evt.GetID()) + if err != nil { + log.Err(err).Msg("Failed to check if message is a duplicate") + } else if existing != nil { + log.Debug().Stringer("existing_mxid", existing.MXID).Msg("Ignoring duplicate message") + return + } + intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessage) + if intent == nil { + return + } + ts := getEventTS(evt) + converted, err := evt.ConvertMessage(ctx, portal, intent) + if err != nil { + log.Err(err).Msg("Failed to convert remote message") + portal.sendRemoteErrorNotice(ctx, intent, err, ts, "message") + return + } + portal.sendConvertedMessage(ctx, evt.GetID(), intent, evt.GetSender(), converted, ts, nil) } func (portal *Portal) sendRemoteErrorNotice(ctx context.Context, intent MatrixAPI, err error, ts time.Time, evtTypeName string) { @@ -1555,7 +1604,7 @@ func (portal *Portal) handleRemoteMarkUnread(ctx context.Context, source *UserLo } func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteReceipt) { - + // TODO implement } func (portal *Portal) handleRemoteTyping(ctx context.Context, source *UserLogin, evt RemoteTyping) { @@ -1579,6 +1628,38 @@ func (portal *Portal) handleRemoteChatInfoChange(ctx context.Context, source *Us portal.ProcessChatInfoChange(ctx, evt.GetSender(), source, info, getEventTS(evt)) } +func (portal *Portal) handleRemoteChatResync(ctx context.Context, source *UserLogin, evt RemoteChatResync) { + log := zerolog.Ctx(ctx) + infoProvider, ok := evt.(RemoteChatResyncWithInfo) + if ok { + info, err := infoProvider.GetChatInfo(ctx, portal) + if err != nil { + log.Err(err).Msg("Failed to get chat info from resync event") + } else if info != nil { + portal.UpdateInfo(ctx, info, source, nil, time.Time{}) + } + } + backfillChecker, ok := evt.(RemoteChatResyncBackfill) + if ok { + latestMessage, err := portal.Bridge.DB.Message.GetLastPartAtOrBeforeTime(ctx, portal.PortalKey, time.Now().Add(10*time.Second)) + if err != nil { + log.Err(err).Msg("Failed to get last message in portal to check if backfill is necessary") + } else if needsBackfill, err := backfillChecker.CheckNeedsBackfill(ctx, latestMessage); err != nil { + log.Err(err).Msg("Failed to check if backfill is needed") + } else if needsBackfill { + portal.doForwardBackfill(ctx, source, latestMessage) + } + } +} + +func (portal *Portal) handleRemoteBackfill(ctx context.Context, source *UserLogin, backfill RemoteBackfill) { + //data, err := backfill.GetBackfillData(ctx, portal) + //if err != nil { + // zerolog.Ctx(ctx).Err(err).Msg("Failed to get backfill data") + // return + //} +} + type ChatInfoChange struct { // The chat info that changed. Any fields that did not change can be left as nil. ChatInfo *ChatInfo diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go new file mode 100644 index 00000000..cef9e140 --- /dev/null +++ b/bridgev2/portalbackfill.go @@ -0,0 +1,211 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "time" + + "github.com/rs/zerolog" + "go.mau.fi/util/ptr" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func (portal *Portal) doForwardBackfill(ctx context.Context, source *UserLogin, lastMessage *database.Message) { + log := zerolog.Ctx(ctx).With().Str("action", "forward backfill").Logger() + ctx = log.WithContext(ctx) + api, ok := source.Client.(BackfillingNetworkAPI) + if !ok { + log.Debug().Msg("Network API does not support backfilling") + return + } + log.Info().Str("latest_message_id", string(lastMessage.ID)).Msg("Fetching messages for forward backfill") + resp, err := api.FetchMessages(ctx, FetchMessagesParams{ + Portal: portal, + ThreadRoot: "", + Forward: true, + AnchorMessage: lastMessage, + Count: 100, + }) + if err != nil { + log.Err(err).Msg("Failed to fetch messages for forward backfill") + return + } + portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, lastMessage) +} + +func (portal *Portal) DoBackwardsBackfill(ctx context.Context, source *UserLogin) { + //log := zerolog.Ctx(ctx) + //api, ok := source.Client.(BackfillingNetworkAPI) + //if !ok { + // log.Debug().Msg("Network API does not support backfilling") + // return + //} + //resp, err := api.FetchMessages(ctx, FetchMessagesParams{ + // Portal: portal, + // ThreadRoot: "", + // Forward: true, + // AnchorMessage: lastMessage, + // Count: 100, + //}) + //if err != nil { + // log.Err(err).Msg("Failed to fetch messages for forward backfill") + // return + //} + //portal.sendBackfill(ctx, source, resp.Messages, false, resp.MarkRead, lastMessage) +} + +func (portal *Portal) sendBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead bool, lastMessage *database.Message) { + if forceForward { + var cutoff int + for i, msg := range messages { + if msg.Timestamp.Before(lastMessage.Timestamp) { + cutoff = i + } + } + if cutoff != 0 { + zerolog.Ctx(ctx).Debug(). + Int("cutoff_count", cutoff). + Int("total_count", len(messages)). + Time("last_bridged_ts", lastMessage.Timestamp). + Msg("Cutting off forward backfill messages older than latest bridged message") + messages = messages[cutoff:] + } + } + canBatchSend := portal.Bridge.Matrix.GetCapabilities().BatchSending + zerolog.Ctx(ctx).Info().Int("message_count", len(messages)).Bool("batch_send", canBatchSend).Msg("Sending backfill messages") + if canBatchSend { + portal.sendBatch(ctx, source, messages, forceForward, markRead) + } else { + portal.sendLegacyBackfill(ctx, source, messages, markRead) + } + zerolog.Ctx(ctx).Debug().Msg("Backfill finished") +} + +func (portal *Portal) sendBatch(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead bool) { + req := &mautrix.ReqBeeperBatchSend{ + ForwardIfNoMessages: !forceForward, + Forward: forceForward, + Events: make([]*event.Event, 0, len(messages)), + } + if markRead { + req.MarkReadBy = source.UserMXID + } else { + req.SendNotification = forceForward + } + prevThreadEvents := make(map[networkid.MessageID]id.EventID) + dbMessages := make([]*database.Message, 0, len(messages)) + var disappearingMessages []*database.DisappearingMessage + for _, msg := range messages { + intent := portal.GetIntentFor(ctx, msg.Sender, source, RemoteEventMessage) + replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta(ctx, msg.ReplyTo, msg.ThreadRoot, true) + if threadRoot != nil && prevThreadEvents[*msg.ThreadRoot] != "" { + prevThreadEvent.MXID = prevThreadEvents[*msg.ThreadRoot] + } + for _, part := range msg.Parts { + portal.applyRelationMeta(part.Content, replyTo, threadRoot, prevThreadEvent) + evtID := portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, msg.ID, part.ID) + req.Events = append(req.Events, &event.Event{ + Sender: intent.GetMXID(), + Type: part.Type, + Timestamp: msg.Timestamp.UnixMilli(), + ID: evtID, + RoomID: portal.MXID, + Content: event.Content{ + Parsed: part.Content, + Raw: part.Extra, + }, + }) + dbMessages = append(dbMessages, &database.Message{ + ID: msg.ID, + PartID: part.ID, + MXID: evtID, + Room: portal.PortalKey, + SenderID: msg.Sender.Sender, + Timestamp: msg.Timestamp, + ThreadRoot: ptr.Val(msg.ThreadRoot), + ReplyTo: ptr.Val(msg.ReplyTo), + Metadata: database.MessageMetadata{ + StandardMessageMetadata: database.StandardMessageMetadata{ + SenderMXID: intent.GetMXID(), + }, + Extra: part.DBMetadata, + }, + }) + if prevThreadEvent != nil { + prevThreadEvent.MXID = evtID + prevThreadEvents[*msg.ThreadRoot] = evtID + } + if msg.Disappear.Type != database.DisappearingTypeNone { + if msg.Disappear.Type == database.DisappearingTypeAfterSend && msg.Disappear.DisappearAt.IsZero() { + msg.Disappear.DisappearAt = msg.Timestamp.Add(msg.Disappear.Timer) + } + disappearingMessages = append(disappearingMessages, &database.DisappearingMessage{ + RoomID: portal.MXID, + EventID: evtID, + DisappearingSetting: msg.Disappear, + }) + } + } + // TODO handle reactions + //for _, reaction := range msg.Reactions { + //} + } + _, err := portal.Bridge.Matrix.BatchSend(ctx, portal.MXID, req) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to send backfill messages") + } + if len(disappearingMessages) > 0 { + go func() { + for _, msg := range disappearingMessages { + portal.Bridge.DisappearLoop.Add(ctx, msg) + } + }() + } + for _, msg := range dbMessages { + err := portal.Bridge.DB.Message.Insert(ctx, msg) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Str("message_id", string(msg.ID)). + Str("part_id", string(msg.PartID)). + Msg("Failed to insert backfilled message to database") + } + } +} + +func (portal *Portal) sendLegacyBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, markRead bool) { + var lastPart id.EventID + for _, msg := range messages { + intent := portal.GetIntentFor(ctx, msg.Sender, source, RemoteEventMessage) + dbMessages := portal.sendConvertedMessage(ctx, msg.ID, intent, msg.Sender, msg.ConvertedMessage, msg.Timestamp, func(z *zerolog.Event) *zerolog.Event { + return z. + Str("message_id", string(msg.ID)). + Any("sender_id", msg.Sender). + Time("message_ts", msg.Timestamp) + }) + if len(dbMessages) > 0 { + lastPart = dbMessages[len(dbMessages)-1].MXID + } + // TODO handle reactions + //for _, reaction := range msg.Reactions { + //} + } + if markRead { + dp := source.User.DoublePuppet(ctx) + if dp != nil { + err := dp.MarkRead(ctx, portal.MXID, lastPart, time.Now()) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to mark room as read after backfill") + } + } + } +} From 681b5449d54574f7fd15b232c166848f0856563a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 12 Jul 2024 16:10:06 +0300 Subject: [PATCH 0438/1647] bridgev2/backfill: fix handling forward backfills in empty rooms --- bridgev2/portalbackfill.go | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index cef9e140..f1ec3110 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -28,7 +28,13 @@ func (portal *Portal) doForwardBackfill(ctx context.Context, source *UserLogin, log.Debug().Msg("Network API does not support backfilling") return } - log.Info().Str("latest_message_id", string(lastMessage.ID)).Msg("Fetching messages for forward backfill") + logEvt := log.Info() + if lastMessage != nil { + logEvt = logEvt.Str("latest_message_id", string(lastMessage.ID)) + } else { + logEvt = logEvt.Str("latest_message_id", "") + } + logEvt.Msg("Fetching messages for forward backfill") resp, err := api.FetchMessages(ctx, FetchMessagesParams{ Portal: portal, ThreadRoot: "", @@ -65,20 +71,22 @@ func (portal *Portal) DoBackwardsBackfill(ctx context.Context, source *UserLogin } func (portal *Portal) sendBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead bool, lastMessage *database.Message) { - if forceForward { - var cutoff int - for i, msg := range messages { - if msg.Timestamp.Before(lastMessage.Timestamp) { - cutoff = i + if lastMessage != nil { + if forceForward { + var cutoff int + for i, msg := range messages { + if msg.Timestamp.Before(lastMessage.Timestamp) { + cutoff = i + } + } + if cutoff != 0 { + zerolog.Ctx(ctx).Debug(). + Int("cutoff_count", cutoff). + Int("total_count", len(messages)). + Time("last_bridged_ts", lastMessage.Timestamp). + Msg("Cutting off forward backfill messages older than latest bridged message") + messages = messages[cutoff:] } - } - if cutoff != 0 { - zerolog.Ctx(ctx).Debug(). - Int("cutoff_count", cutoff). - Int("total_count", len(messages)). - Time("last_bridged_ts", lastMessage.Timestamp). - Msg("Cutting off forward backfill messages older than latest bridged message") - messages = messages[cutoff:] } } canBatchSend := portal.Bridge.Matrix.GetCapabilities().BatchSending From 7c9b8cb2877341191e82e7aa87be6f88dcf0b9d9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 12 Jul 2024 16:48:42 +0300 Subject: [PATCH 0439/1647] bridgev2/matrix: add support for appservice websockets --- bridgev2/matrix/connector.go | 31 +++++++++++--- bridgev2/matrix/mxmain/main.go | 25 ++++++++++-- .../matrix/{websocket.go.dis => websocket.go} | 40 +++++++++++-------- bridgev2/userlogin.go | 17 ++++++++ 4 files changed, 88 insertions(+), 25 deletions(-) rename bridgev2/matrix/{websocket.go.dis => websocket.go} (84%) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 04a7ea1f..b2e62d62 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -88,7 +88,9 @@ type Connector struct { wsStopped chan struct{} wsShortCircuitReconnectBackoff chan struct{} wsStartupWait *sync.WaitGroup - latestState *status.BridgeState + stopping bool + hasSentAnyStates bool + OnWebsocketReplaced func() } var ( @@ -152,7 +154,23 @@ func (br *Connector) Start(ctx context.Context) error { if err != nil { return bridgev2.DBUpgradeError{Section: "matrix_state", Err: err} } - go br.AS.Start() + if br.Config.Homeserver.Websocket || len(br.Config.Homeserver.WSProxy) > 0 { + br.Websocket = true + br.Log.Debug().Msg("Starting appservice websocket") + var wg sync.WaitGroup + wg.Add(1) + br.wsStartupWait = &wg + br.wsShortCircuitReconnectBackoff = make(chan struct{}) + go br.startWebsocket(&wg) + } else if br.AS.Host.IsConfigured() { + br.Log.Debug().Msg("Starting appservice HTTP server") + go br.AS.Start() + } else { + br.Log.WithLevel(zerolog.FatalLevel).Msg("Neither appservice HTTP listener nor websocket is enabled") + os.Exit(23) + } + + br.Log.Debug().Msg("Checking connection to homeserver") br.ensureConnection(ctx) go br.fetchMediaConfig(ctx) if br.Crypto != nil { @@ -171,6 +189,10 @@ func (br *Connector) Start(ctx context.Context) error { br.deterministicEventIDServer = parsed.Hostname() } br.AS.Ready = true + if br.Websocket && br.Config.Homeserver.WSPingInterval > 0 { + br.wsStopPinger = make(chan struct{}, 1) + go br.websocketServerPinger() + } return nil } @@ -193,6 +215,7 @@ func (br *Connector) GetCapabilities() *bridgev2.MatrixCapabilities { } func (br *Connector) Stop() { + br.stopping = true br.AS.Stop() br.EventProcessor.Stop() if br.Crypto != nil { @@ -370,9 +393,7 @@ func (br *Connector) GhostIntent(userID networkid.UserID) bridgev2.MatrixAPI { func (br *Connector) SendBridgeStatus(ctx context.Context, state *status.BridgeState) error { if br.Websocket { - // FIXME this doesn't account for multiple users - br.latestState = state - + br.hasSentAnyStates = true return br.AS.SendWebsocket(&appservice.WebsocketRequest{ Command: "bridge_status", Data: state, diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index 4c25fb33..16a54b29 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -89,6 +89,8 @@ type BridgeMain struct { AdditionalShortFlags string AdditionalLongFlags string + + manualStop chan int } type VersionJSONOutput struct { @@ -115,14 +117,16 @@ func (br *BridgeMain) Run() { br.PreInit() br.Init() br.Start() - br.WaitForInterrupt() + exitCode := br.WaitForInterrupt() br.Stop() + os.Exit(exitCode) } // PreInit parses CLI flags and loads the config file. This is called by [Run] and does not need to be called manually. // // This also handles all flags that cause the bridge to exit immediately (e.g. `--version` and `--generate-registration`). func (br *BridgeMain) PreInit() { + br.manualStop = make(chan int, 1) flag.SetHelpTitles( fmt.Sprintf("%s - %s", br.Name, br.Description), fmt.Sprintf("%s [-hgvn%s] [-c ] [-r ]%s", br.Name, br.AdditionalShortFlags, br.AdditionalLongFlags)) @@ -231,6 +235,9 @@ func (br *BridgeMain) Init() { br.initDB() br.Matrix = matrix.NewConnector(br.Config) + br.Matrix.OnWebsocketReplaced = func() { + br.TriggerStop(0) + } br.Matrix.IgnoreUnsupportedServer = *ignoreUnsupportedServer br.Bridge = bridgev2.NewBridge("", br.DB, *br.Log, &br.Config.Bridge, br.Matrix, br.Connector, commands.NewProcessor) br.Matrix.AS.DoublePuppetValue = br.Name @@ -365,10 +372,22 @@ func (br *BridgeMain) Start() { } // WaitForInterrupt waits for a SIGINT or SIGTERM signal. -func (br *BridgeMain) WaitForInterrupt() { +func (br *BridgeMain) WaitForInterrupt() int { c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt, syscall.SIGTERM) - <-c + select { + case <-c: + return 0 + case exitCode := <-br.manualStop: + return exitCode + } +} + +func (br *BridgeMain) TriggerStop(exitCode int) { + select { + case br.manualStop <- exitCode: + default: + } } // Stop cleanly stops the bridge. This is called by [Run] and does not need to be called manually. diff --git a/bridgev2/matrix/websocket.go.dis b/bridgev2/matrix/websocket.go similarity index 84% rename from bridgev2/matrix/websocket.go.dis rename to bridgev2/matrix/websocket.go index cf4b0517..36b8bca4 100644 --- a/bridgev2/matrix/websocket.go.dis +++ b/bridgev2/matrix/websocket.go @@ -1,14 +1,19 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + package matrix import ( "context" "errors" "fmt" + "os" "sync" "time" - "go.mau.fi/util/jsontime" - "maunium.net/go/mautrix/appservice" ) @@ -20,20 +25,17 @@ func (br *Connector) startWebsocket(wg *sync.WaitGroup) { log := br.Log.With().Str("action", "appservice websocket").Logger() var wgOnce sync.Once onConnect := func() { - wssBr, ok := br.Child.(WebsocketStartingBridge) - if ok { - wssBr.OnWebsocketConnect() - } - if br.latestState != nil { + if br.hasSentAnyStates { go func() { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() - br.latestState.Timestamp = jsontime.UnixNow() - err := br.SendBridgeState(ctx, br.latestState) - if err != nil { - log.Err(err).Msg("Failed to resend latest bridge state after websocket reconnect") - } else { - log.Debug().Any("bridge_state", br.latestState).Msg("Resent bridge state after websocket reconnect") + for _, state := range br.Bridge.GetCurrentBridgeStates() { + err := br.SendBridgeStatus(ctx, &state) + if err != nil { + log.Err(err).Msg("Failed to resend latest bridge state after websocket reconnect") + } else { + log.Debug().Any("bridge_state", state).Msg("Resent bridge state after websocket reconnect") + } } }() } @@ -60,12 +62,16 @@ func (br *Connector) startWebsocket(wg *sync.WaitGroup) { return } else if closeCommand := (&appservice.CloseCommand{}); errors.As(err, &closeCommand) && closeCommand.Status == appservice.MeowConnectionReplaced { log.Info().Msg("Appservice websocket closed by another instance of the bridge, shutting down...") - br.ManualStop(0) + if br.OnWebsocketReplaced != nil { + br.OnWebsocketReplaced() + } else { + os.Exit(1) + } return } else if err != nil { log.Err(err).Msg("Error in appservice websocket") } - if br.Stopping { + if br.stopping { return } now := time.Now().UnixNano() @@ -86,7 +92,7 @@ func (br *Connector) startWebsocket(wg *sync.WaitGroup) { log.Debug().Msg("Reconnect backoff was short-circuited") case <-time.After(reconnectBackoff): } - if br.Stopping { + if br.stopping { return } } @@ -156,7 +162,7 @@ func (br *Connector) websocketServerPinger() { case <-br.wsStopPinger: return } - if br.Stopping { + if br.stopping { return } } diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 9e9f605b..46162ad8 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -126,6 +126,23 @@ func (br *Bridge) GetCachedUserLoginByID(id networkid.UserLoginID) *UserLogin { return br.userLoginsByID[id] } +func (br *Bridge) GetCurrentBridgeStates() (states []status.BridgeState) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + if len(br.userLoginsByID) == 0 { + return []status.BridgeState{{ + StateEvent: status.StateUnconfigured, + }} + } + states = make([]status.BridgeState, len(br.userLoginsByID)) + i := 0 + for _, login := range br.userLoginsByID { + states[i] = login.BridgeState.GetPrev() + i++ + } + return +} + type NewLoginParams struct { LoadUserLogin func(context.Context, *UserLogin) error DeleteOnConflict bool From 88f4da34334e7c9a5350ea2440c725b6066d5608 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 12 Jul 2024 17:31:15 +0300 Subject: [PATCH 0440/1647] bridgev2: add method to get or create management room --- bridgev2/user.go | 57 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/bridgev2/user.go b/bridgev2/user.go index b9ea462b..323e8dbf 100644 --- a/bridgev2/user.go +++ b/bridgev2/user.go @@ -17,9 +17,11 @@ import ( "golang.org/x/exp/maps" "golang.org/x/exp/slices" + "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2/bridgeconfig" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -35,6 +37,8 @@ type User struct { doublePuppetInitialized bool doublePuppetLock sync.Mutex + managementCreateLock sync.Mutex + logins map[networkid.UserLoginID]*UserLogin } @@ -188,6 +192,59 @@ func (user *User) GetDefaultLogin() *UserLogin { return user.logins[loginKeys[0]] } +func (user *User) GetManagementRoom(ctx context.Context) (id.RoomID, error) { + user.managementCreateLock.Lock() + defer user.managementCreateLock.Unlock() + if user.ManagementRoom != "" { + return user.ManagementRoom, nil + } + netName := user.Bridge.Network.GetName() + var err error + autoJoin := user.Bridge.Matrix.GetCapabilities().AutoJoinInvites + doublePuppet := user.DoublePuppet(ctx) + req := &mautrix.ReqCreateRoom{ + Visibility: "private", + Name: netName.DisplayName, + Topic: fmt.Sprintf("%s bridge management room", netName.DisplayName), + InitialState: []*event.Event{{ + Type: event.StateRoomAvatar, + Content: event.Content{ + Parsed: &event.RoomAvatarEventContent{ + URL: netName.NetworkIcon, + }, + }, + }}, + PowerLevelOverride: &event.PowerLevelsEventContent{ + Users: map[id.UserID]int{ + user.Bridge.Bot.GetMXID(): 9001, + user.MXID: 50, + }, + }, + Invite: []id.UserID{user.MXID}, + IsDirect: true, + } + if autoJoin { + req.BeeperInitialMembers = []id.UserID{user.MXID} + // TODO remove this after initial_members is supported in hungryserv + req.BeeperAutoJoinInvites = true + } + user.ManagementRoom, err = user.Bridge.Bot.CreateRoom(ctx, req) + if err != nil { + return "", fmt.Errorf("failed to create management room: %w", err) + } + if !autoJoin && doublePuppet != nil { + err = doublePuppet.EnsureJoined(ctx, user.ManagementRoom) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to auto-join created management room with double puppet") + } + } + err = user.Save(ctx) + if err != nil { + return "", fmt.Errorf("failed to save management room ID: %w", err) + } + return user.ManagementRoom, nil +} + func (user *User) Save(ctx context.Context) error { return user.Bridge.DB.User.Update(ctx, user.User) } From 98a842c075cc03ba718396b2860c6efd7f0e8a92 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 12 Jul 2024 19:42:11 +0300 Subject: [PATCH 0441/1647] bridge: register if /versions fails with M_FORBIDDEN --- bridge/bridge.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/bridge/bridge.go b/bridge/bridge.go index 40d4c615..6f608089 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -324,8 +324,16 @@ func (br *Bridge) ensureConnection(ctx context.Context) { for { versions, err := br.Bot.Versions(ctx) if err != nil { - br.ZLog.Err(err).Msg("Failed to connect to homeserver, retrying in 10 seconds...") - time.Sleep(10 * time.Second) + if errors.Is(err, mautrix.MForbidden) { + br.ZLog.Debug().Msg("M_FORBIDDEN in /versions, trying to register before retrying") + err = br.Bot.EnsureRegistered(ctx) + if err != nil { + br.ZLog.Err(err).Msg("Failed to register after /versions failed") + } + } else { + br.ZLog.Err(err).Msg("Failed to connect to homeserver, retrying in 10 seconds...") + time.Sleep(10 * time.Second) + } } else { br.SpecVersions = *versions *br.AS.SpecVersions = *versions From 85e0664cb441058513549a6e1e198f073bca9af4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 12 Jul 2024 19:43:09 +0300 Subject: [PATCH 0442/1647] bridgev2: register if /versions fails with M_FORBIDDEN --- bridgev2/matrix/connector.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index b2e62d62..43be55b4 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -229,8 +229,16 @@ func (br *Connector) ensureConnection(ctx context.Context) { for { versions, err := br.Bot.Versions(ctx) if err != nil { - br.Log.Err(err).Msg("Failed to connect to homeserver, retrying in 10 seconds...") - time.Sleep(10 * time.Second) + if errors.Is(err, mautrix.MForbidden) { + br.Log.Debug().Msg("M_FORBIDDEN in /versions, trying to register before retrying") + err = br.Bot.EnsureRegistered(ctx) + if err != nil { + br.Log.Err(err).Msg("Failed to register after /versions failed") + } + } else { + br.Log.Err(err).Msg("Failed to connect to homeserver, retrying in 10 seconds...") + time.Sleep(10 * time.Second) + } } else { br.SpecVersions = versions *br.AS.SpecVersions = *versions From 9fdf94132a3d5a1be18765bf863ff50b2cbed8df Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 13 Jul 2024 12:09:52 +0300 Subject: [PATCH 0443/1647] bridgev2/database: move standard metadata fields to columns, add typing for custom metadata --- bridgev2/bridge.go | 2 +- bridgev2/commands/relay.go | 2 +- bridgev2/database/database.go | 104 ++++++++++++++++-- bridgev2/database/disappear.go | 4 - bridgev2/database/ghost.go | 76 ++++++------- bridgev2/database/message.go | 77 ++++++------- bridgev2/database/portal.go | 76 ++++++------- bridgev2/database/reaction.go | 56 ++++------ bridgev2/database/upgrades/00-latest.sql | 61 +++++----- .../08-drop-message-relates-to.postgres.sql | 3 + .../08-drop-message-relates-to.sqlite.sql | 41 +++++++ .../upgrades/09-remove-standard-metadata.sql | 45 ++++++++ ...10-fix-signal-portal-revision.postgres.sql | 4 + .../10-fix-signal-portal-revision.sqlite.sql | 4 + bridgev2/database/user.go | 4 - bridgev2/database/userlogin.go | 64 ++++------- bridgev2/database/userportal.go | 4 - bridgev2/ghost.go | 56 ++-------- bridgev2/matrix/provisioning.go | 6 +- bridgev2/networkinterface.go | 20 +++- bridgev2/portal.go | 74 +++++++------ bridgev2/portalbackfill.go | 8 +- bridgev2/space.go | 4 +- bridgev2/user.go | 2 +- bridgev2/userlogin.go | 11 +- 25 files changed, 446 insertions(+), 362 deletions(-) create mode 100644 bridgev2/database/upgrades/08-drop-message-relates-to.postgres.sql create mode 100644 bridgev2/database/upgrades/08-drop-message-relates-to.sqlite.sql create mode 100644 bridgev2/database/upgrades/09-remove-standard-metadata.sql create mode 100644 bridgev2/database/upgrades/10-fix-signal-portal-revision.postgres.sql create mode 100644 bridgev2/database/upgrades/10-fix-signal-portal-revision.sqlite.sql diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 65d39c80..62dc532b 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -60,7 +60,7 @@ func NewBridge( ) *Bridge { br := &Bridge{ ID: bridgeID, - DB: database.New(bridgeID, db), + DB: database.New(bridgeID, network.GetDBMetaTypes(), db), Log: log, Matrix: matrix, diff --git a/bridgev2/commands/relay.go b/bridgev2/commands/relay.go index 9093a52c..e2d77a2d 100644 --- a/bridgev2/commands/relay.go +++ b/bridgev2/commands/relay.go @@ -80,7 +80,7 @@ func fnSetRelay(ce *Event) { } else { ce.Reply( "Messages sent by users who haven't logged in will now be relayed through %s ([%s](%s)'s login)", - relay.Metadata.RemoteName, + relay.RemoteName, relay.UserMXID, // TODO this will need to stop linkifying if we ever allow UserLogins that aren't bound to a real user. relay.UserMXID.URI().MatrixToURL(), diff --git a/bridgev2/database/database.go b/bridgev2/database/database.go index c6d1e4eb..47858fba 100644 --- a/bridgev2/database/database.go +++ b/bridgev2/database/database.go @@ -34,19 +34,101 @@ type Database struct { UserPortal *UserPortalQuery } -func New(bridgeID networkid.BridgeID, db *dbutil.Database) *Database { +type MetaMerger interface { + CopyFrom(other any) +} + +type MetaTypeCreator func() any + +type MetaTypes struct { + Portal MetaTypeCreator + Ghost MetaTypeCreator + Message MetaTypeCreator + Reaction MetaTypeCreator + UserLogin MetaTypeCreator +} + +type blankMeta struct{} + +var blankMetaItem = &blankMeta{} + +func blankMetaCreator() any { + return blankMetaItem +} + +func New(bridgeID networkid.BridgeID, mt MetaTypes, db *dbutil.Database) *Database { + if mt.Portal == nil { + mt.Portal = blankMetaCreator + } + if mt.Ghost == nil { + mt.Ghost = blankMetaCreator + } + if mt.Message == nil { + mt.Message = blankMetaCreator + } + if mt.Reaction == nil { + mt.Reaction = blankMetaCreator + } + if mt.UserLogin == nil { + mt.UserLogin = blankMetaCreator + } db.UpgradeTable = upgrades.Table return &Database{ - Database: db, - BridgeID: bridgeID, - Portal: &PortalQuery{bridgeID, dbutil.MakeQueryHelper(db, newPortal)}, - Ghost: &GhostQuery{bridgeID, dbutil.MakeQueryHelper(db, newGhost)}, - Message: &MessageQuery{bridgeID, dbutil.MakeQueryHelper(db, newMessage)}, - DisappearingMessage: &DisappearingMessageQuery{bridgeID, dbutil.MakeQueryHelper(db, newDisappearingMessage)}, - Reaction: &ReactionQuery{bridgeID, dbutil.MakeQueryHelper(db, newReaction)}, - User: &UserQuery{bridgeID, dbutil.MakeQueryHelper(db, newUser)}, - UserLogin: &UserLoginQuery{bridgeID, dbutil.MakeQueryHelper(db, newUserLogin)}, - UserPortal: &UserPortalQuery{bridgeID, dbutil.MakeQueryHelper(db, newUserPortal)}, + Database: db, + BridgeID: bridgeID, + Portal: &PortalQuery{ + BridgeID: bridgeID, + MetaType: mt.Portal, + QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*Portal]) *Portal { + return (&Portal{}).ensureHasMetadata(mt.Portal) + }), + }, + Ghost: &GhostQuery{ + BridgeID: bridgeID, + MetaType: mt.Ghost, + QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*Ghost]) *Ghost { + return (&Ghost{}).ensureHasMetadata(mt.Ghost) + }), + }, + Message: &MessageQuery{ + BridgeID: bridgeID, + MetaType: mt.Message, + QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*Message]) *Message { + return (&Message{}).ensureHasMetadata(mt.Message) + }), + }, + DisappearingMessage: &DisappearingMessageQuery{ + BridgeID: bridgeID, + QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*DisappearingMessage]) *DisappearingMessage { + return &DisappearingMessage{} + }), + }, + Reaction: &ReactionQuery{ + BridgeID: bridgeID, + MetaType: mt.Reaction, + QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*Reaction]) *Reaction { + return (&Reaction{}).ensureHasMetadata(mt.Reaction) + }), + }, + User: &UserQuery{ + BridgeID: bridgeID, + QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*User]) *User { + return &User{} + }), + }, + UserLogin: &UserLoginQuery{ + BridgeID: bridgeID, + MetaType: mt.UserLogin, + QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*UserLogin]) *UserLogin { + return (&UserLogin{}).ensureHasMetadata(mt.UserLogin) + }), + }, + UserPortal: &UserPortalQuery{ + BridgeID: bridgeID, + QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*UserPortal]) *UserPortal { + return &UserPortal{} + }), + }, } } diff --git a/bridgev2/database/disappear.go b/bridgev2/database/disappear.go index 22a5be5c..23db1448 100644 --- a/bridgev2/database/disappear.go +++ b/bridgev2/database/disappear.go @@ -46,10 +46,6 @@ type DisappearingMessage struct { DisappearingSetting } -func newDisappearingMessage(_ *dbutil.QueryHelper[*DisappearingMessage]) *DisappearingMessage { - return &DisappearingMessage{} -} - const ( upsertDisappearingMessageQuery = ` INSERT INTO disappearing_message (bridge_id, mx_room, mxid, type, timer, disappear_at) diff --git a/bridgev2/database/ghost.go b/bridgev2/database/ghost.go index e6383c0e..916051a4 100644 --- a/bridgev2/database/ghost.go +++ b/bridgev2/database/ghost.go @@ -18,56 +18,43 @@ import ( type GhostQuery struct { BridgeID networkid.BridgeID + MetaType MetaTypeCreator *dbutil.QueryHelper[*Ghost] } -type StandardGhostMetadata struct { - IsBot bool `json:"is_bot,omitempty"` - Identifiers []string `json:"identifiers,omitempty"` - ContactInfoSet bool `json:"contact_info_set,omitempty"` -} - -type GhostMetadata struct { - StandardGhostMetadata - Extra map[string]any -} - -func (gm *GhostMetadata) UnmarshalJSON(data []byte) error { - return unmarshalMerge(data, &gm.StandardGhostMetadata, &gm.Extra) -} - -func (gm *GhostMetadata) MarshalJSON() ([]byte, error) { - return marshalMerge(&gm.StandardGhostMetadata, gm.Extra) -} - type Ghost struct { BridgeID networkid.BridgeID ID networkid.UserID - Name string - AvatarID networkid.AvatarID - AvatarHash [32]byte - AvatarMXC id.ContentURIString - NameSet bool - AvatarSet bool - Metadata GhostMetadata -} - -func newGhost(_ *dbutil.QueryHelper[*Ghost]) *Ghost { - return &Ghost{} + Name string + AvatarID networkid.AvatarID + AvatarHash [32]byte + AvatarMXC id.ContentURIString + NameSet bool + AvatarSet bool + ContactInfoSet bool + IsBot bool + Identifiers []string + Metadata any } const ( getGhostBaseQuery = ` - SELECT bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc, name_set, avatar_set, metadata FROM ghost + SELECT bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc, + name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata + FROM ghost ` getGhostByIDQuery = getGhostBaseQuery + `WHERE bridge_id=$1 AND id=$2` insertGhostQuery = ` - INSERT INTO ghost (bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc, name_set, avatar_set, metadata) + INSERT INTO ghost ( + bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc, + name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) ` updateGhostQuery = ` - UPDATE ghost SET name=$3, avatar_id=$4, avatar_hash=$5, avatar_mxc=$6, name_set=$7, avatar_set=$8, metadata=$9 + UPDATE ghost SET name=$3, avatar_id=$4, avatar_hash=$5, avatar_mxc=$6, + name_set=$7, avatar_set=$8, contact_info_set=$9, is_bot=$10, identifiers=$11, metadata=$12 WHERE bridge_id=$1 AND id=$2 ` ) @@ -78,12 +65,12 @@ func (gq *GhostQuery) GetByID(ctx context.Context, id networkid.UserID) (*Ghost, func (gq *GhostQuery) Insert(ctx context.Context, ghost *Ghost) error { ensureBridgeIDMatches(&ghost.BridgeID, gq.BridgeID) - return gq.Exec(ctx, insertGhostQuery, ghost.sqlVariables()...) + return gq.Exec(ctx, insertGhostQuery, ghost.ensureHasMetadata(gq.MetaType).sqlVariables()...) } func (gq *GhostQuery) Update(ctx context.Context, ghost *Ghost) error { ensureBridgeIDMatches(&ghost.BridgeID, gq.BridgeID) - return gq.Exec(ctx, updateGhostQuery, ghost.sqlVariables()...) + return gq.Exec(ctx, updateGhostQuery, ghost.ensureHasMetadata(gq.MetaType).sqlVariables()...) } func (g *Ghost) Scan(row dbutil.Scannable) (*Ghost, error) { @@ -91,14 +78,12 @@ func (g *Ghost) Scan(row dbutil.Scannable) (*Ghost, error) { err := row.Scan( &g.BridgeID, &g.ID, &g.Name, &g.AvatarID, &avatarHash, &g.AvatarMXC, - &g.NameSet, &g.AvatarSet, dbutil.JSON{Data: &g.Metadata}, + &g.NameSet, &g.AvatarSet, &g.ContactInfoSet, &g.IsBot, + dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.Metadata}, ) if err != nil { return nil, err } - if g.Metadata.Extra == nil { - g.Metadata.Extra = make(map[string]any) - } if avatarHash != "" { data, _ := hex.DecodeString(avatarHash) if len(data) == 32 { @@ -108,10 +93,14 @@ func (g *Ghost) Scan(row dbutil.Scannable) (*Ghost, error) { return g, nil } -func (g *Ghost) sqlVariables() []any { - if g.Metadata.Extra == nil { - g.Metadata.Extra = make(map[string]any) +func (g *Ghost) ensureHasMetadata(metaType MetaTypeCreator) *Ghost { + if g.Metadata == nil { + g.Metadata = metaType() } + return g +} + +func (g *Ghost) sqlVariables() []any { var avatarHash string if g.AvatarHash != [32]byte{} { avatarHash = hex.EncodeToString(g.AvatarHash[:]) @@ -119,6 +108,7 @@ func (g *Ghost) sqlVariables() []any { return []any{ g.BridgeID, g.ID, g.Name, g.AvatarID, avatarHash, g.AvatarMXC, - g.NameSet, g.AvatarSet, dbutil.JSON{Data: &g.Metadata}, + g.NameSet, g.AvatarSet, g.ContactInfoSet, g.IsBot, + dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.Metadata}, } } diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index d9f78b1e..504f91b2 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -19,27 +19,10 @@ import ( type MessageQuery struct { BridgeID networkid.BridgeID + MetaType MetaTypeCreator *dbutil.QueryHelper[*Message] } -type StandardMessageMetadata struct { - SenderMXID id.UserID `json:"sender_mxid,omitempty"` - EditCount int `json:"edit_count,omitempty"` -} - -type MessageMetadata struct { - StandardMessageMetadata - Extra map[string]any -} - -func (mm *MessageMetadata) UnmarshalJSON(data []byte) error { - return unmarshalMerge(data, &mm.StandardMessageMetadata, &mm.Extra) -} - -func (mm *MessageMetadata) MarshalJSON() ([]byte, error) { - return marshalMerge(&mm.StandardMessageMetadata, mm.Extra) -} - type Message struct { RowID int64 BridgeID networkid.BridgeID @@ -47,23 +30,23 @@ type Message struct { PartID networkid.PartID MXID id.EventID - Room networkid.PortalKey - SenderID networkid.UserID - Timestamp time.Time + Room networkid.PortalKey + SenderID networkid.UserID + SenderMXID id.UserID + Timestamp time.Time + EditCount int ThreadRoot networkid.MessageID ReplyTo networkid.MessageOptionalPartID - Metadata MessageMetadata -} - -func newMessage(_ *dbutil.QueryHelper[*Message]) *Message { - return &Message{} + Metadata any } const ( getMessageBaseQuery = ` - SELECT rowid, bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, timestamp, thread_root_id, reply_to_id, reply_to_part_id, metadata FROM message + SELECT rowid, bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, sender_mxid, + timestamp, edit_count, thread_root_id, reply_to_id, reply_to_part_id, metadata + FROM message ` getAllMessagePartsByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3` getMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 AND part_id=$4` @@ -78,14 +61,17 @@ const ( getLastMessagePartAtOrBeforeTimeQuery = getMessageBaseQuery + `WHERE bridge_id = $1 AND room_id=$2 AND room_receiver=$3 AND timestamp<=$4 ORDER BY timestamp DESC, part_id DESC LIMIT 1` insertMessageQuery = ` - INSERT INTO message (bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, timestamp, thread_root_id, reply_to_id, reply_to_part_id, metadata) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + INSERT INTO message ( + bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, sender_mxid, + timestamp, edit_count, thread_root_id, reply_to_id, reply_to_part_id, metadata + ) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) RETURNING rowid ` updateMessageQuery = ` - UPDATE message SET id=$2, part_id=$3, mxid=$4, room_id=$5, room_receiver=$6, sender_id=$7, timestamp=$8, - thread_root_id=$9, reply_to_id=$10, reply_to_part_id=$11, metadata=$12 - WHERE bridge_id=$1 AND rowid=$13 + UPDATE message SET id=$2, part_id=$3, mxid=$4, room_id=$5, room_receiver=$6, sender_id=$7, sender_mxid=$8, + timestamp=$9, edit_count=$10, thread_root_id=$11, reply_to_id=$12, reply_to_part_id=$13, metadata=$14 + WHERE bridge_id=$1 AND rowid=$15 ` deleteAllMessagePartsByIDQuery = ` DELETE FROM message WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 @@ -145,12 +131,12 @@ func (mq *MessageQuery) GetLastThreadMessage(ctx context.Context, portal network func (mq *MessageQuery) Insert(ctx context.Context, msg *Message) error { ensureBridgeIDMatches(&msg.BridgeID, mq.BridgeID) - return mq.GetDB().QueryRow(ctx, insertMessageQuery, msg.sqlVariables()...).Scan(&msg.RowID) + return mq.GetDB().QueryRow(ctx, insertMessageQuery, msg.ensureHasMetadata(mq.MetaType).sqlVariables()...).Scan(&msg.RowID) } func (mq *MessageQuery) Update(ctx context.Context, msg *Message) error { ensureBridgeIDMatches(&msg.BridgeID, mq.BridgeID) - return mq.Exec(ctx, updateMessageQuery, msg.updateSQLVariables()...) + return mq.Exec(ctx, updateMessageQuery, msg.ensureHasMetadata(mq.MetaType).updateSQLVariables()...) } func (mq *MessageQuery) DeleteAllParts(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID) error { @@ -165,15 +151,12 @@ func (m *Message) Scan(row dbutil.Scannable) (*Message, error) { var timestamp int64 var threadRootID, replyToID, replyToPartID sql.NullString err := row.Scan( - &m.RowID, &m.BridgeID, &m.ID, &m.PartID, &m.MXID, &m.Room.ID, &m.Room.Receiver, &m.SenderID, - ×tamp, &threadRootID, &replyToID, &replyToPartID, dbutil.JSON{Data: &m.Metadata}, + &m.RowID, &m.BridgeID, &m.ID, &m.PartID, &m.MXID, &m.Room.ID, &m.Room.Receiver, &m.SenderID, &m.SenderMXID, + &m.EditCount, ×tamp, &threadRootID, &replyToID, &replyToPartID, dbutil.JSON{Data: m.Metadata}, ) if err != nil { return nil, err } - if m.Metadata.Extra == nil { - m.Metadata.Extra = make(map[string]any) - } m.Timestamp = time.Unix(0, timestamp) m.ThreadRoot = networkid.MessageID(threadRootID.String) if replyToID.Valid { @@ -185,14 +168,18 @@ func (m *Message) Scan(row dbutil.Scannable) (*Message, error) { return m, nil } -func (m *Message) sqlVariables() []any { - if m.Metadata.Extra == nil { - m.Metadata.Extra = make(map[string]any) +func (m *Message) ensureHasMetadata(metaType MetaTypeCreator) *Message { + if m.Metadata == nil { + m.Metadata = metaType() } + return m +} + +func (m *Message) sqlVariables() []any { return []any{ - m.BridgeID, m.ID, m.PartID, m.MXID, m.Room.ID, m.Room.Receiver, m.SenderID, - m.Timestamp.UnixNano(), dbutil.StrPtr(m.ThreadRoot), dbutil.StrPtr(m.ReplyTo.MessageID), m.ReplyTo.PartID, - dbutil.JSON{Data: &m.Metadata}, + m.BridgeID, m.ID, m.PartID, m.MXID, m.Room.ID, m.Room.Receiver, m.SenderID, m.SenderMXID, + m.EditCount, m.Timestamp.UnixNano(), dbutil.StrPtr(m.ThreadRoot), dbutil.StrPtr(m.ReplyTo.MessageID), m.ReplyTo.PartID, + dbutil.JSON{Data: m.Metadata}, } } diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index ca277455..503d8a62 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -19,31 +19,21 @@ import ( "maunium.net/go/mautrix/id" ) +type RoomType string + +const ( + RoomTypeDefault RoomType = "" + RoomTypeDM RoomType = "dm" + RoomTypeGroupDM RoomType = "group_dm" + RoomTypeSpace RoomType = "space" +) + type PortalQuery struct { BridgeID networkid.BridgeID + MetaType MetaTypeCreator *dbutil.QueryHelper[*Portal] } -type StandardPortalMetadata struct { - DisappearType DisappearingType `json:"disappear_type,omitempty"` - DisappearTimer time.Duration `json:"disappear_timer,omitempty"` - IsDirect bool `json:"is_direct,omitempty"` - IsSpace bool `json:"is_space,omitempty"` -} - -type PortalMetadata struct { - StandardPortalMetadata - Extra map[string]any -} - -func (pm *PortalMetadata) UnmarshalJSON(data []byte) error { - return unmarshalMerge(data, &pm.StandardPortalMetadata, &pm.Extra) -} - -func (pm *PortalMetadata) MarshalJSON() ([]byte, error) { - return marshalMerge(&pm.StandardPortalMetadata, pm.Extra) -} - type Portal struct { BridgeID networkid.BridgeID networkid.PortalKey @@ -60,11 +50,9 @@ type Portal struct { TopicSet bool AvatarSet bool InSpace bool - Metadata PortalMetadata -} - -func newPortal(_ *dbutil.QueryHelper[*Portal]) *Portal { - return &Portal{} + RoomType RoomType + Disappear DisappearingSetting + Metadata any } const ( @@ -72,6 +60,7 @@ const ( SELECT bridge_id, id, receiver, mxid, parent_id, relay_login_id, name, topic, avatar_id, avatar_hash, avatar_mxc, name_set, topic_set, avatar_set, in_space, + room_type, disappear_type, disappear_timer, metadata FROM portal ` @@ -88,9 +77,10 @@ const ( parent_id, relay_login_id, name, topic, avatar_id, avatar_hash, avatar_mxc, name_set, avatar_set, topic_set, in_space, + room_type, disappear_type, disappear_timer, metadata, relay_bridge_id ) VALUES ( - $1, $2, $3, $4, $5, cast($6 AS TEXT), $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, + $1, $2, $3, $4, $5, cast($6 AS TEXT), $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, CASE WHEN cast($6 AS TEXT) IS NULL THEN NULL ELSE $1 END ) ` @@ -98,7 +88,8 @@ const ( UPDATE portal SET mxid=$4, parent_id=$5, relay_login_id=cast($6 AS TEXT), relay_bridge_id=CASE WHEN cast($6 AS TEXT) IS NULL THEN NULL ELSE bridge_id END, name=$7, topic=$8, avatar_id=$9, avatar_hash=$10, avatar_mxc=$11, - name_set=$12, avatar_set=$13, topic_set=$14, in_space=$15, metadata=$16 + name_set=$12, avatar_set=$13, topic_set=$14, in_space=$15, + room_type=$16, disappear_type=$17, disappear_timer=$18, metadata=$19 WHERE bridge_id=$1 AND id=$2 AND receiver=$3 ` deletePortalQuery = ` @@ -138,12 +129,12 @@ func (pq *PortalQuery) ReID(ctx context.Context, oldID, newID networkid.PortalKe func (pq *PortalQuery) Insert(ctx context.Context, p *Portal) error { ensureBridgeIDMatches(&p.BridgeID, pq.BridgeID) - return pq.Exec(ctx, insertPortalQuery, p.sqlVariables()...) + return pq.Exec(ctx, insertPortalQuery, p.ensureHasMetadata(pq.MetaType).sqlVariables()...) } func (pq *PortalQuery) Update(ctx context.Context, p *Portal) error { ensureBridgeIDMatches(&p.BridgeID, pq.BridgeID) - return pq.Exec(ctx, updatePortalQuery, p.sqlVariables()...) + return pq.Exec(ctx, updatePortalQuery, p.ensureHasMetadata(pq.MetaType).sqlVariables()...) } func (pq *PortalQuery) Delete(ctx context.Context, key networkid.PortalKey) error { @@ -151,36 +142,45 @@ func (pq *PortalQuery) Delete(ctx context.Context, key networkid.PortalKey) erro } func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { - var mxid, parentID, relayLoginID sql.NullString + var mxid, parentID, relayLoginID, disappearType sql.NullString + var disappearTimer sql.NullInt64 var avatarHash string err := row.Scan( &p.BridgeID, &p.ID, &p.Receiver, &mxid, &parentID, &relayLoginID, &p.Name, &p.Topic, &p.AvatarID, &avatarHash, &p.AvatarMXC, &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.InSpace, - dbutil.JSON{Data: &p.Metadata}, + &p.RoomType, &disappearType, &disappearTimer, + dbutil.JSON{Data: p.Metadata}, ) if err != nil { return nil, err } - if p.Metadata.Extra == nil { - p.Metadata.Extra = make(map[string]any) - } if avatarHash != "" { data, _ := hex.DecodeString(avatarHash) if len(data) == 32 { p.AvatarHash = *(*[32]byte)(data) } } + if disappearType.Valid { + p.Disappear = DisappearingSetting{ + Type: DisappearingType(disappearType.String), + Timer: time.Duration(disappearTimer.Int64), + } + } p.MXID = id.RoomID(mxid.String) p.ParentID = networkid.PortalID(parentID.String) p.RelayLoginID = networkid.UserLoginID(relayLoginID.String) return p, nil } -func (p *Portal) sqlVariables() []any { - if p.Metadata.Extra == nil { - p.Metadata.Extra = make(map[string]any) +func (p *Portal) ensureHasMetadata(metaType MetaTypeCreator) *Portal { + if p.Metadata == nil { + p.Metadata = metaType() } + return p +} + +func (p *Portal) sqlVariables() []any { var avatarHash string if p.AvatarHash != [32]byte{} { avatarHash = hex.EncodeToString(p.AvatarHash[:]) @@ -190,6 +190,6 @@ func (p *Portal) sqlVariables() []any { dbutil.StrPtr(p.ParentID), dbutil.StrPtr(p.RelayLoginID), p.Name, p.Topic, p.AvatarID, avatarHash, p.AvatarMXC, p.NameSet, p.TopicSet, p.AvatarSet, p.InSpace, - dbutil.JSON{Data: &p.Metadata}, + dbutil.JSON{Data: p.Metadata}, } } diff --git a/bridgev2/database/reaction.go b/bridgev2/database/reaction.go index 03e9f521..eaa6ecd6 100644 --- a/bridgev2/database/reaction.go +++ b/bridgev2/database/reaction.go @@ -18,26 +18,10 @@ import ( type ReactionQuery struct { BridgeID networkid.BridgeID + MetaType MetaTypeCreator *dbutil.QueryHelper[*Reaction] } -type StandardReactionMetadata struct { - Emoji string `json:"emoji,omitempty"` -} - -type ReactionMetadata struct { - StandardReactionMetadata - Extra map[string]any -} - -func (rm *ReactionMetadata) UnmarshalJSON(data []byte) error { - return unmarshalMerge(data, &rm.StandardReactionMetadata, &rm.Extra) -} - -func (rm *ReactionMetadata) MarshalJSON() ([]byte, error) { - return marshalMerge(&rm.StandardReactionMetadata, rm.Extra) -} - type Reaction struct { BridgeID networkid.BridgeID Room networkid.PortalKey @@ -48,16 +32,13 @@ type Reaction struct { MXID id.EventID Timestamp time.Time - Metadata ReactionMetadata -} - -func newReaction(_ *dbutil.QueryHelper[*Reaction]) *Reaction { - return &Reaction{} + Emoji string + Metadata any } const ( getReactionBaseQuery = ` - SELECT bridge_id, message_id, message_part_id, sender_id, emoji_id, room_id, room_receiver, mxid, timestamp, metadata FROM reaction + SELECT bridge_id, message_id, message_part_id, sender_id, emoji_id, emoji, room_id, room_receiver, mxid, timestamp, metadata FROM reaction ` getReactionByIDQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND message_part_id=$3 AND sender_id=$4 AND emoji_id=$5` getReactionByIDWithoutMessagePartQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND sender_id=$3 AND emoji_id=$4 ORDER BY message_part_id ASC LIMIT 1` @@ -65,10 +46,10 @@ const ( getAllReactionsToMessageQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2` getReactionByMXIDQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` upsertReactionQuery = ` - INSERT INTO reaction (bridge_id, message_id, message_part_id, sender_id, emoji_id, room_id, room_receiver, mxid, timestamp, metadata) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + INSERT INTO reaction (bridge_id, message_id, message_part_id, sender_id, emoji_id, emoji, room_id, room_receiver, mxid, timestamp, metadata) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) ON CONFLICT (bridge_id, room_receiver, message_id, message_part_id, sender_id, emoji_id) - DO UPDATE SET mxid=excluded.mxid, timestamp=excluded.timestamp, metadata=excluded.metadata + DO UPDATE SET mxid=excluded.mxid, timestamp=excluded.timestamp, emoji=excluded.emoji, metadata=excluded.metadata ` deleteReactionQuery = ` DELETE FROM reaction WHERE bridge_id=$1 AND message_id=$2 AND message_part_id=$3 AND sender_id=$4 AND emoji_id=$5 @@ -97,7 +78,7 @@ func (rq *ReactionQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*React func (rq *ReactionQuery) Upsert(ctx context.Context, reaction *Reaction) error { ensureBridgeIDMatches(&reaction.BridgeID, rq.BridgeID) - return rq.Exec(ctx, upsertReactionQuery, reaction.sqlVariables()...) + return rq.Exec(ctx, upsertReactionQuery, reaction.ensureHasMetadata(rq.MetaType).sqlVariables()...) } func (rq *ReactionQuery) Delete(ctx context.Context, reaction *Reaction) error { @@ -108,25 +89,26 @@ func (rq *ReactionQuery) Delete(ctx context.Context, reaction *Reaction) error { func (r *Reaction) Scan(row dbutil.Scannable) (*Reaction, error) { var timestamp int64 err := row.Scan( - &r.BridgeID, &r.MessageID, &r.MessagePartID, &r.SenderID, &r.EmojiID, - &r.Room.ID, &r.Room.Receiver, &r.MXID, ×tamp, dbutil.JSON{Data: &r.Metadata}, + &r.BridgeID, &r.MessageID, &r.MessagePartID, &r.SenderID, &r.EmojiID, &r.Emoji, + &r.Room.ID, &r.Room.Receiver, &r.MXID, ×tamp, dbutil.JSON{Data: r.Metadata}, ) if err != nil { return nil, err } - if r.Metadata.Extra == nil { - r.Metadata.Extra = make(map[string]any) - } r.Timestamp = time.Unix(0, timestamp) return r, nil } -func (r *Reaction) sqlVariables() []any { - if r.Metadata.Extra == nil { - r.Metadata.Extra = make(map[string]any) +func (r *Reaction) ensureHasMetadata(metaType MetaTypeCreator) *Reaction { + if r.Metadata == nil { + r.Metadata = metaType() } + return r +} + +func (r *Reaction) sqlVariables() []any { return []any{ - r.BridgeID, r.MessageID, r.MessagePartID, r.SenderID, r.EmojiID, - r.Room.ID, r.Room.Receiver, r.MXID, r.Timestamp.UnixNano(), dbutil.JSON{Data: &r.Metadata}, + r.BridgeID, r.MessageID, r.MessagePartID, r.SenderID, r.EmojiID, r.Emoji, + r.Room.ID, r.Room.Receiver, r.MXID, r.Timestamp.UnixNano(), dbutil.JSON{Data: r.Metadata}, } } diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index 303dcf8d..2f76e00a 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v7 (compatible with v1+): Latest revision +-- v0 -> v10 (compatible with v9+): Latest revision CREATE TABLE "user" ( bridge_id TEXT NOT NULL, mxid TEXT NOT NULL, @@ -10,11 +10,12 @@ CREATE TABLE "user" ( ); CREATE TABLE user_login ( - bridge_id TEXT NOT NULL, - user_mxid TEXT NOT NULL, - id TEXT NOT NULL, - space_room TEXT, - metadata jsonb NOT NULL, + bridge_id TEXT NOT NULL, + user_mxid TEXT NOT NULL, + id TEXT NOT NULL, + remote_name TEXT NOT NULL, + space_room TEXT, + metadata jsonb NOT NULL, PRIMARY KEY (bridge_id, id), CONSTRAINT user_login_user_fkey FOREIGN KEY (bridge_id, user_mxid) @@ -45,6 +46,9 @@ CREATE TABLE portal ( avatar_set BOOLEAN NOT NULL, topic_set BOOLEAN NOT NULL, in_space BOOLEAN NOT NULL, + room_type TEXT NOT NULL, + disappear_type TEXT, + disappear_timer BIGINT, metadata jsonb NOT NULL, PRIMARY KEY (bridge_id, id, receiver), @@ -58,16 +62,19 @@ CREATE TABLE portal ( ); CREATE TABLE ghost ( - bridge_id TEXT NOT NULL, - id TEXT NOT NULL, + bridge_id TEXT NOT NULL, + id TEXT NOT NULL, - name TEXT NOT NULL, - avatar_id TEXT NOT NULL, - avatar_hash TEXT NOT NULL, - avatar_mxc TEXT NOT NULL, - name_set BOOLEAN NOT NULL, - avatar_set BOOLEAN NOT NULL, - metadata jsonb NOT NULL, + name TEXT NOT NULL, + avatar_id TEXT NOT NULL, + avatar_hash TEXT NOT NULL, + avatar_mxc TEXT NOT NULL, + name_set BOOLEAN NOT NULL, + avatar_set BOOLEAN NOT NULL, + contact_info_set BOOLEAN NOT NULL, + is_bot BOOLEAN NOT NULL, + identifiers jsonb NOT NULL, + metadata jsonb NOT NULL, PRIMARY KEY (bridge_id, id) ); @@ -82,23 +89,22 @@ CREATE TABLE message ( -- only: postgres rowid BIGINT PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY, - bridge_id TEXT NOT NULL, - id TEXT NOT NULL, - part_id TEXT NOT NULL, - mxid TEXT NOT NULL, + bridge_id TEXT NOT NULL, + id TEXT NOT NULL, + part_id TEXT NOT NULL, + mxid TEXT NOT NULL, - room_id TEXT NOT NULL, - room_receiver TEXT NOT NULL, - sender_id TEXT NOT NULL, - timestamp BIGINT NOT NULL, + room_id TEXT NOT NULL, + room_receiver TEXT NOT NULL, + sender_id TEXT NOT NULL, + sender_mxid TEXT NOT NULL, + timestamp BIGINT NOT NULL, + edit_count INTEGER NOT NULL, thread_root_id TEXT, reply_to_id TEXT, reply_to_part_id TEXT, - relates_to BIGINT, -- unused column, TODO: remove - metadata jsonb NOT NULL, + metadata jsonb NOT NULL, - CONSTRAINT message_relation_fkey FOREIGN KEY (relates_to) - REFERENCES message (rowid) ON DELETE SET NULL, CONSTRAINT message_room_fkey FOREIGN KEY (bridge_id, room_id, room_receiver) REFERENCES portal (bridge_id, id, receiver) ON DELETE CASCADE ON UPDATE CASCADE, @@ -130,6 +136,7 @@ CREATE TABLE reaction ( mxid TEXT NOT NULL, timestamp BIGINT NOT NULL, + emoji TEXT NOT NULL, metadata jsonb NOT NULL, PRIMARY KEY (bridge_id, room_receiver, message_id, message_part_id, sender_id, emoji_id), diff --git a/bridgev2/database/upgrades/08-drop-message-relates-to.postgres.sql b/bridgev2/database/upgrades/08-drop-message-relates-to.postgres.sql new file mode 100644 index 00000000..284f6b0e --- /dev/null +++ b/bridgev2/database/upgrades/08-drop-message-relates-to.postgres.sql @@ -0,0 +1,3 @@ +-- v8: Drop relates_to column in messages +-- transaction: off +ALTER TABLE message DROP COLUMN relates_to; diff --git a/bridgev2/database/upgrades/08-drop-message-relates-to.sqlite.sql b/bridgev2/database/upgrades/08-drop-message-relates-to.sqlite.sql new file mode 100644 index 00000000..307a876e --- /dev/null +++ b/bridgev2/database/upgrades/08-drop-message-relates-to.sqlite.sql @@ -0,0 +1,41 @@ +-- v8: Drop relates_to column in messages +-- transaction: off +PRAGMA foreign_keys = OFF; +BEGIN; + +CREATE TABLE message_new ( + rowid INTEGER PRIMARY KEY, + + bridge_id TEXT NOT NULL, + id TEXT NOT NULL, + part_id TEXT NOT NULL, + mxid TEXT NOT NULL, + + room_id TEXT NOT NULL, + room_receiver TEXT NOT NULL, + sender_id TEXT NOT NULL, + timestamp BIGINT NOT NULL, + thread_root_id TEXT, + reply_to_id TEXT, + reply_to_part_id TEXT, + metadata jsonb NOT NULL, + + CONSTRAINT message_room_fkey FOREIGN KEY (bridge_id, room_id, room_receiver) + REFERENCES portal (bridge_id, id, receiver) + ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT message_sender_fkey FOREIGN KEY (bridge_id, sender_id) + REFERENCES ghost (bridge_id, id) + ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT message_real_pkey UNIQUE (bridge_id, room_receiver, id, part_id) +); + +INSERT INTO message_new (rowid, bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, timestamp, metadata) +SELECT rowid, bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, timestamp, metadata +FROM message; + +DROP TABLE message; +ALTER TABLE message_new RENAME TO message; + +PRAGMA foreign_key_check; +COMMIT; +PRAGMA foreign_keys = ON; diff --git a/bridgev2/database/upgrades/09-remove-standard-metadata.sql b/bridgev2/database/upgrades/09-remove-standard-metadata.sql new file mode 100644 index 00000000..3f348007 --- /dev/null +++ b/bridgev2/database/upgrades/09-remove-standard-metadata.sql @@ -0,0 +1,45 @@ +-- v9: Move standard metadata to separate columns +ALTER TABLE message ADD COLUMN sender_mxid TEXT NOT NULL DEFAULT ''; +UPDATE message SET sender_mxid=COALESCE((metadata->>'sender_mxid'), ''); + +ALTER TABLE message ADD COLUMN edit_count INTEGER NOT NULL DEFAULT 0; +UPDATE message SET edit_count=COALESCE(CAST((metadata->>'edit_count') AS INTEGER), 0); + +ALTER TABLE portal ADD COLUMN disappear_type TEXT; +UPDATE portal SET disappear_type=(metadata->>'disappear_type'); + +ALTER TABLE portal ADD COLUMN disappear_timer BIGINT; +-- only: postgres +UPDATE portal SET disappear_timer=(metadata->>'disappear_timer')::BIGINT; +-- only: sqlite +UPDATE portal SET disappear_timer=CAST(metadata->>'disappear_timer' AS INTEGER); + +ALTER TABLE portal ADD COLUMN room_type TEXT NOT NULL DEFAULT ''; +UPDATE portal SET room_type='dm' WHERE CAST(metadata->>'is_direct' AS BOOLEAN) IS true; +UPDATE portal SET room_type='space' WHERE CAST(metadata->>'is_space' AS BOOLEAN) IS true; + +ALTER TABLE reaction ADD COLUMN emoji TEXT NOT NULL DEFAULT ''; +UPDATE reaction SET emoji=COALESCE((metadata->>'emoji'), ''); + +ALTER TABLE user_login ADD COLUMN remote_name TEXT NOT NULL DEFAULT ''; +UPDATE user_login SET remote_name=COALESCE((metadata->>'remote_name'), ''); + +ALTER TABLE ghost ADD COLUMN contact_info_set BOOLEAN NOT NULL DEFAULT false; +UPDATE ghost SET contact_info_set=COALESCE(CAST((metadata->>'contact_info_set') AS BOOLEAN), false); + +ALTER TABLE ghost ADD COLUMN is_bot BOOLEAN NOT NULL DEFAULT false; +UPDATE ghost SET is_bot=COALESCE(CAST((metadata->>'is_bot') AS BOOLEAN), false); + +ALTER TABLE ghost ADD COLUMN identifiers jsonb NOT NULL DEFAULT '[]'; +UPDATE ghost SET identifiers=COALESCE((metadata->'identifiers'), '[]'); + +-- only: postgres until "end only" +ALTER TABLE message ALTER COLUMN sender_mxid DROP DEFAULT; +ALTER TABLE message ALTER COLUMN edit_count DROP DEFAULT; +ALTER TABLE portal ALTER COLUMN room_type DROP DEFAULT; +ALTER TABLE reaction ALTER COLUMN emoji DROP DEFAULT; +ALTER TABLE user_login ALTER COLUMN remote_name DROP DEFAULT; +ALTER TABLE ghost ALTER COLUMN contact_info_set DROP DEFAULT; +ALTER TABLE ghost ALTER COLUMN is_bot DROP DEFAULT; +ALTER TABLE ghost ALTER COLUMN identifiers DROP DEFAULT; +-- end only postgres diff --git a/bridgev2/database/upgrades/10-fix-signal-portal-revision.postgres.sql b/bridgev2/database/upgrades/10-fix-signal-portal-revision.postgres.sql new file mode 100644 index 00000000..f42402f3 --- /dev/null +++ b/bridgev2/database/upgrades/10-fix-signal-portal-revision.postgres.sql @@ -0,0 +1,4 @@ +-- v10 (compatible with v9+): Fix Signal portal revisions +UPDATE portal +SET metadata=jsonb_set(metadata, '{revision}', CAST((metadata->>'revision') AS jsonb)) +WHERE jsonb_typeof(metadata->'revision')='string'; diff --git a/bridgev2/database/upgrades/10-fix-signal-portal-revision.sqlite.sql b/bridgev2/database/upgrades/10-fix-signal-portal-revision.sqlite.sql new file mode 100644 index 00000000..0fd67c80 --- /dev/null +++ b/bridgev2/database/upgrades/10-fix-signal-portal-revision.sqlite.sql @@ -0,0 +1,4 @@ +-- v10 (compatible with v9+): Fix Signal portal revisions +UPDATE portal +SET metadata=json_set(metadata, '$.revision', CAST(json_extract(metadata, '$.revision') AS INTEGER)) +WHERE json_type(metadata, '$.revision')='text'; diff --git a/bridgev2/database/user.go b/bridgev2/database/user.go index de3b316a..00eae7ca 100644 --- a/bridgev2/database/user.go +++ b/bridgev2/database/user.go @@ -29,10 +29,6 @@ type User struct { AccessToken string } -func newUser(_ *dbutil.QueryHelper[*User]) *User { - return &User{} -} - const ( getUserBaseQuery = ` SELECT bridge_id, mxid, management_room, access_token FROM "user" diff --git a/bridgev2/database/userlogin.go b/bridgev2/database/userlogin.go index cc92e7d4..d994d270 100644 --- a/bridgev2/database/userlogin.go +++ b/bridgev2/database/userlogin.go @@ -18,56 +18,37 @@ import ( type UserLoginQuery struct { BridgeID networkid.BridgeID + MetaType MetaTypeCreator *dbutil.QueryHelper[*UserLogin] } -type StandardUserLoginMetadata struct { - RemoteName string `json:"remote_name,omitempty"` -} - -type UserLoginMetadata struct { - StandardUserLoginMetadata - Extra map[string]any -} - -func (ulm *UserLoginMetadata) UnmarshalJSON(data []byte) error { - return unmarshalMerge(data, &ulm.StandardUserLoginMetadata, &ulm.Extra) -} - -func (ulm *UserLoginMetadata) MarshalJSON() ([]byte, error) { - return marshalMerge(&ulm.StandardUserLoginMetadata, ulm.Extra) -} - type UserLogin struct { - BridgeID networkid.BridgeID - UserMXID id.UserID - ID networkid.UserLoginID - SpaceRoom id.RoomID - Metadata UserLoginMetadata -} - -func newUserLogin(_ *dbutil.QueryHelper[*UserLogin]) *UserLogin { - return &UserLogin{} + BridgeID networkid.BridgeID + UserMXID id.UserID + ID networkid.UserLoginID + RemoteName string + SpaceRoom id.RoomID + Metadata any } const ( getUserLoginBaseQuery = ` - SELECT bridge_id, user_mxid, id, space_room, metadata FROM user_login + SELECT bridge_id, user_mxid, id, remote_name, space_room, metadata FROM user_login ` getLoginByIDQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND id=$2` getAllUsersWithLoginsQuery = `SELECT DISTINCT user_mxid FROM user_login WHERE bridge_id=$1` getAllLoginsForUserQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND user_mxid=$2` getAllLoginsInPortalQuery = ` - SELECT ul.bridge_id, ul.user_mxid, ul.id, ul.space_room, ul.metadata FROM user_portal + SELECT ul.bridge_id, ul.user_mxid, ul.id, ul.remote_name, ul.space_room, ul.metadata FROM user_portal LEFT JOIN user_login ul ON user_portal.bridge_id=ul.bridge_id AND user_portal.user_mxid=ul.user_mxid AND user_portal.login_id=ul.id WHERE user_portal.bridge_id=$1 AND user_portal.portal_id=$2 AND user_portal.portal_receiver=$3 ` insertUserLoginQuery = ` - INSERT INTO user_login (bridge_id, user_mxid, id, space_room, metadata) - VALUES ($1, $2, $3, $4, $5) + INSERT INTO user_login (bridge_id, user_mxid, id, remote_name, space_room, metadata) + VALUES ($1, $2, $3, $4, $5, $6) ` updateUserLoginQuery = ` - UPDATE user_login SET space_room=$4, metadata=$5 + UPDATE user_login SET remote_name=$4, space_room=$5, metadata=$6 WHERE bridge_id=$1 AND user_mxid=$2 AND id=$3 ` deleteUserLoginQuery = ` @@ -94,12 +75,12 @@ func (uq *UserLoginQuery) GetAllForUser(ctx context.Context, userID id.UserID) ( func (uq *UserLoginQuery) Insert(ctx context.Context, login *UserLogin) error { ensureBridgeIDMatches(&login.BridgeID, uq.BridgeID) - return uq.Exec(ctx, insertUserLoginQuery, login.sqlVariables()...) + return uq.Exec(ctx, insertUserLoginQuery, login.ensureHasMetadata(uq.MetaType).sqlVariables()...) } func (uq *UserLoginQuery) Update(ctx context.Context, login *UserLogin) error { ensureBridgeIDMatches(&login.BridgeID, uq.BridgeID) - return uq.Exec(ctx, updateUserLoginQuery, login.sqlVariables()...) + return uq.Exec(ctx, updateUserLoginQuery, login.ensureHasMetadata(uq.MetaType).sqlVariables()...) } func (uq *UserLoginQuery) Delete(ctx context.Context, loginID networkid.UserLoginID) error { @@ -108,20 +89,21 @@ func (uq *UserLoginQuery) Delete(ctx context.Context, loginID networkid.UserLogi func (u *UserLogin) Scan(row dbutil.Scannable) (*UserLogin, error) { var spaceRoom sql.NullString - err := row.Scan(&u.BridgeID, &u.UserMXID, &u.ID, &spaceRoom, dbutil.JSON{Data: &u.Metadata}) + err := row.Scan(&u.BridgeID, &u.UserMXID, &u.ID, &u.RemoteName, &spaceRoom, dbutil.JSON{Data: u.Metadata}) if err != nil { return nil, err } - if u.Metadata.Extra == nil { - u.Metadata.Extra = make(map[string]any) - } u.SpaceRoom = id.RoomID(spaceRoom.String) return u, nil } -func (u *UserLogin) sqlVariables() []any { - if u.Metadata.Extra == nil { - u.Metadata.Extra = make(map[string]any) +func (u *UserLogin) ensureHasMetadata(metaType MetaTypeCreator) *UserLogin { + if u.Metadata == nil { + u.Metadata = metaType() } - return []any{u.BridgeID, u.UserMXID, u.ID, dbutil.StrPtr(u.SpaceRoom), dbutil.JSON{Data: &u.Metadata}} + return u +} + +func (u *UserLogin) sqlVariables() []any { + return []any{u.BridgeID, u.UserMXID, u.ID, u.RemoteName, dbutil.StrPtr(u.SpaceRoom), dbutil.JSON{Data: u.Metadata}} } diff --git a/bridgev2/database/userportal.go b/bridgev2/database/userportal.go index 71235d2a..eeda4ba3 100644 --- a/bridgev2/database/userportal.go +++ b/bridgev2/database/userportal.go @@ -32,10 +32,6 @@ type UserPortal struct { LastRead time.Time } -func newUserPortal(_ *dbutil.QueryHelper[*UserPortal]) *UserPortal { - return &UserPortal{} -} - const ( getUserPortalBaseQuery = ` SELECT bridge_id, user_mxid, login_id, portal_id, portal_receiver, in_space, preferred, last_read diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index cf1a68d9..da1e81f0 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -14,7 +14,6 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/exmime" - "golang.org/x/exp/constraints" "golang.org/x/exp/slices" "maunium.net/go/mautrix/bridgev2/database" @@ -169,31 +168,31 @@ func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, if identifiers != nil { slices.Sort(identifiers) } - if ghost.Metadata.ContactInfoSet && - (identifiers == nil || slices.Equal(identifiers, ghost.Metadata.Identifiers)) && - (isBot == nil || *isBot == ghost.Metadata.IsBot) { + if ghost.ContactInfoSet && + (identifiers == nil || slices.Equal(identifiers, ghost.Identifiers)) && + (isBot == nil || *isBot == ghost.IsBot) { return false } if identifiers != nil { - ghost.Metadata.Identifiers = identifiers + ghost.Identifiers = identifiers } if isBot != nil { - ghost.Metadata.IsBot = *isBot + ghost.IsBot = *isBot } bridgeName := ghost.Bridge.Network.GetName() meta := &event.BeeperProfileExtra{ RemoteID: string(ghost.ID), - Identifiers: ghost.Metadata.Identifiers, + Identifiers: ghost.Identifiers, Service: bridgeName.BeeperBridgeType, Network: bridgeName.NetworkID, IsBridgeBot: false, - IsNetworkBot: ghost.Metadata.IsBot, + IsNetworkBot: ghost.IsBot, } err := ghost.Intent.SetExtraProfileMeta(ctx, meta) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to set extra profile metadata") } else { - ghost.Metadata.ContactInfoSet = true + ghost.ContactInfoSet = true } return true } @@ -222,45 +221,6 @@ func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin } } -func MergeExtraUpdaters[T any](funcs ...func(context.Context, T) bool) func(context.Context, T) bool { - return func(ctx context.Context, obj T) bool { - update := false - for _, f := range funcs { - update = f(ctx, obj) || update - } - return update - } -} - -func NumberMetadataUpdater[T *Ghost | *Portal, MetaType constraints.Integer | constraints.Float](key string, value MetaType) func(context.Context, T) bool { - return simpleMetadataUpdater[T, MetaType](key, value, database.GetNumberFromMap[MetaType]) -} - -func SimpleMetadataUpdater[T *Ghost | *Portal, MetaType comparable](key string, value MetaType) func(context.Context, T) bool { - return simpleMetadataUpdater[T, MetaType](key, value, func(m map[string]any, key string) (MetaType, bool) { - val, ok := m[key].(MetaType) - return val, ok - }) -} - -func simpleMetadataUpdater[T *Ghost | *Portal, MetaType comparable](key string, value MetaType, getter func(map[string]any, string) (MetaType, bool)) func(context.Context, T) bool { - return func(ctx context.Context, obj T) bool { - var meta map[string]any - switch typedObj := any(obj).(type) { - case *Ghost: - meta = typedObj.Metadata.Extra - case *Portal: - meta = typedObj.Metadata.Extra - } - currentVal, ok := getter(meta, key) - if ok && currentVal == value { - return false - } - meta[key] = value - return true - } -} - func (ghost *Ghost) UpdateInfo(ctx context.Context, info *UserInfo) { update := false if info.Name != nil { diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 5e0b2ca7..197670d4 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -409,7 +409,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http. } apiResp.Name = resp.Ghost.Name apiResp.AvatarURL = resp.Ghost.AvatarMXC - apiResp.Identifiers = resp.Ghost.Metadata.Identifiers + apiResp.Identifiers = resp.Ghost.Identifiers apiResp.MXID = resp.Ghost.Intent.GetMXID() } else if resp.UserInfo != nil && resp.UserInfo.Name != nil { apiResp.Name = *resp.UserInfo.Name @@ -490,8 +490,8 @@ func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Reque if contact.Ghost.Name != "" { apiContact.Name = contact.Ghost.Name } - if len(contact.Ghost.Metadata.Identifiers) >= len(apiContact.Identifiers) { - apiContact.Identifiers = contact.Ghost.Metadata.Identifiers + if len(contact.Ghost.Identifiers) >= len(apiContact.Identifiers) { + apiContact.Identifiers = contact.Ghost.Identifiers } apiContact.AvatarURL = contact.Ghost.AvatarMXC apiContact.MXID = contact.Ghost.Intent.GetMXID() diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index f28f2250..6ee84e4a 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -16,7 +16,6 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/configupgrade" "go.mau.fi/util/ptr" - "golang.org/x/exp/maps" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" @@ -30,7 +29,7 @@ type ConvertedMessagePart struct { Type event.Type Content *event.MessageEventContent Extra map[string]any - DBMetadata map[string]any + DBMetadata any } func (cmp *ConvertedMessagePart) ToEditPart(part *database.Message) *ConvertedEditPart { @@ -38,7 +37,12 @@ func (cmp *ConvertedMessagePart) ToEditPart(part *database.Message) *ConvertedEd return nil } if cmp.DBMetadata != nil { - maps.Copy(part.Metadata.Extra, cmp.DBMetadata) + merger, ok := part.Metadata.(database.MetaMerger) + if ok { + merger.CopyFrom(cmp.DBMetadata) + } else { + part.Metadata = cmp.DBMetadata + } } return &ConvertedEditPart{ Part: part, @@ -167,6 +171,10 @@ type NetworkConnector interface { // The output can still be adjusted based on config variables, but the function must have // default values when called without a config. GetName() BridgeName + // GetDBMetaTypes returns struct types that are used to store connector-specific metadata in various tables. + // All fields are optional. If a field isn't provided, then the corresponding table will have no custom metadata. + // This will be called before Init, it should have a hardcoded response. + GetDBMetaTypes() database.MetaTypes // GetCapabilities returns the general capabilities of the network connector. // Note that most capabilities are scoped to rooms and are returned by [NetworkAPI.GetCapabilities] instead. GetCapabilities() *NetworkGeneralCapabilities @@ -720,7 +728,7 @@ type RemoteReactionWithExtraContent interface { type RemoteReactionWithMeta interface { RemoteReaction - GetReactionDBMetadata() map[string]any + GetReactionDBMetadata() any } type RemoteReactionRemove interface { @@ -780,7 +788,7 @@ type SimpleRemoteEvent[T any] struct { TargetMessage networkid.MessageID EmojiID networkid.EmojiID Emoji string - ReactionDBMeta map[string]any + ReactionDBMeta any Timestamp time.Time ChatInfoChange *ChatInfoChange @@ -842,7 +850,7 @@ func (sre *SimpleRemoteEvent[T]) GetRemovedEmojiID() networkid.EmojiID { return sre.EmojiID } -func (sre *SimpleRemoteEvent[T]) GetReactionDBMetadata() map[string]any { +func (sre *SimpleRemoteEvent[T]) GetReactionDBMetadata() any { return sre.ReactionDBMeta } diff --git a/bridgev2/portal.go b/bridgev2/portal.go index ddd8862b..c5054022 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -667,7 +667,9 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin message.ThreadRoot = threadRoot.ThreadRoot } } - message.Metadata.SenderMXID = evt.Sender + if message.SenderMXID == "" { + message.SenderMXID = evt.Sender + } // Hack to ensure the ghost row exists // TODO move to better place (like login) portal.Bridge.GetGhostByID(ctx, message.SenderID) @@ -675,14 +677,14 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin if err != nil { log.Err(err).Msg("Failed to save message to database") } - if portal.Metadata.DisappearType != database.DisappearingTypeNone { + if portal.Disappear.Type != database.DisappearingTypeNone { go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ RoomID: portal.MXID, EventID: message.MXID, DisappearingSetting: database.DisappearingSetting{ - Type: portal.Metadata.DisappearType, - Timer: portal.Metadata.DisappearTimer, - DisappearAt: message.Timestamp.Add(portal.Metadata.DisappearTimer), + Type: portal.Disappear.Type, + Timer: portal.Disappear.Timer, + DisappearAt: message.Timestamp.Add(portal.Disappear.Timer), }, }) } @@ -732,7 +734,7 @@ func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, o } else if caps.EditMaxAge > 0 && time.Since(editTarget.Timestamp) > caps.EditMaxAge { portal.sendErrorStatus(ctx, evt, ErrEditTargetTooOld) return - } else if caps.EditMaxCount > 0 && editTarget.Metadata.EditCount >= caps.EditMaxCount { + } else if caps.EditMaxCount > 0 && editTarget.EditCount >= caps.EditMaxCount { portal.sendErrorStatus(ctx, evt, ErrEditTargetTooManyEdits) return } @@ -809,7 +811,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi log.Err(err).Msg("Failed to check if reaction is a duplicate") return } else if existing != nil { - if existing.EmojiID != "" || existing.Metadata.Emoji == preResp.Emoji { + if existing.EmojiID != "" || existing.Emoji == preResp.Emoji { log.Debug().Msg("Ignoring duplicate reaction") portal.sendSuccessStatus(ctx, evt) return @@ -876,8 +878,8 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi dbReaction.Timestamp = time.UnixMilli(evt.Timestamp) } if preResp.EmojiID == "" && dbReaction.EmojiID == "" { - if dbReaction.Metadata.Emoji == "" { - dbReaction.Metadata.Emoji = preResp.Emoji + if dbReaction.Emoji == "" { + dbReaction.Emoji = preResp.Emoji } } else if dbReaction.EmojiID == "" { dbReaction.EmojiID = preResp.EmojiID @@ -1212,7 +1214,7 @@ func (portal *Portal) applyRelationMeta(content *event.MessageEventContent, repl if content.Mentions == nil { content.Mentions = &event.Mentions{} } - content.Mentions.Add(replyTo.Metadata.SenderMXID) + content.Mentions.Add(replyTo.SenderMXID) } } @@ -1245,12 +1247,12 @@ func (portal *Portal) sendConvertedMessage(ctx context.Context, id networkid.Mes MXID: resp.EventID, Room: portal.PortalKey, SenderID: sender.Sender, + SenderMXID: intent.GetMXID(), Timestamp: ts, ThreadRoot: ptr.Val(converted.ThreadRoot), ReplyTo: ptr.Val(converted.ReplyTo), + Metadata: part.DBMetadata, } - dbMessage.Metadata.SenderMXID = intent.GetMXID() - dbMessage.Metadata.Extra = part.DBMetadata err = portal.Bridge.DB.Message.Insert(ctx, dbMessage) if err != nil { log.Err(err).Str("part_id", string(part.ID)).Msg("Failed to save message part to database") @@ -1419,7 +1421,7 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi if err != nil { log.Err(err).Msg("Failed to check if reaction is a duplicate") return - } else if existingReaction != nil && (emojiID != "" || existingReaction.Metadata.Emoji == emoji) { + } else if existingReaction != nil && (emojiID != "" || existingReaction.Emoji == emoji) { log.Debug().Msg("Ignoring duplicate reaction") return } @@ -1456,10 +1458,10 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi Timestamp: ts, } if metaProvider, ok := evt.(RemoteReactionWithMeta); ok { - dbReaction.Metadata.Extra = metaProvider.GetReactionDBMetadata() + dbReaction.Metadata = metaProvider.GetReactionDBMetadata() } if emojiID == "" { - dbReaction.Metadata.Emoji = emoji + dbReaction.Emoji = emoji } err = portal.Bridge.DB.Reaction.Upsert(ctx, dbReaction) if err != nil { @@ -1771,10 +1773,9 @@ type ChatInfo struct { Members *ChatMemberList JoinRule *event.JoinRulesEventContent - IsDirectChat *bool - IsSpace *bool - Disappear *database.DisappearingSetting - ParentID *networkid.PortalID + Type *database.RoomType + Disappear *database.DisappearingSetting + ParentID *networkid.PortalID UserLocal *UserLocalPortalInfo @@ -1835,7 +1836,7 @@ func (portal *Portal) UpdateAvatar(ctx context.Context, avatar *Avatar, sender M func (portal *Portal) GetTopLevelParent() *Portal { if portal.Parent == nil { - if !portal.Metadata.IsSpace { + if portal.RoomType != database.RoomTypeSpace { return nil } return portal @@ -1854,12 +1855,7 @@ func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) { AvatarURL: portal.AvatarMXC, // TODO external URL? }, - } - if portal.Metadata.IsDirect { - // TODO group dm type? - bridgeInfo.BeeperRoomType = "dm" - } else if portal.Metadata.IsSpace { - bridgeInfo.BeeperRoomType = "space" + BeeperRoomType: string(portal.RoomType), } parent := portal.GetTopLevelParent() if parent != nil { @@ -2131,11 +2127,11 @@ func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting dat if setting.Timer == 0 { setting.Type = "" } - if portal.Metadata.DisappearTimer == setting.Timer && portal.Metadata.DisappearType == setting.Type { + if portal.Disappear.Timer == setting.Timer && portal.Disappear.Type == setting.Type { return false } - portal.Metadata.DisappearType = setting.Type - portal.Metadata.DisappearTimer = setting.Timer + portal.Disappear.Type = setting.Type + portal.Disappear.Timer = setting.Timer if save { err := portal.Save(ctx) if err != nil { @@ -2161,7 +2157,7 @@ func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting dat zerolog.Ctx(ctx).Err(err).Msg("Failed to send disappearing messages notice") } else { zerolog.Ctx(ctx).Debug(). - Dur("new_timer", portal.Metadata.DisappearTimer). + Dur("new_timer", portal.Disappear.Timer). Bool("implicit", implicit). Msg("Sent disappearing messages notice") } @@ -2230,9 +2226,16 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us } // TODO detect changes to functional members list? } - if info.IsDirectChat != nil && portal.Metadata.IsDirect != *info.IsDirectChat { - changed = true - portal.Metadata.IsDirect = *info.IsDirectChat + if info.Type != nil && portal.RoomType != *info.Type { + if portal.MXID != "" && (*info.Type == database.RoomTypeSpace || portal.RoomType == database.RoomTypeSpace) { + zerolog.Ctx(ctx).Warn(). + Str("current_type", string(portal.RoomType)). + Str("target_type", string(*info.Type)). + Msg("Tried to change existing room type from/to space") + } else { + changed = true + portal.RoomType = *info.Type + } } if source != nil { source.MarkInPortal(ctx, portal) @@ -2296,7 +2299,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i CreationContent: make(map[string]any), InitialState: make([]*event.Event, 0, 6), Preset: "private_chat", - IsDirect: portal.Metadata.IsDirect, + IsDirect: portal.RoomType == database.RoomTypeDM, PowerLevelOverride: powerLevels, BeeperLocalRoomID: id.RoomID(fmt.Sprintf("!%s:%s", portal.ID, portal.Bridge.Matrix.ServerName())), } @@ -2307,9 +2310,8 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i req.BeeperAutoJoinInvites = true req.Invite = initialMembers } - if info.IsSpace != nil && *info.IsSpace { + if portal.RoomType == database.RoomTypeSpace { req.CreationContent["type"] = event.RoomTypeSpace - portal.Metadata.IsSpace = true } bridgeInfoStateKey, bridgeInfo := portal.getBridgeInfo() diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index f1ec3110..f6828c32 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -139,15 +139,11 @@ func (portal *Portal) sendBatch(ctx context.Context, source *UserLogin, messages MXID: evtID, Room: portal.PortalKey, SenderID: msg.Sender.Sender, + SenderMXID: intent.GetMXID(), Timestamp: msg.Timestamp, ThreadRoot: ptr.Val(msg.ThreadRoot), ReplyTo: ptr.Val(msg.ReplyTo), - Metadata: database.MessageMetadata{ - StandardMessageMetadata: database.StandardMessageMetadata{ - SenderMXID: intent.GetMXID(), - }, - Extra: part.DBMetadata, - }, + Metadata: part.DBMetadata, }) if prevThreadEvent != nil { prevThreadEvent.MXID = evtID diff --git a/bridgev2/space.go b/bridgev2/space.go index 41ef3c2b..17388f3e 100644 --- a/bridgev2/space.go +++ b/bridgev2/space.go @@ -145,8 +145,8 @@ func (ul *UserLogin) GetSpaceRoom(ctx context.Context) (id.RoomID, error) { doublePuppet := ul.User.DoublePuppet(ctx) req := &mautrix.ReqCreateRoom{ Visibility: "private", - Name: fmt.Sprintf("%s (%s)", netName.DisplayName, ul.Metadata.RemoteName), - Topic: fmt.Sprintf("Your %s bridged chats - %s", netName.DisplayName, ul.Metadata.RemoteName), + Name: fmt.Sprintf("%s (%s)", netName.DisplayName, ul.RemoteName), + Topic: fmt.Sprintf("Your %s bridged chats - %s", netName.DisplayName, ul.RemoteName), InitialState: []*event.Event{{ Type: event.StateRoomAvatar, Content: event.Content{ diff --git a/bridgev2/user.go b/bridgev2/user.go index 323e8dbf..7dc9959a 100644 --- a/bridgev2/user.go +++ b/bridgev2/user.go @@ -175,7 +175,7 @@ func (user *User) GetFormattedUserLogins() string { user.Bridge.cacheLock.Lock() logins := make([]string, len(user.logins)) for key, val := range user.logins { - logins = append(logins, fmt.Sprintf("* `%s` (%s)", key, val.Metadata.RemoteName)) + logins = append(logins, fmt.Sprintf("* `%s` (%s)", key, val.RemoteName)) } user.Bridge.cacheLock.Unlock() return strings.Join(logins, "\n") diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 46162ad8..9f271639 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -9,7 +9,6 @@ package bridgev2 import ( "context" "fmt" - "maps" "sync" "time" @@ -188,8 +187,12 @@ func (user *User) NewLogin(ctx context.Context, data *database.UserLogin, params return nil, fmt.Errorf("login already exists") } doInsert = false - ul.Metadata.RemoteName = data.Metadata.RemoteName - maps.Copy(ul.Metadata.Extra, data.Metadata.Extra) + ul.RemoteName = data.RemoteName + if merger, ok := ul.Metadata.(database.MetaMerger); ok { + merger.CopyFrom(data.Metadata) + } else { + ul.Metadata = data.Metadata + } } else { doInsert = true ul = &UserLogin{ @@ -291,7 +294,7 @@ func (ul *UserLogin) GetRemoteID() string { } func (ul *UserLogin) GetRemoteName() string { - return ul.Metadata.RemoteName + return ul.RemoteName } func (ul *UserLogin) Disconnect(done func()) { From 3a6249bf081e87c245f244d68ae8072faba87ad8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 13 Jul 2024 16:45:02 +0300 Subject: [PATCH 0444/1647] dependencies: update go-util --- crypto/sql_store.go | 4 ++-- crypto/sql_store_upgrade/upgrade.go | 2 +- go.mod | 2 +- go.sum | 4 ++-- sqlstatestore/v05-mark-encryption-state-resync.go | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 0b71e36d..d93ee0ca 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -77,8 +77,8 @@ func (store *SQLCryptoStore) PutNextBatch(ctx context.Context, nextBatch string) // GetNextBatch retrieves the next sync batch token for the current account. func (store *SQLCryptoStore) GetNextBatch(ctx context.Context) (string, error) { if store.SyncToken == "" { - err := store.DB.Conn(ctx). - QueryRowContext(ctx, "SELECT sync_token FROM crypto_account WHERE account_id=$1", store.AccountID). + err := store.DB. + QueryRow(ctx, "SELECT sync_token FROM crypto_account WHERE account_id=$1", store.AccountID). Scan(&store.SyncToken) if !errors.Is(err, sql.ErrNoRows) { return "", err diff --git a/crypto/sql_store_upgrade/upgrade.go b/crypto/sql_store_upgrade/upgrade.go index 08c995da..10c0c0c0 100644 --- a/crypto/sql_store_upgrade/upgrade.go +++ b/crypto/sql_store_upgrade/upgrade.go @@ -22,7 +22,7 @@ const VersionTableName = "crypto_version" var fs embed.FS func init() { - Table.Register(-1, 3, 0, "Unsupported version", false, func(ctx context.Context, database *dbutil.Database) error { + Table.Register(-1, 3, 0, "Unsupported version", dbutil.TxnModeOff, func(ctx context.Context, database *dbutil.Database) error { return fmt.Errorf("upgrading from versions 1 and 2 of the crypto store is no longer supported in mautrix-go v0.12+") }) Table.RegisterFS(fs) diff --git a/go.mod b/go.mod index 2610717d..4748b10d 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/tidwall/gjson v1.17.1 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.4 - go.mau.fi/util v0.5.1-0.20240702075351-577617730cb7 + go.mau.fi/util v0.5.1-0.20240713134429-03648b3ede41 go.mau.fi/zeroconfig v0.1.2 golang.org/x/crypto v0.24.0 golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 diff --git a/go.sum b/go.sum index b3fd55a7..192cc79a 100644 --- a/go.sum +++ b/go.sum @@ -46,8 +46,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.5.1-0.20240702075351-577617730cb7 h1:1avw60QZMpzzMMisf6Jqm+WSycZ59OHJA5IlSXHCCPE= -go.mau.fi/util v0.5.1-0.20240702075351-577617730cb7/go.mod h1:DsJzUrJAG53lCZnnYvq9/mOyLuPScWwYhvETiTrpdP4= +go.mau.fi/util v0.5.1-0.20240713134429-03648b3ede41 h1:suJqVZoWuiqmMo/xojAGSxz04fOYYu0oE7sFPrf2L5c= +go.mau.fi/util v0.5.1-0.20240713134429-03648b3ede41/go.mod h1:DsJzUrJAG53lCZnnYvq9/mOyLuPScWwYhvETiTrpdP4= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= diff --git a/sqlstatestore/v05-mark-encryption-state-resync.go b/sqlstatestore/v05-mark-encryption-state-resync.go index bf44d308..b7f2b1c2 100644 --- a/sqlstatestore/v05-mark-encryption-state-resync.go +++ b/sqlstatestore/v05-mark-encryption-state-resync.go @@ -8,7 +8,7 @@ import ( ) func init() { - UpgradeTable.Register(-1, 5, 0, "Mark rooms that need crypto state event resynced", true, func(ctx context.Context, db *dbutil.Database) error { + UpgradeTable.Register(-1, 5, 0, "Mark rooms that need crypto state event resynced", dbutil.TxnModeOn, func(ctx context.Context, db *dbutil.Database) error { portalExists, err := db.TableExists(ctx, "portal") if err != nil { return fmt.Errorf("failed to check if portal table exists") From 51aad3c0d728be92600ef75741108821c7fd9d23 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 13 Jul 2024 19:59:18 +0300 Subject: [PATCH 0445/1647] bridgev2/database: add indexes for some foreign keys --- bridgev2/database/upgrades/00-latest.sql | 6 +++++- bridgev2/database/upgrades/11-room-fkey-idx.sql | 5 +++++ 2 files changed, 10 insertions(+), 1 deletion(-) create mode 100644 bridgev2/database/upgrades/11-room-fkey-idx.sql diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index 2f76e00a..2c170c3e 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v10 (compatible with v9+): Latest revision +-- v0 -> v11 (compatible with v9+): Latest revision CREATE TABLE "user" ( bridge_id TEXT NOT NULL, mxid TEXT NOT NULL, @@ -113,6 +113,7 @@ CREATE TABLE message ( ON DELETE CASCADE ON UPDATE CASCADE, CONSTRAINT message_real_pkey UNIQUE (bridge_id, room_receiver, id, part_id) ); +CREATE INDEX message_room_idx ON message (bridge_id, room_id, room_receiver); CREATE TABLE disappearing_message ( bridge_id TEXT NOT NULL, @@ -150,6 +151,7 @@ CREATE TABLE reaction ( REFERENCES ghost (bridge_id, id) ON DELETE CASCADE ON UPDATE CASCADE ); +CREATE INDEX reaction_room_idx ON reaction (bridge_id, room_id, room_receiver); CREATE TABLE user_portal ( bridge_id TEXT NOT NULL, @@ -169,3 +171,5 @@ CREATE TABLE user_portal ( REFERENCES portal (bridge_id, id, receiver) ON DELETE CASCADE ON UPDATE CASCADE ); +CREATE INDEX user_portal_login_idx ON user_portal (bridge_id, login_id); +CREATE INDEX user_portal_portal_idx ON user_portal (bridge_id, portal_id, portal_receiver); diff --git a/bridgev2/database/upgrades/11-room-fkey-idx.sql b/bridgev2/database/upgrades/11-room-fkey-idx.sql new file mode 100644 index 00000000..d6a67713 --- /dev/null +++ b/bridgev2/database/upgrades/11-room-fkey-idx.sql @@ -0,0 +1,5 @@ +-- v11: Add indexes for some foreign keys +CREATE INDEX message_room_idx ON message (bridge_id, room_id, room_receiver); +CREATE INDEX reaction_room_idx ON reaction (bridge_id, room_id, room_receiver); +CREATE INDEX user_portal_portal_idx ON user_portal (bridge_id, portal_id, portal_receiver); +CREATE INDEX user_portal_login_idx ON user_portal (bridge_id, login_id); From c6da49328327d0eb173e3a2542c149b0c549db36 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 13 Jul 2024 19:59:29 +0300 Subject: [PATCH 0446/1647] event: ignore calls to Mentions.Add with empty user ID --- event/message.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/event/message.go b/event/message.go index 6ce405d6..003f1fcc 100644 --- a/event/message.go +++ b/event/message.go @@ -214,7 +214,7 @@ type Mentions struct { } func (m *Mentions) Add(userID id.UserID) { - if !slices.Contains(m.UserIDs, userID) { + if userID != "" && !slices.Contains(m.UserIDs, userID) { m.UserIDs = append(m.UserIDs, userID) } } From d1905f62321512ae578427243bb694cfb108b76a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 14 Jul 2024 11:06:19 +0300 Subject: [PATCH 0447/1647] bridgev2: rename some uses of ID to Key in reference to portal keys --- bridgev2/commands/startchat.go | 2 +- bridgev2/database/portal.go | 6 +++--- bridgev2/matrix/provisioning.go | 4 ++-- bridgev2/networkinterface.go | 4 ++-- bridgev2/portal.go | 32 ++++++++++++++++---------------- bridgev2/portalreid.go | 4 ++-- bridgev2/queue.go | 2 +- 7 files changed, 27 insertions(+), 27 deletions(-) diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go index 275b2051..9ad5f77c 100644 --- a/bridgev2/commands/startchat.go +++ b/bridgev2/commands/startchat.go @@ -101,7 +101,7 @@ func fnResolveIdentifier(ce *Event) { } portal := resp.Chat.Portal if portal == nil { - portal, err = ce.Bridge.GetPortalByID(ce.Ctx, resp.Chat.PortalID) + portal, err = ce.Bridge.GetPortalByKey(ce.Ctx, resp.Chat.PortalKey) if err != nil { ce.Log.Err(err).Msg("Failed to get portal") ce.Reply("Failed to get portal: %v", err) diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index 503d8a62..417035f0 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -64,7 +64,7 @@ const ( metadata FROM portal ` - getPortalByIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND receiver=$3` + getPortalByKeyQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND receiver=$3` getPortalByIDWithUncertainReceiverQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='')` getPortalByMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` getChildPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND parent_id=$2` @@ -99,8 +99,8 @@ const ( reIDPortalQuery = `UPDATE portal SET id=$4, receiver=$5 WHERE bridge_id=$1 AND id=$2 AND receiver=$3` ) -func (pq *PortalQuery) GetByID(ctx context.Context, key networkid.PortalKey) (*Portal, error) { - return pq.QueryOne(ctx, getPortalByIDQuery, pq.BridgeID, key.ID, key.Receiver) +func (pq *PortalQuery) GetByKey(ctx context.Context, key networkid.PortalKey) (*Portal, error) { + return pq.QueryOne(ctx, getPortalByKeyQuery, pq.BridgeID, key.ID, key.Receiver) } func (pq *PortalQuery) FindReceiver(ctx context.Context, id networkid.PortalID, maybeReceiver networkid.UserLoginID) (key networkid.PortalKey, err error) { diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 197670d4..c0eb16e4 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -416,7 +416,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http. } if resp.Chat != nil { if resp.Chat.Portal == nil { - resp.Chat.Portal, err = prov.br.Bridge.GetPortalByID(r.Context(), resp.Chat.PortalID) + resp.Chat.Portal, err = prov.br.Bridge.GetPortalByKey(r.Context(), resp.Chat.PortalKey) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get portal") jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ @@ -498,7 +498,7 @@ func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Reque } if contact.Chat != nil { if contact.Chat.Portal == nil { - contact.Chat.Portal, err = prov.br.Bridge.GetPortalByID(r.Context(), contact.Chat.PortalID) + contact.Chat.Portal, err = prov.br.Bridge.GetPortalByKey(r.Context(), contact.Chat.PortalKey) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get portal") jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 6ee84e4a..43842bdd 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -504,8 +504,8 @@ type ResolveIdentifierResponse struct { } type CreateChatResponse struct { - PortalID networkid.PortalKey - // Portal and PortalInfo are not required, the caller will fetch them automatically based on PortalID if necessary. + PortalKey networkid.PortalKey + // Portal and PortalInfo are not required, the caller will fetch them automatically based on PortalKey if necessary. Portal *Portal PortalInfo *ChatInfo } diff --git a/bridgev2/portal.go b/bridgev2/portal.go index c5054022..81b65ad2 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -95,7 +95,7 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que } var err error if portal.ParentID != "" { - portal.Parent, err = br.UnlockedGetPortalByID(ctx, networkid.PortalKey{ID: portal.ParentID}, false) + portal.Parent, err = br.UnlockedGetPortalByKey(ctx, networkid.PortalKey{ID: portal.ParentID}, false) if err != nil { return nil, fmt.Errorf("failed to load parent portal (%s): %w", portal.ParentID, err) } @@ -119,17 +119,17 @@ func (portal *Portal) updateLogger() { portal.Log = logWith.Logger() } -func (br *Bridge) UnlockedGetPortalByID(ctx context.Context, id networkid.PortalKey, onlyIfExists bool) (*Portal, error) { - cached, ok := br.portalsByKey[id] +func (br *Bridge) UnlockedGetPortalByKey(ctx context.Context, key networkid.PortalKey, onlyIfExists bool) (*Portal, error) { + cached, ok := br.portalsByKey[key] if ok { return cached, nil } - idPtr := &id + keyPtr := &key if onlyIfExists { - idPtr = nil + keyPtr = nil } - db, err := br.DB.Portal.GetByID(ctx, id) - return br.loadPortal(ctx, db, err, idPtr) + db, err := br.DB.Portal.GetByKey(ctx, key) + return br.loadPortal(ctx, db, err, keyPtr) } func (br *Bridge) FindPortalReceiver(ctx context.Context, id networkid.PortalID, maybeReceiver networkid.UserLoginID) (networkid.PortalKey, error) { @@ -172,27 +172,27 @@ func (br *Bridge) GetPortalByMXID(ctx context.Context, mxid id.RoomID) (*Portal, return br.loadPortal(ctx, db, err, nil) } -func (br *Bridge) GetPortalByID(ctx context.Context, id networkid.PortalKey) (*Portal, error) { +func (br *Bridge) GetPortalByKey(ctx context.Context, key networkid.PortalKey) (*Portal, error) { br.cacheLock.Lock() defer br.cacheLock.Unlock() - return br.UnlockedGetPortalByID(ctx, id, false) + return br.UnlockedGetPortalByKey(ctx, key, false) } -func (br *Bridge) GetExistingPortalByID(ctx context.Context, id networkid.PortalKey) (*Portal, error) { +func (br *Bridge) GetExistingPortalByKey(ctx context.Context, key networkid.PortalKey) (*Portal, error) { br.cacheLock.Lock() defer br.cacheLock.Unlock() - if id.Receiver == "" { - return br.UnlockedGetPortalByID(ctx, id, true) + if key.Receiver == "" { + return br.UnlockedGetPortalByKey(ctx, key, true) } - cached, ok := br.portalsByKey[id] + cached, ok := br.portalsByKey[key] if ok { return cached, nil } - cached, ok = br.portalsByKey[networkid.PortalKey{ID: id.ID}] + cached, ok = br.portalsByKey[networkid.PortalKey{ID: key.ID}] if ok { return cached, nil } - db, err := br.DB.Portal.GetByIDWithUncertainReceiver(ctx, id) + db, err := br.DB.Portal.GetByIDWithUncertainReceiver(ctx, key) return br.loadPortal(ctx, db, err, nil) } @@ -2178,7 +2178,7 @@ func (portal *Portal) UpdateParent(ctx context.Context, newParent networkid.Port portal.ParentID = newParent portal.InSpace = false if newParent != "" { - portal.Parent, err = portal.Bridge.GetPortalByID(ctx, networkid.PortalKey{ID: newParent}) + portal.Parent, err = portal.Bridge.GetPortalByKey(ctx, networkid.PortalKey{ID: newParent}) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to get new parent portal") } diff --git a/bridgev2/portalreid.go b/bridgev2/portalreid.go index c4f7a69b..a25fe820 100644 --- a/bridgev2/portalreid.go +++ b/bridgev2/portalreid.go @@ -39,7 +39,7 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta }() br.cacheLock.Lock() defer br.cacheLock.Unlock() - sourcePortal, err := br.UnlockedGetPortalByID(ctx, source, true) + sourcePortal, err := br.UnlockedGetPortalByKey(ctx, source, true) if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to get source portal: %w", err) } else if sourcePortal == nil { @@ -59,7 +59,7 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Stringer("source_portal_mxid", sourcePortal.MXID) }) - targetPortal, err := br.UnlockedGetPortalByID(ctx, target, true) + targetPortal, err := br.UnlockedGetPortalByKey(ctx, target, true) if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to get target portal: %w", err) } diff --git a/bridgev2/queue.go b/bridgev2/queue.go index ec60cbb8..6254fd62 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -139,7 +139,7 @@ func (br *Bridge) handleBotInvite(ctx context.Context, evt *event.Event, sender func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) { log := login.Log ctx := log.WithContext(context.TODO()) - portal, err := br.GetPortalByID(ctx, evt.GetPortalKey()) + portal, err := br.GetPortalByKey(ctx, evt.GetPortalKey()) if err != nil { log.Err(err).Object("portal_id", evt.GetPortalKey()). Msg("Failed to get portal to handle remote event") From ffceb93f0f047624e4e1062441e1ed67e38a9d3b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 14 Jul 2024 11:24:57 +0300 Subject: [PATCH 0448/1647] changelog: update --- CHANGELOG.md | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b8213ae..30fff36d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,10 +1,29 @@ ## v0.19.0 (unreleased) +* *(bridgev2)* Added more features. * *(crypto)* Fixed bug with copying `m.relates_to` from wire content to decrypted content. -* *(bridgev2)* Added more features. * *(mediaproxy)* Added module for implementing simple media repos that proxy requests elsewhere. +* *(client)* Changed `Members()` to automatically parse event content for all + returned events. +* *(bridge)* Added `/register` call if `/versions` fails with `M_FORBIDDEN`. +* *(crypto)* Fixed `DecryptMegolmEvent` sometimes calling database without + transaction by using the non-context version of `ResolveTrust`. +* *(crypto/attachment)* Implemented `io.Seeker` in `EncryptStream` to allow + using it in retriable HTTP requests. +* *(event)* Added helper method to add user ID to a `Mentions` object. +* *(event)* Fixed default power level for invites + (thanks to [@rudis] in [#250]). +* *(client)* Fixed incorrect warning log in `State()` when state store returns + no error (thanks to [@rudis] in [#249]). +* *(crypto/verificationhelper)* Fixed deadlock when ignoring unknown + cancellation events (thanks to [@rudis] in [#247]). + +[@rudis]: https://github.com/rudis +[#250]: https://github.com/mautrix/go/pull/250 +[#249]: https://github.com/mautrix/go/pull/249 +[#247]: https://github.com/mautrix/go/pull/247 ### beta.1 (2024-06-16) From 921f8fdfc48dac517ec28933915cd8e04896e1c1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 14 Jul 2024 11:28:03 +0300 Subject: [PATCH 0449/1647] main: rename master branch to main --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 30fff36d..9629a997 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ ## v0.19.0 (unreleased) +* Renamed `master` branch to `main`. * *(bridgev2)* Added more features. * *(crypto)* Fixed bug with copying `m.relates_to` from wire content to decrypted content. From edf1a8d8d02255a1915fee10b10903edda28afa1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 14 Jul 2024 14:45:57 +0300 Subject: [PATCH 0450/1647] bridge2/database: fix bugs in metadata move --- bridgev2/database/ghost.go | 2 +- bridgev2/database/message.go | 2 +- bridgev2/database/portal.go | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/bridgev2/database/ghost.go b/bridgev2/database/ghost.go index 916051a4..c4c626f0 100644 --- a/bridgev2/database/ghost.go +++ b/bridgev2/database/ghost.go @@ -50,7 +50,7 @@ const ( bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc, name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata ) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) ` updateGhostQuery = ` UPDATE ghost SET name=$3, avatar_id=$4, avatar_hash=$5, avatar_mxc=$6, diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index 504f91b2..19cc7bf7 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -178,7 +178,7 @@ func (m *Message) ensureHasMetadata(metaType MetaTypeCreator) *Message { func (m *Message) sqlVariables() []any { return []any{ m.BridgeID, m.ID, m.PartID, m.MXID, m.Room.ID, m.Room.Receiver, m.SenderID, m.SenderMXID, - m.EditCount, m.Timestamp.UnixNano(), dbutil.StrPtr(m.ThreadRoot), dbutil.StrPtr(m.ReplyTo.MessageID), m.ReplyTo.PartID, + m.Timestamp.UnixNano(), m.EditCount, dbutil.StrPtr(m.ThreadRoot), dbutil.StrPtr(m.ReplyTo.MessageID), m.ReplyTo.PartID, dbutil.JSON{Data: m.Metadata}, } } diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index 417035f0..b27e973f 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -190,6 +190,7 @@ func (p *Portal) sqlVariables() []any { dbutil.StrPtr(p.ParentID), dbutil.StrPtr(p.RelayLoginID), p.Name, p.Topic, p.AvatarID, avatarHash, p.AvatarMXC, p.NameSet, p.TopicSet, p.AvatarSet, p.InSpace, + p.RoomType, dbutil.StrPtr(p.Disappear.Type), dbutil.NumPtr(p.Disappear.Timer), dbutil.JSON{Data: p.Metadata}, } } From fb9fb5ae44056636f340e9cacc6f11a55fca0ead Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 14 Jul 2024 15:10:51 +0300 Subject: [PATCH 0451/1647] bridgev2: add method for getting all portals with Matrix room --- bridgev2/database/portal.go | 5 +++++ bridgev2/portal.go | 27 +++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index b27e973f..77bb7e81 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -67,6 +67,7 @@ const ( getPortalByKeyQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND receiver=$3` getPortalByIDWithUncertainReceiverQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='')` getPortalByMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` + getAllPortalsWithMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid IS NOT NULL` getChildPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND parent_id=$2` findPortalReceiverQuery = `SELECT id, receiver FROM portal WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='') LIMIT 1` @@ -119,6 +120,10 @@ func (pq *PortalQuery) GetByMXID(ctx context.Context, mxid id.RoomID) (*Portal, return pq.QueryOne(ctx, getPortalByMXIDQuery, pq.BridgeID, mxid) } +func (pq *PortalQuery) GetAllWithMXID(ctx context.Context) ([]*Portal, error) { + return pq.QueryMany(ctx, getAllPortalsWithMXIDQuery, pq.BridgeID) +} + func (pq *PortalQuery) GetChildren(ctx context.Context, parentID networkid.PortalID) ([]*Portal, error) { return pq.QueryMany(ctx, getChildPortalsQuery, pq.BridgeID, parentID) } diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 81b65ad2..58322310 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -119,6 +119,23 @@ func (portal *Portal) updateLogger() { portal.Log = logWith.Logger() } +func (br *Bridge) loadManyPortals(ctx context.Context, portals []*database.Portal) ([]*Portal, error) { + output := make([]*Portal, 0, len(portals)) + for _, dbPortal := range portals { + if cached, ok := br.portalsByKey[dbPortal.PortalKey]; ok { + output = append(output, cached) + } else { + loaded, err := br.loadPortal(ctx, dbPortal, nil, nil) + if err != nil { + return nil, err + } else if loaded != nil { + output = append(output, loaded) + } + } + } + return output, nil +} + func (br *Bridge) UnlockedGetPortalByKey(ctx context.Context, key networkid.PortalKey, onlyIfExists bool) (*Portal, error) { cached, ok := br.portalsByKey[key] if ok { @@ -172,6 +189,16 @@ func (br *Bridge) GetPortalByMXID(ctx context.Context, mxid id.RoomID) (*Portal, return br.loadPortal(ctx, db, err, nil) } +func (br *Bridge) GetAllPortalsWithMXID(ctx context.Context) ([]*Portal, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + rows, err := br.DB.Portal.GetAllWithMXID(ctx) + if err != nil { + return nil, err + } + return br.loadManyPortals(ctx, rows) +} + func (br *Bridge) GetPortalByKey(ctx context.Context, key networkid.PortalKey) (*Portal, error) { br.cacheLock.Lock() defer br.cacheLock.Unlock() From cb850e3f029371961b43b34f5a9e99c07d36a800 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 15 Jul 2024 15:35:57 +0300 Subject: [PATCH 0452/1647] dependencies: update --- go.mod | 12 ++++++------ go.sum | 24 ++++++++++++------------ 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/go.mod b/go.mod index 4748b10d..e83e162e 100644 --- a/go.mod +++ b/go.mod @@ -15,11 +15,11 @@ require ( github.com/tidwall/gjson v1.17.1 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.4 - go.mau.fi/util v0.5.1-0.20240713134429-03648b3ede41 - go.mau.fi/zeroconfig v0.1.2 - golang.org/x/crypto v0.24.0 - golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 - golang.org/x/net v0.26.0 + go.mau.fi/util v0.5.1-0.20240714204302-8d7c8742a899 + go.mau.fi/zeroconfig v0.1.3 + golang.org/x/crypto v0.25.0 + golang.org/x/exp v0.0.0-20240707233637-46b078467d37 + golang.org/x/net v0.27.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 ) @@ -32,6 +32,6 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect - golang.org/x/sys v0.21.0 // indirect + golang.org/x/sys v0.22.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index 192cc79a..b30d9d01 100644 --- a/go.sum +++ b/go.sum @@ -46,22 +46,22 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.5.1-0.20240713134429-03648b3ede41 h1:suJqVZoWuiqmMo/xojAGSxz04fOYYu0oE7sFPrf2L5c= -go.mau.fi/util v0.5.1-0.20240713134429-03648b3ede41/go.mod h1:DsJzUrJAG53lCZnnYvq9/mOyLuPScWwYhvETiTrpdP4= -go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= -go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= -golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= -golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 h1:yixxcjnhBmY0nkL253HFVIm0JsFHwrHdT3Yh6szTnfY= -golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8/go.mod h1:jj3sYF3dwk5D+ghuXyeI3r5MFf+NT2An6/9dOA95KSI= -golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= -golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= +go.mau.fi/util v0.5.1-0.20240714204302-8d7c8742a899 h1:6/4XgDIvH2/4+aQ1WADo7UOmQCiHjx7wd0jjezew7JE= +go.mau.fi/util v0.5.1-0.20240714204302-8d7c8742a899/go.mod h1:DsJzUrJAG53lCZnnYvq9/mOyLuPScWwYhvETiTrpdP4= +go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= +go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= +golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= +golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/exp v0.0.0-20240707233637-46b078467d37 h1:uLDX+AfeFCct3a2C7uIWBKMJIR3CJMhcgfrUAqjRK6w= +golang.org/x/exp v0.0.0-20240707233637-46b078467d37/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= +golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= +golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= -golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= +golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= From ccb40ff7b496074342dad4f0c17a7652a9c842dd Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 16 Jul 2024 11:13:11 +0300 Subject: [PATCH 0453/1647] Bump version to v0.19.0 --- CHANGELOG.md | 2 +- go.mod | 2 +- go.sum | 4 ++-- version.go | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9629a997..feaedcd1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -## v0.19.0 (unreleased) +## v0.19.0 (2024-07-16) * Renamed `master` branch to `main`. * *(bridgev2)* Added more features. diff --git a/go.mod b/go.mod index e83e162e..4724d937 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/tidwall/gjson v1.17.1 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.4 - go.mau.fi/util v0.5.1-0.20240714204302-8d7c8742a899 + go.mau.fi/util v0.6.0 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.25.0 golang.org/x/exp v0.0.0-20240707233637-46b078467d37 diff --git a/go.sum b/go.sum index b30d9d01..5c4d19ec 100644 --- a/go.sum +++ b/go.sum @@ -46,8 +46,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.5.1-0.20240714204302-8d7c8742a899 h1:6/4XgDIvH2/4+aQ1WADo7UOmQCiHjx7wd0jjezew7JE= -go.mau.fi/util v0.5.1-0.20240714204302-8d7c8742a899/go.mod h1:DsJzUrJAG53lCZnnYvq9/mOyLuPScWwYhvETiTrpdP4= +go.mau.fi/util v0.6.0 h1:W6SyB3Bm/GjenQ5iq8Z8WWdN85Gy2xS6L0wmnR7SVjg= +go.mau.fi/util v0.6.0/go.mod h1:ljYdq3sPfpICc3zMU+/mHV/sa4z0nKxc67hSBwnrk8U= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= diff --git a/version.go b/version.go index e00141ae..d98634ec 100644 --- a/version.go +++ b/version.go @@ -7,7 +7,7 @@ import ( "strings" ) -const Version = "v0.18.1" +const Version = "v0.19.0" var GoModVersion = "" var Commit = "" From a3406120713a9a4c69d859aa85fb976b32594679 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 16 Jul 2024 16:29:34 +0300 Subject: [PATCH 0454/1647] bridgev2: save other user ID in DM portals --- bridgev2/database/portal.go | 21 +++++++------ bridgev2/database/upgrades/00-latest.sql | 4 ++- .../upgrades/12-dm-portal-other-user.sql | 2 ++ bridgev2/portal.go | 30 +++++++++++++++++++ 4 files changed, 47 insertions(+), 10 deletions(-) create mode 100644 bridgev2/database/upgrades/12-dm-portal-other-user.sql diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index 77bb7e81..af0f6a9f 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -41,6 +41,7 @@ type Portal struct { ParentID networkid.PortalID RelayLoginID networkid.UserLoginID + OtherUserID networkid.UserID Name string Topic string AvatarID networkid.AvatarID @@ -57,7 +58,7 @@ type Portal struct { const ( getPortalBaseQuery = ` - SELECT bridge_id, id, receiver, mxid, parent_id, relay_login_id, + SELECT bridge_id, id, receiver, mxid, parent_id, relay_login_id, other_user_id, name, topic, avatar_id, avatar_hash, avatar_mxc, name_set, topic_set, avatar_set, in_space, room_type, disappear_type, disappear_timer, @@ -75,22 +76,22 @@ const ( insertPortalQuery = ` INSERT INTO portal ( bridge_id, id, receiver, mxid, - parent_id, relay_login_id, + parent_id, relay_login_id, other_user_id, name, topic, avatar_id, avatar_hash, avatar_mxc, name_set, avatar_set, topic_set, in_space, room_type, disappear_type, disappear_timer, metadata, relay_bridge_id ) VALUES ( - $1, $2, $3, $4, $5, cast($6 AS TEXT), $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, + $1, $2, $3, $4, $5, cast($6 AS TEXT), $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, CASE WHEN cast($6 AS TEXT) IS NULL THEN NULL ELSE $1 END ) ` updatePortalQuery = ` UPDATE portal SET mxid=$4, parent_id=$5, relay_login_id=cast($6 AS TEXT), relay_bridge_id=CASE WHEN cast($6 AS TEXT) IS NULL THEN NULL ELSE bridge_id END, - name=$7, topic=$8, avatar_id=$9, avatar_hash=$10, avatar_mxc=$11, - name_set=$12, avatar_set=$13, topic_set=$14, in_space=$15, - room_type=$16, disappear_type=$17, disappear_timer=$18, metadata=$19 + other_user_id=$7, name=$8, topic=$9, avatar_id=$10, avatar_hash=$11, avatar_mxc=$12, + name_set=$13, avatar_set=$14, topic_set=$15, in_space=$16, + room_type=$17, disappear_type=$18, disappear_timer=$19, metadata=$20 WHERE bridge_id=$1 AND id=$2 AND receiver=$3 ` deletePortalQuery = ` @@ -147,12 +148,13 @@ func (pq *PortalQuery) Delete(ctx context.Context, key networkid.PortalKey) erro } func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { - var mxid, parentID, relayLoginID, disappearType sql.NullString + var mxid, parentID, relayLoginID, otherUserID, disappearType sql.NullString var disappearTimer sql.NullInt64 var avatarHash string err := row.Scan( &p.BridgeID, &p.ID, &p.Receiver, &mxid, - &parentID, &relayLoginID, &p.Name, &p.Topic, &p.AvatarID, &avatarHash, &p.AvatarMXC, + &parentID, &relayLoginID, &otherUserID, + &p.Name, &p.Topic, &p.AvatarID, &avatarHash, &p.AvatarMXC, &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.InSpace, &p.RoomType, &disappearType, &disappearTimer, dbutil.JSON{Data: p.Metadata}, @@ -173,6 +175,7 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { } } p.MXID = id.RoomID(mxid.String) + p.OtherUserID = networkid.UserID(otherUserID.String) p.ParentID = networkid.PortalID(parentID.String) p.RelayLoginID = networkid.UserLoginID(relayLoginID.String) return p, nil @@ -192,7 +195,7 @@ func (p *Portal) sqlVariables() []any { } return []any{ p.BridgeID, p.ID, p.Receiver, dbutil.StrPtr(p.MXID), - dbutil.StrPtr(p.ParentID), dbutil.StrPtr(p.RelayLoginID), + dbutil.StrPtr(p.ParentID), dbutil.StrPtr(p.RelayLoginID), dbutil.StrPtr(p.OtherUserID), p.Name, p.Topic, p.AvatarID, avatarHash, p.AvatarMXC, p.NameSet, p.TopicSet, p.AvatarSet, p.InSpace, p.RoomType, dbutil.StrPtr(p.Disappear.Type), dbutil.NumPtr(p.Disappear.Timer), diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index 2c170c3e..943f7a59 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v11 (compatible with v9+): Latest revision +-- v0 -> v12 (compatible with v9+): Latest revision CREATE TABLE "user" ( bridge_id TEXT NOT NULL, mxid TEXT NOT NULL, @@ -37,6 +37,8 @@ CREATE TABLE portal ( relay_bridge_id TEXT, relay_login_id TEXT, + other_user_id TEXT, + name TEXT NOT NULL, topic TEXT NOT NULL, avatar_id TEXT NOT NULL, diff --git a/bridgev2/database/upgrades/12-dm-portal-other-user.sql b/bridgev2/database/upgrades/12-dm-portal-other-user.sql new file mode 100644 index 00000000..60ae4e3a --- /dev/null +++ b/bridgev2/database/upgrades/12-dm-portal-other-user.sql @@ -0,0 +1,2 @@ +-- v12: Save other user ID in DM portals +ALTER TABLE portal ADD COLUMN other_user_id TEXT; diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 58322310..cbc5f8bd 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1734,6 +1734,10 @@ type ChatMemberList struct { // The total number of members in the chat, regardless of how many of those members are included in Members. TotalMemberCount int + // For DM portals, the ID of the recipient user. + // This field is optional and will be automatically filled from Members if there are only 2 entries in the list. + OtherUserID networkid.UserID + Members []ChatMember PowerLevels *PowerLevelChanges } @@ -1970,9 +1974,34 @@ func (portal *Portal) GetInitialMemberList(ctx context.Context, members *ChatMem pl.EnsureUserLevel(intent.GetMXID(), member.PowerLevel) } } + portal.updateOtherUser(ctx, members) return } +func (portal *Portal) updateOtherUser(ctx context.Context, members *ChatMemberList) (changed bool) { + var expectedUserID networkid.UserID + if portal.RoomType != database.RoomTypeDM { + // expected user ID is empty + } else if members.OtherUserID != "" { + expectedUserID = members.OtherUserID + } else if len(members.Members) == 2 && members.IsFull { + if members.Members[0].IsFromMe && !members.Members[1].IsFromMe { + expectedUserID = members.Members[1].Sender + } else if members.Members[1].IsFromMe && !members.Members[0].IsFromMe { + expectedUserID = members.Members[0].Sender + } + } + if portal.OtherUserID != expectedUserID { + zerolog.Ctx(ctx).Debug(). + Str("old_other_user_id", string(portal.OtherUserID)). + Str("new_other_user_id", string(expectedUserID)). + Msg("Updating other user ID in DM portal") + portal.OtherUserID = expectedUserID + return true + } + return false +} + func (portal *Portal) SyncParticipants(ctx context.Context, members *ChatMemberList, source *UserLogin, sender MatrixAPI, ts time.Time) error { var loginsInPortal []*UserLogin var err error @@ -2101,6 +2130,7 @@ func (portal *Portal) SyncParticipants(ctx context.Context, members *ChatMemberL log.Err(err).Msg("Failed to update power levels") } } + portal.updateOtherUser(ctx, members) if members.IsFull { for extraMember, memberEvt := range currentMembers { if memberEvt.Membership == event.MembershipLeave || memberEvt.Membership == event.MembershipBan { From c24fd786af9b6dd235e8c052974a58970bbba211 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 16 Jul 2024 16:36:29 +0300 Subject: [PATCH 0455/1647] bridgev2: add basic support for backfilling threads --- bridgev2/networkinterface.go | 3 +++ bridgev2/portalbackfill.go | 32 +++++++++++++++++++++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 43842bdd..432041b2 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -367,6 +367,9 @@ type BackfillMessage struct { ID networkid.MessageID Timestamp time.Time Reactions []*BackfillReaction + + ShouldBackfillThread bool + LastThreadMessage networkid.MessageID } // FetchMessagesResponse contains the response for a message history pagination request. diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index f6828c32..33c2fc6d 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -40,7 +40,7 @@ func (portal *Portal) doForwardBackfill(ctx context.Context, source *UserLogin, ThreadRoot: "", Forward: true, AnchorMessage: lastMessage, - Count: 100, + Count: 100, // TODO make count configurable }) if err != nil { log.Err(err).Msg("Failed to fetch messages for forward backfill") @@ -70,6 +70,31 @@ func (portal *Portal) DoBackwardsBackfill(ctx context.Context, source *UserLogin //portal.sendBackfill(ctx, source, resp.Messages, false, resp.MarkRead, lastMessage) } +func (portal *Portal) doThreadBackfill(ctx context.Context, source *UserLogin, threadID networkid.MessageID) { + log := zerolog.Ctx(ctx).With(). + Str("subaction", "thread backfill"). + Str("thread_id", string(threadID)). + Logger() + log.Info().Msg("Backfilling thread inside other backfill") + anchorMessage, err := portal.Bridge.DB.Message.GetLastThreadMessage(ctx, portal.PortalKey, threadID) + if err != nil { + log.Err(err).Msg("Failed to get last thread message") + return + } + resp, err := source.Client.(BackfillingNetworkAPI).FetchMessages(ctx, FetchMessagesParams{ + Portal: portal, + ThreadRoot: threadID, + Forward: true, + AnchorMessage: anchorMessage, + Count: 100, // TODO make count configurable + }) + if err != nil { + log.Err(err).Msg("Failed to fetch messages for thread backfill") + return + } + portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, anchorMessage) +} + func (portal *Portal) sendBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead bool, lastMessage *database.Message) { if lastMessage != nil { if forceForward { @@ -97,6 +122,11 @@ func (portal *Portal) sendBackfill(ctx context.Context, source *UserLogin, messa portal.sendLegacyBackfill(ctx, source, messages, markRead) } zerolog.Ctx(ctx).Debug().Msg("Backfill finished") + for _, msg := range messages { + if msg.ShouldBackfillThread { + portal.doThreadBackfill(ctx, source, msg.ID) + } + } } func (portal *Portal) sendBatch(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead bool) { From 0d122e5bb2361a792ede62b50e93cd511c99e0d5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 16 Jul 2024 16:53:21 +0300 Subject: [PATCH 0456/1647] bridgev2: implement backwards backfilling method --- bridgev2/database/message.go | 5 +++ bridgev2/portalbackfill.go | 60 +++++++++++++++++++++++++----------- 2 files changed, 47 insertions(+), 18 deletions(-) diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index 19cc7bf7..184ce3d8 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -55,6 +55,7 @@ const ( getLastMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id DESC LIMIT 1` getFirstMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id ASC LIMIT 1` getMessagesBetweenTimeQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND timestamp>$4 AND timestamp<=$5` + getOldestMessageInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp ASC, part_id ASC LIMIT 1` getFirstMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp ASC, part_id ASC LIMIT 1` getLastMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp DESC, part_id DESC LIMIT 1` @@ -121,6 +122,10 @@ func (mq *MessageQuery) GetMessagesBetweenTimeQuery(ctx context.Context, portal return mq.QueryMany(ctx, getMessagesBetweenTimeQuery, mq.BridgeID, portal.ID, portal.Receiver, start.UnixNano(), end.UnixNano()) } +func (mq *MessageQuery) GetFirstPortalMessage(ctx context.Context, portal networkid.PortalKey) (*Message, error) { + return mq.QueryOne(ctx, getOldestMessageInPortal, mq.BridgeID, portal.ID, portal.Receiver) +} + func (mq *MessageQuery) GetFirstThreadMessage(ctx context.Context, portal networkid.PortalKey, threadRoot networkid.MessageID) (*Message, error) { return mq.QueryOne(ctx, getFirstMessageInThread, mq.BridgeID, portal.ID, portal.Receiver, threadRoot) } diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index 33c2fc6d..8e3bdef7 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -50,24 +50,29 @@ func (portal *Portal) doForwardBackfill(ctx context.Context, source *UserLogin, } func (portal *Portal) DoBackwardsBackfill(ctx context.Context, source *UserLogin) { - //log := zerolog.Ctx(ctx) - //api, ok := source.Client.(BackfillingNetworkAPI) - //if !ok { - // log.Debug().Msg("Network API does not support backfilling") - // return - //} - //resp, err := api.FetchMessages(ctx, FetchMessagesParams{ - // Portal: portal, - // ThreadRoot: "", - // Forward: true, - // AnchorMessage: lastMessage, - // Count: 100, - //}) - //if err != nil { - // log.Err(err).Msg("Failed to fetch messages for forward backfill") - // return - //} - //portal.sendBackfill(ctx, source, resp.Messages, false, resp.MarkRead, lastMessage) + log := zerolog.Ctx(ctx) + api, ok := source.Client.(BackfillingNetworkAPI) + if !ok { + log.Debug().Msg("Network API does not support backfilling") + return + } + firstMessage, err := portal.Bridge.DB.Message.GetFirstPortalMessage(ctx, portal.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to get oldest portal message") + return + } + resp, err := api.FetchMessages(ctx, FetchMessagesParams{ + Portal: portal, + ThreadRoot: "", + Forward: false, + AnchorMessage: firstMessage, + Count: 100, // TODO make count configurable + }) + if err != nil { + log.Err(err).Msg("Failed to fetch messages for forward backfill") + return + } + portal.sendBackfill(ctx, source, resp.Messages, false, resp.MarkRead, firstMessage) } func (portal *Portal) doThreadBackfill(ctx context.Context, source *UserLogin, threadID networkid.MessageID) { @@ -102,6 +107,8 @@ func (portal *Portal) sendBackfill(ctx context.Context, source *UserLogin, messa for i, msg := range messages { if msg.Timestamp.Before(lastMessage.Timestamp) { cutoff = i + } else { + break } } if cutoff != 0 { @@ -112,6 +119,23 @@ func (portal *Portal) sendBackfill(ctx context.Context, source *UserLogin, messa Msg("Cutting off forward backfill messages older than latest bridged message") messages = messages[cutoff:] } + } else { + cutoff := -1 + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Timestamp.After(lastMessage.Timestamp) { + cutoff = i + } else { + break + } + } + if cutoff != -1 { + zerolog.Ctx(ctx).Debug(). + Int("cutoff_count", len(messages)-cutoff). + Int("total_count", len(messages)). + Time("oldest_bridged_ts", lastMessage.Timestamp). + Msg("Cutting off backward backfill messages newer than oldest bridged message") + messages = messages[cutoff:] + } } } canBatchSend := portal.Bridge.Matrix.GetCapabilities().BatchSending From 128781cffe32a7b1b674fb57ee8e31e10a9cbd0f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 16 Jul 2024 17:33:45 +0300 Subject: [PATCH 0457/1647] bridgev2: fix some things in backfill --- bridgev2/ghost.go | 4 ++++ bridgev2/portal.go | 12 ++++++------ bridgev2/portalbackfill.go | 22 ++++++++++++---------- 3 files changed, 22 insertions(+), 16 deletions(-) diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index da1e81f0..78fe7b00 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -213,6 +213,10 @@ func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin if ghost.Name != "" && ghost.NameSet && !ghost.Bridge.allowAggressiveUpdateForType(evtType) { return } + zerolog.Ctx(ctx).Debug(). + Bool("has_name", ghost.Name != ""). + Bool("name_set", ghost.NameSet). + Msg("Updating ghost info in IfNecessary call") info, err := source.Client.GetUserInfo(ctx, ghost) if err != nil { zerolog.Ctx(ctx).Err(err).Str("ghost_id", string(ghost.ID)).Msg("Failed to get info to update ghost") diff --git a/bridgev2/portal.go b/bridgev2/portal.go index cbc5f8bd..21654e9a 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1192,7 +1192,7 @@ func (portal *Portal) GetIntentFor(ctx context.Context, sender EventSender, sour return intent } -func (portal *Portal) getRelationMeta(ctx context.Context, replyToPtr *networkid.MessageOptionalPartID, threadRootPtr *networkid.MessageID, isBatchSend bool) (replyTo, threadRoot, prevThreadEvent *database.Message) { +func (portal *Portal) getRelationMeta(ctx context.Context, currentMsg networkid.MessageID, replyToPtr *networkid.MessageOptionalPartID, threadRootPtr *networkid.MessageID, isBatchSend bool) (replyTo, threadRoot, prevThreadEvent *database.Message) { log := zerolog.Ctx(ctx) var err error if replyToPtr != nil { @@ -1210,7 +1210,7 @@ func (portal *Portal) getRelationMeta(ctx context.Context, replyToPtr *networkid } } } - if threadRootPtr != nil { + if threadRootPtr != nil && *threadRootPtr != currentMsg { threadRoot, err = portal.Bridge.DB.Message.GetFirstThreadMessage(ctx, portal.PortalKey, *threadRootPtr) if err != nil { log.Err(err).Msg("Failed to get thread root message from database") @@ -1252,7 +1252,7 @@ func (portal *Portal) sendConvertedMessage(ctx context.Context, id networkid.Mes } } log := zerolog.Ctx(ctx) - replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta(ctx, converted.ReplyTo, converted.ThreadRoot, false) + replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta(ctx, id, converted.ReplyTo, converted.ThreadRoot, false) output := make([]*database.Message, 0, len(converted.Parts)) for _, part := range converted.Parts { portal.applyRelationMeta(part.Content, replyTo, threadRoot, prevThreadEvent) @@ -1261,10 +1261,10 @@ func (portal *Portal) sendConvertedMessage(ctx context.Context, id networkid.Mes Raw: part.Extra, }, ts) if err != nil { - log.Err(err).Str("part_id", string(part.ID)).Msg("Failed to send message part to Matrix") + logContext(log.Err(err)).Str("part_id", string(part.ID)).Msg("Failed to send message part to Matrix") continue } - log.Debug(). + logContext(log.Debug()). Stringer("event_id", resp.EventID). Str("part_id", string(part.ID)). Msg("Sent message part to Matrix") @@ -1282,7 +1282,7 @@ func (portal *Portal) sendConvertedMessage(ctx context.Context, id networkid.Mes } err = portal.Bridge.DB.Message.Insert(ctx, dbMessage) if err != nil { - log.Err(err).Str("part_id", string(part.ID)).Msg("Failed to save message part to database") + logContext(log.Err(err)).Str("part_id", string(part.ID)).Msg("Failed to save message part to database") } if converted.Disappear.Type != database.DisappearingTypeNone { if converted.Disappear.Type == database.DisappearingTypeAfterSend && converted.Disappear.DisappearAt.IsZero() { diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index 8e3bdef7..06eabbf2 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -46,7 +46,7 @@ func (portal *Portal) doForwardBackfill(ctx context.Context, source *UserLogin, log.Err(err).Msg("Failed to fetch messages for forward backfill") return } - portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, lastMessage) + portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, false, lastMessage) } func (portal *Portal) DoBackwardsBackfill(ctx context.Context, source *UserLogin) { @@ -72,7 +72,7 @@ func (portal *Portal) DoBackwardsBackfill(ctx context.Context, source *UserLogin log.Err(err).Msg("Failed to fetch messages for forward backfill") return } - portal.sendBackfill(ctx, source, resp.Messages, false, resp.MarkRead, firstMessage) + portal.sendBackfill(ctx, source, resp.Messages, false, resp.MarkRead, false, firstMessage) } func (portal *Portal) doThreadBackfill(ctx context.Context, source *UserLogin, threadID networkid.MessageID) { @@ -97,15 +97,15 @@ func (portal *Portal) doThreadBackfill(ctx context.Context, source *UserLogin, t log.Err(err).Msg("Failed to fetch messages for thread backfill") return } - portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, anchorMessage) + portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, true, anchorMessage) } -func (portal *Portal) sendBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead bool, lastMessage *database.Message) { +func (portal *Portal) sendBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, inThread bool, lastMessage *database.Message) { if lastMessage != nil { if forceForward { var cutoff int for i, msg := range messages { - if msg.Timestamp.Before(lastMessage.Timestamp) { + if msg.ID == lastMessage.ID || msg.Timestamp.Before(lastMessage.Timestamp) { cutoff = i } else { break @@ -122,7 +122,7 @@ func (portal *Portal) sendBackfill(ctx context.Context, source *UserLogin, messa } else { cutoff := -1 for i := len(messages) - 1; i >= 0; i-- { - if messages[i].Timestamp.After(lastMessage.Timestamp) { + if messages[i].ID == lastMessage.ID || messages[i].Timestamp.After(lastMessage.Timestamp) { cutoff = i } else { break @@ -146,9 +146,11 @@ func (portal *Portal) sendBackfill(ctx context.Context, source *UserLogin, messa portal.sendLegacyBackfill(ctx, source, messages, markRead) } zerolog.Ctx(ctx).Debug().Msg("Backfill finished") - for _, msg := range messages { - if msg.ShouldBackfillThread { - portal.doThreadBackfill(ctx, source, msg.ID) + if !inThread { + for _, msg := range messages { + if msg.ShouldBackfillThread { + portal.doThreadBackfill(ctx, source, msg.ID) + } } } } @@ -169,7 +171,7 @@ func (portal *Portal) sendBatch(ctx context.Context, source *UserLogin, messages var disappearingMessages []*database.DisappearingMessage for _, msg := range messages { intent := portal.GetIntentFor(ctx, msg.Sender, source, RemoteEventMessage) - replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta(ctx, msg.ReplyTo, msg.ThreadRoot, true) + replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta(ctx, msg.ID, msg.ReplyTo, msg.ThreadRoot, true) if threadRoot != nil && prevThreadEvents[*msg.ThreadRoot] != "" { prevThreadEvent.MXID = prevThreadEvents[*msg.ThreadRoot] } From 1bdadae180209135c909f4878bfb3a4fe03e2105 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 16 Jul 2024 18:19:46 +0300 Subject: [PATCH 0458/1647] Ensure `forwarding_curve25519_key_chain` is not null when sharing keys --- crypto/keysharing.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/crypto/keysharing.go b/crypto/keysharing.go index 362dee81..f4407cbb 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -337,6 +337,9 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User mach.rejectKeyRequest(ctx, KeyShareRejectInternalError, device, content.Body) return } + if igs.ForwardingChains == nil { + igs.ForwardingChains = []string{} + } forwardedRoomKey := event.Content{ Parsed: &event.ForwardedRoomKeyEventContent{ From 085859bfdd7c528021c85af4392c4279b3ff6926 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 16 Jul 2024 20:59:30 +0300 Subject: [PATCH 0459/1647] bridgev2: add UserInfo to ChatMember to allow updating ghost info easily --- bridgev2/portal.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 21654e9a..b2535f1a 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1719,6 +1719,7 @@ type ChatMember struct { Membership event.Membership Nickname string PowerLevel int + UserInfo *UserInfo PrevMembership event.Membership } @@ -1959,6 +1960,14 @@ func (portal *Portal) GetInitialMemberList(ctx context.Context, members *ChatMem if member.Membership != event.MembershipJoin && member.Membership != "" { continue } + if member.Sender != "" && member.UserInfo != nil { + ghost, err := portal.Bridge.GetGhostByID(ctx, member.Sender) + if err != nil { + zerolog.Ctx(ctx).Err(err).Str("ghost_id", string(member.Sender)).Msg("Failed to get ghost from member list to update info") + } else { + ghost.UpdateInfo(ctx, member.UserInfo) + } + } intent, extraUserID := portal.getIntentAndUserMXIDFor(ctx, member.EventSender, source, loginsInPortal, 0) if extraUserID != "" { invite = append(invite, extraUserID) @@ -2116,6 +2125,14 @@ func (portal *Portal) SyncParticipants(ctx context.Context, members *ChatMemberL } } for _, member := range members.Members { + if member.Sender != "" && member.UserInfo != nil { + ghost, err := portal.Bridge.GetGhostByID(ctx, member.Sender) + if err != nil { + zerolog.Ctx(ctx).Err(err).Str("ghost_id", string(member.Sender)).Msg("Failed to get ghost from member list to update info") + } else { + ghost.UpdateInfo(ctx, member.UserInfo) + } + } intent, extraUserID := portal.getIntentAndUserMXIDFor(ctx, member.EventSender, source, loginsInPortal, 0) if intent != nil { syncIntent(intent, member) From f120ac6b7e1eab2715a0395c0d7f2043e2fbf1f1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 17 Jul 2024 11:44:56 +0300 Subject: [PATCH 0460/1647] bridgev2: use add user-visible message to more errors --- bridgev2/messagestatus.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index 7be0c188..c49dbf1c 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -21,15 +21,15 @@ var ( ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage() ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false) ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false) - ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true) - ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true) - ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true) - ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true) - ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true) - ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true) - ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true) - ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithSendNotice(false) - ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true) + ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage() + ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage() + ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage() + ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true).WithErrorAsMessage() + ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage() + ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage() + ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage() + ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) + ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage() ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true) ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) From 0d81a91c9febfba606ac057159d9c162fd61945d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 17 Jul 2024 11:57:51 +0300 Subject: [PATCH 0461/1647] bridgev2: fix scanning message timestamp --- bridgev2/database/message.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index 184ce3d8..1403c9bc 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -157,7 +157,7 @@ func (m *Message) Scan(row dbutil.Scannable) (*Message, error) { var threadRootID, replyToID, replyToPartID sql.NullString err := row.Scan( &m.RowID, &m.BridgeID, &m.ID, &m.PartID, &m.MXID, &m.Room.ID, &m.Room.Receiver, &m.SenderID, &m.SenderMXID, - &m.EditCount, ×tamp, &threadRootID, &replyToID, &replyToPartID, dbutil.JSON{Data: m.Metadata}, + ×tamp, &m.EditCount, &threadRootID, &replyToID, &replyToPartID, dbutil.JSON{Data: m.Metadata}, ) if err != nil { return nil, err From 9e8d3050b0ccc0b9b2331dfb4519cc866ff6f8be Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Wed, 17 Jul 2024 13:29:13 +0300 Subject: [PATCH 0462/1647] Add MSC4144 per message profile types (#256) --- event/message.go | 9 +++++---- event/profile.go | 10 ++++++++++ 2 files changed, 15 insertions(+), 4 deletions(-) create mode 100644 event/profile.go diff --git a/event/message.go b/event/message.go index 003f1fcc..3c6edfdd 100644 --- a/event/message.go +++ b/event/message.go @@ -131,10 +131,11 @@ type MessageEventContent struct { replyFallbackRemoved bool - MessageSendRetry *BeeperRetryMetadata `json:"com.beeper.message_send_retry,omitempty"` - BeeperGalleryImages []*MessageEventContent `json:"com.beeper.gallery.images,omitempty"` - BeeperGalleryCaption string `json:"com.beeper.gallery.caption,omitempty"` - BeeperGalleryCaptionHTML string `json:"com.beeper.gallery.caption_html,omitempty"` + MessageSendRetry *BeeperRetryMetadata `json:"com.beeper.message_send_retry,omitempty"` + BeeperGalleryImages []*MessageEventContent `json:"com.beeper.gallery.images,omitempty"` + BeeperGalleryCaption string `json:"com.beeper.gallery.caption,omitempty"` + BeeperGalleryCaptionHTML string `json:"com.beeper.gallery.caption_html,omitempty"` + BeeperPerMessageProfile *BeeperPerMessageProfile `json:"com.beeper.per_message_profile,omitempty"` BeeperLinkPreviews []*BeeperLinkPreview `json:"com.beeper.linkpreviews,omitempty"` diff --git a/event/profile.go b/event/profile.go new file mode 100644 index 00000000..6dc4314a --- /dev/null +++ b/event/profile.go @@ -0,0 +1,10 @@ +package event + +import "maunium.net/go/mautrix/id" + +type BeeperPerMessageProfile struct { + ID string `json:"id"` + Displayname string `json:"displayname,omitempty"` + AvatarURL *id.ContentURIString `json:"avatar_url,omitempty"` + AvatarFile *EncryptedFileInfo `json:"avatar_file,omitempty"` +} From 62e36db08db4a33bf1d76c073aadbe5231e65733 Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Wed, 17 Jul 2024 14:49:04 +0300 Subject: [PATCH 0463/1647] bridgev2: Use pointer type for parsed content in replies (#257) --- bridgev2/commands/event.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/commands/event.go b/bridgev2/commands/event.go index 2a4b26a5..52d74512 100644 --- a/bridgev2/commands/event.go +++ b/bridgev2/commands/event.go @@ -56,7 +56,7 @@ func (ce *Event) Reply(msg string, args ...any) { func (ce *Event) ReplyAdvanced(msg string, allowMarkdown, allowHTML bool) { content := format.RenderMarkdown(msg, allowMarkdown, allowHTML) content.MsgType = event.MsgNotice - _, err := ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: content}, time.Now()) + _, err := ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: &content}, time.Now()) if err != nil { ce.Log.Err(err).Msgf("Failed to reply to command") } From c48630b4f3caf6207ed53cf88ea792322fffb3d5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 17 Jul 2024 17:17:46 +0300 Subject: [PATCH 0464/1647] bridgev2/backfill: add config --- bridgev2/bridgeconfig/backfill.go | 29 ++++++++++++++++++++ bridgev2/bridgeconfig/config.go | 2 ++ bridgev2/bridgeconfig/upgrade.go | 11 ++++++++ bridgev2/matrix/mxmain/example-config.yaml | 31 ++++++++++++++++++++++ bridgev2/matrix/mxmain/main.go | 1 + bridgev2/portal.go | 8 +++++- bridgev2/portalbackfill.go | 15 ++++++++--- 7 files changed, 92 insertions(+), 5 deletions(-) create mode 100644 bridgev2/bridgeconfig/backfill.go diff --git a/bridgev2/bridgeconfig/backfill.go b/bridgev2/bridgeconfig/backfill.go new file mode 100644 index 00000000..34218cc4 --- /dev/null +++ b/bridgev2/bridgeconfig/backfill.go @@ -0,0 +1,29 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgeconfig + +type BackfillConfig struct { + Enabled bool `yaml:"enabled"` + MaxInitialMessages int `yaml:"max_initial_messages"` + MaxCatchupMessages int `yaml:"max_catchup_messages"` + UnreadHoursThreshold int `yaml:"unread_hours_threshold"` + + Threads BackfillThreadsConfig `yaml:"threads"` + Queue BackfillQueueConfig `yaml:"queue"` +} + +type BackfillThreadsConfig struct { + MaxInitialMessages int `yaml:"max_initial_messages"` +} + +type BackfillQueueConfig struct { + BatchSize int `yaml:"batch_size"` + BatchDelay int `yaml:"batch_delay"` + MaxBatches int `yaml:"max_batches"` + + MaxBatchesOverride map[string]int `yaml:"max_batches_override"` +} diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index 623ee446..8c899ad7 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -23,6 +23,7 @@ type Config struct { Matrix MatrixConfig `yaml:"matrix"` Provisioning ProvisioningConfig `yaml:"provisioning"` DirectMedia DirectMediaConfig `yaml:"direct_media"` + Backfill BackfillConfig `yaml:"backfill"` DoublePuppet DoublePuppetConfig `yaml:"double_puppet"` Encryption EncryptionConfig `yaml:"encryption"` Logging zeroconfig.Config `yaml:"logging"` @@ -35,6 +36,7 @@ type BridgeConfig struct { PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"` Relay RelayConfig `yaml:"relay"` Permissions PermissionConfig `yaml:"permissions"` + Backfill BackfillConfig `yaml:"backfill"` } type MatrixConfig struct { diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index e74e2368..1d5ee0ae 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -92,6 +92,16 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Str, "direct_media", "server_key") } + helper.Copy(up.Bool, "backfill", "enabled") + helper.Copy(up.Int, "backfill", "max_initial_messages") + helper.Copy(up.Int, "backfill", "max_catchup_messages") + helper.Copy(up.Int, "backfill", "unread_hours_threshold") + helper.Copy(up.Int, "backfill", "threads", "max_initial_messages") + helper.Copy(up.Int, "backfill", "queue", "batch_size") + helper.Copy(up.Int, "backfill", "queue", "batch_delay") + helper.Copy(up.Int, "backfill", "queue", "max_batches") + helper.Copy(up.Map, "backfill", "queue", "max_batches_override") + helper.Copy(up.Map, "double_puppet", "servers") helper.Copy(up.Bool, "double_puppet", "allow_discovery") helper.Copy(up.Map, "double_puppet", "secrets") @@ -249,6 +259,7 @@ var SpacedBlocks = [][]string{ {"appservice", "username_template"}, {"matrix"}, {"provisioning"}, + {"backfill"}, {"direct_media"}, {"double_puppet"}, {"encryption"}, diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 86de8916..3399297c 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -180,6 +180,37 @@ direct_media: # This key is also used to sign the mxc:// URIs to ensure only the bridge can generate them. server_key: generate +# Settings for backfilling messages. +# Note that the exact way settings are applied depends on the network connector. +# See https://docs.mau.fi/bridges/general/backfill.html for more details. +backfill: + # Whether to do backfilling at all. + enabled: false + # Maximum number of messages to backfill in empty rooms. + max_initial_messages: 50 + # Maximum number of missed messages to backfill after bridge restarts. + max_catchup_messages: 500 + # If a backfilled chat is older than this number of hours, + # mark it as read even if it's unread on the remote network. + unread_hours_threshold: 720 + # Settings for backfilling threads within other backfills. + threads: + # Maximum number of messages to backfill in a new thread. + max_initial_messages: 50 + # Settings for the backwards backfill queue. This only applies when connecting to + # Beeper as standard Matrix servers don't support inserting messages into history. + queue: + # Number of messages to backfill in one batch. + batch_size: 100 + # Delay between batches in seconds. + batch_delay: 20 + # Maximum number of batches to backfill per portal. + # If set to -1, all available messages will be backfilled. + max_batches: -1 + # Optional network-specific overrides for max batches. + # Interpretation of this field depends on the network connector. + max_batches_override: {} + # Settings for enabling double puppeting double_puppet: # Servers to always allow double puppeting from. diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index 16a54b29..1a7c5217 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -351,6 +351,7 @@ func (br *BridgeMain) LoadConfig() { os.Exit(10) } } + cfg.Bridge.Backfill = cfg.Backfill br.Config = &cfg } diff --git a/bridgev2/portal.go b/bridgev2/portal.go index b2535f1a..1f2c752f 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1084,6 +1084,7 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { }() log.UpdateContext(evt.AddLogContext) ctx := log.WithContext(context.TODO()) + evtType := evt.GetType() if portal.MXID == "" { mcp, ok := evt.(RemoteEventThatMayCreatePortal) if !ok || !mcp.ShouldCreatePortal() { @@ -1104,12 +1105,16 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { // TODO error return } + // TODO if CreateMatrixRoom is changed to backfill immediately, there's no need to handle chat resyncs further + //if evtType == RemoteEventChatResync { + // log.Debug().Msg("Not handling chat resync event further as portal was created by it") + // return + //} } preHandler, ok := evt.(RemotePreHandler) if ok { preHandler.PreHandle(ctx, portal) } - evtType := evt.GetType() log.Debug().Stringer("bridge_evt_type", evtType).Msg("Handling remote event") switch evtType { case RemoteEventUnknown: @@ -2477,6 +2482,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i } } } + // TODO backfill portal? if portal.Parent == nil { userPortals, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) if err != nil { diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index 06eabbf2..9d25ccb8 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -29,10 +29,17 @@ func (portal *Portal) doForwardBackfill(ctx context.Context, source *UserLogin, return } logEvt := log.Info() + var limit int if lastMessage != nil { logEvt = logEvt.Str("latest_message_id", string(lastMessage.ID)) + limit = portal.Bridge.Config.Backfill.MaxCatchupMessages } else { logEvt = logEvt.Str("latest_message_id", "") + limit = portal.Bridge.Config.Backfill.MaxInitialMessages + } + if limit <= 0 { + logEvt.Discard().Send() + return } logEvt.Msg("Fetching messages for forward backfill") resp, err := api.FetchMessages(ctx, FetchMessagesParams{ @@ -40,7 +47,7 @@ func (portal *Portal) doForwardBackfill(ctx context.Context, source *UserLogin, ThreadRoot: "", Forward: true, AnchorMessage: lastMessage, - Count: 100, // TODO make count configurable + Count: limit, }) if err != nil { log.Err(err).Msg("Failed to fetch messages for forward backfill") @@ -66,7 +73,7 @@ func (portal *Portal) DoBackwardsBackfill(ctx context.Context, source *UserLogin ThreadRoot: "", Forward: false, AnchorMessage: firstMessage, - Count: 100, // TODO make count configurable + Count: portal.Bridge.Config.Backfill.Queue.BatchSize, }) if err != nil { log.Err(err).Msg("Failed to fetch messages for forward backfill") @@ -91,7 +98,7 @@ func (portal *Portal) doThreadBackfill(ctx context.Context, source *UserLogin, t ThreadRoot: threadID, Forward: true, AnchorMessage: anchorMessage, - Count: 100, // TODO make count configurable + Count: portal.Bridge.Config.Backfill.Threads.MaxInitialMessages, }) if err != nil { log.Err(err).Msg("Failed to fetch messages for thread backfill") @@ -146,7 +153,7 @@ func (portal *Portal) sendBackfill(ctx context.Context, source *UserLogin, messa portal.sendLegacyBackfill(ctx, source, messages, markRead) } zerolog.Ctx(ctx).Debug().Msg("Backfill finished") - if !inThread { + if !inThread && portal.Bridge.Config.Backfill.Threads.MaxInitialMessages > 0 { for _, msg := range messages { if msg.ShouldBackfillThread { portal.doThreadBackfill(ctx, source, msg.ID) From 328be908b5bef6bc9032e1fb12a8e9808c905d26 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 17 Jul 2024 17:32:42 +0300 Subject: [PATCH 0465/1647] bridgev2/backfill: add stub backfill queue --- bridgev2/backfillqueue.go | 146 ++++++++++++++++++ bridgev2/bridge.go | 4 + bridgev2/database/backfillqueue.go | 138 +++++++++++++++++ bridgev2/database/database.go | 7 + bridgev2/database/upgrades/00-latest.sql | 24 ++- .../upgrades/12-dm-portal-other-user.sql | 2 +- .../database/upgrades/13-backfill-queue.sql | 20 +++ 7 files changed, 338 insertions(+), 3 deletions(-) create mode 100644 bridgev2/backfillqueue.go create mode 100644 bridgev2/database/backfillqueue.go create mode 100644 bridgev2/database/upgrades/13-backfill-queue.sql diff --git a/bridgev2/backfillqueue.go b/bridgev2/backfillqueue.go new file mode 100644 index 00000000..d36f6def --- /dev/null +++ b/bridgev2/backfillqueue.go @@ -0,0 +1,146 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "fmt" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/bridgev2/database" +) + +func (br *Bridge) WakeupBackfillQueue() { + select { + case br.wakeupBackfillQueue <- struct{}{}: + default: + } +} + +func (br *Bridge) RunBackfillQueue() { + if !br.Matrix.GetCapabilities().BatchSending { + return + } + log := br.Log.With().Str("component", "backfill queue").Logger() + ctx := log.WithContext(context.Background()) + batchDelay := time.Duration(br.Config.Backfill.Queue.BatchDelay) * time.Second + afterTimer := time.NewTimer(batchDelay) + for { + backfillTask, err := br.DB.BackfillQueue.GetNext(ctx) + if err != nil { + log.Err(err).Msg("Failed to get next backfill queue entry") + time.Sleep(1 * time.Minute) + continue + } else if backfillTask != nil { + br.doBackfillTask(ctx, backfillTask) + } + nextDelay := batchDelay + if backfillTask == nil { + nextDelay = max(10*time.Minute, batchDelay) + } + if !afterTimer.Stop() { + <-afterTimer.C + } + afterTimer.Reset(nextDelay) + select { + case <-br.wakeupBackfillQueue: + case <-afterTimer.C: + } + } +} + +func (br *Bridge) doBackfillTask(ctx context.Context, task *database.BackfillTask) { + log := zerolog.Ctx(ctx).With(). + Object("portal_key", task.PortalKey). + Str("login_id", string(task.UserLoginID)). + Logger() + err := br.DB.BackfillQueue.MarkDispatched(ctx, task) + if err != nil { + log.Err(err).Msg("Failed to mark backfill task as dispatched") + time.Sleep(1 * time.Minute) + return + } + completed, err := br.actuallyDoBackfillTask(ctx, task) + if err != nil { + log.Err(err).Msg("Failed to do backfill task") + time.Sleep(1 * time.Minute) + return + } else if completed { + log.Info().Msg("Backfill task completed successfully") + } else { + log.Info().Msg("Backfill task canceled") + } + err = br.DB.BackfillQueue.Update(ctx, task) + if err != nil { + log.Err(err).Msg("Failed to update backfill task") + time.Sleep(1 * time.Minute) + } +} + +func (br *Bridge) actuallyDoBackfillTask(ctx context.Context, task *database.BackfillTask) (bool, error) { + log := zerolog.Ctx(ctx) + portal, err := br.GetExistingPortalByKey(ctx, task.PortalKey) + if err != nil { + return false, fmt.Errorf("failed to get portal for backfill task: %w", err) + } else if portal == nil { + log.Warn().Msg("Portal not found for backfill task") + err = br.DB.BackfillQueue.Delete(ctx, task.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to delete backfill task after portal wasn't found") + time.Sleep(1 * time.Minute) + } + return false, nil + } else if portal.MXID == "" { + log.Debug().Msg("Portal for backfill task doesn't exist") + task.NextDispatchMinTS = database.BackfillNextDispatchNever + task.UserLoginID = "" + return false, nil + } + login, err := br.GetExistingUserLoginByID(ctx, task.UserLoginID) + if err != nil { + return false, fmt.Errorf("failed to get user login for backfill task: %w", err) + } else if login == nil { + log.Warn().Msg("User login not found for backfill task") + logins, err := br.GetUserLoginsInPortal(ctx, portal.PortalKey) + if err != nil { + return false, fmt.Errorf("failed to get user portals for backfill task: %w", err) + } else if len(logins) == 0 { + log.Debug().Msg("No user logins found for backfill task") + task.NextDispatchMinTS = database.BackfillNextDispatchNever + task.UserLoginID = "" + return false, nil + } + task.UserLoginID = "" + for _, login = range logins { + if login.Client.IsLoggedIn() { + task.UserLoginID = login.ID + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("overridden_login_id", string(login.ID)) + }) + log.Debug().Msg("Found user login for backfill task") + break + } + } + if task.UserLoginID == "" { + log.Debug().Msg("No logged in user logins found for backfill task") + task.NextDispatchMinTS = database.BackfillNextDispatchNever + return false, nil + } + } + maxBatches := br.Config.Backfill.Queue.MaxBatches + // TODO apply max batch overrides + // TODO actually backfill + hasMoreMessages := true + task.BatchCount++ + task.IsDone = task.BatchCount >= maxBatches || !hasMoreMessages + batchDelay := time.Duration(br.Config.Backfill.Queue.BatchDelay) * time.Second + task.CompletedAt = time.Now() + task.NextDispatchMinTS = task.CompletedAt.Add(batchDelay) + return true, nil +} diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 62dc532b..945a6f90 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -47,6 +47,8 @@ type Bridge struct { portalsByMXID map[id.RoomID]*Portal ghostsByID map[networkid.UserID]*Ghost cacheLock sync.Mutex + + wakeupBackfillQueue chan struct{} } func NewBridge( @@ -72,6 +74,8 @@ func NewBridge( portalsByKey: make(map[networkid.PortalKey]*Portal), portalsByMXID: make(map[id.RoomID]*Portal), ghostsByID: make(map[networkid.UserID]*Ghost), + + wakeupBackfillQueue: make(chan struct{}), } if br.Config == nil { br.Config = &bridgeconfig.BridgeConfig{CommandPrefix: "!bridge"} diff --git a/bridgev2/database/backfillqueue.go b/bridgev2/database/backfillqueue.go new file mode 100644 index 00000000..50d4a113 --- /dev/null +++ b/bridgev2/database/backfillqueue.go @@ -0,0 +1,138 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "time" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/bridgev2/networkid" +) + +type BackfillQueueQuery struct { + BridgeID networkid.BridgeID + *dbutil.QueryHelper[*BackfillTask] +} + +type BackfillTask struct { + BridgeID networkid.BridgeID + PortalKey networkid.PortalKey + UserLoginID networkid.UserLoginID + + BatchCount int + IsDone bool + Cursor networkid.PaginationCursor + OldestMessageID networkid.MessageID + DispatchedAt time.Time + CompletedAt time.Time + NextDispatchMinTS time.Time +} + +var BackfillNextDispatchNever = time.Unix(0, (1<<63)-1) + +const ( + ensureBackfillExistsQuery = ` + INSERT INTO backfill_queue (bridge_id, portal_id, portal_receiver, user_login_id, batch_count, is_done, next_dispatch_min_ts) + VALUES ($1, $2, $3, $4, 0, false, $5) + ON CONFLICT DO UPDATE + SET user_login_id=excluded.user_login_id, + next_dispatch_min_ts=CASE + WHEN next_dispatch_min_ts=9223372036854775807 + THEN excluded.next_dispatch_min_ts + ELSE next_dispatch_min_ts + END + ` + markBackfillDispatchedQuery = ` + UPDATE backfill_queue SET dispatched_at=$4, completed_at=NULL, next_dispatch_min_ts=$5 + WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3 + ` + updateBackfillQueueQuery = ` + UPDATE backfill_queue + SET user_login_id=$4, batch_count=$5, is_done=$6, cursor=$7, oldest_message_id=$8, + dispatched_at=$9, completed_at=$10, next_dispatch_min_ts=$11 + WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3 + ` + getNextBackfillQuery = ` + SELECT + bridge_id, portal_id, portal_receiver, user_login_id, batch_count, is_done, + cursor, oldest_message_id, dispatched_at, completed_at, next_dispatch_min_ts + FROM backfill_queue + WHERE bridge_id = $1 AND next_dispatch_min_ts < $2 AND is_done = false AND user_login_id <> '' + ORDER BY next_dispatch_min_ts LIMIT 1 + ` + deleteBackfillQueueQuery = ` + DELETE FROM backfill_queue + WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3 + ` +) + +func (bqq *BackfillQueueQuery) EnsureExists(ctx context.Context, portal networkid.PortalKey) error { + return bqq.Exec(ctx, ensureBackfillExistsQuery, bqq.BridgeID, portal.ID, portal.Receiver, time.Now().UnixNano()) +} + +const UnfinishedBackfillBackoff = 1 * time.Hour + +func (bqq *BackfillQueueQuery) MarkDispatched(ctx context.Context, bq *BackfillTask) error { + ensureBridgeIDMatches(&bq.BridgeID, bqq.BridgeID) + bq.DispatchedAt = time.Now() + bq.CompletedAt = time.Time{} + bq.NextDispatchMinTS = bq.DispatchedAt.Add(UnfinishedBackfillBackoff) + return bqq.Exec( + ctx, markBackfillDispatchedQuery, + bq.BridgeID, bq.PortalKey.ID, bq.PortalKey.Receiver, + bq.DispatchedAt.UnixNano(), bq.NextDispatchMinTS.UnixNano(), + ) +} + +func (bqq *BackfillQueueQuery) Update(ctx context.Context, bq *BackfillTask) error { + ensureBridgeIDMatches(&bq.BridgeID, bqq.BridgeID) + return bqq.Exec(ctx, updateBackfillQueueQuery, bq.sqlVariables()...) +} + +func (bqq *BackfillQueueQuery) GetNext(ctx context.Context) (*BackfillTask, error) { + return bqq.QueryOne(ctx, getNextBackfillQuery, bqq.BridgeID, time.Now().UnixNano()) +} + +func (bqq *BackfillQueueQuery) Delete(ctx context.Context, portalKey networkid.PortalKey) error { + return bqq.Exec(ctx, deleteBackfillQueueQuery, bqq.BridgeID, portalKey.ID, portalKey.Receiver) +} + +func (bt *BackfillTask) Scan(row dbutil.Scannable) (*BackfillTask, error) { + var cursor, oldestMessageID sql.NullString + var dispatchedAt, completedAt, nextDispatchMinTS sql.NullInt64 + err := row.Scan( + &bt.BridgeID, &bt.PortalKey.ID, &bt.PortalKey.Receiver, &bt.UserLoginID, &bt.BatchCount, &bt.IsDone, + &cursor, &oldestMessageID, &dispatchedAt, &completedAt, &nextDispatchMinTS) + if err != nil { + return nil, err + } + bt.Cursor = networkid.PaginationCursor(cursor.String) + bt.OldestMessageID = networkid.MessageID(oldestMessageID.String) + if dispatchedAt.Valid { + bt.DispatchedAt = time.Unix(0, dispatchedAt.Int64) + } + if completedAt.Valid { + bt.CompletedAt = time.Unix(0, completedAt.Int64) + } + if nextDispatchMinTS.Valid { + bt.NextDispatchMinTS = time.Unix(0, nextDispatchMinTS.Int64) + } + return bt, nil +} + +func (bt *BackfillTask) sqlVariables() []any { + return []any{ + bt.BridgeID, bt.PortalKey.ID, bt.PortalKey.Receiver, bt.UserLoginID, bt.BatchCount, bt.IsDone, + dbutil.StrPtr(bt.Cursor), dbutil.StrPtr(bt.OldestMessageID), + dbutil.ConvertedPtr(bt.DispatchedAt, time.Time.UnixNano), + dbutil.ConvertedPtr(bt.CompletedAt, time.Time.UnixNano), + bt.NextDispatchMinTS.UnixNano(), + } +} diff --git a/bridgev2/database/database.go b/bridgev2/database/database.go index 47858fba..16e556cc 100644 --- a/bridgev2/database/database.go +++ b/bridgev2/database/database.go @@ -32,6 +32,7 @@ type Database struct { User *UserQuery UserLogin *UserLoginQuery UserPortal *UserPortalQuery + BackfillQueue *BackfillQueueQuery } type MetaMerger interface { @@ -129,6 +130,12 @@ func New(bridgeID networkid.BridgeID, mt MetaTypes, db *dbutil.Database) *Databa return &UserPortal{} }), }, + BackfillQueue: &BackfillQueueQuery{ + BridgeID: bridgeID, + QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*BackfillTask]) *BackfillTask { + return &BackfillTask{} + }), + }, } } diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index 943f7a59..0d388f7f 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v12 (compatible with v9+): Latest revision +-- v0 -> v13 (compatible with v9+): Latest revision CREATE TABLE "user" ( bridge_id TEXT NOT NULL, mxid TEXT NOT NULL, @@ -37,7 +37,7 @@ CREATE TABLE portal ( relay_bridge_id TEXT, relay_login_id TEXT, - other_user_id TEXT, + other_user_id TEXT, name TEXT NOT NULL, topic TEXT NOT NULL, @@ -175,3 +175,23 @@ CREATE TABLE user_portal ( ); CREATE INDEX user_portal_login_idx ON user_portal (bridge_id, login_id); CREATE INDEX user_portal_portal_idx ON user_portal (bridge_id, portal_id, portal_receiver); + +CREATE TABLE backfill_queue ( + bridge_id TEXT NOT NULL, + portal_id TEXT NOT NULL, + portal_receiver TEXT NOT NULL, + user_login_id TEXT NOT NULL, + + batch_count INTEGER NOT NULL, + is_done BOOLEAN NOT NULL, + cursor TEXT, + oldest_message_id TEXT, + dispatched_at BIGINT, + completed_at BIGINT, + next_dispatch_min_ts BIGINT NOT NULL, + + PRIMARY KEY (bridge_id, portal_id, portal_receiver), + CONSTRAINT backfill_queue_portal_fkey FOREIGN KEY (bridge_id, portal_id, portal_receiver) + REFERENCES portal (bridge_id, id, receiver) + ON DELETE CASCADE ON UPDATE CASCADE +); diff --git a/bridgev2/database/upgrades/12-dm-portal-other-user.sql b/bridgev2/database/upgrades/12-dm-portal-other-user.sql index 60ae4e3a..2d2cb900 100644 --- a/bridgev2/database/upgrades/12-dm-portal-other-user.sql +++ b/bridgev2/database/upgrades/12-dm-portal-other-user.sql @@ -1,2 +1,2 @@ --- v12: Save other user ID in DM portals +-- v12 (compatible with v9+): Save other user ID in DM portals ALTER TABLE portal ADD COLUMN other_user_id TEXT; diff --git a/bridgev2/database/upgrades/13-backfill-queue.sql b/bridgev2/database/upgrades/13-backfill-queue.sql new file mode 100644 index 00000000..b8f511e6 --- /dev/null +++ b/bridgev2/database/upgrades/13-backfill-queue.sql @@ -0,0 +1,20 @@ +-- v13 (compatible with v9+): Add backfill queue +CREATE TABLE backfill_queue ( + bridge_id TEXT NOT NULL, + portal_id TEXT NOT NULL, + portal_receiver TEXT NOT NULL, + user_login_id TEXT NOT NULL, + + batch_count INTEGER NOT NULL, + is_done BOOLEAN NOT NULL, + cursor TEXT, + oldest_message_id TEXT, + dispatched_at BIGINT, + completed_at BIGINT, + next_dispatch_min_ts BIGINT NOT NULL, + + PRIMARY KEY (bridge_id, portal_id, portal_receiver), + CONSTRAINT backfill_queue_portal_fkey FOREIGN KEY (bridge_id, portal_id, portal_receiver) + REFERENCES portal (bridge_id, id, receiver) + ON DELETE CASCADE ON UPDATE CASCADE +); From f80e3d68386d3e1e4b1a5ec02b9581791a951b56 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 18 Jul 2024 14:56:37 +0300 Subject: [PATCH 0466/1647] bridgev2/backfill: actually call backfill function --- bridgev2/backfillqueue.go | 43 ++++++++---- bridgev2/portalbackfill.go | 134 ++++++++++++++++++++++++------------- 2 files changed, 119 insertions(+), 58 deletions(-) diff --git a/bridgev2/backfillqueue.go b/bridgev2/backfillqueue.go index d36f6def..d234b60d 100644 --- a/bridgev2/backfillqueue.go +++ b/bridgev2/backfillqueue.go @@ -16,6 +16,10 @@ import ( "maunium.net/go/mautrix/bridgev2/database" ) +const BackfillMinBackoffAfterRoomCreate = 1 * time.Minute +const BackfillQueueErrorBackoff = 1 * time.Minute +const BackfillQueueMinEmptyBackoff = 10 * time.Minute + func (br *Bridge) WakeupBackfillQueue() { select { case br.wakeupBackfillQueue <- struct{}{}: @@ -35,14 +39,14 @@ func (br *Bridge) RunBackfillQueue() { backfillTask, err := br.DB.BackfillQueue.GetNext(ctx) if err != nil { log.Err(err).Msg("Failed to get next backfill queue entry") - time.Sleep(1 * time.Minute) + time.Sleep(BackfillQueueErrorBackoff) continue } else if backfillTask != nil { br.doBackfillTask(ctx, backfillTask) } nextDelay := batchDelay if backfillTask == nil { - nextDelay = max(10*time.Minute, batchDelay) + nextDelay = max(BackfillQueueMinEmptyBackoff, batchDelay) } if !afterTimer.Stop() { <-afterTimer.C @@ -63,13 +67,13 @@ func (br *Bridge) doBackfillTask(ctx context.Context, task *database.BackfillTas err := br.DB.BackfillQueue.MarkDispatched(ctx, task) if err != nil { log.Err(err).Msg("Failed to mark backfill task as dispatched") - time.Sleep(1 * time.Minute) + time.Sleep(BackfillQueueErrorBackoff) return } completed, err := br.actuallyDoBackfillTask(ctx, task) if err != nil { log.Err(err).Msg("Failed to do backfill task") - time.Sleep(1 * time.Minute) + time.Sleep(BackfillQueueErrorBackoff) return } else if completed { log.Info().Msg("Backfill task completed successfully") @@ -79,10 +83,25 @@ func (br *Bridge) doBackfillTask(ctx context.Context, task *database.BackfillTas err = br.DB.BackfillQueue.Update(ctx, task) if err != nil { log.Err(err).Msg("Failed to update backfill task") - time.Sleep(1 * time.Minute) + time.Sleep(BackfillQueueErrorBackoff) } } +func (portal *Portal) deleteBackfillQueueTaskIfRoomDoesNotExist(ctx context.Context) bool { + // Acquire the room create lock to ensure that task deletion doesn't race with room creation + portal.roomCreateLock.Lock() + defer portal.roomCreateLock.Unlock() + if portal.MXID == "" { + zerolog.Ctx(ctx).Debug().Msg("Portal for backfill task doesn't exist, deleting entry") + err := portal.Bridge.DB.BackfillQueue.Delete(ctx, portal.PortalKey) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to delete backfill task after portal wasn't found") + } + return true + } + return false +} + func (br *Bridge) actuallyDoBackfillTask(ctx context.Context, task *database.BackfillTask) (bool, error) { log := zerolog.Ctx(ctx) portal, err := br.GetExistingPortalByKey(ctx, task.PortalKey) @@ -93,13 +112,11 @@ func (br *Bridge) actuallyDoBackfillTask(ctx context.Context, task *database.Bac err = br.DB.BackfillQueue.Delete(ctx, task.PortalKey) if err != nil { log.Err(err).Msg("Failed to delete backfill task after portal wasn't found") - time.Sleep(1 * time.Minute) + time.Sleep(BackfillQueueErrorBackoff) } return false, nil } else if portal.MXID == "" { - log.Debug().Msg("Portal for backfill task doesn't exist") - task.NextDispatchMinTS = database.BackfillNextDispatchNever - task.UserLoginID = "" + portal.deleteBackfillQueueTaskIfRoomDoesNotExist(ctx) return false, nil } login, err := br.GetExistingUserLoginByID(ctx, task.UserLoginID) @@ -135,10 +152,12 @@ func (br *Bridge) actuallyDoBackfillTask(ctx context.Context, task *database.Bac } maxBatches := br.Config.Backfill.Queue.MaxBatches // TODO apply max batch overrides - // TODO actually backfill - hasMoreMessages := true + err = portal.DoBackwardsBackfill(ctx, login, task) + if err != nil { + return false, fmt.Errorf("failed to backfill: %w", err) + } task.BatchCount++ - task.IsDone = task.BatchCount >= maxBatches || !hasMoreMessages + task.IsDone = task.IsDone || task.BatchCount >= maxBatches batchDelay := time.Duration(br.Config.Backfill.Queue.BatchDelay) * time.Second task.CompletedAt = time.Now() task.NextDispatchMinTS = task.CompletedAt.Add(batchDelay) diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index 9d25ccb8..2c7f39be 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -8,6 +8,7 @@ package bridgev2 import ( "context" + "fmt" "time" "github.com/rs/zerolog" @@ -52,34 +53,60 @@ func (portal *Portal) doForwardBackfill(ctx context.Context, source *UserLogin, if err != nil { log.Err(err).Msg("Failed to fetch messages for forward backfill") return + } else if len(resp.Messages) == 0 { + log.Debug().Msg("No messages to backfill") + return } - portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, false, lastMessage) + resp.Messages = cutoffMessages(&log, resp.Messages, true, lastMessage) + if len(resp.Messages) == 0 { + log.Warn().Msg("No messages left to backfill after cutting off old messages") + return + } + portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, false) } -func (portal *Portal) DoBackwardsBackfill(ctx context.Context, source *UserLogin) { +func (portal *Portal) DoBackwardsBackfill(ctx context.Context, source *UserLogin, task *database.BackfillTask) error { log := zerolog.Ctx(ctx) api, ok := source.Client.(BackfillingNetworkAPI) if !ok { - log.Debug().Msg("Network API does not support backfilling") - return + return fmt.Errorf("network API does not support backfilling") } firstMessage, err := portal.Bridge.DB.Message.GetFirstPortalMessage(ctx, portal.PortalKey) if err != nil { - log.Err(err).Msg("Failed to get oldest portal message") - return + return fmt.Errorf("failed to get first portal message: %w", err) } resp, err := api.FetchMessages(ctx, FetchMessagesParams{ Portal: portal, ThreadRoot: "", Forward: false, + Cursor: task.Cursor, AnchorMessage: firstMessage, Count: portal.Bridge.Config.Backfill.Queue.BatchSize, }) if err != nil { - log.Err(err).Msg("Failed to fetch messages for forward backfill") - return + return fmt.Errorf("failed to fetch messages for backward backfill: %w", err) } - portal.sendBackfill(ctx, source, resp.Messages, false, resp.MarkRead, false, firstMessage) + task.Cursor = resp.Cursor + if !resp.HasMore { + task.IsDone = true + } + if len(resp.Messages) == 0 { + if !resp.HasMore { + log.Debug().Msg("No messages to backfill, marking backfill task as done") + } else { + log.Warn().Msg("No messages to backfill, but HasMore is true") + } + return nil + } + resp.Messages = cutoffMessages(log, resp.Messages, false, firstMessage) + if len(resp.Messages) == 0 { + return fmt.Errorf("no messages left to backfill after cutting off too new messages") + } + portal.sendBackfill(ctx, source, resp.Messages, false, resp.MarkRead, false) + if len(resp.Messages) > 0 { + task.OldestMessageID = resp.Messages[0].ID + } + return nil } func (portal *Portal) doThreadBackfill(ctx context.Context, source *UserLogin, threadID networkid.MessageID) { @@ -103,48 +130,61 @@ func (portal *Portal) doThreadBackfill(ctx context.Context, source *UserLogin, t if err != nil { log.Err(err).Msg("Failed to fetch messages for thread backfill") return + } else if len(resp.Messages) == 0 { + log.Debug().Msg("No messages to backfill") + return } - portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, true, anchorMessage) + resp.Messages = cutoffMessages(&log, resp.Messages, true, anchorMessage) + if len(resp.Messages) == 0 { + log.Warn().Msg("No messages left to backfill after cutting off old messages") + return + } + portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, true) } -func (portal *Portal) sendBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, inThread bool, lastMessage *database.Message) { - if lastMessage != nil { - if forceForward { - var cutoff int - for i, msg := range messages { - if msg.ID == lastMessage.ID || msg.Timestamp.Before(lastMessage.Timestamp) { - cutoff = i - } else { - break - } - } - if cutoff != 0 { - zerolog.Ctx(ctx).Debug(). - Int("cutoff_count", cutoff). - Int("total_count", len(messages)). - Time("last_bridged_ts", lastMessage.Timestamp). - Msg("Cutting off forward backfill messages older than latest bridged message") - messages = messages[cutoff:] - } - } else { - cutoff := -1 - for i := len(messages) - 1; i >= 0; i-- { - if messages[i].ID == lastMessage.ID || messages[i].Timestamp.After(lastMessage.Timestamp) { - cutoff = i - } else { - break - } - } - if cutoff != -1 { - zerolog.Ctx(ctx).Debug(). - Int("cutoff_count", len(messages)-cutoff). - Int("total_count", len(messages)). - Time("oldest_bridged_ts", lastMessage.Timestamp). - Msg("Cutting off backward backfill messages newer than oldest bridged message") - messages = messages[cutoff:] +func cutoffMessages(log *zerolog.Logger, messages []*BackfillMessage, forward bool, lastMessage *database.Message) []*BackfillMessage { + if lastMessage == nil { + return messages + } + if forward { + var cutoff int + for i, msg := range messages { + if msg.ID == lastMessage.ID || msg.Timestamp.Before(lastMessage.Timestamp) { + cutoff = i + } else { + break } } + if cutoff != 0 { + log.Debug(). + Int("cutoff_count", cutoff). + Int("total_count", len(messages)). + Time("last_bridged_ts", lastMessage.Timestamp). + Msg("Cutting off forward backfill messages older than latest bridged message") + messages = messages[cutoff:] + } + } else { + cutoff := -1 + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].ID == lastMessage.ID || messages[i].Timestamp.After(lastMessage.Timestamp) { + cutoff = i + } else { + break + } + } + if cutoff != -1 { + log.Debug(). + Int("cutoff_count", len(messages)-cutoff). + Int("total_count", len(messages)). + Time("oldest_bridged_ts", lastMessage.Timestamp). + Msg("Cutting off backward backfill messages newer than oldest bridged message") + messages = messages[cutoff:] + } } + return messages +} + +func (portal *Portal) sendBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, inThread bool) { canBatchSend := portal.Bridge.Matrix.GetCapabilities().BatchSending zerolog.Ctx(ctx).Info().Int("message_count", len(messages)).Bool("batch_send", canBatchSend).Msg("Sending backfill messages") if canBatchSend { @@ -232,14 +272,16 @@ func (portal *Portal) sendBatch(ctx context.Context, source *UserLogin, messages zerolog.Ctx(ctx).Err(err).Msg("Failed to send backfill messages") } if len(disappearingMessages) > 0 { + // TODO mass insert disappearing messages go func() { for _, msg := range disappearingMessages { portal.Bridge.DisappearLoop.Add(ctx, msg) } }() } + // TODO mass insert db messages for _, msg := range dbMessages { - err := portal.Bridge.DB.Message.Insert(ctx, msg) + err = portal.Bridge.DB.Message.Insert(ctx, msg) if err != nil { zerolog.Ctx(ctx).Err(err). Str("message_id", string(msg.ID)). From 18bca337a5f6d12ed75c78d6f734db05dc9c652f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 18 Jul 2024 14:56:56 +0300 Subject: [PATCH 0467/1647] bridgev2/backfill: insert backfill queue task when creating portal --- bridgev2/database/backfillqueue.go | 36 +++++++++++++++++++++++++----- bridgev2/portal.go | 18 +++++++++++++++ 2 files changed, 48 insertions(+), 6 deletions(-) diff --git a/bridgev2/database/backfillqueue.go b/bridgev2/database/backfillqueue.go index 50d4a113..db90bf8d 100644 --- a/bridgev2/database/backfillqueue.go +++ b/bridgev2/database/backfillqueue.go @@ -41,14 +41,33 @@ const ( ensureBackfillExistsQuery = ` INSERT INTO backfill_queue (bridge_id, portal_id, portal_receiver, user_login_id, batch_count, is_done, next_dispatch_min_ts) VALUES ($1, $2, $3, $4, 0, false, $5) - ON CONFLICT DO UPDATE - SET user_login_id=excluded.user_login_id, + ON CONFLICT (bridge_id, portal_id, portal_receiver) DO UPDATE + SET user_login_id=CASE + WHEN backfill_queue.user_login_id='' + THEN excluded.user_login_id + ELSE backfill_queue.user_login_id + END, next_dispatch_min_ts=CASE - WHEN next_dispatch_min_ts=9223372036854775807 + WHEN backfill_queue.next_dispatch_min_ts=9223372036854775807 THEN excluded.next_dispatch_min_ts - ELSE next_dispatch_min_ts + ELSE backfill_queue.next_dispatch_min_ts END ` + upsertBackfillQueueQuery = ` + INSERT INTO backfill_queue ( + bridge_id, portal_id, portal_receiver, user_login_id, batch_count, is_done, cursor, + oldest_message_id, dispatched_at, completed_at, next_dispatch_min_ts + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) + ON CONFLICT (bridge_id, portal_id, portal_receiver) DO UPDATE + SET user_login_id=excluded.user_login_id, + batch_count=excluded.batch_count, + is_done=excluded.is_done, + cursor=excluded.cursor, + oldest_message_id=excluded.oldest_message_id, + dispatched_at=excluded.dispatched_at, + completed_at=excluded.completed_at, + next_dispatch_min_ts=excluded.next_dispatch_min_ts + ` markBackfillDispatchedQuery = ` UPDATE backfill_queue SET dispatched_at=$4, completed_at=NULL, next_dispatch_min_ts=$5 WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3 @@ -73,8 +92,13 @@ const ( ` ) -func (bqq *BackfillQueueQuery) EnsureExists(ctx context.Context, portal networkid.PortalKey) error { - return bqq.Exec(ctx, ensureBackfillExistsQuery, bqq.BridgeID, portal.ID, portal.Receiver, time.Now().UnixNano()) +func (bqq *BackfillQueueQuery) EnsureExists(ctx context.Context, portal networkid.PortalKey, loginID networkid.UserLoginID) error { + return bqq.Exec(ctx, ensureBackfillExistsQuery, bqq.BridgeID, portal.ID, portal.Receiver, loginID, time.Now().UnixNano()) +} + +func (bqq *BackfillQueueQuery) Upsert(ctx context.Context, bq *BackfillTask) error { + ensureBridgeIDMatches(&bq.BridgeID, bqq.BridgeID) + return bqq.Exec(ctx, upsertBackfillQueueQuery, bq.sqlVariables()...) } const UnfinishedBackfillBackoff = 1 * time.Hour diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 1f2c752f..5e867665 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1816,6 +1816,8 @@ type ChatInfo struct { UserLocal *UserLocalPortalInfo + CanBackfill bool + ExtraUpdates func(context.Context, *Portal) bool } @@ -2320,6 +2322,12 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us source.MarkInPortal(ctx, portal) portal.updateUserLocalInfo(ctx, info.UserLocal, source) } + if info.CanBackfill && source != nil { + err := portal.Bridge.DB.BackfillQueue.EnsureExists(ctx, portal.PortalKey, source.ID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to ensure backfill queue task exists") + } + } if info.ExtraUpdates != nil { changed = info.ExtraUpdates(ctx, portal) || changed } @@ -2457,6 +2465,16 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i log.Err(err).Msg("Failed to save portal to database after creating Matrix room") return err } + if info.CanBackfill { + err = portal.Bridge.DB.BackfillQueue.Upsert(ctx, &database.BackfillTask{ + PortalKey: portal.PortalKey, + UserLoginID: source.ID, + NextDispatchMinTS: time.Now().Add(BackfillMinBackoffAfterRoomCreate), + }) + if err != nil { + log.Err(err).Msg("Failed to create backfill queue task after creating room") + } + } if portal.Parent != nil { if portal.Parent.MXID != "" { portal.addToParentSpaceAndSave(ctx, true) From edc71a5ee386dbc1025b093dc9368bc7a6a5e48e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 18 Jul 2024 16:30:17 +0300 Subject: [PATCH 0468/1647] bridgev2/backfill: actually run backfill queue --- bridgev2/backfillqueue.go | 16 ++++++++++++++-- bridgev2/bridge.go | 4 ++++ bridgev2/bridgeconfig/backfill.go | 7 ++++--- bridgev2/bridgeconfig/upgrade.go | 3 ++- bridgev2/matrix/mxmain/example-config.yaml | 2 ++ bridgev2/portal.go | 2 +- bridgev2/portalbackfill.go | 1 + 7 files changed, 28 insertions(+), 7 deletions(-) diff --git a/bridgev2/backfillqueue.go b/bridgev2/backfillqueue.go index d234b60d..085513f7 100644 --- a/bridgev2/backfillqueue.go +++ b/bridgev2/backfillqueue.go @@ -28,11 +28,19 @@ func (br *Bridge) WakeupBackfillQueue() { } func (br *Bridge) RunBackfillQueue() { - if !br.Matrix.GetCapabilities().BatchSending { + if !br.Config.Backfill.Queue.Enabled || !br.Config.Backfill.Enabled { return } log := br.Log.With().Str("component", "backfill queue").Logger() - ctx := log.WithContext(context.Background()) + if !br.Matrix.GetCapabilities().BatchSending { + log.Warn().Msg("Backfill queue is enabled in config, but Matrix server doesn't support batch sending") + return + } + ctx, cancel := context.WithCancel(log.WithContext(context.Background())) + go func() { + <-br.stopBackfillQueue + cancel() + }() batchDelay := time.Duration(br.Config.Backfill.Queue.BatchDelay) * time.Second afterTimer := time.NewTimer(batchDelay) for { @@ -54,6 +62,10 @@ func (br *Bridge) RunBackfillQueue() { afterTimer.Reset(nextDelay) select { case <-br.wakeupBackfillQueue: + case <-br.stopBackfillQueue: + afterTimer.Stop() + log.Info().Msg("Stopping backfill queue") + return case <-afterTimer.C: } } diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 945a6f90..76a5d2c8 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -49,6 +49,7 @@ type Bridge struct { cacheLock sync.Mutex wakeupBackfillQueue chan struct{} + stopBackfillQueue chan struct{} } func NewBridge( @@ -76,6 +77,7 @@ func NewBridge( ghostsByID: make(map[networkid.UserID]*Ghost), wakeupBackfillQueue: make(chan struct{}), + stopBackfillQueue: make(chan struct{}), } if br.Config == nil { br.Config = &bridgeconfig.BridgeConfig{CommandPrefix: "!bridge"} @@ -149,6 +151,7 @@ func (br *Bridge) Start() error { br.Log.Info().Msg("No user logins found") br.SendGlobalBridgeState(status.BridgeState{StateEvent: status.StateUnconfigured}) } + go br.RunBackfillQueue() br.Log.Info().Msg("Bridge started") return nil @@ -156,6 +159,7 @@ func (br *Bridge) Start() error { func (br *Bridge) Stop() { br.Log.Info().Msg("Shutting down bridge") + close(br.stopBackfillQueue) br.Matrix.Stop() br.cacheLock.Lock() var wg sync.WaitGroup diff --git a/bridgev2/bridgeconfig/backfill.go b/bridgev2/bridgeconfig/backfill.go index 34218cc4..fe464569 100644 --- a/bridgev2/bridgeconfig/backfill.go +++ b/bridgev2/bridgeconfig/backfill.go @@ -21,9 +21,10 @@ type BackfillThreadsConfig struct { } type BackfillQueueConfig struct { - BatchSize int `yaml:"batch_size"` - BatchDelay int `yaml:"batch_delay"` - MaxBatches int `yaml:"max_batches"` + Enabled bool `yaml:"enabled"` + BatchSize int `yaml:"batch_size"` + BatchDelay int `yaml:"batch_delay"` + MaxBatches int `yaml:"max_batches"` MaxBatchesOverride map[string]int `yaml:"max_batches_override"` } diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 1d5ee0ae..6b7493d2 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -96,7 +96,8 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Int, "backfill", "max_initial_messages") helper.Copy(up.Int, "backfill", "max_catchup_messages") helper.Copy(up.Int, "backfill", "unread_hours_threshold") - helper.Copy(up.Int, "backfill", "threads", "max_initial_messages") + helper.Copy(up.Bool, "backfill", "threads", "max_initial_messages") + helper.Copy(up.Int, "backfill", "queue", "enabled") helper.Copy(up.Int, "backfill", "queue", "batch_size") helper.Copy(up.Int, "backfill", "queue", "batch_delay") helper.Copy(up.Int, "backfill", "queue", "max_batches") diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 3399297c..92d4647c 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -200,6 +200,8 @@ backfill: # Settings for the backwards backfill queue. This only applies when connecting to # Beeper as standard Matrix servers don't support inserting messages into history. queue: + # Should the backfill queue be enabled? + enabled: false # Number of messages to backfill in one batch. batch_size: 100 # Delay between batches in seconds. diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 5e867665..38159e9c 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1674,7 +1674,7 @@ func (portal *Portal) handleRemoteChatResync(ctx context.Context, source *UserLo } } backfillChecker, ok := evt.(RemoteChatResyncBackfill) - if ok { + if portal.Bridge.Config.Backfill.Enabled && ok { latestMessage, err := portal.Bridge.DB.Message.GetLastPartAtOrBeforeTime(ctx, portal.PortalKey, time.Now().Add(10*time.Second)) if err != nil { log.Err(err).Msg("Failed to get last message in portal to check if backfill is necessary") diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index 2c7f39be..5b0c2361 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -57,6 +57,7 @@ func (portal *Portal) doForwardBackfill(ctx context.Context, source *UserLogin, log.Debug().Msg("No messages to backfill") return } + // TODO mark backfill queue task as done if last message is nil (-> room was empty) and HasMore is false? resp.Messages = cutoffMessages(&log, resp.Messages, true, lastMessage) if len(resp.Messages) == 0 { log.Warn().Msg("No messages left to backfill after cutting off old messages") From c0aa5898d8adcc3c4b58b96aa08359eaa698bcea Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 18 Jul 2024 16:30:33 +0300 Subject: [PATCH 0469/1647] bridgev2/ghost: adjust UpdateInfoIfNecessary logs --- bridgev2/ghost.go | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index 78fe7b00..125cd9c0 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -213,15 +213,20 @@ func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin if ghost.Name != "" && ghost.NameSet && !ghost.Bridge.allowAggressiveUpdateForType(evtType) { return } - zerolog.Ctx(ctx).Debug(). - Bool("has_name", ghost.Name != ""). - Bool("name_set", ghost.NameSet). - Msg("Updating ghost info in IfNecessary call") info, err := source.Client.GetUserInfo(ctx, ghost) if err != nil { zerolog.Ctx(ctx).Err(err).Str("ghost_id", string(ghost.ID)).Msg("Failed to get info to update ghost") } else if info != nil { + zerolog.Ctx(ctx).Debug(). + Bool("has_name", ghost.Name != ""). + Bool("name_set", ghost.NameSet). + Msg("Updating ghost info in IfNecessary call") ghost.UpdateInfo(ctx, info) + } else { + zerolog.Ctx(ctx).Trace(). + Bool("has_name", ghost.Name != ""). + Bool("name_set", ghost.NameSet). + Msg("No ghost info received in IfNecessary call") } } From 6509b11d9c37027d38f7866353e54e391a9a2317 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 18 Jul 2024 16:40:52 +0300 Subject: [PATCH 0470/1647] bridgev2/backfill: add more logs --- bridgev2/backfillqueue.go | 1 + bridgev2/bridgeconfig/upgrade.go | 4 ++-- bridgev2/portalbackfill.go | 20 +++++++++++++++++++- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/bridgev2/backfillqueue.go b/bridgev2/backfillqueue.go index 085513f7..dd736085 100644 --- a/bridgev2/backfillqueue.go +++ b/bridgev2/backfillqueue.go @@ -43,6 +43,7 @@ func (br *Bridge) RunBackfillQueue() { }() batchDelay := time.Duration(br.Config.Backfill.Queue.BatchDelay) * time.Second afterTimer := time.NewTimer(batchDelay) + log.Info().Stringer("batch_delay", batchDelay).Msg("Backfill queue starting") for { backfillTask, err := br.DB.BackfillQueue.GetNext(ctx) if err != nil { diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 6b7493d2..04f7dab3 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -96,8 +96,8 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Int, "backfill", "max_initial_messages") helper.Copy(up.Int, "backfill", "max_catchup_messages") helper.Copy(up.Int, "backfill", "unread_hours_threshold") - helper.Copy(up.Bool, "backfill", "threads", "max_initial_messages") - helper.Copy(up.Int, "backfill", "queue", "enabled") + helper.Copy(up.Int, "backfill", "threads", "max_initial_messages") + helper.Copy(up.Bool, "backfill", "queue", "enabled") helper.Copy(up.Int, "backfill", "queue", "batch_size") helper.Copy(up.Int, "backfill", "queue", "batch_delay") helper.Copy(up.Int, "backfill", "queue", "max_batches") diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index 5b0c2361..75c1f163 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -76,6 +76,16 @@ func (portal *Portal) DoBackwardsBackfill(ctx context.Context, source *UserLogin if err != nil { return fmt.Errorf("failed to get first portal message: %w", err) } + logEvt := log.Info(). + Str("cursor", string(task.Cursor)). + Str("task_oldest_message_id", string(task.OldestMessageID)). + Int("current_batch_count", task.BatchCount) + if firstMessage != nil { + logEvt = logEvt.Str("db_oldest_message_id", string(firstMessage.ID)) + } else { + logEvt = logEvt.Str("db_oldest_message_id", "") + } + logEvt.Msg("Fetching messages for backward backfill") resp, err := api.FetchMessages(ctx, FetchMessagesParams{ Portal: portal, ThreadRoot: "", @@ -87,6 +97,11 @@ func (portal *Portal) DoBackwardsBackfill(ctx context.Context, source *UserLogin if err != nil { return fmt.Errorf("failed to fetch messages for backward backfill: %w", err) } + log.Debug(). + Str("new_cursor", string(resp.Cursor)). + Bool("has_more", resp.HasMore). + Int("message_count", len(resp.Messages)). + Msg("Fetched messages for backward backfill") task.Cursor = resp.Cursor if !resp.HasMore { task.IsDone = true @@ -187,7 +202,10 @@ func cutoffMessages(log *zerolog.Logger, messages []*BackfillMessage, forward bo func (portal *Portal) sendBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, inThread bool) { canBatchSend := portal.Bridge.Matrix.GetCapabilities().BatchSending - zerolog.Ctx(ctx).Info().Int("message_count", len(messages)).Bool("batch_send", canBatchSend).Msg("Sending backfill messages") + zerolog.Ctx(ctx).Info(). + Int("message_count", len(messages)). + Bool("batch_send", canBatchSend). + Msg("Sending backfill messages") if canBatchSend { portal.sendBatch(ctx, source, messages, forceForward, markRead) } else { From 3fe5071c3fefe30e7146f5f12b25b34432f5891e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 18 Jul 2024 16:51:01 +0300 Subject: [PATCH 0471/1647] bridgev2/database: rename backfill_queue to backfill_task --- bridgev2/backfillqueue.go | 10 ++-- bridgev2/database/backfillqueue.go | 52 +++++++++---------- bridgev2/database/database.go | 4 +- bridgev2/database/upgrades/00-latest.sql | 2 +- .../database/upgrades/13-backfill-queue.sql | 2 +- bridgev2/portal.go | 4 +- 6 files changed, 37 insertions(+), 37 deletions(-) diff --git a/bridgev2/backfillqueue.go b/bridgev2/backfillqueue.go index dd736085..f9cb8010 100644 --- a/bridgev2/backfillqueue.go +++ b/bridgev2/backfillqueue.go @@ -45,7 +45,7 @@ func (br *Bridge) RunBackfillQueue() { afterTimer := time.NewTimer(batchDelay) log.Info().Stringer("batch_delay", batchDelay).Msg("Backfill queue starting") for { - backfillTask, err := br.DB.BackfillQueue.GetNext(ctx) + backfillTask, err := br.DB.BackfillTask.GetNext(ctx) if err != nil { log.Err(err).Msg("Failed to get next backfill queue entry") time.Sleep(BackfillQueueErrorBackoff) @@ -77,7 +77,7 @@ func (br *Bridge) doBackfillTask(ctx context.Context, task *database.BackfillTas Object("portal_key", task.PortalKey). Str("login_id", string(task.UserLoginID)). Logger() - err := br.DB.BackfillQueue.MarkDispatched(ctx, task) + err := br.DB.BackfillTask.MarkDispatched(ctx, task) if err != nil { log.Err(err).Msg("Failed to mark backfill task as dispatched") time.Sleep(BackfillQueueErrorBackoff) @@ -93,7 +93,7 @@ func (br *Bridge) doBackfillTask(ctx context.Context, task *database.BackfillTas } else { log.Info().Msg("Backfill task canceled") } - err = br.DB.BackfillQueue.Update(ctx, task) + err = br.DB.BackfillTask.Update(ctx, task) if err != nil { log.Err(err).Msg("Failed to update backfill task") time.Sleep(BackfillQueueErrorBackoff) @@ -106,7 +106,7 @@ func (portal *Portal) deleteBackfillQueueTaskIfRoomDoesNotExist(ctx context.Cont defer portal.roomCreateLock.Unlock() if portal.MXID == "" { zerolog.Ctx(ctx).Debug().Msg("Portal for backfill task doesn't exist, deleting entry") - err := portal.Bridge.DB.BackfillQueue.Delete(ctx, portal.PortalKey) + err := portal.Bridge.DB.BackfillTask.Delete(ctx, portal.PortalKey) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to delete backfill task after portal wasn't found") } @@ -122,7 +122,7 @@ func (br *Bridge) actuallyDoBackfillTask(ctx context.Context, task *database.Bac return false, fmt.Errorf("failed to get portal for backfill task: %w", err) } else if portal == nil { log.Warn().Msg("Portal not found for backfill task") - err = br.DB.BackfillQueue.Delete(ctx, task.PortalKey) + err = br.DB.BackfillTask.Delete(ctx, task.PortalKey) if err != nil { log.Err(err).Msg("Failed to delete backfill task after portal wasn't found") time.Sleep(BackfillQueueErrorBackoff) diff --git a/bridgev2/database/backfillqueue.go b/bridgev2/database/backfillqueue.go index db90bf8d..5d7cf854 100644 --- a/bridgev2/database/backfillqueue.go +++ b/bridgev2/database/backfillqueue.go @@ -16,7 +16,7 @@ import ( "maunium.net/go/mautrix/bridgev2/networkid" ) -type BackfillQueueQuery struct { +type BackfillTaskQuery struct { BridgeID networkid.BridgeID *dbutil.QueryHelper[*BackfillTask] } @@ -39,22 +39,22 @@ var BackfillNextDispatchNever = time.Unix(0, (1<<63)-1) const ( ensureBackfillExistsQuery = ` - INSERT INTO backfill_queue (bridge_id, portal_id, portal_receiver, user_login_id, batch_count, is_done, next_dispatch_min_ts) + INSERT INTO backfill_task (bridge_id, portal_id, portal_receiver, user_login_id, batch_count, is_done, next_dispatch_min_ts) VALUES ($1, $2, $3, $4, 0, false, $5) ON CONFLICT (bridge_id, portal_id, portal_receiver) DO UPDATE SET user_login_id=CASE - WHEN backfill_queue.user_login_id='' + WHEN backfill_task.user_login_id='' THEN excluded.user_login_id - ELSE backfill_queue.user_login_id + ELSE backfill_task.user_login_id END, next_dispatch_min_ts=CASE - WHEN backfill_queue.next_dispatch_min_ts=9223372036854775807 + WHEN backfill_task.next_dispatch_min_ts=9223372036854775807 THEN excluded.next_dispatch_min_ts - ELSE backfill_queue.next_dispatch_min_ts + ELSE backfill_task.next_dispatch_min_ts END ` upsertBackfillQueueQuery = ` - INSERT INTO backfill_queue ( + INSERT INTO backfill_task ( bridge_id, portal_id, portal_receiver, user_login_id, batch_count, is_done, cursor, oldest_message_id, dispatched_at, completed_at, next_dispatch_min_ts ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) @@ -69,11 +69,11 @@ const ( next_dispatch_min_ts=excluded.next_dispatch_min_ts ` markBackfillDispatchedQuery = ` - UPDATE backfill_queue SET dispatched_at=$4, completed_at=NULL, next_dispatch_min_ts=$5 + UPDATE backfill_task SET dispatched_at=$4, completed_at=NULL, next_dispatch_min_ts=$5 WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3 ` updateBackfillQueueQuery = ` - UPDATE backfill_queue + UPDATE backfill_task SET user_login_id=$4, batch_count=$5, is_done=$6, cursor=$7, oldest_message_id=$8, dispatched_at=$9, completed_at=$10, next_dispatch_min_ts=$11 WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3 @@ -82,50 +82,50 @@ const ( SELECT bridge_id, portal_id, portal_receiver, user_login_id, batch_count, is_done, cursor, oldest_message_id, dispatched_at, completed_at, next_dispatch_min_ts - FROM backfill_queue + FROM backfill_task WHERE bridge_id = $1 AND next_dispatch_min_ts < $2 AND is_done = false AND user_login_id <> '' ORDER BY next_dispatch_min_ts LIMIT 1 ` deleteBackfillQueueQuery = ` - DELETE FROM backfill_queue + DELETE FROM backfill_task WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3 ` ) -func (bqq *BackfillQueueQuery) EnsureExists(ctx context.Context, portal networkid.PortalKey, loginID networkid.UserLoginID) error { - return bqq.Exec(ctx, ensureBackfillExistsQuery, bqq.BridgeID, portal.ID, portal.Receiver, loginID, time.Now().UnixNano()) +func (btq *BackfillTaskQuery) EnsureExists(ctx context.Context, portal networkid.PortalKey, loginID networkid.UserLoginID) error { + return btq.Exec(ctx, ensureBackfillExistsQuery, btq.BridgeID, portal.ID, portal.Receiver, loginID, time.Now().UnixNano()) } -func (bqq *BackfillQueueQuery) Upsert(ctx context.Context, bq *BackfillTask) error { - ensureBridgeIDMatches(&bq.BridgeID, bqq.BridgeID) - return bqq.Exec(ctx, upsertBackfillQueueQuery, bq.sqlVariables()...) +func (btq *BackfillTaskQuery) Upsert(ctx context.Context, bq *BackfillTask) error { + ensureBridgeIDMatches(&bq.BridgeID, btq.BridgeID) + return btq.Exec(ctx, upsertBackfillQueueQuery, bq.sqlVariables()...) } const UnfinishedBackfillBackoff = 1 * time.Hour -func (bqq *BackfillQueueQuery) MarkDispatched(ctx context.Context, bq *BackfillTask) error { - ensureBridgeIDMatches(&bq.BridgeID, bqq.BridgeID) +func (btq *BackfillTaskQuery) MarkDispatched(ctx context.Context, bq *BackfillTask) error { + ensureBridgeIDMatches(&bq.BridgeID, btq.BridgeID) bq.DispatchedAt = time.Now() bq.CompletedAt = time.Time{} bq.NextDispatchMinTS = bq.DispatchedAt.Add(UnfinishedBackfillBackoff) - return bqq.Exec( + return btq.Exec( ctx, markBackfillDispatchedQuery, bq.BridgeID, bq.PortalKey.ID, bq.PortalKey.Receiver, bq.DispatchedAt.UnixNano(), bq.NextDispatchMinTS.UnixNano(), ) } -func (bqq *BackfillQueueQuery) Update(ctx context.Context, bq *BackfillTask) error { - ensureBridgeIDMatches(&bq.BridgeID, bqq.BridgeID) - return bqq.Exec(ctx, updateBackfillQueueQuery, bq.sqlVariables()...) +func (btq *BackfillTaskQuery) Update(ctx context.Context, bq *BackfillTask) error { + ensureBridgeIDMatches(&bq.BridgeID, btq.BridgeID) + return btq.Exec(ctx, updateBackfillQueueQuery, bq.sqlVariables()...) } -func (bqq *BackfillQueueQuery) GetNext(ctx context.Context) (*BackfillTask, error) { - return bqq.QueryOne(ctx, getNextBackfillQuery, bqq.BridgeID, time.Now().UnixNano()) +func (btq *BackfillTaskQuery) GetNext(ctx context.Context) (*BackfillTask, error) { + return btq.QueryOne(ctx, getNextBackfillQuery, btq.BridgeID, time.Now().UnixNano()) } -func (bqq *BackfillQueueQuery) Delete(ctx context.Context, portalKey networkid.PortalKey) error { - return bqq.Exec(ctx, deleteBackfillQueueQuery, bqq.BridgeID, portalKey.ID, portalKey.Receiver) +func (btq *BackfillTaskQuery) Delete(ctx context.Context, portalKey networkid.PortalKey) error { + return btq.Exec(ctx, deleteBackfillQueueQuery, btq.BridgeID, portalKey.ID, portalKey.Receiver) } func (bt *BackfillTask) Scan(row dbutil.Scannable) (*BackfillTask, error) { diff --git a/bridgev2/database/database.go b/bridgev2/database/database.go index 16e556cc..aa77a232 100644 --- a/bridgev2/database/database.go +++ b/bridgev2/database/database.go @@ -32,7 +32,7 @@ type Database struct { User *UserQuery UserLogin *UserLoginQuery UserPortal *UserPortalQuery - BackfillQueue *BackfillQueueQuery + BackfillTask *BackfillTaskQuery } type MetaMerger interface { @@ -130,7 +130,7 @@ func New(bridgeID networkid.BridgeID, mt MetaTypes, db *dbutil.Database) *Databa return &UserPortal{} }), }, - BackfillQueue: &BackfillQueueQuery{ + BackfillTask: &BackfillTaskQuery{ BridgeID: bridgeID, QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*BackfillTask]) *BackfillTask { return &BackfillTask{} diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index 0d388f7f..d02964a2 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -176,7 +176,7 @@ CREATE TABLE user_portal ( CREATE INDEX user_portal_login_idx ON user_portal (bridge_id, login_id); CREATE INDEX user_portal_portal_idx ON user_portal (bridge_id, portal_id, portal_receiver); -CREATE TABLE backfill_queue ( +CREATE TABLE backfill_task ( bridge_id TEXT NOT NULL, portal_id TEXT NOT NULL, portal_receiver TEXT NOT NULL, diff --git a/bridgev2/database/upgrades/13-backfill-queue.sql b/bridgev2/database/upgrades/13-backfill-queue.sql index b8f511e6..dada993c 100644 --- a/bridgev2/database/upgrades/13-backfill-queue.sql +++ b/bridgev2/database/upgrades/13-backfill-queue.sql @@ -1,5 +1,5 @@ -- v13 (compatible with v9+): Add backfill queue -CREATE TABLE backfill_queue ( +CREATE TABLE backfill_task ( bridge_id TEXT NOT NULL, portal_id TEXT NOT NULL, portal_receiver TEXT NOT NULL, diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 38159e9c..28603c74 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2323,7 +2323,7 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us portal.updateUserLocalInfo(ctx, info.UserLocal, source) } if info.CanBackfill && source != nil { - err := portal.Bridge.DB.BackfillQueue.EnsureExists(ctx, portal.PortalKey, source.ID) + err := portal.Bridge.DB.BackfillTask.EnsureExists(ctx, portal.PortalKey, source.ID) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to ensure backfill queue task exists") } @@ -2466,7 +2466,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i return err } if info.CanBackfill { - err = portal.Bridge.DB.BackfillQueue.Upsert(ctx, &database.BackfillTask{ + err = portal.Bridge.DB.BackfillTask.Upsert(ctx, &database.BackfillTask{ PortalKey: portal.PortalKey, UserLoginID: source.ID, NextDispatchMinTS: time.Now().Add(BackfillMinBackoffAfterRoomCreate), From 28d15fa7b063ce42cdf8e19fe383ef3de1a7e755 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 18 Jul 2024 17:10:48 +0300 Subject: [PATCH 0472/1647] bridgev2/commands: add command to delete all portals --- bridgev2/commands/cleanup.go | 76 ++++++++++++++++++++++++++++++++++ bridgev2/commands/processor.go | 2 +- bridgev2/commands/startchat.go | 21 ---------- bridgev2/database/portal.go | 5 +++ bridgev2/portal.go | 10 +++++ 5 files changed, 92 insertions(+), 22 deletions(-) create mode 100644 bridgev2/commands/cleanup.go diff --git a/bridgev2/commands/cleanup.go b/bridgev2/commands/cleanup.go new file mode 100644 index 00000000..55f34d14 --- /dev/null +++ b/bridgev2/commands/cleanup.go @@ -0,0 +1,76 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "cmp" + "slices" + + "maunium.net/go/mautrix/bridgev2" +) + +var CommandDeletePortal = &FullHandler{ + Func: func(ce *Event) { + // TODO clean up child portals? + err := ce.Portal.Delete(ce.Ctx) + if err != nil { + ce.Reply("Failed to delete portal: %v", err) + return + } + err = ce.Bot.DeleteRoom(ce.Ctx, ce.Portal.MXID, false) + if err != nil { + ce.Reply("Failed to clean up room: %v", err) + } + ce.MessageStatus.DisableMSS = true + }, + Name: "delete-portal", + Help: HelpMeta{ + Section: HelpSectionAdmin, + Description: "Delete the current portal room", + }, + RequiresAdmin: true, + RequiresPortal: true, +} + +var CommandDeleteAllPortals = &FullHandler{ + Func: func(ce *Event) { + portals, err := ce.Bridge.GetAllPortals(ce.Ctx) + if err != nil { + ce.Reply("Failed to get portals: %v", err) + return + } + getDepth := func(portal *bridgev2.Portal) int { + depth := 0 + for portal.Parent != nil { + depth++ + portal = portal.Parent + } + return depth + } + // Sort portals so parents are last (to avoid errors caused by deleting parent portals before children) + slices.SortFunc(portals, func(a, b *bridgev2.Portal) int { + return cmp.Compare(getDepth(b), getDepth(a)) + }) + for _, portal := range portals { + err = portal.Delete(ce.Ctx) + if err != nil { + ce.Reply("Failed to delete portal %s: %v", portal.MXID, err) + continue + } + err = ce.Bot.DeleteRoom(ce.Ctx, portal.MXID, false) + if err != nil { + ce.Reply("Failed to clean up room %s: %v", portal.MXID, err) + } + } + }, + Name: "delete-all-portals", + Help: HelpMeta{ + Section: HelpSectionAdmin, + Description: "Delete all portals the bridge knows about", + }, + RequiresAdmin: true, +} diff --git a/bridgev2/commands/processor.go b/bridgev2/commands/processor.go index d14e9781..774be16b 100644 --- a/bridgev2/commands/processor.go +++ b/bridgev2/commands/processor.go @@ -42,7 +42,7 @@ func NewProcessor(bridge *bridgev2.Bridge) bridgev2.CommandProcessor { } proc.AddHandlers( CommandHelp, CommandCancel, - CommandRegisterPush, CommandDeletePortal, + CommandRegisterPush, CommandDeletePortal, CommandDeleteAllPortals, CommandLogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin, CommandSetRelay, CommandUnsetRelay, CommandResolveIdentifier, CommandStartChat, diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go index 9ad5f77c..0a4e6783 100644 --- a/bridgev2/commands/startchat.go +++ b/bridgev2/commands/startchat.go @@ -140,24 +140,3 @@ func fnResolveIdentifier(ce *Event) { ce.Reply("Found %s", formattedName) } } - -var CommandDeletePortal = &FullHandler{ - Func: func(ce *Event) { - err := ce.Portal.Delete(ce.Ctx) - if err != nil { - ce.Reply("Failed to delete portal: %v", err) - } - err = ce.Bot.DeleteRoom(ce.Ctx, ce.Portal.MXID, false) - if err != nil { - ce.Reply("Failed to clean up room: %v", err) - } - ce.MessageStatus.DisableMSS = true - }, - Name: "delete-portal", - Help: HelpMeta{ - Section: HelpSectionAdmin, - Description: "Delete the current portal room", - }, - RequiresAdmin: true, - RequiresPortal: true, -} diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index af0f6a9f..2f675593 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -69,6 +69,7 @@ const ( getPortalByIDWithUncertainReceiverQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='')` getPortalByMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` getAllPortalsWithMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid IS NOT NULL` + getAllPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1` getChildPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND parent_id=$2` findPortalReceiverQuery = `SELECT id, receiver FROM portal WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='') LIMIT 1` @@ -125,6 +126,10 @@ func (pq *PortalQuery) GetAllWithMXID(ctx context.Context) ([]*Portal, error) { return pq.QueryMany(ctx, getAllPortalsWithMXIDQuery, pq.BridgeID) } +func (pq *PortalQuery) GetAll(ctx context.Context) ([]*Portal, error) { + return pq.QueryMany(ctx, getAllPortalsQuery, pq.BridgeID) +} + func (pq *PortalQuery) GetChildren(ctx context.Context, parentID networkid.PortalID) ([]*Portal, error) { return pq.QueryMany(ctx, getChildPortalsQuery, pq.BridgeID, parentID) } diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 28603c74..a790324f 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -199,6 +199,16 @@ func (br *Bridge) GetAllPortalsWithMXID(ctx context.Context) ([]*Portal, error) return br.loadManyPortals(ctx, rows) } +func (br *Bridge) GetAllPortals(ctx context.Context) ([]*Portal, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + rows, err := br.DB.Portal.GetAll(ctx) + if err != nil { + return nil, err + } + return br.loadManyPortals(ctx, rows) +} + func (br *Bridge) GetPortalByKey(ctx context.Context, key networkid.PortalKey) (*Portal, error) { br.cacheLock.Lock() defer br.cacheLock.Unlock() From ceb664005477160875f4705ba5456d6ebf8bd128 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 18 Jul 2024 17:13:38 +0300 Subject: [PATCH 0473/1647] bridgev2/backfill: respect `unread_hours_threshold` config option --- bridgev2/portalbackfill.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index 75c1f163..d23341dd 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -202,14 +202,18 @@ func cutoffMessages(log *zerolog.Logger, messages []*BackfillMessage, forward bo func (portal *Portal) sendBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, inThread bool) { canBatchSend := portal.Bridge.Matrix.GetCapabilities().BatchSending + unreadThreshold := time.Duration(portal.Bridge.Config.Backfill.UnreadHoursThreshold) * time.Hour + forceMarkRead := unreadThreshold > 0 && time.Since(messages[len(messages)-1].Timestamp) > unreadThreshold zerolog.Ctx(ctx).Info(). Int("message_count", len(messages)). Bool("batch_send", canBatchSend). + Bool("mark_read", markRead). + Bool("mark_read_past_threshold", forceMarkRead). Msg("Sending backfill messages") if canBatchSend { - portal.sendBatch(ctx, source, messages, forceForward, markRead) + portal.sendBatch(ctx, source, messages, forceForward, markRead || forceMarkRead) } else { - portal.sendLegacyBackfill(ctx, source, messages, markRead) + portal.sendLegacyBackfill(ctx, source, messages, markRead || forceMarkRead) } zerolog.Ctx(ctx).Debug().Msg("Backfill finished") if !inThread && portal.Bridge.Config.Backfill.Threads.MaxInitialMessages > 0 { From e341bdf0e8125f69f126933133bb1a6b5822542a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 18 Jul 2024 17:19:09 +0300 Subject: [PATCH 0474/1647] bridgev2: implement more fields in SimpleRemoteEvent --- bridgev2/networkinterface.go | 91 ----------------------- bridgev2/simpleremoteevent.go | 131 ++++++++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 91 deletions(-) create mode 100644 bridgev2/simpleremoteevent.go diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 432041b2..fe7d2350 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -778,97 +778,6 @@ type RemoteTypingWithType interface { GetTypingType() TypingType } -// SimpleRemoteEvent is a simple implementation of RemoteEvent that can be used with struct fields and some callbacks. -type SimpleRemoteEvent[T any] struct { - Type RemoteEventType - LogContext func(c zerolog.Context) zerolog.Context - PortalKey networkid.PortalKey - Data T - CreatePortal bool - - ID networkid.MessageID - Sender EventSender - TargetMessage networkid.MessageID - EmojiID networkid.EmojiID - Emoji string - ReactionDBMeta any - Timestamp time.Time - ChatInfoChange *ChatInfoChange - - ConvertMessageFunc func(ctx context.Context, portal *Portal, intent MatrixAPI, data T) (*ConvertedMessage, error) - ConvertEditFunc func(ctx context.Context, portal *Portal, intent MatrixAPI, existing []*database.Message, data T) (*ConvertedEdit, error) -} - -var ( - _ RemoteMessage = (*SimpleRemoteEvent[any])(nil) - _ RemoteEdit = (*SimpleRemoteEvent[any])(nil) - _ RemoteEventWithTimestamp = (*SimpleRemoteEvent[any])(nil) - _ RemoteReaction = (*SimpleRemoteEvent[any])(nil) - _ RemoteReactionWithMeta = (*SimpleRemoteEvent[any])(nil) - _ RemoteReactionRemove = (*SimpleRemoteEvent[any])(nil) - _ RemoteMessageRemove = (*SimpleRemoteEvent[any])(nil) - _ RemoteChatInfoChange = (*SimpleRemoteEvent[any])(nil) -) - -func (sre *SimpleRemoteEvent[T]) AddLogContext(c zerolog.Context) zerolog.Context { - return sre.LogContext(c) -} - -func (sre *SimpleRemoteEvent[T]) GetPortalKey() networkid.PortalKey { - return sre.PortalKey -} - -func (sre *SimpleRemoteEvent[T]) GetTimestamp() time.Time { - if sre.Timestamp.IsZero() { - return time.Now() - } - return sre.Timestamp -} - -func (sre *SimpleRemoteEvent[T]) ConvertMessage(ctx context.Context, portal *Portal, intent MatrixAPI) (*ConvertedMessage, error) { - return sre.ConvertMessageFunc(ctx, portal, intent, sre.Data) -} - -func (sre *SimpleRemoteEvent[T]) ConvertEdit(ctx context.Context, portal *Portal, intent MatrixAPI, existing []*database.Message) (*ConvertedEdit, error) { - return sre.ConvertEditFunc(ctx, portal, intent, existing, sre.Data) -} - -func (sre *SimpleRemoteEvent[T]) GetID() networkid.MessageID { - return sre.ID -} - -func (sre *SimpleRemoteEvent[T]) GetSender() EventSender { - return sre.Sender -} - -func (sre *SimpleRemoteEvent[T]) GetTargetMessage() networkid.MessageID { - return sre.TargetMessage -} - -func (sre *SimpleRemoteEvent[T]) GetReactionEmoji() (string, networkid.EmojiID) { - return sre.Emoji, sre.EmojiID -} - -func (sre *SimpleRemoteEvent[T]) GetRemovedEmojiID() networkid.EmojiID { - return sre.EmojiID -} - -func (sre *SimpleRemoteEvent[T]) GetReactionDBMetadata() any { - return sre.ReactionDBMeta -} - -func (sre *SimpleRemoteEvent[T]) GetChatInfoChange(ctx context.Context) (*ChatInfoChange, error) { - return sre.ChatInfoChange, nil -} - -func (sre *SimpleRemoteEvent[T]) GetType() RemoteEventType { - return sre.Type -} - -func (sre *SimpleRemoteEvent[T]) ShouldCreatePortal() bool { - return sre.CreatePortal -} - type OrigSender struct { User *User event.MemberEventContent diff --git a/bridgev2/simpleremoteevent.go b/bridgev2/simpleremoteevent.go new file mode 100644 index 00000000..a45ff9c3 --- /dev/null +++ b/bridgev2/simpleremoteevent.go @@ -0,0 +1,131 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" +) + +// SimpleRemoteEvent is a simple implementation of RemoteEvent that can be used with struct fields and some callbacks. +// +// Using this type is only recommended for simple bridges. More advanced ones should implement +// the remote event interfaces themselves by wrapping the remote network library event types. +type SimpleRemoteEvent[T any] struct { + Type RemoteEventType + LogContext func(c zerolog.Context) zerolog.Context + PortalKey networkid.PortalKey + Data T + CreatePortal bool + + ID networkid.MessageID + Sender EventSender + TargetMessage networkid.MessageID + EmojiID networkid.EmojiID + Emoji string + ReactionDBMeta any + Timestamp time.Time + ChatInfoChange *ChatInfoChange + + ResyncChatInfo *ChatInfo + ResyncBackfillNeeded bool + + BackfillData *FetchMessagesResponse + + ConvertMessageFunc func(ctx context.Context, portal *Portal, intent MatrixAPI, data T) (*ConvertedMessage, error) + ConvertEditFunc func(ctx context.Context, portal *Portal, intent MatrixAPI, existing []*database.Message, data T) (*ConvertedEdit, error) +} + +var ( + _ RemoteMessage = (*SimpleRemoteEvent[any])(nil) + _ RemoteEdit = (*SimpleRemoteEvent[any])(nil) + _ RemoteEventWithTimestamp = (*SimpleRemoteEvent[any])(nil) + _ RemoteReaction = (*SimpleRemoteEvent[any])(nil) + _ RemoteReactionWithMeta = (*SimpleRemoteEvent[any])(nil) + _ RemoteReactionRemove = (*SimpleRemoteEvent[any])(nil) + _ RemoteMessageRemove = (*SimpleRemoteEvent[any])(nil) + _ RemoteChatInfoChange = (*SimpleRemoteEvent[any])(nil) + _ RemoteChatResyncWithInfo = (*SimpleRemoteEvent[any])(nil) + _ RemoteChatResyncBackfill = (*SimpleRemoteEvent[any])(nil) + _ RemoteBackfill = (*SimpleRemoteEvent[any])(nil) +) + +func (sre *SimpleRemoteEvent[T]) AddLogContext(c zerolog.Context) zerolog.Context { + return sre.LogContext(c) +} + +func (sre *SimpleRemoteEvent[T]) GetPortalKey() networkid.PortalKey { + return sre.PortalKey +} + +func (sre *SimpleRemoteEvent[T]) GetTimestamp() time.Time { + if sre.Timestamp.IsZero() { + return time.Now() + } + return sre.Timestamp +} + +func (sre *SimpleRemoteEvent[T]) ConvertMessage(ctx context.Context, portal *Portal, intent MatrixAPI) (*ConvertedMessage, error) { + return sre.ConvertMessageFunc(ctx, portal, intent, sre.Data) +} + +func (sre *SimpleRemoteEvent[T]) ConvertEdit(ctx context.Context, portal *Portal, intent MatrixAPI, existing []*database.Message) (*ConvertedEdit, error) { + return sre.ConvertEditFunc(ctx, portal, intent, existing, sre.Data) +} + +func (sre *SimpleRemoteEvent[T]) GetID() networkid.MessageID { + return sre.ID +} + +func (sre *SimpleRemoteEvent[T]) GetSender() EventSender { + return sre.Sender +} + +func (sre *SimpleRemoteEvent[T]) GetTargetMessage() networkid.MessageID { + return sre.TargetMessage +} + +func (sre *SimpleRemoteEvent[T]) GetReactionEmoji() (string, networkid.EmojiID) { + return sre.Emoji, sre.EmojiID +} + +func (sre *SimpleRemoteEvent[T]) GetRemovedEmojiID() networkid.EmojiID { + return sre.EmojiID +} + +func (sre *SimpleRemoteEvent[T]) GetReactionDBMetadata() any { + return sre.ReactionDBMeta +} + +func (sre *SimpleRemoteEvent[T]) GetChatInfoChange(ctx context.Context) (*ChatInfoChange, error) { + return sre.ChatInfoChange, nil +} + +func (sre *SimpleRemoteEvent[T]) GetType() RemoteEventType { + return sre.Type +} + +func (sre *SimpleRemoteEvent[T]) ShouldCreatePortal() bool { + return sre.CreatePortal +} + +func (sre *SimpleRemoteEvent[T]) GetBackfillData(ctx context.Context, portal *Portal) (*FetchMessagesResponse, error) { + return sre.BackfillData, nil +} + +func (sre *SimpleRemoteEvent[T]) CheckNeedsBackfill(ctx context.Context, latestMessage *database.Message) (bool, error) { + return sre.ResyncBackfillNeeded, nil +} + +func (sre *SimpleRemoteEvent[T]) GetChatInfo(ctx context.Context, portal *Portal) (*ChatInfo, error) { + return sre.ResyncChatInfo, nil +} From 5a7e002bcc4f6808c3291bebbcdb87a12a9d7ce4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 18 Jul 2024 17:37:22 +0300 Subject: [PATCH 0475/1647] bridgev2/backfill: do forward backfill after room creation --- bridgev2/portal.go | 66 ++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 58 insertions(+), 8 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index a790324f..73282a40 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -39,8 +39,16 @@ type portalRemoteEvent struct { source *UserLogin } +type portalCreateEvent struct { + ctx context.Context + source *UserLogin + info *ChatInfo + cb func(error) +} + func (pme *portalMatrixEvent) isPortalEvent() {} func (pre *portalRemoteEvent) isPortalEvent() {} +func (pre *portalCreateEvent) isPortalEvent() {} type portalEvent interface { isPortalEvent() @@ -250,12 +258,32 @@ func (portal *Portal) eventLoop() { portal.handleMatrixEvent(evt.sender, evt.evt) case *portalRemoteEvent: portal.handleRemoteEvent(evt.source, evt.evt) + case *portalCreateEvent: + portal.handleCreateEvent(evt) default: panic(fmt.Errorf("illegal type %T in eventLoop", evt)) } } } +func (portal *Portal) handleCreateEvent(evt *portalCreateEvent) { + defer func() { + if err := recover(); err != nil { + logEvt := zerolog.Ctx(evt.ctx).Error() + if realErr, ok := err.(error); ok { + logEvt = logEvt.Err(realErr) + } else { + logEvt = logEvt.Any(zerolog.ErrorFieldName, err) + } + logEvt. + Bytes("stack", debug.Stack()). + Msg("Portal creation panicked") + evt.cb(fmt.Errorf("portal creation panicked")) + } + }() + evt.cb(portal.createMatrixRoomInLoop(evt.ctx, evt.source, evt.info)) +} + func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowRelay bool) (*UserLogin, *database.UserPortal, error) { logins, err := portal.Bridge.DB.UserPortal.GetAllForUserInPortal(ctx, user.MXID, portal.PortalKey) if err != nil { @@ -1109,17 +1137,16 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { log.Err(err).Msg("Failed to get chat info for portal creation from chat resync event") } } - err = portal.CreateMatrixRoom(ctx, source, info) + err = portal.createMatrixRoomInLoop(ctx, source, info) if err != nil { log.Err(err).Msg("Failed to create portal to handle event") // TODO error return } - // TODO if CreateMatrixRoom is changed to backfill immediately, there's no need to handle chat resyncs further - //if evtType == RemoteEventChatResync { - // log.Debug().Msg("Not handling chat resync event further as portal was created by it") - // return - //} + if evtType == RemoteEventChatResync { + log.Debug().Msg("Not handling chat resync event further as portal was created by it") + return + } } preHandler, ok := evt.(RemotePreHandler) if ok { @@ -2350,7 +2377,30 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us } } -func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, info *ChatInfo) error { +func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, info *ChatInfo) (retErr error) { + waiter := make(chan struct{}) + closed := false + portal.events <- &portalCreateEvent{ + ctx: ctx, + source: source, + info: info, + cb: func(err error) { + retErr = err + if !closed { + closed = true + close(waiter) + } + }, + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-waiter: + return + } +} + +func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLogin, info *ChatInfo) error { portal.roomCreateLock.Lock() defer portal.roomCreateLock.Unlock() if portal.MXID != "" { @@ -2510,7 +2560,6 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i } } } - // TODO backfill portal? if portal.Parent == nil { userPortals, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) if err != nil { @@ -2525,6 +2574,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i } } } + portal.doForwardBackfill(ctx, source, nil) return nil } From 910e3ee771c7a226686396d691a1034e7f88eb7c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 18 Jul 2024 17:53:20 +0300 Subject: [PATCH 0476/1647] bridgev2/backfill: create new timer every time --- bridgev2/backfillqueue.go | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/bridgev2/backfillqueue.go b/bridgev2/backfillqueue.go index f9cb8010..dd398dcc 100644 --- a/bridgev2/backfillqueue.go +++ b/bridgev2/backfillqueue.go @@ -18,7 +18,7 @@ import ( const BackfillMinBackoffAfterRoomCreate = 1 * time.Minute const BackfillQueueErrorBackoff = 1 * time.Minute -const BackfillQueueMinEmptyBackoff = 10 * time.Minute +const BackfillQueueMaxEmptyBackoff = 10 * time.Minute func (br *Bridge) WakeupBackfillQueue() { select { @@ -42,8 +42,8 @@ func (br *Bridge) RunBackfillQueue() { cancel() }() batchDelay := time.Duration(br.Config.Backfill.Queue.BatchDelay) * time.Second - afterTimer := time.NewTimer(batchDelay) log.Info().Stringer("batch_delay", batchDelay).Msg("Backfill queue starting") + noTasksFoundCount := 0 for { backfillTask, err := br.DB.BackfillTask.GetNext(ctx) if err != nil { @@ -52,22 +52,33 @@ func (br *Bridge) RunBackfillQueue() { continue } else if backfillTask != nil { br.doBackfillTask(ctx, backfillTask) + noTasksFoundCount = 0 } nextDelay := batchDelay - if backfillTask == nil { - nextDelay = max(BackfillQueueMinEmptyBackoff, batchDelay) + if noTasksFoundCount > 0 { + extraDelay := batchDelay * time.Duration(noTasksFoundCount) + nextDelay += min(BackfillQueueMaxEmptyBackoff, extraDelay) } - if !afterTimer.Stop() { - <-afterTimer.C - } - afterTimer.Reset(nextDelay) + timer := time.NewTimer(nextDelay) select { case <-br.wakeupBackfillQueue: + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + noTasksFoundCount = 0 case <-br.stopBackfillQueue: - afterTimer.Stop() + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } log.Info().Msg("Stopping backfill queue") return - case <-afterTimer.C: + case <-timer.C: } } } From b395abf62ed1d1f86058390b78544eb78fd1044e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 18 Jul 2024 18:22:12 +0300 Subject: [PATCH 0477/1647] bridgev2: add extra metadata to SendMessage calls --- bridgev2/commands/event.go | 7 ++- bridgev2/commands/login.go | 5 +-- bridgev2/disappear.go | 2 +- bridgev2/matrix/intent.go | 15 +++++-- bridgev2/matrixinterface.go | 9 +++- bridgev2/portal.go | 89 +++++++++++++++++++++---------------- bridgev2/queue.go | 2 +- 7 files changed, 77 insertions(+), 52 deletions(-) diff --git a/bridgev2/commands/event.go b/bridgev2/commands/event.go index 52d74512..258ae2f0 100644 --- a/bridgev2/commands/event.go +++ b/bridgev2/commands/event.go @@ -10,7 +10,6 @@ import ( "context" "fmt" "strings" - "time" "github.com/rs/zerolog" @@ -56,7 +55,7 @@ func (ce *Event) Reply(msg string, args ...any) { func (ce *Event) ReplyAdvanced(msg string, allowMarkdown, allowHTML bool) { content := format.RenderMarkdown(msg, allowMarkdown, allowHTML) content.MsgType = event.MsgNotice - _, err := ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: &content}, time.Now()) + _, err := ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: &content}, nil) if err != nil { ce.Log.Err(err).Msgf("Failed to reply to command") } @@ -72,7 +71,7 @@ func (ce *Event) React(key string) { Key: key, }, }, - }, time.Now()) + }, nil) if err != nil { ce.Log.Err(err).Msgf("Failed to react to command") } @@ -84,7 +83,7 @@ func (ce *Event) Redact(req ...mautrix.ReqRedact) { Parsed: &event.RedactionEventContent{ Redacts: ce.EventID, }, - }, time.Now()) + }, nil) if err != nil { ce.Log.Err(err).Msgf("Failed to redact command") } diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go index f8c0e402..c4f4471b 100644 --- a/bridgev2/commands/login.go +++ b/bridgev2/commands/login.go @@ -13,7 +13,6 @@ import ( "net/http" "regexp" "strings" - "time" "github.com/skip2/go-qrcode" "golang.org/x/net/html" @@ -148,7 +147,7 @@ func sendQR(ce *Event, qr string, prevEventID *id.EventID) error { if *prevEventID != "" { content.SetEdit(*prevEventID) } - newEventID, err := ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: content}, time.Now()) + newEventID, err := ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: content}, nil) if err != nil { return err } @@ -196,7 +195,7 @@ func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait, Parsed: &event.RedactionEventContent{ Redacts: *prevEvent, }, - }, time.Now()) + }, nil) *prevEvent = "" } if err != nil { diff --git a/bridgev2/disappear.go b/bridgev2/disappear.go index 089c8aef..5f9900a5 100644 --- a/bridgev2/disappear.go +++ b/bridgev2/disappear.go @@ -91,7 +91,7 @@ func (dl *DisappearLoop) sleepAndDisappear(ctx context.Context, dms ...*database Redacts: msg.EventID, Reason: "Message disappeared", }, - }, time.Now()) + }, nil) if err != nil { zerolog.Ctx(ctx).Err(err).Stringer("target_event_id", msg.EventID).Msg("Failed to disappear message") } else { diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index da5e63ce..8f009f03 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -31,7 +31,10 @@ type ASIntent struct { var _ bridgev2.MatrixAPI = (*ASIntent)(nil) -func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, ts time.Time) (*mautrix.RespSendEvent, error) { +func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, extra *bridgev2.MatrixSendExtra) (*mautrix.RespSendEvent, error) { + if extra == nil { + extra = &bridgev2.MatrixSendExtra{} + } // TODO remove this once hungryserv and synapse support sending m.room.redactions directly in all room versions if eventType == event.EventRedaction { parsedContent := content.Parsed.(*event.RedactionEventContent) @@ -45,7 +48,11 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType return nil, fmt.Errorf("failed to check if room is encrypted: %w", err) } else if encrypted { if as.Matrix.IsCustomPuppet { - as.Matrix.AddDoublePuppetValueWithTS(content, ts.UnixMilli()) + if extra.Timestamp.IsZero() { + as.Matrix.AddDoublePuppetValue(content) + } else { + as.Matrix.AddDoublePuppetValueWithTS(content, extra.Timestamp.UnixMilli()) + } } err = as.Connector.Crypto.Encrypt(ctx, roomID, eventType, content) if err != nil { @@ -54,10 +61,10 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType eventType = event.EventEncrypted } } - if ts.IsZero() { + if extra.Timestamp.IsZero() { return as.Matrix.SendMessageEvent(ctx, roomID, eventType, content) } else { - return as.Matrix.SendMassagedMessageEvent(ctx, roomID, eventType, content, ts.UnixMilli()) + return as.Matrix.SendMassagedMessageEvent(ctx, roomID, eventType, content, extra.Timestamp.UnixMilli()) } } diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 49f03c07..67a9e3e1 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -14,6 +14,7 @@ import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridge/status" + "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -56,10 +57,16 @@ type MatrixConnectorWithServer interface { GetRouter() *mux.Router } +type MatrixSendExtra struct { + Timestamp time.Time + MessageMeta *database.Message + ReactionMeta *database.Reaction +} + type MatrixAPI interface { GetMXID() id.UserID - SendMessage(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, ts time.Time) (*mautrix.RespSendEvent, error) + SendMessage(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, extra *MatrixSendExtra) (*mautrix.RespSendEvent, error) SendState(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, content *event.Content, ts time.Time) (*mautrix.RespSendEvent, error) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.EventID, ts time.Time) error MarkUnread(ctx context.Context, roomID id.RoomID, unread bool) error diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 73282a40..01c72d8c 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -885,7 +885,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi Parsed: &event.RedactionEventContent{ Redacts: existing.MXID, }, - }, time.Now()) + }, nil) if err != nil { log.Err(err).Msg("Failed to remove old reaction") } @@ -908,7 +908,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi Parsed: &event.RedactionEventContent{ Redacts: oldReaction.MXID, }, - }, time.Now()) + }, nil) if err != nil { log.Err(err).Msg("Failed to remove previous reaction after limit was exceeded") } @@ -1298,22 +1298,9 @@ func (portal *Portal) sendConvertedMessage(ctx context.Context, id networkid.Mes output := make([]*database.Message, 0, len(converted.Parts)) for _, part := range converted.Parts { portal.applyRelationMeta(part.Content, replyTo, threadRoot, prevThreadEvent) - resp, err := intent.SendMessage(ctx, portal.MXID, part.Type, &event.Content{ - Parsed: part.Content, - Raw: part.Extra, - }, ts) - if err != nil { - logContext(log.Err(err)).Str("part_id", string(part.ID)).Msg("Failed to send message part to Matrix") - continue - } - logContext(log.Debug()). - Stringer("event_id", resp.EventID). - Str("part_id", string(part.ID)). - Msg("Sent message part to Matrix") dbMessage := &database.Message{ ID: id, PartID: part.ID, - MXID: resp.EventID, Room: portal.PortalKey, SenderID: sender.Sender, SenderMXID: intent.GetMXID(), @@ -1322,6 +1309,22 @@ func (portal *Portal) sendConvertedMessage(ctx context.Context, id networkid.Mes ReplyTo: ptr.Val(converted.ReplyTo), Metadata: part.DBMetadata, } + resp, err := intent.SendMessage(ctx, portal.MXID, part.Type, &event.Content{ + Parsed: part.Content, + Raw: part.Extra, + }, &MatrixSendExtra{ + Timestamp: ts, + MessageMeta: dbMessage, + }) + if err != nil { + logContext(log.Err(err)).Str("part_id", string(part.ID)).Msg("Failed to send message part to Matrix") + continue + } + logContext(log.Debug()). + Stringer("event_id", resp.EventID). + Str("part_id", string(part.ID)). + Msg("Sent message part to Matrix") + dbMessage.MXID = resp.EventID err = portal.Bridge.DB.Message.Insert(ctx, dbMessage) if err != nil { logContext(log.Err(err)).Str("part_id", string(part.ID)).Msg("Failed to save message part to database") @@ -1377,7 +1380,9 @@ func (portal *Portal) sendRemoteErrorNotice(ctx context.Context, intent MatrixAP Raw: map[string]any{ "fi.mau.bridge.internal_error": err.Error(), }, - }, ts) + }, &MatrixSendExtra{ + Timestamp: ts, + }) if sendErr != nil { zerolog.Ctx(ctx).Err(sendErr).Msg("Failed to send error notice after remote event handling failed") } else { @@ -1418,7 +1423,10 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e Parsed: part.Content, Raw: part.TopLevelExtra, } - resp, err := intent.SendMessage(ctx, portal.MXID, part.Type, wrappedContent, ts) + resp, err := intent.SendMessage(ctx, portal.MXID, part.Type, wrappedContent, &MatrixSendExtra{ + Timestamp: ts, + MessageMeta: part.Part, + }) if err != nil { log.Err(err).Stringer("part_mxid", part.Part.MXID).Msg("Failed to edit message part") } else { @@ -1438,7 +1446,9 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e Redacts: part.MXID, }, } - resp, err := intent.SendMessage(ctx, portal.MXID, event.EventRedaction, redactContent, ts) + resp, err := intent.SendMessage(ctx, portal.MXID, event.EventRedaction, redactContent, &MatrixSendExtra{ + Timestamp: ts, + }) if err != nil { log.Err(err).Stringer("part_mxid", part.MXID).Msg("Failed to redact message part deleted in edit") } else { @@ -1500,6 +1510,20 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi if extraContentProvider, ok := evt.(RemoteReactionWithExtraContent); ok { extra = extraContentProvider.GetReactionExtraContent() } + dbReaction := &database.Reaction{ + Room: portal.PortalKey, + MessageID: targetMessage.ID, + MessagePartID: targetMessage.PartID, + SenderID: evt.GetSender().Sender, + EmojiID: emojiID, + Timestamp: ts, + } + if emojiID == "" { + dbReaction.Emoji = emoji + } + if metaProvider, ok := evt.(RemoteReactionWithMeta); ok { + dbReaction.Metadata = metaProvider.GetReactionDBMetadata() + } resp, err := intent.SendMessage(ctx, portal.MXID, event.EventReaction, &event.Content{ Parsed: &event.ReactionEventContent{ RelatesTo: event.RelatesTo{ @@ -1509,7 +1533,10 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi }, }, Raw: extra, - }, ts) + }, &MatrixSendExtra{ + Timestamp: ts, + ReactionMeta: dbReaction, + }) if err != nil { log.Err(err).Msg("Failed to send reaction to Matrix") return @@ -1517,21 +1544,7 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi log.Debug(). Stringer("event_id", resp.EventID). Msg("Sent reaction to Matrix") - dbReaction := &database.Reaction{ - Room: portal.PortalKey, - MessageID: targetMessage.ID, - MessagePartID: targetMessage.PartID, - SenderID: evt.GetSender().Sender, - EmojiID: emojiID, - MXID: resp.EventID, - Timestamp: ts, - } - if metaProvider, ok := evt.(RemoteReactionWithMeta); ok { - dbReaction.Metadata = metaProvider.GetReactionDBMetadata() - } - if emojiID == "" { - dbReaction.Emoji = emoji - } + dbReaction.MXID = resp.EventID err = portal.Bridge.DB.Reaction.Upsert(ctx, dbReaction) if err != nil { log.Err(err).Msg("Failed to save reaction to database") @@ -1541,7 +1554,7 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi Parsed: &event.RedactionEventContent{ Redacts: existingReaction.MXID, }, - }, ts) + }, &MatrixSendExtra{Timestamp: ts}) if err != nil { log.Err(err).Msg("Failed to redact old reaction") } @@ -1564,7 +1577,7 @@ func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *Us Parsed: &event.RedactionEventContent{ Redacts: targetReaction.MXID, }, - }, ts) + }, &MatrixSendExtra{Timestamp: ts}) if err != nil { log.Err(err).Stringer("reaction_mxid", targetReaction.MXID).Msg("Failed to redact reaction") } @@ -1591,7 +1604,7 @@ func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *Use Parsed: &event.RedactionEventContent{ Redacts: part.MXID, }, - }, ts) + }, &MatrixSendExtra{Timestamp: ts}) if err != nil { log.Err(err).Stringer("part_mxid", part.MXID).Msg("Failed to redact message part") } else { @@ -2270,7 +2283,7 @@ func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting dat } _, err := sender.SendMessage(ctx, portal.MXID, event.EventMessage, &event.Content{ Parsed: content, - }, ts) + }, &MatrixSendExtra{Timestamp: ts}) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to send disappearing messages notice") } else { diff --git a/bridgev2/queue.go b/bridgev2/queue.go index 6254fd62..e5220227 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -129,7 +129,7 @@ func (br *Bridge) handleBotInvite(ctx context.Context, evt *event.Event, sender } _, err = br.Bot.SendMessage(ctx, evt.RoomID, event.EventMessage, &event.Content{ Parsed: format.RenderMarkdown(message, true, false), - }, time.Time{}) + }, nil) if err != nil { log.Err(err).Msg("Failed to send welcome message to room") } From a4b0b55db29d686cbc8562f93ba1230e3c93cc5a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 18 Jul 2024 19:57:43 +0300 Subject: [PATCH 0478/1647] bridgev2/backfill: add support for reactions --- bridgev2/matrix/connector.go | 9 +++- bridgev2/matrixinterface.go | 3 +- bridgev2/networkinterface.go | 2 +- bridgev2/portal.go | 90 ++++++++++++++++++++--------------- bridgev2/portalbackfill.go | 92 ++++++++++++++++++++++++++++++++---- 5 files changed, 147 insertions(+), 49 deletions(-) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 43be55b4..ca459e98 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -27,6 +27,7 @@ import ( "go.mau.fi/util/dbutil" _ "go.mau.fi/util/dbutil/litestream" "go.mau.fi/util/exsync" + "go.mau.fi/util/random" "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" @@ -34,6 +35,7 @@ import ( "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/bridgeconfig" "maunium.net/go/mautrix/bridgev2/commands" + "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -526,7 +528,7 @@ func (br *Connector) GetMemberInfo(ctx context.Context, roomID id.RoomID, userID return br.AS.StateStore.GetMember(ctx, roomID, userID) } -func (br *Connector) BatchSend(ctx context.Context, roomID id.RoomID, req *mautrix.ReqBeeperBatchSend) (*mautrix.RespBeeperBatchSend, error) { +func (br *Connector) BatchSend(ctx context.Context, roomID id.RoomID, req *mautrix.ReqBeeperBatchSend, extras []*bridgev2.MatrixSendExtra) (*mautrix.RespBeeperBatchSend, error) { if encrypted, err := br.StateStore.IsEncrypted(ctx, roomID); err != nil { return nil, fmt.Errorf("failed to check if room is encrypted: %w", err) } else if encrypted { @@ -563,6 +565,11 @@ func (br *Connector) GenerateDeterministicEventID(roomID id.RoomID, _ networkid. return id.EventID(unsafe.String(unsafe.SliceData(eventID), len(eventID))) } +func (br *Connector) GenerateReactionEventID(roomID id.RoomID, targetMessage *database.Message, sender networkid.UserID, emojiID networkid.EmojiID) id.EventID { + // We don't care about determinism for reactions + return id.EventID(fmt.Sprintf("$%s:%s", base64.RawURLEncoding.EncodeToString(random.Bytes(32)), br.deterministicEventIDServer)) +} + func (br *Connector) ServerName() string { return br.Config.Homeserver.Domain } diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 67a9e3e1..e25d0f06 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -46,8 +46,9 @@ type MatrixConnector interface { GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) GetMemberInfo(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) - BatchSend(ctx context.Context, roomID id.RoomID, req *mautrix.ReqBeeperBatchSend) (*mautrix.RespBeeperBatchSend, error) + BatchSend(ctx context.Context, roomID id.RoomID, req *mautrix.ReqBeeperBatchSend, extras []*MatrixSendExtra) (*mautrix.RespBeeperBatchSend, error) GenerateDeterministicEventID(roomID id.RoomID, portalKey networkid.PortalKey, messageID networkid.MessageID, partID networkid.PartID) id.EventID + GenerateReactionEventID(roomID id.RoomID, targetMessage *database.Message, sender networkid.UserID, emojiID networkid.EmojiID) id.EventID ServerName() string } diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index fe7d2350..48f6d29a 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -357,7 +357,7 @@ type BackfillReaction struct { EmojiID networkid.EmojiID Emoji string ExtraContent map[string]any - DBMetadata map[string]any + DBMetadata any } // BackfillMessage is an individual message in a history pagination request. diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 01c72d8c..b34b9a1c 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1510,45 +1510,11 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi if extraContentProvider, ok := evt.(RemoteReactionWithExtraContent); ok { extra = extraContentProvider.GetReactionExtraContent() } - dbReaction := &database.Reaction{ - Room: portal.PortalKey, - MessageID: targetMessage.ID, - MessagePartID: targetMessage.PartID, - SenderID: evt.GetSender().Sender, - EmojiID: emojiID, - Timestamp: ts, - } - if emojiID == "" { - dbReaction.Emoji = emoji - } + var dbMetadata any if metaProvider, ok := evt.(RemoteReactionWithMeta); ok { - dbReaction.Metadata = metaProvider.GetReactionDBMetadata() - } - resp, err := intent.SendMessage(ctx, portal.MXID, event.EventReaction, &event.Content{ - Parsed: &event.ReactionEventContent{ - RelatesTo: event.RelatesTo{ - Type: event.RelAnnotation, - EventID: targetMessage.MXID, - Key: variationselector.Add(emoji), - }, - }, - Raw: extra, - }, &MatrixSendExtra{ - Timestamp: ts, - ReactionMeta: dbReaction, - }) - if err != nil { - log.Err(err).Msg("Failed to send reaction to Matrix") - return - } - log.Debug(). - Stringer("event_id", resp.EventID). - Msg("Sent reaction to Matrix") - dbReaction.MXID = resp.EventID - err = portal.Bridge.DB.Reaction.Upsert(ctx, dbReaction) - if err != nil { - log.Err(err).Msg("Failed to save reaction to database") + dbMetadata = metaProvider.GetReactionDBMetadata() } + portal.sendConvertedReaction(ctx, evt.GetSender().Sender, intent, targetMessage, emojiID, emoji, ts, dbMetadata, extra, nil) if existingReaction != nil { _, err = intent.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ Parsed: &event.RedactionEventContent{ @@ -1561,6 +1527,56 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi } } +func (portal *Portal) sendConvertedReaction( + ctx context.Context, senderID networkid.UserID, intent MatrixAPI, targetMessage *database.Message, + emojiID networkid.EmojiID, emoji string, ts time.Time, dbMetadata any, extraContent map[string]any, + logContext func(*zerolog.Event) *zerolog.Event, +) { + if logContext == nil { + logContext = func(e *zerolog.Event) *zerolog.Event { + return e + } + } + log := zerolog.Ctx(ctx) + dbReaction := &database.Reaction{ + Room: portal.PortalKey, + MessageID: targetMessage.ID, + MessagePartID: targetMessage.PartID, + SenderID: senderID, + EmojiID: emojiID, + Timestamp: ts, + Metadata: dbMetadata, + } + if emojiID == "" { + dbReaction.Emoji = emoji + } + resp, err := intent.SendMessage(ctx, portal.MXID, event.EventReaction, &event.Content{ + Parsed: &event.ReactionEventContent{ + RelatesTo: event.RelatesTo{ + Type: event.RelAnnotation, + EventID: targetMessage.MXID, + Key: variationselector.Add(emoji), + }, + }, + Raw: extraContent, + }, &MatrixSendExtra{ + Timestamp: ts, + ReactionMeta: dbReaction, + }) + if err != nil { + logContext(log.Err(err)).Msg("Failed to send reaction to Matrix") + return + } + logContext(log.Debug()). + Stringer("event_id", resp.EventID). + Msg("Sent reaction to Matrix") + dbReaction.MXID = resp.EventID + err = portal.Bridge.DB.Reaction.Upsert(ctx, dbReaction) + if err != nil { + logContext(log.Err(err)).Msg("Failed to save reaction to database") + } +} + func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *UserLogin, evt RemoteReactionRemove) { log := zerolog.Ctx(ctx) targetReaction, err := portal.getTargetReaction(ctx, evt) diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index d23341dd..93e1aee5 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -9,10 +9,12 @@ package bridgev2 import ( "context" "fmt" + "slices" "time" "github.com/rs/zerolog" "go.mau.fi/util/ptr" + "go.mau.fi/util/variationselector" "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2/database" @@ -238,6 +240,8 @@ func (portal *Portal) sendBatch(ctx context.Context, source *UserLogin, messages } prevThreadEvents := make(map[networkid.MessageID]id.EventID) dbMessages := make([]*database.Message, 0, len(messages)) + dbReactions := make([]*database.Reaction, 0) + extras := make([]*MatrixSendExtra, 0, len(messages)) var disappearingMessages []*database.DisappearingMessage for _, msg := range messages { intent := portal.GetIntentFor(ctx, msg.Sender, source, RemoteEventMessage) @@ -245,7 +249,10 @@ func (portal *Portal) sendBatch(ctx context.Context, source *UserLogin, messages if threadRoot != nil && prevThreadEvents[*msg.ThreadRoot] != "" { prevThreadEvent.MXID = prevThreadEvents[*msg.ThreadRoot] } + var partIDs []networkid.PartID + partMap := make(map[networkid.PartID]*database.Message, len(msg.Parts)) for _, part := range msg.Parts { + partIDs = append(partIDs, part.ID) portal.applyRelationMeta(part.Content, replyTo, threadRoot, prevThreadEvent) evtID := portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, msg.ID, part.ID) req.Events = append(req.Events, &event.Event{ @@ -259,7 +266,7 @@ func (portal *Portal) sendBatch(ctx context.Context, source *UserLogin, messages Raw: part.Extra, }, }) - dbMessages = append(dbMessages, &database.Message{ + dbMessage := &database.Message{ ID: msg.ID, PartID: part.ID, MXID: evtID, @@ -270,7 +277,10 @@ func (portal *Portal) sendBatch(ctx context.Context, source *UserLogin, messages ThreadRoot: ptr.Val(msg.ThreadRoot), ReplyTo: ptr.Val(msg.ReplyTo), Metadata: part.DBMetadata, - }) + } + partMap[part.ID] = dbMessage + extras = append(extras, &MatrixSendExtra{MessageMeta: dbMessage}) + dbMessages = append(dbMessages, dbMessage) if prevThreadEvent != nil { prevThreadEvent.MXID = evtID prevThreadEvents[*msg.ThreadRoot] = evtID @@ -286,11 +296,53 @@ func (portal *Portal) sendBatch(ctx context.Context, source *UserLogin, messages }) } } - // TODO handle reactions - //for _, reaction := range msg.Reactions { - //} + slices.Sort(partIDs) + for _, reaction := range msg.Reactions { + reactionIntent := portal.GetIntentFor(ctx, reaction.Sender, source, RemoteEventReactionRemove) + if reaction.TargetPart == nil { + reaction.TargetPart = &partIDs[0] + } + if reaction.Timestamp.IsZero() { + reaction.Timestamp = msg.Timestamp.Add(10 * time.Millisecond) + } + targetPart, ok := partMap[*reaction.TargetPart] + if !ok { + // TODO warning log and/or skip reaction? + } + reactionMXID := portal.Bridge.Matrix.GenerateReactionEventID(portal.MXID, targetPart, reaction.Sender.Sender, reaction.EmojiID) + dbReaction := &database.Reaction{ + Room: portal.PortalKey, + MessageID: msg.ID, + MessagePartID: *reaction.TargetPart, + SenderID: reaction.Sender.Sender, + EmojiID: reaction.EmojiID, + MXID: reactionMXID, + Timestamp: reaction.Timestamp, + Emoji: reaction.Emoji, + Metadata: reaction.DBMetadata, + } + req.Events = append(req.Events, &event.Event{ + Sender: reactionIntent.GetMXID(), + Type: event.EventReaction, + Timestamp: reaction.Timestamp.UnixMilli(), + ID: reactionMXID, + RoomID: portal.MXID, + Content: event.Content{ + Parsed: &event.ReactionEventContent{ + RelatesTo: event.RelatesTo{ + Type: event.RelAnnotation, + EventID: portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, msg.ID, *reaction.TargetPart), + Key: variationselector.Add(reaction.Emoji), + }, + }, + Raw: reaction.ExtraContent, + }, + }) + dbReactions = append(dbReactions, dbReaction) + extras = append(extras, &MatrixSendExtra{ReactionMeta: dbReaction}) + } } - _, err := portal.Bridge.Matrix.BatchSend(ctx, portal.MXID, req) + _, err := portal.Bridge.Matrix.BatchSend(ctx, portal.MXID, req, extras) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to send backfill messages") } @@ -326,10 +378,32 @@ func (portal *Portal) sendLegacyBackfill(ctx context.Context, source *UserLogin, }) if len(dbMessages) > 0 { lastPart = dbMessages[len(dbMessages)-1].MXID + for _, reaction := range msg.Reactions { + reactionIntent := portal.GetIntentFor(ctx, reaction.Sender, source, RemoteEventReaction) + targetPart := dbMessages[0] + if reaction.TargetPart != nil { + targetPartIdx := slices.IndexFunc(dbMessages, func(dbMsg *database.Message) bool { + return dbMsg.PartID == *reaction.TargetPart + }) + if targetPartIdx != -1 { + targetPart = dbMessages[targetPartIdx] + } else { + // TODO warning log and/or skip reaction? + } + } + portal.sendConvertedReaction( + ctx, reaction.Sender.Sender, reactionIntent, targetPart, reaction.EmojiID, reaction.Emoji, + reaction.Timestamp, reaction.DBMetadata, reaction.ExtraContent, + func(z *zerolog.Event) *zerolog.Event { + return z. + Str("target_message_id", string(msg.ID)). + Str("target_part_id", string(targetPart.PartID)). + Any("reaction_sender_id", reaction.Sender). + Time("reaction_ts", reaction.Timestamp) + }, + ) + } } - // TODO handle reactions - //for _, reaction := range msg.Reactions { - //} } if markRead { dp := source.User.DoublePuppet(ctx) From 26776481884e7f342cd57e7fcf90e46ab18d3445 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Jul 2024 12:48:07 +0300 Subject: [PATCH 0479/1647] bridgev2/backfill: wake up backfill queue after creating task --- bridgev2/portal.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index b34b9a1c..27eed5e0 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2393,6 +2393,7 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to ensure backfill queue task exists") } + // TODO wake up backfill queue if task was just created } if info.ExtraUpdates != nil { changed = info.ExtraUpdates(ctx, portal) || changed @@ -2563,6 +2564,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo if err != nil { log.Err(err).Msg("Failed to create backfill queue task after creating room") } + portal.Bridge.WakeupBackfillQueue() } if portal.Parent != nil { if portal.Parent.MXID != "" { From 804fd19bb9f3e081e1f3d7de3ef47ad434968adc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Jul 2024 13:18:30 +0300 Subject: [PATCH 0480/1647] bridgev2/legacymigrate: move post-migration DM portal fixing from slack --- bridgev2/bridge.go | 17 +++++ bridgev2/matrix/mxmain/legacymigrate.go | 95 +++++++++++++++++++++++++ bridgev2/matrix/mxmain/main.go | 11 ++- 3 files changed, 122 insertions(+), 1 deletion(-) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 76a5d2c8..1bc364a1 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -104,6 +104,18 @@ func (e DBUpgradeError) Unwrap() error { } func (br *Bridge) Start() error { + err := br.StartConnectors() + if err != nil { + return err + } + err = br.StartLogins() + if err != nil { + return err + } + return nil +} + +func (br *Bridge) StartConnectors() error { br.Log.Info().Msg("Starting bridge") ctx := br.Log.WithContext(context.Background()) @@ -124,6 +136,11 @@ func (br *Bridge) Start() error { if br.Network.GetCapabilities().DisappearingMessages { go br.DisappearLoop.Start() } + return nil +} + +func (br *Bridge) StartLogins() error { + ctx := br.Log.WithContext(context.Background()) userIDs, err := br.DB.UserLogin.GetAllUserIDsWithLogins(ctx) if err != nil { diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go index 8b62708d..d1f0f279 100644 --- a/bridgev2/matrix/mxmain/legacymigrate.go +++ b/bridgev2/matrix/mxmain/legacymigrate.go @@ -12,6 +12,15 @@ import ( "database/sql" "errors" "fmt" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/appservice" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/matrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" ) func (br *BridgeMain) LegacyMigrateSimple(renameTablesQuery, copyDataQuery string, newDBVersion int) func(ctx context.Context) error { @@ -43,6 +52,10 @@ func (br *BridgeMain) LegacyMigrateSimple(renameTablesQuery, copyDataQuery strin if err != nil { return err } + _, err = br.DB.Exec(ctx, "CREATE TABLE database_was_migrated()") + if err != nil { + return err + } return nil } @@ -96,3 +109,85 @@ func (br *BridgeMain) CheckLegacyDB(expectedVersion int, minBridgeVersion, first log.Info().Msg("Successfully migrated legacy database") } } + +func (br *BridgeMain) postMigrateDMPortal(ctx context.Context, portal *bridgev2.Portal) error { + otherUserID := portal.OtherUserID + if otherUserID == "" { + zerolog.Ctx(ctx).Warn(). + Str("portal_id", string(portal.ID)). + Msg("DM portal has no other user ID") + return nil + } + ghost, err := br.Bridge.GetGhostByID(ctx, otherUserID) + if err != nil { + return fmt.Errorf("failed to get ghost for %s: %w", otherUserID, err) + } + mx := ghost.Intent.(*matrix.ASIntent).Matrix + err = br.Matrix.Bot.EnsureJoined(ctx, portal.MXID, appservice.EnsureJoinedParams{ + BotOverride: mx.Client, + }) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Str("portal_id", string(portal.ID)). + Stringer("room_id", portal.MXID). + Msg("Failed to ensure bot is joined to DM") + } + pls, err := mx.PowerLevels(ctx, portal.MXID) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Str("portal_id", string(portal.ID)). + Stringer("room_id", portal.MXID). + Msg("Failed to get power levels in room") + } else { + userLevel := pls.GetUserLevel(mx.UserID) + pls.EnsureUserLevel(br.Matrix.Bot.UserID, userLevel) + if userLevel > 50 { + pls.SetUserLevel(mx.UserID, 50) + } + _, err = mx.SetPowerLevels(ctx, portal.MXID, pls) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Str("portal_id", string(portal.ID)). + Stringer("room_id", portal.MXID). + Msg("Failed to set power levels") + } + } + return nil +} + +func (br *BridgeMain) PostMigrate(ctx context.Context) error { + wasMigrated, err := br.DB.TableExists(ctx, "database_was_migrated") + if err != nil { + return fmt.Errorf("failed to check if database_was_migrated table exists: %w", err) + } else if !wasMigrated { + return nil + } + zerolog.Ctx(ctx).Info().Msg("Doing post-migration updates to Matrix rooms") + + portals, err := br.Bridge.GetAllPortalsWithMXID(ctx) + if err != nil { + return fmt.Errorf("failed to get all portals: %w", err) + } + for _, portal := range portals { + switch portal.RoomType { + case database.RoomTypeDM: + err = br.postMigrateDMPortal(ctx, portal) + if err != nil { + return fmt.Errorf("failed to update DM portal %s: %w", portal.MXID, err) + } + } + _, err = br.Matrix.Bot.SendStateEvent(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.ElementFunctionalMembersContent{ + ServiceMembers: []id.UserID{br.Matrix.Bot.UserID}, + }) + if err != nil { + zerolog.Ctx(ctx).Warn().Err(err).Stringer("room_id", portal.MXID).Msg("Failed to set service members") + } + } + + _, err = br.DB.Exec(ctx, "DROP TABLE database_was_migrated") + if err != nil { + return fmt.Errorf("failed to drop database_was_migrated table: %w", err) + } + zerolog.Ctx(ctx).Info().Msg("Post-migration updates complete") + return nil +} diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index 1a7c5217..f02decb6 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -8,6 +8,7 @@ package mxmain import ( + "context" _ "embed" "encoding/json" "errors" @@ -358,7 +359,7 @@ func (br *BridgeMain) LoadConfig() { // Start starts the bridge after everything has been initialized. // This is called by [Run] and does not need to be called manually. func (br *BridgeMain) Start() { - err := br.Bridge.Start() + err := br.Bridge.StartConnectors() if err != nil { var dbUpgradeErr bridgev2.DBUpgradeError if errors.As(err, &dbUpgradeErr) { @@ -367,6 +368,14 @@ func (br *BridgeMain) Start() { br.Log.Fatal().Err(err).Msg("Failed to start bridge") } } + err = br.PostMigrate(br.Log.WithContext(context.Background())) + if err != nil { + br.Log.Fatal().Err(err).Msg("Failed to run post-migration updates") + } + err = br.Bridge.StartLogins() + if err != nil { + br.Log.Fatal().Err(err).Msg("Failed to start existing user logins") + } if br.PostStart != nil { br.PostStart() } From b8a067206a08b285298983aaef1717b63809c7e4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Jul 2024 13:56:32 +0300 Subject: [PATCH 0481/1647] bridgev2: implement `private_chat_portal_meta` option --- bridgev2/bridgeconfig/config.go | 1 + bridgev2/bridgeconfig/upgrade.go | 6 +++ bridgev2/database/portal.go | 20 +++++--- bridgev2/database/upgrades/00-latest.sql | 3 +- .../upgrades/14-portal-name-custom.sql | 2 + bridgev2/ghost.go | 17 +++++++ bridgev2/matrix/mxmain/example-config.yaml | 3 ++ bridgev2/matrix/mxmain/legacymigrate.go | 1 + bridgev2/portal.go | 51 ++++++++++++++++++- 9 files changed, 95 insertions(+), 9 deletions(-) create mode 100644 bridgev2/database/upgrades/14-portal-name-custom.sql diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index 8c899ad7..3bd173b1 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -34,6 +34,7 @@ type Config struct { type BridgeConfig struct { CommandPrefix string `yaml:"command_prefix"` PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"` + PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"` Relay RelayConfig `yaml:"relay"` Permissions PermissionConfig `yaml:"permissions"` Backfill BackfillConfig `yaml:"backfill"` diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 04f7dab3..07065fc6 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -24,6 +24,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Str, "bridge", "command_prefix") helper.Copy(up.Bool, "bridge", "personal_filtering_spaces") + helper.Copy(up.Bool, "bridge", "private_chat_portal_meta") helper.Copy(up.Bool, "bridge", "relay", "enabled") helper.Copy(up.Bool, "bridge", "relay", "admin_only") helper.Copy(up.List, "bridge", "relay", "default_relays") @@ -183,6 +184,11 @@ func doMigrateLegacy(helper up.Helper) { helper.Copy(up.Str, "bridge", "command_prefix") helper.Copy(up.Bool, "bridge", "personal_filtering_spaces") + if oldPM, ok := helper.Get(up.Str, "bridge", "private_chat_portal_meta"); ok && (oldPM == "default" || oldPM == "always") { + helper.Set(up.Bool, "true", "bridge", "private_chat_portal_meta") + } else { + helper.Set(up.Bool, "false", "bridge", "private_chat_portal_meta") + } helper.Copy(up.Bool, "bridge", "relay", "enabled") helper.Copy(up.Bool, "bridge", "relay", "admin_only") helper.Copy(up.Map, "bridge", "permissions") diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index 2f675593..18e50780 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -50,6 +50,7 @@ type Portal struct { NameSet bool TopicSet bool AvatarSet bool + NameIsCustom bool InSpace bool RoomType RoomType Disappear DisappearingSetting @@ -60,7 +61,7 @@ const ( getPortalBaseQuery = ` SELECT bridge_id, id, receiver, mxid, parent_id, relay_login_id, other_user_id, name, topic, avatar_id, avatar_hash, avatar_mxc, - name_set, topic_set, avatar_set, in_space, + name_set, topic_set, avatar_set, name_is_custom, in_space, room_type, disappear_type, disappear_timer, metadata FROM portal @@ -69,6 +70,7 @@ const ( getPortalByIDWithUncertainReceiverQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='')` getPortalByMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` getAllPortalsWithMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid IS NOT NULL` + getAllDMPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND other_user_id<>''` getAllPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1` getChildPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND parent_id=$2` @@ -79,11 +81,11 @@ const ( bridge_id, id, receiver, mxid, parent_id, relay_login_id, other_user_id, name, topic, avatar_id, avatar_hash, avatar_mxc, - name_set, avatar_set, topic_set, in_space, + name_set, avatar_set, topic_set, name_is_custom, in_space, room_type, disappear_type, disappear_timer, metadata, relay_bridge_id ) VALUES ( - $1, $2, $3, $4, $5, cast($6 AS TEXT), $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, + $1, $2, $3, $4, $5, cast($6 AS TEXT), $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, CASE WHEN cast($6 AS TEXT) IS NULL THEN NULL ELSE $1 END ) ` @@ -91,8 +93,8 @@ const ( UPDATE portal SET mxid=$4, parent_id=$5, relay_login_id=cast($6 AS TEXT), relay_bridge_id=CASE WHEN cast($6 AS TEXT) IS NULL THEN NULL ELSE bridge_id END, other_user_id=$7, name=$8, topic=$9, avatar_id=$10, avatar_hash=$11, avatar_mxc=$12, - name_set=$13, avatar_set=$14, topic_set=$15, in_space=$16, - room_type=$17, disappear_type=$18, disappear_timer=$19, metadata=$20 + name_set=$13, avatar_set=$14, topic_set=$15, name_is_custom=$16, in_space=$17, + room_type=$18, disappear_type=$19, disappear_timer=$20, metadata=$21 WHERE bridge_id=$1 AND id=$2 AND receiver=$3 ` deletePortalQuery = ` @@ -130,6 +132,10 @@ func (pq *PortalQuery) GetAll(ctx context.Context) ([]*Portal, error) { return pq.QueryMany(ctx, getAllPortalsQuery, pq.BridgeID) } +func (pq *PortalQuery) GetAllDMsWith(ctx context.Context, otherUserID networkid.UserID) ([]*Portal, error) { + return pq.QueryMany(ctx, getAllDMPortalsQuery, pq.BridgeID, otherUserID) +} + func (pq *PortalQuery) GetChildren(ctx context.Context, parentID networkid.PortalID) ([]*Portal, error) { return pq.QueryMany(ctx, getChildPortalsQuery, pq.BridgeID, parentID) } @@ -160,7 +166,7 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { &p.BridgeID, &p.ID, &p.Receiver, &mxid, &parentID, &relayLoginID, &otherUserID, &p.Name, &p.Topic, &p.AvatarID, &avatarHash, &p.AvatarMXC, - &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.InSpace, + &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.NameIsCustom, &p.InSpace, &p.RoomType, &disappearType, &disappearTimer, dbutil.JSON{Data: p.Metadata}, ) @@ -202,7 +208,7 @@ func (p *Portal) sqlVariables() []any { p.BridgeID, p.ID, p.Receiver, dbutil.StrPtr(p.MXID), dbutil.StrPtr(p.ParentID), dbutil.StrPtr(p.RelayLoginID), dbutil.StrPtr(p.OtherUserID), p.Name, p.Topic, p.AvatarID, avatarHash, p.AvatarMXC, - p.NameSet, p.TopicSet, p.AvatarSet, p.InSpace, + p.NameSet, p.TopicSet, p.AvatarSet, p.NameIsCustom, p.InSpace, p.RoomType, dbutil.StrPtr(p.Disappear.Type), dbutil.NumPtr(p.Disappear.Timer), dbutil.JSON{Data: p.Metadata}, } diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index d02964a2..304bb00e 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v13 (compatible with v9+): Latest revision +-- v0 -> v14 (compatible with v9+): Latest revision CREATE TABLE "user" ( bridge_id TEXT NOT NULL, mxid TEXT NOT NULL, @@ -47,6 +47,7 @@ CREATE TABLE portal ( name_set BOOLEAN NOT NULL, avatar_set BOOLEAN NOT NULL, topic_set BOOLEAN NOT NULL, + name_is_custom BOOLEAN NOT NULL DEFAULT false, in_space BOOLEAN NOT NULL, room_type TEXT NOT NULL, disappear_type TEXT, diff --git a/bridgev2/database/upgrades/14-portal-name-custom.sql b/bridgev2/database/upgrades/14-portal-name-custom.sql new file mode 100644 index 00000000..2c8dfc8f --- /dev/null +++ b/bridgev2/database/upgrades/14-portal-name-custom.sql @@ -0,0 +1,2 @@ +-- v14 (compatible with v9+): Save whether name is custom in portals +ALTER TABLE portal ADD COLUMN name_is_custom BOOLEAN NOT NULL DEFAULT false; diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index 125cd9c0..bcc905ab 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -230,6 +230,20 @@ func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin } } +func (ghost *Ghost) updateDMPortals(ctx context.Context) { + if !ghost.Bridge.Config.PrivateChatPortalMeta { + return + } + dmPortals, err := ghost.Bridge.GetDMPortalsWith(ctx, ghost.ID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get DM portals to update info") + return + } + for _, portal := range dmPortals { + go portal.lockedUpdateInfoFromGhost(ctx, ghost) + } +} + func (ghost *Ghost) UpdateInfo(ctx context.Context, info *UserInfo) { update := false if info.Name != nil { @@ -238,6 +252,9 @@ func (ghost *Ghost) UpdateInfo(ctx context.Context, info *UserInfo) { if info.Avatar != nil { update = ghost.UpdateAvatar(ctx, info.Avatar) || update } + if update { + ghost.updateDMPortals(ctx) + } if info.Identifiers != nil || info.IsBot != nil { update = ghost.UpdateContactInfo(ctx, info.Identifiers, info.IsBot) || update } diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 92d4647c..16105cb0 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -4,6 +4,8 @@ bridge: command_prefix: '$<>' # Should the bridge create a space for each login containing the rooms that account is in? personal_filtering_spaces: true + # Whether the bridge should set names and avatars explicitly for DM portals. + private_chat_portal_meta: false # Settings for relay mode relay: @@ -139,6 +141,7 @@ matrix: delivery_receipts: false # Whether the bridge should send error notices via m.notice events when a message fails to bridge. message_error_notices: true + # Whether the bridge should update the m.direct account data event when double puppeting is enabled. sync_direct_chat_list: false # Whether created rooms should have federation enabled. If false, created portal rooms # will never be federated. Changing this option requires recreating rooms. diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go index d1f0f279..302bf300 100644 --- a/bridgev2/matrix/mxmain/legacymigrate.go +++ b/bridgev2/matrix/mxmain/legacymigrate.go @@ -152,6 +152,7 @@ func (br *BridgeMain) postMigrateDMPortal(ctx context.Context, portal *bridgev2. Msg("Failed to set power levels") } } + portal.UpdateInfoFromGhost(ctx, ghost) return nil } diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 27eed5e0..38a3ec5c 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -217,6 +217,16 @@ func (br *Bridge) GetAllPortals(ctx context.Context) ([]*Portal, error) { return br.loadManyPortals(ctx, rows) } +func (br *Bridge) GetDMPortalsWith(ctx context.Context, otherUserID networkid.UserID) ([]*Portal, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + rows, err := br.DB.Portal.GetAllDMsWith(ctx, otherUserID) + if err != nil { + return nil, err + } + return br.loadManyPortals(ctx, rows) +} + func (br *Bridge) GetPortalByKey(ctx context.Context, key networkid.PortalKey) (*Portal, error) { br.cacheLock.Lock() defer br.cacheLock.Unlock() @@ -2005,7 +2015,16 @@ func (portal *Portal) sendRoomMeta(ctx context.Context, sender MatrixAPI, ts tim if portal.MXID == "" { return false } - _, err := portal.sendStateWithIntentOrBot(ctx, sender, eventType, stateKey, &event.Content{Parsed: content}, ts) + var extra map[string]any + if !portal.NameIsCustom && (eventType == event.StateRoomName || eventType == event.StateRoomAvatar) { + extra = map[string]any{ + "fi.mau.implicit_name": true, + } + } + _, err := portal.sendStateWithIntentOrBot(ctx, sender, eventType, stateKey, &event.Content{ + Parsed: content, + Raw: extra, + }, ts) if err != nil { zerolog.Ctx(ctx).Err(err). Str("event_type", eventType.Type). @@ -2345,17 +2364,47 @@ func (portal *Portal) UpdateParent(ctx context.Context, newParent networkid.Port return true } +func (portal *Portal) lockedUpdateInfoFromGhost(ctx context.Context, ghost *Ghost) { + portal.roomCreateLock.Lock() + defer portal.roomCreateLock.Unlock() + portal.UpdateInfoFromGhost(ctx, ghost) +} + +func (portal *Portal) UpdateInfoFromGhost(ctx context.Context, ghost *Ghost) (changed bool) { + if portal.NameIsCustom || !portal.Bridge.Config.PrivateChatPortalMeta || (portal.OtherUserID == "" && ghost == nil) || portal.RoomType != database.RoomTypeDM { + return + } + var err error + if ghost == nil { + ghost, err = portal.Bridge.GetGhostByID(ctx, portal.OtherUserID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost to update info from") + return + } + } + changed = portal.UpdateName(ctx, ghost.Name, nil, time.Time{}) || changed + changed = portal.UpdateAvatar(ctx, &Avatar{ + ID: ghost.AvatarID, + MXC: ghost.AvatarMXC, + Hash: ghost.AvatarHash, + }, nil, time.Time{}) || changed + return +} + func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *UserLogin, sender MatrixAPI, ts time.Time) { changed := false if info.Name != nil { + portal.NameIsCustom = true changed = portal.UpdateName(ctx, *info.Name, sender, ts) || changed } if info.Topic != nil { changed = portal.UpdateTopic(ctx, *info.Topic, sender, ts) || changed } if info.Avatar != nil { + portal.NameIsCustom = true changed = portal.UpdateAvatar(ctx, info.Avatar, sender, ts) || changed } + changed = portal.UpdateInfoFromGhost(ctx, nil) || changed if info.Disappear != nil { changed = portal.UpdateDisappearingSetting(ctx, *info.Disappear, sender, ts, false, false) || changed } From b881a7d4551114b45bfd62030526b86b78d7eb71 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Jul 2024 14:07:19 +0300 Subject: [PATCH 0482/1647] bridgev2: fix bugs in avatar handling --- bridgev2/ghost.go | 4 +++- bridgev2/portal.go | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index bcc905ab..bbfa1a19 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -146,11 +146,13 @@ func (ghost *Ghost) UpdateAvatar(ctx context.Context, avatar *Avatar) bool { if !avatar.Remove { newMXC, newHash, err := avatar.Reupload(ctx, ghost.Intent, ghost.AvatarHash) if err != nil { + ghost.AvatarSet = false zerolog.Ctx(ctx).Err(err).Msg("Failed to reupload avatar") return true - } else if newHash == ghost.AvatarHash { + } else if newHash == ghost.AvatarHash && ghost.AvatarSet { return true } + ghost.AvatarHash = newHash ghost.AvatarMXC = newMXC } else { ghost.AvatarMXC = "" diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 38a3ec5c..5ac1a0a7 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1939,7 +1939,7 @@ func (portal *Portal) UpdateAvatar(ctx context.Context, avatar *Avatar, sender M portal.AvatarSet = false zerolog.Ctx(ctx).Err(err).Msg("Failed to reupload room avatar") return true - } else if newHash == portal.AvatarHash { + } else if newHash == portal.AvatarHash && portal.AvatarSet { return true } portal.AvatarMXC = newMXC From 81028a6a08e34980c1e486237a7160fd947a5d32 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Jul 2024 16:04:55 +0300 Subject: [PATCH 0483/1647] bridgev2: add interfaces for mutes, tags and marked unread bridging --- bridgev2/matrix/connector.go | 2 +- bridgev2/matrix/intent.go | 15 +++++++++++- bridgev2/networkinterface.go | 25 ++++++++++++++++---- bridgev2/portal.go | 44 +++++++++++++++++++++++++++++++++++- event/accountdata.go | 20 ++++++++++++++++ event/content.go | 1 + event/type.go | 1 + versions.go | 1 + 8 files changed, 101 insertions(+), 8 deletions(-) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index ca459e98..00538cb4 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -420,7 +420,7 @@ func (br *Connector) SendMessageStatus(ctx context.Context, ms *bridgev2.Message } func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2.MessageStatus, evt *bridgev2.MessageStatusEventInfo, editEvent id.EventID) id.EventID { - if evt.EventType.IsEphemeral() { + if evt.EventType.IsEphemeral() || evt.EventID == "" { return "" } log := zerolog.Ctx(ctx) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 8f009f03..ab82da2a 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -356,7 +356,20 @@ func (as *ASIntent) TagRoom(ctx context.Context, roomID id.RoomID, tag event.Roo } func (as *ASIntent) MuteRoom(ctx context.Context, roomID id.RoomID, until time.Time) error { - if !until.IsZero() && until.Before(time.Now()) { + var mutedUntil int64 + if until.Before(time.Now()) { + mutedUntil = 0 + } else if until == event.MutedForever { + mutedUntil = -1 + } else { + mutedUntil = until.UnixMilli() + } + if as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureAccountDataMute) { + return as.Matrix.SetRoomAccountData(ctx, roomID, event.AccountDataBeeperMute.Type, &event.BeeperMuteEventContent{ + MutedUntil: mutedUntil, + }) + } + if mutedUntil == 0 { err := as.Matrix.DeletePushRule(ctx, "global", pushrules.RoomRule, string(roomID)) // If the push rule doesn't exist, everything is fine if errors.Is(err, mautrix.MNotFound) { diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 48f6d29a..1b759236 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -457,6 +457,21 @@ type TypingHandlingNetworkAPI interface { HandleMatrixTyping(ctx context.Context, msg *MatrixTyping) error } +type MarkedUnreadHandlingNetworkAPI interface { + NetworkAPI + HandleMarkedUnread(ctx context.Context, msg *MatrixMarkedUnread) error +} + +type MuteHandlingNetworkAPI interface { + NetworkAPI + HandleMute(ctx context.Context, msg *MatrixMute) error +} + +type TagHandlingNetworkAPI interface { + NetworkAPI + HandleRoomTag(ctx context.Context, msg *MatrixRoomTag) error +} + // RoomNameHandlingNetworkAPI is an optional interface that network connectors can implement to handle room name changes. type RoomNameHandlingNetworkAPI interface { NetworkAPI @@ -832,11 +847,7 @@ type MatrixMessageRemove struct { TargetMessage *database.Message } -type RoomMetaEventContent interface { - *event.RoomNameEventContent | *event.RoomAvatarEventContent | *event.TopicEventContent -} - -type MatrixRoomMeta[ContentType RoomMetaEventContent] struct { +type MatrixRoomMeta[ContentType any] struct { MatrixEventBase[ContentType] PrevContent ContentType } @@ -865,3 +876,7 @@ type MatrixTyping struct { IsTyping bool Type TypingType } + +type MatrixMarkedUnread = MatrixRoomMeta[*event.MarkedUnreadEventContent] +type MatrixMute = MatrixRoomMeta[*event.BeeperMuteEventContent] +type MatrixRoomTag = MatrixRoomMeta[*event.TagEventContent] diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 5ac1a0a7..a108a110 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -429,6 +429,13 @@ func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) { case event.StateRoomAvatar: handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomAvatarHandlingNetworkAPI.HandleMatrixRoomAvatar) case event.StateEncryption: + // TODO? + case event.AccountDataMarkedUnread: + handleMatrixAccountData(portal, ctx, login, evt, MarkedUnreadHandlingNetworkAPI.HandleMarkedUnread) + case event.AccountDataRoomTags: + handleMatrixAccountData(portal, ctx, login, evt, TagHandlingNetworkAPI.HandleRoomTag) + case event.AccountDataBeeperMute: + handleMatrixAccountData(portal, ctx, login, evt, MuteHandlingNetworkAPI.HandleMute) } } @@ -969,7 +976,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi portal.sendSuccessStatus(ctx, evt) } -func handleMatrixRoomMeta[APIType any, ContentType RoomMetaEventContent]( +func handleMatrixRoomMeta[APIType any, ContentType any]( portal *Portal, ctx context.Context, sender *UserLogin, @@ -1036,6 +1043,39 @@ func handleMatrixRoomMeta[APIType any, ContentType RoomMetaEventContent]( portal.sendSuccessStatus(ctx, evt) } +func handleMatrixAccountData[APIType any, ContentType any]( + portal *Portal, ctx context.Context, sender *UserLogin, evt *event.Event, + fn func(APIType, context.Context, *MatrixRoomMeta[ContentType]) error, +) { + api, ok := sender.Client.(APIType) + if !ok { + return + } + log := zerolog.Ctx(ctx) + content, ok := evt.Content.Parsed.(ContentType) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + return + } + var prevContent ContentType + if evt.Unsigned.PrevContent != nil { + _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) + prevContent, _ = evt.Unsigned.PrevContent.Parsed.(ContentType) + } + + err := fn(api, ctx, &MatrixRoomMeta[ContentType]{ + MatrixEventBase: MatrixEventBase[ContentType]{ + Event: evt, + Content: content, + Portal: portal, + }, + PrevContent: prevContent, + }) + if err != nil { + log.Err(err).Msg("Failed to handle Matrix room account data") + } +} + func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(*event.RedactionEventContent) @@ -1900,6 +1940,8 @@ type ChatInfo struct { var Unmuted = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) type UserLocalPortalInfo struct { + // To signal an indefinite mute, use [event.MutedForever] as the value here. + // To unmute, set any time before now, e.g. [bridgev2.Unmuted]. MutedUntil *time.Time Tag *event.RoomTag } diff --git a/event/accountdata.go b/event/accountdata.go index 2d37e0bd..30ca35a2 100644 --- a/event/accountdata.go +++ b/event/accountdata.go @@ -9,6 +9,7 @@ package event import ( "encoding/json" "strings" + "time" "maunium.net/go/mautrix/id" ) @@ -85,3 +86,22 @@ type IgnoredUser struct { type MarkedUnreadEventContent struct { Unread bool `json:"unread"` } + +type BeeperMuteEventContent struct { + MutedUntil int64 `json:"muted_until,omitempty"` +} + +func (bmec *BeeperMuteEventContent) IsMuted() bool { + return bmec.MutedUntil < 0 || (bmec.MutedUntil > 0 && bmec.GetMutedUntilTime().After(time.Now())) +} + +var MutedForever = time.Date(9999, 12, 31, 23, 59, 59, 999999999, time.UTC) + +func (bmec *BeeperMuteEventContent) GetMutedUntilTime() time.Time { + if bmec.MutedUntil < 0 { + return MutedForever + } else if bmec.MutedUntil > 0 { + return time.UnixMilli(bmec.MutedUntil) + } + return time.Time{} +} diff --git a/event/content.go b/event/content.go index c24de56b..d81a7bdd 100644 --- a/event/content.go +++ b/event/content.go @@ -55,6 +55,7 @@ var TypeMap = map[Type]reflect.Type{ AccountDataFullyRead: reflect.TypeOf(FullyReadEventContent{}), AccountDataIgnoredUserList: reflect.TypeOf(IgnoredUserListEventContent{}), AccountDataMarkedUnread: reflect.TypeOf(MarkedUnreadEventContent{}), + AccountDataBeeperMute: reflect.TypeOf(BeeperMuteEventContent{}), EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}), EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}), diff --git a/event/type.go b/event/type.go index e5c6498a..162e2ce7 100644 --- a/event/type.go +++ b/event/type.go @@ -242,6 +242,7 @@ var ( AccountDataFullyRead = Type{"m.fully_read", AccountDataEventType} AccountDataIgnoredUserList = Type{"m.ignored_user_list", AccountDataEventType} AccountDataMarkedUnread = Type{"m.marked_unread", AccountDataEventType} + AccountDataBeeperMute = Type{"com.beeper.mute", AccountDataEventType} AccountDataSecretStorageDefaultKey = Type{"m.secret_storage.default_key", AccountDataEventType} AccountDataSecretStorageKey = Type{"m.secret_storage.key", AccountDataEventType} diff --git a/versions.go b/versions.go index 246889c0..010d987b 100644 --- a/versions.go +++ b/versions.go @@ -68,6 +68,7 @@ var ( BeeperFeatureRoomYeeting = UnstableFeature{UnstableFlag: "com.beeper.room_yeeting"} BeeperFeatureAutojoinInvites = UnstableFeature{UnstableFlag: "com.beeper.room_create_autojoin_invites"} BeeperFeatureArbitraryProfileMeta = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_profile_meta"} + BeeperFeatureAccountDataMute = UnstableFeature{UnstableFlag: "com.beeper.account_data_mute"} ) func (versions *RespVersions) Supports(feature UnstableFeature) bool { From cc5f225bc61c526f4d47a62f9d4325bc2835b07a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Jul 2024 16:05:42 +0300 Subject: [PATCH 0484/1647] bridgev2/database: fix getting DM portals with user --- bridgev2/database/portal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index 18e50780..bc1f2658 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -70,7 +70,7 @@ const ( getPortalByIDWithUncertainReceiverQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='')` getPortalByMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` getAllPortalsWithMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid IS NOT NULL` - getAllDMPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND other_user_id<>''` + getAllDMPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND other_user_id=$2` getAllPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1` getChildPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND parent_id=$2` From ea591b0a2e2337efaf276e11f9c3d9019c4c8e38 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Jul 2024 20:40:45 +0300 Subject: [PATCH 0485/1647] bridgev2/login: redo cookie login params --- bridgev2/commands/login.go | 171 ++++++++++++++---- bridgev2/login.go | 48 ++++- .../unorganized-docs/login-step.schema.json | 48 ++--- bridgev2/userlogin.go | 1 + go.mod | 2 +- go.sum | 4 +- 6 files changed, 206 insertions(+), 68 deletions(-) diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go index c4f4471b..df94c6ba 100644 --- a/bridgev2/commands/login.go +++ b/bridgev2/commands/login.go @@ -10,15 +10,15 @@ import ( "context" "encoding/json" "fmt" - "net/http" + "net/url" "regexp" "strings" "github.com/skip2/go-qrcode" + "go.mau.fi/util/curl" "golang.org/x/net/html" "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -48,6 +48,7 @@ func fnLogin(ce *Event) { var chosenFlowID string if len(ce.Args) > 0 { inputFlowID := strings.ToLower(ce.Args[0]) + ce.Args = ce.Args[1:] for _, flow := range flows { if flow.ID == inputFlowID { chosenFlowID = flow.ID @@ -75,7 +76,64 @@ func fnLogin(ce *Event) { ce.Reply("Failed to start login: %v", err) return } - doLoginStep(ce, login, nextStep) + + nextStep = checkLoginCommandDirectParams(ce, login, nextStep) + if nextStep != nil { + doLoginStep(ce, login, nextStep) + } +} + +func checkLoginCommandDirectParams(ce *Event, login bridgev2.LoginProcess, nextStep *bridgev2.LoginStep) *bridgev2.LoginStep { + if len(ce.Args) == 0 { + return nextStep + } + var ok bool + defer func() { + if !ok { + login.Cancel() + } + }() + var err error + switch nextStep.Type { + case bridgev2.LoginStepTypeDisplayAndWait: + ce.Reply("Invalid extra parameters for display and wait login step") + return nil + case bridgev2.LoginStepTypeUserInput: + if len(ce.Args) != len(nextStep.UserInputParams.Fields) { + ce.Reply("Invalid number of extra parameters (expected 0 or %d, got %d)", len(nextStep.UserInputParams.Fields), len(ce.Args)) + return nil + } + input := make(map[string]string) + for i, param := range nextStep.UserInputParams.Fields { + param.FillDefaultValidate() + input[param.ID], err = param.Validate(ce.Args[i]) + if err != nil { + ce.Reply("Invalid value for %s: %v", param.Name, err) + return nil + } + } + nextStep, err = login.(bridgev2.LoginProcessUserInput).SubmitUserInput(ce.Ctx, input) + case bridgev2.LoginStepTypeCookies: + if len(ce.Args) != len(nextStep.CookiesParams.Fields) { + ce.Reply("Invalid number of extra parameters (expected 0 or %d, got %d)", len(nextStep.CookiesParams.Fields), len(ce.Args)) + return nil + } + input := make(map[string]string) + for i, param := range nextStep.CookiesParams.Fields { + if match, _ := regexp.MatchString(param.Pattern, ce.Args[i]); !match { + ce.Reply("Invalid value for %s: doesn't match regex `%s`", param.ID, param.Pattern) + return nil + } + input[param.ID] = ce.Args[i] + } + nextStep, err = login.(bridgev2.LoginProcessCookies).SubmitCookies(ce.Ctx, input) + } + if err != nil { + ce.Reply("Failed to submit input: %v", err) + return nil + } + ok = true + return nextStep } type userInputLoginCommandState struct { @@ -219,61 +277,98 @@ func (clcs *cookieLoginCommandState) prompt(ce *Event) { }) } -var curlCookieRegex = regexp.MustCompile(`-H '[cC]ookie: ([^']*)'`) - -func missingKeys(required []string, data map[string]string) (missing []string) { - for _, requiredKey := range required { - if _, ok := data[requiredKey]; !ok { - missing = append(missing, requiredKey) - } - } - return -} - func (clcs *cookieLoginCommandState) submit(ce *Event) { ce.Redact() - cookies := make(map[string]string) + cookiesInput := make(map[string]string) if strings.HasPrefix(strings.TrimSpace(ce.RawArgs), "curl") { - if len(clcs.Data.LocalStorageKeys) > 0 || len(clcs.Data.SpecialKeys) > 0 { - ce.Reply("Special keys and localStorage keys can't be extracted from curl commands - please provide the data as JSON instead") + parsed, err := curl.Parse(ce.RawArgs) + if err != nil { + ce.Reply("Failed to parse curl: %v", err) return } - cookieHeader := curlCookieRegex.FindStringSubmatch(ce.RawArgs) - if len(cookieHeader) != 2 { - ce.Reply("Couldn't find `-H 'Cookie: ...'` in curl command") - return + reqCookies := make(map[string]string) + for _, cookie := range parsed.Cookies() { + reqCookies[cookie.Name], err = url.QueryUnescape(cookie.Value) + if err != nil { + ce.Reply("Failed to parse cookie %s: %v", cookie.Name, err) + return + } } - parsed := (&http.Request{Header: http.Header{"Cookie": {cookieHeader[1]}}}).Cookies() - for _, cookie := range parsed { - cookies[cookie.Name] = cookie.Value + var missingKeys, unsupportedKeys []string + for _, field := range clcs.Data.Fields { + var value string + var supported bool + for _, src := range field.Sources { + switch src.Type { + case bridgev2.LoginCookieTypeCookie: + supported = true + value = reqCookies[src.Name] + case bridgev2.LoginCookieTypeRequestHeader: + supported = true + value = parsed.Header.Get(src.Name) + case bridgev2.LoginCookieTypeRequestBody: + supported = true + switch { + case parsed.MultipartForm != nil: + values, ok := parsed.MultipartForm.Value[src.Name] + if ok && len(values) > 0 { + value = values[0] + } + case parsed.ParsedJSON != nil: + untypedValue, ok := parsed.ParsedJSON[src.Name] + if ok { + value = fmt.Sprintf("%v", untypedValue) + } + } + } + if value != "" { + cookiesInput[field.ID] = value + break + } + } + if value == "" && field.Required { + if supported { + missingKeys = append(missingKeys, field.ID) + } else { + unsupportedKeys = append(unsupportedKeys, field.ID) + } + } + } + if len(unsupportedKeys) > 0 { + ce.Reply("Some keys can't be extracted from a cURL request: %+v\n\nPlease provide a JSON object instead.", unsupportedKeys) + return + } else if len(missingKeys) > 0 { + ce.Reply("Missing some keys: %+v", missingKeys) + return } } else { - err := json.Unmarshal([]byte(ce.RawArgs), &cookies) + err := json.Unmarshal([]byte(ce.RawArgs), &cookiesInput) if err != nil { ce.Reply("Failed to parse input as JSON: %v", err) return } } - missingCookies := missingKeys(clcs.Data.CookieKeys, cookies) - if len(missingCookies) > 0 { - ce.Reply("Missing required cookies: %+v", missingCookies) - return + var missingKeys []string + for _, field := range clcs.Data.Fields { + val, ok := cookiesInput[field.ID] + if !ok && field.Required { + missingKeys = append(missingKeys, field.ID) + } + if match, _ := regexp.MatchString(field.Pattern, val); !match { + ce.Reply("Invalid value for %s: doesn't match regex `%s`", field.ID, field.Pattern) + return + } } - missingLocalStorage := missingKeys(clcs.Data.LocalStorageKeys, cookies) - if len(missingLocalStorage) > 0 { - ce.Reply("Missing required localStorage keys: %+v", missingLocalStorage) - return - } - missingSpecial := missingKeys(clcs.Data.SpecialKeys, cookies) - if len(missingSpecial) > 0 { - ce.Reply("Missing required special keys: %+v", missingSpecial) + if len(missingKeys) > 0 { + ce.Reply("Missing some keys: %+v", missingKeys) return } StoreCommandState(ce.User, nil) - nextStep, err := clcs.Login.SubmitCookies(ce.Ctx, cookies) + nextStep, err := clcs.Login.SubmitCookies(ce.Ctx, cookiesInput) if err != nil { ce.Reply("Login failed: %v", err) + return } doLoginStep(ce, clcs.Login, nextStep) } diff --git a/bridgev2/login.go b/bridgev2/login.go index 775f018b..0735049b 100644 --- a/bridgev2/login.go +++ b/bridgev2/login.go @@ -101,15 +101,53 @@ type LoginDisplayAndWaitParams struct { ImageURL string `json:"image_url,omitempty"` } +type LoginCookieFieldSourceType string + +const ( + LoginCookieTypeCookie LoginCookieFieldSourceType = "cookie" + LoginCookieTypeLocalStorage LoginCookieFieldSourceType = "local_storage" + LoginCookieTypeRequestHeader LoginCookieFieldSourceType = "request_header" + LoginCookieTypeRequestBody LoginCookieFieldSourceType = "request_body" + LoginCookieTypeSpecial LoginCookieFieldSourceType = "special" +) + +type LoginCookieFieldSource struct { + // The type of source. + Type LoginCookieFieldSourceType `json:"type"` + // The name of the field. The exact meaning depends on the type of source. + // Cookie: cookie name + // Local storage: key in local storage + // Request header: header name + // Request body: field name inside body after it's parsed (as JSON or multipart form data) + // Special: a namespaced identifier that clients can implement special handling for + Name string `json:"name"` + + // For request header & body types, a regex matching request URLs where the value can be extracted from. + RequestURLRegex string `json:"request_url_regex,omitempty"` + // For cookie types, the domain the cookie is present on. + CookieDomain string `json:"cookie_domain,omitempty"` +} + +type LoginCookieField struct { + // The key in the map that is submitted to the connector. + ID string `json:"id"` + Required bool `json:"required"` + // The sources that can be used to acquire the field value. Only one of these needs to be used. + Sources []LoginCookieFieldSource `json:"sources"` + // A regex pattern that the client can use to validate value client-side. + Pattern string `json:"pattern,omitempty"` +} + type LoginCookiesParams struct { URL string `json:"url"` UserAgent string `json:"user_agent,omitempty"` - CookieDomain string `json:"cookie_domain,omitempty"` - CookieKeys []string `json:"cookie_keys,omitempty"` - LocalStorageKeys []string `json:"local_storage_keys,omitempty"` - SpecialKeys []string `json:"special_keys,omitempty"` - SpecialExtractJS string `json:"special_extract_js,omitempty"` + // The fields that are needed for this cookie login. + Fields []LoginCookieField `json:"fields"` + // A JavaScript snippet that can extract some or all of the fields. + // The snippet will call `window.mautrixLoginCallback` with the extracted fields after they appear. + // Fields that are not present in the callback must be extracted another way. + ExtractJS string `json:"extract_js,omitempty"` } type LoginInputFieldType string diff --git a/bridgev2/unorganized-docs/login-step.schema.json b/bridgev2/unorganized-docs/login-step.schema.json index 38f85e6b..4dbf6d47 100644 --- a/bridgev2/unorganized-docs/login-step.schema.json +++ b/bridgev2/unorganized-docs/login-step.schema.json @@ -72,32 +72,36 @@ "type": "string", "description": "The user agent to use when opening the URL" }, - "cookie_domain": { - "type": "string", - "description": "The domain of the cookies to extract" - }, - "cookie_keys": { + "fields": { "type": "array", - "description": "The cookie names to extract", + "description": "The list of cookies (or other stored data) that must be extracted", "items": { - "type": "string" + "title": "Cookie Field", + "description": "A cookie (or other stored data) that must be extracted", + "type": "object", + "properties": { + "type": { + "type": "string", + "description": "The type of data to extract", + "enum": ["cookie", "local_storage", "request_header", "request_body", "special"] + }, + "name": { + "type": "string", + "description": "The name of the cookie or key in the storage" + }, + "request_url_regex": { + "type": "string", + "description": "For the `request_header` and `request_body` types, a regex that matches the URLs from which the values can be extracted." + }, + "cookie_domain": { + "type": "string", + "description": "For the `cookie` type, the domain of the cookie" + } + }, + "required": ["type", "name"] } }, - "local_storage_keys": { - "type": "array", - "description": "The local storage keys to extract", - "items": { - "type": "string" - } - }, - "special_keys": { - "type": "array", - "description": "Special-cased extraction types that clients must support individually", - "items": { - "type": "string" - } - }, - "special_extract_js": { + "extract_js": { "type": "string", "description": "JavaScript code that can be evaluated inside the webview to extract the special keys" } diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 9f271639..3d798b93 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -263,6 +263,7 @@ func (ul *UserLogin) delete(ctx context.Context, state status.BridgeState, logou ul.BridgeState.Send(state) } ul.BridgeState.Destroy() + ul.BridgeState = nil } func (ul *UserLogin) deleteSpace(ctx context.Context) { diff --git a/go.mod b/go.mod index 4724d937..003cbc6a 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/tidwall/gjson v1.17.1 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.4 - go.mau.fi/util v0.6.0 + go.mau.fi/util v0.6.1-0.20240719175439-20a6073e1dd4 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.25.0 golang.org/x/exp v0.0.0-20240707233637-46b078467d37 diff --git a/go.sum b/go.sum index 5c4d19ec..6d58567d 100644 --- a/go.sum +++ b/go.sum @@ -46,8 +46,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.6.0 h1:W6SyB3Bm/GjenQ5iq8Z8WWdN85Gy2xS6L0wmnR7SVjg= -go.mau.fi/util v0.6.0/go.mod h1:ljYdq3sPfpICc3zMU+/mHV/sa4z0nKxc67hSBwnrk8U= +go.mau.fi/util v0.6.1-0.20240719175439-20a6073e1dd4 h1:CYKYs5jwJ0bFJqh6pRoWtC9NIJ0lz0/6i2SC4qEBFaU= +go.mau.fi/util v0.6.1-0.20240719175439-20a6073e1dd4/go.mod h1:ljYdq3sPfpICc3zMU+/mHV/sa4z0nKxc67hSBwnrk8U= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= From cfd7cb775f85e340caaea445a837dee2fccb0c84 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 20 Jul 2024 15:16:14 +0300 Subject: [PATCH 0486/1647] bridgev2: update definition of extract_js --- bridgev2/login.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bridgev2/login.go b/bridgev2/login.go index 0735049b..2e2b1d84 100644 --- a/bridgev2/login.go +++ b/bridgev2/login.go @@ -145,8 +145,8 @@ type LoginCookiesParams struct { // The fields that are needed for this cookie login. Fields []LoginCookieField `json:"fields"` // A JavaScript snippet that can extract some or all of the fields. - // The snippet will call `window.mautrixLoginCallback` with the extracted fields after they appear. - // Fields that are not present in the callback must be extracted another way. + // The snippet will evaluate to a promise that resolves when the relevant fields are found. + // Fields that are not present in the promise result must be extracted another way. ExtractJS string `json:"extract_js,omitempty"` } From 5a3a88cd39e3ef1726f5cbf1559fe827e769bbde Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 20 Jul 2024 18:52:35 +0300 Subject: [PATCH 0487/1647] bridgev2/provisioning: add whoami endpoint --- bridgev2/bridgestate.go | 27 ++++++++++++------- bridgev2/matrix/provisioning.go | 47 ++++++++++++++++++++++++++++++++- bridgev2/networkinterface.go | 14 +++++----- 3 files changed, 71 insertions(+), 17 deletions(-) diff --git a/bridgev2/bridgestate.go b/bridgev2/bridgestate.go index e7d18d5e..bded88d3 100644 --- a/bridgev2/bridgestate.go +++ b/bridgev2/bridgestate.go @@ -17,10 +17,11 @@ import ( ) type BridgeStateQueue struct { - prev *status.BridgeState - ch chan status.BridgeState - bridge *Bridge - user status.BridgeStateFiller + prevUnsent *status.BridgeState + prevSent *status.BridgeState + ch chan status.BridgeState + bridge *Bridge + user status.BridgeStateFiller } func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) { @@ -72,7 +73,7 @@ func (bsq *BridgeStateQueue) loop() { func (bsq *BridgeStateQueue) immediateSendBridgeState(state status.BridgeState) { retryIn := 2 for { - if bsq.prev != nil && bsq.prev.ShouldDeduplicate(&state) { + if bsq.prevSent != nil && bsq.prevSent.ShouldDeduplicate(&state) { bsq.bridge.Log.Debug(). Str("state_event", string(state.StateEvent)). Msg("Not sending bridge state as it's a duplicate") @@ -93,7 +94,7 @@ func (bsq *BridgeStateQueue) immediateSendBridgeState(state status.BridgeState) retryIn = 64 } } else { - bsq.prev = &state + bsq.prevSent = &state bsq.bridge.Log.Debug(). Any("bridge_state", state). Msg("Sent new bridge state") @@ -108,6 +109,7 @@ func (bsq *BridgeStateQueue) Send(state status.BridgeState) { } state = state.Fill(bsq.user) + bsq.prevUnsent = &state if len(bsq.ch) >= 8 { bsq.bridge.Log.Warn().Msg("Bridge state queue is nearly full, discarding an item") @@ -124,14 +126,21 @@ func (bsq *BridgeStateQueue) Send(state status.BridgeState) { } func (bsq *BridgeStateQueue) GetPrev() status.BridgeState { - if bsq != nil && bsq.prev != nil { - return *bsq.prev + if bsq != nil && bsq.prevSent != nil { + return *bsq.prevSent + } + return status.BridgeState{} +} + +func (bsq *BridgeStateQueue) GetPrevUnsent() status.BridgeState { + if bsq != nil && bsq.prevSent != nil { + return *bsq.prevUnsent } return status.BridgeState{} } func (bsq *BridgeStateQueue) SetPrev(prev status.BridgeState) { if bsq != nil { - bsq.prev = &prev + bsq.prevSent = &prev } } diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index c0eb16e4..817aa922 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -19,9 +19,11 @@ import ( "github.com/rs/xid" "github.com/rs/zerolog" "github.com/rs/zerolog/hlog" + "go.mau.fi/util/jsontime" "go.mau.fi/util/requestlog" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/id" @@ -87,6 +89,7 @@ func (prov *ProvisioningAPI) Init() { prov.Router.Use(hlog.NewHandler(prov.log)) prov.Router.Use(requestlog.AccessLogger(false)) prov.Router.Use(prov.AuthMiddleware) + prov.Router.Path("/v3/whoami").Methods(http.MethodGet).HandlerFunc(prov.GetWhoami) prov.Router.Path("/v3/login/flows").Methods(http.MethodGet).HandlerFunc(prov.GetLoginFlows) prov.Router.Path("/v3/login/start/{flowID}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginStart) prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginSubmitInput) @@ -205,6 +208,48 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { }) } +type RespWhoami struct { + Network bridgev2.BridgeName `json:"network"` + Homeserver string `json:"homeserver"` + BridgeBot id.UserID `json:"bridge_bot"` + CommandPrefix string `json:"command_prefix"` + + ManagementRoom id.RoomID `json:"management_room"` + Logins []RespWhoamiLogin `json:"logins"` +} + +type RespWhoamiLogin struct { + StateEvent status.BridgeStateEvent `json:"state_event"` + StateTS jsontime.Unix `json:"state_ts"` + ID networkid.UserLoginID `json:"id"` + Name string `json:"name"` + SpaceRoom id.RoomID `json:"space_room"` +} + +func (prov *ProvisioningAPI) GetWhoami(w http.ResponseWriter, r *http.Request) { + user := prov.GetUser(r) + resp := &RespWhoami{ + Network: prov.br.Bridge.Network.GetName(), + Homeserver: prov.br.AS.HomeserverDomain, + BridgeBot: prov.br.Bot.UserID, + CommandPrefix: prov.br.Config.Bridge.CommandPrefix, + ManagementRoom: user.ManagementRoom, + } + logins := user.GetCachedUserLogins() + resp.Logins = make([]RespWhoamiLogin, len(logins)) + for i, login := range logins { + prevState := login.BridgeState.GetPrevUnsent() + resp.Logins[i] = RespWhoamiLogin{ + StateEvent: prevState.StateEvent, + StateTS: prevState.Timestamp, + ID: login.ID, + Name: login.RemoteName, + SpaceRoom: login.SpaceRoom, + } + } + jsonResponse(w, http.StatusOK, resp) +} + type RespLoginFlows struct { Flows []bridgev2.LoginFlow `json:"flows"` } @@ -363,7 +408,7 @@ func (prov *ProvisioningAPI) GetLoginForRequest(w http.ResponseWriter, r *http.R } type RespResolveIdentifier struct { - ID networkid.UserID `json:"id,omitempty"` + ID networkid.UserID `json:"id"` Name string `json:"name,omitempty"` AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` Identifiers []string `json:"identifiers,omitempty"` diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 1b759236..1745940e 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -127,21 +127,21 @@ type ConvertedEdit struct { // BridgeName contains information about the network that a connector bridges to. type BridgeName struct { // The displayname of the network, e.g. `Discord` - DisplayName string + DisplayName string `json:"displayname"` // The URL to the website of the network, e.g. `https://discord.com` - NetworkURL string + NetworkURL string `json:"network_url"` // The icon of the network as a mxc:// URI - NetworkIcon id.ContentURIString + NetworkIcon id.ContentURIString `json:"network_icon"` // An identifier uniquely identifying the network, e.g. `discord` - NetworkID string + NetworkID string `json:"network_id"` // An identifier uniquely identifying the bridge software. // The Go import path is a good choice here (e.g. github.com/octocat/discordbridge) - BeeperBridgeType string + BeeperBridgeType string `json:"beeper_bridge_type"` // The default appservice port to use in the example config, defaults to 8080 if unset // Official mautrix bridges will use ports defined in https://mau.fi/ports - DefaultPort uint16 + DefaultPort uint16 `json:"default_port,omitempty"` // The default command prefix to use in the example config, defaults to NetworkID if unset. Must include the ! prefix. - DefaultCommandPrefix string + DefaultCommandPrefix string `json:"default_command_prefix,omitempty"` } func (bn BridgeName) AsBridgeInfoSection() event.BridgeInfoSection { From dff2164cd3618fdf51ea6024c84cbe379b8b8bfc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 20 Jul 2024 20:13:13 +0300 Subject: [PATCH 0488/1647] bridgev2/provisioning: add missing parameter to start_dm endpoint --- bridgev2/matrix/provisioning.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 817aa922..9b9e6ed1 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -98,7 +98,7 @@ func (prov *ProvisioningAPI) Init() { prov.Router.Path("/v3/logins").Methods(http.MethodGet).HandlerFunc(prov.GetLogins) prov.Router.Path("/v3/contacts").Methods(http.MethodGet).HandlerFunc(prov.GetContactList) prov.Router.Path("/v3/resolve_identifier/{identifier}").Methods(http.MethodGet).HandlerFunc(prov.GetResolveIdentifier) - prov.Router.Path("/v3/create_dm").Methods(http.MethodPost).HandlerFunc(prov.PostCreateDM) + prov.Router.Path("/v3/create_dm/{identifier}").Methods(http.MethodPost).HandlerFunc(prov.PostCreateDM) prov.Router.Path("/v3/create_group").Methods(http.MethodPost).HandlerFunc(prov.PostCreateGroup) if prov.br.Config.Provisioning.DebugEndpoints { From 8cb5d5cc6915d71cc8b357253fd8919798ce5f81 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 20 Jul 2024 20:30:50 +0300 Subject: [PATCH 0489/1647] bridgev2/provisioning: fix invalid auth error code --- bridgev2/matrix/provisioning.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 9b9e6ed1..5ae03a13 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -144,9 +144,9 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { if err != nil { zerolog.Ctx(r.Context()).Warn().Err(err). Msg("Provisioning API request contained invalid auth") - jsonResponse(w, http.StatusForbidden, &mautrix.RespError{ + jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{ Err: "Invalid auth token", - ErrCode: mautrix.MForbidden.ErrCode, + ErrCode: mautrix.MUnknownToken.ErrCode, }) return } From 24ead553b23bde881770f1582ee7075fd44bb80d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 20 Jul 2024 20:35:15 +0300 Subject: [PATCH 0490/1647] bridgev2/provisioning: add separate error for missing auth --- bridgev2/matrix/provisioning.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 5ae03a13..5b042c76 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -138,6 +138,13 @@ func (prov *ProvisioningAPI) checkMatrixAuth(ctx context.Context, userID id.User func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") + if auth == "" { + jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{ + Err: "Missing auth token", + ErrCode: mautrix.MMissingToken.ErrCode, + }) + return + } userID := id.UserID(r.URL.Query().Get("user_id")) if auth != prov.br.Config.Provisioning.SharedSecret { err := prov.checkMatrixAuth(r.Context(), userID, auth) From e878ab1315b8ae3853ade1688b86e426240e3499 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 21 Jul 2024 01:44:31 +0300 Subject: [PATCH 0491/1647] bridgev2/provisioning: add CORS headers --- bridgev2/matrix/provisioning.go | 36 +++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 5b042c76..9058b888 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -87,19 +87,20 @@ func (prov *ProvisioningAPI) Init() { prov.log = prov.br.Log.With().Str("component", "provisioning").Logger() prov.Router = prov.br.AS.Router.PathPrefix(prov.br.Config.Provisioning.Prefix).Subrouter() prov.Router.Use(hlog.NewHandler(prov.log)) + prov.Router.Use(corsMiddleware) prov.Router.Use(requestlog.AccessLogger(false)) prov.Router.Use(prov.AuthMiddleware) - prov.Router.Path("/v3/whoami").Methods(http.MethodGet).HandlerFunc(prov.GetWhoami) - prov.Router.Path("/v3/login/flows").Methods(http.MethodGet).HandlerFunc(prov.GetLoginFlows) - prov.Router.Path("/v3/login/start/{flowID}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginStart) - prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginSubmitInput) - prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:display_and_wait}").Methods(http.MethodPost).HandlerFunc(prov.PostLoginWait) - prov.Router.Path("/v3/logout/{loginID}").Methods(http.MethodPost).HandlerFunc(prov.PostLogout) - prov.Router.Path("/v3/logins").Methods(http.MethodGet).HandlerFunc(prov.GetLogins) - prov.Router.Path("/v3/contacts").Methods(http.MethodGet).HandlerFunc(prov.GetContactList) - prov.Router.Path("/v3/resolve_identifier/{identifier}").Methods(http.MethodGet).HandlerFunc(prov.GetResolveIdentifier) - prov.Router.Path("/v3/create_dm/{identifier}").Methods(http.MethodPost).HandlerFunc(prov.PostCreateDM) - prov.Router.Path("/v3/create_group").Methods(http.MethodPost).HandlerFunc(prov.PostCreateGroup) + prov.Router.Path("/v3/whoami").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetWhoami) + prov.Router.Path("/v3/login/flows").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLoginFlows) + prov.Router.Path("/v3/login/start/{flowID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginStart) + prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginSubmitInput) + prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:display_and_wait}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginWait) + prov.Router.Path("/v3/logout/{loginID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLogout) + prov.Router.Path("/v3/logins").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLogins) + prov.Router.Path("/v3/contacts").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetContactList) + prov.Router.Path("/v3/resolve_identifier/{identifier}").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetResolveIdentifier) + prov.Router.Path("/v3/create_dm/{identifier}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateDM) + prov.Router.Path("/v3/create_group").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateGroup) if prov.br.Config.Provisioning.DebugEndpoints { prov.log.Debug().Msg("Enabling debug API at /debug") @@ -109,6 +110,19 @@ func (prov *ProvisioningAPI) Init() { } } +func corsMiddleware(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization") + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusOK) + return + } + handler.ServeHTTP(w, r) + }) +} + func jsonResponse(w http.ResponseWriter, status int, response any) { w.Header().Add("Content-Type", "application/json") w.WriteHeader(status) From 50f8bfac25a51c6def84b952725a197575c693a1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 21 Jul 2024 01:44:43 +0300 Subject: [PATCH 0492/1647] bridgev2/provisioning: add login flows to whoami endpoint --- bridgev2/matrix/provisioning.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 9058b888..8cdfe0a4 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -230,10 +230,11 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { } type RespWhoami struct { - Network bridgev2.BridgeName `json:"network"` - Homeserver string `json:"homeserver"` - BridgeBot id.UserID `json:"bridge_bot"` - CommandPrefix string `json:"command_prefix"` + Network bridgev2.BridgeName `json:"network"` + LoginFlows []bridgev2.LoginFlow `json:"login_flows"` + Homeserver string `json:"homeserver"` + BridgeBot id.UserID `json:"bridge_bot"` + CommandPrefix string `json:"command_prefix"` ManagementRoom id.RoomID `json:"management_room"` Logins []RespWhoamiLogin `json:"logins"` @@ -251,6 +252,7 @@ func (prov *ProvisioningAPI) GetWhoami(w http.ResponseWriter, r *http.Request) { user := prov.GetUser(r) resp := &RespWhoami{ Network: prov.br.Bridge.Network.GetName(), + LoginFlows: prov.br.Bridge.Network.GetLoginFlows(), Homeserver: prov.br.AS.HomeserverDomain, BridgeBot: prov.br.Bot.UserID, CommandPrefix: prov.br.Config.Bridge.CommandPrefix, From 5915bbfd5f60df8eb33f0f1bc411e060138dca81 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 22 Jul 2024 15:35:51 +0300 Subject: [PATCH 0493/1647] legacymigrate: add column to database_was_migrated table for SQLite compat --- bridgev2/matrix/mxmain/legacymigrate.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go index 302bf300..32556de1 100644 --- a/bridgev2/matrix/mxmain/legacymigrate.go +++ b/bridgev2/matrix/mxmain/legacymigrate.go @@ -52,7 +52,7 @@ func (br *BridgeMain) LegacyMigrateSimple(renameTablesQuery, copyDataQuery strin if err != nil { return err } - _, err = br.DB.Exec(ctx, "CREATE TABLE database_was_migrated()") + _, err = br.DB.Exec(ctx, "CREATE TABLE database_was_migrated(empty INTEGER)") if err != nil { return err } From d6e6b66df107d13ac57fde048e6996be7372b515 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 22 Jul 2024 16:21:02 +0300 Subject: [PATCH 0494/1647] bridgev2: fix avatar deduplication when hash matches but mxc is empty --- bridgev2/ghost.go | 8 ++++---- bridgev2/portal.go | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index bbfa1a19..393a8d6a 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -93,7 +93,7 @@ type Avatar struct { Hash [32]byte } -func (a *Avatar) Reupload(ctx context.Context, intent MatrixAPI, currentHash [32]byte) (id.ContentURIString, [32]byte, error) { +func (a *Avatar) Reupload(ctx context.Context, intent MatrixAPI, currentHash [32]byte, currentMXC id.ContentURIString) (id.ContentURIString, [32]byte, error) { if a.MXC != "" { return a.MXC, a.Hash, nil } @@ -102,8 +102,8 @@ func (a *Avatar) Reupload(ctx context.Context, intent MatrixAPI, currentHash [32 return "", [32]byte{}, err } hash := sha256.Sum256(data) - if hash == currentHash { - return "", hash, nil + if hash == currentHash && currentMXC != "" { + return currentMXC, hash, nil } mime := http.DetectContentType(data) fileName := "avatar" + exmime.ExtensionFromMimetype(mime) @@ -144,7 +144,7 @@ func (ghost *Ghost) UpdateAvatar(ctx context.Context, avatar *Avatar) bool { } ghost.AvatarID = avatar.ID if !avatar.Remove { - newMXC, newHash, err := avatar.Reupload(ctx, ghost.Intent, ghost.AvatarHash) + newMXC, newHash, err := avatar.Reupload(ctx, ghost.Intent, ghost.AvatarHash, ghost.AvatarMXC) if err != nil { ghost.AvatarSet = false zerolog.Ctx(ctx).Err(err).Msg("Failed to reupload avatar") diff --git a/bridgev2/portal.go b/bridgev2/portal.go index a108a110..24451610 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1976,7 +1976,7 @@ func (portal *Portal) UpdateAvatar(ctx context.Context, avatar *Avatar, sender M portal.AvatarMXC = "" portal.AvatarHash = [32]byte{} } else { - newMXC, newHash, err := avatar.Reupload(ctx, sender, portal.AvatarHash) + newMXC, newHash, err := avatar.Reupload(ctx, sender, portal.AvatarHash, portal.AvatarMXC) if err != nil { portal.AvatarSet = false zerolog.Ctx(ctx).Err(err).Msg("Failed to reupload room avatar") From f5beb85721fbac9ed4b64482334ba30360700471 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 22 Jul 2024 16:38:43 +0300 Subject: [PATCH 0495/1647] bridgev2/backfill: fix adding double puppet values --- bridgev2/matrix/connector.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 00538cb4..08aac337 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -535,13 +535,18 @@ func (br *Connector) BatchSend(ctx context.Context, roomID id.RoomID, req *mautr for _, evt := range req.Events { intent, _ := br.doublePuppetIntents.Get(evt.Sender) if intent != nil { - intent.AddDoublePuppetValueWithTS(evt.ID, evt.Timestamp) + intent.AddDoublePuppetValueWithTS(&evt.Content, evt.Timestamp) } - err = br.Crypto.Encrypt(ctx, roomID, evt.Type, &evt.Content) - if err != nil { - return nil, err + if evt.Type != event.EventEncrypted { + err = br.Crypto.Encrypt(ctx, roomID, evt.Type, &evt.Content) + if err != nil { + return nil, err + } + evt.Type = event.EventEncrypted + if intent != nil { + intent.AddDoublePuppetValueWithTS(&evt.Content, evt.Timestamp) + } } - evt.Type = event.EventEncrypted } } return br.Bot.BeeperBatchSend(ctx, roomID, req) From dc8ebb2c65b1c376c7a0909380a60ac615953e8a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 22 Jul 2024 17:34:54 +0300 Subject: [PATCH 0496/1647] bridgev2/matrixinterface: include message/reaction meta for redactions --- bridgev2/portal.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 24451610..857f2de8 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1643,7 +1643,7 @@ func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *Us Parsed: &event.RedactionEventContent{ Redacts: targetReaction.MXID, }, - }, &MatrixSendExtra{Timestamp: ts}) + }, &MatrixSendExtra{Timestamp: ts, ReactionMeta: targetReaction}) if err != nil { log.Err(err).Stringer("reaction_mxid", targetReaction.MXID).Msg("Failed to redact reaction") } @@ -1670,7 +1670,7 @@ func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *Use Parsed: &event.RedactionEventContent{ Redacts: part.MXID, }, - }, &MatrixSendExtra{Timestamp: ts}) + }, &MatrixSendExtra{Timestamp: ts, MessageMeta: part}) if err != nil { log.Err(err).Stringer("part_mxid", part.MXID).Msg("Failed to redact message part") } else { From ac29c5e46136ec7650fe334b9a413770ddef85f1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 22 Jul 2024 18:10:13 +0300 Subject: [PATCH 0497/1647] bridgev2/portal: handle remote reactions to unknown messages --- bridgev2/portal.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 857f2de8..d9025f40 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1544,6 +1544,10 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi if err != nil { log.Err(err).Msg("Failed to get target message for reaction") return + } else if targetMessage == nil { + // TODO use deterministic event ID as target if applicable? + log.Warn().Msg("Target message for reaction not found") + return } emoji, emojiID := evt.GetReactionEmoji() existingReaction, err := portal.Bridge.DB.Reaction.GetByID(ctx, targetMessage.ID, targetMessage.PartID, evt.GetSender().Sender, emojiID) From edb026c8a35c3183e4196698f440836475d4ee1e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 22 Jul 2024 19:22:39 +0300 Subject: [PATCH 0498/1647] bridgev2: fix updating DM portal avatar when ghost has no avatar --- bridgev2/ghost.go | 4 +++- bridgev2/portal.go | 7 ++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index 393a8d6a..8f0e8882 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -94,8 +94,10 @@ type Avatar struct { } func (a *Avatar) Reupload(ctx context.Context, intent MatrixAPI, currentHash [32]byte, currentMXC id.ContentURIString) (id.ContentURIString, [32]byte, error) { - if a.MXC != "" { + if a.MXC != "" || a.Hash != [32]byte{} { return a.MXC, a.Hash, nil + } else if a.Get == nil { + return "", [32]byte{}, fmt.Errorf("no Get function provided for avatar") } data, err := a.Get(ctx) if err != nil { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index d9025f40..a129d068 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2430,9 +2430,10 @@ func (portal *Portal) UpdateInfoFromGhost(ctx context.Context, ghost *Ghost) (ch } changed = portal.UpdateName(ctx, ghost.Name, nil, time.Time{}) || changed changed = portal.UpdateAvatar(ctx, &Avatar{ - ID: ghost.AvatarID, - MXC: ghost.AvatarMXC, - Hash: ghost.AvatarHash, + ID: ghost.AvatarID, + MXC: ghost.AvatarMXC, + Hash: ghost.AvatarHash, + Remove: ghost.AvatarID == "", }, nil, time.Time{}) || changed return } From 1632e6c9ed42e28805fd1a8d346446355df76dc0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 22 Jul 2024 21:19:29 +0300 Subject: [PATCH 0499/1647] event: ignore is_falling_back in non-thread relations --- event/relations.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/event/relations.go b/event/relations.go index ecd7a959..ea40cc06 100644 --- a/event/relations.go +++ b/event/relations.go @@ -73,7 +73,7 @@ func (rel *RelatesTo) GetReplyTo() id.EventID { } func (rel *RelatesTo) GetNonFallbackReplyTo() id.EventID { - if rel != nil && rel.InReplyTo != nil && !rel.IsFallingBack { + if rel != nil && rel.InReplyTo != nil && (rel.Type != RelThread || !rel.IsFallingBack) { return rel.InReplyTo.EventID } return "" From bc0eb86d1837bf07f6428e55b9428b65f717e69d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 23 Jul 2024 19:37:07 +0300 Subject: [PATCH 0500/1647] bridgev2/portal: always get ghost when handling remote event Otherwise the ghost row may not exist --- bridgev2/portal.go | 38 +++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index a129d068..a6303b04 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1236,6 +1236,17 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { } func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID) { + var ghost *Ghost + if sender.Sender != "" { + var err error + ghost, err = portal.Bridge.GetGhostByID(ctx, sender.Sender) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost for message sender") + return + } else { + ghost.UpdateInfoIfNecessary(ctx, source, evtType) + } + } if sender.IsFromMe { intent = source.User.DoublePuppet(ctx) if intent != nil { @@ -1252,26 +1263,19 @@ func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventS extraUserID = senderLogin.UserMXID } } - if sender.Sender != "" { - if portal.Receiver == "" { - for _, login := range otherLogins { - if login.Client.IsThisUser(ctx, sender.Sender) { - intent = login.User.DoublePuppet(ctx) - if intent != nil { - return - } - extraUserID = login.UserMXID + if sender.Sender != "" && portal.Receiver == "" && otherLogins != nil { + for _, login := range otherLogins { + if login.Client.IsThisUser(ctx, sender.Sender) { + intent = login.User.DoublePuppet(ctx) + if intent != nil { + return } + extraUserID = login.UserMXID } } - ghost, err := portal.Bridge.GetGhostByID(ctx, sender.Sender) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost for message sender") - return - } else { - ghost.UpdateInfoIfNecessary(ctx, source, evtType) - intent = ghost.Intent - } + } + if ghost != nil { + intent = ghost.Intent } return } From 358f8702e4a3ee965118d2763e58eb9719bd8558 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 23 Jul 2024 19:39:06 +0300 Subject: [PATCH 0501/1647] bridgev2/login: fill Metadata if it's nil in NewLogin --- bridgev2/userlogin.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 3d798b93..1a88fc2b 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -163,6 +163,9 @@ func (user *User) NewLogin(ctx context.Context, data *database.UserLogin, params defer user.Bridge.cacheLock.Unlock() data.BridgeID = user.BridgeID data.UserMXID = user.MXID + if data.Metadata == nil { + data.Metadata = user.Bridge.Network.GetDBMetaTypes().UserLogin() + } if params == nil { params = &NewLoginParams{} } From 5704fa0b3cf58f4884fc18adf23395eedb756f28 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 23 Jul 2024 22:37:31 +0300 Subject: [PATCH 0502/1647] bridgev2: implement sync_direct_chat_list option --- bridgev2/matrix/intent.go | 39 +++++++++++++++++++++++++++++++++++++ bridgev2/matrixinterface.go | 4 ++++ bridgev2/portal.go | 14 ++++++++++++- 3 files changed, 56 insertions(+), 1 deletion(-) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index ab82da2a..6807e4af 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -10,9 +10,11 @@ import ( "context" "errors" "fmt" + "sync" "time" "github.com/rs/zerolog" + "golang.org/x/exp/slices" "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" @@ -27,9 +29,13 @@ import ( type ASIntent struct { Matrix *appservice.IntentAPI Connector *Connector + + dmUpdateLock sync.Mutex + directChatsCache event.DirectChatsEventContent } var _ bridgev2.MatrixAPI = (*ASIntent)(nil) +var _ bridgev2.MarkAsDMMatrixAPI = (*ASIntent)(nil) func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, extra *bridgev2.MatrixSendExtra) (*mautrix.RespSendEvent, error) { if extra == nil { @@ -295,6 +301,39 @@ func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) return resp.RoomID, nil } +func (as *ASIntent) MarkAsDM(ctx context.Context, roomID id.RoomID, withUser id.UserID) error { + if !as.Connector.Config.Matrix.SyncDirectChatList { + return nil + } + as.dmUpdateLock.Lock() + defer as.dmUpdateLock.Unlock() + cached, ok := as.directChatsCache[withUser] + if ok && slices.Contains(cached, roomID) { + return nil + } + var directChats event.DirectChatsEventContent + err := as.Matrix.GetAccountData(ctx, event.AccountDataDirectChats.Type, &directChats) + if err != nil { + return err + } + as.directChatsCache = directChats + rooms := directChats[withUser] + if slices.Contains(rooms, roomID) { + return nil + } + directChats[withUser] = append(rooms, roomID) + err = as.Matrix.SetAccountData(ctx, event.AccountDataDirectChats.Type, &directChats) + if err != nil { + if rooms == nil { + delete(directChats, withUser) + } else { + directChats[withUser] = rooms + } + return fmt.Errorf("failed to set direct chats account data: %w", err) + } + return nil +} + func (as *ASIntent) DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnly bool) error { if as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureRoomYeeting) { return as.Matrix.BeeperDeleteRoom(ctx, roomID) diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index e25d0f06..b6773810 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -88,3 +88,7 @@ type MatrixAPI interface { TagRoom(ctx context.Context, roomID id.RoomID, tag event.RoomTag, isTagged bool) error MuteRoom(ctx context.Context, roomID id.RoomID, until time.Time) error } + +type MarkAsDMMatrixAPI interface { + MarkAsDM(ctx context.Context, roomID id.RoomID, otherUser id.UserID) error +} diff --git a/bridgev2/portal.go b/bridgev2/portal.go index a6303b04..b30628fa 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2318,13 +2318,25 @@ func (portal *Portal) SyncParticipants(ctx context.Context, members *ChatMemberL } func (portal *Portal) updateUserLocalInfo(ctx context.Context, info *UserLocalPortalInfo, source *UserLogin) { - if portal.MXID == "" || info == nil { + if portal.MXID == "" { return } dp := source.User.DoublePuppet(ctx) if dp == nil { return } + dmMarkingMatrixAPI, canMarkDM := dp.(MarkAsDMMatrixAPI) + if canMarkDM && portal.OtherUserID != "" && portal.RoomType == database.RoomTypeDM { + dmGhost, err := portal.Bridge.GetGhostByID(ctx, portal.OtherUserID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get DM ghost to mark room as DM") + } else if err = dmMarkingMatrixAPI.MarkAsDM(ctx, portal.MXID, dmGhost.Intent.GetMXID()); err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to mark room as DM") + } + } + if info == nil { + return + } if info.MutedUntil != nil { err := dp.MuteRoom(ctx, portal.MXID, *info.MutedUntil) if err != nil { From 1ace7749bb12ce920b6d754bbb8eafb39f43bbb0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 23 Jul 2024 22:56:13 +0300 Subject: [PATCH 0503/1647] bridgev2/mxmain: make it easier to print example config to stdout --- bridgev2/matrix/mxmain/main.go | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index f02decb6..af0868bf 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -164,13 +164,20 @@ func (br *BridgeMain) PreInit() { _ = json.NewEncoder(os.Stdout).Encode(output) os.Exit(0) } else if *writeExampleConfig { - if _, err = os.Stat(*configPath); !errors.Is(err, os.ErrNotExist) { - _, _ = fmt.Fprintln(os.Stderr, *configPath, "already exists, please remove it if you want to generate a new example") - os.Exit(1) + if *configPath != "-" && *configPath != "/dev/stdout" && *configPath != "/dev/stderr" { + if _, err = os.Stat(*configPath); !errors.Is(err, os.ErrNotExist) { + _, _ = fmt.Fprintln(os.Stderr, *configPath, "already exists, please remove it if you want to generate a new example") + os.Exit(1) + } } networkExample, _, _ := br.Connector.GetConfig() - exerrors.PanicIfNotNil(os.WriteFile(*configPath, []byte(br.makeFullExampleConfig(networkExample)), 0600)) - fmt.Println("Wrote example config to", *configPath) + fullCfg := br.makeFullExampleConfig(networkExample) + if *configPath == "-" { + fmt.Print(fullCfg) + } else { + exerrors.PanicIfNotNil(os.WriteFile(*configPath, []byte(fullCfg), 0600)) + fmt.Println("Wrote example config to", *configPath) + } os.Exit(0) } br.LoadConfig() From 1b706d0e5cc93dd77f2fdf7bf8b62a2132332296 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 24 Jul 2024 02:26:45 +0300 Subject: [PATCH 0504/1647] bridgev2: add options for deleting portals on logout --- bridgev2/bridgeconfig/config.go | 24 +++ bridgev2/commands/cleanup.go | 27 +--- bridgev2/database/userportal.go | 7 + bridgev2/matrix/mxmain/example-config.yaml | 27 ++++ bridgev2/portal.go | 4 +- bridgev2/userlogin.go | 176 ++++++++++++++++++++- 6 files changed, 233 insertions(+), 32 deletions(-) diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index 3bd173b1..861805c6 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -31,10 +31,34 @@ type Config struct { ManagementRoomTexts ManagementRoomTexts `yaml:"management_room_texts"` } +type CleanupAction string + +const ( + CleanupActionNull CleanupAction = "" + CleanupActionNothing CleanupAction = "nothing" + CleanupActionKick CleanupAction = "kick" + CleanupActionUnbridge CleanupAction = "unbridge" + CleanupActionDelete CleanupAction = "delete" +) + +type CleanupOnLogout struct { + Private CleanupAction `yaml:"private"` + Relayed CleanupAction `yaml:"relayed"` + SharedNoUsers CleanupAction `yaml:"shared_no_users"` + SharedHasUsers CleanupAction `yaml:"shared_has_users"` +} + +type CleanupOnLogouts struct { + Enabled bool `yaml:"enabled"` + Manual CleanupOnLogout `yaml:"manual"` + BadCredentials CleanupOnLogout `yaml:"bad_credentials"` +} + type BridgeConfig struct { CommandPrefix string `yaml:"command_prefix"` PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"` PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"` + CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"` Relay RelayConfig `yaml:"relay"` Permissions PermissionConfig `yaml:"permissions"` Backfill BackfillConfig `yaml:"backfill"` diff --git a/bridgev2/commands/cleanup.go b/bridgev2/commands/cleanup.go index 55f34d14..f8ad1d23 100644 --- a/bridgev2/commands/cleanup.go +++ b/bridgev2/commands/cleanup.go @@ -7,9 +7,6 @@ package commands import ( - "cmp" - "slices" - "maunium.net/go/mautrix/bridgev2" ) @@ -43,29 +40,13 @@ var CommandDeleteAllPortals = &FullHandler{ ce.Reply("Failed to get portals: %v", err) return } - getDepth := func(portal *bridgev2.Portal) int { - depth := 0 - for portal.Parent != nil { - depth++ - portal = portal.Parent - } - return depth - } - // Sort portals so parents are last (to avoid errors caused by deleting parent portals before children) - slices.SortFunc(portals, func(a, b *bridgev2.Portal) int { - return cmp.Compare(getDepth(b), getDepth(a)) - }) - for _, portal := range portals { - err = portal.Delete(ce.Ctx) - if err != nil { + bridgev2.DeleteManyPortals(ce.Ctx, portals, func(portal *bridgev2.Portal, delete bool, err error) { + if !delete { ce.Reply("Failed to delete portal %s: %v", portal.MXID, err) - continue - } - err = ce.Bot.DeleteRoom(ce.Ctx, portal.MXID, false) - if err != nil { + } else { ce.Reply("Failed to clean up room %s: %v", portal.MXID, err) } - } + }) }, Name: "delete-all-portals", Help: HelpMeta{ diff --git a/bridgev2/database/userportal.go b/bridgev2/database/userportal.go index eeda4ba3..278b236b 100644 --- a/bridgev2/database/userportal.go +++ b/bridgev2/database/userportal.go @@ -67,6 +67,9 @@ const ( markLoginAsPreferredQuery = ` UPDATE user_portal SET preferred=(login_id=$3) WHERE bridge_id=$1 AND user_mxid=$2 AND portal_id=$4 AND portal_receiver=$5 ` + deleteUserPortalQuery = ` + DELETE FROM user_portal WHERE bridge_id=$1 AND user_mxid=$2 AND login_id=$3 AND portal_id=$4 AND portal_receiver=$5 + ` ) func UserPortalFor(ul *UserLogin, portal networkid.PortalKey) *UserPortal { @@ -107,6 +110,10 @@ func (upq *UserPortalQuery) MarkAsPreferred(ctx context.Context, login *UserLogi return upq.Exec(ctx, markLoginAsPreferredQuery, upq.BridgeID, login.UserMXID, login.ID, portal.ID, portal.Receiver) } +func (upq *UserPortalQuery) Delete(ctx context.Context, up *UserPortal) error { + return upq.Exec(ctx, deleteUserPortalQuery, up.BridgeID, up.UserMXID, up.LoginID, up.Portal.ID, up.Portal.Receiver) +} + func (up *UserPortal) Scan(row dbutil.Scannable) (*UserPortal, error) { var lastRead sql.NullInt64 err := row.Scan( diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 16105cb0..2f411c0a 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -7,6 +7,33 @@ bridge: # Whether the bridge should set names and avatars explicitly for DM portals. private_chat_portal_meta: false + # What should be done to portal rooms when a user logs out or is logged out? + # Permitted values: + # nothing - Do nothing, let the user stay in the portals + # kick - Remove the user from the portal rooms, but don't delete them + # unbridge - Remove all ghosts in the room and disassociate it from the remote chat + # delete - Remove all ghosts and users from the room (i.e. delete it) + cleanup_on_logout: + # Should cleanup on logout be enabled at all? + enabled: false + # Settings for manual logouts (explicitly initiated by the Matrix user) + manual: + # Action for private portals which will never be shared with other Matrix users. + private: nothing + # Action for portals with a relay user configured. + relayed: nothing + # Action for portals which may be shared, but don't currently have any other Matrix users. + shared_no_users: nothing + # Action for portals which have other logged-in Matrix users. + shared_has_users: nothing + # Settings for credentials being invalidated (initiated by the remote network, possibly through user action). + # Keys have the same meanings as in the manual section. + bad_credentials: + private: nothing + relayed: nothing + shared_no_users: nothing + shared_has_users: nothing + # Settings for relay mode relay: # Whether relay mode should be allowed. If allowed, the set-relay command can be used to turn any diff --git a/bridgev2/portal.go b/bridgev2/portal.go index b30628fa..f4493573 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -301,9 +301,9 @@ func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowR } portal.Bridge.cacheLock.Lock() defer portal.Bridge.cacheLock.Unlock() - for _, up := range logins { + for i, up := range logins { login, ok := user.logins[up.LoginID] - if ok && login.Client != nil { + if ok && login.Client != nil && (len(logins) == i-1 || login.Client.IsLoggedIn()) { return login, up, nil } } diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 1a88fc2b..2d0eaa20 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -7,8 +7,10 @@ package bridgev2 import ( + "cmp" "context" "fmt" + "slices" "sync" "time" @@ -16,8 +18,10 @@ import ( "go.mau.fi/util/exsync" "maunium.net/go/mautrix/bridge/status" + "maunium.net/go/mautrix/bridgev2/bridgeconfig" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -244,9 +248,13 @@ func (ul *UserLogin) delete(ctx context.Context, state status.BridgeState, logou } else { ul.Disconnect(nil) } - portals, err := ul.Bridge.DB.UserPortal.GetAllForLogin(ctx, ul.UserLogin) - if err != nil { - ul.Log.Err(err).Msg("Failed to get user portals") + var portals []*database.UserPortal + var err error + if ul.Bridge.Config.CleanupOnLogout.Enabled { + portals, err = ul.Bridge.DB.UserPortal.GetAllForLogin(ctx, ul.UserLogin) + if err != nil { + ul.Log.Err(err).Msg("Failed to get user portals") + } } err = ul.Bridge.DB.UserLogin.Delete(ctx, ul.ID) if err != nil { @@ -260,8 +268,11 @@ func (ul *UserLogin) delete(ctx context.Context, state status.BridgeState, logou if !unlocked { ul.Bridge.cacheLock.Unlock() } - go ul.deleteSpace(ctx) - go ul.kickUserFromPortals(ctx, portals) + backgroundCtx := context.WithoutCancel(ctx) + go ul.deleteSpace(backgroundCtx) + if portals != nil { + go ul.kickUserFromPortals(backgroundCtx, portals, state.StateEvent == status.StateBadCredentials, false) + } if state.StateEvent != "" { ul.BridgeState.Send(state) } @@ -279,8 +290,159 @@ func (ul *UserLogin) deleteSpace(ctx context.Context) { } } -func (ul *UserLogin) kickUserFromPortals(ctx context.Context, portals []*database.UserPortal) { - // TODO kick user out of rooms +// KickUserFromPortalsForBadCredentials can be called to kick the user from portals without deleting the entire UserLogin object. +func (ul *UserLogin) KickUserFromPortalsForBadCredentials(ctx context.Context) { + log := zerolog.Ctx(ctx) + portals, err := ul.Bridge.DB.UserPortal.GetAllForLogin(ctx, ul.UserLogin) + if err != nil { + log.Err(err).Msg("Failed to get user portals") + } + ul.kickUserFromPortals(ctx, portals, true, true) +} + +func DeleteManyPortals(ctx context.Context, portals []*Portal, errorCallback func(portal *Portal, delete bool, err error)) { + // TODO is there a more sensible place/name for this function? + if len(portals) == 0 { + return + } + getDepth := func(portal *Portal) int { + depth := 0 + for portal.Parent != nil { + depth++ + portal = portal.Parent + } + return depth + } + // Sort portals so parents are last (to avoid errors caused by deleting parent portals before children) + slices.SortFunc(portals, func(a, b *Portal) int { + return cmp.Compare(getDepth(b), getDepth(a)) + }) + for _, portal := range portals { + err := portal.Delete(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("portal_mxid", portal.MXID). + Object("portal_key", portal.PortalKey). + Msg("Failed to delete portal row from database") + if errorCallback != nil { + errorCallback(portal, false, err) + } + continue + } + err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, false) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("portal_mxid", portal.MXID). + Msg("Failed to clean up portal room") + if errorCallback != nil { + errorCallback(portal, true, err) + } + } + } +} + +func (ul *UserLogin) kickUserFromPortals(ctx context.Context, portals []*database.UserPortal, badCredentials, deleteRow bool) { + var portalsToDelete []*Portal + for _, up := range portals { + portalToDelete, err := ul.kickUserFromPortal(ctx, up, badCredentials, deleteRow) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Object("portal_key", up.Portal). + Stringer("user_mxid", up.UserMXID). + Msg("Failed to apply logout action") + } else if portalToDelete != nil { + portalsToDelete = append(portalsToDelete, portalToDelete) + } + } + DeleteManyPortals(ctx, portalsToDelete, nil) +} + +func (ul *UserLogin) kickUserFromPortal(ctx context.Context, up *database.UserPortal, badCredentials, deleteRow bool) (*Portal, error) { + portal, action, reason, err := ul.getLogoutAction(ctx, up, badCredentials) + if err != nil { + return nil, err + } + zerolog.Ctx(ctx).Debug(). + Str("login_id", string(ul.ID)). + Stringer("user_mxid", ul.UserMXID). + Str("logout_action", string(action)). + Str("action_reason", reason). + Object("portal_key", portal.PortalKey). + Stringer("portal_mxid", portal.MXID). + Msg("Calculated portal action for logout processing") + switch action { + case bridgeconfig.CleanupActionNull, bridgeconfig.CleanupActionNothing: + // do nothing + case bridgeconfig.CleanupActionKick: + _, err = ul.Bridge.Bot.SendState(ctx, portal.MXID, event.StateMember, ul.UserMXID.String(), &event.Content{ + Parsed: &event.MemberEventContent{ + Membership: event.MembershipLeave, + Reason: "Logged out of bridge", + }, + }, time.Time{}) + if err != nil { + return nil, fmt.Errorf("failed to kick user from portal: %w", err) + } + zerolog.Ctx(ctx).Debug(). + Str("login_id", string(ul.ID)). + Stringer("user_mxid", ul.UserMXID). + Stringer("portal_mxid", portal.MXID). + Msg("Kicked user from portal") + if deleteRow { + err = ul.Bridge.DB.UserPortal.Delete(ctx, up) + if err != nil { + zerolog.Ctx(ctx).Warn(). + Str("login_id", string(ul.ID)). + Stringer("user_mxid", ul.UserMXID). + Stringer("portal_mxid", portal.MXID). + Msg("Failed to delete user portal row") + } + } + case bridgeconfig.CleanupActionDelete, bridgeconfig.CleanupActionUnbridge: + // return portal instead of deleting here to allow sorting by depth + return portal, nil + } + return nil, nil +} + +func (ul *UserLogin) getLogoutAction(ctx context.Context, up *database.UserPortal, badCredentials bool) (*Portal, bridgeconfig.CleanupAction, string, error) { + portal, err := ul.Bridge.GetExistingPortalByKey(ctx, up.Portal) + if err != nil { + return nil, bridgeconfig.CleanupActionNull, "", fmt.Errorf("failed to get full portal: %w", err) + } else if portal == nil || portal.MXID == "" { + return nil, bridgeconfig.CleanupActionNull, "portal not found", nil + } + actionsSet := ul.Bridge.Config.CleanupOnLogout.Manual + if badCredentials { + actionsSet = ul.Bridge.Config.CleanupOnLogout.BadCredentials + } + if portal.Receiver == "" { + return portal, actionsSet.Private, "portal has receiver", nil + } + otherUPs, err := ul.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) + if err != nil { + return portal, bridgeconfig.CleanupActionNull, "", fmt.Errorf("failed to get other logins in portal: %w", err) + } + hasOtherUsers := false + for _, otherUP := range otherUPs { + if otherUP.LoginID == ul.ID { + continue + } + if otherUP.UserMXID == ul.UserMXID { + otherUL := ul.Bridge.GetCachedUserLoginByID(otherUP.LoginID) + if otherUL != nil && otherUL.Client.IsLoggedIn() { + return portal, bridgeconfig.CleanupActionNull, "user has another login in portal", nil + } + } else { + hasOtherUsers = true + } + } + if portal.RelayLoginID != "" { + return portal, actionsSet.Relayed, "portal has relay login", nil + } else if hasOtherUsers { + return portal, actionsSet.SharedHasUsers, "portal has logins of other users", nil + } + return portal, actionsSet.SharedNoUsers, "portal doesn't have logins of other users", nil } func (ul *UserLogin) MarkAsPreferredIn(ctx context.Context, portal *Portal) error { From c27b51a41ce3203adeb3e127e8f577ee3b73d443 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 24 Jul 2024 16:21:44 +0300 Subject: [PATCH 0505/1647] bridgev2/portal: don't return other logins in portals with receiver set --- bridgev2/portal.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index f4493573..d9759424 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -303,7 +303,11 @@ func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowR defer portal.Bridge.cacheLock.Unlock() for i, up := range logins { login, ok := user.logins[up.LoginID] - if ok && login.Client != nil && (len(logins) == i-1 || login.Client.IsLoggedIn()) { + if portal.Receiver != "" { + if login.ID == portal.Receiver { + return login, up, nil + } + } else if ok && login.Client != nil && (len(logins) == i-1 || login.Client.IsLoggedIn()) { return login, up, nil } } @@ -316,6 +320,9 @@ func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowR } var firstLogin *UserLogin for _, login := range user.logins { + if portal.Receiver != "" && login.ID != portal.Receiver { + continue + } firstLogin = login break } From e05f095a117ce6f9e96e4427db548d58db163930 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 24 Jul 2024 18:07:21 +0300 Subject: [PATCH 0506/1647] bridgev2/ghost: improve condition for updating DM portal metadata --- bridgev2/ghost.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index 8f0e8882..b066bdd1 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -250,21 +250,23 @@ func (ghost *Ghost) updateDMPortals(ctx context.Context) { func (ghost *Ghost) UpdateInfo(ctx context.Context, info *UserInfo) { update := false + oldName := ghost.Name + oldAvatar := ghost.AvatarMXC if info.Name != nil { update = ghost.UpdateName(ctx, *info.Name) || update } if info.Avatar != nil { update = ghost.UpdateAvatar(ctx, info.Avatar) || update } - if update { - ghost.updateDMPortals(ctx) - } if info.Identifiers != nil || info.IsBot != nil { update = ghost.UpdateContactInfo(ctx, info.Identifiers, info.IsBot) || update } if info.ExtraUpdates != nil { update = info.ExtraUpdates(ctx, ghost) || update } + if oldName != ghost.Name || oldAvatar != ghost.AvatarMXC { + ghost.updateDMPortals(ctx) + } if update { err := ghost.Bridge.DB.Ghost.Update(ctx, ghost.Ghost) if err != nil { From 14d23589be75fc501d9e05de495fd8370eb8b602 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 24 Jul 2024 18:22:09 +0300 Subject: [PATCH 0507/1647] bridgev2/config: add missing copy statements --- bridgev2/bridgeconfig/upgrade.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 07065fc6..57040607 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -189,6 +189,15 @@ func doMigrateLegacy(helper up.Helper) { } else { helper.Set(up.Bool, "false", "bridge", "private_chat_portal_meta") } + helper.Copy(up.Bool, "bridge", "cleanup_on_logout", "enabled") + helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "private") + helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "relayed") + helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "shared_no_users") + helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "shared_has_users") + helper.Copy(up.Str, "bridge", "cleanup_on_logout", "bad_credentials", "private") + helper.Copy(up.Str, "bridge", "cleanup_on_logout", "bad_credentials", "relayed") + helper.Copy(up.Str, "bridge", "cleanup_on_logout", "bad_credentials", "shared_no_users") + helper.Copy(up.Str, "bridge", "cleanup_on_logout", "bad_credentials", "shared_has_users") helper.Copy(up.Bool, "bridge", "relay", "enabled") helper.Copy(up.Bool, "bridge", "relay", "admin_only") helper.Copy(up.Map, "bridge", "permissions") From 669e30f390d2a9e71ba8da5d8880823a3e1ae5d1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 24 Jul 2024 18:37:30 +0300 Subject: [PATCH 0508/1647] bridgev2/commands: add user search command --- bridgev2/commands/processor.go | 2 +- bridgev2/commands/startchat.go | 92 ++++++++++++++++++++++++++-------- 2 files changed, 73 insertions(+), 21 deletions(-) diff --git a/bridgev2/commands/processor.go b/bridgev2/commands/processor.go index 774be16b..49769514 100644 --- a/bridgev2/commands/processor.go +++ b/bridgev2/commands/processor.go @@ -45,7 +45,7 @@ func NewProcessor(bridge *bridgev2.Bridge) bridgev2.CommandProcessor { CommandRegisterPush, CommandDeletePortal, CommandDeleteAllPortals, CommandLogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin, CommandSetRelay, CommandUnsetRelay, - CommandResolveIdentifier, CommandStartChat, + CommandResolveIdentifier, CommandStartChat, CommandSearch, ) return proc } diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go index 0a4e6783..24c8a488 100644 --- a/bridgev2/commands/startchat.go +++ b/bridgev2/commands/startchat.go @@ -7,6 +7,7 @@ package commands import ( + "context" "fmt" "strings" "time" @@ -14,7 +15,6 @@ import ( "golang.org/x/net/html" "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/id" ) @@ -55,6 +55,27 @@ func getClientForStartingChat[T bridgev2.IdentifierResolvingNetworkAPI](ce *Even return login, api, remainingArgs } +func formatResolveIdentifierResult(ctx context.Context, resp *bridgev2.ResolveIdentifierResponse) string { + var targetName string + var targetMXID id.UserID + if resp.Ghost != nil { + if resp.UserInfo != nil { + resp.Ghost.UpdateInfo(ctx, resp.UserInfo) + } + targetName = resp.Ghost.Name + targetMXID = resp.Ghost.Intent.GetMXID() + } else if resp.UserInfo != nil && resp.UserInfo.Name != nil { + targetName = *resp.UserInfo.Name + } + if targetMXID != "" { + return fmt.Sprintf("`%s` / [%s](%s)", resp.UserID, targetName, targetMXID.URI().MatrixToURL()) + } else if targetName != "" { + return fmt.Sprintf("`%s` / %s", resp.UserID, targetName) + } else { + return fmt.Sprintf("`%s`", resp.UserID) + } +} + func fnResolveIdentifier(ce *Event) { if len(ce.Args) == 0 { ce.Reply("Usage: `$cmdprefix %s `", ce.Command) @@ -75,25 +96,7 @@ func fnResolveIdentifier(ce *Event) { ce.ReplyAdvanced(fmt.Sprintf("Identifier %s not found", html.EscapeString(identifier)), false, true) return } - var targetName string - var targetMXID id.UserID - if resp.Ghost != nil { - if resp.UserInfo != nil { - resp.Ghost.UpdateInfo(ce.Ctx, resp.UserInfo) - } - targetName = resp.Ghost.Name - targetMXID = resp.Ghost.Intent.GetMXID() - } else if resp.UserInfo != nil && resp.UserInfo.Name != nil { - targetName = *resp.UserInfo.Name - } - var formattedName string - if targetMXID != "" { - formattedName = fmt.Sprintf("`%s` / [%s](%s)", resp.UserID, targetName, targetMXID.URI().MatrixToURL()) - } else if targetName != "" { - formattedName = fmt.Sprintf("`%s` / %s", resp.UserID, targetName) - } else { - formattedName = fmt.Sprintf("`%s`", resp.UserID) - } + formattedName := formatResolveIdentifierResult(ce.Ctx, resp) if createChat { if resp.Chat == nil { ce.Reply("Interface error: network connector did not return chat for create chat request") @@ -140,3 +143,52 @@ func fnResolveIdentifier(ce *Event) { ce.Reply("Found %s", formattedName) } } + +var CommandSearch = &FullHandler{ + Func: fnSearch, + Name: "search", + Help: HelpMeta{ + Section: HelpSectionChats, + Description: "Search for users on the remote network", + Args: "<_query_>", + }, + RequiresLogin: true, +} + +func fnSearch(ce *Event) { + if len(ce.Args) == 0 { + ce.Reply("Usage: `$cmdprefix search `") + return + } + _, api, queryParts := getClientForStartingChat[bridgev2.UserSearchingNetworkAPI](ce, "searching users") + if api == nil { + return + } + results, err := api.SearchUsers(ce.Ctx, strings.Join(queryParts, " ")) + if err != nil { + ce.Log.Err(err).Msg("Failed to search for users") + ce.Reply("Failed to search for users: %v", err) + return + } + resultsString := make([]string, len(results)) + for i, res := range results { + formattedName := formatResolveIdentifierResult(ce.Ctx, res) + resultsString[i] = fmt.Sprintf("* %s", formattedName) + if res.Chat != nil { + if res.Chat.Portal == nil { + res.Chat.Portal, err = ce.Bridge.GetExistingPortalByKey(ce.Ctx, res.Chat.PortalKey) + if err != nil { + ce.Log.Err(err).Object("portal_key", res.Chat.PortalKey).Msg("Failed to get DM portal") + } + } + if res.Chat.Portal != nil { + portalName := res.Chat.Portal.Name + if portalName == "" { + portalName = res.Chat.Portal.MXID.String() + } + resultsString[i] = fmt.Sprintf("%s - DM portal: [%s](%s)", resultsString[i], portalName, res.Chat.Portal.MXID.URI().MatrixToURL()) + } + } + } + ce.Reply("Search results:\n\n%s", strings.Join(resultsString, "\n")) +} From 3b8d8d8182ee4a36ae137daf0a73cf3db1b70506 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 25 Jul 2024 13:16:36 +0300 Subject: [PATCH 0509/1647] bridgev2: add chat delete event interface --- bridgev2/networkinterface.go | 8 ++++++++ bridgev2/portal.go | 15 +++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 1745940e..e468be00 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -645,6 +645,8 @@ func (ret RemoteEventType) String() string { return "RemoteEventChatInfoChange" case RemoteEventChatResync: return "RemoteEventChatResync" + case RemoteEventChatDelete: + return "RemoteEventChatDelete" case RemoteEventBackfill: return "RemoteEventBackfill" default: @@ -665,6 +667,7 @@ const ( RemoteEventTyping RemoteEventChatInfoChange RemoteEventChatResync + RemoteEventChatDelete RemoteEventBackfill ) @@ -703,6 +706,11 @@ type RemoteChatResyncBackfill interface { CheckNeedsBackfill(ctx context.Context, latestMessage *database.Message) (bool, error) } +type RemoteChatDelete interface { + RemoteEvent + DeleteOnlyForMe() bool +} + type RemoteEventThatMayCreatePortal interface { RemoteEvent ShouldCreatePortal() bool diff --git a/bridgev2/portal.go b/bridgev2/portal.go index d9759424..a06489d0 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2739,6 +2739,21 @@ func (portal *Portal) Delete(ctx context.Context) error { return nil } +func (portal *Portal) RemoveMXID(ctx context.Context) error { + if portal.MXID == "" { + return nil + } + portal.MXID = "" + err := portal.Save(ctx) + if err != nil { + return err + } + portal.Bridge.cacheLock.Lock() + defer portal.Bridge.cacheLock.Unlock() + delete(portal.Bridge.portalsByMXID, portal.MXID) + return nil +} + func (portal *Portal) unlockedDelete(ctx context.Context) error { // TODO delete child portals? err := portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey) From 6b925e0f958e8963c6ae532f18397ac24a4a4be9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 25 Jul 2024 13:21:45 +0300 Subject: [PATCH 0510/1647] bridgev2/simplevent: add better default implementations of RemoteEvent --- bridgev2/simpleremoteevent.go | 2 + bridgev2/simplevent/events.go | 191 ++++++++++++++++++++++++++++++++++ 2 files changed, 193 insertions(+) create mode 100644 bridgev2/simplevent/events.go diff --git a/bridgev2/simpleremoteevent.go b/bridgev2/simpleremoteevent.go index a45ff9c3..66058e3e 100644 --- a/bridgev2/simpleremoteevent.go +++ b/bridgev2/simpleremoteevent.go @@ -20,6 +20,8 @@ import ( // // Using this type is only recommended for simple bridges. More advanced ones should implement // the remote event interfaces themselves by wrapping the remote network library event types. +// +// Deprecated: use the types in the simplevent package instead. type SimpleRemoteEvent[T any] struct { Type RemoteEventType LogContext func(c zerolog.Context) zerolog.Context diff --git a/bridgev2/simplevent/events.go b/bridgev2/simplevent/events.go new file mode 100644 index 00000000..c28e78ed --- /dev/null +++ b/bridgev2/simplevent/events.go @@ -0,0 +1,191 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package simplevent + +import ( + "context" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" +) + +// SimpleRemoteEventMeta is a struct containing metadata fields used by most event types. +type SimpleRemoteEventMeta struct { + Type bridgev2.RemoteEventType + LogContext func(c zerolog.Context) zerolog.Context + PortalKey networkid.PortalKey + Sender bridgev2.EventSender + CreatePortal bool + Timestamp time.Time +} + +var ( + _ bridgev2.RemoteEvent = (*SimpleRemoteEventMeta)(nil) + _ bridgev2.RemoteEventThatMayCreatePortal = (*SimpleRemoteEventMeta)(nil) + _ bridgev2.RemoteEventWithTimestamp = (*SimpleRemoteEventMeta)(nil) +) + +func (evt *SimpleRemoteEventMeta) AddLogContext(c zerolog.Context) zerolog.Context { + return evt.LogContext(c) +} + +func (evt *SimpleRemoteEventMeta) GetPortalKey() networkid.PortalKey { + return evt.PortalKey +} + +func (evt *SimpleRemoteEventMeta) GetTimestamp() time.Time { + if evt.Timestamp.IsZero() { + return time.Now() + } + return evt.Timestamp +} + +func (evt *SimpleRemoteEventMeta) GetSender() bridgev2.EventSender { + return evt.Sender +} + +func (evt *SimpleRemoteEventMeta) GetType() bridgev2.RemoteEventType { + return evt.Type +} + +func (evt *SimpleRemoteEventMeta) ShouldCreatePortal() bool { + return evt.CreatePortal +} + +// SimpleRemoteMessage is a simple implementation of [bridgev2.RemoteMessage] and [bridgev2.RemoteEdit]. +type SimpleRemoteMessage[T any] struct { + SimpleRemoteEventMeta + Data T + + ID networkid.MessageID + TargetMessage networkid.MessageID + + ConvertMessageFunc func(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, data T) (*bridgev2.ConvertedMessage, error) + ConvertEditFunc func(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message, data T) (*bridgev2.ConvertedEdit, error) +} + +var ( + _ bridgev2.RemoteMessage = (*SimpleRemoteMessage[any])(nil) + _ bridgev2.RemoteEdit = (*SimpleRemoteMessage[any])(nil) +) + +func (evt *SimpleRemoteMessage[T]) ConvertMessage(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI) (*bridgev2.ConvertedMessage, error) { + return evt.ConvertMessageFunc(ctx, portal, intent, evt.Data) +} + +func (evt *SimpleRemoteMessage[T]) ConvertEdit(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (*bridgev2.ConvertedEdit, error) { + return evt.ConvertEditFunc(ctx, portal, intent, existing, evt.Data) +} + +func (evt *SimpleRemoteMessage[T]) GetID() networkid.MessageID { + return evt.ID +} + +func (evt *SimpleRemoteMessage[T]) GetTargetMessage() networkid.MessageID { + return evt.TargetMessage +} + +// SimpleRemoteReaction is a simple implementation of [bridgev2.RemoteReaction] and [bridgev2.RemoteReactionRemove]. +type SimpleRemoteReaction struct { + SimpleRemoteEventMeta + TargetMessage networkid.MessageID + EmojiID networkid.EmojiID + Emoji string + ReactionDBMeta any +} + +var ( + _ bridgev2.RemoteReaction = (*SimpleRemoteReaction)(nil) + _ bridgev2.RemoteReactionWithMeta = (*SimpleRemoteReaction)(nil) + _ bridgev2.RemoteReactionRemove = (*SimpleRemoteReaction)(nil) +) + +func (evt *SimpleRemoteReaction) GetTargetMessage() networkid.MessageID { + return evt.TargetMessage +} + +func (evt *SimpleRemoteReaction) GetReactionEmoji() (string, networkid.EmojiID) { + return evt.Emoji, evt.EmojiID +} + +func (evt *SimpleRemoteReaction) GetRemovedEmojiID() networkid.EmojiID { + return evt.EmojiID +} + +func (evt *SimpleRemoteReaction) GetReactionDBMetadata() any { + return evt.ReactionDBMeta +} + +// SimpleRemoteChatResync is a simple implementation of [bridgev2.RemoteChatResync]. +// +// If GetChatInfoFunc is set, it will be used to get the chat info. Otherwise, ChatInfo will be used. +// +// If CheckNeedsBackfillFunc is set, it will be used to determine if backfill is required. +// Otherwise, the latest database message timestamp is compared to LatestMessageTS. +// +// All four fields are optional. +type SimpleRemoteChatResync struct { + SimpleRemoteEventMeta + + ChatInfo *bridgev2.ChatInfo + GetChatInfoFunc func(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) + + LatestMessageTS time.Time + CheckNeedsBackfillFunc func(ctx context.Context, latestMessage *database.Message) (bool, error) +} + +var ( + _ bridgev2.RemoteChatResync = (*SimpleRemoteChatResync)(nil) + _ bridgev2.RemoteChatResyncWithInfo = (*SimpleRemoteChatResync)(nil) + _ bridgev2.RemoteChatResyncBackfill = (*SimpleRemoteChatResync)(nil) +) + +func (evt *SimpleRemoteChatResync) CheckNeedsBackfill(ctx context.Context, latestMessage *database.Message) (bool, error) { + if evt.CheckNeedsBackfillFunc != nil { + return evt.CheckNeedsBackfillFunc(ctx, latestMessage) + } else if latestMessage == nil { + return !evt.LatestMessageTS.IsZero(), nil + } else { + return !evt.LatestMessageTS.IsZero() && evt.LatestMessageTS.Before(latestMessage.Timestamp), nil + } +} + +func (evt *SimpleRemoteChatResync) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) { + if evt.GetChatInfoFunc != nil { + return evt.GetChatInfoFunc(ctx, portal) + } + return evt.ChatInfo, nil +} + +// SimpleRemoteChatDelete is a simple implementation of [bridgev2.RemoteChatDelete]. +type SimpleRemoteChatDelete struct { + SimpleRemoteEventMeta + OnlyForMe bool +} + +var _ bridgev2.RemoteChatDelete = (*SimpleRemoteChatDelete)(nil) + +func (evt *SimpleRemoteChatDelete) DeleteOnlyForMe() bool { + return evt.OnlyForMe +} + +// SimpleRemoteChatInfoChange is a simple implementation of [bridgev2.RemoteChatInfoChange]. +type SimpleRemoteChatInfoChange struct { + SimpleRemoteEventMeta + + ChatInfoChange *bridgev2.ChatInfoChange +} + +var _ bridgev2.RemoteChatInfoChange = (*SimpleRemoteChatInfoChange)(nil) + +func (evt *SimpleRemoteChatInfoChange) GetChatInfoChange(ctx context.Context) (*bridgev2.ChatInfoChange, error) { + return evt.ChatInfoChange, nil +} From 0e7bc4711f085c3ca61fd1ef82a5d5fd88682992 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 25 Jul 2024 13:30:02 +0300 Subject: [PATCH 0511/1647] bridgev2/portal: handle remote chat delete event --- bridgev2/portal.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index a06489d0..f4976d00 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1235,6 +1235,8 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { portal.handleRemoteChatInfoChange(ctx, source, evt.(RemoteChatInfoChange)) case RemoteEventChatResync: portal.handleRemoteChatResync(ctx, source, evt.(RemoteChatResync)) + case RemoteEventChatDelete: + portal.handleRemoteChatDelete(ctx, source, evt.(RemoteChatDelete)) case RemoteEventBackfill: portal.handleRemoteBackfill(ctx, source, evt.(RemoteBackfill)) default: @@ -1817,6 +1819,21 @@ func (portal *Portal) handleRemoteChatResync(ctx context.Context, source *UserLo } } +func (portal *Portal) handleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) { + if portal.Receiver == "" && evt.DeleteOnlyForMe() { + // TODO check if there are other users + } + err := portal.Delete(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to delete portal from database") + return + } + err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, false) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to delete Matrix room") + } +} + func (portal *Portal) handleRemoteBackfill(ctx context.Context, source *UserLogin, backfill RemoteBackfill) { //data, err := backfill.GetBackfillData(ctx, portal) //if err != nil { From 3731f89525e2d4c10ace39af9a5c073ab58d5517 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 25 Jul 2024 13:30:20 +0300 Subject: [PATCH 0512/1647] bridgev2/portal: add remote event type to log context by default --- bridgev2/portal.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index f4976d00..20a75509 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1177,9 +1177,12 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { Msg("Remote event handler panicked") } }() + evtType := evt.GetType() + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Stringer("bridge_evt_type", evtType) + }) log.UpdateContext(evt.AddLogContext) ctx := log.WithContext(context.TODO()) - evtType := evt.GetType() if portal.MXID == "" { mcp, ok := evt.(RemoteEventThatMayCreatePortal) if !ok || !mcp.ShouldCreatePortal() { @@ -1209,7 +1212,7 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { if ok { preHandler.PreHandle(ctx, portal) } - log.Debug().Stringer("bridge_evt_type", evtType).Msg("Handling remote event") + log.Debug().Msg("Handling remote event") switch evtType { case RemoteEventUnknown: log.Debug().Msg("Ignoring remote event with type unknown") @@ -1240,7 +1243,7 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { case RemoteEventBackfill: portal.handleRemoteBackfill(ctx, source, evt.(RemoteBackfill)) default: - log.Warn().Int("type", int(evt.GetType())).Msg("Got remote event with unknown type") + log.Warn().Msg("Got remote event with unknown type") } } From 82e5974d06d2acc5472c1bd8b0d70fa0b8293c82 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 25 Jul 2024 16:35:25 +0300 Subject: [PATCH 0513/1647] bridgev2: allow NetworkAPI to implement custom bridge state fillers --- bridge/status/bridgestate.go | 6 +++++- bridgev2/userlogin.go | 10 +++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/bridge/status/bridgestate.go b/bridge/status/bridgestate.go index bb98e283..e6047a1d 100644 --- a/bridge/status/bridgestate.go +++ b/bridge/status/bridgestate.go @@ -80,9 +80,13 @@ type BridgeStateFiller interface { GetRemoteName() string } +type StandaloneCustomBridgeStateFiller interface { + FillBridgeState(BridgeState) BridgeState +} + type CustomBridgeStateFiller interface { BridgeStateFiller - FillBridgeState(BridgeState) BridgeState + StandaloneCustomBridgeStateFiller } func (pong BridgeState) Fill(user BridgeStateFiller) BridgeState { diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 2d0eaa20..a7d60831 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -449,7 +449,7 @@ func (ul *UserLogin) MarkAsPreferredIn(ctx context.Context, portal *Portal) erro return ul.Bridge.DB.UserPortal.MarkAsPreferred(ctx, ul.UserLogin, portal.PortalKey) } -var _ status.BridgeStateFiller = (*UserLogin)(nil) +var _ status.CustomBridgeStateFiller = (*UserLogin)(nil) func (ul *UserLogin) GetMXID() id.UserID { return ul.UserMXID @@ -463,6 +463,14 @@ func (ul *UserLogin) GetRemoteName() string { return ul.RemoteName } +func (ul *UserLogin) FillBridgeState(state status.BridgeState) status.BridgeState { + filler, ok := ul.Client.(status.StandaloneCustomBridgeStateFiller) + if ok { + return filler.FillBridgeState(state) + } + return state +} + func (ul *UserLogin) Disconnect(done func()) { if done != nil { defer done() From a503da55e38fb468393717754497b16d086aae21 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 26 Jul 2024 00:49:46 +0300 Subject: [PATCH 0514/1647] bridgev2/backfill: never notify for thread batch sends --- bridgev2/portalbackfill.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index 93e1aee5..e46785a2 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -213,7 +213,7 @@ func (portal *Portal) sendBackfill(ctx context.Context, source *UserLogin, messa Bool("mark_read_past_threshold", forceMarkRead). Msg("Sending backfill messages") if canBatchSend { - portal.sendBatch(ctx, source, messages, forceForward, markRead || forceMarkRead) + portal.sendBatch(ctx, source, messages, forceForward, markRead || forceMarkRead, !inThread) } else { portal.sendLegacyBackfill(ctx, source, messages, markRead || forceMarkRead) } @@ -227,16 +227,15 @@ func (portal *Portal) sendBackfill(ctx context.Context, source *UserLogin, messa } } -func (portal *Portal) sendBatch(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead bool) { +func (portal *Portal) sendBatch(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, allowNotification bool) { req := &mautrix.ReqBeeperBatchSend{ ForwardIfNoMessages: !forceForward, Forward: forceForward, Events: make([]*event.Event, 0, len(messages)), + SendNotification: !markRead && forceForward && allowNotification, } if markRead { req.MarkReadBy = source.UserMXID - } else { - req.SendNotification = forceForward } prevThreadEvents := make(map[networkid.MessageID]id.EventID) dbMessages := make([]*database.Message, 0, len(messages)) From f18a2c55c971ec4642e838e6d33dfe1d24b20a1c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 26 Jul 2024 16:36:21 +0300 Subject: [PATCH 0515/1647] crypto/attachments: add type assertion for encryptingReader --- crypto/attachment/attachments.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/crypto/attachment/attachments.go b/crypto/attachment/attachments.go index 5f1e3be9..8008cad2 100644 --- a/crypto/attachment/attachments.go +++ b/crypto/attachment/attachments.go @@ -137,6 +137,8 @@ type encryptingReader struct { isDecrypting bool } +var _ io.ReadSeekCloser = (*encryptingReader)(nil) + func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) { if r.closed { return 0, ReaderClosed From ef99542dc1da0f79aa986e1324eeee06c56a1281 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 26 Jul 2024 16:37:05 +0300 Subject: [PATCH 0516/1647] hicli: fix context import --- hicli/database/event.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hicli/database/event.go b/hicli/database/event.go index cb9568f6..de21e317 100644 --- a/hicli/database/event.go +++ b/hicli/database/event.go @@ -7,6 +7,7 @@ package database import ( + "context" "database/sql" "encoding/json" "fmt" @@ -16,7 +17,6 @@ import ( "github.com/tidwall/gjson" "go.mau.fi/util/dbutil" "go.mau.fi/util/exgjson" - "golang.org/x/net/context" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" From 426921e00a7cd7aac862d7a6a5de1887b42b089b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 26 Jul 2024 20:12:46 +0300 Subject: [PATCH 0517/1647] bridgev2/portal: move exact room type to new field --- bridgev2/portal.go | 5 ++++- event/state.go | 3 ++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 20a75509..e9371580 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2047,7 +2047,10 @@ func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) { AvatarURL: portal.AvatarMXC, // TODO external URL? }, - BeeperRoomType: string(portal.RoomType), + BeeperRoomTypeV2: string(portal.RoomType), + } + if portal.RoomType == database.RoomTypeDM || portal.RoomType == database.RoomTypeGroupDM { + bridgeInfo.BeeperRoomType = "dm" } parent := portal.GetTopLevelParent() if parent != nil { diff --git a/event/state.go b/event/state.go index c47a91ca..6a067cae 100644 --- a/event/state.go +++ b/event/state.go @@ -167,7 +167,8 @@ type BridgeEventContent struct { Network *BridgeInfoSection `json:"network,omitempty"` Channel BridgeInfoSection `json:"channel"` - BeeperRoomType string `json:"com.beeper.room_type,omitempty"` + BeeperRoomType string `json:"com.beeper.room_type,omitempty"` + BeeperRoomTypeV2 string `json:"com.beeper.room_type.v2,omitempty"` } type SpaceChildEventContent struct { From 3ba566d182fe191cd45b6ad670b11299dc8a34ef Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 28 Jul 2024 01:08:44 +0300 Subject: [PATCH 0518/1647] bridgev2/login: fix filling metadata in NewLogin --- bridgev2/userlogin.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index a7d60831..78b732fe 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -168,7 +168,10 @@ func (user *User) NewLogin(ctx context.Context, data *database.UserLogin, params data.BridgeID = user.BridgeID data.UserMXID = user.MXID if data.Metadata == nil { - data.Metadata = user.Bridge.Network.GetDBMetaTypes().UserLogin() + metaTypes := user.Bridge.Network.GetDBMetaTypes() + if metaTypes.UserLogin != nil { + data.Metadata = metaTypes.UserLogin() + } } if params == nil { params = &NewLoginParams{} From 3b9f7c8f230774319e4d232cfd9f9658186b299f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 28 Jul 2024 01:22:28 +0300 Subject: [PATCH 0519/1647] appservice/wshttp: fill RequestURI and RemoteAddr fields --- appservice/wshttp.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/appservice/wshttp.go b/appservice/wshttp.go index 40ceda9d..c5f6a672 100644 --- a/appservice/wshttp.go +++ b/appservice/wshttp.go @@ -61,6 +61,11 @@ func (as *AppService) WebsocketHTTPProxy(cmd WebsocketCommand) (bool, interface{ if err != nil { return false, fmt.Errorf("failed to create fake HTTP request: %w", err) } + httpReq.RequestURI = req.Path + if req.Query != "" { + httpReq.RequestURI += "?" + req.Query + } + httpReq.RemoteAddr = "websocket" httpReq.Header = req.Headers var resp HTTPProxyResponse From b5c26a2fdb6817f32a7b1beaefabcbc749df81d7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 28 Jul 2024 15:18:04 +0300 Subject: [PATCH 0520/1647] federation: add utilities for server name resolution --- federation/request.go | 115 ++++++++++++++++++++++++++ federation/request_test.go | 35 ++++++++ federation/resolution.go | 151 ++++++++++++++++++++++++++++++++++ federation/resolution_test.go | 115 ++++++++++++++++++++++++++ federation/servername.go | 95 +++++++++++++++++++++ federation/servername_test.go | 64 ++++++++++++++ 6 files changed, 575 insertions(+) create mode 100644 federation/request.go create mode 100644 federation/request_test.go create mode 100644 federation/resolution.go create mode 100644 federation/resolution_test.go create mode 100644 federation/servername.go create mode 100644 federation/servername_test.go diff --git a/federation/request.go b/federation/request.go new file mode 100644 index 00000000..faeb16ad --- /dev/null +++ b/federation/request.go @@ -0,0 +1,115 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation + +import ( + "context" + "fmt" + "net" + "net/http" + "sync" + "time" +) + +// ServerResolvingTransport is an http.RoundTripper that resolves Matrix server names before sending requests. +// It only allows requests using the "matrix-federation" scheme. +type ServerResolvingTransport struct { + ResolveOpts *ResolveServerNameOpts + Transport *http.Transport + Dialer *net.Dialer + + cache map[string]*ResolvedServerName + resolveLocks map[string]*sync.Mutex + cacheLock sync.Mutex +} + +func NewServerResolvingTransport() *ServerResolvingTransport { + srt := &ServerResolvingTransport{ + cache: make(map[string]*ResolvedServerName), + resolveLocks: make(map[string]*sync.Mutex), + + Dialer: &net.Dialer{}, + } + srt.Transport = &http.Transport{ + DialContext: srt.DialContext, + } + return srt +} + +func NewFederationHTTPClient() *http.Client { + return &http.Client{ + Transport: NewServerResolvingTransport(), + Timeout: 120 * time.Second, + } +} + +var _ http.RoundTripper = (*ServerResolvingTransport)(nil) + +func (srt *ServerResolvingTransport) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + addrs, ok := ctx.Value(contextKeyIPPort).([]string) + if !ok { + return nil, fmt.Errorf("no IP:port in context") + } + return srt.Dialer.DialContext(ctx, network, addrs[0]) +} + +type contextKey int + +const ( + contextKeyIPPort contextKey = iota +) + +func (srt *ServerResolvingTransport) RoundTrip(request *http.Request) (*http.Response, error) { + if request.URL.Scheme != "matrix-federation" { + return nil, fmt.Errorf("unsupported scheme: %s", request.URL.Scheme) + } + resolved, err := srt.resolve(request.Context(), request.URL.Host) + if err != nil { + return nil, fmt.Errorf("failed to resolve server name: %w", err) + } + request = request.WithContext(context.WithValue(request.Context(), contextKeyIPPort, resolved.IPPort)) + request.URL.Scheme = "https" + request.URL.Host = resolved.HostHeader + request.Host = resolved.HostHeader + return srt.Transport.RoundTrip(request) +} + +func (srt *ServerResolvingTransport) resolve(ctx context.Context, serverName string) (*ResolvedServerName, error) { + res, lock := srt.getResolveCache(serverName) + if res != nil { + return res, nil + } + lock.Lock() + defer lock.Unlock() + res, _ = srt.getResolveCache(serverName) + if res != nil { + return res, nil + } + var err error + res, err = ResolveServerName(ctx, serverName, srt.ResolveOpts) + if err != nil { + return nil, err + } + srt.cacheLock.Lock() + srt.cache[serverName] = res + srt.cacheLock.Unlock() + return res, nil +} + +func (srt *ServerResolvingTransport) getResolveCache(serverName string) (*ResolvedServerName, *sync.Mutex) { + srt.cacheLock.Lock() + defer srt.cacheLock.Unlock() + if val, ok := srt.cache[serverName]; ok && time.Until(val.Expires) > 0 { + return val, nil + } + rl, ok := srt.resolveLocks[serverName] + if !ok { + rl = &sync.Mutex{} + srt.resolveLocks[serverName] = rl + } + return nil, rl +} diff --git a/federation/request_test.go b/federation/request_test.go new file mode 100644 index 00000000..e9037f2d --- /dev/null +++ b/federation/request_test.go @@ -0,0 +1,35 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation_test + +import ( + "encoding/json" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/federation" +) + +type serverVersionResp struct { + Server struct { + Name string `json:"name"` + Version string `json:"version"` + } `json:"server"` +} + +func TestNewFederationClient(t *testing.T) { + cli := federation.NewFederationHTTPClient() + resp, err := cli.Get("matrix-federation://maunium.net/_matrix/federation/v1/version") + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + var respData serverVersionResp + err = json.NewDecoder(resp.Body).Decode(&respData) + require.NoError(t, err) + require.Equal(t, "Synapse", respData.Server.Name) +} diff --git a/federation/resolution.go b/federation/resolution.go new file mode 100644 index 00000000..e6785988 --- /dev/null +++ b/federation/resolution.go @@ -0,0 +1,151 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/rs/zerolog" +) + +type ResolvedServerName struct { + ServerName string `json:"server_name"` + HostHeader string `json:"host_header"` + IPPort []string `json:"ip_port"` + Expires time.Time `json:"expires"` +} + +type ResolveServerNameOpts struct { + HTTPClient *http.Client + DNSClient *net.Resolver +} + +var ( + ErrInvalidServerName = errors.New("invalid server name") +) + +// ResolveServerName implements the full server discovery algorithm as specified in https://spec.matrix.org/v1.11/server-server-api/#resolving-server-names +func ResolveServerName(ctx context.Context, serverName string, opts ...*ResolveServerNameOpts) (*ResolvedServerName, error) { + var opt ResolveServerNameOpts + if len(opts) > 0 && opts[0] != nil { + opt = *opts[0] + } + if opt.HTTPClient == nil { + opt.HTTPClient = http.DefaultClient + } + if opt.DNSClient == nil { + opt.DNSClient = net.DefaultResolver + } + output := ResolvedServerName{ + ServerName: serverName, + HostHeader: serverName, + IPPort: []string{serverName}, + Expires: time.Now().Add(24 * time.Hour), + } + hostname, port, ok := ParseServerName(serverName) + if !ok { + return nil, ErrInvalidServerName + } + // Steps 1 and 2: handle IP literals and hostnames with port + if net.ParseIP(hostname) != nil || port != 0 { + if port == 0 { + port = 8448 + } + output.IPPort = []string{net.JoinHostPort(hostname, strconv.Itoa(int(port)))} + return &output, nil + } + // Step 3: resolve .well-known + wellKnown, expiry, err := RequestWellKnown(ctx, opt.HTTPClient, hostname) + if err != nil { + zerolog.Ctx(ctx).Trace(). + Str("server_name", serverName). + Err(err). + Msg("Failed to get well-known data") + } else if wellKnown != nil { + output.Expires = expiry + output.HostHeader = wellKnown.Server + hostname, port, ok = ParseServerName(wellKnown.Server) + // Step 3.1 and 3.2: IP literals and hostnames with port inside .well-known + if net.ParseIP(hostname) != nil || port != 0 { + if port == 0 { + port = 8448 + } + output.IPPort = []string{net.JoinHostPort(hostname, strconv.Itoa(int(port)))} + return &output, nil + } + } + // Step 3.3, 3.4, 4 and 5: resolve SRV records + srv, err := RequestSRV(ctx, opt.DNSClient, hostname) + if err != nil { + // TODO log more noisily for abnormal errors? + zerolog.Ctx(ctx).Trace(). + Str("server_name", serverName). + Str("hostname", hostname). + Err(err). + Msg("Failed to get SRV record") + } else if len(srv) > 0 { + output.IPPort = make([]string, len(srv)) + for i, record := range srv { + output.IPPort[i] = net.JoinHostPort(strings.TrimRight(record.Target, "."), strconv.Itoa(int(record.Port))) + } + return &output, nil + } + // Step 6 or 3.5: no SRV records were found, so default to port 8448 + output.IPPort = []string{net.JoinHostPort(hostname, "8448")} + return &output, nil +} + +// RequestSRV resolves the `_matrix-fed._tcp` SRV record for the given hostname. +// If the new matrix-fed record is not found, it falls back to the old `_matrix._tcp` record. +func RequestSRV(ctx context.Context, cli *net.Resolver, hostname string) ([]*net.SRV, error) { + _, target, err := cli.LookupSRV(ctx, "matrix-fed", "tcp", hostname) + var dnsErr *net.DNSError + if err != nil && errors.As(err, &dnsErr) && dnsErr.IsNotFound { + _, target, err = cli.LookupSRV(ctx, "matrix", "tcp", hostname) + } + return target, err +} + +// RequestWellKnown sends a request to the well-known endpoint of a server and returns the response, +// plus the time when the cache should expire. +func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (*RespWellKnown, time.Time, error) { + wellKnownURL := url.URL{ + Scheme: "https", + Host: hostname, + Path: "/.well-known/matrix/server", + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnownURL.String(), nil) + if err != nil { + return nil, time.Time{}, fmt.Errorf("failed to prepare request: %w", err) + } + resp, err := cli.Do(req) + if err != nil { + return nil, time.Time{}, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, time.Time{}, fmt.Errorf("unexpected status code %d", resp.StatusCode) + } + var respData RespWellKnown + err = json.NewDecoder(resp.Body).Decode(&respData) + if err != nil { + return nil, time.Time{}, fmt.Errorf("failed to decode response: %w", err) + } else if respData.Server == "" { + return nil, time.Time{}, errors.New("server name not found in response") + } + // TODO parse cache-control header + return &respData, time.Now().Add(24 * time.Hour), nil +} diff --git a/federation/resolution_test.go b/federation/resolution_test.go new file mode 100644 index 00000000..62200454 --- /dev/null +++ b/federation/resolution_test.go @@ -0,0 +1,115 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/federation" +) + +type resolveTestCase struct { + name string + serverName string + expected federation.ResolvedServerName +} + +func TestResolveServerName(t *testing.T) { + // See https://t2bot.io/docs/resolvematrix/ for more info on the RM test cases + testCases := []resolveTestCase{{ + "maunium", + "maunium.net", + federation.ResolvedServerName{ + HostHeader: "federation.mau.chat", + IPPort: []string{"meow.host.mau.fi:443"}, + }, + }, { + "IP literal", + "135.181.208.158", + federation.ResolvedServerName{ + HostHeader: "135.181.208.158", + IPPort: []string{"135.181.208.158:8448"}, + }, + }, { + "IP literal with port", + "135.181.208.158:8447", + federation.ResolvedServerName{ + HostHeader: "135.181.208.158:8447", + IPPort: []string{"135.181.208.158:8447"}, + }, + }, { + "RM Step 2", + "2.s.resolvematrix.dev:7652", + federation.ResolvedServerName{ + HostHeader: "2.s.resolvematrix.dev:7652", + IPPort: []string{"2.s.resolvematrix.dev:7652"}, + }, + }, { + "RM Step 3B", + "3b.s.resolvematrix.dev", + federation.ResolvedServerName{ + HostHeader: "wk.3b.s.resolvematrix.dev:7753", + IPPort: []string{"wk.3b.s.resolvematrix.dev:7753"}, + }, + }, { + "RM Step 3C", + "3c.s.resolvematrix.dev", + federation.ResolvedServerName{ + HostHeader: "wk.3c.s.resolvematrix.dev", + IPPort: []string{"srv.wk.3c.s.resolvematrix.dev:7754"}, + }, + }, { + "RM Step 3C MSC4040", + "3c.msc4040.s.resolvematrix.dev", + federation.ResolvedServerName{ + HostHeader: "wk.3c.msc4040.s.resolvematrix.dev", + IPPort: []string{"srv.wk.3c.msc4040.s.resolvematrix.dev:7053"}, + }, + }, { + "RM Step 3D", + "3d.s.resolvematrix.dev", + federation.ResolvedServerName{ + HostHeader: "wk.3d.s.resolvematrix.dev", + IPPort: []string{"wk.3d.s.resolvematrix.dev:8448"}, + }, + }, { + "RM Step 4", + "4.s.resolvematrix.dev", + federation.ResolvedServerName{ + HostHeader: "4.s.resolvematrix.dev", + IPPort: []string{"srv.4.s.resolvematrix.dev:7855"}, + }, + }, { + "RM Step 4 MSC4040", + "4.msc4040.s.resolvematrix.dev", + federation.ResolvedServerName{ + HostHeader: "4.msc4040.s.resolvematrix.dev", + IPPort: []string{"srv.4.msc4040.s.resolvematrix.dev:7054"}, + }, + }, { + "RM Step 5", + "5.s.resolvematrix.dev", + federation.ResolvedServerName{ + HostHeader: "5.s.resolvematrix.dev", + IPPort: []string{"5.s.resolvematrix.dev:8448"}, + }, + }} + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.expected.ServerName = tc.serverName + resp, err := federation.ResolveServerName(context.TODO(), tc.serverName) + require.NoError(t, err) + resp.Expires = time.Time{} + assert.Equal(t, tc.expected, *resp) + }) + } +} diff --git a/federation/servername.go b/federation/servername.go new file mode 100644 index 00000000..33590712 --- /dev/null +++ b/federation/servername.go @@ -0,0 +1,95 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation + +import ( + "net" + "strconv" + "strings" +) + +func isSpecCompliantIPv6(host string) bool { + // IPv6address = 2*45IPv6char + // IPv6char = DIGIT / %x41-46 / %x61-66 / ":" / "." + // ; 0-9, A-F, a-f, :, . + if len(host) < 2 || len(host) > 45 { + return false + } + for _, ch := range host { + if (ch < '0' || ch > '9') && (ch < 'a' || ch > 'f') && (ch < 'A' || ch > 'F') && ch != ':' && ch != '.' { + return false + } + } + return true +} + +func isValidIPv4Chunk(str string) bool { + if len(str) == 0 || len(str) > 3 { + return false + } + for _, ch := range str { + if ch < '0' || ch > '9' { + return false + } + } + return true + +} + +func isSpecCompliantIPv4(host string) bool { + // IPv4address = 1*3DIGIT "." 1*3DIGIT "." 1*3DIGIT "." 1*3DIGIT + if len(host) < 7 || len(host) > 15 { + return false + } + parts := strings.Split(host, ".") + return len(parts) == 4 && + isValidIPv4Chunk(parts[0]) && + isValidIPv4Chunk(parts[1]) && + isValidIPv4Chunk(parts[2]) && + isValidIPv4Chunk(parts[3]) +} + +func isSpecCompliantDNSName(host string) bool { + // dns-name = 1*255dns-char + // dns-char = DIGIT / ALPHA / "-" / "." + if len(host) == 0 || len(host) > 255 { + return false + } + for _, ch := range host { + if (ch < '0' || ch > '9') && (ch < 'a' || ch > 'z') && (ch < 'A' || ch > 'Z') && ch != '-' && ch != '.' { + return false + } + } + return true +} + +// ParseServerName parses the port and hostname from a Matrix server name and validates that +// it matches the grammar specified in https://spec.matrix.org/v1.11/appendices/#server-name +func ParseServerName(serverName string) (host string, port uint16, ok bool) { + if len(serverName) == 0 || len(serverName) > 255 { + return + } + colonIdx := strings.LastIndexByte(serverName, ':') + if colonIdx > 0 { + u64Port, err := strconv.ParseUint(serverName[colonIdx+1:], 10, 16) + if err == nil { + port = uint16(u64Port) + serverName = serverName[:colonIdx] + } + } + if serverName[0] == '[' { + if serverName[len(serverName)-1] != ']' { + return + } + host = serverName[1 : len(serverName)-1] + ok = isSpecCompliantIPv6(host) && net.ParseIP(host) != nil + } else { + host = serverName + ok = isSpecCompliantDNSName(host) || isSpecCompliantIPv4(host) + } + return +} diff --git a/federation/servername_test.go b/federation/servername_test.go new file mode 100644 index 00000000..156d692f --- /dev/null +++ b/federation/servername_test.go @@ -0,0 +1,64 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "maunium.net/go/mautrix/federation" +) + +type parseTestCase struct { + name string + serverName string + hostname string + port uint16 +} + +func TestParseServerName(t *testing.T) { + testCases := []parseTestCase{{ + "Domain", + "matrix.org", + "matrix.org", + 0, + }, { + "Domain with port", + "matrix.org:8448", + "matrix.org", + 8448, + }, { + "IPv4 literal", + "1.2.3.4", + "1.2.3.4", + 0, + }, { + "IPv4 literal with port", + "1.2.3.4:8448", + "1.2.3.4", + 8448, + }, { + "IPv6 literal", + "[1234:5678::abcd]", + "1234:5678::abcd", + 0, + }, { + "IPv6 literal with port", + "[1234:5678::abcd]:8448", + "1234:5678::abcd", + 8448, + }} + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + hostname, port, ok := federation.ParseServerName(tc.serverName) + assert.True(t, ok) + assert.Equal(t, tc.hostname, hostname) + assert.Equal(t, tc.port, port) + }) + } +} From d237bab4904a64ec68b85eb17f0dcd4f9f57547c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 28 Jul 2024 15:54:39 +0300 Subject: [PATCH 0521/1647] bridgev2/provisioning: add support for authenticating federated users --- bridgev2/matrix/provisioning.go | 22 ++- bridgev2/matrix/provisioningfederation.go | 181 ++++++++++++++++++++++ 2 files changed, 202 insertions(+), 1 deletion(-) create mode 100644 bridgev2/matrix/provisioningfederation.go diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 8cdfe0a4..5f51fc2e 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -26,6 +26,7 @@ import ( "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/federation" "maunium.net/go/mautrix/id" ) @@ -41,6 +42,8 @@ type ProvisioningAPI struct { log zerolog.Logger net bridgev2.NetworkConnector + fedClient *http.Client + logins map[string]*ProvLogin loginsLock sync.RWMutex @@ -85,11 +88,18 @@ func (prov *ProvisioningAPI) Init() { prov.logins = make(map[string]*ProvLogin) prov.net = prov.br.Bridge.Network prov.log = prov.br.Log.With().Str("component", "provisioning").Logger() + prov.fedClient = federation.NewFederationHTTPClient() + tp := prov.fedClient.Transport.(*federation.ServerResolvingTransport) + prov.fedClient.Timeout = 20 * time.Second + tp.Dialer.Timeout = 10 * time.Second + tp.Transport.ResponseHeaderTimeout = 10 * time.Second + tp.Transport.TLSHandshakeTimeout = 10 * time.Second prov.Router = prov.br.AS.Router.PathPrefix(prov.br.Config.Provisioning.Prefix).Subrouter() prov.Router.Use(hlog.NewHandler(prov.log)) prov.Router.Use(corsMiddleware) prov.Router.Use(requestlog.AccessLogger(false)) prov.Router.Use(prov.AuthMiddleware) + prov.Router.Path("/v3/exchange_token").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostExchangeToken) prov.Router.Path("/v3/whoami").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetWhoami) prov.Router.Path("/v3/login/flows").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLoginFlows) prov.Router.Path("/v3/login/start/{flowID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginStart) @@ -151,6 +161,11 @@ func (prov *ProvisioningAPI) checkMatrixAuth(ctx context.Context, userID id.User func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/v3/exchange_token" { + h.ServeHTTP(w, r) + return + } + auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") if auth == "" { jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{ @@ -161,7 +176,12 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { } userID := id.UserID(r.URL.Query().Get("user_id")) if auth != prov.br.Config.Provisioning.SharedSecret { - err := prov.checkMatrixAuth(r.Context(), userID, auth) + var err error + if userID.Homeserver() == prov.br.AS.HomeserverDomain { + err = prov.checkMatrixAuth(r.Context(), userID, auth) + } else { + err = prov.checkJWTAuth(userID, auth) + } if err != nil { zerolog.Ctx(r.Context()).Warn().Err(err). Msg("Provisioning API request contained invalid auth") diff --git a/bridgev2/matrix/provisioningfederation.go b/bridgev2/matrix/provisioningfederation.go new file mode 100644 index 00000000..b33cf90c --- /dev/null +++ b/bridgev2/matrix/provisioningfederation.go @@ -0,0 +1,181 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package matrix + +import ( + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" + "unsafe" + + "github.com/rs/zerolog" + "go.mau.fi/util/jsontime" + "golang.org/x/exp/slices" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/id" +) + +type exchangeTokenData struct { + Token string `json:"token"` +} + +type jwtPayload struct { + Subject id.UserID `json:"sub"` + Expiration jsontime.Unix `json:"exp"` + Issuer string `json:"iss"` + Audience []string `json:"aud"` +} + +const defaultJWTHeader = `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.` // {"alg":"HS256","typ":"JWT"} + +func (prov *ProvisioningAPI) makeJWT(userID id.UserID, validity time.Duration) string { + payload, err := json.Marshal(&jwtPayload{ + Subject: userID, + Expiration: jsontime.U(time.Now().Add(validity)), + Issuer: prov.br.Bot.UserID.String(), + Audience: []string{prov.br.Bot.UserID.String()}, + }) + if err != nil { + return "" + } + payloadLen := base64.RawURLEncoding.EncodedLen(len(payload)) + data := make([]byte, len(defaultJWTHeader)+payloadLen+33) + copy(data, defaultJWTHeader) + base64.RawURLEncoding.Encode(data[len(defaultJWTHeader):], payload) + hasher := hmac.New(sha256.New, []byte(prov.br.Config.Provisioning.SharedSecret)) + hasher.Write(data[:len(defaultJWTHeader)+payloadLen]) + base64.RawURLEncoding.Encode(data[len(defaultJWTHeader)+payloadLen:], hasher.Sum(nil)) + return unsafe.String(unsafe.SliceData(data), len(data)) +} + +func (prov *ProvisioningAPI) validateJWT(jwt string) (id.UserID, error) { + parts := strings.SplitN(jwt, ".", 3) + if len(parts) != 3 { + return "", fmt.Errorf("invalid JWT") + } else if parts[0] != defaultJWTHeader[:len(defaultJWTHeader)-1] { + return "", fmt.Errorf("invalid JWT header") + } + checksum, err := base64.RawURLEncoding.DecodeString(parts[2]) + if err != nil { + return "", fmt.Errorf("failed to decode JWT checksum: %w", err) + } + hasher := hmac.New(sha256.New, []byte(prov.br.Config.Provisioning.SharedSecret)) + hasher.Write([]byte(jwt[:len(defaultJWTHeader)+len(parts[1])])) + if !hmac.Equal(checksum, hasher.Sum(nil)) { + return "", fmt.Errorf("invalid JWT checksum") + } + payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[0]) + if err != nil { + return "", fmt.Errorf("failed to decode JWT payload: %w", err) + } + var payload jwtPayload + err = json.Unmarshal(payloadBytes, &payload) + if err != nil { + return "", fmt.Errorf("failed to unmarshal JWT payload: %w", err) + } else if payload.Expiration.Before(time.Now()) { + return "", fmt.Errorf("JWT has expired") + } else if !slices.Contains(payload.Audience, prov.br.Bot.UserID.String()) { + return "", fmt.Errorf("invalid JWT audience") + } + return payload.Subject, nil +} + +func (prov *ProvisioningAPI) checkJWTAuth(expectedUserID id.UserID, jwt string) error { + userID, err := prov.validateJWT(jwt) + if err != nil { + return err + } + if userID != expectedUserID { + return fmt.Errorf("mismatching user ID (%q != %q)", userID, expectedUserID) + } + return nil +} + +func (prov *ProvisioningAPI) PostExchangeToken(w http.ResponseWriter, r *http.Request) { + var reqData exchangeTokenData + err := json.NewDecoder(r.Body).Decode(&reqData) + if err != nil { + jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ + Err: "Failed to decode request body", + ErrCode: mautrix.MNotJSON.ErrCode, + }) + return + } + userID := id.UserID(r.URL.Query().Get("user_id")) + homeserver := userID.Homeserver() + if homeserver == prov.br.AS.HomeserverDomain { + jsonResponse(w, http.StatusForbidden, &mautrix.RespError{ + Err: "Local users can't exchange tokens", + ErrCode: mautrix.MForbidden.ErrCode, + }) + return + } + perms := prov.br.Config.Bridge.Permissions.Get(userID) + // TODO separate permissions for provisioning API? + if !perms.Commands { + jsonResponse(w, http.StatusForbidden, &mautrix.RespError{ + Err: "User does not have permission to use the provisioning API", + ErrCode: mautrix.MForbidden.ErrCode, + }) + return + } + err = prov.validateOpenIDToken(r.Context(), userID, reqData.Token) + if err != nil { + zerolog.Ctx(r.Context()).Warn().Err(err).Msg("Failed to validate OpenID token") + jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{ + Err: "Failed to validate token", + ErrCode: mautrix.MUnknownToken.ErrCode, + }) + return + } + jsonResponse(w, http.StatusOK, &exchangeTokenData{ + Token: prov.makeJWT(userID, 24*time.Hour), + }) +} + +type respOpenIDUserInfo struct { + Sub id.UserID `json:"sub"` +} + +func (prov *ProvisioningAPI) validateOpenIDToken(ctx context.Context, userID id.UserID, token string) error { + reqURL := url.URL{ + Scheme: "matrix-federation", + Host: userID.Homeserver(), + Path: "/_matrix/federation/v1/openid/userinfo", + RawQuery: (&url.Values{ + "access_token": {token}, + }).Encode(), + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil) + if err != nil { + return fmt.Errorf("failed to prepare request: %w", err) + } + resp, err := prov.fedClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected status code %d", resp.StatusCode) + } + var respData respOpenIDUserInfo + err = json.NewDecoder(resp.Body).Decode(&respData) + if err != nil { + return fmt.Errorf("failed to decode response: %w", err) + } else if respData.Sub != userID { + return fmt.Errorf("mismatching user ID (%q != %q)", respData.Sub, userID) + } + return nil +} From 2733a97c28a8ce743e8f726950fa28e235e43012 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 28 Jul 2024 18:37:47 +0300 Subject: [PATCH 0522/1647] bridgev2/provisioning: use openid tokens directly for federated auth --- bridgev2/matrix/provisioning.go | 63 +++++++- bridgev2/matrix/provisioningfederation.go | 181 ---------------------- 2 files changed, 56 insertions(+), 188 deletions(-) delete mode 100644 bridgev2/matrix/provisioningfederation.go diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 5f51fc2e..0683cb7a 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -11,6 +11,7 @@ import ( "encoding/json" "fmt" "net/http" + "net/url" "strings" "sync" "time" @@ -99,7 +100,6 @@ func (prov *ProvisioningAPI) Init() { prov.Router.Use(corsMiddleware) prov.Router.Use(requestlog.AccessLogger(false)) prov.Router.Use(prov.AuthMiddleware) - prov.Router.Path("/v3/exchange_token").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostExchangeToken) prov.Router.Path("/v3/whoami").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetWhoami) prov.Router.Path("/v3/login/flows").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLoginFlows) prov.Router.Path("/v3/login/start/{flowID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginStart) @@ -159,13 +159,62 @@ func (prov *ProvisioningAPI) checkMatrixAuth(ctx context.Context, userID id.User } } +type respOpenIDUserInfo struct { + Sub id.UserID `json:"sub"` +} + +func (prov *ProvisioningAPI) checkFederatedMatrixAuth(ctx context.Context, userID id.UserID, token string) error { + homeserver := userID.Homeserver() + wrappedToken := fmt.Sprintf("%s:%s", homeserver, token) + // TODO smarter locking + prov.matrixAuthCacheLock.Lock() + defer prov.matrixAuthCacheLock.Unlock() + if cached, ok := prov.matrixAuthCache[wrappedToken]; ok && cached.Expires.After(time.Now()) && cached.UserID == userID { + return nil + } else if validationResult, err := prov.validateOpenIDToken(ctx, homeserver, token); err != nil { + return fmt.Errorf("failed to validate OpenID token: %w", err) + } else if validationResult != userID { + return fmt.Errorf("mismatching user ID (%q != %q)", validationResult, userID) + } else { + prov.matrixAuthCache[wrappedToken] = matrixAuthCacheEntry{ + Expires: time.Now().Add(1 * time.Hour), + UserID: userID, + } + return nil + } +} + +func (prov *ProvisioningAPI) validateOpenIDToken(ctx context.Context, server string, token string) (id.UserID, error) { + reqURL := url.URL{ + Scheme: "matrix-federation", + Host: server, + Path: "/_matrix/federation/v1/openid/userinfo", + RawQuery: (&url.Values{ + "access_token": {token}, + }).Encode(), + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil) + if err != nil { + return "", fmt.Errorf("failed to prepare request: %w", err) + } + resp, err := prov.fedClient.Do(req) + if err != nil { + return "", fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("unexpected status code %d", resp.StatusCode) + } + var respData respOpenIDUserInfo + err = json.NewDecoder(resp.Body).Decode(&respData) + if err != nil { + return "", fmt.Errorf("failed to decode response: %w", err) + } + return respData.Sub, nil +} + func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/v3/exchange_token" { - h.ServeHTTP(w, r) - return - } - auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") if auth == "" { jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{ @@ -180,7 +229,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { if userID.Homeserver() == prov.br.AS.HomeserverDomain { err = prov.checkMatrixAuth(r.Context(), userID, auth) } else { - err = prov.checkJWTAuth(userID, auth) + err = prov.checkFederatedMatrixAuth(r.Context(), userID, auth) } if err != nil { zerolog.Ctx(r.Context()).Warn().Err(err). diff --git a/bridgev2/matrix/provisioningfederation.go b/bridgev2/matrix/provisioningfederation.go deleted file mode 100644 index b33cf90c..00000000 --- a/bridgev2/matrix/provisioningfederation.go +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package matrix - -import ( - "context" - "crypto/hmac" - "crypto/sha256" - "encoding/base64" - "encoding/json" - "fmt" - "net/http" - "net/url" - "strings" - "time" - "unsafe" - - "github.com/rs/zerolog" - "go.mau.fi/util/jsontime" - "golang.org/x/exp/slices" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/id" -) - -type exchangeTokenData struct { - Token string `json:"token"` -} - -type jwtPayload struct { - Subject id.UserID `json:"sub"` - Expiration jsontime.Unix `json:"exp"` - Issuer string `json:"iss"` - Audience []string `json:"aud"` -} - -const defaultJWTHeader = `eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.` // {"alg":"HS256","typ":"JWT"} - -func (prov *ProvisioningAPI) makeJWT(userID id.UserID, validity time.Duration) string { - payload, err := json.Marshal(&jwtPayload{ - Subject: userID, - Expiration: jsontime.U(time.Now().Add(validity)), - Issuer: prov.br.Bot.UserID.String(), - Audience: []string{prov.br.Bot.UserID.String()}, - }) - if err != nil { - return "" - } - payloadLen := base64.RawURLEncoding.EncodedLen(len(payload)) - data := make([]byte, len(defaultJWTHeader)+payloadLen+33) - copy(data, defaultJWTHeader) - base64.RawURLEncoding.Encode(data[len(defaultJWTHeader):], payload) - hasher := hmac.New(sha256.New, []byte(prov.br.Config.Provisioning.SharedSecret)) - hasher.Write(data[:len(defaultJWTHeader)+payloadLen]) - base64.RawURLEncoding.Encode(data[len(defaultJWTHeader)+payloadLen:], hasher.Sum(nil)) - return unsafe.String(unsafe.SliceData(data), len(data)) -} - -func (prov *ProvisioningAPI) validateJWT(jwt string) (id.UserID, error) { - parts := strings.SplitN(jwt, ".", 3) - if len(parts) != 3 { - return "", fmt.Errorf("invalid JWT") - } else if parts[0] != defaultJWTHeader[:len(defaultJWTHeader)-1] { - return "", fmt.Errorf("invalid JWT header") - } - checksum, err := base64.RawURLEncoding.DecodeString(parts[2]) - if err != nil { - return "", fmt.Errorf("failed to decode JWT checksum: %w", err) - } - hasher := hmac.New(sha256.New, []byte(prov.br.Config.Provisioning.SharedSecret)) - hasher.Write([]byte(jwt[:len(defaultJWTHeader)+len(parts[1])])) - if !hmac.Equal(checksum, hasher.Sum(nil)) { - return "", fmt.Errorf("invalid JWT checksum") - } - payloadBytes, err := base64.RawURLEncoding.DecodeString(parts[0]) - if err != nil { - return "", fmt.Errorf("failed to decode JWT payload: %w", err) - } - var payload jwtPayload - err = json.Unmarshal(payloadBytes, &payload) - if err != nil { - return "", fmt.Errorf("failed to unmarshal JWT payload: %w", err) - } else if payload.Expiration.Before(time.Now()) { - return "", fmt.Errorf("JWT has expired") - } else if !slices.Contains(payload.Audience, prov.br.Bot.UserID.String()) { - return "", fmt.Errorf("invalid JWT audience") - } - return payload.Subject, nil -} - -func (prov *ProvisioningAPI) checkJWTAuth(expectedUserID id.UserID, jwt string) error { - userID, err := prov.validateJWT(jwt) - if err != nil { - return err - } - if userID != expectedUserID { - return fmt.Errorf("mismatching user ID (%q != %q)", userID, expectedUserID) - } - return nil -} - -func (prov *ProvisioningAPI) PostExchangeToken(w http.ResponseWriter, r *http.Request) { - var reqData exchangeTokenData - err := json.NewDecoder(r.Body).Decode(&reqData) - if err != nil { - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - Err: "Failed to decode request body", - ErrCode: mautrix.MNotJSON.ErrCode, - }) - return - } - userID := id.UserID(r.URL.Query().Get("user_id")) - homeserver := userID.Homeserver() - if homeserver == prov.br.AS.HomeserverDomain { - jsonResponse(w, http.StatusForbidden, &mautrix.RespError{ - Err: "Local users can't exchange tokens", - ErrCode: mautrix.MForbidden.ErrCode, - }) - return - } - perms := prov.br.Config.Bridge.Permissions.Get(userID) - // TODO separate permissions for provisioning API? - if !perms.Commands { - jsonResponse(w, http.StatusForbidden, &mautrix.RespError{ - Err: "User does not have permission to use the provisioning API", - ErrCode: mautrix.MForbidden.ErrCode, - }) - return - } - err = prov.validateOpenIDToken(r.Context(), userID, reqData.Token) - if err != nil { - zerolog.Ctx(r.Context()).Warn().Err(err).Msg("Failed to validate OpenID token") - jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{ - Err: "Failed to validate token", - ErrCode: mautrix.MUnknownToken.ErrCode, - }) - return - } - jsonResponse(w, http.StatusOK, &exchangeTokenData{ - Token: prov.makeJWT(userID, 24*time.Hour), - }) -} - -type respOpenIDUserInfo struct { - Sub id.UserID `json:"sub"` -} - -func (prov *ProvisioningAPI) validateOpenIDToken(ctx context.Context, userID id.UserID, token string) error { - reqURL := url.URL{ - Scheme: "matrix-federation", - Host: userID.Homeserver(), - Path: "/_matrix/federation/v1/openid/userinfo", - RawQuery: (&url.Values{ - "access_token": {token}, - }).Encode(), - } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil) - if err != nil { - return fmt.Errorf("failed to prepare request: %w", err) - } - resp, err := prov.fedClient.Do(req) - if err != nil { - return fmt.Errorf("failed to send request: %w", err) - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("unexpected status code %d", resp.StatusCode) - } - var respData respOpenIDUserInfo - err = json.NewDecoder(resp.Body).Decode(&respData) - if err != nil { - return fmt.Errorf("failed to decode response: %w", err) - } else if respData.Sub != userID { - return fmt.Errorf("mismatching user ID (%q != %q)", respData.Sub, userID) - } - return nil -} From 638fc771152423f79ea2b2b7e6dcf0eafde496a8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 28 Jul 2024 18:49:50 +0300 Subject: [PATCH 0523/1647] bridgev2/provisioning: use special token prefix instead of homeserver to detect federated auth --- bridgev2/matrix/provisioning.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 0683cb7a..f250ccd7 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -226,10 +226,10 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { userID := id.UserID(r.URL.Query().Get("user_id")) if auth != prov.br.Config.Provisioning.SharedSecret { var err error - if userID.Homeserver() == prov.br.AS.HomeserverDomain { - err = prov.checkMatrixAuth(r.Context(), userID, auth) + if strings.HasPrefix(auth, "openid:") { + err = prov.checkFederatedMatrixAuth(r.Context(), userID, strings.TrimPrefix(auth, "openid:")) } else { - err = prov.checkFederatedMatrixAuth(r.Context(), userID, auth) + err = prov.checkMatrixAuth(r.Context(), userID, auth) } if err != nil { zerolog.Ctx(r.Context()).Warn().Err(err). From 593ad86b80bdb32a29bb980942f4f856b8479dfa Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 28 Jul 2024 23:43:07 +0300 Subject: [PATCH 0524/1647] federation: add wrappers for some federation endpoints --- bridgev2/matrix/provisioning.go | 46 +-- federation/client.go | 373 +++++++++++++++++++++++ federation/client_test.go | 23 ++ federation/{request.go => httpclient.go} | 7 - federation/request_test.go | 35 --- federation/signingkey.go | 4 + 6 files changed, 406 insertions(+), 82 deletions(-) create mode 100644 federation/client.go create mode 100644 federation/client_test.go rename federation/{request.go => httpclient.go} (95%) delete mode 100644 federation/request_test.go diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index f250ccd7..4c87af0f 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -11,7 +11,6 @@ import ( "encoding/json" "fmt" "net/http" - "net/url" "strings" "sync" "time" @@ -43,7 +42,7 @@ type ProvisioningAPI struct { log zerolog.Logger net bridgev2.NetworkConnector - fedClient *http.Client + fedClient *federation.Client logins map[string]*ProvLogin loginsLock sync.RWMutex @@ -89,9 +88,9 @@ func (prov *ProvisioningAPI) Init() { prov.logins = make(map[string]*ProvLogin) prov.net = prov.br.Bridge.Network prov.log = prov.br.Log.With().Str("component", "provisioning").Logger() - prov.fedClient = federation.NewFederationHTTPClient() - tp := prov.fedClient.Transport.(*federation.ServerResolvingTransport) - prov.fedClient.Timeout = 20 * time.Second + prov.fedClient = federation.NewClient("", nil) + prov.fedClient.HTTP.Timeout = 20 * time.Second + tp := prov.fedClient.HTTP.Transport.(*federation.ServerResolvingTransport) tp.Dialer.Timeout = 10 * time.Second tp.Transport.ResponseHeaderTimeout = 10 * time.Second tp.Transport.TLSHandshakeTimeout = 10 * time.Second @@ -159,10 +158,6 @@ func (prov *ProvisioningAPI) checkMatrixAuth(ctx context.Context, userID id.User } } -type respOpenIDUserInfo struct { - Sub id.UserID `json:"sub"` -} - func (prov *ProvisioningAPI) checkFederatedMatrixAuth(ctx context.Context, userID id.UserID, token string) error { homeserver := userID.Homeserver() wrappedToken := fmt.Sprintf("%s:%s", homeserver, token) @@ -171,9 +166,9 @@ func (prov *ProvisioningAPI) checkFederatedMatrixAuth(ctx context.Context, userI defer prov.matrixAuthCacheLock.Unlock() if cached, ok := prov.matrixAuthCache[wrappedToken]; ok && cached.Expires.After(time.Now()) && cached.UserID == userID { return nil - } else if validationResult, err := prov.validateOpenIDToken(ctx, homeserver, token); err != nil { + } else if validationResult, err := prov.fedClient.GetOpenIDUserInfo(ctx, homeserver, token); err != nil { return fmt.Errorf("failed to validate OpenID token: %w", err) - } else if validationResult != userID { + } else if validationResult.Sub != userID { return fmt.Errorf("mismatching user ID (%q != %q)", validationResult, userID) } else { prov.matrixAuthCache[wrappedToken] = matrixAuthCacheEntry{ @@ -184,35 +179,6 @@ func (prov *ProvisioningAPI) checkFederatedMatrixAuth(ctx context.Context, userI } } -func (prov *ProvisioningAPI) validateOpenIDToken(ctx context.Context, server string, token string) (id.UserID, error) { - reqURL := url.URL{ - Scheme: "matrix-federation", - Host: server, - Path: "/_matrix/federation/v1/openid/userinfo", - RawQuery: (&url.Values{ - "access_token": {token}, - }).Encode(), - } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), nil) - if err != nil { - return "", fmt.Errorf("failed to prepare request: %w", err) - } - resp, err := prov.fedClient.Do(req) - if err != nil { - return "", fmt.Errorf("failed to send request: %w", err) - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("unexpected status code %d", resp.StatusCode) - } - var respData respOpenIDUserInfo - err = json.NewDecoder(resp.Body).Decode(&respData) - if err != nil { - return "", fmt.Errorf("failed to decode response: %w", err) - } - return respData.Sub, nil -} - func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") diff --git a/federation/client.go b/federation/client.go new file mode 100644 index 00000000..dc8c139c --- /dev/null +++ b/federation/client.go @@ -0,0 +1,373 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "time" + + "go.mau.fi/util/exslices" + "go.mau.fi/util/jsontime" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/id" +) + +type Client struct { + HTTP *http.Client + ServerName string + UserAgent string + Key *SigningKey +} + +func NewClient(serverName string, key *SigningKey) *Client { + return &Client{ + HTTP: &http.Client{ + Transport: NewServerResolvingTransport(), + Timeout: 120 * time.Second, + }, + UserAgent: mautrix.DefaultUserAgent, + ServerName: serverName, + Key: key, + } +} + +func (c *Client) Version(ctx context.Context, serverName string) (resp *RespServerVersion, err error) { + err = c.MakeRequest(ctx, serverName, false, http.MethodGet, URLPath{"v1", "version"}, nil, &resp) + return +} + +func (c *Client) ServerKeys(ctx context.Context, serverName string) (resp *ServerKeyResponse, err error) { + err = c.MakeRequest(ctx, serverName, false, http.MethodGet, KeyURLPath{"v2", "server"}, nil, &resp) + return +} + +func (c *Client) QueryKeys(ctx context.Context, serverName string, req *ReqQueryKeys) (resp *ServerKeyResponse, err error) { + err = c.MakeRequest(ctx, serverName, false, http.MethodPost, KeyURLPath{"v2", "query"}, req, &resp) + return +} + +type PDU = json.RawMessage +type EDU = json.RawMessage + +type ReqSendTransaction struct { + Destination string `json:"destination"` + TxnID string `json:"-"` + + Origin string `json:"origin"` + OriginServerTS jsontime.UnixMilli `json:"origin_server_ts"` + PDUs []PDU `json:"pdus"` + EDUs []EDU `json:"edus,omitempty"` +} + +type PDUProcessingResult struct { + Error string `json:"error,omitempty"` +} + +type RespSendTransaction struct { + PDUs map[id.EventID]PDUProcessingResult `json:"pdus"` +} + +func (c *Client) SendTransaction(ctx context.Context, req *ReqSendTransaction) (resp *RespSendTransaction, err error) { + err = c.MakeRequest(ctx, req.Destination, true, http.MethodPost, URLPath{"v1", "send", req.TxnID}, req, &resp) + return +} + +type RespGetEventAuthChain struct { + AuthChain []PDU `json:"auth_chain"` +} + +func (c *Client) GetEventAuthChain(ctx context.Context, serverName string, roomID id.RoomID, eventID id.EventID) (resp *RespGetEventAuthChain, err error) { + err = c.MakeRequest(ctx, serverName, true, http.MethodGet, URLPath{"v1", "event_auth", roomID, eventID}, nil, &resp) + return +} + +type ReqBackfill struct { + ServerName string + RoomID id.RoomID + Limit int + BackfillFrom []id.EventID +} + +type RespBackfill struct { + Origin string `json:"origin"` + OriginServerTS jsontime.UnixMilli `json:"origin_server_ts"` + PDUs []PDU `json:"pdus"` +} + +func (c *Client) Backfill(ctx context.Context, req *ReqBackfill) (resp *RespBackfill, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.ServerName, + Method: http.MethodGet, + Path: URLPath{"v1", "backfill", req.RoomID}, + Query: url.Values{ + "limit": {strconv.Itoa(req.Limit)}, + "v": exslices.CastToString[string](req.BackfillFrom), + }, + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + +type ReqGetMissingEvents struct { + ServerName string `json:"-"` + RoomID id.RoomID `json:"-"` + EarliestEvents []id.EventID `json:"earliest_events"` + LatestEvents []id.EventID `json:"latest_events"` + Limit int `json:"limit,omitempty"` + MinDepth int `json:"min_depth,omitempty"` +} + +type RespGetMissingEvents struct { + Events []PDU `json:"events"` +} + +func (c *Client) GetMissingEvents(ctx context.Context, req *ReqGetMissingEvents) (resp *RespGetMissingEvents, err error) { + err = c.MakeRequest(ctx, req.ServerName, true, http.MethodPost, URLPath{"v1", "get_missing_events", req.RoomID}, req, &resp) + return +} + +func (c *Client) GetEvent(ctx context.Context, serverName string, eventID id.EventID) (resp *RespBackfill, err error) { + err = c.MakeRequest(ctx, serverName, true, http.MethodGet, URLPath{"v1", "event", eventID}, nil, &resp) + return +} + +type RespGetState struct { + AuthChain []PDU `json:"auth_chain"` + PDUs []PDU `json:"pdus"` +} + +func (c *Client) GetState(ctx context.Context, serverName string, roomID id.RoomID, eventID id.EventID) (resp *RespGetState, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: serverName, + Method: http.MethodGet, + Path: URLPath{"v1", "state", roomID}, + Query: url.Values{ + "event_id": {string(eventID)}, + }, + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + +type RespGetStateIDs struct { + AuthChain []id.EventID `json:"auth_chain_ids"` + PDUs []id.EventID `json:"pdu_ids"` +} + +func (c *Client) GetStateIDs(ctx context.Context, serverName string, roomID id.RoomID, eventID id.EventID) (resp *RespGetStateIDs, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: serverName, + Method: http.MethodGet, + Path: URLPath{"v1", "state_ids", roomID}, + Query: url.Values{ + "event_id": {string(eventID)}, + }, + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) TimestampToEvent(ctx context.Context, serverName string, roomID id.RoomID, timestamp time.Time, dir mautrix.Direction) (resp *mautrix.RespTimestampToEvent, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: serverName, + Method: http.MethodGet, + Path: URLPath{"v1", "timestamp_to_event", roomID}, + Query: url.Values{ + "dir": {string(dir)}, + "ts": {strconv.FormatInt(timestamp.UnixMilli(), 10)}, + }, + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + +type RespOpenIDUserInfo struct { + Sub id.UserID `json:"sub"` +} + +func (c *Client) GetOpenIDUserInfo(ctx context.Context, serverName, accessToken string) (resp *RespOpenIDUserInfo, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: serverName, + Method: http.MethodGet, + Path: URLPath{"v1", "openid", "userinfo"}, + Query: url.Values{"access_token": {accessToken}}, + ResponseJSON: &resp, + }) + return +} + +type URLPath []any + +func (fup URLPath) FullPath() []any { + return append([]any{"_matrix", "federation"}, []any(fup)...) +} + +type KeyURLPath []any + +func (fkup KeyURLPath) FullPath() []any { + return append([]any{"_matrix", "key"}, []any(fkup)...) +} + +type RequestParams struct { + ServerName string + Method string + Path mautrix.PrefixableURLPath + Query url.Values + Authenticate bool + RequestJSON any + + ResponseJSON any + DontReadBody bool +} + +func (c *Client) MakeRequest(ctx context.Context, serverName string, authenticate bool, method string, path mautrix.PrefixableURLPath, reqJSON, respJSON any) error { + _, _, err := c.MakeFullRequest(ctx, RequestParams{ + ServerName: serverName, + Method: method, + Path: path, + Authenticate: authenticate, + RequestJSON: reqJSON, + ResponseJSON: respJSON, + }) + return err +} + +func (c *Client) MakeFullRequest(ctx context.Context, params RequestParams) ([]byte, *http.Response, error) { + req, err := c.compileRequest(ctx, params) + if err != nil { + return nil, nil, err + } + resp, err := c.HTTP.Do(req) + if err != nil { + return nil, nil, mautrix.HTTPError{ + Request: req, + Response: resp, + + Message: "request error", + WrappedError: err, + } + } + defer func() { + _ = resp.Body.Close() + }() + var body []byte + if resp.StatusCode >= 400 { + body, err = mautrix.ParseErrorResponse(req, resp) + return body, resp, err + } else if params.ResponseJSON != nil || !params.DontReadBody { + body, err = io.ReadAll(resp.Body) + if err != nil { + return body, resp, mautrix.HTTPError{ + Request: req, + Response: resp, + + Message: "failed to read response body", + WrappedError: err, + } + } + if params.ResponseJSON != nil { + err = json.Unmarshal(body, params.ResponseJSON) + if err != nil { + return body, resp, mautrix.HTTPError{ + Request: req, + Response: resp, + + Message: "failed to unmarshal response JSON", + ResponseBody: string(body), + WrappedError: err, + } + } + } + } + return body, resp, nil +} + +func (c *Client) compileRequest(ctx context.Context, params RequestParams) (*http.Request, error) { + reqURL := mautrix.BuildURL(&url.URL{ + Scheme: "matrix-federation", + Host: params.ServerName, + }, params.Path.FullPath()...) + reqURL.RawQuery = params.Query.Encode() + var reqJSON json.RawMessage + var reqBody io.Reader + if params.RequestJSON != nil { + var err error + reqJSON, err = json.Marshal(params.RequestJSON) + if err != nil { + return nil, mautrix.HTTPError{ + Message: "failed to marshal JSON", + WrappedError: err, + } + } + reqBody = bytes.NewReader(reqJSON) + } + req, err := http.NewRequestWithContext(ctx, params.Method, reqURL.String(), reqBody) + if err != nil { + return nil, mautrix.HTTPError{ + Message: "failed to create request", + WrappedError: err, + } + } + req.Header.Set("User-Agent", c.UserAgent) + if params.Authenticate { + if c.ServerName == "" || c.Key == nil { + return nil, mautrix.HTTPError{ + Message: "client not configured for authentication", + } + } + auth, err := (&signableRequest{ + Method: req.Method, + URI: reqURL.RequestURI(), + Origin: c.ServerName, + Destination: params.ServerName, + Content: reqJSON, + }).Sign(c.Key) + if err != nil { + return nil, mautrix.HTTPError{ + Message: "failed to sign request", + WrappedError: err, + } + } + req.Header.Set("Authorization", auth) + } + return req, nil +} + +type signableRequest struct { + Method string `json:"method"` + URI string `json:"uri"` + Origin string `json:"origin"` + Destination string `json:"destination"` + Content any `json:"content,omitempty"` +} + +func (r *signableRequest) Sign(key *SigningKey) (string, error) { + sig, err := key.SignJSON(r) + if err != nil { + return "", err + } + return fmt.Sprintf( + `X-Matrix origin="%s",destination="%s",key="%s",sig="%s"`, + r.Origin, + r.Destination, + key.ID, + base64.RawURLEncoding.EncodeToString(sig), + ), nil +} diff --git a/federation/client_test.go b/federation/client_test.go new file mode 100644 index 00000000..ba3c3ed4 --- /dev/null +++ b/federation/client_test.go @@ -0,0 +1,23 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/federation" +) + +func TestClient_Version(t *testing.T) { + cli := federation.NewClient("", nil) + resp, err := cli.Version(context.TODO(), "maunium.net") + require.NoError(t, err) + require.Equal(t, "Synapse", resp.Server.Name) +} diff --git a/federation/request.go b/federation/httpclient.go similarity index 95% rename from federation/request.go rename to federation/httpclient.go index faeb16ad..d6d97280 100644 --- a/federation/request.go +++ b/federation/httpclient.go @@ -40,13 +40,6 @@ func NewServerResolvingTransport() *ServerResolvingTransport { return srt } -func NewFederationHTTPClient() *http.Client { - return &http.Client{ - Transport: NewServerResolvingTransport(), - Timeout: 120 * time.Second, - } -} - var _ http.RoundTripper = (*ServerResolvingTransport)(nil) func (srt *ServerResolvingTransport) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { diff --git a/federation/request_test.go b/federation/request_test.go deleted file mode 100644 index e9037f2d..00000000 --- a/federation/request_test.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package federation_test - -import ( - "encoding/json" - "net/http" - "testing" - - "github.com/stretchr/testify/require" - - "maunium.net/go/mautrix/federation" -) - -type serverVersionResp struct { - Server struct { - Name string `json:"name"` - Version string `json:"version"` - } `json:"server"` -} - -func TestNewFederationClient(t *testing.T) { - cli := federation.NewFederationHTTPClient() - resp, err := cli.Get("matrix-federation://maunium.net/_matrix/federation/v1/version") - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) - var respData serverVersionResp - err = json.NewDecoder(resp.Body).Decode(&respData) - require.NoError(t, err) - require.Equal(t, "Synapse", respData.Server.Name) -} diff --git a/federation/signingkey.go b/federation/signingkey.go index 3d118233..67751b48 100644 --- a/federation/signingkey.go +++ b/federation/signingkey.go @@ -83,6 +83,10 @@ type ServerVerifyKey struct { Key id.SigningKey `json:"key"` } +func (svk *ServerVerifyKey) Decode() (ed25519.PublicKey, error) { + return base64.RawStdEncoding.DecodeString(string(svk.Key)) +} + type OldVerifyKey struct { Key id.SigningKey `json:"key"` ExpiredTS jsontime.UnixMilli `json:"expired_ts"` From 9ce81b543e79454bb39971929c2bef43ee414231 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 29 Jul 2024 14:41:28 +0300 Subject: [PATCH 0525/1647] event: add generic helper to cast parsed event content --- event/content.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/event/content.go b/event/content.go index d81a7bdd..e0026e9e 100644 --- a/event/content.go +++ b/event/content.go @@ -243,6 +243,15 @@ func init() { gob.Register(&RoomKeyWithheldEventContent{}) } +func CastOrDefault[T any](content *Content) *T { + casted, ok := content.Parsed.(*T) + if ok { + return casted + } + casted2, _ := content.Parsed.(T) + return &casted2 +} + // Helper cast functions below func (content *Content) AsMember() *MemberEventContent { From 86c2f02d32753b5be096ce02ee65afbe70b6e9ee Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Mon, 29 Jul 2024 17:37:29 +0300 Subject: [PATCH 0526/1647] Add rudimentary MSC3414 encrypted state support (#260) --- crypto/cryptohelper/cryptohelper.go | 8 ++++++-- crypto/decryptmegolm.go | 15 +++++++++++---- crypto/encryptmegolm.go | 23 +++++++++++++++++------ 3 files changed, 34 insertions(+), 12 deletions(-) diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go index a0065012..7bb7037d 100644 --- a/crypto/cryptohelper/cryptohelper.go +++ b/crypto/cryptohelper/cryptohelper.go @@ -351,12 +351,16 @@ func (helper *CryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*eve } func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) { + return helper.EncryptWithStateKey(ctx, roomID, evtType, nil, content) +} + +func (helper *CryptoHelper) EncryptWithStateKey(ctx context.Context, roomID id.RoomID, evtType event.Type, stateKey *string, content any) (encrypted *event.EncryptedEventContent, err error) { if helper == nil { return nil, fmt.Errorf("crypto helper is nil") } helper.lock.RLock() defer helper.lock.RUnlock() - encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content) + encrypted, err = helper.mach.EncryptMegolmEventWithStateKey(ctx, roomID, evtType, stateKey, content) if err != nil { if !errors.Is(err, crypto.SessionExpired) && err != crypto.NoGroupSession && !errors.Is(err, crypto.SessionNotShared) { return @@ -371,7 +375,7 @@ func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtTy err = fmt.Errorf("failed to get room member list: %w", err) } else if err = helper.mach.ShareGroupSession(ctx, roomID, users); err != nil { err = fmt.Errorf("failed to share group session: %w", err) - } else if encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content); err != nil { + } else if encrypted, err = helper.mach.EncryptMegolmEventWithStateKey(ctx, roomID, evtType, stateKey, content); err != nil { err = fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err) } } diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index ba2811ab..00f99ce4 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -34,9 +34,10 @@ var ( ) type megolmEvent struct { - RoomID id.RoomID `json:"room_id"` - Type event.Type `json:"type"` - Content event.Content `json:"content"` + RoomID id.RoomID `json:"room_id"` + Type event.Type `json:"type"` + StateKey *string `json:"state_key"` + Content event.Content `json:"content"` } var ( @@ -148,7 +149,12 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event } else if megolmEvt.RoomID != encryptionRoomID { return nil, WrongRoom } - megolmEvt.Type.Class = evt.Type.Class + if evt.StateKey != nil && megolmEvt.StateKey != nil { + megolmEvt.Type.Class = event.StateEventType + } else { + megolmEvt.Type.Class = evt.Type.Class + megolmEvt.StateKey = nil + } log = log.With().Str("decrypted_event_type", megolmEvt.Type.Repr()).Logger() err = megolmEvt.Content.ParseRaw(megolmEvt.Type) if err != nil { @@ -163,6 +169,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event return &event.Event{ Sender: evt.Sender, Type: megolmEvt.Type, + StateKey: megolmEvt.StateKey, Timestamp: evt.Timestamp, ID: evt.ID, RoomID: evt.RoomID, diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index d8d5c7c9..93fe6409 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -52,9 +52,10 @@ func getMentions(content interface{}) *event.Mentions { } type rawMegolmEvent struct { - RoomID id.RoomID `json:"room_id"` - Type event.Type `json:"type"` - Content interface{} `json:"content"` + RoomID id.RoomID `json:"room_id"` + Type event.Type `json:"type"` + StateKey *string `json:"state_key,omitempty"` + Content interface{} `json:"content"` } // IsShareError returns true if the error is caused by the lack of an outgoing megolm session and can be solved with OlmMachine.ShareGroupSession @@ -83,6 +84,14 @@ func ParseMegolmMessageIndex(ciphertext []byte) (uint, error) { // If you use the event.Content struct, make sure you pass a pointer to the struct, // as JSON serialization will not work correctly otherwise. func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID, evtType event.Type, content interface{}) (*event.EncryptedEventContent, error) { + return mach.EncryptMegolmEventWithStateKey(ctx, roomID, evtType, nil, content) +} + +// EncryptMegolmEventWithStateKey encrypts data with the m.megolm.v1.aes-sha2 algorithm. +// +// If you use the event.Content struct, make sure you pass a pointer to the struct, +// as JSON serialization will not work correctly otherwise. +func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, roomID id.RoomID, evtType event.Type, stateKey *string, content interface{}) (*event.EncryptedEventContent, error) { mach.megolmEncryptLock.Lock() defer mach.megolmEncryptLock.Unlock() session, err := mach.CryptoStore.GetOutboundGroupSession(ctx, roomID) @@ -92,15 +101,17 @@ func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID return nil, NoGroupSession } plaintext, err := json.Marshal(&rawMegolmEvent{ - RoomID: roomID, - Type: evtType, - Content: content, + RoomID: roomID, + Type: evtType, + StateKey: stateKey, + Content: content, }) if err != nil { return nil, err } log := mach.machOrContextLog(ctx).With(). Str("event_type", evtType.Type). + Any("state_key", stateKey). Str("room_id", roomID.String()). Str("session_id", session.ID().String()). Uint("expected_index", session.Internal.MessageIndex()). From 7a3b919723e9194d6472d9eac70f6abb549e9bc4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 29 Jul 2024 17:41:41 +0300 Subject: [PATCH 0527/1647] bridgev2/simplevent: remove redundant prefix in types --- bridgev2/simplevent/events.go | 96 +++++++++++++++++------------------ 1 file changed, 48 insertions(+), 48 deletions(-) diff --git a/bridgev2/simplevent/events.go b/bridgev2/simplevent/events.go index c28e78ed..605b18a8 100644 --- a/bridgev2/simplevent/events.go +++ b/bridgev2/simplevent/events.go @@ -17,8 +17,8 @@ import ( "maunium.net/go/mautrix/bridgev2/networkid" ) -// SimpleRemoteEventMeta is a struct containing metadata fields used by most event types. -type SimpleRemoteEventMeta struct { +// EventMeta is a struct containing metadata fields used by most event types. +type EventMeta struct { Type bridgev2.RemoteEventType LogContext func(c zerolog.Context) zerolog.Context PortalKey networkid.PortalKey @@ -28,41 +28,41 @@ type SimpleRemoteEventMeta struct { } var ( - _ bridgev2.RemoteEvent = (*SimpleRemoteEventMeta)(nil) - _ bridgev2.RemoteEventThatMayCreatePortal = (*SimpleRemoteEventMeta)(nil) - _ bridgev2.RemoteEventWithTimestamp = (*SimpleRemoteEventMeta)(nil) + _ bridgev2.RemoteEvent = (*EventMeta)(nil) + _ bridgev2.RemoteEventThatMayCreatePortal = (*EventMeta)(nil) + _ bridgev2.RemoteEventWithTimestamp = (*EventMeta)(nil) ) -func (evt *SimpleRemoteEventMeta) AddLogContext(c zerolog.Context) zerolog.Context { +func (evt *EventMeta) AddLogContext(c zerolog.Context) zerolog.Context { return evt.LogContext(c) } -func (evt *SimpleRemoteEventMeta) GetPortalKey() networkid.PortalKey { +func (evt *EventMeta) GetPortalKey() networkid.PortalKey { return evt.PortalKey } -func (evt *SimpleRemoteEventMeta) GetTimestamp() time.Time { +func (evt *EventMeta) GetTimestamp() time.Time { if evt.Timestamp.IsZero() { return time.Now() } return evt.Timestamp } -func (evt *SimpleRemoteEventMeta) GetSender() bridgev2.EventSender { +func (evt *EventMeta) GetSender() bridgev2.EventSender { return evt.Sender } -func (evt *SimpleRemoteEventMeta) GetType() bridgev2.RemoteEventType { +func (evt *EventMeta) GetType() bridgev2.RemoteEventType { return evt.Type } -func (evt *SimpleRemoteEventMeta) ShouldCreatePortal() bool { +func (evt *EventMeta) ShouldCreatePortal() bool { return evt.CreatePortal } -// SimpleRemoteMessage is a simple implementation of [bridgev2.RemoteMessage] and [bridgev2.RemoteEdit]. -type SimpleRemoteMessage[T any] struct { - SimpleRemoteEventMeta +// Message is a simple implementation of [bridgev2.RemoteMessage] and [bridgev2.RemoteEdit]. +type Message[T any] struct { + EventMeta Data T ID networkid.MessageID @@ -73,29 +73,29 @@ type SimpleRemoteMessage[T any] struct { } var ( - _ bridgev2.RemoteMessage = (*SimpleRemoteMessage[any])(nil) - _ bridgev2.RemoteEdit = (*SimpleRemoteMessage[any])(nil) + _ bridgev2.RemoteMessage = (*Message[any])(nil) + _ bridgev2.RemoteEdit = (*Message[any])(nil) ) -func (evt *SimpleRemoteMessage[T]) ConvertMessage(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI) (*bridgev2.ConvertedMessage, error) { +func (evt *Message[T]) ConvertMessage(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI) (*bridgev2.ConvertedMessage, error) { return evt.ConvertMessageFunc(ctx, portal, intent, evt.Data) } -func (evt *SimpleRemoteMessage[T]) ConvertEdit(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (*bridgev2.ConvertedEdit, error) { +func (evt *Message[T]) ConvertEdit(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (*bridgev2.ConvertedEdit, error) { return evt.ConvertEditFunc(ctx, portal, intent, existing, evt.Data) } -func (evt *SimpleRemoteMessage[T]) GetID() networkid.MessageID { +func (evt *Message[T]) GetID() networkid.MessageID { return evt.ID } -func (evt *SimpleRemoteMessage[T]) GetTargetMessage() networkid.MessageID { +func (evt *Message[T]) GetTargetMessage() networkid.MessageID { return evt.TargetMessage } -// SimpleRemoteReaction is a simple implementation of [bridgev2.RemoteReaction] and [bridgev2.RemoteReactionRemove]. -type SimpleRemoteReaction struct { - SimpleRemoteEventMeta +// Reaction is a simple implementation of [bridgev2.RemoteReaction] and [bridgev2.RemoteReactionRemove]. +type Reaction struct { + EventMeta TargetMessage networkid.MessageID EmojiID networkid.EmojiID Emoji string @@ -103,28 +103,28 @@ type SimpleRemoteReaction struct { } var ( - _ bridgev2.RemoteReaction = (*SimpleRemoteReaction)(nil) - _ bridgev2.RemoteReactionWithMeta = (*SimpleRemoteReaction)(nil) - _ bridgev2.RemoteReactionRemove = (*SimpleRemoteReaction)(nil) + _ bridgev2.RemoteReaction = (*Reaction)(nil) + _ bridgev2.RemoteReactionWithMeta = (*Reaction)(nil) + _ bridgev2.RemoteReactionRemove = (*Reaction)(nil) ) -func (evt *SimpleRemoteReaction) GetTargetMessage() networkid.MessageID { +func (evt *Reaction) GetTargetMessage() networkid.MessageID { return evt.TargetMessage } -func (evt *SimpleRemoteReaction) GetReactionEmoji() (string, networkid.EmojiID) { +func (evt *Reaction) GetReactionEmoji() (string, networkid.EmojiID) { return evt.Emoji, evt.EmojiID } -func (evt *SimpleRemoteReaction) GetRemovedEmojiID() networkid.EmojiID { +func (evt *Reaction) GetRemovedEmojiID() networkid.EmojiID { return evt.EmojiID } -func (evt *SimpleRemoteReaction) GetReactionDBMetadata() any { +func (evt *Reaction) GetReactionDBMetadata() any { return evt.ReactionDBMeta } -// SimpleRemoteChatResync is a simple implementation of [bridgev2.RemoteChatResync]. +// ChatResync is a simple implementation of [bridgev2.RemoteChatResync]. // // If GetChatInfoFunc is set, it will be used to get the chat info. Otherwise, ChatInfo will be used. // @@ -132,8 +132,8 @@ func (evt *SimpleRemoteReaction) GetReactionDBMetadata() any { // Otherwise, the latest database message timestamp is compared to LatestMessageTS. // // All four fields are optional. -type SimpleRemoteChatResync struct { - SimpleRemoteEventMeta +type ChatResync struct { + EventMeta ChatInfo *bridgev2.ChatInfo GetChatInfoFunc func(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) @@ -143,12 +143,12 @@ type SimpleRemoteChatResync struct { } var ( - _ bridgev2.RemoteChatResync = (*SimpleRemoteChatResync)(nil) - _ bridgev2.RemoteChatResyncWithInfo = (*SimpleRemoteChatResync)(nil) - _ bridgev2.RemoteChatResyncBackfill = (*SimpleRemoteChatResync)(nil) + _ bridgev2.RemoteChatResync = (*ChatResync)(nil) + _ bridgev2.RemoteChatResyncWithInfo = (*ChatResync)(nil) + _ bridgev2.RemoteChatResyncBackfill = (*ChatResync)(nil) ) -func (evt *SimpleRemoteChatResync) CheckNeedsBackfill(ctx context.Context, latestMessage *database.Message) (bool, error) { +func (evt *ChatResync) CheckNeedsBackfill(ctx context.Context, latestMessage *database.Message) (bool, error) { if evt.CheckNeedsBackfillFunc != nil { return evt.CheckNeedsBackfillFunc(ctx, latestMessage) } else if latestMessage == nil { @@ -158,34 +158,34 @@ func (evt *SimpleRemoteChatResync) CheckNeedsBackfill(ctx context.Context, lates } } -func (evt *SimpleRemoteChatResync) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) { +func (evt *ChatResync) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) { if evt.GetChatInfoFunc != nil { return evt.GetChatInfoFunc(ctx, portal) } return evt.ChatInfo, nil } -// SimpleRemoteChatDelete is a simple implementation of [bridgev2.RemoteChatDelete]. -type SimpleRemoteChatDelete struct { - SimpleRemoteEventMeta +// ChatDelete is a simple implementation of [bridgev2.RemoteChatDelete]. +type ChatDelete struct { + EventMeta OnlyForMe bool } -var _ bridgev2.RemoteChatDelete = (*SimpleRemoteChatDelete)(nil) +var _ bridgev2.RemoteChatDelete = (*ChatDelete)(nil) -func (evt *SimpleRemoteChatDelete) DeleteOnlyForMe() bool { +func (evt *ChatDelete) DeleteOnlyForMe() bool { return evt.OnlyForMe } -// SimpleRemoteChatInfoChange is a simple implementation of [bridgev2.RemoteChatInfoChange]. -type SimpleRemoteChatInfoChange struct { - SimpleRemoteEventMeta +// ChatInfoChange is a simple implementation of [bridgev2.RemoteChatInfoChange]. +type ChatInfoChange struct { + EventMeta ChatInfoChange *bridgev2.ChatInfoChange } -var _ bridgev2.RemoteChatInfoChange = (*SimpleRemoteChatInfoChange)(nil) +var _ bridgev2.RemoteChatInfoChange = (*ChatInfoChange)(nil) -func (evt *SimpleRemoteChatInfoChange) GetChatInfoChange(ctx context.Context) (*bridgev2.ChatInfoChange, error) { +func (evt *ChatInfoChange) GetChatInfoChange(ctx context.Context) (*bridgev2.ChatInfoChange, error) { return evt.ChatInfoChange, nil } From 2d0135d5c162dc01d6b45980143c5a7673448b19 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 29 Jul 2024 19:18:52 +0300 Subject: [PATCH 0528/1647] bridgev2/portal: make nickname and power level optional in ChatMember --- bridgev2/portal.go | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index e9371580..b69f97ad 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1873,8 +1873,8 @@ type PortalInfo = ChatInfo type ChatMember struct { EventSender Membership event.Membership - Nickname string - PowerLevel int + Nickname *string + PowerLevel *int UserInfo *UserInfo PrevMembership event.Membership @@ -2143,7 +2143,9 @@ func (portal *Portal) GetInitialMemberList(ctx context.Context, members *ChatMem intent, extraUserID := portal.getIntentAndUserMXIDFor(ctx, member.EventSender, source, loginsInPortal, 0) if extraUserID != "" { invite = append(invite, extraUserID) - pl.EnsureUserLevel(extraUserID, member.PowerLevel) + if member.PowerLevel != nil { + pl.EnsureUserLevel(extraUserID, *member.PowerLevel) + } if intent != nil { // If intent is present along with a user ID, it's the ghost of a logged-in user, // so add it to the functional members list @@ -2152,7 +2154,9 @@ func (portal *Portal) GetInitialMemberList(ctx context.Context, members *ChatMem } if intent != nil { invite = append(invite, intent.GetMXID()) - pl.EnsureUserLevel(intent.GetMXID(), member.PowerLevel) + if member.PowerLevel != nil { + pl.EnsureUserLevel(intent.GetMXID(), *member.PowerLevel) + } } } portal.updateOtherUser(ctx, members) @@ -2210,7 +2214,9 @@ func (portal *Portal) SyncParticipants(ctx context.Context, members *ChatMemberL if member.Membership == "" { member.Membership = event.MembershipJoin } - powerChanged = currentPower.EnsureUserLevel(extraUserID, member.PowerLevel) || powerChanged + if member.PowerLevel != nil { + powerChanged = currentPower.EnsureUserLevel(extraUserID, *member.PowerLevel) || powerChanged + } currentMember, ok := currentMembers[extraUserID] delete(currentMembers, extraUserID) if ok && currentMember.Membership == member.Membership { From a863c76ed88de3e34fa5eb5c0c8ada6c975bd3e9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 29 Jul 2024 19:20:10 +0300 Subject: [PATCH 0529/1647] bridgev2/portal: only consider Receiver for FindPreferredLogin when it's set --- bridgev2/portal.go | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index b69f97ad..f326a18b 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -295,6 +295,21 @@ func (portal *Portal) handleCreateEvent(evt *portalCreateEvent) { } func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowRelay bool) (*UserLogin, *database.UserPortal, error) { + if portal.Receiver != "" { + login, err := portal.Bridge.GetExistingUserLoginByID(ctx, portal.Receiver) + if err != nil { + return nil, nil, err + } + if login.UserMXID != user.MXID { + if allowRelay && portal.Relay != nil { + return nil, nil, nil + } + // TODO different error for this case? + return nil, nil, ErrNotLoggedIn + } + up, err := portal.Bridge.DB.UserPortal.Get(ctx, login.UserLogin, portal.PortalKey) + return login, up, err + } logins, err := portal.Bridge.DB.UserPortal.GetAllForUserInPortal(ctx, user.MXID, portal.PortalKey) if err != nil { return nil, nil, err @@ -303,11 +318,7 @@ func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowR defer portal.Bridge.cacheLock.Unlock() for i, up := range logins { login, ok := user.logins[up.LoginID] - if portal.Receiver != "" { - if login.ID == portal.Receiver { - return login, up, nil - } - } else if ok && login.Client != nil && (len(logins) == i-1 || login.Client.IsLoggedIn()) { + if ok && login.Client != nil && (len(logins) == i-1 || login.Client.IsLoggedIn()) { return login, up, nil } } @@ -320,9 +331,6 @@ func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowR } var firstLogin *UserLogin for _, login := range user.logins { - if portal.Receiver != "" && login.ID != portal.Receiver { - continue - } firstLogin = login break } From 7701ba1a49b2f8bd604df5c6f22fe9312efc8098 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 30 Jul 2024 01:01:18 +0300 Subject: [PATCH 0530/1647] bridge,v2: fix special error messages for `M_EXCLUSIVE` and `M_UNKNOWN_TOKEN` Regressed by 98a842c075cc03ba718396b2860c6efd7f0e8a92 --- bridge/bridge.go | 25 +++++++++++++++---------- bridgev2/matrix/connector.go | 25 +++++++++++++++---------- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/bridge/bridge.go b/bridge/bridge.go index 6f608089..4af2470b 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -320,6 +320,18 @@ func (br *Bridge) InitVersion(tag, commit, buildTime string) { var MinSpecVersion = mautrix.SpecV14 +func (br *Bridge) logInitialRequestError(err error, defaultMessage string) { + if errors.Is(err, mautrix.MUnknownToken) { + br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was not accepted. Is the registration file installed in your homeserver correctly?") + br.ZLog.Info().Msg("See https://docs.mau.fi/faq/as-token for more info") + } else if errors.Is(err, mautrix.MExclusive) { + br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was accepted, but the /register request was not. Are the homeserver domain, bot username and username template in the config correct, and do they match the values in the registration?") + br.ZLog.Info().Msg("See https://docs.mau.fi/faq/as-register for more info") + } else { + br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg(defaultMessage) + } +} + func (br *Bridge) ensureConnection(ctx context.Context) { for { versions, err := br.Bot.Versions(ctx) @@ -328,7 +340,8 @@ func (br *Bridge) ensureConnection(ctx context.Context) { br.ZLog.Debug().Msg("M_FORBIDDEN in /versions, trying to register before retrying") err = br.Bot.EnsureRegistered(ctx) if err != nil { - br.ZLog.Err(err).Msg("Failed to register after /versions failed") + br.logInitialRequestError(err, "Failed to register after /versions failed with M_FORBIDDEN") + os.Exit(16) } } else { br.ZLog.Err(err).Msg("Failed to connect to homeserver, retrying in 10 seconds...") @@ -367,15 +380,7 @@ func (br *Bridge) ensureConnection(ctx context.Context) { resp, err := br.Bot.Whoami(ctx) if err != nil { - if errors.Is(err, mautrix.MUnknownToken) { - br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was not accepted. Is the registration file installed in your homeserver correctly?") - br.ZLog.Info().Msg("See https://docs.mau.fi/faq/as-token for more info") - } else if errors.Is(err, mautrix.MExclusive) { - br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was accepted, but the /register request was not. Are the homeserver domain, bot username and username template in the config correct, and do they match the values in the registration?") - br.ZLog.Info().Msg("See https://docs.mau.fi/faq/as-register for more info") - } else { - br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("/whoami request failed with unknown error") - } + br.logInitialRequestError(err, "/whoami request failed with unknown error") os.Exit(16) } else if resp.UserID != br.Bot.UserID { br.ZLog.WithLevel(zerolog.FatalLevel). diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 08aac337..282f1d3b 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -227,6 +227,18 @@ func (br *Connector) Stop() { var MinSpecVersion = mautrix.SpecV14 +func (br *Connector) logInitialRequestError(err error, defaultMessage string) { + if errors.Is(err, mautrix.MUnknownToken) { + br.Log.WithLevel(zerolog.FatalLevel).Msg("The as_token was not accepted. Is the registration file installed in your homeserver correctly?") + br.Log.Info().Msg("See https://docs.mau.fi/faq/as-token for more info") + } else if errors.Is(err, mautrix.MExclusive) { + br.Log.WithLevel(zerolog.FatalLevel).Msg("The as_token was accepted, but the /register request was not. Are the homeserver domain, bot username and username template in the config correct, and do they match the values in the registration?") + br.Log.Info().Msg("See https://docs.mau.fi/faq/as-register for more info") + } else { + br.Log.WithLevel(zerolog.FatalLevel).Err(err).Msg(defaultMessage) + } +} + func (br *Connector) ensureConnection(ctx context.Context) { for { versions, err := br.Bot.Versions(ctx) @@ -235,7 +247,8 @@ func (br *Connector) ensureConnection(ctx context.Context) { br.Log.Debug().Msg("M_FORBIDDEN in /versions, trying to register before retrying") err = br.Bot.EnsureRegistered(ctx) if err != nil { - br.Log.Err(err).Msg("Failed to register after /versions failed") + br.logInitialRequestError(err, "Failed to register after /versions failed with M_FORBIDDEN") + os.Exit(16) } } else { br.Log.Err(err).Msg("Failed to connect to homeserver, retrying in 10 seconds...") @@ -269,15 +282,7 @@ func (br *Connector) ensureConnection(ctx context.Context) { resp, err := br.Bot.Whoami(ctx) if err != nil { - if errors.Is(err, mautrix.MUnknownToken) { - br.Log.WithLevel(zerolog.FatalLevel).Msg("The as_token was not accepted. Is the registration file installed in your homeserver correctly?") - br.Log.Info().Msg("See https://docs.mau.fi/faq/as-token for more info") - } else if errors.Is(err, mautrix.MExclusive) { - br.Log.WithLevel(zerolog.FatalLevel).Msg("The as_token was accepted, but the /register request was not. Are the homeserver domain, bot username and username template in the config correct, and do they match the values in the registration?") - br.Log.Info().Msg("See https://docs.mau.fi/faq/as-register for more info") - } else { - br.Log.WithLevel(zerolog.FatalLevel).Err(err).Msg("/whoami request failed with unknown error") - } + br.logInitialRequestError(err, "/whoami request failed with unknown error") os.Exit(16) } else if resp.UserID != br.Bot.UserID { br.Log.WithLevel(zerolog.FatalLevel). From ff3ec002ee82f1e97cf8c74247790c51aafc1c76 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 30 Jul 2024 11:19:21 +0300 Subject: [PATCH 0531/1647] bridgev2/backfill: fix message cutoff --- bridgev2/portalbackfill.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index e46785a2..8b1bdeb1 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -165,7 +165,7 @@ func cutoffMessages(log *zerolog.Logger, messages []*BackfillMessage, forward bo return messages } if forward { - var cutoff int + cutoff := -1 for i, msg := range messages { if msg.ID == lastMessage.ID || msg.Timestamp.Before(lastMessage.Timestamp) { cutoff = i @@ -173,13 +173,13 @@ func cutoffMessages(log *zerolog.Logger, messages []*BackfillMessage, forward bo break } } - if cutoff != 0 { + if cutoff != -1 { log.Debug(). - Int("cutoff_count", cutoff). + Int("cutoff_count", cutoff+1). Int("total_count", len(messages)). Time("last_bridged_ts", lastMessage.Timestamp). Msg("Cutting off forward backfill messages older than latest bridged message") - messages = messages[cutoff:] + messages = messages[cutoff+1:] } } else { cutoff := -1 @@ -196,7 +196,7 @@ func cutoffMessages(log *zerolog.Logger, messages []*BackfillMessage, forward bo Int("total_count", len(messages)). Time("oldest_bridged_ts", lastMessage.Timestamp). Msg("Cutting off backward backfill messages newer than oldest bridged message") - messages = messages[cutoff:] + messages = messages[:cutoff] } } return messages From 4dda4114b367b3e47ba104cef0c21e1641a6a0ef Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 30 Jul 2024 12:19:34 +0300 Subject: [PATCH 0532/1647] bridgev2/matrix: mention MSC4171 in private_chat_portal_meta doc --- bridgev2/matrix/mxmain/example-config.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 2f411c0a..75d7880a 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -5,6 +5,7 @@ bridge: # Should the bridge create a space for each login containing the rooms that account is in? personal_filtering_spaces: true # Whether the bridge should set names and avatars explicitly for DM portals. + # This is only necessary when using clients that don't support MSC4171. private_chat_portal_meta: false # What should be done to portal rooms when a user logs out or is logged out? From c0218184e419579d0f1695ea3a98f35817a6912a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 30 Jul 2024 16:16:53 +0300 Subject: [PATCH 0533/1647] bridgev2/database: add sender MXID for reactions --- bridgev2/database/reaction.go | 18 ++++++++---- bridgev2/database/upgrades/00-latest.sql | 3 +- .../upgrades/15-reaction-sender-mxid.sql | 2 ++ bridgev2/portal.go | 28 ++++++++++++++++++- 4 files changed, 43 insertions(+), 8 deletions(-) create mode 100644 bridgev2/database/upgrades/15-reaction-sender-mxid.sql diff --git a/bridgev2/database/reaction.go b/bridgev2/database/reaction.go index eaa6ecd6..08ab2c8e 100644 --- a/bridgev2/database/reaction.go +++ b/bridgev2/database/reaction.go @@ -28,6 +28,7 @@ type Reaction struct { MessageID networkid.MessageID MessagePartID networkid.PartID SenderID networkid.UserID + SenderMXID id.UserID EmojiID networkid.EmojiID MXID id.EventID @@ -38,18 +39,19 @@ type Reaction struct { const ( getReactionBaseQuery = ` - SELECT bridge_id, message_id, message_part_id, sender_id, emoji_id, emoji, room_id, room_receiver, mxid, timestamp, metadata FROM reaction + SELECT bridge_id, message_id, message_part_id, sender_id, sender_mxid, emoji_id, emoji, room_id, room_receiver, mxid, timestamp, metadata FROM reaction ` getReactionByIDQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND message_part_id=$3 AND sender_id=$4 AND emoji_id=$5` getReactionByIDWithoutMessagePartQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND sender_id=$3 AND emoji_id=$4 ORDER BY message_part_id ASC LIMIT 1` getAllReactionsToMessageBySenderQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND sender_id=$3 ORDER BY timestamp DESC` getAllReactionsToMessageQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2` + getAllReactionsToMessagePartQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND message_part_id=$3` getReactionByMXIDQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` upsertReactionQuery = ` - INSERT INTO reaction (bridge_id, message_id, message_part_id, sender_id, emoji_id, emoji, room_id, room_receiver, mxid, timestamp, metadata) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) + INSERT INTO reaction (bridge_id, message_id, message_part_id, sender_id, sender_mxid, emoji_id, emoji, room_id, room_receiver, mxid, timestamp, metadata) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) ON CONFLICT (bridge_id, room_receiver, message_id, message_part_id, sender_id, emoji_id) - DO UPDATE SET mxid=excluded.mxid, timestamp=excluded.timestamp, emoji=excluded.emoji, metadata=excluded.metadata + DO UPDATE SET sender_mxid=excluded.sender_mxid, mxid=excluded.mxid, timestamp=excluded.timestamp, emoji=excluded.emoji, metadata=excluded.metadata ` deleteReactionQuery = ` DELETE FROM reaction WHERE bridge_id=$1 AND message_id=$2 AND message_part_id=$3 AND sender_id=$4 AND emoji_id=$5 @@ -72,6 +74,10 @@ func (rq *ReactionQuery) GetAllToMessage(ctx context.Context, messageID networki return rq.QueryMany(ctx, getAllReactionsToMessageQuery, rq.BridgeID, messageID) } +func (rq *ReactionQuery) GetAllToMessagePart(ctx context.Context, messageID networkid.MessageID, partID networkid.PartID) ([]*Reaction, error) { + return rq.QueryMany(ctx, getAllReactionsToMessagePartQuery, rq.BridgeID, messageID, partID) +} + func (rq *ReactionQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Reaction, error) { return rq.QueryOne(ctx, getReactionByMXIDQuery, rq.BridgeID, mxid) } @@ -89,7 +95,7 @@ func (rq *ReactionQuery) Delete(ctx context.Context, reaction *Reaction) error { func (r *Reaction) Scan(row dbutil.Scannable) (*Reaction, error) { var timestamp int64 err := row.Scan( - &r.BridgeID, &r.MessageID, &r.MessagePartID, &r.SenderID, &r.EmojiID, &r.Emoji, + &r.BridgeID, &r.MessageID, &r.MessagePartID, &r.SenderID, &r.SenderMXID, &r.EmojiID, &r.Emoji, &r.Room.ID, &r.Room.Receiver, &r.MXID, ×tamp, dbutil.JSON{Data: r.Metadata}, ) if err != nil { @@ -108,7 +114,7 @@ func (r *Reaction) ensureHasMetadata(metaType MetaTypeCreator) *Reaction { func (r *Reaction) sqlVariables() []any { return []any{ - r.BridgeID, r.MessageID, r.MessagePartID, r.SenderID, r.EmojiID, r.Emoji, + r.BridgeID, r.MessageID, r.MessagePartID, r.SenderID, r.SenderMXID, r.EmojiID, r.Emoji, r.Room.ID, r.Room.Receiver, r.MXID, r.Timestamp.UnixNano(), dbutil.JSON{Data: r.Metadata}, } } diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index 304bb00e..16c701ff 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v14 (compatible with v9+): Latest revision +-- v0 -> v15 (compatible with v9+): Latest revision CREATE TABLE "user" ( bridge_id TEXT NOT NULL, mxid TEXT NOT NULL, @@ -134,6 +134,7 @@ CREATE TABLE reaction ( message_id TEXT NOT NULL, message_part_id TEXT NOT NULL, sender_id TEXT NOT NULL, + sender_mxid TEXT NOT NULL DEFAULT '', emoji_id TEXT NOT NULL, room_id TEXT NOT NULL, room_receiver TEXT NOT NULL, diff --git a/bridgev2/database/upgrades/15-reaction-sender-mxid.sql b/bridgev2/database/upgrades/15-reaction-sender-mxid.sql new file mode 100644 index 00000000..e32bd832 --- /dev/null +++ b/bridgev2/database/upgrades/15-reaction-sender-mxid.sql @@ -0,0 +1,2 @@ +-- v15 (compatible with v9+): Save sender MXID for reactions +ALTER TABLE reaction ADD COLUMN sender_mxid TEXT NOT NULL DEFAULT ''; diff --git a/bridgev2/portal.go b/bridgev2/portal.go index f326a18b..3f51fad7 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -984,6 +984,9 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi if dbReaction.SenderID == "" { dbReaction.SenderID = preResp.SenderID } + if dbReaction.SenderMXID == "" { + dbReaction.SenderMXID = evt.Sender + } err = portal.Bridge.DB.Reaction.Upsert(ctx, dbReaction) if err != nil { log.Err(err).Msg("Failed to save reaction to database") @@ -1621,6 +1624,7 @@ func (portal *Portal) sendConvertedReaction( MessageID: targetMessage.ID, MessagePartID: targetMessage.PartID, SenderID: senderID, + SenderMXID: intent.GetMXID(), EmojiID: emojiID, Timestamp: ts, Metadata: dbMetadata, @@ -1655,6 +1659,22 @@ func (portal *Portal) sendConvertedReaction( } } +func (portal *Portal) getIntentForMXID(ctx context.Context, userID id.UserID) (MatrixAPI, error) { + if userID == "" { + return nil, nil + } else if ghost, err := portal.Bridge.GetGhostByMXID(ctx, userID); err != nil { + return nil, fmt.Errorf("failed to get ghost: %w", err) + } else if ghost != nil { + return ghost.Intent, nil + } else if user, err := portal.Bridge.GetExistingUserByMXID(ctx, userID); err != nil { + return nil, fmt.Errorf("failed to get user: %w", err) + } else if user != nil { + return user.DoublePuppet(ctx), nil + } else { + return nil, nil + } +} + func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *UserLogin, evt RemoteReactionRemove) { log := zerolog.Ctx(ctx) targetReaction, err := portal.getTargetReaction(ctx, evt) @@ -1665,7 +1685,13 @@ func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *Us log.Warn().Msg("Target reaction not found") return } - intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventReactionRemove) + intent, err := portal.getIntentForMXID(ctx, targetReaction.SenderMXID) + if err != nil { + log.Err(err).Stringer("sender_mxid", targetReaction.SenderMXID).Msg("Failed to get intent for removing reaction") + } + if intent == nil { + intent = portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventReactionRemove) + } ts := getEventTS(evt) _, err = intent.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ Parsed: &event.RedactionEventContent{ From c47f6ea7b08809b2db69933e83caca45aac33d26 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 30 Jul 2024 16:19:17 +0300 Subject: [PATCH 0534/1647] bridgev2: add option to skip bridging message but save to database --- bridgev2/database/message.go | 14 ++++++++ bridgev2/networkinterface.go | 1 + bridgev2/portal.go | 68 +++++++++++++++++++++++------------- 3 files changed, 59 insertions(+), 24 deletions(-) diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index 1403c9bc..b2e023d0 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -8,7 +8,10 @@ package database import ( "context" + "crypto/sha256" "database/sql" + "encoding/base64" + "strings" "time" "go.mau.fi/util/dbutil" @@ -191,3 +194,14 @@ func (m *Message) sqlVariables() []any { func (m *Message) updateSQLVariables() []any { return append(m.sqlVariables(), m.RowID) } + +const FakeMXIDPrefix = "~fake:" + +func (m *Message) SetFakeMXID() { + hash := sha256.Sum256([]byte(m.ID)) + m.MXID = id.EventID(FakeMXIDPrefix + base64.RawURLEncoding.EncodeToString(hash[:])) +} + +func (m *Message) HasFakeMXID() bool { + return strings.HasPrefix(m.MXID.String(), FakeMXIDPrefix) +} diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index e468be00..894379a0 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -30,6 +30,7 @@ type ConvertedMessagePart struct { Content *event.MessageEventContent Extra map[string]any DBMetadata any + DontBridge bool } func (cmp *ConvertedMessagePart) ToEditPart(part *database.Message) *ConvertedEditPart { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 3f51fad7..c374c5ca 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1386,27 +1386,31 @@ func (portal *Portal) sendConvertedMessage(ctx context.Context, id networkid.Mes ReplyTo: ptr.Val(converted.ReplyTo), Metadata: part.DBMetadata, } - resp, err := intent.SendMessage(ctx, portal.MXID, part.Type, &event.Content{ - Parsed: part.Content, - Raw: part.Extra, - }, &MatrixSendExtra{ - Timestamp: ts, - MessageMeta: dbMessage, - }) - if err != nil { - logContext(log.Err(err)).Str("part_id", string(part.ID)).Msg("Failed to send message part to Matrix") - continue + if part.DontBridge { + dbMessage.SetFakeMXID() + } else { + resp, err := intent.SendMessage(ctx, portal.MXID, part.Type, &event.Content{ + Parsed: part.Content, + Raw: part.Extra, + }, &MatrixSendExtra{ + Timestamp: ts, + MessageMeta: dbMessage, + }) + if err != nil { + logContext(log.Err(err)).Str("part_id", string(part.ID)).Msg("Failed to send message part to Matrix") + continue + } + logContext(log.Debug()). + Stringer("event_id", resp.EventID). + Str("part_id", string(part.ID)). + Msg("Sent message part to Matrix") + dbMessage.MXID = resp.EventID } - logContext(log.Debug()). - Stringer("event_id", resp.EventID). - Str("part_id", string(part.ID)). - Msg("Sent message part to Matrix") - dbMessage.MXID = resp.EventID - err = portal.Bridge.DB.Message.Insert(ctx, dbMessage) + err := portal.Bridge.DB.Message.Insert(ctx, dbMessage) if err != nil { logContext(log.Err(err)).Str("part_id", string(part.ID)).Msg("Failed to save message part to database") } - if converted.Disappear.Type != database.DisappearingTypeNone { + if converted.Disappear.Type != database.DisappearingTypeNone && !dbMessage.HasFakeMXID() { if converted.Disappear.Type == database.DisappearingTypeAfterSend && converted.Disappear.DisappearAt.IsZero() { converted.Disappear.DisappearAt = dbMessage.Timestamp.Add(converted.Disappear.Timer) } @@ -1416,7 +1420,7 @@ func (portal *Portal) sendConvertedMessage(ctx context.Context, id networkid.Mes DisappearingSetting: converted.Disappear, }) } - if prevThreadEvent != nil { + if prevThreadEvent != nil && !dbMessage.HasFakeMXID() { prevThreadEvent = dbMessage } output = append(output, dbMessage) @@ -1718,8 +1722,27 @@ func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *Use return } intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessageRemove) - ts := getEventTS(evt) - for _, part := range targetParts { + if intent == portal.Bridge.Bot && len(targetParts) > 0 { + senderIntent, err := portal.getIntentForMXID(ctx, targetParts[0].SenderMXID) + if err != nil { + log.Err(err).Stringer("sender_mxid", targetParts[0].SenderMXID).Msg("Failed to get intent for removing message") + } else if senderIntent != nil { + intent = senderIntent + } + } + portal.redactMessageParts(ctx, targetParts, intent, getEventTS(evt)) + err = portal.Bridge.DB.Message.DeleteAllParts(ctx, portal.Receiver, evt.GetTargetMessage()) + if err != nil { + log.Err(err).Msg("Failed to delete target message from database") + } +} + +func (portal *Portal) redactMessageParts(ctx context.Context, parts []*database.Message, intent MatrixAPI, ts time.Time) { + log := zerolog.Ctx(ctx) + for _, part := range parts { + if part.HasFakeMXID() { + continue + } resp, err := intent.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ Parsed: &event.RedactionEventContent{ Redacts: part.MXID, @@ -1735,13 +1758,10 @@ func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *Use Msg("Sent redaction of message part to Matrix") } } - err = portal.Bridge.DB.Message.DeleteAllParts(ctx, portal.Receiver, evt.GetTargetMessage()) - if err != nil { - log.Err(err).Msg("Failed to delete target message from database") - } } func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserLogin, evt RemoteReceipt) { + // TODO exclude fake mxids log := zerolog.Ctx(ctx) var err error var lastTarget *database.Message From 6042dbec538d00f78b9ee0eed243964cf3b4a9d9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 30 Jul 2024 16:20:04 +0300 Subject: [PATCH 0535/1647] bridgev2/caption: only merge if caption is m.text --- bridgev2/networkinterface.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 894379a0..b42e730f 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -73,7 +73,7 @@ func MergeCaption(textPart, mediaPart *ConvertedMessagePart) *ConvertedMessagePa return textPart } mediaPart = ptr.Clone(mediaPart) - if mediaPart.Content.Body != "" && mediaPart.Content.FileName != "" && mediaPart.Content.Body != mediaPart.Content.FileName { + if mediaPart.Content.MsgType == event.MsgNotice || (mediaPart.Content.Body != "" && mediaPart.Content.FileName != "" && mediaPart.Content.Body != mediaPart.Content.FileName) { textPart = ptr.Clone(textPart) textPart.Content.EnsureHasHTML() mediaPart.Content.EnsureHasHTML() @@ -85,6 +85,9 @@ func MergeCaption(textPart, mediaPart *ConvertedMessagePart) *ConvertedMessagePa mediaPart.Content.Format = textPart.Content.Format mediaPart.Content.FormattedBody = textPart.Content.FormattedBody } + if metaMerger, ok := mediaPart.DBMetadata.(database.MetaMerger); ok { + metaMerger.CopyFrom(textPart.DBMetadata) + } mediaPart.ID = textPart.ID return mediaPart } @@ -97,7 +100,7 @@ func (cm *ConvertedMessage) MergeCaption() bool { if textPart.Content.MsgType.IsMedia() { textPart, mediaPart = mediaPart, textPart } - if !mediaPart.Content.MsgType.IsMedia() || !textPart.Content.MsgType.IsText() { + if (!mediaPart.Content.MsgType.IsMedia() && mediaPart.Content.MsgType != event.MsgNotice) || textPart.Content.MsgType != event.MsgText { return false } merged := MergeCaption(textPart, mediaPart) From 3e01676eb5d8fcdfca600629a11f985e8614a302 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 30 Jul 2024 16:25:19 +0300 Subject: [PATCH 0536/1647] bridgev2/events: add reaction mass resync event --- bridgev2/networkinterface.go | 20 +++++++ bridgev2/portal.go | 105 +++++++++++++++++++++++++++++++++++ 2 files changed, 125 insertions(+) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index b42e730f..1e943d4d 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -635,6 +635,8 @@ func (ret RemoteEventType) String() string { return "RemoteEventReaction" case RemoteEventReactionRemove: return "RemoteEventReactionRemove" + case RemoteEventReactionSync: + return "RemoteEventReactionSync" case RemoteEventMessageRemove: return "RemoteEventMessageRemove" case RemoteEventReadReceipt: @@ -664,6 +666,7 @@ const ( RemoteEventEdit RemoteEventReaction RemoteEventReactionRemove + RemoteEventReactionSync RemoteEventMessageRemove RemoteEventReadReceipt RemoteEventDeliveryReceipt @@ -751,6 +754,23 @@ type RemoteReaction interface { GetReactionEmoji() (string, networkid.EmojiID) } +type ReactionSyncUser struct { + Reactions []*BackfillReaction + // Whether the list contains all reactions the user has sent + HasAllReactions bool +} + +type ReactionSyncData struct { + Users map[networkid.UserID]*ReactionSyncUser + // Whether the map contains all users who have reacted to the message + HasAllUsers bool +} + +type RemoteReactionSync interface { + RemoteEventWithTargetMessage + GetReactions() *ReactionSyncData +} + type RemoteReactionWithExtraContent interface { RemoteReaction GetReactionExtraContent() map[string]any diff --git a/bridgev2/portal.go b/bridgev2/portal.go index c374c5ca..c02c6bf9 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1235,6 +1235,8 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { portal.handleRemoteReaction(ctx, source, evt.(RemoteReaction)) case RemoteEventReactionRemove: portal.handleRemoteReactionRemove(ctx, source, evt.(RemoteReactionRemove)) + case RemoteEventReactionSync: + portal.handleRemoteReactionSync(ctx, source, evt.(RemoteReactionSync)) case RemoteEventMessageRemove: portal.handleRemoteMessageRemove(ctx, source, evt.(RemoteMessageRemove)) case RemoteEventReadReceipt: @@ -1569,6 +1571,109 @@ func getEventTS(evt RemoteEvent) time.Time { return time.Now() } +func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *UserLogin, evt RemoteReactionSync) { + log := zerolog.Ctx(ctx) + eventTS := getEventTS(evt) + targetMessage, err := portal.getTargetMessagePart(ctx, evt) + if err != nil { + log.Err(err).Msg("Failed to get target message for reaction") + return + } else if targetMessage == nil { + // TODO use deterministic event ID as target if applicable? + log.Warn().Msg("Target message for reaction not found") + return + } + var existingReactions []*database.Reaction + if partTargeter, ok := evt.(RemoteEventWithTargetPart); ok { + existingReactions, err = portal.Bridge.DB.Reaction.GetAllToMessagePart(ctx, evt.GetTargetMessage(), partTargeter.GetTargetMessagePart()) + } else { + existingReactions, err = portal.Bridge.DB.Reaction.GetAllToMessage(ctx, evt.GetTargetMessage()) + } + existing := make(map[networkid.UserID]map[networkid.EmojiID]*database.Reaction) + for _, existingReaction := range existingReactions { + if existing[existingReaction.SenderID] == nil { + existing[existingReaction.SenderID] = make(map[networkid.EmojiID]*database.Reaction) + } + existing[existingReaction.SenderID][existingReaction.EmojiID] = existingReaction + } + + doAddReaction := func(new *BackfillReaction) MatrixAPI { + intent := portal.GetIntentFor(ctx, new.Sender, source, RemoteEventReactionSync) + portal.sendConvertedReaction( + ctx, new.Sender.Sender, intent, targetMessage, new.EmojiID, new.Emoji, + new.Timestamp, new.DBMetadata, new.ExtraContent, + func(z *zerolog.Event) *zerolog.Event { + return z. + Any("reaction_sender_id", new.Sender). + Time("reaction_ts", new.Timestamp) + }, + ) + return intent + } + doRemoveReaction := func(old *database.Reaction, intent MatrixAPI) { + if intent == nil && old.SenderMXID != "" { + intent, err = portal.getIntentForMXID(ctx, old.SenderMXID) + if err != nil { + log.Err(err). + Stringer("reaction_sender_mxid", old.SenderMXID). + Msg("Failed to get intent for removing reaction") + } + } + if intent == nil { + log.Warn(). + Str("reaction_sender_id", string(old.SenderID)). + Stringer("reaction_sender_mxid", old.SenderMXID). + Msg("Didn't find intent for removing reaction, using bridge bot") + intent = portal.Bridge.Bot + } + _, err = intent.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ + Parsed: &event.RedactionEventContent{ + Redacts: old.MXID, + }, + }, &MatrixSendExtra{Timestamp: eventTS}) + if err != nil { + log.Err(err).Msg("Failed to redact old reaction") + } + } + doOverwriteReaction := func(new *BackfillReaction, old *database.Reaction) { + intent := doAddReaction(new) + doRemoveReaction(old, intent) + } + + newData := evt.GetReactions() + for userID, reactions := range newData.Users { + existingUserReactions := existing[userID] + delete(existing, userID) + for _, reaction := range reactions.Reactions { + if reaction.Timestamp.IsZero() { + reaction.Timestamp = eventTS + } + existingReaction, ok := existingUserReactions[reaction.EmojiID] + if ok { + delete(existingUserReactions, reaction.EmojiID) + if reaction.EmojiID != "" { + continue + } + doOverwriteReaction(reaction, existingReaction) + } else { + doAddReaction(reaction) + } + } + if reactions.HasAllReactions { + for _, existingReaction := range existingUserReactions { + doRemoveReaction(existingReaction, nil) + } + } + } + if newData.HasAllUsers { + for _, userReactions := range existing { + for _, existingReaction := range userReactions { + doRemoveReaction(existingReaction, nil) + } + } + } +} + func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogin, evt RemoteReaction) { log := zerolog.Ctx(ctx) targetMessage, err := portal.getTargetMessagePart(ctx, evt) From a1f38e286705fc18ade2da670bf1c20658494d8c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 30 Jul 2024 16:28:02 +0300 Subject: [PATCH 0537/1647] bridgev2/events: add support for async sending and incoming event upserting --- bridgev2/messagestatus.go | 2 + bridgev2/networkid/bridgeid.go | 5 ++ bridgev2/networkinterface.go | 21 +++++ bridgev2/portal.go | 143 ++++++++++++++++++++++++++++----- bridgev2/portalbackfill.go | 2 +- 5 files changed, 153 insertions(+), 20 deletions(-) diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index c49dbf1c..d2969eda 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -18,6 +18,8 @@ import ( ) var ( + ErrIgnoringRemoteEvent error = errors.New("ignoring remote event") + ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage() ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false) ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false) diff --git a/bridgev2/networkid/bridgeid.go b/bridgev2/networkid/bridgeid.go index 65d34609..46f82155 100644 --- a/bridgev2/networkid/bridgeid.go +++ b/bridgev2/networkid/bridgeid.go @@ -86,6 +86,11 @@ type UserLoginID string // Message IDs must be unique across rooms and consistent across users (i.e. globally unique within the bridge). type MessageID string +// TransactionID is a client-generated identifier for a message send operation on the remote network. +// +// Transaction IDs must be unique across users in a room, but don't need to be unique across different rooms. +type TransactionID string + // PartID is the ID of a message part on the remote network (e.g. index of image in album). // // Part IDs are only unique within a message, not globally. diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 1e943d4d..c932a4c1 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -241,6 +241,9 @@ type MaxFileSizeingNetwork interface { type MatrixMessageResponse struct { DB *database.Message + + Pending networkid.TransactionID + HandleEcho func(RemoteMessage, *database.Message) (bool, error) } type FileRestriction struct { @@ -629,6 +632,8 @@ func (ret RemoteEventType) String() string { return "RemoteEventUnknown" case RemoteEventMessage: return "RemoteEventMessage" + case RemoteEventMessageUpsert: + return "RemoteEventMessageUpsert" case RemoteEventEdit: return "RemoteEventEdit" case RemoteEventReaction: @@ -663,6 +668,7 @@ func (ret RemoteEventType) String() string { const ( RemoteEventUnknown RemoteEventType = iota RemoteEventMessage + RemoteEventMessageUpsert RemoteEventEdit RemoteEventReaction RemoteEventReactionRemove @@ -744,6 +750,21 @@ type RemoteMessage interface { ConvertMessage(ctx context.Context, portal *Portal, intent MatrixAPI) (*ConvertedMessage, error) } +type UpsertResult struct { + SubEvents []RemoteEvent + ContinueMessageHandling bool +} + +type RemoteMessageUpsert interface { + RemoteMessage + HandleExisting(ctx context.Context, portal *Portal, intent MatrixAPI, existing []*database.Message) (UpsertResult, error) +} + +type RemoteMessageWithTransactionID interface { + RemoteMessage + GetTransactionID() networkid.TransactionID +} + type RemoteEdit interface { RemoteEventWithTargetMessage ConvertEdit(ctx context.Context, portal *Portal, intent MatrixAPI, existing []*database.Message) (*ConvertedEdit, error) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index c02c6bf9..0d24b9c4 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -54,6 +54,12 @@ type portalEvent interface { isPortalEvent() } +type outgoingMessage struct { + db *database.Message + evt *event.Event + handle func(RemoteMessage, *database.Message) (bool, error) +} + type Portal struct { *database.Portal Bridge *Bridge @@ -65,6 +71,9 @@ type Portal struct { currentlyTypingLogins map[id.UserID]*UserLogin currentlyTypingLock sync.Mutex + outgoingMessages map[networkid.TransactionID]outgoingMessage + outgoingMessagesLock sync.Mutex + roomCreateLock sync.Mutex events chan portalEvent @@ -93,9 +102,9 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que Portal: dbPortal, Bridge: br, - events: make(chan portalEvent, PortalEventBuffer), - + events: make(chan portalEvent, PortalEventBuffer), currentlyTypingLogins: make(map[id.UserID]*UserLogin), + outgoingMessages: make(map[networkid.TransactionID]outgoingMessage), } br.portalsByKey[portal.PortalKey] = portal if portal.MXID != "" { @@ -767,12 +776,25 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin if message.SenderMXID == "" { message.SenderMXID = evt.Sender } - // Hack to ensure the ghost row exists - // TODO move to better place (like login) - portal.Bridge.GetGhostByID(ctx, message.SenderID) - err = portal.Bridge.DB.Message.Insert(ctx, message) - if err != nil { - log.Err(err).Msg("Failed to save message to database") + if resp.Pending != "" { + // TODO if the event queue is ever removed, this will have to be done by the network connector before sending the request + // (for now this is fine because incoming messages will wait in the queue for this function to return) + portal.outgoingMessagesLock.Lock() + portal.outgoingMessages[resp.Pending] = outgoingMessage{ + db: message, + evt: evt, + handle: resp.HandleEcho, + } + portal.outgoingMessagesLock.Unlock() + } else { + // Hack to ensure the ghost row exists + // TODO move to better place (like login) + portal.Bridge.GetGhostByID(ctx, message.SenderID) + err = portal.Bridge.DB.Message.Insert(ctx, message) + if err != nil { + log.Err(err).Msg("Failed to save message to database") + } + portal.sendSuccessStatus(ctx, evt) } if portal.Disappear.Type != database.DisappearingTypeNone { go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ @@ -785,7 +807,6 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin }, }) } - portal.sendSuccessStatus(ctx, evt) } func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent, caps *NetworkRoomCapabilities) { @@ -1227,7 +1248,7 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { switch evtType { case RemoteEventUnknown: log.Debug().Msg("Ignoring remote event with type unknown") - case RemoteEventMessage: + case RemoteEventMessage, RemoteEventMessageUpsert: portal.handleRemoteMessage(ctx, source, evt.(RemoteMessage)) case RemoteEventEdit: portal.handleRemoteEdit(ctx, source, evt.(RemoteEdit)) @@ -1366,7 +1387,7 @@ func (portal *Portal) applyRelationMeta(content *event.MessageEventContent, repl } } -func (portal *Portal) sendConvertedMessage(ctx context.Context, id networkid.MessageID, intent MatrixAPI, sender EventSender, converted *ConvertedMessage, ts time.Time, logContext func(*zerolog.Event) *zerolog.Event) []*database.Message { +func (portal *Portal) sendConvertedMessage(ctx context.Context, id networkid.MessageID, intent MatrixAPI, senderID networkid.UserID, converted *ConvertedMessage, ts time.Time, logContext func(*zerolog.Event) *zerolog.Event) []*database.Message { if logContext == nil { logContext = func(e *zerolog.Event) *zerolog.Event { return e @@ -1381,7 +1402,7 @@ func (portal *Portal) sendConvertedMessage(ctx context.Context, id networkid.Mes ID: id, PartID: part.ID, Room: portal.PortalKey, - SenderID: sender.Sender, + SenderID: senderID, SenderMXID: intent.GetMXID(), Timestamp: ts, ThreadRoot: ptr.Val(converted.ThreadRoot), @@ -1430,14 +1451,94 @@ func (portal *Portal) sendConvertedMessage(ctx context.Context, id networkid.Mes return output } +func (portal *Portal) checkPendingMessage(ctx context.Context, evt RemoteMessage) (bool, *database.Message) { + evtWithTxn, ok := evt.(RemoteMessageWithTransactionID) + if !ok { + return false, nil + } + txnID := evtWithTxn.GetTransactionID() + if txnID == "" { + return false, nil + } + portal.outgoingMessagesLock.Lock() + defer portal.outgoingMessagesLock.Unlock() + pending, ok := portal.outgoingMessages[txnID] + if !ok { + return false, nil + } + delete(portal.outgoingMessages, txnID) + pending.db.ID = evt.GetID() + if pending.db.SenderID == "" { + pending.db.SenderID = evt.GetSender().Sender + } + evtWithTimestamp, ok := evt.(RemoteEventWithTimestamp) + if ok { + pending.db.Timestamp = evtWithTimestamp.GetTimestamp() + } + var statusErr error + saveMessage := true + if pending.handle != nil { + saveMessage, statusErr = pending.handle(evt, pending.db) + } + if saveMessage { + // Hack to ensure the ghost row exists + // TODO move to better place (like login) + portal.Bridge.GetGhostByID(ctx, pending.db.SenderID) + err := portal.Bridge.DB.Message.Insert(ctx, pending.db) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save message to database after receiving remote echo") + } + } + if statusErr != nil { + portal.sendErrorStatus(ctx, pending.evt, statusErr) + } else { + portal.sendSuccessStatus(ctx, pending.evt) + } + zerolog.Ctx(ctx).Debug().Stringer("event_id", pending.evt.ID).Msg("Received remote echo for message") + return true, pending.db +} + +func (portal *Portal) handleRemoteUpsert(ctx context.Context, source *UserLogin, evt RemoteMessageUpsert, existing []*database.Message) bool { + log := zerolog.Ctx(ctx) + intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessageUpsert) + if intent == nil { + return false + } + res, err := evt.HandleExisting(ctx, portal, intent, existing) + if err != nil { + log.Err(err).Msg("Failed to handle existing message in upsert event after receiving remote echo") + } else if len(res.SubEvents) > 0 { + for _, subEvt := range res.SubEvents { + portal.handleRemoteEvent(source, subEvt) + } + } + return res.ContinueMessageHandling +} + func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin, evt RemoteMessage) { log := zerolog.Ctx(ctx) - existing, err := portal.Bridge.DB.Message.GetFirstPartByID(ctx, portal.Receiver, evt.GetID()) + upsertEvt, isUpsert := evt.(RemoteMessageUpsert) + isUpsert = isUpsert && evt.GetType() == RemoteEventMessageUpsert + if wasPending, dbMessage := portal.checkPendingMessage(ctx, evt); wasPending { + if isUpsert { + portal.handleRemoteUpsert(ctx, source, upsertEvt, []*database.Message{dbMessage}) + } + return + } + existing, err := portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, evt.GetID()) if err != nil { log.Err(err).Msg("Failed to check if message is a duplicate") - } else if existing != nil { - log.Debug().Stringer("existing_mxid", existing.MXID).Msg("Ignoring duplicate message") - return + } else if len(existing) > 0 { + if isUpsert { + if portal.handleRemoteUpsert(ctx, source, upsertEvt, existing) { + log.Debug().Msg("Upsert handler said to continue message handling normally") + } else { + return + } + } else { + log.Debug().Stringer("existing_mxid", existing[0].MXID).Msg("Ignoring duplicate message") + return + } } intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessage) if intent == nil { @@ -1446,11 +1547,15 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin ts := getEventTS(evt) converted, err := evt.ConvertMessage(ctx, portal, intent) if err != nil { - log.Err(err).Msg("Failed to convert remote message") - portal.sendRemoteErrorNotice(ctx, intent, err, ts, "message") + if errors.Is(err, ErrIgnoringRemoteEvent) { + log.Debug().Err(err).Msg("Remote event handling was cancelled by convert function") + } else { + log.Err(err).Msg("Failed to convert remote message") + portal.sendRemoteErrorNotice(ctx, intent, err, ts, "message") + } return } - portal.sendConvertedMessage(ctx, evt.GetID(), intent, evt.GetSender(), converted, ts, nil) + portal.sendConvertedMessage(ctx, evt.GetID(), intent, evt.GetSender().Sender, converted, ts, nil) } func (portal *Portal) sendRemoteErrorNotice(ctx context.Context, intent MatrixAPI, err error, ts time.Time, evtTypeName string) { diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index 8b1bdeb1..b05b16e5 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -369,7 +369,7 @@ func (portal *Portal) sendLegacyBackfill(ctx context.Context, source *UserLogin, var lastPart id.EventID for _, msg := range messages { intent := portal.GetIntentFor(ctx, msg.Sender, source, RemoteEventMessage) - dbMessages := portal.sendConvertedMessage(ctx, msg.ID, intent, msg.Sender, msg.ConvertedMessage, msg.Timestamp, func(z *zerolog.Event) *zerolog.Event { + dbMessages := portal.sendConvertedMessage(ctx, msg.ID, intent, msg.Sender.Sender, msg.ConvertedMessage, msg.Timestamp, func(z *zerolog.Event) *zerolog.Event { return z. Str("message_id", string(msg.ID)). Any("sender_id", msg.Sender). From 74c2cd06a7d56fdfa768d6c7965cf8596872c550 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 30 Jul 2024 16:28:40 +0300 Subject: [PATCH 0538/1647] bridgev2/events: allow dangerously adding parts in edits --- bridgev2/networkinterface.go | 20 ++++++++--- bridgev2/portal.go | 65 ++++++++++++++++++++++++++---------- 2 files changed, 63 insertions(+), 22 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index c932a4c1..4ee7d20d 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -46,10 +46,11 @@ func (cmp *ConvertedMessagePart) ToEditPart(part *database.Message) *ConvertedEd } } return &ConvertedEditPart{ - Part: part, - Type: cmp.Type, - Content: cmp.Content, - Extra: cmp.Extra, + Part: part, + Type: cmp.Type, + Content: cmp.Content, + Extra: cmp.Extra, + DontBridge: cmp.DontBridge, } } @@ -121,11 +122,17 @@ type ConvertedEditPart struct { Extra map[string]any // TopLevelExtra can be used to specify custom fields at the top level of the content rather than inside `m.new_content`. TopLevelExtra map[string]any + + DontBridge bool } type ConvertedEdit struct { ModifiedParts []*ConvertedEditPart DeletedParts []*database.Message + // Warning: added parts will be sent at the end of the room. + // If other messages have been sent after the message being edited, + // these new parts will not be next to the existing parts. + AddedParts *ConvertedMessage } // BridgeName contains information about the network that a connector bridges to. @@ -734,6 +741,11 @@ type RemoteEventWithTargetMessage interface { GetTargetMessage() networkid.MessageID } +type RemoteEventWithBundledParts interface { + RemoteEventWithTargetMessage + GetTargetDBMessage() []*database.Message +} + type RemoteEventWithTargetPart interface { RemoteEventWithTargetMessage GetTargetMessagePart() networkid.PartID diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 0d24b9c4..aa7a02ea 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1580,11 +1580,20 @@ func (portal *Portal) sendRemoteErrorNotice(ctx context.Context, intent MatrixAP func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, evt RemoteEdit) { log := zerolog.Ctx(ctx) - existing, err := portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, evt.GetTargetMessage()) - if err != nil { - log.Err(err).Msg("Failed to get edit target message") - return - } else if existing == nil { + var existing []*database.Message + if bundledEvt, ok := evt.(RemoteEventWithBundledParts); ok { + existing = bundledEvt.GetTargetDBMessage() + } + if existing == nil { + targetID := evt.GetTargetMessage() + var err error + existing, err = portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, targetID) + if err != nil { + log.Err(err).Msg("Failed to get edit target message") + return + } + } + if existing == nil { log.Warn().Msg("Edit target message not found") return } @@ -1599,8 +1608,19 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e portal.sendRemoteErrorNotice(ctx, intent, err, ts, "edit") return } + portal.sendConvertedEdit(ctx, existing[0].ID, evt.GetSender().Sender, converted, intent, ts) +} + +func (portal *Portal) sendConvertedEdit(ctx context.Context, targetID networkid.MessageID, senderID networkid.UserID, converted *ConvertedEdit, intent MatrixAPI, ts time.Time) { + log := zerolog.Ctx(ctx) for _, part := range converted.ModifiedParts { - part.Content.SetEdit(part.Part.MXID) + overrideMXID := true + if part.Part.Room != portal.PortalKey { + part.Part.Room = portal.PortalKey + } else if !part.Part.HasFakeMXID() { + part.Content.SetEdit(part.Part.MXID) + overrideMXID = false + } if part.TopLevelExtra == nil { part.TopLevelExtra = make(map[string]any) } @@ -1611,19 +1631,25 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e Parsed: part.Content, Raw: part.TopLevelExtra, } - resp, err := intent.SendMessage(ctx, portal.MXID, part.Type, wrappedContent, &MatrixSendExtra{ - Timestamp: ts, - MessageMeta: part.Part, - }) - if err != nil { - log.Err(err).Stringer("part_mxid", part.Part.MXID).Msg("Failed to edit message part") - } else { - log.Debug(). - Stringer("event_id", resp.EventID). - Str("part_id", string(part.Part.ID)). - Msg("Sent message part edit to Matrix") + if !part.DontBridge { + resp, err := intent.SendMessage(ctx, portal.MXID, part.Type, wrappedContent, &MatrixSendExtra{ + Timestamp: ts, + MessageMeta: part.Part, + }) + if err != nil { + log.Err(err).Stringer("part_mxid", part.Part.MXID).Msg("Failed to edit message part") + continue + } else { + log.Debug(). + Stringer("event_id", resp.EventID). + Str("part_id", string(part.Part.ID)). + Msg("Sent message part edit to Matrix") + if overrideMXID { + part.Part.MXID = resp.EventID + } + } } - err = portal.Bridge.DB.Message.Update(ctx, part.Part) + err := portal.Bridge.DB.Message.Update(ctx, part.Part) if err != nil { log.Err(err).Int64("part_rowid", part.Part.RowID).Msg("Failed to update message part in database") } @@ -1651,6 +1677,9 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e log.Err(err).Int64("part_rowid", part.RowID).Msg("Failed to delete message part from database") } } + if converted.AddedParts != nil { + portal.sendConvertedMessage(ctx, targetID, intent, senderID, converted.AddedParts, ts, nil) + } } func (portal *Portal) getTargetMessagePart(ctx context.Context, evt RemoteEventWithTargetMessage) (*database.Message, error) { From 27449f25625b45b6e54cd4f2dfd02643bc3fba07 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 30 Jul 2024 16:29:10 +0300 Subject: [PATCH 0539/1647] bridgev2/portal: allow access to private methods --- bridgev2/portal.go | 32 ++-- bridgev2/portalinternal.go | 266 ++++++++++++++++++++++++++++ bridgev2/portalinternal_generate.go | 160 +++++++++++++++++ 3 files changed, 442 insertions(+), 16 deletions(-) create mode 100644 bridgev2/portalinternal.go create mode 100644 bridgev2/portalinternal_generate.go diff --git a/bridgev2/portal.go b/bridgev2/portal.go index aa7a02ea..43a2e21e 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2153,7 +2153,7 @@ func (portal *Portal) ProcessChatInfoChange(ctx context.Context, sender EventSen portal.UpdateInfo(ctx, change.ChatInfo, source, intent, ts) } if change.MemberChanges != nil { - err := portal.SyncParticipants(ctx, change.MemberChanges, source, intent, ts) + err := portal.syncParticipants(ctx, change.MemberChanges, source, intent, ts) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to sync room members") } @@ -2274,7 +2274,7 @@ type UserLocalPortalInfo struct { Tag *event.RoomTag } -func (portal *Portal) UpdateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time) bool { +func (portal *Portal) updateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time) bool { if portal.Name == name && (portal.NameSet || portal.MXID == "") { return false } @@ -2283,7 +2283,7 @@ func (portal *Portal) UpdateName(ctx context.Context, name string, sender Matrix return true } -func (portal *Portal) UpdateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time) bool { +func (portal *Portal) updateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time) bool { if portal.Topic == topic && (portal.TopicSet || portal.MXID == "") { return false } @@ -2292,7 +2292,7 @@ func (portal *Portal) UpdateTopic(ctx context.Context, topic string, sender Matr return true } -func (portal *Portal) UpdateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time) bool { +func (portal *Portal) updateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time) bool { if portal.AvatarID == avatar.ID && (portal.AvatarSet || portal.MXID == "") { return false } @@ -2407,7 +2407,7 @@ func (portal *Portal) sendRoomMeta(ctx context.Context, sender MatrixAPI, ts tim return true } -func (portal *Portal) GetInitialMemberList(ctx context.Context, members *ChatMemberList, source *UserLogin, pl *event.PowerLevelsEventContent) (invite, functional []id.UserID, err error) { +func (portal *Portal) getInitialMemberList(ctx context.Context, members *ChatMemberList, source *UserLogin, pl *event.PowerLevelsEventContent) (invite, functional []id.UserID, err error) { if members == nil { invite = []id.UserID{source.UserMXID} return @@ -2480,7 +2480,7 @@ func (portal *Portal) updateOtherUser(ctx context.Context, members *ChatMemberLi return false } -func (portal *Portal) SyncParticipants(ctx context.Context, members *ChatMemberList, source *UserLogin, sender MatrixAPI, ts time.Time) error { +func (portal *Portal) syncParticipants(ctx context.Context, members *ChatMemberList, source *UserLogin, sender MatrixAPI, ts time.Time) error { var loginsInPortal []*UserLogin var err error if members.CheckAllLogins { @@ -2721,7 +2721,7 @@ func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting dat return true } -func (portal *Portal) UpdateParent(ctx context.Context, newParent networkid.PortalID, source *UserLogin) bool { +func (portal *Portal) updateParent(ctx context.Context, newParent networkid.PortalID, source *UserLogin) bool { if portal.ParentID == newParent { return false } @@ -2773,8 +2773,8 @@ func (portal *Portal) UpdateInfoFromGhost(ctx context.Context, ghost *Ghost) (ch return } } - changed = portal.UpdateName(ctx, ghost.Name, nil, time.Time{}) || changed - changed = portal.UpdateAvatar(ctx, &Avatar{ + changed = portal.updateName(ctx, ghost.Name, nil, time.Time{}) || changed + changed = portal.updateAvatar(ctx, &Avatar{ ID: ghost.AvatarID, MXC: ghost.AvatarMXC, Hash: ghost.AvatarHash, @@ -2787,28 +2787,28 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us changed := false if info.Name != nil { portal.NameIsCustom = true - changed = portal.UpdateName(ctx, *info.Name, sender, ts) || changed + changed = portal.updateName(ctx, *info.Name, sender, ts) || changed } if info.Topic != nil { - changed = portal.UpdateTopic(ctx, *info.Topic, sender, ts) || changed + changed = portal.updateTopic(ctx, *info.Topic, sender, ts) || changed } if info.Avatar != nil { portal.NameIsCustom = true - changed = portal.UpdateAvatar(ctx, info.Avatar, sender, ts) || changed + changed = portal.updateAvatar(ctx, info.Avatar, sender, ts) || changed } changed = portal.UpdateInfoFromGhost(ctx, nil) || changed if info.Disappear != nil { changed = portal.UpdateDisappearingSetting(ctx, *info.Disappear, sender, ts, false, false) || changed } if info.ParentID != nil { - changed = portal.UpdateParent(ctx, *info.ParentID, source) || changed + changed = portal.updateParent(ctx, *info.ParentID, source) || changed } if info.JoinRule != nil { // TODO change detection instead of spamming this every time? portal.sendRoomMeta(ctx, sender, ts, event.StateJoinRules, "", info.JoinRule) } if info.Members != nil && portal.MXID != "" && source != nil { - err := portal.SyncParticipants(ctx, info.Members, source, nil, time.Time{}) + err := portal.syncParticipants(ctx, info.Members, source, nil, time.Time{}) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to sync room members") } @@ -2903,7 +2903,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo }, Users: map[id.UserID]int{}, } - initialMembers, extraFunctionalMembers, err := portal.GetInitialMemberList(ctx, info.Members, source, powerLevels) + initialMembers, extraFunctionalMembers, err := portal.getInitialMemberList(ctx, info.Members, source, powerLevels) if err != nil { log.Err(err).Msg("Failed to process participant list for portal creation") return err @@ -3026,7 +3026,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo } } } else { - err = portal.SyncParticipants(ctx, info.Members, source, nil, time.Time{}) + err = portal.syncParticipants(ctx, info.Members, source, nil, time.Time{}) if err != nil { log.Err(err).Msg("Failed to sync participants after room creation") } diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go new file mode 100644 index 00000000..c261bd2d --- /dev/null +++ b/bridgev2/portalinternal.go @@ -0,0 +1,266 @@ +// GENERATED BY portalinternal_generate.go; DO NOT EDIT + +//go:generate go run portalinternal_generate.go +//go:generate goimports -w portalinternal.go + +package bridgev2 + +import ( + "context" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type PortalInternals Portal + +// Deprecated: portal internals should be used carefully and only when necessary. +func (portal *Portal) Internal() *PortalInternals { + return (*PortalInternals)(portal) +} + +func (portal *PortalInternals) UpdateLogger() { + (*Portal)(portal).updateLogger() +} + +func (portal *PortalInternals) QueueEvent(ctx context.Context, evt portalEvent) { + (*Portal)(portal).queueEvent(ctx, evt) +} + +func (portal *PortalInternals) EventLoop() { + (*Portal)(portal).eventLoop() +} + +func (portal *PortalInternals) HandleCreateEvent(evt *portalCreateEvent) { + (*Portal)(portal).handleCreateEvent(evt) +} + +func (portal *PortalInternals) SendSuccessStatus(ctx context.Context, evt *event.Event) { + (*Portal)(portal).sendSuccessStatus(ctx, evt) +} + +func (portal *PortalInternals) SendErrorStatus(ctx context.Context, evt *event.Event, err error) { + (*Portal)(portal).sendErrorStatus(ctx, evt, err) +} + +func (portal *PortalInternals) HandleMatrixEvent(sender *User, evt *event.Event) { + (*Portal)(portal).handleMatrixEvent(sender, evt) +} + +func (portal *PortalInternals) HandleMatrixReceipts(ctx context.Context, evt *event.Event) { + (*Portal)(portal).handleMatrixReceipts(ctx, evt) +} + +func (portal *PortalInternals) HandleMatrixReadReceipt(ctx context.Context, user *User, eventID id.EventID, receipt event.ReadReceipt) { + (*Portal)(portal).handleMatrixReadReceipt(ctx, user, eventID, receipt) +} + +func (portal *PortalInternals) HandleMatrixTyping(ctx context.Context, evt *event.Event) { + (*Portal)(portal).handleMatrixTyping(ctx, evt) +} + +func (portal *PortalInternals) SendTypings(ctx context.Context, userIDs []id.UserID, typing bool) { + (*Portal)(portal).sendTypings(ctx, userIDs, typing) +} + +func (portal *PortalInternals) PeriodicTypingUpdater() { + (*Portal)(portal).periodicTypingUpdater() +} + +func (portal *PortalInternals) CheckMessageContentCaps(ctx context.Context, caps *NetworkRoomCapabilities, content *event.MessageEventContent, evt *event.Event) bool { + return (*Portal)(portal).checkMessageContentCaps(ctx, caps, content, evt) +} + +func (portal *PortalInternals) HandleMatrixMessage(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { + (*Portal)(portal).handleMatrixMessage(ctx, sender, origSender, evt) +} + +func (portal *PortalInternals) HandleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent, caps *NetworkRoomCapabilities) { + (*Portal)(portal).handleMatrixEdit(ctx, sender, origSender, evt, content, caps) +} + +func (portal *PortalInternals) HandleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) { + (*Portal)(portal).handleMatrixReaction(ctx, sender, evt) +} + +func (portal *PortalInternals) HandleMatrixRedaction(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { + (*Portal)(portal).handleMatrixRedaction(ctx, sender, origSender, evt) +} + +func (portal *PortalInternals) HandleRemoteEvent(source *UserLogin, evt RemoteEvent) { + (*Portal)(portal).handleRemoteEvent(source, evt) +} + +func (portal *PortalInternals) GetIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID) { + return (*Portal)(portal).getIntentAndUserMXIDFor(ctx, sender, source, otherLogins, evtType) +} + +func (portal *PortalInternals) GetRelationMeta(ctx context.Context, currentMsg networkid.MessageID, replyToPtr *networkid.MessageOptionalPartID, threadRootPtr *networkid.MessageID, isBatchSend bool) (replyTo, threadRoot, prevThreadEvent *database.Message) { + return (*Portal)(portal).getRelationMeta(ctx, currentMsg, replyToPtr, threadRootPtr, isBatchSend) +} + +func (portal *PortalInternals) ApplyRelationMeta(content *event.MessageEventContent, replyTo, threadRoot, prevThreadEvent *database.Message) { + (*Portal)(portal).applyRelationMeta(content, replyTo, threadRoot, prevThreadEvent) +} + +func (portal *PortalInternals) SendConvertedMessage(ctx context.Context, id networkid.MessageID, intent MatrixAPI, senderID networkid.UserID, converted *ConvertedMessage, ts time.Time, logContext func(*zerolog.Event) *zerolog.Event) []*database.Message { + return (*Portal)(portal).sendConvertedMessage(ctx, id, intent, senderID, converted, ts, logContext) +} + +func (portal *PortalInternals) CheckPendingMessage(ctx context.Context, evt RemoteMessage) (bool, *database.Message) { + return (*Portal)(portal).checkPendingMessage(ctx, evt) +} + +func (portal *PortalInternals) HandleRemoteUpsert(ctx context.Context, source *UserLogin, evt RemoteMessageUpsert, existing []*database.Message) bool { + return (*Portal)(portal).handleRemoteUpsert(ctx, source, evt, existing) +} + +func (portal *PortalInternals) HandleRemoteMessage(ctx context.Context, source *UserLogin, evt RemoteMessage) { + (*Portal)(portal).handleRemoteMessage(ctx, source, evt) +} + +func (portal *PortalInternals) SendRemoteErrorNotice(ctx context.Context, intent MatrixAPI, err error, ts time.Time, evtTypeName string) { + (*Portal)(portal).sendRemoteErrorNotice(ctx, intent, err, ts, evtTypeName) +} + +func (portal *PortalInternals) HandleRemoteEdit(ctx context.Context, source *UserLogin, evt RemoteEdit) { + (*Portal)(portal).handleRemoteEdit(ctx, source, evt) +} + +func (portal *PortalInternals) SendConvertedEdit(ctx context.Context, targetID networkid.MessageID, senderID networkid.UserID, converted *ConvertedEdit, intent MatrixAPI, ts time.Time) { + (*Portal)(portal).sendConvertedEdit(ctx, targetID, senderID, converted, intent, ts) +} + +func (portal *PortalInternals) GetTargetMessagePart(ctx context.Context, evt RemoteEventWithTargetMessage) (*database.Message, error) { + return (*Portal)(portal).getTargetMessagePart(ctx, evt) +} + +func (portal *PortalInternals) GetTargetReaction(ctx context.Context, evt RemoteReactionRemove) (*database.Reaction, error) { + return (*Portal)(portal).getTargetReaction(ctx, evt) +} + +func (portal *PortalInternals) HandleRemoteReactionSync(ctx context.Context, source *UserLogin, evt RemoteReactionSync) { + (*Portal)(portal).handleRemoteReactionSync(ctx, source, evt) +} + +func (portal *PortalInternals) HandleRemoteReaction(ctx context.Context, source *UserLogin, evt RemoteReaction) { + (*Portal)(portal).handleRemoteReaction(ctx, source, evt) +} + +func (portal *PortalInternals) SendConvertedReaction(ctx context.Context, senderID networkid.UserID, intent MatrixAPI, targetMessage *database.Message, emojiID networkid.EmojiID, emoji string, ts time.Time, dbMetadata any, extraContent map[string]any, logContext func(*zerolog.Event) *zerolog.Event) { + (*Portal)(portal).sendConvertedReaction(ctx, senderID, intent, targetMessage, emojiID, emoji, ts, dbMetadata, extraContent, logContext) +} + +func (portal *PortalInternals) GetIntentForMXID(ctx context.Context, userID id.UserID) (MatrixAPI, error) { + return (*Portal)(portal).getIntentForMXID(ctx, userID) +} + +func (portal *PortalInternals) HandleRemoteReactionRemove(ctx context.Context, source *UserLogin, evt RemoteReactionRemove) { + (*Portal)(portal).handleRemoteReactionRemove(ctx, source, evt) +} + +func (portal *PortalInternals) HandleRemoteMessageRemove(ctx context.Context, source *UserLogin, evt RemoteMessageRemove) { + (*Portal)(portal).handleRemoteMessageRemove(ctx, source, evt) +} + +func (portal *PortalInternals) RedactMessageParts(ctx context.Context, parts []*database.Message, intent MatrixAPI, ts time.Time) { + (*Portal)(portal).redactMessageParts(ctx, parts, intent, ts) +} + +func (portal *PortalInternals) HandleRemoteReadReceipt(ctx context.Context, source *UserLogin, evt RemoteReceipt) { + (*Portal)(portal).handleRemoteReadReceipt(ctx, source, evt) +} + +func (portal *PortalInternals) HandleRemoteMarkUnread(ctx context.Context, source *UserLogin, evt RemoteMarkUnread) { + (*Portal)(portal).handleRemoteMarkUnread(ctx, source, evt) +} + +func (portal *PortalInternals) HandleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteReceipt) { + (*Portal)(portal).handleRemoteDeliveryReceipt(ctx, source, evt) +} + +func (portal *PortalInternals) HandleRemoteTyping(ctx context.Context, source *UserLogin, evt RemoteTyping) { + (*Portal)(portal).handleRemoteTyping(ctx, source, evt) +} + +func (portal *PortalInternals) HandleRemoteChatInfoChange(ctx context.Context, source *UserLogin, evt RemoteChatInfoChange) { + (*Portal)(portal).handleRemoteChatInfoChange(ctx, source, evt) +} + +func (portal *PortalInternals) HandleRemoteChatResync(ctx context.Context, source *UserLogin, evt RemoteChatResync) { + (*Portal)(portal).handleRemoteChatResync(ctx, source, evt) +} + +func (portal *PortalInternals) HandleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) { + (*Portal)(portal).handleRemoteChatDelete(ctx, source, evt) +} + +func (portal *PortalInternals) HandleRemoteBackfill(ctx context.Context, source *UserLogin, backfill RemoteBackfill) { + (*Portal)(portal).handleRemoteBackfill(ctx, source, backfill) +} + +func (portal *PortalInternals) UpdateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time) bool { + return (*Portal)(portal).updateName(ctx, name, sender, ts) +} + +func (portal *PortalInternals) UpdateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time) bool { + return (*Portal)(portal).updateTopic(ctx, topic, sender, ts) +} + +func (portal *PortalInternals) UpdateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time) bool { + return (*Portal)(portal).updateAvatar(ctx, avatar, sender, ts) +} + +func (portal *PortalInternals) GetBridgeInfo() (string, event.BridgeEventContent) { + return (*Portal)(portal).getBridgeInfo() +} + +func (portal *PortalInternals) SendStateWithIntentOrBot(ctx context.Context, sender MatrixAPI, eventType event.Type, stateKey string, content *event.Content, ts time.Time) (resp *mautrix.RespSendEvent, err error) { + return (*Portal)(portal).sendStateWithIntentOrBot(ctx, sender, eventType, stateKey, content, ts) +} + +func (portal *PortalInternals) SendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any) bool { + return (*Portal)(portal).sendRoomMeta(ctx, sender, ts, eventType, stateKey, content) +} + +func (portal *PortalInternals) GetInitialMemberList(ctx context.Context, members *ChatMemberList, source *UserLogin, pl *event.PowerLevelsEventContent) (invite, functional []id.UserID, err error) { + return (*Portal)(portal).getInitialMemberList(ctx, members, source, pl) +} + +func (portal *PortalInternals) UpdateOtherUser(ctx context.Context, members *ChatMemberList) (changed bool) { + return (*Portal)(portal).updateOtherUser(ctx, members) +} + +func (portal *PortalInternals) SyncParticipants(ctx context.Context, members *ChatMemberList, source *UserLogin, sender MatrixAPI, ts time.Time) error { + return (*Portal)(portal).syncParticipants(ctx, members, source, sender, ts) +} + +func (portal *PortalInternals) UpdateUserLocalInfo(ctx context.Context, info *UserLocalPortalInfo, source *UserLogin) { + (*Portal)(portal).updateUserLocalInfo(ctx, info, source) +} + +func (portal *PortalInternals) UpdateParent(ctx context.Context, newParent networkid.PortalID, source *UserLogin) bool { + return (*Portal)(portal).updateParent(ctx, newParent, source) +} + +func (portal *PortalInternals) LockedUpdateInfoFromGhost(ctx context.Context, ghost *Ghost) { + (*Portal)(portal).lockedUpdateInfoFromGhost(ctx, ghost) +} + +func (portal *PortalInternals) CreateMatrixRoomInLoop(ctx context.Context, source *UserLogin, info *ChatInfo) error { + return (*Portal)(portal).createMatrixRoomInLoop(ctx, source, info) +} + +func (portal *PortalInternals) UnlockedDelete(ctx context.Context) error { + return (*Portal)(portal).unlockedDelete(ctx) +} + +func (portal *PortalInternals) UnlockedDeleteCache() { + (*Portal)(portal).unlockedDeleteCache() +} diff --git a/bridgev2/portalinternal_generate.go b/bridgev2/portalinternal_generate.go new file mode 100644 index 00000000..8fd9e917 --- /dev/null +++ b/bridgev2/portalinternal_generate.go @@ -0,0 +1,160 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build ignore + +package main + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "os" + "strings" + + "go.mau.fi/util/exerrors" +) + +const header = `// GENERATED BY portalinternal_generate.go; DO NOT EDIT + +//go:generate go run portalinternal_generate.go +//go:generate goimports -w portalinternal.go + +package bridgev2 + +` +const postImportHeader = ` +type PortalInternals Portal + +// Deprecated: portal internals should be used carefully and only when necessary. +func (portal *Portal) Internal() *PortalInternals { + return (*PortalInternals)(portal) +} +` + +func getTypeName(expr ast.Expr) string { + switch e := expr.(type) { + case *ast.Ident: + return e.Name + case *ast.StarExpr: + return "*" + getTypeName(e.X) + case *ast.ArrayType: + return "[]" + getTypeName(e.Elt) + case *ast.MapType: + return fmt.Sprintf("map[%s]%s", getTypeName(e.Key), getTypeName(e.Value)) + case *ast.ChanType: + return fmt.Sprintf("chan %s", getTypeName(e.Value)) + case *ast.FuncType: + var params []string + for _, param := range e.Params.List { + params = append(params, getTypeName(param.Type)) + } + var results []string + if e.Results != nil { + for _, result := range e.Results.List { + results = append(results, getTypeName(result.Type)) + } + } + return fmt.Sprintf("func(%s) %s", strings.Join(params, ", "), strings.Join(results, ", ")) + case *ast.SelectorExpr: + return fmt.Sprintf("%s.%s", getTypeName(e.X), e.Sel.Name) + default: + panic(fmt.Errorf("unknown type %T", e)) + } +} + +func main() { + fset := token.NewFileSet() + f := exerrors.Must(parser.ParseFile(fset, "portal.go", nil, parser.SkipObjectResolution)) + file := exerrors.Must(os.OpenFile("portalinternal.go", os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)) + write := func(str string) { + exerrors.Must(file.WriteString(str)) + } + writef := func(format string, args ...any) { + exerrors.Must(fmt.Fprintf(file, format, args...)) + } + write(header) + write("import (\n") + for _, i := range f.Imports { + write("\t") + if i.Name != nil { + writef("%s ", i.Name.Name) + } + writef("%s\n", i.Path.Value) + } + write(")\n") + write(postImportHeader) + ast.Inspect(f, func(node ast.Node) (retVal bool) { + retVal = true + funcDecl, ok := node.(*ast.FuncDecl) + if !ok || funcDecl.Name.IsExported() { + return + } + if funcDecl.Recv == nil || len(funcDecl.Recv.List) == 0 || len(funcDecl.Recv.List[0].Names) == 0 || + funcDecl.Recv.List[0].Names[0].Name != "portal" { + return + } + writef("\nfunc (portal *PortalInternals) %s%s(", strings.ToUpper(funcDecl.Name.Name[0:1]), funcDecl.Name.Name[1:]) + for i, param := range funcDecl.Type.Params.List { + if i != 0 { + write(", ") + } + for j, name := range param.Names { + if j != 0 { + write(", ") + } + write(name.Name) + } + if len(param.Names) > 0 { + write(" ") + } + write(getTypeName(param.Type)) + } + write(") ") + if funcDecl.Type.Results != nil && len(funcDecl.Type.Results.List) > 0 { + needsParentheses := len(funcDecl.Type.Results.List) > 1 || len(funcDecl.Type.Results.List[0].Names) > 0 + if needsParentheses { + write("(") + } + for i, result := range funcDecl.Type.Results.List { + if i != 0 { + write(", ") + } + for j, name := range result.Names { + if j != 0 { + write(", ") + } + write(name.Name) + } + if len(result.Names) > 0 { + write(" ") + } + write(getTypeName(result.Type)) + } + if needsParentheses { + write(")") + } + write(" ") + } + write("{\n\t") + if funcDecl.Type.Results != nil { + write("return ") + } + writef("(*Portal)(portal).%s(", funcDecl.Name.Name) + for i, param := range funcDecl.Type.Params.List { + for j, name := range param.Names { + if i != 0 || j != 0 { + write(", ") + } + write(name.Name) + } + } + write(")\n}\n") + return + }) + exerrors.PanicIfNotNil(file.Close()) +} From c1a993719e67e06c98e08f591c94cfe3c3d65df6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 30 Jul 2024 16:29:42 +0300 Subject: [PATCH 0540/1647] bridgev2/messagestatus: add common media errors --- bridgev2/messagestatus.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index d2969eda..43a9822b 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -35,6 +35,11 @@ var ( ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true) ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) + + ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) + ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true) + ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true) + ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true) ) type MessageStatusEventInfo struct { From 779f61ac9c697dc4a01c4f258cf4e9e3576c7424 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 30 Jul 2024 16:36:08 +0300 Subject: [PATCH 0541/1647] bridgev2/simplevent: add more events --- bridgev2/simplevent/chat.go | 81 ++++++++++++++ bridgev2/simplevent/events.go | 191 -------------------------------- bridgev2/simplevent/message.go | 66 +++++++++++ bridgev2/simplevent/meta.go | 59 ++++++++++ bridgev2/simplevent/reaction.go | 67 +++++++++++ bridgev2/simplevent/receipt.go | 64 +++++++++++ 6 files changed, 337 insertions(+), 191 deletions(-) create mode 100644 bridgev2/simplevent/chat.go delete mode 100644 bridgev2/simplevent/events.go create mode 100644 bridgev2/simplevent/message.go create mode 100644 bridgev2/simplevent/meta.go create mode 100644 bridgev2/simplevent/reaction.go create mode 100644 bridgev2/simplevent/receipt.go diff --git a/bridgev2/simplevent/chat.go b/bridgev2/simplevent/chat.go new file mode 100644 index 00000000..c3a62b85 --- /dev/null +++ b/bridgev2/simplevent/chat.go @@ -0,0 +1,81 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package simplevent + +import ( + "context" + "time" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" +) + +// ChatResync is a simple implementation of [bridgev2.RemoteChatResync]. +// +// If GetChatInfoFunc is set, it will be used to get the chat info. Otherwise, ChatInfo will be used. +// +// If CheckNeedsBackfillFunc is set, it will be used to determine if backfill is required. +// Otherwise, the latest database message timestamp is compared to LatestMessageTS. +// +// All four fields are optional. +type ChatResync struct { + EventMeta + + ChatInfo *bridgev2.ChatInfo + GetChatInfoFunc func(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) + + LatestMessageTS time.Time + CheckNeedsBackfillFunc func(ctx context.Context, latestMessage *database.Message) (bool, error) +} + +var ( + _ bridgev2.RemoteChatResync = (*ChatResync)(nil) + _ bridgev2.RemoteChatResyncWithInfo = (*ChatResync)(nil) + _ bridgev2.RemoteChatResyncBackfill = (*ChatResync)(nil) +) + +func (evt *ChatResync) CheckNeedsBackfill(ctx context.Context, latestMessage *database.Message) (bool, error) { + if evt.CheckNeedsBackfillFunc != nil { + return evt.CheckNeedsBackfillFunc(ctx, latestMessage) + } else if latestMessage == nil { + return !evt.LatestMessageTS.IsZero(), nil + } else { + return !evt.LatestMessageTS.IsZero() && evt.LatestMessageTS.Before(latestMessage.Timestamp), nil + } +} + +func (evt *ChatResync) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) { + if evt.GetChatInfoFunc != nil { + return evt.GetChatInfoFunc(ctx, portal) + } + return evt.ChatInfo, nil +} + +// ChatDelete is a simple implementation of [bridgev2.RemoteChatDelete]. +type ChatDelete struct { + EventMeta + OnlyForMe bool +} + +var _ bridgev2.RemoteChatDelete = (*ChatDelete)(nil) + +func (evt *ChatDelete) DeleteOnlyForMe() bool { + return evt.OnlyForMe +} + +// ChatInfoChange is a simple implementation of [bridgev2.RemoteChatInfoChange]. +type ChatInfoChange struct { + EventMeta + + ChatInfoChange *bridgev2.ChatInfoChange +} + +var _ bridgev2.RemoteChatInfoChange = (*ChatInfoChange)(nil) + +func (evt *ChatInfoChange) GetChatInfoChange(ctx context.Context) (*bridgev2.ChatInfoChange, error) { + return evt.ChatInfoChange, nil +} diff --git a/bridgev2/simplevent/events.go b/bridgev2/simplevent/events.go deleted file mode 100644 index 605b18a8..00000000 --- a/bridgev2/simplevent/events.go +++ /dev/null @@ -1,191 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package simplevent - -import ( - "context" - "time" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix/bridgev2" - "maunium.net/go/mautrix/bridgev2/database" - "maunium.net/go/mautrix/bridgev2/networkid" -) - -// EventMeta is a struct containing metadata fields used by most event types. -type EventMeta struct { - Type bridgev2.RemoteEventType - LogContext func(c zerolog.Context) zerolog.Context - PortalKey networkid.PortalKey - Sender bridgev2.EventSender - CreatePortal bool - Timestamp time.Time -} - -var ( - _ bridgev2.RemoteEvent = (*EventMeta)(nil) - _ bridgev2.RemoteEventThatMayCreatePortal = (*EventMeta)(nil) - _ bridgev2.RemoteEventWithTimestamp = (*EventMeta)(nil) -) - -func (evt *EventMeta) AddLogContext(c zerolog.Context) zerolog.Context { - return evt.LogContext(c) -} - -func (evt *EventMeta) GetPortalKey() networkid.PortalKey { - return evt.PortalKey -} - -func (evt *EventMeta) GetTimestamp() time.Time { - if evt.Timestamp.IsZero() { - return time.Now() - } - return evt.Timestamp -} - -func (evt *EventMeta) GetSender() bridgev2.EventSender { - return evt.Sender -} - -func (evt *EventMeta) GetType() bridgev2.RemoteEventType { - return evt.Type -} - -func (evt *EventMeta) ShouldCreatePortal() bool { - return evt.CreatePortal -} - -// Message is a simple implementation of [bridgev2.RemoteMessage] and [bridgev2.RemoteEdit]. -type Message[T any] struct { - EventMeta - Data T - - ID networkid.MessageID - TargetMessage networkid.MessageID - - ConvertMessageFunc func(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, data T) (*bridgev2.ConvertedMessage, error) - ConvertEditFunc func(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message, data T) (*bridgev2.ConvertedEdit, error) -} - -var ( - _ bridgev2.RemoteMessage = (*Message[any])(nil) - _ bridgev2.RemoteEdit = (*Message[any])(nil) -) - -func (evt *Message[T]) ConvertMessage(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI) (*bridgev2.ConvertedMessage, error) { - return evt.ConvertMessageFunc(ctx, portal, intent, evt.Data) -} - -func (evt *Message[T]) ConvertEdit(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (*bridgev2.ConvertedEdit, error) { - return evt.ConvertEditFunc(ctx, portal, intent, existing, evt.Data) -} - -func (evt *Message[T]) GetID() networkid.MessageID { - return evt.ID -} - -func (evt *Message[T]) GetTargetMessage() networkid.MessageID { - return evt.TargetMessage -} - -// Reaction is a simple implementation of [bridgev2.RemoteReaction] and [bridgev2.RemoteReactionRemove]. -type Reaction struct { - EventMeta - TargetMessage networkid.MessageID - EmojiID networkid.EmojiID - Emoji string - ReactionDBMeta any -} - -var ( - _ bridgev2.RemoteReaction = (*Reaction)(nil) - _ bridgev2.RemoteReactionWithMeta = (*Reaction)(nil) - _ bridgev2.RemoteReactionRemove = (*Reaction)(nil) -) - -func (evt *Reaction) GetTargetMessage() networkid.MessageID { - return evt.TargetMessage -} - -func (evt *Reaction) GetReactionEmoji() (string, networkid.EmojiID) { - return evt.Emoji, evt.EmojiID -} - -func (evt *Reaction) GetRemovedEmojiID() networkid.EmojiID { - return evt.EmojiID -} - -func (evt *Reaction) GetReactionDBMetadata() any { - return evt.ReactionDBMeta -} - -// ChatResync is a simple implementation of [bridgev2.RemoteChatResync]. -// -// If GetChatInfoFunc is set, it will be used to get the chat info. Otherwise, ChatInfo will be used. -// -// If CheckNeedsBackfillFunc is set, it will be used to determine if backfill is required. -// Otherwise, the latest database message timestamp is compared to LatestMessageTS. -// -// All four fields are optional. -type ChatResync struct { - EventMeta - - ChatInfo *bridgev2.ChatInfo - GetChatInfoFunc func(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) - - LatestMessageTS time.Time - CheckNeedsBackfillFunc func(ctx context.Context, latestMessage *database.Message) (bool, error) -} - -var ( - _ bridgev2.RemoteChatResync = (*ChatResync)(nil) - _ bridgev2.RemoteChatResyncWithInfo = (*ChatResync)(nil) - _ bridgev2.RemoteChatResyncBackfill = (*ChatResync)(nil) -) - -func (evt *ChatResync) CheckNeedsBackfill(ctx context.Context, latestMessage *database.Message) (bool, error) { - if evt.CheckNeedsBackfillFunc != nil { - return evt.CheckNeedsBackfillFunc(ctx, latestMessage) - } else if latestMessage == nil { - return !evt.LatestMessageTS.IsZero(), nil - } else { - return !evt.LatestMessageTS.IsZero() && evt.LatestMessageTS.Before(latestMessage.Timestamp), nil - } -} - -func (evt *ChatResync) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) { - if evt.GetChatInfoFunc != nil { - return evt.GetChatInfoFunc(ctx, portal) - } - return evt.ChatInfo, nil -} - -// ChatDelete is a simple implementation of [bridgev2.RemoteChatDelete]. -type ChatDelete struct { - EventMeta - OnlyForMe bool -} - -var _ bridgev2.RemoteChatDelete = (*ChatDelete)(nil) - -func (evt *ChatDelete) DeleteOnlyForMe() bool { - return evt.OnlyForMe -} - -// ChatInfoChange is a simple implementation of [bridgev2.RemoteChatInfoChange]. -type ChatInfoChange struct { - EventMeta - - ChatInfoChange *bridgev2.ChatInfoChange -} - -var _ bridgev2.RemoteChatInfoChange = (*ChatInfoChange)(nil) - -func (evt *ChatInfoChange) GetChatInfoChange(ctx context.Context) (*bridgev2.ChatInfoChange, error) { - return evt.ChatInfoChange, nil -} diff --git a/bridgev2/simplevent/message.go b/bridgev2/simplevent/message.go new file mode 100644 index 00000000..928bffc9 --- /dev/null +++ b/bridgev2/simplevent/message.go @@ -0,0 +1,66 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package simplevent + +import ( + "context" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" +) + +// Message is a simple implementation of [bridgev2.RemoteMessage], [bridgev2.RemoteEdit] and [bridgev2.RemoteMessageUpsert]. +type Message[T any] struct { + EventMeta + Data T + + ID networkid.MessageID + TargetMessage networkid.MessageID + + ConvertMessageFunc func(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, data T) (*bridgev2.ConvertedMessage, error) + ConvertEditFunc func(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message, data T) (*bridgev2.ConvertedEdit, error) + HandleExistingFunc func(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message, data T) (bridgev2.UpsertResult, error) +} + +var ( + _ bridgev2.RemoteMessage = (*Message[any])(nil) + _ bridgev2.RemoteEdit = (*Message[any])(nil) + _ bridgev2.RemoteMessageUpsert = (*Message[any])(nil) +) + +func (evt *Message[T]) ConvertMessage(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI) (*bridgev2.ConvertedMessage, error) { + return evt.ConvertMessageFunc(ctx, portal, intent, evt.Data) +} + +func (evt *Message[T]) ConvertEdit(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (*bridgev2.ConvertedEdit, error) { + return evt.ConvertEditFunc(ctx, portal, intent, existing, evt.Data) +} + +func (evt *Message[T]) HandleExisting(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (bridgev2.UpsertResult, error) { + return evt.HandleExistingFunc(ctx, portal, intent, existing, evt.Data) +} + +func (evt *Message[T]) GetID() networkid.MessageID { + return evt.ID +} + +func (evt *Message[T]) GetTargetMessage() networkid.MessageID { + return evt.TargetMessage +} + +type MessageRemove struct { + EventMeta + + TargetMessage networkid.MessageID +} + +var _ bridgev2.RemoteMessageRemove = (*MessageRemove)(nil) + +func (evt *MessageRemove) GetTargetMessage() networkid.MessageID { + return evt.TargetMessage +} diff --git a/bridgev2/simplevent/meta.go b/bridgev2/simplevent/meta.go new file mode 100644 index 00000000..d61c6e91 --- /dev/null +++ b/bridgev2/simplevent/meta.go @@ -0,0 +1,59 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package simplevent + +import ( + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" +) + +// EventMeta is a struct containing metadata fields used by most event types. +type EventMeta struct { + Type bridgev2.RemoteEventType + LogContext func(c zerolog.Context) zerolog.Context + PortalKey networkid.PortalKey + Sender bridgev2.EventSender + CreatePortal bool + Timestamp time.Time +} + +var ( + _ bridgev2.RemoteEvent = (*EventMeta)(nil) + _ bridgev2.RemoteEventThatMayCreatePortal = (*EventMeta)(nil) + _ bridgev2.RemoteEventWithTimestamp = (*EventMeta)(nil) +) + +func (evt *EventMeta) AddLogContext(c zerolog.Context) zerolog.Context { + return evt.LogContext(c) +} + +func (evt *EventMeta) GetPortalKey() networkid.PortalKey { + return evt.PortalKey +} + +func (evt *EventMeta) GetTimestamp() time.Time { + if evt.Timestamp.IsZero() { + return time.Now() + } + return evt.Timestamp +} + +func (evt *EventMeta) GetSender() bridgev2.EventSender { + return evt.Sender +} + +func (evt *EventMeta) GetType() bridgev2.RemoteEventType { + return evt.Type +} + +func (evt *EventMeta) ShouldCreatePortal() bool { + return evt.CreatePortal +} diff --git a/bridgev2/simplevent/reaction.go b/bridgev2/simplevent/reaction.go new file mode 100644 index 00000000..34e0b025 --- /dev/null +++ b/bridgev2/simplevent/reaction.go @@ -0,0 +1,67 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package simplevent + +import ( + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" +) + +// Reaction is a simple implementation of [bridgev2.RemoteReaction] and [bridgev2.RemoteReactionRemove]. +type Reaction struct { + EventMeta + TargetMessage networkid.MessageID + EmojiID networkid.EmojiID + Emoji string + ExtraContent map[string]any + ReactionDBMeta any +} + +var ( + _ bridgev2.RemoteReaction = (*Reaction)(nil) + _ bridgev2.RemoteReactionWithMeta = (*Reaction)(nil) + _ bridgev2.RemoteReactionWithExtraContent = (*Reaction)(nil) + _ bridgev2.RemoteReactionRemove = (*Reaction)(nil) +) + +func (evt *Reaction) GetTargetMessage() networkid.MessageID { + return evt.TargetMessage +} + +func (evt *Reaction) GetReactionEmoji() (string, networkid.EmojiID) { + return evt.Emoji, evt.EmojiID +} + +func (evt *Reaction) GetRemovedEmojiID() networkid.EmojiID { + return evt.EmojiID +} + +func (evt *Reaction) GetReactionDBMetadata() any { + return evt.ReactionDBMeta +} + +func (evt *Reaction) GetReactionExtraContent() map[string]any { + return evt.ExtraContent +} + +type ReactionSync struct { + EventMeta + TargetMessage networkid.MessageID + Reactions *bridgev2.ReactionSyncData +} + +var ( + _ bridgev2.RemoteReactionSync = (*ReactionSync)(nil) +) + +func (evt *ReactionSync) GetTargetMessage() networkid.MessageID { + return evt.TargetMessage +} + +func (evt *ReactionSync) GetReactions() *bridgev2.ReactionSyncData { + return evt.Reactions +} diff --git a/bridgev2/simplevent/receipt.go b/bridgev2/simplevent/receipt.go new file mode 100644 index 00000000..dfd68730 --- /dev/null +++ b/bridgev2/simplevent/receipt.go @@ -0,0 +1,64 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package simplevent + +import ( + "time" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" +) + +type Receipt struct { + EventMeta + + LastTarget networkid.MessageID + Targets []networkid.MessageID + ReadUpTo time.Time +} + +var ( + _ bridgev2.RemoteReceipt = (*Receipt)(nil) +) + +func (evt *Receipt) GetLastReceiptTarget() networkid.MessageID { + return evt.LastTarget +} + +func (evt *Receipt) GetReceiptTargets() []networkid.MessageID { + return evt.Targets +} + +func (evt *Receipt) GetReadUpTo() time.Time { + return evt.ReadUpTo +} + +type MarkUnread struct { + EventMeta + Unread bool +} + +var ( + _ bridgev2.RemoteMarkUnread = (*MarkUnread)(nil) +) + +func (evt *MarkUnread) GetUnread() bool { + return evt.Unread +} + +type Typing struct { + EventMeta + Timeout time.Duration +} + +var ( + _ bridgev2.RemoteTyping = (*Typing)(nil) +) + +func (evt *Typing) GetTimeout() time.Duration { + return evt.Timeout +} From ede6630d79460eb0ae5acab9025499ca8c330dd5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 30 Jul 2024 17:43:59 +0300 Subject: [PATCH 0542/1647] bridgev2/portal: add MaxCount option to mass reaction sync events --- bridgev2/networkinterface.go | 3 +++ bridgev2/portal.go | 15 +++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 4ee7d20d..7468d339 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -791,6 +791,9 @@ type ReactionSyncUser struct { Reactions []*BackfillReaction // Whether the list contains all reactions the user has sent HasAllReactions bool + // If the list doesn't contain all reactions from the user, + // then this field can be set to remove old reactions if there are more than a certain number. + MaxCount int } type ReactionSyncData struct { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 43a2e21e..1fdfa2b0 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -7,6 +7,7 @@ package bridgev2 import ( + "cmp" "context" "errors" "fmt" @@ -20,6 +21,7 @@ import ( "go.mau.fi/util/exslices" "go.mau.fi/util/ptr" "go.mau.fi/util/variationselector" + "golang.org/x/exp/maps" "golang.org/x/exp/slices" "maunium.net/go/mautrix" @@ -1797,6 +1799,19 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User for _, existingReaction := range existingUserReactions { doRemoveReaction(existingReaction, nil) } + } else if reactions.MaxCount > 0 && len(existingUserReactions)+len(reactions.Reactions) > reactions.MaxCount { + remainingReactionList := maps.Values(existingUserReactions) + slices.SortFunc(remainingReactionList, func(a, b *database.Reaction) int { + diff := a.Timestamp.Compare(b.Timestamp) + if diff == 0 { + return cmp.Compare(a.EmojiID, b.EmojiID) + } + return diff + }) + numberToRemove := max(reactions.MaxCount-len(reactions.Reactions), len(remainingReactionList)) + for i := 0; i < numberToRemove; i++ { + doRemoveReaction(remainingReactionList[i], nil) + } } } if newData.HasAllUsers { From 87d540348ceb1a8feae9a877e2204a65c58d9efe Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 31 Jul 2024 15:56:37 +0300 Subject: [PATCH 0543/1647] bridgev2/networkinterface: include reaction being overridden in HandleMatrixReaction --- bridgev2/networkinterface.go | 2 ++ bridgev2/portal.go | 1 + 2 files changed, 3 insertions(+) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 7468d339..ddc83c52 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -894,6 +894,8 @@ type MatrixReaction struct { TargetMessage *database.Message PreHandleResp *MatrixReactionPreResponse + // When EmojiID is blank and there's already an existing reaction, this is the old reaction that is being overridden. + ReactionToOverride *database.Reaction // When MaxReactions is >0 in the pre-response, this is the list of previous reactions that should be preserved. ExistingReactionsToKeep []*database.Reaction } diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 1fdfa2b0..0e28e82d 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -936,6 +936,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi portal.sendSuccessStatus(ctx, evt) return } + react.ReactionToOverride = existing _, err = portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ Parsed: &event.RedactionEventContent{ Redacts: existing.MXID, From c76ebc6947322e0e1f5777b8683390a77db6f79a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 31 Jul 2024 15:57:05 +0300 Subject: [PATCH 0544/1647] bridgev2/networkinterface: add option to not send MSS event when handling remote echo --- bridgev2/messagestatus.go | 3 ++- bridgev2/portal.go | 10 ++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index 43a9822b..97128314 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -18,7 +18,8 @@ import ( ) var ( - ErrIgnoringRemoteEvent error = errors.New("ignoring remote event") + ErrIgnoringRemoteEvent = errors.New("ignoring remote event") + ErrNoStatus = errors.New("omit message status") ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage() ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 0e28e82d..10c830c5 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1492,10 +1492,12 @@ func (portal *Portal) checkPendingMessage(ctx context.Context, evt RemoteMessage zerolog.Ctx(ctx).Err(err).Msg("Failed to save message to database after receiving remote echo") } } - if statusErr != nil { - portal.sendErrorStatus(ctx, pending.evt, statusErr) - } else { - portal.sendSuccessStatus(ctx, pending.evt) + if !errors.Is(statusErr, ErrNoStatus) { + if statusErr != nil { + portal.sendErrorStatus(ctx, pending.evt, statusErr) + } else { + portal.sendSuccessStatus(ctx, pending.evt) + } } zerolog.Ctx(ctx).Debug().Stringer("event_id", pending.evt.ID).Msg("Received remote echo for message") return true, pending.db From 11e8d6296914881ae715ae7a120efcec825a7be2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 31 Jul 2024 15:57:23 +0300 Subject: [PATCH 0545/1647] bridgev2/networkinterface: add option to save database parts after handling upsert --- bridgev2/networkinterface.go | 1 + bridgev2/portal.go | 11 ++++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index ddc83c52..26f06d3b 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -764,6 +764,7 @@ type RemoteMessage interface { type UpsertResult struct { SubEvents []RemoteEvent + SaveParts bool ContinueMessageHandling bool } diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 10c830c5..6c2fa3ce 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1512,7 +1512,16 @@ func (portal *Portal) handleRemoteUpsert(ctx context.Context, source *UserLogin, res, err := evt.HandleExisting(ctx, portal, intent, existing) if err != nil { log.Err(err).Msg("Failed to handle existing message in upsert event after receiving remote echo") - } else if len(res.SubEvents) > 0 { + } + if res.SaveParts { + for _, part := range existing { + err = portal.Bridge.DB.Message.Update(ctx, part) + if err != nil { + log.Err(err).Str("part_id", string(part.PartID)).Msg("Failed to update message part in database") + } + } + } + if len(res.SubEvents) > 0 { for _, subEvt := range res.SubEvents { portal.handleRemoteEvent(source, subEvt) } From 62671f147fdf974701ee575e32b61478d30b8fa2 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Wed, 31 Jul 2024 08:05:15 -0600 Subject: [PATCH 0546/1647] crypto/backup: update comment on computing MAC for encrypted session data Signed-off-by: Sumner Evans --- crypto/backup/encryptedsessiondata.go | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/crypto/backup/encryptedsessiondata.go b/crypto/backup/encryptedsessiondata.go index 37b0a6c8..ec551dbe 100644 --- a/crypto/backup/encryptedsessiondata.go +++ b/crypto/backup/encryptedsessiondata.go @@ -47,28 +47,22 @@ func calculateEncryptionParameters(sharedSecret []byte) (key, macKey, iv []byte, return encryptionParams[:32], encryptionParams[32:64], encryptionParams[64:], nil } -// calculateCompatMAC calculates the MAC for compatibility with Olm and -// Vodozemac which do not actually write the ciphertext when computing the MAC. +// calculateCompatMAC calculates the MAC as described in step 5 of according to +// [Section 11.12.3.2.2] of the Spec which was updated in spec version 1.10 to +// reflect what is actually implemented in libolm and Vodozemac. // -// Deprecated: Use [calculateMAC] instead. +// Libolm implemented the MAC functionality incorrectly. The MAC is computed +// over an empty string rather than the ciphertext. Vodozemac implemented this +// functionality the same way as libolm for compatibility. In version 1.10 of +// the spec, the description of step 5 was updated to reflect the de-facto +// standard of libolm and Vodozemac. +// +// [Section 11.12.3.2.2]: https://spec.matrix.org/v1.11/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 func calculateCompatMAC(macKey []byte) []byte { hash := hmac.New(sha256.New, macKey) return hash.Sum(nil)[:8] } -// calculateMAC calculates the MAC as described in step 5 of according to -// [Section 11.12.3.2.2] of the Spec. -// -// [Section 11.12.3.2.2]: https://spec.matrix.org/v1.9/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 -func calculateMAC(macKey, ciphertext []byte) []byte { - hash := hmac.New(sha256.New, macKey) - _, err := hash.Write(ciphertext) - if err != nil { - panic(err) - } - return hash.Sum(nil)[:8] -} - // EncryptSessionData encrypts the given session data with the given recovery // key as defined in [Section 11.12.3.2.2 of the Spec]. // From 057b2ee61dfb21e2536d4fbf87823b49a88b3878 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 1 Aug 2024 13:10:44 +0300 Subject: [PATCH 0547/1647] bridgev2: add option to force all messages to be sent as DM user --- bridgev2/networkinterface.go | 21 +++++++++++++++++++-- bridgev2/portal.go | 7 +++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 26f06d3b..898b58ab 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -54,10 +54,27 @@ func (cmp *ConvertedMessagePart) ToEditPart(part *database.Message) *ConvertedEd } } +// EventSender represents a specific user in a chat. type EventSender struct { - IsFromMe bool + // If IsFromMe is true, the UserLogin who the event was received through is used as the sender. + // Double puppeting will be used if available. + IsFromMe bool + // SenderLogin is the ID of the UserLogin who sent the event. This may be different from the + // login the event was received through. It is used to ensure double puppeting can still be + // used even if the event is received through another login. SenderLogin networkid.UserLoginID - Sender networkid.UserID + // Sender is the remote user ID of the user who sent the event. + // For new events, this will not be used for double puppeting. + // + // However, in the member list, [ChatMemberList.CheckAllLogins] can be specified to go through every login + // and call [NetworkAPI.IsThisUser] to check if this ID belongs to that login. This method is not recommended, + // it is better to fill the IsFromMe and SenderLogin fields appropriately. + Sender networkid.UserID + + // ForceDMUser can be set if the event should be sent as the DM user even if the Sender is different. + // This only applies in DM rooms where [database.Portal.OtherUserID] is set and is ignored if IsFromMe is true. + // A warning will be logged if the sender is overridden due to this flag. + ForceDMUser bool } type ConvertedMessage struct { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 6c2fa3ce..7e025dda 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1286,6 +1286,13 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID) { var ghost *Ghost + if !sender.IsFromMe && sender.ForceDMUser && portal.OtherUserID != "" && sender.Sender != portal.OtherUserID { + zerolog.Ctx(ctx).Warn(). + Str("original_id", string(sender.Sender)). + Str("default_other_user", string(portal.OtherUserID)). + Msg("Overriding event sender with primary other user in DM portal") + sender.Sender = portal.OtherUserID + } if sender.Sender != "" { var err error ghost, err = portal.Bridge.GetGhostByID(ctx, sender.Sender) From 9827af3a8f60bd9a8ff910ef03e8cf06342ba422 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 1 Aug 2024 13:12:11 +0300 Subject: [PATCH 0548/1647] bridgev2/networkinterface: add new type for delivery receipts --- bridgev2/networkinterface.go | 10 +++++++++- bridgev2/portal.go | 8 ++++---- bridgev2/portalinternal.go | 6 +++--- bridgev2/portalinternal_generate.go | 2 +- bridgev2/simplevent/meta.go | 3 +++ bridgev2/simplevent/receipt.go | 3 ++- 6 files changed, 22 insertions(+), 10 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 898b58ab..b62dfcd9 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -844,13 +844,21 @@ type RemoteMessageRemove interface { RemoteEventWithTargetMessage } -type RemoteReceipt interface { +// Deprecated: Renamed to RemoteReadReceipt. +type RemoteReceipt = RemoteReadReceipt + +type RemoteReadReceipt interface { RemoteEvent GetLastReceiptTarget() networkid.MessageID GetReceiptTargets() []networkid.MessageID GetReadUpTo() time.Time } +type RemoteDeliveryReceipt interface { + RemoteEvent + GetReceiptTargets() []networkid.MessageID +} + type RemoteMarkUnread interface { RemoteEvent GetUnread() bool diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 7e025dda..7e5ce03f 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1264,11 +1264,11 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { case RemoteEventMessageRemove: portal.handleRemoteMessageRemove(ctx, source, evt.(RemoteMessageRemove)) case RemoteEventReadReceipt: - portal.handleRemoteReadReceipt(ctx, source, evt.(RemoteReceipt)) + portal.handleRemoteReadReceipt(ctx, source, evt.(RemoteReadReceipt)) case RemoteEventMarkUnread: portal.handleRemoteMarkUnread(ctx, source, evt.(RemoteMarkUnread)) case RemoteEventDeliveryReceipt: - portal.handleRemoteDeliveryReceipt(ctx, source, evt.(RemoteReceipt)) + portal.handleRemoteDeliveryReceipt(ctx, source, evt.(RemoteDeliveryReceipt)) case RemoteEventTyping: portal.handleRemoteTyping(ctx, source, evt.(RemoteTyping)) case RemoteEventChatInfoChange: @@ -2033,7 +2033,7 @@ func (portal *Portal) redactMessageParts(ctx context.Context, parts []*database. } } -func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserLogin, evt RemoteReceipt) { +func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserLogin, evt RemoteReadReceipt) { // TODO exclude fake mxids log := zerolog.Ctx(ctx) var err error @@ -2100,7 +2100,7 @@ func (portal *Portal) handleRemoteMarkUnread(ctx context.Context, source *UserLo } } -func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteReceipt) { +func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteDeliveryReceipt) { // TODO implement } diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index c261bd2d..68ad6046 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -1,7 +1,7 @@ // GENERATED BY portalinternal_generate.go; DO NOT EDIT //go:generate go run portalinternal_generate.go -//go:generate goimports -w portalinternal.go +//go:generate goimports -local maunium.net/go/mautrix -w portalinternal.go package bridgev2 @@ -173,7 +173,7 @@ func (portal *PortalInternals) RedactMessageParts(ctx context.Context, parts []* (*Portal)(portal).redactMessageParts(ctx, parts, intent, ts) } -func (portal *PortalInternals) HandleRemoteReadReceipt(ctx context.Context, source *UserLogin, evt RemoteReceipt) { +func (portal *PortalInternals) HandleRemoteReadReceipt(ctx context.Context, source *UserLogin, evt RemoteReadReceipt) { (*Portal)(portal).handleRemoteReadReceipt(ctx, source, evt) } @@ -181,7 +181,7 @@ func (portal *PortalInternals) HandleRemoteMarkUnread(ctx context.Context, sourc (*Portal)(portal).handleRemoteMarkUnread(ctx, source, evt) } -func (portal *PortalInternals) HandleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteReceipt) { +func (portal *PortalInternals) HandleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteDeliveryReceipt) { (*Portal)(portal).handleRemoteDeliveryReceipt(ctx, source, evt) } diff --git a/bridgev2/portalinternal_generate.go b/bridgev2/portalinternal_generate.go index 8fd9e917..4438c112 100644 --- a/bridgev2/portalinternal_generate.go +++ b/bridgev2/portalinternal_generate.go @@ -22,7 +22,7 @@ import ( const header = `// GENERATED BY portalinternal_generate.go; DO NOT EDIT //go:generate go run portalinternal_generate.go -//go:generate goimports -w portalinternal.go +//go:generate goimports -local maunium.net/go/mautrix -w portalinternal.go package bridgev2 diff --git a/bridgev2/simplevent/meta.go b/bridgev2/simplevent/meta.go index d61c6e91..afdbf65a 100644 --- a/bridgev2/simplevent/meta.go +++ b/bridgev2/simplevent/meta.go @@ -32,6 +32,9 @@ var ( ) func (evt *EventMeta) AddLogContext(c zerolog.Context) zerolog.Context { + if evt.LogContext == nil { + return c + } return evt.LogContext(c) } diff --git a/bridgev2/simplevent/receipt.go b/bridgev2/simplevent/receipt.go index dfd68730..e9835a66 100644 --- a/bridgev2/simplevent/receipt.go +++ b/bridgev2/simplevent/receipt.go @@ -22,7 +22,8 @@ type Receipt struct { } var ( - _ bridgev2.RemoteReceipt = (*Receipt)(nil) + _ bridgev2.RemoteReadReceipt = (*Receipt)(nil) + _ bridgev2.RemoteDeliveryReceipt = (*Receipt)(nil) ) func (evt *Receipt) GetLastReceiptTarget() networkid.MessageID { From cf35e92ab8a56942b7fe5cd4f39ede7d0fb762d5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 1 Aug 2024 13:23:28 +0300 Subject: [PATCH 0549/1647] bridgev2/portal: only set non-zero timestamps when handling pending message --- bridgev2/portal.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 7e5ce03f..b3501c48 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1483,7 +1483,10 @@ func (portal *Portal) checkPendingMessage(ctx context.Context, evt RemoteMessage } evtWithTimestamp, ok := evt.(RemoteEventWithTimestamp) if ok { - pending.db.Timestamp = evtWithTimestamp.GetTimestamp() + ts := evtWithTimestamp.GetTimestamp() + if !ts.IsZero() { + pending.db.Timestamp = ts + } } var statusErr error saveMessage := true From 1deb9642881f452933121eed80b17bae14a34cc7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 1 Aug 2024 13:32:58 +0300 Subject: [PATCH 0550/1647] bridgev2/portal: implement basic delivery receipts in DMs --- bridgev2/messagestatus.go | 2 +- bridgev2/portal.go | 27 ++++++++++++++++++++++++++- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index 97128314..4b438b64 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -146,7 +146,7 @@ func (ms MessageStatus) Unwrap() error { func (ms *MessageStatus) checkpointStatus() status.MessageCheckpointStatus { switch ms.Status { case event.MessageStatusSuccess: - if ms.DeliveredTo != nil { + if len(ms.DeliveredTo) > 0 { return status.MsgStatusDelivered } return status.MsgStatusSuccess diff --git a/bridgev2/portal.go b/bridgev2/portal.go index b3501c48..6645ee7e 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2104,7 +2104,32 @@ func (portal *Portal) handleRemoteMarkUnread(ctx context.Context, source *UserLo } func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteDeliveryReceipt) { - // TODO implement + if portal.RoomType != database.RoomTypeDM || evt.GetSender().Sender != portal.OtherUserID { + return + } + intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventDeliveryReceipt) + log := zerolog.Ctx(ctx) + for _, target := range evt.GetReceiptTargets() { + targetParts, err := portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, target) + if err != nil { + log.Err(err).Str("target_id", string(target)).Msg("Failed to get target message for delivery receipt") + continue + } else if len(targetParts) == 0 { + continue + } else if _, sentByGhost := portal.Bridge.Matrix.ParseGhostMXID(targetParts[0].SenderMXID); sentByGhost { + continue + } + for _, part := range targetParts { + portal.Bridge.Matrix.SendMessageStatus(ctx, &MessageStatus{ + Status: event.MessageStatusSuccess, + DeliveredTo: []id.UserID{intent.GetMXID()}, + }, &MessageStatusEventInfo{ + RoomID: portal.MXID, + EventID: part.MXID, + Sender: part.SenderMXID, + }) + } + } } func (portal *Portal) handleRemoteTyping(ctx context.Context, source *UserLogin, evt RemoteTyping) { From 6d92dfa9acaf8539ca79056ae0cc23d03ce49341 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 1 Aug 2024 14:39:46 +0300 Subject: [PATCH 0551/1647] bridgev2: add helper to convert reaction sync to backfill reactions --- bridgev2/networkinterface.go | 9 +++++++++ bridgev2/portalbackfill.go | 1 + 2 files changed, 10 insertions(+) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index b62dfcd9..b840b019 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -396,6 +396,7 @@ type BackfillMessage struct { *ConvertedMessage Sender EventSender ID networkid.MessageID + TxnID networkid.TransactionID Timestamp time.Time Reactions []*BackfillReaction @@ -820,6 +821,14 @@ type ReactionSyncData struct { HasAllUsers bool } +func (rsd *ReactionSyncData) ToBackfill() []*BackfillReaction { + var reactions []*BackfillReaction + for _, user := range rsd.Users { + reactions = append(reactions, user.Reactions...) + } + return reactions +} + type RemoteReactionSync interface { RemoteEventWithTargetMessage GetReactions() *ReactionSyncData diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index b05b16e5..77fe23b1 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -59,6 +59,7 @@ func (portal *Portal) doForwardBackfill(ctx context.Context, source *UserLogin, log.Debug().Msg("No messages to backfill") return } + // TODO check pending messages // TODO mark backfill queue task as done if last message is nil (-> room was empty) and HasMore is false? resp.Messages = cutoffMessages(&log, resp.Messages, true, lastMessage) if len(resp.Messages) == 0 { From 52b5649abddc680097ba1e30dc15a3453b7f34e8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 1 Aug 2024 14:43:01 +0300 Subject: [PATCH 0552/1647] bridgev2/portal: fix order of operations in room creation sync `UpdateInfoFromGhost` won't work correctly before participants are synced at least once. --- bridgev2/portal.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 6645ee7e..358b2764 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2858,7 +2858,6 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us portal.NameIsCustom = true changed = portal.updateAvatar(ctx, info.Avatar, sender, ts) || changed } - changed = portal.UpdateInfoFromGhost(ctx, nil) || changed if info.Disappear != nil { changed = portal.UpdateDisappearingSetting(ctx, *info.Disappear, sender, ts, false, false) || changed } @@ -2887,6 +2886,7 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us portal.RoomType = *info.Type } } + changed = portal.UpdateInfoFromGhost(ctx, nil) || changed if source != nil { source.MarkInPortal(ctx, portal) portal.updateUserLocalInfo(ctx, info.UserLocal, source) @@ -2950,13 +2950,15 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo var err error if info == nil || info.Members == nil { + if info != nil { + log.Warn().Msg("CreateMatrixRoom got info without members. Refetching info") + } info, err = source.Client.GetChatInfo(ctx, portal) if err != nil { log.Err(err).Msg("Failed to update portal info for creation") return err } } - portal.UpdateInfo(ctx, info, source, nil, time.Time{}) powerLevels := &event.PowerLevelsEventContent{ Events: map[string]int{ event.StateTombstone.Type: 100, @@ -2972,6 +2974,8 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo } powerLevels.EnsureUserLevel(portal.Bridge.Bot.GetMXID(), 9001) + portal.UpdateInfo(ctx, info, source, nil, time.Time{}) + req := mautrix.ReqCreateRoom{ Visibility: "private", Name: portal.Name, From e939f164d2483f38c5cfedbdb0e64b04148bbe89 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 1 Aug 2024 15:08:31 +0300 Subject: [PATCH 0553/1647] bridgev2/matrix: use beeper inbox state endpoint if available --- bridgev2/matrix/intent.go | 36 ++++++++++++++++++++++-------------- client.go | 6 ++++++ requests.go | 10 ++++++++++ versions.go | 1 + 4 files changed, 39 insertions(+), 14 deletions(-) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 6807e4af..b97ea1e2 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -14,6 +14,7 @@ import ( "time" "github.com/rs/zerolog" + "go.mau.fi/util/ptr" "golang.org/x/exp/slices" "maunium.net/go/mautrix" @@ -118,7 +119,7 @@ func (as *ASIntent) SendState(ctx context.Context, roomID id.RoomID, eventType e } } -func (as *ASIntent) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.EventID, ts time.Time) error { +func (as *ASIntent) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.EventID, ts time.Time) (err error) { extraData := map[string]any{} if !ts.IsZero() { extraData["ts"] = ts.UnixMilli() @@ -132,25 +133,32 @@ func (as *ASIntent) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.E req.FullyRead = eventID req.BeeperFullyReadExtra = extraData } - err := as.Matrix.SetReadMarkers(ctx, roomID, &req) - if err != nil { - return err - } - if as.Matrix.IsCustomPuppet { - err = as.Matrix.SetRoomAccountData(ctx, roomID, event.AccountDataMarkedUnread.Type, &event.MarkedUnreadEventContent{ - Unread: false, + if as.Matrix.IsCustomPuppet && as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureInboxState) { + err = as.Matrix.SetBeeperInboxState(ctx, roomID, &mautrix.ReqSetBeeperInboxState{ + MarkedUnread: ptr.Ptr(false), + ReadMarkers: &req, }) - if err != nil { - return err + } else { + err = as.Matrix.SetReadMarkers(ctx, roomID, &req) + if err == nil && as.Matrix.IsCustomPuppet { + err = as.Matrix.SetRoomAccountData(ctx, roomID, event.AccountDataMarkedUnread.Type, &event.MarkedUnreadEventContent{ + Unread: false, + }) } } - return nil + return } func (as *ASIntent) MarkUnread(ctx context.Context, roomID id.RoomID, unread bool) error { - return as.Matrix.SetRoomAccountData(ctx, roomID, event.AccountDataMarkedUnread.Type, &event.MarkedUnreadEventContent{ - Unread: unread, - }) + if as.Matrix.IsCustomPuppet && as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureInboxState) { + return as.Matrix.SetBeeperInboxState(ctx, roomID, &mautrix.ReqSetBeeperInboxState{ + MarkedUnread: ptr.Ptr(unread), + }) + } else { + return as.Matrix.SetRoomAccountData(ctx, roomID, event.AccountDataMarkedUnread.Type, &event.MarkedUnreadEventContent{ + Unread: unread, + }) + } } func (as *ASIntent) MarkTyping(ctx context.Context, roomID id.RoomID, typingType bridgev2.TypingType, timeout time.Duration) error { diff --git a/client.go b/client.go index 997d7363..9c5ad194 100644 --- a/client.go +++ b/client.go @@ -1932,6 +1932,12 @@ func (cli *Client) SetReadMarkers(ctx context.Context, roomID id.RoomID, content return } +func (cli *Client) SetBeeperInboxState(ctx context.Context, roomID id.RoomID, content *ReqSetBeeperInboxState) (err error) { + urlPath := cli.BuildClientURL("unstable", "com.beeper.inbox", "user", cli.UserID, "rooms", roomID, "inbox_state") + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, content, nil) + return +} + func (cli *Client) AddTag(ctx context.Context, roomID id.RoomID, tag event.RoomTag, order float64) error { return cli.AddTagWithCustomData(ctx, roomID, tag, &event.TagMetadata{ Order: json.Number(strconv.FormatFloat(order, 'e', -1, 64)), diff --git a/requests.go b/requests.go index b6c2f895..d4b634af 100644 --- a/requests.go +++ b/requests.go @@ -365,6 +365,16 @@ type ReqSetReadMarkers struct { BeeperFullyReadExtra interface{} `json:"com.beeper.fully_read.extra,omitempty"` } +type BeeperInboxDone struct { + Delta int64 `json:"at_delta"` +} + +type ReqSetBeeperInboxState struct { + MarkedUnread *bool `json:"marked_unread,omitempty"` + Done *BeeperInboxDone `json:"done,omitempty"` + ReadMarkers *ReqSetReadMarkers `json:"read_markers,omitempty"` +} + type ReqSendReceipt struct { ThreadID string `json:"thread_id,omitempty"` } diff --git a/versions.go b/versions.go index 010d987b..60eb0f30 100644 --- a/versions.go +++ b/versions.go @@ -69,6 +69,7 @@ var ( BeeperFeatureAutojoinInvites = UnstableFeature{UnstableFlag: "com.beeper.room_create_autojoin_invites"} BeeperFeatureArbitraryProfileMeta = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_profile_meta"} BeeperFeatureAccountDataMute = UnstableFeature{UnstableFlag: "com.beeper.account_data_mute"} + BeeperFeatureInboxState = UnstableFeature{UnstableFlag: "com.beeper.inbox_state"} ) func (versions *RespVersions) Supports(feature UnstableFeature) bool { From 1050d07624c0b11a00a4e72b730ba868be242f11 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 1 Aug 2024 15:23:49 +0300 Subject: [PATCH 0554/1647] bridgev2/portal: actually fix order of operations for filling other user ID --- bridgev2/portal.go | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 358b2764..60e4bfd1 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2868,13 +2868,6 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us // TODO change detection instead of spamming this every time? portal.sendRoomMeta(ctx, sender, ts, event.StateJoinRules, "", info.JoinRule) } - if info.Members != nil && portal.MXID != "" && source != nil { - err := portal.syncParticipants(ctx, info.Members, source, nil, time.Time{}) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to sync room members") - } - // TODO detect changes to functional members list? - } if info.Type != nil && portal.RoomType != *info.Type { if portal.MXID != "" && (*info.Type == database.RoomTypeSpace || portal.RoomType == database.RoomTypeSpace) { zerolog.Ctx(ctx).Warn(). @@ -2886,6 +2879,15 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us portal.RoomType = *info.Type } } + if info.Members != nil && portal.MXID != "" && source != nil { + err := portal.syncParticipants(ctx, info.Members, source, nil, time.Time{}) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to sync room members") + } + // TODO detect changes to functional members list? + } else if info.Members != nil { + portal.updateOtherUser(ctx, info.Members) + } changed = portal.UpdateInfoFromGhost(ctx, nil) || changed if source != nil { source.MarkInPortal(ctx, portal) @@ -2959,6 +2961,9 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo return err } } + + portal.UpdateInfo(ctx, info, source, nil, time.Time{}) + powerLevels := &event.PowerLevelsEventContent{ Events: map[string]int{ event.StateTombstone.Type: 100, @@ -2974,8 +2979,6 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo } powerLevels.EnsureUserLevel(portal.Bridge.Bot.GetMXID(), 9001) - portal.UpdateInfo(ctx, info, source, nil, time.Time{}) - req := mautrix.ReqCreateRoom{ Visibility: "private", Name: portal.Name, From 2e0dd48c5d7fb71f96b8949bd358d1eddddecb64 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 1 Aug 2024 15:58:00 +0300 Subject: [PATCH 0555/1647] bridgev2/backfillqueue: sleep on first start --- bridgev2/backfillqueue.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/bridgev2/backfillqueue.go b/bridgev2/backfillqueue.go index dd398dcc..af79fc98 100644 --- a/bridgev2/backfillqueue.go +++ b/bridgev2/backfillqueue.go @@ -45,15 +45,6 @@ func (br *Bridge) RunBackfillQueue() { log.Info().Stringer("batch_delay", batchDelay).Msg("Backfill queue starting") noTasksFoundCount := 0 for { - backfillTask, err := br.DB.BackfillTask.GetNext(ctx) - if err != nil { - log.Err(err).Msg("Failed to get next backfill queue entry") - time.Sleep(BackfillQueueErrorBackoff) - continue - } else if backfillTask != nil { - br.doBackfillTask(ctx, backfillTask) - noTasksFoundCount = 0 - } nextDelay := batchDelay if noTasksFoundCount > 0 { extraDelay := batchDelay * time.Duration(noTasksFoundCount) @@ -80,6 +71,15 @@ func (br *Bridge) RunBackfillQueue() { return case <-timer.C: } + backfillTask, err := br.DB.BackfillTask.GetNext(ctx) + if err != nil { + log.Err(err).Msg("Failed to get next backfill queue entry") + time.Sleep(BackfillQueueErrorBackoff) + continue + } else if backfillTask != nil { + br.doBackfillTask(ctx, backfillTask) + noTasksFoundCount = 0 + } } } From 7d9e60dfdfcb19fa6b6dcc39bcb56ae55698ca52 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 1 Aug 2024 16:13:58 +0300 Subject: [PATCH 0556/1647] client: fix beeper inbox method --- client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client.go b/client.go index 9c5ad194..3c91cb65 100644 --- a/client.go +++ b/client.go @@ -1934,7 +1934,7 @@ func (cli *Client) SetReadMarkers(ctx context.Context, roomID id.RoomID, content func (cli *Client) SetBeeperInboxState(ctx context.Context, roomID id.RoomID, content *ReqSetBeeperInboxState) (err error) { urlPath := cli.BuildClientURL("unstable", "com.beeper.inbox", "user", cli.UserID, "rooms", roomID, "inbox_state") - _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, content, nil) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, content, nil) return } From ddbf9098f491369fd6d2e4dfce8db8e8e73e593a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 1 Aug 2024 16:21:15 +0300 Subject: [PATCH 0557/1647] bridgev2/backfillqueue: fix tasks being marked as done too soon --- bridgev2/backfillqueue.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/bridgev2/backfillqueue.go b/bridgev2/backfillqueue.go index af79fc98..345c3391 100644 --- a/bridgev2/backfillqueue.go +++ b/bridgev2/backfillqueue.go @@ -100,9 +100,15 @@ func (br *Bridge) doBackfillTask(ctx context.Context, task *database.BackfillTas time.Sleep(BackfillQueueErrorBackoff) return } else if completed { - log.Info().Msg("Backfill task completed successfully") + log.Info(). + Int("batch_count", task.BatchCount). + Bool("is_done", task.IsDone). + Msg("Backfill task completed successfully") } else { - log.Info().Msg("Backfill task canceled") + log.Info(). + Int("batch_count", task.BatchCount). + Bool("is_done", task.IsDone). + Msg("Backfill task canceled") } err = br.DB.BackfillTask.Update(ctx, task) if err != nil { @@ -181,7 +187,7 @@ func (br *Bridge) actuallyDoBackfillTask(ctx context.Context, task *database.Bac return false, fmt.Errorf("failed to backfill: %w", err) } task.BatchCount++ - task.IsDone = task.IsDone || task.BatchCount >= maxBatches + task.IsDone = task.IsDone || (maxBatches > 0 && task.BatchCount >= maxBatches) batchDelay := time.Duration(br.Config.Backfill.Queue.BatchDelay) * time.Second task.CompletedAt = time.Now() task.NextDispatchMinTS = task.CompletedAt.Add(batchDelay) From 865606c4407ff5601771d022226d4292ef70c914 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 1 Aug 2024 16:57:50 +0300 Subject: [PATCH 0558/1647] crypto/attachment: return `io.ReadSeekCloser` from stream functions --- crypto/attachment/attachments.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crypto/attachment/attachments.go b/crypto/attachment/attachments.go index 8008cad2..344db4f0 100644 --- a/crypto/attachment/attachments.go +++ b/crypto/attachment/attachments.go @@ -197,7 +197,7 @@ func (r *encryptingReader) Close() (err error) { // The Close() method of the returned io.ReadCloser must be called for the SHA256 hash // in the EncryptedFile struct to be updated. The metadata is not valid before the hash // is filled. -func (ef *EncryptedFile) EncryptStream(reader io.Reader) io.ReadCloser { +func (ef *EncryptedFile) EncryptStream(reader io.Reader) io.ReadSeekCloser { ef.decodeKeys(false) block, _ := aes.NewCipher(ef.decoded.key[:]) return &encryptingReader{ @@ -252,7 +252,7 @@ func (ef *EncryptedFile) DecryptInPlace(data []byte) error { // // The Close call will validate the hash and return an error if it doesn't match. // In this case, the written data should be considered compromised and should not be used further. -func (ef *EncryptedFile) DecryptStream(reader io.Reader) io.ReadCloser { +func (ef *EncryptedFile) DecryptStream(reader io.Reader) io.ReadSeekCloser { block, _ := aes.NewCipher(ef.decoded.key[:]) return &encryptingReader{ stream: cipher.NewCTR(block, ef.decoded.iv[:]), From a1a245be10db073c3e59b9cb9028e83dd0a439f2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 1 Aug 2024 18:19:52 +0300 Subject: [PATCH 0559/1647] bridgev2/matrix: temporarily disable mark unread bridging on Beeper --- bridgev2/matrix/intent.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index b97ea1e2..55648f0b 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -20,6 +20,7 @@ import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/bridgeconfig" "maunium.net/go/mautrix/crypto/attachment" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -135,12 +136,12 @@ func (as *ASIntent) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.E } if as.Matrix.IsCustomPuppet && as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureInboxState) { err = as.Matrix.SetBeeperInboxState(ctx, roomID, &mautrix.ReqSetBeeperInboxState{ - MarkedUnread: ptr.Ptr(false), - ReadMarkers: &req, + //MarkedUnread: ptr.Ptr(false), + ReadMarkers: &req, }) } else { err = as.Matrix.SetReadMarkers(ctx, roomID, &req) - if err == nil && as.Matrix.IsCustomPuppet { + if err == nil && as.Matrix.IsCustomPuppet && as.Connector.Config.Homeserver.Software != bridgeconfig.SoftwareHungry { err = as.Matrix.SetRoomAccountData(ctx, roomID, event.AccountDataMarkedUnread.Type, &event.MarkedUnreadEventContent{ Unread: false, }) @@ -150,6 +151,9 @@ func (as *ASIntent) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.E } func (as *ASIntent) MarkUnread(ctx context.Context, roomID id.RoomID, unread bool) error { + if as.Connector.Config.Homeserver.Software == bridgeconfig.SoftwareHungry { + return nil + } if as.Matrix.IsCustomPuppet && as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureInboxState) { return as.Matrix.SetBeeperInboxState(ctx, roomID, &mautrix.ReqSetBeeperInboxState{ MarkedUnread: ptr.Ptr(unread), From d644a962c8707107d27990fb2948fbcc46d652a5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 1 Aug 2024 19:13:32 +0300 Subject: [PATCH 0560/1647] bridgev2/provisioning: add option to delete previous login when making new one --- bridgev2/matrix/provisioning.go | 74 +++++++++++++++++++++++---------- bridgev2/userlogin.go | 10 ++++- 2 files changed, 61 insertions(+), 23 deletions(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 4c87af0f..4489a3e3 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -55,6 +55,7 @@ type ProvLogin struct { ID string Process bridgev2.LoginProcess NextStep *bridgev2.LoginStep + Override *bridgev2.UserLogin Lock sync.Mutex } @@ -324,6 +325,10 @@ func (prov *ProvisioningAPI) GetLoginFlows(w http.ResponseWriter, r *http.Reques } func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Request) { + overrideLogin, failed := prov.GetExplicitLoginForRequest(w, r) + if failed { + return + } login, err := prov.net.CreateLogin( r.Context(), prov.GetUser(r), @@ -352,11 +357,26 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque ID: loginID, Process: login, NextStep: firstStep, + Override: overrideLogin, } prov.loginsLock.Unlock() jsonResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: loginID, LoginStep: firstStep}) } +func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *ProvLogin, step *bridgev2.LoginStep) { + if login.Override == nil || login.Override.ID == step.CompleteParams.UserLoginID { + return + } + zerolog.Ctx(ctx).Info(). + Str("old_login_id", string(login.Override.ID)). + Str("new_login_id", string(step.CompleteParams.UserLoginID)). + Msg("Login resulted in different remote ID than what was being overridden. Deleting previous login") + login.Override.Delete(ctx, status.BridgeState{ + StateEvent: status.StateLoggedOut, + Reason: "LOGIN_OVERRIDDEN", + }, true) +} + func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http.Request) { var params map[string]string err := json.NewDecoder(r.Body).Decode(¶ms) @@ -387,6 +407,9 @@ func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http return } login.NextStep = nextStep + if nextStep.Type == bridgev2.LoginStepTypeComplete { + prov.handleCompleteStep(r.Context(), login, nextStep) + } jsonResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) } @@ -402,6 +425,9 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques return } login.NextStep = nextStep + if nextStep.Type == bridgev2.LoginStepTypeComplete { + prov.handleCompleteStep(r.Context(), login, nextStep) + } jsonResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) } @@ -439,30 +465,36 @@ func (prov *ProvisioningAPI) GetLogins(w http.ResponseWriter, r *http.Request) { jsonResponse(w, http.StatusOK, &RespGetLogins{LoginIDs: user.GetUserLoginIDs()}) } -func (prov *ProvisioningAPI) GetLoginForRequest(w http.ResponseWriter, r *http.Request) *bridgev2.UserLogin { - user := prov.GetUser(r) +func (prov *ProvisioningAPI) GetExplicitLoginForRequest(w http.ResponseWriter, r *http.Request) (*bridgev2.UserLogin, bool) { userLoginID := networkid.UserLoginID(r.URL.Query().Get("login_id")) - if userLoginID != "" { - userLogin := prov.br.Bridge.GetCachedUserLoginByID(userLoginID) - if userLogin == nil || userLogin.UserMXID != user.MXID { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - Err: "Login not found", - ErrCode: mautrix.MNotFound.ErrCode, - }) - return nil - } - return userLogin - } else { - userLogin := user.GetDefaultLogin() - if userLogin == nil { - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - Err: "Not logged in", - ErrCode: "FI.MAU.NOT_LOGGED_IN", - }) - return nil - } + if userLoginID == "" { + return nil, false + } + userLogin := prov.br.Bridge.GetCachedUserLoginByID(userLoginID) + if userLogin == nil || userLogin.UserMXID != prov.GetUser(r).MXID { + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + Err: "Login not found", + ErrCode: mautrix.MNotFound.ErrCode, + }) + return nil, true + } + return userLogin, false +} + +func (prov *ProvisioningAPI) GetLoginForRequest(w http.ResponseWriter, r *http.Request) *bridgev2.UserLogin { + userLogin, failed := prov.GetExplicitLoginForRequest(w, r) + if userLogin != nil || failed { return userLogin } + userLogin = prov.GetUser(r).GetDefaultLogin() + if userLogin == nil { + jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ + Err: "Not logged in", + ErrCode: "FI.MAU.NOT_LOGGED_IN", + }) + return nil + } + return userLogin } type RespResolveIdentifier struct { diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 78b732fe..61036438 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -37,6 +37,7 @@ type UserLogin struct { inPortalCache *exsync.Set[networkid.PortalKey] spaceCreateLock sync.Mutex + deleteLock sync.Mutex } func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *database.UserLogin) (*UserLogin, error) { @@ -63,9 +64,9 @@ func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *da userLogin.Log.Err(err).Msg("Failed to load user login") return nil, nil } + userLogin.BridgeState = br.NewBridgeStateQueue(userLogin) user.logins[userLogin.ID] = userLogin br.userLoginsByID[userLogin.ID] = userLogin - userLogin.BridgeState = br.NewBridgeStateQueue(userLogin) return userLogin, nil } @@ -186,7 +187,7 @@ func (user *User) NewLogin(ctx context.Context, data *database.UserLogin, params var doInsert bool if ul != nil && ul.UserMXID != user.MXID { if params.DeleteOnConflict { - ul.delete(ctx, status.BridgeState{StateEvent: status.StateLoggedOut, Error: "overridden-by-another-user"}, false, true) + ul.delete(ctx, status.BridgeState{StateEvent: status.StateLoggedOut, Reason: "LOGIN_OVERRIDDEN_ANOTHER_USER"}, false, true) ul = nil } else { return nil, fmt.Errorf("%s is already logged in with that account", ul.UserMXID) @@ -246,6 +247,11 @@ func (ul *UserLogin) Delete(ctx context.Context, state status.BridgeState, logou } func (ul *UserLogin) delete(ctx context.Context, state status.BridgeState, logoutRemote, unlocked bool) { + ul.deleteLock.Lock() + defer ul.deleteLock.Unlock() + if ul.BridgeState == nil { + return + } if logoutRemote { ul.Client.LogoutRemote(ctx) } else { From 69fbdaf5ac32dc3b2ced32d0ae0357683cfe7a80 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 1 Aug 2024 20:35:36 +0300 Subject: [PATCH 0561/1647] bridgev2/ghost: add GetExistingGhostByID --- bridgev2/ghost.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index b066bdd1..32e8e0c9 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -83,6 +83,12 @@ func (br *Bridge) GetGhostByID(ctx context.Context, id networkid.UserID) (*Ghost return br.unlockedGetGhostByID(ctx, id, false) } +func (br *Bridge) GetExistingGhostByID(ctx context.Context, id networkid.UserID) (*Ghost, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + return br.unlockedGetGhostByID(ctx, id, true) +} + type Avatar struct { ID networkid.AvatarID Get func(ctx context.Context) ([]byte, error) From 853273fabfd4e8ab0cacc0774676f73395081120 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 1 Aug 2024 20:35:52 +0300 Subject: [PATCH 0562/1647] bridgev2/util: add function to clean non-international phone numbers --- bridgev2/login.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/bridgev2/login.go b/bridgev2/login.go index 2e2b1d84..32bf6e67 100644 --- a/bridgev2/login.go +++ b/bridgev2/login.go @@ -186,6 +186,14 @@ func isOnlyNumbers(input string) bool { return true } +func CleanNonInternationalPhoneNumber(phone string) (string, error) { + phone = numberCleaner.Replace(phone) + if !isOnlyNumbers(strings.TrimPrefix(phone, "+")) { + return "", fmt.Errorf("phone number must only contain numbers") + } + return phone, nil +} + func CleanPhoneNumber(phone string) (string, error) { phone = numberCleaner.Replace(phone) if len(phone) < 2 { From 07d30e77df104a898591d50c0ffa64021ca82ce8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 1 Aug 2024 21:31:48 +0300 Subject: [PATCH 0563/1647] bridgev2/portal: don't queue event in CreateMatrixRoom if portal already exists --- bridgev2/portal.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 60e4bfd1..ab7f27e1 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2913,6 +2913,12 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us } func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, info *ChatInfo) (retErr error) { + if portal.MXID != "" { + if source != nil { + source.MarkInPortal(ctx, portal) + } + return nil + } waiter := make(chan struct{}) closed := false portal.events <- &portalCreateEvent{ From 7402f5a705016d96a1f65bc83020ccf329e9a923 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 1 Aug 2024 22:13:14 +0300 Subject: [PATCH 0564/1647] bridgev2/matrix: disable using inbox state endpoint --- bridgev2/matrix/intent.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 55648f0b..a4fa1b14 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -134,7 +134,7 @@ func (as *ASIntent) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.E req.FullyRead = eventID req.BeeperFullyReadExtra = extraData } - if as.Matrix.IsCustomPuppet && as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureInboxState) { + if as.Matrix.IsCustomPuppet && as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureInboxState) && as.Connector.Config.Homeserver.Software != bridgeconfig.SoftwareHungry { err = as.Matrix.SetBeeperInboxState(ctx, roomID, &mautrix.ReqSetBeeperInboxState{ //MarkedUnread: ptr.Ptr(false), ReadMarkers: &req, From 0a17ac1cbef3bf7e2840802aca1d6a6a72f79f13 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 1 Aug 2024 18:45:43 -0600 Subject: [PATCH 0565/1647] crypto/ssss: remove id from key metadata Instead, we will pass it into the key constructor functions directly. This avoids the footgun where you don't set the key ID on the metadata and then the ID is not properly propagated to the Key that is returned. Signed-off-by: Sumner Evans --- crypto/ssss/client.go | 2 +- crypto/ssss/meta.go | 12 +++++------- crypto/ssss/meta_test.go | 18 +++++++++--------- hicli/verify.go | 4 ++-- 4 files changed, 17 insertions(+), 19 deletions(-) diff --git a/crypto/ssss/client.go b/crypto/ssss/client.go index 0cfdd24f..e30925d9 100644 --- a/crypto/ssss/client.go +++ b/crypto/ssss/client.go @@ -53,7 +53,7 @@ func (mach *Machine) SetDefaultKeyID(ctx context.Context, keyID string) error { // GetKeyData gets the details about the given key ID. func (mach *Machine) GetKeyData(ctx context.Context, keyID string) (keyData *KeyMetadata, err error) { - keyData = &KeyMetadata{id: keyID} + keyData = &KeyMetadata{} err = mach.Client.GetAccountData(ctx, fmt.Sprintf("%s.%s", event.AccountDataSecretStorageKey.Type, keyID), keyData) return } diff --git a/crypto/ssss/meta.go b/crypto/ssss/meta.go index e752cf0c..210bcdcf 100644 --- a/crypto/ssss/meta.go +++ b/crypto/ssss/meta.go @@ -17,8 +17,6 @@ import ( // KeyMetadata represents server-side metadata about a SSSS key. The metadata can be used to get // the actual SSSS key from a passphrase or recovery key. type KeyMetadata struct { - id string - Name string `json:"name"` Algorithm Algorithm `json:"algorithm"` @@ -31,7 +29,7 @@ type KeyMetadata struct { } // VerifyRecoveryKey verifies that the given passphrase is valid and returns the computed SSSS key. -func (kd *KeyMetadata) VerifyPassphrase(passphrase string) (*Key, error) { +func (kd *KeyMetadata) VerifyPassphrase(keyID, passphrase string) (*Key, error) { ssssKey, err := kd.Passphrase.GetKey(passphrase) if err != nil { return nil, err @@ -40,15 +38,15 @@ func (kd *KeyMetadata) VerifyPassphrase(passphrase string) (*Key, error) { } return &Key{ - ID: kd.id, + ID: keyID, Key: ssssKey, Metadata: kd, }, nil } // VerifyRecoveryKey verifies that the given recovery key is valid and returns the decoded SSSS key. -func (kd *KeyMetadata) VerifyRecoveryKey(recoverKey string) (*Key, error) { - ssssKey := utils.DecodeBase58RecoveryKey(recoverKey) +func (kd *KeyMetadata) VerifyRecoveryKey(keyID, recoveryKey string) (*Key, error) { + ssssKey := utils.DecodeBase58RecoveryKey(recoveryKey) if ssssKey == nil { return nil, ErrInvalidRecoveryKey } else if !kd.VerifyKey(ssssKey) { @@ -56,7 +54,7 @@ func (kd *KeyMetadata) VerifyRecoveryKey(recoverKey string) (*Key, error) { } return &Key{ - ID: kd.id, + ID: keyID, Key: ssssKey, Metadata: kd, }, nil diff --git a/crypto/ssss/meta_test.go b/crypto/ssss/meta_test.go index 2ad8f62a..96c97282 100644 --- a/crypto/ssss/meta_test.go +++ b/crypto/ssss/meta_test.go @@ -55,7 +55,7 @@ func getKey1Meta() *ssss.KeyMetadata { func getKey1() *ssss.Key { km := getKey1Meta() - key, err := km.VerifyRecoveryKey(key1RecoveryKey) + key, err := km.VerifyRecoveryKey(key1ID, key1RecoveryKey) if err != nil { panic(err) } @@ -74,7 +74,7 @@ func getKey2Meta() *ssss.KeyMetadata { func getKey2() *ssss.Key { km := getKey2Meta() - key, err := km.VerifyRecoveryKey(key2RecoveryKey) + key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) if err != nil { panic(err) } @@ -84,7 +84,7 @@ func getKey2() *ssss.Key { func TestKeyMetadata_VerifyRecoveryKey_Correct(t *testing.T) { km := getKey1Meta() - key, err := km.VerifyRecoveryKey(key1RecoveryKey) + key, err := km.VerifyRecoveryKey(key1ID, key1RecoveryKey) assert.NoError(t, err) assert.NotNil(t, key) assert.Equal(t, key1RecoveryKey, key.RecoveryKey()) @@ -92,7 +92,7 @@ func TestKeyMetadata_VerifyRecoveryKey_Correct(t *testing.T) { func TestKeyMetadata_VerifyRecoveryKey_Correct2(t *testing.T) { km := getKey2Meta() - key, err := km.VerifyRecoveryKey(key2RecoveryKey) + key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) assert.NoError(t, err) assert.NotNil(t, key) assert.Equal(t, key2RecoveryKey, key.RecoveryKey()) @@ -100,21 +100,21 @@ func TestKeyMetadata_VerifyRecoveryKey_Correct2(t *testing.T) { func TestKeyMetadata_VerifyRecoveryKey_Invalid(t *testing.T) { km := getKey1Meta() - key, err := km.VerifyRecoveryKey("foo") + key, err := km.VerifyRecoveryKey(key1ID, "foo") assert.True(t, errors.Is(err, ssss.ErrInvalidRecoveryKey), "unexpected error: %v", err) assert.Nil(t, key) } func TestKeyMetadata_VerifyRecoveryKey_Incorrect(t *testing.T) { km := getKey1Meta() - key, err := km.VerifyRecoveryKey(key2RecoveryKey) + key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error: %v", err) assert.Nil(t, key) } func TestKeyMetadata_VerifyPassphrase_Correct(t *testing.T) { km := getKey1Meta() - key, err := km.VerifyPassphrase(key1Passphrase) + key, err := km.VerifyPassphrase(key1ID, key1Passphrase) assert.NoError(t, err) assert.NotNil(t, key) assert.Equal(t, key1RecoveryKey, key.RecoveryKey()) @@ -122,14 +122,14 @@ func TestKeyMetadata_VerifyPassphrase_Correct(t *testing.T) { func TestKeyMetadata_VerifyPassphrase_Incorrect(t *testing.T) { km := getKey1Meta() - key, err := km.VerifyPassphrase("incorrect horse battery staple") + key, err := km.VerifyPassphrase(key1ID, "incorrect horse battery staple") assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error %v", err) assert.Nil(t, key) } func TestKeyMetadata_VerifyPassphrase_NotSet(t *testing.T) { km := getKey2Meta() - key, err := km.VerifyPassphrase("hmm") + key, err := km.VerifyPassphrase(key2ID, "hmm") assert.True(t, errors.Is(err, ssss.ErrNoPassphrase), "unexpected error %v", err) assert.Nil(t, key) } diff --git a/hicli/verify.go b/hicli/verify.go index 2062519a..905be052 100644 --- a/hicli/verify.go +++ b/hicli/verify.go @@ -125,11 +125,11 @@ func (h *HiClient) storeCrossSigningPrivateKeys(ctx context.Context) error { } func (h *HiClient) VerifyWithRecoveryCode(ctx context.Context, code string) error { - _, keyData, err := h.Crypto.SSSS.GetDefaultKeyData(ctx) + keyID, keyData, err := h.Crypto.SSSS.GetDefaultKeyData(ctx) if err != nil { return fmt.Errorf("failed to get default SSSS key data: %w", err) } - key, err := keyData.VerifyRecoveryKey(code) + key, err := keyData.VerifyRecoveryKey(keyID, code) if err != nil { return err } From 83d3a0de5b9be8be9c7cc0147e06fff681492967 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 2 Aug 2024 21:33:29 +0300 Subject: [PATCH 0566/1647] bridgev2: add disambiguation for relayed user displaynames --- bridgev2/matrix/connector.go | 4 ++ bridgev2/matrix/mxmain/example-config.yaml | 26 ++++++--- bridgev2/matrixinterface.go | 4 ++ bridgev2/networkinterface.go | 7 ++- bridgev2/portal.go | 32 ++++++++++- go.mod | 7 ++- go.sum | 14 +++-- hicli/database/statestore.go | 5 ++ sqlstatestore/statestore.go | 29 +++++++++- sqlstatestore/v00-latest-revision.sql | 8 ++- .../v06-displayname-disambiguation.go | 55 +++++++++++++++++++ statestore.go | 6 ++ 12 files changed, 174 insertions(+), 23 deletions(-) create mode 100644 sqlstatestore/v06-displayname-disambiguation.go diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 282f1d3b..d167af9c 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -533,6 +533,10 @@ func (br *Connector) GetMemberInfo(ctx context.Context, roomID id.RoomID, userID return br.AS.StateStore.GetMember(ctx, roomID, userID) } +func (br *Connector) IsConfusableName(ctx context.Context, roomID id.RoomID, userID id.UserID, name string) ([]id.UserID, error) { + return br.AS.StateStore.IsConfusableName(ctx, roomID, userID, name) +} + func (br *Connector) BatchSend(ctx context.Context, roomID id.RoomID, req *mautrix.ReqBeeperBatchSend, extras []*bridgev2.MatrixSendExtra) (*mautrix.RespBeeperBatchSend, error) { if encrypted, err := br.StateStore.IsEncrypted(ctx, roomID); err != nil { return nil, fmt.Errorf("failed to check if room is encrypted: %w", err) diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 75d7880a..4da3ca50 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -45,15 +45,25 @@ bridge: # List of user login IDs which anyone can set as a relay, as long as the relay user is in the room. default_relays: [] # The formats to use when sending messages via the relaybot. + # Available variables: + # .Sender.UserID - The Matrix user ID of the sender. + # .Sender.Displayname - The display name of the sender (if set). + # .Sender.RequiresDisambiguation - Whether the sender's name may be confused with the name of another user in the room. + # .Sender.DisambiguatedName - The disambiguated name of the sender. This will be the displayname if set, + # plus the user ID in parentheses if the displayname is not unique. + # If the displayname is not set, this is just the user ID. + # .Message - The `formatted_body` field of the message. + # .Caption - The `formatted_body` field of the message, if it's a caption. Otherwise an empty string. + # .FileName - The name of the file being sent. message_formats: - m.text: "{{ .Sender.Displayname }}: {{ .Message }}" - m.notice: "{{ .Sender.Displayname }}: {{ .Message }}" - m.emote: "* {{ .Sender.Displayname }} {{ .Message }}" - m.file: "{{ .Sender.Displayname }} sent a file{{ if .Caption }}: {{ .Caption }}{{ end }}" - m.image: "{{ .Sender.Displayname }} sent an image{{ if .Caption }}: {{ .Caption }}{{ end }}" - m.audio: "{{ .Sender.Displayname }} sent an audio file{{ if .Caption }}: {{ .Caption }}{{ end }}" - m.video: "{{ .Sender.Displayname }} sent a video{{ if .Caption }}: {{ .Caption }}{{ end }}" - m.location: "{{ .Sender.Displayname }} sent a location{{ if .Caption }}: {{ .Caption }}{{ end }}" + m.text: "{{ .Sender.DisambiguatedName }}: {{ .Message }}" + m.notice: "{{ .Sender.DisambiguatedName }}: {{ .Message }}" + m.emote: "* {{ .Sender.DisambiguatedName }} {{ .Message }}" + m.file: "{{ .Sender.DisambiguatedName }} sent a file{{ if .Caption }}: {{ .Caption }}{{ end }}" + m.image: "{{ .Sender.DisambiguatedName }} sent an image{{ if .Caption }}: {{ .Caption }}{{ end }}" + m.audio: "{{ .Sender.DisambiguatedName }} sent an audio file{{ if .Caption }}: {{ .Caption }}{{ end }}" + m.video: "{{ .Sender.DisambiguatedName }} sent a video{{ if .Caption }}: {{ .Caption }}{{ end }}" + m.location: "{{ .Sender.DisambiguatedName }} sent a location{{ if .Caption }}: {{ .Caption }}{{ end }}" # Permissions for using the bridge. # Permitted values: diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index b6773810..51fee503 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -58,6 +58,10 @@ type MatrixConnectorWithServer interface { GetRouter() *mux.Router } +type MatrixConnectorWithNameDisambiguation interface { + IsConfusableName(ctx context.Context, roomID id.RoomID, userID id.UserID, name string) ([]id.UserID, error) +} + type MatrixSendExtra struct { Timestamp time.Time MessageMeta *database.Message diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index b840b019..d73d6c4d 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -897,7 +897,12 @@ type RemoteTypingWithType interface { } type OrigSender struct { - User *User + User *User + UserID id.UserID + + RequiresDisambiguation bool + DisambiguatedName string + event.MemberEventContent } diff --git a/bridgev2/portal.go b/bridgev2/portal.go index ab7f27e1..82a6db46 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -373,6 +373,26 @@ func (portal *Portal) sendErrorStatus(ctx context.Context, evt *event.Event, err portal.Bridge.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) } +func (portal *Portal) checkConfusableName(ctx context.Context, userID id.UserID, name string) bool { + conn, ok := portal.Bridge.Matrix.(MatrixConnectorWithNameDisambiguation) + if !ok { + return false + } + confusableWith, err := conn.IsConfusableName(ctx, portal.MXID, userID, name) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to check if name is confusable") + return true + } + for _, confusable := range confusableWith { + // Don't disambiguate names that only conflict with ghosts of this bridge + _, isGhost := portal.Bridge.Matrix.ParseGhostMXID(confusable) + if !isGhost { + return true + } + } + return false +} + func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) { log := portal.Log.With(). Str("action", "handle matrix event"). @@ -423,7 +443,8 @@ func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) { if login == nil { login = portal.Relay origSender = &OrigSender{ - User: sender, + User: sender, + UserID: sender.MXID, } memberInfo, err := portal.Bridge.Matrix.GetMemberInfo(ctx, portal.MXID, sender.MXID) @@ -431,6 +452,15 @@ func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) { log.Warn().Err(err).Msg("Failed to get member info for user being relayed") } else if memberInfo != nil { origSender.MemberEventContent = *memberInfo + if memberInfo.Displayname == "" { + origSender.DisambiguatedName = sender.MXID.String() + } else if origSender.RequiresDisambiguation = portal.checkConfusableName(ctx, sender.MXID, memberInfo.Displayname); origSender.RequiresDisambiguation { + origSender.DisambiguatedName = fmt.Sprintf("%s (%s)", memberInfo.Displayname, sender.MXID) + } else { + origSender.DisambiguatedName = memberInfo.Displayname + } + } else { + origSender.DisambiguatedName = sender.MXID.String() } } log.UpdateContext(func(c zerolog.Context) zerolog.Context { diff --git a/go.mod b/go.mod index 003cbc6a..e3d6cfe3 100644 --- a/go.mod +++ b/go.mod @@ -12,13 +12,13 @@ require ( github.com/rs/zerolog v1.33.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/stretchr/testify v1.9.0 - github.com/tidwall/gjson v1.17.1 + github.com/tidwall/gjson v1.17.3 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.4 - go.mau.fi/util v0.6.1-0.20240719175439-20a6073e1dd4 + go.mau.fi/util v0.6.1-0.20240802175451-b430ebbffc98 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.25.0 - golang.org/x/exp v0.0.0-20240707233637-46b078467d37 + golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 golang.org/x/net v0.27.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 @@ -33,5 +33,6 @@ require ( github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect golang.org/x/sys v0.22.0 // indirect + golang.org/x/text v0.16.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index 6d58567d..ed4ca496 100644 --- a/go.sum +++ b/go.sum @@ -36,8 +36,8 @@ github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDq github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U= -github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94= +github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= @@ -46,14 +46,14 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.6.1-0.20240719175439-20a6073e1dd4 h1:CYKYs5jwJ0bFJqh6pRoWtC9NIJ0lz0/6i2SC4qEBFaU= -go.mau.fi/util v0.6.1-0.20240719175439-20a6073e1dd4/go.mod h1:ljYdq3sPfpICc3zMU+/mHV/sa4z0nKxc67hSBwnrk8U= +go.mau.fi/util v0.6.1-0.20240802175451-b430ebbffc98 h1:gJ0peWecBm6TtlxKFVIc1KbooXSCHtPfsfb2Eha5A0A= +go.mau.fi/util v0.6.1-0.20240802175451-b430ebbffc98/go.mod h1:S1juuPWGau2GctPY3FR/4ec/MDLhAG2QPhdnUwpzWIo= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= -golang.org/x/exp v0.0.0-20240707233637-46b078467d37 h1:uLDX+AfeFCct3a2C7uIWBKMJIR3CJMhcgfrUAqjRK6w= -golang.org/x/exp v0.0.0-20240707233637-46b078467d37/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= +golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8= +golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -62,6 +62,8 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= +golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= diff --git a/hicli/database/statestore.go b/hicli/database/statestore.go index baf84df1..e8050e93 100644 --- a/hicli/database/statestore.go +++ b/hicli/database/statestore.go @@ -92,6 +92,11 @@ func (c *ClientStateStore) TryGetMember(ctx context.Context, roomID id.RoomID, u return } +func (c *ClientStateStore) IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error) { + //TODO implement me + panic("implement me") +} + func (c *ClientStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (content *event.PowerLevelsEventContent, err error) { err = c.QueryRow(ctx, getStateEventContentQuery, roomID, event.StatePowerLevels.Type, "").Scan(&dbutil.JSON{Data: &content}) if errors.Is(err, sql.ErrNoRows) { diff --git a/sqlstatestore/statestore.go b/sqlstatestore/statestore.go index cd94215d..0e5c4184 100644 --- a/sqlstatestore/statestore.go +++ b/sqlstatestore/statestore.go @@ -17,6 +17,7 @@ import ( "strings" "github.com/rs/zerolog" + "go.mau.fi/util/confusable" "go.mau.fi/util/dbutil" "maunium.net/go/mautrix/event" @@ -37,6 +38,8 @@ const VersionTableName = "mx_version" type SQLStateStore struct { *dbutil.Database IsBridge bool + + DisableNameDisambiguation bool } func NewSQLStateStore(db *dbutil.Database, log dbutil.DatabaseLogger, isBridge bool) *SQLStateStore { @@ -65,6 +68,7 @@ func (store *SQLStateStore) MarkRegistered(ctx context.Context, userID id.UserID type Member struct { id.UserID event.MemberEventContent + NameSkeleton [32]byte } func (store *SQLStateStore) GetRoomMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) (map[id.UserID]*event.MemberEventContent, error) { @@ -191,13 +195,32 @@ func (store *SQLStateStore) SetMembership(ctx context.Context, roomID id.RoomID, } func (store *SQLStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error { + var nameSkeleton []byte + if !store.DisableNameDisambiguation && len(member.Displayname) > 0 { + nameSkeletonArr := confusable.SkeletonHash(member.Displayname) + nameSkeleton = nameSkeletonArr[:] + } _, err := store.Exec(ctx, ` - INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, $4, $5) - ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership, displayname=excluded.displayname, avatar_url=excluded.avatar_url - `, roomID, userID, member.Membership, member.Displayname, member.AvatarURL) + INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url, name_skeleton) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (room_id, user_id) DO UPDATE + SET membership=excluded.membership, + displayname=excluded.displayname, + avatar_url=excluded.avatar_url, + name_skeleton=excluded.name_skeleton + `, roomID, userID, member.Membership, member.Displayname, member.AvatarURL, nameSkeleton) return err } +func (store *SQLStateStore) IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error) { + if store.DisableNameDisambiguation { + return nil, nil + } + skeleton := confusable.SkeletonHash(name) + rows, err := store.Query(ctx, "SELECT user_id FROM mx_user_profile WHERE room_id=$1 AND name_skeleton=$2 AND user_id<>$3", roomID, skeleton[:], currentUser) + return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() +} + func (store *SQLStateStore) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error { query := "DELETE FROM mx_user_profile WHERE room_id=$1" params := make([]any, len(memberships)+1) diff --git a/sqlstatestore/v00-latest-revision.sql b/sqlstatestore/v00-latest-revision.sql index 41c2b9a1..b2bb2ae6 100644 --- a/sqlstatestore/v00-latest-revision.sql +++ b/sqlstatestore/v00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v5: Latest revision +-- v0 -> v6 (compatible with v3+): Latest revision CREATE TABLE mx_registrations ( user_id TEXT PRIMARY KEY @@ -13,9 +13,15 @@ CREATE TABLE mx_user_profile ( membership membership NOT NULL, displayname TEXT NOT NULL DEFAULT '', avatar_url TEXT NOT NULL DEFAULT '', + + name_skeleton bytea, + PRIMARY KEY (room_id, user_id) ); +CREATE INDEX mx_user_profile_membership_idx ON mx_user_profile (room_id, membership); +CREATE INDEX mx_user_profile_name_skeleton_idx ON mx_user_profile (room_id, name_skeleton); + CREATE TABLE mx_room_state ( room_id TEXT PRIMARY KEY, power_levels jsonb, diff --git a/sqlstatestore/v06-displayname-disambiguation.go b/sqlstatestore/v06-displayname-disambiguation.go new file mode 100644 index 00000000..d0d1d502 --- /dev/null +++ b/sqlstatestore/v06-displayname-disambiguation.go @@ -0,0 +1,55 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package sqlstatestore + +import ( + "context" + + "go.mau.fi/util/confusable" + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/id" +) + +type roomUserName struct { + RoomID id.RoomID + UserID id.UserID + Name string +} + +func init() { + UpgradeTable.Register(-1, 6, 3, "Add disambiguation column for user profiles", dbutil.TxnModeOn, func(ctx context.Context, db *dbutil.Database) error { + _, err := db.Exec(ctx, ` + ALTER TABLE mx_user_profile ADD COLUMN name_skeleton bytea; + CREATE INDEX mx_user_profile_membership_idx ON mx_user_profile (room_id, membership); + CREATE INDEX mx_user_profile_name_skeleton_idx ON mx_user_profile (room_id, name_skeleton); + `) + if err != nil { + return err + } + const ChunkSize = 1000 + const GetEntriesChunkQuery = "SELECT room_id, user_id, displayname FROM mx_user_profile WHERE displayname<>'' LIMIT $1 OFFSET $2" + const SetSkeletonHashQuery = `UPDATE mx_user_profile SET name_skeleton = $3 WHERE room_id = $1 AND user_id = $2` + for offset := 0; ; offset += ChunkSize { + entries, err := dbutil.NewSimpleReflectRowIter[roomUserName](db.Query(ctx, GetEntriesChunkQuery, ChunkSize, offset)).AsList() + if err != nil { + return err + } + for _, entry := range entries { + skel := confusable.SkeletonHash(entry.Name) + _, err = db.Exec(ctx, SetSkeletonHashQuery, entry.RoomID, entry.UserID, skel[:]) + if err != nil { + return err + } + } + if len(entries) < ChunkSize { + break + } + } + return nil + }) +} diff --git a/statestore.go b/statestore.go index 7b4e2d2d..fd8f81e5 100644 --- a/statestore.go +++ b/statestore.go @@ -26,6 +26,7 @@ type StateStore interface { TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error + IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error @@ -151,6 +152,11 @@ func (store *MemoryStateStore) GetMember(ctx context.Context, roomID id.RoomID, return member, err } +func (store *MemoryStateStore) IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error) { + // TODO implement? + return nil, nil +} + func (store *MemoryStateStore) TryGetMember(_ context.Context, roomID id.RoomID, userID id.UserID) (member *event.MemberEventContent, err error) { store.membersLock.RLock() defer store.membersLock.RUnlock() From ea3cd96e253f7f84bb5cc9433005eff13ab501b3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 2 Aug 2024 23:53:34 +0300 Subject: [PATCH 0567/1647] format/mdext: add single-character bold, italic and strikethrough parsers --- format/mdext/shortemphasis.go | 96 +++++++++++++++++++++++++++++++++++ format/mdext/shortstrike.go | 76 +++++++++++++++++++++++++++ 2 files changed, 172 insertions(+) create mode 100644 format/mdext/shortemphasis.go create mode 100644 format/mdext/shortstrike.go diff --git a/format/mdext/shortemphasis.go b/format/mdext/shortemphasis.go new file mode 100644 index 00000000..62190326 --- /dev/null +++ b/format/mdext/shortemphasis.go @@ -0,0 +1,96 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mdext + +import ( + "github.com/yuin/goldmark" + "github.com/yuin/goldmark/ast" + "github.com/yuin/goldmark/parser" + "github.com/yuin/goldmark/text" + "github.com/yuin/goldmark/util" +) + +var ShortEmphasis goldmark.Extender = &shortEmphasisExtender{} + +type shortEmphasisExtender struct{} + +func (s *shortEmphasisExtender) Extend(m goldmark.Markdown) { + m.Parser().AddOptions(parser.WithInlineParsers( + util.Prioritized(&italicsParser{}, 500), + util.Prioritized(&boldParser{}, 500), + )) +} + +type italicsDelimiterProcessor struct{} + +func (p *italicsDelimiterProcessor) IsDelimiter(b byte) bool { + return b == '_' +} + +func (p *italicsDelimiterProcessor) CanOpenCloser(opener, closer *parser.Delimiter) bool { + return opener.Char == closer.Char +} + +func (p *italicsDelimiterProcessor) OnMatch(consumes int) ast.Node { + return ast.NewEmphasis(1) +} + +var defaultItalicsDelimiterProcessor = &italicsDelimiterProcessor{} + +type italicsParser struct{} + +func (s *italicsParser) Trigger() []byte { + return []byte{'_'} +} + +func (s *italicsParser) Parse(parent ast.Node, block text.Reader, pc parser.Context) ast.Node { + before := block.PrecendingCharacter() + line, segment := block.PeekLine() + node := parser.ScanDelimiter(line, before, 1, defaultItalicsDelimiterProcessor) + if node == nil || node.OriginalLength > 1 || before == '_' { + return nil + } + node.Segment = segment.WithStop(segment.Start + node.OriginalLength) + block.Advance(node.OriginalLength) + pc.PushDelimiter(node) + return node +} + +type boldDelimiterProcessor struct{} + +func (p *boldDelimiterProcessor) IsDelimiter(b byte) bool { + return b == '*' +} + +func (p *boldDelimiterProcessor) CanOpenCloser(opener, closer *parser.Delimiter) bool { + return opener.Char == closer.Char +} + +func (p *boldDelimiterProcessor) OnMatch(consumes int) ast.Node { + return ast.NewEmphasis(2) +} + +var defaultBoldDelimiterProcessor = &boldDelimiterProcessor{} + +type boldParser struct{} + +func (s *boldParser) Trigger() []byte { + return []byte{'*'} +} + +func (s *boldParser) Parse(parent ast.Node, block text.Reader, pc parser.Context) ast.Node { + before := block.PrecendingCharacter() + line, segment := block.PeekLine() + node := parser.ScanDelimiter(line, before, 1, defaultBoldDelimiterProcessor) + if node == nil || node.OriginalLength > 1 || before == '*' { + return nil + } + node.Segment = segment.WithStop(segment.Start + node.OriginalLength) + block.Advance(node.OriginalLength) + pc.PushDelimiter(node) + return node +} diff --git a/format/mdext/shortstrike.go b/format/mdext/shortstrike.go new file mode 100644 index 00000000..00328f22 --- /dev/null +++ b/format/mdext/shortstrike.go @@ -0,0 +1,76 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mdext + +import ( + "github.com/yuin/goldmark" + gast "github.com/yuin/goldmark/ast" + "github.com/yuin/goldmark/extension" + "github.com/yuin/goldmark/extension/ast" + "github.com/yuin/goldmark/parser" + "github.com/yuin/goldmark/renderer" + "github.com/yuin/goldmark/text" + "github.com/yuin/goldmark/util" +) + +var ShortStrike goldmark.Extender = &shortStrikeExtender{length: 1} +var LongStrike goldmark.Extender = &shortStrikeExtender{length: 2} + +type shortStrikeExtender struct { + length int +} + +func (s *shortStrikeExtender) Extend(m goldmark.Markdown) { + m.Parser().AddOptions(parser.WithInlineParsers( + util.Prioritized(&strikethroughParser{length: s.length}, 500), + )) + m.Renderer().AddOptions(renderer.WithNodeRenderers( + util.Prioritized(extension.NewStrikethroughHTMLRenderer(), 500), + )) +} + +type strikethroughDelimiterProcessor struct{} + +func (p *strikethroughDelimiterProcessor) IsDelimiter(b byte) bool { + return b == '~' +} + +func (p *strikethroughDelimiterProcessor) CanOpenCloser(opener, closer *parser.Delimiter) bool { + return opener.Char == closer.Char +} + +func (p *strikethroughDelimiterProcessor) OnMatch(consumes int) gast.Node { + return ast.NewStrikethrough() +} + +var defaultStrikethroughDelimiterProcessor = &strikethroughDelimiterProcessor{} + +type strikethroughParser struct { + length int +} + +func (s *strikethroughParser) Trigger() []byte { + return []byte{'~'} +} + +func (s *strikethroughParser) Parse(parent gast.Node, block text.Reader, pc parser.Context) gast.Node { + before := block.PrecendingCharacter() + line, segment := block.PeekLine() + node := parser.ScanDelimiter(line, before, 1, defaultStrikethroughDelimiterProcessor) + if node == nil || node.OriginalLength != s.length || before == '~' { + return nil + } + + node.Segment = segment.WithStop(segment.Start + node.OriginalLength) + block.Advance(node.OriginalLength) + pc.PushDelimiter(node) + return node +} + +func (s *strikethroughParser) CloseBlock(parent gast.Node, pc parser.Context) { + // nothing to do +} From c6bc42f16cc44f759b7b4e52e903ea507a52e7dc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 3 Aug 2024 18:00:53 +0300 Subject: [PATCH 0568/1647] bridgev2/matrix: add support for generating public media URLs --- bridgev2/bridgeconfig/config.go | 8 ++ bridgev2/bridgeconfig/upgrade.go | 9 ++ bridgev2/matrix/connector.go | 8 +- bridgev2/matrix/mxmain/example-config.yaml | 15 +++ bridgev2/matrix/publicmedia.go | 128 +++++++++++++++++++++ bridgev2/matrixinterface.go | 4 + id/contenturi.go | 19 +++ 7 files changed, 190 insertions(+), 1 deletion(-) create mode 100644 bridgev2/matrix/publicmedia.go diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index 861805c6..74104dec 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -22,6 +22,7 @@ type Config struct { AppService AppserviceConfig `yaml:"appservice"` Matrix MatrixConfig `yaml:"matrix"` Provisioning ProvisioningConfig `yaml:"provisioning"` + PublicMedia PublicMediaConfig `yaml:"public_media"` DirectMedia DirectMediaConfig `yaml:"direct_media"` Backfill BackfillConfig `yaml:"backfill"` DoublePuppet DoublePuppetConfig `yaml:"double_puppet"` @@ -84,6 +85,13 @@ type DirectMediaConfig struct { mediaproxy.BasicConfig `yaml:",inline"` } +type PublicMediaConfig struct { + Enabled bool `yaml:"enabled"` + SigningKey string `yaml:"signing_key"` + HashLength int `yaml:"hash_length"` + Expiry int `yaml:"expiry"` +} + type DoublePuppetConfig struct { Servers map[string]string `yaml:"servers"` AllowDiscovery bool `yaml:"allow_discovery"` diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 57040607..c96a01c7 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -93,6 +93,15 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Str, "direct_media", "server_key") } + helper.Copy(up.Bool, "public_media", "enabled") + if signingKey, ok := helper.Get(up.Str, "public_media", "signing_key"); !ok || signingKey == "generate" { + helper.Set(up.Str, random.String(32), "public_media", "signing_key") + } else { + helper.Copy(up.Str, "public_media", "signing_key") + } + helper.Copy(up.Int, "public_media", "expiry") + helper.Copy(up.Int, "public_media", "hash_length") + helper.Copy(up.Bool, "backfill", "enabled") helper.Copy(up.Int, "backfill", "max_initial_messages") helper.Copy(up.Int, "backfill", "max_catchup_messages") diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index d167af9c..9ea1cd3a 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -69,7 +69,9 @@ type Connector struct { Provisioning *ProvisioningAPI DoublePuppet *doublePuppetUtil MediaProxy *mediaproxy.MediaProxy - dmaSigKey [32]byte + + dmaSigKey [32]byte + pubMediaSigKey []byte doublePuppetIntents *exsync.Map[id.UserID, *appservice.IntentAPI] @@ -152,6 +154,10 @@ func (br *Connector) Start(ctx context.Context) error { if err != nil { return err } + err = br.initPublicMedia() + if err != nil { + return err + } err = br.StateStore.Upgrade(ctx) if err != nil { return bridgev2.DBUpgradeError{Section: "matrix_state", Err: err} diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 4da3ca50..935b8e82 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -199,6 +199,21 @@ provisioning: # Enable debug API at /debug with provisioning authentication. debug_endpoints: false +# Some networks require publicly accessible media download links (e.g. for user avatars when using Discord webhooks). +# These settings control whether the bridge will provide such public media access. +public_media: + # Should public media be enabled at all? + # The public_address field under the appservice section MUST be set when enabling public media. + enabled: false + # A key for signing public media URLs. + # If set to "generate", a random key will be generated. + signing_key: generate + # Number of seconds that public media URLs are valid for. + # If set to 0, URLs will never expire. + expiry: 0 + # Length of hash to use for public media URLs. + hash_length: 32 + # Settings for converting remote media to custom mxc:// URIs instead of reuploading. # More details can be found at https://docs.mau.fi/bridges/go/discord/direct-media.html direct_media: diff --git a/bridgev2/matrix/publicmedia.go b/bridgev2/matrix/publicmedia.go new file mode 100644 index 00000000..9db5f442 --- /dev/null +++ b/bridgev2/matrix/publicmedia.go @@ -0,0 +1,128 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package matrix + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "fmt" + "io" + "net/http" + "time" + + "github.com/gorilla/mux" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/id" +) + +var _ bridgev2.MatrixConnectorWithPublicMedia = (*Connector)(nil) + +func (br *Connector) initPublicMedia() error { + if !br.Config.PublicMedia.Enabled { + return nil + } else if br.GetPublicAddress() == "" { + return fmt.Errorf("public media is enabled in config, but no public address is set") + } else if br.Config.PublicMedia.HashLength > 32 { + return fmt.Errorf("public media hash length is too long") + } else if br.Config.PublicMedia.HashLength < 0 { + return fmt.Errorf("public media hash length is negative") + } + br.pubMediaSigKey = []byte(br.Config.PublicMedia.SigningKey) + br.AS.Router.HandleFunc("/_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia).Methods(http.MethodGet) + return nil +} + +func (br *Connector) hashContentURI(uri id.ContentURI, expiry []byte) []byte { + hasher := hmac.New(sha256.New, br.pubMediaSigKey) + hasher.Write([]byte(uri.String())) + hasher.Write(expiry) + return hasher.Sum(expiry)[:br.Config.PublicMedia.HashLength+len(expiry)] +} + +func (br *Connector) makePublicMediaChecksum(uri id.ContentURI) []byte { + var expiresAt []byte + if br.Config.PublicMedia.Expiry > 0 { + expiresAtInt := time.Now().Add(time.Duration(br.Config.PublicMedia.Expiry) * time.Second).Unix() + expiresAt = binary.BigEndian.AppendUint64(nil, uint64(expiresAtInt)) + } + return br.hashContentURI(uri, expiresAt) +} + +func (br *Connector) verifyPublicMediaChecksum(uri id.ContentURI, checksum []byte) (valid, expired bool) { + var expiryBytes []byte + if br.Config.PublicMedia.Expiry > 0 { + if len(checksum) < 8 { + return + } + expiryBytes = checksum[:8] + expiresAtInt := binary.BigEndian.Uint64(expiryBytes) + expired = time.Now().Unix() > int64(expiresAtInt) + } + valid = hmac.Equal(checksum, br.hashContentURI(uri, expiryBytes)) + return +} + +var proxyHeadersToCopy = []string{ + "Content-Type", "Content-Disposition", "Content-Length", "Content-Security-Policy", + "Access-Control-Allow-Origin", "Access-Control-Allow-Methods", "Access-Control-Allow-Headers", + "Cache-Control", "Cross-Origin-Resource-Policy", +} + +func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + contentURI := id.ContentURI{ + Homeserver: vars["server"], + FileID: vars["mediaID"], + } + if !contentURI.IsValid() { + http.Error(w, "invalid content URI", http.StatusBadRequest) + return + } + checksum, err := base64.RawURLEncoding.DecodeString(vars["checksum"]) + if err != nil || !hmac.Equal(checksum, br.makePublicMediaChecksum(contentURI)) { + http.Error(w, "invalid base64 in checksum", http.StatusBadRequest) + return + } else if valid, expired := br.verifyPublicMediaChecksum(contentURI, checksum); !valid { + http.Error(w, "invalid checksum", http.StatusNotFound) + return + } else if expired { + http.Error(w, "checksum expired", http.StatusGone) + return + } + resp, err := br.Bot.Download(r.Context(), contentURI) + if err != nil { + br.Log.Warn().Stringer("uri", contentURI).Err(err).Msg("Failed to download media to proxy") + http.Error(w, "failed to download media", http.StatusInternalServerError) + return + } + defer resp.Body.Close() + for _, hdr := range proxyHeadersToCopy { + w.Header()[hdr] = resp.Header[hdr] + } + w.WriteHeader(http.StatusOK) + _, _ = io.Copy(w, resp.Body) +} + +func (br *Connector) GetPublicMediaAddress(contentURI id.ContentURIString) string { + if br.pubMediaSigKey == nil { + return "" + } + parsed, err := contentURI.Parse() + if err != nil || !parsed.IsValid() { + return "" + } + return fmt.Sprintf( + "%s/_mautrix/publicmedia/%s/%s/%s", + br.GetPublicAddress(), + parsed.Homeserver, + parsed.FileID, + base64.RawURLEncoding.EncodeToString(br.makePublicMediaChecksum(parsed)), + ) +} diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 51fee503..714d5fc7 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -58,6 +58,10 @@ type MatrixConnectorWithServer interface { GetRouter() *mux.Router } +type MatrixConnectorWithPublicMedia interface { + GetPublicMediaAddress(contentURI id.ContentURIString) string +} + type MatrixConnectorWithNameDisambiguation interface { IsConfusableName(ctx context.Context, roomID id.RoomID, userID id.UserID, name string) ([]id.UserID, error) } diff --git a/id/contenturi.go b/id/contenturi.go index cfd00c3e..df02f54b 100644 --- a/id/contenturi.go +++ b/id/contenturi.go @@ -12,6 +12,7 @@ import ( "encoding/json" "errors" "fmt" + "regexp" "strings" ) @@ -156,3 +157,21 @@ func (uri ContentURI) CUString() ContentURIString { func (uri ContentURI) IsEmpty() bool { return len(uri.Homeserver) == 0 || len(uri.FileID) == 0 } + +var simpleHomeserverRegex = regexp.MustCompile(`^[a-zA-Z0-9.:-]+$`) + +func (uri ContentURI) IsValid() bool { + return IsValidMediaID(uri.Homeserver) && uri.Homeserver != "" && simpleHomeserverRegex.MatchString(uri.Homeserver) +} + +func IsValidMediaID(mediaID string) bool { + if len(mediaID) == 0 { + return false + } + for _, char := range mediaID { + if (char < 'A' || char > 'Z') && (char < 'a' || char > 'z') && (char < '0' || char > '9') && char != '-' && char != '_' { + return false + } + } + return true +} From e6586537d3327825afd270fe901a67d3b77b3baa Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 3 Aug 2024 18:01:53 +0300 Subject: [PATCH 0569/1647] bridgev2: add option for name formatting when using fancy relays (like Discord webhooks and Slack bots that can set custom names) --- bridgev2/bridgeconfig/relay.go | 23 +++++++++++++++++----- bridgev2/bridgeconfig/upgrade.go | 1 + bridgev2/matrix/mxmain/example-config.yaml | 4 ++++ bridgev2/networkinterface.go | 1 + bridgev2/portal.go | 1 + 5 files changed, 25 insertions(+), 5 deletions(-) diff --git a/bridgev2/bridgeconfig/relay.go b/bridgev2/bridgeconfig/relay.go index 7daf8e38..c802f85e 100644 --- a/bridgev2/bridgeconfig/relay.go +++ b/bridgev2/bridgeconfig/relay.go @@ -19,11 +19,13 @@ import ( ) type RelayConfig struct { - Enabled bool `yaml:"enabled"` - AdminOnly bool `yaml:"admin_only"` - DefaultRelays []networkid.UserLoginID `yaml:"default_relays"` - MessageFormats map[event.MessageType]string `yaml:"message_formats"` - messageTemplates *template.Template `yaml:"-"` + Enabled bool `yaml:"enabled"` + AdminOnly bool `yaml:"admin_only"` + DefaultRelays []networkid.UserLoginID `yaml:"default_relays"` + MessageFormats map[event.MessageType]string `yaml:"message_formats"` + DisplaynameFormat string `yaml:"displayname_format"` + messageTemplates *template.Template `yaml:"-"` + nameTemplate *template.Template `yaml:"-"` } type umRelayConfig RelayConfig @@ -42,6 +44,11 @@ func (rc *RelayConfig) UnmarshalYAML(node *yaml.Node) error { } } + rc.nameTemplate, err = template.New("nameTemplate").Parse(rc.DisplaynameFormat) + if err != nil { + return err + } + return nil } @@ -94,3 +101,9 @@ func (rc *RelayConfig) FormatMessage(content *event.MessageEventContent, sender content.Body = format.HTMLToText(content.FormattedBody) return content, nil } + +func (rc *RelayConfig) FormatName(sender any) string { + var buf strings.Builder + _ = rc.nameTemplate.Execute(&buf, sender) + return strings.TrimSpace(buf.String()) +} diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index c96a01c7..07026b6e 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -29,6 +29,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "bridge", "relay", "admin_only") helper.Copy(up.List, "bridge", "relay", "default_relays") helper.Copy(up.Map, "bridge", "relay", "message_formats") + helper.Copy(up.Str, "bridge", "relay", "displayname_format") helper.Copy(up.Map, "bridge", "permissions") if dbType, ok := helper.Get(up.Str, "database", "type"); ok && dbType == "sqlite3" { diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 935b8e82..7c3911ab 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -64,6 +64,10 @@ bridge: m.audio: "{{ .Sender.DisambiguatedName }} sent an audio file{{ if .Caption }}: {{ .Caption }}{{ end }}" m.video: "{{ .Sender.DisambiguatedName }} sent a video{{ if .Caption }}: {{ .Caption }}{{ end }}" m.location: "{{ .Sender.DisambiguatedName }} sent a location{{ if .Caption }}: {{ .Caption }}{{ end }}" + # For networks that support per-message displaynames (i.e. Slack and Discord), the template for those names. + # This has all the Sender variables available under message_formats (but without the .Sender prefix). + # Note that you need to manually remove the displayname from message_formats above. + displayname_format: "{{ .DisambiguatedName }}" # Permissions for using the bridge. # Permitted values: diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index d73d6c4d..a28b13d2 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -902,6 +902,7 @@ type OrigSender struct { RequiresDisambiguation bool DisambiguatedName string + FormattedName string event.MemberEventContent } diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 82a6db46..31439ee6 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -462,6 +462,7 @@ func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) { } else { origSender.DisambiguatedName = sender.MXID.String() } + origSender.FormattedName = portal.Bridge.Config.Relay.FormatName(origSender) } log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Str("login_id", string(login.ID)) From b71b32d0d6d619a2c16cc18ccb1dc5307045f755 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 3 Aug 2024 18:09:39 +0300 Subject: [PATCH 0570/1647] bridgev2: redefine relay admin-only setting Now users can still set relays from the `default_relays` list even if `admin_only` is true. --- bridgev2/bridgeconfig/permissions.go | 7 +++--- bridgev2/commands/relay.go | 28 +++++++++++++++------- bridgev2/matrix/mxmain/example-config.yaml | 1 + 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/bridgev2/bridgeconfig/permissions.go b/bridgev2/bridgeconfig/permissions.go index 15b4561d..e76046f5 100644 --- a/bridgev2/bridgeconfig/permissions.go +++ b/bridgev2/bridgeconfig/permissions.go @@ -21,6 +21,7 @@ type Permissions struct { Login bool `yaml:"login"` DoublePuppet bool `yaml:"double_puppet"` Admin bool `yaml:"admin"` + ManageRelay bool `yaml:"manage_relay"` } type PermissionConfig map[string]*Permissions @@ -58,9 +59,9 @@ func (pc PermissionConfig) Get(userID id.UserID) Permissions { var ( PermissionLevelBlock = Permissions{} PermissionLevelRelay = Permissions{SendEvents: true} - PermissionLevelCommands = Permissions{SendEvents: true, Commands: true} - PermissionLevelUser = Permissions{SendEvents: true, Commands: true, Login: true, DoublePuppet: true} - PermissionLevelAdmin = Permissions{SendEvents: true, Commands: true, Login: true, DoublePuppet: true, Admin: true} + PermissionLevelCommands = Permissions{SendEvents: true, Commands: true, ManageRelay: true} + PermissionLevelUser = Permissions{SendEvents: true, Commands: true, ManageRelay: true, Login: true, DoublePuppet: true} + PermissionLevelAdmin = Permissions{SendEvents: true, Commands: true, ManageRelay: true, Login: true, DoublePuppet: true, Admin: true} ) var namesToLevels = map[string]Permissions{ diff --git a/bridgev2/commands/relay.go b/bridgev2/commands/relay.go index e2d77a2d..af756c87 100644 --- a/bridgev2/commands/relay.go +++ b/bridgev2/commands/relay.go @@ -35,9 +35,14 @@ func fnSetRelay(ce *Event) { ce.Reply("You don't have permission to manage the relay in this room") return } + onlySetDefaultRelays := !ce.User.Permissions.Admin && ce.Bridge.Config.Relay.AdminOnly var relay *bridgev2.UserLogin if len(ce.Args) == 0 { relay = ce.User.GetDefaultLogin() + isLoggedIn := relay != nil + if onlySetDefaultRelays { + relay = nil + } if relay == nil { if len(ce.Bridge.Config.Relay.DefaultRelays) == 0 { ce.Reply("You're not logged in and there are no default relay users configured") @@ -59,7 +64,11 @@ func fnSetRelay(ce *Event) { } } if relay == nil { - ce.Reply("You're not logged in and none of the default relay users are in the chat") + if isLoggedIn { + ce.Reply("You're not allowed to use yourself as relay and none of the default relay users are in the chat") + } else { + ce.Reply("You're not logged in and none of the default relay users are in the chat") + } return } } @@ -68,9 +77,14 @@ func fnSetRelay(ce *Event) { if relay == nil { ce.Reply("User login with ID `%s` not found", ce.Args[0]) return - } else if !slices.Contains(ce.Bridge.Config.Relay.DefaultRelays, relay.ID) && relay.UserMXID != ce.User.MXID && !ce.User.Permissions.Admin { + } else if slices.Contains(ce.Bridge.Config.Relay.DefaultRelays, relay.ID) { + // All good + } else if relay.UserMXID != ce.User.MXID && !ce.User.Permissions.Admin { ce.Reply("Only bridge admins can set another user's login as the relay") return + } else if onlySetDefaultRelays { + ce.Reply("You're not allowed to use yourself as relay") + return } } err := ce.Portal.SetRelay(ce.Ctx, relay) @@ -116,12 +130,10 @@ func fnUnsetRelay(ce *Event) { } func canManageRelay(ce *Event) bool { - if ce.Bridge.Config.Relay.AdminOnly { - return ce.User.Permissions.Admin - } - return ce.User.Permissions.Admin || - (ce.Portal.Relay != nil && ce.Portal.Relay.UserMXID == ce.User.MXID) || - hasRelayRoomPermissions(ce) + return ce.User.Permissions.ManageRelay && + (ce.User.Permissions.Admin || + (ce.Portal.Relay != nil && ce.Portal.Relay.UserMXID == ce.User.MXID) || + hasRelayRoomPermissions(ce)) } func hasRelayRoomPermissions(ce *Event) bool { diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 7c3911ab..16952c35 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -41,6 +41,7 @@ bridge: # authenticated user into a relaybot for that chat. enabled: false # Should only admins be allowed to set themselves as relay users? + # If true, non-admins can only set users listed in default_relays as relays in a room. admin_only: true # List of user login IDs which anyone can set as a relay, as long as the relay user is in the room. default_relays: [] From ed7ec6a066aae705e3cd9c083abe7967d7d8ef60 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 3 Aug 2024 20:23:59 +0300 Subject: [PATCH 0571/1647] bridgev2: don't try to delete non-existent rooms --- bridgev2/userlogin.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 61036438..7e76b203 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -338,13 +338,15 @@ func DeleteManyPortals(ctx context.Context, portals []*Portal, errorCallback fun } continue } - err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, false) - if err != nil { - zerolog.Ctx(ctx).Err(err). - Stringer("portal_mxid", portal.MXID). - Msg("Failed to clean up portal room") - if errorCallback != nil { - errorCallback(portal, true, err) + if portal.MXID != "" { + err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, false) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("portal_mxid", portal.MXID). + Msg("Failed to clean up portal room") + if errorCallback != nil { + errorCallback(portal, true, err) + } } } } From 9a1e84d74e7b804967ac77c625d2ff713c588581 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 3 Aug 2024 20:31:20 +0300 Subject: [PATCH 0572/1647] bridgev2/config: adjust public media section --- bridgev2/bridgeconfig/upgrade.go | 5 +++-- bridgev2/matrix/mxmain/example-config.yaml | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 07026b6e..15eeada0 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -96,7 +96,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "public_media", "enabled") if signingKey, ok := helper.Get(up.Str, "public_media", "signing_key"); !ok || signingKey == "generate" { - helper.Set(up.Str, random.String(32), "public_media", "signing_key") + helper.Set(up.Str, random.String(64), "public_media", "signing_key") } else { helper.Copy(up.Str, "public_media", "signing_key") } @@ -285,8 +285,9 @@ var SpacedBlocks = [][]string{ {"appservice", "username_template"}, {"matrix"}, {"provisioning"}, - {"backfill"}, + {"public_media"}, {"direct_media"}, + {"backfill"}, {"double_puppet"}, {"encryption"}, {"logging"}, diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 16952c35..deabe3c8 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -216,7 +216,7 @@ public_media: # Number of seconds that public media URLs are valid for. # If set to 0, URLs will never expire. expiry: 0 - # Length of hash to use for public media URLs. + # Length of hash to use for public media URLs. Must be between 0 and 32. hash_length: 32 # Settings for converting remote media to custom mxc:// URIs instead of reuploading. From dc35792d756831892209504d31ca1113fcfc2e09 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 3 Aug 2024 21:04:21 +0300 Subject: [PATCH 0573/1647] bridgev2/login: add token type for user input --- bridgev2/login.go | 1 + bridgev2/unorganized-docs/login-step.schema.json | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/bridgev2/login.go b/bridgev2/login.go index 32bf6e67..7acccd9a 100644 --- a/bridgev2/login.go +++ b/bridgev2/login.go @@ -158,6 +158,7 @@ const ( LoginInputFieldTypePhoneNumber LoginInputFieldType = "phone_number" LoginInputFieldTypeEmail LoginInputFieldType = "email" LoginInputFieldType2FACode LoginInputFieldType = "2fa_code" + LoginInputFieldTypeToken LoginInputFieldType = "token" ) type LoginInputDataField struct { diff --git a/bridgev2/unorganized-docs/login-step.schema.json b/bridgev2/unorganized-docs/login-step.schema.json index 4dbf6d47..b039354f 100644 --- a/bridgev2/unorganized-docs/login-step.schema.json +++ b/bridgev2/unorganized-docs/login-step.schema.json @@ -31,7 +31,7 @@ "properties": { "type": { "type": "string", - "enum": ["username", "phone_number", "email", "password", "2fa_code"] + "enum": ["username", "phone_number", "email", "password", "2fa_code", "token"] }, "id": { "type": "string", From 956c13761ebb3aac6e40b7b1ecbde3d102a2b69d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 3 Aug 2024 22:06:39 +0300 Subject: [PATCH 0574/1647] id: fix typo in ContentURI.IsValid --- id/contenturi.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/id/contenturi.go b/id/contenturi.go index df02f54b..e6a313f5 100644 --- a/id/contenturi.go +++ b/id/contenturi.go @@ -161,7 +161,7 @@ func (uri ContentURI) IsEmpty() bool { var simpleHomeserverRegex = regexp.MustCompile(`^[a-zA-Z0-9.:-]+$`) func (uri ContentURI) IsValid() bool { - return IsValidMediaID(uri.Homeserver) && uri.Homeserver != "" && simpleHomeserverRegex.MatchString(uri.Homeserver) + return IsValidMediaID(uri.FileID) && uri.Homeserver != "" && simpleHomeserverRegex.MatchString(uri.Homeserver) } func IsValidMediaID(mediaID string) bool { From 0e048db4f7700744e0c3b99cf2a95b0fda4df45b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 4 Aug 2024 20:52:52 +0300 Subject: [PATCH 0575/1647] bridgev2: add remote profile field in bridge states --- bridge/status/bridgestate.go | 31 ++++++++++---- bridgev2/bridgestate.go | 4 +- bridgev2/database/upgrades/00-latest.sql | 15 +++---- .../upgrades/16-user-login-profile.sql | 2 + bridgev2/database/userlogin.go | 40 +++++++++++++------ bridgev2/userlogin.go | 19 +++------ 6 files changed, 67 insertions(+), 44 deletions(-) create mode 100644 bridgev2/database/upgrades/16-user-login-profile.sql diff --git a/bridge/status/bridgestate.go b/bridge/status/bridgestate.go index e6047a1d..143aaeb0 100644 --- a/bridge/status/bridgestate.go +++ b/bridge/status/bridgestate.go @@ -52,6 +52,18 @@ const ( StateLoggedOut BridgeStateEvent = "LOGGED_OUT" ) +type RemoteProfile struct { + Phone string `json:"phone,omitempty"` + Email string `json:"email,omitempty"` + Username string `json:"username,omitempty"` + Name string `json:"name,omitempty"` + Avatar id.ContentURIString `json:"avatar,omitempty"` +} + +func (rp *RemoteProfile) IsEmpty() bool { + return rp == nil || (rp.Phone == "" && rp.Email == "" && rp.Username == "" && rp.Name == "" && rp.Avatar == "") +} + type BridgeState struct { StateEvent BridgeStateEvent `json:"state_event"` Timestamp jsontime.Unix `json:"timestamp"` @@ -61,9 +73,10 @@ type BridgeState struct { Error BridgeStateErrorCode `json:"error,omitempty"` Message string `json:"message,omitempty"` - UserID id.UserID `json:"user_id,omitempty"` - RemoteID string `json:"remote_id,omitempty"` - RemoteName string `json:"remote_name,omitempty"` + UserID id.UserID `json:"user_id,omitempty"` + RemoteID string `json:"remote_id,omitempty"` + RemoteName string `json:"remote_name,omitempty"` + RemoteProfile RemoteProfile `json:"remote_profile,omitempty"` Reason string `json:"reason,omitempty"` Info map[string]interface{} `json:"info,omitempty"` @@ -89,13 +102,15 @@ type CustomBridgeStateFiller interface { StandaloneCustomBridgeStateFiller } -func (pong BridgeState) Fill(user BridgeStateFiller) BridgeState { +func (pong BridgeState) Fill(user any) BridgeState { if user != nil { - pong.UserID = user.GetMXID() - pong.RemoteID = user.GetRemoteID() - pong.RemoteName = user.GetRemoteName() + if std, ok := user.(BridgeStateFiller); ok { + pong.UserID = std.GetMXID() + pong.RemoteID = std.GetRemoteID() + pong.RemoteName = std.GetRemoteName() + } - if custom, ok := user.(CustomBridgeStateFiller); ok { + if custom, ok := user.(StandaloneCustomBridgeStateFiller); ok { pong = custom.FillBridgeState(pong) } } diff --git a/bridgev2/bridgestate.go b/bridgev2/bridgestate.go index bded88d3..1cd6b0c5 100644 --- a/bridgev2/bridgestate.go +++ b/bridgev2/bridgestate.go @@ -21,7 +21,7 @@ type BridgeStateQueue struct { prevSent *status.BridgeState ch chan status.BridgeState bridge *Bridge - user status.BridgeStateFiller + user status.StandaloneCustomBridgeStateFiller } func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) { @@ -41,7 +41,7 @@ func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) { } } -func (br *Bridge) NewBridgeStateQueue(user status.BridgeStateFiller) *BridgeStateQueue { +func (br *Bridge) NewBridgeStateQueue(user status.StandaloneCustomBridgeStateFiller) *BridgeStateQueue { bsq := &BridgeStateQueue{ ch: make(chan status.BridgeState, 10), bridge: br, diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index 16c701ff..aeb9522e 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v15 (compatible with v9+): Latest revision +-- v0 -> v16 (compatible with v9+): Latest revision CREATE TABLE "user" ( bridge_id TEXT NOT NULL, mxid TEXT NOT NULL, @@ -10,12 +10,13 @@ CREATE TABLE "user" ( ); CREATE TABLE user_login ( - bridge_id TEXT NOT NULL, - user_mxid TEXT NOT NULL, - id TEXT NOT NULL, - remote_name TEXT NOT NULL, - space_room TEXT, - metadata jsonb NOT NULL, + bridge_id TEXT NOT NULL, + user_mxid TEXT NOT NULL, + id TEXT NOT NULL, + remote_name TEXT NOT NULL, + remote_profile jsonb, + space_room TEXT, + metadata jsonb NOT NULL, PRIMARY KEY (bridge_id, id), CONSTRAINT user_login_user_fkey FOREIGN KEY (bridge_id, user_mxid) diff --git a/bridgev2/database/upgrades/16-user-login-profile.sql b/bridgev2/database/upgrades/16-user-login-profile.sql new file mode 100644 index 00000000..e143fcee --- /dev/null +++ b/bridgev2/database/upgrades/16-user-login-profile.sql @@ -0,0 +1,2 @@ +-- v16 (compatible with v9+): Save remote profile in user logins +ALTER TABLE user_login ADD COLUMN remote_profile jsonb; diff --git a/bridgev2/database/userlogin.go b/bridgev2/database/userlogin.go index d994d270..610e7d60 100644 --- a/bridgev2/database/userlogin.go +++ b/bridgev2/database/userlogin.go @@ -12,6 +12,7 @@ import ( "go.mau.fi/util/dbutil" + "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/id" ) @@ -23,32 +24,33 @@ type UserLoginQuery struct { } type UserLogin struct { - BridgeID networkid.BridgeID - UserMXID id.UserID - ID networkid.UserLoginID - RemoteName string - SpaceRoom id.RoomID - Metadata any + BridgeID networkid.BridgeID + UserMXID id.UserID + ID networkid.UserLoginID + RemoteName string + RemoteProfile status.RemoteProfile + SpaceRoom id.RoomID + Metadata any } const ( getUserLoginBaseQuery = ` - SELECT bridge_id, user_mxid, id, remote_name, space_room, metadata FROM user_login + SELECT bridge_id, user_mxid, id, remote_name, remote_profile, space_room, metadata FROM user_login ` getLoginByIDQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND id=$2` getAllUsersWithLoginsQuery = `SELECT DISTINCT user_mxid FROM user_login WHERE bridge_id=$1` getAllLoginsForUserQuery = getUserLoginBaseQuery + `WHERE bridge_id=$1 AND user_mxid=$2` getAllLoginsInPortalQuery = ` - SELECT ul.bridge_id, ul.user_mxid, ul.id, ul.remote_name, ul.space_room, ul.metadata FROM user_portal + SELECT ul.bridge_id, ul.user_mxid, ul.id, ul.remote_name, ul.remote_profile, ul.space_room, ul.metadata FROM user_portal LEFT JOIN user_login ul ON user_portal.bridge_id=ul.bridge_id AND user_portal.user_mxid=ul.user_mxid AND user_portal.login_id=ul.id WHERE user_portal.bridge_id=$1 AND user_portal.portal_id=$2 AND user_portal.portal_receiver=$3 ` insertUserLoginQuery = ` - INSERT INTO user_login (bridge_id, user_mxid, id, remote_name, space_room, metadata) - VALUES ($1, $2, $3, $4, $5, $6) + INSERT INTO user_login (bridge_id, user_mxid, id, remote_name, remote_profile, space_room, metadata) + VALUES ($1, $2, $3, $4, $5, $6, $7) ` updateUserLoginQuery = ` - UPDATE user_login SET remote_name=$4, space_room=$5, metadata=$6 + UPDATE user_login SET remote_name=$4, remote_profile=$5, space_room=$6, metadata=$7 WHERE bridge_id=$1 AND user_mxid=$2 AND id=$3 ` deleteUserLoginQuery = ` @@ -89,7 +91,15 @@ func (uq *UserLoginQuery) Delete(ctx context.Context, loginID networkid.UserLogi func (u *UserLogin) Scan(row dbutil.Scannable) (*UserLogin, error) { var spaceRoom sql.NullString - err := row.Scan(&u.BridgeID, &u.UserMXID, &u.ID, &u.RemoteName, &spaceRoom, dbutil.JSON{Data: u.Metadata}) + err := row.Scan( + &u.BridgeID, + &u.UserMXID, + &u.ID, + &u.RemoteName, + dbutil.JSON{Data: &u.RemoteProfile}, + &spaceRoom, + dbutil.JSON{Data: u.Metadata}, + ) if err != nil { return nil, err } @@ -105,5 +115,9 @@ func (u *UserLogin) ensureHasMetadata(metaType MetaTypeCreator) *UserLogin { } func (u *UserLogin) sqlVariables() []any { - return []any{u.BridgeID, u.UserMXID, u.ID, u.RemoteName, dbutil.StrPtr(u.SpaceRoom), dbutil.JSON{Data: u.Metadata}} + var remoteProfile dbutil.JSON + if !u.RemoteProfile.IsEmpty() { + remoteProfile.Data = &u.RemoteProfile + } + return []any{u.BridgeID, u.UserMXID, u.ID, u.RemoteName, remoteProfile, dbutil.StrPtr(u.SpaceRoom), dbutil.JSON{Data: u.Metadata}} } diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 7e76b203..0c3556bc 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -22,7 +22,6 @@ import ( "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" ) type UserLogin struct { @@ -460,21 +459,13 @@ func (ul *UserLogin) MarkAsPreferredIn(ctx context.Context, portal *Portal) erro return ul.Bridge.DB.UserPortal.MarkAsPreferred(ctx, ul.UserLogin, portal.PortalKey) } -var _ status.CustomBridgeStateFiller = (*UserLogin)(nil) - -func (ul *UserLogin) GetMXID() id.UserID { - return ul.UserMXID -} - -func (ul *UserLogin) GetRemoteID() string { - return string(ul.ID) -} - -func (ul *UserLogin) GetRemoteName() string { - return ul.RemoteName -} +var _ status.StandaloneCustomBridgeStateFiller = (*UserLogin)(nil) func (ul *UserLogin) FillBridgeState(state status.BridgeState) status.BridgeState { + state.UserID = ul.UserMXID + state.RemoteID = string(ul.ID) + state.RemoteName = ul.RemoteName + state.RemoteProfile = ul.RemoteProfile filler, ok := ul.Client.(status.StandaloneCustomBridgeStateFiller) if ok { return filler.FillBridgeState(state) From 3a25416c01cd195ebf8a72443304043ccd68e8f5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 4 Aug 2024 21:01:42 +0300 Subject: [PATCH 0576/1647] bridgev2/login: merge remote profile data when relogining --- bridge/status/bridgestate.go | 16 ++++++++++++++++ bridgev2/userlogin.go | 1 + 2 files changed, 17 insertions(+) diff --git a/bridge/status/bridgestate.go b/bridge/status/bridgestate.go index 143aaeb0..22a04fa4 100644 --- a/bridge/status/bridgestate.go +++ b/bridge/status/bridgestate.go @@ -60,6 +60,22 @@ type RemoteProfile struct { Avatar id.ContentURIString `json:"avatar,omitempty"` } +func coalesce[T ~string](a, b T) T { + if a != "" { + return a + } + return b +} + +func (rp *RemoteProfile) Merge(other RemoteProfile) RemoteProfile { + other.Phone = coalesce(rp.Phone, other.Phone) + other.Email = coalesce(rp.Email, other.Email) + other.Username = coalesce(rp.Username, other.Username) + other.Name = coalesce(rp.Name, other.Name) + other.Avatar = coalesce(rp.Avatar, other.Avatar) + return other +} + func (rp *RemoteProfile) IsEmpty() bool { return rp == nil || (rp.Phone == "" && rp.Email == "" && rp.Username == "" && rp.Name == "" && rp.Avatar == "") } diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 0c3556bc..0b1bd9e0 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -198,6 +198,7 @@ func (user *User) NewLogin(ctx context.Context, data *database.UserLogin, params } doInsert = false ul.RemoteName = data.RemoteName + ul.RemoteProfile = ul.RemoteProfile.Merge(data.RemoteProfile) if merger, ok := ul.Metadata.(database.MetaMerger); ok { merger.CopyFrom(data.Metadata) } else { From e37f91b3b117f9142b64305ad6343ecbe125173e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 4 Aug 2024 21:17:50 +0300 Subject: [PATCH 0577/1647] bridgev2: include remote profile in provisioning API --- bridge/status/bridgestate.go | 1 + bridgev2/matrix/provisioning.go | 2 ++ 2 files changed, 3 insertions(+) diff --git a/bridge/status/bridgestate.go b/bridge/status/bridgestate.go index 22a04fa4..90f228d4 100644 --- a/bridge/status/bridgestate.go +++ b/bridge/status/bridgestate.go @@ -186,6 +186,7 @@ func (pong *BridgeState) SendHTTP(ctx context.Context, url, token string) error func (pong *BridgeState) ShouldDeduplicate(newPong *BridgeState) bool { return pong != nil && pong.StateEvent == newPong.StateEvent && + pong.RemoteProfile == newPong.RemoteProfile && pong.Error == newPong.Error && maps.EqualFunc(pong.Info, newPong.Info, reflect.DeepEqual) && pong.Timestamp.Add(time.Duration(pong.TTL)*time.Second).After(time.Now()) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 4489a3e3..4952c0bf 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -281,6 +281,7 @@ type RespWhoamiLogin struct { StateTS jsontime.Unix `json:"state_ts"` ID networkid.UserLoginID `json:"id"` Name string `json:"name"` + Profile status.RemoteProfile `json:"profile"` SpaceRoom id.RoomID `json:"space_room"` } @@ -303,6 +304,7 @@ func (prov *ProvisioningAPI) GetWhoami(w http.ResponseWriter, r *http.Request) { StateTS: prevState.Timestamp, ID: login.ID, Name: login.RemoteName, + Profile: login.RemoteProfile, SpaceRoom: login.SpaceRoom, } } From 7f22cb44105e84093b0759d53a488f18bf183b5f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 4 Aug 2024 21:42:40 +0300 Subject: [PATCH 0578/1647] bridgev2/provisioning: include remote state info in whoami --- bridgev2/matrix/provisioning.go | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 4952c0bf..69223b02 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -277,12 +277,14 @@ type RespWhoami struct { } type RespWhoamiLogin struct { - StateEvent status.BridgeStateEvent `json:"state_event"` - StateTS jsontime.Unix `json:"state_ts"` - ID networkid.UserLoginID `json:"id"` - Name string `json:"name"` - Profile status.RemoteProfile `json:"profile"` - SpaceRoom id.RoomID `json:"space_room"` + StateEvent status.BridgeStateEvent `json:"state_event"` + StateTS jsontime.Unix `json:"state_ts"` + StateReason string `json:"state_reason,omitempty"` + StateInfo map[string]any `json:"state_info,omitempty"` + ID networkid.UserLoginID `json:"id"` + Name string `json:"name"` + Profile status.RemoteProfile `json:"profile"` + SpaceRoom id.RoomID `json:"space_room"` } func (prov *ProvisioningAPI) GetWhoami(w http.ResponseWriter, r *http.Request) { @@ -300,12 +302,15 @@ func (prov *ProvisioningAPI) GetWhoami(w http.ResponseWriter, r *http.Request) { for i, login := range logins { prevState := login.BridgeState.GetPrevUnsent() resp.Logins[i] = RespWhoamiLogin{ - StateEvent: prevState.StateEvent, - StateTS: prevState.Timestamp, - ID: login.ID, - Name: login.RemoteName, - Profile: login.RemoteProfile, - SpaceRoom: login.SpaceRoom, + StateEvent: prevState.StateEvent, + StateTS: prevState.Timestamp, + StateReason: prevState.Reason, + StateInfo: prevState.Info, + + ID: login.ID, + Name: login.RemoteName, + Profile: login.RemoteProfile, + SpaceRoom: login.SpaceRoom, } } jsonResponse(w, http.StatusOK, resp) From 9d6622c29325ee9e4c3fbf94ca95503290949905 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 4 Aug 2024 22:22:10 +0300 Subject: [PATCH 0579/1647] bridgev2/simplevent: fix needs backfill check in ChatResync --- bridgev2/simplevent/chat.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/simplevent/chat.go b/bridgev2/simplevent/chat.go index c3a62b85..e7b13fef 100644 --- a/bridgev2/simplevent/chat.go +++ b/bridgev2/simplevent/chat.go @@ -44,7 +44,7 @@ func (evt *ChatResync) CheckNeedsBackfill(ctx context.Context, latestMessage *da } else if latestMessage == nil { return !evt.LatestMessageTS.IsZero(), nil } else { - return !evt.LatestMessageTS.IsZero() && evt.LatestMessageTS.Before(latestMessage.Timestamp), nil + return evt.LatestMessageTS.After(latestMessage.Timestamp), nil } } From 8daaabcc011a2bea23859148385e3e2a50234fc3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 5 Aug 2024 20:06:54 +0300 Subject: [PATCH 0580/1647] bridgev2/mxmain: don't suggest using ignore foreign tables flag --- bridge/bridge.go | 1 - bridgev2/matrix/mxmain/dberror.go | 1 - 2 files changed, 2 deletions(-) diff --git a/bridge/bridge.go b/bridge/bridge.go index 4af2470b..17a4a30c 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -710,7 +710,6 @@ func (br *Bridge) LogDBUpgradeErrorAndExit(name string, err error) { if sqlError := (&sqlite3.Error{}); errors.As(err, sqlError) && sqlError.Code == sqlite3.ErrCorrupt { os.Exit(18) } else if errors.Is(err, dbutil.ErrForeignTables) { - br.ZLog.Info().Msg("You can use --ignore-foreign-tables to ignore this error") br.ZLog.Info().Msg("See https://docs.mau.fi/faq/foreign-tables for more info") } else if errors.Is(err, dbutil.ErrNotOwned) { br.ZLog.Info().Msg("Sharing the same database with different programs is not supported") diff --git a/bridgev2/matrix/mxmain/dberror.go b/bridgev2/matrix/mxmain/dberror.go index 1c0f6381..0f6aa68c 100644 --- a/bridgev2/matrix/mxmain/dberror.go +++ b/bridgev2/matrix/mxmain/dberror.go @@ -64,7 +64,6 @@ func (br *BridgeMain) LogDBUpgradeErrorAndExit(name string, err error, message s if sqlError := (&sqlite3.Error{}); errors.As(err, sqlError) && sqlError.Code == sqlite3.ErrCorrupt { os.Exit(18) } else if errors.Is(err, dbutil.ErrForeignTables) { - br.Log.Info().Msg("You can use --ignore-foreign-tables to ignore this error") br.Log.Info().Msg("See https://docs.mau.fi/faq/foreign-tables for more info") } else if errors.Is(err, dbutil.ErrNotOwned) { br.Log.Info().Msg("Sharing the same database with different programs is not supported") From fd078283b5bdbab0e5bd4fb64ce604e18f06380c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 5 Aug 2024 20:13:18 +0300 Subject: [PATCH 0581/1647] bridgev2/userlogin: add option to not clean up rooms when deleting --- bridgev2/matrix/provisioning.go | 2 +- bridgev2/userlogin.go | 23 ++++++++++++++--------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 69223b02..244bca63 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -381,7 +381,7 @@ func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *Prov login.Override.Delete(ctx, status.BridgeState{ StateEvent: status.StateLoggedOut, Reason: "LOGIN_OVERRIDDEN", - }, true) + }, bridgev2.DeleteOpts{LogoutRemote: true}) } func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http.Request) { diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 0b1bd9e0..2615125a 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -186,7 +186,10 @@ func (user *User) NewLogin(ctx context.Context, data *database.UserLogin, params var doInsert bool if ul != nil && ul.UserMXID != user.MXID { if params.DeleteOnConflict { - ul.delete(ctx, status.BridgeState{StateEvent: status.StateLoggedOut, Reason: "LOGIN_OVERRIDDEN_ANOTHER_USER"}, false, true) + ul.Delete(ctx, status.BridgeState{StateEvent: status.StateLoggedOut, Reason: "LOGIN_OVERRIDDEN_ANOTHER_USER"}, DeleteOpts{ + LogoutRemote: false, + unlocked: true, + }) ul = nil } else { return nil, fmt.Errorf("%s is already logged in with that account", ul.UserMXID) @@ -239,27 +242,29 @@ func (ul *UserLogin) Save(ctx context.Context) error { } func (ul *UserLogin) Logout(ctx context.Context) { - ul.Delete(ctx, status.BridgeState{StateEvent: status.StateLoggedOut}, true) + ul.Delete(ctx, status.BridgeState{StateEvent: status.StateLoggedOut}, DeleteOpts{LogoutRemote: true}) } -func (ul *UserLogin) Delete(ctx context.Context, state status.BridgeState, logoutRemote bool) { - ul.delete(ctx, state, logoutRemote, false) +type DeleteOpts struct { + LogoutRemote bool + DontCleanupRooms bool + unlocked bool } -func (ul *UserLogin) delete(ctx context.Context, state status.BridgeState, logoutRemote, unlocked bool) { +func (ul *UserLogin) Delete(ctx context.Context, state status.BridgeState, opts DeleteOpts) { ul.deleteLock.Lock() defer ul.deleteLock.Unlock() if ul.BridgeState == nil { return } - if logoutRemote { + if opts.LogoutRemote { ul.Client.LogoutRemote(ctx) } else { ul.Disconnect(nil) } var portals []*database.UserPortal var err error - if ul.Bridge.Config.CleanupOnLogout.Enabled { + if !opts.DontCleanupRooms && ul.Bridge.Config.CleanupOnLogout.Enabled { portals, err = ul.Bridge.DB.UserPortal.GetAllForLogin(ctx, ul.UserLogin) if err != nil { ul.Log.Err(err).Msg("Failed to get user portals") @@ -269,12 +274,12 @@ func (ul *UserLogin) delete(ctx context.Context, state status.BridgeState, logou if err != nil { ul.Log.Err(err).Msg("Failed to delete user login") } - if !unlocked { + if !opts.unlocked { ul.Bridge.cacheLock.Lock() } delete(ul.User.logins, ul.ID) delete(ul.Bridge.userLoginsByID, ul.ID) - if !unlocked { + if !opts.unlocked { ul.Bridge.cacheLock.Unlock() } backgroundCtx := context.WithoutCancel(ctx) From 87d8d92867561af8a018503cc0fdce2c6a5feeb0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 5 Aug 2024 20:42:12 +0300 Subject: [PATCH 0582/1647] bridgev2/userlogin: log when deleting --- bridgev2/userlogin.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 2615125a..e279df47 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -252,6 +252,11 @@ type DeleteOpts struct { } func (ul *UserLogin) Delete(ctx context.Context, state status.BridgeState, opts DeleteOpts) { + cleanupRooms := !opts.DontCleanupRooms && ul.Bridge.Config.CleanupOnLogout.Enabled + zerolog.Ctx(ctx).Info().Str("user_login_id", string(ul.ID)). + Bool("logout_remote", opts.LogoutRemote). + Bool("cleanup_rooms", cleanupRooms). + Msg("Deleting user login") ul.deleteLock.Lock() defer ul.deleteLock.Unlock() if ul.BridgeState == nil { @@ -264,7 +269,7 @@ func (ul *UserLogin) Delete(ctx context.Context, state status.BridgeState, opts } var portals []*database.UserPortal var err error - if !opts.DontCleanupRooms && ul.Bridge.Config.CleanupOnLogout.Enabled { + if cleanupRooms { portals, err = ul.Bridge.DB.UserPortal.GetAllForLogin(ctx, ul.UserLogin) if err != nil { ul.Log.Err(err).Msg("Failed to get user portals") From 0f6fa7d691d1fb00622b4b9f2bbc5c06639e54b5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 5 Aug 2024 21:37:41 +0300 Subject: [PATCH 0583/1647] bridgev2/config: fix cleanup_on_logout section --- bridgev2/bridgeconfig/upgrade.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 15eeada0..6650b9d1 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -25,6 +25,15 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Str, "bridge", "command_prefix") helper.Copy(up.Bool, "bridge", "personal_filtering_spaces") helper.Copy(up.Bool, "bridge", "private_chat_portal_meta") + helper.Copy(up.Bool, "bridge", "cleanup_on_logout", "enabled") + helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "private") + helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "relayed") + helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "shared_no_users") + helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "shared_has_users") + helper.Copy(up.Str, "bridge", "cleanup_on_logout", "bad_credentials", "private") + helper.Copy(up.Str, "bridge", "cleanup_on_logout", "bad_credentials", "relayed") + helper.Copy(up.Str, "bridge", "cleanup_on_logout", "bad_credentials", "shared_no_users") + helper.Copy(up.Str, "bridge", "cleanup_on_logout", "bad_credentials", "shared_has_users") helper.Copy(up.Bool, "bridge", "relay", "enabled") helper.Copy(up.Bool, "bridge", "relay", "admin_only") helper.Copy(up.List, "bridge", "relay", "default_relays") @@ -199,15 +208,6 @@ func doMigrateLegacy(helper up.Helper) { } else { helper.Set(up.Bool, "false", "bridge", "private_chat_portal_meta") } - helper.Copy(up.Bool, "bridge", "cleanup_on_logout", "enabled") - helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "private") - helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "relayed") - helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "shared_no_users") - helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "shared_has_users") - helper.Copy(up.Str, "bridge", "cleanup_on_logout", "bad_credentials", "private") - helper.Copy(up.Str, "bridge", "cleanup_on_logout", "bad_credentials", "relayed") - helper.Copy(up.Str, "bridge", "cleanup_on_logout", "bad_credentials", "shared_no_users") - helper.Copy(up.Str, "bridge", "cleanup_on_logout", "bad_credentials", "shared_has_users") helper.Copy(up.Bool, "bridge", "relay", "enabled") helper.Copy(up.Bool, "bridge", "relay", "admin_only") helper.Copy(up.Map, "bridge", "permissions") From 9fffe6e54d7ec13dd281c865471af1cc0e44dbe4 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 5 Aug 2024 12:39:34 -0600 Subject: [PATCH 0584/1647] bridgev2/database: allow querying ghosts by metadata Signed-off-by: Sumner Evans --- bridgev2/database/ghost.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/bridgev2/database/ghost.go b/bridgev2/database/ghost.go index c4c626f0..c32929ad 100644 --- a/bridgev2/database/ghost.go +++ b/bridgev2/database/ghost.go @@ -44,8 +44,9 @@ const ( name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata FROM ghost ` - getGhostByIDQuery = getGhostBaseQuery + `WHERE bridge_id=$1 AND id=$2` - insertGhostQuery = ` + getGhostByIDQuery = getGhostBaseQuery + `WHERE bridge_id=$1 AND id=$2` + getGhostByMetadataQuery = getGhostBaseQuery + `WHERE bridge_id=$1 AND metadata->>$2=$3` + insertGhostQuery = ` INSERT INTO ghost ( bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc, name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata @@ -63,6 +64,12 @@ func (gq *GhostQuery) GetByID(ctx context.Context, id networkid.UserID) (*Ghost, return gq.QueryOne(ctx, getGhostByIDQuery, gq.BridgeID, id) } +// GetByMetadata returns the ghosts whose metadata field at the given JSON key +// matches the given value. +func (gq *GhostQuery) GetByMetadata(ctx context.Context, key string, value any) ([]*Ghost, error) { + return gq.QueryMany(ctx, getGhostByMetadataQuery, gq.BridgeID, key, value) +} + func (gq *GhostQuery) Insert(ctx context.Context, ghost *Ghost) error { ensureBridgeIDMatches(&ghost.BridgeID, gq.BridgeID) return gq.Exec(ctx, insertGhostQuery, ghost.ensureHasMetadata(gq.MetaType).sqlVariables()...) From 6ed1d410aa7d25e9c0c38992b90f60cc049f1f5e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 6 Aug 2024 00:39:00 +0300 Subject: [PATCH 0585/1647] bridgev2/portal: fix max count when syncing reactions --- bridgev2/portal.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 31439ee6..60f83e1a 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1848,11 +1848,12 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User doAddReaction(reaction) } } + totalReactionCount := len(existingUserReactions) + len(reactions.Reactions) if reactions.HasAllReactions { for _, existingReaction := range existingUserReactions { doRemoveReaction(existingReaction, nil) } - } else if reactions.MaxCount > 0 && len(existingUserReactions)+len(reactions.Reactions) > reactions.MaxCount { + } else if reactions.MaxCount > 0 && totalReactionCount > reactions.MaxCount { remainingReactionList := maps.Values(existingUserReactions) slices.SortFunc(remainingReactionList, func(a, b *database.Reaction) int { diff := a.Timestamp.Compare(b.Timestamp) @@ -1861,8 +1862,8 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User } return diff }) - numberToRemove := max(reactions.MaxCount-len(reactions.Reactions), len(remainingReactionList)) - for i := 0; i < numberToRemove; i++ { + numberToRemove := totalReactionCount - reactions.MaxCount + for i := 0; i < numberToRemove && i < len(remainingReactionList); i++ { doRemoveReaction(remainingReactionList[i], nil) } } From 9fffc05a7bcfd0ea3d9299bdd3b5feb35d350301 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 6 Aug 2024 18:58:18 +0300 Subject: [PATCH 0586/1647] bridgev2/portal: fix deleting database rows when syncing reactions --- bridgev2/portal.go | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 60f83e1a..7728e09d 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1799,7 +1799,7 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User ) return intent } - doRemoveReaction := func(old *database.Reaction, intent MatrixAPI) { + doRemoveReaction := func(old *database.Reaction, intent MatrixAPI, deleteRow bool) { if intent == nil && old.SenderMXID != "" { intent, err = portal.getIntentForMXID(ctx, old.SenderMXID) if err != nil { @@ -1823,10 +1823,16 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User if err != nil { log.Err(err).Msg("Failed to redact old reaction") } + if deleteRow { + err = portal.Bridge.DB.Reaction.Delete(ctx, old) + if err != nil { + log.Err(err).Msg("Failed to delete old reaction row") + } + } } doOverwriteReaction := func(new *BackfillReaction, old *database.Reaction) { intent := doAddReaction(new) - doRemoveReaction(old, intent) + doRemoveReaction(old, intent, false) } newData := evt.GetReactions() @@ -1851,7 +1857,7 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User totalReactionCount := len(existingUserReactions) + len(reactions.Reactions) if reactions.HasAllReactions { for _, existingReaction := range existingUserReactions { - doRemoveReaction(existingReaction, nil) + doRemoveReaction(existingReaction, nil, true) } } else if reactions.MaxCount > 0 && totalReactionCount > reactions.MaxCount { remainingReactionList := maps.Values(existingUserReactions) @@ -1864,14 +1870,14 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User }) numberToRemove := totalReactionCount - reactions.MaxCount for i := 0; i < numberToRemove && i < len(remainingReactionList); i++ { - doRemoveReaction(remainingReactionList[i], nil) + doRemoveReaction(remainingReactionList[i], nil, true) } } } if newData.HasAllUsers { for _, userReactions := range existing { for _, existingReaction := range userReactions { - doRemoveReaction(existingReaction, nil) + doRemoveReaction(existingReaction, nil, true) } } } From f6b0feab9566f922f59fcf8569f4fbc2be17d9ce Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 6 Aug 2024 18:58:36 +0300 Subject: [PATCH 0587/1647] bridgev2/chatinfo: add utility for merging ExtraUpdaters --- bridgev2/ghost.go | 2 +- bridgev2/portal.go | 22 +++++++++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index 32e8e0c9..ced34925 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -128,7 +128,7 @@ type UserInfo struct { Avatar *Avatar IsBot *bool - ExtraUpdates func(context.Context, *Ghost) bool + ExtraUpdates ExtraUpdater[*Ghost] } func (ghost *Ghost) UpdateName(ctx context.Context, name string) bool { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 7728e09d..dbeb0660 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2362,7 +2362,27 @@ type ChatInfo struct { CanBackfill bool - ExtraUpdates func(context.Context, *Portal) bool + ExtraUpdates ExtraUpdater[*Portal] +} + +type ExtraUpdater[T any] func(context.Context, T) bool + +func MergeExtraUpdaters[T any](funcs ...ExtraUpdater[T]) ExtraUpdater[T] { + funcs = slices.DeleteFunc(funcs, func(f ExtraUpdater[T]) bool { + return f == nil + }) + if len(funcs) == 0 { + return nil + } else if len(funcs) == 1 { + return funcs[0] + } + return func(ctx context.Context, p T) bool { + changed := false + for _, f := range funcs { + changed = f(ctx, p) || changed + } + return changed + } } var Unmuted = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) From 73a63a12fbe34d6e46fae8aba5a337a52f8aa97c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 6 Aug 2024 21:51:05 +0300 Subject: [PATCH 0588/1647] bridgev2/portal: fix event type when sending state with bot fallback --- bridgev2/portal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index dbeb0660..87a520cd 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2499,7 +2499,7 @@ func (portal *Portal) sendStateWithIntentOrBot(ctx context.Context, sender Matri content.Raw = make(map[string]any) } content.Raw["fi.mau.bridge.set_by"] = sender.GetMXID() - resp, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateRoomName, "", content, ts) + resp, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, eventType, "", content, ts) } return } From 213f9df4a4674a08581f3d1ba13a2f1b974da9aa Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 6 Aug 2024 21:53:40 +0300 Subject: [PATCH 0589/1647] bridgev2/portal: also fix state key when sending with bot fallback --- bridgev2/portal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 87a520cd..2a6f269d 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2499,7 +2499,7 @@ func (portal *Portal) sendStateWithIntentOrBot(ctx context.Context, sender Matri content.Raw = make(map[string]any) } content.Raw["fi.mau.bridge.set_by"] = sender.GetMXID() - resp, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, eventType, "", content, ts) + resp, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, eventType, stateKey, content, ts) } return } From e0f58dccf432b733e1cdaf088f252c6ff554eae7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 7 Aug 2024 00:42:51 +0300 Subject: [PATCH 0590/1647] bridgev2/provisioning: remove leftover debug print --- bridgev2/matrix/provisioning.go | 1 - 1 file changed, 1 deletion(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 244bca63..c3b8c3dc 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -618,7 +618,6 @@ func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Reque apiContact := &RespResolveIdentifier{ ID: contact.UserID, } - fmt.Println(contact.UserInfo.Identifiers) apiResp.Contacts[i] = apiContact if contact.UserInfo != nil { if contact.UserInfo.Name != nil { From 45527281cc1812cc1e1af9ad84ed1f2078b2ef57 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 7 Aug 2024 14:38:00 +0300 Subject: [PATCH 0591/1647] bridgev2/backfill: add support for batch limit overrides --- bridgev2/backfillqueue.go | 38 +++++++++++++++++++++++++----- bridgev2/bridgeconfig/backfill.go | 8 +++++++ bridgev2/database/backfillqueue.go | 2 +- bridgev2/database/message.go | 9 +++++++ bridgev2/networkinterface.go | 12 ++++++++-- bridgev2/portalbackfill.go | 1 + 6 files changed, 61 insertions(+), 9 deletions(-) diff --git a/bridgev2/backfillqueue.go b/bridgev2/backfillqueue.go index 345c3391..63a01f68 100644 --- a/bridgev2/backfillqueue.go +++ b/bridgev2/backfillqueue.go @@ -180,13 +180,39 @@ func (br *Bridge) actuallyDoBackfillTask(ctx context.Context, task *database.Bac return false, nil } } - maxBatches := br.Config.Backfill.Queue.MaxBatches - // TODO apply max batch overrides - err = portal.DoBackwardsBackfill(ctx, login, task) - if err != nil { - return false, fmt.Errorf("failed to backfill: %w", err) + if task.BatchCount < 0 { + var msgCount int + msgCount, err = br.DB.Message.CountMessagesInPortal(ctx, task.PortalKey) + if err != nil { + return false, fmt.Errorf("failed to count messages in portal: %w", err) + } + task.BatchCount = msgCount / br.Config.Backfill.Queue.BatchSize + log.Debug(). + Int("message_count", msgCount). + Int("batch_count", task.BatchCount). + Msg("Calculated existing batch count") + } + maxBatches := br.Config.Backfill.Queue.MaxBatches + api, ok := login.Client.(BackfillingNetworkAPI) + if !ok { + return false, fmt.Errorf("network API does not support backfilling") + } + limiterAPI, ok := api.(BackfillingNetworkAPIWithLimits) + if ok { + maxBatches = limiterAPI.GetBackfillMaxBatchCount(ctx, portal, task) + } + if maxBatches < 0 || maxBatches > task.BatchCount { + err = portal.DoBackwardsBackfill(ctx, login, task) + if err != nil { + return false, fmt.Errorf("failed to backfill: %w", err) + } + task.BatchCount++ + } else { + log.Debug(). + Int("max_batches", maxBatches). + Int("batch_count", task.BatchCount). + Msg("Not actually backfilling: max batches reached") } - task.BatchCount++ task.IsDone = task.IsDone || (maxBatches > 0 && task.BatchCount >= maxBatches) batchDelay := time.Duration(br.Config.Backfill.Queue.BatchDelay) * time.Second task.CompletedAt = time.Now() diff --git a/bridgev2/bridgeconfig/backfill.go b/bridgev2/bridgeconfig/backfill.go index fe464569..44d2d588 100644 --- a/bridgev2/bridgeconfig/backfill.go +++ b/bridgev2/bridgeconfig/backfill.go @@ -28,3 +28,11 @@ type BackfillQueueConfig struct { MaxBatchesOverride map[string]int `yaml:"max_batches_override"` } + +func (bqc *BackfillQueueConfig) GetOverride(name string) int { + override, ok := bqc.MaxBatchesOverride[name] + if !ok { + return bqc.MaxBatches + } + return override +} diff --git a/bridgev2/database/backfillqueue.go b/bridgev2/database/backfillqueue.go index 5d7cf854..fed7452d 100644 --- a/bridgev2/database/backfillqueue.go +++ b/bridgev2/database/backfillqueue.go @@ -40,7 +40,7 @@ var BackfillNextDispatchNever = time.Unix(0, (1<<63)-1) const ( ensureBackfillExistsQuery = ` INSERT INTO backfill_task (bridge_id, portal_id, portal_receiver, user_login_id, batch_count, is_done, next_dispatch_min_ts) - VALUES ($1, $2, $3, $4, 0, false, $5) + VALUES ($1, $2, $3, $4, -1, false, $5) ON CONFLICT (bridge_id, portal_id, portal_receiver) DO UPDATE SET user_login_id=CASE WHEN backfill_task.user_login_id='' diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index b2e023d0..8173ad05 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -64,6 +64,10 @@ const ( getLastMessagePartAtOrBeforeTimeQuery = getMessageBaseQuery + `WHERE bridge_id = $1 AND room_id=$2 AND room_receiver=$3 AND timestamp<=$4 ORDER BY timestamp DESC, part_id DESC LIMIT 1` + countMessagesInPortalQuery = ` + SELECT COUNT(*) FROM message WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 + ` + insertMessageQuery = ` INSERT INTO message ( bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, sender_mxid, @@ -155,6 +159,11 @@ func (mq *MessageQuery) Delete(ctx context.Context, rowID int64) error { return mq.Exec(ctx, deleteMessagePartByRowIDQuery, mq.BridgeID, rowID) } +func (mq *MessageQuery) CountMessagesInPortal(ctx context.Context, key networkid.PortalKey) (count int, err error) { + err = mq.GetDB().QueryRow(ctx, countMessagesInPortalQuery, mq.BridgeID, key.ID, key.Receiver).Scan(&count) + return +} + func (m *Message) Scan(row dbutil.Scannable) (*Message, error) { var timestamp int64 var threadRootID, replyToID, replyToPartID sql.NullString diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index a28b13d2..54ec81a8 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -114,8 +114,8 @@ func (cm *ConvertedMessage) MergeCaption() bool { if len(cm.Parts) != 2 { return false } - textPart, mediaPart := cm.Parts[0], cm.Parts[1] - if textPart.Content.MsgType.IsMedia() { + textPart, mediaPart := cm.Parts[1], cm.Parts[0] + if textPart.Content.MsgType != event.MsgText { textPart, mediaPart = mediaPart, textPart } if (!mediaPart.Content.MsgType.IsMedia() && mediaPart.Content.MsgType != event.MsgNotice) || textPart.Content.MsgType != event.MsgText { @@ -369,6 +369,9 @@ type FetchMessagesParams struct { // The preferred number of messages to return. The returned batch can be bigger or smaller // without any side effects, but the network connector should aim for this number. Count int + + // When the messages are being fetched for a queued backfill, this is the task object. + Task *database.BackfillTask } // BackfillReaction is an individual reaction to a message in a history pagination request. @@ -436,6 +439,11 @@ type BackfillingNetworkAPI interface { FetchMessages(ctx context.Context, fetchParams FetchMessagesParams) (*FetchMessagesResponse, error) } +type BackfillingNetworkAPIWithLimits interface { + BackfillingNetworkAPI + GetBackfillMaxBatchCount(ctx context.Context, portal *Portal, task *database.BackfillTask) int +} + // EditHandlingNetworkAPI is an optional interface that network connectors can implement to handle message edits. type EditHandlingNetworkAPI interface { NetworkAPI diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index 77fe23b1..6d4124e8 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -96,6 +96,7 @@ func (portal *Portal) DoBackwardsBackfill(ctx context.Context, source *UserLogin Cursor: task.Cursor, AnchorMessage: firstMessage, Count: portal.Bridge.Config.Backfill.Queue.BatchSize, + Task: task, }) if err != nil { return fmt.Errorf("failed to fetch messages for backward backfill: %w", err) From 1bec37c942d30b82aa6caf0448b412e896b93461 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 7 Aug 2024 18:51:36 +0300 Subject: [PATCH 0592/1647] bridgev2/simplevent: add type to typing events --- bridgev2/simplevent/receipt.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/bridgev2/simplevent/receipt.go b/bridgev2/simplevent/receipt.go index e9835a66..3565986b 100644 --- a/bridgev2/simplevent/receipt.go +++ b/bridgev2/simplevent/receipt.go @@ -54,12 +54,18 @@ func (evt *MarkUnread) GetUnread() bool { type Typing struct { EventMeta Timeout time.Duration + Type bridgev2.TypingType } var ( - _ bridgev2.RemoteTyping = (*Typing)(nil) + _ bridgev2.RemoteTyping = (*Typing)(nil) + _ bridgev2.RemoteTypingWithType = (*Typing)(nil) ) func (evt *Typing) GetTimeout() time.Duration { return evt.Timeout } + +func (evt *Typing) GetTypingType() bridgev2.TypingType { + return evt.Type +} From eabab275895d69ca11c2c3e653f809418b5d451f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 7 Aug 2024 18:58:38 +0300 Subject: [PATCH 0593/1647] bridgev2/portal: log when event is dropped --- bridgev2/portal.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 2a6f269d..e8d5c7f1 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1252,6 +1252,7 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { if portal.MXID == "" { mcp, ok := evt.(RemoteEventThatMayCreatePortal) if !ok || !mcp.ShouldCreatePortal() { + log.Debug().Msg("Dropping event as portal doesn't exist") return } infoProvider, ok := mcp.(RemoteChatResyncWithInfo) From 6c836c6ebdabb770180b9ccfa1e73b6e2e911a0b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 7 Aug 2024 23:32:17 +0300 Subject: [PATCH 0594/1647] crypto: adjust log when rejecting duplicate message index --- crypto/sql_store.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crypto/sql_store.go b/crypto/sql_store.go index d93ee0ca..a8ccab26 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -624,7 +624,7 @@ func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey Str("expected_event_id", expectedEventID.String()). Int64("expected_timestamp", expectedTimestamp). Int64("actual_timestamp", timestamp). - Msg("Failed to validate that message index wasn't duplicated") + Msg("Rejecting different event with duplicate message index") return false, nil } return true, nil From 5d4407950a07f04d01bcca1c466f2e4c0ee3be3f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 8 Aug 2024 01:58:30 +0300 Subject: [PATCH 0595/1647] bridgev2: add IsGhostMXID helper function --- bridgev2/ghost.go | 5 +++++ bridgev2/matrix/intent.go | 3 +-- bridgev2/matrix/matrix.go | 9 +-------- bridgev2/portal.go | 6 ++---- 4 files changed, 9 insertions(+), 14 deletions(-) diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index ced34925..e4e007cd 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -69,6 +69,11 @@ func (br *Bridge) unlockedGetGhostByID(ctx context.Context, id networkid.UserID, return br.loadGhost(ctx, db, err, idPtr) } +func (br *Bridge) IsGhostMXID(userID id.UserID) bool { + _, isGhost := br.Matrix.ParseGhostMXID(userID) + return isGhost +} + func (br *Bridge) GetGhostByMXID(ctx context.Context, mxid id.UserID) (*Ghost, error) { ghostID, ok := br.Matrix.ParseGhostMXID(mxid) if !ok { diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index a4fa1b14..ac771734 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -358,8 +358,7 @@ func (as *ASIntent) DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnl if member == as.Matrix.UserID { continue } - _, isGhost := as.Connector.ParseGhostMXID(member) - if isGhost { + if as.Connector.Bridge.IsGhostMXID(member) { _, err = as.Connector.AS.Intent(member).LeaveRoom(ctx, roomID) if err != nil { zerolog.Ctx(ctx).Err(err).Stringer("user_id", member).Msg("Failed to leave room while cleaning up portal") diff --git a/bridgev2/matrix/matrix.go b/bridgev2/matrix/matrix.go index a7261a61..1117fca2 100644 --- a/bridgev2/matrix/matrix.go +++ b/bridgev2/matrix/matrix.go @@ -165,14 +165,7 @@ func (br *Connector) sendBridgeCheckpoint(ctx context.Context, evt *event.Event) } func (br *Connector) shouldIgnoreEventFromUser(userID id.UserID) bool { - if userID == br.Bot.UserID { - return true - } - _, isGhost := br.ParseGhostMXID(userID) - if isGhost { - return true - } - return false + return userID == br.Bot.UserID || br.Bridge.IsGhostMXID(userID) } func (br *Connector) shouldIgnoreEvent(evt *event.Event) bool { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index e8d5c7f1..c9066cba 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -385,8 +385,7 @@ func (portal *Portal) checkConfusableName(ctx context.Context, userID id.UserID, } for _, confusable := range confusableWith { // Don't disambiguate names that only conflict with ghosts of this bridge - _, isGhost := portal.Bridge.Matrix.ParseGhostMXID(confusable) - if !isGhost { + if !portal.Bridge.IsGhostMXID(confusable) { return true } } @@ -2745,8 +2744,7 @@ func (portal *Portal) syncParticipants(ctx context.Context, members *ChatMemberL if memberEvt.Membership == event.MembershipLeave || memberEvt.Membership == event.MembershipBan { continue } - _, isGhost := portal.Bridge.Matrix.ParseGhostMXID(extraMember) - if !isGhost && portal.Relay != nil { + if !portal.Bridge.IsGhostMXID(extraMember) && portal.Relay != nil { continue } _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateMember, extraMember.String(), &event.Content{ From 962ac6bf1768c957026ba99bea3f6d5a25b7f8b0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 8 Aug 2024 12:45:24 +0300 Subject: [PATCH 0596/1647] appservice: ensure registered before uploading media --- appservice/intent.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/appservice/intent.go b/appservice/intent.go index 9d6b55e5..6848f28c 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -436,6 +436,20 @@ func (intent *IntentAPI) SetRoomTopic(ctx context.Context, roomID id.RoomID, top }) } +func (intent *IntentAPI) UploadMedia(ctx context.Context, data mautrix.ReqUploadMedia) (*mautrix.RespMediaUpload, error) { + if err := intent.EnsureRegistered(ctx); err != nil { + return nil, err + } + return intent.Client.UploadMedia(ctx, data) +} + +func (intent *IntentAPI) UploadAsync(ctx context.Context, data mautrix.ReqUploadMedia) (*mautrix.RespCreateMXC, error) { + if err := intent.EnsureRegistered(ctx); err != nil { + return nil, err + } + return intent.Client.UploadAsync(ctx, data) +} + func (intent *IntentAPI) SetDisplayName(ctx context.Context, displayName string) error { if err := intent.EnsureRegistered(ctx); err != nil { return err From b5f968d8c386de0154fd8da8d83b827ba7cdd6a5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 8 Aug 2024 13:51:12 +0300 Subject: [PATCH 0597/1647] bridgev2/backfill: log more details when inserting message fails --- bridgev2/portalbackfill.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index 6d4124e8..b5ac08bb 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -362,6 +362,9 @@ func (portal *Portal) sendBatch(ctx context.Context, source *UserLogin, messages zerolog.Ctx(ctx).Err(err). Str("message_id", string(msg.ID)). Str("part_id", string(msg.PartID)). + Str("sender_id", string(msg.SenderID)). + Str("portal_id", string(msg.Room.ID)). + Str("portal_receiver", string(msg.Room.Receiver)). Msg("Failed to insert backfilled message to database") } } From e188e7abc3eea2da16815e691d59f97855d6cc12 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 8 Aug 2024 20:44:25 +0300 Subject: [PATCH 0598/1647] bridgev2/portal: allow cancelling remote edit handling --- bridgev2/portal.go | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index c9066cba..dea81336 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1244,9 +1244,9 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { }() evtType := evt.GetType() log.UpdateContext(func(c zerolog.Context) zerolog.Context { - return c.Stringer("bridge_evt_type", evtType) + c = c.Stringer("bridge_evt_type", evtType) + return evt.AddLogContext(c) }) - log.UpdateContext(evt.AddLogContext) ctx := log.WithContext(context.TODO()) if portal.MXID == "" { mcp, ok := evt.(RemoteEventThatMayCreatePortal) @@ -1603,7 +1603,7 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin converted, err := evt.ConvertMessage(ctx, portal, intent) if err != nil { if errors.Is(err, ErrIgnoringRemoteEvent) { - log.Debug().Err(err).Msg("Remote event handling was cancelled by convert function") + log.Debug().Err(err).Msg("Remote message handling was cancelled by convert function") } else { log.Err(err).Msg("Failed to convert remote message") portal.sendRemoteErrorNotice(ctx, intent, err, ts, "message") @@ -1658,7 +1658,10 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e } ts := getEventTS(evt) converted, err := evt.ConvertEdit(ctx, portal, intent, existing) - if err != nil { + if errors.Is(err, ErrIgnoringRemoteEvent) { + log.Debug().Err(err).Msg("Remote edit handling was cancelled by convert function") + return + } else if err != nil { log.Err(err).Msg("Failed to convert remote edit") portal.sendRemoteErrorNotice(ctx, intent, err, ts, "edit") return From 5edfcff2b7b6ab14cac81866e74f10c2d37054fc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 8 Aug 2024 20:44:55 +0300 Subject: [PATCH 0599/1647] bridgev2/networkinterface: define function for validating remote user IDs --- bridgev2/networkinterface.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 54ec81a8..7e73726c 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -241,6 +241,11 @@ type DirectMediableNetwork interface { Download(ctx context.Context, mediaID networkid.MediaID) (mediaproxy.GetMediaResponse, error) } +type IdentifierValidatingNetwork interface { + NetworkConnector + ValidateUserID(id networkid.UserID) bool +} + // ConfigValidatingNetwork is an optional interface that network connectors can implement to validate config fields // before the bridge is started. // From 0e4780cf1f65c72154fc527c7e10f97f4240ce1e Mon Sep 17 00:00:00 2001 From: Malte E <97891689+maltee1@users.noreply.github.com> Date: Fri, 9 Aug 2024 13:30:12 +0200 Subject: [PATCH 0600/1647] bridgev2/networkinterface: add support for handling Matrix member changes (#265) --- bridgev2/bridgeconfig/config.go | 1 + bridgev2/messagestatus.go | 10 ++--- bridgev2/networkinterface.go | 38 +++++++++++++++++++ bridgev2/portal.go | 66 +++++++++++++++++++++++++++++++++ 4 files changed, 110 insertions(+), 5 deletions(-) diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index 74104dec..b7a0ff37 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -63,6 +63,7 @@ type BridgeConfig struct { Relay RelayConfig `yaml:"relay"` Permissions PermissionConfig `yaml:"permissions"` Backfill BackfillConfig `yaml:"backfill"` + BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` } type MatrixConfig struct { diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index 4b438b64..4c61a7a9 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -36,11 +36,11 @@ var ( ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true) ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) - - ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) - ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true) - ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true) - ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true) + ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) + ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true) + ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true) + ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true) + ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) ) type MessageStatusEventInfo struct { diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 7e73726c..a833749f 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -598,6 +598,44 @@ type GroupCreatingNetworkAPI interface { CreateGroup(ctx context.Context, name string, users ...networkid.UserID) (*CreateChatResponse, error) } +type MembershipChangeType struct { + From event.Membership + To event.Membership + IsSelf bool +} + +var ( + AcceptInvite = MembershipChangeType{From: event.MembershipInvite, To: event.MembershipJoin, IsSelf: true} + RevokeInvite = MembershipChangeType{From: event.MembershipInvite, To: event.MembershipLeave} + RejectInvite = MembershipChangeType{From: event.MembershipInvite, To: event.MembershipLeave, IsSelf: true} + BanInvited = MembershipChangeType{From: event.MembershipInvite, To: event.MembershipBan} + ProfileChange = MembershipChangeType{From: event.MembershipJoin, To: event.MembershipJoin, IsSelf: true} + Leave = MembershipChangeType{From: event.MembershipJoin, To: event.MembershipLeave, IsSelf: true} + Kick = MembershipChangeType{From: event.MembershipJoin, To: event.MembershipLeave} + BanJoined = MembershipChangeType{From: event.MembershipJoin, To: event.MembershipBan} + Invite = MembershipChangeType{From: event.MembershipLeave, To: event.MembershipInvite} + Join = MembershipChangeType{From: event.MembershipLeave, To: event.MembershipJoin} + BanLeft = MembershipChangeType{From: event.MembershipLeave, To: event.MembershipBan} + Knock = MembershipChangeType{From: event.MembershipLeave, To: event.MembershipKnock, IsSelf: true} + AcceptKnock = MembershipChangeType{From: event.MembershipKnock, To: event.MembershipInvite} + RejectKnock = MembershipChangeType{From: event.MembershipKnock, To: event.MembershipLeave} + RetractKnock = MembershipChangeType{From: event.MembershipKnock, To: event.MembershipLeave, IsSelf: true} + BanKnocked = MembershipChangeType{From: event.MembershipKnock, To: event.MembershipBan} + Unban = MembershipChangeType{From: event.MembershipBan, To: event.MembershipLeave} +) + +type MatrixMembershipChange struct { + MatrixEventBase[*event.MemberEventContent] + TargetGhost *Ghost + TargetUserLogin *UserLogin + Type MembershipChangeType +} + +type MembershipHandlingNetworkAPI interface { + NetworkAPI + HandleMatrixMembership(ctx context.Context, msg *MatrixMembershipChange) (bool, error) +} + type PushType int func (pt PushType) String() string { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index dea81336..778bb75c 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -492,6 +492,8 @@ func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) { handleMatrixAccountData(portal, ctx, login, evt, TagHandlingNetworkAPI.HandleRoomTag) case event.AccountDataBeeperMute: handleMatrixAccountData(portal, ctx, login, evt, MuteHandlingNetworkAPI.HandleMute) + case event.StateMember: + portal.handleMatrixMembership(ctx, login, origSender, evt) } } @@ -3245,3 +3247,67 @@ func (portal *Portal) SetRelay(ctx context.Context, relay *UserLogin) error { } return nil } + +func (portal *Portal) handleMatrixMembership( + ctx context.Context, + sender *UserLogin, + origSender *OrigSender, + evt *event.Event, +) { + api, ok := sender.Client.(MembershipHandlingNetworkAPI) + if !ok { + portal.sendErrorStatus(ctx, evt, ErrMembershipNotSupported) + return + } + log := zerolog.Ctx(ctx) + targetMXID := id.UserID(*evt.StateKey) + isSelf := sender.User.MXID == targetMXID + var err error + var targetUserLogin *UserLogin + targetGhost, err := portal.Bridge.GetGhostByMXID(ctx, targetMXID) + if err != nil { + log.Err(err).Stringer("mxid", targetMXID).Msg("Failed to get target ghost") + return + } + if targetGhost == nil { + targetUser, err := portal.Bridge.GetUserByMXID(ctx, targetMXID) + if err != nil { + log.Err(err).Stringer("mxid", targetMXID).Msg("Failed to get target user") + return + } + targetUserLogin, _, err = portal.FindPreferredLogin(ctx, targetUser, false) + if err != nil { + log.Err(err).Stringer("mxid", targetMXID).Msg("Failed to get target user login") + return + } + } + prevContent := &event.MemberEventContent{Membership: event.MembershipLeave} + if evt.Unsigned.PrevContent != nil { + _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) + prevContent, _ = evt.Unsigned.PrevContent.Parsed.(*event.MemberEventContent) + } + + content := evt.Content.AsMember() + membershipChangeType := MembershipChangeType{From: prevContent.Membership, To: content.Membership, IsSelf: isSelf} + if !portal.Bridge.Config.BridgeMatrixLeave && membershipChangeType == Leave { + log.Debug().Msg("Dropping leave event") + return + } + membershipChange := &MatrixMembershipChange{ + MatrixEventBase: MatrixEventBase[*event.MemberEventContent]{ + Event: evt, + Content: content, + Portal: portal, + OrigSender: origSender, + }, + TargetGhost: targetGhost, + TargetUserLogin: targetUserLogin, + Type: membershipChangeType, + } + _, err = api.HandleMatrixMembership(ctx, membershipChange) + if err != nil { + log.Err(err).Msg("Failed to handle Matrix Membership Change") + portal.sendErrorStatus(ctx, evt, err) + return + } +} From 8c0f705ee90beb3dcf2c56770f1b2f420d5ab212 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 9 Aug 2024 15:33:04 +0300 Subject: [PATCH 0601/1647] bridgev2: add support for receiving events with uncertain portal receivers --- bridgev2/networkinterface.go | 5 +++++ bridgev2/queue.go | 20 ++++++++++++++++++-- bridgev2/simplevent/meta.go | 24 +++++++++++++++--------- 3 files changed, 38 insertions(+), 11 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index a833749f..f6aedbc6 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -771,6 +771,11 @@ type RemoteEvent interface { GetSender() EventSender } +type RemoteEventWithUncertainPortalReceiver interface { + RemoteEvent + PortalReceiverIsUncertain() bool +} + type RemotePreHandler interface { RemoteEvent PreHandle(ctx context.Context, portal *Portal) diff --git a/bridgev2/queue.go b/bridgev2/queue.go index e5220227..bcf092b8 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -139,11 +139,27 @@ func (br *Bridge) handleBotInvite(ctx context.Context, evt *event.Event, sender func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) { log := login.Log ctx := log.WithContext(context.TODO()) - portal, err := br.GetPortalByKey(ctx, evt.GetPortalKey()) + maybeUncertain, ok := evt.(RemoteEventWithUncertainPortalReceiver) + isUncertain := ok && maybeUncertain.PortalReceiverIsUncertain() + key := evt.GetPortalKey() + var portal *Portal + var err error + if isUncertain { + portal, err = br.GetExistingPortalByKey(ctx, key) + } else { + portal, err = br.GetPortalByKey(ctx, key) + } if err != nil { - log.Err(err).Object("portal_id", evt.GetPortalKey()). + log.Err(err).Object("portal_key", key).Bool("uncertain_receiver", isUncertain). Msg("Failed to get portal to handle remote event") return + } else if portal == nil { + log.Warn(). + Stringer("event_type", evt.GetType()). + Object("portal_key", key). + Bool("uncertain_receiver", isUncertain). + Msg("Portal not found to handle remote event") + return } // TODO put this in a better place, and maybe cache to avoid constant db queries login.MarkInPortal(ctx, portal) diff --git a/bridgev2/simplevent/meta.go b/bridgev2/simplevent/meta.go index afdbf65a..15b97b8b 100644 --- a/bridgev2/simplevent/meta.go +++ b/bridgev2/simplevent/meta.go @@ -17,18 +17,20 @@ import ( // EventMeta is a struct containing metadata fields used by most event types. type EventMeta struct { - Type bridgev2.RemoteEventType - LogContext func(c zerolog.Context) zerolog.Context - PortalKey networkid.PortalKey - Sender bridgev2.EventSender - CreatePortal bool - Timestamp time.Time + Type bridgev2.RemoteEventType + LogContext func(c zerolog.Context) zerolog.Context + PortalKey networkid.PortalKey + UncertainReceiver bool + Sender bridgev2.EventSender + CreatePortal bool + Timestamp time.Time } var ( - _ bridgev2.RemoteEvent = (*EventMeta)(nil) - _ bridgev2.RemoteEventThatMayCreatePortal = (*EventMeta)(nil) - _ bridgev2.RemoteEventWithTimestamp = (*EventMeta)(nil) + _ bridgev2.RemoteEvent = (*EventMeta)(nil) + _ bridgev2.RemoteEventWithUncertainPortalReceiver = (*EventMeta)(nil) + _ bridgev2.RemoteEventThatMayCreatePortal = (*EventMeta)(nil) + _ bridgev2.RemoteEventWithTimestamp = (*EventMeta)(nil) ) func (evt *EventMeta) AddLogContext(c zerolog.Context) zerolog.Context { @@ -42,6 +44,10 @@ func (evt *EventMeta) GetPortalKey() networkid.PortalKey { return evt.PortalKey } +func (evt *EventMeta) PortalReceiverIsUncertain() bool { + return evt.UncertainReceiver +} + func (evt *EventMeta) GetTimestamp() time.Time { if evt.Timestamp.IsZero() { return time.Now() From eb84187368b7559bd2bc1b15df526adc4fdf6e31 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 9 Aug 2024 16:53:20 +0300 Subject: [PATCH 0602/1647] bridgev2/logout: fix has receiver check --- bridgev2/userlogin.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index e279df47..d1711b31 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -437,7 +437,7 @@ func (ul *UserLogin) getLogoutAction(ctx context.Context, up *database.UserPorta if badCredentials { actionsSet = ul.Bridge.Config.CleanupOnLogout.BadCredentials } - if portal.Receiver == "" { + if portal.Receiver != "" { return portal, actionsSet.Private, "portal has receiver", nil } otherUPs, err := ul.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) From 11f93d735ec550840c2451540c135daaea6b09ae Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 9 Aug 2024 22:42:23 +0300 Subject: [PATCH 0603/1647] bridgev2: don't change power levels without permission --- bridgev2/portal.go | 38 +++++++++++++++++++++++++------------- event/powerlevels.go | 26 +++++++++++++++++++++++--- 2 files changed, 48 insertions(+), 16 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 778bb75c..bed794a0 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2310,38 +2310,48 @@ type PowerLevelChanges struct { Custom func(*event.PowerLevelsEventContent) bool } -func (plc *PowerLevelChanges) Apply(content *event.PowerLevelsEventContent) (changed bool) { +func allowChange(newLevel, oldLevel, actorLevel int) bool { + return newLevel <= actorLevel && oldLevel <= actorLevel +} + +func (plc *PowerLevelChanges) Apply(actor id.UserID, content *event.PowerLevelsEventContent) (changed bool) { if plc == nil || content == nil { return } for evtType, level := range plc.Events { - changed = content.EnsureEventLevel(evtType, level) || changed + changed = content.EnsureEventLevelAs(actor, evtType, level) || changed } - if plc.UsersDefault != nil { + var actorLevel int + if actor != "" { + actorLevel = content.GetUserLevel(actor) + } else { + actorLevel = (1 << 31) - 1 + } + if plc.UsersDefault != nil && allowChange(*plc.UsersDefault, content.UsersDefault, actorLevel) { changed = content.UsersDefault != *plc.UsersDefault content.UsersDefault = *plc.UsersDefault } - if plc.EventsDefault != nil { + if plc.EventsDefault != nil && allowChange(*plc.EventsDefault, content.EventsDefault, actorLevel) { changed = content.EventsDefault != *plc.EventsDefault content.EventsDefault = *plc.EventsDefault } - if plc.StateDefault != nil { + if plc.StateDefault != nil && allowChange(*plc.StateDefault, content.StateDefault(), actorLevel) { changed = content.StateDefault() != *plc.StateDefault content.StateDefaultPtr = plc.StateDefault } - if plc.Invite != nil { + if plc.Invite != nil && allowChange(*plc.Invite, content.Invite(), actorLevel) { changed = content.Invite() != *plc.Invite content.InvitePtr = plc.Invite } - if plc.Kick != nil { + if plc.Kick != nil && allowChange(*plc.Kick, content.Kick(), actorLevel) { changed = content.Kick() != *plc.Kick content.KickPtr = plc.Kick } - if plc.Ban != nil { + if plc.Ban != nil && allowChange(*plc.Ban, content.Ban(), actorLevel) { changed = content.Ban() != *plc.Ban content.BanPtr = plc.Ban } - if plc.Redact != nil { + if plc.Redact != nil && allowChange(*plc.Redact, content.Redact(), actorLevel) { changed = content.Redact() != *plc.Redact content.RedactPtr = plc.Redact } @@ -2545,7 +2555,7 @@ func (portal *Portal) getInitialMemberList(ctx context.Context, members *ChatMem return } } - members.PowerLevels.Apply(pl) + members.PowerLevels.Apply("", pl) for _, member := range members.Members { if member.Membership != event.MembershipJoin && member.Membership != "" { continue @@ -2627,13 +2637,13 @@ func (portal *Portal) syncParticipants(ctx context.Context, members *ChatMemberL return fmt.Errorf("failed to get current members: %w", err) } delete(currentMembers, portal.Bridge.Bot.GetMXID()) - powerChanged := members.PowerLevels.Apply(currentPower) + powerChanged := members.PowerLevels.Apply(portal.Bridge.Bot.GetMXID(), currentPower) syncUser := func(extraUserID id.UserID, member ChatMember, hasIntent bool) bool { if member.Membership == "" { member.Membership = event.MembershipJoin } if member.PowerLevel != nil { - powerChanged = currentPower.EnsureUserLevel(extraUserID, *member.PowerLevel) || powerChanged + powerChanged = currentPower.EnsureUserLevelAs(portal.Bridge.Bot.GetMXID(), extraUserID, *member.PowerLevel) || powerChanged } currentMember, ok := currentMembers[extraUserID] delete(currentMembers, extraUserID) @@ -3038,7 +3048,9 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo event.StateServerACL.Type: 100, event.StateEncryption.Type: 100, }, - Users: map[id.UserID]int{}, + Users: map[id.UserID]int{ + portal.Bridge.Bot.GetMXID(): 9001, + }, } initialMembers, extraFunctionalMembers, err := portal.getInitialMemberList(ctx, info.Members, source, powerLevels) if err != nil { diff --git a/event/powerlevels.go b/event/powerlevels.go index 1882f1e9..2f4d4573 100644 --- a/event/powerlevels.go +++ b/event/powerlevels.go @@ -134,10 +134,20 @@ func (pl *PowerLevelsEventContent) SetUserLevel(userID id.UserID, level int) { } } -func (pl *PowerLevelsEventContent) EnsureUserLevel(userID id.UserID, level int) bool { - existingLevel := pl.GetUserLevel(userID) +func (pl *PowerLevelsEventContent) EnsureUserLevel(target id.UserID, level int) bool { + return pl.EnsureUserLevelAs("", target, level) +} + +func (pl *PowerLevelsEventContent) EnsureUserLevelAs(actor, target id.UserID, level int) bool { + existingLevel := pl.GetUserLevel(target) + if actor != "" { + actorLevel := pl.GetUserLevel(actor) + if actorLevel <= existingLevel || actorLevel < level { + return false + } + } if existingLevel != level { - pl.SetUserLevel(userID, level) + pl.SetUserLevel(target, level) return true } return false @@ -170,7 +180,17 @@ func (pl *PowerLevelsEventContent) SetEventLevel(eventType Type, level int) { } func (pl *PowerLevelsEventContent) EnsureEventLevel(eventType Type, level int) bool { + return pl.EnsureEventLevelAs("", eventType, level) +} + +func (pl *PowerLevelsEventContent) EnsureEventLevelAs(actor id.UserID, eventType Type, level int) bool { existingLevel := pl.GetEventLevel(eventType) + if actor != "" { + actorLevel := pl.GetUserLevel(actor) + if existingLevel > actorLevel || level > actorLevel { + return false + } + } if existingLevel != level { pl.SetEventLevel(eventType, level) return true From 5735ea342096c8807e2157ddce53fbc350259cd7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 9 Aug 2024 23:00:34 +0300 Subject: [PATCH 0604/1647] format: generate `m.mentions` when parsing markdown --- format/htmlparser.go | 30 +++++++++++++++++++++++------- format/markdown.go | 13 ++++++++----- 2 files changed, 31 insertions(+), 12 deletions(-) diff --git a/format/htmlparser.go b/format/htmlparser.go index 8ddd8818..d099e8a7 100644 --- a/format/htmlparser.go +++ b/format/htmlparser.go @@ -15,6 +15,7 @@ import ( "golang.org/x/net/html" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -66,9 +67,13 @@ type ColorConverter func(text, fg, bg string, ctx Context) string type CodeBlockConverter func(code, language string, ctx Context) string type PillConverter func(displayname, mxid, eventID string, ctx Context) string -func DefaultPillConverter(displayname, mxid, eventID string, _ Context) string { +const ContextKeyMentions = "_mentions" + +func DefaultPillConverter(displayname, mxid, eventID string, ctx Context) string { switch { case len(mxid) == 0, mxid[0] == '@': + existingMentions, _ := ctx.ReturnData[ContextKeyMentions].([]id.UserID) + ctx.ReturnData[ContextKeyMentions] = append(existingMentions, id.UserID(mxid)) // User link, always just show the displayname return displayname case len(eventID) > 0: @@ -417,11 +422,9 @@ func HTMLToText(html string) string { }).Parse(html, NewContext(context.TODO())) } -// HTMLToMarkdown converts Matrix HTML into markdown with the default settings. -// -// Currently, the only difference to HTMLToText is how links are formatted. -func HTMLToMarkdown(html string) string { - return (&HTMLParser{ +func HTMLToMarkdownAndMentions(html string) (parsed string, mentions *event.Mentions) { + ctx := NewContext(context.TODO()) + parsed = (&HTMLParser{ TabsToSpaces: 4, Newline: "\n", HorizontalLine: "\n---\n", @@ -432,5 +435,18 @@ func HTMLToMarkdown(html string) string { } return fmt.Sprintf("[%s](%s)", text, href) }, - }).Parse(html, NewContext(context.TODO())) + }).Parse(html, ctx) + mentionList, _ := ctx.ReturnData[ContextKeyMentions].([]id.UserID) + mentions = &event.Mentions{ + UserIDs: mentionList, + } + return +} + +// HTMLToMarkdown converts Matrix HTML into markdown with the default settings. +// +// Currently, the only difference to HTMLToText is how links are formatted. +func HTMLToMarkdown(html string) string { + parsed, _ := HTMLToMarkdownAndMentions(html) + return parsed } diff --git a/format/markdown.go b/format/markdown.go index fa2a8e8a..11f9f684 100644 --- a/format/markdown.go +++ b/format/markdown.go @@ -50,18 +50,20 @@ func RenderMarkdownCustom(text string, renderer goldmark.Markdown) event.Message } func HTMLToContent(html string) event.MessageEventContent { - text := HTMLToMarkdown(html) + text, mentions := HTMLToMarkdownAndMentions(html) if html != text { return event.MessageEventContent{ FormattedBody: html, Format: event.FormatHTML, MsgType: event.MsgText, Body: text, + Mentions: mentions, } } return event.MessageEventContent{ - MsgType: event.MsgText, - Body: text, + MsgType: event.MsgText, + Body: text, + Mentions: &event.Mentions{}, } } @@ -79,8 +81,9 @@ func RenderMarkdown(text string, allowMarkdown, allowHTML bool) event.MessageEve return HTMLToContent(htmlBody) } else { return event.MessageEventContent{ - MsgType: event.MsgText, - Body: text, + MsgType: event.MsgText, + Body: text, + Mentions: &event.Mentions{}, } } } From 6d5ae8858bd96e4be6d5f59b3806d9e3b78f1574 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 9 Aug 2024 23:00:57 +0300 Subject: [PATCH 0605/1647] bridgev2: add support for starting DM by inviting ghost --- bridgev2/matrix/connector.go | 19 ++- bridgev2/matrix/intent.go | 16 ++- bridgev2/matrixinterface.go | 4 + bridgev2/matrixinvite.go | 241 +++++++++++++++++++++++++++++++++++ bridgev2/networkinterface.go | 8 ++ bridgev2/queue.go | 103 +++++++-------- 6 files changed, 333 insertions(+), 58 deletions(-) create mode 100644 bridgev2/matrixinvite.go diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 9ea1cd3a..3349cfe1 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -98,8 +98,9 @@ type Connector struct { } var ( - _ bridgev2.MatrixConnector = (*Connector)(nil) - _ bridgev2.MatrixConnectorWithServer = (*Connector)(nil) + _ bridgev2.MatrixConnector = (*Connector)(nil) + _ bridgev2.MatrixConnectorWithServer = (*Connector)(nil) + _ bridgev2.MatrixConnectorWithPostRoomBridgeHandling = (*Connector)(nil) ) func NewConnector(cfg *bridgeconfig.Config) *Connector { @@ -593,3 +594,17 @@ func (br *Connector) GenerateReactionEventID(roomID id.RoomID, targetMessage *da func (br *Connector) ServerName() string { return br.Config.Homeserver.Domain } + +func (br *Connector) HandleNewlyBridgedRoom(ctx context.Context, roomID id.RoomID) error { + if !br.Config.Encryption.Default { + return nil + } + _, err := br.Bot.SendStateEvent(ctx, roomID, event.StateEncryption, "", &event.Content{ + Parsed: br.getDefaultEncryptionEvent(), + }) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to enable encryption in newly bridged room") + return fmt.Errorf("failed to enable encryption") + } + return nil +} diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index ac771734..91d955f9 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -286,17 +286,21 @@ func (as *ASIntent) EnsureInvited(ctx context.Context, roomID id.RoomID, userID return as.Matrix.EnsureInvited(ctx, roomID, userID) } +func (br *Connector) getDefaultEncryptionEvent() *event.EncryptionEventContent { + content := &event.EncryptionEventContent{Algorithm: id.AlgorithmMegolmV1} + if rot := br.Config.Encryption.Rotation; rot.EnableCustom { + content.RotationPeriodMillis = rot.Milliseconds + content.RotationPeriodMessages = rot.Messages + } + return content +} + func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) (id.RoomID, error) { if as.Connector.Config.Encryption.Default { - content := &event.EncryptionEventContent{Algorithm: id.AlgorithmMegolmV1} - if rot := as.Connector.Config.Encryption.Rotation; rot.EnableCustom { - content.RotationPeriodMillis = rot.Milliseconds - content.RotationPeriodMessages = rot.Messages - } req.InitialState = append(req.InitialState, &event.Event{ Type: event.StateEncryption, Content: event.Content{ - Parsed: content, + Parsed: as.Connector.getDefaultEncryptionEvent(), }, }) } diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 714d5fc7..6d30891e 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -66,6 +66,10 @@ type MatrixConnectorWithNameDisambiguation interface { IsConfusableName(ctx context.Context, roomID id.RoomID, userID id.UserID, name string) ([]id.UserID, error) } +type MatrixConnectorWithPostRoomBridgeHandling interface { + HandleNewlyBridgedRoom(ctx context.Context, roomID id.RoomID) error +} + type MatrixSendExtra struct { Timestamp time.Time MessageMeta *database.Message diff --git a/bridgev2/matrixinvite.go b/bridgev2/matrixinvite.go new file mode 100644 index 00000000..740743f6 --- /dev/null +++ b/bridgev2/matrixinvite.go @@ -0,0 +1,241 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" + "maunium.net/go/mautrix/id" +) + +func (br *Bridge) handleBotInvite(ctx context.Context, evt *event.Event, sender *User) { + log := zerolog.Ctx(ctx) + // These invites should already be rejected in QueueMatrixEvent + if !sender.Permissions.Commands { + log.Warn().Msg("Received bot invite from user without permission to send commands") + return + } + err := br.Bot.EnsureJoined(ctx, evt.RoomID) + if err != nil { + log.Err(err).Msg("Failed to accept invite to room") + return + } + log.Debug().Msg("Accepted invite to room as bot") + members, err := br.Matrix.GetMembers(ctx, evt.RoomID) + if err != nil { + log.Err(err).Msg("Failed to get members of room after accepting invite") + } + if len(members) == 2 { + var message string + if sender.ManagementRoom == "" { + message = fmt.Sprintf("Hello, I'm a %s bridge bot.\n\nUse `help` for help or `login` to log in.\n\nThis room has been marked as your management room.", br.Network.GetName().DisplayName) + sender.ManagementRoom = evt.RoomID + err = br.DB.User.Update(ctx, sender.User) + if err != nil { + log.Err(err).Msg("Failed to update user's management room in database") + } + } else { + message = fmt.Sprintf("Hello, I'm a %s bridge bot.\n\nUse `%s help` for help.", br.Network.GetName().DisplayName, br.Config.CommandPrefix) + } + _, err = br.Bot.SendMessage(ctx, evt.RoomID, event.EventMessage, &event.Content{ + Parsed: format.RenderMarkdown(message, true, false), + }, nil) + if err != nil { + log.Err(err).Msg("Failed to send welcome message to room") + } + } +} + +func sendNotice(ctx context.Context, evt *event.Event, intent MatrixAPI, message string, args ...any) { + if len(args) > 0 { + message = fmt.Sprintf(message, args...) + } + content := format.RenderMarkdown(message, true, false) + content.MsgType = event.MsgNotice + resp, err := intent.SendMessage(ctx, evt.RoomID, event.EventMessage, &event.Content{Parsed: content}, nil) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("room_id", evt.RoomID). + Stringer("inviter_id", evt.Sender). + Stringer("invitee_id", intent.GetMXID()). + Str("notice_text", message). + Msg("Failed to send notice") + } else { + zerolog.Ctx(ctx).Debug(). + Stringer("notice_event_id", resp.EventID). + Stringer("room_id", evt.RoomID). + Stringer("inviter_id", evt.Sender). + Stringer("invitee_id", intent.GetMXID()). + Str("notice_text", message). + Msg("Sent notice") + } +} + +func sendErrorAndLeave(ctx context.Context, evt *event.Event, intent MatrixAPI, message string, args ...any) { + sendNotice(ctx, evt, intent, message, args...) + rejectInvite(ctx, evt, intent, "") +} + +func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sender *User) { + ghostID, _ := br.Matrix.ParseGhostMXID(id.UserID(evt.GetStateKey())) + validator, ok := br.Network.(IdentifierValidatingNetwork) + if ghostID == "" || (ok && !validator.ValidateUserID(ghostID)) { + rejectInvite(ctx, evt, br.Matrix.GhostIntent(ghostID), "Malformed user ID") + return + } + log := zerolog.Ctx(ctx).With(). + Str("invitee_network_id", string(ghostID)). + Stringer("room_id", evt.RoomID). + Logger() + // TODO sort in preference order + logins := sender.GetCachedUserLogins() + if len(logins) == 0 { + rejectInvite(ctx, evt, br.Matrix.GhostIntent(ghostID), "You're not logged in") + return + } + _, ok = logins[0].Client.(IdentifierResolvingNetworkAPI) + if !ok { + rejectInvite(ctx, evt, br.Matrix.GhostIntent(ghostID), "This bridge does not support starting chats") + return + } + invitedGhost, err := br.GetGhostByID(ctx, ghostID) + if err != nil { + log.Err(err).Msg("Failed to get invited ghost") + return + } + err = invitedGhost.Intent.EnsureJoined(ctx, evt.RoomID) + if err != nil { + log.Err(err).Msg("Failed to accept invite to room") + return + } + var resp *ResolveIdentifierResponse + var sourceLogin *UserLogin + // TODO this should somehow lock incoming event processing to avoid race conditions where a new portal room is created + // between ResolveIdentifier returning and the portal MXID being updated. + for _, login := range logins { + api, ok := login.Client.(IdentifierResolvingNetworkAPI) + if !ok { + continue + } + resp, err = api.ResolveIdentifier(ctx, string(ghostID), true) + if errors.Is(err, ErrResolveIdentifierTryNext) { + log.Debug().Err(err).Str("login_id", string(login.ID)).Msg("Failed to resolve identifier, trying next login") + continue + } else if err != nil { + log.Err(err).Msg("Failed to resolve identifier") + sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "Failed to create chat") + return + } else { + sourceLogin = login + break + } + } + if resp == nil { + log.Warn().Msg("No login could resolve the identifier") + sendErrorAndLeave(ctx, evt, br.Matrix.GhostIntent(ghostID), "Failed to create chat via any login") + return + } + portal := resp.Chat.Portal + if portal == nil { + portal, err = br.GetPortalByKey(ctx, resp.Chat.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to get portal by key") + sendErrorAndLeave(ctx, evt, br.Matrix.GhostIntent(ghostID), "Failed to create portal entry") + return + } + } + err = invitedGhost.Intent.EnsureInvited(ctx, evt.RoomID, br.Bot.GetMXID()) + if err != nil { + log.Err(err).Msg("Failed to ensure bot is invited to room") + sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "Failed to invite bridge bot") + return + } + err = br.Bot.EnsureJoined(ctx, evt.RoomID) + if err != nil { + log.Err(err).Msg("Failed to ensure bot is joined to room") + sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "Failed to join with bridge bot") + return + } + + didSetPortal := portal.setMXIDToExistingRoom(evt.RoomID) + if resp.Chat.PortalInfo != nil { + portal.UpdateInfo(ctx, resp.Chat.PortalInfo, sourceLogin, nil, time.Time{}) + } + if didSetPortal { + // TODO this might become unnecessary if UpdateInfo starts taking care of it + _, err = br.Bot.SendState(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.Content{ + Parsed: &event.ElementFunctionalMembersContent{ + ServiceMembers: []id.UserID{br.Bot.GetMXID()}, + }, + }, time.Time{}) + if err != nil { + log.Warn().Err(err).Msg("Failed to set service members in room") + } + message := "Private chat portal created" + err = br.givePowerToBot(ctx, evt.RoomID, invitedGhost.Intent) + hasWarning := false + if err != nil { + log.Warn().Err(err).Msg("Failed to give power to bot in new DM") + message += "\n\nWarning: failed to promote bot" + hasWarning = true + } + mx, ok := br.Matrix.(MatrixConnectorWithPostRoomBridgeHandling) + if ok { + err = mx.HandleNewlyBridgedRoom(ctx, evt.RoomID) + if err != nil { + if hasWarning { + message += fmt.Sprintf(", %s", err.Error()) + } else { + message += fmt.Sprintf("\n\nWarning: %s", err.Error()) + } + } + } + sendNotice(ctx, evt, invitedGhost.Intent, message) + } else { + // TODO ensure user is invited even if PortalInfo wasn't provided? + sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "You already have a direct chat with me at [%s](%s)", portal.MXID, portal.MXID.URI(br.Matrix.ServerName()).MatrixToURL()) + rejectInvite(ctx, evt, br.Bot, "") + } +} + +func (br *Bridge) givePowerToBot(ctx context.Context, roomID id.RoomID, userWithPower MatrixAPI) error { + powers, err := br.Matrix.GetPowerLevels(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to get power levels: %w", err) + } + userLevel := powers.GetUserLevel(userWithPower.GetMXID()) + if powers.EnsureUserLevelAs(userWithPower.GetMXID(), br.Bot.GetMXID(), userLevel) { + _, err = userWithPower.SendState(ctx, roomID, event.StatePowerLevels, "", &event.Content{ + Parsed: powers, + }, time.Time{}) + if err != nil { + return fmt.Errorf("failed to give power to bot: %w", err) + } + } + return nil +} + +func (portal *Portal) setMXIDToExistingRoom(roomID id.RoomID) bool { + portal.roomCreateLock.Lock() + defer portal.roomCreateLock.Unlock() + if portal.MXID != "" { + return false + } + portal.MXID = roomID + portal.updateLogger() + portal.Bridge.cacheLock.Lock() + portal.Bridge.portalsByMXID[portal.MXID] = portal + portal.Bridge.cacheLock.Unlock() + return true +} diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index f6aedbc6..1a36b583 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -573,6 +573,14 @@ type CreateChatResponse struct { PortalInfo *ChatInfo } +// ErrResolveIdentifierTryNext can be returned by ResolveIdentifier to signal that the identifier is valid, +// but can't be reached by the current login, and the caller should try the next login if there are more. +// +// This should generally only be returned when resolving internal IDs (which happens when initiating chats via Matrix). +// For example, Google Messages would return this when trying to resolve another login's user ID, +// and Telegram would return this when the access hash isn't available. +var ErrResolveIdentifierTryNext = errors.New("that identifier is not available via this login") + // IdentifierResolvingNetworkAPI is an optional interface that network connectors can implement to support starting new direct chats. type IdentifierResolvingNetworkAPI interface { NetworkAPI diff --git a/bridgev2/queue.go b/bridgev2/queue.go index bcf092b8..a79d56e3 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -16,9 +16,53 @@ import ( "github.com/rs/zerolog" "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/format" + "maunium.net/go/mautrix/id" ) +func rejectInvite(ctx context.Context, evt *event.Event, intent MatrixAPI, reason string) { + resp, err := intent.SendState(ctx, evt.RoomID, event.StateMember, intent.GetMXID().String(), &event.Content{ + Parsed: &event.MemberEventContent{ + Membership: event.MembershipLeave, + Reason: reason, + }, + }, time.Time{}) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("room_id", evt.RoomID). + Stringer("inviter_id", evt.Sender). + Stringer("invitee_id", intent.GetMXID()). + Str("reason", reason). + Msg("Failed to reject invite") + } else { + zerolog.Ctx(ctx).Debug(). + Stringer("leave_event_id", resp.EventID). + Stringer("room_id", evt.RoomID). + Stringer("inviter_id", evt.Sender). + Stringer("invitee_id", intent.GetMXID()). + Str("reason", reason). + Msg("Rejected invite") + } +} + +func (br *Bridge) rejectInviteOnNoPermission(ctx context.Context, evt *event.Event, permType string) bool { + if evt.Type != event.StateMember || evt.Content.AsMember().Membership != event.MembershipInvite { + return false + } + userID := id.UserID(evt.GetStateKey()) + parsed, isGhost := br.Matrix.ParseGhostMXID(userID) + if userID != br.Bot.GetMXID() && !isGhost { + return false + } + var intent MatrixAPI + if userID == br.Bot.GetMXID() { + intent = br.Bot + } else { + intent = br.Matrix.GhostIntent(parsed) + } + rejectInvite(ctx, evt, intent, "You don't have permission to "+permType+" this bridge") + return true +} + func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) { // TODO maybe HandleMatrixEvent would be more appropriate as this also handles bot invites and commands @@ -38,8 +82,12 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) { br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) return } else if !sender.Permissions.SendEvents { - status := WrapErrorInStatus(errors.New("you don't have permission to send messages")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage() - br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) + if !br.rejectInviteOnNoPermission(ctx, evt, "interact with") { + status := WrapErrorInStatus(errors.New("you don't have permission to send messages")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage() + br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) + } + return + } else if !sender.Permissions.Commands && br.rejectInviteOnNoPermission(ctx, evt, "send commands to") { return } } else if evt.Type.Class != event.EphemeralEventType { @@ -83,59 +131,14 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) { evt: evt, sender: sender, }) + } else if evt.Type == event.StateMember && br.IsGhostMXID(id.UserID(evt.GetStateKey())) && evt.Content.AsMember().Membership == event.MembershipInvite && evt.Content.AsMember().IsDirect { + br.handleGhostDMInvite(ctx, evt, sender) } else { status := WrapErrorInStatus(ErrNoPortal) br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) } } -func (br *Bridge) handleBotInvite(ctx context.Context, evt *event.Event, sender *User) { - log := zerolog.Ctx(ctx) - if !sender.Permissions.Commands { - _, err := br.Bot.SendState(ctx, evt.RoomID, event.StateMember, br.Bot.GetMXID().String(), &event.Content{ - Parsed: &event.MemberEventContent{ - Membership: event.MembershipLeave, - Reason: "You don't have permission to send commands to this bridge", - }, - }, time.Time{}) - if err != nil { - log.Err(err).Msg("Failed to reject invite from user with no permission") - } else { - log.Debug().Msg("Rejected invite from user with no permission") - } - return - } - err := br.Bot.EnsureJoined(ctx, evt.RoomID) - if err != nil { - log.Err(err).Msg("Failed to accept invite to room") - return - } - log.Debug().Msg("Accepted invite to room as bot") - members, err := br.Matrix.GetMembers(ctx, evt.RoomID) - if err != nil { - log.Err(err).Msg("Failed to get members of room after accepting invite") - } - if len(members) == 2 { - var message string - if sender.ManagementRoom == "" { - message = fmt.Sprintf("Hello, I'm a %s bridge bot.\n\nUse `help` for help or `login` to log in.\n\nThis room has been marked as your management room.", br.Network.GetName().DisplayName) - sender.ManagementRoom = evt.RoomID - err = br.DB.User.Update(ctx, sender.User) - if err != nil { - log.Err(err).Msg("Failed to update user's management room in database") - } - } else { - message = fmt.Sprintf("Hello, I'm a %s bridge bot.\n\nUse `%s help` for help.", br.Network.GetName().DisplayName, br.Config.CommandPrefix) - } - _, err = br.Bot.SendMessage(ctx, evt.RoomID, event.EventMessage, &event.Content{ - Parsed: format.RenderMarkdown(message, true, false), - }, nil) - if err != nil { - log.Err(err).Msg("Failed to send welcome message to room") - } - } -} - func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) { log := login.Log ctx := log.WithContext(context.TODO()) From b83ac7d071a96391f0ca9b22d18897b66f0e87ad Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 9 Aug 2024 23:17:28 +0300 Subject: [PATCH 0606/1647] bridgev2/matrix: add better log when sending message status fails --- bridgev2/matrix/connector.go | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 3349cfe1..f3870c1a 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -441,9 +441,14 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2 log.Err(err).Msg("Failed to send message checkpoint") } if !ms.DisableMSS && br.Config.Matrix.MessageStatusEvents { - _, err = br.Bot.SendMessageEvent(ctx, evt.RoomID, event.BeeperMessageStatus, ms.ToMSSEvent(evt)) + mssEvt := ms.ToMSSEvent(evt) + _, err = br.Bot.SendMessageEvent(ctx, evt.RoomID, event.BeeperMessageStatus, mssEvt) if err != nil { - log.Err(err).Msg("Failed to send MSS event") + log.Err(err). + Stringer("room_id", evt.RoomID). + Stringer("event_id", evt.EventID). + Any("mss_content", mssEvt). + Msg("Failed to send MSS event") } } if ms.SendNotice && br.Config.Matrix.MessageErrorNotices && (ms.Status == event.MessageStatusFail || ms.Status == event.MessageStatusRetriable || ms.Step == status.MsgStepDecrypted) { @@ -453,7 +458,11 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2 } resp, err := br.Bot.SendMessageEvent(ctx, evt.RoomID, event.EventMessage, content) if err != nil { - log.Err(err).Msg("Failed to send notice event") + log.Err(err). + Stringer("room_id", evt.RoomID). + Stringer("event_id", evt.EventID). + Str("notice_message", content.Body). + Msg("Failed to send notice event") } else { return resp.EventID } @@ -461,7 +470,10 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2 if ms.Status == event.MessageStatusSuccess && br.Config.Matrix.DeliveryReceipts { err = br.Bot.SendReceipt(ctx, evt.RoomID, evt.EventID, event.ReceiptTypeRead, nil) if err != nil { - log.Err(err).Msg("Failed to send Matrix delivery receipt") + log.Err(err). + Stringer("room_id", evt.RoomID). + Stringer("event_id", evt.EventID). + Msg("Failed to send Matrix delivery receipt") } } return "" @@ -596,10 +608,14 @@ func (br *Connector) ServerName() string { } func (br *Connector) HandleNewlyBridgedRoom(ctx context.Context, roomID id.RoomID) error { + _, err := br.Bot.Members(ctx, roomID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to fetch members in newly bridged room") + } if !br.Config.Encryption.Default { return nil } - _, err := br.Bot.SendStateEvent(ctx, roomID, event.StateEncryption, "", &event.Content{ + _, err = br.Bot.SendStateEvent(ctx, roomID, event.StateEncryption, "", &event.Content{ Parsed: br.getDefaultEncryptionEvent(), }) if err != nil { From bbe62ee977a527103041a8d40c1a088c4473c0a4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 9 Aug 2024 23:35:03 +0300 Subject: [PATCH 0607/1647] bridgev2/portal: move membership handler to be with the other handlers --- bridgev2/portal.go | 128 ++++++++++++++++++++++----------------------- 1 file changed, 64 insertions(+), 64 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index bed794a0..3c2db930 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1150,6 +1150,70 @@ func handleMatrixAccountData[APIType any, ContentType any]( } } +func (portal *Portal) handleMatrixMembership( + ctx context.Context, + sender *UserLogin, + origSender *OrigSender, + evt *event.Event, +) { + api, ok := sender.Client.(MembershipHandlingNetworkAPI) + if !ok { + portal.sendErrorStatus(ctx, evt, ErrMembershipNotSupported) + return + } + log := zerolog.Ctx(ctx) + targetMXID := id.UserID(*evt.StateKey) + isSelf := sender.User.MXID == targetMXID + var err error + var targetUserLogin *UserLogin + targetGhost, err := portal.Bridge.GetGhostByMXID(ctx, targetMXID) + if err != nil { + log.Err(err).Stringer("mxid", targetMXID).Msg("Failed to get target ghost") + return + } + if targetGhost == nil { + targetUser, err := portal.Bridge.GetUserByMXID(ctx, targetMXID) + if err != nil { + log.Err(err).Stringer("mxid", targetMXID).Msg("Failed to get target user") + return + } + targetUserLogin, _, err = portal.FindPreferredLogin(ctx, targetUser, false) + if err != nil { + log.Err(err).Stringer("mxid", targetMXID).Msg("Failed to get target user login") + return + } + } + prevContent := &event.MemberEventContent{Membership: event.MembershipLeave} + if evt.Unsigned.PrevContent != nil { + _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) + prevContent, _ = evt.Unsigned.PrevContent.Parsed.(*event.MemberEventContent) + } + + content := evt.Content.AsMember() + membershipChangeType := MembershipChangeType{From: prevContent.Membership, To: content.Membership, IsSelf: isSelf} + if !portal.Bridge.Config.BridgeMatrixLeave && membershipChangeType == Leave { + log.Debug().Msg("Dropping leave event") + return + } + membershipChange := &MatrixMembershipChange{ + MatrixEventBase: MatrixEventBase[*event.MemberEventContent]{ + Event: evt, + Content: content, + Portal: portal, + OrigSender: origSender, + }, + TargetGhost: targetGhost, + TargetUserLogin: targetUserLogin, + Type: membershipChangeType, + } + _, err = api.HandleMatrixMembership(ctx, membershipChange) + if err != nil { + log.Err(err).Msg("Failed to handle Matrix membership change") + portal.sendErrorStatus(ctx, evt, err) + return + } +} + func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(*event.RedactionEventContent) @@ -3259,67 +3323,3 @@ func (portal *Portal) SetRelay(ctx context.Context, relay *UserLogin) error { } return nil } - -func (portal *Portal) handleMatrixMembership( - ctx context.Context, - sender *UserLogin, - origSender *OrigSender, - evt *event.Event, -) { - api, ok := sender.Client.(MembershipHandlingNetworkAPI) - if !ok { - portal.sendErrorStatus(ctx, evt, ErrMembershipNotSupported) - return - } - log := zerolog.Ctx(ctx) - targetMXID := id.UserID(*evt.StateKey) - isSelf := sender.User.MXID == targetMXID - var err error - var targetUserLogin *UserLogin - targetGhost, err := portal.Bridge.GetGhostByMXID(ctx, targetMXID) - if err != nil { - log.Err(err).Stringer("mxid", targetMXID).Msg("Failed to get target ghost") - return - } - if targetGhost == nil { - targetUser, err := portal.Bridge.GetUserByMXID(ctx, targetMXID) - if err != nil { - log.Err(err).Stringer("mxid", targetMXID).Msg("Failed to get target user") - return - } - targetUserLogin, _, err = portal.FindPreferredLogin(ctx, targetUser, false) - if err != nil { - log.Err(err).Stringer("mxid", targetMXID).Msg("Failed to get target user login") - return - } - } - prevContent := &event.MemberEventContent{Membership: event.MembershipLeave} - if evt.Unsigned.PrevContent != nil { - _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) - prevContent, _ = evt.Unsigned.PrevContent.Parsed.(*event.MemberEventContent) - } - - content := evt.Content.AsMember() - membershipChangeType := MembershipChangeType{From: prevContent.Membership, To: content.Membership, IsSelf: isSelf} - if !portal.Bridge.Config.BridgeMatrixLeave && membershipChangeType == Leave { - log.Debug().Msg("Dropping leave event") - return - } - membershipChange := &MatrixMembershipChange{ - MatrixEventBase: MatrixEventBase[*event.MemberEventContent]{ - Event: evt, - Content: content, - Portal: portal, - OrigSender: origSender, - }, - TargetGhost: targetGhost, - TargetUserLogin: targetUserLogin, - Type: membershipChangeType, - } - _, err = api.HandleMatrixMembership(ctx, membershipChange) - if err != nil { - log.Err(err).Msg("Failed to handle Matrix Membership Change") - portal.sendErrorStatus(ctx, evt, err) - return - } -} From 49b1f240edfe781c05286027173e7c34c1d050e3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 9 Aug 2024 23:46:56 +0300 Subject: [PATCH 0608/1647] format: fix tests --- format/markdown_test.go | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/format/markdown_test.go b/format/markdown_test.go index 179de6b6..10ae270c 100644 --- a/format/markdown_test.go +++ b/format/markdown_test.go @@ -17,17 +17,20 @@ import ( "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/format" "maunium.net/go/mautrix/format/mdext" + "maunium.net/go/mautrix/id" ) func TestRenderMarkdown_PlainText(t *testing.T) { content := format.RenderMarkdown("hello world", true, true) - assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "hello world"}, content) + assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "hello world", Mentions: &event.Mentions{}}, content) content = format.RenderMarkdown("hello world", true, false) - assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "hello world"}, content) + assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "hello world", Mentions: &event.Mentions{}}, content) content = format.RenderMarkdown("hello world", false, true) - assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "hello world"}, content) + assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "hello world", Mentions: &event.Mentions{}}, content) content = format.RenderMarkdown("hello world", false, false) - assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "hello world"}, content) + assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "hello world", Mentions: &event.Mentions{}}, content) + content = format.RenderMarkdown(`mention`, false, false) + assert.Equal(t, event.MessageEventContent{MsgType: event.MsgText, Body: "mention", Mentions: &event.Mentions{}}, content) } func TestRenderMarkdown_EscapeHTML(t *testing.T) { @@ -37,6 +40,7 @@ func TestRenderMarkdown_EscapeHTML(t *testing.T) { Body: "hello world", Format: event.FormatHTML, FormattedBody: "<b>hello world</b>", + Mentions: &event.Mentions{}, }, content) } @@ -47,6 +51,7 @@ func TestRenderMarkdown_HTML(t *testing.T) { Body: "**hello world**", Format: event.FormatHTML, FormattedBody: "hello world", + Mentions: &event.Mentions{}, }, content) content = format.RenderMarkdown("hello world", true, true) @@ -55,6 +60,18 @@ func TestRenderMarkdown_HTML(t *testing.T) { Body: "**hello world**", Format: event.FormatHTML, FormattedBody: "hello world", + Mentions: &event.Mentions{}, + }, content) + + content = format.RenderMarkdown(`[mention](https://matrix.to/#/@user:example.com)`, true, false) + assert.Equal(t, event.MessageEventContent{ + MsgType: event.MsgText, + Body: "mention", + Format: event.FormatHTML, + FormattedBody: `mention`, + Mentions: &event.Mentions{ + UserIDs: []id.UserID{"@user:example.com"}, + }, }, content) } From da4cbb554bc82a9cbfa57703fd062bce90f5f52e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 10 Aug 2024 14:00:06 +0300 Subject: [PATCH 0609/1647] bridgev2: update portalinternal --- bridgev2/portalinternal.go | 48 +++++++++++++++++++++++++++++ bridgev2/portalinternal_generate.go | 23 +++++++++++--- 2 files changed, 66 insertions(+), 5 deletions(-) diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index 68ad6046..f7fe658a 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -49,6 +49,10 @@ func (portal *PortalInternals) SendErrorStatus(ctx context.Context, evt *event.E (*Portal)(portal).sendErrorStatus(ctx, evt, err) } +func (portal *PortalInternals) CheckConfusableName(ctx context.Context, userID id.UserID, name string) bool { + return (*Portal)(portal).checkConfusableName(ctx, userID, name) +} + func (portal *PortalInternals) HandleMatrixEvent(sender *User, evt *event.Event) { (*Portal)(portal).handleMatrixEvent(sender, evt) } @@ -89,6 +93,10 @@ func (portal *PortalInternals) HandleMatrixReaction(ctx context.Context, sender (*Portal)(portal).handleMatrixReaction(ctx, sender, evt) } +func (portal *PortalInternals) HandleMatrixMembership(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { + (*Portal)(portal).handleMatrixMembership(ctx, sender, origSender, evt) +} + func (portal *PortalInternals) HandleMatrixRedaction(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { (*Portal)(portal).handleMatrixRedaction(ctx, sender, origSender, evt) } @@ -264,3 +272,43 @@ func (portal *PortalInternals) UnlockedDelete(ctx context.Context) error { func (portal *PortalInternals) UnlockedDeleteCache() { (*Portal)(portal).unlockedDeleteCache() } + +func (portal *PortalInternals) DoForwardBackfill(ctx context.Context, source *UserLogin, lastMessage *database.Message) { + (*Portal)(portal).doForwardBackfill(ctx, source, lastMessage) +} + +func (portal *PortalInternals) DoThreadBackfill(ctx context.Context, source *UserLogin, threadID networkid.MessageID) { + (*Portal)(portal).doThreadBackfill(ctx, source, threadID) +} + +func (portal *PortalInternals) SendBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, inThread bool) { + (*Portal)(portal).sendBackfill(ctx, source, messages, forceForward, markRead, inThread) +} + +func (portal *PortalInternals) SendBatch(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, allowNotification bool) { + (*Portal)(portal).sendBatch(ctx, source, messages, forceForward, markRead, allowNotification) +} + +func (portal *PortalInternals) SendLegacyBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, markRead bool) { + (*Portal)(portal).sendLegacyBackfill(ctx, source, messages, markRead) +} + +func (portal *PortalInternals) UnlockedReID(ctx context.Context, target networkid.PortalKey) error { + return (*Portal)(portal).unlockedReID(ctx, target) +} + +func (portal *PortalInternals) CreateParentAndAddToSpace(ctx context.Context, source *UserLogin) { + (*Portal)(portal).createParentAndAddToSpace(ctx, source) +} + +func (portal *PortalInternals) AddToParentSpaceAndSave(ctx context.Context, save bool) { + (*Portal)(portal).addToParentSpaceAndSave(ctx, save) +} + +func (portal *PortalInternals) ToggleSpace(ctx context.Context, spaceID id.RoomID, canonical, remove bool) error { + return (*Portal)(portal).toggleSpace(ctx, spaceID, canonical, remove) +} + +func (portal *PortalInternals) SetMXIDToExistingRoom(roomID id.RoomID) bool { + return (*Portal)(portal).setMXIDToExistingRoom(roomID) +} diff --git a/bridgev2/portalinternal_generate.go b/bridgev2/portalinternal_generate.go index 4438c112..2ac6c898 100644 --- a/bridgev2/portalinternal_generate.go +++ b/bridgev2/portalinternal_generate.go @@ -67,19 +67,26 @@ func getTypeName(expr ast.Expr) string { } } +var write func(str string) +var writef func(format string, args ...any) + func main() { fset := token.NewFileSet() - f := exerrors.Must(parser.ParseFile(fset, "portal.go", nil, parser.SkipObjectResolution)) + fileNames := []string{"portal.go", "portalbackfill.go", "portalreid.go", "space.go", "matrixinvite.go"} + files := make([]*ast.File, len(fileNames)) + for i, name := range fileNames { + files[i] = exerrors.Must(parser.ParseFile(fset, name, nil, parser.SkipObjectResolution)) + } file := exerrors.Must(os.OpenFile("portalinternal.go", os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)) - write := func(str string) { + write = func(str string) { exerrors.Must(file.WriteString(str)) } - writef := func(format string, args ...any) { + writef = func(format string, args ...any) { exerrors.Must(fmt.Fprintf(file, format, args...)) } write(header) write("import (\n") - for _, i := range f.Imports { + for _, i := range files[0].Imports { write("\t") if i.Name != nil { writef("%s ", i.Name.Name) @@ -88,6 +95,13 @@ func main() { } write(")\n") write(postImportHeader) + for _, f := range files { + processFile(f) + } + exerrors.PanicIfNotNil(file.Close()) +} + +func processFile(f *ast.File) { ast.Inspect(f, func(node ast.Node) (retVal bool) { retVal = true funcDecl, ok := node.(*ast.FuncDecl) @@ -156,5 +170,4 @@ func main() { write(")\n}\n") return }) - exerrors.PanicIfNotNil(file.Close()) } From 23c5446324a8658ed318f5391b270afe5df3bf86 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 10 Aug 2024 16:08:41 +0300 Subject: [PATCH 0610/1647] bridgev2/provisioning: nest bridge state in whoami response --- bridge/status/bridgestate.go | 11 ++++++----- bridgev2/matrix/provisioning.go | 30 +++++++++++++++++++++--------- bridgev2/userlogin.go | 2 +- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/bridge/status/bridgestate.go b/bridge/status/bridgestate.go index 90f228d4..72e61415 100644 --- a/bridge/status/bridgestate.go +++ b/bridge/status/bridgestate.go @@ -18,6 +18,7 @@ import ( "github.com/tidwall/sjson" "go.mau.fi/util/jsontime" + "go.mau.fi/util/ptr" "golang.org/x/exp/maps" "maunium.net/go/mautrix" @@ -89,10 +90,10 @@ type BridgeState struct { Error BridgeStateErrorCode `json:"error,omitempty"` Message string `json:"message,omitempty"` - UserID id.UserID `json:"user_id,omitempty"` - RemoteID string `json:"remote_id,omitempty"` - RemoteName string `json:"remote_name,omitempty"` - RemoteProfile RemoteProfile `json:"remote_profile,omitempty"` + UserID id.UserID `json:"user_id,omitempty"` + RemoteID string `json:"remote_id,omitempty"` + RemoteName string `json:"remote_name,omitempty"` + RemoteProfile *RemoteProfile `json:"remote_profile,omitempty"` Reason string `json:"reason,omitempty"` Info map[string]interface{} `json:"info,omitempty"` @@ -186,7 +187,7 @@ func (pong *BridgeState) SendHTTP(ctx context.Context, url, token string) error func (pong *BridgeState) ShouldDeduplicate(newPong *BridgeState) bool { return pong != nil && pong.StateEvent == newPong.StateEvent && - pong.RemoteProfile == newPong.RemoteProfile && + ptr.Val(pong.RemoteProfile) == ptr.Val(newPong.RemoteProfile) && pong.Error == newPong.Error && maps.EqualFunc(pong.Info, newPong.Info, reflect.DeepEqual) && pong.Timestamp.Add(time.Duration(pong.TTL)*time.Second).After(time.Now()) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index c3b8c3dc..5d9e1b8e 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -272,19 +272,25 @@ type RespWhoami struct { BridgeBot id.UserID `json:"bridge_bot"` CommandPrefix string `json:"command_prefix"` - ManagementRoom id.RoomID `json:"management_room"` + ManagementRoom id.RoomID `json:"management_room,omitempty"` Logins []RespWhoamiLogin `json:"logins"` } type RespWhoamiLogin struct { - StateEvent status.BridgeStateEvent `json:"state_event"` - StateTS jsontime.Unix `json:"state_ts"` - StateReason string `json:"state_reason,omitempty"` - StateInfo map[string]any `json:"state_info,omitempty"` - ID networkid.UserLoginID `json:"id"` - Name string `json:"name"` - Profile status.RemoteProfile `json:"profile"` - SpaceRoom id.RoomID `json:"space_room"` + // Deprecated + StateEvent status.BridgeStateEvent `json:"state_event"` + // Deprecated + StateTS jsontime.Unix `json:"state_ts"` + // Deprecated + StateReason string `json:"state_reason,omitempty"` + // Deprecated + StateInfo map[string]any `json:"state_info,omitempty"` + + State status.BridgeState `json:"state"` + ID networkid.UserLoginID `json:"id"` + Name string `json:"name"` + Profile status.RemoteProfile `json:"profile"` + SpaceRoom id.RoomID `json:"space_room,omitempty"` } func (prov *ProvisioningAPI) GetWhoami(w http.ResponseWriter, r *http.Request) { @@ -301,11 +307,17 @@ func (prov *ProvisioningAPI) GetWhoami(w http.ResponseWriter, r *http.Request) { resp.Logins = make([]RespWhoamiLogin, len(logins)) for i, login := range logins { prevState := login.BridgeState.GetPrevUnsent() + // Clear redundant fields + prevState.UserID = "" + prevState.RemoteID = "" + prevState.RemoteName = "" + prevState.RemoteProfile = nil resp.Logins[i] = RespWhoamiLogin{ StateEvent: prevState.StateEvent, StateTS: prevState.Timestamp, StateReason: prevState.Reason, StateInfo: prevState.Info, + State: prevState, ID: login.ID, Name: login.RemoteName, diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index d1711b31..017df773 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -476,7 +476,7 @@ func (ul *UserLogin) FillBridgeState(state status.BridgeState) status.BridgeStat state.UserID = ul.UserMXID state.RemoteID = string(ul.ID) state.RemoteName = ul.RemoteName - state.RemoteProfile = ul.RemoteProfile + state.RemoteProfile = &ul.RemoteProfile filler, ok := ul.Client.(status.StandaloneCustomBridgeStateFiller) if ok { return filler.FillBridgeState(state) From 5bfed60a3771b5afc36985d6f0341e8043b9e0ac Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 10 Aug 2024 17:43:53 +0300 Subject: [PATCH 0611/1647] bridgev2/matrix: add OpenAPI spec for provisioning API --- .editorconfig | 3 + bridgev2/matrix/provisioning.yaml | 826 ++++++++++++++++++++++++++++++ 2 files changed, 829 insertions(+) create mode 100644 bridgev2/matrix/provisioning.yaml diff --git a/.editorconfig b/.editorconfig index 21d312a1..1a167e7e 100644 --- a/.editorconfig +++ b/.editorconfig @@ -10,3 +10,6 @@ insert_final_newline = true [*.{yaml,yml}] indent_style = space + +[provisioning.yaml] +indent_size = 2 diff --git a/bridgev2/matrix/provisioning.yaml b/bridgev2/matrix/provisioning.yaml new file mode 100644 index 00000000..e195a3bd --- /dev/null +++ b/bridgev2/matrix/provisioning.yaml @@ -0,0 +1,826 @@ +openapi: 3.1.0 +info: + title: Megabridge provisioning + description: |- + This is the provisioning API implemented in mautrix-go's bridgev2 package. + It can be used with any bridge built on that package. + license: + name: Mozilla Public License Version 2.0 + url: https://github.com/mautrix/go/blob/main/LICENSE + version: v0.19.0 +externalDocs: + description: mautrix-go godocs + url: https://pkg.go.dev/maunium.net/go/mautrix/bridgev2 +servers: +- url: http://localhost:8080/_matrix/provision +tags: +- name: auth + description: Manage your logins and log into new remote accounts +- name: snc + description: Starting new chats +paths: + /v3/whoami: + get: + tags: [ auth ] + summary: Get info about the bridge and your logins. + description: | + Get all info that is useful for presenting this bridge in a manager interface. + * Server details: remote network details, available login flows, homeserver name, bridge bot user ID, command prefix + * User details: management room ID, list of logins with current state and info + operationId: whoami + responses: + 200: + description: Successfully fetched info + content: + application/json: + schema: + $ref: '#/components/schemas/Whoami' + 401: + $ref: '#/components/responses/Unauthorized' + 500: + $ref: '#/components/responses/InternalError' + security: + - matrix_auth: [ ] + /v3/login/flows: + get: + tags: [ auth ] + summary: Get the available login flows. + operationId: getLoginFlows + responses: + 200: + description: Successfully fetched flows + content: + application/json: + schema: + type: object + properties: + flows: + type: array + items: + $ref: '#/components/schemas/LoginFlow' + 401: + $ref: '#/components/responses/Unauthorized' + 500: + $ref: '#/components/responses/InternalError' + security: + - matrix_auth: [ ] + /v3/logins: + get: + tags: [ auth ] + summary: Get the login IDs of the current user. + operationId: getLoginIDs + responses: + 200: + description: Successfully fetched list of logins + content: + application/json: + schema: + type: object + properties: + login_ids: + type: array + items: + $ref: '#/components/schemas/UserLoginID' + 401: + $ref: '#/components/responses/Unauthorized' + 500: + $ref: '#/components/responses/InternalError' + security: + - matrix_auth: [ ] + /v3/login/start/{flowID}: + post: + tags: [ auth ] + summary: Start a new login process. + operationId: startLogin + parameters: + - name: login_id + in: query + description: An existing login ID to re-login as. If this is specified and the user logs into a different account, the provided ID will be logged out. + required: false + schema: + $ref: '#/components/schemas/UserLoginID' + - name: flowID + in: path + description: The login flow ID to use. + required: true + schema: + type: string + examples: [ qr ] + responses: + 200: + description: Login successfully started + content: + application/json: + schema: + $ref: '#/components/schemas/LoginStep' + 401: + $ref: '#/components/responses/Unauthorized' + 404: + $ref: '#/components/responses/LoginNotFound' + 500: + $ref: '#/components/responses/InternalError' + security: + - matrix_auth: [ ] + /v3/login/step/{loginProcessID}/{stepID}/user_input: + post: + tags: [ auth ] + summary: Submit user input in a login process. + operationId: submitLoginStepUserInput + parameters: + - $ref: '#/components/parameters/loginProcessID' + - $ref: '#/components/parameters/stepID' + requestBody: + description: The data entered by the user + content: + application/json: + schema: + type: object + additionalProperties: + type: string + responses: + 200: + $ref: '#/components/responses/LoginStepSubmitted' + 400: + $ref: '#/components/responses/BadRequest' + 401: + $ref: '#/components/responses/Unauthorized' + 404: + $ref: '#/components/responses/LoginProcessNotFound' + 500: + $ref: '#/components/responses/InternalError' + security: + - matrix_auth: [ ] + /v3/login/step/{loginProcessID}/{stepID}/cookies: + post: + tags: [ auth ] + summary: Submit extracted cookies in a login process. + operationId: submitLoginStepCookies + parameters: + - $ref: '#/components/parameters/loginProcessID' + - $ref: '#/components/parameters/stepID' + requestBody: + description: The cookies extracted from the website + content: + application/json: + schema: + type: object + additionalProperties: + type: string + responses: + 200: + $ref: '#/components/responses/LoginStepSubmitted' + 400: + $ref: '#/components/responses/BadRequest' + 401: + $ref: '#/components/responses/Unauthorized' + 404: + $ref: '#/components/responses/LoginProcessNotFound' + 500: + $ref: '#/components/responses/InternalError' + security: + - matrix_auth: [ ] + /v3/login/step/{loginProcessID}/{stepID}/display_and_wait: + post: + tags: [ auth ] + summary: Wait for the next step after displaying data to the user. + operationId: submitLoginStepDisplayAndWait + parameters: + - $ref: '#/components/parameters/loginProcessID' + - $ref: '#/components/parameters/stepID' + responses: + 200: + $ref: '#/components/responses/LoginStepSubmitted' + 400: + $ref: '#/components/responses/BadRequest' + 401: + $ref: '#/components/responses/Unauthorized' + 404: + $ref: '#/components/responses/LoginProcessNotFound' + 500: + $ref: '#/components/responses/InternalError' + security: + - matrix_auth: [ ] + /v3/logout/{loginID}: + post: + tags: [ auth ] + summary: Log out of an existing login. + operationId: logout + parameters: + - name: loginID + in: path + description: The ID of the login to log out. Use `all` to log out of all logins. + required: true + schema: + oneOf: + - $ref: '#/components/schemas/UserLoginID' + - type: string + const: all + description: Log out of all logins + responses: + 200: + description: Login was successfully deleted + content: + application/json: + schema: + type: object + description: Empty object + 401: + $ref: '#/components/responses/Unauthorized' + 404: + $ref: '#/components/responses/LoginNotFound' + 500: + $ref: '#/components/responses/InternalError' + /v3/contacts: + get: + tags: [ snc ] + summary: Get a list of contacts. + operationId: getContacts + parameters: + - $ref: "#/components/parameters/loginID" + responses: + 200: + description: Contact list fetched successfully + 401: + $ref: '#/components/responses/Unauthorized' + 404: + $ref: '#/components/responses/LoginNotFound' + 500: + $ref: '#/components/responses/InternalError' + 501: + $ref: '#/components/responses/NotSupported' + /v3/resolve_identifier/{identifier}: + get: + tags: [ snc ] + summary: Resolve an identifier to a user on the remote network. + operationId: resolveIdentifier + parameters: + - $ref: "#/components/parameters/loginID" + - $ref: "#/components/parameters/sncIdentifier" + responses: + 200: + description: Identifier resolved successfully + content: + application/json: + schema: + $ref: '#/components/schemas/ResolvedIdentifier' + 401: + $ref: '#/components/responses/Unauthorized' + 404: + # TODO identifier not found also returns 404 + $ref: '#/components/responses/LoginNotFound' + 500: + $ref: '#/components/responses/InternalError' + 501: + $ref: '#/components/responses/NotSupported' + /v3/create_dm/{identifier}: + post: + tags: [ snc ] + summary: Create a direct chat with a user on the remote network. + operationId: createDM + parameters: + - $ref: "#/components/parameters/loginID" + - $ref: "#/components/parameters/sncIdentifier" + responses: + 200: + description: Identifier resolved successfully + content: + application/json: + schema: + allOf: + - $ref: '#/components/schemas/ResolvedIdentifier' + - required: [id, mxid, dm_room_mxid] + 401: + $ref: '#/components/responses/Unauthorized' + 404: + # TODO identifier not found also returns 404 + $ref: '#/components/responses/LoginNotFound' + 500: + $ref: '#/components/responses/InternalError' + 501: + $ref: '#/components/responses/NotSupported' + /v3/create_group: + post: + tags: [ snc ] + summary: Create a group chat on the remote network. + operationId: createGroup + parameters: + - $ref: "#/components/parameters/loginID" + responses: + 401: + $ref: '#/components/responses/Unauthorized' + 404: + $ref: '#/components/responses/LoginNotFound' + 501: + $ref: '#/components/responses/NotSupported' +components: + parameters: + sncIdentifier: + name: identifier + in: path + description: The identifier to resolve or start a chat with. + required: true + schema: + type: string + examples: + - +12345678 + - username + - meow@example.com + loginID: + name: loginID + in: query + description: An optional explicit login ID to do the action through. + required: false + schema: + $ref: '#/components/schemas/UserLoginID' + loginProcessID: + name: loginProcessID + in: path + description: The ID of the login process, as returned in the `login_id` field of the start call. + required: true + schema: + type: string + stepID: + name: stepID + in: path + description: The ID of the step being submitted, as returned in the `step_id` field of the start call or the previous submit call. + required: true + schema: + type: string + stepType: + name: stepType + in: path + description: The type of step being submitted, as returned in the `type` field of the start call or the previous submit call. + required: true + schema: + type: string + enum: [ display_and_wait, user_input, cookies ] + responses: + BadRequest: + description: Something in the request was invalid + content: + application/json: + schema: + type: object + description: A Matrix-like error response + properties: + errcode: + type: string + enum: [ M_NOT_JSON, M_BAD_STATE ] + description: A Matrix-like error code + error: + type: string + description: A human-readable error message + examples: + - Failed to decode request body + - Step type does not match + Unauthorized: + description: The request contained an invalid token + content: + application/json: + schema: + type: object + description: A Matrix-like error response + properties: + errcode: + type: string + enum: [ M_UNKNOWN_TOKEN, M_MISSING_TOKEN ] + description: A Matrix-like error code + error: + type: string + description: A human-readable error message + examples: + - Invalid auth token + - Missing auth token + InternalError: + description: An unexpected error that doesn't have special handling yet + content: + application/json: + schema: + type: object + description: A Matrix-like error response + properties: + errcode: + type: string + enum: [ M_UNKNOWN ] + description: A Matrix-like error code + error: + type: string + description: A human-readable error message + examples: + - Failed to get user + - Failed to start login + LoginProcessNotFound: + description: The specified login process ID is unknown + content: + application/json: + schema: + type: object + description: A Matrix-like error response + properties: + errcode: + type: string + enum: [ M_NOT_FOUND ] + description: A Matrix-like error code + error: + type: string + description: A human-readable error message + examples: + - Login not found + LoginNotFound: + description: When explicitly specifying an existing user login, the specified login ID is unknown + content: + application/json: + schema: + type: object + description: A Matrix-like error response + properties: + errcode: + type: string + enum: [ M_NOT_FOUND ] + description: A Matrix-like error code + error: + type: string + description: A human-readable error message + examples: + - Login not found + NotSupported: + description: The given endpoint is not supported by this network connector. + content: + application/json: + schema: + type: object + description: A Matrix-like error response + properties: + errcode: + type: string + enum: [ M_UNRECOGNIZED ] + description: A Matrix-like error code + error: + type: string + description: A human-readable error message + examples: + - This bridge does not support listing contacts + LoginStepSubmitted: + description: Step submission successful + content: + application/json: + schema: + $ref: '#/components/schemas/LoginStep' + schemas: + ResolvedIdentifier: + type: object + description: A successfully resolved identifier. + required: [id] + properties: + id: + type: string + description: The internal user ID of the resolved user. + examples: + - c443c1a2-e9f7-48aa-890c-80336c300ba9 + name: + type: string + description: The name of the user on the remote network. + avatar_url: + type: string + format: mxc + description: The avatar of the user on the remote network. + pattern: mxc://[a-zA-Z0-9.:-]+/[a-zA-Z0-9_-]+ + examples: + - mxc://t2bot.io/JYDTofsS6V9aYfUiX7JueA36 + identifiers: + type: array + description: A list of identifiers for the user on the remote network. + items: + type: string + format: uri + examples: + - "tel:+123456789" + - "mailto:foo@example.com" + - "signal:username.123" + mxid: + type: string + format: matrix_user_id + description: The Matrix user ID of the ghost representing the user. + examples: + - '@signal_c443c1a2-e9f7-48aa-890c-80336c300ba9:t2bot.io' + dm_room_mxid: + type: string + format: matrix_room_id + description: The Matrix room ID of the direct chat with the user. + examples: + - '!OKhS0I5q2fCzdnl2qgeozDQw:t2bot.io' + LoginStep: + type: object + description: A step in a login process. + properties: + login_id: + type: string + description: An identifier for the current login process. Must be passed to execute more steps of the login. + type: + type: string + description: The type of login step + enum: [ display_and_wait, user_input, cookies, complete ] + step_id: + type: string + description: An unique ID identifying this step. This can be used to implement special behavior in clients. + examples: [ fi.mau.signal.qr ] + instructions: + type: string + description: Human-readable instructions for completing this login step. + examples: [ Scan the QR code ] + oneOf: + - description: Display and wait login step + required: [ type, display_and_wait ] + properties: + type: + type: string + const: display_and_wait + display_and_wait: + type: object + description: Parameters for the display and wait login step + required: [ type ] + properties: + type: + type: string + description: The type of thing to display + enum: [ qr, emoji, code, nothing ] + data: + type: string + description: The thing to display (raw data for QR, unicode emoji for emoji, plain string for code) + image_url: + type: string + description: An image containing the thing to display. If present, this is recommended over using data directly. For emojis, the URL to the canonical image representation of the emoji + - description: User input login step + required: [ type, user_input ] + properties: + type: + type: string + const: user_input + user_input: + type: object + description: Parameters for the user input login step + required: [ fields ] + properties: + fields: + type: array + description: The list of fields that the user is requested to fill. + items: + type: object + description: A field that the user can fill. + required: [ type, id, name ] + properties: + type: + type: string + description: The type of field. + enum: [ username, phone_number, email, password, 2fa_code, token ] + id: + type: string + description: The internal ID of the field. This must be used as the key in the object when submitting the data back to the bridge. + examples: [ uid, email, 2fa_password, meow ] + name: + type: string + description: The name of the field shown to the user. + examples: [ Username, Password, Phone number, 2FA code, Meow ] + description: + type: string + description: A more detailed description of the field shown to the user. + examples: + - Include the country code with a + + pattern: + type: string + format: regex + description: A regular expression that the field value must match. + - description: Cookie login step + required: [ type, cookies ] + properties: + type: + type: string + const: cookies + cookies: + type: object + description: Parameters for the cookie login step + required: [ url, fields ] + properties: + url: + type: string + format: uri + description: The URL to open when using a webview to extract cookies. + user_agent: + type: string + description: An optional user agent that the webview should use. + fields: + type: array + description: The list of cookies or other stored data that must be extracted. + items: + type: object + description: An individual cookie or other stored data item that must be extracted. + required: [ type, name ] + properties: + type: + type: string + description: The type of data to extract. + enum: [ cookie, local_storage, request_header, request_body, special ] + name: + type: string + description: The name of the item to extract. + request_url_regex: + type: string + description: For the `request_header` and `request_body` types, a regex that matches the URLs from which the values can be extracted. + cookie_domain: + type: string + description: For the `cookie` type, the domain of the cookie. + - description: Login complete + required: [ type, complete ] + properties: + type: + type: string + const: complete + complete: + type: object + description: Information about the completed login + properties: + user_login_id: + $ref: '#/components/schemas/UserLoginID' + LoginFlow: + type: object + description: An individual login flow which can be used to sign into the remote network. + required: [ name, description, id ] + properties: + name: + type: string + description: A human-readable name for the login flow. + examples: + - QR code + description: + type: string + description: A human-readable description of the login flow. + examples: + - Log in by scanning a QR code on the Signal app + id: + type: string + description: An internal ID that is passed to the /login/start call to start a login with this flow. + examples: + - qr + BridgeName: + type: object + description: Info about the network that the bridge is bridging to. + required: [ displayname, network_url, network_icon, network_id, beeper_bridge_type ] + properties: + displayname: + type: string + description: The displayname of the network. + examples: + - Signal + network_url: + type: string + description: The URL to the website of the network. + examples: + - https://signal.org + network_icon: + type: string + description: The icon of the network as a `mxc://` URI. + format: mxc + pattern: mxc://[a-zA-Z0-9.:-]+/[a-zA-Z0-9_-]+ + examples: + - mxc://maunium.net/wPJgTQbZOtpBFmDNkiNEMDUp + network_id: + type: string + description: An identifier uniquely identifying the network. + examples: + - signal + beeper_bridge_type: + type: string + description: An identifier uniquely identifying the bridge software. + examples: + - com.example.fancysignalbridge + BridgeState: + type: object + description: The connection status of an individual login + required: [ state_event, timestamp ] + properties: + state_event: + type: string + description: The current state of this login. + enum: [ "CONNECTING", "CONNECTED", "TRANSIENT_DISCONNECT", "BAD_CREDENTIALS", "UNKNOWN_ERROR" ] + timestamp: + type: number + description: The time when the state was last updated. + format: unix milliseconds + examples: + - 1723294560531 + error: + type: string + description: An error code defined by the network connector. + message: + type: string + description: A human-readable error message defined by the network connector. + reason: + type: string + description: A reason code for non-error states that aren't exactly successes either. + info: + type: object + description: Additional arbitrary info provided by the network connector. + UserLoginID: + type: string + description: The unique ID of a login. Defined by the network connector. + examples: + - bcc68892-b180-414f-9516-b4aadf7d0496 + RemoteProfile: + type: object + description: The profile info of the logged-in user on the remote network. + properties: + phone: + type: string + format: phone + description: The user's phone number + examples: + - +123456789 + email: + type: string + format: email + description: The user's email address + examples: + - foo@example.com + username: + type: string + description: The user's username + examples: + - foo.123 + name: + type: string + description: The user's displayname + examples: + - Foo Bar + avatar: + type: string + format: mxc + description: The user's avatar + pattern: mxc://[a-zA-Z0-9.:-]+/[a-zA-Z0-9_-]+ + examples: + - mxc://t2bot.io/JYDTofsS6V9aYfUiX7JueA36 + WhoamiLogin: + type: object + description: The info of an individual login + required: [ state, id, name, profile ] + properties: + state: + $ref: '#/components/schemas/BridgeState' + id: + $ref: '#/components/schemas/UserLoginID' + name: + type: string + description: A human-readable name for the login. Defined by the network connector. + examples: + - +123456789 + profile: + $ref: '#/components/schemas/RemoteProfile' + space_room: + type: string + format: matrix_room_id + description: The personal filtering space room ID for this login. + examples: + - "!X9l5njn4Mx1BpdoV8MOkyWU1:t2bot.io" + Whoami: + type: object + description: Info about the bridge and user + required: [ network, login_flows, homeserver, bridge_bot, command_prefix, logins ] + properties: + network: + $ref: '#/components/schemas/BridgeName' + login_flows: + type: array + description: The login flows that the bridge supports. + items: + $ref: '#/components/schemas/LoginFlow' + homeserver: + type: string + description: The server name the bridge is running on. + examples: + - t2bot.io + bridge_bot: + type: string + format: matrix_user_id + description: The Matrix user ID of the bridge bot. + examples: + - "@signalbot:t2bot.io" + command_prefix: + type: string + description: The command prefix used by this bridge. + examples: + - "!signal" + management_room: + type: string + format: matrix_room_id + description: The Matrix management room ID of the user who made the /whoami call. + examples: + - '!OKhS0I5q2fCzdnl2qgeozDQw:t2bot.io' + logins: + type: array + description: The logins of the user who made the /whoami call + items: + $ref: '#/components/schemas/WhoamiLogin' + securitySchemes: + matrix_auth: + type: http + scheme: bearer + description: Either a Matrix access token for users on the local server, or a [Matrix OpenID token](https://spec.matrix.org/v1.11/client-server-api/#openid) for users on other servers. From 55e96279b94c8bccaca27ff138e876f5d605177b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 10 Aug 2024 22:45:58 +0300 Subject: [PATCH 0612/1647] bridgev2/messagestatus: use internal error in notice message if custom message is not set --- bridgev2/messagestatus.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index 4c61a7a9..5876d812 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -225,7 +225,7 @@ func (ms *MessageStatus) ToNoticeEvent(evt *MessageStatusEventInfo) *event.Messa evtType = "redaction" } msg := ms.Message - if ms.ErrorAsMessage { + if ms.ErrorAsMessage || msg == "" { msg = ms.InternalError.Error() } messagePrefix := fmt.Sprintf("Your %s %s bridged", evtType, certainty) From fb87a6851e737291f2d863b649b0b56d370c94c8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 10 Aug 2024 23:26:05 +0300 Subject: [PATCH 0613/1647] bridgev2/provisioning: allow network connectors to return custom HTTP errors --- bridgev2/bridge.go | 3 -- bridgev2/errors.go | 82 +++++++++++++++++++++++++++++++++ bridgev2/matrix/provisioning.go | 44 ++++++++++-------- bridgev2/messagestatus.go | 28 +---------- bridgev2/networkinterface.go | 11 ----- client.go | 4 +- error.go | 4 +- 7 files changed, 114 insertions(+), 62 deletions(-) create mode 100644 bridgev2/errors.go diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 1bc364a1..aadefb0a 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -8,7 +8,6 @@ package bridgev2 import ( "context" - "errors" "fmt" "sync" @@ -22,8 +21,6 @@ import ( "maunium.net/go/mautrix/id" ) -var ErrNotLoggedIn = errors.New("not logged in") - type CommandProcessor interface { Handle(ctx context.Context, roomID id.RoomID, eventID id.EventID, user *User, message string, replyTo id.EventID) } diff --git a/bridgev2/errors.go b/bridgev2/errors.go new file mode 100644 index 00000000..effc21cd --- /dev/null +++ b/bridgev2/errors.go @@ -0,0 +1,82 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgev2 + +import ( + "errors" + + "maunium.net/go/mautrix" +) + +// ErrIgnoringRemoteEvent can be returned by [RemoteMessage.ConvertMessage] or [RemoteEdit.ConvertEdit] +// to indicate that the event should be ignored after all. Handling the event will be cancelled immediately. +var ErrIgnoringRemoteEvent = errors.New("ignoring remote event") + +// ErrNoStatus can be returned by [MatrixMessageResponse.HandleEcho] to indicate that the message is still in-flight +// and a status should not be sent yet. The message will still be saved into the database. +var ErrNoStatus = errors.New("omit message status") + +// ErrResolveIdentifierTryNext can be returned by ResolveIdentifier to signal that the identifier is valid, +// but can't be reached by the current login, and the caller should try the next login if there are more. +// +// This should generally only be returned when resolving internal IDs (which happens when initiating chats via Matrix). +// For example, Google Messages would return this when trying to resolve another login's user ID, +// and Telegram would return this when the access hash isn't available. +var ErrResolveIdentifierTryNext = errors.New("that identifier is not available via this login") + +var ErrNotLoggedIn = errors.New("not logged in") + +// ErrDirectMediaNotEnabled may be returned by Matrix connectors if [MatrixConnector.GenerateContentURI] is called, +// but direct media is not enabled. +var ErrDirectMediaNotEnabled = errors.New("direct media is not enabled") + +// Common message status errors +var ( + ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage() + ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false) + ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false) + ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage() + ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage() + ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage() + ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true).WithErrorAsMessage() + ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage() + ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage() + ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage() + ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) + ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage() + ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) + ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true) + ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) + ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) + ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true) + ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true) + ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true) + ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) +) + +// RespError is a class of error that certain network interface methods can return to ensure that the error +// is properly translated into an HTTP error when the method is called via the provisioning API. +// +// However, unlike mautrix.RespError, this does not include the error code +// in the message shown to users when used outside HTTP contexts. +type RespError mautrix.RespError + +func (re RespError) Error() string { + return re.Err +} + +func (re RespError) Is(err error) bool { + var e2 RespError + if errors.As(err, &e2) { + return e2.Err == re.Err + } + return errors.Is(err, mautrix.RespError(re)) +} + +func WrapRespErr(err error, code string, status int) RespError { + return RespError{ErrCode: code, Err: err.Error(), StatusCode: status} +} diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 5d9e1b8e..107837ef 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -9,6 +9,7 @@ package matrix import ( "context" "encoding/json" + "errors" "fmt" "net/http" "strings" @@ -355,19 +356,13 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque ) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create login process") - jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ - Err: "Failed to create login process", - ErrCode: "M_UNKNOWN", - }) + respondMaybeCustomError(w, err, "Internal error creating login process") return } firstStep, err := login.Start(r.Context()) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to start login") - jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ - Err: "Failed to start login", - ErrCode: "M_UNKNOWN", - }) + respondMaybeCustomError(w, err, "Internal error starting login") return } loginID := xid.New().String() @@ -419,10 +414,7 @@ func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http } if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to submit input") - jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ - Err: "Failed to submit input", - ErrCode: "M_UNKNOWN", - }) + respondMaybeCustomError(w, err, "Internal error submitting input") return } login.NextStep = nextStep @@ -516,6 +508,24 @@ func (prov *ProvisioningAPI) GetLoginForRequest(w http.ResponseWriter, r *http.R return userLogin } +func respondMaybeCustomError(w http.ResponseWriter, err error, message string) { + var mautrixRespErr mautrix.RespError + var bv2RespErr bridgev2.RespError + if errors.As(err, &bv2RespErr) { + mautrixRespErr = mautrix.RespError(bv2RespErr) + } else if !errors.As(err, &mautrixRespErr) { + mautrixRespErr = mautrix.RespError{ + Err: message, + ErrCode: "M_UNKNOWN", + StatusCode: http.StatusInternalServerError, + } + } + if mautrixRespErr.StatusCode == 0 { + mautrixRespErr.StatusCode = http.StatusInternalServerError + } + jsonResponse(w, mautrixRespErr.StatusCode, mautrixRespErr) +} + type RespResolveIdentifier struct { ID networkid.UserID `json:"id"` Name string `json:"name,omitempty"` @@ -541,10 +551,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http. resp, err := api.ResolveIdentifier(r.Context(), mux.Vars(r)["identifier"], createChat) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to resolve identifier") - jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ - Err: fmt.Sprintf("Failed to resolve identifier: %v", err), - ErrCode: "M_UNKNOWN", - }) + respondMaybeCustomError(w, err, "Internal error resolving identifier") return } else if resp == nil { jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ @@ -617,10 +624,7 @@ func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Reque resp, err := api.GetContactList(r.Context()) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get contact list") - jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ - Err: fmt.Sprintf("Failed to get contact list: %v", err), - ErrCode: "M_UNKNOWN", - }) + respondMaybeCustomError(w, err, "Internal error fetching contact list") return } apiResp := &RespGetContactList{ diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index 5876d812..04ee8eca 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -17,32 +17,6 @@ import ( "maunium.net/go/mautrix/id" ) -var ( - ErrIgnoringRemoteEvent = errors.New("ignoring remote event") - ErrNoStatus = errors.New("omit message status") - - ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage() - ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false) - ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false) - ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage() - ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage() - ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage() - ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true).WithErrorAsMessage() - ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage() - ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage() - ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage() - ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) - ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage() - ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) - ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true) - ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) - ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) - ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true) - ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true) - ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true) - ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) -) - type MessageStatusEventInfo struct { RoomID id.RoomID EventID id.EventID @@ -67,6 +41,8 @@ func StatusEventInfoFromEvent(evt *event.Event) *MessageStatusEventInfo { } } +// MessageStatus represents the status of a message. It also implements the error interface to allow network connectors +// to return errors which get translated into user-friendly error messages and/or status events. type MessageStatus struct { Step status.MessageCheckpointStep RetryNum int diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 1a36b583..4b6304f5 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -8,7 +8,6 @@ package bridgev2 import ( "context" - "errors" "fmt" "strings" "time" @@ -227,8 +226,6 @@ type NetworkConnector interface { CreateLogin(ctx context.Context, user *User, flowID string) (LoginProcess, error) } -var ErrDirectMediaNotEnabled = errors.New("direct media is not enabled") - // DirectMediableNetwork is an optional interface that network connectors can implement to support direct media access. // // If the Matrix connector has direct media enabled, SetUseDirectMedia will be called @@ -573,14 +570,6 @@ type CreateChatResponse struct { PortalInfo *ChatInfo } -// ErrResolveIdentifierTryNext can be returned by ResolveIdentifier to signal that the identifier is valid, -// but can't be reached by the current login, and the caller should try the next login if there are more. -// -// This should generally only be returned when resolving internal IDs (which happens when initiating chats via Matrix). -// For example, Google Messages would return this when trying to resolve another login's user ID, -// and Telegram would return this when the access hash isn't available. -var ErrResolveIdentifierTryNext = errors.New("that identifier is not available via this login") - // IdentifierResolvingNetworkAPI is an optional interface that network connectors can implement to support starting new direct chats. type IdentifierResolvingNetworkAPI interface { NetworkAPI diff --git a/client.go b/client.go index 3c91cb65..750e3c25 100644 --- a/client.go +++ b/client.go @@ -569,7 +569,9 @@ func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) { return contents, err } - respErr := &RespError{} + respErr := &RespError{ + StatusCode: res.StatusCode, + } if _ = json.Unmarshal(contents, respErr); respErr.ErrCode == "" { respErr = nil } diff --git a/error.go b/error.go index bcd568d8..30bb8b73 100644 --- a/error.go +++ b/error.go @@ -117,7 +117,9 @@ func (e HTTPError) Unwrap() error { type RespError struct { ErrCode string Err string - ExtraData map[string]interface{} + ExtraData map[string]any + + StatusCode int } func (e *RespError) UnmarshalJSON(data []byte) error { From 48b08ad8e91d8360ae80b5d3ec65626563fc14dc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 10 Aug 2024 23:30:07 +0300 Subject: [PATCH 0614/1647] errors: add status codes to predefined error variables --- bridgev2/errors.go | 6 +++++- error.go | 28 ++++++++++++++-------------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/bridgev2/errors.go b/bridgev2/errors.go index effc21cd..a683da51 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -77,6 +77,10 @@ func (re RespError) Is(err error) bool { return errors.Is(err, mautrix.RespError(re)) } -func WrapRespErr(err error, code string, status int) RespError { +func WrapRespErrManual(err error, code string, status int) RespError { return RespError{ErrCode: code, Err: err.Error(), StatusCode: status} } + +func WrapRespErr(err error, target mautrix.RespError) RespError { + return RespError{ErrCode: target.ErrCode, Err: err.Error(), StatusCode: target.StatusCode} +} diff --git a/error.go b/error.go index 30bb8b73..acd90892 100644 --- a/error.go +++ b/error.go @@ -25,43 +25,43 @@ import ( // } var ( // Forbidden access, e.g. joining a room without permission, failed login. - MForbidden = RespError{ErrCode: "M_FORBIDDEN"} + MForbidden = RespError{ErrCode: "M_FORBIDDEN", StatusCode: http.StatusForbidden} // Unrecognized request, e.g. the endpoint does not exist or is not implemented. - MUnrecognized = RespError{ErrCode: "M_UNRECOGNIZED"} + MUnrecognized = RespError{ErrCode: "M_UNRECOGNIZED", StatusCode: http.StatusNotFound} // The access token specified was not recognised. - MUnknownToken = RespError{ErrCode: "M_UNKNOWN_TOKEN"} + MUnknownToken = RespError{ErrCode: "M_UNKNOWN_TOKEN", StatusCode: http.StatusUnauthorized} // No access token was specified for the request. - MMissingToken = RespError{ErrCode: "M_MISSING_TOKEN"} + MMissingToken = RespError{ErrCode: "M_MISSING_TOKEN", StatusCode: http.StatusUnauthorized} // Request contained valid JSON, but it was malformed in some way, e.g. missing required keys, invalid values for keys. - MBadJSON = RespError{ErrCode: "M_BAD_JSON"} + MBadJSON = RespError{ErrCode: "M_BAD_JSON", StatusCode: http.StatusBadRequest} // Request did not contain valid JSON. - MNotJSON = RespError{ErrCode: "M_NOT_JSON"} + MNotJSON = RespError{ErrCode: "M_NOT_JSON", StatusCode: http.StatusBadRequest} // No resource was found for this request. - MNotFound = RespError{ErrCode: "M_NOT_FOUND"} + MNotFound = RespError{ErrCode: "M_NOT_FOUND", StatusCode: http.StatusNotFound} // Too many requests have been sent in a short period of time. Wait a while then try again. - MLimitExceeded = RespError{ErrCode: "M_LIMIT_EXCEEDED"} + MLimitExceeded = RespError{ErrCode: "M_LIMIT_EXCEEDED", StatusCode: http.StatusTooManyRequests} // The user ID associated with the request has been deactivated. // Typically for endpoints that prove authentication, such as /login. MUserDeactivated = RespError{ErrCode: "M_USER_DEACTIVATED"} // Encountered when trying to register a user ID which has been taken. - MUserInUse = RespError{ErrCode: "M_USER_IN_USE"} + MUserInUse = RespError{ErrCode: "M_USER_IN_USE", StatusCode: http.StatusBadRequest} // Encountered when trying to register a user ID which is not valid. - MInvalidUsername = RespError{ErrCode: "M_INVALID_USERNAME"} + MInvalidUsername = RespError{ErrCode: "M_INVALID_USERNAME", StatusCode: http.StatusBadRequest} // Sent when the room alias given to the createRoom API is already in use. - MRoomInUse = RespError{ErrCode: "M_ROOM_IN_USE"} + MRoomInUse = RespError{ErrCode: "M_ROOM_IN_USE", StatusCode: http.StatusBadRequest} // The state change requested cannot be performed, such as attempting to unban a user who is not banned. MBadState = RespError{ErrCode: "M_BAD_STATE"} // The request or entity was too large. - MTooLarge = RespError{ErrCode: "M_TOO_LARGE"} + MTooLarge = RespError{ErrCode: "M_TOO_LARGE", StatusCode: http.StatusRequestEntityTooLarge} // The resource being requested is reserved by an application service, or the application service making the request has not created the resource. - MExclusive = RespError{ErrCode: "M_EXCLUSIVE"} + MExclusive = RespError{ErrCode: "M_EXCLUSIVE", StatusCode: http.StatusBadRequest} // The client's request to create a room used a room version that the server does not support. MUnsupportedRoomVersion = RespError{ErrCode: "M_UNSUPPORTED_ROOM_VERSION"} // The client attempted to join a room that has a version the server does not support. // Inspect the room_version property of the error response for the room's version. MIncompatibleRoomVersion = RespError{ErrCode: "M_INCOMPATIBLE_ROOM_VERSION"} // The client specified a parameter that has the wrong value. - MInvalidParam = RespError{ErrCode: "M_INVALID_PARAM"} + MInvalidParam = RespError{ErrCode: "M_INVALID_PARAM", StatusCode: http.StatusBadRequest} MURLNotSet = RespError{ErrCode: "M_URL_NOT_SET"} MBadStatus = RespError{ErrCode: "M_BAD_STATUS"} From 4112286f551d74c2349deee6c8323d47990d4f62 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 11 Aug 2024 21:32:37 +0300 Subject: [PATCH 0615/1647] bridgev2/matrix: update nocrypto build tag Fixes mautrix/slack#55 --- bridgev2/matrix/no-crypto.go | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/bridgev2/matrix/no-crypto.go b/bridgev2/matrix/no-crypto.go index 5b05272c..fe942f83 100644 --- a/bridgev2/matrix/no-crypto.go +++ b/bridgev2/matrix/no-crypto.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -10,15 +10,13 @@ package matrix import ( "errors" - - "maunium.net/go/mautrix/bridge" ) -func NewCryptoHelper(bridge *bridge.Bridge) bridge.Crypto { - if bridge.Config.Bridge.GetEncryptionConfig().Allow { - bridge.ZLog.Warn().Msg("Bridge built without end-to-bridge encryption, but encryption is enabled in config") +func NewCryptoHelper(c *Connector) Crypto { + if c.Config.Encryption.Allow { + c.Log.Warn().Msg("Bridge built without end-to-bridge encryption, but encryption is enabled in config") } else { - bridge.ZLog.Debug().Msg("Bridge built without end-to-bridge encryption") + c.Log.Debug().Msg("Bridge built without end-to-bridge encryption") } return nil } From e13771ff615ed6e4058978b97e46b23a7d925e3e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 11 Aug 2024 21:47:35 +0300 Subject: [PATCH 0616/1647] dependencies: update --- go.mod | 13 +++++++------ go.sum | 26 ++++++++++++++------------ 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/go.mod b/go.mod index e3d6cfe3..563f7e98 100644 --- a/go.mod +++ b/go.mod @@ -15,11 +15,11 @@ require ( github.com/tidwall/gjson v1.17.3 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.4 - go.mau.fi/util v0.6.1-0.20240802175451-b430ebbffc98 + go.mau.fi/util v0.6.1-0.20240811184504-b00aa5c5af3a go.mau.fi/zeroconfig v0.1.3 - golang.org/x/crypto v0.25.0 - golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 - golang.org/x/net v0.27.0 + golang.org/x/crypto v0.26.0 + golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa + golang.org/x/net v0.28.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 ) @@ -29,10 +29,11 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect + github.com/petermattis/goid v0.0.0-20240716203034-badd1c0974d6 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect - golang.org/x/sys v0.22.0 // indirect - golang.org/x/text v0.16.0 // indirect + golang.org/x/sys v0.24.0 // indirect + golang.org/x/text v0.17.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index ed4ca496..f7b10211 100644 --- a/go.sum +++ b/go.sum @@ -24,6 +24,8 @@ github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APP github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/petermattis/goid v0.0.0-20240716203034-badd1c0974d6 h1:DUDJI8T/9NcGbbL+AWk6vIYlmQ8ZBS8LZqVre6zbkPQ= +github.com/petermattis/goid v0.0.0-20240716203034-badd1c0974d6/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -46,24 +48,24 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.6.1-0.20240802175451-b430ebbffc98 h1:gJ0peWecBm6TtlxKFVIc1KbooXSCHtPfsfb2Eha5A0A= -go.mau.fi/util v0.6.1-0.20240802175451-b430ebbffc98/go.mod h1:S1juuPWGau2GctPY3FR/4ec/MDLhAG2QPhdnUwpzWIo= +go.mau.fi/util v0.6.1-0.20240811184504-b00aa5c5af3a h1:A6AeueGxoDjSSf2X8Tz8X9nQ2S65uYWGVwlvTZa7Bjs= +go.mau.fi/util v0.6.1-0.20240811184504-b00aa5c5af3a/go.mod h1:ZRiX8FK4CsqVINI+3YK50nHnc+dKhfTZNf38zI31S/0= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= -golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= -golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8= -golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= -golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= -golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= +golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= +golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa h1:ELnwvuAXPNtPk1TJRuGkI9fDTwym6AYBu0qzT8AcHdI= +golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ= +golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= +golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= -golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= +golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= +golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= From 7a2b6a93bc4719f856b5c630f69a9d00cd19f221 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 12 Aug 2024 13:53:55 +0300 Subject: [PATCH 0617/1647] bridgev2/errors: add helper to append to RespError message --- bridgev2/errors.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/bridgev2/errors.go b/bridgev2/errors.go index a683da51..809f0fba 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -8,6 +8,7 @@ package bridgev2 import ( "errors" + "fmt" "maunium.net/go/mautrix" ) @@ -77,6 +78,11 @@ func (re RespError) Is(err error) bool { return errors.Is(err, mautrix.RespError(re)) } +func (re RespError) AppendMessage(append string, args ...any) RespError { + re.Err += fmt.Sprintf(append, args...) + return re +} + func WrapRespErrManual(err error, code string, status int) RespError { return RespError{ErrCode: code, Err: err.Error(), StatusCode: status} } From 41f0abd38a0f539b108f2287d0c69c3fe7008339 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 12 Aug 2024 14:32:42 +0300 Subject: [PATCH 0618/1647] bridgev2/matrix: ignore already in room errors when sending invites --- bridgev2/matrix/intent.go | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 91d955f9..e789fa75 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -10,6 +10,7 @@ import ( "context" "errors" "fmt" + "strings" "sync" "time" @@ -109,15 +110,23 @@ func (as *ASIntent) fillMemberEvent(ctx context.Context, roomID id.RoomID, userI } } -func (as *ASIntent) SendState(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, content *event.Content, ts time.Time) (*mautrix.RespSendEvent, error) { +func (as *ASIntent) SendState(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, content *event.Content, ts time.Time) (resp *mautrix.RespSendEvent, err error) { if eventType == event.StateMember { as.fillMemberEvent(ctx, roomID, id.UserID(stateKey), content) } if ts.IsZero() { - return as.Matrix.SendStateEvent(ctx, roomID, eventType, stateKey, content) + resp, err = as.Matrix.SendStateEvent(ctx, roomID, eventType, stateKey, content) } else { - return as.Matrix.SendMassagedStateEvent(ctx, roomID, eventType, stateKey, content, ts.UnixMilli()) + resp, err = as.Matrix.SendMassagedStateEvent(ctx, roomID, eventType, stateKey, content, ts.UnixMilli()) } + if err != nil && eventType == event.StateMember { + var httpErr mautrix.HTTPError + if errors.As(err, &httpErr) && httpErr.RespError != nil && + (strings.Contains(httpErr.RespError.Err, "is already in the room") || strings.Contains(httpErr.RespError.Err, "is already joined to room")) { + err = as.Matrix.StateStore.SetMembership(ctx, roomID, id.UserID(stateKey), event.MembershipJoin) + } + } + return resp, err } func (as *ASIntent) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.EventID, ts time.Time) (err error) { From 091a18d448de14ce78abb08f40da15a2cea303d7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 12 Aug 2024 19:08:58 +0300 Subject: [PATCH 0619/1647] bridgev2/backfill: allow resync events to have bundled backfill data as an optimization The network connector can provide arbitrary data in RemoteChatResync events, which is passed to FetchMessages if the event triggers a backfill. The network connector can then read the data and avoid refetching those bundled messages. --- bridgev2/networkinterface.go | 21 ++++++++++++++++----- bridgev2/portal.go | 20 +++++++++++++++----- bridgev2/portalbackfill.go | 13 ++++++++++++- bridgev2/portalinternal.go | 8 ++++---- bridgev2/simplevent/chat.go | 12 +++++++++--- 5 files changed, 56 insertions(+), 18 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 4b6304f5..70854158 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -372,6 +372,12 @@ type FetchMessagesParams struct { // without any side effects, but the network connector should aim for this number. Count int + // When a forward backfill is triggered by a [RemoteChatResyncBackfillBundle], this will contain + // the bundled data returned by the event. It can be used as an optimization to avoid fetching + // messages that were already provided by the remote network, while still supporting fetching + // more messages if the limit is higher. + BundledData any + // When the messages are being fetched for a queued backfill, this is the task object. Task *database.BackfillTask } @@ -797,6 +803,16 @@ type RemoteChatResyncBackfill interface { CheckNeedsBackfill(ctx context.Context, latestMessage *database.Message) (bool, error) } +type RemoteChatResyncBackfillBundle interface { + RemoteChatResyncBackfill + GetBundledBackfillData() any +} + +type RemoteBackfill interface { + RemoteEvent + GetBackfillData(ctx context.Context, portal *Portal) (*FetchMessagesResponse, error) +} + type RemoteChatDelete interface { RemoteEvent DeleteOnlyForMe() bool @@ -931,11 +947,6 @@ type RemoteTyping interface { GetTimeout() time.Duration } -type RemoteBackfill interface { - RemoteEvent - GetBackfillData(ctx context.Context, portal *Portal) (*FetchMessagesResponse, error) -} - type TypingType int const ( diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 3c2db930..c8e4d3c9 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -302,7 +302,7 @@ func (portal *Portal) handleCreateEvent(evt *portalCreateEvent) { evt.cb(fmt.Errorf("portal creation panicked")) } }() - evt.cb(portal.createMatrixRoomInLoop(evt.ctx, evt.source, evt.info)) + evt.cb(portal.createMatrixRoomInLoop(evt.ctx, evt.source, evt.info, nil)) } func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowRelay bool) (*UserLogin, *database.UserPortal, error) { @@ -1329,7 +1329,12 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { log.Err(err).Msg("Failed to get chat info for portal creation from chat resync event") } } - err = portal.createMatrixRoomInLoop(ctx, source, info) + bundleProvider, ok := evt.(RemoteChatResyncBackfillBundle) + var bundle any + if ok { + bundle = bundleProvider.GetBundledBackfillData() + } + err = portal.createMatrixRoomInLoop(ctx, source, info, bundle) if err != nil { log.Err(err).Msg("Failed to create portal to handle event") // TODO error @@ -2279,7 +2284,12 @@ func (portal *Portal) handleRemoteChatResync(ctx context.Context, source *UserLo } else if needsBackfill, err := backfillChecker.CheckNeedsBackfill(ctx, latestMessage); err != nil { log.Err(err).Msg("Failed to check if backfill is needed") } else if needsBackfill { - portal.doForwardBackfill(ctx, source, latestMessage) + bundleProvider, ok := evt.(RemoteChatResyncBackfillBundle) + var bundle any + if ok { + bundle = bundleProvider.GetBundledBackfillData() + } + portal.doForwardBackfill(ctx, source, latestMessage, bundle) } } } @@ -3077,7 +3087,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i } } -func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLogin, info *ChatInfo) error { +func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLogin, info *ChatInfo, backfillBundle any) error { portal.roomCreateLock.Lock() defer portal.roomCreateLock.Unlock() if portal.MXID != "" { @@ -3259,7 +3269,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo } } } - portal.doForwardBackfill(ctx, source, nil) + portal.doForwardBackfill(ctx, source, nil, backfillBundle) return nil } diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index b5ac08bb..f8f4bdb2 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -23,7 +23,7 @@ import ( "maunium.net/go/mautrix/id" ) -func (portal *Portal) doForwardBackfill(ctx context.Context, source *UserLogin, lastMessage *database.Message) { +func (portal *Portal) doForwardBackfill(ctx context.Context, source *UserLogin, lastMessage *database.Message, bundledData any) { log := zerolog.Ctx(ctx).With().Str("action", "forward backfill").Logger() ctx = log.WithContext(ctx) api, ok := source.Client.(BackfillingNetworkAPI) @@ -51,10 +51,14 @@ func (portal *Portal) doForwardBackfill(ctx context.Context, source *UserLogin, Forward: true, AnchorMessage: lastMessage, Count: limit, + BundledData: bundledData, }) if err != nil { log.Err(err).Msg("Failed to fetch messages for forward backfill") return + } else if resp == nil { + log.Debug().Msg("Didn't get backfill response") + return } else if len(resp.Messages) == 0 { log.Debug().Msg("No messages to backfill") return @@ -100,6 +104,10 @@ func (portal *Portal) DoBackwardsBackfill(ctx context.Context, source *UserLogin }) if err != nil { return fmt.Errorf("failed to fetch messages for backward backfill: %w", err) + } else if resp == nil { + log.Debug().Msg("Didn't get backfill response, marking task as done") + task.IsDone = true + return nil } log.Debug(). Str("new_cursor", string(resp.Cursor)). @@ -150,6 +158,9 @@ func (portal *Portal) doThreadBackfill(ctx context.Context, source *UserLogin, t if err != nil { log.Err(err).Msg("Failed to fetch messages for thread backfill") return + } else if resp == nil { + log.Debug().Msg("Didn't get backfill response") + return } else if len(resp.Messages) == 0 { log.Debug().Msg("No messages to backfill") return diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index f7fe658a..dcd9174b 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -261,8 +261,8 @@ func (portal *PortalInternals) LockedUpdateInfoFromGhost(ctx context.Context, gh (*Portal)(portal).lockedUpdateInfoFromGhost(ctx, ghost) } -func (portal *PortalInternals) CreateMatrixRoomInLoop(ctx context.Context, source *UserLogin, info *ChatInfo) error { - return (*Portal)(portal).createMatrixRoomInLoop(ctx, source, info) +func (portal *PortalInternals) CreateMatrixRoomInLoop(ctx context.Context, source *UserLogin, info *ChatInfo, backfillBundle any) error { + return (*Portal)(portal).createMatrixRoomInLoop(ctx, source, info, backfillBundle) } func (portal *PortalInternals) UnlockedDelete(ctx context.Context) error { @@ -273,8 +273,8 @@ func (portal *PortalInternals) UnlockedDeleteCache() { (*Portal)(portal).unlockedDeleteCache() } -func (portal *PortalInternals) DoForwardBackfill(ctx context.Context, source *UserLogin, lastMessage *database.Message) { - (*Portal)(portal).doForwardBackfill(ctx, source, lastMessage) +func (portal *PortalInternals) DoForwardBackfill(ctx context.Context, source *UserLogin, lastMessage *database.Message, bundledData any) { + (*Portal)(portal).doForwardBackfill(ctx, source, lastMessage, bundledData) } func (portal *PortalInternals) DoThreadBackfill(ctx context.Context, source *UserLogin, threadID networkid.MessageID) { diff --git a/bridgev2/simplevent/chat.go b/bridgev2/simplevent/chat.go index e7b13fef..c725141b 100644 --- a/bridgev2/simplevent/chat.go +++ b/bridgev2/simplevent/chat.go @@ -30,12 +30,14 @@ type ChatResync struct { LatestMessageTS time.Time CheckNeedsBackfillFunc func(ctx context.Context, latestMessage *database.Message) (bool, error) + BundledBackfillData any } var ( - _ bridgev2.RemoteChatResync = (*ChatResync)(nil) - _ bridgev2.RemoteChatResyncWithInfo = (*ChatResync)(nil) - _ bridgev2.RemoteChatResyncBackfill = (*ChatResync)(nil) + _ bridgev2.RemoteChatResync = (*ChatResync)(nil) + _ bridgev2.RemoteChatResyncWithInfo = (*ChatResync)(nil) + _ bridgev2.RemoteChatResyncBackfill = (*ChatResync)(nil) + _ bridgev2.RemoteChatResyncBackfillBundle = (*ChatResync)(nil) ) func (evt *ChatResync) CheckNeedsBackfill(ctx context.Context, latestMessage *database.Message) (bool, error) { @@ -48,6 +50,10 @@ func (evt *ChatResync) CheckNeedsBackfill(ctx context.Context, latestMessage *da } } +func (evt *ChatResync) GetBundledBackfillData() any { + return evt.BundledBackfillData +} + func (evt *ChatResync) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.ChatInfo, error) { if evt.GetChatInfoFunc != nil { return evt.GetChatInfoFunc(ctx, portal) From 1e98cb6a2ec8ec633b1f0a8c544dced0e4983774 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 12 Aug 2024 22:28:40 +0300 Subject: [PATCH 0620/1647] bridgev2: add support for handling Matrix power level changes --- bridgev2/errors.go | 1 + bridgev2/networkinterface.go | 46 ++++++++- bridgev2/portal.go | 188 +++++++++++++++++++++++++++++------ 3 files changed, 199 insertions(+), 36 deletions(-) diff --git a/bridgev2/errors.go b/bridgev2/errors.go index 809f0fba..2834b298 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -57,6 +57,7 @@ var ( ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true) ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true) ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) + ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) ) // RespError is a class of error that certain network interface methods can return to ensure that the error diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 70854158..73789be0 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -627,11 +627,22 @@ var ( Unban = MembershipChangeType{From: event.MembershipBan, To: event.MembershipLeave} ) +type GhostOrUserLogin interface { + isGhostOrUserLogin() +} + +func (*Ghost) isGhostOrUserLogin() {} +func (*UserLogin) isGhostOrUserLogin() {} + type MatrixMembershipChange struct { - MatrixEventBase[*event.MemberEventContent] - TargetGhost *Ghost + MatrixRoomMeta[*event.MemberEventContent] + Target GhostOrUserLogin + Type MembershipChangeType + + // Deprecated: Use Target instead + TargetGhost *Ghost + // Deprecated: Use Target instead TargetUserLogin *UserLogin - Type MembershipChangeType } type MembershipHandlingNetworkAPI interface { @@ -639,6 +650,35 @@ type MembershipHandlingNetworkAPI interface { HandleMatrixMembership(ctx context.Context, msg *MatrixMembershipChange) (bool, error) } +type SinglePowerLevelChange struct { + OrigLevel int + NewLevel int + NewIsSet bool +} + +type UserPowerLevelChange struct { + Target GhostOrUserLogin + SinglePowerLevelChange +} + +type MatrixPowerLevelChange struct { + MatrixRoomMeta[*event.PowerLevelsEventContent] + Users map[id.UserID]*UserPowerLevelChange + Events map[string]*SinglePowerLevelChange + UsersDefault *SinglePowerLevelChange + EventsDefault *SinglePowerLevelChange + StateDefault *SinglePowerLevelChange + Invite *SinglePowerLevelChange + Kick *SinglePowerLevelChange + Ban *SinglePowerLevelChange + Redact *SinglePowerLevelChange +} + +type PowerLevelHandlingNetworkAPI interface { + NetworkAPI + HandleMatrixPowerLevels(ctx context.Context, msg *MatrixPowerLevelChange) (bool, error) +} + type PushType int func (pt PushType) String() string { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index c8e4d3c9..192efb44 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -494,6 +494,8 @@ func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) { handleMatrixAccountData(portal, ctx, login, evt, MuteHandlingNetworkAPI.HandleMute) case event.StateMember: portal.handleMatrixMembership(ctx, login, origSender, evt) + case event.StatePowerLevels: + portal.handleMatrixPowerLevels(ctx, login, origSender, evt) } } @@ -1150,58 +1152,80 @@ func handleMatrixAccountData[APIType any, ContentType any]( } } +func (portal *Portal) getTargetUser(ctx context.Context, userID id.UserID) (GhostOrUserLogin, error) { + if targetGhost, err := portal.Bridge.GetGhostByMXID(ctx, userID); err != nil { + return nil, fmt.Errorf("failed to get ghost: %w", err) + } else if targetGhost != nil { + return targetGhost, nil + } else if targetUser, err := portal.Bridge.GetUserByMXID(ctx, userID); err != nil { + return nil, fmt.Errorf("failed to get user: %w", err) + } else if targetUserLogin, _, err := portal.FindPreferredLogin(ctx, targetUser, false); err != nil { + return nil, fmt.Errorf("failed to find preferred login: %w", err) + } else if targetUserLogin != nil { + return targetUserLogin, nil + } else { + // Return raw nil as a separate case to ensure a typed nil isn't returned + return nil, nil + } +} + func (portal *Portal) handleMatrixMembership( ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, ) { - api, ok := sender.Client.(MembershipHandlingNetworkAPI) - if !ok { - portal.sendErrorStatus(ctx, evt, ErrMembershipNotSupported) - return - } log := zerolog.Ctx(ctx) - targetMXID := id.UserID(*evt.StateKey) - isSelf := sender.User.MXID == targetMXID - var err error - var targetUserLogin *UserLogin - targetGhost, err := portal.Bridge.GetGhostByMXID(ctx, targetMXID) - if err != nil { - log.Err(err).Stringer("mxid", targetMXID).Msg("Failed to get target ghost") + content, ok := evt.Content.Parsed.(*event.MemberEventContent) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) return } - if targetGhost == nil { - targetUser, err := portal.Bridge.GetUserByMXID(ctx, targetMXID) - if err != nil { - log.Err(err).Stringer("mxid", targetMXID).Msg("Failed to get target user") - return - } - targetUserLogin, _, err = portal.FindPreferredLogin(ctx, targetUser, false) - if err != nil { - log.Err(err).Stringer("mxid", targetMXID).Msg("Failed to get target user login") - return - } - } prevContent := &event.MemberEventContent{Membership: event.MembershipLeave} if evt.Unsigned.PrevContent != nil { _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) prevContent, _ = evt.Unsigned.PrevContent.Parsed.(*event.MemberEventContent) } + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c. + Str("membership", string(content.Membership)). + Str("prev_membership", string(prevContent.Membership)). + Str("target_user_id", evt.GetStateKey()) + }) + api, ok := sender.Client.(MembershipHandlingNetworkAPI) + if !ok { + portal.sendErrorStatus(ctx, evt, ErrMembershipNotSupported) + return + } + targetMXID := id.UserID(*evt.StateKey) + isSelf := sender.User.MXID == targetMXID + target, err := portal.getTargetUser(ctx, targetMXID) + if err != nil { + log.Err(err).Msg("Failed to get member event target") + portal.sendErrorStatus(ctx, evt, err) + return + } - content := evt.Content.AsMember() membershipChangeType := MembershipChangeType{From: prevContent.Membership, To: content.Membership, IsSelf: isSelf} if !portal.Bridge.Config.BridgeMatrixLeave && membershipChangeType == Leave { log.Debug().Msg("Dropping leave event") + //portal.sendErrorStatus(ctx, evt, ErrIgnoringLeaveEvent) return } + targetGhost, _ := target.(*Ghost) + targetUserLogin, _ := target.(*UserLogin) membershipChange := &MatrixMembershipChange{ - MatrixEventBase: MatrixEventBase[*event.MemberEventContent]{ - Event: evt, - Content: content, - Portal: portal, - OrigSender: origSender, + MatrixRoomMeta: MatrixRoomMeta[*event.MemberEventContent]{ + MatrixEventBase: MatrixEventBase[*event.MemberEventContent]{ + Event: evt, + Content: content, + Portal: portal, + OrigSender: origSender, + }, + PrevContent: prevContent, }, + Target: target, TargetGhost: targetGhost, TargetUserLogin: targetUserLogin, Type: membershipChangeType, @@ -1214,6 +1238,101 @@ func (portal *Portal) handleMatrixMembership( } } +func makePLChange(old, new int, newIsSet bool) *SinglePowerLevelChange { + if old == new { + return nil + } + return &SinglePowerLevelChange{OrigLevel: old, NewLevel: new, NewIsSet: newIsSet} +} + +func getUniqueKeys[Key comparable, Value any](maps ...map[Key]Value) map[Key]struct{} { + unique := make(map[Key]struct{}) + for _, m := range maps { + for k := range m { + unique[k] = struct{}{} + } + } + return unique +} + +func (portal *Portal) handleMatrixPowerLevels( + ctx context.Context, + sender *UserLogin, + origSender *OrigSender, + evt *event.Event, +) { + log := zerolog.Ctx(ctx) + content, ok := evt.Content.Parsed.(*event.PowerLevelsEventContent) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + return + } + api, ok := sender.Client.(PowerLevelHandlingNetworkAPI) + if !ok { + portal.sendErrorStatus(ctx, evt, ErrPowerLevelsNotSupported) + return + } + prevContent := &event.PowerLevelsEventContent{} + if evt.Unsigned.PrevContent != nil { + _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) + prevContent, _ = evt.Unsigned.PrevContent.Parsed.(*event.PowerLevelsEventContent) + } + + plChange := &MatrixPowerLevelChange{ + MatrixRoomMeta: MatrixRoomMeta[*event.PowerLevelsEventContent]{ + MatrixEventBase: MatrixEventBase[*event.PowerLevelsEventContent]{ + Event: evt, + Content: content, + Portal: portal, + OrigSender: origSender, + }, + PrevContent: prevContent, + }, + Users: make(map[id.UserID]*UserPowerLevelChange), + Events: make(map[string]*SinglePowerLevelChange), + UsersDefault: makePLChange(prevContent.UsersDefault, content.UsersDefault, true), + EventsDefault: makePLChange(prevContent.EventsDefault, content.EventsDefault, true), + StateDefault: makePLChange(prevContent.StateDefault(), content.StateDefault(), content.StateDefaultPtr != nil), + Invite: makePLChange(prevContent.Invite(), content.Invite(), content.InvitePtr != nil), + Kick: makePLChange(prevContent.Kick(), content.Kick(), content.KickPtr != nil), + Ban: makePLChange(prevContent.Ban(), content.Ban(), content.BanPtr != nil), + Redact: makePLChange(prevContent.Redact(), content.Redact(), content.RedactPtr != nil), + } + for eventType := range getUniqueKeys(content.Events, prevContent.Events) { + newLevel, hasNewLevel := content.Events[eventType] + if !hasNewLevel { + // TODO this doesn't handle state events properly + newLevel = content.EventsDefault + } + if change := makePLChange(prevContent.Events[eventType], newLevel, hasNewLevel); change != nil { + plChange.Events[eventType] = change + } + } + for user := range getUniqueKeys(content.Users, prevContent.Users) { + _, hasNewLevel := content.Users[user] + change := makePLChange(prevContent.GetUserLevel(user), content.GetUserLevel(user), hasNewLevel) + if change == nil { + continue + } + target, err := portal.getTargetUser(ctx, user) + if err != nil { + log.Err(err).Stringer("target_user_id", user).Msg("Failed to get user for power level change") + } else { + plChange.Users[user] = &UserPowerLevelChange{ + Target: target, + SinglePowerLevelChange: *change, + } + } + } + _, err := api.HandleMatrixPowerLevels(ctx, plChange) + if err != nil { + log.Err(err).Msg("Failed to handle Matrix power level change") + portal.sendErrorStatus(ctx, evt, err) + return + } +} + func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(*event.RedactionEventContent) @@ -2368,10 +2487,10 @@ type ChatMemberList struct { OtherUserID networkid.UserID Members []ChatMember - PowerLevels *PowerLevelChanges + PowerLevels *PowerLevelOverrides } -type PowerLevelChanges struct { +type PowerLevelOverrides struct { Events map[event.Type]int UsersDefault *int EventsDefault *int @@ -2384,11 +2503,14 @@ type PowerLevelChanges struct { Custom func(*event.PowerLevelsEventContent) bool } +// Deprecated: renamed to PowerLevelOverrides +type PowerLevelChanges = PowerLevelOverrides + func allowChange(newLevel, oldLevel, actorLevel int) bool { return newLevel <= actorLevel && oldLevel <= actorLevel } -func (plc *PowerLevelChanges) Apply(actor id.UserID, content *event.PowerLevelsEventContent) (changed bool) { +func (plc *PowerLevelOverrides) Apply(actor id.UserID, content *event.PowerLevelsEventContent) (changed bool) { if plc == nil || content == nil { return } From 8f5b3ac66f8080b6ad3f68cad3285f66b5dcdf13 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 13 Aug 2024 15:05:28 +0300 Subject: [PATCH 0621/1647] bridgev2/portal: fix last read TS when bridging read receipts from Matrix --- bridgev2/portal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 192efb44..c3e6f334 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -552,8 +552,8 @@ func (portal *Portal) handleMatrixReadReceipt(ctx context.Context, user *User, e if userPortal == nil { userPortal = database.UserPortalFor(login.UserLogin, portal.PortalKey) } else { - userPortal = userPortal.CopyWithoutValues() evt.LastRead = userPortal.LastRead + userPortal = userPortal.CopyWithoutValues() } evt.ExactMessage, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, eventID) if err != nil { From b899ef773f8f9ca4ff216d89abe0cae278bbc64a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 13 Aug 2024 15:05:44 +0300 Subject: [PATCH 0622/1647] bridgev2/matrix: enable handling power level events --- bridgev2/matrix/connector.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index f3870c1a..ee4be11c 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -133,6 +133,7 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) { br.EventProcessor.On(event.EventRedaction, br.handleRoomEvent) br.EventProcessor.On(event.EventEncrypted, br.handleEncryptedEvent) br.EventProcessor.On(event.StateMember, br.handleRoomEvent) + br.EventProcessor.On(event.StatePowerLevels, br.handleRoomEvent) br.EventProcessor.On(event.StateRoomName, br.handleRoomEvent) br.EventProcessor.On(event.StateRoomAvatar, br.handleRoomEvent) br.EventProcessor.On(event.StateTopic, br.handleRoomEvent) From 47bdb8b2f6f59e0826401aa2060d8475731047e6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 13 Aug 2024 15:06:01 +0300 Subject: [PATCH 0623/1647] bridgev2/matrix: send message status in background --- bridgev2/matrix/connector.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index ee4be11c..c2b3f7cd 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -429,7 +429,7 @@ func (br *Connector) SendBridgeStatus(ctx context.Context, state *status.BridgeS } func (br *Connector) SendMessageStatus(ctx context.Context, ms *bridgev2.MessageStatus, evt *bridgev2.MessageStatusEventInfo) { - br.internalSendMessageStatus(ctx, ms, evt, "") + go br.internalSendMessageStatus(ctx, ms, evt, "") } func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2.MessageStatus, evt *bridgev2.MessageStatusEventInfo, editEvent id.EventID) id.EventID { From 2883ac81726fef357572918ae543ee89474ab79b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 13 Aug 2024 15:48:25 +0300 Subject: [PATCH 0624/1647] bridgev2/backfill: ignore messages with no parts --- bridgev2/portalbackfill.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index f8f4bdb2..4a977a3e 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -256,6 +256,9 @@ func (portal *Portal) sendBatch(ctx context.Context, source *UserLogin, messages extras := make([]*MatrixSendExtra, 0, len(messages)) var disappearingMessages []*database.DisappearingMessage for _, msg := range messages { + if len(msg.Parts) == 0 { + continue + } intent := portal.GetIntentFor(ctx, msg.Sender, source, RemoteEventMessage) replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta(ctx, msg.ID, msg.ReplyTo, msg.ThreadRoot, true) if threadRoot != nil && prevThreadEvents[*msg.ThreadRoot] != "" { From 81be525ab6aaec5f833e4afebca60c28fd4477b6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 13 Aug 2024 15:55:04 +0300 Subject: [PATCH 0625/1647] bridgev2/backfill: include portal key in backfill queue context logger --- bridgev2/backfillqueue.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bridgev2/backfillqueue.go b/bridgev2/backfillqueue.go index 63a01f68..6b9242c2 100644 --- a/bridgev2/backfillqueue.go +++ b/bridgev2/backfillqueue.go @@ -88,6 +88,7 @@ func (br *Bridge) doBackfillTask(ctx context.Context, task *database.BackfillTas Object("portal_key", task.PortalKey). Str("login_id", string(task.UserLoginID)). Logger() + ctx = log.WithContext(ctx) err := br.DB.BackfillTask.MarkDispatched(ctx, task) if err != nil { log.Err(err).Msg("Failed to mark backfill task as dispatched") From 755ba0f7d6dea9aa9bdd07f437cacd3eec1c6ea8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 13 Aug 2024 16:23:39 +0300 Subject: [PATCH 0626/1647] bridgev2: add config option to only bridge tag/mute on room create --- bridgev2/bridgeconfig/config.go | 4 +++- bridgev2/bridgeconfig/upgrade.go | 3 +++ bridgev2/matrix/mxmain/example-config.yaml | 9 +++++++++ bridgev2/portal.go | 10 +++++----- bridgev2/portalinternal.go | 12 ++++++++++-- 5 files changed, 30 insertions(+), 8 deletions(-) diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index b7a0ff37..ab97c891 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -59,11 +59,13 @@ type BridgeConfig struct { CommandPrefix string `yaml:"command_prefix"` PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"` PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"` + BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` + TagOnlyOnCreate bool `yaml:"tag_only_on_create"` + MuteOnlyOnCreate bool `yaml:"mute_only_on_create"` CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"` Relay RelayConfig `yaml:"relay"` Permissions PermissionConfig `yaml:"permissions"` Backfill BackfillConfig `yaml:"backfill"` - BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` } type MatrixConfig struct { diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 6650b9d1..4eff205d 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -25,6 +25,9 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Str, "bridge", "command_prefix") helper.Copy(up.Bool, "bridge", "personal_filtering_spaces") helper.Copy(up.Bool, "bridge", "private_chat_portal_meta") + helper.Copy(up.Bool, "bridge", "bridge_matrix_leave") + helper.Copy(up.Bool, "bridge", "tag_only_on_create") + helper.Copy(up.Bool, "bridge", "mute_only_on_create") helper.Copy(up.Bool, "bridge", "cleanup_on_logout", "enabled") helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "private") helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "relayed") diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index deabe3c8..06ed010f 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -8,6 +8,15 @@ bridge: # This is only necessary when using clients that don't support MSC4171. private_chat_portal_meta: false + # Should leaving Matrix rooms be bridged as leaving groups on the remote network? + bridge_matrix_leave: false + # Should room tags only be synced when creating the portal? Tags mean things like favorite/pin and archive/low priority. + # Tags currently can't be synced back to the remote network, so a continuous sync means tagging from Matrix will be undone. + tag_only_on_create: true + # Should room mute status only be synced when creating the portal? + # Like tags, mutes can't currently be synced back to the remote network. + mute_only_on_create: true + # What should be done to portal rooms when a user logs out or is logged out? # Permitted values: # nothing - Do nothing, let the user stay in the portals diff --git a/bridgev2/portal.go b/bridgev2/portal.go index c3e6f334..cbcc325e 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2976,7 +2976,7 @@ func (portal *Portal) syncParticipants(ctx context.Context, members *ChatMemberL return nil } -func (portal *Portal) updateUserLocalInfo(ctx context.Context, info *UserLocalPortalInfo, source *UserLogin) { +func (portal *Portal) updateUserLocalInfo(ctx context.Context, info *UserLocalPortalInfo, source *UserLogin, didJustCreate bool) { if portal.MXID == "" { return } @@ -2996,13 +2996,13 @@ func (portal *Portal) updateUserLocalInfo(ctx context.Context, info *UserLocalPo if info == nil { return } - if info.MutedUntil != nil { + if info.MutedUntil != nil && (didJustCreate || !portal.Bridge.Config.MuteOnlyOnCreate) { err := dp.MuteRoom(ctx, portal.MXID, *info.MutedUntil) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to mute room") } } - if info.Tag != nil { + if info.Tag != nil && (didJustCreate || !portal.Bridge.Config.TagOnlyOnCreate) { err := dp.TagRoom(ctx, portal.MXID, *info.Tag, *info.Tag != "") if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to tag room") @@ -3159,7 +3159,7 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us changed = portal.UpdateInfoFromGhost(ctx, nil) || changed if source != nil { source.MarkInPortal(ctx, portal) - portal.updateUserLocalInfo(ctx, info.UserLocal, source) + portal.updateUserLocalInfo(ctx, info.UserLocal, source, false) } if info.CanBackfill && source != nil { err := portal.Bridge.DB.BackfillTask.EnsureExists(ctx, portal.PortalKey, source.ID) @@ -3360,7 +3360,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo go portal.createParentAndAddToSpace(ctx, source) } } - portal.updateUserLocalInfo(ctx, info.UserLocal, source) + portal.updateUserLocalInfo(ctx, info.UserLocal, source, true) if !autoJoinInvites { if info.Members == nil { dp := source.User.DoublePuppet(ctx) diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index dcd9174b..ffad2ac9 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -93,10 +93,18 @@ func (portal *PortalInternals) HandleMatrixReaction(ctx context.Context, sender (*Portal)(portal).handleMatrixReaction(ctx, sender, evt) } +func (portal *PortalInternals) GetTargetUser(ctx context.Context, userID id.UserID) (GhostOrUserLogin, error) { + return (*Portal)(portal).getTargetUser(ctx, userID) +} + func (portal *PortalInternals) HandleMatrixMembership(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { (*Portal)(portal).handleMatrixMembership(ctx, sender, origSender, evt) } +func (portal *PortalInternals) HandleMatrixPowerLevels(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { + (*Portal)(portal).handleMatrixPowerLevels(ctx, sender, origSender, evt) +} + func (portal *PortalInternals) HandleMatrixRedaction(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { (*Portal)(portal).handleMatrixRedaction(ctx, sender, origSender, evt) } @@ -249,8 +257,8 @@ func (portal *PortalInternals) SyncParticipants(ctx context.Context, members *Ch return (*Portal)(portal).syncParticipants(ctx, members, source, sender, ts) } -func (portal *PortalInternals) UpdateUserLocalInfo(ctx context.Context, info *UserLocalPortalInfo, source *UserLogin) { - (*Portal)(portal).updateUserLocalInfo(ctx, info, source) +func (portal *PortalInternals) UpdateUserLocalInfo(ctx context.Context, info *UserLocalPortalInfo, source *UserLogin, didJustCreate bool) { + (*Portal)(portal).updateUserLocalInfo(ctx, info, source, didJustCreate) } func (portal *PortalInternals) UpdateParent(ctx context.Context, newParent networkid.PortalID, source *UserLogin) bool { From 0df3f47c273b6a5756253727319166935ff94460 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 13 Aug 2024 18:53:31 +0300 Subject: [PATCH 0627/1647] bridgev2: add hack for messages sent by nobody --- bridgev2/portal.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index cbcc325e..0e337c07 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1560,6 +1560,9 @@ func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventS func (portal *Portal) GetIntentFor(ctx context.Context, sender EventSender, source *UserLogin, evtType RemoteEventType) MatrixAPI { intent, _ := portal.getIntentAndUserMXIDFor(ctx, sender, source, nil, evtType) if intent == nil { + // TODO this is very hacky - we should either insert an empty ghost row automatically + // (and not fetch it at runtime) or make the message sender column nullable. + portal.Bridge.GetGhostByID(ctx, "") intent = portal.Bridge.Bot } return intent @@ -1642,6 +1645,10 @@ func (portal *Portal) sendConvertedMessage(ctx context.Context, id networkid.Mes } if part.DontBridge { dbMessage.SetFakeMXID() + logContext(log.Debug()). + Stringer("event_id", dbMessage.MXID). + Str("part_id", string(part.ID)). + Msg("Not bridging message part with DontBridge flag to Matrix") } else { resp, err := intent.SendMessage(ctx, portal.MXID, part.Type, &event.Content{ Parsed: part.Content, From 4f5dea4ca22fea17cf6508ce45272e770a6958de Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 13 Aug 2024 19:43:51 +0300 Subject: [PATCH 0628/1647] bridgev2/networkinterface: allow network connector to customize m.bridge data --- bridgev2/networkinterface.go | 5 +++++ bridgev2/portal.go | 4 ++++ 2 files changed, 9 insertions(+) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 73789be0..041af68a 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -243,6 +243,11 @@ type IdentifierValidatingNetwork interface { ValidateUserID(id networkid.UserID) bool } +type PortalBridgeInfoFillingNetwork interface { + NetworkConnector + FillPortalBridgeInfo(portal *Portal, content *event.BridgeEventContent) +} + // ConfigValidatingNetwork is an optional interface that network connectors can implement to validate config fields // before the bridge is started. // diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 0e337c07..c2347239 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2692,6 +2692,10 @@ func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) { // TODO external URL? } } + filler, ok := portal.Bridge.Network.(PortalBridgeInfoFillingNetwork) + if ok { + filler.FillPortalBridgeInfo(portal, &bridgeInfo) + } // TODO use something globally unique instead of bridge ID? // maybe ask the matrix connector to use serverName+appserviceID+bridgeID stateKey := string(portal.BridgeID) From 9b517179dc9054dfd1d41d189aa353082a412748 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 14 Aug 2024 14:15:56 +0300 Subject: [PATCH 0629/1647] federation: fix signing requests with no body --- federation/client.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/federation/client.go b/federation/client.go index dc8c139c..d49ba560 100644 --- a/federation/client.go +++ b/federation/client.go @@ -332,12 +332,16 @@ func (c *Client) compileRequest(ctx context.Context, params RequestParams) (*htt Message: "client not configured for authentication", } } + var contentAny any + if reqJSON != nil { + contentAny = reqJSON + } auth, err := (&signableRequest{ Method: req.Method, URI: reqURL.RequestURI(), Origin: c.ServerName, Destination: params.ServerName, - Content: reqJSON, + Content: contentAny, }).Sign(c.Key) if err != nil { return nil, mautrix.HTTPError{ From d791a70ade4a427229d6010da667e590dcb34eb6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 14 Aug 2024 14:16:09 +0300 Subject: [PATCH 0630/1647] federation: add query profile and directory wrappers --- federation/client.go | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/federation/client.go b/federation/client.go index d49ba560..098df095 100644 --- a/federation/client.go +++ b/federation/client.go @@ -198,6 +198,28 @@ func (c *Client) TimestampToEvent(ctx context.Context, serverName string, roomID return } +func (c *Client) QueryProfile(ctx context.Context, serverName string, userID id.UserID) (resp *mautrix.RespUserProfile, err error) { + err = c.Query(ctx, serverName, "profile", url.Values{"user_id": {userID.String()}}, &resp) + return +} + +func (c *Client) QueryDirectory(ctx context.Context, serverName string, roomAlias id.RoomAlias) (resp *mautrix.RespAliasResolve, err error) { + err = c.Query(ctx, serverName, "directory", url.Values{"room_alias": {roomAlias.String()}}, &resp) + return +} + +func (c *Client) Query(ctx context.Context, serverName, queryType string, queryParams url.Values, respStruct any) (err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: serverName, + Method: http.MethodGet, + Path: URLPath{"v1", "query", queryType}, + Query: queryParams, + Authenticate: true, + ResponseJSON: respStruct, + }) + return +} + type RespOpenIDUserInfo struct { Sub id.UserID `json:"sub"` } From 169e2db7ed7cf6986624ea3ad5e55209c4c3fa45 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 14 Aug 2024 14:48:06 +0300 Subject: [PATCH 0631/1647] bridgev2/backfill: send thread messages in same batch as root --- bridgev2/portalbackfill.go | 337 ++++++++++++++++++++++--------------- 1 file changed, 199 insertions(+), 138 deletions(-) diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index 4a977a3e..af71faa6 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -137,40 +137,49 @@ func (portal *Portal) DoBackwardsBackfill(ctx context.Context, source *UserLogin return nil } +func (portal *Portal) fetchThreadBackfill(ctx context.Context, source *UserLogin, anchor *database.Message) *FetchMessagesResponse { + log := zerolog.Ctx(ctx) + resp, err := source.Client.(BackfillingNetworkAPI).FetchMessages(ctx, FetchMessagesParams{ + Portal: portal, + ThreadRoot: anchor.ID, + Forward: true, + AnchorMessage: anchor, + Count: portal.Bridge.Config.Backfill.Threads.MaxInitialMessages, + }) + if err != nil { + log.Err(err).Msg("Failed to fetch messages for thread backfill") + return nil + } else if resp == nil { + log.Debug().Msg("Didn't get backfill response") + return nil + } else if len(resp.Messages) == 0 { + log.Debug().Msg("No messages to backfill") + return nil + } + resp.Messages = cutoffMessages(log, resp.Messages, true, anchor) + if len(resp.Messages) == 0 { + log.Warn().Msg("No messages left to backfill after cutting off old messages") + return nil + } + return resp +} + func (portal *Portal) doThreadBackfill(ctx context.Context, source *UserLogin, threadID networkid.MessageID) { log := zerolog.Ctx(ctx).With(). Str("subaction", "thread backfill"). Str("thread_id", string(threadID)). Logger() + ctx = log.WithContext(ctx) log.Info().Msg("Backfilling thread inside other backfill") anchorMessage, err := portal.Bridge.DB.Message.GetLastThreadMessage(ctx, portal.PortalKey, threadID) if err != nil { log.Err(err).Msg("Failed to get last thread message") return } - resp, err := source.Client.(BackfillingNetworkAPI).FetchMessages(ctx, FetchMessagesParams{ - Portal: portal, - ThreadRoot: threadID, - Forward: true, - AnchorMessage: anchorMessage, - Count: portal.Bridge.Config.Backfill.Threads.MaxInitialMessages, - }) - if err != nil { - log.Err(err).Msg("Failed to fetch messages for thread backfill") - return - } else if resp == nil { - log.Debug().Msg("Didn't get backfill response") - return - } else if len(resp.Messages) == 0 { - log.Debug().Msg("No messages to backfill") - return + resp := portal.fetchThreadBackfill(ctx, source, anchorMessage) + if resp != nil { + portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, true) } - resp.Messages = cutoffMessages(&log, resp.Messages, true, anchorMessage) - if len(resp.Messages) == 0 { - log.Warn().Msg("No messages left to backfill after cutting off old messages") - return - } - portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, true) } func cutoffMessages(log *zerolog.Logger, messages []*BackfillMessage, forward bool, lastMessage *database.Message) []*BackfillMessage { @@ -226,12 +235,12 @@ func (portal *Portal) sendBackfill(ctx context.Context, source *UserLogin, messa Bool("mark_read_past_threshold", forceMarkRead). Msg("Sending backfill messages") if canBatchSend { - portal.sendBatch(ctx, source, messages, forceForward, markRead || forceMarkRead, !inThread) + portal.sendBatch(ctx, source, messages, forceForward, markRead || forceMarkRead, inThread) } else { portal.sendLegacyBackfill(ctx, source, messages, markRead || forceMarkRead) } zerolog.Ctx(ctx).Debug().Msg("Backfill finished") - if !inThread && portal.Bridge.Config.Backfill.Threads.MaxInitialMessages > 0 { + if !canBatchSend && !inThread && portal.Bridge.Config.Backfill.Threads.MaxInitialMessages > 0 { for _, msg := range messages { if msg.ShouldBackfillThread { portal.doThreadBackfill(ctx, source, msg.ID) @@ -240,137 +249,176 @@ func (portal *Portal) sendBackfill(ctx context.Context, source *UserLogin, messa } } -func (portal *Portal) sendBatch(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, allowNotification bool) { +type compileBatchOutput struct { + PrevThreadEvents map[networkid.MessageID]id.EventID + + Events []*event.Event + Extras []*MatrixSendExtra + + DBMessages []*database.Message + DBReactions []*database.Reaction + Disappear []*database.DisappearingMessage +} + +func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin, msg *BackfillMessage, out *compileBatchOutput, inThread bool) { + if len(msg.Parts) == 0 { + return + } + intent := portal.GetIntentFor(ctx, msg.Sender, source, RemoteEventMessage) + replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta(ctx, msg.ID, msg.ReplyTo, msg.ThreadRoot, true) + if threadRoot != nil && out.PrevThreadEvents[*msg.ThreadRoot] != "" { + prevThreadEvent.MXID = out.PrevThreadEvents[*msg.ThreadRoot] + } + var partIDs []networkid.PartID + partMap := make(map[networkid.PartID]*database.Message, len(msg.Parts)) + var firstPart *database.Message + for _, part := range msg.Parts { + partIDs = append(partIDs, part.ID) + portal.applyRelationMeta(part.Content, replyTo, threadRoot, prevThreadEvent) + evtID := portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, msg.ID, part.ID) + out.Events = append(out.Events, &event.Event{ + Sender: intent.GetMXID(), + Type: part.Type, + Timestamp: msg.Timestamp.UnixMilli(), + ID: evtID, + RoomID: portal.MXID, + Content: event.Content{ + Parsed: part.Content, + Raw: part.Extra, + }, + }) + dbMessage := &database.Message{ + ID: msg.ID, + PartID: part.ID, + MXID: evtID, + Room: portal.PortalKey, + SenderID: msg.Sender.Sender, + SenderMXID: intent.GetMXID(), + Timestamp: msg.Timestamp, + ThreadRoot: ptr.Val(msg.ThreadRoot), + ReplyTo: ptr.Val(msg.ReplyTo), + Metadata: part.DBMetadata, + } + if firstPart == nil { + firstPart = dbMessage + } + partMap[part.ID] = dbMessage + out.Extras = append(out.Extras, &MatrixSendExtra{MessageMeta: dbMessage}) + out.DBMessages = append(out.DBMessages, dbMessage) + if prevThreadEvent != nil { + prevThreadEvent.MXID = evtID + out.PrevThreadEvents[*msg.ThreadRoot] = evtID + } + if msg.Disappear.Type != database.DisappearingTypeNone { + if msg.Disappear.Type == database.DisappearingTypeAfterSend && msg.Disappear.DisappearAt.IsZero() { + msg.Disappear.DisappearAt = msg.Timestamp.Add(msg.Disappear.Timer) + } + out.Disappear = append(out.Disappear, &database.DisappearingMessage{ + RoomID: portal.MXID, + EventID: evtID, + DisappearingSetting: msg.Disappear, + }) + } + } + slices.Sort(partIDs) + for _, reaction := range msg.Reactions { + reactionIntent := portal.GetIntentFor(ctx, reaction.Sender, source, RemoteEventReactionRemove) + if reaction.TargetPart == nil { + reaction.TargetPart = &partIDs[0] + } + if reaction.Timestamp.IsZero() { + reaction.Timestamp = msg.Timestamp.Add(10 * time.Millisecond) + } + targetPart, ok := partMap[*reaction.TargetPart] + if !ok { + // TODO warning log and/or skip reaction? + } + reactionMXID := portal.Bridge.Matrix.GenerateReactionEventID(portal.MXID, targetPart, reaction.Sender.Sender, reaction.EmojiID) + dbReaction := &database.Reaction{ + Room: portal.PortalKey, + MessageID: msg.ID, + MessagePartID: *reaction.TargetPart, + SenderID: reaction.Sender.Sender, + EmojiID: reaction.EmojiID, + MXID: reactionMXID, + Timestamp: reaction.Timestamp, + Emoji: reaction.Emoji, + Metadata: reaction.DBMetadata, + } + out.Events = append(out.Events, &event.Event{ + Sender: reactionIntent.GetMXID(), + Type: event.EventReaction, + Timestamp: reaction.Timestamp.UnixMilli(), + ID: reactionMXID, + RoomID: portal.MXID, + Content: event.Content{ + Parsed: &event.ReactionEventContent{ + RelatesTo: event.RelatesTo{ + Type: event.RelAnnotation, + EventID: portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, msg.ID, *reaction.TargetPart), + Key: variationselector.Add(reaction.Emoji), + }, + }, + Raw: reaction.ExtraContent, + }, + }) + out.DBReactions = append(out.DBReactions, dbReaction) + out.Extras = append(out.Extras, &MatrixSendExtra{ReactionMeta: dbReaction}) + } + if firstPart != nil && !inThread && portal.Bridge.Config.Backfill.Threads.MaxInitialMessages > 0 { + portal.fetchThreadInsideBatch(ctx, source, firstPart, out) + } +} + +func (portal *Portal) fetchThreadInsideBatch(ctx context.Context, source *UserLogin, dbMsg *database.Message, out *compileBatchOutput) { + log := zerolog.Ctx(ctx).With(). + Str("subaction", "thread backfill in batch"). + Str("thread_id", string(dbMsg.ID)). + Logger() + ctx = log.WithContext(ctx) + resp := portal.fetchThreadBackfill(ctx, source, dbMsg) + if resp != nil { + for _, msg := range resp.Messages { + portal.compileBatchMessage(ctx, source, msg, out, true) + } + } +} + +func (portal *Portal) sendBatch(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, inThread bool) { + out := &compileBatchOutput{ + PrevThreadEvents: make(map[networkid.MessageID]id.EventID), + Events: make([]*event.Event, 0, len(messages)), + Extras: make([]*MatrixSendExtra, 0, len(messages)), + DBMessages: make([]*database.Message, 0, len(messages)), + DBReactions: make([]*database.Reaction, 0), + Disappear: make([]*database.DisappearingMessage, 0), + } + for _, msg := range messages { + portal.compileBatchMessage(ctx, source, msg, out, inThread) + } req := &mautrix.ReqBeeperBatchSend{ ForwardIfNoMessages: !forceForward, Forward: forceForward, - Events: make([]*event.Event, 0, len(messages)), - SendNotification: !markRead && forceForward && allowNotification, + SendNotification: !markRead && forceForward && !inThread, + Events: out.Events, } if markRead { req.MarkReadBy = source.UserMXID } - prevThreadEvents := make(map[networkid.MessageID]id.EventID) - dbMessages := make([]*database.Message, 0, len(messages)) - dbReactions := make([]*database.Reaction, 0) - extras := make([]*MatrixSendExtra, 0, len(messages)) - var disappearingMessages []*database.DisappearingMessage - for _, msg := range messages { - if len(msg.Parts) == 0 { - continue - } - intent := portal.GetIntentFor(ctx, msg.Sender, source, RemoteEventMessage) - replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta(ctx, msg.ID, msg.ReplyTo, msg.ThreadRoot, true) - if threadRoot != nil && prevThreadEvents[*msg.ThreadRoot] != "" { - prevThreadEvent.MXID = prevThreadEvents[*msg.ThreadRoot] - } - var partIDs []networkid.PartID - partMap := make(map[networkid.PartID]*database.Message, len(msg.Parts)) - for _, part := range msg.Parts { - partIDs = append(partIDs, part.ID) - portal.applyRelationMeta(part.Content, replyTo, threadRoot, prevThreadEvent) - evtID := portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, msg.ID, part.ID) - req.Events = append(req.Events, &event.Event{ - Sender: intent.GetMXID(), - Type: part.Type, - Timestamp: msg.Timestamp.UnixMilli(), - ID: evtID, - RoomID: portal.MXID, - Content: event.Content{ - Parsed: part.Content, - Raw: part.Extra, - }, - }) - dbMessage := &database.Message{ - ID: msg.ID, - PartID: part.ID, - MXID: evtID, - Room: portal.PortalKey, - SenderID: msg.Sender.Sender, - SenderMXID: intent.GetMXID(), - Timestamp: msg.Timestamp, - ThreadRoot: ptr.Val(msg.ThreadRoot), - ReplyTo: ptr.Val(msg.ReplyTo), - Metadata: part.DBMetadata, - } - partMap[part.ID] = dbMessage - extras = append(extras, &MatrixSendExtra{MessageMeta: dbMessage}) - dbMessages = append(dbMessages, dbMessage) - if prevThreadEvent != nil { - prevThreadEvent.MXID = evtID - prevThreadEvents[*msg.ThreadRoot] = evtID - } - if msg.Disappear.Type != database.DisappearingTypeNone { - if msg.Disappear.Type == database.DisappearingTypeAfterSend && msg.Disappear.DisappearAt.IsZero() { - msg.Disappear.DisappearAt = msg.Timestamp.Add(msg.Disappear.Timer) - } - disappearingMessages = append(disappearingMessages, &database.DisappearingMessage{ - RoomID: portal.MXID, - EventID: evtID, - DisappearingSetting: msg.Disappear, - }) - } - } - slices.Sort(partIDs) - for _, reaction := range msg.Reactions { - reactionIntent := portal.GetIntentFor(ctx, reaction.Sender, source, RemoteEventReactionRemove) - if reaction.TargetPart == nil { - reaction.TargetPart = &partIDs[0] - } - if reaction.Timestamp.IsZero() { - reaction.Timestamp = msg.Timestamp.Add(10 * time.Millisecond) - } - targetPart, ok := partMap[*reaction.TargetPart] - if !ok { - // TODO warning log and/or skip reaction? - } - reactionMXID := portal.Bridge.Matrix.GenerateReactionEventID(portal.MXID, targetPart, reaction.Sender.Sender, reaction.EmojiID) - dbReaction := &database.Reaction{ - Room: portal.PortalKey, - MessageID: msg.ID, - MessagePartID: *reaction.TargetPart, - SenderID: reaction.Sender.Sender, - EmojiID: reaction.EmojiID, - MXID: reactionMXID, - Timestamp: reaction.Timestamp, - Emoji: reaction.Emoji, - Metadata: reaction.DBMetadata, - } - req.Events = append(req.Events, &event.Event{ - Sender: reactionIntent.GetMXID(), - Type: event.EventReaction, - Timestamp: reaction.Timestamp.UnixMilli(), - ID: reactionMXID, - RoomID: portal.MXID, - Content: event.Content{ - Parsed: &event.ReactionEventContent{ - RelatesTo: event.RelatesTo{ - Type: event.RelAnnotation, - EventID: portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, msg.ID, *reaction.TargetPart), - Key: variationselector.Add(reaction.Emoji), - }, - }, - Raw: reaction.ExtraContent, - }, - }) - dbReactions = append(dbReactions, dbReaction) - extras = append(extras, &MatrixSendExtra{ReactionMeta: dbReaction}) - } - } - _, err := portal.Bridge.Matrix.BatchSend(ctx, portal.MXID, req, extras) + _, err := portal.Bridge.Matrix.BatchSend(ctx, portal.MXID, req, out.Extras) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to send backfill messages") } - if len(disappearingMessages) > 0 { + if len(out.Disappear) > 0 { // TODO mass insert disappearing messages go func() { - for _, msg := range disappearingMessages { + for _, msg := range out.Disappear { portal.Bridge.DisappearLoop.Add(ctx, msg) } }() } // TODO mass insert db messages - for _, msg := range dbMessages { + for _, msg := range out.DBMessages { err = portal.Bridge.DB.Message.Insert(ctx, msg) if err != nil { zerolog.Ctx(ctx).Err(err). @@ -382,6 +430,19 @@ func (portal *Portal) sendBatch(ctx context.Context, source *UserLogin, messages Msg("Failed to insert backfilled message to database") } } + // TODO mass insert db reactions + for _, react := range out.DBReactions { + err = portal.Bridge.DB.Reaction.Upsert(ctx, react) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Str("message_id", string(react.MessageID)). + Str("part_id", string(react.MessagePartID)). + Str("sender_id", string(react.SenderID)). + Str("portal_id", string(react.Room.ID)). + Str("portal_receiver", string(react.Room.Receiver)). + Msg("Failed to insert backfilled reaction to database") + } + } } func (portal *Portal) sendLegacyBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, markRead bool) { From 3cc3f95017b33ccdbd24573b345bc04bdb836d8c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 14 Aug 2024 18:58:28 +0300 Subject: [PATCH 0632/1647] bridgev2: ensure m.mentions is always set --- bridgev2/networkinterface.go | 3 +++ bridgev2/portal.go | 14 +++++++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 041af68a..4fd8a603 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -138,6 +138,9 @@ type ConvertedEditPart struct { Extra map[string]any // TopLevelExtra can be used to specify custom fields at the top level of the content rather than inside `m.new_content`. TopLevelExtra map[string]any + // NewMentions can be used to specify new mentions that should ping the users again. + // Mentions inside the edited content will not ping. + NewMentions *event.Mentions DontBridge bool } diff --git a/bridgev2/portal.go b/bridgev2/portal.go index c2347239..667f98d2 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1609,14 +1609,14 @@ func (portal *Portal) getRelationMeta(ctx context.Context, currentMsg networkid. } func (portal *Portal) applyRelationMeta(content *event.MessageEventContent, replyTo, threadRoot, prevThreadEvent *database.Message) { + if content.Mentions == nil { + content.Mentions = &event.Mentions{} + } if threadRoot != nil && prevThreadEvent != nil { content.GetRelatesTo().SetThread(threadRoot.MXID, prevThreadEvent.MXID) } if replyTo != nil { content.GetRelatesTo().SetReplyTo(replyTo.MXID) - if content.Mentions == nil { - content.Mentions = &event.Mentions{} - } content.Mentions.Add(replyTo.SenderMXID) } } @@ -1869,12 +1869,20 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e func (portal *Portal) sendConvertedEdit(ctx context.Context, targetID networkid.MessageID, senderID networkid.UserID, converted *ConvertedEdit, intent MatrixAPI, ts time.Time) { log := zerolog.Ctx(ctx) for _, part := range converted.ModifiedParts { + if part.Content.Mentions == nil { + part.Content.Mentions = &event.Mentions{} + } overrideMXID := true if part.Part.Room != portal.PortalKey { part.Part.Room = portal.PortalKey } else if !part.Part.HasFakeMXID() { part.Content.SetEdit(part.Part.MXID) overrideMXID = false + if part.NewMentions != nil { + part.Content.Mentions = part.NewMentions + } else { + part.Content.Mentions = &event.Mentions{} + } } if part.TopLevelExtra == nil { part.TopLevelExtra = make(map[string]any) From 0ea4b348fe2bde9db220576b07ff117f4556939b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 14 Aug 2024 21:28:11 +0300 Subject: [PATCH 0633/1647] bridgev2/backfill: add missing check to thread backfill --- bridgev2/portalbackfill.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index af71faa6..b814b1f3 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -365,7 +365,7 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin out.DBReactions = append(out.DBReactions, dbReaction) out.Extras = append(out.Extras, &MatrixSendExtra{ReactionMeta: dbReaction}) } - if firstPart != nil && !inThread && portal.Bridge.Config.Backfill.Threads.MaxInitialMessages > 0 { + if firstPart != nil && !inThread && portal.Bridge.Config.Backfill.Threads.MaxInitialMessages > 0 && msg.ShouldBackfillThread { portal.fetchThreadInsideBatch(ctx, source, firstPart, out) } } From e521ab675cac6eb54f525262a589148ba66e54bd Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 14 Aug 2024 22:46:17 +0300 Subject: [PATCH 0634/1647] crypto/keysharing: improve rejection message when recipient tracking is enabled --- crypto/keysharing.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/crypto/keysharing.go b/crypto/keysharing.go index f4407cbb..4d3b6f7e 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -33,6 +33,7 @@ var ( KeyShareRejectBlacklisted = KeyShareRejection{event.RoomKeyWithheldBlacklisted, "You have been blacklisted by this device"} KeyShareRejectUnverified = KeyShareRejection{event.RoomKeyWithheldUnverified, "This device does not share keys to unverified devices"} KeyShareRejectOtherUser = KeyShareRejection{event.RoomKeyWithheldUnauthorized, "This device does not share keys to other users"} + KeyShareRejectNotRecipient = KeyShareRejection{event.RoomKeyWithheldUnauthorized, "You were not in the original recipient list for that session, or that session didn't originate from this device"} KeyShareRejectUnavailable = KeyShareRejection{event.RoomKeyWithheldUnavailable, "Requested session ID not found on this device"} KeyShareRejectInternalError = KeyShareRejection{event.RoomKeyWithheldUnavailable, "An internal error occurred while trying to share the requested session"} ) @@ -249,13 +250,18 @@ func (mach *OlmMachine) sendToOneDevice(ctx context.Context, userID id.UserID, d func (mach *OlmMachine) defaultAllowKeyShare(ctx context.Context, device *id.Device, evt event.RequestedKeyInfo) *KeyShareRejection { log := mach.machOrContextLog(ctx) if mach.Client.UserID != device.UserID { + if mach.DisableSharedGroupSessionTracking { + log.Debug().Msg("Rejecting key request from another user as recipient list tracking is disabled") + return &KeyShareRejectOtherUser + } isShared, err := mach.CryptoStore.IsOutboundGroupSessionShared(ctx, device.UserID, device.IdentityKey, evt.SessionID) if err != nil { log.Err(err).Msg("Rejecting key request due to internal error when checking session sharing") return &KeyShareRejectNoResponse } else if !isShared { + // TODO differentiate session not shared with requester vs session not created by this device? log.Debug().Msg("Rejecting key request for unshared session") - return &KeyShareRejectOtherUser + return &KeyShareRejectNotRecipient } log.Debug().Msg("Accepting key request for shared session") return nil From 59efa808cb1f836c56abbf04a60f7b641ac1b322 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 15 Aug 2024 13:12:01 +0300 Subject: [PATCH 0635/1647] main: drop support for Go 1.21 --- .github/workflows/go.yml | 10 +++++----- go.mod | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 66f6aee1..8197d3a7 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -12,7 +12,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: "1.22" + go-version: "1.23" cache: true - name: Install libolm @@ -34,8 +34,8 @@ jobs: strategy: fail-fast: false matrix: - go-version: ["1.21", "1.22"] - name: Build (${{ matrix.go-version == '1.22' && 'latest' || 'old' }}, libolm) + go-version: ["1.22", "1.23"] + name: Build (${{ matrix.go-version == '1.23' && 'latest' || 'old' }}, libolm) steps: - uses: actions/checkout@v4 @@ -65,8 +65,8 @@ jobs: strategy: fail-fast: false matrix: - go-version: ["1.21", "1.22"] - name: Build (${{ matrix.go-version == '1.22' && 'latest' || 'old' }}, goolm) + go-version: ["1.22", "1.23"] + name: Build (${{ matrix.go-version == '1.23' && 'latest' || 'old' }}, goolm) steps: - uses: actions/checkout@v4 diff --git a/go.mod b/go.mod index 563f7e98..86d8e5aa 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module maunium.net/go/mautrix -go 1.21 +go 1.22 require ( github.com/chzyer/readline v1.5.1 From 2d3862a65f236c9db834ac86f1fa34dcda049722 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 15 Aug 2024 13:40:14 +0300 Subject: [PATCH 0636/1647] bridgev2/portal: clear userportal cache when deleting portal --- bridgev2/portal.go | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 667f98d2..f4e69644 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -3415,6 +3415,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo } func (portal *Portal) Delete(ctx context.Context) error { + portal.removeInPortalCache(ctx) err := portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey) if err != nil { return err @@ -3440,6 +3441,27 @@ func (portal *Portal) RemoveMXID(ctx context.Context) error { return nil } +func (portal *Portal) removeInPortalCache(ctx context.Context) { + if portal.Receiver != "" { + login := portal.Bridge.GetCachedUserLoginByID(portal.Receiver) + if login != nil { + login.inPortalCache.Remove(portal.PortalKey) + } + return + } + userPortals, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get user logins in portal to remove user portal cache") + } else { + for _, up := range userPortals { + login := portal.Bridge.GetCachedUserLoginByID(up.LoginID) + if login != nil { + login.inPortalCache.Remove(portal.PortalKey) + } + } + } +} + func (portal *Portal) unlockedDelete(ctx context.Context) error { // TODO delete child portals? err := portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey) From 9e031496a01cffb6f04ff76aa7c2351e71d185e6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 15 Aug 2024 13:42:11 +0300 Subject: [PATCH 0637/1647] dependencies: update --- go.mod | 4 ++-- go.sum | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 86d8e5aa..c3d9cbe9 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/tidwall/gjson v1.17.3 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.4 - go.mau.fi/util v0.6.1-0.20240811184504-b00aa5c5af3a + go.mau.fi/util v0.6.1-0.20240815104112-77362c9b05dd go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.26.0 golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa @@ -29,7 +29,7 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect - github.com/petermattis/goid v0.0.0-20240716203034-badd1c0974d6 // indirect + github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect diff --git a/go.sum b/go.sum index f7b10211..5cf874b0 100644 --- a/go.sum +++ b/go.sum @@ -24,8 +24,8 @@ github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APP github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/petermattis/goid v0.0.0-20240716203034-badd1c0974d6 h1:DUDJI8T/9NcGbbL+AWk6vIYlmQ8ZBS8LZqVre6zbkPQ= -github.com/petermattis/goid v0.0.0-20240716203034-badd1c0974d6/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7 h1:Dx7Ovyv/SFnMFw3fD4oEoeorXc6saIiQ23LrGLth0Gw= +github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -48,8 +48,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.6.1-0.20240811184504-b00aa5c5af3a h1:A6AeueGxoDjSSf2X8Tz8X9nQ2S65uYWGVwlvTZa7Bjs= -go.mau.fi/util v0.6.1-0.20240811184504-b00aa5c5af3a/go.mod h1:ZRiX8FK4CsqVINI+3YK50nHnc+dKhfTZNf38zI31S/0= +go.mau.fi/util v0.6.1-0.20240815104112-77362c9b05dd h1:rDu4R3axIbNzv/c2Izri81dMcDXOklQil7tUGivvfNs= +go.mau.fi/util v0.6.1-0.20240815104112-77362c9b05dd/go.mod h1:bWYreIoTULL/UiRbZdfddPh7uWDFW5yX4YCv5FB0eE0= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= From a3f3445657ce7d52f35c02b91111a1aea497a8ed Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 15 Aug 2024 14:32:39 +0300 Subject: [PATCH 0638/1647] changelog: update --- CHANGELOG.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index feaedcd1..6c1d56bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,22 @@ +## v0.20.0 (unreleased) + +* Bumped minimum Go version to 1.22. +* *(bridgev2)* Added more features and fixed bugs. +* *(event)* Added types for [MSC4144]: Per-message profiles. +* *(federation)* Added implementation of server name resolution and a basic + client for making federation requests. +* *(crypto/ssss)* Changed recovery key/passphrase verify functions to take the + key ID as a parameter to ensure it's correctly set even if the key metadata + wasn't fetched via `GetKeyData`. +* *(format/mdext)* Added goldmark extensions for single-character bold, italic + and strikethrough parsing (as in `*foo*` -> **foo**, `_foo_` -> _foo_ and + `~foo~` -> ~~foo~~) +* *(format)* Changed `RenderMarkdown` et al to always include `m.mentions` in + returned content. The mention list is filled with matrix.to URLs from the + input by default. + +[MSC4144]: https://github.com/matrix-org/matrix-spec-proposals/pull/4144 + ## v0.19.0 (2024-07-16) * Renamed `master` branch to `main`. From cb8583825d2397b5829cd3acec4f25d805f7fdc2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 15 Aug 2024 15:11:03 +0300 Subject: [PATCH 0639/1647] bridgev2: allow network connectors to provide message stream order --- bridgev2/matrixinterface.go | 2 ++ bridgev2/networkinterface.go | 16 +++++++++----- bridgev2/portal.go | 42 ++++++++++++++++++++++++++++++------ bridgev2/portalbackfill.go | 6 +++--- bridgev2/portalinternal.go | 28 ++++++++++++++++++------ bridgev2/simplevent/meta.go | 6 ++++++ 6 files changed, 79 insertions(+), 21 deletions(-) diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 6d30891e..0628f16d 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -74,6 +74,8 @@ type MatrixSendExtra struct { Timestamp time.Time MessageMeta *database.Message ReactionMeta *database.Reaction + StreamOrder int64 + PartIndex int } type MatrixAPI interface { diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 4fd8a603..b68ad0c9 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -413,11 +413,12 @@ type BackfillReaction struct { // BackfillMessage is an individual message in a history pagination request. type BackfillMessage struct { *ConvertedMessage - Sender EventSender - ID networkid.MessageID - TxnID networkid.TransactionID - Timestamp time.Time - Reactions []*BackfillReaction + Sender EventSender + ID networkid.MessageID + TxnID networkid.TransactionID + Timestamp time.Time + StreamOrder int64 + Reactions []*BackfillReaction ShouldBackfillThread bool LastThreadMessage networkid.MessageID @@ -891,6 +892,11 @@ type RemoteEventWithTimestamp interface { GetTimestamp() time.Time } +type RemoteEventWithStreamOrder interface { + RemoteEvent + GetStreamOrder() int64 +} + type RemoteMessage interface { RemoteEvent GetID() networkid.MessageID diff --git a/bridgev2/portal.go b/bridgev2/portal.go index f4e69644..fc047beb 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1621,7 +1621,16 @@ func (portal *Portal) applyRelationMeta(content *event.MessageEventContent, repl } } -func (portal *Portal) sendConvertedMessage(ctx context.Context, id networkid.MessageID, intent MatrixAPI, senderID networkid.UserID, converted *ConvertedMessage, ts time.Time, logContext func(*zerolog.Event) *zerolog.Event) []*database.Message { +func (portal *Portal) sendConvertedMessage( + ctx context.Context, + id networkid.MessageID, + intent MatrixAPI, + senderID networkid.UserID, + converted *ConvertedMessage, + ts time.Time, + streamOrder int64, + logContext func(*zerolog.Event) *zerolog.Event, +) []*database.Message { if logContext == nil { logContext = func(e *zerolog.Event) *zerolog.Event { return e @@ -1630,7 +1639,7 @@ func (portal *Portal) sendConvertedMessage(ctx context.Context, id networkid.Mes log := zerolog.Ctx(ctx) replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta(ctx, id, converted.ReplyTo, converted.ThreadRoot, false) output := make([]*database.Message, 0, len(converted.Parts)) - for _, part := range converted.Parts { + for i, part := range converted.Parts { portal.applyRelationMeta(part.Content, replyTo, threadRoot, prevThreadEvent) dbMessage := &database.Message{ ID: id, @@ -1656,6 +1665,8 @@ func (portal *Portal) sendConvertedMessage(ctx context.Context, id networkid.Mes }, &MatrixSendExtra{ Timestamp: ts, MessageMeta: dbMessage, + StreamOrder: streamOrder, + PartIndex: i, }) if err != nil { logContext(log.Err(err)).Str("part_id", string(part.ID)).Msg("Failed to send message part to Matrix") @@ -1807,7 +1818,7 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin } return } - portal.sendConvertedMessage(ctx, evt.GetID(), intent, evt.GetSender().Sender, converted, ts, nil) + portal.sendConvertedMessage(ctx, evt.GetID(), intent, evt.GetSender().Sender, converted, ts, getStreamOrder(evt), nil) } func (portal *Portal) sendRemoteErrorNotice(ctx context.Context, intent MatrixAPI, err error, ts time.Time, evtTypeName string) { @@ -1863,12 +1874,20 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e portal.sendRemoteErrorNotice(ctx, intent, err, ts, "edit") return } - portal.sendConvertedEdit(ctx, existing[0].ID, evt.GetSender().Sender, converted, intent, ts) + portal.sendConvertedEdit(ctx, existing[0].ID, evt.GetSender().Sender, converted, intent, ts, getStreamOrder(evt)) } -func (portal *Portal) sendConvertedEdit(ctx context.Context, targetID networkid.MessageID, senderID networkid.UserID, converted *ConvertedEdit, intent MatrixAPI, ts time.Time) { +func (portal *Portal) sendConvertedEdit( + ctx context.Context, + targetID networkid.MessageID, + senderID networkid.UserID, + converted *ConvertedEdit, + intent MatrixAPI, + ts time.Time, + streamOrder int64, +) { log := zerolog.Ctx(ctx) - for _, part := range converted.ModifiedParts { + for i, part := range converted.ModifiedParts { if part.Content.Mentions == nil { part.Content.Mentions = &event.Mentions{} } @@ -1898,6 +1917,8 @@ func (portal *Portal) sendConvertedEdit(ctx context.Context, targetID networkid. resp, err := intent.SendMessage(ctx, portal.MXID, part.Type, wrappedContent, &MatrixSendExtra{ Timestamp: ts, MessageMeta: part.Part, + StreamOrder: streamOrder, + PartIndex: i, }) if err != nil { log.Err(err).Stringer("part_mxid", part.Part.MXID).Msg("Failed to edit message part") @@ -1941,7 +1962,7 @@ func (portal *Portal) sendConvertedEdit(ctx context.Context, targetID networkid. } } if converted.AddedParts != nil { - portal.sendConvertedMessage(ctx, targetID, intent, senderID, converted.AddedParts, ts, nil) + portal.sendConvertedMessage(ctx, targetID, intent, senderID, converted.AddedParts, ts, streamOrder, nil) } } @@ -1968,6 +1989,13 @@ func getEventTS(evt RemoteEvent) time.Time { return time.Now() } +func getStreamOrder(evt RemoteEvent) int64 { + if streamProvider, ok := evt.(RemoteEventWithStreamOrder); ok { + return streamProvider.GetStreamOrder() + } + return 0 +} + func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *UserLogin, evt RemoteReactionSync) { log := zerolog.Ctx(ctx) eventTS := getEventTS(evt) diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index b814b1f3..1bff29c1 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -272,7 +272,7 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin var partIDs []networkid.PartID partMap := make(map[networkid.PartID]*database.Message, len(msg.Parts)) var firstPart *database.Message - for _, part := range msg.Parts { + for i, part := range msg.Parts { partIDs = append(partIDs, part.ID) portal.applyRelationMeta(part.Content, replyTo, threadRoot, prevThreadEvent) evtID := portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, msg.ID, part.ID) @@ -303,7 +303,7 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin firstPart = dbMessage } partMap[part.ID] = dbMessage - out.Extras = append(out.Extras, &MatrixSendExtra{MessageMeta: dbMessage}) + out.Extras = append(out.Extras, &MatrixSendExtra{MessageMeta: dbMessage, StreamOrder: msg.StreamOrder, PartIndex: i}) out.DBMessages = append(out.DBMessages, dbMessage) if prevThreadEvent != nil { prevThreadEvent.MXID = evtID @@ -449,7 +449,7 @@ func (portal *Portal) sendLegacyBackfill(ctx context.Context, source *UserLogin, var lastPart id.EventID for _, msg := range messages { intent := portal.GetIntentFor(ctx, msg.Sender, source, RemoteEventMessage) - dbMessages := portal.sendConvertedMessage(ctx, msg.ID, intent, msg.Sender.Sender, msg.ConvertedMessage, msg.Timestamp, func(z *zerolog.Event) *zerolog.Event { + dbMessages := portal.sendConvertedMessage(ctx, msg.ID, intent, msg.Sender.Sender, msg.ConvertedMessage, msg.Timestamp, msg.StreamOrder, func(z *zerolog.Event) *zerolog.Event { return z. Str("message_id", string(msg.ID)). Any("sender_id", msg.Sender). diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index ffad2ac9..1ee793a9 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -125,8 +125,8 @@ func (portal *PortalInternals) ApplyRelationMeta(content *event.MessageEventCont (*Portal)(portal).applyRelationMeta(content, replyTo, threadRoot, prevThreadEvent) } -func (portal *PortalInternals) SendConvertedMessage(ctx context.Context, id networkid.MessageID, intent MatrixAPI, senderID networkid.UserID, converted *ConvertedMessage, ts time.Time, logContext func(*zerolog.Event) *zerolog.Event) []*database.Message { - return (*Portal)(portal).sendConvertedMessage(ctx, id, intent, senderID, converted, ts, logContext) +func (portal *PortalInternals) SendConvertedMessage(ctx context.Context, id networkid.MessageID, intent MatrixAPI, senderID networkid.UserID, converted *ConvertedMessage, ts time.Time, streamOrder int64, logContext func(*zerolog.Event) *zerolog.Event) []*database.Message { + return (*Portal)(portal).sendConvertedMessage(ctx, id, intent, senderID, converted, ts, streamOrder, logContext) } func (portal *PortalInternals) CheckPendingMessage(ctx context.Context, evt RemoteMessage) (bool, *database.Message) { @@ -149,8 +149,8 @@ func (portal *PortalInternals) HandleRemoteEdit(ctx context.Context, source *Use (*Portal)(portal).handleRemoteEdit(ctx, source, evt) } -func (portal *PortalInternals) SendConvertedEdit(ctx context.Context, targetID networkid.MessageID, senderID networkid.UserID, converted *ConvertedEdit, intent MatrixAPI, ts time.Time) { - (*Portal)(portal).sendConvertedEdit(ctx, targetID, senderID, converted, intent, ts) +func (portal *PortalInternals) SendConvertedEdit(ctx context.Context, targetID networkid.MessageID, senderID networkid.UserID, converted *ConvertedEdit, intent MatrixAPI, ts time.Time, streamOrder int64) { + (*Portal)(portal).sendConvertedEdit(ctx, targetID, senderID, converted, intent, ts, streamOrder) } func (portal *PortalInternals) GetTargetMessagePart(ctx context.Context, evt RemoteEventWithTargetMessage) (*database.Message, error) { @@ -273,6 +273,10 @@ func (portal *PortalInternals) CreateMatrixRoomInLoop(ctx context.Context, sourc return (*Portal)(portal).createMatrixRoomInLoop(ctx, source, info, backfillBundle) } +func (portal *PortalInternals) RemoveInPortalCache(ctx context.Context) { + (*Portal)(portal).removeInPortalCache(ctx) +} + func (portal *PortalInternals) UnlockedDelete(ctx context.Context) error { return (*Portal)(portal).unlockedDelete(ctx) } @@ -285,6 +289,10 @@ func (portal *PortalInternals) DoForwardBackfill(ctx context.Context, source *Us (*Portal)(portal).doForwardBackfill(ctx, source, lastMessage, bundledData) } +func (portal *PortalInternals) FetchThreadBackfill(ctx context.Context, source *UserLogin, anchor *database.Message) *FetchMessagesResponse { + return (*Portal)(portal).fetchThreadBackfill(ctx, source, anchor) +} + func (portal *PortalInternals) DoThreadBackfill(ctx context.Context, source *UserLogin, threadID networkid.MessageID) { (*Portal)(portal).doThreadBackfill(ctx, source, threadID) } @@ -293,8 +301,16 @@ func (portal *PortalInternals) SendBackfill(ctx context.Context, source *UserLog (*Portal)(portal).sendBackfill(ctx, source, messages, forceForward, markRead, inThread) } -func (portal *PortalInternals) SendBatch(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, allowNotification bool) { - (*Portal)(portal).sendBatch(ctx, source, messages, forceForward, markRead, allowNotification) +func (portal *PortalInternals) CompileBatchMessage(ctx context.Context, source *UserLogin, msg *BackfillMessage, out *compileBatchOutput, inThread bool) { + (*Portal)(portal).compileBatchMessage(ctx, source, msg, out, inThread) +} + +func (portal *PortalInternals) FetchThreadInsideBatch(ctx context.Context, source *UserLogin, dbMsg *database.Message, out *compileBatchOutput) { + (*Portal)(portal).fetchThreadInsideBatch(ctx, source, dbMsg, out) +} + +func (portal *PortalInternals) SendBatch(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, inThread bool) { + (*Portal)(portal).sendBatch(ctx, source, messages, forceForward, markRead, inThread) } func (portal *PortalInternals) SendLegacyBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, markRead bool) { diff --git a/bridgev2/simplevent/meta.go b/bridgev2/simplevent/meta.go index 15b97b8b..a6b278fc 100644 --- a/bridgev2/simplevent/meta.go +++ b/bridgev2/simplevent/meta.go @@ -24,6 +24,7 @@ type EventMeta struct { Sender bridgev2.EventSender CreatePortal bool Timestamp time.Time + StreamOrder int64 } var ( @@ -31,6 +32,7 @@ var ( _ bridgev2.RemoteEventWithUncertainPortalReceiver = (*EventMeta)(nil) _ bridgev2.RemoteEventThatMayCreatePortal = (*EventMeta)(nil) _ bridgev2.RemoteEventWithTimestamp = (*EventMeta)(nil) + _ bridgev2.RemoteEventWithStreamOrder = (*EventMeta)(nil) ) func (evt *EventMeta) AddLogContext(c zerolog.Context) zerolog.Context { @@ -55,6 +57,10 @@ func (evt *EventMeta) GetTimestamp() time.Time { return evt.Timestamp } +func (evt *EventMeta) GetStreamOrder() int64 { + return evt.StreamOrder +} + func (evt *EventMeta) GetSender() bridgev2.EventSender { return evt.Sender } From e50a705cec6deec7bb5699ecc3209bddf426baa8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 15 Aug 2024 16:10:19 +0300 Subject: [PATCH 0640/1647] client: update beeper inbox request content --- requests.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requests.go b/requests.go index d4b634af..f91aaa79 100644 --- a/requests.go +++ b/requests.go @@ -366,7 +366,8 @@ type ReqSetReadMarkers struct { } type BeeperInboxDone struct { - Delta int64 `json:"at_delta"` + Delta int64 `json:"at_delta"` + AtOrder int64 `json:"at_order"` } type ReqSetBeeperInboxState struct { From 329157afde3f65bd517d4190bd0b92c884caf58c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 16 Aug 2024 11:54:00 +0300 Subject: [PATCH 0641/1647] Bump version to v0.20.0 --- CHANGELOG.md | 2 +- bridgev2/matrix/provisioning.yaml | 2 +- go.mod | 2 +- go.sum | 4 ++-- version.go | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c1d56bf..819f69d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -## v0.20.0 (unreleased) +## v0.20.0 (2024-08-16) * Bumped minimum Go version to 1.22. * *(bridgev2)* Added more features and fixed bugs. diff --git a/bridgev2/matrix/provisioning.yaml b/bridgev2/matrix/provisioning.yaml index e195a3bd..1daf7b07 100644 --- a/bridgev2/matrix/provisioning.yaml +++ b/bridgev2/matrix/provisioning.yaml @@ -7,7 +7,7 @@ info: license: name: Mozilla Public License Version 2.0 url: https://github.com/mautrix/go/blob/main/LICENSE - version: v0.19.0 + version: v0.20.0 externalDocs: description: mautrix-go godocs url: https://pkg.go.dev/maunium.net/go/mautrix/bridgev2 diff --git a/go.mod b/go.mod index c3d9cbe9..4ad4143c 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/tidwall/gjson v1.17.3 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.4 - go.mau.fi/util v0.6.1-0.20240815104112-77362c9b05dd + go.mau.fi/util v0.7.0 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.26.0 golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa diff --git a/go.sum b/go.sum index 5cf874b0..0adc7117 100644 --- a/go.sum +++ b/go.sum @@ -48,8 +48,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.6.1-0.20240815104112-77362c9b05dd h1:rDu4R3axIbNzv/c2Izri81dMcDXOklQil7tUGivvfNs= -go.mau.fi/util v0.6.1-0.20240815104112-77362c9b05dd/go.mod h1:bWYreIoTULL/UiRbZdfddPh7uWDFW5yX4YCv5FB0eE0= +go.mau.fi/util v0.7.0 h1:l31z+ivrSQw+cv/9eFebEqtQW2zhxivGypn+JT0h/ws= +go.mau.fi/util v0.7.0/go.mod h1:bWYreIoTULL/UiRbZdfddPh7uWDFW5yX4YCv5FB0eE0= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= diff --git a/version.go b/version.go index d98634ec..29c5eb46 100644 --- a/version.go +++ b/version.go @@ -7,7 +7,7 @@ import ( "strings" ) -const Version = "v0.19.0" +const Version = "v0.20.0" var GoModVersion = "" var Commit = "" From 6946d3cdb502d45fb29bef74264e8eeea4d52723 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 13 Aug 2024 09:25:01 -0600 Subject: [PATCH 0642/1647] verificationhelper: send cancellations to other devices if cancelled from one device Signed-off-by: Sumner Evans --- .../verificationhelper/verificationhelper.go | 47 +++++++++++++++---- .../verificationhelper_test.go | 39 ++++++++++++++- 2 files changed, 75 insertions(+), 11 deletions(-) diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index e7ea53c5..2719ea78 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -746,13 +746,8 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *veri Reason: "The verification was accepted on another device.", }, } - devices, err := vh.mach.CryptoStore.GetDevices(ctx, txn.TheirUser) - if err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to get devices for %s: %w", txn.TheirUser, err) - return - } req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUser: {}}} - for deviceID := range devices { + for _, deviceID := range txn.SentToDeviceIDs { if deviceID == txn.TheirDevice || deviceID == vh.client.DeviceID { // Don't ever send a cancellation to the device that accepted // the request or to our own device (which can happen if this @@ -762,7 +757,7 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *veri req.Messages[txn.TheirUser][deviceID] = content } - _, err = vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) + _, err := vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) if err != nil { log.Warn().Err(err).Msg("Failed to send cancellation requests") } @@ -878,14 +873,48 @@ func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn *verif func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn *verificationTransaction, evt *event.Event) { cancelEvt := evt.Content.AsVerificationCancel() - vh.getLog(ctx).Info(). + log := vh.getLog(ctx).With(). Str("verification_action", "cancel"). Stringer("transaction_id", txn.TransactionID). Str("cancel_code", string(cancelEvt.Code)). Str("reason", cancelEvt.Reason). - Msg("Verification was cancelled") + Logger() + ctx = log.WithContext(ctx) + log.Info().Msg("Verification was cancelled") vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() + + // Element (and at least the old desktop client) send cancellation events + // when the user rejects the verification request. This is really dumb, + // because they should just instead ignore the request and not send a + // cancellation. + // + // The above behavior causes a problem with the other devices that we sent + // the verification request to because they don't know that the request was + // cancelled. + // + // As a workaround, if we receive a cancellation event to a transaction + // that is currently in the REQUESTED state, then we will send + // cancellations to all of the devices that we sent the request to. This + // will ensure that all of the clients know that the request was cancelled. + if txn.VerificationState == verificationStateRequested && len(txn.SentToDeviceIDs) > 0 { + content := &event.Content{ + Parsed: &event.VerificationCancelEventContent{ + ToDeviceVerificationEvent: event.ToDeviceVerificationEvent{TransactionID: txn.TransactionID}, + Code: event.VerificationCancelCodeUser, + Reason: "The verification was rejected from another device.", + }, + } + req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUser: {}}} + for _, deviceID := range txn.SentToDeviceIDs { + req.Messages[txn.TheirUser][deviceID] = content + } + _, err := vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) + if err != nil { + log.Warn().Err(err).Msg("Failed to send cancellation requests") + } + } + delete(vh.activeTransactions, txn.TransactionID) vh.verificationCancelledCallback(ctx, txn.TransactionID, cancelEvt.Code, cancelEvt.Reason) } diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go index e8be5771..876e90f7 100644 --- a/crypto/verificationhelper/verificationhelper_test.go +++ b/crypto/verificationhelper/verificationhelper_test.go @@ -141,13 +141,22 @@ func TestVerification_Start(t *testing.T) { func TestVerification_StartThenCancel(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) + bystanderDeviceID := id.DeviceID("bystander") for _, sendingCancels := range []bool{true, false} { t.Run(fmt.Sprintf("sendingCancels=%t", sendingCancels), func(t *testing.T) { - ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + ts, sendingClient, receivingClient, sendingCryptoStore, receivingCryptoStore, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) defer ts.Close() _, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + bystanderClient, _ := ts.Login(t, ctx, aliceUserID, bystanderDeviceID) + bystanderMachine := bystanderClient.Crypto.(*cryptohelper.CryptoHelper).Machine() + bystanderHelper := verificationhelper.NewVerificationHelper(bystanderClient, bystanderMachine, newAllVerificationCallbacks(), true) + require.NoError(t, bystanderHelper.Init(ctx)) + + require.NoError(t, sendingCryptoStore.PutDevice(ctx, aliceUserID, bystanderMachine.OwnIdentity())) + require.NoError(t, receivingCryptoStore.PutDevice(ctx, aliceUserID, bystanderMachine.OwnIdentity())) + txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) @@ -159,7 +168,13 @@ func TestVerification_StartThenCancel(t *testing.T) { assert.Equal(t, txnID, receivingInbox[0].Content.AsVerificationRequest().TransactionID) ts.dispatchToDevice(t, ctx, receivingClient) - // Cancel the verification request on the sending device. + // Process the request event on the bystander device. + bystanderInbox := ts.DeviceInbox[aliceUserID][bystanderDeviceID] + assert.Len(t, bystanderInbox, 1) + assert.Equal(t, txnID, bystanderInbox[0].Content.AsVerificationRequest().TransactionID) + ts.dispatchToDevice(t, ctx, bystanderClient) + + // Cancel the verification request. var cancelEvt *event.VerificationCancelEventContent if sendingCancels { err = sendingHelper.CancelVerification(ctx, txnID, event.VerificationCancelCodeUser, "Recovery code preferred") @@ -171,6 +186,11 @@ func TestVerification_StartThenCancel(t *testing.T) { // Ensure that the cancellation event was sent to the receiving device. assert.Len(t, ts.DeviceInbox[aliceUserID][receivingDeviceID], 1) cancelEvt = ts.DeviceInbox[aliceUserID][receivingDeviceID][0].Content.AsVerificationCancel() + + // Ensure that the cancellation event was sent to the bystander device. + assert.Len(t, ts.DeviceInbox[aliceUserID][bystanderDeviceID], 1) + bystanderCancelEvt := ts.DeviceInbox[aliceUserID][bystanderDeviceID][0].Content.AsVerificationCancel() + assert.Equal(t, cancelEvt, bystanderCancelEvt) } else { err = receivingHelper.CancelVerification(ctx, txnID, event.VerificationCancelCodeUser, "Recovery code preferred") assert.NoError(t, err) @@ -181,10 +201,25 @@ func TestVerification_StartThenCancel(t *testing.T) { // Ensure that the cancellation event was sent to the sending device. assert.Len(t, ts.DeviceInbox[aliceUserID][sendingDeviceID], 1) cancelEvt = ts.DeviceInbox[aliceUserID][sendingDeviceID][0].Content.AsVerificationCancel() + + // The bystander device should not have a cancellation event. + assert.Empty(t, ts.DeviceInbox[aliceUserID][bystanderDeviceID]) } assert.Equal(t, txnID, cancelEvt.TransactionID) assert.Equal(t, event.VerificationCancelCodeUser, cancelEvt.Code) assert.Equal(t, "Recovery code preferred", cancelEvt.Reason) + + if !sendingCancels { + // Process the cancellation event on the sending device. + ts.dispatchToDevice(t, ctx, sendingClient) + + // Ensure that the cancellation event was sent to the bystander device. + assert.Len(t, ts.DeviceInbox[aliceUserID][bystanderDeviceID], 1) + bystanderCancelEvt := ts.DeviceInbox[aliceUserID][bystanderDeviceID][0].Content.AsVerificationCancel() + assert.Equal(t, txnID, bystanderCancelEvt.TransactionID) + assert.Equal(t, event.VerificationCancelCodeUser, bystanderCancelEvt.Code) + assert.Equal(t, "The verification was rejected from another device.", bystanderCancelEvt.Reason) + } }) } } From d40aa8c7c6a12214f0fac1d8e3ea94414c78e0ab Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 15 Aug 2024 10:04:11 -0600 Subject: [PATCH 0643/1647] verificationhelper: add function to dismiss verification request without cancelling it Signed-off-by: Sumner Evans --- client.go | 21 +++++++++++++++++++ crypto/verificationhelper/reciprocate.go | 7 +++++-- crypto/verificationhelper/sas.go | 12 +++++++---- .../verificationhelper/verificationhelper.go | 18 +++++++++++++--- 4 files changed, 49 insertions(+), 9 deletions(-) diff --git a/client.go b/client.go index 750e3c25..636b355f 100644 --- a/client.go +++ b/client.go @@ -35,16 +35,37 @@ type CryptoHelper interface { } type VerificationHelper interface { + // Init initializes the helper. This should be called before any other + // methods. Init(context.Context) error + + // StartVerification starts an interactive verification flow with the given + // user via a to-device event. StartVerification(ctx context.Context, to id.UserID) (id.VerificationTransactionID, error) + // StartInRoomVerification starts an interactive verification flow with the + // given user in the given room. StartInRoomVerification(ctx context.Context, roomID id.RoomID, to id.UserID) (id.VerificationTransactionID, error) + + // AcceptVerification accepts a verification request. AcceptVerification(ctx context.Context, txnID id.VerificationTransactionID) error + // DismissVerification dismisses a verification request. This will not send + // a cancellation to the other device. This method should only be called + // *before* the request has been accepted and will error otherwise. + DismissVerification(ctx context.Context, txnID id.VerificationTransactionID) error + // CancelVerification cancels a verification request. This method should + // only be called *after* the request has been accepted, although it will + // not error if called beforehand. CancelVerification(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) error + // HandleScannedQRData handles the data from a QR code scan. HandleScannedQRData(ctx context.Context, data []byte) error + // ConfirmQRCodeScanned confirms that our QR code has been scanned. ConfirmQRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) error + // StartSAS starts a SAS verification flow. StartSAS(ctx context.Context, txnID id.VerificationTransactionID) error + // ConfirmSAS indicates that the user has confirmed that the SAS matches + // SAS shown on the other user's device. ConfirmSAS(ctx context.Context, txnID id.VerificationTransactionID) error } diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index 2ea0a0ed..21276218 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -183,8 +183,11 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by return nil } -// ConfirmQRCodeScanned confirms that our QR code has been scanned and sends the -// m.key.verification.done event to the other device. +// ConfirmQRCodeScanned confirms that our QR code has been scanned and sends +// the m.key.verification.done event to the other device for the given +// transaction ID. The transaction ID should be one received via the +// VerificationRequested callback in [RequiredCallbacks] or the +// [StartVerification] or [StartInRoomVerification] functions. func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) error { log := vh.getLog(ctx).With(). Str("verification_action", "confirm QR code scanned"). diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index bf8c6050..e28ec405 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -28,9 +28,10 @@ import ( "maunium.net/go/mautrix/id" ) -// StartSAS starts a SAS verification flow. The transaction ID should be the -// transaction ID of a verification request that was received via the -// VerificationRequested callback in [RequiredCallbacks]. +// StartSAS starts a SAS verification flow for the given transaction ID. The +// transaction ID should be one received via the VerificationRequested callback +// in [RequiredCallbacks] or the [StartVerification] or +// [StartInRoomVerification] functions. func (vh *VerificationHelper) StartSAS(ctx context.Context, txnID id.VerificationTransactionID) error { log := vh.getLog(ctx).With(). Str("verification_action", "accept verification"). @@ -81,7 +82,10 @@ func (vh *VerificationHelper) StartSAS(ctx context.Context, txnID id.Verificatio } // ConfirmSAS indicates that the user has confirmed that the SAS matches SAS -// shown on the other user's device. +// shown on the other user's device for the given transaction ID. The +// transaction ID should be one received via the VerificationRequested callback +// in [RequiredCallbacks] or the [StartVerification] or +// [StartInRoomVerification] functions. func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.VerificationTransactionID) error { log := vh.getLog(ctx).With(). Str("verification_action", "confirm SAS"). diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 2719ea78..f4e5e2f5 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -482,9 +482,21 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V return vh.generateAndShowQRCode(ctx, txn) } -// CancelVerification cancels a verification request. The transaction ID should -// be the transaction ID of a verification request that was received via the -// VerificationRequested callback in [RequiredCallbacks]. +// DismissVerification dismisses the verification request with the given +// transaction ID. The transaction ID should be one received via the +// VerificationRequested callback in [RequiredCallbacks] or the +// [StartVerification] or [StartInRoomVerification] functions. +func (vh *VerificationHelper) DismissVerification(ctx context.Context, txnID id.VerificationTransactionID) error { + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + delete(vh.activeTransactions, txnID) + return nil +} + +// DismissVerification cancels the verification request with the given +// transaction ID. The transaction ID should be one received via the +// VerificationRequested callback in [RequiredCallbacks] or the +// [StartVerification] or [StartInRoomVerification] functions. func (vh *VerificationHelper) CancelVerification(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) error { vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() From 2355d70426f4a41b2230b11a1a0ba76ac88df04a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 17 Aug 2024 14:10:04 +0300 Subject: [PATCH 0644/1647] bridgev2/matrix: return error if trying to encrypt message without encryption enabled --- bridgev2/matrix/intent.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index e789fa75..91748317 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -56,6 +56,9 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType if encrypted, err := as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil { return nil, fmt.Errorf("failed to check if room is encrypted: %w", err) } else if encrypted { + if as.Connector.Crypto == nil { + return nil, fmt.Errorf("room is encrypted, but bridge isn't configured to support encryption") + } if as.Matrix.IsCustomPuppet { if extra.Timestamp.IsZero() { as.Matrix.AddDoublePuppetValue(content) From 6444f9bcccaa1e4d29657ff898c01be9772f178a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 17 Aug 2024 23:30:46 +0300 Subject: [PATCH 0645/1647] bridgev2/matrix: add double puppet value for redactions --- bridgev2/matrix/intent.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 91748317..0c28ad50 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -47,6 +47,7 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType // TODO remove this once hungryserv and synapse support sending m.room.redactions directly in all room versions if eventType == event.EventRedaction { parsedContent := content.Parsed.(*event.RedactionEventContent) + as.Matrix.AddDoublePuppetValue(content) return as.Matrix.RedactEvent(ctx, roomID, parsedContent.Redacts, mautrix.ReqRedact{ Reason: parsedContent.Reason, Extra: content.Raw, From f813c4f8b0a44d6de49f1f3eb50ed81e09b5454a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 19 Aug 2024 11:29:06 +0300 Subject: [PATCH 0646/1647] statestore: log warning if UpdateStateStore has unexpected content type --- statestore.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/statestore.go b/statestore.go index fd8f81e5..35bfc6ab 100644 --- a/statestore.go +++ b/statestore.go @@ -62,6 +62,15 @@ func UpdateStateStore(ctx context.Context, store StateStore, evt *event.Event) { err = store.SetPowerLevels(ctx, evt.RoomID, content) case *event.EncryptionEventContent: err = store.SetEncryptionEvent(ctx, evt.RoomID, content) + default: + switch evt.Type { + case event.StateMember, event.StatePowerLevels, event.StateEncryption: + zerolog.Ctx(ctx).Warn(). + Stringer("event_id", evt.ID). + Str("event_type", evt.Type.Type). + Type("content_type", evt.Content.Parsed). + Msg("Got known event type with unknown content type in UpdateStateStore") + } } if err != nil { zerolog.Ctx(ctx).Warn().Err(err). From b4927420ccad04d8b2ba512c6da35059a7eacf2e Mon Sep 17 00:00:00 2001 From: Brad Murray Date: Mon, 19 Aug 2024 09:23:25 -0400 Subject: [PATCH 0647/1647] client: don't retry requests if context is cancelled (#268) --- client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client.go b/client.go index 636b355f..97571219 100644 --- a/client.go +++ b/client.go @@ -613,7 +613,7 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof defer res.Body.Close() } if err != nil { - if retries > 0 { + if retries > 0 && !errors.Is(err, context.Canceled) { return cli.doRetry(req, err, retries, backoff, responseJSON, handler, dontReadResponse, client) } err = HTTPError{ From 7f392b17b656e3a8f3e9f588f5d05a4463d46a8d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 19 Aug 2024 17:52:54 +0300 Subject: [PATCH 0648/1647] bridgev2/matrix: prevent too many async uploads at once --- bridgev2/matrix/connector.go | 4 ++++ bridgev2/matrix/intent.go | 8 ++++++++ go.mod | 1 + go.sum | 2 ++ 4 files changed, 15 insertions(+) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index c2b3f7cd..8c35fb4f 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -28,6 +28,7 @@ import ( _ "go.mau.fi/util/dbutil/litestream" "go.mau.fi/util/exsync" "go.mau.fi/util/random" + "golang.org/x/sync/semaphore" "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" @@ -70,6 +71,7 @@ type Connector struct { DoublePuppet *doublePuppetUtil MediaProxy *mediaproxy.MediaProxy + uploadSema *semaphore.Weighted dmaSigKey [32]byte pubMediaSigKey []byte @@ -108,6 +110,7 @@ func NewConnector(cfg *bridgeconfig.Config) *Connector { c.Config = cfg c.userIDRegex = cfg.MakeUserIDRegex("(.+)") c.MediaConfig.UploadSize = 50 * 1024 * 1024 + c.uploadSema = semaphore.NewWeighted(c.MediaConfig.UploadSize * 2) c.Capabilities = &bridgev2.MatrixCapabilities{} c.doublePuppetIntents = exsync.NewMap[id.UserID, *appservice.IntentAPI]() return c @@ -366,6 +369,7 @@ func (br *Connector) fetchMediaConfig(ctx context.Context) { if ok { mfsn.SetMaxFileSize(br.MediaConfig.UploadSize) } + br.uploadSema = semaphore.NewWeighted(br.MediaConfig.UploadSize * 2) } } diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 0c28ad50..5eb31436 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -211,6 +211,9 @@ func (as *ASIntent) DownloadMedia(ctx context.Context, uri id.ContentURIString, } func (as *ASIntent) UploadMedia(ctx context.Context, roomID id.RoomID, data []byte, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) { + if int64(len(data)) > as.Connector.MediaConfig.UploadSize { + return "", nil, fmt.Errorf("file too large (%.2f MB > %.2f MB)", float64(len(data))/1000/1000, float64(as.Connector.MediaConfig.UploadSize)/1000/1000) + } if roomID != "" { var encrypted bool if encrypted, err = as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil { @@ -231,6 +234,11 @@ func (as *ASIntent) UploadMedia(ctx context.Context, roomID id.RoomID, data []by FileName: fileName, } if as.Connector.Config.Homeserver.AsyncMedia { + // Prevent too many background uploads at once + err = as.Connector.uploadSema.Acquire(ctx, int64(len(data))) + if err != nil { + return + } var resp *mautrix.RespCreateMXC resp, err = as.Matrix.UploadAsync(ctx, req) if resp != nil { diff --git a/go.mod b/go.mod index 4ad4143c..6ff2233d 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( golang.org/x/crypto v0.26.0 golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa golang.org/x/net v0.28.0 + golang.org/x/sync v0.8.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 ) diff --git a/go.sum b/go.sum index 0adc7117..53105fc7 100644 --- a/go.sum +++ b/go.sum @@ -58,6 +58,8 @@ golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa h1:ELnwvuAXPNtPk1TJRuGkI9fDT golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ= golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= From e217a5f8cdc41b603f48a657009e2113c3062e71 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 19 Aug 2024 18:09:02 +0300 Subject: [PATCH 0649/1647] bridgev2/matrix: add missing semaphore release --- bridgev2/matrix/intent.go | 3 +++ client.go | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 5eb31436..0ad6d34b 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -239,6 +239,9 @@ func (as *ASIntent) UploadMedia(ctx context.Context, roomID id.RoomID, data []by if err != nil { return } + req.DoneCallback = func() { + as.Connector.uploadSema.Release(int64(len(data))) + } var resp *mautrix.RespCreateMXC resp, err = as.Matrix.UploadAsync(ctx, req) if resp != nil { diff --git a/client.go b/client.go index 97571219..a0e86bdb 100644 --- a/client.go +++ b/client.go @@ -1593,6 +1593,7 @@ func (cli *Client) CreateMXC(ctx context.Context) (*RespCreateMXC, error) { func (cli *Client) UploadAsync(ctx context.Context, req ReqUploadMedia) (*RespCreateMXC, error) { resp, err := cli.CreateMXC(ctx) if err != nil { + req.DoneCallback() return nil, err } req.MXC = resp.ContentURI @@ -1636,6 +1637,8 @@ type ReqUploadMedia struct { ContentType string FileName string + DoneCallback func() + // MXC specifies an existing MXC URI which doesn't have content yet to upload into. // See https://spec.matrix.org/unstable/client-server-api/#put_matrixmediav3uploadservernamemediaid MXC id.ContentURI @@ -1711,6 +1714,9 @@ func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (* // UploadMedia uploads the given data to the content repository and returns an MXC URI. // See https://spec.matrix.org/v1.7/client-server-api/#post_matrixmediav3upload func (cli *Client) UploadMedia(ctx context.Context, data ReqUploadMedia) (*RespMediaUpload, error) { + if data.DoneCallback != nil { + defer data.DoneCallback() + } if data.UnstableUploadURL != "" { if data.MXC.IsEmpty() { return nil, errors.New("MXC must also be set when uploading to external URL") From 3481f29c1a8cdccb4970deccaaaa1b08f4636459 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 19 Aug 2024 19:02:14 +0300 Subject: [PATCH 0650/1647] bridgev2/matrix: disable megolm session destination tracking --- bridgev2/matrix/crypto.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bridgev2/matrix/crypto.go b/bridgev2/matrix/crypto.go index 427b369d..7383ddc7 100644 --- a/bridgev2/matrix/crypto.go +++ b/bridgev2/matrix/crypto.go @@ -96,6 +96,7 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { Str("device_id", helper.client.DeviceID.String()). Msg("Logged in as bridge bot") helper.mach = crypto.NewOlmMachine(helper.client, helper.log, helper.store, helper.bridge.StateStore) + helper.mach.DisableSharedGroupSessionTracking = true helper.mach.AllowKeyShare = helper.allowKeyShare encryptionConfig := helper.bridge.Config.Encryption From 420db4fefb7dacafb6b3d1907027f86c062f1aba Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 19 Aug 2024 19:04:02 +0300 Subject: [PATCH 0651/1647] bridgev2/matrix: reduce upload semaphore size --- bridgev2/matrix/connector.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 8c35fb4f..dee28b8d 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -110,7 +110,7 @@ func NewConnector(cfg *bridgeconfig.Config) *Connector { c.Config = cfg c.userIDRegex = cfg.MakeUserIDRegex("(.+)") c.MediaConfig.UploadSize = 50 * 1024 * 1024 - c.uploadSema = semaphore.NewWeighted(c.MediaConfig.UploadSize * 2) + c.uploadSema = semaphore.NewWeighted(c.MediaConfig.UploadSize + 1) c.Capabilities = &bridgev2.MatrixCapabilities{} c.doublePuppetIntents = exsync.NewMap[id.UserID, *appservice.IntentAPI]() return c @@ -369,7 +369,7 @@ func (br *Connector) fetchMediaConfig(ctx context.Context) { if ok { mfsn.SetMaxFileSize(br.MediaConfig.UploadSize) } - br.uploadSema = semaphore.NewWeighted(br.MediaConfig.UploadSize * 2) + br.uploadSema = semaphore.NewWeighted(br.MediaConfig.UploadSize + 1) } } From 20ce646435ebe4fd33511901842221c1d75cf1cd Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 19 Aug 2024 19:33:26 +0300 Subject: [PATCH 0652/1647] bridgev2/matrixinterface: add stream upload method --- bridgev2/matrix/intent.go | 105 +++++++++++++++++++++++++++++++++--- bridgev2/matrixinterface.go | 2 + 2 files changed, 99 insertions(+), 8 deletions(-) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 0ad6d34b..8846a30b 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -10,6 +10,8 @@ import ( "context" "errors" "fmt" + "io" + "os" "strings" "sync" "time" @@ -228,19 +230,106 @@ func (as *ASIntent) UploadMedia(ctx context.Context, roomID id.RoomID, data []by fileName = "" } } - req := mautrix.ReqUploadMedia{ + url, err = as.doUploadReq(ctx, file, mautrix.ReqUploadMedia{ ContentBytes: data, ContentType: mimeType, FileName: fileName, + }) + return +} + +const inMemoryUploadThreshold = 5 * 1024 * 1024 + +type writeToCapturer struct { + data []byte +} + +func (w *writeToCapturer) Write(p []byte) (n int, err error) { + if w.data == nil { + w.data = p + } else { + w.data = append(w.data, p...) } - if as.Connector.Config.Homeserver.AsyncMedia { - // Prevent too many background uploads at once - err = as.Connector.uploadSema.Acquire(ctx, int64(len(data))) - if err != nil { - return + return len(p), nil +} + +func (as *ASIntent) UploadMediaStream(ctx context.Context, roomID id.RoomID, data io.Reader, size int64, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) { + if size > as.Connector.MediaConfig.UploadSize { + return "", nil, fmt.Errorf("file too large (%.2f MB > %.2f MB)", float64(size)/1000/1000, float64(as.Connector.MediaConfig.UploadSize)/1000/1000) + } else if 0 < size && size < inMemoryUploadThreshold { + var dataBytes []byte + wt, ok := data.(io.WriterTo) + if ok { + capturer := &writeToCapturer{} + _, err = wt.WriteTo(capturer) + if err != nil { + return "", nil, err + } + dataBytes = capturer.data + } else { + dataBytes, err = io.ReadAll(data) + if err != nil { + return "", nil, err + } } - req.DoneCallback = func() { - as.Connector.uploadSema.Release(int64(len(data))) + return as.UploadMedia(ctx, roomID, dataBytes, fileName, mimeType) + } + tempFile, err := os.CreateTemp("", "mautrix-upload-*") + if err != nil { + return "", nil, fmt.Errorf("failed to create temp file: %w", err) + } + defer func() { + _ = tempFile.Close() + _ = os.Remove(tempFile.Name()) + }() + var realSize int64 + if roomID != "" { + var encrypted bool + if encrypted, err = as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil { + err = fmt.Errorf("failed to check if room is encrypted: %w", err) + return + } else if encrypted { + file = &event.EncryptedFileInfo{ + EncryptedFile: *attachment.NewEncryptedFile(), + } + encryptStream := file.EncryptStream(data) + realSize, err = io.Copy(tempFile, encryptStream) + if err != nil { + return "", nil, fmt.Errorf("failed to write to temp file: %w", err) + } + err = encryptStream.Close() + if err != nil { + return "", nil, fmt.Errorf("failed to finalize encryption: %w", err) + } + mimeType = "application/octet-stream" + fileName = "" + } + } else { + realSize, err = io.Copy(tempFile, data) + if err != nil { + return "", nil, fmt.Errorf("failed to write to temp file: %w", err) + } + } + url, err = as.doUploadReq(ctx, file, mautrix.ReqUploadMedia{ + Content: tempFile, + ContentLength: realSize, + ContentType: mimeType, + FileName: fileName, + }) + return +} + +func (as *ASIntent) doUploadReq(ctx context.Context, file *event.EncryptedFileInfo, req mautrix.ReqUploadMedia) (url id.ContentURIString, err error) { + if as.Connector.Config.Homeserver.AsyncMedia { + if req.ContentBytes != nil { + // Prevent too many background uploads at once + err = as.Connector.uploadSema.Acquire(ctx, int64(len(req.ContentBytes))) + if err != nil { + return + } + req.DoneCallback = func() { + as.Connector.uploadSema.Release(int64(len(req.ContentBytes))) + } } var resp *mautrix.RespCreateMXC resp, err = as.Matrix.UploadAsync(ctx, req) diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 0628f16d..3de5e7df 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -8,6 +8,7 @@ package bridgev2 import ( "context" + "io" "time" "github.com/gorilla/mux" @@ -88,6 +89,7 @@ type MatrixAPI interface { MarkTyping(ctx context.Context, roomID id.RoomID, typingType TypingType, timeout time.Duration) error DownloadMedia(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo) ([]byte, error) UploadMedia(ctx context.Context, roomID id.RoomID, data []byte, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) + UploadMediaStream(ctx context.Context, roomID id.RoomID, data io.Reader, size int64, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) SetDisplayName(ctx context.Context, name string) error SetAvatarURL(ctx context.Context, avatarURL id.ContentURIString) error From 79527df26e8b80080fa21f78136e8aeff218ac49 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 19 Aug 2024 20:13:38 +0300 Subject: [PATCH 0653/1647] bridgev2/matrixinterface: temporarily remove stream upload from interface --- bridgev2/matrixinterface.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 3de5e7df..0628f16d 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -8,7 +8,6 @@ package bridgev2 import ( "context" - "io" "time" "github.com/gorilla/mux" @@ -89,7 +88,6 @@ type MatrixAPI interface { MarkTyping(ctx context.Context, roomID id.RoomID, typingType TypingType, timeout time.Duration) error DownloadMedia(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo) ([]byte, error) UploadMedia(ctx context.Context, roomID id.RoomID, data []byte, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) - UploadMediaStream(ctx context.Context, roomID id.RoomID, data io.Reader, size int64, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) SetDisplayName(ctx context.Context, name string) error SetAvatarURL(ctx context.Context, avatarURL id.ContentURIString) error From ce2ffd8232a5e525fe62c67d1a1659fd65426158 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 20 Aug 2024 00:52:54 +0300 Subject: [PATCH 0654/1647] bridgev2/matrix: add new stream upload that uses a writer instead of a reader (#269) --- bridgev2/bridgeconfig/config.go | 11 +- bridgev2/bridgeconfig/upgrade.go | 1 + bridgev2/matrix/intent.go | 111 ++++++++++++++------- bridgev2/matrix/mxmain/example-config.yaml | 3 + bridgev2/matrixinterface.go | 10 ++ crypto/attachment/attachments.go | 37 +++++++ 6 files changed, 130 insertions(+), 43 deletions(-) diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index ab97c891..40a17622 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -69,11 +69,12 @@ type BridgeConfig struct { } type MatrixConfig struct { - MessageStatusEvents bool `yaml:"message_status_events"` - DeliveryReceipts bool `yaml:"delivery_receipts"` - MessageErrorNotices bool `yaml:"message_error_notices"` - SyncDirectChatList bool `yaml:"sync_direct_chat_list"` - FederateRooms bool `yaml:"federate_rooms"` + MessageStatusEvents bool `yaml:"message_status_events"` + DeliveryReceipts bool `yaml:"delivery_receipts"` + MessageErrorNotices bool `yaml:"message_error_notices"` + SyncDirectChatList bool `yaml:"sync_direct_chat_list"` + FederateRooms bool `yaml:"federate_rooms"` + UploadFileThreshold int64 `yaml:"upload_file_threshold"` } type ProvisioningConfig struct { diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 4eff205d..9597fa4f 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -84,6 +84,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "matrix", "message_error_notices") helper.Copy(up.Bool, "matrix", "sync_direct_chat_list") helper.Copy(up.Bool, "matrix", "federate_rooms") + helper.Copy(up.Int, "matrix", "upload_file_threshold") helper.Copy(up.Str, "provisioning", "prefix") if secret, ok := helper.Get(up.Str, "provisioning", "shared_secret"); !ok || secret == "generate" { diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 8846a30b..115d7393 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -238,13 +238,11 @@ func (as *ASIntent) UploadMedia(ctx context.Context, roomID id.RoomID, data []by return } -const inMemoryUploadThreshold = 5 * 1024 * 1024 - -type writeToCapturer struct { +type simpleBuffer struct { data []byte } -func (w *writeToCapturer) Write(p []byte) (n int, err error) { +func (w *simpleBuffer) Write(p []byte) (n int, err error) { if w.data == nil { w.data = p } else { @@ -253,36 +251,50 @@ func (w *writeToCapturer) Write(p []byte) (n int, err error) { return len(p), nil } -func (as *ASIntent) UploadMediaStream(ctx context.Context, roomID id.RoomID, data io.Reader, size int64, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) { +func (w *simpleBuffer) Seek(offset int64, whence int) (int64, error) { + if whence == io.SeekStart { + if offset == 0 { + w.data = nil + } else { + w.data = w.data[:offset] + } + return offset, nil + } + return 0, fmt.Errorf("unsupported whence value %d", whence) +} + +func (as *ASIntent) UploadMediaStream( + ctx context.Context, + roomID id.RoomID, + size int64, + requireFile bool, + fileName, + mimeType string, + cb bridgev2.FileStreamCallback, +) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) { if size > as.Connector.MediaConfig.UploadSize { return "", nil, fmt.Errorf("file too large (%.2f MB > %.2f MB)", float64(size)/1000/1000, float64(as.Connector.MediaConfig.UploadSize)/1000/1000) - } else if 0 < size && size < inMemoryUploadThreshold { - var dataBytes []byte - wt, ok := data.(io.WriterTo) - if ok { - capturer := &writeToCapturer{} - _, err = wt.WriteTo(capturer) - if err != nil { - return "", nil, err - } - dataBytes = capturer.data - } else { - dataBytes, err = io.ReadAll(data) - if err != nil { - return "", nil, err - } - } - return as.UploadMedia(ctx, roomID, dataBytes, fileName, mimeType) } - tempFile, err := os.CreateTemp("", "mautrix-upload-*") + if !requireFile && 0 < size && size < as.Connector.Config.Matrix.UploadFileThreshold { + var buf simpleBuffer + replPath, err := cb(&buf) + if err != nil { + return "", nil, err + } else if replPath != "" { + panic(fmt.Errorf("logic error: replacement path must only be returned if requireFile is true")) + } + return as.UploadMedia(ctx, roomID, buf.data, fileName, mimeType) + } + var tempFile *os.File + tempFile, err = os.CreateTemp("", "mautrix-upload-*") if err != nil { - return "", nil, fmt.Errorf("failed to create temp file: %w", err) + err = fmt.Errorf("failed to create temp file: %w", err) + return } defer func() { _ = tempFile.Close() _ = os.Remove(tempFile.Name()) }() - var realSize int64 if roomID != "" { var encrypted bool if encrypted, err = as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil { @@ -292,27 +304,50 @@ func (as *ASIntent) UploadMediaStream(ctx context.Context, roomID id.RoomID, dat file = &event.EncryptedFileInfo{ EncryptedFile: *attachment.NewEncryptedFile(), } - encryptStream := file.EncryptStream(data) - realSize, err = io.Copy(tempFile, encryptStream) - if err != nil { - return "", nil, fmt.Errorf("failed to write to temp file: %w", err) - } - err = encryptStream.Close() - if err != nil { - return "", nil, fmt.Errorf("failed to finalize encryption: %w", err) - } mimeType = "application/octet-stream" fileName = "" } - } else { - realSize, err = io.Copy(tempFile, data) + } + var replPath string + replPath, err = cb(tempFile) + if err != nil { + err = fmt.Errorf("failed to write to temp file: %w", err) + } + var replFile *os.File + if replPath != "" { + replFile, err = os.OpenFile(replPath, os.O_RDWR, 0) if err != nil { - return "", nil, fmt.Errorf("failed to write to temp file: %w", err) + err = fmt.Errorf("failed to open replacement file: %w", err) + return } + } else { + replFile = tempFile + _, err = replFile.Seek(0, io.SeekStart) + if err != nil { + err = fmt.Errorf("failed to seek to start of temp file: %w", err) + return + } + } + if file != nil { + err = file.EncryptFile(replFile) + if err != nil { + err = fmt.Errorf("failed to encrypt file: %w", err) + return + } + _, err = replFile.Seek(0, io.SeekStart) + if err != nil { + err = fmt.Errorf("failed to seek to start of temp file after encrypting: %w", err) + return + } + } + info, err := replFile.Stat() + if err != nil { + err = fmt.Errorf("failed to get temp file info: %w", err) + return } url, err = as.doUploadReq(ctx, file, mautrix.ReqUploadMedia{ Content: tempFile, - ContentLength: realSize, + ContentLength: info.Size(), ContentType: mimeType, FileName: fileName, }) diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 06ed010f..e0a5ed87 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -198,6 +198,9 @@ matrix: # Whether created rooms should have federation enabled. If false, created portal rooms # will never be federated. Changing this option requires recreating rooms. federate_rooms: true + # The threshold as bytes after which the bridge should roundtrip uploads via the disk + # rather than keeping the whole file in memory. + upload_file_threshold: 5242880 # Settings for provisioning API provisioning: diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 0628f16d..02528fde 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -8,6 +8,7 @@ package bridgev2 import ( "context" + "io" "time" "github.com/gorilla/mux" @@ -78,6 +79,14 @@ type MatrixSendExtra struct { PartIndex int } +// FileStreamCallback is a callback function for file uploads that roundtrip via disk. +// +// The parameter is either a file or an in-memory buffer depending on the size of the file and whether the requireFile flag was set. +// +// The first return value can specify a file path to use instead of the original temp file. +// Returning a replacement path is only valid if the parameter is a file. +type FileStreamCallback func(file io.WriteSeeker) (string, error) + type MatrixAPI interface { GetMXID() id.UserID @@ -88,6 +97,7 @@ type MatrixAPI interface { MarkTyping(ctx context.Context, roomID id.RoomID, typingType TypingType, timeout time.Duration) error DownloadMedia(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo) ([]byte, error) UploadMedia(ctx context.Context, roomID id.RoomID, data []byte, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) + UploadMediaStream(ctx context.Context, roomID id.RoomID, size int64, requireFile bool, fileName, mimeType string, cb FileStreamCallback) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) SetDisplayName(ctx context.Context, name string) error SetAvatarURL(ctx context.Context, avatarURL id.ContentURIString) error diff --git a/crypto/attachment/attachments.go b/crypto/attachment/attachments.go index 344db4f0..cfa1c3e5 100644 --- a/crypto/attachment/attachments.go +++ b/crypto/attachment/attachments.go @@ -127,6 +127,43 @@ func (ef *EncryptedFile) EncryptInPlace(data []byte) { ef.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(checksum[:]) } +type ReadWriterAt interface { + io.WriterAt + io.Reader +} + +// EncryptFile encrypts the given file in-place and updates the SHA256 hash in the EncryptedFile struct. +func (ef *EncryptedFile) EncryptFile(file ReadWriterAt) error { + err := ef.decodeKeys(false) + if err != nil { + return err + } + block, _ := aes.NewCipher(ef.decoded.key[:]) + stream := cipher.NewCTR(block, ef.decoded.iv[:]) + hasher := sha256.New() + buf := make([]byte, 32*1024) + var writePtr int64 + var n int + for { + n, err = file.Read(buf) + if err != nil && !errors.Is(err, io.EOF) { + return err + } + if n == 0 { + break + } + stream.XORKeyStream(buf[:n], buf[:n]) + _, err = file.WriteAt(buf[:n], writePtr) + if err != nil { + return err + } + writePtr += int64(n) + hasher.Write(buf[:n]) + } + ef.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(hasher.Sum(nil)) + return nil +} + type encryptingReader struct { stream cipher.Stream hash hash.Hash From 38278ef37d199d3a9deba04b825a094eea6c1d10 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 20 Aug 2024 00:52:38 +0300 Subject: [PATCH 0655/1647] bridgev2/unorganized-docs: update features --- bridgev2/unorganized-docs/FEATURES.md | 52 +++++++++++++-------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/bridgev2/unorganized-docs/FEATURES.md b/bridgev2/unorganized-docs/FEATURES.md index 908ca975..73da364d 100644 --- a/bridgev2/unorganized-docs/FEATURES.md +++ b/bridgev2/unorganized-docs/FEATURES.md @@ -5,45 +5,45 @@ * [x] Attachments * [ ] Polls * [x] Replies - * [ ] Threads + * [x] Threads * [x] Edits * [x] Reactions - * [ ] Reaction mass-syncing + * [x] Reaction mass-syncing * [x] Deletions * [x] Message status events and error notices - * [ ] Backfilling history + * [x] Backfilling history * [x] Login * [x] Logout -* [ ] Re-login after credential expiry -* [ ] Disappearing messages +* [x] Re-login after credential expiry +* [x] Disappearing messages * [x] Read receipts * [ ] Presence -* [ ] Typing notifications -* [ ] Spaces -* [ ] Relay mode -* [ ] Chat metadata - * [ ] Archive/low priority - * [ ] Pin/favorite - * [ ] Mark unread - * [ ] Mute status - * [ ] Temporary mutes ("snooze") +* [x] Typing notifications +* [x] Spaces +* [x] Relay mode +* [x] Chat metadata + * [x] Archive/low priority + * [x] Pin/favorite + * [x] Mark unread + * [x] Mute status + * [x] Temporary mutes ("snooze") * [x] User metadata (name/avatar) -* [ ] Group metadata - * [ ] Initial meta and full resyncs +* [x] Group metadata + * [x] Initial meta and full resyncs * [x] Name, avatar, topic * [x] Members - * [ ] Permissions - * [ ] Change events - * [ ] Name, avatar, topic - * [ ] Members (join, leave, invite, kick, ban, knock) - * [ ] Permissions (promote, demote) + * [x] Permissions + * [x] Change events + * [x] Name, avatar, topic + * [x] Members (join, leave, invite, kick, ban, knock) + * [x] Permissions (promote, demote) * [ ] Misc actions * [ ] Invites / accepting message requests - * [ ] Create group - * [ ] Create DM - * [ ] Get contact list - * [ ] Check if identifier is on remote network - * [ ] Search users on remote network + * [x] Create group + * [x] Create DM + * [x] Get contact list + * [x] Check if identifier is on remote network + * [x] Search users on remote network * [ ] Delete chat * [ ] Report spam * [ ] Custom emojis From feff2d5886b07090569a1c55c9218735af6bb5cc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 20 Aug 2024 01:00:10 +0300 Subject: [PATCH 0656/1647] bridgev2/matrix: don't delete temp file before async upload completes --- bridgev2/matrix/intent.go | 41 +++++++++++++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 115d7393..daa07ea8 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -291,9 +291,15 @@ func (as *ASIntent) UploadMediaStream( err = fmt.Errorf("failed to create temp file: %w", err) return } + removeAndClose := func(f *os.File) { + _ = f.Close() + _ = os.Remove(f.Name()) + } + startedAsyncUpload := false defer func() { - _ = tempFile.Close() - _ = os.Remove(tempFile.Name()) + if !startedAsyncUpload { + removeAndClose(tempFile) + } }() if roomID != "" { var encrypted bool @@ -320,6 +326,11 @@ func (as *ASIntent) UploadMediaStream( err = fmt.Errorf("failed to open replacement file: %w", err) return } + defer func() { + if !startedAsyncUpload { + removeAndClose(replFile) + } + }() } else { replFile = tempFile _, err = replFile.Seek(0, io.SeekStart) @@ -345,12 +356,34 @@ func (as *ASIntent) UploadMediaStream( err = fmt.Errorf("failed to get temp file info: %w", err) return } - url, err = as.doUploadReq(ctx, file, mautrix.ReqUploadMedia{ + req := mautrix.ReqUploadMedia{ Content: tempFile, ContentLength: info.Size(), ContentType: mimeType, FileName: fileName, - }) + } + if as.Connector.Config.Homeserver.AsyncMedia { + req.DoneCallback = func() { + removeAndClose(replFile) + removeAndClose(tempFile) + } + startedAsyncUpload = true + var resp *mautrix.RespCreateMXC + resp, err = as.Matrix.UploadAsync(ctx, req) + if resp != nil { + url = resp.ContentURI.CUString() + } + } else { + var resp *mautrix.RespMediaUpload + resp, err = as.Matrix.UploadMedia(ctx, req) + if resp != nil { + url = resp.ContentURI.CUString() + } + } + if file != nil { + file.URL = url + url = "" + } return } From a614668174a9e80f5763f15a4fd791dcbbc2187c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 20 Aug 2024 01:20:39 +0300 Subject: [PATCH 0657/1647] bridgev2/matrix: remove custom buffer in stream upload --- bridgev2/matrix/intent.go | 30 +++--------------------------- bridgev2/matrixinterface.go | 2 +- 2 files changed, 4 insertions(+), 28 deletions(-) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index daa07ea8..43d23a4c 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -7,6 +7,7 @@ package matrix import ( + "bytes" "context" "errors" "fmt" @@ -238,31 +239,6 @@ func (as *ASIntent) UploadMedia(ctx context.Context, roomID id.RoomID, data []by return } -type simpleBuffer struct { - data []byte -} - -func (w *simpleBuffer) Write(p []byte) (n int, err error) { - if w.data == nil { - w.data = p - } else { - w.data = append(w.data, p...) - } - return len(p), nil -} - -func (w *simpleBuffer) Seek(offset int64, whence int) (int64, error) { - if whence == io.SeekStart { - if offset == 0 { - w.data = nil - } else { - w.data = w.data[:offset] - } - return offset, nil - } - return 0, fmt.Errorf("unsupported whence value %d", whence) -} - func (as *ASIntent) UploadMediaStream( ctx context.Context, roomID id.RoomID, @@ -276,14 +252,14 @@ func (as *ASIntent) UploadMediaStream( return "", nil, fmt.Errorf("file too large (%.2f MB > %.2f MB)", float64(size)/1000/1000, float64(as.Connector.MediaConfig.UploadSize)/1000/1000) } if !requireFile && 0 < size && size < as.Connector.Config.Matrix.UploadFileThreshold { - var buf simpleBuffer + var buf bytes.Buffer replPath, err := cb(&buf) if err != nil { return "", nil, err } else if replPath != "" { panic(fmt.Errorf("logic error: replacement path must only be returned if requireFile is true")) } - return as.UploadMedia(ctx, roomID, buf.data, fileName, mimeType) + return as.UploadMedia(ctx, roomID, buf.Bytes(), fileName, mimeType) } var tempFile *os.File tempFile, err = os.CreateTemp("", "mautrix-upload-*") diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 02528fde..893f0321 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -85,7 +85,7 @@ type MatrixSendExtra struct { // // The first return value can specify a file path to use instead of the original temp file. // Returning a replacement path is only valid if the parameter is a file. -type FileStreamCallback func(file io.WriteSeeker) (string, error) +type FileStreamCallback func(file io.Writer) (string, error) type MatrixAPI interface { GetMXID() id.UserID From 4523807e563803e82a7a88416451bd587410fa79 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 20 Aug 2024 13:56:19 +0300 Subject: [PATCH 0658/1647] bridgev2/portal: slightly refactor power level checks --- bridgev2/portal.go | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index fc047beb..7a5158b8 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2436,6 +2436,8 @@ func (portal *Portal) handleRemoteChatResync(ctx context.Context, source *UserLo log.Err(err).Msg("Failed to get chat info from resync event") } else if info != nil { portal.UpdateInfo(ctx, info, source, nil, time.Time{}) + } else { + log.Debug().Msg("No chat info provided in resync event") } } backfillChecker, ok := evt.(RemoteChatResyncBackfill) @@ -2549,8 +2551,10 @@ type PowerLevelOverrides struct { // Deprecated: renamed to PowerLevelOverrides type PowerLevelChanges = PowerLevelOverrides -func allowChange(newLevel, oldLevel, actorLevel int) bool { - return newLevel <= actorLevel && oldLevel <= actorLevel +func allowChange(newLevel *int, oldLevel, actorLevel int) bool { + return newLevel != nil && + *newLevel <= actorLevel && oldLevel <= actorLevel && + oldLevel != *newLevel } func (plc *PowerLevelOverrides) Apply(actor id.UserID, content *event.PowerLevelsEventContent) (changed bool) { @@ -2566,32 +2570,32 @@ func (plc *PowerLevelOverrides) Apply(actor id.UserID, content *event.PowerLevel } else { actorLevel = (1 << 31) - 1 } - if plc.UsersDefault != nil && allowChange(*plc.UsersDefault, content.UsersDefault, actorLevel) { - changed = content.UsersDefault != *plc.UsersDefault + if allowChange(plc.UsersDefault, content.UsersDefault, actorLevel) { + changed = true content.UsersDefault = *plc.UsersDefault } - if plc.EventsDefault != nil && allowChange(*plc.EventsDefault, content.EventsDefault, actorLevel) { - changed = content.EventsDefault != *plc.EventsDefault + if allowChange(plc.EventsDefault, content.EventsDefault, actorLevel) { + changed = true content.EventsDefault = *plc.EventsDefault } - if plc.StateDefault != nil && allowChange(*plc.StateDefault, content.StateDefault(), actorLevel) { - changed = content.StateDefault() != *plc.StateDefault + if allowChange(plc.StateDefault, content.StateDefault(), actorLevel) { + changed = true content.StateDefaultPtr = plc.StateDefault } - if plc.Invite != nil && allowChange(*plc.Invite, content.Invite(), actorLevel) { - changed = content.Invite() != *plc.Invite + if allowChange(plc.Invite, content.Invite(), actorLevel) { + changed = true content.InvitePtr = plc.Invite } - if plc.Kick != nil && allowChange(*plc.Kick, content.Kick(), actorLevel) { - changed = content.Kick() != *plc.Kick + if allowChange(plc.Kick, content.Kick(), actorLevel) { + changed = true content.KickPtr = plc.Kick } - if plc.Ban != nil && allowChange(*plc.Ban, content.Ban(), actorLevel) { - changed = content.Ban() != *plc.Ban + if allowChange(plc.Ban, content.Ban(), actorLevel) { + changed = true content.BanPtr = plc.Ban } - if plc.Redact != nil && allowChange(*plc.Redact, content.Redact(), actorLevel) { - changed = content.Redact() != *plc.Redact + if allowChange(plc.Redact, content.Redact(), actorLevel) { + changed = true content.RedactPtr = plc.Redact } if plc.Custom != nil { From f4d46376251852f8fa260f32d09287e977ca6ca1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 20 Aug 2024 13:56:45 +0300 Subject: [PATCH 0659/1647] bridgev2/portal: don't create backfill task before portal room --- bridgev2/portal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 7a5158b8..158cd85b 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -3212,7 +3212,7 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us source.MarkInPortal(ctx, portal) portal.updateUserLocalInfo(ctx, info.UserLocal, source, false) } - if info.CanBackfill && source != nil { + if info.CanBackfill && source != nil && portal.MXID != "" { err := portal.Bridge.DB.BackfillTask.EnsureExists(ctx, portal.PortalKey, source.ID) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to ensure backfill queue task exists") From 063d3742261921d1fd7617ad1a18f2c5c684e20b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 20 Aug 2024 15:47:39 +0300 Subject: [PATCH 0660/1647] bridgev2/backfillqueue: don't try to backfill with non-logged-in client --- bridgev2/backfillqueue.go | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/bridgev2/backfillqueue.go b/bridgev2/backfillqueue.go index 6b9242c2..0f4ee048 100644 --- a/bridgev2/backfillqueue.go +++ b/bridgev2/backfillqueue.go @@ -153,18 +153,26 @@ func (br *Bridge) actuallyDoBackfillTask(ctx context.Context, task *database.Bac login, err := br.GetExistingUserLoginByID(ctx, task.UserLoginID) if err != nil { return false, fmt.Errorf("failed to get user login for backfill task: %w", err) - } else if login == nil { - log.Warn().Msg("User login not found for backfill task") + } else if login == nil || !login.Client.IsLoggedIn() { + if login == nil { + log.Warn().Msg("User login not found for backfill task") + } else { + log.Warn().Msg("User login not logged in for backfill task") + } logins, err := br.GetUserLoginsInPortal(ctx, portal.PortalKey) if err != nil { return false, fmt.Errorf("failed to get user portals for backfill task: %w", err) } else if len(logins) == 0 { log.Debug().Msg("No user logins found for backfill task") task.NextDispatchMinTS = database.BackfillNextDispatchNever - task.UserLoginID = "" + if login == nil { + task.UserLoginID = "" + } return false, nil } - task.UserLoginID = "" + if login == nil { + task.UserLoginID = "" + } for _, login = range logins { if login.Client.IsLoggedIn() { task.UserLoginID = login.ID From 9e1a8cd56e31a99a005f9327ade5d2688c30422f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 20 Aug 2024 16:19:49 +0300 Subject: [PATCH 0661/1647] bridgev2/matrix: use cached member list if available --- bridgev2/matrix/connector.go | 7 ++++- bridgev2/matrix/intent.go | 4 +++ client.go | 12 +++++++ hicli/database/statestore.go | 13 ++++++++ sqlstatestore/statestore.go | 41 ++++++++++++++++++++++++ sqlstatestore/v00-latest-revision.sql | 19 ++++++------ sqlstatestore/v07-full-member-flag.sql | 2 ++ statestore.go | 43 +++++++++++++++++++++----- 8 files changed, 123 insertions(+), 18 deletions(-) create mode 100644 sqlstatestore/v07-full-member-flag.sql diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index dee28b8d..115250f2 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -540,7 +540,12 @@ func (br *Connector) GetPowerLevels(ctx context.Context, roomID id.RoomID) (*eve } func (br *Connector) GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { - // TODO use cache? + fetched, err := br.Bot.StateStore.HasFetchedMembers(ctx, roomID) + if err != nil { + return nil, err + } else if fetched { + return br.Bot.StateStore.GetAllMembers(ctx, roomID) + } members, err := br.Bot.Members(ctx, roomID) if err != nil { return nil, err diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 43d23a4c..0f668f8c 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -535,6 +535,10 @@ func (as *ASIntent) DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnl if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to leave room while cleaning up portal") } + err = as.Matrix.StateStore.ClearCachedMembers(ctx, roomID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to clear cached members while cleaning up portal") + } return nil } diff --git a/client.go b/client.go index a0e86bdb..b3bd9158 100644 --- a/client.go +++ b/client.go @@ -1444,6 +1444,11 @@ func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomSt UpdateStateStore(ctx, cli.StateStore, evt) } } + clearErr = cli.StateStore.MarkMembersFetched(ctx, roomID) + if clearErr != nil { + cli.cliOrContextLog(ctx).Warn().Err(clearErr). + Msg("Failed to mark members as fetched after fetching full room state") + } } return } @@ -1840,6 +1845,13 @@ func (cli *Client) Members(ctx context.Context, roomID id.RoomID, req ...ReqMemb for _, evt := range resp.Chunk { UpdateStateStore(ctx, cli.StateStore, evt) } + if extra.NotMembership == "" && extra.Membership == "" { + markErr := cli.StateStore.MarkMembersFetched(ctx, roomID) + if markErr != nil { + cli.cliOrContextLog(ctx).Warn().Err(markErr). + Msg("Failed to mark members as fetched after fetching full member list") + } + } } return } diff --git a/hicli/database/statestore.go b/hicli/database/statestore.go index e8050e93..cefe76d3 100644 --- a/hicli/database/statestore.go +++ b/hicli/database/statestore.go @@ -10,6 +10,7 @@ import ( "context" "database/sql" "errors" + "fmt" "go.mau.fi/util/dbutil" "golang.org/x/exp/slices" @@ -115,6 +116,18 @@ func (c *ClientStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, ro return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() } +func (c *ClientStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error) { + return false, fmt.Errorf("not implemented") +} + +func (c *ClientStateStore) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error { + return fmt.Errorf("not implemented") +} + +func (c *ClientStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { + return nil, fmt.Errorf("not implemented") +} + func (c *ClientStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (isEncrypted bool, err error) { err = c.QueryRow(ctx, isRoomEncryptedQuery, roomID).Scan(&isEncrypted) if errors.Is(err, sql.ErrNoRows) { diff --git a/sqlstatestore/statestore.go b/sqlstatestore/statestore.go index 0e5c4184..2cfd1b97 100644 --- a/sqlstatestore/statestore.go +++ b/sqlstatestore/statestore.go @@ -234,9 +234,50 @@ func (store *SQLStateStore) ClearCachedMembers(ctx context.Context, roomID id.Ro query += fmt.Sprintf(" AND membership IN (%s)", strings.Join(placeholders, ",")) } _, err := store.Exec(ctx, query, params...) + if err != nil { + return err + } + _, err = store.Exec(ctx, "UPDATE mx_room_state SET members_fetched=false WHERE room_id=$1", roomID) return err } +func (store *SQLStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (fetched bool, err error) { + err = store.QueryRow(ctx, "SELECT members_fetched FROM mx_room_state WHERE room_id=$1", roomID).Scan(&fetched) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return +} + +func (store *SQLStateStore) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error { + _, err := store.Exec(ctx, ` + INSERT INTO mx_room_state (room_id, members_fetched) VALUES ($1, true) + ON CONFLICT (room_id) DO UPDATE SET members_fetched=true + `, roomID) + return err +} + +type userAndMembership struct { + UserID id.UserID + event.MemberEventContent +} + +func (store *SQLStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { + rows, err := store.Query(ctx, "SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1", roomID) + if err != nil { + return nil, err + } + output := make(map[id.UserID]*event.MemberEventContent) + err = dbutil.NewRowIterWithError(rows, func(row dbutil.Scannable) (res userAndMembership, err error) { + err = row.Scan(&res.UserID, &res.Membership, &res.Displayname, &res.AvatarURL) + return + }, err).Iter(func(member userAndMembership) (bool, error) { + output[member.UserID] = &member.MemberEventContent + return true, nil + }) + return output, err +} + func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error { contentBytes, err := json.Marshal(content) if err != nil { diff --git a/sqlstatestore/v00-latest-revision.sql b/sqlstatestore/v00-latest-revision.sql index b2bb2ae6..a58cc56a 100644 --- a/sqlstatestore/v00-latest-revision.sql +++ b/sqlstatestore/v00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v6 (compatible with v3+): Latest revision +-- v0 -> v7 (compatible with v3+): Latest revision CREATE TABLE mx_registrations ( user_id TEXT PRIMARY KEY @@ -8,11 +8,11 @@ CREATE TABLE mx_registrations ( CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock'); CREATE TABLE mx_user_profile ( - room_id TEXT, - user_id TEXT, - membership membership NOT NULL, - displayname TEXT NOT NULL DEFAULT '', - avatar_url TEXT NOT NULL DEFAULT '', + room_id TEXT, + user_id TEXT, + membership membership NOT NULL, + displayname TEXT NOT NULL DEFAULT '', + avatar_url TEXT NOT NULL DEFAULT '', name_skeleton bytea, @@ -23,7 +23,8 @@ CREATE INDEX mx_user_profile_membership_idx ON mx_user_profile (room_id, members CREATE INDEX mx_user_profile_name_skeleton_idx ON mx_user_profile (room_id, name_skeleton); CREATE TABLE mx_room_state ( - room_id TEXT PRIMARY KEY, - power_levels jsonb, - encryption jsonb + room_id TEXT PRIMARY KEY, + power_levels jsonb, + encryption jsonb, + members_fetched BOOLEAN NOT NULL DEFAULT false ); diff --git a/sqlstatestore/v07-full-member-flag.sql b/sqlstatestore/v07-full-member-flag.sql new file mode 100644 index 00000000..32f2ef6c --- /dev/null +++ b/sqlstatestore/v07-full-member-flag.sql @@ -0,0 +1,2 @@ +-- v7 (compatible with v3+): Add flag for whether the full member list has been fetched +ALTER TABLE mx_room_state ADD COLUMN members_fetched BOOLEAN NOT NULL DEFAULT false; diff --git a/statestore.go b/statestore.go index 35bfc6ab..5f210e4f 100644 --- a/statestore.go +++ b/statestore.go @@ -8,6 +8,7 @@ package mautrix import ( "context" + "maps" "sync" "github.com/rs/zerolog" @@ -32,6 +33,10 @@ type StateStore interface { SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error) + HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error) + MarkMembersFetched(ctx context.Context, roomID id.RoomID) error + GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) + SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) @@ -90,10 +95,11 @@ func (cli *Client) StateStoreSyncHandler(ctx context.Context, evt *event.Event) } type MemoryStateStore struct { - Registrations map[id.UserID]bool `json:"registrations"` - Members map[id.RoomID]map[id.UserID]*event.MemberEventContent `json:"memberships"` - PowerLevels map[id.RoomID]*event.PowerLevelsEventContent `json:"power_levels"` - Encryption map[id.RoomID]*event.EncryptionEventContent `json:"encryption"` + Registrations map[id.UserID]bool `json:"registrations"` + Members map[id.RoomID]map[id.UserID]*event.MemberEventContent `json:"memberships"` + MembersFetched map[id.RoomID]bool `json:"members_fetched"` + PowerLevels map[id.RoomID]*event.PowerLevelsEventContent `json:"power_levels"` + Encryption map[id.RoomID]*event.EncryptionEventContent `json:"encryption"` registrationsLock sync.RWMutex membersLock sync.RWMutex @@ -103,10 +109,11 @@ type MemoryStateStore struct { func NewMemoryStateStore() StateStore { return &MemoryStateStore{ - Registrations: make(map[id.UserID]bool), - Members: make(map[id.RoomID]map[id.UserID]*event.MemberEventContent), - PowerLevels: make(map[id.RoomID]*event.PowerLevelsEventContent), - Encryption: make(map[id.RoomID]*event.EncryptionEventContent), + Registrations: make(map[id.UserID]bool), + Members: make(map[id.RoomID]map[id.UserID]*event.MemberEventContent), + MembersFetched: make(map[id.RoomID]bool), + PowerLevels: make(map[id.RoomID]*event.PowerLevelsEventContent), + Encryption: make(map[id.RoomID]*event.EncryptionEventContent), } } @@ -246,9 +253,29 @@ func (store *MemoryStateStore) ClearCachedMembers(_ context.Context, roomID id.R } } } + store.MembersFetched[roomID] = false return nil } +func (store *MemoryStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error) { + store.membersLock.RLock() + defer store.membersLock.RUnlock() + return store.MembersFetched[roomID], nil +} + +func (store *MemoryStateStore) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error { + store.membersLock.Lock() + defer store.membersLock.Unlock() + store.MembersFetched[roomID] = true + return nil +} + +func (store *MemoryStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { + store.membersLock.Lock() + defer store.membersLock.Unlock() + return maps.Clone(store.Members[roomID]), nil +} + func (store *MemoryStateStore) SetPowerLevels(_ context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error { store.powerLevelsLock.Lock() store.PowerLevels[roomID] = levels From 4b7fa711cedb97960fe2e2f02d5ac43d661c71f7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 20 Aug 2024 16:29:41 +0300 Subject: [PATCH 0662/1647] changelog: update --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 819f69d0..d0da4c37 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +## unreleased + +* *(bridgev2)* Added more features and fixed bugs. +* *(client)* Fixed requests being retried even after context is canceled. + ## v0.20.0 (2024-08-16) * Bumped minimum Go version to 1.22. From 591ac60f0caa2c0f1c14ed3ed03e9da85e6c26b7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 20 Aug 2024 17:37:21 +0300 Subject: [PATCH 0663/1647] bridgev2/portal: only forward backfill after room creation if enabled in config --- bridgev2/portal.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 158cd85b..c4c7592c 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -3442,7 +3442,9 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo } } } - portal.doForwardBackfill(ctx, source, nil, backfillBundle) + if portal.Bridge.Config.Backfill.Enabled { + portal.doForwardBackfill(ctx, source, nil, backfillBundle) + } return nil } From f99fb60f13a60d0dcf879c212451fc8747127b7d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 21 Aug 2024 01:10:55 +0300 Subject: [PATCH 0664/1647] bridgev2/matrixinterface: move upload stream file name/mime into callback return values --- bridgev2/matrix/intent.go | 24 +++++++++++------------- bridgev2/matrixinterface.go | 22 ++++++++++++++++++---- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 0f668f8c..08a7b940 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -244,8 +244,6 @@ func (as *ASIntent) UploadMediaStream( roomID id.RoomID, size int64, requireFile bool, - fileName, - mimeType string, cb bridgev2.FileStreamCallback, ) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) { if size > as.Connector.MediaConfig.UploadSize { @@ -253,13 +251,13 @@ func (as *ASIntent) UploadMediaStream( } if !requireFile && 0 < size && size < as.Connector.Config.Matrix.UploadFileThreshold { var buf bytes.Buffer - replPath, err := cb(&buf) + res, err := cb(&buf) if err != nil { return "", nil, err - } else if replPath != "" { + } else if res.ReplacementFile != "" { panic(fmt.Errorf("logic error: replacement path must only be returned if requireFile is true")) } - return as.UploadMedia(ctx, roomID, buf.Bytes(), fileName, mimeType) + return as.UploadMedia(ctx, roomID, buf.Bytes(), res.FileName, res.MimeType) } var tempFile *os.File tempFile, err = os.CreateTemp("", "mautrix-upload-*") @@ -286,18 +284,16 @@ func (as *ASIntent) UploadMediaStream( file = &event.EncryptedFileInfo{ EncryptedFile: *attachment.NewEncryptedFile(), } - mimeType = "application/octet-stream" - fileName = "" } } - var replPath string - replPath, err = cb(tempFile) + var res *bridgev2.FileStreamResult + res, err = cb(tempFile) if err != nil { err = fmt.Errorf("failed to write to temp file: %w", err) } var replFile *os.File - if replPath != "" { - replFile, err = os.OpenFile(replPath, os.O_RDWR, 0) + if res.ReplacementFile != "" { + replFile, err = os.OpenFile(res.ReplacementFile, os.O_RDWR, 0) if err != nil { err = fmt.Errorf("failed to open replacement file: %w", err) return @@ -316,6 +312,8 @@ func (as *ASIntent) UploadMediaStream( } } if file != nil { + res.FileName = "" + res.MimeType = "application/octet-stream" err = file.EncryptFile(replFile) if err != nil { err = fmt.Errorf("failed to encrypt file: %w", err) @@ -335,8 +333,8 @@ func (as *ASIntent) UploadMediaStream( req := mautrix.ReqUploadMedia{ Content: tempFile, ContentLength: info.Size(), - ContentType: mimeType, - FileName: fileName, + ContentType: res.MimeType, + FileName: res.FileName, } if as.Connector.Config.Homeserver.AsyncMedia { req.DoneCallback = func() { diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 893f0321..9fb0c82d 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -79,13 +79,27 @@ type MatrixSendExtra struct { PartIndex int } +// FileStreamResult is the result of a FileStreamCallback. +type FileStreamResult struct { + // ReplacementFile is the path to a new file that replaces the original file provided to the callback. + // Providing a replacement file is only allowed if the requireFile flag was set for the UploadMediaStream call. + ReplacementFile string + // FileName is the name of the file to be specified when uploading to the server. + // This should be the same as the file name that will be included in the Matrix event (body or filename field). + // If the file gets encrypted, this field will be ignored. + FileName string + // MimeType is the type of field to be specified when uploading to the server. + // This should be the same as the mime type that will be included in the Matrix event (info -> mimetype field). + // If the file gets encrypted, this field will be replaced with application/octet-stream. + MimeType string +} + // FileStreamCallback is a callback function for file uploads that roundtrip via disk. // // The parameter is either a file or an in-memory buffer depending on the size of the file and whether the requireFile flag was set. // -// The first return value can specify a file path to use instead of the original temp file. -// Returning a replacement path is only valid if the parameter is a file. -type FileStreamCallback func(file io.Writer) (string, error) +// The return value must be non-nil unless there's an error, and should always include FileName and MimeType. +type FileStreamCallback func(file io.Writer) (*FileStreamResult, error) type MatrixAPI interface { GetMXID() id.UserID @@ -97,7 +111,7 @@ type MatrixAPI interface { MarkTyping(ctx context.Context, roomID id.RoomID, typingType TypingType, timeout time.Duration) error DownloadMedia(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo) ([]byte, error) UploadMedia(ctx context.Context, roomID id.RoomID, data []byte, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) - UploadMediaStream(ctx context.Context, roomID id.RoomID, size int64, requireFile bool, fileName, mimeType string, cb FileStreamCallback) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) + UploadMediaStream(ctx context.Context, roomID id.RoomID, size int64, requireFile bool, cb FileStreamCallback) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) SetDisplayName(ctx context.Context, name string) error SetAvatarURL(ctx context.Context, avatarURL id.ContentURIString) error From 851a9485e7be1d4ba6b0b878a3d762b7a3a409aa Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 21 Aug 2024 10:45:05 +0300 Subject: [PATCH 0665/1647] bridgev2/matrix: fix replacement files in UploadMediaStream --- bridgev2/matrix/intent.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 08a7b940..5b27f906 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -331,7 +331,7 @@ func (as *ASIntent) UploadMediaStream( return } req := mautrix.ReqUploadMedia{ - Content: tempFile, + Content: replFile, ContentLength: info.Size(), ContentType: res.MimeType, FileName: res.FileName, From edef968c642d5e256c2cc657c316e5c4612cbd4e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 21 Aug 2024 10:46:24 +0300 Subject: [PATCH 0666/1647] event: remove omitemptys from MSC1767Audio --- event/audio.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/event/audio.go b/event/audio.go index 0fc0818b..798acc8c 100644 --- a/event/audio.go +++ b/event/audio.go @@ -1,8 +1,8 @@ package event type MSC1767Audio struct { - Duration int `json:"duration,omitempty"` - Waveform []int `json:"waveform,omitempty"` + Duration int `json:"duration"` + Waveform []int `json:"waveform"` } type MSC3245Voice struct{} From 8ab31c8c46829f5019d5939a93dbcccf07ba9f28 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 21 Aug 2024 16:46:12 +0300 Subject: [PATCH 0667/1647] bridgev2/userlogin: fix potential panic when kicking user from portals --- bridgev2/userlogin.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 017df773..2df43425 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -382,6 +382,8 @@ func (ul *UserLogin) kickUserFromPortal(ctx context.Context, up *database.UserPo portal, action, reason, err := ul.getLogoutAction(ctx, up, badCredentials) if err != nil { return nil, err + } else if portal == nil { + return nil, nil } zerolog.Ctx(ctx).Debug(). Str("login_id", string(ul.ID)). From ad20a9218fac270030abfd8b0634751d94f0cbea Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Wed, 21 Aug 2024 19:19:03 +0300 Subject: [PATCH 0668/1647] Beeper extension for client side media dedup (#267) The client can send a unique id (like a hash) of a file it intends to upload in the create request. If the server already has the file it will return a "completed" response. The unique id is meant to be opaque for the server. For privacy reasons it is recommended not to use the raw hash of a file. The returned MXC should be stable for the same unique id for the same user but is not guaranteed. The room id is used to tie the lifecycle of created media to an existing room on the homeserver. If a room is purged from the homeserver the media will be purged along with it. If the file has been created but not uploaded the response will not have a completed timestamp which allows the client to retry sending the file. If the upload has already been completed the upload URL will be empty. It is possible for multiple clients to send a create request simultaneously with the same unique id and upload the file at the same time. It is also possible for the server to forget the unique id and allow reuploading the same file again returning a new MXC. This commit also fixes UnusedExpiresAt type in the response which is a breaking change. --- client.go | 19 +++++++++++++++++-- responses.go | 8 ++++++-- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index b3bd9158..294595f8 100644 --- a/client.go +++ b/client.go @@ -1581,12 +1581,27 @@ func (cli *Client) DownloadBytes(ctx context.Context, mxcURL id.ContentURI) ([]b return io.ReadAll(resp.Body) } +type ReqCreateMXC struct { + BeeperUniqueID string + BeeperRoomID id.RoomID +} + // CreateMXC creates a blank Matrix content URI to allow uploading the content asynchronously later. // // See https://spec.matrix.org/v1.7/client-server-api/#post_matrixmediav1create -func (cli *Client) CreateMXC(ctx context.Context) (*RespCreateMXC, error) { +func (cli *Client) CreateMXC(ctx context.Context, extra ...ReqCreateMXC) (*RespCreateMXC, error) { var m RespCreateMXC - _, err := cli.MakeRequest(ctx, http.MethodPost, cli.BuildURL(MediaURLPath{"v1", "create"}), nil, &m) + query := map[string]string{} + if len(extra) > 0 { + if extra[0].BeeperUniqueID != "" { + query["com.beeper.unique_id"] = extra[0].BeeperUniqueID + } + if extra[0].BeeperRoomID != "" { + query["com.beeper.room_id"] = string(extra[0].BeeperRoomID) + } + } + createURL := cli.BuildURLWithQuery(MediaURLPath{"v1", "create"}, query) + _, err := cli.MakeRequest(ctx, http.MethodPost, createURL, nil, &m) return &m, err } diff --git a/responses.go b/responses.go index 9e5fd0aa..26aaac77 100644 --- a/responses.go +++ b/responses.go @@ -111,10 +111,14 @@ type RespMediaUpload struct { // RespCreateMXC is the JSON response for https://spec.matrix.org/v1.7/client-server-api/#post_matrixmediav1create type RespCreateMXC struct { - ContentURI id.ContentURI `json:"content_uri"` - UnusedExpiresAt int `json:"unused_expires_at,omitempty"` + ContentURI id.ContentURI `json:"content_uri"` + UnusedExpiresAt jsontime.UnixMilli `json:"unused_expires_at,omitempty"` UnstableUploadURL string `json:"com.beeper.msc3870.upload_url,omitempty"` + + // Beeper extensions for uploading unique media only once + BeeperUniqueID string `json:"com.beeper.unique_id,omitempty"` + BeeperCompletedAt jsontime.UnixMilli `json:"com.beeper.completed_at,omitempty"` } // RespPreviewURL is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#get_matrixmediav3preview_url From b09d3a99cbff155167bbb542c9ba5d8ca16c00fe Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 21 Aug 2024 21:41:04 +0300 Subject: [PATCH 0669/1647] bridgev2/portal: replace member list with map --- bridgev2/portal.go | 37 ++++++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 9 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index c4c7592c..f9d5aa10 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2524,17 +2524,32 @@ type ChatMemberList struct { // This should be used when SenderLogin can't be filled accurately. CheckAllLogins bool - // The total number of members in the chat, regardless of how many of those members are included in Members. + // The total number of members in the chat, regardless of how many of those members are included in MemberMap. TotalMemberCount int // For DM portals, the ID of the recipient user. - // This field is optional and will be automatically filled from Members if there are only 2 entries in the list. + // This field is optional and will be automatically filled from MemberMap if there are only 2 entries in the map. OtherUserID networkid.UserID + // Deprecated: Use MemberMap instead to avoid duplicate entries Members []ChatMember + MemberMap map[networkid.UserID]ChatMember PowerLevels *PowerLevelOverrides } +func (cml *ChatMemberList) memberListToMap(ctx context.Context) { + if cml.Members == nil || cml.MemberMap != nil { + return + } + cml.MemberMap = make(map[networkid.UserID]ChatMember, len(cml.Members)) + for _, member := range cml.Members { + if _, alreadyExists := cml.MemberMap[member.Sender]; alreadyExists { + zerolog.Ctx(ctx).Warn().Str("member_id", string(member.Sender)).Msg("Duplicate member in list") + } + cml.MemberMap[member.Sender] = member + } +} + type PowerLevelOverrides struct { Events map[event.Type]int UsersDefault *int @@ -2803,7 +2818,8 @@ func (portal *Portal) getInitialMemberList(ctx context.Context, members *ChatMem } } members.PowerLevels.Apply("", pl) - for _, member := range members.Members { + members.memberListToMap(ctx) + for _, member := range members.MemberMap { if member.Membership != event.MembershipJoin && member.Membership != "" { continue } @@ -2839,16 +2855,18 @@ func (portal *Portal) getInitialMemberList(ctx context.Context, members *ChatMem } func (portal *Portal) updateOtherUser(ctx context.Context, members *ChatMemberList) (changed bool) { + members.memberListToMap(ctx) var expectedUserID networkid.UserID if portal.RoomType != database.RoomTypeDM { // expected user ID is empty } else if members.OtherUserID != "" { expectedUserID = members.OtherUserID - } else if len(members.Members) == 2 && members.IsFull { - if members.Members[0].IsFromMe && !members.Members[1].IsFromMe { - expectedUserID = members.Members[1].Sender - } else if members.Members[1].IsFromMe && !members.Members[0].IsFromMe { - expectedUserID = members.Members[0].Sender + } else if len(members.MemberMap) == 2 && members.IsFull { + vals := maps.Values(members.MemberMap) + if vals[0].IsFromMe && !vals[1].IsFromMe { + expectedUserID = vals[1].Sender + } else if vals[1].IsFromMe && !vals[0].IsFromMe { + expectedUserID = vals[0].Sender } } if portal.OtherUserID != expectedUserID { @@ -2863,6 +2881,7 @@ func (portal *Portal) updateOtherUser(ctx context.Context, members *ChatMemberLi } func (portal *Portal) syncParticipants(ctx context.Context, members *ChatMemberList, source *UserLogin, sender MatrixAPI, ts time.Time) error { + members.memberListToMap(ctx) var loginsInPortal []*UserLogin var err error if members.CheckAllLogins { @@ -2977,7 +2996,7 @@ func (portal *Portal) syncParticipants(ctx context.Context, members *ChatMemberL } } } - for _, member := range members.Members { + for _, member := range members.MemberMap { if member.Sender != "" && member.UserInfo != nil { ghost, err := portal.Bridge.GetGhostByID(ctx, member.Sender) if err != nil { From 675d176b4662e92606555da7cbda2f028301194b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 21 Aug 2024 22:40:10 +0300 Subject: [PATCH 0670/1647] bridgev2/user: rename GetCachedUserLogins to GetUserLogins Fixes #271 --- bridgev2/user.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/bridgev2/user.go b/bridgev2/user.go index 7dc9959a..5c2344e8 100644 --- a/bridgev2/user.go +++ b/bridgev2/user.go @@ -165,7 +165,12 @@ func (user *User) GetUserLoginIDs() []networkid.UserLoginID { return maps.Keys(user.logins) } +// Deprecated: renamed to GetUserLogins func (user *User) GetCachedUserLogins() []*UserLogin { + return user.GetUserLogins() +} + +func (user *User) GetUserLogins() []*UserLogin { user.Bridge.cacheLock.Lock() defer user.Bridge.cacheLock.Unlock() return maps.Values(user.logins) From afc796861a44a315b59d02b5a1fff82b184df2e1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 22 Aug 2024 17:58:34 +0300 Subject: [PATCH 0671/1647] bridgev2/commands: fix error reply on invalid login flow --- bridgev2/commands/login.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go index df94c6ba..e813de70 100644 --- a/bridgev2/commands/login.go +++ b/bridgev2/commands/login.go @@ -56,7 +56,7 @@ func fnLogin(ce *Event) { } } if chosenFlowID == "" { - ce.Reply("Invalid login flow `%s`. Available options:\n\n%s", ce.Args[0], formatFlowsReply(flows)) + ce.Reply("Invalid login flow `%s`. Available options:\n\n%s", inputFlowID, formatFlowsReply(flows)) return } } else if len(flows) == 1 { From 8ead76c67b9b6676d9b983de992660fa14b00adb Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 22 Aug 2024 13:07:56 -0600 Subject: [PATCH 0672/1647] provisioning: fix return value from doResolveIdentifer Signed-off-by: Sumner Evans --- bridgev2/matrix/provisioning.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 107837ef..b8a33fff 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -601,7 +601,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http. } apiResp.DMRoomID = resp.Chat.Portal.MXID } - jsonResponse(w, status, resp) + jsonResponse(w, status, apiResp) } type RespGetContactList struct { From 7fa9e14a88c173a97ec05cdb3cd5be6e412275f7 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Sat, 1 Jun 2024 21:35:56 -0600 Subject: [PATCH 0673/1647] crypto/ed25519: reimplementation with libolm-compatible byte layout This PR adds a maunium.net/go/mautrix/crypto/ed25519 package to mautrix-go which is a mirror of the crypto/ed25519 package for generating Ed25519 signatures, but which uses a different private key format. This picture will help with the rest of the explanation: https://blog.mozilla.org/warner/files/2011/11/key-formats.png The private key in the [crypto/ed25519] package is a 64-byte value where the first 32-bytes are the seed and the last 32-bytes are the public key. The private key in this package is stored as a 64-byte value that results from the SHA512 of the seed. This is the format used by libolm, and is required for pickle/unpickle to work properly. Signed-off-by: Sumner Evans --- crypto/ed25519/ed25519.go | 294 +++++++++++++++++++++++++++++++++ crypto/ed25519/ed25519_test.go | 20 +++ go.mod | 1 + go.sum | 2 + 4 files changed, 317 insertions(+) create mode 100644 crypto/ed25519/ed25519.go create mode 100644 crypto/ed25519/ed25519_test.go diff --git a/crypto/ed25519/ed25519.go b/crypto/ed25519/ed25519.go new file mode 100644 index 00000000..6b294c67 --- /dev/null +++ b/crypto/ed25519/ed25519.go @@ -0,0 +1,294 @@ +// Copyright 2024 Sumner Evans. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package ed25519 implements the Ed25519 signature algorithm. See +// https://ed25519.cr.yp.to/. +// +// This package stores the private key in a different format than the +// [crypto/ed25519] package in the standard library. +// +// This picture will help with the rest of the explanation: +// https://blog.mozilla.org/warner/files/2011/11/key-formats.png +// +// The private key in the [crypto/ed25519] package is a 64-byte value where the +// first 32-bytes are the seed and the last 32-bytes are the public key. +// +// The private key in this package is stored as a 64-byte value that results +// from the SHA512 of the seed. +// +// The contents of this package are mostly copied from the standard library, +// and as such the source code is licensed under the BSD license of the +// standard library implementation. +// +// Other notable changes from the standard library include: +// +// - The Seed function of the standard library is not implemented in this +// package because there is no way to recover the seed after hashing it. +package ed25519 + +import ( + "crypto" + "crypto/ed25519" + cryptorand "crypto/rand" + "crypto/sha512" + "crypto/subtle" + "errors" + "io" + "strconv" + + "filippo.io/edwards25519" +) + +const ( + // PublicKeySize is the size, in bytes, of public keys as used in this package. + PublicKeySize = 32 + // PrivateKeySize is the size, in bytes, of private keys as used in this package. + PrivateKeySize = 64 + // SignatureSize is the size, in bytes, of signatures generated and verified by this package. + SignatureSize = 64 + // SeedSize is the size, in bytes, of private key seeds. These are the private key representations used by RFC 8032. + SeedSize = 32 +) + +// PublicKey is the type of Ed25519 public keys. +type PublicKey []byte + +// Any methods implemented on PublicKey might need to also be implemented on +// PrivateKey, as the latter embeds the former and will expose its methods. + +// Equal reports whether pub and x have the same value. +func (pub PublicKey) Equal(x crypto.PublicKey) bool { + switch x := x.(type) { + case PublicKey: + return subtle.ConstantTimeCompare(pub, x) == 1 + case ed25519.PublicKey: + return subtle.ConstantTimeCompare(pub, x) == 1 + default: + return false + } +} + +// PrivateKey is the type of Ed25519 private keys. It implements [crypto.Signer]. +type PrivateKey []byte + +// Public returns the [PublicKey] corresponding to priv. +// +// This method differs from the standard library because it calculates the +// public key instead of returning the right half of the private key (which +// contains the public key in the standard library). +func (priv PrivateKey) Public() crypto.PublicKey { + s, err := edwards25519.NewScalar().SetBytesWithClamping(priv[:32]) + if err != nil { + panic("ed25519: internal error: setting scalar failed") + } + return (&edwards25519.Point{}).ScalarBaseMult(s).Bytes() +} + +// Equal reports whether priv and x have the same value. +func (priv PrivateKey) Equal(x crypto.PrivateKey) bool { + // TODO do we have any need to check equality with standard library ed25519 + // private keys? + xx, ok := x.(PrivateKey) + if !ok { + return false + } + return subtle.ConstantTimeCompare(priv, xx) == 1 +} + +// Sign signs the given message with priv. rand is ignored and can be nil. +// +// If opts.HashFunc() is [crypto.SHA512], the pre-hashed variant Ed25519ph is used +// and message is expected to be a SHA-512 hash, otherwise opts.HashFunc() must +// be [crypto.Hash](0) and the message must not be hashed, as Ed25519 performs two +// passes over messages to be signed. +// +// A value of type [Options] can be used as opts, or crypto.Hash(0) or +// crypto.SHA512 directly to select plain Ed25519 or Ed25519ph, respectively. +func (priv PrivateKey) Sign(rand io.Reader, message []byte, opts crypto.SignerOpts) (signature []byte, err error) { + hash := opts.HashFunc() + context := "" + if opts, ok := opts.(*Options); ok { + context = opts.Context + } + switch { + case hash == crypto.SHA512: // Ed25519ph + if l := len(message); l != sha512.Size { + return nil, errors.New("ed25519: bad Ed25519ph message hash length: " + strconv.Itoa(l)) + } + if l := len(context); l > 255 { + return nil, errors.New("ed25519: bad Ed25519ph context length: " + strconv.Itoa(l)) + } + signature := make([]byte, SignatureSize) + sign(signature, priv, message, domPrefixPh, context) + return signature, nil + case hash == crypto.Hash(0) && context != "": // Ed25519ctx + if l := len(context); l > 255 { + return nil, errors.New("ed25519: bad Ed25519ctx context length: " + strconv.Itoa(l)) + } + signature := make([]byte, SignatureSize) + sign(signature, priv, message, domPrefixCtx, context) + return signature, nil + case hash == crypto.Hash(0): // Ed25519 + return Sign(priv, message), nil + default: + return nil, errors.New("ed25519: expected opts.HashFunc() zero (unhashed message, for standard Ed25519) or SHA-512 (for Ed25519ph)") + } +} + +// Options can be used with [PrivateKey.Sign] or [VerifyWithOptions] +// to select Ed25519 variants. +type Options struct { + // Hash can be zero for regular Ed25519, or crypto.SHA512 for Ed25519ph. + Hash crypto.Hash + + // Context, if not empty, selects Ed25519ctx or provides the context string + // for Ed25519ph. It can be at most 255 bytes in length. + Context string +} + +// HashFunc returns o.Hash. +func (o *Options) HashFunc() crypto.Hash { return o.Hash } + +// GenerateKey generates a public/private key pair using entropy from rand. +// If rand is nil, [crypto/rand.Reader] will be used. +// +// The output of this function is deterministic, and equivalent to reading +// [SeedSize] bytes from rand, and passing them to [NewKeyFromSeed]. +func GenerateKey(rand io.Reader) (PublicKey, PrivateKey, error) { + if rand == nil { + rand = cryptorand.Reader + } + + seed := make([]byte, SeedSize) + if _, err := io.ReadFull(rand, seed); err != nil { + return nil, nil, err + } + + privateKey := NewKeyFromSeed(seed) + return PublicKey(privateKey.Public().([]byte)), privateKey, nil +} + +// NewKeyFromSeed calculates a private key from a seed. It will panic if +// len(seed) is not [SeedSize]. This function is provided for interoperability +// with RFC 8032. RFC 8032's private keys correspond to seeds in this +// package. +func NewKeyFromSeed(seed []byte) PrivateKey { + // Outline the function body so that the returned key can be stack-allocated. + privateKey := make([]byte, PrivateKeySize) + newKeyFromSeed(privateKey, seed) + return privateKey +} + +func newKeyFromSeed(privateKey, seed []byte) { + if l := len(seed); l != SeedSize { + panic("ed25519: bad seed length: " + strconv.Itoa(l)) + } + + h := sha512.Sum512(seed) + copy(privateKey, h[:]) +} + +// Sign signs the message with privateKey and returns a signature. It will +// panic if len(privateKey) is not [PrivateKeySize]. +func Sign(privateKey PrivateKey, message []byte) []byte { + // Outline the function body so that the returned signature can be + // stack-allocated. + signature := make([]byte, SignatureSize) + sign(signature, privateKey, message, domPrefixPure, "") + return signature +} + +// Domain separation prefixes used to disambiguate Ed25519/Ed25519ph/Ed25519ctx. +// See RFC 8032, Section 2 and Section 5.1. +const ( + // domPrefixPure is empty for pure Ed25519. + domPrefixPure = "" + // domPrefixPh is dom2(phflag=1) for Ed25519ph. It must be followed by the + // uint8-length prefixed context. + domPrefixPh = "SigEd25519 no Ed25519 collisions\x01" + // domPrefixCtx is dom2(phflag=0) for Ed25519ctx. It must be followed by the + // uint8-length prefixed context. + domPrefixCtx = "SigEd25519 no Ed25519 collisions\x00" +) + +func sign(signature []byte, privateKey PrivateKey, message []byte, domPrefix, context string) { + if l := len(privateKey); l != PrivateKeySize { + panic("ed25519: bad private key length: " + strconv.Itoa(l)) + } + // We have to extract the public key from the private key. + publicKey := privateKey.Public().([]byte) + // The private key is already the hashed value of the seed. + h := privateKey + + s, err := edwards25519.NewScalar().SetBytesWithClamping(h[:32]) + if err != nil { + panic("ed25519: internal error: setting scalar failed") + } + prefix := h[32:] + + mh := sha512.New() + if domPrefix != domPrefixPure { + mh.Write([]byte(domPrefix)) + mh.Write([]byte{byte(len(context))}) + mh.Write([]byte(context)) + } + mh.Write(prefix) + mh.Write(message) + messageDigest := make([]byte, 0, sha512.Size) + messageDigest = mh.Sum(messageDigest) + r, err := edwards25519.NewScalar().SetUniformBytes(messageDigest) + if err != nil { + panic("ed25519: internal error: setting scalar failed") + } + + R := (&edwards25519.Point{}).ScalarBaseMult(r) + + kh := sha512.New() + if domPrefix != domPrefixPure { + kh.Write([]byte(domPrefix)) + kh.Write([]byte{byte(len(context))}) + kh.Write([]byte(context)) + } + kh.Write(R.Bytes()) + kh.Write(publicKey) + kh.Write(message) + hramDigest := make([]byte, 0, sha512.Size) + hramDigest = kh.Sum(hramDigest) + k, err := edwards25519.NewScalar().SetUniformBytes(hramDigest) + if err != nil { + panic("ed25519: internal error: setting scalar failed") + } + + S := edwards25519.NewScalar().MultiplyAdd(k, s, r) + + copy(signature[:32], R.Bytes()) + copy(signature[32:], S.Bytes()) +} + +// Verify reports whether sig is a valid signature of message by publicKey. It +// will panic if len(publicKey) is not [PublicKeySize]. +// +// This is just a wrapper around [ed25519.Verify] from the standard library. +func Verify(publicKey PublicKey, message, sig []byte) bool { + return ed25519.Verify(ed25519.PublicKey(publicKey), message, sig) +} + +// VerifyWithOptions reports whether sig is a valid signature of message by +// publicKey. A valid signature is indicated by returning a nil error. It will +// panic if len(publicKey) is not [PublicKeySize]. +// +// If opts.Hash is [crypto.SHA512], the pre-hashed variant Ed25519ph is used and +// message is expected to be a SHA-512 hash, otherwise opts.Hash must be +// [crypto.Hash](0) and the message must not be hashed, as Ed25519 performs two +// passes over messages to be signed. +// +// This is just a wrapper around [ed25519.VerifyWithOptions] from the standard +// library. +func VerifyWithOptions(publicKey PublicKey, message, sig []byte, opts *Options) error { + return ed25519.VerifyWithOptions(ed25519.PublicKey(publicKey), message, sig, &ed25519.Options{ + Hash: opts.Hash, + Context: opts.Context, + }) +} diff --git a/crypto/ed25519/ed25519_test.go b/crypto/ed25519/ed25519_test.go new file mode 100644 index 00000000..931c06f6 --- /dev/null +++ b/crypto/ed25519/ed25519_test.go @@ -0,0 +1,20 @@ +package ed25519_test + +import ( + stdlibed25519 "crypto/ed25519" + "testing" + + "github.com/stretchr/testify/assert" + "go.mau.fi/util/random" + + "maunium.net/go/mautrix/crypto/ed25519" +) + +func TestPubkeyEqual(t *testing.T) { + pubkeyBytes := random.Bytes(32) + pubkey := ed25519.PublicKey(pubkeyBytes) + pubkey2 := ed25519.PublicKey(pubkeyBytes) + stdlibPubkey := stdlibed25519.PublicKey(pubkeyBytes) + assert.True(t, pubkey.Equal(pubkey2)) + assert.True(t, pubkey.Equal(stdlibPubkey)) +} diff --git a/go.mod b/go.mod index 6ff2233d..835c7ef8 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module maunium.net/go/mautrix go 1.22 require ( + filippo.io/edwards25519 v1.1.0 github.com/chzyer/readline v1.5.1 github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.5.0 diff --git a/go.sum b/go.sum index 53105fc7..d357ac92 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM= From 213b6ec80d501fdd95f0a8ead2ffeff063054ebd Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 31 May 2024 13:59:55 -0600 Subject: [PATCH 0674/1647] crypto/olm: make everything into an interface This commit turns all of the crypto objects that are provided by olm into interfaces so that multiple implementations (libolm and goolm right now) can implement it. As part of this refactor, the libolm code has been moved to a separate package (goolm was already in its own package). Both packages now implement structs which implement the various interfaces. Additional changes: * goolm/goolmbase64: split into separate package (needed to avoid import loops) * olm/errors: unified all errors under the olm package * ci: remove libolm before building with goolm flag (this allows us to use ./... to build all of the packages under goolm) Signed-off-by: Sumner Evans Signed-off-by: Sumner Evans --- .github/workflows/go.yml | 7 +- crypto/account.go | 54 +- crypto/cross_sign_key.go | 2 +- crypto/cross_sign_signing.go | 2 +- crypto/encryptolm.go | 5 +- crypto/goolm/account/account.go | 107 ++-- crypto/goolm/account/account_test.go | 78 +-- crypto/goolm/account/register.go | 25 + crypto/goolm/cipher/pickle.go | 9 +- crypto/goolm/crypto/curve25519.go | 6 +- crypto/goolm/crypto/ed25519.go | 6 +- crypto/goolm/crypto/one_time_key.go | 4 +- crypto/goolm/errors.go | 28 - crypto/goolm/{ => goolmbase64}/base64.go | 6 +- crypto/goolm/libolmpickle/unpickle.go | 10 +- crypto/goolm/megolm/megolm.go | 11 +- crypto/goolm/message/decoder.go | 6 +- crypto/goolm/message/session_export.go | 6 +- crypto/goolm/message/session_sharing.go | 8 +- crypto/goolm/pk/decryption.go | 15 +- crypto/goolm/pk/encryption.go | 4 +- crypto/goolm/pk/pk_test.go | 7 +- crypto/goolm/pk/register.go | 21 + crypto/goolm/pk/signing.go | 4 +- crypto/goolm/{olm => ratchet}/chain.go | 12 +- crypto/goolm/{olm => ratchet}/olm.go | 30 +- crypto/goolm/{olm => ratchet}/olm_test.go | 32 +- .../goolm/{olm => ratchet}/skipped_message.go | 6 +- crypto/goolm/register.go | 25 + .../goolm/session/megolm_inbound_session.go | 82 ++- .../goolm/session/megolm_outbound_session.go | 54 +- crypto/goolm/session/megolm_session_test.go | 12 +- crypto/goolm/session/olm_session.go | 143 +++--- crypto/goolm/session/olm_session_test.go | 20 +- crypto/goolm/session/register.go | 69 +++ crypto/goolm/utilities/pickle.go | 8 +- crypto/keybackup.go | 2 +- crypto/keyimport.go | 2 +- crypto/keysharing.go | 2 +- crypto/libolm/account.go | 417 +++++++++++++++ crypto/libolm/error.go | 37 ++ crypto/libolm/inboundgroupsession.go | 328 ++++++++++++ crypto/libolm/libolm.go | 10 + crypto/libolm/outboundgroupsession.go | 249 +++++++++ crypto/{olm/pk_libolm.go => libolm/pk.go} | 73 +-- crypto/libolm/register.go | 21 + crypto/libolm/session.go | 388 ++++++++++++++ crypto/olm/account.go | 485 ++++-------------- crypto/olm/account_goolm.go | 154 ------ crypto/olm/error_goolm.go | 23 - crypto/olm/{error.go => errors.go} | 54 +- crypto/olm/inboundgroupsession.go | 349 +++---------- crypto/olm/inboundgroupsession_goolm.go | 149 ------ crypto/olm/olm.go | 26 +- crypto/olm/olm_goolm.go | 13 - crypto/olm/outboundgroupsession.go | 258 ++-------- crypto/olm/outboundgroupsession_goolm.go | 111 ---- crypto/olm/{pk_interface.go => pk.go} | 26 +- crypto/olm/pk_goolm.go | 29 -- crypto/olm/pk_test.go | 7 +- crypto/olm/session.go | 419 +++------------ crypto/olm/session_goolm.go | 110 ---- crypto/olm/session_test.go | 56 ++ crypto/registergoolm.go | 5 + crypto/registerlibolm.go | 5 + crypto/sessions.go | 16 +- crypto/sql_store.go | 54 +- crypto/store_test.go | 16 +- 68 files changed, 2522 insertions(+), 2296 deletions(-) create mode 100644 crypto/goolm/account/register.go delete mode 100644 crypto/goolm/errors.go rename crypto/goolm/{ => goolmbase64}/base64.go (82%) create mode 100644 crypto/goolm/pk/register.go rename crypto/goolm/{olm => ratchet}/chain.go (95%) rename crypto/goolm/{olm => ratchet}/olm.go (93%) rename crypto/goolm/{olm => ratchet}/olm_test.go (82%) rename crypto/goolm/{olm => ratchet}/skipped_message.go (92%) create mode 100644 crypto/goolm/register.go create mode 100644 crypto/goolm/session/register.go create mode 100644 crypto/libolm/account.go create mode 100644 crypto/libolm/error.go create mode 100644 crypto/libolm/inboundgroupsession.go create mode 100644 crypto/libolm/libolm.go create mode 100644 crypto/libolm/outboundgroupsession.go rename crypto/{olm/pk_libolm.go => libolm/pk.go} (71%) create mode 100644 crypto/libolm/register.go create mode 100644 crypto/libolm/session.go delete mode 100644 crypto/olm/account_goolm.go delete mode 100644 crypto/olm/error_goolm.go rename crypto/olm/{error.go => errors.go} (57%) delete mode 100644 crypto/olm/inboundgroupsession_goolm.go delete mode 100644 crypto/olm/olm_goolm.go delete mode 100644 crypto/olm/outboundgroupsession_goolm.go rename crypto/olm/{pk_interface.go => pk.go} (52%) delete mode 100644 crypto/olm/pk_goolm.go delete mode 100644 crypto/olm/session_goolm.go create mode 100644 crypto/olm/session_test.go create mode 100644 crypto/registergoolm.go create mode 100644 crypto/registerlibolm.go diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 8197d3a7..30f05d69 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -83,7 +83,6 @@ jobs: token: ${{ secrets.GITHUB_TOKEN }} - name: Build - run: go build -tags=goolm -v ./... - - - name: Test - run: go test -tags=goolm -json -v ./... 2>&1 | gotestfmt + run: | + rm -rf crypto/libolm + go build -tags=goolm -v ./... diff --git a/crypto/account.go b/crypto/account.go index 2f012e59..2f93280c 100644 --- a/crypto/account.go +++ b/crypto/account.go @@ -7,7 +7,12 @@ package crypto import ( + "encoding/json" + + "github.com/tidwall/sjson" + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto/canonicaljson" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/signatures" "maunium.net/go/mautrix/id" @@ -22,32 +27,61 @@ type OlmAccount struct { } func NewOlmAccount() *OlmAccount { + account, err := olm.NewAccount(nil) + if err != nil { + panic(err) + } return &OlmAccount{ - Internal: *olm.NewAccount(), + Internal: account, } } func (account *OlmAccount) Keys() (id.SigningKey, id.IdentityKey) { if len(account.signingKey) == 0 || len(account.identityKey) == 0 { - account.signingKey, account.identityKey = account.Internal.IdentityKeys() + var err error + account.signingKey, account.identityKey, err = account.Internal.IdentityKeys() + if err != nil { + panic(err) + } } return account.signingKey, account.identityKey } func (account *OlmAccount) SigningKey() id.SigningKey { if len(account.signingKey) == 0 { - account.signingKey, account.identityKey = account.Internal.IdentityKeys() + var err error + account.signingKey, account.identityKey, err = account.Internal.IdentityKeys() + if err != nil { + panic(err) + } } return account.signingKey } func (account *OlmAccount) IdentityKey() id.IdentityKey { if len(account.identityKey) == 0 { - account.signingKey, account.identityKey = account.Internal.IdentityKeys() + var err error + account.signingKey, account.identityKey, err = account.Internal.IdentityKeys() + if err != nil { + panic(err) + } } return account.identityKey } +// SignJSON signs the given JSON object following the Matrix specification: +// https://matrix.org/docs/spec/appendices#signing-json +func (account *OlmAccount) SignJSON(obj any) (string, error) { + objJSON, err := json.Marshal(obj) + if err != nil { + return "", err + } + objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned") + objJSON, _ = sjson.DeleteBytes(objJSON, "signatures") + signed, err := account.Internal.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON)) + return string(signed), err +} + func (account *OlmAccount) getInitialKeys(userID id.UserID, deviceID id.DeviceID) *mautrix.DeviceKeys { deviceKeys := &mautrix.DeviceKeys{ UserID: userID, @@ -59,7 +93,7 @@ func (account *OlmAccount) getInitialKeys(userID id.UserID, deviceID id.DeviceID }, } - signature, err := account.Internal.SignJSON(deviceKeys) + signature, err := account.SignJSON(deviceKeys) if err != nil { panic(err) } @@ -71,12 +105,16 @@ func (account *OlmAccount) getInitialKeys(userID id.UserID, deviceID id.DeviceID func (account *OlmAccount) getOneTimeKeys(userID id.UserID, deviceID id.DeviceID, currentOTKCount int) map[id.KeyID]mautrix.OneTimeKey { newCount := int(account.Internal.MaxNumberOfOneTimeKeys()/2) - currentOTKCount if newCount > 0 { - account.Internal.GenOneTimeKeys(uint(newCount)) + account.Internal.GenOneTimeKeys(nil, uint(newCount)) } oneTimeKeys := make(map[id.KeyID]mautrix.OneTimeKey) - for keyID, key := range account.Internal.OneTimeKeys() { + internalKeys, err := account.Internal.OneTimeKeys() + if err != nil { + panic(err) + } + for keyID, key := range internalKeys { key := mautrix.OneTimeKey{Key: key} - signature, _ := account.Internal.SignJSON(key) + signature, _ := account.SignJSON(key) key.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, deviceID.String(), signature) key.IsSigned = true oneTimeKeys[id.NewKeyID(id.KeyAlgorithmSignedCurve25519, keyID)] = key diff --git a/crypto/cross_sign_key.go b/crypto/cross_sign_key.go index 3d01fb99..97ecd865 100644 --- a/crypto/cross_sign_key.go +++ b/crypto/cross_sign_key.go @@ -101,7 +101,7 @@ func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *Cross masterKeyID: keys.MasterKey.PublicKey(), }, } - masterSig, err := mach.account.Internal.SignJSON(masterKey) + masterSig, err := mach.account.SignJSON(masterKey) if err != nil { return fmt.Errorf("failed to sign master key: %w", err) } diff --git a/crypto/cross_sign_signing.go b/crypto/cross_sign_signing.go index 1d80cc91..ae3d1eb1 100644 --- a/crypto/cross_sign_signing.go +++ b/crypto/cross_sign_signing.go @@ -87,7 +87,7 @@ func (mach *OlmMachine) SignOwnMasterKey(ctx context.Context) error { id.NewKeyID(id.KeyAlgorithmEd25519, masterKey.String()): masterKey.String(), }, } - signature, err := mach.account.Internal.SignJSON(masterKeyObj) + signature, err := mach.account.SignJSON(masterKeyObj) if err != nil { return fmt.Errorf("failed to sign JSON: %w", err) } diff --git a/crypto/encryptolm.go b/crypto/encryptolm.go index 15e9df29..52e30166 100644 --- a/crypto/encryptolm.go +++ b/crypto/encryptolm.go @@ -37,7 +37,10 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession Str("olm_session_id", session.ID().String()). Str("olm_session_description", session.Describe()). Msg("Encrypting olm message") - msgType, ciphertext := session.Encrypt(plaintext) + msgType, ciphertext, err := session.Encrypt(plaintext) + if err != nil { + panic(err) + } err = mach.CryptoStore.UpdateSession(ctx, recipient.IdentityKey, session) if err != nil { log.Error().Err(err).Msg("Failed to update olm session in crypto store after encrypting") diff --git a/crypto/goolm/account/account.go b/crypto/goolm/account/account.go index 4057543a..2b127ab5 100644 --- a/crypto/goolm/account/account.go +++ b/crypto/goolm/account/account.go @@ -10,12 +10,12 @@ import ( "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/session" "maunium.net/go/mautrix/crypto/goolm/utilities" + "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -39,10 +39,13 @@ type Account struct { NumFallbackKeys uint8 `json:"number_fallback_keys"` } +// Ensure that Account adheres to the olm.Account interface. +var _ olm.Account = (*Account)(nil) + // AccountFromJSONPickled loads the Account details from a pickled base64 string. The input is decrypted with the supplied key. func AccountFromJSONPickled(pickled, key []byte) (*Account, error) { if len(pickled) == 0 { - return nil, fmt.Errorf("accountFromPickled: %w", goolm.ErrEmptyInput) + return nil, fmt.Errorf("accountFromPickled: %w", olm.ErrEmptyInput) } a := &Account{} err := a.UnpickleAsJSON(pickled, key) @@ -55,7 +58,7 @@ func AccountFromJSONPickled(pickled, key []byte) (*Account, error) { // AccountFromPickled loads the Account details from a pickled base64 string. The input is decrypted with the supplied key. func AccountFromPickled(pickled, key []byte) (*Account, error) { if len(pickled) == 0 { - return nil, fmt.Errorf("accountFromPickled: %w", goolm.ErrEmptyInput) + return nil, fmt.Errorf("accountFromPickled: %w", olm.ErrEmptyInput) } a := &Account{} err := a.Unpickle(pickled, key) @@ -82,7 +85,7 @@ func NewAccount(reader io.Reader) (*Account, error) { } // PickleAsJSON returns an Account as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. -func (a Account) PickleAsJSON(key []byte) ([]byte, error) { +func (a *Account) PickleAsJSON(key []byte) ([]byte, error) { return utilities.PickleAsJSON(a, accountPickleVersionJSON, key) } @@ -92,29 +95,32 @@ func (a *Account) UnpickleAsJSON(pickled, key []byte) error { } // IdentityKeysJSON returns the public parts of the identity keys for the Account in a JSON string. -func (a Account) IdentityKeysJSON() ([]byte, error) { +func (a *Account) IdentityKeysJSON() ([]byte, error) { res := struct { Ed25519 string `json:"ed25519"` Curve25519 string `json:"curve25519"` }{} - ed25519, curve25519 := a.IdentityKeys() + ed25519, curve25519, err := a.IdentityKeys() + if err != nil { + return nil, err + } res.Ed25519 = string(ed25519) res.Curve25519 = string(curve25519) return json.Marshal(res) } // IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity keys for the Account. -func (a Account) IdentityKeys() (id.Ed25519, id.Curve25519) { +func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519, error) { ed25519 := id.Ed25519(base64.RawStdEncoding.EncodeToString(a.IdKeys.Ed25519.PublicKey)) curve25519 := id.Curve25519(base64.RawStdEncoding.EncodeToString(a.IdKeys.Curve25519.PublicKey)) - return ed25519, curve25519 + return ed25519, curve25519, nil } // Sign returns the base64-encoded signature of a message using the Ed25519 key // for this Account. -func (a Account) Sign(message []byte) ([]byte, error) { +func (a *Account) Sign(message []byte) ([]byte, error) { if len(message) == 0 { - return nil, fmt.Errorf("sign: %w", goolm.ErrEmptyInput) + return nil, fmt.Errorf("sign: %w", olm.ErrEmptyInput) } return []byte(base64.RawStdEncoding.EncodeToString(a.IdKeys.Ed25519.Sign(message))), nil } @@ -122,32 +128,14 @@ func (a Account) Sign(message []byte) ([]byte, error) { // OneTimeKeys returns the public parts of the unpublished one time keys of the Account. // // The returned data is a map with the mapping of key id to base64-encoded Curve25519 key. -func (a Account) OneTimeKeys() map[string]id.Curve25519 { +func (a *Account) OneTimeKeys() (map[string]id.Curve25519, error) { oneTimeKeys := make(map[string]id.Curve25519) for _, curKey := range a.OTKeys { if !curKey.Published { oneTimeKeys[curKey.KeyIDEncoded()] = id.Curve25519(curKey.PublicKeyEncoded()) } } - return oneTimeKeys -} - -//OneTimeKeysJSON returns the public parts of the unpublished one time keys of the Account as a JSON string. -// -//The returned JSON is of format: -/* - { - Curve25519: { - "AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo", - "AAAAAB": "LRvjo46L1X2vx69sS9QNFD29HWulxrmW11Up5AfAjgU" - } - } -*/ -func (a Account) OneTimeKeysJSON() ([]byte, error) { - res := make(map[string]map[string]id.Curve25519) - otKeys := a.OneTimeKeys() - res["Curve25519"] = otKeys - return json.Marshal(res) + return oneTimeKeys, nil } // MarkKeysAsPublished marks the current set of one time keys and the fallback key as being @@ -186,9 +174,9 @@ func (a *Account) GenOneTimeKeys(reader io.Reader, num uint) error { // NewOutboundSession creates a new outbound session to a // given curve25519 identity Key and one time key. -func (a Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*session.OlmSession, error) { +func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (olm.Session, error) { if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 { - return nil, fmt.Errorf("outbound session: %w", goolm.ErrEmptyInput) + return nil, fmt.Errorf("outbound session: %w", olm.ErrEmptyInput) } theirIdentityKeyDecoded, err := base64.RawStdEncoding.DecodeString(string(theirIdentityKey)) if err != nil { @@ -205,13 +193,18 @@ func (a Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25 return s, nil } -// NewInboundSession creates a new inbound session from an incoming PRE_KEY message. -func (a Account) NewInboundSession(theirIdentityKey *id.Curve25519, oneTimeKeyMsg []byte) (*session.OlmSession, error) { +// NewInboundSession creates a new in-bound session for sending/receiving +// messages from an incoming PRE_KEY message. Returns error on failure. +func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) { + return a.NewInboundSessionFrom(nil, oneTimeKeyMsg) +} + +// NewInboundSessionFrom creates a new inbound session from an incoming PRE_KEY message. +func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (olm.Session, error) { if len(oneTimeKeyMsg) == 0 { - return nil, fmt.Errorf("inbound session: %w", goolm.ErrEmptyInput) + return nil, fmt.Errorf("inbound session: %w", olm.ErrEmptyInput) } var theirIdentityKeyDecoded *crypto.Curve25519PublicKey - var err error if theirIdentityKey != nil { theirIdentityKeyDecodedByte, err := base64.RawStdEncoding.DecodeString(string(*theirIdentityKey)) if err != nil { @@ -221,14 +214,10 @@ func (a Account) NewInboundSession(theirIdentityKey *id.Curve25519, oneTimeKeyMs theirIdentityKeyDecoded = &theirIdentityKeyCurve } - s, err := session.NewInboundOlmSession(theirIdentityKeyDecoded, oneTimeKeyMsg, a.searchOTKForOur, a.IdKeys.Curve25519) - if err != nil { - return nil, err - } - return s, nil + return session.NewInboundOlmSession(theirIdentityKeyDecoded, []byte(oneTimeKeyMsg), a.searchOTKForOur, a.IdKeys.Curve25519) } -func (a Account) searchOTKForOur(toFind crypto.Curve25519PublicKey) *crypto.OneTimeKey { +func (a *Account) searchOTKForOur(toFind crypto.Curve25519PublicKey) *crypto.OneTimeKey { for curIndex := range a.OTKeys { if a.OTKeys[curIndex].Key.PublicKey.Equal(toFind) { return &a.OTKeys[curIndex] @@ -244,16 +233,17 @@ func (a Account) searchOTKForOur(toFind crypto.Curve25519PublicKey) *crypto.OneT } // RemoveOneTimeKeys removes the one time key in this Account which matches the one time key in the session s. -func (a *Account) RemoveOneTimeKeys(s *session.OlmSession) { - toFind := s.BobOneTimeKey +func (a *Account) RemoveOneTimeKeys(s olm.Session) error { + toFind := s.(*session.OlmSession).BobOneTimeKey for curIndex := range a.OTKeys { if a.OTKeys[curIndex].Key.PublicKey.Equal(toFind) { //Remove and return a.OTKeys[curIndex] = a.OTKeys[len(a.OTKeys)-1] a.OTKeys = a.OTKeys[:len(a.OTKeys)-1] - return + return nil } } + return nil //if the key is a fallback or prevFallback, don't remove it } @@ -279,7 +269,7 @@ func (a *Account) GenFallbackKey(reader io.Reader) error { // FallbackKey returns the public part of the current fallback key of the Account. // The returned data is a map with the mapping of key id to base64-encoded Curve25519 key. -func (a Account) FallbackKey() map[string]id.Curve25519 { +func (a *Account) FallbackKey() map[string]id.Curve25519 { keys := make(map[string]id.Curve25519) if a.NumFallbackKeys >= 1 { keys[a.CurrentFallbackKey.KeyIDEncoded()] = id.Curve25519(a.CurrentFallbackKey.PublicKeyEncoded()) @@ -297,7 +287,7 @@ func (a Account) FallbackKey() map[string]id.Curve25519 { } } */ -func (a Account) FallbackKeyJSON() ([]byte, error) { +func (a *Account) FallbackKeyJSON() ([]byte, error) { res := make(map[string]map[string]id.Curve25519) fbk := a.FallbackKey() res["curve25519"] = fbk @@ -306,7 +296,7 @@ func (a Account) FallbackKeyJSON() ([]byte, error) { // FallbackKeyUnpublished returns the public part of the current fallback key of the Account only if it is unpublished. // The returned data is a map with the mapping of key id to base64-encoded Curve25519 key. -func (a Account) FallbackKeyUnpublished() map[string]id.Curve25519 { +func (a *Account) FallbackKeyUnpublished() map[string]id.Curve25519 { keys := make(map[string]id.Curve25519) if a.NumFallbackKeys >= 1 && !a.CurrentFallbackKey.Published { keys[a.CurrentFallbackKey.KeyIDEncoded()] = id.Curve25519(a.CurrentFallbackKey.PublicKeyEncoded()) @@ -324,7 +314,7 @@ func (a Account) FallbackKeyUnpublished() map[string]id.Curve25519 { } } */ -func (a Account) FallbackKeyUnpublishedJSON() ([]byte, error) { +func (a *Account) FallbackKeyUnpublishedJSON() ([]byte, error) { res := make(map[string]map[string]id.Curve25519) fbk := a.FallbackKeyUnpublished() res["curve25519"] = fbk @@ -360,7 +350,7 @@ func (a *Account) UnpickleLibOlm(value []byte) (int, error) { switch pickledVersion { case accountPickleVersionLibOLM, 3, 2: default: - return 0, fmt.Errorf("unpickle account: %w", goolm.ErrBadVersion) + return 0, fmt.Errorf("unpickle account: %w", olm.ErrBadVersion) } //read ed25519 key pair readBytes, err := a.IdKeys.Ed25519.UnpickleLibOlm(value[curPos:]) @@ -448,7 +438,10 @@ func (a *Account) UnpickleLibOlm(value []byte) (int, error) { } // Pickle returns a base64 encoded and with key encrypted pickled account using PickleLibOlm(). -func (a Account) Pickle(key []byte) ([]byte, error) { +func (a *Account) Pickle(key []byte) ([]byte, error) { + if len(key) == 0 { + return nil, olm.ErrNoKeyProvided + } pickeledBytes := make([]byte, a.PickleLen()) written, err := a.PickleLibOlm(pickeledBytes) if err != nil { @@ -466,9 +459,9 @@ func (a Account) Pickle(key []byte) ([]byte, error) { // PickleLibOlm encodes the Account into target. target has to have a size of at least PickleLen() and is written to from index 0. // It returns the number of bytes written. -func (a Account) PickleLibOlm(target []byte) (int, error) { +func (a *Account) PickleLibOlm(target []byte) (int, error) { if len(target) < a.PickleLen() { - return 0, fmt.Errorf("pickle account: %w", goolm.ErrValueTooShort) + return 0, fmt.Errorf("pickle account: %w", olm.ErrValueTooShort) } written := libolmpickle.PickleUInt32(accountPickleVersionLibOLM, target) writtenEdKey, err := a.IdKeys.Ed25519.PickleLibOlm(target[written:]) @@ -510,7 +503,7 @@ func (a Account) PickleLibOlm(target []byte) (int, error) { } // PickleLen returns the number of bytes the pickled Account will have. -func (a Account) PickleLen() int { +func (a *Account) PickleLen() int { length := libolmpickle.PickleUInt32Len(accountPickleVersionLibOLM) length += a.IdKeys.Ed25519.PickleLen() length += a.IdKeys.Curve25519.PickleLen() @@ -521,3 +514,9 @@ func (a Account) PickleLen() int { length += libolmpickle.PickleUInt32Len(a.NextOneTimeKeyID) return length } + +// MaxNumberOfOneTimeKeys returns the largest number of one time keys this +// Account can store. +func (a *Account) MaxNumberOfOneTimeKeys() uint { + return uint(MaxOneTimeKeys) +} diff --git a/crypto/goolm/account/account_test.go b/crypto/goolm/account/account_test.go index 943d8570..2482d087 100644 --- a/crypto/goolm/account/account_test.go +++ b/crypto/goolm/account/account_test.go @@ -11,8 +11,8 @@ import ( "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/account" + "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/signatures" ) @@ -71,7 +71,7 @@ func TestAccount(t *testing.T) { t.Fatal("IdentityKeys Ed25519 public unequal") } - if len(firstAccount.OneTimeKeys()) != 2 { + if otks, err := firstAccount.OneTimeKeys(); err != nil || len(otks) != 2 { t.Fatal("should get 2 unpublished oneTimeKeys") } if len(firstAccount.FallbackKeyUnpublished()) == 0 { @@ -84,7 +84,7 @@ func TestAccount(t *testing.T) { if len(firstAccount.FallbackKeyUnpublished()) != 0 { t.Fatal("should get no fallbackKey") } - if len(firstAccount.OneTimeKeys()) != 0 { + if otks, err := firstAccount.OneTimeKeys(); err != nil || len(otks) != 0 { t.Fatal("should get no oneTimeKeys") } } @@ -139,7 +139,7 @@ func TestSessions(t *testing.T) { t.Fatal(err) } plaintext := []byte("test message") - msgType, crypttext, err := aliceSession.Encrypt(plaintext, nil) + msgType, crypttext, err := aliceSession.Encrypt(plaintext) if err != nil { t.Fatal(err) } @@ -147,11 +147,11 @@ func TestSessions(t *testing.T) { t.Fatal("wrong message type") } - bobSession, err := bobAccount.NewInboundSession(nil, crypttext) + bobSession, err := bobAccount.NewInboundSession(string(crypttext)) if err != nil { t.Fatal(err) } - decodedText, err := bobSession.Decrypt(crypttext, msgType) + decodedText, err := bobSession.Decrypt(string(crypttext), msgType) if err != nil { t.Fatal(err) } @@ -225,7 +225,7 @@ func TestOldAccountPickle(t *testing.T) { if err == nil { t.Fatal("expected error") } else { - if !errors.Is(err, goolm.ErrBadVersion) { + if !errors.Is(err, olm.ErrBadVersion) { t.Fatal(err) } } @@ -252,7 +252,7 @@ func TestLoopback(t *testing.T) { } plainText := []byte("Hello, World") - msgType, message1, err := aliceSession.Encrypt(plainText, nil) + msgType, message1, err := aliceSession.Encrypt(plainText) if err != nil { t.Fatal(err) } @@ -260,12 +260,12 @@ func TestLoopback(t *testing.T) { t.Fatal("wrong message type") } - bobSession, err := accountB.NewInboundSession(nil, message1) + bobSession, err := accountB.NewInboundSession(string(message1)) if err != nil { t.Fatal(err) } // Check that the inbound session matches the message it was created from. - sessionIsOK, err := bobSession.MatchesInboundSessionFrom(nil, message1) + sessionIsOK, err := bobSession.MatchesInboundSessionFrom("", string(message1)) if err != nil { t.Fatal(err) } @@ -274,7 +274,7 @@ func TestLoopback(t *testing.T) { } // Check that the inbound session matches the key this message is supposed to be from. aIDKey := accountA.IdKeys.Curve25519.PublicKey.B64Encoded() - sessionIsOK, err = bobSession.MatchesInboundSessionFrom(&aIDKey, message1) + sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(aIDKey), string(message1)) if err != nil { t.Fatal(err) } @@ -283,7 +283,7 @@ func TestLoopback(t *testing.T) { } // Check that the inbound session isn't from a different user. bIDKey := accountB.IdKeys.Curve25519.PublicKey.B64Encoded() - sessionIsOK, err = bobSession.MatchesInboundSessionFrom(&bIDKey, message1) + sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(bIDKey), string(message1)) if err != nil { t.Fatal(err) } @@ -291,7 +291,7 @@ func TestLoopback(t *testing.T) { t.Fatal("session is sad to be from b but is from a") } // Check that we can decrypt the message. - decryptedMessage, err := bobSession.Decrypt(message1, msgType) + decryptedMessage, err := bobSession.Decrypt(string(message1), msgType) if err != nil { t.Fatal(err) } @@ -299,7 +299,7 @@ func TestLoopback(t *testing.T) { t.Fatal("messages are not the same") } - msgTyp2, message2, err := bobSession.Encrypt(plainText, nil) + msgTyp2, message2, err := bobSession.Encrypt(plainText) if err != nil { t.Fatal(err) } @@ -307,7 +307,7 @@ func TestLoopback(t *testing.T) { t.Fatal("wrong message type") } - decryptedMessage2, err := aliceSession.Decrypt(message2, msgTyp2) + decryptedMessage2, err := aliceSession.Decrypt(string(message2), msgTyp2) if err != nil { t.Fatal(err) } @@ -316,7 +316,7 @@ func TestLoopback(t *testing.T) { } //decrypting again should fail, as the chain moved on - _, err = aliceSession.Decrypt(message2, msgTyp2) + _, err = aliceSession.Decrypt(string(message2), msgTyp2) if err == nil { t.Fatal("expected error") } @@ -348,7 +348,7 @@ func TestMoreMessages(t *testing.T) { } plainText := []byte("Hello, World") - msgType, message1, err := aliceSession.Encrypt(plainText, nil) + msgType, message1, err := aliceSession.Encrypt(plainText) if err != nil { t.Fatal(err) } @@ -356,11 +356,11 @@ func TestMoreMessages(t *testing.T) { t.Fatal("wrong message type") } - bobSession, err := accountB.NewInboundSession(nil, message1) + bobSession, err := accountB.NewInboundSession(string(message1)) if err != nil { t.Fatal(err) } - decryptedMessage, err := bobSession.Decrypt(message1, msgType) + decryptedMessage, err := bobSession.Decrypt(string(message1), msgType) if err != nil { t.Fatal(err) } @@ -370,7 +370,7 @@ func TestMoreMessages(t *testing.T) { for i := 0; i < 8; i++ { //alice sends, bob reveices - msgType, message, err := aliceSession.Encrypt(plainText, nil) + msgType, message, err := aliceSession.Encrypt(plainText) if err != nil { t.Fatal(err) } @@ -384,7 +384,7 @@ func TestMoreMessages(t *testing.T) { t.Fatal("wrong message type") } } - decryptedMessage, err := bobSession.Decrypt(message, msgType) + decryptedMessage, err := bobSession.Decrypt(string(message), msgType) if err != nil { t.Fatal(err) } @@ -393,14 +393,14 @@ func TestMoreMessages(t *testing.T) { } //now bob sends, alice receives - msgType, message, err = bobSession.Encrypt(plainText, nil) + msgType, message, err = bobSession.Encrypt(plainText) if err != nil { t.Fatal(err) } if msgType == id.OlmMsgTypePreKey { t.Fatal("wrong message type") } - decryptedMessage, err = aliceSession.Decrypt(message, msgType) + decryptedMessage, err = aliceSession.Decrypt(string(message), msgType) if err != nil { t.Fatal(err) } @@ -435,7 +435,7 @@ func TestFallbackKey(t *testing.T) { } plainText := []byte("Hello, World") - msgType, message1, err := aliceSession.Encrypt(plainText, nil) + msgType, message1, err := aliceSession.Encrypt(plainText) if err != nil { t.Fatal(err) } @@ -443,12 +443,12 @@ func TestFallbackKey(t *testing.T) { t.Fatal("wrong message type") } - bobSession, err := accountB.NewInboundSession(nil, message1) + bobSession, err := accountB.NewInboundSession(string(message1)) if err != nil { t.Fatal(err) } // Check that the inbound session matches the message it was created from. - sessionIsOK, err := bobSession.MatchesInboundSessionFrom(nil, message1) + sessionIsOK, err := bobSession.MatchesInboundSessionFrom("", string(message1)) if err != nil { t.Fatal(err) } @@ -457,7 +457,7 @@ func TestFallbackKey(t *testing.T) { } // Check that the inbound session matches the key this message is supposed to be from. aIDKey := accountA.IdKeys.Curve25519.PublicKey.B64Encoded() - sessionIsOK, err = bobSession.MatchesInboundSessionFrom(&aIDKey, message1) + sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(aIDKey), string(message1)) if err != nil { t.Fatal(err) } @@ -466,7 +466,7 @@ func TestFallbackKey(t *testing.T) { } // Check that the inbound session isn't from a different user. bIDKey := accountB.IdKeys.Curve25519.PublicKey.B64Encoded() - sessionIsOK, err = bobSession.MatchesInboundSessionFrom(&bIDKey, message1) + sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(bIDKey), string(message1)) if err != nil { t.Fatal(err) } @@ -474,7 +474,7 @@ func TestFallbackKey(t *testing.T) { t.Fatal("session is sad to be from b but is from a") } // Check that we can decrypt the message. - decryptedMessage, err := bobSession.Decrypt(message1, msgType) + decryptedMessage, err := bobSession.Decrypt(string(message1), msgType) if err != nil { t.Fatal(err) } @@ -493,7 +493,7 @@ func TestFallbackKey(t *testing.T) { t.Fatal(err) } - msgType2, message2, err := aliceSession2.Encrypt(plainText, nil) + msgType2, message2, err := aliceSession2.Encrypt(plainText) if err != nil { t.Fatal(err) } @@ -502,19 +502,19 @@ func TestFallbackKey(t *testing.T) { } // bobSession should not be valid for the message2 // Check that the inbound session matches the message it was created from. - sessionIsOK, err = bobSession.MatchesInboundSessionFrom(nil, message2) + sessionIsOK, err = bobSession.MatchesInboundSessionFrom("", string(message2)) if err != nil { t.Fatal(err) } if sessionIsOK { t.Fatal("session was detected to be valid but should not") } - bobSession2, err := accountB.NewInboundSession(nil, message2) + bobSession2, err := accountB.NewInboundSession(string(message2)) if err != nil { t.Fatal(err) } // Check that the inbound session matches the message it was created from. - sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(nil, message2) + sessionIsOK, err = bobSession2.MatchesInboundSessionFrom("", string(message2)) if err != nil { t.Fatal(err) } @@ -522,7 +522,7 @@ func TestFallbackKey(t *testing.T) { t.Fatal("session was not detected to be valid") } // Check that the inbound session matches the key this message is supposed to be from. - sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(&aIDKey, message2) + sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(string(aIDKey), string(message2)) if err != nil { t.Fatal(err) } @@ -530,7 +530,7 @@ func TestFallbackKey(t *testing.T) { t.Fatal("session is sad to be not from a but it should") } // Check that the inbound session isn't from a different user. - sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(&bIDKey, message2) + sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(string(bIDKey), string(message2)) if err != nil { t.Fatal(err) } @@ -538,7 +538,7 @@ func TestFallbackKey(t *testing.T) { t.Fatal("session is sad to be from b but is from a") } // Check that we can decrypt the message. - decryptedMessage2, err := bobSession2.Decrypt(message2, msgType2) + decryptedMessage2, err := bobSession2.Decrypt(string(message2), msgType2) if err != nil { t.Fatal(err) } @@ -553,18 +553,18 @@ func TestFallbackKey(t *testing.T) { if err != nil { t.Fatal(err) } - msgType3, message3, err := aliceSession3.Encrypt(plainText, nil) + msgType3, message3, err := aliceSession3.Encrypt(plainText) if err != nil { t.Fatal(err) } if msgType3 != id.OlmMsgTypePreKey { t.Fatal("wrong message type") } - _, err = accountB.NewInboundSession(nil, message3) + _, err = accountB.NewInboundSession(string(message3)) if err == nil { t.Fatal("expected error") } - if !errors.Is(err, goolm.ErrBadMessageKeyID) { + if !errors.Is(err, olm.ErrBadMessageKeyID) { t.Fatal(err) } } diff --git a/crypto/goolm/account/register.go b/crypto/goolm/account/register.go new file mode 100644 index 00000000..ab0c598a --- /dev/null +++ b/crypto/goolm/account/register.go @@ -0,0 +1,25 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package account + +import ( + "io" + + "maunium.net/go/mautrix/crypto/olm" +) + +func init() { + olm.InitNewAccount = func(r io.Reader) (olm.Account, error) { + return NewAccount(r) + } + olm.InitBlankAccount = func() olm.Account { + return &Account{} + } + olm.InitNewAccountFromPickled = func(pickled, key []byte) (olm.Account, error) { + return AccountFromPickled(pickled, key) + } +} diff --git a/crypto/goolm/cipher/pickle.go b/crypto/goolm/cipher/pickle.go index 670ff6ff..551f4356 100644 --- a/crypto/goolm/cipher/pickle.go +++ b/crypto/goolm/cipher/pickle.go @@ -3,7 +3,8 @@ package cipher import ( "fmt" - "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/goolm/goolmbase64" + "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -28,14 +29,14 @@ func Pickle(key, input []byte) ([]byte, error) { return nil, err } ciphertext = append(ciphertext, mac[:pickleMACLength]...) - encoded := goolm.Base64Encode(ciphertext) + encoded := goolmbase64.Encode(ciphertext) return encoded, nil } // Unpickle decodes the input from base64 and decrypts the decoded input with the key and the cipher AESSHA256. func Unpickle(key, input []byte) ([]byte, error) { pickleCipher := NewAESSHA256([]byte(kdfPickle)) - ciphertext, err := goolm.Base64Decode(input) + ciphertext, err := goolmbase64.Decode(input) if err != nil { return nil, err } @@ -45,7 +46,7 @@ func Unpickle(key, input []byte) ([]byte, error) { return nil, err } if !verified { - return nil, fmt.Errorf("decrypt pickle: %w", goolm.ErrBadMAC) + return nil, fmt.Errorf("decrypt pickle: %w", olm.ErrBadMAC) } //Set to next block size targetCipherText := make([]byte, int(len(ciphertext)/PickleBlockSize())*PickleBlockSize()) diff --git a/crypto/goolm/crypto/curve25519.go b/crypto/goolm/crypto/curve25519.go index 125e1bfd..1c182caa 100644 --- a/crypto/goolm/crypto/curve25519.go +++ b/crypto/goolm/crypto/curve25519.go @@ -9,8 +9,8 @@ import ( "golang.org/x/crypto/curve25519" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" + "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) @@ -78,7 +78,7 @@ func (c Curve25519KeyPair) SharedSecret(pubKey Curve25519PublicKey) ([]byte, err // It returns the number of bytes written. func (c Curve25519KeyPair) PickleLibOlm(target []byte) (int, error) { if len(target) < c.PickleLen() { - return 0, fmt.Errorf("pickle curve25519 key pair: %w", goolm.ErrValueTooShort) + return 0, fmt.Errorf("pickle curve25519 key pair: %w", olm.ErrValueTooShort) } written, err := c.PublicKey.PickleLibOlm(target) if err != nil { @@ -159,7 +159,7 @@ func (c Curve25519PublicKey) B64Encoded() id.Curve25519 { // It returns the number of bytes written. func (c Curve25519PublicKey) PickleLibOlm(target []byte) (int, error) { if len(target) < c.PickleLen() { - return 0, fmt.Errorf("pickle curve25519 public key: %w", goolm.ErrValueTooShort) + return 0, fmt.Errorf("pickle curve25519 public key: %w", olm.ErrValueTooShort) } if len(c) != curve25519PubKeyLength { return libolmpickle.PickleBytes(make([]byte, curve25519PubKeyLength), target), nil diff --git a/crypto/goolm/crypto/ed25519.go b/crypto/goolm/crypto/ed25519.go index f0c56297..0756d778 100644 --- a/crypto/goolm/crypto/ed25519.go +++ b/crypto/goolm/crypto/ed25519.go @@ -6,8 +6,8 @@ import ( "fmt" "io" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" + "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) @@ -69,7 +69,7 @@ func (c Ed25519KeyPair) Verify(message, givenSignature []byte) bool { // It returns the number of bytes written. func (c Ed25519KeyPair) PickleLibOlm(target []byte) (int, error) { if len(target) < c.PickleLen() { - return 0, fmt.Errorf("pickle ed25519 key pair: %w", goolm.ErrValueTooShort) + return 0, fmt.Errorf("pickle ed25519 key pair: %w", olm.ErrValueTooShort) } written, err := c.PublicKey.PickleLibOlm(target) if err != nil { @@ -153,7 +153,7 @@ func (c Ed25519PublicKey) Verify(message, givenSignature []byte) bool { // It returns the number of bytes written. func (c Ed25519PublicKey) PickleLibOlm(target []byte) (int, error) { if len(target) < c.PickleLen() { - return 0, fmt.Errorf("pickle ed25519 public key: %w", goolm.ErrValueTooShort) + return 0, fmt.Errorf("pickle ed25519 public key: %w", olm.ErrValueTooShort) } if len(c) != ed25519.PublicKeySize { return libolmpickle.PickleBytes(make([]byte, ed25519.PublicKeySize), target), nil diff --git a/crypto/goolm/crypto/one_time_key.go b/crypto/goolm/crypto/one_time_key.go index 67465563..aaa253d2 100644 --- a/crypto/goolm/crypto/one_time_key.go +++ b/crypto/goolm/crypto/one_time_key.go @@ -5,8 +5,8 @@ import ( "encoding/binary" "fmt" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" + "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) @@ -38,7 +38,7 @@ func (otk OneTimeKey) Equal(s OneTimeKey) bool { // It returns the number of bytes written. func (c OneTimeKey) PickleLibOlm(target []byte) (int, error) { if len(target) < c.PickleLen() { - return 0, fmt.Errorf("pickle one time key: %w", goolm.ErrValueTooShort) + return 0, fmt.Errorf("pickle one time key: %w", olm.ErrValueTooShort) } written := libolmpickle.PickleUInt32(uint32(c.ID), target) written += libolmpickle.PickleBool(c.Published, target[written:]) diff --git a/crypto/goolm/errors.go b/crypto/goolm/errors.go deleted file mode 100644 index 6539b0f1..00000000 --- a/crypto/goolm/errors.go +++ /dev/null @@ -1,28 +0,0 @@ -package goolm - -import ( - "errors" -) - -// Those are the most common used errors -var ( - ErrBadSignature = errors.New("bad signature") - ErrBadMAC = errors.New("bad mac") - ErrBadMessageFormat = errors.New("bad message format") - ErrBadVerification = errors.New("bad verification") - ErrWrongProtocolVersion = errors.New("wrong protocol version") - ErrEmptyInput = errors.New("empty input") - ErrNoKeyProvided = errors.New("no key") - ErrBadMessageKeyID = errors.New("bad message key id") - ErrRatchetNotAvailable = errors.New("ratchet not available: attempt to decode a message whose index is earlier than our earliest known session key") - ErrMsgIndexTooHigh = errors.New("message index too high") - ErrProtocolViolation = errors.New("not protocol message order") - ErrMessageKeyNotFound = errors.New("message key not found") - ErrChainTooHigh = errors.New("chain index too high") - ErrBadInput = errors.New("bad input") - ErrBadVersion = errors.New("wrong version") - ErrWrongPickleVersion = errors.New("wrong pickle version") - ErrValueTooShort = errors.New("value too short") - ErrInputToSmall = errors.New("input too small (truncated?)") - ErrOverflow = errors.New("overflow") -) diff --git a/crypto/goolm/base64.go b/crypto/goolm/goolmbase64/base64.go similarity index 82% rename from crypto/goolm/base64.go rename to crypto/goolm/goolmbase64/base64.go index 229008cf..061a052a 100644 --- a/crypto/goolm/base64.go +++ b/crypto/goolm/goolmbase64/base64.go @@ -1,11 +1,11 @@ -package goolm +package goolmbase64 import ( "encoding/base64" ) // Deprecated: base64.RawStdEncoding should be used directly -func Base64Decode(input []byte) ([]byte, error) { +func Decode(input []byte) ([]byte, error) { decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(input))) writtenBytes, err := base64.RawStdEncoding.Decode(decoded, input) if err != nil { @@ -15,7 +15,7 @@ func Base64Decode(input []byte) ([]byte, error) { } // Deprecated: base64.RawStdEncoding should be used directly -func Base64Encode(input []byte) []byte { +func Encode(input []byte) []byte { encoded := make([]byte, base64.RawStdEncoding.EncodedLen(len(input))) base64.RawStdEncoding.Encode(encoded, input) return encoded diff --git a/crypto/goolm/libolmpickle/unpickle.go b/crypto/goolm/libolmpickle/unpickle.go index 9a6a4b62..dbd275aa 100644 --- a/crypto/goolm/libolmpickle/unpickle.go +++ b/crypto/goolm/libolmpickle/unpickle.go @@ -3,7 +3,7 @@ package libolmpickle import ( "fmt" - "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/olm" ) func isZeroByteSlice(bytes []byte) bool { @@ -16,21 +16,21 @@ func isZeroByteSlice(bytes []byte) bool { func UnpickleUInt8(value []byte) (uint8, int, error) { if len(value) < 1 { - return 0, 0, fmt.Errorf("unpickle uint8: %w", goolm.ErrValueTooShort) + return 0, 0, fmt.Errorf("unpickle uint8: %w", olm.ErrValueTooShort) } return value[0], 1, nil } func UnpickleBool(value []byte) (bool, int, error) { if len(value) < 1 { - return false, 0, fmt.Errorf("unpickle bool: %w", goolm.ErrValueTooShort) + return false, 0, fmt.Errorf("unpickle bool: %w", olm.ErrValueTooShort) } return value[0] != uint8(0x00), 1, nil } func UnpickleBytes(value []byte, length int) ([]byte, int, error) { if len(value) < length { - return nil, 0, fmt.Errorf("unpickle bytes: %w", goolm.ErrValueTooShort) + return nil, 0, fmt.Errorf("unpickle bytes: %w", olm.ErrValueTooShort) } resp := value[:length] if isZeroByteSlice(resp) { @@ -41,7 +41,7 @@ func UnpickleBytes(value []byte, length int) ([]byte, int, error) { func UnpickleUInt32(value []byte) (uint32, int, error) { if len(value) < 4 { - return 0, 0, fmt.Errorf("unpickle uint32: %w", goolm.ErrValueTooShort) + return 0, 0, fmt.Errorf("unpickle uint32: %w", olm.ErrValueTooShort) } var res uint32 count := 0 diff --git a/crypto/goolm/megolm/megolm.go b/crypto/goolm/megolm/megolm.go index c3493f7b..c88583ee 100644 --- a/crypto/goolm/megolm/megolm.go +++ b/crypto/goolm/megolm/megolm.go @@ -5,12 +5,13 @@ import ( "crypto/rand" "fmt" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/message" "maunium.net/go/mautrix/crypto/goolm/utilities" + "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -158,7 +159,7 @@ func (r Ratchet) SessionSharingMessage(key crypto.Ed25519KeyPair) ([]byte, error m.Counter = r.Counter m.RatchetData = r.Data encoded := m.EncodeAndSign(key) - return goolm.Base64Encode(encoded), nil + return goolmbase64.Encode(encoded), nil } // SessionExportMessage creates a message in the session export format. @@ -168,7 +169,7 @@ func (r Ratchet) SessionExportMessage(key crypto.Ed25519PublicKey) ([]byte, erro m.RatchetData = r.Data m.PublicKey = key encoded := m.Encode() - return goolm.Base64Encode(encoded), nil + return goolmbase64.Encode(encoded), nil } // Decrypt decrypts the ciphertext and verifies the MAC but not the signature. @@ -179,7 +180,7 @@ func (r Ratchet) Decrypt(ciphertext []byte, signingkey *crypto.Ed25519PublicKey, return nil, err } if !verifiedMAC { - return nil, fmt.Errorf("decrypt: %w", goolm.ErrBadMAC) + return nil, fmt.Errorf("decrypt: %w", olm.ErrBadMAC) } return RatchetCipher.Decrypt(r.Data[:], msg.Ciphertext) @@ -219,7 +220,7 @@ func (r *Ratchet) UnpickleLibOlm(unpickled []byte) (int, error) { // It returns the number of bytes written. func (r Ratchet) PickleLibOlm(target []byte) (int, error) { if len(target) < r.PickleLen() { - return 0, fmt.Errorf("pickle account: %w", goolm.ErrValueTooShort) + return 0, fmt.Errorf("pickle account: %w", olm.ErrValueTooShort) } written := libolmpickle.PickleBytes(r.Data[:], target) written += libolmpickle.PickleUInt32(r.Counter, target[written:]) diff --git a/crypto/goolm/message/decoder.go b/crypto/goolm/message/decoder.go index ba49f011..9ce426b5 100644 --- a/crypto/goolm/message/decoder.go +++ b/crypto/goolm/message/decoder.go @@ -3,17 +3,17 @@ package message import ( "encoding/binary" - "maunium.net/go/mautrix/crypto/goolm" + "maunium.net/go/mautrix/crypto/olm" ) // checkDecodeErr checks if there was an error during decode. func checkDecodeErr(readBytes int) error { if readBytes == 0 { //end reached - return goolm.ErrInputToSmall + return olm.ErrInputToSmall } if readBytes < 0 { - return goolm.ErrOverflow + return olm.ErrOverflow } return nil } diff --git a/crypto/goolm/message/session_export.go b/crypto/goolm/message/session_export.go index f539cce5..956868b2 100644 --- a/crypto/goolm/message/session_export.go +++ b/crypto/goolm/message/session_export.go @@ -4,8 +4,8 @@ import ( "encoding/binary" "fmt" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -32,10 +32,10 @@ func (s MegolmSessionExport) Encode() []byte { // Decode populates the struct with the data encoded in input. func (s *MegolmSessionExport) Decode(input []byte) error { if len(input) != 165 { - return fmt.Errorf("decrypt: %w", goolm.ErrBadInput) + return fmt.Errorf("decrypt: %w", olm.ErrBadInput) } if input[0] != sessionExportVersion { - return fmt.Errorf("decrypt: %w", goolm.ErrBadVersion) + return fmt.Errorf("decrypt: %w", olm.ErrBadVersion) } s.Counter = binary.BigEndian.Uint32(input[1:5]) copy(s.RatchetData[:], input[5:133]) diff --git a/crypto/goolm/message/session_sharing.go b/crypto/goolm/message/session_sharing.go index c5393f50..85d5d20b 100644 --- a/crypto/goolm/message/session_sharing.go +++ b/crypto/goolm/message/session_sharing.go @@ -4,8 +4,8 @@ import ( "encoding/binary" "fmt" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -34,15 +34,15 @@ func (s MegolmSessionSharing) EncodeAndSign(key crypto.Ed25519KeyPair) []byte { // VerifyAndDecode verifies the input and populates the struct with the data encoded in input. func (s *MegolmSessionSharing) VerifyAndDecode(input []byte) error { if len(input) != 229 { - return fmt.Errorf("verify: %w", goolm.ErrBadInput) + return fmt.Errorf("verify: %w", olm.ErrBadInput) } publicKey := crypto.Ed25519PublicKey(input[133:165]) if !publicKey.Verify(input[:165], input[165:]) { - return fmt.Errorf("verify: %w", goolm.ErrBadVerification) + return fmt.Errorf("verify: %w", olm.ErrBadVerification) } s.PublicKey = publicKey if input[0] != sessionSharingVersion { - return fmt.Errorf("verify: %w", goolm.ErrBadVersion) + return fmt.Errorf("verify: %w", olm.ErrBadVersion) } s.Counter = binary.BigEndian.Uint32(input[1:5]) copy(s.RatchetData[:], input[5:133]) diff --git a/crypto/goolm/pk/decryption.go b/crypto/goolm/pk/decryption.go index d08e09f4..b24716e8 100644 --- a/crypto/goolm/pk/decryption.go +++ b/crypto/goolm/pk/decryption.go @@ -5,11 +5,12 @@ import ( "errors" "fmt" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/utilities" + "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) @@ -56,8 +57,8 @@ func (s Decryption) PrivateKey() crypto.Curve25519PrivateKey { } // Decrypt decrypts the ciphertext and verifies the MAC. The base64 encoded key is used to construct the shared secret. -func (s Decryption) Decrypt(ciphertext, mac []byte, key id.Curve25519) ([]byte, error) { - keyDecoded, err := base64.RawStdEncoding.DecodeString(string(key)) +func (s Decryption) Decrypt(ephemeralKey, mac, ciphertext []byte) ([]byte, error) { + keyDecoded, err := base64.RawStdEncoding.DecodeString(string(ephemeralKey)) if err != nil { return nil, err } @@ -65,7 +66,7 @@ func (s Decryption) Decrypt(ciphertext, mac []byte, key id.Curve25519) ([]byte, if err != nil { return nil, err } - decodedMAC, err := goolm.Base64Decode(mac) + decodedMAC, err := goolmbase64.Decode(mac) if err != nil { return nil, err } @@ -75,7 +76,7 @@ func (s Decryption) Decrypt(ciphertext, mac []byte, key id.Curve25519) ([]byte, return nil, err } if !verified { - return nil, fmt.Errorf("decrypt: %w", goolm.ErrBadMAC) + return nil, fmt.Errorf("decrypt: %w", olm.ErrBadMAC) } plaintext, err := cipher.Decrypt(sharedSecret, ciphertext) if err != nil { @@ -115,7 +116,7 @@ func (a *Decryption) UnpickleLibOlm(value []byte) (int, error) { switch pickledVersion { case decryptionPickleVersionLibOlm: default: - return 0, fmt.Errorf("unpickle olmSession: %w", goolm.ErrBadVersion) + return 0, fmt.Errorf("unpickle olmSession: %w", olm.ErrBadVersion) } readBytes, err := a.KeyPair.UnpickleLibOlm(value[curPos:]) if err != nil { @@ -146,7 +147,7 @@ func (a Decryption) Pickle(key []byte) ([]byte, error) { // It returns the number of bytes written. func (a Decryption) PickleLibOlm(target []byte) (int, error) { if len(target) < a.PickleLen() { - return 0, fmt.Errorf("pickle Decryption: %w", goolm.ErrValueTooShort) + return 0, fmt.Errorf("pickle Decryption: %w", olm.ErrValueTooShort) } written := libolmpickle.PickleUInt32(decryptionPickleVersionLibOlm, target) writtenKey, err := a.KeyPair.PickleLibOlm(target[written:]) diff --git a/crypto/goolm/pk/encryption.go b/crypto/goolm/pk/encryption.go index dc50a6bb..54f15830 100644 --- a/crypto/goolm/pk/encryption.go +++ b/crypto/goolm/pk/encryption.go @@ -5,9 +5,9 @@ import ( "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/goolmbase64" ) // Encryption is used to encrypt pk messages @@ -45,5 +45,5 @@ func (e Encryption) Encrypt(plaintext []byte, privateKey crypto.Curve25519Privat if err != nil { return nil, nil, err } - return ciphertext, goolm.Base64Encode(mac), nil + return ciphertext, goolmbase64.Encode(mac), nil } diff --git a/crypto/goolm/pk/pk_test.go b/crypto/goolm/pk/pk_test.go index 7ac524be..f2d9b108 100644 --- a/crypto/goolm/pk/pk_test.go +++ b/crypto/goolm/pk/pk_test.go @@ -5,10 +5,9 @@ import ( "encoding/base64" "testing" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/pk" - "maunium.net/go/mautrix/id" ) func TestEncryptionDecryption(t *testing.T) { @@ -48,7 +47,7 @@ func TestEncryptionDecryption(t *testing.T) { t.Fatal(err) } - decrypted, err := decryption.Decrypt(ciphertext, mac, id.Curve25519(bobPublic)) + decrypted, err := decryption.Decrypt(bobPublic, mac, ciphertext) if err != nil { t.Fatal(err) } @@ -70,7 +69,7 @@ func TestSigning(t *testing.T) { if err != nil { t.Fatal(err) } - signatureDecoded, err := goolm.Base64Decode(signature) + signatureDecoded, err := goolmbase64.Decode(signature) if err != nil { t.Fatal(err) } diff --git a/crypto/goolm/pk/register.go b/crypto/goolm/pk/register.go new file mode 100644 index 00000000..b7af6a5b --- /dev/null +++ b/crypto/goolm/pk/register.go @@ -0,0 +1,21 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package pk + +import "maunium.net/go/mautrix/crypto/olm" + +func init() { + olm.InitNewPKSigningFromSeed = func(seed []byte) (olm.PKSigning, error) { + return NewSigningFromSeed(seed) + } + olm.InitNewPKSigning = func() (olm.PKSigning, error) { + return NewSigning() + } + olm.InitNewPKDecryptionFromPrivateKey = func(privateKey []byte) (olm.PKDecryption, error) { + return NewDecryptionFromPrivate(privateKey) + } +} diff --git a/crypto/goolm/pk/signing.go b/crypto/goolm/pk/signing.go index a98330d5..b22c76dc 100644 --- a/crypto/goolm/pk/signing.go +++ b/crypto/goolm/pk/signing.go @@ -7,8 +7,8 @@ import ( "github.com/tidwall/sjson" "maunium.net/go/mautrix/crypto/canonicaljson" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/id" ) @@ -49,7 +49,7 @@ func (s Signing) PublicKey() id.Ed25519 { // Sign returns the signature of the message base64 encoded. func (s Signing) Sign(message []byte) ([]byte, error) { signature := s.keyPair.Sign(message) - return goolm.Base64Encode(signature), nil + return goolmbase64.Encode(signature), nil } // SignJSON creates a signature for the given object after encoding it to diff --git a/crypto/goolm/olm/chain.go b/crypto/goolm/ratchet/chain.go similarity index 95% rename from crypto/goolm/olm/chain.go rename to crypto/goolm/ratchet/chain.go index 403637a4..2c2789b7 100644 --- a/crypto/goolm/olm/chain.go +++ b/crypto/goolm/ratchet/chain.go @@ -1,11 +1,11 @@ -package olm +package ratchet import ( "fmt" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" + "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -45,7 +45,7 @@ func (r *chainKey) UnpickleLibOlm(value []byte) (int, error) { // It returns the number of bytes written. func (r chainKey) PickleLibOlm(target []byte) (int, error) { if len(target) < r.PickleLen() { - return 0, fmt.Errorf("pickle chain key: %w", goolm.ErrValueTooShort) + return 0, fmt.Errorf("pickle chain key: %w", olm.ErrValueTooShort) } written, err := r.Key.PickleLibOlm(target) if err != nil { @@ -116,7 +116,7 @@ func (r *senderChain) UnpickleLibOlm(value []byte) (int, error) { // It returns the number of bytes written. func (r senderChain) PickleLibOlm(target []byte) (int, error) { if len(target) < r.PickleLen() { - return 0, fmt.Errorf("pickle sender chain: %w", goolm.ErrValueTooShort) + return 0, fmt.Errorf("pickle sender chain: %w", olm.ErrValueTooShort) } written, err := r.RKey.PickleLibOlm(target) if err != nil { @@ -189,7 +189,7 @@ func (r *receiverChain) UnpickleLibOlm(value []byte) (int, error) { // It returns the number of bytes written. func (r receiverChain) PickleLibOlm(target []byte) (int, error) { if len(target) < r.PickleLen() { - return 0, fmt.Errorf("pickle sender chain: %w", goolm.ErrValueTooShort) + return 0, fmt.Errorf("pickle sender chain: %w", olm.ErrValueTooShort) } written, err := r.RKey.PickleLibOlm(target) if err != nil { @@ -238,7 +238,7 @@ func (m *messageKey) UnpickleLibOlm(value []byte) (int, error) { // It returns the number of bytes written. func (m messageKey) PickleLibOlm(target []byte) (int, error) { if len(target) < m.PickleLen() { - return 0, fmt.Errorf("pickle message key: %w", goolm.ErrValueTooShort) + return 0, fmt.Errorf("pickle message key: %w", olm.ErrValueTooShort) } written := 0 if len(m.Key) != messageKeyLength { diff --git a/crypto/goolm/olm/olm.go b/crypto/goolm/ratchet/olm.go similarity index 93% rename from crypto/goolm/olm/olm.go rename to crypto/goolm/ratchet/olm.go index 299ec7c4..bf04c1cf 100644 --- a/crypto/goolm/olm/olm.go +++ b/crypto/goolm/ratchet/olm.go @@ -1,16 +1,16 @@ -// olm provides the ratchet used by the olm protocol -package olm +// Package ratchet provides the ratchet used by the olm protocol +package ratchet import ( "fmt" "io" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/message" "maunium.net/go/mautrix/crypto/goolm/utilities" + "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -95,10 +95,10 @@ func (r *Ratchet) InitializeAsAlice(sharedSecret []byte, ourRatchetKey crypto.Cu } // Encrypt encrypts the message in a message.Message with MAC. If reader is nil, crypto/rand is used for key generations. -func (r *Ratchet) Encrypt(plaintext []byte, reader io.Reader) ([]byte, error) { +func (r *Ratchet) Encrypt(plaintext []byte) ([]byte, error) { var err error if !r.SenderChains.IsSet { - newRatchetKey, err := crypto.Curve25519GenerateKey(reader) + newRatchetKey, err := crypto.Curve25519GenerateKey(nil) if err != nil { return nil, err } @@ -141,10 +141,10 @@ func (r *Ratchet) Decrypt(input []byte) ([]byte, error) { return nil, err } if message.Version != protocolVersion { - return nil, fmt.Errorf("decrypt: %w", goolm.ErrWrongProtocolVersion) + return nil, fmt.Errorf("decrypt: %w", olm.ErrWrongProtocolVersion) } if !message.HasCounter || len(message.RatchetKey) == 0 || len(message.Ciphertext) == 0 { - return nil, fmt.Errorf("decrypt: %w", goolm.ErrBadMessageFormat) + return nil, fmt.Errorf("decrypt: %w", olm.ErrBadMessageFormat) } var receiverChainFromMessage *receiverChain for curChainIndex := range r.ReceiverChains { @@ -173,7 +173,7 @@ func (r *Ratchet) Decrypt(input []byte) ([]byte, error) { return nil, err } if !verified { - return nil, fmt.Errorf("decrypt from skipped message keys: %w", goolm.ErrBadMAC) + return nil, fmt.Errorf("decrypt from skipped message keys: %w", olm.ErrBadMAC) } result, err = RatchetCipher.Decrypt(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, message.Ciphertext) if err != nil { @@ -189,7 +189,7 @@ func (r *Ratchet) Decrypt(input []byte) ([]byte, error) { } } if !foundSkippedKey { - return nil, fmt.Errorf("decrypt: %w", goolm.ErrMessageKeyNotFound) + return nil, fmt.Errorf("decrypt: %w", olm.ErrMessageKeyNotFound) } } else { //Advancing the chain is done in this method @@ -228,11 +228,11 @@ func (r Ratchet) createMessageKeys(chainKey chainKey) messageKey { // decryptForExistingChain returns the decrypted message by using the chain. The MAC of the rawMessage is verified. func (r *Ratchet) decryptForExistingChain(chain *receiverChain, message *message.Message, rawMessage []byte) ([]byte, error) { if message.Counter < chain.CKey.Index { - return nil, fmt.Errorf("decrypt: %w", goolm.ErrChainTooHigh) + return nil, fmt.Errorf("decrypt: %w", olm.ErrChainTooHigh) } // Limit the number of hashes we're prepared to compute if message.Counter-chain.CKey.Index > maxMessageGap { - return nil, fmt.Errorf("decrypt from existing chain: %w", goolm.ErrMsgIndexTooHigh) + return nil, fmt.Errorf("decrypt from existing chain: %w", olm.ErrMsgIndexTooHigh) } for chain.CKey.Index < message.Counter { messageKey := r.createMessageKeys(chain.chainKey()) @@ -250,7 +250,7 @@ func (r *Ratchet) decryptForExistingChain(chain *receiverChain, message *message return nil, err } if !verified { - return nil, fmt.Errorf("decrypt from existing chain: %w", goolm.ErrBadMAC) + return nil, fmt.Errorf("decrypt from existing chain: %w", olm.ErrBadMAC) } return RatchetCipher.Decrypt(messageKey.Key, message.Ciphertext) } @@ -260,11 +260,11 @@ func (r *Ratchet) decryptForNewChain(message *message.Message, rawMessage []byte // They shouldn't move to a new chain until we've sent them a message // acknowledging the last one if !r.SenderChains.IsSet { - return nil, fmt.Errorf("decrypt for new chain: %w", goolm.ErrProtocolViolation) + return nil, fmt.Errorf("decrypt for new chain: %w", olm.ErrProtocolViolation) } // Limit the number of hashes we're prepared to compute if message.Counter > maxMessageGap { - return nil, fmt.Errorf("decrypt for new chain: %w", goolm.ErrMsgIndexTooHigh) + return nil, fmt.Errorf("decrypt for new chain: %w", olm.ErrMsgIndexTooHigh) } newChainKey, err := r.advanceRootKey(r.SenderChains.ratchetKey(), message.RatchetKey) @@ -371,7 +371,7 @@ func (r *Ratchet) UnpickleLibOlm(value []byte, includesChainIndex bool) (int, er // It returns the number of bytes written. func (r Ratchet) PickleLibOlm(target []byte) (int, error) { if len(target) < r.PickleLen() { - return 0, fmt.Errorf("pickle ratchet: %w", goolm.ErrValueTooShort) + return 0, fmt.Errorf("pickle ratchet: %w", olm.ErrValueTooShort) } written, err := r.RootKey.PickleLibOlm(target) if err != nil { diff --git a/crypto/goolm/olm/olm_test.go b/crypto/goolm/ratchet/olm_test.go similarity index 82% rename from crypto/goolm/olm/olm_test.go rename to crypto/goolm/ratchet/olm_test.go index 974ffc5e..91549bd8 100644 --- a/crypto/goolm/olm/olm_test.go +++ b/crypto/goolm/ratchet/olm_test.go @@ -1,4 +1,4 @@ -package olm_test +package ratchet_test import ( "bytes" @@ -7,24 +7,24 @@ import ( "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" - "maunium.net/go/mautrix/crypto/goolm/olm" + "maunium.net/go/mautrix/crypto/goolm/ratchet" ) var ( sharedSecret = []byte("A secret") ) -func initializeRatchets() (*olm.Ratchet, *olm.Ratchet, error) { - olm.KdfInfo = struct { +func initializeRatchets() (*ratchet.Ratchet, *ratchet.Ratchet, error) { + ratchet.KdfInfo = struct { Root []byte Ratchet []byte }{ Root: []byte("Olm"), Ratchet: []byte("OlmRatchet"), } - olm.RatchetCipher = cipher.NewAESSHA256([]byte("OlmMessageKeys")) - aliceRatchet := olm.New() - bobRatchet := olm.New() + ratchet.RatchetCipher = cipher.NewAESSHA256([]byte("OlmMessageKeys")) + aliceRatchet := ratchet.New() + bobRatchet := ratchet.New() aliceKey, err := crypto.Curve25519GenerateKey(nil) if err != nil { @@ -45,7 +45,7 @@ func TestSendReceive(t *testing.T) { plainText := []byte("Hello Bob") //Alice sends Bob a message - encryptedMessage, err := aliceRatchet.Encrypt(plainText, nil) + encryptedMessage, err := aliceRatchet.Encrypt(plainText) if err != nil { t.Fatal(err) } @@ -60,7 +60,7 @@ func TestSendReceive(t *testing.T) { //Bob sends Alice a message plainText = []byte("Hello Alice") - encryptedMessage, err = bobRatchet.Encrypt(plainText, nil) + encryptedMessage, err = bobRatchet.Encrypt(plainText) if err != nil { t.Fatal(err) } @@ -83,11 +83,11 @@ func TestOutOfOrder(t *testing.T) { plainText2 := []byte("Second Messsage. A bit longer than the first.") /* Alice sends Bob two messages and they arrive out of order */ - message1Encrypted, err := aliceRatchet.Encrypt(plainText1, nil) + message1Encrypted, err := aliceRatchet.Encrypt(plainText1) if err != nil { t.Fatal(err) } - message2Encrypted, err := aliceRatchet.Encrypt(plainText2, nil) + message2Encrypted, err := aliceRatchet.Encrypt(plainText2) if err != nil { t.Fatal(err) } @@ -115,7 +115,7 @@ func TestMoreMessages(t *testing.T) { } plainText := []byte("These 15 bytes") for i := 0; i < 8; i++ { - messageEncrypted, err := aliceRatchet.Encrypt(plainText, nil) + messageEncrypted, err := aliceRatchet.Encrypt(plainText) if err != nil { t.Fatal(err) } @@ -128,7 +128,7 @@ func TestMoreMessages(t *testing.T) { } } for i := 0; i < 8; i++ { - messageEncrypted, err := bobRatchet.Encrypt(plainText, nil) + messageEncrypted, err := bobRatchet.Encrypt(plainText) if err != nil { t.Fatal(err) } @@ -140,7 +140,7 @@ func TestMoreMessages(t *testing.T) { t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted) } } - messageEncrypted, err := aliceRatchet.Encrypt(plainText, nil) + messageEncrypted, err := aliceRatchet.Encrypt(plainText) if err != nil { t.Fatal(err) } @@ -163,7 +163,7 @@ func TestJSONEncoding(t *testing.T) { t.Fatal(err) } - newRatcher := olm.Ratchet{} + newRatcher := ratchet.Ratchet{} err = json.Unmarshal(marshaled, &newRatcher) if err != nil { t.Fatal(err) @@ -171,7 +171,7 @@ func TestJSONEncoding(t *testing.T) { plainText := []byte("These 15 bytes") - messageEncrypted, err := newRatcher.Encrypt(plainText, nil) + messageEncrypted, err := newRatcher.Encrypt(plainText) if err != nil { t.Fatal(err) } diff --git a/crypto/goolm/olm/skipped_message.go b/crypto/goolm/ratchet/skipped_message.go similarity index 92% rename from crypto/goolm/olm/skipped_message.go rename to crypto/goolm/ratchet/skipped_message.go index 944337f6..79927480 100644 --- a/crypto/goolm/olm/skipped_message.go +++ b/crypto/goolm/ratchet/skipped_message.go @@ -1,10 +1,10 @@ -package olm +package ratchet import ( "fmt" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/olm" ) // skippedMessageKey stores a skipped message key @@ -33,7 +33,7 @@ func (r *skippedMessageKey) UnpickleLibOlm(value []byte) (int, error) { // It returns the number of bytes written. func (r skippedMessageKey) PickleLibOlm(target []byte) (int, error) { if len(target) < r.PickleLen() { - return 0, fmt.Errorf("pickle sender chain: %w", goolm.ErrValueTooShort) + return 0, fmt.Errorf("pickle sender chain: %w", olm.ErrValueTooShort) } written, err := r.RKey.PickleLibOlm(target) if err != nil { diff --git a/crypto/goolm/register.go b/crypto/goolm/register.go new file mode 100644 index 00000000..80ed206b --- /dev/null +++ b/crypto/goolm/register.go @@ -0,0 +1,25 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package goolm + +import ( + // Need to import these subpackages to ensure they are registered + _ "maunium.net/go/mautrix/crypto/goolm/account" + _ "maunium.net/go/mautrix/crypto/goolm/pk" + _ "maunium.net/go/mautrix/crypto/goolm/session" + + "maunium.net/go/mautrix/crypto/olm" +) + +func init() { + olm.GetVersion = func() (major, minor, patch uint8) { + return 3, 2, 15 + } + olm.SetPickleKeyImpl = func(key []byte) { + panic("gob and json encoding is deprecated and not supported with goolm") + } +} diff --git a/crypto/goolm/session/megolm_inbound_session.go b/crypto/goolm/session/megolm_inbound_session.go index 165f7f16..f48698e7 100644 --- a/crypto/goolm/session/megolm_inbound_session.go +++ b/crypto/goolm/session/megolm_inbound_session.go @@ -5,13 +5,14 @@ import ( "errors" "fmt" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/megolm" "maunium.net/go/mautrix/crypto/goolm/message" "maunium.net/go/mautrix/crypto/goolm/utilities" + "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) @@ -28,10 +29,14 @@ type MegolmInboundSession struct { SigningKeyVerified bool `json:"signing_key_verified"` //not used for now } +// Ensure that MegolmInboundSession implements the [olm.InboundGroupSession] +// interface. +var _ olm.InboundGroupSession = (*MegolmInboundSession)(nil) + // NewMegolmInboundSession creates a new MegolmInboundSession from a base64 encoded session sharing message. func NewMegolmInboundSession(input []byte) (*MegolmInboundSession, error) { var err error - input, err = goolm.Base64Decode(input) + input, err = goolmbase64.Decode(input) if err != nil { return nil, err } @@ -55,7 +60,7 @@ func NewMegolmInboundSession(input []byte) (*MegolmInboundSession, error) { // NewMegolmInboundSessionFromExport creates a new MegolmInboundSession from a base64 encoded session export message. func NewMegolmInboundSessionFromExport(input []byte) (*MegolmInboundSession, error) { var err error - input, err = goolm.Base64Decode(input) + input, err = goolmbase64.Decode(input) if err != nil { return nil, err } @@ -78,7 +83,7 @@ func NewMegolmInboundSessionFromExport(input []byte) (*MegolmInboundSession, err // MegolmInboundSessionFromPickled loads the MegolmInboundSession details from a pickled base64 string. The input is decrypted with the supplied key. func MegolmInboundSessionFromPickled(pickled, key []byte) (*MegolmInboundSession, error) { if len(pickled) == 0 { - return nil, fmt.Errorf("megolmInboundSessionFromPickled: %w", goolm.ErrEmptyInput) + return nil, fmt.Errorf("megolmInboundSessionFromPickled: %w", olm.ErrEmptyInput) } a := &MegolmInboundSession{} err := a.Unpickle(pickled, key) @@ -89,7 +94,7 @@ func MegolmInboundSessionFromPickled(pickled, key []byte) (*MegolmInboundSession } // getRatchet tries to find the correct ratchet for a messageIndex. -func (o MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, error) { +func (o *MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, error) { // pick a megolm instance to use. if we are at or beyond the latest ratchet value, use that if (messageIndex - o.Ratchet.Counter) < uint32(1<<31) { o.Ratchet.AdvanceTo(messageIndex) @@ -97,7 +102,7 @@ func (o MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, } if (messageIndex - o.InitialRatchet.Counter) >= uint32(1<<31) { // the counter is before our initial ratchet - we can't decode this - return nil, fmt.Errorf("decrypt: %w", goolm.ErrRatchetNotAvailable) + return nil, fmt.Errorf("decrypt: %w", olm.ErrRatchetNotAvailable) } // otherwise, start from the initial ratchet. Take a copy so that we don't overwrite the initial ratchet copiedRatchet := o.InitialRatchet @@ -107,11 +112,14 @@ func (o MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, } // Decrypt decrypts a base64 encoded group message. -func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint32, error) { - if o.SigningKey == nil { - return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrBadMessageFormat) +func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint, error) { + if len(ciphertext) == 0 { + return nil, 0, olm.ErrEmptyInput } - decoded, err := goolm.Base64Decode(ciphertext) + if o.SigningKey == nil { + return nil, 0, fmt.Errorf("decrypt: %w", olm.ErrBadMessageFormat) + } + decoded, err := goolmbase64.Decode(ciphertext) if err != nil { return nil, 0, err } @@ -121,16 +129,16 @@ func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint32, error return nil, 0, err } if msg.Version != protocolVersion { - return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrWrongProtocolVersion) + return nil, 0, fmt.Errorf("decrypt: %w", olm.ErrWrongProtocolVersion) } if msg.Ciphertext == nil || !msg.HasMessageIndex { - return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrBadMessageFormat) + return nil, 0, fmt.Errorf("decrypt: %w", olm.ErrBadMessageFormat) } // verify signature verifiedSignature := msg.VerifySignatureInline(o.SigningKey, decoded) if !verifiedSignature { - return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrBadSignature) + return nil, 0, fmt.Errorf("decrypt: %w", olm.ErrBadSignature) } targetRatch, err := o.getRatchet(msg.MessageIndex) @@ -143,17 +151,17 @@ func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint32, error return nil, 0, err } o.SigningKeyVerified = true - return decrypted, msg.MessageIndex, nil + return decrypted, uint(msg.MessageIndex), nil } -// SessionID returns the base64 endoded signing key -func (o MegolmInboundSession) SessionID() id.SessionID { +// ID returns the base64 endoded signing key +func (o *MegolmInboundSession) ID() id.SessionID { return id.SessionID(base64.RawStdEncoding.EncodeToString(o.SigningKey)) } // PickleAsJSON returns an MegolmInboundSession as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. -func (o MegolmInboundSession) PickleAsJSON(key []byte) ([]byte, error) { +func (o *MegolmInboundSession) PickleAsJSON(key []byte) ([]byte, error) { return utilities.PickleAsJSON(o, megolmInboundSessionPickleVersionJSON, key) } @@ -162,8 +170,14 @@ func (o *MegolmInboundSession) UnpickleAsJSON(pickled, key []byte) error { return utilities.UnpickleAsJSON(o, pickled, key, megolmInboundSessionPickleVersionJSON) } -// SessionExportMessage creates an base64 encoded export of the session. -func (o MegolmInboundSession) SessionExportMessage(messageIndex uint32) ([]byte, error) { +// Export returns the base64-encoded ratchet key for this session, at the given +// index, in a format which can be used by +// InboundGroupSession.InboundGroupSessionImport(). Encrypts the +// InboundGroupSession using the supplied key. Returns error on failure. +// if we do not have a session key corresponding to the given index (ie, it was +// sent before the session key was shared with us) the error will be +// returned. +func (o *MegolmInboundSession) Export(messageIndex uint32) ([]byte, error) { ratchet, err := o.getRatchet(messageIndex) if err != nil { return nil, err @@ -174,6 +188,11 @@ func (o MegolmInboundSession) SessionExportMessage(messageIndex uint32) ([]byte, // Unpickle decodes the base64 encoded string and decrypts the result with the key. // The decrypted value is then passed to UnpickleLibOlm. func (o *MegolmInboundSession) Unpickle(pickled, key []byte) error { + if len(key) == 0 { + return olm.ErrNoKeyProvided + } else if len(pickled) == 0 { + return olm.ErrEmptyInput + } decrypted, err := cipher.Unpickle(key, pickled) if err != nil { return err @@ -192,7 +211,7 @@ func (o *MegolmInboundSession) UnpickleLibOlm(value []byte) (int, error) { switch pickledVersion { case megolmInboundSessionPickleVersionLibOlm, 1: default: - return 0, fmt.Errorf("unpickle MegolmInboundSession: %w", goolm.ErrBadVersion) + return 0, fmt.Errorf("unpickle MegolmInboundSession: %w", olm.ErrBadVersion) } readBytes, err := o.InitialRatchet.UnpickleLibOlm(value[curPos:]) if err != nil { @@ -223,7 +242,10 @@ func (o *MegolmInboundSession) UnpickleLibOlm(value []byte) (int, error) { } // Pickle returns a base64 encoded and with key encrypted pickled MegolmInboundSession using PickleLibOlm(). -func (o MegolmInboundSession) Pickle(key []byte) ([]byte, error) { +func (o *MegolmInboundSession) Pickle(key []byte) ([]byte, error) { + if len(key) == 0 { + return nil, olm.ErrNoKeyProvided + } pickeledBytes := make([]byte, o.PickleLen()) written, err := o.PickleLibOlm(pickeledBytes) if err != nil { @@ -241,9 +263,9 @@ func (o MegolmInboundSession) Pickle(key []byte) ([]byte, error) { // PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0. // It returns the number of bytes written. -func (o MegolmInboundSession) PickleLibOlm(target []byte) (int, error) { +func (o *MegolmInboundSession) PickleLibOlm(target []byte) (int, error) { if len(target) < o.PickleLen() { - return 0, fmt.Errorf("pickle MegolmInboundSession: %w", goolm.ErrValueTooShort) + return 0, fmt.Errorf("pickle MegolmInboundSession: %w", olm.ErrValueTooShort) } written := libolmpickle.PickleUInt32(megolmInboundSessionPickleVersionLibOlm, target) writtenInitRatchet, err := o.InitialRatchet.PickleLibOlm(target[written:]) @@ -266,7 +288,7 @@ func (o MegolmInboundSession) PickleLibOlm(target []byte) (int, error) { } // PickleLen returns the number of bytes the pickled session will have. -func (o MegolmInboundSession) PickleLen() int { +func (o *MegolmInboundSession) PickleLen() int { length := libolmpickle.PickleUInt32Len(megolmInboundSessionPickleVersionLibOlm) length += o.InitialRatchet.PickleLen() length += o.Ratchet.PickleLen() @@ -274,3 +296,15 @@ func (o MegolmInboundSession) PickleLen() int { length += libolmpickle.PickleBoolLen(o.SigningKeyVerified) return length } + +// FirstKnownIndex returns the first message index we know how to decrypt. +func (s *MegolmInboundSession) FirstKnownIndex() uint32 { + return s.InitialRatchet.Counter +} + +// IsVerified check if the session has been verified as a valid session. (A +// session is verified either because the original session share was signed, or +// because we have subsequently successfully decrypted a message.) +func (s *MegolmInboundSession) IsVerified() bool { + return s.SigningKeyVerified +} diff --git a/crypto/goolm/session/megolm_outbound_session.go b/crypto/goolm/session/megolm_outbound_session.go index e594258d..44d001d1 100644 --- a/crypto/goolm/session/megolm_outbound_session.go +++ b/crypto/goolm/session/megolm_outbound_session.go @@ -8,12 +8,13 @@ import ( "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/megolm" "maunium.net/go/mautrix/crypto/goolm/utilities" + "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -27,6 +28,8 @@ type MegolmOutboundSession struct { SigningKey crypto.Ed25519KeyPair `json:"signing_key"` } +var _ olm.OutboundGroupSession = (*MegolmOutboundSession)(nil) + // NewMegolmOutboundSession creates a new MegolmOutboundSession. func NewMegolmOutboundSession() (*MegolmOutboundSession, error) { o := &MegolmOutboundSession{} @@ -51,32 +54,32 @@ func NewMegolmOutboundSession() (*MegolmOutboundSession, error) { // MegolmOutboundSessionFromPickled loads the MegolmOutboundSession details from a pickled base64 string. The input is decrypted with the supplied key. func MegolmOutboundSessionFromPickled(pickled, key []byte) (*MegolmOutboundSession, error) { if len(pickled) == 0 { - return nil, fmt.Errorf("megolmOutboundSessionFromPickled: %w", goolm.ErrEmptyInput) + return nil, fmt.Errorf("megolmOutboundSessionFromPickled: %w", olm.ErrEmptyInput) } a := &MegolmOutboundSession{} err := a.Unpickle(pickled, key) - if err != nil { - return nil, err - } - return a, nil + return a, err } // Encrypt encrypts the plaintext as a base64 encoded group message. func (o *MegolmOutboundSession) Encrypt(plaintext []byte) ([]byte, error) { + if len(plaintext) == 0 { + return nil, olm.ErrEmptyInput + } encrypted, err := o.Ratchet.Encrypt(plaintext, &o.SigningKey) if err != nil { return nil, err } - return goolm.Base64Encode(encrypted), nil + return goolmbase64.Encode(encrypted), nil } // SessionID returns the base64 endoded public signing key -func (o MegolmOutboundSession) SessionID() id.SessionID { +func (o *MegolmOutboundSession) ID() id.SessionID { return id.SessionID(base64.RawStdEncoding.EncodeToString(o.SigningKey.PublicKey)) } // PickleAsJSON returns an Session as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. -func (o MegolmOutboundSession) PickleAsJSON(key []byte) ([]byte, error) { +func (o *MegolmOutboundSession) PickleAsJSON(key []byte) ([]byte, error) { return utilities.PickleAsJSON(o, megolmOutboundSessionPickleVersion, key) } @@ -88,6 +91,9 @@ func (o *MegolmOutboundSession) UnpickleAsJSON(pickled, key []byte) error { // Unpickle decodes the base64 encoded string and decrypts the result with the key. // The decrypted value is then passed to UnpickleLibOlm. func (o *MegolmOutboundSession) Unpickle(pickled, key []byte) error { + if len(key) == 0 { + return olm.ErrNoKeyProvided + } decrypted, err := cipher.Unpickle(key, pickled) if err != nil { return err @@ -106,7 +112,7 @@ func (o *MegolmOutboundSession) UnpickleLibOlm(value []byte) (int, error) { switch pickledVersion { case megolmOutboundSessionPickleVersionLibOlm: default: - return 0, fmt.Errorf("unpickle MegolmInboundSession: %w", goolm.ErrBadVersion) + return 0, fmt.Errorf("unpickle MegolmInboundSession: %w", olm.ErrBadVersion) } readBytes, err := o.Ratchet.UnpickleLibOlm(value[curPos:]) if err != nil { @@ -122,7 +128,10 @@ func (o *MegolmOutboundSession) UnpickleLibOlm(value []byte) (int, error) { } // Pickle returns a base64 encoded and with key encrypted pickled MegolmOutboundSession using PickleLibOlm(). -func (o MegolmOutboundSession) Pickle(key []byte) ([]byte, error) { +func (o *MegolmOutboundSession) Pickle(key []byte) ([]byte, error) { + if len(key) == 0 { + return nil, olm.ErrNoKeyProvided + } pickeledBytes := make([]byte, o.PickleLen()) written, err := o.PickleLibOlm(pickeledBytes) if err != nil { @@ -140,9 +149,9 @@ func (o MegolmOutboundSession) Pickle(key []byte) ([]byte, error) { // PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0. // It returns the number of bytes written. -func (o MegolmOutboundSession) PickleLibOlm(target []byte) (int, error) { +func (o *MegolmOutboundSession) PickleLibOlm(target []byte) (int, error) { if len(target) < o.PickleLen() { - return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", goolm.ErrValueTooShort) + return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", olm.ErrValueTooShort) } written := libolmpickle.PickleUInt32(megolmOutboundSessionPickleVersionLibOlm, target) writtenRatchet, err := o.Ratchet.PickleLibOlm(target[written:]) @@ -159,13 +168,28 @@ func (o MegolmOutboundSession) PickleLibOlm(target []byte) (int, error) { } // PickleLen returns the number of bytes the pickled session will have. -func (o MegolmOutboundSession) PickleLen() int { +func (o *MegolmOutboundSession) PickleLen() int { length := libolmpickle.PickleUInt32Len(megolmOutboundSessionPickleVersionLibOlm) length += o.Ratchet.PickleLen() length += o.SigningKey.PickleLen() return length } -func (o MegolmOutboundSession) SessionSharingMessage() ([]byte, error) { +func (o *MegolmOutboundSession) SessionSharingMessage() ([]byte, error) { return o.Ratchet.SessionSharingMessage(o.SigningKey) } + +// MessageIndex returns the message index for this session. Each message is +// sent with an increasing index; this returns the index for the next message. +func (s *MegolmOutboundSession) MessageIndex() uint { + return uint(s.Ratchet.Counter) +} + +// Key returns the base64-encoded current ratchet key for this session. +func (s *MegolmOutboundSession) Key() string { + message, err := s.SessionSharingMessage() + if err != nil { + panic(err) + } + return string(message) +} diff --git a/crypto/goolm/session/megolm_session_test.go b/crypto/goolm/session/megolm_session_test.go index 9b3f56b5..936ce982 100644 --- a/crypto/goolm/session/megolm_session_test.go +++ b/crypto/goolm/session/megolm_session_test.go @@ -6,10 +6,10 @@ import ( "errors" "testing" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/megolm" "maunium.net/go/mautrix/crypto/goolm/session" + "maunium.net/go/mautrix/crypto/olm" ) func TestOutboundPickleJSON(t *testing.T) { @@ -33,7 +33,7 @@ func TestOutboundPickleJSON(t *testing.T) { if err != nil { t.Fatal(err) } - if sess.SessionID() != newSession.SessionID() { + if sess.ID() != newSession.ID() { t.Fatal("session ids not equal") } if !bytes.Equal(sess.SigningKey.PrivateKey, newSession.SigningKey.PrivateKey) { @@ -75,7 +75,7 @@ func TestInboundPickleJSON(t *testing.T) { if err != nil { t.Fatal(err) } - if sess.SessionID() != newSession.SessionID() { + if sess.ID() != newSession.ID() { t.Fatal("sess ids not equal") } if !bytes.Equal(sess.SigningKey, newSession.SigningKey) { @@ -128,7 +128,7 @@ func TestGroupSendReceive(t *testing.T) { if !inboundSession.SigningKeyVerified { t.Fatal("key not verified") } - if inboundSession.SessionID() != outboundSession.SessionID() { + if inboundSession.ID() != outboundSession.ID() { t.Fatal("session ids not equal") } @@ -174,7 +174,7 @@ func TestGroupSessionExportImport(t *testing.T) { } //Export the keys - exported, err := inboundSession.SessionExportMessage(0) + exported, err := inboundSession.Export(0) if err != nil { t.Fatal(err) } @@ -236,7 +236,7 @@ func TestBadSignatureGroupMessage(t *testing.T) { if err == nil { t.Fatal("Signature was changed but did not cause an error") } - if !errors.Is(err, goolm.ErrBadSignature) { + if !errors.Is(err, olm.ErrBadSignature) { t.Fatalf("wrong error %s", err.Error()) } } diff --git a/crypto/goolm/session/olm_session.go b/crypto/goolm/session/olm_session.go index 6655e0a5..33908edc 100644 --- a/crypto/goolm/session/olm_session.go +++ b/crypto/goolm/session/olm_session.go @@ -5,15 +5,16 @@ import ( "encoding/base64" "errors" "fmt" - "io" + "strings" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/message" - "maunium.net/go/mautrix/crypto/goolm/olm" + "maunium.net/go/mautrix/crypto/goolm/ratchet" "maunium.net/go/mautrix/crypto/goolm/utilities" + "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) @@ -32,9 +33,11 @@ type OlmSession struct { AliceIdentityKey crypto.Curve25519PublicKey `json:"alice_id_key"` AliceBaseKey crypto.Curve25519PublicKey `json:"alice_base_key"` BobOneTimeKey crypto.Curve25519PublicKey `json:"bob_one_time_key"` - Ratchet olm.Ratchet `json:"ratchet"` + Ratchet ratchet.Ratchet `json:"ratchet"` } +var _ olm.Session = (*OlmSession)(nil) + // SearchOTKFunc is used to retrieve a crypto.OneTimeKey from a public key. type SearchOTKFunc = func(crypto.Curve25519PublicKey) *crypto.OneTimeKey @@ -42,7 +45,7 @@ type SearchOTKFunc = func(crypto.Curve25519PublicKey) *crypto.OneTimeKey // the Session using the supplied key. func OlmSessionFromJSONPickled(pickled, key []byte) (*OlmSession, error) { if len(pickled) == 0 { - return nil, fmt.Errorf("sessionFromPickled: %w", goolm.ErrEmptyInput) + return nil, fmt.Errorf("sessionFromPickled: %w", olm.ErrEmptyInput) } a := &OlmSession{} err := a.UnpickleAsJSON(pickled, key) @@ -55,7 +58,7 @@ func OlmSessionFromJSONPickled(pickled, key []byte) (*OlmSession, error) { // OlmSessionFromPickled loads the OlmSession details from a pickled base64 string. The input is decrypted with the supplied key. func OlmSessionFromPickled(pickled, key []byte) (*OlmSession, error) { if len(pickled) == 0 { - return nil, fmt.Errorf("sessionFromPickled: %w", goolm.ErrEmptyInput) + return nil, fmt.Errorf("sessionFromPickled: %w", olm.ErrEmptyInput) } a := &OlmSession{} err := a.Unpickle(pickled, key) @@ -68,7 +71,7 @@ func OlmSessionFromPickled(pickled, key []byte) (*OlmSession, error) { // NewOlmSession creates a new Session. func NewOlmSession() *OlmSession { s := &OlmSession{} - s.Ratchet = *olm.New() + s.Ratchet = *ratchet.New() return s } @@ -117,7 +120,7 @@ func NewOutboundOlmSession(identityKeyAlice crypto.Curve25519KeyPair, identityKe // NewInboundOlmSession creates a new inbound session from receiving the first message. func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, receivedOTKMsg []byte, searchBobOTK SearchOTKFunc, identityKeyBob crypto.Curve25519KeyPair) (*OlmSession, error) { - decodedOTKMsg, err := goolm.Base64Decode(receivedOTKMsg) + decodedOTKMsg, err := goolmbase64.Decode(receivedOTKMsg) if err != nil { return nil, err } @@ -130,7 +133,7 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received return nil, fmt.Errorf("OneTimeKeyMessage decode: %w", err) } if !oneTimeMsg.CheckFields(identityKeyAlice) { - return nil, fmt.Errorf("OneTimeKeyMessage check fields: %w", goolm.ErrBadMessageFormat) + return nil, fmt.Errorf("OneTimeKeyMessage check fields: %w", olm.ErrBadMessageFormat) } //Either the identityKeyAlice is set and/or the oneTimeMsg.IdentityKey is set, which is checked @@ -138,7 +141,7 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received if identityKeyAlice != nil && len(oneTimeMsg.IdentityKey) != 0 { //if both are set, compare them if !identityKeyAlice.Equal(oneTimeMsg.IdentityKey) { - return nil, fmt.Errorf("OneTimeKeyMessage identity keys: %w", goolm.ErrBadMessageKeyID) + return nil, fmt.Errorf("OneTimeKeyMessage identity keys: %w", olm.ErrBadMessageKeyID) } } if identityKeyAlice == nil { @@ -148,7 +151,7 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received oneTimeKeyBob := searchBobOTK(oneTimeMsg.OneTimeKey) if oneTimeKeyBob == nil { - return nil, fmt.Errorf("ourOneTimeKey: %w", goolm.ErrBadMessageKeyID) + return nil, fmt.Errorf("ourOneTimeKey: %w", olm.ErrBadMessageKeyID) } //Calculate shared secret via Triple Diffie-Hellman @@ -179,7 +182,7 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received } if len(msg.RatchetKey) == 0 { - return nil, fmt.Errorf("Message missing ratchet key: %w", goolm.ErrBadMessageFormat) + return nil, fmt.Errorf("Message missing ratchet key: %w", olm.ErrBadMessageFormat) } //Init Ratchet s.Ratchet.InitializeAsBob(secret, msg.RatchetKey) @@ -204,30 +207,54 @@ func (a *OlmSession) UnpickleAsJSON(pickled, key []byte) error { // ID returns an identifier for this Session. Will be the same for both ends of the conversation. // Generated by hashing the public keys used to create the session. -func (s OlmSession) ID() id.SessionID { +func (s *OlmSession) ID() id.SessionID { message := make([]byte, 3*crypto.Curve25519KeyLength) copy(message, s.AliceIdentityKey) copy(message[crypto.Curve25519KeyLength:], s.AliceBaseKey) copy(message[2*crypto.Curve25519KeyLength:], s.BobOneTimeKey) hash := crypto.SHA256(message) - res := id.SessionID(goolm.Base64Encode(hash)) + res := id.SessionID(goolmbase64.Encode(hash)) return res } // HasReceivedMessage returns true if this session has received any message. -func (s OlmSession) HasReceivedMessage() bool { +func (s *OlmSession) HasReceivedMessage() bool { return s.ReceivedMessage } -// MatchesInboundSessionFrom checks if the oneTimeKeyMsg message is set for this inbound -// Session. This can happen if multiple messages are sent to this Account -// before this Account sends a message in reply. Returns true if the session -// matches. Returns false if the session does not match. -func (s OlmSession) MatchesInboundSessionFrom(theirIdentityKeyEncoded *id.Curve25519, receivedOTKMsg []byte) (bool, error) { - if len(receivedOTKMsg) == 0 { - return false, fmt.Errorf("inbound match: %w", goolm.ErrEmptyInput) +// MatchesInboundSession checks if the PRE_KEY message is for this in-bound +// Session. This can happen if multiple messages are sent to this Account +// before this Account sends a message in reply. Returns true if the session +// matches. Returns false if the session does not match. Returns error on +// failure. +func (s *OlmSession) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { + return s.matchesInboundSession(nil, []byte(oneTimeKeyMsg)) +} + +// MatchesInboundSessionFrom checks if the PRE_KEY message is for this in-bound +// Session. This can happen if multiple messages are sent to this Account +// before this Account sends a message in reply. Returns true if the session +// matches. Returns false if the session does not match. Returns error on +// failure. +func (s *OlmSession) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) { + var theirKey *id.Curve25519 + if theirIdentityKey != "" { + theirs := id.Curve25519(theirIdentityKey) + theirKey = &theirs } - decodedOTKMsg, err := goolm.Base64Decode(receivedOTKMsg) + + return s.matchesInboundSession(theirKey, []byte(oneTimeKeyMsg)) +} + +// matchesInboundSession checks if the oneTimeKeyMsg message is set for this +// inbound Session. This can happen if multiple messages are sent to this +// Account before this Account sends a message in reply. Returns true if the +// session matches. Returns false if the session does not match. +func (s *OlmSession) matchesInboundSession(theirIdentityKeyEncoded *id.Curve25519, receivedOTKMsg []byte) (bool, error) { + if len(receivedOTKMsg) == 0 { + return false, fmt.Errorf("inbound match: %w", olm.ErrEmptyInput) + } + decodedOTKMsg, err := goolmbase64.Decode(receivedOTKMsg) if err != nil { return false, err } @@ -266,20 +293,20 @@ func (s OlmSession) MatchesInboundSessionFrom(theirIdentityKeyEncoded *id.Curve2 // EncryptMsgType returns the type of the next message that Encrypt will // return. Returns MsgTypePreKey if the message will be a oneTimeKeyMsg. // Returns MsgTypeMsg if the message will be a normal message. -func (s OlmSession) EncryptMsgType() id.OlmMsgType { +func (s *OlmSession) EncryptMsgType() id.OlmMsgType { if s.ReceivedMessage { return id.OlmMsgTypeMsg } return id.OlmMsgTypePreKey } -// Encrypt encrypts a message using the Session. Returns the encrypted message base64 encoded. If reader is nil, crypto/rand is used for key generations. -func (s *OlmSession) Encrypt(plaintext []byte, reader io.Reader) (id.OlmMsgType, []byte, error) { +// Encrypt encrypts a message using the Session. Returns the encrypted message base64 encoded. +func (s *OlmSession) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) { if len(plaintext) == 0 { - return 0, nil, fmt.Errorf("encrypt: %w", goolm.ErrEmptyInput) + return 0, nil, fmt.Errorf("encrypt: %w", olm.ErrEmptyInput) } messageType := s.EncryptMsgType() - encrypted, err := s.Ratchet.Encrypt(plaintext, reader) + encrypted, err := s.Ratchet.Encrypt(plaintext) if err != nil { return 0, nil, err } @@ -300,15 +327,15 @@ func (s *OlmSession) Encrypt(plaintext []byte, reader io.Reader) (id.OlmMsgType, result = messageBody } - return messageType, goolm.Base64Encode(result), nil + return messageType, goolmbase64.Encode(result), nil } // Decrypt decrypts a base64 encoded message using the Session. -func (s *OlmSession) Decrypt(crypttext []byte, msgType id.OlmMsgType) ([]byte, error) { +func (s *OlmSession) Decrypt(crypttext string, msgType id.OlmMsgType) ([]byte, error) { if len(crypttext) == 0 { - return nil, fmt.Errorf("decrypt: %w", goolm.ErrEmptyInput) + return nil, fmt.Errorf("decrypt: %w", olm.ErrEmptyInput) } - decodedCrypttext, err := goolm.Base64Decode(crypttext) + decodedCrypttext, err := goolmbase64.Decode([]byte(crypttext)) if err != nil { return nil, err } @@ -333,6 +360,9 @@ func (s *OlmSession) Decrypt(crypttext []byte, msgType id.OlmMsgType) ([]byte, e // Unpickle decodes the base64 encoded string and decrypts the result with the key. // The decrypted value is then passed to UnpickleLibOlm. func (o *OlmSession) Unpickle(pickled, key []byte) error { + if len(pickled) == 0 { + return olm.ErrEmptyInput + } decrypted, err := cipher.Unpickle(key, pickled) if err != nil { return err @@ -355,7 +385,7 @@ func (o *OlmSession) UnpickleLibOlm(value []byte) (int, error) { case uint32(0x80000001): includesChainIndex = true default: - return 0, fmt.Errorf("unpickle olmSession: %w", goolm.ErrBadVersion) + return 0, fmt.Errorf("unpickle olmSession: %w", olm.ErrBadVersion) } var readBytes int o.ReceivedMessage, readBytes, err = libolmpickle.UnpickleBool(value[curPos:]) @@ -386,28 +416,28 @@ func (o *OlmSession) UnpickleLibOlm(value []byte) (int, error) { return curPos, nil } -// Pickle returns a base64 encoded and with key encrypted pickled olmSession using PickleLibOlm(). -func (o OlmSession) Pickle(key []byte) ([]byte, error) { - pickeledBytes := make([]byte, o.PickleLen()) - written, err := o.PickleLibOlm(pickeledBytes) +// Pickle returns a base64 encoded and with key encrypted pickled olmSession +// using PickleLibOlm(). +func (s *OlmSession) Pickle(key []byte) ([]byte, error) { + if len(key) == 0 { + return nil, olm.ErrNoKeyProvided + } + pickeledBytes := make([]byte, s.PickleLen()) + written, err := s.PickleLibOlm(pickeledBytes) if err != nil { return nil, err } if written != len(pickeledBytes) { return nil, errors.New("number of written bytes not correct") } - encrypted, err := cipher.Pickle(key, pickeledBytes) - if err != nil { - return nil, err - } - return encrypted, nil + return cipher.Pickle(key, pickeledBytes) } // PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0. // It returns the number of bytes written. -func (o OlmSession) PickleLibOlm(target []byte) (int, error) { +func (o *OlmSession) PickleLibOlm(target []byte) (int, error) { if len(target) < o.PickleLen() { - return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", goolm.ErrValueTooShort) + return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", olm.ErrValueTooShort) } written := libolmpickle.PickleUInt32(olmSessionPickleVersionLibOlm, target) written += libolmpickle.PickleBool(o.ReceivedMessage, target[written:]) @@ -435,7 +465,7 @@ func (o OlmSession) PickleLibOlm(target []byte) (int, error) { } // PickleLen returns the actual number of bytes the pickled session will have. -func (o OlmSession) PickleLen() int { +func (o *OlmSession) PickleLen() int { length := libolmpickle.PickleUInt32Len(olmSessionPickleVersionLibOlm) length += libolmpickle.PickleBoolLen(o.ReceivedMessage) length += o.AliceIdentityKey.PickleLen() @@ -446,7 +476,7 @@ func (o OlmSession) PickleLen() int { } // PickleLenMin returns the minimum number of bytes the pickled session must have. -func (o OlmSession) PickleLenMin() int { +func (o *OlmSession) PickleLenMin() int { length := libolmpickle.PickleUInt32Len(olmSessionPickleVersionLibOlm) length += libolmpickle.PickleBoolLen(o.ReceivedMessage) length += o.AliceIdentityKey.PickleLen() @@ -457,20 +487,17 @@ func (o OlmSession) PickleLenMin() int { } // Describe returns a string describing the current state of the session for debugging. -func (o OlmSession) Describe() string { - var res string - if o.Ratchet.SenderChains.IsSet { - res += fmt.Sprintf("sender chain index: %d ", o.Ratchet.SenderChains.CKey.Index) - } else { - res += "sender chain index: " - } - res += "receiver chain indicies:" +func (o *OlmSession) Describe() string { + var builder strings.Builder + builder.WriteString("sender chain index: ") + builder.WriteString(fmt.Sprint(o.Ratchet.SenderChains.CKey.Index)) + builder.WriteString(" receiver chain indices:") for _, curChain := range o.Ratchet.ReceiverChains { - res += fmt.Sprintf(" %d", curChain.CKey.Index) + builder.WriteString(fmt.Sprintf(" %d", curChain.CKey.Index)) } - res += " skipped message keys:" + builder.WriteString(" skipped message keys:") for _, curSkip := range o.Ratchet.SkippedMessageKeys { - res += fmt.Sprintf(" %d", curSkip.MKey.Index) + builder.WriteString(fmt.Sprintf(" %d", curSkip.MKey.Index)) } - return res + return builder.String() } diff --git a/crypto/goolm/session/olm_session_test.go b/crypto/goolm/session/olm_session_test.go index 11b13c32..b5ff4c32 100644 --- a/crypto/goolm/session/olm_session_test.go +++ b/crypto/goolm/session/olm_session_test.go @@ -6,9 +6,9 @@ import ( "errors" "testing" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/session" + "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) @@ -32,7 +32,7 @@ func TestOlmSession(t *testing.T) { } //create a message so that there are more keys to marshal plaintext := []byte("Test message from Alice to Bob") - msgType, message, err := aliceSession.Encrypt(plaintext, nil) + msgType, message, err := aliceSession.Encrypt(plaintext) if err != nil { t.Fatal(err) } @@ -55,7 +55,7 @@ func TestOlmSession(t *testing.T) { if err != nil { t.Fatal(err) } - decryptedMsg, err := bobSession.Decrypt(message, msgType) + decryptedMsg, err := bobSession.Decrypt(string(message), msgType) if err != nil { t.Fatal(err) } @@ -71,7 +71,7 @@ func TestOlmSession(t *testing.T) { //bob sends a message plaintext = []byte("A message from Bob to Alice") - msgType, message, err = bobSession.Encrypt(plaintext, nil) + msgType, message, err = bobSession.Encrypt(plaintext) if err != nil { t.Fatal(err) } @@ -86,7 +86,7 @@ func TestOlmSession(t *testing.T) { } //Alice receives message - decryptedMsg, err = newAliceSession.Decrypt(message, msgType) + decryptedMsg, err = newAliceSession.Decrypt(string(message), msgType) if err != nil { t.Fatal(err) } @@ -95,14 +95,14 @@ func TestOlmSession(t *testing.T) { } //Alice receives message again - _, err = newAliceSession.Decrypt(message, msgType) + _, err = newAliceSession.Decrypt(string(message), msgType) if err == nil { t.Fatal("should have gotten an error") } //Alice sends another message plaintext = []byte("A second message to Bob") - msgType, message, err = newAliceSession.Encrypt(plaintext, nil) + msgType, message, err = newAliceSession.Encrypt(plaintext) if err != nil { t.Fatal(err) } @@ -110,7 +110,7 @@ func TestOlmSession(t *testing.T) { t.Fatal("Wrong message type") } //bob receives message - decryptedMsg, err = bobSession.Decrypt(message, msgType) + decryptedMsg, err = bobSession.Decrypt(string(message), msgType) if err != nil { t.Fatal(err) } @@ -148,7 +148,7 @@ func TestDecrypts(t *testing.T) { {0xe9, 0xe9, 0xc9, 0xc1, 0xe9, 0xe9, 0xc9, 0xe9, 0xc9, 0xc1, 0xe9, 0xe9, 0xc9, 0xc1}, } expectedErr := []error{ - goolm.ErrInputToSmall, + olm.ErrInputToSmall, // Why are these being tested 🤔 base64.CorruptInputError(0), base64.CorruptInputError(0), @@ -165,7 +165,7 @@ func TestDecrypts(t *testing.T) { t.Fatal(err) } for curIndex, curMessage := range messages { - _, err := sess.Decrypt(curMessage, id.OlmMsgTypePreKey) + _, err := sess.Decrypt(string(curMessage), id.OlmMsgTypePreKey) if err != nil { if !errors.Is(err, expectedErr[curIndex]) { t.Fatal(err) diff --git a/crypto/goolm/session/register.go b/crypto/goolm/session/register.go new file mode 100644 index 00000000..0a8b3605 --- /dev/null +++ b/crypto/goolm/session/register.go @@ -0,0 +1,69 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package session + +import ( + "maunium.net/go/mautrix/crypto/olm" +) + +func init() { + // Inbound Session + olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) { + if len(pickled) == 0 { + return nil, olm.EmptyInput + } + if len(key) == 0 { + key = []byte(" ") + } + return MegolmInboundSessionFromPickled(pickled, key) + } + olm.InitNewInboundGroupSession = func(sessionKey []byte) (olm.InboundGroupSession, error) { + if len(sessionKey) == 0 { + return nil, olm.EmptyInput + } + return NewMegolmInboundSession(sessionKey) + } + olm.InitInboundGroupSessionImport = func(sessionKey []byte) (olm.InboundGroupSession, error) { + if len(sessionKey) == 0 { + return nil, olm.EmptyInput + } + return NewMegolmInboundSessionFromExport(sessionKey) + } + olm.InitBlankInboundGroupSession = func() olm.InboundGroupSession { + return &MegolmInboundSession{} + } + + // Outbound Session + olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) { + if len(pickled) == 0 { + return nil, olm.EmptyInput + } + lenKey := len(key) + if lenKey == 0 { + key = []byte(" ") + } + return MegolmOutboundSessionFromPickled(pickled, key) + } + olm.InitNewOutboundGroupSession = func() olm.OutboundGroupSession { + session, err := NewMegolmOutboundSession() + if err != nil { + panic(err) + } + return session + } + olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession { + return &MegolmOutboundSession{} + } + + // Olm Session + olm.InitSessionFromPickled = func(pickled, key []byte) (olm.Session, error) { + return OlmSessionFromPickled(pickled, key) + } + olm.InitNewBlankSession = func() olm.Session { + return NewOlmSession() + } +} diff --git a/crypto/goolm/utilities/pickle.go b/crypto/goolm/utilities/pickle.go index 993366c8..6ce35efe 100644 --- a/crypto/goolm/utilities/pickle.go +++ b/crypto/goolm/utilities/pickle.go @@ -4,14 +4,14 @@ import ( "encoding/json" "fmt" - "maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/olm" ) // PickleAsJSON returns an object as a base64 string encrypted using the supplied key. The unencrypted representation of the object is in JSON format. func PickleAsJSON(object any, pickleVersion byte, key []byte) ([]byte, error) { if len(key) == 0 { - return nil, fmt.Errorf("pickle: %w", goolm.ErrNoKeyProvided) + return nil, fmt.Errorf("pickle: %w", olm.ErrNoKeyProvided) } marshaled, err := json.Marshal(object) if err != nil { @@ -36,7 +36,7 @@ func PickleAsJSON(object any, pickleVersion byte, key []byte) ([]byte, error) { // UnpickleAsJSON updates the object by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. func UnpickleAsJSON(object any, pickled, key []byte, pickleVersion byte) error { if len(key) == 0 { - return fmt.Errorf("unpickle: %w", goolm.ErrNoKeyProvided) + return fmt.Errorf("unpickle: %w", olm.ErrNoKeyProvided) } decrypted, err := cipher.Unpickle(key, pickled) if err != nil { @@ -50,7 +50,7 @@ func UnpickleAsJSON(object any, pickled, key []byte, pickleVersion byte) error { } } if decrypted[0] != pickleVersion { - return fmt.Errorf("unpickle: %w", goolm.ErrWrongPickleVersion) + return fmt.Errorf("unpickle: %w", olm.ErrWrongPickleVersion) } err = json.Unmarshal(decrypted[1:], object) if err != nil { diff --git a/crypto/keybackup.go b/crypto/keybackup.go index 3e65f4c1..4e9431bb 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -169,7 +169,7 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id. } igs := &InboundGroupSession{ - Internal: *igsInternal, + Internal: igsInternal, SigningKey: keyBackupData.SenderClaimedKeys.Ed25519, SenderKey: keyBackupData.SenderKey, RoomID: roomID, diff --git a/crypto/keyimport.go b/crypto/keyimport.go index 693ff6b8..108c67ac 100644 --- a/crypto/keyimport.go +++ b/crypto/keyimport.go @@ -104,7 +104,7 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor return false, ErrMismatchingExportedSessionID } igs := &InboundGroupSession{ - Internal: *igsInternal, + Internal: igsInternal, SigningKey: session.SenderClaimedKeys.Ed25519, SenderKey: session.SenderKey, RoomID: session.RoomID, diff --git a/crypto/keysharing.go b/crypto/keysharing.go index 4d3b6f7e..38e015c6 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -174,7 +174,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt log.Warn().Uint32("first_known_index", firstKnownIndex).Msg("Importing partial session") } igs := &InboundGroupSession{ - Internal: *igsInternal, + Internal: igsInternal, SigningKey: evt.Keys.Ed25519, SenderKey: content.SenderKey, RoomID: content.RoomID, diff --git a/crypto/libolm/account.go b/crypto/libolm/account.go new file mode 100644 index 00000000..ad329fa3 --- /dev/null +++ b/crypto/libolm/account.go @@ -0,0 +1,417 @@ +package libolm + +// #cgo LDFLAGS: -lolm -lstdc++ +// #include +import "C" + +import ( + "crypto/rand" + "encoding/base64" + "encoding/json" + "io" + "unsafe" + + "github.com/tidwall/gjson" + + "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/id" +) + +// Account stores a device account for end to end encrypted messaging. +type Account struct { + int *C.OlmAccount + mem []byte +} + +func init() { + olm.InitNewAccount = func(r io.Reader) (olm.Account, error) { + return NewAccount(r) + } + olm.InitBlankAccount = func() olm.Account { + return NewBlankAccount() + } + olm.InitNewAccountFromPickled = func(pickled, key []byte) (olm.Account, error) { + return AccountFromPickled(pickled, key) + } +} + +// Ensure that [Account] implements [olm.Account]. +var _ olm.Account = (*Account)(nil) + +// AccountFromPickled loads an Account from a pickled base64 string. Decrypts +// the Account using the supplied key. Returns error on failure. If the key +// doesn't match the one used to encrypt the Account then the error will be +// "BAD_ACCOUNT_KEY". If the base64 couldn't be decoded then the error will be +// "INVALID_BASE64". +func AccountFromPickled(pickled, key []byte) (*Account, error) { + if len(pickled) == 0 { + return nil, olm.EmptyInput + } + a := NewBlankAccount() + return a, a.Unpickle(pickled, key) +} + +func NewBlankAccount() *Account { + memory := make([]byte, accountSize()) + return &Account{ + int: C.olm_account(unsafe.Pointer(&memory[0])), + mem: memory, + } +} + +// NewAccount creates a new [Account]. +func NewAccount(r io.Reader) (*Account, error) { + a := NewBlankAccount() + random := make([]byte, a.createRandomLen()+1) + if r == nil { + r = rand.Reader + } + _, err := r.Read(random) + if err != nil { + panic(olm.NotEnoughGoRandom) + } + ret := C.olm_create_account( + (*C.OlmAccount)(a.int), + unsafe.Pointer(&random[0]), + C.size_t(len(random))) + if ret == errorVal() { + return nil, a.lastError() + } else { + return a, nil + } +} + +// accountSize returns the size of an account object in bytes. +func accountSize() uint { + return uint(C.olm_account_size()) +} + +// lastError returns an error describing the most recent error to happen to an +// account. +func (a *Account) lastError() error { + return convertError(C.GoString(C.olm_account_last_error((*C.OlmAccount)(a.int)))) +} + +// Clear clears the memory used to back this Account. +func (a *Account) Clear() error { + r := C.olm_clear_account((*C.OlmAccount)(a.int)) + if r == errorVal() { + return a.lastError() + } else { + return nil + } +} + +// pickleLen returns the number of bytes needed to store an Account. +func (a *Account) pickleLen() uint { + return uint(C.olm_pickle_account_length((*C.OlmAccount)(a.int))) +} + +// createRandomLen returns the number of random bytes needed to create an +// Account. +func (a *Account) createRandomLen() uint { + return uint(C.olm_create_account_random_length((*C.OlmAccount)(a.int))) +} + +// identityKeysLen returns the size of the output buffer needed to hold the +// identity keys. +func (a *Account) identityKeysLen() uint { + return uint(C.olm_account_identity_keys_length((*C.OlmAccount)(a.int))) +} + +// signatureLen returns the length of an ed25519 signature encoded as base64. +func (a *Account) signatureLen() uint { + return uint(C.olm_account_signature_length((*C.OlmAccount)(a.int))) +} + +// oneTimeKeysLen returns the size of the output buffer needed to hold the one +// time keys. +func (a *Account) oneTimeKeysLen() uint { + return uint(C.olm_account_one_time_keys_length((*C.OlmAccount)(a.int))) +} + +// genOneTimeKeysRandomLen returns the number of random bytes needed to +// generate a given number of new one time keys. +func (a *Account) genOneTimeKeysRandomLen(num uint) uint { + return uint(C.olm_account_generate_one_time_keys_random_length( + (*C.OlmAccount)(a.int), + C.size_t(num))) +} + +// Pickle returns an Account as a base64 string. Encrypts the Account using the +// supplied key. +func (a *Account) Pickle(key []byte) ([]byte, error) { + if len(key) == 0 { + return nil, olm.NoKeyProvided + } + pickled := make([]byte, a.pickleLen()) + r := C.olm_pickle_account( + (*C.OlmAccount)(a.int), + unsafe.Pointer(&key[0]), + C.size_t(len(key)), + unsafe.Pointer(&pickled[0]), + C.size_t(len(pickled))) + if r == errorVal() { + return nil, a.lastError() + } + return pickled[:r], nil +} + +func (a *Account) Unpickle(pickled, key []byte) error { + if len(key) == 0 { + return olm.NoKeyProvided + } + r := C.olm_unpickle_account( + (*C.OlmAccount)(a.int), + unsafe.Pointer(&key[0]), + C.size_t(len(key)), + unsafe.Pointer(&pickled[0]), + C.size_t(len(pickled))) + if r == errorVal() { + return a.lastError() + } + return nil +} + +// Deprecated +func (a *Account) GobEncode() ([]byte, error) { + pickled, err := a.Pickle(pickleKey) + if err != nil { + return nil, err + } + length := base64.RawStdEncoding.DecodedLen(len(pickled)) + rawPickled := make([]byte, length) + _, err = base64.RawStdEncoding.Decode(rawPickled, pickled) + return rawPickled, err +} + +// Deprecated +func (a *Account) GobDecode(rawPickled []byte) error { + if a.int == nil { + *a = *NewBlankAccount() + } + length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) + pickled := make([]byte, length) + base64.RawStdEncoding.Encode(pickled, rawPickled) + return a.Unpickle(pickled, pickleKey) +} + +// Deprecated +func (a *Account) MarshalJSON() ([]byte, error) { + pickled, err := a.Pickle(pickleKey) + if err != nil { + return nil, err + } + quotes := make([]byte, len(pickled)+2) + quotes[0] = '"' + quotes[len(quotes)-1] = '"' + copy(quotes[1:len(quotes)-1], pickled) + return quotes, nil +} + +// Deprecated +func (a *Account) UnmarshalJSON(data []byte) error { + if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { + return olm.InputNotJSONString + } + if a.int == nil { + *a = *NewBlankAccount() + } + return a.Unpickle(data[1:len(data)-1], pickleKey) +} + +// IdentityKeysJSON returns the public parts of the identity keys for the Account. +func (a *Account) IdentityKeysJSON() ([]byte, error) { + identityKeys := make([]byte, a.identityKeysLen()) + r := C.olm_account_identity_keys( + (*C.OlmAccount)(a.int), + unsafe.Pointer(&identityKeys[0]), + C.size_t(len(identityKeys))) + if r == errorVal() { + return nil, a.lastError() + } else { + return identityKeys, nil + } +} + +// IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity +// keys for the Account. +func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519, error) { + identityKeysJSON, err := a.IdentityKeysJSON() + if err != nil { + return "", "", err + } + results := gjson.GetManyBytes(identityKeysJSON, "ed25519", "curve25519") + return id.Ed25519(results[0].Str), id.Curve25519(results[1].Str), nil +} + +// Sign returns the signature of a message using the ed25519 key for this +// Account. +func (a *Account) Sign(message []byte) ([]byte, error) { + if len(message) == 0 { + panic(olm.EmptyInput) + } + signature := make([]byte, a.signatureLen()) + r := C.olm_account_sign( + (*C.OlmAccount)(a.int), + unsafe.Pointer(&message[0]), + C.size_t(len(message)), + unsafe.Pointer(&signature[0]), + C.size_t(len(signature))) + if r == errorVal() { + panic(a.lastError()) + } + return signature, nil +} + +// OneTimeKeys returns the public parts of the unpublished one time keys for +// the Account. +// +// The returned data is a struct with the single value "Curve25519", which is +// itself an object mapping key id to base64-encoded Curve25519 key. For +// example: +// +// { +// Curve25519: { +// "AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo", +// "AAAAAB": "LRvjo46L1X2vx69sS9QNFD29HWulxrmW11Up5AfAjgU" +// } +// } +func (a *Account) OneTimeKeys() (map[string]id.Curve25519, error) { + oneTimeKeysJSON := make([]byte, a.oneTimeKeysLen()) + r := C.olm_account_one_time_keys( + (*C.OlmAccount)(a.int), + unsafe.Pointer(&oneTimeKeysJSON[0]), + C.size_t(len(oneTimeKeysJSON))) + if r == errorVal() { + return nil, a.lastError() + } + var oneTimeKeys struct { + Curve25519 map[string]id.Curve25519 `json:"curve25519"` + } + return oneTimeKeys.Curve25519, json.Unmarshal(oneTimeKeysJSON, &oneTimeKeys) +} + +// MarkKeysAsPublished marks the current set of one time keys as being +// published. +func (a *Account) MarkKeysAsPublished() { + C.olm_account_mark_keys_as_published((*C.OlmAccount)(a.int)) +} + +// MaxNumberOfOneTimeKeys returns the largest number of one time keys this +// Account can store. +func (a *Account) MaxNumberOfOneTimeKeys() uint { + return uint(C.olm_account_max_number_of_one_time_keys((*C.OlmAccount)(a.int))) +} + +// GenOneTimeKeys generates a number of new one time keys. If the total number +// of keys stored by this Account exceeds MaxNumberOfOneTimeKeys then the old +// keys are discarded. +func (a *Account) GenOneTimeKeys(reader io.Reader, num uint) error { + random := make([]byte, a.genOneTimeKeysRandomLen(num)+1) + if reader == nil { + reader = rand.Reader + } + _, err := reader.Read(random) + if err != nil { + return olm.NotEnoughGoRandom + } + r := C.olm_account_generate_one_time_keys( + (*C.OlmAccount)(a.int), + C.size_t(num), + unsafe.Pointer(&random[0]), + C.size_t(len(random))) + if r == errorVal() { + return a.lastError() + } + return nil +} + +// NewOutboundSession creates a new out-bound session for sending messages to a +// given curve25519 identityKey and oneTimeKey. Returns error on failure. If the +// keys couldn't be decoded as base64 then the error will be "INVALID_BASE64" +func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (olm.Session, error) { + if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 { + return nil, olm.EmptyInput + } + s := NewBlankSession() + random := make([]byte, s.createOutboundRandomLen()+1) + _, err := rand.Read(random) + if err != nil { + panic(olm.NotEnoughGoRandom) + } + r := C.olm_create_outbound_session( + (*C.OlmSession)(s.int), + (*C.OlmAccount)(a.int), + unsafe.Pointer(&([]byte(theirIdentityKey)[0])), + C.size_t(len(theirIdentityKey)), + unsafe.Pointer(&([]byte(theirOneTimeKey)[0])), + C.size_t(len(theirOneTimeKey)), + unsafe.Pointer(&random[0]), + C.size_t(len(random))) + if r == errorVal() { + return nil, s.lastError() + } + return s, nil +} + +// NewInboundSession creates a new in-bound session for sending/receiving +// messages from an incoming PRE_KEY message. Returns error on failure. If +// the base64 couldn't be decoded then the error will be "INVALID_BASE64". If +// the message was for an unsupported protocol version then the error will be +// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the +// error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one +// time key then the error will be "BAD_MESSAGE_KEY_ID". +func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) { + if len(oneTimeKeyMsg) == 0 { + return nil, olm.EmptyInput + } + s := NewBlankSession() + r := C.olm_create_inbound_session( + (*C.OlmSession)(s.int), + (*C.OlmAccount)(a.int), + unsafe.Pointer(&([]byte(oneTimeKeyMsg)[0])), + C.size_t(len(oneTimeKeyMsg))) + if r == errorVal() { + return nil, s.lastError() + } + return s, nil +} + +// NewInboundSessionFrom creates a new in-bound session for sending/receiving +// messages from an incoming PRE_KEY message. Returns error on failure. If +// the base64 couldn't be decoded then the error will be "INVALID_BASE64". If +// the message was for an unsupported protocol version then the error will be +// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the +// error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one +// time key then the error will be "BAD_MESSAGE_KEY_ID". +func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (olm.Session, error) { + if theirIdentityKey == nil || len(oneTimeKeyMsg) == 0 { + return nil, olm.EmptyInput + } + s := NewBlankSession() + r := C.olm_create_inbound_session_from( + (*C.OlmSession)(s.int), + (*C.OlmAccount)(a.int), + unsafe.Pointer(&([]byte(*theirIdentityKey)[0])), + C.size_t(len(*theirIdentityKey)), + unsafe.Pointer(&([]byte(oneTimeKeyMsg)[0])), + C.size_t(len(oneTimeKeyMsg))) + if r == errorVal() { + return nil, s.lastError() + } + return s, nil +} + +// RemoveOneTimeKeys removes the one time keys that the session used from the +// Account. Returns error on failure. If the Account doesn't have any +// matching one time keys then the error will be "BAD_MESSAGE_KEY_ID". +func (a *Account) RemoveOneTimeKeys(s olm.Session) error { + r := C.olm_remove_one_time_keys( + (*C.OlmAccount)(a.int), + (*C.OlmSession)(s.(*Session).int)) + if r == errorVal() { + return a.lastError() + } + return nil +} diff --git a/crypto/libolm/error.go b/crypto/libolm/error.go new file mode 100644 index 00000000..9ca415ee --- /dev/null +++ b/crypto/libolm/error.go @@ -0,0 +1,37 @@ +package libolm + +// #cgo LDFLAGS: -lolm -lstdc++ +// #include +import "C" + +import ( + "fmt" + + "maunium.net/go/mautrix/crypto/olm" +) + +var errorMap = map[string]error{ + "NOT_ENOUGH_RANDOM": olm.NotEnoughRandom, + "OUTPUT_BUFFER_TOO_SMALL": olm.OutputBufferTooSmall, + "BAD_MESSAGE_VERSION": olm.BadMessageVersion, + "BAD_MESSAGE_FORMAT": olm.BadMessageFormat, + "BAD_MESSAGE_MAC": olm.BadMessageMAC, + "BAD_MESSAGE_KEY_ID": olm.BadMessageKeyID, + "INVALID_BASE64": olm.InvalidBase64, + "BAD_ACCOUNT_KEY": olm.BadAccountKey, + "UNKNOWN_PICKLE_VERSION": olm.UnknownPickleVersion, + "CORRUPTED_PICKLE": olm.CorruptedPickle, + "BAD_SESSION_KEY": olm.BadSessionKey, + "UNKNOWN_MESSAGE_INDEX": olm.UnknownMessageIndex, + "BAD_LEGACY_ACCOUNT_PICKLE": olm.BadLegacyAccountPickle, + "BAD_SIGNATURE": olm.BadSignature, + "INPUT_BUFFER_TOO_SMALL": olm.InputBufferTooSmall, +} + +func convertError(errCode string) error { + err, ok := errorMap[errCode] + if ok { + return err + } + return fmt.Errorf("unknown error: %s", errCode) +} diff --git a/crypto/libolm/inboundgroupsession.go b/crypto/libolm/inboundgroupsession.go new file mode 100644 index 00000000..1e25748d --- /dev/null +++ b/crypto/libolm/inboundgroupsession.go @@ -0,0 +1,328 @@ +package libolm + +// #cgo LDFLAGS: -lolm -lstdc++ +// #include +import "C" + +import ( + "bytes" + "encoding/base64" + "unsafe" + + "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/id" +) + +// InboundGroupSession stores an inbound encrypted messaging session for a +// group. +type InboundGroupSession struct { + int *C.OlmInboundGroupSession + mem []byte +} + +func init() { + olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) { + return InboundGroupSessionFromPickled(pickled, key) + } + olm.InitNewInboundGroupSession = func(sessionKey []byte) (olm.InboundGroupSession, error) { + return NewInboundGroupSession(sessionKey) + } + olm.InitInboundGroupSessionImport = func(sessionKey []byte) (olm.InboundGroupSession, error) { + return InboundGroupSessionImport(sessionKey) + } + olm.InitBlankInboundGroupSession = func() olm.InboundGroupSession { + return NewBlankInboundGroupSession() + } +} + +// Ensure that [InboundGroupSession] implements [olm.InboundGroupSession]. +var _ olm.InboundGroupSession = (*InboundGroupSession)(nil) + +// InboundGroupSessionFromPickled loads an InboundGroupSession from a pickled +// base64 string. Decrypts the InboundGroupSession using the supplied key. +// Returns error on failure. If the key doesn't match the one used to encrypt +// the InboundGroupSession then the error will be "BAD_SESSION_KEY". If the +// base64 couldn't be decoded then the error will be "INVALID_BASE64". +func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) { + if len(pickled) == 0 { + return nil, olm.EmptyInput + } + lenKey := len(key) + if lenKey == 0 { + key = []byte(" ") + } + s := NewBlankInboundGroupSession() + return s, s.Unpickle(pickled, key) +} + +// NewInboundGroupSession creates a new inbound group session from a key +// exported from OutboundGroupSession.Key(). Returns error on failure. +// If the sessionKey is not valid base64 the error will be +// "OLM_INVALID_BASE64". If the session_key is invalid the error will be +// "OLM_BAD_SESSION_KEY". +func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { + if len(sessionKey) == 0 { + return nil, olm.EmptyInput + } + s := NewBlankInboundGroupSession() + r := C.olm_init_inbound_group_session( + (*C.OlmInboundGroupSession)(s.int), + (*C.uint8_t)(&sessionKey[0]), + C.size_t(len(sessionKey))) + if r == errorVal() { + return nil, s.lastError() + } + return s, nil +} + +// InboundGroupSessionImport imports an inbound group session from a previous +// export. Returns error on failure. If the sessionKey is not valid base64 +// the error will be "OLM_INVALID_BASE64". If the session_key is invalid the +// error will be "OLM_BAD_SESSION_KEY". +func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) { + if len(sessionKey) == 0 { + return nil, olm.EmptyInput + } + s := NewBlankInboundGroupSession() + r := C.olm_import_inbound_group_session( + (*C.OlmInboundGroupSession)(s.int), + (*C.uint8_t)(&sessionKey[0]), + C.size_t(len(sessionKey))) + if r == errorVal() { + return nil, s.lastError() + } + return s, nil +} + +// inboundGroupSessionSize is the size of an inbound group session object in +// bytes. +func inboundGroupSessionSize() uint { + return uint(C.olm_inbound_group_session_size()) +} + +// newInboundGroupSession initialises an empty InboundGroupSession. +func NewBlankInboundGroupSession() *InboundGroupSession { + memory := make([]byte, inboundGroupSessionSize()) + return &InboundGroupSession{ + int: C.olm_inbound_group_session(unsafe.Pointer(&memory[0])), + mem: memory, + } +} + +// lastError returns an error describing the most recent error to happen to an +// inbound group session. +func (s *InboundGroupSession) lastError() error { + return convertError(C.GoString(C.olm_inbound_group_session_last_error((*C.OlmInboundGroupSession)(s.int)))) +} + +// Clear clears the memory used to back this InboundGroupSession. +func (s *InboundGroupSession) Clear() error { + r := C.olm_clear_inbound_group_session((*C.OlmInboundGroupSession)(s.int)) + if r == errorVal() { + return s.lastError() + } + return nil +} + +// pickleLen returns the number of bytes needed to store an inbound group +// session. +func (s *InboundGroupSession) pickleLen() uint { + return uint(C.olm_pickle_inbound_group_session_length((*C.OlmInboundGroupSession)(s.int))) +} + +// Pickle returns an InboundGroupSession as a base64 string. Encrypts the +// InboundGroupSession using the supplied key. +func (s *InboundGroupSession) Pickle(key []byte) ([]byte, error) { + if len(key) == 0 { + return nil, olm.NoKeyProvided + } + pickled := make([]byte, s.pickleLen()) + r := C.olm_pickle_inbound_group_session( + (*C.OlmInboundGroupSession)(s.int), + unsafe.Pointer(&key[0]), + C.size_t(len(key)), + unsafe.Pointer(&pickled[0]), + C.size_t(len(pickled))) + if r == errorVal() { + return nil, s.lastError() + } + return pickled[:r], nil +} + +func (s *InboundGroupSession) Unpickle(pickled, key []byte) error { + if len(key) == 0 { + return olm.NoKeyProvided + } else if len(pickled) == 0 { + return olm.EmptyInput + } + r := C.olm_unpickle_inbound_group_session( + (*C.OlmInboundGroupSession)(s.int), + unsafe.Pointer(&key[0]), + C.size_t(len(key)), + unsafe.Pointer(&pickled[0]), + C.size_t(len(pickled))) + if r == errorVal() { + return s.lastError() + } + return nil +} + +// Deprecated +func (s *InboundGroupSession) GobEncode() ([]byte, error) { + pickled, err := s.Pickle(pickleKey) + if err != nil { + return nil, err + } + length := base64.RawStdEncoding.DecodedLen(len(pickled)) + rawPickled := make([]byte, length) + _, err = base64.RawStdEncoding.Decode(rawPickled, pickled) + return rawPickled, err +} + +// Deprecated +func (s *InboundGroupSession) GobDecode(rawPickled []byte) error { + if s == nil || s.int == nil { + *s = *NewBlankInboundGroupSession() + } + length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) + pickled := make([]byte, length) + base64.RawStdEncoding.Encode(pickled, rawPickled) + return s.Unpickle(pickled, pickleKey) +} + +// Deprecated +func (s *InboundGroupSession) MarshalJSON() ([]byte, error) { + pickled, err := s.Pickle(pickleKey) + if err != nil { + return nil, err + } + quotes := make([]byte, len(pickled)+2) + quotes[0] = '"' + quotes[len(quotes)-1] = '"' + copy(quotes[1:len(quotes)-1], pickled) + return quotes, nil +} + +// Deprecated +func (s *InboundGroupSession) UnmarshalJSON(data []byte) error { + if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { + return olm.InputNotJSONString + } + if s == nil || s.int == nil { + *s = *NewBlankInboundGroupSession() + } + return s.Unpickle(data[1:len(data)-1], pickleKey) +} + +// decryptMaxPlaintextLen returns the maximum number of bytes of plain-text a +// given message could decode to. The actual size could be different due to +// padding. Returns error on failure. If the message base64 couldn't be +// decoded then the error will be "INVALID_BASE64". If the message is for an +// unsupported version of the protocol then the error will be +// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error +// will be "BAD_MESSAGE_FORMAT". +func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, error) { + if len(message) == 0 { + return 0, olm.EmptyInput + } + // olm_group_decrypt_max_plaintext_length destroys the input, so we have to clone it + message = bytes.Clone(message) + r := C.olm_group_decrypt_max_plaintext_length( + (*C.OlmInboundGroupSession)(s.int), + (*C.uint8_t)(&message[0]), + C.size_t(len(message))) + if r == errorVal() { + return 0, s.lastError() + } + return uint(r), nil +} + +// Decrypt decrypts a message using the InboundGroupSession. Returns the the +// plain-text and message index on success. Returns error on failure. If the +// base64 couldn't be decoded then the error will be "INVALID_BASE64". If the +// message is for an unsupported version of the protocol then the error will be +// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error +// will be BAD_MESSAGE_FORMAT". If the MAC on the message was invalid then the +// error will be "BAD_MESSAGE_MAC". If we do not have a session key +// corresponding to the message's index (ie, it was sent before the session key +// was shared with us) the error will be "OLM_UNKNOWN_MESSAGE_INDEX". +func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) { + if len(message) == 0 { + return nil, 0, olm.EmptyInput + } + decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message) + if err != nil { + return nil, 0, err + } + messageCopy := make([]byte, len(message)) + copy(messageCopy, message) + plaintext := make([]byte, decryptMaxPlaintextLen) + var messageIndex uint32 + r := C.olm_group_decrypt( + (*C.OlmInboundGroupSession)(s.int), + (*C.uint8_t)(&messageCopy[0]), + C.size_t(len(messageCopy)), + (*C.uint8_t)(&plaintext[0]), + C.size_t(len(plaintext)), + (*C.uint32_t)(&messageIndex)) + if r == errorVal() { + return nil, 0, s.lastError() + } + return plaintext[:r], uint(messageIndex), nil +} + +// sessionIdLen returns the number of bytes needed to store a session ID. +func (s *InboundGroupSession) sessionIdLen() uint { + return uint(C.olm_inbound_group_session_id_length((*C.OlmInboundGroupSession)(s.int))) +} + +// ID returns a base64-encoded identifier for this session. +func (s *InboundGroupSession) ID() id.SessionID { + sessionID := make([]byte, s.sessionIdLen()) + r := C.olm_inbound_group_session_id( + (*C.OlmInboundGroupSession)(s.int), + (*C.uint8_t)(&sessionID[0]), + C.size_t(len(sessionID))) + if r == errorVal() { + panic(s.lastError()) + } + return id.SessionID(sessionID[:r]) +} + +// FirstKnownIndex returns the first message index we know how to decrypt. +func (s *InboundGroupSession) FirstKnownIndex() uint32 { + return uint32(C.olm_inbound_group_session_first_known_index((*C.OlmInboundGroupSession)(s.int))) +} + +// IsVerified check if the session has been verified as a valid session. (A +// session is verified either because the original session share was signed, or +// because we have subsequently successfully decrypted a message.) +func (s *InboundGroupSession) IsVerified() bool { + return uint(C.olm_inbound_group_session_is_verified((*C.OlmInboundGroupSession)(s.int))) == 1 +} + +// exportLen returns the number of bytes needed to export an inbound group +// session. +func (s *InboundGroupSession) exportLen() uint { + return uint(C.olm_export_inbound_group_session_length((*C.OlmInboundGroupSession)(s.int))) +} + +// Export returns the base64-encoded ratchet key for this session, at the given +// index, in a format which can be used by +// InboundGroupSession.InboundGroupSessionImport(). Encrypts the +// InboundGroupSession using the supplied key. Returns error on failure. +// if we do not have a session key corresponding to the given index (ie, it was +// sent before the session key was shared with us) the error will be +// "OLM_UNKNOWN_MESSAGE_INDEX". +func (s *InboundGroupSession) Export(messageIndex uint32) ([]byte, error) { + key := make([]byte, s.exportLen()) + r := C.olm_export_inbound_group_session( + (*C.OlmInboundGroupSession)(s.int), + (*C.uint8_t)(&key[0]), + C.size_t(len(key)), + C.uint32_t(messageIndex)) + if r == errorVal() { + return nil, s.lastError() + } + return key[:r], nil +} diff --git a/crypto/libolm/libolm.go b/crypto/libolm/libolm.go new file mode 100644 index 00000000..18815767 --- /dev/null +++ b/crypto/libolm/libolm.go @@ -0,0 +1,10 @@ +package libolm + +// #cgo LDFLAGS: -lolm -lstdc++ +// #include +import "C" + +// errorVal returns the value that olm functions return if there was an error. +func errorVal() C.size_t { + return C.olm_error() +} diff --git a/crypto/libolm/outboundgroupsession.go b/crypto/libolm/outboundgroupsession.go new file mode 100644 index 00000000..cb2ce38b --- /dev/null +++ b/crypto/libolm/outboundgroupsession.go @@ -0,0 +1,249 @@ +package libolm + +// #cgo LDFLAGS: -lolm -lstdc++ +// #include +import "C" + +import ( + "crypto/rand" + "encoding/base64" + "unsafe" + + "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/id" +) + +// OutboundGroupSession stores an outbound encrypted messaging session +// for a group. +type OutboundGroupSession struct { + int *C.OlmOutboundGroupSession + mem []byte +} + +func init() { + olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) { + if len(pickled) == 0 { + return nil, olm.EmptyInput + } + s := NewBlankOutboundGroupSession() + return s, s.Unpickle(pickled, key) + } + olm.InitNewOutboundGroupSession = func() olm.OutboundGroupSession { + return NewOutboundGroupSession() + } + olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession { + return NewBlankOutboundGroupSession() + } +} + +// Ensure that [OutboundGroupSession] implements [olm.OutboundGroupSession]. +var _ olm.OutboundGroupSession = (*OutboundGroupSession)(nil) + +func NewOutboundGroupSession() *OutboundGroupSession { + s := NewBlankOutboundGroupSession() + random := make([]byte, s.createRandomLen()+1) + _, err := rand.Read(random) + if err != nil { + panic(olm.NotEnoughGoRandom) + } + r := C.olm_init_outbound_group_session( + (*C.OlmOutboundGroupSession)(s.int), + (*C.uint8_t)(&random[0]), + C.size_t(len(random))) + if r == errorVal() { + panic(s.lastError()) + } + return s +} + +// outboundGroupSessionSize is the size of an outbound group session object in +// bytes. +func outboundGroupSessionSize() uint { + return uint(C.olm_outbound_group_session_size()) +} + +// NewBlankOutboundGroupSession initialises an empty [OutboundGroupSession]. +func NewBlankOutboundGroupSession() *OutboundGroupSession { + memory := make([]byte, outboundGroupSessionSize()) + return &OutboundGroupSession{ + int: C.olm_outbound_group_session(unsafe.Pointer(&memory[0])), + mem: memory, + } +} + +// lastError returns an error describing the most recent error to happen to an +// outbound group session. +func (s *OutboundGroupSession) lastError() error { + return convertError(C.GoString(C.olm_outbound_group_session_last_error((*C.OlmOutboundGroupSession)(s.int)))) +} + +// Clear clears the memory used to back this OutboundGroupSession. +func (s *OutboundGroupSession) Clear() error { + r := C.olm_clear_outbound_group_session((*C.OlmOutboundGroupSession)(s.int)) + if r == errorVal() { + return s.lastError() + } else { + return nil + } +} + +// pickleLen returns the number of bytes needed to store an outbound group +// session. +func (s *OutboundGroupSession) pickleLen() uint { + return uint(C.olm_pickle_outbound_group_session_length((*C.OlmOutboundGroupSession)(s.int))) +} + +// Pickle returns an OutboundGroupSession as a base64 string. Encrypts the +// OutboundGroupSession using the supplied key. +func (s *OutboundGroupSession) Pickle(key []byte) ([]byte, error) { + if len(key) == 0 { + return nil, olm.NoKeyProvided + } + pickled := make([]byte, s.pickleLen()) + r := C.olm_pickle_outbound_group_session( + (*C.OlmOutboundGroupSession)(s.int), + unsafe.Pointer(&key[0]), + C.size_t(len(key)), + unsafe.Pointer(&pickled[0]), + C.size_t(len(pickled))) + if r == errorVal() { + return nil, s.lastError() + } + return pickled[:r], nil +} + +func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error { + if len(key) == 0 { + return olm.NoKeyProvided + } + r := C.olm_unpickle_outbound_group_session( + (*C.OlmOutboundGroupSession)(s.int), + unsafe.Pointer(&key[0]), + C.size_t(len(key)), + unsafe.Pointer(&pickled[0]), + C.size_t(len(pickled))) + if r == errorVal() { + return s.lastError() + } + return nil +} + +// Deprecated +func (s *OutboundGroupSession) GobEncode() ([]byte, error) { + pickled, err := s.Pickle(pickleKey) + if err != nil { + return nil, err + } + length := base64.RawStdEncoding.DecodedLen(len(pickled)) + rawPickled := make([]byte, length) + _, err = base64.RawStdEncoding.Decode(rawPickled, pickled) + return rawPickled, err +} + +// Deprecated +func (s *OutboundGroupSession) GobDecode(rawPickled []byte) error { + if s == nil || s.int == nil { + *s = *NewBlankOutboundGroupSession() + } + length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) + pickled := make([]byte, length) + base64.RawStdEncoding.Encode(pickled, rawPickled) + return s.Unpickle(pickled, pickleKey) +} + +// Deprecated +func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) { + pickled, err := s.Pickle(pickleKey) + if err != nil { + return nil, err + } + quotes := make([]byte, len(pickled)+2) + quotes[0] = '"' + quotes[len(quotes)-1] = '"' + copy(quotes[1:len(quotes)-1], pickled) + return quotes, nil +} + +// Deprecated +func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error { + if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { + return olm.InputNotJSONString + } + if s == nil || s.int == nil { + *s = *NewBlankOutboundGroupSession() + } + return s.Unpickle(data[1:len(data)-1], pickleKey) +} + +// createRandomLen returns the number of random bytes needed to create an +// Account. +func (s *OutboundGroupSession) createRandomLen() uint { + return uint(C.olm_init_outbound_group_session_random_length((*C.OlmOutboundGroupSession)(s.int))) +} + +// encryptMsgLen returns the size of the next message in bytes for the given +// number of plain-text bytes. +func (s *OutboundGroupSession) encryptMsgLen(plainTextLen int) uint { + return uint(C.olm_group_encrypt_message_length((*C.OlmOutboundGroupSession)(s.int), C.size_t(plainTextLen))) +} + +// Encrypt encrypts a message using the Session. Returns the encrypted message +// as base64. +func (s *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) { + if len(plaintext) == 0 { + return nil, olm.EmptyInput + } + message := make([]byte, s.encryptMsgLen(len(plaintext))) + r := C.olm_group_encrypt( + (*C.OlmOutboundGroupSession)(s.int), + (*C.uint8_t)(&plaintext[0]), + C.size_t(len(plaintext)), + (*C.uint8_t)(&message[0]), + C.size_t(len(message))) + if r == errorVal() { + return nil, s.lastError() + } + return message[:r], nil +} + +// sessionIdLen returns the number of bytes needed to store a session ID. +func (s *OutboundGroupSession) sessionIdLen() uint { + return uint(C.olm_outbound_group_session_id_length((*C.OlmOutboundGroupSession)(s.int))) +} + +// ID returns a base64-encoded identifier for this session. +func (s *OutboundGroupSession) ID() id.SessionID { + sessionID := make([]byte, s.sessionIdLen()) + r := C.olm_outbound_group_session_id( + (*C.OlmOutboundGroupSession)(s.int), + (*C.uint8_t)(&sessionID[0]), + C.size_t(len(sessionID))) + if r == errorVal() { + panic(s.lastError()) + } + return id.SessionID(sessionID[:r]) +} + +// MessageIndex returns the message index for this session. Each message is +// sent with an increasing index; this returns the index for the next message. +func (s *OutboundGroupSession) MessageIndex() uint { + return uint(C.olm_outbound_group_session_message_index((*C.OlmOutboundGroupSession)(s.int))) +} + +// sessionKeyLen returns the number of bytes needed to store a session key. +func (s *OutboundGroupSession) sessionKeyLen() uint { + return uint(C.olm_outbound_group_session_key_length((*C.OlmOutboundGroupSession)(s.int))) +} + +// Key returns the base64-encoded current ratchet key for this session. +func (s *OutboundGroupSession) Key() string { + sessionKey := make([]byte, s.sessionKeyLen()) + r := C.olm_outbound_group_session_key( + (*C.OlmOutboundGroupSession)(s.int), + (*C.uint8_t)(&sessionKey[0]), + C.size_t(len(sessionKey))) + if r == errorVal() { + panic(s.lastError()) + } + return string(sessionKey[:r]) +} diff --git a/crypto/olm/pk_libolm.go b/crypto/libolm/pk.go similarity index 71% rename from crypto/olm/pk_libolm.go rename to crypto/libolm/pk.go index 0854b4d1..db8d35c5 100644 --- a/crypto/olm/pk_libolm.go +++ b/crypto/libolm/pk.go @@ -4,9 +4,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -//go:build !goolm - -package olm +package libolm // #cgo LDFLAGS: -lolm -lstdc++ // #include @@ -21,19 +19,30 @@ import ( "github.com/tidwall/sjson" "maunium.net/go/mautrix/crypto/canonicaljson" + "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) -// LibOlmPKSigning stores a key pair for signing messages. -type LibOlmPKSigning struct { +// PKSigning stores a key pair for signing messages. +type PKSigning struct { int *C.OlmPkSigning mem []byte publicKey id.Ed25519 seed []byte } -// Ensure that LibOlmPKSigning implements PKSigning. -var _ PKSigning = (*LibOlmPKSigning)(nil) +// Ensure that [PKSigning] implements [olm.PKSigning]. +var _ olm.PKSigning = (*PKSigning)(nil) + +func init() { + olm.InitNewPKSigning = func() (olm.PKSigning, error) { return NewPKSigning() } + olm.InitNewPKSigningFromSeed = func(seed []byte) (olm.PKSigning, error) { + return NewPKSigningFromSeed(seed) + } + olm.InitNewPKDecryptionFromPrivateKey = func(privateKey []byte) (olm.PKDecryption, error) { + return NewPkDecryption(privateKey) + } +} func pkSigningSize() uint { return uint(C.olm_pk_signing_size()) @@ -51,16 +60,16 @@ func pkSigningSignatureLength() uint { return uint(C.olm_pk_signature_length()) } -func newBlankPKSigning() *LibOlmPKSigning { +func newBlankPKSigning() *PKSigning { memory := make([]byte, pkSigningSize()) - return &LibOlmPKSigning{ + return &PKSigning{ int: C.olm_pk_signing(unsafe.Pointer(&memory[0])), mem: memory, } } // NewPKSigningFromSeed creates a new [PKSigning] object using the given seed. -func NewPKSigningFromSeed(seed []byte) (PKSigning, error) { +func NewPKSigningFromSeed(seed []byte) (*PKSigning, error) { p := newBlankPKSigning() p.clear() pubKey := make([]byte, pkSigningPublicKeyLength()) @@ -74,34 +83,34 @@ func NewPKSigningFromSeed(seed []byte) (PKSigning, error) { return p, nil } -// NewPKSigning creates a new LibOlmPKSigning object, containing a key pair for +// NewPKSigning creates a new [PKSigning] object, containing a key pair for // signing messages. -func NewPKSigning() (PKSigning, error) { +func NewPKSigning() (*PKSigning, error) { // Generate the seed seed := make([]byte, pkSigningSeedLength()) _, err := rand.Read(seed) if err != nil { - panic(NotEnoughGoRandom) + panic(olm.NotEnoughGoRandom) } pk, err := NewPKSigningFromSeed(seed) return pk, err } -func (p *LibOlmPKSigning) PublicKey() id.Ed25519 { +func (p *PKSigning) PublicKey() id.Ed25519 { return p.publicKey } -func (p *LibOlmPKSigning) Seed() []byte { +func (p *PKSigning) Seed() []byte { return p.seed } -// clear clears the underlying memory of a LibOlmPKSigning object. -func (p *LibOlmPKSigning) clear() { +// clear clears the underlying memory of a [PKSigning] object. +func (p *PKSigning) clear() { C.olm_clear_pk_signing((*C.OlmPkSigning)(p.int)) } // Sign creates a signature for the given message using this key. -func (p *LibOlmPKSigning) Sign(message []byte) ([]byte, error) { +func (p *PKSigning) Sign(message []byte) ([]byte, error) { signature := make([]byte, pkSigningSignatureLength()) if C.olm_pk_sign((*C.OlmPkSigning)(p.int), (*C.uint8_t)(unsafe.Pointer(&message[0])), C.size_t(len(message)), (*C.uint8_t)(unsafe.Pointer(&signature[0])), C.size_t(len(signature))) == errorVal() { @@ -111,7 +120,7 @@ func (p *LibOlmPKSigning) Sign(message []byte) ([]byte, error) { } // SignJSON creates a signature for the given object after encoding it to canonical JSON. -func (p *LibOlmPKSigning) SignJSON(obj interface{}) (string, error) { +func (p *PKSigning) SignJSON(obj interface{}) (string, error) { objJSON, err := json.Marshal(obj) if err != nil { return "", err @@ -126,15 +135,15 @@ func (p *LibOlmPKSigning) SignJSON(obj interface{}) (string, error) { } // lastError returns the last error that happened in relation to this -// LibOlmPKSigning object. -func (p *LibOlmPKSigning) lastError() error { +// [PKSigning] object. +func (p *PKSigning) lastError() error { return convertError(C.GoString(C.olm_pk_signing_last_error((*C.OlmPkSigning)(p.int)))) } -type LibOlmPKDecryption struct { +type PKDecryption struct { int *C.OlmPkDecryption mem []byte - PublicKey []byte + publicKey []byte } func pkDecryptionSize() uint { @@ -145,9 +154,9 @@ func pkDecryptionPublicKeySize() uint { return uint(C.olm_pk_key_length()) } -func NewPkDecryption(privateKey []byte) (*LibOlmPKDecryption, error) { +func NewPkDecryption(privateKey []byte) (*PKDecryption, error) { memory := make([]byte, pkDecryptionSize()) - p := &LibOlmPKDecryption{ + p := &PKDecryption{ int: C.olm_pk_decryption(unsafe.Pointer(&memory[0])), mem: memory, } @@ -159,12 +168,16 @@ func NewPkDecryption(privateKey []byte) (*LibOlmPKDecryption, error) { unsafe.Pointer(&privateKey[0]), C.size_t(len(privateKey))) == errorVal() { return nil, p.lastError() } - p.PublicKey = pubKey + p.publicKey = pubKey return p, nil } -func (p *LibOlmPKDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byte) ([]byte, error) { +func (p *PKDecryption) PublicKey() id.Curve25519 { + return id.Curve25519(p.publicKey) +} + +func (p *PKDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byte) ([]byte, error) { maxPlaintextLength := uint(C.olm_pk_max_plaintext_length((*C.OlmPkDecryption)(p.int), C.size_t(len(ciphertext)))) plaintext := make([]byte, maxPlaintextLength) @@ -181,12 +194,12 @@ func (p *LibOlmPKDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext } // Clear clears the underlying memory of a PkDecryption object. -func (p *LibOlmPKDecryption) clear() { +func (p *PKDecryption) clear() { C.olm_clear_pk_decryption((*C.OlmPkDecryption)(p.int)) } // lastError returns the last error that happened in relation to this -// LibOlmPKDecryption object. -func (p *LibOlmPKDecryption) lastError() error { +// [PKDecryption] object. +func (p *PKDecryption) lastError() error { return convertError(C.GoString(C.olm_pk_decryption_last_error((*C.OlmPkDecryption)(p.int)))) } diff --git a/crypto/libolm/register.go b/crypto/libolm/register.go new file mode 100644 index 00000000..a423a7d0 --- /dev/null +++ b/crypto/libolm/register.go @@ -0,0 +1,21 @@ +package libolm + +// #cgo LDFLAGS: -lolm -lstdc++ +// #include +import "C" +import "maunium.net/go/mautrix/crypto/olm" + +var pickleKey = []byte("maunium.net/go/mautrix/crypto/olm") + +func init() { + olm.GetVersion = func() (major, minor, patch uint8) { + C.olm_get_library_version( + (*C.uint8_t)(&major), + (*C.uint8_t)(&minor), + (*C.uint8_t)(&patch)) + return 3, 2, 15 + } + olm.SetPickleKeyImpl = func(key []byte) { + pickleKey = key + } +} diff --git a/crypto/libolm/session.go b/crypto/libolm/session.go new file mode 100644 index 00000000..4cc22809 --- /dev/null +++ b/crypto/libolm/session.go @@ -0,0 +1,388 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package libolm + +// #cgo LDFLAGS: -lolm -lstdc++ +// #include +// #include +// #include +// void olm_session_describe(OlmSession * session, char *buf, size_t buflen) __attribute__((weak)); +// void meowlm_session_describe(OlmSession * session, char *buf, size_t buflen) { +// if (olm_session_describe) { +// olm_session_describe(session, buf, buflen); +// } else { +// sprintf(buf, "olm_session_describe not supported"); +// } +// } +import "C" + +import ( + "crypto/rand" + "encoding/base64" + "unsafe" + + "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/id" +) + +// Session stores an end to end encrypted messaging session. +type Session struct { + int *C.OlmSession + mem []byte +} + +// Ensure that [Session] implements [olm.Session]. +var _ olm.Session = (*Session)(nil) + +func init() { + olm.InitSessionFromPickled = func(pickled, key []byte) (olm.Session, error) { + return SessionFromPickled(pickled, key) + } + olm.InitNewBlankSession = func() olm.Session { + return NewBlankSession() + } +} + +// sessionSize is the size of a session object in bytes. +func sessionSize() uint { + return uint(C.olm_session_size()) +} + +// SessionFromPickled loads a Session from a pickled base64 string. Decrypts +// the Session using the supplied key. Returns error on failure. If the key +// doesn't match the one used to encrypt the Session then the error will be +// "BAD_SESSION_KEY". If the base64 couldn't be decoded then the error will be +// "INVALID_BASE64". +func SessionFromPickled(pickled, key []byte) (*Session, error) { + if len(pickled) == 0 { + return nil, olm.EmptyInput + } + s := NewBlankSession() + return s, s.Unpickle(pickled, key) +} + +func NewBlankSession() *Session { + memory := make([]byte, sessionSize()) + return &Session{ + int: C.olm_session(unsafe.Pointer(&memory[0])), + mem: memory, + } +} + +// lastError returns an error describing the most recent error to happen to a +// session. +func (s *Session) lastError() error { + return convertError(C.GoString(C.olm_session_last_error((*C.OlmSession)(s.int)))) +} + +// Clear clears the memory used to back this Session. +func (s *Session) Clear() error { + r := C.olm_clear_session((*C.OlmSession)(s.int)) + if r == errorVal() { + return s.lastError() + } + return nil +} + +// pickleLen returns the number of bytes needed to store a session. +func (s *Session) pickleLen() uint { + return uint(C.olm_pickle_session_length((*C.OlmSession)(s.int))) +} + +// createOutboundRandomLen returns the number of random bytes needed to create +// an outbound session. +func (s *Session) createOutboundRandomLen() uint { + return uint(C.olm_create_outbound_session_random_length((*C.OlmSession)(s.int))) +} + +// idLen returns the length of the buffer needed to return the id for this +// session. +func (s *Session) idLen() uint { + return uint(C.olm_session_id_length((*C.OlmSession)(s.int))) +} + +// encryptRandomLen returns the number of random bytes needed to encrypt the +// next message. +func (s *Session) encryptRandomLen() uint { + return uint(C.olm_encrypt_random_length((*C.OlmSession)(s.int))) +} + +// encryptMsgLen returns the size of the next message in bytes for the given +// number of plain-text bytes. +func (s *Session) encryptMsgLen(plainTextLen int) uint { + return uint(C.olm_encrypt_message_length((*C.OlmSession)(s.int), C.size_t(plainTextLen))) +} + +// decryptMaxPlaintextLen returns the maximum number of bytes of plain-text a +// given message could decode to. The actual size could be different due to +// padding. Returns error on failure. If the message base64 couldn't be +// decoded then the error will be "INVALID_BASE64". If the message is for an +// unsupported version of the protocol then the error will be +// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error +// will be "BAD_MESSAGE_FORMAT". +func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) (uint, error) { + if len(message) == 0 { + return 0, olm.EmptyInput + } + r := C.olm_decrypt_max_plaintext_length( + (*C.OlmSession)(s.int), + C.size_t(msgType), + unsafe.Pointer(C.CString(message)), + C.size_t(len(message))) + if r == errorVal() { + return 0, s.lastError() + } + return uint(r), nil +} + +// Pickle returns a Session as a base64 string. Encrypts the Session using the +// supplied key. +func (s *Session) Pickle(key []byte) ([]byte, error) { + if len(key) == 0 { + return nil, olm.NoKeyProvided + } + pickled := make([]byte, s.pickleLen()) + r := C.olm_pickle_session( + (*C.OlmSession)(s.int), + unsafe.Pointer(&key[0]), + C.size_t(len(key)), + unsafe.Pointer(&pickled[0]), + C.size_t(len(pickled))) + if r == errorVal() { + panic(s.lastError()) + } + return pickled[:r], nil +} + +// Unpickle unpickles the base64-encoded Olm session decrypting it with the +// provided key. This function mutates the input pickled data slice. +func (s *Session) Unpickle(pickled, key []byte) error { + if len(key) == 0 { + return olm.NoKeyProvided + } + r := C.olm_unpickle_session( + (*C.OlmSession)(s.int), + unsafe.Pointer(&key[0]), + C.size_t(len(key)), + unsafe.Pointer(&pickled[0]), + C.size_t(len(pickled))) + if r == errorVal() { + return s.lastError() + } + return nil +} + +// Deprecated +func (s *Session) GobEncode() ([]byte, error) { + pickled, err := s.Pickle(pickleKey) + if err != nil { + return nil, err + } + length := base64.RawStdEncoding.DecodedLen(len(pickled)) + rawPickled := make([]byte, length) + _, err = base64.RawStdEncoding.Decode(rawPickled, pickled) + return rawPickled, err +} + +// Deprecated +func (s *Session) GobDecode(rawPickled []byte) error { + if s == nil || s.int == nil { + *s = *NewBlankSession() + } + length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) + pickled := make([]byte, length) + base64.RawStdEncoding.Encode(pickled, rawPickled) + return s.Unpickle(pickled, pickleKey) +} + +// Deprecated +func (s *Session) MarshalJSON() ([]byte, error) { + pickled, err := s.Pickle(pickleKey) + if err != nil { + return nil, err + } + quotes := make([]byte, len(pickled)+2) + quotes[0] = '"' + quotes[len(quotes)-1] = '"' + copy(quotes[1:len(quotes)-1], pickled) + return quotes, nil +} + +// Deprecated +func (s *Session) UnmarshalJSON(data []byte) error { + if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { + return olm.InputNotJSONString + } + if s == nil || s.int == nil { + *s = *NewBlankSession() + } + return s.Unpickle(data[1:len(data)-1], pickleKey) +} + +// Id returns an identifier for this Session. Will be the same for both ends +// of the conversation. +func (s *Session) ID() id.SessionID { + sessionID := make([]byte, s.idLen()) + r := C.olm_session_id( + (*C.OlmSession)(s.int), + unsafe.Pointer(&sessionID[0]), + C.size_t(len(sessionID))) + if r == errorVal() { + panic(s.lastError()) + } + return id.SessionID(sessionID) +} + +// HasReceivedMessage returns true if this session has received any message. +func (s *Session) HasReceivedMessage() bool { + switch C.olm_session_has_received_message((*C.OlmSession)(s.int)) { + case 0: + return false + default: + return true + } +} + +// MatchesInboundSession checks if the PRE_KEY message is for this in-bound +// Session. This can happen if multiple messages are sent to this Account +// before this Account sends a message in reply. Returns true if the session +// matches. Returns false if the session does not match. Returns error on +// failure. If the base64 couldn't be decoded then the error will be +// "INVALID_BASE64". If the message was for an unsupported protocol version +// then the error will be "BAD_MESSAGE_VERSION". If the message couldn't be +// decoded then then the error will be "BAD_MESSAGE_FORMAT". +func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { + if len(oneTimeKeyMsg) == 0 { + return false, olm.EmptyInput + } + r := C.olm_matches_inbound_session( + (*C.OlmSession)(s.int), + unsafe.Pointer(&([]byte(oneTimeKeyMsg))[0]), + C.size_t(len(oneTimeKeyMsg))) + if r == 1 { + return true, nil + } else if r == 0 { + return false, nil + } else { // if r == errorVal() + return false, s.lastError() + } +} + +// MatchesInboundSessionFrom checks if the PRE_KEY message is for this in-bound +// Session. This can happen if multiple messages are sent to this Account +// before this Account sends a message in reply. Returns true if the session +// matches. Returns false if the session does not match. Returns error on +// failure. If the base64 couldn't be decoded then the error will be +// "INVALID_BASE64". If the message was for an unsupported protocol version +// then the error will be "BAD_MESSAGE_VERSION". If the message couldn't be +// decoded then then the error will be "BAD_MESSAGE_FORMAT". +func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) { + if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 { + return false, olm.EmptyInput + } + r := C.olm_matches_inbound_session_from( + (*C.OlmSession)(s.int), + unsafe.Pointer(&([]byte(theirIdentityKey))[0]), + C.size_t(len(theirIdentityKey)), + unsafe.Pointer(&([]byte(oneTimeKeyMsg))[0]), + C.size_t(len(oneTimeKeyMsg))) + if r == 1 { + return true, nil + } else if r == 0 { + return false, nil + } else { // if r == errorVal() + return false, s.lastError() + } +} + +// EncryptMsgType returns the type of the next message that Encrypt will +// return. Returns MsgTypePreKey if the message will be a PRE_KEY message. +// Returns MsgTypeMsg if the message will be a normal message. Returns error +// on failure. +func (s *Session) EncryptMsgType() id.OlmMsgType { + switch C.olm_encrypt_message_type((*C.OlmSession)(s.int)) { + case C.size_t(id.OlmMsgTypePreKey): + return id.OlmMsgTypePreKey + case C.size_t(id.OlmMsgTypeMsg): + return id.OlmMsgTypeMsg + default: + panic("olm_encrypt_message_type returned invalid result") + } +} + +// Encrypt encrypts a message using the Session. Returns the encrypted message +// as base64. +func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) { + if len(plaintext) == 0 { + return 0, nil, olm.EmptyInput + } + // Make the slice be at least length 1 + random := make([]byte, s.encryptRandomLen()+1) + _, err := rand.Read(random) + if err != nil { + // TODO can we just return err here? + return 0, nil, olm.NotEnoughGoRandom + } + messageType := s.EncryptMsgType() + message := make([]byte, s.encryptMsgLen(len(plaintext))) + r := C.olm_encrypt( + (*C.OlmSession)(s.int), + unsafe.Pointer(&plaintext[0]), + C.size_t(len(plaintext)), + unsafe.Pointer(&random[0]), + C.size_t(len(random)), + unsafe.Pointer(&message[0]), + C.size_t(len(message))) + if r == errorVal() { + return 0, nil, s.lastError() + } + return messageType, message[:r], nil +} + +// Decrypt decrypts a message using the Session. Returns the the plain-text on +// success. Returns error on failure. If the base64 couldn't be decoded then +// the error will be "INVALID_BASE64". If the message is for an unsupported +// version of the protocol then the error will be "BAD_MESSAGE_VERSION". If +// the message couldn't be decoded then the error will be BAD_MESSAGE_FORMAT". +// If the MAC on the message was invalid then the error will be +// "BAD_MESSAGE_MAC". +func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) { + if len(message) == 0 { + return nil, olm.EmptyInput + } + decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message, msgType) + if err != nil { + return nil, err + } + messageCopy := []byte(message) + plaintext := make([]byte, decryptMaxPlaintextLen) + r := C.olm_decrypt( + (*C.OlmSession)(s.int), + C.size_t(msgType), + unsafe.Pointer(&(messageCopy)[0]), + C.size_t(len(messageCopy)), + unsafe.Pointer(&plaintext[0]), + C.size_t(len(plaintext))) + if r == errorVal() { + return nil, s.lastError() + } + return plaintext[:r], nil +} + +// https://gitlab.matrix.org/matrix-org/olm/-/blob/3.2.8/include/olm/olm.h#L392-393 +const maxDescribeSize = 600 + +// Describe generates a string describing the internal state of an olm session for debugging and logging purposes. +func (s *Session) Describe() string { + desc := (*C.char)(C.malloc(C.size_t(maxDescribeSize))) + defer C.free(unsafe.Pointer(desc)) + C.meowlm_session_describe( + (*C.OlmSession)(s.int), + desc, + C.size_t(maxDescribeSize)) + return C.GoString(desc) +} diff --git a/crypto/olm/account.go b/crypto/olm/account.go index 37458d1b..3271b1c1 100644 --- a/crypto/olm/account.go +++ b/crypto/olm/account.go @@ -1,28 +1,106 @@ -//go:build !goolm +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. package olm -// #cgo LDFLAGS: -lolm -lstdc++ -// #include -import "C" - import ( - "crypto/rand" - "encoding/base64" - "encoding/json" - "unsafe" + "io" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" - - "maunium.net/go/mautrix/crypto/canonicaljson" "maunium.net/go/mautrix/id" ) -// Account stores a device account for end to end encrypted messaging. -type Account struct { - int *C.OlmAccount - mem []byte +type Account interface { + // Pickle returns an Account as a base64 string. Encrypts the Account using the + // supplied key. + Pickle(key []byte) ([]byte, error) + + // Unpickle loads an Account from a pickled base64 string. Decrypts the + // Account using the supplied key. Returns error on failure. + Unpickle(pickled, key []byte) error + + // IdentityKeysJSON returns the public parts of the identity keys for the Account. + IdentityKeysJSON() ([]byte, error) + + // IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity + // keys for the Account. + IdentityKeys() (id.Ed25519, id.Curve25519, error) + + // Sign returns the signature of a message using the ed25519 key for this + // Account. + Sign(message []byte) ([]byte, error) + + // OneTimeKeys returns the public parts of the unpublished one time keys for + // the Account. + // + // The returned data is a struct with the single value "Curve25519", which is + // itself an object mapping key id to base64-encoded Curve25519 key. For + // example: + // + // { + // Curve25519: { + // "AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo", + // "AAAAAB": "LRvjo46L1X2vx69sS9QNFD29HWulxrmW11Up5AfAjgU" + // } + // } + OneTimeKeys() (map[string]id.Curve25519, error) + + // MarkKeysAsPublished marks the current set of one time keys as being + // published. + MarkKeysAsPublished() + + // MaxNumberOfOneTimeKeys returns the largest number of one time keys this + // Account can store. + MaxNumberOfOneTimeKeys() uint + + // GenOneTimeKeys generates a number of new one time keys. If the total + // number of keys stored by this Account exceeds MaxNumberOfOneTimeKeys + // then the old keys are discarded. Reads random data from the given + // reader, or if nil is passed, defaults to crypto/rand. + GenOneTimeKeys(reader io.Reader, num uint) error + + // NewOutboundSession creates a new out-bound session for sending messages to a + // given curve25519 identityKey and oneTimeKey. Returns error on failure. If the + // keys couldn't be decoded as base64 then the error will be "INVALID_BASE64" + NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (Session, error) + + // NewInboundSession creates a new in-bound session for sending/receiving + // messages from an incoming PRE_KEY message. Returns error on failure. If + // the base64 couldn't be decoded then the error will be "INVALID_BASE64". If + // the message was for an unsupported protocol version then the error will be + // "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the + // error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one + // time key then the error will be "BAD_MESSAGE_KEY_ID". + NewInboundSession(oneTimeKeyMsg string) (Session, error) + + // NewInboundSessionFrom creates a new in-bound session for sending/receiving + // messages from an incoming PRE_KEY message. Returns error on failure. If + // the base64 couldn't be decoded then the error will be "INVALID_BASE64". If + // the message was for an unsupported protocol version then the error will be + // "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the + // error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one + // time key then the error will be "BAD_MESSAGE_KEY_ID". + NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (Session, error) + + // RemoveOneTimeKeys removes the one time keys that the session used from the + // Account. Returns error on failure. If the Account doesn't have any + // matching one time keys then the error will be "BAD_MESSAGE_KEY_ID". + RemoveOneTimeKeys(s Session) error +} + +var InitBlankAccount func() Account +var InitNewAccount func(io.Reader) (Account, error) +var InitNewAccountFromPickled func(pickled, key []byte) (Account, error) + +// NewAccount creates a new Account. +func NewAccount(r io.Reader) (Account, error) { + return InitNewAccount(r) +} + +func NewBlankAccount() Account { + return InitBlankAccount() } // AccountFromPickled loads an Account from a pickled base64 string. Decrypts @@ -30,375 +108,6 @@ type Account struct { // doesn't match the one used to encrypt the Account then the error will be // "BAD_ACCOUNT_KEY". If the base64 couldn't be decoded then the error will be // "INVALID_BASE64". -func AccountFromPickled(pickled, key []byte) (*Account, error) { - if len(pickled) == 0 { - return nil, EmptyInput - } - a := NewBlankAccount() - return a, a.Unpickle(pickled, key) -} - -func NewBlankAccount() *Account { - memory := make([]byte, accountSize()) - return &Account{ - int: C.olm_account(unsafe.Pointer(&memory[0])), - mem: memory, - } -} - -// NewAccount creates a new Account. -func NewAccount() *Account { - a := NewBlankAccount() - random := make([]byte, a.createRandomLen()+1) - _, err := rand.Read(random) - if err != nil { - panic(NotEnoughGoRandom) - } - r := C.olm_create_account( - (*C.OlmAccount)(a.int), - unsafe.Pointer(&random[0]), - C.size_t(len(random))) - if r == errorVal() { - panic(a.lastError()) - } else { - return a - } -} - -// accountSize returns the size of an account object in bytes. -func accountSize() uint { - return uint(C.olm_account_size()) -} - -// lastError returns an error describing the most recent error to happen to an -// account. -func (a *Account) lastError() error { - return convertError(C.GoString(C.olm_account_last_error((*C.OlmAccount)(a.int)))) -} - -// Clear clears the memory used to back this Account. -func (a *Account) Clear() error { - r := C.olm_clear_account((*C.OlmAccount)(a.int)) - if r == errorVal() { - return a.lastError() - } else { - return nil - } -} - -// pickleLen returns the number of bytes needed to store an Account. -func (a *Account) pickleLen() uint { - return uint(C.olm_pickle_account_length((*C.OlmAccount)(a.int))) -} - -// createRandomLen returns the number of random bytes needed to create an -// Account. -func (a *Account) createRandomLen() uint { - return uint(C.olm_create_account_random_length((*C.OlmAccount)(a.int))) -} - -// identityKeysLen returns the size of the output buffer needed to hold the -// identity keys. -func (a *Account) identityKeysLen() uint { - return uint(C.olm_account_identity_keys_length((*C.OlmAccount)(a.int))) -} - -// signatureLen returns the length of an ed25519 signature encoded as base64. -func (a *Account) signatureLen() uint { - return uint(C.olm_account_signature_length((*C.OlmAccount)(a.int))) -} - -// oneTimeKeysLen returns the size of the output buffer needed to hold the one -// time keys. -func (a *Account) oneTimeKeysLen() uint { - return uint(C.olm_account_one_time_keys_length((*C.OlmAccount)(a.int))) -} - -// genOneTimeKeysRandomLen returns the number of random bytes needed to -// generate a given number of new one time keys. -func (a *Account) genOneTimeKeysRandomLen(num uint) uint { - return uint(C.olm_account_generate_one_time_keys_random_length( - (*C.OlmAccount)(a.int), - C.size_t(num))) -} - -// Pickle returns an Account as a base64 string. Encrypts the Account using the -// supplied key. -func (a *Account) Pickle(key []byte) []byte { - if len(key) == 0 { - panic(NoKeyProvided) - } - pickled := make([]byte, a.pickleLen()) - r := C.olm_pickle_account( - (*C.OlmAccount)(a.int), - unsafe.Pointer(&key[0]), - C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), - C.size_t(len(pickled))) - if r == errorVal() { - panic(a.lastError()) - } - return pickled[:r] -} - -func (a *Account) Unpickle(pickled, key []byte) error { - if len(key) == 0 { - return NoKeyProvided - } - r := C.olm_unpickle_account( - (*C.OlmAccount)(a.int), - unsafe.Pointer(&key[0]), - C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), - C.size_t(len(pickled))) - if r == errorVal() { - return a.lastError() - } - return nil -} - -// Deprecated -func (a *Account) GobEncode() ([]byte, error) { - pickled := a.Pickle(pickleKey) - length := base64.RawStdEncoding.DecodedLen(len(pickled)) - rawPickled := make([]byte, length) - _, err := base64.RawStdEncoding.Decode(rawPickled, pickled) - return rawPickled, err -} - -// Deprecated -func (a *Account) GobDecode(rawPickled []byte) error { - if a.int == nil { - *a = *NewBlankAccount() - } - length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) - pickled := make([]byte, length) - base64.RawStdEncoding.Encode(pickled, rawPickled) - return a.Unpickle(pickled, pickleKey) -} - -// Deprecated -func (a *Account) MarshalJSON() ([]byte, error) { - pickled := a.Pickle(pickleKey) - quotes := make([]byte, len(pickled)+2) - quotes[0] = '"' - quotes[len(quotes)-1] = '"' - copy(quotes[1:len(quotes)-1], pickled) - return quotes, nil -} - -// Deprecated -func (a *Account) UnmarshalJSON(data []byte) error { - if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return InputNotJSONString - } - if a.int == nil { - *a = *NewBlankAccount() - } - return a.Unpickle(data[1:len(data)-1], pickleKey) -} - -// IdentityKeysJSON returns the public parts of the identity keys for the Account. -func (a *Account) IdentityKeysJSON() []byte { - identityKeys := make([]byte, a.identityKeysLen()) - r := C.olm_account_identity_keys( - (*C.OlmAccount)(a.int), - unsafe.Pointer(&identityKeys[0]), - C.size_t(len(identityKeys))) - if r == errorVal() { - panic(a.lastError()) - } else { - return identityKeys - } -} - -// IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity -// keys for the Account. -func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519) { - identityKeysJSON := a.IdentityKeysJSON() - results := gjson.GetManyBytes(identityKeysJSON, "ed25519", "curve25519") - return id.Ed25519(results[0].Str), id.Curve25519(results[1].Str) -} - -// Sign returns the signature of a message using the ed25519 key for this -// Account. -func (a *Account) Sign(message []byte) []byte { - if len(message) == 0 { - panic(EmptyInput) - } - signature := make([]byte, a.signatureLen()) - r := C.olm_account_sign( - (*C.OlmAccount)(a.int), - unsafe.Pointer(&message[0]), - C.size_t(len(message)), - unsafe.Pointer(&signature[0]), - C.size_t(len(signature))) - if r == errorVal() { - panic(a.lastError()) - } - return signature -} - -// SignJSON signs the given JSON object following the Matrix specification: -// https://matrix.org/docs/spec/appendices#signing-json -func (a *Account) SignJSON(obj interface{}) (string, error) { - objJSON, err := json.Marshal(obj) - if err != nil { - return "", err - } - objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned") - objJSON, _ = sjson.DeleteBytes(objJSON, "signatures") - return string(a.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))), nil -} - -// OneTimeKeys returns the public parts of the unpublished one time keys for -// the Account. -// -// The returned data is a struct with the single value "Curve25519", which is -// itself an object mapping key id to base64-encoded Curve25519 key. For -// example: -// -// { -// Curve25519: { -// "AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo", -// "AAAAAB": "LRvjo46L1X2vx69sS9QNFD29HWulxrmW11Up5AfAjgU" -// } -// } -func (a *Account) OneTimeKeys() map[string]id.Curve25519 { - oneTimeKeysJSON := make([]byte, a.oneTimeKeysLen()) - r := C.olm_account_one_time_keys( - (*C.OlmAccount)(a.int), - unsafe.Pointer(&oneTimeKeysJSON[0]), - C.size_t(len(oneTimeKeysJSON))) - if r == errorVal() { - panic(a.lastError()) - } - var oneTimeKeys struct { - Curve25519 map[string]id.Curve25519 `json:"curve25519"` - } - err := json.Unmarshal(oneTimeKeysJSON, &oneTimeKeys) - if err != nil { - panic(err) - } - return oneTimeKeys.Curve25519 -} - -// MarkKeysAsPublished marks the current set of one time keys as being -// published. -func (a *Account) MarkKeysAsPublished() { - C.olm_account_mark_keys_as_published((*C.OlmAccount)(a.int)) -} - -// MaxNumberOfOneTimeKeys returns the largest number of one time keys this -// Account can store. -func (a *Account) MaxNumberOfOneTimeKeys() uint { - return uint(C.olm_account_max_number_of_one_time_keys((*C.OlmAccount)(a.int))) -} - -// GenOneTimeKeys generates a number of new one time keys. If the total number -// of keys stored by this Account exceeds MaxNumberOfOneTimeKeys then the old -// keys are discarded. -func (a *Account) GenOneTimeKeys(num uint) { - random := make([]byte, a.genOneTimeKeysRandomLen(num)+1) - _, err := rand.Read(random) - if err != nil { - panic(NotEnoughGoRandom) - } - r := C.olm_account_generate_one_time_keys( - (*C.OlmAccount)(a.int), - C.size_t(num), - unsafe.Pointer(&random[0]), - C.size_t(len(random))) - if r == errorVal() { - panic(a.lastError()) - } -} - -// NewOutboundSession creates a new out-bound session for sending messages to a -// given curve25519 identityKey and oneTimeKey. Returns error on failure. If the -// keys couldn't be decoded as base64 then the error will be "INVALID_BASE64" -func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*Session, error) { - if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 { - return nil, EmptyInput - } - s := NewBlankSession() - random := make([]byte, s.createOutboundRandomLen()+1) - _, err := rand.Read(random) - if err != nil { - panic(NotEnoughGoRandom) - } - r := C.olm_create_outbound_session( - (*C.OlmSession)(s.int), - (*C.OlmAccount)(a.int), - unsafe.Pointer(&([]byte(theirIdentityKey)[0])), - C.size_t(len(theirIdentityKey)), - unsafe.Pointer(&([]byte(theirOneTimeKey)[0])), - C.size_t(len(theirOneTimeKey)), - unsafe.Pointer(&random[0]), - C.size_t(len(random))) - if r == errorVal() { - return nil, s.lastError() - } - return s, nil -} - -// NewInboundSession creates a new in-bound session for sending/receiving -// messages from an incoming PRE_KEY message. Returns error on failure. If -// the base64 couldn't be decoded then the error will be "INVALID_BASE64". If -// the message was for an unsupported protocol version then the error will be -// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the -// error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one -// time key then the error will be "BAD_MESSAGE_KEY_ID". -func (a *Account) NewInboundSession(oneTimeKeyMsg string) (*Session, error) { - if len(oneTimeKeyMsg) == 0 { - return nil, EmptyInput - } - s := NewBlankSession() - r := C.olm_create_inbound_session( - (*C.OlmSession)(s.int), - (*C.OlmAccount)(a.int), - unsafe.Pointer(&([]byte(oneTimeKeyMsg)[0])), - C.size_t(len(oneTimeKeyMsg))) - if r == errorVal() { - return nil, s.lastError() - } - return s, nil -} - -// NewInboundSessionFrom creates a new in-bound session for sending/receiving -// messages from an incoming PRE_KEY message. Returns error on failure. If -// the base64 couldn't be decoded then the error will be "INVALID_BASE64". If -// the message was for an unsupported protocol version then the error will be -// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the -// error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one -// time key then the error will be "BAD_MESSAGE_KEY_ID". -func (a *Account) NewInboundSessionFrom(theirIdentityKey id.Curve25519, oneTimeKeyMsg string) (*Session, error) { - if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 { - return nil, EmptyInput - } - s := NewBlankSession() - r := C.olm_create_inbound_session_from( - (*C.OlmSession)(s.int), - (*C.OlmAccount)(a.int), - unsafe.Pointer(&([]byte(theirIdentityKey)[0])), - C.size_t(len(theirIdentityKey)), - unsafe.Pointer(&([]byte(oneTimeKeyMsg)[0])), - C.size_t(len(oneTimeKeyMsg))) - if r == errorVal() { - return nil, s.lastError() - } - return s, nil -} - -// RemoveOneTimeKeys removes the one time keys that the session used from the -// Account. Returns error on failure. If the Account doesn't have any -// matching one time keys then the error will be "BAD_MESSAGE_KEY_ID". -func (a *Account) RemoveOneTimeKeys(s *Session) error { - r := C.olm_remove_one_time_keys( - (*C.OlmAccount)(a.int), - (*C.OlmSession)(s.int)) - if r == errorVal() { - return a.lastError() - } - return nil +func AccountFromPickled(pickled, key []byte) (Account, error) { + return InitNewAccountFromPickled(pickled, key) } diff --git a/crypto/olm/account_goolm.go b/crypto/olm/account_goolm.go deleted file mode 100644 index eeff54f9..00000000 --- a/crypto/olm/account_goolm.go +++ /dev/null @@ -1,154 +0,0 @@ -//go:build goolm - -package olm - -import ( - "encoding/json" - - "github.com/tidwall/sjson" - - "maunium.net/go/mautrix/crypto/canonicaljson" - "maunium.net/go/mautrix/crypto/goolm/account" - "maunium.net/go/mautrix/id" -) - -// Account stores a device account for end to end encrypted messaging. -type Account struct { - account.Account -} - -// NewAccount creates a new Account. -func NewAccount() *Account { - a, err := account.NewAccount(nil) - if err != nil { - panic(err) - } - ac := &Account{} - ac.Account = *a - return ac -} - -func NewBlankAccount() *Account { - return &Account{} -} - -// Clear clears the memory used to back this Account. -func (a *Account) Clear() error { - a.Account = account.Account{} - return nil -} - -// Pickle returns an Account as a base64 string. Encrypts the Account using the -// supplied key. -func (a *Account) Pickle(key []byte) []byte { - if len(key) == 0 { - panic(NoKeyProvided) - } - pickled, err := a.Account.Pickle(key) - if err != nil { - panic(err) - } - return pickled -} - -// IdentityKeysJSON returns the public parts of the identity keys for the Account. -func (a *Account) IdentityKeysJSON() []byte { - identityKeys, err := a.Account.IdentityKeysJSON() - if err != nil { - panic(err) - } - return identityKeys -} - -// Sign returns the signature of a message using the ed25519 key for this -// Account. -func (a *Account) Sign(message []byte) []byte { - if len(message) == 0 { - panic(EmptyInput) - } - signature, err := a.Account.Sign(message) - if err != nil { - panic(err) - } - return signature -} - -// SignJSON signs the given JSON object following the Matrix specification: -// https://matrix.org/docs/spec/appendices#signing-json -func (a *Account) SignJSON(obj interface{}) (string, error) { - objJSON, err := json.Marshal(obj) - if err != nil { - return "", err - } - objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned") - objJSON, _ = sjson.DeleteBytes(objJSON, "signatures") - return string(a.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))), nil -} - -// MaxNumberOfOneTimeKeys returns the largest number of one time keys this -// Account can store. -func (a *Account) MaxNumberOfOneTimeKeys() uint { - return uint(account.MaxOneTimeKeys) -} - -// GenOneTimeKeys generates a number of new one time keys. If the total number -// of keys stored by this Account exceeds MaxNumberOfOneTimeKeys then the old -// keys are discarded. -func (a *Account) GenOneTimeKeys(num uint) { - err := a.Account.GenOneTimeKeys(nil, num) - if err != nil { - panic(err) - } -} - -// NewOutboundSession creates a new out-bound session for sending messages to a -// given curve25519 identityKey and oneTimeKey. Returns error on failure. -func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*Session, error) { - if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 { - return nil, EmptyInput - } - s := &Session{} - newSession, err := a.Account.NewOutboundSession(theirIdentityKey, theirOneTimeKey) - if err != nil { - return nil, err - } - s.OlmSession = *newSession - return s, nil -} - -// NewInboundSession creates a new in-bound session for sending/receiving -// messages from an incoming PRE_KEY message. Returns error on failure. -func (a *Account) NewInboundSession(oneTimeKeyMsg string) (*Session, error) { - if len(oneTimeKeyMsg) == 0 { - return nil, EmptyInput - } - s := &Session{} - newSession, err := a.Account.NewInboundSession(nil, []byte(oneTimeKeyMsg)) - if err != nil { - return nil, err - } - s.OlmSession = *newSession - return s, nil -} - -// NewInboundSessionFrom creates a new in-bound session for sending/receiving -// messages from an incoming PRE_KEY message. Returns error on failure. -func (a *Account) NewInboundSessionFrom(theirIdentityKey id.Curve25519, oneTimeKeyMsg string) (*Session, error) { - if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 { - return nil, EmptyInput - } - s := &Session{} - newSession, err := a.Account.NewInboundSession(&theirIdentityKey, []byte(oneTimeKeyMsg)) - if err != nil { - return nil, err - } - s.OlmSession = *newSession - return s, nil -} - -// RemoveOneTimeKeys removes the one time keys that the session used from the -// Account. Returns error on failure. -func (a *Account) RemoveOneTimeKeys(s *Session) error { - a.Account.RemoveOneTimeKeys(&s.OlmSession) - return nil -} diff --git a/crypto/olm/error_goolm.go b/crypto/olm/error_goolm.go deleted file mode 100644 index 0e54e566..00000000 --- a/crypto/olm/error_goolm.go +++ /dev/null @@ -1,23 +0,0 @@ -//go:build goolm - -package olm - -import ( - "errors" - - "maunium.net/go/mautrix/crypto/goolm" -) - -// Error codes from go-olm -var ( - EmptyInput = goolm.ErrEmptyInput - NoKeyProvided = goolm.ErrNoKeyProvided - NotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand") - SignatureNotFound = errors.New("input JSON doesn't contain signature from specified device") - InputNotJSONString = errors.New("input doesn't look like a JSON string") -) - -// Error codes from olm code -var ( - UnknownMessageIndex = goolm.ErrRatchetNotAvailable -) diff --git a/crypto/olm/error.go b/crypto/olm/errors.go similarity index 57% rename from crypto/olm/error.go rename to crypto/olm/errors.go index 63352e20..eb0cffff 100644 --- a/crypto/olm/error.go +++ b/crypto/olm/errors.go @@ -1,10 +1,28 @@ -//go:build !goolm - package olm -import ( - "errors" - "fmt" +import "errors" + +// Those are the most common used errors +var ( + ErrBadSignature = errors.New("bad signature") + ErrBadMAC = errors.New("bad mac") + ErrBadMessageFormat = errors.New("bad message format") + ErrBadVerification = errors.New("bad verification") + ErrWrongProtocolVersion = errors.New("wrong protocol version") + ErrEmptyInput = errors.New("empty input") + ErrNoKeyProvided = errors.New("no key") + ErrBadMessageKeyID = errors.New("bad message key id") + ErrRatchetNotAvailable = errors.New("ratchet not available: attempt to decode a message whose index is earlier than our earliest known session key") + ErrMsgIndexTooHigh = errors.New("message index too high") + ErrProtocolViolation = errors.New("not protocol message order") + ErrMessageKeyNotFound = errors.New("message key not found") + ErrChainTooHigh = errors.New("chain index too high") + ErrBadInput = errors.New("bad input") + ErrBadVersion = errors.New("wrong version") + ErrWrongPickleVersion = errors.New("wrong pickle version") + ErrValueTooShort = errors.New("value too short") + ErrInputToSmall = errors.New("input too small (truncated?)") + ErrOverflow = errors.New("overflow") ) // Error codes from go-olm @@ -34,29 +52,3 @@ var ( BadSignature = errors.New("received message had a bad signature") InputBufferTooSmall = errors.New("the input data was too small to be valid") ) - -var errorMap = map[string]error{ - "NOT_ENOUGH_RANDOM": NotEnoughRandom, - "OUTPUT_BUFFER_TOO_SMALL": OutputBufferTooSmall, - "BAD_MESSAGE_VERSION": BadMessageVersion, - "BAD_MESSAGE_FORMAT": BadMessageFormat, - "BAD_MESSAGE_MAC": BadMessageMAC, - "BAD_MESSAGE_KEY_ID": BadMessageKeyID, - "INVALID_BASE64": InvalidBase64, - "BAD_ACCOUNT_KEY": BadAccountKey, - "UNKNOWN_PICKLE_VERSION": UnknownPickleVersion, - "CORRUPTED_PICKLE": CorruptedPickle, - "BAD_SESSION_KEY": BadSessionKey, - "UNKNOWN_MESSAGE_INDEX": UnknownMessageIndex, - "BAD_LEGACY_ACCOUNT_PICKLE": BadLegacyAccountPickle, - "BAD_SIGNATURE": BadSignature, - "INPUT_BUFFER_TOO_SMALL": InputBufferTooSmall, -} - -func convertError(errCode string) error { - err, ok := errorMap[errCode] - if ok { - return err - } - return fmt.Errorf("unknown error: %s", errCode) -} diff --git a/crypto/olm/inboundgroupsession.go b/crypto/olm/inboundgroupsession.go index cac49d18..8839b48c 100644 --- a/crypto/olm/inboundgroupsession.go +++ b/crypto/olm/inboundgroupsession.go @@ -1,305 +1,80 @@ -//go:build !goolm +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. package olm -// #cgo LDFLAGS: -lolm -lstdc++ -// #include -import "C" +import "maunium.net/go/mautrix/id" -import ( - "bytes" - "encoding/base64" - "unsafe" +type InboundGroupSession interface { + // Pickle returns an InboundGroupSession as a base64 string. Encrypts the + // InboundGroupSession using the supplied key. + Pickle(key []byte) ([]byte, error) - "maunium.net/go/mautrix/id" -) + // Unpickle loads an [InboundGroupSession] from a pickled base64 string. + // Decrypts the [InboundGroupSession] using the supplied key. + Unpickle(pickled, key []byte) error -// InboundGroupSession stores an inbound encrypted messaging session for a -// group. -type InboundGroupSession struct { - int *C.OlmInboundGroupSession - mem []byte + // Decrypt decrypts a message using the [InboundGroupSession]. Returns the + // plain-text and message index on success. Returns error on failure. If + // the base64 couldn't be decoded then the error will be "INVALID_BASE64". + // If the message is for an unsupported version of the protocol then the + // error will be "BAD_MESSAGE_VERSION". If the message couldn't be decoded + // then the error will be BAD_MESSAGE_FORMAT". If the MAC on the message + // was invalid then the error will be "BAD_MESSAGE_MAC". If we do not have + // a session key corresponding to the message's index (ie, it was sent + // before the session key was shared with us) the error will be + // "OLM_UNKNOWN_MESSAGE_INDEX". + Decrypt(message []byte) ([]byte, uint, error) + + // ID returns a base64-encoded identifier for this session. + ID() id.SessionID + + // FirstKnownIndex returns the first message index we know how to decrypt. + FirstKnownIndex() uint32 + + // IsVerified check if the session has been verified as a valid session. + // (A session is verified either because the original session share was + // signed, or because we have subsequently successfully decrypted a + // message.) + IsVerified() bool + + // Export returns the base64-encoded ratchet key for this session, at the + // given index, in a format which can be used by + // InboundGroupSession.InboundGroupSessionImport(). Encrypts the + // InboundGroupSession using the supplied key. Returns error on failure. + // if we do not have a session key corresponding to the given index (ie, it + // was sent before the session key was shared with us) the error will be + // "OLM_UNKNOWN_MESSAGE_INDEX". + Export(messageIndex uint32) ([]byte, error) } +var InitInboundGroupSessionFromPickled func(pickled, key []byte) (InboundGroupSession, error) +var InitNewInboundGroupSession func(sessionKey []byte) (InboundGroupSession, error) +var InitInboundGroupSessionImport func(sessionKey []byte) (InboundGroupSession, error) +var InitBlankInboundGroupSession func() InboundGroupSession + // InboundGroupSessionFromPickled loads an InboundGroupSession from a pickled -// base64 string. Decrypts the InboundGroupSession using the supplied key. -// Returns error on failure. If the key doesn't match the one used to encrypt -// the InboundGroupSession then the error will be "BAD_SESSION_KEY". If the -// base64 couldn't be decoded then the error will be "INVALID_BASE64". -func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) { - if len(pickled) == 0 { - return nil, EmptyInput - } - lenKey := len(key) - if lenKey == 0 { - key = []byte(" ") - } - s := NewBlankInboundGroupSession() - return s, s.Unpickle(pickled, key) +// base64 string. Decrypts the InboundGroupSession using the supplied key. +// Returns error on failure. +func InboundGroupSessionFromPickled(pickled, key []byte) (InboundGroupSession, error) { + return InitInboundGroupSessionFromPickled(pickled, key) } // NewInboundGroupSession creates a new inbound group session from a key -// exported from OutboundGroupSession.Key(). Returns error on failure. -// If the sessionKey is not valid base64 the error will be -// "OLM_INVALID_BASE64". If the session_key is invalid the error will be -// "OLM_BAD_SESSION_KEY". -func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { - if len(sessionKey) == 0 { - return nil, EmptyInput - } - s := NewBlankInboundGroupSession() - r := C.olm_init_inbound_group_session( - (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(&sessionKey[0]), - C.size_t(len(sessionKey))) - if r == errorVal() { - return nil, s.lastError() - } - return s, nil +// exported from OutboundGroupSession.Key(). Returns error on failure. +func NewInboundGroupSession(sessionKey []byte) (InboundGroupSession, error) { + return InitNewInboundGroupSession(sessionKey) } // InboundGroupSessionImport imports an inbound group session from a previous -// export. Returns error on failure. If the sessionKey is not valid base64 -// the error will be "OLM_INVALID_BASE64". If the session_key is invalid the -// error will be "OLM_BAD_SESSION_KEY". -func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) { - if len(sessionKey) == 0 { - return nil, EmptyInput - } - s := NewBlankInboundGroupSession() - r := C.olm_import_inbound_group_session( - (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(&sessionKey[0]), - C.size_t(len(sessionKey))) - if r == errorVal() { - return nil, s.lastError() - } - return s, nil +// export. Returns error on failure. +func InboundGroupSessionImport(sessionKey []byte) (InboundGroupSession, error) { + return InitInboundGroupSessionImport(sessionKey) } -// inboundGroupSessionSize is the size of an inbound group session object in -// bytes. -func inboundGroupSessionSize() uint { - return uint(C.olm_inbound_group_session_size()) -} - -// newInboundGroupSession initialises an empty InboundGroupSession. -func NewBlankInboundGroupSession() *InboundGroupSession { - memory := make([]byte, inboundGroupSessionSize()) - return &InboundGroupSession{ - int: C.olm_inbound_group_session(unsafe.Pointer(&memory[0])), - mem: memory, - } -} - -// lastError returns an error describing the most recent error to happen to an -// inbound group session. -func (s *InboundGroupSession) lastError() error { - return convertError(C.GoString(C.olm_inbound_group_session_last_error((*C.OlmInboundGroupSession)(s.int)))) -} - -// Clear clears the memory used to back this InboundGroupSession. -func (s *InboundGroupSession) Clear() error { - r := C.olm_clear_inbound_group_session((*C.OlmInboundGroupSession)(s.int)) - if r == errorVal() { - return s.lastError() - } - return nil -} - -// pickleLen returns the number of bytes needed to store an inbound group -// session. -func (s *InboundGroupSession) pickleLen() uint { - return uint(C.olm_pickle_inbound_group_session_length((*C.OlmInboundGroupSession)(s.int))) -} - -// Pickle returns an InboundGroupSession as a base64 string. Encrypts the -// InboundGroupSession using the supplied key. -func (s *InboundGroupSession) Pickle(key []byte) []byte { - if len(key) == 0 { - panic(NoKeyProvided) - } - pickled := make([]byte, s.pickleLen()) - r := C.olm_pickle_inbound_group_session( - (*C.OlmInboundGroupSession)(s.int), - unsafe.Pointer(&key[0]), - C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), - C.size_t(len(pickled))) - if r == errorVal() { - panic(s.lastError()) - } - return pickled[:r] -} - -func (s *InboundGroupSession) Unpickle(pickled, key []byte) error { - if len(key) == 0 { - return NoKeyProvided - } else if len(pickled) == 0 { - return EmptyInput - } - r := C.olm_unpickle_inbound_group_session( - (*C.OlmInboundGroupSession)(s.int), - unsafe.Pointer(&key[0]), - C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), - C.size_t(len(pickled))) - if r == errorVal() { - return s.lastError() - } - return nil -} - -// Deprecated -func (s *InboundGroupSession) GobEncode() ([]byte, error) { - pickled := s.Pickle(pickleKey) - length := base64.RawStdEncoding.DecodedLen(len(pickled)) - rawPickled := make([]byte, length) - _, err := base64.RawStdEncoding.Decode(rawPickled, pickled) - return rawPickled, err -} - -// Deprecated -func (s *InboundGroupSession) GobDecode(rawPickled []byte) error { - if s == nil || s.int == nil { - *s = *NewBlankInboundGroupSession() - } - length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) - pickled := make([]byte, length) - base64.RawStdEncoding.Encode(pickled, rawPickled) - return s.Unpickle(pickled, pickleKey) -} - -// Deprecated -func (s *InboundGroupSession) MarshalJSON() ([]byte, error) { - pickled := s.Pickle(pickleKey) - quotes := make([]byte, len(pickled)+2) - quotes[0] = '"' - quotes[len(quotes)-1] = '"' - copy(quotes[1:len(quotes)-1], pickled) - return quotes, nil -} - -// Deprecated -func (s *InboundGroupSession) UnmarshalJSON(data []byte) error { - if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return InputNotJSONString - } - if s == nil || s.int == nil { - *s = *NewBlankInboundGroupSession() - } - return s.Unpickle(data[1:len(data)-1], pickleKey) -} - -// decryptMaxPlaintextLen returns the maximum number of bytes of plain-text a -// given message could decode to. The actual size could be different due to -// padding. Returns error on failure. If the message base64 couldn't be -// decoded then the error will be "INVALID_BASE64". If the message is for an -// unsupported version of the protocol then the error will be -// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error -// will be "BAD_MESSAGE_FORMAT". -func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, error) { - if len(message) == 0 { - return 0, EmptyInput - } - // olm_group_decrypt_max_plaintext_length destroys the input, so we have to clone it - message = bytes.Clone(message) - r := C.olm_group_decrypt_max_plaintext_length( - (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(&message[0]), - C.size_t(len(message))) - if r == errorVal() { - return 0, s.lastError() - } - return uint(r), nil -} - -// Decrypt decrypts a message using the InboundGroupSession. Returns the the -// plain-text and message index on success. Returns error on failure. If the -// base64 couldn't be decoded then the error will be "INVALID_BASE64". If the -// message is for an unsupported version of the protocol then the error will be -// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error -// will be BAD_MESSAGE_FORMAT". If the MAC on the message was invalid then the -// error will be "BAD_MESSAGE_MAC". If we do not have a session key -// corresponding to the message's index (ie, it was sent before the session key -// was shared with us) the error will be "OLM_UNKNOWN_MESSAGE_INDEX". -func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) { - if len(message) == 0 { - return nil, 0, EmptyInput - } - decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message) - if err != nil { - return nil, 0, err - } - messageCopy := make([]byte, len(message)) - copy(messageCopy, message) - plaintext := make([]byte, decryptMaxPlaintextLen) - var messageIndex uint32 - r := C.olm_group_decrypt( - (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(&messageCopy[0]), - C.size_t(len(messageCopy)), - (*C.uint8_t)(&plaintext[0]), - C.size_t(len(plaintext)), - (*C.uint32_t)(&messageIndex)) - if r == errorVal() { - return nil, 0, s.lastError() - } - return plaintext[:r], uint(messageIndex), nil -} - -// sessionIdLen returns the number of bytes needed to store a session ID. -func (s *InboundGroupSession) sessionIdLen() uint { - return uint(C.olm_inbound_group_session_id_length((*C.OlmInboundGroupSession)(s.int))) -} - -// ID returns a base64-encoded identifier for this session. -func (s *InboundGroupSession) ID() id.SessionID { - sessionID := make([]byte, s.sessionIdLen()) - r := C.olm_inbound_group_session_id( - (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(&sessionID[0]), - C.size_t(len(sessionID))) - if r == errorVal() { - panic(s.lastError()) - } - return id.SessionID(sessionID[:r]) -} - -// FirstKnownIndex returns the first message index we know how to decrypt. -func (s *InboundGroupSession) FirstKnownIndex() uint32 { - return uint32(C.olm_inbound_group_session_first_known_index((*C.OlmInboundGroupSession)(s.int))) -} - -// IsVerified check if the session has been verified as a valid session. (A -// session is verified either because the original session share was signed, or -// because we have subsequently successfully decrypted a message.) -func (s *InboundGroupSession) IsVerified() uint { - return uint(C.olm_inbound_group_session_is_verified((*C.OlmInboundGroupSession)(s.int))) -} - -// exportLen returns the number of bytes needed to export an inbound group -// session. -func (s *InboundGroupSession) exportLen() uint { - return uint(C.olm_export_inbound_group_session_length((*C.OlmInboundGroupSession)(s.int))) -} - -// Export returns the base64-encoded ratchet key for this session, at the given -// index, in a format which can be used by -// InboundGroupSession.InboundGroupSessionImport(). Encrypts the -// InboundGroupSession using the supplied key. Returns error on failure. -// if we do not have a session key corresponding to the given index (ie, it was -// sent before the session key was shared with us) the error will be -// "OLM_UNKNOWN_MESSAGE_INDEX". -func (s *InboundGroupSession) Export(messageIndex uint32) ([]byte, error) { - key := make([]byte, s.exportLen()) - r := C.olm_export_inbound_group_session( - (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(&key[0]), - C.size_t(len(key)), - C.uint32_t(messageIndex)) - if r == errorVal() { - return nil, s.lastError() - } - return key[:r], nil +func NewBlankInboundGroupSession() InboundGroupSession { + return InitBlankInboundGroupSession() } diff --git a/crypto/olm/inboundgroupsession_goolm.go b/crypto/olm/inboundgroupsession_goolm.go deleted file mode 100644 index 4e561cf7..00000000 --- a/crypto/olm/inboundgroupsession_goolm.go +++ /dev/null @@ -1,149 +0,0 @@ -//go:build goolm - -package olm - -import ( - "maunium.net/go/mautrix/crypto/goolm/session" - "maunium.net/go/mautrix/id" -) - -// InboundGroupSession stores an inbound encrypted messaging session for a -// group. -type InboundGroupSession struct { - session.MegolmInboundSession -} - -// InboundGroupSessionFromPickled loads an InboundGroupSession from a pickled -// base64 string. Decrypts the InboundGroupSession using the supplied key. -// Returns error on failure. -func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) { - if len(pickled) == 0 { - return nil, EmptyInput - } - lenKey := len(key) - if lenKey == 0 { - key = []byte(" ") - } - megolmSession, err := session.MegolmInboundSessionFromPickled(pickled, key) - if err != nil { - return nil, err - } - return &InboundGroupSession{ - MegolmInboundSession: *megolmSession, - }, nil -} - -// NewInboundGroupSession creates a new inbound group session from a key -// exported from OutboundGroupSession.Key(). Returns error on failure. -func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { - if len(sessionKey) == 0 { - return nil, EmptyInput - } - megolmSession, err := session.NewMegolmInboundSession(sessionKey) - if err != nil { - return nil, err - } - return &InboundGroupSession{ - MegolmInboundSession: *megolmSession, - }, nil -} - -// InboundGroupSessionImport imports an inbound group session from a previous -// export. Returns error on failure. -func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) { - if len(sessionKey) == 0 { - return nil, EmptyInput - } - megolmSession, err := session.NewMegolmInboundSessionFromExport(sessionKey) - if err != nil { - return nil, err - } - return &InboundGroupSession{ - MegolmInboundSession: *megolmSession, - }, nil -} - -func NewBlankInboundGroupSession() *InboundGroupSession { - return &InboundGroupSession{} -} - -// Clear clears the memory used to back this InboundGroupSession. -func (s *InboundGroupSession) Clear() error { - s.MegolmInboundSession = session.MegolmInboundSession{} - return nil -} - -// Pickle returns an InboundGroupSession as a base64 string. Encrypts the -// InboundGroupSession using the supplied key. -func (s *InboundGroupSession) Pickle(key []byte) []byte { - if len(key) == 0 { - panic(NoKeyProvided) - } - pickled, err := s.MegolmInboundSession.Pickle(key) - if err != nil { - panic(err) - } - return pickled -} - -func (s *InboundGroupSession) Unpickle(pickled, key []byte) error { - if len(key) == 0 { - return NoKeyProvided - } else if len(pickled) == 0 { - return EmptyInput - } - sOlm, err := session.MegolmInboundSessionFromPickled(pickled, key) - if err != nil { - return err - } - s.MegolmInboundSession = *sOlm - return nil -} - -// Decrypt decrypts a message using the InboundGroupSession. Returns the the -// plain-text and message index on success. Returns error on failure. -func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) { - if len(message) == 0 { - return nil, 0, EmptyInput - } - plaintext, messageIndex, err := s.MegolmInboundSession.Decrypt(message) - if err != nil { - return nil, 0, err - } - return plaintext, uint(messageIndex), nil -} - -// ID returns a base64-encoded identifier for this session. -func (s *InboundGroupSession) ID() id.SessionID { - return s.MegolmInboundSession.SessionID() -} - -// FirstKnownIndex returns the first message index we know how to decrypt. -func (s *InboundGroupSession) FirstKnownIndex() uint32 { - return s.MegolmInboundSession.InitialRatchet.Counter -} - -// IsVerified check if the session has been verified as a valid session. (A -// session is verified either because the original session share was signed, or -// because we have subsequently successfully decrypted a message.) -func (s *InboundGroupSession) IsVerified() uint { - if s.MegolmInboundSession.SigningKeyVerified { - return 1 - } - return 0 -} - -// Export returns the base64-encoded ratchet key for this session, at the given -// index, in a format which can be used by -// InboundGroupSession.InboundGroupSessionImport(). Encrypts the -// InboundGroupSession using the supplied key. Returns error on failure. -// if we do not have a session key corresponding to the given index (ie, it was -// sent before the session key was shared with us) the error will be -// returned. -func (s *InboundGroupSession) Export(messageIndex uint32) ([]byte, error) { - res, err := s.MegolmInboundSession.SessionExportMessage(messageIndex) - if err != nil { - return nil, err - } - return res, nil -} diff --git a/crypto/olm/olm.go b/crypto/olm/olm.go index fa1ae856..fa2345e1 100644 --- a/crypto/olm/olm.go +++ b/crypto/olm/olm.go @@ -1,28 +1,20 @@ -//go:build !goolm +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. package olm -// #cgo LDFLAGS: -lolm -lstdc++ -// #include -import "C" +var GetVersion func() (major, minor, patch uint8) +var SetPickleKeyImpl func(key []byte) // Version returns the version number of the olm library. func Version() (major, minor, patch uint8) { - C.olm_get_library_version( - (*C.uint8_t)(&major), - (*C.uint8_t)(&minor), - (*C.uint8_t)(&patch)) - return + return GetVersion() } -// errorVal returns the value that olm functions return if there was an error. -func errorVal() C.size_t { - return C.olm_error() -} - -var pickleKey = []byte("maunium.net/go/mautrix/crypto/olm") - // SetPickleKey sets the global pickle key used when encoding structs with Gob or JSON. func SetPickleKey(key []byte) { - pickleKey = key + SetPickleKeyImpl(key) } diff --git a/crypto/olm/olm_goolm.go b/crypto/olm/olm_goolm.go deleted file mode 100644 index a1489ded..00000000 --- a/crypto/olm/olm_goolm.go +++ /dev/null @@ -1,13 +0,0 @@ -//go:build goolm - -package olm - -// Version returns the version number of the olm library. -func Version() (major, minor, patch uint8) { - return 3, 2, 15 -} - -// SetPickleKey sets the global pickle key used when encoding structs with Gob or JSON. -func SetPickleKey(key []byte) { - panic("gob and json encoding is deprecated and not supported with goolm") -} diff --git a/crypto/olm/outboundgroupsession.go b/crypto/olm/outboundgroupsession.go index b6a33d36..c5b7bcbf 100644 --- a/crypto/olm/outboundgroupsession.go +++ b/crypto/olm/outboundgroupsession.go @@ -1,239 +1,57 @@ -//go:build !goolm +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. package olm -// #cgo LDFLAGS: -lolm -lstdc++ -// #include -import "C" +import "maunium.net/go/mautrix/id" -import ( - "crypto/rand" - "encoding/base64" - "unsafe" +type OutboundGroupSession interface { + // Pickle returns a Session as a base64 string. Encrypts the Session using + // the supplied key. + Pickle(key []byte) ([]byte, error) - "maunium.net/go/mautrix/id" -) + // Unpickle loads an [OutboundGroupSession] from a pickled base64 string. + // Decrypts the [OutboundGroupSession] using the supplied key. + Unpickle(pickled, key []byte) error -// OutboundGroupSession stores an outbound encrypted messaging session for a -// group. -type OutboundGroupSession struct { - int *C.OlmOutboundGroupSession - mem []byte + // Encrypt encrypts a message using the [OutboundGroupSession]. Returns the + // encrypted message as base64. + Encrypt(plaintext []byte) ([]byte, error) + + // ID returns a base64-encoded identifier for this session. + ID() id.SessionID + + // MessageIndex returns the message index for this session. Each message + // is sent with an increasing index; this returns the index for the next + // message. + MessageIndex() uint + + // Key returns the base64-encoded current ratchet key for this session. + Key() string } +var InitNewOutboundGroupSessionFromPickled func(pickled, key []byte) (OutboundGroupSession, error) +var InitNewOutboundGroupSession func() OutboundGroupSession +var InitNewBlankOutboundGroupSession func() OutboundGroupSession + // OutboundGroupSessionFromPickled loads an OutboundGroupSession from a pickled // base64 string. Decrypts the OutboundGroupSession using the supplied key. // Returns error on failure. If the key doesn't match the one used to encrypt // the OutboundGroupSession then the error will be "BAD_SESSION_KEY". If the // base64 couldn't be decoded then the error will be "INVALID_BASE64". -func OutboundGroupSessionFromPickled(pickled, key []byte) (*OutboundGroupSession, error) { - if len(pickled) == 0 { - return nil, EmptyInput - } - s := NewBlankOutboundGroupSession() - return s, s.Unpickle(pickled, key) +func OutboundGroupSessionFromPickled(pickled, key []byte) (OutboundGroupSession, error) { + return InitNewOutboundGroupSessionFromPickled(pickled, key) } // NewOutboundGroupSession creates a new outbound group session. -func NewOutboundGroupSession() *OutboundGroupSession { - s := NewBlankOutboundGroupSession() - random := make([]byte, s.createRandomLen()+1) - _, err := rand.Read(random) - if err != nil { - panic(NotEnoughGoRandom) - } - r := C.olm_init_outbound_group_session( - (*C.OlmOutboundGroupSession)(s.int), - (*C.uint8_t)(&random[0]), - C.size_t(len(random))) - if r == errorVal() { - panic(s.lastError()) - } - return s +func NewOutboundGroupSession() OutboundGroupSession { + return InitNewOutboundGroupSession() } -// outboundGroupSessionSize is the size of an outbound group session object in -// bytes. -func outboundGroupSessionSize() uint { - return uint(C.olm_outbound_group_session_size()) -} - -// newOutboundGroupSession initialises an empty OutboundGroupSession. -func NewBlankOutboundGroupSession() *OutboundGroupSession { - memory := make([]byte, outboundGroupSessionSize()) - return &OutboundGroupSession{ - int: C.olm_outbound_group_session(unsafe.Pointer(&memory[0])), - mem: memory, - } -} - -// lastError returns an error describing the most recent error to happen to an -// outbound group session. -func (s *OutboundGroupSession) lastError() error { - return convertError(C.GoString(C.olm_outbound_group_session_last_error((*C.OlmOutboundGroupSession)(s.int)))) -} - -// Clear clears the memory used to back this OutboundGroupSession. -func (s *OutboundGroupSession) Clear() error { - r := C.olm_clear_outbound_group_session((*C.OlmOutboundGroupSession)(s.int)) - if r == errorVal() { - return s.lastError() - } else { - return nil - } -} - -// pickleLen returns the number of bytes needed to store an outbound group -// session. -func (s *OutboundGroupSession) pickleLen() uint { - return uint(C.olm_pickle_outbound_group_session_length((*C.OlmOutboundGroupSession)(s.int))) -} - -// Pickle returns an OutboundGroupSession as a base64 string. Encrypts the -// OutboundGroupSession using the supplied key. -func (s *OutboundGroupSession) Pickle(key []byte) []byte { - if len(key) == 0 { - panic(NoKeyProvided) - } - pickled := make([]byte, s.pickleLen()) - r := C.olm_pickle_outbound_group_session( - (*C.OlmOutboundGroupSession)(s.int), - unsafe.Pointer(&key[0]), - C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), - C.size_t(len(pickled))) - if r == errorVal() { - panic(s.lastError()) - } - return pickled[:r] -} - -func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error { - if len(key) == 0 { - return NoKeyProvided - } - r := C.olm_unpickle_outbound_group_session( - (*C.OlmOutboundGroupSession)(s.int), - unsafe.Pointer(&key[0]), - C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), - C.size_t(len(pickled))) - if r == errorVal() { - return s.lastError() - } - return nil -} - -// Deprecated -func (s *OutboundGroupSession) GobEncode() ([]byte, error) { - pickled := s.Pickle(pickleKey) - length := base64.RawStdEncoding.DecodedLen(len(pickled)) - rawPickled := make([]byte, length) - _, err := base64.RawStdEncoding.Decode(rawPickled, pickled) - return rawPickled, err -} - -// Deprecated -func (s *OutboundGroupSession) GobDecode(rawPickled []byte) error { - if s == nil || s.int == nil { - *s = *NewBlankOutboundGroupSession() - } - length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) - pickled := make([]byte, length) - base64.RawStdEncoding.Encode(pickled, rawPickled) - return s.Unpickle(pickled, pickleKey) -} - -// Deprecated -func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) { - pickled := s.Pickle(pickleKey) - quotes := make([]byte, len(pickled)+2) - quotes[0] = '"' - quotes[len(quotes)-1] = '"' - copy(quotes[1:len(quotes)-1], pickled) - return quotes, nil -} - -// Deprecated -func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error { - if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return InputNotJSONString - } - if s == nil || s.int == nil { - *s = *NewBlankOutboundGroupSession() - } - return s.Unpickle(data[1:len(data)-1], pickleKey) -} - -// createRandomLen returns the number of random bytes needed to create an -// Account. -func (s *OutboundGroupSession) createRandomLen() uint { - return uint(C.olm_init_outbound_group_session_random_length((*C.OlmOutboundGroupSession)(s.int))) -} - -// encryptMsgLen returns the size of the next message in bytes for the given -// number of plain-text bytes. -func (s *OutboundGroupSession) encryptMsgLen(plainTextLen int) uint { - return uint(C.olm_group_encrypt_message_length((*C.OlmOutboundGroupSession)(s.int), C.size_t(plainTextLen))) -} - -// Encrypt encrypts a message using the Session. Returns the encrypted message -// as base64. -func (s *OutboundGroupSession) Encrypt(plaintext []byte) []byte { - if len(plaintext) == 0 { - panic(EmptyInput) - } - message := make([]byte, s.encryptMsgLen(len(plaintext))) - r := C.olm_group_encrypt( - (*C.OlmOutboundGroupSession)(s.int), - (*C.uint8_t)(&plaintext[0]), - C.size_t(len(plaintext)), - (*C.uint8_t)(&message[0]), - C.size_t(len(message))) - if r == errorVal() { - panic(s.lastError()) - } - return message[:r] -} - -// sessionIdLen returns the number of bytes needed to store a session ID. -func (s *OutboundGroupSession) sessionIdLen() uint { - return uint(C.olm_outbound_group_session_id_length((*C.OlmOutboundGroupSession)(s.int))) -} - -// ID returns a base64-encoded identifier for this session. -func (s *OutboundGroupSession) ID() id.SessionID { - sessionID := make([]byte, s.sessionIdLen()) - r := C.olm_outbound_group_session_id( - (*C.OlmOutboundGroupSession)(s.int), - (*C.uint8_t)(&sessionID[0]), - C.size_t(len(sessionID))) - if r == errorVal() { - panic(s.lastError()) - } - return id.SessionID(sessionID[:r]) -} - -// MessageIndex returns the message index for this session. Each message is -// sent with an increasing index; this returns the index for the next message. -func (s *OutboundGroupSession) MessageIndex() uint { - return uint(C.olm_outbound_group_session_message_index((*C.OlmOutboundGroupSession)(s.int))) -} - -// sessionKeyLen returns the number of bytes needed to store a session key. -func (s *OutboundGroupSession) sessionKeyLen() uint { - return uint(C.olm_outbound_group_session_key_length((*C.OlmOutboundGroupSession)(s.int))) -} - -// Key returns the base64-encoded current ratchet key for this session. -func (s *OutboundGroupSession) Key() string { - sessionKey := make([]byte, s.sessionKeyLen()) - r := C.olm_outbound_group_session_key( - (*C.OlmOutboundGroupSession)(s.int), - (*C.uint8_t)(&sessionKey[0]), - C.size_t(len(sessionKey))) - if r == errorVal() { - panic(s.lastError()) - } - return string(sessionKey[:r]) +// NewBlankOutboundGroupSession initialises an empty [OutboundGroupSession]. +func NewBlankOutboundGroupSession() OutboundGroupSession { + return InitNewBlankOutboundGroupSession() } diff --git a/crypto/olm/outboundgroupsession_goolm.go b/crypto/olm/outboundgroupsession_goolm.go deleted file mode 100644 index 7c201213..00000000 --- a/crypto/olm/outboundgroupsession_goolm.go +++ /dev/null @@ -1,111 +0,0 @@ -//go:build goolm - -package olm - -import ( - "maunium.net/go/mautrix/crypto/goolm/session" - "maunium.net/go/mautrix/id" -) - -// OutboundGroupSession stores an outbound encrypted messaging session for a -// group. -type OutboundGroupSession struct { - session.MegolmOutboundSession -} - -// OutboundGroupSessionFromPickled loads an OutboundGroupSession from a pickled -// base64 string. Decrypts the OutboundGroupSession using the supplied key. -// Returns error on failure. If the key doesn't match the one used to encrypt -// the OutboundGroupSession then the error will be "BAD_SESSION_KEY". If the -// base64 couldn't be decoded then the error will be "INVALID_BASE64". -func OutboundGroupSessionFromPickled(pickled, key []byte) (*OutboundGroupSession, error) { - if len(pickled) == 0 { - return nil, EmptyInput - } - lenKey := len(key) - if lenKey == 0 { - key = []byte(" ") - } - megolmSession, err := session.MegolmOutboundSessionFromPickled(pickled, key) - if err != nil { - return nil, err - } - return &OutboundGroupSession{ - MegolmOutboundSession: *megolmSession, - }, nil -} - -// NewOutboundGroupSession creates a new outbound group session. -func NewOutboundGroupSession() *OutboundGroupSession { - megolmSession, err := session.NewMegolmOutboundSession() - if err != nil { - panic(err) - } - return &OutboundGroupSession{ - MegolmOutboundSession: *megolmSession, - } -} - -// newOutboundGroupSession initialises an empty OutboundGroupSession. -func NewBlankOutboundGroupSession() *OutboundGroupSession { - return &OutboundGroupSession{} -} - -// Clear clears the memory used to back this OutboundGroupSession. -func (s *OutboundGroupSession) Clear() error { - s.MegolmOutboundSession = session.MegolmOutboundSession{} - return nil -} - -// Pickle returns an OutboundGroupSession as a base64 string. Encrypts the -// OutboundGroupSession using the supplied key. -func (s *OutboundGroupSession) Pickle(key []byte) []byte { - if len(key) == 0 { - panic(NoKeyProvided) - } - pickled, err := s.MegolmOutboundSession.Pickle(key) - if err != nil { - panic(err) - } - return pickled -} - -func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error { - if len(key) == 0 { - return NoKeyProvided - } - return s.MegolmOutboundSession.Unpickle(pickled, key) -} - -// Encrypt encrypts a message using the Session. Returns the encrypted message -// as base64. -func (s *OutboundGroupSession) Encrypt(plaintext []byte) []byte { - if len(plaintext) == 0 { - panic(EmptyInput) - } - message, err := s.MegolmOutboundSession.Encrypt(plaintext) - if err != nil { - panic(err) - } - return message -} - -// ID returns a base64-encoded identifier for this session. -func (s *OutboundGroupSession) ID() id.SessionID { - return s.MegolmOutboundSession.SessionID() -} - -// MessageIndex returns the message index for this session. Each message is -// sent with an increasing index; this returns the index for the next message. -func (s *OutboundGroupSession) MessageIndex() uint { - return uint(s.MegolmOutboundSession.Ratchet.Counter) -} - -// Key returns the base64-encoded current ratchet key for this session. -func (s *OutboundGroupSession) Key() string { - message, err := s.MegolmOutboundSession.SessionSharingMessage() - if err != nil { - panic(err) - } - return string(message) -} diff --git a/crypto/olm/pk_interface.go b/crypto/olm/pk.go similarity index 52% rename from crypto/olm/pk_interface.go rename to crypto/olm/pk.go index 11c41431..70ee452d 100644 --- a/crypto/olm/pk_interface.go +++ b/crypto/olm/pk.go @@ -7,7 +7,6 @@ package olm import ( - "maunium.net/go/mautrix/crypto/goolm/pk" "maunium.net/go/mautrix/id" ) @@ -27,15 +26,32 @@ type PKSigning interface { SignJSON(obj any) (string, error) } -var _ PKSigning = (*pk.Signing)(nil) - // PKDecryption is an interface for decrypting messages. type PKDecryption interface { // PublicKey returns the public key. PublicKey() id.Curve25519 // Decrypt verifies and decrypts the given message. - Decrypt(ciphertext, mac []byte, key id.Curve25519) ([]byte, error) + Decrypt(ephemeralKey, mac, ciphertext []byte) ([]byte, error) } -var _ PKDecryption = (*pk.Decryption)(nil) +var InitNewPKSigning func() (PKSigning, error) +var InitNewPKSigningFromSeed func(seed []byte) (PKSigning, error) +var InitNewPKDecryptionFromPrivateKey func(privateKey []byte) (PKDecryption, error) + +// NewPKSigning creates a new [PKSigning] object, containing a key pair for +// signing messages. +func NewPKSigning() (PKSigning, error) { + return InitNewPKSigning() +} + +// NewPKSigningFromSeed creates a new PKSigning object using the given seed. +func NewPKSigningFromSeed(seed []byte) (PKSigning, error) { + return InitNewPKSigningFromSeed(seed) +} + +// NewPKDecryptionFromPrivateKey creates a new [PKDecryption] from a +// base64-encoded private key. +func NewPKDecryptionFromPrivateKey(privateKey []byte) (PKDecryption, error) { + return InitNewPKDecryptionFromPrivateKey(privateKey) +} diff --git a/crypto/olm/pk_goolm.go b/crypto/olm/pk_goolm.go deleted file mode 100644 index 372c94fa..00000000 --- a/crypto/olm/pk_goolm.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) 2024 Sumner Evans -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -// When the goolm build flag is enabled, this file will make [PKSigning] -// constructors use the goolm constuctors. - -//go:build goolm - -package olm - -import "maunium.net/go/mautrix/crypto/goolm/pk" - -// NewPKSigningFromSeed creates a new PKSigning object using the given seed. -func NewPKSigningFromSeed(seed []byte) (PKSigning, error) { - return pk.NewSigningFromSeed(seed) -} - -// NewPKSigning creates a new [PKSigning] object, containing a key pair for -// signing messages. -func NewPKSigning() (PKSigning, error) { - return pk.NewSigning() -} - -func NewPKDecryption(privateKey []byte) (PKDecryption, error) { - return pk.NewDecryption() -} diff --git a/crypto/olm/pk_test.go b/crypto/olm/pk_test.go index b57e6571..99ac1e6b 100644 --- a/crypto/olm/pk_test.go +++ b/crypto/olm/pk_test.go @@ -4,8 +4,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -// Only run this test if goo is disabled (that is, libolm is used). -//go:build !goolm +// Only run this test if goolm is disabled (that is, libolm is used). package olm_test @@ -16,7 +15,7 @@ import ( "github.com/stretchr/testify/require" "maunium.net/go/mautrix/crypto/goolm/pk" - "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/crypto/libolm" ) func FuzzSign(f *testing.F) { @@ -24,7 +23,7 @@ func FuzzSign(f *testing.F) { goolmPkSigning, err := pk.NewSigningFromSeed(seed) require.NoError(f, err) - libolmPkSigning, err := olm.NewPKSigningFromSeed(seed) + libolmPkSigning, err := libolm.NewPKSigningFromSeed(seed) require.NoError(f, err) f.Add([]byte("message")) diff --git a/crypto/olm/session.go b/crypto/olm/session.go index 185e0b3d..c4b91ffc 100644 --- a/crypto/olm/session.go +++ b/crypto/olm/session.go @@ -1,362 +1,83 @@ -//go:build !goolm +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. package olm -// #cgo LDFLAGS: -lolm -lstdc++ -// #include -// #include -// #include -// void olm_session_describe(OlmSession * session, char *buf, size_t buflen) __attribute__((weak)); -// void meowlm_session_describe(OlmSession * session, char *buf, size_t buflen) { -// if (olm_session_describe) { -// olm_session_describe(session, buf, buflen); -// } else { -// sprintf(buf, "olm_session_describe not supported"); -// } -// } -import "C" +import "maunium.net/go/mautrix/id" -import ( - "crypto/rand" - "encoding/base64" - "unsafe" +type Session interface { + // Pickle returns a Session as a base64 string. Encrypts the Session using + // the supplied key. + Pickle(key []byte) ([]byte, error) - "maunium.net/go/mautrix/id" -) + // Unpickle loads a Session from a pickled base64 string. Decrypts the + // Session using the supplied key. + Unpickle(pickled, key []byte) error -// Session stores an end to end encrypted messaging session. -type Session struct { - int *C.OlmSession - mem []byte + // ID returns an identifier for this Session. Will be the same for both + // ends of the conversation. + ID() id.SessionID + + // HasReceivedMessage returns true if this session has received any + // message. + HasReceivedMessage() bool + + // MatchesInboundSession checks if the PRE_KEY message is for this in-bound + // Session. This can happen if multiple messages are sent to this Account + // before this Account sends a message in reply. Returns true if the + // session matches. Returns false if the session does not match. Returns + // error on failure. If the base64 couldn't be decoded then the error will + // be "INVALID_BASE64". If the message was for an unsupported protocol + // version then the error will be "BAD_MESSAGE_VERSION". If the message + // couldn't be decoded then then the error will be "BAD_MESSAGE_FORMAT". + MatchesInboundSession(oneTimeKeyMsg string) (bool, error) + + // MatchesInboundSessionFrom checks if the PRE_KEY message is for this + // in-bound Session. This can happen if multiple messages are sent to this + // Account before this Account sends a message in reply. Returns true if + // the session matches. Returns false if the session does not match. + // Returns error on failure. If the base64 couldn't be decoded then the + // error will be "INVALID_BASE64". If the message was for an unsupported + // protocol version then the error will be "BAD_MESSAGE_VERSION". If the + // message couldn't be decoded then then the error will be + // "BAD_MESSAGE_FORMAT". + MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) + + // EncryptMsgType returns the type of the next message that Encrypt will + // return. Returns MsgTypePreKey if the message will be a PRE_KEY message. + // Returns MsgTypeMsg if the message will be a normal message. + EncryptMsgType() id.OlmMsgType + + // Encrypt encrypts a message using the Session. Returns the encrypted + // message as base64. + Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) + + // Decrypt decrypts a message using the Session. Returns the plain-text on + // success. Returns error on failure. If the base64 couldn't be decoded + // then the error will be "INVALID_BASE64". If the message is for an + // unsupported version of the protocol then the error will be + // "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error + // will be BAD_MESSAGE_FORMAT". If the MAC on the message was invalid then + // the error will be "BAD_MESSAGE_MAC". + Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) + + // Describe generates a string describing the internal state of an olm + // session for debugging and logging purposes. + Describe() string } -// sessionSize is the size of a session object in bytes. -func sessionSize() uint { - return uint(C.olm_session_size()) -} +var InitSessionFromPickled func(pickled, key []byte) (Session, error) +var InitNewBlankSession func() Session // SessionFromPickled loads a Session from a pickled base64 string. Decrypts -// the Session using the supplied key. Returns error on failure. If the key -// doesn't match the one used to encrypt the Session then the error will be -// "BAD_SESSION_KEY". If the base64 couldn't be decoded then the error will be -// "INVALID_BASE64". -func SessionFromPickled(pickled, key []byte) (*Session, error) { - if len(pickled) == 0 { - return nil, EmptyInput - } - s := NewBlankSession() - return s, s.Unpickle(pickled, key) +// the Session using the supplied key. Returns error on failure. +func SessionFromPickled(pickled, key []byte) (Session, error) { + return InitSessionFromPickled(pickled, key) } -func NewBlankSession() *Session { - memory := make([]byte, sessionSize()) - return &Session{ - int: C.olm_session(unsafe.Pointer(&memory[0])), - mem: memory, - } -} - -// lastError returns an error describing the most recent error to happen to a -// session. -func (s *Session) lastError() error { - return convertError(C.GoString(C.olm_session_last_error((*C.OlmSession)(s.int)))) -} - -// Clear clears the memory used to back this Session. -func (s *Session) Clear() error { - r := C.olm_clear_session((*C.OlmSession)(s.int)) - if r == errorVal() { - return s.lastError() - } - return nil -} - -// pickleLen returns the number of bytes needed to store a session. -func (s *Session) pickleLen() uint { - return uint(C.olm_pickle_session_length((*C.OlmSession)(s.int))) -} - -// createOutboundRandomLen returns the number of random bytes needed to create -// an outbound session. -func (s *Session) createOutboundRandomLen() uint { - return uint(C.olm_create_outbound_session_random_length((*C.OlmSession)(s.int))) -} - -// idLen returns the length of the buffer needed to return the id for this -// session. -func (s *Session) idLen() uint { - return uint(C.olm_session_id_length((*C.OlmSession)(s.int))) -} - -// encryptRandomLen returns the number of random bytes needed to encrypt the -// next message. -func (s *Session) encryptRandomLen() uint { - return uint(C.olm_encrypt_random_length((*C.OlmSession)(s.int))) -} - -// encryptMsgLen returns the size of the next message in bytes for the given -// number of plain-text bytes. -func (s *Session) encryptMsgLen(plainTextLen int) uint { - return uint(C.olm_encrypt_message_length((*C.OlmSession)(s.int), C.size_t(plainTextLen))) -} - -// decryptMaxPlaintextLen returns the maximum number of bytes of plain-text a -// given message could decode to. The actual size could be different due to -// padding. Returns error on failure. If the message base64 couldn't be -// decoded then the error will be "INVALID_BASE64". If the message is for an -// unsupported version of the protocol then the error will be -// "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error -// will be "BAD_MESSAGE_FORMAT". -func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) (uint, error) { - if len(message) == 0 { - return 0, EmptyInput - } - r := C.olm_decrypt_max_plaintext_length( - (*C.OlmSession)(s.int), - C.size_t(msgType), - unsafe.Pointer(C.CString(message)), - C.size_t(len(message))) - if r == errorVal() { - return 0, s.lastError() - } - return uint(r), nil -} - -// Pickle returns a Session as a base64 string. Encrypts the Session using the -// supplied key. -func (s *Session) Pickle(key []byte) []byte { - if len(key) == 0 { - panic(NoKeyProvided) - } - pickled := make([]byte, s.pickleLen()) - r := C.olm_pickle_session( - (*C.OlmSession)(s.int), - unsafe.Pointer(&key[0]), - C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), - C.size_t(len(pickled))) - if r == errorVal() { - panic(s.lastError()) - } - return pickled[:r] -} - -func (s *Session) Unpickle(pickled, key []byte) error { - if len(key) == 0 { - return NoKeyProvided - } - r := C.olm_unpickle_session( - (*C.OlmSession)(s.int), - unsafe.Pointer(&key[0]), - C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), - C.size_t(len(pickled))) - if r == errorVal() { - return s.lastError() - } - return nil -} - -// Deprecated -func (s *Session) GobEncode() ([]byte, error) { - pickled := s.Pickle(pickleKey) - length := base64.RawStdEncoding.DecodedLen(len(pickled)) - rawPickled := make([]byte, length) - _, err := base64.RawStdEncoding.Decode(rawPickled, pickled) - return rawPickled, err -} - -// Deprecated -func (s *Session) GobDecode(rawPickled []byte) error { - if s == nil || s.int == nil { - *s = *NewBlankSession() - } - length := base64.RawStdEncoding.EncodedLen(len(rawPickled)) - pickled := make([]byte, length) - base64.RawStdEncoding.Encode(pickled, rawPickled) - return s.Unpickle(pickled, pickleKey) -} - -// Deprecated -func (s *Session) MarshalJSON() ([]byte, error) { - pickled := s.Pickle(pickleKey) - quotes := make([]byte, len(pickled)+2) - quotes[0] = '"' - quotes[len(quotes)-1] = '"' - copy(quotes[1:len(quotes)-1], pickled) - return quotes, nil -} - -// Deprecated -func (s *Session) UnmarshalJSON(data []byte) error { - if len(data) == 0 || len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return InputNotJSONString - } - if s == nil || s.int == nil { - *s = *NewBlankSession() - } - return s.Unpickle(data[1:len(data)-1], pickleKey) -} - -// Id returns an identifier for this Session. Will be the same for both ends -// of the conversation. -func (s *Session) ID() id.SessionID { - sessionID := make([]byte, s.idLen()) - r := C.olm_session_id( - (*C.OlmSession)(s.int), - unsafe.Pointer(&sessionID[0]), - C.size_t(len(sessionID))) - if r == errorVal() { - panic(s.lastError()) - } - return id.SessionID(sessionID) -} - -// HasReceivedMessage returns true if this session has received any message. -func (s *Session) HasReceivedMessage() bool { - switch C.olm_session_has_received_message((*C.OlmSession)(s.int)) { - case 0: - return false - default: - return true - } -} - -// MatchesInboundSession checks if the PRE_KEY message is for this in-bound -// Session. This can happen if multiple messages are sent to this Account -// before this Account sends a message in reply. Returns true if the session -// matches. Returns false if the session does not match. Returns error on -// failure. If the base64 couldn't be decoded then the error will be -// "INVALID_BASE64". If the message was for an unsupported protocol version -// then the error will be "BAD_MESSAGE_VERSION". If the message couldn't be -// decoded then then the error will be "BAD_MESSAGE_FORMAT". -func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { - if len(oneTimeKeyMsg) == 0 { - return false, EmptyInput - } - r := C.olm_matches_inbound_session( - (*C.OlmSession)(s.int), - unsafe.Pointer(&([]byte(oneTimeKeyMsg))[0]), - C.size_t(len(oneTimeKeyMsg))) - if r == 1 { - return true, nil - } else if r == 0 { - return false, nil - } else { // if r == errorVal() - return false, s.lastError() - } -} - -// MatchesInboundSessionFrom checks if the PRE_KEY message is for this in-bound -// Session. This can happen if multiple messages are sent to this Account -// before this Account sends a message in reply. Returns true if the session -// matches. Returns false if the session does not match. Returns error on -// failure. If the base64 couldn't be decoded then the error will be -// "INVALID_BASE64". If the message was for an unsupported protocol version -// then the error will be "BAD_MESSAGE_VERSION". If the message couldn't be -// decoded then then the error will be "BAD_MESSAGE_FORMAT". -func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) { - if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 { - return false, EmptyInput - } - r := C.olm_matches_inbound_session_from( - (*C.OlmSession)(s.int), - unsafe.Pointer(&([]byte(theirIdentityKey))[0]), - C.size_t(len(theirIdentityKey)), - unsafe.Pointer(&([]byte(oneTimeKeyMsg))[0]), - C.size_t(len(oneTimeKeyMsg))) - if r == 1 { - return true, nil - } else if r == 0 { - return false, nil - } else { // if r == errorVal() - return false, s.lastError() - } -} - -// EncryptMsgType returns the type of the next message that Encrypt will -// return. Returns MsgTypePreKey if the message will be a PRE_KEY message. -// Returns MsgTypeMsg if the message will be a normal message. Returns error -// on failure. -func (s *Session) EncryptMsgType() id.OlmMsgType { - switch C.olm_encrypt_message_type((*C.OlmSession)(s.int)) { - case C.size_t(id.OlmMsgTypePreKey): - return id.OlmMsgTypePreKey - case C.size_t(id.OlmMsgTypeMsg): - return id.OlmMsgTypeMsg - default: - panic("olm_encrypt_message_type returned invalid result") - } -} - -// Encrypt encrypts a message using the Session. Returns the encrypted message -// as base64. -func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) { - if len(plaintext) == 0 { - panic(EmptyInput) - } - // Make the slice be at least length 1 - random := make([]byte, s.encryptRandomLen()+1) - _, err := rand.Read(random) - if err != nil { - panic(NotEnoughGoRandom) - } - messageType := s.EncryptMsgType() - message := make([]byte, s.encryptMsgLen(len(plaintext))) - r := C.olm_encrypt( - (*C.OlmSession)(s.int), - unsafe.Pointer(&plaintext[0]), - C.size_t(len(plaintext)), - unsafe.Pointer(&random[0]), - C.size_t(len(random)), - unsafe.Pointer(&message[0]), - C.size_t(len(message))) - if r == errorVal() { - panic(s.lastError()) - } - return messageType, message[:r] -} - -// Decrypt decrypts a message using the Session. Returns the the plain-text on -// success. Returns error on failure. If the base64 couldn't be decoded then -// the error will be "INVALID_BASE64". If the message is for an unsupported -// version of the protocol then the error will be "BAD_MESSAGE_VERSION". If -// the message couldn't be decoded then the error will be BAD_MESSAGE_FORMAT". -// If the MAC on the message was invalid then the error will be -// "BAD_MESSAGE_MAC". -func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) { - if len(message) == 0 { - return nil, EmptyInput - } - decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message, msgType) - if err != nil { - return nil, err - } - messageCopy := []byte(message) - plaintext := make([]byte, decryptMaxPlaintextLen) - r := C.olm_decrypt( - (*C.OlmSession)(s.int), - C.size_t(msgType), - unsafe.Pointer(&(messageCopy)[0]), - C.size_t(len(messageCopy)), - unsafe.Pointer(&plaintext[0]), - C.size_t(len(plaintext))) - if r == errorVal() { - return nil, s.lastError() - } - return plaintext[:r], nil -} - -// https://gitlab.matrix.org/matrix-org/olm/-/blob/3.2.8/include/olm/olm.h#L392-393 -const maxDescribeSize = 600 - -// Describe generates a string describing the internal state of an olm session for debugging and logging purposes. -func (s *Session) Describe() string { - desc := (*C.char)(C.malloc(C.size_t(maxDescribeSize))) - defer C.free(unsafe.Pointer(desc)) - C.meowlm_session_describe( - (*C.OlmSession)(s.int), - desc, - C.size_t(maxDescribeSize)) - return C.GoString(desc) +func NewBlankSession() Session { + return InitNewBlankSession() } diff --git a/crypto/olm/session_goolm.go b/crypto/olm/session_goolm.go deleted file mode 100644 index c77efaa2..00000000 --- a/crypto/olm/session_goolm.go +++ /dev/null @@ -1,110 +0,0 @@ -//go:build goolm - -package olm - -import ( - "maunium.net/go/mautrix/crypto/goolm/session" - "maunium.net/go/mautrix/id" -) - -// Session stores an end to end encrypted messaging session. -type Session struct { - session.OlmSession -} - -// SessionFromPickled loads a Session from a pickled base64 string. Decrypts -// the Session using the supplied key. Returns error on failure. -func SessionFromPickled(pickled, key []byte) (*Session, error) { - if len(pickled) == 0 { - return nil, EmptyInput - } - s := NewBlankSession() - return s, s.Unpickle(pickled, key) -} - -func NewBlankSession() *Session { - return &Session{} -} - -// Clear clears the memory used to back this Session. -func (s *Session) Clear() error { - s.OlmSession = session.OlmSession{} - return nil -} - -// Pickle returns a Session as a base64 string. Encrypts the Session using the -// supplied key. -func (s *Session) Pickle(key []byte) []byte { - if len(key) == 0 { - panic(NoKeyProvided) - } - pickled, err := s.OlmSession.Pickle(key) - if err != nil { - panic(err) - } - return pickled -} - -func (s *Session) Unpickle(pickled, key []byte) error { - if len(key) == 0 { - return NoKeyProvided - } else if len(pickled) == 0 { - return EmptyInput - } - sOlm, err := session.OlmSessionFromPickled(pickled, key) - if err != nil { - return err - } - s.OlmSession = *sOlm - return nil -} - -// MatchesInboundSession checks if the PRE_KEY message is for this in-bound -// Session. This can happen if multiple messages are sent to this Account -// before this Account sends a message in reply. Returns true if the session -// matches. Returns false if the session does not match. Returns error on -// failure. -func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { - return s.MatchesInboundSessionFrom("", oneTimeKeyMsg) -} - -// MatchesInboundSessionFrom checks if the PRE_KEY message is for this in-bound -// Session. This can happen if multiple messages are sent to this Account -// before this Account sends a message in reply. Returns true if the session -// matches. Returns false if the session does not match. Returns error on -// failure. -func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) { - if theirIdentityKey != "" { - theirKey := id.Curve25519(theirIdentityKey) - return s.OlmSession.MatchesInboundSessionFrom(&theirKey, []byte(oneTimeKeyMsg)) - } - return s.OlmSession.MatchesInboundSessionFrom(nil, []byte(oneTimeKeyMsg)) - -} - -// Encrypt encrypts a message using the Session. Returns the encrypted message -// as base64. -func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) { - if len(plaintext) == 0 { - panic(EmptyInput) - } - messageType, message, err := s.OlmSession.Encrypt(plaintext, nil) - if err != nil { - panic(err) - } - return messageType, message -} - -// Decrypt decrypts a message using the Session. Returns the the plain-text on -// success. Returns error on failure. -func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) { - if len(message) == 0 { - return nil, EmptyInput - } - return s.OlmSession.Decrypt([]byte(message), msgType) -} - -// Describe generates a string describing the internal state of an olm session for debugging and logging purposes. -func (s *Session) Describe() string { - return s.OlmSession.Describe() -} diff --git a/crypto/olm/session_test.go b/crypto/olm/session_test.go new file mode 100644 index 00000000..ff9445d9 --- /dev/null +++ b/crypto/olm/session_test.go @@ -0,0 +1,56 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package olm_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "maunium.net/go/mautrix/crypto/goolm/session" + "maunium.net/go/mautrix/crypto/libolm" +) + +func TestBlankSession(t *testing.T) { + libolmSession := libolm.NewBlankSession() + session := session.NewOlmSession() + + assert.Equal(t, libolmSession.ID(), session.ID()) + assert.Equal(t, libolmSession.HasReceivedMessage(), session.HasReceivedMessage()) + assert.Equal(t, libolmSession.EncryptMsgType(), session.EncryptMsgType()) + assert.Equal(t, libolmSession.Describe(), session.Describe()) + + libolmPickled, err := libolmSession.Pickle([]byte("test")) + assert.NoError(t, err) + goolmPickled, err := session.Pickle([]byte("test")) + assert.NoError(t, err) + assert.Equal(t, goolmPickled, libolmPickled) +} + +func TestSessionPickle(t *testing.T) { + pickledDataFromLibOlm := []byte("icDKYm0b4aO23WgUuOxdpPoxC0UlEOYPVeuduNH3IkpFsmnWx5KuEOpxGiZw5IuB/sSn2RZUCTiJ90IvgC7AClkYGHep9O8lpiqQX73XVKD9okZDCAkBc83eEq0DKYC7HBkGRAU/4T6QPIBBY3UK4QZwULLE/fLsi3j4YZBehMtnlsqgHK0q1bvX4cRznZItVKR4ro0O9EAk6LLxJtSnRu5elSUk7YXT") + pickleKey := []byte("secret_key") + + goolmSession := session.NewOlmSession() + err := goolmSession.Unpickle(pickledDataFromLibOlm, pickleKey) + assert.NoError(t, err) + + libolmSession := libolm.NewBlankSession() + err = libolmSession.Unpickle(pickledDataFromLibOlm, pickleKey) + assert.NoError(t, err) + + // Reset the pickle data since libolmSession.Unpickle modifies it. + pickledDataFromLibOlm = []byte("icDKYm0b4aO23WgUuOxdpPoxC0UlEOYPVeuduNH3IkpFsmnWx5KuEOpxGiZw5IuB/sSn2RZUCTiJ90IvgC7AClkYGHep9O8lpiqQX73XVKD9okZDCAkBc83eEq0DKYC7HBkGRAU/4T6QPIBBY3UK4QZwULLE/fLsi3j4YZBehMtnlsqgHK0q1bvX4cRznZItVKR4ro0O9EAk6LLxJtSnRu5elSUk7YXT") + + goolmPickled, err := goolmSession.Pickle(pickleKey) + assert.NoError(t, err) + assert.Equal(t, pickledDataFromLibOlm, goolmPickled) + + libolmPickled, err := libolmSession.Pickle(pickleKey) + assert.Equal(t, pickledDataFromLibOlm, libolmPickled) + assert.NoError(t, err) +} diff --git a/crypto/registergoolm.go b/crypto/registergoolm.go new file mode 100644 index 00000000..f5cecafc --- /dev/null +++ b/crypto/registergoolm.go @@ -0,0 +1,5 @@ +//go:build goolm + +package crypto + +import _ "maunium.net/go/mautrix/crypto/goolm" diff --git a/crypto/registerlibolm.go b/crypto/registerlibolm.go new file mode 100644 index 00000000..ab388a5c --- /dev/null +++ b/crypto/registerlibolm.go @@ -0,0 +1,5 @@ +//go:build !goolm + +package crypto + +import _ "maunium.net/go/mautrix/crypto/libolm" diff --git a/crypto/sessions.go b/crypto/sessions.go index 6075a644..4aac6cf7 100644 --- a/crypto/sessions.go +++ b/crypto/sessions.go @@ -54,9 +54,9 @@ func (session *OlmSession) Describe() string { return session.Internal.Describe() } -func wrapSession(session *olm.Session) *OlmSession { +func wrapSession(session olm.Session) *OlmSession { return &OlmSession{ - Internal: *session, + Internal: session, ExpirationMixin: ExpirationMixin{ TimeMixin: TimeMixin{ CreationTime: time.Now(), @@ -68,7 +68,7 @@ func wrapSession(session *olm.Session) *OlmSession { } func (account *OlmAccount) NewInboundSessionFrom(senderKey id.Curve25519, ciphertext string) (*OlmSession, error) { - session, err := account.Internal.NewInboundSessionFrom(senderKey, ciphertext) + session, err := account.Internal.NewInboundSessionFrom(&senderKey, ciphertext) if err != nil { return nil, err } @@ -76,7 +76,7 @@ func (account *OlmAccount) NewInboundSessionFrom(senderKey id.Curve25519, cipher return wrapSession(session), nil } -func (session *OlmSession) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) { +func (session *OlmSession) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) { session.LastEncryptedTime = time.Now() return session.Internal.Encrypt(plaintext) } @@ -120,7 +120,7 @@ func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomI return nil, err } return &InboundGroupSession{ - Internal: *igs, + Internal: igs, SigningKey: signingKey, SenderKey: senderKey, RoomID: roomID, @@ -148,7 +148,7 @@ func (igs *InboundGroupSession) RatchetTo(index uint32) error { if err != nil { return err } - igs.Internal = *imported + igs.Internal = imported return nil } @@ -182,7 +182,7 @@ type OutboundGroupSession struct { func NewOutboundGroupSession(roomID id.RoomID, encryptionContent *event.EncryptionEventContent) *OutboundGroupSession { ogs := &OutboundGroupSession{ - Internal: *olm.NewOutboundGroupSession(), + Internal: olm.NewOutboundGroupSession(), ExpirationMixin: ExpirationMixin{ TimeMixin: TimeMixin{ CreationTime: time.Now(), @@ -240,7 +240,7 @@ func (ogs *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) { } ogs.MessageCount++ ogs.LastEncryptedTime = time.Now() - return ogs.Internal.Encrypt(plaintext), nil + return ogs.Internal.Encrypt(plaintext) } type TimeMixin struct { diff --git a/crypto/sql_store.go b/crypto/sql_store.go index a8ccab26..255247fd 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -123,8 +123,11 @@ func (store *SQLCryptoStore) FindDeviceID(ctx context.Context) (deviceID id.Devi // PutAccount stores an OlmAccount in the database. func (store *SQLCryptoStore) PutAccount(ctx context.Context, account *OlmAccount) error { store.Account = account - bytes := account.Internal.Pickle(store.PickleKey) - _, err := store.DB.Exec(ctx, ` + bytes, err := account.Internal.Pickle(store.PickleKey) + if err != nil { + return err + } + _, err = store.DB.Exec(ctx, ` INSERT INTO crypto_account (device_id, shared, sync_token, account, account_id, key_backup_version) VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (account_id) DO UPDATE SET shared=excluded.shared, sync_token=excluded.sync_token, account=excluded.account, account_id=excluded.account_id, @@ -137,7 +140,7 @@ func (store *SQLCryptoStore) PutAccount(ctx context.Context, account *OlmAccount func (store *SQLCryptoStore) GetAccount(ctx context.Context) (*OlmAccount, error) { if store.Account == nil { row := store.DB.QueryRow(ctx, "SELECT shared, sync_token, account, key_backup_version FROM crypto_account WHERE account_id=$1", store.AccountID) - acc := &OlmAccount{Internal: *olm.NewBlankAccount()} + acc := &OlmAccount{Internal: olm.NewBlankAccount()} var accountBytes []byte err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes, &acc.KeyBackupVersion) if err == sql.ErrNoRows { @@ -183,7 +186,7 @@ func (store *SQLCryptoStore) GetSessions(ctx context.Context, key id.SenderKey) defer store.olmSessionCacheLock.Unlock() cache := store.getOlmSessionCache(key) for rows.Next() { - sess := OlmSession{Internal: *olm.NewBlankSession()} + sess := OlmSession{Internal: olm.NewBlankSession()} var sessionBytes []byte var sessionID id.SessionID err = rows.Scan(&sessionID, &sessionBytes, &sess.CreationTime, &sess.LastEncryptedTime, &sess.LastDecryptedTime) @@ -220,7 +223,7 @@ func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.Sender row := store.DB.QueryRow(ctx, "SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC LIMIT 1", key, store.AccountID) - sess := OlmSession{Internal: *olm.NewBlankSession()} + sess := OlmSession{Internal: olm.NewBlankSession()} var sessionBytes []byte var sessionID id.SessionID @@ -246,8 +249,11 @@ func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.Sender func (store *SQLCryptoStore) AddSession(ctx context.Context, key id.SenderKey, session *OlmSession) error { store.olmSessionCacheLock.Lock() defer store.olmSessionCacheLock.Unlock() - sessionBytes := session.Internal.Pickle(store.PickleKey) - _, err := store.DB.Exec(ctx, "INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)", + sessionBytes, err := session.Internal.Pickle(store.PickleKey) + if err != nil { + return err + } + _, err = store.DB.Exec(ctx, "INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)", session.ID(), key, sessionBytes, session.CreationTime, session.LastEncryptedTime, session.LastDecryptedTime, store.AccountID) store.getOlmSessionCache(key)[session.ID()] = session return err @@ -255,8 +261,11 @@ func (store *SQLCryptoStore) AddSession(ctx context.Context, key id.SenderKey, s // UpdateSession replaces the Olm session for a sender in the database. func (store *SQLCryptoStore) UpdateSession(ctx context.Context, _ id.SenderKey, session *OlmSession) error { - sessionBytes := session.Internal.Pickle(store.PickleKey) - _, err := store.DB.Exec(ctx, "UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3 WHERE session_id=$4 AND account_id=$5", + sessionBytes, err := session.Internal.Pickle(store.PickleKey) + if err != nil { + return err + } + _, err = store.DB.Exec(ctx, "UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3 WHERE session_id=$4 AND account_id=$5", sessionBytes, session.LastEncryptedTime, session.LastDecryptedTime, session.ID(), store.AccountID) return err } @@ -270,7 +279,10 @@ func datePtr(t time.Time) *time.Time { // PutGroupSession stores an inbound Megolm group session for a room, sender and session. func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, session *InboundGroupSession) error { - sessionBytes := session.Internal.Pickle(store.PickleKey) + sessionBytes, err := session.Internal.Pickle(store.PickleKey) + if err != nil { + return err + } forwardingChains := strings.Join(session.ForwardingChains, ",") ratchetSafety, err := json.Marshal(&session.RatchetSafety) if err != nil { @@ -340,7 +352,7 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room return nil, err } return &InboundGroupSession{ - Internal: *igs, + Internal: igs, SigningKey: id.Ed25519(signingKey.String), SenderKey: id.Curve25519(senderKey.String), RoomID: roomID, @@ -455,7 +467,7 @@ func (store *SQLCryptoStore) GetWithheldGroupSession(ctx context.Context, roomID }, nil } -func (store *SQLCryptoStore) postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes []byte, forwardingChains string) (igs *olm.InboundGroupSession, chains []string, safety RatchetSafety, err error) { +func (store *SQLCryptoStore) postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes []byte, forwardingChains string) (igs olm.InboundGroupSession, chains []string, safety RatchetSafety, err error) { igs = olm.NewBlankInboundGroupSession() err = igs.Unpickle(sessionBytes, store.PickleKey) if err != nil { @@ -491,7 +503,7 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In return nil, err } return &InboundGroupSession{ - Internal: *igs, + Internal: igs, SigningKey: id.Ed25519(signingKey.String), SenderKey: id.Curve25519(senderKey.String), RoomID: roomID, @@ -534,8 +546,11 @@ func (store *SQLCryptoStore) GetGroupSessionsWithoutKeyBackupVersion(ctx context // AddOutboundGroupSession stores an outbound Megolm session, along with the information about the room and involved devices. func (store *SQLCryptoStore) AddOutboundGroupSession(ctx context.Context, session *OutboundGroupSession) error { - sessionBytes := session.Internal.Pickle(store.PickleKey) - _, err := store.DB.Exec(ctx, ` + sessionBytes, err := session.Internal.Pickle(store.PickleKey) + if err != nil { + return err + } + _, err = store.DB.Exec(ctx, ` INSERT INTO crypto_megolm_outbound_session (room_id, session_id, session, shared, max_messages, message_count, max_age, created_at, last_used, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) @@ -550,8 +565,11 @@ func (store *SQLCryptoStore) AddOutboundGroupSession(ctx context.Context, sessio // UpdateOutboundGroupSession replaces an outbound Megolm session with for same room and session ID. func (store *SQLCryptoStore) UpdateOutboundGroupSession(ctx context.Context, session *OutboundGroupSession) error { - sessionBytes := session.Internal.Pickle(store.PickleKey) - _, err := store.DB.Exec(ctx, "UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5 AND account_id=$6", + sessionBytes, err := session.Internal.Pickle(store.PickleKey) + if err != nil { + return err + } + _, err = store.DB.Exec(ctx, "UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5 AND account_id=$6", sessionBytes, session.MessageCount, session.LastEncryptedTime, session.RoomID, session.ID(), store.AccountID) return err } @@ -576,7 +594,7 @@ func (store *SQLCryptoStore) GetOutboundGroupSession(ctx context.Context, roomID if err != nil { return nil, err } - ogs.Internal = *intOGS + ogs.Internal = intOGS ogs.RoomID = roomID ogs.MaxAge = time.Duration(maxAgeMS) * time.Millisecond return &ogs, nil diff --git a/crypto/store_test.go b/crypto/store_test.go index 740273dd..08079f5e 100644 --- a/crypto/store_test.go +++ b/crypto/store_test.go @@ -115,7 +115,7 @@ func TestStoreOlmSession(t *testing.T) { olmSess := OlmSession{ id: olmSessID, - Internal: *olmInternal, + Internal: olmInternal, } err = store.AddSession(context.TODO(), olmSessID, &olmSess) if err != nil { @@ -133,7 +133,13 @@ func TestStoreOlmSession(t *testing.T) { if retrieved.ID() != olmSessID { t.Errorf("Expected session ID to be %v, got %v", olmSessID, retrieved.ID()) } - if pickled := string(retrieved.Internal.Pickle([]byte("test"))); pickled != olmPickled { + + pickled, err := retrieved.Internal.Pickle([]byte("test")) + if err != nil { + t.Fatalf("Error pickling Olm session: %v", err) + } + + if string(pickled) != olmPickled { t.Error("Pickled Olm session does not match original") } }) @@ -152,7 +158,7 @@ func TestStoreMegolmSession(t *testing.T) { } igs := &InboundGroupSession{ - Internal: *internal, + Internal: internal, SigningKey: acc.SigningKey(), SenderKey: acc.IdentityKey(), RoomID: "room1", @@ -168,7 +174,9 @@ func TestStoreMegolmSession(t *testing.T) { t.Errorf("Error retrieving inbound group session: %v", err) } - if pickled := string(retrieved.Internal.Pickle([]byte("test"))); pickled != groupSession { + if pickled, err := retrieved.Internal.Pickle([]byte("test")); err != nil { + t.Fatalf("Error pickling inbound group session: %v", err) + } else if string(pickled) != groupSession { t.Error("Pickled inbound group session does not match original") } }) From 66c417825bbbf843fd93920bd2e72ee8a39f0d0b Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 31 May 2024 10:11:04 -0600 Subject: [PATCH 0675/1647] crypto/olm: add tests comparing libolm and goolm, replace crypto/ed25519 -> maunium.net/go/mautrix/crypto/ed25519 The following tests were added that compare libolm and goolm against each other. * account: add (un)pickle tests * groupsession: add test for (en|de)cryption for group sessions * account: test IdentityKeysJSON and OneTimeKeys * session: add test for encrypt/decrypt * session: add test for private key format * outboundsession: add differential fuzz test for encryption Signed-off-by: Sumner Evans --- crypto/ed25519/ed25519.go | 16 ++- crypto/goolm/crypto/ed25519.go | 10 +- .../goolm/session/megolm_outbound_session.go | 9 +- crypto/olm/account_test.go | 122 +++++++++++++++++ crypto/olm/errors.go | 6 + crypto/olm/groupsession_test.go | 47 +++++++ crypto/olm/outboundgroupsession_test.go | 127 ++++++++++++++++++ crypto/olm/session_test.go | 81 +++++++++-- 8 files changed, 396 insertions(+), 22 deletions(-) create mode 100644 crypto/olm/account_test.go create mode 100644 crypto/olm/groupsession_test.go create mode 100644 crypto/olm/outboundgroupsession_test.go diff --git a/crypto/ed25519/ed25519.go b/crypto/ed25519/ed25519.go index 6b294c67..327cbb3c 100644 --- a/crypto/ed25519/ed25519.go +++ b/crypto/ed25519/ed25519.go @@ -6,8 +6,9 @@ // Package ed25519 implements the Ed25519 signature algorithm. See // https://ed25519.cr.yp.to/. // -// This package stores the private key in a different format than the -// [crypto/ed25519] package in the standard library. +// This package stores the private key in the NaCl format, which is a different +// format than that used by the [crypto/ed25519] package in the standard +// library. // // This picture will help with the rest of the explanation: // https://blog.mozilla.org/warner/files/2011/11/key-formats.png @@ -15,8 +16,9 @@ // The private key in the [crypto/ed25519] package is a 64-byte value where the // first 32-bytes are the seed and the last 32-bytes are the public key. // -// The private key in this package is stored as a 64-byte value that results -// from the SHA512 of the seed. +// The private key in this package is stored in the NaCl format. That is, the +// left 32-bytes are the private scalar A and the right 32-bytes are the right +// half of the SHA512 result. // // The contents of this package are mostly copied from the standard library, // and as such the source code is licensed under the BSD license of the @@ -187,6 +189,12 @@ func newKeyFromSeed(privateKey, seed []byte) { } h := sha512.Sum512(seed) + + // Apply clamping to get A in the left half, and leave the right half + // as-is. This gets the private key into the NaCl format. + h[0] &= 248 + h[31] &= 63 + h[31] |= 64 copy(privateKey, h[:]) } diff --git a/crypto/goolm/crypto/ed25519.go b/crypto/goolm/crypto/ed25519.go index 0756d778..57fc25fa 100644 --- a/crypto/goolm/crypto/ed25519.go +++ b/crypto/goolm/crypto/ed25519.go @@ -1,11 +1,11 @@ package crypto import ( - "crypto/ed25519" "encoding/base64" "fmt" "io" + "maunium.net/go/mautrix/crypto/ed25519" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" @@ -123,12 +123,16 @@ func (c Ed25519PrivateKey) Equal(x Ed25519PrivateKey) bool { // PubKey returns the public key derived from the private key. func (c Ed25519PrivateKey) PubKey() Ed25519PublicKey { publicKey := ed25519.PrivateKey(c).Public() - return Ed25519PublicKey(publicKey.(ed25519.PublicKey)) + return Ed25519PublicKey(publicKey.([]byte)) } // Sign returns the signature for the message. func (c Ed25519PrivateKey) Sign(message []byte) []byte { - return ed25519.Sign(ed25519.PrivateKey(c), message) + signature, err := ed25519.PrivateKey(c).Sign(nil, message, &ed25519.Options{}) + if err != nil { + panic(err) + } + return signature } // Ed25519PublicKey represents the public key for ed25519 usage. This is just a wrapper. diff --git a/crypto/goolm/session/megolm_outbound_session.go b/crypto/goolm/session/megolm_outbound_session.go index 44d001d1..ce9a4b26 100644 --- a/crypto/goolm/session/megolm_outbound_session.go +++ b/crypto/goolm/session/megolm_outbound_session.go @@ -6,7 +6,7 @@ import ( "errors" "fmt" - "maunium.net/go/mautrix/id" + "go.mau.fi/util/exerrors" "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" @@ -15,6 +15,7 @@ import ( "maunium.net/go/mautrix/crypto/goolm/megolm" "maunium.net/go/mautrix/crypto/goolm/utilities" "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/id" ) const ( @@ -187,9 +188,5 @@ func (s *MegolmOutboundSession) MessageIndex() uint { // Key returns the base64-encoded current ratchet key for this session. func (s *MegolmOutboundSession) Key() string { - message, err := s.SessionSharingMessage() - if err != nil { - panic(err) - } - return string(message) + return string(exerrors.Must(s.SessionSharingMessage())) } diff --git a/crypto/olm/account_test.go b/crypto/olm/account_test.go new file mode 100644 index 00000000..0c628a20 --- /dev/null +++ b/crypto/olm/account_test.go @@ -0,0 +1,122 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package olm_test + +import ( + "bytes" + "encoding/base64" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.mau.fi/util/exerrors" + + "maunium.net/go/mautrix/crypto/ed25519" + "maunium.net/go/mautrix/crypto/goolm/account" + "maunium.net/go/mautrix/crypto/libolm" + "maunium.net/go/mautrix/crypto/olm" +) + +func ensureAccountsEqual(t *testing.T, a, b olm.Account) { + t.Helper() + + assert.Equal(t, a.MaxNumberOfOneTimeKeys(), b.MaxNumberOfOneTimeKeys()) + + aEd25519, aCurve25519, err := a.IdentityKeys() + require.NoError(t, err) + bEd25519, bCurve25519, err := b.IdentityKeys() + require.NoError(t, err) + assert.Equal(t, aEd25519, bEd25519) + assert.Equal(t, aCurve25519, bCurve25519) + + aIdentityKeysJSON, err := a.IdentityKeysJSON() + require.NoError(t, err) + bIdentityKeysJSON, err := b.IdentityKeysJSON() + require.NoError(t, err) + assert.JSONEq(t, string(aIdentityKeysJSON), string(bIdentityKeysJSON)) + + aOTKs, err := a.OneTimeKeys() + require.NoError(t, err) + bOTKs, err := b.OneTimeKeys() + require.NoError(t, err) + assert.Equal(t, aOTKs, bOTKs) +} + +// TestAccount_UnpickleLibolmToGoolm tests creating an account from libolm, +// pickling it, and importing it into goolm. +func TestAccount_UnpickleLibolmToGoolm(t *testing.T) { + libolmAccount, err := libolm.NewAccount(nil) + require.NoError(t, err) + + require.NoError(t, libolmAccount.GenOneTimeKeys(nil, 50)) + + libolmPickled, err := libolmAccount.Pickle([]byte("test")) + require.NoError(t, err) + + goolmAccount, err := account.AccountFromPickled(libolmPickled, []byte("test")) + require.NoError(t, err) + + ensureAccountsEqual(t, libolmAccount, goolmAccount) + + goolmPickled, err := goolmAccount.Pickle([]byte("test")) + require.NoError(t, err) + assert.Equal(t, libolmPickled, goolmPickled) +} + +// TestAccount_UnpickleGoolmToLibolm tests creating an account from goolm, +// pickling it, and importing it into libolm. +func TestAccount_UnpickleGoolmToLibolm(t *testing.T) { + goolmAccount, err := account.NewAccount(nil) + require.NoError(t, err) + + require.NoError(t, goolmAccount.GenOneTimeKeys(nil, 50)) + + goolmPickled, err := goolmAccount.Pickle([]byte("test")) + require.NoError(t, err) + + libolmAccount, err := libolm.AccountFromPickled(bytes.Clone(goolmPickled), []byte("test")) + require.NoError(t, err) + + ensureAccountsEqual(t, libolmAccount, goolmAccount) + + libolmPickled, err := libolmAccount.Pickle([]byte("test")) + require.NoError(t, err) + assert.Equal(t, goolmPickled, libolmPickled) +} + +func FuzzAccount_Sign(f *testing.F) { + f.Add([]byte("anything")) + + libolmAccount := exerrors.Must(libolm.NewAccount(nil)) + goolmAccount := exerrors.Must(account.AccountFromPickled(exerrors.Must(libolmAccount.Pickle([]byte("test"))), []byte("test"))) + + f.Fuzz(func(t *testing.T, message []byte) { + if len(message) == 0 { + t.Skip("empty message is not supported") + } + + libolmSignature, err := libolmAccount.Sign(bytes.Clone(message)) + require.NoError(t, err) + goolmSignature, err := goolmAccount.Sign(bytes.Clone(message)) + require.NoError(t, err) + assert.Equal(t, goolmSignature, libolmSignature) + + goolmSignatureBytes, err := base64.RawStdEncoding.DecodeString(string(goolmSignature)) + require.NoError(t, err) + libolmSignatureBytes, err := base64.RawStdEncoding.DecodeString(string(libolmSignature)) + require.NoError(t, err) + + libolmEd25519, _, err := libolmAccount.IdentityKeys() + require.NoError(t, err) + + assert.True(t, ed25519.Verify(ed25519.PublicKey(libolmEd25519.Bytes()), message, libolmSignatureBytes)) + assert.True(t, ed25519.Verify(ed25519.PublicKey(libolmEd25519.Bytes()), message, goolmSignatureBytes)) + + assert.True(t, goolmAccount.IdKeys.Ed25519.Verify(bytes.Clone(message), libolmSignatureBytes)) + assert.True(t, goolmAccount.IdKeys.Ed25519.Verify(bytes.Clone(message), goolmSignatureBytes)) + }) +} diff --git a/crypto/olm/errors.go b/crypto/olm/errors.go index eb0cffff..c80b82e4 100644 --- a/crypto/olm/errors.go +++ b/crypto/olm/errors.go @@ -1,3 +1,9 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + package olm import "errors" diff --git a/crypto/olm/groupsession_test.go b/crypto/olm/groupsession_test.go new file mode 100644 index 00000000..276e7cfb --- /dev/null +++ b/crypto/olm/groupsession_test.go @@ -0,0 +1,47 @@ +package olm_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/crypto/goolm/session" + "maunium.net/go/mautrix/crypto/libolm" +) + +// TestEncryptDecrypt_GoolmToLibolm tests encryption where goolm encrypts and libolm decrypts +func TestEncryptDecrypt_GoolmToLibolm(t *testing.T) { + goolmOutbound, err := session.NewMegolmOutboundSession() + require.NoError(t, err) + + libolmInbound, err := libolm.NewInboundGroupSession([]byte(goolmOutbound.Key())) + require.NoError(t, err) + + for i := 0; i < 10; i++ { + ciphertext, err := goolmOutbound.Encrypt([]byte(fmt.Sprintf("message %d", i))) + require.NoError(t, err) + + plaintext, msgIdx, err := libolmInbound.Decrypt(ciphertext) + assert.NoError(t, err) + assert.Equal(t, []byte(fmt.Sprintf("message %d", i)), plaintext) + assert.Equal(t, goolmOutbound.MessageIndex()-1, msgIdx) + } +} + +func TestEncryptDecrypt_LibolmToGoolm(t *testing.T) { + libolmOutbound := libolm.NewOutboundGroupSession() + goolmInbound, err := session.NewMegolmInboundSession([]byte(libolmOutbound.Key())) + require.NoError(t, err) + + for i := 0; i < 10; i++ { + ciphertext, err := libolmOutbound.Encrypt([]byte(fmt.Sprintf("message %d", i))) + require.NoError(t, err) + + plaintext, msgIdx, err := goolmInbound.Decrypt(ciphertext) + assert.NoError(t, err) + assert.Equal(t, []byte(fmt.Sprintf("message %d", i)), plaintext) + assert.Equal(t, libolmOutbound.MessageIndex()-1, msgIdx) + } +} diff --git a/crypto/olm/outboundgroupsession_test.go b/crypto/olm/outboundgroupsession_test.go new file mode 100644 index 00000000..46c63780 --- /dev/null +++ b/crypto/olm/outboundgroupsession_test.go @@ -0,0 +1,127 @@ +package olm_test + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/crypto/goolm/session" + "maunium.net/go/mautrix/crypto/libolm" +) + +func TestMegolmOutboundSessionPickle_RoundtripThroughGoolm(t *testing.T) { + libolmSession := libolm.NewOutboundGroupSession() + libolmPickled, err := libolmSession.Pickle([]byte("test")) + require.NoError(t, err) + + goolmSession, err := session.MegolmOutboundSessionFromPickled(libolmPickled, []byte("test")) + require.NoError(t, err) + + goolmPickled, err := goolmSession.Pickle([]byte("test")) + require.NoError(t, err) + + assert.Equal(t, libolmPickled, goolmPickled, "pickled versions are not the same") + + libolmSession2 := libolm.NewOutboundGroupSession() + err = libolmSession2.Unpickle(bytes.Clone(goolmPickled), []byte("test")) + require.NoError(t, err) + + assert.Equal(t, libolmSession.Key(), libolmSession2.Key()) +} + +func TestMegolmOutboundSessionPickle_RoundtripThroughLibolm(t *testing.T) { + goolmSession, err := session.NewMegolmOutboundSession() + require.NoError(t, err) + + goolmPickled, err := goolmSession.Pickle([]byte("test")) + require.NoError(t, err) + + libolmSession := libolm.NewOutboundGroupSession() + err = libolmSession.Unpickle(bytes.Clone(goolmPickled), []byte("test")) + require.NoError(t, err) + + libolmPickled, err := libolmSession.Pickle([]byte("test")) + require.NoError(t, err) + + assert.Equal(t, goolmPickled, libolmPickled, "pickled versions are not the same") + + goolmSession2, err := session.MegolmOutboundSessionFromPickled(libolmPickled, []byte("test")) + require.NoError(t, err) + + assert.Equal(t, goolmSession.Key(), goolmSession2.Key()) + assert.Equal(t, goolmSession.SigningKey.PrivateKey, goolmSession2.SigningKey.PrivateKey) +} + +func TestMegolmOutboundSessionPickleLibolm(t *testing.T) { + libolmSession := libolm.NewOutboundGroupSession() + libolmPickled, err := libolmSession.Pickle([]byte("test")) + require.NoError(t, err) + + goolmSession, err := session.MegolmOutboundSessionFromPickled(bytes.Clone(libolmPickled), []byte("test")) + require.NoError(t, err) + goolmPickled, err := goolmSession.Pickle([]byte("test")) + require.NoError(t, err) + + assert.Equal(t, libolmPickled, goolmPickled, "pickled versions are not the same") + assert.Equal(t, goolmSession.SigningKey.PrivateKey.PubKey(), goolmSession.SigningKey.PublicKey) + + // Ensure that the key export is the same and that the pickle is the same + assert.Equal(t, libolmSession.Key(), goolmSession.Key(), "keys are not the same") +} + +func TestMegolmOutboundSessionPickleGoolm(t *testing.T) { + goolmSession, err := session.NewMegolmOutboundSession() + require.NoError(t, err) + goolmPickled, err := goolmSession.Pickle([]byte("test")) + require.NoError(t, err) + + libolmSession := libolm.NewOutboundGroupSession() + err = libolmSession.Unpickle(bytes.Clone(goolmPickled), []byte("test")) + require.NoError(t, err) + libolmPickled, err := libolmSession.Pickle([]byte("test")) + require.NoError(t, err) + + assert.Equal(t, libolmPickled, goolmPickled, "pickled versions are not the same") + assert.Equal(t, goolmSession.SigningKey.PrivateKey.PubKey(), goolmSession.SigningKey.PublicKey) + + // Ensure that the key export is the same and that the pickle is the same + assert.Equal(t, libolmSession.Key(), goolmSession.Key(), "keys are not the same") +} + +func FuzzMegolmOutboundSession_Encrypt(f *testing.F) { + f.Add([]byte("anything")) + + f.Fuzz(func(t *testing.T, plaintext []byte) { + if len(plaintext) == 0 { + t.Skip("empty plaintext is not supported") + } + + libolmSession := libolm.NewOutboundGroupSession() + libolmPickled, err := libolmSession.Pickle([]byte("test")) + require.NoError(t, err) + + goolmSession, err := session.MegolmOutboundSessionFromPickled(bytes.Clone(libolmPickled), []byte("test")) + require.NoError(t, err) + + assert.Equal(t, libolmSession.Key(), goolmSession.Key()) + + // Encrypt the plaintext ten times because the ratchet increments. + for i := 0; i < 10; i++ { + assert.EqualValues(t, i, libolmSession.MessageIndex()) + assert.EqualValues(t, i, goolmSession.MessageIndex()) + + libolmEncrypted, err := libolmSession.Encrypt(plaintext) + require.NoError(t, err) + + goolmEncrypted, err := goolmSession.Encrypt(plaintext) + require.NoError(t, err) + + assert.Equal(t, libolmEncrypted, goolmEncrypted) + + assert.EqualValues(t, i+1, libolmSession.MessageIndex()) + assert.EqualValues(t, i+1, goolmSession.MessageIndex()) + } + }) +} diff --git a/crypto/olm/session_test.go b/crypto/olm/session_test.go index ff9445d9..9f0986eb 100644 --- a/crypto/olm/session_test.go +++ b/crypto/olm/session_test.go @@ -7,12 +7,20 @@ package olm_test import ( + "bytes" + "fmt" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.mau.fi/util/exerrors" + "golang.org/x/exp/maps" + "maunium.net/go/mautrix/crypto/goolm/account" "maunium.net/go/mautrix/crypto/goolm/session" "maunium.net/go/mautrix/crypto/libolm" + "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/id" ) func TestBlankSession(t *testing.T) { @@ -35,22 +43,77 @@ func TestSessionPickle(t *testing.T) { pickledDataFromLibOlm := []byte("icDKYm0b4aO23WgUuOxdpPoxC0UlEOYPVeuduNH3IkpFsmnWx5KuEOpxGiZw5IuB/sSn2RZUCTiJ90IvgC7AClkYGHep9O8lpiqQX73XVKD9okZDCAkBc83eEq0DKYC7HBkGRAU/4T6QPIBBY3UK4QZwULLE/fLsi3j4YZBehMtnlsqgHK0q1bvX4cRznZItVKR4ro0O9EAk6LLxJtSnRu5elSUk7YXT") pickleKey := []byte("secret_key") - goolmSession := session.NewOlmSession() - err := goolmSession.Unpickle(pickledDataFromLibOlm, pickleKey) + goolmSession, err := session.OlmSessionFromPickled(bytes.Clone(pickledDataFromLibOlm), pickleKey) assert.NoError(t, err) - libolmSession := libolm.NewBlankSession() - err = libolmSession.Unpickle(pickledDataFromLibOlm, pickleKey) + libolmSession, err := libolm.SessionFromPickled(bytes.Clone(pickledDataFromLibOlm), pickleKey) assert.NoError(t, err) - // Reset the pickle data since libolmSession.Unpickle modifies it. - pickledDataFromLibOlm = []byte("icDKYm0b4aO23WgUuOxdpPoxC0UlEOYPVeuduNH3IkpFsmnWx5KuEOpxGiZw5IuB/sSn2RZUCTiJ90IvgC7AClkYGHep9O8lpiqQX73XVKD9okZDCAkBc83eEq0DKYC7HBkGRAU/4T6QPIBBY3UK4QZwULLE/fLsi3j4YZBehMtnlsqgHK0q1bvX4cRznZItVKR4ro0O9EAk6LLxJtSnRu5elSUk7YXT") - goolmPickled, err := goolmSession.Pickle(pickleKey) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, pickledDataFromLibOlm, goolmPickled) libolmPickled, err := libolmSession.Pickle(pickleKey) + require.NoError(t, err) assert.Equal(t, pickledDataFromLibOlm, libolmPickled) - assert.NoError(t, err) +} + +func TestSession_EncryptDecrypt(t *testing.T) { + combos := [][2]olm.Account{ + {exerrors.Must(libolm.NewAccount(nil)), exerrors.Must(libolm.NewAccount(nil))}, + {exerrors.Must(account.NewAccount(nil)), exerrors.Must(account.NewAccount(nil))}, + {exerrors.Must(libolm.NewAccount(nil)), exerrors.Must(account.NewAccount(nil))}, + {exerrors.Must(account.NewAccount(nil)), exerrors.Must(libolm.NewAccount(nil))}, + } + + for _, combo := range combos { + receiver, sender := combo[0], combo[1] + require.NoError(t, receiver.GenOneTimeKeys(nil, 50)) + require.NoError(t, sender.GenOneTimeKeys(nil, 50)) + + _, receiverCurve25519, err := receiver.IdentityKeys() + require.NoError(t, err) + accountAOTKs, err := receiver.OneTimeKeys() + require.NoError(t, err) + + senderSession, err := sender.NewOutboundSession(receiverCurve25519, accountAOTKs[maps.Keys(accountAOTKs)[0]]) + require.NoError(t, err) + + // Send a couple pre-key messages from sender -> receiver. + var receiverSession olm.Session + for i := 0; i < 10; i++ { + msgType, ciphertext, err := senderSession.Encrypt([]byte(fmt.Sprintf("prekey %d", i))) + require.NoError(t, err) + assert.Equal(t, id.OlmMsgTypePreKey, msgType) + + receiverSession, err = receiver.NewInboundSession(string(ciphertext)) + require.NoError(t, err) + + decrypted, err := receiverSession.Decrypt(string(ciphertext), msgType) + require.NoError(t, err) + assert.Equal(t, []byte(fmt.Sprintf("prekey %d", i)), decrypted) + } + + // Send some messages from receiver -> sender. + for i := 0; i < 10; i++ { + msgType, ciphertext, err := receiverSession.Encrypt([]byte(fmt.Sprintf("response %d", i))) + require.NoError(t, err) + assert.Equal(t, id.OlmMsgTypeMsg, msgType) + + decrypted, err := senderSession.Decrypt(string(ciphertext), msgType) + require.NoError(t, err) + assert.Equal(t, []byte(fmt.Sprintf("response %d", i)), decrypted) + } + + // Send some more messages from sender -> receiver + for i := 0; i < 10; i++ { + msgType, ciphertext, err := senderSession.Encrypt([]byte(fmt.Sprintf("%d", i))) + require.NoError(t, err) + assert.Equal(t, id.OlmMsgTypeMsg, msgType) + + decrypted, err := receiverSession.Decrypt(string(ciphertext), msgType) + require.NoError(t, err) + assert.Equal(t, []byte(fmt.Sprintf("%d", i)), decrypted) + } + } } From dae53d42c7cc244ab7c87e0b2bf596ba0425672e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 25 Aug 2024 00:54:08 +0300 Subject: [PATCH 0676/1647] bridgev2: add standard error for unsupported media types --- bridgev2/errors.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bridgev2/errors.go b/bridgev2/errors.go index 2834b298..e80377cf 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -53,6 +53,7 @@ var ( ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true) ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) + ErrUnsupportedMediaType error = WrapErrorInStatus(errors.New("unsupported media type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true) ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true) ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true) From 649a637350a03dae96714d03deb16d002b1148bc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 25 Aug 2024 02:17:04 +0300 Subject: [PATCH 0677/1647] bridgev2/simplevent: add transaction ID field to message event --- bridgev2/simplevent/message.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/bridgev2/simplevent/message.go b/bridgev2/simplevent/message.go index 928bffc9..55d25bd8 100644 --- a/bridgev2/simplevent/message.go +++ b/bridgev2/simplevent/message.go @@ -20,6 +20,7 @@ type Message[T any] struct { Data T ID networkid.MessageID + TransactionID networkid.TransactionID TargetMessage networkid.MessageID ConvertMessageFunc func(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, data T) (*bridgev2.ConvertedMessage, error) @@ -28,9 +29,10 @@ type Message[T any] struct { } var ( - _ bridgev2.RemoteMessage = (*Message[any])(nil) - _ bridgev2.RemoteEdit = (*Message[any])(nil) - _ bridgev2.RemoteMessageUpsert = (*Message[any])(nil) + _ bridgev2.RemoteMessage = (*Message[any])(nil) + _ bridgev2.RemoteEdit = (*Message[any])(nil) + _ bridgev2.RemoteMessageUpsert = (*Message[any])(nil) + _ bridgev2.RemoteMessageWithTransactionID = (*Message[any])(nil) ) func (evt *Message[T]) ConvertMessage(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI) (*bridgev2.ConvertedMessage, error) { @@ -53,6 +55,10 @@ func (evt *Message[T]) GetTargetMessage() networkid.MessageID { return evt.TargetMessage } +func (evt *Message[T]) GetTransactionID() networkid.TransactionID { + return evt.TransactionID +} + type MessageRemove struct { EventMeta From a7aa97679d72aac61aa0193513e6ee2c23325502 Mon Sep 17 00:00:00 2001 From: Brad Murray Date: Sun, 25 Aug 2024 13:45:54 -0400 Subject: [PATCH 0678/1647] filter: Add com.beeper.to_device sync filter field (#274) * Add com.beeper.to_device sync filter field * Update filter.go Co-authored-by: Tulir Asokan --------- Co-authored-by: Tulir Asokan --- filter.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/filter.go b/filter.go index fd6de7a0..2603bfb9 100644 --- a/filter.go +++ b/filter.go @@ -24,6 +24,8 @@ type Filter struct { EventFormat EventFormat `json:"event_format,omitempty"` Presence FilterPart `json:"presence,omitempty"` Room RoomFilter `json:"room,omitempty"` + + BeeperToDevice *FilterPart `json:"com.beeper.to_device,omitempty"` } // RoomFilter is used to define filtering rules for room events From b68fdc9057bff8b9ea1257ff7a50ba068c5aa879 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 25 Aug 2024 22:20:47 +0300 Subject: [PATCH 0679/1647] bridgev2: fix username templates --- bridgev2/bridgeconfig/appservice.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/bridgeconfig/appservice.go b/bridgev2/bridgeconfig/appservice.go index 9ff333e9..5e482499 100644 --- a/bridgev2/bridgeconfig/appservice.go +++ b/bridgev2/bridgeconfig/appservice.go @@ -8,9 +8,9 @@ package bridgeconfig import ( "fmt" - "html/template" "regexp" "strings" + "text/template" "go.mau.fi/util/exerrors" "go.mau.fi/util/random" From 7e7cb57ee770a54e041b31d747ef4ef472ebd10a Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 26 Aug 2024 08:54:15 -0600 Subject: [PATCH 0680/1647] bridgev2/commands: fix NPE on search Signed-off-by: Sumner Evans --- bridgev2/commands/startchat.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go index 24c8a488..c18e977a 100644 --- a/bridgev2/commands/startchat.go +++ b/bridgev2/commands/startchat.go @@ -181,7 +181,7 @@ func fnSearch(ce *Event) { ce.Log.Err(err).Object("portal_key", res.Chat.PortalKey).Msg("Failed to get DM portal") } } - if res.Chat.Portal != nil { + if res.Chat.Portal != nil && res.Chat.Portal.MXID != "" { portalName := res.Chat.Portal.Name if portalName == "" { portalName = res.Chat.Portal.MXID.String() From 720648ffdf834c503771ee95cd2ec13376e28f03 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 26 Aug 2024 17:58:27 +0300 Subject: [PATCH 0681/1647] id: don't panic if URI methods are called with empty/nil values --- id/matrixuri.go | 19 ++++++++++++++----- id/opaque.go | 15 +++++++++++++++ id/userid.go | 3 +++ 3 files changed, 32 insertions(+), 5 deletions(-) diff --git a/id/matrixuri.go b/id/matrixuri.go index 5ec403e9..acd8e0c0 100644 --- a/id/matrixuri.go +++ b/id/matrixuri.go @@ -65,6 +65,9 @@ func (uri *MatrixURI) getQuery() url.Values { // String converts the parsed matrix: URI back into the string representation. func (uri *MatrixURI) String() string { + if uri == nil { + return "" + } parts := []string{ SigilToPathSegment[uri.Sigil1], url.PathEscape(uri.MXID1), @@ -81,6 +84,9 @@ func (uri *MatrixURI) String() string { // MatrixToURL converts to parsed matrix: URI into a matrix.to URL func (uri *MatrixURI) MatrixToURL() string { + if uri == nil { + return "" + } fragment := fmt.Sprintf("#/%s", url.PathEscape(uri.PrimaryIdentifier())) if uri.Sigil2 != 0 { fragment = fmt.Sprintf("%s/%s", fragment, url.PathEscape(uri.SecondaryIdentifier())) @@ -96,13 +102,16 @@ func (uri *MatrixURI) MatrixToURL() string { // PrimaryIdentifier returns the first Matrix identifier in the URI. // Currently room IDs, room aliases and user IDs can be in the primary identifier slot. func (uri *MatrixURI) PrimaryIdentifier() string { + if uri == nil { + return "" + } return fmt.Sprintf("%c%s", uri.Sigil1, uri.MXID1) } // SecondaryIdentifier returns the second Matrix identifier in the URI. // Currently only event IDs can be in the secondary identifier slot. func (uri *MatrixURI) SecondaryIdentifier() string { - if uri.Sigil2 == 0 { + if uri == nil || uri.Sigil2 == 0 { return "" } return fmt.Sprintf("%c%s", uri.Sigil2, uri.MXID2) @@ -110,7 +119,7 @@ func (uri *MatrixURI) SecondaryIdentifier() string { // UserID returns the user ID from the URI if the primary identifier is a user ID. func (uri *MatrixURI) UserID() UserID { - if uri.Sigil1 == '@' { + if uri != nil && uri.Sigil1 == '@' { return UserID(uri.PrimaryIdentifier()) } return "" @@ -118,7 +127,7 @@ func (uri *MatrixURI) UserID() UserID { // RoomID returns the room ID from the URI if the primary identifier is a room ID. func (uri *MatrixURI) RoomID() RoomID { - if uri.Sigil1 == '!' { + if uri != nil && uri.Sigil1 == '!' { return RoomID(uri.PrimaryIdentifier()) } return "" @@ -126,7 +135,7 @@ func (uri *MatrixURI) RoomID() RoomID { // RoomAlias returns the room alias from the URI if the primary identifier is a room alias. func (uri *MatrixURI) RoomAlias() RoomAlias { - if uri.Sigil1 == '#' { + if uri != nil && uri.Sigil1 == '#' { return RoomAlias(uri.PrimaryIdentifier()) } return "" @@ -134,7 +143,7 @@ func (uri *MatrixURI) RoomAlias() RoomAlias { // EventID returns the event ID from the URI if the primary identifier is a room ID or alias and the secondary identifier is an event ID. func (uri *MatrixURI) EventID() EventID { - if (uri.Sigil1 == '!' || uri.Sigil1 == '#') && uri.Sigil2 == '$' { + if uri != nil && (uri.Sigil1 == '!' || uri.Sigil1 == '#') && uri.Sigil2 == '$' { return EventID(uri.SecondaryIdentifier()) } return "" diff --git a/id/opaque.go b/id/opaque.go index 16863b95..1d9f0dcf 100644 --- a/id/opaque.go +++ b/id/opaque.go @@ -37,6 +37,9 @@ func (roomID RoomID) String() string { } func (roomID RoomID) URI(via ...string) *MatrixURI { + if roomID == "" { + return nil + } return &MatrixURI{ Sigil1: '!', MXID1: string(roomID)[1:], @@ -45,6 +48,11 @@ func (roomID RoomID) URI(via ...string) *MatrixURI { } func (roomID RoomID) EventURI(eventID EventID, via ...string) *MatrixURI { + if roomID == "" { + return nil + } else if eventID == "" { + return roomID.URI(via...) + } return &MatrixURI{ Sigil1: '!', MXID1: string(roomID)[1:], @@ -59,13 +67,20 @@ func (roomAlias RoomAlias) String() string { } func (roomAlias RoomAlias) URI() *MatrixURI { + if roomAlias == "" { + return nil + } return &MatrixURI{ Sigil1: '#', MXID1: string(roomAlias)[1:], } } +// Deprecated: room alias event links should not be used. Use room IDs instead. func (roomAlias RoomAlias) EventURI(eventID EventID) *MatrixURI { + if roomAlias == "" { + return nil + } return &MatrixURI{ Sigil1: '#', MXID1: string(roomAlias)[1:], diff --git a/id/userid.go b/id/userid.go index 53b68b96..1e1f3b29 100644 --- a/id/userid.go +++ b/id/userid.go @@ -81,6 +81,9 @@ func (userID UserID) Homeserver() string { // // This does not parse or validate the user ID. Use the ParseAndValidate method if you want to ensure the user ID is valid first. func (userID UserID) URI() *MatrixURI { + if userID == "" { + return nil + } return &MatrixURI{ Sigil1: '@', MXID1: string(userID)[1:], From dfc92ee926ca44bce1184a2fa035ca964ae9ba74 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 26 Aug 2024 21:57:07 +0300 Subject: [PATCH 0682/1647] bridgev2/legacymigrate: add support for running another upgrader --- bridgev2/matrix/mxmain/legacymigrate.go | 31 +++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go index 32556de1..d908483e 100644 --- a/bridgev2/matrix/mxmain/legacymigrate.go +++ b/bridgev2/matrix/mxmain/legacymigrate.go @@ -14,6 +14,7 @@ import ( "fmt" "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" "maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/bridgev2" @@ -23,7 +24,7 @@ import ( "maunium.net/go/mautrix/id" ) -func (br *BridgeMain) LegacyMigrateSimple(renameTablesQuery, copyDataQuery string, newDBVersion int) func(ctx context.Context) error { +func (br *BridgeMain) LegacyMigrateWithAnotherUpgrader(renameTablesQuery, copyDataQuery string, newDBVersion int, otherTable dbutil.UpgradeTable, otherTableName string, otherNewVersion int) func(ctx context.Context) error { return func(ctx context.Context) error { _, err := br.DB.Exec(ctx, renameTablesQuery) if err != nil { @@ -36,6 +37,22 @@ func (br *BridgeMain) LegacyMigrateSimple(renameTablesQuery, copyDataQuery strin if upgradesTo < newDBVersion || compat > newDBVersion { return fmt.Errorf("unexpected new database version (%d/c:%d, expected %d)", upgradesTo, compat, newDBVersion) } + if otherTable != nil { + _, err = br.DB.Exec(ctx, fmt.Sprintf("CREATE TABLE %s (version INTEGER, compat INTEGER)", otherTableName)) + if err != nil { + return err + } + upgradesTo, compat, err = otherTable[0].DangerouslyRun(ctx, br.DB) + if err != nil { + return err + } else if upgradesTo < otherNewVersion || compat > otherNewVersion { + return fmt.Errorf("unexpected new database version for %s (%d/c:%d, expected %d)", otherTableName, upgradesTo, compat, newDBVersion) + } + _, err = br.DB.Exec(ctx, fmt.Sprintf("INSERT INTO %s (version, compat) VALUES ($1, $2)", otherTableName), upgradesTo, compat) + if err != nil { + return err + } + } copyDataQuery, err = br.DB.Internals().FilterSQLUpgrade(bytes.Split([]byte(copyDataQuery), []byte("\n"))) if err != nil { return err @@ -61,7 +78,17 @@ func (br *BridgeMain) LegacyMigrateSimple(renameTablesQuery, copyDataQuery strin } } -func (br *BridgeMain) CheckLegacyDB(expectedVersion int, minBridgeVersion, firstMegaVersion string, migrator func(context.Context) error, transaction bool) { +func (br *BridgeMain) LegacyMigrateSimple(renameTablesQuery, copyDataQuery string, newDBVersion int) func(ctx context.Context) error { + return br.LegacyMigrateWithAnotherUpgrader(renameTablesQuery, copyDataQuery, newDBVersion, nil, "", 0) +} + +func (br *BridgeMain) CheckLegacyDB( + expectedVersion int, + minBridgeVersion, + firstMegaVersion string, + migrator func(context.Context) error, + transaction bool, +) { log := br.Log.With().Str("action", "migrate legacy db").Logger() ctx := log.WithContext(context.Background()) exists, err := br.DB.TableExists(ctx, "database_owner") From ad32a3e60c95093836189d4edbe276870ced49f5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 27 Aug 2024 14:29:03 +0300 Subject: [PATCH 0683/1647] bridgev2/legacymigrate: upgrade version table when migrating --- bridgev2/matrix/mxmain/legacymigrate.go | 6 +++++- go.mod | 2 +- go.sum | 4 ++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go index d908483e..441b59bd 100644 --- a/bridgev2/matrix/mxmain/legacymigrate.go +++ b/bridgev2/matrix/mxmain/legacymigrate.go @@ -26,7 +26,11 @@ import ( func (br *BridgeMain) LegacyMigrateWithAnotherUpgrader(renameTablesQuery, copyDataQuery string, newDBVersion int, otherTable dbutil.UpgradeTable, otherTableName string, otherNewVersion int) func(ctx context.Context) error { return func(ctx context.Context) error { - _, err := br.DB.Exec(ctx, renameTablesQuery) + err := dbutil.DangerousInternalUpgradeVersionTable(ctx, br.DB) + if err != nil { + return err + } + _, err = br.DB.Exec(ctx, renameTablesQuery) if err != nil { return err } diff --git a/go.mod b/go.mod index 835c7ef8..47be73d1 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/tidwall/gjson v1.17.3 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.4 - go.mau.fi/util v0.7.0 + go.mau.fi/util v0.7.1-0.20240827112829-84c63841c264 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.26.0 golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa diff --git a/go.sum b/go.sum index d357ac92..7e819192 100644 --- a/go.sum +++ b/go.sum @@ -50,8 +50,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.7.0 h1:l31z+ivrSQw+cv/9eFebEqtQW2zhxivGypn+JT0h/ws= -go.mau.fi/util v0.7.0/go.mod h1:bWYreIoTULL/UiRbZdfddPh7uWDFW5yX4YCv5FB0eE0= +go.mau.fi/util v0.7.1-0.20240827112829-84c63841c264 h1:mWujT3q8pxyJc/3BvWTgTN4+k41d1pBCvwxH56prqQA= +go.mau.fi/util v0.7.1-0.20240827112829-84c63841c264/go.mod h1:WuAOOV0O/otkxGkFUvfv/XE2ztegaoyM15ovS6SYbf4= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= From ae306b3efaad4dcfad41b561f199e5c4f836023a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 27 Aug 2024 14:46:26 +0300 Subject: [PATCH 0684/1647] bridgev2/legacymigrate: add support for python configs --- bridgev2/bridgeconfig/legacymigrate.go | 172 +++++++++++++++++++++++++ bridgev2/bridgeconfig/upgrade.go | 121 +---------------- 2 files changed, 176 insertions(+), 117 deletions(-) create mode 100644 bridgev2/bridgeconfig/legacymigrate.go diff --git a/bridgev2/bridgeconfig/legacymigrate.go b/bridgev2/bridgeconfig/legacymigrate.go new file mode 100644 index 00000000..9ec8ea12 --- /dev/null +++ b/bridgev2/bridgeconfig/legacymigrate.go @@ -0,0 +1,172 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package bridgeconfig + +import ( + "fmt" + "net/url" + "os" + "strings" + + up "go.mau.fi/util/configupgrade" +) + +var HackyMigrateLegacyNetworkConfig func(up.Helper) + +func CopyToOtherLocation(helper up.Helper, fieldType up.YAMLType, source, dest []string) { + val, ok := helper.Get(fieldType, source...) + if ok { + helper.Set(fieldType, val, dest...) + } +} + +func CopyMapToOtherLocation(helper up.Helper, source, dest []string) { + val := helper.GetNode(source...) + if val != nil && val.Map != nil { + helper.SetMap(val.Map, dest...) + } +} + +func doMigrateLegacy(helper up.Helper, python bool) { + if HackyMigrateLegacyNetworkConfig == nil { + _, _ = fmt.Fprintln(os.Stderr, "Legacy bridge config detected, but hacky network config migrator is not set") + os.Exit(1) + } + _, _ = fmt.Fprintln(os.Stderr, "Migrating legacy bridge config") + + helper.Copy(up.Str, "homeserver", "address") + helper.Copy(up.Str, "homeserver", "domain") + helper.Copy(up.Str, "homeserver", "software") + helper.Copy(up.Str|up.Null, "homeserver", "status_endpoint") + helper.Copy(up.Str|up.Null, "homeserver", "message_send_checkpoint_endpoint") + helper.Copy(up.Bool, "homeserver", "async_media") + helper.Copy(up.Str|up.Null, "homeserver", "websocket_proxy") + helper.Copy(up.Bool, "homeserver", "websocket") + helper.Copy(up.Int, "homeserver", "ping_interval_seconds") + + helper.Copy(up.Str|up.Null, "appservice", "address") + helper.Copy(up.Str|up.Null, "appservice", "hostname") + helper.Copy(up.Int|up.Null, "appservice", "port") + helper.Copy(up.Str, "appservice", "id") + if python { + CopyToOtherLocation(helper, up.Str, []string{"appservice", "bot_username"}, []string{"appservice", "bot", "username"}) + CopyToOtherLocation(helper, up.Str, []string{"appservice", "bot_displayname"}, []string{"appservice", "bot", "displayname"}) + CopyToOtherLocation(helper, up.Str, []string{"appservice", "bot_avatar"}, []string{"appservice", "bot", "avatar"}) + } else { + helper.Copy(up.Str, "appservice", "bot", "username") + helper.Copy(up.Str, "appservice", "bot", "displayname") + helper.Copy(up.Str, "appservice", "bot", "avatar") + } + helper.Copy(up.Bool, "appservice", "ephemeral_events") + helper.Copy(up.Bool, "appservice", "async_transactions") + helper.Copy(up.Str, "appservice", "as_token") + helper.Copy(up.Str, "appservice", "hs_token") + + helper.Copy(up.Str, "bridge", "command_prefix") + helper.Copy(up.Bool, "bridge", "personal_filtering_spaces") + if oldPM, ok := helper.Get(up.Str, "bridge", "private_chat_portal_meta"); ok && (oldPM == "default" || oldPM == "always") { + helper.Set(up.Bool, "true", "bridge", "private_chat_portal_meta") + } else { + helper.Set(up.Bool, "false", "bridge", "private_chat_portal_meta") + } + helper.Copy(up.Bool, "bridge", "relay", "enabled") + helper.Copy(up.Bool, "bridge", "relay", "admin_only") + helper.Copy(up.Map, "bridge", "permissions") + + if python { + legacyDB, ok := helper.Get(up.Str, "appservice", "database") + if ok { + if strings.HasPrefix(legacyDB, "postgres") { + parsedDB, err := url.Parse(legacyDB) + if err != nil { + panic(err) + } + q := parsedDB.Query() + if parsedDB.Host == "" && !q.Has("host") { + q.Set("host", "/var/run/postgresql") + } else if !q.Has("sslmode") { + q.Set("sslmode", "disable") + } + parsedDB.RawQuery = q.Encode() + helper.Set(up.Str, parsedDB.String(), "database", "uri") + helper.Set(up.Str, "postgres", "database", "type") + } else { + dbPath := strings.TrimPrefix(strings.TrimPrefix(legacyDB, "sqlite:"), "///") + helper.Set(up.Str, dbPath, "database", "uri") + helper.Set(up.Str, "sqlite3-fk-wal", "database", "type") + } + } + if legacyDBMinSize, ok := helper.Get(up.Int, "appservice", "database_opts", "min_size"); ok { + helper.Set(up.Int, legacyDBMinSize, "database", "max_idle_conns") + } + if legacyDBMaxSize, ok := helper.Get(up.Int, "appservice", "database_opts", "max_size"); ok { + helper.Set(up.Int, legacyDBMaxSize, "database", "max_open_conns") + } + } else { + CopyToOtherLocation(helper, up.Str, []string{"appservice", "database", "type"}, []string{"database", "type"}) + CopyToOtherLocation(helper, up.Str, []string{"appservice", "database", "uri"}, []string{"database", "uri"}) + CopyToOtherLocation(helper, up.Int, []string{"appservice", "database", "max_open_conns"}, []string{"database", "max_open_conns"}) + CopyToOtherLocation(helper, up.Int, []string{"appservice", "database", "max_idle_conns"}, []string{"database", "max_idle_conns"}) + CopyToOtherLocation(helper, up.Int, []string{"appservice", "database", "max_conn_idle_time"}, []string{"database", "max_conn_idle_time"}) + CopyToOtherLocation(helper, up.Int, []string{"appservice", "database", "max_conn_lifetime"}, []string{"database", "max_conn_lifetime"}) + } + + if python { + if usernameTemplate, ok := helper.Get(up.Str, "bridge", "username_template"); ok && strings.Contains(usernameTemplate, "{userid}") { + helper.Set(up.Str, strings.ReplaceAll(usernameTemplate, "{userid}", "{{.}}"), "appservice", "username_template") + } + } else { + CopyToOtherLocation(helper, up.Str, []string{"bridge", "username_template"}, []string{"appservice", "username_template"}) + } + + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "message_status_events"}, []string{"matrix", "message_status_events"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "delivery_receipts"}, []string{"matrix", "delivery_receipts"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "message_error_notices"}, []string{"matrix", "message_error_notices"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "sync_direct_chat_list"}, []string{"matrix", "sync_direct_chat_list"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "federate_rooms"}, []string{"matrix", "federate_rooms"}) + + CopyToOtherLocation(helper, up.Str, []string{"bridge", "provisioning", "prefix"}, []string{"provisioning", "prefix"}) + CopyToOtherLocation(helper, up.Str, []string{"bridge", "provisioning", "shared_secret"}, []string{"provisioning", "shared_secret"}) + CopyToOtherLocation(helper, up.Str, []string{"appservice", "provisioning", "prefix"}, []string{"provisioning", "prefix"}) + CopyToOtherLocation(helper, up.Str, []string{"appservice", "provisioning", "shared_secret"}, []string{"provisioning", "shared_secret"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "provisioning", "debug_endpoints"}, []string{"provisioning", "debug_endpoints"}) + + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "double_puppet_allow_discovery"}, []string{"double_puppet", "allow_discovery"}) + CopyMapToOtherLocation(helper, []string{"bridge", "double_puppet_server_map"}, []string{"double_puppet", "servers"}) + CopyMapToOtherLocation(helper, []string{"bridge", "login_shared_secret_map"}, []string{"double_puppet", "secrets"}) + + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "allow"}, []string{"encryption", "allow"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "default"}, []string{"encryption", "default"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "require"}, []string{"encryption", "require"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "appservice"}, []string{"encryption", "appservice"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "allow_key_sharing"}, []string{"encryption", "allow_key_sharing"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_outbound_on_ack"}, []string{"encryption", "delete_keys", "delete_outbound_on_ack"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "dont_store_outbound"}, []string{"encryption", "delete_keys", "dont_store_outbound"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "ratchet_on_decrypt"}, []string{"encryption", "delete_keys", "ratchet_on_decrypt"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_fully_used_on_decrypt"}, []string{"encryption", "delete_keys", "delete_fully_used_on_decrypt"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_prev_on_new_session"}, []string{"encryption", "delete_keys", "delete_prev_on_new_session"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_on_device_delete"}, []string{"encryption", "delete_keys", "delete_on_device_delete"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "periodically_delete_expired"}, []string{"encryption", "delete_keys", "periodically_delete_expired"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_outdated_inbound"}, []string{"encryption", "delete_keys", "delete_outdated_inbound"}) + CopyToOtherLocation(helper, up.Str, []string{"bridge", "encryption", "verification_levels", "receive"}, []string{"encryption", "verification_levels", "receive"}) + CopyToOtherLocation(helper, up.Str, []string{"bridge", "encryption", "verification_levels", "send"}, []string{"encryption", "verification_levels", "send"}) + CopyToOtherLocation(helper, up.Str, []string{"bridge", "encryption", "verification_levels", "share"}, []string{"encryption", "verification_levels", "share"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "rotation", "enable_custom"}, []string{"encryption", "rotation", "enable_custom"}) + CopyToOtherLocation(helper, up.Int, []string{"bridge", "encryption", "rotation", "milliseconds"}, []string{"encryption", "rotation", "milliseconds"}) + CopyToOtherLocation(helper, up.Int, []string{"bridge", "encryption", "rotation", "messages"}, []string{"encryption", "rotation", "messages"}) + CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "rotation", "disable_device_change_key_rotation"}, []string{"encryption", "rotation", "disable_device_change_key_rotation"}) + + if helper.GetNode("logging", "writers") == nil && (helper.GetNode("logging", "print_level") != nil || helper.GetNode("logging", "file_name_format") != nil) { + _, _ = fmt.Fprintln(os.Stderr, "Migrating maulogger configs is not supported") + } else if (helper.GetNode("logging", "writers") == nil && (helper.GetNode("logging", "handlers") != nil)) || python { + _, _ = fmt.Fprintln(os.Stderr, "Migrating Python log configs is not supported") + } else { + helper.Copy(up.Map, "logging") + } + + HackyMigrateLegacyNetworkConfig(helper) +} diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 9597fa4f..7e524e84 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -7,9 +7,6 @@ package bridgeconfig import ( - "fmt" - "os" - up "go.mau.fi/util/configupgrade" "go.mau.fi/util/random" @@ -18,7 +15,10 @@ import ( func doUpgrade(helper up.Helper) { if _, isLegacyConfig := helper.Get(up.Str, "appservice", "database", "uri"); isLegacyConfig { - doMigrateLegacy(helper) + doMigrateLegacy(helper, false) + return + } else if _, isLegacyPython := helper.Get(up.Str, "appservice", "database"); isLegacyPython { + doMigrateLegacy(helper, true) return } @@ -160,119 +160,6 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Map, "logging") } -func CopyToOtherLocation(helper up.Helper, fieldType up.YAMLType, source, dest []string) { - val, ok := helper.Get(fieldType, source...) - if ok { - helper.Set(fieldType, val, dest...) - } -} - -func CopyMapToOtherLocation(helper up.Helper, source, dest []string) { - val := helper.GetNode(source...) - if val != nil && val.Map != nil { - helper.SetMap(val.Map, dest...) - } -} - -var HackyMigrateLegacyNetworkConfig func(up.Helper) - -func doMigrateLegacy(helper up.Helper) { - if HackyMigrateLegacyNetworkConfig == nil { - _, _ = fmt.Fprintln(os.Stderr, "Legacy bridge config detected, but hacky network config migrator is not set") - os.Exit(1) - } - _, _ = fmt.Fprintln(os.Stderr, "Migrating legacy bridge config") - - helper.Copy(up.Str, "homeserver", "address") - helper.Copy(up.Str, "homeserver", "domain") - helper.Copy(up.Str, "homeserver", "software") - helper.Copy(up.Str|up.Null, "homeserver", "status_endpoint") - helper.Copy(up.Str|up.Null, "homeserver", "message_send_checkpoint_endpoint") - helper.Copy(up.Bool, "homeserver", "async_media") - helper.Copy(up.Str|up.Null, "homeserver", "websocket_proxy") - helper.Copy(up.Bool, "homeserver", "websocket") - helper.Copy(up.Int, "homeserver", "ping_interval_seconds") - - helper.Copy(up.Str|up.Null, "appservice", "address") - helper.Copy(up.Str|up.Null, "appservice", "hostname") - helper.Copy(up.Int|up.Null, "appservice", "port") - helper.Copy(up.Str, "appservice", "id") - helper.Copy(up.Str, "appservice", "bot", "username") - helper.Copy(up.Str, "appservice", "bot", "displayname") - helper.Copy(up.Str, "appservice", "bot", "avatar") - helper.Copy(up.Bool, "appservice", "ephemeral_events") - helper.Copy(up.Bool, "appservice", "async_transactions") - helper.Copy(up.Str, "appservice", "as_token") - helper.Copy(up.Str, "appservice", "hs_token") - - helper.Copy(up.Str, "bridge", "command_prefix") - helper.Copy(up.Bool, "bridge", "personal_filtering_spaces") - if oldPM, ok := helper.Get(up.Str, "bridge", "private_chat_portal_meta"); ok && (oldPM == "default" || oldPM == "always") { - helper.Set(up.Bool, "true", "bridge", "private_chat_portal_meta") - } else { - helper.Set(up.Bool, "false", "bridge", "private_chat_portal_meta") - } - helper.Copy(up.Bool, "bridge", "relay", "enabled") - helper.Copy(up.Bool, "bridge", "relay", "admin_only") - helper.Copy(up.Map, "bridge", "permissions") - - CopyToOtherLocation(helper, up.Str, []string{"appservice", "database", "type"}, []string{"database", "type"}) - CopyToOtherLocation(helper, up.Str, []string{"appservice", "database", "uri"}, []string{"database", "uri"}) - CopyToOtherLocation(helper, up.Int, []string{"appservice", "database", "max_open_conns"}, []string{"database", "max_open_conns"}) - CopyToOtherLocation(helper, up.Int, []string{"appservice", "database", "max_idle_conns"}, []string{"database", "max_idle_conns"}) - CopyToOtherLocation(helper, up.Int, []string{"appservice", "database", "max_conn_idle_time"}, []string{"database", "max_conn_idle_time"}) - CopyToOtherLocation(helper, up.Int, []string{"appservice", "database", "max_conn_lifetime"}, []string{"database", "max_conn_lifetime"}) - - CopyToOtherLocation(helper, up.Str, []string{"bridge", "username_template"}, []string{"appservice", "username_template"}) - - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "message_status_events"}, []string{"matrix", "message_status_events"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "delivery_receipts"}, []string{"matrix", "delivery_receipts"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "message_error_notices"}, []string{"matrix", "message_error_notices"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "sync_direct_chat_list"}, []string{"matrix", "sync_direct_chat_list"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "federate_rooms"}, []string{"matrix", "federate_rooms"}) - - CopyToOtherLocation(helper, up.Str, []string{"bridge", "provisioning", "prefix"}, []string{"provisioning", "prefix"}) - CopyToOtherLocation(helper, up.Str, []string{"bridge", "provisioning", "shared_secret"}, []string{"provisioning", "shared_secret"}) - CopyToOtherLocation(helper, up.Str, []string{"appservice", "provisioning", "prefix"}, []string{"provisioning", "prefix"}) - CopyToOtherLocation(helper, up.Str, []string{"appservice", "provisioning", "shared_secret"}, []string{"provisioning", "shared_secret"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "provisioning", "debug_endpoints"}, []string{"provisioning", "debug_endpoints"}) - - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "double_puppet_allow_discovery"}, []string{"double_puppet", "allow_discovery"}) - CopyMapToOtherLocation(helper, []string{"bridge", "double_puppet_server_map"}, []string{"double_puppet", "servers"}) - CopyMapToOtherLocation(helper, []string{"bridge", "login_shared_secret_map"}, []string{"double_puppet", "secrets"}) - - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "allow"}, []string{"encryption", "allow"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "default"}, []string{"encryption", "default"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "require"}, []string{"encryption", "require"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "appservice"}, []string{"encryption", "appservice"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "allow_key_sharing"}, []string{"encryption", "allow_key_sharing"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_outbound_on_ack"}, []string{"encryption", "delete_keys", "delete_outbound_on_ack"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "dont_store_outbound"}, []string{"encryption", "delete_keys", "dont_store_outbound"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "ratchet_on_decrypt"}, []string{"encryption", "delete_keys", "ratchet_on_decrypt"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_fully_used_on_decrypt"}, []string{"encryption", "delete_keys", "delete_fully_used_on_decrypt"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_prev_on_new_session"}, []string{"encryption", "delete_keys", "delete_prev_on_new_session"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_on_device_delete"}, []string{"encryption", "delete_keys", "delete_on_device_delete"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "periodically_delete_expired"}, []string{"encryption", "delete_keys", "periodically_delete_expired"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "delete_keys", "delete_outdated_inbound"}, []string{"encryption", "delete_keys", "delete_outdated_inbound"}) - CopyToOtherLocation(helper, up.Str, []string{"bridge", "encryption", "verification_levels", "receive"}, []string{"encryption", "verification_levels", "receive"}) - CopyToOtherLocation(helper, up.Str, []string{"bridge", "encryption", "verification_levels", "send"}, []string{"encryption", "verification_levels", "send"}) - CopyToOtherLocation(helper, up.Str, []string{"bridge", "encryption", "verification_levels", "share"}, []string{"encryption", "verification_levels", "share"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "rotation", "enable_custom"}, []string{"encryption", "rotation", "enable_custom"}) - CopyToOtherLocation(helper, up.Int, []string{"bridge", "encryption", "rotation", "milliseconds"}, []string{"encryption", "rotation", "milliseconds"}) - CopyToOtherLocation(helper, up.Int, []string{"bridge", "encryption", "rotation", "messages"}, []string{"encryption", "rotation", "messages"}) - CopyToOtherLocation(helper, up.Bool, []string{"bridge", "encryption", "rotation", "disable_device_change_key_rotation"}, []string{"encryption", "rotation", "disable_device_change_key_rotation"}) - - if helper.GetNode("logging", "writers") == nil && (helper.GetNode("logging", "print_level") != nil || helper.GetNode("logging", "file_name_format") != nil) { - _, _ = fmt.Fprintln(os.Stderr, "Migrating maulogger configs is not supported") - } else if helper.GetNode("logging", "writers") == nil && (helper.GetNode("logging", "handlers") != nil) { - _, _ = fmt.Fprintln(os.Stderr, "Migrating Python log configs is not supported") - } else { - helper.Copy(up.Map, "logging") - } - - HackyMigrateLegacyNetworkConfig(helper) -} - var SpacedBlocks = [][]string{ {"bridge"}, {"bridge", "relay"}, From 3eeca239f197d10123123b3fed33ec3dc17e58c2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 27 Aug 2024 14:51:37 +0300 Subject: [PATCH 0685/1647] bridgev2/legacymigrate: add txlock when migrating python sqlite db config --- bridgev2/bridgeconfig/legacymigrate.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/bridgeconfig/legacymigrate.go b/bridgev2/bridgeconfig/legacymigrate.go index 9ec8ea12..e8fab743 100644 --- a/bridgev2/bridgeconfig/legacymigrate.go +++ b/bridgev2/bridgeconfig/legacymigrate.go @@ -96,7 +96,7 @@ func doMigrateLegacy(helper up.Helper, python bool) { helper.Set(up.Str, "postgres", "database", "type") } else { dbPath := strings.TrimPrefix(strings.TrimPrefix(legacyDB, "sqlite:"), "///") - helper.Set(up.Str, dbPath, "database", "uri") + helper.Set(up.Str, fmt.Sprintf("file:%s?_txlock=immediate", dbPath), "database", "uri") helper.Set(up.Str, "sqlite3-fk-wal", "database", "type") } } From 7276f10fcf1e9a17be678112f381b16295901388 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 27 Aug 2024 14:54:26 +0300 Subject: [PATCH 0686/1647] bridgev2/legacymigrate: clear database owner on upgrade --- bridgev2/matrix/mxmain/legacymigrate.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go index 441b59bd..f9456f75 100644 --- a/bridgev2/matrix/mxmain/legacymigrate.go +++ b/bridgev2/matrix/mxmain/legacymigrate.go @@ -65,7 +65,11 @@ func (br *BridgeMain) LegacyMigrateWithAnotherUpgrader(renameTablesQuery, copyDa if err != nil { return err } - _, err = br.DB.Exec(ctx, "UPDATE database_owner SET owner = $1 WHERE key = 0", br.DB.Owner) + _, err = br.DB.Exec(ctx, "DELETE FROM database_owner") + if err != nil { + return err + } + _, err = br.DB.Exec(ctx, "INSERT INTO database_owner (key, owner) VALUES (0, $1)", br.DB.Owner) if err != nil { return err } @@ -103,7 +107,7 @@ func (br *BridgeMain) CheckLegacyDB( return } var owner string - err = br.DB.QueryRow(ctx, "SELECT owner FROM database_owner WHERE key=0").Scan(&owner) + err = br.DB.QueryRow(ctx, "SELECT owner FROM database_owner LIMIT 1").Scan(&owner) if err != nil && !errors.Is(err, sql.ErrNoRows) { log.Err(err).Msg("Failed to get database owner") return From 09e9cac5e8ed73375422dac7520e52d5d996aefb Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 27 Aug 2024 16:13:51 +0300 Subject: [PATCH 0687/1647] bridgev2/commands: reply with login URL when doing cookie login --- bridgev2/commands/login.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go index e813de70..86071a8e 100644 --- a/bridgev2/commands/login.go +++ b/bridgev2/commands/login.go @@ -269,6 +269,7 @@ type cookieLoginCommandState struct { } func (clcs *cookieLoginCommandState) prompt(ce *Event) { + ce.Reply("Login URL: <%s>", clcs.Data.URL) StoreCommandState(ce.User, &CommandState{ Next: MinimalCommandHandlerFunc(clcs.submit), Action: "Login", From e3eb2953ddda20dba34b18717684241faaa02dfd Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 27 Aug 2024 16:49:50 +0300 Subject: [PATCH 0688/1647] bridgev2/legacymigrate: fix version table values --- bridgev2/matrix/mxmain/legacymigrate.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go index f9456f75..b2bdaa91 100644 --- a/bridgev2/matrix/mxmain/legacymigrate.go +++ b/bridgev2/matrix/mxmain/legacymigrate.go @@ -46,13 +46,13 @@ func (br *BridgeMain) LegacyMigrateWithAnotherUpgrader(renameTablesQuery, copyDa if err != nil { return err } - upgradesTo, compat, err = otherTable[0].DangerouslyRun(ctx, br.DB) + otherUpgradesTo, otherCompat, err := otherTable[0].DangerouslyRun(ctx, br.DB) if err != nil { return err - } else if upgradesTo < otherNewVersion || compat > otherNewVersion { - return fmt.Errorf("unexpected new database version for %s (%d/c:%d, expected %d)", otherTableName, upgradesTo, compat, newDBVersion) + } else if otherUpgradesTo < otherNewVersion || otherCompat > otherNewVersion { + return fmt.Errorf("unexpected new database version for %s (%d/c:%d, expected %d)", otherTableName, otherUpgradesTo, otherCompat, newDBVersion) } - _, err = br.DB.Exec(ctx, fmt.Sprintf("INSERT INTO %s (version, compat) VALUES ($1, $2)", otherTableName), upgradesTo, compat) + _, err = br.DB.Exec(ctx, fmt.Sprintf("INSERT INTO %s (version, compat) VALUES ($1, $2)", otherTableName), otherUpgradesTo, otherCompat) if err != nil { return err } From f56905a27645bb38b2da452fc409c2a910d64dfc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 27 Aug 2024 22:10:23 +0300 Subject: [PATCH 0689/1647] bridgev2/portal: fix panic in FindPreferredLogin if receiver login doesn't exist --- bridgev2/portal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index f9d5aa10..384a0357 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -311,7 +311,7 @@ func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowR if err != nil { return nil, nil, err } - if login.UserMXID != user.MXID { + if login == nil || login.UserMXID != user.MXID { if allowRelay && portal.Relay != nil { return nil, nil, nil } From 892e5cf01fc806c248229804210ea627516754b4 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 27 Aug 2024 16:11:41 -0600 Subject: [PATCH 0690/1647] bridgev2/provisioning: allow custom auth token retrieval Signed-off-by: Sumner Evans --- bridgev2/matrix/provisioning.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index b8a33fff..9a737a61 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -50,6 +50,10 @@ type ProvisioningAPI struct { matrixAuthCache map[string]matrixAuthCacheEntry matrixAuthCacheLock sync.Mutex + + // GetAuthFromRequest is a custom function for getting the auth token from + // the request if the Authorization header is not present. + GetAuthFromRequest func(r *http.Request) string } type ProvLogin struct { @@ -184,6 +188,9 @@ func (prov *ProvisioningAPI) checkFederatedMatrixAuth(ctx context.Context, userI func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") + if auth == "" && prov.GetAuthFromRequest != nil { + auth = prov.GetAuthFromRequest(r) + } if auth == "" { jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{ Err: "Missing auth token", From eef37c3295441e97d84f2b60abd31d6adb2ddd74 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 28 Aug 2024 00:16:13 +0300 Subject: [PATCH 0691/1647] client: add option to log syncs at trace level --- client.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/client.go b/client.go index 294595f8..686ec21f 100644 --- a/client.go +++ b/client.go @@ -14,6 +14,7 @@ import ( "net/url" "os" "strconv" + "strings" "sync/atomic" "time" @@ -92,6 +93,7 @@ type Client struct { UpdateRequestOnRetry func(req *http.Request, cause error) *http.Request SyncPresence event.Presence + SyncTraceLog bool StreamSyncMinAge time.Duration @@ -321,6 +323,8 @@ func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err er } else if handlerErr != nil { evt = zerolog.Ctx(req.Context()).Warn(). AnErr("body_parse_err", handlerErr) + } else if cli.SyncTraceLog && strings.HasSuffix(req.URL.Path, "/_matrix/client/v3/sync") { + evt = zerolog.Ctx(req.Context()).Trace() } else { evt = zerolog.Ctx(req.Context()).Debug() } From a224ed019dcb3e98e292891075119fb9ac6e024b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 28 Aug 2024 13:40:46 +0300 Subject: [PATCH 0692/1647] bridgev2: stop using GetCachedUserLogins --- bridgev2/bridge.go | 2 +- bridgev2/matrix/provisioning.go | 2 +- bridgev2/matrixinvite.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index aadefb0a..e64d7a40 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -151,7 +151,7 @@ func (br *Bridge) StartLogins() error { if err != nil { br.Log.Err(err).Stringer("user_id", userID).Msg("Failed to load user") } else { - for _, login := range user.GetCachedUserLogins() { + for _, login := range user.GetUserLogins() { startedAny = true br.Log.Info().Str("id", string(login.ID)).Msg("Starting user login") err = login.Client.Connect(login.Log.WithContext(ctx)) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 9a737a61..1cbb14b3 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -311,7 +311,7 @@ func (prov *ProvisioningAPI) GetWhoami(w http.ResponseWriter, r *http.Request) { CommandPrefix: prov.br.Config.Bridge.CommandPrefix, ManagementRoom: user.ManagementRoom, } - logins := user.GetCachedUserLogins() + logins := user.GetUserLogins() resp.Logins = make([]RespWhoamiLogin, len(logins)) for i, login := range logins { prevState := login.BridgeState.GetPrevUnsent() diff --git a/bridgev2/matrixinvite.go b/bridgev2/matrixinvite.go index 740743f6..25938c4f 100644 --- a/bridgev2/matrixinvite.go +++ b/bridgev2/matrixinvite.go @@ -99,7 +99,7 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen Stringer("room_id", evt.RoomID). Logger() // TODO sort in preference order - logins := sender.GetCachedUserLogins() + logins := sender.GetUserLogins() if len(logins) == 0 { rejectInvite(ctx, evt, br.Matrix.GhostIntent(ghostID), "You're not logged in") return From 838237da73fcdc1223363d70955903e6fcb1232b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 28 Aug 2024 13:41:04 +0300 Subject: [PATCH 0693/1647] bridgev2/provisioning: add search users endpoint --- bridgev2/matrix/provisioning.go | 109 +++++++++++++++++++++--------- bridgev2/matrix/provisioning.yaml | 45 ++++++++++++ 2 files changed, 122 insertions(+), 32 deletions(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 1cbb14b3..c1f02890 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -113,6 +113,7 @@ func (prov *ProvisioningAPI) Init() { prov.Router.Path("/v3/logout/{loginID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLogout) prov.Router.Path("/v3/logins").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLogins) prov.Router.Path("/v3/contacts").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetContactList) + prov.Router.Path("/v3/search_users").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostSearchUsers) prov.Router.Path("/v3/resolve_identifier/{identifier}").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetResolveIdentifier) prov.Router.Path("/v3/create_dm/{identifier}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateDM) prov.Router.Path("/v3/create_group").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateGroup) @@ -615,33 +616,13 @@ type RespGetContactList struct { Contacts []*RespResolveIdentifier `json:"contacts"` } -func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Request) { - login := prov.GetLoginForRequest(w, r) - if login == nil { - return - } - api, ok := login.Client.(bridgev2.ContactListingNetworkAPI) - if !ok { - jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ - Err: "This bridge does not support listing contacts", - ErrCode: mautrix.MUnrecognized.ErrCode, - }) - return - } - resp, err := api.GetContactList(r.Context()) - if err != nil { - zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get contact list") - respondMaybeCustomError(w, err, "Internal error fetching contact list") - return - } - apiResp := &RespGetContactList{ - Contacts: make([]*RespResolveIdentifier, len(resp)), - } +func (prov *ProvisioningAPI) processResolveIdentifiers(ctx context.Context, resp []*bridgev2.ResolveIdentifierResponse) (apiResp []*RespResolveIdentifier) { + apiResp = make([]*RespResolveIdentifier, len(resp)) for i, contact := range resp { apiContact := &RespResolveIdentifier{ ID: contact.UserID, } - apiResp.Contacts[i] = apiContact + apiResp[i] = apiContact if contact.UserInfo != nil { if contact.UserInfo.Name != nil { apiContact.Name = *contact.UserInfo.Name @@ -662,20 +643,84 @@ func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Reque } if contact.Chat != nil { if contact.Chat.Portal == nil { - contact.Chat.Portal, err = prov.br.Bridge.GetPortalByKey(r.Context(), contact.Chat.PortalKey) + var err error + contact.Chat.Portal, err = prov.br.Bridge.GetPortalByKey(ctx, contact.Chat.PortalKey) if err != nil { - zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get portal") - jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ - Err: "Failed to get portal", - ErrCode: "M_UNKNOWN", - }) - return + zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal") } } - apiContact.DMRoomID = contact.Chat.Portal.MXID + if contact.Chat.Portal != nil { + apiContact.DMRoomID = contact.Chat.Portal.MXID + } } } - jsonResponse(w, http.StatusOK, apiResp) + return +} + +func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Request) { + login := prov.GetLoginForRequest(w, r) + if login == nil { + return + } + api, ok := login.Client.(bridgev2.ContactListingNetworkAPI) + if !ok { + jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ + Err: "This bridge does not support listing contacts", + ErrCode: mautrix.MUnrecognized.ErrCode, + }) + return + } + resp, err := api.GetContactList(r.Context()) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get contact list") + respondMaybeCustomError(w, err, "Internal error fetching contact list") + return + } + jsonResponse(w, http.StatusOK, &RespGetContactList{ + Contacts: prov.processResolveIdentifiers(r.Context(), resp), + }) +} + +type ReqSearchUsers struct { + Query string `json:"query"` +} + +type RespSearchUsers struct { + Results []*RespResolveIdentifier `json:"results"` +} + +func (prov *ProvisioningAPI) PostSearchUsers(w http.ResponseWriter, r *http.Request) { + var req ReqSearchUsers + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") + jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ + Err: "Failed to decode request body", + ErrCode: mautrix.MNotJSON.ErrCode, + }) + return + } + login := prov.GetLoginForRequest(w, r) + if login == nil { + return + } + api, ok := login.Client.(bridgev2.UserSearchingNetworkAPI) + if !ok { + jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ + Err: "This bridge does not support searching for users", + ErrCode: mautrix.MUnrecognized.ErrCode, + }) + return + } + resp, err := api.SearchUsers(r.Context(), req.Query) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get contact list") + respondMaybeCustomError(w, err, "Internal error fetching contact list") + return + } + jsonResponse(w, http.StatusOK, &RespSearchUsers{ + Results: prov.processResolveIdentifiers(r.Context(), resp), + }) } func (prov *ProvisioningAPI) GetResolveIdentifier(w http.ResponseWriter, r *http.Request) { diff --git a/bridgev2/matrix/provisioning.yaml b/bridgev2/matrix/provisioning.yaml index 1daf7b07..d03101b0 100644 --- a/bridgev2/matrix/provisioning.yaml +++ b/bridgev2/matrix/provisioning.yaml @@ -240,6 +240,51 @@ paths: responses: 200: description: Contact list fetched successfully + content: + application/json: + schema: + type: object + properties: + contacts: + type: array + items: + $ref: '#/components/schemas/ResolvedIdentifier' + 401: + $ref: '#/components/responses/Unauthorized' + 404: + $ref: '#/components/responses/LoginNotFound' + 500: + $ref: '#/components/responses/InternalError' + 501: + $ref: '#/components/responses/NotSupported' + /v3/search_users: + post: + tags: [ snc ] + summary: Search for users on the remote network + operationId: searchUsers + parameters: + - $ref: "#/components/parameters/loginID" + requestBody: + content: + application/json: + schema: + type: object + properties: + query: + type: string + description: The search query to send to the remote network + responses: + 200: + description: Search completed successfully + content: + application/json: + schema: + type: object + properties: + results: + type: array + items: + $ref: '#/components/schemas/ResolvedIdentifier' 401: $ref: '#/components/responses/Unauthorized' 404: From fd89457be8f4df2287e51b81df76b81a81e9cb77 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 28 Aug 2024 16:18:27 +0300 Subject: [PATCH 0694/1647] bridgev2/backfill: add option for aggressive deduplication --- bridgev2/networkinterface.go | 42 ++++++++++++++++++++++++++++++++++++ bridgev2/portalbackfill.go | 42 ++++++++++++++++++++++++++++++------ 2 files changed, 77 insertions(+), 7 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index b68ad0c9..556b7407 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -424,6 +424,43 @@ type BackfillMessage struct { LastThreadMessage networkid.MessageID } +var ( + _ RemoteMessageWithTransactionID = (*BackfillMessage)(nil) + _ RemoteEventWithTimestamp = (*BackfillMessage)(nil) +) + +func (b *BackfillMessage) GetType() RemoteEventType { + return RemoteEventMessage +} + +func (b *BackfillMessage) GetPortalKey() networkid.PortalKey { + panic("GetPortalKey called for BackfillMessage") +} + +func (b *BackfillMessage) AddLogContext(c zerolog.Context) zerolog.Context { + return c +} + +func (b *BackfillMessage) GetSender() EventSender { + return b.Sender +} + +func (b *BackfillMessage) GetID() networkid.MessageID { + return b.ID +} + +func (b *BackfillMessage) GetTransactionID() networkid.TransactionID { + return b.TxnID +} + +func (b *BackfillMessage) GetTimestamp() time.Time { + return b.Timestamp +} + +func (b *BackfillMessage) ConvertMessage(ctx context.Context, portal *Portal, intent MatrixAPI) (*ConvertedMessage, error) { + return b.ConvertedMessage, nil +} + // FetchMessagesResponse contains the response for a message history pagination request. type FetchMessagesResponse struct { // The messages to backfill. Messages should always be sorted in chronological order (oldest to newest). @@ -440,6 +477,11 @@ type FetchMessagesResponse struct { // to mark the messages as read immediately after backfilling. MarkRead bool + // Should the bridge check each message against the database to ensure it's not a duplicate before bridging? + // By default, the bridge will only drop messages that are older than the last bridged message for forward backfills, + // or newer than the first for backward. + AggressiveDeduplication bool + // When HasMore is true, one of the following fields can be set to report backfill progress: // Approximate backfill progress as a number between 0 and 1. diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index 1bff29c1..ffe68ca5 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -63,9 +63,8 @@ func (portal *Portal) doForwardBackfill(ctx context.Context, source *UserLogin, log.Debug().Msg("No messages to backfill") return } - // TODO check pending messages // TODO mark backfill queue task as done if last message is nil (-> room was empty) and HasMore is false? - resp.Messages = cutoffMessages(&log, resp.Messages, true, lastMessage) + resp.Messages = portal.cutoffMessages(ctx, resp.Messages, resp.AggressiveDeduplication, true, lastMessage) if len(resp.Messages) == 0 { log.Warn().Msg("No messages left to backfill after cutting off old messages") return @@ -126,7 +125,7 @@ func (portal *Portal) DoBackwardsBackfill(ctx context.Context, source *UserLogin } return nil } - resp.Messages = cutoffMessages(log, resp.Messages, false, firstMessage) + resp.Messages = portal.cutoffMessages(ctx, resp.Messages, resp.AggressiveDeduplication, false, firstMessage) if len(resp.Messages) == 0 { return fmt.Errorf("no messages left to backfill after cutting off too new messages") } @@ -156,7 +155,7 @@ func (portal *Portal) fetchThreadBackfill(ctx context.Context, source *UserLogin log.Debug().Msg("No messages to backfill") return nil } - resp.Messages = cutoffMessages(log, resp.Messages, true, anchor) + resp.Messages = portal.cutoffMessages(ctx, resp.Messages, resp.AggressiveDeduplication, true, anchor) if len(resp.Messages) == 0 { log.Warn().Msg("No messages left to backfill after cutting off old messages") return nil @@ -182,7 +181,7 @@ func (portal *Portal) doThreadBackfill(ctx context.Context, source *UserLogin, t } } -func cutoffMessages(log *zerolog.Logger, messages []*BackfillMessage, forward bool, lastMessage *database.Message) []*BackfillMessage { +func (portal *Portal) cutoffMessages(ctx context.Context, messages []*BackfillMessage, aggressiveDedup, forward bool, lastMessage *database.Message) []*BackfillMessage { if lastMessage == nil { return messages } @@ -196,7 +195,7 @@ func cutoffMessages(log *zerolog.Logger, messages []*BackfillMessage, forward bo } } if cutoff != -1 { - log.Debug(). + zerolog.Ctx(ctx).Debug(). Int("cutoff_count", cutoff+1). Int("total_count", len(messages)). Time("last_bridged_ts", lastMessage.Timestamp). @@ -213,7 +212,7 @@ func cutoffMessages(log *zerolog.Logger, messages []*BackfillMessage, forward bo } } if cutoff != -1 { - log.Debug(). + zerolog.Ctx(ctx).Debug(). Int("cutoff_count", len(messages)-cutoff). Int("total_count", len(messages)). Time("oldest_bridged_ts", lastMessage.Timestamp). @@ -221,6 +220,35 @@ func cutoffMessages(log *zerolog.Logger, messages []*BackfillMessage, forward bo messages = messages[:cutoff] } } + if aggressiveDedup { + filteredMessages := messages[:0] + for _, msg := range messages { + existingMsg, err := portal.Bridge.DB.Message.GetFirstPartByID(ctx, portal.Receiver, msg.ID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Str("message_id", string(msg.ID)).Msg("Failed to check for existing message") + } else if existingMsg != nil { + zerolog.Ctx(ctx).Err(err). + Str("message_id", string(msg.ID)). + Time("message_ts", msg.Timestamp). + Str("message_sender", string(msg.Sender.Sender)). + Msg("Ignoring duplicate message in backfill") + continue + } + if forward && msg.TxnID != "" { + wasPending, _ := portal.checkPendingMessage(ctx, msg) + if wasPending { + zerolog.Ctx(ctx).Err(err). + Str("transaction_id", string(msg.TxnID)). + Str("message_id", string(msg.ID)). + Time("message_ts", msg.Timestamp). + Msg("Found pending message in backfill") + continue + } + } + filteredMessages = append(filteredMessages, msg) + } + messages = filteredMessages + } return messages } From 5f49ca683a3d492586146fb2806bb7d01ddc4747 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Wed, 28 Aug 2024 12:14:22 -0600 Subject: [PATCH 0695/1647] bridgestate: deduplicate on remote name Signed-off-by: Sumner Evans --- bridge/status/bridgestate.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bridge/status/bridgestate.go b/bridge/status/bridgestate.go index 72e61415..1aa4bb1f 100644 --- a/bridge/status/bridgestate.go +++ b/bridge/status/bridgestate.go @@ -187,6 +187,7 @@ func (pong *BridgeState) SendHTTP(ctx context.Context, url, token string) error func (pong *BridgeState) ShouldDeduplicate(newPong *BridgeState) bool { return pong != nil && pong.StateEvent == newPong.StateEvent && + pong.RemoteName == newPong.RemoteName && ptr.Val(pong.RemoteProfile) == ptr.Val(newPong.RemoteProfile) && pong.Error == newPong.Error && maps.EqualFunc(pong.Info, newPong.Info, reflect.DeepEqual) && From 238cacf2d5261db9eda350bca951354d05b8f9c8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 28 Aug 2024 22:06:43 +0300 Subject: [PATCH 0696/1647] client,crypto,appservice: add MSC3202 features --- appservice/registration.go | 1 + client.go | 3 ++ crypto/cryptohelper/cryptohelper.go | 65 ++++++++++++++++++++++++----- crypto/machine.go | 5 ++- requests.go | 2 +- url.go | 4 ++ 6 files changed, 66 insertions(+), 14 deletions(-) diff --git a/appservice/registration.go b/appservice/registration.go index b11bd84b..c0b62124 100644 --- a/appservice/registration.go +++ b/appservice/registration.go @@ -28,6 +28,7 @@ type Registration struct { SoruEphemeralEvents bool `yaml:"de.sorunome.msc2409.push_ephemeral,omitempty" json:"de.sorunome.msc2409.push_ephemeral,omitempty"` EphemeralEvents bool `yaml:"push_ephemeral,omitempty" json:"push_ephemeral,omitempty"` + MSC3202 bool `yaml:"org.matrix.msc3202,omitempty" json:"org.matrix.msc3202,omitempty"` } // CreateRegistration creates a Registration with random appservice and homeserver tokens. diff --git a/client.go b/client.go index 686ec21f..e1b36f87 100644 --- a/client.go +++ b/client.go @@ -110,6 +110,9 @@ type Client struct { // Should the ?user_id= query parameter be set in requests? // See https://spec.matrix.org/v1.6/application-service-api/#identity-assertion SetAppServiceUserID bool + // Should the org.matrix.msc3202.device_id query parameter be set in requests? + // See https://github.com/matrix-org/matrix-spec-proposals/pull/3202 + SetAppServiceDeviceID bool syncingID uint32 // Identifies the current Sync. Only one Sync can be active at any given time. } diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go index 7bb7037d..5f1a952f 100644 --- a/crypto/cryptohelper/cryptohelper.go +++ b/crypto/cryptohelper/cryptohelper.go @@ -38,6 +38,8 @@ type CryptoHelper struct { LoginAs *mautrix.ReqLogin + ASEventProcessor crypto.ASEventProcessor + DBAccountID string } @@ -58,7 +60,7 @@ func NewCryptoHelper(cli *mautrix.Client, pickleKey []byte, store any) (*CryptoH return nil, fmt.Errorf("pickle key must be provided") } _, isExtensible := cli.Syncer.(mautrix.ExtensibleSyncer) - if !isExtensible { + if !cli.SetAppServiceDeviceID && !isExtensible { return nil, fmt.Errorf("the client syncer must implement ExtensibleSyncer") } @@ -111,7 +113,11 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { } syncer, ok := helper.client.Syncer.(mautrix.ExtensibleSyncer) if !ok { - return fmt.Errorf("the client syncer must implement ExtensibleSyncer") + if !helper.client.SetAppServiceDeviceID { + return fmt.Errorf("the client syncer must implement ExtensibleSyncer") + } else if helper.ASEventProcessor == nil { + return fmt.Errorf("an appservice must be provided when using appservice mode encryption") + } } var stateStore crypto.StateStore @@ -140,7 +146,27 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to find existing device ID: %w", err) } - if helper.LoginAs != nil { + if helper.LoginAs != nil && helper.LoginAs.Type == mautrix.AuthTypeAppservice && helper.client.SetAppServiceDeviceID { + if storedDeviceID == "" { + helper.log.Debug(). + Str("username", helper.LoginAs.Identifier.User). + Msg("Logging in with appservice") + var resp *mautrix.RespLogin + resp, err = helper.client.Login(ctx, helper.LoginAs) + if err != nil { + return err + } + managedCryptoStore.DeviceID = resp.DeviceID + helper.client.DeviceID = resp.DeviceID + } else { + helper.log.Debug(). + Str("username", helper.LoginAs.Identifier.User). + Stringer("device_id", storedDeviceID). + Msg("Using existing device") + managedCryptoStore.DeviceID = storedDeviceID + helper.client.DeviceID = storedDeviceID + } + } else if helper.LoginAs != nil { if storedDeviceID != "" { helper.LoginAs.DeviceID = storedDeviceID } @@ -177,16 +203,29 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { return err } - syncer.OnSync(helper.mach.ProcessSyncResponse) - syncer.OnEventType(event.StateMember, helper.mach.HandleMemberEvent) - if _, ok = helper.client.Syncer.(mautrix.DispatchableSyncer); ok { - syncer.OnEventType(event.EventEncrypted, helper.HandleEncrypted) + if syncer != nil { + syncer.OnSync(helper.mach.ProcessSyncResponse) + syncer.OnEventType(event.StateMember, helper.mach.HandleMemberEvent) + if _, ok = helper.client.Syncer.(mautrix.DispatchableSyncer); ok { + syncer.OnEventType(event.EventEncrypted, helper.HandleEncrypted) + } else { + helper.log.Warn().Msg("Client syncer does not implement DispatchableSyncer. Events will not be decrypted automatically.") + } + if helper.managedStateStore != nil { + syncer.OnEvent(helper.client.StateStoreSyncHandler) + } } else { - helper.log.Warn().Msg("Client syncer does not implement DispatchableSyncer. Events will not be decrypted automatically.") + helper.mach.AddAppserviceListener(helper.ASEventProcessor) + helper.ASEventProcessor.On(event.EventEncrypted, helper.HandleEncrypted) } - if helper.managedStateStore != nil { - syncer.OnEvent(helper.client.StateStoreSyncHandler) + + if helper.client.SetAppServiceDeviceID { + err = helper.mach.ShareKeys(ctx, -1) + if err != nil { + return fmt.Errorf("failed to share keys: %w", err) + } } + return nil } @@ -281,7 +320,11 @@ func (helper *CryptoHelper) HandleEncrypted(ctx context.Context, evt *event.Even func (helper *CryptoHelper) postDecrypt(ctx context.Context, decrypted *event.Event) { decrypted.Mautrix.EventSource |= event.SourceDecrypted - helper.client.Syncer.(mautrix.DispatchableSyncer).Dispatch(ctx, decrypted) + if helper.ASEventProcessor != nil { + helper.ASEventProcessor.Dispatch(ctx, decrypted) + } else { + helper.client.Syncer.(mautrix.DispatchableSyncer).Dispatch(ctx, decrypted) + } } func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) { diff --git a/crypto/machine.go b/crypto/machine.go index 2477b9e1..188aa210 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -208,13 +208,14 @@ func (mach *OlmMachine) OwnIdentity() *id.Device { } } -type asEventProcessor interface { +type ASEventProcessor interface { On(evtType event.Type, handler func(ctx context.Context, evt *event.Event)) OnOTK(func(ctx context.Context, otk *mautrix.OTKCount)) OnDeviceList(func(ctx context.Context, lists *mautrix.DeviceLists, since string)) + Dispatch(ctx context.Context, evt *event.Event) } -func (mach *OlmMachine) AddAppserviceListener(ep asEventProcessor) { +func (mach *OlmMachine) AddAppserviceListener(ep ASEventProcessor) { // ToDeviceForwardedRoomKey and ToDeviceRoomKey should only be present inside encrypted to-device events ep.On(event.ToDeviceEncrypted, mach.HandleToDeviceEvent) ep.On(event.ToDeviceRoomKeyRequest, mach.HandleToDeviceEvent) diff --git a/requests.go b/requests.go index f91aaa79..c49f7c9c 100644 --- a/requests.go +++ b/requests.go @@ -225,7 +225,7 @@ func (otk *OneTimeKey) MarshalJSON() ([]byte, error) { type ReqUploadKeys struct { DeviceKeys *DeviceKeys `json:"device_keys,omitempty"` - OneTimeKeys map[id.KeyID]OneTimeKey `json:"one_time_keys"` + OneTimeKeys map[id.KeyID]OneTimeKey `json:"one_time_keys,omitempty"` } type ReqKeysSignatures struct { diff --git a/url.go b/url.go index 4646b442..f35ae5e2 100644 --- a/url.go +++ b/url.go @@ -102,6 +102,10 @@ func (cli *Client) BuildURLWithQuery(urlPath PrefixableURLPath, urlQuery map[str if cli.SetAppServiceUserID { query.Set("user_id", string(cli.UserID)) } + if cli.SetAppServiceDeviceID && cli.DeviceID != "" { + query.Set("device_id", string(cli.DeviceID)) + query.Set("org.matrix.msc3202.device_id", string(cli.DeviceID)) + } if urlQuery != nil { for k, v := range urlQuery { query.Set(k, v) From fe20235578acdc5ba714cea2738f05e1f7e2427c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 29 Aug 2024 16:19:49 +0300 Subject: [PATCH 0697/1647] bridgev2/backfill: remove www. prefix in deterministic event IDs --- bridgev2/matrix/connector.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 115250f2..4297cba7 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -199,7 +199,7 @@ func (br *Connector) Start(ctx context.Context) error { } parsed, _ := url.Parse(br.Bridge.Network.GetName().NetworkURL) if parsed != nil { - br.deterministicEventIDServer = parsed.Hostname() + br.deterministicEventIDServer = strings.TrimPrefix(parsed.Hostname(), "www.") } br.AS.Ready = true if br.Websocket && br.Config.Homeserver.WSPingInterval > 0 { From f46572f058176e4f1f78f53f34171a610cf86160 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 30 Aug 2024 00:59:31 +0300 Subject: [PATCH 0698/1647] event: add type for policy rule recommendations --- event/state.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/event/state.go b/event/state.go index 6a067cae..eb539069 100644 --- a/event/state.go +++ b/event/state.go @@ -182,12 +182,18 @@ type SpaceParentEventContent struct { Canonical bool `json:"canonical,omitempty"` } +type PolicyRecommendation string + +const ( + PolicyRecommendationBan PolicyRecommendation = "m.ban" +) + // ModPolicyContent represents the content of a m.room.rule.user, m.room.rule.room, and m.room.rule.server state event. // https://spec.matrix.org/v1.2/client-server-api/#moderation-policy-lists type ModPolicyContent struct { - Entity string `json:"entity"` - Reason string `json:"reason"` - Recommendation string `json:"recommendation"` + Entity string `json:"entity"` + Reason string `json:"reason"` + Recommendation PolicyRecommendation `json:"recommendation"` } // Deprecated: MSC2716 has been abandoned From 67703cf96fab5ab72e03ceddfdfaa3e95003f967 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 30 Aug 2024 02:59:42 +0300 Subject: [PATCH 0699/1647] pushrules: use glob package from util --- go.mod | 2 +- go.sum | 4 +- pushrules/condition.go | 8 +-- pushrules/glob/LICENSE | 22 -------- pushrules/glob/README.md | 28 ---------- pushrules/glob/glob.go | 108 --------------------------------------- pushrules/rule.go | 9 ++-- 7 files changed, 12 insertions(+), 169 deletions(-) delete mode 100644 pushrules/glob/LICENSE delete mode 100644 pushrules/glob/README.md delete mode 100644 pushrules/glob/glob.go diff --git a/go.mod b/go.mod index 47be73d1..66f76646 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/tidwall/gjson v1.17.3 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.4 - go.mau.fi/util v0.7.1-0.20240827112829-84c63841c264 + go.mau.fi/util v0.7.1-0.20240829235756-95504af915a4 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.26.0 golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa diff --git a/go.sum b/go.sum index 7e819192..cb0a7a39 100644 --- a/go.sum +++ b/go.sum @@ -50,8 +50,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.7.1-0.20240827112829-84c63841c264 h1:mWujT3q8pxyJc/3BvWTgTN4+k41d1pBCvwxH56prqQA= -go.mau.fi/util v0.7.1-0.20240827112829-84c63841c264/go.mod h1:WuAOOV0O/otkxGkFUvfv/XE2ztegaoyM15ovS6SYbf4= +go.mau.fi/util v0.7.1-0.20240829235756-95504af915a4 h1:vk/X3TjGzYR9RMUpq74TuhOuR0tcXGU8y3uQL4WLD9w= +go.mau.fi/util v0.7.1-0.20240829235756-95504af915a4/go.mod h1:WuAOOV0O/otkxGkFUvfv/XE2ztegaoyM15ovS6SYbf4= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= diff --git a/pushrules/condition.go b/pushrules/condition.go index 435178fb..974da114 100644 --- a/pushrules/condition.go +++ b/pushrules/condition.go @@ -15,10 +15,10 @@ import ( "unicode" "github.com/tidwall/gjson" + "go.mau.fi/util/glob" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/pushrules/glob" ) // Room is an interface with the functions that are needed for processing room-specific push conditions @@ -219,11 +219,11 @@ func (cond *PushCondition) matchValue(evt *event.Event) bool { switch cond.Kind { case KindEventMatch, KindRelatedEventMatch, KindUnstableRelatedEventMatch: - pattern, err := glob.Compile(cond.Pattern) - if err != nil { + pattern := glob.Compile(cond.Pattern) + if pattern == nil { return false } - return pattern.MatchString(stringifyForPushCondition(val)) + return pattern.Match(stringifyForPushCondition(val)) case KindEventPropertyIs: return valueEquals(val, cond.Value) case KindEventPropertyContains: diff --git a/pushrules/glob/LICENSE b/pushrules/glob/LICENSE deleted file mode 100644 index cb00d952..00000000 --- a/pushrules/glob/LICENSE +++ /dev/null @@ -1,22 +0,0 @@ -Glob is licensed under the MIT "Expat" License: - -Copyright (c) 2016: Zachary Yedidia. - -Permission is hereby granted, free of charge, to any person obtaining -a copy of this software and associated documentation files (the -"Software"), to deal in the Software without restriction, including -without limitation the rights to use, copy, modify, merge, publish, -distribute, sublicense, and/or sell copies of the Software, and to -permit persons to whom the Software is furnished to do so, subject to -the following conditions: - -The above copyright notice and this permission notice shall be -included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, -TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE -SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/pushrules/glob/README.md b/pushrules/glob/README.md deleted file mode 100644 index e2e6c649..00000000 --- a/pushrules/glob/README.md +++ /dev/null @@ -1,28 +0,0 @@ -# String globbing in Go - -[![GoDoc](https://godoc.org/github.com/zyedidia/glob?status.svg)](http://godoc.org/github.com/zyedidia/glob) - -This package adds support for globs in Go. - -It simply converts glob expressions to regexps. I try to follow the standard defined [here](http://pubs.opengroup.org/onlinepubs/009695399/utilities/xcu_chap02.html#tag_02_13). - -# Example - -```go -package main - -import "github.com/zyedidia/glob" - -func main() { - glob, err := glob.Compile("{*.go,*.c}") - if err != nil { - // Error - } - - glob.Match([]byte("test.c")) // true - glob.Match([]byte("hello.go")) // true - glob.Match([]byte("test.d")) // false -} -``` - -You can call all the same functions on a glob that you can call on a regexp. diff --git a/pushrules/glob/glob.go b/pushrules/glob/glob.go deleted file mode 100644 index c270dbc5..00000000 --- a/pushrules/glob/glob.go +++ /dev/null @@ -1,108 +0,0 @@ -// Package glob provides objects for matching strings with globs -package glob - -import "regexp" - -// Glob is a wrapper of *regexp.Regexp. -// It should contain a glob expression compiled into a regular expression. -type Glob struct { - *regexp.Regexp -} - -// Compile a takes a glob expression as a string and transforms it -// into a *Glob object (which is really just a regular expression) -// Compile also returns a possible error. -func Compile(pattern string) (*Glob, error) { - r, err := globToRegex(pattern) - return &Glob{r}, err -} - -func globToRegex(glob string) (*regexp.Regexp, error) { - regex := "" - inGroup := 0 - inClass := 0 - firstIndexInClass := -1 - arr := []byte(glob) - - hasGlobCharacters := false - - for i := 0; i < len(arr); i++ { - ch := arr[i] - - switch ch { - case '\\': - i++ - if i >= len(arr) { - regex += "\\" - } else { - next := arr[i] - switch next { - case ',': - // Nothing - case 'Q', 'E': - regex += "\\\\" - default: - regex += "\\" - } - regex += string(next) - } - case '*': - if inClass == 0 { - regex += ".*" - } else { - regex += "*" - } - hasGlobCharacters = true - case '?': - if inClass == 0 { - regex += "." - } else { - regex += "?" - } - hasGlobCharacters = true - case '[': - inClass++ - firstIndexInClass = i + 1 - regex += "[" - hasGlobCharacters = true - case ']': - inClass-- - regex += "]" - case '.', '(', ')', '+', '|', '^', '$', '@', '%': - if inClass == 0 || (firstIndexInClass == i && ch == '^') { - regex += "\\" - } - regex += string(ch) - hasGlobCharacters = true - case '!': - if firstIndexInClass == i { - regex += "^" - } else { - regex += "!" - } - hasGlobCharacters = true - case '{': - inGroup++ - regex += "(" - hasGlobCharacters = true - case '}': - inGroup-- - regex += ")" - case ',': - if inGroup > 0 { - regex += "|" - hasGlobCharacters = true - } else { - regex += "," - } - default: - regex += string(ch) - } - } - - if hasGlobCharacters { - return regexp.Compile("^" + regex + "$") - } else { - return regexp.Compile(regex) - } -} diff --git a/pushrules/rule.go b/pushrules/rule.go index 0f7436f3..75ca9322 100644 --- a/pushrules/rule.go +++ b/pushrules/rule.go @@ -9,9 +9,10 @@ package pushrules import ( "encoding/gob" + "go.mau.fi/util/glob" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/pushrules/glob" ) func init() { @@ -164,13 +165,13 @@ func (rule *PushRule) matchConditions(room Room, evt *event.Event) bool { } func (rule *PushRule) matchPattern(room Room, evt *event.Event) bool { - pattern, err := glob.Compile(rule.Pattern) - if err != nil { + pattern := glob.Compile(rule.Pattern) + if pattern == nil { return false } msg, ok := evt.Content.Raw["body"].(string) if !ok { return false } - return pattern.MatchString(msg) + return pattern.Match(msg) } From 7a86cb26ff393cd4aaa93a6b709928ccf3a02e78 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 30 Aug 2024 18:10:22 +0300 Subject: [PATCH 0700/1647] pushrules: fix glob matching --- go.mod | 2 +- go.sum | 4 ++-- pushrules/condition.go | 2 +- pushrules/rule.go | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index 66f76646..e3700339 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( github.com/tidwall/gjson v1.17.3 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.4 - go.mau.fi/util v0.7.1-0.20240829235756-95504af915a4 + go.mau.fi/util v0.7.1-0.20240830150939-8c1e9c295943 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.26.0 golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa diff --git a/go.sum b/go.sum index cb0a7a39..1584f6a8 100644 --- a/go.sum +++ b/go.sum @@ -50,8 +50,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.7.1-0.20240829235756-95504af915a4 h1:vk/X3TjGzYR9RMUpq74TuhOuR0tcXGU8y3uQL4WLD9w= -go.mau.fi/util v0.7.1-0.20240829235756-95504af915a4/go.mod h1:WuAOOV0O/otkxGkFUvfv/XE2ztegaoyM15ovS6SYbf4= +go.mau.fi/util v0.7.1-0.20240830150939-8c1e9c295943 h1:wdJ9XC/M6lVUrwDltHPodaA3SRJq+S+AzGEXdQ/o2AQ= +go.mau.fi/util v0.7.1-0.20240830150939-8c1e9c295943/go.mod h1:WuAOOV0O/otkxGkFUvfv/XE2ztegaoyM15ovS6SYbf4= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= diff --git a/pushrules/condition.go b/pushrules/condition.go index 974da114..dbe83a61 100644 --- a/pushrules/condition.go +++ b/pushrules/condition.go @@ -219,7 +219,7 @@ func (cond *PushCondition) matchValue(evt *event.Event) bool { switch cond.Kind { case KindEventMatch, KindRelatedEventMatch, KindUnstableRelatedEventMatch: - pattern := glob.Compile(cond.Pattern) + pattern := glob.CompileWithImplicitContains(cond.Pattern) if pattern == nil { return false } diff --git a/pushrules/rule.go b/pushrules/rule.go index 75ca9322..ee6d33c4 100644 --- a/pushrules/rule.go +++ b/pushrules/rule.go @@ -165,7 +165,7 @@ func (rule *PushRule) matchConditions(room Room, evt *event.Event) bool { } func (rule *PushRule) matchPattern(room Room, evt *event.Event) bool { - pattern := glob.Compile(rule.Pattern) + pattern := glob.CompileWithImplicitContains(rule.Pattern) if pattern == nil { return false } From 87ca6a9ba2b6ac3e6559d05975b7a2bdc445302b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 30 Aug 2024 21:00:38 +0300 Subject: [PATCH 0701/1647] bridgev2/portal: add event ID to handle matrix event log --- bridgev2/portal.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 384a0357..4ba7a8f6 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -395,6 +395,7 @@ func (portal *Portal) checkConfusableName(ctx context.Context, userID id.UserID, func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) { log := portal.Log.With(). Str("action", "handle matrix event"). + Stringer("event_id", evt.ID). Str("event_type", evt.Type.Type). Logger() ctx := log.WithContext(context.TODO()) From 79391515ed64f4c39490e39ae694ee9f25a2aee6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 31 Aug 2024 02:41:28 +0300 Subject: [PATCH 0702/1647] event: add unstable prefixes for policy events --- event/content.go | 7 +++++++ event/state.go | 3 ++- event/type.go | 7 +++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/event/content.go b/event/content.go index e0026e9e..d08acad1 100644 --- a/event/content.go +++ b/event/content.go @@ -40,6 +40,13 @@ var TypeMap = map[Type]reflect.Type{ StateSpaceChild: reflect.TypeOf(SpaceChildEventContent{}), StateInsertionMarker: reflect.TypeOf(InsertionMarkerContent{}), + StateLegacyPolicyRoom: reflect.TypeOf(ModPolicyContent{}), + StateLegacyPolicyServer: reflect.TypeOf(ModPolicyContent{}), + StateLegacyPolicyUser: reflect.TypeOf(ModPolicyContent{}), + StateUnstablePolicyRoom: reflect.TypeOf(ModPolicyContent{}), + StateUnstablePolicyServer: reflect.TypeOf(ModPolicyContent{}), + StateUnstablePolicyUser: reflect.TypeOf(ModPolicyContent{}), + StateElementFunctionalMembers: reflect.TypeOf(ElementFunctionalMembersContent{}), EventMessage: reflect.TypeOf(MessageEventContent{}), diff --git a/event/state.go b/event/state.go index eb539069..0844936a 100644 --- a/event/state.go +++ b/event/state.go @@ -185,7 +185,8 @@ type SpaceParentEventContent struct { type PolicyRecommendation string const ( - PolicyRecommendationBan PolicyRecommendation = "m.ban" + PolicyRecommendationBan PolicyRecommendation = "m.ban" + PolicyRecommendationUnstableBan PolicyRecommendation = "org.matrix.mjolnir.ban" ) // ModPolicyContent represents the content of a m.room.rule.user, m.room.rule.room, and m.room.rule.server state event. diff --git a/event/type.go b/event/type.go index 162e2ce7..3f447343 100644 --- a/event/type.go +++ b/event/type.go @@ -192,6 +192,13 @@ var ( StateSpaceChild = Type{"m.space.child", StateEventType} StateSpaceParent = Type{"m.space.parent", StateEventType} + StateLegacyPolicyRoom = Type{"m.room.rule.room", StateEventType} + StateLegacyPolicyServer = Type{"m.room.rule.server", StateEventType} + StateLegacyPolicyUser = Type{"m.room.rule.user", StateEventType} + StateUnstablePolicyRoom = Type{"org.matrix.mjolnir.rule.room", StateEventType} + StateUnstablePolicyServer = Type{"org.matrix.mjolnir.rule.server", StateEventType} + StateUnstablePolicyUser = Type{"org.matrix.mjolnir.rule.user", StateEventType} + // Deprecated: MSC2716 has been abandoned StateInsertionMarker = Type{"org.matrix.msc2716.marker", StateEventType} From e7bc21d463d6d5881f2cc4d01d64772339e8657e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 1 Sep 2024 12:54:13 +0300 Subject: [PATCH 0703/1647] client: add support for feature flag for authenticated media --- client.go | 6 +++--- versions.go | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index e1b36f87..6ec0b252 100644 --- a/client.go +++ b/client.go @@ -1463,7 +1463,7 @@ func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomSt // GetMediaConfig fetches the configuration of the content repository, such as upload limitations. func (cli *Client) GetMediaConfig(ctx context.Context) (resp *RespMediaConfig, err error) { var u string - if cli.SpecVersions.ContainsGreaterOrEqual(SpecV111) { + if cli.SpecVersions.Supports(FeatureAuthenticatedMedia) { u = cli.BuildClientURL("v1", "media", "config") } else { u = cli.BuildURL(MediaURLPath{"v3", "config"}) @@ -1562,7 +1562,7 @@ func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (*http.Re if ctxLog.GetLevel() == zerolog.Disabled || ctxLog == zerolog.DefaultContextLogger { ctx = cli.Log.WithContext(ctx) } - if cli.SpecVersions.ContainsGreaterOrEqual(SpecV111) { + if cli.SpecVersions.Supports(FeatureAuthenticatedMedia) { _, resp, err := cli.MakeFullRequestWithResp(ctx, FullRequest{ Method: http.MethodGet, URL: cli.BuildClientURL("v1", "media", "download", mxcURL.Homeserver, mxcURL.FileID), @@ -1785,7 +1785,7 @@ func (cli *Client) UploadMedia(ctx context.Context, data ReqUploadMedia) (*RespM // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixmediav3preview_url func (cli *Client) GetURLPreview(ctx context.Context, url string) (*RespPreviewURL, error) { var urlPath PrefixableURLPath - if cli.SpecVersions.ContainsGreaterOrEqual(SpecV111) { + if cli.SpecVersions.Supports(FeatureAuthenticatedMedia) { urlPath = ClientURLPath{"v1", "media", "preview_url"} } else { urlPath = MediaURLPath{"v3", "preview_url"} diff --git a/versions.go b/versions.go index 60eb0f30..672018ff 100644 --- a/versions.go +++ b/versions.go @@ -60,8 +60,9 @@ type UnstableFeature struct { } var ( - FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17} - FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17} + FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17} + FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17} + FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111} BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"} BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"} From a0d427e4df091c2dbcbc5c8560cd203b6dcef993 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 2 Sep 2024 01:20:22 +0300 Subject: [PATCH 0704/1647] crypto: add hack to avoid logging about OTK counts for cross-signing keys --- crypto/machine.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/crypto/machine.go b/crypto/machine.go index 188aa210..85da2b3b 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -245,8 +245,22 @@ func (mach *OlmMachine) HandleDeviceLists(ctx context.Context, dl *mautrix.Devic } } +func (mach *OlmMachine) otkCountIsForCrossSigningKey(otkCount *mautrix.OTKCount) bool { + if mach.crossSigningPubkeys == nil || otkCount.UserID != mach.Client.UserID { + return false + } + switch id.Ed25519(otkCount.DeviceID) { + case mach.crossSigningPubkeys.MasterKey, mach.crossSigningPubkeys.UserSigningKey, mach.crossSigningPubkeys.SelfSigningKey: + return true + } + return false +} + func (mach *OlmMachine) HandleOTKCounts(ctx context.Context, otkCount *mautrix.OTKCount) { if (len(otkCount.UserID) > 0 && otkCount.UserID != mach.Client.UserID) || (len(otkCount.DeviceID) > 0 && otkCount.DeviceID != mach.Client.DeviceID) { + if mach.otkCountIsForCrossSigningKey(otkCount) { + return + } // TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions mach.Log.Warn(). Str("target_user_id", otkCount.UserID.String()). From fd15ca4a68f1684ab691b75b4b63cfe8829dd3a9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 2 Sep 2024 12:02:14 +0300 Subject: [PATCH 0705/1647] bridgev2/matrix: add missing return in UploadMediaStream --- bridgev2/matrix/intent.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 5b27f906..7f6ebbb9 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -289,7 +289,8 @@ func (as *ASIntent) UploadMediaStream( var res *bridgev2.FileStreamResult res, err = cb(tempFile) if err != nil { - err = fmt.Errorf("failed to write to temp file: %w", err) + err = fmt.Errorf("write callback failed: %w", err) + return } var replFile *os.File if res.ReplacementFile != "" { From 6f1a3878c452179f258852eadb01278c78532e70 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 2 Sep 2024 12:23:46 +0300 Subject: [PATCH 0706/1647] dependencies: update --- go.mod | 6 +++--- go.sum | 11 ++++++----- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/go.mod b/go.mod index e3700339..f1f9dc53 100644 --- a/go.mod +++ b/go.mod @@ -9,17 +9,17 @@ require ( github.com/gorilla/websocket v1.5.0 github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.22 - github.com/rs/xid v1.5.0 + github.com/rs/xid v1.6.0 github.com/rs/zerolog v1.33.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/stretchr/testify v1.9.0 github.com/tidwall/gjson v1.17.3 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.4 - go.mau.fi/util v0.7.1-0.20240830150939-8c1e9c295943 + go.mau.fi/util v0.7.1-0.20240901193650-bf007b10eaf6 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.26.0 - golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa + golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 golang.org/x/net v0.28.0 golang.org/x/sync v0.8.0 gopkg.in/yaml.v3 v3.0.1 diff --git a/go.sum b/go.sum index 1584f6a8..05012924 100644 --- a/go.sum +++ b/go.sum @@ -31,8 +31,9 @@ github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7/go.mod h1:pxMtw7c github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= @@ -50,14 +51,14 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.7.1-0.20240830150939-8c1e9c295943 h1:wdJ9XC/M6lVUrwDltHPodaA3SRJq+S+AzGEXdQ/o2AQ= -go.mau.fi/util v0.7.1-0.20240830150939-8c1e9c295943/go.mod h1:WuAOOV0O/otkxGkFUvfv/XE2ztegaoyM15ovS6SYbf4= +go.mau.fi/util v0.7.1-0.20240901193650-bf007b10eaf6 h1:cSLCabMKbR6rTPYRGWD2XaHo210BK3BtPg+CRC4A4og= +go.mau.fi/util v0.7.1-0.20240901193650-bf007b10eaf6/go.mod h1:WuAOOV0O/otkxGkFUvfv/XE2ztegaoyM15ovS6SYbf4= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= -golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa h1:ELnwvuAXPNtPk1TJRuGkI9fDTwym6AYBu0qzT8AcHdI= -golang.org/x/exp v0.0.0-20240808152545-0cdaa3abc0fa/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ= +golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 h1:kx6Ds3MlpiUHKj7syVnbp57++8WpuKPcR5yjLBjvLEA= +golang.org/x/exp v0.0.0-20240823005443-9b4947da3948/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ= golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= From 0c14ad0f0c1c4710127b11ad459f5f72538d669c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 2 Sep 2024 23:17:30 +0300 Subject: [PATCH 0707/1647] event: add unban mod policy recommendation --- event/state.go | 1 + 1 file changed, 1 insertion(+) diff --git a/event/state.go b/event/state.go index 0844936a..6e5f0ae4 100644 --- a/event/state.go +++ b/event/state.go @@ -187,6 +187,7 @@ type PolicyRecommendation string const ( PolicyRecommendationBan PolicyRecommendation = "m.ban" PolicyRecommendationUnstableBan PolicyRecommendation = "org.matrix.mjolnir.ban" + PolicyRecommendationUnban PolicyRecommendation = "fi.mau.meowlnir.unban" ) // ModPolicyContent represents the content of a m.room.rule.user, m.room.rule.room, and m.room.rule.server state event. From db8f2433a1dbe24f51171c321980fa3c405b6687 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 2 Sep 2024 23:49:06 +0300 Subject: [PATCH 0708/1647] statestore: mass insert members on refetch --- client.go | 88 +++++++++++++++++------------------- hicli/database/statestore.go | 4 ++ sqlstatestore/statestore.go | 79 ++++++++++++++++++++++++++++---- statestore.go | 16 ++++++- 4 files changed, 130 insertions(+), 57 deletions(-) diff --git a/client.go b/client.go index 6ec0b252..9a10c987 100644 --- a/client.go +++ b/client.go @@ -13,13 +13,16 @@ import ( "net/http" "net/url" "os" + "slices" "strconv" "strings" "sync/atomic" "time" "github.com/rs/zerolog" + "go.mau.fi/util/ptr" "go.mau.fi/util/retryafter" + "golang.org/x/exp/maps" "maunium.net/go/mautrix/crypto/backup" "maunium.net/go/mautrix/event" @@ -1440,21 +1443,19 @@ func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomSt Handler: parseRoomStateArray, }) if err == nil && cli.StateStore != nil { - clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID) - if clearErr != nil { - cli.cliOrContextLog(ctx).Warn().Err(clearErr). - Stringer("room_id", roomID). - Msg("Failed to clear cached member list after fetching state") - } - for _, evts := range stateMap { + for evtType, evts := range stateMap { + if evtType == event.StateMember { + continue + } for _, evt := range evts { UpdateStateStore(ctx, cli.StateStore, evt) } } - clearErr = cli.StateStore.MarkMembersFetched(ctx, roomID) - if clearErr != nil { - cli.cliOrContextLog(ctx).Warn().Err(clearErr). - Msg("Failed to mark members as fetched after fetching full room state") + updateErr := cli.StateStore.ReplaceCachedMembers(ctx, roomID, maps.Values(stateMap[event.StateMember])) + if updateErr != nil { + cli.cliOrContextLog(ctx).Warn().Err(updateErr). + Stringer("room_id", roomID). + Msg("Failed to update members in state store after fetching members") } } return @@ -1806,24 +1807,26 @@ func (cli *Client) JoinedMembers(ctx context.Context, roomID id.RoomID) (resp *R u := cli.BuildClientURL("v3", "rooms", roomID, "joined_members") _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp) if err == nil && cli.StateStore != nil { - clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, event.MembershipJoin) - if clearErr != nil { - cli.cliOrContextLog(ctx).Warn().Err(clearErr). - Stringer("room_id", roomID). - Msg("Failed to clear cached member list after fetching joined members") - } + fakeEvents := make([]*event.Event, len(resp.Joined)) + i := 0 for userID, member := range resp.Joined { - updateErr := cli.StateStore.SetMember(ctx, roomID, userID, &event.MemberEventContent{ - Membership: event.MembershipJoin, - AvatarURL: id.ContentURIString(member.AvatarURL), - Displayname: member.DisplayName, - }) - if updateErr != nil { - cli.cliOrContextLog(ctx).Warn().Err(updateErr). - Stringer("room_id", roomID). - Stringer("user_id", userID). - Msg("Failed to update membership in state store after fetching joined members") + fakeEvents[i] = &event.Event{ + StateKey: ptr.Ptr(userID.String()), + Type: event.StateMember, + RoomID: roomID, + Content: event.Content{Parsed: &event.MemberEventContent{ + Membership: event.MembershipJoin, + AvatarURL: id.ContentURIString(member.AvatarURL), + Displayname: member.DisplayName, + }}, } + i++ + } + updateErr := cli.StateStore.ReplaceCachedMembers(ctx, roomID, fakeEvents, event.MembershipJoin) + if updateErr != nil { + cli.cliOrContextLog(ctx).Warn().Err(updateErr). + Stringer("room_id", roomID). + Msg("Failed to update members in state store after fetching joined members") } } return @@ -1852,27 +1855,20 @@ func (cli *Client) Members(ctx context.Context, roomID id.RoomID, req ...ReqMemb } } if err == nil && cli.StateStore != nil { - var clearMemberships []event.Membership + var onlyMemberships []event.Membership if extra.Membership != "" { - clearMemberships = append(clearMemberships, extra.Membership) + onlyMemberships = []event.Membership{extra.Membership} + } else if extra.NotMembership != "" { + onlyMemberships = []event.Membership{event.MembershipJoin, event.MembershipLeave, event.MembershipInvite, event.MembershipBan, event.MembershipKnock} + onlyMemberships = slices.DeleteFunc(onlyMemberships, func(m event.Membership) bool { + return m == extra.NotMembership + }) } - if extra.NotMembership == "" { - clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, clearMemberships...) - if clearErr != nil { - cli.cliOrContextLog(ctx).Warn().Err(clearErr). - Stringer("room_id", roomID). - Msg("Failed to clear cached member list after fetching joined members") - } - } - for _, evt := range resp.Chunk { - UpdateStateStore(ctx, cli.StateStore, evt) - } - if extra.NotMembership == "" && extra.Membership == "" { - markErr := cli.StateStore.MarkMembersFetched(ctx, roomID) - if markErr != nil { - cli.cliOrContextLog(ctx).Warn().Err(markErr). - Msg("Failed to mark members as fetched after fetching full member list") - } + updateErr := cli.StateStore.ReplaceCachedMembers(ctx, roomID, resp.Chunk, onlyMemberships...) + if updateErr != nil { + cli.cliOrContextLog(ctx).Warn().Err(updateErr). + Stringer("room_id", roomID). + Msg("Failed to update members in state store after fetching members") } } return diff --git a/hicli/database/statestore.go b/hicli/database/statestore.go index cefe76d3..1779afa5 100644 --- a/hicli/database/statestore.go +++ b/hicli/database/statestore.go @@ -174,3 +174,7 @@ func (c *ClientStateStore) SetEncryptionEvent(ctx context.Context, roomID id.Roo } func (c *ClientStateStore) UpdateState(ctx context.Context, evt *event.Event) {} + +func (c *ClientStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error { + return nil +} diff --git a/sqlstatestore/statestore.go b/sqlstatestore/statestore.go index 2cfd1b97..d594c307 100644 --- a/sqlstatestore/statestore.go +++ b/sqlstatestore/statestore.go @@ -19,6 +19,7 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/confusable" "go.mau.fi/util/dbutil" + "go.mau.fi/util/exslices" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -194,21 +195,37 @@ func (store *SQLStateStore) SetMembership(ctx context.Context, roomID id.RoomID, return err } +const insertUserProfileQuery = ` + INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url, name_skeleton) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (room_id, user_id) DO UPDATE + SET membership=excluded.membership, + displayname=excluded.displayname, + avatar_url=excluded.avatar_url, + name_skeleton=excluded.name_skeleton +` + +type userProfileRow struct { + UserID id.UserID + Membership event.Membership + Displayname string + AvatarURL id.ContentURIString + NameSkeleton []byte +} + +func (u *userProfileRow) GetMassInsertValues() [5]any { + return [5]any{u.UserID, u.Membership, u.Displayname, u.AvatarURL, u.NameSkeleton} +} + +var userProfileMassInserter = dbutil.NewMassInsertBuilder[*userProfileRow, [1]any](insertUserProfileQuery, "($1, $%d, $%d, $%d, $%d, $%d)") + func (store *SQLStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error { var nameSkeleton []byte if !store.DisableNameDisambiguation && len(member.Displayname) > 0 { nameSkeletonArr := confusable.SkeletonHash(member.Displayname) nameSkeleton = nameSkeletonArr[:] } - _, err := store.Exec(ctx, ` - INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url, name_skeleton) - VALUES ($1, $2, $3, $4, $5, $6) - ON CONFLICT (room_id, user_id) DO UPDATE - SET membership=excluded.membership, - displayname=excluded.displayname, - avatar_url=excluded.avatar_url, - name_skeleton=excluded.name_skeleton - `, roomID, userID, member.Membership, member.Displayname, member.AvatarURL, nameSkeleton) + _, err := store.Exec(ctx, insertUserProfileQuery, roomID, userID, member.Membership, member.Displayname, member.AvatarURL, nameSkeleton) return err } @@ -221,6 +238,50 @@ func (store *SQLStateStore) IsConfusableName(ctx context.Context, roomID id.Room return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() } +const userProfileMassInsertBatchSize = 500 + +func (store *SQLStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error { + return store.DoTxn(ctx, nil, func(ctx context.Context) error { + err := store.ClearCachedMembers(ctx, roomID, onlyMemberships...) + if err != nil { + return fmt.Errorf("failed to clear cached members: %w", err) + } + rows := make([]*userProfileRow, min(len(evts), userProfileMassInsertBatchSize)) + for _, evtsChunk := range exslices.Chunk(evts, userProfileMassInsertBatchSize) { + rows = rows[:0] + for _, evt := range evtsChunk { + content, ok := evt.Content.Parsed.(*event.MemberEventContent) + if !ok { + continue + } + row := &userProfileRow{ + UserID: id.UserID(*evt.StateKey), + Membership: content.Membership, + Displayname: content.Displayname, + AvatarURL: content.AvatarURL, + } + if !store.DisableNameDisambiguation && len(content.Displayname) > 0 { + nameSkeletonArr := confusable.SkeletonHash(content.Displayname) + row.NameSkeleton = nameSkeletonArr[:] + } + rows = append(rows, row) + } + query, args := userProfileMassInserter.Build([1]any{roomID}, rows) + _, err = store.Exec(ctx, query, args...) + if err != nil { + return fmt.Errorf("failed to insert members: %w", err) + } + } + if len(onlyMemberships) == 0 { + err = store.MarkMembersFetched(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to mark members as fetched: %w", err) + } + } + return nil + }) +} + func (store *SQLStateStore) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error { query := "DELETE FROM mx_user_profile WHERE room_id=$1" params := make([]any, len(memberships)+1) diff --git a/statestore.go b/statestore.go index 5f210e4f..e728b885 100644 --- a/statestore.go +++ b/statestore.go @@ -29,6 +29,7 @@ type StateStore interface { SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error + ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error) @@ -270,9 +271,20 @@ func (store *MemoryStateStore) MarkMembersFetched(ctx context.Context, roomID id return nil } +func (store *MemoryStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error { + _ = store.ClearCachedMembers(ctx, roomID, onlyMemberships...) + for _, evt := range evts { + UpdateStateStore(ctx, store, evt) + } + if len(onlyMemberships) == 0 { + _ = store.MarkMembersFetched(ctx, roomID) + } + return nil +} + func (store *MemoryStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { - store.membersLock.Lock() - defer store.membersLock.Unlock() + store.membersLock.RLock() + defer store.membersLock.RUnlock() return maps.Clone(store.Members[roomID]), nil } From 55770a4a15b2958b99990eb16a312304706cb28e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 3 Sep 2024 01:48:53 +0300 Subject: [PATCH 0709/1647] bridgev2/provisioning: check user permissions --- bridgev2/matrix/provisioning.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index c1f02890..69720609 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -227,6 +227,14 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { return } // TODO handle user being nil? + // TODO per-endpoint permissions? + if !user.Permissions.Login { + jsonResponse(w, http.StatusForbidden, &mautrix.RespError{ + Err: "User does not have login permissions", + ErrCode: mautrix.MForbidden.ErrCode, + }) + return + } ctx := context.WithValue(r.Context(), provisioningUserKey, user) if loginID, ok := mux.Vars(r)["loginProcessID"]; ok { From 4e8156519efeadab321018b63c3b399df5ad0739 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 3 Sep 2024 01:49:06 +0300 Subject: [PATCH 0710/1647] federation: limit .well-known file size --- federation/resolution.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/federation/resolution.go b/federation/resolution.go index e6785988..24085282 100644 --- a/federation/resolution.go +++ b/federation/resolution.go @@ -11,6 +11,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net" "net/http" "net/url" @@ -140,7 +141,7 @@ func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (* return nil, time.Time{}, fmt.Errorf("unexpected status code %d", resp.StatusCode) } var respData RespWellKnown - err = json.NewDecoder(resp.Body).Decode(&respData) + err = json.NewDecoder(io.LimitReader(resp.Body, 50*1024)).Decode(&respData) if err != nil { return nil, time.Time{}, fmt.Errorf("failed to decode response: %w", err) } else if respData.Server == "" { From a17dc5867e37d742de22f305c192b234b6e9d2fb Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 2 Sep 2024 16:53:40 -0600 Subject: [PATCH 0711/1647] bridgev2/provisioning: allow custom user ID retrieval Signed-off-by: Sumner Evans --- bridgev2/matrix/provisioning.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 69720609..00f5eb72 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -54,6 +54,11 @@ type ProvisioningAPI struct { // GetAuthFromRequest is a custom function for getting the auth token from // the request if the Authorization header is not present. GetAuthFromRequest func(r *http.Request) string + + // GetUserIDFromRequest is a custom function for getting the user ID to + // authenticate as instead of using the user ID provided in the query + // parameter. + GetUserIDFromRequest func(r *http.Request) id.UserID } type ProvLogin struct { @@ -200,6 +205,9 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { return } userID := id.UserID(r.URL.Query().Get("user_id")) + if userID == "" && prov.GetUserIDFromRequest != nil { + userID = prov.GetUserIDFromRequest(r) + } if auth != prov.br.Config.Provisioning.SharedSecret { var err error if strings.HasPrefix(auth, "openid:") { From ab110b44254291c4930ed81fc3b76857363e69eb Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Tue, 3 Sep 2024 13:54:05 +0100 Subject: [PATCH 0712/1647] Add refresh token field to login request (#278) --- requests.go | 1 + 1 file changed, 1 insertion(+) diff --git a/requests.go b/requests.go index c49f7c9c..189e620d 100644 --- a/requests.go +++ b/requests.go @@ -83,6 +83,7 @@ type ReqLogin struct { Token string `json:"token,omitempty"` DeviceID id.DeviceID `json:"device_id,omitempty"` InitialDeviceDisplayName string `json:"initial_device_display_name,omitempty"` + RefreshToken bool `json:"refresh_token,omitempty"` // Whether or not the returned credentials should be stored in the Client StoreCredentials bool `json:"-"` From da2780bcbe0a1d9e0ebf0010ad7bd6fa6bc42f00 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Tue, 3 Sep 2024 13:54:10 +0100 Subject: [PATCH 0713/1647] Use a warning log when request context canceled (#279) --- client.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index 9a10c987..edbeedfe 100644 --- a/client.go +++ b/client.go @@ -324,7 +324,9 @@ func (cli *Client) RequestStart(req *http.Request) { func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err error, handlerErr error, contentLength int, duration time.Duration) { var evt *zerolog.Event - if err != nil { + if errors.Is(err, context.Canceled) { + evt = zerolog.Ctx(req.Context()).Warn() + } else if err != nil { evt = zerolog.Ctx(req.Context()).Err(err) } else if handlerErr != nil { evt = zerolog.Ctx(req.Context()).Warn(). @@ -357,7 +359,9 @@ func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err er if body := req.Context().Value(LogBodyContextKey); body != nil { evt.Interface("req_body", body) } - if err != nil { + if errors.Is(err, context.Canceled) { + evt.Msg("Request canceled") + } else if err != nil { evt.Msg("Request failed") } else if handlerErr != nil { evt.Msg("Request parsing failed") From a62bdb625099dcebcb927f763918084a9a182b32 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 4 Sep 2024 20:36:18 +0300 Subject: [PATCH 0714/1647] dependencies: update --- go.mod | 8 ++++---- go.sum | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/go.mod b/go.mod index f1f9dc53..78d1b8c4 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.5.0 github.com/lib/pq v1.10.9 - github.com/mattn/go-sqlite3 v1.14.22 + github.com/mattn/go-sqlite3 v1.14.23 github.com/rs/xid v1.6.0 github.com/rs/zerolog v1.33.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e @@ -16,7 +16,7 @@ require ( github.com/tidwall/gjson v1.17.3 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.4 - go.mau.fi/util v0.7.1-0.20240901193650-bf007b10eaf6 + go.mau.fi/util v0.7.1-0.20240904173517-ca3b3fe376c2 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.26.0 golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 @@ -35,7 +35,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect - golang.org/x/sys v0.24.0 // indirect - golang.org/x/text v0.17.0 // indirect + golang.org/x/sys v0.25.0 // indirect + golang.org/x/text v0.18.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index 05012924..0f1a0558 100644 --- a/go.sum +++ b/go.sum @@ -24,8 +24,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= -github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mattn/go-sqlite3 v1.14.23 h1:gbShiuAP1W5j9UOksQ06aiiqPMxYecovVGwmTxWtuw0= +github.com/mattn/go-sqlite3 v1.14.23/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7 h1:Dx7Ovyv/SFnMFw3fD4oEoeorXc6saIiQ23LrGLth0Gw= github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -51,8 +51,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.7.1-0.20240901193650-bf007b10eaf6 h1:cSLCabMKbR6rTPYRGWD2XaHo210BK3BtPg+CRC4A4og= -go.mau.fi/util v0.7.1-0.20240901193650-bf007b10eaf6/go.mod h1:WuAOOV0O/otkxGkFUvfv/XE2ztegaoyM15ovS6SYbf4= +go.mau.fi/util v0.7.1-0.20240904173517-ca3b3fe376c2 h1:VZQlKBbeJ7KOlYSh6BnN5uWQTY/ypn/bJv0YyEd+pXc= +go.mau.fi/util v0.7.1-0.20240904173517-ca3b3fe376c2/go.mod h1:WgYvbt9rVmoFeajP97NunQU7AjgvTPiNExN3oTHeePs= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= @@ -67,10 +67,10 @@ golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= -golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= -golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= +golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= From e5ea10d64c032c6062bae02f800ae2b657d6478f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 4 Sep 2024 22:17:33 +0300 Subject: [PATCH 0715/1647] bridgev2: allow adding pending message before returning from handler --- bridgev2/networkinterface.go | 11 ++- bridgev2/portal.go | 134 ++++++++++++++++++++++++----------- 2 files changed, 100 insertions(+), 45 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 556b7407..3e0617ae 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -273,11 +273,16 @@ type MaxFileSizeingNetwork interface { SetMaxFileSize(maxSize int64) } +type RemoteEchoHandler func(RemoteMessage, *database.Message) (bool, error) + type MatrixMessageResponse struct { DB *database.Message - - Pending networkid.TransactionID - HandleEcho func(RemoteMessage, *database.Message) (bool, error) + // If Pending is set, the bridge will not save the provided message to the database. + // This should only be used if AddPendingToSave has been called. + Pending bool + // If RemovePending is set, the bridge will remove the provided transaction ID from pending messages + // after saving the provided message to the database. This should be used with AddPendingToIgnore. + RemovePending networkid.TransactionID } type FileRestriction struct { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 4ba7a8f6..a5a150a4 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -59,6 +59,7 @@ type portalEvent interface { type outgoingMessage struct { db *database.Message evt *event.Event + ignore bool handle func(RemoteMessage, *database.Message) (bool, error) } @@ -775,7 +776,7 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin } } - resp, err := sender.Client.HandleMatrixMessage(ctx, &MatrixMessage{ + wrappedEvt := &MatrixMessage{ MatrixEventBase: MatrixEventBase[*event.MessageEventContent]{ Event: evt, Content: content, @@ -784,52 +785,30 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin }, ThreadRoot: threadRoot, ReplyTo: replyTo, - }) + } + resp, err := sender.Client.HandleMatrixMessage(ctx, wrappedEvt) if err != nil { log.Err(err).Msg("Failed to handle Matrix message") portal.sendErrorStatus(ctx, evt, err) return } - message := resp.DB - if message.MXID == "" { - message.MXID = evt.ID - } - if message.Room.ID == "" { - message.Room = portal.PortalKey - } - if message.Timestamp.IsZero() { - message.Timestamp = time.UnixMilli(evt.Timestamp) - } - if message.ReplyTo.MessageID == "" && replyTo != nil { - message.ReplyTo.MessageID = replyTo.ID - message.ReplyTo.PartID = &replyTo.PartID - } - if message.ThreadRoot == "" && threadRoot != nil { - message.ThreadRoot = threadRoot.ID - if threadRoot.ThreadRoot != "" { - message.ThreadRoot = threadRoot.ThreadRoot - } - } - if message.SenderMXID == "" { - message.SenderMXID = evt.Sender - } - if resp.Pending != "" { - // TODO if the event queue is ever removed, this will have to be done by the network connector before sending the request - // (for now this is fine because incoming messages will wait in the queue for this function to return) - portal.outgoingMessagesLock.Lock() - portal.outgoingMessages[resp.Pending] = outgoingMessage{ - db: message, - evt: evt, - handle: resp.HandleEcho, - } - portal.outgoingMessagesLock.Unlock() - } else { - // Hack to ensure the ghost row exists - // TODO move to better place (like login) - portal.Bridge.GetGhostByID(ctx, message.SenderID) - err = portal.Bridge.DB.Message.Insert(ctx, message) - if err != nil { - log.Err(err).Msg("Failed to save message to database") + message := wrappedEvt.fillDBMessage(resp.DB) + if !resp.Pending { + if resp.DB == nil { + log.Error().Msg("Network connector didn't return a message to save") + } else { + // Hack to ensure the ghost row exists + // TODO move to better place (like login) + portal.Bridge.GetGhostByID(ctx, message.SenderID) + err = portal.Bridge.DB.Message.Insert(ctx, message) + if err != nil { + log.Err(err).Msg("Failed to save message to database") + } + if resp.RemovePending != "" { + portal.outgoingMessagesLock.Lock() + delete(portal.outgoingMessages, resp.RemovePending) + portal.outgoingMessagesLock.Unlock() + } } portal.sendSuccessStatus(ctx, evt) } @@ -846,6 +825,75 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin } } +// AddPendingToIgnore adds a transaction ID that should be ignored if encountered as a new message. +// +// This should be used when the network connector will return the real message ID from HandleMatrixMessage. +// The [MatrixMessageResponse] should include RemovePending with the transaction ID sto remove it from the lit +// after saving to database. +// +// See also: [MatrixMessage.AddPendingToSave] +func (evt *MatrixMessage) AddPendingToIgnore(txnID networkid.TransactionID) { + evt.Portal.outgoingMessagesLock.Lock() + evt.Portal.outgoingMessages[txnID] = outgoingMessage{ + ignore: true, + } + evt.Portal.outgoingMessagesLock.Unlock() +} + +// AddPendingToSave adds a transaction ID that should be processed and pointed at the existing event if encountered. +// +// This should be used when the network connector returns `Pending: true` from HandleMatrixMessage, +// i.e. when the network connector does not know the message ID at the end of the handler. +// The [MatrixMessageResponse] should set Pending to true to prevent saving the returned message to the database. +// +// The provided function will be called when the message is encountered. +func (evt *MatrixMessage) AddPendingToSave(message *database.Message, txnID networkid.TransactionID, handleEcho RemoteEchoHandler) { + evt.Portal.outgoingMessagesLock.Lock() + evt.Portal.outgoingMessages[txnID] = outgoingMessage{ + db: evt.fillDBMessage(message), + evt: evt.Event, + handle: handleEcho, + } + evt.Portal.outgoingMessagesLock.Unlock() +} + +// RemovePending removes a transaction ID from the list of pending messages. +// This should only be called if sending the message fails. +func (evt *MatrixMessage) RemovePending(txnID networkid.TransactionID) { + evt.Portal.outgoingMessagesLock.Lock() + delete(evt.Portal.outgoingMessages, txnID) + evt.Portal.outgoingMessagesLock.Unlock() +} + +func (evt *MatrixMessage) fillDBMessage(message *database.Message) *database.Message { + if message == nil { + message = &database.Message{} + } + if message.MXID == "" { + message.MXID = evt.Event.ID + } + if message.Room.ID == "" { + message.Room = evt.Portal.PortalKey + } + if message.Timestamp.IsZero() { + message.Timestamp = time.UnixMilli(evt.Event.Timestamp) + } + if message.ReplyTo.MessageID == "" && evt.ReplyTo != nil { + message.ReplyTo.MessageID = evt.ReplyTo.ID + message.ReplyTo.PartID = &evt.ReplyTo.PartID + } + if message.ThreadRoot == "" && evt.ThreadRoot != nil { + message.ThreadRoot = evt.ThreadRoot.ID + if evt.ThreadRoot.ThreadRoot != "" { + message.ThreadRoot = evt.ThreadRoot.ThreadRoot + } + } + if message.SenderMXID == "" { + message.SenderMXID = evt.Event.Sender + } + return message +} + func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent, caps *NetworkRoomCapabilities) { log := zerolog.Ctx(ctx) editTargetID := content.RelatesTo.GetReplaceID() @@ -1715,6 +1763,8 @@ func (portal *Portal) checkPendingMessage(ctx context.Context, evt RemoteMessage pending, ok := portal.outgoingMessages[txnID] if !ok { return false, nil + } else if pending.ignore { + return true, nil } delete(portal.outgoingMessages, txnID) pending.db.ID = evt.GetID() From 6c8519d39e43b05ce8d228e68a95b34ec7aec66e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 4 Sep 2024 23:49:44 +0300 Subject: [PATCH 0716/1647] bridgev2: add timeouts for event handling --- bridgev2/bridgeconfig/config.go | 1 + bridgev2/bridgeconfig/upgrade.go | 1 + bridgev2/matrix/mxmain/example-config.yaml | 3 + bridgev2/portal.go | 117 +++++++++++++-------- bridgev2/portalinternal.go | 20 ++-- 5 files changed, 92 insertions(+), 50 deletions(-) diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index 40a17622..051e6a00 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -59,6 +59,7 @@ type BridgeConfig struct { CommandPrefix string `yaml:"command_prefix"` PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"` PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"` + AsyncEvents bool `yaml:"async_events"` BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` TagOnlyOnCreate bool `yaml:"tag_only_on_create"` MuteOnlyOnCreate bool `yaml:"mute_only_on_create"` diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 7e524e84..d6ccf007 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -25,6 +25,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Str, "bridge", "command_prefix") helper.Copy(up.Bool, "bridge", "personal_filtering_spaces") helper.Copy(up.Bool, "bridge", "private_chat_portal_meta") + helper.Copy(up.Bool, "bridge", "async_events") helper.Copy(up.Bool, "bridge", "bridge_matrix_leave") helper.Copy(up.Bool, "bridge", "tag_only_on_create") helper.Copy(up.Bool, "bridge", "mute_only_on_create") diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index e0a5ed87..31490bb3 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -7,6 +7,9 @@ bridge: # Whether the bridge should set names and avatars explicitly for DM portals. # This is only necessary when using clients that don't support MSC4171. private_chat_portal_meta: false + # Should events be handled asynchronously within portal rooms? + # If true, events may end up being out of order, but slow events won't block other ones. + async_events: false # Should leaving Matrix rooms be bridged as leaving groups on the remote network? bridge_matrix_leave: false diff --git a/bridgev2/portal.go b/bridgev2/portal.go index a5a150a4..ca3be3e4 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -14,6 +14,7 @@ import ( "runtime/debug" "strings" "sync" + "sync/atomic" "time" "github.com/rs/zerolog" @@ -275,23 +276,49 @@ func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) { func (portal *Portal) eventLoop() { for rawEvt := range portal.events { - switch evt := rawEvt.(type) { - case *portalMatrixEvent: - portal.handleMatrixEvent(evt.sender, evt.evt) - case *portalRemoteEvent: - portal.handleRemoteEvent(evt.source, evt.evt) - case *portalCreateEvent: - portal.handleCreateEvent(evt) - default: - panic(fmt.Errorf("illegal type %T in eventLoop", evt)) - } + portal.handleSingleEventAsync(rawEvt) } } -func (portal *Portal) handleCreateEvent(evt *portalCreateEvent) { +func (portal *Portal) handleSingleEventAsync(rawEvt any) { + log := portal.Log.With().Logger() + if _, isCreate := rawEvt.(*portalCreateEvent); isCreate { + portal.handleSingleEvent(&log, rawEvt, func() {}) + } else if portal.Bridge.Config.AsyncEvents { + go portal.handleSingleEvent(&log, rawEvt, func() {}) + } else { + doneCh := make(chan struct{}) + var backgrounded atomic.Bool + go portal.handleSingleEvent(&log, rawEvt, func() { + close(doneCh) + if backgrounded.Load() { + log.Debug().Msg("Event that took too long finally finished handling") + } + }) + tick := time.NewTicker(30 * time.Second) + defer tick.Stop() + for i := 0; i < 10; i++ { + select { + case <-doneCh: + if i > 0 { + log.Debug().Msg("Event that took long finished handling") + } + return + case <-tick.C: + log.Warn().Msg("Event handling is taking long") + } + } + log.Warn().Msg("Event handling is taking too long, continuing in background") + backgrounded.Store(true) + } +} + +func (portal *Portal) handleSingleEvent(log *zerolog.Logger, rawEvt any, doneCallback func()) { + ctx := log.WithContext(context.Background()) defer func() { + doneCallback() if err := recover(); err != nil { - logEvt := zerolog.Ctx(evt.ctx).Error() + logEvt := log.Error() if realErr, ok := err.(error); ok { logEvt = logEvt.Err(realErr) } else { @@ -300,10 +327,36 @@ func (portal *Portal) handleCreateEvent(evt *portalCreateEvent) { logEvt. Bytes("stack", debug.Stack()). Msg("Portal creation panicked") - evt.cb(fmt.Errorf("portal creation panicked")) + switch evt := rawEvt.(type) { + case *portalMatrixEvent: + if evt.evt.ID != "" { + go portal.sendErrorStatus(ctx, evt.evt, ErrPanicInEventHandler) + } + case *portalCreateEvent: + evt.cb(fmt.Errorf("portal creation panicked")) + } } }() - evt.cb(portal.createMatrixRoomInLoop(evt.ctx, evt.source, evt.info, nil)) + switch evt := rawEvt.(type) { + case *portalMatrixEvent: + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("action", "handle matrix event"). + Stringer("event_id", evt.evt.ID). + Str("event_type", evt.evt.Type.Type) + }) + portal.handleMatrixEvent(ctx, evt.sender, evt.evt) + case *portalRemoteEvent: + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("action", "handle remote event"). + Str("source_id", string(evt.source.ID)) + }) + portal.handleRemoteEvent(ctx, evt.source, evt.evt) + case *portalCreateEvent: + *log = *zerolog.Ctx(evt.ctx) + evt.cb(portal.createMatrixRoomInLoop(evt.ctx, evt.source, evt.info, nil)) + default: + panic(fmt.Errorf("illegal type %T in eventLoop", evt)) + } } func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowRelay bool) (*UserLogin, *database.UserPortal, error) { @@ -393,29 +446,8 @@ func (portal *Portal) checkConfusableName(ctx context.Context, userID id.UserID, return false } -func (portal *Portal) handleMatrixEvent(sender *User, evt *event.Event) { - log := portal.Log.With(). - Str("action", "handle matrix event"). - Stringer("event_id", evt.ID). - Str("event_type", evt.Type.Type). - Logger() - ctx := log.WithContext(context.TODO()) - defer func() { - if err := recover(); err != nil { - logEvt := log.Error() - if realErr, ok := err.(error); ok { - logEvt = logEvt.Err(realErr) - } else { - logEvt = logEvt.Any(zerolog.ErrorFieldName, err) - } - logEvt. - Bytes("stack", debug.Stack()). - Msg("Matrix event handler panicked") - if evt.ID != "" { - go portal.sendErrorStatus(ctx, evt, ErrPanicInEventHandler) - } - } - }() +func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) { + log := zerolog.Ctx(ctx) if evt.Mautrix.EventSource&event.SourceEphemeral != 0 { switch evt.Type { case event.EphemeralEventReceipt: @@ -1458,11 +1490,8 @@ func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLog portal.sendSuccessStatus(ctx, evt) } -func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { - log := portal.Log.With(). - Str("source_id", string(source.ID)). - Str("action", "handle remote event"). - Logger() +func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, evt RemoteEvent) { + log := zerolog.Ctx(ctx) defer func() { if err := recover(); err != nil { logEvt := log.Error() @@ -1481,7 +1510,6 @@ func (portal *Portal) handleRemoteEvent(source *UserLogin, evt RemoteEvent) { c = c.Stringer("bridge_evt_type", evtType) return evt.AddLogContext(c) }) - ctx := log.WithContext(context.TODO()) if portal.MXID == "" { mcp, ok := evt.(RemoteEventThatMayCreatePortal) if !ok || !mcp.ShouldCreatePortal() { @@ -1823,7 +1851,8 @@ func (portal *Portal) handleRemoteUpsert(ctx context.Context, source *UserLogin, } if len(res.SubEvents) > 0 { for _, subEvt := range res.SubEvents { - portal.handleRemoteEvent(source, subEvt) + log := portal.Log.With().Str("source_id", string(source.ID)).Str("action", "handle remote subevent").Logger() + portal.handleRemoteEvent(log.WithContext(ctx), source, subEvt) } } return res.ContinueMessageHandling diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index 1ee793a9..a4bd611a 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -37,8 +37,12 @@ func (portal *PortalInternals) EventLoop() { (*Portal)(portal).eventLoop() } -func (portal *PortalInternals) HandleCreateEvent(evt *portalCreateEvent) { - (*Portal)(portal).handleCreateEvent(evt) +func (portal *PortalInternals) HandleSingleEventAsync(rawEvt any) { + (*Portal)(portal).handleSingleEventAsync(rawEvt) +} + +func (portal *PortalInternals) HandleSingleEvent(log *zerolog.Logger, rawEvt any, doneCallback func()) { + (*Portal)(portal).handleSingleEvent(log, rawEvt, doneCallback) } func (portal *PortalInternals) SendSuccessStatus(ctx context.Context, evt *event.Event) { @@ -53,8 +57,8 @@ func (portal *PortalInternals) CheckConfusableName(ctx context.Context, userID i return (*Portal)(portal).checkConfusableName(ctx, userID, name) } -func (portal *PortalInternals) HandleMatrixEvent(sender *User, evt *event.Event) { - (*Portal)(portal).handleMatrixEvent(sender, evt) +func (portal *PortalInternals) HandleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) { + (*Portal)(portal).handleMatrixEvent(ctx, sender, evt) } func (portal *PortalInternals) HandleMatrixReceipts(ctx context.Context, evt *event.Event) { @@ -109,8 +113,8 @@ func (portal *PortalInternals) HandleMatrixRedaction(ctx context.Context, sender (*Portal)(portal).handleMatrixRedaction(ctx, sender, origSender, evt) } -func (portal *PortalInternals) HandleRemoteEvent(source *UserLogin, evt RemoteEvent) { - (*Portal)(portal).handleRemoteEvent(source, evt) +func (portal *PortalInternals) HandleRemoteEvent(ctx context.Context, source *UserLogin, evt RemoteEvent) { + (*Portal)(portal).handleRemoteEvent(ctx, source, evt) } func (portal *PortalInternals) GetIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID) { @@ -297,6 +301,10 @@ func (portal *PortalInternals) DoThreadBackfill(ctx context.Context, source *Use (*Portal)(portal).doThreadBackfill(ctx, source, threadID) } +func (portal *PortalInternals) CutoffMessages(ctx context.Context, messages []*BackfillMessage, aggressiveDedup, forward bool, lastMessage *database.Message) []*BackfillMessage { + return (*Portal)(portal).cutoffMessages(ctx, messages, aggressiveDedup, forward, lastMessage) +} + func (portal *PortalInternals) SendBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, inThread bool) { (*Portal)(portal).sendBackfill(ctx, source, messages, forceForward, markRead, inThread) } From e750881a4ab0aae0dc13a9ac382196db2c78261b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 5 Sep 2024 00:41:12 +0300 Subject: [PATCH 0717/1647] bridgev2/backfill: respect DontBridge flag --- bridgev2/portalbackfill.go | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index ffe68ca5..e4a3e0ad 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -304,17 +304,6 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin partIDs = append(partIDs, part.ID) portal.applyRelationMeta(part.Content, replyTo, threadRoot, prevThreadEvent) evtID := portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, msg.ID, part.ID) - out.Events = append(out.Events, &event.Event{ - Sender: intent.GetMXID(), - Type: part.Type, - Timestamp: msg.Timestamp.UnixMilli(), - ID: evtID, - RoomID: portal.MXID, - Content: event.Content{ - Parsed: part.Content, - Raw: part.Extra, - }, - }) dbMessage := &database.Message{ ID: msg.ID, PartID: part.ID, @@ -327,6 +316,22 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin ReplyTo: ptr.Val(msg.ReplyTo), Metadata: part.DBMetadata, } + if part.DontBridge { + dbMessage.SetFakeMXID() + out.DBMessages = append(out.DBMessages, dbMessage) + continue + } + out.Events = append(out.Events, &event.Event{ + Sender: intent.GetMXID(), + Type: part.Type, + Timestamp: msg.Timestamp.UnixMilli(), + ID: evtID, + RoomID: portal.MXID, + Content: event.Content{ + Parsed: part.Content, + Raw: part.Extra, + }, + }) if firstPart == nil { firstPart = dbMessage } From 059d9a36e5cfe4153b749b517b291796865cee16 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 6 Sep 2024 00:15:16 +0300 Subject: [PATCH 0718/1647] bridgev2: fix issues in event loop debug logger --- bridgev2/portal.go | 110 +++++++++++++++++++------------------ bridgev2/portalinternal.go | 16 ++++-- 2 files changed, 68 insertions(+), 58 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index ca3be3e4..dce8cfe5 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -38,8 +38,9 @@ type portalMatrixEvent struct { } type portalRemoteEvent struct { - evt RemoteEvent - source *UserLogin + evt RemoteEvent + source *UserLogin + evtType RemoteEventType } type portalCreateEvent struct { @@ -275,24 +276,31 @@ func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) { } func (portal *Portal) eventLoop() { + i := 0 for rawEvt := range portal.events { - portal.handleSingleEventAsync(rawEvt) + i++ + portal.handleSingleEventAsync(i, rawEvt) } } -func (portal *Portal) handleSingleEventAsync(rawEvt any) { - log := portal.Log.With().Logger() +func (portal *Portal) handleSingleEventAsync(idx int, rawEvt any) { + ctx := portal.getEventCtxWithLog(rawEvt, idx) if _, isCreate := rawEvt.(*portalCreateEvent); isCreate { - portal.handleSingleEvent(&log, rawEvt, func() {}) + portal.handleSingleEvent(ctx, rawEvt, func() {}) } else if portal.Bridge.Config.AsyncEvents { - go portal.handleSingleEvent(&log, rawEvt, func() {}) + go portal.handleSingleEvent(ctx, rawEvt, func() {}) } else { + log := zerolog.Ctx(ctx) doneCh := make(chan struct{}) var backgrounded atomic.Bool - go portal.handleSingleEvent(&log, rawEvt, func() { + start := time.Now() + var handleDuration time.Duration + go portal.handleSingleEvent(ctx, rawEvt, func() { + handleDuration = time.Since(start) close(doneCh) if backgrounded.Load() { - log.Debug().Msg("Event that took too long finally finished handling") + log.Debug().Stringer("duration", handleDuration). + Msg("Event that took too long finally finished handling") } }) tick := time.NewTicker(30 * time.Second) @@ -301,7 +309,8 @@ func (portal *Portal) handleSingleEventAsync(rawEvt any) { select { case <-doneCh: if i > 0 { - log.Debug().Msg("Event that took long finished handling") + log.Debug().Stringer("duration", handleDuration). + Msg("Event that took long finished handling") } return case <-tick.C: @@ -313,8 +322,34 @@ func (portal *Portal) handleSingleEventAsync(rawEvt any) { } } -func (portal *Portal) handleSingleEvent(log *zerolog.Logger, rawEvt any, doneCallback func()) { - ctx := log.WithContext(context.Background()) +func (portal *Portal) getEventCtxWithLog(rawEvt any, idx int) context.Context { + var logWith zerolog.Context + switch evt := rawEvt.(type) { + case *portalMatrixEvent: + logWith = portal.Log.With().Int("event_loop_index", idx). + Str("action", "handle matrix event"). + Stringer("event_id", evt.evt.ID). + Str("event_type", evt.evt.Type.Type) + if evt.evt.Mautrix.EventSource&event.SourceEphemeral == 0 { + logWith = logWith. + Stringer("event_id", evt.evt.ID). + Stringer("sender", evt.sender.MXID) + } + case *portalRemoteEvent: + evt.evtType = evt.evt.GetType() + logWith = portal.Log.With().Int("event_loop_index", idx). + Str("action", "handle remote event"). + Str("source_id", string(evt.source.ID)). + Stringer("bridge_evt_type", evt.evtType) + logWith = evt.evt.AddLogContext(logWith) + case *portalCreateEvent: + return evt.ctx + } + return logWith.Logger().WithContext(context.Background()) +} + +func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCallback func()) { + log := zerolog.Ctx(ctx) defer func() { doneCallback() if err := recover(); err != nil { @@ -339,20 +374,10 @@ func (portal *Portal) handleSingleEvent(log *zerolog.Logger, rawEvt any, doneCal }() switch evt := rawEvt.(type) { case *portalMatrixEvent: - log.UpdateContext(func(c zerolog.Context) zerolog.Context { - return c.Str("action", "handle matrix event"). - Stringer("event_id", evt.evt.ID). - Str("event_type", evt.evt.Type.Type) - }) portal.handleMatrixEvent(ctx, evt.sender, evt.evt) case *portalRemoteEvent: - log.UpdateContext(func(c zerolog.Context) zerolog.Context { - return c.Str("action", "handle remote event"). - Str("source_id", string(evt.source.ID)) - }) - portal.handleRemoteEvent(ctx, evt.source, evt.evt) + portal.handleRemoteEvent(ctx, evt.source, evt.evtType, evt.evt) case *portalCreateEvent: - *log = *zerolog.Ctx(evt.ctx) evt.cb(portal.createMatrixRoomInLoop(evt.ctx, evt.source, evt.info, nil)) default: panic(fmt.Errorf("illegal type %T in eventLoop", evt)) @@ -457,11 +482,6 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * } return } - log.UpdateContext(func(c zerolog.Context) zerolog.Context { - return c. - Stringer("event_id", evt.ID). - Stringer("sender", sender.MXID) - }) login, _, err := portal.FindPreferredLogin(ctx, sender, true) if err != nil { log.Err(err).Msg("Failed to get user login to handle Matrix event") @@ -497,9 +517,8 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * } origSender.FormattedName = portal.Bridge.Config.Relay.FormatName(origSender) } - log.UpdateContext(func(c zerolog.Context) zerolog.Context { - return c.Str("login_id", string(login.ID)) - }) + // Copy logger because many of the handlers will use UpdateContext + ctx = log.With().Str("login_id", string(login.ID)).Logger().WithContext(ctx) switch evt.Type { case event.EventMessage, event.EventSticker: portal.handleMatrixMessage(ctx, login, origSender, evt) @@ -1490,26 +1509,8 @@ func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLog portal.sendSuccessStatus(ctx, evt) } -func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, evt RemoteEvent) { +func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, evtType RemoteEventType, evt RemoteEvent) { log := zerolog.Ctx(ctx) - defer func() { - if err := recover(); err != nil { - logEvt := log.Error() - if realErr, ok := err.(error); ok { - logEvt = logEvt.Err(realErr) - } else { - logEvt = logEvt.Any(zerolog.ErrorFieldName, err) - } - logEvt. - Bytes("stack", debug.Stack()). - Msg("Remote event handler panicked") - } - }() - evtType := evt.GetType() - log.UpdateContext(func(c zerolog.Context) zerolog.Context { - c = c.Stringer("bridge_evt_type", evtType) - return evt.AddLogContext(c) - }) if portal.MXID == "" { mcp, ok := evt.(RemoteEventThatMayCreatePortal) if !ok || !mcp.ShouldCreatePortal() { @@ -1851,8 +1852,13 @@ func (portal *Portal) handleRemoteUpsert(ctx context.Context, source *UserLogin, } if len(res.SubEvents) > 0 { for _, subEvt := range res.SubEvents { - log := portal.Log.With().Str("source_id", string(source.ID)).Str("action", "handle remote subevent").Logger() - portal.handleRemoteEvent(log.WithContext(ctx), source, subEvt) + subType := subEvt.GetType() + log := portal.Log.With(). + Str("source_id", string(source.ID)). + Str("action", "handle remote subevent"). + Stringer("bridge_evt_type", subType). + Logger() + portal.handleRemoteEvent(log.WithContext(ctx), source, subType, subEvt) } } return res.ContinueMessageHandling diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index a4bd611a..77bdd7fd 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -37,12 +37,16 @@ func (portal *PortalInternals) EventLoop() { (*Portal)(portal).eventLoop() } -func (portal *PortalInternals) HandleSingleEventAsync(rawEvt any) { - (*Portal)(portal).handleSingleEventAsync(rawEvt) +func (portal *PortalInternals) HandleSingleEventAsync(idx int, rawEvt any) { + (*Portal)(portal).handleSingleEventAsync(idx, rawEvt) } -func (portal *PortalInternals) HandleSingleEvent(log *zerolog.Logger, rawEvt any, doneCallback func()) { - (*Portal)(portal).handleSingleEvent(log, rawEvt, doneCallback) +func (portal *PortalInternals) GetEventCtxWithLog(rawEvt any, idx int) context.Context { + return (*Portal)(portal).getEventCtxWithLog(rawEvt, idx) +} + +func (portal *PortalInternals) HandleSingleEvent(ctx context.Context, rawEvt any, doneCallback func()) { + (*Portal)(portal).handleSingleEvent(ctx, rawEvt, doneCallback) } func (portal *PortalInternals) SendSuccessStatus(ctx context.Context, evt *event.Event) { @@ -113,8 +117,8 @@ func (portal *PortalInternals) HandleMatrixRedaction(ctx context.Context, sender (*Portal)(portal).handleMatrixRedaction(ctx, sender, origSender, evt) } -func (portal *PortalInternals) HandleRemoteEvent(ctx context.Context, source *UserLogin, evt RemoteEvent) { - (*Portal)(portal).handleRemoteEvent(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteEvent(ctx context.Context, source *UserLogin, evtType RemoteEventType, evt RemoteEvent) { + (*Portal)(portal).handleRemoteEvent(ctx, source, evtType, evt) } func (portal *PortalInternals) GetIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID) { From 33d724bf4c78636b425e8fa2157b0f314fa2349e Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Fri, 6 Sep 2024 13:54:54 +0300 Subject: [PATCH 0719/1647] event: add encrypted file info for m.room.avatar (#283) --- event/state.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/event/state.go b/event/state.go index 6e5f0ae4..d6c200a9 100644 --- a/event/state.go +++ b/event/state.go @@ -26,8 +26,9 @@ type RoomNameEventContent struct { // RoomAvatarEventContent represents the content of a m.room.avatar state event. // https://spec.matrix.org/v1.2/client-server-api/#mroomavatar type RoomAvatarEventContent struct { - URL id.ContentURIString `json:"url"` - Info *FileInfo `json:"info,omitempty"` + URL id.ContentURIString `json:"url,omitempty"` + Info *FileInfo `json:"info,omitempty"` + MSC3414File *EncryptedFileInfo `json:"org.matrix.msc3414.file,omitempty"` } // ServerACLEventContent represents the content of a m.room.server_acl state event. From 6b055b1475bd390abb6e5b83edd3a24113edcb2e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 6 Sep 2024 17:51:09 +0300 Subject: [PATCH 0720/1647] bridgev2: include portal receiver in m.bridge events --- bridgev2/portal.go | 1 + event/state.go | 2 ++ 2 files changed, 3 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index dce8cfe5..81a1ff89 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2817,6 +2817,7 @@ func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) { ID: string(portal.ID), DisplayName: portal.Name, AvatarURL: portal.AvatarMXC, + Receiver: string(portal.Receiver), // TODO external URL? }, BeeperRoomTypeV2: string(portal.RoomType), diff --git a/event/state.go b/event/state.go index d6c200a9..15972892 100644 --- a/event/state.go +++ b/event/state.go @@ -157,6 +157,8 @@ type BridgeInfoSection struct { DisplayName string `json:"displayname,omitempty"` AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` ExternalURL string `json:"external_url,omitempty"` + + Receiver string `json:"fi.mau.receiver,omitempty"` } // BridgeEventContent represents the content of a m.bridge state event. From 4098a3726eabd98cf10c564e2578f8eccd5fe8f8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 8 Sep 2024 13:47:49 +0300 Subject: [PATCH 0721/1647] event: ensure MSC1767 audio has empty waveform --- event/audio.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/event/audio.go b/event/audio.go index 798acc8c..9eeb8edb 100644 --- a/event/audio.go +++ b/event/audio.go @@ -1,8 +1,21 @@ package event +import ( + "encoding/json" +) + type MSC1767Audio struct { Duration int `json:"duration"` Waveform []int `json:"waveform"` } +type serializableMSC1767Audio MSC1767Audio + +func (ma *MSC1767Audio) MarshalJSON() ([]byte, error) { + if ma.Waveform == nil { + ma.Waveform = []int{} + } + return json.Marshal((*serializableMSC1767Audio)(ma)) +} + type MSC3245Voice struct{} From 4fd082aba92e8fd012042b8f5e35edde233ab9ca Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Mon, 9 Sep 2024 16:56:40 +0300 Subject: [PATCH 0722/1647] event: add Beeper backup flag to unsigned (#284) --- event/events.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/event/events.go b/event/events.go index 4653a531..e5a01d02 100644 --- a/event/events.go +++ b/event/events.go @@ -144,7 +144,8 @@ type Unsigned struct { RedactedBecause *Event `json:"redacted_because,omitempty"` InviteRoomState []StrippedState `json:"invite_room_state,omitempty"` - BeeperHSOrder int64 `json:"com.beeper.hs.order,omitempty"` + BeeperHSOrder int64 `json:"com.beeper.hs.order,omitempty"` + BeeperFromBackup bool `json:"com.beeper.from_backup,omitempty"` } func (us *Unsigned) IsEmpty() bool { From cbc307b3112eac60c8bce9697446493366ffdeb1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 10 Sep 2024 14:15:02 +0300 Subject: [PATCH 0723/1647] bridgev2/database: add unique constraint on message mxids --- bridgev2/database/upgrades/00-latest.sql | 8 +++++--- bridgev2/database/upgrades/17-message-mxid-unique.sql | 7 +++++++ 2 files changed, 12 insertions(+), 3 deletions(-) create mode 100644 bridgev2/database/upgrades/17-message-mxid-unique.sql diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index aeb9522e..80acfab1 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v16 (compatible with v9+): Latest revision +-- v0 -> v17 (compatible with v9+): Latest revision CREATE TABLE "user" ( bridge_id TEXT NOT NULL, mxid TEXT NOT NULL, @@ -115,7 +115,8 @@ CREATE TABLE message ( CONSTRAINT message_sender_fkey FOREIGN KEY (bridge_id, sender_id) REFERENCES ghost (bridge_id, id) ON DELETE CASCADE ON UPDATE CASCADE, - CONSTRAINT message_real_pkey UNIQUE (bridge_id, room_receiver, id, part_id) + CONSTRAINT message_real_pkey UNIQUE (bridge_id, room_receiver, id, part_id), + CONSTRAINT message_mxid_unique UNIQUE (bridge_id, mxid) ); CREATE INDEX message_room_idx ON message (bridge_id, room_id, room_receiver); @@ -154,7 +155,8 @@ CREATE TABLE reaction ( ON DELETE CASCADE ON UPDATE CASCADE, CONSTRAINT reaction_sender_fkey FOREIGN KEY (bridge_id, sender_id) REFERENCES ghost (bridge_id, id) - ON DELETE CASCADE ON UPDATE CASCADE + ON DELETE CASCADE ON UPDATE CASCADE, + CONSTRAINT reaction_mxid_unique UNIQUE (bridge_id, mxid) ); CREATE INDEX reaction_room_idx ON reaction (bridge_id, room_id, room_receiver); diff --git a/bridgev2/database/upgrades/17-message-mxid-unique.sql b/bridgev2/database/upgrades/17-message-mxid-unique.sql new file mode 100644 index 00000000..05503191 --- /dev/null +++ b/bridgev2/database/upgrades/17-message-mxid-unique.sql @@ -0,0 +1,7 @@ +-- v17 (compatible with v9+): Add unique constraint for message and reaction mxids +-- only: postgres for next 2 lines +ALTER TABLE message ADD CONSTRAINT message_mxid_unique UNIQUE (bridge_id, mxid); +ALTER TABLE reaction ADD CONSTRAINT reaction_mxid_unique UNIQUE (bridge_id, mxid); +-- only: sqlite for next 2 lines +CREATE UNIQUE INDEX message_mxid_unique ON message (bridge_id, mxid); +CREATE UNIQUE INDEX reaction_mxid_unique ON reaction (bridge_id, mxid); From ffdb1d575e5fab5fa5f9dbb0c13806632b5e6da6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 10 Sep 2024 14:16:14 +0300 Subject: [PATCH 0724/1647] bridgev2/portal: maybe fix check for adding log context --- bridgev2/portal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 81a1ff89..0ab1c9f5 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -330,7 +330,7 @@ func (portal *Portal) getEventCtxWithLog(rawEvt any, idx int) context.Context { Str("action", "handle matrix event"). Stringer("event_id", evt.evt.ID). Str("event_type", evt.evt.Type.Type) - if evt.evt.Mautrix.EventSource&event.SourceEphemeral == 0 { + if evt.evt.Type.Class != event.EphemeralEventType { logWith = logWith. Stringer("event_id", evt.evt.ID). Stringer("sender", evt.sender.MXID) From c9a9cb69575af703dfed6db6b8b431856ab7faf7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 11 Sep 2024 02:07:39 +0300 Subject: [PATCH 0725/1647] bridgev2/matrixinterface: add download to file method --- bridgev2/matrix/intent.go | 62 +++++++++++++++++++++++++++++++++++++ bridgev2/matrixinterface.go | 2 ++ 2 files changed, 64 insertions(+) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 7f6ebbb9..11cee7ba 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -18,6 +18,7 @@ import ( "time" "github.com/rs/zerolog" + "go.mau.fi/util/fallocate" "go.mau.fi/util/ptr" "golang.org/x/exp/slices" @@ -213,6 +214,60 @@ func (as *ASIntent) DownloadMedia(ctx context.Context, uri id.ContentURIString, return data, nil } +func (as *ASIntent) DownloadMediaToFile(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo, writable bool) (*os.File, error) { + if file != nil { + uri = file.URL + err := file.PrepareForDecryption() + if err != nil { + return nil, err + } + } + parsedURI, err := uri.Parse() + if err != nil { + return nil, err + } + tempFile, err := os.CreateTemp("", "mautrix-download-*") + if err != nil { + return nil, fmt.Errorf("failed to create temp file: %w", err) + } + ok := false + defer func() { + if !ok { + _ = tempFile.Close() + _ = os.Remove(tempFile.Name()) + } + }() + resp, err := as.Matrix.Download(ctx, parsedURI) + if err != nil { + return nil, fmt.Errorf("failed to send download request: %w", err) + } + defer resp.Body.Close() + reader := resp.Body + if file != nil { + reader = file.DecryptStream(reader) + } + if resp.ContentLength > 0 { + err = fallocate.Fallocate(tempFile, int(resp.ContentLength)) + if err != nil { + return nil, fmt.Errorf("failed to preallocate file: %w", err) + } + } + _, err = io.Copy(tempFile, reader) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + err = reader.Close() + if err != nil { + return nil, fmt.Errorf("failed to close response body: %w", err) + } + _, err = tempFile.Seek(0, io.SeekStart) + if err != nil { + return nil, fmt.Errorf("failed to seek to start of temp file: %w", err) + } + ok = true + return tempFile, nil +} + func (as *ASIntent) UploadMedia(ctx context.Context, roomID id.RoomID, data []byte, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) { if int64(len(data)) > as.Connector.MediaConfig.UploadSize { return "", nil, fmt.Errorf("file too large (%.2f MB > %.2f MB)", float64(len(data))/1000/1000, float64(as.Connector.MediaConfig.UploadSize)/1000/1000) @@ -275,6 +330,13 @@ func (as *ASIntent) UploadMediaStream( removeAndClose(tempFile) } }() + if size > 0 { + err = fallocate.Fallocate(tempFile, int(size)) + if err != nil { + err = fmt.Errorf("failed to preallocate file: %w", err) + return + } + } if roomID != "" { var encrypted bool if encrypted, err = as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil { diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 9fb0c82d..b302bd20 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -9,6 +9,7 @@ package bridgev2 import ( "context" "io" + "os" "time" "github.com/gorilla/mux" @@ -110,6 +111,7 @@ type MatrixAPI interface { MarkUnread(ctx context.Context, roomID id.RoomID, unread bool) error MarkTyping(ctx context.Context, roomID id.RoomID, typingType TypingType, timeout time.Duration) error DownloadMedia(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo) ([]byte, error) + DownloadMediaToFile(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo, writable bool) (*os.File, error) UploadMedia(ctx context.Context, roomID id.RoomID, data []byte, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) UploadMediaStream(ctx context.Context, roomID id.RoomID, size int64, requireFile bool, cb FileStreamCallback) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) From a95cb9adb3358851b8725bfb32f8b509b7b55985 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 11 Sep 2024 12:46:21 +0300 Subject: [PATCH 0726/1647] bridgev2/matrixinterface: use callback for DownloadMediaToFile The caller doesn't know whether the file should be removed, so let the Matrix interface deal with it. --- bridgev2/matrix/intent.go | 34 +++++++++++++++++----------------- bridgev2/matrixinterface.go | 16 +++++++++++++++- 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 11cee7ba..e7e860f6 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -214,32 +214,29 @@ func (as *ASIntent) DownloadMedia(ctx context.Context, uri id.ContentURIString, return data, nil } -func (as *ASIntent) DownloadMediaToFile(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo, writable bool) (*os.File, error) { +func (as *ASIntent) DownloadMediaToFile(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo, writable bool, callback func(*os.File) error) error { if file != nil { uri = file.URL err := file.PrepareForDecryption() if err != nil { - return nil, err + return err } } parsedURI, err := uri.Parse() if err != nil { - return nil, err + return err } tempFile, err := os.CreateTemp("", "mautrix-download-*") if err != nil { - return nil, fmt.Errorf("failed to create temp file: %w", err) + return fmt.Errorf("failed to create temp file: %w", err) } - ok := false defer func() { - if !ok { - _ = tempFile.Close() - _ = os.Remove(tempFile.Name()) - } + _ = tempFile.Close() + _ = os.Remove(tempFile.Name()) }() resp, err := as.Matrix.Download(ctx, parsedURI) if err != nil { - return nil, fmt.Errorf("failed to send download request: %w", err) + return fmt.Errorf("failed to send download request: %w", err) } defer resp.Body.Close() reader := resp.Body @@ -249,23 +246,26 @@ func (as *ASIntent) DownloadMediaToFile(ctx context.Context, uri id.ContentURISt if resp.ContentLength > 0 { err = fallocate.Fallocate(tempFile, int(resp.ContentLength)) if err != nil { - return nil, fmt.Errorf("failed to preallocate file: %w", err) + return fmt.Errorf("failed to preallocate file: %w", err) } } _, err = io.Copy(tempFile, reader) if err != nil { - return nil, fmt.Errorf("failed to read response: %w", err) + return fmt.Errorf("failed to read response: %w", err) } err = reader.Close() if err != nil { - return nil, fmt.Errorf("failed to close response body: %w", err) + return fmt.Errorf("failed to close response body: %w", err) } _, err = tempFile.Seek(0, io.SeekStart) if err != nil { - return nil, fmt.Errorf("failed to seek to start of temp file: %w", err) + return fmt.Errorf("failed to seek to start of temp file: %w", err) } - ok = true - return tempFile, nil + err = callback(tempFile) + if err != nil { + return bridgev2.CallbackError{Type: "read", Wrapped: err} + } + return nil } func (as *ASIntent) UploadMedia(ctx context.Context, roomID id.RoomID, data []byte, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) { @@ -351,7 +351,7 @@ func (as *ASIntent) UploadMediaStream( var res *bridgev2.FileStreamResult res, err = cb(tempFile) if err != nil { - err = fmt.Errorf("write callback failed: %w", err) + err = bridgev2.CallbackError{Type: "write", Wrapped: err} return } var replFile *os.File diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index b302bd20..4473b74e 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -8,6 +8,7 @@ package bridgev2 import ( "context" + "fmt" "io" "os" "time" @@ -102,6 +103,19 @@ type FileStreamResult struct { // The return value must be non-nil unless there's an error, and should always include FileName and MimeType. type FileStreamCallback func(file io.Writer) (*FileStreamResult, error) +type CallbackError struct { + Type string + Wrapped error +} + +func (ce CallbackError) Error() string { + return fmt.Sprintf("%s callback failed: %s", ce.Type, ce.Wrapped.Error()) +} + +func (ce CallbackError) Unwrap() error { + return ce.Wrapped +} + type MatrixAPI interface { GetMXID() id.UserID @@ -111,7 +125,7 @@ type MatrixAPI interface { MarkUnread(ctx context.Context, roomID id.RoomID, unread bool) error MarkTyping(ctx context.Context, roomID id.RoomID, typingType TypingType, timeout time.Duration) error DownloadMedia(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo) ([]byte, error) - DownloadMediaToFile(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo, writable bool) (*os.File, error) + DownloadMediaToFile(ctx context.Context, uri id.ContentURIString, file *event.EncryptedFileInfo, writable bool, callback func(*os.File) error) error UploadMedia(ctx context.Context, roomID id.RoomID, data []byte, fileName, mimeType string) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) UploadMediaStream(ctx context.Context, roomID id.RoomID, size int64, requireFile bool, cb FileStreamCallback) (url id.ContentURIString, file *event.EncryptedFileInfo, err error) From 00883680667363dda8153e5eb1ce4c210b4746bd Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 11 Sep 2024 13:53:42 +0300 Subject: [PATCH 0727/1647] bridgev2/provisioning: export responding with wrapped errors --- bridgev2/matrix/provisioning.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 00f5eb72..e0f630a0 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -380,13 +380,13 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque ) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create login process") - respondMaybeCustomError(w, err, "Internal error creating login process") + RespondWithError(w, err, "Internal error creating login process") return } firstStep, err := login.Start(r.Context()) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to start login") - respondMaybeCustomError(w, err, "Internal error starting login") + RespondWithError(w, err, "Internal error starting login") return } loginID := xid.New().String() @@ -438,7 +438,7 @@ func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http } if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to submit input") - respondMaybeCustomError(w, err, "Internal error submitting input") + RespondWithError(w, err, "Internal error submitting input") return } login.NextStep = nextStep @@ -532,7 +532,7 @@ func (prov *ProvisioningAPI) GetLoginForRequest(w http.ResponseWriter, r *http.R return userLogin } -func respondMaybeCustomError(w http.ResponseWriter, err error, message string) { +func RespondWithError(w http.ResponseWriter, err error, message string) { var mautrixRespErr mautrix.RespError var bv2RespErr bridgev2.RespError if errors.As(err, &bv2RespErr) { @@ -575,7 +575,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http. resp, err := api.ResolveIdentifier(r.Context(), mux.Vars(r)["identifier"], createChat) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to resolve identifier") - respondMaybeCustomError(w, err, "Internal error resolving identifier") + RespondWithError(w, err, "Internal error resolving identifier") return } else if resp == nil { jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ @@ -689,7 +689,7 @@ func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Reque resp, err := api.GetContactList(r.Context()) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get contact list") - respondMaybeCustomError(w, err, "Internal error fetching contact list") + RespondWithError(w, err, "Internal error fetching contact list") return } jsonResponse(w, http.StatusOK, &RespGetContactList{ @@ -731,7 +731,7 @@ func (prov *ProvisioningAPI) PostSearchUsers(w http.ResponseWriter, r *http.Requ resp, err := api.SearchUsers(r.Context(), req.Query) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get contact list") - respondMaybeCustomError(w, err, "Internal error fetching contact list") + RespondWithError(w, err, "Internal error fetching contact list") return } jsonResponse(w, http.StatusOK, &RespSearchUsers{ From 328bab41a309ecfbbf4e35827aed0c653c926f57 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 11 Sep 2024 17:07:14 +0300 Subject: [PATCH 0728/1647] bridgev2/portal: run portal create background tasks without context cancellation --- bridgev2/portal.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 0ab1c9f5..d86d9ff2 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -3509,12 +3509,13 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo } portal.Bridge.WakeupBackfillQueue() } + withoutCancelCtx := context.WithoutCancel(ctx) if portal.Parent != nil { if portal.Parent.MXID != "" { portal.addToParentSpaceAndSave(ctx, true) } else { log.Info().Msg("Parent portal doesn't exist, creating in background") - go portal.createParentAndAddToSpace(ctx, source) + go portal.createParentAndAddToSpace(withoutCancelCtx, source) } } portal.updateUserLocalInfo(ctx, info.UserLocal, source, true) @@ -3543,7 +3544,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo login := portal.Bridge.GetCachedUserLoginByID(up.LoginID) if login != nil { login.inPortalCache.Remove(portal.PortalKey) - go login.tryAddPortalToSpace(ctx, portal, up.CopyWithoutValues()) + go login.tryAddPortalToSpace(withoutCancelCtx, portal, up.CopyWithoutValues()) } } } From 2bf53fce9215eefc4ece159a142a331154f107b5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 11 Sep 2024 23:59:02 +0300 Subject: [PATCH 0729/1647] crypto/helper: don't require setting ASEventProcessor --- crypto/cryptohelper/cryptohelper.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go index 5f1a952f..4a642055 100644 --- a/crypto/cryptohelper/cryptohelper.go +++ b/crypto/cryptohelper/cryptohelper.go @@ -115,8 +115,6 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { if !ok { if !helper.client.SetAppServiceDeviceID { return fmt.Errorf("the client syncer must implement ExtensibleSyncer") - } else if helper.ASEventProcessor == nil { - return fmt.Errorf("an appservice must be provided when using appservice mode encryption") } } @@ -214,7 +212,7 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { if helper.managedStateStore != nil { syncer.OnEvent(helper.client.StateStoreSyncHandler) } - } else { + } else if helper.ASEventProcessor != nil { helper.mach.AddAppserviceListener(helper.ASEventProcessor) helper.ASEventProcessor.On(event.EventEncrypted, helper.HandleEncrypted) } From 288a94ec1457aca1c797a815328ae282f37b6c5f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 11 Sep 2024 23:59:16 +0300 Subject: [PATCH 0730/1647] crypto/sqlstore: allow initializing manually --- crypto/sql_store.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 255247fd..00544a9b 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -52,14 +52,18 @@ var _ Store = (*SQLCryptoStore)(nil) // NewSQLCryptoStore initializes a new crypto Store using the given database, for a device's crypto material. // The stored material will be encrypted with the given key. func NewSQLCryptoStore(db *dbutil.Database, log dbutil.DatabaseLogger, accountID string, deviceID id.DeviceID, pickleKey []byte) *SQLCryptoStore { - return &SQLCryptoStore{ + store := &SQLCryptoStore{ DB: db.Child(sql_store_upgrade.VersionTableName, sql_store_upgrade.Table, log), PickleKey: pickleKey, AccountID: accountID, DeviceID: deviceID, - - olmSessionCache: make(map[id.SenderKey]map[id.SessionID]*OlmSession), } + store.InitFields() + return store +} + +func (store *SQLCryptoStore) InitFields() { + store.olmSessionCache = make(map[id.SenderKey]map[id.SessionID]*OlmSession) } // Flush does nothing for this implementation as data is already persisted in the database. From 1d428be6d942535b54ad3564fb0e8d33a28f4f70 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 12 Sep 2024 02:19:50 +0300 Subject: [PATCH 0731/1647] crypto/helper: allow using LoginAs with unmanaged crypto store --- crypto/cryptohelper/cryptohelper.go | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go index 4a642055..9aed5121 100644 --- a/crypto/cryptohelper/cryptohelper.go +++ b/crypto/cryptohelper/cryptohelper.go @@ -140,7 +140,13 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to upgrade crypto state store: %w", err) } - storedDeviceID, err := managedCryptoStore.FindDeviceID(ctx) + cryptoStore = managedCryptoStore + } else { + cryptoStore = helper.unmanagedCryptoStore + } + shouldFindDeviceID := helper.LoginAs != nil || helper.unmanagedCryptoStore == nil + if rawCryptoStore, ok := cryptoStore.(*crypto.SQLCryptoStore); ok && shouldFindDeviceID { + storedDeviceID, err := rawCryptoStore.FindDeviceID(ctx) if err != nil { return fmt.Errorf("failed to find existing device ID: %w", err) } @@ -154,14 +160,14 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { if err != nil { return err } - managedCryptoStore.DeviceID = resp.DeviceID + rawCryptoStore.DeviceID = resp.DeviceID helper.client.DeviceID = resp.DeviceID } else { helper.log.Debug(). Str("username", helper.LoginAs.Identifier.User). Stringer("device_id", storedDeviceID). Msg("Using existing device") - managedCryptoStore.DeviceID = storedDeviceID + rawCryptoStore.DeviceID = storedDeviceID helper.client.DeviceID = storedDeviceID } } else if helper.LoginAs != nil { @@ -178,17 +184,13 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { return err } if storedDeviceID == "" { - managedCryptoStore.DeviceID = helper.client.DeviceID + rawCryptoStore.DeviceID = helper.client.DeviceID } } else if storedDeviceID != "" && storedDeviceID != helper.client.DeviceID { return fmt.Errorf("mismatching device ID in client and crypto store (%q != %q)", storedDeviceID, helper.client.DeviceID) } - cryptoStore = managedCryptoStore - } else { - if helper.LoginAs != nil { - return fmt.Errorf("LoginAs can only be used with a managed crypto store") - } - cryptoStore = helper.unmanagedCryptoStore + } else if helper.LoginAs != nil { + return fmt.Errorf("LoginAs can only be used with a managed crypto store") } if helper.client.DeviceID == "" || helper.client.UserID == "" { return fmt.Errorf("the client must be logged in") From c62757ab15517ea9eafdb4e80af03030dbd76200 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 12 Sep 2024 02:33:15 +0300 Subject: [PATCH 0732/1647] client: add wrappers for event and room reporting endpoints --- client.go | 12 ++++++++++++ requests.go | 5 +++++ 2 files changed, 17 insertions(+) diff --git a/client.go b/client.go index edbeedfe..36f980bf 100644 --- a/client.go +++ b/client.go @@ -2359,6 +2359,18 @@ func (cli *Client) PutPushRule(ctx context.Context, scope string, kind pushrules return err } +func (cli *Client) ReportEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID, reason string) error { + urlPath := cli.BuildClientURL("v3", "rooms", roomID, "report", eventID) + _, err := cli.MakeRequest(ctx, http.MethodPost, urlPath, &ReqReport{Reason: reason, Score: -100}, nil) + return err +} + +func (cli *Client) ReportRoom(ctx context.Context, roomID id.RoomID, reason string) error { + urlPath := cli.BuildClientURL("v3", "rooms", roomID, "report") + _, err := cli.MakeRequest(ctx, http.MethodPost, urlPath, &ReqReport{Reason: reason, Score: -100}, nil) + return err +} + // BatchSend sends a batch of historical events into a room. This is only available for appservices. // // Deprecated: MSC2716 has been abandoned, so this is now Beeper-specific. BeeperBatchSend should be used instead. diff --git a/requests.go b/requests.go index 189e620d..a6b0ea8b 100644 --- a/requests.go +++ b/requests.go @@ -468,3 +468,8 @@ type ReqKeyBackupData struct { IsVerified bool `json:"is_verified"` SessionData json.RawMessage `json:"session_data"` } + +type ReqReport struct { + Reason string `json:"reason,omitempty"` + Score int `json:"score,omitempty"` +} From 6b4ff8b60eddfa3abb2ad551918a9d544d173191 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 12 Sep 2024 14:23:17 +0300 Subject: [PATCH 0733/1647] bridgev2/portal: fix event handlin panic log message --- bridgev2/portal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index d86d9ff2..c5847f35 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -361,7 +361,7 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal } logEvt. Bytes("stack", debug.Stack()). - Msg("Portal creation panicked") + Msg("Event handling panicked") switch evt := rawEvt.(type) { case *portalMatrixEvent: if evt.evt.ID != "" { From 0328ed1c9f41d1f666cb1cb55c8111afc18724e5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 12 Sep 2024 15:59:01 +0300 Subject: [PATCH 0734/1647] bridgev2/backfill: add log before deduplicating messages --- bridgev2/portalbackfill.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index e4a3e0ad..6a912355 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -63,6 +63,11 @@ func (portal *Portal) doForwardBackfill(ctx context.Context, source *UserLogin, log.Debug().Msg("No messages to backfill") return } + log.Debug(). + Int("message_count", len(resp.Messages)). + Bool("mark_read", resp.MarkRead). + Bool("aggressive_deduplication", resp.AggressiveDeduplication). + Msg("Fetched messages for forward backfill, deduplicating before sending") // TODO mark backfill queue task as done if last message is nil (-> room was empty) and HasMore is false? resp.Messages = portal.cutoffMessages(ctx, resp.Messages, resp.AggressiveDeduplication, true, lastMessage) if len(resp.Messages) == 0 { From bc22852f06767b630d6a71be0874d0245928ccbf Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 12 Sep 2024 16:53:50 +0300 Subject: [PATCH 0735/1647] bridgev2: add analytics sending method --- bridgev2/bridgeconfig/config.go | 7 +++ bridgev2/bridgeconfig/upgrade.go | 5 ++ bridgev2/matrix/analytics.go | 62 ++++++++++++++++++++++ bridgev2/matrix/connector.go | 3 ++ bridgev2/matrix/mxmain/example-config.yaml | 9 ++++ bridgev2/matrixinterface.go | 4 ++ bridgev2/user.go | 11 ++++ 7 files changed, 101 insertions(+) create mode 100644 bridgev2/matrix/analytics.go diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index 051e6a00..1731688d 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -21,6 +21,7 @@ type Config struct { Homeserver HomeserverConfig `yaml:"homeserver"` AppService AppserviceConfig `yaml:"appservice"` Matrix MatrixConfig `yaml:"matrix"` + Analytics AnalyticsConfig `yaml:"analytics"` Provisioning ProvisioningConfig `yaml:"provisioning"` PublicMedia PublicMediaConfig `yaml:"public_media"` DirectMedia DirectMediaConfig `yaml:"direct_media"` @@ -78,6 +79,12 @@ type MatrixConfig struct { UploadFileThreshold int64 `yaml:"upload_file_threshold"` } +type AnalyticsConfig struct { + Token string `yaml:"token"` + URL string `yaml:"url"` + UserID string `yaml:"user_id"` +} + type ProvisioningConfig struct { Prefix string `yaml:"prefix"` SharedSecret string `yaml:"shared_secret"` diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index d6ccf007..4491dedd 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -87,6 +87,10 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "matrix", "federate_rooms") helper.Copy(up.Int, "matrix", "upload_file_threshold") + helper.Copy(up.Str|up.Null, "analytics", "token") + helper.Copy(up.Str|up.Null, "analytics", "url") + helper.Copy(up.Str|up.Null, "analytics", "user_id") + helper.Copy(up.Str, "provisioning", "prefix") if secret, ok := helper.Get(up.Str, "provisioning", "shared_secret"); !ok || secret == "generate" { sharedSecret := random.String(64) @@ -176,6 +180,7 @@ var SpacedBlocks = [][]string{ {"appservice", "as_token"}, {"appservice", "username_template"}, {"matrix"}, + {"analytics"}, {"provisioning"}, {"public_media"}, {"direct_media"}, diff --git a/bridgev2/matrix/analytics.go b/bridgev2/matrix/analytics.go new file mode 100644 index 00000000..92ea2104 --- /dev/null +++ b/bridgev2/matrix/analytics.go @@ -0,0 +1,62 @@ +package matrix + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + + "maunium.net/go/mautrix/id" +) + +func (br *Connector) trackSync(userID id.UserID, event string, properties map[string]any) error { + var buf bytes.Buffer + var analyticsUserID string + if br.Config.Analytics.UserID != "" { + analyticsUserID = br.Config.Analytics.UserID + } else { + analyticsUserID = userID.String() + } + err := json.NewEncoder(&buf).Encode(map[string]any{ + "userId": analyticsUserID, + "event": event, + "properties": properties, + }) + if err != nil { + return err + } + + req, err := http.NewRequest(http.MethodPost, br.Config.Analytics.URL, &buf) + if err != nil { + return err + } + req.SetBasicAuth(br.Config.Analytics.Token, "") + resp, err := br.AS.HTTPClient.Do(req) + if err != nil { + return err + } + _ = resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected status code %d", resp.StatusCode) + } + return nil +} + +func (br *Connector) Track(userID id.UserID, event string, props map[string]any) { + if br.Config.Analytics.Token == "" || br.Config.Analytics.URL == "" { + return + } + + if props == nil { + props = map[string]any{} + } + props["bridge"] = br.Bridge.Network.GetName().BeeperBridgeType + go func() { + err := br.trackSync(userID, event, props) + if err != nil { + br.Log.Err(err).Str("component", "analytics").Str("event", event).Msg("Error tracking event") + } else { + br.Log.Debug().Str("component", "analytics").Str("event", event).Msg("Tracked event") + } + }() +} diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 4297cba7..c5df2421 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -103,6 +103,9 @@ var ( _ bridgev2.MatrixConnector = (*Connector)(nil) _ bridgev2.MatrixConnectorWithServer = (*Connector)(nil) _ bridgev2.MatrixConnectorWithPostRoomBridgeHandling = (*Connector)(nil) + _ bridgev2.MatrixConnectorWithPublicMedia = (*Connector)(nil) + _ bridgev2.MatrixConnectorWithNameDisambiguation = (*Connector)(nil) + _ bridgev2.MatrixConnectorWithAnalytics = (*Connector)(nil) ) func NewConnector(cfg *bridgeconfig.Config) *Connector { diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 31490bb3..7e642f96 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -205,6 +205,15 @@ matrix: # rather than keeping the whole file in memory. upload_file_threshold: 5242880 +# Segment-compatible analytics endpoint for tracking some events, like provisioning API login and encryption errors. +analytics: + # API key to send with tracking requests. Tracking is disabled if this is null. + token: null + # Address to send tracking requests to. + url: https://api.segment.io/v1/track + # Optional user ID for tracking events. If null, defaults to using Matrix user ID. + user_id: null + # Settings for provisioning API provisioning: # Prefix for the provisioning API paths. diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 4473b74e..6ff69250 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -73,6 +73,10 @@ type MatrixConnectorWithPostRoomBridgeHandling interface { HandleNewlyBridgedRoom(ctx context.Context, roomID id.RoomID) error } +type MatrixConnectorWithAnalytics interface { + Track(userID id.UserID, event string, properties map[string]any) +} + type MatrixSendExtra struct { Timestamp time.Time MessageMeta *database.Message diff --git a/bridgev2/user.go b/bridgev2/user.go index 5c2344e8..fbb8095e 100644 --- a/bridgev2/user.go +++ b/bridgev2/user.go @@ -253,3 +253,14 @@ func (user *User) GetManagementRoom(ctx context.Context) (id.RoomID, error) { func (user *User) Save(ctx context.Context) error { return user.Bridge.DB.User.Update(ctx, user.User) } + +func (br *Bridge) Track(userID id.UserID, event string, props map[string]any) { + analyticSender, ok := br.Matrix.(MatrixConnectorWithAnalytics) + if ok { + analyticSender.Track(userID, event, props) + } +} + +func (user *User) Track(event string, props map[string]any) { + user.Bridge.Track(user.MXID, event, props) +} From 08d58d4d2a2e5e8050363e4edd0edef02ef61dbd Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 12 Sep 2024 17:08:16 +0300 Subject: [PATCH 0736/1647] bridgev2/analytics: rename method --- bridgev2/matrix/analytics.go | 2 +- bridgev2/matrixinterface.go | 2 +- bridgev2/user.go | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/bridgev2/matrix/analytics.go b/bridgev2/matrix/analytics.go index 92ea2104..7eb2a33a 100644 --- a/bridgev2/matrix/analytics.go +++ b/bridgev2/matrix/analytics.go @@ -42,7 +42,7 @@ func (br *Connector) trackSync(userID id.UserID, event string, properties map[st return nil } -func (br *Connector) Track(userID id.UserID, event string, props map[string]any) { +func (br *Connector) TrackAnalytics(userID id.UserID, event string, props map[string]any) { if br.Config.Analytics.Token == "" || br.Config.Analytics.URL == "" { return } diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 6ff69250..fe218db1 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -74,7 +74,7 @@ type MatrixConnectorWithPostRoomBridgeHandling interface { } type MatrixConnectorWithAnalytics interface { - Track(userID id.UserID, event string, properties map[string]any) + TrackAnalytics(userID id.UserID, event string, properties map[string]any) } type MatrixSendExtra struct { diff --git a/bridgev2/user.go b/bridgev2/user.go index fbb8095e..1530b865 100644 --- a/bridgev2/user.go +++ b/bridgev2/user.go @@ -254,13 +254,13 @@ func (user *User) Save(ctx context.Context) error { return user.Bridge.DB.User.Update(ctx, user.User) } -func (br *Bridge) Track(userID id.UserID, event string, props map[string]any) { +func (br *Bridge) TrackAnalytics(userID id.UserID, event string, props map[string]any) { analyticSender, ok := br.Matrix.(MatrixConnectorWithAnalytics) if ok { - analyticSender.Track(userID, event, props) + analyticSender.TrackAnalytics(userID, event, props) } } -func (user *User) Track(event string, props map[string]any) { - user.Bridge.Track(user.MXID, event, props) +func (user *User) TrackAnalytics(event string, props map[string]any) { + user.Bridge.TrackAnalytics(user.MXID, event, props) } From d472be34126bdd86ae4f5503b89b71a558f9db27 Mon Sep 17 00:00:00 2001 From: Scott Weber Date: Thu, 12 Sep 2024 11:09:02 -0400 Subject: [PATCH 0737/1647] event: add BeeperHSSuborder to unsigned (#287) --- event/events.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/event/events.go b/event/events.go index e5a01d02..23769ae8 100644 --- a/event/events.go +++ b/event/events.go @@ -145,11 +145,12 @@ type Unsigned struct { InviteRoomState []StrippedState `json:"invite_room_state,omitempty"` BeeperHSOrder int64 `json:"com.beeper.hs.order,omitempty"` + BeeperHSSuborder int64 `json:"com.beeper.hs.suborder,omitempty"` BeeperFromBackup bool `json:"com.beeper.from_backup,omitempty"` } func (us *Unsigned) IsEmpty() bool { return us.PrevContent == nil && us.PrevSender == "" && us.ReplacesState == "" && us.Age == 0 && us.TransactionID == "" && us.RedactedBecause == nil && us.InviteRoomState == nil && us.Relations == nil && - us.BeeperHSOrder == 0 + us.BeeperHSOrder == 0 && us.BeeperHSSuborder == 0 } From e16b681c100a31b9f6cb6aad78ea45064032f582 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 13 Sep 2024 00:46:30 +0300 Subject: [PATCH 0738/1647] crypto/helper: allow overriding post-decrypt function --- crypto/cryptohelper/cryptohelper.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go index 9aed5121..0b3fbeaa 100644 --- a/crypto/cryptohelper/cryptohelper.go +++ b/crypto/cryptohelper/cryptohelper.go @@ -38,7 +38,8 @@ type CryptoHelper struct { LoginAs *mautrix.ReqLogin - ASEventProcessor crypto.ASEventProcessor + ASEventProcessor crypto.ASEventProcessor + CustomPostDecrypt func(context.Context, *event.Event) DBAccountID string } @@ -320,7 +321,9 @@ func (helper *CryptoHelper) HandleEncrypted(ctx context.Context, evt *event.Even func (helper *CryptoHelper) postDecrypt(ctx context.Context, decrypted *event.Event) { decrypted.Mautrix.EventSource |= event.SourceDecrypted - if helper.ASEventProcessor != nil { + if helper.CustomPostDecrypt != nil { + helper.CustomPostDecrypt(ctx, decrypted) + } else if helper.ASEventProcessor != nil { helper.ASEventProcessor.Dispatch(ctx, decrypted) } else { helper.client.Syncer.(mautrix.DispatchableSyncer).Dispatch(ctx, decrypted) From 012aed97e32e8b06f61e28ca482458e1f285635e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 13 Sep 2024 01:33:29 +0300 Subject: [PATCH 0739/1647] error: add WithMessage and Write helpers --- bridgev2/errors.go | 9 +++++++++ error.go | 17 +++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/bridgev2/errors.go b/bridgev2/errors.go index e80377cf..7f925b29 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -9,6 +9,7 @@ package bridgev2 import ( "errors" "fmt" + "net/http" "maunium.net/go/mautrix" ) @@ -80,6 +81,14 @@ func (re RespError) Is(err error) bool { return errors.Is(err, mautrix.RespError(re)) } +func (re *RespError) Write(w http.ResponseWriter) { + (*mautrix.RespError)(re).Write(w) +} + +func (re RespError) WithMessage(msg string, args ...any) RespError { + return RespError(mautrix.RespError(re).WithMessage(msg, args...)) +} + func (re RespError) AppendMessage(append string, args ...any) RespError { re.Err += fmt.Sprintf(append, args...) return re diff --git a/error.go b/error.go index acd90892..906bbd62 100644 --- a/error.go +++ b/error.go @@ -12,6 +12,7 @@ import ( "fmt" "net/http" + "go.mau.fi/util/exhttp" "golang.org/x/exp/maps" ) @@ -142,6 +143,22 @@ func (e *RespError) MarshalJSON() ([]byte, error) { return json.Marshal(data) } +func (e *RespError) Write(w http.ResponseWriter) { + statusCode := e.StatusCode + if statusCode == 0 { + statusCode = http.StatusInternalServerError + } + exhttp.WriteJSONResponse(w, statusCode, e) +} + +func (e RespError) WithMessage(msg string, args ...any) RespError { + if len(args) > 0 { + msg = fmt.Sprintf(msg, args...) + } + e.Err = msg + return e +} + // Error returns the errcode and error message. func (e RespError) Error() string { return e.ErrCode + ": " + e.Err From 36ef69bf7f4357a7e1451812b55b786ba23bf4b6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 13 Sep 2024 01:36:38 +0300 Subject: [PATCH 0740/1647] error: remove pointer receiver from Write Otherwise it can't be chained after WithMessage --- bridgev2/errors.go | 4 ++-- error.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/bridgev2/errors.go b/bridgev2/errors.go index 7f925b29..0b1ef8b3 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -81,8 +81,8 @@ func (re RespError) Is(err error) bool { return errors.Is(err, mautrix.RespError(re)) } -func (re *RespError) Write(w http.ResponseWriter) { - (*mautrix.RespError)(re).Write(w) +func (re RespError) Write(w http.ResponseWriter) { + mautrix.RespError(re).Write(w) } func (re RespError) WithMessage(msg string, args ...any) RespError { diff --git a/error.go b/error.go index 906bbd62..7de5666d 100644 --- a/error.go +++ b/error.go @@ -143,12 +143,12 @@ func (e *RespError) MarshalJSON() ([]byte, error) { return json.Marshal(data) } -func (e *RespError) Write(w http.ResponseWriter) { +func (e RespError) Write(w http.ResponseWriter) { statusCode := e.StatusCode if statusCode == 0 { statusCode = http.StatusInternalServerError } - exhttp.WriteJSONResponse(w, statusCode, e) + exhttp.WriteJSONResponse(w, statusCode, &e) } func (e RespError) WithMessage(msg string, args ...any) RespError { From 65364d0133a902d932eb12a0668f3de7f31bf937 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 13 Sep 2024 01:38:18 +0300 Subject: [PATCH 0741/1647] error: add MUnknown --- error.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/error.go b/error.go index 7de5666d..2f9ab983 100644 --- a/error.go +++ b/error.go @@ -25,6 +25,9 @@ import ( // // logout // } var ( + // Generic error for when the server encounters an error and it does not have a more specific error code. + // Note that `errors.Is` will check the error message rather than code for M_UNKNOWNs. + MUnknown = RespError{ErrCode: "M_UNKNOWN", StatusCode: http.StatusInternalServerError} // Forbidden access, e.g. joining a room without permission, failed login. MForbidden = RespError{ErrCode: "M_FORBIDDEN", StatusCode: http.StatusForbidden} // Unrecognized request, e.g. the endpoint does not exist or is not implemented. From e51e36ac99a0dadc62ab2c36cec8f3d34347b0db Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 13 Sep 2024 01:53:45 +0300 Subject: [PATCH 0742/1647] client: drop support for unauthenticated media --- client.go | 109 ++++-------------------------------------------------- 1 file changed, 8 insertions(+), 101 deletions(-) diff --git a/client.go b/client.go index 36f980bf..b237612b 100644 --- a/client.go +++ b/client.go @@ -1467,13 +1467,7 @@ func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomSt // GetMediaConfig fetches the configuration of the content repository, such as upload limitations. func (cli *Client) GetMediaConfig(ctx context.Context) (resp *RespMediaConfig, err error) { - var u string - if cli.SpecVersions.Supports(FeatureAuthenticatedMedia) { - u = cli.BuildClientURL("v1", "media", "config") - } else { - u = cli.BuildURL(MediaURLPath{"v3", "config"}) - } - _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildClientURL("v1", "media", "config"), nil, &resp) return } @@ -1494,94 +1488,13 @@ func (cli *Client) UploadLink(ctx context.Context, link string) (*RespMediaUploa return cli.Upload(ctx, res.Body, res.Header.Get("Content-Type"), res.ContentLength) } -// Deprecated: unauthenticated media is deprecated as of Matrix v1.11. Use [Download] or [DownloadBytes] instead. -func (cli *Client) GetDownloadURL(mxcURL id.ContentURI) string { - return cli.BuildURLWithQuery(MediaURLPath{"v3", "download", mxcURL.Homeserver, mxcURL.FileID}, map[string]string{"allow_redirect": "true"}) -} - -func (cli *Client) doMediaRetry(req *http.Request, cause error, retries int, backoff time.Duration) (*http.Response, error) { - log := zerolog.Ctx(req.Context()) - if req.Body != nil { - var err error - if req.GetBody != nil { - req.Body, err = req.GetBody() - if err != nil { - log.Warn().Err(err).Msg("Failed to get new body to retry request") - return nil, cause - } - } else if bodySeeker, ok := req.Body.(io.ReadSeeker); ok { - _, err = bodySeeker.Seek(0, io.SeekStart) - if err != nil { - log.Warn().Err(err).Msg("Failed to seek to beginning of request body") - return nil, cause - } - } else { - log.Warn().Msg("Failed to get new body to retry request: GetBody is nil and Body is not an io.ReadSeeker") - return nil, cause - } - } - log.Warn().Err(cause). - Int("retry_in_seconds", int(backoff.Seconds())). - Msg("Request failed, retrying") - time.Sleep(backoff) - return cli.doMediaRequest(req, retries-1, backoff*2) -} - -func (cli *Client) doMediaRequest(req *http.Request, retries int, backoff time.Duration) (*http.Response, error) { - cli.RequestStart(req) - startTime := time.Now() - res, err := cli.Client.Do(req) - duration := time.Now().Sub(startTime) - if err != nil { - if retries > 0 { - return cli.doMediaRetry(req, err, retries, backoff) - } - err = HTTPError{ - Request: req, - Response: res, - - Message: "request error", - WrappedError: err, - } - cli.LogRequestDone(req, res, err, nil, 0, duration) - return nil, err - } - - if retries > 0 && retryafter.Should(res.StatusCode, !cli.IgnoreRateLimit) { - backoff = retryafter.Parse(res.Header.Get("Retry-After"), backoff) - return cli.doMediaRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff) - } - - if res.StatusCode < 200 || res.StatusCode >= 300 { - var body []byte - body, err = ParseErrorResponse(req, res) - cli.LogRequestDone(req, res, err, nil, len(body), duration) - } else { - cli.LogRequestDone(req, res, nil, nil, -1, duration) - } - return res, err -} - func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (*http.Response, error) { - ctxLog := zerolog.Ctx(ctx) - if ctxLog.GetLevel() == zerolog.Disabled || ctxLog == zerolog.DefaultContextLogger { - ctx = cli.Log.WithContext(ctx) - } - if cli.SpecVersions.Supports(FeatureAuthenticatedMedia) { - _, resp, err := cli.MakeFullRequestWithResp(ctx, FullRequest{ - Method: http.MethodGet, - URL: cli.BuildClientURL("v1", "media", "download", mxcURL.Homeserver, mxcURL.FileID), - DontReadResponse: true, - }) - return resp, err - } else { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, cli.GetDownloadURL(mxcURL), nil) - if err != nil { - return nil, err - } - req.Header.Set("User-Agent", cli.UserAgent+" (media downloader)") - return cli.doMediaRequest(req, cli.DefaultHTTPRetries, 4*time.Second) - } + _, resp, err := cli.MakeFullRequestWithResp(ctx, FullRequest{ + Method: http.MethodGet, + URL: cli.BuildClientURL("v1", "media", "download", mxcURL.Homeserver, mxcURL.FileID), + DontReadResponse: true, + }) + return resp, err } func (cli *Client) DownloadBytes(ctx context.Context, mxcURL id.ContentURI) ([]byte, error) { @@ -1789,13 +1702,7 @@ func (cli *Client) UploadMedia(ctx context.Context, data ReqUploadMedia) (*RespM // // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixmediav3preview_url func (cli *Client) GetURLPreview(ctx context.Context, url string) (*RespPreviewURL, error) { - var urlPath PrefixableURLPath - if cli.SpecVersions.Supports(FeatureAuthenticatedMedia) { - urlPath = ClientURLPath{"v1", "media", "preview_url"} - } else { - urlPath = MediaURLPath{"v3", "preview_url"} - } - reqURL := cli.BuildURLWithQuery(urlPath, map[string]string{ + reqURL := cli.BuildURLWithQuery(ClientURLPath{"v1", "media", "preview_url"}, map[string]string{ "url": url, }) var output RespPreviewURL From 96e68fb485d2a2654e8d585046706245aac3215e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 13 Sep 2024 12:16:47 +0300 Subject: [PATCH 0743/1647] dependencies: update --- go.mod | 12 +++++++----- go.sum | 16 ++++++++-------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/go.mod b/go.mod index 78d1b8c4..f1518cc0 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module maunium.net/go/mautrix -go 1.22 +go 1.22.0 + +toolchain go1.23.1 require ( filippo.io/edwards25519 v1.1.0 @@ -16,11 +18,11 @@ require ( github.com/tidwall/gjson v1.17.3 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.4 - go.mau.fi/util v0.7.1-0.20240904173517-ca3b3fe376c2 + go.mau.fi/util v0.7.1-0.20240913091524-7617daa66719 go.mau.fi/zeroconfig v0.1.3 - golang.org/x/crypto v0.26.0 - golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 - golang.org/x/net v0.28.0 + golang.org/x/crypto v0.27.0 + golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 + golang.org/x/net v0.29.0 golang.org/x/sync v0.8.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 diff --git a/go.sum b/go.sum index 0f1a0558..6aad56c7 100644 --- a/go.sum +++ b/go.sum @@ -51,16 +51,16 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.7.1-0.20240904173517-ca3b3fe376c2 h1:VZQlKBbeJ7KOlYSh6BnN5uWQTY/ypn/bJv0YyEd+pXc= -go.mau.fi/util v0.7.1-0.20240904173517-ca3b3fe376c2/go.mod h1:WgYvbt9rVmoFeajP97NunQU7AjgvTPiNExN3oTHeePs= +go.mau.fi/util v0.7.1-0.20240913091524-7617daa66719 h1:sg1P/f4RHY1JuAwsPOjTCsZr8ROzR9bRTtnvvBu42d4= +go.mau.fi/util v0.7.1-0.20240913091524-7617daa66719/go.mod h1:1Ixb8HWoVbl3rT6nAX6nV4iMkzn7KU/KXwE0Rn5RmsQ= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= -golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= -golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 h1:kx6Ds3MlpiUHKj7syVnbp57++8WpuKPcR5yjLBjvLEA= -golang.org/x/exp v0.0.0-20240823005443-9b4947da3948/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ= -golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= -golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= +golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= +golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= +golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk= +golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY= +golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= +golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= From e12ecbe82d36fc985074b1996568264991140e23 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 13 Sep 2024 12:55:32 +0300 Subject: [PATCH 0744/1647] bridgev2/matrix: allow key sharing for bridge admins --- bridgev2/matrix/crypto.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/bridgev2/matrix/crypto.go b/bridgev2/matrix/crypto.go index 7383ddc7..04654ff5 100644 --- a/bridgev2/matrix/crypto.go +++ b/bridgev2/matrix/crypto.go @@ -217,11 +217,12 @@ func (helper *CryptoHelper) allowKeyShare(ctx context.Context, device *id.Device zerolog.Ctx(ctx).Err(err).Msg("Failed to get user to handle key request") return &crypto.KeyShareRejectNoResponse } else if user == nil { - // TODO + zerolog.Ctx(ctx).Debug().Msg("Couldn't find user to handle key request") return &crypto.KeyShareRejectNoResponse - } else if true { - // TODO admin check and is in room check - return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnauthorized, Reason: "Key sharing is not yet implemented in bridgev2"} + } else if !user.Permissions.Admin { + zerolog.Ctx(ctx).Debug().Msg("Rejecting key request: user is not admin") + // TODO is in room check? + return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnauthorized, Reason: "Key sharing for non-admins is not yet implemented"} } zerolog.Ctx(ctx).Debug().Msg("Accepting key request") return nil From c2d0f4cf5d25d375c06b422e61dc28903fed46bf Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 13 Sep 2024 13:27:34 +0300 Subject: [PATCH 0745/1647] ci: don't allow go to update itself --- .github/workflows/go.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 30f05d69..9117286f 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -2,6 +2,9 @@ name: Go on: [push, pull_request] +env: + GOTOOLCHAIN: local + jobs: lint: runs-on: ubuntu-latest From 41213c6230ae5bfa30bf54485270141a8221c843 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 13 Sep 2024 21:16:25 +0300 Subject: [PATCH 0746/1647] bridgev2/matrix: add separators to data when generating deterministic event ids --- bridgev2/matrix/connector.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index c5df2421..8bc31f0c 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -594,9 +594,11 @@ func (br *Connector) BatchSend(ctx context.Context, roomID id.RoomID, req *mautr } func (br *Connector) GenerateDeterministicEventID(roomID id.RoomID, _ networkid.PortalKey, messageID networkid.MessageID, partID networkid.PartID) id.EventID { - data := make([]byte, 0, len(roomID)+len(messageID)+len(partID)) + data := make([]byte, 0, len(roomID)+1+len(messageID)+1+len(partID)) data = append(data, roomID...) + data = append(data, 0) data = append(data, messageID...) + data = append(data, 0) data = append(data, partID...) hash := sha256.Sum256(data) From ff4126b5d04d7f4b2ffe96a005f1a9de0844916d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 13 Sep 2024 21:20:28 +0300 Subject: [PATCH 0747/1647] bridgev2/database: delete duplicate mxids in migration --- bridgev2/database/upgrades/17-message-mxid-unique.sql | 1 + 1 file changed, 1 insertion(+) diff --git a/bridgev2/database/upgrades/17-message-mxid-unique.sql b/bridgev2/database/upgrades/17-message-mxid-unique.sql index 05503191..ee53b3f0 100644 --- a/bridgev2/database/upgrades/17-message-mxid-unique.sql +++ b/bridgev2/database/upgrades/17-message-mxid-unique.sql @@ -1,4 +1,5 @@ -- v17 (compatible with v9+): Add unique constraint for message and reaction mxids +DELETE FROM message WHERE mxid IN (SELECT mxid FROM message GROUP BY mxid HAVING COUNT(*) > 1); -- only: postgres for next 2 lines ALTER TABLE message ADD CONSTRAINT message_mxid_unique UNIQUE (bridge_id, mxid); ALTER TABLE reaction ADD CONSTRAINT reaction_mxid_unique UNIQUE (bridge_id, mxid); From a5c4446a2271c3b1d5f76ce37f0998228728dd2d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 13 Sep 2024 23:36:47 +0300 Subject: [PATCH 0748/1647] bridgev2: add option to split all portals by user login --- bridgev2/bridgeconfig/config.go | 1 + bridgev2/bridgeconfig/upgrade.go | 1 + bridgev2/database/portal.go | 38 ++++++++------ bridgev2/database/upgrades/00-latest.sql | 2 - bridgev2/matrix/mxmain/example-config.yaml | 5 ++ bridgev2/portal.go | 59 +++++++++++++++------- bridgev2/queue.go | 2 +- bridgev2/userlogin.go | 7 +++ 8 files changed, 78 insertions(+), 37 deletions(-) diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index 1731688d..06900525 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -61,6 +61,7 @@ type BridgeConfig struct { PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"` PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"` AsyncEvents bool `yaml:"async_events"` + SplitPortals bool `yaml:"split_portals"` BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` TagOnlyOnCreate bool `yaml:"tag_only_on_create"` MuteOnlyOnCreate bool `yaml:"mute_only_on_create"` diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 4491dedd..70b20603 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -26,6 +26,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "bridge", "personal_filtering_spaces") helper.Copy(up.Bool, "bridge", "private_chat_portal_meta") helper.Copy(up.Bool, "bridge", "async_events") + helper.Copy(up.Bool, "bridge", "split_portals") helper.Copy(up.Bool, "bridge", "bridge_matrix_leave") helper.Copy(up.Bool, "bridge", "tag_only_on_create") helper.Copy(up.Bool, "bridge", "mute_only_on_create") diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index bc1f2658..adbf1e6c 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -39,7 +39,7 @@ type Portal struct { networkid.PortalKey MXID id.RoomID - ParentID networkid.PortalID + ParentKey networkid.PortalKey RelayLoginID networkid.UserLoginID OtherUserID networkid.UserID Name string @@ -59,7 +59,7 @@ type Portal struct { const ( getPortalBaseQuery = ` - SELECT bridge_id, id, receiver, mxid, parent_id, relay_login_id, other_user_id, + SELECT bridge_id, id, receiver, mxid, parent_id, parent_receiver, relay_login_id, other_user_id, name, topic, avatar_id, avatar_hash, avatar_mxc, name_set, topic_set, avatar_set, name_is_custom, in_space, room_type, disappear_type, disappear_timer, @@ -72,29 +72,30 @@ const ( getAllPortalsWithMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid IS NOT NULL` getAllDMPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND other_user_id=$2` getAllPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1` - getChildPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND parent_id=$2` + getChildPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND parent_id=$2 AND parent_receiver=$3` findPortalReceiverQuery = `SELECT id, receiver FROM portal WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='') LIMIT 1` insertPortalQuery = ` INSERT INTO portal ( bridge_id, id, receiver, mxid, - parent_id, relay_login_id, other_user_id, + parent_id, parent_receiver, relay_login_id, other_user_id, name, topic, avatar_id, avatar_hash, avatar_mxc, name_set, avatar_set, topic_set, name_is_custom, in_space, room_type, disappear_type, disappear_timer, metadata, relay_bridge_id ) VALUES ( - $1, $2, $3, $4, $5, cast($6 AS TEXT), $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, - CASE WHEN cast($6 AS TEXT) IS NULL THEN NULL ELSE $1 END + $1, $2, $3, $4, $5, $6, cast($7 AS TEXT), $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, + CASE WHEN cast($7 AS TEXT) IS NULL THEN NULL ELSE $1 END ) ` updatePortalQuery = ` UPDATE portal - SET mxid=$4, parent_id=$5, relay_login_id=cast($6 AS TEXT), relay_bridge_id=CASE WHEN cast($6 AS TEXT) IS NULL THEN NULL ELSE bridge_id END, - other_user_id=$7, name=$8, topic=$9, avatar_id=$10, avatar_hash=$11, avatar_mxc=$12, - name_set=$13, avatar_set=$14, topic_set=$15, name_is_custom=$16, in_space=$17, - room_type=$18, disappear_type=$19, disappear_timer=$20, metadata=$21 + SET mxid=$4, parent_id=$5, parent_receiver=$6, + relay_login_id=cast($7 AS TEXT), relay_bridge_id=CASE WHEN cast($7 AS TEXT) IS NULL THEN NULL ELSE bridge_id END, + other_user_id=$8, name=$9, topic=$10, avatar_id=$11, avatar_hash=$12, avatar_mxc=$13, + name_set=$14, avatar_set=$15, topic_set=$16, name_is_custom=$17, in_space=$18, + room_type=$19, disappear_type=$20, disappear_timer=$21, metadata=$22 WHERE bridge_id=$1 AND id=$2 AND receiver=$3 ` deletePortalQuery = ` @@ -136,8 +137,8 @@ func (pq *PortalQuery) GetAllDMsWith(ctx context.Context, otherUserID networkid. return pq.QueryMany(ctx, getAllDMPortalsQuery, pq.BridgeID, otherUserID) } -func (pq *PortalQuery) GetChildren(ctx context.Context, parentID networkid.PortalID) ([]*Portal, error) { - return pq.QueryMany(ctx, getChildPortalsQuery, pq.BridgeID, parentID) +func (pq *PortalQuery) GetChildren(ctx context.Context, parentKey networkid.PortalKey) ([]*Portal, error) { + return pq.QueryMany(ctx, getChildPortalsQuery, pq.BridgeID, parentKey.ID, parentKey.Receiver) } func (pq *PortalQuery) ReID(ctx context.Context, oldID, newID networkid.PortalKey) error { @@ -159,12 +160,12 @@ func (pq *PortalQuery) Delete(ctx context.Context, key networkid.PortalKey) erro } func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { - var mxid, parentID, relayLoginID, otherUserID, disappearType sql.NullString + var mxid, parentID, parentReceiver, relayLoginID, otherUserID, disappearType sql.NullString var disappearTimer sql.NullInt64 var avatarHash string err := row.Scan( &p.BridgeID, &p.ID, &p.Receiver, &mxid, - &parentID, &relayLoginID, &otherUserID, + &parentID, &parentReceiver, &relayLoginID, &otherUserID, &p.Name, &p.Topic, &p.AvatarID, &avatarHash, &p.AvatarMXC, &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.NameIsCustom, &p.InSpace, &p.RoomType, &disappearType, &disappearTimer, @@ -187,7 +188,12 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { } p.MXID = id.RoomID(mxid.String) p.OtherUserID = networkid.UserID(otherUserID.String) - p.ParentID = networkid.PortalID(parentID.String) + if parentID.Valid { + p.ParentKey = networkid.PortalKey{ + ID: networkid.PortalID(parentID.String), + Receiver: networkid.UserLoginID(parentReceiver.String), + } + } p.RelayLoginID = networkid.UserLoginID(relayLoginID.String) return p, nil } @@ -206,7 +212,7 @@ func (p *Portal) sqlVariables() []any { } return []any{ p.BridgeID, p.ID, p.Receiver, dbutil.StrPtr(p.MXID), - dbutil.StrPtr(p.ParentID), dbutil.StrPtr(p.RelayLoginID), dbutil.StrPtr(p.OtherUserID), + dbutil.StrPtr(p.ParentKey.ID), p.ParentKey.Receiver, dbutil.StrPtr(p.RelayLoginID), dbutil.StrPtr(p.OtherUserID), p.Name, p.Topic, p.AvatarID, avatarHash, p.AvatarMXC, p.NameSet, p.TopicSet, p.AvatarSet, p.NameIsCustom, p.InSpace, p.RoomType, dbutil.StrPtr(p.Disappear.Type), dbutil.NumPtr(p.Disappear.Timer), diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index 80acfab1..8c9e0627 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -31,8 +31,6 @@ CREATE TABLE portal ( mxid TEXT, parent_id TEXT, - -- This is not accessed by the bridge, it's only used for the portal parent foreign key. - -- Parent groups are probably never DMs, so they don't need a receiver. parent_receiver TEXT NOT NULL DEFAULT '', relay_bridge_id TEXT, diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 7e642f96..f7736389 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -10,6 +10,11 @@ bridge: # Should events be handled asynchronously within portal rooms? # If true, events may end up being out of order, but slow events won't block other ones. async_events: false + # Should every user have their own portals rather than sharing them? + # By default, users who are in the same group on the remote network will be + # in the same Matrix room bridged to that group. If this is set to true, + # every user will get their own Matrix room instead. + split_portals: false # Should leaving Matrix rooms be bridged as leaving groups on the remote network? bridge_matrix_leave: false diff --git a/bridgev2/portal.go b/bridgev2/portal.go index c5847f35..1c9a12a4 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -116,10 +116,10 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que br.portalsByMXID[portal.MXID] = portal } var err error - if portal.ParentID != "" { - portal.Parent, err = br.UnlockedGetPortalByKey(ctx, networkid.PortalKey{ID: portal.ParentID}, false) + if portal.ParentKey.ID != "" { + portal.Parent, err = br.UnlockedGetPortalByKey(ctx, portal.ParentKey, false) if err != nil { - return nil, fmt.Errorf("failed to load parent portal (%s): %w", portal.ParentID, err) + return nil, fmt.Errorf("failed to load parent portal (%s): %w", portal.ParentKey, err) } } if portal.RelayLoginID != "" { @@ -159,6 +159,9 @@ func (br *Bridge) loadManyPortals(ctx context.Context, portals []*database.Porta } func (br *Bridge) UnlockedGetPortalByKey(ctx context.Context, key networkid.PortalKey, onlyIfExists bool) (*Portal, error) { + if br.Config.SplitPortals && key.Receiver == "" { + return nil, fmt.Errorf("receiver must always be set when split portals is enabled") + } cached, ok := br.portalsByKey[key] if ok { return cached, nil @@ -184,6 +187,9 @@ func (br *Bridge) FindPortalReceiver(ctx context.Context, id networkid.PortalID, } func (br *Bridge) FindCachedPortalReceiver(id networkid.PortalID, maybeReceiver networkid.UserLoginID) networkid.PortalKey { + if br.Config.SplitPortals { + return networkid.PortalKey{ID: id, Receiver: maybeReceiver} + } br.cacheLock.Lock() defer br.cacheLock.Unlock() portal, ok := br.portalsByKey[networkid.PortalKey{ @@ -250,7 +256,7 @@ func (br *Bridge) GetPortalByKey(ctx context.Context, key networkid.PortalKey) ( func (br *Bridge) GetExistingPortalByKey(ctx context.Context, key networkid.PortalKey) (*Portal, error) { br.cacheLock.Lock() defer br.cacheLock.Unlock() - if key.Receiver == "" { + if key.Receiver == "" || br.Config.SplitPortals { return br.UnlockedGetPortalByKey(ctx, key, true) } cached, ok := br.portalsByKey[key] @@ -2897,7 +2903,7 @@ func (portal *Portal) getInitialMemberList(ctx context.Context, members *ChatMem return } var loginsInPortal []*UserLogin - if members.CheckAllLogins { + if members.CheckAllLogins && !portal.Bridge.Config.SplitPortals { loginsInPortal, err = portal.Bridge.GetUserLoginsInPortal(ctx, portal.PortalKey) if err != nil { err = fmt.Errorf("failed to get user logins in portal: %w", err) @@ -2971,7 +2977,7 @@ func (portal *Portal) syncParticipants(ctx context.Context, members *ChatMemberL members.memberListToMap(ctx) var loginsInPortal []*UserLogin var err error - if members.CheckAllLogins { + if members.CheckAllLogins && !portal.Bridge.Config.SplitPortals { loginsInPortal, err = portal.Bridge.GetUserLoginsInPortal(ctx, portal.PortalKey) if err != nil { return fmt.Errorf("failed to get user logins in portal: %w", err) @@ -3208,8 +3214,12 @@ func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting dat return true } -func (portal *Portal) updateParent(ctx context.Context, newParent networkid.PortalID, source *UserLogin) bool { - if portal.ParentID == newParent { +func (portal *Portal) updateParent(ctx context.Context, newParentID networkid.PortalID, source *UserLogin) bool { + newParent := networkid.PortalKey{ID: newParentID} + if portal.Bridge.Config.SplitPortals { + newParent.Receiver = portal.Receiver + } + if portal.ParentKey == newParent { return false } var err error @@ -3219,10 +3229,10 @@ func (portal *Portal) updateParent(ctx context.Context, newParent networkid.Port zerolog.Ctx(ctx).Err(err).Stringer("old_space_mxid", portal.Parent.MXID).Msg("Failed to remove portal from old space") } } - portal.ParentID = newParent + portal.ParentKey = newParent portal.InSpace = false - if newParent != "" { - portal.Parent, err = portal.Bridge.GetPortalByKey(ctx, networkid.PortalKey{ID: newParent}) + if newParent.ID != "" { + portal.Parent, err = portal.Bridge.GetPortalByKey(ctx, newParent) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to get new parent portal") } @@ -3536,17 +3546,30 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo } } if portal.Parent == nil { - userPortals, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) - if err != nil { - log.Err(err).Msg("Failed to get user logins in portal to add portal to spaces") - } else { - for _, up := range userPortals { - login := portal.Bridge.GetCachedUserLoginByID(up.LoginID) - if login != nil { + if portal.Receiver != "" { + login := portal.Bridge.GetCachedUserLoginByID(portal.Receiver) + if login != nil { + up, err := portal.Bridge.DB.UserPortal.Get(ctx, login.UserLogin, portal.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to get user portal to add portal to spaces") + } else { login.inPortalCache.Remove(portal.PortalKey) go login.tryAddPortalToSpace(withoutCancelCtx, portal, up.CopyWithoutValues()) } } + } else { + userPortals, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to get user logins in portal to add portal to spaces") + } else { + for _, up := range userPortals { + login := portal.Bridge.GetCachedUserLoginByID(up.LoginID) + if login != nil { + login.inPortalCache.Remove(portal.PortalKey) + go login.tryAddPortalToSpace(withoutCancelCtx, portal, up.CopyWithoutValues()) + } + } + } } } if portal.Bridge.Config.Backfill.Enabled { diff --git a/bridgev2/queue.go b/bridgev2/queue.go index a79d56e3..213307a3 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -147,7 +147,7 @@ func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) { key := evt.GetPortalKey() var portal *Portal var err error - if isUncertain { + if isUncertain && !br.Config.SplitPortals { portal, err = br.GetExistingPortalByKey(ctx, key) } else { portal, err = br.GetPortalByKey(ctx, key) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 2df43425..f9b8f7b1 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -96,6 +96,13 @@ func (br *Bridge) unlockedLoadUserLoginsByMXID(ctx context.Context, user *User) } func (br *Bridge) GetUserLoginsInPortal(ctx context.Context, portal networkid.PortalKey) ([]*UserLogin, error) { + if portal.Receiver != "" { + ul := br.GetCachedUserLoginByID(portal.Receiver) + if ul == nil { + return nil, nil + } + return []*UserLogin{ul}, nil + } logins, err := br.DB.UserLogin.GetAllInPortal(ctx, portal) if err != nil { return nil, err From d89dac594db03955c537dd2c73aebe5f9f3df064 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 14 Sep 2024 12:37:56 +0300 Subject: [PATCH 0749/1647] bridgev2: automatically update old portals when enabling split portals --- bridgev2/bridge.go | 33 +++++++++++++ bridgev2/bridgeconfig/config.go | 1 + bridgev2/bridgeconfig/upgrade.go | 3 ++ bridgev2/database/database.go | 5 ++ bridgev2/database/kvstore.go | 56 ++++++++++++++++++++++ bridgev2/database/portal.go | 22 ++++++++- bridgev2/database/upgrades/00-latest.sql | 10 +++- bridgev2/database/upgrades/18-kv-store.sql | 8 ++++ bridgev2/matrix/mxmain/example-config.yaml | 2 + 9 files changed, 138 insertions(+), 2 deletions(-) create mode 100644 bridgev2/database/kvstore.go create mode 100644 bridgev2/database/upgrades/18-kv-store.sql diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index e64d7a40..2b520e23 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -120,6 +120,7 @@ func (br *Bridge) StartConnectors() error { if err != nil { return DBUpgradeError{Err: err, Section: "main"} } + didSplitPortals := br.MigrateToSplitPortals(ctx) br.Log.Info().Msg("Starting Matrix connector") err = br.Matrix.Start(ctx) if err != nil { @@ -133,9 +134,41 @@ func (br *Bridge) StartConnectors() error { if br.Network.GetCapabilities().DisappearingMessages { go br.DisappearLoop.Start() } + if didSplitPortals || br.Config.ResendBridgeInfo { + br.ResendBridgeInfo(ctx) + } return nil } +func (br *Bridge) ResendBridgeInfo(ctx context.Context) { + log := zerolog.Ctx(ctx).With().Str("action", "resend bridge info").Logger() + portals, err := br.GetAllPortalsWithMXID(ctx) + if err != nil { + log.Err(err).Msg("Failed to get portals") + return + } + for _, portal := range portals { + portal.UpdateBridgeInfo(ctx) + } + log.Info().Msg("Resent bridge info to all portals") +} + +func (br *Bridge) MigrateToSplitPortals(ctx context.Context) bool { + log := zerolog.Ctx(ctx).With().Str("action", "migrate to split portals").Logger() + ctx = log.WithContext(ctx) + if !br.Config.SplitPortals || br.DB.KV.Get(ctx, database.KeySplitPortalsEnabled) == "true" { + return false + } + affected, err := br.DB.Portal.MigrateToSplitPortals(ctx) + if err != nil { + log.Err(err).Msg("Failed to migrate portals") + return false + } + log.Info().Int64("rows_affected", affected).Msg("Migrated to split portals") + br.DB.KV.Set(ctx, database.KeySplitPortalsEnabled, "true") + return affected > 0 +} + func (br *Bridge) StartLogins() error { ctx := br.Log.WithContext(context.Background()) diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index 06900525..aa07f42e 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -62,6 +62,7 @@ type BridgeConfig struct { PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"` AsyncEvents bool `yaml:"async_events"` SplitPortals bool `yaml:"split_portals"` + ResendBridgeInfo bool `yaml:"resend_bridge_info"` BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` TagOnlyOnCreate bool `yaml:"tag_only_on_create"` MuteOnlyOnCreate bool `yaml:"mute_only_on_create"` diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 70b20603..4122f4d6 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -27,6 +27,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "bridge", "private_chat_portal_meta") helper.Copy(up.Bool, "bridge", "async_events") helper.Copy(up.Bool, "bridge", "split_portals") + helper.Copy(up.Bool, "bridge", "resend_bridge_info") helper.Copy(up.Bool, "bridge", "bridge_matrix_leave") helper.Copy(up.Bool, "bridge", "tag_only_on_create") helper.Copy(up.Bool, "bridge", "mute_only_on_create") @@ -168,6 +169,8 @@ func doUpgrade(helper up.Helper) { var SpacedBlocks = [][]string{ {"bridge"}, + {"bridge", "bridge_matrix_leave"}, + {"bridge", "cleanup_on_logout"}, {"bridge", "relay"}, {"bridge", "permissions"}, {"database"}, diff --git a/bridgev2/database/database.go b/bridgev2/database/database.go index aa77a232..f1789441 100644 --- a/bridgev2/database/database.go +++ b/bridgev2/database/database.go @@ -33,6 +33,7 @@ type Database struct { UserLogin *UserLoginQuery UserPortal *UserPortalQuery BackfillTask *BackfillTaskQuery + KV *KVQuery } type MetaMerger interface { @@ -136,6 +137,10 @@ func New(bridgeID networkid.BridgeID, mt MetaTypes, db *dbutil.Database) *Databa return &BackfillTask{} }), }, + KV: &KVQuery{ + BridgeID: bridgeID, + Database: db, + }, } } diff --git a/bridgev2/database/kvstore.go b/bridgev2/database/kvstore.go new file mode 100644 index 00000000..3fc54f2c --- /dev/null +++ b/bridgev2/database/kvstore.go @@ -0,0 +1,56 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "errors" + + "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/bridgev2/networkid" +) + +type Key string + +const ( + KeySplitPortalsEnabled Key = "split_portals_enabled" +) + +type KVQuery struct { + BridgeID networkid.BridgeID + *dbutil.Database +} + +const ( + getKVQuery = `SELECT value FROM kv_store WHERE bridge_id = $1 AND key = $2` + setKVQuery = ` + INSERT INTO kv_store (bridge_id, key, value) VALUES ($1, $2, $3) + ON CONFLICT (bridge_id, key) DO UPDATE SET value = $3 + ` +) + +func (kvq *KVQuery) Get(ctx context.Context, key Key) string { + var value string + err := kvq.QueryRow(ctx, getKVQuery, kvq.BridgeID, key).Scan(&value) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + zerolog.Ctx(ctx).Err(err).Str("key", string(key)).Msg("Failed to get key from kvstore") + } + return value +} + +func (kvq *KVQuery) Set(ctx context.Context, key Key, value string) { + _, err := kvq.Exec(ctx, setKVQuery, kvq.BridgeID, key, value) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Str("key", string(key)). + Str("value", value). + Msg("Failed to set key in kvstore") + } +} diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index adbf1e6c..72e31454 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -102,7 +102,19 @@ const ( DELETE FROM portal WHERE bridge_id=$1 AND id=$2 AND receiver=$3 ` - reIDPortalQuery = `UPDATE portal SET id=$4, receiver=$5 WHERE bridge_id=$1 AND id=$2 AND receiver=$3` + reIDPortalQuery = `UPDATE portal SET id=$4, receiver=$5 WHERE bridge_id=$1 AND id=$2 AND receiver=$3` + migrateToSplitPortalsQuery = ` + UPDATE portal + SET receiver=COALESCE(( + SELECT login_id + FROM user_portal + WHERE bridge_id=portal.bridge_id AND portal_id=portal.id AND portal_receiver='' + LIMIT 1 + ), ( + SELECT id FROM user_login WHERE bridge_id=portal.bridge_id LIMIT 1 + ), '') + WHERE receiver='' AND bridge_id=$1 + ` ) func (pq *PortalQuery) GetByKey(ctx context.Context, key networkid.PortalKey) (*Portal, error) { @@ -159,6 +171,14 @@ func (pq *PortalQuery) Delete(ctx context.Context, key networkid.PortalKey) erro return pq.Exec(ctx, deletePortalQuery, pq.BridgeID, key.ID, key.Receiver) } +func (pq *PortalQuery) MigrateToSplitPortals(ctx context.Context) (int64, error) { + res, err := pq.GetDB().Exec(ctx, migrateToSplitPortalsQuery, pq.BridgeID) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { var mxid, parentID, parentReceiver, relayLoginID, otherUserID, disappearType sql.NullString var disappearTimer sql.NullInt64 diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index 8c9e0627..6d6dcf2c 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v17 (compatible with v9+): Latest revision +-- v0 -> v18 (compatible with v9+): Latest revision CREATE TABLE "user" ( bridge_id TEXT NOT NULL, mxid TEXT NOT NULL, @@ -198,3 +198,11 @@ CREATE TABLE backfill_task ( REFERENCES portal (bridge_id, id, receiver) ON DELETE CASCADE ON UPDATE CASCADE ); + +CREATE TABLE kv_store ( + bridge_id TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + + PRIMARY KEY (bridge_id, key) +); diff --git a/bridgev2/database/upgrades/18-kv-store.sql b/bridgev2/database/upgrades/18-kv-store.sql new file mode 100644 index 00000000..9d233095 --- /dev/null +++ b/bridgev2/database/upgrades/18-kv-store.sql @@ -0,0 +1,8 @@ +-- v18 (compatible with v9+): Add generic key-value store +CREATE TABLE kv_store ( + bridge_id TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + + PRIMARY KEY (bridge_id, key) +); diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index f7736389..8b9682ba 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -15,6 +15,8 @@ bridge: # in the same Matrix room bridged to that group. If this is set to true, # every user will get their own Matrix room instead. split_portals: false + # Should the bridge resend `m.bridge` events to all portals on startup? + resend_bridge_info: false # Should leaving Matrix rooms be bridged as leaving groups on the remote network? bridge_matrix_leave: false From cb4204ceb5226e420f2c68c606936047ef5366b2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 14 Sep 2024 12:54:05 +0300 Subject: [PATCH 0750/1647] bridgev2/networkid: update receiver docs --- bridgev2/networkid/bridgeid.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bridgev2/networkid/bridgeid.go b/bridgev2/networkid/bridgeid.go index 46f82155..d78813eb 100644 --- a/bridgev2/networkid/bridgeid.go +++ b/bridgev2/networkid/bridgeid.go @@ -43,6 +43,9 @@ type PortalID string // It is also permitted to use a non-empty receiver for group chats if there is a good reason to // segregate them. For example, Telegram's non-supergroups have user-scoped message IDs instead // of chat-scoped IDs, which is easier to manage with segregated rooms. +// +// As a special case, Receiver MUST be set if the Bridge.Config.SplitPortals flag is set to true. +// The flag is intended for puppeting-only bridges which want multiple logins to create separate portals for each user. type PortalKey struct { ID PortalID Receiver UserLoginID From d86913bd5c51e0513bfe1c5014ab5ab085a7bfb6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 14 Sep 2024 14:31:30 +0300 Subject: [PATCH 0751/1647] bridgev2/legacymigrate: drop *_mxid_unique constraints before migration --- bridgev2/matrix/mxmain/legacymigrate.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go index b2bdaa91..f269ccd5 100644 --- a/bridgev2/matrix/mxmain/legacymigrate.go +++ b/bridgev2/matrix/mxmain/legacymigrate.go @@ -26,6 +26,18 @@ import ( func (br *BridgeMain) LegacyMigrateWithAnotherUpgrader(renameTablesQuery, copyDataQuery string, newDBVersion int, otherTable dbutil.UpgradeTable, otherTableName string, otherNewVersion int) func(ctx context.Context) error { return func(ctx context.Context) error { + // Unique constraints must have globally unique names on postgres, and renaming the table doesn't rename them, + // so just drop the ones that may conflict with the new schema. + if br.DB.Dialect == dbutil.Postgres { + _, err := br.DB.Exec(ctx, "ALTER TABLE message DROP CONSTRAINT IF EXISTS message_mxid_unique") + if err != nil { + return fmt.Errorf("failed to drop potentially conflicting constraint on message: %w", err) + } + _, err = br.DB.Exec(ctx, "ALTER TABLE reaction DROP CONSTRAINT IF EXISTS reaction_mxid_unique") + if err != nil { + return fmt.Errorf("failed to drop potentially conflicting constraint on reaction: %w", err) + } + } err := dbutil.DangerousInternalUpgradeVersionTable(ctx, br.DB) if err != nil { return err From 6f9927c3991a5acb185a3877c7bec0c8e3798a56 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 15 Sep 2024 01:11:53 +0300 Subject: [PATCH 0752/1647] crypto: make OTK count for other user log less noisy --- crypto/machine.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/crypto/machine.go b/crypto/machine.go index 85da2b3b..c130f775 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -65,8 +65,9 @@ type OlmMachine struct { megolmEncryptLock sync.Mutex megolmDecryptLock sync.Mutex - otkUploadLock sync.Mutex - lastOTKUpload time.Time + otkUploadLock sync.Mutex + lastOTKUpload time.Time + receivedOTKsForSelf bool CrossSigningKeys *CrossSigningKeysCache crossSigningPubkeys *CrossSigningPublicKeysCache @@ -258,16 +259,15 @@ func (mach *OlmMachine) otkCountIsForCrossSigningKey(otkCount *mautrix.OTKCount) func (mach *OlmMachine) HandleOTKCounts(ctx context.Context, otkCount *mautrix.OTKCount) { if (len(otkCount.UserID) > 0 && otkCount.UserID != mach.Client.UserID) || (len(otkCount.DeviceID) > 0 && otkCount.DeviceID != mach.Client.DeviceID) { - if mach.otkCountIsForCrossSigningKey(otkCount) { - return + if otkCount.UserID != mach.Client.UserID || (!mach.receivedOTKsForSelf && !mach.otkCountIsForCrossSigningKey(otkCount)) { + mach.Log.Warn(). + Str("target_user_id", otkCount.UserID.String()). + Str("target_device_id", otkCount.DeviceID.String()). + Msg("Dropping OTK counts targeted to someone else") } - // TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions - mach.Log.Warn(). - Str("target_user_id", otkCount.UserID.String()). - Str("target_device_id", otkCount.DeviceID.String()). - Msg("Dropping OTK counts targeted to someone else") return } + mach.receivedOTKsForSelf = true minCount := mach.account.Internal.MaxNumberOfOneTimeKeys() / 2 if otkCount.SignedCurve25519 < int(minCount) { From b5602fd4fe1b956c99482183b0d2b76312ee3b78 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 15 Sep 2024 01:20:50 +0300 Subject: [PATCH 0753/1647] appservice: increase OTK count channel size --- appservice/appservice.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/appservice/appservice.go b/appservice/appservice.go index 90ace5d9..67b7e5f0 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -32,7 +32,7 @@ import ( // EventChannelSize is the size for the Events channel in Appservice instances. var EventChannelSize = 64 -var OTKChannelSize = 4 +var OTKChannelSize = 64 // Create creates a blank appservice instance. func Create() *AppService { From 1e3493188f8db68c67f9fe5cc6e824d8fa9289f6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 16 Sep 2024 13:48:59 +0300 Subject: [PATCH 0754/1647] Bump version to v0.21.0 --- CHANGELOG.md | 17 ++++++++++++++++- go.mod | 2 +- go.sum | 4 ++-- version.go | 2 +- 4 files changed, 20 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d0da4c37..661a71a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,22 @@ -## unreleased +## v0.21.0 (2024-09-16) +* **Breaking change *(client)*** Dropped support for unauthenticated media. + Matrix v1.11 support is now required from the homeserver, although it's not + enforced using `/versions` as some servers don't advertise it. * *(bridgev2)* Added more features and fixed bugs. +* *(appservice,crypto)* Added support for using MSC3202 for appservice + encryption. +* *(crypto/olm)* Made everything into an interface to allow side-by-side + testing of libolm and goolm, as well as potentially support vodozemac + in the future. * *(client)* Fixed requests being retried even after context is canceled. +* *(client)* Added option to move `/sync` request logs to trace level. +* *(error)* Added `Write` and `WithMessage` helpers to `RespError` to make it + easier to use on servers. +* *(event)* Fixed `org.matrix.msc1767.audio` field allowing omitting the + duration and waveform. +* *(id)* Changed `MatrixURI` methods to not panic if the receiver is nil. +* *(federation)* Added limit to response size when fetching `.well-known` files. ## v0.20.0 (2024-08-16) diff --git a/go.mod b/go.mod index f1518cc0..ed1f66cf 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/tidwall/gjson v1.17.3 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.4 - go.mau.fi/util v0.7.1-0.20240913091524-7617daa66719 + go.mau.fi/util v0.8.0 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.27.0 golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 diff --git a/go.sum b/go.sum index 6aad56c7..8974dea6 100644 --- a/go.sum +++ b/go.sum @@ -51,8 +51,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.7.1-0.20240913091524-7617daa66719 h1:sg1P/f4RHY1JuAwsPOjTCsZr8ROzR9bRTtnvvBu42d4= -go.mau.fi/util v0.7.1-0.20240913091524-7617daa66719/go.mod h1:1Ixb8HWoVbl3rT6nAX6nV4iMkzn7KU/KXwE0Rn5RmsQ= +go.mau.fi/util v0.8.0 h1:MiSny8jgQq4XtCLAT64gDJhZVhqiDeMVIEBDFVw+M0g= +go.mau.fi/util v0.8.0/go.mod h1:1Ixb8HWoVbl3rT6nAX6nV4iMkzn7KU/KXwE0Rn5RmsQ= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= diff --git a/version.go b/version.go index 29c5eb46..80b96661 100644 --- a/version.go +++ b/version.go @@ -7,7 +7,7 @@ import ( "strings" ) -const Version = "v0.20.0" +const Version = "v0.21.0" var GoModVersion = "" var Commit = "" From 830136b49d3d401b2cc6adcfdbfea966be0f9a13 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 16 Sep 2024 17:15:19 +0300 Subject: [PATCH 0755/1647] crypto: avoid data race in HandleOTKCounts --- crypto/machine.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/crypto/machine.go b/crypto/machine.go index c130f775..c9fc2249 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "sync" + "sync/atomic" "time" "github.com/rs/zerolog" @@ -67,7 +68,7 @@ type OlmMachine struct { otkUploadLock sync.Mutex lastOTKUpload time.Time - receivedOTKsForSelf bool + receivedOTKsForSelf atomic.Bool CrossSigningKeys *CrossSigningKeysCache crossSigningPubkeys *CrossSigningPublicKeysCache @@ -258,16 +259,18 @@ func (mach *OlmMachine) otkCountIsForCrossSigningKey(otkCount *mautrix.OTKCount) } func (mach *OlmMachine) HandleOTKCounts(ctx context.Context, otkCount *mautrix.OTKCount) { + receivedOTKsForSelf := mach.receivedOTKsForSelf.Load() if (len(otkCount.UserID) > 0 && otkCount.UserID != mach.Client.UserID) || (len(otkCount.DeviceID) > 0 && otkCount.DeviceID != mach.Client.DeviceID) { - if otkCount.UserID != mach.Client.UserID || (!mach.receivedOTKsForSelf && !mach.otkCountIsForCrossSigningKey(otkCount)) { + if otkCount.UserID != mach.Client.UserID || (!receivedOTKsForSelf && !mach.otkCountIsForCrossSigningKey(otkCount)) { mach.Log.Warn(). Str("target_user_id", otkCount.UserID.String()). Str("target_device_id", otkCount.DeviceID.String()). Msg("Dropping OTK counts targeted to someone else") } return + } else if !receivedOTKsForSelf { + mach.receivedOTKsForSelf.Store(true) } - mach.receivedOTKsForSelf = true minCount := mach.account.Internal.MaxNumberOfOneTimeKeys() / 2 if otkCount.SignedCurve25519 < int(minCount) { From a95101ea7f013e067b04f405e4df55a6016bc724 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 17 Sep 2024 16:58:25 +0300 Subject: [PATCH 0756/1647] bridgev2/backfill: add optional done callback to fetch response --- bridgev2/networkinterface.go | 3 +++ bridgev2/portalbackfill.go | 19 +++++++++++++++---- bridgev2/portalinternal.go | 8 ++++---- 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 3e0617ae..30b806bc 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -495,6 +495,9 @@ type FetchMessagesResponse struct { ApproxRemainingCount int // Approximate total number of messages in the chat. ApproxTotalCount int + + // An optional function that is called after the backfill batch has been sent. + CompleteCallback func() } // BackfillingNetworkAPI is an optional interface that network connectors can implement to support backfilling message history. diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index 6a912355..4350cfa2 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -74,7 +74,7 @@ func (portal *Portal) doForwardBackfill(ctx context.Context, source *UserLogin, log.Warn().Msg("No messages left to backfill after cutting off old messages") return } - portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, false) + portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, false, resp.CompleteCallback) } func (portal *Portal) DoBackwardsBackfill(ctx context.Context, source *UserLogin, task *database.BackfillTask) error { @@ -134,7 +134,7 @@ func (portal *Portal) DoBackwardsBackfill(ctx context.Context, source *UserLogin if len(resp.Messages) == 0 { return fmt.Errorf("no messages left to backfill after cutting off too new messages") } - portal.sendBackfill(ctx, source, resp.Messages, false, resp.MarkRead, false) + portal.sendBackfill(ctx, source, resp.Messages, false, resp.MarkRead, false, resp.CompleteCallback) if len(resp.Messages) > 0 { task.OldestMessageID = resp.Messages[0].ID } @@ -182,7 +182,7 @@ func (portal *Portal) doThreadBackfill(ctx context.Context, source *UserLogin, t } resp := portal.fetchThreadBackfill(ctx, source, anchorMessage) if resp != nil { - portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, true) + portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, true, resp.CompleteCallback) } } @@ -257,7 +257,15 @@ func (portal *Portal) cutoffMessages(ctx context.Context, messages []*BackfillMe return messages } -func (portal *Portal) sendBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, inThread bool) { +func (portal *Portal) sendBackfill( + ctx context.Context, + source *UserLogin, + messages []*BackfillMessage, + forceForward, + markRead, + inThread bool, + done func(), +) { canBatchSend := portal.Bridge.Matrix.GetCapabilities().BatchSending unreadThreshold := time.Duration(portal.Bridge.Config.Backfill.UnreadHoursThreshold) * time.Hour forceMarkRead := unreadThreshold > 0 && time.Since(messages[len(messages)-1].Timestamp) > unreadThreshold @@ -272,6 +280,9 @@ func (portal *Portal) sendBackfill(ctx context.Context, source *UserLogin, messa } else { portal.sendLegacyBackfill(ctx, source, messages, markRead || forceMarkRead) } + if done != nil { + done() + } zerolog.Ctx(ctx).Debug().Msg("Backfill finished") if !canBatchSend && !inThread && portal.Bridge.Config.Backfill.Threads.MaxInitialMessages > 0 { for _, msg := range messages { diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index 77bdd7fd..56726fee 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -269,8 +269,8 @@ func (portal *PortalInternals) UpdateUserLocalInfo(ctx context.Context, info *Us (*Portal)(portal).updateUserLocalInfo(ctx, info, source, didJustCreate) } -func (portal *PortalInternals) UpdateParent(ctx context.Context, newParent networkid.PortalID, source *UserLogin) bool { - return (*Portal)(portal).updateParent(ctx, newParent, source) +func (portal *PortalInternals) UpdateParent(ctx context.Context, newParentID networkid.PortalID, source *UserLogin) bool { + return (*Portal)(portal).updateParent(ctx, newParentID, source) } func (portal *PortalInternals) LockedUpdateInfoFromGhost(ctx context.Context, ghost *Ghost) { @@ -309,8 +309,8 @@ func (portal *PortalInternals) CutoffMessages(ctx context.Context, messages []*B return (*Portal)(portal).cutoffMessages(ctx, messages, aggressiveDedup, forward, lastMessage) } -func (portal *PortalInternals) SendBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, inThread bool) { - (*Portal)(portal).sendBackfill(ctx, source, messages, forceForward, markRead, inThread) +func (portal *PortalInternals) SendBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, forceForward, markRead, inThread bool, done func()) { + (*Portal)(portal).sendBackfill(ctx, source, messages, forceForward, markRead, inThread, done) } func (portal *PortalInternals) CompileBatchMessage(ctx context.Context, source *UserLogin, msg *BackfillMessage, out *compileBatchOutput, inThread bool) { From 7324f6edec82b0d6a9960c9ce76f4b80f7090251 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 21 Sep 2024 17:51:54 +0300 Subject: [PATCH 0757/1647] bridgev2/matrix: handle token errors in /versions properly Fixes mautrix/signal#554 --- bridgev2/matrix/connector.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 8bc31f0c..8ed0710f 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -264,6 +264,9 @@ func (br *Connector) ensureConnection(ctx context.Context) { br.logInitialRequestError(err, "Failed to register after /versions failed with M_FORBIDDEN") os.Exit(16) } + } else if errors.Is(err, mautrix.MUnknownToken) || errors.Is(err, mautrix.MExclusive) { + br.logInitialRequestError(err, "/versions request failed with auth error") + os.Exit(16) } else { br.Log.Err(err).Msg("Failed to connect to homeserver, retrying in 10 seconds...") time.Sleep(10 * time.Second) From 5a8f566c3ccaace095ec820b7b39476e7ca1e277 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 24 Sep 2024 13:55:36 +0300 Subject: [PATCH 0758/1647] bridgev2/portal: log when thread or reply message is not found --- bridgev2/portal.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 1c9a12a4..5c8ea60b 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -806,12 +806,16 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin threadRoot, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, threadRootID) if err != nil { log.Err(err).Msg("Failed to get thread root message from database") + } else if threadRoot == nil { + log.Warn().Stringer("thread_root_id", threadRootID).Msg("Thread root message not found") } } if replyToID != "" && (caps.Replies || caps.Threads) { replyTo, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, replyToID) if err != nil { log.Err(err).Msg("Failed to get reply target message from database") + } else if replyTo == nil { + log.Warn().Stringer("reply_to_id", replyToID).Msg("Reply target message not found") } else { // Support replying to threads from non-thread-capable clients. // The fallback happens if the message is not a Matrix thread and either From a834fa84316e93a0e3464c3e5d66173770560190 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 24 Sep 2024 19:26:06 +0300 Subject: [PATCH 0759/1647] portal: plumb stream order from send response to MSS event --- bridgev2/messagestatus.go | 1 + bridgev2/networkinterface.go | 3 ++- bridgev2/portal.go | 27 +++++++++++++++------------ bridgev2/portalinternal.go | 4 ++-- 4 files changed, 20 insertions(+), 15 deletions(-) diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index 04ee8eca..77ca98fd 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -24,6 +24,7 @@ type MessageStatusEventInfo struct { MessageType event.MessageType Sender id.UserID ThreadRoot id.EventID + StreamOrder int64 } func StatusEventInfoFromEvent(evt *event.Event) *MessageStatusEventInfo { diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 30b806bc..0646b2e7 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -276,7 +276,8 @@ type MaxFileSizeingNetwork interface { type RemoteEchoHandler func(RemoteMessage, *database.Message) (bool, error) type MatrixMessageResponse struct { - DB *database.Message + DB *database.Message + StreamOrder int64 // If Pending is set, the bridge will not save the provided message to the database. // This should only be used if AddPendingToSave has been called. Pending bool diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 5c8ea60b..3f674251 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -440,8 +440,10 @@ func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowR } } -func (portal *Portal) sendSuccessStatus(ctx context.Context, evt *event.Event) { - portal.Bridge.Matrix.SendMessageStatus(ctx, &MessageStatus{Status: event.MessageStatusSuccess}, StatusEventInfoFromEvent(evt)) +func (portal *Portal) sendSuccessStatus(ctx context.Context, evt *event.Event, streamOrder int64) { + info := StatusEventInfoFromEvent(evt) + info.StreamOrder = streamOrder + portal.Bridge.Matrix.SendMessageStatus(ctx, &MessageStatus{Status: event.MessageStatusSuccess}, info) } func (portal *Portal) sendErrorStatus(ctx context.Context, evt *event.Event, err error) { @@ -871,7 +873,7 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin portal.outgoingMessagesLock.Unlock() } } - portal.sendSuccessStatus(ctx, evt) + portal.sendSuccessStatus(ctx, evt, resp.StreamOrder) } if portal.Disappear.Type != database.DisappearingTypeNone { go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ @@ -1023,7 +1025,8 @@ func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, o if err != nil { log.Err(err).Msg("Failed to save message to database after editing") } - portal.sendSuccessStatus(ctx, evt) + // TODO allow returning stream order from HandleMatrixEdit + portal.sendSuccessStatus(ctx, evt, 0) } func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) { @@ -1077,7 +1080,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi } else if existing != nil { if existing.EmojiID != "" || existing.Emoji == preResp.Emoji { log.Debug().Msg("Ignoring duplicate reaction") - portal.sendSuccessStatus(ctx, evt) + portal.sendSuccessStatus(ctx, evt, 0) return } react.ReactionToOverride = existing @@ -1159,7 +1162,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi if err != nil { log.Err(err).Msg("Failed to save reaction to database") } - portal.sendSuccessStatus(ctx, evt) + portal.sendSuccessStatus(ctx, evt, 0) } func handleMatrixRoomMeta[APIType any, ContentType any]( @@ -1185,17 +1188,17 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( switch typedContent := evt.Content.Parsed.(type) { case *event.RoomNameEventContent: if typedContent.Name == portal.Name { - portal.sendSuccessStatus(ctx, evt) + portal.sendSuccessStatus(ctx, evt, 0) return } case *event.TopicEventContent: if typedContent.Topic == portal.Topic { - portal.sendSuccessStatus(ctx, evt) + portal.sendSuccessStatus(ctx, evt, 0) return } case *event.RoomAvatarEventContent: if typedContent.URL == portal.AvatarMXC { - portal.sendSuccessStatus(ctx, evt) + portal.sendSuccessStatus(ctx, evt, 0) return } } @@ -1226,7 +1229,7 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( log.Err(err).Msg("Failed to save portal after updating room metadata") } } - portal.sendSuccessStatus(ctx, evt) + portal.sendSuccessStatus(ctx, evt, 0) } func handleMatrixAccountData[APIType any, ContentType any]( @@ -1516,7 +1519,7 @@ func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLog return } // TODO delete msg/reaction db row - portal.sendSuccessStatus(ctx, evt) + portal.sendSuccessStatus(ctx, evt, 0) } func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, evtType RemoteEventType, evt RemoteEvent) { @@ -1835,7 +1838,7 @@ func (portal *Portal) checkPendingMessage(ctx context.Context, evt RemoteMessage if statusErr != nil { portal.sendErrorStatus(ctx, pending.evt, statusErr) } else { - portal.sendSuccessStatus(ctx, pending.evt) + portal.sendSuccessStatus(ctx, pending.evt, getStreamOrder(evt)) } } zerolog.Ctx(ctx).Debug().Stringer("event_id", pending.evt.ID).Msg("Received remote echo for message") diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index 56726fee..a3b1fbf4 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -49,8 +49,8 @@ func (portal *PortalInternals) HandleSingleEvent(ctx context.Context, rawEvt any (*Portal)(portal).handleSingleEvent(ctx, rawEvt, doneCallback) } -func (portal *PortalInternals) SendSuccessStatus(ctx context.Context, evt *event.Event) { - (*Portal)(portal).sendSuccessStatus(ctx, evt) +func (portal *PortalInternals) SendSuccessStatus(ctx context.Context, evt *event.Event, streamOrder int64) { + (*Portal)(portal).sendSuccessStatus(ctx, evt, streamOrder) } func (portal *PortalInternals) SendErrorStatus(ctx context.Context, evt *event.Event, err error) { From cb64bbcff4739a3e1a14b32fbd176b78d05218c3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 24 Sep 2024 20:44:55 +0300 Subject: [PATCH 0760/1647] ci: lock closed issues automatically after 90 days [skip ci] --- .github/workflows/stale.yml | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 .github/workflows/stale.yml diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 00000000..68bc2292 --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,29 @@ +name: 'Lock old issues' + +on: + schedule: + - cron: '0 * * * *' + workflow_dispatch: + +permissions: + issues: write +# pull-requests: write +# discussions: write + +concurrency: + group: lock-threads + +jobs: + lock-stale: + runs-on: ubuntu-latest + steps: + - uses: dessant/lock-threads@v5 + id: lock + with: + issue-inactive-days: 90 + process-only: issues + - name: Log processed threads + run: | + if [ '${{ steps.lock.outputs.issues }}' ]; then + echo "Issues:" && echo '${{ steps.lock.outputs.issues }}' | jq -r '.[] | "https://github.com/\(.owner)/\(.repo)/issues/\(.issue_number)"' + fi From dff2edec782b25f91050592b6dc6e8c36698bb70 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 24 Sep 2024 22:14:50 +0300 Subject: [PATCH 0761/1647] bridgev2/legacymigrate: clear version table to be safe Fixes mautrix/signal#557 --- bridgev2/matrix/mxmain/legacymigrate.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go index f269ccd5..35679e87 100644 --- a/bridgev2/matrix/mxmain/legacymigrate.go +++ b/bridgev2/matrix/mxmain/legacymigrate.go @@ -85,7 +85,11 @@ func (br *BridgeMain) LegacyMigrateWithAnotherUpgrader(renameTablesQuery, copyDa if err != nil { return err } - _, err = br.DB.Exec(ctx, "UPDATE version SET version = $1, compat = $2", upgradesTo, compat) + _, err = br.DB.Exec(ctx, "DELETE FROM version", upgradesTo, compat) + if err != nil { + return err + } + _, err = br.DB.Exec(ctx, "INSERT INTO version (version, compat) VALUES ($1, $2)", upgradesTo, compat) if err != nil { return err } From b3452db0386a6d55258f0639a90e444889b2c37b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 25 Sep 2024 00:14:49 +0300 Subject: [PATCH 0762/1647] ci: reduce issue lock interval [skip ci] --- .github/workflows/stale.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 68bc2292..4ad4d792 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -2,7 +2,7 @@ name: 'Lock old issues' on: schedule: - - cron: '0 * * * *' + - cron: '0 0 * * *' workflow_dispatch: permissions: From 9a14949d9acef37d009126ed60e2425671183027 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 25 Sep 2024 14:54:08 +0300 Subject: [PATCH 0763/1647] bridgev2/matrixinterface: add method for getting URL preview --- bridgev2/matrix/connector.go | 5 +++++ bridgev2/matrixinterface.go | 4 ++++ 2 files changed, 9 insertions(+) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 8ed0710f..4fd25306 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -105,6 +105,7 @@ var ( _ bridgev2.MatrixConnectorWithPostRoomBridgeHandling = (*Connector)(nil) _ bridgev2.MatrixConnectorWithPublicMedia = (*Connector)(nil) _ bridgev2.MatrixConnectorWithNameDisambiguation = (*Connector)(nil) + _ bridgev2.MatrixConnectorWithURLPreviews = (*Connector)(nil) _ bridgev2.MatrixConnectorWithAnalytics = (*Connector)(nil) ) @@ -642,3 +643,7 @@ func (br *Connector) HandleNewlyBridgedRoom(ctx context.Context, roomID id.RoomI } return nil } + +func (br *Connector) GetURLPreview(ctx context.Context, url string) (*event.LinkPreview, error) { + return br.Bot.GetURLPreview(ctx, url) +} diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index fe218db1..66d39403 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -69,6 +69,10 @@ type MatrixConnectorWithNameDisambiguation interface { IsConfusableName(ctx context.Context, roomID id.RoomID, userID id.UserID, name string) ([]id.UserID, error) } +type MatrixConnectorWithURLPreviews interface { + GetURLPreview(ctx context.Context, url string) (*event.LinkPreview, error) +} + type MatrixConnectorWithPostRoomBridgeHandling interface { HandleNewlyBridgedRoom(ctx context.Context, roomID id.RoomID) error } From 0c7f701828f07331d53cf674646b90c95b08a0c0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 25 Sep 2024 16:14:09 +0300 Subject: [PATCH 0764/1647] bridgev2/portal: add function for formatting disappearing timer change --- bridgev2/portal.go | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 3f674251..a5980fae 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -3180,6 +3180,19 @@ func (portal *Portal) updateUserLocalInfo(ctx context.Context, info *UserLocalPo } } +func DisappearingMessageNotice(expiration time.Duration, implicit bool) *event.MessageEventContent { + content := &event.MessageEventContent{ + MsgType: event.MsgNotice, + Body: fmt.Sprintf("Set the disappearing message timer to %s", exfmt.Duration(expiration)), + } + if implicit { + content.Body = fmt.Sprintf("Automatically enabled disappearing message timer (%s) because incoming message is disappearing", exfmt.Duration(expiration)) + } else if expiration == 0 { + content.Body = "Turned off disappearing messages" + } + return content +} + func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting database.DisappearingSetting, sender MatrixAPI, ts time.Time, implicit, save bool) bool { if setting.Timer == 0 { setting.Type = "" @@ -3195,15 +3208,7 @@ func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting dat zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal to database after updating disappearing setting") } } - content := &event.MessageEventContent{ - MsgType: event.MsgNotice, - Body: fmt.Sprintf("Disappearing messages set to %s", exfmt.Duration(setting.Timer)), - } - if implicit { - content.Body = fmt.Sprintf("Automatically enabled disappearing message timer (%s) because incoming message is disappearing", exfmt.Duration(setting.Timer)) - } else if setting.Timer == 0 { - content.Body = "Disappearing messages disabled" - } + content := DisappearingMessageNotice(setting.Timer, implicit) if sender == nil { sender = portal.Bridge.Bot } From 4e180ee36b5007cb008f769bc06f85c2737dd61e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 25 Sep 2024 16:45:09 +0300 Subject: [PATCH 0765/1647] format/mdext: add indented paragraph fixer --- format/mdext/indentableparagraph.go | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 format/mdext/indentableparagraph.go diff --git a/format/mdext/indentableparagraph.go b/format/mdext/indentableparagraph.go new file mode 100644 index 00000000..a6ebd6c0 --- /dev/null +++ b/format/mdext/indentableparagraph.go @@ -0,0 +1,28 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mdext + +import ( + "github.com/yuin/goldmark" + "github.com/yuin/goldmark/parser" + "github.com/yuin/goldmark/util" +) + +// indentableParagraphParser is the default paragraph parser with CanAcceptIndentedLine. +// Used when disabling CodeBlockParser (as disabling it without a replacement will make indented blocks disappear). +type indentableParagraphParser struct { + parser.BlockParser +} + +var defaultIndentableParagraphParser = &indentableParagraphParser{BlockParser: parser.NewParagraphParser()} + +func (b *indentableParagraphParser) CanAcceptIndentedLine() bool { + return true +} + +// FixIndentedParagraphs is a goldmark option which fixes indented paragraphs when disabling CodeBlockParser. +var FixIndentedParagraphs = goldmark.WithParserOptions(parser.WithBlockParsers(util.Prioritized(defaultIndentableParagraphParser, 500))) From edae08383b3ed69e4405cfb6004fdc43c7efb142 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 26 Sep 2024 00:23:05 +0300 Subject: [PATCH 0766/1647] event: add Has method for Mentions --- event/message.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/event/message.go b/event/message.go index 3c6edfdd..097c585e 100644 --- a/event/message.go +++ b/event/message.go @@ -220,6 +220,10 @@ func (m *Mentions) Add(userID id.UserID) { } } +func (m *Mentions) Has(userID id.UserID) bool { + return m != nil && slices.Contains(m.UserIDs, userID) +} + type EncryptedFileInfo struct { attachment.EncryptedFile URL id.ContentURIString `json:"url"` From 5d916e0e9a66bc925fc6fdf23b97547b79200ed9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 26 Sep 2024 12:12:02 +0300 Subject: [PATCH 0767/1647] bridgev2/queue: add shortcut for QueueRemoteEvent --- bridgev2/queue.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bridgev2/queue.go b/bridgev2/queue.go index 213307a3..38895953 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -139,6 +139,10 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) { } } +func (ul *UserLogin) QueueRemoteEvent(evt RemoteEvent) { + ul.Bridge.QueueRemoteEvent(ul, evt) +} + func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) { log := login.Log ctx := log.WithContext(context.TODO()) From 7a5f15b03c2ea9f95bd4cf3ddd535bcda3243b9d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 26 Sep 2024 12:16:54 +0300 Subject: [PATCH 0768/1647] bridgev2/networkinterface: add DeleteOnlyForMe field to message remove events --- bridgev2/networkinterface.go | 6 +++++- bridgev2/portal.go | 6 ++++++ bridgev2/simplevent/message.go | 10 +++++++++- 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 0646b2e7..1fd09bb6 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -913,11 +913,15 @@ type RemoteBackfill interface { GetBackfillData(ctx context.Context, portal *Portal) (*FetchMessagesResponse, error) } -type RemoteChatDelete interface { +type RemoteDeleteOnlyForMe interface { RemoteEvent DeleteOnlyForMe() bool } +type RemoteChatDelete interface { + RemoteDeleteOnlyForMe +} + type RemoteEventThatMayCreatePortal interface { RemoteEvent ShouldCreatePortal() bool diff --git a/bridgev2/portal.go b/bridgev2/portal.go index a5980fae..40e4b6ca 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2370,6 +2370,12 @@ func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *Use log.Debug().Msg("Target message not found") return } + onlyForMeProvider, ok := evt.(RemoteDeleteOnlyForMe) + onlyForMe := ok && onlyForMeProvider.DeleteOnlyForMe() + if onlyForMe && portal.Receiver == "" { + // TODO check if there are other user logins before deleting + } + intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessageRemove) if intent == portal.Bridge.Bot && len(targetParts) > 0 { senderIntent, err := portal.getIntentForMXID(ctx, targetParts[0].SenderMXID) diff --git a/bridgev2/simplevent/message.go b/bridgev2/simplevent/message.go index 55d25bd8..f648ab12 100644 --- a/bridgev2/simplevent/message.go +++ b/bridgev2/simplevent/message.go @@ -63,10 +63,18 @@ type MessageRemove struct { EventMeta TargetMessage networkid.MessageID + OnlyForMe bool } -var _ bridgev2.RemoteMessageRemove = (*MessageRemove)(nil) +var ( + _ bridgev2.RemoteMessageRemove = (*MessageRemove)(nil) + _ bridgev2.RemoteDeleteOnlyForMe = (*MessageRemove)(nil) +) func (evt *MessageRemove) GetTargetMessage() networkid.MessageID { return evt.TargetMessage } + +func (evt *MessageRemove) DeleteOnlyForMe() bool { + return evt.OnlyForMe +} From d1e5b09d972b32ba2b196a3973a5b534ca79ee5a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 27 Sep 2024 14:36:33 +0300 Subject: [PATCH 0769/1647] bridgev2: add UserLogin.TrackAnalytics shortcut --- bridgev2/user.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/bridgev2/user.go b/bridgev2/user.go index 1530b865..993eda92 100644 --- a/bridgev2/user.go +++ b/bridgev2/user.go @@ -264,3 +264,8 @@ func (br *Bridge) TrackAnalytics(userID id.UserID, event string, props map[strin func (user *User) TrackAnalytics(event string, props map[string]any) { user.Bridge.TrackAnalytics(user.MXID, event, props) } + +func (ul *UserLogin) TrackAnalytics(event string, props map[string]any) { + // TODO include user login ID? + ul.Bridge.TrackAnalytics(ul.UserMXID, event, props) +} From cf80de9f1a18c43c67615ecf33dd6e8019cd4f44 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 27 Sep 2024 20:45:47 +0300 Subject: [PATCH 0770/1647] bridgev2: don't include weeks in disappearing timer notices --- bridgev2/portal.go | 5 +++-- go.mod | 2 +- go.sum | 4 ++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 40e4b6ca..fcff8348 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -3187,12 +3187,13 @@ func (portal *Portal) updateUserLocalInfo(ctx context.Context, info *UserLocalPo } func DisappearingMessageNotice(expiration time.Duration, implicit bool) *event.MessageEventContent { + formattedDuration := exfmt.DurationCustom(expiration, nil, exfmt.Day, time.Hour, time.Minute, time.Second) content := &event.MessageEventContent{ MsgType: event.MsgNotice, - Body: fmt.Sprintf("Set the disappearing message timer to %s", exfmt.Duration(expiration)), + Body: fmt.Sprintf("Set the disappearing message timer to %s", formattedDuration), } if implicit { - content.Body = fmt.Sprintf("Automatically enabled disappearing message timer (%s) because incoming message is disappearing", exfmt.Duration(expiration)) + content.Body = fmt.Sprintf("Automatically enabled disappearing message timer (%s) because incoming message is disappearing", formattedDuration) } else if expiration == 0 { content.Body = "Turned off disappearing messages" } diff --git a/go.mod b/go.mod index ed1f66cf..affeeb31 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/tidwall/gjson v1.17.3 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.4 - go.mau.fi/util v0.8.0 + go.mau.fi/util v0.8.1-0.20240927174413-000d30f9a02a go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.27.0 golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 diff --git a/go.sum b/go.sum index 8974dea6..c5058f61 100644 --- a/go.sum +++ b/go.sum @@ -51,8 +51,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.8.0 h1:MiSny8jgQq4XtCLAT64gDJhZVhqiDeMVIEBDFVw+M0g= -go.mau.fi/util v0.8.0/go.mod h1:1Ixb8HWoVbl3rT6nAX6nV4iMkzn7KU/KXwE0Rn5RmsQ= +go.mau.fi/util v0.8.1-0.20240927174413-000d30f9a02a h1:4TrWJ0ooHT9YssDBUgXNU8FiR2cwi9jEAjtaVur4f0M= +go.mau.fi/util v0.8.1-0.20240927174413-000d30f9a02a/go.mod h1:1Ixb8HWoVbl3rT6nAX6nV4iMkzn7KU/KXwE0Rn5RmsQ= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= From f48a66c31c96d477fa587e5e6a6dc6186b51a92e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 28 Sep 2024 18:18:29 +0300 Subject: [PATCH 0771/1647] ci: change cron schedule --- .github/workflows/stale.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 4ad4d792..578349c9 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -2,7 +2,7 @@ name: 'Lock old issues' on: schedule: - - cron: '0 0 * * *' + - cron: '0 6 * * *' workflow_dispatch: permissions: From cc179f8ff7ac322df6e15c3c7248af61486ebd79 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 29 Sep 2024 16:16:39 +0300 Subject: [PATCH 0772/1647] appservice: remove TLS support --- appservice/appservice.go | 3 --- appservice/http.go | 9 ++------- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/appservice/appservice.go b/appservice/appservice.go index 67b7e5f0..518e1073 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -224,9 +224,6 @@ type HostConfig struct { Hostname string `yaml:"hostname"` // Port is required when Hostname is an IP address, optional for unix sockets Port uint16 `yaml:"port"` - - TLSKey string `yaml:"tls_key,omitempty"` - TLSCert string `yaml:"tls_cert,omitempty"` } // Address gets the whole address of the Appservice. diff --git a/appservice/http.go b/appservice/http.go index 47f6a282..661513b4 100644 --- a/appservice/http.go +++ b/appservice/http.go @@ -59,13 +59,8 @@ func (as *AppService) listenUnix() error { } func (as *AppService) listenTCP() error { - if len(as.Host.TLSCert) == 0 || len(as.Host.TLSKey) == 0 { - as.Log.Info().Str("address", as.server.Addr).Msg("Starting HTTP listener") - return as.server.ListenAndServe() - } else { - as.Log.Info().Str("address", as.server.Addr).Msg("Starting HTTP listener with TLS") - return as.server.ListenAndServeTLS(as.Host.TLSCert, as.Host.TLSKey) - } + as.Log.Info().Str("address", as.server.Addr).Msg("Starting HTTP listener") + return as.server.ListenAndServe() } func (as *AppService) Stop() { From 0cf0b48a96e2af3575df09bc9a5e22ab79bc6668 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 28 Sep 2024 21:16:19 +0300 Subject: [PATCH 0773/1647] bridgev2/provisioning: use RespError.Write in RespondWithError --- bridgev2/matrix/provisioning.go | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index e0f630a0..3da849c2 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -532,22 +532,21 @@ func (prov *ProvisioningAPI) GetLoginForRequest(w http.ResponseWriter, r *http.R return userLogin } +type WritableError interface { + Write(w http.ResponseWriter) +} + func RespondWithError(w http.ResponseWriter, err error, message string) { - var mautrixRespErr mautrix.RespError - var bv2RespErr bridgev2.RespError - if errors.As(err, &bv2RespErr) { - mautrixRespErr = mautrix.RespError(bv2RespErr) - } else if !errors.As(err, &mautrixRespErr) { - mautrixRespErr = mautrix.RespError{ + var we WritableError + if errors.As(err, &we) { + we.Write(w) + } else { + mautrix.RespError{ Err: message, ErrCode: "M_UNKNOWN", StatusCode: http.StatusInternalServerError, - } + }.Write(w) } - if mautrixRespErr.StatusCode == 0 { - mautrixRespErr.StatusCode = http.StatusInternalServerError - } - jsonResponse(w, mautrixRespErr.StatusCode, mautrixRespErr) } type RespResolveIdentifier struct { From c8d19e8e1877503e25218a1a82e96ad291cd1fe2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 28 Sep 2024 22:07:09 +0300 Subject: [PATCH 0774/1647] bridgev2/commands: add sudo and doin --- bridgev2/commands/event.go | 35 +++++------ bridgev2/commands/processor.go | 38 +++++++----- bridgev2/commands/sudo.go | 103 +++++++++++++++++++++++++++++++++ 3 files changed, 143 insertions(+), 33 deletions(-) create mode 100644 bridgev2/commands/sudo.go diff --git a/bridgev2/commands/event.go b/bridgev2/commands/event.go index 258ae2f0..bd2c52d2 100644 --- a/bridgev2/commands/event.go +++ b/bridgev2/commands/event.go @@ -23,20 +23,21 @@ import ( // Event stores all data which might be used to handle commands type Event struct { - Bot bridgev2.MatrixAPI - Bridge *bridgev2.Bridge - Portal *bridgev2.Portal - Processor *Processor - Handler MinimalCommandHandler - RoomID id.RoomID - EventID id.EventID - User *bridgev2.User - Command string - Args []string - RawArgs string - ReplyTo id.EventID - Ctx context.Context - Log *zerolog.Logger + Bot bridgev2.MatrixAPI + Bridge *bridgev2.Bridge + Portal *bridgev2.Portal + Processor *Processor + Handler MinimalCommandHandler + RoomID id.RoomID + OrigRoomID id.RoomID + EventID id.EventID + User *bridgev2.User + Command string + Args []string + RawArgs string + ReplyTo id.EventID + Ctx context.Context + Log *zerolog.Logger MessageStatus *bridgev2.MessageStatus } @@ -55,7 +56,7 @@ func (ce *Event) Reply(msg string, args ...any) { func (ce *Event) ReplyAdvanced(msg string, allowMarkdown, allowHTML bool) { content := format.RenderMarkdown(msg, allowMarkdown, allowHTML) content.MsgType = event.MsgNotice - _, err := ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: &content}, nil) + _, err := ce.Bot.SendMessage(ce.Ctx, ce.OrigRoomID, event.EventMessage, &event.Content{Parsed: &content}, nil) if err != nil { ce.Log.Err(err).Msgf("Failed to reply to command") } @@ -63,7 +64,7 @@ func (ce *Event) ReplyAdvanced(msg string, allowMarkdown, allowHTML bool) { // React sends a reaction to the command. func (ce *Event) React(key string) { - _, err := ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventReaction, &event.Content{ + _, err := ce.Bot.SendMessage(ce.Ctx, ce.OrigRoomID, event.EventReaction, &event.Content{ Parsed: &event.ReactionEventContent{ RelatesTo: event.RelatesTo{ Type: event.RelAnnotation, @@ -79,7 +80,7 @@ func (ce *Event) React(key string) { // Redact redacts the command. func (ce *Event) Redact(req ...mautrix.ReqRedact) { - _, err := ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventRedaction, &event.Content{ + _, err := ce.Bot.SendMessage(ce.Ctx, ce.OrigRoomID, event.EventRedaction, &event.Content{ Parsed: &event.RedactionEventContent{ Redacts: ce.EventID, }, diff --git a/bridgev2/commands/processor.go b/bridgev2/commands/processor.go index 49769514..a09418d4 100644 --- a/bridgev2/commands/processor.go +++ b/bridgev2/commands/processor.go @@ -46,6 +46,7 @@ func NewProcessor(bridge *bridgev2.Bridge) bridgev2.CommandProcessor { CommandLogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin, CommandSetRelay, CommandUnsetRelay, CommandResolveIdentifier, CommandStartChat, CommandSearch, + CommandSudo, CommandDoIn, ) return proc } @@ -108,25 +109,30 @@ func (proc *Processor) Handle(ctx context.Context, roomID id.RoomID, eventID id. rawArgs := strings.TrimLeft(strings.TrimPrefix(message, command), " ") portal, err := proc.bridge.GetPortalByMXID(ctx, roomID) if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal") // :( } ce := &Event{ - Bot: proc.bridge.Bot, - Bridge: proc.bridge, - Portal: portal, - Processor: proc, - RoomID: roomID, - EventID: eventID, - User: user, - Command: command, - Args: args[1:], - RawArgs: rawArgs, - ReplyTo: replyTo, - Ctx: ctx, + Bot: proc.bridge.Bot, + Bridge: proc.bridge, + Portal: portal, + Processor: proc, + RoomID: roomID, + OrigRoomID: roomID, + EventID: eventID, + User: user, + Command: command, + Args: args[1:], + RawArgs: rawArgs, + ReplyTo: replyTo, + Ctx: ctx, MessageStatus: ms, } + proc.handleCommand(ctx, ce, message, args) +} +func (proc *Processor) handleCommand(ctx context.Context, ce *Event, origMessage string, origArgs []string) { realCommand, ok := proc.aliases[ce.Command] if !ok { realCommand = ce.Command @@ -138,8 +144,8 @@ func (proc *Processor) Handle(ctx context.Context, roomID id.RoomID, eventID id. state := LoadCommandState(ce.User) if state != nil && state.Next != nil { ce.Command = "" - ce.RawArgs = message - ce.Args = args + ce.RawArgs = origMessage + ce.Args = origArgs ce.Handler = state.Next log := zerolog.Ctx(ctx).With().Str("action", state.Action).Logger() ce.Log = &log @@ -147,11 +153,11 @@ func (proc *Processor) Handle(ctx context.Context, roomID id.RoomID, eventID id. log.Debug().Msg("Received reply to command state") state.Next.Run(ce) } else { - zerolog.Ctx(ctx).Debug().Str("mx_command", command).Msg("Received unknown command") + zerolog.Ctx(ctx).Debug().Str("mx_command", ce.Command).Msg("Received unknown command") ce.Reply("Unknown command, use the `help` command for help.") } } else { - log := zerolog.Ctx(ctx).With().Str("mx_command", command).Logger() + log := zerolog.Ctx(ctx).With().Str("mx_command", ce.Command).Logger() ctx = log.WithContext(ctx) ce.Log = &log ce.Ctx = ctx diff --git a/bridgev2/commands/sudo.go b/bridgev2/commands/sudo.go new file mode 100644 index 00000000..885a0f36 --- /dev/null +++ b/bridgev2/commands/sudo.go @@ -0,0 +1,103 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "strings" + + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +var CommandSudo = &FullHandler{ + Func: fnSudo, + Name: "sudo", + Aliases: []string{"doas", "do-as", "runas", "run-as"}, + Help: HelpMeta{ + Section: HelpSectionAdmin, + Description: "Run a command as a different user.", + Args: "[--create] <_user ID_> <_command_> [_args..._]", + }, + RequiresAdmin: true, +} + +func fnSudo(ce *Event) { + forceNonexistentUser := len(ce.Args) > 0 && strings.ToLower(ce.Args[0]) == "--create" + if forceNonexistentUser { + ce.Args = ce.Args[1:] + } + if len(ce.Args) < 2 { + ce.Reply("Usage: `$cmdprefix sudo [--create] [args...]`") + return + } + targetUserID := id.UserID(ce.Args[0]) + if _, _, err := targetUserID.Parse(); err != nil || len(targetUserID) > id.UserIDMaxLength { + ce.Reply("Invalid user ID `%s`", targetUserID) + return + } + var targetUser *bridgev2.User + var err error + if forceNonexistentUser { + targetUser, err = ce.Bridge.GetUserByMXID(ce.Ctx, targetUserID) + } else { + targetUser, err = ce.Bridge.GetExistingUserByMXID(ce.Ctx, targetUserID) + } + if err != nil { + ce.Log.Err(err).Msg("Failed to get user from database") + ce.Reply("Failed to get user") + return + } else if targetUser == nil { + ce.Reply("User not found. Use `--create` if you want to run commands as a user who has never used the bridge.") + return + } + ce.User = targetUser + origArgs := ce.Args[1:] + ce.Args = ce.Args[2:] + ce.Processor.handleCommand(ce.Ctx, ce, strings.Join(origArgs, " "), origArgs) +} + +var CommandDoIn = &FullHandler{ + Func: fnDoIn, + Name: "doin", + Aliases: []string{"do-in", "runin", "run-in"}, + Help: HelpMeta{ + Section: HelpSectionAdmin, + Description: "Run a command in a different room.", + Args: "<_room ID_> <_command_> [_args..._]", + }, +} + +func fnDoIn(ce *Event) { + if len(ce.Args) < 2 { + ce.Reply("Usage: `$cmdprefix doin [args...]`") + return + } + targetRoomID := id.RoomID(ce.Args[0]) + if !ce.User.Permissions.Admin { + memberInfo, err := ce.Bridge.Matrix.GetMemberInfo(ce.Ctx, targetRoomID, ce.User.MXID) + if err != nil { + ce.Log.Err(err).Msg("Failed to check if user is in doin target room") + ce.Reply("Failed to check if you're in the target room") + return + } else if memberInfo == nil || memberInfo.Membership != event.MembershipJoin { + ce.Reply("You must be in the target room to run commands there") + return + } + } + ce.RoomID = targetRoomID + var err error + ce.Portal, err = ce.Bridge.GetPortalByMXID(ce.Ctx, targetRoomID) + if err != nil { + ce.Log.Err(err).Msg("Failed to get target portal") + ce.Reply("Failed to get portal") + return + } + origArgs := ce.Args[1:] + ce.Args = ce.Args[2:] + ce.Processor.handleCommand(ce.Ctx, ce, strings.Join(origArgs, " "), origArgs) +} From 3b878b4bcdedbdca16a4f49b8596382913bc74bd Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 30 Sep 2024 12:05:58 +0300 Subject: [PATCH 0775/1647] bridgev2/commands: change how log context is applied --- bridgev2/commands/processor.go | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/bridgev2/commands/processor.go b/bridgev2/commands/processor.go index a09418d4..482acf18 100644 --- a/bridgev2/commands/processor.go +++ b/bridgev2/commands/processor.go @@ -73,6 +73,8 @@ func (proc *Processor) Handle(ctx context.Context, roomID id.RoomID, eventID id. Step: status.MsgStepCommand, Status: event.MessageStatusSuccess, } + logCopy := zerolog.Ctx(ctx).With().Logger() + log := &logCopy defer func() { statusInfo := &bridgev2.MessageStatusEventInfo{ RoomID: roomID, @@ -82,7 +84,7 @@ func (proc *Processor) Handle(ctx context.Context, roomID id.RoomID, eventID id. } err := recover() if err != nil { - logEvt := zerolog.Ctx(ctx).Error(). + logEvt := log.Error(). Bytes(zerolog.ErrorStackFieldName, debug.Stack()) if realErr, ok := err.(error); ok { logEvt = logEvt.Err(realErr) @@ -109,7 +111,7 @@ func (proc *Processor) Handle(ctx context.Context, roomID id.RoomID, eventID id. rawArgs := strings.TrimLeft(strings.TrimPrefix(message, command), " ") portal, err := proc.bridge.GetPortalByMXID(ctx, roomID) if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal") + log.Err(err).Msg("Failed to get portal") // :( } ce := &Event{ @@ -126,6 +128,7 @@ func (proc *Processor) Handle(ctx context.Context, roomID id.RoomID, eventID id. RawArgs: rawArgs, ReplyTo: replyTo, Ctx: ctx, + Log: log, MessageStatus: ms, } @@ -137,6 +140,7 @@ func (proc *Processor) handleCommand(ctx context.Context, ce *Event, origMessage if !ok { realCommand = ce.Command } + log := zerolog.Ctx(ctx) var handler MinimalCommandHandler handler, ok = proc.handlers[realCommand] @@ -147,9 +151,9 @@ func (proc *Processor) handleCommand(ctx context.Context, ce *Event, origMessage ce.RawArgs = origMessage ce.Args = origArgs ce.Handler = state.Next - log := zerolog.Ctx(ctx).With().Str("action", state.Action).Logger() - ce.Log = &log - ce.Ctx = log.WithContext(ctx) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("action", state.Action) + }) log.Debug().Msg("Received reply to command state") state.Next.Run(ce) } else { @@ -157,10 +161,9 @@ func (proc *Processor) handleCommand(ctx context.Context, ce *Event, origMessage ce.Reply("Unknown command, use the `help` command for help.") } } else { - log := zerolog.Ctx(ctx).With().Str("mx_command", ce.Command).Logger() - ctx = log.WithContext(ctx) - ce.Log = &log - ce.Ctx = ctx + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Str("mx_command", ce.Command) + }) log.Debug().Msg("Received command") ce.Handler = handler handler.Run(ce) From 59251f83def4313baeb930bf12bf2966fd4a860d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 30 Sep 2024 12:41:43 +0300 Subject: [PATCH 0776/1647] bridgev2/example-config: be more explicit about securing the appservice address --- bridgev2/matrix/mxmain/example-config.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 8b9682ba..c8a86ac0 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -158,8 +158,12 @@ homeserver: # Changing these values requires regeneration of the registration (except when noted otherwise) appservice: # The address that the homeserver can use to connect to this appservice. + # Like the homeserver address, a local non-https address is recommended when the bridge is on the same machine. + # If the bridge is elsewhere, you must secure the connection yourself (e.g. with https or wireguard) + # If you want to use https, you need to use a reverse proxy. The bridge does not have TLS support built in. address: http://localhost:$<> # A public address that external services can use to reach this appservice. + # This is only needed for things like public media. A reverse proxy is generally necessary when using this field. # This value doesn't affect the registration file. public_address: https://bridge.example.com From 919834e6bf864912271155eb2756cc37f7dafcb3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 30 Sep 2024 16:16:05 +0300 Subject: [PATCH 0777/1647] bridgev2/provisioning: add simpler auth for pprof endpoints --- bridgev2/matrix/provisioning.go | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 3da849c2..f43c84ca 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -12,6 +12,7 @@ import ( "errors" "fmt" "net/http" + "net/http/pprof" "strings" "sync" "time" @@ -126,8 +127,12 @@ func (prov *ProvisioningAPI) Init() { if prov.br.Config.Provisioning.DebugEndpoints { prov.log.Debug().Msg("Enabling debug API at /debug") r := prov.br.AS.Router.PathPrefix("/debug").Subrouter() - r.Use(prov.AuthMiddleware) - r.PathPrefix("/pprof").Handler(http.DefaultServeMux) + r.Use(prov.DebugAuthMiddleware) + r.HandleFunc("/pprof/cmdline", pprof.Cmdline).Methods(http.MethodGet) + r.HandleFunc("/pprof/profile", pprof.Profile).Methods(http.MethodGet) + r.HandleFunc("/pprof/symbol", pprof.Symbol).Methods(http.MethodGet) + r.HandleFunc("/pprof/trace", pprof.Trace).Methods(http.MethodGet) + r.PathPrefix("/pprof/").HandlerFunc(pprof.Index) } } @@ -191,6 +196,25 @@ func (prov *ProvisioningAPI) checkFederatedMatrixAuth(ctx context.Context, userI } } +func (prov *ProvisioningAPI) DebugAuthMiddleware(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") + if auth == "" { + jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{ + Err: "Missing auth token", + ErrCode: mautrix.MMissingToken.ErrCode, + }) + } else if auth != prov.br.Config.Provisioning.SharedSecret { + jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{ + Err: "Invalid auth token", + ErrCode: mautrix.MUnknownToken.ErrCode, + }) + } else { + h.ServeHTTP(w, r) + } + }) +} + func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") From ed074ba6577db5a735950b8083045caee16b65ab Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 30 Sep 2024 17:18:26 +0300 Subject: [PATCH 0778/1647] changelog: update --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 661a71a7..5e821fc0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +## unreleased + +* *(appservice)* Removed TLS support. A reverse proxy should be used if TLS + is needed. +* *(format/mdext)* Added goldmark extension to fix indented paragraphs when + disabling indented code block parser. + ## v0.21.0 (2024-09-16) * **Breaking change *(client)*** Dropped support for unauthenticated media. From 741b4e823ffb2e06a8d9ec0d18e3bde72981e304 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 30 Sep 2024 17:21:22 +0300 Subject: [PATCH 0779/1647] bridgev2/commands: set missing fields in sudo/doin --- bridgev2/commands/sudo.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bridgev2/commands/sudo.go b/bridgev2/commands/sudo.go index 885a0f36..f05ca1bb 100644 --- a/bridgev2/commands/sudo.go +++ b/bridgev2/commands/sudo.go @@ -57,7 +57,9 @@ func fnSudo(ce *Event) { } ce.User = targetUser origArgs := ce.Args[1:] + ce.Command = strings.ToLower(ce.Args[1]) ce.Args = ce.Args[2:] + ce.RawArgs = strings.Join(ce.Args, " ") ce.Processor.handleCommand(ce.Ctx, ce, strings.Join(origArgs, " "), origArgs) } @@ -98,6 +100,8 @@ func fnDoIn(ce *Event) { return } origArgs := ce.Args[1:] + ce.Command = strings.ToLower(ce.Args[1]) ce.Args = ce.Args[2:] + ce.RawArgs = strings.Join(ce.Args, " ") ce.Processor.handleCommand(ce.Ctx, ce, strings.Join(origArgs, " "), origArgs) } From 1b7a78b8113eb0c81d41da80eb8884054b1c72e0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 30 Sep 2024 21:28:34 +0300 Subject: [PATCH 0780/1647] bridgev2/commands: url-decode cookie parameters --- bridgev2/commands/login.go | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go index 86071a8e..e5b3e50c 100644 --- a/bridgev2/commands/login.go +++ b/bridgev2/commands/login.go @@ -12,6 +12,7 @@ import ( "fmt" "net/url" "regexp" + "slices" "strings" "github.com/skip2/go-qrcode" @@ -120,11 +121,12 @@ func checkLoginCommandDirectParams(ce *Event, login bridgev2.LoginProcess, nextS } input := make(map[string]string) for i, param := range nextStep.CookiesParams.Fields { - if match, _ := regexp.MatchString(param.Pattern, ce.Args[i]); !match { + val := maybeURLDecodeCookie(ce.Args[i], ¶m) + if match, _ := regexp.MatchString(param.Pattern, val); !match { ce.Reply("Invalid value for %s: doesn't match regex `%s`", param.ID, param.Pattern) return nil } - input[param.ID] = ce.Args[i] + input[param.ID] = val } nextStep, err = login.(bridgev2.LoginProcessCookies).SubmitCookies(ce.Ctx, input) } @@ -349,6 +351,12 @@ func (clcs *cookieLoginCommandState) submit(ce *Event) { ce.Reply("Failed to parse input as JSON: %v", err) return } + for _, field := range clcs.Data.Fields { + val, ok := cookiesInput[field.ID] + if ok { + cookiesInput[field.ID] = maybeURLDecodeCookie(val, &field) + } + } } var missingKeys []string for _, field := range clcs.Data.Fields { @@ -374,6 +382,27 @@ func (clcs *cookieLoginCommandState) submit(ce *Event) { doLoginStep(ce, clcs.Login, nextStep) } +func maybeURLDecodeCookie(val string, field *bridgev2.LoginCookieField) string { + if val == "" { + return val + } + isCookie := slices.ContainsFunc(field.Sources, func(src bridgev2.LoginCookieFieldSource) bool { + return src.Type == bridgev2.LoginCookieTypeCookie + }) + if !isCookie { + return val + } + match, _ := regexp.MatchString(field.Pattern, val) + if !match { + return val + } + decoded, err := url.QueryUnescape(val) + if err != nil { + return val + } + return decoded +} + func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginStep) { if step.Instructions != "" { ce.Reply(step.Instructions) From 37af19a01a611f3f2f62d9f1c00e7ce4934b8b56 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 1 Oct 2024 13:52:37 +0300 Subject: [PATCH 0781/1647] bridgev2/portal: allow remote events to have post handlers --- bridgev2/networkinterface.go | 5 +++++ bridgev2/portal.go | 4 ++++ 2 files changed, 9 insertions(+) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 1fd09bb6..842de252 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -884,6 +884,11 @@ type RemotePreHandler interface { PreHandle(ctx context.Context, portal *Portal) } +type RemotePostHandler interface { + RemoteEvent + PostHandle(ctx context.Context, portal *Portal) +} + type RemoteChatInfoChange interface { RemoteEvent GetChatInfoChange(ctx context.Context) (*ChatInfoChange, error) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index fcff8348..28bbeb2e 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1594,6 +1594,10 @@ func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, default: log.Warn().Msg("Got remote event with unknown type") } + postHandler, ok := evt.(RemotePostHandler) + if ok { + postHandler.PostHandle(ctx, portal) + } } func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID) { From c259682a7cc588154be3d42e156c032dcfdd8cb1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 1 Oct 2024 16:55:37 +0300 Subject: [PATCH 0782/1647] bridgev2/backfill: catch panics in backfill queue --- bridgev2/backfillqueue.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/bridgev2/backfillqueue.go b/bridgev2/backfillqueue.go index 0f4ee048..95f3107d 100644 --- a/bridgev2/backfillqueue.go +++ b/bridgev2/backfillqueue.go @@ -9,6 +9,7 @@ package bridgev2 import ( "context" "fmt" + "runtime/debug" "time" "github.com/rs/zerolog" @@ -88,6 +89,19 @@ func (br *Bridge) doBackfillTask(ctx context.Context, task *database.BackfillTas Object("portal_key", task.PortalKey). Str("login_id", string(task.UserLoginID)). Logger() + defer func() { + err := recover() + if err != nil { + logEvt := log.Error(). + Bytes(zerolog.ErrorStackFieldName, debug.Stack()) + if realErr, ok := err.(error); ok { + logEvt = logEvt.Err(realErr) + } else { + logEvt = logEvt.Any(zerolog.ErrorFieldName, err) + } + logEvt.Msg("Panic in backfill queue") + } + }() ctx = log.WithContext(ctx) err := br.DB.BackfillTask.MarkDispatched(ctx, task) if err != nil { From 31a68cbcea95eb1838482465e5da031390ddfb7c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 1 Oct 2024 17:27:29 +0300 Subject: [PATCH 0783/1647] bridgev2/portal: don't try to send disappearing notice before room is created --- bridgev2/portal.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 28bbeb2e..ea0cbab2 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -3219,6 +3219,9 @@ func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting dat zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal to database after updating disappearing setting") } } + if portal.MXID == "" { + return true + } content := DisappearingMessageNotice(setting.Timer, implicit) if sender == nil { sender = portal.Bridge.Bot From e0961922e5c0211bdeef6643d64ee0c06d8a7dbc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 1 Oct 2024 17:52:28 +0300 Subject: [PATCH 0784/1647] bridgev2/legacymigrate: fix error message for other db upgrade --- bridgev2/matrix/mxmain/legacymigrate.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go index 35679e87..1e7f5a31 100644 --- a/bridgev2/matrix/mxmain/legacymigrate.go +++ b/bridgev2/matrix/mxmain/legacymigrate.go @@ -62,7 +62,7 @@ func (br *BridgeMain) LegacyMigrateWithAnotherUpgrader(renameTablesQuery, copyDa if err != nil { return err } else if otherUpgradesTo < otherNewVersion || otherCompat > otherNewVersion { - return fmt.Errorf("unexpected new database version for %s (%d/c:%d, expected %d)", otherTableName, otherUpgradesTo, otherCompat, newDBVersion) + return fmt.Errorf("unexpected new database version for %s (%d/c:%d, expected %d)", otherTableName, otherUpgradesTo, otherCompat, otherNewVersion) } _, err = br.DB.Exec(ctx, fmt.Sprintf("INSERT INTO %s (version, compat) VALUES ($1, $2)", otherTableName), otherUpgradesTo, otherCompat) if err != nil { From c27b62aa24edbad09ccc2e3563f7d1a5a0638adb Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 1 Oct 2024 10:41:51 -0600 Subject: [PATCH 0785/1647] bridgev2/provisioning: add request ID middleware Signed-off-by: Sumner Evans --- bridgev2/matrix/provisioning.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index f43c84ca..951e6df1 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -108,6 +108,7 @@ func (prov *ProvisioningAPI) Init() { tp.Transport.TLSHandshakeTimeout = 10 * time.Second prov.Router = prov.br.AS.Router.PathPrefix(prov.br.Config.Provisioning.Prefix).Subrouter() prov.Router.Use(hlog.NewHandler(prov.log)) + prov.Router.Use(hlog.RequestIDHandler("request_id", "Request-Id")) prov.Router.Use(corsMiddleware) prov.Router.Use(requestlog.AccessLogger(false)) prov.Router.Use(prov.AuthMiddleware) From 260f642bf0636b590bcd4b5315fa9faee172fee8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 1 Oct 2024 20:20:07 +0300 Subject: [PATCH 0786/1647] bridgev2/legacymigrate: fix extra parameters --- bridgev2/matrix/mxmain/legacymigrate.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go index 1e7f5a31..18880027 100644 --- a/bridgev2/matrix/mxmain/legacymigrate.go +++ b/bridgev2/matrix/mxmain/legacymigrate.go @@ -85,7 +85,7 @@ func (br *BridgeMain) LegacyMigrateWithAnotherUpgrader(renameTablesQuery, copyDa if err != nil { return err } - _, err = br.DB.Exec(ctx, "DELETE FROM version", upgradesTo, compat) + _, err = br.DB.Exec(ctx, "DELETE FROM version") if err != nil { return err } From 3cda7344479710143601d17152fc0ac17af73e9b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 2 Oct 2024 22:41:48 +0300 Subject: [PATCH 0787/1647] bridgev2/portal: don't handle unsaved pending message as upsert --- bridgev2/portal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index ea0cbab2..76c0a3ea 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1886,7 +1886,7 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin upsertEvt, isUpsert := evt.(RemoteMessageUpsert) isUpsert = isUpsert && evt.GetType() == RemoteEventMessageUpsert if wasPending, dbMessage := portal.checkPendingMessage(ctx, evt); wasPending { - if isUpsert { + if isUpsert && dbMessage != nil { portal.handleRemoteUpsert(ctx, source, upsertEvt, []*database.Message{dbMessage}) } return From eea038ea6bac23127512ccd5b82e1c54a895f8d2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 2 Oct 2024 23:01:41 +0300 Subject: [PATCH 0788/1647] bridgev2/config: mark async_events as unsafe --- bridgev2/matrix/mxmain/example-config.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index c8a86ac0..ed21df38 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -9,6 +9,7 @@ bridge: private_chat_portal_meta: false # Should events be handled asynchronously within portal rooms? # If true, events may end up being out of order, but slow events won't block other ones. + # This is not yet safe to use. async_events: false # Should every user have their own portals rather than sharing them? # By default, users who are in the same group on the remote network will be From b9fdcd0dcefb738cddcb0cbb5cb6ae2a331e1eee Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 3 Oct 2024 00:35:33 +0300 Subject: [PATCH 0789/1647] bridgev2/config: add support for deprecated integer permissions --- bridgev2/bridgeconfig/permissions.go | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/bridgev2/bridgeconfig/permissions.go b/bridgev2/bridgeconfig/permissions.go index e76046f5..610051e0 100644 --- a/bridgev2/bridgeconfig/permissions.go +++ b/bridgev2/bridgeconfig/permissions.go @@ -8,6 +8,8 @@ package bridgeconfig import ( "fmt" + "os" + "strconv" "strings" "gopkg.in/yaml.v3" @@ -94,6 +96,23 @@ func (p *Permissions) UnmarshalYAML(perm *yaml.Node) error { case "!!map": err := perm.Decode((*umPerm)(p)) return err + case "!!int": + val, err := strconv.Atoi(perm.Value) + if err != nil { + return fmt.Errorf("invalid permissions level %s", perm.Value) + } + _, _ = fmt.Fprintln(os.Stderr, "Warning: config contains deprecated integer permission values") + // Integer values are deprecated, so they're hardcoded + if val < 5 { + *p = PermissionLevelBlock + } else if val < 10 { + *p = PermissionLevelRelay + } else if val < 100 { + *p = PermissionLevelUser + } else { + *p = PermissionLevelAdmin + } + return nil default: return fmt.Errorf("invalid permissions type %s", perm.Tag) } From 7e041c6e76a529e9ae0f7c5534e71ebb87e781a7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 3 Oct 2024 12:33:00 +0300 Subject: [PATCH 0790/1647] dependencies: update --- go.mod | 6 +++--- go.sum | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/go.mod b/go.mod index affeeb31..8ef08be8 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,7 @@ module maunium.net/go/mautrix go 1.22.0 -toolchain go1.23.1 +toolchain go1.23.2 require ( filippo.io/edwards25519 v1.1.0 @@ -15,10 +15,10 @@ require ( github.com/rs/zerolog v1.33.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/stretchr/testify v1.9.0 - github.com/tidwall/gjson v1.17.3 + github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.4 - go.mau.fi/util v0.8.1-0.20240927174413-000d30f9a02a + go.mau.fi/util v0.8.1-0.20241003092848-3b49d3e0b9ee go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.27.0 golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 diff --git a/go.sum b/go.sum index c5058f61..ac8e03f6 100644 --- a/go.sum +++ b/go.sum @@ -41,8 +41,8 @@ github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDq github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.17.3 h1:bwWLZU7icoKRG+C+0PNwIKC6FCJO/Q3p2pZvuP0jN94= -github.com/tidwall/gjson v1.17.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= @@ -51,8 +51,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.8.1-0.20240927174413-000d30f9a02a h1:4TrWJ0ooHT9YssDBUgXNU8FiR2cwi9jEAjtaVur4f0M= -go.mau.fi/util v0.8.1-0.20240927174413-000d30f9a02a/go.mod h1:1Ixb8HWoVbl3rT6nAX6nV4iMkzn7KU/KXwE0Rn5RmsQ= +go.mau.fi/util v0.8.1-0.20241003092848-3b49d3e0b9ee h1:/BGpUK7fzVyFgy5KBiyP7ktEDn20vzz/5FTngrXtIEE= +go.mau.fi/util v0.8.1-0.20241003092848-3b49d3e0b9ee/go.mod h1:L9qnqEkhe4KpuYmILrdttKTXL79MwGLyJ4EOskWxO3I= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= From d48a5ca61567c6acaeae955ea4f97461550e231b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 4 Oct 2024 14:41:10 +0300 Subject: [PATCH 0791/1647] bridgev2/portal: assume newly created rooms are unmuted and untagged --- bridgev2/portal.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 76c0a3ea..62256430 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -3176,13 +3176,13 @@ func (portal *Portal) updateUserLocalInfo(ctx context.Context, info *UserLocalPo if info == nil { return } - if info.MutedUntil != nil && (didJustCreate || !portal.Bridge.Config.MuteOnlyOnCreate) { + if info.MutedUntil != nil && (didJustCreate || !portal.Bridge.Config.MuteOnlyOnCreate) && (!didJustCreate || info.MutedUntil.After(time.Now())) { err := dp.MuteRoom(ctx, portal.MXID, *info.MutedUntil) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to mute room") } } - if info.Tag != nil && (didJustCreate || !portal.Bridge.Config.TagOnlyOnCreate) { + if info.Tag != nil && (didJustCreate || !portal.Bridge.Config.TagOnlyOnCreate) && (!didJustCreate || *info.Tag != "") { err := dp.TagRoom(ctx, portal.MXID, *info.Tag, *info.Tag != "") if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to tag room") From 144a99595122bb384262f7cf75fdf106054cf794 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 4 Oct 2024 16:25:43 +0300 Subject: [PATCH 0792/1647] bridgev2/commands: add pm as alias to start-chat --- bridgev2/commands/startchat.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go index c18e977a..53c07530 100644 --- a/bridgev2/commands/startchat.go +++ b/bridgev2/commands/startchat.go @@ -31,8 +31,9 @@ var CommandResolveIdentifier = &FullHandler{ } var CommandStartChat = &FullHandler{ - Func: fnResolveIdentifier, - Name: "start-chat", + Func: fnResolveIdentifier, + Name: "start-chat", + Aliases: []string{"pm"}, Help: HelpMeta{ Section: HelpSectionChats, Description: "Start a direct chat with the given user", From 1e7196ed34947b509023fc8d784611fedbd8b9f9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 5 Oct 2024 16:07:31 +0300 Subject: [PATCH 0793/1647] hicli: add helpers for using hicli over RPC --- hicli/database/account.go | 9 +++ hicli/database/room.go | 17 ++--- hicli/events.go | 8 +++ hicli/hicli.go | 30 ++++++--- hicli/json-commands.go | 136 ++++++++++++++++++++++++++++++++++++++ hicli/json.go | 113 +++++++++++++++++++++++++++++++ hicli/login.go | 7 +- hicli/verify.go | 6 +- 8 files changed, 304 insertions(+), 22 deletions(-) create mode 100644 hicli/json-commands.go create mode 100644 hicli/json.go diff --git a/hicli/database/account.go b/hicli/database/account.go index 49b50771..1dde74fd 100644 --- a/hicli/database/account.go +++ b/hicli/database/account.go @@ -8,6 +8,8 @@ package database import ( "context" + "database/sql" + "errors" "go.mau.fi/util/dbutil" @@ -32,7 +34,14 @@ type AccountQuery struct { } func (aq *AccountQuery) GetFirstUserID(ctx context.Context) (userID id.UserID, err error) { + var exists bool + if exists, err = aq.GetDB().TableExists(ctx, "account"); err != nil || !exists { + return + } err = aq.GetDB().QueryRow(ctx, `SELECT user_id FROM account LIMIT 1`).Scan(&userID) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } return } diff --git a/hicli/database/room.go b/hicli/database/room.go index 92adc279..1acbd081 100644 --- a/hicli/database/room.go +++ b/hicli/database/room.go @@ -13,6 +13,7 @@ import ( "time" "go.mau.fi/util/dbutil" + "go.mau.fi/util/jsontime" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" @@ -115,8 +116,8 @@ type Room struct { EncryptionEvent *event.EncryptionEventContent `json:"encryption_event,omitempty"` HasMemberList bool `json:"has_member_list"` - PreviewEventRowID EventRowID `json:"preview_event_rowid"` - SortingTimestamp time.Time `json:"sorting_timestamp"` + PreviewEventRowID EventRowID `json:"preview_event_rowid"` + SortingTimestamp jsontime.UnixMilli `json:"sorting_timestamp"` PrevBatch string `json:"prev_batch"` } @@ -152,7 +153,7 @@ func (r *Room) CheckChangesAndCopyInto(other *Room) (hasChanges bool) { other.PreviewEventRowID = r.PreviewEventRowID hasChanges = true } - if r.SortingTimestamp.After(other.SortingTimestamp) { + if r.SortingTimestamp.After(other.SortingTimestamp.Time) { other.SortingTimestamp = r.SortingTimestamp hasChanges = true } @@ -186,7 +187,7 @@ func (r *Room) Scan(row dbutil.Scannable) (*Room, error) { } r.PrevBatch = prevBatch.String r.PreviewEventRowID = EventRowID(previewEventRowID.Int64) - r.SortingTimestamp = time.UnixMilli(sortingTimestamp.Int64) + r.SortingTimestamp = jsontime.UM(time.UnixMilli(sortingTimestamp.Int64)) return r, nil } @@ -203,19 +204,19 @@ func (r *Room) sqlVariables() []any { dbutil.JSONPtr(r.EncryptionEvent), r.HasMemberList, dbutil.NumPtr(r.PreviewEventRowID), - dbutil.UnixMilliPtr(r.SortingTimestamp), + dbutil.UnixMilliPtr(r.SortingTimestamp.Time), dbutil.StrPtr(r.PrevBatch), } } func (r *Room) BumpSortingTimestamp(evt *Event) bool { - if !evt.BumpsSortingTimestamp() || evt.Timestamp.Before(r.SortingTimestamp) { + if !evt.BumpsSortingTimestamp() || evt.Timestamp.Before(r.SortingTimestamp.Time) { return false } - r.SortingTimestamp = evt.Timestamp + r.SortingTimestamp = jsontime.UM(evt.Timestamp) now := time.Now() if r.SortingTimestamp.After(now) { - r.SortingTimestamp = now + r.SortingTimestamp = jsontime.UM(now) } return true } diff --git a/hicli/events.go b/hicli/events.go index a30dda8d..2de01f35 100644 --- a/hicli/events.go +++ b/hicli/events.go @@ -37,3 +37,11 @@ type SendComplete struct { Event *database.Event `json:"event"` Error error `json:"error"` } + +type ClientState struct { + IsLoggedIn bool `json:"is_logged_in"` + IsVerified bool `json:"is_verified"` + UserID id.UserID `json:"user_id,omitempty"` + DeviceID id.DeviceID `json:"device_id,omitempty"` + HomeserverURL string `json:"homeserver_url,omitempty"` +} diff --git a/hicli/hicli.go b/hicli/hicli.go index 7524b6bc..4253c581 100644 --- a/hicli/hicli.go +++ b/hicli/hicli.go @@ -15,6 +15,7 @@ import ( "net/http" "net/url" "sync" + "sync/atomic" "time" "github.com/rs/zerolog" @@ -47,11 +48,14 @@ type HiClient struct { firstSyncReceived bool syncingID int syncLock sync.Mutex - stopSync context.CancelFunc + stopSync atomic.Pointer[context.CancelFunc] encryptLock sync.Mutex requestQueueWakeup chan struct{} + jsonRequestsLock sync.Mutex + jsonRequests map[int64]context.CancelCauseFunc + paginationInterrupterLock sync.Mutex paginationInterrupter map[id.RoomID]context.CancelCauseFunc } @@ -74,7 +78,9 @@ func New(rawDB, cryptoDB *dbutil.Database, log zerolog.Logger, pickleKey []byte, DB: db, Log: log, - requestQueueWakeup: make(chan struct{}, 1), + requestQueueWakeup: make(chan struct{}, 1), + jsonRequests: make(map[int64]context.CancelCauseFunc), + paginationInterrupter: make(map[id.RoomID]context.CancelCauseFunc), EventHandler: evtHandler, } @@ -166,7 +172,6 @@ func (h *HiClient) Start(ctx context.Context, userID id.UserID, expectedAccount return err } go h.Sync() - go h.RunRequestQueue(ctx) } } return nil @@ -186,10 +191,14 @@ func (h *HiClient) CheckServerVersions(ctx context.Context) error { return nil } +func (h *HiClient) IsSyncing() bool { + return h.stopSync.Load() != nil +} + func (h *HiClient) Sync() { h.Client.StopSync() - if fn := h.stopSync; fn != nil { - fn() + if fn := h.stopSync.Load(); fn != nil { + (*fn)() } h.syncLock.Lock() defer h.syncLock.Unlock() @@ -199,8 +208,11 @@ func (h *HiClient) Sync() { Str("action", "sync"). Int("sync_id", syncingID). Logger() - ctx, cancel := context.WithCancel(log.WithContext(context.Background())) - h.stopSync = cancel + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + h.stopSync.Store(&cancel) + go h.RunRequestQueue(h.Log.WithContext(ctx)) + ctx = log.WithContext(ctx) log.Info().Msg("Starting syncing") err := h.Client.SyncWithContext(ctx) if err != nil && ctx.Err() == nil { @@ -212,8 +224,8 @@ func (h *HiClient) Sync() { func (h *HiClient) Stop() { h.Client.StopSync() - if fn := h.stopSync; fn != nil { - fn() + if fn := h.stopSync.Swap(nil); fn != nil { + (*fn)() } h.syncLock.Lock() h.syncLock.Unlock() diff --git a/hicli/json-commands.go b/hicli/json-commands.go new file mode 100644 index 00000000..90cb2890 --- /dev/null +++ b/hicli/json-commands.go @@ -0,0 +1,136 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "encoding/json" + "errors" + "fmt" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/hicli/database" + "maunium.net/go/mautrix/id" +) + +func (h *HiClient) handleJSONCommand(ctx context.Context, req *JSONCommand) (any, error) { + switch req.Command { + case "get_state": + return h.State(), nil + case "cancel": + return unmarshalAndCall(req.Data, func(params *cancelRequestParams) (bool, error) { + h.jsonRequestsLock.Lock() + cancelTarget, ok := h.jsonRequests[params.RequestID] + h.jsonRequestsLock.Unlock() + if ok { + return false, nil + } + if params.Reason == "" { + cancelTarget(nil) + } else { + cancelTarget(errors.New(params.Reason)) + } + return true, nil + }) + case "send_message": + return unmarshalAndCall(req.Data, func(params *sendParams) (*database.Event, error) { + return h.Send(ctx, params.RoomID, params.EventType, params.Content) + }) + case "get_event": + return unmarshalAndCall(req.Data, func(params *getEventParams) (*database.Event, error) { + return h.GetEvent(ctx, params.RoomID, params.EventID) + }) + case "get_events_by_rowids": + return unmarshalAndCall(req.Data, func(params *getEventsByRowIDsParams) ([]*database.Event, error) { + return h.GetEventsByRowIDs(ctx, params.RowIDs) + }) + case "paginate": + return unmarshalAndCall(req.Data, func(params *paginateParams) ([]*database.Event, error) { + return h.Paginate(ctx, params.RoomID, params.MaxTimelineID, params.Limit) + }) + case "paginate_server": + return unmarshalAndCall(req.Data, func(params *paginateParams) ([]*database.Event, error) { + return h.PaginateServer(ctx, params.RoomID, params.Limit) + }) + case "ensure_group_session_shared": + return unmarshalAndCall(req.Data, func(params *ensureGroupSessionSharedParams) (bool, error) { + return true, h.EnsureGroupSessionShared(ctx, params.RoomID) + }) + case "login": + return unmarshalAndCall(req.Data, func(params *loginParams) (bool, error) { + return true, h.LoginPassword(ctx, params.HomeserverURL, params.Username, params.Password) + }) + case "verify": + return unmarshalAndCall(req.Data, func(params *verifyParams) (bool, error) { + return true, h.VerifyWithRecoveryKey(ctx, params.RecoveryKey) + }) + case "discover_homeserver": + return unmarshalAndCall(req.Data, func(params *discoverHomeserverParams) (*mautrix.ClientWellKnown, error) { + _, homeserver, err := params.UserID.Parse() + if err != nil { + return nil, err + } + return mautrix.DiscoverClientAPI(ctx, homeserver) + }) + default: + return nil, fmt.Errorf("unknown command %q", req.Command) + } +} + +func unmarshalAndCall[T, O any](data json.RawMessage, fn func(*T) (O, error)) (output O, err error) { + var input T + err = json.Unmarshal(data, &input) + if err != nil { + return + } + return fn(&input) +} + +type cancelRequestParams struct { + RequestID int64 `json:"request_id"` + Reason string `json:"reason"` +} + +type sendParams struct { + RoomID id.RoomID `json:"room_id"` + EventType event.Type `json:"type"` + Content json.RawMessage `json:"content"` +} + +type getEventParams struct { + RoomID id.RoomID `json:"room_id"` + EventID id.EventID `json:"event_id"` +} + +type getEventsByRowIDsParams struct { + RowIDs []database.EventRowID `json:"row_ids"` +} + +type ensureGroupSessionSharedParams struct { + RoomID id.RoomID `json:"room_id"` +} + +type loginParams struct { + HomeserverURL string `json:"homeserver_url"` + Username string `json:"username"` + Password string `json:"password"` +} + +type verifyParams struct { + RecoveryKey string `json:"recovery_key"` +} + +type discoverHomeserverParams struct { + UserID id.UserID `json:"user_id"` +} + +type paginateParams struct { + RoomID id.RoomID `json:"room_id"` + MaxTimelineID database.TimelineRowID `json:"max_timeline_id"` + Limit int `json:"limit"` +} diff --git a/hicli/json.go b/hicli/json.go new file mode 100644 index 00000000..df853c33 --- /dev/null +++ b/hicli/json.go @@ -0,0 +1,113 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sync/atomic" + + "go.mau.fi/util/exerrors" +) + +type JSONCommand struct { + Command string `json:"command"` + RequestID int64 `json:"request_id"` + Data json.RawMessage `json:"data"` +} + +type JSONEventHandler func(*JSONCommand) + +var outgoingEventCounter atomic.Int64 + +func (jeh JSONEventHandler) HandleEvent(evt any) { + var command string + switch evt.(type) { + case *SyncComplete: + command = "sync_complete" + case *EventsDecrypted: + command = "events_decrypted" + case *Typing: + command = "typing" + case *SendComplete: + command = "send_complete" + case *ClientState: + command = "client_state" + default: + panic(fmt.Errorf("unknown event type %T", evt)) + } + data, err := json.Marshal(evt) + if err != nil { + panic(fmt.Errorf("failed to marshal event %T: %w", evt, err)) + } + jeh(&JSONCommand{ + Command: command, + RequestID: -outgoingEventCounter.Add(1), + Data: data, + }) +} + +func (h *HiClient) State() *ClientState { + state := &ClientState{} + if acc := h.Account; acc != nil { + state.IsLoggedIn = true + state.UserID = acc.UserID + state.DeviceID = acc.DeviceID + state.HomeserverURL = acc.HomeserverURL + state.IsVerified = h.Verified + } + return state +} + +func (h *HiClient) dispatchCurrentState() { + h.EventHandler(h.State()) +} + +func (h *HiClient) SubmitJSONCommand(ctx context.Context, req *JSONCommand) *JSONCommand { + log := h.Log.With().Int64("request_id", req.RequestID).Str("command", req.Command).Logger() + ctx, cancel := context.WithCancelCause(ctx) + defer func() { + cancel(nil) + h.jsonRequestsLock.Lock() + delete(h.jsonRequests, req.RequestID) + h.jsonRequestsLock.Unlock() + }() + ctx = log.WithContext(ctx) + h.jsonRequestsLock.Lock() + h.jsonRequests[req.RequestID] = cancel + h.jsonRequestsLock.Unlock() + resp, err := h.handleJSONCommand(ctx, req) + if err != nil { + if errors.Is(err, context.Canceled) { + causeErr := context.Cause(ctx) + if causeErr != ctx.Err() { + err = fmt.Errorf("%w: %w", err, causeErr) + } + } + return &JSONCommand{ + Command: "error", + RequestID: req.RequestID, + Data: exerrors.Must(json.Marshal(err.Error())), + } + } + var respData json.RawMessage + respData, err = json.Marshal(resp) + if err != nil { + return &JSONCommand{ + Command: "error", + RequestID: req.RequestID, + Data: exerrors.Must(json.Marshal(fmt.Sprintf("failed to marshal response json: %v", err))), + } + } + return &JSONCommand{ + Command: "response", + RequestID: req.RequestID, + Data: respData, + } +} diff --git a/hicli/login.go b/hicli/login.go index d33ea422..6dbaf6e6 100644 --- a/hicli/login.go +++ b/hicli/login.go @@ -46,6 +46,7 @@ func (h *HiClient) Login(ctx context.Context, req *mautrix.ReqLogin) error { if err != nil { return err } + defer h.dispatchCurrentState() h.Account = &database.Account{ UserID: resp.UserID, DeviceID: resp.DeviceID, @@ -73,16 +74,14 @@ func (h *HiClient) Login(ctx context.Context, req *mautrix.ReqLogin) error { return nil } -func (h *HiClient) LoginAndVerify(ctx context.Context, homeserverURL, username, password, recoveryCode string) error { +func (h *HiClient) LoginAndVerify(ctx context.Context, homeserverURL, username, password, recoveryKey string) error { err := h.LoginPassword(ctx, homeserverURL, username, password) if err != nil { return err } - err = h.VerifyWithRecoveryCode(ctx, recoveryCode) + err = h.VerifyWithRecoveryKey(ctx, recoveryKey) if err != nil { return err } - go h.Sync() - go h.RunRequestQueue(ctx) return nil } diff --git a/hicli/verify.go b/hicli/verify.go index 905be052..6dc2a4c3 100644 --- a/hicli/verify.go +++ b/hicli/verify.go @@ -124,7 +124,8 @@ func (h *HiClient) storeCrossSigningPrivateKeys(ctx context.Context) error { return nil } -func (h *HiClient) VerifyWithRecoveryCode(ctx context.Context, code string) error { +func (h *HiClient) VerifyWithRecoveryKey(ctx context.Context, code string) error { + defer h.dispatchCurrentState() keyID, keyData, err := h.Crypto.SSSS.GetDefaultKeyData(ctx) if err != nil { return fmt.Errorf("failed to get default SSSS key data: %w", err) @@ -154,5 +155,8 @@ func (h *HiClient) VerifyWithRecoveryCode(ctx context.Context, code string) erro return fmt.Errorf("failed to fetch key backup key: %w", err) } h.Verified = true + if !h.IsSyncing() { + go h.Sync() + } return nil } From 6e2a54d2b07d186f891cad60cd14390462d96a06 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 6 Oct 2024 00:12:20 +0300 Subject: [PATCH 0794/1647] hicli: parse raw before removing reply fallback --- hicli/sync.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/hicli/sync.go b/hicli/sync.go index c3f30a72..086f6dd1 100644 --- a/hicli/sync.go +++ b/hicli/sync.go @@ -204,6 +204,10 @@ func isDecryptionErrorRetryable(err error) bool { } func removeReplyFallback(evt *event.Event) []byte { + if evt.Type != event.EventMessage { + return nil + } + _ = evt.Content.ParseRaw(evt.Type) content, ok := evt.Content.Parsed.(*event.MessageEventContent) if ok && content.RelatesTo.GetReplyTo() != "" { prevFormattedBody := content.FormattedBody From 0e05a6b8661404ae7e8dcfd5a343d23e06985ebd Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 6 Oct 2024 02:09:30 +0300 Subject: [PATCH 0795/1647] crypto: reduce logs when verifying cross-signign --- crypto/cross_sign_store.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crypto/cross_sign_store.go b/crypto/cross_sign_store.go index 968a52a1..4c2c80e4 100644 --- a/crypto/cross_sign_store.go +++ b/crypto/cross_sign_store.go @@ -77,11 +77,11 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK } } if len(signingKey) != 43 { - log.Debug().Msg("Cross-signing key has a signature from an unknown key") + log.Trace().Msg("Cross-signing key has a signature from an unknown key") continue } - log.Debug().Msg("Verifying cross-signing key signature") + log.Trace().Msg("Verifying cross-signing key signature") if verified, err := signatures.VerifySignatureJSON(userKeys, signUserID, signKeyName, signingKey); err != nil { log.Warn().Err(err).Msg("Error verifying cross-signing key signature") } else { From e2329e84300e155e46b99a4c5e1f64c929c9f99b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 6 Oct 2024 02:10:03 +0300 Subject: [PATCH 0796/1647] hicli: save event send error to database --- hicli/database/event.go | 21 ++++++++++++++----- hicli/database/state.go | 2 +- .../database/upgrades/00-latest-revision.sql | 1 + hicli/send.go | 8 ++++++- 4 files changed, 25 insertions(+), 7 deletions(-) diff --git a/hicli/database/event.go b/hicli/database/event.go index de21e317..57a3291d 100644 --- a/hicli/database/event.go +++ b/hicli/database/event.go @@ -25,7 +25,8 @@ import ( const ( getEventBaseQuery = ` SELECT rowid, -1, room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, - transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, reactions, last_edit_rowid + transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, send_error, + reactions, last_edit_rowid FROM event ` getEventByRowID = getEventBaseQuery + `WHERE rowid = $1` @@ -35,9 +36,9 @@ const ( insertEventBaseQuery = ` INSERT INTO event ( room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, - transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error + transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, send_error ) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17) ` insertEventQuery = insertEventBaseQuery + `RETURNING rowid` upsertEventQuery = insertEventBaseQuery + ` @@ -46,6 +47,7 @@ const ( decrypted_type=COALESCE(event.decrypted_type, excluded.decrypted_type), redacted_by=COALESCE(event.redacted_by, excluded.redacted_by), decryption_error=CASE WHEN COALESCE(event.decrypted, excluded.decrypted) IS NULL THEN COALESCE(excluded.decryption_error, event.decryption_error) END, + send_error=excluded.send_error, timestamp=excluded.timestamp, unsigned=COALESCE(excluded.unsigned, event.unsigned) ON CONFLICT (transaction_id) DO UPDATE @@ -54,7 +56,8 @@ const ( unsigned=excluded.unsigned RETURNING rowid ` - updateEventIDQuery = `UPDATE event SET event_id=$2 WHERE rowid=$1` + updateEventSendErrorQuery = `UPDATE event SET send_error = $2 WHERE rowid = $1` + updateEventIDQuery = `UPDATE event SET event_id = $2, send_error = NULL WHERE rowid=$1` updateEventDecryptedQuery = `UPDATE event SET decrypted = $1, decrypted_type = $2, decryption_error = NULL WHERE rowid = $3` getEventReactionsQuery = getEventBaseQuery + ` WHERE room_id = ? @@ -123,6 +126,10 @@ func (eq *EventQuery) UpdateID(ctx context.Context, rowID EventRowID, newID id.E return eq.Exec(ctx, updateEventIDQuery, rowID, newID) } +func (eq *EventQuery) UpdateSendError(ctx context.Context, rowID EventRowID, sendError string) error { + return eq.Exec(ctx, updateEventSendErrorQuery, rowID, sendError) +} + func (eq *EventQuery) UpdateDecrypted(ctx context.Context, rowID EventRowID, decrypted json.RawMessage, decryptedType string) error { return eq.Exec(ctx, updateEventDecryptedQuery, unsafeJSONString(decrypted), decryptedType, rowID) } @@ -280,6 +287,7 @@ type Event struct { MegolmSessionID id.SessionID `json:"-,omitempty"` DecryptionError string `json:"decryption_error,omitempty"` + SendError string `json:"send_error,omitempty"` Reactions map[string]int `json:"reactions,omitempty"` LastEditRowID *EventRowID `json:"last_edit_rowid,omitempty"` @@ -332,7 +340,7 @@ func (e *Event) AsRawMautrix() *event.Event { func (e *Event) Scan(row dbutil.Scannable) (*Event, error) { var timestamp int64 - var transactionID, redactedBy, relatesTo, relationType, megolmSessionID, decryptionError, decryptedType sql.NullString + var transactionID, redactedBy, relatesTo, relationType, megolmSessionID, decryptionError, sendError, decryptedType sql.NullString err := row.Scan( &e.RowID, &e.TimelineRowID, @@ -352,6 +360,7 @@ func (e *Event) Scan(row dbutil.Scannable) (*Event, error) { &relationType, &megolmSessionID, &decryptionError, + &sendError, dbutil.JSON{Data: &e.Reactions}, &e.LastEditRowID, ) @@ -366,6 +375,7 @@ func (e *Event) Scan(row dbutil.Scannable) (*Event, error) { e.MegolmSessionID = id.SessionID(megolmSessionID.String) e.DecryptedType = decryptedType.String e.DecryptionError = decryptionError.String + e.SendError = sendError.String return e, nil } @@ -420,6 +430,7 @@ func (e *Event) sqlVariables() []any { dbutil.StrPtr(e.RelationType), dbutil.StrPtr(e.MegolmSessionID), dbutil.StrPtr(e.DecryptionError), + dbutil.StrPtr(e.SendError), dbutil.JSON{Data: reactions}, e.LastEditRowID, } diff --git a/hicli/database/state.go b/hicli/database/state.go index 845de6ed..e74f2950 100644 --- a/hicli/database/state.go +++ b/hicli/database/state.go @@ -22,7 +22,7 @@ const ( ` getCurrentRoomStateQuery = ` SELECT event.rowid, -1, event.room_id, event.event_id, sender, event.type, event.state_key, timestamp, content, decrypted, decrypted_type, unsigned, - transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, reactions, last_edit_rowid + transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid FROM current_state cs JOIN event ON cs.event_rowid = event.rowid WHERE cs.room_id = $1 diff --git a/hicli/database/upgrades/00-latest-revision.sql b/hicli/database/upgrades/00-latest-revision.sql index df6499a1..f4456165 100644 --- a/hicli/database/upgrades/00-latest-revision.sql +++ b/hicli/database/upgrades/00-latest-revision.sql @@ -74,6 +74,7 @@ CREATE TABLE event ( megolm_session_id TEXT, decryption_error TEXT, + send_error TEXT, reactions TEXT, last_edit_rowid INTEGER, diff --git a/hicli/send.go b/hicli/send.go index 66175e75..8824f3c3 100644 --- a/hicli/send.go +++ b/hicli/send.go @@ -68,6 +68,7 @@ func (h *HiClient) Send(ctx context.Context, roomID id.RoomID, evtType event.Typ RelationType: relationType, MegolmSessionID: megolmSessionID, DecryptionError: "", + SendError: "not sent", Reactions: map[string]int{}, LastEditRowID: &zero, } @@ -90,8 +91,13 @@ func (h *HiClient) Send(ctx context.Context, roomID id.RoomID, evtType event.Typ DontEncrypt: true, }) if err != nil { - // TODO save send error to db? + dbEvt.SendError = err.Error() err = fmt.Errorf("failed to send event: %w", err) + err2 := h.DB.Event.UpdateSendError(ctx, dbEvt.RowID, dbEvt.SendError) + if err2 != nil { + zerolog.Ctx(ctx).Err(err2).AnErr("send_error", err). + Msg("Failed to update send error in database after sending failed") + } return } dbEvt.ID = resp.EventID From a284650568ffb5f63a71c64c1cdcb27bb2ad3371 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 6 Oct 2024 02:11:05 +0300 Subject: [PATCH 0797/1647] hicli: scope EventsDecrypted event to single room --- hicli/decryptionqueue.go | 6 +++--- hicli/events.go | 5 +++-- hicli/hitest/hitest.go | 4 ++-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/hicli/decryptionqueue.go b/hicli/decryptionqueue.go index 02466b69..70ea9f23 100644 --- a/hicli/decryptionqueue.go +++ b/hicli/decryptionqueue.go @@ -60,7 +60,7 @@ func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.Ro } } if len(decrypted) > 0 { - previewRowIDChanges := make(map[id.RoomID]database.EventRowID) + var newPreview database.EventRowID err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { for _, evt := range decrypted { err = h.DB.Event.UpdateDecrypted(ctx, evt.RowID, evt.Decrypted, evt.DecryptedType) @@ -73,7 +73,7 @@ func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.Ro if err != nil { return fmt.Errorf("failed to update room %s preview to %d: %w", evt.RoomID, evt.RowID, err) } else if previewChanged { - previewRowIDChanges[evt.RoomID] = evt.RowID + newPreview = evt.RowID } } } @@ -82,7 +82,7 @@ func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.Ro if err != nil { log.Err(err).Msg("Failed to save decrypted events") } else { - h.EventHandler(&EventsDecrypted{Events: decrypted, PreviewRowIDs: previewRowIDChanges}) + h.EventHandler(&EventsDecrypted{Events: decrypted, PreviewEventRowID: newPreview, RoomID: roomID}) } } } diff --git a/hicli/events.go b/hicli/events.go index 2de01f35..c7228541 100644 --- a/hicli/events.go +++ b/hicli/events.go @@ -24,8 +24,9 @@ type SyncComplete struct { } type EventsDecrypted struct { - PreviewRowIDs map[id.RoomID]database.EventRowID `json:"room_preview_rowids"` - Events []*database.Event `json:"events"` + RoomID id.RoomID `json:"room_id"` + PreviewEventRowID database.EventRowID `json:"preview_event_rowid,omitempty"` + Events []*database.Event `json:"events"` } type Typing struct { diff --git a/hicli/hitest/hitest.go b/hicli/hitest/hitest.go index c6873bac..bdf1598f 100644 --- a/hicli/hitest/hitest.go +++ b/hicli/hitest/hitest.go @@ -67,8 +67,8 @@ func main() { for _, decrypted := range evt.Events { _, _ = fmt.Fprintf(rl, "Delayed decryption of %s completed: %s / %s\n", decrypted.ID, decrypted.DecryptedType, decrypted.Decrypted) } - if len(evt.PreviewRowIDs) > 0 { - _, _ = fmt.Fprintf(rl, "Room previews updated: %+v\n", evt.PreviewRowIDs) + if evt.PreviewEventRowID != 0 { + _, _ = fmt.Fprintf(rl, "Room preview updated: %+v\n", evt.PreviewEventRowID) } case *hicli.Typing: _, _ = fmt.Fprintf(rl, "Typing list in %s: %+v\n", evt.RoomID, evt.UserIDs) From 381c8780e036be338ec915763124baa673918ef7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 6 Oct 2024 14:48:00 +0300 Subject: [PATCH 0798/1647] hicli: add media cache table --- hicli/database/cachedmedia.go | 112 ++++++++++++++++++ hicli/database/database.go | 6 + .../database/upgrades/00-latest-revision.sql | 12 ++ 3 files changed, 130 insertions(+) create mode 100644 hicli/database/cachedmedia.go diff --git a/hicli/database/cachedmedia.go b/hicli/database/cachedmedia.go new file mode 100644 index 00000000..9ea9c27a --- /dev/null +++ b/hicli/database/cachedmedia.go @@ -0,0 +1,112 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + + "go.mau.fi/util/dbutil" + "golang.org/x/exp/slices" + + "maunium.net/go/mautrix/crypto/attachment" + "maunium.net/go/mautrix/id" +) + +const ( + insertCachedMediaQuery = ` + INSERT INTO cached_media (mxc, event_rowid, enc_file, file_name, mime_type, size, hash) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (mxc) DO NOTHING + ` + upsertCachedMediaQuery = ` + INSERT INTO cached_media (mxc, event_rowid, enc_file, file_name, mime_type, size, hash) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (mxc) DO UPDATE + SET enc_file = excluded.enc_file, + file_name = excluded.file_name, + mime_type = excluded.mime_type, + size = excluded.size, + hash = excluded.hash + ` + getCachedMediaQuery = ` + SELECT mxc, event_rowid, enc_file, file_name, mime_type, size, hash + FROM cached_media + WHERE mxc = $1 + ` +) + +type CachedMediaQuery struct { + *dbutil.QueryHelper[*CachedMedia] +} + +func (cmq *CachedMediaQuery) Add(ctx context.Context, cm *CachedMedia) error { + return cmq.Exec(ctx, insertCachedMediaQuery, cm.sqlVariables()...) +} + +func (cmq *CachedMediaQuery) Put(ctx context.Context, cm *CachedMedia) error { + return cmq.Exec(ctx, upsertCachedMediaQuery, cm.sqlVariables()...) +} + +func (cmq *CachedMediaQuery) Get(ctx context.Context, mxc id.ContentURI) (*CachedMedia, error) { + return cmq.QueryOne(ctx, getCachedMediaQuery, &mxc) +} + +type CachedMedia struct { + MXC id.ContentURI + EventRowID EventRowID + EncFile *attachment.EncryptedFile + FileName string + MimeType string + Size int64 + Hash *[32]byte +} + +func (c *CachedMedia) sqlVariables() []any { + var hash []byte + if c.Hash != nil { + hash = c.Hash[:] + } + return []any{ + &c.MXC, dbutil.NumPtr(c.EventRowID), dbutil.JSONPtr(c.EncFile), + dbutil.StrPtr(c.FileName), dbutil.StrPtr(c.MimeType), dbutil.NumPtr(c.Size), hash, + } +} + +var safeMimes = []string{ + "text/css", "text/plain", "text/csv", + "application/json", "application/ld+json", + "image/jpeg", "image/gif", "image/png", "image/apng", "image/webp", "image/avif", + "video/mp4", "video/webm", "video/ogg", "video/quicktime", + "audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave", + "audio/wav", "audio/x-wav", "audio/x-pn-wav", "audio/flac", "audio/x-flac", +} + +func (c *CachedMedia) Scan(row dbutil.Scannable) (*CachedMedia, error) { + var mimeType, fileName sql.NullString + var size, eventRowID sql.NullInt64 + var hash []byte + err := row.Scan(&c.MXC, &eventRowID, dbutil.JSON{Data: &c.EncFile}, &fileName, &mimeType, &size, &hash) + if err != nil { + return nil, err + } + c.MimeType = mimeType.String + c.FileName = fileName.String + c.EventRowID = EventRowID(eventRowID.Int64) + c.Size = size.Int64 + if hash != nil && len(hash) == 32 { + c.Hash = (*[32]byte)(hash) + } + return c, nil +} + +func (c *CachedMedia) ContentDisposition() string { + if slices.Contains(safeMimes, c.MimeType) { + return "inline" + } + return "attachment" +} diff --git a/hicli/database/database.go b/hicli/database/database.go index 601ca64b..2a357f06 100644 --- a/hicli/database/database.go +++ b/hicli/database/database.go @@ -23,6 +23,7 @@ type Database struct { Timeline TimelineQuery SessionRequest SessionRequestQuery Receipt ReceiptQuery + CachedMedia CachedMediaQuery } func New(rawDB *dbutil.Database) *Database { @@ -39,6 +40,7 @@ func New(rawDB *dbutil.Database) *Database { Timeline: TimelineQuery{QueryHelper: eventQH}, SessionRequest: SessionRequestQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newSessionRequest)}, Receipt: ReceiptQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newReceipt)}, + CachedMedia: CachedMediaQuery{QueryHelper: dbutil.MakeQueryHelper(rawDB, newCachedMedia)}, } } @@ -58,6 +60,10 @@ func newReceipt(_ *dbutil.QueryHelper[*Receipt]) *Receipt { return &Receipt{} } +func newCachedMedia(_ *dbutil.QueryHelper[*CachedMedia]) *CachedMedia { + return &CachedMedia{} +} + func newAccountData(_ *dbutil.QueryHelper[*AccountData]) *AccountData { return &AccountData{} } diff --git a/hicli/database/upgrades/00-latest-revision.sql b/hicli/database/upgrades/00-latest-revision.sql index f4456165..848dce47 100644 --- a/hicli/database/upgrades/00-latest-revision.sql +++ b/hicli/database/upgrades/00-latest-revision.sql @@ -177,6 +177,18 @@ BEGIN AND reactions IS NOT NULL; END; +CREATE TABLE cached_media ( + mxc TEXT NOT NULL PRIMARY KEY, + event_rowid INTEGER, + enc_file TEXT, + file_name TEXT, + mime_type TEXT, + size INTEGER, + hash BLOB, + + CONSTRAINT cached_media_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid) ON DELETE SET NULL +) STRICT; + CREATE TABLE session_request ( room_id TEXT NOT NULL, session_id TEXT NOT NULL, From bb6aaf79a9402d0242aa843de132f92719a917cc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 6 Oct 2024 15:06:28 +0300 Subject: [PATCH 0799/1647] event: add helpers for getting caption and file name --- event/message.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/event/message.go b/event/message.go index 097c585e..9badd9a2 100644 --- a/event/message.go +++ b/event/message.go @@ -143,6 +143,27 @@ type MessageEventContent struct { MSC3245Voice *MSC3245Voice `json:"org.matrix.msc3245.voice,omitempty"` } +func (content *MessageEventContent) GetFileName() string { + if content.FileName != "" { + return content.FileName + } + return content.Body +} + +func (content *MessageEventContent) GetCaption() string { + if content.FileName != "" && content.Body != "" && content.Body != content.FileName { + return content.Body + } + return "" +} + +func (content *MessageEventContent) GetFormattedCaption() string { + if content.Format == FormatHTML && content.FormattedBody != "" { + return content.FormattedBody + } + return "" +} + func (content *MessageEventContent) GetRelatesTo() *RelatesTo { if content.RelatesTo == nil { content.RelatesTo = &RelatesTo{} From 014ea707622a188ccf592a6e0bc537b7d668d4d3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 6 Oct 2024 15:07:04 +0300 Subject: [PATCH 0800/1647] hicli: add media cache entries when receiving events --- hicli/decryptionqueue.go | 5 ++- hicli/sync.go | 84 +++++++++++++++++++++++++++++++++++++--- 2 files changed, 82 insertions(+), 7 deletions(-) diff --git a/hicli/decryptionqueue.go b/hicli/decryptionqueue.go index 70ea9f23..87b6b8b2 100644 --- a/hicli/decryptionqueue.go +++ b/hicli/decryptionqueue.go @@ -14,6 +14,7 @@ import ( "github.com/rs/zerolog" "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/hicli/database" "maunium.net/go/mautrix/id" ) @@ -52,11 +53,13 @@ func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.Ro continue } - evt.Decrypted, evt.DecryptedType, err = h.decryptEvent(ctx, evt.AsRawMautrix()) + var mautrixEvt *event.Event + mautrixEvt, evt.Decrypted, evt.DecryptedType, err = h.decryptEvent(ctx, evt.AsRawMautrix()) if err != nil { log.Warn().Err(err).Stringer("event_id", evt.ID).Msg("Failed to decrypt event even after receiving megolm session") } else { decrypted = append(decrypted, evt) + h.cacheMedia(ctx, mautrixEvt, evt.RowID) } } if len(decrypted) > 0 { diff --git a/hicli/sync.go b/hicli/sync.go index 086f6dd1..3b40af9f 100644 --- a/hicli/sync.go +++ b/hicli/sync.go @@ -223,20 +223,86 @@ func removeReplyFallback(evt *event.Event) []byte { return nil } -func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) ([]byte, string, error) { +func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) (*event.Event, []byte, string, error) { err := evt.Content.ParseRaw(evt.Type) if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) { - return nil, "", err + return nil, nil, "", err } decrypted, err := h.Crypto.DecryptMegolmEvent(ctx, evt) if err != nil { - return nil, "", err + return nil, nil, "", err } withoutFallback := removeReplyFallback(decrypted) if withoutFallback != nil { - return withoutFallback, decrypted.Type.Type, nil + return decrypted, withoutFallback, decrypted.Type.Type, nil + } + return decrypted, decrypted.Content.VeryRaw, decrypted.Type.Type, nil +} + +func (h *HiClient) addMediaCache( + ctx context.Context, + eventRowID database.EventRowID, + uri id.ContentURIString, + file *event.EncryptedFileInfo, + info *event.FileInfo, + fileName string, +) { + parsedMXC := uri.ParseOrIgnore() + if !parsedMXC.IsValid() { + return + } + cm := &database.CachedMedia{ + MXC: parsedMXC, + EventRowID: eventRowID, + FileName: fileName, + } + if file != nil { + cm.EncFile = &file.EncryptedFile + } + if info != nil { + cm.MimeType = info.MimeType + } + err := h.DB.CachedMedia.Put(ctx, cm) + if err != nil { + zerolog.Ctx(ctx).Warn().Err(err). + Stringer("mxc", parsedMXC). + Int64("event_rowid", int64(eventRowID)). + Msg("Failed to add cached media entry") + } +} + +func (h *HiClient) cacheMedia(ctx context.Context, evt *event.Event, rowID database.EventRowID) { + switch evt.Type { + case event.EventMessage, event.EventSticker: + content, ok := evt.Content.Parsed.(*event.MessageEventContent) + if !ok { + return + } + if content.File != nil { + h.addMediaCache(ctx, rowID, content.File.URL, content.File, content.Info, content.GetFileName()) + } else if content.URL != "" { + h.addMediaCache(ctx, rowID, content.URL, nil, content.Info, content.GetFileName()) + } + if content.GetInfo().ThumbnailFile != nil { + h.addMediaCache(ctx, rowID, content.Info.ThumbnailFile.URL, content.Info.ThumbnailFile, content.Info.ThumbnailInfo, "") + } else if content.GetInfo().ThumbnailURL != "" { + h.addMediaCache(ctx, rowID, content.Info.ThumbnailURL, nil, content.Info.ThumbnailInfo, "") + } + case event.StateRoomAvatar: + _ = evt.Content.ParseRaw(evt.Type) + content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent) + if !ok { + return + } + h.addMediaCache(ctx, rowID, content.URL, nil, nil, "") + case event.StateMember: + _ = evt.Content.ParseRaw(evt.Type) + content, ok := evt.Content.Parsed.(*event.MemberEventContent) + if !ok { + return + } + h.addMediaCache(ctx, rowID, content.AvatarURL, nil, nil, "") } - return decrypted.Content.VeryRaw, decrypted.Type.Type, nil } func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptionQueue map[id.SessionID]*database.SessionRequest, checkDB bool) (*database.Event, error) { @@ -254,8 +320,9 @@ func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptio dbEvt.Content = contentWithoutFallback } var decryptionErr error + var decryptedMautrixEvt *event.Event if evt.Type == event.EventEncrypted && dbEvt.RedactedBy == "" { - dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt) + decryptedMautrixEvt, dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt) if decryptionErr != nil { dbEvt.DecryptionError = decryptionErr.Error() } @@ -272,6 +339,11 @@ func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptio if err != nil { return dbEvt, fmt.Errorf("failed to save event %s: %w", evt.ID, err) } + if decryptedMautrixEvt != nil { + h.cacheMedia(ctx, decryptedMautrixEvt, dbEvt.RowID) + } else { + h.cacheMedia(ctx, evt, dbEvt.RowID) + } if decryptionErr != nil && isDecryptionErrorRetryable(decryptionErr) { req, ok := decryptionQueue[dbEvt.MegolmSessionID] if !ok { From 64692eb06e111dc50dcdec4b5a7a8e11441b5f6f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 6 Oct 2024 21:17:05 +0300 Subject: [PATCH 0801/1647] hicli: add method to get rooms by sort order --- hicli/database/room.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/hicli/database/room.go b/hicli/database/room.go index 1acbd081..5778e5f5 100644 --- a/hicli/database/room.go +++ b/hicli/database/room.go @@ -27,8 +27,9 @@ const ( preview_event_rowid, sorting_timestamp, prev_batch FROM room ` - getRoomByIDQuery = getRoomBaseQuery + `WHERE room_id = $1` - ensureRoomExistsQuery = ` + getRoomsBySortingTimestampQuery = getRoomBaseQuery + `WHERE sorting_timestamp < $1 ORDER BY sorting_timestamp DESC LIMIT $2` + getRoomByIDQuery = getRoomBaseQuery + `WHERE room_id = $1` + ensureRoomExistsQuery = ` INSERT INTO room (room_id) VALUES ($1) ON CONFLICT (room_id) DO NOTHING ` @@ -69,6 +70,10 @@ func (rq *RoomQuery) Get(ctx context.Context, roomID id.RoomID) (*Room, error) { return rq.QueryOne(ctx, getRoomByIDQuery, roomID) } +func (rq *RoomQuery) GetBySortTS(ctx context.Context, maxTS time.Time, limit int) ([]*Room, error) { + return rq.QueryMany(ctx, getRoomsBySortingTimestampQuery, maxTS.UnixMilli(), limit) +} + func (rq *RoomQuery) Upsert(ctx context.Context, room *Room) error { return rq.Exec(ctx, upsertRoomFromSyncQuery, room.sqlVariables()...) } From 2621417bf0004fd9f33075b077833403eaa26a1b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 7 Oct 2024 12:53:35 +0300 Subject: [PATCH 0802/1647] client: don't let http close request body reader --- client.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/client.go b/client.go index b237612b..836afc55 100644 --- a/client.go +++ b/client.go @@ -420,6 +420,10 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e params.RequestLength = int64(len(params.RequestBytes)) } else if params.RequestLength > 0 && params.RequestBody != nil { logBody = fmt.Sprintf("<%d bytes>", params.RequestLength) + if rsc, ok := params.RequestBody.(io.ReadSeekCloser); ok { + // Prevent HTTP from closing the request body, it might be needed for retries + reqBody = nopCloseSeeker{rsc} + } } else if params.Method != http.MethodGet && params.Method != http.MethodHead { params.RequestJSON = struct{}{} logBody = params.RequestJSON @@ -1612,6 +1616,9 @@ func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (* if data.ContentBytes != nil { data.ContentLength = int64(len(data.ContentBytes)) reader = bytes.NewReader(data.ContentBytes) + } else if rsc, ok := reader.(io.ReadSeekCloser); ok { + // Prevent HTTP from closing the request body, it might be needed for retries + reader = nopCloseSeeker{rsc} } readerSeeker, canSeek := reader.(io.ReadSeeker) if !canSeek { @@ -1656,6 +1663,14 @@ func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (* return m, nil } +type nopCloseSeeker struct { + io.ReadSeeker +} + +func (nopCloseSeeker) Close() error { + return nil +} + // UploadMedia uploads the given data to the content repository and returns an MXC URI. // See https://spec.matrix.org/v1.7/client-server-api/#post_matrixmediav3upload func (cli *Client) UploadMedia(ctx context.Context, data ReqUploadMedia) (*RespMediaUpload, error) { From cb361e1f59a18f5342740dc61fb6f714bcc63a9e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 7 Oct 2024 12:53:46 +0300 Subject: [PATCH 0803/1647] pre-commit: add todo for staticcheck --- .pre-commit-config.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1ef1b112..c15d69d6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,6 +18,8 @@ repos: - "-w" - id: go-vet-repo-mod - id: go-mod-tidy + # TODO enable this + #- id: go-staticcheck-repo-mod - repo: https://github.com/beeper/pre-commit-go rev: v0.3.1 From c4d8189d4742b7d2b70316ed7d9f59c1905c914f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 7 Oct 2024 14:09:56 +0300 Subject: [PATCH 0804/1647] bridgev2/matrix: enable DM portal meta and list syncing by default --- bridgev2/matrix/mxmain/example-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index ed21df38..d31396ff 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -6,7 +6,7 @@ bridge: personal_filtering_spaces: true # Whether the bridge should set names and avatars explicitly for DM portals. # This is only necessary when using clients that don't support MSC4171. - private_chat_portal_meta: false + private_chat_portal_meta: true # Should events be handled asynchronously within portal rooms? # If true, events may end up being out of order, but slow events won't block other ones. # This is not yet safe to use. @@ -209,7 +209,7 @@ matrix: # Whether the bridge should send error notices via m.notice events when a message fails to bridge. message_error_notices: true # Whether the bridge should update the m.direct account data event when double puppeting is enabled. - sync_direct_chat_list: false + sync_direct_chat_list: true # Whether created rooms should have federation enabled. If false, created portal rooms # will never be federated. Changing this option requires recreating rooms. federate_rooms: true From 092ba65cad89a2ad0eef41ba7146b96f699443ce Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 7 Oct 2024 16:07:02 +0300 Subject: [PATCH 0805/1647] bridgev2,event: add basic support for polls --- bridgev2/errors.go | 2 + bridgev2/matrix/connector.go | 2 + bridgev2/networkinterface.go | 18 +++++++ bridgev2/portal.go | 92 ++++++++++++++++++++++++++++++------ event/beeper.go | 7 +++ event/content.go | 3 ++ event/poll.go | 67 ++++++++++++++++++++++++++ event/profile.go | 10 ---- event/type.go | 5 +- 9 files changed, 180 insertions(+), 26 deletions(-) create mode 100644 event/poll.go delete mode 100644 event/profile.go diff --git a/bridgev2/errors.go b/bridgev2/errors.go index 0b1ef8b3..55df5357 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -41,6 +41,7 @@ var ( ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage() ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false) ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false) + ErrIgnoringPollFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring poll event from relayed user")).WithIsCertain(true).WithSendNotice(false) ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage() ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage() ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage() @@ -48,6 +49,7 @@ var ( ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage() ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage() ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage() + ErrPollsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support polls")).WithIsCertain(true).WithErrorAsMessage() ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage() ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 4fd25306..35ef4a08 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -136,6 +136,8 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) { } br.EventProcessor.On(event.EventMessage, br.handleRoomEvent) br.EventProcessor.On(event.EventSticker, br.handleRoomEvent) + br.EventProcessor.On(event.EventUnstablePollStart, br.handleRoomEvent) + br.EventProcessor.On(event.EventUnstablePollResponse, br.handleRoomEvent) br.EventProcessor.On(event.EventReaction, br.handleRoomEvent) br.EventProcessor.On(event.EventRedaction, br.handleRoomEvent) br.EventProcessor.On(event.EventEncrypted, br.handleEncryptedEvent) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 842de252..14c3701c 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -309,6 +309,7 @@ type NetworkRoomCapabilities struct { Captions bool MaxTextLength int MaxCaptionLength int + Polls bool Threads bool Replies bool @@ -521,6 +522,12 @@ type EditHandlingNetworkAPI interface { HandleMatrixEdit(ctx context.Context, msg *MatrixEdit) error } +type PollHandlingNetworkAPI interface { + NetworkAPI + HandleMatrixPollStart(ctx context.Context, msg *MatrixPollStart) (*MatrixMessageResponse, error) + HandleMatrixPollVote(ctx context.Context, msg *MatrixPollVote) (*MatrixMessageResponse, error) +} + // ReactionHandlingNetworkAPI is an optional interface that network connectors can implement to handle message reactions. type ReactionHandlingNetworkAPI interface { NetworkAPI @@ -1108,6 +1115,17 @@ type MatrixEdit struct { EditTarget *database.Message } +type MatrixPollStart struct { + MatrixMessage + Content *event.PollStartEventContent +} + +type MatrixPollVote struct { + MatrixMessage + VoteTo *database.Message + Content *event.PollResponseEventContent +} + type MatrixReaction struct { MatrixEventBase[*event.ReactionEventContent] TargetMessage *database.Message diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 62256430..8ceb9759 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -528,7 +528,7 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * // Copy logger because many of the handlers will use UpdateContext ctx = log.With().Str("login_id", string(login.ID)).Logger().WithContext(ctx) switch evt.Type { - case event.EventMessage, event.EventSticker: + case event.EventMessage, event.EventSticker, event.EventUnstablePollStart, event.EventUnstablePollResponse: portal.handleMatrixMessage(ctx, login, origSender, evt) case event.EventReaction: if origSender != nil { @@ -771,7 +771,21 @@ func (portal *Portal) checkMessageContentCaps(ctx context.Context, caps *Network func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { log := zerolog.Ctx(ctx) - content, ok := evt.Content.Parsed.(*event.MessageEventContent) + var relatesTo *event.RelatesTo + var msgContent *event.MessageEventContent + var pollContent *event.PollStartEventContent + var pollResponseContent *event.PollResponseEventContent + var ok bool + if evt.Type == event.EventUnstablePollStart { + pollContent, ok = evt.Content.Parsed.(*event.PollStartEventContent) + relatesTo = pollContent.RelatesTo + } else if evt.Type == event.EventUnstablePollResponse { + pollResponseContent, ok = evt.Content.Parsed.(*event.PollResponseEventContent) + relatesTo = &pollResponseContent.RelatesTo + } else { + msgContent, ok = evt.Content.Parsed.(*event.MessageEventContent) + relatesTo = msgContent.RelatesTo + } if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) @@ -779,31 +793,61 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin } caps := sender.Client.GetCapabilities(ctx, portal) - if content.RelatesTo.GetReplaceID() != "" { - portal.handleMatrixEdit(ctx, sender, origSender, evt, content, caps) + if relatesTo.GetReplaceID() != "" { + if msgContent == nil { + log.Warn().Msg("Ignoring edit of poll") + portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w of polls", ErrEditsNotSupported)) + return + } + portal.handleMatrixEdit(ctx, sender, origSender, evt, msgContent, caps) return } var err error if origSender != nil { - content, err = portal.Bridge.Config.Relay.FormatMessage(content, origSender) + if msgContent == nil { + log.Debug().Msg("Ignoring poll event from relayed user") + portal.sendErrorStatus(ctx, evt, ErrIgnoringPollFromRelayedUser) + return + } + msgContent, err = portal.Bridge.Config.Relay.FormatMessage(msgContent, origSender) if err != nil { log.Err(err).Msg("Failed to format message for relaying") portal.sendErrorStatus(ctx, evt, err) return } } - if !portal.checkMessageContentCaps(ctx, caps, content, evt) { - return + if msgContent != nil { + if !portal.checkMessageContentCaps(ctx, caps, msgContent, evt) { + return + } + } else if pollResponseContent != nil || pollContent != nil { + if _, ok = sender.Client.(PollHandlingNetworkAPI); !ok { + log.Debug().Msg("Ignoring poll event as network connector doesn't implement PollHandlingNetworkAPI") + portal.sendErrorStatus(ctx, evt, ErrPollsNotSupported) + return + } } - var threadRoot, replyTo *database.Message + var threadRoot, replyTo, voteTo *database.Message + if evt.Type == event.EventUnstablePollResponse { + voteTo, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, relatesTo.GetReferenceID()) + if err != nil { + log.Err(err).Msg("Failed to get poll target message from database") + // TODO send status + return + } else if voteTo == nil { + log.Warn().Stringer("vote_to_id", relatesTo.GetReferenceID()).Msg("Poll target message not found") + // TODO send status + return + } + } var replyToID id.EventID if caps.Threads { - replyToID = content.RelatesTo.GetNonFallbackReplyTo() + replyToID = relatesTo.GetNonFallbackReplyTo() } else { - replyToID = content.RelatesTo.GetReplyTo() + replyToID = relatesTo.GetReplyTo() } - threadRootID := content.RelatesTo.GetThreadParent() + threadRootID := relatesTo.GetThreadParent() if caps.Threads && threadRootID != "" { threadRoot, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, threadRootID) if err != nil { @@ -839,23 +883,41 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin } } - wrappedEvt := &MatrixMessage{ + wrappedMsgEvt := &MatrixMessage{ MatrixEventBase: MatrixEventBase[*event.MessageEventContent]{ Event: evt, - Content: content, + Content: msgContent, OrigSender: origSender, Portal: portal, }, ThreadRoot: threadRoot, ReplyTo: replyTo, } - resp, err := sender.Client.HandleMatrixMessage(ctx, wrappedEvt) + var resp *MatrixMessageResponse + if msgContent != nil { + resp, err = sender.Client.HandleMatrixMessage(ctx, wrappedMsgEvt) + } else if pollContent != nil { + resp, err = sender.Client.(PollHandlingNetworkAPI).HandleMatrixPollStart(ctx, &MatrixPollStart{ + MatrixMessage: *wrappedMsgEvt, + Content: pollContent, + }) + } else if pollResponseContent != nil { + resp, err = sender.Client.(PollHandlingNetworkAPI).HandleMatrixPollVote(ctx, &MatrixPollVote{ + MatrixMessage: *wrappedMsgEvt, + VoteTo: voteTo, + Content: pollResponseContent, + }) + } else { + log.Error().Msg("Failed to handle Matrix message: all contents are nil?") + portal.sendErrorStatus(ctx, evt, fmt.Errorf("all contents are nil")) + return + } if err != nil { log.Err(err).Msg("Failed to handle Matrix message") portal.sendErrorStatus(ctx, evt, err) return } - message := wrappedEvt.fillDBMessage(resp.DB) + message := wrappedMsgEvt.fillDBMessage(resp.DB) if !resp.Pending { if resp.DB == nil { log.Error().Msg("Network connector didn't return a message to save") diff --git a/event/beeper.go b/event/beeper.go index 1394a6ce..911bdfe3 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -97,3 +97,10 @@ type BeeperProfileExtra struct { IsBridgeBot bool `json:"com.beeper.bridge.is_bridge_bot,omitempty"` IsNetworkBot bool `json:"com.beeper.bridge.is_network_bot,omitempty"` } + +type BeeperPerMessageProfile struct { + ID string `json:"id"` + Displayname string `json:"displayname,omitempty"` + AvatarURL *id.ContentURIString `json:"avatar_url,omitempty"` + AvatarFile *EncryptedFileInfo `json:"avatar_file,omitempty"` +} diff --git a/event/content.go b/event/content.go index d08acad1..882d3368 100644 --- a/event/content.go +++ b/event/content.go @@ -55,6 +55,9 @@ var TypeMap = map[Type]reflect.Type{ EventRedaction: reflect.TypeOf(RedactionEventContent{}), EventReaction: reflect.TypeOf(ReactionEventContent{}), + EventUnstablePollStart: reflect.TypeOf(PollStartEventContent{}), + EventUnstablePollResponse: reflect.TypeOf(PollResponseEventContent{}), + BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}), AccountDataRoomTags: reflect.TypeOf(TagEventContent{}), diff --git a/event/poll.go b/event/poll.go new file mode 100644 index 00000000..37333015 --- /dev/null +++ b/event/poll.go @@ -0,0 +1,67 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package event + +type PollResponseEventContent struct { + RelatesTo RelatesTo `json:"m.relates_to"` + Response struct { + Answers []string `json:"answers"` + } `json:"org.matrix.msc3381.poll.response"` +} + +func (content *PollResponseEventContent) GetRelatesTo() *RelatesTo { + return &content.RelatesTo +} + +func (content *PollResponseEventContent) OptionalGetRelatesTo() *RelatesTo { + if content.RelatesTo.Type == "" { + return nil + } + return &content.RelatesTo +} + +func (content *PollResponseEventContent) SetRelatesTo(rel *RelatesTo) { + content.RelatesTo = *rel +} + +type MSC1767Message struct { + Text string `json:"org.matrix.msc1767.text,omitempty"` + HTML string `json:"org.matrix.msc1767.html,omitempty"` + Message []struct { + MimeType string `json:"mimetype"` + Body string `json:"body"` + } `json:"org.matrix.msc1767.message,omitempty"` +} + +type PollStartEventContent struct { + RelatesTo *RelatesTo `json:"m.relates_to"` + Mentions *Mentions `json:"m.mentions,omitempty"` + PollStart struct { + Kind string `json:"kind"` + MaxSelections int `json:"max_selections"` + Question MSC1767Message `json:"question"` + Answers []struct { + ID string `json:"id"` + MSC1767Message + } `json:"answers"` + } `json:"org.matrix.msc3381.poll.start"` +} + +func (content *PollStartEventContent) GetRelatesTo() *RelatesTo { + if content.RelatesTo == nil { + content.RelatesTo = &RelatesTo{} + } + return content.RelatesTo +} + +func (content *PollStartEventContent) OptionalGetRelatesTo() *RelatesTo { + return content.RelatesTo +} + +func (content *PollStartEventContent) SetRelatesTo(rel *RelatesTo) { + content.RelatesTo = rel +} diff --git a/event/profile.go b/event/profile.go deleted file mode 100644 index 6dc4314a..00000000 --- a/event/profile.go +++ /dev/null @@ -1,10 +0,0 @@ -package event - -import "maunium.net/go/mautrix/id" - -type BeeperPerMessageProfile struct { - ID string `json:"id"` - Displayname string `json:"displayname,omitempty"` - AvatarURL *id.ContentURIString `json:"avatar_url,omitempty"` - AvatarFile *EncryptedFileInfo `json:"avatar_file,omitempty"` -} diff --git a/event/type.go b/event/type.go index 3f447343..4396c9cc 100644 --- a/event/type.go +++ b/event/type.go @@ -126,7 +126,7 @@ func (et *Type) GuessClass() TypeClass { InRoomVerificationStart.Type, InRoomVerificationReady.Type, InRoomVerificationAccept.Type, InRoomVerificationKey.Type, InRoomVerificationMAC.Type, InRoomVerificationCancel.Type, CallInvite.Type, CallCandidates.Type, CallAnswer.Type, CallReject.Type, CallSelectAnswer.Type, - CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type: + CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type, EventUnstablePollStart.Type, EventUnstablePollResponse.Type: return MessageEventType case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type, ToDeviceBeeperRoomKeyAck.Type: @@ -232,6 +232,9 @@ var ( CallHangup = Type{"m.call.hangup", MessageEventType} BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType} + + EventUnstablePollStart = Type{Type: "org.matrix.msc3381.poll.start", Class: MessageEventType} + EventUnstablePollResponse = Type{Type: "org.matrix.msc3381.poll.response", Class: MessageEventType} ) // Ephemeral events From 7a9269e8ff9fe765e7e60b97688ed6516f302554 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 7 Oct 2024 16:12:43 +0300 Subject: [PATCH 0806/1647] bridgev2/networkinterface: add post-save callback for matrix messages --- bridgev2/networkinterface.go | 3 +++ bridgev2/portal.go | 2 ++ 2 files changed, 5 insertions(+) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 14c3701c..3b406a9d 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -284,6 +284,9 @@ type MatrixMessageResponse struct { // If RemovePending is set, the bridge will remove the provided transaction ID from pending messages // after saving the provided message to the database. This should be used with AddPendingToIgnore. RemovePending networkid.TransactionID + // An optional function that is called after the message is saved to the database. + // Will not be called if the message is not saved for some reason. + PostSave func(context.Context, *database.Message) } type FileRestriction struct { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 8ceb9759..f8dc8596 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -928,6 +928,8 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin err = portal.Bridge.DB.Message.Insert(ctx, message) if err != nil { log.Err(err).Msg("Failed to save message to database") + } else if resp.PostSave != nil { + resp.PostSave(ctx, message) } if resp.RemovePending != "" { portal.outgoingMessagesLock.Lock() From 9dbd4333636bbb92bd89fae4c54751276d06917d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 7 Oct 2024 19:49:00 +0300 Subject: [PATCH 0807/1647] bridgev2/backfill: do complete callback if messages are cut off --- bridgev2/portalbackfill.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index 4350cfa2..55225efc 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -72,6 +72,9 @@ func (portal *Portal) doForwardBackfill(ctx context.Context, source *UserLogin, resp.Messages = portal.cutoffMessages(ctx, resp.Messages, resp.AggressiveDeduplication, true, lastMessage) if len(resp.Messages) == 0 { log.Warn().Msg("No messages left to backfill after cutting off old messages") + if resp.CompleteCallback != nil { + resp.CompleteCallback() + } return } portal.sendBackfill(ctx, source, resp.Messages, true, resp.MarkRead, false, resp.CompleteCallback) @@ -128,10 +131,16 @@ func (portal *Portal) DoBackwardsBackfill(ctx context.Context, source *UserLogin } else { log.Warn().Msg("No messages to backfill, but HasMore is true") } + if resp.CompleteCallback != nil { + resp.CompleteCallback() + } return nil } resp.Messages = portal.cutoffMessages(ctx, resp.Messages, resp.AggressiveDeduplication, false, firstMessage) if len(resp.Messages) == 0 { + if resp.CompleteCallback != nil { + resp.CompleteCallback() + } return fmt.Errorf("no messages left to backfill after cutting off too new messages") } portal.sendBackfill(ctx, source, resp.Messages, false, resp.MarkRead, false, resp.CompleteCallback) @@ -163,6 +172,9 @@ func (portal *Portal) fetchThreadBackfill(ctx context.Context, source *UserLogin resp.Messages = portal.cutoffMessages(ctx, resp.Messages, resp.AggressiveDeduplication, true, anchor) if len(resp.Messages) == 0 { log.Warn().Msg("No messages left to backfill after cutting off old messages") + if resp.CompleteCallback != nil { + resp.CompleteCallback() + } return nil } return resp From 1d8891fdb40e215b660f7f5f5f523d93e27442b2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 7 Oct 2024 22:13:03 +0300 Subject: [PATCH 0808/1647] hicli: fix pagination bugs --- hicli/database/event.go | 2 +- hicli/database/timeline.go | 17 ++++++++++---- hicli/json-commands.go | 4 ++-- hicli/paginate.go | 45 +++++++++++++++++++++++++++++++++----- 4 files changed, 55 insertions(+), 13 deletions(-) diff --git a/hicli/database/event.go b/hicli/database/event.go index 57a3291d..2953bb9b 100644 --- a/hicli/database/event.go +++ b/hicli/database/event.go @@ -154,7 +154,7 @@ func (eq *EventQuery) FillReactionCounts(ctx context.Context, roomID id.RoomID, } func (eq *EventQuery) FillLastEditRowIDs(ctx context.Context, roomID id.RoomID, events []*Event) error { - eventIDs := make([]id.EventID, 0) + eventIDs := make([]id.EventID, len(events)) eventMap := make(map[id.EventID]*Event) for i, evt := range events { if evt.LastEditRowID == nil { diff --git a/hicli/database/timeline.go b/hicli/database/timeline.go index 891f6acb..0a01c7f5 100644 --- a/hicli/database/timeline.go +++ b/hicli/database/timeline.go @@ -27,13 +27,16 @@ const ( prependTimelineQuery = ` INSERT INTO timeline (room_id, rowid, event_rowid) VALUES ($1, $2, $3) ` + checkTimelineContainsQuery = ` + SELECT EXISTS(SELECT 1 FROM timeline WHERE room_id = $1 AND event_rowid = $2) + ` findMinRowIDQuery = `SELECT MIN(rowid) FROM timeline` getTimelineQuery = ` SELECT event.rowid, timeline.rowid, event.room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, - redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, reactions, last_edit_rowid + transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid FROM timeline JOIN event ON event.rowid = timeline.event_rowid - WHERE timeline.room_id = $1 AND timeline.rowid < $2 + WHERE timeline.room_id = $1 AND ($2 = 0 OR timeline.rowid < $2) ORDER BY timeline.rowid DESC LIMIT $3 ` @@ -80,12 +83,13 @@ func (tq *TimelineQuery) reserveRowIDs(ctx context.Context, count int) (startFro return } if tq.minRowID >= 0 { - // No negative row IDs exist, start at -1 - tq.minRowID = -1 + // No negative row IDs exist, start at -2 + tq.minRowID = -2 } else { // We fetched the lowest row ID, but we want the next available one, so decrement one tq.minRowID-- } + tq.minRowIDFound = true } startFrom = tq.minRowID tq.minRowID -= TimelineRowID(count) @@ -121,3 +125,8 @@ func (tq *TimelineQuery) Append(ctx context.Context, roomID id.RoomID, rowIDs [] func (tq *TimelineQuery) Get(ctx context.Context, roomID id.RoomID, limit int, before TimelineRowID) ([]*Event, error) { return tq.QueryMany(ctx, getTimelineQuery, roomID, before, limit) } + +func (tq *TimelineQuery) Has(ctx context.Context, roomID id.RoomID, eventRowID EventRowID) (exists bool, err error) { + err = tq.GetDB().QueryRow(ctx, checkTimelineContainsQuery, roomID, eventRowID).Scan(&exists) + return +} diff --git a/hicli/json-commands.go b/hicli/json-commands.go index 90cb2890..29a2ac73 100644 --- a/hicli/json-commands.go +++ b/hicli/json-commands.go @@ -50,11 +50,11 @@ func (h *HiClient) handleJSONCommand(ctx context.Context, req *JSONCommand) (any return h.GetEventsByRowIDs(ctx, params.RowIDs) }) case "paginate": - return unmarshalAndCall(req.Data, func(params *paginateParams) ([]*database.Event, error) { + return unmarshalAndCall(req.Data, func(params *paginateParams) (*PaginationResponse, error) { return h.Paginate(ctx, params.RoomID, params.MaxTimelineID, params.Limit) }) case "paginate_server": - return unmarshalAndCall(req.Data, func(params *paginateParams) ([]*database.Event, error) { + return unmarshalAndCall(req.Data, func(params *paginateParams) (*PaginationResponse, error) { return h.PaginateServer(ctx, params.RoomID, params.Limit) }) case "ensure_group_session_shared": diff --git a/hicli/paginate.go b/hicli/paginate.go index 9992b36e..878a033a 100644 --- a/hicli/paginate.go +++ b/hicli/paginate.go @@ -11,6 +11,8 @@ import ( "errors" "fmt" + "github.com/rs/zerolog" + "maunium.net/go/mautrix" "maunium.net/go/mautrix/hicli/database" "maunium.net/go/mautrix/id" @@ -60,18 +62,23 @@ func (h *HiClient) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.Ev } } -func (h *HiClient) Paginate(ctx context.Context, roomID id.RoomID, maxTimelineID database.TimelineRowID, limit int) ([]*database.Event, error) { +type PaginationResponse struct { + Events []*database.Event `json:"events"` + HasMore bool `json:"has_more"` +} + +func (h *HiClient) Paginate(ctx context.Context, roomID id.RoomID, maxTimelineID database.TimelineRowID, limit int) (*PaginationResponse, error) { evts, err := h.DB.Timeline.Get(ctx, roomID, limit, maxTimelineID) if err != nil { return nil, err } else if len(evts) > 0 { - return evts, nil + return &PaginationResponse{Events: evts, HasMore: true}, nil } else { return h.PaginateServer(ctx, roomID, limit) } } -func (h *HiClient) PaginateServer(ctx context.Context, roomID id.RoomID, limit int) ([]*database.Event, error) { +func (h *HiClient) PaginateServer(ctx context.Context, roomID id.RoomID, limit int) (*PaginationResponse, error) { ctx, cancel := context.WithCancelCause(ctx) h.paginationInterrupterLock.Lock() if _, alreadyPaginating := h.paginationInterrupter[roomID]; alreadyPaginating { @@ -89,12 +96,21 @@ func (h *HiClient) PaginateServer(ctx context.Context, roomID id.RoomID, limit i room, err := h.DB.Room.Get(ctx, roomID) if err != nil { return nil, fmt.Errorf("failed to get room from database: %w", err) + } else if room.PrevBatch == "" { + return &PaginationResponse{Events: []*database.Event{}, HasMore: false}, nil } resp, err := h.Client.Messages(ctx, roomID, room.PrevBatch, "", mautrix.DirectionBackward, nil, limit) if err != nil { return nil, fmt.Errorf("failed to get messages from server: %w", err) } events := make([]*database.Event, len(resp.Chunk)) + if resp.End == "" || len(resp.Chunk) == 0 { + err = h.DB.Room.SetPrevBatch(ctx, room.ID, resp.End) + if err != nil { + return nil, fmt.Errorf("failed to set prev_batch: %w", err) + } + return &PaginationResponse{Events: events, HasMore: resp.End != ""}, nil + } wakeupSessionRequests := false err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { if err = ctx.Err(); err != nil { @@ -102,13 +118,30 @@ func (h *HiClient) PaginateServer(ctx context.Context, roomID id.RoomID, limit i } eventRowIDs := make([]database.EventRowID, len(resp.Chunk)) decryptionQueue := make(map[id.SessionID]*database.SessionRequest) + iOffset := 0 for i, evt := range resp.Chunk { - events[i], err = h.processEvent(ctx, evt, decryptionQueue, true) + dbEvt, err := h.processEvent(ctx, evt, decryptionQueue, true) if err != nil { return err + } else if exists, err := h.DB.Timeline.Has(ctx, roomID, dbEvt.RowID); err != nil { + return fmt.Errorf("failed to check if event exists in timeline: %w", err) + } else if exists { + zerolog.Ctx(ctx).Warn(). + Int64("row_id", int64(dbEvt.RowID)). + Str("event_id", dbEvt.ID.String()). + Msg("Event already exists in timeline, skipping") + iOffset++ + continue } - eventRowIDs[i] = events[i].RowID + events[i-iOffset] = dbEvt + eventRowIDs[i-iOffset] = events[i-iOffset].RowID } + if iOffset >= len(events) { + events = events[:0] + return nil + } + events = events[:len(events)-iOffset] + eventRowIDs = eventRowIDs[:len(eventRowIDs)-iOffset] wakeupSessionRequests = len(decryptionQueue) > 0 for _, entry := range decryptionQueue { err = h.DB.SessionRequest.Put(ctx, entry) @@ -137,5 +170,5 @@ func (h *HiClient) PaginateServer(ctx context.Context, roomID id.RoomID, limit i if err == nil && wakeupSessionRequests { h.WakeupRequestQueue() } - return events, err + return &PaginationResponse{Events: events, HasMore: true}, err } From e192932af9f06b0baa23aa55bd4852076fcada41 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 8 Oct 2024 19:43:33 +0300 Subject: [PATCH 0809/1647] bridgev2/provisioning: add description of entire login flow --- bridgev2/matrix/provisioning.yaml | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/bridgev2/matrix/provisioning.yaml b/bridgev2/matrix/provisioning.yaml index d03101b0..bd9217c8 100644 --- a/bridgev2/matrix/provisioning.yaml +++ b/bridgev2/matrix/provisioning.yaml @@ -91,6 +91,24 @@ paths: post: tags: [ auth ] summary: Start a new login process. + description: | + This endpoint starts a new login process, which is used to log into the bridge. + + The basic flow of the entire login, including calling this endpoint, is: + 1. Call `GET /v3/login/flows` to get the list of available flows. + If there's more than one flow, ask the user to pick which one they want to use. + 2. Call this endpoint with the chosen flow ID to start the login. + The first login step will be returned. + 3. Render the information provided in the step. + 4. Call the `/login/step/...` endpoint corresponding to the step type: + * For `user_input` and `cookies`, acquire the requested fields before calling the endpoint. + * For `display_and_wait`, call the endpoint immediately + (as there's nothing to acquire on the client side). + 5. Handle the data returned by the login step endpoint: + * If an error is returned, the login has failed and must be restarted + (from either step 1 or step 2) if the user wants to try again. + * If step type `complete` is returned, the login finished successfully. + * Otherwise, go to step 3 with the new data. operationId: startLogin parameters: - name: login_id From 1f7f489fa97d98736eb83dc30ff0eb9310512489 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 8 Oct 2024 21:43:07 +0300 Subject: [PATCH 0810/1647] hicli: allow storing errors in media cache --- error.go | 5 ++ hicli/database/cachedmedia.go | 49 ++++++++++++++++--- .../database/upgrades/00-latest-revision.sql | 1 + 3 files changed, 47 insertions(+), 8 deletions(-) diff --git a/error.go b/error.go index 2f9ab983..a4ba9859 100644 --- a/error.go +++ b/error.go @@ -162,6 +162,11 @@ func (e RespError) WithMessage(msg string, args ...any) RespError { return e } +func (e RespError) WithStatus(status int) RespError { + e.StatusCode = status + return e +} + // Error returns the errcode and error message. func (e RespError) Error() string { return e.ErrCode + ": " + e.Err diff --git a/hicli/database/cachedmedia.go b/hicli/database/cachedmedia.go index 9ea9c27a..c546d61e 100644 --- a/hicli/database/cachedmedia.go +++ b/hicli/database/cachedmedia.go @@ -9,32 +9,38 @@ package database import ( "context" "database/sql" + "net/http" + "time" "go.mau.fi/util/dbutil" + "go.mau.fi/util/jsontime" "golang.org/x/exp/slices" + "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/attachment" "maunium.net/go/mautrix/id" ) const ( insertCachedMediaQuery = ` - INSERT INTO cached_media (mxc, event_rowid, enc_file, file_name, mime_type, size, hash) - VALUES ($1, $2, $3, $4, $5, $6, $7) + INSERT INTO cached_media (mxc, event_rowid, enc_file, file_name, mime_type, size, hash, error) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT (mxc) DO NOTHING ` upsertCachedMediaQuery = ` - INSERT INTO cached_media (mxc, event_rowid, enc_file, file_name, mime_type, size, hash) - VALUES ($1, $2, $3, $4, $5, $6, $7) + INSERT INTO cached_media (mxc, event_rowid, enc_file, file_name, mime_type, size, hash, error) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT (mxc) DO UPDATE SET enc_file = excluded.enc_file, file_name = excluded.file_name, mime_type = excluded.mime_type, size = excluded.size, - hash = excluded.hash + hash = excluded.hash, + error = excluded.error + WHERE excluded.error IS NULL OR cached_media.hash IS NULL ` getCachedMediaQuery = ` - SELECT mxc, event_rowid, enc_file, file_name, mime_type, size, hash + SELECT mxc, event_rowid, enc_file, file_name, mime_type, size, hash, error FROM cached_media WHERE mxc = $1 ` @@ -56,6 +62,27 @@ func (cmq *CachedMediaQuery) Get(ctx context.Context, mxc id.ContentURI) (*Cache return cmq.QueryOne(ctx, getCachedMediaQuery, &mxc) } +type MediaError struct { + Matrix *mautrix.RespError `json:"data"` + StatusCode int `json:"status_code"` + ReceivedAt jsontime.UnixMilli `json:"received_at"` + Attempts int `json:"attempts"` +} + +const MaxMediaBackoff = 7 * 24 * time.Hour + +func (me *MediaError) UseCache() bool { + return me != nil && time.Since(me.ReceivedAt.Time) < min(time.Duration(2< Date: Tue, 8 Oct 2024 23:49:15 +0300 Subject: [PATCH 0811/1647] hicli: include next retry ts in cached errors --- hicli/database/cachedmedia.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/hicli/database/cachedmedia.go b/hicli/database/cachedmedia.go index c546d61e..2ccaca3b 100644 --- a/hicli/database/cachedmedia.go +++ b/hicli/database/cachedmedia.go @@ -71,8 +71,12 @@ type MediaError struct { const MaxMediaBackoff = 7 * 24 * time.Hour +func (me *MediaError) backoff() time.Duration { + return min(time.Duration(2< Date: Wed, 9 Oct 2024 01:22:13 +0300 Subject: [PATCH 0812/1647] hicli: include state in sync and add method to get state --- client.go | 12 +++++++ hicli/database/room.go | 5 ++- hicli/database/state.go | 48 +++++++++++++++++++++++++++ hicli/database/statestore.go | 10 +++++- hicli/events.go | 9 +++--- hicli/json-commands.go | 10 ++++++ hicli/paginate.go | 63 ++++++++++++++++++++++++++++++++++++ hicli/send.go | 16 ++++++--- hicli/sync.go | 16 +++++++-- 9 files changed, 176 insertions(+), 13 deletions(-) diff --git a/client.go b/client.go index 836afc55..125bba0d 100644 --- a/client.go +++ b/client.go @@ -1469,6 +1469,18 @@ func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomSt return } +// StateAsArray gets all the state in a room as an array. It does not update the state store. +// Use State to get the events as a map and also update the state store. +func (cli *Client) StateAsArray(ctx context.Context, roomID id.RoomID) (state []*event.Event, err error) { + _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildClientURL("v3", "rooms", roomID, "state"), nil, &state) + if err == nil { + for _, evt := range state { + evt.Type.Class = event.StateEventType + } + } + return +} + // GetMediaConfig fetches the configuration of the content repository, such as upload limitations. func (cli *Client) GetMediaConfig(ctx context.Context) (resp *RespMediaConfig, err error) { _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildClientURL("v1", "media", "config"), nil, &resp) diff --git a/hicli/database/room.go b/hicli/database/room.go index 5778e5f5..e7138d94 100644 --- a/hicli/database/room.go +++ b/hicli/database/room.go @@ -153,7 +153,10 @@ func (r *Room) CheckChangesAndCopyInto(other *Room) (hasChanges bool) { other.EncryptionEvent = r.EncryptionEvent hasChanges = true } - other.HasMemberList = other.HasMemberList || r.HasMemberList + if r.HasMemberList && !other.HasMemberList { + hasChanges = true + other.HasMemberList = true + } if r.PreviewEventRowID > other.PreviewEventRowID { other.PreviewEventRowID = r.PreviewEventRowID hasChanges = true diff --git a/hicli/database/state.go b/hicli/database/state.go index e74f2950..5dc13729 100644 --- a/hicli/database/state.go +++ b/hicli/database/state.go @@ -4,10 +4,14 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. +//go:build go1.23 + package database import ( "context" + "fmt" + "slices" "go.mau.fi/util/dbutil" @@ -20,6 +24,13 @@ const ( INSERT INTO current_state (room_id, event_type, state_key, event_rowid, membership) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (room_id, event_type, state_key) DO UPDATE SET event_rowid = excluded.event_rowid, membership = excluded.membership ` + addCurrentStateQuery = ` + INSERT INTO current_state (room_id, event_type, state_key, event_rowid, membership) VALUES ($1, $2, $3, $4, $5) + ON CONFLICT DO NOTHING + ` + deleteCurrentStateQuery = ` + DELETE FROM current_state WHERE room_id = $1 + ` getCurrentRoomStateQuery = ` SELECT event.rowid, -1, event.room_id, event.event_id, sender, event.type, event.state_key, timestamp, content, decrypted, decrypted_type, unsigned, transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid @@ -30,6 +41,21 @@ const ( getCurrentStateEventQuery = getCurrentRoomStateQuery + `AND cs.event_type = $2 AND cs.state_key = $3` ) +var massInsertCurrentStateBuilder = dbutil.NewMassInsertBuilder[*CurrentStateEntry, [1]any](addCurrentStateQuery, "($1, $%d, $%d, $%d, $%d)") + +const currentStateMassInsertBatchSize = 1000 + +type CurrentStateEntry struct { + EventType event.Type + StateKey string + EventRowID EventRowID + Membership event.Membership +} + +func (cse *CurrentStateEntry) GetMassInsertValues() [4]any { + return [4]any{cse.EventType.Type, cse.StateKey, cse.EventRowID, dbutil.StrPtr(cse.Membership)} +} + type CurrentStateQuery struct { *dbutil.QueryHelper[*Event] } @@ -38,6 +64,28 @@ func (csq *CurrentStateQuery) Set(ctx context.Context, roomID id.RoomID, eventTy return csq.Exec(ctx, setCurrentStateQuery, roomID, eventType.Type, stateKey, eventRowID, dbutil.StrPtr(membership)) } +func (csq *CurrentStateQuery) AddMany(ctx context.Context, roomID id.RoomID, deleteOld bool, entries []*CurrentStateEntry) error { + var err error + if deleteOld { + err = csq.Exec(ctx, deleteCurrentStateQuery, roomID) + if err != nil { + return fmt.Errorf("failed to delete old state: %w", err) + } + } + for entryChunk := range slices.Chunk(entries, currentStateMassInsertBatchSize) { + query, params := massInsertCurrentStateBuilder.Build([1]any{roomID}, entryChunk) + err = csq.Exec(ctx, query, params...) + if err != nil { + return err + } + } + return nil +} + +func (csq *CurrentStateQuery) Add(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, eventRowID EventRowID, membership event.Membership) error { + return csq.Exec(ctx, addCurrentStateQuery, roomID, eventType.Type, stateKey, eventRowID, dbutil.StrPtr(membership)) +} + func (csq *CurrentStateQuery) Get(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*Event, error) { return csq.QueryOne(ctx, getCurrentStateEventQuery, roomID, eventType.Type, stateKey) } diff --git a/hicli/database/statestore.go b/hicli/database/statestore.go index 1779afa5..fcd6aceb 100644 --- a/hicli/database/statestore.go +++ b/hicli/database/statestore.go @@ -39,6 +39,9 @@ const ( SELECT state_key FROM current_state WHERE room_id = $1 AND event_type = 'm.room.member' AND membership IN ('join', 'invite') ` + getHasFetchedMembersQuery = ` + SELECT has_member_list FROM room WHERE room_id = $1 + ` isRoomEncryptedQuery = ` SELECT room.encryption_event IS NOT NULL FROM room WHERE room_id = $1 ` @@ -116,7 +119,12 @@ func (c *ClientStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, ro return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() } -func (c *ClientStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error) { +func (c *ClientStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (hasFetched bool, err error) { + //err = c.QueryRow(ctx, getHasFetchedMembersQuery, roomID).Scan(&hasFetched) + //if errors.Is(err, sql.ErrNoRows) { + // err = nil + //} + //return return false, fmt.Errorf("not implemented") } diff --git a/hicli/events.go b/hicli/events.go index c7228541..ea03be7e 100644 --- a/hicli/events.go +++ b/hicli/events.go @@ -13,10 +13,11 @@ import ( ) type SyncRoom struct { - Meta *database.Room `json:"meta"` - Timeline []database.TimelineRowTuple `json:"timeline"` - Events []*database.Event `json:"events"` - Reset bool `json:"reset"` + Meta *database.Room `json:"meta"` + Timeline []database.TimelineRowTuple `json:"timeline"` + State map[event.Type]map[string]database.EventRowID `json:"state"` + Events []*database.Event `json:"events"` + Reset bool `json:"reset"` } type SyncComplete struct { diff --git a/hicli/json-commands.go b/hicli/json-commands.go index 29a2ac73..12026f6b 100644 --- a/hicli/json-commands.go +++ b/hicli/json-commands.go @@ -49,6 +49,10 @@ func (h *HiClient) handleJSONCommand(ctx context.Context, req *JSONCommand) (any return unmarshalAndCall(req.Data, func(params *getEventsByRowIDsParams) ([]*database.Event, error) { return h.GetEventsByRowIDs(ctx, params.RowIDs) }) + case "get_room_state": + return unmarshalAndCall(req.Data, func(params *getRoomStateParams) ([]*database.Event, error) { + return h.GetRoomState(ctx, params.RoomID, params.FetchMembers, params.Refetch) + }) case "paginate": return unmarshalAndCall(req.Data, func(params *paginateParams) (*PaginationResponse, error) { return h.Paginate(ctx, params.RoomID, params.MaxTimelineID, params.Limit) @@ -111,6 +115,12 @@ type getEventsByRowIDsParams struct { RowIDs []database.EventRowID `json:"row_ids"` } +type getRoomStateParams struct { + RoomID id.RoomID `json:"room_id"` + Refetch bool `json:"refetch"` + FetchMembers bool `json:"fetch_members"` +} + type ensureGroupSessionSharedParams struct { RoomID id.RoomID `json:"room_id"` } diff --git a/hicli/paginate.go b/hicli/paginate.go index 878a033a..4109e7af 100644 --- a/hicli/paginate.go +++ b/hicli/paginate.go @@ -14,6 +14,7 @@ import ( "github.com/rs/zerolog" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/hicli/database" "maunium.net/go/mautrix/id" ) @@ -62,6 +63,68 @@ func (h *HiClient) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.Ev } } +func (h *HiClient) GetRoomState(ctx context.Context, roomID id.RoomID, fetchMembers, refetch bool) ([]*database.Event, error) { + var evts []*event.Event + if refetch { + resp, err := h.Client.StateAsArray(ctx, roomID) + if err != nil { + return nil, fmt.Errorf("failed to refetch state: %w", err) + } + evts = resp + } else if fetchMembers { + resp, err := h.Client.Members(ctx, roomID) + if err != nil { + return nil, fmt.Errorf("failed to fetch members: %w", err) + } + evts = resp.Chunk + } + if evts != nil { + err := h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { + room, err := h.DB.Room.Get(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to get room from database: %w", err) + } + updatedRoom := &database.Room{ + ID: room.ID, + HasMemberList: true, + } + entries := make([]*database.CurrentStateEntry, len(evts)) + for i, evt := range evts { + dbEvt, err := h.processEvent(ctx, evt, nil, false) + if err != nil { + return fmt.Errorf("failed to process event %s: %w", evt.ID, err) + } + entries[i] = &database.CurrentStateEntry{ + EventType: evt.Type, + StateKey: *evt.StateKey, + EventRowID: dbEvt.RowID, + } + if evt.Type == event.StateMember { + entries[i].Membership = event.Membership(evt.Content.Raw["membership"].(string)) + } else { + processImportantEvent(ctx, evt, room, updatedRoom) + } + } + err = h.DB.CurrentState.AddMany(ctx, room.ID, refetch, entries) + if err != nil { + return err + } + roomChanged := updatedRoom.CheckChangesAndCopyInto(room) + if roomChanged { + err = h.DB.Room.Upsert(ctx, updatedRoom) + if err != nil { + return fmt.Errorf("failed to save room data: %w", err) + } + } + return nil + }) + if err != nil { + return nil, err + } + } + return h.DB.CurrentState.GetAll(ctx, roomID) +} + type PaginationResponse struct { Events []*database.Event `json:"events"` HasMore bool `json:"has_more"` diff --git a/hicli/send.go b/hicli/send.go index 8824f3c3..1c76a5a2 100644 --- a/hicli/send.go +++ b/hicli/send.go @@ -148,17 +148,23 @@ func (h *HiClient) loadMembers(ctx context.Context, room *database.Room) error { return fmt.Errorf("failed to get room member list: %w", err) } err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { - for _, evt := range resp.Chunk { + entries := make([]*database.CurrentStateEntry, len(resp.Chunk)) + for i, evt := range resp.Chunk { dbEvt, err := h.processEvent(ctx, evt, nil, true) if err != nil { return err } - membership := event.Membership(evt.Content.Raw["membership"].(string)) - err = h.DB.CurrentState.Set(ctx, room.ID, evt.Type, *evt.StateKey, dbEvt.RowID, membership) - if err != nil { - return err + entries[i] = &database.CurrentStateEntry{ + EventType: evt.Type, + StateKey: *evt.StateKey, + EventRowID: dbEvt.RowID, + Membership: event.Membership(evt.Content.Raw["membership"].(string)), } } + err := h.DB.CurrentState.AddMany(ctx, room.ID, false, entries) + if err != nil { + return err + } return h.DB.Room.Upsert(ctx, &database.Room{ ID: room.ID, HasMemberList: true, diff --git a/hicli/sync.go b/hicli/sync.go index 3b40af9f..aaf1f5c6 100644 --- a/hicli/sync.go +++ b/hicli/sync.go @@ -410,15 +410,23 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R allNewEvents = append(allNewEvents, dbEvt) return dbEvt.RowID, nil } - var err error + changedState := make(map[event.Type]map[string]database.EventRowID) + setNewState := func(evtType event.Type, stateKey string, rowID database.EventRowID) { + if _, ok := changedState[evtType]; !ok { + changedState[evtType] = make(map[string]database.EventRowID) + } + changedState[evtType][stateKey] = rowID + } for _, evt := range state.Events { evt.Type.Class = event.StateEventType - _, err = processNewEvent(evt, false) + rowID, err := processNewEvent(evt, false) if err != nil { return err } + setNewState(evt.Type, *evt.StateKey, rowID) } var timelineRowTuples []database.TimelineRowTuple + var err error if len(timeline.Events) > 0 { timelineIDs := make([]database.EventRowID, len(timeline.Events)) for i, evt := range timeline.Events { @@ -431,6 +439,9 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R if err != nil { return err } + if evt.StateKey != nil { + setNewState(evt.Type, *evt.StateKey, timelineIDs[i]) + } } for _, entry := range decryptionQueue { err = h.DB.SessionRequest.Put(ctx, entry) @@ -481,6 +492,7 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R ctx.Value(syncContextKey).(*syncContext).evt.Rooms[room.ID] = &SyncRoom{ Meta: room, Timeline: timelineRowTuples, + State: changedState, Reset: timeline.Limited, Events: allNewEvents, } From 0e1ff4e10a4ba2ed107912373d86a89c724fe7a9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 9 Oct 2024 01:28:14 +0300 Subject: [PATCH 0813/1647] hicli: don't dispatch event on empty sync --- hicli/events.go | 4 ++++ hicli/sync.go | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/hicli/events.go b/hicli/events.go index ea03be7e..b96fd266 100644 --- a/hicli/events.go +++ b/hicli/events.go @@ -24,6 +24,10 @@ type SyncComplete struct { Rooms map[id.RoomID]*SyncRoom `json:"rooms"` } +func (c *SyncComplete) IsEmpty() bool { + return len(c.Rooms) == 0 +} + type EventsDecrypted struct { RoomID id.RoomID `json:"room_id"` PreviewEventRowID database.EventRowID `json:"preview_event_rowid,omitempty"` diff --git a/hicli/sync.go b/hicli/sync.go index aaf1f5c6..36f0286f 100644 --- a/hicli/sync.go +++ b/hicli/sync.go @@ -68,7 +68,9 @@ func (h *HiClient) postProcessSyncResponse(ctx context.Context, resp *mautrix.Re h.WakeupRequestQueue() } h.firstSyncReceived = true - h.EventHandler(syncCtx.evt) + if !syncCtx.evt.IsEmpty() { + h.EventHandler(syncCtx.evt) + } } func (h *HiClient) asyncPostProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) { From 463fc41125fcaba3f30e761dc15b3279226c3a36 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 9 Oct 2024 01:58:54 +0300 Subject: [PATCH 0814/1647] hicli/database: add json tags for account data and receipts --- hicli/database/accountdata.go | 8 ++++---- hicli/database/receipt.go | 15 ++++++++------- hicli/sync.go | 3 ++- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/hicli/database/accountdata.go b/hicli/database/accountdata.go index 963886c3..8723b595 100644 --- a/hicli/database/accountdata.go +++ b/hicli/database/accountdata.go @@ -50,10 +50,10 @@ func (adq *AccountDataQuery) PutRoom(ctx context.Context, userID id.UserID, room } type AccountData struct { - UserID id.UserID - RoomID id.RoomID - Type string - Content json.RawMessage + UserID id.UserID `json:"user_id"` + RoomID id.RoomID `json:"room_id,omitempty"` + Type string `json:"type"` + Content json.RawMessage `json:"content"` } func (a *AccountData) Scan(row dbutil.Scannable) (*AccountData, error) { diff --git a/hicli/database/receipt.go b/hicli/database/receipt.go index a3370fba..8830efc7 100644 --- a/hicli/database/receipt.go +++ b/hicli/database/receipt.go @@ -12,6 +12,7 @@ import ( "go.mau.fi/util/dbutil" "go.mau.fi/util/exslices" + "go.mau.fi/util/jsontime" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -54,12 +55,12 @@ func (rq *ReceiptQuery) PutMany(ctx context.Context, roomID id.RoomID, receipts } type Receipt struct { - RoomID id.RoomID - UserID id.UserID - ReceiptType event.ReceiptType - ThreadID event.ThreadID - EventID id.EventID - Timestamp time.Time + RoomID id.RoomID `json:"room_id"` + UserID id.UserID `json:"user_id"` + ReceiptType event.ReceiptType `json:"receipt_type"` + ThreadID event.ThreadID `json:"thread_id"` + EventID id.EventID `json:"event_id"` + Timestamp jsontime.UnixMilli `json:"timestamp"` } func (r *Receipt) Scan(row dbutil.Scannable) (*Receipt, error) { @@ -68,7 +69,7 @@ func (r *Receipt) Scan(row dbutil.Scannable) (*Receipt, error) { if err != nil { return nil, err } - r.Timestamp = time.UnixMilli(ts) + r.Timestamp = jsontime.UM(time.UnixMilli(ts)) return r, nil } diff --git a/hicli/sync.go b/hicli/sync.go index 36f0286f..0b6ac9d8 100644 --- a/hicli/sync.go +++ b/hicli/sync.go @@ -16,6 +16,7 @@ import ( "github.com/tidwall/gjson" "github.com/tidwall/sjson" "go.mau.fi/util/exzerolog" + "go.mau.fi/util/jsontime" "golang.org/x/exp/slices" "maunium.net/go/mautrix" @@ -133,7 +134,7 @@ func receiptsToList(content *event.ReceiptEventContent) []*database.Receipt { ReceiptType: receiptType, ThreadID: receiptInfo.ThreadID, EventID: eventID, - Timestamp: receiptInfo.Timestamp, + Timestamp: jsontime.UM(receiptInfo.Timestamp), }) } } From c068fd7bd7c2aad1cec60cc6cc54c886dbc829c6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 9 Oct 2024 02:02:53 +0300 Subject: [PATCH 0815/1647] hicli/database: use exslices for chunking --- hicli/database/state.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/hicli/database/state.go b/hicli/database/state.go index 5dc13729..c12f9f60 100644 --- a/hicli/database/state.go +++ b/hicli/database/state.go @@ -4,16 +4,14 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -//go:build go1.23 - package database import ( "context" "fmt" - "slices" "go.mau.fi/util/dbutil" + "go.mau.fi/util/exslices" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -72,7 +70,7 @@ func (csq *CurrentStateQuery) AddMany(ctx context.Context, roomID id.RoomID, del return fmt.Errorf("failed to delete old state: %w", err) } } - for entryChunk := range slices.Chunk(entries, currentStateMassInsertBatchSize) { + for _, entryChunk := range exslices.Chunk(entries, currentStateMassInsertBatchSize) { query, params := massInsertCurrentStateBuilder.Build([1]any{roomID}, entryChunk) err = csq.Exec(ctx, query, params...) if err != nil { From 29b0d9b95c587b3f72834c2f16183b352ae4e3eb Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 9 Oct 2024 20:57:08 +0300 Subject: [PATCH 0816/1647] crypto: move more cross-signing logs to trace level --- crypto/cross_sign_store.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crypto/cross_sign_store.go b/crypto/cross_sign_store.go index 4c2c80e4..b583bada 100644 --- a/crypto/cross_sign_store.go +++ b/crypto/cross_sign_store.go @@ -51,7 +51,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK for _, key := range userKeys.Keys { log := log.With().Str("key", key.String()).Array("usages", exzerolog.ArrayOfStrs(userKeys.Usage)).Logger() for _, usage := range userKeys.Usage { - log.Debug().Str("usage", string(usage)).Msg("Storing cross-signing key") + log.Trace().Str("usage", string(usage)).Msg("Storing cross-signing key") if err = mach.CryptoStore.PutCrossSigningKey(ctx, userID, usage, key); err != nil { log.Error().Err(err).Msg("Error storing cross-signing key") } @@ -86,7 +86,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK log.Warn().Err(err).Msg("Error verifying cross-signing key signature") } else { if verified { - log.Debug().Err(err).Msg("Cross-signing key signature verified") + log.Trace().Err(err).Msg("Cross-signing key signature verified") err = mach.CryptoStore.PutSignature(ctx, userID, key, signUserID, signingKey, signature) if err != nil { log.Error().Err(err).Msg("Error storing cross-signing key signature") From 33834b1b2cf48526e1603365bb60feaf9dbe2cef Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 9 Oct 2024 21:23:12 +0300 Subject: [PATCH 0817/1647] hicli/send: ignore context cancellation inside goroutine --- hicli/send.go | 1 + 1 file changed, 1 insertion(+) diff --git a/hicli/send.go b/hicli/send.go index 1c76a5a2..42b309a0 100644 --- a/hicli/send.go +++ b/hicli/send.go @@ -77,6 +77,7 @@ func (h *HiClient) Send(ctx context.Context, roomID id.RoomID, evtType event.Typ return nil, fmt.Errorf("failed to insert event into database: %w", err) } go func() { + ctx := context.WithoutCancel(ctx) var err error defer func() { h.EventHandler(&SendComplete{ From 691c834144b07e1469d0279bb3c3f444527809a2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 10 Oct 2024 15:51:20 +0200 Subject: [PATCH 0818/1647] bridgev2: add option to use deterministic ID for outgoing messages (#292) --- bridgev2/bridgeconfig/config.go | 1 + bridgev2/commands/processor.go | 8 +++---- bridgev2/matrix/connector.go | 10 ++++----- bridgev2/messagestatus.go | 35 ++++++++++++++++--------------- bridgev2/portal.go | 37 ++++++++++++++++++++------------- bridgev2/portalinternal.go | 4 ++-- 6 files changed, 53 insertions(+), 42 deletions(-) diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index aa07f42e..4c0fa6b4 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -66,6 +66,7 @@ type BridgeConfig struct { BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` TagOnlyOnCreate bool `yaml:"tag_only_on_create"` MuteOnlyOnCreate bool `yaml:"mute_only_on_create"` + OutgoingMessageReID bool `yaml:"outgoing_message_re_id"` CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"` Relay RelayConfig `yaml:"relay"` Permissions PermissionConfig `yaml:"permissions"` diff --git a/bridgev2/commands/processor.go b/bridgev2/commands/processor.go index 482acf18..1aca596c 100644 --- a/bridgev2/commands/processor.go +++ b/bridgev2/commands/processor.go @@ -77,10 +77,10 @@ func (proc *Processor) Handle(ctx context.Context, roomID id.RoomID, eventID id. log := &logCopy defer func() { statusInfo := &bridgev2.MessageStatusEventInfo{ - RoomID: roomID, - EventID: eventID, - EventType: event.EventMessage, - Sender: user.MXID, + RoomID: roomID, + SourceEventID: eventID, + EventType: event.EventMessage, + Sender: user.MXID, } err := recover() if err != nil { diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 35ef4a08..94fdd97c 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -446,7 +446,7 @@ func (br *Connector) SendMessageStatus(ctx context.Context, ms *bridgev2.Message } func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2.MessageStatus, evt *bridgev2.MessageStatusEventInfo, editEvent id.EventID) id.EventID { - if evt.EventType.IsEphemeral() || evt.EventID == "" { + if evt.EventType.IsEphemeral() || evt.SourceEventID == "" { return "" } log := zerolog.Ctx(ctx) @@ -460,7 +460,7 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2 if err != nil { log.Err(err). Stringer("room_id", evt.RoomID). - Stringer("event_id", evt.EventID). + Stringer("event_id", evt.SourceEventID). Any("mss_content", mssEvt). Msg("Failed to send MSS event") } @@ -474,7 +474,7 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2 if err != nil { log.Err(err). Stringer("room_id", evt.RoomID). - Stringer("event_id", evt.EventID). + Stringer("event_id", evt.SourceEventID). Str("notice_message", content.Body). Msg("Failed to send notice event") } else { @@ -482,11 +482,11 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2 } } if ms.Status == event.MessageStatusSuccess && br.Config.Matrix.DeliveryReceipts { - err = br.Bot.SendReceipt(ctx, evt.RoomID, evt.EventID, event.ReceiptTypeRead, nil) + err = br.Bot.SendReceipt(ctx, evt.RoomID, evt.SourceEventID, event.ReceiptTypeRead, nil) if err != nil { log.Err(err). Stringer("room_id", evt.RoomID). - Stringer("event_id", evt.EventID). + Stringer("event_id", evt.SourceEventID). Msg("Failed to send Matrix delivery receipt") } } diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index 77ca98fd..1983b4de 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -18,13 +18,14 @@ import ( ) type MessageStatusEventInfo struct { - RoomID id.RoomID - EventID id.EventID - EventType event.Type - MessageType event.MessageType - Sender id.UserID - ThreadRoot id.EventID - StreamOrder int64 + RoomID id.RoomID + SourceEventID id.EventID + NewEventID id.EventID + EventType event.Type + MessageType event.MessageType + Sender id.UserID + ThreadRoot id.EventID + StreamOrder int64 } func StatusEventInfoFromEvent(evt *event.Event) *MessageStatusEventInfo { @@ -33,12 +34,12 @@ func StatusEventInfoFromEvent(evt *event.Event) *MessageStatusEventInfo { threadRoot = relatable.OptionalGetRelatesTo().GetThreadParent() } return &MessageStatusEventInfo{ - RoomID: evt.RoomID, - EventID: evt.ID, - EventType: evt.Type, - MessageType: evt.Content.AsMessage().MsgType, - Sender: evt.Sender, - ThreadRoot: threadRoot, + RoomID: evt.RoomID, + SourceEventID: evt.ID, + EventType: evt.Type, + MessageType: evt.Content.AsMessage().MsgType, + Sender: evt.Sender, + ThreadRoot: threadRoot, } } @@ -150,7 +151,7 @@ func (ms *MessageStatus) ToCheckpoint(evt *MessageStatusEventInfo) *status.Messa } checkpoint := &status.MessageCheckpoint{ RoomID: evt.RoomID, - EventID: evt.EventID, + EventID: evt.SourceEventID, Step: step, Timestamp: jsontime.UnixMilliNow(), Status: ms.checkpointStatus(), @@ -171,7 +172,7 @@ func (ms *MessageStatus) ToMSSEvent(evt *MessageStatusEventInfo) *event.BeeperMe content := &event.BeeperMessageStatusEventContent{ RelatesTo: event.RelatesTo{ Type: event.RelReference, - EventID: evt.EventID, + EventID: evt.SourceEventID, }, Status: ms.Status, Reason: ms.ErrorReason, @@ -216,9 +217,9 @@ func (ms *MessageStatus) ToNoticeEvent(evt *MessageStatusEventInfo) *event.Messa Mentions: &event.Mentions{}, } if evt.ThreadRoot != "" { - content.RelatesTo.SetThread(evt.ThreadRoot, evt.EventID) + content.RelatesTo.SetThread(evt.ThreadRoot, evt.SourceEventID) } else { - content.RelatesTo.SetReplyTo(evt.EventID) + content.RelatesTo.SetReplyTo(evt.SourceEventID) } if evt.Sender != "" { content.Mentions.UserIDs = []id.UserID{evt.Sender} diff --git a/bridgev2/portal.go b/bridgev2/portal.go index f8dc8596..016d3693 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -440,9 +440,12 @@ func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowR } } -func (portal *Portal) sendSuccessStatus(ctx context.Context, evt *event.Event, streamOrder int64) { +func (portal *Portal) sendSuccessStatus(ctx context.Context, evt *event.Event, streamOrder int64, newEventID id.EventID) { info := StatusEventInfoFromEvent(evt) info.StreamOrder = streamOrder + if newEventID != evt.ID { + info.NewEventID = newEventID + } portal.Bridge.Matrix.SendMessageStatus(ctx, &MessageStatus{Status: event.MessageStatusSuccess}, info) } @@ -922,6 +925,9 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin if resp.DB == nil { log.Error().Msg("Network connector didn't return a message to save") } else { + if portal.Bridge.Config.OutgoingMessageReID { + message.MXID = portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, message.ID, message.PartID) + } // Hack to ensure the ghost row exists // TODO move to better place (like login) portal.Bridge.GetGhostByID(ctx, message.SenderID) @@ -937,7 +943,7 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin portal.outgoingMessagesLock.Unlock() } } - portal.sendSuccessStatus(ctx, evt, resp.StreamOrder) + portal.sendSuccessStatus(ctx, evt, resp.StreamOrder, message.MXID) } if portal.Disappear.Type != database.DisappearingTypeNone { go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ @@ -1090,7 +1096,7 @@ func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, o log.Err(err).Msg("Failed to save message to database after editing") } // TODO allow returning stream order from HandleMatrixEdit - portal.sendSuccessStatus(ctx, evt, 0) + portal.sendSuccessStatus(ctx, evt, 0, "") } func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) { @@ -1144,7 +1150,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi } else if existing != nil { if existing.EmojiID != "" || existing.Emoji == preResp.Emoji { log.Debug().Msg("Ignoring duplicate reaction") - portal.sendSuccessStatus(ctx, evt, 0) + portal.sendSuccessStatus(ctx, evt, 0, "") return } react.ReactionToOverride = existing @@ -1226,7 +1232,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi if err != nil { log.Err(err).Msg("Failed to save reaction to database") } - portal.sendSuccessStatus(ctx, evt, 0) + portal.sendSuccessStatus(ctx, evt, 0, "") } func handleMatrixRoomMeta[APIType any, ContentType any]( @@ -1252,17 +1258,17 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( switch typedContent := evt.Content.Parsed.(type) { case *event.RoomNameEventContent: if typedContent.Name == portal.Name { - portal.sendSuccessStatus(ctx, evt, 0) + portal.sendSuccessStatus(ctx, evt, 0, "") return } case *event.TopicEventContent: if typedContent.Topic == portal.Topic { - portal.sendSuccessStatus(ctx, evt, 0) + portal.sendSuccessStatus(ctx, evt, 0, "") return } case *event.RoomAvatarEventContent: if typedContent.URL == portal.AvatarMXC { - portal.sendSuccessStatus(ctx, evt, 0) + portal.sendSuccessStatus(ctx, evt, 0, "") return } } @@ -1293,7 +1299,7 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( log.Err(err).Msg("Failed to save portal after updating room metadata") } } - portal.sendSuccessStatus(ctx, evt, 0) + portal.sendSuccessStatus(ctx, evt, 0, "") } func handleMatrixAccountData[APIType any, ContentType any]( @@ -1583,7 +1589,7 @@ func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLog return } // TODO delete msg/reaction db row - portal.sendSuccessStatus(ctx, evt, 0) + portal.sendSuccessStatus(ctx, evt, 0, "") } func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, evtType RemoteEventType, evt RemoteEvent) { @@ -1894,6 +1900,9 @@ func (portal *Portal) checkPendingMessage(ctx context.Context, evt RemoteMessage saveMessage, statusErr = pending.handle(evt, pending.db) } if saveMessage { + if portal.Bridge.Config.OutgoingMessageReID { + pending.db.MXID = portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, pending.db.ID, pending.db.PartID) + } // Hack to ensure the ghost row exists // TODO move to better place (like login) portal.Bridge.GetGhostByID(ctx, pending.db.SenderID) @@ -1906,7 +1915,7 @@ func (portal *Portal) checkPendingMessage(ctx context.Context, evt RemoteMessage if statusErr != nil { portal.sendErrorStatus(ctx, pending.evt, statusErr) } else { - portal.sendSuccessStatus(ctx, pending.evt, getStreamOrder(evt)) + portal.sendSuccessStatus(ctx, pending.evt, getStreamOrder(evt), pending.evt.ID) } } zerolog.Ctx(ctx).Debug().Stringer("event_id", pending.evt.ID).Msg("Received remote echo for message") @@ -2571,9 +2580,9 @@ func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *U Status: event.MessageStatusSuccess, DeliveredTo: []id.UserID{intent.GetMXID()}, }, &MessageStatusEventInfo{ - RoomID: portal.MXID, - EventID: part.MXID, - Sender: part.SenderMXID, + RoomID: portal.MXID, + SourceEventID: part.MXID, + Sender: part.SenderMXID, }) } } diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index a3b1fbf4..a5da077b 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -49,8 +49,8 @@ func (portal *PortalInternals) HandleSingleEvent(ctx context.Context, rawEvt any (*Portal)(portal).handleSingleEvent(ctx, rawEvt, doneCallback) } -func (portal *PortalInternals) SendSuccessStatus(ctx context.Context, evt *event.Event, streamOrder int64) { - (*Portal)(portal).sendSuccessStatus(ctx, evt, streamOrder) +func (portal *PortalInternals) SendSuccessStatus(ctx context.Context, evt *event.Event, streamOrder int64, newEventID id.EventID) { + (*Portal)(portal).sendSuccessStatus(ctx, evt, streamOrder, newEventID) } func (portal *PortalInternals) SendErrorStatus(ctx context.Context, evt *event.Event, err error) { From 38610d681dcd6c4addb443c50a0efe2adb102df0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 10 Oct 2024 17:04:25 +0300 Subject: [PATCH 0819/1647] bridgev2/backfill: don't try to backfill if client is logged out --- bridgev2/backfillqueue.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bridgev2/backfillqueue.go b/bridgev2/backfillqueue.go index 95f3107d..fce4a1b0 100644 --- a/bridgev2/backfillqueue.go +++ b/bridgev2/backfillqueue.go @@ -187,8 +187,10 @@ func (br *Bridge) actuallyDoBackfillTask(ctx context.Context, task *database.Bac if login == nil { task.UserLoginID = "" } + foundLogin := false for _, login = range logins { if login.Client.IsLoggedIn() { + foundLogin = true task.UserLoginID = login.ID log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Str("overridden_login_id", string(login.ID)) @@ -197,7 +199,7 @@ func (br *Bridge) actuallyDoBackfillTask(ctx context.Context, task *database.Bac break } } - if task.UserLoginID == "" { + if !foundLogin { log.Debug().Msg("No logged in user logins found for backfill task") task.NextDispatchMinTS = database.BackfillNextDispatchNever return false, nil From a7e8a2ce059f281b7e4a83ecbe4cb325ebb1376d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 10 Oct 2024 20:09:24 +0300 Subject: [PATCH 0820/1647] hicli/sync: update prev batch on limited timelines --- hicli/sync.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hicli/sync.go b/hicli/sync.go index 0b6ac9d8..e5b59c33 100644 --- a/hicli/sync.go +++ b/hicli/sync.go @@ -481,7 +481,7 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R updatedRoom.Name = &name updatedRoom.NameQuality = database.NameQualityParticipants } - if timeline.PrevBatch != "" && room.PrevBatch == "" { + if timeline.PrevBatch != "" && (room.PrevBatch == "" || timeline.Limited) { updatedRoom.PrevBatch = timeline.PrevBatch } roomChanged := updatedRoom.CheckChangesAndCopyInto(room) From 5e4a7b56d824871f7e6c783ecaf642fbdfd6c32d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 10 Oct 2024 20:17:13 +0300 Subject: [PATCH 0821/1647] hicli/sync: include empty list as timeline if there are no new events --- hicli/sync.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/hicli/sync.go b/hicli/sync.go index e5b59c33..1f0c7979 100644 --- a/hicli/sync.go +++ b/hicli/sync.go @@ -471,6 +471,8 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R if err != nil { return fmt.Errorf("failed to append timeline: %w", err) } + } else { + timelineRowTuples = make([]database.TimelineRowTuple, 0) } // Calculate name from participants if participants changed and current name was generated from participants, or if the room name was unset if (heroesChanged && updatedRoom.NameQuality <= database.NameQualityParticipants) || updatedRoom.NameQuality == database.NameQualityNil { From 3d84843ec4fd2901ad20f351c214cfc60e68c0f4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 10 Oct 2024 20:18:42 +0300 Subject: [PATCH 0822/1647] hicli/paginate: add special value for pagination complete --- hicli/database/room.go | 2 ++ hicli/paginate.go | 7 +++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/hicli/database/room.go b/hicli/database/room.go index e7138d94..9b9e2a9c 100644 --- a/hicli/database/room.go +++ b/hicli/database/room.go @@ -106,6 +106,8 @@ const ( NameQualityExplicit ) +const PrevBatchPaginationComplete = "fi.mau.gomuks.pagination_complete" + type Room struct { ID id.RoomID `json:"room_id"` CreationContent *event.CreateEventContent `json:"creation_content,omitempty"` diff --git a/hicli/paginate.go b/hicli/paginate.go index 4109e7af..da927b9b 100644 --- a/hicli/paginate.go +++ b/hicli/paginate.go @@ -159,7 +159,7 @@ func (h *HiClient) PaginateServer(ctx context.Context, roomID id.RoomID, limit i room, err := h.DB.Room.Get(ctx, roomID) if err != nil { return nil, fmt.Errorf("failed to get room from database: %w", err) - } else if room.PrevBatch == "" { + } else if room.PrevBatch == database.PrevBatchPaginationComplete { return &PaginationResponse{Events: []*database.Event{}, HasMore: false}, nil } resp, err := h.Client.Messages(ctx, roomID, room.PrevBatch, "", mautrix.DirectionBackward, nil, limit) @@ -167,7 +167,10 @@ func (h *HiClient) PaginateServer(ctx context.Context, roomID id.RoomID, limit i return nil, fmt.Errorf("failed to get messages from server: %w", err) } events := make([]*database.Event, len(resp.Chunk)) - if resp.End == "" || len(resp.Chunk) == 0 { + if resp.End == "" { + resp.End = database.PrevBatchPaginationComplete + } + if resp.End == database.PrevBatchPaginationComplete || len(resp.Chunk) == 0 { err = h.DB.Room.SetPrevBatch(ctx, room.ID, resp.End) if err != nil { return nil, fmt.Errorf("failed to set prev_batch: %w", err) From 50f4a2eec1936c0f5cbc295ee21e76cd0bbea113 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 10 Oct 2024 20:18:58 +0300 Subject: [PATCH 0823/1647] client: omit from parameter in /messages if empty --- client.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index 125bba0d..b85d86fb 100644 --- a/client.go +++ b/client.go @@ -1836,11 +1836,10 @@ func (cli *Client) Hierarchy(ctx context.Context, roomID id.RoomID, req *ReqHier // Messages returns a list of message and state events for a room. It uses // pagination query parameters to paginate history in the room. -// See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidmessages +// See https://spec.matrix.org/v1.12/client-server-api/#get_matrixclientv3roomsroomidmessages func (cli *Client) Messages(ctx context.Context, roomID id.RoomID, from, to string, dir Direction, filter *FilterPart, limit int) (resp *RespMessages, err error) { query := map[string]string{ - "from": from, - "dir": string(dir), + "dir": string(dir), } if filter != nil { filterJSON, err := json.Marshal(filter) @@ -1849,6 +1848,9 @@ func (cli *Client) Messages(ctx context.Context, roomID id.RoomID, from, to stri } query["filter"] = string(filterJSON) } + if from != "" { + query["from"] = from + } if to != "" { query["to"] = to } From 8b3828c7640376ff33abaf77492c5cde55009a90 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 11 Oct 2024 00:30:16 +0300 Subject: [PATCH 0824/1647] hicli/sync: remove reply fallback from stickers --- hicli/sync.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hicli/sync.go b/hicli/sync.go index 1f0c7979..300187cf 100644 --- a/hicli/sync.go +++ b/hicli/sync.go @@ -207,7 +207,7 @@ func isDecryptionErrorRetryable(err error) bool { } func removeReplyFallback(evt *event.Event) []byte { - if evt.Type != event.EventMessage { + if evt.Type != event.EventMessage && evt.Type != event.EventSticker { return nil } _ = evt.Content.ParseRaw(evt.Type) From 99ce4618c60205527fcb728d310e91e1ccebdc06 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 12 Oct 2024 01:08:46 +0300 Subject: [PATCH 0825/1647] hicli/sync: eagerly include related events in sync events --- hicli/database/event.go | 4 +-- hicli/database/room.go | 18 +++++++++++++ hicli/sync.go | 56 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 76 insertions(+), 2 deletions(-) diff --git a/hicli/database/event.go b/hicli/database/event.go index 2953bb9b..db59afaf 100644 --- a/hicli/database/event.go +++ b/hicli/database/event.go @@ -371,7 +371,7 @@ func (e *Event) Scan(row dbutil.Scannable) (*Event, error) { e.TransactionID = transactionID.String e.RedactedBy = id.EventID(redactedBy.String) e.RelatesTo = id.EventID(relatesTo.String) - e.RelationType = event.RelationType(relatesTo.String) + e.RelationType = event.RelationType(relationType.String) e.MegolmSessionID = id.SessionID(megolmSessionID.String) e.DecryptedType = decryptedType.String e.DecryptionError = decryptionError.String @@ -440,7 +440,7 @@ func (e *Event) CanUseForPreview() bool { return (e.Type == event.EventMessage.Type || e.Type == event.EventSticker.Type || (e.Type == event.EventEncrypted.Type && (e.DecryptedType == event.EventMessage.Type || e.DecryptedType == event.EventSticker.Type))) && - e.RelationType != event.RelReplace + e.RelationType != event.RelReplace && e.RedactedBy == "" } func (e *Event) BumpsSortingTimestamp() bool { diff --git a/hicli/database/room.go b/hicli/database/room.go index 9b9e2a9c..a5e4f75a 100644 --- a/hicli/database/room.go +++ b/hicli/database/room.go @@ -60,6 +60,19 @@ const ( > COALESCE((SELECT rowid FROM timeline WHERE event_rowid = preview_event_rowid), 0) RETURNING preview_event_rowid ` + recalculateRoomPreviewEventQuery = ` + SELECT rowid + FROM event + WHERE + room_id = $1 + AND (type IN ('m.room.message', 'm.sticker') + OR (type = 'm.room.encrypted' + AND decrypted_type IN ('m.room.message', 'm.sticker'))) + AND relation_type <> 'm.replace' + AND redacted_by IS NULL + ORDER BY timestamp DESC + LIMIT 1 + ` ) type RoomQuery struct { @@ -97,6 +110,11 @@ func (rq *RoomQuery) UpdatePreviewIfLaterOnTimeline(ctx context.Context, roomID return } +func (rq *RoomQuery) RecalculatePreview(ctx context.Context, roomID id.RoomID) (rowID EventRowID, err error) { + err = rq.GetDB().QueryRow(ctx, recalculateRoomPreviewEventQuery, roomID).Scan(&rowID) + return +} + type NameQuality int const ( diff --git a/hicli/sync.go b/hicli/sync.go index 300187cf..d66c35bd 100644 --- a/hicli/sync.go +++ b/hicli/sync.go @@ -336,6 +336,8 @@ func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptio if err != nil { return dbEvt, fmt.Errorf("failed to set redacts field: %w", err) } + } else if evt.Redacts == "" { + evt.Redacts = id.EventID(gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str) } } _, err := h.DB.Event.Upsert(ctx, dbEvt) @@ -382,6 +384,38 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R } decryptionQueue := make(map[id.SessionID]*database.SessionRequest) allNewEvents := make([]*database.Event, 0, len(state.Events)+len(timeline.Events)) + recalculatePreviewEvent := false + addOldEvent := func(rowID database.EventRowID, evtID id.EventID) (dbEvt *database.Event, err error) { + if rowID != 0 { + dbEvt, err = h.DB.Event.GetByRowID(ctx, rowID) + } else { + dbEvt, err = h.DB.Event.GetByID(ctx, evtID) + } + if err != nil { + return nil, fmt.Errorf("failed to get redaction target: %w", err) + } else if dbEvt == nil { + return nil, nil + } + allNewEvents = append(allNewEvents, dbEvt) + return dbEvt, nil + } + processRedaction := func(evt *event.Event) error { + dbEvt, err := addOldEvent(0, evt.Redacts) + if err != nil { + return fmt.Errorf("failed to get redaction target: %w", err) + } + if dbEvt.RelationType == event.RelReplace || dbEvt.RelationType == event.RelAnnotation { + _, err = addOldEvent(0, dbEvt.RelatesTo) + if err != nil { + return fmt.Errorf("failed to get relation target of redaction target: %w", err) + } + } + if updatedRoom.PreviewEventRowID == dbEvt.RowID { + updatedRoom.PreviewEventRowID = 0 + recalculatePreviewEvent = true + } + return nil + } processNewEvent := func(evt *event.Event, isTimeline bool) (database.EventRowID, error) { evt.RoomID = room.ID dbEvt, err := h.processEvent(ctx, evt, decryptionQueue, false) @@ -391,6 +425,7 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R if isTimeline { if dbEvt.CanUseForPreview() { updatedRoom.PreviewEventRowID = dbEvt.RowID + recalculatePreviewEvent = false } updatedRoom.BumpSortingTimestamp(dbEvt) } @@ -411,6 +446,17 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R processImportantEvent(ctx, evt, room, updatedRoom) } allNewEvents = append(allNewEvents, dbEvt) + if evt.Type == event.EventRedaction && evt.Redacts != "" { + err = processRedaction(evt) + if err != nil { + return -1, fmt.Errorf("failed to process redaction: %w", err) + } + } else if dbEvt.RelationType == event.RelReplace || dbEvt.RelationType == event.RelAnnotation { + _, err = addOldEvent(0, dbEvt.RelatesTo) + if err != nil { + return -1, fmt.Errorf("failed to get relation target of event: %w", err) + } + } return dbEvt.RowID, nil } changedState := make(map[event.Type]map[string]database.EventRowID) @@ -474,6 +520,16 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R } else { timelineRowTuples = make([]database.TimelineRowTuple, 0) } + if recalculatePreviewEvent && updatedRoom.PreviewEventRowID == 0 { + updatedRoom.PreviewEventRowID, err = h.DB.Room.RecalculatePreview(ctx, room.ID) + if err != nil { + return fmt.Errorf("failed to recalculate preview event: %w", err) + } + _, err = addOldEvent(updatedRoom.PreviewEventRowID, "") + if err != nil { + return fmt.Errorf("failed to get preview event: %w", err) + } + } // Calculate name from participants if participants changed and current name was generated from participants, or if the room name was unset if (heroesChanged && updatedRoom.NameQuality <= database.NameQualityParticipants) || updatedRoom.NameQuality == database.NameQualityNil { name, err := h.calculateRoomParticipantName(ctx, room.ID, summary) From 9e796dd66c0de4457765fb61329be78ed918c00d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 12 Oct 2024 12:14:19 +0300 Subject: [PATCH 0826/1647] hicli/json: handle ping commands --- hicli/json.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/hicli/json.go b/hicli/json.go index df853c33..a27fd007 100644 --- a/hicli/json.go +++ b/hicli/json.go @@ -70,6 +70,12 @@ func (h *HiClient) dispatchCurrentState() { } func (h *HiClient) SubmitJSONCommand(ctx context.Context, req *JSONCommand) *JSONCommand { + if req.Command == "ping" { + return &JSONCommand{ + Command: "pong", + RequestID: req.RequestID, + } + } log := h.Log.With().Int64("request_id", req.RequestID).Str("command", req.Command).Logger() ctx, cancel := context.WithCancelCause(ctx) defer func() { From 8f6dec74c7f768271d0d1bee763f7b5bd8810fe8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 12 Oct 2024 14:08:45 +0300 Subject: [PATCH 0827/1647] hicli/database: add more checks for edit triggers --- hicli/database/upgrades/00-latest-revision.sql | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/hicli/database/upgrades/00-latest-revision.sql b/hicli/database/upgrades/00-latest-revision.sql index 1617035c..8ba1fd15 100644 --- a/hicli/database/upgrades/00-latest-revision.sql +++ b/hicli/database/upgrades/00-latest-revision.sql @@ -102,6 +102,7 @@ CREATE TRIGGER event_update_last_edit_when_redacted WHEN OLD.redacted_by IS NULL AND NEW.redacted_by IS NOT NULL AND NEW.relation_type = 'm.replace' + AND NEW.state_key IS NULL BEGIN UPDATE event SET last_edit_rowid = COALESCE( @@ -113,11 +114,13 @@ BEGIN AND edit.type = event.type AND edit.sender = event.sender AND edit.redacted_by IS NULL + AND edit.state_key IS NULL ORDER BY edit.timestamp DESC LIMIT 1), 0) WHERE event_id = NEW.relates_to - AND last_edit_rowid = NEW.rowid; + AND last_edit_rowid = NEW.rowid + AND state_key IS NULL; END; CREATE TRIGGER event_insert_update_last_edit @@ -125,6 +128,7 @@ CREATE TRIGGER event_insert_update_last_edit ON event WHEN NEW.relation_type = 'm.replace' AND NEW.redacted_by IS NULL + AND NEW.state_key IS NULL BEGIN UPDATE event SET last_edit_rowid = NEW.rowid From 190760cd65849d063290b17525e7204ffd55a0dc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 12 Oct 2024 14:52:21 +0300 Subject: [PATCH 0828/1647] hicli: add support for sending markdown and rainbows --- format/mdext/rainbow/goldmark.go | 120 +++++++++++++++++++++++++++++++ format/mdext/rainbow/gradient.go | 56 +++++++++++++++ go.mod | 2 + go.sum | 4 ++ hicli/json-commands.go | 14 +++- hicli/send.go | 26 +++++++ 6 files changed, 220 insertions(+), 2 deletions(-) create mode 100644 format/mdext/rainbow/goldmark.go create mode 100644 format/mdext/rainbow/gradient.go diff --git a/format/mdext/rainbow/goldmark.go b/format/mdext/rainbow/goldmark.go new file mode 100644 index 00000000..59a36178 --- /dev/null +++ b/format/mdext/rainbow/goldmark.go @@ -0,0 +1,120 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package rainbow + +import ( + "fmt" + "unicode" + + "github.com/rivo/uniseg" + "github.com/yuin/goldmark" + "github.com/yuin/goldmark/ast" + "github.com/yuin/goldmark/renderer" + "github.com/yuin/goldmark/renderer/html" + "github.com/yuin/goldmark/util" + "go.mau.fi/util/random" +) + +// Extension is a goldmark extension that adds rainbow text coloring to the HTML renderer. +var Extension = &extRainbow{} + +type extRainbow struct{} +type rainbowRenderer struct { + HardWraps bool + ColorID string +} + +var defaultRB = &rainbowRenderer{HardWraps: true, ColorID: random.String(16)} + +func (er *extRainbow) Extend(m goldmark.Markdown) { + m.Renderer().AddOptions(renderer.WithNodeRenderers(util.Prioritized(defaultRB, 0))) +} + +func (rb *rainbowRenderer) RegisterFuncs(reg renderer.NodeRendererFuncRegisterer) { + reg.Register(ast.KindText, rb.renderText) + reg.Register(ast.KindString, rb.renderString) +} + +type rainbowBufWriter struct { + util.BufWriter + ColorID string +} + +func (rbw rainbowBufWriter) WriteString(s string) (int, error) { + i := 0 + graphemes := uniseg.NewGraphemes(s) + for graphemes.Next() { + runes := graphemes.Runes() + if len(runes) == 1 && unicode.IsSpace(runes[0]) { + i2, err := rbw.BufWriter.WriteRune(runes[0]) + i += i2 + if err != nil { + return i, err + } + continue + } + i2, err := fmt.Fprintf(rbw.BufWriter, "%s", rbw.ColorID, graphemes.Str()) + i += i2 + if err != nil { + return i, err + } + } + return i, nil +} + +func (rbw rainbowBufWriter) Write(data []byte) (int, error) { + return rbw.WriteString(string(data)) +} + +func (rbw rainbowBufWriter) WriteByte(c byte) error { + _, err := rbw.WriteRune(rune(c)) + return err +} + +func (rbw rainbowBufWriter) WriteRune(r rune) (int, error) { + if unicode.IsSpace(r) { + return rbw.BufWriter.WriteRune(r) + } else { + return fmt.Fprintf(rbw.BufWriter, "%c", rbw.ColorID, r) + } +} + +func (rb *rainbowRenderer) renderText(w util.BufWriter, source []byte, node ast.Node, entering bool) (ast.WalkStatus, error) { + if !entering { + return ast.WalkContinue, nil + } + n := node.(*ast.Text) + segment := n.Segment + if n.IsRaw() { + html.DefaultWriter.RawWrite(rainbowBufWriter{w, rb.ColorID}, segment.Value(source)) + } else { + html.DefaultWriter.Write(rainbowBufWriter{w, rb.ColorID}, segment.Value(source)) + if n.HardLineBreak() || (n.SoftLineBreak() && rb.HardWraps) { + _, _ = w.WriteString("
\n") + } else if n.SoftLineBreak() { + _ = w.WriteByte('\n') + } + } + return ast.WalkContinue, nil +} + +func (rb *rainbowRenderer) renderString(w util.BufWriter, source []byte, node ast.Node, entering bool) (ast.WalkStatus, error) { + if !entering { + return ast.WalkContinue, nil + } + n := node.(*ast.String) + if n.IsCode() { + _, _ = w.Write(n.Value) + } else { + if n.IsRaw() { + html.DefaultWriter.RawWrite(rainbowBufWriter{w, rb.ColorID}, n.Value) + } else { + html.DefaultWriter.Write(rainbowBufWriter{w, rb.ColorID}, n.Value) + } + } + return ast.WalkContinue, nil +} diff --git a/format/mdext/rainbow/gradient.go b/format/mdext/rainbow/gradient.go new file mode 100644 index 00000000..34c499e6 --- /dev/null +++ b/format/mdext/rainbow/gradient.go @@ -0,0 +1,56 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package rainbow + +import ( + "regexp" + "strings" + + "github.com/lucasb-eyer/go-colorful" +) + +// GradientTable from https://github.com/lucasb-eyer/go-colorful/blob/master/doc/gradientgen/gradientgen.go +type GradientTable []struct { + Col colorful.Color + Pos float64 +} + +func (gt GradientTable) GetInterpolatedColorFor(t float64) colorful.Color { + for i := 0; i < len(gt)-1; i++ { + c1 := gt[i] + c2 := gt[i+1] + if c1.Pos <= t && t <= c2.Pos { + t := (t - c1.Pos) / (c2.Pos - c1.Pos) + return c1.Col.BlendHcl(c2.Col, t).Clamped() + } + } + return gt[len(gt)-1].Col +} + +var Gradient = GradientTable{ + {colorful.LinearRgb(1, 0, 0), 0 / 11.0}, + {colorful.LinearRgb(1, 0.5, 0), 1 / 11.0}, + {colorful.LinearRgb(1, 1, 0), 2 / 11.0}, + {colorful.LinearRgb(0.5, 1, 0), 3 / 11.0}, + {colorful.LinearRgb(0, 1, 0), 4 / 11.0}, + {colorful.LinearRgb(0, 1, 0.5), 5 / 11.0}, + {colorful.LinearRgb(0, 1, 1), 6 / 11.0}, + {colorful.LinearRgb(0, 0.5, 1), 7 / 11.0}, + {colorful.LinearRgb(0, 0, 1), 8 / 11.0}, + {colorful.LinearRgb(0.5, 0, 1), 9 / 11.0}, + {colorful.LinearRgb(1, 0, 1), 10 / 11.0}, + {colorful.LinearRgb(1, 0, 0.5), 11 / 11.0}, +} + +func ApplyColor(htmlBody string) string { + count := strings.Count(htmlBody, defaultRB.ColorID) + i := -1 + return regexp.MustCompile(defaultRB.ColorID).ReplaceAllStringFunc(htmlBody, func(match string) string { + i++ + return Gradient.GetInterpolatedColorFor(float64(i) / float64(count)).Hex() + }) +} diff --git a/go.mod b/go.mod index 8ef08be8..a1e97f8c 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,9 @@ require ( github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.5.0 github.com/lib/pq v1.10.9 + github.com/lucasb-eyer/go-colorful v1.2.0 github.com/mattn/go-sqlite3 v1.14.23 + github.com/rivo/uniseg v0.4.7 github.com/rs/xid v1.6.0 github.com/rs/zerolog v1.33.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e diff --git a/go.sum b/go.sum index ac8e03f6..b825326f 100644 --- a/go.sum +++ b/go.sum @@ -19,6 +19,8 @@ github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWm github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= +github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= @@ -31,6 +33,8 @@ github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7/go.mod h1:pxMtw7c github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= diff --git a/hicli/json-commands.go b/hicli/json-commands.go index 12026f6b..a81cf097 100644 --- a/hicli/json-commands.go +++ b/hicli/json-commands.go @@ -38,7 +38,11 @@ func (h *HiClient) handleJSONCommand(ctx context.Context, req *JSONCommand) (any return true, nil }) case "send_message": - return unmarshalAndCall(req.Data, func(params *sendParams) (*database.Event, error) { + return unmarshalAndCall(req.Data, func(params *sendMessageParams) (*database.Event, error) { + return h.SendMessage(ctx, params.RoomID, params.Text, params.MediaPath) + }) + case "send_event": + return unmarshalAndCall(req.Data, func(params *sendEventParams) (*database.Event, error) { return h.Send(ctx, params.RoomID, params.EventType, params.Content) }) case "get_event": @@ -100,7 +104,13 @@ type cancelRequestParams struct { Reason string `json:"reason"` } -type sendParams struct { +type sendMessageParams struct { + RoomID id.RoomID `json:"room_id"` + Text string `json:"text"` + MediaPath string `json:"media_path"` +} + +type sendEventParams struct { RoomID id.RoomID `json:"room_id"` EventType event.Type `json:"type"` Content json.RawMessage `json:"content"` diff --git a/hicli/send.go b/hicli/send.go index 42b309a0..d16cadf6 100644 --- a/hicli/send.go +++ b/hicli/send.go @@ -11,17 +11,43 @@ import ( "encoding/json" "errors" "fmt" + "strings" "time" "github.com/rs/zerolog" + "github.com/yuin/goldmark" "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto" "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" + "maunium.net/go/mautrix/format/mdext/rainbow" "maunium.net/go/mautrix/hicli/database" "maunium.net/go/mautrix/id" ) +var ( + rainbowWithHTML = goldmark.New(format.Extensions, format.HTMLOptions, goldmark.WithExtensions(rainbow.Extension)) +) + +func (h *HiClient) SendMessage(ctx context.Context, roomID id.RoomID, text, mediaPath string) (*database.Event, error) { + var content event.MessageEventContent + if strings.HasPrefix(text, "/rainbow ") { + text = strings.TrimPrefix(text, "/rainbow ") + content = format.RenderMarkdownCustom(text, rainbowWithHTML) + content.FormattedBody = rainbow.ApplyColor(content.FormattedBody) + } else if strings.HasPrefix(text, "/plain ") { + text = strings.TrimPrefix(text, "/plain ") + content = format.RenderMarkdown(text, false, false) + } else if strings.HasPrefix(text, "/html ") { + text = strings.TrimPrefix(text, "/html ") + content = format.RenderMarkdown(text, false, true) + } else { + content = format.RenderMarkdown(text, true, false) + } + return h.Send(ctx, roomID, event.EventMessage, &content) +} + func (h *HiClient) Send(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (*database.Event, error) { roomMeta, err := h.DB.Room.Get(ctx, roomID) if err != nil { From 974fab0e0f4a59d36813fc044ac6e843f385387f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 12 Oct 2024 15:45:36 +0300 Subject: [PATCH 0829/1647] hicli: use unix timestamps for events --- hicli/database/event.go | 17 +++++++++-------- hicli/database/room.go | 2 +- hicli/send.go | 4 ++-- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/hicli/database/event.go b/hicli/database/event.go index db59afaf..0c55d84c 100644 --- a/hicli/database/event.go +++ b/hicli/database/event.go @@ -17,6 +17,7 @@ import ( "github.com/tidwall/gjson" "go.mau.fi/util/dbutil" "go.mau.fi/util/exgjson" + "go.mau.fi/util/jsontime" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -267,12 +268,12 @@ type Event struct { RowID EventRowID `json:"rowid"` TimelineRowID TimelineRowID `json:"timeline_rowid"` - RoomID id.RoomID `json:"room_id"` - ID id.EventID `json:"event_id"` - Sender id.UserID `json:"sender"` - Type string `json:"type"` - StateKey *string `json:"state_key,omitempty"` - Timestamp time.Time `json:"timestamp"` + RoomID id.RoomID `json:"room_id"` + ID id.EventID `json:"event_id"` + Sender id.UserID `json:"sender"` + Type string `json:"type"` + StateKey *string `json:"state_key,omitempty"` + Timestamp jsontime.UnixMilli `json:"timestamp"` Content json.RawMessage `json:"content"` Decrypted json.RawMessage `json:"decrypted,omitempty"` @@ -300,7 +301,7 @@ func MautrixToEvent(evt *event.Event) *Event { Sender: evt.Sender, Type: evt.Type.Type, StateKey: evt.StateKey, - Timestamp: time.UnixMilli(evt.Timestamp), + Timestamp: jsontime.UM(time.UnixMilli(evt.Timestamp)), Content: evt.Content.VeryRaw, MegolmSessionID: getMegolmSessionID(evt), TransactionID: evt.Unsigned.TransactionID, @@ -367,7 +368,7 @@ func (e *Event) Scan(row dbutil.Scannable) (*Event, error) { if err != nil { return nil, err } - e.Timestamp = time.UnixMilli(timestamp) + e.Timestamp = jsontime.UM(time.UnixMilli(timestamp)) e.TransactionID = transactionID.String e.RedactedBy = id.EventID(redactedBy.String) e.RelatesTo = id.EventID(relatesTo.String) diff --git a/hicli/database/room.go b/hicli/database/room.go index a5e4f75a..5971788e 100644 --- a/hicli/database/room.go +++ b/hicli/database/room.go @@ -241,7 +241,7 @@ func (r *Room) BumpSortingTimestamp(evt *Event) bool { if !evt.BumpsSortingTimestamp() || evt.Timestamp.Before(r.SortingTimestamp.Time) { return false } - r.SortingTimestamp = jsontime.UM(evt.Timestamp) + r.SortingTimestamp = evt.Timestamp now := time.Now() if r.SortingTimestamp.After(now) { r.SortingTimestamp = jsontime.UM(now) diff --git a/hicli/send.go b/hicli/send.go index d16cadf6..8df8312f 100644 --- a/hicli/send.go +++ b/hicli/send.go @@ -12,10 +12,10 @@ import ( "errors" "fmt" "strings" - "time" "github.com/rs/zerolog" "github.com/yuin/goldmark" + "go.mau.fi/util/jsontime" "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto" @@ -84,7 +84,7 @@ func (h *HiClient) Send(ctx context.Context, roomID id.RoomID, evtType event.Typ ID: id.EventID(fmt.Sprintf("~%s", txnID)), Sender: h.Account.UserID, Type: evtType.Type, - Timestamp: time.Now(), + Timestamp: jsontime.UnixMilliNow(), Content: mainContent, Decrypted: decryptedContent, DecryptedType: decryptedType.Type, From 226144ca9f8507bd9f3e3779cd70eb784a44fbf7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 12 Oct 2024 18:14:22 +0300 Subject: [PATCH 0830/1647] hicli/sync: fix handling redactions to unknown events --- hicli/sync.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/hicli/sync.go b/hicli/sync.go index d66c35bd..209e4997 100644 --- a/hicli/sync.go +++ b/hicli/sync.go @@ -404,6 +404,9 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R if err != nil { return fmt.Errorf("failed to get redaction target: %w", err) } + if dbEvt == nil { + return nil + } if dbEvt.RelationType == event.RelReplace || dbEvt.RelationType == event.RelAnnotation { _, err = addOldEvent(0, dbEvt.RelatesTo) if err != nil { From e48f081942632d24ab55daeb1800d2ec099111d0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 13 Oct 2024 17:14:18 +0300 Subject: [PATCH 0831/1647] hicli/send: add mark read method --- hicli/json-commands.go | 10 ++++++++++ hicli/send.go | 22 ++++++++++++++++++++-- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/hicli/json-commands.go b/hicli/json-commands.go index a81cf097..8c848525 100644 --- a/hicli/json-commands.go +++ b/hicli/json-commands.go @@ -45,6 +45,10 @@ func (h *HiClient) handleJSONCommand(ctx context.Context, req *JSONCommand) (any return unmarshalAndCall(req.Data, func(params *sendEventParams) (*database.Event, error) { return h.Send(ctx, params.RoomID, params.EventType, params.Content) }) + case "mark_read": + return unmarshalAndCall(req.Data, func(params *markReadParams) (bool, error) { + return h.MarkRead(ctx, params.RoomID, params.EventID, params.ReceiptType) + }) case "get_event": return unmarshalAndCall(req.Data, func(params *getEventParams) (*database.Event, error) { return h.GetEvent(ctx, params.RoomID, params.EventID) @@ -116,6 +120,12 @@ type sendEventParams struct { Content json.RawMessage `json:"content"` } +type markReadParams struct { + RoomID id.RoomID `json:"room_id"` + EventID id.EventID `json:"event_id"` + ReceiptType event.ReceiptType `json:"receipt_type"` +} + type getEventParams struct { RoomID id.RoomID `json:"room_id"` EventID id.EventID `json:"event_id"` diff --git a/hicli/send.go b/hicli/send.go index 8df8312f..fb9b4470 100644 --- a/hicli/send.go +++ b/hicli/send.go @@ -16,6 +16,7 @@ import ( "github.com/rs/zerolog" "github.com/yuin/goldmark" "go.mau.fi/util/jsontime" + "go.mau.fi/util/ptr" "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto" @@ -48,6 +49,24 @@ func (h *HiClient) SendMessage(ctx context.Context, roomID id.RoomID, text, medi return h.Send(ctx, roomID, event.EventMessage, &content) } +func (h *HiClient) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.EventID, receiptType event.ReceiptType) (bool, error) { + content := &mautrix.ReqSetReadMarkers{ + FullyRead: eventID, + } + if receiptType == event.ReceiptTypeRead { + content.Read = eventID + } else if receiptType == event.ReceiptTypeReadPrivate { + content.ReadPrivate = eventID + } else { + return false, fmt.Errorf("invalid receipt type: %v", receiptType) + } + err := h.Client.SetReadMarkers(ctx, roomID, content) + if err != nil { + return false, fmt.Errorf("failed to mark event as read: %w", err) + } + return true, nil +} + func (h *HiClient) Send(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (*database.Event, error) { roomMeta, err := h.DB.Room.Get(ctx, roomID) if err != nil { @@ -76,7 +95,6 @@ func (h *HiClient) Send(ctx context.Context, roomID id.RoomID, evtType event.Typ if err != nil { return nil, fmt.Errorf("failed to marshal event content: %w", err) } - var zero database.EventRowID txnID := "hicli-" + h.Client.TxnID() relatesTo, relationType := database.GetRelatesToFromBytes(mainContent) dbEvt := &database.Event{ @@ -96,7 +114,7 @@ func (h *HiClient) Send(ctx context.Context, roomID id.RoomID, evtType event.Typ DecryptionError: "", SendError: "not sent", Reactions: map[string]int{}, - LastEditRowID: &zero, + LastEditRowID: ptr.Ptr(database.EventRowID(0)), } _, err = h.DB.Event.Insert(ctx, dbEvt) if err != nil { From 5cccf93cdc6a1163edc5523321a4fa0ffd3fc53f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 13 Oct 2024 17:14:33 +0300 Subject: [PATCH 0832/1647] hicli/sync: cache push rules in memory --- hicli/hicli.go | 14 ++++++++++++++ hicli/sync.go | 10 ++++++++++ 2 files changed, 24 insertions(+) diff --git a/hicli/hicli.go b/hicli/hicli.go index 4253c581..78a1acc0 100644 --- a/hicli/hicli.go +++ b/hicli/hicli.go @@ -27,6 +27,7 @@ import ( "maunium.net/go/mautrix/crypto/backup" "maunium.net/go/mautrix/hicli/database" "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/pushrules" ) type HiClient struct { @@ -43,6 +44,8 @@ type HiClient struct { KeyBackupVersion id.KeyBackupVersion KeyBackupKey *backup.MegolmBackupKey + PushRules atomic.Pointer[pushrules.PushRuleset] + EventHandler func(evt any) firstSyncReceived bool @@ -212,6 +215,7 @@ func (h *HiClient) Sync() { defer cancel() h.stopSync.Store(&cancel) go h.RunRequestQueue(h.Log.WithContext(ctx)) + go h.LoadPushRules(h.Log.WithContext(ctx)) ctx = log.WithContext(ctx) log.Info().Msg("Starting syncing") err := h.Client.SyncWithContext(ctx) @@ -222,6 +226,16 @@ func (h *HiClient) Sync() { } } +func (h *HiClient) LoadPushRules(ctx context.Context) { + rules, err := h.Client.GetPushRules(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to load push rules") + return + } + h.PushRules.Store(rules) + zerolog.Ctx(ctx).Debug().Msg("Updated push rules from fetch") +} + func (h *HiClient) Stop() { h.Client.StopSync() if fn := h.stopSync.Swap(nil); fn != nil { diff --git a/hicli/sync.go b/hicli/sync.go index 209e4997..1b6ec613 100644 --- a/hicli/sync.go +++ b/hicli/sync.go @@ -25,6 +25,7 @@ import ( "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/hicli/database" "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/pushrules" ) type syncContext struct { @@ -103,6 +104,15 @@ func (h *HiClient) processSyncResponse(ctx context.Context, resp *mautrix.RespSy if err != nil { return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err) } + if evt.Type == event.AccountDataPushRules { + err = evt.Content.ParseRaw(evt.Type) + if err != nil { + zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to parse push rules in sync") + } else if pushRules, ok := evt.Content.Parsed.(*pushrules.EventContent); ok { + h.PushRules.Store(pushRules.Ruleset) + zerolog.Ctx(ctx).Debug().Msg("Updated push rules from sync") + } + } } for roomID, room := range resp.Rooms.Join { err := h.processSyncJoinedRoom(ctx, roomID, room) From 8d9caf0d55f4b35f0af6b85a3a48b04042b30b61 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 13 Oct 2024 22:33:17 +0300 Subject: [PATCH 0833/1647] hicli/send: add support for sending replies --- hicli/json-commands.go | 10 ++++++---- hicli/send.go | 13 ++++++++++++- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/hicli/json-commands.go b/hicli/json-commands.go index 8c848525..0c65c0f2 100644 --- a/hicli/json-commands.go +++ b/hicli/json-commands.go @@ -39,7 +39,7 @@ func (h *HiClient) handleJSONCommand(ctx context.Context, req *JSONCommand) (any }) case "send_message": return unmarshalAndCall(req.Data, func(params *sendMessageParams) (*database.Event, error) { - return h.SendMessage(ctx, params.RoomID, params.Text, params.MediaPath) + return h.SendMessage(ctx, params.RoomID, params.Text, params.MediaPath, params.ReplyTo, params.Mentions) }) case "send_event": return unmarshalAndCall(req.Data, func(params *sendEventParams) (*database.Event, error) { @@ -109,9 +109,11 @@ type cancelRequestParams struct { } type sendMessageParams struct { - RoomID id.RoomID `json:"room_id"` - Text string `json:"text"` - MediaPath string `json:"media_path"` + RoomID id.RoomID `json:"room_id"` + Text string `json:"text"` + MediaPath string `json:"media_path"` + ReplyTo id.EventID `json:"reply_to"` + Mentions *event.Mentions `json:"mentions"` } type sendEventParams struct { diff --git a/hicli/send.go b/hicli/send.go index fb9b4470..d732784b 100644 --- a/hicli/send.go +++ b/hicli/send.go @@ -31,7 +31,7 @@ var ( rainbowWithHTML = goldmark.New(format.Extensions, format.HTMLOptions, goldmark.WithExtensions(rainbow.Extension)) ) -func (h *HiClient) SendMessage(ctx context.Context, roomID id.RoomID, text, mediaPath string) (*database.Event, error) { +func (h *HiClient) SendMessage(ctx context.Context, roomID id.RoomID, text, mediaPath string, replyTo id.EventID, mentions *event.Mentions) (*database.Event, error) { var content event.MessageEventContent if strings.HasPrefix(text, "/rainbow ") { text = strings.TrimPrefix(text, "/rainbow ") @@ -46,6 +46,17 @@ func (h *HiClient) SendMessage(ctx context.Context, roomID id.RoomID, text, medi } else { content = format.RenderMarkdown(text, true, false) } + if mentions != nil { + content.Mentions.Room = mentions.Room + for _, userID := range mentions.UserIDs { + if userID != h.Account.UserID { + content.Mentions.Add(userID) + } + } + } + if replyTo != "" { + content.RelatesTo = (&event.RelatesTo{}).SetReplyTo(replyTo) + } return h.Send(ctx, roomID, event.EventMessage, &content) } From 6f1c516baa07158c130165511a57dc58773b9f1f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 14 Oct 2024 00:49:45 +0300 Subject: [PATCH 0834/1647] hicli/database: ignore edits to edits and annotations --- hicli/database/upgrades/00-latest-revision.sql | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/hicli/database/upgrades/00-latest-revision.sql b/hicli/database/upgrades/00-latest-revision.sql index 8ba1fd15..8a9f8367 100644 --- a/hicli/database/upgrades/00-latest-revision.sql +++ b/hicli/database/upgrades/00-latest-revision.sql @@ -120,7 +120,8 @@ BEGIN 0) WHERE event_id = NEW.relates_to AND last_edit_rowid = NEW.rowid - AND state_key IS NULL; + AND state_key IS NULL + AND relation_type NOT IN ('m.replace', 'm.annotation'); END; CREATE TRIGGER event_insert_update_last_edit @@ -136,6 +137,7 @@ BEGIN AND type = NEW.type AND sender = NEW.sender AND state_key IS NULL + AND relation_type NOT IN ('m.replace', 'm.annotation') AND NEW.timestamp > COALESCE((SELECT prev_edit.timestamp FROM event prev_edit WHERE prev_edit.rowid = event.last_edit_rowid), 0); END; From 67b1c97f1485e04ed907381d07f65795133e7d8e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 14 Oct 2024 01:54:29 +0300 Subject: [PATCH 0835/1647] hicli/database: fix edit triggers --- hicli/database/upgrades/00-latest-revision.sql | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hicli/database/upgrades/00-latest-revision.sql b/hicli/database/upgrades/00-latest-revision.sql index 8a9f8367..04de7baf 100644 --- a/hicli/database/upgrades/00-latest-revision.sql +++ b/hicli/database/upgrades/00-latest-revision.sql @@ -121,7 +121,7 @@ BEGIN WHERE event_id = NEW.relates_to AND last_edit_rowid = NEW.rowid AND state_key IS NULL - AND relation_type NOT IN ('m.replace', 'm.annotation'); + AND (relation_type IS NULL OR relation_type NOT IN ('m.replace', 'm.annotation')); END; CREATE TRIGGER event_insert_update_last_edit @@ -137,7 +137,7 @@ BEGIN AND type = NEW.type AND sender = NEW.sender AND state_key IS NULL - AND relation_type NOT IN ('m.replace', 'm.annotation') + AND (relation_type IS NULL OR relation_type NOT IN ('m.replace', 'm.annotation')) AND NEW.timestamp > COALESCE((SELECT prev_edit.timestamp FROM event prev_edit WHERE prev_edit.rowid = event.last_edit_rowid), 0); END; From efc532bfb2ce5b4900dddb68ba615619f08e6fa3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 14 Oct 2024 17:23:15 +0300 Subject: [PATCH 0836/1647] hicli/processEvent: save session request manually if decryption queue is not provided --- hicli/sync.go | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/hicli/sync.go b/hicli/sync.go index 1b6ec613..210cd1ac 100644 --- a/hicli/sync.go +++ b/hicli/sync.go @@ -370,7 +370,18 @@ func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptio } minIndex, _ := crypto.ParseMegolmMessageIndex(evt.Content.AsEncrypted().MegolmCiphertext) req.MinIndex = min(uint32(minIndex), req.MinIndex) - decryptionQueue[dbEvt.MegolmSessionID] = req + if decryptionQueue != nil { + decryptionQueue[dbEvt.MegolmSessionID] = req + } else { + err = h.DB.SessionRequest.Put(ctx, req) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("session_id", dbEvt.MegolmSessionID). + Msg("Failed to save session request") + } else { + h.WakeupRequestQueue() + } + } } return dbEvt, err } From 965008e8462e42b55c2554d8e8993fa9d90cbd81 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 14 Oct 2024 19:47:22 +0300 Subject: [PATCH 0837/1647] bridgev2: add optional stop method for network connectors --- bridgev2/bridge.go | 3 +++ bridgev2/networkinterface.go | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 2b520e23..16ebdb77 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -216,6 +216,9 @@ func (br *Bridge) Stop() { } wg.Wait() br.cacheLock.Unlock() + if stopNet, ok := br.Network.(StoppableNetwork); ok { + stopNet.Stop() + } err := br.DB.Close() if err != nil { br.Log.Warn().Err(err).Msg("Failed to close database") diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 3b406a9d..ae7d6520 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -229,6 +229,11 @@ type NetworkConnector interface { CreateLogin(ctx context.Context, user *User, flowID string) (LoginProcess, error) } +type StoppableNetwork interface { + // Stop is called when the bridge is stopping, after all network clients have been disconnected. + Stop() +} + // DirectMediableNetwork is an optional interface that network connectors can implement to support direct media access. // // If the Matrix connector has direct media enabled, SetUseDirectMedia will be called From e2c698098862f680d95d8130439d4eeaa4656069 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 15 Oct 2024 00:24:13 +0300 Subject: [PATCH 0838/1647] hicli/send: add set typing method --- hicli/json-commands.go | 12 +++++++++++- hicli/send.go | 22 +++++++++++++++++----- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/hicli/json-commands.go b/hicli/json-commands.go index 0c65c0f2..c9dc89d2 100644 --- a/hicli/json-commands.go +++ b/hicli/json-commands.go @@ -11,6 +11,7 @@ import ( "encoding/json" "errors" "fmt" + "time" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" @@ -47,7 +48,11 @@ func (h *HiClient) handleJSONCommand(ctx context.Context, req *JSONCommand) (any }) case "mark_read": return unmarshalAndCall(req.Data, func(params *markReadParams) (bool, error) { - return h.MarkRead(ctx, params.RoomID, params.EventID, params.ReceiptType) + return true, h.MarkRead(ctx, params.RoomID, params.EventID, params.ReceiptType) + }) + case "set_typing": + return unmarshalAndCall(req.Data, func(params *setTypingParams) (bool, error) { + return true, h.SetTyping(ctx, params.RoomID, time.Duration(params.Timeout)*time.Millisecond) }) case "get_event": return unmarshalAndCall(req.Data, func(params *getEventParams) (*database.Event, error) { @@ -128,6 +133,11 @@ type markReadParams struct { ReceiptType event.ReceiptType `json:"receipt_type"` } +type setTypingParams struct { + RoomID id.RoomID `json:"room_id"` + Timeout int `json:"timeout"` +} + type getEventParams struct { RoomID id.RoomID `json:"room_id"` EventID id.EventID `json:"event_id"` diff --git a/hicli/send.go b/hicli/send.go index d732784b..3b05d494 100644 --- a/hicli/send.go +++ b/hicli/send.go @@ -12,6 +12,7 @@ import ( "errors" "fmt" "strings" + "time" "github.com/rs/zerolog" "github.com/yuin/goldmark" @@ -60,7 +61,7 @@ func (h *HiClient) SendMessage(ctx context.Context, roomID id.RoomID, text, medi return h.Send(ctx, roomID, event.EventMessage, &content) } -func (h *HiClient) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.EventID, receiptType event.ReceiptType) (bool, error) { +func (h *HiClient) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.EventID, receiptType event.ReceiptType) error { content := &mautrix.ReqSetReadMarkers{ FullyRead: eventID, } @@ -69,13 +70,18 @@ func (h *HiClient) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.Ev } else if receiptType == event.ReceiptTypeReadPrivate { content.ReadPrivate = eventID } else { - return false, fmt.Errorf("invalid receipt type: %v", receiptType) + return fmt.Errorf("invalid receipt type: %v", receiptType) } err := h.Client.SetReadMarkers(ctx, roomID, content) if err != nil { - return false, fmt.Errorf("failed to mark event as read: %w", err) + return fmt.Errorf("failed to mark event as read: %w", err) } - return true, nil + return nil +} + +func (h *HiClient) SetTyping(ctx context.Context, roomID id.RoomID, timeout time.Duration) error { + _, err := h.Client.UserTyping(ctx, roomID, timeout > 0, timeout) + return err } func (h *HiClient) Send(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (*database.Event, error) { @@ -131,8 +137,14 @@ func (h *HiClient) Send(ctx context.Context, roomID id.RoomID, evtType event.Typ if err != nil { return nil, fmt.Errorf("failed to insert event into database: %w", err) } + ctx = context.WithoutCancel(ctx) + go func() { + err := h.SetTyping(ctx, roomID, 0) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to stop typing while sending message") + } + }() go func() { - ctx := context.WithoutCancel(ctx) var err error defer func() { h.EventHandler(&SendComplete{ From 68f1ff3e69f48b666f05f343d8568455de56e726 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 15 Oct 2024 01:21:13 +0300 Subject: [PATCH 0839/1647] hicli/sync: fix calculating room name if member event is not found --- hicli/sync.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/hicli/sync.go b/hicli/sync.go index 210cd1ac..f84e4168 100644 --- a/hicli/sync.go +++ b/hicli/sync.go @@ -628,6 +628,9 @@ func (h *HiClient) calculateRoomParticipantName(ctx context.Context, roomID id.R heroEvt, err := h.DB.CurrentState.Get(ctx, roomID, event.StateMember, hero.String()) if err != nil { return "", fmt.Errorf("failed to get %s's member event: %w", hero, err) + } else if heroEvt == nil { + leftMembers = append(leftMembers, hero.String()) + continue } results := gjson.GetManyBytes(heroEvt.Content, "membership", "displayname") name := results[1].Str From 89f78e907dbe07e9d90a006d1eb319e7849cc3fe Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 15 Oct 2024 01:39:05 +0300 Subject: [PATCH 0840/1647] hicli: use user avatar as room avatar in DMs --- hicli/database/room.go | 23 ++++++++----- .../database/upgrades/00-latest-revision.sql | 3 +- .../upgrades/02-explicit-avatar-flag.sql | 2 ++ hicli/sync.go | 34 +++++++++++++------ 4 files changed, 41 insertions(+), 21 deletions(-) create mode 100644 hicli/database/upgrades/02-explicit-avatar-flag.sql diff --git a/hicli/database/room.go b/hicli/database/room.go index 5971788e..c20ed2c5 100644 --- a/hicli/database/room.go +++ b/hicli/database/room.go @@ -22,7 +22,7 @@ import ( const ( getRoomBaseQuery = ` - SELECT room_id, creation_content, name, name_quality, avatar, topic, canonical_alias, + SELECT room_id, creation_content, name, name_quality, avatar, explicit_avatar, topic, canonical_alias, lazy_load_summary, encryption_event, has_member_list, preview_event_rowid, sorting_timestamp, prev_batch FROM room @@ -39,14 +39,15 @@ const ( name = COALESCE($3, room.name), name_quality = CASE WHEN $3 IS NOT NULL THEN $4 ELSE room.name_quality END, avatar = COALESCE($5, room.avatar), - topic = COALESCE($6, room.topic), - canonical_alias = COALESCE($7, room.canonical_alias), - lazy_load_summary = COALESCE($8, room.lazy_load_summary), - encryption_event = COALESCE($9, room.encryption_event), - has_member_list = room.has_member_list OR $10, - preview_event_rowid = COALESCE($11, room.preview_event_rowid), - sorting_timestamp = COALESCE($12, room.sorting_timestamp), - prev_batch = COALESCE($13, room.prev_batch) + explicit_avatar = CASE WHEN $5 IS NOT NULL THEN $6 ELSE room.explicit_avatar END, + topic = COALESCE($7, room.topic), + canonical_alias = COALESCE($8, room.canonical_alias), + lazy_load_summary = COALESCE($9, room.lazy_load_summary), + encryption_event = COALESCE($10, room.encryption_event), + has_member_list = room.has_member_list OR $11, + preview_event_rowid = COALESCE($12, room.preview_event_rowid), + sorting_timestamp = COALESCE($13, room.sorting_timestamp), + prev_batch = COALESCE($14, room.prev_batch) WHERE room_id = $1 ` setRoomPrevBatchQuery = ` @@ -133,6 +134,7 @@ type Room struct { Name *string `json:"name,omitempty"` NameQuality NameQuality `json:"name_quality"` Avatar *id.ContentURI `json:"avatar,omitempty"` + ExplicitAvatar bool `json:"explicit_avatar"` Topic *string `json:"topic,omitempty"` CanonicalAlias *id.RoomAlias `json:"canonical_alias,omitempty"` @@ -155,6 +157,7 @@ func (r *Room) CheckChangesAndCopyInto(other *Room) (hasChanges bool) { } if r.Avatar != nil { other.Avatar = r.Avatar + other.ExplicitAvatar = r.ExplicitAvatar hasChanges = true } if r.Topic != nil { @@ -201,6 +204,7 @@ func (r *Room) Scan(row dbutil.Scannable) (*Room, error) { &r.Name, &r.NameQuality, &r.Avatar, + &r.ExplicitAvatar, &r.Topic, &r.CanonicalAlias, dbutil.JSON{Data: &r.LazyLoadSummary}, @@ -226,6 +230,7 @@ func (r *Room) sqlVariables() []any { r.Name, r.NameQuality, r.Avatar, + r.ExplicitAvatar, r.Topic, r.CanonicalAlias, dbutil.JSONPtr(r.LazyLoadSummary), diff --git a/hicli/database/upgrades/00-latest-revision.sql b/hicli/database/upgrades/00-latest-revision.sql index 04de7baf..f8c84a61 100644 --- a/hicli/database/upgrades/00-latest-revision.sql +++ b/hicli/database/upgrades/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v1: Latest revision +-- v0 -> v2 (compatible with v1+): Latest revision CREATE TABLE account ( user_id TEXT NOT NULL PRIMARY KEY, device_id TEXT NOT NULL, @@ -15,6 +15,7 @@ CREATE TABLE room ( name TEXT, name_quality INTEGER NOT NULL DEFAULT 0, avatar TEXT, + explicit_avatar INTEGER NOT NULL DEFAULT 0, topic TEXT, canonical_alias TEXT, lazy_load_summary TEXT, diff --git a/hicli/database/upgrades/02-explicit-avatar-flag.sql b/hicli/database/upgrades/02-explicit-avatar-flag.sql new file mode 100644 index 00000000..c11e8801 --- /dev/null +++ b/hicli/database/upgrades/02-explicit-avatar-flag.sql @@ -0,0 +1,2 @@ +-- v2 (compatible with v1+): Add explicit avatar flag to rooms +ALTER TABLE room ADD COLUMN explicit_avatar INTEGER NOT NULL DEFAULT 0; diff --git a/hicli/sync.go b/hicli/sync.go index f84e4168..ffce4dcd 100644 --- a/hicli/sync.go +++ b/hicli/sync.go @@ -556,12 +556,15 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R } // Calculate name from participants if participants changed and current name was generated from participants, or if the room name was unset if (heroesChanged && updatedRoom.NameQuality <= database.NameQualityParticipants) || updatedRoom.NameQuality == database.NameQualityNil { - name, err := h.calculateRoomParticipantName(ctx, room.ID, summary) + name, dmAvatarURL, err := h.calculateRoomParticipantName(ctx, room.ID, summary) if err != nil { return fmt.Errorf("failed to calculate room name: %w", err) } updatedRoom.Name = &name updatedRoom.NameQuality = database.NameQualityParticipants + if !dmAvatarURL.IsEmpty() && !room.ExplicitAvatar { + updatedRoom.Avatar = &dmAvatarURL + } } if timeline.PrevBatch != "" && (room.PrevBatch == "" || timeline.Limited) { updatedRoom.PrevBatch = timeline.PrevBatch @@ -595,14 +598,15 @@ func joinMemberNames(names []string, totalCount int) string { } } -func (h *HiClient) calculateRoomParticipantName(ctx context.Context, roomID id.RoomID, summary *mautrix.LazyLoadSummary) (string, error) { +func (h *HiClient) calculateRoomParticipantName(ctx context.Context, roomID id.RoomID, summary *mautrix.LazyLoadSummary) (string, id.ContentURI, error) { + var primaryAvatarURL id.ContentURI if summary == nil || len(summary.Heroes) == 0 { - return "Empty room", nil + return "Empty room", primaryAvatarURL, nil } var functionalMembers []id.UserID functionalMembersEvt, err := h.DB.CurrentState.Get(ctx, roomID, event.StateElementFunctionalMembers, "") if err != nil { - return "", fmt.Errorf("failed to get %s event: %w", event.StateElementFunctionalMembers.Type, err) + return "", primaryAvatarURL, fmt.Errorf("failed to get %s event: %w", event.StateElementFunctionalMembers.Type, err) } else if functionalMembersEvt != nil { mautrixEvt := functionalMembersEvt.AsRawMautrix() _ = mautrixEvt.Content.ParseRaw(mautrixEvt.Type) @@ -627,28 +631,35 @@ func (h *HiClient) calculateRoomParticipantName(ctx context.Context, roomID id.R } heroEvt, err := h.DB.CurrentState.Get(ctx, roomID, event.StateMember, hero.String()) if err != nil { - return "", fmt.Errorf("failed to get %s's member event: %w", hero, err) + return "", primaryAvatarURL, fmt.Errorf("failed to get %s's member event: %w", hero, err) } else if heroEvt == nil { leftMembers = append(leftMembers, hero.String()) continue } - results := gjson.GetManyBytes(heroEvt.Content, "membership", "displayname") - name := results[1].Str + membership := gjson.GetBytes(heroEvt.Content, "membership").Str + name := gjson.GetBytes(heroEvt.Content, "displayname").Str if name == "" { name = hero.String() } - if results[0].Str == "join" || results[0].Str == "invite" { + avatarURL := gjson.GetBytes(heroEvt.Content, "avatar_url").Str + if avatarURL != "" { + primaryAvatarURL = id.ContentURIString(avatarURL).ParseOrIgnore() + } + if membership == "join" || membership == "invite" { members = append(members, name) } else { leftMembers = append(leftMembers, name) } } + if len(members)+len(leftMembers) > 1 || !primaryAvatarURL.IsValid() { + primaryAvatarURL = id.ContentURI{} + } if len(members) > 0 { - return joinMemberNames(members, memberCount), nil + return joinMemberNames(members, memberCount), primaryAvatarURL, nil } else if len(leftMembers) > 0 { - return fmt.Sprintf("Empty room (was %s)", joinMemberNames(leftMembers, memberCount)), nil + return fmt.Sprintf("Empty room (was %s)", joinMemberNames(leftMembers, memberCount)), primaryAvatarURL, nil } else { - return "Empty room", nil + return "Empty room", primaryAvatarURL, nil } } @@ -721,6 +732,7 @@ func processImportantEvent(ctx context.Context, evt *event.Event, existingRoomDa if ok { url, _ := content.URL.Parse() updatedRoom.Avatar = &url + updatedRoom.ExplicitAvatar = true } case event.StateTopic: content, ok := evt.Content.Parsed.(*event.TopicEventContent) From 948c9b0f399073d300113a3c1120f5fe12e613b9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 15 Oct 2024 01:52:00 +0300 Subject: [PATCH 0841/1647] hicli/sync: don't fail event parsing if it's already parsed --- hicli/send.go | 2 +- hicli/sync.go | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/hicli/send.go b/hicli/send.go index 3b05d494..76852dde 100644 --- a/hicli/send.go +++ b/hicli/send.go @@ -272,7 +272,7 @@ func (h *HiClient) shouldShareKeysToInvitedUsers(ctx context.Context, roomID id. } mautrixEvt := historyVisibility.AsRawMautrix() err = mautrixEvt.Content.ParseRaw(mautrixEvt.Type) - if err != nil { + if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) { zerolog.Ctx(ctx).Err(err).Msg("Failed to parse history visibility event") return false } diff --git a/hicli/sync.go b/hicli/sync.go index ffce4dcd..a6da9517 100644 --- a/hicli/sync.go +++ b/hicli/sync.go @@ -40,7 +40,7 @@ func (h *HiClient) preProcessSyncResponse(ctx context.Context, resp *mautrix.Res for _, evt := range resp.ToDevice.Events { evt.Type.Class = event.ToDeviceEventType err := evt.Content.ParseRaw(evt.Type) - if err != nil { + if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) { log.Warn().Err(err). Stringer("event_type", &evt.Type). Stringer("sender", evt.Sender). @@ -106,7 +106,7 @@ func (h *HiClient) processSyncResponse(ctx context.Context, resp *mautrix.RespSy } if evt.Type == event.AccountDataPushRules { err = evt.Content.ParseRaw(evt.Type) - if err != nil { + if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) { zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to parse push rules in sync") } else if pushRules, ok := evt.Content.Parsed.(*pushrules.EventContent); ok { h.PushRules.Store(pushRules.Ruleset) @@ -179,7 +179,7 @@ func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, for _, evt := range room.Ephemeral.Events { evt.Type.Class = event.EphemeralEventType err = evt.Content.ParseRaw(evt.Type) - if err != nil { + if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) { zerolog.Ctx(ctx).Debug().Err(err).Msg("Failed to parse ephemeral event content") continue } @@ -683,7 +683,7 @@ func processImportantEvent(ctx context.Context, evt *event.Event, existingRoomDa return } err := evt.Content.ParseRaw(evt.Type) - if err != nil { + if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) { zerolog.Ctx(ctx).Warn().Err(err). Stringer("event_type", &evt.Type). Stringer("event_id", evt.ID). From 21eaeeaecf6987a234c265a1e54979c24fac6c9f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 15 Oct 2024 02:19:04 +0300 Subject: [PATCH 0842/1647] hicli/sync: always set sorting timestamp for new rooms --- hicli/database/room.go | 2 +- hicli/sync.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/hicli/database/room.go b/hicli/database/room.go index c20ed2c5..d9293cf8 100644 --- a/hicli/database/room.go +++ b/hicli/database/room.go @@ -27,7 +27,7 @@ const ( preview_event_rowid, sorting_timestamp, prev_batch FROM room ` - getRoomsBySortingTimestampQuery = getRoomBaseQuery + `WHERE sorting_timestamp < $1 ORDER BY sorting_timestamp DESC LIMIT $2` + getRoomsBySortingTimestampQuery = getRoomBaseQuery + `WHERE sorting_timestamp < $1 AND sorting_timestamp > 0 ORDER BY sorting_timestamp DESC LIMIT $2` getRoomByIDQuery = getRoomBaseQuery + `WHERE room_id = $1` ensureRoomExistsQuery = ` INSERT INTO room (room_id) VALUES ($1) diff --git a/hicli/sync.go b/hicli/sync.go index a6da9517..16930b59 100644 --- a/hicli/sync.go +++ b/hicli/sync.go @@ -161,7 +161,7 @@ func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, if err != nil { return fmt.Errorf("failed to ensure room row exists: %w", err) } - existingRoomData = &database.Room{ID: roomID} + existingRoomData = &database.Room{ID: roomID, SortingTimestamp: jsontime.UnixMilliNow()} } for _, evt := range room.AccountData.Events { From df65202dacf09fdf8a98fdd6128d0942e4c95d44 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 15 Oct 2024 17:04:51 +0300 Subject: [PATCH 0843/1647] dependencies: update --- go.mod | 16 ++++++++-------- go.sum | 32 ++++++++++++++++---------------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/go.mod b/go.mod index a1e97f8c..4b3c7723 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/gorilla/websocket v1.5.0 github.com/lib/pq v1.10.9 github.com/lucasb-eyer/go-colorful v1.2.0 - github.com/mattn/go-sqlite3 v1.14.23 + github.com/mattn/go-sqlite3 v1.14.24 github.com/rivo/uniseg v0.4.7 github.com/rs/xid v1.6.0 github.com/rs/zerolog v1.33.0 @@ -19,12 +19,12 @@ require ( github.com/stretchr/testify v1.9.0 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 - github.com/yuin/goldmark v1.7.4 - go.mau.fi/util v0.8.1-0.20241003092848-3b49d3e0b9ee + github.com/yuin/goldmark v1.7.7 + go.mau.fi/util v0.8.1-0.20241015132414-c3f7e22b3de9 go.mau.fi/zeroconfig v0.1.3 - golang.org/x/crypto v0.27.0 - golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 - golang.org/x/net v0.29.0 + golang.org/x/crypto v0.28.0 + golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c + golang.org/x/net v0.30.0 golang.org/x/sync v0.8.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 @@ -39,7 +39,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect - golang.org/x/sys v0.25.0 // indirect - golang.org/x/text v0.18.0 // indirect + golang.org/x/sys v0.26.0 // indirect + golang.org/x/text v0.19.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index b825326f..25601050 100644 --- a/go.sum +++ b/go.sum @@ -26,8 +26,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.23 h1:gbShiuAP1W5j9UOksQ06aiiqPMxYecovVGwmTxWtuw0= -github.com/mattn/go-sqlite3 v1.14.23/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= +github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7 h1:Dx7Ovyv/SFnMFw3fD4oEoeorXc6saIiQ23LrGLth0Gw= github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -53,28 +53,28 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/yuin/goldmark v1.7.4 h1:BDXOHExt+A7gwPCJgPIIq7ENvceR7we7rOS9TNoLZeg= -github.com/yuin/goldmark v1.7.4/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.8.1-0.20241003092848-3b49d3e0b9ee h1:/BGpUK7fzVyFgy5KBiyP7ktEDn20vzz/5FTngrXtIEE= -go.mau.fi/util v0.8.1-0.20241003092848-3b49d3e0b9ee/go.mod h1:L9qnqEkhe4KpuYmILrdttKTXL79MwGLyJ4EOskWxO3I= +github.com/yuin/goldmark v1.7.7 h1:5m9rrB1sW3JUMToKFQfb+FGt1U7r57IHu5GrYrG2nqU= +github.com/yuin/goldmark v1.7.7/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= +go.mau.fi/util v0.8.1-0.20241015132414-c3f7e22b3de9 h1:DnZ0keW636LpkkQKA1LQilYglEjNbxwXOnsJw0fuNIo= +go.mau.fi/util v0.8.1-0.20241015132414-c3f7e22b3de9/go.mod h1:T1u/rD2rzidVrBLyaUdPpZiJdP/rsyi+aTzn0D+Q6wc= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= -golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= -golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk= -golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY= -golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= -golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= +golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= +golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= +golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY= +golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8= +golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= +golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= -golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= -golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= +golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= From dc697ecd64ae73b1a3128a630818be745d3791e0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 15 Oct 2024 20:36:12 +0300 Subject: [PATCH 0844/1647] bridgev2/portal: include receiver in deterministic room IDs --- bridgev2/portal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 016d3693..aa263098 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -3530,7 +3530,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo Preset: "private_chat", IsDirect: portal.RoomType == database.RoomTypeDM, PowerLevelOverride: powerLevels, - BeeperLocalRoomID: id.RoomID(fmt.Sprintf("!%s:%s", portal.ID, portal.Bridge.Matrix.ServerName())), + BeeperLocalRoomID: id.RoomID(fmt.Sprintf("!%s.%s:%s", portal.ID, portal.Receiver, portal.Bridge.Matrix.ServerName())), } autoJoinInvites := portal.Bridge.Matrix.GetCapabilities().AutoJoinInvites if autoJoinInvites { From cc4170475b4e82e401340d88a49e5b3f0e26a4c7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 16 Oct 2024 10:31:14 +0300 Subject: [PATCH 0845/1647] Bump version to v0.21.1 --- CHANGELOG.md | 6 +++++- go.mod | 2 +- go.sum | 4 ++-- version.go | 2 +- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e821fc0..e1904312 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,13 @@ -## unreleased +## v0.21.1 (2024-10-16) +* *(bridgev2)* Added more features and fixed bugs. +* *(hicli)* Added more features and fixed bugs. * *(appservice)* Removed TLS support. A reverse proxy should be used if TLS is needed. * *(format/mdext)* Added goldmark extension to fix indented paragraphs when disabling indented code block parser. +* *(event)* Added `Has` method for `Mentions`. +* *(event)* Added basic support for the unstable version of polls. ## v0.21.0 (2024-09-16) diff --git a/go.mod b/go.mod index 4b3c7723..f45b8990 100644 --- a/go.mod +++ b/go.mod @@ -20,7 +20,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.7 - go.mau.fi/util v0.8.1-0.20241015132414-c3f7e22b3de9 + go.mau.fi/util v0.8.1 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.28.0 golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c diff --git a/go.sum b/go.sum index 25601050..e7a58076 100644 --- a/go.sum +++ b/go.sum @@ -55,8 +55,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.7 h1:5m9rrB1sW3JUMToKFQfb+FGt1U7r57IHu5GrYrG2nqU= github.com/yuin/goldmark v1.7.7/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.8.1-0.20241015132414-c3f7e22b3de9 h1:DnZ0keW636LpkkQKA1LQilYglEjNbxwXOnsJw0fuNIo= -go.mau.fi/util v0.8.1-0.20241015132414-c3f7e22b3de9/go.mod h1:T1u/rD2rzidVrBLyaUdPpZiJdP/rsyi+aTzn0D+Q6wc= +go.mau.fi/util v0.8.1 h1:Ga43cz6esQBYqcjZ/onRoVnYWoUwjWbsxVeJg2jOTSo= +go.mau.fi/util v0.8.1/go.mod h1:T1u/rD2rzidVrBLyaUdPpZiJdP/rsyi+aTzn0D+Q6wc= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= diff --git a/version.go b/version.go index 80b96661..29368573 100644 --- a/version.go +++ b/version.go @@ -7,7 +7,7 @@ import ( "strings" ) -const Version = "v0.21.0" +const Version = "v0.21.1" var GoModVersion = "" var Commit = "" From eed7bc66a05b53fe50d7e8f035fc9d5549930195 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 16 Oct 2024 17:32:04 +0300 Subject: [PATCH 0846/1647] ci: use pre-commit action instead of running manually --- .github/workflows/go.yml | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 9117286f..10025368 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -26,11 +26,8 @@ jobs: go install golang.org/x/tools/cmd/goimports@latest export PATH="$HOME/go/bin:$PATH" - - name: Install pre-commit - run: pip install pre-commit - - - name: Lint - run: pre-commit run -a + - name: Run pre-commit + uses: pre-commit/action@v3.0.1 build: runs-on: ubuntu-latest From 1d4c2d2554551363ddbc5ebd3a7f4edc48e80031 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 16 Oct 2024 17:33:40 +0300 Subject: [PATCH 0847/1647] hicli/database: ignore duplicate timeline inserts --- hicli/database/timeline.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/hicli/database/timeline.go b/hicli/database/timeline.go index 0a01c7f5..ddebd793 100644 --- a/hicli/database/timeline.go +++ b/hicli/database/timeline.go @@ -22,7 +22,9 @@ const ( DELETE FROM timeline WHERE room_id = $1 ` appendTimelineQuery = ` - INSERT INTO timeline (room_id, event_rowid) VALUES ($1, $2) RETURNING rowid, event_rowid + INSERT INTO timeline (room_id, event_rowid) VALUES ($1, $2) + ON CONFLICT DO NOTHING + RETURNING rowid, event_rowid ` prependTimelineQuery = ` INSERT INTO timeline (room_id, rowid, event_rowid) VALUES ($1, $2, $3) From 758e80a5f09ad84ab0cea66583d73746ebdd41b8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 17 Oct 2024 00:21:53 +0300 Subject: [PATCH 0848/1647] hicli: add html sanitization and push rule evaluation --- go.mod | 1 + go.sum | 2 + hicli/database/event.go | 74 ++- hicli/database/room.go | 34 +- hicli/database/state.go | 6 +- hicli/database/timeline.go | 6 +- .../database/upgrades/00-latest-revision.sql | 37 +- .../upgrades/03-more-event-fields.sql | 6 + hicli/decryptionqueue.go | 4 +- hicli/events.go | 16 +- hicli/html.go | 476 ++++++++++++++++++ hicli/paginate.go | 6 +- hicli/pushrules.go | 80 +++ hicli/send.go | 2 +- hicli/sync.go | 141 +++++- pushrules/ruleset.go | 3 + 16 files changed, 823 insertions(+), 71 deletions(-) create mode 100644 hicli/database/upgrades/03-more-event-fields.sql create mode 100644 hicli/html.go create mode 100644 hicli/pushrules.go diff --git a/go.mod b/go.mod index f45b8990..ad6dbdc5 100644 --- a/go.mod +++ b/go.mod @@ -28,6 +28,7 @@ require ( golang.org/x/sync v0.8.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 + mvdan.cc/xurls/v2 v2.5.0 ) require ( diff --git a/go.sum b/go.sum index e7a58076..955cbb91 100644 --- a/go.sum +++ b/go.sum @@ -83,3 +83,5 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= +mvdan.cc/xurls/v2 v2.5.0 h1:lyBNOm8Wo71UknhUs4QTFUNNMyxy2JEIaKKo0RWOh+8= +mvdan.cc/xurls/v2 v2.5.0/go.mod h1:yQgaGQ1rFtJUzkmKiHYSSfuQxqfYmd//X6PxvholpeE= diff --git a/hicli/database/event.go b/hicli/database/event.go index 0c55d84c..b0f64eb3 100644 --- a/hicli/database/event.go +++ b/hicli/database/event.go @@ -25,9 +25,10 @@ import ( const ( getEventBaseQuery = ` - SELECT rowid, -1, room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, - transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, send_error, - reactions, last_edit_rowid + SELECT rowid, -1, + room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, + unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type, + megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type FROM event ` getEventByRowID = getEventBaseQuery + `WHERE rowid = $1` @@ -36,10 +37,11 @@ const ( getFailedEventsByMegolmSessionID = getEventBaseQuery + `WHERE room_id = $1 AND megolm_session_id = $2 AND decryption_error IS NOT NULL` insertEventBaseQuery = ` INSERT INTO event ( - room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, - transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, send_error + room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, + unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type, + megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type ) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21) ` insertEventQuery = insertEventBaseQuery + `RETURNING rowid` upsertEventQuery = insertEventBaseQuery + ` @@ -50,7 +52,8 @@ const ( decryption_error=CASE WHEN COALESCE(event.decrypted, excluded.decrypted) IS NULL THEN COALESCE(excluded.decryption_error, event.decryption_error) END, send_error=excluded.send_error, timestamp=excluded.timestamp, - unsigned=COALESCE(excluded.unsigned, event.unsigned) + unsigned=COALESCE(excluded.unsigned, event.unsigned), + local_content=COALESCE(excluded.local_content, event.local_content) ON CONFLICT (transaction_id) DO UPDATE SET event_id=excluded.event_id, timestamp=excluded.timestamp, @@ -59,7 +62,7 @@ const ( ` updateEventSendErrorQuery = `UPDATE event SET send_error = $2 WHERE rowid = $1` updateEventIDQuery = `UPDATE event SET event_id = $2, send_error = NULL WHERE rowid=$1` - updateEventDecryptedQuery = `UPDATE event SET decrypted = $1, decrypted_type = $2, decryption_error = NULL WHERE rowid = $3` + updateEventDecryptedQuery = `UPDATE event SET decrypted = $2, decrypted_type = $3, decryption_error = NULL, unread_type = $4, local_content = $5 WHERE rowid = $1` getEventReactionsQuery = getEventBaseQuery + ` WHERE room_id = ? AND type = 'm.reaction' @@ -131,8 +134,16 @@ func (eq *EventQuery) UpdateSendError(ctx context.Context, rowID EventRowID, sen return eq.Exec(ctx, updateEventSendErrorQuery, rowID, sendError) } -func (eq *EventQuery) UpdateDecrypted(ctx context.Context, rowID EventRowID, decrypted json.RawMessage, decryptedType string) error { - return eq.Exec(ctx, updateEventDecryptedQuery, unsafeJSONString(decrypted), decryptedType, rowID) +func (eq *EventQuery) UpdateDecrypted(ctx context.Context, evt *Event) error { + return eq.Exec( + ctx, + updateEventDecryptedQuery, + evt.RowID, + unsafeJSONString(evt.Decrypted), + evt.DecryptedType, + evt.UnreadType, + dbutil.JSONPtr(evt.LocalContent), + ) } func (eq *EventQuery) FillReactionCounts(ctx context.Context, roomID id.RoomID, events []*Event) error { @@ -264,6 +275,24 @@ func (m EventRowID) GetMassInsertValues() [1]any { return [1]any{m} } +type LocalContent struct { + SanitizedHTML string `json:"sanitized_html,omitempty"` +} + +type UnreadType int + +func (ut UnreadType) Is(flag UnreadType) bool { + return ut&flag != 0 +} + +const ( + UnreadTypeNone UnreadType = 0b0000 + UnreadTypeNormal UnreadType = 0b0001 + UnreadTypeNotify UnreadType = 0b0010 + UnreadTypeHighlight UnreadType = 0b0100 + UnreadTypeSound UnreadType = 0b1000 +) + type Event struct { RowID EventRowID `json:"rowid"` TimelineRowID TimelineRowID `json:"timeline_rowid"` @@ -279,6 +308,7 @@ type Event struct { Decrypted json.RawMessage `json:"decrypted,omitempty"` DecryptedType string `json:"decrypted_type,omitempty"` Unsigned json.RawMessage `json:"unsigned,omitempty"` + LocalContent *LocalContent `json:"local_content,omitempty"` TransactionID string `json:"transaction_id,omitempty"` @@ -292,6 +322,7 @@ type Event struct { Reactions map[string]int `json:"reactions,omitempty"` LastEditRowID *EventRowID `json:"last_edit_rowid,omitempty"` + UnreadType UnreadType `json:"unread_type,omitempty"` } func MautrixToEvent(evt *event.Event) *Event { @@ -318,6 +349,9 @@ func MautrixToEvent(evt *event.Event) *Event { } func (e *Event) AsRawMautrix() *event.Event { + if e == nil { + return nil + } evt := &event.Event{ RoomID: e.RoomID, ID: e.ID, @@ -355,6 +389,7 @@ func (e *Event) Scan(row dbutil.Scannable) (*Event, error) { (*[]byte)(&e.Decrypted), &decryptedType, (*[]byte)(&e.Unsigned), + dbutil.JSON{Data: &e.LocalContent}, &transactionID, &redactedBy, &relatesTo, @@ -364,6 +399,7 @@ func (e *Event) Scan(row dbutil.Scannable) (*Event, error) { &sendError, dbutil.JSON{Data: &e.Reactions}, &e.LastEditRowID, + &e.UnreadType, ) if err != nil { return nil, err @@ -425,6 +461,7 @@ func (e *Event) sqlVariables() []any { unsafeJSONString(e.Decrypted), dbutil.StrPtr(e.DecryptedType), unsafeJSONString(e.Unsigned), + dbutil.JSONPtr(e.LocalContent), dbutil.StrPtr(e.TransactionID), dbutil.StrPtr(e.RedactedBy), dbutil.StrPtr(e.RelatesTo), @@ -434,9 +471,26 @@ func (e *Event) sqlVariables() []any { dbutil.StrPtr(e.SendError), dbutil.JSON{Data: reactions}, e.LastEditRowID, + e.UnreadType, } } +func (e *Event) GetNonPushUnreadType() UnreadType { + if e.RelationType == event.RelReplace { + return UnreadTypeNone + } + switch e.Type { + case event.EventMessage.Type, event.EventSticker.Type: + return UnreadTypeNormal + case event.EventEncrypted.Type: + switch e.DecryptedType { + case event.EventMessage.Type, event.EventSticker.Type: + return UnreadTypeNormal + } + } + return UnreadTypeNone +} + func (e *Event) CanUseForPreview() bool { return (e.Type == event.EventMessage.Type || e.Type == event.EventSticker.Type || (e.Type == event.EventEncrypted.Type && diff --git a/hicli/database/room.go b/hicli/database/room.go index d9293cf8..42108022 100644 --- a/hicli/database/room.go +++ b/hicli/database/room.go @@ -23,8 +23,8 @@ import ( const ( getRoomBaseQuery = ` SELECT room_id, creation_content, name, name_quality, avatar, explicit_avatar, topic, canonical_alias, - lazy_load_summary, encryption_event, has_member_list, - preview_event_rowid, sorting_timestamp, prev_batch + lazy_load_summary, encryption_event, has_member_list, preview_event_rowid, sorting_timestamp, + unread_highlights, unread_notifications, unread_messages, prev_batch FROM room ` getRoomsBySortingTimestampQuery = getRoomBaseQuery + `WHERE sorting_timestamp < $1 AND sorting_timestamp > 0 ORDER BY sorting_timestamp DESC LIMIT $2` @@ -47,7 +47,10 @@ const ( has_member_list = room.has_member_list OR $11, preview_event_rowid = COALESCE($12, room.preview_event_rowid), sorting_timestamp = COALESCE($13, room.sorting_timestamp), - prev_batch = COALESCE($14, room.prev_batch) + unread_highlights = COALESCE($14, room.unread_highlights), + unread_notifications = COALESCE($15, room.unread_notifications), + unread_messages = COALESCE($16, room.unread_messages), + prev_batch = COALESCE($17, room.prev_batch) WHERE room_id = $1 ` setRoomPrevBatchQuery = ` @@ -143,8 +146,11 @@ type Room struct { EncryptionEvent *event.EncryptionEventContent `json:"encryption_event,omitempty"` HasMemberList bool `json:"has_member_list"` - PreviewEventRowID EventRowID `json:"preview_event_rowid"` - SortingTimestamp jsontime.UnixMilli `json:"sorting_timestamp"` + PreviewEventRowID EventRowID `json:"preview_event_rowid"` + SortingTimestamp jsontime.UnixMilli `json:"sorting_timestamp"` + UnreadHighlights int `json:"unread_highlights"` + UnreadNotifications int `json:"unread_notifications"` + UnreadMessages int `json:"unread_messages"` PrevBatch string `json:"prev_batch"` } @@ -188,6 +194,18 @@ func (r *Room) CheckChangesAndCopyInto(other *Room) (hasChanges bool) { other.SortingTimestamp = r.SortingTimestamp hasChanges = true } + if r.UnreadHighlights != other.UnreadHighlights { + other.UnreadHighlights = r.UnreadHighlights + hasChanges = true + } + if r.UnreadNotifications != other.UnreadNotifications { + other.UnreadNotifications = r.UnreadNotifications + hasChanges = true + } + if r.UnreadMessages != other.UnreadMessages { + other.UnreadMessages = r.UnreadMessages + hasChanges = true + } if r.PrevBatch != "" && other.PrevBatch == "" { other.PrevBatch = r.PrevBatch hasChanges = true @@ -212,6 +230,9 @@ func (r *Room) Scan(row dbutil.Scannable) (*Room, error) { &r.HasMemberList, &previewEventRowID, &sortingTimestamp, + &r.UnreadHighlights, + &r.UnreadNotifications, + &r.UnreadMessages, &prevBatch, ) if err != nil { @@ -238,6 +259,9 @@ func (r *Room) sqlVariables() []any { r.HasMemberList, dbutil.NumPtr(r.PreviewEventRowID), dbutil.UnixMilliPtr(r.SortingTimestamp.Time), + r.UnreadHighlights, + r.UnreadNotifications, + r.UnreadMessages, dbutil.StrPtr(r.PrevBatch), } } diff --git a/hicli/database/state.go b/hicli/database/state.go index c12f9f60..d6fbf53d 100644 --- a/hicli/database/state.go +++ b/hicli/database/state.go @@ -30,8 +30,10 @@ const ( DELETE FROM current_state WHERE room_id = $1 ` getCurrentRoomStateQuery = ` - SELECT event.rowid, -1, event.room_id, event.event_id, sender, event.type, event.state_key, timestamp, content, decrypted, decrypted_type, unsigned, - transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid + SELECT event.rowid, -1, + event.room_id, event.event_id, sender, event.type, event.state_key, timestamp, content, decrypted, decrypted_type, + unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type, + megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type FROM current_state cs JOIN event ON cs.event_rowid = event.rowid WHERE cs.room_id = $1 diff --git a/hicli/database/timeline.go b/hicli/database/timeline.go index ddebd793..e04eeb88 100644 --- a/hicli/database/timeline.go +++ b/hicli/database/timeline.go @@ -34,8 +34,10 @@ const ( ` findMinRowIDQuery = `SELECT MIN(rowid) FROM timeline` getTimelineQuery = ` - SELECT event.rowid, timeline.rowid, event.room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, unsigned, - transaction_id, redacted_by, relates_to, relation_type, megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid + SELECT event.rowid, timeline.rowid, + event.room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, + unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type, + megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type FROM timeline JOIN event ON event.rowid = timeline.event_rowid WHERE timeline.room_id = $1 AND ($2 = 0 OR timeline.rowid < $2) diff --git a/hicli/database/upgrades/00-latest-revision.sql b/hicli/database/upgrades/00-latest-revision.sql index f8c84a61..0808a6e9 100644 --- a/hicli/database/upgrades/00-latest-revision.sql +++ b/hicli/database/upgrades/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v2 (compatible with v1+): Latest revision +-- v0 -> v3 (compatible with v1+): Latest revision CREATE TABLE account ( user_id TEXT NOT NULL PRIMARY KEY, device_id TEXT NOT NULL, @@ -9,29 +9,34 @@ CREATE TABLE account ( ) STRICT; CREATE TABLE room ( - room_id TEXT NOT NULL PRIMARY KEY, - creation_content TEXT, + room_id TEXT NOT NULL PRIMARY KEY, + creation_content TEXT, - name TEXT, - name_quality INTEGER NOT NULL DEFAULT 0, - avatar TEXT, - explicit_avatar INTEGER NOT NULL DEFAULT 0, - topic TEXT, - canonical_alias TEXT, - lazy_load_summary TEXT, + name TEXT, + name_quality INTEGER NOT NULL DEFAULT 0, + avatar TEXT, + explicit_avatar INTEGER NOT NULL DEFAULT 0, + topic TEXT, + canonical_alias TEXT, + lazy_load_summary TEXT, - encryption_event TEXT, - has_member_list INTEGER NOT NULL DEFAULT false, + encryption_event TEXT, + has_member_list INTEGER NOT NULL DEFAULT false, - preview_event_rowid INTEGER, - sorting_timestamp INTEGER, + preview_event_rowid INTEGER, + sorting_timestamp INTEGER, + unread_highlights INTEGER NOT NULL DEFAULT 0, + unread_notifications INTEGER NOT NULL DEFAULT 0, + unread_messages INTEGER NOT NULL DEFAULT 0, - prev_batch TEXT, + prev_batch TEXT, CONSTRAINT room_preview_event_fkey FOREIGN KEY (preview_event_rowid) REFERENCES event (rowid) ON DELETE SET NULL ) STRICT; CREATE INDEX room_type_idx ON room (creation_content ->> 'type'); CREATE INDEX room_sorting_timestamp_idx ON room (sorting_timestamp DESC); +-- CREATE INDEX room_sorting_timestamp_idx ON room (unread_notifications > 0); +-- CREATE INDEX room_sorting_timestamp_idx ON room (unread_messages > 0); CREATE TABLE account_data ( user_id TEXT NOT NULL, @@ -66,6 +71,7 @@ CREATE TABLE event ( decrypted TEXT, decrypted_type TEXT, unsigned TEXT NOT NULL, + local_content TEXT, transaction_id TEXT, @@ -79,6 +85,7 @@ CREATE TABLE event ( reactions TEXT, last_edit_rowid INTEGER, + unread_type INTEGER NOT NULL DEFAULT 0, CONSTRAINT event_id_unique_key UNIQUE (event_id), CONSTRAINT transaction_id_unique_key UNIQUE (transaction_id), diff --git a/hicli/database/upgrades/03-more-event-fields.sql b/hicli/database/upgrades/03-more-event-fields.sql new file mode 100644 index 00000000..3e07ad75 --- /dev/null +++ b/hicli/database/upgrades/03-more-event-fields.sql @@ -0,0 +1,6 @@ +-- v3 (compatible with v1+): Add more fields to events +ALTER TABLE event ADD COLUMN local_content TEXT; +ALTER TABLE event ADD COLUMN unread_type INTEGER NOT NULL DEFAULT 0; +ALTER TABLE room ADD COLUMN unread_highlights INTEGER NOT NULL DEFAULT 0; +ALTER TABLE room ADD COLUMN unread_notifications INTEGER NOT NULL DEFAULT 0; +ALTER TABLE room ADD COLUMN unread_messages INTEGER NOT NULL DEFAULT 0; diff --git a/hicli/decryptionqueue.go b/hicli/decryptionqueue.go index 87b6b8b2..665ee78a 100644 --- a/hicli/decryptionqueue.go +++ b/hicli/decryptionqueue.go @@ -59,14 +59,14 @@ func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.Ro log.Warn().Err(err).Stringer("event_id", evt.ID).Msg("Failed to decrypt event even after receiving megolm session") } else { decrypted = append(decrypted, evt) - h.cacheMedia(ctx, mautrixEvt, evt.RowID) + h.postDecryptProcess(ctx, nil, evt, mautrixEvt) } } if len(decrypted) > 0 { var newPreview database.EventRowID err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { for _, evt := range decrypted { - err = h.DB.Event.UpdateDecrypted(ctx, evt.RowID, evt.Decrypted, evt.DecryptedType) + err = h.DB.Event.UpdateDecrypted(ctx, evt) if err != nil { return fmt.Errorf("failed to save decrypted content for %s: %w", evt.ID, err) } diff --git a/hicli/events.go b/hicli/events.go index b96fd266..e730475b 100644 --- a/hicli/events.go +++ b/hicli/events.go @@ -13,11 +13,17 @@ import ( ) type SyncRoom struct { - Meta *database.Room `json:"meta"` - Timeline []database.TimelineRowTuple `json:"timeline"` - State map[event.Type]map[string]database.EventRowID `json:"state"` - Events []*database.Event `json:"events"` - Reset bool `json:"reset"` + Meta *database.Room `json:"meta"` + Timeline []database.TimelineRowTuple `json:"timeline"` + State map[event.Type]map[string]database.EventRowID `json:"state"` + Events []*database.Event `json:"events"` + Reset bool `json:"reset"` + Notifications []SyncNotification `json:"notifications"` +} + +type SyncNotification struct { + RowID database.EventRowID `json:"event_rowid"` + Sound bool `json:"sound"` } type SyncComplete struct { diff --git a/hicli/html.go b/hicli/html.go new file mode 100644 index 00000000..b0ad824d --- /dev/null +++ b/hicli/html.go @@ -0,0 +1,476 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "bytes" + "errors" + "fmt" + "io" + "net/url" + "regexp" + "slices" + "strconv" + "strings" + + "golang.org/x/net/html" + "golang.org/x/net/html/atom" + "mvdan.cc/xurls/v2" + + "maunium.net/go/mautrix/id" +) + +func tagIsAllowed(tag atom.Atom) bool { + switch tag { + case atom.Del, atom.H1, atom.H2, atom.H3, atom.H4, atom.H5, atom.H6, atom.Blockquote, atom.P, + atom.A, atom.Ul, atom.Ol, atom.Sup, atom.Sub, atom.Li, atom.B, atom.I, atom.U, atom.Strong, + atom.Em, atom.S, atom.Code, atom.Hr, atom.Br, atom.Div, atom.Table, atom.Thead, atom.Tbody, + atom.Tr, atom.Th, atom.Td, atom.Caption, atom.Pre, atom.Span, atom.Font, atom.Img, + atom.Details, atom.Summary: + return true + default: + return false + } +} + +func isSelfClosing(tag atom.Atom) bool { + switch tag { + case atom.Img, atom.Br, atom.Hr: + return true + default: + return false + } +} + +var languageRegex = regexp.MustCompile(`^language-[a-zA-Z0-9-]+$`) +var allowedColorRegex = regexp.MustCompile(`^#[0-9a-fA-F]{6}$`) + +// This is approximately a mirror of web/src/util/mediasize.ts in gomuks +func calculateMediaSize(widthInt, heightInt int) (width, height float64, ok bool) { + if widthInt <= 0 || heightInt <= 0 { + return + } + width = float64(widthInt) + height = float64(heightInt) + const imageContainerWidth float64 = 320 + const imageContainerHeight float64 = 240 + const imageContainerAspectRatio = imageContainerWidth / imageContainerHeight + if width > imageContainerWidth || height > imageContainerHeight { + aspectRatio := width / height + if aspectRatio > imageContainerAspectRatio { + width = imageContainerWidth + height = imageContainerWidth / aspectRatio + } else if aspectRatio < imageContainerAspectRatio { + width = imageContainerHeight * aspectRatio + height = imageContainerHeight + } else { + width = imageContainerWidth + height = imageContainerHeight + } + } + ok = true + return +} + +func parseImgAttributes(attrs []html.Attribute) (src, alt, title string, isCustomEmoji bool, width, height int) { + for _, attr := range attrs { + switch attr.Key { + case "src": + src = attr.Val + case "alt": + alt = attr.Val + case "title": + title = attr.Val + case "data-mx-emoticon": + isCustomEmoji = true + case "width": + width, _ = strconv.Atoi(attr.Val) + case "height": + height, _ = strconv.Atoi(attr.Val) + } + } + return +} + +func parseSpanAttributes(attrs []html.Attribute) (bgColor, textColor, spoiler, maths string, isSpoiler bool) { + for _, attr := range attrs { + switch attr.Key { + case "data-mx-bg-color": + if allowedColorRegex.MatchString(attr.Val) { + bgColor = attr.Val + } + case "data-mx-color", "color": + if allowedColorRegex.MatchString(attr.Val) { + textColor = attr.Val + } + case "data-mx-spoiler": + spoiler = attr.Val + isSpoiler = true + case "data-mx-maths": + maths = attr.Val + } + } + return +} + +func parseAAttributes(attrs []html.Attribute) (href string) { + for _, attr := range attrs { + switch attr.Key { + case "href": + href = strings.TrimSpace(attr.Val) + } + } + return +} + +func attributeIsAllowed(tag atom.Atom, attr html.Attribute) bool { + switch tag { + case atom.Ol: + switch attr.Key { + case "start": + _, err := strconv.Atoi(attr.Val) + return err == nil + } + case atom.Code: + switch attr.Key { + case "class": + return languageRegex.MatchString(attr.Val) + } + case atom.Div: + switch attr.Key { + case "data-mx-maths": + return true + } + } + return false +} + +// Funny user IDs will just need to be linkified by the sender, no auto-linkification for them. +var plainUserOrAliasMentionRegex = regexp.MustCompile(`[@#][a-zA-Z0-9._=/+-]{0,254}:[a-zA-Z0-9.-]+(?:\d{1,5})?`) + +func getNextItem(items [][]int, minIndex int) (index, start, end int, ok bool) { + for i, item := range items { + if item[0] >= minIndex { + return i, item[0], item[1], true + } + } + return -1, -1, -1, false +} + +func writeMention(w *strings.Builder, mention []byte) { + w.WriteString(`') + writeEscapedBytes(w, mention) + w.WriteString("") +} + +func writeURL(w *strings.Builder, addr []byte) { + parsedURL, err := url.Parse(string(addr)) + if err != nil { + writeEscapedBytes(w, addr) + return + } + if parsedURL.Scheme == "" { + parsedURL.Scheme = "https" + } + w.WriteString(`') + writeEscapedBytes(w, addr) + w.WriteString("") +} + +func linkifyAndWriteBytes(w *strings.Builder, s []byte) { + mentions := plainUserOrAliasMentionRegex.FindAllIndex(s, -1) + urls := xurls.Relaxed().FindAllIndex(s, -1) + minIndex := 0 + for { + mentionIdx, nextMentionStart, nextMentionEnd, hasMention := getNextItem(mentions, minIndex) + urlIdx, nextURLStart, nextURLEnd, hasURL := getNextItem(urls, minIndex) + if hasMention && (!hasURL || nextMentionStart <= nextURLStart) { + writeEscapedBytes(w, s[minIndex:nextMentionStart]) + writeMention(w, s[nextMentionStart:nextMentionEnd]) + minIndex = nextMentionEnd + mentions = mentions[mentionIdx:] + } else if hasURL && (!hasMention || nextURLStart < nextMentionStart) { + writeEscapedBytes(w, s[minIndex:nextURLStart]) + writeURL(w, s[nextURLStart:nextURLEnd]) + minIndex = nextURLEnd + urls = urls[urlIdx:] + } else { + break + } + } + writeEscapedBytes(w, s[minIndex:]) +} + +const escapedChars = "&'<>\"\r" + +func writeEscapedBytes(w *strings.Builder, s []byte) { + i := bytes.IndexAny(s, escapedChars) + for i != -1 { + w.Write(s[:i]) + var esc string + switch s[i] { + case '&': + esc = "&" + case '\'': + // "'" is shorter than "'" and apos was not in HTML until HTML5. + esc = "'" + case '<': + esc = "<" + case '>': + esc = ">" + case '"': + // """ is shorter than """. + esc = """ + case '\r': + esc = " " + default: + panic("unrecognized escape character") + } + s = s[i+1:] + w.WriteString(esc) + i = bytes.IndexAny(s, escapedChars) + } + w.Write(s) +} + +func writeEscapedString(w *strings.Builder, s string) { + i := strings.IndexAny(s, escapedChars) + for i != -1 { + w.WriteString(s[:i]) + var esc string + switch s[i] { + case '&': + esc = "&" + case '\'': + // "'" is shorter than "'" and apos was not in HTML until HTML5. + esc = "'" + case '<': + esc = "<" + case '>': + esc = ">" + case '"': + // """ is shorter than """. + esc = """ + case '\r': + esc = " " + default: + panic("unrecognized escape character") + } + s = s[i+1:] + w.WriteString(esc) + i = strings.IndexAny(s, escapedChars) + } + w.WriteString(s) +} + +func writeAttribute(w *strings.Builder, key, value string) { + w.WriteByte(' ') + w.WriteString(key) + w.WriteString(`="`) + writeEscapedString(w, value) + w.WriteByte('"') +} + +func writeA(w *strings.Builder, attr []html.Attribute) { + w.WriteString("`) + w.WriteString(spoiler) + w.WriteString(" ") + } + w.WriteByte('<') + w.WriteString("span") + if isSpoiler { + writeAttribute(w, "class", "hicli-spoiler") + } + var style string + if bgColor != "" { + style += fmt.Sprintf("background-color: %s;", bgColor) + } + if textColor != "" { + style += fmt.Sprintf("color: %s;", textColor) + } + if style != "" { + writeAttribute(w, "style", style) + } +} + +type tagStack []atom.Atom + +func (ts *tagStack) contains(tags ...atom.Atom) bool { + for i := len(*ts) - 1; i >= 0; i-- { + for _, tag := range tags { + if (*ts)[i] == tag { + return true + } + } + } + return false +} + +func (ts *tagStack) push(tag atom.Atom) { + *ts = append(*ts, tag) +} + +func (ts *tagStack) pop(tag atom.Atom) bool { + if len(*ts) > 0 && (*ts)[len(*ts)-1] == tag { + *ts = (*ts)[:len(*ts)-1] + return true + } + return false +} + +func sanitizeAndLinkifyHTML(body string) (string, error) { + tz := html.NewTokenizer(strings.NewReader(body)) + var built strings.Builder + ts := make(tagStack, 2) +Loop: + for { + switch tz.Next() { + case html.ErrorToken: + err := tz.Err() + if errors.Is(err, io.EOF) { + break Loop + } + return "", err + case html.StartTagToken, html.SelfClosingTagToken: + token := tz.Token() + if !tagIsAllowed(token.DataAtom) { + continue + } + tagIsSelfClosing := isSelfClosing(token.DataAtom) + if token.Type == html.SelfClosingTagToken && !tagIsSelfClosing { + continue + } + switch token.DataAtom { + case atom.A: + writeA(&built, token.Attr) + case atom.Img: + writeImg(&built, token.Attr) + case atom.Span, atom.Font: + writeSpan(&built, token.Attr) + default: + built.WriteByte('<') + built.WriteString(token.Data) + for _, attr := range token.Attr { + if attributeIsAllowed(token.DataAtom, attr) { + writeAttribute(&built, attr.Key, attr.Val) + } + } + } + built.WriteByte('>') + if !tagIsSelfClosing { + ts.push(token.DataAtom) + } + case html.EndTagToken: + tagName, _ := tz.TagName() + tag := atom.Lookup(tagName) + if tagIsAllowed(tag) && ts.pop(tag) { + built.WriteString("') + } + case html.TextToken: + if ts.contains(atom.Pre, atom.Code, atom.A) { + writeEscapedBytes(&built, tz.Text()) + } else { + linkifyAndWriteBytes(&built, tz.Text()) + } + case html.DoctypeToken, html.CommentToken: + // ignore + } + } + slices.Reverse(ts) + for _, t := range ts { + built.WriteString("') + } + return built.String(), nil +} diff --git a/hicli/paginate.go b/hicli/paginate.go index da927b9b..7fc50827 100644 --- a/hicli/paginate.go +++ b/hicli/paginate.go @@ -59,7 +59,7 @@ func (h *HiClient) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.Ev } else if serverEvt, err := h.Client.GetEvent(ctx, roomID, eventID); err != nil { return nil, fmt.Errorf("failed to get event from server: %w", err) } else { - return h.processEvent(ctx, serverEvt, nil, false) + return h.processEvent(ctx, serverEvt, nil, nil, false) } } @@ -90,7 +90,7 @@ func (h *HiClient) GetRoomState(ctx context.Context, roomID id.RoomID, fetchMemb } entries := make([]*database.CurrentStateEntry, len(evts)) for i, evt := range evts { - dbEvt, err := h.processEvent(ctx, evt, nil, false) + dbEvt, err := h.processEvent(ctx, evt, room.LazyLoadSummary, nil, false) if err != nil { return fmt.Errorf("failed to process event %s: %w", evt.ID, err) } @@ -186,7 +186,7 @@ func (h *HiClient) PaginateServer(ctx context.Context, roomID id.RoomID, limit i decryptionQueue := make(map[id.SessionID]*database.SessionRequest) iOffset := 0 for i, evt := range resp.Chunk { - dbEvt, err := h.processEvent(ctx, evt, decryptionQueue, true) + dbEvt, err := h.processEvent(ctx, evt, room.LazyLoadSummary, decryptionQueue, true) if err != nil { return err } else if exists, err := h.DB.Timeline.Has(ctx, roomID, dbEvt.RowID); err != nil { diff --git a/hicli/pushrules.go b/hicli/pushrules.go new file mode 100644 index 00000000..74c0e8e4 --- /dev/null +++ b/hicli/pushrules.go @@ -0,0 +1,80 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package hicli + +import ( + "context" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/hicli/database" + "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/pushrules" +) + +type pushRoom struct { + ctx context.Context + roomID id.RoomID + h *HiClient + ll *mautrix.LazyLoadSummary +} + +func (p *pushRoom) GetOwnDisplayname() string { + // TODO implement + return "" +} + +func (p *pushRoom) GetMemberCount() int { + if p.ll == nil { + room, err := p.h.DB.Room.Get(p.ctx, p.roomID) + if err != nil { + zerolog.Ctx(p.ctx).Err(err). + Stringer("room_id", p.roomID). + Msg("Failed to get room by ID in push rule evaluator") + } else if room != nil { + p.ll = room.LazyLoadSummary + } + } + if p.ll != nil && p.ll.JoinedMemberCount != nil { + return *p.ll.JoinedMemberCount + } + // TODO query db? + return 0 +} + +func (p *pushRoom) GetEvent(id id.EventID) *event.Event { + evt, err := p.h.DB.Event.GetByID(p.ctx, id) + if err != nil { + zerolog.Ctx(p.ctx).Err(err). + Stringer("event_id", id). + Msg("Failed to get event by ID in push rule evaluator") + } + return evt.AsRawMautrix() +} + +var _ pushrules.EventfulRoom = (*pushRoom)(nil) + +func (h *HiClient) evaluatePushRules(ctx context.Context, llSummary *mautrix.LazyLoadSummary, baseType database.UnreadType, evt *event.Event) database.UnreadType { + should := h.PushRules.Load().GetMatchingRule(&pushRoom{ + ctx: ctx, + roomID: evt.RoomID, + h: h, + ll: llSummary, + }, evt).GetActions().Should() + if should.Notify { + baseType |= database.UnreadTypeNotify + } + if should.Highlight { + baseType |= database.UnreadTypeHighlight + } + if should.PlaySound { + baseType |= database.UnreadTypeSound + } + return baseType +} diff --git a/hicli/send.go b/hicli/send.go index 76852dde..cdb8571b 100644 --- a/hicli/send.go +++ b/hicli/send.go @@ -218,7 +218,7 @@ func (h *HiClient) loadMembers(ctx context.Context, room *database.Room) error { err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { entries := make([]*database.CurrentStateEntry, len(resp.Chunk)) for i, evt := range resp.Chunk { - dbEvt, err := h.processEvent(ctx, evt, nil, true) + dbEvt, err := h.processEvent(ctx, evt, nil, nil, true) if err != nil { return err } diff --git a/hicli/sync.go b/hicli/sync.go index 16930b59..dcb33637 100644 --- a/hicli/sync.go +++ b/hicli/sync.go @@ -134,11 +134,15 @@ func (h *HiClient) processSyncResponse(ctx context.Context, resp *mautrix.RespSy return nil } -func receiptsToList(content *event.ReceiptEventContent) []*database.Receipt { +func (h *HiClient) receiptsToList(content *event.ReceiptEventContent) ([]*database.Receipt, []id.EventID) { receiptList := make([]*database.Receipt, 0) + var newOwnReceipts []id.EventID for eventID, receipts := range *content { for receiptType, users := range receipts { for userID, receiptInfo := range users { + if userID == h.Account.UserID { + newOwnReceipts = append(newOwnReceipts, eventID) + } receiptList = append(receiptList, &database.Receipt{ UserID: userID, ReceiptType: receiptType, @@ -149,7 +153,12 @@ func receiptsToList(content *event.ReceiptEventContent) []*database.Receipt { } } } - return receiptList + return receiptList, newOwnReceipts +} + +type receiptsToSave struct { + roomID id.RoomID + receipts []*database.Receipt } func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncJoinedRoom) error { @@ -172,10 +181,8 @@ func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err) } } - err = h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary) - if err != nil { - return err - } + var receipts []receiptsToSave + var newOwnReceipts []id.EventID for _, evt := range room.Ephemeral.Events { evt.Type.Class = event.EphemeralEventType err = evt.Content.ParseRaw(evt.Type) @@ -185,18 +192,24 @@ func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, } switch evt.Type { case event.EphemeralEventReceipt: - err = h.DB.Receipt.PutMany(ctx, roomID, receiptsToList(evt.Content.AsReceipt())...) - if err != nil { - return fmt.Errorf("failed to save receipts: %w", err) - } + var receiptsList []*database.Receipt + receiptsList, newOwnReceipts = h.receiptsToList(evt.Content.AsReceipt()) + receipts = append(receipts, receiptsToSave{roomID, receiptsList}) case event.EphemeralEventTyping: go h.EventHandler(&Typing{ RoomID: roomID, TypingEventContent: *evt.Content.AsTyping(), }) } - if evt.Type != event.EphemeralEventReceipt { - continue + } + err = h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary, newOwnReceipts, room.UnreadNotifications) + if err != nil { + return err + } + for _, rs := range receipts { + err = h.DB.Receipt.PutMany(ctx, rs.roomID, rs.receipts...) + if err != nil { + return fmt.Errorf("failed to save receipts: %w", err) } } return nil @@ -209,7 +222,8 @@ func (h *HiClient) processSyncLeftRoom(ctx context.Context, roomID id.RoomID, ro } else if existingRoomData == nil { return nil } - return h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary) + // TODO delete room instead of processing? + return h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary, nil, nil) } func isDecryptionErrorRetryable(err error) bool { @@ -318,7 +332,47 @@ func (h *HiClient) cacheMedia(ctx context.Context, evt *event.Event, rowID datab } } -func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptionQueue map[id.SessionID]*database.SessionRequest, checkDB bool) (*database.Event, error) { +func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Event, evt *event.Event) *database.LocalContent { + if evt.Type != event.EventMessage && evt.Type != event.EventSticker { + return nil + } + _ = evt.Content.ParseRaw(evt.Type) + content, ok := evt.Content.Parsed.(*event.MessageEventContent) + if !ok { + return nil + } + if dbEvt.RelationType == event.RelReplace && content.NewContent != nil { + content = content.NewContent + } + if content != nil { + var sanitizedHTML string + if content.Format == event.FormatHTML { + sanitizedHTML, _ = sanitizeAndLinkifyHTML(content.FormattedBody) + } else { + var builder strings.Builder + linkifyAndWriteBytes(&builder, []byte(content.Body)) + sanitizedHTML = builder.String() + } + return &database.LocalContent{SanitizedHTML: sanitizedHTML} + } + return nil +} + +func (h *HiClient) postDecryptProcess(ctx context.Context, llSummary *mautrix.LazyLoadSummary, dbEvt *database.Event, evt *event.Event) { + if dbEvt.RowID != 0 { + h.cacheMedia(ctx, evt, dbEvt.RowID) + } + dbEvt.UnreadType = h.evaluatePushRules(ctx, llSummary, dbEvt.GetNonPushUnreadType(), evt) + dbEvt.LocalContent = h.calculateLocalContent(ctx, dbEvt, evt) +} + +func (h *HiClient) processEvent( + ctx context.Context, + evt *event.Event, + llSummary *mautrix.LazyLoadSummary, + decryptionQueue map[id.SessionID]*database.SessionRequest, + checkDB bool, +) (*database.Event, error) { if checkDB { dbEvt, err := h.DB.Event.GetByID(ctx, evt.ID) if err != nil { @@ -350,6 +404,11 @@ func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptio evt.Redacts = id.EventID(gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str) } } + if decryptedMautrixEvt != nil { + h.postDecryptProcess(ctx, llSummary, dbEvt, decryptedMautrixEvt) + } else { + h.postDecryptProcess(ctx, llSummary, dbEvt, evt) + } _, err := h.DB.Event.Upsert(ctx, dbEvt) if err != nil { return dbEvt, fmt.Errorf("failed to save event %s: %w", evt.ID, err) @@ -386,12 +445,27 @@ func (h *HiClient) processEvent(ctx context.Context, evt *event.Event, decryptio return dbEvt, err } -func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.Room, state *mautrix.SyncEventsList, timeline *mautrix.SyncTimeline, summary *mautrix.LazyLoadSummary) error { +func (h *HiClient) processStateAndTimeline( + ctx context.Context, + room *database.Room, + state *mautrix.SyncEventsList, + timeline *mautrix.SyncTimeline, + summary *mautrix.LazyLoadSummary, + newOwnReceipts []id.EventID, + serverNotificationCounts *mautrix.UnreadNotificationCounts, +) error { updatedRoom := &database.Room{ ID: room.ID, - SortingTimestamp: room.SortingTimestamp, - NameQuality: room.NameQuality, + SortingTimestamp: room.SortingTimestamp, + NameQuality: room.NameQuality, + UnreadHighlights: room.UnreadHighlights, + UnreadNotifications: room.UnreadNotifications, + UnreadMessages: room.UnreadMessages, + } + if serverNotificationCounts != nil { + updatedRoom.UnreadHighlights = serverNotificationCounts.HighlightCount + updatedRoom.UnreadNotifications = serverNotificationCounts.NotificationCount } heroesChanged := false if summary.Heroes == nil && summary.JoinedMemberCount == nil && summary.InvitedMemberCount == nil { @@ -405,6 +479,7 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R } decryptionQueue := make(map[id.SessionID]*database.SessionRequest) allNewEvents := make([]*database.Event, 0, len(state.Events)+len(timeline.Events)) + newNotifications := make([]SyncNotification, 0) recalculatePreviewEvent := false addOldEvent := func(rowID database.EventRowID, evtID id.EventID) (dbEvt *database.Event, err error) { if rowID != 0 { @@ -440,12 +515,18 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R } return nil } - processNewEvent := func(evt *event.Event, isTimeline bool) (database.EventRowID, error) { + processNewEvent := func(evt *event.Event, isTimeline, isUnread bool) (database.EventRowID, error) { evt.RoomID = room.ID - dbEvt, err := h.processEvent(ctx, evt, decryptionQueue, false) + dbEvt, err := h.processEvent(ctx, evt, summary, decryptionQueue, false) if err != nil { return -1, err } + if isUnread && dbEvt.UnreadType.Is(database.UnreadTypeNotify) { + newNotifications = append(newNotifications, SyncNotification{ + RowID: dbEvt.RowID, + Sound: dbEvt.UnreadType.Is(database.UnreadTypeSound), + }) + } if isTimeline { if dbEvt.CanUseForPreview() { updatedRoom.PreviewEventRowID = dbEvt.RowID @@ -492,7 +573,7 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R } for _, evt := range state.Events { evt.Type.Class = event.StateEventType - rowID, err := processNewEvent(evt, false) + rowID, err := processNewEvent(evt, false, false) if err != nil { return err } @@ -502,13 +583,20 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R var err error if len(timeline.Events) > 0 { timelineIDs := make([]database.EventRowID, len(timeline.Events)) + readUpToIndex := -1 + for i := len(timeline.Events) - 1; i >= 0; i-- { + if slices.Contains(newOwnReceipts, timeline.Events[i].ID) { + readUpToIndex = i + break + } + } for i, evt := range timeline.Events { if evt.StateKey != nil { evt.Type.Class = event.StateEventType } else { evt.Type.Class = event.MessageEventType } - timelineIDs[i], err = processNewEvent(evt, true) + timelineIDs[i], err = processNewEvent(evt, true, i > readUpToIndex) if err != nil { return err } @@ -578,11 +666,12 @@ func (h *HiClient) processStateAndTimeline(ctx context.Context, room *database.R } if roomChanged || len(timelineRowTuples) > 0 || len(allNewEvents) > 0 { ctx.Value(syncContextKey).(*syncContext).evt.Rooms[room.ID] = &SyncRoom{ - Meta: room, - Timeline: timelineRowTuples, - State: changedState, - Reset: timeline.Limited, - Events: allNewEvents, + Meta: room, + Timeline: timelineRowTuples, + State: changedState, + Reset: timeline.Limited, + Events: allNewEvents, + Notifications: newNotifications, } } return nil diff --git a/pushrules/ruleset.go b/pushrules/ruleset.go index 609997b4..c42d4799 100644 --- a/pushrules/ruleset.go +++ b/pushrules/ruleset.go @@ -68,6 +68,9 @@ func (rs *PushRuleset) MarshalJSON() ([]byte, error) { var DefaultPushActions = PushActionArray{&PushAction{Action: ActionDontNotify}} func (rs *PushRuleset) GetMatchingRule(room Room, evt *event.Event) (rule *PushRule) { + if rs == nil { + return nil + } // Add push rule collections to array in priority order arrays := []PushRuleCollection{rs.Override, rs.Content, rs.Room, rs.Sender, rs.Underride} // Loop until one of the push rule collections matches the room/event combo. From 915167f459856aa4a1e2314800072883d130e7da Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 17 Oct 2024 13:47:11 +0300 Subject: [PATCH 0849/1647] bridgev2/commands: use PathUnescape instead of QueryUnescape for cookies `+` should not be decoded into a space --- bridgev2/commands/login.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go index e5b3e50c..bf9cdf45 100644 --- a/bridgev2/commands/login.go +++ b/bridgev2/commands/login.go @@ -123,7 +123,7 @@ func checkLoginCommandDirectParams(ce *Event, login bridgev2.LoginProcess, nextS for i, param := range nextStep.CookiesParams.Fields { val := maybeURLDecodeCookie(ce.Args[i], ¶m) if match, _ := regexp.MatchString(param.Pattern, val); !match { - ce.Reply("Invalid value for %s: doesn't match regex `%s`", param.ID, param.Pattern) + ce.Reply("Invalid value for %s: `%s` doesn't match regex `%s`", param.ID, val, param.Pattern) return nil } input[param.ID] = val @@ -292,7 +292,7 @@ func (clcs *cookieLoginCommandState) submit(ce *Event) { } reqCookies := make(map[string]string) for _, cookie := range parsed.Cookies() { - reqCookies[cookie.Name], err = url.QueryUnescape(cookie.Value) + reqCookies[cookie.Name], err = url.PathUnescape(cookie.Value) if err != nil { ce.Reply("Failed to parse cookie %s: %v", cookie.Name, err) return @@ -365,7 +365,7 @@ func (clcs *cookieLoginCommandState) submit(ce *Event) { missingKeys = append(missingKeys, field.ID) } if match, _ := regexp.MatchString(field.Pattern, val); !match { - ce.Reply("Invalid value for %s: doesn't match regex `%s`", field.ID, field.Pattern) + ce.Reply("Invalid value for %s: `%s` doesn't match regex `%s`", field.ID, val, field.Pattern) return } } @@ -396,7 +396,7 @@ func maybeURLDecodeCookie(val string, field *bridgev2.LoginCookieField) string { if !match { return val } - decoded, err := url.QueryUnescape(val) + decoded, err := url.PathUnescape(val) if err != nil { return val } From af360cd534faf911d6dfb6e26bbaa48983a554ae Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 17 Oct 2024 20:02:23 +0300 Subject: [PATCH 0850/1647] id: drop support for room alias + event ID links --- id/matrixuri.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/id/matrixuri.go b/id/matrixuri.go index acd8e0c0..2637d876 100644 --- a/id/matrixuri.go +++ b/id/matrixuri.go @@ -213,7 +213,7 @@ func ProcessMatrixURI(uri *url.URL) (*MatrixURI, error) { parsed.MXID1 = parts[1] // Step 6: if the first part is a room and the URI has 4 segments, construct a second level identifier - if (parsed.Sigil1 == '!' || parsed.Sigil1 == '#') && len(parts) == 4 { + if parsed.Sigil1 == '!' && len(parts) == 4 { // a: find the sigil from the third segment switch parts[2] { case "e", "event": From 32e9a2f6e31642f93d9675706aa9c7310e348def Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 17 Oct 2024 20:04:09 +0300 Subject: [PATCH 0851/1647] hicli: move into gomuks --- go.mod | 1 - go.sum | 2 - hicli/cryptohelper.go | 65 -- hicli/database/account.go | 74 -- hicli/database/accountdata.go | 71 -- hicli/database/cachedmedia.go | 150 ---- hicli/database/database.go | 73 -- hicli/database/event.go | 504 ----------- hicli/database/receipt.go | 82 -- hicli/database/room.go | 279 ------ hicli/database/sessionrequest.go | 69 -- hicli/database/state.go | 95 -- hicli/database/statestore.go | 188 ---- hicli/database/timeline.go | 136 --- .../database/upgrades/00-latest-revision.sql | 255 ------ .../upgrades/02-explicit-avatar-flag.sql | 2 - .../upgrades/03-more-event-fields.sql | 6 - hicli/database/upgrades/upgrades.go | 22 - hicli/decryptionqueue.go | 209 ----- hicli/events.go | 59 -- hicli/hicli.go | 250 ------ hicli/hitest/hitest.go | 110 --- hicli/html.go | 476 ---------- hicli/json-commands.go | 178 ---- hicli/json.go | 119 --- hicli/login.go | 87 -- hicli/paginate.go | 240 ----- hicli/pushrules.go | 80 -- hicli/send.go | 287 ------ hicli/sync.go | 833 ------------------ hicli/syncwrap.go | 96 -- hicli/verify.go | 162 ---- 32 files changed, 5260 deletions(-) delete mode 100644 hicli/cryptohelper.go delete mode 100644 hicli/database/account.go delete mode 100644 hicli/database/accountdata.go delete mode 100644 hicli/database/cachedmedia.go delete mode 100644 hicli/database/database.go delete mode 100644 hicli/database/event.go delete mode 100644 hicli/database/receipt.go delete mode 100644 hicli/database/room.go delete mode 100644 hicli/database/sessionrequest.go delete mode 100644 hicli/database/state.go delete mode 100644 hicli/database/statestore.go delete mode 100644 hicli/database/timeline.go delete mode 100644 hicli/database/upgrades/00-latest-revision.sql delete mode 100644 hicli/database/upgrades/02-explicit-avatar-flag.sql delete mode 100644 hicli/database/upgrades/03-more-event-fields.sql delete mode 100644 hicli/database/upgrades/upgrades.go delete mode 100644 hicli/decryptionqueue.go delete mode 100644 hicli/events.go delete mode 100644 hicli/hicli.go delete mode 100644 hicli/hitest/hitest.go delete mode 100644 hicli/html.go delete mode 100644 hicli/json-commands.go delete mode 100644 hicli/json.go delete mode 100644 hicli/login.go delete mode 100644 hicli/paginate.go delete mode 100644 hicli/pushrules.go delete mode 100644 hicli/send.go delete mode 100644 hicli/sync.go delete mode 100644 hicli/syncwrap.go delete mode 100644 hicli/verify.go diff --git a/go.mod b/go.mod index ad6dbdc5..f45b8990 100644 --- a/go.mod +++ b/go.mod @@ -28,7 +28,6 @@ require ( golang.org/x/sync v0.8.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 - mvdan.cc/xurls/v2 v2.5.0 ) require ( diff --git a/go.sum b/go.sum index 955cbb91..e7a58076 100644 --- a/go.sum +++ b/go.sum @@ -83,5 +83,3 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= -mvdan.cc/xurls/v2 v2.5.0 h1:lyBNOm8Wo71UknhUs4QTFUNNMyxy2JEIaKKo0RWOh+8= -mvdan.cc/xurls/v2 v2.5.0/go.mod h1:yQgaGQ1rFtJUzkmKiHYSSfuQxqfYmd//X6PxvholpeE= diff --git a/hicli/cryptohelper.go b/hicli/cryptohelper.go deleted file mode 100644 index 2a2e9626..00000000 --- a/hicli/cryptohelper.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package hicli - -import ( - "context" - "fmt" - "time" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -type hiCryptoHelper HiClient - -var _ mautrix.CryptoHelper = (*hiCryptoHelper)(nil) - -func (h *hiCryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (*event.EncryptedEventContent, error) { - roomMeta, err := h.DB.Room.Get(ctx, roomID) - if err != nil { - return nil, fmt.Errorf("failed to get room metadata: %w", err) - } else if roomMeta == nil { - return nil, fmt.Errorf("unknown room") - } - return (*HiClient)(h).Encrypt(ctx, roomMeta, evtType, content) -} - -func (h *hiCryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*event.Event, error) { - return h.Crypto.DecryptMegolmEvent(ctx, evt) -} - -func (h *hiCryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { - return h.Crypto.WaitForSession(ctx, roomID, senderKey, sessionID, timeout) -} - -func (h *hiCryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) { - err := h.Crypto.SendRoomKeyRequest(ctx, roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{ - userID: {deviceID}, - h.Account.UserID: {"*"}, - }) - if err != nil { - zerolog.Ctx(ctx).Err(err). - Stringer("room_id", roomID). - Stringer("session_id", sessionID). - Stringer("user_id", userID). - Msg("Failed to send room key request") - } else { - zerolog.Ctx(ctx).Debug(). - Stringer("room_id", roomID). - Stringer("session_id", sessionID). - Stringer("user_id", userID). - Msg("Sent room key request") - } -} - -func (h *hiCryptoHelper) Init(ctx context.Context) error { - return nil -} diff --git a/hicli/database/account.go b/hicli/database/account.go deleted file mode 100644 index 1dde74fd..00000000 --- a/hicli/database/account.go +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package database - -import ( - "context" - "database/sql" - "errors" - - "go.mau.fi/util/dbutil" - - "maunium.net/go/mautrix/id" -) - -const ( - getAccountQuery = `SELECT user_id, device_id, access_token, homeserver_url, next_batch FROM account WHERE user_id = $1` - putNextBatchQuery = `UPDATE account SET next_batch = $1 WHERE user_id = $2` - upsertAccountQuery = ` - INSERT INTO account (user_id, device_id, access_token, homeserver_url, next_batch) - VALUES ($1, $2, $3, $4, $5) ON CONFLICT (user_id) - DO UPDATE SET device_id = excluded.device_id, - access_token = excluded.access_token, - homeserver_url = excluded.homeserver_url, - next_batch = excluded.next_batch - ` -) - -type AccountQuery struct { - *dbutil.QueryHelper[*Account] -} - -func (aq *AccountQuery) GetFirstUserID(ctx context.Context) (userID id.UserID, err error) { - var exists bool - if exists, err = aq.GetDB().TableExists(ctx, "account"); err != nil || !exists { - return - } - err = aq.GetDB().QueryRow(ctx, `SELECT user_id FROM account LIMIT 1`).Scan(&userID) - if errors.Is(err, sql.ErrNoRows) { - err = nil - } - return -} - -func (aq *AccountQuery) Get(ctx context.Context, userID id.UserID) (*Account, error) { - return aq.QueryOne(ctx, getAccountQuery, userID) -} - -func (aq *AccountQuery) PutNextBatch(ctx context.Context, userID id.UserID, nextBatch string) error { - return aq.Exec(ctx, putNextBatchQuery, nextBatch, userID) -} - -func (aq *AccountQuery) Put(ctx context.Context, account *Account) error { - return aq.Exec(ctx, upsertAccountQuery, account.sqlVariables()...) -} - -type Account struct { - UserID id.UserID - DeviceID id.DeviceID - AccessToken string - HomeserverURL string - NextBatch string -} - -func (a *Account) Scan(row dbutil.Scannable) (*Account, error) { - return dbutil.ValueOrErr(a, row.Scan(&a.UserID, &a.DeviceID, &a.AccessToken, &a.HomeserverURL, &a.NextBatch)) -} - -func (a *Account) sqlVariables() []any { - return []any{a.UserID, a.DeviceID, a.AccessToken, a.HomeserverURL, a.NextBatch} -} diff --git a/hicli/database/accountdata.go b/hicli/database/accountdata.go deleted file mode 100644 index 8723b595..00000000 --- a/hicli/database/accountdata.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package database - -import ( - "context" - "database/sql" - "encoding/json" - "unsafe" - - "go.mau.fi/util/dbutil" - - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -const ( - upsertAccountDataQuery = ` - INSERT INTO account_data (user_id, type, content) VALUES ($1, $2, $3) - ON CONFLICT (user_id, type) DO UPDATE SET content = excluded.content - ` - upsertRoomAccountDataQuery = ` - INSERT INTO room_account_data (user_id, room_id, type, content) VALUES ($1, $2, $3, $4) - ON CONFLICT (user_id, room_id, type) DO UPDATE SET content = excluded.content - ` -) - -type AccountDataQuery struct { - *dbutil.QueryHelper[*AccountData] -} - -func unsafeJSONString(content json.RawMessage) *string { - if content == nil { - return nil - } - str := unsafe.String(unsafe.SliceData(content), len(content)) - return &str -} - -func (adq *AccountDataQuery) Put(ctx context.Context, userID id.UserID, eventType event.Type, content json.RawMessage) error { - return adq.Exec(ctx, upsertAccountDataQuery, userID, eventType.Type, unsafeJSONString(content)) -} - -func (adq *AccountDataQuery) PutRoom(ctx context.Context, userID id.UserID, roomID id.RoomID, eventType event.Type, content json.RawMessage) error { - return adq.Exec(ctx, upsertRoomAccountDataQuery, userID, roomID, eventType.Type, unsafeJSONString(content)) -} - -type AccountData struct { - UserID id.UserID `json:"user_id"` - RoomID id.RoomID `json:"room_id,omitempty"` - Type string `json:"type"` - Content json.RawMessage `json:"content"` -} - -func (a *AccountData) Scan(row dbutil.Scannable) (*AccountData, error) { - var roomID sql.NullString - err := row.Scan(&a.UserID, &roomID, &a.Type, (*[]byte)(&a.Content)) - if err != nil { - return nil, err - } - a.RoomID = id.RoomID(roomID.String) - return a, nil -} - -func (a *AccountData) sqlVariables() []any { - return []any{a.UserID, dbutil.StrPtr(a.RoomID), a.Type, unsafeJSONString(a.Content)} -} diff --git a/hicli/database/cachedmedia.go b/hicli/database/cachedmedia.go deleted file mode 100644 index 2ccaca3b..00000000 --- a/hicli/database/cachedmedia.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package database - -import ( - "context" - "database/sql" - "net/http" - "time" - - "go.mau.fi/util/dbutil" - "go.mau.fi/util/jsontime" - "golang.org/x/exp/slices" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/crypto/attachment" - "maunium.net/go/mautrix/id" -) - -const ( - insertCachedMediaQuery = ` - INSERT INTO cached_media (mxc, event_rowid, enc_file, file_name, mime_type, size, hash, error) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - ON CONFLICT (mxc) DO NOTHING - ` - upsertCachedMediaQuery = ` - INSERT INTO cached_media (mxc, event_rowid, enc_file, file_name, mime_type, size, hash, error) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - ON CONFLICT (mxc) DO UPDATE - SET enc_file = excluded.enc_file, - file_name = excluded.file_name, - mime_type = excluded.mime_type, - size = excluded.size, - hash = excluded.hash, - error = excluded.error - WHERE excluded.error IS NULL OR cached_media.hash IS NULL - ` - getCachedMediaQuery = ` - SELECT mxc, event_rowid, enc_file, file_name, mime_type, size, hash, error - FROM cached_media - WHERE mxc = $1 - ` -) - -type CachedMediaQuery struct { - *dbutil.QueryHelper[*CachedMedia] -} - -func (cmq *CachedMediaQuery) Add(ctx context.Context, cm *CachedMedia) error { - return cmq.Exec(ctx, insertCachedMediaQuery, cm.sqlVariables()...) -} - -func (cmq *CachedMediaQuery) Put(ctx context.Context, cm *CachedMedia) error { - return cmq.Exec(ctx, upsertCachedMediaQuery, cm.sqlVariables()...) -} - -func (cmq *CachedMediaQuery) Get(ctx context.Context, mxc id.ContentURI) (*CachedMedia, error) { - return cmq.QueryOne(ctx, getCachedMediaQuery, &mxc) -} - -type MediaError struct { - Matrix *mautrix.RespError `json:"data"` - StatusCode int `json:"status_code"` - ReceivedAt jsontime.UnixMilli `json:"received_at"` - Attempts int `json:"attempts"` -} - -const MaxMediaBackoff = 7 * 24 * time.Hour - -func (me *MediaError) backoff() time.Duration { - return min(time.Duration(2< 0 { - err = eq.Exec(ctx, updateReactionCountsQuery, evtID, dbutil.JSON{Data: &res.Counts}) - if err != nil { - return err - } - } - } - return nil - }) -} - -type EventRowID int64 - -func (m EventRowID) GetMassInsertValues() [1]any { - return [1]any{m} -} - -type LocalContent struct { - SanitizedHTML string `json:"sanitized_html,omitempty"` -} - -type UnreadType int - -func (ut UnreadType) Is(flag UnreadType) bool { - return ut&flag != 0 -} - -const ( - UnreadTypeNone UnreadType = 0b0000 - UnreadTypeNormal UnreadType = 0b0001 - UnreadTypeNotify UnreadType = 0b0010 - UnreadTypeHighlight UnreadType = 0b0100 - UnreadTypeSound UnreadType = 0b1000 -) - -type Event struct { - RowID EventRowID `json:"rowid"` - TimelineRowID TimelineRowID `json:"timeline_rowid"` - - RoomID id.RoomID `json:"room_id"` - ID id.EventID `json:"event_id"` - Sender id.UserID `json:"sender"` - Type string `json:"type"` - StateKey *string `json:"state_key,omitempty"` - Timestamp jsontime.UnixMilli `json:"timestamp"` - - Content json.RawMessage `json:"content"` - Decrypted json.RawMessage `json:"decrypted,omitempty"` - DecryptedType string `json:"decrypted_type,omitempty"` - Unsigned json.RawMessage `json:"unsigned,omitempty"` - LocalContent *LocalContent `json:"local_content,omitempty"` - - TransactionID string `json:"transaction_id,omitempty"` - - RedactedBy id.EventID `json:"redacted_by,omitempty"` - RelatesTo id.EventID `json:"relates_to,omitempty"` - RelationType event.RelationType `json:"relation_type,omitempty"` - - MegolmSessionID id.SessionID `json:"-,omitempty"` - DecryptionError string `json:"decryption_error,omitempty"` - SendError string `json:"send_error,omitempty"` - - Reactions map[string]int `json:"reactions,omitempty"` - LastEditRowID *EventRowID `json:"last_edit_rowid,omitempty"` - UnreadType UnreadType `json:"unread_type,omitempty"` -} - -func MautrixToEvent(evt *event.Event) *Event { - dbEvt := &Event{ - RoomID: evt.RoomID, - ID: evt.ID, - Sender: evt.Sender, - Type: evt.Type.Type, - StateKey: evt.StateKey, - Timestamp: jsontime.UM(time.UnixMilli(evt.Timestamp)), - Content: evt.Content.VeryRaw, - MegolmSessionID: getMegolmSessionID(evt), - TransactionID: evt.Unsigned.TransactionID, - } - if !strings.HasPrefix(dbEvt.TransactionID, "hicli-mautrix-go_") { - dbEvt.TransactionID = "" - } - dbEvt.RelatesTo, dbEvt.RelationType = getRelatesToFromEvent(evt) - dbEvt.Unsigned, _ = json.Marshal(&evt.Unsigned) - if evt.Unsigned.RedactedBecause != nil { - dbEvt.RedactedBy = evt.Unsigned.RedactedBecause.ID - } - return dbEvt -} - -func (e *Event) AsRawMautrix() *event.Event { - if e == nil { - return nil - } - evt := &event.Event{ - RoomID: e.RoomID, - ID: e.ID, - Sender: e.Sender, - Type: event.Type{Type: e.Type, Class: event.MessageEventType}, - StateKey: e.StateKey, - Timestamp: e.Timestamp.UnixMilli(), - Content: event.Content{VeryRaw: e.Content}, - } - if e.Decrypted != nil { - evt.Content.VeryRaw = e.Decrypted - evt.Type.Type = e.DecryptedType - evt.Mautrix.WasEncrypted = true - } - if e.StateKey != nil { - evt.Type.Class = event.StateEventType - } - _ = json.Unmarshal(e.Unsigned, &evt.Unsigned) - return evt -} - -func (e *Event) Scan(row dbutil.Scannable) (*Event, error) { - var timestamp int64 - var transactionID, redactedBy, relatesTo, relationType, megolmSessionID, decryptionError, sendError, decryptedType sql.NullString - err := row.Scan( - &e.RowID, - &e.TimelineRowID, - &e.RoomID, - &e.ID, - &e.Sender, - &e.Type, - &e.StateKey, - ×tamp, - (*[]byte)(&e.Content), - (*[]byte)(&e.Decrypted), - &decryptedType, - (*[]byte)(&e.Unsigned), - dbutil.JSON{Data: &e.LocalContent}, - &transactionID, - &redactedBy, - &relatesTo, - &relationType, - &megolmSessionID, - &decryptionError, - &sendError, - dbutil.JSON{Data: &e.Reactions}, - &e.LastEditRowID, - &e.UnreadType, - ) - if err != nil { - return nil, err - } - e.Timestamp = jsontime.UM(time.UnixMilli(timestamp)) - e.TransactionID = transactionID.String - e.RedactedBy = id.EventID(redactedBy.String) - e.RelatesTo = id.EventID(relatesTo.String) - e.RelationType = event.RelationType(relationType.String) - e.MegolmSessionID = id.SessionID(megolmSessionID.String) - e.DecryptedType = decryptedType.String - e.DecryptionError = decryptionError.String - e.SendError = sendError.String - return e, nil -} - -var relatesToPath = exgjson.Path("m.relates_to", "event_id") -var relationTypePath = exgjson.Path("m.relates_to", "rel_type") - -func getRelatesToFromEvent(evt *event.Event) (id.EventID, event.RelationType) { - if evt.StateKey != nil { - return "", "" - } - return GetRelatesToFromBytes(evt.Content.VeryRaw) -} - -func GetRelatesToFromBytes(content []byte) (id.EventID, event.RelationType) { - results := gjson.GetManyBytes(content, relatesToPath, relationTypePath) - if len(results) == 2 && results[0].Exists() && results[1].Exists() && results[0].Type == gjson.String && results[1].Type == gjson.String { - return id.EventID(results[0].Str), event.RelationType(results[1].Str) - } - return "", "" -} - -func getMegolmSessionID(evt *event.Event) id.SessionID { - if evt.Type != event.EventEncrypted { - return "" - } - res := gjson.GetBytes(evt.Content.VeryRaw, "session_id") - if res.Exists() && res.Type == gjson.String { - return id.SessionID(res.Str) - } - return "" -} - -func (e *Event) sqlVariables() []any { - var reactions any - if e.Reactions != nil { - reactions = e.Reactions - } - return []any{ - e.RoomID, - e.ID, - e.Sender, - e.Type, - e.StateKey, - e.Timestamp.UnixMilli(), - unsafeJSONString(e.Content), - unsafeJSONString(e.Decrypted), - dbutil.StrPtr(e.DecryptedType), - unsafeJSONString(e.Unsigned), - dbutil.JSONPtr(e.LocalContent), - dbutil.StrPtr(e.TransactionID), - dbutil.StrPtr(e.RedactedBy), - dbutil.StrPtr(e.RelatesTo), - dbutil.StrPtr(e.RelationType), - dbutil.StrPtr(e.MegolmSessionID), - dbutil.StrPtr(e.DecryptionError), - dbutil.StrPtr(e.SendError), - dbutil.JSON{Data: reactions}, - e.LastEditRowID, - e.UnreadType, - } -} - -func (e *Event) GetNonPushUnreadType() UnreadType { - if e.RelationType == event.RelReplace { - return UnreadTypeNone - } - switch e.Type { - case event.EventMessage.Type, event.EventSticker.Type: - return UnreadTypeNormal - case event.EventEncrypted.Type: - switch e.DecryptedType { - case event.EventMessage.Type, event.EventSticker.Type: - return UnreadTypeNormal - } - } - return UnreadTypeNone -} - -func (e *Event) CanUseForPreview() bool { - return (e.Type == event.EventMessage.Type || e.Type == event.EventSticker.Type || - (e.Type == event.EventEncrypted.Type && - (e.DecryptedType == event.EventMessage.Type || e.DecryptedType == event.EventSticker.Type))) && - e.RelationType != event.RelReplace && e.RedactedBy == "" -} - -func (e *Event) BumpsSortingTimestamp() bool { - return (e.Type == event.EventMessage.Type || e.Type == event.EventSticker.Type || e.Type == event.EventEncrypted.Type) && - e.RelationType != event.RelReplace -} diff --git a/hicli/database/receipt.go b/hicli/database/receipt.go deleted file mode 100644 index 8830efc7..00000000 --- a/hicli/database/receipt.go +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package database - -import ( - "context" - "time" - - "go.mau.fi/util/dbutil" - "go.mau.fi/util/exslices" - "go.mau.fi/util/jsontime" - - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -const ( - upsertReceiptQuery = ` - INSERT INTO receipt (room_id, user_id, receipt_type, thread_id, event_id, timestamp) - VALUES ($1, $2, $3, $4, $5, $6) - ON CONFLICT (room_id, user_id, receipt_type, thread_id) DO UPDATE - SET event_id = excluded.event_id, - timestamp = excluded.timestamp - ` -) - -var receiptMassInserter = dbutil.NewMassInsertBuilder[*Receipt, [1]any](upsertReceiptQuery, "($1, $%d, $%d, $%d, $%d, $%d)") - -type ReceiptQuery struct { - *dbutil.QueryHelper[*Receipt] -} - -func (rq *ReceiptQuery) Put(ctx context.Context, receipt *Receipt) error { - return rq.Exec(ctx, upsertReceiptQuery, receipt.sqlVariables()...) -} - -func (rq *ReceiptQuery) PutMany(ctx context.Context, roomID id.RoomID, receipts ...*Receipt) error { - if len(receipts) > 1000 { - return rq.GetDB().DoTxn(ctx, nil, func(ctx context.Context) error { - for _, receiptChunk := range exslices.Chunk(receipts, 200) { - err := rq.PutMany(ctx, roomID, receiptChunk...) - if err != nil { - return err - } - } - return nil - }) - } - query, params := receiptMassInserter.Build([1]any{roomID}, receipts) - return rq.Exec(ctx, query, params...) -} - -type Receipt struct { - RoomID id.RoomID `json:"room_id"` - UserID id.UserID `json:"user_id"` - ReceiptType event.ReceiptType `json:"receipt_type"` - ThreadID event.ThreadID `json:"thread_id"` - EventID id.EventID `json:"event_id"` - Timestamp jsontime.UnixMilli `json:"timestamp"` -} - -func (r *Receipt) Scan(row dbutil.Scannable) (*Receipt, error) { - var ts int64 - err := row.Scan(&r.RoomID, &r.UserID, &r.ReceiptType, &r.ThreadID, &r.EventID, &ts) - if err != nil { - return nil, err - } - r.Timestamp = jsontime.UM(time.UnixMilli(ts)) - return r, nil -} - -func (r *Receipt) sqlVariables() []any { - return []any{r.RoomID, r.UserID, r.ReceiptType, r.ThreadID, r.EventID, r.Timestamp.UnixMilli()} -} - -func (r *Receipt) GetMassInsertValues() [5]any { - return [5]any{r.UserID, r.ReceiptType, r.ThreadID, r.EventID, r.Timestamp.UnixMilli()} -} diff --git a/hicli/database/room.go b/hicli/database/room.go deleted file mode 100644 index 42108022..00000000 --- a/hicli/database/room.go +++ /dev/null @@ -1,279 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package database - -import ( - "context" - "database/sql" - "errors" - "time" - - "go.mau.fi/util/dbutil" - "go.mau.fi/util/jsontime" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -const ( - getRoomBaseQuery = ` - SELECT room_id, creation_content, name, name_quality, avatar, explicit_avatar, topic, canonical_alias, - lazy_load_summary, encryption_event, has_member_list, preview_event_rowid, sorting_timestamp, - unread_highlights, unread_notifications, unread_messages, prev_batch - FROM room - ` - getRoomsBySortingTimestampQuery = getRoomBaseQuery + `WHERE sorting_timestamp < $1 AND sorting_timestamp > 0 ORDER BY sorting_timestamp DESC LIMIT $2` - getRoomByIDQuery = getRoomBaseQuery + `WHERE room_id = $1` - ensureRoomExistsQuery = ` - INSERT INTO room (room_id) VALUES ($1) - ON CONFLICT (room_id) DO NOTHING - ` - upsertRoomFromSyncQuery = ` - UPDATE room - SET creation_content = COALESCE(room.creation_content, $2), - name = COALESCE($3, room.name), - name_quality = CASE WHEN $3 IS NOT NULL THEN $4 ELSE room.name_quality END, - avatar = COALESCE($5, room.avatar), - explicit_avatar = CASE WHEN $5 IS NOT NULL THEN $6 ELSE room.explicit_avatar END, - topic = COALESCE($7, room.topic), - canonical_alias = COALESCE($8, room.canonical_alias), - lazy_load_summary = COALESCE($9, room.lazy_load_summary), - encryption_event = COALESCE($10, room.encryption_event), - has_member_list = room.has_member_list OR $11, - preview_event_rowid = COALESCE($12, room.preview_event_rowid), - sorting_timestamp = COALESCE($13, room.sorting_timestamp), - unread_highlights = COALESCE($14, room.unread_highlights), - unread_notifications = COALESCE($15, room.unread_notifications), - unread_messages = COALESCE($16, room.unread_messages), - prev_batch = COALESCE($17, room.prev_batch) - WHERE room_id = $1 - ` - setRoomPrevBatchQuery = ` - UPDATE room SET prev_batch = $2 WHERE room_id = $1 - ` - updateRoomPreviewIfLaterOnTimelineQuery = ` - UPDATE room - SET preview_event_rowid = $2 - WHERE room_id = $1 - AND COALESCE((SELECT rowid FROM timeline WHERE event_rowid = $2), -1) - > COALESCE((SELECT rowid FROM timeline WHERE event_rowid = preview_event_rowid), 0) - RETURNING preview_event_rowid - ` - recalculateRoomPreviewEventQuery = ` - SELECT rowid - FROM event - WHERE - room_id = $1 - AND (type IN ('m.room.message', 'm.sticker') - OR (type = 'm.room.encrypted' - AND decrypted_type IN ('m.room.message', 'm.sticker'))) - AND relation_type <> 'm.replace' - AND redacted_by IS NULL - ORDER BY timestamp DESC - LIMIT 1 - ` -) - -type RoomQuery struct { - *dbutil.QueryHelper[*Room] -} - -func (rq *RoomQuery) Get(ctx context.Context, roomID id.RoomID) (*Room, error) { - return rq.QueryOne(ctx, getRoomByIDQuery, roomID) -} - -func (rq *RoomQuery) GetBySortTS(ctx context.Context, maxTS time.Time, limit int) ([]*Room, error) { - return rq.QueryMany(ctx, getRoomsBySortingTimestampQuery, maxTS.UnixMilli(), limit) -} - -func (rq *RoomQuery) Upsert(ctx context.Context, room *Room) error { - return rq.Exec(ctx, upsertRoomFromSyncQuery, room.sqlVariables()...) -} - -func (rq *RoomQuery) CreateRow(ctx context.Context, roomID id.RoomID) error { - return rq.Exec(ctx, ensureRoomExistsQuery, roomID) -} - -func (rq *RoomQuery) SetPrevBatch(ctx context.Context, roomID id.RoomID, prevBatch string) error { - return rq.Exec(ctx, setRoomPrevBatchQuery, roomID, prevBatch) -} - -func (rq *RoomQuery) UpdatePreviewIfLaterOnTimeline(ctx context.Context, roomID id.RoomID, rowID EventRowID) (previewChanged bool, err error) { - var newPreviewRowID EventRowID - err = rq.GetDB().QueryRow(ctx, updateRoomPreviewIfLaterOnTimelineQuery, roomID, rowID).Scan(&newPreviewRowID) - if errors.Is(err, sql.ErrNoRows) { - err = nil - } else if err == nil { - previewChanged = newPreviewRowID == rowID - } - return -} - -func (rq *RoomQuery) RecalculatePreview(ctx context.Context, roomID id.RoomID) (rowID EventRowID, err error) { - err = rq.GetDB().QueryRow(ctx, recalculateRoomPreviewEventQuery, roomID).Scan(&rowID) - return -} - -type NameQuality int - -const ( - NameQualityNil NameQuality = iota - NameQualityParticipants - NameQualityCanonicalAlias - NameQualityExplicit -) - -const PrevBatchPaginationComplete = "fi.mau.gomuks.pagination_complete" - -type Room struct { - ID id.RoomID `json:"room_id"` - CreationContent *event.CreateEventContent `json:"creation_content,omitempty"` - - Name *string `json:"name,omitempty"` - NameQuality NameQuality `json:"name_quality"` - Avatar *id.ContentURI `json:"avatar,omitempty"` - ExplicitAvatar bool `json:"explicit_avatar"` - Topic *string `json:"topic,omitempty"` - CanonicalAlias *id.RoomAlias `json:"canonical_alias,omitempty"` - - LazyLoadSummary *mautrix.LazyLoadSummary `json:"lazy_load_summary,omitempty"` - - EncryptionEvent *event.EncryptionEventContent `json:"encryption_event,omitempty"` - HasMemberList bool `json:"has_member_list"` - - PreviewEventRowID EventRowID `json:"preview_event_rowid"` - SortingTimestamp jsontime.UnixMilli `json:"sorting_timestamp"` - UnreadHighlights int `json:"unread_highlights"` - UnreadNotifications int `json:"unread_notifications"` - UnreadMessages int `json:"unread_messages"` - - PrevBatch string `json:"prev_batch"` -} - -func (r *Room) CheckChangesAndCopyInto(other *Room) (hasChanges bool) { - if r.Name != nil && r.NameQuality >= other.NameQuality { - other.Name = r.Name - other.NameQuality = r.NameQuality - hasChanges = true - } - if r.Avatar != nil { - other.Avatar = r.Avatar - other.ExplicitAvatar = r.ExplicitAvatar - hasChanges = true - } - if r.Topic != nil { - other.Topic = r.Topic - hasChanges = true - } - if r.CanonicalAlias != nil { - other.CanonicalAlias = r.CanonicalAlias - hasChanges = true - } - if r.LazyLoadSummary != nil { - other.LazyLoadSummary = r.LazyLoadSummary - hasChanges = true - } - if r.EncryptionEvent != nil && other.EncryptionEvent == nil { - other.EncryptionEvent = r.EncryptionEvent - hasChanges = true - } - if r.HasMemberList && !other.HasMemberList { - hasChanges = true - other.HasMemberList = true - } - if r.PreviewEventRowID > other.PreviewEventRowID { - other.PreviewEventRowID = r.PreviewEventRowID - hasChanges = true - } - if r.SortingTimestamp.After(other.SortingTimestamp.Time) { - other.SortingTimestamp = r.SortingTimestamp - hasChanges = true - } - if r.UnreadHighlights != other.UnreadHighlights { - other.UnreadHighlights = r.UnreadHighlights - hasChanges = true - } - if r.UnreadNotifications != other.UnreadNotifications { - other.UnreadNotifications = r.UnreadNotifications - hasChanges = true - } - if r.UnreadMessages != other.UnreadMessages { - other.UnreadMessages = r.UnreadMessages - hasChanges = true - } - if r.PrevBatch != "" && other.PrevBatch == "" { - other.PrevBatch = r.PrevBatch - hasChanges = true - } - return -} - -func (r *Room) Scan(row dbutil.Scannable) (*Room, error) { - var prevBatch sql.NullString - var previewEventRowID, sortingTimestamp sql.NullInt64 - err := row.Scan( - &r.ID, - dbutil.JSON{Data: &r.CreationContent}, - &r.Name, - &r.NameQuality, - &r.Avatar, - &r.ExplicitAvatar, - &r.Topic, - &r.CanonicalAlias, - dbutil.JSON{Data: &r.LazyLoadSummary}, - dbutil.JSON{Data: &r.EncryptionEvent}, - &r.HasMemberList, - &previewEventRowID, - &sortingTimestamp, - &r.UnreadHighlights, - &r.UnreadNotifications, - &r.UnreadMessages, - &prevBatch, - ) - if err != nil { - return nil, err - } - r.PrevBatch = prevBatch.String - r.PreviewEventRowID = EventRowID(previewEventRowID.Int64) - r.SortingTimestamp = jsontime.UM(time.UnixMilli(sortingTimestamp.Int64)) - return r, nil -} - -func (r *Room) sqlVariables() []any { - return []any{ - r.ID, - dbutil.JSONPtr(r.CreationContent), - r.Name, - r.NameQuality, - r.Avatar, - r.ExplicitAvatar, - r.Topic, - r.CanonicalAlias, - dbutil.JSONPtr(r.LazyLoadSummary), - dbutil.JSONPtr(r.EncryptionEvent), - r.HasMemberList, - dbutil.NumPtr(r.PreviewEventRowID), - dbutil.UnixMilliPtr(r.SortingTimestamp.Time), - r.UnreadHighlights, - r.UnreadNotifications, - r.UnreadMessages, - dbutil.StrPtr(r.PrevBatch), - } -} - -func (r *Room) BumpSortingTimestamp(evt *Event) bool { - if !evt.BumpsSortingTimestamp() || evt.Timestamp.Before(r.SortingTimestamp.Time) { - return false - } - r.SortingTimestamp = evt.Timestamp - now := time.Now() - if r.SortingTimestamp.After(now) { - r.SortingTimestamp = jsontime.UM(now) - } - return true -} diff --git a/hicli/database/sessionrequest.go b/hicli/database/sessionrequest.go deleted file mode 100644 index 6690c13f..00000000 --- a/hicli/database/sessionrequest.go +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package database - -import ( - "context" - - "go.mau.fi/util/dbutil" - - "maunium.net/go/mautrix/id" -) - -const ( - putSessionRequestQueueEntry = ` - INSERT INTO session_request (room_id, session_id, sender, min_index, backup_checked, request_sent) - VALUES ($1, $2, $3, $4, $5, $6) - ON CONFLICT (session_id) DO UPDATE - SET min_index = MIN(excluded.min_index, session_request.min_index), - backup_checked = excluded.backup_checked OR session_request.backup_checked, - request_sent = excluded.request_sent OR session_request.request_sent - ` - removeSessionRequestQuery = ` - DELETE FROM session_request WHERE session_id = $1 AND min_index >= $2 - ` - getNextSessionsToRequestQuery = ` - SELECT room_id, session_id, sender, min_index, backup_checked, request_sent - FROM session_request - WHERE request_sent = false OR backup_checked = false - ORDER BY backup_checked, rowid - LIMIT $1 - ` -) - -type SessionRequestQuery struct { - *dbutil.QueryHelper[*SessionRequest] -} - -func (srq *SessionRequestQuery) Next(ctx context.Context, count int) ([]*SessionRequest, error) { - return srq.QueryMany(ctx, getNextSessionsToRequestQuery, count) -} - -func (srq *SessionRequestQuery) Remove(ctx context.Context, sessionID id.SessionID, minIndex uint32) error { - return srq.Exec(ctx, removeSessionRequestQuery, sessionID, minIndex) -} - -func (srq *SessionRequestQuery) Put(ctx context.Context, sr *SessionRequest) error { - return srq.Exec(ctx, putSessionRequestQueueEntry, sr.sqlVariables()...) -} - -type SessionRequest struct { - RoomID id.RoomID - SessionID id.SessionID - Sender id.UserID - MinIndex uint32 - BackupChecked bool - RequestSent bool -} - -func (s *SessionRequest) Scan(row dbutil.Scannable) (*SessionRequest, error) { - return dbutil.ValueOrErr(s, row.Scan(&s.RoomID, &s.SessionID, &s.Sender, &s.MinIndex, &s.BackupChecked, &s.RequestSent)) -} - -func (s *SessionRequest) sqlVariables() []any { - return []any{s.RoomID, s.SessionID, s.Sender, s.MinIndex, s.BackupChecked, s.RequestSent} -} diff --git a/hicli/database/state.go b/hicli/database/state.go deleted file mode 100644 index d6fbf53d..00000000 --- a/hicli/database/state.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package database - -import ( - "context" - "fmt" - - "go.mau.fi/util/dbutil" - "go.mau.fi/util/exslices" - - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -const ( - setCurrentStateQuery = ` - INSERT INTO current_state (room_id, event_type, state_key, event_rowid, membership) VALUES ($1, $2, $3, $4, $5) - ON CONFLICT (room_id, event_type, state_key) DO UPDATE SET event_rowid = excluded.event_rowid, membership = excluded.membership - ` - addCurrentStateQuery = ` - INSERT INTO current_state (room_id, event_type, state_key, event_rowid, membership) VALUES ($1, $2, $3, $4, $5) - ON CONFLICT DO NOTHING - ` - deleteCurrentStateQuery = ` - DELETE FROM current_state WHERE room_id = $1 - ` - getCurrentRoomStateQuery = ` - SELECT event.rowid, -1, - event.room_id, event.event_id, sender, event.type, event.state_key, timestamp, content, decrypted, decrypted_type, - unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type, - megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type - FROM current_state cs - JOIN event ON cs.event_rowid = event.rowid - WHERE cs.room_id = $1 - ` - getCurrentStateEventQuery = getCurrentRoomStateQuery + `AND cs.event_type = $2 AND cs.state_key = $3` -) - -var massInsertCurrentStateBuilder = dbutil.NewMassInsertBuilder[*CurrentStateEntry, [1]any](addCurrentStateQuery, "($1, $%d, $%d, $%d, $%d)") - -const currentStateMassInsertBatchSize = 1000 - -type CurrentStateEntry struct { - EventType event.Type - StateKey string - EventRowID EventRowID - Membership event.Membership -} - -func (cse *CurrentStateEntry) GetMassInsertValues() [4]any { - return [4]any{cse.EventType.Type, cse.StateKey, cse.EventRowID, dbutil.StrPtr(cse.Membership)} -} - -type CurrentStateQuery struct { - *dbutil.QueryHelper[*Event] -} - -func (csq *CurrentStateQuery) Set(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, eventRowID EventRowID, membership event.Membership) error { - return csq.Exec(ctx, setCurrentStateQuery, roomID, eventType.Type, stateKey, eventRowID, dbutil.StrPtr(membership)) -} - -func (csq *CurrentStateQuery) AddMany(ctx context.Context, roomID id.RoomID, deleteOld bool, entries []*CurrentStateEntry) error { - var err error - if deleteOld { - err = csq.Exec(ctx, deleteCurrentStateQuery, roomID) - if err != nil { - return fmt.Errorf("failed to delete old state: %w", err) - } - } - for _, entryChunk := range exslices.Chunk(entries, currentStateMassInsertBatchSize) { - query, params := massInsertCurrentStateBuilder.Build([1]any{roomID}, entryChunk) - err = csq.Exec(ctx, query, params...) - if err != nil { - return err - } - } - return nil -} - -func (csq *CurrentStateQuery) Add(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, eventRowID EventRowID, membership event.Membership) error { - return csq.Exec(ctx, addCurrentStateQuery, roomID, eventType.Type, stateKey, eventRowID, dbutil.StrPtr(membership)) -} - -func (csq *CurrentStateQuery) Get(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*Event, error) { - return csq.QueryOne(ctx, getCurrentStateEventQuery, roomID, eventType.Type, stateKey) -} - -func (csq *CurrentStateQuery) GetAll(ctx context.Context, roomID id.RoomID) ([]*Event, error) { - return csq.QueryMany(ctx, getCurrentRoomStateQuery, roomID) -} diff --git a/hicli/database/statestore.go b/hicli/database/statestore.go deleted file mode 100644 index fcd6aceb..00000000 --- a/hicli/database/statestore.go +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package database - -import ( - "context" - "database/sql" - "errors" - "fmt" - - "go.mau.fi/util/dbutil" - "golang.org/x/exp/slices" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/crypto" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -const ( - getMembershipQuery = ` - SELECT membership FROM current_state - WHERE room_id = $1 AND event_type = 'm.room.member' AND state_key = $2 - ` - getStateEventContentQuery = ` - SELECT event.content FROM current_state cs - LEFT JOIN event ON event.rowid = cs.event_rowid - WHERE cs.room_id = $1 AND cs.event_type = $2 AND cs.state_key = $3 - ` - getRoomJoinedMembersQuery = ` - SELECT state_key FROM current_state - WHERE room_id = $1 AND event_type = 'm.room.member' AND membership = 'join' - ` - getRoomJoinedOrInvitedMembersQuery = ` - SELECT state_key FROM current_state - WHERE room_id = $1 AND event_type = 'm.room.member' AND membership IN ('join', 'invite') - ` - getHasFetchedMembersQuery = ` - SELECT has_member_list FROM room WHERE room_id = $1 - ` - isRoomEncryptedQuery = ` - SELECT room.encryption_event IS NOT NULL FROM room WHERE room_id = $1 - ` - getRoomEncryptionEventQuery = ` - SELECT room.encryption_event FROM room WHERE room_id = $1 - ` - findSharedRoomsQuery = ` - SELECT room_id FROM current_state - WHERE event_type = 'm.room.member' AND state_key = $1 AND membership = 'join' - ` -) - -type ClientStateStore struct { - *Database -} - -var _ mautrix.StateStore = (*ClientStateStore)(nil) -var _ mautrix.StateStoreUpdater = (*ClientStateStore)(nil) -var _ crypto.StateStore = (*ClientStateStore)(nil) - -func (c *ClientStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool { - return c.IsMembership(ctx, roomID, userID, event.MembershipJoin) -} - -func (c *ClientStateStore) IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool { - return c.IsMembership(ctx, roomID, userID, event.MembershipInvite, event.MembershipJoin) -} - -func (c *ClientStateStore) IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool { - var membership event.Membership - err := c.QueryRow(ctx, getMembershipQuery, roomID, userID).Scan(&membership) - if errors.Is(err, sql.ErrNoRows) { - err = nil - membership = event.MembershipLeave - } - return slices.Contains(allowedMemberships, membership) -} - -func (c *ClientStateStore) GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) { - content, err := c.TryGetMember(ctx, roomID, userID) - if content == nil { - content = &event.MemberEventContent{Membership: event.MembershipLeave} - } - return content, err -} - -func (c *ClientStateStore) TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (content *event.MemberEventContent, err error) { - err = c.QueryRow(ctx, getStateEventContentQuery, roomID, event.StateMember.Type, userID).Scan(&dbutil.JSON{Data: &content}) - if errors.Is(err, sql.ErrNoRows) { - err = nil - } - return -} - -func (c *ClientStateStore) IsConfusableName(ctx context.Context, roomID id.RoomID, currentUser id.UserID, name string) ([]id.UserID, error) { - //TODO implement me - panic("implement me") -} - -func (c *ClientStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (content *event.PowerLevelsEventContent, err error) { - err = c.QueryRow(ctx, getStateEventContentQuery, roomID, event.StatePowerLevels.Type, "").Scan(&dbutil.JSON{Data: &content}) - if errors.Is(err, sql.ErrNoRows) { - err = nil - } - return -} - -func (c *ClientStateStore) GetRoomJoinedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) { - rows, err := c.Query(ctx, getRoomJoinedMembersQuery, roomID) - return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() -} - -func (c *ClientStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) { - rows, err := c.Query(ctx, getRoomJoinedOrInvitedMembersQuery, roomID) - return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() -} - -func (c *ClientStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (hasFetched bool, err error) { - //err = c.QueryRow(ctx, getHasFetchedMembersQuery, roomID).Scan(&hasFetched) - //if errors.Is(err, sql.ErrNoRows) { - // err = nil - //} - //return - return false, fmt.Errorf("not implemented") -} - -func (c *ClientStateStore) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error { - return fmt.Errorf("not implemented") -} - -func (c *ClientStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { - return nil, fmt.Errorf("not implemented") -} - -func (c *ClientStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (isEncrypted bool, err error) { - err = c.QueryRow(ctx, isRoomEncryptedQuery, roomID).Scan(&isEncrypted) - if errors.Is(err, sql.ErrNoRows) { - err = nil - } - return -} - -func (c *ClientStateStore) GetEncryptionEvent(ctx context.Context, roomID id.RoomID) (content *event.EncryptionEventContent, err error) { - err = c.QueryRow(ctx, getRoomEncryptionEventQuery, roomID). - Scan(&dbutil.JSON{Data: &content}) - if errors.Is(err, sql.ErrNoRows) { - err = nil - } - return -} - -func (c *ClientStateStore) FindSharedRooms(ctx context.Context, userID id.UserID) ([]id.RoomID, error) { - // TODO for multiuser support, this might need to filter by the local user's membership - rows, err := c.Query(ctx, findSharedRoomsQuery, userID) - return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.RoomID], err).AsList() -} - -// Update methods are all intentionally no-ops as the state store wants to have the full event - -func (c *ClientStateStore) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error { - return nil -} - -func (c *ClientStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error { - return nil -} - -func (c *ClientStateStore) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error { - return nil -} - -func (c *ClientStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error { - return nil -} - -func (c *ClientStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error { - return nil -} - -func (c *ClientStateStore) UpdateState(ctx context.Context, evt *event.Event) {} - -func (c *ClientStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error { - return nil -} diff --git a/hicli/database/timeline.go b/hicli/database/timeline.go deleted file mode 100644 index e04eeb88..00000000 --- a/hicli/database/timeline.go +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package database - -import ( - "context" - "database/sql" - "errors" - "sync" - - "go.mau.fi/util/dbutil" - - "maunium.net/go/mautrix/id" -) - -const ( - clearTimelineQuery = ` - DELETE FROM timeline WHERE room_id = $1 - ` - appendTimelineQuery = ` - INSERT INTO timeline (room_id, event_rowid) VALUES ($1, $2) - ON CONFLICT DO NOTHING - RETURNING rowid, event_rowid - ` - prependTimelineQuery = ` - INSERT INTO timeline (room_id, rowid, event_rowid) VALUES ($1, $2, $3) - ` - checkTimelineContainsQuery = ` - SELECT EXISTS(SELECT 1 FROM timeline WHERE room_id = $1 AND event_rowid = $2) - ` - findMinRowIDQuery = `SELECT MIN(rowid) FROM timeline` - getTimelineQuery = ` - SELECT event.rowid, timeline.rowid, - event.room_id, event_id, sender, type, state_key, timestamp, content, decrypted, decrypted_type, - unsigned, local_content, transaction_id, redacted_by, relates_to, relation_type, - megolm_session_id, decryption_error, send_error, reactions, last_edit_rowid, unread_type - FROM timeline - JOIN event ON event.rowid = timeline.event_rowid - WHERE timeline.room_id = $1 AND ($2 = 0 OR timeline.rowid < $2) - ORDER BY timeline.rowid DESC - LIMIT $3 - ` -) - -type TimelineRowID int64 - -type TimelineRowTuple struct { - Timeline TimelineRowID `json:"timeline_rowid"` - Event EventRowID `json:"event_rowid"` -} - -var timelineRowTupleScanner = dbutil.ConvertRowFn[TimelineRowTuple](func(row dbutil.Scannable) (trt TimelineRowTuple, err error) { - err = row.Scan(&trt.Timeline, &trt.Event) - return -}) - -func (trt TimelineRowTuple) GetMassInsertValues() [2]any { - return [2]any{trt.Timeline, trt.Event} -} - -var appendTimelineQueryBuilder = dbutil.NewMassInsertBuilder[EventRowID, [1]any](appendTimelineQuery, "($1, $%d)") -var prependTimelineQueryBuilder = dbutil.NewMassInsertBuilder[TimelineRowTuple, [1]any](prependTimelineQuery, "($1, $%d, $%d)") - -type TimelineQuery struct { - *dbutil.QueryHelper[*Event] - - minRowID TimelineRowID - minRowIDFound bool - prependLock sync.Mutex -} - -// Clear clears the timeline of a given room. -func (tq *TimelineQuery) Clear(ctx context.Context, roomID id.RoomID) error { - return tq.Exec(ctx, clearTimelineQuery, roomID) -} - -func (tq *TimelineQuery) reserveRowIDs(ctx context.Context, count int) (startFrom TimelineRowID, err error) { - tq.prependLock.Lock() - defer tq.prependLock.Unlock() - if !tq.minRowIDFound { - err = tq.GetDB().QueryRow(ctx, findMinRowIDQuery).Scan(&tq.minRowID) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - return - } - if tq.minRowID >= 0 { - // No negative row IDs exist, start at -2 - tq.minRowID = -2 - } else { - // We fetched the lowest row ID, but we want the next available one, so decrement one - tq.minRowID-- - } - tq.minRowIDFound = true - } - startFrom = tq.minRowID - tq.minRowID -= TimelineRowID(count) - return -} - -// Prepend adds the given event row IDs to the beginning of the timeline. -// The events must be sorted in reverse chronological order (newest event first). -func (tq *TimelineQuery) Prepend(ctx context.Context, roomID id.RoomID, rowIDs []EventRowID) (prependEntries []TimelineRowTuple, err error) { - var startFrom TimelineRowID - startFrom, err = tq.reserveRowIDs(ctx, len(rowIDs)) - if err != nil { - return - } - prependEntries = make([]TimelineRowTuple, len(rowIDs)) - for i, rowID := range rowIDs { - prependEntries[i] = TimelineRowTuple{ - Timeline: startFrom - TimelineRowID(i), - Event: rowID, - } - } - query, params := prependTimelineQueryBuilder.Build([1]any{roomID}, prependEntries) - err = tq.Exec(ctx, query, params...) - return -} - -// Append adds the given event row IDs to the end of the timeline. -func (tq *TimelineQuery) Append(ctx context.Context, roomID id.RoomID, rowIDs []EventRowID) ([]TimelineRowTuple, error) { - query, params := appendTimelineQueryBuilder.Build([1]any{roomID}, rowIDs) - return timelineRowTupleScanner.NewRowIter(tq.GetDB().Query(ctx, query, params...)).AsList() -} - -func (tq *TimelineQuery) Get(ctx context.Context, roomID id.RoomID, limit int, before TimelineRowID) ([]*Event, error) { - return tq.QueryMany(ctx, getTimelineQuery, roomID, before, limit) -} - -func (tq *TimelineQuery) Has(ctx context.Context, roomID id.RoomID, eventRowID EventRowID) (exists bool, err error) { - err = tq.GetDB().QueryRow(ctx, checkTimelineContainsQuery, roomID, eventRowID).Scan(&exists) - return -} diff --git a/hicli/database/upgrades/00-latest-revision.sql b/hicli/database/upgrades/00-latest-revision.sql deleted file mode 100644 index 0808a6e9..00000000 --- a/hicli/database/upgrades/00-latest-revision.sql +++ /dev/null @@ -1,255 +0,0 @@ --- v0 -> v3 (compatible with v1+): Latest revision -CREATE TABLE account ( - user_id TEXT NOT NULL PRIMARY KEY, - device_id TEXT NOT NULL, - access_token TEXT NOT NULL, - homeserver_url TEXT NOT NULL, - - next_batch TEXT NOT NULL -) STRICT; - -CREATE TABLE room ( - room_id TEXT NOT NULL PRIMARY KEY, - creation_content TEXT, - - name TEXT, - name_quality INTEGER NOT NULL DEFAULT 0, - avatar TEXT, - explicit_avatar INTEGER NOT NULL DEFAULT 0, - topic TEXT, - canonical_alias TEXT, - lazy_load_summary TEXT, - - encryption_event TEXT, - has_member_list INTEGER NOT NULL DEFAULT false, - - preview_event_rowid INTEGER, - sorting_timestamp INTEGER, - unread_highlights INTEGER NOT NULL DEFAULT 0, - unread_notifications INTEGER NOT NULL DEFAULT 0, - unread_messages INTEGER NOT NULL DEFAULT 0, - - prev_batch TEXT, - - CONSTRAINT room_preview_event_fkey FOREIGN KEY (preview_event_rowid) REFERENCES event (rowid) ON DELETE SET NULL -) STRICT; -CREATE INDEX room_type_idx ON room (creation_content ->> 'type'); -CREATE INDEX room_sorting_timestamp_idx ON room (sorting_timestamp DESC); --- CREATE INDEX room_sorting_timestamp_idx ON room (unread_notifications > 0); --- CREATE INDEX room_sorting_timestamp_idx ON room (unread_messages > 0); - -CREATE TABLE account_data ( - user_id TEXT NOT NULL, - type TEXT NOT NULL, - content TEXT NOT NULL, - - PRIMARY KEY (user_id, type) -) STRICT; - -CREATE TABLE room_account_data ( - user_id TEXT NOT NULL, - room_id TEXT NOT NULL, - type TEXT NOT NULL, - content TEXT NOT NULL, - - PRIMARY KEY (user_id, room_id, type), - CONSTRAINT room_account_data_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE -) STRICT; -CREATE INDEX room_account_data_room_id_idx ON room_account_data (room_id); - -CREATE TABLE event ( - rowid INTEGER PRIMARY KEY, - - room_id TEXT NOT NULL, - event_id TEXT NOT NULL, - sender TEXT NOT NULL, - type TEXT NOT NULL, - state_key TEXT, - timestamp INTEGER NOT NULL, - - content TEXT NOT NULL, - decrypted TEXT, - decrypted_type TEXT, - unsigned TEXT NOT NULL, - local_content TEXT, - - transaction_id TEXT, - - redacted_by TEXT, - relates_to TEXT, - relation_type TEXT, - - megolm_session_id TEXT, - decryption_error TEXT, - send_error TEXT, - - reactions TEXT, - last_edit_rowid INTEGER, - unread_type INTEGER NOT NULL DEFAULT 0, - - CONSTRAINT event_id_unique_key UNIQUE (event_id), - CONSTRAINT transaction_id_unique_key UNIQUE (transaction_id), - CONSTRAINT event_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE -) STRICT; -CREATE INDEX event_room_id_idx ON event (room_id); -CREATE INDEX event_redacted_by_idx ON event (room_id, redacted_by); -CREATE INDEX event_relates_to_idx ON event (room_id, relates_to); -CREATE INDEX event_megolm_session_id_idx ON event (room_id, megolm_session_id); - -CREATE TRIGGER event_update_redacted_by - AFTER INSERT - ON event - WHEN NEW.type = 'm.room.redaction' -BEGIN - UPDATE event SET redacted_by = NEW.event_id WHERE room_id = NEW.room_id AND event_id = NEW.content ->> 'redacts'; -END; - -CREATE TRIGGER event_update_last_edit_when_redacted - AFTER UPDATE - ON event - WHEN OLD.redacted_by IS NULL - AND NEW.redacted_by IS NOT NULL - AND NEW.relation_type = 'm.replace' - AND NEW.state_key IS NULL -BEGIN - UPDATE event - SET last_edit_rowid = COALESCE( - (SELECT rowid - FROM event edit - WHERE edit.room_id = event.room_id - AND edit.relates_to = event.event_id - AND edit.relation_type = 'm.replace' - AND edit.type = event.type - AND edit.sender = event.sender - AND edit.redacted_by IS NULL - AND edit.state_key IS NULL - ORDER BY edit.timestamp DESC - LIMIT 1), - 0) - WHERE event_id = NEW.relates_to - AND last_edit_rowid = NEW.rowid - AND state_key IS NULL - AND (relation_type IS NULL OR relation_type NOT IN ('m.replace', 'm.annotation')); -END; - -CREATE TRIGGER event_insert_update_last_edit - AFTER INSERT - ON event - WHEN NEW.relation_type = 'm.replace' - AND NEW.redacted_by IS NULL - AND NEW.state_key IS NULL -BEGIN - UPDATE event - SET last_edit_rowid = NEW.rowid - WHERE event_id = NEW.relates_to - AND type = NEW.type - AND sender = NEW.sender - AND state_key IS NULL - AND (relation_type IS NULL OR relation_type NOT IN ('m.replace', 'm.annotation')) - AND NEW.timestamp > - COALESCE((SELECT prev_edit.timestamp FROM event prev_edit WHERE prev_edit.rowid = event.last_edit_rowid), 0); -END; - -CREATE TRIGGER event_insert_fill_reactions - AFTER INSERT - ON event - WHEN NEW.type = 'm.reaction' - AND NEW.relation_type = 'm.annotation' - AND NEW.redacted_by IS NULL - AND typeof(NEW.content ->> '$."m.relates_to".key') = 'text' -BEGIN - UPDATE event - SET reactions=json_set( - reactions, - '$.' || json_quote(NEW.content ->> '$."m.relates_to".key'), - coalesce( - reactions ->> ('$.' || json_quote(NEW.content ->> '$."m.relates_to".key')), - 0 - ) + 1) - WHERE event_id = NEW.relates_to - AND reactions IS NOT NULL; -END; - -CREATE TRIGGER event_redact_fill_reactions - AFTER UPDATE - ON event - WHEN NEW.type = 'm.reaction' - AND NEW.relation_type = 'm.annotation' - AND NEW.redacted_by IS NOT NULL - AND OLD.redacted_by IS NULL - AND typeof(NEW.content ->> '$."m.relates_to".key') = 'text' -BEGIN - UPDATE event - SET reactions=json_set( - reactions, - '$.' || json_quote(NEW.content ->> '$."m.relates_to".key'), - coalesce( - reactions ->> ('$.' || json_quote(NEW.content ->> '$."m.relates_to".key')), - 0 - ) - 1) - WHERE event_id = NEW.relates_to - AND reactions IS NOT NULL; -END; - -CREATE TABLE cached_media ( - mxc TEXT NOT NULL PRIMARY KEY, - event_rowid INTEGER, - enc_file TEXT, - file_name TEXT, - mime_type TEXT, - size INTEGER, - hash BLOB, - error TEXT, - - CONSTRAINT cached_media_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid) ON DELETE SET NULL -) STRICT; - -CREATE TABLE session_request ( - room_id TEXT NOT NULL, - session_id TEXT NOT NULL, - sender TEXT NOT NULL, - min_index INTEGER NOT NULL, - backup_checked INTEGER NOT NULL DEFAULT false, - request_sent INTEGER NOT NULL DEFAULT false, - - PRIMARY KEY (session_id), - CONSTRAINT session_request_queue_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE -) STRICT; -CREATE INDEX session_request_room_idx ON session_request (room_id); - -CREATE TABLE timeline ( - rowid INTEGER PRIMARY KEY, - room_id TEXT NOT NULL, - event_rowid INTEGER NOT NULL, - - CONSTRAINT timeline_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE, - CONSTRAINT timeline_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid) ON DELETE CASCADE, - CONSTRAINT timeline_event_unique_key UNIQUE (event_rowid) -) STRICT; -CREATE INDEX timeline_room_id_idx ON timeline (room_id); - -CREATE TABLE current_state ( - room_id TEXT NOT NULL, - event_type TEXT NOT NULL, - state_key TEXT NOT NULL, - event_rowid INTEGER NOT NULL, - - membership TEXT, - - PRIMARY KEY (room_id, event_type, state_key), - CONSTRAINT current_state_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE, - CONSTRAINT current_state_event_fkey FOREIGN KEY (event_rowid) REFERENCES event (rowid) -) STRICT, WITHOUT ROWID; - -CREATE TABLE receipt ( - room_id TEXT NOT NULL, - user_id TEXT NOT NULL, - receipt_type TEXT NOT NULL, - thread_id TEXT NOT NULL, - event_id TEXT NOT NULL, - timestamp INTEGER NOT NULL, - - PRIMARY KEY (room_id, user_id, receipt_type, thread_id), - CONSTRAINT receipt_room_fkey FOREIGN KEY (room_id) REFERENCES room (room_id) ON DELETE CASCADE - -- note: there's no foreign key on event ID because receipts could point at events that are too far in history. -) STRICT; diff --git a/hicli/database/upgrades/02-explicit-avatar-flag.sql b/hicli/database/upgrades/02-explicit-avatar-flag.sql deleted file mode 100644 index c11e8801..00000000 --- a/hicli/database/upgrades/02-explicit-avatar-flag.sql +++ /dev/null @@ -1,2 +0,0 @@ --- v2 (compatible with v1+): Add explicit avatar flag to rooms -ALTER TABLE room ADD COLUMN explicit_avatar INTEGER NOT NULL DEFAULT 0; diff --git a/hicli/database/upgrades/03-more-event-fields.sql b/hicli/database/upgrades/03-more-event-fields.sql deleted file mode 100644 index 3e07ad75..00000000 --- a/hicli/database/upgrades/03-more-event-fields.sql +++ /dev/null @@ -1,6 +0,0 @@ --- v3 (compatible with v1+): Add more fields to events -ALTER TABLE event ADD COLUMN local_content TEXT; -ALTER TABLE event ADD COLUMN unread_type INTEGER NOT NULL DEFAULT 0; -ALTER TABLE room ADD COLUMN unread_highlights INTEGER NOT NULL DEFAULT 0; -ALTER TABLE room ADD COLUMN unread_notifications INTEGER NOT NULL DEFAULT 0; -ALTER TABLE room ADD COLUMN unread_messages INTEGER NOT NULL DEFAULT 0; diff --git a/hicli/database/upgrades/upgrades.go b/hicli/database/upgrades/upgrades.go deleted file mode 100644 index 9d0bd1a0..00000000 --- a/hicli/database/upgrades/upgrades.go +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package upgrades - -import ( - "embed" - - "go.mau.fi/util/dbutil" -) - -var Table dbutil.UpgradeTable - -//go:embed *.sql -var upgrades embed.FS - -func init() { - Table.RegisterFS(upgrades) -} diff --git a/hicli/decryptionqueue.go b/hicli/decryptionqueue.go deleted file mode 100644 index 665ee78a..00000000 --- a/hicli/decryptionqueue.go +++ /dev/null @@ -1,209 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package hicli - -import ( - "context" - "fmt" - "sync" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix/crypto" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/hicli/database" - "maunium.net/go/mautrix/id" -) - -func (h *HiClient) fetchFromKeyBackup(ctx context.Context, roomID id.RoomID, sessionID id.SessionID) (*crypto.InboundGroupSession, error) { - data, err := h.Client.GetKeyBackupForRoomAndSession(ctx, h.KeyBackupVersion, roomID, sessionID) - if err != nil { - return nil, err - } else if data == nil { - return nil, nil - } - decrypted, err := data.SessionData.Decrypt(h.KeyBackupKey) - if err != nil { - return nil, err - } - return h.Crypto.ImportRoomKeyFromBackup(ctx, h.KeyBackupVersion, roomID, sessionID, decrypted) -} - -func (h *HiClient) handleReceivedMegolmSession(ctx context.Context, roomID id.RoomID, sessionID id.SessionID, firstKnownIndex uint32) { - log := zerolog.Ctx(ctx) - err := h.DB.SessionRequest.Remove(ctx, sessionID, firstKnownIndex) - if err != nil { - log.Warn().Err(err).Msg("Failed to remove session request after receiving megolm session") - } - events, err := h.DB.Event.GetFailedByMegolmSessionID(ctx, roomID, sessionID) - if err != nil { - log.Err(err).Msg("Failed to get events that failed to decrypt to retry decryption") - return - } else if len(events) == 0 { - log.Trace().Msg("No events to retry decryption for") - return - } - decrypted := events[:0] - for _, evt := range events { - if evt.Decrypted != nil { - continue - } - - var mautrixEvt *event.Event - mautrixEvt, evt.Decrypted, evt.DecryptedType, err = h.decryptEvent(ctx, evt.AsRawMautrix()) - if err != nil { - log.Warn().Err(err).Stringer("event_id", evt.ID).Msg("Failed to decrypt event even after receiving megolm session") - } else { - decrypted = append(decrypted, evt) - h.postDecryptProcess(ctx, nil, evt, mautrixEvt) - } - } - if len(decrypted) > 0 { - var newPreview database.EventRowID - err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { - for _, evt := range decrypted { - err = h.DB.Event.UpdateDecrypted(ctx, evt) - if err != nil { - return fmt.Errorf("failed to save decrypted content for %s: %w", evt.ID, err) - } - if evt.CanUseForPreview() { - var previewChanged bool - previewChanged, err = h.DB.Room.UpdatePreviewIfLaterOnTimeline(ctx, evt.RoomID, evt.RowID) - if err != nil { - return fmt.Errorf("failed to update room %s preview to %d: %w", evt.RoomID, evt.RowID, err) - } else if previewChanged { - newPreview = evt.RowID - } - } - } - return nil - }) - if err != nil { - log.Err(err).Msg("Failed to save decrypted events") - } else { - h.EventHandler(&EventsDecrypted{Events: decrypted, PreviewEventRowID: newPreview, RoomID: roomID}) - } - } -} - -func (h *HiClient) WakeupRequestQueue() { - select { - case h.requestQueueWakeup <- struct{}{}: - default: - } -} - -func (h *HiClient) RunRequestQueue(ctx context.Context) { - log := zerolog.Ctx(ctx).With().Str("action", "request queue").Logger() - ctx = log.WithContext(ctx) - log.Info().Msg("Starting key request queue") - defer func() { - log.Info().Msg("Stopping key request queue") - }() - for { - err := h.FetchKeysForOutdatedUsers(ctx) - if err != nil { - log.Err(err).Msg("Failed to fetch outdated device lists for tracked users") - } - madeRequests, err := h.RequestQueuedSessions(ctx) - if err != nil { - log.Err(err).Msg("Failed to handle session request queue") - } else if madeRequests { - continue - } - select { - case <-ctx.Done(): - return - case <-h.requestQueueWakeup: - } - } -} - -func (h *HiClient) requestQueuedSession(ctx context.Context, req *database.SessionRequest, doneFunc func()) { - defer doneFunc() - log := zerolog.Ctx(ctx) - if !req.BackupChecked { - sess, err := h.fetchFromKeyBackup(ctx, req.RoomID, req.SessionID) - if err != nil { - log.Err(err). - Stringer("session_id", req.SessionID). - Msg("Failed to fetch session from key backup") - - // TODO should this have retries instead of just storing it's checked? - req.BackupChecked = true - err = h.DB.SessionRequest.Put(ctx, req) - if err != nil { - log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after trying to check backup") - } - } else if sess == nil || sess.Internal.FirstKnownIndex() > req.MinIndex { - req.BackupChecked = true - err = h.DB.SessionRequest.Put(ctx, req) - if err != nil { - log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after checking backup") - } - } else { - log.Debug().Stringer("session_id", req.SessionID). - Msg("Found session with sufficiently low first known index, removing from queue") - err = h.DB.SessionRequest.Remove(ctx, req.SessionID, sess.Internal.FirstKnownIndex()) - if err != nil { - log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to remove session from request queue") - } - } - } else { - err := h.Crypto.SendRoomKeyRequest(ctx, req.RoomID, "", req.SessionID, "", map[id.UserID][]id.DeviceID{ - h.Account.UserID: {"*"}, - req.Sender: {"*"}, - }) - //var err error - if err != nil { - log.Err(err). - Stringer("session_id", req.SessionID). - Msg("Failed to send key request") - } else { - log.Debug().Stringer("session_id", req.SessionID).Msg("Sent key request") - req.RequestSent = true - err = h.DB.SessionRequest.Put(ctx, req) - if err != nil { - log.Err(err).Stringer("session_id", req.SessionID).Msg("Failed to update session request after sending request") - } - } - } -} - -const MaxParallelRequests = 5 - -func (h *HiClient) RequestQueuedSessions(ctx context.Context) (bool, error) { - sessions, err := h.DB.SessionRequest.Next(ctx, MaxParallelRequests) - if err != nil { - return false, fmt.Errorf("failed to get next events to decrypt: %w", err) - } else if len(sessions) == 0 { - return false, nil - } - var wg sync.WaitGroup - wg.Add(len(sessions)) - for _, req := range sessions { - go h.requestQueuedSession(ctx, req, wg.Done) - } - wg.Wait() - - return true, err -} - -func (h *HiClient) FetchKeysForOutdatedUsers(ctx context.Context) error { - outdatedUsers, err := h.Crypto.CryptoStore.GetOutdatedTrackedUsers(ctx) - if err != nil { - return err - } else if len(outdatedUsers) == 0 { - return nil - } - _, err = h.Crypto.FetchKeys(ctx, outdatedUsers, false) - if err != nil { - return err - } - // TODO backoff for users that fail to be fetched? - return nil -} diff --git a/hicli/events.go b/hicli/events.go deleted file mode 100644 index e730475b..00000000 --- a/hicli/events.go +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package hicli - -import ( - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/hicli/database" - "maunium.net/go/mautrix/id" -) - -type SyncRoom struct { - Meta *database.Room `json:"meta"` - Timeline []database.TimelineRowTuple `json:"timeline"` - State map[event.Type]map[string]database.EventRowID `json:"state"` - Events []*database.Event `json:"events"` - Reset bool `json:"reset"` - Notifications []SyncNotification `json:"notifications"` -} - -type SyncNotification struct { - RowID database.EventRowID `json:"event_rowid"` - Sound bool `json:"sound"` -} - -type SyncComplete struct { - Rooms map[id.RoomID]*SyncRoom `json:"rooms"` -} - -func (c *SyncComplete) IsEmpty() bool { - return len(c.Rooms) == 0 -} - -type EventsDecrypted struct { - RoomID id.RoomID `json:"room_id"` - PreviewEventRowID database.EventRowID `json:"preview_event_rowid,omitempty"` - Events []*database.Event `json:"events"` -} - -type Typing struct { - RoomID id.RoomID `json:"room_id"` - event.TypingEventContent -} - -type SendComplete struct { - Event *database.Event `json:"event"` - Error error `json:"error"` -} - -type ClientState struct { - IsLoggedIn bool `json:"is_logged_in"` - IsVerified bool `json:"is_verified"` - UserID id.UserID `json:"user_id,omitempty"` - DeviceID id.DeviceID `json:"device_id,omitempty"` - HomeserverURL string `json:"homeserver_url,omitempty"` -} diff --git a/hicli/hicli.go b/hicli/hicli.go deleted file mode 100644 index 78a1acc0..00000000 --- a/hicli/hicli.go +++ /dev/null @@ -1,250 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -// Package hicli contains a highly opinionated high-level framework for developing instant messaging clients on Matrix. -package hicli - -import ( - "context" - "errors" - "fmt" - "net" - "net/http" - "net/url" - "sync" - "sync/atomic" - "time" - - "github.com/rs/zerolog" - "go.mau.fi/util/dbutil" - "go.mau.fi/util/exerrors" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/crypto" - "maunium.net/go/mautrix/crypto/backup" - "maunium.net/go/mautrix/hicli/database" - "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/pushrules" -) - -type HiClient struct { - DB *database.Database - Account *database.Account - Client *mautrix.Client - Crypto *crypto.OlmMachine - CryptoStore *crypto.SQLCryptoStore - ClientStore *database.ClientStateStore - Log zerolog.Logger - - Verified bool - - KeyBackupVersion id.KeyBackupVersion - KeyBackupKey *backup.MegolmBackupKey - - PushRules atomic.Pointer[pushrules.PushRuleset] - - EventHandler func(evt any) - - firstSyncReceived bool - syncingID int - syncLock sync.Mutex - stopSync atomic.Pointer[context.CancelFunc] - encryptLock sync.Mutex - - requestQueueWakeup chan struct{} - - jsonRequestsLock sync.Mutex - jsonRequests map[int64]context.CancelCauseFunc - - paginationInterrupterLock sync.Mutex - paginationInterrupter map[id.RoomID]context.CancelCauseFunc -} - -var ErrTimelineReset = errors.New("got limited timeline sync response") - -func New(rawDB, cryptoDB *dbutil.Database, log zerolog.Logger, pickleKey []byte, evtHandler func(any)) *HiClient { - if cryptoDB == nil { - cryptoDB = rawDB - } - if rawDB.Owner == "" { - rawDB.Owner = "hicli" - rawDB.IgnoreForeignTables = true - } - if rawDB.Log == nil { - rawDB.Log = dbutil.ZeroLogger(log.With().Str("db_section", "hicli").Logger()) - } - db := database.New(rawDB) - c := &HiClient{ - DB: db, - Log: log, - - requestQueueWakeup: make(chan struct{}, 1), - jsonRequests: make(map[int64]context.CancelCauseFunc), - paginationInterrupter: make(map[id.RoomID]context.CancelCauseFunc), - - EventHandler: evtHandler, - } - c.ClientStore = &database.ClientStateStore{Database: db} - c.Client = &mautrix.Client{ - UserAgent: mautrix.DefaultUserAgent, - Client: &http.Client{ - Transport: &http.Transport{ - DialContext: (&net.Dialer{Timeout: 10 * time.Second}).DialContext, - TLSHandshakeTimeout: 10 * time.Second, - // This needs to be relatively high to allow initial syncs - ResponseHeaderTimeout: 180 * time.Second, - ForceAttemptHTTP2: true, - }, - Timeout: 180 * time.Second, - }, - Syncer: (*hiSyncer)(c), - Store: (*hiStore)(c), - StateStore: c.ClientStore, - Log: log.With().Str("component", "mautrix client").Logger(), - } - c.CryptoStore = crypto.NewSQLCryptoStore(cryptoDB, dbutil.ZeroLogger(log.With().Str("db_section", "crypto").Logger()), "", "", pickleKey) - cryptoLog := log.With().Str("component", "crypto").Logger() - c.Crypto = crypto.NewOlmMachine(c.Client, &cryptoLog, c.CryptoStore, c.ClientStore) - c.Crypto.SessionReceived = c.handleReceivedMegolmSession - c.Crypto.DisableRatchetTracking = true - c.Crypto.DisableDecryptKeyFetching = true - c.Client.Crypto = (*hiCryptoHelper)(c) - return c -} - -func (h *HiClient) IsLoggedIn() bool { - return h.Account != nil -} - -func (h *HiClient) Start(ctx context.Context, userID id.UserID, expectedAccount *database.Account) error { - if expectedAccount != nil && userID != expectedAccount.UserID { - panic(fmt.Errorf("invalid parameters: different user ID in expected account and user ID")) - } - err := h.DB.Upgrade(ctx) - if err != nil { - return fmt.Errorf("failed to upgrade hicli db: %w", err) - } - err = h.CryptoStore.DB.Upgrade(ctx) - if err != nil { - return fmt.Errorf("failed to upgrade crypto db: %w", err) - } - account, err := h.DB.Account.Get(ctx, userID) - if err != nil { - return err - } else if account == nil && expectedAccount != nil { - err = h.DB.Account.Put(ctx, expectedAccount) - if err != nil { - return err - } - account = expectedAccount - } else if expectedAccount != nil && expectedAccount.DeviceID != account.DeviceID { - return fmt.Errorf("device ID mismatch: expected %s, got %s", expectedAccount.DeviceID, account.DeviceID) - } - if account != nil { - zerolog.Ctx(ctx).Debug().Stringer("user_id", account.UserID).Msg("Preparing client with existing credentials") - h.Account = account - h.CryptoStore.AccountID = account.UserID.String() - h.CryptoStore.DeviceID = account.DeviceID - h.Client.UserID = account.UserID - h.Client.DeviceID = account.DeviceID - h.Client.AccessToken = account.AccessToken - h.Client.HomeserverURL, err = url.Parse(account.HomeserverURL) - if err != nil { - return err - } - err = h.CheckServerVersions(ctx) - if err != nil { - return err - } - err = h.Crypto.Load(ctx) - if err != nil { - return fmt.Errorf("failed to load olm machine: %w", err) - } - - h.Verified, err = h.checkIsCurrentDeviceVerified(ctx) - if err != nil { - return err - } - zerolog.Ctx(ctx).Debug().Bool("verified", h.Verified).Msg("Checked current device verification status") - if h.Verified { - err = h.loadPrivateKeys(ctx) - if err != nil { - return err - } - go h.Sync() - } - } - return nil -} - -var ErrFailedToCheckServerVersions = errors.New("failed to check server versions") -var ErrOutdatedServer = errors.New("homeserver is outdated") -var MinimumSpecVersion = mautrix.SpecV11 - -func (h *HiClient) CheckServerVersions(ctx context.Context) error { - versions, err := h.Client.Versions(ctx) - if err != nil { - return exerrors.NewDualError(ErrFailedToCheckServerVersions, err) - } else if !versions.Contains(MinimumSpecVersion) { - return fmt.Errorf("%w (minimum: %s, highest supported: %s)", ErrOutdatedServer, MinimumSpecVersion, versions.GetLatest()) - } - return nil -} - -func (h *HiClient) IsSyncing() bool { - return h.stopSync.Load() != nil -} - -func (h *HiClient) Sync() { - h.Client.StopSync() - if fn := h.stopSync.Load(); fn != nil { - (*fn)() - } - h.syncLock.Lock() - defer h.syncLock.Unlock() - h.syncingID++ - syncingID := h.syncingID - log := h.Log.With(). - Str("action", "sync"). - Int("sync_id", syncingID). - Logger() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - h.stopSync.Store(&cancel) - go h.RunRequestQueue(h.Log.WithContext(ctx)) - go h.LoadPushRules(h.Log.WithContext(ctx)) - ctx = log.WithContext(ctx) - log.Info().Msg("Starting syncing") - err := h.Client.SyncWithContext(ctx) - if err != nil && ctx.Err() == nil { - log.Err(err).Msg("Fatal error in syncer") - } else { - log.Info().Msg("Syncing stopped") - } -} - -func (h *HiClient) LoadPushRules(ctx context.Context) { - rules, err := h.Client.GetPushRules(ctx) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to load push rules") - return - } - h.PushRules.Store(rules) - zerolog.Ctx(ctx).Debug().Msg("Updated push rules from fetch") -} - -func (h *HiClient) Stop() { - h.Client.StopSync() - if fn := h.stopSync.Swap(nil); fn != nil { - (*fn)() - } - h.syncLock.Lock() - h.syncLock.Unlock() - err := h.DB.Close() - if err != nil { - h.Log.Err(err).Msg("Failed to close database cleanly") - } -} diff --git a/hicli/hitest/hitest.go b/hicli/hitest/hitest.go deleted file mode 100644 index bdf1598f..00000000 --- a/hicli/hitest/hitest.go +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package main - -import ( - "context" - "fmt" - "io" - "strings" - - "github.com/chzyer/readline" - _ "github.com/mattn/go-sqlite3" - "github.com/rs/zerolog" - "go.mau.fi/util/dbutil" - _ "go.mau.fi/util/dbutil/litestream" - "go.mau.fi/util/exerrors" - "go.mau.fi/util/exzerolog" - "go.mau.fi/zeroconfig" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/hicli" - "maunium.net/go/mautrix/id" -) - -var writerTypeReadline zeroconfig.WriterType = "hitest_readline" - -func main() { - hicli.InitialDeviceDisplayName = "mautrix hitest" - rl := exerrors.Must(readline.New("> ")) - defer func() { - _ = rl.Close() - }() - zeroconfig.RegisterWriter(writerTypeReadline, func(config *zeroconfig.WriterConfig) (io.Writer, error) { - return rl.Stdout(), nil - }) - debug := zerolog.DebugLevel - log := exerrors.Must((&zeroconfig.Config{ - MinLevel: &debug, - Writers: []zeroconfig.WriterConfig{{ - Type: writerTypeReadline, - Format: zeroconfig.LogFormatPrettyColored, - }}, - }).Compile()) - exzerolog.SetupDefaults(log) - - rawDB := exerrors.Must(dbutil.NewWithDialect("hicli.db", "sqlite3-fk-wal")) - ctx := log.WithContext(context.Background()) - cli := hicli.New(rawDB, nil, *log, []byte("meow"), func(a any) { - _, _ = fmt.Fprintf(rl, "Received event of type %T\n", a) - switch evt := a.(type) { - case *hicli.SyncComplete: - for _, room := range evt.Rooms { - name := "name unset" - if room.Meta.Name != nil { - name = *room.Meta.Name - } - _, _ = fmt.Fprintf(rl, "Room %s (%s) in sync:\n", name, room.Meta.ID) - _, _ = fmt.Fprintf(rl, " Preview: %d, sort: %v\n", room.Meta.PreviewEventRowID, room.Meta.SortingTimestamp) - _, _ = fmt.Fprintf(rl, " Timeline: +%d %v, reset: %t\n", len(room.Timeline), room.Timeline, room.Reset) - } - case *hicli.EventsDecrypted: - for _, decrypted := range evt.Events { - _, _ = fmt.Fprintf(rl, "Delayed decryption of %s completed: %s / %s\n", decrypted.ID, decrypted.DecryptedType, decrypted.Decrypted) - } - if evt.PreviewEventRowID != 0 { - _, _ = fmt.Fprintf(rl, "Room preview updated: %+v\n", evt.PreviewEventRowID) - } - case *hicli.Typing: - _, _ = fmt.Fprintf(rl, "Typing list in %s: %+v\n", evt.RoomID, evt.UserIDs) - } - }) - userID, _ := cli.DB.Account.GetFirstUserID(ctx) - exerrors.PanicIfNotNil(cli.Start(ctx, userID, nil)) - if !cli.IsLoggedIn() { - rl.SetPrompt("User ID: ") - userID := id.UserID(exerrors.Must(rl.Readline())) - _, serverName := exerrors.Must2(userID.Parse()) - discovery := exerrors.Must(mautrix.DiscoverClientAPI(ctx, serverName)) - password := exerrors.Must(rl.ReadPassword("Password: ")) - recoveryCode := exerrors.Must(rl.ReadPassword("Recovery code: ")) - exerrors.PanicIfNotNil(cli.LoginAndVerify(ctx, discovery.Homeserver.BaseURL, userID.String(), string(password), string(recoveryCode))) - } - rl.SetPrompt("> ") - - for { - line, err := rl.Readline() - if err != nil { - break - } - fields := strings.Fields(line) - if len(fields) == 0 { - continue - } - switch strings.ToLower(fields[0]) { - case "send": - resp, err := cli.Send(ctx, id.RoomID(fields[1]), event.EventMessage, &event.MessageEventContent{ - Body: strings.Join(fields[2:], " "), - MsgType: event.MsgText, - }) - _, _ = fmt.Fprintln(rl, err) - _, _ = fmt.Fprintf(rl, "%+v\n", resp) - } - } - cli.Stop() -} diff --git a/hicli/html.go b/hicli/html.go deleted file mode 100644 index b0ad824d..00000000 --- a/hicli/html.go +++ /dev/null @@ -1,476 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package hicli - -import ( - "bytes" - "errors" - "fmt" - "io" - "net/url" - "regexp" - "slices" - "strconv" - "strings" - - "golang.org/x/net/html" - "golang.org/x/net/html/atom" - "mvdan.cc/xurls/v2" - - "maunium.net/go/mautrix/id" -) - -func tagIsAllowed(tag atom.Atom) bool { - switch tag { - case atom.Del, atom.H1, atom.H2, atom.H3, atom.H4, atom.H5, atom.H6, atom.Blockquote, atom.P, - atom.A, atom.Ul, atom.Ol, atom.Sup, atom.Sub, atom.Li, atom.B, atom.I, atom.U, atom.Strong, - atom.Em, atom.S, atom.Code, atom.Hr, atom.Br, atom.Div, atom.Table, atom.Thead, atom.Tbody, - atom.Tr, atom.Th, atom.Td, atom.Caption, atom.Pre, atom.Span, atom.Font, atom.Img, - atom.Details, atom.Summary: - return true - default: - return false - } -} - -func isSelfClosing(tag atom.Atom) bool { - switch tag { - case atom.Img, atom.Br, atom.Hr: - return true - default: - return false - } -} - -var languageRegex = regexp.MustCompile(`^language-[a-zA-Z0-9-]+$`) -var allowedColorRegex = regexp.MustCompile(`^#[0-9a-fA-F]{6}$`) - -// This is approximately a mirror of web/src/util/mediasize.ts in gomuks -func calculateMediaSize(widthInt, heightInt int) (width, height float64, ok bool) { - if widthInt <= 0 || heightInt <= 0 { - return - } - width = float64(widthInt) - height = float64(heightInt) - const imageContainerWidth float64 = 320 - const imageContainerHeight float64 = 240 - const imageContainerAspectRatio = imageContainerWidth / imageContainerHeight - if width > imageContainerWidth || height > imageContainerHeight { - aspectRatio := width / height - if aspectRatio > imageContainerAspectRatio { - width = imageContainerWidth - height = imageContainerWidth / aspectRatio - } else if aspectRatio < imageContainerAspectRatio { - width = imageContainerHeight * aspectRatio - height = imageContainerHeight - } else { - width = imageContainerWidth - height = imageContainerHeight - } - } - ok = true - return -} - -func parseImgAttributes(attrs []html.Attribute) (src, alt, title string, isCustomEmoji bool, width, height int) { - for _, attr := range attrs { - switch attr.Key { - case "src": - src = attr.Val - case "alt": - alt = attr.Val - case "title": - title = attr.Val - case "data-mx-emoticon": - isCustomEmoji = true - case "width": - width, _ = strconv.Atoi(attr.Val) - case "height": - height, _ = strconv.Atoi(attr.Val) - } - } - return -} - -func parseSpanAttributes(attrs []html.Attribute) (bgColor, textColor, spoiler, maths string, isSpoiler bool) { - for _, attr := range attrs { - switch attr.Key { - case "data-mx-bg-color": - if allowedColorRegex.MatchString(attr.Val) { - bgColor = attr.Val - } - case "data-mx-color", "color": - if allowedColorRegex.MatchString(attr.Val) { - textColor = attr.Val - } - case "data-mx-spoiler": - spoiler = attr.Val - isSpoiler = true - case "data-mx-maths": - maths = attr.Val - } - } - return -} - -func parseAAttributes(attrs []html.Attribute) (href string) { - for _, attr := range attrs { - switch attr.Key { - case "href": - href = strings.TrimSpace(attr.Val) - } - } - return -} - -func attributeIsAllowed(tag atom.Atom, attr html.Attribute) bool { - switch tag { - case atom.Ol: - switch attr.Key { - case "start": - _, err := strconv.Atoi(attr.Val) - return err == nil - } - case atom.Code: - switch attr.Key { - case "class": - return languageRegex.MatchString(attr.Val) - } - case atom.Div: - switch attr.Key { - case "data-mx-maths": - return true - } - } - return false -} - -// Funny user IDs will just need to be linkified by the sender, no auto-linkification for them. -var plainUserOrAliasMentionRegex = regexp.MustCompile(`[@#][a-zA-Z0-9._=/+-]{0,254}:[a-zA-Z0-9.-]+(?:\d{1,5})?`) - -func getNextItem(items [][]int, minIndex int) (index, start, end int, ok bool) { - for i, item := range items { - if item[0] >= minIndex { - return i, item[0], item[1], true - } - } - return -1, -1, -1, false -} - -func writeMention(w *strings.Builder, mention []byte) { - w.WriteString(`') - writeEscapedBytes(w, mention) - w.WriteString("") -} - -func writeURL(w *strings.Builder, addr []byte) { - parsedURL, err := url.Parse(string(addr)) - if err != nil { - writeEscapedBytes(w, addr) - return - } - if parsedURL.Scheme == "" { - parsedURL.Scheme = "https" - } - w.WriteString(`') - writeEscapedBytes(w, addr) - w.WriteString("") -} - -func linkifyAndWriteBytes(w *strings.Builder, s []byte) { - mentions := plainUserOrAliasMentionRegex.FindAllIndex(s, -1) - urls := xurls.Relaxed().FindAllIndex(s, -1) - minIndex := 0 - for { - mentionIdx, nextMentionStart, nextMentionEnd, hasMention := getNextItem(mentions, minIndex) - urlIdx, nextURLStart, nextURLEnd, hasURL := getNextItem(urls, minIndex) - if hasMention && (!hasURL || nextMentionStart <= nextURLStart) { - writeEscapedBytes(w, s[minIndex:nextMentionStart]) - writeMention(w, s[nextMentionStart:nextMentionEnd]) - minIndex = nextMentionEnd - mentions = mentions[mentionIdx:] - } else if hasURL && (!hasMention || nextURLStart < nextMentionStart) { - writeEscapedBytes(w, s[minIndex:nextURLStart]) - writeURL(w, s[nextURLStart:nextURLEnd]) - minIndex = nextURLEnd - urls = urls[urlIdx:] - } else { - break - } - } - writeEscapedBytes(w, s[minIndex:]) -} - -const escapedChars = "&'<>\"\r" - -func writeEscapedBytes(w *strings.Builder, s []byte) { - i := bytes.IndexAny(s, escapedChars) - for i != -1 { - w.Write(s[:i]) - var esc string - switch s[i] { - case '&': - esc = "&" - case '\'': - // "'" is shorter than "'" and apos was not in HTML until HTML5. - esc = "'" - case '<': - esc = "<" - case '>': - esc = ">" - case '"': - // """ is shorter than """. - esc = """ - case '\r': - esc = " " - default: - panic("unrecognized escape character") - } - s = s[i+1:] - w.WriteString(esc) - i = bytes.IndexAny(s, escapedChars) - } - w.Write(s) -} - -func writeEscapedString(w *strings.Builder, s string) { - i := strings.IndexAny(s, escapedChars) - for i != -1 { - w.WriteString(s[:i]) - var esc string - switch s[i] { - case '&': - esc = "&" - case '\'': - // "'" is shorter than "'" and apos was not in HTML until HTML5. - esc = "'" - case '<': - esc = "<" - case '>': - esc = ">" - case '"': - // """ is shorter than """. - esc = """ - case '\r': - esc = " " - default: - panic("unrecognized escape character") - } - s = s[i+1:] - w.WriteString(esc) - i = strings.IndexAny(s, escapedChars) - } - w.WriteString(s) -} - -func writeAttribute(w *strings.Builder, key, value string) { - w.WriteByte(' ') - w.WriteString(key) - w.WriteString(`="`) - writeEscapedString(w, value) - w.WriteByte('"') -} - -func writeA(w *strings.Builder, attr []html.Attribute) { - w.WriteString("`) - w.WriteString(spoiler) - w.WriteString(" ") - } - w.WriteByte('<') - w.WriteString("span") - if isSpoiler { - writeAttribute(w, "class", "hicli-spoiler") - } - var style string - if bgColor != "" { - style += fmt.Sprintf("background-color: %s;", bgColor) - } - if textColor != "" { - style += fmt.Sprintf("color: %s;", textColor) - } - if style != "" { - writeAttribute(w, "style", style) - } -} - -type tagStack []atom.Atom - -func (ts *tagStack) contains(tags ...atom.Atom) bool { - for i := len(*ts) - 1; i >= 0; i-- { - for _, tag := range tags { - if (*ts)[i] == tag { - return true - } - } - } - return false -} - -func (ts *tagStack) push(tag atom.Atom) { - *ts = append(*ts, tag) -} - -func (ts *tagStack) pop(tag atom.Atom) bool { - if len(*ts) > 0 && (*ts)[len(*ts)-1] == tag { - *ts = (*ts)[:len(*ts)-1] - return true - } - return false -} - -func sanitizeAndLinkifyHTML(body string) (string, error) { - tz := html.NewTokenizer(strings.NewReader(body)) - var built strings.Builder - ts := make(tagStack, 2) -Loop: - for { - switch tz.Next() { - case html.ErrorToken: - err := tz.Err() - if errors.Is(err, io.EOF) { - break Loop - } - return "", err - case html.StartTagToken, html.SelfClosingTagToken: - token := tz.Token() - if !tagIsAllowed(token.DataAtom) { - continue - } - tagIsSelfClosing := isSelfClosing(token.DataAtom) - if token.Type == html.SelfClosingTagToken && !tagIsSelfClosing { - continue - } - switch token.DataAtom { - case atom.A: - writeA(&built, token.Attr) - case atom.Img: - writeImg(&built, token.Attr) - case atom.Span, atom.Font: - writeSpan(&built, token.Attr) - default: - built.WriteByte('<') - built.WriteString(token.Data) - for _, attr := range token.Attr { - if attributeIsAllowed(token.DataAtom, attr) { - writeAttribute(&built, attr.Key, attr.Val) - } - } - } - built.WriteByte('>') - if !tagIsSelfClosing { - ts.push(token.DataAtom) - } - case html.EndTagToken: - tagName, _ := tz.TagName() - tag := atom.Lookup(tagName) - if tagIsAllowed(tag) && ts.pop(tag) { - built.WriteString("') - } - case html.TextToken: - if ts.contains(atom.Pre, atom.Code, atom.A) { - writeEscapedBytes(&built, tz.Text()) - } else { - linkifyAndWriteBytes(&built, tz.Text()) - } - case html.DoctypeToken, html.CommentToken: - // ignore - } - } - slices.Reverse(ts) - for _, t := range ts { - built.WriteString("') - } - return built.String(), nil -} diff --git a/hicli/json-commands.go b/hicli/json-commands.go deleted file mode 100644 index c9dc89d2..00000000 --- a/hicli/json-commands.go +++ /dev/null @@ -1,178 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package hicli - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "time" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/hicli/database" - "maunium.net/go/mautrix/id" -) - -func (h *HiClient) handleJSONCommand(ctx context.Context, req *JSONCommand) (any, error) { - switch req.Command { - case "get_state": - return h.State(), nil - case "cancel": - return unmarshalAndCall(req.Data, func(params *cancelRequestParams) (bool, error) { - h.jsonRequestsLock.Lock() - cancelTarget, ok := h.jsonRequests[params.RequestID] - h.jsonRequestsLock.Unlock() - if ok { - return false, nil - } - if params.Reason == "" { - cancelTarget(nil) - } else { - cancelTarget(errors.New(params.Reason)) - } - return true, nil - }) - case "send_message": - return unmarshalAndCall(req.Data, func(params *sendMessageParams) (*database.Event, error) { - return h.SendMessage(ctx, params.RoomID, params.Text, params.MediaPath, params.ReplyTo, params.Mentions) - }) - case "send_event": - return unmarshalAndCall(req.Data, func(params *sendEventParams) (*database.Event, error) { - return h.Send(ctx, params.RoomID, params.EventType, params.Content) - }) - case "mark_read": - return unmarshalAndCall(req.Data, func(params *markReadParams) (bool, error) { - return true, h.MarkRead(ctx, params.RoomID, params.EventID, params.ReceiptType) - }) - case "set_typing": - return unmarshalAndCall(req.Data, func(params *setTypingParams) (bool, error) { - return true, h.SetTyping(ctx, params.RoomID, time.Duration(params.Timeout)*time.Millisecond) - }) - case "get_event": - return unmarshalAndCall(req.Data, func(params *getEventParams) (*database.Event, error) { - return h.GetEvent(ctx, params.RoomID, params.EventID) - }) - case "get_events_by_rowids": - return unmarshalAndCall(req.Data, func(params *getEventsByRowIDsParams) ([]*database.Event, error) { - return h.GetEventsByRowIDs(ctx, params.RowIDs) - }) - case "get_room_state": - return unmarshalAndCall(req.Data, func(params *getRoomStateParams) ([]*database.Event, error) { - return h.GetRoomState(ctx, params.RoomID, params.FetchMembers, params.Refetch) - }) - case "paginate": - return unmarshalAndCall(req.Data, func(params *paginateParams) (*PaginationResponse, error) { - return h.Paginate(ctx, params.RoomID, params.MaxTimelineID, params.Limit) - }) - case "paginate_server": - return unmarshalAndCall(req.Data, func(params *paginateParams) (*PaginationResponse, error) { - return h.PaginateServer(ctx, params.RoomID, params.Limit) - }) - case "ensure_group_session_shared": - return unmarshalAndCall(req.Data, func(params *ensureGroupSessionSharedParams) (bool, error) { - return true, h.EnsureGroupSessionShared(ctx, params.RoomID) - }) - case "login": - return unmarshalAndCall(req.Data, func(params *loginParams) (bool, error) { - return true, h.LoginPassword(ctx, params.HomeserverURL, params.Username, params.Password) - }) - case "verify": - return unmarshalAndCall(req.Data, func(params *verifyParams) (bool, error) { - return true, h.VerifyWithRecoveryKey(ctx, params.RecoveryKey) - }) - case "discover_homeserver": - return unmarshalAndCall(req.Data, func(params *discoverHomeserverParams) (*mautrix.ClientWellKnown, error) { - _, homeserver, err := params.UserID.Parse() - if err != nil { - return nil, err - } - return mautrix.DiscoverClientAPI(ctx, homeserver) - }) - default: - return nil, fmt.Errorf("unknown command %q", req.Command) - } -} - -func unmarshalAndCall[T, O any](data json.RawMessage, fn func(*T) (O, error)) (output O, err error) { - var input T - err = json.Unmarshal(data, &input) - if err != nil { - return - } - return fn(&input) -} - -type cancelRequestParams struct { - RequestID int64 `json:"request_id"` - Reason string `json:"reason"` -} - -type sendMessageParams struct { - RoomID id.RoomID `json:"room_id"` - Text string `json:"text"` - MediaPath string `json:"media_path"` - ReplyTo id.EventID `json:"reply_to"` - Mentions *event.Mentions `json:"mentions"` -} - -type sendEventParams struct { - RoomID id.RoomID `json:"room_id"` - EventType event.Type `json:"type"` - Content json.RawMessage `json:"content"` -} - -type markReadParams struct { - RoomID id.RoomID `json:"room_id"` - EventID id.EventID `json:"event_id"` - ReceiptType event.ReceiptType `json:"receipt_type"` -} - -type setTypingParams struct { - RoomID id.RoomID `json:"room_id"` - Timeout int `json:"timeout"` -} - -type getEventParams struct { - RoomID id.RoomID `json:"room_id"` - EventID id.EventID `json:"event_id"` -} - -type getEventsByRowIDsParams struct { - RowIDs []database.EventRowID `json:"row_ids"` -} - -type getRoomStateParams struct { - RoomID id.RoomID `json:"room_id"` - Refetch bool `json:"refetch"` - FetchMembers bool `json:"fetch_members"` -} - -type ensureGroupSessionSharedParams struct { - RoomID id.RoomID `json:"room_id"` -} - -type loginParams struct { - HomeserverURL string `json:"homeserver_url"` - Username string `json:"username"` - Password string `json:"password"` -} - -type verifyParams struct { - RecoveryKey string `json:"recovery_key"` -} - -type discoverHomeserverParams struct { - UserID id.UserID `json:"user_id"` -} - -type paginateParams struct { - RoomID id.RoomID `json:"room_id"` - MaxTimelineID database.TimelineRowID `json:"max_timeline_id"` - Limit int `json:"limit"` -} diff --git a/hicli/json.go b/hicli/json.go deleted file mode 100644 index a27fd007..00000000 --- a/hicli/json.go +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package hicli - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "sync/atomic" - - "go.mau.fi/util/exerrors" -) - -type JSONCommand struct { - Command string `json:"command"` - RequestID int64 `json:"request_id"` - Data json.RawMessage `json:"data"` -} - -type JSONEventHandler func(*JSONCommand) - -var outgoingEventCounter atomic.Int64 - -func (jeh JSONEventHandler) HandleEvent(evt any) { - var command string - switch evt.(type) { - case *SyncComplete: - command = "sync_complete" - case *EventsDecrypted: - command = "events_decrypted" - case *Typing: - command = "typing" - case *SendComplete: - command = "send_complete" - case *ClientState: - command = "client_state" - default: - panic(fmt.Errorf("unknown event type %T", evt)) - } - data, err := json.Marshal(evt) - if err != nil { - panic(fmt.Errorf("failed to marshal event %T: %w", evt, err)) - } - jeh(&JSONCommand{ - Command: command, - RequestID: -outgoingEventCounter.Add(1), - Data: data, - }) -} - -func (h *HiClient) State() *ClientState { - state := &ClientState{} - if acc := h.Account; acc != nil { - state.IsLoggedIn = true - state.UserID = acc.UserID - state.DeviceID = acc.DeviceID - state.HomeserverURL = acc.HomeserverURL - state.IsVerified = h.Verified - } - return state -} - -func (h *HiClient) dispatchCurrentState() { - h.EventHandler(h.State()) -} - -func (h *HiClient) SubmitJSONCommand(ctx context.Context, req *JSONCommand) *JSONCommand { - if req.Command == "ping" { - return &JSONCommand{ - Command: "pong", - RequestID: req.RequestID, - } - } - log := h.Log.With().Int64("request_id", req.RequestID).Str("command", req.Command).Logger() - ctx, cancel := context.WithCancelCause(ctx) - defer func() { - cancel(nil) - h.jsonRequestsLock.Lock() - delete(h.jsonRequests, req.RequestID) - h.jsonRequestsLock.Unlock() - }() - ctx = log.WithContext(ctx) - h.jsonRequestsLock.Lock() - h.jsonRequests[req.RequestID] = cancel - h.jsonRequestsLock.Unlock() - resp, err := h.handleJSONCommand(ctx, req) - if err != nil { - if errors.Is(err, context.Canceled) { - causeErr := context.Cause(ctx) - if causeErr != ctx.Err() { - err = fmt.Errorf("%w: %w", err, causeErr) - } - } - return &JSONCommand{ - Command: "error", - RequestID: req.RequestID, - Data: exerrors.Must(json.Marshal(err.Error())), - } - } - var respData json.RawMessage - respData, err = json.Marshal(resp) - if err != nil { - return &JSONCommand{ - Command: "error", - RequestID: req.RequestID, - Data: exerrors.Must(json.Marshal(fmt.Sprintf("failed to marshal response json: %v", err))), - } - } - return &JSONCommand{ - Command: "response", - RequestID: req.RequestID, - Data: respData, - } -} diff --git a/hicli/login.go b/hicli/login.go deleted file mode 100644 index 6dbaf6e6..00000000 --- a/hicli/login.go +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package hicli - -import ( - "context" - "fmt" - "net/url" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/hicli/database" - "maunium.net/go/mautrix/id" -) - -var InitialDeviceDisplayName = "mautrix hiclient" - -func (h *HiClient) LoginPassword(ctx context.Context, homeserverURL, username, password string) error { - var err error - h.Client.HomeserverURL, err = url.Parse(homeserverURL) - if err != nil { - return err - } - return h.Login(ctx, &mautrix.ReqLogin{ - Type: mautrix.AuthTypePassword, - Identifier: mautrix.UserIdentifier{ - Type: mautrix.IdentifierTypeUser, - User: username, - }, - Password: password, - InitialDeviceDisplayName: InitialDeviceDisplayName, - }) -} - -func (h *HiClient) Login(ctx context.Context, req *mautrix.ReqLogin) error { - err := h.CheckServerVersions(ctx) - if err != nil { - return err - } - req.StoreCredentials = true - req.StoreHomeserverURL = true - resp, err := h.Client.Login(ctx, req) - if err != nil { - return err - } - defer h.dispatchCurrentState() - h.Account = &database.Account{ - UserID: resp.UserID, - DeviceID: resp.DeviceID, - AccessToken: resp.AccessToken, - HomeserverURL: h.Client.HomeserverURL.String(), - } - h.CryptoStore.AccountID = resp.UserID.String() - h.CryptoStore.DeviceID = resp.DeviceID - err = h.DB.Account.Put(ctx, h.Account) - if err != nil { - return err - } - err = h.Crypto.Load(ctx) - if err != nil { - return fmt.Errorf("failed to load olm machine: %w", err) - } - err = h.Crypto.ShareKeys(ctx, 0) - if err != nil { - return err - } - _, err = h.Crypto.FetchKeys(ctx, []id.UserID{h.Account.UserID}, true) - if err != nil { - return fmt.Errorf("failed to fetch own devices: %w", err) - } - return nil -} - -func (h *HiClient) LoginAndVerify(ctx context.Context, homeserverURL, username, password, recoveryKey string) error { - err := h.LoginPassword(ctx, homeserverURL, username, password) - if err != nil { - return err - } - err = h.VerifyWithRecoveryKey(ctx, recoveryKey) - if err != nil { - return err - } - return nil -} diff --git a/hicli/paginate.go b/hicli/paginate.go deleted file mode 100644 index 7fc50827..00000000 --- a/hicli/paginate.go +++ /dev/null @@ -1,240 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package hicli - -import ( - "context" - "errors" - "fmt" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/hicli/database" - "maunium.net/go/mautrix/id" -) - -var ErrPaginationAlreadyInProgress = errors.New("pagination is already in progress") - -func (h *HiClient) GetEventsByRowIDs(ctx context.Context, rowIDs []database.EventRowID) ([]*database.Event, error) { - events, err := h.DB.Event.GetByRowIDs(ctx, rowIDs...) - if err != nil { - return nil, err - } else if len(events) == 0 { - return events, nil - } - firstRoomID := events[0].RoomID - allInSameRoom := true - for _, evt := range events { - if evt.RoomID != firstRoomID { - allInSameRoom = false - break - } - } - if allInSameRoom { - err = h.DB.Event.FillLastEditRowIDs(ctx, firstRoomID, events) - if err != nil { - return events, fmt.Errorf("failed to fill last edit row IDs: %w", err) - } - err = h.DB.Event.FillReactionCounts(ctx, firstRoomID, events) - if err != nil { - return events, fmt.Errorf("failed to fill reaction counts: %w", err) - } - } else { - // TODO slow path where events are collected and filling is done one room at a time? - } - return events, nil -} - -func (h *HiClient) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*database.Event, error) { - if evt, err := h.DB.Event.GetByID(ctx, eventID); err != nil { - return nil, fmt.Errorf("failed to get event from database: %w", err) - } else if evt != nil { - return evt, nil - } else if serverEvt, err := h.Client.GetEvent(ctx, roomID, eventID); err != nil { - return nil, fmt.Errorf("failed to get event from server: %w", err) - } else { - return h.processEvent(ctx, serverEvt, nil, nil, false) - } -} - -func (h *HiClient) GetRoomState(ctx context.Context, roomID id.RoomID, fetchMembers, refetch bool) ([]*database.Event, error) { - var evts []*event.Event - if refetch { - resp, err := h.Client.StateAsArray(ctx, roomID) - if err != nil { - return nil, fmt.Errorf("failed to refetch state: %w", err) - } - evts = resp - } else if fetchMembers { - resp, err := h.Client.Members(ctx, roomID) - if err != nil { - return nil, fmt.Errorf("failed to fetch members: %w", err) - } - evts = resp.Chunk - } - if evts != nil { - err := h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { - room, err := h.DB.Room.Get(ctx, roomID) - if err != nil { - return fmt.Errorf("failed to get room from database: %w", err) - } - updatedRoom := &database.Room{ - ID: room.ID, - HasMemberList: true, - } - entries := make([]*database.CurrentStateEntry, len(evts)) - for i, evt := range evts { - dbEvt, err := h.processEvent(ctx, evt, room.LazyLoadSummary, nil, false) - if err != nil { - return fmt.Errorf("failed to process event %s: %w", evt.ID, err) - } - entries[i] = &database.CurrentStateEntry{ - EventType: evt.Type, - StateKey: *evt.StateKey, - EventRowID: dbEvt.RowID, - } - if evt.Type == event.StateMember { - entries[i].Membership = event.Membership(evt.Content.Raw["membership"].(string)) - } else { - processImportantEvent(ctx, evt, room, updatedRoom) - } - } - err = h.DB.CurrentState.AddMany(ctx, room.ID, refetch, entries) - if err != nil { - return err - } - roomChanged := updatedRoom.CheckChangesAndCopyInto(room) - if roomChanged { - err = h.DB.Room.Upsert(ctx, updatedRoom) - if err != nil { - return fmt.Errorf("failed to save room data: %w", err) - } - } - return nil - }) - if err != nil { - return nil, err - } - } - return h.DB.CurrentState.GetAll(ctx, roomID) -} - -type PaginationResponse struct { - Events []*database.Event `json:"events"` - HasMore bool `json:"has_more"` -} - -func (h *HiClient) Paginate(ctx context.Context, roomID id.RoomID, maxTimelineID database.TimelineRowID, limit int) (*PaginationResponse, error) { - evts, err := h.DB.Timeline.Get(ctx, roomID, limit, maxTimelineID) - if err != nil { - return nil, err - } else if len(evts) > 0 { - return &PaginationResponse{Events: evts, HasMore: true}, nil - } else { - return h.PaginateServer(ctx, roomID, limit) - } -} - -func (h *HiClient) PaginateServer(ctx context.Context, roomID id.RoomID, limit int) (*PaginationResponse, error) { - ctx, cancel := context.WithCancelCause(ctx) - h.paginationInterrupterLock.Lock() - if _, alreadyPaginating := h.paginationInterrupter[roomID]; alreadyPaginating { - h.paginationInterrupterLock.Unlock() - return nil, ErrPaginationAlreadyInProgress - } - h.paginationInterrupter[roomID] = cancel - h.paginationInterrupterLock.Unlock() - defer func() { - h.paginationInterrupterLock.Lock() - delete(h.paginationInterrupter, roomID) - h.paginationInterrupterLock.Unlock() - }() - - room, err := h.DB.Room.Get(ctx, roomID) - if err != nil { - return nil, fmt.Errorf("failed to get room from database: %w", err) - } else if room.PrevBatch == database.PrevBatchPaginationComplete { - return &PaginationResponse{Events: []*database.Event{}, HasMore: false}, nil - } - resp, err := h.Client.Messages(ctx, roomID, room.PrevBatch, "", mautrix.DirectionBackward, nil, limit) - if err != nil { - return nil, fmt.Errorf("failed to get messages from server: %w", err) - } - events := make([]*database.Event, len(resp.Chunk)) - if resp.End == "" { - resp.End = database.PrevBatchPaginationComplete - } - if resp.End == database.PrevBatchPaginationComplete || len(resp.Chunk) == 0 { - err = h.DB.Room.SetPrevBatch(ctx, room.ID, resp.End) - if err != nil { - return nil, fmt.Errorf("failed to set prev_batch: %w", err) - } - return &PaginationResponse{Events: events, HasMore: resp.End != ""}, nil - } - wakeupSessionRequests := false - err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { - if err = ctx.Err(); err != nil { - return err - } - eventRowIDs := make([]database.EventRowID, len(resp.Chunk)) - decryptionQueue := make(map[id.SessionID]*database.SessionRequest) - iOffset := 0 - for i, evt := range resp.Chunk { - dbEvt, err := h.processEvent(ctx, evt, room.LazyLoadSummary, decryptionQueue, true) - if err != nil { - return err - } else if exists, err := h.DB.Timeline.Has(ctx, roomID, dbEvt.RowID); err != nil { - return fmt.Errorf("failed to check if event exists in timeline: %w", err) - } else if exists { - zerolog.Ctx(ctx).Warn(). - Int64("row_id", int64(dbEvt.RowID)). - Str("event_id", dbEvt.ID.String()). - Msg("Event already exists in timeline, skipping") - iOffset++ - continue - } - events[i-iOffset] = dbEvt - eventRowIDs[i-iOffset] = events[i-iOffset].RowID - } - if iOffset >= len(events) { - events = events[:0] - return nil - } - events = events[:len(events)-iOffset] - eventRowIDs = eventRowIDs[:len(eventRowIDs)-iOffset] - wakeupSessionRequests = len(decryptionQueue) > 0 - for _, entry := range decryptionQueue { - err = h.DB.SessionRequest.Put(ctx, entry) - if err != nil { - return fmt.Errorf("failed to save session request for %s: %w", entry.SessionID, err) - } - } - err = h.DB.Event.FillLastEditRowIDs(ctx, roomID, events) - if err != nil { - return fmt.Errorf("failed to fill last edit row IDs: %w", err) - } - err = h.DB.Room.SetPrevBatch(ctx, room.ID, resp.End) - if err != nil { - return fmt.Errorf("failed to set prev_batch: %w", err) - } - var tuples []database.TimelineRowTuple - tuples, err = h.DB.Timeline.Prepend(ctx, room.ID, eventRowIDs) - if err != nil { - return fmt.Errorf("failed to prepend events to timeline: %w", err) - } - for i, evt := range events { - evt.TimelineRowID = tuples[i].Timeline - } - return nil - }) - if err == nil && wakeupSessionRequests { - h.WakeupRequestQueue() - } - return &PaginationResponse{Events: events, HasMore: true}, err -} diff --git a/hicli/pushrules.go b/hicli/pushrules.go deleted file mode 100644 index 74c0e8e4..00000000 --- a/hicli/pushrules.go +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package hicli - -import ( - "context" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/hicli/database" - "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/pushrules" -) - -type pushRoom struct { - ctx context.Context - roomID id.RoomID - h *HiClient - ll *mautrix.LazyLoadSummary -} - -func (p *pushRoom) GetOwnDisplayname() string { - // TODO implement - return "" -} - -func (p *pushRoom) GetMemberCount() int { - if p.ll == nil { - room, err := p.h.DB.Room.Get(p.ctx, p.roomID) - if err != nil { - zerolog.Ctx(p.ctx).Err(err). - Stringer("room_id", p.roomID). - Msg("Failed to get room by ID in push rule evaluator") - } else if room != nil { - p.ll = room.LazyLoadSummary - } - } - if p.ll != nil && p.ll.JoinedMemberCount != nil { - return *p.ll.JoinedMemberCount - } - // TODO query db? - return 0 -} - -func (p *pushRoom) GetEvent(id id.EventID) *event.Event { - evt, err := p.h.DB.Event.GetByID(p.ctx, id) - if err != nil { - zerolog.Ctx(p.ctx).Err(err). - Stringer("event_id", id). - Msg("Failed to get event by ID in push rule evaluator") - } - return evt.AsRawMautrix() -} - -var _ pushrules.EventfulRoom = (*pushRoom)(nil) - -func (h *HiClient) evaluatePushRules(ctx context.Context, llSummary *mautrix.LazyLoadSummary, baseType database.UnreadType, evt *event.Event) database.UnreadType { - should := h.PushRules.Load().GetMatchingRule(&pushRoom{ - ctx: ctx, - roomID: evt.RoomID, - h: h, - ll: llSummary, - }, evt).GetActions().Should() - if should.Notify { - baseType |= database.UnreadTypeNotify - } - if should.Highlight { - baseType |= database.UnreadTypeHighlight - } - if should.PlaySound { - baseType |= database.UnreadTypeSound - } - return baseType -} diff --git a/hicli/send.go b/hicli/send.go deleted file mode 100644 index cdb8571b..00000000 --- a/hicli/send.go +++ /dev/null @@ -1,287 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package hicli - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "strings" - "time" - - "github.com/rs/zerolog" - "github.com/yuin/goldmark" - "go.mau.fi/util/jsontime" - "go.mau.fi/util/ptr" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/crypto" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/format" - "maunium.net/go/mautrix/format/mdext/rainbow" - "maunium.net/go/mautrix/hicli/database" - "maunium.net/go/mautrix/id" -) - -var ( - rainbowWithHTML = goldmark.New(format.Extensions, format.HTMLOptions, goldmark.WithExtensions(rainbow.Extension)) -) - -func (h *HiClient) SendMessage(ctx context.Context, roomID id.RoomID, text, mediaPath string, replyTo id.EventID, mentions *event.Mentions) (*database.Event, error) { - var content event.MessageEventContent - if strings.HasPrefix(text, "/rainbow ") { - text = strings.TrimPrefix(text, "/rainbow ") - content = format.RenderMarkdownCustom(text, rainbowWithHTML) - content.FormattedBody = rainbow.ApplyColor(content.FormattedBody) - } else if strings.HasPrefix(text, "/plain ") { - text = strings.TrimPrefix(text, "/plain ") - content = format.RenderMarkdown(text, false, false) - } else if strings.HasPrefix(text, "/html ") { - text = strings.TrimPrefix(text, "/html ") - content = format.RenderMarkdown(text, false, true) - } else { - content = format.RenderMarkdown(text, true, false) - } - if mentions != nil { - content.Mentions.Room = mentions.Room - for _, userID := range mentions.UserIDs { - if userID != h.Account.UserID { - content.Mentions.Add(userID) - } - } - } - if replyTo != "" { - content.RelatesTo = (&event.RelatesTo{}).SetReplyTo(replyTo) - } - return h.Send(ctx, roomID, event.EventMessage, &content) -} - -func (h *HiClient) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.EventID, receiptType event.ReceiptType) error { - content := &mautrix.ReqSetReadMarkers{ - FullyRead: eventID, - } - if receiptType == event.ReceiptTypeRead { - content.Read = eventID - } else if receiptType == event.ReceiptTypeReadPrivate { - content.ReadPrivate = eventID - } else { - return fmt.Errorf("invalid receipt type: %v", receiptType) - } - err := h.Client.SetReadMarkers(ctx, roomID, content) - if err != nil { - return fmt.Errorf("failed to mark event as read: %w", err) - } - return nil -} - -func (h *HiClient) SetTyping(ctx context.Context, roomID id.RoomID, timeout time.Duration) error { - _, err := h.Client.UserTyping(ctx, roomID, timeout > 0, timeout) - return err -} - -func (h *HiClient) Send(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (*database.Event, error) { - roomMeta, err := h.DB.Room.Get(ctx, roomID) - if err != nil { - return nil, fmt.Errorf("failed to get room metadata: %w", err) - } else if roomMeta == nil { - return nil, fmt.Errorf("unknown room") - } - var decryptedType event.Type - var decryptedContent json.RawMessage - var megolmSessionID id.SessionID - if roomMeta.EncryptionEvent != nil && evtType != event.EventReaction { - decryptedType = evtType - decryptedContent, err = json.Marshal(content) - if err != nil { - return nil, fmt.Errorf("failed to marshal event content: %w", err) - } - encryptedContent, err := h.Encrypt(ctx, roomMeta, evtType, content) - if err != nil { - return nil, fmt.Errorf("failed to encrypt event: %w", err) - } - megolmSessionID = encryptedContent.SessionID - content = encryptedContent - evtType = event.EventEncrypted - } - mainContent, err := json.Marshal(content) - if err != nil { - return nil, fmt.Errorf("failed to marshal event content: %w", err) - } - txnID := "hicli-" + h.Client.TxnID() - relatesTo, relationType := database.GetRelatesToFromBytes(mainContent) - dbEvt := &database.Event{ - RoomID: roomID, - ID: id.EventID(fmt.Sprintf("~%s", txnID)), - Sender: h.Account.UserID, - Type: evtType.Type, - Timestamp: jsontime.UnixMilliNow(), - Content: mainContent, - Decrypted: decryptedContent, - DecryptedType: decryptedType.Type, - Unsigned: []byte("{}"), - TransactionID: txnID, - RelatesTo: relatesTo, - RelationType: relationType, - MegolmSessionID: megolmSessionID, - DecryptionError: "", - SendError: "not sent", - Reactions: map[string]int{}, - LastEditRowID: ptr.Ptr(database.EventRowID(0)), - } - _, err = h.DB.Event.Insert(ctx, dbEvt) - if err != nil { - return nil, fmt.Errorf("failed to insert event into database: %w", err) - } - ctx = context.WithoutCancel(ctx) - go func() { - err := h.SetTyping(ctx, roomID, 0) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to stop typing while sending message") - } - }() - go func() { - var err error - defer func() { - h.EventHandler(&SendComplete{ - Event: dbEvt, - Error: err, - }) - }() - var resp *mautrix.RespSendEvent - resp, err = h.Client.SendMessageEvent(ctx, roomID, evtType, content, mautrix.ReqSendEvent{ - Timestamp: dbEvt.Timestamp.UnixMilli(), - TransactionID: txnID, - DontEncrypt: true, - }) - if err != nil { - dbEvt.SendError = err.Error() - err = fmt.Errorf("failed to send event: %w", err) - err2 := h.DB.Event.UpdateSendError(ctx, dbEvt.RowID, dbEvt.SendError) - if err2 != nil { - zerolog.Ctx(ctx).Err(err2).AnErr("send_error", err). - Msg("Failed to update send error in database after sending failed") - } - return - } - dbEvt.ID = resp.EventID - err = h.DB.Event.UpdateID(ctx, dbEvt.RowID, dbEvt.ID) - if err != nil { - err = fmt.Errorf("failed to update event ID in database: %w", err) - } - }() - return dbEvt, nil -} - -func (h *HiClient) Encrypt(ctx context.Context, room *database.Room, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) { - h.encryptLock.Lock() - defer h.encryptLock.Unlock() - encrypted, err = h.Crypto.EncryptMegolmEvent(ctx, room.ID, evtType, content) - if errors.Is(err, crypto.SessionExpired) || errors.Is(err, crypto.NoGroupSession) || errors.Is(err, crypto.SessionNotShared) { - if err = h.shareGroupSession(ctx, room); err != nil { - err = fmt.Errorf("failed to share group session: %w", err) - } else if encrypted, err = h.Crypto.EncryptMegolmEvent(ctx, room.ID, evtType, content); err != nil { - err = fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err) - } - } - return -} - -func (h *HiClient) EnsureGroupSessionShared(ctx context.Context, roomID id.RoomID) error { - h.encryptLock.Lock() - defer h.encryptLock.Unlock() - if session, err := h.CryptoStore.GetOutboundGroupSession(ctx, roomID); err != nil { - return fmt.Errorf("failed to get previous outbound group session: %w", err) - } else if session != nil && session.Shared && !session.Expired() { - return nil - } else if roomMeta, err := h.DB.Room.Get(ctx, roomID); err != nil { - return fmt.Errorf("failed to get room metadata: %w", err) - } else if roomMeta == nil { - return fmt.Errorf("unknown room") - } else { - return h.shareGroupSession(ctx, roomMeta) - } -} - -func (h *HiClient) loadMembers(ctx context.Context, room *database.Room) error { - if room.HasMemberList { - return nil - } - resp, err := h.Client.Members(ctx, room.ID) - if err != nil { - return fmt.Errorf("failed to get room member list: %w", err) - } - err = h.DB.DoTxn(ctx, nil, func(ctx context.Context) error { - entries := make([]*database.CurrentStateEntry, len(resp.Chunk)) - for i, evt := range resp.Chunk { - dbEvt, err := h.processEvent(ctx, evt, nil, nil, true) - if err != nil { - return err - } - entries[i] = &database.CurrentStateEntry{ - EventType: evt.Type, - StateKey: *evt.StateKey, - EventRowID: dbEvt.RowID, - Membership: event.Membership(evt.Content.Raw["membership"].(string)), - } - } - err := h.DB.CurrentState.AddMany(ctx, room.ID, false, entries) - if err != nil { - return err - } - return h.DB.Room.Upsert(ctx, &database.Room{ - ID: room.ID, - HasMemberList: true, - }) - }) - if err != nil { - return fmt.Errorf("failed to process room member list: %w", err) - } - return nil -} - -func (h *HiClient) shareGroupSession(ctx context.Context, room *database.Room) error { - err := h.loadMembers(ctx, room) - if err != nil { - return err - } - shareToInvited := h.shouldShareKeysToInvitedUsers(ctx, room.ID) - var users []id.UserID - if shareToInvited { - users, err = h.ClientStore.GetRoomJoinedOrInvitedMembers(ctx, room.ID) - } else { - users, err = h.ClientStore.GetRoomJoinedMembers(ctx, room.ID) - } - if err != nil { - return fmt.Errorf("failed to get room member list: %w", err) - } else if err = h.Crypto.ShareGroupSession(ctx, room.ID, users); err != nil { - return fmt.Errorf("failed to share group session: %w", err) - } - return nil -} - -func (h *HiClient) shouldShareKeysToInvitedUsers(ctx context.Context, roomID id.RoomID) bool { - historyVisibility, err := h.DB.CurrentState.Get(ctx, roomID, event.StateHistoryVisibility, "") - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get history visibility event") - return false - } - mautrixEvt := historyVisibility.AsRawMautrix() - err = mautrixEvt.Content.ParseRaw(mautrixEvt.Type) - if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) { - zerolog.Ctx(ctx).Err(err).Msg("Failed to parse history visibility event") - return false - } - hv, ok := mautrixEvt.Content.Parsed.(*event.HistoryVisibilityEventContent) - if !ok { - zerolog.Ctx(ctx).Warn().Msg("Unexpected parsed content type for history visibility event") - return false - } - return hv.HistoryVisibility == event.HistoryVisibilityInvited || - hv.HistoryVisibility == event.HistoryVisibilityShared || - hv.HistoryVisibility == event.HistoryVisibilityWorldReadable -} diff --git a/hicli/sync.go b/hicli/sync.go deleted file mode 100644 index dcb33637..00000000 --- a/hicli/sync.go +++ /dev/null @@ -1,833 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package hicli - -import ( - "context" - "errors" - "fmt" - "strings" - - "github.com/rs/zerolog" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" - "go.mau.fi/util/exzerolog" - "go.mau.fi/util/jsontime" - "golang.org/x/exp/slices" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/crypto" - "maunium.net/go/mautrix/crypto/olm" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/hicli/database" - "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/pushrules" -) - -type syncContext struct { - shouldWakeupRequestQueue bool - - evt *SyncComplete -} - -func (h *HiClient) preProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error { - log := zerolog.Ctx(ctx) - postponedToDevices := resp.ToDevice.Events[:0] - for _, evt := range resp.ToDevice.Events { - evt.Type.Class = event.ToDeviceEventType - err := evt.Content.ParseRaw(evt.Type) - if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) { - log.Warn().Err(err). - Stringer("event_type", &evt.Type). - Stringer("sender", evt.Sender). - Msg("Failed to parse to-device event, skipping") - continue - } - - switch content := evt.Content.Parsed.(type) { - case *event.EncryptedEventContent: - h.Crypto.HandleEncryptedEvent(ctx, evt) - case *event.RoomKeyWithheldEventContent: - h.Crypto.HandleRoomKeyWithheld(ctx, content) - default: - postponedToDevices = append(postponedToDevices, evt) - } - } - resp.ToDevice.Events = postponedToDevices - - return nil -} - -func (h *HiClient) postProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) { - h.Crypto.HandleOTKCounts(ctx, &resp.DeviceOTKCount) - go h.asyncPostProcessSyncResponse(ctx, resp, since) - syncCtx := ctx.Value(syncContextKey).(*syncContext) - if syncCtx.shouldWakeupRequestQueue { - h.WakeupRequestQueue() - } - h.firstSyncReceived = true - if !syncCtx.evt.IsEmpty() { - h.EventHandler(syncCtx.evt) - } -} - -func (h *HiClient) asyncPostProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) { - for _, evt := range resp.ToDevice.Events { - switch content := evt.Content.Parsed.(type) { - case *event.SecretRequestEventContent: - h.Crypto.HandleSecretRequest(ctx, evt.Sender, content) - case *event.RoomKeyRequestEventContent: - h.Crypto.HandleRoomKeyRequest(ctx, evt.Sender, content) - } - } -} - -func (h *HiClient) processSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) error { - if len(resp.DeviceLists.Changed) > 0 { - zerolog.Ctx(ctx).Debug(). - Array("users", exzerolog.ArrayOfStringers(resp.DeviceLists.Changed)). - Msg("Marking changed device lists for tracked users as outdated") - err := h.Crypto.CryptoStore.MarkTrackedUsersOutdated(ctx, resp.DeviceLists.Changed) - if err != nil { - return fmt.Errorf("failed to mark changed device lists as outdated: %w", err) - } - ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true - } - - for _, evt := range resp.AccountData.Events { - evt.Type.Class = event.AccountDataEventType - err := h.DB.AccountData.Put(ctx, h.Account.UserID, evt.Type, evt.Content.VeryRaw) - if err != nil { - return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err) - } - if evt.Type == event.AccountDataPushRules { - err = evt.Content.ParseRaw(evt.Type) - if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) { - zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to parse push rules in sync") - } else if pushRules, ok := evt.Content.Parsed.(*pushrules.EventContent); ok { - h.PushRules.Store(pushRules.Ruleset) - zerolog.Ctx(ctx).Debug().Msg("Updated push rules from sync") - } - } - } - for roomID, room := range resp.Rooms.Join { - err := h.processSyncJoinedRoom(ctx, roomID, room) - if err != nil { - return fmt.Errorf("failed to process joined room %s: %w", roomID, err) - } - } - for roomID, room := range resp.Rooms.Leave { - err := h.processSyncLeftRoom(ctx, roomID, room) - if err != nil { - return fmt.Errorf("failed to process left room %s: %w", roomID, err) - } - } - h.Account.NextBatch = resp.NextBatch - err := h.DB.Account.PutNextBatch(ctx, h.Account.UserID, resp.NextBatch) - if err != nil { - return fmt.Errorf("failed to save next_batch: %w", err) - } - return nil -} - -func (h *HiClient) receiptsToList(content *event.ReceiptEventContent) ([]*database.Receipt, []id.EventID) { - receiptList := make([]*database.Receipt, 0) - var newOwnReceipts []id.EventID - for eventID, receipts := range *content { - for receiptType, users := range receipts { - for userID, receiptInfo := range users { - if userID == h.Account.UserID { - newOwnReceipts = append(newOwnReceipts, eventID) - } - receiptList = append(receiptList, &database.Receipt{ - UserID: userID, - ReceiptType: receiptType, - ThreadID: receiptInfo.ThreadID, - EventID: eventID, - Timestamp: jsontime.UM(receiptInfo.Timestamp), - }) - } - } - } - return receiptList, newOwnReceipts -} - -type receiptsToSave struct { - roomID id.RoomID - receipts []*database.Receipt -} - -func (h *HiClient) processSyncJoinedRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncJoinedRoom) error { - existingRoomData, err := h.DB.Room.Get(ctx, roomID) - if err != nil { - return fmt.Errorf("failed to get room data: %w", err) - } else if existingRoomData == nil { - err = h.DB.Room.CreateRow(ctx, roomID) - if err != nil { - return fmt.Errorf("failed to ensure room row exists: %w", err) - } - existingRoomData = &database.Room{ID: roomID, SortingTimestamp: jsontime.UnixMilliNow()} - } - - for _, evt := range room.AccountData.Events { - evt.Type.Class = event.AccountDataEventType - evt.RoomID = roomID - err = h.DB.AccountData.PutRoom(ctx, h.Account.UserID, roomID, evt.Type, evt.Content.VeryRaw) - if err != nil { - return fmt.Errorf("failed to save account data event %s: %w", evt.Type.Type, err) - } - } - var receipts []receiptsToSave - var newOwnReceipts []id.EventID - for _, evt := range room.Ephemeral.Events { - evt.Type.Class = event.EphemeralEventType - err = evt.Content.ParseRaw(evt.Type) - if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) { - zerolog.Ctx(ctx).Debug().Err(err).Msg("Failed to parse ephemeral event content") - continue - } - switch evt.Type { - case event.EphemeralEventReceipt: - var receiptsList []*database.Receipt - receiptsList, newOwnReceipts = h.receiptsToList(evt.Content.AsReceipt()) - receipts = append(receipts, receiptsToSave{roomID, receiptsList}) - case event.EphemeralEventTyping: - go h.EventHandler(&Typing{ - RoomID: roomID, - TypingEventContent: *evt.Content.AsTyping(), - }) - } - } - err = h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary, newOwnReceipts, room.UnreadNotifications) - if err != nil { - return err - } - for _, rs := range receipts { - err = h.DB.Receipt.PutMany(ctx, rs.roomID, rs.receipts...) - if err != nil { - return fmt.Errorf("failed to save receipts: %w", err) - } - } - return nil -} - -func (h *HiClient) processSyncLeftRoom(ctx context.Context, roomID id.RoomID, room *mautrix.SyncLeftRoom) error { - existingRoomData, err := h.DB.Room.Get(ctx, roomID) - if err != nil { - return fmt.Errorf("failed to get room data: %w", err) - } else if existingRoomData == nil { - return nil - } - // TODO delete room instead of processing? - return h.processStateAndTimeline(ctx, existingRoomData, &room.State, &room.Timeline, &room.Summary, nil, nil) -} - -func isDecryptionErrorRetryable(err error) bool { - return errors.Is(err, crypto.NoSessionFound) || errors.Is(err, olm.UnknownMessageIndex) || errors.Is(err, crypto.ErrGroupSessionWithheld) -} - -func removeReplyFallback(evt *event.Event) []byte { - if evt.Type != event.EventMessage && evt.Type != event.EventSticker { - return nil - } - _ = evt.Content.ParseRaw(evt.Type) - content, ok := evt.Content.Parsed.(*event.MessageEventContent) - if ok && content.RelatesTo.GetReplyTo() != "" { - prevFormattedBody := content.FormattedBody - content.RemoveReplyFallback() - if content.FormattedBody != prevFormattedBody { - bytes, err := sjson.SetBytes(evt.Content.VeryRaw, "formatted_body", content.FormattedBody) - bytes, err2 := sjson.SetBytes(bytes, "body", content.Body) - if err == nil && err2 == nil { - return bytes - } - } - } - return nil -} - -func (h *HiClient) decryptEvent(ctx context.Context, evt *event.Event) (*event.Event, []byte, string, error) { - err := evt.Content.ParseRaw(evt.Type) - if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) { - return nil, nil, "", err - } - decrypted, err := h.Crypto.DecryptMegolmEvent(ctx, evt) - if err != nil { - return nil, nil, "", err - } - withoutFallback := removeReplyFallback(decrypted) - if withoutFallback != nil { - return decrypted, withoutFallback, decrypted.Type.Type, nil - } - return decrypted, decrypted.Content.VeryRaw, decrypted.Type.Type, nil -} - -func (h *HiClient) addMediaCache( - ctx context.Context, - eventRowID database.EventRowID, - uri id.ContentURIString, - file *event.EncryptedFileInfo, - info *event.FileInfo, - fileName string, -) { - parsedMXC := uri.ParseOrIgnore() - if !parsedMXC.IsValid() { - return - } - cm := &database.CachedMedia{ - MXC: parsedMXC, - EventRowID: eventRowID, - FileName: fileName, - } - if file != nil { - cm.EncFile = &file.EncryptedFile - } - if info != nil { - cm.MimeType = info.MimeType - } - err := h.DB.CachedMedia.Put(ctx, cm) - if err != nil { - zerolog.Ctx(ctx).Warn().Err(err). - Stringer("mxc", parsedMXC). - Int64("event_rowid", int64(eventRowID)). - Msg("Failed to add cached media entry") - } -} - -func (h *HiClient) cacheMedia(ctx context.Context, evt *event.Event, rowID database.EventRowID) { - switch evt.Type { - case event.EventMessage, event.EventSticker: - content, ok := evt.Content.Parsed.(*event.MessageEventContent) - if !ok { - return - } - if content.File != nil { - h.addMediaCache(ctx, rowID, content.File.URL, content.File, content.Info, content.GetFileName()) - } else if content.URL != "" { - h.addMediaCache(ctx, rowID, content.URL, nil, content.Info, content.GetFileName()) - } - if content.GetInfo().ThumbnailFile != nil { - h.addMediaCache(ctx, rowID, content.Info.ThumbnailFile.URL, content.Info.ThumbnailFile, content.Info.ThumbnailInfo, "") - } else if content.GetInfo().ThumbnailURL != "" { - h.addMediaCache(ctx, rowID, content.Info.ThumbnailURL, nil, content.Info.ThumbnailInfo, "") - } - case event.StateRoomAvatar: - _ = evt.Content.ParseRaw(evt.Type) - content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent) - if !ok { - return - } - h.addMediaCache(ctx, rowID, content.URL, nil, nil, "") - case event.StateMember: - _ = evt.Content.ParseRaw(evt.Type) - content, ok := evt.Content.Parsed.(*event.MemberEventContent) - if !ok { - return - } - h.addMediaCache(ctx, rowID, content.AvatarURL, nil, nil, "") - } -} - -func (h *HiClient) calculateLocalContent(ctx context.Context, dbEvt *database.Event, evt *event.Event) *database.LocalContent { - if evt.Type != event.EventMessage && evt.Type != event.EventSticker { - return nil - } - _ = evt.Content.ParseRaw(evt.Type) - content, ok := evt.Content.Parsed.(*event.MessageEventContent) - if !ok { - return nil - } - if dbEvt.RelationType == event.RelReplace && content.NewContent != nil { - content = content.NewContent - } - if content != nil { - var sanitizedHTML string - if content.Format == event.FormatHTML { - sanitizedHTML, _ = sanitizeAndLinkifyHTML(content.FormattedBody) - } else { - var builder strings.Builder - linkifyAndWriteBytes(&builder, []byte(content.Body)) - sanitizedHTML = builder.String() - } - return &database.LocalContent{SanitizedHTML: sanitizedHTML} - } - return nil -} - -func (h *HiClient) postDecryptProcess(ctx context.Context, llSummary *mautrix.LazyLoadSummary, dbEvt *database.Event, evt *event.Event) { - if dbEvt.RowID != 0 { - h.cacheMedia(ctx, evt, dbEvt.RowID) - } - dbEvt.UnreadType = h.evaluatePushRules(ctx, llSummary, dbEvt.GetNonPushUnreadType(), evt) - dbEvt.LocalContent = h.calculateLocalContent(ctx, dbEvt, evt) -} - -func (h *HiClient) processEvent( - ctx context.Context, - evt *event.Event, - llSummary *mautrix.LazyLoadSummary, - decryptionQueue map[id.SessionID]*database.SessionRequest, - checkDB bool, -) (*database.Event, error) { - if checkDB { - dbEvt, err := h.DB.Event.GetByID(ctx, evt.ID) - if err != nil { - return nil, fmt.Errorf("failed to check if event %s exists: %w", evt.ID, err) - } else if dbEvt != nil { - return dbEvt, nil - } - } - dbEvt := database.MautrixToEvent(evt) - contentWithoutFallback := removeReplyFallback(evt) - if contentWithoutFallback != nil { - dbEvt.Content = contentWithoutFallback - } - var decryptionErr error - var decryptedMautrixEvt *event.Event - if evt.Type == event.EventEncrypted && dbEvt.RedactedBy == "" { - decryptedMautrixEvt, dbEvt.Decrypted, dbEvt.DecryptedType, decryptionErr = h.decryptEvent(ctx, evt) - if decryptionErr != nil { - dbEvt.DecryptionError = decryptionErr.Error() - } - } else if evt.Type == event.EventRedaction { - if evt.Redacts != "" && gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str != evt.Redacts.String() { - var err error - evt.Content.VeryRaw, err = sjson.SetBytes(evt.Content.VeryRaw, "redacts", evt.Redacts) - if err != nil { - return dbEvt, fmt.Errorf("failed to set redacts field: %w", err) - } - } else if evt.Redacts == "" { - evt.Redacts = id.EventID(gjson.GetBytes(evt.Content.VeryRaw, "redacts").Str) - } - } - if decryptedMautrixEvt != nil { - h.postDecryptProcess(ctx, llSummary, dbEvt, decryptedMautrixEvt) - } else { - h.postDecryptProcess(ctx, llSummary, dbEvt, evt) - } - _, err := h.DB.Event.Upsert(ctx, dbEvt) - if err != nil { - return dbEvt, fmt.Errorf("failed to save event %s: %w", evt.ID, err) - } - if decryptedMautrixEvt != nil { - h.cacheMedia(ctx, decryptedMautrixEvt, dbEvt.RowID) - } else { - h.cacheMedia(ctx, evt, dbEvt.RowID) - } - if decryptionErr != nil && isDecryptionErrorRetryable(decryptionErr) { - req, ok := decryptionQueue[dbEvt.MegolmSessionID] - if !ok { - req = &database.SessionRequest{ - RoomID: evt.RoomID, - SessionID: dbEvt.MegolmSessionID, - Sender: evt.Sender, - } - } - minIndex, _ := crypto.ParseMegolmMessageIndex(evt.Content.AsEncrypted().MegolmCiphertext) - req.MinIndex = min(uint32(minIndex), req.MinIndex) - if decryptionQueue != nil { - decryptionQueue[dbEvt.MegolmSessionID] = req - } else { - err = h.DB.SessionRequest.Put(ctx, req) - if err != nil { - zerolog.Ctx(ctx).Err(err). - Stringer("session_id", dbEvt.MegolmSessionID). - Msg("Failed to save session request") - } else { - h.WakeupRequestQueue() - } - } - } - return dbEvt, err -} - -func (h *HiClient) processStateAndTimeline( - ctx context.Context, - room *database.Room, - state *mautrix.SyncEventsList, - timeline *mautrix.SyncTimeline, - summary *mautrix.LazyLoadSummary, - newOwnReceipts []id.EventID, - serverNotificationCounts *mautrix.UnreadNotificationCounts, -) error { - updatedRoom := &database.Room{ - ID: room.ID, - - SortingTimestamp: room.SortingTimestamp, - NameQuality: room.NameQuality, - UnreadHighlights: room.UnreadHighlights, - UnreadNotifications: room.UnreadNotifications, - UnreadMessages: room.UnreadMessages, - } - if serverNotificationCounts != nil { - updatedRoom.UnreadHighlights = serverNotificationCounts.HighlightCount - updatedRoom.UnreadNotifications = serverNotificationCounts.NotificationCount - } - heroesChanged := false - if summary.Heroes == nil && summary.JoinedMemberCount == nil && summary.InvitedMemberCount == nil { - summary = room.LazyLoadSummary - } else if room.LazyLoadSummary == nil || - !slices.Equal(summary.Heroes, room.LazyLoadSummary.Heroes) || - !intPtrEqual(summary.JoinedMemberCount, room.LazyLoadSummary.JoinedMemberCount) || - !intPtrEqual(summary.InvitedMemberCount, room.LazyLoadSummary.InvitedMemberCount) { - updatedRoom.LazyLoadSummary = summary - heroesChanged = true - } - decryptionQueue := make(map[id.SessionID]*database.SessionRequest) - allNewEvents := make([]*database.Event, 0, len(state.Events)+len(timeline.Events)) - newNotifications := make([]SyncNotification, 0) - recalculatePreviewEvent := false - addOldEvent := func(rowID database.EventRowID, evtID id.EventID) (dbEvt *database.Event, err error) { - if rowID != 0 { - dbEvt, err = h.DB.Event.GetByRowID(ctx, rowID) - } else { - dbEvt, err = h.DB.Event.GetByID(ctx, evtID) - } - if err != nil { - return nil, fmt.Errorf("failed to get redaction target: %w", err) - } else if dbEvt == nil { - return nil, nil - } - allNewEvents = append(allNewEvents, dbEvt) - return dbEvt, nil - } - processRedaction := func(evt *event.Event) error { - dbEvt, err := addOldEvent(0, evt.Redacts) - if err != nil { - return fmt.Errorf("failed to get redaction target: %w", err) - } - if dbEvt == nil { - return nil - } - if dbEvt.RelationType == event.RelReplace || dbEvt.RelationType == event.RelAnnotation { - _, err = addOldEvent(0, dbEvt.RelatesTo) - if err != nil { - return fmt.Errorf("failed to get relation target of redaction target: %w", err) - } - } - if updatedRoom.PreviewEventRowID == dbEvt.RowID { - updatedRoom.PreviewEventRowID = 0 - recalculatePreviewEvent = true - } - return nil - } - processNewEvent := func(evt *event.Event, isTimeline, isUnread bool) (database.EventRowID, error) { - evt.RoomID = room.ID - dbEvt, err := h.processEvent(ctx, evt, summary, decryptionQueue, false) - if err != nil { - return -1, err - } - if isUnread && dbEvt.UnreadType.Is(database.UnreadTypeNotify) { - newNotifications = append(newNotifications, SyncNotification{ - RowID: dbEvt.RowID, - Sound: dbEvt.UnreadType.Is(database.UnreadTypeSound), - }) - } - if isTimeline { - if dbEvt.CanUseForPreview() { - updatedRoom.PreviewEventRowID = dbEvt.RowID - recalculatePreviewEvent = false - } - updatedRoom.BumpSortingTimestamp(dbEvt) - } - if evt.StateKey != nil { - var membership event.Membership - if evt.Type == event.StateMember { - membership = event.Membership(gjson.GetBytes(evt.Content.VeryRaw, "membership").Str) - if summary != nil && slices.Contains(summary.Heroes, id.UserID(*evt.StateKey)) { - heroesChanged = true - } - } else if evt.Type == event.StateElementFunctionalMembers { - heroesChanged = true - } - err = h.DB.CurrentState.Set(ctx, room.ID, evt.Type, *evt.StateKey, dbEvt.RowID, membership) - if err != nil { - return -1, fmt.Errorf("failed to save current state event ID %s for %s/%s: %w", evt.ID, evt.Type.Type, *evt.StateKey, err) - } - processImportantEvent(ctx, evt, room, updatedRoom) - } - allNewEvents = append(allNewEvents, dbEvt) - if evt.Type == event.EventRedaction && evt.Redacts != "" { - err = processRedaction(evt) - if err != nil { - return -1, fmt.Errorf("failed to process redaction: %w", err) - } - } else if dbEvt.RelationType == event.RelReplace || dbEvt.RelationType == event.RelAnnotation { - _, err = addOldEvent(0, dbEvt.RelatesTo) - if err != nil { - return -1, fmt.Errorf("failed to get relation target of event: %w", err) - } - } - return dbEvt.RowID, nil - } - changedState := make(map[event.Type]map[string]database.EventRowID) - setNewState := func(evtType event.Type, stateKey string, rowID database.EventRowID) { - if _, ok := changedState[evtType]; !ok { - changedState[evtType] = make(map[string]database.EventRowID) - } - changedState[evtType][stateKey] = rowID - } - for _, evt := range state.Events { - evt.Type.Class = event.StateEventType - rowID, err := processNewEvent(evt, false, false) - if err != nil { - return err - } - setNewState(evt.Type, *evt.StateKey, rowID) - } - var timelineRowTuples []database.TimelineRowTuple - var err error - if len(timeline.Events) > 0 { - timelineIDs := make([]database.EventRowID, len(timeline.Events)) - readUpToIndex := -1 - for i := len(timeline.Events) - 1; i >= 0; i-- { - if slices.Contains(newOwnReceipts, timeline.Events[i].ID) { - readUpToIndex = i - break - } - } - for i, evt := range timeline.Events { - if evt.StateKey != nil { - evt.Type.Class = event.StateEventType - } else { - evt.Type.Class = event.MessageEventType - } - timelineIDs[i], err = processNewEvent(evt, true, i > readUpToIndex) - if err != nil { - return err - } - if evt.StateKey != nil { - setNewState(evt.Type, *evt.StateKey, timelineIDs[i]) - } - } - for _, entry := range decryptionQueue { - err = h.DB.SessionRequest.Put(ctx, entry) - if err != nil { - return fmt.Errorf("failed to save session request for %s: %w", entry.SessionID, err) - } - } - if len(decryptionQueue) > 0 { - ctx.Value(syncContextKey).(*syncContext).shouldWakeupRequestQueue = true - } - if timeline.Limited { - err = h.DB.Timeline.Clear(ctx, room.ID) - if err != nil { - return fmt.Errorf("failed to clear old timeline: %w", err) - } - updatedRoom.PrevBatch = timeline.PrevBatch - h.paginationInterrupterLock.Lock() - if interrupt, ok := h.paginationInterrupter[room.ID]; ok { - interrupt(ErrTimelineReset) - } - h.paginationInterrupterLock.Unlock() - } - timelineRowTuples, err = h.DB.Timeline.Append(ctx, room.ID, timelineIDs) - if err != nil { - return fmt.Errorf("failed to append timeline: %w", err) - } - } else { - timelineRowTuples = make([]database.TimelineRowTuple, 0) - } - if recalculatePreviewEvent && updatedRoom.PreviewEventRowID == 0 { - updatedRoom.PreviewEventRowID, err = h.DB.Room.RecalculatePreview(ctx, room.ID) - if err != nil { - return fmt.Errorf("failed to recalculate preview event: %w", err) - } - _, err = addOldEvent(updatedRoom.PreviewEventRowID, "") - if err != nil { - return fmt.Errorf("failed to get preview event: %w", err) - } - } - // Calculate name from participants if participants changed and current name was generated from participants, or if the room name was unset - if (heroesChanged && updatedRoom.NameQuality <= database.NameQualityParticipants) || updatedRoom.NameQuality == database.NameQualityNil { - name, dmAvatarURL, err := h.calculateRoomParticipantName(ctx, room.ID, summary) - if err != nil { - return fmt.Errorf("failed to calculate room name: %w", err) - } - updatedRoom.Name = &name - updatedRoom.NameQuality = database.NameQualityParticipants - if !dmAvatarURL.IsEmpty() && !room.ExplicitAvatar { - updatedRoom.Avatar = &dmAvatarURL - } - } - if timeline.PrevBatch != "" && (room.PrevBatch == "" || timeline.Limited) { - updatedRoom.PrevBatch = timeline.PrevBatch - } - roomChanged := updatedRoom.CheckChangesAndCopyInto(room) - if roomChanged { - err = h.DB.Room.Upsert(ctx, updatedRoom) - if err != nil { - return fmt.Errorf("failed to save room data: %w", err) - } - } - if roomChanged || len(timelineRowTuples) > 0 || len(allNewEvents) > 0 { - ctx.Value(syncContextKey).(*syncContext).evt.Rooms[room.ID] = &SyncRoom{ - Meta: room, - Timeline: timelineRowTuples, - State: changedState, - Reset: timeline.Limited, - Events: allNewEvents, - Notifications: newNotifications, - } - } - return nil -} - -func joinMemberNames(names []string, totalCount int) string { - if len(names) == 1 { - return names[0] - } else if len(names) < 5 || (len(names) == 5 && totalCount <= 6) { - return strings.Join(names[:len(names)-1], ", ") + " and " + names[len(names)-1] - } else { - return fmt.Sprintf("%s and %d others", strings.Join(names[:4], ", "), totalCount-5) - } -} - -func (h *HiClient) calculateRoomParticipantName(ctx context.Context, roomID id.RoomID, summary *mautrix.LazyLoadSummary) (string, id.ContentURI, error) { - var primaryAvatarURL id.ContentURI - if summary == nil || len(summary.Heroes) == 0 { - return "Empty room", primaryAvatarURL, nil - } - var functionalMembers []id.UserID - functionalMembersEvt, err := h.DB.CurrentState.Get(ctx, roomID, event.StateElementFunctionalMembers, "") - if err != nil { - return "", primaryAvatarURL, fmt.Errorf("failed to get %s event: %w", event.StateElementFunctionalMembers.Type, err) - } else if functionalMembersEvt != nil { - mautrixEvt := functionalMembersEvt.AsRawMautrix() - _ = mautrixEvt.Content.ParseRaw(mautrixEvt.Type) - content, ok := mautrixEvt.Content.Parsed.(*event.ElementFunctionalMembersContent) - if ok { - functionalMembers = content.ServiceMembers - } - } - var members, leftMembers []string - var memberCount int - if summary.JoinedMemberCount != nil && *summary.JoinedMemberCount > 0 { - memberCount = *summary.JoinedMemberCount - } else if summary.InvitedMemberCount != nil { - memberCount = *summary.InvitedMemberCount - } - for _, hero := range summary.Heroes { - if slices.Contains(functionalMembers, hero) { - memberCount-- - continue - } else if len(members) >= 5 { - break - } - heroEvt, err := h.DB.CurrentState.Get(ctx, roomID, event.StateMember, hero.String()) - if err != nil { - return "", primaryAvatarURL, fmt.Errorf("failed to get %s's member event: %w", hero, err) - } else if heroEvt == nil { - leftMembers = append(leftMembers, hero.String()) - continue - } - membership := gjson.GetBytes(heroEvt.Content, "membership").Str - name := gjson.GetBytes(heroEvt.Content, "displayname").Str - if name == "" { - name = hero.String() - } - avatarURL := gjson.GetBytes(heroEvt.Content, "avatar_url").Str - if avatarURL != "" { - primaryAvatarURL = id.ContentURIString(avatarURL).ParseOrIgnore() - } - if membership == "join" || membership == "invite" { - members = append(members, name) - } else { - leftMembers = append(leftMembers, name) - } - } - if len(members)+len(leftMembers) > 1 || !primaryAvatarURL.IsValid() { - primaryAvatarURL = id.ContentURI{} - } - if len(members) > 0 { - return joinMemberNames(members, memberCount), primaryAvatarURL, nil - } else if len(leftMembers) > 0 { - return fmt.Sprintf("Empty room (was %s)", joinMemberNames(leftMembers, memberCount)), primaryAvatarURL, nil - } else { - return "Empty room", primaryAvatarURL, nil - } -} - -func intPtrEqual(a, b *int) bool { - if a == nil || b == nil { - return a == b - } - return *a == *b -} - -func processImportantEvent(ctx context.Context, evt *event.Event, existingRoomData, updatedRoom *database.Room) (roomDataChanged bool) { - if evt.StateKey == nil { - return - } - switch evt.Type { - case event.StateCreate, event.StateRoomName, event.StateCanonicalAlias, event.StateRoomAvatar, event.StateTopic, event.StateEncryption: - if *evt.StateKey != "" { - return - } - default: - return - } - err := evt.Content.ParseRaw(evt.Type) - if err != nil && !errors.Is(err, event.ErrContentAlreadyParsed) { - zerolog.Ctx(ctx).Warn().Err(err). - Stringer("event_type", &evt.Type). - Stringer("event_id", evt.ID). - Msg("Failed to parse state event, skipping") - return - } - switch evt.Type { - case event.StateCreate: - updatedRoom.CreationContent, _ = evt.Content.Parsed.(*event.CreateEventContent) - case event.StateEncryption: - newEncryption, _ := evt.Content.Parsed.(*event.EncryptionEventContent) - if existingRoomData.EncryptionEvent == nil || existingRoomData.EncryptionEvent.Algorithm == newEncryption.Algorithm { - updatedRoom.EncryptionEvent = newEncryption - } - case event.StateRoomName: - content, ok := evt.Content.Parsed.(*event.RoomNameEventContent) - if ok { - updatedRoom.Name = &content.Name - updatedRoom.NameQuality = database.NameQualityExplicit - if content.Name == "" { - if updatedRoom.CanonicalAlias != nil && *updatedRoom.CanonicalAlias != "" { - updatedRoom.Name = (*string)(updatedRoom.CanonicalAlias) - updatedRoom.NameQuality = database.NameQualityCanonicalAlias - } else if existingRoomData.CanonicalAlias != nil && *existingRoomData.CanonicalAlias != "" { - updatedRoom.Name = (*string)(existingRoomData.CanonicalAlias) - updatedRoom.NameQuality = database.NameQualityCanonicalAlias - } else { - updatedRoom.NameQuality = database.NameQualityNil - } - } - } - case event.StateCanonicalAlias: - content, ok := evt.Content.Parsed.(*event.CanonicalAliasEventContent) - if ok { - updatedRoom.CanonicalAlias = &content.Alias - if updatedRoom.NameQuality <= database.NameQualityCanonicalAlias { - updatedRoom.Name = (*string)(&content.Alias) - updatedRoom.NameQuality = database.NameQualityCanonicalAlias - if content.Alias == "" { - updatedRoom.NameQuality = database.NameQualityNil - } - } - } - case event.StateRoomAvatar: - content, ok := evt.Content.Parsed.(*event.RoomAvatarEventContent) - if ok { - url, _ := content.URL.Parse() - updatedRoom.Avatar = &url - updatedRoom.ExplicitAvatar = true - } - case event.StateTopic: - content, ok := evt.Content.Parsed.(*event.TopicEventContent) - if ok { - updatedRoom.Topic = &content.Topic - } - } - return -} diff --git a/hicli/syncwrap.go b/hicli/syncwrap.go deleted file mode 100644 index 13837202..00000000 --- a/hicli/syncwrap.go +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package hicli - -import ( - "context" - "fmt" - "time" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/id" -) - -type hiSyncer HiClient - -var _ mautrix.Syncer = (*hiSyncer)(nil) - -type contextKey int - -const ( - syncContextKey contextKey = iota -) - -func (h *hiSyncer) ProcessResponse(ctx context.Context, resp *mautrix.RespSync, since string) error { - c := (*HiClient)(h) - ctx = context.WithValue(ctx, syncContextKey, &syncContext{evt: &SyncComplete{Rooms: make(map[id.RoomID]*SyncRoom, len(resp.Rooms.Join))}}) - err := c.preProcessSyncResponse(ctx, resp, since) - if err != nil { - return err - } - err = c.DB.DoTxn(ctx, nil, func(ctx context.Context) error { - return c.processSyncResponse(ctx, resp, since) - }) - if err != nil { - return err - } - c.postProcessSyncResponse(ctx, resp, since) - return nil -} - -func (h *hiSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.Duration, error) { - (*HiClient)(h).Log.Err(err).Msg("Sync failed, retrying in 1 second") - return 1 * time.Second, nil -} - -func (h *hiSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter { - if !h.Verified { - return &mautrix.Filter{ - Presence: mautrix.FilterPart{ - NotRooms: []id.RoomID{"*"}, - }, - Room: mautrix.RoomFilter{ - NotRooms: []id.RoomID{"*"}, - }, - } - } - return &mautrix.Filter{ - Presence: mautrix.FilterPart{ - NotRooms: []id.RoomID{"*"}, - }, - Room: mautrix.RoomFilter{ - State: mautrix.FilterPart{ - LazyLoadMembers: true, - }, - Timeline: mautrix.FilterPart{ - Limit: 100, - LazyLoadMembers: true, - }, - }, - } -} - -type hiStore HiClient - -var _ mautrix.SyncStore = (*hiStore)(nil) - -// Filter ID save and load are intentionally no-ops: we want to recreate filters when restarting syncing - -func (h *hiStore) SaveFilterID(_ context.Context, _ id.UserID, _ string) error { return nil } -func (h *hiStore) LoadFilterID(_ context.Context, _ id.UserID) (string, error) { return "", nil } - -func (h *hiStore) SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) error { - // This is intentionally a no-op: we don't want to save the next batch before processing the sync - return nil -} - -func (h *hiStore) LoadNextBatch(_ context.Context, userID id.UserID) (string, error) { - if h.Account.UserID != userID { - return "", fmt.Errorf("mismatching user ID") - } - return h.Account.NextBatch, nil -} diff --git a/hicli/verify.go b/hicli/verify.go deleted file mode 100644 index 6dc2a4c3..00000000 --- a/hicli/verify.go +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package hicli - -import ( - "context" - "encoding/base64" - "fmt" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix/crypto" - "maunium.net/go/mautrix/crypto/backup" - "maunium.net/go/mautrix/crypto/ssss" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -func (h *HiClient) checkIsCurrentDeviceVerified(ctx context.Context) (bool, error) { - keys := h.Crypto.GetOwnCrossSigningPublicKeys(ctx) - if keys == nil { - return false, fmt.Errorf("own cross-signing keys not found") - } - isVerified, err := h.Crypto.CryptoStore.IsKeySignedBy(ctx, h.Account.UserID, h.Crypto.GetAccount().SigningKey(), h.Account.UserID, keys.SelfSigningKey) - if err != nil { - return false, fmt.Errorf("failed to check if current device is signed by own self-signing key: %w", err) - } - return isVerified, nil -} - -func (h *HiClient) fetchKeyBackupKey(ctx context.Context, ssssKey *ssss.Key) error { - latestVersion, err := h.Client.GetKeyBackupLatestVersion(ctx) - if err != nil { - return fmt.Errorf("failed to get key backup latest version: %w", err) - } - h.KeyBackupVersion = latestVersion.Version - data, err := h.Crypto.SSSS.GetDecryptedAccountData(ctx, event.AccountDataMegolmBackupKey, ssssKey) - if err != nil { - return fmt.Errorf("failed to get megolm backup key from SSSS: %w", err) - } - key, err := backup.MegolmBackupKeyFromBytes(data) - if err != nil { - return fmt.Errorf("failed to parse megolm backup key: %w", err) - } - err = h.CryptoStore.PutSecret(ctx, id.SecretMegolmBackupV1, base64.StdEncoding.EncodeToString(key.Bytes())) - if err != nil { - return fmt.Errorf("failed to store megolm backup key: %w", err) - } - h.KeyBackupKey = key - return nil -} - -func (h *HiClient) getAndDecodeSecret(ctx context.Context, secret id.Secret) ([]byte, error) { - secretData, err := h.CryptoStore.GetSecret(ctx, secret) - if err != nil { - return nil, fmt.Errorf("failed to get secret %s: %w", secret, err) - } - data, err := base64.StdEncoding.DecodeString(secretData) - if err != nil { - return nil, fmt.Errorf("failed to decode secret %s: %w", secret, err) - } - return data, nil -} - -func (h *HiClient) loadPrivateKeys(ctx context.Context) error { - zerolog.Ctx(ctx).Debug().Msg("Loading cross-signing private keys") - masterKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSMaster) - if err != nil { - return fmt.Errorf("failed to get master key: %w", err) - } - selfSigningKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSSelfSigning) - if err != nil { - return fmt.Errorf("failed to get self-signing key: %w", err) - } - userSigningKeySeed, err := h.getAndDecodeSecret(ctx, id.SecretXSUserSigning) - if err != nil { - return fmt.Errorf("failed to get user signing key: %w", err) - } - err = h.Crypto.ImportCrossSigningKeys(crypto.CrossSigningSeeds{ - MasterKey: masterKeySeed, - SelfSigningKey: selfSigningKeySeed, - UserSigningKey: userSigningKeySeed, - }) - if err != nil { - return fmt.Errorf("failed to import cross-signing private keys: %w", err) - } - zerolog.Ctx(ctx).Debug().Msg("Loading key backup key") - keyBackupKey, err := h.getAndDecodeSecret(ctx, id.SecretMegolmBackupV1) - if err != nil { - return fmt.Errorf("failed to get megolm backup key: %w", err) - } - h.KeyBackupKey, err = backup.MegolmBackupKeyFromBytes(keyBackupKey) - if err != nil { - return fmt.Errorf("failed to parse megolm backup key: %w", err) - } - zerolog.Ctx(ctx).Debug().Msg("Fetching key backup version") - latestVersion, err := h.Client.GetKeyBackupLatestVersion(ctx) - if err != nil { - return fmt.Errorf("failed to get key backup latest version: %w", err) - } - h.KeyBackupVersion = latestVersion.Version - zerolog.Ctx(ctx).Debug().Msg("Secrets loaded") - return nil -} - -func (h *HiClient) storeCrossSigningPrivateKeys(ctx context.Context) error { - keys := h.Crypto.CrossSigningKeys - err := h.CryptoStore.PutSecret(ctx, id.SecretXSMaster, base64.StdEncoding.EncodeToString(keys.MasterKey.Seed())) - if err != nil { - return err - } - err = h.CryptoStore.PutSecret(ctx, id.SecretXSSelfSigning, base64.StdEncoding.EncodeToString(keys.SelfSigningKey.Seed())) - if err != nil { - return err - } - err = h.CryptoStore.PutSecret(ctx, id.SecretXSUserSigning, base64.StdEncoding.EncodeToString(keys.UserSigningKey.Seed())) - if err != nil { - return err - } - return nil -} - -func (h *HiClient) VerifyWithRecoveryKey(ctx context.Context, code string) error { - defer h.dispatchCurrentState() - keyID, keyData, err := h.Crypto.SSSS.GetDefaultKeyData(ctx) - if err != nil { - return fmt.Errorf("failed to get default SSSS key data: %w", err) - } - key, err := keyData.VerifyRecoveryKey(keyID, code) - if err != nil { - return err - } - err = h.Crypto.FetchCrossSigningKeysFromSSSS(ctx, key) - if err != nil { - return fmt.Errorf("failed to fetch cross-signing keys from SSSS: %w", err) - } - err = h.Crypto.SignOwnDevice(ctx, h.Crypto.OwnIdentity()) - if err != nil { - return fmt.Errorf("failed to sign own device: %w", err) - } - err = h.Crypto.SignOwnMasterKey(ctx) - if err != nil { - return fmt.Errorf("failed to sign own master key: %w", err) - } - err = h.storeCrossSigningPrivateKeys(ctx) - if err != nil { - return fmt.Errorf("failed to store cross-signing private keys: %w", err) - } - err = h.fetchKeyBackupKey(ctx, key) - if err != nil { - return fmt.Errorf("failed to fetch key backup key: %w", err) - } - h.Verified = true - if !h.IsSyncing() { - go h.Sync() - } - return nil -} From 2a2a576bf4223227a49b64d10e86d85ed88965d1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 17 Oct 2024 20:15:19 +0300 Subject: [PATCH 0852/1647] format/rainbow: move into gomuks --- format/mdext/rainbow/goldmark.go | 120 ------------------------------- format/mdext/rainbow/gradient.go | 56 --------------- go.mod | 2 - go.sum | 4 -- 4 files changed, 182 deletions(-) delete mode 100644 format/mdext/rainbow/goldmark.go delete mode 100644 format/mdext/rainbow/gradient.go diff --git a/format/mdext/rainbow/goldmark.go b/format/mdext/rainbow/goldmark.go deleted file mode 100644 index 59a36178..00000000 --- a/format/mdext/rainbow/goldmark.go +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package rainbow - -import ( - "fmt" - "unicode" - - "github.com/rivo/uniseg" - "github.com/yuin/goldmark" - "github.com/yuin/goldmark/ast" - "github.com/yuin/goldmark/renderer" - "github.com/yuin/goldmark/renderer/html" - "github.com/yuin/goldmark/util" - "go.mau.fi/util/random" -) - -// Extension is a goldmark extension that adds rainbow text coloring to the HTML renderer. -var Extension = &extRainbow{} - -type extRainbow struct{} -type rainbowRenderer struct { - HardWraps bool - ColorID string -} - -var defaultRB = &rainbowRenderer{HardWraps: true, ColorID: random.String(16)} - -func (er *extRainbow) Extend(m goldmark.Markdown) { - m.Renderer().AddOptions(renderer.WithNodeRenderers(util.Prioritized(defaultRB, 0))) -} - -func (rb *rainbowRenderer) RegisterFuncs(reg renderer.NodeRendererFuncRegisterer) { - reg.Register(ast.KindText, rb.renderText) - reg.Register(ast.KindString, rb.renderString) -} - -type rainbowBufWriter struct { - util.BufWriter - ColorID string -} - -func (rbw rainbowBufWriter) WriteString(s string) (int, error) { - i := 0 - graphemes := uniseg.NewGraphemes(s) - for graphemes.Next() { - runes := graphemes.Runes() - if len(runes) == 1 && unicode.IsSpace(runes[0]) { - i2, err := rbw.BufWriter.WriteRune(runes[0]) - i += i2 - if err != nil { - return i, err - } - continue - } - i2, err := fmt.Fprintf(rbw.BufWriter, "%s", rbw.ColorID, graphemes.Str()) - i += i2 - if err != nil { - return i, err - } - } - return i, nil -} - -func (rbw rainbowBufWriter) Write(data []byte) (int, error) { - return rbw.WriteString(string(data)) -} - -func (rbw rainbowBufWriter) WriteByte(c byte) error { - _, err := rbw.WriteRune(rune(c)) - return err -} - -func (rbw rainbowBufWriter) WriteRune(r rune) (int, error) { - if unicode.IsSpace(r) { - return rbw.BufWriter.WriteRune(r) - } else { - return fmt.Fprintf(rbw.BufWriter, "%c", rbw.ColorID, r) - } -} - -func (rb *rainbowRenderer) renderText(w util.BufWriter, source []byte, node ast.Node, entering bool) (ast.WalkStatus, error) { - if !entering { - return ast.WalkContinue, nil - } - n := node.(*ast.Text) - segment := n.Segment - if n.IsRaw() { - html.DefaultWriter.RawWrite(rainbowBufWriter{w, rb.ColorID}, segment.Value(source)) - } else { - html.DefaultWriter.Write(rainbowBufWriter{w, rb.ColorID}, segment.Value(source)) - if n.HardLineBreak() || (n.SoftLineBreak() && rb.HardWraps) { - _, _ = w.WriteString("
\n") - } else if n.SoftLineBreak() { - _ = w.WriteByte('\n') - } - } - return ast.WalkContinue, nil -} - -func (rb *rainbowRenderer) renderString(w util.BufWriter, source []byte, node ast.Node, entering bool) (ast.WalkStatus, error) { - if !entering { - return ast.WalkContinue, nil - } - n := node.(*ast.String) - if n.IsCode() { - _, _ = w.Write(n.Value) - } else { - if n.IsRaw() { - html.DefaultWriter.RawWrite(rainbowBufWriter{w, rb.ColorID}, n.Value) - } else { - html.DefaultWriter.Write(rainbowBufWriter{w, rb.ColorID}, n.Value) - } - } - return ast.WalkContinue, nil -} diff --git a/format/mdext/rainbow/gradient.go b/format/mdext/rainbow/gradient.go deleted file mode 100644 index 34c499e6..00000000 --- a/format/mdext/rainbow/gradient.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package rainbow - -import ( - "regexp" - "strings" - - "github.com/lucasb-eyer/go-colorful" -) - -// GradientTable from https://github.com/lucasb-eyer/go-colorful/blob/master/doc/gradientgen/gradientgen.go -type GradientTable []struct { - Col colorful.Color - Pos float64 -} - -func (gt GradientTable) GetInterpolatedColorFor(t float64) colorful.Color { - for i := 0; i < len(gt)-1; i++ { - c1 := gt[i] - c2 := gt[i+1] - if c1.Pos <= t && t <= c2.Pos { - t := (t - c1.Pos) / (c2.Pos - c1.Pos) - return c1.Col.BlendHcl(c2.Col, t).Clamped() - } - } - return gt[len(gt)-1].Col -} - -var Gradient = GradientTable{ - {colorful.LinearRgb(1, 0, 0), 0 / 11.0}, - {colorful.LinearRgb(1, 0.5, 0), 1 / 11.0}, - {colorful.LinearRgb(1, 1, 0), 2 / 11.0}, - {colorful.LinearRgb(0.5, 1, 0), 3 / 11.0}, - {colorful.LinearRgb(0, 1, 0), 4 / 11.0}, - {colorful.LinearRgb(0, 1, 0.5), 5 / 11.0}, - {colorful.LinearRgb(0, 1, 1), 6 / 11.0}, - {colorful.LinearRgb(0, 0.5, 1), 7 / 11.0}, - {colorful.LinearRgb(0, 0, 1), 8 / 11.0}, - {colorful.LinearRgb(0.5, 0, 1), 9 / 11.0}, - {colorful.LinearRgb(1, 0, 1), 10 / 11.0}, - {colorful.LinearRgb(1, 0, 0.5), 11 / 11.0}, -} - -func ApplyColor(htmlBody string) string { - count := strings.Count(htmlBody, defaultRB.ColorID) - i := -1 - return regexp.MustCompile(defaultRB.ColorID).ReplaceAllStringFunc(htmlBody, func(match string) string { - i++ - return Gradient.GetInterpolatedColorFor(float64(i) / float64(count)).Hex() - }) -} diff --git a/go.mod b/go.mod index f45b8990..2f47e155 100644 --- a/go.mod +++ b/go.mod @@ -10,9 +10,7 @@ require ( github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.5.0 github.com/lib/pq v1.10.9 - github.com/lucasb-eyer/go-colorful v1.2.0 github.com/mattn/go-sqlite3 v1.14.24 - github.com/rivo/uniseg v0.4.7 github.com/rs/xid v1.6.0 github.com/rs/zerolog v1.33.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e diff --git a/go.sum b/go.sum index e7a58076..48d9fa8d 100644 --- a/go.sum +++ b/go.sum @@ -19,8 +19,6 @@ github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWm github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= -github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= @@ -33,8 +31,6 @@ github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7/go.mod h1:pxMtw7c github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= -github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= From 367828429297e48558c67ee4b03b000a5e3d8d71 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 17 Oct 2024 20:30:32 +0300 Subject: [PATCH 0853/1647] id: remove outdated URI tests --- id/matrixuri_test.go | 47 +++++++++++++------------------------------- 1 file changed, 14 insertions(+), 33 deletions(-) diff --git a/id/matrixuri_test.go b/id/matrixuri_test.go index d26d4bfd..8b1096cb 100644 --- a/id/matrixuri_test.go +++ b/id/matrixuri_test.go @@ -16,12 +16,11 @@ import ( ) var ( - roomIDLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org"} - roomIDViaLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org", Via: []string{"maunium.net", "matrix.org"}} - roomAliasLink = id.MatrixURI{Sigil1: '#', MXID1: "someroom:example.org"} - roomIDEventLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org", Sigil2: '$', MXID2: "uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"} - roomAliasEventLink = id.MatrixURI{Sigil1: '#', MXID1: "someroom:example.org", Sigil2: '$', MXID2: "uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"} - userLink = id.MatrixURI{Sigil1: '@', MXID1: "user:example.org"} + roomIDLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org"} + roomIDViaLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org", Via: []string{"maunium.net", "matrix.org"}} + roomAliasLink = id.MatrixURI{Sigil1: '#', MXID1: "someroom:example.org"} + roomIDEventLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org", Sigil2: '$', MXID2: "uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"} + userLink = id.MatrixURI{Sigil1: '@', MXID1: "user:example.org"} escapeRoomIDEventLink = id.MatrixURI{Sigil1: '!', MXID1: "meow & 🐈️:example.org", Sigil2: '$', MXID2: "uOH4C9cK4HhMeFWkUXMbdF/dtndJ0j9je+kIK3XpV1s"} ) @@ -31,7 +30,6 @@ func TestMatrixURI_MatrixToURL(t *testing.T) { assert.Equal(t, "https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl:example.org?via=maunium.net&via=matrix.org", roomIDViaLink.MatrixToURL()) assert.Equal(t, "https://matrix.to/#/%23someroom:example.org", roomAliasLink.MatrixToURL()) assert.Equal(t, "https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", roomIDEventLink.MatrixToURL()) - assert.Equal(t, "https://matrix.to/#/%23someroom:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", roomAliasEventLink.MatrixToURL()) assert.Equal(t, "https://matrix.to/#/@user:example.org", userLink.MatrixToURL()) assert.Equal(t, "https://matrix.to/#/%21meow%20&%20%F0%9F%90%88%EF%B8%8F:example.org/$uOH4C9cK4HhMeFWkUXMbdF%2FdtndJ0j9je+kIK3XpV1s", escapeRoomIDEventLink.MatrixToURL()) } @@ -41,7 +39,6 @@ func TestMatrixURI_String(t *testing.T) { assert.Equal(t, "matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org?via=maunium.net&via=matrix.org", roomIDViaLink.String()) assert.Equal(t, "matrix:r/someroom:example.org", roomAliasLink.String()) assert.Equal(t, "matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", roomIDEventLink.String()) - assert.Equal(t, "matrix:r/someroom:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", roomAliasEventLink.String()) assert.Equal(t, "matrix:u/user:example.org", userLink.String()) assert.Equal(t, "matrix:roomid/meow%20&%20%F0%9F%90%88%EF%B8%8F:example.org/e/uOH4C9cK4HhMeFWkUXMbdF%2FdtndJ0j9je+kIK3XpV1s", escapeRoomIDEventLink.String()) } @@ -98,19 +95,11 @@ func TestParseMatrixURI_UserID(t *testing.T) { } func TestParseMatrixURI_EventID(t *testing.T) { - parsed1, err := id.ParseMatrixURI("matrix:r/someroom:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") + parsed, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") require.NoError(t, err) - require.NotNil(t, parsed1) - parsed2, err := id.ParseMatrixURI("matrix:room/someroom:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") - require.NoError(t, err) - require.NotNil(t, parsed2) - parsed3, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") - require.NoError(t, err) - require.NotNil(t, parsed3) + require.NotNil(t, parsed) - assert.Equal(t, roomAliasEventLink, *parsed1) - assert.Equal(t, roomAliasEventLink, *parsed2) - assert.Equal(t, roomIDEventLink, *parsed3) + assert.Equal(t, roomIDEventLink, *parsed) } func TestParseMatrixToURL_RoomAlias(t *testing.T) { @@ -158,21 +147,13 @@ func TestParseMatrixToURL_UserID(t *testing.T) { } func TestParseMatrixToURL_EventID(t *testing.T) { - parsed1, err := id.ParseMatrixToURL("https://matrix.to/#/#someroom:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") + parsed, err := id.ParseMatrixToURL("https://matrix.to/#/!7NdBVvkd4aLSbgKt9RXl:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") require.NoError(t, err) - require.NotNil(t, parsed1) - parsed2, err := id.ParseMatrixToURL("https://matrix.to/#/!7NdBVvkd4aLSbgKt9RXl:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") + require.NotNil(t, parsed) + parsedEncoded, err := id.ParseMatrixToURL("https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl:example.org/%24uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") require.NoError(t, err) - require.NotNil(t, parsed2) - parsed1Encoded, err := id.ParseMatrixToURL("https://matrix.to/#/%23someroom:example.org/%24uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") - require.NoError(t, err) - require.NotNil(t, parsed1) - parsed2Encoded, err := id.ParseMatrixToURL("https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl:example.org/%24uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") - require.NoError(t, err) - require.NotNil(t, parsed2) + require.NotNil(t, parsedEncoded) - assert.Equal(t, roomAliasEventLink, *parsed1) - assert.Equal(t, roomAliasEventLink, *parsed1Encoded) - assert.Equal(t, roomIDEventLink, *parsed2) - assert.Equal(t, roomIDEventLink, *parsed2Encoded) + assert.Equal(t, roomIDEventLink, *parsed) + assert.Equal(t, roomIDEventLink, *parsedEncoded) } From 6c07832ed7b5a45853ccbb3b8a5bdaefb912e119 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 18 Oct 2024 14:07:25 +0300 Subject: [PATCH 0854/1647] pushrules: add support for sender_notification_permission condition kind --- pushrules/condition.go | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/pushrules/condition.go b/pushrules/condition.go index dbe83a61..caa717de 100644 --- a/pushrules/condition.go +++ b/pushrules/condition.go @@ -27,6 +27,11 @@ type Room interface { GetMemberCount() int } +type PowerLevelfulRoom interface { + Room + GetPowerLevels() *event.PowerLevelsEventContent +} + // EventfulRoom is an extension of Room to support MSC3664. type EventfulRoom interface { Room @@ -38,11 +43,12 @@ type PushCondKind string // The allowed push condition kinds as specified in https://spec.matrix.org/v1.2/client-server-api/#conditions-1 const ( - KindEventMatch PushCondKind = "event_match" - KindContainsDisplayName PushCondKind = "contains_display_name" - KindRoomMemberCount PushCondKind = "room_member_count" - KindEventPropertyIs PushCondKind = "event_property_is" - KindEventPropertyContains PushCondKind = "event_property_contains" + KindEventMatch PushCondKind = "event_match" + KindContainsDisplayName PushCondKind = "contains_display_name" + KindRoomMemberCount PushCondKind = "room_member_count" + KindEventPropertyIs PushCondKind = "event_property_is" + KindEventPropertyContains PushCondKind = "event_property_contains" + KindSenderNotificationPermission PushCondKind = "sender_notification_permission" // MSC3664: https://github.com/matrix-org/matrix-spec-proposals/pull/3664 @@ -82,6 +88,8 @@ func (cond *PushCondition) Match(room Room, evt *event.Event) bool { return cond.matchDisplayName(room, evt) case KindRoomMemberCount: return cond.matchMemberCount(room) + case KindSenderNotificationPermission: + return cond.matchSenderNotificationPermission(room, evt.Sender, cond.Key) default: return false } @@ -334,3 +342,18 @@ func (cond *PushCondition) matchMemberCount(room Room) bool { return false } } + +func (cond *PushCondition) matchSenderNotificationPermission(room Room, sender id.UserID, key string) bool { + if key != "room" { + return false + } + plRoom, ok := room.(PowerLevelfulRoom) + if !ok { + return false + } + pls := plRoom.GetPowerLevels() + if pls == nil { + return false + } + return pls.GetUserLevel(sender) >= pls.Notifications.Room() +} From 3277c529a2e582a16d3fd23dbea78e07cdaaba36 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 18 Oct 2024 20:49:05 +0300 Subject: [PATCH 0855/1647] crypto: add full support for json.RawMessage in EncryptMegolmEvent --- crypto/encryptmegolm.go | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index 93fe6409..7c8a7542 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -15,6 +15,8 @@ import ( "fmt" "github.com/rs/zerolog" + "github.com/tidwall/gjson" + "go.mau.fi/util/exgjson" "go.mau.fi/util/exzerolog" "maunium.net/go/mautrix" @@ -27,7 +29,24 @@ var ( NoGroupSession = errors.New("no group session created") ) -func getRelatesTo(content interface{}) *event.RelatesTo { +func getRawJSON[T any](content json.RawMessage, path ...string) *T { + value := gjson.GetBytes(content, exgjson.Path(path...)) + if !value.IsObject() { + return nil + } + var result T + err := json.Unmarshal([]byte(value.Raw), &result) + if err != nil { + return nil + } + return &result +} + +func getRelatesTo(content any) *event.RelatesTo { + contentJSON, ok := content.(json.RawMessage) + if ok { + return getRawJSON[event.RelatesTo](contentJSON, "m.relates_to") + } contentStruct, ok := content.(*event.Content) if ok { content = contentStruct.Parsed @@ -39,7 +58,11 @@ func getRelatesTo(content interface{}) *event.RelatesTo { return nil } -func getMentions(content interface{}) *event.Mentions { +func getMentions(content any) *event.Mentions { + contentJSON, ok := content.(json.RawMessage) + if ok { + return getRawJSON[event.Mentions](contentJSON, "m.mentions") + } contentStruct, ok := content.(*event.Content) if ok { content = contentStruct.Parsed From 3f08ef0d57183a782a16a7bd4df0c291b3e6e7f3 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 17 Oct 2024 14:10:26 -0600 Subject: [PATCH 0856/1647] verificationhelper/request: check txn ID is different before sending cancellations This will allow it to fallthrought to the correct error which is that we received a new verification request for the same transaction ID. Signed-off-by: Sumner Evans --- crypto/verificationhelper/verificationhelper.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index f4e5e2f5..cbcff887 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -690,7 +690,7 @@ func (vh *VerificationHelper) onVerificationRequest(ctx context.Context, evt *ev TheirSupportedMethods: verificationRequest.Methods, } for existingTxnID, existingTxn := range vh.activeTransactions { - if existingTxn.TheirUser == evt.Sender && existingTxn.TheirDevice == verificationRequest.FromDevice { + if existingTxn.TheirUser == evt.Sender && existingTxn.TheirDevice == verificationRequest.FromDevice && existingTxnID != verificationRequest.TransactionID { vh.cancelVerificationTxn(ctx, existingTxn, event.VerificationCancelCodeUnexpectedMessage, "received multiple verification requests from the same device") vh.cancelVerificationTxn(ctx, newTxn, event.VerificationCancelCodeUnexpectedMessage, "received multiple verification requests from the same device") delete(vh.activeTransactions, existingTxnID) From e17cb8385518fddd2a6c080a9550dbec3e9f4233 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 20 Oct 2024 19:43:55 +0300 Subject: [PATCH 0857/1647] error: ignore RespError.Write calls with nil writer --- error.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/error.go b/error.go index a4ba9859..0133e80e 100644 --- a/error.go +++ b/error.go @@ -147,6 +147,9 @@ func (e *RespError) MarshalJSON() ([]byte, error) { } func (e RespError) Write(w http.ResponseWriter) { + if w == nil { + return + } statusCode := e.StatusCode if statusCode == 0 { statusCode = http.StatusInternalServerError From 8a8163106d95c631f15edf9fcde5fa313aaad797 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 22 Oct 2024 12:50:53 +0300 Subject: [PATCH 0858/1647] sqlstatestore: handle nulls in members_fetched --- sqlstatestore/statestore.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlstatestore/statestore.go b/sqlstatestore/statestore.go index d594c307..33c10c4c 100644 --- a/sqlstatestore/statestore.go +++ b/sqlstatestore/statestore.go @@ -303,7 +303,7 @@ func (store *SQLStateStore) ClearCachedMembers(ctx context.Context, roomID id.Ro } func (store *SQLStateStore) HasFetchedMembers(ctx context.Context, roomID id.RoomID) (fetched bool, err error) { - err = store.QueryRow(ctx, "SELECT members_fetched FROM mx_room_state WHERE room_id=$1", roomID).Scan(&fetched) + err = store.QueryRow(ctx, "SELECT COALESCE(members_fetched, false) FROM mx_room_state WHERE room_id=$1", roomID).Scan(&fetched) if errors.Is(err, sql.ErrNoRows) { err = nil } From 9b8244269bc6aa43a79d0385ae20d3372c510e56 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 22 Oct 2024 19:01:26 +0300 Subject: [PATCH 0859/1647] bridgev2/portal: re-id outgoing reactions too --- bridgev2/portal.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index aa263098..39a7724f 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1143,6 +1143,10 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi portal.sendErrorStatus(ctx, evt, err) return } + var deterministicID id.EventID + if portal.Bridge.Config.OutgoingMessageReID { + deterministicID = portal.Bridge.Matrix.GenerateReactionEventID(portal.MXID, reactionTarget, preResp.SenderID, preResp.EmojiID) + } existing, err := portal.Bridge.DB.Reaction.GetByID(ctx, reactionTarget.ID, reactionTarget.PartID, preResp.SenderID, preResp.EmojiID) if err != nil { log.Err(err).Msg("Failed to check if reaction is a duplicate") @@ -1150,7 +1154,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi } else if existing != nil { if existing.EmojiID != "" || existing.Emoji == preResp.Emoji { log.Debug().Msg("Ignoring duplicate reaction") - portal.sendSuccessStatus(ctx, evt, 0, "") + portal.sendSuccessStatus(ctx, evt, 0, deterministicID) return } react.ReactionToOverride = existing @@ -1209,7 +1213,9 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi dbReaction.MessageID = reactionTarget.ID dbReaction.MessagePartID = reactionTarget.PartID } - if dbReaction.MXID == "" { + if deterministicID != "" { + dbReaction.MXID = deterministicID + } else if dbReaction.MXID == "" { dbReaction.MXID = evt.ID } if dbReaction.Timestamp.IsZero() { @@ -1232,7 +1238,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi if err != nil { log.Err(err).Msg("Failed to save reaction to database") } - portal.sendSuccessStatus(ctx, evt, 0, "") + portal.sendSuccessStatus(ctx, evt, 0, deterministicID) } func handleMatrixRoomMeta[APIType any, ContentType any]( From d316a6b55f367f84bcad77af029365b45fa4c8c1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 23 Oct 2024 12:12:49 +0300 Subject: [PATCH 0860/1647] bridgev2/commands: don't validate cookies before url decoding --- bridgev2/commands/login.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go index bf9cdf45..8896eb60 100644 --- a/bridgev2/commands/login.go +++ b/bridgev2/commands/login.go @@ -392,10 +392,6 @@ func maybeURLDecodeCookie(val string, field *bridgev2.LoginCookieField) string { if !isCookie { return val } - match, _ := regexp.MatchString(field.Pattern, val) - if !match { - return val - } decoded, err := url.PathUnescape(val) if err != nil { return val From eead5937ea5443be867efa4d71f073bef0277452 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 23 Oct 2024 12:30:22 +0300 Subject: [PATCH 0861/1647] bridgev2/provisioning: include HTTP request in login contexts --- bridgev2/matrix/provisioning.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 951e6df1..51465d05 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -78,6 +78,8 @@ const ( provisioningLoginProcessKey ) +const ProvisioningKeyRequest = "fi.mau.provision.request" + func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User { return r.Context().Value(provisioningUserKey).(*bridgev2.User) } @@ -269,7 +271,8 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { return } - ctx := context.WithValue(r.Context(), provisioningUserKey, user) + ctx := context.WithValue(r.Context(), ProvisioningKeyRequest, r) + ctx = context.WithValue(ctx, provisioningUserKey, user) if loginID, ok := mux.Vars(r)["loginProcessID"]; ok { prov.loginsLock.RLock() login, ok := prov.logins[loginID] @@ -309,7 +312,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { }) return } - ctx = context.WithValue(r.Context(), provisioningLoginProcessKey, login) + ctx = context.WithValue(ctx, provisioningLoginProcessKey, login) } h.ServeHTTP(w, r.WithContext(ctx)) }) From ab6c2ed9a260c2d9795a0dd5fa119745628ccbef Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 23 Oct 2024 14:56:23 +0100 Subject: [PATCH 0862/1647] Add custom field to set state event ID when sending --- client.go | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index b85d86fb..dc59ca10 100644 --- a/client.go +++ b/client.go @@ -1133,8 +1133,19 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event // SendStateEvent sends a state event into a room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. -func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (resp *RespSendEvent, err error) { - urlPath := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey) +func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { + var req ReqSendEvent + if len(extra) > 0 { + req = extra[0] + } + + queryParams := map[string]string{} + if req.MeowEventID != "" { + queryParams["fi.mau.event_id"] = req.MeowEventID.String() + } + + urlData := ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey} + urlPath := cli.BuildURLWithQuery(urlData, queryParams) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp) if err == nil && cli.StateStore != nil { cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON) From 07d7b3cfc28a60e55d734896449d6289ea8db524 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 23 Oct 2024 17:29:25 +0300 Subject: [PATCH 0863/1647] bridgev2/portal: add special constant to reset portal name --- bridgev2/portal.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 39a7724f..e3a1d0fa 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2807,6 +2807,10 @@ func (plc *PowerLevelOverrides) Apply(actor id.UserID, content *event.PowerLevel return changed } +// DefaultChatName can be used to explicitly clear the name of a room +// and reset it to the default one based on members. +var DefaultChatName = ptr.Ptr("") + type ChatInfo struct { Name *string Topic *string @@ -3387,7 +3391,10 @@ func (portal *Portal) UpdateInfoFromGhost(ctx context.Context, ghost *Ghost) (ch func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *UserLogin, sender MatrixAPI, ts time.Time) { changed := false - if info.Name != nil { + if info.Name == DefaultChatName { + portal.NameIsCustom = false + changed = portal.updateName(ctx, "", sender, ts) || changed + } else if info.Name != nil { portal.NameIsCustom = true changed = portal.updateName(ctx, *info.Name, sender, ts) || changed } From e7811488dd95f8ff0d6d65d7cf78123271831f3f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 23 Oct 2024 17:31:31 +0300 Subject: [PATCH 0864/1647] bridgev2/portal: only clear name if it's set to a custom one --- bridgev2/portal.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index e3a1d0fa..93410cbc 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -3392,8 +3392,10 @@ func (portal *Portal) UpdateInfoFromGhost(ctx context.Context, ghost *Ghost) (ch func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *UserLogin, sender MatrixAPI, ts time.Time) { changed := false if info.Name == DefaultChatName { - portal.NameIsCustom = false - changed = portal.updateName(ctx, "", sender, ts) || changed + if portal.NameIsCustom { + portal.NameIsCustom = false + changed = portal.updateName(ctx, "", sender, ts) || changed + } } else if info.Name != nil { portal.NameIsCustom = true changed = portal.updateName(ctx, *info.Name, sender, ts) || changed From 6fd4b8a2132d535627b61e6408c27bcd4c82a316 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Wed, 23 Oct 2024 14:40:42 -0600 Subject: [PATCH 0865/1647] bridgev2/database: add function to get last N messages in portal Signed-off-by: Sumner Evans --- bridgev2/database/message.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index 8173ad05..8daf7407 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -61,6 +61,7 @@ const ( getOldestMessageInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp ASC, part_id ASC LIMIT 1` getFirstMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp ASC, part_id ASC LIMIT 1` getLastMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp DESC, part_id DESC LIMIT 1` + getLastNInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp DESC, part_id DESC LIMIT $4` getLastMessagePartAtOrBeforeTimeQuery = getMessageBaseQuery + `WHERE bridge_id = $1 AND room_id=$2 AND room_receiver=$3 AND timestamp<=$4 ORDER BY timestamp DESC, part_id DESC LIMIT 1` @@ -141,6 +142,10 @@ func (mq *MessageQuery) GetLastThreadMessage(ctx context.Context, portal network return mq.QueryOne(ctx, getLastMessageInThread, mq.BridgeID, portal.ID, portal.Receiver, threadRoot) } +func (mq *MessageQuery) GetLastNInPortal(ctx context.Context, portal networkid.PortalKey, n int) ([]*Message, error) { + return mq.QueryMany(ctx, getLastNInPortal, mq.BridgeID, portal.ID, portal.Receiver, n) +} + func (mq *MessageQuery) Insert(ctx context.Context, msg *Message) error { ensureBridgeIDMatches(&msg.BridgeID, mq.BridgeID) return mq.GetDB().QueryRow(ctx, insertMessageQuery, msg.ensureHasMetadata(mq.MetaType).sqlVariables()...).Scan(&msg.RowID) From 7cc46f1ff37f3304bf3d4823e736ee7591eb1462 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 25 Oct 2024 01:58:22 -0600 Subject: [PATCH 0866/1647] crypto: always read from crypto/rand Signed-off-by: Sumner Evans --- crypto/account.go | 4 +- crypto/goolm/account/account.go | 22 +++++------ crypto/goolm/account/account_test.go | 38 +++++++++---------- crypto/goolm/account/register.go | 6 +-- crypto/goolm/crypto/curve25519.go | 17 ++------- crypto/goolm/crypto/curve25519_test.go | 10 ++--- crypto/goolm/crypto/ed25519.go | 7 ++-- crypto/goolm/crypto/ed25519_test.go | 10 ++--- crypto/goolm/pk/decryption.go | 2 +- crypto/goolm/ratchet/olm.go | 6 +-- crypto/goolm/ratchet/olm_test.go | 2 +- .../goolm/session/megolm_outbound_session.go | 2 +- crypto/goolm/session/megolm_session_test.go | 4 +- crypto/goolm/session/olm_session.go | 4 +- crypto/goolm/session/olm_session_test.go | 6 +-- crypto/libolm/account.go | 19 +++------- crypto/olm/account.go | 13 +++---- crypto/olm/account_test.go | 10 ++--- crypto/olm/session_test.go | 12 +++--- 19 files changed, 86 insertions(+), 108 deletions(-) diff --git a/crypto/account.go b/crypto/account.go index 2f93280c..0bd09ecf 100644 --- a/crypto/account.go +++ b/crypto/account.go @@ -27,7 +27,7 @@ type OlmAccount struct { } func NewOlmAccount() *OlmAccount { - account, err := olm.NewAccount(nil) + account, err := olm.NewAccount() if err != nil { panic(err) } @@ -105,7 +105,7 @@ func (account *OlmAccount) getInitialKeys(userID id.UserID, deviceID id.DeviceID func (account *OlmAccount) getOneTimeKeys(userID id.UserID, deviceID id.DeviceID, currentOTKCount int) map[id.KeyID]mautrix.OneTimeKey { newCount := int(account.Internal.MaxNumberOfOneTimeKeys()/2) - currentOTKCount if newCount > 0 { - account.Internal.GenOneTimeKeys(nil, uint(newCount)) + account.Internal.GenOneTimeKeys(uint(newCount)) } oneTimeKeys := make(map[id.KeyID]mautrix.OneTimeKey) internalKeys, err := account.Internal.OneTimeKeys() diff --git a/crypto/goolm/account/account.go b/crypto/goolm/account/account.go index 2b127ab5..46ae2571 100644 --- a/crypto/goolm/account/account.go +++ b/crypto/goolm/account/account.go @@ -6,7 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "maunium.net/go/mautrix/id" @@ -68,15 +67,15 @@ func AccountFromPickled(pickled, key []byte) (*Account, error) { return a, nil } -// NewAccount creates a new Account. If reader is nil, crypto/rand is used for the key creation. -func NewAccount(reader io.Reader) (*Account, error) { +// NewAccount creates a new Account. +func NewAccount() (*Account, error) { a := &Account{} - kPEd25519, err := crypto.Ed25519GenerateKey(reader) + kPEd25519, err := crypto.Ed25519GenerateKey() if err != nil { return nil, err } a.IdKeys.Ed25519 = kPEd25519 - kPCurve25519, err := crypto.Curve25519GenerateKey(reader) + kPCurve25519, err := crypto.Curve25519GenerateKey() if err != nil { return nil, err } @@ -151,14 +150,14 @@ func (a *Account) MarkKeysAsPublished() { // GenOneTimeKeys generates a number of new one time keys. If the total number // of keys stored by this Account exceeds MaxOneTimeKeys then the older -// keys are discarded. If reader is nil, crypto/rand is used for the key creation. -func (a *Account) GenOneTimeKeys(reader io.Reader, num uint) error { +// keys are discarded. +func (a *Account) GenOneTimeKeys(num uint) error { for i := uint(0); i < num; i++ { key := crypto.OneTimeKey{ Published: false, ID: a.NextOneTimeKeyID, } - newKP, err := crypto.Curve25519GenerateKey(reader) + newKP, err := crypto.Curve25519GenerateKey() if err != nil { return err } @@ -247,14 +246,15 @@ func (a *Account) RemoveOneTimeKeys(s olm.Session) error { //if the key is a fallback or prevFallback, don't remove it } -// GenFallbackKey generates a new fallback key. The old fallback key is stored in a.PrevFallbackKey overwriting any previous PrevFallbackKey. If reader is nil, crypto/rand is used for the key creation. -func (a *Account) GenFallbackKey(reader io.Reader) error { +// GenFallbackKey generates a new fallback key. The old fallback key is stored +// in a.PrevFallbackKey overwriting any previous PrevFallbackKey. +func (a *Account) GenFallbackKey() error { a.PrevFallbackKey = a.CurrentFallbackKey key := crypto.OneTimeKey{ Published: false, ID: a.NextOneTimeKeyID, } - newKP, err := crypto.Curve25519GenerateKey(reader) + newKP, err := crypto.Curve25519GenerateKey() if err != nil { return err } diff --git a/crypto/goolm/account/account_test.go b/crypto/goolm/account/account_test.go index 2482d087..05d6d5fc 100644 --- a/crypto/goolm/account/account_test.go +++ b/crypto/goolm/account/account_test.go @@ -17,15 +17,15 @@ import ( ) func TestAccount(t *testing.T) { - firstAccount, err := account.NewAccount(nil) + firstAccount, err := account.NewAccount() if err != nil { t.Fatal(err) } - err = firstAccount.GenFallbackKey(nil) + err = firstAccount.GenFallbackKey() if err != nil { t.Fatal(err) } - err = firstAccount.GenOneTimeKeys(nil, 2) + err = firstAccount.GenOneTimeKeys(2) if err != nil { t.Fatal(err) } @@ -118,19 +118,19 @@ func TestAccountPickleJSON(t *testing.T) { } func TestSessions(t *testing.T) { - aliceAccount, err := account.NewAccount(nil) + aliceAccount, err := account.NewAccount() if err != nil { t.Fatal(err) } - err = aliceAccount.GenOneTimeKeys(nil, 5) + err = aliceAccount.GenOneTimeKeys(5) if err != nil { t.Fatal(err) } - bobAccount, err := account.NewAccount(nil) + bobAccount, err := account.NewAccount() if err != nil { t.Fatal(err) } - err = bobAccount.GenOneTimeKeys(nil, 5) + err = bobAccount.GenOneTimeKeys(5) if err != nil { t.Fatal(err) } @@ -217,7 +217,7 @@ func TestOldAccountPickle(t *testing.T) { "K/A/8TOu9iK2hDFszy6xETiousHnHgh2ZGbRUh4pQx+YMm8ZdNZeRnwFGLnrWyf9" + "O5TmXua1FcU") pickleKey := []byte("") - account, err := account.NewAccount(nil) + account, err := account.NewAccount() if err != nil { t.Fatal(err) } @@ -232,16 +232,16 @@ func TestOldAccountPickle(t *testing.T) { } func TestLoopback(t *testing.T) { - accountA, err := account.NewAccount(nil) + accountA, err := account.NewAccount() if err != nil { t.Fatal(err) } - accountB, err := account.NewAccount(nil) + accountB, err := account.NewAccount() if err != nil { t.Fatal(err) } - err = accountB.GenOneTimeKeys(nil, 42) + err = accountB.GenOneTimeKeys(42) if err != nil { t.Fatal(err) } @@ -328,16 +328,16 @@ func TestLoopback(t *testing.T) { } func TestMoreMessages(t *testing.T) { - accountA, err := account.NewAccount(nil) + accountA, err := account.NewAccount() if err != nil { t.Fatal(err) } - accountB, err := account.NewAccount(nil) + accountB, err := account.NewAccount() if err != nil { t.Fatal(err) } - err = accountB.GenOneTimeKeys(nil, 42) + err = accountB.GenOneTimeKeys(42) if err != nil { t.Fatal(err) } @@ -411,16 +411,16 @@ func TestMoreMessages(t *testing.T) { } func TestFallbackKey(t *testing.T) { - accountA, err := account.NewAccount(nil) + accountA, err := account.NewAccount() if err != nil { t.Fatal(err) } - accountB, err := account.NewAccount(nil) + accountB, err := account.NewAccount() if err != nil { t.Fatal(err) } - err = accountB.GenFallbackKey(nil) + err = accountB.GenFallbackKey() if err != nil { t.Fatal(err) } @@ -483,7 +483,7 @@ func TestFallbackKey(t *testing.T) { } // create a new fallback key for B (the old fallback should still be usable) - err = accountB.GenFallbackKey(nil) + err = accountB.GenFallbackKey() if err != nil { t.Fatal(err) } @@ -602,7 +602,7 @@ func TestOldV3AccountPickle(t *testing.T) { } func TestAccountSign(t *testing.T) { - accountA, err := account.NewAccount(nil) + accountA, err := account.NewAccount() require.NoError(t, err) plainText := []byte("Hello, World") signatureB64, err := accountA.Sign(plainText) diff --git a/crypto/goolm/account/register.go b/crypto/goolm/account/register.go index ab0c598a..c6b9e523 100644 --- a/crypto/goolm/account/register.go +++ b/crypto/goolm/account/register.go @@ -7,14 +7,12 @@ package account import ( - "io" - "maunium.net/go/mautrix/crypto/olm" ) func init() { - olm.InitNewAccount = func(r io.Reader) (olm.Account, error) { - return NewAccount(r) + olm.InitNewAccount = func() (olm.Account, error) { + return NewAccount() } olm.InitBlankAccount = func() olm.Account { return &Account{} diff --git a/crypto/goolm/crypto/curve25519.go b/crypto/goolm/crypto/curve25519.go index 1c182caa..872ce3a1 100644 --- a/crypto/goolm/crypto/curve25519.go +++ b/crypto/goolm/crypto/curve25519.go @@ -5,7 +5,6 @@ import ( "crypto/rand" "encoding/base64" "fmt" - "io" "golang.org/x/crypto/curve25519" @@ -19,19 +18,11 @@ const ( curve25519PubKeyLength = 32 ) -// Curve25519GenerateKey creates a new curve25519 key pair. If reader is nil, the random data is taken from crypto/rand. -func Curve25519GenerateKey(reader io.Reader) (Curve25519KeyPair, error) { +// Curve25519GenerateKey creates a new curve25519 key pair. +func Curve25519GenerateKey() (Curve25519KeyPair, error) { privateKeyByte := make([]byte, Curve25519KeyLength) - if reader == nil { - _, err := rand.Read(privateKeyByte) - if err != nil { - return Curve25519KeyPair{}, err - } - } else { - _, err := reader.Read(privateKeyByte) - if err != nil { - return Curve25519KeyPair{}, err - } + if _, err := rand.Read(privateKeyByte); err != nil { + return Curve25519KeyPair{}, err } privateKey := Curve25519PrivateKey(privateKeyByte) diff --git a/crypto/goolm/crypto/curve25519_test.go b/crypto/goolm/crypto/curve25519_test.go index f7df5edc..ce5f561b 100644 --- a/crypto/goolm/crypto/curve25519_test.go +++ b/crypto/goolm/crypto/curve25519_test.go @@ -8,11 +8,11 @@ import ( ) func TestCurve25519(t *testing.T) { - firstKeypair, err := crypto.Curve25519GenerateKey(nil) + firstKeypair, err := crypto.Curve25519GenerateKey() if err != nil { t.Fatal(err) } - secondKeypair, err := crypto.Curve25519GenerateKey(nil) + secondKeypair, err := crypto.Curve25519GenerateKey() if err != nil { t.Fatal(err) } @@ -93,7 +93,7 @@ func TestCurve25519Case1(t *testing.T) { func TestCurve25519Pickle(t *testing.T) { //create keypair - keyPair, err := crypto.Curve25519GenerateKey(nil) + keyPair, err := crypto.Curve25519GenerateKey() if err != nil { t.Fatal(err) } @@ -124,7 +124,7 @@ func TestCurve25519Pickle(t *testing.T) { func TestCurve25519PicklePubKeyOnly(t *testing.T) { //create keypair - keyPair, err := crypto.Curve25519GenerateKey(nil) + keyPair, err := crypto.Curve25519GenerateKey() if err != nil { t.Fatal(err) } @@ -156,7 +156,7 @@ func TestCurve25519PicklePubKeyOnly(t *testing.T) { func TestCurve25519PicklePrivKeyOnly(t *testing.T) { //create keypair - keyPair, err := crypto.Curve25519GenerateKey(nil) + keyPair, err := crypto.Curve25519GenerateKey() if err != nil { t.Fatal(err) } diff --git a/crypto/goolm/crypto/ed25519.go b/crypto/goolm/crypto/ed25519.go index 57fc25fa..bc260377 100644 --- a/crypto/goolm/crypto/ed25519.go +++ b/crypto/goolm/crypto/ed25519.go @@ -3,7 +3,6 @@ package crypto import ( "encoding/base64" "fmt" - "io" "maunium.net/go/mautrix/crypto/ed25519" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" @@ -15,9 +14,9 @@ const ( ED25519SignatureSize = ed25519.SignatureSize //The length of a signature ) -// Ed25519GenerateKey creates a new ed25519 key pair. If reader is nil, the random data is taken from crypto/rand. -func Ed25519GenerateKey(reader io.Reader) (Ed25519KeyPair, error) { - publicKey, privateKey, err := ed25519.GenerateKey(reader) +// Ed25519GenerateKey creates a new ed25519 key pair. +func Ed25519GenerateKey() (Ed25519KeyPair, error) { + publicKey, privateKey, err := ed25519.GenerateKey(nil) if err != nil { return Ed25519KeyPair{}, err } diff --git a/crypto/goolm/crypto/ed25519_test.go b/crypto/goolm/crypto/ed25519_test.go index 391de912..3588205a 100644 --- a/crypto/goolm/crypto/ed25519_test.go +++ b/crypto/goolm/crypto/ed25519_test.go @@ -8,7 +8,7 @@ import ( ) func TestEd25519(t *testing.T) { - keypair, err := crypto.Ed25519GenerateKey(nil) + keypair, err := crypto.Ed25519GenerateKey() if err != nil { t.Fatal(err) } @@ -21,7 +21,7 @@ func TestEd25519(t *testing.T) { func TestEd25519Case1(t *testing.T) { //64 bytes for ed25519 package - keyPair, err := crypto.Ed25519GenerateKey(nil) + keyPair, err := crypto.Ed25519GenerateKey() if err != nil { t.Fatal(err) } @@ -46,7 +46,7 @@ func TestEd25519Case1(t *testing.T) { func TestEd25519Pickle(t *testing.T) { //create keypair - keyPair, err := crypto.Ed25519GenerateKey(nil) + keyPair, err := crypto.Ed25519GenerateKey() if err != nil { t.Fatal(err) } @@ -77,7 +77,7 @@ func TestEd25519Pickle(t *testing.T) { func TestEd25519PicklePubKeyOnly(t *testing.T) { //create keypair - keyPair, err := crypto.Ed25519GenerateKey(nil) + keyPair, err := crypto.Ed25519GenerateKey() if err != nil { t.Fatal(err) } @@ -109,7 +109,7 @@ func TestEd25519PicklePubKeyOnly(t *testing.T) { func TestEd25519PicklePrivKeyOnly(t *testing.T) { //create keypair - keyPair, err := crypto.Ed25519GenerateKey(nil) + keyPair, err := crypto.Ed25519GenerateKey() if err != nil { t.Fatal(err) } diff --git a/crypto/goolm/pk/decryption.go b/crypto/goolm/pk/decryption.go index b24716e8..dcec5107 100644 --- a/crypto/goolm/pk/decryption.go +++ b/crypto/goolm/pk/decryption.go @@ -26,7 +26,7 @@ type Decryption struct { // NewDecryption returns a new Decryption with a new generated key pair. func NewDecryption() (*Decryption, error) { - keyPair, err := crypto.Curve25519GenerateKey(nil) + keyPair, err := crypto.Curve25519GenerateKey() if err != nil { return nil, err } diff --git a/crypto/goolm/ratchet/olm.go b/crypto/goolm/ratchet/olm.go index bf04c1cf..4653aae7 100644 --- a/crypto/goolm/ratchet/olm.go +++ b/crypto/goolm/ratchet/olm.go @@ -94,11 +94,11 @@ func (r *Ratchet) InitializeAsAlice(sharedSecret []byte, ourRatchetKey crypto.Cu return nil } -// Encrypt encrypts the message in a message.Message with MAC. If reader is nil, crypto/rand is used for key generations. +// Encrypt encrypts the message in a message.Message with MAC. func (r *Ratchet) Encrypt(plaintext []byte) ([]byte, error) { var err error if !r.SenderChains.IsSet { - newRatchetKey, err := crypto.Curve25519GenerateKey(nil) + newRatchetKey, err := crypto.Curve25519GenerateKey() if err != nil { return nil, err } @@ -132,7 +132,7 @@ func (r *Ratchet) Encrypt(plaintext []byte) ([]byte, error) { return output, nil } -// Decrypt decrypts the ciphertext and verifies the MAC. If reader is nil, crypto/rand is used for key generations. +// Decrypt decrypts the ciphertext and verifies the MAC. func (r *Ratchet) Decrypt(input []byte) ([]byte, error) { message := &message.Message{} //The mac is not verified here, as we do not know the key yet diff --git a/crypto/goolm/ratchet/olm_test.go b/crypto/goolm/ratchet/olm_test.go index 91549bd8..91dd6e9b 100644 --- a/crypto/goolm/ratchet/olm_test.go +++ b/crypto/goolm/ratchet/olm_test.go @@ -26,7 +26,7 @@ func initializeRatchets() (*ratchet.Ratchet, *ratchet.Ratchet, error) { aliceRatchet := ratchet.New() bobRatchet := ratchet.New() - aliceKey, err := crypto.Curve25519GenerateKey(nil) + aliceKey, err := crypto.Curve25519GenerateKey() if err != nil { return nil, nil, err } diff --git a/crypto/goolm/session/megolm_outbound_session.go b/crypto/goolm/session/megolm_outbound_session.go index ce9a4b26..b3234967 100644 --- a/crypto/goolm/session/megolm_outbound_session.go +++ b/crypto/goolm/session/megolm_outbound_session.go @@ -35,7 +35,7 @@ var _ olm.OutboundGroupSession = (*MegolmOutboundSession)(nil) func NewMegolmOutboundSession() (*MegolmOutboundSession, error) { o := &MegolmOutboundSession{} var err error - o.SigningKey, err = crypto.Ed25519GenerateKey(nil) + o.SigningKey, err = crypto.Ed25519GenerateKey() if err != nil { return nil, err } diff --git a/crypto/goolm/session/megolm_session_test.go b/crypto/goolm/session/megolm_session_test.go index 936ce982..7c3f455f 100644 --- a/crypto/goolm/session/megolm_session_test.go +++ b/crypto/goolm/session/megolm_session_test.go @@ -18,7 +18,7 @@ func TestOutboundPickleJSON(t *testing.T) { if err != nil { t.Fatal(err) } - kp, err := crypto.Ed25519GenerateKey(nil) + kp, err := crypto.Ed25519GenerateKey() if err != nil { t.Fatal(err) } @@ -50,7 +50,7 @@ func TestOutboundPickleJSON(t *testing.T) { func TestInboundPickleJSON(t *testing.T) { pickleKey := []byte("secretKey") sess := session.MegolmInboundSession{} - kp, err := crypto.Ed25519GenerateKey(nil) + kp, err := crypto.Ed25519GenerateKey() if err != nil { t.Fatal(err) } diff --git a/crypto/goolm/session/olm_session.go b/crypto/goolm/session/olm_session.go index 33908edc..c0067bfa 100644 --- a/crypto/goolm/session/olm_session.go +++ b/crypto/goolm/session/olm_session.go @@ -80,12 +80,12 @@ func NewOlmSession() *OlmSession { func NewOutboundOlmSession(identityKeyAlice crypto.Curve25519KeyPair, identityKeyBob crypto.Curve25519PublicKey, oneTimeKeyBob crypto.Curve25519PublicKey) (*OlmSession, error) { s := NewOlmSession() //generate E_A - baseKey, err := crypto.Curve25519GenerateKey(nil) + baseKey, err := crypto.Curve25519GenerateKey() if err != nil { return nil, err } //generate T_0 - ratchetKey, err := crypto.Curve25519GenerateKey(nil) + ratchetKey, err := crypto.Curve25519GenerateKey() if err != nil { return nil, err } diff --git a/crypto/goolm/session/olm_session_test.go b/crypto/goolm/session/olm_session_test.go index b5ff4c32..773da8c7 100644 --- a/crypto/goolm/session/olm_session_test.go +++ b/crypto/goolm/session/olm_session_test.go @@ -14,15 +14,15 @@ import ( func TestOlmSession(t *testing.T) { pickleKey := []byte("secretKey") - aliceKeyPair, err := crypto.Curve25519GenerateKey(nil) + aliceKeyPair, err := crypto.Curve25519GenerateKey() if err != nil { t.Fatal(err) } - bobKeyPair, err := crypto.Curve25519GenerateKey(nil) + bobKeyPair, err := crypto.Curve25519GenerateKey() if err != nil { t.Fatal(err) } - bobOneTimeKey, err := crypto.Curve25519GenerateKey(nil) + bobOneTimeKey, err := crypto.Curve25519GenerateKey() if err != nil { t.Fatal(err) } diff --git a/crypto/libolm/account.go b/crypto/libolm/account.go index ad329fa3..cddce7ce 100644 --- a/crypto/libolm/account.go +++ b/crypto/libolm/account.go @@ -8,7 +8,6 @@ import ( "crypto/rand" "encoding/base64" "encoding/json" - "io" "unsafe" "github.com/tidwall/gjson" @@ -24,8 +23,8 @@ type Account struct { } func init() { - olm.InitNewAccount = func(r io.Reader) (olm.Account, error) { - return NewAccount(r) + olm.InitNewAccount = func() (olm.Account, error) { + return NewAccount() } olm.InitBlankAccount = func() olm.Account { return NewBlankAccount() @@ -60,13 +59,10 @@ func NewBlankAccount() *Account { } // NewAccount creates a new [Account]. -func NewAccount(r io.Reader) (*Account, error) { +func NewAccount() (*Account, error) { a := NewBlankAccount() random := make([]byte, a.createRandomLen()+1) - if r == nil { - r = rand.Reader - } - _, err := r.Read(random) + _, err := rand.Read(random) if err != nil { panic(olm.NotEnoughGoRandom) } @@ -307,12 +303,9 @@ func (a *Account) MaxNumberOfOneTimeKeys() uint { // GenOneTimeKeys generates a number of new one time keys. If the total number // of keys stored by this Account exceeds MaxNumberOfOneTimeKeys then the old // keys are discarded. -func (a *Account) GenOneTimeKeys(reader io.Reader, num uint) error { +func (a *Account) GenOneTimeKeys(num uint) error { random := make([]byte, a.genOneTimeKeysRandomLen(num)+1) - if reader == nil { - reader = rand.Reader - } - _, err := reader.Read(random) + _, err := rand.Read(random) if err != nil { return olm.NotEnoughGoRandom } diff --git a/crypto/olm/account.go b/crypto/olm/account.go index 3271b1c1..68393e8a 100644 --- a/crypto/olm/account.go +++ b/crypto/olm/account.go @@ -7,8 +7,6 @@ package olm import ( - "io" - "maunium.net/go/mautrix/id" ) @@ -57,9 +55,8 @@ type Account interface { // GenOneTimeKeys generates a number of new one time keys. If the total // number of keys stored by this Account exceeds MaxNumberOfOneTimeKeys - // then the old keys are discarded. Reads random data from the given - // reader, or if nil is passed, defaults to crypto/rand. - GenOneTimeKeys(reader io.Reader, num uint) error + // then the old keys are discarded. + GenOneTimeKeys(num uint) error // NewOutboundSession creates a new out-bound session for sending messages to a // given curve25519 identityKey and oneTimeKey. Returns error on failure. If the @@ -91,12 +88,12 @@ type Account interface { } var InitBlankAccount func() Account -var InitNewAccount func(io.Reader) (Account, error) +var InitNewAccount func() (Account, error) var InitNewAccountFromPickled func(pickled, key []byte) (Account, error) // NewAccount creates a new Account. -func NewAccount(r io.Reader) (Account, error) { - return InitNewAccount(r) +func NewAccount() (Account, error) { + return InitNewAccount() } func NewBlankAccount() Account { diff --git a/crypto/olm/account_test.go b/crypto/olm/account_test.go index 0c628a20..0e055881 100644 --- a/crypto/olm/account_test.go +++ b/crypto/olm/account_test.go @@ -49,10 +49,10 @@ func ensureAccountsEqual(t *testing.T, a, b olm.Account) { // TestAccount_UnpickleLibolmToGoolm tests creating an account from libolm, // pickling it, and importing it into goolm. func TestAccount_UnpickleLibolmToGoolm(t *testing.T) { - libolmAccount, err := libolm.NewAccount(nil) + libolmAccount, err := libolm.NewAccount() require.NoError(t, err) - require.NoError(t, libolmAccount.GenOneTimeKeys(nil, 50)) + require.NoError(t, libolmAccount.GenOneTimeKeys(50)) libolmPickled, err := libolmAccount.Pickle([]byte("test")) require.NoError(t, err) @@ -70,10 +70,10 @@ func TestAccount_UnpickleLibolmToGoolm(t *testing.T) { // TestAccount_UnpickleGoolmToLibolm tests creating an account from goolm, // pickling it, and importing it into libolm. func TestAccount_UnpickleGoolmToLibolm(t *testing.T) { - goolmAccount, err := account.NewAccount(nil) + goolmAccount, err := account.NewAccount() require.NoError(t, err) - require.NoError(t, goolmAccount.GenOneTimeKeys(nil, 50)) + require.NoError(t, goolmAccount.GenOneTimeKeys(50)) goolmPickled, err := goolmAccount.Pickle([]byte("test")) require.NoError(t, err) @@ -91,7 +91,7 @@ func TestAccount_UnpickleGoolmToLibolm(t *testing.T) { func FuzzAccount_Sign(f *testing.F) { f.Add([]byte("anything")) - libolmAccount := exerrors.Must(libolm.NewAccount(nil)) + libolmAccount := exerrors.Must(libolm.NewAccount()) goolmAccount := exerrors.Must(account.AccountFromPickled(exerrors.Must(libolmAccount.Pickle([]byte("test"))), []byte("test"))) f.Fuzz(func(t *testing.T, message []byte) { diff --git a/crypto/olm/session_test.go b/crypto/olm/session_test.go index 9f0986eb..b0b9896f 100644 --- a/crypto/olm/session_test.go +++ b/crypto/olm/session_test.go @@ -60,16 +60,16 @@ func TestSessionPickle(t *testing.T) { func TestSession_EncryptDecrypt(t *testing.T) { combos := [][2]olm.Account{ - {exerrors.Must(libolm.NewAccount(nil)), exerrors.Must(libolm.NewAccount(nil))}, - {exerrors.Must(account.NewAccount(nil)), exerrors.Must(account.NewAccount(nil))}, - {exerrors.Must(libolm.NewAccount(nil)), exerrors.Must(account.NewAccount(nil))}, - {exerrors.Must(account.NewAccount(nil)), exerrors.Must(libolm.NewAccount(nil))}, + {exerrors.Must(libolm.NewAccount()), exerrors.Must(libolm.NewAccount())}, + {exerrors.Must(account.NewAccount()), exerrors.Must(account.NewAccount())}, + {exerrors.Must(libolm.NewAccount()), exerrors.Must(account.NewAccount())}, + {exerrors.Must(account.NewAccount()), exerrors.Must(libolm.NewAccount())}, } for _, combo := range combos { receiver, sender := combo[0], combo[1] - require.NoError(t, receiver.GenOneTimeKeys(nil, 50)) - require.NoError(t, sender.GenOneTimeKeys(nil, 50)) + require.NoError(t, receiver.GenOneTimeKeys(50)) + require.NoError(t, sender.GenOneTimeKeys(50)) _, receiverCurve25519, err := receiver.IdentityKeys() require.NoError(t, err) From eb632a9994ab189125286c41d8362a4dc664d76e Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Wed, 4 Sep 2024 00:31:14 -0600 Subject: [PATCH 0867/1647] goolm: simplify tests using testify Signed-off-by: Sumner Evans --- crypto/goolm/account/account_test.go | 543 ++++++-------------- crypto/goolm/cipher/aes_sha256_test.go | 73 +-- crypto/goolm/cipher/pickle_test.go | 15 +- crypto/goolm/crypto/curve25519_test.go | 128 ++--- crypto/goolm/crypto/ed25519_test.go | 111 ++-- crypto/goolm/crypto/hmac_test.go | 63 +-- crypto/goolm/libolmpickle/pickle_test.go | 35 +- crypto/goolm/libolmpickle/unpickle_test.go | 51 +- crypto/goolm/megolm/megolm_test.go | 80 +-- crypto/goolm/message/decoder_test.go | 19 +- crypto/goolm/message/group_message_test.go | 27 +- crypto/goolm/message/message_test.go | 35 +- crypto/goolm/message/prekey_message_test.go | 40 +- crypto/goolm/pk/pk_test.go | 91 +--- crypto/goolm/ratchet/olm_test.go | 126 ++--- crypto/goolm/session/megolm_session_test.go | 210 ++------ crypto/goolm/session/olm_session_test.go | 114 ++-- 17 files changed, 489 insertions(+), 1272 deletions(-) diff --git a/crypto/goolm/account/account_test.go b/crypto/goolm/account/account_test.go index 05d6d5fc..e7739bb6 100644 --- a/crypto/goolm/account/account_test.go +++ b/crypto/goolm/account/account_test.go @@ -1,13 +1,10 @@ package account_test import ( - "bytes" "encoding/base64" - "errors" "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "maunium.net/go/mautrix/id" @@ -18,75 +15,42 @@ import ( func TestAccount(t *testing.T) { firstAccount, err := account.NewAccount() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) err = firstAccount.GenFallbackKey() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) err = firstAccount.GenOneTimeKeys(2) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) encryptionKey := []byte("testkey") + //now pickle account in JSON format pickled, err := firstAccount.PickleAsJSON(encryptionKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + //now unpickle into new Account unpickledAccount, err := account.AccountFromJSONPickled(pickled, encryptionKey) - if err != nil { - t.Fatal(err) - } - //check if accounts are the same - if firstAccount.NextOneTimeKeyID != unpickledAccount.NextOneTimeKeyID { - t.Fatal("NextOneTimeKeyID unequal") - } - if !firstAccount.CurrentFallbackKey.Equal(unpickledAccount.CurrentFallbackKey) { - t.Fatal("CurrentFallbackKey unequal") - } - if !firstAccount.PrevFallbackKey.Equal(unpickledAccount.PrevFallbackKey) { - t.Fatal("PrevFallbackKey unequal") - } - if len(firstAccount.OTKeys) != len(unpickledAccount.OTKeys) { - t.Fatal("OneTimeKeysunequal") - } - for i := range firstAccount.OTKeys { - if !firstAccount.OTKeys[i].Equal(unpickledAccount.OTKeys[i]) { - t.Fatalf("OneTimeKeys %d unequal", i) - } - } - if !firstAccount.IdKeys.Curve25519.PrivateKey.Equal(unpickledAccount.IdKeys.Curve25519.PrivateKey) { - t.Fatal("IdentityKeys Curve25519 private unequal") - } - if !firstAccount.IdKeys.Curve25519.PublicKey.Equal(unpickledAccount.IdKeys.Curve25519.PublicKey) { - t.Fatal("IdentityKeys Curve25519 public unequal") - } - if !firstAccount.IdKeys.Ed25519.PrivateKey.Equal(unpickledAccount.IdKeys.Ed25519.PrivateKey) { - t.Fatal("IdentityKeys Ed25519 private unequal") - } - if !firstAccount.IdKeys.Ed25519.PublicKey.Equal(unpickledAccount.IdKeys.Ed25519.PublicKey) { - t.Fatal("IdentityKeys Ed25519 public unequal") - } + assert.NoError(t, err) - if otks, err := firstAccount.OneTimeKeys(); err != nil || len(otks) != 2 { - t.Fatal("should get 2 unpublished oneTimeKeys") - } - if len(firstAccount.FallbackKeyUnpublished()) == 0 { - t.Fatal("should get fallbackKey") - } + //check if accounts are the same + assert.Equal(t, firstAccount.NextOneTimeKeyID, unpickledAccount.NextOneTimeKeyID) + assert.Equal(t, firstAccount.CurrentFallbackKey, unpickledAccount.CurrentFallbackKey) + assert.Equal(t, firstAccount.PrevFallbackKey, unpickledAccount.PrevFallbackKey) + assert.Equal(t, firstAccount.OTKeys, unpickledAccount.OTKeys) + assert.Equal(t, firstAccount.IdKeys, unpickledAccount.IdKeys) + + // Ensure that all of the keys are unpublished right now + otks, err := firstAccount.OneTimeKeys() + assert.NoError(t, err) + assert.Len(t, otks, 2) + assert.Len(t, firstAccount.FallbackKeyUnpublished(), 1) + + // Now, publish the key and make sure that they are published firstAccount.MarkKeysAsPublished() - if len(firstAccount.FallbackKey()) == 0 { - t.Fatal("should get fallbackKey") - } - if len(firstAccount.FallbackKeyUnpublished()) != 0 { - t.Fatal("should get no fallbackKey") - } - if otks, err := firstAccount.OneTimeKeys(); err != nil || len(otks) != 0 { - t.Fatal("should get no oneTimeKeys") - } + + assert.Len(t, firstAccount.FallbackKeyUnpublished(), 0) + assert.Len(t, firstAccount.FallbackKey(), 1) + otks, err = firstAccount.OneTimeKeys() + assert.NoError(t, err) + assert.Len(t, otks, 0) } func TestAccountPickleJSON(t *testing.T) { @@ -104,109 +68,49 @@ func TestAccountPickleJSON(t *testing.T) { pickledData := []byte("6POkBWwbNl20fwvZWsOu0jgbHy4jkA5h0Ji+XCag59+ifWIRPDrqtgQi9HmkLiSF6wUhhYaV4S73WM+Hh+dlCuZRuXhTQr8yGPTifjcjq8birdAhObbEqHrYEdqaQkrgBLr/rlS5sibXeDqbkhVu4LslvootU9DkcCbd4b/0Flh7iugxqkcCs5GDndTEx9IzTVJzmK82Y0Q1Z1Z9Vuc2Iw746PtBJLtZjite6fSMp2NigPX/ZWWJ3OnwcJo0Vvjy8hgptZEWkamOHdWbUtelbHyjDIZlvxOC25D3rFif0zzPkF9qdpBPqVCWPPzGFmgnqKau6CHrnPfq7GLsM3BrprD7sHN1Js28ex14gXQPjBT7KTUo6H0e4gQMTMRp4qb8btNXDeId8xIFIElTh2SXZBTDmSq/ziVNJinEvYV8mGPvJZjDQQU+SyoS/HZ8uMc41tH0BOGDbFMHbfLMiz61E429gOrx2klu5lqyoyet7//HKi0ed5w2dQ") account, err := account.AccountFromJSONPickled(pickledData, key) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) expectedJSON := `{"ed25519":"qWvNB6Ztov5/AOsP073op0O32KJ8/tgSNarT7MaYgQE","curve25519":"TFUB6M6zwgyWhBEp2m1aUodl2AsnsrIuBr8l9AvwGS8"}` jsonData, err := account.IdentityKeysJSON() - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(jsonData, []byte(expectedJSON)) { - t.Fatalf("Expected '%s' but got '%s'", expectedJSON, jsonData) - } + assert.NoError(t, err) + assert.Equal(t, expectedJSON, string(jsonData)) } func TestSessions(t *testing.T) { aliceAccount, err := account.NewAccount() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) err = aliceAccount.GenOneTimeKeys(5) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) bobAccount, err := account.NewAccount() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) err = bobAccount.GenOneTimeKeys(5) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) aliceSession, err := aliceAccount.NewOutboundSession(bobAccount.IdKeys.Curve25519.B64Encoded(), bobAccount.OTKeys[2].Key.B64Encoded()) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) plaintext := []byte("test message") msgType, crypttext, err := aliceSession.Encrypt(plaintext) - if err != nil { - t.Fatal(err) - } - if msgType != id.OlmMsgTypePreKey { - t.Fatal("wrong message type") - } + assert.NoError(t, err) + assert.Equal(t, id.OlmMsgTypePreKey, msgType) bobSession, err := bobAccount.NewInboundSession(string(crypttext)) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) decodedText, err := bobSession.Decrypt(string(crypttext), msgType) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plaintext, decodedText) { - t.Fatalf("expected '%s' but got '%s'", string(plaintext), string(decodedText)) - } + assert.NoError(t, err) + assert.Equal(t, plaintext, decodedText) } func TestAccountPickle(t *testing.T) { pickleKey := []byte("secret_key") account, err := account.AccountFromPickled(pickledDataFromLibOlm, pickleKey) - if err != nil { - t.Fatal(err) - } - if !expectedEd25519KeyPairPickleLibOLM.PrivateKey.Equal(account.IdKeys.Ed25519.PrivateKey) { - t.Fatal("keys not equal") - } - if !expectedEd25519KeyPairPickleLibOLM.PublicKey.Equal(account.IdKeys.Ed25519.PublicKey) { - t.Fatal("keys not equal") - } - if !expectedCurve25519KeyPairPickleLibOLM.PrivateKey.Equal(account.IdKeys.Curve25519.PrivateKey) { - t.Fatal("keys not equal") - } - if !expectedCurve25519KeyPairPickleLibOLM.PublicKey.Equal(account.IdKeys.Curve25519.PublicKey) { - t.Fatal("keys not equal") - } - if account.NextOneTimeKeyID != 42 { - t.Fatal("wrong next otKey id") - } - if len(account.OTKeys) != len(expectedOTKeysPickleLibOLM) { - t.Fatal("wrong number of otKeys") - } - if account.NumFallbackKeys != 0 { - t.Fatal("fallback keys set but not in pickle") - } - for curIndex, curValue := range account.OTKeys { - curExpected := expectedOTKeysPickleLibOLM[curIndex] - if curExpected.ID != curValue.ID { - t.Fatal("OTKey id not correct") - } - if !curExpected.Key.PublicKey.Equal(curValue.Key.PublicKey) { - t.Fatal("OTKey public key not correct") - } - if !curExpected.Key.PrivateKey.Equal(curValue.Key.PrivateKey) { - t.Fatal("OTKey private key not correct") - } - } + assert.NoError(t, err) + assert.Equal(t, expectedEd25519KeyPairPickleLibOLM, account.IdKeys.Ed25519) + assert.Equal(t, expectedCurve25519KeyPairPickleLibOLM, account.IdKeys.Curve25519) + assert.EqualValues(t, 42, account.NextOneTimeKeyID) + assert.Equal(t, account.OTKeys, expectedOTKeysPickleLibOLM) + assert.EqualValues(t, 0, account.NumFallbackKeys) targetPickled, err := account.Pickle(pickleKey) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(targetPickled, pickledDataFromLibOlm) { - t.Fatal("repickled value does not equal given value") - } + assert.NoError(t, err) + assert.Equal(t, pickledDataFromLibOlm, targetPickled) } func TestOldAccountPickle(t *testing.T) { @@ -218,355 +122,212 @@ func TestOldAccountPickle(t *testing.T) { "O5TmXua1FcU") pickleKey := []byte("") account, err := account.NewAccount() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) err = account.Unpickle(pickled, pickleKey) - if err == nil { - t.Fatal("expected error") - } else { - if !errors.Is(err, olm.ErrBadVersion) { - t.Fatal(err) - } - } + assert.ErrorIs(t, err, olm.ErrBadVersion) } func TestLoopback(t *testing.T) { accountA, err := account.NewAccount() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) accountB, err := account.NewAccount() - if err != nil { - t.Fatal(err) - } - err = accountB.GenOneTimeKeys(42) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + err = accountB.GenOneTimeKeys( 42) + assert.NoError(t, err) aliceSession, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), accountB.OTKeys[0].Key.B64Encoded()) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) plainText := []byte("Hello, World") msgType, message1, err := aliceSession.Encrypt(plainText) - if err != nil { - t.Fatal(err) - } - if msgType != id.OlmMsgTypePreKey { - t.Fatal("wrong message type") - } + assert.NoError(t, err) + assert.Equal(t, id.OlmMsgTypePreKey, msgType) bobSession, err := accountB.NewInboundSession(string(message1)) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) // Check that the inbound session matches the message it was created from. sessionIsOK, err := bobSession.MatchesInboundSessionFrom("", string(message1)) - if err != nil { - t.Fatal(err) - } - if !sessionIsOK { - t.Fatal("session was not detected to be valid") - } + assert.NoError(t, err) + assert.True(t, sessionIsOK, "session was not detected to be valid") + // Check that the inbound session matches the key this message is supposed to be from. aIDKey := accountA.IdKeys.Curve25519.PublicKey.B64Encoded() sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(aIDKey), string(message1)) - if err != nil { - t.Fatal(err) - } - if !sessionIsOK { - t.Fatal("session is sad to be not from a but it should") - } + assert.NoError(t, err) + assert.True(t, sessionIsOK, "session is sad to be not from a but it should") + // Check that the inbound session isn't from a different user. bIDKey := accountB.IdKeys.Curve25519.PublicKey.B64Encoded() sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(bIDKey), string(message1)) - if err != nil { - t.Fatal(err) - } - if sessionIsOK { - t.Fatal("session is sad to be from b but is from a") - } + assert.NoError(t, err) + assert.False(t, sessionIsOK, "session is sad to be from b but is from a") + // Check that we can decrypt the message. decryptedMessage, err := bobSession.Decrypt(string(message1), msgType) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(decryptedMessage, plainText) { - t.Fatal("messages are not the same") - } + assert.NoError(t, err) + assert.Equal(t, plainText, decryptedMessage) msgTyp2, message2, err := bobSession.Encrypt(plainText) - if err != nil { - t.Fatal(err) - } - if msgTyp2 == id.OlmMsgTypePreKey { - t.Fatal("wrong message type") - } + assert.NoError(t, err) + assert.Equal(t, id.OlmMsgTypeMsg, msgTyp2) decryptedMessage2, err := aliceSession.Decrypt(string(message2), msgTyp2) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(decryptedMessage2, plainText) { - t.Fatal("messages are not the same") - } + assert.NoError(t, err) + assert.Equal(t, plainText, decryptedMessage2) //decrypting again should fail, as the chain moved on _, err = aliceSession.Decrypt(string(message2), msgTyp2) - if err == nil { - t.Fatal("expected error") - } + assert.Error(t, err) + assert.ErrorIs(t, err, olm.ErrMessageKeyNotFound) //compare sessionIDs - if aliceSession.ID() != bobSession.ID() { - t.Fatal("sessionIDs are not equal") - } + assert.Equal(t, aliceSession.ID(), bobSession.ID()) } func TestMoreMessages(t *testing.T) { accountA, err := account.NewAccount() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) accountB, err := account.NewAccount() - if err != nil { - t.Fatal(err) - } - err = accountB.GenOneTimeKeys(42) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + err = accountB.GenOneTimeKeys( 42) + assert.NoError(t, err) aliceSession, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), accountB.OTKeys[0].Key.B64Encoded()) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) plainText := []byte("Hello, World") msgType, message1, err := aliceSession.Encrypt(plainText) - if err != nil { - t.Fatal(err) - } - if msgType != id.OlmMsgTypePreKey { - t.Fatal("wrong message type") - } + assert.NoError(t, err) + assert.Equal(t, id.OlmMsgTypePreKey, msgType) bobSession, err := accountB.NewInboundSession(string(message1)) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) decryptedMessage, err := bobSession.Decrypt(string(message1), msgType) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(decryptedMessage, plainText) { - t.Fatal("messages are not the same") - } + assert.NoError(t, err) + assert.Equal(t, plainText, decryptedMessage) for i := 0; i < 8; i++ { //alice sends, bob reveices msgType, message, err := aliceSession.Encrypt(plainText) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) if i == 0 { //The first time should still be a preKeyMessage as bob has not yet send a message to alice - if msgType != id.OlmMsgTypePreKey { - t.Fatal("wrong message type") - } + assert.Equal(t, id.OlmMsgTypePreKey, msgType) } else { - if msgType == id.OlmMsgTypePreKey { - t.Fatal("wrong message type") - } + assert.Equal(t, id.OlmMsgTypeMsg, msgType) } + decryptedMessage, err := bobSession.Decrypt(string(message), msgType) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(decryptedMessage, plainText) { - t.Fatal("messages are not the same") - } + assert.NoError(t, err) + assert.Equal(t, plainText, decryptedMessage) //now bob sends, alice receives msgType, message, err = bobSession.Encrypt(plainText) - if err != nil { - t.Fatal(err) - } - if msgType == id.OlmMsgTypePreKey { - t.Fatal("wrong message type") - } + assert.NoError(t, err) + assert.Equal(t, id.OlmMsgTypeMsg, msgType) + decryptedMessage, err = aliceSession.Decrypt(string(message), msgType) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(decryptedMessage, plainText) { - t.Fatal("messages are not the same") - } + assert.NoError(t, err) + assert.Equal(t, plainText, decryptedMessage) } } func TestFallbackKey(t *testing.T) { accountA, err := account.NewAccount() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) accountB, err := account.NewAccount() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) err = accountB.GenFallbackKey() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) fallBackKeys := accountB.FallbackKeyUnpublished() var fallbackKey id.Curve25519 for _, fbKey := range fallBackKeys { fallbackKey = fbKey } aliceSession, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), fallbackKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) plainText := []byte("Hello, World") msgType, message1, err := aliceSession.Encrypt(plainText) - if err != nil { - t.Fatal(err) - } - if msgType != id.OlmMsgTypePreKey { - t.Fatal("wrong message type") - } + assert.NoError(t, err) + assert.Equal(t, id.OlmMsgTypePreKey, msgType) bobSession, err := accountB.NewInboundSession(string(message1)) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) // Check that the inbound session matches the message it was created from. sessionIsOK, err := bobSession.MatchesInboundSessionFrom("", string(message1)) - if err != nil { - t.Fatal(err) - } - if !sessionIsOK { - t.Fatal("session was not detected to be valid") - } + assert.NoError(t, err) + assert.True(t, sessionIsOK, "session was not detected to be valid") + // Check that the inbound session matches the key this message is supposed to be from. aIDKey := accountA.IdKeys.Curve25519.PublicKey.B64Encoded() sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(aIDKey), string(message1)) - if err != nil { - t.Fatal(err) - } - if !sessionIsOK { - t.Fatal("session is sad to be not from a but it should") - } + assert.NoError(t, err) + assert.True(t, sessionIsOK, "session is sad to be not from a but it should") + // Check that the inbound session isn't from a different user. bIDKey := accountB.IdKeys.Curve25519.PublicKey.B64Encoded() sessionIsOK, err = bobSession.MatchesInboundSessionFrom(string(bIDKey), string(message1)) - if err != nil { - t.Fatal(err) - } - if sessionIsOK { - t.Fatal("session is sad to be from b but is from a") - } + assert.NoError(t, err) + assert.False(t, sessionIsOK, "session is sad to be from b but is from a") + // Check that we can decrypt the message. decryptedMessage, err := bobSession.Decrypt(string(message1), msgType) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(decryptedMessage, plainText) { - t.Fatal("messages are not the same") - } + assert.NoError(t, err) + assert.Equal(t, plainText, decryptedMessage) // create a new fallback key for B (the old fallback should still be usable) err = accountB.GenFallbackKey() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) // start another session and encrypt a message aliceSession2, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), fallbackKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) msgType2, message2, err := aliceSession2.Encrypt(plainText) - if err != nil { - t.Fatal(err) - } - if msgType2 != id.OlmMsgTypePreKey { - t.Fatal("wrong message type") - } + assert.NoError(t, err) + assert.Equal(t, id.OlmMsgTypePreKey, msgType2) + // bobSession should not be valid for the message2 // Check that the inbound session matches the message it was created from. sessionIsOK, err = bobSession.MatchesInboundSessionFrom("", string(message2)) - if err != nil { - t.Fatal(err) - } - if sessionIsOK { - t.Fatal("session was detected to be valid but should not") - } + assert.NoError(t, err) + assert.False(t, sessionIsOK, "session was detected to be valid but should not") + bobSession2, err := accountB.NewInboundSession(string(message2)) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) // Check that the inbound session matches the message it was created from. sessionIsOK, err = bobSession2.MatchesInboundSessionFrom("", string(message2)) - if err != nil { - t.Fatal(err) - } - if !sessionIsOK { - t.Fatal("session was not detected to be valid") - } + assert.NoError(t, err) + assert.True(t, sessionIsOK, "session was not detected to be valid") + // Check that the inbound session matches the key this message is supposed to be from. sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(string(aIDKey), string(message2)) - if err != nil { - t.Fatal(err) - } - if !sessionIsOK { - t.Fatal("session is sad to be not from a but it should") - } + assert.NoError(t, err) + assert.True(t, sessionIsOK, "session is sad to be not from a but it should") + // Check that the inbound session isn't from a different user. sessionIsOK, err = bobSession2.MatchesInboundSessionFrom(string(bIDKey), string(message2)) - if err != nil { - t.Fatal(err) - } - if sessionIsOK { - t.Fatal("session is sad to be from b but is from a") - } + assert.NoError(t, err) + assert.False(t, sessionIsOK, "session is sad to be from b but is from a") + // Check that we can decrypt the message. decryptedMessage2, err := bobSession2.Decrypt(string(message2), msgType2) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(decryptedMessage2, plainText) { - t.Fatal("messages are not the same") - } + assert.NoError(t, err) + assert.Equal(t, plainText, decryptedMessage2) //Forget the old fallback key -- creating a new session should fail now accountB.ForgetOldFallbackKey() // start another session and encrypt a message aliceSession3, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), fallbackKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) msgType3, message3, err := aliceSession3.Encrypt(plainText) - if err != nil { - t.Fatal(err) - } - if msgType3 != id.OlmMsgTypePreKey { - t.Fatal("wrong message type") - } + assert.NoError(t, err) + assert.Equal(t, id.OlmMsgTypePreKey, msgType3) _, err = accountB.NewInboundSession(string(message3)) - if err == nil { - t.Fatal("expected error") - } - if !errors.Is(err, olm.ErrBadMessageKeyID) { - t.Fatal(err) - } + assert.ErrorIs(t, err, olm.ErrBadMessageKeyID) } func TestOldV3AccountPickle(t *testing.T) { @@ -582,33 +343,23 @@ func TestOldV3AccountPickle(t *testing.T) { expectedUnpublishedFallbackJSON := []byte("{\"curve25519\":{}}") account, err := account.AccountFromPickled(pickledData, pickleKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) fallbackJSON, err := account.FallbackKeyJSON() - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(fallbackJSON, expectedFallbackJSON) { - t.Fatalf("expected not as result:\n%s\n%s\n", expectedFallbackJSON, fallbackJSON) - } + assert.NoError(t, err) + assert.Equal(t, expectedFallbackJSON, fallbackJSON) fallbackJSONUnpublished, err := account.FallbackKeyUnpublishedJSON() - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(fallbackJSONUnpublished, expectedUnpublishedFallbackJSON) { - t.Fatalf("expected not as result:\n%s\n%s\n", expectedUnpublishedFallbackJSON, fallbackJSONUnpublished) - } + assert.NoError(t, err) + assert.Equal(t, expectedUnpublishedFallbackJSON, fallbackJSONUnpublished) } func TestAccountSign(t *testing.T) { accountA, err := account.NewAccount() - require.NoError(t, err) + assert.NoError(t, err) plainText := []byte("Hello, World") signatureB64, err := accountA.Sign(plainText) - require.NoError(t, err) + assert.NoError(t, err) signature, err := base64.RawStdEncoding.DecodeString(string(signatureB64)) - require.NoError(t, err) + assert.NoError(t, err) verified, err := signatures.VerifySignature(plainText, accountA.IdKeys.Ed25519.B64Encoded(), signature) assert.NoError(t, err) diff --git a/crypto/goolm/cipher/aes_sha256_test.go b/crypto/goolm/cipher/aes_sha256_test.go index d2f49cb1..69aae100 100644 --- a/crypto/goolm/cipher/aes_sha256_test.go +++ b/crypto/goolm/cipher/aes_sha256_test.go @@ -1,52 +1,44 @@ package cipher import ( - "bytes" "crypto/aes" "testing" + + "github.com/stretchr/testify/assert" ) func TestDeriveAESKeys(t *testing.T) { kdfInfo := []byte("test") key := []byte("test key") derivedKeys, err := deriveAESKeys(kdfInfo, key) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) derivedKeys2, err := deriveAESKeys(kdfInfo, key) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + //derivedKeys and derivedKeys2 should be identical - if !bytes.Equal(derivedKeys.key, derivedKeys2.key) || - !bytes.Equal(derivedKeys.iv, derivedKeys2.iv) || - !bytes.Equal(derivedKeys.hmacKey, derivedKeys2.hmacKey) { - t.Fail() - } + assert.Equal(t, derivedKeys.key, derivedKeys2.key) + assert.Equal(t, derivedKeys.iv, derivedKeys2.iv) + assert.Equal(t, derivedKeys.hmacKey, derivedKeys2.hmacKey) + //changing kdfInfo kdfInfo = []byte("other kdf") derivedKeys2, err = deriveAESKeys(kdfInfo, key) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + //derivedKeys and derivedKeys2 should now be different - if bytes.Equal(derivedKeys.key, derivedKeys2.key) || - bytes.Equal(derivedKeys.iv, derivedKeys2.iv) || - bytes.Equal(derivedKeys.hmacKey, derivedKeys2.hmacKey) { - t.Fail() - } + assert.NotEqual(t, derivedKeys.key, derivedKeys2.key) + assert.NotEqual(t, derivedKeys.iv, derivedKeys2.iv) + assert.NotEqual(t, derivedKeys.hmacKey, derivedKeys2.hmacKey) + //changing key key = []byte("other test key") derivedKeys, err = deriveAESKeys(kdfInfo, key) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + //derivedKeys and derivedKeys2 should now be different - if bytes.Equal(derivedKeys.key, derivedKeys2.key) || - bytes.Equal(derivedKeys.iv, derivedKeys2.iv) || - bytes.Equal(derivedKeys.hmacKey, derivedKeys2.hmacKey) { - t.Fail() - } + assert.NotEqual(t, derivedKeys.key, derivedKeys2.key) + assert.NotEqual(t, derivedKeys.iv, derivedKeys2.iv) + assert.NotEqual(t, derivedKeys.hmacKey, derivedKeys2.hmacKey) } func TestCipherAESSha256(t *testing.T) { @@ -58,26 +50,15 @@ func TestCipherAESSha256(t *testing.T) { message = append(message, []byte("-")...) } encrypted, err := cipher.Encrypt(key, []byte(message)) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) mac, err := cipher.MAC(key, encrypted) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) verified, err := cipher.Verify(key, encrypted, mac[:8]) - if err != nil { - t.Fatal(err) - } - if !verified { - t.Fatal("signature verification failed") - } + assert.NoError(t, err) + assert.True(t, verified, "signature verification failed") + resultPlainText, err := cipher.Decrypt(key, encrypted) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(message, resultPlainText) { - t.Fail() - } + assert.NoError(t, err) + assert.Equal(t, message, resultPlainText) } diff --git a/crypto/goolm/cipher/pickle_test.go b/crypto/goolm/cipher/pickle_test.go index b47bf3ea..b6cfe809 100644 --- a/crypto/goolm/cipher/pickle_test.go +++ b/crypto/goolm/cipher/pickle_test.go @@ -1,10 +1,11 @@ package cipher_test import ( - "bytes" "crypto/aes" "testing" + "github.com/stretchr/testify/assert" + "maunium.net/go/mautrix/crypto/goolm/cipher" ) @@ -19,15 +20,9 @@ func TestEncoding(t *testing.T) { copy(toEncrypt, input) } encoded, err := cipher.Pickle(key, toEncrypt) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) decoded, err := cipher.Unpickle(key, encoded) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(decoded, toEncrypt) { - t.Fatalf("Expected '%s' but got '%s'", toEncrypt, decoded) - } + assert.NoError(t, err) + assert.Equal(t, toEncrypt, decoded) } diff --git a/crypto/goolm/crypto/curve25519_test.go b/crypto/goolm/crypto/curve25519_test.go index ce5f561b..b7c86eee 100644 --- a/crypto/goolm/crypto/curve25519_test.go +++ b/crypto/goolm/crypto/curve25519_test.go @@ -1,39 +1,26 @@ package crypto_test import ( - "bytes" "testing" + "github.com/stretchr/testify/assert" + "maunium.net/go/mautrix/crypto/goolm/crypto" ) func TestCurve25519(t *testing.T) { firstKeypair, err := crypto.Curve25519GenerateKey() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) secondKeypair, err := crypto.Curve25519GenerateKey() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) sharedSecretFromFirst, err := firstKeypair.SharedSecret(secondKeypair.PublicKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) sharedSecretFromSecond, err := secondKeypair.SharedSecret(firstKeypair.PublicKey) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(sharedSecretFromFirst, sharedSecretFromSecond) { - t.Fatal("shared secret not equal") - } + assert.NoError(t, err) + assert.Equal(t, sharedSecretFromFirst, sharedSecretFromSecond, "shared secret not equal") fromPrivate, err := crypto.Curve25519GenerateFromPrivate(firstKeypair.PrivateKey) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(fromPrivate.PublicKey, firstKeypair.PublicKey) { - t.Fatal("public keys not equal") - } + assert.NoError(t, err) + assert.Equal(t, fromPrivate, firstKeypair) } func TestCurve25519Case1(t *testing.T) { @@ -76,112 +63,59 @@ func TestCurve25519Case1(t *testing.T) { PublicKey: bobPublic, } agreementFromAlice, err := aliceKeyPair.SharedSecret(bobKeyPair.PublicKey) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(agreementFromAlice, expectedAgreement) { - t.Fatal("expected agreement does not match agreement from Alice's view") - } + assert.NoError(t, err) + assert.Equal(t, expectedAgreement, agreementFromAlice, "expected agreement does not match agreement from Alice's view") agreementFromBob, err := bobKeyPair.SharedSecret(aliceKeyPair.PublicKey) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(agreementFromBob, expectedAgreement) { - t.Fatal("expected agreement does not match agreement from Bob's view") - } + assert.NoError(t, err) + assert.Equal(t, expectedAgreement, agreementFromBob, "expected agreement does not match agreement from Bob's view") } func TestCurve25519Pickle(t *testing.T) { //create keypair keyPair, err := crypto.Curve25519GenerateKey() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) target := make([]byte, keyPair.PickleLen()) writtenBytes, err := keyPair.PickleLibOlm(target) - if err != nil { - t.Fatal(err) - } - if writtenBytes != len(target) { - t.Fatal("written bytes not correct") - } + assert.NoError(t, err) + assert.Len(t, target, writtenBytes) unpickledKeyPair := crypto.Curve25519KeyPair{} readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) - if err != nil { - t.Fatal(err) - } - if readBytes != len(target) { - t.Fatal("read bytes not correct") - } - if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) { - t.Fatal("private keys not correct") - } - if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) { - t.Fatal("public keys not correct") - } + assert.NoError(t, err) + assert.Len(t, target, readBytes) + assert.Equal(t, keyPair, unpickledKeyPair) } func TestCurve25519PicklePubKeyOnly(t *testing.T) { //create keypair keyPair, err := crypto.Curve25519GenerateKey() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) //Remove privateKey keyPair.PrivateKey = nil target := make([]byte, keyPair.PickleLen()) writtenBytes, err := keyPair.PickleLibOlm(target) - if err != nil { - t.Fatal(err) - } - if writtenBytes != len(target) { - t.Fatal("written bytes not correct") - } + assert.NoError(t, err) + assert.Len(t, target, writtenBytes) unpickledKeyPair := crypto.Curve25519KeyPair{} readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) - if err != nil { - t.Fatal(err) - } - if readBytes != len(target) { - t.Fatal("read bytes not correct") - } - if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) { - t.Fatal("private keys not correct") - } - if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) { - t.Fatal("public keys not correct") - } + assert.NoError(t, err) + assert.Len(t, target, readBytes) + assert.Equal(t, keyPair, unpickledKeyPair) } func TestCurve25519PicklePrivKeyOnly(t *testing.T) { //create keypair keyPair, err := crypto.Curve25519GenerateKey() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) //Remove public keyPair.PublicKey = nil target := make([]byte, keyPair.PickleLen()) writtenBytes, err := keyPair.PickleLibOlm(target) - if err != nil { - t.Fatal(err) - } - if writtenBytes != len(target) { - t.Fatal("written bytes not correct") - } + assert.NoError(t, err) + assert.Len(t, target, writtenBytes) unpickledKeyPair := crypto.Curve25519KeyPair{} readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) - if err != nil { - t.Fatal(err) - } - if readBytes != len(target) { - t.Fatal("read bytes not correct") - } - if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) { - t.Fatal("private keys not correct") - } - if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) { - t.Fatal("public keys not correct") - } + assert.NoError(t, err) + assert.Len(t, target, readBytes) + assert.Equal(t, keyPair, unpickledKeyPair) } diff --git a/crypto/goolm/crypto/ed25519_test.go b/crypto/goolm/crypto/ed25519_test.go index 3588205a..41fb0977 100644 --- a/crypto/goolm/crypto/ed25519_test.go +++ b/crypto/goolm/crypto/ed25519_test.go @@ -1,140 +1,87 @@ package crypto_test import ( - "bytes" "testing" + "github.com/stretchr/testify/assert" + "maunium.net/go/mautrix/crypto/goolm/crypto" ) func TestEd25519(t *testing.T) { keypair, err := crypto.Ed25519GenerateKey() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) message := []byte("test message") signature := keypair.Sign(message) - if !keypair.Verify(message, signature) { - t.Fail() - } + assert.True(t, keypair.Verify(message, signature)) } func TestEd25519Case1(t *testing.T) { //64 bytes for ed25519 package keyPair, err := crypto.Ed25519GenerateKey() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) message := []byte("Hello, World") keyPair2 := crypto.Ed25519GenerateFromPrivate(keyPair.PrivateKey) - if !bytes.Equal(keyPair.PublicKey, keyPair2.PublicKey) { - t.Fatal("not equal key pairs") - } + assert.Equal(t, keyPair, keyPair2, "not equal key pairs") signature := keyPair.Sign(message) verified := keyPair.Verify(message, signature) - if !verified { - t.Fatal("message did not verify although it should") - } + assert.True(t, verified, "message did not verify although it should") + //Now change the message and verify again message = append(message, []byte("a")...) verified = keyPair.Verify(message, signature) - if verified { - t.Fatal("message did verify although it should not") - } + assert.False(t, verified, "message did verify although it should not") } func TestEd25519Pickle(t *testing.T) { //create keypair keyPair, err := crypto.Ed25519GenerateKey() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) target := make([]byte, keyPair.PickleLen()) writtenBytes, err := keyPair.PickleLibOlm(target) - if err != nil { - t.Fatal(err) - } - if writtenBytes != len(target) { - t.Fatal("written bytes not correct") - } + assert.NoError(t, err) + assert.Len(t, target, writtenBytes) unpickledKeyPair := crypto.Ed25519KeyPair{} readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) - if err != nil { - t.Fatal(err) - } - if readBytes != len(target) { - t.Fatal("read bytes not correct") - } - if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) { - t.Fatal("private keys not correct") - } - if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) { - t.Fatal("public keys not correct") - } + assert.NoError(t, err) + assert.Len(t, target, readBytes, "read bytes not correct") + assert.Equal(t, keyPair, unpickledKeyPair) } func TestEd25519PicklePubKeyOnly(t *testing.T) { //create keypair keyPair, err := crypto.Ed25519GenerateKey() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) //Remove privateKey keyPair.PrivateKey = nil target := make([]byte, keyPair.PickleLen()) writtenBytes, err := keyPair.PickleLibOlm(target) - if err != nil { - t.Fatal(err) - } - if writtenBytes != len(target) { - t.Fatal("written bytes not correct") - } + assert.NoError(t, err) + assert.Len(t, target, writtenBytes) + unpickledKeyPair := crypto.Ed25519KeyPair{} readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) - if err != nil { - t.Fatal(err) - } - if readBytes != len(target) { - t.Fatal("read bytes not correct") - } - if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) { - t.Fatal("private keys not correct") - } - if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) { - t.Fatal("public keys not correct") - } + assert.NoError(t, err) + assert.Len(t, target, readBytes, "read bytes not correct") + assert.Equal(t, keyPair, unpickledKeyPair) } func TestEd25519PicklePrivKeyOnly(t *testing.T) { //create keypair keyPair, err := crypto.Ed25519GenerateKey() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) //Remove public keyPair.PublicKey = nil target := make([]byte, keyPair.PickleLen()) writtenBytes, err := keyPair.PickleLibOlm(target) - if err != nil { - t.Fatal(err) - } - if writtenBytes != len(target) { - t.Fatal("written bytes not correct") - } + assert.NoError(t, err) + assert.Len(t, target, writtenBytes) + unpickledKeyPair := crypto.Ed25519KeyPair{} readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) - if err != nil { - t.Fatal(err) - } - if readBytes != len(target) { - t.Fatal("read bytes not correct") - } - if !bytes.Equal(keyPair.PrivateKey, unpickledKeyPair.PrivateKey) { - t.Fatal("private keys not correct") - } - if !bytes.Equal(keyPair.PublicKey, unpickledKeyPair.PublicKey) { - t.Fatal("public keys not correct") - } + assert.NoError(t, err) + assert.Len(t, target, readBytes, "read bytes not correct") + assert.Equal(t, keyPair, unpickledKeyPair) } diff --git a/crypto/goolm/crypto/hmac_test.go b/crypto/goolm/crypto/hmac_test.go index 95c0bfd5..127be131 100644 --- a/crypto/goolm/crypto/hmac_test.go +++ b/crypto/goolm/crypto/hmac_test.go @@ -1,49 +1,44 @@ package crypto_test import ( - "bytes" "encoding/base64" "io" "testing" + "github.com/stretchr/testify/assert" + "maunium.net/go/mautrix/crypto/goolm/crypto" ) -func TestHMACSha256(t *testing.T) { +func TestHMACSHA256(t *testing.T) { key := []byte("test key") message := []byte("test message") hash := crypto.HMACSHA256(key, message) - if !bytes.Equal(hash, crypto.HMACSHA256(key, message)) { - t.Fail() - } + assert.Equal(t, hash, crypto.HMACSHA256(key, message)) + str := "A4M0ovdiWHaZ5msdDFbrvtChFwZIoIaRSVGmv8bmPtc" result, err := base64.RawStdEncoding.DecodeString(str) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(result, hash) { - t.Fail() - } + assert.NoError(t, err) + assert.Equal(t, result, hash) } -func TestHKDFSha256(t *testing.T) { +func TestHKDFSHA256(t *testing.T) { message := []byte("test content") + hkdf := crypto.HKDFSHA256(message, nil, nil) - hkdf2 := crypto.HKDFSHA256(message, nil, nil) result := make([]byte, 32) - if _, err := io.ReadFull(hkdf, result); err != nil { - t.Fatal(err) - } + _, err := io.ReadFull(hkdf, result) + assert.NoError(t, err) + + hkdf2 := crypto.HKDFSHA256(message, nil, nil) result2 := make([]byte, 32) - if _, err := io.ReadFull(hkdf2, result2); err != nil { - t.Fatal(err) - } - if !bytes.Equal(result, result2) { - t.Fail() - } + _, err = io.ReadFull(hkdf2, result2) + assert.NoError(t, err) + + assert.Equal(t, result, result2) } -func TestSha256Case1(t *testing.T) { +func TestSHA256Case1(t *testing.T) { input := make([]byte, 0) expected := []byte{ 0xE3, 0xB0, 0xC4, 0x42, 0x98, 0xFC, 0x1C, 0x14, @@ -52,9 +47,7 @@ func TestSha256Case1(t *testing.T) { 0xA4, 0x95, 0x99, 0x1B, 0x78, 0x52, 0xB8, 0x55, } result := crypto.SHA256(input) - if !bytes.Equal(expected, result) { - t.Fatalf("result not as expected:\n%v\n%v\n", result, expected) - } + assert.Equal(t, expected, result) } func TestHMACCase1(t *testing.T) { @@ -66,9 +59,7 @@ func TestHMACCase1(t *testing.T) { 0xc6, 0xc7, 0x12, 0x14, 0x42, 0x92, 0xc5, 0xad, } result := crypto.HMACSHA256(input, input) - if !bytes.Equal(expected, result) { - t.Fatalf("result not as expected:\n%v\n%v\n", result, expected) - } + assert.Equal(t, expected, result) } func TestHDKFCase1(t *testing.T) { @@ -92,9 +83,8 @@ func TestHDKFCase1(t *testing.T) { 0x22, 0xec, 0x84, 0x4a, 0xd7, 0xc2, 0xb3, 0xe5, } result := crypto.HMACSHA256(salt, input) - if !bytes.Equal(expectedHMAC, result) { - t.Fatalf("result not as expected:\n%v\n%v\n", result, expectedHMAC) - } + assert.Equal(t, expectedHMAC, result) + expectedHDKF := []byte{ 0x3c, 0xb2, 0x5f, 0x25, 0xfa, 0xac, 0xd5, 0x7a, 0x90, 0x43, 0x4f, 0x64, 0xd0, 0x36, 0x2f, 0x2a, @@ -105,10 +95,7 @@ func TestHDKFCase1(t *testing.T) { } resultReader := crypto.HKDFSHA256(input, salt, info) result = make([]byte, len(expectedHDKF)) - if _, err := io.ReadFull(resultReader, result); err != nil { - t.Fatal(err) - } - if !bytes.Equal(expectedHDKF, result) { - t.Fatalf("result not as expected:\n%v\n%v\n", result, expectedHDKF) - } + _, err := io.ReadFull(resultReader, result) + assert.NoError(t, err) + assert.Equal(t, expectedHDKF, result) } diff --git a/crypto/goolm/libolmpickle/pickle_test.go b/crypto/goolm/libolmpickle/pickle_test.go index ce118428..27f083a0 100644 --- a/crypto/goolm/libolmpickle/pickle_test.go +++ b/crypto/goolm/libolmpickle/pickle_test.go @@ -1,9 +1,10 @@ package libolmpickle_test import ( - "bytes" "testing" + "github.com/stretchr/testify/assert" + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" ) @@ -23,12 +24,8 @@ func TestPickleUInt32(t *testing.T) { for curIndex := range values { response := make([]byte, 4) resPLen := libolmpickle.PickleUInt32(values[curIndex], response) - if resPLen != libolmpickle.PickleUInt32Len(values[curIndex]) { - t.Fatal("written bytes not correct") - } - if !bytes.Equal(response, expected[curIndex]) { - t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) - } + assert.Equal(t, libolmpickle.PickleUInt32Len(values[curIndex]), resPLen) + assert.Equal(t, expected[curIndex], response) } } @@ -44,12 +41,8 @@ func TestPickleBool(t *testing.T) { for curIndex := range values { response := make([]byte, 1) resPLen := libolmpickle.PickleBool(values[curIndex], response) - if resPLen != libolmpickle.PickleBoolLen(values[curIndex]) { - t.Fatal("written bytes not correct") - } - if !bytes.Equal(response, expected[curIndex]) { - t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) - } + assert.Equal(t, libolmpickle.PickleBoolLen(values[curIndex]), resPLen) + assert.Equal(t, expected[curIndex], response) } } @@ -65,12 +58,8 @@ func TestPickleUInt8(t *testing.T) { for curIndex := range values { response := make([]byte, 1) resPLen := libolmpickle.PickleUInt8(values[curIndex], response) - if resPLen != libolmpickle.PickleUInt8Len(values[curIndex]) { - t.Fatal("written bytes not correct") - } - if !bytes.Equal(response, expected[curIndex]) { - t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) - } + assert.Equal(t, libolmpickle.PickleUInt8Len(values[curIndex]), resPLen) + assert.Equal(t, expected[curIndex], response) } } @@ -88,11 +77,7 @@ func TestPickleBytes(t *testing.T) { for curIndex := range values { response := make([]byte, len(values[curIndex])) resPLen := libolmpickle.PickleBytes(values[curIndex], response) - if resPLen != libolmpickle.PickleBytesLen(values[curIndex]) { - t.Fatal("written bytes not correct") - } - if !bytes.Equal(response, expected[curIndex]) { - t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) - } + assert.Equal(t, libolmpickle.PickleBytesLen(values[curIndex]), resPLen) + assert.Equal(t, expected[curIndex], response) } } diff --git a/crypto/goolm/libolmpickle/unpickle_test.go b/crypto/goolm/libolmpickle/unpickle_test.go index 937630e5..71f75b18 100644 --- a/crypto/goolm/libolmpickle/unpickle_test.go +++ b/crypto/goolm/libolmpickle/unpickle_test.go @@ -1,9 +1,10 @@ package libolmpickle_test import ( - "bytes" "testing" + "github.com/stretchr/testify/assert" + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" ) @@ -20,15 +21,9 @@ func TestUnpickleUInt32(t *testing.T) { } for curIndex := range values { response, readLength, err := libolmpickle.UnpickleUInt32(values[curIndex]) - if err != nil { - t.Fatal(err) - } - if readLength != 4 { - t.Fatal("read bytes not correct") - } - if response != expected[curIndex] { - t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) - } + assert.NoError(t, err) + assert.Equal(t, 4, readLength) + assert.Equal(t, expected[curIndex], response) } } @@ -45,15 +40,9 @@ func TestUnpickleBool(t *testing.T) { } for curIndex := range values { response, readLength, err := libolmpickle.UnpickleBool(values[curIndex]) - if err != nil { - t.Fatal(err) - } - if readLength != 1 { - t.Fatal("read bytes not correct") - } - if response != expected[curIndex] { - t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) - } + assert.NoError(t, err) + assert.Equal(t, 1, readLength) + assert.Equal(t, expected[curIndex], response) } } @@ -68,15 +57,9 @@ func TestUnpickleUInt8(t *testing.T) { } for curIndex := range values { response, readLength, err := libolmpickle.UnpickleUInt8(values[curIndex]) - if err != nil { - t.Fatal(err) - } - if readLength != 1 { - t.Fatal("read bytes not correct") - } - if response != expected[curIndex] { - t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) - } + assert.NoError(t, err) + assert.Equal(t, 1, readLength) + assert.Equal(t, expected[curIndex], response) } } @@ -93,14 +76,8 @@ func TestUnpickleBytes(t *testing.T) { } for curIndex := range values { response, readLength, err := libolmpickle.UnpickleBytes(values[curIndex], 4) - if err != nil { - t.Fatal(err) - } - if readLength != 4 { - t.Fatal("read bytes not correct") - } - if !bytes.Equal(response, expected[curIndex]) { - t.Fatalf("response not as expected:\n%v\n%v\n", response, expected[curIndex]) - } + assert.NoError(t, err) + assert.Equal(t, 4, readLength) + assert.Equal(t, expected[curIndex], response) } } diff --git a/crypto/goolm/megolm/megolm_test.go b/crypto/goolm/megolm/megolm_test.go index 40289eaf..a6f7c1a7 100644 --- a/crypto/goolm/megolm/megolm_test.go +++ b/crypto/goolm/megolm/megolm_test.go @@ -1,9 +1,10 @@ package megolm_test import ( - "bytes" "testing" + "github.com/stretchr/testify/assert" + "maunium.net/go/mautrix/crypto/goolm/megolm" ) @@ -19,9 +20,7 @@ func init() { func TestAdvance(t *testing.T) { m, err := megolm.New(0, startData) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) expectedData := [megolm.RatchetParts * megolm.RatchetPartLength]byte{ 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, @@ -34,9 +33,7 @@ func TestAdvance(t *testing.T) { 0x89, 0xbb, 0xb4, 0x23, 0xa1, 0x8f, 0x23, 0x82, 0x8f, 0xb2, 0x09, 0x0d, 0x6e, 0x2a, 0xf8, 0x6a, } m.Advance() - if !bytes.Equal(m.Data[:], expectedData[:]) { - t.Fatal("result after advancing the ratchet is not as expected") - } + assert.Equal(t, m.Data[:], expectedData[:], "result after advancing the ratchet is not as expected") //repeat with complex advance m.Data = startData @@ -51,9 +48,8 @@ func TestAdvance(t *testing.T) { 0x89, 0xbb, 0xb4, 0x23, 0xa1, 0x8f, 0x23, 0x82, 0x8f, 0xb2, 0x09, 0x0d, 0x6e, 0x2a, 0xf8, 0x6a, } m.AdvanceTo(0x1000000) - if !bytes.Equal(m.Data[:], expectedData[:]) { - t.Fatal("result after advancing the ratchet is not as expected") - } + assert.Equal(t, m.Data[:], expectedData[:], "result after advancing the ratchet is not as expected") + expectedData = [megolm.RatchetParts * megolm.RatchetPartLength]byte{ 0x54, 0x02, 0x2d, 0x7d, 0xc0, 0x29, 0x8e, 0x16, 0x37, 0xe2, 0x1c, 0x97, 0x15, 0x30, 0x92, 0xf9, 0x33, 0xc0, 0x56, 0xff, 0x74, 0xfe, 0x1b, 0x92, 0x2d, 0x97, 0x1f, 0x24, 0x82, 0xc2, 0x85, 0x9c, @@ -65,77 +61,45 @@ func TestAdvance(t *testing.T) { 0xd5, 0x6f, 0x03, 0xe2, 0x44, 0x16, 0xb9, 0x8e, 0x1c, 0xfd, 0x97, 0xc2, 0x06, 0xaa, 0x90, 0x7a, } m.AdvanceTo(0x1041506) - if !bytes.Equal(m.Data[:], expectedData[:]) { - t.Fatal("result after advancing the ratchet is not as expected") - } + assert.Equal(t, m.Data[:], expectedData[:], "result after advancing the ratchet is not as expected") } func TestAdvanceWraparound(t *testing.T) { m, err := megolm.New(0xffffffff, startData) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) m.AdvanceTo(0x1000000) - if m.Counter != 0x1000000 { - t.Fatal("counter not correct") - } + assert.EqualValues(t, 0x1000000, m.Counter, "counter not correct") m2, err := megolm.New(0, startData) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) m2.AdvanceTo(0x2000000) - if m2.Counter != 0x2000000 { - t.Fatal("counter not correct") - } - if !bytes.Equal(m.Data[:], m2.Data[:]) { - t.Fatal("result after wrapping the ratchet is not as expected") - } + assert.EqualValues(t, 0x2000000, m2.Counter, "counter not correct") + assert.Equal(t, m.Data, m2.Data, "result after wrapping the ratchet is not as expected") } func TestAdvanceOverflowByOne(t *testing.T) { m, err := megolm.New(0xffffffff, startData) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) m.AdvanceTo(0x0) - if m.Counter != 0x0 { - t.Fatal("counter not correct") - } + assert.EqualValues(t, 0x0, m.Counter, "counter not correct") m2, err := megolm.New(0xffffffff, startData) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) m2.Advance() - if m2.Counter != 0x0 { - t.Fatal("counter not correct") - } - if !bytes.Equal(m.Data[:], m2.Data[:]) { - t.Fatal("result after wrapping the ratchet is not as expected") - } + assert.EqualValues(t, 0x0, m2.Counter, "counter not correct") + assert.Equal(t, m.Data, m2.Data, "result after wrapping the ratchet is not as expected") } func TestAdvanceOverflow(t *testing.T) { m, err := megolm.New(0x1, startData) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) m.AdvanceTo(0x80000000) m.AdvanceTo(0x0) - if m.Counter != 0x0 { - t.Fatal("counter not correct") - } + assert.EqualValues(t, 0x0, m.Counter, "counter not correct") m2, err := megolm.New(0x1, startData) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) m2.AdvanceTo(0x0) - if m2.Counter != 0x0 { - t.Fatal("counter not correct") - } - if !bytes.Equal(m.Data[:], m2.Data[:]) { - t.Fatal("result after wrapping the ratchet is not as expected") - } + assert.EqualValues(t, 0x0, m2.Counter, "counter not correct") + assert.Equal(t, m.Data, m2.Data, "result after wrapping the ratchet is not as expected") } diff --git a/crypto/goolm/message/decoder_test.go b/crypto/goolm/message/decoder_test.go index 39503e3e..8b7561ad 100644 --- a/crypto/goolm/message/decoder_test.go +++ b/crypto/goolm/message/decoder_test.go @@ -1,17 +1,16 @@ package message import ( - "bytes" "testing" + + "github.com/stretchr/testify/assert" ) func TestEncodeLengthInt(t *testing.T) { numbers := []uint32{127, 128, 16383, 16384, 32767} expected := []int{1, 2, 2, 3, 3} for curIndex := range numbers { - if result := encodeVarIntByteLength(numbers[curIndex]); result != expected[curIndex] { - t.Fatalf("expected byte length of %d but got %d", expected[curIndex], result) - } + assert.Equal(t, expected[curIndex], encodeVarIntByteLength(numbers[curIndex])) } } @@ -25,9 +24,7 @@ func TestEncodeLengthString(t *testing.T) { strings = append(strings, []byte("this is an even longer message with a length between 128 and 16383 so that the varint of the length needs two byte. just needs some padding again ---------")) expected = append(expected, 2+155) for curIndex := range strings { - if result := encodeVarStringByteLength(strings[curIndex]); result != expected[curIndex] { - t.Fatalf("expected byte length of %d but got %d", expected[curIndex], result) - } + assert.Equal(t, expected[curIndex], encodeVarStringByteLength(strings[curIndex])) } } @@ -43,9 +40,7 @@ func TestEncodeInt(t *testing.T) { ints = append(ints, 16383) expected = append(expected, []byte{0b11111111, 0b01111111}) for curIndex := range ints { - if result := encodeVarInt(ints[curIndex]); !bytes.Equal(result, expected[curIndex]) { - t.Fatalf("expected byte of %b but got %b", expected[curIndex], result) - } + assert.Equal(t, expected[curIndex], encodeVarInt(ints[curIndex])) } } @@ -75,8 +70,6 @@ func TestEncodeString(t *testing.T) { res = append(res, curTest...) //Add string itself expected = append(expected, res) for curIndex := range strings { - if result := encodeVarString(strings[curIndex]); !bytes.Equal(result, expected[curIndex]) { - t.Fatalf("expected byte of %b but got %b", expected[curIndex], result) - } + assert.Equal(t, expected[curIndex], encodeVarString(strings[curIndex])) } } diff --git a/crypto/goolm/message/group_message_test.go b/crypto/goolm/message/group_message_test.go index 4ae1f830..d52cf6a3 100644 --- a/crypto/goolm/message/group_message_test.go +++ b/crypto/goolm/message/group_message_test.go @@ -1,9 +1,10 @@ package message_test import ( - "bytes" "testing" + "github.com/stretchr/testify/assert" + "maunium.net/go/mautrix/crypto/goolm/message" ) @@ -16,18 +17,10 @@ func TestGroupMessageDecode(t *testing.T) { msg := message.GroupMessage{} err := msg.Decode(messageRaw) - if err != nil { - t.Fatal(err) - } - if msg.Version != 3 { - t.Fatalf("Expected Version to be 3 but go %d", msg.Version) - } - if msg.MessageIndex != expectedMessageIndex { - t.Fatalf("Expected message index to be %d but got %d", expectedMessageIndex, msg.MessageIndex) - } - if !bytes.Equal(msg.Ciphertext, expectedCipherText) { - t.Fatalf("expected '%s' but got '%s'", expectedCipherText, msg.Ciphertext) - } + assert.NoError(t, err) + assert.EqualValues(t, 3, msg.Version) + assert.Equal(t, expectedMessageIndex, msg.MessageIndex) + assert.Equal(t, expectedCipherText, msg.Ciphertext) } func TestGroupMessageEncode(t *testing.T) { @@ -40,12 +33,8 @@ func TestGroupMessageEncode(t *testing.T) { Ciphertext: []byte("ciphertext"), } encoded, err := msg.EncodeAndMacAndSign(nil, nil, nil) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) encoded = append(encoded, hmacsha256...) encoded = append(encoded, sign...) - if !bytes.Equal(encoded, expectedRaw) { - t.Fatalf("expected '%s' but got '%s'", expectedRaw, encoded) - } + assert.Equal(t, expectedRaw, encoded) } diff --git a/crypto/goolm/message/message_test.go b/crypto/goolm/message/message_test.go index 4a9f29fb..b5c3551b 100644 --- a/crypto/goolm/message/message_test.go +++ b/crypto/goolm/message/message_test.go @@ -1,9 +1,10 @@ package message_test import ( - "bytes" "testing" + "github.com/stretchr/testify/assert" + "maunium.net/go/mautrix/crypto/goolm/message" ) @@ -14,24 +15,12 @@ func TestMessageDecode(t *testing.T) { msg := message.Message{} err := msg.Decode(messageRaw) - if err != nil { - t.Fatal(err) - } - if msg.Version != 3 { - t.Fatalf("Expected Version to be 3 but go %d", msg.Version) - } - if !msg.HasCounter { - t.Fatal("Expected to have counter") - } - if msg.Counter != 1 { - t.Fatalf("Expected counter to be 1 but got %d", msg.Counter) - } - if !bytes.Equal(msg.Ciphertext, expectedCipherText) { - t.Fatalf("expected '%s' but got '%s'", expectedCipherText, msg.Ciphertext) - } - if !bytes.Equal(msg.RatchetKey, expectedRatchetKey) { - t.Fatalf("expected '%s' but got '%s'", expectedRatchetKey, msg.RatchetKey) - } + assert.NoError(t, err) + assert.EqualValues(t, 3, msg.Version) + assert.True(t, msg.HasCounter) + assert.EqualValues(t, 1, msg.Counter) + assert.Equal(t, expectedCipherText, msg.Ciphertext) + assert.EqualValues(t, expectedRatchetKey, msg.RatchetKey) } func TestMessageEncode(t *testing.T) { @@ -44,11 +33,7 @@ func TestMessageEncode(t *testing.T) { Ciphertext: []byte("ciphertext"), } encoded, err := msg.EncodeAndMAC(nil, nil) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) encoded = append(encoded, hmacsha256...) - if !bytes.Equal(encoded, expectedRaw) { - t.Fatalf("expected '%s' but got '%s'", expectedRaw, encoded) - } + assert.Equal(t, expectedRaw, encoded) } diff --git a/crypto/goolm/message/prekey_message_test.go b/crypto/goolm/message/prekey_message_test.go index 431d27d5..fe196e31 100644 --- a/crypto/goolm/message/prekey_message_test.go +++ b/crypto/goolm/message/prekey_message_test.go @@ -1,9 +1,10 @@ package message_test import ( - "bytes" "testing" + "github.com/stretchr/testify/assert" + "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/message" ) @@ -19,29 +20,14 @@ func TestPreKeyMessageDecode(t *testing.T) { msg := message.PreKeyMessage{} err := msg.Decode(messageRaw) - if err != nil { - t.Fatal(err) - } - if msg.Version != 3 { - t.Fatalf("Expected Version to be 3 but go %d", msg.Version) - } - if !bytes.Equal(msg.OneTimeKey, expectedOneTimeKey) { - t.Fatalf("expected '%s' but got '%s'", expectedOneTimeKey, msg.OneTimeKey) - } - if !bytes.Equal(msg.IdentityKey, expectedIdKey) { - t.Fatalf("expected '%s' but got '%s'", expectedIdKey, msg.IdentityKey) - } - if !bytes.Equal(msg.BaseKey, expectedbaseKey) { - t.Fatalf("expected '%s' but got '%s'", expectedbaseKey, msg.BaseKey) - } - if !bytes.Equal(msg.Message, expectedmessage) { - t.Fatalf("expected '%s' but got '%s'", expectedmessage, msg.Message) - } + assert.NoError(t, err) + assert.EqualValues(t, 3, msg.Version) + assert.EqualValues(t, expectedOneTimeKey, msg.OneTimeKey) + assert.EqualValues(t, expectedIdKey, msg.IdentityKey) + assert.EqualValues(t, expectedbaseKey, msg.BaseKey) + assert.Equal(t, expectedmessage, msg.Message) theirIDKey := crypto.Curve25519PublicKey(expectedIdKey) - checked := msg.CheckFields(&theirIDKey) - if !checked { - t.Fatal("field check failed") - } + assert.True(t, msg.CheckFields(&theirIDKey), "field check failed") } func TestPreKeyMessageEncode(t *testing.T) { @@ -54,10 +40,6 @@ func TestPreKeyMessageEncode(t *testing.T) { Message: []byte("message"), } encoded, err := msg.Encode() - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(encoded, expectedRaw) { - t.Fatalf("got other than expected:\nExpected:\n%v\nGot:\n%v", expectedRaw, encoded) - } + assert.NoError(t, err) + assert.Equal(t, expectedRaw, encoded) } diff --git a/crypto/goolm/pk/pk_test.go b/crypto/goolm/pk/pk_test.go index f2d9b108..4b247430 100644 --- a/crypto/goolm/pk/pk_test.go +++ b/crypto/goolm/pk/pk_test.go @@ -1,12 +1,12 @@ package pk_test import ( - "bytes" "encoding/base64" "testing" + "github.com/stretchr/testify/assert" + "maunium.net/go/mautrix/crypto/goolm/crypto" - "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/pk" ) @@ -26,34 +26,20 @@ func TestEncryptionDecryption(t *testing.T) { } bobPublic := []byte("3p7bfXt9wbTTW2HC7OQ1Nz+DQ8hbeGdNrfx+FG+IK08") decryption, err := pk.NewDecryptionFromPrivate(alicePrivate) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal([]byte(decryption.PublicKey()), alicePublic) { - t.Fatal("public key not correct") - } - if !bytes.Equal(decryption.PrivateKey(), alicePrivate) { - t.Fatal("private key not correct") - } + assert.NoError(t, err) + assert.EqualValues(t, alicePublic, decryption.PublicKey(), "public key not correct") + assert.EqualValues(t, alicePrivate, decryption.PrivateKey(), "private key not correct") encryption, err := pk.NewEncryption(decryption.PublicKey()) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) plaintext := []byte("This is a test") ciphertext, mac, err := encryption.Encrypt(plaintext, bobPrivate) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) decrypted, err := decryption.Decrypt(bobPublic, mac, ciphertext) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(decrypted, plaintext) { - t.Fatal("message not equal") - } + assert.NoError(t, err) + assert.EqualValues(t, plaintext, decrypted, "message not equal") } func TestSigning(t *testing.T) { @@ -66,29 +52,20 @@ func TestSigning(t *testing.T) { message := []byte("We hold these truths to be self-evident, that all men are created equal, that they are endowed by their Creator with certain unalienable Rights, that among these are Life, Liberty and the pursuit of Happiness.") signing, _ := pk.NewSigningFromSeed(seed) signature, err := signing.Sign(message) - if err != nil { - t.Fatal(err) - } - signatureDecoded, err := goolmbase64.Decode(signature) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + signatureDecoded, err := base64.RawStdEncoding.DecodeString(string(signature)) + assert.NoError(t, err) pubKeyEncoded := signing.PublicKey() pubKeyDecoded, err := base64.RawStdEncoding.DecodeString(string(pubKeyEncoded)) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) pubKey := crypto.Ed25519PublicKey(pubKeyDecoded) verified := pubKey.Verify(message, signatureDecoded) - if !verified { - t.Fatal("signature did not verify") - } + assert.True(t, verified, "signature did not verify") + copy(signatureDecoded[0:], []byte("m")) verified = pubKey.Verify(message, signatureDecoded) - if verified { - t.Fatal("signature did verify") - } + assert.False(t, verified, "signature verified with wrong message") } func TestDecryptionPickling(t *testing.T) { @@ -100,37 +77,19 @@ func TestDecryptionPickling(t *testing.T) { } alicePublic := []byte("hSDwCYkwp1R0i33ctD73Wg2/Og0mOBr066SpjqqbTmo") decryption, err := pk.NewDecryptionFromPrivate(alicePrivate) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal([]byte(decryption.PublicKey()), alicePublic) { - t.Fatal("public key not correct") - } - if !bytes.Equal(decryption.PrivateKey(), alicePrivate) { - t.Fatal("private key not correct") - } + assert.NoError(t, err) + assert.EqualValues(t, alicePublic, decryption.PublicKey(), "public key not correct") + assert.EqualValues(t, alicePrivate, decryption.PrivateKey(), "private key not correct") pickleKey := []byte("secret_key") expectedPickle := []byte("qx37WTQrjZLz5tId/uBX9B3/okqAbV1ofl9UnHKno1eipByCpXleAAlAZoJgYnCDOQZDQWzo3luTSfkF9pU1mOILCbbouubs6TVeDyPfgGD9i86J8irHjA") pickled, err := decryption.Pickle(pickleKey) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(expectedPickle, pickled) { - t.Fatalf("pickle not as expected:\n%v\n%v\n", pickled, expectedPickle) - } + assert.NoError(t, err) + assert.EqualValues(t, expectedPickle, pickled, "pickle not as expected") newDecription, err := pk.NewDecryption() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) err = newDecription.Unpickle(pickled, pickleKey) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal([]byte(newDecription.PublicKey()), alicePublic) { - t.Fatal("public key not correct") - } - if !bytes.Equal(newDecription.PrivateKey(), alicePrivate) { - t.Fatal("private key not correct") - } + assert.NoError(t, err) + assert.EqualValues(t, alicePublic, newDecription.PublicKey(), "public key not correct") + assert.EqualValues(t, alicePrivate, newDecription.PrivateKey(), "private key not correct") } diff --git a/crypto/goolm/ratchet/olm_test.go b/crypto/goolm/ratchet/olm_test.go index 91dd6e9b..6a8fefc3 100644 --- a/crypto/goolm/ratchet/olm_test.go +++ b/crypto/goolm/ratchet/olm_test.go @@ -1,10 +1,11 @@ package ratchet_test import ( - "bytes" "encoding/json" "testing" + "github.com/stretchr/testify/assert" + "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/ratchet" @@ -38,149 +39,90 @@ func initializeRatchets() (*ratchet.Ratchet, *ratchet.Ratchet, error) { func TestSendReceive(t *testing.T) { aliceRatchet, bobRatchet, err := initializeRatchets() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) plainText := []byte("Hello Bob") //Alice sends Bob a message encryptedMessage, err := aliceRatchet.Encrypt(plainText) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) decrypted, err := bobRatchet.Decrypt(encryptedMessage) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plainText, decrypted) { - t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted) - } + assert.NoError(t, err) + assert.Equal(t, plainText, decrypted) //Bob sends Alice a message plainText = []byte("Hello Alice") encryptedMessage, err = bobRatchet.Encrypt(plainText) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) decrypted, err = aliceRatchet.Decrypt(encryptedMessage) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plainText, decrypted) { - t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted) - } + assert.NoError(t, err) + assert.Equal(t, plainText, decrypted) } func TestOutOfOrder(t *testing.T) { aliceRatchet, bobRatchet, err := initializeRatchets() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) plainText1 := []byte("First Message") plainText2 := []byte("Second Messsage. A bit longer than the first.") /* Alice sends Bob two messages and they arrive out of order */ message1Encrypted, err := aliceRatchet.Encrypt(plainText1) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) message2Encrypted, err := aliceRatchet.Encrypt(plainText2) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) decrypted2, err := bobRatchet.Decrypt(message2Encrypted) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) decrypted1, err := bobRatchet.Decrypt(message1Encrypted) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plainText1, decrypted1) { - t.Fatalf("expected '%v' from decryption but got '%v'", plainText1, decrypted1) - } - if !bytes.Equal(plainText2, decrypted2) { - t.Fatalf("expected '%v' from decryption but got '%v'", plainText2, decrypted2) - } + assert.NoError(t, err) + assert.Equal(t, plainText1, decrypted1) + assert.Equal(t, plainText2, decrypted2) } func TestMoreMessages(t *testing.T) { aliceRatchet, bobRatchet, err := initializeRatchets() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) plainText := []byte("These 15 bytes") for i := 0; i < 8; i++ { messageEncrypted, err := aliceRatchet.Encrypt(plainText) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + decrypted, err := bobRatchet.Decrypt(messageEncrypted) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plainText, decrypted) { - t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted) - } + assert.NoError(t, err) + assert.Equal(t, plainText, decrypted) } for i := 0; i < 8; i++ { messageEncrypted, err := bobRatchet.Encrypt(plainText) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + decrypted, err := aliceRatchet.Decrypt(messageEncrypted) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plainText, decrypted) { - t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted) - } + assert.NoError(t, err) + assert.Equal(t, plainText, decrypted) } messageEncrypted, err := aliceRatchet.Encrypt(plainText) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) decrypted, err := bobRatchet.Decrypt(messageEncrypted) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plainText, decrypted) { - t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted) - } + assert.NoError(t, err) + assert.Equal(t, plainText, decrypted) } func TestJSONEncoding(t *testing.T) { aliceRatchet, bobRatchet, err := initializeRatchets() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) marshaled, err := json.Marshal(aliceRatchet) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) newRatcher := ratchet.Ratchet{} err = json.Unmarshal(marshaled, &newRatcher) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) plainText := []byte("These 15 bytes") messageEncrypted, err := newRatcher.Encrypt(plainText) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) decrypted, err := bobRatchet.Decrypt(messageEncrypted) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plainText, decrypted) { - t.Fatalf("expected '%v' from decryption but got '%v'", plainText, decrypted) - } - + assert.NoError(t, err) + assert.Equal(t, plainText, decrypted) } diff --git a/crypto/goolm/session/megolm_session_test.go b/crypto/goolm/session/megolm_session_test.go index 7c3f455f..72d8857b 100644 --- a/crypto/goolm/session/megolm_session_test.go +++ b/crypto/goolm/session/megolm_session_test.go @@ -1,11 +1,12 @@ package session_test import ( - "bytes" "crypto/rand" - "errors" + "encoding/base64" "testing" + "github.com/stretchr/testify/assert" + "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/megolm" "maunium.net/go/mautrix/crypto/goolm/session" @@ -15,78 +16,42 @@ import ( func TestOutboundPickleJSON(t *testing.T) { pickleKey := []byte("secretKey") sess, err := session.NewMegolmOutboundSession() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) kp, err := crypto.Ed25519GenerateKey() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) sess.SigningKey = kp pickled, err := sess.PickleAsJSON(pickleKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) newSession := session.MegolmOutboundSession{} err = newSession.UnpickleAsJSON(pickled, pickleKey) - if err != nil { - t.Fatal(err) - } - if sess.ID() != newSession.ID() { - t.Fatal("session ids not equal") - } - if !bytes.Equal(sess.SigningKey.PrivateKey, newSession.SigningKey.PrivateKey) { - t.Fatal("private keys not equal") - } - if !bytes.Equal(sess.Ratchet.Data[:], newSession.Ratchet.Data[:]) { - t.Fatal("ratchet data not equal") - } - if sess.Ratchet.Counter != newSession.Ratchet.Counter { - t.Fatal("ratchet counter not equal") - } + assert.NoError(t, err) + assert.Equal(t, sess.ID(), newSession.ID()) + assert.Equal(t, sess.SigningKey, newSession.SigningKey) + assert.Equal(t, sess.Ratchet, newSession.Ratchet) } func TestInboundPickleJSON(t *testing.T) { pickleKey := []byte("secretKey") sess := session.MegolmInboundSession{} kp, err := crypto.Ed25519GenerateKey() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) sess.SigningKey = kp.PublicKey var randomData [megolm.RatchetParts * megolm.RatchetPartLength]byte _, err = rand.Read(randomData[:]) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) ratchet, err := megolm.New(0, randomData) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) sess.Ratchet = *ratchet pickled, err := sess.PickleAsJSON(pickleKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) newSession := session.MegolmInboundSession{} err = newSession.UnpickleAsJSON(pickled, pickleKey) - if err != nil { - t.Fatal(err) - } - if sess.ID() != newSession.ID() { - t.Fatal("sess ids not equal") - } - if !bytes.Equal(sess.SigningKey, newSession.SigningKey) { - t.Fatal("private keys not equal") - } - if !bytes.Equal(sess.Ratchet.Data[:], newSession.Ratchet.Data[:]) { - t.Fatal("ratchet data not equal") - } - if sess.Ratchet.Counter != newSession.Ratchet.Counter { - t.Fatal("ratchet counter not equal") - } + assert.NoError(t, err) + assert.Equal(t, sess.ID(), newSession.ID()) + assert.Equal(t, sess.SigningKey, newSession.SigningKey) + assert.Equal(t, sess.Ratchet, newSession.Ratchet) } func TestGroupSendReceive(t *testing.T) { @@ -100,46 +65,27 @@ func TestGroupSendReceive(t *testing.T) { ) outboundSession, err := session.NewMegolmOutboundSession() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) copy(outboundSession.Ratchet.Data[:], randomData) - if outboundSession.Ratchet.Counter != 0 { - t.Fatal("ratchet counter is not correkt") - } + assert.EqualValues(t, 0, outboundSession.Ratchet.Counter) + sessionSharing, err := outboundSession.SessionSharingMessage() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) plainText := []byte("Message") ciphertext, err := outboundSession.Encrypt(plainText) - if err != nil { - t.Fatal(err) - } - if outboundSession.Ratchet.Counter != 1 { - t.Fatal("ratchet counter is not correkt") - } + assert.NoError(t, err) + assert.EqualValues(t, 1, outboundSession.Ratchet.Counter) //build inbound session inboundSession, err := session.NewMegolmInboundSession(sessionSharing) - if err != nil { - t.Fatal(err) - } - if !inboundSession.SigningKeyVerified { - t.Fatal("key not verified") - } - if inboundSession.ID() != outboundSession.ID() { - t.Fatal("session ids not equal") - } + assert.NoError(t, err) + assert.True(t, inboundSession.SigningKeyVerified) + assert.Equal(t, outboundSession.ID(), inboundSession.ID()) //decode message decoded, _, err := inboundSession.Decrypt(ciphertext) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plainText, decoded) { - t.Fatal("messages not equal") - } + assert.NoError(t, err) + assert.Equal(t, plainText, decoded) } func TestGroupSessionExportImport(t *testing.T) { @@ -158,45 +104,26 @@ func TestGroupSessionExportImport(t *testing.T) { //init inbound inboundSession, err := session.NewMegolmInboundSession(sessionKey) - if err != nil { - t.Fatal(err) - } - if !inboundSession.SigningKeyVerified { - t.Fatal("signing key not verified") - } + assert.NoError(t, err) + assert.True(t, inboundSession.SigningKeyVerified) decrypted, _, err := inboundSession.Decrypt(message) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plaintext, decrypted) { - t.Fatal("message is not correct") - } + assert.NoError(t, err) + assert.Equal(t, plaintext, decrypted) //Export the keys exported, err := inboundSession.Export(0) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) secondInboundSession, err := session.NewMegolmInboundSessionFromExport(exported) - if err != nil { - t.Fatal(err) - } - if secondInboundSession.SigningKeyVerified { - t.Fatal("signing key is verified") - } + assert.NoError(t, err) + assert.False(t, secondInboundSession.SigningKeyVerified) + //decrypt with new session decrypted, _, err = secondInboundSession.Decrypt(message) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plaintext, decrypted) { - t.Fatal("message is not correct") - } - if !secondInboundSession.SigningKeyVerified { - t.Fatal("signing key not verified") - } + assert.NoError(t, err) + assert.Equal(t, plaintext, decrypted) + assert.True(t, secondInboundSession.SigningKeyVerified) } func TestBadSignatureGroupMessage(t *testing.T) { @@ -215,70 +142,43 @@ func TestBadSignatureGroupMessage(t *testing.T) { //init inbound inboundSession, err := session.NewMegolmInboundSession(sessionKey) - if err != nil { - t.Fatal(err) - } - if !inboundSession.SigningKeyVerified { - t.Fatal("signing key not verified") - } + assert.NoError(t, err) + assert.True(t, inboundSession.SigningKeyVerified) decrypted, _, err := inboundSession.Decrypt(message) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plaintext, decrypted) { - t.Fatal("message is not correct") - } + assert.NoError(t, err) + assert.Equal(t, plaintext, decrypted) //Now twiddle the signature copy(message[len(message)-1:], []byte("E")) _, _, err = inboundSession.Decrypt(message) - if err == nil { - t.Fatal("Signature was changed but did not cause an error") - } - if !errors.Is(err, olm.ErrBadSignature) { - t.Fatalf("wrong error %s", err.Error()) - } + assert.ErrorIs(t, err, olm.ErrBadSignature) } func TestOutbountPickle(t *testing.T) { pickledDataFromLibOlm := []byte("icDKYm0b4aO23WgUuOxdpPoxC0UlEOYPVeuduNH3IkpFsmnWx5KuEOpxGiZw5IuB/sSn2RZUCTiJ90IvgC7AClkYGHep9O8lpiqQX73XVKD9okZDCAkBc83eEq0DKYC7HBkGRAU/4T6QPIBBY3UK4QZwULLE/fLsi3j4YZBehMtnlsqgHK0q1bvX4cRznZItUO3TiOp5I+6PnQka6n8eHTyIEh3tCetilD+BKnHvtakE0eHHvG6pjEsMNN/vs7lkB5rV6XkoUKHLTE1dAfFunYEeHEZuKQpbG385dBwaMJXt4JrC0hU5jnv6jWNqAA0Ud9GxRDvkp04") pickleKey := []byte("secret_key") sess, err := session.MegolmOutboundSessionFromPickled(pickledDataFromLibOlm, pickleKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) newPickled, err := sess.Pickle(pickleKey) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(pickledDataFromLibOlm, newPickled) { - t.Fatal("pickled version does not equal libolm version") - } + assert.NoError(t, err) + assert.Equal(t, pickledDataFromLibOlm, newPickled) + pickledDataFromLibOlm = append(pickledDataFromLibOlm, []byte("a")...) _, err = session.MegolmOutboundSessionFromPickled(pickledDataFromLibOlm, pickleKey) - if err == nil { - t.Fatal("should have gotten an error") - } + assert.ErrorIs(t, err, olm.ErrBadMAC) } func TestInbountPickle(t *testing.T) { pickledDataFromLibOlm := []byte("1/IPCdtUoQxMba5XT7sjjUW0Hrs7no9duGFnhsEmxzFX2H3qtRc4eaFBRZYXxOBRTGZ6eMgy3IiSrgAQ1gUlSZf5Q4AVKeBkhvN4LZ6hdhQFv91mM+C2C55/4B9/gDjJEbDGiRgLoMqbWPDV+y0F4h0KaR1V1PiTCC7zCi4WdxJQ098nJLgDL4VSsDbnaLcSMO60FOYgRN4KsLaKUGkXiiUBWp4boFMCiuTTOiyH8XlH0e9uWc0vMLyGNUcO8kCbpAnx3v1JTIVan3WGsnGv4K8Qu4M8GAkZewpexrsb2BSNNeLclOV9/cR203Y5KlzXcpiWNXSs8XoB3TLEtHYMnjuakMQfyrcXKIQntg4xPD/+wvfqkcMg9i7pcplQh7X2OK5ylrMZQrZkJ1fAYBGbBz1tykWOjfrZ") pickleKey := []byte("secret_key") sess, err := session.MegolmInboundSessionFromPickled(pickledDataFromLibOlm, pickleKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) newPickled, err := sess.Pickle(pickleKey) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(pickledDataFromLibOlm, newPickled) { - t.Fatal("pickled version does not equal libolm version") - } + assert.NoError(t, err) + assert.Equal(t, pickledDataFromLibOlm, newPickled) + pickledDataFromLibOlm = append(pickledDataFromLibOlm, []byte("a")...) _, err = session.MegolmInboundSessionFromPickled(pickledDataFromLibOlm, pickleKey) - if err == nil { - t.Fatal("should have gotten an error") - } + assert.ErrorIs(t, err, base64.CorruptInputError(416)) } diff --git a/crypto/goolm/session/olm_session_test.go b/crypto/goolm/session/olm_session_test.go index 773da8c7..f87c2e7e 100644 --- a/crypto/goolm/session/olm_session_test.go +++ b/crypto/goolm/session/olm_session_test.go @@ -1,11 +1,11 @@ package session_test import ( - "bytes" "encoding/base64" - "errors" "testing" + "github.com/stretchr/testify/assert" + "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/session" "maunium.net/go/mautrix/crypto/olm" @@ -15,30 +15,18 @@ import ( func TestOlmSession(t *testing.T) { pickleKey := []byte("secretKey") aliceKeyPair, err := crypto.Curve25519GenerateKey() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) bobKeyPair, err := crypto.Curve25519GenerateKey() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) bobOneTimeKey, err := crypto.Curve25519GenerateKey() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) aliceSession, err := session.NewOutboundOlmSession(aliceKeyPair, bobKeyPair.PublicKey, bobOneTimeKey.PublicKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) //create a message so that there are more keys to marshal plaintext := []byte("Test message from Alice to Bob") msgType, message, err := aliceSession.Encrypt(plaintext) - if err != nil { - t.Fatal(err) - } - if msgType != id.OlmMsgTypePreKey { - t.Fatal("Wrong message type") - } + assert.NoError(t, err) + assert.Equal(t, id.OlmMsgTypePreKey, msgType) searchFunc := func(target crypto.Curve25519PublicKey) *crypto.OneTimeKey { if target.Equal(bobOneTimeKey.PublicKey) { @@ -52,92 +40,58 @@ func TestOlmSession(t *testing.T) { } //bob receives message bobSession, err := session.NewInboundOlmSession(nil, message, searchFunc, bobKeyPair) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) decryptedMsg, err := bobSession.Decrypt(string(message), msgType) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plaintext, decryptedMsg) { - t.Fatalf("messages are not equal:\n%v\n%v\n", plaintext, decryptedMsg) - } + assert.NoError(t, err) + assert.Equal(t, plaintext, decryptedMsg) // Alice pickles session pickled, err := aliceSession.PickleAsJSON(pickleKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) //bob sends a message plaintext = []byte("A message from Bob to Alice") msgType, message, err = bobSession.Encrypt(plaintext) - if err != nil { - t.Fatal(err) - } - if msgType != id.OlmMsgTypeMsg { - t.Fatal("Wrong message type") - } + assert.NoError(t, err) + assert.Equal(t, id.OlmMsgTypeMsg, msgType) //Alice unpickles session newAliceSession, err := session.OlmSessionFromJSONPickled(pickled, pickleKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) //Alice receives message decryptedMsg, err = newAliceSession.Decrypt(string(message), msgType) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plaintext, decryptedMsg) { - t.Fatalf("messages are not equal:\n%v\n%v\n", plaintext, decryptedMsg) - } + assert.NoError(t, err) + assert.Equal(t, plaintext, decryptedMsg) //Alice receives message again _, err = newAliceSession.Decrypt(string(message), msgType) - if err == nil { - t.Fatal("should have gotten an error") - } + assert.ErrorIs(t, err, olm.ErrMessageKeyNotFound) //Alice sends another message plaintext = []byte("A second message to Bob") msgType, message, err = newAliceSession.Encrypt(plaintext) - if err != nil { - t.Fatal(err) - } - if msgType != id.OlmMsgTypeMsg { - t.Fatal("Wrong message type") - } + assert.NoError(t, err) + assert.Equal(t, id.OlmMsgTypeMsg, msgType) + //bob receives message decryptedMsg, err = bobSession.Decrypt(string(message), msgType) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(plaintext, decryptedMsg) { - t.Fatalf("messages are not equal:\n%v\n%v\n", plaintext, decryptedMsg) - } + assert.NoError(t, err) + assert.Equal(t, plaintext, decryptedMsg) } func TestSessionPickle(t *testing.T) { pickledDataFromLibOlm := []byte("icDKYm0b4aO23WgUuOxdpPoxC0UlEOYPVeuduNH3IkpFsmnWx5KuEOpxGiZw5IuB/sSn2RZUCTiJ90IvgC7AClkYGHep9O8lpiqQX73XVKD9okZDCAkBc83eEq0DKYC7HBkGRAU/4T6QPIBBY3UK4QZwULLE/fLsi3j4YZBehMtnlsqgHK0q1bvX4cRznZItVKR4ro0O9EAk6LLxJtSnRu5elSUk7YXT") pickleKey := []byte("secret_key") sess, err := session.OlmSessionFromPickled(pickledDataFromLibOlm, pickleKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) newPickled, err := sess.Pickle(pickleKey) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(pickledDataFromLibOlm, newPickled) { - t.Fatal("pickled version does not equal libolm version") - } + assert.NoError(t, err) + assert.Equal(t, pickledDataFromLibOlm, newPickled) + pickledDataFromLibOlm = append(pickledDataFromLibOlm, []byte("a")...) _, err = session.OlmSessionFromPickled(pickledDataFromLibOlm, pickleKey) - if err == nil { - t.Fatal("should have gotten an error") - } + assert.ErrorIs(t, err, base64.CorruptInputError(224)) } func TestDecrypts(t *testing.T) { @@ -161,17 +115,9 @@ func TestDecrypts(t *testing.T) { "dGvPXeH8qLeNZA") pickleKey := []byte("") sess, err := session.OlmSessionFromPickled(sessionPickled, pickleKey) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) for curIndex, curMessage := range messages { _, err := sess.Decrypt(string(curMessage), id.OlmMsgTypePreKey) - if err != nil { - if !errors.Is(err, expectedErr[curIndex]) { - t.Fatal(err) - } - } else { - t.Fatal("error expected") - } + assert.ErrorIs(t, err, expectedErr[curIndex]) } } From b1af9f494144aecf6535bc135250212998452b4e Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Wed, 4 Sep 2024 01:59:42 -0600 Subject: [PATCH 0868/1647] goolm/cipher: inline keys and KDF info in test Before this commit, it didn't properly check that changing only the key changes the output. Signed-off-by: Sumner Evans --- crypto/goolm/cipher/aes_sha256_test.go | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/crypto/goolm/cipher/aes_sha256_test.go b/crypto/goolm/cipher/aes_sha256_test.go index 69aae100..2f58605f 100644 --- a/crypto/goolm/cipher/aes_sha256_test.go +++ b/crypto/goolm/cipher/aes_sha256_test.go @@ -8,11 +8,9 @@ import ( ) func TestDeriveAESKeys(t *testing.T) { - kdfInfo := []byte("test") - key := []byte("test key") - derivedKeys, err := deriveAESKeys(kdfInfo, key) + derivedKeys, err := deriveAESKeys([]byte("test"), []byte("test key")) assert.NoError(t, err) - derivedKeys2, err := deriveAESKeys(kdfInfo, key) + derivedKeys2, err := deriveAESKeys([]byte("test"), []byte("test key")) assert.NoError(t, err) //derivedKeys and derivedKeys2 should be identical @@ -21,8 +19,7 @@ func TestDeriveAESKeys(t *testing.T) { assert.Equal(t, derivedKeys.hmacKey, derivedKeys2.hmacKey) //changing kdfInfo - kdfInfo = []byte("other kdf") - derivedKeys2, err = deriveAESKeys(kdfInfo, key) + derivedKeys2, err = deriveAESKeys([]byte("other kdf"), []byte("test key")) assert.NoError(t, err) //derivedKeys and derivedKeys2 should now be different @@ -31,8 +28,7 @@ func TestDeriveAESKeys(t *testing.T) { assert.NotEqual(t, derivedKeys.hmacKey, derivedKeys2.hmacKey) //changing key - key = []byte("other test key") - derivedKeys, err = deriveAESKeys(kdfInfo, key) + derivedKeys, err = deriveAESKeys([]byte("test"), []byte("other test key")) assert.NoError(t, err) //derivedKeys and derivedKeys2 should now be different From 95e562b2fee41d8e9d5fd9d4496f729955a8c59b Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Wed, 4 Sep 2024 02:01:40 -0600 Subject: [PATCH 0869/1647] goolm/cipher: make deriveAESKeys call io.ReadFull less Signed-off-by: Sumner Evans --- crypto/goolm/cipher/aes_sha256.go | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/crypto/goolm/cipher/aes_sha256.go b/crypto/goolm/cipher/aes_sha256.go index 2d2d58d5..6bd8997d 100644 --- a/crypto/goolm/cipher/aes_sha256.go +++ b/crypto/goolm/cipher/aes_sha256.go @@ -17,23 +17,15 @@ type derivedAESKeys struct { } // deriveAESKeys derives three keys for the AESSHA256 cipher -func deriveAESKeys(kdfInfo []byte, key []byte) (*derivedAESKeys, error) { +func deriveAESKeys(kdfInfo []byte, key []byte) (derivedAESKeys, error) { hkdf := crypto.HKDFSHA256(key, nil, kdfInfo) - keys := &derivedAESKeys{ - key: make([]byte, 32), - hmacKey: make([]byte, 32), - iv: make([]byte, 16), - } - if _, err := io.ReadFull(hkdf, keys.key); err != nil { - return nil, err - } - if _, err := io.ReadFull(hkdf, keys.hmacKey); err != nil { - return nil, err - } - if _, err := io.ReadFull(hkdf, keys.iv); err != nil { - return nil, err - } - return keys, nil + keymatter := make([]byte, 80) + _, err := io.ReadFull(hkdf, keymatter) + return derivedAESKeys{ + key: keymatter[:32], + hmacKey: keymatter[32:64], + iv: keymatter[64:], + }, err } // AESSha512BlockSize resturns the blocksize of the cipher AESSHA256. From 93a57f5378a2dfadf1707fc9a61f17014d59b5d1 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Wed, 4 Sep 2024 02:02:19 -0600 Subject: [PATCH 0870/1647] goolm/cipher: remove unnecessary function Signed-off-by: Sumner Evans --- crypto/goolm/cipher/aes_sha256.go | 6 ------ crypto/goolm/cipher/pickle.go | 3 ++- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/crypto/goolm/cipher/aes_sha256.go b/crypto/goolm/cipher/aes_sha256.go index 6bd8997d..2e8bcd9c 100644 --- a/crypto/goolm/cipher/aes_sha256.go +++ b/crypto/goolm/cipher/aes_sha256.go @@ -2,7 +2,6 @@ package cipher import ( "bytes" - "crypto/aes" "io" "maunium.net/go/mautrix/crypto/aescbc" @@ -28,11 +27,6 @@ func deriveAESKeys(kdfInfo []byte, key []byte) (derivedAESKeys, error) { }, err } -// AESSha512BlockSize resturns the blocksize of the cipher AESSHA256. -func AESSha512BlockSize() int { - return aes.BlockSize -} - // AESSHA256 is a valid cipher using AES with CBC and HKDFSha256. type AESSHA256 struct { kdfInfo []byte diff --git a/crypto/goolm/cipher/pickle.go b/crypto/goolm/cipher/pickle.go index 551f4356..55d031e8 100644 --- a/crypto/goolm/cipher/pickle.go +++ b/crypto/goolm/cipher/pickle.go @@ -1,6 +1,7 @@ package cipher import ( + "crypto/aes" "fmt" "maunium.net/go/mautrix/crypto/goolm/goolmbase64" @@ -14,7 +15,7 @@ const ( // PickleBlockSize returns the blocksize of the used cipher. func PickleBlockSize() int { - return AESSha512BlockSize() + return aes.BlockSize } // Pickle encrypts the input with the key and the cipher AESSHA256. The result is then encoded in base64. From fbea2a067ce07bcd02e9ceb60112c397e9b7c589 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Wed, 4 Sep 2024 02:15:48 -0600 Subject: [PATCH 0871/1647] goolm: simplify return statements Signed-off-by: Sumner Evans --- crypto/goolm/account/account.go | 24 ++----- crypto/goolm/cipher/aes_sha256.go | 12 +--- crypto/goolm/cipher/pickle.go | 9 +-- crypto/goolm/crypto/curve25519.go | 12 +--- crypto/goolm/crypto/ed25519.go | 5 +- crypto/goolm/pk/decryption.go | 12 +--- crypto/goolm/pk/encryption.go | 5 +- crypto/goolm/pk/signing.go | 5 +- crypto/goolm/ratchet/olm.go | 72 +++++++------------ .../goolm/session/megolm_inbound_session.go | 6 +- .../goolm/session/megolm_outbound_session.go | 11 +-- crypto/goolm/session/olm_session.go | 12 +--- 12 files changed, 46 insertions(+), 139 deletions(-) diff --git a/crypto/goolm/account/account.go b/crypto/goolm/account/account.go index 46ae2571..4708fba1 100644 --- a/crypto/goolm/account/account.go +++ b/crypto/goolm/account/account.go @@ -47,11 +47,7 @@ func AccountFromJSONPickled(pickled, key []byte) (*Account, error) { return nil, fmt.Errorf("accountFromPickled: %w", olm.ErrEmptyInput) } a := &Account{} - err := a.UnpickleAsJSON(pickled, key) - if err != nil { - return nil, err - } - return a, nil + return a, a.UnpickleAsJSON(pickled, key) } // AccountFromPickled loads the Account details from a pickled base64 string. The input is decrypted with the supplied key. @@ -60,11 +56,7 @@ func AccountFromPickled(pickled, key []byte) (*Account, error) { return nil, fmt.Errorf("accountFromPickled: %w", olm.ErrEmptyInput) } a := &Account{} - err := a.Unpickle(pickled, key) - if err != nil { - return nil, err - } - return a, nil + return a, a.Unpickle(pickled, key) } // NewAccount creates a new Account. @@ -185,11 +177,7 @@ func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve2 if err != nil { return nil, err } - s, err := session.NewOutboundOlmSession(a.IdKeys.Curve25519, theirIdentityKeyDecoded, theirOneTimeKeyDecoded) - if err != nil { - return nil, err - } - return s, nil + return session.NewOutboundOlmSession(a.IdKeys.Curve25519, theirIdentityKeyDecoded, theirOneTimeKeyDecoded) } // NewInboundSession creates a new in-bound session for sending/receiving @@ -450,11 +438,7 @@ func (a *Account) Pickle(key []byte) ([]byte, error) { if written != len(pickeledBytes) { return nil, errors.New("number of written bytes not correct") } - encrypted, err := cipher.Pickle(key, pickeledBytes) - if err != nil { - return nil, err - } - return encrypted, nil + return cipher.Pickle(key, pickeledBytes) } // PickleLibOlm encodes the Account into target. target has to have a size of at least PickleLen() and is written to from index 0. diff --git a/crypto/goolm/cipher/aes_sha256.go b/crypto/goolm/cipher/aes_sha256.go index 2e8bcd9c..065cb501 100644 --- a/crypto/goolm/cipher/aes_sha256.go +++ b/crypto/goolm/cipher/aes_sha256.go @@ -45,11 +45,7 @@ func (c AESSHA256) Encrypt(key, plaintext []byte) (ciphertext []byte, err error) if err != nil { return nil, err } - ciphertext, err = aescbc.Encrypt(keys.key, keys.iv, plaintext) - if err != nil { - return nil, err - } - return ciphertext, nil + return aescbc.Encrypt(keys.key, keys.iv, plaintext) } // Decrypt decrypts the ciphertext with the key. The key is used to derive the actual encryption key (32 bytes) as well as the iv (16 bytes). @@ -58,11 +54,7 @@ func (c AESSHA256) Decrypt(key, ciphertext []byte) (plaintext []byte, err error) if err != nil { return nil, err } - plaintext, err = aescbc.Decrypt(keys.key, keys.iv, ciphertext) - if err != nil { - return nil, err - } - return plaintext, nil + return aescbc.Decrypt(keys.key, keys.iv, ciphertext) } // MAC returns the MAC for the message using the key. The key is used to derive the actual mac key (32 bytes). diff --git a/crypto/goolm/cipher/pickle.go b/crypto/goolm/cipher/pickle.go index 55d031e8..754c7963 100644 --- a/crypto/goolm/cipher/pickle.go +++ b/crypto/goolm/cipher/pickle.go @@ -30,8 +30,7 @@ func Pickle(key, input []byte) ([]byte, error) { return nil, err } ciphertext = append(ciphertext, mac[:pickleMACLength]...) - encoded := goolmbase64.Encode(ciphertext) - return encoded, nil + return goolmbase64.Encode(ciphertext), nil } // Unpickle decodes the input from base64 and decrypts the decoded input with the key and the cipher AESSHA256. @@ -52,9 +51,5 @@ func Unpickle(key, input []byte) ([]byte, error) { //Set to next block size targetCipherText := make([]byte, int(len(ciphertext)/PickleBlockSize())*PickleBlockSize()) copy(targetCipherText, ciphertext) - plaintext, err := pickleCipher.Decrypt(key, targetCipherText) - if err != nil { - return nil, err - } - return plaintext, nil + return pickleCipher.Decrypt(key, targetCipherText) } diff --git a/crypto/goolm/crypto/curve25519.go b/crypto/goolm/crypto/curve25519.go index 872ce3a1..2ae20e0e 100644 --- a/crypto/goolm/crypto/curve25519.go +++ b/crypto/goolm/crypto/curve25519.go @@ -26,15 +26,11 @@ func Curve25519GenerateKey() (Curve25519KeyPair, error) { } privateKey := Curve25519PrivateKey(privateKeyByte) - publicKey, err := privateKey.PubKey() - if err != nil { - return Curve25519KeyPair{}, err - } return Curve25519KeyPair{ PrivateKey: Curve25519PrivateKey(privateKey), PublicKey: Curve25519PublicKey(publicKey), - }, nil + }, err } // Curve25519GenerateFromPrivate creates a new curve25519 key pair with the private key given. @@ -121,11 +117,7 @@ func (c Curve25519PrivateKey) Equal(x Curve25519PrivateKey) bool { // PubKey returns the public key derived from the private key. func (c Curve25519PrivateKey) PubKey() (Curve25519PublicKey, error) { - publicKey, err := curve25519.X25519(c, curve25519.Basepoint) - if err != nil { - return nil, err - } - return publicKey, nil + return curve25519.X25519(c, curve25519.Basepoint) } // SharedSecret returns the shared secret between the private key and the given public key. diff --git a/crypto/goolm/crypto/ed25519.go b/crypto/goolm/crypto/ed25519.go index bc260377..ceb3818b 100644 --- a/crypto/goolm/crypto/ed25519.go +++ b/crypto/goolm/crypto/ed25519.go @@ -17,13 +17,10 @@ const ( // Ed25519GenerateKey creates a new ed25519 key pair. func Ed25519GenerateKey() (Ed25519KeyPair, error) { publicKey, privateKey, err := ed25519.GenerateKey(nil) - if err != nil { - return Ed25519KeyPair{}, err - } return Ed25519KeyPair{ PrivateKey: Ed25519PrivateKey(privateKey), PublicKey: Ed25519PublicKey(publicKey), - }, nil + }, err } // Ed25519GenerateFromPrivate creates a new ed25519 key pair with the private key given. diff --git a/crypto/goolm/pk/decryption.go b/crypto/goolm/pk/decryption.go index dcec5107..76dfa0da 100644 --- a/crypto/goolm/pk/decryption.go +++ b/crypto/goolm/pk/decryption.go @@ -78,11 +78,7 @@ func (s Decryption) Decrypt(ephemeralKey, mac, ciphertext []byte) ([]byte, error if !verified { return nil, fmt.Errorf("decrypt: %w", olm.ErrBadMAC) } - plaintext, err := cipher.Decrypt(sharedSecret, ciphertext) - if err != nil { - return nil, err - } - return plaintext, nil + return cipher.Decrypt(sharedSecret, ciphertext) } // PickleAsJSON returns an Decryption as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. @@ -136,11 +132,7 @@ func (a Decryption) Pickle(key []byte) ([]byte, error) { if written != len(pickeledBytes) { return nil, errors.New("number of written bytes not correct") } - encrypted, err := cipher.Pickle(key, pickeledBytes) - if err != nil { - return nil, err - } - return encrypted, nil + return cipher.Pickle(key, pickeledBytes) } // PickleLibOlm encodes the Decryption into target. target has to have a size of at least PickleLen() and is written to from index 0. diff --git a/crypto/goolm/pk/encryption.go b/crypto/goolm/pk/encryption.go index 54f15830..c99a9517 100644 --- a/crypto/goolm/pk/encryption.go +++ b/crypto/goolm/pk/encryption.go @@ -42,8 +42,5 @@ func (e Encryption) Encrypt(plaintext []byte, privateKey crypto.Curve25519Privat return nil, nil, err } mac, err = cipher.MAC(sharedSecret, ciphertext) - if err != nil { - return nil, nil, err - } - return ciphertext, goolmbase64.Encode(mac), nil + return ciphertext, goolmbase64.Encode(mac), err } diff --git a/crypto/goolm/pk/signing.go b/crypto/goolm/pk/signing.go index b22c76dc..9dfd24a1 100644 --- a/crypto/goolm/pk/signing.go +++ b/crypto/goolm/pk/signing.go @@ -62,8 +62,5 @@ func (s Signing) SignJSON(obj any) (string, error) { objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned") objJSON, _ = sjson.DeleteBytes(objJSON, "signatures") signature, err := s.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON)) - if err != nil { - return "", err - } - return string(signature), nil + return string(signature), err } diff --git a/crypto/goolm/ratchet/olm.go b/crypto/goolm/ratchet/olm.go index 4653aae7..879f6cfe 100644 --- a/crypto/goolm/ratchet/olm.go +++ b/crypto/goolm/ratchet/olm.go @@ -124,12 +124,7 @@ func (r *Ratchet) Encrypt(plaintext []byte) ([]byte, error) { message.RatchetKey = r.SenderChains.ratchetKey().PublicKey message.Ciphertext = encryptedText //creating the mac is done in encode - output, err := message.EncodeAndMAC(messageKey.Key, RatchetCipher) - if err != nil { - return nil, err - } - - return output, nil + return message.EncodeAndMAC(messageKey.Key, RatchetCipher) } // Decrypt decrypts the ciphertext and verifies the MAC. @@ -153,53 +148,42 @@ func (r *Ratchet) Decrypt(input []byte) ([]byte, error) { break } } - var result []byte if receiverChainFromMessage == nil { //Advancing the chain is done in this method - result, err = r.decryptForNewChain(message, input) - if err != nil { - return nil, err - } + return r.decryptForNewChain(message, input) } else if receiverChainFromMessage.chainKey().Index > message.Counter { // No need to advance the chain // Chain already advanced beyond the key for this message // Check if the message keys are in the skipped key list. - foundSkippedKey := false for curSkippedIndex := range r.SkippedMessageKeys { - if message.Counter == r.SkippedMessageKeys[curSkippedIndex].MKey.Index { - // Found the key for this message. Check the MAC. - verified, err := message.VerifyMACInline(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, RatchetCipher, input) - if err != nil { - return nil, err - } - if !verified { - return nil, fmt.Errorf("decrypt from skipped message keys: %w", olm.ErrBadMAC) - } - result, err = RatchetCipher.Decrypt(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, message.Ciphertext) - if err != nil { - return nil, fmt.Errorf("cipher decrypt: %w", err) - } - if len(result) != 0 { - // Remove the key from the skipped keys now that we've - // decoded the message it corresponds to. - r.SkippedMessageKeys[curSkippedIndex] = r.SkippedMessageKeys[len(r.SkippedMessageKeys)-1] - r.SkippedMessageKeys = r.SkippedMessageKeys[:len(r.SkippedMessageKeys)-1] - } - foundSkippedKey = true + if message.Counter != r.SkippedMessageKeys[curSkippedIndex].MKey.Index { + continue + } + + // Found the key for this message. Check the MAC. + verified, err := message.VerifyMACInline(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, RatchetCipher, input) + if err != nil { + return nil, err + } + if !verified { + return nil, fmt.Errorf("decrypt from skipped message keys: %w", olm.ErrBadMAC) + } + result, err := RatchetCipher.Decrypt(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, message.Ciphertext) + if err != nil { + return nil, fmt.Errorf("cipher decrypt: %w", err) + } else if len(result) != 0 { + // Remove the key from the skipped keys now that we've + // decoded the message it corresponds to. + r.SkippedMessageKeys[curSkippedIndex] = r.SkippedMessageKeys[len(r.SkippedMessageKeys)-1] + r.SkippedMessageKeys = r.SkippedMessageKeys[:len(r.SkippedMessageKeys)-1] + return result, nil } } - if !foundSkippedKey { - return nil, fmt.Errorf("decrypt: %w", olm.ErrMessageKeyNotFound) - } + return nil, fmt.Errorf("decrypt: %w", olm.ErrMessageKeyNotFound) } else { //Advancing the chain is done in this method - result, err = r.decryptForExistingChain(receiverChainFromMessage, message, input) - if err != nil { - return nil, err - } + return r.decryptForExistingChain(receiverChainFromMessage, message, input) } - - return result, nil } // advanceRootKey created the next root key and returns the next chainKey @@ -281,11 +265,7 @@ func (r *Ratchet) decryptForNewChain(message *message.Message, rawMessage []byte */ r.SenderChains = senderChain{} - decrypted, err := r.decryptForExistingChain(&r.ReceiverChains[0], message, rawMessage) - if err != nil { - return nil, err - } - return decrypted, nil + return r.decryptForExistingChain(&r.ReceiverChains[0], message, rawMessage) } // PickleAsJSON returns a ratchet as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. diff --git a/crypto/goolm/session/megolm_inbound_session.go b/crypto/goolm/session/megolm_inbound_session.go index f48698e7..bfb24322 100644 --- a/crypto/goolm/session/megolm_inbound_session.go +++ b/crypto/goolm/session/megolm_inbound_session.go @@ -254,11 +254,7 @@ func (o *MegolmInboundSession) Pickle(key []byte) ([]byte, error) { if written != len(pickeledBytes) { return nil, errors.New("number of written bytes not correct") } - encrypted, err := cipher.Pickle(key, pickeledBytes) - if err != nil { - return nil, err - } - return encrypted, nil + return cipher.Pickle(key, pickeledBytes) } // PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0. diff --git a/crypto/goolm/session/megolm_outbound_session.go b/crypto/goolm/session/megolm_outbound_session.go index b3234967..e85a5c12 100644 --- a/crypto/goolm/session/megolm_outbound_session.go +++ b/crypto/goolm/session/megolm_outbound_session.go @@ -68,10 +68,7 @@ func (o *MegolmOutboundSession) Encrypt(plaintext []byte) ([]byte, error) { return nil, olm.ErrEmptyInput } encrypted, err := o.Ratchet.Encrypt(plaintext, &o.SigningKey) - if err != nil { - return nil, err - } - return goolmbase64.Encode(encrypted), nil + return goolmbase64.Encode(encrypted), err } // SessionID returns the base64 endoded public signing key @@ -141,11 +138,7 @@ func (o *MegolmOutboundSession) Pickle(key []byte) ([]byte, error) { if written != len(pickeledBytes) { return nil, errors.New("number of written bytes not correct") } - encrypted, err := cipher.Pickle(key, pickeledBytes) - if err != nil { - return nil, err - } - return encrypted, nil + return cipher.Pickle(key, pickeledBytes) } // PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0. diff --git a/crypto/goolm/session/olm_session.go b/crypto/goolm/session/olm_session.go index c0067bfa..f58edf87 100644 --- a/crypto/goolm/session/olm_session.go +++ b/crypto/goolm/session/olm_session.go @@ -48,11 +48,7 @@ func OlmSessionFromJSONPickled(pickled, key []byte) (*OlmSession, error) { return nil, fmt.Errorf("sessionFromPickled: %w", olm.ErrEmptyInput) } a := &OlmSession{} - err := a.UnpickleAsJSON(pickled, key) - if err != nil { - return nil, err - } - return a, nil + return a, a.UnpickleAsJSON(pickled, key) } // OlmSessionFromPickled loads the OlmSession details from a pickled base64 string. The input is decrypted with the supplied key. @@ -61,11 +57,7 @@ func OlmSessionFromPickled(pickled, key []byte) (*OlmSession, error) { return nil, fmt.Errorf("sessionFromPickled: %w", olm.ErrEmptyInput) } a := &OlmSession{} - err := a.Unpickle(pickled, key) - if err != nil { - return nil, err - } - return a, nil + return a, a.Unpickle(pickled, key) } // NewOlmSession creates a new Session. From bc1f09086f61995d0227caa68dae29c2e4f6bccd Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 24 Oct 2024 20:34:53 -0600 Subject: [PATCH 0872/1647] goolm: use constants for pickle lengths when possible Signed-off-by: Sumner Evans --- crypto/goolm/account/account.go | 16 +++---- crypto/goolm/crypto/curve25519.go | 38 +++++---------- crypto/goolm/crypto/curve25519_test.go | 6 +-- crypto/goolm/crypto/ed25519.go | 37 +++++--------- crypto/goolm/crypto/ed25519_test.go | 6 +-- crypto/goolm/crypto/one_time_key.go | 15 ++---- crypto/goolm/libolmpickle/pickle.go | 15 +++--- crypto/goolm/libolmpickle/pickle_test.go | 6 +-- crypto/goolm/megolm/megolm.go | 12 ++--- crypto/goolm/pk/decryption.go | 19 ++++---- crypto/goolm/ratchet/chain.go | 48 +++++++------------ crypto/goolm/ratchet/olm.go | 24 +++++----- crypto/goolm/ratchet/skipped_message.go | 12 ++--- .../goolm/session/megolm_inbound_session.go | 21 ++++---- .../goolm/session/megolm_outbound_session.go | 16 +++---- crypto/goolm/session/olm_session.go | 20 ++++---- 16 files changed, 119 insertions(+), 192 deletions(-) diff --git a/crypto/goolm/account/account.go b/crypto/goolm/account/account.go index 4708fba1..dcb4c9d4 100644 --- a/crypto/goolm/account/account.go +++ b/crypto/goolm/account/account.go @@ -488,14 +488,14 @@ func (a *Account) PickleLibOlm(target []byte) (int, error) { // PickleLen returns the number of bytes the pickled Account will have. func (a *Account) PickleLen() int { - length := libolmpickle.PickleUInt32Len(accountPickleVersionLibOLM) - length += a.IdKeys.Ed25519.PickleLen() - length += a.IdKeys.Curve25519.PickleLen() - length += libolmpickle.PickleUInt32Len(uint32(len(a.OTKeys))) - length += (len(a.OTKeys) * (&crypto.OneTimeKey{}).PickleLen()) - length += libolmpickle.PickleUInt8Len(a.NumFallbackKeys) - length += (int(a.NumFallbackKeys) * (&crypto.OneTimeKey{}).PickleLen()) - length += libolmpickle.PickleUInt32Len(a.NextOneTimeKeyID) + length := libolmpickle.PickleUInt32Length + length += crypto.Ed25519KeyPairPickleLength // IdKeys.Ed25519 + length += crypto.Curve25519KeyPairPickleLength // IdKeys.Curve25519 + length += libolmpickle.PickleUInt32Length + length += (len(a.OTKeys) * crypto.OneTimeKeyPickleLength) + length += libolmpickle.PickleUInt8Length + length += (int(a.NumFallbackKeys) * crypto.OneTimeKeyPickleLength) + length += libolmpickle.PickleUInt32Length return length } diff --git a/crypto/goolm/crypto/curve25519.go b/crypto/goolm/crypto/curve25519.go index 2ae20e0e..01ada7e0 100644 --- a/crypto/goolm/crypto/curve25519.go +++ b/crypto/goolm/crypto/curve25519.go @@ -15,7 +15,7 @@ import ( const ( Curve25519KeyLength = curve25519.ScalarSize //The length of the private key. - curve25519PubKeyLength = 32 + Curve25519PubKeyLength = 32 ) // Curve25519GenerateKey creates a new curve25519 key pair. @@ -51,6 +51,9 @@ type Curve25519KeyPair struct { PublicKey Curve25519PublicKey `json:"public,omitempty"` } +const Curve25519KeyPairPickleLength = Curve25519PubKeyLength + // Public Key + Curve25519KeyLength // Private Key + // B64Encoded returns a base64 encoded string of the public key. func (c Curve25519KeyPair) B64Encoded() id.Curve25519 { return c.PublicKey.B64Encoded() @@ -61,10 +64,11 @@ func (c Curve25519KeyPair) SharedSecret(pubKey Curve25519PublicKey) ([]byte, err return c.PrivateKey.SharedSecret(pubKey) } -// PickleLibOlm encodes the key pair into target. target has to have a size of at least PickleLen() and is written to from index 0. +// PickleLibOlm encodes the key pair into target. The target has to have a size +// of at least [Curve25519KeyPairPickleLength] and is written to from index 0. // It returns the number of bytes written. func (c Curve25519KeyPair) PickleLibOlm(target []byte) (int, error) { - if len(target) < c.PickleLen() { + if len(target) < Curve25519KeyPairPickleLength { return 0, fmt.Errorf("pickle curve25519 key pair: %w", olm.ErrValueTooShort) } written, err := c.PublicKey.PickleLibOlm(target) @@ -95,18 +99,6 @@ func (c *Curve25519KeyPair) UnpickleLibOlm(value []byte) (int, error) { return read + readPriv, nil } -// PickleLen returns the number of bytes the pickled key pair will have. -func (c Curve25519KeyPair) PickleLen() int { - lenPublic := c.PublicKey.PickleLen() - var lenPrivate int - if len(c.PrivateKey) != Curve25519KeyLength { - lenPrivate = libolmpickle.PickleBytesLen(make([]byte, Curve25519KeyLength)) - } else { - lenPrivate = libolmpickle.PickleBytesLen(c.PrivateKey) - } - return lenPublic + lenPrivate -} - // Curve25519PrivateKey represents the private key for curve25519 usage type Curve25519PrivateKey []byte @@ -141,29 +133,21 @@ func (c Curve25519PublicKey) B64Encoded() id.Curve25519 { // PickleLibOlm encodes the public key into target. target has to have a size of at least PickleLen() and is written to from index 0. // It returns the number of bytes written. func (c Curve25519PublicKey) PickleLibOlm(target []byte) (int, error) { - if len(target) < c.PickleLen() { + if len(target) < Curve25519PubKeyLength { return 0, fmt.Errorf("pickle curve25519 public key: %w", olm.ErrValueTooShort) } - if len(c) != curve25519PubKeyLength { - return libolmpickle.PickleBytes(make([]byte, curve25519PubKeyLength), target), nil + if len(c) != Curve25519PubKeyLength { + return libolmpickle.PickleBytes(make([]byte, Curve25519PubKeyLength), target), nil } return libolmpickle.PickleBytes(c, target), nil } // UnpickleLibOlm decodes the unencryted value and populates the public key accordingly. It returns the number of bytes read. func (c *Curve25519PublicKey) UnpickleLibOlm(value []byte) (int, error) { - unpickled, readBytes, err := libolmpickle.UnpickleBytes(value, curve25519PubKeyLength) + unpickled, readBytes, err := libolmpickle.UnpickleBytes(value, Curve25519PubKeyLength) if err != nil { return 0, err } *c = unpickled return readBytes, nil } - -// PickleLen returns the number of bytes the pickled public key will have. -func (c Curve25519PublicKey) PickleLen() int { - if len(c) != curve25519PubKeyLength { - return libolmpickle.PickleBytesLen(make([]byte, curve25519PubKeyLength)) - } - return libolmpickle.PickleBytesLen(c) -} diff --git a/crypto/goolm/crypto/curve25519_test.go b/crypto/goolm/crypto/curve25519_test.go index b7c86eee..fc8ee54b 100644 --- a/crypto/goolm/crypto/curve25519_test.go +++ b/crypto/goolm/crypto/curve25519_test.go @@ -74,7 +74,7 @@ func TestCurve25519Pickle(t *testing.T) { //create keypair keyPair, err := crypto.Curve25519GenerateKey() assert.NoError(t, err) - target := make([]byte, keyPair.PickleLen()) + target := make([]byte, crypto.Curve25519KeyPairPickleLength) writtenBytes, err := keyPair.PickleLibOlm(target) assert.NoError(t, err) assert.Len(t, target, writtenBytes) @@ -92,7 +92,7 @@ func TestCurve25519PicklePubKeyOnly(t *testing.T) { assert.NoError(t, err) //Remove privateKey keyPair.PrivateKey = nil - target := make([]byte, keyPair.PickleLen()) + target := make([]byte, crypto.Curve25519KeyPairPickleLength) writtenBytes, err := keyPair.PickleLibOlm(target) assert.NoError(t, err) assert.Len(t, target, writtenBytes) @@ -109,7 +109,7 @@ func TestCurve25519PicklePrivKeyOnly(t *testing.T) { assert.NoError(t, err) //Remove public keyPair.PublicKey = nil - target := make([]byte, keyPair.PickleLen()) + target := make([]byte, crypto.Curve25519KeyPairPickleLength) writtenBytes, err := keyPair.PickleLibOlm(target) assert.NoError(t, err) assert.Len(t, target, writtenBytes) diff --git a/crypto/goolm/crypto/ed25519.go b/crypto/goolm/crypto/ed25519.go index ceb3818b..abacb1ee 100644 --- a/crypto/goolm/crypto/ed25519.go +++ b/crypto/goolm/crypto/ed25519.go @@ -46,6 +46,9 @@ type Ed25519KeyPair struct { PublicKey Ed25519PublicKey `json:"public,omitempty"` } +const Ed25519KeyPairPickleLength = ed25519.PublicKeySize + // PublicKey + ed25519.PrivateKeySize // Private Key + // B64Encoded returns a base64 encoded string of the public key. func (c Ed25519KeyPair) B64Encoded() id.Ed25519 { return id.Ed25519(base64.RawStdEncoding.EncodeToString(c.PublicKey)) @@ -61,10 +64,11 @@ func (c Ed25519KeyPair) Verify(message, givenSignature []byte) bool { return c.PublicKey.Verify(message, givenSignature) } -// PickleLibOlm encodes the key pair into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. +// PickleLibOlm encodes the key pair into target. target has to have a size of +// at least [Ed25519KeyPairPickleLength] and is written to from index 0. It +// returns the number of bytes written. func (c Ed25519KeyPair) PickleLibOlm(target []byte) (int, error) { - if len(target) < c.PickleLen() { + if len(target) < Ed25519KeyPairPickleLength { return 0, fmt.Errorf("pickle ed25519 key pair: %w", olm.ErrValueTooShort) } written, err := c.PublicKey.PickleLibOlm(target) @@ -96,18 +100,6 @@ func (c *Ed25519KeyPair) UnpickleLibOlm(value []byte) (int, error) { return read + readPriv, nil } -// PickleLen returns the number of bytes the pickled key pair will have. -func (c Ed25519KeyPair) PickleLen() int { - lenPublic := c.PublicKey.PickleLen() - var lenPrivate int - if len(c.PrivateKey) != ed25519.PrivateKeySize { - lenPrivate = libolmpickle.PickleBytesLen(make([]byte, ed25519.PrivateKeySize)) - } else { - lenPrivate = libolmpickle.PickleBytesLen(c.PrivateKey) - } - return lenPublic + lenPrivate -} - // Curve25519PrivateKey represents the private key for ed25519 usage. This is just a wrapper. type Ed25519PrivateKey ed25519.PrivateKey @@ -149,10 +141,11 @@ func (c Ed25519PublicKey) Verify(message, givenSignature []byte) bool { return ed25519.Verify(ed25519.PublicKey(c), message, givenSignature) } -// PickleLibOlm encodes the public key into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. +// PickleLibOlm encodes the public key into target. target has to have a size +// of at least [ed25519.PublicKeySize] and is written to from index 0. It +// returns the number of bytes written. func (c Ed25519PublicKey) PickleLibOlm(target []byte) (int, error) { - if len(target) < c.PickleLen() { + if len(target) < ed25519.PublicKeySize { return 0, fmt.Errorf("pickle ed25519 public key: %w", olm.ErrValueTooShort) } if len(c) != ed25519.PublicKeySize { @@ -170,11 +163,3 @@ func (c *Ed25519PublicKey) UnpickleLibOlm(value []byte) (int, error) { *c = unpickled return readBytes, nil } - -// PickleLen returns the number of bytes the pickled public key will have. -func (c Ed25519PublicKey) PickleLen() int { - if len(c) != ed25519.PublicKeySize { - return libolmpickle.PickleBytesLen(make([]byte, ed25519.PublicKeySize)) - } - return libolmpickle.PickleBytesLen(c) -} diff --git a/crypto/goolm/crypto/ed25519_test.go b/crypto/goolm/crypto/ed25519_test.go index 41fb0977..e5314622 100644 --- a/crypto/goolm/crypto/ed25519_test.go +++ b/crypto/goolm/crypto/ed25519_test.go @@ -38,7 +38,7 @@ func TestEd25519Pickle(t *testing.T) { //create keypair keyPair, err := crypto.Ed25519GenerateKey() assert.NoError(t, err) - target := make([]byte, keyPair.PickleLen()) + target := make([]byte, crypto.Ed25519KeyPairPickleLength) writtenBytes, err := keyPair.PickleLibOlm(target) assert.NoError(t, err) assert.Len(t, target, writtenBytes) @@ -56,7 +56,7 @@ func TestEd25519PicklePubKeyOnly(t *testing.T) { assert.NoError(t, err) //Remove privateKey keyPair.PrivateKey = nil - target := make([]byte, keyPair.PickleLen()) + target := make([]byte, crypto.Ed25519KeyPairPickleLength) writtenBytes, err := keyPair.PickleLibOlm(target) assert.NoError(t, err) assert.Len(t, target, writtenBytes) @@ -74,7 +74,7 @@ func TestEd25519PicklePrivKeyOnly(t *testing.T) { assert.NoError(t, err) //Remove public keyPair.PublicKey = nil - target := make([]byte, keyPair.PickleLen()) + target := make([]byte, crypto.Ed25519KeyPairPickleLength) writtenBytes, err := keyPair.PickleLibOlm(target) assert.NoError(t, err) assert.Len(t, target, writtenBytes) diff --git a/crypto/goolm/crypto/one_time_key.go b/crypto/goolm/crypto/one_time_key.go index aaa253d2..2fbd6366 100644 --- a/crypto/goolm/crypto/one_time_key.go +++ b/crypto/goolm/crypto/one_time_key.go @@ -17,6 +17,10 @@ type OneTimeKey struct { Key Curve25519KeyPair `json:"key,omitempty"` } +const OneTimeKeyPickleLength = libolmpickle.PickleUInt32Length + // ID + libolmpickle.PickleBoolLength + // Published + Curve25519KeyPairPickleLength // Key + // Equal compares the one time key to the given one. func (otk OneTimeKey) Equal(s OneTimeKey) bool { if otk.ID != s.ID { @@ -37,7 +41,7 @@ func (otk OneTimeKey) Equal(s OneTimeKey) bool { // PickleLibOlm encodes the key pair into target. target has to have a size of at least PickleLen() and is written to from index 0. // It returns the number of bytes written. func (c OneTimeKey) PickleLibOlm(target []byte) (int, error) { - if len(target) < c.PickleLen() { + if len(target) < OneTimeKeyPickleLength { return 0, fmt.Errorf("pickle one time key: %w", olm.ErrValueTooShort) } written := libolmpickle.PickleUInt32(uint32(c.ID), target) @@ -73,15 +77,6 @@ func (c *OneTimeKey) UnpickleLibOlm(value []byte) (int, error) { return totalReadBytes, nil } -// PickleLen returns the number of bytes the pickled OneTimeKey will have. -func (c OneTimeKey) PickleLen() int { - length := 0 - length += libolmpickle.PickleUInt32Len(c.ID) - length += libolmpickle.PickleBoolLen(c.Published) - length += c.Key.PickleLen() - return length -} - // KeyIDEncoded returns the base64 encoded id. func (c OneTimeKey) KeyIDEncoded() string { resSlice := make([]byte, 4) diff --git a/crypto/goolm/libolmpickle/pickle.go b/crypto/goolm/libolmpickle/pickle.go index ec125a34..bedeee04 100644 --- a/crypto/goolm/libolmpickle/pickle.go +++ b/crypto/goolm/libolmpickle/pickle.go @@ -4,13 +4,16 @@ import ( "encoding/binary" ) +const ( + PickleBoolLength = 1 + PickleUInt8Length = 1 + PickleUInt32Length = 4 +) + func PickleUInt8(value uint8, target []byte) int { target[0] = value return 1 } -func PickleUInt8Len(value uint8) int { - return 1 -} func PickleBool(value bool, target []byte) int { if value { @@ -20,9 +23,6 @@ func PickleBool(value bool, target []byte) int { } return 1 } -func PickleBoolLen(value bool) int { - return 1 -} func PickleBytes(value, target []byte) int { return copy(target, value) @@ -36,6 +36,3 @@ func PickleUInt32(value uint32, target []byte) int { binary.BigEndian.PutUint32(res, value) return copy(target, res) } -func PickleUInt32Len(value uint32) int { - return 4 -} diff --git a/crypto/goolm/libolmpickle/pickle_test.go b/crypto/goolm/libolmpickle/pickle_test.go index 27f083a0..d5596b2a 100644 --- a/crypto/goolm/libolmpickle/pickle_test.go +++ b/crypto/goolm/libolmpickle/pickle_test.go @@ -24,7 +24,7 @@ func TestPickleUInt32(t *testing.T) { for curIndex := range values { response := make([]byte, 4) resPLen := libolmpickle.PickleUInt32(values[curIndex], response) - assert.Equal(t, libolmpickle.PickleUInt32Len(values[curIndex]), resPLen) + assert.Equal(t, libolmpickle.PickleUInt32Length, resPLen) assert.Equal(t, expected[curIndex], response) } } @@ -41,7 +41,7 @@ func TestPickleBool(t *testing.T) { for curIndex := range values { response := make([]byte, 1) resPLen := libolmpickle.PickleBool(values[curIndex], response) - assert.Equal(t, libolmpickle.PickleBoolLen(values[curIndex]), resPLen) + assert.Equal(t, libolmpickle.PickleBoolLength, resPLen) assert.Equal(t, expected[curIndex], response) } } @@ -58,7 +58,7 @@ func TestPickleUInt8(t *testing.T) { for curIndex := range values { response := make([]byte, 1) resPLen := libolmpickle.PickleUInt8(values[curIndex], response) - assert.Equal(t, libolmpickle.PickleUInt8Len(values[curIndex]), resPLen) + assert.Equal(t, libolmpickle.PickleUInt8Length, resPLen) assert.Equal(t, expected[curIndex], response) } } diff --git a/crypto/goolm/megolm/megolm.go b/crypto/goolm/megolm/megolm.go index c88583ee..47e24077 100644 --- a/crypto/goolm/megolm/megolm.go +++ b/crypto/goolm/megolm/megolm.go @@ -22,6 +22,9 @@ const ( protocolVersion = 3 RatchetParts = 4 // number of ratchet parts RatchetPartLength = 256 / 8 // length of each ratchet part in bytes + + RatchetPickleLength = (RatchetParts * RatchetPartLength) + //Data + libolmpickle.PickleUInt32Length // Counter ) var RatchetCipher = cipher.NewAESSHA256([]byte("MEGOLM_KEYS")) @@ -219,17 +222,10 @@ func (r *Ratchet) UnpickleLibOlm(unpickled []byte) (int, error) { // PickleLibOlm encodes the ratchet into target. target has to have a size of at least PickleLen() and is written to from index 0. // It returns the number of bytes written. func (r Ratchet) PickleLibOlm(target []byte) (int, error) { - if len(target) < r.PickleLen() { + if len(target) < RatchetPickleLength { return 0, fmt.Errorf("pickle account: %w", olm.ErrValueTooShort) } written := libolmpickle.PickleBytes(r.Data[:], target) written += libolmpickle.PickleUInt32(r.Counter, target[written:]) return written, nil } - -// PickleLen returns the number of bytes the pickled ratchet will have. -func (r Ratchet) PickleLen() int { - length := libolmpickle.PickleBytesLen(r.Data[:]) - length += libolmpickle.PickleUInt32Len(r.Counter) - return length -} diff --git a/crypto/goolm/pk/decryption.go b/crypto/goolm/pk/decryption.go index 76dfa0da..fb537eaf 100644 --- a/crypto/goolm/pk/decryption.go +++ b/crypto/goolm/pk/decryption.go @@ -17,6 +17,9 @@ import ( const ( decryptionPickleVersionJSON uint8 = 1 decryptionPickleVersionLibOlm uint32 = 1 + + DecryptionPickleLength = libolmpickle.PickleUInt32Length + // Version + crypto.Curve25519KeyPairPickleLength // KeyPair ) // Decryption is used to decrypt pk messages @@ -124,7 +127,7 @@ func (a *Decryption) UnpickleLibOlm(value []byte) (int, error) { // Pickle returns a base64 encoded and with key encrypted pickled Decryption using PickleLibOlm(). func (a Decryption) Pickle(key []byte) ([]byte, error) { - pickeledBytes := make([]byte, a.PickleLen()) + pickeledBytes := make([]byte, DecryptionPickleLength) written, err := a.PickleLibOlm(pickeledBytes) if err != nil { return nil, err @@ -135,10 +138,11 @@ func (a Decryption) Pickle(key []byte) ([]byte, error) { return cipher.Pickle(key, pickeledBytes) } -// PickleLibOlm encodes the Decryption into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. +// PickleLibOlm encodes the Decryption into target. target has to have a size +// of at least [DecryptionPickleLength] and is written to from index 0. It +// returns the number of bytes written. func (a Decryption) PickleLibOlm(target []byte) (int, error) { - if len(target) < a.PickleLen() { + if len(target) < DecryptionPickleLength { return 0, fmt.Errorf("pickle Decryption: %w", olm.ErrValueTooShort) } written := libolmpickle.PickleUInt32(decryptionPickleVersionLibOlm, target) @@ -149,10 +153,3 @@ func (a Decryption) PickleLibOlm(target []byte) (int, error) { written += writtenKey return written, nil } - -// PickleLen returns the number of bytes the pickled Decryption will have. -func (a Decryption) PickleLen() int { - length := libolmpickle.PickleUInt32Len(decryptionPickleVersionLibOlm) - length += a.KeyPair.PickleLen() - return length -} diff --git a/crypto/goolm/ratchet/chain.go b/crypto/goolm/ratchet/chain.go index 2c2789b7..ea0400e8 100644 --- a/crypto/goolm/ratchet/chain.go +++ b/crypto/goolm/ratchet/chain.go @@ -19,6 +19,9 @@ type chainKey struct { Key crypto.Curve25519PublicKey `json:"key"` } +const chainKeyPickleLength = crypto.Curve25519PubKeyLength + // Key + libolmpickle.PickleUInt32Length // Index + // advance advances the chain func (c *chainKey) advance() { c.Key = crypto.HMACSHA256(c.Key, []byte{chainKeySeed}) @@ -44,7 +47,7 @@ func (r *chainKey) UnpickleLibOlm(value []byte) (int, error) { // PickleLibOlm encodes the chain key into target. target has to have a size of at least PickleLen() and is written to from index 0. // It returns the number of bytes written. func (r chainKey) PickleLibOlm(target []byte) (int, error) { - if len(target) < r.PickleLen() { + if len(target) < chainKeyPickleLength { return 0, fmt.Errorf("pickle chain key: %w", olm.ErrValueTooShort) } written, err := r.Key.PickleLibOlm(target) @@ -55,13 +58,6 @@ func (r chainKey) PickleLibOlm(target []byte) (int, error) { return written, nil } -// PickleLen returns the number of bytes the pickled chain key will have. -func (r chainKey) PickleLen() int { - length := r.Key.PickleLen() - length += libolmpickle.PickleUInt32Len(r.Index) - return length -} - // senderChain is a chain for sending messages type senderChain struct { RKey crypto.Curve25519KeyPair `json:"ratchet_key"` @@ -69,6 +65,9 @@ type senderChain struct { IsSet bool `json:"set"` } +const senderChainPickleLength = chainKeyPickleLength + // RKey + chainKeyPickleLength // CKey + // newSenderChain returns a sender chain initialized with chainKey and ratchet key pair. func newSenderChain(key crypto.Curve25519PublicKey, ratchet crypto.Curve25519KeyPair) *senderChain { return &senderChain{ @@ -115,7 +114,7 @@ func (r *senderChain) UnpickleLibOlm(value []byte) (int, error) { // PickleLibOlm encodes the chain into target. target has to have a size of at least PickleLen() and is written to from index 0. // It returns the number of bytes written. func (r senderChain) PickleLibOlm(target []byte) (int, error) { - if len(target) < r.PickleLen() { + if len(target) < senderChainPickleLength { return 0, fmt.Errorf("pickle sender chain: %w", olm.ErrValueTooShort) } written, err := r.RKey.PickleLibOlm(target) @@ -130,19 +129,15 @@ func (r senderChain) PickleLibOlm(target []byte) (int, error) { return written, nil } -// PickleLen returns the number of bytes the pickled chain will have. -func (r senderChain) PickleLen() int { - length := r.RKey.PickleLen() - length += r.CKey.PickleLen() - return length -} - // senderChain is a chain for receiving messages type receiverChain struct { RKey crypto.Curve25519PublicKey `json:"ratchet_key"` CKey chainKey `json:"chain_key"` } +const receiverChainPickleLength = crypto.Curve25519PubKeyLength + // Ratchet Key + chainKeyPickleLength // CKey + // newReceiverChain returns a receiver chain initialized with chainKey and ratchet public key. func newReceiverChain(chain crypto.Curve25519PublicKey, ratchet crypto.Curve25519PublicKey) *receiverChain { return &receiverChain{ @@ -188,7 +183,7 @@ func (r *receiverChain) UnpickleLibOlm(value []byte) (int, error) { // PickleLibOlm encodes the chain into target. target has to have a size of at least PickleLen() and is written to from index 0. // It returns the number of bytes written. func (r receiverChain) PickleLibOlm(target []byte) (int, error) { - if len(target) < r.PickleLen() { + if len(target) < receiverChainPickleLength { return 0, fmt.Errorf("pickle sender chain: %w", olm.ErrValueTooShort) } written, err := r.RKey.PickleLibOlm(target) @@ -203,19 +198,15 @@ func (r receiverChain) PickleLibOlm(target []byte) (int, error) { return written, nil } -// PickleLen returns the number of bytes the pickled chain will have. -func (r receiverChain) PickleLen() int { - length := r.RKey.PickleLen() - length += r.CKey.PickleLen() - return length -} - // messageKey wraps the index and the key of a message type messageKey struct { Index uint32 `json:"index"` Key []byte `json:"key"` } +const messageKeyPickleLength = messageKeyLength + // Key + libolmpickle.PickleUInt32Length // Index + // UnpickleLibOlm decodes the unencryted value and populates the message key accordingly. It returns the number of bytes read. func (m *messageKey) UnpickleLibOlm(value []byte) (int, error) { curPos := 0 @@ -237,7 +228,7 @@ func (m *messageKey) UnpickleLibOlm(value []byte) (int, error) { // PickleLibOlm encodes the message key into target. target has to have a size of at least PickleLen() and is written to from index 0. // It returns the number of bytes written. func (m messageKey) PickleLibOlm(target []byte) (int, error) { - if len(target) < m.PickleLen() { + if len(target) < messageKeyPickleLength { return 0, fmt.Errorf("pickle message key: %w", olm.ErrValueTooShort) } written := 0 @@ -249,10 +240,3 @@ func (m messageKey) PickleLibOlm(target []byte) (int, error) { written += libolmpickle.PickleUInt32(m.Index, target[written:]) return written, nil } - -// PickleLen returns the number of bytes the pickled message key will have. -func (r messageKey) PickleLen() int { - length := libolmpickle.PickleBytesLen(make([]byte, messageKeyLength)) - length += libolmpickle.PickleUInt32Len(r.Index) - return length -} diff --git a/crypto/goolm/ratchet/olm.go b/crypto/goolm/ratchet/olm.go index 879f6cfe..8e0944d4 100644 --- a/crypto/goolm/ratchet/olm.go +++ b/crypto/goolm/ratchet/olm.go @@ -388,25 +388,25 @@ func (r Ratchet) PickleLibOlm(target []byte) (int, error) { // PickleLen returns the actual number of bytes the pickled ratchet will have. func (r Ratchet) PickleLen() int { - length := r.RootKey.PickleLen() + length := crypto.Curve25519PubKeyLength // Root Key if r.SenderChains.IsSet { - length += libolmpickle.PickleUInt32Len(1) - length += r.SenderChains.PickleLen() + length += libolmpickle.PickleUInt32Length // 1 + length += senderChainPickleLength // SenderChains } else { - length += libolmpickle.PickleUInt32Len(0) + length += libolmpickle.PickleUInt32Length // 0 } - length += libolmpickle.PickleUInt32Len(uint32(len(r.ReceiverChains))) - length += len(r.ReceiverChains) * receiverChain{}.PickleLen() - length += libolmpickle.PickleUInt32Len(uint32(len(r.SkippedMessageKeys))) - length += len(r.SkippedMessageKeys) * skippedMessageKey{}.PickleLen() + length += libolmpickle.PickleUInt32Length // ReceiverChains length + length += len(r.ReceiverChains) * receiverChainPickleLength + length += libolmpickle.PickleUInt32Length // SkippedMessageKeys length + length += len(r.SkippedMessageKeys) * skippedMessageKeyPickleLen return length } // PickleLen returns the minimum number of bytes the pickled ratchet must have. func (r Ratchet) PickleLenMin() int { - length := r.RootKey.PickleLen() - length += libolmpickle.PickleUInt32Len(0) - length += libolmpickle.PickleUInt32Len(0) - length += libolmpickle.PickleUInt32Len(0) + length := crypto.Curve25519PubKeyLength // Root Key + length += libolmpickle.PickleUInt32Length + length += libolmpickle.PickleUInt32Length + length += libolmpickle.PickleUInt32Length return length } diff --git a/crypto/goolm/ratchet/skipped_message.go b/crypto/goolm/ratchet/skipped_message.go index 79927480..b577edbe 100644 --- a/crypto/goolm/ratchet/skipped_message.go +++ b/crypto/goolm/ratchet/skipped_message.go @@ -13,6 +13,9 @@ type skippedMessageKey struct { MKey messageKey `json:"message_key"` } +const skippedMessageKeyPickleLen = crypto.Curve25519PubKeyLength + // RKey + messageKeyPickleLength // MKey + // UnpickleLibOlm decodes the unencryted value and populates the chain accordingly. It returns the number of bytes read. func (r *skippedMessageKey) UnpickleLibOlm(value []byte) (int, error) { curPos := 0 @@ -32,7 +35,7 @@ func (r *skippedMessageKey) UnpickleLibOlm(value []byte) (int, error) { // PickleLibOlm encodes the chain into target. target has to have a size of at least PickleLen() and is written to from index 0. // It returns the number of bytes written. func (r skippedMessageKey) PickleLibOlm(target []byte) (int, error) { - if len(target) < r.PickleLen() { + if len(target) < skippedMessageKeyPickleLen { return 0, fmt.Errorf("pickle sender chain: %w", olm.ErrValueTooShort) } written, err := r.RKey.PickleLibOlm(target) @@ -46,10 +49,3 @@ func (r skippedMessageKey) PickleLibOlm(target []byte) (int, error) { written += writtenChain return written, nil } - -// PickleLen returns the number of bytes the pickled chain will have. -func (r skippedMessageKey) PickleLen() int { - length := r.RKey.PickleLen() - length += r.MKey.PickleLen() - return length -} diff --git a/crypto/goolm/session/megolm_inbound_session.go b/crypto/goolm/session/megolm_inbound_session.go index bfb24322..50ca4d97 100644 --- a/crypto/goolm/session/megolm_inbound_session.go +++ b/crypto/goolm/session/megolm_inbound_session.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" + "maunium.net/go/mautrix/crypto/ed25519" "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/goolmbase64" @@ -19,6 +20,12 @@ import ( const ( megolmInboundSessionPickleVersionJSON byte = 1 megolmInboundSessionPickleVersionLibOlm uint32 = 2 + + megolmInboundSessionPickleLength = libolmpickle.PickleUInt32Length + // Version + megolm.RatchetPickleLength + // InitialRatchet + megolm.RatchetPickleLength + // Ratchet + ed25519.PublicKeySize + // SigningKey + libolmpickle.PickleBoolLength // Verified ) // MegolmInboundSession stores information about the sessions of receive. @@ -246,7 +253,7 @@ func (o *MegolmInboundSession) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { return nil, olm.ErrNoKeyProvided } - pickeledBytes := make([]byte, o.PickleLen()) + pickeledBytes := make([]byte, megolmInboundSessionPickleLength) written, err := o.PickleLibOlm(pickeledBytes) if err != nil { return nil, err @@ -260,7 +267,7 @@ func (o *MegolmInboundSession) Pickle(key []byte) ([]byte, error) { // PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0. // It returns the number of bytes written. func (o *MegolmInboundSession) PickleLibOlm(target []byte) (int, error) { - if len(target) < o.PickleLen() { + if len(target) < megolmInboundSessionPickleLength { return 0, fmt.Errorf("pickle MegolmInboundSession: %w", olm.ErrValueTooShort) } written := libolmpickle.PickleUInt32(megolmInboundSessionPickleVersionLibOlm, target) @@ -283,16 +290,6 @@ func (o *MegolmInboundSession) PickleLibOlm(target []byte) (int, error) { return written, nil } -// PickleLen returns the number of bytes the pickled session will have. -func (o *MegolmInboundSession) PickleLen() int { - length := libolmpickle.PickleUInt32Len(megolmInboundSessionPickleVersionLibOlm) - length += o.InitialRatchet.PickleLen() - length += o.Ratchet.PickleLen() - length += o.SigningKey.PickleLen() - length += libolmpickle.PickleBoolLen(o.SigningKeyVerified) - return length -} - // FirstKnownIndex returns the first message index we know how to decrypt. func (s *MegolmInboundSession) FirstKnownIndex() uint32 { return s.InitialRatchet.Counter diff --git a/crypto/goolm/session/megolm_outbound_session.go b/crypto/goolm/session/megolm_outbound_session.go index e85a5c12..d58650ce 100644 --- a/crypto/goolm/session/megolm_outbound_session.go +++ b/crypto/goolm/session/megolm_outbound_session.go @@ -21,6 +21,10 @@ import ( const ( megolmOutboundSessionPickleVersion byte = 1 megolmOutboundSessionPickleVersionLibOlm uint32 = 1 + + MegolmOutboundSessionPickleLength = libolmpickle.PickleUInt32Length + // Version + megolm.RatchetPickleLength + // Ratchet + crypto.Ed25519KeyPairPickleLength // SigningKey ) // MegolmOutboundSession stores information about the sessions to send. @@ -130,7 +134,7 @@ func (o *MegolmOutboundSession) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { return nil, olm.ErrNoKeyProvided } - pickeledBytes := make([]byte, o.PickleLen()) + pickeledBytes := make([]byte, MegolmOutboundSessionPickleLength) written, err := o.PickleLibOlm(pickeledBytes) if err != nil { return nil, err @@ -144,7 +148,7 @@ func (o *MegolmOutboundSession) Pickle(key []byte) ([]byte, error) { // PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0. // It returns the number of bytes written. func (o *MegolmOutboundSession) PickleLibOlm(target []byte) (int, error) { - if len(target) < o.PickleLen() { + if len(target) < MegolmOutboundSessionPickleLength { return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", olm.ErrValueTooShort) } written := libolmpickle.PickleUInt32(megolmOutboundSessionPickleVersionLibOlm, target) @@ -161,14 +165,6 @@ func (o *MegolmOutboundSession) PickleLibOlm(target []byte) (int, error) { return written, nil } -// PickleLen returns the number of bytes the pickled session will have. -func (o *MegolmOutboundSession) PickleLen() int { - length := libolmpickle.PickleUInt32Len(megolmOutboundSessionPickleVersionLibOlm) - length += o.Ratchet.PickleLen() - length += o.SigningKey.PickleLen() - return length -} - func (o *MegolmOutboundSession) SessionSharingMessage() ([]byte, error) { return o.Ratchet.SessionSharingMessage(o.SigningKey) } diff --git a/crypto/goolm/session/olm_session.go b/crypto/goolm/session/olm_session.go index f58edf87..2c9585b3 100644 --- a/crypto/goolm/session/olm_session.go +++ b/crypto/goolm/session/olm_session.go @@ -458,22 +458,22 @@ func (o *OlmSession) PickleLibOlm(target []byte) (int, error) { // PickleLen returns the actual number of bytes the pickled session will have. func (o *OlmSession) PickleLen() int { - length := libolmpickle.PickleUInt32Len(olmSessionPickleVersionLibOlm) - length += libolmpickle.PickleBoolLen(o.ReceivedMessage) - length += o.AliceIdentityKey.PickleLen() - length += o.AliceBaseKey.PickleLen() - length += o.BobOneTimeKey.PickleLen() + length := libolmpickle.PickleUInt32Length + length += libolmpickle.PickleBoolLength + length += crypto.Curve25519PubKeyLength // AliceIdentityKey + length += crypto.Curve25519PubKeyLength // AliceBaseKey + length += crypto.Curve25519PubKeyLength // BobOneTimeKey length += o.Ratchet.PickleLen() return length } // PickleLenMin returns the minimum number of bytes the pickled session must have. func (o *OlmSession) PickleLenMin() int { - length := libolmpickle.PickleUInt32Len(olmSessionPickleVersionLibOlm) - length += libolmpickle.PickleBoolLen(o.ReceivedMessage) - length += o.AliceIdentityKey.PickleLen() - length += o.AliceBaseKey.PickleLen() - length += o.BobOneTimeKey.PickleLen() + length := libolmpickle.PickleUInt32Length + length += libolmpickle.PickleBoolLength + length += crypto.Curve25519PubKeyLength // AliceIdentityKey + length += crypto.Curve25519PubKeyLength // AliceBaseKey + length += crypto.Curve25519PubKeyLength // BobOneTimeKey length += o.Ratchet.PickleLenMin() return length } From e525e151e14bc64dccce2fd245ab48f4806b785b Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 24 Oct 2024 21:52:46 -0600 Subject: [PATCH 0873/1647] goolm/libolmpickle: add Encoder for easier pickling API Signed-off-by: Sumner Evans --- crypto/goolm/account/account.go | 78 ++++------------ crypto/goolm/crypto/curve25519.go | 51 ++++------ crypto/goolm/crypto/curve25519_test.go | 41 +++++---- crypto/goolm/crypto/ed25519.go | 43 +++------ crypto/goolm/crypto/ed25519_test.go | 38 ++++---- crypto/goolm/crypto/one_time_key.go | 25 +---- crypto/goolm/libolmpickle/pickle.go | 34 +++---- crypto/goolm/libolmpickle/pickle_test.go | 56 +++++++---- crypto/goolm/libolmpickle/unpickle.go | 11 ++- crypto/goolm/megolm/megolm.go | 16 +--- crypto/goolm/message/prekey_message.go | 6 +- crypto/goolm/pk/decryption.go | 34 ++----- crypto/goolm/ratchet/chain.go | 92 +++++-------------- crypto/goolm/ratchet/olm.go | 70 +++----------- crypto/goolm/ratchet/skipped_message.go | 27 +----- .../goolm/session/megolm_inbound_session.go | 51 ++-------- .../goolm/session/megolm_outbound_session.go | 40 ++------ crypto/goolm/session/olm_session.go | 78 +++------------- crypto/olm/errors.go | 1 - 19 files changed, 241 insertions(+), 551 deletions(-) diff --git a/crypto/goolm/account/account.go b/crypto/goolm/account/account.go index dcb4c9d4..2a99e985 100644 --- a/crypto/goolm/account/account.go +++ b/crypto/goolm/account/account.go @@ -4,7 +4,6 @@ package account import ( "encoding/base64" "encoding/json" - "errors" "fmt" "maunium.net/go/mautrix/id" @@ -430,73 +429,32 @@ func (a *Account) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { return nil, olm.ErrNoKeyProvided } - pickeledBytes := make([]byte, a.PickleLen()) - written, err := a.PickleLibOlm(pickeledBytes) - if err != nil { - return nil, err - } - if written != len(pickeledBytes) { - return nil, errors.New("number of written bytes not correct") - } - return cipher.Pickle(key, pickeledBytes) + return cipher.Pickle(key, a.PickleLibOlm()) } -// PickleLibOlm encodes the Account into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (a *Account) PickleLibOlm(target []byte) (int, error) { - if len(target) < a.PickleLen() { - return 0, fmt.Errorf("pickle account: %w", olm.ErrValueTooShort) - } - written := libolmpickle.PickleUInt32(accountPickleVersionLibOLM, target) - writtenEdKey, err := a.IdKeys.Ed25519.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle account: %w", err) - } - written += writtenEdKey - writtenCurveKey, err := a.IdKeys.Curve25519.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle account: %w", err) - } - written += writtenCurveKey - written += libolmpickle.PickleUInt32(uint32(len(a.OTKeys)), target[written:]) +// PickleLibOlm pickles the [Account] and returns the raw bytes. +func (a *Account) PickleLibOlm() []byte { + encoder := libolmpickle.NewEncoder() + encoder.WriteUInt32(accountPickleVersionLibOLM) + a.IdKeys.Ed25519.PickleLibOlm(encoder) + a.IdKeys.Curve25519.PickleLibOlm(encoder) + + // One-Time Keys + encoder.WriteUInt32(uint32(len(a.OTKeys))) for _, curOTKey := range a.OTKeys { - writtenOT, err := curOTKey.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle account: %w", err) - } - written += writtenOT + curOTKey.PickleLibOlm(encoder) } - written += libolmpickle.PickleUInt8(a.NumFallbackKeys, target[written:]) + + // Fallback Keys + encoder.WriteUInt8(a.NumFallbackKeys) if a.NumFallbackKeys >= 1 { - writtenOT, err := a.CurrentFallbackKey.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle account: %w", err) - } - written += writtenOT - + a.CurrentFallbackKey.PickleLibOlm(encoder) if a.NumFallbackKeys >= 2 { - writtenOT, err := a.PrevFallbackKey.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle account: %w", err) - } - written += writtenOT + a.PrevFallbackKey.PickleLibOlm(encoder) } } - written += libolmpickle.PickleUInt32(a.NextOneTimeKeyID, target[written:]) - return written, nil -} - -// PickleLen returns the number of bytes the pickled Account will have. -func (a *Account) PickleLen() int { - length := libolmpickle.PickleUInt32Length - length += crypto.Ed25519KeyPairPickleLength // IdKeys.Ed25519 - length += crypto.Curve25519KeyPairPickleLength // IdKeys.Curve25519 - length += libolmpickle.PickleUInt32Length - length += (len(a.OTKeys) * crypto.OneTimeKeyPickleLength) - length += libolmpickle.PickleUInt8Length - length += (int(a.NumFallbackKeys) * crypto.OneTimeKeyPickleLength) - length += libolmpickle.PickleUInt32Length - return length + encoder.WriteUInt32(a.NextOneTimeKeyID) + return encoder.Bytes() } // MaxNumberOfOneTimeKeys returns the largest number of one time keys this diff --git a/crypto/goolm/crypto/curve25519.go b/crypto/goolm/crypto/curve25519.go index 01ada7e0..459031f1 100644 --- a/crypto/goolm/crypto/curve25519.go +++ b/crypto/goolm/crypto/curve25519.go @@ -4,23 +4,21 @@ import ( "bytes" "crypto/rand" "encoding/base64" - "fmt" "golang.org/x/crypto/curve25519" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" - "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) const ( - Curve25519KeyLength = curve25519.ScalarSize //The length of the private key. - Curve25519PubKeyLength = 32 + Curve25519PrivateKeyLength = curve25519.ScalarSize //The length of the private key. + Curve25519PublicKeyLength = 32 ) // Curve25519GenerateKey creates a new curve25519 key pair. func Curve25519GenerateKey() (Curve25519KeyPair, error) { - privateKeyByte := make([]byte, Curve25519KeyLength) + privateKeyByte := make([]byte, Curve25519PrivateKeyLength) if _, err := rand.Read(privateKeyByte); err != nil { return Curve25519KeyPair{}, err } @@ -51,9 +49,6 @@ type Curve25519KeyPair struct { PublicKey Curve25519PublicKey `json:"public,omitempty"` } -const Curve25519KeyPairPickleLength = Curve25519PubKeyLength + // Public Key - Curve25519KeyLength // Private Key - // B64Encoded returns a base64 encoded string of the public key. func (c Curve25519KeyPair) B64Encoded() id.Curve25519 { return c.PublicKey.B64Encoded() @@ -64,23 +59,14 @@ func (c Curve25519KeyPair) SharedSecret(pubKey Curve25519PublicKey) ([]byte, err return c.PrivateKey.SharedSecret(pubKey) } -// PickleLibOlm encodes the key pair into target. The target has to have a size -// of at least [Curve25519KeyPairPickleLength] and is written to from index 0. -// It returns the number of bytes written. -func (c Curve25519KeyPair) PickleLibOlm(target []byte) (int, error) { - if len(target) < Curve25519KeyPairPickleLength { - return 0, fmt.Errorf("pickle curve25519 key pair: %w", olm.ErrValueTooShort) - } - written, err := c.PublicKey.PickleLibOlm(target) - if err != nil { - return 0, fmt.Errorf("pickle curve25519 key pair: %w", err) - } - if len(c.PrivateKey) != Curve25519KeyLength { - written += libolmpickle.PickleBytes(make([]byte, Curve25519KeyLength), target[written:]) +// PickleLibOlm pickles the key pair into the encoder. +func (c Curve25519KeyPair) PickleLibOlm(encoder *libolmpickle.Encoder) { + c.PublicKey.PickleLibOlm(encoder) + if len(c.PrivateKey) == Curve25519PrivateKeyLength { + encoder.Write(c.PrivateKey) } else { - written += libolmpickle.PickleBytes(c.PrivateKey, target[written:]) + encoder.WriteEmptyBytes(Curve25519PrivateKeyLength) } - return written, nil } // UnpickleLibOlm decodes the unencryted value and populates the key pair accordingly. It returns the number of bytes read. @@ -91,7 +77,7 @@ func (c *Curve25519KeyPair) UnpickleLibOlm(value []byte) (int, error) { return 0, err } //unpickle PrivateKey - privKey, readPriv, err := libolmpickle.UnpickleBytes(value[read:], Curve25519KeyLength) + privKey, readPriv, err := libolmpickle.UnpickleBytes(value[read:], Curve25519PrivateKeyLength) if err != nil { return read, err } @@ -130,21 +116,18 @@ func (c Curve25519PublicKey) B64Encoded() id.Curve25519 { return id.Curve25519(base64.RawStdEncoding.EncodeToString(c)) } -// PickleLibOlm encodes the public key into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (c Curve25519PublicKey) PickleLibOlm(target []byte) (int, error) { - if len(target) < Curve25519PubKeyLength { - return 0, fmt.Errorf("pickle curve25519 public key: %w", olm.ErrValueTooShort) +// PickleLibOlm pickles the public key into the encoder. +func (c Curve25519PublicKey) PickleLibOlm(encoder *libolmpickle.Encoder) { + if len(c) == Curve25519PublicKeyLength { + encoder.Write(c) + } else { + encoder.WriteEmptyBytes(Curve25519PublicKeyLength) } - if len(c) != Curve25519PubKeyLength { - return libolmpickle.PickleBytes(make([]byte, Curve25519PubKeyLength), target), nil - } - return libolmpickle.PickleBytes(c, target), nil } // UnpickleLibOlm decodes the unencryted value and populates the public key accordingly. It returns the number of bytes read. func (c *Curve25519PublicKey) UnpickleLibOlm(value []byte) (int, error) { - unpickled, readBytes, err := libolmpickle.UnpickleBytes(value, Curve25519PubKeyLength) + unpickled, readBytes, err := libolmpickle.UnpickleBytes(value, Curve25519PublicKeyLength) if err != nil { return 0, err } diff --git a/crypto/goolm/crypto/curve25519_test.go b/crypto/goolm/crypto/curve25519_test.go index fc8ee54b..fb9f0098 100644 --- a/crypto/goolm/crypto/curve25519_test.go +++ b/crypto/goolm/crypto/curve25519_test.go @@ -6,8 +6,12 @@ import ( "github.com/stretchr/testify/assert" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" ) +const curve25519KeyPairPickleLength = crypto.Curve25519PublicKeyLength + // Public Key + crypto.Curve25519PrivateKeyLength // Private Key + func TestCurve25519(t *testing.T) { firstKeypair, err := crypto.Curve25519GenerateKey() assert.NoError(t, err) @@ -74,15 +78,15 @@ func TestCurve25519Pickle(t *testing.T) { //create keypair keyPair, err := crypto.Curve25519GenerateKey() assert.NoError(t, err) - target := make([]byte, crypto.Curve25519KeyPairPickleLength) - writtenBytes, err := keyPair.PickleLibOlm(target) - assert.NoError(t, err) - assert.Len(t, target, writtenBytes) + + encoder := libolmpickle.NewEncoder() + keyPair.PickleLibOlm(encoder) + assert.Len(t, encoder.Bytes(), curve25519KeyPairPickleLength) unpickledKeyPair := crypto.Curve25519KeyPair{} - readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) + readBytes, err := unpickledKeyPair.UnpickleLibOlm(encoder.Bytes()) assert.NoError(t, err) - assert.Len(t, target, readBytes) + assert.Len(t, encoder.Bytes(), readBytes) assert.Equal(t, keyPair, unpickledKeyPair) } @@ -90,16 +94,18 @@ func TestCurve25519PicklePubKeyOnly(t *testing.T) { //create keypair keyPair, err := crypto.Curve25519GenerateKey() assert.NoError(t, err) + //Remove privateKey keyPair.PrivateKey = nil - target := make([]byte, crypto.Curve25519KeyPairPickleLength) - writtenBytes, err := keyPair.PickleLibOlm(target) - assert.NoError(t, err) - assert.Len(t, target, writtenBytes) + + encoder := libolmpickle.NewEncoder() + keyPair.PickleLibOlm(encoder) + assert.Len(t, encoder.Bytes(), curve25519KeyPairPickleLength) + unpickledKeyPair := crypto.Curve25519KeyPair{} - readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) + readBytes, err := unpickledKeyPair.UnpickleLibOlm(encoder.Bytes()) assert.NoError(t, err) - assert.Len(t, target, readBytes) + assert.Len(t, encoder.Bytes(), readBytes) assert.Equal(t, keyPair, unpickledKeyPair) } @@ -109,13 +115,12 @@ func TestCurve25519PicklePrivKeyOnly(t *testing.T) { assert.NoError(t, err) //Remove public keyPair.PublicKey = nil - target := make([]byte, crypto.Curve25519KeyPairPickleLength) - writtenBytes, err := keyPair.PickleLibOlm(target) - assert.NoError(t, err) - assert.Len(t, target, writtenBytes) + encoder := libolmpickle.NewEncoder() + keyPair.PickleLibOlm(encoder) + assert.Len(t, encoder.Bytes(), curve25519KeyPairPickleLength) unpickledKeyPair := crypto.Curve25519KeyPair{} - readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) + readBytes, err := unpickledKeyPair.UnpickleLibOlm(encoder.Bytes()) assert.NoError(t, err) - assert.Len(t, target, readBytes) + assert.Len(t, encoder.Bytes(), readBytes) assert.Equal(t, keyPair, unpickledKeyPair) } diff --git a/crypto/goolm/crypto/ed25519.go b/crypto/goolm/crypto/ed25519.go index abacb1ee..bc535e6a 100644 --- a/crypto/goolm/crypto/ed25519.go +++ b/crypto/goolm/crypto/ed25519.go @@ -2,11 +2,9 @@ package crypto import ( "encoding/base64" - "fmt" "maunium.net/go/mautrix/crypto/ed25519" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" - "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) @@ -46,9 +44,6 @@ type Ed25519KeyPair struct { PublicKey Ed25519PublicKey `json:"public,omitempty"` } -const Ed25519KeyPairPickleLength = ed25519.PublicKeySize + // PublicKey - ed25519.PrivateKeySize // Private Key - // B64Encoded returns a base64 encoded string of the public key. func (c Ed25519KeyPair) B64Encoded() id.Ed25519 { return id.Ed25519(base64.RawStdEncoding.EncodeToString(c.PublicKey)) @@ -64,24 +59,14 @@ func (c Ed25519KeyPair) Verify(message, givenSignature []byte) bool { return c.PublicKey.Verify(message, givenSignature) } -// PickleLibOlm encodes the key pair into target. target has to have a size of -// at least [Ed25519KeyPairPickleLength] and is written to from index 0. It -// returns the number of bytes written. -func (c Ed25519KeyPair) PickleLibOlm(target []byte) (int, error) { - if len(target) < Ed25519KeyPairPickleLength { - return 0, fmt.Errorf("pickle ed25519 key pair: %w", olm.ErrValueTooShort) - } - written, err := c.PublicKey.PickleLibOlm(target) - if err != nil { - return 0, fmt.Errorf("pickle ed25519 key pair: %w", err) - } - - if len(c.PrivateKey) != ed25519.PrivateKeySize { - written += libolmpickle.PickleBytes(make([]byte, ed25519.PrivateKeySize), target[written:]) +// PickleLibOlm pickles the key pair into the encoder. +func (c Ed25519KeyPair) PickleLibOlm(encoder *libolmpickle.Encoder) { + c.PublicKey.PickleLibOlm(encoder) + if len(c.PrivateKey) == ed25519.PrivateKeySize { + encoder.Write(c.PrivateKey) } else { - written += libolmpickle.PickleBytes(c.PrivateKey, target[written:]) + encoder.WriteEmptyBytes(ed25519.PrivateKeySize) } - return written, nil } // UnpickleLibOlm decodes the unencryted value and populates the key pair accordingly. It returns the number of bytes read. @@ -141,17 +126,13 @@ func (c Ed25519PublicKey) Verify(message, givenSignature []byte) bool { return ed25519.Verify(ed25519.PublicKey(c), message, givenSignature) } -// PickleLibOlm encodes the public key into target. target has to have a size -// of at least [ed25519.PublicKeySize] and is written to from index 0. It -// returns the number of bytes written. -func (c Ed25519PublicKey) PickleLibOlm(target []byte) (int, error) { - if len(target) < ed25519.PublicKeySize { - return 0, fmt.Errorf("pickle ed25519 public key: %w", olm.ErrValueTooShort) +// PickleLibOlm pickles the public key into the encoder. +func (c Ed25519PublicKey) PickleLibOlm(encoder *libolmpickle.Encoder) { + if len(c) == ed25519.PublicKeySize { + encoder.Write(c) + } else { + encoder.WriteEmptyBytes(ed25519.PublicKeySize) } - if len(c) != ed25519.PublicKeySize { - return libolmpickle.PickleBytes(make([]byte, ed25519.PublicKeySize), target), nil - } - return libolmpickle.PickleBytes(c, target), nil } // UnpickleLibOlm decodes the unencryted value and populates the public key accordingly. It returns the number of bytes read. diff --git a/crypto/goolm/crypto/ed25519_test.go b/crypto/goolm/crypto/ed25519_test.go index e5314622..3ac8863a 100644 --- a/crypto/goolm/crypto/ed25519_test.go +++ b/crypto/goolm/crypto/ed25519_test.go @@ -5,9 +5,14 @@ import ( "github.com/stretchr/testify/assert" + "maunium.net/go/mautrix/crypto/ed25519" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" ) +const ed25519KeyPairPickleLength = ed25519.PublicKeySize + // PublicKey + ed25519.PrivateKeySize // Private Key + func TestEd25519(t *testing.T) { keypair, err := crypto.Ed25519GenerateKey() assert.NoError(t, err) @@ -38,15 +43,14 @@ func TestEd25519Pickle(t *testing.T) { //create keypair keyPair, err := crypto.Ed25519GenerateKey() assert.NoError(t, err) - target := make([]byte, crypto.Ed25519KeyPairPickleLength) - writtenBytes, err := keyPair.PickleLibOlm(target) - assert.NoError(t, err) - assert.Len(t, target, writtenBytes) + encoder := libolmpickle.NewEncoder() + keyPair.PickleLibOlm(encoder) + assert.Len(t, encoder.Bytes(), ed25519KeyPairPickleLength) unpickledKeyPair := crypto.Ed25519KeyPair{} - readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) + readBytes, err := unpickledKeyPair.UnpickleLibOlm(encoder.Bytes()) assert.NoError(t, err) - assert.Len(t, target, readBytes, "read bytes not correct") + assert.Len(t, encoder.Bytes(), readBytes, "read bytes not correct") assert.Equal(t, keyPair, unpickledKeyPair) } @@ -56,15 +60,14 @@ func TestEd25519PicklePubKeyOnly(t *testing.T) { assert.NoError(t, err) //Remove privateKey keyPair.PrivateKey = nil - target := make([]byte, crypto.Ed25519KeyPairPickleLength) - writtenBytes, err := keyPair.PickleLibOlm(target) - assert.NoError(t, err) - assert.Len(t, target, writtenBytes) + encoder := libolmpickle.NewEncoder() + keyPair.PickleLibOlm(encoder) + assert.Len(t, encoder.Bytes(), ed25519KeyPairPickleLength) unpickledKeyPair := crypto.Ed25519KeyPair{} - readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) + readBytes, err := unpickledKeyPair.UnpickleLibOlm(encoder.Bytes()) assert.NoError(t, err) - assert.Len(t, target, readBytes, "read bytes not correct") + assert.Len(t, encoder.Bytes(), readBytes, "read bytes not correct") assert.Equal(t, keyPair, unpickledKeyPair) } @@ -74,14 +77,13 @@ func TestEd25519PicklePrivKeyOnly(t *testing.T) { assert.NoError(t, err) //Remove public keyPair.PublicKey = nil - target := make([]byte, crypto.Ed25519KeyPairPickleLength) - writtenBytes, err := keyPair.PickleLibOlm(target) - assert.NoError(t, err) - assert.Len(t, target, writtenBytes) + encoder := libolmpickle.NewEncoder() + keyPair.PickleLibOlm(encoder) + assert.Len(t, encoder.Bytes(), ed25519KeyPairPickleLength) unpickledKeyPair := crypto.Ed25519KeyPair{} - readBytes, err := unpickledKeyPair.UnpickleLibOlm(target) + readBytes, err := unpickledKeyPair.UnpickleLibOlm(encoder.Bytes()) assert.NoError(t, err) - assert.Len(t, target, readBytes, "read bytes not correct") + assert.Len(t, encoder.Bytes(), readBytes, "read bytes not correct") assert.Equal(t, keyPair, unpickledKeyPair) } diff --git a/crypto/goolm/crypto/one_time_key.go b/crypto/goolm/crypto/one_time_key.go index 2fbd6366..b7e594ef 100644 --- a/crypto/goolm/crypto/one_time_key.go +++ b/crypto/goolm/crypto/one_time_key.go @@ -3,10 +3,8 @@ package crypto import ( "encoding/base64" "encoding/binary" - "fmt" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" - "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) @@ -17,10 +15,6 @@ type OneTimeKey struct { Key Curve25519KeyPair `json:"key,omitempty"` } -const OneTimeKeyPickleLength = libolmpickle.PickleUInt32Length + // ID - libolmpickle.PickleBoolLength + // Published - Curve25519KeyPairPickleLength // Key - // Equal compares the one time key to the given one. func (otk OneTimeKey) Equal(s OneTimeKey) bool { if otk.ID != s.ID { @@ -38,20 +32,11 @@ func (otk OneTimeKey) Equal(s OneTimeKey) bool { return true } -// PickleLibOlm encodes the key pair into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (c OneTimeKey) PickleLibOlm(target []byte) (int, error) { - if len(target) < OneTimeKeyPickleLength { - return 0, fmt.Errorf("pickle one time key: %w", olm.ErrValueTooShort) - } - written := libolmpickle.PickleUInt32(uint32(c.ID), target) - written += libolmpickle.PickleBool(c.Published, target[written:]) - writtenKey, err := c.Key.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle one time key: %w", err) - } - written += writtenKey - return written, nil +// PickleLibOlm pickles the key pair into the encoder. +func (c OneTimeKey) PickleLibOlm(encoder *libolmpickle.Encoder) { + encoder.WriteUInt32(c.ID) + encoder.WriteBool(c.Published) + c.Key.PickleLibOlm(encoder) } // UnpickleLibOlm decodes the unencryted value and populates the OneTimeKey accordingly. It returns the number of bytes read. diff --git a/crypto/goolm/libolmpickle/pickle.go b/crypto/goolm/libolmpickle/pickle.go index bedeee04..590033fc 100644 --- a/crypto/goolm/libolmpickle/pickle.go +++ b/crypto/goolm/libolmpickle/pickle.go @@ -1,7 +1,10 @@ package libolmpickle import ( + "bytes" "encoding/binary" + + "go.mau.fi/util/exerrors" ) const ( @@ -10,29 +13,28 @@ const ( PickleUInt32Length = 4 ) -func PickleUInt8(value uint8, target []byte) int { - target[0] = value - return 1 +type Encoder struct { + bytes.Buffer } -func PickleBool(value bool, target []byte) int { +func NewEncoder() *Encoder { return &Encoder{} } + +func (p *Encoder) WriteUInt8(value uint8) { + exerrors.PanicIfNotNil(p.WriteByte(value)) +} + +func (p *Encoder) WriteBool(value bool) { if value { - target[0] = 0x01 + exerrors.PanicIfNotNil(p.WriteByte(0x01)) } else { - target[0] = 0x00 + exerrors.PanicIfNotNil(p.WriteByte(0x00)) } - return 1 } -func PickleBytes(value, target []byte) int { - return copy(target, value) -} -func PickleBytesLen(value []byte) int { - return len(value) +func (p *Encoder) WriteEmptyBytes(count int) { + exerrors.Must(p.Write(make([]byte, count))) } -func PickleUInt32(value uint32, target []byte) int { - res := make([]byte, 4) //4 bytes for int32 - binary.BigEndian.PutUint32(res, value) - return copy(target, res) +func (p *Encoder) WriteUInt32(value uint32) { + exerrors.Must(p.Write(binary.BigEndian.AppendUint32(nil, value))) } diff --git a/crypto/goolm/libolmpickle/pickle_test.go b/crypto/goolm/libolmpickle/pickle_test.go index d5596b2a..c7811225 100644 --- a/crypto/goolm/libolmpickle/pickle_test.go +++ b/crypto/goolm/libolmpickle/pickle_test.go @@ -8,6 +8,26 @@ import ( "maunium.net/go/mautrix/crypto/goolm/libolmpickle" ) +func TestEncoder(t *testing.T) { + var encoder libolmpickle.Encoder + encoder.WriteUInt32(4) + encoder.WriteUInt8(8) + encoder.WriteBool(false) + encoder.WriteEmptyBytes(10) + encoder.WriteBool(true) + encoder.Write([]byte("test")) + encoder.WriteUInt32(420_000) + assert.Equal(t, []byte{ + 0x00, 0x00, 0x00, 0x04, // 4 + 0x08, // 8 + 0x00, // false + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // ten empty bytes + 0x01, //true + 0x74, 0x65, 0x73, 0x74, // "test" (ASCII) + 0x00, 0x06, 0x68, 0xa0, // 420,000 + }, encoder.Bytes()) +} + func TestPickleUInt32(t *testing.T) { values := []uint32{ 0xffffffff, @@ -21,11 +41,10 @@ func TestPickleUInt32(t *testing.T) { {0xf0, 0x00, 0x00, 0x00}, {0xf0, 0x0f, 0x00, 0x00}, } - for curIndex := range values { - response := make([]byte, 4) - resPLen := libolmpickle.PickleUInt32(values[curIndex], response) - assert.Equal(t, libolmpickle.PickleUInt32Length, resPLen) - assert.Equal(t, expected[curIndex], response) + for i, value := range values { + var encoder libolmpickle.Encoder + encoder.WriteUInt32(value) + assert.Equal(t, expected[i], encoder.Bytes()) } } @@ -38,11 +57,10 @@ func TestPickleBool(t *testing.T) { {0x01}, {0x00}, } - for curIndex := range values { - response := make([]byte, 1) - resPLen := libolmpickle.PickleBool(values[curIndex], response) - assert.Equal(t, libolmpickle.PickleBoolLength, resPLen) - assert.Equal(t, expected[curIndex], response) + for i, value := range values { + var encoder libolmpickle.Encoder + encoder.WriteBool(value) + assert.Equal(t, expected[i], encoder.Bytes()) } } @@ -55,11 +73,10 @@ func TestPickleUInt8(t *testing.T) { {0xff}, {0x1a}, } - for curIndex := range values { - response := make([]byte, 1) - resPLen := libolmpickle.PickleUInt8(values[curIndex], response) - assert.Equal(t, libolmpickle.PickleUInt8Length, resPLen) - assert.Equal(t, expected[curIndex], response) + for i, value := range values { + var encoder libolmpickle.Encoder + encoder.WriteUInt8(value) + assert.Equal(t, expected[i], encoder.Bytes()) } } @@ -74,10 +91,9 @@ func TestPickleBytes(t *testing.T) { {0x00, 0xff, 0x00, 0xff}, {0xf0, 0x00, 0x00, 0x00}, } - for curIndex := range values { - response := make([]byte, len(values[curIndex])) - resPLen := libolmpickle.PickleBytes(values[curIndex], response) - assert.Equal(t, libolmpickle.PickleBytesLen(values[curIndex]), resPLen) - assert.Equal(t, expected[curIndex], response) + for i, value := range values { + var encoder libolmpickle.Encoder + encoder.Write(value) + assert.Equal(t, expected[i], encoder.Bytes()) } } diff --git a/crypto/goolm/libolmpickle/unpickle.go b/crypto/goolm/libolmpickle/unpickle.go index dbd275aa..66803d34 100644 --- a/crypto/goolm/libolmpickle/unpickle.go +++ b/crypto/goolm/libolmpickle/unpickle.go @@ -14,23 +14,26 @@ func isZeroByteSlice(bytes []byte) bool { return b == 0 } +type Decoder struct { +} + func UnpickleUInt8(value []byte) (uint8, int, error) { if len(value) < 1 { - return 0, 0, fmt.Errorf("unpickle uint8: %w", olm.ErrValueTooShort) + return 0, 0, fmt.Errorf("unpickle uint8: %w", olm.ErrInputToSmall) } return value[0], 1, nil } func UnpickleBool(value []byte) (bool, int, error) { if len(value) < 1 { - return false, 0, fmt.Errorf("unpickle bool: %w", olm.ErrValueTooShort) + return false, 0, fmt.Errorf("unpickle bool: %w", olm.ErrInputToSmall) } return value[0] != uint8(0x00), 1, nil } func UnpickleBytes(value []byte, length int) ([]byte, int, error) { if len(value) < length { - return nil, 0, fmt.Errorf("unpickle bytes: %w", olm.ErrValueTooShort) + return nil, 0, fmt.Errorf("unpickle bytes: %w", olm.ErrInputToSmall) } resp := value[:length] if isZeroByteSlice(resp) { @@ -41,7 +44,7 @@ func UnpickleBytes(value []byte, length int) ([]byte, int, error) { func UnpickleUInt32(value []byte) (uint32, int, error) { if len(value) < 4 { - return 0, 0, fmt.Errorf("unpickle uint32: %w", olm.ErrValueTooShort) + return 0, 0, fmt.Errorf("unpickle uint32: %w", olm.ErrInputToSmall) } var res uint32 count := 0 diff --git a/crypto/goolm/megolm/megolm.go b/crypto/goolm/megolm/megolm.go index 47e24077..f44e8cb7 100644 --- a/crypto/goolm/megolm/megolm.go +++ b/crypto/goolm/megolm/megolm.go @@ -22,9 +22,6 @@ const ( protocolVersion = 3 RatchetParts = 4 // number of ratchet parts RatchetPartLength = 256 / 8 // length of each ratchet part in bytes - - RatchetPickleLength = (RatchetParts * RatchetPartLength) + //Data - libolmpickle.PickleUInt32Length // Counter ) var RatchetCipher = cipher.NewAESSHA256([]byte("MEGOLM_KEYS")) @@ -219,13 +216,8 @@ func (r *Ratchet) UnpickleLibOlm(unpickled []byte) (int, error) { return curPos, nil } -// PickleLibOlm encodes the ratchet into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (r Ratchet) PickleLibOlm(target []byte) (int, error) { - if len(target) < RatchetPickleLength { - return 0, fmt.Errorf("pickle account: %w", olm.ErrValueTooShort) - } - written := libolmpickle.PickleBytes(r.Data[:], target) - written += libolmpickle.PickleUInt32(r.Counter, target[written:]) - return written, nil +// PickleLibOlm pickles the ratchet into the encoder. +func (r Ratchet) PickleLibOlm(encoder *libolmpickle.Encoder) { + encoder.Write(r.Data[:]) + encoder.WriteUInt32(r.Counter) } diff --git a/crypto/goolm/message/prekey_message.go b/crypto/goolm/message/prekey_message.go index 6e007e06..1238a9a5 100644 --- a/crypto/goolm/message/prekey_message.go +++ b/crypto/goolm/message/prekey_message.go @@ -74,11 +74,11 @@ func (r *PreKeyMessage) CheckFields(theirIdentityKey *crypto.Curve25519PublicKey ok := true ok = ok && (theirIdentityKey != nil || r.IdentityKey != nil) if r.IdentityKey != nil { - ok = ok && (len(r.IdentityKey) == crypto.Curve25519KeyLength) + ok = ok && (len(r.IdentityKey) == crypto.Curve25519PrivateKeyLength) } ok = ok && len(r.Message) != 0 - ok = ok && len(r.BaseKey) == crypto.Curve25519KeyLength - ok = ok && len(r.OneTimeKey) == crypto.Curve25519KeyLength + ok = ok && len(r.BaseKey) == crypto.Curve25519PrivateKeyLength + ok = ok && len(r.OneTimeKey) == crypto.Curve25519PrivateKeyLength return ok } diff --git a/crypto/goolm/pk/decryption.go b/crypto/goolm/pk/decryption.go index fb537eaf..cc363841 100644 --- a/crypto/goolm/pk/decryption.go +++ b/crypto/goolm/pk/decryption.go @@ -2,7 +2,6 @@ package pk import ( "encoding/base64" - "errors" "fmt" "maunium.net/go/mautrix/crypto/goolm/cipher" @@ -17,9 +16,6 @@ import ( const ( decryptionPickleVersionJSON uint8 = 1 decryptionPickleVersionLibOlm uint32 = 1 - - DecryptionPickleLength = libolmpickle.PickleUInt32Length + // Version - crypto.Curve25519KeyPairPickleLength // KeyPair ) // Decryption is used to decrypt pk messages @@ -127,29 +123,13 @@ func (a *Decryption) UnpickleLibOlm(value []byte) (int, error) { // Pickle returns a base64 encoded and with key encrypted pickled Decryption using PickleLibOlm(). func (a Decryption) Pickle(key []byte) ([]byte, error) { - pickeledBytes := make([]byte, DecryptionPickleLength) - written, err := a.PickleLibOlm(pickeledBytes) - if err != nil { - return nil, err - } - if written != len(pickeledBytes) { - return nil, errors.New("number of written bytes not correct") - } - return cipher.Pickle(key, pickeledBytes) + return cipher.Pickle(key, a.PickleLibOlm()) } -// PickleLibOlm encodes the Decryption into target. target has to have a size -// of at least [DecryptionPickleLength] and is written to from index 0. It -// returns the number of bytes written. -func (a Decryption) PickleLibOlm(target []byte) (int, error) { - if len(target) < DecryptionPickleLength { - return 0, fmt.Errorf("pickle Decryption: %w", olm.ErrValueTooShort) - } - written := libolmpickle.PickleUInt32(decryptionPickleVersionLibOlm, target) - writtenKey, err := a.KeyPair.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle Decryption: %w", err) - } - written += writtenKey - return written, nil +// PickleLibOlm pickles the [Decryption] into the encoder. +func (a Decryption) PickleLibOlm() []byte { + encoder := libolmpickle.NewEncoder() + encoder.WriteUInt32(decryptionPickleVersionLibOlm) + a.KeyPair.PickleLibOlm(encoder) + return encoder.Bytes() } diff --git a/crypto/goolm/ratchet/chain.go b/crypto/goolm/ratchet/chain.go index ea0400e8..124d6906 100644 --- a/crypto/goolm/ratchet/chain.go +++ b/crypto/goolm/ratchet/chain.go @@ -1,11 +1,8 @@ package ratchet import ( - "fmt" - "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" - "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -19,9 +16,6 @@ type chainKey struct { Key crypto.Curve25519PublicKey `json:"key"` } -const chainKeyPickleLength = crypto.Curve25519PubKeyLength + // Key - libolmpickle.PickleUInt32Length // Index - // advance advances the chain func (c *chainKey) advance() { c.Key = crypto.HMACSHA256(c.Key, []byte{chainKeySeed}) @@ -44,18 +38,10 @@ func (r *chainKey) UnpickleLibOlm(value []byte) (int, error) { return curPos, nil } -// PickleLibOlm encodes the chain key into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (r chainKey) PickleLibOlm(target []byte) (int, error) { - if len(target) < chainKeyPickleLength { - return 0, fmt.Errorf("pickle chain key: %w", olm.ErrValueTooShort) - } - written, err := r.Key.PickleLibOlm(target) - if err != nil { - return 0, fmt.Errorf("pickle chain key: %w", err) - } - written += libolmpickle.PickleUInt32(r.Index, target[written:]) - return written, nil +// PickleLibOlm pickles the chain key into the encoder. +func (r chainKey) PickleLibOlm(encoder *libolmpickle.Encoder) { + r.Key.PickleLibOlm(encoder) + encoder.WriteUInt32(r.Index) } // senderChain is a chain for sending messages @@ -65,9 +51,6 @@ type senderChain struct { IsSet bool `json:"set"` } -const senderChainPickleLength = chainKeyPickleLength + // RKey - chainKeyPickleLength // CKey - // newSenderChain returns a sender chain initialized with chainKey and ratchet key pair. func newSenderChain(key crypto.Curve25519PublicKey, ratchet crypto.Curve25519KeyPair) *senderChain { return &senderChain{ @@ -111,22 +94,15 @@ func (r *senderChain) UnpickleLibOlm(value []byte) (int, error) { return curPos, nil } -// PickleLibOlm encodes the chain into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (r senderChain) PickleLibOlm(target []byte) (int, error) { - if len(target) < senderChainPickleLength { - return 0, fmt.Errorf("pickle sender chain: %w", olm.ErrValueTooShort) +// PickleLibOlm pickles the sender chain into the encoder. +func (r senderChain) PickleLibOlm(encoder *libolmpickle.Encoder) { + if r.IsSet { + encoder.WriteUInt32(1) // Length of the sender chain (1 if set) + r.RKey.PickleLibOlm(encoder) + r.CKey.PickleLibOlm(encoder) + } else { + encoder.WriteUInt32(0) } - written, err := r.RKey.PickleLibOlm(target) - if err != nil { - return 0, fmt.Errorf("pickle sender chain: %w", err) - } - writtenChain, err := r.CKey.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle sender chain: %w", err) - } - written += writtenChain - return written, nil } // senderChain is a chain for receiving messages @@ -135,9 +111,6 @@ type receiverChain struct { CKey chainKey `json:"chain_key"` } -const receiverChainPickleLength = crypto.Curve25519PubKeyLength + // Ratchet Key - chainKeyPickleLength // CKey - // newReceiverChain returns a receiver chain initialized with chainKey and ratchet public key. func newReceiverChain(chain crypto.Curve25519PublicKey, ratchet crypto.Curve25519PublicKey) *receiverChain { return &receiverChain{ @@ -180,22 +153,10 @@ func (r *receiverChain) UnpickleLibOlm(value []byte) (int, error) { return curPos, nil } -// PickleLibOlm encodes the chain into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (r receiverChain) PickleLibOlm(target []byte) (int, error) { - if len(target) < receiverChainPickleLength { - return 0, fmt.Errorf("pickle sender chain: %w", olm.ErrValueTooShort) - } - written, err := r.RKey.PickleLibOlm(target) - if err != nil { - return 0, fmt.Errorf("pickle sender chain: %w", err) - } - writtenChain, err := r.CKey.PickleLibOlm(target) - if err != nil { - return 0, fmt.Errorf("pickle sender chain: %w", err) - } - written += writtenChain - return written, nil +// PickleLibOlm pickles the receiver chain into the encoder. +func (r receiverChain) PickleLibOlm(encoder *libolmpickle.Encoder) { + r.RKey.PickleLibOlm(encoder) + r.CKey.PickleLibOlm(encoder) } // messageKey wraps the index and the key of a message @@ -204,9 +165,6 @@ type messageKey struct { Key []byte `json:"key"` } -const messageKeyPickleLength = messageKeyLength + // Key - libolmpickle.PickleUInt32Length // Index - // UnpickleLibOlm decodes the unencryted value and populates the message key accordingly. It returns the number of bytes read. func (m *messageKey) UnpickleLibOlm(value []byte) (int, error) { curPos := 0 @@ -225,18 +183,12 @@ func (m *messageKey) UnpickleLibOlm(value []byte) (int, error) { return curPos, nil } -// PickleLibOlm encodes the message key into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (m messageKey) PickleLibOlm(target []byte) (int, error) { - if len(target) < messageKeyPickleLength { - return 0, fmt.Errorf("pickle message key: %w", olm.ErrValueTooShort) - } - written := 0 - if len(m.Key) != messageKeyLength { - written += libolmpickle.PickleBytes(make([]byte, messageKeyLength), target) +// PickleLibOlm pickles the message key into the encoder. +func (m messageKey) PickleLibOlm(encoder *libolmpickle.Encoder) { + if len(m.Key) == messageKeyLength { + encoder.Write(m.Key) } else { - written += libolmpickle.PickleBytes(m.Key, target) + encoder.WriteEmptyBytes(messageKeyLength) } - written += libolmpickle.PickleUInt32(m.Index, target[written:]) - return written, nil + encoder.WriteUInt32(m.Index) } diff --git a/crypto/goolm/ratchet/olm.go b/crypto/goolm/ratchet/olm.go index 8e0944d4..81433654 100644 --- a/crypto/goolm/ratchet/olm.go +++ b/crypto/goolm/ratchet/olm.go @@ -347,66 +347,20 @@ func (r *Ratchet) UnpickleLibOlm(value []byte, includesChainIndex bool) (int, er return curPos, nil } -// PickleLibOlm encodes the ratchet into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (r Ratchet) PickleLibOlm(target []byte) (int, error) { - if len(target) < r.PickleLen() { - return 0, fmt.Errorf("pickle ratchet: %w", olm.ErrValueTooShort) - } - written, err := r.RootKey.PickleLibOlm(target) - if err != nil { - return 0, fmt.Errorf("pickle ratchet: %w", err) - } - if r.SenderChains.IsSet { - written += libolmpickle.PickleUInt32(1, target[written:]) //Length of sender chain, always 1 - writtenSender, err := r.SenderChains.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle ratchet: %w", err) - } - written += writtenSender - } else { - written += libolmpickle.PickleUInt32(0, target[written:]) //Length of sender chain - } - written += libolmpickle.PickleUInt32(uint32(len(r.ReceiverChains)), target[written:]) +// PickleLibOlm pickles the ratchet into the encoder. +func (r Ratchet) PickleLibOlm(encoder *libolmpickle.Encoder) { + r.RootKey.PickleLibOlm(encoder) + r.SenderChains.PickleLibOlm(encoder) + + // Receiver Chains + encoder.WriteUInt32(uint32(len(r.ReceiverChains))) for _, curChain := range r.ReceiverChains { - writtenChain, err := curChain.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle ratchet: %w", err) - } - written += writtenChain + curChain.PickleLibOlm(encoder) } - written += libolmpickle.PickleUInt32(uint32(len(r.SkippedMessageKeys)), target[written:]) + + // Skipped Message Keys + encoder.WriteUInt32(uint32(len(r.SkippedMessageKeys))) for _, curChain := range r.SkippedMessageKeys { - writtenChain, err := curChain.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle ratchet: %w", err) - } - written += writtenChain + curChain.PickleLibOlm(encoder) } - return written, nil -} - -// PickleLen returns the actual number of bytes the pickled ratchet will have. -func (r Ratchet) PickleLen() int { - length := crypto.Curve25519PubKeyLength // Root Key - if r.SenderChains.IsSet { - length += libolmpickle.PickleUInt32Length // 1 - length += senderChainPickleLength // SenderChains - } else { - length += libolmpickle.PickleUInt32Length // 0 - } - length += libolmpickle.PickleUInt32Length // ReceiverChains length - length += len(r.ReceiverChains) * receiverChainPickleLength - length += libolmpickle.PickleUInt32Length // SkippedMessageKeys length - length += len(r.SkippedMessageKeys) * skippedMessageKeyPickleLen - return length -} - -// PickleLen returns the minimum number of bytes the pickled ratchet must have. -func (r Ratchet) PickleLenMin() int { - length := crypto.Curve25519PubKeyLength // Root Key - length += libolmpickle.PickleUInt32Length - length += libolmpickle.PickleUInt32Length - length += libolmpickle.PickleUInt32Length - return length } diff --git a/crypto/goolm/ratchet/skipped_message.go b/crypto/goolm/ratchet/skipped_message.go index b577edbe..7510548f 100644 --- a/crypto/goolm/ratchet/skipped_message.go +++ b/crypto/goolm/ratchet/skipped_message.go @@ -1,10 +1,8 @@ package ratchet import ( - "fmt" - "maunium.net/go/mautrix/crypto/goolm/crypto" - "maunium.net/go/mautrix/crypto/olm" + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" ) // skippedMessageKey stores a skipped message key @@ -13,9 +11,6 @@ type skippedMessageKey struct { MKey messageKey `json:"message_key"` } -const skippedMessageKeyPickleLen = crypto.Curve25519PubKeyLength + // RKey - messageKeyPickleLength // MKey - // UnpickleLibOlm decodes the unencryted value and populates the chain accordingly. It returns the number of bytes read. func (r *skippedMessageKey) UnpickleLibOlm(value []byte) (int, error) { curPos := 0 @@ -32,20 +27,8 @@ func (r *skippedMessageKey) UnpickleLibOlm(value []byte) (int, error) { return curPos, nil } -// PickleLibOlm encodes the chain into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (r skippedMessageKey) PickleLibOlm(target []byte) (int, error) { - if len(target) < skippedMessageKeyPickleLen { - return 0, fmt.Errorf("pickle sender chain: %w", olm.ErrValueTooShort) - } - written, err := r.RKey.PickleLibOlm(target) - if err != nil { - return 0, fmt.Errorf("pickle sender chain: %w", err) - } - writtenChain, err := r.MKey.PickleLibOlm(target) - if err != nil { - return 0, fmt.Errorf("pickle sender chain: %w", err) - } - written += writtenChain - return written, nil +// PickleLibOlm pickles the skipped message key into the encoder. +func (r skippedMessageKey) PickleLibOlm(encoder *libolmpickle.Encoder) { + r.RKey.PickleLibOlm(encoder) + r.MKey.PickleLibOlm(encoder) } diff --git a/crypto/goolm/session/megolm_inbound_session.go b/crypto/goolm/session/megolm_inbound_session.go index 50ca4d97..c263f268 100644 --- a/crypto/goolm/session/megolm_inbound_session.go +++ b/crypto/goolm/session/megolm_inbound_session.go @@ -2,10 +2,8 @@ package session import ( "encoding/base64" - "errors" "fmt" - "maunium.net/go/mautrix/crypto/ed25519" "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/goolmbase64" @@ -20,12 +18,6 @@ import ( const ( megolmInboundSessionPickleVersionJSON byte = 1 megolmInboundSessionPickleVersionLibOlm uint32 = 2 - - megolmInboundSessionPickleLength = libolmpickle.PickleUInt32Length + // Version - megolm.RatchetPickleLength + // InitialRatchet - megolm.RatchetPickleLength + // Ratchet - ed25519.PublicKeySize + // SigningKey - libolmpickle.PickleBoolLength // Verified ) // MegolmInboundSession stores information about the sessions of receive. @@ -253,41 +245,18 @@ func (o *MegolmInboundSession) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { return nil, olm.ErrNoKeyProvided } - pickeledBytes := make([]byte, megolmInboundSessionPickleLength) - written, err := o.PickleLibOlm(pickeledBytes) - if err != nil { - return nil, err - } - if written != len(pickeledBytes) { - return nil, errors.New("number of written bytes not correct") - } - return cipher.Pickle(key, pickeledBytes) + return cipher.Pickle(key, o.PickleLibOlm()) } -// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (o *MegolmInboundSession) PickleLibOlm(target []byte) (int, error) { - if len(target) < megolmInboundSessionPickleLength { - return 0, fmt.Errorf("pickle MegolmInboundSession: %w", olm.ErrValueTooShort) - } - written := libolmpickle.PickleUInt32(megolmInboundSessionPickleVersionLibOlm, target) - writtenInitRatchet, err := o.InitialRatchet.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle MegolmInboundSession: %w", err) - } - written += writtenInitRatchet - writtenRatchet, err := o.Ratchet.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle MegolmInboundSession: %w", err) - } - written += writtenRatchet - writtenPubKey, err := o.SigningKey.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle MegolmInboundSession: %w", err) - } - written += writtenPubKey - written += libolmpickle.PickleBool(o.SigningKeyVerified, target[written:]) - return written, nil +// PickleLibOlm pickles the session returning the raw bytes. +func (o *MegolmInboundSession) PickleLibOlm() []byte { + encoder := libolmpickle.NewEncoder() + encoder.WriteUInt32(megolmInboundSessionPickleVersionLibOlm) + o.InitialRatchet.PickleLibOlm(encoder) + o.Ratchet.PickleLibOlm(encoder) + o.SigningKey.PickleLibOlm(encoder) + encoder.WriteBool(o.SigningKeyVerified) + return encoder.Bytes() } // FirstKnownIndex returns the first message index we know how to decrypt. diff --git a/crypto/goolm/session/megolm_outbound_session.go b/crypto/goolm/session/megolm_outbound_session.go index d58650ce..7b498d88 100644 --- a/crypto/goolm/session/megolm_outbound_session.go +++ b/crypto/goolm/session/megolm_outbound_session.go @@ -3,7 +3,6 @@ package session import ( "crypto/rand" "encoding/base64" - "errors" "fmt" "go.mau.fi/util/exerrors" @@ -21,10 +20,6 @@ import ( const ( megolmOutboundSessionPickleVersion byte = 1 megolmOutboundSessionPickleVersionLibOlm uint32 = 1 - - MegolmOutboundSessionPickleLength = libolmpickle.PickleUInt32Length + // Version - megolm.RatchetPickleLength + // Ratchet - crypto.Ed25519KeyPairPickleLength // SigningKey ) // MegolmOutboundSession stores information about the sessions to send. @@ -134,35 +129,16 @@ func (o *MegolmOutboundSession) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { return nil, olm.ErrNoKeyProvided } - pickeledBytes := make([]byte, MegolmOutboundSessionPickleLength) - written, err := o.PickleLibOlm(pickeledBytes) - if err != nil { - return nil, err - } - if written != len(pickeledBytes) { - return nil, errors.New("number of written bytes not correct") - } - return cipher.Pickle(key, pickeledBytes) + return cipher.Pickle(key, o.PickleLibOlm()) } -// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (o *MegolmOutboundSession) PickleLibOlm(target []byte) (int, error) { - if len(target) < MegolmOutboundSessionPickleLength { - return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", olm.ErrValueTooShort) - } - written := libolmpickle.PickleUInt32(megolmOutboundSessionPickleVersionLibOlm, target) - writtenRatchet, err := o.Ratchet.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err) - } - written += writtenRatchet - writtenPubKey, err := o.SigningKey.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err) - } - written += writtenPubKey - return written, nil +// PickleLibOlm pickles the session returning the raw bytes. +func (o *MegolmOutboundSession) PickleLibOlm() []byte { + encoder := libolmpickle.NewEncoder() + encoder.WriteUInt32(megolmOutboundSessionPickleVersionLibOlm) + o.Ratchet.PickleLibOlm(encoder) + o.SigningKey.PickleLibOlm(encoder) + return encoder.Bytes() } func (o *MegolmOutboundSession) SessionSharingMessage() ([]byte, error) { diff --git a/crypto/goolm/session/olm_session.go b/crypto/goolm/session/olm_session.go index 2c9585b3..533aafb5 100644 --- a/crypto/goolm/session/olm_session.go +++ b/crypto/goolm/session/olm_session.go @@ -3,7 +3,6 @@ package session import ( "bytes" "encoding/base64" - "errors" "fmt" "strings" @@ -200,10 +199,10 @@ func (a *OlmSession) UnpickleAsJSON(pickled, key []byte) error { // ID returns an identifier for this Session. Will be the same for both ends of the conversation. // Generated by hashing the public keys used to create the session. func (s *OlmSession) ID() id.SessionID { - message := make([]byte, 3*crypto.Curve25519KeyLength) + message := make([]byte, 3*crypto.Curve25519PrivateKeyLength) copy(message, s.AliceIdentityKey) - copy(message[crypto.Curve25519KeyLength:], s.AliceBaseKey) - copy(message[2*crypto.Curve25519KeyLength:], s.BobOneTimeKey) + copy(message[crypto.Curve25519PrivateKeyLength:], s.AliceBaseKey) + copy(message[2*crypto.Curve25519PrivateKeyLength:], s.BobOneTimeKey) hash := crypto.SHA256(message) res := id.SessionID(goolmbase64.Encode(hash)) return res @@ -414,68 +413,19 @@ func (s *OlmSession) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { return nil, olm.ErrNoKeyProvided } - pickeledBytes := make([]byte, s.PickleLen()) - written, err := s.PickleLibOlm(pickeledBytes) - if err != nil { - return nil, err - } - if written != len(pickeledBytes) { - return nil, errors.New("number of written bytes not correct") - } - return cipher.Pickle(key, pickeledBytes) + return cipher.Pickle(key, s.PickleLibOlm()) } -// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0. -// It returns the number of bytes written. -func (o *OlmSession) PickleLibOlm(target []byte) (int, error) { - if len(target) < o.PickleLen() { - return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", olm.ErrValueTooShort) - } - written := libolmpickle.PickleUInt32(olmSessionPickleVersionLibOlm, target) - written += libolmpickle.PickleBool(o.ReceivedMessage, target[written:]) - writtenRatchet, err := o.AliceIdentityKey.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err) - } - written += writtenRatchet - writtenRatchet, err = o.AliceBaseKey.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err) - } - written += writtenRatchet - writtenRatchet, err = o.BobOneTimeKey.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err) - } - written += writtenRatchet - writtenRatchet, err = o.Ratchet.PickleLibOlm(target[written:]) - if err != nil { - return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err) - } - written += writtenRatchet - return written, nil -} - -// PickleLen returns the actual number of bytes the pickled session will have. -func (o *OlmSession) PickleLen() int { - length := libolmpickle.PickleUInt32Length - length += libolmpickle.PickleBoolLength - length += crypto.Curve25519PubKeyLength // AliceIdentityKey - length += crypto.Curve25519PubKeyLength // AliceBaseKey - length += crypto.Curve25519PubKeyLength // BobOneTimeKey - length += o.Ratchet.PickleLen() - return length -} - -// PickleLenMin returns the minimum number of bytes the pickled session must have. -func (o *OlmSession) PickleLenMin() int { - length := libolmpickle.PickleUInt32Length - length += libolmpickle.PickleBoolLength - length += crypto.Curve25519PubKeyLength // AliceIdentityKey - length += crypto.Curve25519PubKeyLength // AliceBaseKey - length += crypto.Curve25519PubKeyLength // BobOneTimeKey - length += o.Ratchet.PickleLenMin() - return length +// PickleLibOlm pickles the session and returns the raw bytes. +func (o *OlmSession) PickleLibOlm() []byte { + encoder := libolmpickle.NewEncoder() + encoder.WriteUInt32(olmSessionPickleVersionLibOlm) + encoder.WriteBool(o.ReceivedMessage) + o.AliceIdentityKey.PickleLibOlm(encoder) + o.AliceBaseKey.PickleLibOlm(encoder) + o.BobOneTimeKey.PickleLibOlm(encoder) + o.Ratchet.PickleLibOlm(encoder) + return encoder.Bytes() } // Describe returns a string describing the current state of the session for debugging. diff --git a/crypto/olm/errors.go b/crypto/olm/errors.go index c80b82e4..957d7928 100644 --- a/crypto/olm/errors.go +++ b/crypto/olm/errors.go @@ -26,7 +26,6 @@ var ( ErrBadInput = errors.New("bad input") ErrBadVersion = errors.New("wrong version") ErrWrongPickleVersion = errors.New("wrong pickle version") - ErrValueTooShort = errors.New("value too short") ErrInputToSmall = errors.New("input too small (truncated?)") ErrOverflow = errors.New("overflow") ) From d2aaa2dc5c972aa26bf39db99aea8ae14da5fd25 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 25 Oct 2024 00:54:38 -0600 Subject: [PATCH 0874/1647] goolm/libolmpickle: add Decoder for easier pickling API Signed-off-by: Sumner Evans --- crypto/goolm/account/account.go | 114 +++++++----------- crypto/goolm/crypto/curve25519.go | 31 ++--- crypto/goolm/crypto/curve25519_test.go | 9 +- crypto/goolm/crypto/ed25519.go | 36 +++--- crypto/goolm/crypto/ed25519_test.go | 9 +- crypto/goolm/crypto/one_time_key.go | 28 ++--- crypto/goolm/libolmpickle/unpickle.go | 68 +++++------ crypto/goolm/libolmpickle/unpickle_test.go | 16 +-- crypto/goolm/megolm/megolm.go | 20 +-- crypto/goolm/pk/decryption.go | 25 ++-- crypto/goolm/ratchet/chain.go | 77 ++++-------- crypto/goolm/ratchet/olm.go | 88 ++++++-------- crypto/goolm/ratchet/skipped_message.go | 19 +-- .../goolm/session/megolm_inbound_session.go | 51 +++----- .../goolm/session/megolm_outbound_session.go | 34 ++---- crypto/goolm/session/olm_session.go | 56 +++------ 16 files changed, 259 insertions(+), 422 deletions(-) diff --git a/crypto/goolm/account/account.go b/crypto/goolm/account/account.go index 2a99e985..f3554e29 100644 --- a/crypto/goolm/account/account.go +++ b/crypto/goolm/account/account.go @@ -323,65 +323,46 @@ func (a *Account) Unpickle(pickled, key []byte) error { if err != nil { return err } - _, err = a.UnpickleLibOlm(decrypted) - return err + return a.UnpickleLibOlm(decrypted) } -// UnpickleLibOlm decodes the unencryted value and populates the Account accordingly. It returns the number of bytes read. -func (a *Account) UnpickleLibOlm(value []byte) (int, error) { - //First 4 bytes are the accountPickleVersion - pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value) +// UnpickleLibOlm unpickles the unencryted value and populates the [Account] accordingly. +func (a *Account) UnpickleLibOlm(buf []byte) error { + decoder := libolmpickle.NewDecoder(buf) + pickledVersion, err := decoder.ReadUInt32() if err != nil { - return 0, err + return err + } else if pickledVersion != accountPickleVersionLibOLM && pickledVersion != 3 && pickledVersion != 2 { + return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrBadVersion, pickledVersion) + } else if err = a.IdKeys.Ed25519.UnpickleLibOlm(decoder); err != nil { // read the ed25519 key pair + return err + } else if err = a.IdKeys.Curve25519.UnpickleLibOlm(decoder); err != nil { // read curve25519 key pair + return err } - switch pickledVersion { - case accountPickleVersionLibOLM, 3, 2: - default: - return 0, fmt.Errorf("unpickle account: %w", olm.ErrBadVersion) - } - //read ed25519 key pair - readBytes, err := a.IdKeys.Ed25519.UnpickleLibOlm(value[curPos:]) + + otkCount, err := decoder.ReadUInt32() if err != nil { - return 0, err + return err } - curPos += readBytes - //read curve25519 key pair - readBytes, err = a.IdKeys.Curve25519.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - //Read number of onetimeKeys - numberOTKeys, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - //Read i one time keys - a.OTKeys = make([]crypto.OneTimeKey, numberOTKeys) - for i := uint32(0); i < numberOTKeys; i++ { - readBytes, err := a.OTKeys[i].UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err + + a.OTKeys = make([]crypto.OneTimeKey, otkCount) + for i := uint32(0); i < otkCount; i++ { + if err := a.OTKeys[i].UnpickleLibOlm(decoder); err != nil { + return err } - curPos += readBytes } + if pickledVersion <= 2 { // version 2 did not have fallback keys a.NumFallbackKeys = 0 } else if pickledVersion == 3 { // version 3 used the published flag to indicate how many fallback keys // were present (we'll have to assume that the keys were published) - readBytes, err := a.CurrentFallbackKey.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err + if err = a.CurrentFallbackKey.UnpickleLibOlm(decoder); err != nil { + return err + } else if err = a.PrevFallbackKey.UnpickleLibOlm(decoder); err != nil { + return err } - curPos += readBytes - readBytes, err = a.PrevFallbackKey.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes if a.CurrentFallbackKey.Published { if a.PrevFallbackKey.Published { a.NumFallbackKeys = 2 @@ -392,36 +373,33 @@ func (a *Account) UnpickleLibOlm(value []byte) (int, error) { a.NumFallbackKeys = 0 } } else { - //Read number of fallback keys - numFallbackKeys, readBytes, err := libolmpickle.UnpickleUInt8(value[curPos:]) + // Read number of fallback keys + a.NumFallbackKeys, err = decoder.ReadUInt8() if err != nil { - return 0, err + return err } - curPos += readBytes - a.NumFallbackKeys = numFallbackKeys - if a.NumFallbackKeys >= 1 { - readBytes, err := a.CurrentFallbackKey.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - if a.NumFallbackKeys >= 2 { - readBytes, err := a.PrevFallbackKey.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err + for i := 0; i < int(a.NumFallbackKeys); i++ { + switch i { + case 0: + if err = a.CurrentFallbackKey.UnpickleLibOlm(decoder); err != nil { + return err + } + case 1: + if err = a.PrevFallbackKey.UnpickleLibOlm(decoder); err != nil { + return err + } + default: + // Just drain any remaining fallback keys + if err = (&crypto.OneTimeKey{}).UnpickleLibOlm(decoder); err != nil { + return err } - curPos += readBytes } } } - //Read next onetime key id - nextOTKeyID, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - a.NextOneTimeKeyID = nextOTKeyID - return curPos, nil + + //Read next onetime key ID + a.NextOneTimeKeyID, err = decoder.ReadUInt32() + return err } // Pickle returns a base64 encoded and with key encrypted pickled account using PickleLibOlm(). diff --git a/crypto/goolm/crypto/curve25519.go b/crypto/goolm/crypto/curve25519.go index 459031f1..1dbc83fd 100644 --- a/crypto/goolm/crypto/curve25519.go +++ b/crypto/goolm/crypto/curve25519.go @@ -70,19 +70,15 @@ func (c Curve25519KeyPair) PickleLibOlm(encoder *libolmpickle.Encoder) { } // UnpickleLibOlm decodes the unencryted value and populates the key pair accordingly. It returns the number of bytes read. -func (c *Curve25519KeyPair) UnpickleLibOlm(value []byte) (int, error) { - //unpickle PubKey - read, err := c.PublicKey.UnpickleLibOlm(value) - if err != nil { - return 0, err +func (c *Curve25519KeyPair) UnpickleLibOlm(decoder *libolmpickle.Decoder) error { + if err := c.PublicKey.UnpickleLibOlm(decoder); err != nil { + return err + } else if privKey, err := decoder.ReadBytes(Curve25519PrivateKeyLength); err != nil { + return err + } else { + c.PrivateKey = privKey + return nil } - //unpickle PrivateKey - privKey, readPriv, err := libolmpickle.UnpickleBytes(value[read:], Curve25519PrivateKeyLength) - if err != nil { - return read, err - } - c.PrivateKey = privKey - return read + readPriv, nil } // Curve25519PrivateKey represents the private key for curve25519 usage @@ -126,11 +122,8 @@ func (c Curve25519PublicKey) PickleLibOlm(encoder *libolmpickle.Encoder) { } // UnpickleLibOlm decodes the unencryted value and populates the public key accordingly. It returns the number of bytes read. -func (c *Curve25519PublicKey) UnpickleLibOlm(value []byte) (int, error) { - unpickled, readBytes, err := libolmpickle.UnpickleBytes(value, Curve25519PublicKeyLength) - if err != nil { - return 0, err - } - *c = unpickled - return readBytes, nil +func (c *Curve25519PublicKey) UnpickleLibOlm(decoder *libolmpickle.Decoder) error { + pubkey, err := decoder.ReadBytes(Curve25519PublicKeyLength) + *c = pubkey + return err } diff --git a/crypto/goolm/crypto/curve25519_test.go b/crypto/goolm/crypto/curve25519_test.go index fb9f0098..9039c126 100644 --- a/crypto/goolm/crypto/curve25519_test.go +++ b/crypto/goolm/crypto/curve25519_test.go @@ -84,9 +84,8 @@ func TestCurve25519Pickle(t *testing.T) { assert.Len(t, encoder.Bytes(), curve25519KeyPairPickleLength) unpickledKeyPair := crypto.Curve25519KeyPair{} - readBytes, err := unpickledKeyPair.UnpickleLibOlm(encoder.Bytes()) + err = unpickledKeyPair.UnpickleLibOlm(libolmpickle.NewDecoder(encoder.Bytes())) assert.NoError(t, err) - assert.Len(t, encoder.Bytes(), readBytes) assert.Equal(t, keyPair, unpickledKeyPair) } @@ -103,9 +102,8 @@ func TestCurve25519PicklePubKeyOnly(t *testing.T) { assert.Len(t, encoder.Bytes(), curve25519KeyPairPickleLength) unpickledKeyPair := crypto.Curve25519KeyPair{} - readBytes, err := unpickledKeyPair.UnpickleLibOlm(encoder.Bytes()) + err = unpickledKeyPair.UnpickleLibOlm(libolmpickle.NewDecoder(encoder.Bytes())) assert.NoError(t, err) - assert.Len(t, encoder.Bytes(), readBytes) assert.Equal(t, keyPair, unpickledKeyPair) } @@ -119,8 +117,7 @@ func TestCurve25519PicklePrivKeyOnly(t *testing.T) { keyPair.PickleLibOlm(encoder) assert.Len(t, encoder.Bytes(), curve25519KeyPairPickleLength) unpickledKeyPair := crypto.Curve25519KeyPair{} - readBytes, err := unpickledKeyPair.UnpickleLibOlm(encoder.Bytes()) + err = unpickledKeyPair.UnpickleLibOlm(libolmpickle.NewDecoder(encoder.Bytes())) assert.NoError(t, err) - assert.Len(t, encoder.Bytes(), readBytes) assert.Equal(t, keyPair, unpickledKeyPair) } diff --git a/crypto/goolm/crypto/ed25519.go b/crypto/goolm/crypto/ed25519.go index bc535e6a..34ad397a 100644 --- a/crypto/goolm/crypto/ed25519.go +++ b/crypto/goolm/crypto/ed25519.go @@ -69,20 +69,16 @@ func (c Ed25519KeyPair) PickleLibOlm(encoder *libolmpickle.Encoder) { } } -// UnpickleLibOlm decodes the unencryted value and populates the key pair accordingly. It returns the number of bytes read. -func (c *Ed25519KeyPair) UnpickleLibOlm(value []byte) (int, error) { - //unpickle PubKey - read, err := c.PublicKey.UnpickleLibOlm(value) - if err != nil { - return 0, err +// UnpickleLibOlm unpickles the unencryted value and populates the key pair accordingly. +func (c *Ed25519KeyPair) UnpickleLibOlm(decoder *libolmpickle.Decoder) error { + if err := c.PublicKey.UnpickleLibOlm(decoder); err != nil { + return err + } else if privKey, err := decoder.ReadBytes(ed25519.PrivateKeySize); err != nil { + return err + } else { + c.PrivateKey = privKey + return nil } - //unpickle PrivateKey - privKey, readPriv, err := libolmpickle.UnpickleBytes(value[read:], ed25519.PrivateKeySize) - if err != nil { - return read, err - } - c.PrivateKey = privKey - return read + readPriv, nil } // Curve25519PrivateKey represents the private key for ed25519 usage. This is just a wrapper. @@ -135,12 +131,10 @@ func (c Ed25519PublicKey) PickleLibOlm(encoder *libolmpickle.Encoder) { } } -// UnpickleLibOlm decodes the unencryted value and populates the public key accordingly. It returns the number of bytes read. -func (c *Ed25519PublicKey) UnpickleLibOlm(value []byte) (int, error) { - unpickled, readBytes, err := libolmpickle.UnpickleBytes(value, ed25519.PublicKeySize) - if err != nil { - return 0, err - } - *c = unpickled - return readBytes, nil +// UnpickleLibOlm unpickles the unencryted value and populates the public key +// accordingly. +func (c *Ed25519PublicKey) UnpickleLibOlm(decoder *libolmpickle.Decoder) error { + key, err := decoder.ReadBytes(ed25519.PublicKeySize) + *c = key + return err } diff --git a/crypto/goolm/crypto/ed25519_test.go b/crypto/goolm/crypto/ed25519_test.go index 3ac8863a..96d67385 100644 --- a/crypto/goolm/crypto/ed25519_test.go +++ b/crypto/goolm/crypto/ed25519_test.go @@ -48,9 +48,8 @@ func TestEd25519Pickle(t *testing.T) { assert.Len(t, encoder.Bytes(), ed25519KeyPairPickleLength) unpickledKeyPair := crypto.Ed25519KeyPair{} - readBytes, err := unpickledKeyPair.UnpickleLibOlm(encoder.Bytes()) + err = unpickledKeyPair.UnpickleLibOlm(libolmpickle.NewDecoder(encoder.Bytes())) assert.NoError(t, err) - assert.Len(t, encoder.Bytes(), readBytes, "read bytes not correct") assert.Equal(t, keyPair, unpickledKeyPair) } @@ -65,9 +64,8 @@ func TestEd25519PicklePubKeyOnly(t *testing.T) { assert.Len(t, encoder.Bytes(), ed25519KeyPairPickleLength) unpickledKeyPair := crypto.Ed25519KeyPair{} - readBytes, err := unpickledKeyPair.UnpickleLibOlm(encoder.Bytes()) + err = unpickledKeyPair.UnpickleLibOlm(libolmpickle.NewDecoder(encoder.Bytes())) assert.NoError(t, err) - assert.Len(t, encoder.Bytes(), readBytes, "read bytes not correct") assert.Equal(t, keyPair, unpickledKeyPair) } @@ -82,8 +80,7 @@ func TestEd25519PicklePrivKeyOnly(t *testing.T) { assert.Len(t, encoder.Bytes(), ed25519KeyPairPickleLength) unpickledKeyPair := crypto.Ed25519KeyPair{} - readBytes, err := unpickledKeyPair.UnpickleLibOlm(encoder.Bytes()) + err = unpickledKeyPair.UnpickleLibOlm(libolmpickle.NewDecoder(encoder.Bytes())) assert.NoError(t, err) - assert.Len(t, encoder.Bytes(), readBytes, "read bytes not correct") assert.Equal(t, keyPair, unpickledKeyPair) } diff --git a/crypto/goolm/crypto/one_time_key.go b/crypto/goolm/crypto/one_time_key.go index b7e594ef..0947f43b 100644 --- a/crypto/goolm/crypto/one_time_key.go +++ b/crypto/goolm/crypto/one_time_key.go @@ -39,27 +39,15 @@ func (c OneTimeKey) PickleLibOlm(encoder *libolmpickle.Encoder) { c.Key.PickleLibOlm(encoder) } -// UnpickleLibOlm decodes the unencryted value and populates the OneTimeKey accordingly. It returns the number of bytes read. -func (c *OneTimeKey) UnpickleLibOlm(value []byte) (int, error) { - totalReadBytes := 0 - id, readBytes, err := libolmpickle.UnpickleUInt32(value) - if err != nil { - return 0, err +// UnpickleLibOlm unpickles the unencryted value and populates the [OneTimeKey] +// accordingly. +func (c *OneTimeKey) UnpickleLibOlm(decoder *libolmpickle.Decoder) (err error) { + if c.ID, err = decoder.ReadUInt32(); err != nil { + return + } else if c.Published, err = decoder.ReadBool(); err != nil { + return } - totalReadBytes += readBytes - c.ID = id - published, readBytes, err := libolmpickle.UnpickleBool(value[totalReadBytes:]) - if err != nil { - return 0, err - } - totalReadBytes += readBytes - c.Published = published - readBytes, err = c.Key.UnpickleLibOlm(value[totalReadBytes:]) - if err != nil { - return 0, err - } - totalReadBytes += readBytes - return totalReadBytes, nil + return c.Key.UnpickleLibOlm(decoder) } // KeyIDEncoded returns the base64 encoded id. diff --git a/crypto/goolm/libolmpickle/unpickle.go b/crypto/goolm/libolmpickle/unpickle.go index 66803d34..d13be315 100644 --- a/crypto/goolm/libolmpickle/unpickle.go +++ b/crypto/goolm/libolmpickle/unpickle.go @@ -1,56 +1,52 @@ package libolmpickle import ( + "bytes" + "encoding/binary" "fmt" - - "maunium.net/go/mautrix/crypto/olm" ) -func isZeroByteSlice(bytes []byte) bool { - b := byte(0) - for _, s := range bytes { - b |= s +func isZeroByteSlice(data []byte) bool { + for _, b := range data { + if b != 0 { + return false + } } - return b == 0 + return true } type Decoder struct { + buf bytes.Buffer } -func UnpickleUInt8(value []byte) (uint8, int, error) { - if len(value) < 1 { - return 0, 0, fmt.Errorf("unpickle uint8: %w", olm.ErrInputToSmall) - } - return value[0], 1, nil +func NewDecoder(buf []byte) *Decoder { + return &Decoder{buf: *bytes.NewBuffer(buf)} } -func UnpickleBool(value []byte) (bool, int, error) { - if len(value) < 1 { - return false, 0, fmt.Errorf("unpickle bool: %w", olm.ErrInputToSmall) - } - return value[0] != uint8(0x00), 1, nil +func (d *Decoder) ReadUInt8() (uint8, error) { + return d.buf.ReadByte() } -func UnpickleBytes(value []byte, length int) ([]byte, int, error) { - if len(value) < length { - return nil, 0, fmt.Errorf("unpickle bytes: %w", olm.ErrInputToSmall) - } - resp := value[:length] - if isZeroByteSlice(resp) { - return nil, length, nil - } - return resp, length, nil +func (d *Decoder) ReadBool() (bool, error) { + val, err := d.buf.ReadByte() + return val != 0x00, err } -func UnpickleUInt32(value []byte) (uint32, int, error) { - if len(value) < 4 { - return 0, 0, fmt.Errorf("unpickle uint32: %w", olm.ErrInputToSmall) +func (d *Decoder) ReadBytes(length int) (data []byte, err error) { + data = d.buf.Next(length) + if len(data) != length { + return nil, fmt.Errorf("only %d in buffer, expected %d", len(data), length) + } else if isZeroByteSlice(data) { + return nil, nil + } + return +} + +func (d *Decoder) ReadUInt32() (uint32, error) { + data := d.buf.Next(4) + if len(data) != 4 { + return 0, fmt.Errorf("only %d bytes is buffer, expected 4 for uint32", len(data)) + } else { + return binary.BigEndian.Uint32(data), nil } - var res uint32 - count := 0 - for i := 3; i >= 0; i-- { - res |= uint32(value[count]) << (8 * i) - count++ - } - return res, 4, nil } diff --git a/crypto/goolm/libolmpickle/unpickle_test.go b/crypto/goolm/libolmpickle/unpickle_test.go index 71f75b18..30355a76 100644 --- a/crypto/goolm/libolmpickle/unpickle_test.go +++ b/crypto/goolm/libolmpickle/unpickle_test.go @@ -20,9 +20,9 @@ func TestUnpickleUInt32(t *testing.T) { {0xf0, 0x00, 0x00, 0x00}, } for curIndex := range values { - response, readLength, err := libolmpickle.UnpickleUInt32(values[curIndex]) + decoder := libolmpickle.NewDecoder(values[curIndex]) + response, err := decoder.ReadUInt32() assert.NoError(t, err) - assert.Equal(t, 4, readLength) assert.Equal(t, expected[curIndex], response) } } @@ -39,9 +39,9 @@ func TestUnpickleBool(t *testing.T) { {0x02}, } for curIndex := range values { - response, readLength, err := libolmpickle.UnpickleBool(values[curIndex]) + decoder := libolmpickle.NewDecoder(values[curIndex]) + response, err := decoder.ReadBool() assert.NoError(t, err) - assert.Equal(t, 1, readLength) assert.Equal(t, expected[curIndex], response) } } @@ -56,9 +56,9 @@ func TestUnpickleUInt8(t *testing.T) { {0x1a}, } for curIndex := range values { - response, readLength, err := libolmpickle.UnpickleUInt8(values[curIndex]) + decoder := libolmpickle.NewDecoder(values[curIndex]) + response, err := decoder.ReadUInt8() assert.NoError(t, err) - assert.Equal(t, 1, readLength) assert.Equal(t, expected[curIndex], response) } } @@ -75,9 +75,9 @@ func TestUnpickleBytes(t *testing.T) { {0xf0, 0x00, 0x00, 0x00}, } for curIndex := range values { - response, readLength, err := libolmpickle.UnpickleBytes(values[curIndex], 4) + decoder := libolmpickle.NewDecoder(values[curIndex]) + response, err := decoder.ReadBytes(4) assert.NoError(t, err) - assert.Equal(t, 4, readLength) assert.Equal(t, expected[curIndex], response) } } diff --git a/crypto/goolm/megolm/megolm.go b/crypto/goolm/megolm/megolm.go index f44e8cb7..6b5caf7e 100644 --- a/crypto/goolm/megolm/megolm.go +++ b/crypto/goolm/megolm/megolm.go @@ -197,23 +197,15 @@ func (r *Ratchet) UnpickleAsJSON(pickled, key []byte) error { } // UnpickleLibOlm decodes the unencryted value and populates the Ratchet accordingly. It returns the number of bytes read. -func (r *Ratchet) UnpickleLibOlm(unpickled []byte) (int, error) { - //read ratchet data - curPos := 0 - ratchetData, readBytes, err := libolmpickle.UnpickleBytes(unpickled, RatchetParts*RatchetPartLength) +func (r *Ratchet) UnpickleLibOlm(decoder *libolmpickle.Decoder) error { + ratchetData, err := decoder.ReadBytes(RatchetParts * RatchetPartLength) if err != nil { - return 0, err + return err } copy(r.Data[:], ratchetData) - curPos += readBytes - //Read counter - counter, readBytes, err := libolmpickle.UnpickleUInt32(unpickled[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - r.Counter = counter - return curPos, nil + + r.Counter, err = decoder.ReadUInt32() + return err } // PickleLibOlm pickles the ratchet into the encoder. diff --git a/crypto/goolm/pk/decryption.go b/crypto/goolm/pk/decryption.go index cc363841..ba94dc37 100644 --- a/crypto/goolm/pk/decryption.go +++ b/crypto/goolm/pk/decryption.go @@ -97,28 +97,21 @@ func (a *Decryption) Unpickle(pickled, key []byte) error { if err != nil { return err } - _, err = a.UnpickleLibOlm(decrypted) - return err + return a.UnpickleLibOlm(decrypted) } // UnpickleLibOlm decodes the unencryted value and populates the Decryption accordingly. It returns the number of bytes read. -func (a *Decryption) UnpickleLibOlm(value []byte) (int, error) { - //First 4 bytes are the accountPickleVersion - pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value) +func (a *Decryption) UnpickleLibOlm(unpickled []byte) error { + decoder := libolmpickle.NewDecoder(unpickled) + pickledVersion, err := decoder.ReadUInt32() if err != nil { - return 0, err + return err } - switch pickledVersion { - case decryptionPickleVersionLibOlm: - default: - return 0, fmt.Errorf("unpickle olmSession: %w", olm.ErrBadVersion) + if pickledVersion == decryptionPickleVersionLibOlm { + return a.KeyPair.UnpickleLibOlm(decoder) + } else { + return fmt.Errorf("unpickle olmSession: %w (found %d, expected %d)", olm.ErrBadVersion, pickledVersion, decryptionPickleVersionLibOlm) } - readBytes, err := a.KeyPair.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - return curPos, nil } // Pickle returns a base64 encoded and with key encrypted pickled Decryption using PickleLibOlm(). diff --git a/crypto/goolm/ratchet/chain.go b/crypto/goolm/ratchet/chain.go index 124d6906..dc021b8a 100644 --- a/crypto/goolm/ratchet/chain.go +++ b/crypto/goolm/ratchet/chain.go @@ -22,20 +22,14 @@ func (c *chainKey) advance() { c.Index++ } -// UnpickleLibOlm decodes the unencryted value and populates the chain key accordingly. It returns the number of bytes read. -func (r *chainKey) UnpickleLibOlm(value []byte) (int, error) { - curPos := 0 - readBytes, err := r.Key.UnpickleLibOlm(value) +// UnpickleLibOlm unpickles the unencryted value and populates the chain key accordingly. +func (r *chainKey) UnpickleLibOlm(decoder *libolmpickle.Decoder) error { + err := r.Key.UnpickleLibOlm(decoder) if err != nil { - return 0, err + return err } - curPos += readBytes - r.Index, readBytes, err = libolmpickle.UnpickleUInt32(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - return curPos, nil + r.Index, err = decoder.ReadUInt32() + return err } // PickleLibOlm pickles the chain key into the encoder. @@ -78,20 +72,13 @@ func (s senderChain) chainKey() chainKey { return s.CKey } -// UnpickleLibOlm decodes the unencryted value and populates the chain accordingly. It returns the number of bytes read. -func (r *senderChain) UnpickleLibOlm(value []byte) (int, error) { - curPos := 0 - readBytes, err := r.RKey.UnpickleLibOlm(value) - if err != nil { - return 0, err +// UnpickleLibOlm unpickles the unencryted value and populates the sender chain +// accordingly. +func (r *senderChain) UnpickleLibOlm(decoder *libolmpickle.Decoder) error { + if err := r.RKey.UnpickleLibOlm(decoder); err != nil { + return err } - curPos += readBytes - readBytes, err = r.CKey.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - return curPos, nil + return r.CKey.UnpickleLibOlm(decoder) } // PickleLibOlm pickles the sender chain into the encoder. @@ -137,20 +124,12 @@ func (s receiverChain) chainKey() chainKey { return s.CKey } -// UnpickleLibOlm decodes the unencryted value and populates the chain accordingly. It returns the number of bytes read. -func (r *receiverChain) UnpickleLibOlm(value []byte) (int, error) { - curPos := 0 - readBytes, err := r.RKey.UnpickleLibOlm(value) - if err != nil { - return 0, err +// UnpickleLibOlm unpickles the unencryted value and populates the chain accordingly. +func (r *receiverChain) UnpickleLibOlm(decoder *libolmpickle.Decoder) error { + if err := r.RKey.UnpickleLibOlm(decoder); err != nil { + return err } - curPos += readBytes - readBytes, err = r.CKey.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - return curPos, nil + return r.CKey.UnpickleLibOlm(decoder) } // PickleLibOlm pickles the receiver chain into the encoder. @@ -165,22 +144,14 @@ type messageKey struct { Key []byte `json:"key"` } -// UnpickleLibOlm decodes the unencryted value and populates the message key accordingly. It returns the number of bytes read. -func (m *messageKey) UnpickleLibOlm(value []byte) (int, error) { - curPos := 0 - ratchetKey, readBytes, err := libolmpickle.UnpickleBytes(value, messageKeyLength) - if err != nil { - return 0, err +// UnpickleLibOlm unpickles the unencryted value and populates the message key +// accordingly. +func (m *messageKey) UnpickleLibOlm(decoder *libolmpickle.Decoder) (err error) { + if m.Key, err = decoder.ReadBytes(messageKeyLength); err != nil { + return } - m.Key = ratchetKey - curPos += readBytes - keyID, readBytes, err := libolmpickle.UnpickleUInt32(value[:curPos]) - if err != nil { - return 0, err - } - curPos += readBytes - m.Index = keyID - return curPos, nil + m.Index, err = decoder.ReadUInt32() + return } // PickleLibOlm pickles the message key into the encoder. diff --git a/crypto/goolm/ratchet/olm.go b/crypto/goolm/ratchet/olm.go index 81433654..b40328ab 100644 --- a/crypto/goolm/ratchet/olm.go +++ b/crypto/goolm/ratchet/olm.go @@ -278,73 +278,59 @@ func (r *Ratchet) UnpickleAsJSON(pickled, key []byte) error { return utilities.UnpickleAsJSON(r, pickled, key, olmPickleVersion) } -// UnpickleLibOlm decodes the unencryted value and populates the Ratchet accordingly. It returns the number of bytes read. -func (r *Ratchet) UnpickleLibOlm(value []byte, includesChainIndex bool) (int, error) { - //read ratchet data - curPos := 0 - readBytes, err := r.RootKey.UnpickleLibOlm(value) - if err != nil { - return 0, err +// UnpickleLibOlm unpickles the unencryted value and populates the [Ratchet] +// accordingly. +func (r *Ratchet) UnpickleLibOlm(decoder *libolmpickle.Decoder, includesChainIndex bool) error { + if err := r.RootKey.UnpickleLibOlm(decoder); err != nil { + return err } - curPos += readBytes - countSenderChains, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) //Length of sender chain + senderChainsCount, err := decoder.ReadUInt32() if err != nil { - return 0, err + return err } - curPos += readBytes - for i := uint32(0); i < countSenderChains; i++ { + + for i := uint32(0); i < senderChainsCount; i++ { if i == 0 { - //only first is stored - readBytes, err := r.SenderChains.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes + // only the first sender key is stored + err = r.SenderChains.UnpickleLibOlm(decoder) r.SenderChains.IsSet = true } else { - dummy := senderChain{} - readBytes, err := dummy.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes + // just eat the values + err = (&senderChain{}).UnpickleLibOlm(decoder) } - } - countReceivChains, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) //Length of recevier chain - if err != nil { - return 0, err - } - curPos += readBytes - r.ReceiverChains = make([]receiverChain, countReceivChains) - for i := uint32(0); i < countReceivChains; i++ { - readBytes, err := r.ReceiverChains[i].UnpickleLibOlm(value[curPos:]) if err != nil { - return 0, err + return err } - curPos += readBytes } - countSkippedMessageKeys, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) //Length of skippedMessageKeys + + receiverChainCount, err := decoder.ReadUInt32() if err != nil { - return 0, err + return err } - curPos += readBytes - r.SkippedMessageKeys = make([]skippedMessageKey, countSkippedMessageKeys) - for i := uint32(0); i < countSkippedMessageKeys; i++ { - readBytes, err := r.SkippedMessageKeys[i].UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err + r.ReceiverChains = make([]receiverChain, receiverChainCount) + for i := uint32(0); i < receiverChainCount; i++ { + if err := r.ReceiverChains[i].UnpickleLibOlm(decoder); err != nil { + return err } - curPos += readBytes } - // pickle v 0x80000001 includes a chain index; pickle v1 does not. + + skippedMessageKeysCount, err := decoder.ReadUInt32() + if err != nil { + return err + } + r.SkippedMessageKeys = make([]skippedMessageKey, skippedMessageKeysCount) + for i := uint32(0); i < skippedMessageKeysCount; i++ { + if err := r.SkippedMessageKeys[i].UnpickleLibOlm(decoder); err != nil { + return err + } + } + + // pickle version 0x80000001 includes a chain index; pickle version 1 does not. if includesChainIndex { - _, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes + _, err = decoder.ReadUInt32() + return err } - return curPos, nil + return nil } // PickleLibOlm pickles the ratchet into the encoder. diff --git a/crypto/goolm/ratchet/skipped_message.go b/crypto/goolm/ratchet/skipped_message.go index 7510548f..2ffaee7b 100644 --- a/crypto/goolm/ratchet/skipped_message.go +++ b/crypto/goolm/ratchet/skipped_message.go @@ -11,20 +11,13 @@ type skippedMessageKey struct { MKey messageKey `json:"message_key"` } -// UnpickleLibOlm decodes the unencryted value and populates the chain accordingly. It returns the number of bytes read. -func (r *skippedMessageKey) UnpickleLibOlm(value []byte) (int, error) { - curPos := 0 - readBytes, err := r.RKey.UnpickleLibOlm(value) - if err != nil { - return 0, err +// UnpickleLibOlm unpickles the unencryted value and populates the skipped +// message keys accordingly. +func (r *skippedMessageKey) UnpickleLibOlm(decoder *libolmpickle.Decoder) (err error) { + if err = r.RKey.UnpickleLibOlm(decoder); err != nil { + return } - curPos += readBytes - readBytes, err = r.MKey.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - return curPos, nil + return r.MKey.UnpickleLibOlm(decoder) } // PickleLibOlm pickles the skipped message key into the encoder. diff --git a/crypto/goolm/session/megolm_inbound_session.go b/crypto/goolm/session/megolm_inbound_session.go index c263f268..4c107e92 100644 --- a/crypto/goolm/session/megolm_inbound_session.go +++ b/crypto/goolm/session/megolm_inbound_session.go @@ -196,48 +196,37 @@ func (o *MegolmInboundSession) Unpickle(pickled, key []byte) error { if err != nil { return err } - _, err = o.UnpickleLibOlm(decrypted) - return err + return o.UnpickleLibOlm(decrypted) } -// UnpickleLibOlm decodes the unencryted value and populates the Session accordingly. It returns the number of bytes read. -func (o *MegolmInboundSession) UnpickleLibOlm(value []byte) (int, error) { - //First 4 bytes are the accountPickleVersion - pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value) +// UnpickleLibOlm unpickles the unencryted value and populates the [Session] +// accordingly. +func (o *MegolmInboundSession) UnpickleLibOlm(value []byte) error { + decoder := libolmpickle.NewDecoder(value) + pickledVersion, err := decoder.ReadUInt32() if err != nil { - return 0, err + return err } - switch pickledVersion { - case megolmInboundSessionPickleVersionLibOlm, 1: - default: - return 0, fmt.Errorf("unpickle MegolmInboundSession: %w", olm.ErrBadVersion) + if pickledVersion != megolmInboundSessionPickleVersionLibOlm && pickledVersion != 1 { + return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion) } - readBytes, err := o.InitialRatchet.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err + + if err = o.InitialRatchet.UnpickleLibOlm(decoder); err != nil { + return err + } else if err = o.Ratchet.UnpickleLibOlm(decoder); err != nil { + return err + } else if err = o.SigningKey.UnpickleLibOlm(decoder); err != nil { + return err } - curPos += readBytes - readBytes, err = o.Ratchet.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - readBytes, err = o.SigningKey.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes + if pickledVersion == 1 { // pickle v1 had no signing_key_verified field (all keyshares were verified at import time) o.SigningKeyVerified = true } else { - o.SigningKeyVerified, readBytes, err = libolmpickle.UnpickleBool(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes + o.SigningKeyVerified, err = decoder.ReadBool() + return err } - return curPos, nil + return nil } // Pickle returns a base64 encoded and with key encrypted pickled MegolmInboundSession using PickleLibOlm(). diff --git a/crypto/goolm/session/megolm_outbound_session.go b/crypto/goolm/session/megolm_outbound_session.go index 7b498d88..b42dab53 100644 --- a/crypto/goolm/session/megolm_outbound_session.go +++ b/crypto/goolm/session/megolm_outbound_session.go @@ -95,33 +95,21 @@ func (o *MegolmOutboundSession) Unpickle(pickled, key []byte) error { if err != nil { return err } - _, err = o.UnpickleLibOlm(decrypted) - return err + return o.UnpickleLibOlm(decrypted) } -// UnpickleLibOlm decodes the unencryted value and populates the Session accordingly. It returns the number of bytes read. -func (o *MegolmOutboundSession) UnpickleLibOlm(value []byte) (int, error) { - //First 4 bytes are the accountPickleVersion - pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value) - if err != nil { - return 0, err +// UnpickleLibOlm unpickles the unencryted value and populates the +// [MegolmOutboundSession] accordingly. +func (o *MegolmOutboundSession) UnpickleLibOlm(buf []byte) error { + decoder := libolmpickle.NewDecoder(buf) + pickledVersion, err := decoder.ReadUInt32() + if pickledVersion != megolmOutboundSessionPickleVersionLibOlm { + return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion) } - switch pickledVersion { - case megolmOutboundSessionPickleVersionLibOlm: - default: - return 0, fmt.Errorf("unpickle MegolmInboundSession: %w", olm.ErrBadVersion) + if err = o.Ratchet.UnpickleLibOlm(decoder); err != nil { + return err } - readBytes, err := o.Ratchet.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - readBytes, err = o.SigningKey.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - return curPos, nil + return o.SigningKey.UnpickleLibOlm(decoder) } // Pickle returns a base64 encoded and with key encrypted pickled MegolmOutboundSession using PickleLibOlm(). diff --git a/crypto/goolm/session/olm_session.go b/crypto/goolm/session/olm_session.go index 533aafb5..574d029f 100644 --- a/crypto/goolm/session/olm_session.go +++ b/crypto/goolm/session/olm_session.go @@ -358,53 +358,35 @@ func (o *OlmSession) Unpickle(pickled, key []byte) error { if err != nil { return err } - _, err = o.UnpickleLibOlm(decrypted) - return err + return o.UnpickleLibOlm(decrypted) } -// UnpickleLibOlm decodes the unencryted value and populates the Session accordingly. It returns the number of bytes read. -func (o *OlmSession) UnpickleLibOlm(value []byte) (int, error) { - //First 4 bytes are the accountPickleVersion - pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value) - if err != nil { - return 0, err - } - includesChainIndex := true +// UnpickleLibOlm unpickles the unencryted value and populates the [OlmSession] +// accordingly. +func (o *OlmSession) UnpickleLibOlm(buf []byte) error { + decoder := libolmpickle.NewDecoder(buf) + pickledVersion, err := decoder.ReadUInt32() + + var includesChainIndex bool switch pickledVersion { case olmSessionPickleVersionLibOlm: includesChainIndex = false case uint32(0x80000001): includesChainIndex = true default: - return 0, fmt.Errorf("unpickle olmSession: %w", olm.ErrBadVersion) + return fmt.Errorf("unpickle olmSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion) } - var readBytes int - o.ReceivedMessage, readBytes, err = libolmpickle.UnpickleBool(value[curPos:]) - if err != nil { - return 0, err + + if o.ReceivedMessage, err = decoder.ReadBool(); err != nil { + return err + } else if err = o.AliceIdentityKey.UnpickleLibOlm(decoder); err != nil { + return err + } else if err = o.AliceBaseKey.UnpickleLibOlm(decoder); err != nil { + return err + } else if err = o.BobOneTimeKey.UnpickleLibOlm(decoder); err != nil { + return err } - curPos += readBytes - readBytes, err = o.AliceIdentityKey.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - readBytes, err = o.AliceBaseKey.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - readBytes, err = o.BobOneTimeKey.UnpickleLibOlm(value[curPos:]) - if err != nil { - return 0, err - } - curPos += readBytes - readBytes, err = o.Ratchet.UnpickleLibOlm(value[curPos:], includesChainIndex) - if err != nil { - return 0, err - } - curPos += readBytes - return curPos, nil + return o.Ratchet.UnpickleLibOlm(decoder, includesChainIndex) } // Pickle returns a base64 encoded and with key encrypted pickled olmSession From c09eae39d04dd5730bf21b4216ff5a869910fbc8 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 25 Oct 2024 01:19:50 -0600 Subject: [PATCH 0875/1647] crypto: always read from crypto/rand Signed-off-by: Sumner Evans --- crypto/goolm/account/account_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crypto/goolm/account/account_test.go b/crypto/goolm/account/account_test.go index e7739bb6..e1c9b452 100644 --- a/crypto/goolm/account/account_test.go +++ b/crypto/goolm/account/account_test.go @@ -133,7 +133,7 @@ func TestLoopback(t *testing.T) { accountB, err := account.NewAccount() assert.NoError(t, err) - err = accountB.GenOneTimeKeys( 42) + err = accountB.GenOneTimeKeys(42) assert.NoError(t, err) aliceSession, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), accountB.OTKeys[0].Key.B64Encoded()) @@ -191,7 +191,7 @@ func TestMoreMessages(t *testing.T) { accountB, err := account.NewAccount() assert.NoError(t, err) - err = accountB.GenOneTimeKeys( 42) + err = accountB.GenOneTimeKeys(42) assert.NoError(t, err) aliceSession, err := accountA.NewOutboundSession(accountB.IdKeys.Curve25519.B64Encoded(), accountB.OTKeys[0].Key.B64Encoded()) From 9f74b58d84b1661ef45ef22a53548c9dad87d279 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 25 Oct 2024 01:26:33 -0600 Subject: [PATCH 0876/1647] crypto/goolm/crypto: use stdlib for HKDF and HMAC operations Signed-off-by: Sumner Evans --- crypto/goolm/cipher/aes_sha256.go | 13 ++-- crypto/goolm/crypto/hmac.go | 29 -------- crypto/goolm/crypto/hmac_test.go | 101 ---------------------------- crypto/goolm/megolm/megolm.go | 7 +- crypto/goolm/ratchet/chain.go | 7 +- crypto/goolm/ratchet/olm.go | 20 ++++-- crypto/goolm/session/olm_session.go | 5 +- 7 files changed, 36 insertions(+), 146 deletions(-) delete mode 100644 crypto/goolm/crypto/hmac.go delete mode 100644 crypto/goolm/crypto/hmac_test.go diff --git a/crypto/goolm/cipher/aes_sha256.go b/crypto/goolm/cipher/aes_sha256.go index 065cb501..42f5d069 100644 --- a/crypto/goolm/cipher/aes_sha256.go +++ b/crypto/goolm/cipher/aes_sha256.go @@ -2,10 +2,13 @@ package cipher import ( "bytes" + "crypto/hmac" + "crypto/sha256" "io" + "golang.org/x/crypto/hkdf" + "maunium.net/go/mautrix/crypto/aescbc" - "maunium.net/go/mautrix/crypto/goolm/crypto" ) // derivedAESKeys stores the derived keys for the AESSHA256 cipher @@ -17,9 +20,9 @@ type derivedAESKeys struct { // deriveAESKeys derives three keys for the AESSHA256 cipher func deriveAESKeys(kdfInfo []byte, key []byte) (derivedAESKeys, error) { - hkdf := crypto.HKDFSHA256(key, nil, kdfInfo) + kdf := hkdf.New(sha256.New, key, nil, kdfInfo) keymatter := make([]byte, 80) - _, err := io.ReadFull(hkdf, keymatter) + _, err := io.ReadFull(kdf, keymatter) return derivedAESKeys{ key: keymatter[:32], hmacKey: keymatter[32:64], @@ -63,7 +66,9 @@ func (c AESSHA256) MAC(key, message []byte) ([]byte, error) { if err != nil { return nil, err } - return crypto.HMACSHA256(keys.hmacKey, message), nil + hash := hmac.New(sha256.New, keys.hmacKey) + _, err = hash.Write(message) + return hash.Sum(nil), err } // Verify checks the MAC of the message using the key against the givenMAC. The key is used to derive the actual mac key (32 bytes). diff --git a/crypto/goolm/crypto/hmac.go b/crypto/goolm/crypto/hmac.go deleted file mode 100644 index 8542f7cb..00000000 --- a/crypto/goolm/crypto/hmac.go +++ /dev/null @@ -1,29 +0,0 @@ -package crypto - -import ( - "crypto/hmac" - "crypto/sha256" - "io" - - "golang.org/x/crypto/hkdf" -) - -// HMACSHA256 returns the hash message authentication code with SHA-256 of the input with the key. -func HMACSHA256(key, input []byte) []byte { - hash := hmac.New(sha256.New, key) - hash.Write(input) - return hash.Sum(nil) -} - -// SHA256 return the SHA-256 of the value. -func SHA256(value []byte) []byte { - hash := sha256.New() - hash.Write(value) - return hash.Sum(nil) -} - -// HKDFSHA256 is the key deivation function based on HMAC and returns a reader based on input. salt and info can both be nil. -// The reader can be used to read an arbitary length of bytes which are based on all parameters. -func HKDFSHA256(input, salt, info []byte) io.Reader { - return hkdf.New(sha256.New, input, salt, info) -} diff --git a/crypto/goolm/crypto/hmac_test.go b/crypto/goolm/crypto/hmac_test.go deleted file mode 100644 index 127be131..00000000 --- a/crypto/goolm/crypto/hmac_test.go +++ /dev/null @@ -1,101 +0,0 @@ -package crypto_test - -import ( - "encoding/base64" - "io" - "testing" - - "github.com/stretchr/testify/assert" - - "maunium.net/go/mautrix/crypto/goolm/crypto" -) - -func TestHMACSHA256(t *testing.T) { - key := []byte("test key") - message := []byte("test message") - hash := crypto.HMACSHA256(key, message) - assert.Equal(t, hash, crypto.HMACSHA256(key, message)) - - str := "A4M0ovdiWHaZ5msdDFbrvtChFwZIoIaRSVGmv8bmPtc" - result, err := base64.RawStdEncoding.DecodeString(str) - assert.NoError(t, err) - assert.Equal(t, result, hash) -} - -func TestHKDFSHA256(t *testing.T) { - message := []byte("test content") - - hkdf := crypto.HKDFSHA256(message, nil, nil) - result := make([]byte, 32) - _, err := io.ReadFull(hkdf, result) - assert.NoError(t, err) - - hkdf2 := crypto.HKDFSHA256(message, nil, nil) - result2 := make([]byte, 32) - _, err = io.ReadFull(hkdf2, result2) - assert.NoError(t, err) - - assert.Equal(t, result, result2) -} - -func TestSHA256Case1(t *testing.T) { - input := make([]byte, 0) - expected := []byte{ - 0xE3, 0xB0, 0xC4, 0x42, 0x98, 0xFC, 0x1C, 0x14, - 0x9A, 0xFB, 0xF4, 0xC8, 0x99, 0x6F, 0xB9, 0x24, - 0x27, 0xAE, 0x41, 0xE4, 0x64, 0x9B, 0x93, 0x4C, - 0xA4, 0x95, 0x99, 0x1B, 0x78, 0x52, 0xB8, 0x55, - } - result := crypto.SHA256(input) - assert.Equal(t, expected, result) -} - -func TestHMACCase1(t *testing.T) { - input := make([]byte, 0) - expected := []byte{ - 0xb6, 0x13, 0x67, 0x9a, 0x08, 0x14, 0xd9, 0xec, - 0x77, 0x2f, 0x95, 0xd7, 0x78, 0xc3, 0x5f, 0xc5, - 0xff, 0x16, 0x97, 0xc4, 0x93, 0x71, 0x56, 0x53, - 0xc6, 0xc7, 0x12, 0x14, 0x42, 0x92, 0xc5, 0xad, - } - result := crypto.HMACSHA256(input, input) - assert.Equal(t, expected, result) -} - -func TestHDKFCase1(t *testing.T) { - input := []byte{ - 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, - 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, - 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, - } - salt := []byte{ - 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, - 0x08, 0x09, 0x0a, 0x0b, 0x0c, - } - info := []byte{ - 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, - 0xf8, 0xf9, - } - expectedHMAC := []byte{ - 0x07, 0x77, 0x09, 0x36, 0x2c, 0x2e, 0x32, 0xdf, - 0x0d, 0xdc, 0x3f, 0x0d, 0xc4, 0x7b, 0xba, 0x63, - 0x90, 0xb6, 0xc7, 0x3b, 0xb5, 0x0f, 0x9c, 0x31, - 0x22, 0xec, 0x84, 0x4a, 0xd7, 0xc2, 0xb3, 0xe5, - } - result := crypto.HMACSHA256(salt, input) - assert.Equal(t, expectedHMAC, result) - - expectedHDKF := []byte{ - 0x3c, 0xb2, 0x5f, 0x25, 0xfa, 0xac, 0xd5, 0x7a, - 0x90, 0x43, 0x4f, 0x64, 0xd0, 0x36, 0x2f, 0x2a, - 0x2d, 0x2d, 0x0a, 0x90, 0xcf, 0x1a, 0x5a, 0x4c, - 0x5d, 0xb0, 0x2d, 0x56, 0xec, 0xc4, 0xc5, 0xbf, - 0x34, 0x00, 0x72, 0x08, 0xd5, 0xb8, 0x87, 0x18, - 0x58, 0x65, - } - resultReader := crypto.HKDFSHA256(input, salt, info) - result = make([]byte, len(expectedHDKF)) - _, err := io.ReadFull(resultReader, result) - assert.NoError(t, err) - assert.Equal(t, expectedHDKF, result) -} diff --git a/crypto/goolm/megolm/megolm.go b/crypto/goolm/megolm/megolm.go index 6b5caf7e..eab82cc0 100644 --- a/crypto/goolm/megolm/megolm.go +++ b/crypto/goolm/megolm/megolm.go @@ -2,7 +2,9 @@ package megolm import ( + "crypto/hmac" "crypto/rand" + "crypto/sha256" "fmt" "maunium.net/go/mautrix/crypto/goolm/cipher" @@ -63,8 +65,9 @@ func NewWithRandom(counter uint32) (*Ratchet, error) { // rehashPart rehases the part of the ratchet data with the base defined as from storing into the target to. func (m *Ratchet) rehashPart(from, to int) { - newData := crypto.HMACSHA256(m.Data[from*RatchetPartLength:from*RatchetPartLength+RatchetPartLength], hashKeySeeds[to]) - copy(m.Data[to*RatchetPartLength:], newData[:RatchetPartLength]) + hash := hmac.New(sha256.New, m.Data[from*RatchetPartLength:from*RatchetPartLength+RatchetPartLength]) + hash.Write(hashKeySeeds[to]) + copy(m.Data[to*RatchetPartLength:], hash.Sum(nil)) } // Advance advances the ratchet one step. diff --git a/crypto/goolm/ratchet/chain.go b/crypto/goolm/ratchet/chain.go index dc021b8a..5deb90f5 100644 --- a/crypto/goolm/ratchet/chain.go +++ b/crypto/goolm/ratchet/chain.go @@ -1,6 +1,9 @@ package ratchet import ( + "crypto/hmac" + "crypto/sha256" + "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" ) @@ -18,7 +21,9 @@ type chainKey struct { // advance advances the chain func (c *chainKey) advance() { - c.Key = crypto.HMACSHA256(c.Key, []byte{chainKeySeed}) + hash := hmac.New(sha256.New, c.Key) + hash.Write([]byte{chainKeySeed}) + c.Key = hash.Sum(nil) c.Index++ } diff --git a/crypto/goolm/ratchet/olm.go b/crypto/goolm/ratchet/olm.go index b40328ab..e53d126a 100644 --- a/crypto/goolm/ratchet/olm.go +++ b/crypto/goolm/ratchet/olm.go @@ -2,9 +2,13 @@ package ratchet import ( + "crypto/hmac" + "crypto/sha256" "fmt" "io" + "golang.org/x/crypto/hkdf" + "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" @@ -70,7 +74,7 @@ func New() *Ratchet { // InitializeAsBob initializes this ratchet from a receiving point of view (only first message). func (r *Ratchet) InitializeAsBob(sharedSecret []byte, theirRatchetKey crypto.Curve25519PublicKey) error { - derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, nil, KdfInfo.Root) + derivedSecretsReader := hkdf.New(sha256.New, sharedSecret, nil, KdfInfo.Root) derivedSecrets := make([]byte, 2*sharedKeyLength) if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil { return err @@ -83,7 +87,7 @@ func (r *Ratchet) InitializeAsBob(sharedSecret []byte, theirRatchetKey crypto.Cu // InitializeAsAlice initializes this ratchet from a sending point of view (only first message). func (r *Ratchet) InitializeAsAlice(sharedSecret []byte, ourRatchetKey crypto.Curve25519KeyPair) error { - derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, nil, KdfInfo.Root) + derivedSecretsReader := hkdf.New(sha256.New, sharedSecret, nil, KdfInfo.Root) derivedSecrets := make([]byte, 2*sharedKeyLength) if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil { return err @@ -192,7 +196,7 @@ func (r *Ratchet) advanceRootKey(newRatchetKey crypto.Curve25519KeyPair, oldRatc if err != nil { return nil, err } - derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, r.RootKey, KdfInfo.Ratchet) + derivedSecretsReader := hkdf.New(sha256.New, sharedSecret, r.RootKey, KdfInfo.Ratchet) derivedSecrets := make([]byte, 2*sharedKeyLength) if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil { return nil, err @@ -203,10 +207,12 @@ func (r *Ratchet) advanceRootKey(newRatchetKey crypto.Curve25519KeyPair, oldRatc // createMessageKeys returns the messageKey derived from the chainKey func (r Ratchet) createMessageKeys(chainKey chainKey) messageKey { - res := messageKey{} - res.Key = crypto.HMACSHA256(chainKey.Key, []byte{messageKeySeed}) - res.Index = chainKey.Index - return res + hash := hmac.New(sha256.New, chainKey.Key) + hash.Write([]byte{messageKeySeed}) + return messageKey{ + Key: hash.Sum(nil), + Index: chainKey.Index, + } } // decryptForExistingChain returns the decrypted message by using the chain. The MAC of the rawMessage is verified. diff --git a/crypto/goolm/session/olm_session.go b/crypto/goolm/session/olm_session.go index 574d029f..fcd9d0dc 100644 --- a/crypto/goolm/session/olm_session.go +++ b/crypto/goolm/session/olm_session.go @@ -2,6 +2,7 @@ package session import ( "bytes" + "crypto/sha256" "encoding/base64" "fmt" "strings" @@ -203,8 +204,8 @@ func (s *OlmSession) ID() id.SessionID { copy(message, s.AliceIdentityKey) copy(message[crypto.Curve25519PrivateKeyLength:], s.AliceBaseKey) copy(message[2*crypto.Curve25519PrivateKeyLength:], s.BobOneTimeKey) - hash := crypto.SHA256(message) - res := id.SessionID(goolmbase64.Encode(hash)) + hash := sha256.Sum256(message) + res := id.SessionID(goolmbase64.Encode(hash[:])) return res } From 4a2557ed15dc2707302853396b1055f70ea09e0f Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 25 Oct 2024 09:39:19 -0600 Subject: [PATCH 0877/1647] crypto: propagate more errors Signed-off-by: Sumner Evans --- crypto/encryptmegolm.go | 5 +++- crypto/goolm/account/account.go | 11 +++++---- crypto/goolm/crypto/curve25519.go | 23 ++++++++---------- crypto/goolm/crypto/ed25519.go | 12 ++++------ crypto/goolm/crypto/ed25519_test.go | 7 ++++-- crypto/goolm/crypto/one_time_key.go | 31 ++++++------------------- crypto/goolm/megolm/megolm.go | 4 ++-- crypto/goolm/message/group_message.go | 13 +++++++---- crypto/goolm/message/session_sharing.go | 6 ++--- crypto/goolm/pk/signing.go | 4 ++-- crypto/goolm/session/register.go | 12 ++-------- crypto/libolm/outboundgroupsession.go | 16 +++++-------- crypto/olm/groupsession_test.go | 3 ++- crypto/olm/outboundgroupsession.go | 4 ++-- crypto/olm/outboundgroupsession_test.go | 18 +++++++++----- crypto/sessions.go | 10 +++++--- crypto/store_test.go | 4 +++- 17 files changed, 86 insertions(+), 97 deletions(-) diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index 7c8a7542..3199ce57 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -179,7 +179,10 @@ func (mach *OlmMachine) newOutboundGroupSession(ctx context.Context, roomID id.R Msg("Failed to get encryption event in room") return nil, fmt.Errorf("failed to get encryption event in room %s: %w", roomID, err) } - session := NewOutboundGroupSession(roomID, encryptionEvent) + session, err := NewOutboundGroupSession(roomID, encryptionEvent) + if err != nil { + return nil, err + } if !mach.DontStoreOutboundKeys { signingKey, idKey := mach.account.Keys() err := mach.createGroupSession(ctx, idKey, signingKey, roomID, session.ID(), session.Internal.Key(), session.MaxAge, session.MaxMessages, false) diff --git a/crypto/goolm/account/account.go b/crypto/goolm/account/account.go index f3554e29..099cc493 100644 --- a/crypto/goolm/account/account.go +++ b/crypto/goolm/account/account.go @@ -111,8 +111,11 @@ func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519, error) { func (a *Account) Sign(message []byte) ([]byte, error) { if len(message) == 0 { return nil, fmt.Errorf("sign: %w", olm.ErrEmptyInput) + } else if signature, err := a.IdKeys.Ed25519.Sign(message); err != nil { + return nil, err + } else { + return []byte(base64.RawStdEncoding.EncodeToString(signature)), nil } - return []byte(base64.RawStdEncoding.EncodeToString(a.IdKeys.Ed25519.Sign(message))), nil } // OneTimeKeys returns the public parts of the unpublished one time keys of the Account. @@ -122,7 +125,7 @@ func (a *Account) OneTimeKeys() (map[string]id.Curve25519, error) { oneTimeKeys := make(map[string]id.Curve25519) for _, curKey := range a.OTKeys { if !curKey.Published { - oneTimeKeys[curKey.KeyIDEncoded()] = id.Curve25519(curKey.PublicKeyEncoded()) + oneTimeKeys[curKey.KeyIDEncoded()] = curKey.Key.PublicKey.B64Encoded() } } return oneTimeKeys, nil @@ -259,7 +262,7 @@ func (a *Account) GenFallbackKey() error { func (a *Account) FallbackKey() map[string]id.Curve25519 { keys := make(map[string]id.Curve25519) if a.NumFallbackKeys >= 1 { - keys[a.CurrentFallbackKey.KeyIDEncoded()] = id.Curve25519(a.CurrentFallbackKey.PublicKeyEncoded()) + keys[a.CurrentFallbackKey.KeyIDEncoded()] = a.CurrentFallbackKey.Key.PublicKey.B64Encoded() } return keys } @@ -286,7 +289,7 @@ func (a *Account) FallbackKeyJSON() ([]byte, error) { func (a *Account) FallbackKeyUnpublished() map[string]id.Curve25519 { keys := make(map[string]id.Curve25519) if a.NumFallbackKeys >= 1 && !a.CurrentFallbackKey.Published { - keys[a.CurrentFallbackKey.KeyIDEncoded()] = id.Curve25519(a.CurrentFallbackKey.PublicKeyEncoded()) + keys[a.CurrentFallbackKey.KeyIDEncoded()] = a.CurrentFallbackKey.Key.PublicKey.B64Encoded() } return keys } diff --git a/crypto/goolm/crypto/curve25519.go b/crypto/goolm/crypto/curve25519.go index 1dbc83fd..e9759501 100644 --- a/crypto/goolm/crypto/curve25519.go +++ b/crypto/goolm/crypto/curve25519.go @@ -1,8 +1,8 @@ package crypto import ( - "bytes" "crypto/rand" + "crypto/subtle" "encoding/base64" "golang.org/x/crypto/curve25519" @@ -16,6 +16,12 @@ const ( Curve25519PublicKeyLength = 32 ) +// Curve25519KeyPair stores both parts of a curve25519 key. +type Curve25519KeyPair struct { + PrivateKey Curve25519PrivateKey `json:"private,omitempty"` + PublicKey Curve25519PublicKey `json:"public,omitempty"` +} + // Curve25519GenerateKey creates a new curve25519 key pair. func Curve25519GenerateKey() (Curve25519KeyPair, error) { privateKeyByte := make([]byte, Curve25519PrivateKeyLength) @@ -34,19 +40,10 @@ func Curve25519GenerateKey() (Curve25519KeyPair, error) { // Curve25519GenerateFromPrivate creates a new curve25519 key pair with the private key given. func Curve25519GenerateFromPrivate(private Curve25519PrivateKey) (Curve25519KeyPair, error) { publicKey, err := private.PubKey() - if err != nil { - return Curve25519KeyPair{}, err - } return Curve25519KeyPair{ PrivateKey: private, PublicKey: Curve25519PublicKey(publicKey), - }, nil -} - -// Curve25519KeyPair stores both parts of a curve25519 key. -type Curve25519KeyPair struct { - PrivateKey Curve25519PrivateKey `json:"private,omitempty"` - PublicKey Curve25519PublicKey `json:"public,omitempty"` + }, err } // B64Encoded returns a base64 encoded string of the public key. @@ -86,7 +83,7 @@ type Curve25519PrivateKey []byte // Equal compares the private key to the given private key. func (c Curve25519PrivateKey) Equal(x Curve25519PrivateKey) bool { - return bytes.Equal(c, x) + return subtle.ConstantTimeCompare(c, x) == 1 } // PubKey returns the public key derived from the private key. @@ -104,7 +101,7 @@ type Curve25519PublicKey []byte // Equal compares the public key to the given public key. func (c Curve25519PublicKey) Equal(x Curve25519PublicKey) bool { - return bytes.Equal(c, x) + return subtle.ConstantTimeCompare(c, x) == 1 } // B64Encoded returns a base64 encoded string of the public key. diff --git a/crypto/goolm/crypto/ed25519.go b/crypto/goolm/crypto/ed25519.go index 34ad397a..a3345ba9 100644 --- a/crypto/goolm/crypto/ed25519.go +++ b/crypto/goolm/crypto/ed25519.go @@ -9,7 +9,7 @@ import ( ) const ( - ED25519SignatureSize = ed25519.SignatureSize //The length of a signature + Ed25519SignatureSize = ed25519.SignatureSize //The length of a signature ) // Ed25519GenerateKey creates a new ed25519 key pair. @@ -50,7 +50,7 @@ func (c Ed25519KeyPair) B64Encoded() id.Ed25519 { } // Sign returns the signature for the message. -func (c Ed25519KeyPair) Sign(message []byte) []byte { +func (c Ed25519KeyPair) Sign(message []byte) ([]byte, error) { return c.PrivateKey.Sign(message) } @@ -96,12 +96,8 @@ func (c Ed25519PrivateKey) PubKey() Ed25519PublicKey { } // Sign returns the signature for the message. -func (c Ed25519PrivateKey) Sign(message []byte) []byte { - signature, err := ed25519.PrivateKey(c).Sign(nil, message, &ed25519.Options{}) - if err != nil { - panic(err) - } - return signature +func (c Ed25519PrivateKey) Sign(message []byte) ([]byte, error) { + return ed25519.PrivateKey(c).Sign(nil, message, &ed25519.Options{}) } // Ed25519PublicKey represents the public key for ed25519 usage. This is just a wrapper. diff --git a/crypto/goolm/crypto/ed25519_test.go b/crypto/goolm/crypto/ed25519_test.go index 96d67385..610b8f3e 100644 --- a/crypto/goolm/crypto/ed25519_test.go +++ b/crypto/goolm/crypto/ed25519_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "maunium.net/go/mautrix/crypto/ed25519" "maunium.net/go/mautrix/crypto/goolm/crypto" @@ -17,7 +18,8 @@ func TestEd25519(t *testing.T) { keypair, err := crypto.Ed25519GenerateKey() assert.NoError(t, err) message := []byte("test message") - signature := keypair.Sign(message) + signature, err := keypair.Sign(message) + require.NoError(t, err) assert.True(t, keypair.Verify(message, signature)) } @@ -29,7 +31,8 @@ func TestEd25519Case1(t *testing.T) { keyPair2 := crypto.Ed25519GenerateFromPrivate(keyPair.PrivateKey) assert.Equal(t, keyPair, keyPair2, "not equal key pairs") - signature := keyPair.Sign(message) + signature, err := keyPair.Sign(message) + require.NoError(t, err) verified := keyPair.Verify(message, signature) assert.True(t, verified, "message did not verify although it should") diff --git a/crypto/goolm/crypto/one_time_key.go b/crypto/goolm/crypto/one_time_key.go index 0947f43b..888b1749 100644 --- a/crypto/goolm/crypto/one_time_key.go +++ b/crypto/goolm/crypto/one_time_key.go @@ -5,7 +5,6 @@ import ( "encoding/binary" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" - "maunium.net/go/mautrix/id" ) // OneTimeKey stores the information about a one time key. @@ -16,20 +15,11 @@ type OneTimeKey struct { } // Equal compares the one time key to the given one. -func (otk OneTimeKey) Equal(s OneTimeKey) bool { - if otk.ID != s.ID { - return false - } - if otk.Published != s.Published { - return false - } - if !otk.Key.PrivateKey.Equal(s.Key.PrivateKey) { - return false - } - if !otk.Key.PublicKey.Equal(s.Key.PublicKey) { - return false - } - return true +func (otk OneTimeKey) Equal(other OneTimeKey) bool { + return otk.ID == other.ID && + otk.Published == other.Published && + otk.Key.PrivateKey.Equal(other.Key.PrivateKey) && + otk.Key.PublicKey.Equal(other.Key.PublicKey) } // PickleLibOlm pickles the key pair into the encoder. @@ -50,14 +40,7 @@ func (c *OneTimeKey) UnpickleLibOlm(decoder *libolmpickle.Decoder) (err error) { return c.Key.UnpickleLibOlm(decoder) } -// KeyIDEncoded returns the base64 encoded id. +// KeyIDEncoded returns the base64 encoded key ID. func (c OneTimeKey) KeyIDEncoded() string { - resSlice := make([]byte, 4) - binary.BigEndian.PutUint32(resSlice, c.ID) - return base64.RawStdEncoding.EncodeToString(resSlice) -} - -// PublicKeyEncoded returns the base64 encoded public key -func (c OneTimeKey) PublicKeyEncoded() id.Curve25519 { - return c.Key.PublicKey.B64Encoded() + return base64.RawStdEncoding.EncodeToString(binary.BigEndian.AppendUint32(nil, c.ID)) } diff --git a/crypto/goolm/megolm/megolm.go b/crypto/goolm/megolm/megolm.go index eab82cc0..416db111 100644 --- a/crypto/goolm/megolm/megolm.go +++ b/crypto/goolm/megolm/megolm.go @@ -161,8 +161,8 @@ func (r Ratchet) SessionSharingMessage(key crypto.Ed25519KeyPair) ([]byte, error m := message.MegolmSessionSharing{} m.Counter = r.Counter m.RatchetData = r.Data - encoded := m.EncodeAndSign(key) - return goolmbase64.Encode(encoded), nil + encoded, err := m.EncodeAndSign(key) + return goolmbase64.Encode(encoded), err } // SessionExportMessage creates a message in the session export format. diff --git a/crypto/goolm/message/group_message.go b/crypto/goolm/message/group_message.go index ebd5b77e..b34bfa5e 100644 --- a/crypto/goolm/message/group_message.go +++ b/crypto/goolm/message/group_message.go @@ -32,7 +32,7 @@ func (r *GroupMessage) Decode(input []byte) error { //first Byte is always version r.Version = input[0] curPos := 1 - for curPos < len(input)-countMACBytesGroupMessage-crypto.ED25519SignatureSize { + for curPos < len(input)-countMACBytesGroupMessage-crypto.Ed25519SignatureSize { //Read Key curKey, readBytes := decodeVarInt(input[curPos:]) if err := checkDecodeErr(readBytes); err != nil { @@ -98,7 +98,10 @@ func (r *GroupMessage) EncodeAndMacAndSign(macKey []byte, cipher cipher.Cipher, out = append(out, mac[:countMACBytesGroupMessage]...) } if signKey != nil { - signature := signKey.Sign(out) + signature, err := signKey.Sign(out) + if err != nil { + return nil, err + } out = append(out, signature...) } return out, nil @@ -120,8 +123,8 @@ func (r *GroupMessage) VerifySignature(key crypto.Ed25519PublicKey, message, giv // VerifySignature verifies the signature taken from the message to the calculated signature of the message. func (r *GroupMessage) VerifySignatureInline(key crypto.Ed25519PublicKey, message []byte) bool { - signature := message[len(message)-crypto.ED25519SignatureSize:] - message = message[:len(message)-crypto.ED25519SignatureSize] + signature := message[len(message)-crypto.Ed25519SignatureSize:] + message = message[:len(message)-crypto.Ed25519SignatureSize] return key.Verify(message, signature) } @@ -136,7 +139,7 @@ func (r *GroupMessage) VerifyMAC(key []byte, cipher cipher.Cipher, message, give // VerifyMACInline verifies the MAC taken from the message to the calculated MAC of the message. func (r *GroupMessage) VerifyMACInline(key []byte, cipher cipher.Cipher, message []byte) (bool, error) { - startMAC := len(message) - countMACBytesGroupMessage - crypto.ED25519SignatureSize + startMAC := len(message) - countMACBytesGroupMessage - crypto.Ed25519SignatureSize endMAC := startMAC + countMACBytesGroupMessage suplMac := message[startMAC:endMAC] message = message[:startMAC] diff --git a/crypto/goolm/message/session_sharing.go b/crypto/goolm/message/session_sharing.go index 85d5d20b..16240945 100644 --- a/crypto/goolm/message/session_sharing.go +++ b/crypto/goolm/message/session_sharing.go @@ -20,15 +20,15 @@ type MegolmSessionSharing struct { } // Encode returns the encoded message in the correct format with the signature by key appended. -func (s MegolmSessionSharing) EncodeAndSign(key crypto.Ed25519KeyPair) []byte { +func (s MegolmSessionSharing) EncodeAndSign(key crypto.Ed25519KeyPair) ([]byte, error) { output := make([]byte, 229) output[0] = sessionSharingVersion binary.BigEndian.PutUint32(output[1:], s.Counter) copy(output[5:], s.RatchetData[:]) copy(output[133:], key.PublicKey) - signature := key.Sign(output[:165]) + signature, err := key.Sign(output[:165]) copy(output[165:], signature) - return output + return output, err } // VerifyAndDecode verifies the input and populates the struct with the data encoded in input. diff --git a/crypto/goolm/pk/signing.go b/crypto/goolm/pk/signing.go index 9dfd24a1..61b31b6f 100644 --- a/crypto/goolm/pk/signing.go +++ b/crypto/goolm/pk/signing.go @@ -48,8 +48,8 @@ func (s Signing) PublicKey() id.Ed25519 { // Sign returns the signature of the message base64 encoded. func (s Signing) Sign(message []byte) ([]byte, error) { - signature := s.keyPair.Sign(message) - return goolmbase64.Encode(signature), nil + signature, err := s.keyPair.Sign(message) + return goolmbase64.Encode(signature), err } // SignJSON creates a signature for the given object after encoding it to diff --git a/crypto/goolm/session/register.go b/crypto/goolm/session/register.go index 0a8b3605..09ed42d4 100644 --- a/crypto/goolm/session/register.go +++ b/crypto/goolm/session/register.go @@ -48,16 +48,8 @@ func init() { } return MegolmOutboundSessionFromPickled(pickled, key) } - olm.InitNewOutboundGroupSession = func() olm.OutboundGroupSession { - session, err := NewMegolmOutboundSession() - if err != nil { - panic(err) - } - return session - } - olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession { - return &MegolmOutboundSession{} - } + olm.InitNewOutboundGroupSession = func() (olm.OutboundGroupSession, error) { return NewMegolmOutboundSession() } + olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession { return &MegolmOutboundSession{} } // Olm Session olm.InitSessionFromPickled = func(pickled, key []byte) (olm.Session, error) { diff --git a/crypto/libolm/outboundgroupsession.go b/crypto/libolm/outboundgroupsession.go index cb2ce38b..a21f8d4a 100644 --- a/crypto/libolm/outboundgroupsession.go +++ b/crypto/libolm/outboundgroupsession.go @@ -28,32 +28,28 @@ func init() { s := NewBlankOutboundGroupSession() return s, s.Unpickle(pickled, key) } - olm.InitNewOutboundGroupSession = func() olm.OutboundGroupSession { - return NewOutboundGroupSession() - } - olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession { - return NewBlankOutboundGroupSession() - } + olm.InitNewOutboundGroupSession = func() (olm.OutboundGroupSession, error) { return NewOutboundGroupSession() } + olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession { return NewBlankOutboundGroupSession() } } // Ensure that [OutboundGroupSession] implements [olm.OutboundGroupSession]. var _ olm.OutboundGroupSession = (*OutboundGroupSession)(nil) -func NewOutboundGroupSession() *OutboundGroupSession { +func NewOutboundGroupSession() (*OutboundGroupSession, error) { s := NewBlankOutboundGroupSession() random := make([]byte, s.createRandomLen()+1) _, err := rand.Read(random) if err != nil { - panic(olm.NotEnoughGoRandom) + return nil, err } r := C.olm_init_outbound_group_session( (*C.OlmOutboundGroupSession)(s.int), (*C.uint8_t)(&random[0]), C.size_t(len(random))) if r == errorVal() { - panic(s.lastError()) + return nil, s.lastError() } - return s + return s, nil } // outboundGroupSessionSize is the size of an outbound group session object in diff --git a/crypto/olm/groupsession_test.go b/crypto/olm/groupsession_test.go index 276e7cfb..0f845e90 100644 --- a/crypto/olm/groupsession_test.go +++ b/crypto/olm/groupsession_test.go @@ -31,7 +31,8 @@ func TestEncryptDecrypt_GoolmToLibolm(t *testing.T) { } func TestEncryptDecrypt_LibolmToGoolm(t *testing.T) { - libolmOutbound := libolm.NewOutboundGroupSession() + libolmOutbound, err := libolm.NewOutboundGroupSession() + require.NoError(t, err) goolmInbound, err := session.NewMegolmInboundSession([]byte(libolmOutbound.Key())) require.NoError(t, err) diff --git a/crypto/olm/outboundgroupsession.go b/crypto/olm/outboundgroupsession.go index c5b7bcbf..7e582b7e 100644 --- a/crypto/olm/outboundgroupsession.go +++ b/crypto/olm/outboundgroupsession.go @@ -34,7 +34,7 @@ type OutboundGroupSession interface { } var InitNewOutboundGroupSessionFromPickled func(pickled, key []byte) (OutboundGroupSession, error) -var InitNewOutboundGroupSession func() OutboundGroupSession +var InitNewOutboundGroupSession func() (OutboundGroupSession, error) var InitNewBlankOutboundGroupSession func() OutboundGroupSession // OutboundGroupSessionFromPickled loads an OutboundGroupSession from a pickled @@ -47,7 +47,7 @@ func OutboundGroupSessionFromPickled(pickled, key []byte) (OutboundGroupSession, } // NewOutboundGroupSession creates a new outbound group session. -func NewOutboundGroupSession() OutboundGroupSession { +func NewOutboundGroupSession() (OutboundGroupSession, error) { return InitNewOutboundGroupSession() } diff --git a/crypto/olm/outboundgroupsession_test.go b/crypto/olm/outboundgroupsession_test.go index 46c63780..cbbc89f7 100644 --- a/crypto/olm/outboundgroupsession_test.go +++ b/crypto/olm/outboundgroupsession_test.go @@ -12,7 +12,8 @@ import ( ) func TestMegolmOutboundSessionPickle_RoundtripThroughGoolm(t *testing.T) { - libolmSession := libolm.NewOutboundGroupSession() + libolmSession, err := libolm.NewOutboundGroupSession() + require.NoError(t, err) libolmPickled, err := libolmSession.Pickle([]byte("test")) require.NoError(t, err) @@ -24,7 +25,8 @@ func TestMegolmOutboundSessionPickle_RoundtripThroughGoolm(t *testing.T) { assert.Equal(t, libolmPickled, goolmPickled, "pickled versions are not the same") - libolmSession2 := libolm.NewOutboundGroupSession() + libolmSession2, err := libolm.NewOutboundGroupSession() + require.NoError(t, err) err = libolmSession2.Unpickle(bytes.Clone(goolmPickled), []byte("test")) require.NoError(t, err) @@ -38,7 +40,8 @@ func TestMegolmOutboundSessionPickle_RoundtripThroughLibolm(t *testing.T) { goolmPickled, err := goolmSession.Pickle([]byte("test")) require.NoError(t, err) - libolmSession := libolm.NewOutboundGroupSession() + libolmSession, err := libolm.NewOutboundGroupSession() + require.NoError(t, err) err = libolmSession.Unpickle(bytes.Clone(goolmPickled), []byte("test")) require.NoError(t, err) @@ -55,7 +58,8 @@ func TestMegolmOutboundSessionPickle_RoundtripThroughLibolm(t *testing.T) { } func TestMegolmOutboundSessionPickleLibolm(t *testing.T) { - libolmSession := libolm.NewOutboundGroupSession() + libolmSession, err := libolm.NewOutboundGroupSession() + require.NoError(t, err) libolmPickled, err := libolmSession.Pickle([]byte("test")) require.NoError(t, err) @@ -77,7 +81,8 @@ func TestMegolmOutboundSessionPickleGoolm(t *testing.T) { goolmPickled, err := goolmSession.Pickle([]byte("test")) require.NoError(t, err) - libolmSession := libolm.NewOutboundGroupSession() + libolmSession, err := libolm.NewOutboundGroupSession() + require.NoError(t, err) err = libolmSession.Unpickle(bytes.Clone(goolmPickled), []byte("test")) require.NoError(t, err) libolmPickled, err := libolmSession.Pickle([]byte("test")) @@ -98,7 +103,8 @@ func FuzzMegolmOutboundSession_Encrypt(f *testing.F) { t.Skip("empty plaintext is not supported") } - libolmSession := libolm.NewOutboundGroupSession() + libolmSession, err := libolm.NewOutboundGroupSession() + require.NoError(t, err) libolmPickled, err := libolmSession.Pickle([]byte("test")) require.NoError(t, err) diff --git a/crypto/sessions.go b/crypto/sessions.go index 4aac6cf7..c22b5b58 100644 --- a/crypto/sessions.go +++ b/crypto/sessions.go @@ -180,9 +180,13 @@ type OutboundGroupSession struct { content *event.RoomKeyEventContent } -func NewOutboundGroupSession(roomID id.RoomID, encryptionContent *event.EncryptionEventContent) *OutboundGroupSession { +func NewOutboundGroupSession(roomID id.RoomID, encryptionContent *event.EncryptionEventContent) (*OutboundGroupSession, error) { + internal, err := olm.NewOutboundGroupSession() + if err != nil { + return nil, err + } ogs := &OutboundGroupSession{ - Internal: olm.NewOutboundGroupSession(), + Internal: internal, ExpirationMixin: ExpirationMixin{ TimeMixin: TimeMixin{ CreationTime: time.Now(), @@ -206,7 +210,7 @@ func NewOutboundGroupSession(roomID id.RoomID, encryptionContent *event.Encrypti ogs.MaxMessages = min(max(encryptionContent.RotationPeriodMessages, 1), 10000) } } - return ogs + return ogs, nil } func (ogs *OutboundGroupSession) ShareContent() event.Content { diff --git a/crypto/store_test.go b/crypto/store_test.go index 08079f5e..a7c4d75a 100644 --- a/crypto/store_test.go +++ b/crypto/store_test.go @@ -13,6 +13,7 @@ import ( "testing" _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/require" "go.mau.fi/util/dbutil" "maunium.net/go/mautrix/crypto/olm" @@ -195,7 +196,8 @@ func TestStoreOutboundMegolmSession(t *testing.T) { t.Errorf("Error retrieving outbound session: %v", err) } - outbound := NewOutboundGroupSession("room1", nil) + outbound, err := NewOutboundGroupSession("room1", nil) + require.NoError(t, err) err = store.AddOutboundGroupSession(context.TODO(), outbound) if err != nil { t.Errorf("Error inserting outbound session: %v", err) From a59d4d78677fc948c0736569079454184fac2a72 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 26 Oct 2024 14:21:23 +0300 Subject: [PATCH 0878/1647] format: add support for img tags in HTML parser --- format/htmlparser.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/format/htmlparser.go b/format/htmlparser.go index d099e8a7..bafd41af 100644 --- a/format/htmlparser.go +++ b/format/htmlparser.go @@ -66,6 +66,7 @@ type LinkConverter func(text, href string, ctx Context) string type ColorConverter func(text, fg, bg string, ctx Context) string type CodeBlockConverter func(code, language string, ctx Context) string type PillConverter func(displayname, mxid, eventID string, ctx Context) string +type ImageConverter func(src, alt string, isEmoji bool) string const ContextKeyMentions = "_mentions" @@ -107,6 +108,7 @@ type HTMLParser struct { MonospaceBlockConverter CodeBlockConverter MonospaceConverter TextConverter TextConverter TextConverter + ImageConverter ImageConverter } // TaggedString is a string that also contains a HTML tag. @@ -298,6 +300,16 @@ func (parser *HTMLParser) linkToString(node *html.Node, ctx Context) string { return fmt.Sprintf("%s (%s)", str, href) } +func (parser *HTMLParser) imgToString(node *html.Node, ctx Context) string { + src := parser.getAttribute(node, "src") + alt := parser.getAttribute(node, "alt") + _, isEmoji := parser.maybeGetAttribute(node, "data-mx-emoticon") + if parser.ImageConverter != nil { + return parser.ImageConverter(src, alt, isEmoji) + } + return alt +} + func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string { ctx = ctx.WithTag(node.Data) switch node.Data { @@ -317,6 +329,8 @@ func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string { return parser.linkToString(node, ctx) case "p": return parser.nodeToTagAwareString(node.FirstChild, ctx) + case "img": + return parser.imgToString(node, ctx) case "hr": return parser.HorizontalLine case "pre": From 0f3c599888236925f813de68002595928a715f7b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 28 Oct 2024 14:24:00 +0200 Subject: [PATCH 0879/1647] appservice/registration: update stable ephemeral event field --- appservice/registration.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/appservice/registration.go b/appservice/registration.go index c0b62124..026df8ea 100644 --- a/appservice/registration.go +++ b/appservice/registration.go @@ -27,7 +27,7 @@ type Registration struct { Protocols []string `yaml:"protocols,omitempty" json:"protocols,omitempty"` SoruEphemeralEvents bool `yaml:"de.sorunome.msc2409.push_ephemeral,omitempty" json:"de.sorunome.msc2409.push_ephemeral,omitempty"` - EphemeralEvents bool `yaml:"push_ephemeral,omitempty" json:"push_ephemeral,omitempty"` + EphemeralEvents bool `yaml:"receive_ephemeral,omitempty" json:"receive_ephemeral,omitempty"` MSC3202 bool `yaml:"org.matrix.msc3202,omitempty" json:"org.matrix.msc3202,omitempty"` } From 0f31a2fb8e3420ed5f1e5b56404c7c113f179e66 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 28 Oct 2024 13:29:17 -0600 Subject: [PATCH 0880/1647] simplevent/meta: add builder pattern methods Signed-off-by: Sumner Evans --- bridgev2/simplevent/meta.go | 40 +++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/bridgev2/simplevent/meta.go b/bridgev2/simplevent/meta.go index a6b278fc..5fddd6f3 100644 --- a/bridgev2/simplevent/meta.go +++ b/bridgev2/simplevent/meta.go @@ -72,3 +72,43 @@ func (evt *EventMeta) GetType() bridgev2.RemoteEventType { func (evt *EventMeta) ShouldCreatePortal() bool { return evt.CreatePortal } + +func (evt *EventMeta) WithType(t bridgev2.RemoteEventType) *EventMeta { + evt.Type = t + return evt +} + +func (evt *EventMeta) WithLogContext(f func(c zerolog.Context) zerolog.Context) *EventMeta { + evt.LogContext = f + return evt +} + +func (evt *EventMeta) WithPortalKey(p networkid.PortalKey) *EventMeta { + evt.PortalKey = p + return evt +} + +func (evt *EventMeta) WithUncertainReceiver(u bool) *EventMeta { + evt.UncertainReceiver = u + return evt +} + +func (evt *EventMeta) WithSender(s bridgev2.EventSender) *EventMeta { + evt.Sender = s + return evt +} + +func (evt *EventMeta) WithCreatePortal(c bool) *EventMeta { + evt.CreatePortal = c + return evt +} + +func (evt *EventMeta) WithTimestamp(t time.Time) *EventMeta { + evt.Timestamp = t + return evt +} + +func (evt *EventMeta) WithStreamOrder(s int64) *EventMeta { + evt.StreamOrder = s + return evt +} From 48aa04889cd6606fcabc6b866ee60ed7bc1ac63d Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 28 Oct 2024 13:46:14 -0600 Subject: [PATCH 0881/1647] simplevent/meta: make builder pattern methods non-pointer receivers Signed-off-by: Sumner Evans --- bridgev2/simplevent/meta.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/bridgev2/simplevent/meta.go b/bridgev2/simplevent/meta.go index 5fddd6f3..9d7d697a 100644 --- a/bridgev2/simplevent/meta.go +++ b/bridgev2/simplevent/meta.go @@ -73,42 +73,42 @@ func (evt *EventMeta) ShouldCreatePortal() bool { return evt.CreatePortal } -func (evt *EventMeta) WithType(t bridgev2.RemoteEventType) *EventMeta { +func (evt EventMeta) WithType(t bridgev2.RemoteEventType) EventMeta { evt.Type = t return evt } -func (evt *EventMeta) WithLogContext(f func(c zerolog.Context) zerolog.Context) *EventMeta { +func (evt EventMeta) WithLogContext(f func(c zerolog.Context) zerolog.Context) EventMeta { evt.LogContext = f return evt } -func (evt *EventMeta) WithPortalKey(p networkid.PortalKey) *EventMeta { +func (evt EventMeta) WithPortalKey(p networkid.PortalKey) EventMeta { evt.PortalKey = p return evt } -func (evt *EventMeta) WithUncertainReceiver(u bool) *EventMeta { +func (evt EventMeta) WithUncertainReceiver(u bool) EventMeta { evt.UncertainReceiver = u return evt } -func (evt *EventMeta) WithSender(s bridgev2.EventSender) *EventMeta { +func (evt EventMeta) WithSender(s bridgev2.EventSender) EventMeta { evt.Sender = s return evt } -func (evt *EventMeta) WithCreatePortal(c bool) *EventMeta { +func (evt EventMeta) WithCreatePortal(c bool) EventMeta { evt.CreatePortal = c return evt } -func (evt *EventMeta) WithTimestamp(t time.Time) *EventMeta { +func (evt EventMeta) WithTimestamp(t time.Time) EventMeta { evt.Timestamp = t return evt } -func (evt *EventMeta) WithStreamOrder(s int64) *EventMeta { +func (evt EventMeta) WithStreamOrder(s int64) EventMeta { evt.StreamOrder = s return evt } From 7066beb946d016607b58c9b3bc7d3767ae6d7103 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 28 Oct 2024 23:02:40 -0600 Subject: [PATCH 0882/1647] event: fix receivers for event.Type Signed-off-by: Sumner Evans --- event/type.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/event/type.go b/event/type.go index 4396c9cc..f2b841ad 100644 --- a/event/type.go +++ b/event/type.go @@ -149,7 +149,7 @@ func (et *Type) MarshalJSON() ([]byte, error) { return json.Marshal(&et.Type) } -func (et Type) UnmarshalText(data []byte) error { +func (et *Type) UnmarshalText(data []byte) error { et.Type = string(data) et.Class = et.GuessClass() return nil @@ -159,11 +159,11 @@ func (et Type) MarshalText() ([]byte, error) { return []byte(et.Type), nil } -func (et *Type) String() string { +func (et Type) String() string { return et.Type } -func (et *Type) Repr() string { +func (et Type) Repr() string { return fmt.Sprintf("%s (%s)", et.Type, et.Class.Name()) } From 2a20aa32323e91b26adc4177233ecefb6ad2c1be Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 28 Oct 2024 23:04:00 -0600 Subject: [PATCH 0883/1647] verificationhelper: add better logging when sending cancellation due to unknown transaction ID Signed-off-by: Sumner Evans --- crypto/verificationhelper/verificationhelper.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index cbcff887..d3b7d4f5 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -240,6 +240,7 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { Stringer("sender", evt.Sender). Stringer("room_id", evt.RoomID). Stringer("event_id", evt.ID). + Stringer("event_type", evt.Type). Logger() var transactionID id.VerificationTransactionID @@ -266,8 +267,10 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { return } + log.Warn().Msg("Sending cancellation event for unknown transaction ID") + // We have to create a fake transaction so that the call to - // verificationCancelled works. + // cancelVerificationTxn works. txn = &verificationTransaction{ RoomID: evt.RoomID, TheirUser: evt.Sender, From 40927f4b1256bd568312b8342274e2d17ae9756b Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Wed, 30 Oct 2024 09:58:09 -0600 Subject: [PATCH 0884/1647] pre-commit: update and ban use of global zerolog logger Signed-off-by: Sumner Evans --- .pre-commit-config.yaml | 5 +++-- crypto/verificationhelper/mockserver_test.go | 2 +- .../verificationhelper_qr_crosssign_test.go | 2 +- crypto/verificationhelper/verificationhelper_qr_self_test.go | 2 +- crypto/verificationhelper/verificationhelper_sas_test.go | 2 +- crypto/verificationhelper/verificationhelper_test.go | 2 +- 6 files changed, 8 insertions(+), 7 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c15d69d6..8827c231 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: trailing-whitespace exclude_types: [markdown] @@ -22,6 +22,7 @@ repos: #- id: go-staticcheck-repo-mod - repo: https://github.com/beeper/pre-commit-go - rev: v0.3.1 + rev: v0.4.1 hooks: - id: prevent-literal-http-methods + - id: zerolog-ban-global-log diff --git a/crypto/verificationhelper/mockserver_test.go b/crypto/verificationhelper/mockserver_test.go index e35f51b2..b6bf3d2c 100644 --- a/crypto/verificationhelper/mockserver_test.go +++ b/crypto/verificationhelper/mockserver_test.go @@ -17,7 +17,7 @@ import ( "testing" "github.com/gorilla/mux" - "github.com/rs/zerolog/log" + "github.com/rs/zerolog/log" // zerolog-allow-global-log "github.com/stretchr/testify/require" "go.mau.fi/util/random" diff --git a/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go b/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go index 2bbed25e..aace2230 100644 --- a/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go +++ b/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go @@ -11,7 +11,7 @@ import ( "fmt" "testing" - "github.com/rs/zerolog/log" + "github.com/rs/zerolog/log" // zerolog-allow-global-log "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" diff --git a/crypto/verificationhelper/verificationhelper_qr_self_test.go b/crypto/verificationhelper/verificationhelper_qr_self_test.go index 443157b7..11358b88 100644 --- a/crypto/verificationhelper/verificationhelper_qr_self_test.go +++ b/crypto/verificationhelper/verificationhelper_qr_self_test.go @@ -11,7 +11,7 @@ import ( "fmt" "testing" - "github.com/rs/zerolog/log" + "github.com/rs/zerolog/log" // zerolog-allow-global-log "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" diff --git a/crypto/verificationhelper/verificationhelper_sas_test.go b/crypto/verificationhelper/verificationhelper_sas_test.go index e986cf85..20e52e0f 100644 --- a/crypto/verificationhelper/verificationhelper_sas_test.go +++ b/crypto/verificationhelper/verificationhelper_sas_test.go @@ -5,7 +5,7 @@ import ( "fmt" "testing" - "github.com/rs/zerolog/log" + "github.com/rs/zerolog/log" // zerolog-allow-global-log "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/exp/maps" diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go index 876e90f7..273042c3 100644 --- a/crypto/verificationhelper/verificationhelper_test.go +++ b/crypto/verificationhelper/verificationhelper_test.go @@ -8,7 +8,7 @@ import ( "time" "github.com/rs/zerolog" - "github.com/rs/zerolog/log" + "github.com/rs/zerolog/log" // zerolog-allow-global-log "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" From 7c227e175de9354614d2bccccf031985100cd9d4 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Wed, 30 Oct 2024 10:15:37 -0600 Subject: [PATCH 0885/1647] pre-commit: update and ban Msgf on zerolog logs Signed-off-by: Sumner Evans --- .pre-commit-config.yaml | 3 ++- bridge/commands/event.go | 8 ++++---- bridgev2/commands/event.go | 8 ++++---- crypto/keysharing.go | 10 +++++++--- 4 files changed, 17 insertions(+), 12 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8827c231..81701203 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,7 +22,8 @@ repos: #- id: go-staticcheck-repo-mod - repo: https://github.com/beeper/pre-commit-go - rev: v0.4.1 + rev: v0.4.2 hooks: - id: prevent-literal-http-methods - id: zerolog-ban-global-log + - id: zerolog-ban-msgf diff --git a/bridge/commands/event.go b/bridge/commands/event.go index 42b49b68..49a8b277 100644 --- a/bridge/commands/event.go +++ b/bridge/commands/event.go @@ -66,7 +66,7 @@ func (ce *Event) ReplyAdvanced(msg string, allowMarkdown, allowHTML bool) { content.MsgType = event.MsgNotice _, err := ce.MainIntent().SendMessageEvent(ce.Ctx, ce.RoomID, event.EventMessage, content) if err != nil { - ce.ZLog.Error().Err(err).Msgf("Failed to reply to command") + ce.ZLog.Error().Err(err).Msg("Failed to reply to command") } } @@ -74,7 +74,7 @@ func (ce *Event) ReplyAdvanced(msg string, allowMarkdown, allowHTML bool) { func (ce *Event) React(key string) { _, err := ce.MainIntent().SendReaction(ce.Ctx, ce.RoomID, ce.EventID, key) if err != nil { - ce.ZLog.Error().Err(err).Msgf("Failed to react to command") + ce.ZLog.Error().Err(err).Msg("Failed to react to command") } } @@ -82,7 +82,7 @@ func (ce *Event) React(key string) { func (ce *Event) Redact(req ...mautrix.ReqRedact) { _, err := ce.MainIntent().RedactEvent(ce.Ctx, ce.RoomID, ce.EventID, req...) if err != nil { - ce.ZLog.Error().Err(err).Msgf("Failed to redact command") + ce.ZLog.Error().Err(err).Msg("Failed to redact command") } } @@ -90,6 +90,6 @@ func (ce *Event) Redact(req ...mautrix.ReqRedact) { func (ce *Event) MarkRead() { err := ce.MainIntent().SendReceipt(ce.Ctx, ce.RoomID, ce.EventID, event.ReceiptTypeRead, nil) if err != nil { - ce.ZLog.Error().Err(err).Msgf("Failed to mark command as read") + ce.ZLog.Error().Err(err).Msg("Failed to mark command as read") } } diff --git a/bridgev2/commands/event.go b/bridgev2/commands/event.go index bd2c52d2..78ed94bb 100644 --- a/bridgev2/commands/event.go +++ b/bridgev2/commands/event.go @@ -58,7 +58,7 @@ func (ce *Event) ReplyAdvanced(msg string, allowMarkdown, allowHTML bool) { content.MsgType = event.MsgNotice _, err := ce.Bot.SendMessage(ce.Ctx, ce.OrigRoomID, event.EventMessage, &event.Content{Parsed: &content}, nil) if err != nil { - ce.Log.Err(err).Msgf("Failed to reply to command") + ce.Log.Err(err).Msg("Failed to reply to command") } } @@ -74,7 +74,7 @@ func (ce *Event) React(key string) { }, }, nil) if err != nil { - ce.Log.Err(err).Msgf("Failed to react to command") + ce.Log.Err(err).Msg("Failed to react to command") } } @@ -86,7 +86,7 @@ func (ce *Event) Redact(req ...mautrix.ReqRedact) { }, }, nil) if err != nil { - ce.Log.Err(err).Msgf("Failed to redact command") + ce.Log.Err(err).Msg("Failed to redact command") } } @@ -95,6 +95,6 @@ func (ce *Event) MarkRead() { // TODO //err := ce.Bot.SendReceipt(ce.Ctx, ce.RoomID, ce.EventID, event.ReceiptTypeRead, nil) //if err != nil { - // ce.Log.Err(err).Msgf("Failed to mark command as read") + // ce.Log.Err(err).Msg("Failed to mark command as read") //} } diff --git a/crypto/keysharing.go b/crypto/keysharing.go index 38e015c6..0ccf006a 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -59,11 +59,15 @@ func (mach *OlmMachine) RequestRoomKey(ctx context.Context, toUser id.UserID, to select { case <-keyResponseReceived: // key request successful - mach.Log.Debug().Msgf("Key for session %v was received, cancelling other key requests", sessionID) + mach.Log.Debug(). + Stringer("session_id", sessionID). + Msg("Key for session was received, cancelling other key requests") resChan <- true case <-ctx.Done(): // if the context is done, key request was unsuccessful - mach.Log.Debug().Msgf("Context closed (%v) before forwared key for session %v received, sending key request cancellation", ctx.Err(), sessionID) + mach.Log.Debug().Err(err). + Stringer("session_id", sessionID). + Msg("Context closed before forwarded key for session received, sending key request cancellation") resChan <- false } @@ -332,7 +336,7 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User } if internalID := igs.ID(); internalID != content.Body.SessionID { // Should this be an error? - log = log.With().Str("unexpected_session_id", internalID.String()).Logger() + log = log.With().Stringer("unexpected_session_id", internalID).Logger() } firstKnownIndex := igs.Internal.FirstKnownIndex() From 34d551085c18cd3f2eb82cf60a2caa904e102ca1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 30 Oct 2024 21:04:00 +0200 Subject: [PATCH 0886/1647] crypto/encryptmegolm: log target identity key when encrypting session --- crypto/encryptmegolm.go | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index 3199ce57..ef5f404f 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -359,23 +359,26 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session toDevice.Messages[userID] = output for deviceID, device := range sessions { log.Trace(). - Str("target_user_id", userID.String()). - Str("target_device_id", deviceID.String()). + Stringer("target_user_id", userID). + Stringer("target_device_id", deviceID). + Stringer("target_identity_key", device.identity.IdentityKey). Msg("Encrypting group session for device") content := mach.encryptOlmEvent(ctx, device.session, device.identity, event.ToDeviceRoomKey, session.ShareContent()) output[deviceID] = &event.Content{Parsed: content} deviceCount++ log.Debug(). - Str("target_user_id", userID.String()). - Str("target_device_id", deviceID.String()). + Stringer("target_user_id", userID). + Stringer("target_device_id", deviceID). + Stringer("target_identity_key", device.identity.IdentityKey). Msg("Encrypted group session for device") if !mach.DisableSharedGroupSessionTracking { err := mach.CryptoStore.MarkOutboundGroupSessionShared(ctx, userID, device.identity.IdentityKey, session.id) if err != nil { log.Warn(). Err(err). - Str("target_user_id", userID.String()). - Str("target_device_id", deviceID.String()). + Stringer("target_user_id", userID). + Stringer("target_device_id", deviceID). + Stringer("target_identity_key", device.identity.IdentityKey). Stringer("target_session_id", session.id). Msg("Failed to mark outbound group session shared") } @@ -394,8 +397,9 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session func (mach *OlmMachine) findOlmSessionsForUser(ctx context.Context, session *OutboundGroupSession, userID id.UserID, devices map[id.DeviceID]*id.Device, output map[id.DeviceID]deviceSessionWrapper, withheld map[id.DeviceID]*event.Content, missingOutput map[id.DeviceID]*id.Device) { for deviceID, device := range devices { log := zerolog.Ctx(ctx).With(). - Str("target_user_id", userID.String()). - Str("target_device_id", deviceID.String()). + Stringer("target_user_id", userID). + Stringer("target_device_id", deviceID). + Stringer("target_identity_key", device.IdentityKey). Logger() userKey := UserDevice{UserID: userID, DeviceID: deviceID} if state := session.Users[userKey]; state != OGSNotShared { From 7e8d435aefdbc009b45958575cdb858b4a802549 Mon Sep 17 00:00:00 2001 From: Scott Weber Date: Thu, 31 Oct 2024 14:35:48 -0400 Subject: [PATCH 0887/1647] event: add BeeperHSOrderString to unsigned (#308) --- event/events.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/event/events.go b/event/events.go index 23769ae8..e9c5955c 100644 --- a/event/events.go +++ b/event/events.go @@ -144,13 +144,14 @@ type Unsigned struct { RedactedBecause *Event `json:"redacted_because,omitempty"` InviteRoomState []StrippedState `json:"invite_room_state,omitempty"` - BeeperHSOrder int64 `json:"com.beeper.hs.order,omitempty"` - BeeperHSSuborder int64 `json:"com.beeper.hs.suborder,omitempty"` - BeeperFromBackup bool `json:"com.beeper.from_backup,omitempty"` + BeeperHSOrder int64 `json:"com.beeper.hs.order,omitempty"` + BeeperHSSuborder int64 `json:"com.beeper.hs.suborder,omitempty"` + BeeperHSOrderString string `json:"com.beeper.hs.order_string,omitempty"` + BeeperFromBackup bool `json:"com.beeper.from_backup,omitempty"` } func (us *Unsigned) IsEmpty() bool { return us.PrevContent == nil && us.PrevSender == "" && us.ReplacesState == "" && us.Age == 0 && us.TransactionID == "" && us.RedactedBecause == nil && us.InviteRoomState == nil && us.Relations == nil && - us.BeeperHSOrder == 0 && us.BeeperHSSuborder == 0 + us.BeeperHSOrder == 0 && us.BeeperHSSuborder == 0 && us.BeeperHSOrderString == "" } From f606129e732f66f0c06d8dca16bd58cf9c5f78fc Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Fri, 1 Nov 2024 16:26:20 +0000 Subject: [PATCH 0888/1647] Add Beeper local bridge fields to create room struct --- requests.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requests.go b/requests.go index a6b0ea8b..595f1212 100644 --- a/requests.go +++ b/requests.go @@ -123,6 +123,8 @@ type ReqCreateRoom struct { BeeperInitialMembers []id.UserID `json:"com.beeper.initial_members,omitempty"` BeeperAutoJoinInvites bool `json:"com.beeper.auto_join_invites,omitempty"` BeeperLocalRoomID id.RoomID `json:"com.beeper.local_room_id,omitempty"` + BeeperBridgeName string `json:"com.beeper.bridge_name,omitempty"` + BeeperBridgeAccountID string `json:"com.beeper.bridge_account_id,omitempty"` } // ReqRedact is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidredacteventidtxnid From 9f1a8f2cc499e13cf50c0f7f48bb470daaf89126 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 2 Nov 2024 11:53:18 +0200 Subject: [PATCH 0889/1647] format: add TextToContent helper --- format/markdown.go | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/format/markdown.go b/format/markdown.go index 11f9f684..d62013cc 100644 --- a/format/markdown.go +++ b/format/markdown.go @@ -49,6 +49,14 @@ func RenderMarkdownCustom(text string, renderer goldmark.Markdown) event.Message return HTMLToContent(htmlBody) } +func TextToContent(text string) event.MessageEventContent { + return event.MessageEventContent{ + MsgType: event.MsgText, + Body: text, + Mentions: &event.Mentions{}, + } +} + func HTMLToContent(html string) event.MessageEventContent { text, mentions := HTMLToMarkdownAndMentions(html) if html != text { @@ -60,11 +68,7 @@ func HTMLToContent(html string) event.MessageEventContent { Mentions: mentions, } } - return event.MessageEventContent{ - MsgType: event.MsgText, - Body: text, - Mentions: &event.Mentions{}, - } + return TextToContent(text) } func RenderMarkdown(text string, allowMarkdown, allowHTML bool) event.MessageEventContent { @@ -80,10 +84,6 @@ func RenderMarkdown(text string, allowMarkdown, allowHTML bool) event.MessageEve htmlBody = strings.Replace(text, "\n", "
", -1) return HTMLToContent(htmlBody) } else { - return event.MessageEventContent{ - MsgType: event.MsgText, - Body: text, - Mentions: &event.Mentions{}, - } + return TextToContent(text) } } From f0c46cf629b8d74d094ea225b422b875f2c102ec Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 2 Nov 2024 11:53:36 +0200 Subject: [PATCH 0890/1647] format/mdext: add math parser --- format/markdown_test.go | 20 ++++ format/mdext/math.go | 252 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 272 insertions(+) create mode 100644 format/mdext/math.go diff --git a/format/markdown_test.go b/format/markdown_test.go index 10ae270c..6170a47f 100644 --- a/format/markdown_test.go +++ b/format/markdown_test.go @@ -158,3 +158,23 @@ func TestRenderMarkdown_DiscordUnderline(t *testing.T) { assert.Equal(t, html, strings.ReplaceAll(rendered, "\n", "")) } } + +var mathTests = map[string]string{ + "$foo$": `foo`, + "$$foo$$": `
foo
`, + "$$\nfoo\nbar\n$$": `
foo
bar
`, + "`$foo$`": `$foo$`, + "```\n$foo$\n```": `
$foo$\n
`, + "~~$foo$~~": `foo`, + "$5 or $10": `$5 or $10`, + "5$ or 10$": `5$ or 10$`, + "$5 or 10$": `5 or 10`, +} + +func TestRenderMarkdown_Math(t *testing.T) { + renderer := goldmark.New(goldmark.WithExtensions(extension.Strikethrough, mdext.Math, mdext.EscapeHTML), format.HTMLOptions) + for markdown, html := range mathTests { + rendered := format.UnwrapSingleParagraph(render(renderer, markdown)) + assert.Equal(t, html, strings.ReplaceAll(rendered, "\n", "\\n")) + } +} diff --git a/format/mdext/math.go b/format/mdext/math.go new file mode 100644 index 00000000..0f08b9ea --- /dev/null +++ b/format/mdext/math.go @@ -0,0 +1,252 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mdext + +import ( + "bytes" + "fmt" + stdhtml "html" + "strings" + + "github.com/yuin/goldmark" + "github.com/yuin/goldmark/ast" + "github.com/yuin/goldmark/parser" + "github.com/yuin/goldmark/renderer" + "github.com/yuin/goldmark/renderer/html" + "github.com/yuin/goldmark/text" + "github.com/yuin/goldmark/util" +) + +var astKindMath = ast.NewNodeKind("Math") + +type astMath struct { + ast.BaseInline + block bool +} + +func (n *astMath) Dump(source []byte, level int) { + ast.DumpHelper(n, source, level, nil, nil) +} + +func (n *astMath) Kind() ast.NodeKind { + return astKindMath +} + +type astMathBlock struct { + ast.BaseBlock + info *ast.Text +} + +func (n *astMathBlock) Dump(source []byte, level int) { + ast.DumpHelper(n, source, level, nil, nil) +} + +func (n *astMathBlock) Kind() ast.NodeKind { + return astKindMath +} + +type mathDelimiterProcessor struct{} + +var defaultMathDelimiterProcessor = &mathDelimiterProcessor{} + +func (p *mathDelimiterProcessor) IsDelimiter(b byte) bool { + return b == '$' +} + +func (p *mathDelimiterProcessor) CanOpenCloser(opener, closer *parser.Delimiter) bool { + return opener.Char == closer.Char +} + +func (p *mathDelimiterProcessor) OnMatch(consumes int) ast.Node { + return &astMath{block: consumes > 1} +} + +type inlineMathParser struct{} + +var defaultInlineMathParser = &inlineMathParser{} + +func NewInlineMathParser() parser.InlineParser { + return defaultInlineMathParser +} + +func (s *inlineMathParser) Trigger() []byte { + return []byte{'$'} +} + +func (s *inlineMathParser) Parse(parent ast.Node, block text.Reader, pc parser.Context) ast.Node { + before := block.PrecendingCharacter() + line, segment := block.PeekLine() + node := parser.ScanDelimiter(line, before, 1, defaultMathDelimiterProcessor) + if node == nil { + return nil + } + node.Segment = segment.WithStop(segment.Start + node.OriginalLength) + block.Advance(node.OriginalLength) + pc.PushDelimiter(node) + return node +} + +func (s *inlineMathParser) CloseBlock(parent ast.Node, pc parser.Context) { + // nothing to do +} + +type blockMathParser struct{} + +var defaultBlockMathParser = &blockMathParser{} + +func NewBlockMathParser() parser.BlockParser { + return defaultBlockMathParser +} + +var mathBlockInfoKey = parser.NewContextKey() + +type mathBlockData struct { + indent int + length int + node ast.Node +} + +func (b *blockMathParser) Trigger() []byte { + return []byte{'$'} +} + +func (b *blockMathParser) Open(parent ast.Node, reader text.Reader, pc parser.Context) (ast.Node, parser.State) { + const fenceChar = '$' + line, segment := reader.PeekLine() + pos := pc.BlockOffset() + if pos < 0 || (line[pos] != fenceChar) { + return nil, parser.NoChildren + } + findent := pos + i := pos + for ; i < len(line) && line[i] == fenceChar; i++ { + } + oFenceLength := i - pos + if oFenceLength < 2 { + return nil, parser.NoChildren + } + var info *ast.Text + if i < len(line)-1 { + rest := line[i:] + left := util.TrimLeftSpaceLength(rest) + right := util.TrimRightSpaceLength(rest) + if left < len(rest)-right { + infoStart, infoStop := segment.Start-segment.Padding+i+left, segment.Stop-right + value := rest[left : len(rest)-right] + if bytes.IndexByte(value, fenceChar) > -1 { + return nil, parser.NoChildren + } else if infoStart != infoStop { + info = ast.NewTextSegment(text.NewSegment(infoStart, infoStop)) + } + } + } + node := &astMathBlock{info: info} + pc.Set(mathBlockInfoKey, &mathBlockData{findent, oFenceLength, node}) + return node, parser.NoChildren + +} + +func (b *blockMathParser) Continue(node ast.Node, reader text.Reader, pc parser.Context) parser.State { + const fenceChar = '$' + line, segment := reader.PeekLine() + fdata := pc.Get(mathBlockInfoKey).(*mathBlockData) + + w, pos := util.IndentWidth(line, reader.LineOffset()) + if w < 4 { + i := pos + for ; i < len(line) && line[i] == fenceChar; i++ { + } + length := i - pos + if length >= fdata.length && util.IsBlank(line[i:]) { + newline := 1 + if line[len(line)-1] != '\n' { + newline = 0 + } + reader.Advance(segment.Stop - segment.Start - newline + segment.Padding) + return parser.Close + } + } + pos, padding := util.IndentPositionPadding(line, reader.LineOffset(), segment.Padding, fdata.indent) + if pos < 0 { + pos = util.FirstNonSpacePosition(line) + if pos < 0 { + pos = 0 + } + padding = 0 + } + seg := text.NewSegmentPadding(segment.Start+pos, segment.Stop, padding) + seg.ForceNewline = true // EOF as newline + node.Lines().Append(seg) + reader.AdvanceAndSetPadding(segment.Stop-segment.Start-pos-1, padding) + return parser.Continue | parser.NoChildren +} + +func (b *blockMathParser) Close(node ast.Node, reader text.Reader, pc parser.Context) { + fdata := pc.Get(mathBlockInfoKey).(*mathBlockData) + if fdata.node == node { + pc.Set(mathBlockInfoKey, nil) + } +} + +func (b *blockMathParser) CanInterruptParagraph() bool { + return true +} + +func (b *blockMathParser) CanAcceptIndentedLine() bool { + return false +} + +type mathHTMLRenderer struct { + html.Config +} + +func NewMathHTMLRenderer(opts ...html.Option) renderer.NodeRenderer { + r := &mathHTMLRenderer{ + Config: html.NewConfig(), + } + for _, opt := range opts { + opt.SetHTMLOption(&r.Config) + } + return r +} + +func (r *mathHTMLRenderer) RegisterFuncs(reg renderer.NodeRendererFuncRegisterer) { + reg.Register(astKindMath, r.renderMath) +} + +func (r *mathHTMLRenderer) renderMath(w util.BufWriter, source []byte, n ast.Node, entering bool) (ast.WalkStatus, error) { + if entering { + tag := "span" + switch typed := n.(type) { + case *astMathBlock: + tag = "div" + case *astMath: + if typed.block { + tag = "div" + } + } + tex := stdhtml.EscapeString(string(n.Text(source))) + _, _ = fmt.Fprintf(w, `<%s data-mx-maths="%s">%s`, tag, tex, strings.ReplaceAll(tex, "\n", "
"), tag) + } + return ast.WalkSkipChildren, nil +} + +type math struct{} + +// Math is an extension that allow you to use math like '$$text$$'. +var Math = &math{} + +func (e *math) Extend(m goldmark.Markdown) { + m.Parser().AddOptions(parser.WithInlineParsers( + util.Prioritized(NewInlineMathParser(), 500), + ), parser.WithBlockParsers( + util.Prioritized(NewBlockMathParser(), 850), + )) + m.Renderer().AddOptions(renderer.WithNodeRenderers( + util.Prioritized(NewMathHTMLRenderer(), 500), + )) +} From 0f73c831966528b2b27750dacf9abbd1261838d6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 2 Nov 2024 12:29:18 +0200 Subject: [PATCH 0891/1647] format/mdext: maybe improve math parser --- format/markdown_test.go | 26 +++++++++------ format/mdext/math.go | 72 +++++++++++++++++------------------------ 2 files changed, 46 insertions(+), 52 deletions(-) diff --git a/format/markdown_test.go b/format/markdown_test.go index 6170a47f..bb415d8a 100644 --- a/format/markdown_test.go +++ b/format/markdown_test.go @@ -160,21 +160,27 @@ func TestRenderMarkdown_DiscordUnderline(t *testing.T) { } var mathTests = map[string]string{ - "$foo$": `foo`, - "$$foo$$": `
foo
`, - "$$\nfoo\nbar\n$$": `
foo
bar
`, - "`$foo$`": `$foo$`, - "```\n$foo$\n```": `
$foo$\n
`, - "~~$foo$~~": `foo`, - "$5 or $10": `$5 or $10`, - "5$ or 10$": `5$ or 10$`, - "$5 or 10$": `5 or 10`, + "$foo$": `foo`, + "hello $foo$ world": `hello foo world`, + "$$\nfoo\nbar\n$$": `
foo
bar
`, + "`$foo$`": `$foo$`, + "```\n$foo$\n```": `
$foo$\n
`, + "~~meow $foo$ asd~~": `meow foo asd`, + "$5 or $10": `$5 or $10`, + "5$ or 10$": `5$ or 10$`, + "$5 or 10$": `5 or 10`, + "$*500*$": `*500*`, + "$$\n*500*\n$$": `
*500*
`, + + // TODO: This doesn't work :( + // Maybe same reason as the spoiler wrapping not working? + //"~~$foo$~~": `foo`, } func TestRenderMarkdown_Math(t *testing.T) { renderer := goldmark.New(goldmark.WithExtensions(extension.Strikethrough, mdext.Math, mdext.EscapeHTML), format.HTMLOptions) for markdown, html := range mathTests { rendered := format.UnwrapSingleParagraph(render(renderer, markdown)) - assert.Equal(t, html, strings.ReplaceAll(rendered, "\n", "\\n")) + assert.Equal(t, html, strings.ReplaceAll(rendered, "\n", "\\n"), "with input %q", markdown) } } diff --git a/format/mdext/math.go b/format/mdext/math.go index 0f08b9ea..e6a6ecc5 100644 --- a/format/mdext/math.go +++ b/format/mdext/math.go @@ -10,7 +10,9 @@ import ( "bytes" "fmt" stdhtml "html" + "regexp" "strings" + "unicode" "github.com/yuin/goldmark" "github.com/yuin/goldmark/ast" @@ -25,7 +27,7 @@ var astKindMath = ast.NewNodeKind("Math") type astMath struct { ast.BaseInline - block bool + value []byte } func (n *astMath) Dump(source []byte, level int) { @@ -38,7 +40,6 @@ func (n *astMath) Kind() ast.NodeKind { type astMathBlock struct { ast.BaseBlock - info *ast.Text } func (n *astMathBlock) Dump(source []byte, level int) { @@ -49,22 +50,6 @@ func (n *astMathBlock) Kind() ast.NodeKind { return astKindMath } -type mathDelimiterProcessor struct{} - -var defaultMathDelimiterProcessor = &mathDelimiterProcessor{} - -func (p *mathDelimiterProcessor) IsDelimiter(b byte) bool { - return b == '$' -} - -func (p *mathDelimiterProcessor) CanOpenCloser(opener, closer *parser.Delimiter) bool { - return opener.Char == closer.Char -} - -func (p *mathDelimiterProcessor) OnMatch(consumes int) ast.Node { - return &astMath{block: consumes > 1} -} - type inlineMathParser struct{} var defaultInlineMathParser = &inlineMathParser{} @@ -73,21 +58,30 @@ func NewInlineMathParser() parser.InlineParser { return defaultInlineMathParser } +const mathDelimiter = '$' + func (s *inlineMathParser) Trigger() []byte { - return []byte{'$'} + return []byte{mathDelimiter} } +// This ignores lines where there's no space after the closing $ to avoid false positives +var latexInlineRegexp = regexp.MustCompile(`^(\$[^$]*\$)(?:$|\s)`) + func (s *inlineMathParser) Parse(parent ast.Node, block text.Reader, pc parser.Context) ast.Node { before := block.PrecendingCharacter() - line, segment := block.PeekLine() - node := parser.ScanDelimiter(line, before, 1, defaultMathDelimiterProcessor) - if node == nil { + // Ignore lines where the opening $ comes after a letter or number to avoid false positives + if unicode.IsLetter(before) || unicode.IsNumber(before) { return nil } - node.Segment = segment.WithStop(segment.Start + node.OriginalLength) - block.Advance(node.OriginalLength) - pc.PushDelimiter(node) - return node + line, segment := block.PeekLine() + idx := latexInlineRegexp.FindSubmatchIndex(line) + if idx == nil { + return nil + } + block.Advance(idx[3]) + return &astMath{ + value: block.Value(text.NewSegment(segment.Start+1, segment.Start+idx[3]-1)), + } } func (s *inlineMathParser) CloseBlock(parent ast.Node, pc parser.Context) { @@ -115,50 +109,44 @@ func (b *blockMathParser) Trigger() []byte { } func (b *blockMathParser) Open(parent ast.Node, reader text.Reader, pc parser.Context) (ast.Node, parser.State) { - const fenceChar = '$' - line, segment := reader.PeekLine() + line, _ := reader.PeekLine() pos := pc.BlockOffset() - if pos < 0 || (line[pos] != fenceChar) { + if pos < 0 || (line[pos] != mathDelimiter) { return nil, parser.NoChildren } findent := pos i := pos - for ; i < len(line) && line[i] == fenceChar; i++ { + for ; i < len(line) && line[i] == mathDelimiter; i++ { } oFenceLength := i - pos if oFenceLength < 2 { return nil, parser.NoChildren } - var info *ast.Text if i < len(line)-1 { rest := line[i:] left := util.TrimLeftSpaceLength(rest) right := util.TrimRightSpaceLength(rest) if left < len(rest)-right { - infoStart, infoStop := segment.Start-segment.Padding+i+left, segment.Stop-right value := rest[left : len(rest)-right] - if bytes.IndexByte(value, fenceChar) > -1 { + if bytes.IndexByte(value, mathDelimiter) > -1 { return nil, parser.NoChildren - } else if infoStart != infoStop { - info = ast.NewTextSegment(text.NewSegment(infoStart, infoStop)) } } } - node := &astMathBlock{info: info} + node := &astMathBlock{} pc.Set(mathBlockInfoKey, &mathBlockData{findent, oFenceLength, node}) return node, parser.NoChildren } func (b *blockMathParser) Continue(node ast.Node, reader text.Reader, pc parser.Context) parser.State { - const fenceChar = '$' line, segment := reader.PeekLine() fdata := pc.Get(mathBlockInfoKey).(*mathBlockData) w, pos := util.IndentWidth(line, reader.LineOffset()) if w < 4 { i := pos - for ; i < len(line) && line[i] == fenceChar; i++ { + for ; i < len(line) && line[i] == mathDelimiter; i++ { } length := i - pos if length >= fdata.length && util.IsBlank(line[i:]) { @@ -221,15 +209,15 @@ func (r *mathHTMLRenderer) RegisterFuncs(reg renderer.NodeRendererFuncRegisterer func (r *mathHTMLRenderer) renderMath(w util.BufWriter, source []byte, n ast.Node, entering bool) (ast.WalkStatus, error) { if entering { tag := "span" + var tex string switch typed := n.(type) { case *astMathBlock: tag = "div" + tex = string(n.Lines().Value(source)) case *astMath: - if typed.block { - tag = "div" - } + tex = string(typed.value) } - tex := stdhtml.EscapeString(string(n.Text(source))) + tex = stdhtml.EscapeString(strings.TrimSpace(tex)) _, _ = fmt.Fprintf(w, `<%s data-mx-maths="%s">%s`, tag, tex, strings.ReplaceAll(tex, "\n", "
"), tag) } return ast.WalkSkipChildren, nil From 013afd06d34185108daf00e57a3d4349aee7c6b7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 2 Nov 2024 12:37:50 +0200 Subject: [PATCH 0892/1647] format/htmlparser: convert math to just latex --- format/htmlparser.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/format/htmlparser.go b/format/htmlparser.go index bafd41af..5a967669 100644 --- a/format/htmlparser.go +++ b/format/htmlparser.go @@ -102,6 +102,8 @@ type HTMLParser struct { ItalicConverter TextConverter StrikethroughConverter TextConverter UnderlineConverter TextConverter + MathConverter TextConverter + MathBlockConverter TextConverter LinkConverter LinkConverter SpoilerConverter SpoilerConverter ColorConverter ColorConverter @@ -238,6 +240,16 @@ func (parser *HTMLParser) basicFormatToString(node *html.Node, ctx Context) stri func (parser *HTMLParser) spanToString(node *html.Node, ctx Context) string { str := parser.nodeToTagAwareString(node.FirstChild, ctx) + if node.Data == "span" || node.Data == "div" { + math, _ := parser.maybeGetAttribute(node, "data-mx-maths") + if math != "" && parser.MathConverter != nil { + if node.Data == "div" && parser.MathBlockConverter != nil { + str = parser.MathBlockConverter(math, ctx) + } else { + str = parser.MathConverter(math, ctx) + } + } + } if node.Data == "span" { reason, isSpoiler := parser.maybeGetAttribute(node, "data-mx-spoiler") if isSpoiler { @@ -449,6 +461,12 @@ func HTMLToMarkdownAndMentions(html string) (parsed string, mentions *event.Ment } return fmt.Sprintf("[%s](%s)", text, href) }, + MathConverter: func(s string, c Context) string { + return fmt.Sprintf("$%s$", s) + }, + MathBlockConverter: func(s string, c Context) string { + return fmt.Sprintf("$$\n%s\n$$", s) + }, }).Parse(html, ctx) mentionList, _ := ctx.ReturnData[ContextKeyMentions].([]id.UserID) mentions = &event.Mentions{ From 83e60efa1558a799606bd8147f3f599c341c0002 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 2 Nov 2024 13:37:24 +0200 Subject: [PATCH 0893/1647] format: add markdown renderer for custom emojis --- format/htmlparser.go | 62 ++++++++++++++++++++----------- format/markdown.go | 8 +++- format/markdown_test.go | 12 ++++++ format/mdext/customemoji.go | 73 +++++++++++++++++++++++++++++++++++++ go.mod | 2 +- go.sum | 4 +- 6 files changed, 134 insertions(+), 27 deletions(-) create mode 100644 format/mdext/customemoji.go diff --git a/format/htmlparser.go b/format/htmlparser.go index 5a967669..7c3b3c88 100644 --- a/format/htmlparser.go +++ b/format/htmlparser.go @@ -66,7 +66,7 @@ type LinkConverter func(text, href string, ctx Context) string type ColorConverter func(text, fg, bg string, ctx Context) string type CodeBlockConverter func(code, language string, ctx Context) string type PillConverter func(displayname, mxid, eventID string, ctx Context) string -type ImageConverter func(src, alt string, isEmoji bool) string +type ImageConverter func(src, alt, title, width, height string, isEmoji bool) string const ContextKeyMentions = "_mentions" @@ -315,9 +315,12 @@ func (parser *HTMLParser) linkToString(node *html.Node, ctx Context) string { func (parser *HTMLParser) imgToString(node *html.Node, ctx Context) string { src := parser.getAttribute(node, "src") alt := parser.getAttribute(node, "alt") + title := parser.getAttribute(node, "title") + width := parser.getAttribute(node, "width") + height := parser.getAttribute(node, "height") _, isEmoji := parser.maybeGetAttribute(node, "data-mx-emoticon") if parser.ImageConverter != nil { - return parser.ImageConverter(src, alt, isEmoji) + return parser.ImageConverter(src, alt, title, width, height, isEmoji) } return alt } @@ -438,6 +441,35 @@ func (parser *HTMLParser) Parse(htmlData string, ctx Context) string { return parser.nodeToTagAwareString(node, ctx) } +var TextHTMLParser = &HTMLParser{ + TabsToSpaces: 4, + Newline: "\n", + HorizontalLine: "\n---\n", + PillConverter: DefaultPillConverter, +} + +var MarkdownHTMLParser = &HTMLParser{ + TabsToSpaces: 4, + Newline: "\n", + HorizontalLine: "\n---\n", + PillConverter: DefaultPillConverter, + LinkConverter: func(text, href string, ctx Context) string { + if text == href { + return text + } + return fmt.Sprintf("[%s](%s)", text, href) + }, + MathConverter: func(s string, c Context) string { + return fmt.Sprintf("$%s$", s) + }, + MathBlockConverter: func(s string, c Context) string { + return fmt.Sprintf("$$\n%s\n$$", s) + }, + UnderlineConverter: func(s string, c Context) string { + return fmt.Sprintf("%s", s) + }, +} + // HTMLToText converts Matrix HTML into text with the default settings. func HTMLToText(html string) string { return (&HTMLParser{ @@ -448,26 +480,12 @@ func HTMLToText(html string) string { }).Parse(html, NewContext(context.TODO())) } -func HTMLToMarkdownAndMentions(html string) (parsed string, mentions *event.Mentions) { +func HTMLToMarkdownFull(parser *HTMLParser, html string) (parsed string, mentions *event.Mentions) { + if parser == nil { + parser = MarkdownHTMLParser + } ctx := NewContext(context.TODO()) - parsed = (&HTMLParser{ - TabsToSpaces: 4, - Newline: "\n", - HorizontalLine: "\n---\n", - PillConverter: DefaultPillConverter, - LinkConverter: func(text, href string, ctx Context) string { - if text == href { - return text - } - return fmt.Sprintf("[%s](%s)", text, href) - }, - MathConverter: func(s string, c Context) string { - return fmt.Sprintf("$%s$", s) - }, - MathBlockConverter: func(s string, c Context) string { - return fmt.Sprintf("$$\n%s\n$$", s) - }, - }).Parse(html, ctx) + parsed = parser.Parse(html, ctx) mentionList, _ := ctx.ReturnData[ContextKeyMentions].([]id.UserID) mentions = &event.Mentions{ UserIDs: mentionList, @@ -479,6 +497,6 @@ func HTMLToMarkdownAndMentions(html string) (parsed string, mentions *event.Ment // // Currently, the only difference to HTMLToText is how links are formatted. func HTMLToMarkdown(html string) string { - parsed, _ := HTMLToMarkdownAndMentions(html) + parsed, _ := HTMLToMarkdownFull(nil, html) return parsed } diff --git a/format/markdown.go b/format/markdown.go index d62013cc..d099ba00 100644 --- a/format/markdown.go +++ b/format/markdown.go @@ -57,8 +57,8 @@ func TextToContent(text string) event.MessageEventContent { } } -func HTMLToContent(html string) event.MessageEventContent { - text, mentions := HTMLToMarkdownAndMentions(html) +func HTMLToContentFull(renderer *HTMLParser, html string) event.MessageEventContent { + text, mentions := HTMLToMarkdownFull(renderer, html) if html != text { return event.MessageEventContent{ FormattedBody: html, @@ -71,6 +71,10 @@ func HTMLToContent(html string) event.MessageEventContent { return TextToContent(text) } +func HTMLToContent(html string) event.MessageEventContent { + return HTMLToContentFull(nil, html) +} + func RenderMarkdown(text string, allowMarkdown, allowHTML bool) event.MessageEventContent { var htmlBody string diff --git a/format/markdown_test.go b/format/markdown_test.go index bb415d8a..d4e7d716 100644 --- a/format/markdown_test.go +++ b/format/markdown_test.go @@ -184,3 +184,15 @@ func TestRenderMarkdown_Math(t *testing.T) { assert.Equal(t, html, strings.ReplaceAll(rendered, "\n", "\\n"), "with input %q", markdown) } } + +var customEmojiTests = map[string]string{ + `![:meow:](mxc://example.com/emoji.png "Emoji: meow")`: `:meow:`, +} + +func TestRenderMarkdown_CustomEmoji(t *testing.T) { + renderer := goldmark.New(goldmark.WithExtensions(mdext.CustomEmoji), format.HTMLOptions) + for markdown, html := range customEmojiTests { + rendered := format.UnwrapSingleParagraph(render(renderer, markdown)) + assert.Equal(t, html, rendered, "with input %q", markdown) + } +} diff --git a/format/mdext/customemoji.go b/format/mdext/customemoji.go new file mode 100644 index 00000000..2884a5ea --- /dev/null +++ b/format/mdext/customemoji.go @@ -0,0 +1,73 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mdext + +import ( + "bytes" + + "github.com/yuin/goldmark" + "github.com/yuin/goldmark/ast" + "github.com/yuin/goldmark/renderer" + "github.com/yuin/goldmark/util" +) + +type extCustomEmoji struct{} +type customEmojiRenderer struct { + funcs functionCapturer +} + +// CustomEmoji is an extension that converts certain markdown images into Matrix custom emojis. +var CustomEmoji = &extCustomEmoji{} + +type functionCapturer struct { + renderImage renderer.NodeRendererFunc + renderText renderer.NodeRendererFunc + renderString renderer.NodeRendererFunc +} + +func (fc *functionCapturer) Register(kind ast.NodeKind, rendererFunc renderer.NodeRendererFunc) { + switch kind { + case ast.KindImage: + fc.renderImage = rendererFunc + case ast.KindText: + fc.renderText = rendererFunc + case ast.KindString: + fc.renderString = rendererFunc + } +} + +var ( + _ renderer.NodeRendererFuncRegisterer = (*functionCapturer)(nil) + _ renderer.Option = (*functionCapturer)(nil) +) + +func (fc *functionCapturer) SetConfig(cfg *renderer.Config) { + cfg.NodeRenderers[0].Value.(renderer.NodeRenderer).RegisterFuncs(fc) +} + +func (eeh *extCustomEmoji) Extend(m goldmark.Markdown) { + var fc functionCapturer + m.Renderer().AddOptions(&fc) + m.Renderer().AddOptions(renderer.WithNodeRenderers(util.Prioritized(&customEmojiRenderer{fc}, 0))) +} + +func (cer *customEmojiRenderer) RegisterFuncs(reg renderer.NodeRendererFuncRegisterer) { + reg.Register(ast.KindImage, cer.renderImage) +} + +var emojiPrefix = []byte("Emoji: ") +var mxcPrefix = []byte("mxc://") + +func (cer *customEmojiRenderer) renderImage(w util.BufWriter, source []byte, node ast.Node, entering bool) (ast.WalkStatus, error) { + n, ok := node.(*ast.Image) + if ok && entering && bytes.HasPrefix(n.Title, emojiPrefix) && bytes.HasPrefix(n.Destination, mxcPrefix) { + n.Title = bytes.TrimPrefix(n.Title, emojiPrefix) + n.SetAttributeString("data-mx-emoticon", nil) + n.SetAttributeString("height", "32") + } + return cer.funcs.renderImage(w, source, node, entering) +} diff --git a/go.mod b/go.mod index 2f47e155..beb2badd 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/stretchr/testify v1.9.0 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 - github.com/yuin/goldmark v1.7.7 + github.com/yuin/goldmark v1.7.8 go.mau.fi/util v0.8.1 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.28.0 diff --git a/go.sum b/go.sum index 48d9fa8d..518fe895 100644 --- a/go.sum +++ b/go.sum @@ -49,8 +49,8 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/yuin/goldmark v1.7.7 h1:5m9rrB1sW3JUMToKFQfb+FGt1U7r57IHu5GrYrG2nqU= -github.com/yuin/goldmark v1.7.7/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= +github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= +github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= go.mau.fi/util v0.8.1 h1:Ga43cz6esQBYqcjZ/onRoVnYWoUwjWbsxVeJg2jOTSo= go.mau.fi/util v0.8.1/go.mod h1:T1u/rD2rzidVrBLyaUdPpZiJdP/rsyi+aTzn0D+Q6wc= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= From fff009b5fac47919db806bf539d52fd499ab1bee Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 4 Nov 2024 10:00:05 -0700 Subject: [PATCH 0894/1647] bridgev2/portal: check if client is logged in before handling read receipt Signed-off-by: Sumner Evans --- bridgev2/portal.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 93410cbc..6ada8918 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -396,7 +396,7 @@ func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowR if err != nil { return nil, nil, err } - if login == nil || login.UserMXID != user.MXID { + if login == nil || login.UserMXID != user.MXID || !login.Client.IsLoggedIn() { if allowRelay && portal.Relay != nil { return nil, nil, nil } @@ -412,9 +412,9 @@ func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowR } portal.Bridge.cacheLock.Lock() defer portal.Bridge.cacheLock.Unlock() - for i, up := range logins { + for _, up := range logins { login, ok := user.logins[up.LoginID] - if ok && login.Client != nil && (len(logins) == i-1 || login.Client.IsLoggedIn()) { + if ok && login.Client != nil && login.Client.IsLoggedIn() { return login, up, nil } } @@ -430,7 +430,7 @@ func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowR firstLogin = login break } - if firstLogin != nil { + if firstLogin != nil && firstLogin.Client.IsLoggedIn() { zerolog.Ctx(ctx).Warn(). Str("chosen_login_id", string(firstLogin.ID)). Msg("No usable user portal rows found, returning random login") From ff907f403334e75fa2df1a733c9b904c6d2c6ea8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 6 Nov 2024 09:26:59 +0100 Subject: [PATCH 0895/1647] event/reply: implement MSC2781 --- event/reply.go | 49 +------------------------------------------------ 1 file changed, 1 insertion(+), 48 deletions(-) diff --git a/event/reply.go b/event/reply.go index 73f8cfc7..1a88c619 100644 --- a/event/reply.go +++ b/event/reply.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 Tulir Asokan +// Copyright (c) 2024 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -7,7 +7,6 @@ package event import ( - "fmt" "regexp" "strings" @@ -47,52 +46,6 @@ func (content *MessageEventContent) GetReplyTo() id.EventID { return content.RelatesTo.GetReplyTo() } -const ReplyFormat = `
In reply to %s
%s
` - -func (evt *Event) GenerateReplyFallbackHTML() string { - parsedContent, ok := evt.Content.Parsed.(*MessageEventContent) - if !ok { - return "" - } - parsedContent.RemoveReplyFallback() - body := parsedContent.FormattedBody - if len(body) == 0 { - body = TextToHTML(parsedContent.Body) - } - - senderDisplayName := evt.Sender - - return fmt.Sprintf(ReplyFormat, evt.RoomID, evt.ID, evt.Sender, senderDisplayName, body) -} - -func (evt *Event) GenerateReplyFallbackText() string { - parsedContent, ok := evt.Content.Parsed.(*MessageEventContent) - if !ok { - return "" - } - parsedContent.RemoveReplyFallback() - body := parsedContent.Body - lines := strings.Split(strings.TrimSpace(body), "\n") - firstLine, lines := lines[0], lines[1:] - - senderDisplayName := evt.Sender - - var fallbackText strings.Builder - _, _ = fmt.Fprintf(&fallbackText, "> <%s> %s", senderDisplayName, firstLine) - for _, line := range lines { - _, _ = fmt.Fprintf(&fallbackText, "\n> %s", line) - } - fallbackText.WriteString("\n\n") - return fallbackText.String() -} - func (content *MessageEventContent) SetReply(inReplyTo *Event) { content.RelatesTo = (&RelatesTo{}).SetReplyTo(inReplyTo.ID) - - if content.MsgType == MsgText || content.MsgType == MsgNotice { - content.EnsureHasHTML() - content.FormattedBody = inReplyTo.GenerateReplyFallbackHTML() + content.FormattedBody - content.Body = inReplyTo.GenerateReplyFallbackText() + content.Body - content.replyFallbackRemoved = false - } } From 05a970d3702728478ccf34546732cd438639bfcc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 6 Nov 2024 09:33:45 +0100 Subject: [PATCH 0896/1647] changelog: update --- CHANGELOG.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e1904312..56a297a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,21 @@ +## unreleased + +* *(hicli)* Moved package into gomuks repo. +* *(bridgev2/commands)* Fixed cookie unescaping in login commands. +* *(bridgev2/portal)* Added special `DefaultChatName` constant to explicitly + reset portal names to the default (based on members). +* *(appservice)* Updated [MSC2409] stable registration field name from + `push_ephemeral` to `receive_ephemeral`. Homeserver admins must update + existing registrations manually. +* *(format)* Added support for `img` tags. +* *(format/mdext)* Added goldmark extensions for Matrix math and custom emojis. +* *(event/reply)* Removed support for generating reply fallbacks ([MSC2781]). +* *(pushrules)* Added support for `sender_notification_permission` condition + kind (used for `@room` mentions). +* *(crypto)* Added support for `json.RawMessage` in `EncryptMegolmEvent`. + +[MSC2781]: https://github.com/matrix-org/matrix-spec-proposals/pull/2781 + ## v0.21.1 (2024-10-16) * *(bridgev2)* Added more features and fixed bugs. From 7a4e5e549e37c53195bb4e3c342c5535191b4028 Mon Sep 17 00:00:00 2001 From: Scott Weber Date: Wed, 6 Nov 2024 04:56:20 -0500 Subject: [PATCH 0897/1647] Use BeeperEncodedOrder to encode BeeperHSOrderString (#311) --- event/beeper.go | 101 ++++++++++++++++++++++++++++++++++++++++++++++++ event/events.go | 10 ++--- 2 files changed, 106 insertions(+), 5 deletions(-) diff --git a/event/beeper.go b/event/beeper.go index 911bdfe3..0af466f8 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -7,6 +7,10 @@ package event import ( + "encoding/base32" + "encoding/binary" + "fmt" + "maunium.net/go/mautrix/id" ) @@ -104,3 +108,100 @@ type BeeperPerMessageProfile struct { AvatarURL *id.ContentURIString `json:"avatar_url,omitempty"` AvatarFile *EncryptedFileInfo `json:"avatar_file,omitempty"` } + +type BeeperEncodedOrder struct { + order int64 + suborder int64 +} + +func NewBeeperEncodedOrder(order int64, suborder int64) BeeperEncodedOrder { + return BeeperEncodedOrder{order: order, suborder: suborder} +} + +func BeeperEncodedOrderFromString(str string) (BeeperEncodedOrder, error) { + order, suborder, err := decodeIntPair(str) + if err != nil { + return BeeperEncodedOrder{}, err + } + return BeeperEncodedOrder{order: order, suborder: suborder}, nil +} + +func (b BeeperEncodedOrder) String() string { + return encodeIntPair(b.order, b.suborder) +} + +func (b BeeperEncodedOrder) OrderPair() (int64, int64) { + return b.order, b.suborder +} + +func (b BeeperEncodedOrder) IsZero() bool { + return b.order == 0 && b.suborder == 0 +} + +func (b BeeperEncodedOrder) MarshalJSON() ([]byte, error) { + return []byte(`"` + b.String() + `"`), nil +} + +func (b *BeeperEncodedOrder) UnmarshalJSON(data []byte) error { + str := string(data) + if len(str) < 2 { + return fmt.Errorf("invalid encoded order string: %s", str) + } + decoded, err := BeeperEncodedOrderFromString(str[1 : len(str)-1]) + if err != nil { + return err + } + b.order, b.suborder = decoded.order, decoded.suborder + return nil +} + +// encodeIntPair encodes two int64 integers into a lexicographically sortable string +func encodeIntPair(a, b int64) string { + // Create a buffer to hold the binary representation of the integers. + var buf [16]byte + + // Flip the sign bit of each integer to map the entire int64 range to uint64 + // in a way that preserves the order of the original integers. + // + // Explanation: + // - By XORing with (1 << 63), we flip the most significant bit (sign bit) of the int64 value. + // - Negative numbers (which have a sign bit of 1) become smaller uint64 values. + // - Non-negative numbers (with a sign bit of 0) become larger uint64 values. + // - This mapping preserves the original ordering when the uint64 values are compared. + binary.BigEndian.PutUint64(buf[0:8], uint64(a)^(1<<63)) + binary.BigEndian.PutUint64(buf[8:16], uint64(b)^(1<<63)) + + // Encode the buffer into a Base32 string without padding using the Hex encoding. + // + // Explanation: + // - Base32 encoding converts binary data into a text representation using 32 ASCII characters. + // - Using Base32HexEncoding ensures that the characters are in lexicographical order. + // - Disabling padding results in a consistent string length, which is important for sorting. + encoded := base32.HexEncoding.WithPadding(base32.NoPadding).EncodeToString(buf[:]) + + return encoded +} + +// decodeIntPair decodes a string produced by encodeIntPair back into the original int64 integers +func decodeIntPair(encoded string) (int64, int64, error) { + // Decode the Base32 string back into the original byte buffer. + buf, err := base32.HexEncoding.WithPadding(base32.NoPadding).DecodeString(encoded) + if err != nil { + return 0, 0, fmt.Errorf("failed to decode string: %w", err) + } + + // Check that the decoded buffer has the expected length. + if len(buf) != 16 { + return 0, 0, fmt.Errorf("invalid encoded string length: expected 16 bytes, got %d", len(buf)) + } + + // Read the uint64 values from the buffer using big-endian byte order. + aPos := binary.BigEndian.Uint64(buf[0:8]) + bPos := binary.BigEndian.Uint64(buf[8:16]) + + // Reverse the sign bit flip to retrieve the original int64 values. + a := int64(aPos ^ (1 << 63)) + b := int64(bPos ^ (1 << 63)) + + return a, b, nil +} diff --git a/event/events.go b/event/events.go index e9c5955c..38f0d848 100644 --- a/event/events.go +++ b/event/events.go @@ -144,14 +144,14 @@ type Unsigned struct { RedactedBecause *Event `json:"redacted_because,omitempty"` InviteRoomState []StrippedState `json:"invite_room_state,omitempty"` - BeeperHSOrder int64 `json:"com.beeper.hs.order,omitempty"` - BeeperHSSuborder int64 `json:"com.beeper.hs.suborder,omitempty"` - BeeperHSOrderString string `json:"com.beeper.hs.order_string,omitempty"` - BeeperFromBackup bool `json:"com.beeper.from_backup,omitempty"` + BeeperHSOrder int64 `json:"com.beeper.hs.order,omitempty"` + BeeperHSSuborder int64 `json:"com.beeper.hs.suborder,omitempty"` + BeeperHSOrderString BeeperEncodedOrder `json:"com.beeper.hs.order_string,omitempty"` + BeeperFromBackup bool `json:"com.beeper.from_backup,omitempty"` } func (us *Unsigned) IsEmpty() bool { return us.PrevContent == nil && us.PrevSender == "" && us.ReplacesState == "" && us.Age == 0 && us.TransactionID == "" && us.RedactedBecause == nil && us.InviteRoomState == nil && us.Relations == nil && - us.BeeperHSOrder == 0 && us.BeeperHSSuborder == 0 && us.BeeperHSOrderString == "" + us.BeeperHSOrder == 0 && us.BeeperHSSuborder == 0 && us.BeeperHSOrderString.IsZero() } From 3b93df0702338f961da2923627df75299191d91c Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 6 Nov 2024 11:19:23 +0100 Subject: [PATCH 0898/1647] Remove unused `InviteUser` in the `MatrixAPI` interface --- bridgev2/matrix/intent.go | 8 -------- bridgev2/matrixinterface.go | 1 - 2 files changed, 9 deletions(-) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index e7e860f6..9f6c520e 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -478,14 +478,6 @@ func (as *ASIntent) GetMXID() id.UserID { return as.Matrix.UserID } -func (as *ASIntent) InviteUser(ctx context.Context, roomID id.RoomID, userID id.UserID) error { - _, err := as.Matrix.InviteUser(ctx, roomID, &mautrix.ReqInviteUser{ - Reason: "", - UserID: userID, - }) - return err -} - func (as *ASIntent) EnsureJoined(ctx context.Context, roomID id.RoomID) error { err := as.Matrix.EnsureJoined(ctx, roomID) if err != nil { diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 66d39403..699ce07b 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -143,7 +143,6 @@ type MatrixAPI interface { CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) (id.RoomID, error) DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnly bool) error - InviteUser(ctx context.Context, roomID id.RoomID, userID id.UserID) error EnsureJoined(ctx context.Context, roomID id.RoomID) error EnsureInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) error From 56aadb232f5cda0002c1aa07128d331784574245 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 6 Nov 2024 10:36:45 +0100 Subject: [PATCH 0899/1647] mediaproxy: add support for writing directly to http response --- CHANGELOG.md | 2 ++ mediaproxy/mediaproxy.go | 58 +++++++++++++++++++++++++++++++--------- 2 files changed, 48 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 56a297a8..a971597c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ * *(pushrules)* Added support for `sender_notification_permission` condition kind (used for `@room` mentions). * *(crypto)* Added support for `json.RawMessage` in `EncryptMegolmEvent`. +* *(mediaproxy)* Added `GetMediaResponseCallback` to write proxied response + directly instead of having to use an `io.Reader`. [MSC2781]: https://github.com/matrix-org/matrix-spec-proposals/pull/2781 diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go index f2591428..d1ab0815 100644 --- a/mediaproxy/mediaproxy.go +++ b/mediaproxy/mediaproxy.go @@ -32,20 +32,54 @@ type GetMediaResponse interface { isGetMediaResponse() } -func (*GetMediaResponseURL) isGetMediaResponse() {} -func (*GetMediaResponseData) isGetMediaResponse() {} +func (*GetMediaResponseURL) isGetMediaResponse() {} +func (*GetMediaResponseData) isGetMediaResponse() {} +func (*GetMediaResponseCallback) isGetMediaResponse() {} type GetMediaResponseURL struct { URL string ExpiresAt time.Time } +type GetMediaResponseWriter interface { + GetMediaResponse + io.WriterTo + GetContentType() string + GetContentLength() int64 +} + type GetMediaResponseData struct { Reader io.ReadCloser ContentType string ContentLength int64 } +func (d *GetMediaResponseData) WriteTo(w io.Writer) (int64, error) { + return io.Copy(w, d.Reader) +} + +func (d *GetMediaResponseData) GetContentType() string { + return d.ContentType +} + +func (d *GetMediaResponseData) GetContentLength() int64 { + return d.ContentLength +} + +type GetMediaResponseCallback struct { + Callback func(w io.Writer) (int64, error) + ContentType string + ContentLength int64 +} + +func (d *GetMediaResponseCallback) WriteTo(w io.Writer) (int64, error) { + return d.Callback(w) +} + +func (d *GetMediaResponseCallback) GetContentLength() int64 { + return d.ContentLength +} + type GetMediaFunc = func(ctx context.Context, mediaID string) (response GetMediaResponse, err error) type MediaProxy struct { @@ -316,21 +350,21 @@ func (mp *MediaProxy) DownloadMediaFederation(w http.ResponseWriter, r *http.Req log.Err(err).Msg("Failed to create multipart redirect field") return } - } else if dataResp, ok := resp.(*GetMediaResponseData); ok { + } else if dataResp, ok := resp.(GetMediaResponseWriter); ok { dataPart, err := mpw.CreatePart(textproto.MIMEHeader{ - "Content-Type": {dataResp.ContentType}, + "Content-Type": {dataResp.GetContentType()}, }) if err != nil { log.Err(err).Msg("Failed to create multipart data field") return } - _, err = io.Copy(dataPart, dataResp.Reader) + _, err = dataResp.WriteTo(dataPart) if err != nil { log.Err(err).Msg("Failed to write multipart data field") return } } else { - panic("unknown GetMediaResponse type") + panic(fmt.Errorf("unknown GetMediaResponse type %T", resp)) } err = mpw.Close() if err != nil { @@ -374,18 +408,18 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { w.Header().Set("Cache-Control", "no-store") } w.WriteHeader(http.StatusTemporaryRedirect) - } else if dataResp, ok := resp.(*GetMediaResponseData); ok { - w.Header().Set("Content-Type", dataResp.ContentType) - if dataResp.ContentLength != 0 { - w.Header().Set("Content-Length", strconv.FormatInt(dataResp.ContentLength, 10)) + } else if dataResp, ok := resp.(GetMediaResponseWriter); ok { + w.Header().Set("Content-Type", dataResp.GetContentType()) + if dataResp.GetContentLength() != 0 { + w.Header().Set("Content-Length", strconv.FormatInt(dataResp.GetContentLength(), 10)) } w.WriteHeader(http.StatusOK) - _, err := io.Copy(w, dataResp.Reader) + _, err := dataResp.WriteTo(w) if err != nil { log.Err(err).Msg("Failed to write media data") } } else { - panic("unknown GetMediaResponse type") + panic(fmt.Errorf("unknown GetMediaResponse type %T", resp)) } } From 39bddeb7d3f410d41bab3b3fba45ec33bc85a5bb Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 6 Nov 2024 11:01:49 +0100 Subject: [PATCH 0900/1647] mediaproxy: add support for temp files --- CHANGELOG.md | 5 ++- mediaproxy/mediaproxy.go | 89 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a971597c..26df4cbf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,8 +13,9 @@ * *(pushrules)* Added support for `sender_notification_permission` condition kind (used for `@room` mentions). * *(crypto)* Added support for `json.RawMessage` in `EncryptMegolmEvent`. -* *(mediaproxy)* Added `GetMediaResponseCallback` to write proxied response - directly instead of having to use an `io.Reader`. +* *(mediaproxy)* Added `GetMediaResponseCallback` and `GetMediaResponseFile` + to write proxied data directly to http response or temp file instead of + having to use an `io.Reader`. [MSC2781]: https://github.com/matrix-org/matrix-spec-proposals/pull/2781 diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go index d1ab0815..ce8dd99b 100644 --- a/mediaproxy/mediaproxy.go +++ b/mediaproxy/mediaproxy.go @@ -17,6 +17,7 @@ import ( "net" "net/http" "net/textproto" + "os" "strconv" "strings" "time" @@ -35,6 +36,7 @@ type GetMediaResponse interface { func (*GetMediaResponseURL) isGetMediaResponse() {} func (*GetMediaResponseData) isGetMediaResponse() {} func (*GetMediaResponseCallback) isGetMediaResponse() {} +func (*GetMediaResponseFile) isGetMediaResponse() {} type GetMediaResponseURL struct { URL string @@ -48,6 +50,11 @@ type GetMediaResponseWriter interface { GetContentLength() int64 } +var ( + _ GetMediaResponseWriter = (*GetMediaResponseCallback)(nil) + _ GetMediaResponseWriter = (*GetMediaResponseData)(nil) +) + type GetMediaResponseData struct { Reader io.ReadCloser ContentType string @@ -80,6 +87,15 @@ func (d *GetMediaResponseCallback) GetContentLength() int64 { return d.ContentLength } +func (d *GetMediaResponseCallback) GetContentType() string { + return d.ContentType +} + +type GetMediaResponseFile struct { + Callback func(w *os.File) error + ContentType string +} + type GetMediaFunc = func(ctx context.Context, mediaID string) (response GetMediaResponse, err error) type MediaProxy struct { @@ -350,6 +366,20 @@ func (mp *MediaProxy) DownloadMediaFederation(w http.ResponseWriter, r *http.Req log.Err(err).Msg("Failed to create multipart redirect field") return } + } else if fileResp, ok := resp.(*GetMediaResponseFile); ok { + _, err = doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error { + dataPart, err := mpw.CreatePart(textproto.MIMEHeader{ + "Content-Type": {mimeType}, + }) + if err != nil { + return fmt.Errorf("failed to create multipart data field: %w", err) + } + _, err = wt.WriteTo(dataPart) + return err + }) + if err != nil { + log.Err(err).Msg("Failed to do media proxy with temp file") + } } else if dataResp, ok := resp.(GetMediaResponseWriter); ok { dataPart, err := mpw.CreatePart(textproto.MIMEHeader{ "Content-Type": {dataResp.GetContentType()}, @@ -408,6 +438,20 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { w.Header().Set("Cache-Control", "no-store") } w.WriteHeader(http.StatusTemporaryRedirect) + } else if fileResp, ok := resp.(*GetMediaResponseFile); ok { + responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error { + w.Header().Set("Content-Type", mimeType) + w.Header().Set("Content-Length", strconv.FormatInt(size, 10)) + w.WriteHeader(http.StatusOK) + _, err := wt.WriteTo(w) + return err + }) + if err != nil { + log.Err(err).Msg("Failed to do media proxy with temp file") + if !responseStarted { + mautrix.MUnknown.WithMessage("Internal error proxying media").Write(w) + } + } } else if dataResp, ok := resp.(GetMediaResponseWriter); ok { w.Header().Set("Content-Type", dataResp.GetContentType()) if dataResp.GetContentLength() != 0 { @@ -423,6 +467,51 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { } } +func doTempFileDownload( + data *GetMediaResponseFile, + respond func(w io.WriterTo, size int64, mimeType string) error, +) (bool, error) { + tempFile, err := os.CreateTemp("", "mautrix-mediaproxy-*") + if err != nil { + return false, fmt.Errorf("failed to create temp file: %w", err) + } + defer func() { + _ = tempFile.Close() + _ = os.Remove(tempFile.Name()) + }() + err = data.Callback(tempFile) + if err != nil { + return false, err + } + _, err = tempFile.Seek(0, io.SeekStart) + if err != nil { + return false, fmt.Errorf("failed to seek to start of temp file: %w", err) + } + fileInfo, err := tempFile.Stat() + if err != nil { + return false, fmt.Errorf("failed to stat temp file: %w", err) + } + mimeType := data.ContentType + if mimeType == "" { + buf := make([]byte, 512) + n, err := tempFile.Read(buf) + if err != nil { + return false, fmt.Errorf("failed to read temp file to detect mime: %w", err) + } + buf = buf[:n] + _, err = tempFile.Seek(0, io.SeekStart) + if err != nil { + return false, fmt.Errorf("failed to seek to start of temp file: %w", err) + } + mimeType = http.DetectContentType(buf) + } + err = respond(tempFile, fileInfo.Size(), mimeType) + if err != nil { + return true, err + } + return true, nil +} + func jsonResponse(w http.ResponseWriter, status int, response interface{}) { w.Header().Add("Content-Type", "application/json") w.WriteHeader(status) From 49310b6ec146586ed021db556fb745bbaf08bc31 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 6 Nov 2024 11:17:13 +0100 Subject: [PATCH 0901/1647] mediaproxy: refactor error responses --- bridgev2/matrix/directmedia.go | 9 +- mediaproxy/mediaproxy.go | 146 ++++++++++++++++++--------------- 2 files changed, 82 insertions(+), 73 deletions(-) diff --git a/bridgev2/matrix/directmedia.go b/bridgev2/matrix/directmedia.go index 15af0263..bc5b312c 100644 --- a/bridgev2/matrix/directmedia.go +++ b/bridgev2/matrix/directmedia.go @@ -13,7 +13,6 @@ import ( "crypto/sha256" "encoding/base64" "fmt" - "net/http" "strings" "maunium.net/go/mautrix" @@ -80,13 +79,7 @@ func (br *Connector) getDirectMedia(ctx context.Context, mediaIDStr string) (res receivedHash := mediaID[len(mediaID)-MediaIDTruncatedHashLength:] expectedHash := br.hashMediaID(mediaID[:len(mediaID)-MediaIDTruncatedHashLength]) if !hmac.Equal(receivedHash, expectedHash) { - return nil, &mediaproxy.ResponseError{ - Status: http.StatusNotFound, - Data: &mautrix.RespError{ - ErrCode: mautrix.MNotFound.ErrCode, - Err: "Invalid checksum in media ID part", - }, - } + return nil, mautrix.MNotFound.WithMessage("Invalid checksum in media ID part") } remoteMediaID := networkid.MediaID(mediaID[len(MediaIDPrefix) : len(mediaID)-MediaIDTruncatedHashLength]) return br.Bridge.Network.(bridgev2.DirectMediableNetwork).Download(ctx, remoteMediaID) diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go index ce8dd99b..aab85f34 100644 --- a/mediaproxy/mediaproxy.go +++ b/mediaproxy/mediaproxy.go @@ -242,10 +242,7 @@ func (mp *MediaProxy) proxyDownload(ctx context.Context, w http.ResponseWriter, req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { log.Err(err).Str("url", url).Msg("Failed to create proxy request") - jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ - ErrCode: "M_UNKNOWN", - Err: "Failed to create proxy request", - }) + mautrix.MUnknown.WithMessage("Failed to create proxy request").Write(w) return } req.Header.Set("User-Agent", mautrix.DefaultUserAgent+" (media proxy)") @@ -260,17 +257,11 @@ func (mp *MediaProxy) proxyDownload(ctx context.Context, w http.ResponseWriter, }() if err != nil { log.Err(err).Str("url", url).Msg("Failed to proxy download") - jsonResponse(w, http.StatusServiceUnavailable, &mautrix.RespError{ - ErrCode: "M_UNKNOWN", - Err: "Failed to proxy download", - }) + mautrix.MUnknown.WithMessage("Failed to proxy download").WithStatus(http.StatusServiceUnavailable).Write(w) return } else if resp.StatusCode != http.StatusOK { log.Warn().Str("url", url).Int("status", resp.StatusCode).Msg("Unexpected status code proxying download") - jsonResponse(w, resp.StatusCode, &mautrix.RespError{ - ErrCode: "M_UNKNOWN", - Err: "Unexpected status code proxying download", - }) + mautrix.MUnknown.WithMessage("Unexpected status code proxying download").WithStatus(resp.StatusCode).Write(w) return } w.Header()["Content-Type"] = resp.Header["Content-Type"] @@ -298,6 +289,7 @@ func (mp *MediaProxy) proxyDownload(ctx context.Context, w http.ResponseWriter, } } +// Deprecated: use mautrix.RespError instead type ResponseError struct { Status int Data any @@ -313,26 +305,45 @@ func (mp *MediaProxy) getMedia(w http.ResponseWriter, r *http.Request) GetMediaR mediaID := mux.Vars(r)["mediaID"] resp, err := mp.GetMedia(r.Context(), mediaID) if err != nil { + //lint:ignore SA1019 deprecated types need to be supported until they're removed var respError *ResponseError + var mautrixRespError mautrix.RespError if errors.Is(err, ErrInvalidMediaIDSyntax) { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - ErrCode: mautrix.MNotFound.ErrCode, - Err: fmt.Sprintf("This is a media proxy at %q, other media downloads are not available here", mp.serverName), - }) + mautrix.MNotFound.WithMessage("This is a media proxy at %q, other media downloads are not available here", mp.serverName).Write(w) + } else if errors.As(err, &mautrixRespError) { + mautrixRespError.Write(w) } else if errors.As(err, &respError) { - jsonResponse(w, respError.Status, respError.Data) + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(respError.Status) + _ = json.NewEncoder(w).Encode(respError.Data) } else { zerolog.Ctx(r.Context()).Err(err).Str("media_id", mediaID).Msg("Failed to get media URL") - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - ErrCode: mautrix.MNotFound.ErrCode, - Err: "Media not found", - }) + mautrix.MNotFound.WithMessage("Media not found").Write(w) } return nil } return resp } +func startMultipart(ctx context.Context, w http.ResponseWriter) *multipart.Writer { + mpw := multipart.NewWriter(w) + w.Header().Set("Content-Type", strings.Replace(mpw.FormDataContentType(), "form-data", "mixed", 1)) + w.WriteHeader(http.StatusOK) + metaPart, err := mpw.CreatePart(textproto.MIMEHeader{ + "Content-Type": {"application/json"}, + }) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to create multipart metadata field") + return nil + } + _, err = metaPart.Write([]byte(`{}`)) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to write multipart metadata field") + return nil + } + return mpw +} + func (mp *MediaProxy) DownloadMediaFederation(w http.ResponseWriter, r *http.Request) { ctx := r.Context() log := zerolog.Ctx(ctx) @@ -343,23 +354,13 @@ func (mp *MediaProxy) DownloadMediaFederation(w http.ResponseWriter, r *http.Req return } - mpw := multipart.NewWriter(w) - w.Header().Set("Content-Type", strings.Replace(mpw.FormDataContentType(), "form-data", "mixed", 1)) - w.WriteHeader(http.StatusOK) - metaPart, err := mpw.CreatePart(textproto.MIMEHeader{ - "Content-Type": {"application/json"}, - }) - if err != nil { - log.Err(err).Msg("Failed to create multipart metadata field") - return - } - _, err = metaPart.Write([]byte(`{}`)) - if err != nil { - log.Err(err).Msg("Failed to write multipart metadata field") - return - } + var mpw *multipart.Writer if urlResp, ok := resp.(*GetMediaResponseURL); ok { - _, err = mpw.CreatePart(textproto.MIMEHeader{ + mpw = startMultipart(ctx, w) + if mpw == nil { + return + } + _, err := mpw.CreatePart(textproto.MIMEHeader{ "Location": {urlResp.URL}, }) if err != nil { @@ -367,7 +368,11 @@ func (mp *MediaProxy) DownloadMediaFederation(w http.ResponseWriter, r *http.Req return } } else if fileResp, ok := resp.(*GetMediaResponseFile); ok { - _, err = doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error { + responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error { + mpw = startMultipart(ctx, w) + if mpw == nil { + return fmt.Errorf("failed to start multipart writer") + } dataPart, err := mpw.CreatePart(textproto.MIMEHeader{ "Content-Type": {mimeType}, }) @@ -379,8 +384,21 @@ func (mp *MediaProxy) DownloadMediaFederation(w http.ResponseWriter, r *http.Req }) if err != nil { log.Err(err).Msg("Failed to do media proxy with temp file") + if !responseStarted { + var mautrixRespError mautrix.RespError + if errors.As(err, &mautrixRespError) { + mautrixRespError.Write(w) + } else { + mautrix.MUnknown.WithMessage("Internal error proxying media").Write(w) + } + } + return } } else if dataResp, ok := resp.(GetMediaResponseWriter); ok { + mpw = startMultipart(ctx, w) + if mpw == nil { + return + } dataPart, err := mpw.CreatePart(textproto.MIMEHeader{ "Content-Type": {dataResp.GetContentType()}, }) @@ -396,7 +414,7 @@ func (mp *MediaProxy) DownloadMediaFederation(w http.ResponseWriter, r *http.Req } else { panic(fmt.Errorf("unknown GetMediaResponse type %T", resp)) } - err = mpw.Close() + err := mpw.Close() if err != nil { log.Err(err).Msg("Failed to close multipart writer") return @@ -408,10 +426,7 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { log := zerolog.Ctx(ctx) vars := mux.Vars(r) if vars["serverName"] != mp.serverName { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - ErrCode: mautrix.MNotFound.ErrCode, - Err: fmt.Sprintf("This is a media proxy at %q, other media downloads are not available here", mp.serverName), - }) + mautrix.MNotFound.WithMessage("This is a media proxy at %q, other media downloads are not available here", mp.serverName).Write(w) return } resp := mp.getMedia(w, r) @@ -449,7 +464,12 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { if err != nil { log.Err(err).Msg("Failed to do media proxy with temp file") if !responseStarted { - mautrix.MUnknown.WithMessage("Internal error proxying media").Write(w) + var mautrixRespError mautrix.RespError + if errors.As(err, &mautrixRespError) { + mautrixRespError.Write(w) + } else { + mautrix.MUnknown.WithMessage("Internal error proxying media").Write(w) + } } } } else if dataResp, ok := resp.(GetMediaResponseWriter); ok { @@ -512,36 +532,32 @@ func doTempFileDownload( return true, nil } -func jsonResponse(w http.ResponseWriter, status int, response interface{}) { - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(status) - _ = json.NewEncoder(w).Encode(response) -} +var ( + ErrUploadNotSupported = mautrix.MUnrecognized. + WithMessage("This is a media proxy and does not support media uploads."). + WithStatus(http.StatusNotImplemented) + ErrPreviewURLNotSupported = mautrix.MUnrecognized. + WithMessage("This is a media proxy and does not support URL previews."). + WithStatus(http.StatusNotImplemented) + ErrUnknownEndpoint = mautrix.MUnrecognized. + WithMessage("Unrecognized endpoint") + ErrUnsupportedMethod = mautrix.MUnrecognized. + WithMessage("Invalid method for endpoint"). + WithStatus(http.StatusMethodNotAllowed) +) func (mp *MediaProxy) UploadNotSupported(w http.ResponseWriter, r *http.Request) { - jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ - ErrCode: mautrix.MUnrecognized.ErrCode, - Err: "This is a media proxy and does not support media uploads.", - }) + ErrUploadNotSupported.Write(w) } func (mp *MediaProxy) PreviewURLNotSupported(w http.ResponseWriter, r *http.Request) { - jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ - ErrCode: mautrix.MUnrecognized.ErrCode, - Err: "This is a media proxy and does not support URL previews.", - }) + ErrPreviewURLNotSupported.Write(w) } func (mp *MediaProxy) UnknownEndpoint(w http.ResponseWriter, r *http.Request) { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - ErrCode: mautrix.MUnrecognized.ErrCode, - Err: "Unrecognized endpoint", - }) + ErrUnknownEndpoint.Write(w) } func (mp *MediaProxy) UnsupportedMethod(w http.ResponseWriter, r *http.Request) { - jsonResponse(w, http.StatusMethodNotAllowed, &mautrix.RespError{ - ErrCode: mautrix.MUnrecognized.ErrCode, - Err: "Invalid method for endpoint", - }) + ErrUnsupportedMethod.Write(w) } From 87651815a97bbd1bccdff6f9e09fb6623c701400 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 6 Nov 2024 11:38:19 +0100 Subject: [PATCH 0902/1647] mediaproxy: drop support for proxying urls --- mediaproxy/mediaproxy.go | 109 ++++----------------------------------- 1 file changed, 9 insertions(+), 100 deletions(-) diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go index aab85f34..22e62403 100644 --- a/mediaproxy/mediaproxy.go +++ b/mediaproxy/mediaproxy.go @@ -12,9 +12,7 @@ import ( "errors" "fmt" "io" - "mime" "mime/multipart" - "net" "net/http" "net/textproto" "os" @@ -99,8 +97,7 @@ type GetMediaResponseFile struct { type GetMediaFunc = func(ctx context.Context, mediaID string) (response GetMediaResponse, err error) type MediaProxy struct { - KeyServer *federation.KeyServer - ProxyClient *http.Client + KeyServer *federation.KeyServer ForceProxyLegacyFederation bool @@ -111,7 +108,6 @@ type MediaProxy struct { serverKey *federation.SigningKey FederationRouter *mux.Router - LegacyMediaRouter *mux.Router ClientMediaRouter *mux.Router } @@ -124,14 +120,6 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx serverName: serverName, serverKey: parsed, GetMedia: getMedia, - ProxyClient: &http.Client{ - Transport: &http.Transport{ - DialContext: (&net.Dialer{Timeout: 10 * time.Second}).DialContext, - TLSHandshakeTimeout: 10 * time.Second, - ForceAttemptHTTP2: false, - }, - Timeout: 60 * time.Second, - }, KeyServer: &federation.KeyServer{ KeyProvider: &federation.StaticServerKey{ ServerName: serverName, @@ -149,7 +137,6 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx type BasicConfig struct { ServerName string `yaml:"server_name" json:"server_name"` ServerKey string `yaml:"server_key" json:"server_key"` - AllowProxy bool `yaml:"allow_proxy" json:"allow_proxy"` WellKnownResponse string `yaml:"well_known_response" json:"well_known_response"` } @@ -158,9 +145,6 @@ func NewFromConfig(cfg BasicConfig, getMedia GetMediaFunc) (*MediaProxy, error) if err != nil { return nil, err } - if !cfg.AllowProxy { - mp.DisallowProxying() - } if cfg.WellKnownResponse != "" { mp.KeyServer.WellKnownTarget = cfg.WellKnownResponse } @@ -186,39 +170,24 @@ func (mp *MediaProxy) GetServerKey() *federation.SigningKey { return mp.serverKey } -func (mp *MediaProxy) DisallowProxying() { - mp.ProxyClient = nil -} - func (mp *MediaProxy) RegisterRoutes(router *mux.Router) { if mp.FederationRouter == nil { mp.FederationRouter = router.PathPrefix("/_matrix/federation").Subrouter() } - if mp.LegacyMediaRouter == nil { - mp.LegacyMediaRouter = router.PathPrefix("/_matrix/media").Subrouter() - } if mp.ClientMediaRouter == nil { mp.ClientMediaRouter = router.PathPrefix("/_matrix/client/v1/media").Subrouter() } mp.FederationRouter.HandleFunc("/v1/media/download/{mediaID}", mp.DownloadMediaFederation).Methods(http.MethodGet) mp.FederationRouter.HandleFunc("/v1/version", mp.KeyServer.GetServerVersion).Methods(http.MethodGet) - addClientRoutes := func(router *mux.Router, prefix string) { - router.HandleFunc(prefix+"/download/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet) - router.HandleFunc(prefix+"/download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia).Methods(http.MethodGet) - router.HandleFunc(prefix+"/thumbnail/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet) - router.HandleFunc(prefix+"/upload/{serverName}/{mediaID}", mp.UploadNotSupported).Methods(http.MethodPut) - router.HandleFunc(prefix+"/upload", mp.UploadNotSupported).Methods(http.MethodPost) - router.HandleFunc(prefix+"/create", mp.UploadNotSupported).Methods(http.MethodPost) - router.HandleFunc(prefix+"/config", mp.UploadNotSupported).Methods(http.MethodGet) - router.HandleFunc(prefix+"/preview_url", mp.PreviewURLNotSupported).Methods(http.MethodGet) - } - addClientRoutes(mp.LegacyMediaRouter, "/v3") - addClientRoutes(mp.LegacyMediaRouter, "/r0") - addClientRoutes(mp.LegacyMediaRouter, "/v1") - addClientRoutes(mp.ClientMediaRouter, "") - mp.LegacyMediaRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint) - mp.LegacyMediaRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod) + mp.ClientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet) + mp.ClientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia).Methods(http.MethodGet) + mp.ClientMediaRouter.HandleFunc("/thumbnail/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet) + mp.ClientMediaRouter.HandleFunc("/upload/{serverName}/{mediaID}", mp.UploadNotSupported).Methods(http.MethodPut) + mp.ClientMediaRouter.HandleFunc("/upload", mp.UploadNotSupported).Methods(http.MethodPost) + mp.ClientMediaRouter.HandleFunc("/create", mp.UploadNotSupported).Methods(http.MethodPost) + mp.ClientMediaRouter.HandleFunc("/config", mp.UploadNotSupported).Methods(http.MethodGet) + mp.ClientMediaRouter.HandleFunc("/preview_url", mp.PreviewURLNotSupported).Methods(http.MethodGet) mp.FederationRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint) mp.FederationRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod) mp.ClientMediaRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint) @@ -232,63 +201,10 @@ func (mp *MediaProxy) RegisterRoutes(router *mux.Router) { next.ServeHTTP(w, r) }) } - mp.LegacyMediaRouter.Use(corsMiddleware) mp.ClientMediaRouter.Use(corsMiddleware) mp.KeyServer.Register(router) } -func (mp *MediaProxy) proxyDownload(ctx context.Context, w http.ResponseWriter, url, fileName string) { - log := zerolog.Ctx(ctx) - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - log.Err(err).Str("url", url).Msg("Failed to create proxy request") - mautrix.MUnknown.WithMessage("Failed to create proxy request").Write(w) - return - } - req.Header.Set("User-Agent", mautrix.DefaultUserAgent+" (media proxy)") - if mp.PrepareProxyRequest != nil { - mp.PrepareProxyRequest(req) - } - resp, err := mp.ProxyClient.Do(req) - defer func() { - if resp != nil && resp.Body != nil { - _ = resp.Body.Close() - } - }() - if err != nil { - log.Err(err).Str("url", url).Msg("Failed to proxy download") - mautrix.MUnknown.WithMessage("Failed to proxy download").WithStatus(http.StatusServiceUnavailable).Write(w) - return - } else if resp.StatusCode != http.StatusOK { - log.Warn().Str("url", url).Int("status", resp.StatusCode).Msg("Unexpected status code proxying download") - mautrix.MUnknown.WithMessage("Unexpected status code proxying download").WithStatus(resp.StatusCode).Write(w) - return - } - w.Header()["Content-Type"] = resp.Header["Content-Type"] - w.Header()["Content-Length"] = resp.Header["Content-Length"] - w.Header()["Last-Modified"] = resp.Header["Last-Modified"] - w.Header()["Cache-Control"] = resp.Header["Cache-Control"] - contentDisposition := "attachment" - switch resp.Header.Get("Content-Type") { - case "text/css", "text/plain", "text/csv", "application/json", "application/ld+json", "image/jpeg", "image/gif", - "image/png", "image/apng", "image/webp", "image/avif", "video/mp4", "video/webm", "video/ogg", "video/quicktime", - "audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave", "audio/wav", "audio/x-wav", - "audio/x-pn-wav", "audio/flac", "audio/x-flac", "application/pdf": - contentDisposition = "inline" - } - if fileName != "" { - contentDisposition = mime.FormatMediaType(contentDisposition, map[string]string{ - "filename": fileName, - }) - } - w.Header().Set("Content-Disposition", contentDisposition) - w.WriteHeader(http.StatusOK) - _, err = io.Copy(w, resp.Body) - if err != nil { - log.Debug().Err(err).Msg("Failed to write proxy response") - } -} - // Deprecated: use mautrix.RespError instead type ResponseError struct { Status int @@ -435,13 +351,6 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { } if urlResp, ok := resp.(*GetMediaResponseURL); ok { - // Proxy if the config allows proxying and the request doesn't allow redirects. - // In any other case, redirect to the URL. - isFederated := strings.HasPrefix(r.Header.Get("Authorization"), "X-Matrix") - if mp.ProxyClient != nil && (r.URL.Query().Get("allow_redirect") != "true" || (mp.ForceProxyLegacyFederation && isFederated)) { - mp.proxyDownload(ctx, w, urlResp.URL, vars["fileName"]) - return - } w.Header().Set("Location", urlResp.URL) expirySeconds := (time.Until(urlResp.ExpiresAt) - 5*time.Minute).Seconds() if urlResp.ExpiresAt.IsZero() { From 8c3056a447b33691abeb4e8bd560a832355be956 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 6 Nov 2024 11:41:23 +0100 Subject: [PATCH 0903/1647] mediaproxy: add content disposition for proxied downloads --- mediaproxy/mediaproxy.go | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go index 22e62403..f2cde105 100644 --- a/mediaproxy/mediaproxy.go +++ b/mediaproxy/mediaproxy.go @@ -12,6 +12,7 @@ import ( "errors" "fmt" "io" + "mime" "mime/multipart" "net/http" "net/textproto" @@ -337,6 +338,25 @@ func (mp *MediaProxy) DownloadMediaFederation(w http.ResponseWriter, r *http.Req } } +func (mp *MediaProxy) addHeaders(w http.ResponseWriter, mimeType, fileName string) { + w.Header().Set("Cache-Control", "public, max-age=31536000, immutable") + contentDisposition := "attachment" + switch mimeType { + case "text/css", "text/plain", "text/csv", "application/json", "application/ld+json", "image/jpeg", "image/gif", + "image/png", "image/apng", "image/webp", "image/avif", "video/mp4", "video/webm", "video/ogg", "video/quicktime", + "audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave", "audio/wav", "audio/x-wav", + "audio/x-pn-wav", "audio/flac", "audio/x-flac", "application/pdf": + contentDisposition = "inline" + } + if fileName != "" { + contentDisposition = mime.FormatMediaType(contentDisposition, map[string]string{ + "filename": fileName, + }) + } + w.Header().Set("Content-Disposition", contentDisposition) + w.Header().Set("Content-Type", mimeType) +} + func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { ctx := r.Context() log := zerolog.Ctx(ctx) @@ -364,7 +384,7 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusTemporaryRedirect) } else if fileResp, ok := resp.(*GetMediaResponseFile); ok { responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error { - w.Header().Set("Content-Type", mimeType) + mp.addHeaders(w, mimeType, r.PathValue("fileName")) w.Header().Set("Content-Length", strconv.FormatInt(size, 10)) w.WriteHeader(http.StatusOK) _, err := wt.WriteTo(w) @@ -382,7 +402,7 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { } } } else if dataResp, ok := resp.(GetMediaResponseWriter); ok { - w.Header().Set("Content-Type", dataResp.GetContentType()) + mp.addHeaders(w, dataResp.GetContentType(), r.PathValue("fileName")) if dataResp.GetContentLength() != 0 { w.Header().Set("Content-Length", strconv.FormatInt(dataResp.GetContentLength(), 10)) } From f588c35d8b1c08c39e30071fcbc00a603111bfeb Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 6 Nov 2024 13:10:29 +0100 Subject: [PATCH 0904/1647] mediaproxy: pass through query parameters --- CHANGELOG.md | 2 ++ bridgev2/matrix/directmedia.go | 4 ++-- bridgev2/networkinterface.go | 2 +- go.mod | 2 +- go.sum | 4 ++-- mediaproxy/mediaproxy.go | 17 +++++++++++++---- 6 files changed, 21 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 26df4cbf..f07e1603 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ * *(mediaproxy)* Added `GetMediaResponseCallback` and `GetMediaResponseFile` to write proxied data directly to http response or temp file instead of having to use an `io.Reader`. +* *(mediaproxy)* Dropped support for legacy media download endpoints. +* *(mediaproxy,bridgev2)* Made interface pass through query parameters. [MSC2781]: https://github.com/matrix-org/matrix-spec-proposals/pull/2781 diff --git a/bridgev2/matrix/directmedia.go b/bridgev2/matrix/directmedia.go index bc5b312c..71c01078 100644 --- a/bridgev2/matrix/directmedia.go +++ b/bridgev2/matrix/directmedia.go @@ -71,7 +71,7 @@ func (br *Connector) GenerateContentURI(ctx context.Context, mediaID networkid.M return mxc, nil } -func (br *Connector) getDirectMedia(ctx context.Context, mediaIDStr string) (response mediaproxy.GetMediaResponse, err error) { +func (br *Connector) getDirectMedia(ctx context.Context, mediaIDStr string, params map[string]string) (response mediaproxy.GetMediaResponse, err error) { mediaID, err := base64.RawURLEncoding.DecodeString(strings.TrimPrefix(mediaIDStr, br.Config.DirectMedia.MediaIDPrefix)) if err != nil || !bytes.HasPrefix(mediaID, []byte(MediaIDPrefix)) || len(mediaID) < len(MediaIDPrefix)+MediaIDTruncatedHashLength+1 { return nil, mediaproxy.ErrInvalidMediaIDSyntax @@ -82,5 +82,5 @@ func (br *Connector) getDirectMedia(ctx context.Context, mediaIDStr string) (res return nil, mautrix.MNotFound.WithMessage("Invalid checksum in media ID part") } remoteMediaID := networkid.MediaID(mediaID[len(MediaIDPrefix) : len(mediaID)-MediaIDTruncatedHashLength]) - return br.Bridge.Network.(bridgev2.DirectMediableNetwork).Download(ctx, remoteMediaID) + return br.Bridge.Network.(bridgev2.DirectMediableNetwork).Download(ctx, remoteMediaID, params) } diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index ae7d6520..852d81ef 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -243,7 +243,7 @@ type StoppableNetwork interface { type DirectMediableNetwork interface { NetworkConnector SetUseDirectMedia() - Download(ctx context.Context, mediaID networkid.MediaID) (mediaproxy.GetMediaResponse, error) + Download(ctx context.Context, mediaID networkid.MediaID, params map[string]string) (mediaproxy.GetMediaResponse, error) } type IdentifierValidatingNetwork interface { diff --git a/go.mod b/go.mod index beb2badd..9626b21c 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.8 - go.mau.fi/util v0.8.1 + go.mau.fi/util v0.8.2-0.20241106111346-576742786fe9 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.28.0 golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c diff --git a/go.sum b/go.sum index 518fe895..88469672 100644 --- a/go.sum +++ b/go.sum @@ -51,8 +51,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.8.1 h1:Ga43cz6esQBYqcjZ/onRoVnYWoUwjWbsxVeJg2jOTSo= -go.mau.fi/util v0.8.1/go.mod h1:T1u/rD2rzidVrBLyaUdPpZiJdP/rsyi+aTzn0D+Q6wc= +go.mau.fi/util v0.8.2-0.20241106111346-576742786fe9 h1:zYcb/lTZudowXAjKi6Yc2/2y5xxglPFfy9ZT2pNGsuM= +go.mau.fi/util v0.8.2-0.20241106111346-576742786fe9/go.mod h1:T1u/rD2rzidVrBLyaUdPpZiJdP/rsyi+aTzn0D+Q6wc= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go index f2cde105..ff8b2157 100644 --- a/mediaproxy/mediaproxy.go +++ b/mediaproxy/mediaproxy.go @@ -16,6 +16,7 @@ import ( "mime/multipart" "net/http" "net/textproto" + "net/url" "os" "strconv" "strings" @@ -95,7 +96,7 @@ type GetMediaResponseFile struct { ContentType string } -type GetMediaFunc = func(ctx context.Context, mediaID string) (response GetMediaResponse, err error) +type GetMediaFunc = func(ctx context.Context, mediaID string, params map[string]string) (response GetMediaResponse, err error) type MediaProxy struct { KeyServer *federation.KeyServer @@ -218,9 +219,17 @@ func (err *ResponseError) Error() string { var ErrInvalidMediaIDSyntax = errors.New("invalid media ID syntax") +func queryToMap(vals url.Values) map[string]string { + m := make(map[string]string, len(vals)) + for k, v := range vals { + m[k] = v[0] + } + return m +} + func (mp *MediaProxy) getMedia(w http.ResponseWriter, r *http.Request) GetMediaResponse { mediaID := mux.Vars(r)["mediaID"] - resp, err := mp.GetMedia(r.Context(), mediaID) + resp, err := mp.GetMedia(r.Context(), mediaID, queryToMap(r.URL.Query())) if err != nil { //lint:ignore SA1019 deprecated types need to be supported until they're removed var respError *ResponseError @@ -384,7 +393,7 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusTemporaryRedirect) } else if fileResp, ok := resp.(*GetMediaResponseFile); ok { responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error { - mp.addHeaders(w, mimeType, r.PathValue("fileName")) + mp.addHeaders(w, mimeType, vars["fileName"]) w.Header().Set("Content-Length", strconv.FormatInt(size, 10)) w.WriteHeader(http.StatusOK) _, err := wt.WriteTo(w) @@ -402,7 +411,7 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { } } } else if dataResp, ok := resp.(GetMediaResponseWriter); ok { - mp.addHeaders(w, dataResp.GetContentType(), r.PathValue("fileName")) + mp.addHeaders(w, dataResp.GetContentType(), vars["fileName"]) if dataResp.GetContentLength() != 0 { w.Header().Set("Content-Length", strconv.FormatInt(dataResp.GetContentLength(), 10)) } From 5967fe7b0f82be44ae3caef35931207b3e2fc9aa Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 6 Nov 2024 14:50:41 +0100 Subject: [PATCH 0905/1647] bridgev2/simplevent: add pre/post handle support to EventMeta --- bridgev2/simplevent/meta.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/bridgev2/simplevent/meta.go b/bridgev2/simplevent/meta.go index 9d7d697a..f9a1ea6a 100644 --- a/bridgev2/simplevent/meta.go +++ b/bridgev2/simplevent/meta.go @@ -7,6 +7,7 @@ package simplevent import ( + "context" "time" "github.com/rs/zerolog" @@ -25,6 +26,9 @@ type EventMeta struct { CreatePortal bool Timestamp time.Time StreamOrder int64 + + PreHandleFunc func(context.Context, *bridgev2.Portal) + PostHandleFunc func(context.Context, *bridgev2.Portal) } var ( @@ -33,6 +37,8 @@ var ( _ bridgev2.RemoteEventThatMayCreatePortal = (*EventMeta)(nil) _ bridgev2.RemoteEventWithTimestamp = (*EventMeta)(nil) _ bridgev2.RemoteEventWithStreamOrder = (*EventMeta)(nil) + _ bridgev2.RemotePreHandler = (*EventMeta)(nil) + _ bridgev2.RemotePostHandler = (*EventMeta)(nil) ) func (evt *EventMeta) AddLogContext(c zerolog.Context) zerolog.Context { @@ -73,6 +79,14 @@ func (evt *EventMeta) ShouldCreatePortal() bool { return evt.CreatePortal } +func (evt *EventMeta) PreHandle(ctx context.Context, portal *bridgev2.Portal) { + evt.PreHandleFunc(ctx, portal) +} + +func (evt *EventMeta) PostHandle(ctx context.Context, portal *bridgev2.Portal) { + evt.PostHandleFunc(ctx, portal) +} + func (evt EventMeta) WithType(t bridgev2.RemoteEventType) EventMeta { evt.Type = t return evt From 449de115ffad2d547109fd369d2ccd0d0cdd251a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 6 Nov 2024 15:56:47 +0100 Subject: [PATCH 0906/1647] bridgev2/portal: run post handle even if chat resync is short-circuited --- bridgev2/portal.go | 4 ++++ bridgev2/simplevent/meta.go | 8 ++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 6ada8918..f3adddd9 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1628,6 +1628,10 @@ func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, } if evtType == RemoteEventChatResync { log.Debug().Msg("Not handling chat resync event further as portal was created by it") + postHandler, ok := evt.(RemotePostHandler) + if ok { + postHandler.PostHandle(ctx, portal) + } return } } diff --git a/bridgev2/simplevent/meta.go b/bridgev2/simplevent/meta.go index f9a1ea6a..8aa91866 100644 --- a/bridgev2/simplevent/meta.go +++ b/bridgev2/simplevent/meta.go @@ -80,11 +80,15 @@ func (evt *EventMeta) ShouldCreatePortal() bool { } func (evt *EventMeta) PreHandle(ctx context.Context, portal *bridgev2.Portal) { - evt.PreHandleFunc(ctx, portal) + if evt.PreHandleFunc != nil { + evt.PreHandleFunc(ctx, portal) + } } func (evt *EventMeta) PostHandle(ctx context.Context, portal *bridgev2.Portal) { - evt.PostHandleFunc(ctx, portal) + if evt.PostHandleFunc != nil { + evt.PostHandleFunc(ctx, portal) + } } func (evt EventMeta) WithType(t bridgev2.RemoteEventType) EventMeta { From 22a4c50e0d36c029a03d202c3635ab098fbd5a80 Mon Sep 17 00:00:00 2001 From: Scott Weber Date: Wed, 6 Nov 2024 11:23:37 -0500 Subject: [PATCH 0907/1647] event: add GetRaw() to initialize Raw if necessary (#313) --- event/content.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/event/content.go b/event/content.go index 882d3368..ab57c658 100644 --- a/event/content.go +++ b/event/content.go @@ -188,6 +188,13 @@ func IsUnsupportedContentType(err error) bool { var ErrContentAlreadyParsed = errors.New("content is already parsed") var ErrUnsupportedContentType = errors.New("unsupported event type") +func (content *Content) GetRaw() map[string]interface{} { + if content.Raw == nil { + content.Raw = make(map[string]interface{}) + } + return content.Raw +} + func (content *Content) ParseRaw(evtType Type) error { if content.Parsed != nil { return ErrContentAlreadyParsed From 702a0e047c2d5dd2db006a529388a49eeb2808d0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 8 Nov 2024 10:24:25 +0100 Subject: [PATCH 0908/1647] bridgev2/config: add option to disable tag bridging --- bridgev2/bridgeconfig/config.go | 1 + bridgev2/bridgeconfig/upgrade.go | 1 + bridgev2/matrix/mxmain/example-config.yaml | 2 ++ bridgev2/portal.go | 2 +- 4 files changed, 5 insertions(+), 1 deletion(-) diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index 4c0fa6b4..01cb9478 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -65,6 +65,7 @@ type BridgeConfig struct { ResendBridgeInfo bool `yaml:"resend_bridge_info"` BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` TagOnlyOnCreate bool `yaml:"tag_only_on_create"` + EnableTagBridging bool `yaml:"enable_tag_bridging"` MuteOnlyOnCreate bool `yaml:"mute_only_on_create"` OutgoingMessageReID bool `yaml:"outgoing_message_re_id"` CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"` diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 4122f4d6..a4402e58 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -30,6 +30,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "bridge", "resend_bridge_info") helper.Copy(up.Bool, "bridge", "bridge_matrix_leave") helper.Copy(up.Bool, "bridge", "tag_only_on_create") + helper.Copy(up.Bool, "bridge", "enable_tag_bridging") helper.Copy(up.Bool, "bridge", "mute_only_on_create") helper.Copy(up.Bool, "bridge", "cleanup_on_logout", "enabled") helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "private") diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index d31396ff..f705c4c0 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -24,6 +24,8 @@ bridge: # Should room tags only be synced when creating the portal? Tags mean things like favorite/pin and archive/low priority. # Tags currently can't be synced back to the remote network, so a continuous sync means tagging from Matrix will be undone. tag_only_on_create: true + # Should room tags be synced at all? + enable_tag_bridging: true # Should room mute status only be synced when creating the portal? # Like tags, mutes can't currently be synced back to the remote network. mute_only_on_create: true diff --git a/bridgev2/portal.go b/bridgev2/portal.go index f3adddd9..337921f9 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -3269,7 +3269,7 @@ func (portal *Portal) updateUserLocalInfo(ctx context.Context, info *UserLocalPo zerolog.Ctx(ctx).Err(err).Msg("Failed to mute room") } } - if info.Tag != nil && (didJustCreate || !portal.Bridge.Config.TagOnlyOnCreate) && (!didJustCreate || *info.Tag != "") { + if info.Tag != nil && portal.Bridge.Config.EnableTagBridging && (didJustCreate || !portal.Bridge.Config.TagOnlyOnCreate) && (!didJustCreate || *info.Tag != "") { err := dp.TagRoom(ctx, portal.MXID, *info.Tag, *info.Tag != "") if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to tag room") From 8fbf245e97b5b102c9ca35104308bd0947383a67 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 13 Nov 2024 15:26:01 +0200 Subject: [PATCH 0909/1647] bridgev2/commands: include state event in list-logins --- bridgev2/user.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/user.go b/bridgev2/user.go index 993eda92..e6a5dd99 100644 --- a/bridgev2/user.go +++ b/bridgev2/user.go @@ -180,7 +180,7 @@ func (user *User) GetFormattedUserLogins() string { user.Bridge.cacheLock.Lock() logins := make([]string, len(user.logins)) for key, val := range user.logins { - logins = append(logins, fmt.Sprintf("* `%s` (%s)", key, val.RemoteName)) + logins = append(logins, fmt.Sprintf("* `%s` (%s) - `%s`", key, val.RemoteName, val.BridgeState.GetPrev().StateEvent)) } user.Bridge.cacheLock.Unlock() return strings.Join(logins, "\n") From b22764aa170af7b361a71b4be3c75fcde024ad5b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 13 Nov 2024 15:56:38 +0200 Subject: [PATCH 0910/1647] dependencies: update --- go.mod | 18 +++++++++--------- go.sum | 32 ++++++++++++++++---------------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/go.mod b/go.mod index 9626b21c..373bbf06 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,7 @@ module maunium.net/go/mautrix go 1.22.0 -toolchain go1.23.2 +toolchain go1.23.3 require ( filippo.io/edwards25519 v1.1.0 @@ -18,12 +18,12 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.8 - go.mau.fi/util v0.8.2-0.20241106111346-576742786fe9 + go.mau.fi/util v0.8.2-0.20241113135441-636f8643f367 go.mau.fi/zeroconfig v0.1.3 - golang.org/x/crypto v0.28.0 - golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c - golang.org/x/net v0.30.0 - golang.org/x/sync v0.8.0 + golang.org/x/crypto v0.29.0 + golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f + golang.org/x/net v0.31.0 + golang.org/x/sync v0.9.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 ) @@ -33,11 +33,11 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect - github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7 // indirect + github.com/petermattis/goid v0.0.0-20241025130422-66cb2e6d7274 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect - golang.org/x/sys v0.26.0 // indirect - golang.org/x/text v0.19.0 // indirect + golang.org/x/sys v0.27.0 // indirect + golang.org/x/text v0.20.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index 88469672..0a95fccd 100644 --- a/go.sum +++ b/go.sum @@ -26,8 +26,8 @@ github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APP github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7 h1:Dx7Ovyv/SFnMFw3fD4oEoeorXc6saIiQ23LrGLth0Gw= -github.com/petermattis/goid v0.0.0-20240813172612-4fcff4a6cae7/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/petermattis/goid v0.0.0-20241025130422-66cb2e6d7274 h1:qli3BGQK0tYDkSEvZ/FzZTi9ZrOX86Q6CIhKLGc489A= +github.com/petermattis/goid v0.0.0-20241025130422-66cb2e6d7274/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -51,26 +51,26 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.8.2-0.20241106111346-576742786fe9 h1:zYcb/lTZudowXAjKi6Yc2/2y5xxglPFfy9ZT2pNGsuM= -go.mau.fi/util v0.8.2-0.20241106111346-576742786fe9/go.mod h1:T1u/rD2rzidVrBLyaUdPpZiJdP/rsyi+aTzn0D+Q6wc= +go.mau.fi/util v0.8.2-0.20241113135441-636f8643f367 h1:GU0TYiAOU79p6r3jf3e4k5cdgnPxOcJWkWeWampdAjw= +go.mau.fi/util v0.8.2-0.20241113135441-636f8643f367/go.mod h1:SVzC++wSl8Yq4YVQRClLPa1frNpSVDVP6mfkw/OvDbc= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= -golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= -golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c h1:7dEasQXItcW1xKJ2+gg5VOiBnqWrJc+rq0DPKyvvdbY= -golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c/go.mod h1:NQtJDoLvd6faHhE7m4T/1IY708gDefGGjR/iUW8yQQ8= -golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= -golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ= +golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg= +golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f h1:XdNn9LlyWAhLVp6P/i8QYBW+hlyhrhei9uErw2B5GJo= +golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f/go.mod h1:D5SMRVC3C2/4+F/DB1wZsLRnSNimn2Sp/NPsCrsv8ak= +golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo= +golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM= +golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ= +golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= -golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= +golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= +golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= From 3f9a63784ec5b52139dd908d08b8b76c94c311f7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 13 Nov 2024 18:31:41 +0200 Subject: [PATCH 0911/1647] bridgev2/config: add more granular control over room tags --- bridgev2/bridgeconfig/config.go | 3 ++- bridgev2/bridgeconfig/upgrade.go | 2 +- bridgev2/matrix/mxmain/example-config.yaml | 4 ++-- bridgev2/portal.go | 6 +++++- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index 01cb9478..cf87864f 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -11,6 +11,7 @@ import ( "go.mau.fi/zeroconfig" "gopkg.in/yaml.v3" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/mediaproxy" ) @@ -65,7 +66,7 @@ type BridgeConfig struct { ResendBridgeInfo bool `yaml:"resend_bridge_info"` BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` TagOnlyOnCreate bool `yaml:"tag_only_on_create"` - EnableTagBridging bool `yaml:"enable_tag_bridging"` + OnlyBridgeTags []event.RoomTag `yaml:"only_bridge_tags"` MuteOnlyOnCreate bool `yaml:"mute_only_on_create"` OutgoingMessageReID bool `yaml:"outgoing_message_re_id"` CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"` diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index a4402e58..3948cc11 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -30,7 +30,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "bridge", "resend_bridge_info") helper.Copy(up.Bool, "bridge", "bridge_matrix_leave") helper.Copy(up.Bool, "bridge", "tag_only_on_create") - helper.Copy(up.Bool, "bridge", "enable_tag_bridging") + helper.Copy(up.List, "bridge", "only_bridge_tags") helper.Copy(up.Bool, "bridge", "mute_only_on_create") helper.Copy(up.Bool, "bridge", "cleanup_on_logout", "enabled") helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "private") diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index f705c4c0..b8a00637 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -24,8 +24,8 @@ bridge: # Should room tags only be synced when creating the portal? Tags mean things like favorite/pin and archive/low priority. # Tags currently can't be synced back to the remote network, so a continuous sync means tagging from Matrix will be undone. tag_only_on_create: true - # Should room tags be synced at all? - enable_tag_bridging: true + # List of tags to allow bridging. If empty, no tags will be bridged. + only_bridge_tags: [m.favourite, m.lowpriority] # Should room mute status only be synced when creating the portal? # Like tags, mutes can't currently be synced back to the remote network. mute_only_on_create: true diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 337921f9..8e1afdac 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -3269,7 +3269,11 @@ func (portal *Portal) updateUserLocalInfo(ctx context.Context, info *UserLocalPo zerolog.Ctx(ctx).Err(err).Msg("Failed to mute room") } } - if info.Tag != nil && portal.Bridge.Config.EnableTagBridging && (didJustCreate || !portal.Bridge.Config.TagOnlyOnCreate) && (!didJustCreate || *info.Tag != "") { + if info.Tag != nil && + len(portal.Bridge.Config.OnlyBridgeTags) > 0 && + (*info.Tag == "" || slices.Contains(portal.Bridge.Config.OnlyBridgeTags, *info.Tag)) && + (didJustCreate || !portal.Bridge.Config.TagOnlyOnCreate) && + (!didJustCreate || *info.Tag != "") { err := dp.TagRoom(ctx, portal.MXID, *info.Tag, *info.Tag != "") if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to tag room") From 21aa3291f31ff16a1399f38e59f899502b0e5344 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 14 Nov 2024 14:58:08 +0200 Subject: [PATCH 0912/1647] bridgev2/database: include portal receiver in reaction queries --- bridgev2/database/reaction.go | 34 +++++++++++++++++----------------- bridgev2/portal.go | 14 +++++++------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/bridgev2/database/reaction.go b/bridgev2/database/reaction.go index 08ab2c8e..b65a5c38 100644 --- a/bridgev2/database/reaction.go +++ b/bridgev2/database/reaction.go @@ -41,11 +41,11 @@ const ( getReactionBaseQuery = ` SELECT bridge_id, message_id, message_part_id, sender_id, sender_mxid, emoji_id, emoji, room_id, room_receiver, mxid, timestamp, metadata FROM reaction ` - getReactionByIDQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND message_part_id=$3 AND sender_id=$4 AND emoji_id=$5` - getReactionByIDWithoutMessagePartQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND sender_id=$3 AND emoji_id=$4 ORDER BY message_part_id ASC LIMIT 1` - getAllReactionsToMessageBySenderQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND sender_id=$3 ORDER BY timestamp DESC` - getAllReactionsToMessageQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2` - getAllReactionsToMessagePartQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND message_id=$2 AND message_part_id=$3` + getReactionByIDQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3 AND message_part_id=$4 AND sender_id=$5 AND emoji_id=$6` + getReactionByIDWithoutMessagePartQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3 AND sender_id=$4 AND emoji_id=$5 ORDER BY message_part_id ASC LIMIT 1` + getAllReactionsToMessageBySenderQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3 AND sender_id=$4 ORDER BY timestamp DESC` + getAllReactionsToMessageQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3` + getAllReactionsToMessagePartQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3 AND message_part_id=$4` getReactionByMXIDQuery = getReactionBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` upsertReactionQuery = ` INSERT INTO reaction (bridge_id, message_id, message_part_id, sender_id, sender_mxid, emoji_id, emoji, room_id, room_receiver, mxid, timestamp, metadata) @@ -54,28 +54,28 @@ const ( DO UPDATE SET sender_mxid=excluded.sender_mxid, mxid=excluded.mxid, timestamp=excluded.timestamp, emoji=excluded.emoji, metadata=excluded.metadata ` deleteReactionQuery = ` - DELETE FROM reaction WHERE bridge_id=$1 AND message_id=$2 AND message_part_id=$3 AND sender_id=$4 AND emoji_id=$5 + DELETE FROM reaction WHERE bridge_id=$1 AND room_receiver=$2 AND message_id=$3 AND message_part_id=$4 AND sender_id=$5 AND emoji_id=$6 ` ) -func (rq *ReactionQuery) GetByID(ctx context.Context, messageID networkid.MessageID, messagePartID networkid.PartID, senderID networkid.UserID, emojiID networkid.EmojiID) (*Reaction, error) { - return rq.QueryOne(ctx, getReactionByIDQuery, rq.BridgeID, messageID, messagePartID, senderID, emojiID) +func (rq *ReactionQuery) GetByID(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID, messagePartID networkid.PartID, senderID networkid.UserID, emojiID networkid.EmojiID) (*Reaction, error) { + return rq.QueryOne(ctx, getReactionByIDQuery, rq.BridgeID, receiver, messageID, messagePartID, senderID, emojiID) } -func (rq *ReactionQuery) GetByIDWithoutMessagePart(ctx context.Context, messageID networkid.MessageID, senderID networkid.UserID, emojiID networkid.EmojiID) (*Reaction, error) { - return rq.QueryOne(ctx, getReactionByIDWithoutMessagePartQuery, rq.BridgeID, messageID, senderID, emojiID) +func (rq *ReactionQuery) GetByIDWithoutMessagePart(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID, senderID networkid.UserID, emojiID networkid.EmojiID) (*Reaction, error) { + return rq.QueryOne(ctx, getReactionByIDWithoutMessagePartQuery, rq.BridgeID, receiver, messageID, senderID, emojiID) } -func (rq *ReactionQuery) GetAllToMessageBySender(ctx context.Context, messageID networkid.MessageID, senderID networkid.UserID) ([]*Reaction, error) { - return rq.QueryMany(ctx, getAllReactionsToMessageBySenderQuery, rq.BridgeID, messageID, senderID) +func (rq *ReactionQuery) GetAllToMessageBySender(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID, senderID networkid.UserID) ([]*Reaction, error) { + return rq.QueryMany(ctx, getAllReactionsToMessageBySenderQuery, rq.BridgeID, receiver, messageID, senderID) } -func (rq *ReactionQuery) GetAllToMessage(ctx context.Context, messageID networkid.MessageID) ([]*Reaction, error) { - return rq.QueryMany(ctx, getAllReactionsToMessageQuery, rq.BridgeID, messageID) +func (rq *ReactionQuery) GetAllToMessage(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID) ([]*Reaction, error) { + return rq.QueryMany(ctx, getAllReactionsToMessageQuery, rq.BridgeID, receiver, messageID) } -func (rq *ReactionQuery) GetAllToMessagePart(ctx context.Context, messageID networkid.MessageID, partID networkid.PartID) ([]*Reaction, error) { - return rq.QueryMany(ctx, getAllReactionsToMessagePartQuery, rq.BridgeID, messageID, partID) +func (rq *ReactionQuery) GetAllToMessagePart(ctx context.Context, receiver networkid.UserLoginID, messageID networkid.MessageID, partID networkid.PartID) ([]*Reaction, error) { + return rq.QueryMany(ctx, getAllReactionsToMessagePartQuery, rq.BridgeID, receiver, messageID, partID) } func (rq *ReactionQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Reaction, error) { @@ -89,7 +89,7 @@ func (rq *ReactionQuery) Upsert(ctx context.Context, reaction *Reaction) error { func (rq *ReactionQuery) Delete(ctx context.Context, reaction *Reaction) error { ensureBridgeIDMatches(&reaction.BridgeID, rq.BridgeID) - return rq.Exec(ctx, deleteReactionQuery, reaction.BridgeID, reaction.MessageID, reaction.MessagePartID, reaction.SenderID, reaction.EmojiID) + return rq.Exec(ctx, deleteReactionQuery, reaction.BridgeID, reaction.Room.Receiver, reaction.MessageID, reaction.MessagePartID, reaction.SenderID, reaction.EmojiID) } func (r *Reaction) Scan(row dbutil.Scannable) (*Reaction, error) { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 8e1afdac..0b5c2adc 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1147,7 +1147,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi if portal.Bridge.Config.OutgoingMessageReID { deterministicID = portal.Bridge.Matrix.GenerateReactionEventID(portal.MXID, reactionTarget, preResp.SenderID, preResp.EmojiID) } - existing, err := portal.Bridge.DB.Reaction.GetByID(ctx, reactionTarget.ID, reactionTarget.PartID, preResp.SenderID, preResp.EmojiID) + existing, err := portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, reactionTarget.ID, reactionTarget.PartID, preResp.SenderID, preResp.EmojiID) if err != nil { log.Err(err).Msg("Failed to check if reaction is a duplicate") return @@ -1169,7 +1169,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi } react.PreHandleResp = &preResp if preResp.MaxReactions > 0 { - allReactions, err := portal.Bridge.DB.Reaction.GetAllToMessageBySender(ctx, reactionTarget.ID, preResp.SenderID) + allReactions, err := portal.Bridge.DB.Reaction.GetAllToMessageBySender(ctx, portal.Receiver, reactionTarget.ID, preResp.SenderID) if err != nil { log.Err(err).Msg("Failed to get all reactions to message by sender") portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get previous reactions: %w", ErrDatabaseError, err)) @@ -2162,9 +2162,9 @@ func (portal *Portal) getTargetMessagePart(ctx context.Context, evt RemoteEventW func (portal *Portal) getTargetReaction(ctx context.Context, evt RemoteReactionRemove) (*database.Reaction, error) { if partTargeter, ok := evt.(RemoteEventWithTargetPart); ok { - return portal.Bridge.DB.Reaction.GetByID(ctx, evt.GetTargetMessage(), partTargeter.GetTargetMessagePart(), evt.GetSender().Sender, evt.GetRemovedEmojiID()) + return portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, evt.GetTargetMessage(), partTargeter.GetTargetMessagePart(), evt.GetSender().Sender, evt.GetRemovedEmojiID()) } else { - return portal.Bridge.DB.Reaction.GetByIDWithoutMessagePart(ctx, evt.GetTargetMessage(), evt.GetSender().Sender, evt.GetRemovedEmojiID()) + return portal.Bridge.DB.Reaction.GetByIDWithoutMessagePart(ctx, portal.Receiver, evt.GetTargetMessage(), evt.GetSender().Sender, evt.GetRemovedEmojiID()) } } @@ -2196,9 +2196,9 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User } var existingReactions []*database.Reaction if partTargeter, ok := evt.(RemoteEventWithTargetPart); ok { - existingReactions, err = portal.Bridge.DB.Reaction.GetAllToMessagePart(ctx, evt.GetTargetMessage(), partTargeter.GetTargetMessagePart()) + existingReactions, err = portal.Bridge.DB.Reaction.GetAllToMessagePart(ctx, portal.Receiver, evt.GetTargetMessage(), partTargeter.GetTargetMessagePart()) } else { - existingReactions, err = portal.Bridge.DB.Reaction.GetAllToMessage(ctx, evt.GetTargetMessage()) + existingReactions, err = portal.Bridge.DB.Reaction.GetAllToMessage(ctx, portal.Receiver, evt.GetTargetMessage()) } existing := make(map[networkid.UserID]map[networkid.EmojiID]*database.Reaction) for _, existingReaction := range existingReactions { @@ -2317,7 +2317,7 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi return } emoji, emojiID := evt.GetReactionEmoji() - existingReaction, err := portal.Bridge.DB.Reaction.GetByID(ctx, targetMessage.ID, targetMessage.PartID, evt.GetSender().Sender, emojiID) + existingReaction, err := portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, targetMessage.ID, targetMessage.PartID, evt.GetSender().Sender, emojiID) if err != nil { log.Err(err).Msg("Failed to check if reaction is a duplicate") return From 5a3dd8d45c105e29075e84911f346c319bf18cb7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 15 Nov 2024 15:34:31 +0200 Subject: [PATCH 0913/1647] imports: use html instead of x/net/html for escaping --- bridgev2/commands/login.go | 2 +- bridgev2/commands/startchat.go | 3 +-- event/message.go | 3 +-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go index 8896eb60..660c90d7 100644 --- a/bridgev2/commands/login.go +++ b/bridgev2/commands/login.go @@ -10,6 +10,7 @@ import ( "context" "encoding/json" "fmt" + "html" "net/url" "regexp" "slices" @@ -17,7 +18,6 @@ import ( "github.com/skip2/go-qrcode" "go.mau.fi/util/curl" - "golang.org/x/net/html" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go index 53c07530..42f528b0 100644 --- a/bridgev2/commands/startchat.go +++ b/bridgev2/commands/startchat.go @@ -9,11 +9,10 @@ package commands import ( "context" "fmt" + "html" "strings" "time" - "golang.org/x/net/html" - "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/id" diff --git a/event/message.go b/event/message.go index 9badd9a2..92bdcf07 100644 --- a/event/message.go +++ b/event/message.go @@ -8,12 +8,11 @@ package event import ( "encoding/json" + "html" "slices" "strconv" "strings" - "golang.org/x/net/html" - "maunium.net/go/mautrix/crypto/attachment" "maunium.net/go/mautrix/id" ) From c13ec82f6d1919881f6f929086c3b4824e71d956 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 16 Nov 2024 16:03:52 +0200 Subject: [PATCH 0914/1647] Bump version to v0.22.0 --- CHANGELOG.md | 4 +++- go.mod | 2 +- go.sum | 4 ++-- version.go | 2 +- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f07e1603..18d96b35 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,11 @@ -## unreleased +## v0.22.0 (2024-11-16) * *(hicli)* Moved package into gomuks repo. * *(bridgev2/commands)* Fixed cookie unescaping in login commands. * *(bridgev2/portal)* Added special `DefaultChatName` constant to explicitly reset portal names to the default (based on members). +* *(bridgev2/config)* Added options to disable room tag bridging. +* *(bridgev2/database)* Fixed reaction queries not including portal receiver. * *(appservice)* Updated [MSC2409] stable registration field name from `push_ephemeral` to `receive_ephemeral`. Homeserver admins must update existing registrations manually. diff --git a/go.mod b/go.mod index 373bbf06..8bf9baac 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.8 - go.mau.fi/util v0.8.2-0.20241113135441-636f8643f367 + go.mau.fi/util v0.8.2 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.29.0 golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f diff --git a/go.sum b/go.sum index 0a95fccd..205cbfaf 100644 --- a/go.sum +++ b/go.sum @@ -51,8 +51,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.8.2-0.20241113135441-636f8643f367 h1:GU0TYiAOU79p6r3jf3e4k5cdgnPxOcJWkWeWampdAjw= -go.mau.fi/util v0.8.2-0.20241113135441-636f8643f367/go.mod h1:SVzC++wSl8Yq4YVQRClLPa1frNpSVDVP6mfkw/OvDbc= +go.mau.fi/util v0.8.2 h1:zWbVHwdRKwI6U9AusmZ8bwgcLosikwbb4GGqLrNr1YE= +go.mau.fi/util v0.8.2/go.mod h1:BHHC9R2WLMJd1bwTZfTcFxUgRFmUgUmiWcT4RbzUgiA= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ= diff --git a/version.go b/version.go index 29368573..dd70d55b 100644 --- a/version.go +++ b/version.go @@ -7,7 +7,7 @@ import ( "strings" ) -const Version = "v0.21.1" +const Version = "v0.22.0" var GoModVersion = "" var Commit = "" From 88dd813d6722b00d70d07d6fe382d156cd2108e9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 18 Nov 2024 19:52:36 +0200 Subject: [PATCH 0915/1647] events/beeper: make order_string a pointer --- event/beeper.go | 29 +++++++++++++++++++---------- event/events.go | 8 ++++---- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/event/beeper.go b/event/beeper.go index 0af466f8..8af9e0d0 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -114,35 +114,44 @@ type BeeperEncodedOrder struct { suborder int64 } -func NewBeeperEncodedOrder(order int64, suborder int64) BeeperEncodedOrder { - return BeeperEncodedOrder{order: order, suborder: suborder} +func NewBeeperEncodedOrder(order int64, suborder int64) *BeeperEncodedOrder { + return &BeeperEncodedOrder{order: order, suborder: suborder} } -func BeeperEncodedOrderFromString(str string) (BeeperEncodedOrder, error) { +func BeeperEncodedOrderFromString(str string) (*BeeperEncodedOrder, error) { order, suborder, err := decodeIntPair(str) if err != nil { - return BeeperEncodedOrder{}, err + return nil, err } - return BeeperEncodedOrder{order: order, suborder: suborder}, nil + return &BeeperEncodedOrder{order: order, suborder: suborder}, nil } -func (b BeeperEncodedOrder) String() string { +func (b *BeeperEncodedOrder) String() string { + if b == nil { + return "" + } return encodeIntPair(b.order, b.suborder) } -func (b BeeperEncodedOrder) OrderPair() (int64, int64) { +func (b *BeeperEncodedOrder) OrderPair() (int64, int64) { + if b == nil { + return 0, 0 + } return b.order, b.suborder } -func (b BeeperEncodedOrder) IsZero() bool { - return b.order == 0 && b.suborder == 0 +func (b *BeeperEncodedOrder) IsZero() bool { + return b == nil || (b.order == 0 && b.suborder == 0) } -func (b BeeperEncodedOrder) MarshalJSON() ([]byte, error) { +func (b *BeeperEncodedOrder) MarshalJSON() ([]byte, error) { return []byte(`"` + b.String() + `"`), nil } func (b *BeeperEncodedOrder) UnmarshalJSON(data []byte) error { + if b == nil { + return fmt.Errorf("BeeperEncodedOrder: receiver is nil") + } str := string(data) if len(str) < 2 { return fmt.Errorf("invalid encoded order string: %s", str) diff --git a/event/events.go b/event/events.go index 38f0d848..56104123 100644 --- a/event/events.go +++ b/event/events.go @@ -144,10 +144,10 @@ type Unsigned struct { RedactedBecause *Event `json:"redacted_because,omitempty"` InviteRoomState []StrippedState `json:"invite_room_state,omitempty"` - BeeperHSOrder int64 `json:"com.beeper.hs.order,omitempty"` - BeeperHSSuborder int64 `json:"com.beeper.hs.suborder,omitempty"` - BeeperHSOrderString BeeperEncodedOrder `json:"com.beeper.hs.order_string,omitempty"` - BeeperFromBackup bool `json:"com.beeper.from_backup,omitempty"` + BeeperHSOrder int64 `json:"com.beeper.hs.order,omitempty"` + BeeperHSSuborder int64 `json:"com.beeper.hs.suborder,omitempty"` + BeeperHSOrderString *BeeperEncodedOrder `json:"com.beeper.hs.order_string,omitempty"` + BeeperFromBackup bool `json:"com.beeper.from_backup,omitempty"` } func (us *Unsigned) IsEmpty() bool { From 363fdfa3b2274cb9f4a8b5a5894bdc2e2de410f9 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Wed, 17 Jul 2024 10:02:12 -0600 Subject: [PATCH 0916/1647] verificationhelper: add callback for when other user reports done Signed-off-by: Sumner Evans --- crypto/verificationhelper/callbacks_test.go | 11 +++++++++++ crypto/verificationhelper/verificationhelper.go | 7 +++++++ .../verificationhelper_qr_crosssign_test.go | 2 ++ .../verificationhelper_qr_self_test.go | 2 ++ .../verificationhelper/verificationhelper_sas_test.go | 1 + 5 files changed, 23 insertions(+) diff --git a/crypto/verificationhelper/callbacks_test.go b/crypto/verificationhelper/callbacks_test.go index 7b1055d1..cc473f9a 100644 --- a/crypto/verificationhelper/callbacks_test.go +++ b/crypto/verificationhelper/callbacks_test.go @@ -25,6 +25,7 @@ type baseVerificationCallbacks struct { verificationsRequested map[id.UserID][]id.VerificationTransactionID qrCodesShown map[id.VerificationTransactionID]*verificationhelper.QRCode qrCodesScanned map[id.VerificationTransactionID]struct{} + otherDoneTransactions map[id.VerificationTransactionID]struct{} doneTransactions map[id.VerificationTransactionID]struct{} verificationCancellation map[id.VerificationTransactionID]*event.VerificationCancelEventContent emojisShown map[id.VerificationTransactionID][]rune @@ -36,6 +37,7 @@ func newBaseVerificationCallbacks() *baseVerificationCallbacks { verificationsRequested: map[id.UserID][]id.VerificationTransactionID{}, qrCodesShown: map[id.VerificationTransactionID]*verificationhelper.QRCode{}, qrCodesScanned: map[id.VerificationTransactionID]struct{}{}, + otherDoneTransactions: map[id.VerificationTransactionID]struct{}{}, doneTransactions: map[id.VerificationTransactionID]struct{}{}, verificationCancellation: map[id.VerificationTransactionID]*event.VerificationCancelEventContent{}, emojisShown: map[id.VerificationTransactionID][]rune{}, @@ -60,6 +62,11 @@ func (c *baseVerificationCallbacks) WasOurQRCodeScanned(txnID id.VerificationTra return ok } +func (c *baseVerificationCallbacks) OtherReportedDone(txnID id.VerificationTransactionID) bool { + _, ok := c.otherDoneTransactions[txnID] + return ok +} + func (c *baseVerificationCallbacks) IsVerificationDone(txnID id.VerificationTransactionID) bool { _, ok := c.doneTransactions[txnID] return ok @@ -88,6 +95,10 @@ func (c *baseVerificationCallbacks) VerificationCancelled(ctx context.Context, t } } +func (c *baseVerificationCallbacks) OtherReportsDone(ctx context.Context, txnID id.VerificationTransactionID) { + c.otherDoneTransactions[txnID] = struct{}{} +} + func (c *baseVerificationCallbacks) VerificationDone(ctx context.Context, txnID id.VerificationTransactionID) { c.doneTransactions[txnID] = struct{}{} } diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index d3b7d4f5..bad0b006 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -115,6 +115,10 @@ type RequiredCallbacks interface { // VerificationCancelled is called when the verification is cancelled. VerificationCancelled(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) + // OtherReportsDone is called when the other user has reported that the + // verification is done. + OtherReportsDone(ctx context.Context, txnID id.VerificationTransactionID) + // VerificationDone is called when the verification is done. VerificationDone(ctx context.Context, txnID id.VerificationTransactionID) } @@ -152,6 +156,7 @@ type VerificationHelper struct { supportedMethods []event.VerificationMethod verificationRequested func(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID) verificationCancelledCallback func(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) + otherReportsDone func(ctx context.Context, txnID id.VerificationTransactionID) verificationDone func(ctx context.Context, txnID id.VerificationTransactionID) showSAS func(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, decimals []int) @@ -179,6 +184,7 @@ func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, call } else { helper.verificationRequested = c.VerificationRequested helper.verificationCancelledCallback = c.VerificationCancelled + helper.otherReportsDone = c.OtherReportsDone helper.verificationDone = c.VerificationDone } @@ -880,6 +886,7 @@ func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn *verif } txn.ReceivedTheirDone = true + vh.otherReportsDone(ctx, txn.TransactionID) if txn.SentOurDone { delete(vh.activeTransactions, txn.TransactionID) vh.verificationDone(ctx, txn.TransactionID) diff --git a/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go b/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go index aace2230..0003410b 100644 --- a/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go +++ b/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go @@ -99,6 +99,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { assert.Equal(t, txnID, doneEvt.TransactionID) ts.dispatchToDevice(t, ctx, sendingClient) + assert.True(t, sendingCallbacks.OtherReportedDone(txnID)) } else { // receiving scans QR // Emulate scanning the QR code shown by the sending device on // the receiving device. @@ -137,6 +138,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { assert.Equal(t, txnID, doneEvt.TransactionID) ts.dispatchToDevice(t, ctx, receivingClient) + assert.True(t, receivingCallbacks.OtherReportedDone(txnID)) } // Ensure that both devices have marked the verification as done. diff --git a/crypto/verificationhelper/verificationhelper_qr_self_test.go b/crypto/verificationhelper/verificationhelper_qr_self_test.go index 11358b88..9942ae30 100644 --- a/crypto/verificationhelper/verificationhelper_qr_self_test.go +++ b/crypto/verificationhelper/verificationhelper_qr_self_test.go @@ -200,6 +200,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { assert.Equal(t, txnID, doneEvt.TransactionID) ts.dispatchToDevice(t, ctx, sendingClient) + assert.True(t, sendingCallbacks.OtherReportedDone(txnID)) } else { // receiving scans QR // Emulate scanning the QR code shown by the sending device on // the receiving device. @@ -238,6 +239,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { assert.Equal(t, txnID, doneEvt.TransactionID) ts.dispatchToDevice(t, ctx, receivingClient) + assert.True(t, receivingCallbacks.OtherReportedDone(txnID)) } // Ensure that both devices have marked the verification as done. diff --git a/crypto/verificationhelper/verificationhelper_sas_test.go b/crypto/verificationhelper/verificationhelper_sas_test.go index 20e52e0f..4f036f18 100644 --- a/crypto/verificationhelper/verificationhelper_sas_test.go +++ b/crypto/verificationhelper/verificationhelper_sas_test.go @@ -269,6 +269,7 @@ func TestVerification_SAS(t *testing.T) { // twice to process and drain all of the events. ts.dispatchToDevice(t, ctx, sendingClient) ts.dispatchToDevice(t, ctx, receivingClient) + assert.True(t, receivingCallbacks.OtherReportedDone(txnID)) ts.dispatchToDevice(t, ctx, sendingClient) ts.dispatchToDevice(t, ctx, receivingClient) assert.True(t, sendingCallbacks.IsVerificationDone(txnID)) From 039f7335e4d0d6af31c6185272ead8b823b6b7fc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 19 Nov 2024 01:34:35 +0200 Subject: [PATCH 0917/1647] client: switch to via in /join calls as per MSC4156 --- client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client.go b/client.go index dc59ca10..e931fa66 100644 --- a/client.go +++ b/client.go @@ -942,7 +942,7 @@ func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias, serverName strin var urlPath string if serverName != "" { urlPath = cli.BuildURLWithQuery(ClientURLPath{"v3", "join", roomIDorAlias}, map[string]string{ - "server_name": serverName, + "via": serverName, }) } else { urlPath = cli.BuildClientURL("v3", "join", roomIDorAlias) From 4bc4bc00465a224b1262738d84ddd0a5bd375c35 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 19 Nov 2024 16:00:28 +0200 Subject: [PATCH 0918/1647] bridgev2/networkinterface: add native flag for push config --- bridgev2/networkinterface.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 852d81ef..6e31e721 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -808,9 +808,10 @@ type APNsPushConfig struct { } type PushConfig struct { - Web *WebPushConfig `json:"web,omitempty"` - FCM *FCMPushConfig `json:"fcm,omitempty"` - APNs *APNsPushConfig `json:"apns,omitempty"` + Web *WebPushConfig `json:"web,omitempty"` + FCM *FCMPushConfig `json:"fcm,omitempty"` + APNs *APNsPushConfig `json:"apns,omitempty"` + Native bool `json:"native,omitempty"` } type PushableNetworkAPI interface { From d575cc79ef863ef67043fd33cef0117e7cd45c69 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 19 Nov 2024 16:09:18 -0700 Subject: [PATCH 0919/1647] Revert "verificationhelper: add callback for when other user reports done" This reverts commit 363fdfa3b2274cb9f4a8b5a5894bdc2e2de410f9. --- crypto/verificationhelper/callbacks_test.go | 11 ----------- crypto/verificationhelper/verificationhelper.go | 7 ------- .../verificationhelper_qr_crosssign_test.go | 2 -- .../verificationhelper_qr_self_test.go | 2 -- .../verificationhelper/verificationhelper_sas_test.go | 1 - 5 files changed, 23 deletions(-) diff --git a/crypto/verificationhelper/callbacks_test.go b/crypto/verificationhelper/callbacks_test.go index cc473f9a..7b1055d1 100644 --- a/crypto/verificationhelper/callbacks_test.go +++ b/crypto/verificationhelper/callbacks_test.go @@ -25,7 +25,6 @@ type baseVerificationCallbacks struct { verificationsRequested map[id.UserID][]id.VerificationTransactionID qrCodesShown map[id.VerificationTransactionID]*verificationhelper.QRCode qrCodesScanned map[id.VerificationTransactionID]struct{} - otherDoneTransactions map[id.VerificationTransactionID]struct{} doneTransactions map[id.VerificationTransactionID]struct{} verificationCancellation map[id.VerificationTransactionID]*event.VerificationCancelEventContent emojisShown map[id.VerificationTransactionID][]rune @@ -37,7 +36,6 @@ func newBaseVerificationCallbacks() *baseVerificationCallbacks { verificationsRequested: map[id.UserID][]id.VerificationTransactionID{}, qrCodesShown: map[id.VerificationTransactionID]*verificationhelper.QRCode{}, qrCodesScanned: map[id.VerificationTransactionID]struct{}{}, - otherDoneTransactions: map[id.VerificationTransactionID]struct{}{}, doneTransactions: map[id.VerificationTransactionID]struct{}{}, verificationCancellation: map[id.VerificationTransactionID]*event.VerificationCancelEventContent{}, emojisShown: map[id.VerificationTransactionID][]rune{}, @@ -62,11 +60,6 @@ func (c *baseVerificationCallbacks) WasOurQRCodeScanned(txnID id.VerificationTra return ok } -func (c *baseVerificationCallbacks) OtherReportedDone(txnID id.VerificationTransactionID) bool { - _, ok := c.otherDoneTransactions[txnID] - return ok -} - func (c *baseVerificationCallbacks) IsVerificationDone(txnID id.VerificationTransactionID) bool { _, ok := c.doneTransactions[txnID] return ok @@ -95,10 +88,6 @@ func (c *baseVerificationCallbacks) VerificationCancelled(ctx context.Context, t } } -func (c *baseVerificationCallbacks) OtherReportsDone(ctx context.Context, txnID id.VerificationTransactionID) { - c.otherDoneTransactions[txnID] = struct{}{} -} - func (c *baseVerificationCallbacks) VerificationDone(ctx context.Context, txnID id.VerificationTransactionID) { c.doneTransactions[txnID] = struct{}{} } diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index bad0b006..d3b7d4f5 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -115,10 +115,6 @@ type RequiredCallbacks interface { // VerificationCancelled is called when the verification is cancelled. VerificationCancelled(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) - // OtherReportsDone is called when the other user has reported that the - // verification is done. - OtherReportsDone(ctx context.Context, txnID id.VerificationTransactionID) - // VerificationDone is called when the verification is done. VerificationDone(ctx context.Context, txnID id.VerificationTransactionID) } @@ -156,7 +152,6 @@ type VerificationHelper struct { supportedMethods []event.VerificationMethod verificationRequested func(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID) verificationCancelledCallback func(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) - otherReportsDone func(ctx context.Context, txnID id.VerificationTransactionID) verificationDone func(ctx context.Context, txnID id.VerificationTransactionID) showSAS func(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, decimals []int) @@ -184,7 +179,6 @@ func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, call } else { helper.verificationRequested = c.VerificationRequested helper.verificationCancelledCallback = c.VerificationCancelled - helper.otherReportsDone = c.OtherReportsDone helper.verificationDone = c.VerificationDone } @@ -886,7 +880,6 @@ func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn *verif } txn.ReceivedTheirDone = true - vh.otherReportsDone(ctx, txn.TransactionID) if txn.SentOurDone { delete(vh.activeTransactions, txn.TransactionID) vh.verificationDone(ctx, txn.TransactionID) diff --git a/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go b/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go index 0003410b..aace2230 100644 --- a/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go +++ b/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go @@ -99,7 +99,6 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { assert.Equal(t, txnID, doneEvt.TransactionID) ts.dispatchToDevice(t, ctx, sendingClient) - assert.True(t, sendingCallbacks.OtherReportedDone(txnID)) } else { // receiving scans QR // Emulate scanning the QR code shown by the sending device on // the receiving device. @@ -138,7 +137,6 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { assert.Equal(t, txnID, doneEvt.TransactionID) ts.dispatchToDevice(t, ctx, receivingClient) - assert.True(t, receivingCallbacks.OtherReportedDone(txnID)) } // Ensure that both devices have marked the verification as done. diff --git a/crypto/verificationhelper/verificationhelper_qr_self_test.go b/crypto/verificationhelper/verificationhelper_qr_self_test.go index 9942ae30..11358b88 100644 --- a/crypto/verificationhelper/verificationhelper_qr_self_test.go +++ b/crypto/verificationhelper/verificationhelper_qr_self_test.go @@ -200,7 +200,6 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { assert.Equal(t, txnID, doneEvt.TransactionID) ts.dispatchToDevice(t, ctx, sendingClient) - assert.True(t, sendingCallbacks.OtherReportedDone(txnID)) } else { // receiving scans QR // Emulate scanning the QR code shown by the sending device on // the receiving device. @@ -239,7 +238,6 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { assert.Equal(t, txnID, doneEvt.TransactionID) ts.dispatchToDevice(t, ctx, receivingClient) - assert.True(t, receivingCallbacks.OtherReportedDone(txnID)) } // Ensure that both devices have marked the verification as done. diff --git a/crypto/verificationhelper/verificationhelper_sas_test.go b/crypto/verificationhelper/verificationhelper_sas_test.go index 4f036f18..20e52e0f 100644 --- a/crypto/verificationhelper/verificationhelper_sas_test.go +++ b/crypto/verificationhelper/verificationhelper_sas_test.go @@ -269,7 +269,6 @@ func TestVerification_SAS(t *testing.T) { // twice to process and drain all of the events. ts.dispatchToDevice(t, ctx, sendingClient) ts.dispatchToDevice(t, ctx, receivingClient) - assert.True(t, receivingCallbacks.OtherReportedDone(txnID)) ts.dispatchToDevice(t, ctx, sendingClient) ts.dispatchToDevice(t, ctx, receivingClient) assert.True(t, sendingCallbacks.IsVerificationDone(txnID)) From c8e197a4f9ee5e121c6dacedb6566e4c95af57df Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 19 Nov 2024 17:07:51 +0200 Subject: [PATCH 0920/1647] bridgev2/errors: add shared error for unknown login flow ID --- bridgev2/errors.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/bridgev2/errors.go b/bridgev2/errors.go index 55df5357..052a606b 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -64,6 +64,11 @@ var ( ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) ) +// Common login interface errors +var ( + ErrInvalidLoginFlowID error = RespError(mautrix.MNotFound.WithMessage("Invalid login flow ID")) +) + // RespError is a class of error that certain network interface methods can return to ensure that the error // is properly translated into an HTTP error when the method is called via the provisioning API. // From 93737946060de2b319424eaa7e63cd9083e8209c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 20 Nov 2024 14:03:21 +0100 Subject: [PATCH 0921/1647] crypto: delete old olm sessions if there are too many (#315) --- crypto/decryptolm.go | 39 +++++++++++++++++++++++++++++++++++++-- crypto/sql_store.go | 7 ++++++- crypto/store.go | 16 ++++++++++++++++ 3 files changed, 59 insertions(+), 3 deletions(-) diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go index 55614b76..965656a9 100644 --- a/crypto/decryptolm.go +++ b/crypto/decryptolm.go @@ -11,6 +11,7 @@ import ( "encoding/json" "errors" "fmt" + "slices" "time" "github.com/rs/zerolog" @@ -74,6 +75,11 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e return nil, UnsupportedOlmMessageType } + log := mach.machOrContextLog(ctx).With(). + Stringer("sender_key", senderKey). + Int("olm_msg_type", int(olmType)). + Logger() + ctx = log.WithContext(ctx) endTimeTrace := mach.timeTrace(ctx, "decrypting olm ciphertext", 5*time.Second) plaintext, err := mach.tryDecryptOlmCiphertext(ctx, evt.Sender, senderKey, olmType, ciphertext) endTimeTrace() @@ -168,6 +174,8 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U return plaintext, nil } +const MaxOlmSessionsPerDevice = 5 + func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.Context, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) { log := *zerolog.Ctx(ctx) endTimeTrace := mach.timeTrace(ctx, "getting sessions with sender key", time.Second) @@ -176,6 +184,31 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.C if err != nil { return nil, fmt.Errorf("failed to get session for %s: %w", senderKey, err) } + if len(sessions) > MaxOlmSessionsPerDevice*2 { + // SQL store sorts sessions, but other implementations may not, so re-sort just in case + slices.SortFunc(sessions, func(a, b *OlmSession) int { + return b.LastDecryptedTime.Compare(a.LastDecryptedTime) + }) + log.Warn(). + Int("session_count", len(sessions)). + Time("newest_last_decrypted_at", sessions[0].LastDecryptedTime). + Time("oldest_last_decrypted_at", sessions[len(sessions)-1].LastDecryptedTime). + Msg("Too many sessions, deleting old ones") + for i := MaxOlmSessionsPerDevice; i < len(sessions); i++ { + err = mach.CryptoStore.DeleteSession(ctx, senderKey, sessions[i]) + if err != nil { + log.Warn().Err(err). + Stringer("olm_session_id", sessions[i].ID()). + Time("last_decrypt", sessions[i].LastDecryptedTime). + Msg("Failed to delete olm session") + } else { + log.Debug(). + Stringer("olm_session_id", sessions[i].ID()). + Time("last_decrypt", sessions[i].LastDecryptedTime). + Msg("Deleted olm session") + } + } + } for _, session := range sessions { log := log.With().Str("olm_session_id", session.ID().String()).Logger() @@ -190,11 +223,13 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.C continue } } - log.Debug().Str("session_description", session.Describe()).Msg("Trying to decrypt olm message") endTimeTrace = mach.timeTrace(ctx, "decrypting olm message", time.Second) plaintext, err := session.Decrypt(ciphertext, olmType) endTimeTrace() if err != nil { + log.Warn().Err(err). + Str("session_description", session.Describe()). + Msg("Failed to decrypt olm message") if olmType == id.OlmMsgTypePreKey { return nil, DecryptionFailedWithMatchingSession } @@ -205,7 +240,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.C if err != nil { log.Warn().Err(err).Msg("Failed to update olm session in crypto store after decrypting") } - log.Debug().Msg("Decrypted olm message") + log.Debug().Str("session_description", session.Describe()).Msg("Decrypted olm message") return plaintext, nil } } diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 00544a9b..e68f0df5 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -219,7 +219,7 @@ func (store *SQLCryptoStore) getOlmSessionCache(key id.SenderKey) map[id.Session return data } -// GetLatestSession retrieves the Olm session for a given sender key from the database that has the largest ID. +// GetLatestSession retrieves the Olm session for a given sender key from the database that had the most recent successful decryption. func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.SenderKey) (*OlmSession, error) { store.olmSessionCacheLock.Lock() defer store.olmSessionCacheLock.Unlock() @@ -274,6 +274,11 @@ func (store *SQLCryptoStore) UpdateSession(ctx context.Context, _ id.SenderKey, return err } +func (store *SQLCryptoStore) DeleteSession(ctx context.Context, _ id.SenderKey, session *OlmSession) error { + _, err := store.DB.Exec(ctx, "DELETE FROM crypto_olm_session WHERE session_id=$1 AND account_id=$2", session.ID(), store.AccountID) + return err +} + func datePtr(t time.Time) *time.Time { if t.IsZero() { return nil diff --git a/crypto/store.go b/crypto/store.go index a84d4f13..64fa8912 100644 --- a/crypto/store.go +++ b/crypto/store.go @@ -9,6 +9,7 @@ package crypto import ( "context" "fmt" + "slices" "sort" "sync" @@ -47,6 +48,8 @@ type Store interface { GetLatestSession(context.Context, id.SenderKey) (*OlmSession, error) // UpdateSession updates a session that has previously been inserted with AddSession. UpdateSession(context.Context, id.SenderKey, *OlmSession) error + // DeleteSession deletes the given session that has been previously inserted with AddSession. + DeleteSession(context.Context, id.SenderKey, *OlmSession) error // PutGroupSession inserts an inbound Megolm session into the store. If an earlier withhold event has been inserted // with PutWithheldGroupSession, this call should replace that. However, PutWithheldGroupSession must not replace @@ -233,6 +236,19 @@ func (gs *MemoryStore) AddSession(_ context.Context, senderKey id.SenderKey, ses return gs.save() } +func (gs *MemoryStore) DeleteSession(ctx context.Context, senderKey id.SenderKey, target *OlmSession) error { + gs.lock.Lock() + defer gs.lock.Unlock() + sessions, ok := gs.Sessions[senderKey] + if !ok { + return nil + } + gs.Sessions[senderKey] = slices.DeleteFunc(sessions, func(session *OlmSession) bool { + return session == target + }) + return gs.save() +} + func (gs *MemoryStore) UpdateSession(_ context.Context, _ id.SenderKey, _ *OlmSession) error { // we don't need to do anything here because the session is a pointer and already stored in our map return gs.save() From ed709621d43a2201cad65e576546f36ddca2ac75 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 20 Nov 2024 15:25:37 +0200 Subject: [PATCH 0922/1647] bridgev2/portal: ensure ghost row exists even when overriding sender id --- bridgev2/portal.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 0b5c2adc..45668aa3 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1687,6 +1687,12 @@ func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventS Str("original_id", string(sender.Sender)). Str("default_other_user", string(portal.OtherUserID)). Msg("Overriding event sender with primary other user in DM portal") + // Ensure the ghost row exists anyway to prevent foreign key errors when saving messages + // TODO it'd probably be better to override the sender in the saved message, but that's more effort + _, err := portal.Bridge.GetGhostByID(ctx, sender.Sender) + if err != nil { + zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to get ghost with original user ID") + } sender.Sender = portal.OtherUserID } if sender.Sender != "" { From 40dbe7535dcf7e4e8fc79d786f938b6cb98b753a Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Sun, 3 Nov 2024 08:58:53 -0700 Subject: [PATCH 0923/1647] verificationhelper: save verification status in store Signed-off-by: Sumner Evans --- crypto/verificationhelper/reciprocate.go | 69 ++-- crypto/verificationhelper/sas.go | 103 +++--- .../verificationhelper/verificationhelper.go | 299 +++++++----------- .../verificationhelper_qr_self_test.go | 4 +- .../verificationhelper_test.go | 16 +- .../verificationhelper/verificationstore.go | 187 +++++++++++ 6 files changed, 415 insertions(+), 263 deletions(-) create mode 100644 crypto/verificationhelper/verificationstore.go diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index 21276218..395775e1 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -35,13 +35,13 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, ok := vh.activeTransactions[qrCode.TransactionID] - if !ok { - return fmt.Errorf("unknown transaction ID found in QR code") - } else if txn.VerificationState != verificationStateReady { + txn, err := vh.store.GetVerificationTransaction(ctx, qrCode.TransactionID) + if err != nil { + return fmt.Errorf("failed to get transaction %s: %w", qrCode.TransactionID, err) + } else if txn.VerificationState != VerificationStateReady { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "transaction found in the QR code is not in the ready state") } - txn.VerificationState = verificationStateTheirQRScanned + txn.VerificationState = VerificationStateTheirQRScanned // Verify the keys log.Info().Msg("Verifying keys from QR code") @@ -53,9 +53,9 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by switch qrCode.Mode { case QRCodeModeCrossSigning: - theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUser) + theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID) if err != nil { - return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "couldn't get %s's cross-signing keys: %w", txn.TheirUser, err) + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "couldn't get %s's cross-signing keys: %w", txn.TheirUserID, err) } if bytes.Equal(theirSigningKeys.MasterKey.Bytes(), qrCode.Key1[:]) { log.Info().Msg("Verified that the other device has the master key we expected") @@ -70,7 +70,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "the master key does not match") } - if err := vh.mach.SignUser(ctx, txn.TheirUser, theirSigningKeys.MasterKey); err != nil { + if err := vh.mach.SignUser(ctx, txn.TheirUserID, theirSigningKeys.MasterKey); err != nil { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to sign their master key: %w", err) } case QRCodeModeSelfVerifyingMasterKeyTrusted: @@ -78,7 +78,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by // means that we don't trust the key. Key1 is the master key public // key, and Key2 is what the other device thinks our device key is. - if vh.client.UserID != txn.TheirUser { + if vh.client.UserID != txn.TheirUserID { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "mode %d is only allowed when the other user is the same as the current user", qrCode.Mode) } @@ -114,12 +114,12 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeMasterKeyNotTrusted, "the master key is not trusted by this device, cannot verify device that does not trust the master key") } - if vh.client.UserID != txn.TheirUser { + if vh.client.UserID != txn.TheirUserID { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "mode %d is only allowed when the other user is the same as the current user", qrCode.Mode) } // Get their device - theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) + theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) if err != nil { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to get their device: %w", err) } @@ -140,7 +140,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by // Trust their device theirDevice.Trust = id.TrustStateVerified - err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUser, theirDevice) + err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) if err != nil { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to update device trust state after verifying: %+v", err) } @@ -177,8 +177,12 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by txn.SentOurDone = true if txn.ReceivedTheirDone { log.Debug().Msg("We already received their done event. Setting verification state to done.") - delete(vh.activeTransactions, txn.TransactionID) + if err = vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { + return err + } vh.verificationDone(ctx, txn.TransactionID) + } else { + vh.store.SaveVerificationTransaction(ctx, txn) } return nil } @@ -196,28 +200,27 @@ func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, ok := vh.activeTransactions[txnID] - if !ok { - log.Warn().Msg("Ignoring QR code scan confirmation for an unknown transaction") - return nil - } else if txn.VerificationState != verificationStateOurQRScanned { + txn, err := vh.store.GetVerificationTransaction(ctx, txnID) + if err != nil { + return fmt.Errorf("failed to get transaction %s: %w", txnID, err) + } else if txn.VerificationState != VerificationStateOurQRScanned { return fmt.Errorf("transaction is not in the scanned state") } log.Info().Msg("Confirming QR code scanned") - if txn.TheirUser == vh.client.UserID { + if txn.TheirUserID == vh.client.UserID { // Self-signing situation. Trust their device. // Get their device - theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) + theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) if err != nil { return err } // Trust their device theirDevice.Trust = id.TrustStateVerified - err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUser, theirDevice) + err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) if err != nil { return fmt.Errorf("failed to update device trust state after verifying: %w", err) } @@ -231,29 +234,33 @@ func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id } } else { // Cross-signing situation. Sign their master key. - theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUser) + theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID) if err != nil { - return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "couldn't get %s's cross-signing keys: %w", txn.TheirUser, err) + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "couldn't get %s's cross-signing keys: %w", txn.TheirUserID, err) } - if err := vh.mach.SignUser(ctx, txn.TheirUser, theirSigningKeys.MasterKey); err != nil { + if err := vh.mach.SignUser(ctx, txn.TheirUserID, theirSigningKeys.MasterKey); err != nil { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to sign their master key: %w", err) } } - err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) + err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { return err } txn.SentOurDone = true if txn.ReceivedTheirDone { - delete(vh.activeTransactions, txn.TransactionID) + if err = vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { + return err + } vh.verificationDone(ctx, txn.TransactionID) + } else { + vh.store.SaveVerificationTransaction(ctx, txn) } return nil } -func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *verificationTransaction) error { +func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn VerificationTransaction) error { log := vh.getLog(ctx).With(). Str("verification_action", "generate and show QR code"). Stringer("transaction_id", txn.TransactionID). @@ -276,7 +283,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *ve return err } mode := QRCodeModeCrossSigning - if vh.client.UserID == txn.TheirUser { + if vh.client.UserID == txn.TheirUserID { // This is a self-signing situation. if ownMasterKeyTrusted { mode = QRCodeModeSelfVerifyingMasterKeyTrusted @@ -298,7 +305,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *ve key1 = ownCrossSigningPublicKeys.MasterKey.Bytes() // Key 2 is the other user's master signing key. - theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUser) + theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID) if err != nil { return err } @@ -308,7 +315,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *ve key1 = ownCrossSigningPublicKeys.MasterKey.Bytes() // Key 2 is the other device's key. - theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) + theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) if err != nil { return err } @@ -326,5 +333,5 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *ve qrCode := NewQRCode(mode, txn.TransactionID, [32]byte(key1), [32]byte(key2)) txn.QRCodeSharedSecret = qrCode.SharedSecret vh.showQRCode(ctx, txn.TransactionID, qrCode) - return nil + return vh.store.SaveVerificationTransaction(ctx, txn) } diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index e28ec405..0492dd8d 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -40,23 +40,23 @@ func (vh *VerificationHelper) StartSAS(ctx context.Context, txnID id.Verificatio vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, ok := vh.activeTransactions[txnID] - if !ok { - return fmt.Errorf("unknown transaction ID") - } else if txn.VerificationState != verificationStateReady { + txn, err := vh.store.GetVerificationTransaction(ctx, txnID) + if err != nil { + return fmt.Errorf("failed to get verification transaction %s: %w", txnID, err) + } else if txn.VerificationState != VerificationStateReady { return errors.New("transaction is not in ready state") } else if txn.StartEventContent != nil { return errors.New("start event already sent or received") } - txn.VerificationState = verificationStateSASStarted + txn.VerificationState = VerificationStateSASStarted txn.StartedByUs = true if !slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodSAS) { return fmt.Errorf("the other device does not support SAS verification") } // Ensure that we have their device key. - _, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) + _, err = vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) if err != nil { log.Err(err).Msg("Failed to fetch device") return err @@ -78,6 +78,9 @@ func (vh *VerificationHelper) StartSAS(ctx context.Context, txnID id.Verificatio event.SASMethodEmoji, }, } + if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { + return err + } return vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationStart, txn.StartEventContent) } @@ -94,14 +97,13 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, ok := vh.activeTransactions[txnID] - if !ok { - return fmt.Errorf("unknown transaction ID") - } else if txn.VerificationState != verificationStateSASKeysExchanged { + txn, err := vh.store.GetVerificationTransaction(ctx, txnID) + if err != nil { + return fmt.Errorf("failed to get transaction %s: %w", txnID, err) + } else if txn.VerificationState != VerificationStateSASKeysExchanged { return errors.New("transaction is not in keys exchanged state") } - var err error keys := map[id.KeyID]jsonbytes.UnpaddedBytes{} log.Info().Msg("Signing keys") @@ -109,7 +111,7 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat // My device key myDevice := vh.mach.OwnIdentity() myDeviceKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, myDevice.DeviceID.String()) - keys[myDeviceKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUser, txn.TheirDevice, myDeviceKeyID.String(), myDevice.SigningKey.String()) + keys[myDeviceKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUserID, txn.TheirDeviceID, myDeviceKeyID.String(), myDevice.SigningKey.String()) if err != nil { return err } @@ -118,7 +120,7 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat crossSigningKeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) if crossSigningKeys != nil { crossSigningKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, crossSigningKeys.MasterKey.String()) - keys[crossSigningKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUser, txn.TheirDevice, crossSigningKeyID.String(), crossSigningKeys.MasterKey.String()) + keys[crossSigningKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUserID, txn.TheirDeviceID, crossSigningKeyID.String(), crossSigningKeys.MasterKey.String()) if err != nil { return err } @@ -129,7 +131,7 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat keyIDs = append(keyIDs, keyID.String()) } slices.Sort(keyIDs) - keysMAC, err := vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUser, txn.TheirDevice, "KEY_IDS", strings.Join(keyIDs, ",")) + keysMAC, err := vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUserID, txn.TheirDeviceID, "KEY_IDS", strings.Join(keyIDs, ",")) if err != nil { return err } @@ -145,14 +147,14 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat txn.SentOurMAC = true if txn.ReceivedTheirMAC { - txn.VerificationState = verificationStateSASMACExchanged + txn.VerificationState = VerificationStateSASMACExchanged err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { return err } txn.SentOurDone = true } - return nil + return vh.store.SaveVerificationTransaction(ctx, txn) } // onVerificationStartSAS handles the m.key.verification.start events with @@ -160,7 +162,7 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat // Spec. // // [Section 11.12.2.2]: https://spec.matrix.org/v1.9/client-server-api/#short-authentication-string-sas-verification -func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn *verificationTransaction, evt *event.Event) error { +func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn VerificationTransaction, evt *event.Event) error { startEvt := evt.Content.AsVerificationStart() log := vh.getLog(ctx).With(). Str("verification_action", "start_sas"). @@ -208,7 +210,7 @@ func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn *v return fmt.Errorf("failed to generate ephemeral key: %w", err) } txn.MACMethod = macMethod - txn.EphemeralKey = ephemeralKey + txn.EphemeralKey = &ECDHPrivateKey{ephemeralKey} txn.StartEventContent = startEvt commitment, err := calculateCommitment(ephemeralKey.PublicKey(), startEvt) @@ -226,8 +228,8 @@ func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn *v if err != nil { return fmt.Errorf("failed to send accept event: %w", err) } - txn.VerificationState = verificationStateSASAccepted - return nil + txn.VerificationState = VerificationStateSASAccepted + return vh.store.SaveVerificationTransaction(ctx, txn) } func calculateCommitment(ephemeralPubKey *ecdh.PublicKey, startEvt *event.VerificationStartEventContent) ([]byte, error) { @@ -252,7 +254,7 @@ func calculateCommitment(ephemeralPubKey *ecdh.PublicKey, startEvt *event.Verifi // event. This follows Step 4 of [Section 11.12.2.2] of the Spec. // // [Section 11.12.2.2]: https://spec.matrix.org/v1.9/client-server-api/#short-authentication-string-sas-verification -func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn *verificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn VerificationTransaction, evt *event.Event) { acceptEvt := evt.Content.AsVerificationAccept() log := vh.getLog(ctx).With(). Str("verification_action", "accept"). @@ -267,7 +269,7 @@ func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn *ver vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState != verificationStateSASStarted { + if txn.VerificationState != VerificationStateSASStarted { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "received accept event for a transaction that is not in the started state") return @@ -287,14 +289,18 @@ func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn *ver return } - txn.VerificationState = verificationStateSASAccepted + txn.VerificationState = VerificationStateSASAccepted txn.MACMethod = acceptEvt.MessageAuthenticationCode txn.Commitment = acceptEvt.Commitment - txn.EphemeralKey = ephemeralKey + txn.EphemeralKey = &ECDHPrivateKey{ephemeralKey} txn.EphemeralPublicKeyShared = true + + if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { + log.Err(err).Msg("failed to save verification transaction") + } } -func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn *verificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn VerificationTransaction, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "key"). Logger() @@ -302,22 +308,23 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn *verifi vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState != verificationStateSASAccepted { + if txn.VerificationState != VerificationStateSASAccepted { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "received key event for a transaction that is not in the accepted state") return } var err error - txn.OtherPublicKey, err = ecdh.X25519().NewPublicKey(keyEvt.Key) + publicKey, err := ecdh.X25519().NewPublicKey(keyEvt.Key) if err != nil { log.Err(err).Msg("Failed to generate other public key") return } + txn.OtherPublicKey = &ECDHPublicKey{publicKey} if txn.EphemeralPublicKeyShared { // Verify that the commitment hash is correct - commitment, err := calculateCommitment(txn.OtherPublicKey, txn.StartEventContent) + commitment, err := calculateCommitment(publicKey, txn.StartEventContent) if err != nil { log.Err(err).Msg("Failed to calculate commitment") return @@ -342,7 +349,7 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn *verifi } txn.EphemeralPublicKeyShared = true } - txn.VerificationState = verificationStateSASKeysExchanged + txn.VerificationState = VerificationStateSASKeysExchanged sasBytes, err := vh.verificationSASHKDF(txn) if err != nil { @@ -370,10 +377,14 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn *verifi } } vh.showSAS(ctx, txn.TransactionID, emojis, decimals) + + if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { + log.Err(err).Msg("failed to save verification transaction") + } } -func (vh *VerificationHelper) verificationSASHKDF(txn *verificationTransaction) ([]byte, error) { - sharedSecret, err := txn.EphemeralKey.ECDH(txn.OtherPublicKey) +func (vh *VerificationHelper) verificationSASHKDF(txn VerificationTransaction) ([]byte, error) { + sharedSecret, err := txn.EphemeralKey.ECDH(txn.OtherPublicKey.PublicKey) if err != nil { return nil, err } @@ -388,8 +399,8 @@ func (vh *VerificationHelper) verificationSASHKDF(txn *verificationTransaction) }, "|") theirInfo := strings.Join([]string{ - txn.TheirUser.String(), - txn.TheirDevice.String(), + txn.TheirUserID.String(), + txn.TheirDeviceID.String(), base64.RawStdEncoding.EncodeToString(txn.OtherPublicKey.Bytes()), }, "|") @@ -462,8 +473,8 @@ func BrokenB64Encode(input []byte) string { return string(output) } -func (vh *VerificationHelper) verificationMACHKDF(txn *verificationTransaction, senderUser id.UserID, senderDevice id.DeviceID, receivingUser id.UserID, receivingDevice id.DeviceID, keyID, key string) ([]byte, error) { - sharedSecret, err := txn.EphemeralKey.ECDH(txn.OtherPublicKey) +func (vh *VerificationHelper) verificationMACHKDF(txn VerificationTransaction, senderUser id.UserID, senderDevice id.DeviceID, receivingUser id.UserID, receivingDevice id.DeviceID, keyID, key string) ([]byte, error) { + sharedSecret, err := txn.EphemeralKey.ECDH(txn.OtherPublicKey.PublicKey) if err != nil { return nil, err } @@ -563,7 +574,7 @@ var allEmojis = []rune{ '📌', } -func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn VerificationTransaction, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "mac"). Logger() @@ -579,12 +590,12 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi for keyID := range macEvt.MAC { keyIDs = append(keyIDs, keyID.String()) _, kID := keyID.Parse() - if kID == txn.TheirDevice.String() { + if kID == txn.TheirDeviceID.String() { hasTheirDeviceKey = true } } slices.Sort(keyIDs) - expectedKeyMAC, err := vh.verificationMACHKDF(txn, txn.TheirUser, txn.TheirDevice, vh.client.UserID, vh.client.DeviceID, "KEY_IDS", strings.Join(keyIDs, ",")) + expectedKeyMAC, err := vh.verificationMACHKDF(txn, txn.TheirUserID, txn.TheirDeviceID, vh.client.UserID, vh.client.DeviceID, "KEY_IDS", strings.Join(keyIDs, ",")) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeSASMismatch, "failed to calculate key list MAC: %w", err) return @@ -610,8 +621,8 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi var key string var theirDevice *id.Device - if kID == txn.TheirDevice.String() { - theirDevice, err = vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) + if kID == txn.TheirDeviceID.String() { + theirDevice, err = vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to fetch their device: %w", err) return @@ -630,7 +641,7 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi key = crossSigningKeys.MasterKey.String() } - expectedMAC, err := vh.verificationMACHKDF(txn, txn.TheirUser, txn.TheirDevice, vh.client.UserID, vh.client.DeviceID, keyID.String(), key) + expectedMAC, err := vh.verificationMACHKDF(txn, txn.TheirUserID, txn.TheirDeviceID, vh.client.UserID, vh.client.DeviceID, keyID.String(), key) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to calculate key MAC: %w", err) return @@ -641,9 +652,9 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi } // Trust their device - if kID == txn.TheirDevice.String() { + if kID == txn.TheirDeviceID.String() { theirDevice.Trust = id.TrustStateVerified - err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUser, theirDevice) + err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to update device trust state after verifying: %w", err) return @@ -654,7 +665,7 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi txn.ReceivedTheirMAC = true if txn.SentOurMAC { - txn.VerificationState = verificationStateSASMACExchanged + txn.VerificationState = VerificationStateSASMACExchanged err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to send verification done event: %w", err) @@ -662,4 +673,8 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi } txn.SentOurDone = true } + + if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { + log.Err(err).Msg("failed to save verification transaction") + } } diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index d3b7d4f5..f8173ee1 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -9,7 +9,7 @@ package verificationhelper import ( "bytes" "context" - "crypto/ecdh" + "errors" "fmt" "sync" "time" @@ -25,86 +25,6 @@ import ( "maunium.net/go/mautrix/id" ) -type verificationState int - -const ( - verificationStateRequested verificationState = iota - verificationStateReady - - verificationStateTheirQRScanned // We scanned their QR code - verificationStateOurQRScanned // They scanned our QR code - - verificationStateSASStarted // An SAS verification has been started - verificationStateSASAccepted // An SAS verification has been accepted - verificationStateSASKeysExchanged // An SAS verification has exchanged keys - verificationStateSASMACExchanged // An SAS verification has exchanged MACs -) - -func (step verificationState) String() string { - switch step { - case verificationStateRequested: - return "requested" - case verificationStateReady: - return "ready" - case verificationStateTheirQRScanned: - return "their_qr_scanned" - case verificationStateOurQRScanned: - return "our_qr_scanned" - case verificationStateSASStarted: - return "sas_started" - case verificationStateSASAccepted: - return "sas_accepted" - case verificationStateSASKeysExchanged: - return "sas_keys_exchanged" - case verificationStateSASMACExchanged: - return "sas_mac" - default: - return fmt.Sprintf("verificationStep(%d)", step) - } -} - -type verificationTransaction struct { - // RoomID is the room ID if the verification is happening in a room or - // empty if it is a to-device verification. - RoomID id.RoomID - - // VerificationState is the current step of the verification flow. - VerificationState verificationState - // TransactionID is the ID of the verification transaction. - TransactionID id.VerificationTransactionID - - // TheirDevice is the device ID of the device that either made the initial - // request or accepted our request. - TheirDevice id.DeviceID - // TheirUser is the user ID of the other user. - TheirUser id.UserID - // TheirSupportedMethods is a list of verification methods that the other - // device supports. - TheirSupportedMethods []event.VerificationMethod - - // SentToDeviceIDs is a list of devices which the initial request was sent - // to. This is only used for to-device verification requests, and is meant - // to be used to send cancellation requests to all other devices when a - // verification request is accepted via a m.key.verification.ready event. - SentToDeviceIDs []id.DeviceID - - // QRCodeSharedSecret is the shared secret that was encoded in the QR code - // that we showed. - QRCodeSharedSecret []byte - - StartedByUs bool // Whether the verification was started by us - StartEventContent *event.VerificationStartEventContent // The m.key.verification.start event content - Commitment []byte // The commitment from the m.key.verification.accept event - MACMethod event.MACMethod // The method used to calculate the MAC - EphemeralKey *ecdh.PrivateKey // The ephemeral key - EphemeralPublicKeyShared bool // Whether this device's ephemeral public key has been shared - OtherPublicKey *ecdh.PublicKey // The other device's ephemeral public key - ReceivedTheirMAC bool // Whether we have received their MAC - SentOurMAC bool // Whether we have sent our MAC - ReceivedTheirDone bool // Whether we have received their done event - SentOurDone bool // Whether we have sent our done event -} - // RequiredCallbacks is an interface representing the callbacks required for // the [VerificationHelper]. type RequiredCallbacks interface { @@ -145,8 +65,9 @@ type VerificationHelper struct { client *mautrix.Client mach *crypto.OlmMachine - activeTransactions map[id.VerificationTransactionID]*verificationTransaction + store VerificationStore activeTransactionsLock sync.Mutex + // activeTransactions map[id.VerificationTransactionID]*verificationTransaction // supportedMethods are the methods that *we* support supportedMethods []event.VerificationMethod @@ -163,15 +84,19 @@ type VerificationHelper struct { var _ mautrix.VerificationHelper = (*VerificationHelper)(nil) -func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, callbacks any, supportsScan bool) *VerificationHelper { +func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, store VerificationStore, callbacks any, supportsScan bool) *VerificationHelper { if client.Crypto == nil { panic("client.Crypto is nil") } + if store == nil { + store = NewInMemoryVerificationStore() + } + helper := VerificationHelper{ - client: client, - mach: mach, - activeTransactions: map[id.VerificationTransactionID]*verificationTransaction{}, + client: client, + mach: mach, + store: store, } if c, ok := callbacks.(RequiredCallbacks); !ok { @@ -233,7 +158,7 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { // Wrapper for the event handlers to check that the transaction ID is known // and ignore the event if it isn't. - wrapHandler := func(callback func(context.Context, *verificationTransaction, *event.Event)) func(context.Context, *event.Event) { + wrapHandler := func(callback func(context.Context, VerificationTransaction, *event.Event)) func(context.Context, *event.Event) { return func(ctx context.Context, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "check transaction ID"). @@ -257,8 +182,11 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { log = log.With().Stringer("transaction_id", transactionID).Logger() vh.activeTransactionsLock.Lock() - txn, ok := vh.activeTransactions[transactionID] - if !ok { + txn, err := vh.store.GetVerificationTransaction(ctx, transactionID) + if err != nil && errors.Is(err, ErrUnknownVerificationTransaction) { + log.Err(err).Msg("failed to get verification transaction") + return + } else if errors.Is(err, ErrUnknownVerificationTransaction) { // If it's a cancellation event for an unknown transaction, we // can just ignore it. if evt.Type == event.ToDeviceVerificationCancel || evt.Type == event.InRoomVerificationCancel { @@ -271,9 +199,9 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { // We have to create a fake transaction so that the call to // cancelVerificationTxn works. - txn = &verificationTransaction{ - RoomID: evt.RoomID, - TheirUser: evt.Sender, + txn = VerificationTransaction{ + RoomID: evt.RoomID, + TheirUserID: evt.Sender, } if transactionable, ok := evt.Content.Parsed.(event.VerificationTransactionable); ok { txn.TransactionID = transactionable.GetTransactionID() @@ -281,7 +209,7 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { txn.TransactionID = id.VerificationTransactionID(evt.ID) } if fromDevice, ok := evt.Content.Raw["from_device"]; ok { - txn.TheirDevice = id.DeviceID(fromDevice.(string)) + txn.TheirDeviceID = id.DeviceID(fromDevice.(string)) } // Send a cancellation event. @@ -322,7 +250,11 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { syncer.OnEventType(event.InRoomVerificationKey, wrapHandler(vh.onVerificationKey)) // SAS syncer.OnEventType(event.InRoomVerificationMAC, wrapHandler(vh.onVerificationMAC)) // SAS - return nil + allTransactions, err := vh.store.GetAllVerificationTransactions(ctx) + for _, txn := range allTransactions { + vh.expireTransactionAt(txn.TransactionID, txn.ExpirationTime.Time) + } + return err } // StartVerification starts an interactive verification flow with the given @@ -382,13 +314,12 @@ func (vh *VerificationHelper) StartVerification(ctx context.Context, to id.UserI vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - vh.activeTransactions[txnID] = &verificationTransaction{ - VerificationState: verificationStateRequested, + return txnID, vh.store.SaveVerificationTransaction(ctx, VerificationTransaction{ + VerificationState: VerificationStateRequested, TransactionID: txnID, - TheirUser: to, + TheirUserID: to, SentToDeviceIDs: maps.Keys(devices), - } - return txnID, nil + }) } // StartInRoomVerification starts an interactive verification flow with the @@ -422,13 +353,12 @@ func (vh *VerificationHelper) StartInRoomVerification(ctx context.Context, roomI vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - vh.activeTransactions[txnID] = &verificationTransaction{ + return txnID, vh.store.SaveVerificationTransaction(ctx, VerificationTransaction{ RoomID: roomID, - VerificationState: verificationStateRequested, + VerificationState: VerificationStateRequested, TransactionID: txnID, - TheirUser: to, - } - return txnID, nil + TheirUserID: to, + }) } // AcceptVerification accepts a verification request. The transaction ID should @@ -440,10 +370,10 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V Stringer("transaction_id", txnID). Logger() - txn, ok := vh.activeTransactions[txnID] - if !ok { - return fmt.Errorf("unknown transaction ID") - } else if txn.VerificationState != verificationStateRequested { + txn, err := vh.store.GetVerificationTransaction(ctx, txnID) + if err != nil { + return err + } else if txn.VerificationState != VerificationStateRequested { return fmt.Errorf("transaction is not in the requested state") } @@ -472,11 +402,11 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V FromDevice: vh.client.DeviceID, Methods: maps.Keys(supportedMethods), } - err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationReady, readyEvt) + err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationReady, readyEvt) if err != nil { return err } - txn.VerificationState = verificationStateReady + txn.VerificationState = VerificationStateReady if vh.scanQRCode != nil && slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) { vh.scanQRCode(ctx, txn.TransactionID) @@ -492,8 +422,7 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V func (vh *VerificationHelper) DismissVerification(ctx context.Context, txnID id.VerificationTransactionID) error { vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - delete(vh.activeTransactions, txnID) - return nil + return vh.store.DeleteVerification(ctx, txnID) } // DismissVerification cancels the verification request with the given @@ -504,9 +433,9 @@ func (vh *VerificationHelper) CancelVerification(ctx context.Context, txnID id.V vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, ok := vh.activeTransactions[txnID] - if !ok { - return fmt.Errorf("unknown transaction ID") + txn, err := vh.store.GetVerificationTransaction(ctx, txnID) + if err != nil { + return err } log := vh.getLog(ctx).With(). Str("verification_action", "cancel verification"). @@ -527,29 +456,28 @@ func (vh *VerificationHelper) CancelVerification(ctx context.Context, txnID id.V } else { cancelEvt.SetTransactionID(txn.TransactionID) req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{ - txn.TheirUser: {}, + txn.TheirUserID: {}, }} - if len(txn.TheirDevice) > 0 { + if len(txn.TheirDeviceID) > 0 { // Send the cancellation event to only the device that accepted the // verification request. All of the other devices already received a // cancellation event with code "m.acceped". - req.Messages[txn.TheirUser][txn.TheirDevice] = &event.Content{Parsed: cancelEvt} + req.Messages[txn.TheirUserID][txn.TheirDeviceID] = &event.Content{Parsed: cancelEvt} } else { // Send the cancellation event to all of the devices that we sent the // request to. for _, deviceID := range txn.SentToDeviceIDs { if deviceID != vh.client.DeviceID { - req.Messages[txn.TheirUser][deviceID] = &event.Content{Parsed: cancelEvt} + req.Messages[txn.TheirUserID][deviceID] = &event.Content{Parsed: cancelEvt} } } } _, err := vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) if err != nil { - return fmt.Errorf("failed to send m.key.verification.cancel event to %v: %w", maps.Keys(req.Messages[txn.TheirUser]), err) + return fmt.Errorf("failed to send m.key.verification.cancel event to %v: %w", maps.Keys(req.Messages[txn.TheirUserID]), err) } } - delete(vh.activeTransactions, txn.TransactionID) - return nil + return vh.store.DeleteVerification(ctx, txn.TransactionID) } // sendVerificationEvent sends a verification event to the other user's device @@ -561,7 +489,7 @@ func (vh *VerificationHelper) CancelVerification(ctx context.Context, txnID id.V // [event.VerificationTransactionable]. // - evtType can be either the to-device or in-room version of the event type // as it is always stringified. -func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn *verificationTransaction, evtType event.Type, content any) error { +func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn VerificationTransaction, evtType event.Type, content any) error { if txn.RoomID != "" { content.(event.Relatable).SetRelatesTo(&event.RelatesTo{Type: event.RelReference, EventID: id.EventID(txn.TransactionID)}) _, err := vh.client.SendMessageEvent(ctx, txn.RoomID, evtType, &event.Content{ @@ -573,13 +501,13 @@ func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn *ve } else { content.(event.VerificationTransactionable).SetTransactionID(txn.TransactionID) req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{ - txn.TheirUser: { - txn.TheirDevice: &event.Content{Parsed: content}, + txn.TheirUserID: { + txn.TheirDeviceID: &event.Content{Parsed: content}, }, }} _, err := vh.client.SendToDevice(ctx, evtType, &req) if err != nil { - return fmt.Errorf("failed to send %s event to %s: %w", evtType.String(), txn.TheirDevice, err) + return fmt.Errorf("failed to send %s event to %s: %w", evtType.String(), txn.TheirDeviceID, err) } } return nil @@ -591,7 +519,7 @@ func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn *ve // directly to expose the error to its caller). // // Must always be called with the activeTransactionsLock held. -func (vh *VerificationHelper) cancelVerificationTxn(ctx context.Context, txn *verificationTransaction, code event.VerificationCancelCode, reasonFmtStr string, fmtArgs ...any) error { +func (vh *VerificationHelper) cancelVerificationTxn(ctx context.Context, txn VerificationTransaction, code event.VerificationCancelCode, reasonFmtStr string, fmtArgs ...any) error { log := vh.getLog(ctx) reason := fmt.Errorf(reasonFmtStr, fmtArgs...).Error() log.Info(). @@ -605,7 +533,9 @@ func (vh *VerificationHelper) cancelVerificationTxn(ctx context.Context, txn *ve log.Err(err).Msg("failed to send cancellation event") return fmt.Errorf("failed to send cancel verification event (code: %s, reason: %s): %w", code, reason, err) } - delete(vh.activeTransactions, txn.TransactionID) + if err = vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { + log.Err(err).Msg("deleting verification failed") + } vh.verificationCancelledCallback(ctx, txn.TransactionID, code, reason) return fmt.Errorf("verification cancelled (code: %s): %s", code, reason) } @@ -684,54 +614,58 @@ func (vh *VerificationHelper) onVerificationRequest(ctx context.Context, evt *ev } vh.activeTransactionsLock.Lock() - newTxn := &verificationTransaction{ + newTxn := VerificationTransaction{ + ExpirationTime: jsontime.UnixMilli{Time: verificationRequest.Timestamp.Add(time.Minute * 10)}, RoomID: evt.RoomID, - VerificationState: verificationStateRequested, + VerificationState: VerificationStateRequested, TransactionID: verificationRequest.TransactionID, - TheirDevice: verificationRequest.FromDevice, - TheirUser: evt.Sender, + TheirDeviceID: verificationRequest.FromDevice, + TheirUserID: evt.Sender, TheirSupportedMethods: verificationRequest.Methods, } - for existingTxnID, existingTxn := range vh.activeTransactions { - if existingTxn.TheirUser == evt.Sender && existingTxn.TheirDevice == verificationRequest.FromDevice && existingTxnID != verificationRequest.TransactionID { - vh.cancelVerificationTxn(ctx, existingTxn, event.VerificationCancelCodeUnexpectedMessage, "received multiple verification requests from the same device") + if txn, err := vh.store.FindVerificationTransactionForUserDevice(ctx, evt.Sender, verificationRequest.FromDevice); err != nil && !errors.Is(err, ErrUnknownVerificationTransaction) { + log.Err(err).Stringer("sender", evt.Sender).Stringer("device_id", verificationRequest.FromDevice).Msg("failed to find verification transaction") + vh.activeTransactionsLock.Unlock() + return + } else if !errors.Is(err, ErrUnknownVerificationTransaction) { + if txn.TransactionID == verificationRequest.TransactionID { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "received a new verification request for the same transaction ID") + } else { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "received multiple verification requests from the same device") vh.cancelVerificationTxn(ctx, newTxn, event.VerificationCancelCodeUnexpectedMessage, "received multiple verification requests from the same device") - delete(vh.activeTransactions, existingTxnID) - vh.activeTransactionsLock.Unlock() - return - } - - if existingTxnID == verificationRequest.TransactionID { - vh.cancelVerificationTxn(ctx, existingTxn, event.VerificationCancelCodeUnexpectedMessage, "received a new verification request for the same transaction ID") - delete(vh.activeTransactions, existingTxnID) - vh.activeTransactionsLock.Unlock() - return } + vh.activeTransactionsLock.Unlock() + return + } + if err := vh.store.SaveVerificationTransaction(ctx, newTxn); err != nil { + log.Err(err).Msg("failed to save verification transaction") } - vh.activeTransactions[verificationRequest.TransactionID] = newTxn vh.activeTransactionsLock.Unlock() vh.expireTransactionAt(verificationRequest.TransactionID, verificationRequest.Timestamp.Add(time.Minute*10)) vh.verificationRequested(ctx, verificationRequest.TransactionID, evt.Sender) } -func (vh *VerificationHelper) expireTransactionAt(txnID id.VerificationTransactionID, expireAt time.Time) { +func (vh *VerificationHelper) expireTransactionAt(txnID id.VerificationTransactionID, expiresAt time.Time) { go func() { - time.Sleep(time.Until(expireAt)) + time.Sleep(time.Until(expiresAt)) vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, ok := vh.activeTransactions[txnID] - if !ok { + txn, err := vh.store.GetVerificationTransaction(context.Background(), txnID) + if err == ErrUnknownVerificationTransaction { + // Already deleted, nothing to expire return + } else if err != nil { + vh.getLog(context.Background()).Err(err).Msg("failed to get verification transaction to expire") + } else { + vh.cancelVerificationTxn(context.Background(), txn, event.VerificationCancelCodeTimeout, "verification timed out") } - - vh.cancelVerificationTxn(context.Background(), txn, event.VerificationCancelCodeTimeout, "verification timed out") }() } -func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *verificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn VerificationTransaction, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "verification ready"). Logger() @@ -739,7 +673,7 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *veri vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState != verificationStateRequested { + if txn.VerificationState != VerificationStateRequested { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "verification ready event received for a transaction that is not in the requested state") return } @@ -747,8 +681,8 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *veri readyEvt := evt.Content.AsVerificationReady() // Update the transaction state. - txn.VerificationState = verificationStateReady - txn.TheirDevice = readyEvt.FromDevice + txn.VerificationState = VerificationStateReady + txn.TheirDeviceID = readyEvt.FromDevice txn.TheirSupportedMethods = readyEvt.Methods // If we sent this verification request, send cancellations to all of the @@ -761,16 +695,16 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *veri Reason: "The verification was accepted on another device.", }, } - req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUser: {}}} + req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUserID: {}}} for _, deviceID := range txn.SentToDeviceIDs { - if deviceID == txn.TheirDevice || deviceID == vh.client.DeviceID { + if deviceID == txn.TheirDeviceID || deviceID == vh.client.DeviceID { // Don't ever send a cancellation to the device that accepted // the request or to our own device (which can happen if this // is a self-verification). continue } - req.Messages[txn.TheirUser][deviceID] = content + req.Messages[txn.TheirUserID][deviceID] = content } _, err := vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) if err != nil { @@ -782,13 +716,12 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *veri vh.scanQRCode(ctx, txn.TransactionID) } - err := vh.generateAndShowQRCode(ctx, txn) - if err != nil { + if err := vh.generateAndShowQRCode(ctx, txn); err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to generate and show QR code: %w", err) } } -func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *verificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn VerificationTransaction, evt *event.Event) { startEvt := evt.Content.AsVerificationStart() log := vh.getLog(ctx).With(). Str("verification_action", "verification start"). @@ -799,7 +732,7 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *veri vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState == verificationStateSASStarted || txn.VerificationState == verificationStateOurQRScanned || txn.VerificationState == verificationStateTheirQRScanned { + if txn.VerificationState == VerificationStateSASStarted || txn.VerificationState == VerificationStateOurQRScanned || txn.VerificationState == VerificationStateTheirQRScanned { // We might have sent the event, and they also sent an event. if txn.StartEventContent == nil || !txn.StartedByUs { // We didn't sent a start event yet, so we have gotten ourselves @@ -831,19 +764,19 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *veri return } - if txn.TheirUser < vh.client.UserID || (txn.TheirUser == vh.client.UserID && txn.TheirDevice < vh.client.DeviceID) { + if txn.TheirUserID < vh.client.UserID || (txn.TheirUserID == vh.client.UserID && txn.TheirDeviceID < vh.client.DeviceID) { // Use their start event instead of ours txn.StartedByUs = false txn.StartEventContent = startEvt } - } else if txn.VerificationState != verificationStateReady { + } else if txn.VerificationState != VerificationStateReady { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "got start event for transaction that is not in ready state") return } switch startEvt.Method { case event.VerificationMethodSAS: - txn.VerificationState = verificationStateSASStarted + txn.VerificationState = VerificationStateSASStarted if err := vh.onVerificationStartSAS(ctx, txn, evt); err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to handle SAS verification start: %w", err) } @@ -853,8 +786,11 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *veri vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "reciprocated shared secret does not match") return } - txn.VerificationState = verificationStateOurQRScanned + txn.VerificationState = VerificationStateOurQRScanned vh.qrCodeScaned(ctx, txn.TransactionID) + if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { + log.Err(err).Msg("failed to save verification transaction") + } default: // Note that we should never get m.qr_code.show.v1 or m.qr_code.scan.v1 // here, since the start command for scanning and showing QR codes @@ -864,16 +800,17 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *veri } } -func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn *verificationTransaction, evt *event.Event) { - vh.getLog(ctx).Info(). +func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn VerificationTransaction, evt *event.Event) { + log := vh.getLog(ctx).With(). Str("verification_action", "done"). Stringer("transaction_id", txn.TransactionID). - Msg("Verification done") + Logger() + log.Info().Msg("Verification done") vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if !slices.Contains([]verificationState{ - verificationStateTheirQRScanned, verificationStateOurQRScanned, verificationStateSASMACExchanged, + if !slices.Contains([]VerificationState{ + VerificationStateTheirQRScanned, VerificationStateOurQRScanned, VerificationStateSASMACExchanged, }, txn.VerificationState) { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "got done event for transaction that is not in QR-scanned or MAC-exchanged state") return @@ -881,12 +818,16 @@ func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn *verif txn.ReceivedTheirDone = true if txn.SentOurDone { - delete(vh.activeTransactions, txn.TransactionID) + if err := vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { + log.Err(err).Msg("Delete verification failed") + } vh.verificationDone(ctx, txn.TransactionID) + } else if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { + log.Err(err).Msg("failed to save verification transaction") } } -func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn *verificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn VerificationTransaction, evt *event.Event) { cancelEvt := evt.Content.AsVerificationCancel() log := vh.getLog(ctx).With(). Str("verification_action", "cancel"). @@ -912,7 +853,7 @@ func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn *ver // that is currently in the REQUESTED state, then we will send // cancellations to all of the devices that we sent the request to. This // will ensure that all of the clients know that the request was cancelled. - if txn.VerificationState == verificationStateRequested && len(txn.SentToDeviceIDs) > 0 { + if txn.VerificationState == VerificationStateRequested && len(txn.SentToDeviceIDs) > 0 { content := &event.Content{ Parsed: &event.VerificationCancelEventContent{ ToDeviceVerificationEvent: event.ToDeviceVerificationEvent{TransactionID: txn.TransactionID}, @@ -920,9 +861,9 @@ func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn *ver Reason: "The verification was rejected from another device.", }, } - req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUser: {}}} + req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUserID: {}}} for _, deviceID := range txn.SentToDeviceIDs { - req.Messages[txn.TheirUser][deviceID] = content + req.Messages[txn.TheirUserID][deviceID] = content } _, err := vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) if err != nil { @@ -930,6 +871,8 @@ func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn *ver } } - delete(vh.activeTransactions, txn.TransactionID) + if err := vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { + log.Err(err).Msg("Delete verification failed") + } vh.verificationCancelledCallback(ctx, txn.TransactionID, cancelEvt.Code, cancelEvt.Reason) } diff --git a/crypto/verificationhelper/verificationhelper_qr_self_test.go b/crypto/verificationhelper/verificationhelper_qr_self_test.go index 11358b88..937cc414 100644 --- a/crypto/verificationhelper/verificationhelper_qr_self_test.go +++ b/crypto/verificationhelper/verificationhelper_qr_self_test.go @@ -278,12 +278,12 @@ func TestSelfVerification_ScanQRTransactionIDCorrupted(t *testing.T) { // Emulate scanning the QR code shown by the receiving device // on the sending device. err = sendingHelper.HandleScannedQRData(ctx, receivingShownQRCodeBytes) - assert.ErrorContains(t, err, "unknown transaction ID found in QR code") + assert.ErrorContains(t, err, "unknown transaction ID") // Emulate scanning the QR code shown by the sending device on // the receiving device. err = receivingHelper.HandleScannedQRData(ctx, sendingShownQRCodeBytes) - assert.ErrorContains(t, err, "unknown transaction ID found in QR code") + assert.ErrorContains(t, err, "unknown transaction ID") } func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go index 273042c3..d0bf2298 100644 --- a/crypto/verificationhelper/verificationhelper_test.go +++ b/crypto/verificationhelper/verificationhelper_test.go @@ -65,11 +65,11 @@ func initServerAndLoginAliceBob(t *testing.T, ctx context.Context) (ts *mockServ func initDefaultCallbacks(t *testing.T, ctx context.Context, sendingClient, receivingClient *mautrix.Client, sendingMachine, receivingMachine *crypto.OlmMachine) (sendingCallbacks, receivingCallbacks *allVerificationCallbacks, sendingHelper, receivingHelper *verificationhelper.VerificationHelper) { t.Helper() sendingCallbacks = newAllVerificationCallbacks() - sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, sendingCallbacks, true) + sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, sendingCallbacks, true) require.NoError(t, sendingHelper.Init(ctx)) receivingCallbacks = newAllVerificationCallbacks() - receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, receivingCallbacks, true) + receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, nil, receivingCallbacks, true) require.NoError(t, receivingHelper.Init(ctx)) return } @@ -104,7 +104,7 @@ func TestVerification_Start(t *testing.T) { addDeviceID(ctx, cryptoStore, aliceUserID, receivingDeviceID) addDeviceID(ctx, cryptoStore, aliceUserID, receivingDeviceID2) - senderHelper := verificationhelper.NewVerificationHelper(client, client.Crypto.(*cryptohelper.CryptoHelper).Machine(), tc.callbacks, tc.supportsScan) + senderHelper := verificationhelper.NewVerificationHelper(client, client.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, tc.callbacks, tc.supportsScan) err := senderHelper.Init(ctx) require.NoError(t, err) @@ -151,7 +151,7 @@ func TestVerification_StartThenCancel(t *testing.T) { bystanderClient, _ := ts.Login(t, ctx, aliceUserID, bystanderDeviceID) bystanderMachine := bystanderClient.Crypto.(*cryptohelper.CryptoHelper).Machine() - bystanderHelper := verificationhelper.NewVerificationHelper(bystanderClient, bystanderMachine, newAllVerificationCallbacks(), true) + bystanderHelper := verificationhelper.NewVerificationHelper(bystanderClient, bystanderMachine, nil, newAllVerificationCallbacks(), true) require.NoError(t, bystanderHelper.Init(ctx)) require.NoError(t, sendingCryptoStore.PutDevice(ctx, aliceUserID, bystanderMachine.OwnIdentity())) @@ -241,12 +241,12 @@ func TestVerification_Accept_NoSupportedMethods(t *testing.T) { assert.NotEmpty(t, recoveryKey) assert.NotNil(t, cache) - sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, newAllVerificationCallbacks(), true) + sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, newAllVerificationCallbacks(), true) err = sendingHelper.Init(ctx) require.NoError(t, err) receivingCallbacks := newBaseVerificationCallbacks() - receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine(), receivingCallbacks, false) + receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, receivingCallbacks, false) err = receivingHelper.Init(ctx) require.NoError(t, err) @@ -289,11 +289,11 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { assert.NotEmpty(t, recoveryKey) assert.NotNil(t, sendingCrossSigningKeysCache) - sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, tc.sendingCallbacks, tc.sendingSupportsScan) + sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, tc.sendingCallbacks, tc.sendingSupportsScan) err = sendingHelper.Init(ctx) require.NoError(t, err) - receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, tc.receivingCallbacks, tc.receivingSupportsScan) + receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, nil, tc.receivingCallbacks, tc.receivingSupportsScan) err = receivingHelper.Init(ctx) require.NoError(t, err) diff --git a/crypto/verificationhelper/verificationstore.go b/crypto/verificationhelper/verificationstore.go new file mode 100644 index 00000000..725a66a6 --- /dev/null +++ b/crypto/verificationhelper/verificationstore.go @@ -0,0 +1,187 @@ +package verificationhelper + +import ( + "context" + "crypto/ecdh" + "encoding/json" + "errors" + "fmt" + + "go.mau.fi/util/jsontime" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +var ErrUnknownVerificationTransaction = errors.New("unknown transaction ID") + +type VerificationState int + +const ( + VerificationStateRequested VerificationState = iota + VerificationStateReady + + VerificationStateTheirQRScanned // We scanned their QR code + VerificationStateOurQRScanned // They scanned our QR code + + VerificationStateSASStarted // An SAS verification has been started + VerificationStateSASAccepted // An SAS verification has been accepted + VerificationStateSASKeysExchanged // An SAS verification has exchanged keys + VerificationStateSASMACExchanged // An SAS verification has exchanged MACs +) + +func (step VerificationState) String() string { + switch step { + case VerificationStateRequested: + return "requested" + case VerificationStateReady: + return "ready" + case VerificationStateTheirQRScanned: + return "their_qr_scanned" + case VerificationStateOurQRScanned: + return "our_qr_scanned" + case VerificationStateSASStarted: + return "sas_started" + case VerificationStateSASAccepted: + return "sas_accepted" + case VerificationStateSASKeysExchanged: + return "sas_keys_exchanged" + case VerificationStateSASMACExchanged: + return "sas_mac" + default: + return fmt.Sprintf("VerificationState(%d)", step) + } +} + +type ECDHPrivateKey struct { + *ecdh.PrivateKey +} + +func (e *ECDHPrivateKey) UnmarshalJSON(data []byte) (err error) { + e.PrivateKey, err = ecdh.P256().NewPrivateKey(data) + return +} + +func (e *ECDHPrivateKey) MarshalJSON() ([]byte, error) { + return json.Marshal(e.Bytes()) +} + +type ECDHPublicKey struct { + *ecdh.PublicKey +} + +func (e *ECDHPublicKey) UnmarshalJSON(data []byte) (err error) { + e.PublicKey, err = ecdh.P256().NewPublicKey(data) + return +} + +func (e *ECDHPublicKey) MarshalJSON() ([]byte, error) { + return json.Marshal(e.Bytes()) +} + +type VerificationTransaction struct { + ExpirationTime jsontime.UnixMilli `json:"expiration_time"` + + // RoomID is the room ID if the verification is happening in a room or + // empty if it is a to-device verification. + RoomID id.RoomID `json:"room_id"` + + // VerificationState is the current step of the verification flow. + VerificationState VerificationState `json:"verification_state"` + // TransactionID is the ID of the verification transaction. + TransactionID id.VerificationTransactionID `json:"transaction_id"` + + // TheirDeviceID is the device ID of the device that either made the + // initial request or accepted our request. + TheirDeviceID id.DeviceID `json:"their_device_id"` + // TheirUserID is the user ID of the other user. + TheirUserID id.UserID `json:"their_user_id"` + // TheirSupportedMethods is a list of verification methods that the other + // device supports. + TheirSupportedMethods []event.VerificationMethod `json:"their_supported_methods"` + + // SentToDeviceIDs is a list of devices which the initial request was sent + // to. This is only used for to-device verification requests, and is meant + // to be used to send cancellation requests to all other devices when a + // verification request is accepted via a m.key.verification.ready event. + SentToDeviceIDs []id.DeviceID `json:"sent_to_device_ids"` + + // QRCodeSharedSecret is the shared secret that was encoded in the QR code + // that we showed. + QRCodeSharedSecret []byte `json:"qr_code_shared_secret"` + + StartedByUs bool `json:"started_by_us"` // Whether the verification was started by us + StartEventContent *event.VerificationStartEventContent `json:"start_event_content"` // The m.key.verification.start event content + Commitment []byte `json:"committment"` // The commitment from the m.key.verification.accept event + MACMethod event.MACMethod `json:"mac_method"` // The method used to calculate the MAC + EphemeralKey *ECDHPrivateKey `json:"ephemeral_key"` // The ephemeral key + EphemeralPublicKeyShared bool `json:"ephemeral_public_key_shared"` // Whether this device's ephemeral public key has been shared + OtherPublicKey *ECDHPublicKey `json:"other_public_key"` // The other device's ephemeral public key + ReceivedTheirMAC bool `json:"received_their_mac"` // Whether we have received their MAC + SentOurMAC bool `json:"sent_our_mac"` // Whether we have sent our MAC + ReceivedTheirDone bool `json:"received_their_done"` // Whether we have received their done event + SentOurDone bool `json:"sent_our_done"` // Whether we have sent our done event +} + +type VerificationStore interface { + // DeleteVerification deletes a verification transaction by ID + DeleteVerification(ctx context.Context, txnID id.VerificationTransactionID) error + // GetVerificationTransaction gets a verification transaction by ID + GetVerificationTransaction(ctx context.Context, txnID id.VerificationTransactionID) (VerificationTransaction, error) + // SaveVerificationTransaction saves a verification transaction by ID + SaveVerificationTransaction(ctx context.Context, txn VerificationTransaction) error + // FindVerificationTransactionForUserDevice finds a verification + // transaction by user and device ID + FindVerificationTransactionForUserDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (VerificationTransaction, error) + // GetAllVerificationTransactions returns all of the verification + // transactions. This is used to reset the cancellation timeouts. + GetAllVerificationTransactions(ctx context.Context) ([]VerificationTransaction, error) +} + +type InMemoryVerificationStore struct { + txns map[id.VerificationTransactionID]VerificationTransaction +} + +var _ VerificationStore = (*InMemoryVerificationStore)(nil) + +func NewInMemoryVerificationStore() *InMemoryVerificationStore { + return &InMemoryVerificationStore{ + txns: map[id.VerificationTransactionID]VerificationTransaction{}, + } +} + +func (i *InMemoryVerificationStore) DeleteVerification(ctx context.Context, txnID id.VerificationTransactionID) error { + if _, ok := i.txns[txnID]; !ok { + return ErrUnknownVerificationTransaction + } + delete(i.txns, txnID) + return nil +} + +func (i *InMemoryVerificationStore) GetVerificationTransaction(ctx context.Context, txnID id.VerificationTransactionID) (VerificationTransaction, error) { + if _, ok := i.txns[txnID]; !ok { + return VerificationTransaction{}, ErrUnknownVerificationTransaction + } + return i.txns[txnID], nil +} + +func (i *InMemoryVerificationStore) SaveVerificationTransaction(ctx context.Context, txn VerificationTransaction) error { + i.txns[txn.TransactionID] = txn + return nil +} + +func (i *InMemoryVerificationStore) FindVerificationTransactionForUserDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (VerificationTransaction, error) { + for _, existingTxn := range i.txns { + if existingTxn.TheirUserID == userID && existingTxn.TheirDeviceID == deviceID { + return existingTxn, nil + } + } + return VerificationTransaction{}, ErrUnknownVerificationTransaction +} + +func (i *InMemoryVerificationStore) GetAllVerificationTransactions(ctx context.Context) (txns []VerificationTransaction, err error) { + for _, txn := range i.txns { + txns = append(txns, txn) + } + return +} From d89912cfcb416ffe26101790c043dd47e4508811 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Wed, 20 Nov 2024 10:08:10 -0700 Subject: [PATCH 0924/1647] verificationhelper: fix hard-coded username Signed-off-by: Sumner Evans --- crypto/verificationhelper/verificationhelper.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index f8173ee1..18a24322 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -334,7 +334,7 @@ func (vh *VerificationHelper) StartInRoomVerification(ctx context.Context, roomI log.Info().Msg("Sending verification request") content := event.MessageEventContent{ MsgType: event.MsgVerificationRequest, - Body: "Alice is requesting to verify your device, but your client does not support verification, so you may need to use a different verification method.", + Body: fmt.Sprintf("%s is requesting to verify your device, but your client does not support verification, so you may need to use a different verification method.", vh.client.UserID), FromDevice: vh.client.DeviceID, Methods: vh.supportedMethods, To: to, From 4cd2bb62ff01431ab5c1ea5e0d57ad6cc6c9458f Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Wed, 20 Nov 2024 10:16:18 -0700 Subject: [PATCH 0925/1647] verificationhelper: improve logging on ready and start event handlers Signed-off-by: Sumner Evans --- crypto/verificationhelper/verificationhelper.go | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 18a24322..bc424624 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -397,7 +397,7 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V } } - log.Info().Msg("Sending ready event") + log.Info().Any("methods", maps.Keys(supportedMethods)).Msg("Sending ready event") readyEvt := &event.VerificationReadyEventContent{ FromDevice: vh.client.DeviceID, Methods: maps.Keys(supportedMethods), @@ -685,6 +685,11 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn Verif txn.TheirDeviceID = readyEvt.FromDevice txn.TheirSupportedMethods = readyEvt.Methods + log.Info(). + Stringer("their_device_id", txn.TheirDeviceID). + Any("their_supported_methods", txn.TheirSupportedMethods). + Msg("Received verification ready event") + // If we sent this verification request, send cancellations to all of the // other devices. if len(txn.SentToDeviceIDs) > 0 { @@ -726,8 +731,12 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn Verif log := vh.getLog(ctx).With(). Str("verification_action", "verification start"). Str("method", string(startEvt.Method)). + Stringer("their_device_id", txn.TheirDeviceID). + Any("their_supported_methods", txn.TheirSupportedMethods). + Bool("started_by_us", txn.StartedByUs). Logger() ctx = log.WithContext(ctx) + log.Info().Msg("Received verification start event") vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() @@ -765,7 +774,7 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn Verif } if txn.TheirUserID < vh.client.UserID || (txn.TheirUserID == vh.client.UserID && txn.TheirDeviceID < vh.client.DeviceID) { - // Use their start event instead of ours + log.Debug().Msg("Using their start event instead of ours because they are alphabetically before us") txn.StartedByUs = false txn.StartEventContent = startEvt } @@ -776,6 +785,7 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn Verif switch startEvt.Method { case event.VerificationMethodSAS: + log.Info().Msg("Received SAS start event") txn.VerificationState = VerificationStateSASStarted if err := vh.onVerificationStartSAS(ctx, txn, evt); err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to handle SAS verification start: %w", err) @@ -804,6 +814,7 @@ func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn Verifi log := vh.getLog(ctx).With(). Str("verification_action", "done"). Stringer("transaction_id", txn.TransactionID). + Bool("sent_our_done", txn.SentOurDone). Logger() log.Info().Msg("Verification done") vh.activeTransactionsLock.Lock() From 1170825b092e97c780e6c583a3282d85d2f29cee Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 21 Nov 2024 18:22:34 +0200 Subject: [PATCH 0926/1647] crypto: fix key share count log --- crypto/machine.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crypto/machine.go b/crypto/machine.go index c9fc2249..7c1093f3 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -278,7 +278,7 @@ func (mach *OlmMachine) HandleOTKCounts(ctx context.Context, otkCount *mautrix.O log := mach.Log.With().Str("trace_id", traceID).Logger() ctx = log.WithContext(ctx) log.Debug(). - Int("keys_left", otkCount.Curve25519). + Int("keys_left", otkCount.SignedCurve25519). Msg("Sync response said we have less than 50 signed curve25519 keys left, sharing new ones...") err := mach.ShareKeys(ctx, otkCount.SignedCurve25519) if err != nil { From b4551fc3da8fe3658afcc84e3546e46836e46e90 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 22 Nov 2024 10:20:34 +0200 Subject: [PATCH 0927/1647] crypto/decryptolm: don't use deleted sessions for decrypting --- crypto/decryptolm.go | 1 + 1 file changed, 1 insertion(+) diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go index 965656a9..8f1eb1f7 100644 --- a/crypto/decryptolm.go +++ b/crypto/decryptolm.go @@ -208,6 +208,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.C Msg("Deleted olm session") } } + sessions = sessions[:MaxOlmSessionsPerDevice] } for _, session := range sessions { From 249bc1b14e9835cbb956839778e168d6a05c2889 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 25 Nov 2024 11:06:33 -0700 Subject: [PATCH 0928/1647] Revert "verificationhelper: save verification status in store" (#319) Reverts commit 40dbe7535dcf7e4e8fc79d786f938b6cb98b753a Signed-off-by: Sumner Evans Signed-off-by: Sumner Evans --- crypto/verificationhelper/reciprocate.go | 69 ++-- crypto/verificationhelper/sas.go | 103 +++--- .../verificationhelper/verificationhelper.go | 303 +++++++++++------- .../verificationhelper_qr_self_test.go | 4 +- .../verificationhelper_test.go | 16 +- .../verificationhelper/verificationstore.go | 187 ----------- 6 files changed, 265 insertions(+), 417 deletions(-) delete mode 100644 crypto/verificationhelper/verificationstore.go diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index 395775e1..21276218 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -35,13 +35,13 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, err := vh.store.GetVerificationTransaction(ctx, qrCode.TransactionID) - if err != nil { - return fmt.Errorf("failed to get transaction %s: %w", qrCode.TransactionID, err) - } else if txn.VerificationState != VerificationStateReady { + txn, ok := vh.activeTransactions[qrCode.TransactionID] + if !ok { + return fmt.Errorf("unknown transaction ID found in QR code") + } else if txn.VerificationState != verificationStateReady { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "transaction found in the QR code is not in the ready state") } - txn.VerificationState = VerificationStateTheirQRScanned + txn.VerificationState = verificationStateTheirQRScanned // Verify the keys log.Info().Msg("Verifying keys from QR code") @@ -53,9 +53,9 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by switch qrCode.Mode { case QRCodeModeCrossSigning: - theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID) + theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUser) if err != nil { - return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "couldn't get %s's cross-signing keys: %w", txn.TheirUserID, err) + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "couldn't get %s's cross-signing keys: %w", txn.TheirUser, err) } if bytes.Equal(theirSigningKeys.MasterKey.Bytes(), qrCode.Key1[:]) { log.Info().Msg("Verified that the other device has the master key we expected") @@ -70,7 +70,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "the master key does not match") } - if err := vh.mach.SignUser(ctx, txn.TheirUserID, theirSigningKeys.MasterKey); err != nil { + if err := vh.mach.SignUser(ctx, txn.TheirUser, theirSigningKeys.MasterKey); err != nil { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to sign their master key: %w", err) } case QRCodeModeSelfVerifyingMasterKeyTrusted: @@ -78,7 +78,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by // means that we don't trust the key. Key1 is the master key public // key, and Key2 is what the other device thinks our device key is. - if vh.client.UserID != txn.TheirUserID { + if vh.client.UserID != txn.TheirUser { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "mode %d is only allowed when the other user is the same as the current user", qrCode.Mode) } @@ -114,12 +114,12 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeMasterKeyNotTrusted, "the master key is not trusted by this device, cannot verify device that does not trust the master key") } - if vh.client.UserID != txn.TheirUserID { + if vh.client.UserID != txn.TheirUser { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "mode %d is only allowed when the other user is the same as the current user", qrCode.Mode) } // Get their device - theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) + theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) if err != nil { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to get their device: %w", err) } @@ -140,7 +140,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by // Trust their device theirDevice.Trust = id.TrustStateVerified - err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) + err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUser, theirDevice) if err != nil { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to update device trust state after verifying: %+v", err) } @@ -177,12 +177,8 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by txn.SentOurDone = true if txn.ReceivedTheirDone { log.Debug().Msg("We already received their done event. Setting verification state to done.") - if err = vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { - return err - } + delete(vh.activeTransactions, txn.TransactionID) vh.verificationDone(ctx, txn.TransactionID) - } else { - vh.store.SaveVerificationTransaction(ctx, txn) } return nil } @@ -200,27 +196,28 @@ func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, err := vh.store.GetVerificationTransaction(ctx, txnID) - if err != nil { - return fmt.Errorf("failed to get transaction %s: %w", txnID, err) - } else if txn.VerificationState != VerificationStateOurQRScanned { + txn, ok := vh.activeTransactions[txnID] + if !ok { + log.Warn().Msg("Ignoring QR code scan confirmation for an unknown transaction") + return nil + } else if txn.VerificationState != verificationStateOurQRScanned { return fmt.Errorf("transaction is not in the scanned state") } log.Info().Msg("Confirming QR code scanned") - if txn.TheirUserID == vh.client.UserID { + if txn.TheirUser == vh.client.UserID { // Self-signing situation. Trust their device. // Get their device - theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) + theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) if err != nil { return err } // Trust their device theirDevice.Trust = id.TrustStateVerified - err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) + err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUser, theirDevice) if err != nil { return fmt.Errorf("failed to update device trust state after verifying: %w", err) } @@ -234,33 +231,29 @@ func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id } } else { // Cross-signing situation. Sign their master key. - theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID) + theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUser) if err != nil { - return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "couldn't get %s's cross-signing keys: %w", txn.TheirUserID, err) + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "couldn't get %s's cross-signing keys: %w", txn.TheirUser, err) } - if err := vh.mach.SignUser(ctx, txn.TheirUserID, theirSigningKeys.MasterKey); err != nil { + if err := vh.mach.SignUser(ctx, txn.TheirUser, theirSigningKeys.MasterKey); err != nil { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to sign their master key: %w", err) } } - err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) + err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { return err } txn.SentOurDone = true if txn.ReceivedTheirDone { - if err = vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { - return err - } + delete(vh.activeTransactions, txn.TransactionID) vh.verificationDone(ctx, txn.TransactionID) - } else { - vh.store.SaveVerificationTransaction(ctx, txn) } return nil } -func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn VerificationTransaction) error { +func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *verificationTransaction) error { log := vh.getLog(ctx).With(). Str("verification_action", "generate and show QR code"). Stringer("transaction_id", txn.TransactionID). @@ -283,7 +276,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn Ver return err } mode := QRCodeModeCrossSigning - if vh.client.UserID == txn.TheirUserID { + if vh.client.UserID == txn.TheirUser { // This is a self-signing situation. if ownMasterKeyTrusted { mode = QRCodeModeSelfVerifyingMasterKeyTrusted @@ -305,7 +298,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn Ver key1 = ownCrossSigningPublicKeys.MasterKey.Bytes() // Key 2 is the other user's master signing key. - theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID) + theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUser) if err != nil { return err } @@ -315,7 +308,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn Ver key1 = ownCrossSigningPublicKeys.MasterKey.Bytes() // Key 2 is the other device's key. - theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) + theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) if err != nil { return err } @@ -333,5 +326,5 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn Ver qrCode := NewQRCode(mode, txn.TransactionID, [32]byte(key1), [32]byte(key2)) txn.QRCodeSharedSecret = qrCode.SharedSecret vh.showQRCode(ctx, txn.TransactionID, qrCode) - return vh.store.SaveVerificationTransaction(ctx, txn) + return nil } diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index 0492dd8d..e28ec405 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -40,23 +40,23 @@ func (vh *VerificationHelper) StartSAS(ctx context.Context, txnID id.Verificatio vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, err := vh.store.GetVerificationTransaction(ctx, txnID) - if err != nil { - return fmt.Errorf("failed to get verification transaction %s: %w", txnID, err) - } else if txn.VerificationState != VerificationStateReady { + txn, ok := vh.activeTransactions[txnID] + if !ok { + return fmt.Errorf("unknown transaction ID") + } else if txn.VerificationState != verificationStateReady { return errors.New("transaction is not in ready state") } else if txn.StartEventContent != nil { return errors.New("start event already sent or received") } - txn.VerificationState = VerificationStateSASStarted + txn.VerificationState = verificationStateSASStarted txn.StartedByUs = true if !slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodSAS) { return fmt.Errorf("the other device does not support SAS verification") } // Ensure that we have their device key. - _, err = vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) + _, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) if err != nil { log.Err(err).Msg("Failed to fetch device") return err @@ -78,9 +78,6 @@ func (vh *VerificationHelper) StartSAS(ctx context.Context, txnID id.Verificatio event.SASMethodEmoji, }, } - if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { - return err - } return vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationStart, txn.StartEventContent) } @@ -97,13 +94,14 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, err := vh.store.GetVerificationTransaction(ctx, txnID) - if err != nil { - return fmt.Errorf("failed to get transaction %s: %w", txnID, err) - } else if txn.VerificationState != VerificationStateSASKeysExchanged { + txn, ok := vh.activeTransactions[txnID] + if !ok { + return fmt.Errorf("unknown transaction ID") + } else if txn.VerificationState != verificationStateSASKeysExchanged { return errors.New("transaction is not in keys exchanged state") } + var err error keys := map[id.KeyID]jsonbytes.UnpaddedBytes{} log.Info().Msg("Signing keys") @@ -111,7 +109,7 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat // My device key myDevice := vh.mach.OwnIdentity() myDeviceKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, myDevice.DeviceID.String()) - keys[myDeviceKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUserID, txn.TheirDeviceID, myDeviceKeyID.String(), myDevice.SigningKey.String()) + keys[myDeviceKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUser, txn.TheirDevice, myDeviceKeyID.String(), myDevice.SigningKey.String()) if err != nil { return err } @@ -120,7 +118,7 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat crossSigningKeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) if crossSigningKeys != nil { crossSigningKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, crossSigningKeys.MasterKey.String()) - keys[crossSigningKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUserID, txn.TheirDeviceID, crossSigningKeyID.String(), crossSigningKeys.MasterKey.String()) + keys[crossSigningKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUser, txn.TheirDevice, crossSigningKeyID.String(), crossSigningKeys.MasterKey.String()) if err != nil { return err } @@ -131,7 +129,7 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat keyIDs = append(keyIDs, keyID.String()) } slices.Sort(keyIDs) - keysMAC, err := vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUserID, txn.TheirDeviceID, "KEY_IDS", strings.Join(keyIDs, ",")) + keysMAC, err := vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUser, txn.TheirDevice, "KEY_IDS", strings.Join(keyIDs, ",")) if err != nil { return err } @@ -147,14 +145,14 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat txn.SentOurMAC = true if txn.ReceivedTheirMAC { - txn.VerificationState = VerificationStateSASMACExchanged + txn.VerificationState = verificationStateSASMACExchanged err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { return err } txn.SentOurDone = true } - return vh.store.SaveVerificationTransaction(ctx, txn) + return nil } // onVerificationStartSAS handles the m.key.verification.start events with @@ -162,7 +160,7 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat // Spec. // // [Section 11.12.2.2]: https://spec.matrix.org/v1.9/client-server-api/#short-authentication-string-sas-verification -func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn VerificationTransaction, evt *event.Event) error { +func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn *verificationTransaction, evt *event.Event) error { startEvt := evt.Content.AsVerificationStart() log := vh.getLog(ctx).With(). Str("verification_action", "start_sas"). @@ -210,7 +208,7 @@ func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn Ve return fmt.Errorf("failed to generate ephemeral key: %w", err) } txn.MACMethod = macMethod - txn.EphemeralKey = &ECDHPrivateKey{ephemeralKey} + txn.EphemeralKey = ephemeralKey txn.StartEventContent = startEvt commitment, err := calculateCommitment(ephemeralKey.PublicKey(), startEvt) @@ -228,8 +226,8 @@ func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn Ve if err != nil { return fmt.Errorf("failed to send accept event: %w", err) } - txn.VerificationState = VerificationStateSASAccepted - return vh.store.SaveVerificationTransaction(ctx, txn) + txn.VerificationState = verificationStateSASAccepted + return nil } func calculateCommitment(ephemeralPubKey *ecdh.PublicKey, startEvt *event.VerificationStartEventContent) ([]byte, error) { @@ -254,7 +252,7 @@ func calculateCommitment(ephemeralPubKey *ecdh.PublicKey, startEvt *event.Verifi // event. This follows Step 4 of [Section 11.12.2.2] of the Spec. // // [Section 11.12.2.2]: https://spec.matrix.org/v1.9/client-server-api/#short-authentication-string-sas-verification -func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn VerificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn *verificationTransaction, evt *event.Event) { acceptEvt := evt.Content.AsVerificationAccept() log := vh.getLog(ctx).With(). Str("verification_action", "accept"). @@ -269,7 +267,7 @@ func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn Veri vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState != VerificationStateSASStarted { + if txn.VerificationState != verificationStateSASStarted { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "received accept event for a transaction that is not in the started state") return @@ -289,18 +287,14 @@ func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn Veri return } - txn.VerificationState = VerificationStateSASAccepted + txn.VerificationState = verificationStateSASAccepted txn.MACMethod = acceptEvt.MessageAuthenticationCode txn.Commitment = acceptEvt.Commitment - txn.EphemeralKey = &ECDHPrivateKey{ephemeralKey} + txn.EphemeralKey = ephemeralKey txn.EphemeralPublicKeyShared = true - - if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { - log.Err(err).Msg("failed to save verification transaction") - } } -func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn VerificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn *verificationTransaction, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "key"). Logger() @@ -308,23 +302,22 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn Verific vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState != VerificationStateSASAccepted { + if txn.VerificationState != verificationStateSASAccepted { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "received key event for a transaction that is not in the accepted state") return } var err error - publicKey, err := ecdh.X25519().NewPublicKey(keyEvt.Key) + txn.OtherPublicKey, err = ecdh.X25519().NewPublicKey(keyEvt.Key) if err != nil { log.Err(err).Msg("Failed to generate other public key") return } - txn.OtherPublicKey = &ECDHPublicKey{publicKey} if txn.EphemeralPublicKeyShared { // Verify that the commitment hash is correct - commitment, err := calculateCommitment(publicKey, txn.StartEventContent) + commitment, err := calculateCommitment(txn.OtherPublicKey, txn.StartEventContent) if err != nil { log.Err(err).Msg("Failed to calculate commitment") return @@ -349,7 +342,7 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn Verific } txn.EphemeralPublicKeyShared = true } - txn.VerificationState = VerificationStateSASKeysExchanged + txn.VerificationState = verificationStateSASKeysExchanged sasBytes, err := vh.verificationSASHKDF(txn) if err != nil { @@ -377,14 +370,10 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn Verific } } vh.showSAS(ctx, txn.TransactionID, emojis, decimals) - - if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { - log.Err(err).Msg("failed to save verification transaction") - } } -func (vh *VerificationHelper) verificationSASHKDF(txn VerificationTransaction) ([]byte, error) { - sharedSecret, err := txn.EphemeralKey.ECDH(txn.OtherPublicKey.PublicKey) +func (vh *VerificationHelper) verificationSASHKDF(txn *verificationTransaction) ([]byte, error) { + sharedSecret, err := txn.EphemeralKey.ECDH(txn.OtherPublicKey) if err != nil { return nil, err } @@ -399,8 +388,8 @@ func (vh *VerificationHelper) verificationSASHKDF(txn VerificationTransaction) ( }, "|") theirInfo := strings.Join([]string{ - txn.TheirUserID.String(), - txn.TheirDeviceID.String(), + txn.TheirUser.String(), + txn.TheirDevice.String(), base64.RawStdEncoding.EncodeToString(txn.OtherPublicKey.Bytes()), }, "|") @@ -473,8 +462,8 @@ func BrokenB64Encode(input []byte) string { return string(output) } -func (vh *VerificationHelper) verificationMACHKDF(txn VerificationTransaction, senderUser id.UserID, senderDevice id.DeviceID, receivingUser id.UserID, receivingDevice id.DeviceID, keyID, key string) ([]byte, error) { - sharedSecret, err := txn.EphemeralKey.ECDH(txn.OtherPublicKey.PublicKey) +func (vh *VerificationHelper) verificationMACHKDF(txn *verificationTransaction, senderUser id.UserID, senderDevice id.DeviceID, receivingUser id.UserID, receivingDevice id.DeviceID, keyID, key string) ([]byte, error) { + sharedSecret, err := txn.EphemeralKey.ECDH(txn.OtherPublicKey) if err != nil { return nil, err } @@ -574,7 +563,7 @@ var allEmojis = []rune{ '📌', } -func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn VerificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verificationTransaction, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "mac"). Logger() @@ -590,12 +579,12 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific for keyID := range macEvt.MAC { keyIDs = append(keyIDs, keyID.String()) _, kID := keyID.Parse() - if kID == txn.TheirDeviceID.String() { + if kID == txn.TheirDevice.String() { hasTheirDeviceKey = true } } slices.Sort(keyIDs) - expectedKeyMAC, err := vh.verificationMACHKDF(txn, txn.TheirUserID, txn.TheirDeviceID, vh.client.UserID, vh.client.DeviceID, "KEY_IDS", strings.Join(keyIDs, ",")) + expectedKeyMAC, err := vh.verificationMACHKDF(txn, txn.TheirUser, txn.TheirDevice, vh.client.UserID, vh.client.DeviceID, "KEY_IDS", strings.Join(keyIDs, ",")) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeSASMismatch, "failed to calculate key list MAC: %w", err) return @@ -621,8 +610,8 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific var key string var theirDevice *id.Device - if kID == txn.TheirDeviceID.String() { - theirDevice, err = vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) + if kID == txn.TheirDevice.String() { + theirDevice, err = vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to fetch their device: %w", err) return @@ -641,7 +630,7 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific key = crossSigningKeys.MasterKey.String() } - expectedMAC, err := vh.verificationMACHKDF(txn, txn.TheirUserID, txn.TheirDeviceID, vh.client.UserID, vh.client.DeviceID, keyID.String(), key) + expectedMAC, err := vh.verificationMACHKDF(txn, txn.TheirUser, txn.TheirDevice, vh.client.UserID, vh.client.DeviceID, keyID.String(), key) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to calculate key MAC: %w", err) return @@ -652,9 +641,9 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific } // Trust their device - if kID == txn.TheirDeviceID.String() { + if kID == txn.TheirDevice.String() { theirDevice.Trust = id.TrustStateVerified - err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) + err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUser, theirDevice) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to update device trust state after verifying: %w", err) return @@ -665,7 +654,7 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific txn.ReceivedTheirMAC = true if txn.SentOurMAC { - txn.VerificationState = VerificationStateSASMACExchanged + txn.VerificationState = verificationStateSASMACExchanged err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to send verification done event: %w", err) @@ -673,8 +662,4 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific } txn.SentOurDone = true } - - if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { - log.Err(err).Msg("failed to save verification transaction") - } } diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index bc424624..be8357f5 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -9,7 +9,7 @@ package verificationhelper import ( "bytes" "context" - "errors" + "crypto/ecdh" "fmt" "sync" "time" @@ -25,6 +25,86 @@ import ( "maunium.net/go/mautrix/id" ) +type verificationState int + +const ( + verificationStateRequested verificationState = iota + verificationStateReady + + verificationStateTheirQRScanned // We scanned their QR code + verificationStateOurQRScanned // They scanned our QR code + + verificationStateSASStarted // An SAS verification has been started + verificationStateSASAccepted // An SAS verification has been accepted + verificationStateSASKeysExchanged // An SAS verification has exchanged keys + verificationStateSASMACExchanged // An SAS verification has exchanged MACs +) + +func (step verificationState) String() string { + switch step { + case verificationStateRequested: + return "requested" + case verificationStateReady: + return "ready" + case verificationStateTheirQRScanned: + return "their_qr_scanned" + case verificationStateOurQRScanned: + return "our_qr_scanned" + case verificationStateSASStarted: + return "sas_started" + case verificationStateSASAccepted: + return "sas_accepted" + case verificationStateSASKeysExchanged: + return "sas_keys_exchanged" + case verificationStateSASMACExchanged: + return "sas_mac" + default: + return fmt.Sprintf("verificationStep(%d)", step) + } +} + +type verificationTransaction struct { + // RoomID is the room ID if the verification is happening in a room or + // empty if it is a to-device verification. + RoomID id.RoomID + + // VerificationState is the current step of the verification flow. + VerificationState verificationState + // TransactionID is the ID of the verification transaction. + TransactionID id.VerificationTransactionID + + // TheirDevice is the device ID of the device that either made the initial + // request or accepted our request. + TheirDevice id.DeviceID + // TheirUser is the user ID of the other user. + TheirUser id.UserID + // TheirSupportedMethods is a list of verification methods that the other + // device supports. + TheirSupportedMethods []event.VerificationMethod + + // SentToDeviceIDs is a list of devices which the initial request was sent + // to. This is only used for to-device verification requests, and is meant + // to be used to send cancellation requests to all other devices when a + // verification request is accepted via a m.key.verification.ready event. + SentToDeviceIDs []id.DeviceID + + // QRCodeSharedSecret is the shared secret that was encoded in the QR code + // that we showed. + QRCodeSharedSecret []byte + + StartedByUs bool // Whether the verification was started by us + StartEventContent *event.VerificationStartEventContent // The m.key.verification.start event content + Commitment []byte // The commitment from the m.key.verification.accept event + MACMethod event.MACMethod // The method used to calculate the MAC + EphemeralKey *ecdh.PrivateKey // The ephemeral key + EphemeralPublicKeyShared bool // Whether this device's ephemeral public key has been shared + OtherPublicKey *ecdh.PublicKey // The other device's ephemeral public key + ReceivedTheirMAC bool // Whether we have received their MAC + SentOurMAC bool // Whether we have sent our MAC + ReceivedTheirDone bool // Whether we have received their done event + SentOurDone bool // Whether we have sent our done event +} + // RequiredCallbacks is an interface representing the callbacks required for // the [VerificationHelper]. type RequiredCallbacks interface { @@ -65,9 +145,8 @@ type VerificationHelper struct { client *mautrix.Client mach *crypto.OlmMachine - store VerificationStore + activeTransactions map[id.VerificationTransactionID]*verificationTransaction activeTransactionsLock sync.Mutex - // activeTransactions map[id.VerificationTransactionID]*verificationTransaction // supportedMethods are the methods that *we* support supportedMethods []event.VerificationMethod @@ -84,19 +163,15 @@ type VerificationHelper struct { var _ mautrix.VerificationHelper = (*VerificationHelper)(nil) -func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, store VerificationStore, callbacks any, supportsScan bool) *VerificationHelper { +func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, callbacks any, supportsScan bool) *VerificationHelper { if client.Crypto == nil { panic("client.Crypto is nil") } - if store == nil { - store = NewInMemoryVerificationStore() - } - helper := VerificationHelper{ - client: client, - mach: mach, - store: store, + client: client, + mach: mach, + activeTransactions: map[id.VerificationTransactionID]*verificationTransaction{}, } if c, ok := callbacks.(RequiredCallbacks); !ok { @@ -158,7 +233,7 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { // Wrapper for the event handlers to check that the transaction ID is known // and ignore the event if it isn't. - wrapHandler := func(callback func(context.Context, VerificationTransaction, *event.Event)) func(context.Context, *event.Event) { + wrapHandler := func(callback func(context.Context, *verificationTransaction, *event.Event)) func(context.Context, *event.Event) { return func(ctx context.Context, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "check transaction ID"). @@ -182,11 +257,8 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { log = log.With().Stringer("transaction_id", transactionID).Logger() vh.activeTransactionsLock.Lock() - txn, err := vh.store.GetVerificationTransaction(ctx, transactionID) - if err != nil && errors.Is(err, ErrUnknownVerificationTransaction) { - log.Err(err).Msg("failed to get verification transaction") - return - } else if errors.Is(err, ErrUnknownVerificationTransaction) { + txn, ok := vh.activeTransactions[transactionID] + if !ok { // If it's a cancellation event for an unknown transaction, we // can just ignore it. if evt.Type == event.ToDeviceVerificationCancel || evt.Type == event.InRoomVerificationCancel { @@ -199,9 +271,9 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { // We have to create a fake transaction so that the call to // cancelVerificationTxn works. - txn = VerificationTransaction{ - RoomID: evt.RoomID, - TheirUserID: evt.Sender, + txn = &verificationTransaction{ + RoomID: evt.RoomID, + TheirUser: evt.Sender, } if transactionable, ok := evt.Content.Parsed.(event.VerificationTransactionable); ok { txn.TransactionID = transactionable.GetTransactionID() @@ -209,7 +281,7 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { txn.TransactionID = id.VerificationTransactionID(evt.ID) } if fromDevice, ok := evt.Content.Raw["from_device"]; ok { - txn.TheirDeviceID = id.DeviceID(fromDevice.(string)) + txn.TheirDevice = id.DeviceID(fromDevice.(string)) } // Send a cancellation event. @@ -250,11 +322,7 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { syncer.OnEventType(event.InRoomVerificationKey, wrapHandler(vh.onVerificationKey)) // SAS syncer.OnEventType(event.InRoomVerificationMAC, wrapHandler(vh.onVerificationMAC)) // SAS - allTransactions, err := vh.store.GetAllVerificationTransactions(ctx) - for _, txn := range allTransactions { - vh.expireTransactionAt(txn.TransactionID, txn.ExpirationTime.Time) - } - return err + return nil } // StartVerification starts an interactive verification flow with the given @@ -314,12 +382,13 @@ func (vh *VerificationHelper) StartVerification(ctx context.Context, to id.UserI vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - return txnID, vh.store.SaveVerificationTransaction(ctx, VerificationTransaction{ - VerificationState: VerificationStateRequested, + vh.activeTransactions[txnID] = &verificationTransaction{ + VerificationState: verificationStateRequested, TransactionID: txnID, - TheirUserID: to, + TheirUser: to, SentToDeviceIDs: maps.Keys(devices), - }) + } + return txnID, nil } // StartInRoomVerification starts an interactive verification flow with the @@ -353,12 +422,13 @@ func (vh *VerificationHelper) StartInRoomVerification(ctx context.Context, roomI vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - return txnID, vh.store.SaveVerificationTransaction(ctx, VerificationTransaction{ + vh.activeTransactions[txnID] = &verificationTransaction{ RoomID: roomID, - VerificationState: VerificationStateRequested, + VerificationState: verificationStateRequested, TransactionID: txnID, - TheirUserID: to, - }) + TheirUser: to, + } + return txnID, nil } // AcceptVerification accepts a verification request. The transaction ID should @@ -370,10 +440,10 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V Stringer("transaction_id", txnID). Logger() - txn, err := vh.store.GetVerificationTransaction(ctx, txnID) - if err != nil { - return err - } else if txn.VerificationState != VerificationStateRequested { + txn, ok := vh.activeTransactions[txnID] + if !ok { + return fmt.Errorf("unknown transaction ID") + } else if txn.VerificationState != verificationStateRequested { return fmt.Errorf("transaction is not in the requested state") } @@ -402,11 +472,11 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V FromDevice: vh.client.DeviceID, Methods: maps.Keys(supportedMethods), } - err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationReady, readyEvt) + err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationReady, readyEvt) if err != nil { return err } - txn.VerificationState = VerificationStateReady + txn.VerificationState = verificationStateReady if vh.scanQRCode != nil && slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) { vh.scanQRCode(ctx, txn.TransactionID) @@ -422,7 +492,8 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V func (vh *VerificationHelper) DismissVerification(ctx context.Context, txnID id.VerificationTransactionID) error { vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - return vh.store.DeleteVerification(ctx, txnID) + delete(vh.activeTransactions, txnID) + return nil } // DismissVerification cancels the verification request with the given @@ -433,9 +504,9 @@ func (vh *VerificationHelper) CancelVerification(ctx context.Context, txnID id.V vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, err := vh.store.GetVerificationTransaction(ctx, txnID) - if err != nil { - return err + txn, ok := vh.activeTransactions[txnID] + if !ok { + return fmt.Errorf("unknown transaction ID") } log := vh.getLog(ctx).With(). Str("verification_action", "cancel verification"). @@ -456,28 +527,29 @@ func (vh *VerificationHelper) CancelVerification(ctx context.Context, txnID id.V } else { cancelEvt.SetTransactionID(txn.TransactionID) req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{ - txn.TheirUserID: {}, + txn.TheirUser: {}, }} - if len(txn.TheirDeviceID) > 0 { + if len(txn.TheirDevice) > 0 { // Send the cancellation event to only the device that accepted the // verification request. All of the other devices already received a // cancellation event with code "m.acceped". - req.Messages[txn.TheirUserID][txn.TheirDeviceID] = &event.Content{Parsed: cancelEvt} + req.Messages[txn.TheirUser][txn.TheirDevice] = &event.Content{Parsed: cancelEvt} } else { // Send the cancellation event to all of the devices that we sent the // request to. for _, deviceID := range txn.SentToDeviceIDs { if deviceID != vh.client.DeviceID { - req.Messages[txn.TheirUserID][deviceID] = &event.Content{Parsed: cancelEvt} + req.Messages[txn.TheirUser][deviceID] = &event.Content{Parsed: cancelEvt} } } } _, err := vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) if err != nil { - return fmt.Errorf("failed to send m.key.verification.cancel event to %v: %w", maps.Keys(req.Messages[txn.TheirUserID]), err) + return fmt.Errorf("failed to send m.key.verification.cancel event to %v: %w", maps.Keys(req.Messages[txn.TheirUser]), err) } } - return vh.store.DeleteVerification(ctx, txn.TransactionID) + delete(vh.activeTransactions, txn.TransactionID) + return nil } // sendVerificationEvent sends a verification event to the other user's device @@ -489,7 +561,7 @@ func (vh *VerificationHelper) CancelVerification(ctx context.Context, txnID id.V // [event.VerificationTransactionable]. // - evtType can be either the to-device or in-room version of the event type // as it is always stringified. -func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn VerificationTransaction, evtType event.Type, content any) error { +func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn *verificationTransaction, evtType event.Type, content any) error { if txn.RoomID != "" { content.(event.Relatable).SetRelatesTo(&event.RelatesTo{Type: event.RelReference, EventID: id.EventID(txn.TransactionID)}) _, err := vh.client.SendMessageEvent(ctx, txn.RoomID, evtType, &event.Content{ @@ -501,13 +573,13 @@ func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn Ver } else { content.(event.VerificationTransactionable).SetTransactionID(txn.TransactionID) req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{ - txn.TheirUserID: { - txn.TheirDeviceID: &event.Content{Parsed: content}, + txn.TheirUser: { + txn.TheirDevice: &event.Content{Parsed: content}, }, }} _, err := vh.client.SendToDevice(ctx, evtType, &req) if err != nil { - return fmt.Errorf("failed to send %s event to %s: %w", evtType.String(), txn.TheirDeviceID, err) + return fmt.Errorf("failed to send %s event to %s: %w", evtType.String(), txn.TheirDevice, err) } } return nil @@ -519,7 +591,7 @@ func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn Ver // directly to expose the error to its caller). // // Must always be called with the activeTransactionsLock held. -func (vh *VerificationHelper) cancelVerificationTxn(ctx context.Context, txn VerificationTransaction, code event.VerificationCancelCode, reasonFmtStr string, fmtArgs ...any) error { +func (vh *VerificationHelper) cancelVerificationTxn(ctx context.Context, txn *verificationTransaction, code event.VerificationCancelCode, reasonFmtStr string, fmtArgs ...any) error { log := vh.getLog(ctx) reason := fmt.Errorf(reasonFmtStr, fmtArgs...).Error() log.Info(). @@ -533,9 +605,7 @@ func (vh *VerificationHelper) cancelVerificationTxn(ctx context.Context, txn Ver log.Err(err).Msg("failed to send cancellation event") return fmt.Errorf("failed to send cancel verification event (code: %s, reason: %s): %w", code, reason, err) } - if err = vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { - log.Err(err).Msg("deleting verification failed") - } + delete(vh.activeTransactions, txn.TransactionID) vh.verificationCancelledCallback(ctx, txn.TransactionID, code, reason) return fmt.Errorf("verification cancelled (code: %s): %s", code, reason) } @@ -614,58 +684,54 @@ func (vh *VerificationHelper) onVerificationRequest(ctx context.Context, evt *ev } vh.activeTransactionsLock.Lock() - newTxn := VerificationTransaction{ - ExpirationTime: jsontime.UnixMilli{Time: verificationRequest.Timestamp.Add(time.Minute * 10)}, + newTxn := &verificationTransaction{ RoomID: evt.RoomID, - VerificationState: VerificationStateRequested, + VerificationState: verificationStateRequested, TransactionID: verificationRequest.TransactionID, - TheirDeviceID: verificationRequest.FromDevice, - TheirUserID: evt.Sender, + TheirDevice: verificationRequest.FromDevice, + TheirUser: evt.Sender, TheirSupportedMethods: verificationRequest.Methods, } - if txn, err := vh.store.FindVerificationTransactionForUserDevice(ctx, evt.Sender, verificationRequest.FromDevice); err != nil && !errors.Is(err, ErrUnknownVerificationTransaction) { - log.Err(err).Stringer("sender", evt.Sender).Stringer("device_id", verificationRequest.FromDevice).Msg("failed to find verification transaction") - vh.activeTransactionsLock.Unlock() - return - } else if !errors.Is(err, ErrUnknownVerificationTransaction) { - if txn.TransactionID == verificationRequest.TransactionID { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "received a new verification request for the same transaction ID") - } else { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "received multiple verification requests from the same device") + for existingTxnID, existingTxn := range vh.activeTransactions { + if existingTxn.TheirUser == evt.Sender && existingTxn.TheirDevice == verificationRequest.FromDevice && existingTxnID != verificationRequest.TransactionID { + vh.cancelVerificationTxn(ctx, existingTxn, event.VerificationCancelCodeUnexpectedMessage, "received multiple verification requests from the same device") vh.cancelVerificationTxn(ctx, newTxn, event.VerificationCancelCodeUnexpectedMessage, "received multiple verification requests from the same device") + delete(vh.activeTransactions, existingTxnID) + vh.activeTransactionsLock.Unlock() + return + } + + if existingTxnID == verificationRequest.TransactionID { + vh.cancelVerificationTxn(ctx, existingTxn, event.VerificationCancelCodeUnexpectedMessage, "received a new verification request for the same transaction ID") + delete(vh.activeTransactions, existingTxnID) + vh.activeTransactionsLock.Unlock() + return } - vh.activeTransactionsLock.Unlock() - return - } - if err := vh.store.SaveVerificationTransaction(ctx, newTxn); err != nil { - log.Err(err).Msg("failed to save verification transaction") } + vh.activeTransactions[verificationRequest.TransactionID] = newTxn vh.activeTransactionsLock.Unlock() vh.expireTransactionAt(verificationRequest.TransactionID, verificationRequest.Timestamp.Add(time.Minute*10)) vh.verificationRequested(ctx, verificationRequest.TransactionID, evt.Sender) } -func (vh *VerificationHelper) expireTransactionAt(txnID id.VerificationTransactionID, expiresAt time.Time) { +func (vh *VerificationHelper) expireTransactionAt(txnID id.VerificationTransactionID, expireAt time.Time) { go func() { - time.Sleep(time.Until(expiresAt)) + time.Sleep(time.Until(expireAt)) vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, err := vh.store.GetVerificationTransaction(context.Background(), txnID) - if err == ErrUnknownVerificationTransaction { - // Already deleted, nothing to expire + txn, ok := vh.activeTransactions[txnID] + if !ok { return - } else if err != nil { - vh.getLog(context.Background()).Err(err).Msg("failed to get verification transaction to expire") - } else { - vh.cancelVerificationTxn(context.Background(), txn, event.VerificationCancelCodeTimeout, "verification timed out") } + + vh.cancelVerificationTxn(context.Background(), txn, event.VerificationCancelCodeTimeout, "verification timed out") }() } -func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn VerificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *verificationTransaction, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "verification ready"). Logger() @@ -673,7 +739,7 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn Verif vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState != VerificationStateRequested { + if txn.VerificationState != verificationStateRequested { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "verification ready event received for a transaction that is not in the requested state") return } @@ -681,12 +747,12 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn Verif readyEvt := evt.Content.AsVerificationReady() // Update the transaction state. - txn.VerificationState = VerificationStateReady - txn.TheirDeviceID = readyEvt.FromDevice + txn.VerificationState = verificationStateReady + txn.TheirDevice = readyEvt.FromDevice txn.TheirSupportedMethods = readyEvt.Methods log.Info(). - Stringer("their_device_id", txn.TheirDeviceID). + Stringer("their_device_id", txn.TheirDevice). Any("their_supported_methods", txn.TheirSupportedMethods). Msg("Received verification ready event") @@ -700,16 +766,16 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn Verif Reason: "The verification was accepted on another device.", }, } - req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUserID: {}}} + req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUser: {}}} for _, deviceID := range txn.SentToDeviceIDs { - if deviceID == txn.TheirDeviceID || deviceID == vh.client.DeviceID { + if deviceID == txn.TheirDevice || deviceID == vh.client.DeviceID { // Don't ever send a cancellation to the device that accepted // the request or to our own device (which can happen if this // is a self-verification). continue } - req.Messages[txn.TheirUserID][deviceID] = content + req.Messages[txn.TheirUser][deviceID] = content } _, err := vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) if err != nil { @@ -721,17 +787,18 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn Verif vh.scanQRCode(ctx, txn.TransactionID) } - if err := vh.generateAndShowQRCode(ctx, txn); err != nil { + err := vh.generateAndShowQRCode(ctx, txn) + if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to generate and show QR code: %w", err) } } -func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn VerificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *verificationTransaction, evt *event.Event) { startEvt := evt.Content.AsVerificationStart() log := vh.getLog(ctx).With(). Str("verification_action", "verification start"). Str("method", string(startEvt.Method)). - Stringer("their_device_id", txn.TheirDeviceID). + Stringer("their_device_id", txn.TheirDevice). Any("their_supported_methods", txn.TheirSupportedMethods). Bool("started_by_us", txn.StartedByUs). Logger() @@ -741,7 +808,7 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn Verif vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState == VerificationStateSASStarted || txn.VerificationState == VerificationStateOurQRScanned || txn.VerificationState == VerificationStateTheirQRScanned { + if txn.VerificationState == verificationStateSASStarted || txn.VerificationState == verificationStateOurQRScanned || txn.VerificationState == verificationStateTheirQRScanned { // We might have sent the event, and they also sent an event. if txn.StartEventContent == nil || !txn.StartedByUs { // We didn't sent a start event yet, so we have gotten ourselves @@ -773,12 +840,12 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn Verif return } - if txn.TheirUserID < vh.client.UserID || (txn.TheirUserID == vh.client.UserID && txn.TheirDeviceID < vh.client.DeviceID) { + if txn.TheirUser < vh.client.UserID || (txn.TheirUser == vh.client.UserID && txn.TheirDevice < vh.client.DeviceID) { log.Debug().Msg("Using their start event instead of ours because they are alphabetically before us") txn.StartedByUs = false txn.StartEventContent = startEvt } - } else if txn.VerificationState != VerificationStateReady { + } else if txn.VerificationState != verificationStateReady { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "got start event for transaction that is not in ready state") return } @@ -786,7 +853,7 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn Verif switch startEvt.Method { case event.VerificationMethodSAS: log.Info().Msg("Received SAS start event") - txn.VerificationState = VerificationStateSASStarted + txn.VerificationState = verificationStateSASStarted if err := vh.onVerificationStartSAS(ctx, txn, evt); err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to handle SAS verification start: %w", err) } @@ -796,11 +863,8 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn Verif vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "reciprocated shared secret does not match") return } - txn.VerificationState = VerificationStateOurQRScanned + txn.VerificationState = verificationStateOurQRScanned vh.qrCodeScaned(ctx, txn.TransactionID) - if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { - log.Err(err).Msg("failed to save verification transaction") - } default: // Note that we should never get m.qr_code.show.v1 or m.qr_code.scan.v1 // here, since the start command for scanning and showing QR codes @@ -810,18 +874,17 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn Verif } } -func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn VerificationTransaction, evt *event.Event) { - log := vh.getLog(ctx).With(). +func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn *verificationTransaction, evt *event.Event) { + vh.getLog(ctx).Info(). Str("verification_action", "done"). Stringer("transaction_id", txn.TransactionID). Bool("sent_our_done", txn.SentOurDone). - Logger() - log.Info().Msg("Verification done") + Msg("Verification done") vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if !slices.Contains([]VerificationState{ - VerificationStateTheirQRScanned, VerificationStateOurQRScanned, VerificationStateSASMACExchanged, + if !slices.Contains([]verificationState{ + verificationStateTheirQRScanned, verificationStateOurQRScanned, verificationStateSASMACExchanged, }, txn.VerificationState) { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "got done event for transaction that is not in QR-scanned or MAC-exchanged state") return @@ -829,16 +892,12 @@ func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn Verifi txn.ReceivedTheirDone = true if txn.SentOurDone { - if err := vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { - log.Err(err).Msg("Delete verification failed") - } + delete(vh.activeTransactions, txn.TransactionID) vh.verificationDone(ctx, txn.TransactionID) - } else if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { - log.Err(err).Msg("failed to save verification transaction") } } -func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn VerificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn *verificationTransaction, evt *event.Event) { cancelEvt := evt.Content.AsVerificationCancel() log := vh.getLog(ctx).With(). Str("verification_action", "cancel"). @@ -864,7 +923,7 @@ func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn Veri // that is currently in the REQUESTED state, then we will send // cancellations to all of the devices that we sent the request to. This // will ensure that all of the clients know that the request was cancelled. - if txn.VerificationState == VerificationStateRequested && len(txn.SentToDeviceIDs) > 0 { + if txn.VerificationState == verificationStateRequested && len(txn.SentToDeviceIDs) > 0 { content := &event.Content{ Parsed: &event.VerificationCancelEventContent{ ToDeviceVerificationEvent: event.ToDeviceVerificationEvent{TransactionID: txn.TransactionID}, @@ -872,9 +931,9 @@ func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn Veri Reason: "The verification was rejected from another device.", }, } - req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUserID: {}}} + req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUser: {}}} for _, deviceID := range txn.SentToDeviceIDs { - req.Messages[txn.TheirUserID][deviceID] = content + req.Messages[txn.TheirUser][deviceID] = content } _, err := vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) if err != nil { @@ -882,8 +941,6 @@ func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn Veri } } - if err := vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { - log.Err(err).Msg("Delete verification failed") - } + delete(vh.activeTransactions, txn.TransactionID) vh.verificationCancelledCallback(ctx, txn.TransactionID, cancelEvt.Code, cancelEvt.Reason) } diff --git a/crypto/verificationhelper/verificationhelper_qr_self_test.go b/crypto/verificationhelper/verificationhelper_qr_self_test.go index 937cc414..11358b88 100644 --- a/crypto/verificationhelper/verificationhelper_qr_self_test.go +++ b/crypto/verificationhelper/verificationhelper_qr_self_test.go @@ -278,12 +278,12 @@ func TestSelfVerification_ScanQRTransactionIDCorrupted(t *testing.T) { // Emulate scanning the QR code shown by the receiving device // on the sending device. err = sendingHelper.HandleScannedQRData(ctx, receivingShownQRCodeBytes) - assert.ErrorContains(t, err, "unknown transaction ID") + assert.ErrorContains(t, err, "unknown transaction ID found in QR code") // Emulate scanning the QR code shown by the sending device on // the receiving device. err = receivingHelper.HandleScannedQRData(ctx, sendingShownQRCodeBytes) - assert.ErrorContains(t, err, "unknown transaction ID") + assert.ErrorContains(t, err, "unknown transaction ID found in QR code") } func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go index d0bf2298..273042c3 100644 --- a/crypto/verificationhelper/verificationhelper_test.go +++ b/crypto/verificationhelper/verificationhelper_test.go @@ -65,11 +65,11 @@ func initServerAndLoginAliceBob(t *testing.T, ctx context.Context) (ts *mockServ func initDefaultCallbacks(t *testing.T, ctx context.Context, sendingClient, receivingClient *mautrix.Client, sendingMachine, receivingMachine *crypto.OlmMachine) (sendingCallbacks, receivingCallbacks *allVerificationCallbacks, sendingHelper, receivingHelper *verificationhelper.VerificationHelper) { t.Helper() sendingCallbacks = newAllVerificationCallbacks() - sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, sendingCallbacks, true) + sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, sendingCallbacks, true) require.NoError(t, sendingHelper.Init(ctx)) receivingCallbacks = newAllVerificationCallbacks() - receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, nil, receivingCallbacks, true) + receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, receivingCallbacks, true) require.NoError(t, receivingHelper.Init(ctx)) return } @@ -104,7 +104,7 @@ func TestVerification_Start(t *testing.T) { addDeviceID(ctx, cryptoStore, aliceUserID, receivingDeviceID) addDeviceID(ctx, cryptoStore, aliceUserID, receivingDeviceID2) - senderHelper := verificationhelper.NewVerificationHelper(client, client.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, tc.callbacks, tc.supportsScan) + senderHelper := verificationhelper.NewVerificationHelper(client, client.Crypto.(*cryptohelper.CryptoHelper).Machine(), tc.callbacks, tc.supportsScan) err := senderHelper.Init(ctx) require.NoError(t, err) @@ -151,7 +151,7 @@ func TestVerification_StartThenCancel(t *testing.T) { bystanderClient, _ := ts.Login(t, ctx, aliceUserID, bystanderDeviceID) bystanderMachine := bystanderClient.Crypto.(*cryptohelper.CryptoHelper).Machine() - bystanderHelper := verificationhelper.NewVerificationHelper(bystanderClient, bystanderMachine, nil, newAllVerificationCallbacks(), true) + bystanderHelper := verificationhelper.NewVerificationHelper(bystanderClient, bystanderMachine, newAllVerificationCallbacks(), true) require.NoError(t, bystanderHelper.Init(ctx)) require.NoError(t, sendingCryptoStore.PutDevice(ctx, aliceUserID, bystanderMachine.OwnIdentity())) @@ -241,12 +241,12 @@ func TestVerification_Accept_NoSupportedMethods(t *testing.T) { assert.NotEmpty(t, recoveryKey) assert.NotNil(t, cache) - sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, newAllVerificationCallbacks(), true) + sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, newAllVerificationCallbacks(), true) err = sendingHelper.Init(ctx) require.NoError(t, err) receivingCallbacks := newBaseVerificationCallbacks() - receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, receivingCallbacks, false) + receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine(), receivingCallbacks, false) err = receivingHelper.Init(ctx) require.NoError(t, err) @@ -289,11 +289,11 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { assert.NotEmpty(t, recoveryKey) assert.NotNil(t, sendingCrossSigningKeysCache) - sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, tc.sendingCallbacks, tc.sendingSupportsScan) + sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, tc.sendingCallbacks, tc.sendingSupportsScan) err = sendingHelper.Init(ctx) require.NoError(t, err) - receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, nil, tc.receivingCallbacks, tc.receivingSupportsScan) + receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, tc.receivingCallbacks, tc.receivingSupportsScan) err = receivingHelper.Init(ctx) require.NoError(t, err) diff --git a/crypto/verificationhelper/verificationstore.go b/crypto/verificationhelper/verificationstore.go deleted file mode 100644 index 725a66a6..00000000 --- a/crypto/verificationhelper/verificationstore.go +++ /dev/null @@ -1,187 +0,0 @@ -package verificationhelper - -import ( - "context" - "crypto/ecdh" - "encoding/json" - "errors" - "fmt" - - "go.mau.fi/util/jsontime" - - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" -) - -var ErrUnknownVerificationTransaction = errors.New("unknown transaction ID") - -type VerificationState int - -const ( - VerificationStateRequested VerificationState = iota - VerificationStateReady - - VerificationStateTheirQRScanned // We scanned their QR code - VerificationStateOurQRScanned // They scanned our QR code - - VerificationStateSASStarted // An SAS verification has been started - VerificationStateSASAccepted // An SAS verification has been accepted - VerificationStateSASKeysExchanged // An SAS verification has exchanged keys - VerificationStateSASMACExchanged // An SAS verification has exchanged MACs -) - -func (step VerificationState) String() string { - switch step { - case VerificationStateRequested: - return "requested" - case VerificationStateReady: - return "ready" - case VerificationStateTheirQRScanned: - return "their_qr_scanned" - case VerificationStateOurQRScanned: - return "our_qr_scanned" - case VerificationStateSASStarted: - return "sas_started" - case VerificationStateSASAccepted: - return "sas_accepted" - case VerificationStateSASKeysExchanged: - return "sas_keys_exchanged" - case VerificationStateSASMACExchanged: - return "sas_mac" - default: - return fmt.Sprintf("VerificationState(%d)", step) - } -} - -type ECDHPrivateKey struct { - *ecdh.PrivateKey -} - -func (e *ECDHPrivateKey) UnmarshalJSON(data []byte) (err error) { - e.PrivateKey, err = ecdh.P256().NewPrivateKey(data) - return -} - -func (e *ECDHPrivateKey) MarshalJSON() ([]byte, error) { - return json.Marshal(e.Bytes()) -} - -type ECDHPublicKey struct { - *ecdh.PublicKey -} - -func (e *ECDHPublicKey) UnmarshalJSON(data []byte) (err error) { - e.PublicKey, err = ecdh.P256().NewPublicKey(data) - return -} - -func (e *ECDHPublicKey) MarshalJSON() ([]byte, error) { - return json.Marshal(e.Bytes()) -} - -type VerificationTransaction struct { - ExpirationTime jsontime.UnixMilli `json:"expiration_time"` - - // RoomID is the room ID if the verification is happening in a room or - // empty if it is a to-device verification. - RoomID id.RoomID `json:"room_id"` - - // VerificationState is the current step of the verification flow. - VerificationState VerificationState `json:"verification_state"` - // TransactionID is the ID of the verification transaction. - TransactionID id.VerificationTransactionID `json:"transaction_id"` - - // TheirDeviceID is the device ID of the device that either made the - // initial request or accepted our request. - TheirDeviceID id.DeviceID `json:"their_device_id"` - // TheirUserID is the user ID of the other user. - TheirUserID id.UserID `json:"their_user_id"` - // TheirSupportedMethods is a list of verification methods that the other - // device supports. - TheirSupportedMethods []event.VerificationMethod `json:"their_supported_methods"` - - // SentToDeviceIDs is a list of devices which the initial request was sent - // to. This is only used for to-device verification requests, and is meant - // to be used to send cancellation requests to all other devices when a - // verification request is accepted via a m.key.verification.ready event. - SentToDeviceIDs []id.DeviceID `json:"sent_to_device_ids"` - - // QRCodeSharedSecret is the shared secret that was encoded in the QR code - // that we showed. - QRCodeSharedSecret []byte `json:"qr_code_shared_secret"` - - StartedByUs bool `json:"started_by_us"` // Whether the verification was started by us - StartEventContent *event.VerificationStartEventContent `json:"start_event_content"` // The m.key.verification.start event content - Commitment []byte `json:"committment"` // The commitment from the m.key.verification.accept event - MACMethod event.MACMethod `json:"mac_method"` // The method used to calculate the MAC - EphemeralKey *ECDHPrivateKey `json:"ephemeral_key"` // The ephemeral key - EphemeralPublicKeyShared bool `json:"ephemeral_public_key_shared"` // Whether this device's ephemeral public key has been shared - OtherPublicKey *ECDHPublicKey `json:"other_public_key"` // The other device's ephemeral public key - ReceivedTheirMAC bool `json:"received_their_mac"` // Whether we have received their MAC - SentOurMAC bool `json:"sent_our_mac"` // Whether we have sent our MAC - ReceivedTheirDone bool `json:"received_their_done"` // Whether we have received their done event - SentOurDone bool `json:"sent_our_done"` // Whether we have sent our done event -} - -type VerificationStore interface { - // DeleteVerification deletes a verification transaction by ID - DeleteVerification(ctx context.Context, txnID id.VerificationTransactionID) error - // GetVerificationTransaction gets a verification transaction by ID - GetVerificationTransaction(ctx context.Context, txnID id.VerificationTransactionID) (VerificationTransaction, error) - // SaveVerificationTransaction saves a verification transaction by ID - SaveVerificationTransaction(ctx context.Context, txn VerificationTransaction) error - // FindVerificationTransactionForUserDevice finds a verification - // transaction by user and device ID - FindVerificationTransactionForUserDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (VerificationTransaction, error) - // GetAllVerificationTransactions returns all of the verification - // transactions. This is used to reset the cancellation timeouts. - GetAllVerificationTransactions(ctx context.Context) ([]VerificationTransaction, error) -} - -type InMemoryVerificationStore struct { - txns map[id.VerificationTransactionID]VerificationTransaction -} - -var _ VerificationStore = (*InMemoryVerificationStore)(nil) - -func NewInMemoryVerificationStore() *InMemoryVerificationStore { - return &InMemoryVerificationStore{ - txns: map[id.VerificationTransactionID]VerificationTransaction{}, - } -} - -func (i *InMemoryVerificationStore) DeleteVerification(ctx context.Context, txnID id.VerificationTransactionID) error { - if _, ok := i.txns[txnID]; !ok { - return ErrUnknownVerificationTransaction - } - delete(i.txns, txnID) - return nil -} - -func (i *InMemoryVerificationStore) GetVerificationTransaction(ctx context.Context, txnID id.VerificationTransactionID) (VerificationTransaction, error) { - if _, ok := i.txns[txnID]; !ok { - return VerificationTransaction{}, ErrUnknownVerificationTransaction - } - return i.txns[txnID], nil -} - -func (i *InMemoryVerificationStore) SaveVerificationTransaction(ctx context.Context, txn VerificationTransaction) error { - i.txns[txn.TransactionID] = txn - return nil -} - -func (i *InMemoryVerificationStore) FindVerificationTransactionForUserDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (VerificationTransaction, error) { - for _, existingTxn := range i.txns { - if existingTxn.TheirUserID == userID && existingTxn.TheirDeviceID == deviceID { - return existingTxn, nil - } - } - return VerificationTransaction{}, ErrUnknownVerificationTransaction -} - -func (i *InMemoryVerificationStore) GetAllVerificationTransactions(ctx context.Context) (txns []VerificationTransaction, err error) { - for _, txn := range i.txns { - txns = append(txns, txn) - } - return -} From 4820d4da48dd2edfd07eb4d2424955af357ccc26 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 25 Nov 2024 23:49:21 +0200 Subject: [PATCH 0929/1647] client: use timeout 0 for initial sync --- client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client.go b/client.go index e931fa66..f68cae63 100644 --- a/client.go +++ b/client.go @@ -240,7 +240,7 @@ func (cli *Client) SyncWithContext(ctx context.Context) error { streamResp = true } timeout := 30000 - if isFailing { + if isFailing || nextBatch == "" { timeout = 0 } resSync, err := cli.FullSyncRequest(ctx, ReqSync{ From 2353d323a4fef45d071b040aab504dcbce9fd99f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 26 Nov 2024 13:28:26 +0200 Subject: [PATCH 0930/1647] bridgev2/legacymigrate: log portal info in post-migration --- bridgev2/matrix/mxmain/legacymigrate.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go index 18880027..d33dd8cd 100644 --- a/bridgev2/matrix/mxmain/legacymigrate.go +++ b/bridgev2/matrix/mxmain/legacymigrate.go @@ -221,6 +221,11 @@ func (br *BridgeMain) PostMigrate(ctx context.Context) error { return fmt.Errorf("failed to get all portals: %w", err) } for _, portal := range portals { + zerolog.Ctx(ctx).Debug(). + Stringer("room_id", portal.MXID). + Object("portal_key", portal.PortalKey). + Str("room_type", string(portal.RoomType)). + Msg("Migrating portal") switch portal.RoomType { case database.RoomTypeDM: err = br.postMigrateDMPortal(ctx, portal) From 0384e800fd3ee1a9eb2c08d58c3e6bee12a7481a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 26 Nov 2024 14:50:02 +0200 Subject: [PATCH 0931/1647] bridgev2/login: add url and domain user input types --- bridgev2/login.go | 2 ++ bridgev2/matrix/provisioning.yaml | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/bridgev2/login.go b/bridgev2/login.go index 7acccd9a..9e8be655 100644 --- a/bridgev2/login.go +++ b/bridgev2/login.go @@ -159,6 +159,8 @@ const ( LoginInputFieldTypeEmail LoginInputFieldType = "email" LoginInputFieldType2FACode LoginInputFieldType = "2fa_code" LoginInputFieldTypeToken LoginInputFieldType = "token" + LoginInputFieldTypeURL LoginInputFieldType = "url" + LoginInputFieldTypeDomain LoginInputFieldType = "domain" ) type LoginInputDataField struct { diff --git a/bridgev2/matrix/provisioning.yaml b/bridgev2/matrix/provisioning.yaml index bd9217c8..bf6c6f3d 100644 --- a/bridgev2/matrix/provisioning.yaml +++ b/bridgev2/matrix/provisioning.yaml @@ -635,7 +635,7 @@ components: type: type: string description: The type of field. - enum: [ username, phone_number, email, password, 2fa_code, token ] + enum: [ username, phone_number, email, password, 2fa_code, token, url, domain ] id: type: string description: The internal ID of the field. This must be used as the key in the object when submitting the data back to the bridge. From 15da5bed5244b2d3ff53263f3c13e36cd7353231 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Sun, 3 Nov 2024 08:58:53 -0700 Subject: [PATCH 0932/1647] verificationhelper: save verification status in store Signed-off-by: Sumner Evans --- crypto/verificationhelper/ecdhkeys.go | 57 ++++ crypto/verificationhelper/ecdhkeys_test.go | 48 +++ crypto/verificationhelper/reciprocate.go | 67 ++-- crypto/verificationhelper/sas.go | 108 +++--- .../verificationhelper/verificationhelper.go | 312 ++++++++---------- .../verificationhelper_qr_self_test.go | 4 +- .../verificationhelper_test.go | 16 +- .../verificationhelper/verificationstore.go | 159 +++++++++ 8 files changed, 503 insertions(+), 268 deletions(-) create mode 100644 crypto/verificationhelper/ecdhkeys.go create mode 100644 crypto/verificationhelper/ecdhkeys_test.go create mode 100644 crypto/verificationhelper/verificationstore.go diff --git a/crypto/verificationhelper/ecdhkeys.go b/crypto/verificationhelper/ecdhkeys.go new file mode 100644 index 00000000..754530ed --- /dev/null +++ b/crypto/verificationhelper/ecdhkeys.go @@ -0,0 +1,57 @@ +package verificationhelper + +import ( + "crypto/ecdh" + "encoding/json" +) + +type ECDHPrivateKey struct { + *ecdh.PrivateKey +} + +func (e *ECDHPrivateKey) UnmarshalJSON(data []byte) (err error) { + if len(data) == 0 { + return nil + } + var raw []byte + err = json.Unmarshal(data, &raw) + if err != nil { + return + } + if len(raw) == 0 { + return nil + } + e.PrivateKey, err = ecdh.X25519().NewPrivateKey(raw) + return err +} + +func (e ECDHPrivateKey) MarshalJSON() ([]byte, error) { + if e.PrivateKey == nil { + return json.Marshal(nil) + } + return json.Marshal(e.Bytes()) +} + +type ECDHPublicKey struct { + *ecdh.PublicKey +} + +func (e *ECDHPublicKey) UnmarshalJSON(data []byte) (err error) { + if len(data) == 0 { + return nil + } + var raw []byte + err = json.Unmarshal(data, &raw) + if err != nil { + return + } + if len(raw) == 0 { + return nil + } + e.PublicKey, err = ecdh.X25519().NewPublicKey(raw) + return +} + +func (e ECDHPublicKey) MarshalJSON() ([]byte, error) { + return json.Marshal(e.Bytes()) +} diff --git a/crypto/verificationhelper/ecdhkeys_test.go b/crypto/verificationhelper/ecdhkeys_test.go new file mode 100644 index 00000000..109fbf88 --- /dev/null +++ b/crypto/verificationhelper/ecdhkeys_test.go @@ -0,0 +1,48 @@ +package verificationhelper_test + +import ( + "crypto/ecdh" + "crypto/rand" + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/crypto/verificationhelper" +) + +func TestECDHPrivateKey(t *testing.T) { + pk, err := ecdh.X25519().GenerateKey(rand.Reader) + require.NoError(t, err) + private := verificationhelper.ECDHPrivateKey{pk} + marshalled, err := json.Marshal(private) + require.NoError(t, err) + + assert.Len(t, marshalled, 46) + + var unmarshalled verificationhelper.ECDHPrivateKey + err = json.Unmarshal(marshalled, &unmarshalled) + require.NoError(t, err) + + assert.True(t, private.Equal(unmarshalled.PrivateKey)) +} + +func TestECDHPublicKey(t *testing.T) { + private, err := ecdh.X25519().GenerateKey(rand.Reader) + require.NoError(t, err) + + public := private.PublicKey() + + pub := verificationhelper.ECDHPublicKey{public} + marshalled, err := json.Marshal(pub) + require.NoError(t, err) + + assert.Len(t, marshalled, 46) + + var unmarshalled verificationhelper.ECDHPublicKey + err = json.Unmarshal(marshalled, &unmarshalled) + require.NoError(t, err) + + assert.True(t, public.Equal(unmarshalled.PublicKey)) +} diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index 21276218..5c38c655 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -35,13 +35,13 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, ok := vh.activeTransactions[qrCode.TransactionID] - if !ok { - return fmt.Errorf("unknown transaction ID found in QR code") - } else if txn.VerificationState != verificationStateReady { + txn, err := vh.store.GetVerificationTransaction(ctx, qrCode.TransactionID) + if err != nil { + return fmt.Errorf("failed to get transaction %s: %w", qrCode.TransactionID, err) + } else if txn.VerificationState != VerificationStateReady { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "transaction found in the QR code is not in the ready state") } - txn.VerificationState = verificationStateTheirQRScanned + txn.VerificationState = VerificationStateTheirQRScanned // Verify the keys log.Info().Msg("Verifying keys from QR code") @@ -53,9 +53,9 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by switch qrCode.Mode { case QRCodeModeCrossSigning: - theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUser) + theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID) if err != nil { - return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "couldn't get %s's cross-signing keys: %w", txn.TheirUser, err) + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "couldn't get %s's cross-signing keys: %w", txn.TheirUserID, err) } if bytes.Equal(theirSigningKeys.MasterKey.Bytes(), qrCode.Key1[:]) { log.Info().Msg("Verified that the other device has the master key we expected") @@ -70,7 +70,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "the master key does not match") } - if err := vh.mach.SignUser(ctx, txn.TheirUser, theirSigningKeys.MasterKey); err != nil { + if err := vh.mach.SignUser(ctx, txn.TheirUserID, theirSigningKeys.MasterKey); err != nil { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to sign their master key: %w", err) } case QRCodeModeSelfVerifyingMasterKeyTrusted: @@ -78,7 +78,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by // means that we don't trust the key. Key1 is the master key public // key, and Key2 is what the other device thinks our device key is. - if vh.client.UserID != txn.TheirUser { + if vh.client.UserID != txn.TheirUserID { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "mode %d is only allowed when the other user is the same as the current user", qrCode.Mode) } @@ -114,12 +114,12 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeMasterKeyNotTrusted, "the master key is not trusted by this device, cannot verify device that does not trust the master key") } - if vh.client.UserID != txn.TheirUser { + if vh.client.UserID != txn.TheirUserID { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "mode %d is only allowed when the other user is the same as the current user", qrCode.Mode) } // Get their device - theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) + theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) if err != nil { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to get their device: %w", err) } @@ -140,7 +140,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by // Trust their device theirDevice.Trust = id.TrustStateVerified - err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUser, theirDevice) + err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) if err != nil { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to update device trust state after verifying: %+v", err) } @@ -177,8 +177,12 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by txn.SentOurDone = true if txn.ReceivedTheirDone { log.Debug().Msg("We already received their done event. Setting verification state to done.") - delete(vh.activeTransactions, txn.TransactionID) + if err = vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { + return err + } vh.verificationDone(ctx, txn.TransactionID) + } else { + return vh.store.SaveVerificationTransaction(ctx, txn) } return nil } @@ -196,28 +200,27 @@ func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, ok := vh.activeTransactions[txnID] - if !ok { - log.Warn().Msg("Ignoring QR code scan confirmation for an unknown transaction") - return nil - } else if txn.VerificationState != verificationStateOurQRScanned { + txn, err := vh.store.GetVerificationTransaction(ctx, txnID) + if err != nil { + return fmt.Errorf("failed to get transaction %s: %w", txnID, err) + } else if txn.VerificationState != VerificationStateOurQRScanned { return fmt.Errorf("transaction is not in the scanned state") } log.Info().Msg("Confirming QR code scanned") - if txn.TheirUser == vh.client.UserID { + if txn.TheirUserID == vh.client.UserID { // Self-signing situation. Trust their device. // Get their device - theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) + theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) if err != nil { return err } // Trust their device theirDevice.Trust = id.TrustStateVerified - err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUser, theirDevice) + err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) if err != nil { return fmt.Errorf("failed to update device trust state after verifying: %w", err) } @@ -231,29 +234,33 @@ func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id } } else { // Cross-signing situation. Sign their master key. - theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUser) + theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID) if err != nil { - return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "couldn't get %s's cross-signing keys: %w", txn.TheirUser, err) + return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "couldn't get %s's cross-signing keys: %w", txn.TheirUserID, err) } - if err := vh.mach.SignUser(ctx, txn.TheirUser, theirSigningKeys.MasterKey); err != nil { + if err := vh.mach.SignUser(ctx, txn.TheirUserID, theirSigningKeys.MasterKey); err != nil { return vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to sign their master key: %w", err) } } - err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) + err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { return err } txn.SentOurDone = true if txn.ReceivedTheirDone { - delete(vh.activeTransactions, txn.TransactionID) + if err = vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { + return err + } vh.verificationDone(ctx, txn.TransactionID) + } else { + return vh.store.SaveVerificationTransaction(ctx, txn) } return nil } -func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *verificationTransaction) error { +func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *VerificationTransaction) error { log := vh.getLog(ctx).With(). Str("verification_action", "generate and show QR code"). Stringer("transaction_id", txn.TransactionID). @@ -276,7 +283,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *ve return err } mode := QRCodeModeCrossSigning - if vh.client.UserID == txn.TheirUser { + if vh.client.UserID == txn.TheirUserID { // This is a self-signing situation. if ownMasterKeyTrusted { mode = QRCodeModeSelfVerifyingMasterKeyTrusted @@ -298,7 +305,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *ve key1 = ownCrossSigningPublicKeys.MasterKey.Bytes() // Key 2 is the other user's master signing key. - theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUser) + theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID) if err != nil { return err } @@ -308,7 +315,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *ve key1 = ownCrossSigningPublicKeys.MasterKey.Bytes() // Key 2 is the other device's key. - theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) + theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) if err != nil { return err } diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index e28ec405..1e465867 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -40,30 +40,30 @@ func (vh *VerificationHelper) StartSAS(ctx context.Context, txnID id.Verificatio vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, ok := vh.activeTransactions[txnID] - if !ok { - return fmt.Errorf("unknown transaction ID") - } else if txn.VerificationState != verificationStateReady { + txn, err := vh.store.GetVerificationTransaction(ctx, txnID) + if err != nil { + return fmt.Errorf("failed to get verification transaction %s: %w", txnID, err) + } else if txn.VerificationState != VerificationStateReady { return errors.New("transaction is not in ready state") } else if txn.StartEventContent != nil { return errors.New("start event already sent or received") } - txn.VerificationState = verificationStateSASStarted + txn.VerificationState = VerificationStateSASStarted txn.StartedByUs = true if !slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodSAS) { return fmt.Errorf("the other device does not support SAS verification") } // Ensure that we have their device key. - _, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) + _, err = vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) if err != nil { log.Err(err).Msg("Failed to fetch device") return err } log.Info().Msg("Sending start event") - txn.StartEventContent = &event.VerificationStartEventContent{ + startEventContent := event.VerificationStartEventContent{ FromDevice: vh.client.DeviceID, Method: event.VerificationMethodSAS, @@ -78,7 +78,11 @@ func (vh *VerificationHelper) StartSAS(ctx context.Context, txnID id.Verificatio event.SASMethodEmoji, }, } - return vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationStart, txn.StartEventContent) + if err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationStart, &startEventContent); err != nil { + return err + } + txn.StartEventContent = &startEventContent + return vh.store.SaveVerificationTransaction(ctx, txn) } // ConfirmSAS indicates that the user has confirmed that the SAS matches SAS @@ -94,14 +98,13 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, ok := vh.activeTransactions[txnID] - if !ok { - return fmt.Errorf("unknown transaction ID") - } else if txn.VerificationState != verificationStateSASKeysExchanged { + txn, err := vh.store.GetVerificationTransaction(ctx, txnID) + if err != nil { + return fmt.Errorf("failed to get transaction %s: %w", txnID, err) + } else if txn.VerificationState != VerificationStateSASKeysExchanged { return errors.New("transaction is not in keys exchanged state") } - var err error keys := map[id.KeyID]jsonbytes.UnpaddedBytes{} log.Info().Msg("Signing keys") @@ -109,7 +112,7 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat // My device key myDevice := vh.mach.OwnIdentity() myDeviceKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, myDevice.DeviceID.String()) - keys[myDeviceKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUser, txn.TheirDevice, myDeviceKeyID.String(), myDevice.SigningKey.String()) + keys[myDeviceKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUserID, txn.TheirDeviceID, myDeviceKeyID.String(), myDevice.SigningKey.String()) if err != nil { return err } @@ -118,7 +121,7 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat crossSigningKeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) if crossSigningKeys != nil { crossSigningKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, crossSigningKeys.MasterKey.String()) - keys[crossSigningKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUser, txn.TheirDevice, crossSigningKeyID.String(), crossSigningKeys.MasterKey.String()) + keys[crossSigningKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUserID, txn.TheirDeviceID, crossSigningKeyID.String(), crossSigningKeys.MasterKey.String()) if err != nil { return err } @@ -129,7 +132,7 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat keyIDs = append(keyIDs, keyID.String()) } slices.Sort(keyIDs) - keysMAC, err := vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUser, txn.TheirDevice, "KEY_IDS", strings.Join(keyIDs, ",")) + keysMAC, err := vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUserID, txn.TheirDeviceID, "KEY_IDS", strings.Join(keyIDs, ",")) if err != nil { return err } @@ -145,14 +148,14 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat txn.SentOurMAC = true if txn.ReceivedTheirMAC { - txn.VerificationState = verificationStateSASMACExchanged + txn.VerificationState = VerificationStateSASMACExchanged err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { return err } txn.SentOurDone = true } - return nil + return vh.store.SaveVerificationTransaction(ctx, txn) } // onVerificationStartSAS handles the m.key.verification.start events with @@ -160,7 +163,7 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat // Spec. // // [Section 11.12.2.2]: https://spec.matrix.org/v1.9/client-server-api/#short-authentication-string-sas-verification -func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn *verificationTransaction, evt *event.Event) error { +func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn VerificationTransaction, evt *event.Event) error { startEvt := evt.Content.AsVerificationStart() log := vh.getLog(ctx).With(). Str("verification_action", "start_sas"). @@ -208,7 +211,7 @@ func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn *v return fmt.Errorf("failed to generate ephemeral key: %w", err) } txn.MACMethod = macMethod - txn.EphemeralKey = ephemeralKey + txn.EphemeralKey = &ECDHPrivateKey{ephemeralKey} txn.StartEventContent = startEvt commitment, err := calculateCommitment(ephemeralKey.PublicKey(), startEvt) @@ -226,8 +229,8 @@ func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn *v if err != nil { return fmt.Errorf("failed to send accept event: %w", err) } - txn.VerificationState = verificationStateSASAccepted - return nil + txn.VerificationState = VerificationStateSASAccepted + return vh.store.SaveVerificationTransaction(ctx, txn) } func calculateCommitment(ephemeralPubKey *ecdh.PublicKey, startEvt *event.VerificationStartEventContent) ([]byte, error) { @@ -252,7 +255,7 @@ func calculateCommitment(ephemeralPubKey *ecdh.PublicKey, startEvt *event.Verifi // event. This follows Step 4 of [Section 11.12.2.2] of the Spec. // // [Section 11.12.2.2]: https://spec.matrix.org/v1.9/client-server-api/#short-authentication-string-sas-verification -func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn *verificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn VerificationTransaction, evt *event.Event) { acceptEvt := evt.Content.AsVerificationAccept() log := vh.getLog(ctx).With(). Str("verification_action", "accept"). @@ -267,7 +270,7 @@ func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn *ver vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState != verificationStateSASStarted { + if txn.VerificationState != VerificationStateSASStarted { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "received accept event for a transaction that is not in the started state") return @@ -287,14 +290,18 @@ func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn *ver return } - txn.VerificationState = verificationStateSASAccepted + txn.VerificationState = VerificationStateSASAccepted txn.MACMethod = acceptEvt.MessageAuthenticationCode txn.Commitment = acceptEvt.Commitment - txn.EphemeralKey = ephemeralKey + txn.EphemeralKey = &ECDHPrivateKey{ephemeralKey} txn.EphemeralPublicKeyShared = true + + if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { + log.Err(err).Msg("failed to save verification transaction") + } } -func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn *verificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn VerificationTransaction, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "key"). Logger() @@ -302,22 +309,23 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn *verifi vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState != verificationStateSASAccepted { + if txn.VerificationState != VerificationStateSASAccepted { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "received key event for a transaction that is not in the accepted state") return } var err error - txn.OtherPublicKey, err = ecdh.X25519().NewPublicKey(keyEvt.Key) + publicKey, err := ecdh.X25519().NewPublicKey(keyEvt.Key) if err != nil { log.Err(err).Msg("Failed to generate other public key") return } + txn.OtherPublicKey = &ECDHPublicKey{publicKey} if txn.EphemeralPublicKeyShared { // Verify that the commitment hash is correct - commitment, err := calculateCommitment(txn.OtherPublicKey, txn.StartEventContent) + commitment, err := calculateCommitment(publicKey, txn.StartEventContent) if err != nil { log.Err(err).Msg("Failed to calculate commitment") return @@ -342,7 +350,7 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn *verifi } txn.EphemeralPublicKeyShared = true } - txn.VerificationState = verificationStateSASKeysExchanged + txn.VerificationState = VerificationStateSASKeysExchanged sasBytes, err := vh.verificationSASHKDF(txn) if err != nil { @@ -370,10 +378,14 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn *verifi } } vh.showSAS(ctx, txn.TransactionID, emojis, decimals) + + if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { + log.Err(err).Msg("failed to save verification transaction") + } } -func (vh *VerificationHelper) verificationSASHKDF(txn *verificationTransaction) ([]byte, error) { - sharedSecret, err := txn.EphemeralKey.ECDH(txn.OtherPublicKey) +func (vh *VerificationHelper) verificationSASHKDF(txn VerificationTransaction) ([]byte, error) { + sharedSecret, err := txn.EphemeralKey.ECDH(txn.OtherPublicKey.PublicKey) if err != nil { return nil, err } @@ -388,8 +400,8 @@ func (vh *VerificationHelper) verificationSASHKDF(txn *verificationTransaction) }, "|") theirInfo := strings.Join([]string{ - txn.TheirUser.String(), - txn.TheirDevice.String(), + txn.TheirUserID.String(), + txn.TheirDeviceID.String(), base64.RawStdEncoding.EncodeToString(txn.OtherPublicKey.Bytes()), }, "|") @@ -462,8 +474,8 @@ func BrokenB64Encode(input []byte) string { return string(output) } -func (vh *VerificationHelper) verificationMACHKDF(txn *verificationTransaction, senderUser id.UserID, senderDevice id.DeviceID, receivingUser id.UserID, receivingDevice id.DeviceID, keyID, key string) ([]byte, error) { - sharedSecret, err := txn.EphemeralKey.ECDH(txn.OtherPublicKey) +func (vh *VerificationHelper) verificationMACHKDF(txn VerificationTransaction, senderUser id.UserID, senderDevice id.DeviceID, receivingUser id.UserID, receivingDevice id.DeviceID, keyID, key string) ([]byte, error) { + sharedSecret, err := txn.EphemeralKey.ECDH(txn.OtherPublicKey.PublicKey) if err != nil { return nil, err } @@ -563,7 +575,7 @@ var allEmojis = []rune{ '📌', } -func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn VerificationTransaction, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "mac"). Logger() @@ -579,12 +591,12 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi for keyID := range macEvt.MAC { keyIDs = append(keyIDs, keyID.String()) _, kID := keyID.Parse() - if kID == txn.TheirDevice.String() { + if kID == txn.TheirDeviceID.String() { hasTheirDeviceKey = true } } slices.Sort(keyIDs) - expectedKeyMAC, err := vh.verificationMACHKDF(txn, txn.TheirUser, txn.TheirDevice, vh.client.UserID, vh.client.DeviceID, "KEY_IDS", strings.Join(keyIDs, ",")) + expectedKeyMAC, err := vh.verificationMACHKDF(txn, txn.TheirUserID, txn.TheirDeviceID, vh.client.UserID, vh.client.DeviceID, "KEY_IDS", strings.Join(keyIDs, ",")) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeSASMismatch, "failed to calculate key list MAC: %w", err) return @@ -610,8 +622,8 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi var key string var theirDevice *id.Device - if kID == txn.TheirDevice.String() { - theirDevice, err = vh.mach.GetOrFetchDevice(ctx, txn.TheirUser, txn.TheirDevice) + if kID == txn.TheirDeviceID.String() { + theirDevice, err = vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to fetch their device: %w", err) return @@ -630,7 +642,7 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi key = crossSigningKeys.MasterKey.String() } - expectedMAC, err := vh.verificationMACHKDF(txn, txn.TheirUser, txn.TheirDevice, vh.client.UserID, vh.client.DeviceID, keyID.String(), key) + expectedMAC, err := vh.verificationMACHKDF(txn, txn.TheirUserID, txn.TheirDeviceID, vh.client.UserID, vh.client.DeviceID, keyID.String(), key) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to calculate key MAC: %w", err) return @@ -641,9 +653,9 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi } // Trust their device - if kID == txn.TheirDevice.String() { + if kID == txn.TheirDeviceID.String() { theirDevice.Trust = id.TrustStateVerified - err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUser, theirDevice) + err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to update device trust state after verifying: %w", err) return @@ -654,7 +666,7 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi txn.ReceivedTheirMAC = true if txn.SentOurMAC { - txn.VerificationState = verificationStateSASMACExchanged + txn.VerificationState = VerificationStateSASMACExchanged err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to send verification done event: %w", err) @@ -662,4 +674,8 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn *verifi } txn.SentOurDone = true } + + if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { + log.Err(err).Msg("failed to save verification transaction") + } } diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index be8357f5..4c10827c 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -9,7 +9,7 @@ package verificationhelper import ( "bytes" "context" - "crypto/ecdh" + "errors" "fmt" "sync" "time" @@ -25,86 +25,6 @@ import ( "maunium.net/go/mautrix/id" ) -type verificationState int - -const ( - verificationStateRequested verificationState = iota - verificationStateReady - - verificationStateTheirQRScanned // We scanned their QR code - verificationStateOurQRScanned // They scanned our QR code - - verificationStateSASStarted // An SAS verification has been started - verificationStateSASAccepted // An SAS verification has been accepted - verificationStateSASKeysExchanged // An SAS verification has exchanged keys - verificationStateSASMACExchanged // An SAS verification has exchanged MACs -) - -func (step verificationState) String() string { - switch step { - case verificationStateRequested: - return "requested" - case verificationStateReady: - return "ready" - case verificationStateTheirQRScanned: - return "their_qr_scanned" - case verificationStateOurQRScanned: - return "our_qr_scanned" - case verificationStateSASStarted: - return "sas_started" - case verificationStateSASAccepted: - return "sas_accepted" - case verificationStateSASKeysExchanged: - return "sas_keys_exchanged" - case verificationStateSASMACExchanged: - return "sas_mac" - default: - return fmt.Sprintf("verificationStep(%d)", step) - } -} - -type verificationTransaction struct { - // RoomID is the room ID if the verification is happening in a room or - // empty if it is a to-device verification. - RoomID id.RoomID - - // VerificationState is the current step of the verification flow. - VerificationState verificationState - // TransactionID is the ID of the verification transaction. - TransactionID id.VerificationTransactionID - - // TheirDevice is the device ID of the device that either made the initial - // request or accepted our request. - TheirDevice id.DeviceID - // TheirUser is the user ID of the other user. - TheirUser id.UserID - // TheirSupportedMethods is a list of verification methods that the other - // device supports. - TheirSupportedMethods []event.VerificationMethod - - // SentToDeviceIDs is a list of devices which the initial request was sent - // to. This is only used for to-device verification requests, and is meant - // to be used to send cancellation requests to all other devices when a - // verification request is accepted via a m.key.verification.ready event. - SentToDeviceIDs []id.DeviceID - - // QRCodeSharedSecret is the shared secret that was encoded in the QR code - // that we showed. - QRCodeSharedSecret []byte - - StartedByUs bool // Whether the verification was started by us - StartEventContent *event.VerificationStartEventContent // The m.key.verification.start event content - Commitment []byte // The commitment from the m.key.verification.accept event - MACMethod event.MACMethod // The method used to calculate the MAC - EphemeralKey *ecdh.PrivateKey // The ephemeral key - EphemeralPublicKeyShared bool // Whether this device's ephemeral public key has been shared - OtherPublicKey *ecdh.PublicKey // The other device's ephemeral public key - ReceivedTheirMAC bool // Whether we have received their MAC - SentOurMAC bool // Whether we have sent our MAC - ReceivedTheirDone bool // Whether we have received their done event - SentOurDone bool // Whether we have sent our done event -} - // RequiredCallbacks is an interface representing the callbacks required for // the [VerificationHelper]. type RequiredCallbacks interface { @@ -145,8 +65,9 @@ type VerificationHelper struct { client *mautrix.Client mach *crypto.OlmMachine - activeTransactions map[id.VerificationTransactionID]*verificationTransaction + store VerificationStore activeTransactionsLock sync.Mutex + // activeTransactions map[id.VerificationTransactionID]*verificationTransaction // supportedMethods are the methods that *we* support supportedMethods []event.VerificationMethod @@ -163,15 +84,19 @@ type VerificationHelper struct { var _ mautrix.VerificationHelper = (*VerificationHelper)(nil) -func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, callbacks any, supportsScan bool) *VerificationHelper { +func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, store VerificationStore, callbacks any, supportsScan bool) *VerificationHelper { if client.Crypto == nil { panic("client.Crypto is nil") } + if store == nil { + store = NewInMemoryVerificationStore() + } + helper := VerificationHelper{ - client: client, - mach: mach, - activeTransactions: map[id.VerificationTransactionID]*verificationTransaction{}, + client: client, + mach: mach, + store: store, } if c, ok := callbacks.(RequiredCallbacks); !ok { @@ -233,7 +158,7 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { // Wrapper for the event handlers to check that the transaction ID is known // and ignore the event if it isn't. - wrapHandler := func(callback func(context.Context, *verificationTransaction, *event.Event)) func(context.Context, *event.Event) { + wrapHandler := func(callback func(context.Context, VerificationTransaction, *event.Event)) func(context.Context, *event.Event) { return func(ctx context.Context, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "check transaction ID"). @@ -257,8 +182,11 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { log = log.With().Stringer("transaction_id", transactionID).Logger() vh.activeTransactionsLock.Lock() - txn, ok := vh.activeTransactions[transactionID] - if !ok { + txn, err := vh.store.GetVerificationTransaction(ctx, transactionID) + if err != nil && errors.Is(err, ErrUnknownVerificationTransaction) { + log.Err(err).Msg("failed to get verification transaction") + return + } else if errors.Is(err, ErrUnknownVerificationTransaction) { // If it's a cancellation event for an unknown transaction, we // can just ignore it. if evt.Type == event.ToDeviceVerificationCancel || evt.Type == event.InRoomVerificationCancel { @@ -271,9 +199,9 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { // We have to create a fake transaction so that the call to // cancelVerificationTxn works. - txn = &verificationTransaction{ - RoomID: evt.RoomID, - TheirUser: evt.Sender, + txn = VerificationTransaction{ + RoomID: evt.RoomID, + TheirUserID: evt.Sender, } if transactionable, ok := evt.Content.Parsed.(event.VerificationTransactionable); ok { txn.TransactionID = transactionable.GetTransactionID() @@ -281,7 +209,7 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { txn.TransactionID = id.VerificationTransactionID(evt.ID) } if fromDevice, ok := evt.Content.Raw["from_device"]; ok { - txn.TheirDevice = id.DeviceID(fromDevice.(string)) + txn.TheirDeviceID = id.DeviceID(fromDevice.(string)) } // Send a cancellation event. @@ -322,7 +250,11 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { syncer.OnEventType(event.InRoomVerificationKey, wrapHandler(vh.onVerificationKey)) // SAS syncer.OnEventType(event.InRoomVerificationMAC, wrapHandler(vh.onVerificationMAC)) // SAS - return nil + allTransactions, err := vh.store.GetAllVerificationTransactions(ctx) + for _, txn := range allTransactions { + vh.expireTransactionAt(txn.TransactionID, txn.ExpirationTime.Time) + } + return err } // StartVerification starts an interactive verification flow with the given @@ -382,13 +314,12 @@ func (vh *VerificationHelper) StartVerification(ctx context.Context, to id.UserI vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - vh.activeTransactions[txnID] = &verificationTransaction{ - VerificationState: verificationStateRequested, + return txnID, vh.store.SaveVerificationTransaction(ctx, VerificationTransaction{ + VerificationState: VerificationStateRequested, TransactionID: txnID, - TheirUser: to, + TheirUserID: to, SentToDeviceIDs: maps.Keys(devices), - } - return txnID, nil + }) } // StartInRoomVerification starts an interactive verification flow with the @@ -422,13 +353,12 @@ func (vh *VerificationHelper) StartInRoomVerification(ctx context.Context, roomI vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - vh.activeTransactions[txnID] = &verificationTransaction{ + return txnID, vh.store.SaveVerificationTransaction(ctx, VerificationTransaction{ RoomID: roomID, - VerificationState: verificationStateRequested, + VerificationState: VerificationStateRequested, TransactionID: txnID, - TheirUser: to, - } - return txnID, nil + TheirUserID: to, + }) } // AcceptVerification accepts a verification request. The transaction ID should @@ -440,10 +370,10 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V Stringer("transaction_id", txnID). Logger() - txn, ok := vh.activeTransactions[txnID] - if !ok { - return fmt.Errorf("unknown transaction ID") - } else if txn.VerificationState != verificationStateRequested { + txn, err := vh.store.GetVerificationTransaction(ctx, txnID) + if err != nil { + return err + } else if txn.VerificationState != VerificationStateRequested { return fmt.Errorf("transaction is not in the requested state") } @@ -472,17 +402,20 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V FromDevice: vh.client.DeviceID, Methods: maps.Keys(supportedMethods), } - err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationReady, readyEvt) + err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationReady, readyEvt) if err != nil { return err } - txn.VerificationState = verificationStateReady + txn.VerificationState = VerificationStateReady if vh.scanQRCode != nil && slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) { vh.scanQRCode(ctx, txn.TransactionID) } - return vh.generateAndShowQRCode(ctx, txn) + if err := vh.generateAndShowQRCode(ctx, &txn); err != nil { + return err + } + return vh.store.SaveVerificationTransaction(ctx, txn) } // DismissVerification dismisses the verification request with the given @@ -492,8 +425,7 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V func (vh *VerificationHelper) DismissVerification(ctx context.Context, txnID id.VerificationTransactionID) error { vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - delete(vh.activeTransactions, txnID) - return nil + return vh.store.DeleteVerification(ctx, txnID) } // DismissVerification cancels the verification request with the given @@ -504,9 +436,9 @@ func (vh *VerificationHelper) CancelVerification(ctx context.Context, txnID id.V vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, ok := vh.activeTransactions[txnID] - if !ok { - return fmt.Errorf("unknown transaction ID") + txn, err := vh.store.GetVerificationTransaction(ctx, txnID) + if err != nil { + return err } log := vh.getLog(ctx).With(). Str("verification_action", "cancel verification"). @@ -527,29 +459,28 @@ func (vh *VerificationHelper) CancelVerification(ctx context.Context, txnID id.V } else { cancelEvt.SetTransactionID(txn.TransactionID) req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{ - txn.TheirUser: {}, + txn.TheirUserID: {}, }} - if len(txn.TheirDevice) > 0 { + if len(txn.TheirDeviceID) > 0 { // Send the cancellation event to only the device that accepted the // verification request. All of the other devices already received a // cancellation event with code "m.acceped". - req.Messages[txn.TheirUser][txn.TheirDevice] = &event.Content{Parsed: cancelEvt} + req.Messages[txn.TheirUserID][txn.TheirDeviceID] = &event.Content{Parsed: cancelEvt} } else { // Send the cancellation event to all of the devices that we sent the // request to. for _, deviceID := range txn.SentToDeviceIDs { if deviceID != vh.client.DeviceID { - req.Messages[txn.TheirUser][deviceID] = &event.Content{Parsed: cancelEvt} + req.Messages[txn.TheirUserID][deviceID] = &event.Content{Parsed: cancelEvt} } } } _, err := vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) if err != nil { - return fmt.Errorf("failed to send m.key.verification.cancel event to %v: %w", maps.Keys(req.Messages[txn.TheirUser]), err) + return fmt.Errorf("failed to send m.key.verification.cancel event to %v: %w", maps.Keys(req.Messages[txn.TheirUserID]), err) } } - delete(vh.activeTransactions, txn.TransactionID) - return nil + return vh.store.DeleteVerification(ctx, txn.TransactionID) } // sendVerificationEvent sends a verification event to the other user's device @@ -561,7 +492,7 @@ func (vh *VerificationHelper) CancelVerification(ctx context.Context, txnID id.V // [event.VerificationTransactionable]. // - evtType can be either the to-device or in-room version of the event type // as it is always stringified. -func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn *verificationTransaction, evtType event.Type, content any) error { +func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn VerificationTransaction, evtType event.Type, content any) error { if txn.RoomID != "" { content.(event.Relatable).SetRelatesTo(&event.RelatesTo{Type: event.RelReference, EventID: id.EventID(txn.TransactionID)}) _, err := vh.client.SendMessageEvent(ctx, txn.RoomID, evtType, &event.Content{ @@ -573,13 +504,13 @@ func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn *ve } else { content.(event.VerificationTransactionable).SetTransactionID(txn.TransactionID) req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{ - txn.TheirUser: { - txn.TheirDevice: &event.Content{Parsed: content}, + txn.TheirUserID: { + txn.TheirDeviceID: &event.Content{Parsed: content}, }, }} _, err := vh.client.SendToDevice(ctx, evtType, &req) if err != nil { - return fmt.Errorf("failed to send %s event to %s: %w", evtType.String(), txn.TheirDevice, err) + return fmt.Errorf("failed to send %s event to %s: %w", evtType.String(), txn.TheirDeviceID, err) } } return nil @@ -591,7 +522,7 @@ func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn *ve // directly to expose the error to its caller). // // Must always be called with the activeTransactionsLock held. -func (vh *VerificationHelper) cancelVerificationTxn(ctx context.Context, txn *verificationTransaction, code event.VerificationCancelCode, reasonFmtStr string, fmtArgs ...any) error { +func (vh *VerificationHelper) cancelVerificationTxn(ctx context.Context, txn VerificationTransaction, code event.VerificationCancelCode, reasonFmtStr string, fmtArgs ...any) error { log := vh.getLog(ctx) reason := fmt.Errorf(reasonFmtStr, fmtArgs...).Error() log.Info(). @@ -605,7 +536,9 @@ func (vh *VerificationHelper) cancelVerificationTxn(ctx context.Context, txn *ve log.Err(err).Msg("failed to send cancellation event") return fmt.Errorf("failed to send cancel verification event (code: %s, reason: %s): %w", code, reason, err) } - delete(vh.activeTransactions, txn.TransactionID) + if err = vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { + log.Err(err).Msg("deleting verification failed") + } vh.verificationCancelledCallback(ctx, txn.TransactionID, code, reason) return fmt.Errorf("verification cancelled (code: %s): %s", code, reason) } @@ -684,54 +617,58 @@ func (vh *VerificationHelper) onVerificationRequest(ctx context.Context, evt *ev } vh.activeTransactionsLock.Lock() - newTxn := &verificationTransaction{ + newTxn := VerificationTransaction{ + ExpirationTime: jsontime.UnixMilli{Time: verificationRequest.Timestamp.Add(time.Minute * 10)}, RoomID: evt.RoomID, - VerificationState: verificationStateRequested, + VerificationState: VerificationStateRequested, TransactionID: verificationRequest.TransactionID, - TheirDevice: verificationRequest.FromDevice, - TheirUser: evt.Sender, + TheirDeviceID: verificationRequest.FromDevice, + TheirUserID: evt.Sender, TheirSupportedMethods: verificationRequest.Methods, } - for existingTxnID, existingTxn := range vh.activeTransactions { - if existingTxn.TheirUser == evt.Sender && existingTxn.TheirDevice == verificationRequest.FromDevice && existingTxnID != verificationRequest.TransactionID { - vh.cancelVerificationTxn(ctx, existingTxn, event.VerificationCancelCodeUnexpectedMessage, "received multiple verification requests from the same device") + if txn, err := vh.store.FindVerificationTransactionForUserDevice(ctx, evt.Sender, verificationRequest.FromDevice); err != nil && !errors.Is(err, ErrUnknownVerificationTransaction) { + log.Err(err).Stringer("sender", evt.Sender).Stringer("device_id", verificationRequest.FromDevice).Msg("failed to find verification transaction") + vh.activeTransactionsLock.Unlock() + return + } else if !errors.Is(err, ErrUnknownVerificationTransaction) { + if txn.TransactionID == verificationRequest.TransactionID { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "received a new verification request for the same transaction ID") + } else { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "received multiple verification requests from the same device") vh.cancelVerificationTxn(ctx, newTxn, event.VerificationCancelCodeUnexpectedMessage, "received multiple verification requests from the same device") - delete(vh.activeTransactions, existingTxnID) - vh.activeTransactionsLock.Unlock() - return - } - - if existingTxnID == verificationRequest.TransactionID { - vh.cancelVerificationTxn(ctx, existingTxn, event.VerificationCancelCodeUnexpectedMessage, "received a new verification request for the same transaction ID") - delete(vh.activeTransactions, existingTxnID) - vh.activeTransactionsLock.Unlock() - return } + vh.activeTransactionsLock.Unlock() + return + } + if err := vh.store.SaveVerificationTransaction(ctx, newTxn); err != nil { + log.Err(err).Msg("failed to save verification transaction") } - vh.activeTransactions[verificationRequest.TransactionID] = newTxn vh.activeTransactionsLock.Unlock() vh.expireTransactionAt(verificationRequest.TransactionID, verificationRequest.Timestamp.Add(time.Minute*10)) vh.verificationRequested(ctx, verificationRequest.TransactionID, evt.Sender) } -func (vh *VerificationHelper) expireTransactionAt(txnID id.VerificationTransactionID, expireAt time.Time) { +func (vh *VerificationHelper) expireTransactionAt(txnID id.VerificationTransactionID, expiresAt time.Time) { go func() { - time.Sleep(time.Until(expireAt)) + time.Sleep(time.Until(expiresAt)) vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - txn, ok := vh.activeTransactions[txnID] - if !ok { + txn, err := vh.store.GetVerificationTransaction(context.Background(), txnID) + if err == ErrUnknownVerificationTransaction { + // Already deleted, nothing to expire return + } else if err != nil { + vh.getLog(context.Background()).Err(err).Msg("failed to get verification transaction to expire") + } else { + vh.cancelVerificationTxn(context.Background(), txn, event.VerificationCancelCodeTimeout, "verification timed out") } - - vh.cancelVerificationTxn(context.Background(), txn, event.VerificationCancelCodeTimeout, "verification timed out") }() } -func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *verificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn VerificationTransaction, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "verification ready"). Logger() @@ -739,7 +676,7 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *veri vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState != verificationStateRequested { + if txn.VerificationState != VerificationStateRequested { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "verification ready event received for a transaction that is not in the requested state") return } @@ -747,12 +684,12 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *veri readyEvt := evt.Content.AsVerificationReady() // Update the transaction state. - txn.VerificationState = verificationStateReady - txn.TheirDevice = readyEvt.FromDevice + txn.VerificationState = VerificationStateReady + txn.TheirDeviceID = readyEvt.FromDevice txn.TheirSupportedMethods = readyEvt.Methods log.Info(). - Stringer("their_device_id", txn.TheirDevice). + Stringer("their_device_id", txn.TheirDeviceID). Any("their_supported_methods", txn.TheirSupportedMethods). Msg("Received verification ready event") @@ -766,16 +703,16 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *veri Reason: "The verification was accepted on another device.", }, } - req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUser: {}}} + req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUserID: {}}} for _, deviceID := range txn.SentToDeviceIDs { - if deviceID == txn.TheirDevice || deviceID == vh.client.DeviceID { + if deviceID == txn.TheirDeviceID || deviceID == vh.client.DeviceID { // Don't ever send a cancellation to the device that accepted // the request or to our own device (which can happen if this // is a self-verification). continue } - req.Messages[txn.TheirUser][deviceID] = content + req.Messages[txn.TheirUserID][deviceID] = content } _, err := vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) if err != nil { @@ -787,18 +724,19 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn *veri vh.scanQRCode(ctx, txn.TransactionID) } - err := vh.generateAndShowQRCode(ctx, txn) - if err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to generate and show QR code: %w", err) + if err := vh.generateAndShowQRCode(ctx, &txn); err != nil { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to generate and show QR code: %w", err) + } else if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to save verification transaction: %w", err) } } -func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *verificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn VerificationTransaction, evt *event.Event) { startEvt := evt.Content.AsVerificationStart() log := vh.getLog(ctx).With(). Str("verification_action", "verification start"). Str("method", string(startEvt.Method)). - Stringer("their_device_id", txn.TheirDevice). + Stringer("their_device_id", txn.TheirDeviceID). Any("their_supported_methods", txn.TheirSupportedMethods). Bool("started_by_us", txn.StartedByUs). Logger() @@ -808,7 +746,7 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *veri vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if txn.VerificationState == verificationStateSASStarted || txn.VerificationState == verificationStateOurQRScanned || txn.VerificationState == verificationStateTheirQRScanned { + if txn.VerificationState == VerificationStateSASStarted || txn.VerificationState == VerificationStateOurQRScanned || txn.VerificationState == VerificationStateTheirQRScanned { // We might have sent the event, and they also sent an event. if txn.StartEventContent == nil || !txn.StartedByUs { // We didn't sent a start event yet, so we have gotten ourselves @@ -840,12 +778,12 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *veri return } - if txn.TheirUser < vh.client.UserID || (txn.TheirUser == vh.client.UserID && txn.TheirDevice < vh.client.DeviceID) { + if txn.TheirUserID < vh.client.UserID || (txn.TheirUserID == vh.client.UserID && txn.TheirDeviceID < vh.client.DeviceID) { log.Debug().Msg("Using their start event instead of ours because they are alphabetically before us") txn.StartedByUs = false txn.StartEventContent = startEvt } - } else if txn.VerificationState != verificationStateReady { + } else if txn.VerificationState != VerificationStateReady { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "got start event for transaction that is not in ready state") return } @@ -853,7 +791,7 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *veri switch startEvt.Method { case event.VerificationMethodSAS: log.Info().Msg("Received SAS start event") - txn.VerificationState = verificationStateSASStarted + txn.VerificationState = VerificationStateSASStarted if err := vh.onVerificationStartSAS(ctx, txn, evt); err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to handle SAS verification start: %w", err) } @@ -863,8 +801,11 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *veri vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "reciprocated shared secret does not match") return } - txn.VerificationState = verificationStateOurQRScanned + txn.VerificationState = VerificationStateOurQRScanned vh.qrCodeScaned(ctx, txn.TransactionID) + if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { + log.Err(err).Msg("failed to save verification transaction") + } default: // Note that we should never get m.qr_code.show.v1 or m.qr_code.scan.v1 // here, since the start command for scanning and showing QR codes @@ -874,17 +815,18 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn *veri } } -func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn *verificationTransaction, evt *event.Event) { - vh.getLog(ctx).Info(). +func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn VerificationTransaction, evt *event.Event) { + log := vh.getLog(ctx).With(). Str("verification_action", "done"). Stringer("transaction_id", txn.TransactionID). Bool("sent_our_done", txn.SentOurDone). - Msg("Verification done") + Logger() + log.Info().Msg("Verification done") vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() - if !slices.Contains([]verificationState{ - verificationStateTheirQRScanned, verificationStateOurQRScanned, verificationStateSASMACExchanged, + if !slices.Contains([]VerificationState{ + VerificationStateTheirQRScanned, VerificationStateOurQRScanned, VerificationStateSASMACExchanged, }, txn.VerificationState) { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "got done event for transaction that is not in QR-scanned or MAC-exchanged state") return @@ -892,12 +834,16 @@ func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn *verif txn.ReceivedTheirDone = true if txn.SentOurDone { - delete(vh.activeTransactions, txn.TransactionID) + if err := vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { + log.Err(err).Msg("Delete verification failed") + } vh.verificationDone(ctx, txn.TransactionID) + } else if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { + log.Err(err).Msg("failed to save verification transaction") } } -func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn *verificationTransaction, evt *event.Event) { +func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn VerificationTransaction, evt *event.Event) { cancelEvt := evt.Content.AsVerificationCancel() log := vh.getLog(ctx).With(). Str("verification_action", "cancel"). @@ -923,7 +869,7 @@ func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn *ver // that is currently in the REQUESTED state, then we will send // cancellations to all of the devices that we sent the request to. This // will ensure that all of the clients know that the request was cancelled. - if txn.VerificationState == verificationStateRequested && len(txn.SentToDeviceIDs) > 0 { + if txn.VerificationState == VerificationStateRequested && len(txn.SentToDeviceIDs) > 0 { content := &event.Content{ Parsed: &event.VerificationCancelEventContent{ ToDeviceVerificationEvent: event.ToDeviceVerificationEvent{TransactionID: txn.TransactionID}, @@ -931,9 +877,9 @@ func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn *ver Reason: "The verification was rejected from another device.", }, } - req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUser: {}}} + req := mautrix.ReqSendToDevice{Messages: map[id.UserID]map[id.DeviceID]*event.Content{txn.TheirUserID: {}}} for _, deviceID := range txn.SentToDeviceIDs { - req.Messages[txn.TheirUser][deviceID] = content + req.Messages[txn.TheirUserID][deviceID] = content } _, err := vh.client.SendToDevice(ctx, event.ToDeviceVerificationCancel, &req) if err != nil { @@ -941,6 +887,8 @@ func (vh *VerificationHelper) onVerificationCancel(ctx context.Context, txn *ver } } - delete(vh.activeTransactions, txn.TransactionID) + if err := vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { + log.Err(err).Msg("Delete verification failed") + } vh.verificationCancelledCallback(ctx, txn.TransactionID, cancelEvt.Code, cancelEvt.Reason) } diff --git a/crypto/verificationhelper/verificationhelper_qr_self_test.go b/crypto/verificationhelper/verificationhelper_qr_self_test.go index 11358b88..937cc414 100644 --- a/crypto/verificationhelper/verificationhelper_qr_self_test.go +++ b/crypto/verificationhelper/verificationhelper_qr_self_test.go @@ -278,12 +278,12 @@ func TestSelfVerification_ScanQRTransactionIDCorrupted(t *testing.T) { // Emulate scanning the QR code shown by the receiving device // on the sending device. err = sendingHelper.HandleScannedQRData(ctx, receivingShownQRCodeBytes) - assert.ErrorContains(t, err, "unknown transaction ID found in QR code") + assert.ErrorContains(t, err, "unknown transaction ID") // Emulate scanning the QR code shown by the sending device on // the receiving device. err = receivingHelper.HandleScannedQRData(ctx, sendingShownQRCodeBytes) - assert.ErrorContains(t, err, "unknown transaction ID found in QR code") + assert.ErrorContains(t, err, "unknown transaction ID") } func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go index 273042c3..d0bf2298 100644 --- a/crypto/verificationhelper/verificationhelper_test.go +++ b/crypto/verificationhelper/verificationhelper_test.go @@ -65,11 +65,11 @@ func initServerAndLoginAliceBob(t *testing.T, ctx context.Context) (ts *mockServ func initDefaultCallbacks(t *testing.T, ctx context.Context, sendingClient, receivingClient *mautrix.Client, sendingMachine, receivingMachine *crypto.OlmMachine) (sendingCallbacks, receivingCallbacks *allVerificationCallbacks, sendingHelper, receivingHelper *verificationhelper.VerificationHelper) { t.Helper() sendingCallbacks = newAllVerificationCallbacks() - sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, sendingCallbacks, true) + sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, sendingCallbacks, true) require.NoError(t, sendingHelper.Init(ctx)) receivingCallbacks = newAllVerificationCallbacks() - receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, receivingCallbacks, true) + receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, nil, receivingCallbacks, true) require.NoError(t, receivingHelper.Init(ctx)) return } @@ -104,7 +104,7 @@ func TestVerification_Start(t *testing.T) { addDeviceID(ctx, cryptoStore, aliceUserID, receivingDeviceID) addDeviceID(ctx, cryptoStore, aliceUserID, receivingDeviceID2) - senderHelper := verificationhelper.NewVerificationHelper(client, client.Crypto.(*cryptohelper.CryptoHelper).Machine(), tc.callbacks, tc.supportsScan) + senderHelper := verificationhelper.NewVerificationHelper(client, client.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, tc.callbacks, tc.supportsScan) err := senderHelper.Init(ctx) require.NoError(t, err) @@ -151,7 +151,7 @@ func TestVerification_StartThenCancel(t *testing.T) { bystanderClient, _ := ts.Login(t, ctx, aliceUserID, bystanderDeviceID) bystanderMachine := bystanderClient.Crypto.(*cryptohelper.CryptoHelper).Machine() - bystanderHelper := verificationhelper.NewVerificationHelper(bystanderClient, bystanderMachine, newAllVerificationCallbacks(), true) + bystanderHelper := verificationhelper.NewVerificationHelper(bystanderClient, bystanderMachine, nil, newAllVerificationCallbacks(), true) require.NoError(t, bystanderHelper.Init(ctx)) require.NoError(t, sendingCryptoStore.PutDevice(ctx, aliceUserID, bystanderMachine.OwnIdentity())) @@ -241,12 +241,12 @@ func TestVerification_Accept_NoSupportedMethods(t *testing.T) { assert.NotEmpty(t, recoveryKey) assert.NotNil(t, cache) - sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, newAllVerificationCallbacks(), true) + sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, newAllVerificationCallbacks(), true) err = sendingHelper.Init(ctx) require.NoError(t, err) receivingCallbacks := newBaseVerificationCallbacks() - receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine(), receivingCallbacks, false) + receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, receivingCallbacks, false) err = receivingHelper.Init(ctx) require.NoError(t, err) @@ -289,11 +289,11 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { assert.NotEmpty(t, recoveryKey) assert.NotNil(t, sendingCrossSigningKeysCache) - sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, tc.sendingCallbacks, tc.sendingSupportsScan) + sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, tc.sendingCallbacks, tc.sendingSupportsScan) err = sendingHelper.Init(ctx) require.NoError(t, err) - receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, tc.receivingCallbacks, tc.receivingSupportsScan) + receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, nil, tc.receivingCallbacks, tc.receivingSupportsScan) err = receivingHelper.Init(ctx) require.NoError(t, err) diff --git a/crypto/verificationhelper/verificationstore.go b/crypto/verificationhelper/verificationstore.go new file mode 100644 index 00000000..1eb8f752 --- /dev/null +++ b/crypto/verificationhelper/verificationstore.go @@ -0,0 +1,159 @@ +package verificationhelper + +import ( + "context" + "errors" + "fmt" + + "go.mau.fi/util/jsontime" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +var ErrUnknownVerificationTransaction = errors.New("unknown transaction ID") + +type VerificationState int + +const ( + VerificationStateRequested VerificationState = iota + VerificationStateReady + + VerificationStateTheirQRScanned // We scanned their QR code + VerificationStateOurQRScanned // They scanned our QR code + + VerificationStateSASStarted // An SAS verification has been started + VerificationStateSASAccepted // An SAS verification has been accepted + VerificationStateSASKeysExchanged // An SAS verification has exchanged keys + VerificationStateSASMACExchanged // An SAS verification has exchanged MACs +) + +func (step VerificationState) String() string { + switch step { + case VerificationStateRequested: + return "requested" + case VerificationStateReady: + return "ready" + case VerificationStateTheirQRScanned: + return "their_qr_scanned" + case VerificationStateOurQRScanned: + return "our_qr_scanned" + case VerificationStateSASStarted: + return "sas_started" + case VerificationStateSASAccepted: + return "sas_accepted" + case VerificationStateSASKeysExchanged: + return "sas_keys_exchanged" + case VerificationStateSASMACExchanged: + return "sas_mac" + default: + return fmt.Sprintf("VerificationState(%d)", step) + } +} + +type VerificationTransaction struct { + ExpirationTime jsontime.UnixMilli `json:"expiration_time,omitempty"` + + // RoomID is the room ID if the verification is happening in a room or + // empty if it is a to-device verification. + RoomID id.RoomID `json:"room_id,omitempty"` + + // VerificationState is the current step of the verification flow. + VerificationState VerificationState `json:"verification_state"` + // TransactionID is the ID of the verification transaction. + TransactionID id.VerificationTransactionID `json:"transaction_id"` + + // TheirDeviceID is the device ID of the device that either made the + // initial request or accepted our request. + TheirDeviceID id.DeviceID `json:"their_device_id,omitempty"` + // TheirUserID is the user ID of the other user. + TheirUserID id.UserID `json:"their_user_id,omitempty"` + // TheirSupportedMethods is a list of verification methods that the other + // device supports. + TheirSupportedMethods []event.VerificationMethod `json:"their_supported_methods,omitempty"` + + // SentToDeviceIDs is a list of devices which the initial request was sent + // to. This is only used for to-device verification requests, and is meant + // to be used to send cancellation requests to all other devices when a + // verification request is accepted via a m.key.verification.ready event. + SentToDeviceIDs []id.DeviceID `json:"sent_to_device_ids,omitempty"` + + // QRCodeSharedSecret is the shared secret that was encoded in the QR code + // that we showed. + QRCodeSharedSecret []byte `json:"qr_code_shared_secret,omitempty"` + + StartedByUs bool `json:"started_by_us,omitempty"` // Whether the verification was started by us + StartEventContent *event.VerificationStartEventContent `json:"start_event_content,omitempty"` // The m.key.verification.start event content + Commitment []byte `json:"committment,omitempty"` // The commitment from the m.key.verification.accept event + MACMethod event.MACMethod `json:"mac_method,omitempty"` // The method used to calculate the MAC + EphemeralKey *ECDHPrivateKey `json:"ephemeral_key,omitempty"` // The ephemeral key + EphemeralPublicKeyShared bool `json:"ephemeral_public_key_shared,omitempty"` // Whether this device's ephemeral public key has been shared + OtherPublicKey *ECDHPublicKey `json:"other_public_key,omitempty"` // The other device's ephemeral public key + ReceivedTheirMAC bool `json:"received_their_mac,omitempty"` // Whether we have received their MAC + SentOurMAC bool `json:"sent_our_mac,omitempty"` // Whether we have sent our MAC + ReceivedTheirDone bool `json:"received_their_done,omitempty"` // Whether we have received their done event + SentOurDone bool `json:"sent_our_done,omitempty"` // Whether we have sent our done event +} + +type VerificationStore interface { + // DeleteVerification deletes a verification transaction by ID + DeleteVerification(ctx context.Context, txnID id.VerificationTransactionID) error + // GetVerificationTransaction gets a verification transaction by ID + GetVerificationTransaction(ctx context.Context, txnID id.VerificationTransactionID) (VerificationTransaction, error) + // SaveVerificationTransaction saves a verification transaction by ID + SaveVerificationTransaction(ctx context.Context, txn VerificationTransaction) error + // FindVerificationTransactionForUserDevice finds a verification + // transaction by user and device ID + FindVerificationTransactionForUserDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (VerificationTransaction, error) + // GetAllVerificationTransactions returns all of the verification + // transactions. This is used to reset the cancellation timeouts. + GetAllVerificationTransactions(ctx context.Context) ([]VerificationTransaction, error) +} + +type InMemoryVerificationStore struct { + txns map[id.VerificationTransactionID]VerificationTransaction +} + +var _ VerificationStore = (*InMemoryVerificationStore)(nil) + +func NewInMemoryVerificationStore() *InMemoryVerificationStore { + return &InMemoryVerificationStore{ + txns: map[id.VerificationTransactionID]VerificationTransaction{}, + } +} + +func (i *InMemoryVerificationStore) DeleteVerification(ctx context.Context, txnID id.VerificationTransactionID) error { + if _, ok := i.txns[txnID]; !ok { + return ErrUnknownVerificationTransaction + } + delete(i.txns, txnID) + return nil +} + +func (i *InMemoryVerificationStore) GetVerificationTransaction(ctx context.Context, txnID id.VerificationTransactionID) (VerificationTransaction, error) { + if _, ok := i.txns[txnID]; !ok { + return VerificationTransaction{}, ErrUnknownVerificationTransaction + } + return i.txns[txnID], nil +} + +func (i *InMemoryVerificationStore) SaveVerificationTransaction(ctx context.Context, txn VerificationTransaction) error { + i.txns[txn.TransactionID] = txn + return nil +} + +func (i *InMemoryVerificationStore) FindVerificationTransactionForUserDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (VerificationTransaction, error) { + for _, existingTxn := range i.txns { + if existingTxn.TheirUserID == userID && existingTxn.TheirDeviceID == deviceID { + return existingTxn, nil + } + } + return VerificationTransaction{}, ErrUnknownVerificationTransaction +} + +func (i *InMemoryVerificationStore) GetAllVerificationTransactions(ctx context.Context) (txns []VerificationTransaction, err error) { + for _, txn := range i.txns { + txns = append(txns, txn) + } + return +} From 3f23a752e6f081a49318a857aa3b8cd114dabb92 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 22 Nov 2024 16:47:05 -0700 Subject: [PATCH 0933/1647] verificationhelper: set the context logger more aggressively Signed-off-by: Sumner Evans --- crypto/verificationhelper/reciprocate.go | 4 ++++ crypto/verificationhelper/sas.go | 6 ++++++ .../verificationhelper/verificationhelper.go | 20 ++++++++++++++----- 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index 5c38c655..ef69f23c 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -32,6 +32,8 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by Stringer("transaction_id", qrCode.TransactionID). Int("mode", int(qrCode.Mode)). Logger() + ctx = log.WithContext(ctx) + vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() @@ -197,6 +199,7 @@ func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id Str("verification_action", "confirm QR code scanned"). Stringer("transaction_id", txnID). Logger() + ctx = log.WithContext(ctx) vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() @@ -265,6 +268,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *Ve Str("verification_action", "generate and show QR code"). Stringer("transaction_id", txn.TransactionID). Logger() + ctx = log.WithContext(ctx) if vh.showQRCode == nil { log.Info().Msg("Ignoring QR code generation request as showing a QR code is not enabled on this device") return nil diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index 1e465867..81728bd4 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -37,6 +37,7 @@ func (vh *VerificationHelper) StartSAS(ctx context.Context, txnID id.Verificatio Str("verification_action", "accept verification"). Stringer("transaction_id", txnID). Logger() + ctx = log.WithContext(ctx) vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() @@ -95,6 +96,7 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat Str("verification_action", "confirm SAS"). Stringer("transaction_id", txnID). Logger() + ctx = log.WithContext(ctx) vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() @@ -169,6 +171,7 @@ func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn Ve Str("verification_action", "start_sas"). Stringer("transaction_id", txn.TransactionID). Logger() + ctx = log.WithContext(ctx) log.Info().Msg("Received SAS verification start event") _, err := vh.mach.GetOrFetchDevice(ctx, evt.Sender, startEvt.FromDevice) @@ -266,6 +269,7 @@ func (vh *VerificationHelper) onVerificationAccept(ctx context.Context, txn Veri Str("message_authentication_code", string(acceptEvt.MessageAuthenticationCode)). Any("short_authentication_string", acceptEvt.ShortAuthenticationString). Logger() + ctx = log.WithContext(ctx) log.Info().Msg("Received SAS verification accept event") vh.activeTransactionsLock.Lock() @@ -305,6 +309,7 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn Verific log := vh.getLog(ctx).With(). Str("verification_action", "key"). Logger() + ctx = log.WithContext(ctx) keyEvt := evt.Content.AsVerificationKey() vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() @@ -579,6 +584,7 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific log := vh.getLog(ctx).With(). Str("verification_action", "mac"). Logger() + ctx = log.WithContext(ctx) log.Info().Msg("Received SAS verification MAC event") vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 4c10827c..eb9099f4 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -131,6 +131,8 @@ func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, stor func (vh *VerificationHelper) getLog(ctx context.Context) *zerolog.Logger { logger := zerolog.Ctx(ctx).With(). Str("component", "verification"). + Stringer("device_id", vh.client.DeviceID). + Stringer("user_id", vh.client.UserID). Any("supported_methods", vh.supportedMethods). Logger() return &logger @@ -167,6 +169,7 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { Stringer("event_id", evt.ID). Stringer("event_type", evt.Type). Logger() + ctx = log.WithContext(ctx) var transactionID id.VerificationTransactionID if evt.ID != "" { @@ -279,12 +282,14 @@ func (vh *VerificationHelper) StartVerification(ctx context.Context, to id.UserI } } - vh.getLog(ctx).Info(). + log := vh.getLog(ctx).With(). Str("verification_action", "start verification"). Stringer("transaction_id", txnID). Stringer("to", to). Any("device_ids", maps.Keys(devices)). - Msg("Sending verification request") + Logger() + ctx = log.WithContext(ctx) + log.Info().Msg("Sending verification request") now := time.Now() content := &event.Content{ @@ -330,6 +335,7 @@ func (vh *VerificationHelper) StartInRoomVerification(ctx context.Context, roomI Stringer("room_id", roomID). Stringer("to", to). Logger() + ctx = log.WithContext(ctx) log.Info().Msg("Sending verification request") content := event.MessageEventContent{ @@ -369,6 +375,7 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V Str("verification_action", "accept verification"). Stringer("transaction_id", txnID). Logger() + ctx = log.WithContext(ctx) txn, err := vh.store.GetVerificationTransaction(ctx, txnID) if err != nil { @@ -523,13 +530,14 @@ func (vh *VerificationHelper) sendVerificationEvent(ctx context.Context, txn Ver // // Must always be called with the activeTransactionsLock held. func (vh *VerificationHelper) cancelVerificationTxn(ctx context.Context, txn VerificationTransaction, code event.VerificationCancelCode, reasonFmtStr string, fmtArgs ...any) error { - log := vh.getLog(ctx) reason := fmt.Errorf(reasonFmtStr, fmtArgs...).Error() - log.Info(). + log := vh.getLog(ctx).With(). Stringer("transaction_id", txn.TransactionID). Str("code", string(code)). Str("reason", reason). - Msg("Sending cancellation event") + Logger() + ctx = log.WithContext(ctx) + log.Info().Msg("Sending cancellation event") cancelEvt := &event.VerificationCancelEventContent{Code: code, Reason: reason} err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationCancel, cancelEvt) if err != nil { @@ -672,6 +680,7 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn Verif log := vh.getLog(ctx).With(). Str("verification_action", "verification ready"). Logger() + ctx = log.WithContext(ctx) vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() @@ -821,6 +830,7 @@ func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn Verifi Stringer("transaction_id", txn.TransactionID). Bool("sent_our_done", txn.SentOurDone). Logger() + ctx = log.WithContext(ctx) log.Info().Msg("Verification done") vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() From 2a8e6fba65607ef1bfaf44308448c9e0024a6b42 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 22 Nov 2024 16:48:46 -0700 Subject: [PATCH 0934/1647] verificationhelper: set the expiration time correctly Signed-off-by: Sumner Evans --- crypto/verificationhelper/verificationhelper.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index eb9099f4..33d69cd2 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -203,8 +203,9 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { // We have to create a fake transaction so that the call to // cancelVerificationTxn works. txn = VerificationTransaction{ - RoomID: evt.RoomID, - TheirUserID: evt.Sender, + ExpirationTime: jsontime.UnixMilli{Time: time.Now().Add(time.Minute * 10)}, + RoomID: evt.RoomID, + TheirUserID: evt.Sender, } if transactionable, ok := evt.Content.Parsed.(event.VerificationTransactionable); ok { txn.TransactionID = transactionable.GetTransactionID() @@ -320,6 +321,7 @@ func (vh *VerificationHelper) StartVerification(ctx context.Context, to id.UserI vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() return txnID, vh.store.SaveVerificationTransaction(ctx, VerificationTransaction{ + ExpirationTime: jsontime.UnixMilli{Time: time.Now().Add(time.Minute * 10)}, VerificationState: VerificationStateRequested, TransactionID: txnID, TheirUserID: to, @@ -360,6 +362,7 @@ func (vh *VerificationHelper) StartInRoomVerification(ctx context.Context, roomI vh.activeTransactionsLock.Lock() defer vh.activeTransactionsLock.Unlock() return txnID, vh.store.SaveVerificationTransaction(ctx, VerificationTransaction{ + ExpirationTime: jsontime.UnixMilli{Time: time.Now().Add(time.Minute * 10)}, RoomID: roomID, VerificationState: VerificationStateRequested, TransactionID: txnID, From f7e5f0a3b688cc16426cddc5b4e0b8b92f3cbc63 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 26 Nov 2024 10:08:52 -0700 Subject: [PATCH 0935/1647] verificationhelper: add tests for using SQLite store for verification Signed-off-by: Sumner Evans --- .../verificationhelper_test.go | 14 ++- .../verificationstore_test.go | 87 +++++++++++++++++++ 2 files changed, 99 insertions(+), 2 deletions(-) create mode 100644 crypto/verificationhelper/verificationstore_test.go diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go index d0bf2298..af4a28c3 100644 --- a/crypto/verificationhelper/verificationhelper_test.go +++ b/crypto/verificationhelper/verificationhelper_test.go @@ -2,6 +2,7 @@ package verificationhelper_test import ( "context" + "database/sql" "fmt" "os" "testing" @@ -65,11 +66,20 @@ func initServerAndLoginAliceBob(t *testing.T, ctx context.Context) (ts *mockServ func initDefaultCallbacks(t *testing.T, ctx context.Context, sendingClient, receivingClient *mautrix.Client, sendingMachine, receivingMachine *crypto.OlmMachine) (sendingCallbacks, receivingCallbacks *allVerificationCallbacks, sendingHelper, receivingHelper *verificationhelper.VerificationHelper) { t.Helper() sendingCallbacks = newAllVerificationCallbacks() - sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, sendingCallbacks, true) + senderVerificationDB, err := sql.Open("sqlite3", ":memory:") + require.NoError(t, err) + senderVerificationStore, err := NewSQLiteVerificationStore(ctx, senderVerificationDB) + require.NoError(t, err) + + sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, senderVerificationStore, sendingCallbacks, true) require.NoError(t, sendingHelper.Init(ctx)) receivingCallbacks = newAllVerificationCallbacks() - receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, nil, receivingCallbacks, true) + receiverVerificationDB, err := sql.Open("sqlite3", ":memory:") + require.NoError(t, err) + receiverVerificationStore, err := NewSQLiteVerificationStore(ctx, receiverVerificationDB) + require.NoError(t, err) + receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, receiverVerificationStore, receivingCallbacks, true) require.NoError(t, receivingHelper.Init(ctx)) return } diff --git a/crypto/verificationhelper/verificationstore_test.go b/crypto/verificationhelper/verificationstore_test.go new file mode 100644 index 00000000..a3b1895d --- /dev/null +++ b/crypto/verificationhelper/verificationstore_test.go @@ -0,0 +1,87 @@ +package verificationhelper_test + +import ( + "context" + "database/sql" + + _ "github.com/mattn/go-sqlite3" + "github.com/rs/zerolog" + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/crypto/verificationhelper" + "maunium.net/go/mautrix/id" +) + +type SQLiteVerificationStore struct { + db *sql.DB +} + +const ( + selectVerifications = `SELECT transaction_data FROM verifications` + getVerificationByTransactionID = selectVerifications + ` WHERE transaction_id = ?1` + getVerificationByUserDeviceID = selectVerifications + ` + WHERE transaction_data->>'their_user_id' = ?1 + AND transaction_data->>'their_device_id' = ?2 + ` + deleteVerificationsQuery = `DELETE FROM verifications WHERE transaction_id = ?1` +) + +var _ verificationhelper.VerificationStore = (*SQLiteVerificationStore)(nil) + +func NewSQLiteVerificationStore(ctx context.Context, db *sql.DB) (*SQLiteVerificationStore, error) { + _, err := db.ExecContext(ctx, ` + CREATE TABLE verifications ( + transaction_id TEXT PRIMARY KEY NOT NULL, + transaction_data JSONB NOT NULL + ); + CREATE INDEX verifications_user_device_id ON + verifications(transaction_data->>'their_user_id', transaction_data->>'their_device_id'); + `) + return &SQLiteVerificationStore{db}, err +} + +func (s *SQLiteVerificationStore) GetAllVerificationTransactions(ctx context.Context) ([]verificationhelper.VerificationTransaction, error) { + rows, err := s.db.QueryContext(ctx, selectVerifications) + if err != nil { + return nil, err + } + return dbutil.NewRowIter(rows, func(dbutil.Scannable) (txn verificationhelper.VerificationTransaction, err error) { + err = rows.Scan(&dbutil.JSON{Data: &txn}) + return + }).AsList() +} + +func (vq *SQLiteVerificationStore) GetVerificationTransaction(ctx context.Context, txnID id.VerificationTransactionID) (txn verificationhelper.VerificationTransaction, err error) { + zerolog.Ctx(ctx).Warn().Stringer("transaction_id", txnID).Msg("Getting verification transaction") + row := vq.db.QueryRowContext(ctx, getVerificationByTransactionID, txnID) + err = row.Scan(&dbutil.JSON{Data: &txn}) + if err == sql.ErrNoRows { + err = verificationhelper.ErrUnknownVerificationTransaction + } + return +} + +func (vq *SQLiteVerificationStore) FindVerificationTransactionForUserDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (txn verificationhelper.VerificationTransaction, err error) { + row := vq.db.QueryRowContext(ctx, getVerificationByUserDeviceID, userID, deviceID) + err = row.Scan(&dbutil.JSON{Data: &txn}) + if err == sql.ErrNoRows { + err = verificationhelper.ErrUnknownVerificationTransaction + } + return +} + +func (vq *SQLiteVerificationStore) SaveVerificationTransaction(ctx context.Context, txn verificationhelper.VerificationTransaction) (err error) { + zerolog.Ctx(ctx).Debug().Any("transaction", &txn).Msg("Saving verification transaction") + _, err = vq.db.ExecContext(ctx, ` + INSERT INTO verifications (transaction_id, transaction_data) + VALUES (?1, ?2) + ON CONFLICT (transaction_id) DO UPDATE + SET transaction_data=excluded.transaction_data + `, txn.TransactionID, &dbutil.JSON{Data: &txn}) + return +} + +func (vq *SQLiteVerificationStore) DeleteVerification(ctx context.Context, txnID id.VerificationTransactionID) (err error) { + _, err = vq.db.ExecContext(ctx, deleteVerificationsQuery, txnID) + return +} From 4b970e0ea7e691bcc921cb6bdb2e553bcbcc569e Mon Sep 17 00:00:00 2001 From: Brad Murray Date: Tue, 26 Nov 2024 15:29:18 -0500 Subject: [PATCH 0936/1647] crypto/sqlstore: add index to crypto_olm_sessions table to speed up lookups by sender_key (#323) Co-authored-by: Tulir Asokan --- crypto/sql_store_upgrade/00-latest-revision.sql | 3 ++- crypto/sql_store_upgrade/16-crypto-olm-sessions-index.sql | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 crypto/sql_store_upgrade/16-crypto-olm-sessions-index.sql diff --git a/crypto/sql_store_upgrade/00-latest-revision.sql b/crypto/sql_store_upgrade/00-latest-revision.sql index 7e039af5..7cd3331c 100644 --- a/crypto/sql_store_upgrade/00-latest-revision.sql +++ b/crypto/sql_store_upgrade/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v15: Latest revision +-- v0 -> v16 (compatible with v15+): Latest revision CREATE TABLE IF NOT EXISTS crypto_account ( account_id TEXT PRIMARY KEY, device_id TEXT NOT NULL, @@ -43,6 +43,7 @@ CREATE TABLE IF NOT EXISTS crypto_olm_session ( last_encrypted timestamp NOT NULL, PRIMARY KEY (account_id, session_id) ); +CREATE INDEX crypto_olm_session_sender_key_idx ON crypto_olm_session (account_id, sender_key); CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session ( account_id TEXT, diff --git a/crypto/sql_store_upgrade/16-crypto-olm-sessions-index.sql b/crypto/sql_store_upgrade/16-crypto-olm-sessions-index.sql new file mode 100644 index 00000000..f0c3a0c5 --- /dev/null +++ b/crypto/sql_store_upgrade/16-crypto-olm-sessions-index.sql @@ -0,0 +1,2 @@ +-- v16 (compatible with v15+): Add index to crypto_olm_sessions to speedup lookups by sender_key +CREATE INDEX crypto_olm_session_sender_key_idx ON crypto_olm_session (account_id, sender_key); From b32def2b148b2861ff134fdcac948831281ddbc8 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 27 Nov 2024 15:07:47 +0000 Subject: [PATCH 0937/1647] bridgev2/userlogin: add blocking cleanup delete option This allows the caller to block until the cleanup actions are complete. --- bridgev2/userlogin.go | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index f9b8f7b1..142d67d4 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -255,6 +255,7 @@ func (ul *UserLogin) Logout(ctx context.Context) { type DeleteOpts struct { LogoutRemote bool DontCleanupRooms bool + BlockingCleanup bool unlocked bool } @@ -295,9 +296,17 @@ func (ul *UserLogin) Delete(ctx context.Context, state status.BridgeState, opts ul.Bridge.cacheLock.Unlock() } backgroundCtx := context.WithoutCancel(ctx) - go ul.deleteSpace(backgroundCtx) + if !opts.BlockingCleanup { + go ul.deleteSpace(backgroundCtx) + } else { + ul.deleteSpace(backgroundCtx) + } if portals != nil { - go ul.kickUserFromPortals(backgroundCtx, portals, state.StateEvent == status.StateBadCredentials, false) + if !opts.BlockingCleanup { + go ul.kickUserFromPortals(backgroundCtx, portals, state.StateEvent == status.StateBadCredentials, false) + } else { + ul.kickUserFromPortals(backgroundCtx, portals, state.StateEvent == status.StateBadCredentials, false) + } } if state.StateEvent != "" { ul.BridgeState.Send(state) From e3d52674855d0742646446b30f631ed59255de7a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 29 Nov 2024 18:21:41 +0200 Subject: [PATCH 0938/1647] bridgev2/networkinterface: don't allow returning errors in Connect --- bridgev2/bridge.go | 5 +---- bridgev2/networkinterface.go | 4 +++- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 16ebdb77..b2151ee6 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -187,10 +187,7 @@ func (br *Bridge) StartLogins() error { for _, login := range user.GetUserLogins() { startedAny = true br.Log.Info().Str("id", string(login.ID)).Msg("Starting user login") - err = login.Client.Connect(login.Log.WithContext(ctx)) - if err != nil { - br.Log.Err(err).Msg("Failed to connect existing client") - } + login.Client.Connect(login.Log.WithContext(ctx)) } } } diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 6e31e721..9acc6eae 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -344,7 +344,9 @@ type NetworkRoomCapabilities struct { type NetworkAPI interface { // Connect is called to actually connect to the remote network. // If there's no persistent connection, this may just check access token validity, or even do nothing at all. - Connect(ctx context.Context) error + // This method isn't allowed to return errors, because any connection errors should be sent + // using the bridge state mechanism (UserLogin.BridgeState.Send) + Connect(ctx context.Context) // Disconnect should disconnect from the remote network. // A clean disconnection is preferred, but it should not take too long. Disconnect() From 6032adb1135375db766e74c6175327262768a408 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 30 Nov 2024 22:44:21 +0200 Subject: [PATCH 0939/1647] bridgev2/legacymigrate: fix database type for legacy go configs too --- bridgev2/bridgeconfig/legacymigrate.go | 6 +++++- bridgev2/bridgeconfig/upgrade.go | 3 +++ bridgev2/matrix/mxmain/main.go | 4 ++++ 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/bridgev2/bridgeconfig/legacymigrate.go b/bridgev2/bridgeconfig/legacymigrate.go index e8fab743..fb2a86d6 100644 --- a/bridgev2/bridgeconfig/legacymigrate.go +++ b/bridgev2/bridgeconfig/legacymigrate.go @@ -107,7 +107,11 @@ func doMigrateLegacy(helper up.Helper, python bool) { helper.Set(up.Int, legacyDBMaxSize, "database", "max_open_conns") } } else { - CopyToOtherLocation(helper, up.Str, []string{"appservice", "database", "type"}, []string{"database", "type"}) + if dbType, ok := helper.Get(up.Str, "appservice", "database", "type"); ok && dbType == "sqlite3" { + helper.Set(up.Str, "sqlite3-fk-wal", "database", "type") + } else { + CopyToOtherLocation(helper, up.Str, []string{"appservice", "database", "type"}, []string{"database", "type"}) + } CopyToOtherLocation(helper, up.Str, []string{"appservice", "database", "uri"}, []string{"database", "uri"}) CopyToOtherLocation(helper, up.Int, []string{"appservice", "database", "max_open_conns"}, []string{"database", "max_open_conns"}) CopyToOtherLocation(helper, up.Int, []string{"appservice", "database", "max_idle_conns"}, []string{"database", "max_idle_conns"}) diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 3948cc11..165d8332 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -7,6 +7,8 @@ package bridgeconfig import ( + "fmt" + up "go.mau.fi/util/configupgrade" "go.mau.fi/util/random" @@ -49,6 +51,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Map, "bridge", "permissions") if dbType, ok := helper.Get(up.Str, "database", "type"); ok && dbType == "sqlite3" { + fmt.Println("Warning: invalid database type sqlite3 in config. Autocorrecting to sqlite3-fk-wal") helper.Set(up.Str, "sqlite3-fk-wal", "database", "type") } else { helper.Copy(up.Str, "database", "type") diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index af0868bf..c1bfb9ee 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -267,6 +267,10 @@ func (br *BridgeMain) Init() { func (br *BridgeMain) initDB() { br.Log.Debug().Msg("Initializing database connection") dbConfig := br.Config.Database + if dbConfig.Type == "sqlite3" { + br.Log.WithLevel(zerolog.FatalLevel).Msg("Invalid database type sqlite3. Use sqlite3-fk-wal instead.") + os.Exit(14) + } if (dbConfig.Type == "sqlite3-fk-wal" || dbConfig.Type == "litestream") && dbConfig.MaxOpenConns != 1 && !strings.Contains(dbConfig.URI, "_txlock=immediate") { var fixedExampleURI string if !strings.HasPrefix(dbConfig.URI, "file:") { From 166ba04aae02405c8517ec7488aaa7c2344e06c2 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Mon, 2 Dec 2024 13:00:12 +0000 Subject: [PATCH 0940/1647] verificationhelper: Add missing verification txns unlock --- crypto/verificationhelper/verificationhelper.go | 1 + 1 file changed, 1 insertion(+) diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 33d69cd2..1a896b1d 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -188,6 +188,7 @@ func (vh *VerificationHelper) Init(ctx context.Context) error { txn, err := vh.store.GetVerificationTransaction(ctx, transactionID) if err != nil && errors.Is(err, ErrUnknownVerificationTransaction) { log.Err(err).Msg("failed to get verification transaction") + vh.activeTransactionsLock.Unlock() return } else if errors.Is(err, ErrUnknownVerificationTransaction) { // If it's a cancellation event for an unknown transaction, we From 9593f72d1b51b328e3389bffeba8590aefbda241 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Tue, 3 Dec 2024 15:30:24 +0000 Subject: [PATCH 0941/1647] event: add encrypted file info for m.room.member --- event/member.go | 1 + 1 file changed, 1 insertion(+) diff --git a/event/member.go b/event/member.go index ebafdcb7..d0ff2a7c 100644 --- a/event/member.go +++ b/event/member.go @@ -41,6 +41,7 @@ type MemberEventContent struct { IsDirect bool `json:"is_direct,omitempty"` ThirdPartyInvite *ThirdPartyInvite `json:"third_party_invite,omitempty"` Reason string `json:"reason,omitempty"` + MSC3414File *EncryptedFileInfo `json:"org.matrix.msc3414.file,omitempty"` } type ThirdPartyInvite struct { From bfa32f375f3f0cd16f14f30ae3d7744a5c5a84a0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 5 Dec 2024 01:29:55 +0200 Subject: [PATCH 0942/1647] client: add support for MSC2666 --- client.go | 16 ++++++++++++++++ requests.go | 4 ++++ responses.go | 5 +++++ versions.go | 1 + 4 files changed, 26 insertions(+) diff --git a/client.go b/client.go index f68cae63..e7cafe4f 100644 --- a/client.go +++ b/client.go @@ -978,6 +978,22 @@ func (cli *Client) GetProfile(ctx context.Context, mxid id.UserID) (resp *RespUs return } +func (cli *Client) GetMutualRooms(ctx context.Context, otherUserID id.UserID, extras ...ReqMutualRooms) (resp *RespMutualRooms, err error) { + if cli.SpecVersions != nil && !cli.SpecVersions.Supports(FeatureMutualRooms) { + err = fmt.Errorf("server does not support fetching mutual rooms") + return + } + query := map[string]string{ + "user_id": otherUserID.String(), + } + if len(extras) > 0 { + query["from"] = extras[0].From + } + urlPath := cli.BuildURLWithQuery(ClientURLPath{"unstable", "uk.half-shot.msc2666", "user", "mutual_rooms"}, query) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + return +} + // GetDisplayName returns the display name of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseriddisplayname func (cli *Client) GetDisplayName(ctx context.Context, mxid id.UserID) (resp *RespUserDisplayName, err error) { urlPath := cli.BuildClientURL("v3", "profile", mxid, "displayname") diff --git a/requests.go b/requests.go index 595f1212..5611bd57 100644 --- a/requests.go +++ b/requests.go @@ -140,6 +140,10 @@ type ReqMembers struct { NotMembership event.Membership `json:"not_membership,omitempty"` } +type ReqMutualRooms struct { + From string `json:"-"` +} + // ReqInvite3PID is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidinvite-1 // It is also a JSON object used in https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3createroom type ReqInvite3PID struct { diff --git a/responses.go b/responses.go index 26aaac77..6ead355e 100644 --- a/responses.go +++ b/responses.go @@ -159,6 +159,11 @@ type RespUserProfile struct { AvatarURL id.ContentURI `json:"avatar_url"` } +type RespMutualRooms struct { + Joined []id.RoomID `json:"joined"` + NextBatch string `json:"next_batch,omitempty"` +} + // RespRegisterAvailable is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv3registeravailable type RespRegisterAvailable struct { Available bool `json:"available"` diff --git a/versions.go b/versions.go index 672018ff..5c0d6eaa 100644 --- a/versions.go +++ b/versions.go @@ -63,6 +63,7 @@ var ( FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17} FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17} FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111} + FeatureMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"} BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"} BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"} From 933daead3b3471a958bc2d5b262faa8d7ed66707 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 5 Dec 2024 14:24:54 +0200 Subject: [PATCH 0943/1647] bridgev2/commands: fix pm command not starting chat --- bridgev2/commands/startchat.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go index 42f528b0..aa766c0e 100644 --- a/bridgev2/commands/startchat.go +++ b/bridgev2/commands/startchat.go @@ -85,7 +85,7 @@ func fnResolveIdentifier(ce *Event) { if api == nil { return } - createChat := ce.Command == "start-chat" + createChat := ce.Command == "start-chat" || ce.Command == "pm" identifier := strings.Join(identifierParts, " ") resp, err := api.ResolveIdentifier(ce.Ctx, identifier, createChat) if err != nil { From 3a9061e69cebc892458869e33be290d1d6c12041 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 7 Dec 2024 14:56:22 +0200 Subject: [PATCH 0944/1647] crypto/devicelist: add helper for getting cached device list --- crypto/devicelist.go | 89 ++++++++++++++++++++++++++++++++++++++++++++ crypto/store.go | 2 + 2 files changed, 91 insertions(+) diff --git a/crypto/devicelist.go b/crypto/devicelist.go index de6c21f3..38782e90 100644 --- a/crypto/devicelist.go +++ b/crypto/devicelist.go @@ -10,6 +10,8 @@ import ( "context" "errors" "fmt" + "slices" + "strings" "github.com/rs/zerolog" "go.mau.fi/util/exzerolog" @@ -26,6 +28,8 @@ var ( NoSigningKeyFound = errors.New("didn't find ed25519 signing key") NoIdentityKeyFound = errors.New("didn't find curve25519 identity key") InvalidKeySignature = errors.New("invalid signature on device keys") + + ErrUserNotTracked = errors.New("user is not tracked") ) func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) (keys map[id.DeviceID]*id.Device) { @@ -40,6 +44,91 @@ func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) (keys m return nil } +type GetCachedDevicesParams struct { + IsUserTracked bool +} + +type CachedDevices struct { + Devices []*id.Device + MasterKey *id.CrossSigningKey + HasValidSelfSigningKey bool + MasterKeySignedByUs bool +} + +func (mach *OlmMachine) GetCachedDevices(ctx context.Context, userID id.UserID, extra ...GetCachedDevicesParams) (*CachedDevices, error) { + var params GetCachedDevicesParams + if len(extra) > 0 { + params = extra[0] + } + if !params.IsUserTracked { + userIDs, err := mach.CryptoStore.FilterTrackedUsers(ctx, []id.UserID{userID}) + if err != nil { + return nil, fmt.Errorf("failed to check if user's devices are tracked: %w", err) + } else if len(userIDs) == 0 { + return nil, ErrUserNotTracked + } + } + ownKeys := mach.GetOwnCrossSigningPublicKeys(ctx) + var ownUserSigningKey id.Ed25519 + if ownKeys != nil { + ownUserSigningKey = ownKeys.UserSigningKey + } + var resp CachedDevices + csKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID) + theirMasterKey := csKeys[id.XSUsageMaster] + theirSelfSignKey := csKeys[id.XSUsageSelfSigning] + if err != nil { + return nil, fmt.Errorf("failed to get cross-signing keys: %w", err) + } else if csKeys != nil && theirMasterKey.Key != "" { + resp.MasterKey = &theirMasterKey + if theirSelfSignKey.Key != "" { + resp.HasValidSelfSigningKey, err = mach.CryptoStore.IsKeySignedBy(ctx, userID, theirSelfSignKey.Key, userID, theirMasterKey.Key) + if err != nil { + return nil, fmt.Errorf("failed to check if self-signing key is signed by master key: %w", err) + } + } + } + devices, err := mach.CryptoStore.GetDevices(ctx, userID) + if err != nil { + return nil, fmt.Errorf("failed to get devices: %w", err) + } + if userID == mach.Client.UserID { + if ownKeys != nil && ownKeys.MasterKey == theirMasterKey.Key { + resp.MasterKeySignedByUs, err = mach.CryptoStore.IsKeySignedBy(ctx, userID, theirMasterKey.Key, userID, mach.OwnIdentity().SigningKey) + } + } else if ownUserSigningKey != "" && theirMasterKey.Key != "" { + // TODO should own master key and user-signing key signatures be checked here too? + resp.MasterKeySignedByUs, err = mach.CryptoStore.IsKeySignedBy(ctx, userID, theirMasterKey.Key, mach.Client.UserID, ownUserSigningKey) + } + if err != nil { + return nil, fmt.Errorf("failed to check if user is trusted: %w", err) + } + resp.Devices = make([]*id.Device, len(devices)) + i := 0 + for _, device := range devices { + resp.Devices[i] = device + if resp.HasValidSelfSigningKey && device.Trust == id.TrustStateUnset { + signed, err := mach.CryptoStore.IsKeySignedBy(ctx, device.UserID, device.SigningKey, device.UserID, theirSelfSignKey.Key) + if err != nil { + return nil, fmt.Errorf("failed to check if device %s is signed by self-signing key: %w", device.DeviceID, err) + } else if signed { + if resp.MasterKeySignedByUs { + device.Trust = id.TrustStateCrossSignedVerified + } else if theirMasterKey.Key == theirMasterKey.First { + device.Trust = id.TrustStateCrossSignedTOFU + } else { + device.Trust = id.TrustStateCrossSignedUntrusted + } + } + } + i++ + } + slices.SortFunc(resp.Devices, func(a, b *id.Device) int { + return strings.Compare(a.DeviceID.String(), b.DeviceID.String()) + }) + return &resp, nil +} + func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id.UserID, deviceID id.DeviceID, resp *mautrix.RespQueryKeys) { log := zerolog.Ctx(ctx) deviceKeys := resp.DeviceKeys[userID][deviceID] diff --git a/crypto/store.go b/crypto/store.go index 64fa8912..9a3a4394 100644 --- a/crypto/store.go +++ b/crypto/store.go @@ -135,6 +135,8 @@ type Store interface { IsKeySignedBy(ctx context.Context, userID id.UserID, key id.Ed25519, signedByUser id.UserID, signedByKey id.Ed25519) (bool, error) // DropSignaturesByKey deletes the signatures made by the given user and key from the store. It returns the number of signatures deleted. DropSignaturesByKey(context.Context, id.UserID, id.Ed25519) (int64, error) + // GetSignaturesForKeyBy retrieves the stored signatures for a given cross-signing or device key, by the given signer. + GetSignaturesForKeyBy(context.Context, id.UserID, id.Ed25519, id.UserID) (map[id.Ed25519]string, error) // PutSecret stores a named secret, replacing it if it exists already. PutSecret(context.Context, id.Secret, string) error From 421bd5c4c837a589fe66fff2583d91de97314df9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 7 Dec 2024 15:04:33 +0200 Subject: [PATCH 0945/1647] crypto/devicelist: remove unnecessary parameter --- crypto/devicelist.go | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/crypto/devicelist.go b/crypto/devicelist.go index 38782e90..a2116ed5 100644 --- a/crypto/devicelist.go +++ b/crypto/devicelist.go @@ -44,10 +44,6 @@ func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) (keys m return nil } -type GetCachedDevicesParams struct { - IsUserTracked bool -} - type CachedDevices struct { Devices []*id.Device MasterKey *id.CrossSigningKey @@ -55,18 +51,12 @@ type CachedDevices struct { MasterKeySignedByUs bool } -func (mach *OlmMachine) GetCachedDevices(ctx context.Context, userID id.UserID, extra ...GetCachedDevicesParams) (*CachedDevices, error) { - var params GetCachedDevicesParams - if len(extra) > 0 { - params = extra[0] - } - if !params.IsUserTracked { - userIDs, err := mach.CryptoStore.FilterTrackedUsers(ctx, []id.UserID{userID}) - if err != nil { - return nil, fmt.Errorf("failed to check if user's devices are tracked: %w", err) - } else if len(userIDs) == 0 { - return nil, ErrUserNotTracked - } +func (mach *OlmMachine) GetCachedDevices(ctx context.Context, userID id.UserID) (*CachedDevices, error) { + userIDs, err := mach.CryptoStore.FilterTrackedUsers(ctx, []id.UserID{userID}) + if err != nil { + return nil, fmt.Errorf("failed to check if user's devices are tracked: %w", err) + } else if len(userIDs) == 0 { + return nil, ErrUserNotTracked } ownKeys := mach.GetOwnCrossSigningPublicKeys(ctx) var ownUserSigningKey id.Ed25519 From 3cb79ba7b5baa8dc445944b6a41b1d5c76810765 Mon Sep 17 00:00:00 2001 From: onestacked Date: Sat, 7 Dec 2024 23:33:35 +0200 Subject: [PATCH 0946/1647] client,bridgev2: add support for MSC4190 Closes #288 --- appservice/registration.go | 1 + bridgev2/bridgeconfig/appservice.go | 2 ++ bridgev2/bridgeconfig/upgrade.go | 1 + bridgev2/matrix/crypto.go | 21 ++++++++++++++++----- bridgev2/matrix/mxmain/example-config.yaml | 4 ++++ client.go | 17 +++++++++++++++++ requests.go | 4 ++++ 7 files changed, 45 insertions(+), 5 deletions(-) diff --git a/appservice/registration.go b/appservice/registration.go index 026df8ea..54eff716 100644 --- a/appservice/registration.go +++ b/appservice/registration.go @@ -29,6 +29,7 @@ type Registration struct { SoruEphemeralEvents bool `yaml:"de.sorunome.msc2409.push_ephemeral,omitempty" json:"de.sorunome.msc2409.push_ephemeral,omitempty"` EphemeralEvents bool `yaml:"receive_ephemeral,omitempty" json:"receive_ephemeral,omitempty"` MSC3202 bool `yaml:"org.matrix.msc3202,omitempty" json:"org.matrix.msc3202,omitempty"` + MSC4190 bool `yaml:"io.element.msc4190,omitempty" json:"io.element.msc4190,omitempty"` } // CreateRegistration creates a Registration with random appservice and homeserver tokens. diff --git a/bridgev2/bridgeconfig/appservice.go b/bridgev2/bridgeconfig/appservice.go index 5e482499..89ce5677 100644 --- a/bridgev2/bridgeconfig/appservice.go +++ b/bridgev2/bridgeconfig/appservice.go @@ -34,6 +34,7 @@ type AppserviceConfig struct { EphemeralEvents bool `yaml:"ephemeral_events"` AsyncTransactions bool `yaml:"async_transactions"` + MSC4190 bool `yaml:"msc4190"` UsernameTemplate string `yaml:"username_template"` usernameTemplate *template.Template `yaml:"-"` @@ -77,6 +78,7 @@ func (asc *AppserviceConfig) copyToRegistration(registration *appservice.Registr registration.RateLimited = &falseVal registration.EphemeralEvents = asc.EphemeralEvents registration.SoruEphemeralEvents = asc.EphemeralEvents + registration.MSC4190 = asc.MSC4190 } // GenerateRegistration generates a registration file for the homeserver. diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 165d8332..776fa44d 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -82,6 +82,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Str, "appservice", "bot", "avatar") helper.Copy(up.Bool, "appservice", "ephemeral_events") helper.Copy(up.Bool, "appservice", "async_transactions") + helper.Copy(up.Bool, "appservice", "msc4190") helper.Copy(up.Str, "appservice", "as_token") helper.Copy(up.Str, "appservice", "hs_token") helper.Copy(up.Str, "appservice", "username_template") diff --git a/bridgev2/matrix/crypto.go b/bridgev2/matrix/crypto.go index 04654ff5..3cb16e52 100644 --- a/bridgev2/matrix/crypto.go +++ b/bridgev2/matrix/crypto.go @@ -241,23 +241,34 @@ func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool // Create a new client instance with the default AS settings (including as_token), // the Login call will then override the access token in the client. client := helper.bridge.AS.NewMautrixClient(helper.bridge.AS.BotMXID()) + + initialDeviceDisplayName := fmt.Sprintf("%s bridge", helper.bridge.Bridge.Network.GetName().DisplayName) + if helper.bridge.Config.AppService.MSC4190 { + helper.log.Debug().Msg("Creating bot device with MSC4190") + err = client.CreateDeviceMSC4190(ctx, deviceID, initialDeviceDisplayName) + if err != nil { + return nil, deviceID != "", fmt.Errorf("failed to create device for bridge bot: %w", err) + } + helper.store.DeviceID = client.DeviceID + return client, deviceID != "", nil + } + flows, err := client.GetLoginFlows(ctx) if err != nil { return nil, deviceID != "", fmt.Errorf("failed to get supported login flows: %w", err) } else if !flows.HasFlow(mautrix.AuthTypeAppservice) { return nil, deviceID != "", fmt.Errorf("homeserver does not support appservice login") } + resp, err := client.Login(ctx, &mautrix.ReqLogin{ Type: mautrix.AuthTypeAppservice, Identifier: mautrix.UserIdentifier{ Type: mautrix.IdentifierTypeUser, User: string(helper.bridge.AS.BotMXID()), }, - DeviceID: deviceID, - StoreCredentials: true, - - // TODO find proper bridge name - InitialDeviceDisplayName: "Megabridge", // fmt.Sprintf("%s bridge", helper.bridge.ProtocolName), + DeviceID: deviceID, + StoreCredentials: true, + InitialDeviceDisplayName: initialDeviceDisplayName, }) if err != nil { return nil, deviceID != "", fmt.Errorf("failed to log in as bridge bot: %w", err) diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index b8a00637..8f7655fc 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -193,6 +193,10 @@ appservice: # However, messages will not be guaranteed to be bridged in the same order they were sent in. # This value doesn't affect the registration file. async_transactions: false + # Whether to use MSC4190 instead of appservice login to create the bridge bot device. + # Requires the homeserver to support MSC4190 and the device masquerading parts of MSC3202. + # Only relevant when using end-to-bridge encryption, required when using encryption with next-gen auth (MSC3861). + msc4190: false # Authentication tokens for AS <-> HS communication. Autogenerated; do not modify. as_token: "This value is generated when generating the registration" diff --git a/client.go b/client.go index e7cafe4f..e8689708 100644 --- a/client.go +++ b/client.go @@ -21,6 +21,7 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/ptr" + "go.mau.fi/util/random" "go.mau.fi/util/retryafter" "golang.org/x/exp/maps" @@ -901,6 +902,22 @@ func (cli *Client) Login(ctx context.Context, req *ReqLogin) (resp *RespLogin, e return } +// Create a device for an appservice user using MSC4190. +func (cli *Client) CreateDeviceMSC4190(ctx context.Context, deviceID id.DeviceID, initialDisplayName string) error { + if len(deviceID) == 0 { + deviceID = id.DeviceID(strings.ToUpper(random.String(10))) + } + _, err := cli.MakeRequest(ctx, http.MethodPut, cli.BuildClientURL("v3", "devices", deviceID), &ReqPutDevice{ + DisplayName: initialDisplayName, + }, nil) + if err != nil { + return err + } + cli.DeviceID = deviceID + cli.SetAppServiceDeviceID = true + return nil +} + // Logout the current user. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3logout // This does not clear the credentials from the client instance. See ClearCredentials() instead. func (cli *Client) Logout(ctx context.Context) (resp *RespLogout, err error) { diff --git a/requests.go b/requests.go index 5611bd57..9e7eb0bd 100644 --- a/requests.go +++ b/requests.go @@ -91,6 +91,10 @@ type ReqLogin struct { StoreHomeserverURL bool `json:"-"` } +type ReqPutDevice struct { + DisplayName string `json:"display_name,omitempty"` +} + type ReqUIAuthFallback struct { Session string `json:"session"` User string `json:"user"` From 8d3c208bda64cef1f5cc7f44b70bd7b3337fd827 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 8 Dec 2024 22:04:08 +0200 Subject: [PATCH 0947/1647] crypto/verificationhelper: add from device parameter to requested callback --- crypto/verificationhelper/callbacks_test.go | 2 +- crypto/verificationhelper/verificationhelper.go | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/crypto/verificationhelper/callbacks_test.go b/crypto/verificationhelper/callbacks_test.go index 7b1055d1..5faf2009 100644 --- a/crypto/verificationhelper/callbacks_test.go +++ b/crypto/verificationhelper/callbacks_test.go @@ -77,7 +77,7 @@ func (c *baseVerificationCallbacks) GetDecimalsShown(txnID id.VerificationTransa return c.decimalsShown[txnID] } -func (c *baseVerificationCallbacks) VerificationRequested(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID) { +func (c *baseVerificationCallbacks) VerificationRequested(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID, fromDevice id.DeviceID) { c.verificationsRequested[from] = append(c.verificationsRequested[from], txnID) } diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 1a896b1d..de943976 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -30,7 +30,7 @@ import ( type RequiredCallbacks interface { // VerificationRequested is called when a verification request is received // from another device. - VerificationRequested(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID) + VerificationRequested(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID, fromDevice id.DeviceID) // VerificationCancelled is called when the verification is cancelled. VerificationCancelled(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) @@ -71,7 +71,7 @@ type VerificationHelper struct { // supportedMethods are the methods that *we* support supportedMethods []event.VerificationMethod - verificationRequested func(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID) + verificationRequested func(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID, fromDevice id.DeviceID) verificationCancelledCallback func(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) verificationDone func(ctx context.Context, txnID id.VerificationTransactionID) @@ -658,7 +658,7 @@ func (vh *VerificationHelper) onVerificationRequest(ctx context.Context, evt *ev vh.activeTransactionsLock.Unlock() vh.expireTransactionAt(verificationRequest.TransactionID, verificationRequest.Timestamp.Add(time.Minute*10)) - vh.verificationRequested(ctx, verificationRequest.TransactionID, evt.Sender) + vh.verificationRequested(ctx, verificationRequest.TransactionID, evt.Sender, verificationRequest.FromDevice) } func (vh *VerificationHelper) expireTransactionAt(txnID id.VerificationTransactionID, expiresAt time.Time) { From 3312a581612a94f95c903f598f30524e29d72c12 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 8 Dec 2024 23:08:45 +0200 Subject: [PATCH 0948/1647] bridgev2/matrix: log type of interrupt --- bridgev2/matrix/mxmain/main.go | 2 ++ bridgev2/matrix/websocket.go | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index c1bfb9ee..2c0c07b9 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -398,8 +398,10 @@ func (br *BridgeMain) WaitForInterrupt() int { signal.Notify(c, os.Interrupt, syscall.SIGTERM) select { case <-c: + br.Log.Info().Msg("Interrupt signal received from OS") return 0 case exitCode := <-br.manualStop: + br.Log.Info().Msg("Internal stop signal received") return exitCode } } diff --git a/bridgev2/matrix/websocket.go b/bridgev2/matrix/websocket.go index 36b8bca4..c679f960 100644 --- a/bridgev2/matrix/websocket.go +++ b/bridgev2/matrix/websocket.go @@ -61,7 +61,7 @@ func (br *Connector) startWebsocket(wg *sync.WaitGroup) { if errors.Is(err, appservice.ErrWebsocketManualStop) { return } else if closeCommand := (&appservice.CloseCommand{}); errors.As(err, &closeCommand) && closeCommand.Status == appservice.MeowConnectionReplaced { - log.Info().Msg("Appservice websocket closed by another instance of the bridge, shutting down...") + log.Warn().Msg("Appservice websocket closed by another instance of the bridge, shutting down...") if br.OnWebsocketReplaced != nil { br.OnWebsocketReplaced() } else { From c15e0dba939cfc6d4f49d89876601679c952db38 Mon Sep 17 00:00:00 2001 From: Scott Weber Date: Mon, 9 Dec 2024 14:41:36 -0500 Subject: [PATCH 0949/1647] event/beeper: change suborder field to int16 (from int64) (#328) --- event/beeper.go | 33 +++++++++++++++++---------------- event/events.go | 2 +- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/event/beeper.go b/event/beeper.go index 8af9e0d0..7ea0d068 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -111,10 +111,10 @@ type BeeperPerMessageProfile struct { type BeeperEncodedOrder struct { order int64 - suborder int64 + suborder int16 } -func NewBeeperEncodedOrder(order int64, suborder int64) *BeeperEncodedOrder { +func NewBeeperEncodedOrder(order int64, suborder int16) *BeeperEncodedOrder { return &BeeperEncodedOrder{order: order, suborder: suborder} } @@ -133,7 +133,7 @@ func (b *BeeperEncodedOrder) String() string { return encodeIntPair(b.order, b.suborder) } -func (b *BeeperEncodedOrder) OrderPair() (int64, int64) { +func (b *BeeperEncodedOrder) OrderPair() (int64, int16) { if b == nil { return 0, 0 } @@ -164,12 +164,13 @@ func (b *BeeperEncodedOrder) UnmarshalJSON(data []byte) error { return nil } -// encodeIntPair encodes two int64 integers into a lexicographically sortable string -func encodeIntPair(a, b int64) string { +// encodeIntPair encodes an int64 and an int16 into a lexicographically sortable string +func encodeIntPair(a int64, b int16) string { // Create a buffer to hold the binary representation of the integers. - var buf [16]byte + // Will need 8 bytes for the int64 and 2 bytes for the int16. + var buf [10]byte - // Flip the sign bit of each integer to map the entire int64 range to uint64 + // Flip the sign bit of each integer to map the entire int range to uint // in a way that preserves the order of the original integers. // // Explanation: @@ -178,7 +179,7 @@ func encodeIntPair(a, b int64) string { // - Non-negative numbers (with a sign bit of 0) become larger uint64 values. // - This mapping preserves the original ordering when the uint64 values are compared. binary.BigEndian.PutUint64(buf[0:8], uint64(a)^(1<<63)) - binary.BigEndian.PutUint64(buf[8:16], uint64(b)^(1<<63)) + binary.BigEndian.PutUint16(buf[8:10], uint16(b)^(1<<15)) // Encode the buffer into a Base32 string without padding using the Hex encoding. // @@ -191,8 +192,8 @@ func encodeIntPair(a, b int64) string { return encoded } -// decodeIntPair decodes a string produced by encodeIntPair back into the original int64 integers -func decodeIntPair(encoded string) (int64, int64, error) { +// decodeIntPair decodes a string produced by encodeIntPair back into the original int64 and int16 values +func decodeIntPair(encoded string) (int64, int16, error) { // Decode the Base32 string back into the original byte buffer. buf, err := base32.HexEncoding.WithPadding(base32.NoPadding).DecodeString(encoded) if err != nil { @@ -200,17 +201,17 @@ func decodeIntPair(encoded string) (int64, int64, error) { } // Check that the decoded buffer has the expected length. - if len(buf) != 16 { - return 0, 0, fmt.Errorf("invalid encoded string length: expected 16 bytes, got %d", len(buf)) + if len(buf) != 10 { + return 0, 0, fmt.Errorf("invalid encoded string length: expected 10 bytes, got %d", len(buf)) } - // Read the uint64 values from the buffer using big-endian byte order. + // Read the uint values from the buffer using big-endian byte order. aPos := binary.BigEndian.Uint64(buf[0:8]) - bPos := binary.BigEndian.Uint64(buf[8:16]) + bPos := binary.BigEndian.Uint16(buf[8:10]) - // Reverse the sign bit flip to retrieve the original int64 values. + // Reverse the sign bit flip to retrieve the original values. a := int64(aPos ^ (1 << 63)) - b := int64(bPos ^ (1 << 63)) + b := int16(bPos ^ (1 << 15)) return a, b, nil } diff --git a/event/events.go b/event/events.go index 56104123..1c173351 100644 --- a/event/events.go +++ b/event/events.go @@ -145,7 +145,7 @@ type Unsigned struct { InviteRoomState []StrippedState `json:"invite_room_state,omitempty"` BeeperHSOrder int64 `json:"com.beeper.hs.order,omitempty"` - BeeperHSSuborder int64 `json:"com.beeper.hs.suborder,omitempty"` + BeeperHSSuborder int16 `json:"com.beeper.hs.suborder,omitempty"` BeeperHSOrderString *BeeperEncodedOrder `json:"com.beeper.hs.order_string,omitempty"` BeeperFromBackup bool `json:"com.beeper.from_backup,omitempty"` } From 58a30b295818fbdd7e85bb917502845bbc82864c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 11 Dec 2024 17:39:30 +0200 Subject: [PATCH 0950/1647] bridgev2/commands: redact login commands with direct parameters --- bridgev2/commands/login.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go index 660c90d7..717ac194 100644 --- a/bridgev2/commands/login.go +++ b/bridgev2/commands/login.go @@ -105,6 +105,7 @@ func checkLoginCommandDirectParams(ce *Event, login bridgev2.LoginProcess, nextS return nil } input := make(map[string]string) + var shouldRedact bool for i, param := range nextStep.UserInputParams.Fields { param.FillDefaultValidate() input[param.ID], err = param.Validate(ce.Args[i]) @@ -112,6 +113,12 @@ func checkLoginCommandDirectParams(ce *Event, login bridgev2.LoginProcess, nextS ce.Reply("Invalid value for %s: %v", param.Name, err) return nil } + if param.Type == bridgev2.LoginInputFieldTypePassword || param.Type == bridgev2.LoginInputFieldTypeToken { + shouldRedact = true + } + } + if shouldRedact { + ce.Redact() } nextStep, err = login.(bridgev2.LoginProcessUserInput).SubmitUserInput(ce.Ctx, input) case bridgev2.LoginStepTypeCookies: @@ -128,6 +135,7 @@ func checkLoginCommandDirectParams(ce *Event, login bridgev2.LoginProcess, nextS } input[param.ID] = val } + ce.Redact() nextStep, err = login.(bridgev2.LoginProcessCookies).SubmitCookies(ce.Ctx, input) } if err != nil { @@ -162,7 +170,7 @@ func (uilcs *userInputLoginCommandState) promptNext(ce *Event) { func (uilcs *userInputLoginCommandState) submitNext(ce *Event) { field := uilcs.RemainingFields[0] field.FillDefaultValidate() - if field.Type == bridgev2.LoginInputFieldTypePassword { + if field.Type == bridgev2.LoginInputFieldTypePassword || field.Type == bridgev2.LoginInputFieldTypeToken { ce.Redact() } var err error From fb0563e5ca4c899459868fe803876d13815516ea Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 11 Dec 2024 17:43:01 +0200 Subject: [PATCH 0951/1647] bridgev2/login: add reauth hook in login process --- bridgev2/commands/login.go | 76 +++++++++++++++++++++++++++------ bridgev2/commands/processor.go | 2 +- bridgev2/login.go | 11 +++++ bridgev2/matrix/provisioning.go | 8 +++- 4 files changed, 81 insertions(+), 16 deletions(-) diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go index 717ac194..5c7ae57d 100644 --- a/bridgev2/commands/login.go +++ b/bridgev2/commands/login.go @@ -19,6 +19,7 @@ import ( "github.com/skip2/go-qrcode" "go.mau.fi/util/curl" + "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" @@ -36,6 +37,17 @@ var CommandLogin = &FullHandler{ RequiresLoginPermission: true, } +var CommandRelogin = &FullHandler{ + Func: fnLogin, + Name: "relogin", + Help: HelpMeta{ + Section: HelpSectionAuth, + Description: "Re-authenticate an existing login", + Args: "<_login ID_> [_flow ID_]", + }, + RequiresLoginPermission: true, +} + func formatFlowsReply(flows []bridgev2.LoginFlow) string { var buf strings.Builder for _, flow := range flows { @@ -45,6 +57,19 @@ func formatFlowsReply(flows []bridgev2.LoginFlow) string { } func fnLogin(ce *Event) { + var reauth *bridgev2.UserLogin + if ce.Command == "relogin" { + if len(ce.Args) == 0 { + ce.Reply("Usage: `$cmdprefix relogin [_flow ID_]`\n\nYour logins:\n\n%s", ce.User.GetFormattedUserLogins()) + return + } + reauth = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) + if reauth == nil { + ce.Reply("Login `%s` not found", ce.Args[0]) + return + } + ce.Args = ce.Args[1:] + } flows := ce.Bridge.Network.GetLoginFlows() var chosenFlowID string if len(ce.Args) > 0 { @@ -63,7 +88,11 @@ func fnLogin(ce *Event) { } else if len(flows) == 1 { chosenFlowID = flows[0].ID } else { - ce.Reply("Please specify a login flow, e.g. `login %s`.\n\n%s", flows[0].ID, formatFlowsReply(flows)) + if reauth != nil { + ce.Reply("Please specify a login flow, e.g. `relogin %s %s`.\n\n%s", reauth.ID, flows[0].ID, formatFlowsReply(flows)) + } else { + ce.Reply("Please specify a login flow, e.g. `login %s`.\n\n%s", flows[0].ID, formatFlowsReply(flows)) + } return } @@ -72,7 +101,13 @@ func fnLogin(ce *Event) { ce.Reply("Failed to prepare login process: %v", err) return } - nextStep, err := login.Start(ce.Ctx) + overridable, ok := login.(bridgev2.LoginProcessWithOverride) + var nextStep *bridgev2.LoginStep + if ok && reauth != nil { + nextStep, err = overridable.StartWithOverride(ce.Ctx, reauth) + } else { + nextStep, err = login.Start(ce.Ctx) + } if err != nil { ce.Reply("Failed to start login: %v", err) return @@ -80,7 +115,7 @@ func fnLogin(ce *Event) { nextStep = checkLoginCommandDirectParams(ce, login, nextStep) if nextStep != nil { - doLoginStep(ce, login, nextStep) + doLoginStep(ce, login, nextStep, reauth) } } @@ -150,6 +185,7 @@ type userInputLoginCommandState struct { Login bridgev2.LoginProcessUserInput Data map[string]string RemainingFields []bridgev2.LoginInputDataField + Override *bridgev2.UserLogin } func (uilcs *userInputLoginCommandState) promptNext(ce *Event) { @@ -187,7 +223,7 @@ func (uilcs *userInputLoginCommandState) submitNext(ce *Event) { if nextStep, err := uilcs.Login.SubmitUserInput(ce.Ctx, uilcs.Data); err != nil { ce.Reply("Failed to submit input: %v", err) } else { - doLoginStep(ce, uilcs.Login, nextStep) + doLoginStep(ce, uilcs.Login, nextStep, uilcs.Override) } } @@ -231,7 +267,7 @@ const ( contextKeyPrevEventID contextKey = iota ) -func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait, step *bridgev2.LoginStep) { +func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait, step *bridgev2.LoginStep, override *bridgev2.UserLogin) { prevEvent, ok := ce.Ctx.Value(contextKeyPrevEventID).(*id.EventID) if !ok { prevEvent = new(id.EventID) @@ -270,12 +306,13 @@ func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait, ce.Reply("Login failed: %v", err) return } - doLoginStep(ce, login, nextStep) + doLoginStep(ce, login, nextStep, override) } type cookieLoginCommandState struct { - Login bridgev2.LoginProcessCookies - Data *bridgev2.LoginCookiesParams + Login bridgev2.LoginProcessCookies + Data *bridgev2.LoginCookiesParams + Override *bridgev2.UserLogin } func (clcs *cookieLoginCommandState) prompt(ce *Event) { @@ -387,7 +424,7 @@ func (clcs *cookieLoginCommandState) submit(ce *Event) { ce.Reply("Login failed: %v", err) return } - doLoginStep(ce, clcs.Login, nextStep) + doLoginStep(ce, clcs.Login, nextStep, clcs.Override) } func maybeURLDecodeCookie(val string, field *bridgev2.LoginCookieField) string { @@ -407,27 +444,38 @@ func maybeURLDecodeCookie(val string, field *bridgev2.LoginCookieField) string { return decoded } -func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginStep) { +func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginStep, override *bridgev2.UserLogin) { if step.Instructions != "" { ce.Reply(step.Instructions) } switch step.Type { case bridgev2.LoginStepTypeDisplayAndWait: - doLoginDisplayAndWait(ce, login.(bridgev2.LoginProcessDisplayAndWait), step) + doLoginDisplayAndWait(ce, login.(bridgev2.LoginProcessDisplayAndWait), step, override) case bridgev2.LoginStepTypeCookies: (&cookieLoginCommandState{ - Login: login.(bridgev2.LoginProcessCookies), - Data: step.CookiesParams, + Login: login.(bridgev2.LoginProcessCookies), + Data: step.CookiesParams, + Override: override, }).prompt(ce) case bridgev2.LoginStepTypeUserInput: (&userInputLoginCommandState{ Login: login.(bridgev2.LoginProcessUserInput), RemainingFields: step.UserInputParams.Fields, Data: make(map[string]string), + Override: override, }).promptNext(ce) case bridgev2.LoginStepTypeComplete: - // Nothing to do other than instructions + if override != nil && override.ID != step.CompleteParams.UserLoginID { + ce.Log.Info(). + Str("old_login_id", string(override.ID)). + Str("new_login_id", string(step.CompleteParams.UserLoginID)). + Msg("Login resulted in different remote ID than what was being overridden. Deleting previous login") + override.Delete(ce.Ctx, status.BridgeState{ + StateEvent: status.StateLoggedOut, + Reason: "LOGIN_OVERRIDDEN", + }, bridgev2.DeleteOpts{LogoutRemote: true}) + } default: panic(fmt.Errorf("unknown login step type %q", step.Type)) } diff --git a/bridgev2/commands/processor.go b/bridgev2/commands/processor.go index 1aca596c..3343e1ba 100644 --- a/bridgev2/commands/processor.go +++ b/bridgev2/commands/processor.go @@ -43,7 +43,7 @@ func NewProcessor(bridge *bridgev2.Bridge) bridgev2.CommandProcessor { proc.AddHandlers( CommandHelp, CommandCancel, CommandRegisterPush, CommandDeletePortal, CommandDeleteAllPortals, - CommandLogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin, + CommandLogin, CommandRelogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin, CommandSetRelay, CommandUnsetRelay, CommandResolveIdentifier, CommandStartChat, CommandSearch, CommandSudo, CommandDoIn, diff --git a/bridgev2/login.go b/bridgev2/login.go index 9e8be655..b28ccfdb 100644 --- a/bridgev2/login.go +++ b/bridgev2/login.go @@ -32,6 +32,17 @@ type LoginProcess interface { Cancel() } +type LoginProcessWithOverride interface { + LoginProcess + // StartWithOverride starts the process with the intent of re-authenticating an existing login. + // + // The call to this is mutually exclusive with the call to the default Start method. + // + // The user login being overridden will still be logged out automatically + // in case the complete step returns a different login. + StartWithOverride(ctx context.Context, override *UserLogin) (*LoginStep, error) +} + type LoginProcessDisplayAndWait interface { LoginProcess Wait(ctx context.Context) (*LoginStep, error) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 51465d05..87f6576d 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -411,7 +411,13 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque RespondWithError(w, err, "Internal error creating login process") return } - firstStep, err := login.Start(r.Context()) + var firstStep *bridgev2.LoginStep + overridable, ok := login.(bridgev2.LoginProcessWithOverride) + if ok && overrideLogin != nil { + firstStep, err = overridable.StartWithOverride(r.Context(), overrideLogin) + } else { + firstStep, err = login.Start(r.Context()) + } if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to start login") RespondWithError(w, err, "Internal error starting login") From 48b7b3aca5b76810e329ca4696a4acfe142121e1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 11 Dec 2024 22:56:11 +0200 Subject: [PATCH 0952/1647] bridgev2/portal: don't try to backfill spaces --- bridgev2/portal.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 45668aa3..b1aae9e7 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2639,7 +2639,7 @@ func (portal *Portal) handleRemoteChatResync(ctx context.Context, source *UserLo } } backfillChecker, ok := evt.(RemoteChatResyncBackfill) - if portal.Bridge.Config.Backfill.Enabled && ok { + if portal.Bridge.Config.Backfill.Enabled && ok && portal.RoomType != database.RoomTypeSpace { latestMessage, err := portal.Bridge.DB.Message.GetLastPartAtOrBeforeTime(ctx, portal.PortalKey, time.Now().Add(10*time.Second)) if err != nil { log.Err(err).Msg("Failed to get last message in portal to check if backfill is necessary") @@ -3636,7 +3636,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo log.Err(err).Msg("Failed to save portal to database after creating Matrix room") return err } - if info.CanBackfill { + if info.CanBackfill && portal.RoomType != database.RoomTypeSpace { err = portal.Bridge.DB.BackfillTask.Upsert(ctx, &database.BackfillTask{ PortalKey: portal.PortalKey, UserLoginID: source.ID, @@ -3700,7 +3700,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo } } } - if portal.Bridge.Config.Backfill.Enabled { + if portal.Bridge.Config.Backfill.Enabled && portal.RoomType != database.RoomTypeSpace { portal.doForwardBackfill(ctx, source, nil, backfillBundle) } return nil From bfdd0efd0e2ef931cb3e9a0642cbab05b5d9fed9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 12 Dec 2024 02:46:32 +0200 Subject: [PATCH 0953/1647] dependencies: update --- go.mod | 20 ++++++++++---------- go.sum | 36 ++++++++++++++++++------------------ 2 files changed, 28 insertions(+), 28 deletions(-) diff --git a/go.mod b/go.mod index 8bf9baac..965e6558 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,7 @@ module maunium.net/go/mautrix go 1.22.0 -toolchain go1.23.3 +toolchain go1.23.4 require ( filippo.io/edwards25519 v1.1.0 @@ -14,16 +14,16 @@ require ( github.com/rs/xid v1.6.0 github.com/rs/zerolog v1.33.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e - github.com/stretchr/testify v1.9.0 + github.com/stretchr/testify v1.10.0 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.8 - go.mau.fi/util v0.8.2 + go.mau.fi/util v0.8.3-0.20241212004537-24c1a9b1d8f6 go.mau.fi/zeroconfig v0.1.3 - golang.org/x/crypto v0.29.0 - golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f - golang.org/x/net v0.31.0 - golang.org/x/sync v0.9.0 + golang.org/x/crypto v0.31.0 + golang.org/x/exp v0.0.0-20241210194714-1829a127f884 + golang.org/x/net v0.32.0 + golang.org/x/sync v0.10.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 ) @@ -33,11 +33,11 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect - github.com/petermattis/goid v0.0.0-20241025130422-66cb2e6d7274 // indirect + github.com/petermattis/goid v0.0.0-20241211131331-93ee7e083c43 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect - golang.org/x/sys v0.27.0 // indirect - golang.org/x/text v0.20.0 // indirect + golang.org/x/sys v0.28.0 // indirect + golang.org/x/text v0.21.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index 205cbfaf..7dfcbbd8 100644 --- a/go.sum +++ b/go.sum @@ -26,8 +26,8 @@ github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APP github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/petermattis/goid v0.0.0-20241025130422-66cb2e6d7274 h1:qli3BGQK0tYDkSEvZ/FzZTi9ZrOX86Q6CIhKLGc489A= -github.com/petermattis/goid v0.0.0-20241025130422-66cb2e6d7274/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/petermattis/goid v0.0.0-20241211131331-93ee7e083c43 h1:ah1dvbqPMN5+ocrg/ZSgZ6k8bOk+kcZQ7fnyx6UvOm4= +github.com/petermattis/goid v0.0.0-20241211131331-93ee7e083c43/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -38,8 +38,8 @@ github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= @@ -51,26 +51,26 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.8.2 h1:zWbVHwdRKwI6U9AusmZ8bwgcLosikwbb4GGqLrNr1YE= -go.mau.fi/util v0.8.2/go.mod h1:BHHC9R2WLMJd1bwTZfTcFxUgRFmUgUmiWcT4RbzUgiA= +go.mau.fi/util v0.8.3-0.20241212004537-24c1a9b1d8f6 h1:s4aQJQBBMkjkHTdThRRNp2E35wqmMJVGkOAXG4b0X8c= +go.mau.fi/util v0.8.3-0.20241212004537-24c1a9b1d8f6/go.mod h1:sWpI/kFgk/QP4BDJJwSjjBsuAT7oUsj9VlFhvUo+34I= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ= -golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg= -golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f h1:XdNn9LlyWAhLVp6P/i8QYBW+hlyhrhei9uErw2B5GJo= -golang.org/x/exp v0.0.0-20241108190413-2d47ceb2692f/go.mod h1:D5SMRVC3C2/4+F/DB1wZsLRnSNimn2Sp/NPsCrsv8ak= -golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo= -golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM= -golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ= -golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= +golang.org/x/exp v0.0.0-20241210194714-1829a127f884 h1:Y/Mj/94zIQQGHVSv1tTtQBDaQaJe62U9bkDZKKyhPCU= +golang.org/x/exp v0.0.0-20241210194714-1829a127f884/go.mod h1:qj5a5QZpwLU2NLQudwIN5koi3beDhSAlJwa67PuM98c= +golang.org/x/net v0.32.0 h1:ZqPmj8Kzc+Y6e0+skZsuACbx+wzMgo5MQsJh9Qd6aYI= +golang.org/x/net v0.32.0/go.mod h1:CwU0IoeOlnQQWJ6ioyFrfRuomB8GKF6KbYXZVyeXNfs= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= -golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= -golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= From 351b49f8a893350508922953663c0fb3869edb14 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 15 Dec 2024 12:31:54 +0200 Subject: [PATCH 0954/1647] bridgev2/matrixinvite: add separate interface for creating DM with ghost --- bridgev2/errors.go | 5 +++-- bridgev2/matrixinvite.go | 21 +++++++++++++++------ bridgev2/networkinterface.go | 13 +++++++++++++ 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/bridgev2/errors.go b/bridgev2/errors.go index 052a606b..789d0026 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -22,8 +22,9 @@ var ErrIgnoringRemoteEvent = errors.New("ignoring remote event") // and a status should not be sent yet. The message will still be saved into the database. var ErrNoStatus = errors.New("omit message status") -// ErrResolveIdentifierTryNext can be returned by ResolveIdentifier to signal that the identifier is valid, -// but can't be reached by the current login, and the caller should try the next login if there are more. +// ErrResolveIdentifierTryNext can be returned by ResolveIdentifier or CreateChatWithGhost to signal that +// the identifier is valid, but can't be reached by the current login, and the caller should try the next +// login if there are more. // // This should generally only be returned when resolving internal IDs (which happens when initiating chats via Matrix). // For example, Google Messages would return this when trying to resolve another login's user ID, diff --git a/bridgev2/matrixinvite.go b/bridgev2/matrixinvite.go index 25938c4f..f8217700 100644 --- a/bridgev2/matrixinvite.go +++ b/bridgev2/matrixinvite.go @@ -119,7 +119,7 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen log.Err(err).Msg("Failed to accept invite to room") return } - var resp *ResolveIdentifierResponse + var resp *CreateChatResponse var sourceLogin *UserLogin // TODO this should somehow lock incoming event processing to avoid race conditions where a new portal room is created // between ResolveIdentifier returning and the portal MXID being updated. @@ -128,7 +128,16 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen if !ok { continue } - resp, err = api.ResolveIdentifier(ctx, string(ghostID), true) + var resolveResp *ResolveIdentifierResponse + ghostAPI, ok := login.Client.(GhostDMCreatingNetworkAPI) + if ok { + resp, err = ghostAPI.CreateChatWithGhost(ctx, invitedGhost) + } else { + resolveResp, err = api.ResolveIdentifier(ctx, string(ghostID), true) + if resolveResp != nil { + resp = resolveResp.Chat + } + } if errors.Is(err, ErrResolveIdentifierTryNext) { log.Debug().Err(err).Str("login_id", string(login.ID)).Msg("Failed to resolve identifier, trying next login") continue @@ -146,9 +155,9 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen sendErrorAndLeave(ctx, evt, br.Matrix.GhostIntent(ghostID), "Failed to create chat via any login") return } - portal := resp.Chat.Portal + portal := resp.Portal if portal == nil { - portal, err = br.GetPortalByKey(ctx, resp.Chat.PortalKey) + portal, err = br.GetPortalByKey(ctx, resp.PortalKey) if err != nil { log.Err(err).Msg("Failed to get portal by key") sendErrorAndLeave(ctx, evt, br.Matrix.GhostIntent(ghostID), "Failed to create portal entry") @@ -169,8 +178,8 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen } didSetPortal := portal.setMXIDToExistingRoom(evt.RoomID) - if resp.Chat.PortalInfo != nil { - portal.UpdateInfo(ctx, resp.Chat.PortalInfo, sourceLogin, nil, time.Time{}) + if resp.PortalInfo != nil { + portal.UpdateInfo(ctx, resp.PortalInfo, sourceLogin, nil, time.Time{}) } if didSetPortal { // TODO this might become unnecessary if UpdateInfo starts taking care of it diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 9acc6eae..8ddf1269 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -246,6 +246,9 @@ type DirectMediableNetwork interface { Download(ctx context.Context, mediaID networkid.MediaID, params map[string]string) (mediaproxy.GetMediaResponse, error) } +// IdentifierValidatingNetwork is an optional interface that network connectors can implement to validate the shape of user IDs. +// +// This should not perform any checks to see if the user ID actually exists on the network, just that the user ID looks valid. type IdentifierValidatingNetwork interface { NetworkConnector ValidateUserID(id networkid.UserID) bool @@ -662,6 +665,16 @@ type IdentifierResolvingNetworkAPI interface { ResolveIdentifier(ctx context.Context, identifier string, createChat bool) (*ResolveIdentifierResponse, error) } +// GhostDMCreatingNetworkAPI is an optional extension to IdentifierResolvingNetworkAPI for starting chats with pre-validated user IDs. +type GhostDMCreatingNetworkAPI interface { + IdentifierResolvingNetworkAPI + // CreateChatWithGhost may be called instead of [IdentifierResolvingNetworkAPI.ResolveIdentifier] + // when starting a chat with an internal user identifier that has been pre-validated using + // [IdentifierValidatingNetwork.ValidateUserID]. If this is not implemented, ResolveIdentifier + // will be used instead (by stringifying the ghost ID). + CreateChatWithGhost(ctx context.Context, ghost *Ghost) (*CreateChatResponse, error) +} + // ContactListingNetworkAPI is an optional interface that network connectors can implement to provide the user's contact list. type ContactListingNetworkAPI interface { NetworkAPI From 2dcf2f9244a5591cb08ae567eaaca2f55e4dbbfd Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 16 Dec 2024 15:59:29 +0200 Subject: [PATCH 0955/1647] Bump version to v0.22.1 --- CHANGELOG.md | 23 +++++++++++++++++++++++ go.mod | 4 ++-- go.sum | 8 ++++---- version.go | 2 +- 4 files changed, 30 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 18d96b35..0cc96b60 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,26 @@ +## v0.22.1 (2024-12-16) + +* *(crypto)* Added automatic cleanup when there are too many olm sessions with + a single device. +* *(crypto)* Added helper for getting cached device list with cross-signing + status. +* *(crypto/verificationhelper)* Added interface for persisting the state of + in-progress verifications. +* *(client)* Added `GetMutualRooms` wrapper for [MSC2666]. +* *(client)* Switched `JoinRoom` to use the `via` query param instead of + `server_name` as per [MSC4156]. +* *(bridgev2/commands)* Fixed `pm` command not actually starting the chat. +* *(bridgev2/interface)* Added separate network API interface for starting + chats with a Matrix ghost user. This allows treating internal user IDs + differently than arbitrary user-input strings. +* *(bridgev2/crypto)* Added support for [MSC4190] + (thanks to [@onestacked] in [#288]). + +[MSC2666]: https://github.com/matrix-org/matrix-spec-proposals/pull/2666 +[MSC4156]: https://github.com/matrix-org/matrix-spec-proposals/pull/4156 +[MSC4190]: https://github.com/matrix-org/matrix-spec-proposals/pull/4190 +[#288]: https://github.com/mautrix/go/pull/288 + ## v0.22.0 (2024-11-16) * *(hicli)* Moved package into gomuks repo. diff --git a/go.mod b/go.mod index 965e6558..c686489a 100644 --- a/go.mod +++ b/go.mod @@ -18,10 +18,10 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.8 - go.mau.fi/util v0.8.3-0.20241212004537-24c1a9b1d8f6 + go.mau.fi/util v0.8.3 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.31.0 - golang.org/x/exp v0.0.0-20241210194714-1829a127f884 + golang.org/x/exp v0.0.0-20241215155358-4a5509556b9e golang.org/x/net v0.32.0 golang.org/x/sync v0.10.0 gopkg.in/yaml.v3 v3.0.1 diff --git a/go.sum b/go.sum index 7dfcbbd8..47ac74ef 100644 --- a/go.sum +++ b/go.sum @@ -51,14 +51,14 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.8.3-0.20241212004537-24c1a9b1d8f6 h1:s4aQJQBBMkjkHTdThRRNp2E35wqmMJVGkOAXG4b0X8c= -go.mau.fi/util v0.8.3-0.20241212004537-24c1a9b1d8f6/go.mod h1:sWpI/kFgk/QP4BDJJwSjjBsuAT7oUsj9VlFhvUo+34I= +go.mau.fi/util v0.8.3 h1:sulhXtfquMrQjsOP67x9CzWVBYUwhYeoo8hNQIpCWZ4= +go.mau.fi/util v0.8.3/go.mod h1:c00Db8xog70JeIsEvhdHooylTkTkakgnAOsZ04hplQY= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= -golang.org/x/exp v0.0.0-20241210194714-1829a127f884 h1:Y/Mj/94zIQQGHVSv1tTtQBDaQaJe62U9bkDZKKyhPCU= -golang.org/x/exp v0.0.0-20241210194714-1829a127f884/go.mod h1:qj5a5QZpwLU2NLQudwIN5koi3beDhSAlJwa67PuM98c= +golang.org/x/exp v0.0.0-20241215155358-4a5509556b9e h1:4qufH0hlUYs6AO6XmZC3GqfDPGSXHVXUFR6OND+iJX4= +golang.org/x/exp v0.0.0-20241215155358-4a5509556b9e/go.mod h1:qj5a5QZpwLU2NLQudwIN5koi3beDhSAlJwa67PuM98c= golang.org/x/net v0.32.0 h1:ZqPmj8Kzc+Y6e0+skZsuACbx+wzMgo5MQsJh9Qd6aYI= golang.org/x/net v0.32.0/go.mod h1:CwU0IoeOlnQQWJ6ioyFrfRuomB8GKF6KbYXZVyeXNfs= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= diff --git a/version.go b/version.go index dd70d55b..362a684b 100644 --- a/version.go +++ b/version.go @@ -7,7 +7,7 @@ import ( "strings" ) -const Version = "v0.22.0" +const Version = "v0.22.1" var GoModVersion = "" var Commit = "" From 15ab545e728e2d229027978da444a0d4cc02e24d Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Mon, 16 Dec 2024 15:35:15 +0000 Subject: [PATCH 0956/1647] crypto: add background context to olm machine Defaults to `context.Background()` but can be passed any context to support cancelling background jobs the olm instance might be executing. --- crypto/decryptolm.go | 2 +- crypto/machine.go | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go index 8f1eb1f7..788bf832 100644 --- a/crypto/decryptolm.go +++ b/crypto/decryptolm.go @@ -265,7 +265,7 @@ const MinUnwedgeInterval = 1 * time.Hour func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, senderKey id.SenderKey) { log = log.With().Str("action", "unwedge olm session").Logger() - ctx := log.WithContext(context.TODO()) + ctx := log.WithContext(mach.BackgroundCtx) mach.recentlyUnwedgedLock.Lock() prevUnwedge, ok := mach.recentlyUnwedged[senderKey] delta := time.Now().Sub(prevUnwedge) diff --git a/crypto/machine.go b/crypto/machine.go index 7c1093f3..95d86d42 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -33,6 +33,8 @@ type OlmMachine struct { CryptoStore Store StateStore StateStore + BackgroundCtx context.Context + PlaintextMentions bool // Never ask the server for keys automatically as a side effect during Megolm decryption. @@ -112,6 +114,8 @@ func NewOlmMachine(client *mautrix.Client, log *zerolog.Logger, cryptoStore Stor CryptoStore: cryptoStore, StateStore: stateStore, + BackgroundCtx: context.Background(), + SendKeysMinTrust: id.TrustStateUnset, ShareKeysMinTrust: id.TrustStateCrossSignedTOFU, From 742af7f70b71c9aec289fd6dd6a0af73ece86d79 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 13 Dec 2024 14:53:24 -0700 Subject: [PATCH 0957/1647] status/bridgestate: add IsValid function Signed-off-by: Sumner Evans --- bridge/status/bridgestate.go | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/bridge/status/bridgestate.go b/bridge/status/bridgestate.go index 1aa4bb1f..73410df6 100644 --- a/bridge/status/bridgestate.go +++ b/bridge/status/bridgestate.go @@ -53,6 +53,26 @@ const ( StateLoggedOut BridgeStateEvent = "LOGGED_OUT" ) +func (e BridgeStateEvent) IsValid() bool { + switch e { + case + StateStarting, + StateUnconfigured, + StateRunning, + StateBridgeUnreachable, + StateConnecting, + StateBackfilling, + StateConnected, + StateTransientDisconnect, + StateBadCredentials, + StateUnknownError, + StateLoggedOut: + return true + default: + return false + } +} + type RemoteProfile struct { Phone string `json:"phone,omitempty"` Email string `json:"email,omitempty"` From 6bf21f101946c3d74007ddd9326eb45a5fda758d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 17 Dec 2024 15:44:08 +0200 Subject: [PATCH 0958/1647] bridgev2/config: move msc4190 flag to encryption section --- bridgev2/bridgeconfig/appservice.go | 9 +++++++-- bridgev2/bridgeconfig/encryption.go | 1 + bridgev2/bridgeconfig/upgrade.go | 6 +++++- bridgev2/matrix/crypto.go | 2 +- bridgev2/matrix/mxmain/example-config.yaml | 10 ++++++---- 5 files changed, 20 insertions(+), 8 deletions(-) diff --git a/bridgev2/bridgeconfig/appservice.go b/bridgev2/bridgeconfig/appservice.go index 89ce5677..f709c8e0 100644 --- a/bridgev2/bridgeconfig/appservice.go +++ b/bridgev2/bridgeconfig/appservice.go @@ -34,7 +34,6 @@ type AppserviceConfig struct { EphemeralEvents bool `yaml:"ephemeral_events"` AsyncTransactions bool `yaml:"async_transactions"` - MSC4190 bool `yaml:"msc4190"` UsernameTemplate string `yaml:"username_template"` usernameTemplate *template.Template `yaml:"-"` @@ -78,7 +77,11 @@ func (asc *AppserviceConfig) copyToRegistration(registration *appservice.Registr registration.RateLimited = &falseVal registration.EphemeralEvents = asc.EphemeralEvents registration.SoruEphemeralEvents = asc.EphemeralEvents - registration.MSC4190 = asc.MSC4190 +} + +func (ec *EncryptionConfig) applyUnstableFlags(registration *appservice.Registration) { + registration.MSC4190 = ec.MSC4190 + registration.MSC3202 = ec.Appservice } // GenerateRegistration generates a registration file for the homeserver. @@ -87,6 +90,7 @@ func (config *Config) GenerateRegistration() *appservice.Registration { config.AppService.HSToken = registration.ServerToken config.AppService.ASToken = registration.AppToken config.AppService.copyToRegistration(registration) + config.Encryption.applyUnstableFlags(registration) registration.SenderLocalpart = random.String(32) botRegex := regexp.MustCompile(fmt.Sprintf("^@%s:%s$", @@ -105,6 +109,7 @@ func (config *Config) MakeAppService() *appservice.AppService { as.Host.Hostname = config.AppService.Hostname as.Host.Port = config.AppService.Port as.Registration = config.AppService.GetRegistration() + config.Encryption.applyUnstableFlags(as.Registration) return as } diff --git a/bridgev2/bridgeconfig/encryption.go b/bridgev2/bridgeconfig/encryption.go index 93a427d3..1ef7e18f 100644 --- a/bridgev2/bridgeconfig/encryption.go +++ b/bridgev2/bridgeconfig/encryption.go @@ -15,6 +15,7 @@ type EncryptionConfig struct { Default bool `yaml:"default"` Require bool `yaml:"require"` Appservice bool `yaml:"appservice"` + MSC4190 bool `yaml:"msc4190"` PlaintextMentions bool `yaml:"plaintext_mentions"` diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 776fa44d..17c4af13 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -82,7 +82,6 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Str, "appservice", "bot", "avatar") helper.Copy(up.Bool, "appservice", "ephemeral_events") helper.Copy(up.Bool, "appservice", "async_transactions") - helper.Copy(up.Bool, "appservice", "msc4190") helper.Copy(up.Str, "appservice", "as_token") helper.Copy(up.Str, "appservice", "hs_token") helper.Copy(up.Str, "appservice", "username_template") @@ -147,6 +146,11 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "encryption", "default") helper.Copy(up.Bool, "encryption", "require") helper.Copy(up.Bool, "encryption", "appservice") + if val, ok := helper.Get(up.Bool, "appservice", "msc4190"); ok { + helper.Set(up.Bool, val, "encryption", "msc4190") + } else { + helper.Copy(up.Bool, "encryption", "msc4190") + } helper.Copy(up.Bool, "encryption", "allow_key_sharing") if secret, ok := helper.Get(up.Str, "encryption", "pickle_key"); !ok || secret == "generate" { helper.Set(up.Str, random.String(64), "encryption", "pickle_key") diff --git a/bridgev2/matrix/crypto.go b/bridgev2/matrix/crypto.go index 3cb16e52..df6f7a63 100644 --- a/bridgev2/matrix/crypto.go +++ b/bridgev2/matrix/crypto.go @@ -243,7 +243,7 @@ func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool client := helper.bridge.AS.NewMautrixClient(helper.bridge.AS.BotMXID()) initialDeviceDisplayName := fmt.Sprintf("%s bridge", helper.bridge.Bridge.Network.GetName().DisplayName) - if helper.bridge.Config.AppService.MSC4190 { + if helper.bridge.Config.Encryption.MSC4190 { helper.log.Debug().Msg("Creating bot device with MSC4190") err = client.CreateDeviceMSC4190(ctx, deviceID, initialDeviceDisplayName) if err != nil { diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 8f7655fc..d09047c3 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -193,10 +193,6 @@ appservice: # However, messages will not be guaranteed to be bridged in the same order they were sent in. # This value doesn't affect the registration file. async_transactions: false - # Whether to use MSC4190 instead of appservice login to create the bridge bot device. - # Requires the homeserver to support MSC4190 and the device masquerading parts of MSC3202. - # Only relevant when using end-to-bridge encryption, required when using encryption with next-gen auth (MSC3861). - msc4190: false # Authentication tokens for AS <-> HS communication. Autogenerated; do not modify. as_token: "This value is generated when generating the registration" @@ -343,7 +339,13 @@ encryption: require: false # Whether to use MSC2409/MSC3202 instead of /sync long polling for receiving encryption-related data. # This option is not yet compatible with standard Matrix servers like Synapse and should not be used. + # Changing this option requires updating the appservice registration file. appservice: false + # Whether to use MSC4190 instead of appservice login to create the bridge bot device. + # Requires the homeserver to support MSC4190 and the device masquerading parts of MSC3202. + # Only relevant when using end-to-bridge encryption, required when using encryption with next-gen auth (MSC3861). + # Changing this option requires updating the appservice registration file. + msc4190: false # Enable key sharing? If enabled, key requests for rooms where users are in will be fulfilled. # You must use a client that supports requesting keys from other users to use this feature. allow_key_sharing: true From 513bae79b83c9a606fd475d626c28d366d79c033 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 17 Dec 2024 21:02:04 +0200 Subject: [PATCH 0959/1647] bridgev2/config: fix missing database config error message --- bridgev2/matrix/mxmain/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index 2c0c07b9..46f27e73 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -310,7 +310,7 @@ func (br *BridgeMain) validateConfig() error { case br.Config.AppService.HSToken == "This value is generated when generating the registration": return errors.New("appservice.hs_token not configured. Did you forget to generate the registration? ") case br.Config.Database.URI == "postgres://user:password@host/database?sslmode=disable": - return errors.New("appservice.database not configured") + return errors.New("database.uri not configured") case !br.Config.Bridge.Permissions.IsConfigured(): return errors.New("bridge.permissions not configured") case !strings.Contains(br.Config.AppService.FormatUsername("1234567890"), "1234567890"): From a1a87918600cc3d0f3349e7c7d486cf8ec3776c5 Mon Sep 17 00:00:00 2001 From: SpiritCroc Date: Wed, 18 Dec 2024 17:46:39 +0100 Subject: [PATCH 0960/1647] crypto/verificationhelper: add cross-signing to SAS verification path (#332) When verifying a new device via SAS / emoji verification, we never cross-sign the new device, which thus never fully finishes verification. This commit refactors a bunch of the key signing code and makes the SAS and QR methods behave more similarly. --------- Signed-off-by: Sumner Evans Co-authored-by: Sumner Evans --- crypto/verificationhelper/reciprocate.go | 41 ++++++++------ crypto/verificationhelper/sas.go | 70 +++++++++++++++++++----- 2 files changed, 79 insertions(+), 32 deletions(-) diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index ef69f23c..33dccef9 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -212,27 +212,34 @@ func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id log.Info().Msg("Confirming QR code scanned") + // Get their device + theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) + if err != nil { + return err + } + + // Trust their device + theirDevice.Trust = id.TrustStateVerified + err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) + if err != nil { + return fmt.Errorf("failed to update device trust state after verifying: %w", err) + } + if txn.TheirUserID == vh.client.UserID { - // Self-signing situation. Trust their device. - - // Get their device - theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) - if err != nil { - return err - } - - // Trust their device - theirDevice.Trust = id.TrustStateVerified - err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) - if err != nil { - return fmt.Errorf("failed to update device trust state after verifying: %w", err) - } - - // Cross-sign their device with the self-signing key + // Self-signing situation. + // + // If we have the cross-signing keys, then we need to sign their device + // using the self-signing key. Otherwise, they have the master private + // key, so we need to trust the master public key. if vh.mach.CrossSigningKeys != nil { err = vh.mach.SignOwnDevice(ctx, theirDevice) if err != nil { - return fmt.Errorf("failed to sign their device: %w", err) + return fmt.Errorf("failed to sign our own new device: %w", err) + } + } else { + err = vh.mach.SignOwnMasterKey(ctx) + if err != nil { + return fmt.Errorf("failed to sign our own master key: %w", err) } } } else { diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index 81728bd4..178838b8 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -13,6 +13,7 @@ import ( "crypto/hmac" "crypto/rand" "crypto/sha256" + "crypto/subtle" "encoding/base64" "encoding/json" "errors" @@ -336,13 +337,7 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn Verific return } if !bytes.Equal(commitment, txn.Commitment) { - err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationCancel, &event.VerificationCancelEventContent{ - Code: event.VerificationCancelCodeKeyMismatch, - Reason: "The key was not the one we expected.", - }) - if err != nil { - log.Err(err).Msg("Failed to send cancellation event") - } + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "The key was not the one we expected") return } } else { @@ -593,12 +588,15 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific // Verifying Keys MAC log.Info().Msg("Verifying MAC for all sent keys") var hasTheirDeviceKey bool + var masterKey string var keyIDs []string for keyID := range macEvt.MAC { keyIDs = append(keyIDs, keyID.String()) _, kID := keyID.Parse() if kID == txn.TheirDeviceID.String() { hasTheirDeviceKey = true + } else { + masterKey = kID } } slices.Sort(keyIDs) @@ -617,6 +615,7 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific } // Verify the MAC for each key + var theirDevice *id.Device for keyID, mac := range macEvt.MAC { log.Info().Str("key_id", keyID.String()).Msg("Received MAC for key") @@ -627,8 +626,11 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific } var key string - var theirDevice *id.Device if kID == txn.TheirDeviceID.String() { + if theirDevice != nil { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInvalidMessage, "two keys found for their device ID") + return + } theirDevice, err = vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to fetch their device: %w", err) @@ -653,22 +655,60 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to calculate key MAC: %w", err) return } - if !bytes.Equal(expectedMAC, mac) { + if subtle.ConstantTimeCompare(expectedMAC, mac) == 0 { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeSASMismatch, "MAC mismatch for key %s", keyID) return } + } + log.Info().Msg("All MACs verified") - // Trust their device - if kID == txn.TheirDeviceID.String() { - theirDevice.Trust = id.TrustStateVerified - err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) + // Trust their device + theirDevice.Trust = id.TrustStateVerified + err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) + if err != nil { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to update device trust state after verifying: %w", err) + return + } + + if txn.TheirUserID == vh.client.UserID { + // Self-signing situation. + // + // If we have the cross-signing keys, then we need to sign their device + // using the self-signing key. Otherwise, they have the master private + // key, so we need to trust the master public key. + if vh.mach.CrossSigningKeys != nil { + err = vh.mach.SignOwnDevice(ctx, theirDevice) if err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to update device trust state after verifying: %w", err) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to sign our own new device: %w", err) + return + } + } else { + err = vh.mach.SignOwnMasterKey(ctx) + if err != nil { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to sign our own master key: %w", err) return } } + } else if masterKey != "" { + // Cross-signing situation. + // + // The master key was included in the list of keys to verify, so verify + // that it matches what we expect and sign their master key using the + // user-signing key. + theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID) + if err != nil { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "couldn't get %s's cross-signing keys: %w", txn.TheirUserID, err) + return + } else if theirSigningKeys.MasterKey.String() != masterKey { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "master keys do not match") + return + } + + if err := vh.mach.SignUser(ctx, txn.TheirUserID, theirSigningKeys.MasterKey); err != nil { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to sign their master key: %w", err) + return + } } - log.Info().Msg("All MACs verified") txn.ReceivedTheirMAC = true if txn.SentOurMAC { From 6f47d6abfcbffad0dc2488838e12ecf14e116fe9 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 19 Dec 2024 09:57:34 -0700 Subject: [PATCH 0961/1647] responses: add m.get_login_token to capabilities See https://github.com/matrix-org/matrix-spec-proposals/pull/3882 Signed-off-by: Sumner Evans --- responses.go | 1 + 1 file changed, 1 insertion(+) diff --git a/responses.go b/responses.go index 6ead355e..7b62b433 100644 --- a/responses.go +++ b/responses.go @@ -449,6 +449,7 @@ type RespCapabilities struct { SetDisplayname *CapBooleanTrue `json:"m.set_displayname,omitempty"` SetAvatarURL *CapBooleanTrue `json:"m.set_avatar_url,omitempty"` ThreePIDChanges *CapBooleanTrue `json:"m.3pid_changes,omitempty"` + GetLoginToken *CapBooleanTrue `json:"m.get_login_token,omitempty"` Custom map[string]interface{} `json:"-"` } From 1b78c8398988813820cf9aed6460b389fa8b7e93 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 18 Dec 2024 15:10:20 +0200 Subject: [PATCH 0962/1647] bridgev2/space: ensure adding portal to space isn't cancelled --- bridgev2/space.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/space.go b/bridgev2/space.go index 17388f3e..11de9cfa 100644 --- a/bridgev2/space.go +++ b/bridgev2/space.go @@ -43,7 +43,7 @@ func (ul *UserLogin) MarkInPortal(ctx context.Context, portal *Portal) { } } if ul.Bridge.Config.PersonalFilteringSpaces && (userPortal.InSpace == nil || !*userPortal.InSpace) { - go ul.tryAddPortalToSpace(ctx, portal, userPortal.CopyWithoutValues()) + go ul.tryAddPortalToSpace(context.WithoutCancel(ctx), portal, userPortal.CopyWithoutValues()) } } } From ddcb5fa6c53c2544530b4fbdb83b0c790d39a67f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 19 Dec 2024 19:05:00 +0200 Subject: [PATCH 0963/1647] versions: add new spec version constants --- versions.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/versions.go b/versions.go index 5c0d6eaa..a8728c34 100644 --- a/versions.go +++ b/versions.go @@ -111,6 +111,8 @@ var ( SpecV19 = MustParseSpecVersion("v1.9") SpecV110 = MustParseSpecVersion("v1.10") SpecV111 = MustParseSpecVersion("v1.11") + SpecV112 = MustParseSpecVersion("v1.12") + SpecV113 = MustParseSpecVersion("v1.13") ) func (svf SpecVersionFormat) String() string { From 4b4599d4ab165a0f8a3060cd7b9e50501f7a4aa6 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 19 Dec 2024 12:56:15 -0700 Subject: [PATCH 0964/1647] filter: make sub-structs properly nullable so omitempty works Signed-off-by: Sumner Evans --- bridge/crypto.go | 14 +++++++------- bridgev2/matrix/crypto.go | 14 +++++++------- filter.go | 20 ++++++++++---------- sync.go | 4 ++-- 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/bridge/crypto.go b/bridge/crypto.go index f0b90056..4765039b 100644 --- a/bridge/crypto.go +++ b/bridge/crypto.go @@ -476,14 +476,14 @@ func (syncer *cryptoSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.D func (syncer *cryptoSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter { everything := []event.Type{{Type: "*"}} return &mautrix.Filter{ - Presence: mautrix.FilterPart{NotTypes: everything}, - AccountData: mautrix.FilterPart{NotTypes: everything}, - Room: mautrix.RoomFilter{ + Presence: &mautrix.FilterPart{NotTypes: everything}, + AccountData: &mautrix.FilterPart{NotTypes: everything}, + Room: &mautrix.RoomFilter{ IncludeLeave: false, - Ephemeral: mautrix.FilterPart{NotTypes: everything}, - AccountData: mautrix.FilterPart{NotTypes: everything}, - State: mautrix.FilterPart{NotTypes: everything}, - Timeline: mautrix.FilterPart{NotTypes: everything}, + Ephemeral: &mautrix.FilterPart{NotTypes: everything}, + AccountData: &mautrix.FilterPart{NotTypes: everything}, + State: &mautrix.FilterPart{NotTypes: everything}, + Timeline: &mautrix.FilterPart{NotTypes: everything}, }, } } diff --git a/bridgev2/matrix/crypto.go b/bridgev2/matrix/crypto.go index df6f7a63..f330f9f4 100644 --- a/bridgev2/matrix/crypto.go +++ b/bridgev2/matrix/crypto.go @@ -503,14 +503,14 @@ func (syncer *cryptoSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.D func (syncer *cryptoSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter { everything := []event.Type{{Type: "*"}} return &mautrix.Filter{ - Presence: mautrix.FilterPart{NotTypes: everything}, - AccountData: mautrix.FilterPart{NotTypes: everything}, - Room: mautrix.RoomFilter{ + Presence: &mautrix.FilterPart{NotTypes: everything}, + AccountData: &mautrix.FilterPart{NotTypes: everything}, + Room: &mautrix.RoomFilter{ IncludeLeave: false, - Ephemeral: mautrix.FilterPart{NotTypes: everything}, - AccountData: mautrix.FilterPart{NotTypes: everything}, - State: mautrix.FilterPart{NotTypes: everything}, - Timeline: mautrix.FilterPart{NotTypes: everything}, + Ephemeral: &mautrix.FilterPart{NotTypes: everything}, + AccountData: &mautrix.FilterPart{NotTypes: everything}, + State: &mautrix.FilterPart{NotTypes: everything}, + Timeline: &mautrix.FilterPart{NotTypes: everything}, }, } } diff --git a/filter.go b/filter.go index 2603bfb9..ce3dde47 100644 --- a/filter.go +++ b/filter.go @@ -19,24 +19,24 @@ const ( // Filter is used by clients to specify how the server should filter responses to e.g. sync requests // Specified by: https://spec.matrix.org/v1.2/client-server-api/#filtering type Filter struct { - AccountData FilterPart `json:"account_data,omitempty"` + AccountData *FilterPart `json:"account_data,omitempty"` EventFields []string `json:"event_fields,omitempty"` EventFormat EventFormat `json:"event_format,omitempty"` - Presence FilterPart `json:"presence,omitempty"` - Room RoomFilter `json:"room,omitempty"` + Presence *FilterPart `json:"presence,omitempty"` + Room *RoomFilter `json:"room,omitempty"` BeeperToDevice *FilterPart `json:"com.beeper.to_device,omitempty"` } // RoomFilter is used to define filtering rules for room events type RoomFilter struct { - AccountData FilterPart `json:"account_data,omitempty"` - Ephemeral FilterPart `json:"ephemeral,omitempty"` + AccountData *FilterPart `json:"account_data,omitempty"` + Ephemeral *FilterPart `json:"ephemeral,omitempty"` IncludeLeave bool `json:"include_leave,omitempty"` NotRooms []id.RoomID `json:"not_rooms,omitempty"` Rooms []id.RoomID `json:"rooms,omitempty"` - State FilterPart `json:"state,omitempty"` - Timeline FilterPart `json:"timeline,omitempty"` + State *FilterPart `json:"state,omitempty"` + Timeline *FilterPart `json:"timeline,omitempty"` } // FilterPart is used to define filtering rules for specific categories of events @@ -69,7 +69,7 @@ func DefaultFilter() Filter { EventFields: nil, EventFormat: "client", Presence: DefaultFilterPart(), - Room: RoomFilter{ + Room: &RoomFilter{ AccountData: DefaultFilterPart(), Ephemeral: DefaultFilterPart(), IncludeLeave: false, @@ -82,8 +82,8 @@ func DefaultFilter() Filter { } // DefaultFilterPart returns the default filter part used by the Matrix server if no filter is provided in the request -func DefaultFilterPart() FilterPart { - return FilterPart{ +func DefaultFilterPart() *FilterPart { + return &FilterPart{ NotRooms: nil, Rooms: nil, Limit: 20, diff --git a/sync.go b/sync.go index d4208404..48906bbc 100644 --- a/sync.go +++ b/sync.go @@ -191,8 +191,8 @@ func (s *DefaultSyncer) OnFailedSync(res *RespSync, err error) (time.Duration, e } var defaultFilter = Filter{ - Room: RoomFilter{ - Timeline: FilterPart{ + Room: &RoomFilter{ + Timeline: &FilterPart{ Limit: 50, }, }, From 7cc19d9720caefb6d4a8263cb118815868257e3f Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 19 Dec 2024 13:09:21 -0700 Subject: [PATCH 0965/1647] filter: add unread_thread_notifications to FilterPart Signed-off-by: Sumner Evans --- filter.go | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/filter.go b/filter.go index ce3dde47..c6c8211b 100644 --- a/filter.go +++ b/filter.go @@ -41,17 +41,17 @@ type RoomFilter struct { // FilterPart is used to define filtering rules for specific categories of events type FilterPart struct { - NotRooms []id.RoomID `json:"not_rooms,omitempty"` - Rooms []id.RoomID `json:"rooms,omitempty"` - Limit int `json:"limit,omitempty"` - NotSenders []id.UserID `json:"not_senders,omitempty"` - NotTypes []event.Type `json:"not_types,omitempty"` - Senders []id.UserID `json:"senders,omitempty"` - Types []event.Type `json:"types,omitempty"` - ContainsURL *bool `json:"contains_url,omitempty"` - - LazyLoadMembers bool `json:"lazy_load_members,omitempty"` - IncludeRedundantMembers bool `json:"include_redundant_members,omitempty"` + NotRooms []id.RoomID `json:"not_rooms,omitempty"` + Rooms []id.RoomID `json:"rooms,omitempty"` + Limit int `json:"limit,omitempty"` + NotSenders []id.UserID `json:"not_senders,omitempty"` + NotTypes []event.Type `json:"not_types,omitempty"` + Senders []id.UserID `json:"senders,omitempty"` + Types []event.Type `json:"types,omitempty"` + ContainsURL *bool `json:"contains_url,omitempty"` + LazyLoadMembers bool `json:"lazy_load_members,omitempty"` + IncludeRedundantMembers bool `json:"include_redundant_members,omitempty"` + UnreadThreadNotifications bool `json:"unread_thread_notifications,omitempty"` } // Validate checks if the filter contains valid property values From fbee4248a1463e154475631948cb39dd8f2750a9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 19 Dec 2024 21:57:45 +0200 Subject: [PATCH 0966/1647] client: allow multiple vias in JoinRoom --- client.go | 23 +++++++++++------------ requests.go | 6 ++++++ url.go | 22 ++++++++++++++++------ 3 files changed, 33 insertions(+), 18 deletions(-) diff --git a/client.go b/client.go index e8689708..3db4e219 100644 --- a/client.go +++ b/client.go @@ -951,20 +951,19 @@ func (cli *Client) Capabilities(ctx context.Context) (resp *RespCapabilities, er return } -// JoinRoom joins the client to a room ID or alias. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3joinroomidoralias +// JoinRoom joins the client to a room ID or alias. See https://spec.matrix.org/v1.13/client-server-api/#post_matrixclientv3joinroomidoralias // -// If serverName is specified, this will be added as a query param to instruct the homeserver to join via that server. If content is specified, it will -// be JSON encoded and used as the request body. -func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias, serverName string, content interface{}) (resp *RespJoinRoom, err error) { - var urlPath string - if serverName != "" { - urlPath = cli.BuildURLWithQuery(ClientURLPath{"v3", "join", roomIDorAlias}, map[string]string{ - "via": serverName, - }) - } else { - urlPath = cli.BuildClientURL("v3", "join", roomIDorAlias) +// The last parameter contains optional extra fields and can be left nil. +func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias string, req *ReqJoinRoom) (resp *RespJoinRoom, err error) { + if req == nil { + req = &ReqJoinRoom{} } - _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, content, &resp) + urlPath := cli.BuildURLWithFullQuery(ClientURLPath{"v3", "join", roomIDorAlias}, func(q url.Values) { + if len(req.Via) > 0 { + q["via"] = req.Via + } + }) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) if err == nil && cli.StateStore != nil { err = cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipJoin) if err != nil { diff --git a/requests.go b/requests.go index 9e7eb0bd..c1985da4 100644 --- a/requests.go +++ b/requests.go @@ -144,6 +144,12 @@ type ReqMembers struct { NotMembership event.Membership `json:"not_membership,omitempty"` } +type ReqJoinRoom struct { + Via []string `json:"-"` + Reason string `json:"reason,omitempty"` + ThirdPartySigned any `json:"third_party_signed,omitempty"` +} + type ReqMutualRooms struct { From string `json:"-"` } diff --git a/url.go b/url.go index f35ae5e2..0b4eec67 100644 --- a/url.go +++ b/url.go @@ -57,13 +57,13 @@ func BuildURL(baseURL *url.URL, path ...any) *url.URL { // BuildURL builds a URL with the Client's homeserver and appservice user ID set already. func (cli *Client) BuildURL(urlPath PrefixableURLPath) string { - return cli.BuildURLWithQuery(urlPath, nil) + return cli.BuildURLWithFullQuery(urlPath, nil) } // BuildClientURL builds a URL with the Client's homeserver and appservice user ID set already. // This method also automatically prepends the client API prefix (/_matrix/client). func (cli *Client) BuildClientURL(urlPath ...any) string { - return cli.BuildURLWithQuery(ClientURLPath(urlPath), nil) + return cli.BuildURLWithFullQuery(ClientURLPath(urlPath), nil) } type PrefixableURLPath interface { @@ -97,6 +97,18 @@ func (saup SynapseAdminURLPath) FullPath() []any { // BuildURLWithQuery builds a URL with query parameters in addition to the Client's homeserver // and appservice user ID set already. func (cli *Client) BuildURLWithQuery(urlPath PrefixableURLPath, urlQuery map[string]string) string { + return cli.BuildURLWithFullQuery(urlPath, func(q url.Values) { + if urlQuery != nil { + for k, v := range urlQuery { + q.Set(k, v) + } + } + }) +} + +// BuildURLWithQuery builds a URL with query parameters in addition to the Client's homeserver +// and appservice user ID set already. +func (cli *Client) BuildURLWithFullQuery(urlPath PrefixableURLPath, fn func(q url.Values)) string { hsURL := *BuildURL(cli.HomeserverURL, urlPath.FullPath()...) query := hsURL.Query() if cli.SetAppServiceUserID { @@ -106,10 +118,8 @@ func (cli *Client) BuildURLWithQuery(urlPath PrefixableURLPath, urlQuery map[str query.Set("device_id", string(cli.DeviceID)) query.Set("org.matrix.msc3202.device_id", string(cli.DeviceID)) } - if urlQuery != nil { - for k, v := range urlQuery { - query.Set(k, v) - } + if fn != nil { + fn(query) } hsURL.RawQuery = query.Encode() return hsURL.String() From 6c9a29d25a4a1c03b711b6cb84dad8c4a7d7f108 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 19 Dec 2024 21:57:57 +0200 Subject: [PATCH 0967/1647] client: add GetRoomSummary to implement MSC3266 --- client.go | 11 +++++++++++ responses.go | 49 ++++++++++++++++++++++++++++++++++++------------- 2 files changed, 47 insertions(+), 13 deletions(-) diff --git a/client.go b/client.go index 3db4e219..c45bda27 100644 --- a/client.go +++ b/client.go @@ -1010,6 +1010,17 @@ func (cli *Client) GetMutualRooms(ctx context.Context, otherUserID id.UserID, ex return } +func (cli *Client) GetRoomSummary(ctx context.Context, roomIDOrAlias string, via ...string) (resp *RespRoomSummary, err error) { + // TODO add version check after one is added to MSC3266 + urlPath := cli.BuildURLWithFullQuery(ClientURLPath{"unstable", "im.nheko.summary", "summary", roomIDOrAlias}, func(q url.Values) { + if len(via) > 0 { + q["via"] = via + } + }) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + return +} + // GetDisplayName returns the display name of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseriddisplayname func (cli *Client) GetDisplayName(ctx context.Context, mxid id.UserID) (resp *RespUserDisplayName, err error) { urlPath := cli.BuildClientURL("v3", "profile", mxid, "displayname") diff --git a/responses.go b/responses.go index 7b62b433..a067682d 100644 --- a/responses.go +++ b/responses.go @@ -164,6 +164,18 @@ type RespMutualRooms struct { NextBatch string `json:"next_batch,omitempty"` } +type RespRoomSummary struct { + PublicRoomInfo + + Membership event.Membership `json:"membership,omitempty"` + RoomVersion event.RoomVersion `json:"room_version,omitempty"` + Encryption id.Algorithm `json:"encryption,omitempty"` + + UnstableRoomVersion event.RoomVersion `json:"im.nheko.summary.room_version,omitempty"` + UnstableRoomVersionOld event.RoomVersion `json:"im.nheko.summary.version,omitempty"` + UnstableEncryption id.Algorithm `json:"im.nheko.summary.encryption,omitempty"` +} + // RespRegisterAvailable is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv3registeravailable type RespRegisterAvailable struct { Available bool `json:"available"` @@ -558,24 +570,35 @@ func (vers *CapRoomVersions) IsAvailable(version string) bool { return available } +type RespPublicRooms struct { + Chunk []*PublicRoomInfo `json:"chunk"` + NextBatch string `json:"next_batch,omitempty"` + PrevBatch string `json:"prev_batch,omitempty"` + TotalRoomCountEstimate int `json:"total_room_count_estimate"` +} + +type PublicRoomInfo struct { + RoomID id.RoomID `json:"room_id"` + AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` + CanonicalAlias id.RoomAlias `json:"canonical_alias,omitempty"` + GuestCanJoin bool `json:"guest_can_join"` + JoinRule event.JoinRule `json:"join_rule,omitempty"` + Name string `json:"name,omitempty"` + NumJoinedMembers int `json:"num_joined_members"` + RoomType event.RoomType `json:"room_type"` + Topic string `json:"topic,omitempty"` + WorldReadable bool `json:"world_readable"` +} + // RespHierarchy is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy type RespHierarchy struct { - NextBatch string `json:"next_batch,omitempty"` - Rooms []ChildRoomsChunk `json:"rooms"` + NextBatch string `json:"next_batch,omitempty"` + Rooms []*ChildRoomsChunk `json:"rooms"` } type ChildRoomsChunk struct { - AvatarURL id.ContentURI `json:"avatar_url,omitempty"` - CanonicalAlias id.RoomAlias `json:"canonical_alias,omitempty"` - ChildrenState []StrippedStateWithTime `json:"children_state"` - GuestCanJoin bool `json:"guest_can_join"` - JoinRule event.JoinRule `json:"join_rule,omitempty"` - Name string `json:"name,omitempty"` - NumJoinedMembers int `json:"num_joined_members"` - RoomID id.RoomID `json:"room_id"` - RoomType event.RoomType `json:"room_type"` - Topic string `json:"topic,omitempty"` - WorldReadble bool `json:"world_readable"` + PublicRoomInfo + ChildrenState []StrippedStateWithTime `json:"children_state"` } type StrippedStateWithTime struct { From 918ed4bf23cec4a593c4c1b49d2cae7488c600dc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 19 Dec 2024 23:33:56 +0200 Subject: [PATCH 0968/1647] error: don't include path in HTTP errors The request data is logged anyway, so it's nicer if the error string only has the important part. --- error.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/error.go b/error.go index 0133e80e..653ac5a1 100644 --- a/error.go +++ b/error.go @@ -96,10 +96,9 @@ func (e HTTPError) Error() string { if e.WrappedError != nil { return fmt.Sprintf("%s: %v", e.Message, e.WrappedError) } else if e.RespError != nil { - return fmt.Sprintf("failed to %s %s: %s (HTTP %d): %s", e.Request.Method, e.Request.URL.Path, - e.RespError.ErrCode, e.Response.StatusCode, e.RespError.Err) + return fmt.Sprintf("%s (HTTP %d): %s", e.RespError.ErrCode, e.Response.StatusCode, e.RespError.Err) } else { - msg := fmt.Sprintf("failed to %s %s: HTTP %d", e.Request.Method, e.Request.URL.Path, e.Response.StatusCode) + msg := fmt.Sprintf("HTTP %d", e.Response.StatusCode) if len(e.ResponseBody) > 0 { msg = fmt.Sprintf("%s: %s", msg, e.ResponseBody) } From e844153658485a873f8a691a657c62a3368b317a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 20 Dec 2024 14:38:24 +0200 Subject: [PATCH 0969/1647] crypto/decryptolm: store olm hashes to prevent errors if they're repeated --- crypto/decryptolm.go | 45 +++++++++++++++++-- crypto/machine.go | 33 ++++++++++++++ crypto/sql_store.go | 23 ++++++++++ .../sql_store_upgrade/00-latest-revision.sql | 12 ++++- .../17-decrypted-olm-messages.sql | 11 +++++ crypto/store.go | 28 ++++++++++++ 6 files changed, 148 insertions(+), 4 deletions(-) create mode 100644 crypto/sql_store_upgrade/17-decrypted-olm-messages.sql diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go index 788bf832..00fba988 100644 --- a/crypto/decryptolm.go +++ b/crypto/decryptolm.go @@ -8,6 +8,8 @@ package crypto import ( "context" + "crypto/sha256" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -29,6 +31,7 @@ var ( SenderMismatch = errors.New("mismatched sender in olm payload") RecipientMismatch = errors.New("mismatched recipient in olm payload") RecipientKeyMismatch = errors.New("mismatched recipient key in olm payload") + ErrDuplicateMessage = errors.New("duplicate olm message") ) // DecryptedOlmEvent represents an event that was decrypted from an event encrypted with the m.olm.v1.curve25519-aes-sha2 algorithm. @@ -113,14 +116,35 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e return &olmEvt, nil } +func olmMessageHash(ciphertext string) ([32]byte, error) { + ciphertextBytes, err := base64.RawStdEncoding.DecodeString(ciphertext) + return sha256.Sum256(ciphertextBytes), err +} + func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.UserID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) { + ciphertextHash, err := olmMessageHash(ciphertext) + if err != nil { + return nil, fmt.Errorf("failed to hash olm ciphertext: %w", err) + } + log := *zerolog.Ctx(ctx) endTimeTrace := mach.timeTrace(ctx, "waiting for olm lock", 5*time.Second) mach.olmLock.Lock() endTimeTrace() defer mach.olmLock.Unlock() - plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(ctx, senderKey, olmType, ciphertext) + duplicateTS, err := mach.CryptoStore.GetOlmHash(ctx, ciphertextHash) + if err != nil { + log.Warn().Err(err).Msg("Failed to check for duplicate olm message") + } else if !duplicateTS.IsZero() { + log.Warn(). + Hex("ciphertext_hash", ciphertextHash[:]). + Time("duplicate_ts", duplicateTS). + Msg("Ignoring duplicate olm message") + return nil, ErrDuplicateMessage + } + + plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(ctx, senderKey, olmType, ciphertext, ciphertextHash) if err != nil { if err == DecryptionFailedWithMatchingSession { log.Warn().Msg("Found matching session, but decryption failed") @@ -153,6 +177,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U } log = log.With().Str("new_olm_session_id", session.ID().String()).Logger() log.Debug(). + Hex("ciphertext_hash", ciphertextHash[:]). Str("olm_session_description", session.Describe()). Msg("Created inbound olm session") ctx = log.WithContext(ctx) @@ -166,6 +191,10 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U } endTimeTrace = mach.timeTrace(ctx, "updating new session in database", time.Second) + err = mach.CryptoStore.PutOlmHash(ctx, ciphertextHash, time.Now()) + if err != nil { + log.Warn().Err(err).Msg("Failed to store olm message hash after decrypting") + } err = mach.CryptoStore.UpdateSession(ctx, senderKey, session) endTimeTrace() if err != nil { @@ -176,7 +205,9 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U const MaxOlmSessionsPerDevice = 5 -func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.Context, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) { +func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession( + ctx context.Context, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string, ciphertextHash [32]byte, +) ([]byte, error) { log := *zerolog.Ctx(ctx) endTimeTrace := mach.timeTrace(ctx, "getting sessions with sender key", time.Second) sessions, err := mach.CryptoStore.GetSessions(ctx, senderKey) @@ -229,6 +260,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.C endTimeTrace() if err != nil { log.Warn().Err(err). + Hex("ciphertext_hash", ciphertextHash[:]). Str("session_description", session.Describe()). Msg("Failed to decrypt olm message") if olmType == id.OlmMsgTypePreKey { @@ -236,12 +268,19 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.C } } else { endTimeTrace = mach.timeTrace(ctx, "updating session in database", time.Second) + err = mach.CryptoStore.PutOlmHash(ctx, ciphertextHash, time.Now()) + if err != nil { + log.Warn().Err(err).Msg("Failed to store olm message hash after decrypting") + } err = mach.CryptoStore.UpdateSession(ctx, senderKey, session) endTimeTrace() if err != nil { log.Warn().Err(err).Msg("Failed to update olm session in crypto store after decrypting") } - log.Debug().Str("session_description", session.Describe()).Msg("Decrypted olm message") + log.Debug(). + Hex("ciphertext_hash", ciphertextHash[:]). + Str("session_description", session.Describe()). + Msg("Decrypted olm message") return plaintext, nil } } diff --git a/crypto/machine.go b/crypto/machine.go index 95d86d42..4594b9d8 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -63,6 +63,9 @@ type OlmMachine struct { devicesToUnwedgeLock sync.Mutex recentlyUnwedged map[id.IdentityKey]time.Time recentlyUnwedgedLock sync.Mutex + olmHashSavePoints []time.Time + lastHashDelete time.Time + olmHashSavePointLock sync.Mutex olmLock sync.Mutex megolmEncryptLock sync.Mutex @@ -312,6 +315,7 @@ func (mach *OlmMachine) ProcessSyncResponse(ctx context.Context, resp *mautrix.R } mach.HandleOTKCounts(ctx, &resp.DeviceOTKCount) + mach.MarkOlmHashSavePoint(ctx) return true } @@ -399,6 +403,35 @@ func (mach *OlmMachine) HandleEncryptedEvent(ctx context.Context, evt *event.Eve } } +const olmHashSavePointCount = 5 +const olmHashDeleteMinInterval = 10 * time.Minute +const minSavePointInterval = 1 * time.Minute + +// MarkOlmHashSavePoint marks the current time as a save point for olm hashes and deletes old hashes if needed. +// +// This should be called after all to-device events in a sync have been processed. +// The function will then delete old olm hashes after enough syncs have happened +// (such that it's unlikely for the olm messages to repeat). +func (mach *OlmMachine) MarkOlmHashSavePoint(ctx context.Context) { + mach.olmHashSavePointLock.Lock() + defer mach.olmHashSavePointLock.Unlock() + if len(mach.olmHashSavePoints) > 0 && time.Since(mach.olmHashSavePoints[len(mach.olmHashSavePoints)-1]) < minSavePointInterval { + return + } + mach.olmHashSavePoints = append(mach.olmHashSavePoints, time.Now()) + if len(mach.olmHashSavePoints) > olmHashSavePointCount { + sp := mach.olmHashSavePoints[0] + mach.olmHashSavePoints = mach.olmHashSavePoints[1:] + if time.Since(mach.lastHashDelete) > olmHashDeleteMinInterval { + err := mach.CryptoStore.DeleteOldOlmHashes(ctx, sp) + mach.lastHashDelete = time.Now() + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to delete old olm hashes") + } + } + } +} + // HandleToDeviceEvent handles a single to-device event. This is automatically called by ProcessSyncResponse, so you // don't need to add any custom handlers if you use that method. func (mach *OlmMachine) HandleToDeviceEvent(ctx context.Context, evt *event.Event) { diff --git a/crypto/sql_store.go b/crypto/sql_store.go index e68f0df5..02dfe8a1 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -279,6 +279,29 @@ func (store *SQLCryptoStore) DeleteSession(ctx context.Context, _ id.SenderKey, return err } +func (store *SQLCryptoStore) PutOlmHash(ctx context.Context, messageHash [32]byte, receivedAt time.Time) error { + _, err := store.DB.Exec(ctx, "INSERT INTO crypto_olm_message_hash (account_id, received_at, message_hash) VALUES ($1, $2, $3) ON CONFLICT (message_hash) DO NOTHING", store.Account, messageHash[:], receivedAt.UnixMilli()) + return err +} + +func (store *SQLCryptoStore) GetOlmHash(ctx context.Context, messageHash [32]byte) (receivedAt time.Time, err error) { + var receivedAtInt int64 + err = store.DB.QueryRow(ctx, "SELECT received_at FROM crypto_olm_message_hash WHERE message_hash=$1", messageHash).Scan(&receivedAtInt) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return + } + receivedAt = time.UnixMilli(receivedAtInt) + return +} + +func (store *SQLCryptoStore) DeleteOldOlmHashes(ctx context.Context, beforeTS time.Time) error { + _, err := store.DB.Exec(ctx, "DELETE FROM crypto_olm_message_hash WHERE account_id = $1 AND received_at < $2", store.AccountID, beforeTS.UnixMilli()) + return err +} + func datePtr(t time.Time) *time.Time { if t.IsZero() { return nil diff --git a/crypto/sql_store_upgrade/00-latest-revision.sql b/crypto/sql_store_upgrade/00-latest-revision.sql index 7cd3331c..00dd1387 100644 --- a/crypto/sql_store_upgrade/00-latest-revision.sql +++ b/crypto/sql_store_upgrade/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v16 (compatible with v15+): Latest revision +-- v0 -> v17 (compatible with v15+): Latest revision CREATE TABLE IF NOT EXISTS crypto_account ( account_id TEXT PRIMARY KEY, device_id TEXT NOT NULL, @@ -45,6 +45,16 @@ CREATE TABLE IF NOT EXISTS crypto_olm_session ( ); CREATE INDEX crypto_olm_session_sender_key_idx ON crypto_olm_session (account_id, sender_key); +CREATE TABLE crypto_olm_message_hash ( + account_id TEXT NOT NULL, + received_at BIGINT NOT NULL, + message_hash bytea NOT NULL PRIMARY KEY, + + CONSTRAINT crypto_olm_message_hash_account_fkey FOREIGN KEY (account_id) + REFERENCES crypto_account (account_id) ON DELETE CASCADE ON UPDATE CASCADE +); +CREATE INDEX crypto_olm_message_hash_account_idx ON crypto_olm_message_hash (account_id); + CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session ( account_id TEXT, session_id CHAR(43), diff --git a/crypto/sql_store_upgrade/17-decrypted-olm-messages.sql b/crypto/sql_store_upgrade/17-decrypted-olm-messages.sql new file mode 100644 index 00000000..525bbb52 --- /dev/null +++ b/crypto/sql_store_upgrade/17-decrypted-olm-messages.sql @@ -0,0 +1,11 @@ +-- v17 (compatible with v15+): Add table for decrypted Olm message hashes +CREATE TABLE crypto_olm_message_hash ( + account_id TEXT NOT NULL, + received_at BIGINT NOT NULL, + message_hash bytea NOT NULL PRIMARY KEY, + + CONSTRAINT crypto_olm_message_hash_account_fkey FOREIGN KEY (account_id) + REFERENCES crypto_account (account_id) ON DELETE CASCADE ON UPDATE CASCADE +); + +CREATE INDEX crypto_olm_message_hash_account_idx ON crypto_olm_message_hash (account_id); diff --git a/crypto/store.go b/crypto/store.go index 9a3a4394..4e43bd2a 100644 --- a/crypto/store.go +++ b/crypto/store.go @@ -12,8 +12,10 @@ import ( "slices" "sort" "sync" + "time" "go.mau.fi/util/dbutil" + "go.mau.fi/util/exsync" "golang.org/x/exp/maps" "maunium.net/go/mautrix/event" @@ -51,6 +53,13 @@ type Store interface { // DeleteSession deletes the given session that has been previously inserted with AddSession. DeleteSession(context.Context, id.SenderKey, *OlmSession) error + // PutOlmHash marks a given olm message hash as handled. + PutOlmHash(context.Context, [32]byte, time.Time) error + // GetOlmHash gets the time that a given olm hash was handled. + GetOlmHash(context.Context, [32]byte) (time.Time, error) + // DeleteOldOlmHashes deletes all olm hashes that were handled before the given time. + DeleteOldOlmHashes(context.Context, time.Time) error + // PutGroupSession inserts an inbound Megolm session into the store. If an earlier withhold event has been inserted // with PutWithheldGroupSession, this call should replace that. However, PutWithheldGroupSession must not replace // sessions inserted with this call. @@ -176,6 +185,7 @@ type MemoryStore struct { KeySignatures map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string OutdatedUsers map[id.UserID]struct{} Secrets map[id.Secret]string + OlmHashes *exsync.Set[[32]byte] } var _ Store = (*MemoryStore)(nil) @@ -198,6 +208,7 @@ func NewMemoryStore(saveCallback func() error) *MemoryStore { KeySignatures: make(map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string), OutdatedUsers: make(map[id.UserID]struct{}), Secrets: make(map[id.Secret]string), + OlmHashes: exsync.NewSet[[32]byte](), } } @@ -263,6 +274,23 @@ func (gs *MemoryStore) HasSession(_ context.Context, senderKey id.SenderKey) boo return ok && len(sessions) > 0 && !sessions[0].Expired() } +func (gs *MemoryStore) PutOlmHash(_ context.Context, hash [32]byte, receivedAt time.Time) error { + gs.OlmHashes.Add(hash) + return nil +} + +func (gs *MemoryStore) GetOlmHash(_ context.Context, hash [32]byte) (time.Time, error) { + if gs.OlmHashes.Has(hash) { + // The time isn't that important, so we just return the current time + return time.Now(), nil + } + return time.Time{}, nil +} + +func (gs *MemoryStore) DeleteOldOlmHashes(_ context.Context, beforeTS time.Time) error { + return nil +} + func (gs *MemoryStore) GetLatestSession(_ context.Context, senderKey id.SenderKey) (*OlmSession, error) { gs.lock.RLock() defer gs.lock.RUnlock() From 049990cd7bfc927e41cf1df133a198dff25c8381 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 20 Dec 2024 14:48:17 +0200 Subject: [PATCH 0970/1647] crypto/decryptolm: check last olm session creation ts before unwedging --- crypto/decryptolm.go | 11 +++++++++++ crypto/sql_store.go | 10 ++++++++++ crypto/store.go | 17 ++++++++++++++--- 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go index 00fba988..353979d4 100644 --- a/crypto/decryptolm.go +++ b/crypto/decryptolm.go @@ -318,6 +318,17 @@ func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, send mach.recentlyUnwedged[senderKey] = time.Now() mach.recentlyUnwedgedLock.Unlock() + lastCreatedAt, err := mach.CryptoStore.GetNewestSessionCreationTS(ctx, senderKey) + if err != nil { + log.Warn().Err(err).Msg("Failed to get newest session creation timestamp") + return + } else if time.Since(lastCreatedAt) < MinUnwedgeInterval { + log.Debug(). + Time("last_created_at", lastCreatedAt). + Msg("Not creating new Olm session as it was already recreated recently") + return + } + deviceIdentity, err := mach.GetOrFetchDeviceByKey(ctx, sender, senderKey) if err != nil { log.Error().Err(err).Msg("Failed to find device info by identity key") diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 02dfe8a1..d66a4760 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -249,6 +249,16 @@ func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.Sender } } +// GetNewestSessionCreationTS gets the creation timestamp of the most recently created session with the given sender key. +func (store *SQLCryptoStore) GetNewestSessionCreationTS(ctx context.Context, key id.SenderKey) (createdAt time.Time, err error) { + err = store.DB.QueryRow(ctx, "SELECT created_at FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY created_at DESC LIMIT 1", + key, store.AccountID).Scan(&createdAt) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + return +} + // AddSession persists an Olm session for a sender in the database. func (store *SQLCryptoStore) AddSession(ctx context.Context, key id.SenderKey, session *OlmSession) error { store.olmSessionCacheLock.Lock() diff --git a/crypto/store.go b/crypto/store.go index 4e43bd2a..8b7c0a96 100644 --- a/crypto/store.go +++ b/crypto/store.go @@ -45,9 +45,11 @@ type Store interface { HasSession(context.Context, id.SenderKey) bool // GetSessions returns all Olm sessions in the store with the given sender key. GetSessions(context.Context, id.SenderKey) (OlmSessionList, error) - // GetLatestSession returns the session with the highest session ID (lexiographically sorting). - // It's usually safe to return the most recently added session if sorting by session ID is too difficult. + // GetLatestSession returns the most recent session that should be used for encrypting outbound messages. + // It's usually the one with the most recent successful decryption or the highest ID lexically. GetLatestSession(context.Context, id.SenderKey) (*OlmSession, error) + // GetNewestSessionCreationTS returns the creation timestamp of the most recently created session for the given sender key. + GetNewestSessionCreationTS(context.Context, id.SenderKey) (time.Time, error) // UpdateSession updates a session that has previously been inserted with AddSession. UpdateSession(context.Context, id.SenderKey, *OlmSession) error // DeleteSession deletes the given session that has been previously inserted with AddSession. @@ -298,7 +300,16 @@ func (gs *MemoryStore) GetLatestSession(_ context.Context, senderKey id.SenderKe if !ok || len(sessions) == 0 { return nil, nil } - return sessions[0], nil + return sessions[len(sessions)-1], nil +} + +func (gs *MemoryStore) GetNewestSessionCreationTS(ctx context.Context, senderKey id.SenderKey) (createdAt time.Time, err error) { + var sess *OlmSession + sess, err = gs.GetLatestSession(ctx, senderKey) + if sess != nil { + createdAt = sess.CreationTime + } + return } func (gs *MemoryStore) getGroupSessions(roomID id.RoomID) map[id.SessionID]*InboundGroupSession { From 1b66266b15d2bf8196fa95b1cc71e8633bd8b83e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 20 Dec 2024 15:21:05 +0200 Subject: [PATCH 0971/1647] responses: remove non-existent summary field in invites --- responses.go | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/responses.go b/responses.go index a067682d..7312f099 100644 --- a/responses.go +++ b/responses.go @@ -364,16 +364,7 @@ func (sjr SyncJoinedRoom) MarshalJSON() ([]byte, error) { } type SyncInvitedRoom struct { - Summary LazyLoadSummary `json:"summary"` - State SyncEventsList `json:"invite_state"` -} - -type marshalableSyncInvitedRoom SyncInvitedRoom - -var syncInvitedRoomPathsToDelete = []string{"summary"} - -func (sir SyncInvitedRoom) MarshalJSON() ([]byte, error) { - return marshalAndDeleteEmpty((marshalableSyncInvitedRoom)(sir), syncInvitedRoomPathsToDelete) + State SyncEventsList `json:"invite_state"` } type SyncKnockedRoom struct { From 33b4e823c5e55d62e3decff242ccff3795b6ffd6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 22 Dec 2024 14:09:44 +0200 Subject: [PATCH 0972/1647] crypto/sqlstore: fix mistakes in olm hash methods --- crypto/sql_store.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crypto/sql_store.go b/crypto/sql_store.go index d66a4760..0415c704 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -290,13 +290,13 @@ func (store *SQLCryptoStore) DeleteSession(ctx context.Context, _ id.SenderKey, } func (store *SQLCryptoStore) PutOlmHash(ctx context.Context, messageHash [32]byte, receivedAt time.Time) error { - _, err := store.DB.Exec(ctx, "INSERT INTO crypto_olm_message_hash (account_id, received_at, message_hash) VALUES ($1, $2, $3) ON CONFLICT (message_hash) DO NOTHING", store.Account, messageHash[:], receivedAt.UnixMilli()) + _, err := store.DB.Exec(ctx, "INSERT INTO crypto_olm_message_hash (account_id, received_at, message_hash) VALUES ($1, $2, $3) ON CONFLICT (message_hash) DO NOTHING", store.AccountID, receivedAt.UnixMilli(), messageHash[:]) return err } func (store *SQLCryptoStore) GetOlmHash(ctx context.Context, messageHash [32]byte) (receivedAt time.Time, err error) { var receivedAtInt int64 - err = store.DB.QueryRow(ctx, "SELECT received_at FROM crypto_olm_message_hash WHERE message_hash=$1", messageHash).Scan(&receivedAtInt) + err = store.DB.QueryRow(ctx, "SELECT received_at FROM crypto_olm_message_hash WHERE message_hash=$1", messageHash[:]).Scan(&receivedAtInt) if err != nil { if errors.Is(err, sql.ErrNoRows) { err = nil From 5c4474ae70ad2504cb0f617fb6f389acea905cc5 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sun, 22 Dec 2024 14:17:22 +0000 Subject: [PATCH 0973/1647] client: support setting status message in SetPresence (#336) --- client.go | 5 ++--- requests.go | 3 ++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index c45bda27..6b60f975 100644 --- a/client.go +++ b/client.go @@ -1404,10 +1404,9 @@ func (cli *Client) GetOwnPresence(ctx context.Context) (resp *RespPresence, err return cli.GetPresence(ctx, cli.UserID) } -func (cli *Client) SetPresence(ctx context.Context, status event.Presence) (err error) { - req := ReqPresence{Presence: status} +func (cli *Client) SetPresence(ctx context.Context, presence ReqPresence) (err error) { u := cli.BuildClientURL("v3", "presence", cli.UserID, "status") - _, err = cli.MakeRequest(ctx, http.MethodPut, u, req, nil) + _, err = cli.MakeRequest(ctx, http.MethodPut, u, presence, nil) return } diff --git a/requests.go b/requests.go index c1985da4..a796e653 100644 --- a/requests.go +++ b/requests.go @@ -197,7 +197,8 @@ type ReqTyping struct { } type ReqPresence struct { - Presence event.Presence `json:"presence"` + Presence event.Presence `json:"presence"` + StatusMsg string `json:"status_msg,omitempty"` } type ReqAliasCreate struct { From ba210a16b99237f7ea87236fb950d5e83b944cc8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 23 Dec 2024 13:46:08 +0200 Subject: [PATCH 0974/1647] event: add site_name to link previews --- event/beeper.go | 1 + 1 file changed, 1 insertion(+) diff --git a/event/beeper.go b/event/beeper.go index 7ea0d068..74b44a09 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -74,6 +74,7 @@ type LinkPreview struct { Title string `json:"og:title,omitempty"` Type string `json:"og:type,omitempty"` Description string `json:"og:description,omitempty"` + SiteName string `json:"og:site_name,omitempty"` ImageURL id.ContentURIString `json:"og:image,omitempty"` From 2cd6183f30c0a0e72a835db813e28421c0a9b03b Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sun, 29 Dec 2024 19:37:41 +0000 Subject: [PATCH 0975/1647] client: add support for arbitrary fields in /profile (#337) Co-authored-by: Tulir Asokan --- client.go | 16 ++++++++++++++++ responses.go | 41 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index 6b60f975..902056f3 100644 --- a/client.go +++ b/client.go @@ -1043,6 +1043,22 @@ func (cli *Client) SetDisplayName(ctx context.Context, displayName string) (err return } +// UnstableSetProfileField sets an arbitrary MSC4133 profile field. See https://github.com/matrix-org/matrix-spec-proposals/pull/4133 +func (cli *Client) UnstableSetProfileField(ctx context.Context, key string, value any) (err error) { + urlPath := cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, map[string]any{ + key: value, + }, nil) + return +} + +// UnstableDeleteProfileField deletes an arbitrary MSC4133 profile field. See https://github.com/matrix-org/matrix-spec-proposals/pull/4133 +func (cli *Client) UnstableDeleteProfileField(ctx context.Context, key string) (err error) { + urlPath := cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key) + _, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, nil) + return +} + // GetAvatarURL gets the avatar URL of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseridavatar_url func (cli *Client) GetAvatarURL(ctx context.Context, mxid id.UserID) (url id.ContentURI, err error) { urlPath := cli.BuildClientURL("v3", "profile", mxid, "avatar_url") diff --git a/responses.go b/responses.go index 7312f099..a2e5c2b8 100644 --- a/responses.go +++ b/responses.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "maps" "reflect" "strconv" "strings" @@ -155,8 +156,44 @@ type RespUserDisplayName struct { } type RespUserProfile struct { - DisplayName string `json:"displayname"` - AvatarURL id.ContentURI `json:"avatar_url"` + DisplayName string `json:"displayname,omitempty"` + AvatarURL id.ContentURI `json:"avatar_url,omitempty"` + Extra map[string]any `json:"-"` +} + +type marshalableUserProfile RespUserProfile + +func (r *RespUserProfile) UnmarshalJSON(data []byte) error { + err := json.Unmarshal(data, &r.Extra) + if err != nil { + return err + } + r.DisplayName, _ = r.Extra["displayname"].(string) + avatarURL, _ := r.Extra["avatar_url"].(string) + if avatarURL != "" { + r.AvatarURL, _ = id.ParseContentURI(avatarURL) + } + delete(r.Extra, "displayname") + delete(r.Extra, "avatar_url") + return nil +} + +func (r *RespUserProfile) MarshalJSON() ([]byte, error) { + if len(r.Extra) == 0 { + return json.Marshal((*marshalableUserProfile)(r)) + } + marshalMap := maps.Clone(r.Extra) + if r.DisplayName != "" { + marshalMap["displayname"] = r.DisplayName + } else { + delete(marshalMap, "displayname") + } + if !r.AvatarURL.IsEmpty() { + marshalMap["avatar_url"] = r.AvatarURL.String() + } else { + delete(marshalMap, "avatar_url") + } + return json.Marshal(r.Extra) } type RespMutualRooms struct { From 077716a4ec7cbb3983c8c40cba2757afab6d9033 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 1 Jan 2025 18:00:18 +0200 Subject: [PATCH 0976/1647] client: add wrapper for /openid/request_token --- client.go | 5 +++++ responses.go | 7 +++++++ 2 files changed, 12 insertions(+) diff --git a/client.go b/client.go index 902056f3..37de63ad 100644 --- a/client.go +++ b/client.go @@ -1556,6 +1556,11 @@ func (cli *Client) GetMediaConfig(ctx context.Context) (resp *RespMediaConfig, e return } +func (cli *Client) RequestOpenIDToken(ctx context.Context) (resp *RespOpenIDToken, err error) { + _, err = cli.MakeRequest(ctx, http.MethodPost, cli.BuildClientURL("v3", "user", cli.UserID, "openid", "request_token"), nil, &resp) + return +} + // UploadLink uploads an HTTP URL and then returns an MXC URI. func (cli *Client) UploadLink(ctx context.Context, link string) (*RespMediaUpload, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, link, nil) diff --git a/responses.go b/responses.go index a2e5c2b8..dd52b1e7 100644 --- a/responses.go +++ b/responses.go @@ -680,3 +680,10 @@ type RespRoomKeysUpdate struct { Count int `json:"count"` ETag string `json:"etag"` } + +type RespOpenIDToken struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + MatrixServerName string `json:"matrix_server_name"` + TokenType string `json:"token_type"` // Always "Bearer" +} From dbd04afd41dad6ee9e152bc552b71d3d8943f890 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 3 Jan 2025 10:55:08 -0700 Subject: [PATCH 0977/1647] verificationhelper/sas: include emoji descriptions in callback Signed-off-by: Sumner Evans --- crypto/verificationhelper/callbacks_test.go | 9 ++- crypto/verificationhelper/sas.go | 71 ++++++++++++++++++- .../verificationhelper/verificationhelper.go | 7 +- .../verificationhelper_sas_test.go | 13 +++- .../verificationhelper_test.go | 2 +- 5 files changed, 91 insertions(+), 11 deletions(-) diff --git a/crypto/verificationhelper/callbacks_test.go b/crypto/verificationhelper/callbacks_test.go index 5faf2009..466a60fc 100644 --- a/crypto/verificationhelper/callbacks_test.go +++ b/crypto/verificationhelper/callbacks_test.go @@ -28,6 +28,7 @@ type baseVerificationCallbacks struct { doneTransactions map[id.VerificationTransactionID]struct{} verificationCancellation map[id.VerificationTransactionID]*event.VerificationCancelEventContent emojisShown map[id.VerificationTransactionID][]rune + emojiDescriptionsShown map[id.VerificationTransactionID][]string decimalsShown map[id.VerificationTransactionID][]int } @@ -39,6 +40,7 @@ func newBaseVerificationCallbacks() *baseVerificationCallbacks { doneTransactions: map[id.VerificationTransactionID]struct{}{}, verificationCancellation: map[id.VerificationTransactionID]*event.VerificationCancelEventContent{}, emojisShown: map[id.VerificationTransactionID][]rune{}, + emojiDescriptionsShown: map[id.VerificationTransactionID][]string{}, decimalsShown: map[id.VerificationTransactionID][]int{}, } } @@ -69,8 +71,8 @@ func (c *baseVerificationCallbacks) GetVerificationCancellation(txnID id.Verific return c.verificationCancellation[txnID] } -func (c *baseVerificationCallbacks) GetEmojisShown(txnID id.VerificationTransactionID) []rune { - return c.emojisShown[txnID] +func (c *baseVerificationCallbacks) GetEmojisAndDescriptionsShown(txnID id.VerificationTransactionID) ([]rune, []string) { + return c.emojisShown[txnID], c.emojiDescriptionsShown[txnID] } func (c *baseVerificationCallbacks) GetDecimalsShown(txnID id.VerificationTransactionID) []int { @@ -104,8 +106,9 @@ func newSASVerificationCallbacksWithBase(base *baseVerificationCallbacks) *sasVe return &sasVerificationCallbacks{base} } -func (c *sasVerificationCallbacks) ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, decimals []int) { +func (c *sasVerificationCallbacks) ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, emojiDescriptions []string, decimals []int) { c.emojisShown[txnID] = emojis + c.emojiDescriptionsShown[txnID] = emojiDescriptions c.decimalsShown[txnID] = decimals } diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index 178838b8..a78b4b57 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -360,6 +360,7 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn Verific var decimals []int var emojis []rune + var emojiDescriptions []string if slices.Contains(txn.StartEventContent.ShortAuthenticationString, event.SASMethodDecimal) { decimals = []int{ (int(sasBytes[0])<<5 | int(sasBytes[1])>>3) + 1000, @@ -375,9 +376,10 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn Verific // Right shift the number and then mask the lowest 6 bits. emojiIdx := (sasNum >> uint(48-(i+1)*6)) & 0b111111 emojis = append(emojis, allEmojis[emojiIdx]) + emojiDescriptions = append(emojiDescriptions, allEmojiDescriptions[emojiIdx]) } } - vh.showSAS(ctx, txn.TransactionID, emojis, decimals) + vh.showSAS(ctx, txn.TransactionID, emojis, emojiDescriptions, decimals) if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { log.Err(err).Msg("failed to save verification transaction") @@ -575,6 +577,73 @@ var allEmojis = []rune{ '📌', } +var allEmojiDescriptions = []string{ + "Dog", + "Cat", + "Lion", + "Horse", + "Unicorn", + "Pig", + "Elephant", + "Rabbit", + "Panda", + "Rooster", + "Penguin", + "Turtle", + "Fish", + "Octopus", + "Butterfly", + "Flower", + "Tree", + "Cactus", + "Mushroom", + "Globe", + "Moon", + "Cloud", + "Fire", + "Banana", + "Apple", + "Strawberry", + "Corn", + "Pizza", + "Cake", + "Heart", + "Smiley", + "Robot", + "Hat", + "Glasses", + "Spanner", + "Santa", + "Thumbs Up", + "Umbrella", + "Hourglass", + "Clock", + "Gift", + "Light Bulb", + "Book", + "Pencil", + "Paperclip", + "Scissors", + "Lock", + "Key", + "Hammer", + "Telephone", + "Flag", + "Train", + "Bicycle", + "Aeroplane", + "Rocket", + "Trophy", + "Ball", + "Guitar", + "Trumpet", + "Bell", + "Anchor", + "Headphones", + "Folder", + "Pin", +} + func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn VerificationTransaction, evt *event.Event) { log := vh.getLog(ctx).With(). Str("verification_action", "mac"). diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index de943976..be547e7e 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -42,8 +42,9 @@ type RequiredCallbacks interface { type ShowSASCallbacks interface { // ShowSAS is a callback that is called when the SAS verification has // generated a short authentication string to show. It is guaranteed that - // either the emojis list, or the decimals list, or both will be present. - ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, decimals []int) + // either the emojis and emoji descriptions lists, or the decimals list, or + // both will be present. + ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, emojiDescriptions []string, decimals []int) } type ShowQRCodeCallbacks interface { @@ -75,7 +76,7 @@ type VerificationHelper struct { verificationCancelledCallback func(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) verificationDone func(ctx context.Context, txnID id.VerificationTransactionID) - showSAS func(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, decimals []int) + showSAS func(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, emojiDescriptions []string, decimals []int) scanQRCode func(ctx context.Context, txnID id.VerificationTransactionID) showQRCode func(ctx context.Context, txnID id.VerificationTransactionID, qrCode *QRCode) diff --git a/crypto/verificationhelper/verificationhelper_sas_test.go b/crypto/verificationhelper/verificationhelper_sas_test.go index 20e52e0f..22b1563c 100644 --- a/crypto/verificationhelper/verificationhelper_sas_test.go +++ b/crypto/verificationhelper/verificationhelper_sas_test.go @@ -165,7 +165,9 @@ func TestVerification_SAS(t *testing.T) { // Ensure that the receiving device showed emojis and SAS numbers. assert.Len(t, receivingCallbacks.GetDecimalsShown(txnID), 3) - assert.Len(t, receivingCallbacks.GetEmojisShown(txnID), 7) + emojis, descriptions := receivingCallbacks.GetEmojisAndDescriptionsShown(txnID) + assert.Len(t, emojis, 7) + assert.Len(t, descriptions, 7) } else { // Process the first key event on the sending device. ts.dispatchToDevice(t, ctx, sendingClient) @@ -178,7 +180,9 @@ func TestVerification_SAS(t *testing.T) { // Ensure that the sending device showed emojis and SAS numbers. assert.Len(t, sendingCallbacks.GetDecimalsShown(txnID), 3) - assert.Len(t, sendingCallbacks.GetEmojisShown(txnID), 7) + emojis, descriptions := sendingCallbacks.GetEmojisAndDescriptionsShown(txnID) + assert.Len(t, emojis, 7) + assert.Len(t, descriptions, 7) } assert.Equal(t, txnID, secondKeyEvt.TransactionID) assert.NotEmpty(t, secondKeyEvt.Key) @@ -193,7 +197,10 @@ func TestVerification_SAS(t *testing.T) { ts.dispatchToDevice(t, ctx, receivingClient) } assert.Equal(t, sendingCallbacks.GetDecimalsShown(txnID), receivingCallbacks.GetDecimalsShown(txnID)) - assert.Equal(t, sendingCallbacks.GetEmojisShown(txnID), receivingCallbacks.GetEmojisShown(txnID)) + sendingEmojis, sendingDescriptions := sendingCallbacks.GetEmojisAndDescriptionsShown(txnID) + receivingEmojis, receivingDescriptions := receivingCallbacks.GetEmojisAndDescriptionsShown(txnID) + assert.Equal(t, sendingEmojis, receivingEmojis) + assert.Equal(t, sendingDescriptions, receivingDescriptions) // Test that the first MAC event is correct var firstMACEvt *event.VerificationMACEventContent diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go index af4a28c3..49c8db07 100644 --- a/crypto/verificationhelper/verificationhelper_test.go +++ b/crypto/verificationhelper/verificationhelper_test.go @@ -138,7 +138,7 @@ func TestVerification_Start(t *testing.T) { assert.NotEmpty(t, toDeviceInbox[receivingDeviceID]) assert.NotEmpty(t, toDeviceInbox[receivingDeviceID2]) assert.Equal(t, toDeviceInbox[receivingDeviceID], toDeviceInbox[receivingDeviceID2]) - assert.Len(t, toDeviceInbox[receivingDeviceID], 1) + require.Len(t, toDeviceInbox[receivingDeviceID], 1) // Ensure that the verification request is correct. verificationRequest := toDeviceInbox[receivingDeviceID][0].Content.AsVerificationRequest() From 5227c7701282e7e052062673afb5fb8f84b878bb Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 24 Dec 2024 13:19:16 +0200 Subject: [PATCH 0978/1647] bridgev2/commands: hide commands based on network interface implementations --- bridgev2/commands/debug.go | 1 + bridgev2/commands/handler.go | 26 ++++++++++++++++++++++++-- bridgev2/commands/startchat.go | 3 +++ bridgev2/networkinterface.go | 2 ++ 4 files changed, 30 insertions(+), 2 deletions(-) diff --git a/bridgev2/commands/debug.go b/bridgev2/commands/debug.go index d00697ee..4c93dbd4 100644 --- a/bridgev2/commands/debug.go +++ b/bridgev2/commands/debug.go @@ -57,4 +57,5 @@ var CommandRegisterPush = &FullHandler{ }, RequiresAdmin: true, RequiresLogin: true, + NetworkAPI: NetworkAPIImplements[bridgev2.PushableNetworkAPI], } diff --git a/bridgev2/commands/handler.go b/bridgev2/commands/handler.go index c1daf1af..672c81dc 100644 --- a/bridgev2/commands/handler.go +++ b/bridgev2/commands/handler.go @@ -7,6 +7,7 @@ package commands import ( + "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/event" ) @@ -37,6 +38,18 @@ type AliasedCommandHandler interface { GetAliases() []string } +func NetworkAPIImplements[T bridgev2.NetworkAPI](val bridgev2.NetworkAPI) bool { + _, ok := val.(T) + return ok +} + +func NetworkConnectorImplements[T bridgev2.NetworkConnector](val bridgev2.NetworkConnector) bool { + _, ok := val.(T) + return ok +} + +type ImplementationChecker[T any] func(val T) bool + type FullHandler struct { Func func(*Event) @@ -49,6 +62,9 @@ type FullHandler struct { RequiresLogin bool RequiresEventLevel event.Type RequiresLoginPermission bool + + NetworkAPI ImplementationChecker[bridgev2.NetworkAPI] + NetworkConnector ImplementationChecker[bridgev2.NetworkConnector] } func (fh *FullHandler) GetHelp() HelpMeta { @@ -64,9 +80,15 @@ func (fh *FullHandler) GetAliases() []string { return fh.Aliases } +func (fh *FullHandler) ImplementationsFulfilled(ce *Event) bool { + // TODO add dedicated method to get an empty NetworkAPI instead of getting default login + client := ce.User.GetDefaultLogin() + return (fh.NetworkAPI == nil || client == nil || fh.NetworkAPI(client.Client)) && + (fh.NetworkConnector == nil || fh.NetworkConnector(ce.Bridge.Network)) +} + func (fh *FullHandler) ShowInHelp(ce *Event) bool { - return true - //return !fh.RequiresAdmin || ce.User.GetPermissionLevel() >= bridgeconfig.PermissionLevelAdmin + return fh.ImplementationsFulfilled(ce) && (!fh.RequiresAdmin || ce.User.Permissions.Admin) } func (fh *FullHandler) userHasRoomPermission(ce *Event) bool { diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go index aa766c0e..719d3dd5 100644 --- a/bridgev2/commands/startchat.go +++ b/bridgev2/commands/startchat.go @@ -27,6 +27,7 @@ var CommandResolveIdentifier = &FullHandler{ Args: "[_login ID_] <_identifier_>", }, RequiresLogin: true, + NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI], } var CommandStartChat = &FullHandler{ @@ -39,6 +40,7 @@ var CommandStartChat = &FullHandler{ Args: "[_login ID_] <_identifier_>", }, RequiresLogin: true, + NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI], } func getClientForStartingChat[T bridgev2.IdentifierResolvingNetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) { @@ -153,6 +155,7 @@ var CommandSearch = &FullHandler{ Args: "<_query_>", }, RequiresLogin: true, + NetworkAPI: NetworkAPIImplements[bridgev2.UserSearchingNetworkAPI], } func fnSearch(ce *Event) { diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 8ddf1269..db066f0a 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -830,6 +830,8 @@ type PushConfig struct { } type PushableNetworkAPI interface { + NetworkAPI + RegisterPushNotifications(ctx context.Context, pushType PushType, token string) error GetPushConfigs() *PushConfig } From 012c246061035feced6c7d9fe060d3ef8423cff9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 4 Jan 2025 12:42:16 +0200 Subject: [PATCH 0979/1647] bridgev2/matrixinvite: fix setting service members when creating DM via invite --- bridgev2/matrixinvite.go | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/bridgev2/matrixinvite.go b/bridgev2/matrixinvite.go index f8217700..25c35eb7 100644 --- a/bridgev2/matrixinvite.go +++ b/bridgev2/matrixinvite.go @@ -182,6 +182,14 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen portal.UpdateInfo(ctx, resp.PortalInfo, sourceLogin, nil, time.Time{}) } if didSetPortal { + message := "Private chat portal created" + err = br.givePowerToBot(ctx, evt.RoomID, invitedGhost.Intent) + hasWarning := false + if err != nil { + log.Warn().Err(err).Msg("Failed to give power to bot in new DM") + message += "\n\nWarning: failed to promote bot" + hasWarning = true + } // TODO this might become unnecessary if UpdateInfo starts taking care of it _, err = br.Bot.SendState(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.Content{ Parsed: &event.ElementFunctionalMembersContent{ @@ -190,14 +198,10 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen }, time.Time{}) if err != nil { log.Warn().Err(err).Msg("Failed to set service members in room") - } - message := "Private chat portal created" - err = br.givePowerToBot(ctx, evt.RoomID, invitedGhost.Intent) - hasWarning := false - if err != nil { - log.Warn().Err(err).Msg("Failed to give power to bot in new DM") - message += "\n\nWarning: failed to promote bot" - hasWarning = true + if !hasWarning { + message += "\n\nWarning: failed to set service members" + hasWarning = true + } } mx, ok := br.Matrix.(MatrixConnectorWithPostRoomBridgeHandling) if ok { @@ -225,6 +229,9 @@ func (br *Bridge) givePowerToBot(ctx context.Context, roomID id.RoomID, userWith } userLevel := powers.GetUserLevel(userWithPower.GetMXID()) if powers.EnsureUserLevelAs(userWithPower.GetMXID(), br.Bot.GetMXID(), userLevel) { + if userLevel > powers.UsersDefault { + powers.SetUserLevel(userWithPower.GetMXID(), userLevel-1) + } _, err = userWithPower.SendState(ctx, roomID, event.StatePowerLevels, "", &event.Content{ Parsed: powers, }, time.Time{}) From 68eaa9d1df1f35ce5ee6750c9e8d57262c867a54 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 6 Jan 2025 17:24:26 +0200 Subject: [PATCH 0980/1647] dependencies: update --- go.mod | 10 +++++----- go.sum | 20 ++++++++++---------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/go.mod b/go.mod index c686489a..fd8baddb 100644 --- a/go.mod +++ b/go.mod @@ -18,11 +18,11 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.8 - go.mau.fi/util v0.8.3 + go.mau.fi/util v0.8.4-0.20250106152331-30b8c95e7d7a go.mau.fi/zeroconfig v0.1.3 - golang.org/x/crypto v0.31.0 - golang.org/x/exp v0.0.0-20241215155358-4a5509556b9e - golang.org/x/net v0.32.0 + golang.org/x/crypto v0.32.0 + golang.org/x/exp v0.0.0-20250103183323-7d7fa50e5329 + golang.org/x/net v0.33.0 golang.org/x/sync v0.10.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 @@ -37,7 +37,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect - golang.org/x/sys v0.28.0 // indirect + golang.org/x/sys v0.29.0 // indirect golang.org/x/text v0.21.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index 47ac74ef..18a2859e 100644 --- a/go.sum +++ b/go.sum @@ -51,24 +51,24 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.8.3 h1:sulhXtfquMrQjsOP67x9CzWVBYUwhYeoo8hNQIpCWZ4= -go.mau.fi/util v0.8.3/go.mod h1:c00Db8xog70JeIsEvhdHooylTkTkakgnAOsZ04hplQY= +go.mau.fi/util v0.8.4-0.20250106152331-30b8c95e7d7a h1:D9RCHBFjxah9F/YB7amvRJjT2IEOFWcz8jpcEY8dBV0= +go.mau.fi/util v0.8.4-0.20250106152331-30b8c95e7d7a/go.mod h1:MOfGTs1CBuK6ERTcSL4lb5YU7/ujz09eOPVEDckuazY= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= -golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= -golang.org/x/exp v0.0.0-20241215155358-4a5509556b9e h1:4qufH0hlUYs6AO6XmZC3GqfDPGSXHVXUFR6OND+iJX4= -golang.org/x/exp v0.0.0-20241215155358-4a5509556b9e/go.mod h1:qj5a5QZpwLU2NLQudwIN5koi3beDhSAlJwa67PuM98c= -golang.org/x/net v0.32.0 h1:ZqPmj8Kzc+Y6e0+skZsuACbx+wzMgo5MQsJh9Qd6aYI= -golang.org/x/net v0.32.0/go.mod h1:CwU0IoeOlnQQWJ6ioyFrfRuomB8GKF6KbYXZVyeXNfs= +golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= +golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= +golang.org/x/exp v0.0.0-20250103183323-7d7fa50e5329 h1:9kj3STMvgqy3YA4VQXBrN7925ICMxD5wzMRcgA30588= +golang.org/x/exp v0.0.0-20250103183323-7d7fa50e5329/go.mod h1:qj5a5QZpwLU2NLQudwIN5koi3beDhSAlJwa67PuM98c= +golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= +golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= -golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= From ceb9c7b866e113b163cf18d5713ecb5984893783 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 7 Jan 2025 13:44:37 +0200 Subject: [PATCH 0981/1647] bridgev2/portal: fix reaction sync replacing all emojis --- bridgev2/portal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index b1aae9e7..a74f9654 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2274,7 +2274,7 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User existingReaction, ok := existingUserReactions[reaction.EmojiID] if ok { delete(existingUserReactions, reaction.EmojiID) - if reaction.EmojiID != "" { + if reaction.EmojiID != "" || reaction.Emoji == existingReaction.Emoji { continue } doOverwriteReaction(reaction, existingReaction) From 6c5e4d8476d7761d210eaa0e93228601473f6ef7 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Mon, 6 Jan 2025 11:55:26 +0000 Subject: [PATCH 0982/1647] bridgev2/portal: using blocking portal queue push if buffer disabled --- bridgev2/portal.go | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index a74f9654..77267e51 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -84,7 +84,7 @@ type Portal struct { events chan portalEvent } -const PortalEventBuffer = 64 +var PortalEventBuffer = 64 func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, queryErr error, key *networkid.PortalKey) (*Portal, error) { if queryErr != nil { @@ -272,12 +272,16 @@ func (br *Bridge) GetExistingPortalByKey(ctx context.Context, key networkid.Port } func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) { - select { - case portal.events <- evt: - default: - zerolog.Ctx(ctx).Error(). - Str("portal_id", string(portal.ID)). - Msg("Portal event channel is full") + if PortalEventBuffer == 0 { + portal.events <- evt + } else { + select { + case portal.events <- evt: + default: + zerolog.Ctx(ctx).Error(). + Str("portal_id", string(portal.ID)). + Msg("Portal event channel is full") + } } } From e571946e82ec3b1e0417b177be16c2ea73b86b89 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 8 Jan 2025 09:35:06 +0000 Subject: [PATCH 0983/1647] client: add optional media HTTP client --- client.go | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/client.go b/client.go index 37de63ad..ea7fd6a1 100644 --- a/client.go +++ b/client.go @@ -76,18 +76,19 @@ type VerificationHelper interface { // Client represents a Matrix client. type Client struct { - HomeserverURL *url.URL // The base homeserver URL - UserID id.UserID // The user ID of the client. Used for forming HTTP paths which use the client's user ID. - DeviceID id.DeviceID // The device ID of the client. - AccessToken string // The access_token for the client. - UserAgent string // The value for the User-Agent header - Client *http.Client // The underlying HTTP client which will be used to make HTTP requests. - Syncer Syncer // The thing which can process /sync responses - Store SyncStore // The thing which can store tokens/ids - StateStore StateStore - Crypto CryptoHelper - Verification VerificationHelper - SpecVersions *RespVersions + HomeserverURL *url.URL // The base homeserver URL + UserID id.UserID // The user ID of the client. Used for forming HTTP paths which use the client's user ID. + DeviceID id.DeviceID // The device ID of the client. + AccessToken string // The access_token for the client. + UserAgent string // The value for the User-Agent header + Client *http.Client // The underlying HTTP client which will be used to make HTTP requests. + Syncer Syncer // The thing which can process /sync responses + Store SyncStore // The thing which can store tokens/ids + StateStore StateStore + Crypto CryptoHelper + Verification VerificationHelper + SpecVersions *RespVersions + ExternalClient *http.Client // The HTTP client used for external (not matrix) media HTTP requests. Log zerolog.Logger @@ -1693,7 +1694,11 @@ func (cli *Client) tryUploadMediaToURL(ctx context.Context, url, contentType str req.Header.Set("Content-Type", contentType) req.Header.Set("User-Agent", cli.UserAgent+" (external media uploader)") - return http.DefaultClient.Do(req) + if cli.ExternalClient != nil { + return cli.ExternalClient.Do(req) + } else { + return http.DefaultClient.Do(req) + } } func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (*RespMediaUpload, error) { From 9748015309bd6ec3bc9a4f1085a1acf6e68018e3 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 9 Jan 2025 10:57:01 -0700 Subject: [PATCH 0984/1647] bridgev2/portal: add function to get per-message profile for sender Signed-off-by: Sumner Evans --- bridgev2/portal.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 77267e51..3fc6b463 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -3792,3 +3792,17 @@ func (portal *Portal) SetRelay(ctx context.Context, relay *UserLogin) error { } return nil } + +func (portal *Portal) PerMessageProfileForSender(ctx context.Context, sender networkid.UserID) (profile event.BeeperPerMessageProfile, err error) { + var ghost *Ghost + ghost, err = portal.Bridge.GetGhostByID(ctx, sender) + if err != nil { + return + } + profile.ID = string(ghost.Intent.GetMXID()) + profile.Displayname = ghost.Name + if ghost.AvatarMXC != "" { + profile.AvatarURL = &ghost.AvatarMXC + } + return +} From ac1ff66e3ba84960e16936d76e82071556ff2482 Mon Sep 17 00:00:00 2001 From: Brad Murray Date: Fri, 10 Jan 2025 08:19:44 -0500 Subject: [PATCH 0985/1647] bridgev2/messagestatus: prevent checkpoints for double puppeted events (#342) Co-authored-by: Tulir Asokan --- appservice/intent.go | 6 +++- bridgev2/database/message.go | 28 ++++++++++--------- bridgev2/database/upgrades/00-latest.sql | 3 +- .../19-add-double-puppeted-to-message.sql | 2 ++ bridgev2/matrix/connector.go | 14 ++++++---- bridgev2/matrix/intent.go | 4 +++ bridgev2/matrixinterface.go | 1 + bridgev2/messagestatus.go | 8 ++++++ bridgev2/portal.go | 21 ++++++++------ bridgev2/portalbackfill.go | 21 +++++++------- 10 files changed, 69 insertions(+), 39 deletions(-) create mode 100644 bridgev2/database/upgrades/19-add-double-puppeted-to-message.sql diff --git a/appservice/intent.go b/appservice/intent.go index 6848f28c..30313273 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -142,12 +142,16 @@ func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, ext return nil } +func (intent *IntentAPI) IsDoublePuppet() bool { + return intent.IsCustomPuppet && intent.as.DoublePuppetValue != "" +} + func (intent *IntentAPI) AddDoublePuppetValue(into any) any { return intent.AddDoublePuppetValueWithTS(into, 0) } func (intent *IntentAPI) AddDoublePuppetValueWithTS(into any, ts int64) any { - if !intent.IsCustomPuppet || intent.as.DoublePuppetValue == "" { + if !intent.IsDoublePuppet() { return into } // Only use ts deduplication feature with appservice double puppeting diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index 8daf7407..04958490 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -33,11 +33,12 @@ type Message struct { PartID networkid.PartID MXID id.EventID - Room networkid.PortalKey - SenderID networkid.UserID - SenderMXID id.UserID - Timestamp time.Time - EditCount int + Room networkid.PortalKey + SenderID networkid.UserID + SenderMXID id.UserID + Timestamp time.Time + EditCount int + IsDoublePuppeted bool ThreadRoot networkid.MessageID ReplyTo networkid.MessageOptionalPartID @@ -48,7 +49,7 @@ type Message struct { const ( getMessageBaseQuery = ` SELECT rowid, bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, sender_mxid, - timestamp, edit_count, thread_root_id, reply_to_id, reply_to_part_id, metadata + timestamp, edit_count, double_puppeted, thread_root_id, reply_to_id, reply_to_part_id, metadata FROM message ` getAllMessagePartsByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3` @@ -72,15 +73,16 @@ const ( insertMessageQuery = ` INSERT INTO message ( bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, sender_mxid, - timestamp, edit_count, thread_root_id, reply_to_id, reply_to_part_id, metadata + timestamp, edit_count, double_puppeted, thread_root_id, reply_to_id, reply_to_part_id, metadata ) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) RETURNING rowid ` updateMessageQuery = ` UPDATE message SET id=$2, part_id=$3, mxid=$4, room_id=$5, room_receiver=$6, sender_id=$7, sender_mxid=$8, - timestamp=$9, edit_count=$10, thread_root_id=$11, reply_to_id=$12, reply_to_part_id=$13, metadata=$14 - WHERE bridge_id=$1 AND rowid=$15 + timestamp=$9, edit_count=$10, double_puppeted=$11, thread_root_id=$12, reply_to_id=$13, + reply_to_part_id=$14, metadata=$15 + WHERE bridge_id=$1 AND rowid=$16 ` deleteAllMessagePartsByIDQuery = ` DELETE FROM message WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 @@ -174,7 +176,7 @@ func (m *Message) Scan(row dbutil.Scannable) (*Message, error) { var threadRootID, replyToID, replyToPartID sql.NullString err := row.Scan( &m.RowID, &m.BridgeID, &m.ID, &m.PartID, &m.MXID, &m.Room.ID, &m.Room.Receiver, &m.SenderID, &m.SenderMXID, - ×tamp, &m.EditCount, &threadRootID, &replyToID, &replyToPartID, dbutil.JSON{Data: m.Metadata}, + ×tamp, &m.EditCount, &m.IsDoublePuppeted, &threadRootID, &replyToID, &replyToPartID, dbutil.JSON{Data: m.Metadata}, ) if err != nil { return nil, err @@ -200,8 +202,8 @@ func (m *Message) ensureHasMetadata(metaType MetaTypeCreator) *Message { func (m *Message) sqlVariables() []any { return []any{ m.BridgeID, m.ID, m.PartID, m.MXID, m.Room.ID, m.Room.Receiver, m.SenderID, m.SenderMXID, - m.Timestamp.UnixNano(), m.EditCount, dbutil.StrPtr(m.ThreadRoot), dbutil.StrPtr(m.ReplyTo.MessageID), m.ReplyTo.PartID, - dbutil.JSON{Data: m.Metadata}, + m.Timestamp.UnixNano(), m.EditCount, m.IsDoublePuppeted, dbutil.StrPtr(m.ThreadRoot), + dbutil.StrPtr(m.ReplyTo.MessageID), m.ReplyTo.PartID, dbutil.JSON{Data: m.Metadata}, } } diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index 6d6dcf2c..a1bdccac 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -87,7 +87,7 @@ CREATE TABLE message ( -- would try to set bridge_id to null as well. -- only: sqlite (line commented) --- rowid INTEGER PRIMARY KEY, +-- rowid INTEGER PRIMARY KEY, -- only: postgres rowid BIGINT PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY, @@ -102,6 +102,7 @@ CREATE TABLE message ( sender_mxid TEXT NOT NULL, timestamp BIGINT NOT NULL, edit_count INTEGER NOT NULL, + double_puppeted BOOLEAN, thread_root_id TEXT, reply_to_id TEXT, reply_to_part_id TEXT, diff --git a/bridgev2/database/upgrades/19-add-double-puppeted-to-message.sql b/bridgev2/database/upgrades/19-add-double-puppeted-to-message.sql new file mode 100644 index 00000000..ec6fe836 --- /dev/null +++ b/bridgev2/database/upgrades/19-add-double-puppeted-to-message.sql @@ -0,0 +1,2 @@ +-- v19 (compatible with v9+): Add double puppeted state to messages +ALTER TABLE message ADD COLUMN double_puppeted BOOLEAN; diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 94fdd97c..fc3dd36b 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -450,13 +450,17 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2 return "" } log := zerolog.Ctx(ctx) - err := br.SendMessageCheckpoints([]*status.MessageCheckpoint{ms.ToCheckpoint(evt)}) - if err != nil { - log.Err(err).Msg("Failed to send message checkpoint") + + if !evt.IsSourceEventDoublePuppeted { + err := br.SendMessageCheckpoints([]*status.MessageCheckpoint{ms.ToCheckpoint(evt)}) + if err != nil { + log.Err(err).Msg("Failed to send message checkpoint") + } } + if !ms.DisableMSS && br.Config.Matrix.MessageStatusEvents { mssEvt := ms.ToMSSEvent(evt) - _, err = br.Bot.SendMessageEvent(ctx, evt.RoomID, event.BeeperMessageStatus, mssEvt) + _, err := br.Bot.SendMessageEvent(ctx, evt.RoomID, event.BeeperMessageStatus, mssEvt) if err != nil { log.Err(err). Stringer("room_id", evt.RoomID). @@ -482,7 +486,7 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2 } } if ms.Status == event.MessageStatusSuccess && br.Config.Matrix.DeliveryReceipts { - err = br.Bot.SendReceipt(ctx, evt.RoomID, evt.SourceEventID, event.ReceiptTypeRead, nil) + err := br.Bot.SendReceipt(ctx, evt.RoomID, evt.SourceEventID, event.ReceiptTypeRead, nil) if err != nil { log.Err(err). Stringer("room_id", evt.RoomID). diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 9f6c520e..7efc1bab 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -478,6 +478,10 @@ func (as *ASIntent) GetMXID() id.UserID { return as.Matrix.UserID } +func (as *ASIntent) IsDoublePuppet() bool { + return as.Matrix.IsDoublePuppet() +} + func (as *ASIntent) EnsureJoined(ctx context.Context, roomID id.RoomID) error { err := as.Matrix.EnsureJoined(ctx, roomID) if err != nil { diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 699ce07b..aba1eaa4 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -126,6 +126,7 @@ func (ce CallbackError) Unwrap() error { type MatrixAPI interface { GetMXID() id.UserID + IsDoublePuppet() bool SendMessage(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, extra *MatrixSendExtra) (*mautrix.RespSendEvent, error) SendState(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, content *event.Content, ts time.Time) (*mautrix.RespSendEvent, error) diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index 1983b4de..c846f502 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -12,6 +12,7 @@ import ( "go.mau.fi/util/jsontime" + "maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -26,6 +27,8 @@ type MessageStatusEventInfo struct { Sender id.UserID ThreadRoot id.EventID StreamOrder int64 + + IsSourceEventDoublePuppeted bool } func StatusEventInfoFromEvent(evt *event.Event) *MessageStatusEventInfo { @@ -33,6 +36,9 @@ func StatusEventInfoFromEvent(evt *event.Event) *MessageStatusEventInfo { if relatable, ok := evt.Content.Parsed.(event.Relatable); ok { threadRoot = relatable.OptionalGetRelatesTo().GetThreadParent() } + + _, isDoublePuppeted := evt.Content.Raw[appservice.DoublePuppetKey] + return &MessageStatusEventInfo{ RoomID: evt.RoomID, SourceEventID: evt.ID, @@ -40,6 +46,8 @@ func StatusEventInfoFromEvent(evt *event.Event) *MessageStatusEventInfo { MessageType: evt.Content.AsMessage().MsgType, Sender: evt.Sender, ThreadRoot: threadRoot, + + IsSourceEventDoublePuppeted: isDoublePuppeted, } } diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 3fc6b463..06c03c93 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1827,15 +1827,16 @@ func (portal *Portal) sendConvertedMessage( for i, part := range converted.Parts { portal.applyRelationMeta(part.Content, replyTo, threadRoot, prevThreadEvent) dbMessage := &database.Message{ - ID: id, - PartID: part.ID, - Room: portal.PortalKey, - SenderID: senderID, - SenderMXID: intent.GetMXID(), - Timestamp: ts, - ThreadRoot: ptr.Val(converted.ThreadRoot), - ReplyTo: ptr.Val(converted.ReplyTo), - Metadata: part.DBMetadata, + ID: id, + PartID: part.ID, + Room: portal.PortalKey, + SenderID: senderID, + SenderMXID: intent.GetMXID(), + Timestamp: ts, + ThreadRoot: ptr.Val(converted.ThreadRoot), + ReplyTo: ptr.Val(converted.ReplyTo), + Metadata: part.DBMetadata, + IsDoublePuppeted: intent.IsDoublePuppet(), } if part.DontBridge { dbMessage.SetFakeMXID() @@ -2603,6 +2604,8 @@ func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *U RoomID: portal.MXID, SourceEventID: part.MXID, Sender: part.SenderMXID, + + IsSourceEventDoublePuppeted: part.IsDoublePuppeted, }) } } diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index 55225efc..ac15880d 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -333,16 +333,17 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin portal.applyRelationMeta(part.Content, replyTo, threadRoot, prevThreadEvent) evtID := portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, msg.ID, part.ID) dbMessage := &database.Message{ - ID: msg.ID, - PartID: part.ID, - MXID: evtID, - Room: portal.PortalKey, - SenderID: msg.Sender.Sender, - SenderMXID: intent.GetMXID(), - Timestamp: msg.Timestamp, - ThreadRoot: ptr.Val(msg.ThreadRoot), - ReplyTo: ptr.Val(msg.ReplyTo), - Metadata: part.DBMetadata, + ID: msg.ID, + PartID: part.ID, + MXID: evtID, + Room: portal.PortalKey, + SenderID: msg.Sender.Sender, + SenderMXID: intent.GetMXID(), + Timestamp: msg.Timestamp, + ThreadRoot: ptr.Val(msg.ThreadRoot), + ReplyTo: ptr.Val(msg.ReplyTo), + Metadata: part.DBMetadata, + IsDoublePuppeted: intent.IsDoublePuppet(), } if part.DontBridge { dbMessage.SetFakeMXID() From fc696eaa4721246dc2a438078ba2a49f84487a5c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 10 Jan 2025 16:48:09 +0200 Subject: [PATCH 0986/1647] bridgev2/database: fix bugs with double puppeted column --- bridgev2/database/message.go | 4 +++- bridgev2/database/upgrades/00-latest.sql | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index 04958490..42581c6e 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -174,15 +174,17 @@ func (mq *MessageQuery) CountMessagesInPortal(ctx context.Context, key networkid func (m *Message) Scan(row dbutil.Scannable) (*Message, error) { var timestamp int64 var threadRootID, replyToID, replyToPartID sql.NullString + var doublePuppeted sql.NullBool err := row.Scan( &m.RowID, &m.BridgeID, &m.ID, &m.PartID, &m.MXID, &m.Room.ID, &m.Room.Receiver, &m.SenderID, &m.SenderMXID, - ×tamp, &m.EditCount, &m.IsDoublePuppeted, &threadRootID, &replyToID, &replyToPartID, dbutil.JSON{Data: m.Metadata}, + ×tamp, &m.EditCount, &doublePuppeted, &threadRootID, &replyToID, &replyToPartID, dbutil.JSON{Data: m.Metadata}, ) if err != nil { return nil, err } m.Timestamp = time.Unix(0, timestamp) m.ThreadRoot = networkid.MessageID(threadRootID.String) + m.IsDoublePuppeted = doublePuppeted.Valid if replyToID.Valid { m.ReplyTo.MessageID = networkid.MessageID(replyToID.String) if replyToPartID.Valid { diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index a1bdccac..056c3cc4 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v18 (compatible with v9+): Latest revision +-- v0 -> v19 (compatible with v9+): Latest revision CREATE TABLE "user" ( bridge_id TEXT NOT NULL, mxid TEXT NOT NULL, From 59645cdf73830fe2a870d23b9e023f8edf6f76b8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 10 Jan 2025 16:54:46 +0200 Subject: [PATCH 0987/1647] bridgev2/matrixinterface: let connector generate deterministic room IDs (#343) --- bridgev2/matrix/connector.go | 4 ++++ bridgev2/matrixinterface.go | 1 + bridgev2/portal.go | 2 +- 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index fc3dd36b..53fe1d85 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -623,6 +623,10 @@ func (br *Connector) GenerateDeterministicEventID(roomID id.RoomID, _ networkid. return id.EventID(unsafe.String(unsafe.SliceData(eventID), len(eventID))) } +func (br *Connector) GenerateDeterministicRoomID(key networkid.PortalKey) id.RoomID { + return id.RoomID(fmt.Sprintf("!%s.%s:%s", key.ID, key.Receiver, br.ServerName())) +} + func (br *Connector) GenerateReactionEventID(roomID id.RoomID, targetMessage *database.Message, sender networkid.UserID, emojiID networkid.EmojiID) id.EventID { // We don't care about determinism for reactions return id.EventID(fmt.Sprintf("$%s:%s", base64.RawURLEncoding.EncodeToString(random.Bytes(32)), br.deterministicEventIDServer)) diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index aba1eaa4..615fbcb7 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -50,6 +50,7 @@ type MatrixConnector interface { GetMemberInfo(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) BatchSend(ctx context.Context, roomID id.RoomID, req *mautrix.ReqBeeperBatchSend, extras []*MatrixSendExtra) (*mautrix.RespBeeperBatchSend, error) + GenerateDeterministicRoomID(portalKey networkid.PortalKey) id.RoomID GenerateDeterministicEventID(roomID id.RoomID, portalKey networkid.PortalKey, messageID networkid.MessageID, partID networkid.PartID) id.EventID GenerateReactionEventID(roomID id.RoomID, targetMessage *database.Message, sender networkid.UserID, emojiID networkid.EmojiID) id.EventID diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 06c03c93..2ddba76a 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -3566,7 +3566,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo Preset: "private_chat", IsDirect: portal.RoomType == database.RoomTypeDM, PowerLevelOverride: powerLevels, - BeeperLocalRoomID: id.RoomID(fmt.Sprintf("!%s.%s:%s", portal.ID, portal.Receiver, portal.Bridge.Matrix.ServerName())), + BeeperLocalRoomID: portal.Bridge.Matrix.GenerateDeterministicRoomID(portal.PortalKey), } autoJoinInvites := portal.Bridge.Matrix.GetCapabilities().AutoJoinInvites if autoJoinInvites { From 285106586976ee38ee4c9aa4b0e485c5b4b04d9d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 10 Jan 2025 16:55:18 +0200 Subject: [PATCH 0988/1647] bridgev2: send room capabilities as a state event (#344) --- bridgev2/bridge.go | 73 ++++- bridgev2/bridgeconfig/config.go | 1 + bridgev2/bridgeconfig/upgrade.go | 1 + bridgev2/database/kvstore.go | 1 + bridgev2/database/portal.go | 18 +- bridgev2/database/upgrades/00-latest.sql | 3 +- .../upgrades/20-portal-capabilities.sql | 2 + bridgev2/matrix/connector.go | 4 + bridgev2/matrix/mxmain/example-config.yaml | 3 + bridgev2/matrix/mxmain/main.go | 1 + bridgev2/matrixinterface.go | 4 + bridgev2/networkinterface.go | 40 +-- bridgev2/portal.go | 129 +++++++-- bridgev2/portalinternal.go | 4 +- event/capabilities.d.ts | 158 +++++++++++ event/capabilities.go | 259 ++++++++++++++++++ event/content.go | 1 + event/message.go | 30 +- event/type.go | 3 +- go.mod | 2 +- go.sum | 4 +- 21 files changed, 659 insertions(+), 82 deletions(-) create mode 100644 bridgev2/database/upgrades/20-portal-capabilities.sql create mode 100644 event/capabilities.d.ts create mode 100644 event/capabilities.go diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index b2151ee6..794c1f29 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -45,6 +45,8 @@ type Bridge struct { ghostsByID map[networkid.UserID]*Ghost cacheLock sync.Mutex + didSplitPortals bool + wakeupBackfillQueue chan struct{} stopBackfillQueue chan struct{} } @@ -109,18 +111,22 @@ func (br *Bridge) Start() error { if err != nil { return err } + br.PostStart() return nil } func (br *Bridge) StartConnectors() error { br.Log.Info().Msg("Starting bridge") ctx := br.Log.WithContext(context.Background()) + foreground := true err := br.DB.Upgrade(ctx) if err != nil { return DBUpgradeError{Err: err, Section: "main"} } - didSplitPortals := br.MigrateToSplitPortals(ctx) + if foreground { + br.didSplitPortals = br.MigrateToSplitPortals(ctx) + } br.Log.Info().Msg("Starting Matrix connector") err = br.Matrix.Start(ctx) if err != nil { @@ -134,13 +140,35 @@ func (br *Bridge) StartConnectors() error { if br.Network.GetCapabilities().DisappearingMessages { go br.DisappearLoop.Start() } - if didSplitPortals || br.Config.ResendBridgeInfo { - br.ResendBridgeInfo(ctx) - } return nil } -func (br *Bridge) ResendBridgeInfo(ctx context.Context) { +func (br *Bridge) PostStart() { + ctx := br.Log.WithContext(context.Background()) + rawBridgeInfoVer := br.DB.KV.Get(ctx, database.KeyBridgeInfoVersion) + bridgeInfoVer, capVer, err := parseBridgeInfoVersion(rawBridgeInfoVer) + if err != nil { + br.Log.Err(err).Str("db_bridge_info_version", rawBridgeInfoVer).Msg("Failed to parse bridge info version") + return + } + expectedBridgeInfoVer, expectedCapVer := br.Network.GetBridgeInfoVersion() + doResendBridgeInfo := bridgeInfoVer != expectedBridgeInfoVer || br.didSplitPortals || br.Config.ResendBridgeInfo + doResendCapabilities := capVer != expectedCapVer || br.didSplitPortals + if doResendBridgeInfo || doResendCapabilities { + br.ResendBridgeInfo(ctx, doResendBridgeInfo, doResendCapabilities) + } + br.DB.KV.Set(ctx, database.KeyBridgeInfoVersion, fmt.Sprintf("%d,%d", expectedBridgeInfoVer, expectedCapVer)) +} + +func parseBridgeInfoVersion(version string) (info, capabilities int, err error) { + _, err = fmt.Sscanf(version, "%d,%d", &info, &capabilities) + if version == "" { + err = nil + } + return +} + +func (br *Bridge) ResendBridgeInfo(ctx context.Context, resendInfo, resendCaps bool) { log := zerolog.Ctx(ctx).With().Str("action", "resend bridge info").Logger() portals, err := br.GetAllPortalsWithMXID(ctx) if err != nil { @@ -148,9 +176,40 @@ func (br *Bridge) ResendBridgeInfo(ctx context.Context) { return } for _, portal := range portals { - portal.UpdateBridgeInfo(ctx) + if resendInfo { + portal.UpdateBridgeInfo(ctx) + } + if resendCaps { + logins, err := br.GetUserLoginsInPortal(ctx, portal.PortalKey) + if err != nil { + log.Err(err). + Stringer("room_id", portal.MXID). + Object("portal_key", portal.PortalKey). + Msg("Failed to get user logins in portal") + } else { + found := false + for _, login := range logins { + if portal.CapState.ID == "" || login.ID == portal.CapState.Source { + portal.UpdateCapabilities(ctx, login, true) + found = true + } + } + if !found && len(logins) > 0 { + portal.CapState.Source = "" + portal.UpdateCapabilities(ctx, logins[0], true) + } else if !found { + log.Warn(). + Stringer("room_id", portal.MXID). + Object("portal_key", portal.PortalKey). + Msg("No user login found to update capabilities") + } + } + } } - log.Info().Msg("Resent bridge info to all portals") + log.Info(). + Bool("capabilities", resendCaps). + Bool("info", resendInfo). + Msg("Resent bridge info to all portals") } func (br *Bridge) MigrateToSplitPortals(ctx context.Context) bool { diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index cf87864f..12d5452b 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -64,6 +64,7 @@ type BridgeConfig struct { AsyncEvents bool `yaml:"async_events"` SplitPortals bool `yaml:"split_portals"` ResendBridgeInfo bool `yaml:"resend_bridge_info"` + NoBridgeInfoStateKey bool `yaml:"no_bridge_info_state_key"` BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` TagOnlyOnCreate bool `yaml:"tag_only_on_create"` OnlyBridgeTags []event.RoomTag `yaml:"only_bridge_tags"` diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 17c4af13..ea986fda 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -30,6 +30,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "bridge", "async_events") helper.Copy(up.Bool, "bridge", "split_portals") helper.Copy(up.Bool, "bridge", "resend_bridge_info") + helper.Copy(up.Bool, "bridge", "no_bridge_info_state_key") helper.Copy(up.Bool, "bridge", "bridge_matrix_leave") helper.Copy(up.Bool, "bridge", "tag_only_on_create") helper.Copy(up.List, "bridge", "only_bridge_tags") diff --git a/bridgev2/database/kvstore.go b/bridgev2/database/kvstore.go index 3fc54f2c..5a1af019 100644 --- a/bridgev2/database/kvstore.go +++ b/bridgev2/database/kvstore.go @@ -21,6 +21,7 @@ type Key string const ( KeySplitPortalsEnabled Key = "split_portals_enabled" + KeyBridgeInfoVersion Key = "bridge_info_version" ) type KVQuery struct { diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index 72e31454..17e44b09 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -34,6 +34,11 @@ type PortalQuery struct { *dbutil.QueryHelper[*Portal] } +type CapabilityState struct { + Source networkid.UserLoginID `json:"source"` + ID string `json:"id"` +} + type Portal struct { BridgeID networkid.BridgeID networkid.PortalKey @@ -54,6 +59,7 @@ type Portal struct { InSpace bool RoomType RoomType Disappear DisappearingSetting + CapState CapabilityState Metadata any } @@ -62,7 +68,7 @@ const ( SELECT bridge_id, id, receiver, mxid, parent_id, parent_receiver, relay_login_id, other_user_id, name, topic, avatar_id, avatar_hash, avatar_mxc, name_set, topic_set, avatar_set, name_is_custom, in_space, - room_type, disappear_type, disappear_timer, + room_type, disappear_type, disappear_timer, cap_state, metadata FROM portal ` @@ -82,10 +88,10 @@ const ( parent_id, parent_receiver, relay_login_id, other_user_id, name, topic, avatar_id, avatar_hash, avatar_mxc, name_set, avatar_set, topic_set, name_is_custom, in_space, - room_type, disappear_type, disappear_timer, + room_type, disappear_type, disappear_timer, cap_state, metadata, relay_bridge_id ) VALUES ( - $1, $2, $3, $4, $5, $6, cast($7 AS TEXT), $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, + $1, $2, $3, $4, $5, $6, cast($7 AS TEXT), $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, CASE WHEN cast($7 AS TEXT) IS NULL THEN NULL ELSE $1 END ) ` @@ -95,7 +101,7 @@ const ( relay_login_id=cast($7 AS TEXT), relay_bridge_id=CASE WHEN cast($7 AS TEXT) IS NULL THEN NULL ELSE bridge_id END, other_user_id=$8, name=$9, topic=$10, avatar_id=$11, avatar_hash=$12, avatar_mxc=$13, name_set=$14, avatar_set=$15, topic_set=$16, name_is_custom=$17, in_space=$18, - room_type=$19, disappear_type=$20, disappear_timer=$21, metadata=$22 + room_type=$19, disappear_type=$20, disappear_timer=$21, cap_state=$22, metadata=$23 WHERE bridge_id=$1 AND id=$2 AND receiver=$3 ` deletePortalQuery = ` @@ -189,7 +195,7 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { &p.Name, &p.Topic, &p.AvatarID, &avatarHash, &p.AvatarMXC, &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.NameIsCustom, &p.InSpace, &p.RoomType, &disappearType, &disappearTimer, - dbutil.JSON{Data: p.Metadata}, + dbutil.JSON{Data: &p.CapState}, dbutil.JSON{Data: p.Metadata}, ) if err != nil { return nil, err @@ -236,6 +242,6 @@ func (p *Portal) sqlVariables() []any { p.Name, p.Topic, p.AvatarID, avatarHash, p.AvatarMXC, p.NameSet, p.TopicSet, p.AvatarSet, p.NameIsCustom, p.InSpace, p.RoomType, dbutil.StrPtr(p.Disappear.Type), dbutil.NumPtr(p.Disappear.Timer), - dbutil.JSON{Data: p.Metadata}, + dbutil.JSON{Data: p.CapState}, dbutil.JSON{Data: p.Metadata}, } } diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index 056c3cc4..56976b82 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v19 (compatible with v9+): Latest revision +-- v0 -> v20 (compatible with v9+): Latest revision CREATE TABLE "user" ( bridge_id TEXT NOT NULL, mxid TEXT NOT NULL, @@ -51,6 +51,7 @@ CREATE TABLE portal ( room_type TEXT NOT NULL, disappear_type TEXT, disappear_timer BIGINT, + cap_state jsonb, metadata jsonb NOT NULL, PRIMARY KEY (bridge_id, id, receiver), diff --git a/bridgev2/database/upgrades/20-portal-capabilities.sql b/bridgev2/database/upgrades/20-portal-capabilities.sql new file mode 100644 index 00000000..00bd96ca --- /dev/null +++ b/bridgev2/database/upgrades/20-portal-capabilities.sql @@ -0,0 +1,2 @@ +-- v20 (compatible with v9+): Add portal capability state +ALTER TABLE portal ADD COLUMN cap_state jsonb; diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 53fe1d85..0bb1ee61 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -579,6 +579,10 @@ func (br *Connector) IsConfusableName(ctx context.Context, roomID id.RoomID, use return br.AS.StateStore.IsConfusableName(ctx, roomID, userID, name) } +func (br *Connector) GetUniqueBridgeID() string { + return fmt.Sprintf("%s/%s", br.Config.Homeserver.Domain, br.Config.AppService.ID) +} + func (br *Connector) BatchSend(ctx context.Context, roomID id.RoomID, req *mautrix.ReqBeeperBatchSend, extras []*bridgev2.MatrixSendExtra) (*mautrix.RespBeeperBatchSend, error) { if encrypted, err := br.StateStore.IsEncrypted(ctx, roomID); err != nil { return nil, fmt.Errorf("failed to check if room is encrypted: %w", err) diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index d09047c3..82c431fb 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -18,6 +18,9 @@ bridge: split_portals: false # Should the bridge resend `m.bridge` events to all portals on startup? resend_bridge_info: false + # Should `m.bridge` events be sent without a state key? + # By default, the bridge uses a unique key that won't conflict with other bridges. + no_bridge_info_state_key: false # Should leaving Matrix rooms be bridged as leaving groups on the remote network? bridge_matrix_leave: false diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index 46f27e73..695b042b 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -387,6 +387,7 @@ func (br *BridgeMain) Start() { if err != nil { br.Log.Fatal().Err(err).Msg("Failed to start existing user logins") } + br.Bridge.PostStart() if br.PostStart != nil { br.PostStart() } diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 615fbcb7..8ac2e92d 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -70,6 +70,10 @@ type MatrixConnectorWithNameDisambiguation interface { IsConfusableName(ctx context.Context, roomID id.RoomID, userID id.UserID, name string) ([]id.UserID, error) } +type MatrixConnectorWithBridgeIdentifier interface { + GetUniqueBridgeID() string +} + type MatrixConnectorWithURLPreviews interface { GetURLPreview(ctx context.Context, url string) (*event.LinkPreview, error) } diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index db066f0a..05d948c8 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -227,6 +227,10 @@ type NetworkConnector interface { // This should generally not do any work, it should just return a LoginProcess that remembers // the user and will execute the requested flow. The actual work should start when [LoginProcess.Start] is called. CreateLogin(ctx context.Context, user *User, flowID string) (LoginProcess, error) + + // GetBridgeInfoVersion returns version numbers for bridge info and room capabilities respectively. + // When the versions change, the bridge will automatically resend bridge info to all rooms. + GetBridgeInfoVersion() (info, capabilities int) } type StoppableNetwork interface { @@ -297,11 +301,6 @@ type MatrixMessageResponse struct { PostSave func(context.Context, *database.Message) } -type FileRestriction struct { - MaxSize int64 - MimeTypes []string -} - type NetworkGeneralCapabilities struct { // Does the network connector support disappearing messages? // This flag enables the message disappearing loop in the bridge. @@ -311,35 +310,6 @@ type NetworkGeneralCapabilities struct { AggressiveUpdateInfo bool } -type NetworkRoomCapabilities struct { - FormattedText bool - UserMentions bool - RoomMentions bool - - LocationMessages bool - Captions bool - MaxTextLength int - MaxCaptionLength int - Polls bool - - Threads bool - Replies bool - Edits bool - EditMaxCount int - EditMaxAge time.Duration - Deletes bool - DeleteMaxAge time.Duration - - DefaultFileRestriction *FileRestriction - Files map[event.MessageType]FileRestriction - - ReadReceipts bool - - Reactions bool - ReactionCount int - AllowedReactions []string -} - // NetworkAPI is an interface representing a remote network client for a single user login. // // Implementations of this interface are stored in [UserLogin.Client]. @@ -372,7 +342,7 @@ type NetworkAPI interface { // GetCapabilities returns the bridging capabilities in a given room. // This can simply return a static list if the remote network has no per-chat capability differences, // but all calls will include the portal, because some networks do have per-chat differences. - GetCapabilities(ctx context.Context, portal *Portal) *NetworkRoomCapabilities + GetCapabilities(ctx context.Context, portal *Portal) *event.RoomFeatures // HandleMatrixMessage is called when a message is sent from Matrix in an existing portal room. // This function should convert the message as appropriate, send it over to the remote network, diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 2ddba76a..4d8d76e6 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -79,6 +79,8 @@ type Portal struct { outgoingMessages map[networkid.TransactionID]outgoingMessage outgoingMessagesLock sync.Mutex + lastCapUpdate time.Time + roomCreateLock sync.Mutex events chan portalEvent @@ -755,22 +757,36 @@ func (portal *Portal) periodicTypingUpdater() { } } -func (portal *Portal) checkMessageContentCaps(ctx context.Context, caps *NetworkRoomCapabilities, content *event.MessageEventContent, evt *event.Event) bool { +func (portal *Portal) checkMessageContentCaps(ctx context.Context, caps *event.RoomFeatures, content *event.MessageEventContent, evt *event.Event) bool { switch content.MsgType { case event.MsgText, event.MsgNotice, event.MsgEmote: // No checks for now, message length is safer to check after conversion inside connector case event.MsgLocation: - if !caps.LocationMessages { + if caps.LocationMessage.Reject() { portal.sendErrorStatus(ctx, evt, ErrLocationMessagesNotAllowed) return false } - case event.MsgImage, event.MsgAudio, event.MsgVideo, event.MsgFile: - if content.FileName != "" && content.Body != content.FileName { - if !caps.Captions { - portal.sendErrorStatus(ctx, evt, ErrCaptionsNotAllowed) + case event.MsgImage, event.MsgAudio, event.MsgVideo, event.MsgFile, event.CapMsgSticker: + capMsgType := content.GetCapMsgType() + feat, ok := caps.File[capMsgType] + if !ok { + portal.sendErrorStatus(ctx, evt, ErrUnsupportedMessageType) + return false + } + if content.MsgType != event.CapMsgSticker && + content.FileName != "" && + content.Body != content.FileName && + feat.Caption.Reject() { + portal.sendErrorStatus(ctx, evt, ErrCaptionsNotAllowed) + return false + } + if content.Info != nil && content.Info.MimeType != "" { + if feat.GetMimeSupport(content.Info.MimeType).Reject() { + portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w (%s in %s)", ErrUnsupportedMediaType, content.Info.MimeType, capMsgType)) return false } } + fallthrough default: } return true @@ -792,6 +808,9 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin } else { msgContent, ok = evt.Content.Parsed.(*event.MessageEventContent) relatesTo = msgContent.RelatesTo + if evt.Type == event.EventSticker { + msgContent.MsgType = event.CapMsgSticker + } } if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") @@ -849,21 +868,21 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin } } var replyToID id.EventID - if caps.Threads { + threadRootID := relatesTo.GetThreadParent() + if caps.Thread.Partial() { replyToID = relatesTo.GetNonFallbackReplyTo() + if threadRootID != "" { + threadRoot, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, threadRootID) + if err != nil { + log.Err(err).Msg("Failed to get thread root message from database") + } else if threadRoot == nil { + log.Warn().Stringer("thread_root_id", threadRootID).Msg("Thread root message not found") + } + } } else { replyToID = relatesTo.GetReplyTo() } - threadRootID := relatesTo.GetThreadParent() - if caps.Threads && threadRootID != "" { - threadRoot, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, threadRootID) - if err != nil { - log.Err(err).Msg("Failed to get thread root message from database") - } else if threadRoot == nil { - log.Warn().Stringer("thread_root_id", threadRootID).Msg("Thread root message not found") - } - } - if replyToID != "" && (caps.Replies || caps.Threads) { + if replyToID != "" && (caps.Reply.Partial() || caps.Thread.Partial()) { replyTo, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, replyToID) if err != nil { log.Err(err).Msg("Failed to get reply target message from database") @@ -874,7 +893,7 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin // The fallback happens if the message is not a Matrix thread and either // * the replied-to message is in a thread, or // * the network only supports threads (assume the user wants to start a new thread) - if caps.Threads && threadRoot == nil && (replyTo.ThreadRoot != "" || !caps.Replies) { + if caps.Thread.Partial() && threadRoot == nil && (replyTo.ThreadRoot != "" || !caps.Reply.Partial()) { threadRootRemoteID := replyTo.ThreadRoot if threadRootRemoteID == "" { threadRootRemoteID = replyTo.ID @@ -884,7 +903,7 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin log.Err(err).Msg("Failed to get thread root message from database (via reply fallback)") } } - if !caps.Replies { + if !caps.Reply.Partial() { replyTo = nil } } @@ -1031,7 +1050,7 @@ func (evt *MatrixMessage) fillDBMessage(message *database.Message) *database.Mes return message } -func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent, caps *NetworkRoomCapabilities) { +func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent, caps *event.RoomFeatures) { log := zerolog.Ctx(ctx) editTargetID := content.RelatesTo.GetReplaceID() log.UpdateContext(func(c zerolog.Context) zerolog.Context { @@ -1039,6 +1058,9 @@ func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, o }) if content.NewContent != nil { content = content.NewContent + if evt.Type == event.EventSticker { + content.MsgType = event.CapMsgSticker + } } if origSender != nil { var err error @@ -1055,7 +1077,7 @@ func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, o log.Debug().Msg("Ignoring edit as network connector doesn't implement EditHandlingNetworkAPI") portal.sendErrorStatus(ctx, evt, ErrEditsNotSupported) return - } else if !caps.Edits { + } else if !caps.Edit.Partial() { log.Debug().Msg("Ignoring edit as room doesn't support edits") portal.sendErrorStatus(ctx, evt, ErrEditsNotSupportedInPortal) return @@ -1071,7 +1093,7 @@ func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, o log.Warn().Msg("Edit target message not found in database") portal.sendErrorStatus(ctx, evt, fmt.Errorf("edit %w", ErrTargetMessageNotFound)) return - } else if caps.EditMaxAge > 0 && time.Since(editTarget.Timestamp) > caps.EditMaxAge { + } else if caps.EditMaxAge.Duration > 0 && time.Since(editTarget.Timestamp) > caps.EditMaxAge.Duration { portal.sendErrorStatus(ctx, evt, ErrEditTargetTooOld) return } else if caps.EditMaxCount > 0 && editTarget.EditCount >= caps.EditMaxCount { @@ -2931,6 +2953,17 @@ func (portal *Portal) GetTopLevelParent() *Portal { return portal.Parent.GetTopLevelParent() } +func (portal *Portal) getBridgeInfoStateKey() string { + if portal.Bridge.Config.NoBridgeInfoStateKey { + return "" + } + idProvider, ok := portal.Bridge.Matrix.(MatrixConnectorWithBridgeIdentifier) + if ok { + return idProvider.GetUniqueBridgeID() + } + return string(portal.BridgeID) +} + func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) { bridgeInfo := event.BridgeEventContent{ BridgeBot: portal.Bridge.Bot.GetMXID(), @@ -2961,10 +2994,7 @@ func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) { if ok { filler.FillPortalBridgeInfo(portal, &bridgeInfo) } - // TODO use something globally unique instead of bridge ID? - // maybe ask the matrix connector to use serverName+appserviceID+bridgeID - stateKey := string(portal.BridgeID) - return stateKey, bridgeInfo + return portal.getBridgeInfoStateKey(), bridgeInfo } func (portal *Portal) UpdateBridgeInfo(ctx context.Context) { @@ -2976,6 +3006,43 @@ func (portal *Portal) UpdateBridgeInfo(ctx context.Context) { portal.sendRoomMeta(ctx, nil, time.Now(), event.StateHalfShotBridge, stateKey, &bridgeInfo) } +func (portal *Portal) UpdateCapabilities(ctx context.Context, source *UserLogin, implicit bool) bool { + if portal.MXID == "" { + return false + } else if !implicit && time.Since(portal.lastCapUpdate) < 24*time.Hour { + return false + } else if portal.CapState.ID != "" && source.ID != portal.CapState.Source && source.ID != portal.Receiver { + // TODO allow capability state source to change if the old user login is removed from the portal + return false + } + caps := source.Client.GetCapabilities(ctx, portal) + capID := caps.GetID() + if capID == portal.CapState.ID { + return false + } + zerolog.Ctx(ctx).Debug(). + Str("user_login_id", string(source.ID)). + Str("old_id", portal.CapState.ID). + Str("new_id", capID). + Msg("Sending new room capability event") + success := portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperRoomFeatures, portal.getBridgeInfoStateKey(), caps) + if !success { + return false + } + portal.CapState = database.CapabilityState{ + Source: source.ID, + ID: capID, + } + portal.lastCapUpdate = time.Now() + if implicit { + err := portal.Save(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal capability state after sending state event") + } + } + return true +} + func (portal *Portal) sendStateWithIntentOrBot(ctx context.Context, sender MatrixAPI, eventType event.Type, stateKey string, content *event.Content, ts time.Time) (resp *mautrix.RespSendEvent, err error) { if sender == nil { sender = portal.Bridge.Bot @@ -3462,6 +3529,7 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us if source != nil { source.MarkInPortal(ctx, portal) portal.updateUserLocalInfo(ctx, info.UserLocal, source, false) + changed = portal.UpdateCapabilities(ctx, source, false) || changed } if info.CanBackfill && source != nil && portal.MXID != "" { err := portal.Bridge.DB.BackfillTask.EnsureExists(ctx, portal.PortalKey, source.ID) @@ -3579,6 +3647,11 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo req.CreationContent["type"] = event.RoomTypeSpace } bridgeInfoStateKey, bridgeInfo := portal.getBridgeInfo() + roomFeatures := source.Client.GetCapabilities(ctx, portal) + portal.CapState = database.CapabilityState{ + Source: source.ID, + ID: roomFeatures.GetID(), + } req.InitialState = append(req.InitialState, &event.Event{ Type: event.StateElementFunctionalMembers, @@ -3593,6 +3666,10 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo StateKey: &bridgeInfoStateKey, Type: event.StateBridge, Content: event.Content{Parsed: &bridgeInfo}, + }, &event.Event{ + StateKey: &bridgeInfoStateKey, + Type: event.StateBeeperRoomFeatures, + Content: event.Content{Parsed: roomFeatures}, }) if req.Topic == "" { // Add explicit topic event if topic is empty to ensure the event is set. diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index a5da077b..e0f4ee5a 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -85,7 +85,7 @@ func (portal *PortalInternals) PeriodicTypingUpdater() { (*Portal)(portal).periodicTypingUpdater() } -func (portal *PortalInternals) CheckMessageContentCaps(ctx context.Context, caps *NetworkRoomCapabilities, content *event.MessageEventContent, evt *event.Event) bool { +func (portal *PortalInternals) CheckMessageContentCaps(ctx context.Context, caps *event.RoomFeatures, content *event.MessageEventContent, evt *event.Event) bool { return (*Portal)(portal).checkMessageContentCaps(ctx, caps, content, evt) } @@ -93,7 +93,7 @@ func (portal *PortalInternals) HandleMatrixMessage(ctx context.Context, sender * (*Portal)(portal).handleMatrixMessage(ctx, sender, origSender, evt) } -func (portal *PortalInternals) HandleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent, caps *NetworkRoomCapabilities) { +func (portal *PortalInternals) HandleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent, caps *event.RoomFeatures) { (*Portal)(portal).handleMatrixEdit(ctx, sender, origSender, evt, content, caps) } diff --git a/event/capabilities.d.ts b/event/capabilities.d.ts new file mode 100644 index 00000000..1c5d533d --- /dev/null +++ b/event/capabilities.d.ts @@ -0,0 +1,158 @@ +/** + * The content of the `com.beeper.room_features` state event. + */ +export interface RoomFeatures { + /** + * Supported formatting features. If omitted, no formatting is supported. + * + * Capability level 0 means the corresponding HTML tags/attributes are ignored + * and will be treated as if they don't exist, which means that children will + * be rendered, but attributes will be dropped. + */ + formatting?: Record + /** + * Supported file message types and their features. + * + * If a message type isn't listed here, it should be treated as support level -2 (will be rejected). + */ + file?: Record + + /** Whether location messages (`m.location`) are supported. */ + location_message?: CapabilitySupportLevel + /** Whether polls are supported. */ + poll?: CapabilitySupportLevel + /** Whether replying in a thread is supported. */ + thread?: CapabilitySupportLevel + /** Whether replying to a specific message is supported. */ + reply?: CapabilitySupportLevel + + /** Whether edits are supported. */ + edit?: CapabilitySupportLevel + /** How many times can an individual message be edited. */ + edit_max_count?: integer + /** How old messages can be edited, in seconds. */ + edit_max_age?: seconds + /** Whether deleting messages for everyone is supported */ + delete?: CapabilitySupportLevel + /** How old messages can be deleted for everyone, in seconds. */ + delete_max_age?: seconds + /** Whether deleting messages just for yourself is supported. No message age limit. */ + delete_for_me?: boolean + + /** Whether reactions are supported. */ + reaction?: CapabilitySupportLevel + /** How many reactions can be added to a single message. */ + reaction_count?: integer + /** + * The Unicode emojis allowed for reactions. If omitted, all emojis are allowed. + * Emojis in this list must include variation selector 16 if allowed in the Unicode spec. + */ + allowed_reactions?: string[] + /** Whether custom emoji reactions are allowed. */ + custom_emoji_reactions?: boolean +} + +declare type integer = number +declare type seconds = integer +declare type MIMEClass = "image" | "audio" | "video" | "text" | "font" | "model" | "application" +declare type MIMETypeOrPattern = + "*/*" + | `${MIMEClass}/*` + | `${MIMEClass}/${string}` + | `${MIMEClass}/${string}; ${string}` + +export enum CapabilityMsgType { + // Real message types used in the `msgtype` field + Image = "m.image", + File = "m.file", + Audio = "m.audio", + Video = "m.video", + + // Pseudo types only used in capabilities + /** An `m.audio` message that has `"org.matrix.msc3245.voice": {}` */ + Voice = "org.matrix.msc3245.voice", + /** An `m.video` message that has `"info": {"fi.mau.gif": true}`, or an `m.image` message of type `image/gif` */ + GIF = "fi.mau.gif", + /** An `m.sticker` event, no `msgtype` field */ + Sticker = "m.sticker", +} + +export interface FileFeatures { + /** + * The supported MIME types or type patterns and their support levels. + * + * If a mime type doesn't match any pattern provided, + * it should be treated as support level -2 (will be rejected). + */ + mime_types: Record + + /** The support level for captions within this file message type */ + caption?: CapabilitySupportLevel + /** The maximum length for captions (only applicable if captions are supported). */ + max_caption_length?: integer + /** The maximum file size as bytes. */ + max_size?: integer + /** For images and videos, the maximum width as pixels. */ + max_width?: integer + /** For images and videos, the maximum height as pixels. */ + max_height?: integer + /** For videos and audio files, the maximum duration as seconds. */ + max_duration?: seconds + + /** Can this type of file be sent as view-once media? */ + view_once?: boolean +} + +/** + * The support level for a feature. These are integers rather than booleans + * to accurately represent what the bridge is doing and hopefully make the + * state event more generally useful. Our clients should check for > 0 to + * determine if the feature should be allowed. + */ +export enum CapabilitySupportLevel { + /** The feature is unsupported and messages using it will be rejected. */ + Rejected = -2, + /** The feature is unsupported and has no fallback. The message will go through, but data may be lost. */ + Dropped = -1, + /** The feature is unsupported, but may have a fallback. The nature of the fallback depends on the context. */ + Unsupported = 0, + /** The feature is partially supported (e.g. it may be converted to a different format). */ + PartialSupport = 1, + /** The feature is fully supported and can be safely used. */ + FullySupported = 2, +} + +/** + * A formatting feature that consists of specific HTML tags and/or attributes. + */ +export enum FormattingFeature { + Bold = "bold", // strong, b + Italic = "italic", // em, i + Underline = "underline", // u + Strikethrough = "strikethrough", // del, s + InlineCode = "inline_code", // code + CodeBlock = "code_block", // pre + code + SyntaxHighlighting = "code_block.syntax_highlighting", //

+	Blockquote = "blockquote", // blockquote
+	InlineLink = "inline_link", // a
+	UserLink = "user_link", // 
+	RoomLink = "room_link", // 
+	EventLink = "event_link", // 
+	AtRoomMention = "at_room_mention", // @room (no html tag)
+	UnorderedList = "unordered_list", // ul + li
+	OrderedList = "ordered_list", // ol + li
+	ListStart = "ordered_list.start", // 
    + ListJumpValue = "ordered_list.jump_value", //
  1. + CustomEmoji = "custom_emoji", // + Spoiler = "spoiler", // + SpoilerReason = "spoiler.reason", // + TextForegroundColor = "color.foreground", // + TextBackgroundColor = "color.background", // + HorizontalLine = "horizontal_line", // hr + Headers = "headers", // h1, h2, h3, h4, h5, h6 + Superscript = "superscript", // sup + Subscript = "subscript", // sub + Math = "math", // + DetailsSummary = "details_summary", //
    ......
    + Table = "table", // table, thead, tbody, tr, th, td +} diff --git a/event/capabilities.go b/event/capabilities.go new file mode 100644 index 00000000..6a9b0b56 --- /dev/null +++ b/event/capabilities.go @@ -0,0 +1,259 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package event + +import ( + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "fmt" + "io" + "mime" + "slices" + "strings" + + "go.mau.fi/util/exerrors" + "go.mau.fi/util/jsontime" + "golang.org/x/exp/constraints" + "golang.org/x/exp/maps" +) + +type RoomFeatures struct { + ID string `json:"id,omitempty"` + + // N.B. New fields need to be added to the Hash function to be included in the deduplication hash. + + Formatting FormattingFeatureMap `json:"formatting,omitempty"` + File FileFeatureMap `json:"file,omitempty"` + + LocationMessage CapabilitySupportLevel `json:"location_message,omitempty"` + Poll CapabilitySupportLevel `json:"poll,omitempty"` + Thread CapabilitySupportLevel `json:"thread,omitempty"` + Reply CapabilitySupportLevel `json:"reply,omitempty"` + + Edit CapabilitySupportLevel `json:"edit,omitempty"` + EditMaxCount int `json:"edit_max_count,omitempty"` + EditMaxAge *jsontime.Seconds `json:"edit_max_age,omitempty"` + Delete CapabilitySupportLevel `json:"delete,omitempty"` + DeleteForMe bool `json:"delete_for_me,omitempty"` + DeleteMaxAge *jsontime.Seconds `json:"delete_max_age,omitempty"` + + Reaction CapabilitySupportLevel `json:"reaction,omitempty"` + ReactionCount int `json:"reaction_count,omitempty"` + AllowedReactions []string `json:"allowed_reactions,omitempty"` + CustomEmojiReactions bool `json:"custom_emoji_reactions,omitempty"` + + ReadReceipts bool `json:"read_receipts,omitempty"` + TypingNotifications bool `json:"typing_notifications,omitempty"` + Archive bool `json:"archive,omitempty"` + MarkAsUnread bool `json:"mark_as_unread,omitempty"` + DeleteChat bool `json:"delete_chat,omitempty"` +} + +func (rf *RoomFeatures) GetID() string { + if rf.ID != "" { + return rf.ID + } + return base64.RawURLEncoding.EncodeToString(rf.Hash()) +} + +type FormattingFeatureMap map[FormattingFeature]CapabilitySupportLevel + +type FileFeatureMap map[CapabilityMsgType]*FileFeatures + +type CapabilityMsgType = MessageType + +// Message types which are used for event capability signaling, but aren't real values for the msgtype field. +const ( + CapMsgVoice CapabilityMsgType = "org.matrix.msc3245.voice" + CapMsgGIF CapabilityMsgType = "fi.mau.gif" + CapMsgSticker CapabilityMsgType = "m.sticker" +) + +type CapabilitySupportLevel int + +func (csl CapabilitySupportLevel) Partial() bool { + return csl >= CapLevelPartialSupport +} + +func (csl CapabilitySupportLevel) Full() bool { + return csl >= CapLevelFullySupported +} + +func (csl CapabilitySupportLevel) Reject() bool { + return csl <= CapLevelRejected +} + +const ( + CapLevelRejected CapabilitySupportLevel = -2 // The feature is unsupported and messages using it will be rejected. + CapLevelDropped CapabilitySupportLevel = -1 // The feature is unsupported and has no fallback. The message will go through, but data may be lost. + CapLevelUnsupported CapabilitySupportLevel = 0 // The feature is unsupported, but may have a fallback. + CapLevelPartialSupport CapabilitySupportLevel = 1 // The feature is partially supported (e.g. it may be converted to a different format). + CapLevelFullySupported CapabilitySupportLevel = 2 // The feature is fully supported and can be safely used. +) + +type FormattingFeature string + +const ( + FmtBold FormattingFeature = "bold" // strong, b + FmtItalic FormattingFeature = "italic" // em, i + FmtUnderline FormattingFeature = "underline" // u + FmtStrikethrough FormattingFeature = "strikethrough" // del, s + FmtInlineCode FormattingFeature = "inline_code" // code + FmtCodeBlock FormattingFeature = "code_block" // pre + code + FmtSyntaxHighlighting FormattingFeature = "code_block.syntax_highlighting" //
    
    +	FmtBlockquote          FormattingFeature = "blockquote"                     // blockquote
    +	FmtInlineLink          FormattingFeature = "inline_link"                    // a
    +	FmtUserLink            FormattingFeature = "user_link"                      // 
    +	FmtRoomLink            FormattingFeature = "room_link"                      // 
    +	FmtEventLink           FormattingFeature = "event_link"                     // 
    +	FmtAtRoomMention       FormattingFeature = "at_room_mention"                // @room (no html tag)
    +	FmtUnorderedList       FormattingFeature = "unordered_list"                 // ul + li
    +	FmtOrderedList         FormattingFeature = "ordered_list"                   // ol + li
    +	FmtListStart           FormattingFeature = "ordered_list.start"             // 
      + FmtListJumpValue FormattingFeature = "ordered_list.jump_value" //
    1. + FmtCustomEmoji FormattingFeature = "custom_emoji" // + FmtSpoiler FormattingFeature = "spoiler" // + FmtSpoilerReason FormattingFeature = "spoiler.reason" // + FmtTextForegroundColor FormattingFeature = "color.foreground" // + FmtTextBackgroundColor FormattingFeature = "color.background" // + FmtHorizontalLine FormattingFeature = "horizontal_line" // hr + FmtHeaders FormattingFeature = "headers" // h1, h2, h3, h4, h5, h6 + FmtSuperscript FormattingFeature = "superscript" // sup + FmtSubscript FormattingFeature = "subscript" // sub + FmtMath FormattingFeature = "math" // + FmtDetailsSummary FormattingFeature = "details_summary" //
      ......
      + FmtTable FormattingFeature = "table" // table, thead, tbody, tr, th, td +) + +type FileFeatures struct { + // N.B. New fields need to be added to the Hash function to be included in the deduplication hash. + + MimeTypes map[string]CapabilitySupportLevel `json:"mime_types"` + + Caption CapabilitySupportLevel `json:"caption,omitempty"` + MaxCaptionLength int `json:"max_caption_length,omitempty"` + + MaxSize int64 `json:"max_size,omitempty"` + MaxWidth int `json:"max_width,omitempty"` + MaxHeight int `json:"max_height,omitempty"` + MaxDuration *jsontime.Seconds `json:"max_duration,omitempty"` + + ViewOnce bool `json:"view_once,omitempty"` +} + +func (ff *FileFeatures) GetMimeSupport(inputType string) CapabilitySupportLevel { + match, ok := ff.MimeTypes[inputType] + if ok { + return match + } + if strings.IndexByte(inputType, ';') != -1 { + plainMime, _, _ := mime.ParseMediaType(inputType) + if plainMime != "" { + if match, ok = ff.MimeTypes[plainMime]; ok { + return match + } + } + } + if slash := strings.IndexByte(inputType, '/'); slash > 0 { + generalType := fmt.Sprintf("%s/*", inputType[:slash]) + if match, ok = ff.MimeTypes[generalType]; ok { + return match + } + } + match, ok = ff.MimeTypes["*/*"] + if ok { + return match + } + return CapLevelRejected +} + +type hashable interface { + Hash() []byte +} + +func hashMap[Key ~string, Value hashable](w io.Writer, name string, data map[Key]Value) { + keys := maps.Keys(data) + slices.Sort(keys) + exerrors.Must(w.Write([]byte(name))) + for _, key := range keys { + exerrors.Must(w.Write([]byte(key))) + exerrors.Must(w.Write(data[key].Hash())) + exerrors.Must(w.Write([]byte{0})) + } +} + +func hashValue(w io.Writer, name string, data hashable) { + exerrors.Must(w.Write([]byte(name))) + exerrors.Must(w.Write(data.Hash())) +} + +func hashInt[T constraints.Integer](w io.Writer, name string, data T) { + exerrors.Must(w.Write(binary.BigEndian.AppendUint64([]byte(name), uint64(data)))) +} + +func hashBool[T ~bool](w io.Writer, name string, data T) { + exerrors.Must(w.Write([]byte(name))) + if data { + exerrors.Must(w.Write([]byte{1})) + } else { + exerrors.Must(w.Write([]byte{0})) + } +} + +func (csl CapabilitySupportLevel) Hash() []byte { + return []byte{byte(csl + 128)} +} + +func (rf *RoomFeatures) Hash() []byte { + hasher := sha256.New() + + hashMap(hasher, "formatting", rf.Formatting) + hashMap(hasher, "file", rf.File) + + hashValue(hasher, "location_message", rf.LocationMessage) + hashValue(hasher, "poll", rf.Poll) + hashValue(hasher, "thread", rf.Thread) + hashValue(hasher, "reply", rf.Reply) + + hashValue(hasher, "edit", rf.Edit) + hashInt(hasher, "edit_max_count", rf.EditMaxCount) + hashInt(hasher, "edit_max_age", rf.EditMaxAge.Get()) + + hashValue(hasher, "delete", rf.Delete) + hashBool(hasher, "delete_for_me", rf.DeleteForMe) + hashInt(hasher, "delete_max_age", rf.DeleteMaxAge.Get()) + + hashValue(hasher, "reaction", rf.Reaction) + hashInt(hasher, "reaction_count", rf.ReactionCount) + hasher.Write([]byte("allowed_reactions")) + for _, reaction := range rf.AllowedReactions { + hasher.Write([]byte(reaction)) + } + hashBool(hasher, "custom_emoji_reactions", rf.CustomEmojiReactions) + + hashBool(hasher, "read_receipts", rf.ReadReceipts) + hashBool(hasher, "typing_notifications", rf.TypingNotifications) + hashBool(hasher, "archive", rf.Archive) + hashBool(hasher, "mark_as_unread", rf.MarkAsUnread) + hashBool(hasher, "delete_chat", rf.DeleteChat) + + return hasher.Sum(nil) +} + +func (ff *FileFeatures) Hash() []byte { + hasher := sha256.New() + hashMap(hasher, "mime_types", ff.MimeTypes) + hashValue(hasher, "caption", ff.Caption) + hashInt(hasher, "max_caption_length", ff.MaxCaptionLength) + hashInt(hasher, "max_size", ff.MaxSize) + hashInt(hasher, "max_width", ff.MaxWidth) + hashInt(hasher, "max_height", ff.MaxHeight) + hashInt(hasher, "max_duration", ff.MaxDuration.Get()) + hashBool(hasher, "view_once", ff.ViewOnce) + return hasher.Sum(nil) +} diff --git a/event/content.go b/event/content.go index ab57c658..b8e130db 100644 --- a/event/content.go +++ b/event/content.go @@ -48,6 +48,7 @@ var TypeMap = map[Type]reflect.Type{ StateUnstablePolicyUser: reflect.TypeOf(ModPolicyContent{}), StateElementFunctionalMembers: reflect.TypeOf(ElementFunctionalMembersContent{}), + StateBeeperRoomFeatures: reflect.TypeOf(RoomFeatures{}), EventMessage: reflect.TypeOf(MessageEventContent{}), EventSticker: reflect.TypeOf(MessageEventContent{}), diff --git a/event/message.go b/event/message.go index 92bdcf07..f9c4f49c 100644 --- a/event/message.go +++ b/event/message.go @@ -32,7 +32,7 @@ func (mt MessageType) IsText() bool { func (mt MessageType) IsMedia() bool { switch mt { - case MsgImage, MsgVideo, MsgAudio, MsgFile, MessageType(EventSticker.Type): + case MsgImage, MsgVideo, MsgAudio, MsgFile, CapMsgSticker: return true default: return false @@ -142,6 +142,32 @@ type MessageEventContent struct { MSC3245Voice *MSC3245Voice `json:"org.matrix.msc3245.voice,omitempty"` } +func (content *MessageEventContent) GetCapMsgType() CapabilityMsgType { + switch content.MsgType { + case CapMsgSticker: + return CapMsgSticker + case "": + if content.URL != "" || content.File != nil { + return CapMsgSticker + } + case MsgImage: + return MsgImage + case MsgAudio: + if content.MSC3245Voice != nil { + return CapMsgVoice + } + return MsgAudio + case MsgVideo: + if content.Info != nil && content.Info.MauGIF { + return CapMsgGIF + } + return MsgVideo + case MsgFile: + return MsgFile + } + return "" +} + func (content *MessageEventContent) GetFileName() string { if content.FileName != "" { return content.FileName @@ -258,6 +284,8 @@ type FileInfo struct { Blurhash string `json:"blurhash,omitempty"` AnoaBlurhash string `json:"xyz.amorgan.blurhash,omitempty"` + MauGIF bool `json:"fi.mau.gif,omitempty"` + Width int `json:"-"` Height int `json:"-"` Duration int `json:"-"` diff --git a/event/type.go b/event/type.go index f2b841ad..41d7c47b 100644 --- a/event/type.go +++ b/event/type.go @@ -112,7 +112,7 @@ func (et *Type) GuessClass() TypeClass { StatePowerLevels.Type, StateRoomName.Type, StateRoomAvatar.Type, StateServerACL.Type, StateTopic.Type, StatePinnedEvents.Type, StateTombstone.Type, StateEncryption.Type, StateBridge.Type, StateHalfShotBridge.Type, StateSpaceParent.Type, StateSpaceChild.Type, StatePolicyRoom.Type, StatePolicyServer.Type, StatePolicyUser.Type, - StateInsertionMarker.Type, StateElementFunctionalMembers.Type: + StateInsertionMarker.Type, StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type: return StateEventType case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type: return EphemeralEventType @@ -203,6 +203,7 @@ var ( StateInsertionMarker = Type{"org.matrix.msc2716.marker", StateEventType} StateElementFunctionalMembers = Type{"io.element.functional_members", StateEventType} + StateBeeperRoomFeatures = Type{"com.beeper.room_features", StateEventType} ) // Message events diff --git a/go.mod b/go.mod index fd8baddb..8252186f 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.8 - go.mau.fi/util v0.8.4-0.20250106152331-30b8c95e7d7a + go.mau.fi/util v0.8.4-0.20250110124612-64d4dbbec957 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.32.0 golang.org/x/exp v0.0.0-20250103183323-7d7fa50e5329 diff --git a/go.sum b/go.sum index 18a2859e..4c27345f 100644 --- a/go.sum +++ b/go.sum @@ -51,8 +51,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.8.4-0.20250106152331-30b8c95e7d7a h1:D9RCHBFjxah9F/YB7amvRJjT2IEOFWcz8jpcEY8dBV0= -go.mau.fi/util v0.8.4-0.20250106152331-30b8c95e7d7a/go.mod h1:MOfGTs1CBuK6ERTcSL4lb5YU7/ujz09eOPVEDckuazY= +go.mau.fi/util v0.8.4-0.20250110124612-64d4dbbec957 h1:tsLt3t6ARc55niz+JMgJy6U4sL210Z0K/nyxF09xT0E= +go.mau.fi/util v0.8.4-0.20250110124612-64d4dbbec957/go.mod h1:MOfGTs1CBuK6ERTcSL4lb5YU7/ujz09eOPVEDckuazY= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= From bbcb1904e268a9ba4ab11308ea64c75967ee11da Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 10 Jan 2025 17:40:22 +0200 Subject: [PATCH 0989/1647] event/capabilities: add max text length field --- event/capabilities.d.ts | 3 +++ event/capabilities.go | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/event/capabilities.d.ts b/event/capabilities.d.ts index 1c5d533d..4cf29de7 100644 --- a/event/capabilities.d.ts +++ b/event/capabilities.d.ts @@ -17,6 +17,9 @@ export interface RoomFeatures { */ file?: Record + /** Maximum length of normal text messages. */ + max_text_length?: integer + /** Whether location messages (`m.location`) are supported. */ location_message?: CapabilitySupportLevel /** Whether polls are supported. */ diff --git a/event/capabilities.go b/event/capabilities.go index 6a9b0b56..9c9eb09a 100644 --- a/event/capabilities.go +++ b/event/capabilities.go @@ -30,6 +30,8 @@ type RoomFeatures struct { Formatting FormattingFeatureMap `json:"formatting,omitempty"` File FileFeatureMap `json:"file,omitempty"` + MaxTextLength int `json:"max_text_length,omitempty"` + LocationMessage CapabilitySupportLevel `json:"location_message,omitempty"` Poll CapabilitySupportLevel `json:"poll,omitempty"` Thread CapabilitySupportLevel `json:"thread,omitempty"` @@ -215,6 +217,8 @@ func (rf *RoomFeatures) Hash() []byte { hashMap(hasher, "formatting", rf.Formatting) hashMap(hasher, "file", rf.File) + hashInt(hasher, "max_text_length", rf.MaxTextLength) + hashValue(hasher, "location_message", rf.LocationMessage) hashValue(hasher, "poll", rf.Poll) hashValue(hasher, "thread", rf.Thread) From c05be16a5233a244e112d49d5bf7cf20764dbecf Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 13 Jan 2025 22:08:41 +0200 Subject: [PATCH 0990/1647] event: fix de/serializing `fi.mau.gif` file info field --- event/message.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/event/message.go b/event/message.go index f9c4f49c..3b2330c2 100644 --- a/event/message.go +++ b/event/message.go @@ -301,6 +301,8 @@ type serializableFileInfo struct { Blurhash string `json:"blurhash,omitempty"` AnoaBlurhash string `json:"xyz.amorgan.blurhash,omitempty"` + MauGIF bool `json:"fi.mau.gif,omitempty"` + Width json.Number `json:"w,omitempty"` Height json.Number `json:"h,omitempty"` Duration json.Number `json:"duration,omitempty"` @@ -317,6 +319,8 @@ func (sfi *serializableFileInfo) CopyFrom(fileInfo *FileInfo) *serializableFileI ThumbnailInfo: (&serializableFileInfo{}).CopyFrom(fileInfo.ThumbnailInfo), ThumbnailFile: fileInfo.ThumbnailFile, + MauGIF: fileInfo.MauGIF, + Blurhash: fileInfo.Blurhash, AnoaBlurhash: fileInfo.AnoaBlurhash, } @@ -345,6 +349,7 @@ func (sfi *serializableFileInfo) CopyTo(fileInfo *FileInfo) { MimeType: sfi.MimeType, ThumbnailURL: sfi.ThumbnailURL, ThumbnailFile: sfi.ThumbnailFile, + MauGIF: sfi.MauGIF, Blurhash: sfi.Blurhash, AnoaBlurhash: sfi.AnoaBlurhash, } From 53a56684d3d3241efb2929ab36f8ca340b3f601d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 13 Jan 2025 22:09:22 +0200 Subject: [PATCH 0991/1647] event: remove struct tags from FileInfo They're lies, only `serializableFileInfo` is actually used --- event/message.go | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/event/message.go b/event/message.go index 3b2330c2..48313784 100644 --- a/event/message.go +++ b/event/message.go @@ -276,20 +276,20 @@ type EncryptedFileInfo struct { } type FileInfo struct { - MimeType string `json:"mimetype,omitempty"` - ThumbnailInfo *FileInfo `json:"thumbnail_info,omitempty"` - ThumbnailURL id.ContentURIString `json:"thumbnail_url,omitempty"` - ThumbnailFile *EncryptedFileInfo `json:"thumbnail_file,omitempty"` + MimeType string + ThumbnailInfo *FileInfo + ThumbnailURL id.ContentURIString + ThumbnailFile *EncryptedFileInfo - Blurhash string `json:"blurhash,omitempty"` - AnoaBlurhash string `json:"xyz.amorgan.blurhash,omitempty"` + Blurhash string + AnoaBlurhash string - MauGIF bool `json:"fi.mau.gif,omitempty"` + MauGIF bool - Width int `json:"-"` - Height int `json:"-"` - Duration int `json:"-"` - Size int `json:"-"` + Width int + Height int + Duration int + Size int } type serializableFileInfo struct { From 27ac910b655913da38d6468da44d29cc47ab9694 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 14 Jan 2025 21:34:33 +0200 Subject: [PATCH 0992/1647] bridgev2/portal: only use event loop when buffer is enabled When buffer is disabled, queueEvent will instead acquire a lock and call the handler directly. Hopefully the queueEvent callers are already in a queue and will block so that queueEvent itself doesn't need to be strictly FIFO (if callers aren't in a queue, even the buffered channel writes could race each other). --- bridgev2/portal.go | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 4d8d76e6..30433d5a 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -84,6 +84,9 @@ type Portal struct { roomCreateLock sync.Mutex events chan portalEvent + + eventsLock sync.Mutex + eventIdx int } var PortalEventBuffer = 64 @@ -109,7 +112,6 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que Portal: dbPortal, Bridge: br, - events: make(chan portalEvent, PortalEventBuffer), currentlyTypingLogins: make(map[id.UserID]*UserLogin), outgoingMessages: make(map[networkid.TransactionID]outgoingMessage), } @@ -131,7 +133,10 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que } } portal.updateLogger() - go portal.eventLoop() + if PortalEventBuffer != 0 { + portal.events = make(chan portalEvent, PortalEventBuffer) + go portal.eventLoop() + } return portal, nil } @@ -275,7 +280,10 @@ func (br *Bridge) GetExistingPortalByKey(ctx context.Context, key networkid.Port func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) { if PortalEventBuffer == 0 { - portal.events <- evt + portal.eventsLock.Lock() + defer portal.eventsLock.Unlock() + portal.eventIdx++ + portal.handleSingleEventAsync(portal.eventIdx, evt) } else { select { case portal.events <- evt: From b17a8cd74cf6e35109c50b63e7bbec45c530ae9b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 15 Jan 2025 15:05:29 +0200 Subject: [PATCH 0993/1647] bridgev2: add RunOnce method to backfill a single user login and disconnect --- bridgev2/bridge.go | 68 ++++++++++++++++++++++++---------- bridgev2/matrix/mxmain/main.go | 9 +++-- bridgev2/networkinterface.go | 9 +++++ 3 files changed, 62 insertions(+), 24 deletions(-) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 794c1f29..fc195d35 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -47,6 +47,8 @@ type Bridge struct { didSplitPortals bool + Background bool + wakeupBackfillQueue chan struct{} stopBackfillQueue chan struct{} } @@ -103,28 +105,48 @@ func (e DBUpgradeError) Unwrap() error { } func (br *Bridge) Start() error { - err := br.StartConnectors() + ctx := br.Log.WithContext(context.Background()) + err := br.StartConnectors(ctx) if err != nil { return err } - err = br.StartLogins() + err = br.StartLogins(ctx) if err != nil { return err } - br.PostStart() + br.PostStart(ctx) return nil } -func (br *Bridge) StartConnectors() error { +func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID) error { + br.Background = true + err := br.StartConnectors(ctx) + if err != nil { + return err + } + login, err := br.GetExistingUserLoginByID(ctx, loginID) + if err != nil { + return fmt.Errorf("failed to get user login: %w", err) + } else if login == nil { + return ErrNotLoggedIn + } + syncClient, ok := login.Client.(BackgroundSyncingNetworkAPI) + if !ok { + return fmt.Errorf("%T does not implement BackgroundSyncingNetworkAPI", login.Client) + } + defer br.stop(true) + br.Log.Info().Str("user_login_id", string(login.ID)).Msg("Starting individual user login in background mode") + return syncClient.ConnectBackground(login.Log.WithContext(ctx)) +} + +func (br *Bridge) StartConnectors(ctx context.Context) error { br.Log.Info().Msg("Starting bridge") - ctx := br.Log.WithContext(context.Background()) - foreground := true err := br.DB.Upgrade(ctx) if err != nil { return DBUpgradeError{Err: err, Section: "main"} } - if foreground { + if !br.Background { br.didSplitPortals = br.MigrateToSplitPortals(ctx) } br.Log.Info().Msg("Starting Matrix connector") @@ -137,14 +159,16 @@ func (br *Bridge) StartConnectors() error { if err != nil { return fmt.Errorf("failed to start network connector: %w", err) } - if br.Network.GetCapabilities().DisappearingMessages { + if br.Network.GetCapabilities().DisappearingMessages && !br.Background { go br.DisappearLoop.Start() } return nil } -func (br *Bridge) PostStart() { - ctx := br.Log.WithContext(context.Background()) +func (br *Bridge) PostStart(ctx context.Context) { + if br.Background { + return + } rawBridgeInfoVer := br.DB.KV.Get(ctx, database.KeyBridgeInfoVersion) bridgeInfoVer, capVer, err := parseBridgeInfoVersion(rawBridgeInfoVer) if err != nil { @@ -228,9 +252,7 @@ func (br *Bridge) MigrateToSplitPortals(ctx context.Context) bool { return affected > 0 } -func (br *Bridge) StartLogins() error { - ctx := br.Log.WithContext(context.Background()) - +func (br *Bridge) StartLogins(ctx context.Context) error { userIDs, err := br.DB.UserLogin.GetAllUserIDsWithLogins(ctx) if err != nil { return fmt.Errorf("failed to get users with logins: %w", err) @@ -261,17 +283,23 @@ func (br *Bridge) StartLogins() error { } func (br *Bridge) Stop() { + br.stop(false) +} + +func (br *Bridge) stop(isRunOnce bool) { br.Log.Info().Msg("Shutting down bridge") close(br.stopBackfillQueue) br.Matrix.Stop() - br.cacheLock.Lock() - var wg sync.WaitGroup - wg.Add(len(br.userLoginsByID)) - for _, login := range br.userLoginsByID { - go login.Disconnect(wg.Done) + if !isRunOnce { + br.cacheLock.Lock() + var wg sync.WaitGroup + wg.Add(len(br.userLoginsByID)) + for _, login := range br.userLoginsByID { + go login.Disconnect(wg.Done) + } + wg.Wait() + br.cacheLock.Unlock() } - wg.Wait() - br.cacheLock.Unlock() if stopNet, ok := br.Network.(StoppableNetwork); ok { stopNet.Stop() } diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index 695b042b..dab0b914 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -370,7 +370,8 @@ func (br *BridgeMain) LoadConfig() { // Start starts the bridge after everything has been initialized. // This is called by [Run] and does not need to be called manually. func (br *BridgeMain) Start() { - err := br.Bridge.StartConnectors() + ctx := br.Log.WithContext(context.Background()) + err := br.Bridge.StartConnectors(ctx) if err != nil { var dbUpgradeErr bridgev2.DBUpgradeError if errors.As(err, &dbUpgradeErr) { @@ -379,15 +380,15 @@ func (br *BridgeMain) Start() { br.Log.Fatal().Err(err).Msg("Failed to start bridge") } } - err = br.PostMigrate(br.Log.WithContext(context.Background())) + err = br.PostMigrate(ctx) if err != nil { br.Log.Fatal().Err(err).Msg("Failed to run post-migration updates") } - err = br.Bridge.StartLogins() + err = br.Bridge.StartLogins(ctx) if err != nil { br.Log.Fatal().Err(err).Msg("Failed to start existing user logins") } - br.Bridge.PostStart() + br.Bridge.PostStart(ctx) if br.PostStart != nil { br.PostStart() } diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 05d948c8..c0e3f880 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -352,6 +352,15 @@ type NetworkAPI interface { HandleMatrixMessage(ctx context.Context, msg *MatrixMessage) (message *MatrixMessageResponse, err error) } +// BackgroundSyncingNetworkAPI is an optional interface that network connectors can implement to support background resyncs. +type BackgroundSyncingNetworkAPI interface { + NetworkAPI + // ConnectBackground is called in place of Connect for background resyncs. + // The client should connect to the remote network, handle pending messages, and then disconnect. + // This call should block until the entire sync is complete and the client is disconnected. + ConnectBackground(ctx context.Context) error +} + // FetchMessagesParams contains the parameters for a message history pagination request. type FetchMessagesParams struct { // The portal to fetch messages in. Always present. From 757cdc7563cf32b6be1056ca81dcb40748fd5d6c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 16 Jan 2025 12:03:53 +0200 Subject: [PATCH 0994/1647] bridgev2/config: update MSC reference for appservice e2ee --- bridgev2/matrix/mxmain/example-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 82c431fb..48d6a77e 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -340,7 +340,7 @@ encryption: default: false # Whether to require all messages to be encrypted and drop any unencrypted messages. require: false - # Whether to use MSC2409/MSC3202 instead of /sync long polling for receiving encryption-related data. + # Whether to use MSC3202/MSC4203 instead of /sync long polling for receiving encryption-related data. # This option is not yet compatible with standard Matrix servers like Synapse and should not be used. # Changing this option requires updating the appservice registration file. appservice: false From d579e450c65a9dd4ec068a1615e174b15d412865 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 16 Jan 2025 12:04:13 +0200 Subject: [PATCH 0995/1647] dependencies: update --- go.mod | 12 ++++++------ go.sum | 21 ++++++++++++--------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/go.mod b/go.mod index 8252186f..0f4c7ed5 100644 --- a/go.mod +++ b/go.mod @@ -18,11 +18,11 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.8 - go.mau.fi/util v0.8.4-0.20250110124612-64d4dbbec957 + go.mau.fi/util v0.8.4 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.32.0 - golang.org/x/exp v0.0.0-20250103183323-7d7fa50e5329 - golang.org/x/net v0.33.0 + golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 + golang.org/x/net v0.34.0 golang.org/x/sync v0.10.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 @@ -31,12 +31,12 @@ require ( require ( github.com/coreos/go-systemd/v22 v22.5.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-isatty v0.0.19 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/petermattis/goid v0.0.0-20241211131331-93ee7e083c43 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect - github.com/tidwall/pretty v1.2.0 // indirect + github.com/tidwall/pretty v1.2.1 // indirect golang.org/x/sys v0.29.0 // indirect golang.org/x/text v0.21.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect diff --git a/go.sum b/go.sum index 4c27345f..fc14b4d9 100644 --- a/go.sum +++ b/go.sum @@ -19,11 +19,13 @@ github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWm github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= -github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/petermattis/goid v0.0.0-20241211131331-93ee7e083c43 h1:ah1dvbqPMN5+ocrg/ZSgZ6k8bOk+kcZQ7fnyx6UvOm4= @@ -45,22 +47,23 @@ github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= -github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.8.4-0.20250110124612-64d4dbbec957 h1:tsLt3t6ARc55niz+JMgJy6U4sL210Z0K/nyxF09xT0E= -go.mau.fi/util v0.8.4-0.20250110124612-64d4dbbec957/go.mod h1:MOfGTs1CBuK6ERTcSL4lb5YU7/ujz09eOPVEDckuazY= +go.mau.fi/util v0.8.4 h1:mVKlJcXWfVo8ZW3f4vqtjGpqtZqJvX4ETekxawt2vnQ= +go.mau.fi/util v0.8.4/go.mod h1:MOfGTs1CBuK6ERTcSL4lb5YU7/ujz09eOPVEDckuazY= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= -golang.org/x/exp v0.0.0-20250103183323-7d7fa50e5329 h1:9kj3STMvgqy3YA4VQXBrN7925ICMxD5wzMRcgA30588= -golang.org/x/exp v0.0.0-20250103183323-7d7fa50e5329/go.mod h1:qj5a5QZpwLU2NLQudwIN5koi3beDhSAlJwa67PuM98c= -golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= -golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 h1:yqrTHse8TCMW1M1ZCP+VAR/l0kKxwaAIqN/il7x4voA= +golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU= +golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= +golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= From 250d3356a42fa53efa485561413427ba28305ff3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 16 Jan 2025 12:36:11 +0200 Subject: [PATCH 0996/1647] Bump version to v0.23.0 --- CHANGELOG.md | 35 +++++++++++++++++++++++++++++++++++ version.go | 2 +- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0cc96b60..4490978f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,38 @@ +## v0.23.0 (2025-01-16) + +* **Breaking change *(client)*** Changed `JoinRoom` parameters to allow multiple + `via`s. +* **Breaking change *(bridgev2)*** Updated capability system. + * The return type of `NetworkAPI.GetCapabilities` is now different. + * Media type capabilities are enforced automatically by bridgev2. + * Capabilities are now sent to Matrix rooms using the + `com.beeper.room_features` state event. +* *(client)* Added `GetRoomSummary` to implement [MSC3266]. +* *(client)* Added support for arbitrary profile fields to implement [MSC4133] + (thanks to [@nexy7574] in [#337]). +* *(crypto)* Started storing olm message hashes to prevent decryption errors + if messages are repeated (e.g. if the app crashes right after decrypting). +* *(crypto)* Improved olm session unwedging to check when the last session was + created instead of only relying on an in-memory map. +* *(crypto/verificationhelper)* Fixed emoji verification not doing cross-signing + properly after a successful verification. +* *(bridgev2/config)* Moved MSC4190 flag from `appservice` to `encryption`. +* *(bridgev2/space)* Fixed failing to add rooms to spaces if the room create + call was made with a temporary context. +* *(bridgev2/commands)* Changed `help` command to hide commands which require + interfaces that aren't implemented by the network connector. +* *(bridgev2/matrixinterface)* Moved deterministic room ID generation to Matrix + connector. +* *(bridgev2)* Fixed service member state event not being set correctly when + creating a DM by inviting a ghost user. +* *(bridgev2)* Fixed `RemoteReactionSync` events replacing all reactions every + time instead of only changed ones. + +[MSC3266]: https://github.com/matrix-org/matrix-spec-proposals/pull/3266 +[MSC4133]: https://github.com/matrix-org/matrix-spec-proposals/pull/4133 +[@nexy7574]: https://github.com/nexy7574 +[#337]: https://github.com/mautrix/go/pull/337 + ## v0.22.1 (2024-12-16) * *(crypto)* Added automatic cleanup when there are too many olm sessions with diff --git a/version.go b/version.go index 362a684b..16b6b2f2 100644 --- a/version.go +++ b/version.go @@ -7,7 +7,7 @@ import ( "strings" ) -const Version = "v0.22.1" +const Version = "v0.23.0" var GoModVersion = "" var Commit = "" From d60d8d474461c5add1c2c3a223cfbf360cd22daa Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 17 Jan 2025 11:03:49 -0700 Subject: [PATCH 0997/1647] crypto/aessha2: extract AES SHA2 functionality from crypto/goolm/cipher This also refactors it to not recompute the keys via HKDF repeatedly. Signed-off-by: Sumner Evans --- crypto/goolm/aessha2/aessha2.go | 59 ++++++++++++++ crypto/goolm/aessha2/aessha2_test.go | 33 ++++++++ crypto/goolm/cipher/aes_sha256.go | 81 ------------------- crypto/goolm/cipher/aes_sha256_test.go | 60 -------------- crypto/goolm/cipher/cipher.go | 18 ----- crypto/goolm/cipher/pickle.go | 49 +++++------ crypto/goolm/megolm/megolm.go | 26 +++--- crypto/goolm/message/group_message.go | 43 ++++------ crypto/goolm/message/group_message_test.go | 28 ++++++- crypto/goolm/message/message.go | 20 +++-- crypto/goolm/message/message_test.go | 7 +- crypto/goolm/pk/decryption.go | 25 +++--- crypto/goolm/pk/encryption.go | 8 +- crypto/goolm/ratchet/olm.go | 37 +++++---- crypto/goolm/ratchet/olm_test.go | 2 - .../goolm/session/megolm_outbound_session.go | 2 +- crypto/goolm/utilities/pickle.go | 4 +- 17 files changed, 224 insertions(+), 278 deletions(-) create mode 100644 crypto/goolm/aessha2/aessha2.go create mode 100644 crypto/goolm/aessha2/aessha2_test.go delete mode 100644 crypto/goolm/cipher/aes_sha256.go delete mode 100644 crypto/goolm/cipher/aes_sha256_test.go delete mode 100644 crypto/goolm/cipher/cipher.go diff --git a/crypto/goolm/aessha2/aessha2.go b/crypto/goolm/aessha2/aessha2.go new file mode 100644 index 00000000..42d9811b --- /dev/null +++ b/crypto/goolm/aessha2/aessha2.go @@ -0,0 +1,59 @@ +// Copyright (c) 2024 Sumner Evans +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Package aessha2 implements the m.megolm.v1.aes-sha2 encryption algorithm +// described in [Section 10.12.4.3] in the Spec +// +// [Section 10.12.4.3]: https://spec.matrix.org/v1.12/client-server-api/#mmegolmv1aes-sha2 +package aessha2 + +import ( + "crypto/hmac" + "crypto/sha256" + "crypto/subtle" + "io" + + "golang.org/x/crypto/hkdf" + + "maunium.net/go/mautrix/crypto/aescbc" +) + +type AESSHA2 struct { + aesKey, hmacKey, iv []byte +} + +func NewAESSHA2(secret, info []byte) (AESSHA2, error) { + kdf := hkdf.New(sha256.New, secret, nil, info) + keymatter := make([]byte, 80) + _, err := io.ReadFull(kdf, keymatter) + return AESSHA2{ + keymatter[:32], // AES Key + keymatter[32:64], // HMAC Key + keymatter[64:], // IV + }, err +} + +func (a *AESSHA2) Encrypt(plaintext []byte) ([]byte, error) { + return aescbc.Encrypt(a.aesKey, a.iv, plaintext) +} + +func (a *AESSHA2) Decrypt(ciphertext []byte) ([]byte, error) { + return aescbc.Decrypt(a.aesKey, a.iv, ciphertext) +} + +func (a *AESSHA2) MAC(ciphertext []byte) ([]byte, error) { + hash := hmac.New(sha256.New, a.hmacKey) + _, err := hash.Write(ciphertext) + return hash.Sum(nil), err +} + +func (a *AESSHA2) VerifyMAC(ciphertext, theirMAC []byte) (bool, error) { + if mac, err := a.MAC(ciphertext); err != nil { + return false, err + } else { + return subtle.ConstantTimeCompare(mac[:len(theirMAC)], theirMAC) == 1, nil + } +} diff --git a/crypto/goolm/aessha2/aessha2_test.go b/crypto/goolm/aessha2/aessha2_test.go new file mode 100644 index 00000000..b2cfe8aa --- /dev/null +++ b/crypto/goolm/aessha2/aessha2_test.go @@ -0,0 +1,33 @@ +package aessha2_test + +import ( + "crypto/aes" + "testing" + + "github.com/stretchr/testify/assert" + + "maunium.net/go/mautrix/crypto/goolm/aessha2" +) + +func TestCipherAESSha256(t *testing.T) { + key := []byte("test key") + cipher, err := aessha2.NewAESSHA2(key, []byte("testKDFinfo")) + assert.NoError(t, err) + message := []byte("this is a random message for testing the implementation") + //increase to next block size + for len(message)%aes.BlockSize != 0 { + message = append(message, []byte("-")...) + } + encrypted, err := cipher.Encrypt([]byte(message)) + assert.NoError(t, err) + mac, err := cipher.MAC(encrypted) + assert.NoError(t, err) + + verified, err := cipher.VerifyMAC(encrypted, mac[:8]) + assert.NoError(t, err) + assert.True(t, verified, "signature verification failed") + + resultPlainText, err := cipher.Decrypt(encrypted) + assert.NoError(t, err) + assert.Equal(t, message, resultPlainText) +} diff --git a/crypto/goolm/cipher/aes_sha256.go b/crypto/goolm/cipher/aes_sha256.go deleted file mode 100644 index 42f5d069..00000000 --- a/crypto/goolm/cipher/aes_sha256.go +++ /dev/null @@ -1,81 +0,0 @@ -package cipher - -import ( - "bytes" - "crypto/hmac" - "crypto/sha256" - "io" - - "golang.org/x/crypto/hkdf" - - "maunium.net/go/mautrix/crypto/aescbc" -) - -// derivedAESKeys stores the derived keys for the AESSHA256 cipher -type derivedAESKeys struct { - key []byte - hmacKey []byte - iv []byte -} - -// deriveAESKeys derives three keys for the AESSHA256 cipher -func deriveAESKeys(kdfInfo []byte, key []byte) (derivedAESKeys, error) { - kdf := hkdf.New(sha256.New, key, nil, kdfInfo) - keymatter := make([]byte, 80) - _, err := io.ReadFull(kdf, keymatter) - return derivedAESKeys{ - key: keymatter[:32], - hmacKey: keymatter[32:64], - iv: keymatter[64:], - }, err -} - -// AESSHA256 is a valid cipher using AES with CBC and HKDFSha256. -type AESSHA256 struct { - kdfInfo []byte -} - -// NewAESSHA256 returns a new AESSHA256 cipher with the key derive function info (kdfInfo). -func NewAESSHA256(kdfInfo []byte) *AESSHA256 { - return &AESSHA256{ - kdfInfo: kdfInfo, - } -} - -// Encrypt encrypts the plaintext with the key. The key is used to derive the actual encryption key (32 bytes) as well as the iv (16 bytes). -func (c AESSHA256) Encrypt(key, plaintext []byte) (ciphertext []byte, err error) { - keys, err := deriveAESKeys(c.kdfInfo, key) - if err != nil { - return nil, err - } - return aescbc.Encrypt(keys.key, keys.iv, plaintext) -} - -// Decrypt decrypts the ciphertext with the key. The key is used to derive the actual encryption key (32 bytes) as well as the iv (16 bytes). -func (c AESSHA256) Decrypt(key, ciphertext []byte) (plaintext []byte, err error) { - keys, err := deriveAESKeys(c.kdfInfo, key) - if err != nil { - return nil, err - } - return aescbc.Decrypt(keys.key, keys.iv, ciphertext) -} - -// MAC returns the MAC for the message using the key. The key is used to derive the actual mac key (32 bytes). -func (c AESSHA256) MAC(key, message []byte) ([]byte, error) { - keys, err := deriveAESKeys(c.kdfInfo, key) - if err != nil { - return nil, err - } - hash := hmac.New(sha256.New, keys.hmacKey) - _, err = hash.Write(message) - return hash.Sum(nil), err -} - -// Verify checks the MAC of the message using the key against the givenMAC. The key is used to derive the actual mac key (32 bytes). -func (c AESSHA256) Verify(key, message, givenMAC []byte) (bool, error) { - mac, err := c.MAC(key, message) - if err != nil { - return false, err - } - return bytes.Equal(givenMAC, mac[:len(givenMAC)]), nil -} diff --git a/crypto/goolm/cipher/aes_sha256_test.go b/crypto/goolm/cipher/aes_sha256_test.go deleted file mode 100644 index 2f58605f..00000000 --- a/crypto/goolm/cipher/aes_sha256_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package cipher - -import ( - "crypto/aes" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestDeriveAESKeys(t *testing.T) { - derivedKeys, err := deriveAESKeys([]byte("test"), []byte("test key")) - assert.NoError(t, err) - derivedKeys2, err := deriveAESKeys([]byte("test"), []byte("test key")) - assert.NoError(t, err) - - //derivedKeys and derivedKeys2 should be identical - assert.Equal(t, derivedKeys.key, derivedKeys2.key) - assert.Equal(t, derivedKeys.iv, derivedKeys2.iv) - assert.Equal(t, derivedKeys.hmacKey, derivedKeys2.hmacKey) - - //changing kdfInfo - derivedKeys2, err = deriveAESKeys([]byte("other kdf"), []byte("test key")) - assert.NoError(t, err) - - //derivedKeys and derivedKeys2 should now be different - assert.NotEqual(t, derivedKeys.key, derivedKeys2.key) - assert.NotEqual(t, derivedKeys.iv, derivedKeys2.iv) - assert.NotEqual(t, derivedKeys.hmacKey, derivedKeys2.hmacKey) - - //changing key - derivedKeys, err = deriveAESKeys([]byte("test"), []byte("other test key")) - assert.NoError(t, err) - - //derivedKeys and derivedKeys2 should now be different - assert.NotEqual(t, derivedKeys.key, derivedKeys2.key) - assert.NotEqual(t, derivedKeys.iv, derivedKeys2.iv) - assert.NotEqual(t, derivedKeys.hmacKey, derivedKeys2.hmacKey) -} - -func TestCipherAESSha256(t *testing.T) { - key := []byte("test key") - cipher := NewAESSHA256([]byte("testKDFinfo")) - message := []byte("this is a random message for testing the implementation") - //increase to next block size - for len(message)%aes.BlockSize != 0 { - message = append(message, []byte("-")...) - } - encrypted, err := cipher.Encrypt(key, []byte(message)) - assert.NoError(t, err) - mac, err := cipher.MAC(key, encrypted) - assert.NoError(t, err) - - verified, err := cipher.Verify(key, encrypted, mac[:8]) - assert.NoError(t, err) - assert.True(t, verified, "signature verification failed") - - resultPlainText, err := cipher.Decrypt(key, encrypted) - assert.NoError(t, err) - assert.Equal(t, message, resultPlainText) -} diff --git a/crypto/goolm/cipher/cipher.go b/crypto/goolm/cipher/cipher.go deleted file mode 100644 index 43580b0b..00000000 --- a/crypto/goolm/cipher/cipher.go +++ /dev/null @@ -1,18 +0,0 @@ -// Package cipher provides the methods and structs to do encryptions for -// olm/megolm. -package cipher - -// Cipher defines a valid cipher. -type Cipher interface { - // Encrypt encrypts the plaintext. - Encrypt(key, plaintext []byte) (ciphertext []byte, err error) - - // Decrypt decrypts the ciphertext. - Decrypt(key, ciphertext []byte) (plaintext []byte, err error) - - //MAC returns the MAC of the message calculated with the key. - MAC(key, message []byte) ([]byte, error) - - //Verify checks the MAC of the message calculated with the key against the givenMAC. - Verify(key, message, givenMAC []byte) (bool, error) -} diff --git a/crypto/goolm/cipher/pickle.go b/crypto/goolm/cipher/pickle.go index 754c7963..76f0d248 100644 --- a/crypto/goolm/cipher/pickle.go +++ b/crypto/goolm/cipher/pickle.go @@ -4,52 +4,47 @@ import ( "crypto/aes" "fmt" + "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/olm" ) -const ( - kdfPickle = "Pickle" //used to derive the keys for encryption - pickleMACLength = 8 -) +var kdfPickle = []byte("Pickle") //used to derive the keys for encryption +const pickleMACLength = 8 -// PickleBlockSize returns the blocksize of the used cipher. -func PickleBlockSize() int { - return aes.BlockSize -} +// PickleBlockSize is the blocksize of the pickle cipher. +const PickleBlockSize = aes.BlockSize // Pickle encrypts the input with the key and the cipher AESSHA256. The result is then encoded in base64. -func Pickle(key, input []byte) ([]byte, error) { - pickleCipher := NewAESSHA256([]byte(kdfPickle)) - ciphertext, err := pickleCipher.Encrypt(key, input) - if err != nil { +func Pickle(key, plaintext []byte) ([]byte, error) { + if c, err := aessha2.NewAESSHA2(key, kdfPickle); err != nil { return nil, err - } - mac, err := pickleCipher.MAC(key, ciphertext) - if err != nil { + } else if ciphertext, err := c.Encrypt(plaintext); err != nil { return nil, err + } else if mac, err := c.MAC(ciphertext); err != nil { + return nil, err + } else { + return goolmbase64.Encode(append(ciphertext, mac[:pickleMACLength]...)), nil } - ciphertext = append(ciphertext, mac[:pickleMACLength]...) - return goolmbase64.Encode(ciphertext), nil } // Unpickle decodes the input from base64 and decrypts the decoded input with the key and the cipher AESSHA256. func Unpickle(key, input []byte) ([]byte, error) { - pickleCipher := NewAESSHA256([]byte(kdfPickle)) ciphertext, err := goolmbase64.Decode(input) if err != nil { return nil, err } - //remove mac and check - verified, err := pickleCipher.Verify(key, ciphertext[:len(ciphertext)-pickleMACLength], ciphertext[len(ciphertext)-pickleMACLength:]) - if err != nil { + ciphertext, mac := ciphertext[:len(ciphertext)-pickleMACLength], ciphertext[len(ciphertext)-pickleMACLength:] + if c, err := aessha2.NewAESSHA2(key, kdfPickle); err != nil { return nil, err - } - if !verified { + } else if verified, err := c.VerifyMAC(ciphertext, mac); err != nil { + return nil, err + } else if !verified { return nil, fmt.Errorf("decrypt pickle: %w", olm.ErrBadMAC) + } else { + // Set to next block size + targetCipherText := make([]byte, int(len(ciphertext)/PickleBlockSize)*PickleBlockSize) + copy(targetCipherText, ciphertext) + return c.Decrypt(targetCipherText) } - //Set to next block size - targetCipherText := make([]byte, int(len(ciphertext)/PickleBlockSize())*PickleBlockSize()) - copy(targetCipherText, ciphertext) - return pickleCipher.Decrypt(key, targetCipherText) } diff --git a/crypto/goolm/megolm/megolm.go b/crypto/goolm/megolm/megolm.go index 416db111..930d8f44 100644 --- a/crypto/goolm/megolm/megolm.go +++ b/crypto/goolm/megolm/megolm.go @@ -7,7 +7,7 @@ import ( "crypto/sha256" "fmt" - "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" @@ -26,7 +26,7 @@ const ( RatchetPartLength = 256 / 8 // length of each ratchet part in bytes ) -var RatchetCipher = cipher.NewAESSHA256([]byte("MEGOLM_KEYS")) +var megolmKeysKDFInfo = []byte("MEGOLM_KEYS") // hasKeySeed are the seed for the different ratchet parts var hashKeySeeds [RatchetParts][]byte = [RatchetParts][]byte{ @@ -136,9 +136,8 @@ func (m *Ratchet) AdvanceTo(target uint32) { // Encrypt encrypts the message in a message.GroupMessage with MAC and signature. // The output is base64 encoded. -func (r *Ratchet) Encrypt(plaintext []byte, key *crypto.Ed25519KeyPair) ([]byte, error) { - var err error - encryptedText, err := RatchetCipher.Encrypt(r.Data[:], plaintext) +func (r *Ratchet) Encrypt(plaintext []byte, key crypto.Ed25519KeyPair) ([]byte, error) { + cipher, err := aessha2.NewAESSHA2(r.Data[:], megolmKeysKDFInfo) if err != nil { return nil, fmt.Errorf("cipher encrypt: %w", err) } @@ -146,9 +145,12 @@ func (r *Ratchet) Encrypt(plaintext []byte, key *crypto.Ed25519KeyPair) ([]byte, message := &message.GroupMessage{} message.Version = protocolVersion message.MessageIndex = r.Counter - message.Ciphertext = encryptedText - //creating the mac and signing is done in encode - output, err := message.EncodeAndMacAndSign(r.Data[:], RatchetCipher, key) + message.Ciphertext, err = cipher.Encrypt(plaintext) + if err != nil { + return nil, err + } + //creating the MAC and signing is done in encode + output, err := message.EncodeAndMACAndSign(cipher, key) if err != nil { return nil, err } @@ -178,7 +180,11 @@ func (r Ratchet) SessionExportMessage(key crypto.Ed25519PublicKey) ([]byte, erro // Decrypt decrypts the ciphertext and verifies the MAC but not the signature. func (r Ratchet) Decrypt(ciphertext []byte, signingkey *crypto.Ed25519PublicKey, msg *message.GroupMessage) ([]byte, error) { //verify mac - verifiedMAC, err := msg.VerifyMACInline(r.Data[:], RatchetCipher, ciphertext) + cipher, err := aessha2.NewAESSHA2(r.Data[:], megolmKeysKDFInfo) + if err != nil { + return nil, err + } + verifiedMAC, err := msg.VerifyMACInline(cipher, ciphertext) if err != nil { return nil, err } @@ -186,7 +192,7 @@ func (r Ratchet) Decrypt(ciphertext []byte, signingkey *crypto.Ed25519PublicKey, return nil, fmt.Errorf("decrypt: %w", olm.ErrBadMAC) } - return RatchetCipher.Decrypt(r.Data[:], msg.Ciphertext) + return cipher.Decrypt(msg.Ciphertext) } // PickleAsJSON returns a ratchet as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. diff --git a/crypto/goolm/message/group_message.go b/crypto/goolm/message/group_message.go index b34bfa5e..411e0879 100644 --- a/crypto/goolm/message/group_message.go +++ b/crypto/goolm/message/group_message.go @@ -3,7 +3,7 @@ package message import ( "bytes" - "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" ) @@ -68,9 +68,9 @@ func (r *GroupMessage) Decode(input []byte) error { return nil } -// EncodeAndMacAndSign encodes the message, creates the mac with the key and the cipher and signs the message. +// EncodeAndMACAndSign encodes the message, creates the mac with the key and the cipher and signs the message. // If macKey or cipher is nil, no mac is appended. If signKey is nil, no signature is appended. -func (r *GroupMessage) EncodeAndMacAndSign(macKey []byte, cipher cipher.Cipher, signKey *crypto.Ed25519KeyPair) ([]byte, error) { +func (r *GroupMessage) EncodeAndMACAndSign(cipher aessha2.AESSHA2, signKey crypto.Ed25519KeyPair) ([]byte, error) { var lengthOfMessage int lengthOfMessage += 1 //Version lengthOfMessage += encodeVarIntByteLength(messageIndexTag) + encodeVarIntByteLength(r.MessageIndex) @@ -90,37 +90,28 @@ func (r *GroupMessage) EncodeAndMacAndSign(macKey []byte, cipher cipher.Cipher, encodedValue = encodeVarString(r.Ciphertext) copy(out[curPos:], encodedValue) curPos += len(encodedValue) - if len(macKey) != 0 && cipher != nil { - mac, err := r.MAC(macKey, cipher, out) - if err != nil { - return nil, err - } - out = append(out, mac[:countMACBytesGroupMessage]...) + mac, err := r.MAC(cipher, out) + if err != nil { + return nil, err } - if signKey != nil { - signature, err := signKey.Sign(out) - if err != nil { - return nil, err - } - out = append(out, signature...) + out = append(out, mac[:countMACBytesGroupMessage]...) + signature, err := signKey.Sign(out) + if err != nil { + return nil, err } + out = append(out, signature...) return out, nil } // MAC returns the MAC of the message calculated with cipher and key. The length of the MAC is truncated to the correct length. -func (r *GroupMessage) MAC(key []byte, cipher cipher.Cipher, message []byte) ([]byte, error) { - mac, err := cipher.MAC(key, message) +func (r *GroupMessage) MAC(cipher aessha2.AESSHA2, ciphertext []byte) ([]byte, error) { + mac, err := cipher.MAC(ciphertext) if err != nil { return nil, err } return mac[:countMACBytesGroupMessage], nil } -// VerifySignature verifies the givenSignature to the calculated signature of the message. -func (r *GroupMessage) VerifySignature(key crypto.Ed25519PublicKey, message, givenSignature []byte) bool { - return key.Verify(message, givenSignature) -} - // VerifySignature verifies the signature taken from the message to the calculated signature of the message. func (r *GroupMessage) VerifySignatureInline(key crypto.Ed25519PublicKey, message []byte) bool { signature := message[len(message)-crypto.Ed25519SignatureSize:] @@ -129,8 +120,8 @@ func (r *GroupMessage) VerifySignatureInline(key crypto.Ed25519PublicKey, messag } // VerifyMAC verifies the givenMAC to the calculated MAC of the message. -func (r *GroupMessage) VerifyMAC(key []byte, cipher cipher.Cipher, message, givenMAC []byte) (bool, error) { - checkMac, err := r.MAC(key, cipher, message) +func (r *GroupMessage) VerifyMAC(cipher aessha2.AESSHA2, ciphertext, givenMAC []byte) (bool, error) { + checkMac, err := r.MAC(cipher, ciphertext) if err != nil { return false, err } @@ -138,10 +129,10 @@ func (r *GroupMessage) VerifyMAC(key []byte, cipher cipher.Cipher, message, give } // VerifyMACInline verifies the MAC taken from the message to the calculated MAC of the message. -func (r *GroupMessage) VerifyMACInline(key []byte, cipher cipher.Cipher, message []byte) (bool, error) { +func (r *GroupMessage) VerifyMACInline(cipher aessha2.AESSHA2, message []byte) (bool, error) { startMAC := len(message) - countMACBytesGroupMessage - crypto.Ed25519SignatureSize endMAC := startMAC + countMACBytesGroupMessage suplMac := message[startMAC:endMAC] message = message[:startMAC] - return r.VerifyMAC(key, cipher, message, suplMac) + return r.VerifyMAC(cipher, message, suplMac) } diff --git a/crypto/goolm/message/group_message_test.go b/crypto/goolm/message/group_message_test.go index d52cf6a3..272138c4 100644 --- a/crypto/goolm/message/group_message_test.go +++ b/crypto/goolm/message/group_message_test.go @@ -4,7 +4,10 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "maunium.net/go/mautrix/crypto/goolm/aessha2" + "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/message" ) @@ -24,7 +27,6 @@ func TestGroupMessageDecode(t *testing.T) { } func TestGroupMessageEncode(t *testing.T) { - expectedRaw := []byte("\x03\x08\xC8\x01\x12\x0aciphertexthmacsha2signature") hmacsha256 := []byte("hmacsha2") sign := []byte("signature") msg := message.GroupMessage{ @@ -32,9 +34,29 @@ func TestGroupMessageEncode(t *testing.T) { MessageIndex: 200, Ciphertext: []byte("ciphertext"), } - encoded, err := msg.EncodeAndMacAndSign(nil, nil, nil) + + cipher, err := aessha2.NewAESSHA2(nil, nil) + require.NoError(t, err) + encoded, err := msg.EncodeAndMACAndSign(cipher, crypto.Ed25519GenerateFromSeed(make([]byte, 32))) assert.NoError(t, err) encoded = append(encoded, hmacsha256...) encoded = append(encoded, sign...) - assert.Equal(t, expectedRaw, encoded) + expected := []byte{ + 0x03, // Version + 0x08, + 0xC8, // 200 + 0x01, + 0x12, + 0x0a, + } + expected = append(expected, []byte("ciphertext")...) + expected = append(expected, []byte{ + 0x6f, 0x95, 0x35, 0x51, 0xdc, 0xdb, 0xcb, 0x03, 0x0b, 0x22, 0xa2, 0xa7, 0xa1, 0xb7, 0x4f, 0x1a, + 0xa3, 0xe9, 0x5c, 0x05, 0x5d, 0x56, 0xdc, 0x5b, 0x87, 0x73, 0x05, 0x42, 0x2a, 0x59, 0x9a, 0x9a, + 0x26, 0x7a, 0x8d, 0xba, 0x65, 0xb2, 0x17, 0x65, 0x51, 0x6f, 0x37, 0xf3, 0x8f, 0xa1, 0x70, 0xd0, + 0xc4, 0x06, 0x05, 0xdc, 0x17, 0x71, 0x5e, 0x63, 0x84, 0xbe, 0xec, 0x7b, 0xa0, 0xc4, 0x08, 0xb8, + 0x9b, 0xc5, 0x08, 0x16, 0xad, 0xe5, 0x43, 0x0c, + }...) + expected = append(expected, []byte("hmacsha2signature")...) + assert.Equal(t, expected, encoded) } diff --git a/crypto/goolm/message/message.go b/crypto/goolm/message/message.go index 8b721aeb..88efdc14 100644 --- a/crypto/goolm/message/message.go +++ b/crypto/goolm/message/message.go @@ -3,7 +3,7 @@ package message import ( "bytes" - "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" ) @@ -76,7 +76,7 @@ func (r *Message) Decode(input []byte) error { // EncodeAndMAC encodes the message and creates the MAC with the key and the cipher. // If key or cipher is nil, no MAC is appended. -func (r *Message) EncodeAndMAC(key []byte, cipher cipher.Cipher) ([]byte, error) { +func (r *Message) EncodeAndMAC(cipher aessha2.AESSHA2) ([]byte, error) { var lengthOfMessage int lengthOfMessage += 1 //Version lengthOfMessage += encodeVarIntByteLength(ratchetKeyTag) + encodeVarStringByteLength(r.RatchetKey) @@ -103,19 +103,17 @@ func (r *Message) EncodeAndMAC(key []byte, cipher cipher.Cipher) ([]byte, error) encodedValue = encodeVarString(r.Ciphertext) copy(out[curPos:], encodedValue) curPos += len(encodedValue) - if len(key) != 0 && cipher != nil { - mac, err := cipher.MAC(key, out) - if err != nil { - return nil, err - } - out = append(out, mac[:countMACBytesMessage]...) + mac, err := cipher.MAC(out) + if err != nil { + return nil, err } + out = append(out, mac[:countMACBytesMessage]...) return out, nil } // VerifyMAC verifies the givenMAC to the calculated MAC of the message. -func (r *Message) VerifyMAC(key []byte, cipher cipher.Cipher, message, givenMAC []byte) (bool, error) { - checkMAC, err := cipher.MAC(key, message) +func (r *Message) VerifyMAC(key []byte, cipher aessha2.AESSHA2, ciphertext, givenMAC []byte) (bool, error) { + checkMAC, err := cipher.MAC(ciphertext) if err != nil { return false, err } @@ -123,7 +121,7 @@ func (r *Message) VerifyMAC(key []byte, cipher cipher.Cipher, message, givenMAC } // VerifyMACInline verifies the MAC taken from the message to the calculated MAC of the message. -func (r *Message) VerifyMACInline(key []byte, cipher cipher.Cipher, message []byte) (bool, error) { +func (r *Message) VerifyMACInline(key []byte, cipher aessha2.AESSHA2, message []byte) (bool, error) { givenMAC := message[len(message)-countMACBytesMessage:] return r.VerifyMAC(key, cipher, message[:len(message)-countMACBytesMessage], givenMAC) } diff --git a/crypto/goolm/message/message_test.go b/crypto/goolm/message/message_test.go index b5c3551b..f3aa7108 100644 --- a/crypto/goolm/message/message_test.go +++ b/crypto/goolm/message/message_test.go @@ -5,6 +5,7 @@ import ( "github.com/stretchr/testify/assert" + "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/message" ) @@ -24,7 +25,7 @@ func TestMessageDecode(t *testing.T) { } func TestMessageEncode(t *testing.T) { - expectedRaw := []byte("\x03\n\nratchetkey\x10\x01\"\nciphertexthmacsha2") + expectedRaw := []byte("\x03\n\nratchetkey\x10\x01\"\nciphertext\x95\x95\x92\x72\x04\x70\x56\xcdhmacsha2") hmacsha256 := []byte("hmacsha2") msg := message.Message{ Version: 3, @@ -32,7 +33,9 @@ func TestMessageEncode(t *testing.T) { RatchetKey: []byte("ratchetkey"), Ciphertext: []byte("ciphertext"), } - encoded, err := msg.EncodeAndMAC(nil, nil) + cipher, err := aessha2.NewAESSHA2(nil, nil) + assert.NoError(t, err) + encoded, err := msg.EncodeAndMAC(cipher) assert.NoError(t, err) encoded = append(encoded, hmacsha256...) assert.Equal(t, expectedRaw, encoded) diff --git a/crypto/goolm/pk/decryption.go b/crypto/goolm/pk/decryption.go index ba94dc37..990df3c0 100644 --- a/crypto/goolm/pk/decryption.go +++ b/crypto/goolm/pk/decryption.go @@ -4,6 +4,7 @@ import ( "encoding/base64" "fmt" + "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/goolmbase64" @@ -57,27 +58,21 @@ func (s Decryption) PrivateKey() crypto.Curve25519PrivateKey { // Decrypt decrypts the ciphertext and verifies the MAC. The base64 encoded key is used to construct the shared secret. func (s Decryption) Decrypt(ephemeralKey, mac, ciphertext []byte) ([]byte, error) { - keyDecoded, err := base64.RawStdEncoding.DecodeString(string(ephemeralKey)) - if err != nil { + if keyDecoded, err := base64.RawStdEncoding.DecodeString(string(ephemeralKey)); err != nil { return nil, err - } - sharedSecret, err := s.KeyPair.SharedSecret(keyDecoded) - if err != nil { + } else if sharedSecret, err := s.KeyPair.SharedSecret(keyDecoded); err != nil { return nil, err - } - decodedMAC, err := goolmbase64.Decode(mac) - if err != nil { + } else if decodedMAC, err := goolmbase64.Decode(mac); err != nil { return nil, err - } - cipher := cipher.NewAESSHA256(nil) - verified, err := cipher.Verify(sharedSecret, ciphertext, decodedMAC) - if err != nil { + } else if cipher, err := aessha2.NewAESSHA2(sharedSecret, nil); err != nil { return nil, err - } - if !verified { + } else if verified, err := cipher.VerifyMAC(ciphertext, decodedMAC); err != nil { + return nil, err + } else if !verified { return nil, fmt.Errorf("decrypt: %w", olm.ErrBadMAC) + } else { + return cipher.Decrypt(ciphertext) } - return cipher.Decrypt(sharedSecret, ciphertext) } // PickleAsJSON returns an Decryption as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. diff --git a/crypto/goolm/pk/encryption.go b/crypto/goolm/pk/encryption.go index c99a9517..23f67ddf 100644 --- a/crypto/goolm/pk/encryption.go +++ b/crypto/goolm/pk/encryption.go @@ -5,7 +5,7 @@ import ( "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/goolmbase64" ) @@ -36,11 +36,11 @@ func (e Encryption) Encrypt(plaintext []byte, privateKey crypto.Curve25519Privat if err != nil { return nil, nil, err } - cipher := cipher.NewAESSHA256(nil) - ciphertext, err = cipher.Encrypt(sharedSecret, plaintext) + cipher, err := aessha2.NewAESSHA2(sharedSecret, nil) + ciphertext, err = cipher.Encrypt(plaintext) if err != nil { return nil, nil, err } - mac, err = cipher.MAC(sharedSecret, ciphertext) + mac, err = cipher.MAC(ciphertext) return ciphertext, goolmbase64.Encode(mac), err } diff --git a/crypto/goolm/ratchet/olm.go b/crypto/goolm/ratchet/olm.go index e53d126a..fcb72c20 100644 --- a/crypto/goolm/ratchet/olm.go +++ b/crypto/goolm/ratchet/olm.go @@ -9,7 +9,7 @@ import ( "golang.org/x/crypto/hkdf" - "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/message" @@ -31,6 +31,8 @@ const ( sharedKeyLength = 32 ) +var olmKeysKDFInfo = []byte("OLM_KEYS") + // KdfInfo has the infos used for the kdf var KdfInfo = struct { Root []byte @@ -40,8 +42,6 @@ var KdfInfo = struct { Ratchet: []byte("OLM_RATCHET"), } -var RatchetCipher = cipher.NewAESSHA256([]byte("OLM_KEYS")) - // Ratchet represents the olm ratchet as described in // // https://gitlab.matrix.org/matrix-org/olm/-/blob/master/docs/olm.md @@ -68,8 +68,7 @@ type Ratchet struct { // New creates a new ratchet, setting the kdfInfos and cipher. func New() *Ratchet { - r := &Ratchet{} - return r + return &Ratchet{} } // InitializeAsBob initializes this ratchet from a receiving point of view (only first message). @@ -117,7 +116,11 @@ func (r *Ratchet) Encrypt(plaintext []byte) ([]byte, error) { messageKey := r.createMessageKeys(r.SenderChains.chainKey()) r.SenderChains.advance() - encryptedText, err := RatchetCipher.Encrypt(messageKey.Key, plaintext) + cipher, err := aessha2.NewAESSHA2(messageKey.Key, olmKeysKDFInfo) + if err != nil { + return nil, err + } + encryptedText, err := cipher.Encrypt(plaintext) if err != nil { return nil, fmt.Errorf("cipher encrypt: %w", err) } @@ -128,7 +131,7 @@ func (r *Ratchet) Encrypt(plaintext []byte) ([]byte, error) { message.RatchetKey = r.SenderChains.ratchetKey().PublicKey message.Ciphertext = encryptedText //creating the mac is done in encode - return message.EncodeAndMAC(messageKey.Key, RatchetCipher) + return message.EncodeAndMAC(cipher) } // Decrypt decrypts the ciphertext and verifies the MAC. @@ -165,15 +168,13 @@ func (r *Ratchet) Decrypt(input []byte) ([]byte, error) { } // Found the key for this message. Check the MAC. - verified, err := message.VerifyMACInline(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, RatchetCipher, input) - if err != nil { + if cipher, err := aessha2.NewAESSHA2(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, olmKeysKDFInfo); err != nil { return nil, err - } - if !verified { + } else if verified, err := message.VerifyMACInline(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, cipher, input); err != nil { + return nil, err + } else if !verified { return nil, fmt.Errorf("decrypt from skipped message keys: %w", olm.ErrBadMAC) - } - result, err := RatchetCipher.Decrypt(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, message.Ciphertext) - if err != nil { + } else if result, err := cipher.Decrypt(message.Ciphertext); err != nil { return nil, fmt.Errorf("cipher decrypt: %w", err) } else if len(result) != 0 { // Remove the key from the skipped keys now that we've @@ -235,14 +236,18 @@ func (r *Ratchet) decryptForExistingChain(chain *receiverChain, message *message } messageKey := r.createMessageKeys(chain.chainKey()) chain.advance() - verified, err := message.VerifyMACInline(messageKey.Key, RatchetCipher, rawMessage) + cipher, err := aessha2.NewAESSHA2(messageKey.Key, olmKeysKDFInfo) + if err != nil { + return nil, err + } + verified, err := message.VerifyMACInline(messageKey.Key, cipher, rawMessage) if err != nil { return nil, err } if !verified { return nil, fmt.Errorf("decrypt from existing chain: %w", olm.ErrBadMAC) } - return RatchetCipher.Decrypt(messageKey.Key, message.Ciphertext) + return cipher.Decrypt(message.Ciphertext) } // decryptForNewChain returns the decrypted message by creating a new chain and advancing the root key. diff --git a/crypto/goolm/ratchet/olm_test.go b/crypto/goolm/ratchet/olm_test.go index 6a8fefc3..2bf7ea0a 100644 --- a/crypto/goolm/ratchet/olm_test.go +++ b/crypto/goolm/ratchet/olm_test.go @@ -6,7 +6,6 @@ import ( "github.com/stretchr/testify/assert" - "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/ratchet" ) @@ -23,7 +22,6 @@ func initializeRatchets() (*ratchet.Ratchet, *ratchet.Ratchet, error) { Root: []byte("Olm"), Ratchet: []byte("OlmRatchet"), } - ratchet.RatchetCipher = cipher.NewAESSHA256([]byte("OlmMessageKeys")) aliceRatchet := ratchet.New() bobRatchet := ratchet.New() diff --git a/crypto/goolm/session/megolm_outbound_session.go b/crypto/goolm/session/megolm_outbound_session.go index b42dab53..6164f965 100644 --- a/crypto/goolm/session/megolm_outbound_session.go +++ b/crypto/goolm/session/megolm_outbound_session.go @@ -66,7 +66,7 @@ func (o *MegolmOutboundSession) Encrypt(plaintext []byte) ([]byte, error) { if len(plaintext) == 0 { return nil, olm.ErrEmptyInput } - encrypted, err := o.Ratchet.Encrypt(plaintext, &o.SigningKey) + encrypted, err := o.Ratchet.Encrypt(plaintext, o.SigningKey) return goolmbase64.Encode(encrypted), err } diff --git a/crypto/goolm/utilities/pickle.go b/crypto/goolm/utilities/pickle.go index 6ce35efe..c6d9d693 100644 --- a/crypto/goolm/utilities/pickle.go +++ b/crypto/goolm/utilities/pickle.go @@ -21,8 +21,8 @@ func PickleAsJSON(object any, pickleVersion byte, key []byte) ([]byte, error) { toEncrypt := make([]byte, len(marshaled)) copy(toEncrypt, marshaled) //pad marshaled to get block size - if len(marshaled)%cipher.PickleBlockSize() != 0 { - padding := cipher.PickleBlockSize() - len(marshaled)%cipher.PickleBlockSize() + if len(marshaled)%cipher.PickleBlockSize != 0 { + padding := cipher.PickleBlockSize - len(marshaled)%cipher.PickleBlockSize toEncrypt = make([]byte, len(marshaled)+padding) copy(toEncrypt, marshaled) } From 976e11ad112afc19ce47fa185e326a4f33726249 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 17 Jan 2025 11:21:24 -0700 Subject: [PATCH 0998/1647] crypto/goolm/message: use buffers for encode/decode functions Signed-off-by: Sumner Evans --- crypto/goolm/message/decoder.go | 82 +++--------- crypto/goolm/message/encoder.go | 24 ++++ .../{decoder_test.go => encoder_test.go} | 34 ++--- crypto/goolm/message/group_message.go | 91 ++++++-------- crypto/goolm/message/message.go | 98 ++++++--------- crypto/goolm/message/prekey_message.go | 119 ++++++++---------- 6 files changed, 178 insertions(+), 270 deletions(-) create mode 100644 crypto/goolm/message/encoder.go rename crypto/goolm/message/{decoder_test.go => encoder_test.go} (58%) diff --git a/crypto/goolm/message/decoder.go b/crypto/goolm/message/decoder.go index 9ce426b5..a71cf302 100644 --- a/crypto/goolm/message/decoder.go +++ b/crypto/goolm/message/decoder.go @@ -1,70 +1,28 @@ package message import ( + "bytes" "encoding/binary" - - "maunium.net/go/mautrix/crypto/olm" ) -// checkDecodeErr checks if there was an error during decode. -func checkDecodeErr(readBytes int) error { - if readBytes == 0 { - //end reached - return olm.ErrInputToSmall +type Decoder struct { + *bytes.Buffer +} + +func NewDecoder(buf []byte) *Decoder { + return &Decoder{bytes.NewBuffer(buf)} +} + +func (d *Decoder) ReadVarInt() (uint64, error) { + return binary.ReadUvarint(d) +} + +func (d *Decoder) ReadVarBytes() ([]byte, error) { + if n, err := d.ReadVarInt(); err != nil { + return nil, err + } else { + out := make([]byte, n) + _, err = d.Read(out) + return out, err } - if readBytes < 0 { - return olm.ErrOverflow - } - return nil -} - -// decodeVarInt decodes a single big-endian encoded varint. -func decodeVarInt(input []byte) (uint32, int) { - value, readBytes := binary.Uvarint(input) - return uint32(value), readBytes -} - -// decodeVarString decodes the length of the string (varint) and returns the actual string -func decodeVarString(input []byte) ([]byte, int) { - stringLen, readBytes := decodeVarInt(input) - if readBytes <= 0 { - return nil, readBytes - } - input = input[readBytes:] - value := input[:stringLen] - readBytes += int(stringLen) - return value, readBytes -} - -// encodeVarIntByteLength returns the number of bytes needed to encode the uint32. -func encodeVarIntByteLength(input uint32) int { - result := 1 - for input >= 128 { - result++ - input >>= 7 - } - return result -} - -// encodeVarStringByteLength returns the number of bytes needed to encode the input. -func encodeVarStringByteLength(input []byte) int { - result := encodeVarIntByteLength(uint32(len(input))) - result += len(input) - return result -} - -// encodeVarInt encodes a single uint32 -func encodeVarInt(input uint32) []byte { - out := make([]byte, encodeVarIntByteLength(input)) - binary.PutUvarint(out, uint64(input)) - return out -} - -// encodeVarString encodes the length of the input (varint) and appends the actual input -func encodeVarString(input []byte) []byte { - out := make([]byte, encodeVarStringByteLength(input)) - length := encodeVarInt(uint32(len(input))) - copy(out, length) - copy(out[len(length):], input) - return out } diff --git a/crypto/goolm/message/encoder.go b/crypto/goolm/message/encoder.go new file mode 100644 index 00000000..95ab6d41 --- /dev/null +++ b/crypto/goolm/message/encoder.go @@ -0,0 +1,24 @@ +package message + +import "encoding/binary" + +type Encoder struct { + buf []byte +} + +func (e *Encoder) Bytes() []byte { + return e.buf +} + +func (e *Encoder) PutByte(val byte) { + e.buf = append(e.buf, val) +} + +func (e *Encoder) PutVarInt(val uint64) { + e.buf = binary.AppendUvarint(e.buf, val) +} + +func (e *Encoder) PutVarBytes(data []byte) { + e.PutVarInt(uint64(len(data))) + e.buf = append(e.buf, data...) +} diff --git a/crypto/goolm/message/decoder_test.go b/crypto/goolm/message/encoder_test.go similarity index 58% rename from crypto/goolm/message/decoder_test.go rename to crypto/goolm/message/encoder_test.go index 8b7561ad..1fe2ebdb 100644 --- a/crypto/goolm/message/decoder_test.go +++ b/crypto/goolm/message/encoder_test.go @@ -1,33 +1,13 @@ -package message +package message_test import ( "testing" "github.com/stretchr/testify/assert" + + "maunium.net/go/mautrix/crypto/goolm/message" ) -func TestEncodeLengthInt(t *testing.T) { - numbers := []uint32{127, 128, 16383, 16384, 32767} - expected := []int{1, 2, 2, 3, 3} - for curIndex := range numbers { - assert.Equal(t, expected[curIndex], encodeVarIntByteLength(numbers[curIndex])) - } -} - -func TestEncodeLengthString(t *testing.T) { - var strings [][]byte - var expected []int - strings = append(strings, []byte("test")) - expected = append(expected, 1+4) - strings = append(strings, []byte("this is a long message with a length of 127 so that the varint of the length is just one byte. just needs some padding---------")) - expected = append(expected, 1+127) - strings = append(strings, []byte("this is an even longer message with a length between 128 and 16383 so that the varint of the length needs two byte. just needs some padding again ---------")) - expected = append(expected, 2+155) - for curIndex := range strings { - assert.Equal(t, expected[curIndex], encodeVarStringByteLength(strings[curIndex])) - } -} - func TestEncodeInt(t *testing.T) { var ints []uint32 var expected [][]byte @@ -40,7 +20,9 @@ func TestEncodeInt(t *testing.T) { ints = append(ints, 16383) expected = append(expected, []byte{0b11111111, 0b01111111}) for curIndex := range ints { - assert.Equal(t, expected[curIndex], encodeVarInt(ints[curIndex])) + var encoder message.Encoder + encoder.PutVarInt(uint64(ints[curIndex])) + assert.Equal(t, expected[curIndex], encoder.Bytes()) } } @@ -70,6 +52,8 @@ func TestEncodeString(t *testing.T) { res = append(res, curTest...) //Add string itself expected = append(expected, res) for curIndex := range strings { - assert.Equal(t, expected[curIndex], encodeVarString(strings[curIndex])) + var encoder message.Encoder + encoder.PutVarBytes(strings[curIndex]) + assert.Equal(t, expected[curIndex], encoder.Bytes()) } } diff --git a/crypto/goolm/message/group_message.go b/crypto/goolm/message/group_message.go index 411e0879..c2a43b1f 100644 --- a/crypto/goolm/message/group_message.go +++ b/crypto/goolm/message/group_message.go @@ -2,6 +2,7 @@ package message import ( "bytes" + "io" "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" @@ -22,85 +23,63 @@ type GroupMessage struct { } // Decodes decodes the input and populates the corresponding fileds. MAC and signature are ignored but have to be present. -func (r *GroupMessage) Decode(input []byte) error { +func (r *GroupMessage) Decode(input []byte) (err error) { r.Version = 0 r.MessageIndex = 0 r.Ciphertext = nil if len(input) == 0 { return nil } - //first Byte is always version - r.Version = input[0] - curPos := 1 - for curPos < len(input)-countMACBytesGroupMessage-crypto.Ed25519SignatureSize { - //Read Key - curKey, readBytes := decodeVarInt(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { - return err - } - curPos += readBytes - if (curKey & 0b111) == 0 { - //The value is of type varint - value, readBytes := decodeVarInt(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { - return err + + decoder := NewDecoder(input[:len(input)-countMACBytesGroupMessage-crypto.Ed25519SignatureSize]) + r.Version, err = decoder.ReadByte() // First byte is the version + if err != nil { + return + } + + for { + // Read Key + if curKey, err := decoder.ReadVarInt(); err != nil { + if err == io.EOF { + // No more keys to read + return nil } - curPos += readBytes - switch curKey { - case messageIndexTag: - r.MessageIndex = value + return err + } else if (curKey & 0b111) == 0 { + // The value is of type varint + if value, err := decoder.ReadVarInt(); err != nil { + return err + } else if curKey == messageIndexTag { + r.MessageIndex = uint32(value) r.HasMessageIndex = true } } else if (curKey & 0b111) == 2 { - //The value is of type string - value, readBytes := decodeVarString(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { + // The value is of type string + if value, err := decoder.ReadVarBytes(); err != nil { return err - } - curPos += readBytes - switch curKey { - case cipherTextTag: + } else if curKey == cipherTextTag { r.Ciphertext = value } } } - - return nil } // EncodeAndMACAndSign encodes the message, creates the mac with the key and the cipher and signs the message. // If macKey or cipher is nil, no mac is appended. If signKey is nil, no signature is appended. func (r *GroupMessage) EncodeAndMACAndSign(cipher aessha2.AESSHA2, signKey crypto.Ed25519KeyPair) ([]byte, error) { - var lengthOfMessage int - lengthOfMessage += 1 //Version - lengthOfMessage += encodeVarIntByteLength(messageIndexTag) + encodeVarIntByteLength(r.MessageIndex) - lengthOfMessage += encodeVarIntByteLength(cipherTextTag) + encodeVarStringByteLength(r.Ciphertext) - out := make([]byte, lengthOfMessage) - out[0] = r.Version - curPos := 1 - encodedTag := encodeVarInt(messageIndexTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue := encodeVarInt(r.MessageIndex) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - encodedTag = encodeVarInt(cipherTextTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue = encodeVarString(r.Ciphertext) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - mac, err := r.MAC(cipher, out) + var encoder Encoder + encoder.PutByte(r.Version) + encoder.PutVarInt(messageIndexTag) + encoder.PutVarInt(uint64(r.MessageIndex)) + encoder.PutVarInt(cipherTextTag) + encoder.PutVarBytes(r.Ciphertext) + mac, err := r.MAC(cipher, encoder.Bytes()) if err != nil { return nil, err } - out = append(out, mac[:countMACBytesGroupMessage]...) - signature, err := signKey.Sign(out) - if err != nil { - return nil, err - } - out = append(out, signature...) - return out, nil + ciphertextWithMAC := append(encoder.Bytes(), mac[:countMACBytesGroupMessage]...) + signature, err := signKey.Sign(ciphertextWithMAC) + return append(ciphertextWithMAC, signature...), err } // MAC returns the MAC of the message calculated with cipher and key. The length of the MAC is truncated to the correct length. diff --git a/crypto/goolm/message/message.go b/crypto/goolm/message/message.go index 88efdc14..8bb6e0cd 100644 --- a/crypto/goolm/message/message.go +++ b/crypto/goolm/message/message.go @@ -2,6 +2,7 @@ package message import ( "bytes" + "io" "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" @@ -24,7 +25,7 @@ type Message struct { } // Decodes decodes the input and populates the corresponding fileds. MAC is ignored but has to be present. -func (r *Message) Decode(input []byte) error { +func (r *Message) Decode(input []byte) (err error) { r.Version = 0 r.HasCounter = false r.Counter = 0 @@ -33,82 +34,55 @@ func (r *Message) Decode(input []byte) error { if len(input) == 0 { return nil } - //first Byte is always version - r.Version = input[0] - curPos := 1 - for curPos < len(input)-countMACBytesMessage { - //Read Key - curKey, readBytes := decodeVarInt(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { - return err - } - curPos += readBytes - if (curKey & 0b111) == 0 { - //The value is of type varint - value, readBytes := decodeVarInt(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { - return err + + decoder := NewDecoder(input[:len(input)-countMACBytesMessage]) + r.Version, err = decoder.ReadByte() // first byte is always version + if err != nil { + return + } + + for { + // Read Key + if curKey, err := decoder.ReadVarInt(); err != nil { + if err == io.EOF { + // No more keys to read + return nil } - curPos += readBytes - switch curKey { - case counterTag: + return err + } else if (curKey & 0b111) == 0 { + // The value is of type varint + if value, err := decoder.ReadVarInt(); err != nil { + return err + } else if curKey == counterTag { + r.Counter = uint32(value) r.HasCounter = true - r.Counter = value } } else if (curKey & 0b111) == 2 { - //The value is of type string - value, readBytes := decodeVarString(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { + // The value is of type string + if value, err := decoder.ReadVarBytes(); err != nil { return err - } - curPos += readBytes - switch curKey { - case ratchetKeyTag: + } else if curKey == ratchetKeyTag { r.RatchetKey = value - case cipherTextKeyTag: + } else if curKey == cipherTextKeyTag { r.Ciphertext = value } } } - - return nil } // EncodeAndMAC encodes the message and creates the MAC with the key and the cipher. // If key or cipher is nil, no MAC is appended. func (r *Message) EncodeAndMAC(cipher aessha2.AESSHA2) ([]byte, error) { - var lengthOfMessage int - lengthOfMessage += 1 //Version - lengthOfMessage += encodeVarIntByteLength(ratchetKeyTag) + encodeVarStringByteLength(r.RatchetKey) - lengthOfMessage += encodeVarIntByteLength(counterTag) + encodeVarIntByteLength(r.Counter) - lengthOfMessage += encodeVarIntByteLength(cipherTextKeyTag) + encodeVarStringByteLength(r.Ciphertext) - out := make([]byte, lengthOfMessage) - out[0] = r.Version - curPos := 1 - encodedTag := encodeVarInt(ratchetKeyTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue := encodeVarString(r.RatchetKey) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - encodedTag = encodeVarInt(counterTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue = encodeVarInt(r.Counter) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - encodedTag = encodeVarInt(cipherTextKeyTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue = encodeVarString(r.Ciphertext) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - mac, err := cipher.MAC(out) - if err != nil { - return nil, err - } - out = append(out, mac[:countMACBytesMessage]...) - return out, nil + var encoder Encoder + encoder.PutByte(r.Version) + encoder.PutVarInt(ratchetKeyTag) + encoder.PutVarBytes(r.RatchetKey) + encoder.PutVarInt(counterTag) + encoder.PutVarInt(uint64(r.Counter)) + encoder.PutVarInt(cipherTextKeyTag) + encoder.PutVarBytes(r.Ciphertext) + mac, err := cipher.MAC(encoder.Bytes()) + return append(encoder.Bytes(), mac[:countMACBytesMessage]...), err } // VerifyMAC verifies the givenMAC to the calculated MAC of the message. diff --git a/crypto/goolm/message/prekey_message.go b/crypto/goolm/message/prekey_message.go index 1238a9a5..22ebf9c3 100644 --- a/crypto/goolm/message/prekey_message.go +++ b/crypto/goolm/message/prekey_message.go @@ -1,11 +1,14 @@ package message import ( + "io" + "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/olm" ) const ( - oneTimeKeyIdTag = 0x0A + oneTimeKeyIDTag = 0x0A baseKeyTag = 0x12 identityKeyTag = 0x1A messageTag = 0x22 @@ -20,7 +23,7 @@ type PreKeyMessage struct { } // Decodes decodes the input and populates the corresponding fileds. -func (r *PreKeyMessage) Decode(input []byte) error { +func (r *PreKeyMessage) Decode(input []byte) (err error) { r.Version = 0 r.IdentityKey = nil r.BaseKey = nil @@ -29,44 +32,52 @@ func (r *PreKeyMessage) Decode(input []byte) error { if len(input) == 0 { return nil } - //first Byte is always version - r.Version = input[0] - curPos := 1 - for curPos < len(input) { - //Read Key - curKey, readBytes := decodeVarInt(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { - return err + + decoder := NewDecoder(input) + r.Version, err = decoder.ReadByte() // first byte is always version + if err != nil { + if err == io.EOF { + return olm.ErrInputToSmall } - curPos += readBytes - if (curKey & 0b111) == 0 { - //The value is of type varint - _, readBytes := decodeVarInt(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { + return + } + + for { + // Read Key + if curKey, err := decoder.ReadVarInt(); err != nil { + if err == io.EOF { + return nil + } + return err + } else if (curKey & 0b111) == 0 { + // The value is of type varint + if _, err = decoder.ReadVarInt(); err != nil { + if err == io.EOF { + return olm.ErrInputToSmall + } return err } - curPos += readBytes } else if (curKey & 0b111) == 2 { - //The value is of type string - value, readBytes := decodeVarString(input[curPos:]) - if err := checkDecodeErr(readBytes); err != nil { + // The value is of type string + if value, err := decoder.ReadVarBytes(); err != nil { + if err == io.EOF { + return olm.ErrInputToSmall + } return err - } - curPos += readBytes - switch curKey { - case oneTimeKeyIdTag: - r.OneTimeKey = value - case baseKeyTag: - r.BaseKey = value - case identityKeyTag: - r.IdentityKey = value - case messageTag: - r.Message = value + } else { + switch curKey { + case oneTimeKeyIDTag: + r.OneTimeKey = value + case baseKeyTag: + r.BaseKey = value + case identityKeyTag: + r.IdentityKey = value + case messageTag: + r.Message = value + } } } } - - return nil } // CheckField verifies the fields. If theirIdentityKey is nil, it is not compared to the key in the message. @@ -84,37 +95,15 @@ func (r *PreKeyMessage) CheckFields(theirIdentityKey *crypto.Curve25519PublicKey // Encode encodes the message. func (r *PreKeyMessage) Encode() ([]byte, error) { - var lengthOfMessage int - lengthOfMessage += 1 //Version - lengthOfMessage += encodeVarIntByteLength(oneTimeKeyIdTag) + encodeVarStringByteLength(r.OneTimeKey) - lengthOfMessage += encodeVarIntByteLength(identityKeyTag) + encodeVarStringByteLength(r.IdentityKey) - lengthOfMessage += encodeVarIntByteLength(baseKeyTag) + encodeVarStringByteLength(r.BaseKey) - lengthOfMessage += encodeVarIntByteLength(messageTag) + encodeVarStringByteLength(r.Message) - out := make([]byte, lengthOfMessage) - out[0] = r.Version - curPos := 1 - encodedTag := encodeVarInt(oneTimeKeyIdTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue := encodeVarString(r.OneTimeKey) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - encodedTag = encodeVarInt(identityKeyTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue = encodeVarString(r.IdentityKey) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - encodedTag = encodeVarInt(baseKeyTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue = encodeVarString(r.BaseKey) - copy(out[curPos:], encodedValue) - curPos += len(encodedValue) - encodedTag = encodeVarInt(messageTag) - copy(out[curPos:], encodedTag) - curPos += len(encodedTag) - encodedValue = encodeVarString(r.Message) - copy(out[curPos:], encodedValue) - return out, nil + var encoder Encoder + encoder.PutByte(r.Version) + encoder.PutVarInt(oneTimeKeyIDTag) + encoder.PutVarBytes(r.OneTimeKey) + encoder.PutVarInt(identityKeyTag) + encoder.PutVarBytes(r.IdentityKey) + encoder.PutVarInt(baseKeyTag) + encoder.PutVarBytes(r.BaseKey) + encoder.PutVarInt(messageTag) + encoder.PutVarBytes(r.Message) + return encoder.Bytes(), nil } From 20db7f86eccc5510f29f6962a3f73f0a1ce07134 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Sat, 26 Oct 2024 18:26:40 -0600 Subject: [PATCH 0999/1647] crypto/goolm: reorganize pickle code Signed-off-by: Sumner Evans --- crypto/goolm/account/account.go | 10 +- crypto/goolm/cipher/pickle.go | 50 -------- crypto/goolm/cipher/pickle_test.go | 28 ----- crypto/goolm/libolmpickle/encoder.go | 40 +++++++ crypto/goolm/libolmpickle/encoder_test.go | 99 ++++++++++++++++ crypto/goolm/libolmpickle/pickle.go | 62 +++++----- crypto/goolm/libolmpickle/pickle_test.go | 107 +++--------------- .../pickle.go => libolmpickle/picklejson.go} | 12 +- crypto/goolm/megolm/megolm.go | 5 +- crypto/goolm/pk/decryption.go | 10 +- crypto/goolm/ratchet/olm.go | 5 +- .../goolm/session/megolm_inbound_session.go | 10 +- .../goolm/session/megolm_outbound_session.go | 10 +- crypto/goolm/session/olm_session.go | 10 +- crypto/sql_store.go | 6 +- 15 files changed, 224 insertions(+), 240 deletions(-) delete mode 100644 crypto/goolm/cipher/pickle.go delete mode 100644 crypto/goolm/cipher/pickle_test.go create mode 100644 crypto/goolm/libolmpickle/encoder.go create mode 100644 crypto/goolm/libolmpickle/encoder_test.go rename crypto/goolm/{utilities/pickle.go => libolmpickle/picklejson.go} (84%) diff --git a/crypto/goolm/account/account.go b/crypto/goolm/account/account.go index 099cc493..4da08a73 100644 --- a/crypto/goolm/account/account.go +++ b/crypto/goolm/account/account.go @@ -8,11 +8,9 @@ import ( "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/session" - "maunium.net/go/mautrix/crypto/goolm/utilities" "maunium.net/go/mautrix/crypto/olm" ) @@ -76,12 +74,12 @@ func NewAccount() (*Account, error) { // PickleAsJSON returns an Account as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. func (a *Account) PickleAsJSON(key []byte) ([]byte, error) { - return utilities.PickleAsJSON(a, accountPickleVersionJSON, key) + return libolmpickle.PickleAsJSON(a, accountPickleVersionJSON, key) } // UnpickleAsJSON updates an Account by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. func (a *Account) UnpickleAsJSON(pickled, key []byte) error { - return utilities.UnpickleAsJSON(a, pickled, key, accountPickleVersionJSON) + return libolmpickle.UnpickleAsJSON(a, pickled, key, accountPickleVersionJSON) } // IdentityKeysJSON returns the public parts of the identity keys for the Account in a JSON string. @@ -322,7 +320,7 @@ func (a *Account) ForgetOldFallbackKey() { // Unpickle decodes the base64 encoded string and decrypts the result with the key. // The decrypted value is then passed to UnpickleLibOlm. func (a *Account) Unpickle(pickled, key []byte) error { - decrypted, err := cipher.Unpickle(key, pickled) + decrypted, err := libolmpickle.Unpickle(key, pickled) if err != nil { return err } @@ -410,7 +408,7 @@ func (a *Account) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { return nil, olm.ErrNoKeyProvided } - return cipher.Pickle(key, a.PickleLibOlm()) + return libolmpickle.Pickle(key, a.PickleLibOlm()) } // PickleLibOlm pickles the [Account] and returns the raw bytes. diff --git a/crypto/goolm/cipher/pickle.go b/crypto/goolm/cipher/pickle.go deleted file mode 100644 index 76f0d248..00000000 --- a/crypto/goolm/cipher/pickle.go +++ /dev/null @@ -1,50 +0,0 @@ -package cipher - -import ( - "crypto/aes" - "fmt" - - "maunium.net/go/mautrix/crypto/goolm/aessha2" - "maunium.net/go/mautrix/crypto/goolm/goolmbase64" - "maunium.net/go/mautrix/crypto/olm" -) - -var kdfPickle = []byte("Pickle") //used to derive the keys for encryption -const pickleMACLength = 8 - -// PickleBlockSize is the blocksize of the pickle cipher. -const PickleBlockSize = aes.BlockSize - -// Pickle encrypts the input with the key and the cipher AESSHA256. The result is then encoded in base64. -func Pickle(key, plaintext []byte) ([]byte, error) { - if c, err := aessha2.NewAESSHA2(key, kdfPickle); err != nil { - return nil, err - } else if ciphertext, err := c.Encrypt(plaintext); err != nil { - return nil, err - } else if mac, err := c.MAC(ciphertext); err != nil { - return nil, err - } else { - return goolmbase64.Encode(append(ciphertext, mac[:pickleMACLength]...)), nil - } -} - -// Unpickle decodes the input from base64 and decrypts the decoded input with the key and the cipher AESSHA256. -func Unpickle(key, input []byte) ([]byte, error) { - ciphertext, err := goolmbase64.Decode(input) - if err != nil { - return nil, err - } - ciphertext, mac := ciphertext[:len(ciphertext)-pickleMACLength], ciphertext[len(ciphertext)-pickleMACLength:] - if c, err := aessha2.NewAESSHA2(key, kdfPickle); err != nil { - return nil, err - } else if verified, err := c.VerifyMAC(ciphertext, mac); err != nil { - return nil, err - } else if !verified { - return nil, fmt.Errorf("decrypt pickle: %w", olm.ErrBadMAC) - } else { - // Set to next block size - targetCipherText := make([]byte, int(len(ciphertext)/PickleBlockSize)*PickleBlockSize) - copy(targetCipherText, ciphertext) - return c.Decrypt(targetCipherText) - } -} diff --git a/crypto/goolm/cipher/pickle_test.go b/crypto/goolm/cipher/pickle_test.go deleted file mode 100644 index b6cfe809..00000000 --- a/crypto/goolm/cipher/pickle_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package cipher_test - -import ( - "crypto/aes" - "testing" - - "github.com/stretchr/testify/assert" - - "maunium.net/go/mautrix/crypto/goolm/cipher" -) - -func TestEncoding(t *testing.T) { - key := []byte("test key") - input := []byte("test") - //pad marshaled to get block size - toEncrypt := input - if len(input)%aes.BlockSize != 0 { - padding := aes.BlockSize - len(input)%aes.BlockSize - toEncrypt = make([]byte, len(input)+padding) - copy(toEncrypt, input) - } - encoded, err := cipher.Pickle(key, toEncrypt) - assert.NoError(t, err) - - decoded, err := cipher.Unpickle(key, encoded) - assert.NoError(t, err) - assert.Equal(t, toEncrypt, decoded) -} diff --git a/crypto/goolm/libolmpickle/encoder.go b/crypto/goolm/libolmpickle/encoder.go new file mode 100644 index 00000000..63e7b09b --- /dev/null +++ b/crypto/goolm/libolmpickle/encoder.go @@ -0,0 +1,40 @@ +package libolmpickle + +import ( + "bytes" + "encoding/binary" + + "go.mau.fi/util/exerrors" +) + +const ( + PickleBoolLength = 1 + PickleUInt8Length = 1 + PickleUInt32Length = 4 +) + +type Encoder struct { + bytes.Buffer +} + +func NewEncoder() *Encoder { return &Encoder{} } + +func (p *Encoder) WriteUInt8(value uint8) { + exerrors.PanicIfNotNil(p.WriteByte(value)) +} + +func (p *Encoder) WriteBool(value bool) { + if value { + exerrors.PanicIfNotNil(p.WriteByte(0x01)) + } else { + exerrors.PanicIfNotNil(p.WriteByte(0x00)) + } +} + +func (p *Encoder) WriteEmptyBytes(count int) { + exerrors.Must(p.Write(make([]byte, count))) +} + +func (p *Encoder) WriteUInt32(value uint32) { + exerrors.PanicIfNotNil(binary.Write(&p.Buffer, binary.BigEndian, value)) +} diff --git a/crypto/goolm/libolmpickle/encoder_test.go b/crypto/goolm/libolmpickle/encoder_test.go new file mode 100644 index 00000000..c7811225 --- /dev/null +++ b/crypto/goolm/libolmpickle/encoder_test.go @@ -0,0 +1,99 @@ +package libolmpickle_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" +) + +func TestEncoder(t *testing.T) { + var encoder libolmpickle.Encoder + encoder.WriteUInt32(4) + encoder.WriteUInt8(8) + encoder.WriteBool(false) + encoder.WriteEmptyBytes(10) + encoder.WriteBool(true) + encoder.Write([]byte("test")) + encoder.WriteUInt32(420_000) + assert.Equal(t, []byte{ + 0x00, 0x00, 0x00, 0x04, // 4 + 0x08, // 8 + 0x00, // false + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // ten empty bytes + 0x01, //true + 0x74, 0x65, 0x73, 0x74, // "test" (ASCII) + 0x00, 0x06, 0x68, 0xa0, // 420,000 + }, encoder.Bytes()) +} + +func TestPickleUInt32(t *testing.T) { + values := []uint32{ + 0xffffffff, + 0x00ff00ff, + 0xf0000000, + 0xf00f0000, + } + expected := [][]byte{ + {0xff, 0xff, 0xff, 0xff}, + {0x00, 0xff, 0x00, 0xff}, + {0xf0, 0x00, 0x00, 0x00}, + {0xf0, 0x0f, 0x00, 0x00}, + } + for i, value := range values { + var encoder libolmpickle.Encoder + encoder.WriteUInt32(value) + assert.Equal(t, expected[i], encoder.Bytes()) + } +} + +func TestPickleBool(t *testing.T) { + values := []bool{ + true, + false, + } + expected := [][]byte{ + {0x01}, + {0x00}, + } + for i, value := range values { + var encoder libolmpickle.Encoder + encoder.WriteBool(value) + assert.Equal(t, expected[i], encoder.Bytes()) + } +} + +func TestPickleUInt8(t *testing.T) { + values := []uint8{ + 0xff, + 0x1a, + } + expected := [][]byte{ + {0xff}, + {0x1a}, + } + for i, value := range values { + var encoder libolmpickle.Encoder + encoder.WriteUInt8(value) + assert.Equal(t, expected[i], encoder.Bytes()) + } +} + +func TestPickleBytes(t *testing.T) { + values := [][]byte{ + {0xff, 0xff, 0xff, 0xff}, + {0x00, 0xff, 0x00, 0xff}, + {0xf0, 0x00, 0x00, 0x00}, + } + expected := [][]byte{ + {0xff, 0xff, 0xff, 0xff}, + {0x00, 0xff, 0x00, 0xff}, + {0xf0, 0x00, 0x00, 0x00}, + } + for i, value := range values { + var encoder libolmpickle.Encoder + encoder.Write(value) + assert.Equal(t, expected[i], encoder.Bytes()) + } +} diff --git a/crypto/goolm/libolmpickle/pickle.go b/crypto/goolm/libolmpickle/pickle.go index 590033fc..d15358fd 100644 --- a/crypto/goolm/libolmpickle/pickle.go +++ b/crypto/goolm/libolmpickle/pickle.go @@ -1,40 +1,48 @@ package libolmpickle import ( - "bytes" - "encoding/binary" + "crypto/aes" + "fmt" - "go.mau.fi/util/exerrors" + "maunium.net/go/mautrix/crypto/goolm/aessha2" + "maunium.net/go/mautrix/crypto/goolm/goolmbase64" + "maunium.net/go/mautrix/crypto/olm" ) -const ( - PickleBoolLength = 1 - PickleUInt8Length = 1 - PickleUInt32Length = 4 -) +const pickleMACLength = 8 -type Encoder struct { - bytes.Buffer -} +var kdfPickle = []byte("Pickle") //used to derive the keys for encryption -func NewEncoder() *Encoder { return &Encoder{} } - -func (p *Encoder) WriteUInt8(value uint8) { - exerrors.PanicIfNotNil(p.WriteByte(value)) -} - -func (p *Encoder) WriteBool(value bool) { - if value { - exerrors.PanicIfNotNil(p.WriteByte(0x01)) +// Pickle encrypts the input with the key and the cipher AESSHA256. The result is then encoded in base64. +func Pickle(key, plaintext []byte) ([]byte, error) { + if c, err := aessha2.NewAESSHA2(key, kdfPickle); err != nil { + return nil, err + } else if ciphertext, err := c.Encrypt(plaintext); err != nil { + return nil, err + } else if mac, err := c.MAC(ciphertext); err != nil { + return nil, err } else { - exerrors.PanicIfNotNil(p.WriteByte(0x00)) + return goolmbase64.Encode(append(ciphertext, mac[:pickleMACLength]...)), nil } } -func (p *Encoder) WriteEmptyBytes(count int) { - exerrors.Must(p.Write(make([]byte, count))) -} - -func (p *Encoder) WriteUInt32(value uint32) { - exerrors.Must(p.Write(binary.BigEndian.AppendUint32(nil, value))) +// Unpickle decodes the input from base64 and decrypts the decoded input with the key and the cipher AESSHA256. +func Unpickle(key, input []byte) ([]byte, error) { + ciphertext, err := goolmbase64.Decode(input) + if err != nil { + return nil, err + } + ciphertext, mac := ciphertext[:len(ciphertext)-pickleMACLength], ciphertext[len(ciphertext)-pickleMACLength:] + if c, err := aessha2.NewAESSHA2(key, kdfPickle); err != nil { + return nil, err + } else if verified, err := c.VerifyMAC(ciphertext, mac); err != nil { + return nil, err + } else if !verified { + return nil, fmt.Errorf("decrypt pickle: %w", olm.ErrBadMAC) + } else { + // Set to next block size + targetCipherText := make([]byte, int(len(ciphertext)/aes.BlockSize)*aes.BlockSize) + copy(targetCipherText, ciphertext) + return c.Decrypt(targetCipherText) + } } diff --git a/crypto/goolm/libolmpickle/pickle_test.go b/crypto/goolm/libolmpickle/pickle_test.go index c7811225..0720e008 100644 --- a/crypto/goolm/libolmpickle/pickle_test.go +++ b/crypto/goolm/libolmpickle/pickle_test.go @@ -1,99 +1,26 @@ -package libolmpickle_test +package libolmpickle import ( + "crypto/aes" "testing" "github.com/stretchr/testify/assert" - - "maunium.net/go/mautrix/crypto/goolm/libolmpickle" ) -func TestEncoder(t *testing.T) { - var encoder libolmpickle.Encoder - encoder.WriteUInt32(4) - encoder.WriteUInt8(8) - encoder.WriteBool(false) - encoder.WriteEmptyBytes(10) - encoder.WriteBool(true) - encoder.Write([]byte("test")) - encoder.WriteUInt32(420_000) - assert.Equal(t, []byte{ - 0x00, 0x00, 0x00, 0x04, // 4 - 0x08, // 8 - 0x00, // false - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // ten empty bytes - 0x01, //true - 0x74, 0x65, 0x73, 0x74, // "test" (ASCII) - 0x00, 0x06, 0x68, 0xa0, // 420,000 - }, encoder.Bytes()) -} +func TestEncoding(t *testing.T) { + key := []byte("test key") + input := []byte("test") + //pad marshaled to get block size + toEncrypt := input + if len(input)%aes.BlockSize != 0 { + padding := aes.BlockSize - len(input)%aes.BlockSize + toEncrypt = make([]byte, len(input)+padding) + copy(toEncrypt, input) + } + encoded, err := Pickle(key, toEncrypt) + assert.NoError(t, err) -func TestPickleUInt32(t *testing.T) { - values := []uint32{ - 0xffffffff, - 0x00ff00ff, - 0xf0000000, - 0xf00f0000, - } - expected := [][]byte{ - {0xff, 0xff, 0xff, 0xff}, - {0x00, 0xff, 0x00, 0xff}, - {0xf0, 0x00, 0x00, 0x00}, - {0xf0, 0x0f, 0x00, 0x00}, - } - for i, value := range values { - var encoder libolmpickle.Encoder - encoder.WriteUInt32(value) - assert.Equal(t, expected[i], encoder.Bytes()) - } -} - -func TestPickleBool(t *testing.T) { - values := []bool{ - true, - false, - } - expected := [][]byte{ - {0x01}, - {0x00}, - } - for i, value := range values { - var encoder libolmpickle.Encoder - encoder.WriteBool(value) - assert.Equal(t, expected[i], encoder.Bytes()) - } -} - -func TestPickleUInt8(t *testing.T) { - values := []uint8{ - 0xff, - 0x1a, - } - expected := [][]byte{ - {0xff}, - {0x1a}, - } - for i, value := range values { - var encoder libolmpickle.Encoder - encoder.WriteUInt8(value) - assert.Equal(t, expected[i], encoder.Bytes()) - } -} - -func TestPickleBytes(t *testing.T) { - values := [][]byte{ - {0xff, 0xff, 0xff, 0xff}, - {0x00, 0xff, 0x00, 0xff}, - {0xf0, 0x00, 0x00, 0x00}, - } - expected := [][]byte{ - {0xff, 0xff, 0xff, 0xff}, - {0x00, 0xff, 0x00, 0xff}, - {0xf0, 0x00, 0x00, 0x00}, - } - for i, value := range values { - var encoder libolmpickle.Encoder - encoder.Write(value) - assert.Equal(t, expected[i], encoder.Bytes()) - } + decoded, err := Unpickle(key, encoded) + assert.NoError(t, err) + assert.Equal(t, toEncrypt, decoded) } diff --git a/crypto/goolm/utilities/pickle.go b/crypto/goolm/libolmpickle/picklejson.go similarity index 84% rename from crypto/goolm/utilities/pickle.go rename to crypto/goolm/libolmpickle/picklejson.go index c6d9d693..308e472c 100644 --- a/crypto/goolm/utilities/pickle.go +++ b/crypto/goolm/libolmpickle/picklejson.go @@ -1,10 +1,10 @@ -package utilities +package libolmpickle import ( + "crypto/aes" "encoding/json" "fmt" - "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/olm" ) @@ -21,12 +21,12 @@ func PickleAsJSON(object any, pickleVersion byte, key []byte) ([]byte, error) { toEncrypt := make([]byte, len(marshaled)) copy(toEncrypt, marshaled) //pad marshaled to get block size - if len(marshaled)%cipher.PickleBlockSize != 0 { - padding := cipher.PickleBlockSize - len(marshaled)%cipher.PickleBlockSize + if len(marshaled)%aes.BlockSize != 0 { + padding := aes.BlockSize - len(marshaled)%aes.BlockSize toEncrypt = make([]byte, len(marshaled)+padding) copy(toEncrypt, marshaled) } - encrypted, err := cipher.Pickle(key, toEncrypt) + encrypted, err := Pickle(key, toEncrypt) if err != nil { return nil, fmt.Errorf("pickle encrypt: %w", err) } @@ -38,7 +38,7 @@ func UnpickleAsJSON(object any, pickled, key []byte, pickleVersion byte) error { if len(key) == 0 { return fmt.Errorf("unpickle: %w", olm.ErrNoKeyProvided) } - decrypted, err := cipher.Unpickle(key, pickled) + decrypted, err := Unpickle(key, pickled) if err != nil { return fmt.Errorf("unpickle decrypt: %w", err) } diff --git a/crypto/goolm/megolm/megolm.go b/crypto/goolm/megolm/megolm.go index 930d8f44..3b5f1e4a 100644 --- a/crypto/goolm/megolm/megolm.go +++ b/crypto/goolm/megolm/megolm.go @@ -12,7 +12,6 @@ import ( "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/message" - "maunium.net/go/mautrix/crypto/goolm/utilities" "maunium.net/go/mautrix/crypto/olm" ) @@ -197,12 +196,12 @@ func (r Ratchet) Decrypt(ciphertext []byte, signingkey *crypto.Ed25519PublicKey, // PickleAsJSON returns a ratchet as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. func (r Ratchet) PickleAsJSON(key []byte) ([]byte, error) { - return utilities.PickleAsJSON(r, megolmPickleVersion, key) + return libolmpickle.PickleAsJSON(r, megolmPickleVersion, key) } // UnpickleAsJSON updates a ratchet by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. func (r *Ratchet) UnpickleAsJSON(pickled, key []byte) error { - return utilities.UnpickleAsJSON(r, pickled, key, megolmPickleVersion) + return libolmpickle.UnpickleAsJSON(r, pickled, key, megolmPickleVersion) } // UnpickleLibOlm decodes the unencryted value and populates the Ratchet accordingly. It returns the number of bytes read. diff --git a/crypto/goolm/pk/decryption.go b/crypto/goolm/pk/decryption.go index 990df3c0..afb01f74 100644 --- a/crypto/goolm/pk/decryption.go +++ b/crypto/goolm/pk/decryption.go @@ -5,11 +5,9 @@ import ( "fmt" "maunium.net/go/mautrix/crypto/goolm/aessha2" - "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" - "maunium.net/go/mautrix/crypto/goolm/utilities" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) @@ -77,18 +75,18 @@ func (s Decryption) Decrypt(ephemeralKey, mac, ciphertext []byte) ([]byte, error // PickleAsJSON returns an Decryption as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. func (a Decryption) PickleAsJSON(key []byte) ([]byte, error) { - return utilities.PickleAsJSON(a, decryptionPickleVersionJSON, key) + return libolmpickle.PickleAsJSON(a, decryptionPickleVersionJSON, key) } // UnpickleAsJSON updates an Decryption by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. func (a *Decryption) UnpickleAsJSON(pickled, key []byte) error { - return utilities.UnpickleAsJSON(a, pickled, key, decryptionPickleVersionJSON) + return libolmpickle.UnpickleAsJSON(a, pickled, key, decryptionPickleVersionJSON) } // Unpickle decodes the base64 encoded string and decrypts the result with the key. // The decrypted value is then passed to UnpickleLibOlm. func (a *Decryption) Unpickle(pickled, key []byte) error { - decrypted, err := cipher.Unpickle(key, pickled) + decrypted, err := libolmpickle.Unpickle(key, pickled) if err != nil { return err } @@ -111,7 +109,7 @@ func (a *Decryption) UnpickleLibOlm(unpickled []byte) error { // Pickle returns a base64 encoded and with key encrypted pickled Decryption using PickleLibOlm(). func (a Decryption) Pickle(key []byte) ([]byte, error) { - return cipher.Pickle(key, a.PickleLibOlm()) + return libolmpickle.Pickle(key, a.PickleLibOlm()) } // PickleLibOlm pickles the [Decryption] into the encoder. diff --git a/crypto/goolm/ratchet/olm.go b/crypto/goolm/ratchet/olm.go index fcb72c20..229c9bd2 100644 --- a/crypto/goolm/ratchet/olm.go +++ b/crypto/goolm/ratchet/olm.go @@ -13,7 +13,6 @@ import ( "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/message" - "maunium.net/go/mautrix/crypto/goolm/utilities" "maunium.net/go/mautrix/crypto/olm" ) @@ -281,12 +280,12 @@ func (r *Ratchet) decryptForNewChain(message *message.Message, rawMessage []byte // PickleAsJSON returns a ratchet as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. func (r Ratchet) PickleAsJSON(key []byte) ([]byte, error) { - return utilities.PickleAsJSON(r, olmPickleVersion, key) + return libolmpickle.PickleAsJSON(r, olmPickleVersion, key) } // UnpickleAsJSON updates a ratchet by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. func (r *Ratchet) UnpickleAsJSON(pickled, key []byte) error { - return utilities.UnpickleAsJSON(r, pickled, key, olmPickleVersion) + return libolmpickle.UnpickleAsJSON(r, pickled, key, olmPickleVersion) } // UnpickleLibOlm unpickles the unencryted value and populates the [Ratchet] diff --git a/crypto/goolm/session/megolm_inbound_session.go b/crypto/goolm/session/megolm_inbound_session.go index 4c107e92..80dd71cc 100644 --- a/crypto/goolm/session/megolm_inbound_session.go +++ b/crypto/goolm/session/megolm_inbound_session.go @@ -4,13 +4,11 @@ import ( "encoding/base64" "fmt" - "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/megolm" "maunium.net/go/mautrix/crypto/goolm/message" - "maunium.net/go/mautrix/crypto/goolm/utilities" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) @@ -161,12 +159,12 @@ func (o *MegolmInboundSession) ID() id.SessionID { // PickleAsJSON returns an MegolmInboundSession as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. func (o *MegolmInboundSession) PickleAsJSON(key []byte) ([]byte, error) { - return utilities.PickleAsJSON(o, megolmInboundSessionPickleVersionJSON, key) + return libolmpickle.PickleAsJSON(o, megolmInboundSessionPickleVersionJSON, key) } // UnpickleAsJSON updates an MegolmInboundSession by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format. func (o *MegolmInboundSession) UnpickleAsJSON(pickled, key []byte) error { - return utilities.UnpickleAsJSON(o, pickled, key, megolmInboundSessionPickleVersionJSON) + return libolmpickle.UnpickleAsJSON(o, pickled, key, megolmInboundSessionPickleVersionJSON) } // Export returns the base64-encoded ratchet key for this session, at the given @@ -192,7 +190,7 @@ func (o *MegolmInboundSession) Unpickle(pickled, key []byte) error { } else if len(pickled) == 0 { return olm.ErrEmptyInput } - decrypted, err := cipher.Unpickle(key, pickled) + decrypted, err := libolmpickle.Unpickle(key, pickled) if err != nil { return err } @@ -234,7 +232,7 @@ func (o *MegolmInboundSession) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { return nil, olm.ErrNoKeyProvided } - return cipher.Pickle(key, o.PickleLibOlm()) + return libolmpickle.Pickle(key, o.PickleLibOlm()) } // PickleLibOlm pickles the session returning the raw bytes. diff --git a/crypto/goolm/session/megolm_outbound_session.go b/crypto/goolm/session/megolm_outbound_session.go index 6164f965..2b8e1c84 100644 --- a/crypto/goolm/session/megolm_outbound_session.go +++ b/crypto/goolm/session/megolm_outbound_session.go @@ -7,12 +7,10 @@ import ( "go.mau.fi/util/exerrors" - "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/megolm" - "maunium.net/go/mautrix/crypto/goolm/utilities" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) @@ -77,12 +75,12 @@ func (o *MegolmOutboundSession) ID() id.SessionID { // PickleAsJSON returns an Session as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. func (o *MegolmOutboundSession) PickleAsJSON(key []byte) ([]byte, error) { - return utilities.PickleAsJSON(o, megolmOutboundSessionPickleVersion, key) + return libolmpickle.PickleAsJSON(o, megolmOutboundSessionPickleVersion, key) } // UnpickleAsJSON updates an Session by a base64 encrypted string with the key. The unencrypted representation has to be in JSON format. func (o *MegolmOutboundSession) UnpickleAsJSON(pickled, key []byte) error { - return utilities.UnpickleAsJSON(o, pickled, key, megolmOutboundSessionPickleVersion) + return libolmpickle.UnpickleAsJSON(o, pickled, key, megolmOutboundSessionPickleVersion) } // Unpickle decodes the base64 encoded string and decrypts the result with the key. @@ -91,7 +89,7 @@ func (o *MegolmOutboundSession) Unpickle(pickled, key []byte) error { if len(key) == 0 { return olm.ErrNoKeyProvided } - decrypted, err := cipher.Unpickle(key, pickled) + decrypted, err := libolmpickle.Unpickle(key, pickled) if err != nil { return err } @@ -117,7 +115,7 @@ func (o *MegolmOutboundSession) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { return nil, olm.ErrNoKeyProvided } - return cipher.Pickle(key, o.PickleLibOlm()) + return libolmpickle.Pickle(key, o.PickleLibOlm()) } // PickleLibOlm pickles the session returning the raw bytes. diff --git a/crypto/goolm/session/olm_session.go b/crypto/goolm/session/olm_session.go index fcd9d0dc..b99ab630 100644 --- a/crypto/goolm/session/olm_session.go +++ b/crypto/goolm/session/olm_session.go @@ -7,13 +7,11 @@ import ( "fmt" "strings" - "maunium.net/go/mautrix/crypto/goolm/cipher" "maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/goolmbase64" "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/goolm/message" "maunium.net/go/mautrix/crypto/goolm/ratchet" - "maunium.net/go/mautrix/crypto/goolm/utilities" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) @@ -189,12 +187,12 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received // PickleAsJSON returns an Session as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format. func (a OlmSession) PickleAsJSON(key []byte) ([]byte, error) { - return utilities.PickleAsJSON(a, olmSessionPickleVersionJSON, key) + return libolmpickle.PickleAsJSON(a, olmSessionPickleVersionJSON, key) } // UnpickleAsJSON updates an Session by a base64 encrypted string with the key. The unencrypted representation has to be in JSON format. func (a *OlmSession) UnpickleAsJSON(pickled, key []byte) error { - return utilities.UnpickleAsJSON(a, pickled, key, olmSessionPickleVersionJSON) + return libolmpickle.UnpickleAsJSON(a, pickled, key, olmSessionPickleVersionJSON) } // ID returns an identifier for this Session. Will be the same for both ends of the conversation. @@ -355,7 +353,7 @@ func (o *OlmSession) Unpickle(pickled, key []byte) error { if len(pickled) == 0 { return olm.ErrEmptyInput } - decrypted, err := cipher.Unpickle(key, pickled) + decrypted, err := libolmpickle.Unpickle(key, pickled) if err != nil { return err } @@ -396,7 +394,7 @@ func (s *OlmSession) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { return nil, olm.ErrNoKeyProvided } - return cipher.Pickle(key, s.PickleLibOlm()) + return libolmpickle.Pickle(key, s.PickleLibOlm()) } // PickleLibOlm pickles the session and returns the raw bytes. diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 0415c704..d6bcb530 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -21,7 +21,7 @@ import ( "go.mau.fi/util/dbutil" "maunium.net/go/mautrix" - "maunium.net/go/mautrix/crypto/goolm/cipher" + "maunium.net/go/mautrix/crypto/goolm/libolmpickle" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/sql_store_upgrade" "maunium.net/go/mautrix/event" @@ -944,7 +944,7 @@ func (store *SQLCryptoStore) DropSignaturesByKey(ctx context.Context, userID id. } func (store *SQLCryptoStore) PutSecret(ctx context.Context, name id.Secret, value string) error { - bytes, err := cipher.Pickle(store.PickleKey, []byte(value)) + bytes, err := libolmpickle.Pickle(store.PickleKey, []byte(value)) if err != nil { return err } @@ -963,7 +963,7 @@ func (store *SQLCryptoStore) GetSecret(ctx context.Context, name id.Secret) (val } else if err != nil { return "", err } - bytes, err = cipher.Unpickle(store.PickleKey, bytes) + bytes, err = libolmpickle.Unpickle(store.PickleKey, bytes) return string(bytes), err } From 71d7d1e097789f42086d0d8cf7eadc199432685d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 21 Jan 2025 12:51:45 +0200 Subject: [PATCH 1000/1647] bridgev2/portal: fix manual CreateMatrixRoom calls when buffer is disabled --- bridgev2/portal.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 30433d5a..def3fb11 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -3567,7 +3567,7 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i } waiter := make(chan struct{}) closed := false - portal.events <- &portalCreateEvent{ + evt := &portalCreateEvent{ ctx: ctx, source: source, info: info, @@ -3579,6 +3579,11 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i } }, } + if PortalEventBuffer == 0 { + go portal.queueEvent(ctx, evt) + } else { + portal.events <- evt + } select { case <-ctx.Done(): return ctx.Err() From 21c059184b832efdc306375dddb12a5e356759c8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 21 Jan 2025 12:52:08 +0200 Subject: [PATCH 1001/1647] bridgev2/networkinterface: add some comments --- bridgev2/networkinterface.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index c0e3f880..34ddc434 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -497,11 +497,18 @@ type FetchMessagesResponse struct { // BackfillingNetworkAPI is an optional interface that network connectors can implement to support backfilling message history. type BackfillingNetworkAPI interface { NetworkAPI + // FetchMessages returns a batch of messages to backfill in a portal room. + // For details on the input and output, see the documentation of [FetchMessagesParams] and [FetchMessagesResponse]. FetchMessages(ctx context.Context, fetchParams FetchMessagesParams) (*FetchMessagesResponse, error) } +// BackfillingNetworkAPIWithLimits is an optional interface that network connectors can implement to customize +// the limit for backwards backfilling tasks. It is recommended to implement this by reading the MaxBatchesOverride +// config field with network-specific keys for different room types. type BackfillingNetworkAPIWithLimits interface { BackfillingNetworkAPI + // GetBackfillMaxBatchCount is called before a backfill task is executed to determine the maximum number of batches + // that should be backfilled. Return values less than 0 are treated as unlimited. GetBackfillMaxBatchCount(ctx context.Context, portal *Portal, task *database.BackfillTask) int } From 2c1aa218aec567e029ac2dc8e4ee8173faa2ec08 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 21 Jan 2025 12:52:23 +0200 Subject: [PATCH 1002/1647] bridgev2/backfill: call complete callback if forward backfill has no messages --- bridgev2/portalbackfill.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index ac15880d..a5dfb42a 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -61,6 +61,9 @@ func (portal *Portal) doForwardBackfill(ctx context.Context, source *UserLogin, return } else if len(resp.Messages) == 0 { log.Debug().Msg("No messages to backfill") + if resp.CompleteCallback != nil { + resp.CompleteCallback() + } return } log.Debug(). From 9fa82729911c7783d028cea02316b2a1b2239034 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 21 Jan 2025 12:52:56 +0200 Subject: [PATCH 1003/1647] bridgev2: add fallback for RunOnce --- bridgev2/bridge.go | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index fc195d35..b574a573 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -10,6 +10,7 @@ import ( "context" "fmt" "sync" + "time" "github.com/rs/zerolog" "go.mau.fi/util/dbutil" @@ -124,6 +125,22 @@ func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID) er if err != nil { return err } + + if loginID == "" { + br.Log.Info().Msg("No login ID provided to RunOnce, running all logins for 20 seconds") + err = br.StartLogins(ctx) + if err != nil { + return err + } + defer br.Stop() + select { + case <-time.After(20 * time.Second): + case <-ctx.Done(): + } + return nil + } + + defer br.stop(true) login, err := br.GetExistingUserLoginByID(ctx, loginID) if err != nil { return fmt.Errorf("failed to get user login: %w", err) @@ -132,11 +149,18 @@ func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID) er } syncClient, ok := login.Client.(BackgroundSyncingNetworkAPI) if !ok { - return fmt.Errorf("%T does not implement BackgroundSyncingNetworkAPI", login.Client) + br.Log.Warn().Msg("Network connector doesn't implement background mode, using fallback mechanism for RunOnce") + login.Client.Connect(ctx) + defer login.Disconnect(nil) + select { + case <-time.After(20 * time.Second): + case <-ctx.Done(): + } + return nil + } else { + br.Log.Info().Str("user_login_id", string(login.ID)).Msg("Starting individual user login in background mode") + return syncClient.ConnectBackground(login.Log.WithContext(ctx)) } - defer br.stop(true) - br.Log.Info().Str("user_login_id", string(login.ID)).Msg("Starting individual user login in background mode") - return syncClient.ConnectBackground(login.Log.WithContext(ctx)) } func (br *Bridge) StartConnectors(ctx context.Context) error { From 524379bdb327e5f34462a19d9df1fd03c3cc02f4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 22 Jan 2025 15:18:41 +0200 Subject: [PATCH 1004/1647] bridgev2/networkinterface: add PushParsingNetwork --- bridgev2/networkinterface.go | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 34ddc434..ca440ce8 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -8,6 +8,7 @@ package bridgev2 import ( "context" + "encoding/json" "fmt" "strings" "time" @@ -809,19 +810,35 @@ type APNsPushConfig struct { } type PushConfig struct { - Web *WebPushConfig `json:"web,omitempty"` - FCM *FCMPushConfig `json:"fcm,omitempty"` - APNs *APNsPushConfig `json:"apns,omitempty"` - Native bool `json:"native,omitempty"` + Web *WebPushConfig `json:"web,omitempty"` + FCM *FCMPushConfig `json:"fcm,omitempty"` + APNs *APNsPushConfig `json:"apns,omitempty"` + // If Native is true, it means the network supports registering for pushes + // that are delivered directly to the app without the use of a push relay. + Native bool `json:"native,omitempty"` } +// PushableNetworkAPI is an optional interface that network connectors can implement +// to support waking up the wrapper app using push notifications. type PushableNetworkAPI interface { NetworkAPI + // RegisterPushNotifications is called when the wrapper app wants to register a push token with the remote network. RegisterPushNotifications(ctx context.Context, pushType PushType, token string) error + // GetPushConfigs is used to find which types of push notifications the remote network can provide. GetPushConfigs() *PushConfig } +// PushParsingNetwork is an optional interface that network connectors can implement +// to support parsing native push notifications from networks. +type PushParsingNetwork interface { + NetworkConnector + + // ParsePushNotification is called when a native push is received. + // It must return the corresponding user login ID to wake up, plus optionally data to pass to the wakeup call. + ParsePushNotification(ctx context.Context, data json.RawMessage) (networkid.UserLoginID, any, error) +} + type RemoteEventType int func (ret RemoteEventType) String() string { From 2d79ce4eed56a682a12746c4aeccb3f6cb10846d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 23 Jan 2025 15:06:50 +0200 Subject: [PATCH 1005/1647] bridgev2: allow passing extra data in ConnectBackground --- bridgev2/bridge.go | 4 ++-- bridgev2/networkinterface.go | 10 +++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index b574a573..040e1c1c 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -119,7 +119,7 @@ func (br *Bridge) Start() error { return nil } -func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID) error { +func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, params *ConnectBackgroundParams) error { br.Background = true err := br.StartConnectors(ctx) if err != nil { @@ -159,7 +159,7 @@ func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID) er return nil } else { br.Log.Info().Str("user_login_id", string(login.ID)).Msg("Starting individual user login in background mode") - return syncClient.ConnectBackground(login.Log.WithContext(ctx)) + return syncClient.ConnectBackground(login.Log.WithContext(ctx), params) } } diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index ca440ce8..487f1ea6 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -353,13 +353,21 @@ type NetworkAPI interface { HandleMatrixMessage(ctx context.Context, msg *MatrixMessage) (message *MatrixMessageResponse, err error) } +type ConnectBackgroundParams struct { + // RawData is the raw data in the push that triggered the background connection. + RawData json.RawMessage + // ExtraData is the data returned by [PushParsingNetwork.ParsePushNotification]. + // It's only present for native pushes. Relayed pushes will only have the raw data. + ExtraData any +} + // BackgroundSyncingNetworkAPI is an optional interface that network connectors can implement to support background resyncs. type BackgroundSyncingNetworkAPI interface { NetworkAPI // ConnectBackground is called in place of Connect for background resyncs. // The client should connect to the remote network, handle pending messages, and then disconnect. // This call should block until the entire sync is complete and the client is disconnected. - ConnectBackground(ctx context.Context) error + ConnectBackground(ctx context.Context, params *ConnectBackgroundParams) error } // FetchMessagesParams contains the parameters for a message history pagination request. From 4cde40cfb9adef8013100dfdc78d6a1b5a3fd341 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 24 Jan 2025 18:01:28 +0200 Subject: [PATCH 1006/1647] bridgev2/matrixinterface: add interface for displaying raw notifications --- bridgev2/matrixinterface.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 8ac2e92d..ac0e4e92 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -86,6 +86,18 @@ type MatrixConnectorWithAnalytics interface { TrackAnalytics(userID id.UserID, event string, properties map[string]any) } +type DirectNotificationData struct { + Portal *Portal + Sender *Ghost + Message string + + FormattedNotification string +} + +type MatrixConnectorWithNotifications interface { + DisplayNotification(ctx context.Context, data *DirectNotificationData) +} + type MatrixSendExtra struct { Timestamp time.Time MessageMeta *database.Message From 873d34ff5db70888a2c748727e55eb7cd4cc6116 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 24 Jan 2025 18:05:41 +0200 Subject: [PATCH 1007/1647] bridgev2/matrixinterface: add message ID field to notification data --- bridgev2/matrixinterface.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index ac0e4e92..1b6477e3 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -87,9 +87,10 @@ type MatrixConnectorWithAnalytics interface { } type DirectNotificationData struct { - Portal *Portal - Sender *Ghost - Message string + Portal *Portal + Sender *Ghost + MessageID networkid.MessageID + Message string FormattedNotification string } From 625dbc6de346b066be0aaff0d1870eba326408a8 Mon Sep 17 00:00:00 2001 From: Brad Murray Date: Mon, 27 Jan 2025 14:40:10 -0500 Subject: [PATCH 1008/1647] Add local bridge state types (#348) --- bridge/status/localbridgestate.go | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 bridge/status/localbridgestate.go diff --git a/bridge/status/localbridgestate.go b/bridge/status/localbridgestate.go new file mode 100644 index 00000000..3ad66538 --- /dev/null +++ b/bridge/status/localbridgestate.go @@ -0,0 +1,23 @@ +package status + +type LocalBridgeAccountState string + +const ( + // LocalBridgeAccountStateSetup means the user wants this account to be setup and connected + LocalBridgeAccountStateSetup LocalBridgeAccountState = "SETUP" + // LocalBridgeAccountStateDeleted means the user wants this account to be deleted + LocalBridgeAccountStateDeleted LocalBridgeAccountState = "DELETED" +) + +type LocalBridgeDeviceState string + +const ( + // LocalBridgeDeviceStateSetup means this device is setup to be connected to this account + LocalBridgeDeviceStateSetup LocalBridgeDeviceState = "SETUP" + // LocalBridgeDeviceStateLoggedOut means the user has logged this particular device out while wanting their other devices to remain setup + LocalBridgeDeviceStateLoggedOut LocalBridgeDeviceState = "LOGGED_OUT" + // LocalBridgeDeviceStateError means this particular device has fallen into a persistent error state that may need user intervention to fix + LocalBridgeDeviceStateError LocalBridgeDeviceState = "ERROR" + // LocalBridgeDeviceStateDeleted means this particular device has cleaned up after the account as a whole was requested to be deleted + LocalBridgeDeviceStateDeleted LocalBridgeDeviceState = "DELETED" +) From 7f209326077a41afa6c9ee537e1c443b5a816142 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 29 Jan 2025 00:34:57 +0200 Subject: [PATCH 1009/1647] client: add method to get full state event --- client.go | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index ea7fd6a1..495dad3b 100644 --- a/client.go +++ b/client.go @@ -1460,8 +1460,8 @@ func (cli *Client) updateStoreWithOutgoingEvent(ctx context.Context, roomID id.R UpdateStateStore(ctx, cli.StateStore, fakeEvt) } -// StateEvent gets a single state event in a room. It will attempt to JSON unmarshal into the given "outContent" struct with -// the HTTP response body, or return an error. +// StateEvent gets the content of a single state event in a room. +// It will attempt to JSON unmarshal into the given "outContent" struct with the HTTP response body, or return an error. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstateeventtypestatekey func (cli *Client) StateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) (err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey) @@ -1472,6 +1472,23 @@ func (cli *Client) StateEvent(ctx context.Context, roomID id.RoomID, eventType e return } +// FullStateEvent gets a single state event in a room. Unlike [StateEvent], this gets the entire event +// (including details like the sender and timestamp). +// This requires the server to support the ?format=event query parameter, which is currently missing from the spec. +// See https://github.com/matrix-org/matrix-spec/issues/1047 for more info +func (cli *Client) FullStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (evt *event.Event, err error) { + u := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey}, map[string]string{ + "format": "event", + }) + _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &evt) + if err == nil && cli.StateStore != nil { + UpdateStateStore(ctx, cli.StateStore, evt) + } + evt.Type.Class = event.StateEventType + _ = evt.Content.ParseRaw(evt.Type) + return +} + // parseRoomStateArray parses a JSON array as a stream and stores the events inside it in a room state map. func parseRoomStateArray(_ *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { response := make(RoomStateMap) From f2966bc55a5843940ca63815934e100e50027e0d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 29 Jan 2025 14:48:48 +0200 Subject: [PATCH 1010/1647] dependencies: update --- go.mod | 6 +++--- go.sum | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/go.mod b/go.mod index 0f4c7ed5..a17fe368 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,7 @@ module maunium.net/go/mautrix go 1.22.0 -toolchain go1.23.4 +toolchain go1.23.5 require ( filippo.io/edwards25519 v1.1.0 @@ -18,10 +18,10 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.8 - go.mau.fi/util v0.8.4 + go.mau.fi/util v0.8.5-0.20250129121406-18c356e558b8 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.32.0 - golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 + golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c golang.org/x/net v0.34.0 golang.org/x/sync v0.10.0 gopkg.in/yaml.v3 v3.0.1 diff --git a/go.sum b/go.sum index fc14b4d9..b38f317a 100644 --- a/go.sum +++ b/go.sum @@ -54,14 +54,14 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.8.4 h1:mVKlJcXWfVo8ZW3f4vqtjGpqtZqJvX4ETekxawt2vnQ= -go.mau.fi/util v0.8.4/go.mod h1:MOfGTs1CBuK6ERTcSL4lb5YU7/ujz09eOPVEDckuazY= +go.mau.fi/util v0.8.5-0.20250129121406-18c356e558b8 h1:O1cRlXPahwbu1ckIf8XgUP3gHMJlSqJxaVTqwRlVK4s= +go.mau.fi/util v0.8.5-0.20250129121406-18c356e558b8/go.mod h1:MOfGTs1CBuK6ERTcSL4lb5YU7/ujz09eOPVEDckuazY= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= -golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 h1:yqrTHse8TCMW1M1ZCP+VAR/l0kKxwaAIqN/il7x4voA= -golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU= +golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c h1:KL/ZBHXgKGVmuZBZ01Lt57yE5ws8ZPSkkihmEyq7FXc= +golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU= golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= From 30ad8a99a8ce32b058b1e5b767797a0511ec1716 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 29 Jan 2025 14:50:04 +0200 Subject: [PATCH 1011/1647] bridgev2: make restarting bridges safer --- bridgev2/backfillqueue.go | 6 ++++-- bridgev2/bridge.go | 32 ++++++++++++++++++++------------ 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/bridgev2/backfillqueue.go b/bridgev2/backfillqueue.go index fce4a1b0..7d521fd1 100644 --- a/bridgev2/backfillqueue.go +++ b/bridgev2/backfillqueue.go @@ -38,8 +38,10 @@ func (br *Bridge) RunBackfillQueue() { return } ctx, cancel := context.WithCancel(log.WithContext(context.Background())) + br.stopBackfillQueue.Clear() + stopChan := br.stopBackfillQueue.GetChan() go func() { - <-br.stopBackfillQueue + <-stopChan cancel() }() batchDelay := time.Duration(br.Config.Backfill.Queue.BatchDelay) * time.Second @@ -61,7 +63,7 @@ func (br *Bridge) RunBackfillQueue() { } } noTasksFoundCount = 0 - case <-br.stopBackfillQueue: + case <-stopChan: if !timer.Stop() { select { case <-timer.C: diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 040e1c1c..309f48ed 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -14,6 +14,7 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/dbutil" + "go.mau.fi/util/exsync" "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2/bridgeconfig" @@ -48,10 +49,11 @@ type Bridge struct { didSplitPortals bool - Background bool + Background bool + ExternallyManagedDB bool wakeupBackfillQueue chan struct{} - stopBackfillQueue chan struct{} + stopBackfillQueue *exsync.Event } func NewBridge( @@ -79,7 +81,7 @@ func NewBridge( ghostsByID: make(map[networkid.UserID]*Ghost), wakeupBackfillQueue: make(chan struct{}), - stopBackfillQueue: make(chan struct{}), + stopBackfillQueue: exsync.NewEvent(), } if br.Config == nil { br.Config = &bridgeconfig.BridgeConfig{CommandPrefix: "!bridge"} @@ -166,15 +168,17 @@ func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, pa func (br *Bridge) StartConnectors(ctx context.Context) error { br.Log.Info().Msg("Starting bridge") - err := br.DB.Upgrade(ctx) - if err != nil { - return DBUpgradeError{Err: err, Section: "main"} + if !br.ExternallyManagedDB { + err := br.DB.Upgrade(ctx) + if err != nil { + return DBUpgradeError{Err: err, Section: "main"} + } } if !br.Background { br.didSplitPortals = br.MigrateToSplitPortals(ctx) } br.Log.Info().Msg("Starting Matrix connector") - err = br.Matrix.Start(ctx) + err := br.Matrix.Start(ctx) if err != nil { return fmt.Errorf("failed to start Matrix connector: %w", err) } @@ -300,7 +304,9 @@ func (br *Bridge) StartLogins(ctx context.Context) error { br.Log.Info().Msg("No user logins found") br.SendGlobalBridgeState(status.BridgeState{StateEvent: status.StateUnconfigured}) } - go br.RunBackfillQueue() + if !br.Background { + go br.RunBackfillQueue() + } br.Log.Info().Msg("Bridge started") return nil @@ -312,7 +318,7 @@ func (br *Bridge) Stop() { func (br *Bridge) stop(isRunOnce bool) { br.Log.Info().Msg("Shutting down bridge") - close(br.stopBackfillQueue) + br.stopBackfillQueue.Set() br.Matrix.Stop() if !isRunOnce { br.cacheLock.Lock() @@ -327,9 +333,11 @@ func (br *Bridge) stop(isRunOnce bool) { if stopNet, ok := br.Network.(StoppableNetwork); ok { stopNet.Stop() } - err := br.DB.Close() - if err != nil { - br.Log.Warn().Err(err).Msg("Failed to close database") + if !br.ExternallyManagedDB { + err := br.DB.Close() + if err != nil { + br.Log.Warn().Err(err).Msg("Failed to close database") + } } br.Log.Info().Msg("Shutdown complete") } From 7c0ed06e43611db248df821276345613aab61101 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 29 Jan 2025 15:11:06 +0200 Subject: [PATCH 1012/1647] bridge,crypto: fix uses of deprecated NewRowIter --- bridge/crypto.go | 6 +----- bridgev2/matrix/crypto.go | 6 +----- crypto/sql_store.go | 15 +++------------ .../verificationhelper/verificationstore_test.go | 12 +++++------- 4 files changed, 10 insertions(+), 29 deletions(-) diff --git a/bridge/crypto.go b/bridge/crypto.go index 4765039b..de1aebbc 100644 --- a/bridge/crypto.go +++ b/bridge/crypto.go @@ -139,15 +139,11 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) { log := helper.log.With().Str("action", "resync encryption event").Logger() rows, err := helper.bridge.DB.Query(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`) + roomIDs, err := dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.RoomID], err).AsList() if err != nil { log.Err(err).Msg("Failed to query rooms for resync") return } - roomIDs, err := dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList() - if err != nil { - log.Err(err).Msg("Failed to scan rooms for resync") - return - } if len(roomIDs) > 0 { log.Debug().Interface("room_ids", roomIDs).Msg("Resyncing rooms") for _, roomID := range roomIDs { diff --git a/bridgev2/matrix/crypto.go b/bridgev2/matrix/crypto.go index f330f9f4..be5e196e 100644 --- a/bridgev2/matrix/crypto.go +++ b/bridgev2/matrix/crypto.go @@ -145,15 +145,11 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) { log := helper.log.With().Str("action", "resync encryption event").Logger() rows, err := helper.store.DB.Query(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`) + roomIDs, err := dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.RoomID], err).AsList() if err != nil { log.Err(err).Msg("Failed to query rooms for resync") return } - roomIDs, err := dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList() - if err != nil { - log.Err(err).Msg("Failed to scan rooms for resync") - return - } if len(roomIDs) > 0 { log.Debug().Interface("room_ids", roomIDs).Msg("Resyncing rooms") for _, roomID := range roomIDs { diff --git a/crypto/sql_store.go b/crypto/sql_store.go index d6bcb530..5bae3a1d 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -428,10 +428,7 @@ func (store *SQLCryptoStore) RedactGroupSessions(ctx context.Context, roomID id. AND session IS NOT NULL AND is_scheduled=false AND received_at IS NOT NULL RETURNING session_id `, event.RoomKeyWithheldBeeperRedacted, "Session redacted: "+reason, roomID, senderKey, store.AccountID) - if err != nil { - return nil, err - } - return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList() + return dbutil.NewRowIterWithError(res, dbutil.ScanSingleColumn[id.SessionID], err).AsList() } func (store *SQLCryptoStore) RedactExpiredGroupSessions(ctx context.Context) ([]id.SessionID, error) { @@ -459,10 +456,7 @@ func (store *SQLCryptoStore) RedactExpiredGroupSessions(ctx context.Context) ([] return nil, fmt.Errorf("unsupported dialect") } res, err := store.DB.Query(ctx, query, event.RoomKeyWithheldBeeperRedacted, "Session redacted: expired", store.AccountID) - if err != nil { - return nil, err - } - return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList() + return dbutil.NewRowIterWithError(res, dbutil.ScanSingleColumn[id.SessionID], err).AsList() } func (store *SQLCryptoStore) RedactOutdatedGroupSessions(ctx context.Context) ([]id.SessionID, error) { @@ -472,10 +466,7 @@ func (store *SQLCryptoStore) RedactOutdatedGroupSessions(ctx context.Context) ([ WHERE account_id=$3 AND session IS NOT NULL AND received_at IS NULL RETURNING session_id `, event.RoomKeyWithheldBeeperRedacted, "Session redacted: outdated", store.AccountID) - if err != nil { - return nil, err - } - return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList() + return dbutil.NewRowIterWithError(res, dbutil.ScanSingleColumn[id.SessionID], err).AsList() } func (store *SQLCryptoStore) PutWithheldGroupSession(ctx context.Context, content event.RoomKeyWithheldEventContent) error { diff --git a/crypto/verificationhelper/verificationstore_test.go b/crypto/verificationhelper/verificationstore_test.go index a3b1895d..e64153b1 100644 --- a/crypto/verificationhelper/verificationstore_test.go +++ b/crypto/verificationhelper/verificationstore_test.go @@ -3,6 +3,7 @@ package verificationhelper_test import ( "context" "database/sql" + "errors" _ "github.com/mattn/go-sqlite3" "github.com/rs/zerolog" @@ -42,20 +43,17 @@ func NewSQLiteVerificationStore(ctx context.Context, db *sql.DB) (*SQLiteVerific func (s *SQLiteVerificationStore) GetAllVerificationTransactions(ctx context.Context) ([]verificationhelper.VerificationTransaction, error) { rows, err := s.db.QueryContext(ctx, selectVerifications) - if err != nil { - return nil, err - } - return dbutil.NewRowIter(rows, func(dbutil.Scannable) (txn verificationhelper.VerificationTransaction, err error) { + return dbutil.NewRowIterWithError(rows, func(dbutil.Scannable) (txn verificationhelper.VerificationTransaction, err error) { err = rows.Scan(&dbutil.JSON{Data: &txn}) return - }).AsList() + }, err).AsList() } func (vq *SQLiteVerificationStore) GetVerificationTransaction(ctx context.Context, txnID id.VerificationTransactionID) (txn verificationhelper.VerificationTransaction, err error) { zerolog.Ctx(ctx).Warn().Stringer("transaction_id", txnID).Msg("Getting verification transaction") row := vq.db.QueryRowContext(ctx, getVerificationByTransactionID, txnID) err = row.Scan(&dbutil.JSON{Data: &txn}) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { err = verificationhelper.ErrUnknownVerificationTransaction } return @@ -64,7 +62,7 @@ func (vq *SQLiteVerificationStore) GetVerificationTransaction(ctx context.Contex func (vq *SQLiteVerificationStore) FindVerificationTransactionForUserDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (txn verificationhelper.VerificationTransaction, err error) { row := vq.db.QueryRowContext(ctx, getVerificationByUserDeviceID, userID, deviceID) err = row.Scan(&dbutil.JSON{Data: &txn}) - if err == sql.ErrNoRows { + if errors.Is(err, sql.ErrNoRows) { err = verificationhelper.ErrUnknownVerificationTransaction } return From 4d1cd8432cce74df65429704362950fea94cf743 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 29 Jan 2025 15:16:06 +0200 Subject: [PATCH 1013/1647] crypto,sqlstatestore: fix more deprecated NewRowIter uses --- crypto/sql_store.go | 15 +++------------ sqlstatestore/statestore.go | 12 +++--------- 2 files changed, 6 insertions(+), 21 deletions(-) diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 5bae3a1d..b3c3c4d9 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -703,11 +703,8 @@ func (store *SQLCryptoStore) GetDevices(ctx context.Context, userID id.UserID) ( } rows, err := store.DB.Query(ctx, "SELECT user_id, device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1 AND deleted=false", userID) - if err != nil { - return nil, err - } data := make(map[id.DeviceID]*id.Device) - err = dbutil.NewRowIter(rows, scanDevice).Iter(func(device *id.Device) (bool, error) { + err = dbutil.NewRowIterWithError(rows, scanDevice, err).Iter(func(device *id.Device) (bool, error) { data[device.DeviceID] = device return true, nil }) @@ -827,10 +824,7 @@ func (store *SQLCryptoStore) FilterTrackedUsers(ctx context.Context, users []id. placeholders, params := userIDsToParams(users) rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+placeholders+")", params...) } - if err != nil { - return users, err - } - return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.UserID]).AsList() + return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() } // MarkTrackedUsersOutdated flags that the device list for given users are outdated. @@ -847,10 +841,7 @@ func (store *SQLCryptoStore) MarkTrackedUsersOutdated(ctx context.Context, users // GetOutdatedTrackerUsers gets all tracked users whose devices need to be updated. func (store *SQLCryptoStore) GetOutdatedTrackedUsers(ctx context.Context) ([]id.UserID, error) { rows, err := store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE devices_outdated = TRUE") - if err != nil { - return nil, err - } - return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.UserID]).AsList() + return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.UserID], err).AsList() } // PutCrossSigningKey stores a cross-signing key of some user along with its usage. diff --git a/sqlstatestore/statestore.go b/sqlstatestore/statestore.go index 33c10c4c..4a220a2b 100644 --- a/sqlstatestore/statestore.go +++ b/sqlstatestore/statestore.go @@ -85,14 +85,11 @@ func (store *SQLStateStore) GetRoomMembers(ctx context.Context, roomID id.RoomID query = fmt.Sprintf("%s AND membership IN (%s)", query, strings.Join(placeholders, ",")) } rows, err := store.Query(ctx, query, args...) - if err != nil { - return nil, err - } members := make(map[id.UserID]*event.MemberEventContent) - return members, dbutil.NewRowIter(rows, func(row dbutil.Scannable) (ret Member, err error) { + return members, dbutil.NewRowIterWithError(rows, func(row dbutil.Scannable) (ret Member, err error) { err = row.Scan(&ret.UserID, &ret.Membership, &ret.Displayname, &ret.AvatarURL) return - }).Iter(func(m Member) (bool, error) { + }, err).Iter(func(m Member) (bool, error) { members[m.UserID] = &m.MemberEventContent return true, nil }) @@ -159,10 +156,7 @@ func (store *SQLStateStore) FindSharedRooms(ctx context.Context, userID id.UserI ` } rows, err := store.Query(ctx, query, userID) - if err != nil { - return nil, err - } - return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList() + return dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.RoomID], err).AsList() } func (store *SQLStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool { From 36942121f4877bb0a990a7fa965f26790d468b6b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 29 Jan 2025 21:35:32 +0200 Subject: [PATCH 1014/1647] crypto/helper: add support for MSC4190 --- crypto/cryptohelper/cryptohelper.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go index 0b3fbeaa..1b0e08e1 100644 --- a/crypto/cryptohelper/cryptohelper.go +++ b/crypto/cryptohelper/cryptohelper.go @@ -36,6 +36,7 @@ type CryptoHelper struct { DecryptErrorCallback func(*event.Event, error) + MSC4190 bool LoginAs *mautrix.ReqLogin ASEventProcessor crypto.ASEventProcessor @@ -151,7 +152,14 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to find existing device ID: %w", err) } - if helper.LoginAs != nil && helper.LoginAs.Type == mautrix.AuthTypeAppservice && helper.client.SetAppServiceDeviceID { + if helper.MSC4190 { + helper.log.Debug().Msg("Creating bot device with MSC4190") + err = helper.client.CreateDeviceMSC4190(ctx, storedDeviceID, helper.LoginAs.InitialDeviceDisplayName) + if err != nil { + return fmt.Errorf("failed to create device for bot: %w", err) + } + rawCryptoStore.DeviceID = helper.client.DeviceID + } else if helper.LoginAs != nil && helper.LoginAs.Type == mautrix.AuthTypeAppservice && helper.client.SetAppServiceDeviceID { if storedDeviceID == "" { helper.log.Debug(). Str("username", helper.LoginAs.Identifier.User). From f915ba26710a6da1be0ef37e398afcde4c5d9b70 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 29 Jan 2025 21:48:36 +0200 Subject: [PATCH 1015/1647] client: add wrapper for MSC4194 --- client.go | 8 ++++++++ responses.go | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/client.go b/client.go index 495dad3b..039fc6da 100644 --- a/client.go +++ b/client.go @@ -1277,6 +1277,14 @@ func (cli *Client) RedactEvent(ctx context.Context, roomID id.RoomID, eventID id return } +func (cli *Client) UnstableRedactUserEvents(ctx context.Context, roomID id.RoomID, userID id.UserID, limit int) (resp *RespRedactUserEvents, err error) { + urlPath := cli.BuildURLWithQuery(ClientURLPath{"unstable", "org.matrix.msc4194", "rooms", roomID, "redact", "user", userID}, map[string]string{ + "limit": strconv.Itoa(limit), + }) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + return +} + // CreateRoom creates a new Matrix room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3createroom // // resp, err := cli.CreateRoom(&mautrix.ReqCreateRoom{ diff --git a/responses.go b/responses.go index dd52b1e7..f4ab024a 100644 --- a/responses.go +++ b/responses.go @@ -100,6 +100,14 @@ type RespSendEvent struct { EventID id.EventID `json:"event_id"` } +type RespRedactUserEvents struct { + IsMoreEvents bool `json:"is_more_events"` + RedactedEvents struct { + Total int `json:"total"` + SoftFailed int `json:"soft_failed"` + } `json:"redacted_events"` +} + // RespMediaConfig is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixmediav3config type RespMediaConfig struct { UploadSize int64 `json:"m.upload.size,omitempty"` From 990519c29f47a8bf9404b69fcbd3ac17ab05be81 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 29 Jan 2025 21:49:37 +0200 Subject: [PATCH 1016/1647] versions: add constant for MSC4194 feature flag --- versions.go | 1 + 1 file changed, 1 insertion(+) diff --git a/versions.go b/versions.go index a8728c34..183bc9ad 100644 --- a/versions.go +++ b/versions.go @@ -64,6 +64,7 @@ var ( FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17} FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111} FeatureMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"} + FeatureUserRedaction = UnstableFeature{UnstableFlag: "org.matrix.msc4194"} BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"} BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"} From 642e17f2aecb48fcbc93ee21d2e6575c3e46f025 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 29 Jan 2025 21:52:05 +0200 Subject: [PATCH 1017/1647] client: add request body for user redact --- client.go | 15 ++++++++++----- requests.go | 5 +++++ 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index 039fc6da..ae37798e 100644 --- a/client.go +++ b/client.go @@ -1277,11 +1277,16 @@ func (cli *Client) RedactEvent(ctx context.Context, roomID id.RoomID, eventID id return } -func (cli *Client) UnstableRedactUserEvents(ctx context.Context, roomID id.RoomID, userID id.UserID, limit int) (resp *RespRedactUserEvents, err error) { - urlPath := cli.BuildURLWithQuery(ClientURLPath{"unstable", "org.matrix.msc4194", "rooms", roomID, "redact", "user", userID}, map[string]string{ - "limit": strconv.Itoa(limit), - }) - _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) +func (cli *Client) UnstableRedactUserEvents(ctx context.Context, roomID id.RoomID, userID id.UserID, req *ReqRedactUser) (resp *RespRedactUserEvents, err error) { + if req == nil { + req = &ReqRedactUser{} + } + query := map[string]string{} + if req.Limit > 0 { + query["limit"] = strconv.Itoa(req.Limit) + } + urlPath := cli.BuildURLWithQuery(ClientURLPath{"unstable", "org.matrix.msc4194", "rooms", roomID, "redact", "user", userID}, query) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) return } diff --git a/requests.go b/requests.go index a796e653..9788aec7 100644 --- a/requests.go +++ b/requests.go @@ -138,6 +138,11 @@ type ReqRedact struct { Extra map[string]interface{} } +type ReqRedactUser struct { + Reason string `json:"reason"` + Limit int `json:"-"` +} + type ReqMembers struct { At string `json:"at"` Membership event.Membership `json:"membership,omitempty"` From cf1004159875759d53d77bcede1c940af1de5853 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 3 Feb 2025 17:33:25 +0200 Subject: [PATCH 1018/1647] bridgev2/portal: fix handling edits if max age is undefined --- bridgev2/portal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index def3fb11..63874333 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1101,7 +1101,7 @@ func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, o log.Warn().Msg("Edit target message not found in database") portal.sendErrorStatus(ctx, evt, fmt.Errorf("edit %w", ErrTargetMessageNotFound)) return - } else if caps.EditMaxAge.Duration > 0 && time.Since(editTarget.Timestamp) > caps.EditMaxAge.Duration { + } else if caps.EditMaxAge != nil && caps.EditMaxAge.Duration > 0 && time.Since(editTarget.Timestamp) > caps.EditMaxAge.Duration { portal.sendErrorStatus(ctx, evt, ErrEditTargetTooOld) return } else if caps.EditMaxCount > 0 && editTarget.EditCount >= caps.EditMaxCount { From 475c4bf39d91a63da2233a1795bd7fcae6ae344c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 4 Feb 2025 00:05:23 +0200 Subject: [PATCH 1019/1647] crypto: fix key exports --- crypto/keyexport.go | 93 +++++++++++++++++++++++++++------------- crypto/keyexport_test.go | 35 +++++++++++++++ crypto/keyimport.go | 4 ++ crypto/sessions.go | 17 ++++++++ go.mod | 2 +- go.sum | 4 +- 6 files changed, 122 insertions(+), 33 deletions(-) create mode 100644 crypto/keyexport_test.go diff --git a/crypto/keyexport.go b/crypto/keyexport.go index 3d126db4..1904c8a5 100644 --- a/crypto/keyexport.go +++ b/crypto/keyexport.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -16,15 +16,21 @@ import ( "encoding/base64" "encoding/binary" "encoding/json" + "errors" "fmt" "math" + "go.mau.fi/util/dbutil" + "go.mau.fi/util/exbytes" + "go.mau.fi/util/exerrors" "go.mau.fi/util/random" "golang.org/x/crypto/pbkdf2" "maunium.net/go/mautrix/id" ) +var ErrNoSessionsForExport = errors.New("no sessions provided for export") + type SenderClaimedKeys struct { Ed25519 id.Ed25519 `json:"ed25519"` } @@ -78,22 +84,14 @@ func makeExportKeys(passphrase string) (encryptionKey, hashKey, salt, iv []byte) return } -func exportSessions(sessions []*InboundGroupSession) ([]ExportedSession, error) { - export := make([]ExportedSession, len(sessions)) +func exportSessions(sessions []*InboundGroupSession) ([]*ExportedSession, error) { + export := make([]*ExportedSession, len(sessions)) + var err error for i, session := range sessions { - key, err := session.Internal.Export(session.Internal.FirstKnownIndex()) + export[i], err = session.export() if err != nil { return nil, fmt.Errorf("failed to export session: %w", err) } - export[i] = ExportedSession{ - Algorithm: id.AlgorithmMegolmV1, - ForwardingChains: session.ForwardingChains, - RoomID: session.RoomID, - SenderKey: session.SenderKey, - SenderClaimedKeys: SenderClaimedKeys{}, - SessionID: session.ID(), - SessionKey: string(key), - } } return export, nil } @@ -107,38 +105,73 @@ func exportSessionsJSON(sessions []*InboundGroupSession) ([]byte, error) { } func formatKeyExportData(data []byte) []byte { - base64Data := make([]byte, base64.StdEncoding.EncodedLen(len(data))) - base64.StdEncoding.Encode(base64Data, data) - - // Prefix + data and newline for each 76 characters of data + suffix + encodedLen := base64.StdEncoding.EncodedLen(len(data)) outputLength := len(exportPrefix) + - len(base64Data) + int(math.Ceil(float64(len(base64Data))/exportLineLengthLimit)) + + encodedLen + int(math.Ceil(float64(encodedLen)/exportLineLengthLimit)) + len(exportSuffix) + output := make([]byte, 0, outputLength) + outputWriter := (*exbytes.Writer)(&output) + base64Writer := base64.NewEncoder(base64.StdEncoding, outputWriter) + lineByteCount := base64.StdEncoding.DecodedLen(exportLineLengthLimit) + exerrors.Must(outputWriter.WriteString(exportPrefix)) + for i := 0; i < len(data); i += lineByteCount { + exerrors.Must(base64Writer.Write(data[i:min(i+lineByteCount, len(data))])) + if i+lineByteCount >= len(data) { + exerrors.PanicIfNotNil(base64Writer.Close()) + } + exerrors.PanicIfNotNil(outputWriter.WriteByte('\n')) + } + exerrors.Must(outputWriter.WriteString(exportSuffix)) + if len(output) != outputLength { + panic(fmt.Errorf("unexpected length %d / %d", len(output), outputLength)) + } + return output +} - var buf bytes.Buffer - buf.Grow(outputLength) - buf.WriteString(exportPrefix) - for ptr := 0; ptr < len(base64Data); ptr += exportLineLengthLimit { - buf.Write(base64Data[ptr:min(ptr+exportLineLengthLimit, len(base64Data))]) - buf.WriteRune('\n') +func ExportKeysIter(passphrase string, sessions dbutil.RowIter[*InboundGroupSession]) ([]byte, error) { + buf := bytes.NewBuffer(make([]byte, 0, 50*1024)) + enc := json.NewEncoder(buf) + buf.WriteByte('[') + err := sessions.Iter(func(session *InboundGroupSession) (bool, error) { + exported, err := session.export() + if err != nil { + return false, err + } + err = enc.Encode(exported) + if err != nil { + return false, err + } + buf.WriteByte(',') + return true, nil + }) + if err != nil { + return nil, err } - buf.WriteString(exportSuffix) - if buf.Len() != buf.Cap() || buf.Len() != outputLength { - panic(fmt.Errorf("unexpected length %d / %d / %d", buf.Len(), buf.Cap(), outputLength)) + output := buf.Bytes() + if len(output) == 1 { + return nil, ErrNoSessionsForExport } - return buf.Bytes() + output[len(output)-1] = ']' // Replace the last comma with a closing bracket + return EncryptKeyExport(passphrase, output) } // ExportKeys exports the given Megolm sessions with the format specified in the Matrix spec. // See https://spec.matrix.org/v1.2/client-server-api/#key-exports func ExportKeys(passphrase string, sessions []*InboundGroupSession) ([]byte, error) { - // Make all the keys necessary for exporting - encryptionKey, hashKey, salt, iv := makeExportKeys(passphrase) + if len(sessions) == 0 { + return nil, ErrNoSessionsForExport + } // Export all the given sessions and put them in JSON unencryptedData, err := exportSessionsJSON(sessions) if err != nil { return nil, err } + return EncryptKeyExport(passphrase, unencryptedData) +} + +func EncryptKeyExport(passphrase string, unencryptedData json.RawMessage) ([]byte, error) { + // Make all the keys necessary for exporting + encryptionKey, hashKey, salt, iv := makeExportKeys(passphrase) // The export data consists of: // 1 byte of export format version diff --git a/crypto/keyexport_test.go b/crypto/keyexport_test.go new file mode 100644 index 00000000..15d944d5 --- /dev/null +++ b/crypto/keyexport_test.go @@ -0,0 +1,35 @@ +// Copyright (c) 2024 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package crypto_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "go.mau.fi/util/exerrors" + "go.mau.fi/util/exfmt" + + "maunium.net/go/mautrix/crypto" + "maunium.net/go/mautrix/crypto/olm" +) + +func TestExportKeys(t *testing.T) { + acc := crypto.NewOlmAccount() + sess := exerrors.Must(crypto.NewInboundGroupSession( + acc.IdentityKey(), + acc.SigningKey(), + "!room:example.com", + exerrors.Must(olm.NewOutboundGroupSession()).Key(), + 7*exfmt.Day, + 100, + false, + )) + data, err := crypto.ExportKeys("meow", []*crypto.InboundGroupSession{sess}) + assert.NoError(t, err) + assert.Len(t, data, 840) +} diff --git a/crypto/keyimport.go b/crypto/keyimport.go index 108c67ac..1dc7f6cc 100644 --- a/crypto/keyimport.go +++ b/crypto/keyimport.go @@ -36,6 +36,10 @@ var ( var exportPrefixBytes, exportSuffixBytes = []byte(exportPrefix), []byte(exportSuffix) func decodeKeyExport(data []byte) ([]byte, error) { + // Fix some types of corruption in the key export file before checking anything + if bytes.IndexByte(data, '\r') != -1 { + data = bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'}) + } // If the valid prefix and suffix aren't there, it's probably not a Matrix key export if !bytes.HasPrefix(data, exportPrefixBytes) { return nil, ErrMissingExportPrefix diff --git a/crypto/sessions.go b/crypto/sessions.go index c22b5b58..457a0a43 100644 --- a/crypto/sessions.go +++ b/crypto/sessions.go @@ -8,6 +8,7 @@ package crypto import ( "errors" + "fmt" "time" "maunium.net/go/mautrix/crypto/olm" @@ -152,6 +153,22 @@ func (igs *InboundGroupSession) RatchetTo(index uint32) error { return nil } +func (igs *InboundGroupSession) export() (*ExportedSession, error) { + key, err := igs.Internal.Export(igs.Internal.FirstKnownIndex()) + if err != nil { + return nil, fmt.Errorf("failed to export session: %w", err) + } + return &ExportedSession{ + Algorithm: id.AlgorithmMegolmV1, + ForwardingChains: igs.ForwardingChains, + RoomID: igs.RoomID, + SenderKey: igs.SenderKey, + SenderClaimedKeys: SenderClaimedKeys{}, + SessionID: igs.ID(), + SessionKey: string(key), + }, nil +} + type OGSState int const ( diff --git a/go.mod b/go.mod index a17fe368..e34ef036 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.8 - go.mau.fi/util v0.8.5-0.20250129121406-18c356e558b8 + go.mau.fi/util v0.8.5-0.20250203220331-1c0d19ea6003 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.32.0 golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c diff --git a/go.sum b/go.sum index b38f317a..e04685a8 100644 --- a/go.sum +++ b/go.sum @@ -54,8 +54,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.8.5-0.20250129121406-18c356e558b8 h1:O1cRlXPahwbu1ckIf8XgUP3gHMJlSqJxaVTqwRlVK4s= -go.mau.fi/util v0.8.5-0.20250129121406-18c356e558b8/go.mod h1:MOfGTs1CBuK6ERTcSL4lb5YU7/ujz09eOPVEDckuazY= +go.mau.fi/util v0.8.5-0.20250203220331-1c0d19ea6003 h1:ye5l+QpYW5CpGVMedb3EHlmflGMQsMtw8mC4K/U8hIw= +go.mau.fi/util v0.8.5-0.20250203220331-1c0d19ea6003/go.mod h1:MOfGTs1CBuK6ERTcSL4lb5YU7/ujz09eOPVEDckuazY= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= From 890db20d8ebce9ffc37226722996370055094f0f Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Wed, 5 Feb 2025 12:22:02 -0700 Subject: [PATCH 1020/1647] verificationhelper: don't request QR scan if not enabled Signed-off-by: Sumner Evans --- crypto/verificationhelper/callbacks_test.go | 69 ++++++++++++++++--- .../verificationhelper/verificationhelper.go | 31 +++++---- .../verificationhelper_test.go | 19 ++--- 3 files changed, 86 insertions(+), 33 deletions(-) diff --git a/crypto/verificationhelper/callbacks_test.go b/crypto/verificationhelper/callbacks_test.go index 466a60fc..b5ca9af8 100644 --- a/crypto/verificationhelper/callbacks_test.go +++ b/crypto/verificationhelper/callbacks_test.go @@ -32,6 +32,8 @@ type baseVerificationCallbacks struct { decimalsShown map[id.VerificationTransactionID][]int } +var _ verificationhelper.RequiredCallbacks = (*baseVerificationCallbacks)(nil) + func newBaseVerificationCallbacks() *baseVerificationCallbacks { return &baseVerificationCallbacks{ verificationsRequested: map[id.UserID][]id.VerificationTransactionID{}, @@ -98,6 +100,8 @@ type sasVerificationCallbacks struct { *baseVerificationCallbacks } +var _ verificationhelper.ShowSASCallbacks = (*sasVerificationCallbacks)(nil) + func newSASVerificationCallbacks() *sasVerificationCallbacks { return &sasVerificationCallbacks{newBaseVerificationCallbacks()} } @@ -112,34 +116,76 @@ func (c *sasVerificationCallbacks) ShowSAS(ctx context.Context, txnID id.Verific c.decimalsShown[txnID] = decimals } -type qrCodeVerificationCallbacks struct { +type scanQRCodeVerificationCallbacks struct { *baseVerificationCallbacks } -func newQRCodeVerificationCallbacks() *qrCodeVerificationCallbacks { - return &qrCodeVerificationCallbacks{newBaseVerificationCallbacks()} +var _ verificationhelper.ScanQRCodeCallbacks = (*scanQRCodeVerificationCallbacks)(nil) + +func newScanQRCodeVerificationCallbacks() *scanQRCodeVerificationCallbacks { + return &scanQRCodeVerificationCallbacks{newBaseVerificationCallbacks()} } -func newQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *qrCodeVerificationCallbacks { - return &qrCodeVerificationCallbacks{base} +func newScanQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *scanQRCodeVerificationCallbacks { + return &scanQRCodeVerificationCallbacks{base} } - -func (c *qrCodeVerificationCallbacks) ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID) { +func (c *scanQRCodeVerificationCallbacks) ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID) { c.scanQRCodeTransactions = append(c.scanQRCodeTransactions, txnID) } -func (c *qrCodeVerificationCallbacks) ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *verificationhelper.QRCode) { +type showQRCodeVerificationCallbacks struct { + *baseVerificationCallbacks +} + +var _ verificationhelper.ShowQRCodeCallbacks = (*showQRCodeVerificationCallbacks)(nil) + +func newShowQRCodeVerificationCallbacks() *showQRCodeVerificationCallbacks { + return &showQRCodeVerificationCallbacks{newBaseVerificationCallbacks()} +} + +func newShowQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *showQRCodeVerificationCallbacks { + return &showQRCodeVerificationCallbacks{base} +} + +func (c *showQRCodeVerificationCallbacks) ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *verificationhelper.QRCode) { c.qrCodesShown[txnID] = qrCode } -func (c *qrCodeVerificationCallbacks) QRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) { +func (c *showQRCodeVerificationCallbacks) QRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) { c.qrCodesScanned[txnID] = struct{}{} } +type showAndScanQRCodeVerificationCallbacks struct { + *baseVerificationCallbacks + *showQRCodeVerificationCallbacks + *scanQRCodeVerificationCallbacks +} + +var _ verificationhelper.ScanQRCodeCallbacks = (*showAndScanQRCodeVerificationCallbacks)(nil) +var _ verificationhelper.ShowQRCodeCallbacks = (*showAndScanQRCodeVerificationCallbacks)(nil) + +func newShowAndScanQRCodeVerificationCallbacks() *showAndScanQRCodeVerificationCallbacks { + base := newBaseVerificationCallbacks() + return &showAndScanQRCodeVerificationCallbacks{ + base, + newShowQRCodeVerificationCallbacks(), + newScanQRCodeVerificationCallbacks(), + } +} + +func newShowAndScanQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *showAndScanQRCodeVerificationCallbacks { + return &showAndScanQRCodeVerificationCallbacks{ + base, + newShowQRCodeVerificationCallbacks(), + newScanQRCodeVerificationCallbacks(), + } +} + type allVerificationCallbacks struct { *baseVerificationCallbacks *sasVerificationCallbacks - *qrCodeVerificationCallbacks + *scanQRCodeVerificationCallbacks + *showQRCodeVerificationCallbacks } func newAllVerificationCallbacks() *allVerificationCallbacks { @@ -147,6 +193,7 @@ func newAllVerificationCallbacks() *allVerificationCallbacks { return &allVerificationCallbacks{ base, newSASVerificationCallbacksWithBase(base), - newQRCodeVerificationCallbacksWithBase(base), + newScanQRCodeVerificationCallbacksWithBase(base), + newShowQRCodeVerificationCallbacksWithBase(base), } } diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index be547e7e..92d4de23 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -15,6 +15,7 @@ import ( "time" "github.com/rs/zerolog" + "go.mau.fi/util/exslices" "go.mau.fi/util/jsontime" "golang.org/x/exp/maps" "golang.org/x/exp/slices" @@ -47,12 +48,14 @@ type ShowSASCallbacks interface { ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, emojiDescriptions []string, decimals []int) } -type ShowQRCodeCallbacks interface { +type ScanQRCodeCallbacks interface { // ScanQRCode is called when another device has sent a // m.key.verification.ready event and indicated that they are capable of // showing a QR code. ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID) +} +type ShowQRCodeCallbacks interface { // ShowQRCode is called when the verification has been accepted and a QR // code should be shown to the user. ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *QRCode) @@ -108,24 +111,22 @@ func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, stor helper.verificationDone = c.VerificationDone } - supportedMethods := map[event.VerificationMethod]struct{}{} if c, ok := callbacks.(ShowSASCallbacks); ok { - supportedMethods[event.VerificationMethodSAS] = struct{}{} + helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodSAS) helper.showSAS = c.ShowSAS } if c, ok := callbacks.(ShowQRCodeCallbacks); ok { - supportedMethods[event.VerificationMethodQRCodeShow] = struct{}{} - supportedMethods[event.VerificationMethodReciprocate] = struct{}{} - helper.scanQRCode = c.ScanQRCode + helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeShow) + helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodReciprocate) helper.showQRCode = c.ShowQRCode helper.qrCodeScaned = c.QRCodeScanned } - if supportsScan { - supportedMethods[event.VerificationMethodQRCodeScan] = struct{}{} - supportedMethods[event.VerificationMethodReciprocate] = struct{}{} + if c, ok := callbacks.(ScanQRCodeCallbacks); ok && supportsScan { + helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeScan) + helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodReciprocate) + helper.scanQRCode = c.ScanQRCode } - - helper.supportedMethods = maps.Keys(supportedMethods) + helper.supportedMethods = exslices.DeduplicateUnsorted(helper.supportedMethods) return &helper } @@ -420,7 +421,9 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V } txn.VerificationState = VerificationStateReady - if vh.scanQRCode != nil && slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) { + if vh.scanQRCode != nil && + slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && // technically redundant because vh.scanQRCode is only set if this is true + slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) { vh.scanQRCode(ctx, txn.TransactionID) } @@ -734,7 +737,9 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn Verif } } - if vh.scanQRCode != nil && slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) { + if vh.scanQRCode != nil && + slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && // technically redundant because vh.scanQRCode is only set if this is true + slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) { vh.scanQRCode(ctx, txn.TransactionID) } diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go index 49c8db07..31bc7d6e 100644 --- a/crypto/verificationhelper/verificationhelper_test.go +++ b/crypto/verificationhelper/verificationhelper_test.go @@ -95,11 +95,12 @@ func TestVerification_Start(t *testing.T) { expectedVerificationMethods []event.VerificationMethod }{ {false, newBaseVerificationCallbacks(), "no supported verification methods", nil}, - {true, newBaseVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, {false, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}}, - {true, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, - {true, newQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, - {false, newQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {true, newScanQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {false, newScanQRCodeVerificationCallbacks(), "no supported verification methods", nil}, + {false, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {true, newShowAndScanQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {false, newShowAndScanQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, {false, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, {true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, } @@ -124,7 +125,7 @@ func TestVerification_Start(t *testing.T) { return } - assert.NoError(t, err) + require.NoError(t, err) assert.NotEmpty(t, txnID) toDeviceInbox := ts.DeviceInbox[aliceUserID] @@ -283,8 +284,8 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { expectedVerificationMethods []event.VerificationMethod }{ {false, false, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}}, - {true, false, newQRCodeVerificationCallbacks(), newQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, - {false, true, newQRCodeVerificationCallbacks(), newQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {true, false, newShowAndScanQRCodeVerificationCallbacks(), newShowAndScanQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {false, true, newShowAndScanQRCodeVerificationCallbacks(), newShowAndScanQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, {true, false, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, {true, true, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, } @@ -321,10 +322,10 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - _, sendingIsQRCallbacks := tc.sendingCallbacks.(*qrCodeVerificationCallbacks) + _, sendingIsQRCallbacks := tc.sendingCallbacks.(*showQRCodeVerificationCallbacks) _, sendingIsAllCallbacks := tc.sendingCallbacks.(*allVerificationCallbacks) sendingCanShowQR := sendingIsQRCallbacks || sendingIsAllCallbacks - _, receivingIsQRCallbacks := tc.receivingCallbacks.(*qrCodeVerificationCallbacks) + _, receivingIsQRCallbacks := tc.receivingCallbacks.(*showQRCodeVerificationCallbacks) _, receivingIsAllCallbacks := tc.receivingCallbacks.(*allVerificationCallbacks) receivingCanShowQR := receivingIsQRCallbacks || receivingIsAllCallbacks From 4c652f52008006368e27d291f3850477d4efe365 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 6 Feb 2025 15:02:35 +0200 Subject: [PATCH 1021/1647] bridgev2: add FormattedTitle to direct notification data --- bridgev2/matrixinterface.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 1b6477e3..2d6ed982 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -93,6 +93,7 @@ type DirectNotificationData struct { Message string FormattedNotification string + FormattedTitle string } type MatrixConnectorWithNotifications interface { From 29319ccfd515102659fa20136e09b30562057f71 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 8 Feb 2025 16:17:54 +0200 Subject: [PATCH 1022/1647] pushrules: fix word boundary matching and case sensitivity --- go.mod | 2 +- go.sum | 4 ++-- pushrules/rule.go | 20 +++++++++++++++----- pushrules/rule_test.go | 28 ++++++++++++++++++++++++++++ 4 files changed, 46 insertions(+), 8 deletions(-) diff --git a/go.mod b/go.mod index e34ef036..e8e71d54 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.8 - go.mau.fi/util v0.8.5-0.20250203220331-1c0d19ea6003 + go.mau.fi/util v0.8.5-0.20250208141401-fde0c0c733f1 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.32.0 golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c diff --git a/go.sum b/go.sum index e04685a8..3f58ff96 100644 --- a/go.sum +++ b/go.sum @@ -54,8 +54,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.8.5-0.20250203220331-1c0d19ea6003 h1:ye5l+QpYW5CpGVMedb3EHlmflGMQsMtw8mC4K/U8hIw= -go.mau.fi/util v0.8.5-0.20250203220331-1c0d19ea6003/go.mod h1:MOfGTs1CBuK6ERTcSL4lb5YU7/ujz09eOPVEDckuazY= +go.mau.fi/util v0.8.5-0.20250208141401-fde0c0c733f1 h1:XQ47o9cbYOCtohOkxXIzIM3xnSsR/8lggdgLEZm8PHU= +go.mau.fi/util v0.8.5-0.20250208141401-fde0c0c733f1/go.mod h1:MOfGTs1CBuK6ERTcSL4lb5YU7/ujz09eOPVEDckuazY= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= diff --git a/pushrules/rule.go b/pushrules/rule.go index ee6d33c4..cf659695 100644 --- a/pushrules/rule.go +++ b/pushrules/rule.go @@ -8,7 +8,10 @@ package pushrules import ( "encoding/gob" + "regexp" + "strings" + "go.mau.fi/util/exerrors" "go.mau.fi/util/glob" "maunium.net/go/mautrix/event" @@ -165,13 +168,20 @@ func (rule *PushRule) matchConditions(room Room, evt *event.Event) bool { } func (rule *PushRule) matchPattern(room Room, evt *event.Event) bool { - pattern := glob.CompileWithImplicitContains(rule.Pattern) - if pattern == nil { - return false - } msg, ok := evt.Content.Raw["body"].(string) if !ok { return false } - return pattern.Match(msg) + var buf strings.Builder + // As per https://spec.matrix.org/unstable/client-server-api/#push-rules, content rules are case-insensitive + // and must match whole words, so wrap the converted glob in (?i) and \b. + buf.WriteString(`(?i)\b`) + // strings.Builder will never return errors + exerrors.PanicIfNotNil(glob.ToRegexPattern(rule.Pattern, &buf)) + buf.WriteString(`\b`) + pattern, err := regexp.Compile(buf.String()) + if err != nil { + return false + } + return pattern.MatchString(msg) } diff --git a/pushrules/rule_test.go b/pushrules/rule_test.go index 803c721e..7ff839a7 100644 --- a/pushrules/rule_test.go +++ b/pushrules/rule_test.go @@ -186,6 +186,34 @@ func TestPushRule_Match_Content(t *testing.T) { assert.True(t, rule.Match(blankTestRoom, evt)) } +func TestPushRule_Match_WordBoundary(t *testing.T) { + rule := &pushrules.PushRule{ + Type: pushrules.ContentRule, + Enabled: true, + Pattern: "test", + } + + evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{ + MsgType: event.MsgEmote, + Body: "is testing pushrules", + }) + assert.False(t, rule.Match(blankTestRoom, evt)) +} + +func TestPushRule_Match_CaseInsensitive(t *testing.T) { + rule := &pushrules.PushRule{ + Type: pushrules.ContentRule, + Enabled: true, + Pattern: "test", + } + + evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{ + MsgType: event.MsgEmote, + Body: "is TeSt-InG pushrules", + }) + assert.True(t, rule.Match(blankTestRoom, evt)) +} + func TestPushRule_Match_Content_Fail(t *testing.T) { rule := &pushrules.PushRule{ Type: pushrules.ContentRule, From aaad5119e01e190031a7cef8e63f6a729504cd04 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 12 Feb 2025 13:44:24 +0200 Subject: [PATCH 1023/1647] dependencies: update go --- .github/workflows/go.yml | 10 +++++----- go.mod | 6 +++--- go.sum | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 10025368..71c1988b 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -15,7 +15,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: "1.23" + go-version: "1.24" cache: true - name: Install libolm @@ -34,8 +34,8 @@ jobs: strategy: fail-fast: false matrix: - go-version: ["1.22", "1.23"] - name: Build (${{ matrix.go-version == '1.23' && 'latest' || 'old' }}, libolm) + go-version: ["1.23", "1.24"] + name: Build (${{ matrix.go-version == '1.24' && 'latest' || 'old' }}, libolm) steps: - uses: actions/checkout@v4 @@ -65,8 +65,8 @@ jobs: strategy: fail-fast: false matrix: - go-version: ["1.22", "1.23"] - name: Build (${{ matrix.go-version == '1.23' && 'latest' || 'old' }}, goolm) + go-version: ["1.23", "1.24"] + name: Build (${{ matrix.go-version == '1.24' && 'latest' || 'old' }}, goolm) steps: - uses: actions/checkout@v4 diff --git a/go.mod b/go.mod index e8e71d54..69752cdb 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,8 @@ module maunium.net/go/mautrix -go 1.22.0 +go 1.23.0 -toolchain go1.23.5 +toolchain go1.24.0 require ( filippo.io/edwards25519 v1.1.0 @@ -18,7 +18,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.8 - go.mau.fi/util v0.8.5-0.20250208141401-fde0c0c733f1 + go.mau.fi/util v0.8.5-0.20250212114338-06310c7133a5 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.32.0 golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c diff --git a/go.sum b/go.sum index 3f58ff96..73a5e58a 100644 --- a/go.sum +++ b/go.sum @@ -54,8 +54,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.8.5-0.20250208141401-fde0c0c733f1 h1:XQ47o9cbYOCtohOkxXIzIM3xnSsR/8lggdgLEZm8PHU= -go.mau.fi/util v0.8.5-0.20250208141401-fde0c0c733f1/go.mod h1:MOfGTs1CBuK6ERTcSL4lb5YU7/ujz09eOPVEDckuazY= +go.mau.fi/util v0.8.5-0.20250212114338-06310c7133a5 h1:b8XKQEONXqnawxcKVdu2HqjmFSa+6tAeHyCI0E/JM2g= +go.mau.fi/util v0.8.5-0.20250212114338-06310c7133a5/go.mod h1:rleUeT8LjZQBUDg/FKAdX/xbd/Vnuof/UoCcnGvGm9M= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= From 100d945d3977a3261e01090e245c9ffa3bd47af2 Mon Sep 17 00:00:00 2001 From: Brad Murray Date: Wed, 12 Feb 2025 16:58:04 -0500 Subject: [PATCH 1024/1647] Trust key backups if the public key matches (#351) --- crypto/keybackup.go | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/crypto/keybackup.go b/crypto/keybackup.go index 4e9431bb..fe0b40dc 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -2,6 +2,7 @@ package crypto import ( "context" + "encoding/base64" "fmt" "time" @@ -21,7 +22,7 @@ func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, meg ctx = log.WithContext(ctx) - versionInfo, err := mach.GetAndVerifyLatestKeyBackupVersion(ctx) + versionInfo, err := mach.GetAndVerifyLatestKeyBackupVersion(ctx, megolmBackupKey) if err != nil { return "", err } else if versionInfo == nil { @@ -32,7 +33,7 @@ func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, meg return versionInfo.Version, err } -func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context) (*mautrix.RespRoomKeysVersion[backup.MegolmAuthData], error) { +func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context, megolmBackupKey *backup.MegolmBackupKey) (*mautrix.RespRoomKeysVersion[backup.MegolmAuthData], error) { versionInfo, err := mach.Client.GetKeyBackupLatestVersion(ctx) if err != nil { return nil, err @@ -48,6 +49,17 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context) Stringer("key_backup_version", versionInfo.Version). Logger() + // https://spec.matrix.org/v1.10/client-server-api/#server-side-key-backups + // "Clients must only store keys in backups after they have ensured that the auth_data is trusted. This can be done either... + // ...by deriving the public key from a private key that it obtained from a trusted source. Trusted sources for the private + // key include the user entering the key, retrieving the key stored in secret storage, or obtaining the key via secret sharing + // from a verified device belonging to the same user." + if megolmBackupKey != nil && versionInfo.AuthData.PublicKey == id.Ed25519(base64.RawStdEncoding.EncodeToString(megolmBackupKey.PublicKey().Bytes())) { + log.Debug().Msg("key backup is trusted based on public key") + return versionInfo, nil + } + + // "...or checking that it is signed by the user’s master cross-signing key or by a verified device belonging to the same user" userSignatures, ok := versionInfo.AuthData.Signatures[mach.Client.UserID] if !ok { return nil, fmt.Errorf("no signature from user %s found in key backup", mach.Client.UserID) @@ -87,6 +99,7 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context) continue } else { // One of the signatures is valid, break from the loop. + log.Debug().Stringer("key_id", keyID).Msg("key backup is trusted based on matching signature") signatureVerified = true break } From 041784441f73ef5187338894cbc09e31684944d7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 13 Feb 2025 14:07:31 +0200 Subject: [PATCH 1025/1647] crypto: add context to IsDeviceTrusted and deprecate ResolveTrust --- bridge/crypto.go | 2 +- bridgev2/matrix/crypto.go | 2 +- crypto/cross_sign_test.go | 10 +++++----- crypto/cross_sign_validation.go | 11 +++++++++-- crypto/encryptmegolm.go | 2 +- crypto/keybackup.go | 2 +- crypto/keysharing.go | 2 +- 7 files changed, 19 insertions(+), 12 deletions(-) diff --git a/bridge/crypto.go b/bridge/crypto.go index de1aebbc..e3885a22 100644 --- a/bridge/crypto.go +++ b/bridge/crypto.go @@ -193,7 +193,7 @@ func (helper *CryptoHelper) allowKeyShare(ctx context.Context, device *id.Device return &crypto.KeyShareRejectNoResponse } else if device.Trust == id.TrustStateBlacklisted { return &crypto.KeyShareRejectBlacklisted - } else if trustState := helper.mach.ResolveTrust(device); trustState >= cfg.VerificationLevels.Share { + } else if trustState, _ := helper.mach.ResolveTrustContext(ctx, device); trustState >= cfg.VerificationLevels.Share { portal := helper.bridge.Child.GetIPortal(info.RoomID) if portal == nil { zerolog.Ctx(ctx).Debug().Msg("Rejecting key request: room is not a portal") diff --git a/bridgev2/matrix/crypto.go b/bridgev2/matrix/crypto.go index be5e196e..6e6416a9 100644 --- a/bridgev2/matrix/crypto.go +++ b/bridgev2/matrix/crypto.go @@ -199,7 +199,7 @@ func (helper *CryptoHelper) allowKeyShare(ctx context.Context, device *id.Device return &crypto.KeyShareRejectNoResponse } else if device.Trust == id.TrustStateBlacklisted { return &crypto.KeyShareRejectBlacklisted - } else if trustState := helper.mach.ResolveTrust(device); trustState >= cfg.VerificationLevels.Share { + } else if trustState, _ := helper.mach.ResolveTrustContext(ctx, device); trustState >= cfg.VerificationLevels.Share { portal, err := helper.bridge.Bridge.GetPortalByMXID(ctx, info.RoomID) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal to handle key request") diff --git a/crypto/cross_sign_test.go b/crypto/cross_sign_test.go index e11fb018..5e1ffd50 100644 --- a/crypto/cross_sign_test.go +++ b/crypto/cross_sign_test.go @@ -66,7 +66,7 @@ func TestTrustOwnDevice(t *testing.T) { DeviceID: "device", SigningKey: id.Ed25519("deviceKey"), } - if m.IsDeviceTrusted(ownDevice) { + if m.IsDeviceTrusted(context.TODO(), ownDevice) { t.Error("Own device trusted while it shouldn't be") } @@ -78,7 +78,7 @@ func TestTrustOwnDevice(t *testing.T) { if trusted, _ := m.IsUserTrusted(context.TODO(), ownDevice.UserID); !trusted { t.Error("Own user not trusted while they should be") } - if !m.IsDeviceTrusted(ownDevice) { + if !m.IsDeviceTrusted(context.TODO(), ownDevice) { t.Error("Own device not trusted while it should be") } } @@ -123,7 +123,7 @@ func TestTrustOtherDevice(t *testing.T) { if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted { t.Error("Other user trusted while they shouldn't be") } - if m.IsDeviceTrusted(theirDevice) { + if m.IsDeviceTrusted(context.TODO(), theirDevice) { t.Error("Other device trusted while it shouldn't be") } @@ -144,14 +144,14 @@ func TestTrustOtherDevice(t *testing.T) { m.CryptoStore.PutSignature(context.TODO(), otherUser, theirSSK.PublicKey(), otherUser, theirMasterKey.PublicKey(), "sig3") - if m.IsDeviceTrusted(theirDevice) { + if m.IsDeviceTrusted(context.TODO(), theirDevice) { t.Error("Other device trusted before it has been signed with user's SSK") } m.CryptoStore.PutSignature(context.TODO(), otherUser, theirDevice.SigningKey, otherUser, theirSSK.PublicKey(), "sig4") - if !m.IsDeviceTrusted(theirDevice) { + if !m.IsDeviceTrusted(context.TODO(), theirDevice) { t.Error("Other device not trusted while it should be") } } diff --git a/crypto/cross_sign_validation.go b/crypto/cross_sign_validation.go index 04a179df..4cdf0dd5 100644 --- a/crypto/cross_sign_validation.go +++ b/crypto/cross_sign_validation.go @@ -13,6 +13,9 @@ import ( "maunium.net/go/mautrix/id" ) +// ResolveTrust resolves the trust state of the device from cross-signing. +// +// Deprecated: This method doesn't take a context. Use [OlmMachine.ResolveTrustContext] instead. func (mach *OlmMachine) ResolveTrust(device *id.Device) id.TrustState { state, _ := mach.ResolveTrustContext(context.Background(), device) return state @@ -77,8 +80,12 @@ func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Devi } // IsDeviceTrusted returns whether a device has been determined to be trusted either through verification or cross-signing. -func (mach *OlmMachine) IsDeviceTrusted(device *id.Device) bool { - switch mach.ResolveTrust(device) { +// +// Note: this will return false if resolving the trust state fails due to database errors. +// Use [OlmMachine.ResolveTrustContext] if special error handling is required. +func (mach *OlmMachine) IsDeviceTrusted(ctx context.Context, device *id.Device) bool { + trust, _ := mach.ResolveTrustContext(ctx, device) + switch trust { case id.TrustStateVerified, id.TrustStateCrossSignedTOFU, id.TrustStateCrossSignedVerified: return true default: diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index ef5f404f..804e15de 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -417,7 +417,7 @@ func (mach *OlmMachine) findOlmSessionsForUser(ctx context.Context, session *Out Reason: "Device is blacklisted", }} session.Users[userKey] = OGSIgnored - } else if trustState := mach.ResolveTrust(device); trustState < mach.SendKeysMinTrust { + } else if trustState, _ := mach.ResolveTrustContext(ctx, device); trustState < mach.SendKeysMinTrust { log.Debug(). Str("min_trust", mach.SendKeysMinTrust.String()). Str("device_trust", trustState.String()). diff --git a/crypto/keybackup.go b/crypto/keybackup.go index fe0b40dc..00f74175 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -86,7 +86,7 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context, } else if device == nil { log.Warn().Err(err).Msg("Device does not exist, ignoring signature") continue - } else if !mach.IsDeviceTrusted(device) { + } else if !mach.IsDeviceTrusted(ctx, device) { log.Warn().Err(err).Msg("Device is not trusted") continue } else { diff --git a/crypto/keysharing.go b/crypto/keysharing.go index 0ccf006a..ea0ae65d 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -275,7 +275,7 @@ func (mach *OlmMachine) defaultAllowKeyShare(ctx context.Context, device *id.Dev } else if device.Trust == id.TrustStateBlacklisted { log.Debug().Msg("Rejecting key request from blacklisted device") return &KeyShareRejectBlacklisted - } else if trustState := mach.ResolveTrust(device); trustState >= mach.ShareKeysMinTrust { + } else if trustState, _ := mach.ResolveTrustContext(ctx, device); trustState >= mach.ShareKeysMinTrust { log.Debug(). Str("min_trust", mach.SendKeysMinTrust.String()). Str("device_trust", trustState.String()). From 14008caaa4d2348067efec1e1de0e6d441db4d31 Mon Sep 17 00:00:00 2001 From: Brad Murray Date: Thu, 13 Feb 2025 15:52:34 -0500 Subject: [PATCH 1026/1647] crypto/ssss: only accept secret shares from verified devices (#352) Co-authored-by: Tulir Asokan --- crypto/sharing.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/crypto/sharing.go b/crypto/sharing.go index c0f3e209..10e37ccc 100644 --- a/crypto/sharing.go +++ b/crypto/sharing.go @@ -173,6 +173,19 @@ func (mach *OlmMachine) receiveSecret(ctx context.Context, evt *DecryptedOlmEven return } + // https://spec.matrix.org/v1.10/client-server-api/#msecretsend + // "The recipient must ensure... that the device is a verified device owned by the recipient" + if senderDevice, err := mach.GetOrFetchDevice(ctx, evt.Sender, evt.SenderDevice); err != nil { + log.Err(err).Msg("Failed to get or fetch sender device, rejecting secret") + return + } else if senderDevice == nil { + log.Warn().Msg("Unknown sender device, rejecting secret") + return + } else if !mach.IsDeviceTrusted(ctx, senderDevice) { + log.Warn().Msg("Sender device is not verified, rejecting secret") + return + } + mach.secretLock.Lock() secretChan := mach.secretListeners[content.RequestID] mach.secretLock.Unlock() From 5600dd4054b165f3c3cae299372fa2514d28b1f7 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Tue, 11 Feb 2025 19:17:22 -0700 Subject: [PATCH 1027/1647] verificationhelper: add VerificationReady callback for when verification is accepted This callback supersedes the ScanQRCode and ShowQRCode callbacks. Signed-off-by: Sumner Evans --- crypto/verificationhelper/callbacks_test.go | 66 ++++----------- crypto/verificationhelper/qrcode.go | 4 + crypto/verificationhelper/reciprocate.go | 25 +++--- .../verificationhelper/verificationhelper.go | 81 ++++++++++-------- .../verificationhelper_test.go | 82 +++++++++++-------- 5 files changed, 129 insertions(+), 129 deletions(-) diff --git a/crypto/verificationhelper/callbacks_test.go b/crypto/verificationhelper/callbacks_test.go index b5ca9af8..18cb964f 100644 --- a/crypto/verificationhelper/callbacks_test.go +++ b/crypto/verificationhelper/callbacks_test.go @@ -17,12 +17,14 @@ import ( type MockVerificationCallbacks interface { GetRequestedVerifications() map[id.UserID][]id.VerificationTransactionID GetScanQRCodeTransactions() []id.VerificationTransactionID + GetVerificationsReadyTransactions() []id.VerificationTransactionID GetQRCodeShown(id.VerificationTransactionID) *verificationhelper.QRCode } type baseVerificationCallbacks struct { scanQRCodeTransactions []id.VerificationTransactionID verificationsRequested map[id.UserID][]id.VerificationTransactionID + verificationsReady []id.VerificationTransactionID qrCodesShown map[id.VerificationTransactionID]*verificationhelper.QRCode qrCodesScanned map[id.VerificationTransactionID]struct{} doneTransactions map[id.VerificationTransactionID]struct{} @@ -33,6 +35,7 @@ type baseVerificationCallbacks struct { } var _ verificationhelper.RequiredCallbacks = (*baseVerificationCallbacks)(nil) +var _ MockVerificationCallbacks = (*baseVerificationCallbacks)(nil) func newBaseVerificationCallbacks() *baseVerificationCallbacks { return &baseVerificationCallbacks{ @@ -55,6 +58,10 @@ func (c *baseVerificationCallbacks) GetScanQRCodeTransactions() []id.Verificatio return c.scanQRCodeTransactions } +func (c *baseVerificationCallbacks) GetVerificationsReadyTransactions() []id.VerificationTransactionID { + return c.verificationsReady +} + func (c *baseVerificationCallbacks) GetQRCodeShown(txnID id.VerificationTransactionID) *verificationhelper.QRCode { return c.qrCodesShown[txnID] } @@ -85,6 +92,16 @@ func (c *baseVerificationCallbacks) VerificationRequested(ctx context.Context, t c.verificationsRequested[from] = append(c.verificationsRequested[from], txnID) } +func (c *baseVerificationCallbacks) VerificationReady(ctx context.Context, txnID id.VerificationTransactionID, otherDeviceID id.DeviceID, supportsSAS, allowScanQRCode bool, qrCode *verificationhelper.QRCode) { + c.verificationsReady = append(c.verificationsReady, txnID) + if allowScanQRCode { + c.scanQRCodeTransactions = append(c.scanQRCodeTransactions, txnID) + } + if qrCode != nil { + c.qrCodesShown[txnID] = qrCode + } +} + func (c *baseVerificationCallbacks) VerificationCancelled(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) { c.verificationCancellation[txnID] = &event.VerificationCancelEventContent{ Code: code, @@ -116,23 +133,6 @@ func (c *sasVerificationCallbacks) ShowSAS(ctx context.Context, txnID id.Verific c.decimalsShown[txnID] = decimals } -type scanQRCodeVerificationCallbacks struct { - *baseVerificationCallbacks -} - -var _ verificationhelper.ScanQRCodeCallbacks = (*scanQRCodeVerificationCallbacks)(nil) - -func newScanQRCodeVerificationCallbacks() *scanQRCodeVerificationCallbacks { - return &scanQRCodeVerificationCallbacks{newBaseVerificationCallbacks()} -} - -func newScanQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *scanQRCodeVerificationCallbacks { - return &scanQRCodeVerificationCallbacks{base} -} -func (c *scanQRCodeVerificationCallbacks) ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID) { - c.scanQRCodeTransactions = append(c.scanQRCodeTransactions, txnID) -} - type showQRCodeVerificationCallbacks struct { *baseVerificationCallbacks } @@ -147,44 +147,13 @@ func newShowQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) return &showQRCodeVerificationCallbacks{base} } -func (c *showQRCodeVerificationCallbacks) ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *verificationhelper.QRCode) { - c.qrCodesShown[txnID] = qrCode -} - func (c *showQRCodeVerificationCallbacks) QRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) { c.qrCodesScanned[txnID] = struct{}{} } -type showAndScanQRCodeVerificationCallbacks struct { - *baseVerificationCallbacks - *showQRCodeVerificationCallbacks - *scanQRCodeVerificationCallbacks -} - -var _ verificationhelper.ScanQRCodeCallbacks = (*showAndScanQRCodeVerificationCallbacks)(nil) -var _ verificationhelper.ShowQRCodeCallbacks = (*showAndScanQRCodeVerificationCallbacks)(nil) - -func newShowAndScanQRCodeVerificationCallbacks() *showAndScanQRCodeVerificationCallbacks { - base := newBaseVerificationCallbacks() - return &showAndScanQRCodeVerificationCallbacks{ - base, - newShowQRCodeVerificationCallbacks(), - newScanQRCodeVerificationCallbacks(), - } -} - -func newShowAndScanQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *showAndScanQRCodeVerificationCallbacks { - return &showAndScanQRCodeVerificationCallbacks{ - base, - newShowQRCodeVerificationCallbacks(), - newScanQRCodeVerificationCallbacks(), - } -} - type allVerificationCallbacks struct { *baseVerificationCallbacks *sasVerificationCallbacks - *scanQRCodeVerificationCallbacks *showQRCodeVerificationCallbacks } @@ -193,7 +162,6 @@ func newAllVerificationCallbacks() *allVerificationCallbacks { return &allVerificationCallbacks{ base, newSASVerificationCallbacksWithBase(base), - newScanQRCodeVerificationCallbacksWithBase(base), newShowQRCodeVerificationCallbacksWithBase(base), } } diff --git a/crypto/verificationhelper/qrcode.go b/crypto/verificationhelper/qrcode.go index a28d8fc3..11698152 100644 --- a/crypto/verificationhelper/qrcode.go +++ b/crypto/verificationhelper/qrcode.go @@ -82,6 +82,10 @@ func NewQRCodeFromBytes(data []byte) (*QRCode, error) { // // [Section 11.12.2.4.1]: https://spec.matrix.org/v1.9/client-server-api/#qr-code-format func (q *QRCode) Bytes() []byte { + if q == nil { + return nil + } + var buf bytes.Buffer buf.WriteString("MATRIX") // Header buf.WriteByte(0x02) // Version diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index 33dccef9..9cb84c24 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -270,28 +270,30 @@ func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id return nil } -func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *VerificationTransaction) error { +func (vh *VerificationHelper) generateQRCode(ctx context.Context, txn *VerificationTransaction) (*QRCode, error) { log := vh.getLog(ctx).With(). Str("verification_action", "generate and show QR code"). Stringer("transaction_id", txn.TransactionID). Logger() ctx = log.WithContext(ctx) - if vh.showQRCode == nil { - log.Info().Msg("Ignoring QR code generation request as showing a QR code is not enabled on this device") - return nil + + if !slices.Contains(vh.supportedMethods, event.VerificationMethodReciprocate) || + !slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodReciprocate) { + log.Info().Msg("Ignoring QR code generation request as reciprocating is not supported by both devices") + return nil, nil } else if !slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeScan) { log.Info().Msg("Ignoring QR code generation request as other device cannot scan QR codes") - return nil + return nil, nil } ownCrossSigningPublicKeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) if ownCrossSigningPublicKeys == nil || len(ownCrossSigningPublicKeys.MasterKey) == 0 { - return errors.New("failed to get own cross-signing master public key") + return nil, errors.New("failed to get own cross-signing master public key") } ownMasterKeyTrusted, err := vh.mach.CryptoStore.IsKeySignedBy(ctx, vh.client.UserID, ownCrossSigningPublicKeys.MasterKey, vh.client.UserID, vh.mach.OwnIdentity().SigningKey) if err != nil { - return err + return nil, err } mode := QRCodeModeCrossSigning if vh.client.UserID == txn.TheirUserID { @@ -304,7 +306,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *Ve } else { // This is a cross-signing situation. if !ownMasterKeyTrusted { - return errors.New("cannot cross-sign other device when own master key is not trusted") + return nil, errors.New("cannot cross-sign other device when own master key is not trusted") } mode = QRCodeModeCrossSigning } @@ -318,7 +320,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *Ve // Key 2 is the other user's master signing key. theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID) if err != nil { - return err + return nil, err } key2 = theirSigningKeys.MasterKey.Bytes() case QRCodeModeSelfVerifyingMasterKeyTrusted: @@ -328,7 +330,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *Ve // Key 2 is the other device's key. theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) if err != nil { - return err + return nil, err } key2 = theirDevice.SigningKey.Bytes() case QRCodeModeSelfVerifyingMasterKeyUntrusted: @@ -343,6 +345,5 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *Ve qrCode := NewQRCode(mode, txn.TransactionID, [32]byte(key1), [32]byte(key2)) txn.QRCodeSharedSecret = qrCode.SharedSecret - vh.showQRCode(ctx, txn.TransactionID, qrCode) - return nil + return qrCode, nil } diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 92d4de23..1fdbcc70 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -33,6 +33,10 @@ type RequiredCallbacks interface { // from another device. VerificationRequested(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID, fromDevice id.DeviceID) + // VerificationReady is called when a verification request has been + // accepted by both parties. + VerificationReady(ctx context.Context, txnID id.VerificationTransactionID, otherDeviceID id.DeviceID, supportsSAS, supportsScanQRCode bool, qrCode *QRCode) + // VerificationCancelled is called when the verification is cancelled. VerificationCancelled(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) @@ -48,18 +52,7 @@ type ShowSASCallbacks interface { ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, emojiDescriptions []string, decimals []int) } -type ScanQRCodeCallbacks interface { - // ScanQRCode is called when another device has sent a - // m.key.verification.ready event and indicated that they are capable of - // showing a QR code. - ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID) -} - type ShowQRCodeCallbacks interface { - // ShowQRCode is called when the verification has been accepted and a QR - // code should be shown to the user. - ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *QRCode) - // QRCodeScanned is called when the other user has scanned the QR code and // sent the m.key.verification.start event. QRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) @@ -71,24 +64,25 @@ type VerificationHelper struct { store VerificationStore activeTransactionsLock sync.Mutex - // activeTransactions map[id.VerificationTransactionID]*verificationTransaction // supportedMethods are the methods that *we* support supportedMethods []event.VerificationMethod verificationRequested func(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID, fromDevice id.DeviceID) + verificationReady func(ctx context.Context, txnID id.VerificationTransactionID, otherDeviceID id.DeviceID, supportsSAS, supportsScanQRCode bool, qrCode *QRCode) verificationCancelledCallback func(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) verificationDone func(ctx context.Context, txnID id.VerificationTransactionID) + // showSAS is a callback that will be called after the SAS verification + // dance is complete and we want the client to show the emojis/decimals showSAS func(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, emojiDescriptions []string, decimals []int) - - scanQRCode func(ctx context.Context, txnID id.VerificationTransactionID) - showQRCode func(ctx context.Context, txnID id.VerificationTransactionID, qrCode *QRCode) + // qrCodeScaned is a callback that will be called when the other device + // scanned the QR code we are showing qrCodeScaned func(ctx context.Context, txnID id.VerificationTransactionID) } var _ mautrix.VerificationHelper = (*VerificationHelper)(nil) -func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, store VerificationStore, callbacks any, supportsScan bool) *VerificationHelper { +func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, store VerificationStore, callbacks any, supportsQRShow, supportsQRScan bool) *VerificationHelper { if client.Crypto == nil { panic("client.Crypto is nil") } @@ -107,6 +101,7 @@ func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, stor panic("callbacks must implement RequiredCallbacks") } else { helper.verificationRequested = c.VerificationRequested + helper.verificationReady = c.VerificationReady helper.verificationCancelledCallback = c.VerificationCancelled helper.verificationDone = c.VerificationDone } @@ -115,16 +110,18 @@ func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, stor helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodSAS) helper.showSAS = c.ShowSAS } - if c, ok := callbacks.(ShowQRCodeCallbacks); ok { - helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeShow) - helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodReciprocate) - helper.showQRCode = c.ShowQRCode - helper.qrCodeScaned = c.QRCodeScanned + if supportsQRShow { + if c, ok := callbacks.(ShowQRCodeCallbacks); !ok { + panic("callbacks must implement ShowQRCodeCallbacks if supportsQRShow is true") + } else { + helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeShow) + helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodReciprocate) + helper.qrCodeScaned = c.QRCodeScanned + } } - if c, ok := callbacks.(ScanQRCodeCallbacks); ok && supportsScan { + if supportsQRScan { helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeScan) helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodReciprocate) - helper.scanQRCode = c.ScanQRCode } helper.supportedMethods = exslices.DeduplicateUnsorted(helper.supportedMethods) return &helper @@ -421,15 +418,19 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V } txn.VerificationState = VerificationStateReady - if vh.scanQRCode != nil && - slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && // technically redundant because vh.scanQRCode is only set if this is true - slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) { - vh.scanQRCode(ctx, txn.TransactionID) - } + supportsSAS := slices.Contains(vh.supportedMethods, event.VerificationMethodSAS) && + slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodSAS) + supportsReciprocate := slices.Contains(vh.supportedMethods, event.VerificationMethodReciprocate) && + slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodReciprocate) + supportsScanQRCode := supportsReciprocate && + slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && + slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) - if err := vh.generateAndShowQRCode(ctx, &txn); err != nil { + qrCode, err := vh.generateQRCode(ctx, &txn) + if err != nil { return err } + vh.verificationReady(ctx, txn.TransactionID, txn.TheirDeviceID, supportsSAS, supportsScanQRCode, qrCode) return vh.store.SaveVerificationTransaction(ctx, txn) } @@ -737,15 +738,23 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn Verif } } - if vh.scanQRCode != nil && - slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && // technically redundant because vh.scanQRCode is only set if this is true - slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) { - vh.scanQRCode(ctx, txn.TransactionID) + supportsSAS := slices.Contains(vh.supportedMethods, event.VerificationMethodSAS) && + slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodSAS) + supportsReciprocate := slices.Contains(vh.supportedMethods, event.VerificationMethodReciprocate) && + slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodReciprocate) + supportsScanQRCode := supportsReciprocate && + slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && + slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) + + qrCode, err := vh.generateQRCode(ctx, &txn) + if err != nil { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to generate QR code: %w", err) + return } - if err := vh.generateAndShowQRCode(ctx, &txn); err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to generate and show QR code: %w", err) - } else if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { + vh.verificationReady(ctx, txn.TransactionID, txn.TheirDeviceID, supportsSAS, supportsScanQRCode, qrCode) + + if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to save verification transaction: %w", err) } } diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go index 31bc7d6e..e192508b 100644 --- a/crypto/verificationhelper/verificationhelper_test.go +++ b/crypto/verificationhelper/verificationhelper_test.go @@ -71,7 +71,7 @@ func initDefaultCallbacks(t *testing.T, ctx context.Context, sendingClient, rece senderVerificationStore, err := NewSQLiteVerificationStore(ctx, senderVerificationDB) require.NoError(t, err) - sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, senderVerificationStore, sendingCallbacks, true) + sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, senderVerificationStore, sendingCallbacks, true, true) require.NoError(t, sendingHelper.Init(ctx)) receivingCallbacks = newAllVerificationCallbacks() @@ -79,7 +79,7 @@ func initDefaultCallbacks(t *testing.T, ctx context.Context, sendingClient, rece require.NoError(t, err) receiverVerificationStore, err := NewSQLiteVerificationStore(ctx, receiverVerificationDB) require.NoError(t, err) - receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, receiverVerificationStore, receivingCallbacks, true) + receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, receiverVerificationStore, receivingCallbacks, true, true) require.NoError(t, receivingHelper.Init(ctx)) return } @@ -89,20 +89,27 @@ func TestVerification_Start(t *testing.T) { receivingDeviceID2 := id.DeviceID("receiving2") testCases := []struct { + supportsShow bool supportsScan bool callbacks MockVerificationCallbacks startVerificationErrMsg string expectedVerificationMethods []event.VerificationMethod }{ - {false, newBaseVerificationCallbacks(), "no supported verification methods", nil}, - {false, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}}, - {true, newScanQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, - {false, newScanQRCodeVerificationCallbacks(), "no supported verification methods", nil}, - {false, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, - {true, newShowAndScanQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, - {false, newShowAndScanQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, - {false, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, - {true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, + {false, false, newBaseVerificationCallbacks(), "no supported verification methods", nil}, + {false, true, newBaseVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + + {false, false, newShowQRCodeVerificationCallbacks(), "no supported verification methods", nil}, + {true, false, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {false, true, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {true, true, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + + {false, false, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}}, + {false, true, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + + {false, false, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}}, + {false, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {true, false, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {true, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, } for i, tc := range testCases { @@ -115,7 +122,7 @@ func TestVerification_Start(t *testing.T) { addDeviceID(ctx, cryptoStore, aliceUserID, receivingDeviceID) addDeviceID(ctx, cryptoStore, aliceUserID, receivingDeviceID2) - senderHelper := verificationhelper.NewVerificationHelper(client, client.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, tc.callbacks, tc.supportsScan) + senderHelper := verificationhelper.NewVerificationHelper(client, client.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, tc.callbacks, tc.supportsShow, tc.supportsScan) err := senderHelper.Init(ctx) require.NoError(t, err) @@ -162,7 +169,7 @@ func TestVerification_StartThenCancel(t *testing.T) { bystanderClient, _ := ts.Login(t, ctx, aliceUserID, bystanderDeviceID) bystanderMachine := bystanderClient.Crypto.(*cryptohelper.CryptoHelper).Machine() - bystanderHelper := verificationhelper.NewVerificationHelper(bystanderClient, bystanderMachine, nil, newAllVerificationCallbacks(), true) + bystanderHelper := verificationhelper.NewVerificationHelper(bystanderClient, bystanderMachine, nil, newAllVerificationCallbacks(), true, true) require.NoError(t, bystanderHelper.Init(ctx)) require.NoError(t, sendingCryptoStore.PutDevice(ctx, aliceUserID, bystanderMachine.OwnIdentity())) @@ -252,12 +259,12 @@ func TestVerification_Accept_NoSupportedMethods(t *testing.T) { assert.NotEmpty(t, recoveryKey) assert.NotNil(t, cache) - sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, newAllVerificationCallbacks(), true) + sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, newAllVerificationCallbacks(), true, true) err = sendingHelper.Init(ctx) require.NoError(t, err) receivingCallbacks := newBaseVerificationCallbacks() - receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, receivingCallbacks, false) + receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, receivingCallbacks, false, false) err = receivingHelper.Init(ctx) require.NoError(t, err) @@ -278,16 +285,26 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { testCases := []struct { sendingSupportsScan bool + sendingSupportsShow bool receivingSupportsScan bool + receivingSupportsShow bool sendingCallbacks MockVerificationCallbacks receivingCallbacks MockVerificationCallbacks expectedVerificationMethods []event.VerificationMethod }{ - {false, false, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}}, - {true, false, newShowAndScanQRCodeVerificationCallbacks(), newShowAndScanQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, - {false, true, newShowAndScanQRCodeVerificationCallbacks(), newShowAndScanQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, - {true, false, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, - {true, true, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate, event.VerificationMethodSAS}}, + // TODO + {false, false, false, false, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}}, + {true, false, true, false, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}}, + + {true, false, false, true, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}}, + {false, true, true, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}}, + {true, false, true, true, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}}, + {false, true, true, true, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}}, + {true, true, true, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}}, + {true, true, false, true, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}}, + {true, true, true, true, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow}}, + + {true, true, true, true, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow}}, } for i, tc := range testCases { @@ -300,11 +317,11 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { assert.NotEmpty(t, recoveryKey) assert.NotNil(t, sendingCrossSigningKeysCache) - sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, tc.sendingCallbacks, tc.sendingSupportsScan) + sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, tc.sendingCallbacks, tc.sendingSupportsShow, tc.sendingSupportsScan) err = sendingHelper.Init(ctx) require.NoError(t, err) - receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, nil, tc.receivingCallbacks, tc.receivingSupportsScan) + receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, nil, tc.receivingCallbacks, tc.receivingSupportsShow, tc.receivingSupportsScan) err = receivingHelper.Init(ctx) require.NoError(t, err) @@ -322,16 +339,13 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - _, sendingIsQRCallbacks := tc.sendingCallbacks.(*showQRCodeVerificationCallbacks) - _, sendingIsAllCallbacks := tc.sendingCallbacks.(*allVerificationCallbacks) - sendingCanShowQR := sendingIsQRCallbacks || sendingIsAllCallbacks - _, receivingIsQRCallbacks := tc.receivingCallbacks.(*showQRCodeVerificationCallbacks) - _, receivingIsAllCallbacks := tc.receivingCallbacks.(*allVerificationCallbacks) - receivingCanShowQR := receivingIsQRCallbacks || receivingIsAllCallbacks + // Ensure that the receiving device get a notification about the + // transaction being ready. + assert.Contains(t, tc.receivingCallbacks.GetVerificationsReadyTransactions(), txnID) // Ensure that if the receiving device should show a QR code that // it has the correct content. - if tc.sendingSupportsScan && receivingCanShowQR { + if tc.sendingSupportsScan && tc.receivingSupportsShow { receivingShownQRCode := tc.receivingCallbacks.GetQRCodeShown(txnID) require.NotNil(t, receivingShownQRCode) assert.Equal(t, txnID, receivingShownQRCode.TransactionID) @@ -340,7 +354,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { // Check for whether the receiving device should be scanning a QR // code. - if tc.receivingSupportsScan && sendingCanShowQR { + if tc.receivingSupportsScan && tc.sendingSupportsShow { assert.Contains(t, tc.receivingCallbacks.GetScanQRCodeTransactions(), txnID) } @@ -357,9 +371,13 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { // device. ts.dispatchToDevice(t, ctx, sendingClient) + // Ensure that the sending device got a notification about the + // transaction being ready. + assert.Contains(t, tc.sendingCallbacks.GetVerificationsReadyTransactions(), txnID) + // Ensure that if the sending device should show a QR code that it // has the correct content. - if tc.receivingSupportsScan && sendingCanShowQR { + if tc.receivingSupportsScan && tc.sendingSupportsShow { sendingShownQRCode := tc.sendingCallbacks.GetQRCodeShown(txnID) require.NotNil(t, sendingShownQRCode) assert.Equal(t, txnID, sendingShownQRCode.TransactionID) @@ -368,7 +386,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { // Check for whether the sending device should be scanning a QR // code. - if tc.sendingSupportsScan && receivingCanShowQR { + if tc.sendingSupportsScan && tc.receivingSupportsShow { assert.Contains(t, tc.sendingCallbacks.GetScanQRCodeTransactions(), txnID) } }) From b6c225c3430d283bcee8689ef5317200c0401f19 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 16 Feb 2025 17:17:12 +0200 Subject: [PATCH 1028/1647] dependencies: update --- go.mod | 16 ++++++++-------- go.sum | 32 ++++++++++++++++---------------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/go.mod b/go.mod index 69752cdb..910c759d 100644 --- a/go.mod +++ b/go.mod @@ -18,12 +18,12 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.8 - go.mau.fi/util v0.8.5-0.20250212114338-06310c7133a5 + go.mau.fi/util v0.8.5 go.mau.fi/zeroconfig v0.1.3 - golang.org/x/crypto v0.32.0 - golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c - golang.org/x/net v0.34.0 - golang.org/x/sync v0.10.0 + golang.org/x/crypto v0.33.0 + golang.org/x/exp v0.0.0-20250215185904-eff6e970281f + golang.org/x/net v0.35.0 + golang.org/x/sync v0.11.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 ) @@ -33,11 +33,11 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/petermattis/goid v0.0.0-20241211131331-93ee7e083c43 // indirect + github.com/petermattis/goid v0.0.0-20250211185408-f2b9d978cd7a // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect - golang.org/x/sys v0.29.0 // indirect - golang.org/x/text v0.21.0 // indirect + golang.org/x/sys v0.30.0 // indirect + golang.org/x/text v0.22.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index 73a5e58a..653a68da 100644 --- a/go.sum +++ b/go.sum @@ -28,8 +28,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/petermattis/goid v0.0.0-20241211131331-93ee7e083c43 h1:ah1dvbqPMN5+ocrg/ZSgZ6k8bOk+kcZQ7fnyx6UvOm4= -github.com/petermattis/goid v0.0.0-20241211131331-93ee7e083c43/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/petermattis/goid v0.0.0-20250211185408-f2b9d978cd7a h1:ckxP/kGzsxvxXo8jO6E/0QJ8MMmwI7IRj4Fys9QbAZA= +github.com/petermattis/goid v0.0.0-20250211185408-f2b9d978cd7a/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -54,26 +54,26 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.8.5-0.20250212114338-06310c7133a5 h1:b8XKQEONXqnawxcKVdu2HqjmFSa+6tAeHyCI0E/JM2g= -go.mau.fi/util v0.8.5-0.20250212114338-06310c7133a5/go.mod h1:rleUeT8LjZQBUDg/FKAdX/xbd/Vnuof/UoCcnGvGm9M= +go.mau.fi/util v0.8.5 h1:PwCAAtcfK0XxZ4sdErJyfBMkTEWoQU33aB7QqDDzQRI= +go.mau.fi/util v0.8.5/go.mod h1:Ycug9mrbztlahHPEJ6H5r8Nu/xqZaWbE5vPHVWmfz6M= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= -golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= -golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c h1:KL/ZBHXgKGVmuZBZ01Lt57yE5ws8ZPSkkihmEyq7FXc= -golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU= -golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= -golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= -golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= -golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= +golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= +golang.org/x/exp v0.0.0-20250215185904-eff6e970281f h1:oFMYAjX0867ZD2jcNiLBrI9BdpmEkvPyi5YrBGXbamg= +golang.org/x/exp v0.0.0-20250215185904-eff6e970281f/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk= +golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= +golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= +golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= +golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= -golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= -golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= +golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= From 12db97adb35157e87882e7d796dd22265b765fff Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 16 Feb 2025 17:17:34 +0200 Subject: [PATCH 1029/1647] Bump version to v0.23.1 --- CHANGELOG.md | 18 ++++++++++++++++++ version.go | 2 +- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4490978f..85d362ee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,21 @@ +## v0.23.1 (2025-02-16) + +* *(client)* Added `FullStateEvent` method to get a state event including + metadata (using the `?format=event` query parameter). +* *(client)* Added wrapper method for [MSC4194]'s redact endpoint. +* *(pushrules)* Fixed content rules not considering word boundaries and being + case-sensitive. +* *(crypto)* Fixed bugs that would cause key exports to fail for no reason. +* *(crypto)* Deprecated `ResolveTrust` in favor of `ResolveTrustContext`. +* *(crypto)* Stopped accepting secret shares from unverified devices. +* **Breaking change *(crypto)*** Changed `GetAndVerifyLatestKeyBackupVersion` + to take an optional private key parameter. The method will now trust the + public key if it matches the provided private key even if there are no valid + signatures. +* **Breaking change *(crypto)*** Added context parameter to `IsDeviceTrusted`. + +[MSC4194]: https://github.com/matrix-org/matrix-spec-proposals/pull/4194 + ## v0.23.0 (2025-01-16) * **Breaking change *(client)*** Changed `JoinRoom` parameters to allow multiple diff --git a/version.go b/version.go index 16b6b2f2..2ff08518 100644 --- a/version.go +++ b/version.go @@ -7,7 +7,7 @@ import ( "strings" ) -const Version = "v0.23.0" +const Version = "v0.23.1" var GoModVersion = "" var Commit = "" From 4c58b82813051d134abdc3e2a47fe2dbb08f1c14 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 20 Feb 2025 14:13:42 -0700 Subject: [PATCH 1030/1647] verificationhelper/sas: don't trust keys until both MAC events are sent Signed-off-by: Sumner Evans --- crypto/verificationhelper/sas.go | 115 +++++++++++++++++-------------- 1 file changed, 65 insertions(+), 50 deletions(-) diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index a78b4b57..70b77759 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -111,6 +111,7 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat keys := map[id.KeyID]jsonbytes.UnpaddedBytes{} log.Info().Msg("Signing keys") + var masterKey string // My device key myDevice := vh.mach.OwnIdentity() @@ -123,8 +124,9 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat // Master signing key crossSigningKeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx) if crossSigningKeys != nil { - crossSigningKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, crossSigningKeys.MasterKey.String()) - keys[crossSigningKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUserID, txn.TheirDeviceID, crossSigningKeyID.String(), crossSigningKeys.MasterKey.String()) + masterKey = crossSigningKeys.MasterKey.String() + crossSigningKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, masterKey) + keys[crossSigningKeyID], err = vh.verificationMACHKDF(txn, vh.client.UserID, vh.client.DeviceID, txn.TheirUserID, txn.TheirDeviceID, crossSigningKeyID.String(), masterKey) if err != nil { return err } @@ -148,10 +150,16 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat if err != nil { return err } + log.Info().Msg("Sent our MAC event") txn.SentOurMAC = true if txn.ReceivedTheirMAC { txn.VerificationState = VerificationStateSASMACExchanged + + if err := vh.trustKeysAfterMACCheck(ctx, txn, masterKey); err != nil { + return fmt.Errorf("failed to trust keys: %w", err) + } + err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { return err @@ -731,57 +739,15 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific } log.Info().Msg("All MACs verified") - // Trust their device - theirDevice.Trust = id.TrustStateVerified - err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) - if err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to update device trust state after verifying: %w", err) - return - } - - if txn.TheirUserID == vh.client.UserID { - // Self-signing situation. - // - // If we have the cross-signing keys, then we need to sign their device - // using the self-signing key. Otherwise, they have the master private - // key, so we need to trust the master public key. - if vh.mach.CrossSigningKeys != nil { - err = vh.mach.SignOwnDevice(ctx, theirDevice) - if err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to sign our own new device: %w", err) - return - } - } else { - err = vh.mach.SignOwnMasterKey(ctx) - if err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to sign our own master key: %w", err) - return - } - } - } else if masterKey != "" { - // Cross-signing situation. - // - // The master key was included in the list of keys to verify, so verify - // that it matches what we expect and sign their master key using the - // user-signing key. - theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID) - if err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "couldn't get %s's cross-signing keys: %w", txn.TheirUserID, err) - return - } else if theirSigningKeys.MasterKey.String() != masterKey { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "master keys do not match") - return - } - - if err := vh.mach.SignUser(ctx, txn.TheirUserID, theirSigningKeys.MasterKey); err != nil { - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to sign their master key: %w", err) - return - } - } - txn.ReceivedTheirMAC = true if txn.SentOurMAC { txn.VerificationState = VerificationStateSASMACExchanged + + if err := vh.trustKeysAfterMACCheck(ctx, txn, masterKey); err != nil { + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to trust keys: %w", err) + return + } + err := vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationDone, &event.VerificationDoneEventContent{}) if err != nil { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUser, "failed to send verification done event: %w", err) @@ -794,3 +760,52 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific log.Err(err).Msg("failed to save verification transaction") } } + +func (vh *VerificationHelper) trustKeysAfterMACCheck(ctx context.Context, txn VerificationTransaction, masterKey string) error { + theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID) + if err != nil { + return fmt.Errorf("failed to fetch their device: %w", err) + } + // Trust their device + theirDevice.Trust = id.TrustStateVerified + err = vh.mach.CryptoStore.PutDevice(ctx, txn.TheirUserID, theirDevice) + if err != nil { + return fmt.Errorf("failed to update device trust state after verifying: %w", err) + } + + if txn.TheirUserID == vh.client.UserID { + // Self-signing situation. + // + // If we have the cross-signing keys, then we need to sign their device + // using the self-signing key. Otherwise, they have the master private + // key, so we need to trust the master public key. + if vh.mach.CrossSigningKeys != nil { + err = vh.mach.SignOwnDevice(ctx, theirDevice) + if err != nil { + return fmt.Errorf("failed to sign our own new device: %w", err) + } + } else { + err = vh.mach.SignOwnMasterKey(ctx) + if err != nil { + return fmt.Errorf("failed to sign our own master key: %w", err) + } + } + } else if masterKey != "" { + // Cross-signing situation. + // + // The master key was included in the list of keys to verify, so verify + // that it matches what we expect and sign their master key using the + // user-signing key. + theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID) + if err != nil { + return fmt.Errorf("couldn't get %s's cross-signing keys: %w", txn.TheirUserID, err) + } else if theirSigningKeys.MasterKey.String() != masterKey { + return fmt.Errorf("master keys do not match") + } + + if err := vh.mach.SignUser(ctx, txn.TheirUserID, theirSigningKeys.MasterKey); err != nil { + return fmt.Errorf("failed to sign %s's master key: %w", txn.TheirUserID, err) + } + } + return nil +} From fcdf7fd1933a83b4cc0f5cff903401c5e16e8789 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 22 Feb 2025 20:16:00 +0200 Subject: [PATCH 1031/1647] bridgev2/commands: add command to set new management room --- bridgev2/commands/cleanup.go | 40 ++++++++++++++++++++++++++++++++++ bridgev2/commands/processor.go | 2 +- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/bridgev2/commands/cleanup.go b/bridgev2/commands/cleanup.go index f8ad1d23..dc21a16e 100644 --- a/bridgev2/commands/cleanup.go +++ b/bridgev2/commands/cleanup.go @@ -55,3 +55,43 @@ var CommandDeleteAllPortals = &FullHandler{ }, RequiresAdmin: true, } + +var CommandSetManagementRoom = &FullHandler{ + Func: func(ce *Event) { + if ce.User.ManagementRoom == ce.RoomID { + ce.Reply("This room is already your management room") + return + } else if ce.Portal != nil { + ce.Reply("This is a portal room: you can't set this as your management room") + return + } + members, err := ce.Bridge.Matrix.GetMembers(ce.Ctx, ce.RoomID) + if err != nil { + ce.Log.Err(err).Msg("Failed to get room members to check if room can be a management room") + ce.Reply("Failed to get room members") + return + } + _, hasBot := members[ce.Bot.GetMXID()] + if !hasBot { + // This reply will probably fail, but whatever + ce.Reply("The bridge bot must be in the room to set it as your management room") + return + } else if len(members) != 2 { + ce.Reply("Your management room must not have any members other than you and the bridge bot") + return + } + ce.User.ManagementRoom = ce.RoomID + err = ce.User.Save(ce.Ctx) + if err != nil { + ce.Log.Err(err).Msg("Failed to save management room") + ce.Reply("Failed to save management room") + } else { + ce.Reply("Management room updated") + } + }, + Name: "set-management-room", + Help: HelpMeta{ + Section: HelpSectionGeneral, + Description: "Mark this room as your management room", + }, +} diff --git a/bridgev2/commands/processor.go b/bridgev2/commands/processor.go index 3343e1ba..610fb825 100644 --- a/bridgev2/commands/processor.go +++ b/bridgev2/commands/processor.go @@ -42,7 +42,7 @@ func NewProcessor(bridge *bridgev2.Bridge) bridgev2.CommandProcessor { } proc.AddHandlers( CommandHelp, CommandCancel, - CommandRegisterPush, CommandDeletePortal, CommandDeleteAllPortals, + CommandRegisterPush, CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom, CommandLogin, CommandRelogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin, CommandSetRelay, CommandUnsetRelay, CommandResolveIdentifier, CommandStartChat, CommandSearch, From af84927e3125c0c8cc11bb2824a6d6d1da4fa400 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 22 Feb 2025 22:59:50 +0200 Subject: [PATCH 1032/1647] bridgev2: add option to disable bridging m.notice messages --- bridgev2/bridgeconfig/config.go | 1 + bridgev2/bridgeconfig/upgrade.go | 1 + bridgev2/errors.go | 1 + bridgev2/matrix/mxmain/example-config.yaml | 2 ++ bridgev2/portal.go | 7 ++++++- 5 files changed, 11 insertions(+), 1 deletion(-) diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index 12d5452b..156fb772 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -66,6 +66,7 @@ type BridgeConfig struct { ResendBridgeInfo bool `yaml:"resend_bridge_info"` NoBridgeInfoStateKey bool `yaml:"no_bridge_info_state_key"` BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` + BridgeNotices bool `yaml:"bridge_notices"` TagOnlyOnCreate bool `yaml:"tag_only_on_create"` OnlyBridgeTags []event.RoomTag `yaml:"only_bridge_tags"` MuteOnlyOnCreate bool `yaml:"mute_only_on_create"` diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index ea986fda..07477ef1 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -32,6 +32,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "bridge", "resend_bridge_info") helper.Copy(up.Bool, "bridge", "no_bridge_info_state_key") helper.Copy(up.Bool, "bridge", "bridge_matrix_leave") + helper.Copy(up.Bool, "bridge", "bridge_notices") helper.Copy(up.Bool, "bridge", "tag_only_on_create") helper.Copy(up.List, "bridge", "only_bridge_tags") helper.Copy(up.Bool, "bridge", "mute_only_on_create") diff --git a/bridgev2/errors.go b/bridgev2/errors.go index 789d0026..0e948184 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -58,6 +58,7 @@ var ( ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) ErrUnsupportedMediaType error = WrapErrorInStatus(errors.New("unsupported media type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) + ErrIgnoringMNotice error = WrapErrorInStatus(errors.New("ignoring m.notice message")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true) ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true) ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true) diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 48d6a77e..86838ff1 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -24,6 +24,8 @@ bridge: # Should leaving Matrix rooms be bridged as leaving groups on the remote network? bridge_matrix_leave: false + # Should `m.notice` messages be bridged? + bridge_notices: false # Should room tags only be synced when creating the portal? Tags mean things like favorite/pin and archive/low priority. # Tags currently can't be synced back to the remote network, so a continuous sync means tagging from Matrix will be undone. tag_only_on_create: true diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 63874333..803c3c7c 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -511,7 +511,8 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * if err != nil { log.Err(err).Msg("Failed to get user login to handle Matrix event") if errors.Is(err, ErrNotLoggedIn) { - portal.sendErrorStatus(ctx, evt, WrapErrorInStatus(err).WithMessage("You're not logged in").WithIsCertain(true).WithSendNotice(true)) + shouldSendNotice := evt.Content.AsMessage().MsgType != event.MsgNotice + portal.sendErrorStatus(ctx, evt, WrapErrorInStatus(err).WithMessage("You're not logged in").WithIsCertain(true).WithSendNotice(shouldSendNotice)) } else { portal.sendErrorStatus(ctx, evt, WrapErrorInStatus(err).WithMessage("Failed to get login to handle event").WithIsCertain(true).WithSendNotice(true)) } @@ -819,6 +820,10 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin if evt.Type == event.EventSticker { msgContent.MsgType = event.CapMsgSticker } + if msgContent.MsgType == event.MsgNotice && !portal.Bridge.Config.BridgeNotices { + portal.sendErrorStatus(ctx, evt, ErrIgnoringMNotice) + return + } } if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") From 83f81ea67e14982f0bf336ac0b6b43ade1597e84 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 22 Feb 2025 23:00:02 +0200 Subject: [PATCH 1033/1647] bridgev2/messagestatus: send errors as m.notice --- bridgev2/messagestatus.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index c846f502..b29eac0b 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -219,7 +219,7 @@ func (ms *MessageStatus) ToNoticeEvent(evt *MessageStatusEventInfo) *event.Messa messagePrefix = "Handling your command panicked" } content := &event.MessageEventContent{ - MsgType: event.MsgText, + MsgType: event.MsgNotice, Body: fmt.Sprintf("\u26a0\ufe0f %s: %s", messagePrefix, msg), RelatesTo: &event.RelatesTo{}, Mentions: &event.Mentions{}, From 43dbbb1ff847c5053238c883c81acc5fbcfbcb81 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 22 Feb 2025 23:00:45 +0200 Subject: [PATCH 1034/1647] bridgev2/portal: add m.mentions for disappearing message timer notice --- bridgev2/portal.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 803c3c7c..d23e73dd 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -3377,8 +3377,9 @@ func (portal *Portal) updateUserLocalInfo(ctx context.Context, info *UserLocalPo func DisappearingMessageNotice(expiration time.Duration, implicit bool) *event.MessageEventContent { formattedDuration := exfmt.DurationCustom(expiration, nil, exfmt.Day, time.Hour, time.Minute, time.Second) content := &event.MessageEventContent{ - MsgType: event.MsgNotice, - Body: fmt.Sprintf("Set the disappearing message timer to %s", formattedDuration), + MsgType: event.MsgNotice, + Body: fmt.Sprintf("Set the disappearing message timer to %s", formattedDuration), + Mentions: &event.Mentions{}, } if implicit { content.Body = fmt.Sprintf("Automatically enabled disappearing message timer (%s) because incoming message is disappearing", formattedDuration) From e879ad19ccf324c9e8222756466a9162c0c36574 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 23 Feb 2025 16:24:03 +0200 Subject: [PATCH 1035/1647] crypto/decryptmegolm: hide state event decryption behind flag --- crypto/decryptmegolm.go | 2 +- crypto/machine.go | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index 00f99ce4..11ab0f49 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -149,7 +149,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event } else if megolmEvt.RoomID != encryptionRoomID { return nil, WrongRoom } - if evt.StateKey != nil && megolmEvt.StateKey != nil { + if evt.StateKey != nil && megolmEvt.StateKey != nil && mach.AllowEncryptedState { megolmEvt.Type.Class = event.StateEventType } else { megolmEvt.Type.Class = evt.Type.Class diff --git a/crypto/machine.go b/crypto/machine.go index 4594b9d8..4ad4347c 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -35,7 +35,8 @@ type OlmMachine struct { BackgroundCtx context.Context - PlaintextMentions bool + PlaintextMentions bool + AllowEncryptedState bool // Never ask the server for keys automatically as a side effect during Megolm decryption. DisableDecryptKeyFetching bool From 1cc073cde6ca7ec81502554fd7ac250586ccde2e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 23 Feb 2025 18:12:28 +0200 Subject: [PATCH 1036/1647] client: add wrapper method for MSC2815 --- client.go | 8 ++++++++ error.go | 3 +++ versions.go | 11 ++++++----- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/client.go b/client.go index ae37798e..2d67231c 100644 --- a/client.go +++ b/client.go @@ -2013,6 +2013,14 @@ func (cli *Client) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.Ev return } +func (cli *Client) GetUnredactedEventContent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (resp *event.Event, err error) { + urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "event", eventID}, map[string]string{ + "fi.mau.msc2815.include_unredacted_content": "true", + }) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + return +} + func (cli *Client) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.EventID) (err error) { return cli.SendReceipt(ctx, roomID, eventID, event.ReceiptTypeRead, nil) } diff --git a/error.go b/error.go index 653ac5a1..38039464 100644 --- a/error.go +++ b/error.go @@ -71,6 +71,9 @@ var ( MBadStatus = RespError{ErrCode: "M_BAD_STATUS"} MConnectionTimeout = RespError{ErrCode: "M_CONNECTION_TIMEOUT"} MConnectionFailed = RespError{ErrCode: "M_CONNECTION_FAILED"} + + MUnredactedContentDeleted = RespError{ErrCode: "FI.MAU.MSC2815_UNREDACTED_CONTENT_DELETED"} + MUnredactedContentNotReceived = RespError{ErrCode: "FI.MAU.MSC2815_UNREDACTED_CONTENT_NOT_RECEIVED"} ) // HTTPError An HTTP Error response, which may wrap an underlying native Go Error. diff --git a/versions.go b/versions.go index 183bc9ad..7e752986 100644 --- a/versions.go +++ b/versions.go @@ -60,11 +60,12 @@ type UnstableFeature struct { } var ( - FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17} - FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17} - FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111} - FeatureMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"} - FeatureUserRedaction = UnstableFeature{UnstableFlag: "org.matrix.msc4194"} + FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17} + FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17} + FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111} + FeatureMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"} + FeatureUserRedaction = UnstableFeature{UnstableFlag: "org.matrix.msc4194"} + FeatureViewRedactedContent = UnstableFeature{UnstableFlag: "fi.mau.msc2815"} BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"} BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"} From 0115ba0258cb8f3cce47647cf70d6d1bdb284609 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 24 Feb 2025 11:47:59 -0700 Subject: [PATCH 1037/1647] event: add support for MSC3765 rich text room topics Signed-off-by: Sumner Evans --- event/poll.go | 9 +++------ event/state.go | 17 ++++++++++++++++- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/event/poll.go b/event/poll.go index 37333015..47131a8f 100644 --- a/event/poll.go +++ b/event/poll.go @@ -29,12 +29,9 @@ func (content *PollResponseEventContent) SetRelatesTo(rel *RelatesTo) { } type MSC1767Message struct { - Text string `json:"org.matrix.msc1767.text,omitempty"` - HTML string `json:"org.matrix.msc1767.html,omitempty"` - Message []struct { - MimeType string `json:"mimetype"` - Body string `json:"body"` - } `json:"org.matrix.msc1767.message,omitempty"` + Text string `json:"org.matrix.msc1767.text,omitempty"` + HTML string `json:"org.matrix.msc1767.html,omitempty"` + Message []ExtensibleText `json:"org.matrix.msc1767.message,omitempty"` } type PollStartEventContent struct { diff --git a/event/state.go b/event/state.go index 15972892..4e926dcd 100644 --- a/event/state.go +++ b/event/state.go @@ -42,7 +42,22 @@ type ServerACLEventContent struct { // TopicEventContent represents the content of a m.room.topic state event. // https://spec.matrix.org/v1.2/client-server-api/#mroomtopic type TopicEventContent struct { - Topic string `json:"topic"` + Topic string `json:"topic"` + ExtensibleTopic *ExtensibleTopic `json:"m.topic,omitempty"` +} + +// ExtensibleTopic represents the contents of the m.topic field within the +// m.room.topic state event as described in [MSC3765]. +// +// [MSC3765]: https://github.com/matrix-org/matrix-spec-proposals/pull/3765 +type ExtensibleTopic struct { + Text []ExtensibleText `json:"m.text"` +} + +// ExtensibleText represents the contents of an m.text field. +type ExtensibleText struct { + MimeType string `json:"mimetype,omitempty"` + Body string `json:"body"` } // TombstoneEventContent represents the content of a m.room.tombstone state event. From b72caa948c18be24000a91d844bb1dd7c192655b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 26 Feb 2025 22:56:29 +0200 Subject: [PATCH 1038/1647] format/htmlparser: keep <> when converting links without text --- format/htmlparser.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/format/htmlparser.go b/format/htmlparser.go index 7c3b3c88..25543926 100644 --- a/format/htmlparser.go +++ b/format/htmlparser.go @@ -455,7 +455,7 @@ var MarkdownHTMLParser = &HTMLParser{ PillConverter: DefaultPillConverter, LinkConverter: func(text, href string, ctx Context) string { if text == href { - return text + return fmt.Sprintf("<%s>", href) } return fmt.Sprintf("[%s](%s)", text, href) }, From 2e7bdbc7a2c393524009619eff5484fcdda06e1d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 27 Feb 2025 16:48:25 +0200 Subject: [PATCH 1039/1647] bridgev2/portal: delete old reaction before sending new one --- bridgev2/portal.go | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index d23e73dd..dcd3cd37 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2254,8 +2254,10 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User existing[existingReaction.SenderID][existingReaction.EmojiID] = existingReaction } - doAddReaction := func(new *BackfillReaction) MatrixAPI { - intent := portal.GetIntentFor(ctx, new.Sender, source, RemoteEventReactionSync) + doAddReaction := func(new *BackfillReaction, intent MatrixAPI) MatrixAPI { + if intent == nil { + intent = portal.GetIntentFor(ctx, new.Sender, source, RemoteEventReactionSync) + } portal.sendConvertedReaction( ctx, new.Sender.Sender, intent, targetMessage, new.EmojiID, new.Emoji, new.Timestamp, new.DBMetadata, new.ExtraContent, @@ -2299,8 +2301,9 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User } } doOverwriteReaction := func(new *BackfillReaction, old *database.Reaction) { - intent := doAddReaction(new) + intent := portal.GetIntentFor(ctx, new.Sender, source, RemoteEventReactionSync) doRemoveReaction(old, intent, false) + doAddReaction(new, intent) } newData := evt.GetReactions() @@ -2319,7 +2322,7 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User } doOverwriteReaction(reaction, existingReaction) } else { - doAddReaction(reaction) + doAddReaction(reaction, nil) } } totalReactionCount := len(existingUserReactions) + len(reactions.Reactions) @@ -2381,7 +2384,6 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi if metaProvider, ok := evt.(RemoteReactionWithMeta); ok { dbMetadata = metaProvider.GetReactionDBMetadata() } - portal.sendConvertedReaction(ctx, evt.GetSender().Sender, intent, targetMessage, emojiID, emoji, ts, dbMetadata, extra, nil) if existingReaction != nil { _, err = intent.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ Parsed: &event.RedactionEventContent{ @@ -2392,6 +2394,7 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi log.Err(err).Msg("Failed to redact old reaction") } } + portal.sendConvertedReaction(ctx, evt.GetSender().Sender, intent, targetMessage, emojiID, emoji, ts, dbMetadata, extra, nil) } func (portal *Portal) sendConvertedReaction( From 02733b5775d4fff8182b557983e465c4b515f4d8 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Thu, 27 Feb 2025 11:59:07 -0700 Subject: [PATCH 1040/1647] deps/go-util: upgrade to support -b in curl parsing Signed-off-by: Sumner Evans --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 910c759d..a622b05f 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.8 - go.mau.fi/util v0.8.5 + go.mau.fi/util v0.8.6-0.20250227184636-7ff63b0b9d95 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.33.0 golang.org/x/exp v0.0.0-20250215185904-eff6e970281f diff --git a/go.sum b/go.sum index 653a68da..c92de482 100644 --- a/go.sum +++ b/go.sum @@ -54,8 +54,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.8.5 h1:PwCAAtcfK0XxZ4sdErJyfBMkTEWoQU33aB7QqDDzQRI= -go.mau.fi/util v0.8.5/go.mod h1:Ycug9mrbztlahHPEJ6H5r8Nu/xqZaWbE5vPHVWmfz6M= +go.mau.fi/util v0.8.6-0.20250227184636-7ff63b0b9d95 h1:5EfVWWjU2Hte9uE6B/hBgvjnVfBx/7SYDZBnsuo+EBs= +go.mau.fi/util v0.8.6-0.20250227184636-7ff63b0b9d95/go.mod h1:Ycug9mrbztlahHPEJ6H5r8Nu/xqZaWbE5vPHVWmfz6M= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= From c7cb9ff2a3888a0c71952e475ca687f74ad0edde Mon Sep 17 00:00:00 2001 From: Brad Murray Date: Fri, 28 Feb 2025 09:45:28 -0500 Subject: [PATCH 1041/1647] crypto/keybackup: log mismatching public keys when verifiying megolm backups (#356) Co-authored-by: Sumner Evans --- crypto/keybackup.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/crypto/keybackup.go b/crypto/keybackup.go index 00f74175..a9686fdf 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -54,9 +54,15 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context, // ...by deriving the public key from a private key that it obtained from a trusted source. Trusted sources for the private // key include the user entering the key, retrieving the key stored in secret storage, or obtaining the key via secret sharing // from a verified device belonging to the same user." - if megolmBackupKey != nil && versionInfo.AuthData.PublicKey == id.Ed25519(base64.RawStdEncoding.EncodeToString(megolmBackupKey.PublicKey().Bytes())) { - log.Debug().Msg("key backup is trusted based on public key") + megolmBackupDerivedPublicKey := id.Ed25519(base64.RawStdEncoding.EncodeToString(megolmBackupKey.PublicKey().Bytes())) + if megolmBackupKey != nil && versionInfo.AuthData.PublicKey == megolmBackupDerivedPublicKey { + log.Debug().Msg("key backup is trusted based on derived public key") return versionInfo, nil + } else { + log.Debug(). + Stringer("expected_key", megolmBackupDerivedPublicKey). + Stringer("actual_key", versionInfo.AuthData.PublicKey). + Msg("key backup public keys do not match, proceeding to check device signatures") } // "...or checking that it is signed by the user’s master cross-signing key or by a verified device belonging to the same user" From 006bbe28068db1e60465e94b30980677fbd8508c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 2 Mar 2025 18:41:18 +0200 Subject: [PATCH 1042/1647] crypto/helper: use sqlite3-fk-wal by default --- crypto/cryptohelper/cryptohelper.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go index 1b0e08e1..f03835ef 100644 --- a/crypto/cryptohelper/cryptohelper.go +++ b/crypto/cryptohelper/cryptohelper.go @@ -15,6 +15,7 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/dbutil" + _ "go.mau.fi/util/dbutil/litestream" "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto" @@ -78,7 +79,7 @@ func NewCryptoHelper(cli *mautrix.Client, pickleKey []byte, store any) (*CryptoH } unmanagedCryptoStore = typedStore case string: - db, err := dbutil.NewWithDialect(typedStore, "sqlite3") + db, err := dbutil.NewWithDialect(fmt.Sprintf("file:%s?_txlock=immediate", typedStore), "sqlite3-fk-wal") if err != nil { return nil, err } From 0c1fc68ec3f8e5ea4e1d3fa2b912694afa7db320 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 4 Mar 2025 02:37:08 +0200 Subject: [PATCH 1043/1647] crypto/machine: return unhandled to-device events in HandleEncryptedEvent --- crypto/machine.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/crypto/machine.go b/crypto/machine.go index 4ad4347c..cacc73b6 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -363,16 +363,16 @@ func (mach *OlmMachine) HandleMemberEvent(ctx context.Context, evt *event.Event) } } -func (mach *OlmMachine) HandleEncryptedEvent(ctx context.Context, evt *event.Event) { +func (mach *OlmMachine) HandleEncryptedEvent(ctx context.Context, evt *event.Event) *DecryptedOlmEvent { if _, ok := evt.Content.Parsed.(*event.EncryptedEventContent); !ok { mach.machOrContextLog(ctx).Warn().Msg("Passed invalid event to encrypted handler") - return + return nil } decryptedEvt, err := mach.decryptOlmEvent(ctx, evt) if err != nil { mach.machOrContextLog(ctx).Error().Err(err).Msg("Failed to decrypt to-device event") - return + return nil } log := mach.machOrContextLog(ctx).With(). @@ -401,7 +401,9 @@ func (mach *OlmMachine) HandleEncryptedEvent(ctx context.Context, evt *event.Eve log.Trace().Msg("Handled secret send event") default: log.Debug().Msg("Unhandled encrypted to-device event") + return decryptedEvt } + return nil } const olmHashSavePointCount = 5 From 7d3791ace3b7c866aec5e1b147ea1236da68c43e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 4 Mar 2025 02:46:59 +0200 Subject: [PATCH 1044/1647] crypto/encryptolm: add generic method to encrypt to-device events --- crypto/encryptolm.go | 64 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/crypto/encryptolm.go b/crypto/encryptolm.go index 52e30166..80b76dc5 100644 --- a/crypto/encryptolm.go +++ b/crypto/encryptolm.go @@ -17,6 +17,70 @@ import ( "maunium.net/go/mautrix/id" ) +func (mach *OlmMachine) EncryptToDevices(ctx context.Context, eventType event.Type, req *mautrix.ReqSendToDevice) (*mautrix.ReqSendToDevice, error) { + devicesToCreateSessions := make(map[id.UserID]map[id.DeviceID]*id.Device) + for userID, devices := range req.Messages { + for deviceID := range devices { + device, err := mach.GetOrFetchDevice(ctx, userID, deviceID) + if err != nil { + return nil, fmt.Errorf("failed to get device %s of user %s: %w", deviceID, userID, err) + } + + if _, ok := devicesToCreateSessions[userID]; !ok { + devicesToCreateSessions[userID] = make(map[id.DeviceID]*id.Device) + } + devicesToCreateSessions[userID][deviceID] = device + } + } + if err := mach.createOutboundSessions(ctx, devicesToCreateSessions); err != nil { + return nil, fmt.Errorf("failed to create outbound sessions: %w", err) + } + + mach.olmLock.Lock() + defer mach.olmLock.Unlock() + + encryptedReq := &mautrix.ReqSendToDevice{ + Messages: make(map[id.UserID]map[id.DeviceID]*event.Content), + } + + log := mach.machOrContextLog(ctx) + + for userID, devices := range req.Messages { + encryptedReq.Messages[userID] = make(map[id.DeviceID]*event.Content) + + for deviceID, content := range devices { + device := devicesToCreateSessions[userID][deviceID] + + olmSess, err := mach.CryptoStore.GetLatestSession(ctx, device.IdentityKey) + if err != nil { + return nil, fmt.Errorf("failed to get latest session for device %s of %s: %w", deviceID, userID, err) + } else if olmSess == nil { + log.Warn(). + Str("target_user_id", userID.String()). + Str("target_device_id", deviceID.String()). + Str("identity_key", device.IdentityKey.String()). + Msg("No outbound session found for device") + continue + } + + encrypted := mach.encryptOlmEvent(ctx, olmSess, device, eventType, *content) + encryptedContent := &event.Content{Parsed: &encrypted} + + log.Debug(). + Str("decrypted_type", eventType.Type). + Str("target_user_id", userID.String()). + Str("target_device_id", deviceID.String()). + Str("target_identity_key", device.IdentityKey.String()). + Str("olm_session_id", olmSess.ID().String()). + Msg("Encrypted to-device event") + + encryptedReq.Messages[userID][deviceID] = encryptedContent + } + } + + return encryptedReq, nil +} + func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession, recipient *id.Device, evtType event.Type, content event.Content) *event.EncryptedEventContent { evt := &DecryptedOlmEvent{ Sender: mach.Client.UserID, From 32b2376409f896c2a25b5ee40be72577aa4870da Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 4 Mar 2025 12:34:42 +0200 Subject: [PATCH 1045/1647] crypto/ssss: fix panic if key metadata is corrupted Fixes tulir/gomuks#601 --- crypto/ssss/key.go | 7 ++++- crypto/ssss/meta.go | 40 ++++++++++++++++++++------ crypto/ssss/meta_test.go | 61 +++++++++++++++++++++++++++------------- crypto/ssss/types.go | 1 + 4 files changed, 80 insertions(+), 29 deletions(-) diff --git a/crypto/ssss/key.go b/crypto/ssss/key.go index c973c1fe..aa22360a 100644 --- a/crypto/ssss/key.go +++ b/crypto/ssss/key.go @@ -57,7 +57,12 @@ func NewKey(passphrase string) (*Key, error) { // We store a certain hash in the key metadata so that clients can check if the user entered the correct key. ivBytes := random.Bytes(utils.AESCTRIVLength) keyData.IV = base64.RawStdEncoding.EncodeToString(ivBytes) - keyData.MAC = keyData.calculateHash(ssssKey) + var err error + keyData.MAC, err = keyData.calculateHash(ssssKey) + if err != nil { + // This should never happen because we just generated the IV and key. + return nil, fmt.Errorf("failed to calculate hash: %w", err) + } return &Key{ Key: ssssKey, diff --git a/crypto/ssss/meta.go b/crypto/ssss/meta.go index 210bcdcf..474c85d8 100644 --- a/crypto/ssss/meta.go +++ b/crypto/ssss/meta.go @@ -33,8 +33,8 @@ func (kd *KeyMetadata) VerifyPassphrase(keyID, passphrase string) (*Key, error) ssssKey, err := kd.Passphrase.GetKey(passphrase) if err != nil { return nil, err - } else if !kd.VerifyKey(ssssKey) { - return nil, ErrIncorrectSSSSKey + } else if err = kd.verifyKey(ssssKey); err != nil { + return nil, err } return &Key{ @@ -49,8 +49,8 @@ func (kd *KeyMetadata) VerifyRecoveryKey(keyID, recoveryKey string) (*Key, error ssssKey := utils.DecodeBase58RecoveryKey(recoveryKey) if ssssKey == nil { return nil, ErrInvalidRecoveryKey - } else if !kd.VerifyKey(ssssKey) { - return nil, ErrIncorrectSSSSKey + } else if err := kd.verifyKey(ssssKey); err != nil { + return nil, err } return &Key{ @@ -60,22 +60,46 @@ func (kd *KeyMetadata) VerifyRecoveryKey(keyID, recoveryKey string) (*Key, error }, nil } +func (kd *KeyMetadata) verifyKey(key []byte) error { + unpaddedMAC := strings.TrimRight(kd.MAC, "=") + expectedMACLength := base64.RawStdEncoding.EncodedLen(utils.SHAHashLength) + if len(unpaddedMAC) != expectedMACLength { + return fmt.Errorf("%w: invalid mac length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedMAC), expectedMACLength) + } + hash, err := kd.calculateHash(key) + if err != nil { + return err + } + if unpaddedMAC != hash { + return ErrIncorrectSSSSKey + } + return nil +} + // VerifyKey verifies the SSSS key is valid by calculating and comparing its MAC. func (kd *KeyMetadata) VerifyKey(key []byte) bool { - return strings.TrimRight(kd.MAC, "=") == kd.calculateHash(key) + return kd.verifyKey(key) == nil } // calculateHash calculates the hash used for checking if the key is entered correctly as described // in the spec: https://matrix.org/docs/spec/client_server/unstable#m-secret-storage-v1-aes-hmac-sha2 -func (kd *KeyMetadata) calculateHash(key []byte) string { +func (kd *KeyMetadata) calculateHash(key []byte) (string, error) { aesKey, hmacKey := utils.DeriveKeysSHA256(key, "") + unpaddedIV := strings.TrimRight(kd.IV, "=") + expectedIVLength := base64.RawStdEncoding.EncodedLen(utils.AESCTRIVLength) + if len(unpaddedIV) != expectedIVLength { + return "", fmt.Errorf("%w: invalid iv length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedIV), expectedIVLength) + } var ivBytes [utils.AESCTRIVLength]byte - _, _ = base64.RawStdEncoding.Decode(ivBytes[:], []byte(strings.TrimRight(kd.IV, "="))) + _, err := base64.RawStdEncoding.Decode(ivBytes[:], []byte(unpaddedIV)) + if err != nil { + return "", fmt.Errorf("%w: failed to decode iv: %w", ErrCorruptedKeyMetadata, err) + } cipher := utils.XorA256CTR(make([]byte, utils.AESCTRKeyLength), aesKey, ivBytes) - return utils.HMACSHA256B64(cipher, hmacKey) + return utils.HMACSHA256B64(cipher, hmacKey), nil } // PassphraseMetadata represents server-side metadata about a SSSS key passphrase. diff --git a/crypto/ssss/meta_test.go b/crypto/ssss/meta_test.go index 96c97282..4f2ff378 100644 --- a/crypto/ssss/meta_test.go +++ b/crypto/ssss/meta_test.go @@ -41,12 +41,28 @@ const key2Meta = ` } ` +const key2MetaBrokenIV = ` +{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2", + "iv": "O0BOvTqiIAYjC+RMcyHfWwMeowMeowMeow", + "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI=" +} +` + +const key2MetaBrokenMAC = ` +{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2", + "iv": "O0BOvTqiIAYjC+RMcyHfWw==", + "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtIMeowMeowMeow" +} +` + const key2ID = "NVe5vK6lZS9gEMQLJw0yqkzmE5Mr7dLv" const key2RecoveryKey = "EsUC xSxt XJgQ dz19 8WBZ rHdE GZo7 ybsn EFmG Y5HY MDAG GNWe" -func getKey1Meta() *ssss.KeyMetadata { +func getKeyMeta(meta string) *ssss.KeyMetadata { var km ssss.KeyMetadata - err := json.Unmarshal([]byte(key1Meta), &km) + err := json.Unmarshal([]byte(meta), &km) if err != nil { panic(err) } @@ -54,7 +70,7 @@ func getKey1Meta() *ssss.KeyMetadata { } func getKey1() *ssss.Key { - km := getKey1Meta() + km := getKeyMeta(key1Meta) key, err := km.VerifyRecoveryKey(key1ID, key1RecoveryKey) if err != nil { panic(err) @@ -63,17 +79,8 @@ func getKey1() *ssss.Key { return key } -func getKey2Meta() *ssss.KeyMetadata { - var km ssss.KeyMetadata - err := json.Unmarshal([]byte(key2Meta), &km) - if err != nil { - panic(err) - } - return &km -} - func getKey2() *ssss.Key { - km := getKey2Meta() + km := getKeyMeta(key2Meta) key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) if err != nil { panic(err) @@ -83,7 +90,7 @@ func getKey2() *ssss.Key { } func TestKeyMetadata_VerifyRecoveryKey_Correct(t *testing.T) { - km := getKey1Meta() + km := getKeyMeta(key1Meta) key, err := km.VerifyRecoveryKey(key1ID, key1RecoveryKey) assert.NoError(t, err) assert.NotNil(t, key) @@ -91,7 +98,7 @@ func TestKeyMetadata_VerifyRecoveryKey_Correct(t *testing.T) { } func TestKeyMetadata_VerifyRecoveryKey_Correct2(t *testing.T) { - km := getKey2Meta() + km := getKeyMeta(key2Meta) key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) assert.NoError(t, err) assert.NotNil(t, key) @@ -99,21 +106,21 @@ func TestKeyMetadata_VerifyRecoveryKey_Correct2(t *testing.T) { } func TestKeyMetadata_VerifyRecoveryKey_Invalid(t *testing.T) { - km := getKey1Meta() + km := getKeyMeta(key1Meta) key, err := km.VerifyRecoveryKey(key1ID, "foo") assert.True(t, errors.Is(err, ssss.ErrInvalidRecoveryKey), "unexpected error: %v", err) assert.Nil(t, key) } func TestKeyMetadata_VerifyRecoveryKey_Incorrect(t *testing.T) { - km := getKey1Meta() + km := getKeyMeta(key1Meta) key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error: %v", err) assert.Nil(t, key) } func TestKeyMetadata_VerifyPassphrase_Correct(t *testing.T) { - km := getKey1Meta() + km := getKeyMeta(key1Meta) key, err := km.VerifyPassphrase(key1ID, key1Passphrase) assert.NoError(t, err) assert.NotNil(t, key) @@ -121,15 +128,29 @@ func TestKeyMetadata_VerifyPassphrase_Correct(t *testing.T) { } func TestKeyMetadata_VerifyPassphrase_Incorrect(t *testing.T) { - km := getKey1Meta() + km := getKeyMeta(key1Meta) key, err := km.VerifyPassphrase(key1ID, "incorrect horse battery staple") assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error %v", err) assert.Nil(t, key) } func TestKeyMetadata_VerifyPassphrase_NotSet(t *testing.T) { - km := getKey2Meta() + km := getKeyMeta(key2Meta) key, err := km.VerifyPassphrase(key2ID, "hmm") assert.True(t, errors.Is(err, ssss.ErrNoPassphrase), "unexpected error %v", err) assert.Nil(t, key) } + +func TestKeyMetadata_VerifyRecoveryKey_CorruptedIV(t *testing.T) { + km := getKeyMeta(key2MetaBrokenIV) + key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) + assert.True(t, errors.Is(err, ssss.ErrCorruptedKeyMetadata), "unexpected error %v", err) + assert.Nil(t, key) +} + +func TestKeyMetadata_VerifyRecoveryKey_CorruptedMAC(t *testing.T) { + km := getKeyMeta(key2MetaBrokenMAC) + key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) + assert.True(t, errors.Is(err, ssss.ErrCorruptedKeyMetadata), "unexpected error %v", err) + assert.Nil(t, key) +} diff --git a/crypto/ssss/types.go b/crypto/ssss/types.go index 60852c55..345393b0 100644 --- a/crypto/ssss/types.go +++ b/crypto/ssss/types.go @@ -26,6 +26,7 @@ var ( ErrUnsupportedPassphraseAlgorithm = errors.New("unsupported passphrase KDF algorithm") ErrIncorrectSSSSKey = errors.New("incorrect SSSS key") ErrInvalidRecoveryKey = errors.New("invalid recovery key") + ErrCorruptedKeyMetadata = errors.New("corrupted key metadata") ) // Algorithm is the identifier for an SSSS encryption algorithm. From 07f0d8836a8425cc2bdbb79f2db5b17832b3667c Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 5 Mar 2025 10:46:06 +0000 Subject: [PATCH 1046/1647] bridgev2/backfillqueue: expose `DoBackfillTask` method --- bridgev2/backfillqueue.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bridgev2/backfillqueue.go b/bridgev2/backfillqueue.go index 7d521fd1..61318d94 100644 --- a/bridgev2/backfillqueue.go +++ b/bridgev2/backfillqueue.go @@ -80,13 +80,13 @@ func (br *Bridge) RunBackfillQueue() { time.Sleep(BackfillQueueErrorBackoff) continue } else if backfillTask != nil { - br.doBackfillTask(ctx, backfillTask) + br.DoBackfillTask(ctx, backfillTask) noTasksFoundCount = 0 } } } -func (br *Bridge) doBackfillTask(ctx context.Context, task *database.BackfillTask) { +func (br *Bridge) DoBackfillTask(ctx context.Context, task *database.BackfillTask) { log := zerolog.Ctx(ctx).With(). Object("portal_key", task.PortalKey). Str("login_id", string(task.UserLoginID)). From ef5eb3c9cf8760df6d18f12e147177ae63721fe2 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 5 Mar 2025 10:46:50 +0000 Subject: [PATCH 1047/1647] bridgev2/database: add `BackfillTaskQuery.GetNextForPortal` method --- bridgev2/database/backfillqueue.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/bridgev2/database/backfillqueue.go b/bridgev2/database/backfillqueue.go index fed7452d..224ae626 100644 --- a/bridgev2/database/backfillqueue.go +++ b/bridgev2/database/backfillqueue.go @@ -86,6 +86,13 @@ const ( WHERE bridge_id = $1 AND next_dispatch_min_ts < $2 AND is_done = false AND user_login_id <> '' ORDER BY next_dispatch_min_ts LIMIT 1 ` + getNextBackfillQueryForPortal = ` + SELECT + bridge_id, portal_id, portal_receiver, user_login_id, batch_count, is_done, + cursor, oldest_message_id, dispatched_at, completed_at, next_dispatch_min_ts + FROM backfill_task + WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3 AND is_done = false AND user_login_id <> '' + ` deleteBackfillQueueQuery = ` DELETE FROM backfill_task WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3 @@ -124,6 +131,10 @@ func (btq *BackfillTaskQuery) GetNext(ctx context.Context) (*BackfillTask, error return btq.QueryOne(ctx, getNextBackfillQuery, btq.BridgeID, time.Now().UnixNano()) } +func (btq *BackfillTaskQuery) GetNextForPortal(ctx context.Context, portalKey networkid.PortalKey) (*BackfillTask, error) { + return btq.QueryOne(ctx, getNextBackfillQueryForPortal, btq.BridgeID, portalKey.ID, portalKey.Receiver) +} + func (btq *BackfillTaskQuery) Delete(ctx context.Context, portalKey networkid.PortalKey) error { return btq.Exec(ctx, deleteBackfillQueueQuery, btq.BridgeID, portalKey.ID, portalKey.Receiver) } From 8c4920a6c48b008ae2e4a105bb519bf8b44ce041 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 6 Mar 2025 00:58:25 +0200 Subject: [PATCH 1048/1647] client: add wrappers for sending MSC4140 delayed events --- client.go | 23 +++++++++++++---------- requests.go | 16 ++++++++++++++++ responses.go | 4 ++++ 3 files changed, 33 insertions(+), 10 deletions(-) diff --git a/client.go b/client.go index 2d67231c..4740abba 100644 --- a/client.go +++ b/client.go @@ -1137,15 +1137,6 @@ func (cli *Client) SetRoomAccountData(ctx context.Context, roomID id.RoomID, nam return nil } -type ReqSendEvent struct { - Timestamp int64 - TransactionID string - - DontEncrypt bool - - MeowEventID id.EventID -} - // SendMessageEvent sends a message event into a room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidsendeventtypetxnid // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { @@ -1168,6 +1159,9 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event if req.MeowEventID != "" { queryParams["fi.mau.event_id"] = req.MeowEventID.String() } + if req.UnstableDelay > 0 { + queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10) + } if !req.DontEncrypt && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted { var isEncrypted bool @@ -1203,11 +1197,14 @@ func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventTy if req.MeowEventID != "" { queryParams["fi.mau.event_id"] = req.MeowEventID.String() } + if req.UnstableDelay > 0 { + queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10) + } urlData := ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey} urlPath := cli.BuildURLWithQuery(urlData, queryParams) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp) - if err == nil && cli.StateStore != nil { + if err == nil && cli.StateStore != nil && req.UnstableDelay == 0 { cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON) } return @@ -1226,6 +1223,12 @@ func (cli *Client) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, return } +func (cli *Client) UpdateDelayedEvent(ctx context.Context, req *ReqUpdateDelayedEvent) (resp *RespUpdateDelayedEvent, err error) { + urlPath := cli.BuildClientURL("v1", "delayed_events", req.DelayID) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, req, &resp) + return +} + // SendText sends an m.room.message event into the given room with a msgtype of m.text // See https://spec.matrix.org/v1.2/client-server-api/#mtext func (cli *Client) SendText(ctx context.Context, roomID id.RoomID, text string) (*RespSendEvent, error) { diff --git a/requests.go b/requests.go index 9788aec7..0ae90d63 100644 --- a/requests.go +++ b/requests.go @@ -3,6 +3,7 @@ package mautrix import ( "encoding/json" "strconv" + "time" "maunium.net/go/mautrix/crypto/signatures" "maunium.net/go/mautrix/event" @@ -331,6 +332,21 @@ type ReqSendToDevice struct { Messages map[id.UserID]map[id.DeviceID]*event.Content `json:"messages"` } +type ReqSendEvent struct { + Timestamp int64 + TransactionID string + UnstableDelay time.Duration + + DontEncrypt bool + + MeowEventID id.EventID +} + +type ReqUpdateDelayedEvent struct { + DelayID string `json:"-"` + Action string `json:"action"` // TODO use enum +} + // ReqDeviceInfo is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3devicesdeviceid type ReqDeviceInfo struct { DisplayName string `json:"display_name,omitempty"` diff --git a/responses.go b/responses.go index f4ab024a..04f79143 100644 --- a/responses.go +++ b/responses.go @@ -98,8 +98,12 @@ type RespContext struct { // RespSendEvent is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidsendeventtypetxnid type RespSendEvent struct { EventID id.EventID `json:"event_id"` + + UnstableDelayID string `json:"delay_id,omitempty"` } +type RespUpdateDelayedEvent struct{} + type RespRedactUserEvents struct { IsMoreEvents bool `json:"is_more_events"` RedactedEvents struct { From 01d1e9d69c9047999baf4b970846bdadc24b1cfb Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 6 Mar 2025 01:27:34 +0200 Subject: [PATCH 1049/1647] client: fix msc4104 unstable prefix --- client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client.go b/client.go index 4740abba..a6d89338 100644 --- a/client.go +++ b/client.go @@ -1224,7 +1224,7 @@ func (cli *Client) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, } func (cli *Client) UpdateDelayedEvent(ctx context.Context, req *ReqUpdateDelayedEvent) (resp *RespUpdateDelayedEvent, err error) { - urlPath := cli.BuildClientURL("v1", "delayed_events", req.DelayID) + urlPath := cli.BuildClientURL("unstable", "org.matrix.msc4140", "delayed_events", req.DelayID) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, req, &resp) return } From e306c2817edd776f87b1084eccc3cf6f541f272e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 6 Mar 2025 01:29:06 +0200 Subject: [PATCH 1050/1647] client: fix update delayed event method --- client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client.go b/client.go index a6d89338..e6757d67 100644 --- a/client.go +++ b/client.go @@ -1225,7 +1225,7 @@ func (cli *Client) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, func (cli *Client) UpdateDelayedEvent(ctx context.Context, req *ReqUpdateDelayedEvent) (resp *RespUpdateDelayedEvent, err error) { urlPath := cli.BuildClientURL("unstable", "org.matrix.msc4140", "delayed_events", req.DelayID) - _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, req, &resp) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) return } From c10d4eb80bba60bc7b7de03bf59411a84a2b5e7f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 8 Mar 2025 19:27:44 +0200 Subject: [PATCH 1051/1647] event: make m.federate a pointer --- event/state.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/event/state.go b/event/state.go index 4e926dcd..d5cacbad 100644 --- a/event/state.go +++ b/event/state.go @@ -92,10 +92,12 @@ const ( // https://spec.matrix.org/v1.2/client-server-api/#mroomcreate type CreateEventContent struct { Type RoomType `json:"type,omitempty"` - Creator id.UserID `json:"creator,omitempty"` - Federate bool `json:"m.federate,omitempty"` + Federate *bool `json:"m.federate,omitempty"` RoomVersion RoomVersion `json:"room_version,omitempty"` Predecessor *Predecessor `json:"predecessor,omitempty"` + + // Deprecated: use the event sender instead + Creator id.UserID `json:"creator,omitempty"` } // JoinRule specifies how open a room is to new members. From 0f4c560bd6110e89677ebeabde920d6c3c019b9f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 8 Mar 2025 19:27:57 +0200 Subject: [PATCH 1052/1647] bridgev2/userlogin: log error if client isn't loaded --- bridgev2/userlogin.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 142d67d4..b9c7288a 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -62,6 +62,9 @@ func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *da if err != nil { userLogin.Log.Err(err).Msg("Failed to load user login") return nil, nil + } else if userLogin.Client == nil { + userLogin.Log.Error().Msg("LoadUserLogin didn't fill Client") + return nil, nil } userLogin.BridgeState = br.NewBridgeStateQueue(userLogin) user.logins[userLogin.ID] = userLogin From d83b63aeaf642d537e7d6755893323ee41ac7dff Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sun, 9 Mar 2025 17:16:31 +0000 Subject: [PATCH 1053/1647] client: add wrapper for /knock endpoint (#359) --- client.go | 22 ++++++++++++++++++++++ requests.go | 5 +++++ responses.go | 5 +++++ 3 files changed, 32 insertions(+) diff --git a/client.go b/client.go index e6757d67..7b5a7fe4 100644 --- a/client.go +++ b/client.go @@ -974,6 +974,28 @@ func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias string, req *ReqJ return } +// KnockRoom requests to join a room ID or alias. See https://spec.matrix.org/v1.13/client-server-api/#post_matrixclientv3knockroomidoralias +// +// The last parameter contains optional extra fields and can be left nil. +func (cli *Client) KnockRoom(ctx context.Context, roomIDorAlias string, req *ReqKnockRoom) (resp *RespKnockRoom, err error) { + if req == nil { + req = &ReqKnockRoom{} + } + urlPath := cli.BuildURLWithFullQuery(ClientURLPath{"v3", "knock", roomIDorAlias}, func(q url.Values) { + if len(req.Via) > 0 { + q["via"] = req.Via + } + }) + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) + if err == nil && cli.StateStore != nil { + err = cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipKnock) + if err != nil { + err = fmt.Errorf("failed to update state store: %w", err) + } + } + return +} + // JoinRoomByID joins the client to a room ID. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidjoin // // Unlike JoinRoom, this method can only be used to join rooms that the server already knows about. diff --git a/requests.go b/requests.go index 0ae90d63..377534ae 100644 --- a/requests.go +++ b/requests.go @@ -156,6 +156,11 @@ type ReqJoinRoom struct { ThirdPartySigned any `json:"third_party_signed,omitempty"` } +type ReqKnockRoom struct { + Via []string `json:"-"` + Reason string `json:"reason,omitempty"` +} + type ReqMutualRooms struct { From string `json:"-"` } diff --git a/responses.go b/responses.go index 04f79143..158d2444 100644 --- a/responses.go +++ b/responses.go @@ -33,6 +33,11 @@ type RespJoinRoom struct { RoomID id.RoomID `json:"room_id"` } +// RespKnockRoom is the JSON response for https://spec.matrix.org/v1.13/client-server-api/#post_matrixclientv3knockroomidoralias +type RespKnockRoom struct { + RoomID id.RoomID `json:"room_id"` +} + // RespLeaveRoom is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidleave type RespLeaveRoom struct{} From 7f04ae7a9f3855087bbaa557f0e28f9ba65336bf Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 9 Mar 2025 14:25:11 +0200 Subject: [PATCH 1054/1647] client: remove unprefixed beeper streaming sync flag --- client.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/client.go b/client.go index 7b5a7fe4..5a54110c 100644 --- a/client.go +++ b/client.go @@ -717,8 +717,6 @@ func (req *ReqSync) BuildQuery() map[string]string { query["full_state"] = "true" } if req.BeeperStreaming { - // TODO remove this - query["streaming"] = "" query["com.beeper.streaming"] = "true" } return query From 52c8a2e1de39bd19d809cf5f3697a4090357b652 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 9 Mar 2025 14:29:54 +0200 Subject: [PATCH 1055/1647] sync: add support for MSC4222 --- client.go | 4 ++++ event/events.go | 3 +++ responses.go | 1 + sync.go | 32 +++++++++++++++++++------------- 4 files changed, 27 insertions(+), 13 deletions(-) diff --git a/client.go b/client.go index 5a54110c..dfbd190b 100644 --- a/client.go +++ b/client.go @@ -696,6 +696,7 @@ type ReqSync struct { FullState bool SetPresence event.Presence StreamResponse bool + UseStateAfter bool BeeperStreaming bool Client *http.Client } @@ -716,6 +717,9 @@ func (req *ReqSync) BuildQuery() map[string]string { if req.FullState { query["full_state"] = "true" } + if req.UseStateAfter { + query["org.matrix.msc4222.use_state_after"] = "true" + } if req.BeeperStreaming { query["com.beeper.streaming"] = "true" } diff --git a/event/events.go b/event/events.go index 1c173351..92cc39ae 100644 --- a/event/events.go +++ b/event/events.go @@ -118,6 +118,9 @@ type MautrixInfo struct { DecryptionDuration time.Duration CheckpointSent bool + // When using MSC4222 and the state_after field, this field is set + // for timeline events to indicate they shouldn't update room state. + IgnoreState bool } func (evt *Event) GetStateKey() string { diff --git a/responses.go b/responses.go index 158d2444..3123a530 100644 --- a/responses.go +++ b/responses.go @@ -393,6 +393,7 @@ type BeeperInboxPreviewEvent struct { type SyncJoinedRoom struct { Summary LazyLoadSummary `json:"summary"` State SyncEventsList `json:"state"` + StateAfter *SyncEventsList `json:"org.matrix.msc4222.state_after,omitempty"` Timeline SyncTimeline `json:"timeline"` Ephemeral SyncEventsList `json:"ephemeral"` AccountData SyncEventsList `json:"account_data"` diff --git a/sync.go b/sync.go index 48906bbc..9a2b9edf 100644 --- a/sync.go +++ b/sync.go @@ -97,33 +97,38 @@ func (s *DefaultSyncer) ProcessResponse(ctx context.Context, res *RespSync, sinc } } - s.processSyncEvents(ctx, "", res.ToDevice.Events, event.SourceToDevice) - s.processSyncEvents(ctx, "", res.Presence.Events, event.SourcePresence) - s.processSyncEvents(ctx, "", res.AccountData.Events, event.SourceAccountData) + s.processSyncEvents(ctx, "", res.ToDevice.Events, event.SourceToDevice, false) + s.processSyncEvents(ctx, "", res.Presence.Events, event.SourcePresence, false) + s.processSyncEvents(ctx, "", res.AccountData.Events, event.SourceAccountData, false) for roomID, roomData := range res.Rooms.Join { - s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceJoin|event.SourceState) - s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceJoin|event.SourceTimeline) - s.processSyncEvents(ctx, roomID, roomData.Ephemeral.Events, event.SourceJoin|event.SourceEphemeral) - s.processSyncEvents(ctx, roomID, roomData.AccountData.Events, event.SourceJoin|event.SourceAccountData) + if roomData.StateAfter == nil { + s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceJoin|event.SourceState, false) + s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceJoin|event.SourceTimeline, false) + } else { + s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceJoin|event.SourceTimeline, true) + s.processSyncEvents(ctx, roomID, roomData.StateAfter.Events, event.SourceJoin|event.SourceState, false) + } + s.processSyncEvents(ctx, roomID, roomData.Ephemeral.Events, event.SourceJoin|event.SourceEphemeral, false) + s.processSyncEvents(ctx, roomID, roomData.AccountData.Events, event.SourceJoin|event.SourceAccountData, false) } for roomID, roomData := range res.Rooms.Invite { - s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceInvite|event.SourceState) + s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceInvite|event.SourceState, false) } for roomID, roomData := range res.Rooms.Leave { - s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceLeave|event.SourceState) - s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceLeave|event.SourceTimeline) + s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceLeave|event.SourceState, false) + s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceLeave|event.SourceTimeline, false) } return } -func (s *DefaultSyncer) processSyncEvents(ctx context.Context, roomID id.RoomID, events []*event.Event, source event.Source) { +func (s *DefaultSyncer) processSyncEvents(ctx context.Context, roomID id.RoomID, events []*event.Event, source event.Source, ignoreState bool) { for _, evt := range events { - s.processSyncEvent(ctx, roomID, evt, source) + s.processSyncEvent(ctx, roomID, evt, source, ignoreState) } } -func (s *DefaultSyncer) processSyncEvent(ctx context.Context, roomID id.RoomID, evt *event.Event, source event.Source) { +func (s *DefaultSyncer) processSyncEvent(ctx context.Context, roomID id.RoomID, evt *event.Event, source event.Source, ignoreState bool) { evt.RoomID = roomID // Ensure the type class is correct. It's safe to mutate the class since the event type is not a pointer. @@ -149,6 +154,7 @@ func (s *DefaultSyncer) processSyncEvent(ctx context.Context, roomID id.RoomID, } evt.Mautrix.EventSource = source + evt.Mautrix.IgnoreState = ignoreState s.Dispatch(ctx, evt) } From d4975cbffdf32cf5a7ebc40a07422a52f63f94fd Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 10 Mar 2025 01:25:17 +0200 Subject: [PATCH 1056/1647] crypto/sql_store: fix MarkTrackedUsersOutdated for big lists --- crypto/sql_store.go | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/crypto/sql_store.go b/crypto/sql_store.go index b3c3c4d9..514c1e8c 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -13,6 +13,7 @@ import ( "encoding/json" "errors" "fmt" + "slices" "strings" "sync" "time" @@ -829,11 +830,13 @@ func (store *SQLCryptoStore) FilterTrackedUsers(ctx context.Context, users []id. // MarkTrackedUsersOutdated flags that the device list for given users are outdated. func (store *SQLCryptoStore) MarkTrackedUsersOutdated(ctx context.Context, users []id.UserID) (err error) { - if store.DB.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil { - _, err = store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id = ANY($1)", PostgresArrayWrapper(users)) - } else { - placeholders, params := userIDsToParams(users) - _, err = store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id IN ("+placeholders+")", params...) + for chunk := range slices.Chunk(users, 1000) { + if store.DB.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil { + _, err = store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id = ANY($1)", PostgresArrayWrapper(chunk)) + } else { + placeholders, params := userIDsToParams(chunk) + _, err = store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id IN ("+placeholders+")", params...) + } } return } From 7492e6e3080724e7104182ff4e08f63a28d92c39 Mon Sep 17 00:00:00 2001 From: Brad Murray Date: Mon, 10 Mar 2025 14:51:26 -0400 Subject: [PATCH 1057/1647] crypto/verificationhelper: add supports SAS parameter (#361) --- .../verificationhelper/verificationhelper.go | 12 ++-- .../verificationhelper_test.go | 64 ++++++++++--------- 2 files changed, 42 insertions(+), 34 deletions(-) diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 1fdbcc70..c47eea71 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -82,7 +82,7 @@ type VerificationHelper struct { var _ mautrix.VerificationHelper = (*VerificationHelper)(nil) -func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, store VerificationStore, callbacks any, supportsQRShow, supportsQRScan bool) *VerificationHelper { +func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, store VerificationStore, callbacks any, supportsQRShow, supportsQRScan, supportsSAS bool) *VerificationHelper { if client.Crypto == nil { panic("client.Crypto is nil") } @@ -106,9 +106,13 @@ func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, stor helper.verificationDone = c.VerificationDone } - if c, ok := callbacks.(ShowSASCallbacks); ok { - helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodSAS) - helper.showSAS = c.ShowSAS + if supportsSAS { + if c, ok := callbacks.(ShowSASCallbacks); !ok { + panic("callbacks must implement showSAS if supportsSAS is true") + } else { + helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodSAS) + helper.showSAS = c.ShowSAS + } } if supportsQRShow { if c, ok := callbacks.(ShowQRCodeCallbacks); !ok { diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go index e192508b..b4c21c18 100644 --- a/crypto/verificationhelper/verificationhelper_test.go +++ b/crypto/verificationhelper/verificationhelper_test.go @@ -71,7 +71,7 @@ func initDefaultCallbacks(t *testing.T, ctx context.Context, sendingClient, rece senderVerificationStore, err := NewSQLiteVerificationStore(ctx, senderVerificationDB) require.NoError(t, err) - sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, senderVerificationStore, sendingCallbacks, true, true) + sendingHelper = verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, senderVerificationStore, sendingCallbacks, true, true, true) require.NoError(t, sendingHelper.Init(ctx)) receivingCallbacks = newAllVerificationCallbacks() @@ -79,7 +79,7 @@ func initDefaultCallbacks(t *testing.T, ctx context.Context, sendingClient, rece require.NoError(t, err) receiverVerificationStore, err := NewSQLiteVerificationStore(ctx, receiverVerificationDB) require.NoError(t, err) - receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, receiverVerificationStore, receivingCallbacks, true, true) + receivingHelper = verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, receiverVerificationStore, receivingCallbacks, true, true, true) require.NoError(t, receivingHelper.Init(ctx)) return } @@ -91,25 +91,27 @@ func TestVerification_Start(t *testing.T) { testCases := []struct { supportsShow bool supportsScan bool + supportsSAS bool callbacks MockVerificationCallbacks startVerificationErrMsg string expectedVerificationMethods []event.VerificationMethod }{ - {false, false, newBaseVerificationCallbacks(), "no supported verification methods", nil}, - {false, true, newBaseVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {false, false, false, newBaseVerificationCallbacks(), "no supported verification methods", nil}, + {false, true, false, newBaseVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, - {false, false, newShowQRCodeVerificationCallbacks(), "no supported verification methods", nil}, - {true, false, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, - {false, true, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, - {true, true, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {false, false, false, newShowQRCodeVerificationCallbacks(), "no supported verification methods", nil}, + {true, false, false, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {false, true, false, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {true, true, false, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, - {false, false, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}}, - {false, true, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {false, false, true, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}}, + {false, true, true, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, - {false, false, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}}, - {false, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, - {true, false, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, - {true, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {false, false, false, newAllVerificationCallbacks(), "no supported verification methods", nil}, + {false, false, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}}, + {false, true, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, + {true, false, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}}, + {true, true, true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}}, } for i, tc := range testCases { @@ -122,7 +124,7 @@ func TestVerification_Start(t *testing.T) { addDeviceID(ctx, cryptoStore, aliceUserID, receivingDeviceID) addDeviceID(ctx, cryptoStore, aliceUserID, receivingDeviceID2) - senderHelper := verificationhelper.NewVerificationHelper(client, client.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, tc.callbacks, tc.supportsShow, tc.supportsScan) + senderHelper := verificationhelper.NewVerificationHelper(client, client.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, tc.callbacks, tc.supportsShow, tc.supportsScan, tc.supportsSAS) err := senderHelper.Init(ctx) require.NoError(t, err) @@ -169,7 +171,7 @@ func TestVerification_StartThenCancel(t *testing.T) { bystanderClient, _ := ts.Login(t, ctx, aliceUserID, bystanderDeviceID) bystanderMachine := bystanderClient.Crypto.(*cryptohelper.CryptoHelper).Machine() - bystanderHelper := verificationhelper.NewVerificationHelper(bystanderClient, bystanderMachine, nil, newAllVerificationCallbacks(), true, true) + bystanderHelper := verificationhelper.NewVerificationHelper(bystanderClient, bystanderMachine, nil, newAllVerificationCallbacks(), true, true, true) require.NoError(t, bystanderHelper.Init(ctx)) require.NoError(t, sendingCryptoStore.PutDevice(ctx, aliceUserID, bystanderMachine.OwnIdentity())) @@ -259,12 +261,12 @@ func TestVerification_Accept_NoSupportedMethods(t *testing.T) { assert.NotEmpty(t, recoveryKey) assert.NotNil(t, cache) - sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, newAllVerificationCallbacks(), true, true) + sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, newAllVerificationCallbacks(), true, true, true) err = sendingHelper.Init(ctx) require.NoError(t, err) receivingCallbacks := newBaseVerificationCallbacks() - receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, receivingCallbacks, false, false) + receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingClient.Crypto.(*cryptohelper.CryptoHelper).Machine(), nil, receivingCallbacks, false, false, false) err = receivingHelper.Init(ctx) require.NoError(t, err) @@ -288,23 +290,25 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { sendingSupportsShow bool receivingSupportsScan bool receivingSupportsShow bool + sendingSupportsSAS bool + receivingSupportsSAS bool sendingCallbacks MockVerificationCallbacks receivingCallbacks MockVerificationCallbacks expectedVerificationMethods []event.VerificationMethod }{ // TODO - {false, false, false, false, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}}, - {true, false, true, false, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}}, + {false, false, false, false, true, true, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}}, + {true, false, true, false, true, true, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}}, - {true, false, false, true, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}}, - {false, true, true, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}}, - {true, false, true, true, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}}, - {false, true, true, true, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}}, - {true, true, true, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}}, - {true, true, false, true, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}}, - {true, true, true, true, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow}}, + {true, false, false, true, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}}, + {false, true, true, false, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}}, + {true, false, true, true, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}}, + {false, true, true, true, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}}, + {true, true, true, false, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan}}, + {true, true, false, true, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeShow}}, + {true, true, true, true, false, false, newShowQRCodeVerificationCallbacks(), newShowQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow}}, - {true, true, true, true, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow}}, + {true, true, true, true, true, true, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS, event.VerificationMethodReciprocate, event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow}}, } for i, tc := range testCases { @@ -317,11 +321,11 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { assert.NotEmpty(t, recoveryKey) assert.NotNil(t, sendingCrossSigningKeysCache) - sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, tc.sendingCallbacks, tc.sendingSupportsShow, tc.sendingSupportsScan) + sendingHelper := verificationhelper.NewVerificationHelper(sendingClient, sendingMachine, nil, tc.sendingCallbacks, tc.sendingSupportsShow, tc.sendingSupportsScan, tc.sendingSupportsSAS) err = sendingHelper.Init(ctx) require.NoError(t, err) - receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, nil, tc.receivingCallbacks, tc.receivingSupportsShow, tc.receivingSupportsScan) + receivingHelper := verificationhelper.NewVerificationHelper(receivingClient, receivingMachine, nil, tc.receivingCallbacks, tc.receivingSupportsShow, tc.receivingSupportsScan, tc.receivingSupportsSAS) err = receivingHelper.Init(ctx) require.NoError(t, err) From a01edae1c3d6fe9491d45ebfe5ac899f355650e5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 10 Mar 2025 22:52:06 +0200 Subject: [PATCH 1058/1647] bridgev2/portal: don't bridge remote edits by different users --- bridgev2/portal.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index dcd3cd37..409a9c10 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2095,6 +2095,12 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventEdit) if intent == nil { return + } else if intent.GetMXID() != existing[0].SenderMXID { + log.Warn(). + Stringer("edit_sender_mxid", intent.GetMXID()). + Stringer("original_sender_mxid", existing[0].SenderMXID). + Msg("Not bridging edit: sender doesn't match original message sender") + return } ts := getEventTS(evt) converted, err := evt.ConvertEdit(ctx, portal, intent, existing) From a6d948f7c2bb2cbe7f659f0fda90cedb17d5fcf7 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Thu, 6 Mar 2025 16:41:41 +0000 Subject: [PATCH 1059/1647] bridgev2/config: add `BackfillConfig.WillPaginateManually` --- bridgev2/bridgeconfig/backfill.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/bridgev2/bridgeconfig/backfill.go b/bridgev2/bridgeconfig/backfill.go index 44d2d588..53282e41 100644 --- a/bridgev2/bridgeconfig/backfill.go +++ b/bridgev2/bridgeconfig/backfill.go @@ -14,6 +14,11 @@ type BackfillConfig struct { Threads BackfillThreadsConfig `yaml:"threads"` Queue BackfillQueueConfig `yaml:"queue"` + + // Flag to indicate that the creator will not run the backfill queue but will still paginate + // backfill by calling DoBackfillTask directly. Note that this is not used anywhere within + // mautrix-go and exists so bridges can use it to decide when to drop backfill data. + WillPaginateManually bool `yaml:"will_paginate_manually"` } type BackfillThreadsConfig struct { From eed1ffe1076d089ed8d26dcb786f67dcc39d47ef Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Mon, 10 Mar 2025 22:09:46 -0600 Subject: [PATCH 1060/1647] verificationhelper/sas: fix error when both sides send StartSAS Signed-off-by: Sumner Evans --- crypto/verificationhelper/sas.go | 37 ++++----- .../verificationhelper/verificationhelper.go | 8 +- .../verificationhelper_sas_test.go | 77 +++++++++++++++++++ 3 files changed, 101 insertions(+), 21 deletions(-) diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index 70b77759..2906e3a2 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -224,28 +224,29 @@ func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn Ve } txn.MACMethod = macMethod txn.EphemeralKey = &ECDHPrivateKey{ephemeralKey} - txn.StartEventContent = startEvt - commitment, err := calculateCommitment(ephemeralKey.PublicKey(), startEvt) - if err != nil { - return fmt.Errorf("failed to calculate commitment: %w", err) - } + if !txn.StartedByUs { + commitment, err := calculateCommitment(ephemeralKey.PublicKey(), txn) + if err != nil { + return fmt.Errorf("failed to calculate commitment: %w", err) + } - err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationAccept, &event.VerificationAcceptEventContent{ - Commitment: commitment, - Hash: hashAlgorithm, - KeyAgreementProtocol: keyAggreementProtocol, - MessageAuthenticationCode: macMethod, - ShortAuthenticationString: sasMethods, - }) - if err != nil { - return fmt.Errorf("failed to send accept event: %w", err) + err = vh.sendVerificationEvent(ctx, txn, event.InRoomVerificationAccept, &event.VerificationAcceptEventContent{ + Commitment: commitment, + Hash: hashAlgorithm, + KeyAgreementProtocol: keyAggreementProtocol, + MessageAuthenticationCode: macMethod, + ShortAuthenticationString: sasMethods, + }) + if err != nil { + return fmt.Errorf("failed to send accept event: %w", err) + } + txn.VerificationState = VerificationStateSASAccepted } - txn.VerificationState = VerificationStateSASAccepted return vh.store.SaveVerificationTransaction(ctx, txn) } -func calculateCommitment(ephemeralPubKey *ecdh.PublicKey, startEvt *event.VerificationStartEventContent) ([]byte, error) { +func calculateCommitment(ephemeralPubKey *ecdh.PublicKey, txn VerificationTransaction) ([]byte, error) { // The commitmentHashInput is the hash (encoded as unpadded base64) of the // concatenation of the device's ephemeral public key (encoded as // unpadded base64) and the canonical JSON representation of the @@ -255,7 +256,7 @@ func calculateCommitment(ephemeralPubKey *ecdh.PublicKey, startEvt *event.Verifi // hashing it, but we are just stuck on that. commitmentHashInput := sha256.New() commitmentHashInput.Write([]byte(base64.RawStdEncoding.EncodeToString(ephemeralPubKey.Bytes()))) - encodedStartEvt, err := json.Marshal(startEvt) + encodedStartEvt, err := json.Marshal(txn.StartEventContent) if err != nil { return nil, err } @@ -339,7 +340,7 @@ func (vh *VerificationHelper) onVerificationKey(ctx context.Context, txn Verific if txn.EphemeralPublicKeyShared { // Verify that the commitment hash is correct - commitment, err := calculateCommitment(publicKey, txn.StartEventContent) + commitment, err := calculateCommitment(publicKey, txn) if err != nil { log.Err(err).Msg("Failed to calculate commitment") return diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index c47eea71..550df942 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -818,6 +818,8 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn Verif } else if txn.VerificationState != VerificationStateReady { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnexpectedMessage, "got start event for transaction that is not in ready state") return + } else { + txn.StartEventContent = startEvt } switch startEvt.Method { @@ -829,7 +831,7 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn Verif } case event.VerificationMethodReciprocate: log.Info().Msg("Received reciprocate start event") - if !bytes.Equal(txn.QRCodeSharedSecret, startEvt.Secret) { + if !bytes.Equal(txn.QRCodeSharedSecret, txn.StartEventContent.Secret) { vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeKeyMismatch, "reciprocated shared secret does not match") return } @@ -842,8 +844,8 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn Verif // Note that we should never get m.qr_code.show.v1 or m.qr_code.scan.v1 // here, since the start command for scanning and showing QR codes // should be of type m.reciprocate.v1. - log.Error().Str("method", string(startEvt.Method)).Msg("Unsupported verification method in start event") - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnknownMethod, fmt.Sprintf("unknown method %s", startEvt.Method)) + log.Error().Str("method", string(txn.StartEventContent.Method)).Msg("Unsupported verification method in start event") + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnknownMethod, fmt.Sprintf("unknown method %s", txn.StartEventContent.Method)) } } diff --git a/crypto/verificationhelper/verificationhelper_sas_test.go b/crypto/verificationhelper/verificationhelper_sas_test.go index 22b1563c..5747ac34 100644 --- a/crypto/verificationhelper/verificationhelper_sas_test.go +++ b/crypto/verificationhelper/verificationhelper_sas_test.go @@ -283,3 +283,80 @@ func TestVerification_SAS(t *testing.T) { }) } } + +func TestVerification_SAS_BothCallStart(t *testing.T) { + ctx := log.Logger.WithContext(context.TODO()) + + ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) + defer ts.Close() + sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) + var err error + + var sendingRecoveryKey string + var sendingCrossSigningKeysCache *crypto.CrossSigningKeysCache + + sendingRecoveryKey, sendingCrossSigningKeysCache, err = sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + require.NoError(t, err) + assert.NotEmpty(t, sendingRecoveryKey) + assert.NotNil(t, sendingCrossSigningKeysCache) + + // Send the verification request from the sender device and accept + // it on the receiving device and receive the verification ready + // event on the sending device. + txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) + require.NoError(t, err) + ts.dispatchToDevice(t, ctx, receivingClient) + err = receivingHelper.AcceptVerification(ctx, txnID) + require.NoError(t, err) + ts.dispatchToDevice(t, ctx, sendingClient) + + err = sendingHelper.StartSAS(ctx, txnID) + require.NoError(t, err) + + err = receivingHelper.StartSAS(ctx, txnID) + require.NoError(t, err) + + // Ensure that both devices have received the verification start event. + receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] + assert.Len(t, receivingInbox, 1) + assert.Equal(t, txnID, receivingInbox[0].Content.AsVerificationStart().TransactionID) + sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] + assert.Len(t, sendingInbox, 1) + assert.Equal(t, txnID, sendingInbox[0].Content.AsVerificationStart().TransactionID) + + // Process the start event from the receiving client to the sending client. + ts.dispatchToDevice(t, ctx, sendingClient) + receivingInbox = ts.DeviceInbox[aliceUserID][receivingDeviceID] + assert.Len(t, receivingInbox, 2) + assert.Equal(t, txnID, receivingInbox[0].Content.AsVerificationStart().TransactionID) + assert.Equal(t, txnID, receivingInbox[1].Content.AsVerificationAccept().TransactionID) + + // Process the rest of the events until we need to confirm the SAS. + for len(ts.DeviceInbox[aliceUserID][sendingDeviceID]) > 0 || len(ts.DeviceInbox[aliceUserID][receivingDeviceID]) > 0 { + ts.dispatchToDevice(t, ctx, receivingClient) + ts.dispatchToDevice(t, ctx, sendingClient) + } + + // Confirm the SAS only the receiving device. + receivingHelper.ConfirmSAS(ctx, txnID) + ts.dispatchToDevice(t, ctx, sendingClient) + + // Verification is not done until both devices confirm the SAS. + assert.False(t, sendingCallbacks.IsVerificationDone(txnID)) + assert.False(t, receivingCallbacks.IsVerificationDone(txnID)) + + // Now, confirm it on the sending device. + sendingHelper.ConfirmSAS(ctx, txnID) + + // Dispatching the events to the receiving device should get us to the done + // state on the receiving device. + ts.dispatchToDevice(t, ctx, receivingClient) + assert.False(t, sendingCallbacks.IsVerificationDone(txnID)) + assert.True(t, receivingCallbacks.IsVerificationDone(txnID)) + + // Dispatching the events to the sending client should get us to the done + // state on the sending device. + ts.dispatchToDevice(t, ctx, sendingClient) + assert.True(t, sendingCallbacks.IsVerificationDone(txnID)) + assert.True(t, receivingCallbacks.IsVerificationDone(txnID)) +} From 6bba74ecb6a655268a5c0a2209c70482377e8c82 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 12 Mar 2025 12:27:57 +0200 Subject: [PATCH 1061/1647] client: add refresh token to login response Closes #294 --- responses.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/responses.go b/responses.go index 3123a530..93a780ef 100644 --- a/responses.go +++ b/responses.go @@ -278,6 +278,9 @@ type RespLogin struct { DeviceID id.DeviceID `json:"device_id"` UserID id.UserID `json:"user_id"` WellKnown *ClientWellKnown `json:"well_known,omitempty"` + + RefreshToken string `json:"refresh_token,omitempty"` + ExpiresInMS int64 `json:"expires_in_ms,omitempty"` } // RespLogout is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3logout From f33b0506d0bd91325837224cb0ff282670ff8c2d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 12 Mar 2025 19:48:05 +0200 Subject: [PATCH 1062/1647] client: add some nil safety --- client.go | 29 +++++++++++++++++++++++------ error.go | 4 ++++ url.go | 3 +++ 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/client.go b/client.go index dfbd190b..a08fbed2 100644 --- a/client.go +++ b/client.go @@ -319,12 +319,15 @@ const ( ) func (cli *Client) RequestStart(req *http.Request) { - if cli.RequestHook != nil { + if cli != nil && cli.RequestHook != nil { cli.RequestHook(req) } } func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err error, handlerErr error, contentLength int, duration time.Duration) { + if cli == nil { + return + } var evt *zerolog.Event if errors.Is(err, context.Canceled) { evt = zerolog.Ctx(req.Context()).Warn() @@ -466,6 +469,9 @@ func (cli *Client) MakeFullRequest(ctx context.Context, params FullRequest) ([]b } func (cli *Client) MakeFullRequestWithResp(ctx context.Context, params FullRequest) ([]byte, *http.Response, error) { + if cli == nil { + return nil, nil, ErrClientIsNil + } if params.MaxAttempts == 0 { params.MaxAttempts = 1 + cli.DefaultHTTPRetries } @@ -665,7 +671,6 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof // Whoami gets the user ID of the current user. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3accountwhoami func (cli *Client) Whoami(ctx context.Context) (resp *RespWhoami, err error) { - urlPath := cli.BuildClientURL("v3", "account", "whoami") _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return @@ -1187,7 +1192,7 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10) } - if !req.DontEncrypt && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted { + if !req.DontEncrypt && cli != nil && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted { var isEncrypted bool isEncrypted, err = cli.StateStore.IsEncrypted(ctx, roomID) if err != nil { @@ -1468,7 +1473,7 @@ func (cli *Client) SetPresence(ctx context.Context, presence ReqPresence) (err e } func (cli *Client) updateStoreWithOutgoingEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) { - if cli.StateStore == nil { + if cli == nil || cli.StateStore == nil { return } fakeEvt := &event.Event{ @@ -1524,8 +1529,10 @@ func (cli *Client) FullStateEvent(ctx context.Context, roomID id.RoomID, eventTy if err == nil && cli.StateStore != nil { UpdateStateStore(ctx, cli.StateStore, evt) } - evt.Type.Class = event.StateEventType - _ = evt.Content.ParseRaw(evt.Type) + if evt != nil { + evt.Type.Class = event.StateEventType + _ = evt.Content.ParseRaw(evt.Type) + } return } @@ -1621,6 +1628,10 @@ func (cli *Client) RequestOpenIDToken(ctx context.Context) (resp *RespOpenIDToke // UploadLink uploads an HTTP URL and then returns an MXC URI. func (cli *Client) UploadLink(ctx context.Context, link string) (*RespMediaUpload, error) { + if cli == nil { + return nil, ErrClientIsNil + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, link, nil) if err != nil { return nil, err @@ -1825,6 +1836,9 @@ func (cli *Client) UploadMedia(ctx context.Context, data ReqUploadMedia) (*RespM if data.DoneCallback != nil { defer data.DoneCallback() } + if cli == nil { + return nil, ErrClientIsNil + } if data.UnstableUploadURL != "" { if data.MXC.IsEmpty() { return nil, errors.New("MXC must also be set when uploading to external URL") @@ -2509,6 +2523,9 @@ func (cli *Client) BeeperDeleteRoom(ctx context.Context, roomID id.RoomID) (err // TxnID returns the next transaction ID. func (cli *Client) TxnID() string { + if cli == nil { + return "client is nil" + } txnID := atomic.AddInt32(&cli.txnID, 1) return fmt.Sprintf("mautrix-go_%d_%d", time.Now().UnixNano(), txnID) } diff --git a/error.go b/error.go index 38039464..6f5dbe72 100644 --- a/error.go +++ b/error.go @@ -76,6 +76,10 @@ var ( MUnredactedContentNotReceived = RespError{ErrCode: "FI.MAU.MSC2815_UNREDACTED_CONTENT_NOT_RECEIVED"} ) +var ( + ErrClientIsNil = errors.New("client is nil") +) + // HTTPError An HTTP Error response, which may wrap an underlying native Go Error. type HTTPError struct { Request *http.Request diff --git a/url.go b/url.go index 0b4eec67..d888956a 100644 --- a/url.go +++ b/url.go @@ -109,6 +109,9 @@ func (cli *Client) BuildURLWithQuery(urlPath PrefixableURLPath, urlQuery map[str // BuildURLWithQuery builds a URL with query parameters in addition to the Client's homeserver // and appservice user ID set already. func (cli *Client) BuildURLWithFullQuery(urlPath PrefixableURLPath, fn func(q url.Values)) string { + if cli == nil { + return "client is nil" + } hsURL := *BuildURL(cli.HomeserverURL, urlPath.FullPath()...) query := hsURL.Query() if cli.SetAppServiceUserID { From 1b77ce1d3d336db4ebcce29817c9553d8a0b18a3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 13 Mar 2025 01:08:29 +0200 Subject: [PATCH 1063/1647] event: add fields for MSC4204 and MSC4205 --- event/state.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/event/state.go b/event/state.go index d5cacbad..0a82fbe5 100644 --- a/event/state.go +++ b/event/state.go @@ -205,17 +205,23 @@ type SpaceParentEventContent struct { type PolicyRecommendation string const ( - PolicyRecommendationBan PolicyRecommendation = "m.ban" - PolicyRecommendationUnstableBan PolicyRecommendation = "org.matrix.mjolnir.ban" - PolicyRecommendationUnban PolicyRecommendation = "fi.mau.meowlnir.unban" + PolicyRecommendationBan PolicyRecommendation = "m.ban" + PolicyRecommendationUnstableTakedown PolicyRecommendation = "org.matrix.msc4204.takedown" + PolicyRecommendationUnstableBan PolicyRecommendation = "org.matrix.mjolnir.ban" + PolicyRecommendationUnban PolicyRecommendation = "fi.mau.meowlnir.unban" ) +type PolicyHashes struct { + SHA256 string `json:"sha256"` +} + // ModPolicyContent represents the content of a m.room.rule.user, m.room.rule.room, and m.room.rule.server state event. // https://spec.matrix.org/v1.2/client-server-api/#moderation-policy-lists type ModPolicyContent struct { Entity string `json:"entity"` Reason string `json:"reason"` Recommendation PolicyRecommendation `json:"recommendation"` + UnstableHashes *PolicyHashes `json:"org.matrix.msc4205.hashes,omitempty"` } // Deprecated: MSC2716 has been abandoned From de5bee328b2c46f2826feaae6087335bf24aeb8b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 13 Mar 2025 02:08:07 +0200 Subject: [PATCH 1064/1647] event: add utility functions for hashed policy list entities --- event/state.go | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/event/state.go b/event/state.go index 0a82fbe5..006ed2a5 100644 --- a/event/state.go +++ b/event/state.go @@ -7,6 +7,8 @@ package event import ( + "encoding/base64" + "maunium.net/go/mautrix/id" ) @@ -215,6 +217,17 @@ type PolicyHashes struct { SHA256 string `json:"sha256"` } +func (ph *PolicyHashes) DecodeSHA256() *[32]byte { + if ph == nil || ph.SHA256 == "" { + return nil + } + decoded, _ := base64.StdEncoding.DecodeString(ph.SHA256) + if len(decoded) == 32 { + return (*[32]byte)(decoded) + } + return nil +} + // ModPolicyContent represents the content of a m.room.rule.user, m.room.rule.room, and m.room.rule.server state event. // https://spec.matrix.org/v1.2/client-server-api/#moderation-policy-lists type ModPolicyContent struct { @@ -224,6 +237,13 @@ type ModPolicyContent struct { UnstableHashes *PolicyHashes `json:"org.matrix.msc4205.hashes,omitempty"` } +func (mpc *ModPolicyContent) EntityOrHash() string { + if mpc.UnstableHashes != nil && mpc.UnstableHashes.SHA256 != "" { + return mpc.UnstableHashes.SHA256 + } + return mpc.Entity +} + // Deprecated: MSC2716 has been abandoned type InsertionMarkerContent struct { InsertionID id.EventID `json:"org.matrix.msc2716.marker.insertion"` From df7e02616d854d69ad4941711f34b3f37c91c43a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 14 Mar 2025 00:25:42 +0200 Subject: [PATCH 1065/1647] dependencies: update --- go.mod | 18 +++++++++--------- go.sum | 32 ++++++++++++++++---------------- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/go.mod b/go.mod index a622b05f..cc11719a 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,7 @@ module maunium.net/go/mautrix go 1.23.0 -toolchain go1.24.0 +toolchain go1.24.1 require ( filippo.io/edwards25519 v1.1.0 @@ -18,12 +18,12 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.8 - go.mau.fi/util v0.8.6-0.20250227184636-7ff63b0b9d95 + go.mau.fi/util v0.8.6-0.20250313222444-739a30158a62 go.mau.fi/zeroconfig v0.1.3 - golang.org/x/crypto v0.33.0 - golang.org/x/exp v0.0.0-20250215185904-eff6e970281f - golang.org/x/net v0.35.0 - golang.org/x/sync v0.11.0 + golang.org/x/crypto v0.36.0 + golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 + golang.org/x/net v0.37.0 + golang.org/x/sync v0.12.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 ) @@ -33,11 +33,11 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/petermattis/goid v0.0.0-20250211185408-f2b9d978cd7a // indirect + github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect - golang.org/x/sys v0.30.0 // indirect - golang.org/x/text v0.22.0 // indirect + golang.org/x/sys v0.31.0 // indirect + golang.org/x/text v0.23.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index c92de482..2fcb8086 100644 --- a/go.sum +++ b/go.sum @@ -28,8 +28,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/petermattis/goid v0.0.0-20250211185408-f2b9d978cd7a h1:ckxP/kGzsxvxXo8jO6E/0QJ8MMmwI7IRj4Fys9QbAZA= -github.com/petermattis/goid v0.0.0-20250211185408-f2b9d978cd7a/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203 h1:E7Kmf11E4K7B5hDti2K2NqPb1nlYlGYsu02S1JNd/Bs= +github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -54,26 +54,26 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.8.6-0.20250227184636-7ff63b0b9d95 h1:5EfVWWjU2Hte9uE6B/hBgvjnVfBx/7SYDZBnsuo+EBs= -go.mau.fi/util v0.8.6-0.20250227184636-7ff63b0b9d95/go.mod h1:Ycug9mrbztlahHPEJ6H5r8Nu/xqZaWbE5vPHVWmfz6M= +go.mau.fi/util v0.8.6-0.20250313222444-739a30158a62 h1:8EjBMxX7QkT94/815jKIVK5k41ku+ES3SxSk8DyQRk4= +go.mau.fi/util v0.8.6-0.20250313222444-739a30158a62/go.mod h1:uNB3UTXFbkpp7xL1M/WvQks90B/L4gvbLpbS0603KOE= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= -golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= -golang.org/x/exp v0.0.0-20250215185904-eff6e970281f h1:oFMYAjX0867ZD2jcNiLBrI9BdpmEkvPyi5YrBGXbamg= -golang.org/x/exp v0.0.0-20250215185904-eff6e970281f/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk= -golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= -golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= -golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= -golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= +golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= +golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 h1:nDVHiLt8aIbd/VzvPWN6kSOPE7+F/fNFDSXLVYkE/Iw= +golang.org/x/exp v0.0.0-20250305212735-054e65f0b394/go.mod h1:sIifuuw/Yco/y6yb6+bDNfyeQ/MdPUy/hKEMYQV17cM= +golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= +golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= +golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= +golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= -golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= -golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= +golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= +golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= +golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= From e1938c5159bf86a175ace8b4c7661e7eee105f35 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 15 Mar 2025 22:28:16 +0200 Subject: [PATCH 1066/1647] bridge: remove package --- bridge/bridge.go | 936 ------------------ bridge/bridgeconfig/config.go | 337 ------- bridge/bridgeconfig/permissions.go | 71 -- bridge/bridgestate.go | 156 --- bridge/commands/admin.go | 77 -- bridge/commands/doublepuppet.go | 83 -- bridge/commands/event.go | 95 -- bridge/commands/handler.go | 100 -- bridge/commands/help.go | 129 --- bridge/commands/meta.go | 56 -- bridge/commands/processor.go | 122 --- bridge/crypto.go | 507 ---------- bridge/cryptostore.go | 63 -- bridge/doublepuppet.go | 173 ---- bridge/matrix.go | 755 -------------- bridge/messagecheckpoint.go | 61 -- bridge/no-crypto.go | 26 - bridge/status/deprecated.go | 83 ++ bridge/websocket.go | 163 --- bridgev2/bridge.go | 2 +- bridgev2/bridgestate.go | 2 +- bridgev2/commands/login.go | 2 +- bridgev2/commands/processor.go | 3 +- bridgev2/database/userlogin.go | 2 +- bridgev2/matrix/connector.go | 2 +- bridgev2/matrix/cryptoerror.go | 2 +- bridgev2/matrix/matrix.go | 2 +- bridgev2/matrix/provisioning.go | 2 +- bridgev2/matrixinterface.go | 2 +- bridgev2/messagestatus.go | 2 +- {bridge => bridgev2}/status/bridgestate.go | 0 .../status/localbridgestate.go | 0 .../status/messagecheckpoint.go | 0 bridgev2/userlogin.go | 2 +- 34 files changed, 95 insertions(+), 3923 deletions(-) delete mode 100644 bridge/bridge.go delete mode 100644 bridge/bridgeconfig/config.go delete mode 100644 bridge/bridgeconfig/permissions.go delete mode 100644 bridge/bridgestate.go delete mode 100644 bridge/commands/admin.go delete mode 100644 bridge/commands/doublepuppet.go delete mode 100644 bridge/commands/event.go delete mode 100644 bridge/commands/handler.go delete mode 100644 bridge/commands/help.go delete mode 100644 bridge/commands/meta.go delete mode 100644 bridge/commands/processor.go delete mode 100644 bridge/crypto.go delete mode 100644 bridge/cryptostore.go delete mode 100644 bridge/doublepuppet.go delete mode 100644 bridge/matrix.go delete mode 100644 bridge/messagecheckpoint.go delete mode 100644 bridge/no-crypto.go create mode 100644 bridge/status/deprecated.go delete mode 100644 bridge/websocket.go rename {bridge => bridgev2}/status/bridgestate.go (100%) rename {bridge => bridgev2}/status/localbridgestate.go (100%) rename {bridge => bridgev2}/status/messagecheckpoint.go (100%) diff --git a/bridge/bridge.go b/bridge/bridge.go deleted file mode 100644 index 17a4a30c..00000000 --- a/bridge/bridge.go +++ /dev/null @@ -1,936 +0,0 @@ -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package bridge - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "net/url" - "os" - "os/signal" - "runtime" - "strings" - "sync" - "syscall" - "time" - - "github.com/lib/pq" - "github.com/mattn/go-sqlite3" - "github.com/rs/zerolog" - "go.mau.fi/util/configupgrade" - "go.mau.fi/util/dbutil" - _ "go.mau.fi/util/dbutil/litestream" - "go.mau.fi/util/exzerolog" - "gopkg.in/yaml.v3" - flag "maunium.net/go/mauflag" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/bridge/bridgeconfig" - "maunium.net/go/mautrix/bridge/status" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/sqlstatestore" -) - -var configPath = flag.MakeFull("c", "config", "The path to your config file.", "config.yaml").String() -var dontSaveConfig = flag.MakeFull("n", "no-update", "Don't save updated config to disk.", "false").Bool() -var registrationPath = flag.MakeFull("r", "registration", "The path where to save the appservice registration.", "registration.yaml").String() -var generateRegistration = flag.MakeFull("g", "generate-registration", "Generate registration and quit.", "false").Bool() -var version = flag.MakeFull("v", "version", "View bridge version and quit.", "false").Bool() -var versionJSON = flag.Make().LongKey("version-json").Usage("Print a JSON object representing the bridge version and quit.").Default("false").Bool() -var ignoreUnsupportedDatabase = flag.Make().LongKey("ignore-unsupported-database").Usage("Run even if the database schema is too new").Default("false").Bool() -var ignoreForeignTables = flag.Make().LongKey("ignore-foreign-tables").Usage("Run even if the database contains tables from other programs (like Synapse)").Default("false").Bool() -var ignoreUnsupportedServer = flag.Make().LongKey("ignore-unsupported-server").Usage("Run even if the Matrix homeserver is outdated").Default("false").Bool() -var wantHelp, _ = flag.MakeHelpFlag() - -var _ appservice.StateStore = (*sqlstatestore.SQLStateStore)(nil) - -type Portal interface { - IsEncrypted() bool - IsPrivateChat() bool - MarkEncrypted() - MainIntent() *appservice.IntentAPI - - ReceiveMatrixEvent(user User, evt *event.Event) - UpdateBridgeInfo(ctx context.Context) -} - -type MembershipHandlingPortal interface { - Portal - HandleMatrixLeave(sender User, evt *event.Event) - HandleMatrixKick(sender User, ghost Ghost, evt *event.Event) - HandleMatrixInvite(sender User, ghost Ghost, evt *event.Event) -} - -type ReadReceiptHandlingPortal interface { - Portal - HandleMatrixReadReceipt(sender User, eventID id.EventID, receipt event.ReadReceipt) -} - -type TypingPortal interface { - Portal - HandleMatrixTyping(userIDs []id.UserID) -} - -type MetaHandlingPortal interface { - Portal - HandleMatrixMeta(sender User, evt *event.Event) -} - -type DisappearingPortal interface { - Portal - ScheduleDisappearing() -} - -type PowerLevelHandlingPortal interface { - Portal - HandleMatrixPowerLevels(sender User, evt *event.Event) -} - -type JoinRuleHandlingPortal interface { - Portal - HandleMatrixJoinRule(sender User, evt *event.Event) -} - -type BanHandlingPortal interface { - Portal - HandleMatrixBan(sender User, ghost Ghost, evt *event.Event) - HandleMatrixUnban(sender User, ghost Ghost, evt *event.Event) -} - -type KnockHandlingPortal interface { - Portal - HandleMatrixKnock(sender User, evt *event.Event) - HandleMatrixRetractKnock(sender User, evt *event.Event) - HandleMatrixAcceptKnock(sender User, ghost Ghost, evt *event.Event) - HandleMatrixRejectKnock(sender User, ghost Ghost, evt *event.Event) -} - -type InviteHandlingPortal interface { - Portal - HandleMatrixAcceptInvite(sender User, evt *event.Event) - HandleMatrixRejectInvite(sender User, evt *event.Event) - HandleMatrixRetractInvite(sender User, ghost Ghost, evt *event.Event) -} - -type User interface { - GetPermissionLevel() bridgeconfig.PermissionLevel - IsLoggedIn() bool - GetManagementRoomID() id.RoomID - SetManagementRoom(id.RoomID) - GetMXID() id.UserID - GetIDoublePuppet() DoublePuppet - GetIGhost() Ghost -} - -type DoublePuppet interface { - CustomIntent() *appservice.IntentAPI - SwitchCustomMXID(accessToken string, userID id.UserID) error - ClearCustomMXID() -} - -type Ghost interface { - DoublePuppet - DefaultIntent() *appservice.IntentAPI - GetMXID() id.UserID -} - -type GhostWithProfile interface { - Ghost - GetDisplayname() string - GetAvatarURL() id.ContentURI -} - -type ChildOverride interface { - GetExampleConfig() string - GetConfigPtr() interface{} - - Init() - Start() - Stop() - - GetIPortal(id.RoomID) Portal - GetAllIPortals() []Portal - GetIUser(id id.UserID, create bool) User - IsGhost(id.UserID) bool - GetIGhost(id.UserID) Ghost - CreatePrivatePortal(id.RoomID, User, Ghost) -} - -type ConfigValidatingBridge interface { - ChildOverride - ValidateConfig() error -} - -type FlagHandlingBridge interface { - ChildOverride - HandleFlags() bool -} - -type PreInitableBridge interface { - ChildOverride - PreInit() -} - -type WebsocketStartingBridge interface { - ChildOverride - OnWebsocketConnect() -} - -type CSFeatureRequirer interface { - CheckFeatures(versions *mautrix.RespVersions) (string, bool) -} - -type Bridge struct { - Name string - URL string - Description string - Version string - ProtocolName string - BeeperServiceName string - BeeperNetworkName string - - AdditionalShortFlags string - AdditionalLongFlags string - - VersionDesc string - LinkifiedVersion string - BuildTime string - commit string - baseVersion string - - PublicHSAddress *url.URL - - DoublePuppet *doublePuppetUtil - - AS *appservice.AppService - EventProcessor *appservice.EventProcessor - CommandProcessor CommandProcessor - MatrixHandler *MatrixHandler - Bot *appservice.IntentAPI - Config bridgeconfig.BaseConfig - ConfigPath string - RegistrationPath string - SaveConfig bool - ConfigUpgrader configupgrade.BaseUpgrader - DB *dbutil.Database - StateStore *sqlstatestore.SQLStateStore - Crypto Crypto - CryptoPickleKey string - - ZLog *zerolog.Logger - - MediaConfig mautrix.RespMediaConfig - SpecVersions mautrix.RespVersions - - Child ChildOverride - - manualStop chan int - Stopping bool - - latestState *status.BridgeState - - Websocket bool - wsStopPinger chan struct{} - wsStarted chan struct{} - wsStopped chan struct{} - wsShortCircuitReconnectBackoff chan struct{} - wsStartupWait *sync.WaitGroup -} - -type Crypto interface { - HandleMemberEvent(context.Context, *event.Event) - Decrypt(context.Context, *event.Event) (*event.Event, error) - Encrypt(context.Context, id.RoomID, event.Type, *event.Content) error - WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool - RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) - ResetSession(context.Context, id.RoomID) - Init(ctx context.Context) error - Start() - Stop() - Reset(ctx context.Context, startAfterReset bool) - Client() *mautrix.Client - ShareKeys(context.Context) error -} - -func (br *Bridge) GenerateRegistration() { - if !br.SaveConfig { - // We need to save the generated as_token and hs_token in the config - _, _ = fmt.Fprintln(os.Stderr, "--no-update is not compatible with --generate-registration") - os.Exit(5) - } else if br.Config.Homeserver.Domain == "example.com" { - _, _ = fmt.Fprintln(os.Stderr, "Homeserver domain is not set") - os.Exit(20) - } - reg := br.Config.GenerateRegistration() - err := reg.Save(br.RegistrationPath) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to save registration:", err) - os.Exit(21) - } - - updateTokens := func(helper configupgrade.Helper) { - helper.Set(configupgrade.Str, reg.AppToken, "appservice", "as_token") - helper.Set(configupgrade.Str, reg.ServerToken, "appservice", "hs_token") - } - _, _, err = configupgrade.Do(br.ConfigPath, true, br.ConfigUpgrader, configupgrade.SimpleUpgrader(updateTokens)) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to save config:", err) - os.Exit(22) - } - fmt.Println("Registration generated. See https://docs.mau.fi/bridges/general/registering-appservices.html for instructions on installing the registration.") - os.Exit(0) -} - -func (br *Bridge) InitVersion(tag, commit, buildTime string) { - br.baseVersion = br.Version - if len(tag) > 0 && tag[0] == 'v' { - tag = tag[1:] - } - if tag != br.Version { - suffix := "" - if !strings.HasSuffix(br.Version, "+dev") { - suffix = "+dev" - } - if len(commit) > 8 { - br.Version = fmt.Sprintf("%s%s.%s", br.Version, suffix, commit[:8]) - } else { - br.Version = fmt.Sprintf("%s%s.unknown", br.Version, suffix) - } - } - - br.LinkifiedVersion = fmt.Sprintf("v%s", br.Version) - if tag == br.Version { - br.LinkifiedVersion = fmt.Sprintf("[v%s](%s/releases/v%s)", br.Version, br.URL, tag) - } else if len(commit) > 8 { - br.LinkifiedVersion = strings.Replace(br.LinkifiedVersion, commit[:8], fmt.Sprintf("[%s](%s/commit/%s)", commit[:8], br.URL, commit), 1) - } - mautrix.DefaultUserAgent = fmt.Sprintf("%s/%s %s", br.Name, br.Version, mautrix.DefaultUserAgent) - br.VersionDesc = fmt.Sprintf("%s %s (%s with %s)", br.Name, br.Version, buildTime, runtime.Version()) - br.commit = commit - br.BuildTime = buildTime -} - -var MinSpecVersion = mautrix.SpecV14 - -func (br *Bridge) logInitialRequestError(err error, defaultMessage string) { - if errors.Is(err, mautrix.MUnknownToken) { - br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was not accepted. Is the registration file installed in your homeserver correctly?") - br.ZLog.Info().Msg("See https://docs.mau.fi/faq/as-token for more info") - } else if errors.Is(err, mautrix.MExclusive) { - br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The as_token was accepted, but the /register request was not. Are the homeserver domain, bot username and username template in the config correct, and do they match the values in the registration?") - br.ZLog.Info().Msg("See https://docs.mau.fi/faq/as-register for more info") - } else { - br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg(defaultMessage) - } -} - -func (br *Bridge) ensureConnection(ctx context.Context) { - for { - versions, err := br.Bot.Versions(ctx) - if err != nil { - if errors.Is(err, mautrix.MForbidden) { - br.ZLog.Debug().Msg("M_FORBIDDEN in /versions, trying to register before retrying") - err = br.Bot.EnsureRegistered(ctx) - if err != nil { - br.logInitialRequestError(err, "Failed to register after /versions failed with M_FORBIDDEN") - os.Exit(16) - } - } else { - br.ZLog.Err(err).Msg("Failed to connect to homeserver, retrying in 10 seconds...") - time.Sleep(10 * time.Second) - } - } else { - br.SpecVersions = *versions - *br.AS.SpecVersions = *versions - break - } - } - - unsupportedServerLogLevel := zerolog.FatalLevel - if *ignoreUnsupportedServer { - unsupportedServerLogLevel = zerolog.ErrorLevel - } - if br.Config.Homeserver.Software == bridgeconfig.SoftwareHungry && !br.SpecVersions.Supports(mautrix.BeeperFeatureHungry) { - br.ZLog.WithLevel(zerolog.FatalLevel).Msg("The config claims the homeserver is hungryserv, but the /versions response didn't confirm it") - os.Exit(18) - } else if !br.SpecVersions.ContainsGreaterOrEqual(MinSpecVersion) { - br.ZLog.WithLevel(unsupportedServerLogLevel). - Stringer("server_supports", br.SpecVersions.GetLatest()). - Stringer("bridge_requires", MinSpecVersion). - Msg("The homeserver is outdated (supported spec versions are below minimum required by bridge)") - if !*ignoreUnsupportedServer { - os.Exit(18) - } - } else if fr, ok := br.Child.(CSFeatureRequirer); ok { - if msg, hasFeatures := fr.CheckFeatures(&br.SpecVersions); !hasFeatures { - br.ZLog.WithLevel(unsupportedServerLogLevel).Msg(msg) - if !*ignoreUnsupportedServer { - os.Exit(18) - } - } - } - - resp, err := br.Bot.Whoami(ctx) - if err != nil { - br.logInitialRequestError(err, "/whoami request failed with unknown error") - os.Exit(16) - } else if resp.UserID != br.Bot.UserID { - br.ZLog.WithLevel(zerolog.FatalLevel). - Stringer("got_user_id", resp.UserID). - Stringer("expected_user_id", br.Bot.UserID). - Msg("Unexpected user ID in whoami call") - os.Exit(17) - } - - if br.Websocket { - br.ZLog.Debug().Msg("Websocket mode: no need to check status of homeserver -> bridge connection") - return - } else if !br.SpecVersions.Supports(mautrix.FeatureAppservicePing) { - br.ZLog.Debug().Msg("Homeserver does not support checking status of homeserver -> bridge connection") - return - } - var pingResp *mautrix.RespAppservicePing - var txnID string - var retryCount int - const maxRetries = 6 - for { - txnID = br.Bot.TxnID() - pingResp, err = br.Bot.AppservicePing(ctx, br.Config.AppService.ID, txnID) - if err == nil { - break - } - var httpErr mautrix.HTTPError - var pingErrBody string - if errors.As(err, &httpErr) && httpErr.RespError != nil { - if val, ok := httpErr.RespError.ExtraData["body"].(string); ok { - pingErrBody = strings.TrimSpace(val) - } - } - outOfRetries := retryCount >= maxRetries - level := zerolog.ErrorLevel - if outOfRetries { - level = zerolog.FatalLevel - } - evt := br.ZLog.WithLevel(level).Err(err).Str("txn_id", txnID) - if pingErrBody != "" { - bodyBytes := []byte(pingErrBody) - if json.Valid(bodyBytes) { - evt.RawJSON("body", bodyBytes) - } else { - evt.Str("body", pingErrBody) - } - } - if outOfRetries { - evt.Msg("Homeserver -> bridge connection is not working") - br.ZLog.Info().Msg("See https://docs.mau.fi/faq/as-ping for more info") - os.Exit(13) - } - evt.Msg("Homeserver -> bridge connection is not working, retrying in 5 seconds...") - time.Sleep(5 * time.Second) - retryCount++ - } - br.ZLog.Debug(). - Str("txn_id", txnID). - Int64("duration_ms", pingResp.DurationMS). - Msg("Homeserver -> bridge connection works") -} - -func (br *Bridge) fetchMediaConfig(ctx context.Context) { - cfg, err := br.Bot.GetMediaConfig(ctx) - if err != nil { - br.ZLog.Warn().Err(err).Msg("Failed to fetch media config") - } else { - if cfg.UploadSize == 0 { - cfg.UploadSize = 50 * 1024 * 1024 - } - br.MediaConfig = *cfg - } -} - -func (br *Bridge) UpdateBotProfile(ctx context.Context) { - br.ZLog.Debug().Msg("Updating bot profile") - botConfig := &br.Config.AppService.Bot - - var err error - var mxc id.ContentURI - if botConfig.Avatar == "remove" { - err = br.Bot.SetAvatarURL(ctx, mxc) - } else if !botConfig.ParsedAvatar.IsEmpty() { - err = br.Bot.SetAvatarURL(ctx, botConfig.ParsedAvatar) - } - if err != nil { - br.ZLog.Warn().Err(err).Msg("Failed to update bot avatar") - } - - if botConfig.Displayname == "remove" { - err = br.Bot.SetDisplayName(ctx, "") - } else if len(botConfig.Displayname) > 0 { - err = br.Bot.SetDisplayName(ctx, botConfig.Displayname) - } - if err != nil { - br.ZLog.Warn().Err(err).Msg("Failed to update bot displayname") - } - - if br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) && br.BeeperNetworkName != "" { - br.ZLog.Debug().Msg("Setting contact info on the appservice bot") - br.Bot.BeeperUpdateProfile(ctx, map[string]any{ - "com.beeper.bridge.service": br.BeeperServiceName, - "com.beeper.bridge.network": br.BeeperNetworkName, - "com.beeper.bridge.is_bridge_bot": true, - }) - } -} - -func (br *Bridge) loadConfig() { - configData, upgraded, err := configupgrade.Do(br.ConfigPath, br.SaveConfig, br.ConfigUpgrader) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Error updating config:", err) - if configData == nil { - os.Exit(10) - } - } - - target := br.Child.GetConfigPtr() - if !upgraded { - // Fallback: if config upgrading failed, load example config for base values - err = yaml.Unmarshal([]byte(br.Child.GetExampleConfig()), &target) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to unmarshal example config:", err) - os.Exit(10) - } - } - err = yaml.Unmarshal(configData, target) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to parse config:", err) - os.Exit(10) - } -} - -func (br *Bridge) validateConfig() error { - switch { - case br.Config.Homeserver.Address == "https://matrix.example.com": - return errors.New("homeserver.address not configured") - case br.Config.Homeserver.Domain == "example.com": - return errors.New("homeserver.domain not configured") - case !bridgeconfig.AllowedHomeserverSoftware[br.Config.Homeserver.Software]: - return errors.New("invalid value for homeserver.software (use `standard` if you don't know what the field is for)") - case br.Config.AppService.ASToken == "This value is generated when generating the registration": - return errors.New("appservice.as_token not configured. Did you forget to generate the registration? ") - case br.Config.AppService.HSToken == "This value is generated when generating the registration": - return errors.New("appservice.hs_token not configured. Did you forget to generate the registration? ") - case br.Config.AppService.Database.URI == "postgres://user:password@host/database?sslmode=disable": - return errors.New("appservice.database not configured") - default: - err := br.Config.Bridge.Validate() - if err != nil { - return err - } - validator, ok := br.Child.(ConfigValidatingBridge) - if ok { - return validator.ValidateConfig() - } - return nil - } -} - -func (br *Bridge) getProfile(userID id.UserID, roomID id.RoomID) *event.MemberEventContent { - ghost := br.Child.GetIGhost(userID) - if ghost == nil { - return nil - } - profilefulGhost, ok := ghost.(GhostWithProfile) - if ok { - return &event.MemberEventContent{ - Displayname: profilefulGhost.GetDisplayname(), - AvatarURL: profilefulGhost.GetAvatarURL().CUString(), - } - } - return nil -} - -func (br *Bridge) init() { - pib, ok := br.Child.(PreInitableBridge) - if ok { - pib.PreInit() - } - - var err error - - br.MediaConfig.UploadSize = 50 * 1024 * 1024 - - br.ZLog, err = br.Config.Logging.Compile() - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to initialize logger:", err) - os.Exit(12) - } - exzerolog.SetupDefaults(br.ZLog) - - br.DoublePuppet = &doublePuppetUtil{br: br, log: br.ZLog.With().Str("component", "double puppet").Logger()} - - err = br.validateConfig() - if err != nil { - br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("Configuration error") - br.ZLog.Info().Msg("See https://docs.mau.fi/faq/field-unconfigured for more info") - os.Exit(11) - } - - br.ZLog.Info(). - Str("name", br.Name). - Str("version", br.Version). - Str("built_at", br.BuildTime). - Str("go_version", runtime.Version()). - Msg("Initializing bridge") - - br.ZLog.Debug().Msg("Initializing database connection") - dbConfig := br.Config.AppService.Database - if (dbConfig.Type == "sqlite3-fk-wal" || dbConfig.Type == "litestream") && dbConfig.MaxOpenConns != 1 && !strings.Contains(dbConfig.URI, "_txlock=immediate") { - var fixedExampleURI string - if !strings.HasPrefix(dbConfig.URI, "file:") { - fixedExampleURI = fmt.Sprintf("file:%s?_txlock=immediate", dbConfig.URI) - } else if !strings.ContainsRune(dbConfig.URI, '?') { - fixedExampleURI = fmt.Sprintf("%s?_txlock=immediate", dbConfig.URI) - } else { - fixedExampleURI = fmt.Sprintf("%s&_txlock=immediate", dbConfig.URI) - } - br.ZLog.Warn(). - Str("fixed_uri_example", fixedExampleURI). - Msg("Using SQLite without _txlock=immediate is not recommended") - } - br.DB, err = dbutil.NewFromConfig(br.Name, dbConfig, dbutil.ZeroLogger(br.ZLog.With().Str("db_section", "main").Logger())) - if err != nil { - br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to initialize database connection") - if sqlError := (&sqlite3.Error{}); errors.As(err, sqlError) && sqlError.Code == sqlite3.ErrCorrupt { - os.Exit(18) - } - os.Exit(14) - } - br.DB.IgnoreUnsupportedDatabase = *ignoreUnsupportedDatabase - br.DB.IgnoreForeignTables = *ignoreForeignTables - - br.ZLog.Debug().Msg("Initializing state store") - br.StateStore = sqlstatestore.NewSQLStateStore(br.DB, dbutil.ZeroLogger(br.ZLog.With().Str("db_section", "matrix_state").Logger()), true) - - br.AS, err = appservice.CreateFull(appservice.CreateOpts{ - Registration: br.Config.AppService.GetRegistration(), - HomeserverDomain: br.Config.Homeserver.Domain, - HomeserverURL: br.Config.Homeserver.Address, - HostConfig: appservice.HostConfig{ - Hostname: br.Config.AppService.Hostname, - Port: br.Config.AppService.Port, - }, - StateStore: br.StateStore, - }) - if err != nil { - br.ZLog.WithLevel(zerolog.FatalLevel).Err(err). - Msg("Failed to initialize appservice") - os.Exit(15) - } - br.AS.Log = *br.ZLog - br.AS.DoublePuppetValue = br.Name - br.AS.GetProfile = br.getProfile - br.Bot = br.AS.BotIntent() - - br.ZLog.Debug().Msg("Initializing Matrix event processor") - br.EventProcessor = appservice.NewEventProcessor(br.AS) - if !br.Config.AppService.AsyncTransactions { - br.EventProcessor.ExecMode = appservice.Sync - } - br.ZLog.Debug().Msg("Initializing Matrix event handler") - br.MatrixHandler = NewMatrixHandler(br) - - br.Crypto = NewCryptoHelper(br) - - hsURL := br.Config.Homeserver.Address - if br.Config.Homeserver.PublicAddress != "" { - hsURL = br.Config.Homeserver.PublicAddress - } - br.PublicHSAddress, err = url.Parse(hsURL) - if err != nil { - br.ZLog.WithLevel(zerolog.FatalLevel).Err(err). - Str("input", hsURL). - Msg("Failed to parse public homeserver URL") - os.Exit(15) - } - - br.Child.Init() -} - -type zerologPQError pq.Error - -func (zpe *zerologPQError) MarshalZerologObject(evt *zerolog.Event) { - maybeStr := func(field, value string) { - if value != "" { - evt.Str(field, value) - } - } - maybeStr("severity", zpe.Severity) - if name := zpe.Code.Name(); name != "" { - evt.Str("code", name) - } else if zpe.Code != "" { - evt.Str("code", string(zpe.Code)) - } - //maybeStr("message", zpe.Message) - maybeStr("detail", zpe.Detail) - maybeStr("hint", zpe.Hint) - maybeStr("position", zpe.Position) - maybeStr("internal_position", zpe.InternalPosition) - maybeStr("internal_query", zpe.InternalQuery) - maybeStr("where", zpe.Where) - maybeStr("schema", zpe.Schema) - maybeStr("table", zpe.Table) - maybeStr("column", zpe.Column) - maybeStr("data_type_name", zpe.DataTypeName) - maybeStr("constraint", zpe.Constraint) - maybeStr("file", zpe.File) - maybeStr("line", zpe.Line) - maybeStr("routine", zpe.Routine) -} - -func (br *Bridge) LogDBUpgradeErrorAndExit(name string, err error) { - logEvt := br.ZLog.WithLevel(zerolog.FatalLevel). - Err(err). - Str("db_section", name) - var errWithLine *dbutil.PQErrorWithLine - if errors.As(err, &errWithLine) { - logEvt.Str("sql_line", errWithLine.Line) - } - var pqe *pq.Error - if errors.As(err, &pqe) { - logEvt.Object("pq_error", (*zerologPQError)(pqe)) - } - logEvt.Msg("Failed to initialize database") - if sqlError := (&sqlite3.Error{}); errors.As(err, sqlError) && sqlError.Code == sqlite3.ErrCorrupt { - os.Exit(18) - } else if errors.Is(err, dbutil.ErrForeignTables) { - br.ZLog.Info().Msg("See https://docs.mau.fi/faq/foreign-tables for more info") - } else if errors.Is(err, dbutil.ErrNotOwned) { - br.ZLog.Info().Msg("Sharing the same database with different programs is not supported") - } else if errors.Is(err, dbutil.ErrUnsupportedDatabaseVersion) { - br.ZLog.Info().Msg("Downgrading the bridge is not supported") - } - os.Exit(15) -} - -func (br *Bridge) WaitWebsocketConnected() { - if br.wsStartupWait != nil { - br.wsStartupWait.Wait() - } -} - -func (br *Bridge) start() { - br.ZLog.Debug().Msg("Running database upgrades") - err := br.DB.Upgrade(br.ZLog.With().Str("db_section", "main").Logger().WithContext(context.TODO())) - if err != nil { - br.LogDBUpgradeErrorAndExit("main", err) - } else if err = br.StateStore.Upgrade(br.ZLog.With().Str("db_section", "matrix_state").Logger().WithContext(context.TODO())); err != nil { - br.LogDBUpgradeErrorAndExit("matrix_state", err) - } - - if br.Config.Homeserver.Websocket || len(br.Config.Homeserver.WSProxy) > 0 { - br.Websocket = true - br.ZLog.Debug().Msg("Starting application service websocket") - var wg sync.WaitGroup - wg.Add(1) - br.wsStartupWait = &wg - br.wsShortCircuitReconnectBackoff = make(chan struct{}) - go br.startWebsocket(&wg) - } else if br.AS.Host.IsConfigured() { - br.ZLog.Debug().Msg("Starting application service HTTP server") - go br.AS.Start() - } else { - br.ZLog.WithLevel(zerolog.FatalLevel).Msg("Neither appservice HTTP listener nor websocket is enabled") - os.Exit(23) - } - br.ZLog.Debug().Msg("Checking connection to homeserver") - - ctx := br.ZLog.WithContext(context.Background()) - br.ensureConnection(ctx) - go br.fetchMediaConfig(ctx) - - if br.Crypto != nil { - err = br.Crypto.Init(ctx) - if err != nil { - br.ZLog.WithLevel(zerolog.FatalLevel).Err(err).Msg("Error initializing end-to-bridge encryption") - os.Exit(19) - } - } - - br.ZLog.Debug().Msg("Starting event processor") - br.EventProcessor.Start(ctx) - - go br.UpdateBotProfile(ctx) - if br.Crypto != nil { - go br.Crypto.Start() - } - - br.Child.Start() - br.WaitWebsocketConnected() - br.AS.Ready = true - - if br.Config.Bridge.GetResendBridgeInfo() { - go br.ResendBridgeInfo() - } - if br.Websocket && br.Config.Homeserver.WSPingInterval > 0 { - br.wsStopPinger = make(chan struct{}, 1) - go br.websocketServerPinger() - } -} - -func (br *Bridge) ResendBridgeInfo() { - if !br.SaveConfig { - br.ZLog.Warn().Msg("Not setting resend_bridge_info to false in config due to --no-update flag") - } else { - _, _, err := configupgrade.Do(br.ConfigPath, true, br.ConfigUpgrader, configupgrade.SimpleUpgrader(func(helper configupgrade.Helper) { - helper.Set(configupgrade.Bool, "false", "bridge", "resend_bridge_info") - })) - if err != nil { - br.ZLog.Err(err).Msg("Failed to save config after setting resend_bridge_info to false") - } - } - br.ZLog.Info().Msg("Re-sending bridge info state event to all portals") - for _, portal := range br.Child.GetAllIPortals() { - portal.UpdateBridgeInfo(context.TODO()) - } - br.ZLog.Info().Msg("Finished re-sending bridge info state events") -} - -func sendStopSignal(ch chan struct{}) { - if ch != nil { - select { - case ch <- struct{}{}: - default: - } - } -} - -func (br *Bridge) stop() { - br.Stopping = true - if br.Crypto != nil { - br.Crypto.Stop() - } - waitForWS := false - if br.AS.StopWebsocket != nil { - br.ZLog.Debug().Msg("Stopping application service websocket") - br.AS.StopWebsocket(appservice.ErrWebsocketManualStop) - waitForWS = true - } - br.AS.Stop() - sendStopSignal(br.wsStopPinger) - sendStopSignal(br.wsShortCircuitReconnectBackoff) - br.EventProcessor.Stop() - br.Child.Stop() - err := br.DB.Close() - if err != nil { - br.ZLog.Warn().Err(err).Msg("Error closing database") - } - if waitForWS { - select { - case <-br.wsStopped: - case <-time.After(4 * time.Second): - br.ZLog.Warn().Msg("Timed out waiting for websocket to close") - } - } -} - -func (br *Bridge) ManualStop(exitCode int) { - if br.manualStop != nil { - br.manualStop <- exitCode - } else { - os.Exit(exitCode) - } -} - -type VersionJSONOutput struct { - Name string - URL string - - Version string - IsRelease bool - Commit string - FormattedVersion string - BuildTime string - - OS string - Arch string - - Mautrix struct { - Version string - Commit string - } -} - -func (br *Bridge) Main() { - flag.SetHelpTitles( - fmt.Sprintf("%s - %s", br.Name, br.Description), - fmt.Sprintf("%s [-hgvn%s] [-c ] [-r ]%s", br.Name, br.AdditionalShortFlags, br.AdditionalLongFlags)) - err := flag.Parse() - br.ConfigPath = *configPath - br.RegistrationPath = *registrationPath - br.SaveConfig = !*dontSaveConfig - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, err) - flag.PrintHelp() - os.Exit(1) - } else if *wantHelp { - flag.PrintHelp() - os.Exit(0) - } else if *version { - fmt.Println(br.VersionDesc) - return - } else if *versionJSON { - output := VersionJSONOutput{ - URL: br.URL, - Name: br.Name, - - Version: br.baseVersion, - IsRelease: br.Version == br.baseVersion, - Commit: br.commit, - FormattedVersion: br.Version, - BuildTime: br.BuildTime, - - OS: runtime.GOOS, - Arch: runtime.GOARCH, - } - output.Mautrix.Commit = mautrix.Commit - output.Mautrix.Version = mautrix.Version - _ = json.NewEncoder(os.Stdout).Encode(output) - return - } else if flagHandler, ok := br.Child.(FlagHandlingBridge); ok && flagHandler.HandleFlags() { - return - } - - br.loadConfig() - - if *generateRegistration { - br.GenerateRegistration() - return - } - - br.manualStop = make(chan int, 1) - br.init() - br.ZLog.Info().Msg("Bridge initialization complete, starting...") - br.start() - br.ZLog.Info().Msg("Bridge started!") - - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - var exitCode int - select { - case <-c: - br.ZLog.Info().Msg("Interrupt received, stopping...") - case exitCode = <-br.manualStop: - br.ZLog.Info().Int("exit_code", exitCode).Msg("Manual stop requested") - } - - br.stop() - br.ZLog.Info().Msg("Bridge stopped.") - os.Exit(exitCode) -} diff --git a/bridge/bridgeconfig/config.go b/bridge/bridgeconfig/config.go deleted file mode 100644 index dfb6b7e5..00000000 --- a/bridge/bridgeconfig/config.go +++ /dev/null @@ -1,337 +0,0 @@ -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package bridgeconfig - -import ( - "fmt" - "os" - "path/filepath" - "regexp" - "strings" - - "github.com/rs/zerolog" - up "go.mau.fi/util/configupgrade" - "go.mau.fi/util/dbutil" - "go.mau.fi/util/random" - "go.mau.fi/zeroconfig" - "gopkg.in/yaml.v3" - - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/id" -) - -type HomeserverSoftware string - -const ( - SoftwareStandard HomeserverSoftware = "standard" - SoftwareAsmux HomeserverSoftware = "asmux" - SoftwareHungry HomeserverSoftware = "hungry" -) - -var AllowedHomeserverSoftware = map[HomeserverSoftware]bool{ - SoftwareStandard: true, - SoftwareAsmux: true, - SoftwareHungry: true, -} - -type HomeserverConfig struct { - Address string `yaml:"address"` - Domain string `yaml:"domain"` - AsyncMedia bool `yaml:"async_media"` - - PublicAddress string `yaml:"public_address,omitempty"` - - Software HomeserverSoftware `yaml:"software"` - - StatusEndpoint string `yaml:"status_endpoint"` - MessageSendCheckpointEndpoint string `yaml:"message_send_checkpoint_endpoint"` - - Websocket bool `yaml:"websocket"` - WSProxy string `yaml:"websocket_proxy"` - WSPingInterval int `yaml:"ping_interval_seconds"` -} - -type AppserviceConfig struct { - Address string `yaml:"address"` - Hostname string `yaml:"hostname"` - Port uint16 `yaml:"port"` - - Database dbutil.Config `yaml:"database"` - - ID string `yaml:"id"` - Bot BotUserConfig `yaml:"bot"` - - ASToken string `yaml:"as_token"` - HSToken string `yaml:"hs_token"` - - EphemeralEvents bool `yaml:"ephemeral_events"` - AsyncTransactions bool `yaml:"async_transactions"` -} - -func (config *BaseConfig) MakeUserIDRegex(matcher string) *regexp.Regexp { - usernamePlaceholder := strings.ToLower(random.String(16)) - usernameTemplate := fmt.Sprintf("@%s:%s", - config.Bridge.FormatUsername(usernamePlaceholder), - config.Homeserver.Domain) - usernameTemplate = regexp.QuoteMeta(usernameTemplate) - usernameTemplate = strings.Replace(usernameTemplate, usernamePlaceholder, matcher, 1) - usernameTemplate = fmt.Sprintf("^%s$", usernameTemplate) - return regexp.MustCompile(usernameTemplate) -} - -// GenerateRegistration generates a registration file for the homeserver. -func (config *BaseConfig) GenerateRegistration() *appservice.Registration { - registration := appservice.CreateRegistration() - config.AppService.HSToken = registration.ServerToken - config.AppService.ASToken = registration.AppToken - config.AppService.copyToRegistration(registration) - - registration.SenderLocalpart = random.String(32) - botRegex := regexp.MustCompile(fmt.Sprintf("^@%s:%s$", - regexp.QuoteMeta(config.AppService.Bot.Username), - regexp.QuoteMeta(config.Homeserver.Domain))) - registration.Namespaces.UserIDs.Register(botRegex, true) - registration.Namespaces.UserIDs.Register(config.MakeUserIDRegex(".*"), true) - - return registration -} - -func (config *BaseConfig) MakeAppService() *appservice.AppService { - as := appservice.Create() - as.HomeserverDomain = config.Homeserver.Domain - _ = as.SetHomeserverURL(config.Homeserver.Address) - as.Host.Hostname = config.AppService.Hostname - as.Host.Port = config.AppService.Port - as.Registration = config.AppService.GetRegistration() - return as -} - -// GetRegistration copies the data from the bridge config into an *appservice.Registration struct. -// This can't be used with the homeserver, see GenerateRegistration for generating files for the homeserver. -func (asc *AppserviceConfig) GetRegistration() *appservice.Registration { - reg := &appservice.Registration{} - asc.copyToRegistration(reg) - reg.SenderLocalpart = asc.Bot.Username - reg.ServerToken = asc.HSToken - reg.AppToken = asc.ASToken - return reg -} - -func (asc *AppserviceConfig) copyToRegistration(registration *appservice.Registration) { - registration.ID = asc.ID - registration.URL = asc.Address - falseVal := false - registration.RateLimited = &falseVal - registration.EphemeralEvents = asc.EphemeralEvents - registration.SoruEphemeralEvents = asc.EphemeralEvents -} - -type BotUserConfig struct { - Username string `yaml:"username"` - Displayname string `yaml:"displayname"` - Avatar string `yaml:"avatar"` - - ParsedAvatar id.ContentURI `yaml:"-"` -} - -type serializableBUC BotUserConfig - -func (buc *BotUserConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { - var sbuc serializableBUC - err := unmarshal(&sbuc) - if err != nil { - return err - } - *buc = (BotUserConfig)(sbuc) - if buc.Avatar != "" && buc.Avatar != "remove" { - buc.ParsedAvatar, err = id.ParseContentURI(buc.Avatar) - if err != nil { - return fmt.Errorf("%w in bot avatar", err) - } - } - return nil -} - -type BridgeConfig interface { - FormatUsername(username string) string - GetEncryptionConfig() EncryptionConfig - GetCommandPrefix() string - GetManagementRoomTexts() ManagementRoomTexts - GetDoublePuppetConfig() DoublePuppetConfig - GetResendBridgeInfo() bool - EnableMessageStatusEvents() bool - EnableMessageErrorNotices() bool - Validate() error -} - -type DoublePuppetConfig struct { - ServerMap map[string]string `yaml:"double_puppet_server_map"` - AllowDiscovery bool `yaml:"double_puppet_allow_discovery"` - SharedSecretMap map[string]string `yaml:"login_shared_secret_map"` -} - -type EncryptionConfig struct { - Allow bool `yaml:"allow"` - Default bool `yaml:"default"` - Require bool `yaml:"require"` - Appservice bool `yaml:"appservice"` - - PlaintextMentions bool `yaml:"plaintext_mentions"` - - DeleteKeys struct { - DeleteOutboundOnAck bool `yaml:"delete_outbound_on_ack"` - DontStoreOutbound bool `yaml:"dont_store_outbound"` - RatchetOnDecrypt bool `yaml:"ratchet_on_decrypt"` - DeleteFullyUsedOnDecrypt bool `yaml:"delete_fully_used_on_decrypt"` - DeletePrevOnNewSession bool `yaml:"delete_prev_on_new_session"` - DeleteOnDeviceDelete bool `yaml:"delete_on_device_delete"` - PeriodicallyDeleteExpired bool `yaml:"periodically_delete_expired"` - DeleteOutdatedInbound bool `yaml:"delete_outdated_inbound"` - } `yaml:"delete_keys"` - - VerificationLevels struct { - Receive id.TrustState `yaml:"receive"` - Send id.TrustState `yaml:"send"` - Share id.TrustState `yaml:"share"` - } `yaml:"verification_levels"` - AllowKeySharing bool `yaml:"allow_key_sharing"` - - Rotation struct { - EnableCustom bool `yaml:"enable_custom"` - Milliseconds int64 `yaml:"milliseconds"` - Messages int `yaml:"messages"` - - DisableDeviceChangeKeyRotation bool `yaml:"disable_device_change_key_rotation"` - } `yaml:"rotation"` -} - -type ManagementRoomTexts struct { - Welcome string `yaml:"welcome"` - WelcomeConnected string `yaml:"welcome_connected"` - WelcomeUnconnected string `yaml:"welcome_unconnected"` - AdditionalHelp string `yaml:"additional_help"` -} - -type BaseConfig struct { - Homeserver HomeserverConfig `yaml:"homeserver"` - AppService AppserviceConfig `yaml:"appservice"` - Bridge BridgeConfig `yaml:"-"` - Logging zeroconfig.Config `yaml:"logging"` -} - -func doUpgrade(helper up.Helper) { - helper.Copy(up.Str, "homeserver", "address") - helper.Copy(up.Str, "homeserver", "domain") - if legacyAsmuxFlag, ok := helper.Get(up.Bool, "homeserver", "asmux"); ok && legacyAsmuxFlag == "true" { - helper.Set(up.Str, string(SoftwareAsmux), "homeserver", "software") - } else { - helper.Copy(up.Str, "homeserver", "software") - } - helper.Copy(up.Str|up.Null, "homeserver", "status_endpoint") - helper.Copy(up.Str|up.Null, "homeserver", "message_send_checkpoint_endpoint") - helper.Copy(up.Bool, "homeserver", "async_media") - helper.Copy(up.Str|up.Null, "homeserver", "websocket_proxy") - helper.Copy(up.Bool, "homeserver", "websocket") - helper.Copy(up.Int, "homeserver", "ping_interval_seconds") - - helper.Copy(up.Str|up.Null, "appservice", "address") - helper.Copy(up.Str|up.Null, "appservice", "hostname") - helper.Copy(up.Int|up.Null, "appservice", "port") - if dbType, ok := helper.Get(up.Str, "appservice", "database", "type"); ok && dbType == "sqlite3" { - helper.Set(up.Str, "sqlite3-fk-wal", "appservice", "database", "type") - } else { - helper.Copy(up.Str, "appservice", "database", "type") - } - helper.Copy(up.Str, "appservice", "database", "uri") - helper.Copy(up.Int, "appservice", "database", "max_open_conns") - helper.Copy(up.Int, "appservice", "database", "max_idle_conns") - helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_idle_time") - helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_lifetime") - helper.Copy(up.Str, "appservice", "id") - helper.Copy(up.Str, "appservice", "bot", "username") - helper.Copy(up.Str, "appservice", "bot", "displayname") - helper.Copy(up.Str, "appservice", "bot", "avatar") - helper.Copy(up.Bool, "appservice", "ephemeral_events") - helper.Copy(up.Bool, "appservice", "async_transactions") - helper.Copy(up.Str, "appservice", "as_token") - helper.Copy(up.Str, "appservice", "hs_token") - - if helper.GetNode("logging", "writers") == nil && (helper.GetNode("logging", "print_level") != nil || helper.GetNode("logging", "file_name_format") != nil) { - _, _ = fmt.Fprintln(os.Stderr, "Migrating legacy log config") - migrateLegacyLogConfig(helper) - } else if helper.GetNode("logging", "writers") == nil && (helper.GetNode("logging", "handlers") != nil) { - _, _ = fmt.Fprintln(os.Stderr, "Migrating Python log config is not currently supported") - // TODO implement? - //migratePythonLogConfig(helper) - } else { - helper.Copy(up.Map, "logging") - } -} - -type legacyLogConfig struct { - Directory string `yaml:"directory"` - FileNameFormat string `yaml:"file_name_format"` - FileDateFormat string `yaml:"file_date_format"` - FileMode uint32 `yaml:"file_mode"` - TimestampFormat string `yaml:"timestamp_format"` - RawPrintLevel string `yaml:"print_level"` - JSONStdout bool `yaml:"print_json"` - JSONFile bool `yaml:"file_json"` -} - -func migrateLegacyLogConfig(helper up.Helper) { - var llc legacyLogConfig - var newConfig zeroconfig.Config - err := helper.GetBaseNode("logging").Decode(&newConfig) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Base config is corrupted: failed to decode example log config:", err) - return - } else if len(newConfig.Writers) != 2 || newConfig.Writers[0].Type != "stdout" || newConfig.Writers[1].Type != "file" { - _, _ = fmt.Fprintln(os.Stderr, "Base log config is not in expected format") - return - } - err = helper.GetNode("logging").Decode(&llc) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to decode legacy log config:", err) - return - } - if llc.RawPrintLevel != "" { - level, err := zerolog.ParseLevel(llc.RawPrintLevel) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to parse minimum stdout log level:", err) - } else { - newConfig.Writers[0].MinLevel = &level - } - } - if llc.Directory != "" && llc.FileNameFormat != "" { - if llc.FileNameFormat == "{{.Date}}-{{.Index}}.log" { - llc.FileNameFormat = "bridge.log" - } else { - llc.FileNameFormat = strings.ReplaceAll(llc.FileNameFormat, "{{.Date}}", "") - llc.FileNameFormat = strings.ReplaceAll(llc.FileNameFormat, "{{.Index}}", "") - } - newConfig.Writers[1].Filename = filepath.Join(llc.Directory, llc.FileNameFormat) - } else if llc.FileNameFormat == "" { - newConfig.Writers = newConfig.Writers[0:1] - } - if llc.JSONStdout { - newConfig.Writers[0].TimeFormat = "" - newConfig.Writers[0].Format = "json" - } else if llc.TimestampFormat != "" { - newConfig.Writers[0].TimeFormat = llc.TimestampFormat - } - var updatedConfig yaml.Node - err = updatedConfig.Encode(&newConfig) - if err != nil { - _, _ = fmt.Fprintln(os.Stderr, "Failed to encode migrated log config:", err) - return - } - *helper.GetBaseNode("logging").Node = updatedConfig -} - -// Upgrader is a config upgrader that copies the default fields in the homeserver, appservice and logging blocks. -var Upgrader = up.SimpleUpgrader(doUpgrade) diff --git a/bridge/bridgeconfig/permissions.go b/bridge/bridgeconfig/permissions.go deleted file mode 100644 index 198e140e..00000000 --- a/bridge/bridgeconfig/permissions.go +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package bridgeconfig - -import ( - "strconv" - "strings" - - "maunium.net/go/mautrix/id" -) - -type PermissionConfig map[string]PermissionLevel - -type PermissionLevel int - -const ( - PermissionLevelBlock PermissionLevel = 0 - PermissionLevelRelay PermissionLevel = 5 - PermissionLevelUser PermissionLevel = 10 - PermissionLevelAdmin PermissionLevel = 100 -) - -var namesToLevels = map[string]PermissionLevel{ - "block": PermissionLevelBlock, - "relay": PermissionLevelRelay, - "user": PermissionLevelUser, - "admin": PermissionLevelAdmin, -} - -func RegisterPermissionLevel(name string, level PermissionLevel) { - namesToLevels[name] = level -} - -func (pc *PermissionConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { - rawPC := make(map[string]string) - err := unmarshal(&rawPC) - if err != nil { - return err - } - - if *pc == nil { - *pc = make(map[string]PermissionLevel) - } - for key, value := range rawPC { - level, ok := namesToLevels[strings.ToLower(value)] - if ok { - (*pc)[key] = level - } else if val, err := strconv.Atoi(value); err == nil { - (*pc)[key] = PermissionLevel(val) - } else { - (*pc)[key] = PermissionLevelBlock - } - } - return nil -} - -func (pc PermissionConfig) Get(userID id.UserID) PermissionLevel { - if level, ok := pc[string(userID)]; ok { - return level - } else if level, ok = pc[userID.Homeserver()]; len(userID.Homeserver()) > 0 && ok { - return level - } else if level, ok = pc["*"]; ok { - return level - } else { - return PermissionLevelBlock - } -} diff --git a/bridge/bridgestate.go b/bridge/bridgestate.go deleted file mode 100644 index f9c3a3c6..00000000 --- a/bridge/bridgestate.go +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package bridge - -import ( - "context" - "runtime/debug" - "time" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/bridge/status" -) - -func (br *Bridge) SendBridgeState(ctx context.Context, state *status.BridgeState) error { - if br.Websocket { - // FIXME this doesn't account for multiple users - br.latestState = state - - return br.AS.SendWebsocket(&appservice.WebsocketRequest{ - Command: "bridge_status", - Data: state, - }) - } else if br.Config.Homeserver.StatusEndpoint != "" { - return state.SendHTTP(ctx, br.Config.Homeserver.StatusEndpoint, br.Config.AppService.ASToken) - } else { - return nil - } -} - -func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) { - if len(br.Config.Homeserver.StatusEndpoint) == 0 && !br.Websocket { - return - } - - for { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - if err := br.SendBridgeState(ctx, &state); err != nil { - br.ZLog.Warn().Err(err).Msg("Failed to update global bridge state") - cancel() - time.Sleep(5 * time.Second) - continue - } else { - br.ZLog.Debug().Interface("bridge_state", state).Msg("Sent new global bridge state") - cancel() - break - } - } -} - -type BridgeStateQueue struct { - prev *status.BridgeState - ch chan status.BridgeState - bridge *Bridge - user status.BridgeStateFiller -} - -func (br *Bridge) NewBridgeStateQueue(user status.BridgeStateFiller) *BridgeStateQueue { - if len(br.Config.Homeserver.StatusEndpoint) == 0 && !br.Websocket { - return nil - } - bsq := &BridgeStateQueue{ - ch: make(chan status.BridgeState, 10), - bridge: br, - user: user, - } - go bsq.loop() - return bsq -} - -func (bsq *BridgeStateQueue) loop() { - defer func() { - err := recover() - if err != nil { - bsq.bridge.ZLog.Error(). - Str(zerolog.ErrorStackFieldName, string(debug.Stack())). - Interface(zerolog.ErrorFieldName, err). - Msg("Panic in bridge state loop") - } - }() - for state := range bsq.ch { - bsq.immediateSendBridgeState(state) - } -} - -func (bsq *BridgeStateQueue) immediateSendBridgeState(state status.BridgeState) { - retryIn := 2 - for { - if bsq.prev != nil && bsq.prev.ShouldDeduplicate(&state) { - bsq.bridge.ZLog.Debug(). - Str("state_event", string(state.StateEvent)). - Msg("Not sending bridge state as it's a duplicate") - return - } - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - err := bsq.bridge.SendBridgeState(ctx, &state) - cancel() - - if err != nil { - bsq.bridge.ZLog.Warn().Err(err). - Int("retry_in_seconds", retryIn). - Msg("Failed to update bridge state") - time.Sleep(time.Duration(retryIn) * time.Second) - retryIn *= 2 - if retryIn > 64 { - retryIn = 64 - } - } else { - bsq.prev = &state - bsq.bridge.ZLog.Debug(). - Interface("bridge_state", state). - Msg("Sent new bridge state") - return - } - } -} - -func (bsq *BridgeStateQueue) Send(state status.BridgeState) { - if bsq == nil { - return - } - - state = state.Fill(bsq.user) - - if len(bsq.ch) >= 8 { - bsq.bridge.ZLog.Warn().Msg("Bridge state queue is nearly full, discarding an item") - select { - case <-bsq.ch: - default: - } - } - select { - case bsq.ch <- state: - default: - bsq.bridge.ZLog.Error().Msg("Bridge state queue is full, dropped new state") - } -} - -func (bsq *BridgeStateQueue) GetPrev() status.BridgeState { - if bsq != nil && bsq.prev != nil { - return *bsq.prev - } - return status.BridgeState{} -} - -func (bsq *BridgeStateQueue) SetPrev(prev status.BridgeState) { - if bsq != nil { - bsq.prev = &prev - } -} diff --git a/bridge/commands/admin.go b/bridge/commands/admin.go deleted file mode 100644 index ff3340e3..00000000 --- a/bridge/commands/admin.go +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright (c) 2022 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package commands - -import ( - "strconv" - - "maunium.net/go/mautrix/id" -) - -var CommandDiscardMegolmSession = &FullHandler{ - Func: func(ce *Event) { - if ce.Bridge.Crypto == nil { - ce.Reply("This bridge instance doesn't have end-to-bridge encryption enabled") - } else { - ce.Bridge.Crypto.ResetSession(ce.Ctx, ce.RoomID) - ce.Reply("Successfully reset Megolm session in this room. New decryption keys will be shared the next time a message is sent from the remote network.") - } - }, - Name: "discard-megolm-session", - Aliases: []string{"discard-session"}, - Help: HelpMeta{ - Section: HelpSectionAdmin, - Description: "Discard the Megolm session in the room", - }, - RequiresAdmin: true, -} - -func fnSetPowerLevel(ce *Event) { - var level int - var userID id.UserID - var err error - if len(ce.Args) == 1 { - level, err = strconv.Atoi(ce.Args[0]) - if err != nil { - ce.Reply("Invalid power level \"%s\"", ce.Args[0]) - return - } - userID = ce.User.GetMXID() - } else if len(ce.Args) == 2 { - userID = id.UserID(ce.Args[0]) - _, _, err := userID.Parse() - if err != nil { - ce.Reply("Invalid user ID \"%s\"", ce.Args[0]) - return - } - level, err = strconv.Atoi(ce.Args[1]) - if err != nil { - ce.Reply("Invalid power level \"%s\"", ce.Args[1]) - return - } - } else { - ce.Reply("**Usage:** `set-pl [user] `") - return - } - _, err = ce.Portal.MainIntent().SetPowerLevel(ce.Ctx, ce.RoomID, userID, level) - if err != nil { - ce.Reply("Failed to set power levels: %v", err) - } -} - -var CommandSetPowerLevel = &FullHandler{ - Func: fnSetPowerLevel, - Name: "set-pl", - Aliases: []string{"set-power-level"}, - Help: HelpMeta{ - Section: HelpSectionAdmin, - Description: "Change the power level in a portal room.", - Args: "[_user ID_] <_power level_>", - }, - RequiresAdmin: true, - RequiresPortal: true, -} diff --git a/bridge/commands/doublepuppet.go b/bridge/commands/doublepuppet.go deleted file mode 100644 index 3f074951..00000000 --- a/bridge/commands/doublepuppet.go +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) 2022 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package commands - -var CommandLoginMatrix = &FullHandler{ - Func: fnLoginMatrix, - Name: "login-matrix", - Help: HelpMeta{ - Section: HelpSectionAuth, - Description: "Enable double puppeting.", - Args: "<_access token_>", - }, - RequiresLogin: true, -} - -func fnLoginMatrix(ce *Event) { - if len(ce.Args) == 0 { - ce.Reply("**Usage:** `login-matrix `") - return - } - puppet := ce.User.GetIDoublePuppet() - if puppet == nil { - puppet = ce.User.GetIGhost() - if puppet == nil { - ce.Reply("Didn't get a ghost :(") - return - } - } - err := puppet.SwitchCustomMXID(ce.Args[0], ce.User.GetMXID()) - if err != nil { - ce.Reply("Failed to enable double puppeting: %v", err) - } else { - ce.Reply("Successfully switched puppet") - } -} - -var CommandPingMatrix = &FullHandler{ - Func: fnPingMatrix, - Name: "ping-matrix", - Help: HelpMeta{ - Section: HelpSectionAuth, - Description: "Ping the Matrix server with the double puppet.", - }, - RequiresLogin: true, -} - -func fnPingMatrix(ce *Event) { - puppet := ce.User.GetIDoublePuppet() - if puppet == nil || puppet.CustomIntent() == nil { - ce.Reply("You are not logged in with your Matrix account.") - return - } - resp, err := puppet.CustomIntent().Whoami(ce.Ctx) - if err != nil { - ce.Reply("Failed to validate Matrix login: %v", err) - } else { - ce.Reply("Confirmed valid access token for %s / %s", resp.UserID, resp.DeviceID) - } -} - -var CommandLogoutMatrix = &FullHandler{ - Func: fnLogoutMatrix, - Name: "logout-matrix", - Help: HelpMeta{ - Section: HelpSectionAuth, - Description: "Disable double puppeting.", - }, - RequiresLogin: true, -} - -func fnLogoutMatrix(ce *Event) { - puppet := ce.User.GetIDoublePuppet() - if puppet == nil || puppet.CustomIntent() == nil { - ce.Reply("You don't have double puppeting enabled.") - return - } - puppet.ClearCustomMXID() - ce.Reply("Successfully disabled double puppeting.") -} diff --git a/bridge/commands/event.go b/bridge/commands/event.go deleted file mode 100644 index 49a8b277..00000000 --- a/bridge/commands/event.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package commands - -import ( - "context" - "fmt" - "strings" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/bridge" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/format" - "maunium.net/go/mautrix/id" -) - -// Event stores all data which might be used to handle commands -type Event struct { - Bot *appservice.IntentAPI - Bridge *bridge.Bridge - Portal bridge.Portal - Processor *Processor - Handler MinimalHandler - RoomID id.RoomID - EventID id.EventID - User bridge.User - Command string - Args []string - RawArgs string - ReplyTo id.EventID - Ctx context.Context - ZLog *zerolog.Logger -} - -// MainIntent returns the intent to use when replying to the command. -// -// It prefers the bridge bot, but falls back to the other user in DMs if the bridge bot is not present. -func (ce *Event) MainIntent() *appservice.IntentAPI { - intent := ce.Bot - if ce.Portal != nil && ce.Portal.IsPrivateChat() && !ce.Portal.IsEncrypted() { - intent = ce.Portal.MainIntent() - } - return intent -} - -// Reply sends a reply to command as notice, with optional string formatting and automatic $cmdprefix replacement. -func (ce *Event) Reply(msg string, args ...interface{}) { - msg = strings.ReplaceAll(msg, "$cmdprefix ", ce.Bridge.Config.Bridge.GetCommandPrefix()+" ") - if len(args) > 0 { - msg = fmt.Sprintf(msg, args...) - } - ce.ReplyAdvanced(msg, true, false) -} - -// ReplyAdvanced sends a reply to command as notice. It allows using HTML and disabling markdown, -// but doesn't have built-in string formatting. -func (ce *Event) ReplyAdvanced(msg string, allowMarkdown, allowHTML bool) { - content := format.RenderMarkdown(msg, allowMarkdown, allowHTML) - content.MsgType = event.MsgNotice - _, err := ce.MainIntent().SendMessageEvent(ce.Ctx, ce.RoomID, event.EventMessage, content) - if err != nil { - ce.ZLog.Error().Err(err).Msg("Failed to reply to command") - } -} - -// React sends a reaction to the command. -func (ce *Event) React(key string) { - _, err := ce.MainIntent().SendReaction(ce.Ctx, ce.RoomID, ce.EventID, key) - if err != nil { - ce.ZLog.Error().Err(err).Msg("Failed to react to command") - } -} - -// Redact redacts the command. -func (ce *Event) Redact(req ...mautrix.ReqRedact) { - _, err := ce.MainIntent().RedactEvent(ce.Ctx, ce.RoomID, ce.EventID, req...) - if err != nil { - ce.ZLog.Error().Err(err).Msg("Failed to redact command") - } -} - -// MarkRead marks the command event as read. -func (ce *Event) MarkRead() { - err := ce.MainIntent().SendReceipt(ce.Ctx, ce.RoomID, ce.EventID, event.ReceiptTypeRead, nil) - if err != nil { - ce.ZLog.Error().Err(err).Msg("Failed to mark command as read") - } -} diff --git a/bridge/commands/handler.go b/bridge/commands/handler.go deleted file mode 100644 index ab6899c0..00000000 --- a/bridge/commands/handler.go +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright (c) 2022 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package commands - -import ( - "maunium.net/go/mautrix/bridge" - "maunium.net/go/mautrix/bridge/bridgeconfig" - "maunium.net/go/mautrix/event" -) - -type MinimalHandler interface { - Run(*Event) -} - -type MinimalHandlerFunc func(*Event) - -func (mhf MinimalHandlerFunc) Run(ce *Event) { - mhf(ce) -} - -type CommandState struct { - Next MinimalHandler - Action string - Meta interface{} -} - -type CommandingUser interface { - bridge.User - GetCommandState() *CommandState - SetCommandState(*CommandState) -} - -type Handler interface { - MinimalHandler - GetName() string -} - -type AliasedHandler interface { - Handler - GetAliases() []string -} - -type FullHandler struct { - Func func(*Event) - - Name string - Aliases []string - Help HelpMeta - - RequiresAdmin bool - RequiresPortal bool - RequiresLogin bool - - RequiresEventLevel event.Type -} - -func (fh *FullHandler) GetHelp() HelpMeta { - fh.Help.Command = fh.Name - return fh.Help -} - -func (fh *FullHandler) GetName() string { - return fh.Name -} - -func (fh *FullHandler) GetAliases() []string { - return fh.Aliases -} - -func (fh *FullHandler) ShowInHelp(ce *Event) bool { - return !fh.RequiresAdmin || ce.User.GetPermissionLevel() >= bridgeconfig.PermissionLevelAdmin -} - -func (fh *FullHandler) userHasRoomPermission(ce *Event) bool { - levels, err := ce.MainIntent().PowerLevels(ce.Ctx, ce.RoomID) - if err != nil { - ce.ZLog.Warn().Err(err).Msg("Failed to check room power levels") - ce.Reply("Failed to get room power levels to see if you're allowed to use that command") - return false - } - return levels.GetUserLevel(ce.User.GetMXID()) >= levels.GetEventLevel(fh.RequiresEventLevel) -} - -func (fh *FullHandler) Run(ce *Event) { - if fh.RequiresAdmin && ce.User.GetPermissionLevel() < bridgeconfig.PermissionLevelAdmin { - ce.Reply("That command is limited to bridge administrators.") - } else if fh.RequiresEventLevel.Type != "" && ce.User.GetPermissionLevel() < bridgeconfig.PermissionLevelAdmin && !fh.userHasRoomPermission(ce) { - ce.Reply("That command requires room admin rights.") - } else if fh.RequiresPortal && ce.Portal == nil { - ce.Reply("That command can only be ran in portal rooms.") - } else if fh.RequiresLogin && !ce.User.IsLoggedIn() { - ce.Reply("That command requires you to be logged in.") - } else { - fh.Func(ce) - } -} diff --git a/bridge/commands/help.go b/bridge/commands/help.go deleted file mode 100644 index f4891555..00000000 --- a/bridge/commands/help.go +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright (c) 2022 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package commands - -import ( - "fmt" - "sort" - "strings" -) - -type HelpfulHandler interface { - Handler - GetHelp() HelpMeta - ShowInHelp(*Event) bool -} - -type HelpSection struct { - Name string - Order int -} - -var ( - // Deprecated: this should be used as a placeholder that needs to be fixed - HelpSectionUnclassified = HelpSection{"Unclassified", -1} - - HelpSectionGeneral = HelpSection{"General", 0} - HelpSectionAuth = HelpSection{"Authentication", 10} - HelpSectionAdmin = HelpSection{"Administration", 50} -) - -type HelpMeta struct { - Command string - Section HelpSection - Description string - Args string -} - -func (hm *HelpMeta) String() string { - if len(hm.Args) == 0 { - return fmt.Sprintf("**%s** - %s", hm.Command, hm.Description) - } - return fmt.Sprintf("**%s** %s - %s", hm.Command, hm.Args, hm.Description) -} - -type helpSectionList []HelpSection - -func (h helpSectionList) Len() int { - return len(h) -} - -func (h helpSectionList) Less(i, j int) bool { - return h[i].Order < h[j].Order -} - -func (h helpSectionList) Swap(i, j int) { - h[i], h[j] = h[j], h[i] -} - -type helpMetaList []HelpMeta - -func (h helpMetaList) Len() int { - return len(h) -} - -func (h helpMetaList) Less(i, j int) bool { - return h[i].Command < h[j].Command -} - -func (h helpMetaList) Swap(i, j int) { - h[i], h[j] = h[j], h[i] -} - -var _ sort.Interface = (helpSectionList)(nil) -var _ sort.Interface = (helpMetaList)(nil) - -func FormatHelp(ce *Event) string { - sections := make(map[HelpSection]helpMetaList) - for _, handler := range ce.Processor.handlers { - helpfulHandler, ok := handler.(HelpfulHandler) - if !ok || !helpfulHandler.ShowInHelp(ce) { - continue - } - help := helpfulHandler.GetHelp() - if help.Description == "" { - continue - } - sections[help.Section] = append(sections[help.Section], help) - } - - sortedSections := make(helpSectionList, 0, len(sections)) - for section := range sections { - sortedSections = append(sortedSections, section) - } - sort.Sort(sortedSections) - - var output strings.Builder - output.Grow(10240) - - var prefixMsg string - if ce.RoomID == ce.User.GetManagementRoomID() { - prefixMsg = "This is your management room: prefixing commands with `%s` is not required." - } else if ce.Portal != nil { - prefixMsg = "**This is a portal room**: you must always prefix commands with `%s`. Management commands will not be bridged." - } else { - prefixMsg = "This is not your management room: prefixing commands with `%s` is required." - } - _, _ = fmt.Fprintf(&output, prefixMsg, ce.Bridge.Config.Bridge.GetCommandPrefix()) - output.WriteByte('\n') - output.WriteString("Parameters in [square brackets] are optional, while parameters in are required.") - output.WriteByte('\n') - output.WriteByte('\n') - - for _, section := range sortedSections { - output.WriteString("#### ") - output.WriteString(section.Name) - output.WriteByte('\n') - sort.Sort(sections[section]) - for _, command := range sections[section] { - output.WriteString(command.String()) - output.WriteByte('\n') - } - output.WriteByte('\n') - } - return output.String() -} diff --git a/bridge/commands/meta.go b/bridge/commands/meta.go deleted file mode 100644 index 615f6a34..00000000 --- a/bridge/commands/meta.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) 2022 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package commands - -var CommandHelp = &FullHandler{ - Func: func(ce *Event) { - ce.Reply(FormatHelp(ce)) - }, - Name: "help", - Help: HelpMeta{ - Section: HelpSectionGeneral, - Description: "Show this help message.", - }, -} - -var CommandVersion = &FullHandler{ - Func: func(ce *Event) { - ce.Reply("[%s](%s) %s (%s)", ce.Bridge.Name, ce.Bridge.URL, ce.Bridge.LinkifiedVersion, ce.Bridge.BuildTime) - }, - Name: "version", - Help: HelpMeta{ - Section: HelpSectionGeneral, - Description: "Get the bridge version.", - }, -} - -var CommandCancel = &FullHandler{ - Func: func(ce *Event) { - commandingUser, ok := ce.User.(CommandingUser) - if !ok { - ce.Reply("This bridge does not implement cancelable commands") - return - } - state := commandingUser.GetCommandState() - - if state != nil { - action := state.Action - if action == "" { - action = "Unknown action" - } - commandingUser.SetCommandState(nil) - ce.Reply("%s cancelled.", action) - } else { - ce.Reply("No ongoing command.") - } - }, - Name: "cancel", - Help: HelpMeta{ - Section: HelpSectionGeneral, - Description: "Cancel an ongoing action.", - }, -} diff --git a/bridge/commands/processor.go b/bridge/commands/processor.go deleted file mode 100644 index 6158a7cd..00000000 --- a/bridge/commands/processor.go +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package commands - -import ( - "context" - "runtime/debug" - "strings" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix/bridge" - "maunium.net/go/mautrix/id" -) - -type Processor struct { - bridge *bridge.Bridge - log *zerolog.Logger - - handlers map[string]Handler - aliases map[string]string -} - -// NewProcessor creates a Processor -func NewProcessor(bridge *bridge.Bridge) *Processor { - proc := &Processor{ - bridge: bridge, - log: bridge.ZLog, - - handlers: make(map[string]Handler), - aliases: make(map[string]string), - } - proc.AddHandlers( - CommandHelp, CommandVersion, CommandCancel, - CommandLoginMatrix, CommandLogoutMatrix, CommandPingMatrix, - CommandDiscardMegolmSession, CommandSetPowerLevel) - return proc -} - -func (proc *Processor) AddHandlers(handlers ...Handler) { - for _, handler := range handlers { - proc.AddHandler(handler) - } -} - -func (proc *Processor) AddHandler(handler Handler) { - proc.handlers[handler.GetName()] = handler - aliased, ok := handler.(AliasedHandler) - if ok { - for _, alias := range aliased.GetAliases() { - proc.aliases[alias] = handler.GetName() - } - } -} - -// Handle handles messages to the bridge -func (proc *Processor) Handle(ctx context.Context, roomID id.RoomID, eventID id.EventID, user bridge.User, message string, replyTo id.EventID) { - defer func() { - err := recover() - if err != nil { - zerolog.Ctx(ctx).Error(). - Str(zerolog.ErrorStackFieldName, string(debug.Stack())). - Interface(zerolog.ErrorFieldName, err). - Msg("Panic in Matrix command handler") - } - }() - args := strings.Fields(message) - if len(args) == 0 { - args = []string{"unknown-command"} - } - command := strings.ToLower(args[0]) - rawArgs := strings.TrimLeft(strings.TrimPrefix(message, command), " ") - log := zerolog.Ctx(ctx).With().Str("mx_command", command).Logger() - ctx = log.WithContext(ctx) - ce := &Event{ - Bot: proc.bridge.Bot, - Bridge: proc.bridge, - Portal: proc.bridge.Child.GetIPortal(roomID), - Processor: proc, - RoomID: roomID, - EventID: eventID, - User: user, - Command: command, - Args: args[1:], - RawArgs: rawArgs, - ReplyTo: replyTo, - Ctx: ctx, - ZLog: &log, - } - log.Debug().Msg("Received command") - - realCommand, ok := proc.aliases[ce.Command] - if !ok { - realCommand = ce.Command - } - commandingUser, ok := ce.User.(CommandingUser) - - var handler MinimalHandler - handler, ok = proc.handlers[realCommand] - if !ok { - var state *CommandState - if commandingUser != nil { - state = commandingUser.GetCommandState() - } - if state != nil && state.Next != nil { - ce.Command = "" - ce.RawArgs = message - ce.Args = args - ce.Handler = state.Next - state.Next.Run(ce) - } else { - ce.Reply("Unknown command, use the `help` command for help.") - } - } else { - ce.Handler = handler - handler.Run(ce) - } -} diff --git a/bridge/crypto.go b/bridge/crypto.go deleted file mode 100644 index e3885a22..00000000 --- a/bridge/crypto.go +++ /dev/null @@ -1,507 +0,0 @@ -// Copyright (c) 2024 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build cgo && !nocrypto - -package bridge - -import ( - "context" - "errors" - "fmt" - "os" - "runtime/debug" - "sync" - "time" - - "github.com/rs/zerolog" - "go.mau.fi/util/dbutil" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/bridge/bridgeconfig" - "maunium.net/go/mautrix/crypto" - "maunium.net/go/mautrix/crypto/olm" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/id" - "maunium.net/go/mautrix/sqlstatestore" -) - -var _ crypto.StateStore = (*sqlstatestore.SQLStateStore)(nil) - -var NoSessionFound = crypto.NoSessionFound -var DuplicateMessageIndex = crypto.DuplicateMessageIndex -var UnknownMessageIndex = olm.UnknownMessageIndex - -type CryptoHelper struct { - bridge *Bridge - client *mautrix.Client - mach *crypto.OlmMachine - store *SQLCryptoStore - log *zerolog.Logger - - lock sync.RWMutex - syncDone sync.WaitGroup - cancelSync func() - - cancelPeriodicDeleteLoop func() -} - -func NewCryptoHelper(bridge *Bridge) Crypto { - if !bridge.Config.Bridge.GetEncryptionConfig().Allow { - bridge.ZLog.Debug().Msg("Bridge built with end-to-bridge encryption, but disabled in config") - return nil - } - log := bridge.ZLog.With().Str("component", "crypto").Logger() - return &CryptoHelper{ - bridge: bridge, - log: &log, - } -} - -func (helper *CryptoHelper) Init(ctx context.Context) error { - if len(helper.bridge.CryptoPickleKey) == 0 { - panic("CryptoPickleKey not set") - } - helper.log.Debug().Msg("Initializing end-to-bridge encryption...") - - helper.store = NewSQLCryptoStore( - helper.bridge.DB, - dbutil.ZeroLogger(helper.bridge.ZLog.With().Str("db_section", "crypto").Logger()), - helper.bridge.AS.BotMXID(), - fmt.Sprintf("@%s:%s", helper.bridge.Config.Bridge.FormatUsername("%"), helper.bridge.AS.HomeserverDomain), - helper.bridge.CryptoPickleKey, - ) - - err := helper.store.DB.Upgrade(ctx) - if err != nil { - helper.bridge.LogDBUpgradeErrorAndExit("crypto", err) - } - - var isExistingDevice bool - helper.client, isExistingDevice, err = helper.loginBot(ctx) - if err != nil { - return err - } - - helper.log.Debug(). - Str("device_id", helper.client.DeviceID.String()). - Msg("Logged in as bridge bot") - stateStore := &cryptoStateStore{helper.bridge} - helper.mach = crypto.NewOlmMachine(helper.client, helper.log, helper.store, stateStore) - helper.mach.AllowKeyShare = helper.allowKeyShare - - encryptionConfig := helper.bridge.Config.Bridge.GetEncryptionConfig() - helper.mach.SendKeysMinTrust = encryptionConfig.VerificationLevels.Receive - helper.mach.PlaintextMentions = encryptionConfig.PlaintextMentions - - helper.mach.DeleteOutboundKeysOnAck = encryptionConfig.DeleteKeys.DeleteOutboundOnAck - helper.mach.DontStoreOutboundKeys = encryptionConfig.DeleteKeys.DontStoreOutbound - helper.mach.RatchetKeysOnDecrypt = encryptionConfig.DeleteKeys.RatchetOnDecrypt - helper.mach.DeleteFullyUsedKeysOnDecrypt = encryptionConfig.DeleteKeys.DeleteFullyUsedOnDecrypt - helper.mach.DeletePreviousKeysOnReceive = encryptionConfig.DeleteKeys.DeletePrevOnNewSession - helper.mach.DeleteKeysOnDeviceDelete = encryptionConfig.DeleteKeys.DeleteOnDeviceDelete - helper.mach.DisableDeviceChangeKeyRotation = encryptionConfig.Rotation.DisableDeviceChangeKeyRotation - if encryptionConfig.DeleteKeys.PeriodicallyDeleteExpired { - ctx, cancel := context.WithCancel(context.Background()) - helper.cancelPeriodicDeleteLoop = cancel - go helper.mach.ExpiredKeyDeleteLoop(ctx) - } - - if encryptionConfig.DeleteKeys.DeleteOutdatedInbound { - deleted, err := helper.store.RedactOutdatedGroupSessions(ctx) - if err != nil { - return err - } - if len(deleted) > 0 { - helper.log.Debug().Int("deleted", len(deleted)).Msg("Deleted inbound keys which lacked expiration metadata") - } - } - - helper.client.Syncer = &cryptoSyncer{helper.mach} - helper.client.Store = helper.store - - err = helper.mach.Load(ctx) - if err != nil { - return err - } - if isExistingDevice { - helper.verifyKeysAreOnServer(ctx) - } - - go helper.resyncEncryptionInfo(context.TODO()) - - return nil -} - -func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) { - log := helper.log.With().Str("action", "resync encryption event").Logger() - rows, err := helper.bridge.DB.Query(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`) - roomIDs, err := dbutil.NewRowIterWithError(rows, dbutil.ScanSingleColumn[id.RoomID], err).AsList() - if err != nil { - log.Err(err).Msg("Failed to query rooms for resync") - return - } - if len(roomIDs) > 0 { - log.Debug().Interface("room_ids", roomIDs).Msg("Resyncing rooms") - for _, roomID := range roomIDs { - var evt event.EncryptionEventContent - err = helper.client.StateEvent(ctx, roomID, event.StateEncryption, "", &evt) - if err != nil { - log.Err(err).Str("room_id", roomID.String()).Msg("Failed to get encryption event") - _, err = helper.bridge.DB.Exec(ctx, ` - UPDATE mx_room_state SET encryption=NULL WHERE room_id=$1 AND encryption='{"resync":true}' - `, roomID) - if err != nil { - log.Err(err).Str("room_id", roomID.String()).Msg("Failed to unmark room for resync after failed sync") - } - } else { - maxAge := evt.RotationPeriodMillis - if maxAge <= 0 { - maxAge = (7 * 24 * time.Hour).Milliseconds() - } - maxMessages := evt.RotationPeriodMessages - if maxMessages <= 0 { - maxMessages = 100 - } - log.Debug(). - Str("room_id", roomID.String()). - Int64("max_age_ms", maxAge). - Int("max_messages", maxMessages). - Interface("content", &evt). - Msg("Resynced encryption event") - _, err = helper.bridge.DB.Exec(ctx, ` - UPDATE crypto_megolm_inbound_session - SET max_age=$1, max_messages=$2 - WHERE room_id=$3 AND max_age IS NULL AND max_messages IS NULL - `, maxAge, maxMessages, roomID) - if err != nil { - log.Err(err).Str("room_id", roomID.String()).Msg("Failed to update megolm session table") - } else { - log.Debug().Str("room_id", roomID.String()).Msg("Updated megolm session table") - } - } - } - } -} - -func (helper *CryptoHelper) allowKeyShare(ctx context.Context, device *id.Device, info event.RequestedKeyInfo) *crypto.KeyShareRejection { - cfg := helper.bridge.Config.Bridge.GetEncryptionConfig() - if !cfg.AllowKeySharing { - return &crypto.KeyShareRejectNoResponse - } else if device.Trust == id.TrustStateBlacklisted { - return &crypto.KeyShareRejectBlacklisted - } else if trustState, _ := helper.mach.ResolveTrustContext(ctx, device); trustState >= cfg.VerificationLevels.Share { - portal := helper.bridge.Child.GetIPortal(info.RoomID) - if portal == nil { - zerolog.Ctx(ctx).Debug().Msg("Rejecting key request: room is not a portal") - return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnavailable, Reason: "Requested room is not a portal room"} - } - user := helper.bridge.Child.GetIUser(device.UserID, true) - // FIXME reimplement IsInPortal - if user.GetPermissionLevel() < bridgeconfig.PermissionLevelAdmin /*&& !user.IsInPortal(portal.Key)*/ { - zerolog.Ctx(ctx).Debug().Msg("Rejecting key request: user is not in portal") - return &crypto.KeyShareRejection{Code: event.RoomKeyWithheldUnauthorized, Reason: "You're not in that portal"} - } - zerolog.Ctx(ctx).Debug().Msg("Accepting key request") - return nil - } else { - return &crypto.KeyShareRejectUnverified - } -} - -func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool, error) { - deviceID, err := helper.store.FindDeviceID(ctx) - if err != nil { - return nil, false, fmt.Errorf("failed to find existing device ID: %w", err) - } else if len(deviceID) > 0 { - helper.log.Debug().Str("device_id", deviceID.String()).Msg("Found existing device ID for bot in database") - } - // Create a new client instance with the default AS settings (including as_token), - // the Login call will then override the access token in the client. - client := helper.bridge.AS.NewMautrixClient(helper.bridge.AS.BotMXID()) - flows, err := client.GetLoginFlows(ctx) - if err != nil { - return nil, deviceID != "", fmt.Errorf("failed to get supported login flows: %w", err) - } else if !flows.HasFlow(mautrix.AuthTypeAppservice) { - return nil, deviceID != "", fmt.Errorf("homeserver does not support appservice login") - } - resp, err := client.Login(ctx, &mautrix.ReqLogin{ - Type: mautrix.AuthTypeAppservice, - Identifier: mautrix.UserIdentifier{ - Type: mautrix.IdentifierTypeUser, - User: string(helper.bridge.AS.BotMXID()), - }, - DeviceID: deviceID, - StoreCredentials: true, - - InitialDeviceDisplayName: fmt.Sprintf("%s bridge", helper.bridge.ProtocolName), - }) - if err != nil { - return nil, deviceID != "", fmt.Errorf("failed to log in as bridge bot: %w", err) - } - helper.store.DeviceID = resp.DeviceID - return client, deviceID != "", nil -} - -func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) { - helper.log.Debug().Msg("Making sure keys are still on server") - resp, err := helper.client.QueryKeys(ctx, &mautrix.ReqQueryKeys{ - DeviceKeys: map[id.UserID]mautrix.DeviceIDList{ - helper.client.UserID: {helper.client.DeviceID}, - }, - }) - if err != nil { - helper.log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to query own keys to make sure device still exists") - os.Exit(33) - } - device, ok := resp.DeviceKeys[helper.client.UserID][helper.client.DeviceID] - if ok && len(device.Keys) > 0 { - return - } - helper.log.Warn().Msg("Existing device doesn't have keys on server, resetting crypto") - helper.Reset(ctx, false) -} - -func (helper *CryptoHelper) Start() { - if helper.bridge.Config.Bridge.GetEncryptionConfig().Appservice { - helper.log.Debug().Msg("End-to-bridge encryption is in appservice mode, registering event listeners and not starting syncer") - helper.bridge.AS.Registration.EphemeralEvents = true - helper.mach.AddAppserviceListener(helper.bridge.EventProcessor) - return - } - helper.syncDone.Add(1) - defer helper.syncDone.Done() - helper.log.Debug().Msg("Starting syncer for receiving to-device messages") - var ctx context.Context - ctx, helper.cancelSync = context.WithCancel(context.Background()) - err := helper.client.SyncWithContext(ctx) - if err != nil && !errors.Is(err, context.Canceled) { - helper.log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Fatal error syncing") - os.Exit(51) - } else { - helper.log.Info().Msg("Bridge bot to-device syncer stopped without error") - } -} - -func (helper *CryptoHelper) Stop() { - helper.log.Debug().Msg("CryptoHelper.Stop() called, stopping bridge bot sync") - helper.client.StopSync() - if helper.cancelSync != nil { - helper.cancelSync() - } - if helper.cancelPeriodicDeleteLoop != nil { - helper.cancelPeriodicDeleteLoop() - } - helper.syncDone.Wait() -} - -func (helper *CryptoHelper) clearDatabase(ctx context.Context) { - _, err := helper.store.DB.Exec(ctx, "DELETE FROM crypto_account") - if err != nil { - helper.log.Warn().Err(err).Msg("Failed to clear crypto_account table") - } - _, err = helper.store.DB.Exec(ctx, "DELETE FROM crypto_olm_session") - if err != nil { - helper.log.Warn().Err(err).Msg("Failed to clear crypto_olm_session table") - } - _, err = helper.store.DB.Exec(ctx, "DELETE FROM crypto_megolm_outbound_session") - if err != nil { - helper.log.Warn().Err(err).Msg("Failed to clear crypto_megolm_outbound_session table") - } - //_, _ = helper.store.DB.Exec("DELETE FROM crypto_device") - //_, _ = helper.store.DB.Exec("DELETE FROM crypto_tracked_user") - //_, _ = helper.store.DB.Exec("DELETE FROM crypto_cross_signing_keys") - //_, _ = helper.store.DB.Exec("DELETE FROM crypto_cross_signing_signatures") -} - -func (helper *CryptoHelper) Reset(ctx context.Context, startAfterReset bool) { - helper.lock.Lock() - defer helper.lock.Unlock() - helper.log.Info().Msg("Resetting end-to-bridge encryption device") - helper.Stop() - helper.log.Debug().Msg("Crypto syncer stopped, clearing database") - helper.clearDatabase(ctx) - helper.log.Debug().Msg("Crypto database cleared, logging out of all sessions") - _, err := helper.client.LogoutAll(ctx) - if err != nil { - helper.log.Warn().Err(err).Msg("Failed to log out all devices") - } - helper.client = nil - helper.store = nil - helper.mach = nil - err = helper.Init(ctx) - if err != nil { - helper.log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Error reinitializing end-to-bridge encryption") - os.Exit(50) - } - helper.log.Info().Msg("End-to-bridge encryption successfully reset") - if startAfterReset { - go helper.Start() - } -} - -func (helper *CryptoHelper) Client() *mautrix.Client { - return helper.client -} - -func (helper *CryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*event.Event, error) { - return helper.mach.DecryptMegolmEvent(ctx, evt) -} - -func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content *event.Content) (err error) { - helper.lock.RLock() - defer helper.lock.RUnlock() - var encrypted *event.EncryptedEventContent - encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content) - if err != nil { - if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.SessionNotShared) && !errors.Is(err, crypto.NoGroupSession) { - return - } - helper.log.Debug().Err(err). - Str("room_id", roomID.String()). - Msg("Got error while encrypting event for room, sharing group session and trying again...") - var users []id.UserID - users, err = helper.store.GetRoomJoinedOrInvitedMembers(ctx, roomID) - if err != nil { - err = fmt.Errorf("failed to get room member list: %w", err) - } else if err = helper.mach.ShareGroupSession(ctx, roomID, users); err != nil { - err = fmt.Errorf("failed to share group session: %w", err) - } else if encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content); err != nil { - err = fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err) - } - } - if encrypted != nil { - content.Parsed = encrypted - content.Raw = nil - } - return -} - -func (helper *CryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { - helper.lock.RLock() - defer helper.lock.RUnlock() - return helper.mach.WaitForSession(ctx, roomID, senderKey, sessionID, timeout) -} - -func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) { - helper.lock.RLock() - defer helper.lock.RUnlock() - if deviceID == "" { - deviceID = "*" - } - err := helper.mach.SendRoomKeyRequest(ctx, roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{userID: {deviceID}}) - if err != nil { - helper.log.Warn().Err(err). - Str("user_id", userID.String()). - Str("device_id", deviceID.String()). - Str("session_id", sessionID.String()). - Str("room_id", roomID.String()). - Msg("Failed to send key request") - } else { - helper.log.Debug(). - Str("user_id", userID.String()). - Str("device_id", deviceID.String()). - Str("session_id", sessionID.String()). - Str("room_id", roomID.String()). - Msg("Sent key request") - } -} - -func (helper *CryptoHelper) ResetSession(ctx context.Context, roomID id.RoomID) { - helper.lock.RLock() - defer helper.lock.RUnlock() - err := helper.mach.CryptoStore.RemoveOutboundGroupSession(ctx, roomID) - if err != nil { - helper.log.Debug().Err(err). - Str("room_id", roomID.String()). - Msg("Error manually removing outbound group session in room") - } -} - -func (helper *CryptoHelper) HandleMemberEvent(ctx context.Context, evt *event.Event) { - helper.lock.RLock() - defer helper.lock.RUnlock() - helper.mach.HandleMemberEvent(ctx, evt) -} - -// ShareKeys uploads the given number of one-time-keys to the server. -func (helper *CryptoHelper) ShareKeys(ctx context.Context) error { - return helper.mach.ShareKeys(ctx, -1) -} - -type cryptoSyncer struct { - *crypto.OlmMachine -} - -func (syncer *cryptoSyncer) ProcessResponse(ctx context.Context, resp *mautrix.RespSync, since string) error { - done := make(chan struct{}) - go func() { - defer func() { - if err := recover(); err != nil { - syncer.Log.Error(). - Str("since", since). - Interface("error", err). - Str("stack", string(debug.Stack())). - Msg("Processing sync response panicked") - } - done <- struct{}{} - }() - syncer.Log.Trace().Str("since", since).Msg("Starting sync response handling") - syncer.ProcessSyncResponse(ctx, resp, since) - syncer.Log.Trace().Str("since", since).Msg("Successfully handled sync response") - }() - select { - case <-done: - case <-time.After(30 * time.Second): - syncer.Log.Warn().Str("since", since).Msg("Handling sync response is taking unusually long") - } - return nil -} - -func (syncer *cryptoSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.Duration, error) { - if errors.Is(err, mautrix.MUnknownToken) { - return 0, err - } - syncer.Log.Error().Err(err).Msg("Error /syncing, waiting 10 seconds") - return 10 * time.Second, nil -} - -func (syncer *cryptoSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter { - everything := []event.Type{{Type: "*"}} - return &mautrix.Filter{ - Presence: &mautrix.FilterPart{NotTypes: everything}, - AccountData: &mautrix.FilterPart{NotTypes: everything}, - Room: &mautrix.RoomFilter{ - IncludeLeave: false, - Ephemeral: &mautrix.FilterPart{NotTypes: everything}, - AccountData: &mautrix.FilterPart{NotTypes: everything}, - State: &mautrix.FilterPart{NotTypes: everything}, - Timeline: &mautrix.FilterPart{NotTypes: everything}, - }, - } -} - -type cryptoStateStore struct { - bridge *Bridge -} - -var _ crypto.StateStore = (*cryptoStateStore)(nil) - -func (c *cryptoStateStore) IsEncrypted(ctx context.Context, id id.RoomID) (bool, error) { - portal := c.bridge.Child.GetIPortal(id) - if portal != nil { - return portal.IsEncrypted(), nil - } - return c.bridge.StateStore.IsEncrypted(ctx, id) -} - -func (c *cryptoStateStore) FindSharedRooms(ctx context.Context, id id.UserID) ([]id.RoomID, error) { - return c.bridge.StateStore.FindSharedRooms(ctx, id) -} - -func (c *cryptoStateStore) GetEncryptionEvent(ctx context.Context, id id.RoomID) (*event.EncryptionEventContent, error) { - return c.bridge.StateStore.GetEncryptionEvent(ctx, id) -} diff --git a/bridge/cryptostore.go b/bridge/cryptostore.go deleted file mode 100644 index dde48a25..00000000 --- a/bridge/cryptostore.go +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (c) 2022 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build cgo && !nocrypto - -package bridge - -import ( - "context" - - "github.com/lib/pq" - "go.mau.fi/util/dbutil" - - "maunium.net/go/mautrix/crypto" - "maunium.net/go/mautrix/id" -) - -func init() { - crypto.PostgresArrayWrapper = pq.Array -} - -type SQLCryptoStore struct { - *crypto.SQLCryptoStore - UserID id.UserID - GhostIDFormat string -} - -var _ crypto.Store = (*SQLCryptoStore)(nil) - -func NewSQLCryptoStore(db *dbutil.Database, log dbutil.DatabaseLogger, userID id.UserID, ghostIDFormat, pickleKey string) *SQLCryptoStore { - return &SQLCryptoStore{ - SQLCryptoStore: crypto.NewSQLCryptoStore(db, log, "", "", []byte(pickleKey)), - UserID: userID, - GhostIDFormat: ghostIDFormat, - } -} - -func (store *SQLCryptoStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) (members []id.UserID, err error) { - var rows dbutil.Rows - rows, err = store.DB.Query(ctx, ` - SELECT user_id FROM mx_user_profile - WHERE room_id=$1 - AND (membership='join' OR membership='invite') - AND user_id<>$2 - AND user_id NOT LIKE $3 - `, roomID, store.UserID, store.GhostIDFormat) - if err != nil { - return - } - for rows.Next() { - var userID id.UserID - err = rows.Scan(&userID) - if err != nil { - return members, err - } else { - members = append(members, userID) - } - } - return -} diff --git a/bridge/doublepuppet.go b/bridge/doublepuppet.go deleted file mode 100644 index 265d3d5c..00000000 --- a/bridge/doublepuppet.go +++ /dev/null @@ -1,173 +0,0 @@ -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package bridge - -import ( - "context" - "crypto/hmac" - "crypto/sha512" - "encoding/hex" - "errors" - "fmt" - "strings" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/id" -) - -type doublePuppetUtil struct { - br *Bridge - log zerolog.Logger -} - -func (dp *doublePuppetUtil) newClient(ctx context.Context, mxid id.UserID, accessToken string) (*mautrix.Client, error) { - _, homeserver, err := mxid.Parse() - if err != nil { - return nil, err - } - homeserverURL, found := dp.br.Config.Bridge.GetDoublePuppetConfig().ServerMap[homeserver] - if !found { - if homeserver == dp.br.AS.HomeserverDomain { - homeserverURL = "" - } else if dp.br.Config.Bridge.GetDoublePuppetConfig().AllowDiscovery { - resp, err := mautrix.DiscoverClientAPI(ctx, homeserver) - if err != nil { - return nil, fmt.Errorf("failed to find homeserver URL for %s: %v", homeserver, err) - } - homeserverURL = resp.Homeserver.BaseURL - dp.log.Debug(). - Str("homeserver", homeserver). - Str("url", homeserverURL). - Str("user_id", mxid.String()). - Msg("Discovered URL to enable double puppeting for user") - } else { - return nil, fmt.Errorf("double puppeting from %s is not allowed", homeserver) - } - } - return dp.br.AS.NewExternalMautrixClient(mxid, accessToken, homeserverURL) -} - -func (dp *doublePuppetUtil) newIntent(ctx context.Context, mxid id.UserID, accessToken string) (*appservice.IntentAPI, error) { - client, err := dp.newClient(ctx, mxid, accessToken) - if err != nil { - return nil, err - } - - ia := dp.br.AS.NewIntentAPI("custom") - ia.Client = client - ia.Localpart, _, _ = mxid.Parse() - ia.UserID = mxid - ia.IsCustomPuppet = true - return ia, nil -} - -func (dp *doublePuppetUtil) autoLogin(ctx context.Context, mxid id.UserID, loginSecret string) (string, error) { - dp.log.Debug().Str("user_id", mxid.String()).Msg("Logging into user account with shared secret") - client, err := dp.newClient(ctx, mxid, "") - if err != nil { - return "", fmt.Errorf("failed to create mautrix client to log in: %v", err) - } - bridgeName := fmt.Sprintf("%s Bridge", dp.br.ProtocolName) - req := mautrix.ReqLogin{ - Identifier: mautrix.UserIdentifier{Type: mautrix.IdentifierTypeUser, User: string(mxid)}, - DeviceID: id.DeviceID(bridgeName), - InitialDeviceDisplayName: bridgeName, - } - if loginSecret == "appservice" { - client.AccessToken = dp.br.AS.Registration.AppToken - req.Type = mautrix.AuthTypeAppservice - } else { - loginFlows, err := client.GetLoginFlows(ctx) - if err != nil { - return "", fmt.Errorf("failed to get supported login flows: %w", err) - } - mac := hmac.New(sha512.New, []byte(loginSecret)) - mac.Write([]byte(mxid)) - token := hex.EncodeToString(mac.Sum(nil)) - switch { - case loginFlows.HasFlow(mautrix.AuthTypeDevtureSharedSecret): - req.Type = mautrix.AuthTypeDevtureSharedSecret - req.Token = token - case loginFlows.HasFlow(mautrix.AuthTypePassword): - req.Type = mautrix.AuthTypePassword - req.Password = token - default: - return "", fmt.Errorf("no supported auth types for shared secret auth found") - } - } - resp, err := client.Login(ctx, &req) - if err != nil { - return "", err - } - return resp.AccessToken, nil -} - -var ( - ErrMismatchingMXID = errors.New("whoami result does not match custom mxid") - ErrNoAccessToken = errors.New("no access token provided") - ErrNoMXID = errors.New("no mxid provided") -) - -const useConfigASToken = "appservice-config" -const asTokenModePrefix = "as_token:" - -func (dp *doublePuppetUtil) Setup(ctx context.Context, mxid id.UserID, savedAccessToken string, reloginOnFail bool) (intent *appservice.IntentAPI, newAccessToken string, err error) { - if len(mxid) == 0 { - err = ErrNoMXID - return - } - _, homeserver, _ := mxid.Parse() - loginSecret, hasSecret := dp.br.Config.Bridge.GetDoublePuppetConfig().SharedSecretMap[homeserver] - // Special case appservice: prefix to not login and use it as an as_token directly. - if hasSecret && strings.HasPrefix(loginSecret, asTokenModePrefix) { - intent, err = dp.newIntent(ctx, mxid, strings.TrimPrefix(loginSecret, asTokenModePrefix)) - if err != nil { - return - } - intent.SetAppServiceUserID = true - if savedAccessToken != useConfigASToken { - var resp *mautrix.RespWhoami - resp, err = intent.Whoami(ctx) - if err == nil && resp.UserID != mxid { - err = ErrMismatchingMXID - } - } - return intent, useConfigASToken, err - } - if savedAccessToken == "" || savedAccessToken == useConfigASToken { - if reloginOnFail && hasSecret { - savedAccessToken, err = dp.autoLogin(ctx, mxid, loginSecret) - } else { - err = ErrNoAccessToken - } - if err != nil { - return - } - } - intent, err = dp.newIntent(ctx, mxid, savedAccessToken) - if err != nil { - return - } - var resp *mautrix.RespWhoami - resp, err = intent.Whoami(ctx) - if err != nil { - if reloginOnFail && hasSecret && errors.Is(err, mautrix.MUnknownToken) { - intent.AccessToken, err = dp.autoLogin(ctx, mxid, loginSecret) - if err == nil { - newAccessToken = intent.AccessToken - } - } - } else if resp.UserID != mxid { - err = ErrMismatchingMXID - } else { - newAccessToken = savedAccessToken - } - return -} diff --git a/bridge/matrix.go b/bridge/matrix.go deleted file mode 100644 index 446a0b0a..00000000 --- a/bridge/matrix.go +++ /dev/null @@ -1,755 +0,0 @@ -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package bridge - -import ( - "context" - "errors" - "fmt" - "strings" - "time" - - "github.com/rs/zerolog" - - "maunium.net/go/mautrix" - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/bridge/bridgeconfig" - "maunium.net/go/mautrix/bridge/status" - "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/format" - "maunium.net/go/mautrix/id" -) - -type CommandProcessor interface { - Handle(ctx context.Context, roomID id.RoomID, eventID id.EventID, user User, message string, replyTo id.EventID) -} - -type MatrixHandler struct { - bridge *Bridge - as *appservice.AppService - log *zerolog.Logger - - TrackEventDuration func(event.Type) func() -} - -func noop() {} - -func noopTrack(_ event.Type) func() { - return noop -} - -func NewMatrixHandler(br *Bridge) *MatrixHandler { - handler := &MatrixHandler{ - bridge: br, - as: br.AS, - log: br.ZLog, - - TrackEventDuration: noopTrack, - } - for evtType := range status.CheckpointTypes { - br.EventProcessor.On(evtType, handler.sendBridgeCheckpoint) - } - br.EventProcessor.On(event.EventMessage, handler.HandleMessage) - br.EventProcessor.On(event.EventEncrypted, handler.HandleEncrypted) - br.EventProcessor.On(event.EventSticker, handler.HandleMessage) - br.EventProcessor.On(event.EventReaction, handler.HandleReaction) - br.EventProcessor.On(event.EventRedaction, handler.HandleRedaction) - br.EventProcessor.On(event.StateMember, handler.HandleMembership) - br.EventProcessor.On(event.StateRoomName, handler.HandleRoomMetadata) - br.EventProcessor.On(event.StateRoomAvatar, handler.HandleRoomMetadata) - br.EventProcessor.On(event.StateTopic, handler.HandleRoomMetadata) - br.EventProcessor.On(event.StateEncryption, handler.HandleEncryption) - br.EventProcessor.On(event.EphemeralEventReceipt, handler.HandleReceipt) - br.EventProcessor.On(event.EphemeralEventTyping, handler.HandleTyping) - br.EventProcessor.On(event.StatePowerLevels, handler.HandlePowerLevels) - br.EventProcessor.On(event.StateJoinRules, handler.HandleJoinRule) - return handler -} - -func (mx *MatrixHandler) sendBridgeCheckpoint(_ context.Context, evt *event.Event) { - if !evt.Mautrix.CheckpointSent { - go mx.bridge.SendMessageSuccessCheckpoint(evt, status.MsgStepBridge, 0) - } -} - -func (mx *MatrixHandler) HandleEncryption(ctx context.Context, evt *event.Event) { - defer mx.TrackEventDuration(evt.Type)() - if evt.Content.AsEncryption().Algorithm != id.AlgorithmMegolmV1 { - return - } - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal != nil && !portal.IsEncrypted() { - mx.log.Debug(). - Str("user_id", evt.Sender.String()). - Str("room_id", evt.RoomID.String()). - Msg("Encryption was enabled in room") - portal.MarkEncrypted() - if portal.IsPrivateChat() { - err := mx.as.BotIntent().EnsureJoined(ctx, evt.RoomID, appservice.EnsureJoinedParams{BotOverride: portal.MainIntent().Client}) - if err != nil { - mx.log.Err(err). - Str("room_id", evt.RoomID.String()). - Msg("Failed to join bot to room after encryption was enabled") - } - } - } -} - -func (mx *MatrixHandler) joinAndCheckMembers(ctx context.Context, evt *event.Event, intent *appservice.IntentAPI) *mautrix.RespJoinedMembers { - log := zerolog.Ctx(ctx) - resp, err := intent.JoinRoomByID(ctx, evt.RoomID) - if err != nil { - log.Warn().Err(err).Msg("Failed to join room with invite") - return nil - } - - members, err := intent.JoinedMembers(ctx, resp.RoomID) - if err != nil { - log.Warn().Err(err).Msg("Failed to get members in room after accepting invite, leaving room") - _, _ = intent.LeaveRoom(ctx, resp.RoomID) - return nil - } - - if len(members.Joined) < 2 { - log.Debug().Msg("Leaving empty room after accepting invite") - _, _ = intent.LeaveRoom(ctx, resp.RoomID) - return nil - } - return members -} - -func (mx *MatrixHandler) sendNoticeWithMarkdown(ctx context.Context, roomID id.RoomID, message string) (*mautrix.RespSendEvent, error) { - intent := mx.as.BotIntent() - content := format.RenderMarkdown(message, true, false) - content.MsgType = event.MsgNotice - return intent.SendMessageEvent(ctx, roomID, event.EventMessage, content) -} - -func (mx *MatrixHandler) HandleBotInvite(ctx context.Context, evt *event.Event) { - intent := mx.as.BotIntent() - - user := mx.bridge.Child.GetIUser(evt.Sender, true) - if user == nil { - return - } - - members := mx.joinAndCheckMembers(ctx, evt, intent) - if members == nil { - return - } - - if user.GetPermissionLevel() < bridgeconfig.PermissionLevelUser { - _, _ = intent.SendNotice(ctx, evt.RoomID, "You are not whitelisted to use this bridge.\n"+ - "If you're the owner of this bridge, see the bridge.permissions section in your config file.") - _, _ = intent.LeaveRoom(ctx, evt.RoomID) - return - } - - texts := mx.bridge.Config.Bridge.GetManagementRoomTexts() - _, _ = mx.sendNoticeWithMarkdown(ctx, evt.RoomID, texts.Welcome) - - if len(members.Joined) == 2 && (len(user.GetManagementRoomID()) == 0 || evt.Content.AsMember().IsDirect) { - user.SetManagementRoom(evt.RoomID) - _, _ = intent.SendNotice(ctx, user.GetManagementRoomID(), "This room has been registered as your bridge management/status room.") - zerolog.Ctx(ctx).Debug().Msg("Registered room as management room with inviter") - } - - if evt.RoomID == user.GetManagementRoomID() { - if user.IsLoggedIn() { - _, _ = mx.sendNoticeWithMarkdown(ctx, evt.RoomID, texts.WelcomeConnected) - } else { - _, _ = mx.sendNoticeWithMarkdown(ctx, evt.RoomID, texts.WelcomeUnconnected) - } - - additionalHelp := texts.AdditionalHelp - if len(additionalHelp) > 0 { - _, _ = mx.sendNoticeWithMarkdown(ctx, evt.RoomID, additionalHelp) - } - } -} - -func (mx *MatrixHandler) HandleGhostInvite(ctx context.Context, evt *event.Event, inviter User, ghost Ghost) { - log := zerolog.Ctx(ctx) - intent := ghost.DefaultIntent() - - if inviter.GetPermissionLevel() < bridgeconfig.PermissionLevelUser { - log.Debug().Msg("Rejecting invite: inviter is not whitelisted") - _, err := intent.LeaveRoom(ctx, evt.RoomID, &mautrix.ReqLeave{ - Reason: "You're not whitelisted to use this bridge", - }) - if err != nil { - log.Error().Err(err).Msg("Failed to reject invite") - } - return - } else if !inviter.IsLoggedIn() { - log.Debug().Msg("Rejecting invite: inviter is not logged in") - _, err := intent.LeaveRoom(ctx, evt.RoomID, &mautrix.ReqLeave{ - Reason: "You're not logged into this bridge", - }) - if err != nil { - log.Error().Err(err).Msg("Failed to reject invite") - } - return - } - - members := mx.joinAndCheckMembers(ctx, evt, intent) - if members == nil { - return - } - var createEvent event.CreateEventContent - if err := intent.StateEvent(ctx, evt.RoomID, event.StateCreate, "", &createEvent); err != nil { - log.Warn().Err(err).Msg("Failed to check m.room.create event in room") - } else if createEvent.Type != "" { - log.Warn().Str("room_type", string(createEvent.Type)).Msg("Non-standard room type, leaving room") - _, err = intent.LeaveRoom(ctx, evt.RoomID, &mautrix.ReqLeave{ - Reason: "Unsupported room type", - }) - if err != nil { - log.Error().Err(err).Msg("Failed to leave room") - } - return - } - var hasBridgeBot, hasOtherUsers bool - for mxid, _ := range members.Joined { - if mxid == intent.UserID || mxid == inviter.GetMXID() { - continue - } else if mxid == mx.bridge.Bot.UserID { - hasBridgeBot = true - } else { - hasOtherUsers = true - } - } - if !hasBridgeBot && !hasOtherUsers && evt.Content.AsMember().IsDirect { - mx.bridge.Child.CreatePrivatePortal(evt.RoomID, inviter, ghost) - } else if !hasBridgeBot { - log.Debug().Msg("Leaving multi-user room after accepting invite") - _, _ = intent.SendNotice(ctx, evt.RoomID, "Please invite the bridge bot first if you want to bridge to a remote chat.") - _, _ = intent.LeaveRoom(ctx, evt.RoomID) - } else { - _, _ = intent.SendNotice(ctx, evt.RoomID, "This puppet will remain inactive until this room is bridged to a remote chat.") - } -} - -func (mx *MatrixHandler) HandleMembership(ctx context.Context, evt *event.Event) { - if evt.Sender == mx.bridge.Bot.UserID || mx.bridge.Child.IsGhost(evt.Sender) { - return - } - defer mx.TrackEventDuration(evt.Type)() - - if mx.bridge.Crypto != nil { - mx.bridge.Crypto.HandleMemberEvent(ctx, evt) - } - - log := mx.log.With(). - Str("sender", evt.Sender.String()). - Str("target", evt.GetStateKey()). - Str("room_id", evt.RoomID.String()). - Logger() - ctx = log.WithContext(ctx) - - content := evt.Content.AsMember() - if content.Membership == event.MembershipInvite && id.UserID(evt.GetStateKey()) == mx.as.BotMXID() { - mx.HandleBotInvite(ctx, evt) - return - } - - if mx.shouldIgnoreEvent(evt) { - return - } - - user := mx.bridge.Child.GetIUser(evt.Sender, true) - if user == nil { - return - } - isSelf := id.UserID(evt.GetStateKey()) == evt.Sender - ghost := mx.bridge.Child.GetIGhost(id.UserID(evt.GetStateKey())) - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal == nil { - if ghost != nil && content.Membership == event.MembershipInvite { - mx.HandleGhostInvite(ctx, evt, user, ghost) - } - return - } else if user.GetPermissionLevel() < bridgeconfig.PermissionLevelUser || !user.IsLoggedIn() { - return - } - bhp, bhpOk := portal.(BanHandlingPortal) - mhp, mhpOk := portal.(MembershipHandlingPortal) - khp, khpOk := portal.(KnockHandlingPortal) - ihp, ihpOk := portal.(InviteHandlingPortal) - if !(mhpOk || bhpOk || khpOk) { - return - } - prevContent := &event.MemberEventContent{Membership: event.MembershipLeave} - if evt.Unsigned.PrevContent != nil { - _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) - prevContent, _ = evt.Unsigned.PrevContent.Parsed.(*event.MemberEventContent) - } - if ihpOk && prevContent.Membership == event.MembershipInvite && content.Membership != event.MembershipBan { - if content.Membership == event.MembershipJoin { - ihp.HandleMatrixAcceptInvite(user, evt) - } - if content.Membership == event.MembershipLeave { - if isSelf { - ihp.HandleMatrixRejectInvite(user, evt) - } else if ghost != nil { - ihp.HandleMatrixRetractInvite(user, ghost, evt) - } - } - } - if bhpOk && ghost != nil { - if content.Membership == event.MembershipBan { - bhp.HandleMatrixBan(user, ghost, evt) - } else if content.Membership == event.MembershipLeave && prevContent.Membership == event.MembershipBan { - bhp.HandleMatrixUnban(user, ghost, evt) - } - } - if khpOk { - if content.Membership == event.MembershipKnock { - khp.HandleMatrixKnock(user, evt) - } else if prevContent.Membership == event.MembershipKnock { - if content.Membership == event.MembershipInvite && ghost != nil { - khp.HandleMatrixAcceptKnock(user, ghost, evt) - } else if content.Membership == event.MembershipLeave { - if isSelf { - khp.HandleMatrixRetractKnock(user, evt) - } else if ghost != nil { - khp.HandleMatrixRejectKnock(user, ghost, evt) - } - } - } - } - if mhpOk { - if content.Membership == event.MembershipLeave && prevContent.Membership == event.MembershipJoin { - if isSelf { - mhp.HandleMatrixLeave(user, evt) - } else if ghost != nil { - mhp.HandleMatrixKick(user, ghost, evt) - } - } else if content.Membership == event.MembershipInvite && !isSelf && ghost != nil { - mhp.HandleMatrixInvite(user, ghost, evt) - } - } - // TODO kicking/inviting non-ghost users users -} - -func (mx *MatrixHandler) HandleRoomMetadata(ctx context.Context, evt *event.Event) { - defer mx.TrackEventDuration(evt.Type)() - if mx.shouldIgnoreEvent(evt) { - return - } - - user := mx.bridge.Child.GetIUser(evt.Sender, true) - if user == nil { - return - } - - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal == nil || portal.IsPrivateChat() { - return - } - - metaPortal, ok := portal.(MetaHandlingPortal) - if !ok { - return - } - - metaPortal.HandleMatrixMeta(user, evt) -} - -func (mx *MatrixHandler) shouldIgnoreEvent(evt *event.Event) bool { - if evt.Sender == mx.bridge.Bot.UserID || mx.bridge.Child.IsGhost(evt.Sender) { - return true - } - user := mx.bridge.Child.GetIUser(evt.Sender, true) - if user == nil || user.GetPermissionLevel() <= 0 { - return true - } else if val, ok := evt.Content.Raw[appservice.DoublePuppetKey]; ok && val == mx.bridge.Name && user.GetIDoublePuppet() != nil { - return true - } - return false -} - -const initialSessionWaitTimeout = 3 * time.Second -const extendedSessionWaitTimeout = 22 * time.Second - -func (mx *MatrixHandler) sendCryptoStatusError(ctx context.Context, evt *event.Event, editEvent id.EventID, err error, retryCount int, isFinal bool) id.EventID { - mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepDecrypted, err, isFinal, retryCount) - - if mx.bridge.Config.Bridge.EnableMessageStatusEvents() { - statusEvent := &event.BeeperMessageStatusEventContent{ - // TODO: network - RelatesTo: event.RelatesTo{ - Type: event.RelReference, - EventID: evt.ID, - }, - Status: event.MessageStatusRetriable, - Reason: event.MessageStatusUndecryptable, - Error: err.Error(), - Message: errorToHumanMessage(err), - } - if !isFinal { - statusEvent.Status = event.MessageStatusPending - } - _, sendErr := mx.bridge.Bot.SendMessageEvent(ctx, evt.RoomID, event.BeeperMessageStatus, statusEvent) - if sendErr != nil { - zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to send message status event") - } - } - if mx.bridge.Config.Bridge.EnableMessageErrorNotices() { - update := event.MessageEventContent{ - MsgType: event.MsgNotice, - Body: fmt.Sprintf("\u26a0 Your message was not bridged: %v.", err), - } - if errors.Is(err, errNoCrypto) { - update.Body = "🔒 This bridge has not been configured to support encryption" - } - relatable, ok := evt.Content.Parsed.(event.Relatable) - if editEvent != "" { - update.SetEdit(editEvent) - } else if ok && relatable.OptionalGetRelatesTo().GetThreadParent() != "" { - update.GetRelatesTo().SetThread(relatable.OptionalGetRelatesTo().GetThreadParent(), evt.ID) - } - resp, sendErr := mx.bridge.Bot.SendMessageEvent(ctx, evt.RoomID, event.EventMessage, &update) - if sendErr != nil { - zerolog.Ctx(ctx).Error().Err(sendErr).Msg("Failed to send decryption error notice") - } else if resp != nil { - return resp.EventID - } - } - return "" -} - -var ( - errDeviceNotTrusted = errors.New("your device is not trusted") - errMessageNotEncrypted = errors.New("unencrypted message") - errNoDecryptionKeys = errors.New("the bridge hasn't received the decryption keys") - errNoCrypto = errors.New("this bridge has not been configured to support encryption") -) - -func errorToHumanMessage(err error) string { - var withheld *event.RoomKeyWithheldEventContent - switch { - case errors.Is(err, errDeviceNotTrusted), errors.Is(err, errNoDecryptionKeys): - return err.Error() - case errors.Is(err, UnknownMessageIndex): - return "the keys received by the bridge can't decrypt the message" - case errors.Is(err, DuplicateMessageIndex): - return "your client encrypted multiple messages with the same key" - case errors.As(err, &withheld): - if withheld.Code == event.RoomKeyWithheldBeeperRedacted { - return "your client used an outdated encryption session" - } - return "your client refused to share decryption keys with the bridge" - case errors.Is(err, errMessageNotEncrypted): - return "the message is not encrypted" - default: - return "the bridge failed to decrypt the message" - } -} - -func deviceUnverifiedErrorWithExplanation(trust id.TrustState) error { - var explanation string - switch trust { - case id.TrustStateBlacklisted: - explanation = "device is blacklisted" - case id.TrustStateUnset: - explanation = "unverified" - case id.TrustStateUnknownDevice: - explanation = "device info not found" - case id.TrustStateForwarded: - explanation = "keys were forwarded from an unknown device" - case id.TrustStateCrossSignedUntrusted: - explanation = "cross-signing keys changed after setting up the bridge" - default: - return errDeviceNotTrusted - } - return fmt.Errorf("%w (%s)", errDeviceNotTrusted, explanation) -} - -func copySomeKeys(original, decrypted *event.Event) { - isScheduled, _ := original.Content.Raw["com.beeper.scheduled"].(bool) - _, alreadyExists := decrypted.Content.Raw["com.beeper.scheduled"] - if isScheduled && !alreadyExists { - decrypted.Content.Raw["com.beeper.scheduled"] = true - } -} - -func (mx *MatrixHandler) postDecrypt(ctx context.Context, original, decrypted *event.Event, retryCount int, errorEventID id.EventID, duration time.Duration) { - log := zerolog.Ctx(ctx) - minLevel := mx.bridge.Config.Bridge.GetEncryptionConfig().VerificationLevels.Send - if decrypted.Mautrix.TrustState < minLevel { - logEvt := log.Warn(). - Str("user_id", decrypted.Sender.String()). - Bool("forwarded_keys", decrypted.Mautrix.ForwardedKeys). - Stringer("device_trust", decrypted.Mautrix.TrustState). - Stringer("min_trust", minLevel) - if decrypted.Mautrix.TrustSource != nil { - dev := decrypted.Mautrix.TrustSource - logEvt. - Str("device_id", dev.DeviceID.String()). - Str("device_signing_key", dev.SigningKey.String()) - } else { - logEvt.Str("device_id", "unknown") - } - logEvt.Msg("Dropping event due to insufficient verification level") - err := deviceUnverifiedErrorWithExplanation(decrypted.Mautrix.TrustState) - go mx.sendCryptoStatusError(ctx, decrypted, errorEventID, err, retryCount, true) - return - } - copySomeKeys(original, decrypted) - - mx.bridge.SendMessageSuccessCheckpoint(decrypted, status.MsgStepDecrypted, retryCount) - decrypted.Mautrix.CheckpointSent = true - decrypted.Mautrix.DecryptionDuration = duration - decrypted.Mautrix.EventSource |= event.SourceDecrypted - mx.bridge.EventProcessor.Dispatch(ctx, decrypted) - if errorEventID != "" { - _, _ = mx.bridge.Bot.RedactEvent(ctx, decrypted.RoomID, errorEventID) - } -} - -func (mx *MatrixHandler) HandleEncrypted(ctx context.Context, evt *event.Event) { - defer mx.TrackEventDuration(evt.Type)() - if mx.shouldIgnoreEvent(evt) { - return - } - content := evt.Content.AsEncrypted() - log := zerolog.Ctx(ctx).With(). - Str("event_id", evt.ID.String()). - Str("session_id", content.SessionID.String()). - Logger() - ctx = log.WithContext(ctx) - if mx.bridge.Crypto == nil { - go mx.sendCryptoStatusError(ctx, evt, "", errNoCrypto, 0, true) - return - } - log.Debug().Msg("Decrypting received event") - - decryptionStart := time.Now() - decrypted, err := mx.bridge.Crypto.Decrypt(ctx, evt) - decryptionRetryCount := 0 - if errors.Is(err, NoSessionFound) { - decryptionRetryCount = 1 - log.Debug(). - Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())). - Msg("Couldn't find session, waiting for keys to arrive...") - mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepDecrypted, err, false, 0) - if mx.bridge.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { - log.Debug().Msg("Got keys after waiting, trying to decrypt event again") - decrypted, err = mx.bridge.Crypto.Decrypt(ctx, evt) - } else { - go mx.waitLongerForSession(ctx, evt, decryptionStart) - return - } - } - if err != nil { - mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepDecrypted, err, true, decryptionRetryCount) - log.Warn().Err(err).Msg("Failed to decrypt event") - go mx.sendCryptoStatusError(ctx, evt, "", err, decryptionRetryCount, true) - return - } - mx.postDecrypt(ctx, evt, decrypted, decryptionRetryCount, "", time.Since(decryptionStart)) -} - -func (mx *MatrixHandler) waitLongerForSession(ctx context.Context, evt *event.Event, decryptionStart time.Time) { - log := zerolog.Ctx(ctx) - content := evt.Content.AsEncrypted() - log.Debug(). - Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())). - Msg("Couldn't find session, requesting keys and waiting longer...") - - go mx.bridge.Crypto.RequestSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) - errorEventID := mx.sendCryptoStatusError(ctx, evt, "", fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), 1, false) - - if !mx.bridge.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { - log.Debug().Msg("Didn't get session, giving up trying to decrypt event") - mx.sendCryptoStatusError(ctx, evt, errorEventID, errNoDecryptionKeys, 2, true) - return - } - - log.Debug().Msg("Got keys after waiting longer, trying to decrypt event again") - decrypted, err := mx.bridge.Crypto.Decrypt(ctx, evt) - if err != nil { - log.Error().Err(err).Msg("Failed to decrypt event") - mx.sendCryptoStatusError(ctx, evt, errorEventID, err, 2, true) - return - } - - mx.postDecrypt(ctx, evt, decrypted, 2, errorEventID, time.Since(decryptionStart)) -} - -func (mx *MatrixHandler) HandleMessage(ctx context.Context, evt *event.Event) { - defer mx.TrackEventDuration(evt.Type)() - log := zerolog.Ctx(ctx).With(). - Str("event_id", evt.ID.String()). - Str("room_id", evt.RoomID.String()). - Str("sender", evt.Sender.String()). - Logger() - ctx = log.WithContext(ctx) - if mx.shouldIgnoreEvent(evt) { - return - } else if !evt.Mautrix.WasEncrypted && mx.bridge.Config.Bridge.GetEncryptionConfig().Require { - log.Warn().Msg("Dropping unencrypted event") - mx.sendCryptoStatusError(ctx, evt, "", errMessageNotEncrypted, 0, true) - return - } - - user := mx.bridge.Child.GetIUser(evt.Sender, true) - if user == nil { - return - } - - content := evt.Content.AsMessage() - content.RemoveReplyFallback() - if user.GetPermissionLevel() >= bridgeconfig.PermissionLevelUser && content.MsgType == event.MsgText { - commandPrefix := mx.bridge.Config.Bridge.GetCommandPrefix() - hasCommandPrefix := strings.HasPrefix(content.Body, commandPrefix) - if hasCommandPrefix { - content.Body = strings.TrimLeft(strings.TrimPrefix(content.Body, commandPrefix), " ") - } - if hasCommandPrefix || evt.RoomID == user.GetManagementRoomID() { - go mx.bridge.CommandProcessor.Handle(ctx, evt.RoomID, evt.ID, user, content.Body, content.RelatesTo.GetReplyTo()) - go mx.bridge.SendMessageSuccessCheckpoint(evt, status.MsgStepCommand, 0) - if mx.bridge.Config.Bridge.EnableMessageStatusEvents() { - statusEvent := &event.BeeperMessageStatusEventContent{ - // TODO: network - RelatesTo: event.RelatesTo{ - Type: event.RelReference, - EventID: evt.ID, - }, - Status: event.MessageStatusSuccess, - } - _, sendErr := mx.bridge.Bot.SendMessageEvent(ctx, evt.RoomID, event.BeeperMessageStatus, statusEvent) - if sendErr != nil { - log.Warn().Err(sendErr).Msg("Failed to send message status event for command") - } - } - return - } - } - - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal != nil { - portal.ReceiveMatrixEvent(user, evt) - } else { - mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepRemote, fmt.Errorf("unknown room"), true, 0) - } -} - -func (mx *MatrixHandler) HandleReaction(_ context.Context, evt *event.Event) { - defer mx.TrackEventDuration(evt.Type)() - if mx.shouldIgnoreEvent(evt) { - return - } - - user := mx.bridge.Child.GetIUser(evt.Sender, true) - if user == nil || user.GetPermissionLevel() < bridgeconfig.PermissionLevelUser || !user.IsLoggedIn() { - return - } - - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal != nil { - portal.ReceiveMatrixEvent(user, evt) - } else { - mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepRemote, fmt.Errorf("unknown room"), true, 0) - } -} - -func (mx *MatrixHandler) HandleRedaction(_ context.Context, evt *event.Event) { - defer mx.TrackEventDuration(evt.Type)() - if mx.shouldIgnoreEvent(evt) { - return - } - - user := mx.bridge.Child.GetIUser(evt.Sender, true) - if user == nil { - return - } - - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal != nil { - portal.ReceiveMatrixEvent(user, evt) - } else { - mx.bridge.SendMessageErrorCheckpoint(evt, status.MsgStepRemote, fmt.Errorf("unknown room"), true, 0) - } -} - -func (mx *MatrixHandler) HandleReceipt(_ context.Context, evt *event.Event) { - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal == nil { - return - } - - rrPortal, ok := portal.(ReadReceiptHandlingPortal) - if !ok { - return - } - - for eventID, receipts := range *evt.Content.AsReceipt() { - for userID, receipt := range receipts[event.ReceiptTypeRead] { - user := mx.bridge.Child.GetIUser(userID, false) - if user == nil { - // Not a bridge user - continue - } - customPuppet := user.GetIDoublePuppet() - if val, ok := receipt.Extra[appservice.DoublePuppetKey].(string); ok && customPuppet != nil && val == mx.bridge.Name { - // Ignore double puppeted read receipts. - mx.log.Debug().Interface("content", evt.Content.Raw).Msg("Ignoring double-puppeted read receipt") - // But do start disappearing messages, because the user read the chat - dp, ok := portal.(DisappearingPortal) - if ok { - dp.ScheduleDisappearing() - } - } else { - rrPortal.HandleMatrixReadReceipt(user, eventID, receipt) - } - } - } -} - -func (mx *MatrixHandler) HandleTyping(_ context.Context, evt *event.Event) { - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal == nil { - return - } - typingPortal, ok := portal.(TypingPortal) - if !ok { - return - } - typingPortal.HandleMatrixTyping(evt.Content.AsTyping().UserIDs) -} - -func (mx *MatrixHandler) HandlePowerLevels(_ context.Context, evt *event.Event) { - if mx.shouldIgnoreEvent(evt) { - return - } - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal == nil { - return - } - powerLevelPortal, ok := portal.(PowerLevelHandlingPortal) - if ok { - user := mx.bridge.Child.GetIUser(evt.Sender, true) - powerLevelPortal.HandleMatrixPowerLevels(user, evt) - } -} - -func (mx *MatrixHandler) HandleJoinRule(_ context.Context, evt *event.Event) { - if mx.shouldIgnoreEvent(evt) { - return - } - portal := mx.bridge.Child.GetIPortal(evt.RoomID) - if portal == nil { - return - } - joinRulePortal, ok := portal.(JoinRuleHandlingPortal) - if ok { - user := mx.bridge.Child.GetIUser(evt.Sender, true) - joinRulePortal.HandleMatrixJoinRule(user, evt) - } -} diff --git a/bridge/messagecheckpoint.go b/bridge/messagecheckpoint.go deleted file mode 100644 index a95d2160..00000000 --- a/bridge/messagecheckpoint.go +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2021 Sumner Evans -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package bridge - -import ( - "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/bridge/status" - "maunium.net/go/mautrix/event" -) - -func (br *Bridge) SendMessageSuccessCheckpoint(evt *event.Event, step status.MessageCheckpointStep, retryNum int) { - br.SendMessageCheckpoint(evt, step, nil, status.MsgStatusSuccess, retryNum) -} - -func (br *Bridge) SendMessageErrorCheckpoint(evt *event.Event, step status.MessageCheckpointStep, err error, permanent bool, retryNum int) { - s := status.MsgStatusWillRetry - if permanent { - s = status.MsgStatusPermFailure - } - br.SendMessageCheckpoint(evt, step, err, s, retryNum) -} - -func (br *Bridge) SendMessageCheckpoint(evt *event.Event, step status.MessageCheckpointStep, err error, s status.MessageCheckpointStatus, retryNum int) { - checkpoint := status.NewMessageCheckpoint(evt, step, s, retryNum) - if err != nil { - checkpoint.Info = err.Error() - } - go br.SendRawMessageCheckpoint(checkpoint) -} - -func (br *Bridge) SendRawMessageCheckpoint(cp *status.MessageCheckpoint) { - err := br.SendMessageCheckpoints([]*status.MessageCheckpoint{cp}) - if err != nil { - br.ZLog.Warn().Err(err).Interface("message_checkpoint", cp).Msg("Error sending message checkpoint") - } else { - br.ZLog.Debug().Interface("message_checkpoint", cp).Msg("Sent message checkpoint") - } -} - -func (br *Bridge) SendMessageCheckpoints(checkpoints []*status.MessageCheckpoint) error { - checkpointsJSON := status.CheckpointsJSON{Checkpoints: checkpoints} - - if br.Websocket { - return br.AS.SendWebsocket(&appservice.WebsocketRequest{ - Command: "message_checkpoint", - Data: checkpointsJSON, - }) - } - - endpoint := br.Config.Homeserver.MessageSendCheckpointEndpoint - if endpoint == "" { - return nil - } - - return checkpointsJSON.SendHTTP(endpoint, br.AS.Registration.AppToken) -} diff --git a/bridge/no-crypto.go b/bridge/no-crypto.go deleted file mode 100644 index 019ab7c1..00000000 --- a/bridge/no-crypto.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2023 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//go:build !cgo || nocrypto - -package bridge - -import ( - "errors" -) - -func NewCryptoHelper(bridge *Bridge) Crypto { - if bridge.Config.Bridge.GetEncryptionConfig().Allow { - bridge.ZLog.Warn().Msg("Bridge built without end-to-bridge encryption, but encryption is enabled in config") - } else { - bridge.ZLog.Debug().Msg("Bridge built without end-to-bridge encryption") - } - return nil -} - -var NoSessionFound = errors.New("nil") -var UnknownMessageIndex = NoSessionFound -var DuplicateMessageIndex = NoSessionFound diff --git a/bridge/status/deprecated.go b/bridge/status/deprecated.go new file mode 100644 index 00000000..1b3f24a4 --- /dev/null +++ b/bridge/status/deprecated.go @@ -0,0 +1,83 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Deprecated: use bridgev2/status +package status + +import ( + "maunium.net/go/mautrix/bridgev2/status" +) + +// Deprecated: use bridgev2/status +type ( + BridgeStateEvent = status.BridgeStateEvent + BridgeStateErrorCode = status.BridgeStateErrorCode + BridgeStateErrorMap = status.BridgeStateErrorMap + BridgeState = status.BridgeState + RemoteProfile = status.RemoteProfile + GlobalBridgeState = status.GlobalBridgeState + BridgeStateFiller = status.BridgeStateFiller + StandaloneCustomBridgeStateFiller = status.StandaloneCustomBridgeStateFiller + CustomBridgeStateFiller = status.CustomBridgeStateFiller + MessageCheckpointStep = status.MessageCheckpointStep + MessageCheckpointStatus = status.MessageCheckpointStatus + MessageCheckpointReportedBy = status.MessageCheckpointReportedBy + MessageCheckpoint = status.MessageCheckpoint + CheckpointsJSON = status.CheckpointsJSON + LocalBridgeAccountState = status.LocalBridgeAccountState + LocalBridgeDeviceState = status.LocalBridgeDeviceState +) + +// Deprecated: use bridgev2/status +const ( + StateStarting = status.StateStarting + StateUnconfigured = status.StateUnconfigured + StateRunning = status.StateRunning + StateBridgeUnreachable = status.StateBridgeUnreachable + + StateConnecting = status.StateConnecting + StateBackfilling = status.StateBackfilling + StateConnected = status.StateConnected + StateTransientDisconnect = status.StateTransientDisconnect + StateBadCredentials = status.StateBadCredentials + StateUnknownError = status.StateUnknownError + StateLoggedOut = status.StateLoggedOut + + MsgStepClient = status.MsgStepClient + MsgStepHomeserver = status.MsgStepHomeserver + MsgStepBridge = status.MsgStepBridge + MsgStepDecrypted = status.MsgStepDecrypted + MsgStepRemote = status.MsgStepRemote + MsgStepCommand = status.MsgStepCommand + + MsgStatusSuccess = status.MsgStatusSuccess + MsgStatusWillRetry = status.MsgStatusWillRetry + MsgStatusPermFailure = status.MsgStatusPermFailure + MsgStatusUnsupported = status.MsgStatusUnsupported + MsgStatusTimeout = status.MsgStatusTimeout + MsgStatusDelivered = status.MsgStatusDelivered + MsgStatusDeliveryFailed = status.MsgStatusDeliveryFailed + + MsgReportedByAsmux = status.MsgReportedByAsmux + MsgReportedByBridge = status.MsgReportedByBridge + MsgReportedByHungry = status.MsgReportedByHungry + + LocalBridgeAccountStateSetup = status.LocalBridgeAccountStateSetup + LocalBridgeAccountStateDeleted = status.LocalBridgeAccountStateDeleted + + LocalBridgeDeviceStateSetup = status.LocalBridgeDeviceStateSetup + LocalBridgeDeviceStateLoggedOut = status.LocalBridgeDeviceStateLoggedOut + LocalBridgeDeviceStateError = status.LocalBridgeDeviceStateError + LocalBridgeDeviceStateDeleted = status.LocalBridgeDeviceStateDeleted +) + +// Deprecated: use bridgev2/status +var ( + CheckpointTypes = status.CheckpointTypes + NewMessageCheckpoint = status.NewMessageCheckpoint + ReasonToCheckpointStatus = status.ReasonToCheckpointStatus + BridgeStateHumanErrors = status.BridgeStateHumanErrors +) diff --git a/bridge/websocket.go b/bridge/websocket.go deleted file mode 100644 index 44a3d8d8..00000000 --- a/bridge/websocket.go +++ /dev/null @@ -1,163 +0,0 @@ -package bridge - -import ( - "context" - "errors" - "fmt" - "sync" - "time" - - "go.mau.fi/util/jsontime" - - "maunium.net/go/mautrix/appservice" -) - -const defaultReconnectBackoff = 2 * time.Second -const maxReconnectBackoff = 2 * time.Minute -const reconnectBackoffReset = 5 * time.Minute - -func (br *Bridge) startWebsocket(wg *sync.WaitGroup) { - log := br.ZLog.With().Str("action", "appservice websocket").Logger() - var wgOnce sync.Once - onConnect := func() { - wssBr, ok := br.Child.(WebsocketStartingBridge) - if ok { - wssBr.OnWebsocketConnect() - } - if br.latestState != nil { - go func() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - br.latestState.Timestamp = jsontime.UnixNow() - err := br.SendBridgeState(ctx, br.latestState) - if err != nil { - log.Err(err).Msg("Failed to resend latest bridge state after websocket reconnect") - } else { - log.Debug().Any("bridge_state", br.latestState).Msg("Resent bridge state after websocket reconnect") - } - }() - } - wgOnce.Do(wg.Done) - select { - case br.wsStarted <- struct{}{}: - default: - } - } - reconnectBackoff := defaultReconnectBackoff - lastDisconnect := time.Now().UnixNano() - br.wsStopped = make(chan struct{}) - defer func() { - log.Debug().Msg("Appservice websocket loop finished") - close(br.wsStopped) - }() - addr := br.Config.Homeserver.WSProxy - if addr == "" { - addr = br.Config.Homeserver.Address - } - for { - err := br.AS.StartWebsocket(addr, onConnect) - if errors.Is(err, appservice.ErrWebsocketManualStop) { - return - } else if closeCommand := (&appservice.CloseCommand{}); errors.As(err, &closeCommand) && closeCommand.Status == appservice.MeowConnectionReplaced { - log.Info().Msg("Appservice websocket closed by another instance of the bridge, shutting down...") - br.ManualStop(0) - return - } else if err != nil { - log.Err(err).Msg("Error in appservice websocket") - } - if br.Stopping { - return - } - now := time.Now().UnixNano() - if lastDisconnect+reconnectBackoffReset.Nanoseconds() < now { - reconnectBackoff = defaultReconnectBackoff - } else { - reconnectBackoff *= 2 - if reconnectBackoff > maxReconnectBackoff { - reconnectBackoff = maxReconnectBackoff - } - } - lastDisconnect = now - log.Info(). - Int("backoff_seconds", int(reconnectBackoff.Seconds())). - Msg("Websocket disconnected, reconnecting...") - select { - case <-br.wsShortCircuitReconnectBackoff: - log.Debug().Msg("Reconnect backoff was short-circuited") - case <-time.After(reconnectBackoff): - } - if br.Stopping { - return - } - } -} - -type wsPingData struct { - Timestamp int64 `json:"timestamp"` -} - -func (br *Bridge) PingServer() (start, serverTs, end time.Time) { - if !br.Websocket { - panic(fmt.Errorf("PingServer called without websocket enabled")) - } - if !br.AS.HasWebsocket() { - br.ZLog.Debug().Msg("Received server ping request, but no websocket connected. Trying to short-circuit backoff sleep") - select { - case br.wsShortCircuitReconnectBackoff <- struct{}{}: - default: - br.ZLog.Warn().Msg("Failed to ping websocket: not connected and no backoff?") - return - } - select { - case <-br.wsStarted: - case <-time.After(15 * time.Second): - if !br.AS.HasWebsocket() { - br.ZLog.Warn().Msg("Failed to ping websocket: didn't connect after 15 seconds of waiting") - return - } - } - } - start = time.Now() - var resp wsPingData - br.ZLog.Debug().Msg("Pinging appservice websocket") - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - err := br.AS.RequestWebsocket(ctx, &appservice.WebsocketRequest{ - Command: "ping", - Data: &wsPingData{Timestamp: start.UnixMilli()}, - }, &resp) - end = time.Now() - if err != nil { - br.ZLog.Warn().Err(err).Dur("duration", end.Sub(start)).Msg("Websocket ping returned error") - br.AS.StopWebsocket(fmt.Errorf("websocket ping returned error in %s: %w", end.Sub(start), err)) - } else { - serverTs = time.Unix(0, resp.Timestamp*int64(time.Millisecond)) - br.ZLog.Debug(). - Dur("duration", end.Sub(start)). - Dur("req_duration", serverTs.Sub(start)). - Dur("resp_duration", end.Sub(serverTs)). - Msg("Websocket ping returned success") - } - return -} - -func (br *Bridge) websocketServerPinger() { - interval := time.Duration(br.Config.Homeserver.WSPingInterval) * time.Second - clock := time.NewTicker(interval) - defer func() { - br.ZLog.Info().Msg("Stopping websocket pinger") - clock.Stop() - }() - br.ZLog.Info().Dur("interval_duration", interval).Msg("Starting websocket pinger") - for { - select { - case <-clock.C: - br.PingServer() - case <-br.wsStopPinger: - return - } - if br.Stopping { - return - } - } -} diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 309f48ed..bdc8480d 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -16,10 +16,10 @@ import ( "go.mau.fi/util/dbutil" "go.mau.fi/util/exsync" - "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2/bridgeconfig" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/id" ) diff --git a/bridgev2/bridgestate.go b/bridgev2/bridgestate.go index 1cd6b0c5..7863dcba 100644 --- a/bridgev2/bridgestate.go +++ b/bridgev2/bridgestate.go @@ -13,7 +13,7 @@ import ( "github.com/rs/zerolog" - "maunium.net/go/mautrix/bridge/status" + "maunium.net/go/mautrix/bridgev2/status" ) type BridgeStateQueue struct { diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go index 5c7ae57d..3544998c 100644 --- a/bridgev2/commands/login.go +++ b/bridgev2/commands/login.go @@ -19,9 +19,9 @@ import ( "github.com/skip2/go-qrcode" "go.mau.fi/util/curl" - "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) diff --git a/bridgev2/commands/processor.go b/bridgev2/commands/processor.go index 610fb825..c28e3a32 100644 --- a/bridgev2/commands/processor.go +++ b/bridgev2/commands/processor.go @@ -17,8 +17,7 @@ import ( "github.com/rs/zerolog" "maunium.net/go/mautrix/bridgev2" - - "maunium.net/go/mautrix/bridge/status" + "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) diff --git a/bridgev2/database/userlogin.go b/bridgev2/database/userlogin.go index 610e7d60..9fa6569a 100644 --- a/bridgev2/database/userlogin.go +++ b/bridgev2/database/userlogin.go @@ -12,8 +12,8 @@ import ( "go.mau.fi/util/dbutil" - "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/id" ) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 0bb1ee61..4a930135 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -32,12 +32,12 @@ import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/bridgeconfig" "maunium.net/go/mautrix/bridgev2/commands" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" "maunium.net/go/mautrix/mediaproxy" diff --git a/bridgev2/matrix/cryptoerror.go b/bridgev2/matrix/cryptoerror.go index 55110429..ea29703a 100644 --- a/bridgev2/matrix/cryptoerror.go +++ b/bridgev2/matrix/cryptoerror.go @@ -11,8 +11,8 @@ import ( "errors" "fmt" - "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) diff --git a/bridgev2/matrix/matrix.go b/bridgev2/matrix/matrix.go index 1117fca2..84e85d24 100644 --- a/bridgev2/matrix/matrix.go +++ b/bridgev2/matrix/matrix.go @@ -17,8 +17,8 @@ import ( "go.mau.fi/util/jsontime" "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 87f6576d..e9ef4b41 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -25,9 +25,9 @@ import ( "go.mau.fi/util/requestlog" "maunium.net/go/mautrix" - "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/federation" "maunium.net/go/mautrix/id" ) diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 2d6ed982..2665956c 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -16,9 +16,9 @@ import ( "github.com/gorilla/mux" "maunium.net/go/mautrix" - "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index b29eac0b..7118649d 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -13,7 +13,7 @@ import ( "go.mau.fi/util/jsontime" "maunium.net/go/mautrix/appservice" - "maunium.net/go/mautrix/bridge/status" + "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) diff --git a/bridge/status/bridgestate.go b/bridgev2/status/bridgestate.go similarity index 100% rename from bridge/status/bridgestate.go rename to bridgev2/status/bridgestate.go diff --git a/bridge/status/localbridgestate.go b/bridgev2/status/localbridgestate.go similarity index 100% rename from bridge/status/localbridgestate.go rename to bridgev2/status/localbridgestate.go diff --git a/bridge/status/messagecheckpoint.go b/bridgev2/status/messagecheckpoint.go similarity index 100% rename from bridge/status/messagecheckpoint.go rename to bridgev2/status/messagecheckpoint.go diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index b9c7288a..d07acce5 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -17,10 +17,10 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/exsync" - "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/bridgev2/bridgeconfig" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" ) From c0d1df18b42343e8265240febcad866cf81faec8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 16 Mar 2025 15:05:55 +0200 Subject: [PATCH 1067/1647] appservice/http: use constant time comparisons for access tokens --- appservice/http.go | 3 ++- bridgev2/matrix/provisioning.go | 5 +++-- go.mod | 2 +- go.sum | 4 ++-- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/appservice/http.go b/appservice/http.go index 661513b4..66c7bc5b 100644 --- a/appservice/http.go +++ b/appservice/http.go @@ -19,6 +19,7 @@ import ( "github.com/gorilla/mux" "github.com/rs/zerolog" + "go.mau.fi/util/exstrings" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" @@ -83,7 +84,7 @@ func (as *AppService) CheckServerToken(w http.ResponseWriter, r *http.Request) ( HTTPStatus: http.StatusForbidden, Message: "Missing access token", }.Write(w) - } else if authHeader[len("Bearer "):] != as.Registration.ServerToken { + } else if !exstrings.ConstantTimeEqual(authHeader[len("Bearer "):], as.Registration.ServerToken) { Error{ ErrorCode: ErrUnknownToken, HTTPStatus: http.StatusForbidden, diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index e9ef4b41..126d54de 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -21,6 +21,7 @@ import ( "github.com/rs/xid" "github.com/rs/zerolog" "github.com/rs/zerolog/hlog" + "go.mau.fi/util/exstrings" "go.mau.fi/util/jsontime" "go.mau.fi/util/requestlog" @@ -207,7 +208,7 @@ func (prov *ProvisioningAPI) DebugAuthMiddleware(h http.Handler) http.Handler { Err: "Missing auth token", ErrCode: mautrix.MMissingToken.ErrCode, }) - } else if auth != prov.br.Config.Provisioning.SharedSecret { + } else if !exstrings.ConstantTimeEqual(auth, prov.br.Config.Provisioning.SharedSecret) { jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{ Err: "Invalid auth token", ErrCode: mautrix.MUnknownToken.ErrCode, @@ -235,7 +236,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { if userID == "" && prov.GetUserIDFromRequest != nil { userID = prov.GetUserIDFromRequest(r) } - if auth != prov.br.Config.Provisioning.SharedSecret { + if !exstrings.ConstantTimeEqual(auth, prov.br.Config.Provisioning.SharedSecret) { var err error if strings.HasPrefix(auth, "openid:") { err = prov.checkFederatedMatrixAuth(r.Context(), userID, strings.TrimPrefix(auth, "openid:")) diff --git a/go.mod b/go.mod index cc11719a..ab20f9b1 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.8 - go.mau.fi/util v0.8.6-0.20250313222444-739a30158a62 + go.mau.fi/util v0.8.6-0.20250316130503-05facedd4121 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.36.0 golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 diff --git a/go.sum b/go.sum index 2fcb8086..a9e283fb 100644 --- a/go.sum +++ b/go.sum @@ -54,8 +54,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.8.6-0.20250313222444-739a30158a62 h1:8EjBMxX7QkT94/815jKIVK5k41ku+ES3SxSk8DyQRk4= -go.mau.fi/util v0.8.6-0.20250313222444-739a30158a62/go.mod h1:uNB3UTXFbkpp7xL1M/WvQks90B/L4gvbLpbS0603KOE= +go.mau.fi/util v0.8.6-0.20250316130503-05facedd4121 h1:d7KUA46BWjtyEwJjVSvtZnQln+lR3+cdvzw4z2nCyhM= +go.mau.fi/util v0.8.6-0.20250316130503-05facedd4121/go.mod h1:uNB3UTXFbkpp7xL1M/WvQks90B/L4gvbLpbS0603KOE= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= From 11f93740038242b4127aa1789f82459fb4eb9fb2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 16 Mar 2025 16:18:59 +0200 Subject: [PATCH 1068/1647] Bump version to v0.23.2 --- CHANGELOG.md | 33 +++++++++++++++++++++++++++++++++ go.mod | 2 +- go.sum | 4 ++-- version.go | 2 +- 4 files changed, 37 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 85d362ee..3fc91bb7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,36 @@ +## v0.23.2 (2025-03-16) + +* **Breaking change *(bridge)*** Removed legacy bridge module. +* **Breaking change *(event)*** Changed `m.federate` field in room create event + content to a pointer to allow detecting omitted values. +* *(bridgev2/commands)* Added `set-management-room` command to set a new + management room. +* *(bridgev2/portal)* Changed edit bridging to ignore remote edits if the + original sender on Matrix can't be puppeted. +* *(bridgv2)* Added config option to disable bridging `m.notice` messages. +* *(appservice/http)* Switched access token validation to use constant time + comparisons. +* *(event)* Added support for [MSC3765] rich text topics. +* *(event)* Added fields to policy list event contents for [MSC4204] and + [MSC4205]. +* *(client)* Added method for getting the content of a redacted event using + [MSC2815]. +* *(client)* Added methods for sending and updating [MSC4140] delayed events. +* *(client)* Added support for [MSC4222] in sync payloads. +* *(crypto/cryptohelper)* Switched to using `sqlite3-fk-wal` instead of plain + `sqlite3` by default. +* *(crypto/encryptolm)* Added generic method for encrypting to-device events. +* *(crypto/ssss)* Fixed panic if server-side key metadata is corrupted. +* *(crypto/sqlstore)* Fixed error when marking over 32 thousand device lists + as outdated on SQLite. + +[MSC2815]: https://github.com/matrix-org/matrix-spec-proposals/pull/2815 +[MSC3765]: https://github.com/matrix-org/matrix-spec-proposals/pull/3765 +[MSC4140]: https://github.com/matrix-org/matrix-spec-proposals/pull/4140 +[MSC4204]: https://github.com/matrix-org/matrix-spec-proposals/pull/4204 +[MSC4205]: https://github.com/matrix-org/matrix-spec-proposals/pull/4205 +[MSC4222]: https://github.com/matrix-org/matrix-spec-proposals/pull/4222 + ## v0.23.1 (2025-02-16) * *(client)* Added `FullStateEvent` method to get a state event including diff --git a/go.mod b/go.mod index ab20f9b1..cea3c580 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.8 - go.mau.fi/util v0.8.6-0.20250316130503-05facedd4121 + go.mau.fi/util v0.8.6 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.36.0 golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 diff --git a/go.sum b/go.sum index a9e283fb..eab4a1b8 100644 --- a/go.sum +++ b/go.sum @@ -54,8 +54,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.8.6-0.20250316130503-05facedd4121 h1:d7KUA46BWjtyEwJjVSvtZnQln+lR3+cdvzw4z2nCyhM= -go.mau.fi/util v0.8.6-0.20250316130503-05facedd4121/go.mod h1:uNB3UTXFbkpp7xL1M/WvQks90B/L4gvbLpbS0603KOE= +go.mau.fi/util v0.8.6 h1:AEK13rfgtiZJL2YsNK+W4ihhYCuukcRom8WPP/w/L54= +go.mau.fi/util v0.8.6/go.mod h1:uNB3UTXFbkpp7xL1M/WvQks90B/L4gvbLpbS0603KOE= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= diff --git a/version.go b/version.go index 2ff08518..dba1a59e 100644 --- a/version.go +++ b/version.go @@ -7,7 +7,7 @@ import ( "strings" ) -const Version = "v0.23.1" +const Version = "v0.23.2" var GoModVersion = "" var Commit = "" From 1c2898870cd7644046d8367e8b50470e6d59dc9c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 18 Mar 2025 14:23:37 +0200 Subject: [PATCH 1069/1647] bridge: remove fallback status package --- bridge/status/deprecated.go | 83 ------------------------------------- 1 file changed, 83 deletions(-) delete mode 100644 bridge/status/deprecated.go diff --git a/bridge/status/deprecated.go b/bridge/status/deprecated.go deleted file mode 100644 index 1b3f24a4..00000000 --- a/bridge/status/deprecated.go +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -// Deprecated: use bridgev2/status -package status - -import ( - "maunium.net/go/mautrix/bridgev2/status" -) - -// Deprecated: use bridgev2/status -type ( - BridgeStateEvent = status.BridgeStateEvent - BridgeStateErrorCode = status.BridgeStateErrorCode - BridgeStateErrorMap = status.BridgeStateErrorMap - BridgeState = status.BridgeState - RemoteProfile = status.RemoteProfile - GlobalBridgeState = status.GlobalBridgeState - BridgeStateFiller = status.BridgeStateFiller - StandaloneCustomBridgeStateFiller = status.StandaloneCustomBridgeStateFiller - CustomBridgeStateFiller = status.CustomBridgeStateFiller - MessageCheckpointStep = status.MessageCheckpointStep - MessageCheckpointStatus = status.MessageCheckpointStatus - MessageCheckpointReportedBy = status.MessageCheckpointReportedBy - MessageCheckpoint = status.MessageCheckpoint - CheckpointsJSON = status.CheckpointsJSON - LocalBridgeAccountState = status.LocalBridgeAccountState - LocalBridgeDeviceState = status.LocalBridgeDeviceState -) - -// Deprecated: use bridgev2/status -const ( - StateStarting = status.StateStarting - StateUnconfigured = status.StateUnconfigured - StateRunning = status.StateRunning - StateBridgeUnreachable = status.StateBridgeUnreachable - - StateConnecting = status.StateConnecting - StateBackfilling = status.StateBackfilling - StateConnected = status.StateConnected - StateTransientDisconnect = status.StateTransientDisconnect - StateBadCredentials = status.StateBadCredentials - StateUnknownError = status.StateUnknownError - StateLoggedOut = status.StateLoggedOut - - MsgStepClient = status.MsgStepClient - MsgStepHomeserver = status.MsgStepHomeserver - MsgStepBridge = status.MsgStepBridge - MsgStepDecrypted = status.MsgStepDecrypted - MsgStepRemote = status.MsgStepRemote - MsgStepCommand = status.MsgStepCommand - - MsgStatusSuccess = status.MsgStatusSuccess - MsgStatusWillRetry = status.MsgStatusWillRetry - MsgStatusPermFailure = status.MsgStatusPermFailure - MsgStatusUnsupported = status.MsgStatusUnsupported - MsgStatusTimeout = status.MsgStatusTimeout - MsgStatusDelivered = status.MsgStatusDelivered - MsgStatusDeliveryFailed = status.MsgStatusDeliveryFailed - - MsgReportedByAsmux = status.MsgReportedByAsmux - MsgReportedByBridge = status.MsgReportedByBridge - MsgReportedByHungry = status.MsgReportedByHungry - - LocalBridgeAccountStateSetup = status.LocalBridgeAccountStateSetup - LocalBridgeAccountStateDeleted = status.LocalBridgeAccountStateDeleted - - LocalBridgeDeviceStateSetup = status.LocalBridgeDeviceStateSetup - LocalBridgeDeviceStateLoggedOut = status.LocalBridgeDeviceStateLoggedOut - LocalBridgeDeviceStateError = status.LocalBridgeDeviceStateError - LocalBridgeDeviceStateDeleted = status.LocalBridgeDeviceStateDeleted -) - -// Deprecated: use bridgev2/status -var ( - CheckpointTypes = status.CheckpointTypes - NewMessageCheckpoint = status.NewMessageCheckpoint - ReasonToCheckpointStatus = status.ReasonToCheckpointStatus - BridgeStateHumanErrors = status.BridgeStateHumanErrors -) From 03618fcc891f17f1ce179cf7a08350f24d448dbf Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 20 Mar 2025 13:40:08 +0200 Subject: [PATCH 1070/1647] bridgev2: add support for timeouts on pending messages --- bridgev2/errors.go | 3 ++ bridgev2/networkinterface.go | 14 +++++++ bridgev2/portal.go | 72 +++++++++++++++++++++++++++++++----- 3 files changed, 79 insertions(+), 10 deletions(-) diff --git a/bridgev2/errors.go b/bridgev2/errors.go index 0e948184..c023dcdf 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -12,6 +12,7 @@ import ( "net/http" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" ) // ErrIgnoringRemoteEvent can be returned by [RemoteMessage.ConvertMessage] or [RemoteEdit.ConvertEdit] @@ -64,6 +65,8 @@ var ( ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true) ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) + ErrRemoteEchoTimeout = WrapErrorInStatus(errors.New("remote echo timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) + ErrRemoteAckTimeout = WrapErrorInStatus(errors.New("remote ack timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) ) // Common login interface errors diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 487f1ea6..29ee9fc9 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -302,6 +302,14 @@ type MatrixMessageResponse struct { PostSave func(context.Context, *database.Message) } +type OutgoingTimeoutConfig struct { + CheckInterval time.Duration + NoEchoTimeout time.Duration + NoEchoMessage string + NoAckTimeout time.Duration + NoAckMessage string +} + type NetworkGeneralCapabilities struct { // Does the network connector support disappearing messages? // This flag enables the message disappearing loop in the bridge. @@ -309,6 +317,10 @@ type NetworkGeneralCapabilities struct { // Should the bridge re-request user info on incoming messages even if the ghost already has info? // By default, info is only requested for ghosts with no name, and other updating is left to events. AggressiveUpdateInfo bool + // If the bridge uses the pending message mechanism ([MatrixMessage.AddPendingToSave]) + // to handle asynchronous message responses, this field can be set to enable + // automatic timeout errors in case the asynchronous response never arrives. + OutgoingMessageTimeouts *OutgoingTimeoutConfig } // NetworkAPI is an interface representing a remote network client for a single user login. @@ -1145,6 +1157,8 @@ type MatrixMessage struct { MatrixEventBase[*event.MessageEventContent] ThreadRoot *database.Message ReplyTo *database.Message + + pendingSaves []*outgoingMessage } type MatrixEdit struct { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 409a9c10..5e4328e7 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -59,10 +59,12 @@ type portalEvent interface { } type outgoingMessage struct { - db *database.Message - evt *event.Event - ignore bool - handle func(RemoteMessage, *database.Message) (bool, error) + db *database.Message + evt *event.Event + ignore bool + handle func(RemoteMessage, *database.Message) (bool, error) + ackedAt time.Time + timeouted bool } type Portal struct { @@ -76,7 +78,7 @@ type Portal struct { currentlyTypingLogins map[id.UserID]*UserLogin currentlyTypingLock sync.Mutex - outgoingMessages map[networkid.TransactionID]outgoingMessage + outgoingMessages map[networkid.TransactionID]*outgoingMessage outgoingMessagesLock sync.Mutex lastCapUpdate time.Time @@ -113,7 +115,7 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que Bridge: br, currentlyTypingLogins: make(map[id.UserID]*UserLogin), - outgoingMessages: make(map[networkid.TransactionID]outgoingMessage), + outgoingMessages: make(map[networkid.TransactionID]*outgoingMessage), } br.portalsByKey[portal.PortalKey] = portal if portal.MXID != "" { @@ -296,6 +298,11 @@ func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) { } func (portal *Portal) eventLoop() { + if cfg := portal.Bridge.Network.GetCapabilities().OutgoingMessageTimeouts; cfg != nil { + ctx, cancel := context.WithCancel(portal.Log.WithContext(context.Background())) + go portal.pendingMessageTimeoutLoop(ctx, cfg) + defer cancel() + } i := 0 for rawEvt := range portal.events { i++ @@ -957,7 +964,11 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin return } message := wrappedMsgEvt.fillDBMessage(resp.DB) - if !resp.Pending { + if resp.Pending { + for _, save := range wrappedMsgEvt.pendingSaves { + save.ackedAt = time.Now() + } + } else { if resp.DB == nil { log.Error().Msg("Network connector didn't return a message to save") } else { @@ -1003,7 +1014,7 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin // See also: [MatrixMessage.AddPendingToSave] func (evt *MatrixMessage) AddPendingToIgnore(txnID networkid.TransactionID) { evt.Portal.outgoingMessagesLock.Lock() - evt.Portal.outgoingMessages[txnID] = outgoingMessage{ + evt.Portal.outgoingMessages[txnID] = &outgoingMessage{ ignore: true, } evt.Portal.outgoingMessagesLock.Unlock() @@ -1017,12 +1028,14 @@ func (evt *MatrixMessage) AddPendingToIgnore(txnID networkid.TransactionID) { // // The provided function will be called when the message is encountered. func (evt *MatrixMessage) AddPendingToSave(message *database.Message, txnID networkid.TransactionID, handleEcho RemoteEchoHandler) { - evt.Portal.outgoingMessagesLock.Lock() - evt.Portal.outgoingMessages[txnID] = outgoingMessage{ + pending := &outgoingMessage{ db: evt.fillDBMessage(message), evt: evt.Event, handle: handleEcho, } + evt.Portal.outgoingMessagesLock.Lock() + evt.Portal.outgoingMessages[txnID] = pending + evt.pendingSaves = append(evt.pendingSaves, pending) evt.Portal.outgoingMessagesLock.Unlock() } @@ -1030,6 +1043,12 @@ func (evt *MatrixMessage) AddPendingToSave(message *database.Message, txnID netw // This should only be called if sending the message fails. func (evt *MatrixMessage) RemovePending(txnID networkid.TransactionID) { evt.Portal.outgoingMessagesLock.Lock() + pendingSave := evt.Portal.outgoingMessages[txnID] + if pendingSave != nil { + evt.pendingSaves = slices.DeleteFunc(evt.pendingSaves, func(save *outgoingMessage) bool { + return save == pendingSave + }) + } delete(evt.Portal.outgoingMessages, txnID) evt.Portal.outgoingMessagesLock.Unlock() } @@ -1063,6 +1082,35 @@ func (evt *MatrixMessage) fillDBMessage(message *database.Message) *database.Mes return message } +func (portal *Portal) pendingMessageTimeoutLoop(ctx context.Context, cfg *OutgoingTimeoutConfig) { + ticker := time.NewTicker(cfg.CheckInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + portal.checkPendingMessages(ctx, cfg) + case <-ctx.Done(): + return + } + } +} + +func (portal *Portal) checkPendingMessages(ctx context.Context, cfg *OutgoingTimeoutConfig) { + portal.outgoingMessagesLock.Lock() + defer portal.outgoingMessagesLock.Unlock() + for _, msg := range portal.outgoingMessages { + if msg.evt != nil && !msg.timeouted { + if cfg.NoEchoTimeout > 0 && !msg.ackedAt.IsZero() && time.Since(msg.ackedAt) > cfg.NoEchoTimeout { + msg.timeouted = true + portal.sendErrorStatus(ctx, msg.evt, ErrRemoteEchoTimeout.WithMessage(cfg.NoEchoMessage)) + } else if cfg.NoAckTimeout > 0 && time.Since(msg.db.Timestamp) > cfg.NoAckTimeout { + msg.timeouted = true + portal.sendErrorStatus(ctx, msg.evt, ErrRemoteAckTimeout.WithMessage(cfg.NoAckMessage)) + } + } + } +} + func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent, caps *event.RoomFeatures) { log := zerolog.Ctx(ctx) editTargetID := content.RelatesTo.GetReplaceID() @@ -3881,6 +3929,10 @@ func (portal *Portal) unlockedDeleteCache() { if portal.MXID != "" { delete(portal.Bridge.portalsByMXID, portal.MXID) } + if portal.events != nil { + // TODO there's a small risk of this racing with a queueEvent call + close(portal.events) + } } func (portal *Portal) Save(ctx context.Context) error { From 4a0aed30e85030bb39650482c15ba971f5b0533e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 20 Mar 2025 15:34:40 +0200 Subject: [PATCH 1071/1647] brdigev2/bridgestate: add option to send status updates to management room --- bridgev2/bridgeconfig/config.go | 1 + bridgev2/bridgeconfig/upgrade.go | 1 + bridgev2/bridgestate.go | 78 +++++++++++++++++----- bridgev2/matrix/mxmain/example-config.yaml | 4 ++ bridgev2/status/bridgestate.go | 26 ++------ 5 files changed, 72 insertions(+), 38 deletions(-) diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index 156fb772..937d9441 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -65,6 +65,7 @@ type BridgeConfig struct { SplitPortals bool `yaml:"split_portals"` ResendBridgeInfo bool `yaml:"resend_bridge_info"` NoBridgeInfoStateKey bool `yaml:"no_bridge_info_state_key"` + BridgeStatusNotices string `yaml:"bridge_status_notices"` BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` BridgeNotices bool `yaml:"bridge_notices"` TagOnlyOnCreate bool `yaml:"tag_only_on_create"` diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 07477ef1..95370681 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -31,6 +31,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "bridge", "split_portals") helper.Copy(up.Bool, "bridge", "resend_bridge_info") helper.Copy(up.Bool, "bridge", "no_bridge_info_state_key") + helper.Copy(up.Str|up.Null, "bridge", "bridge_status_notices") helper.Copy(up.Bool, "bridge", "bridge_matrix_leave") helper.Copy(up.Bool, "bridge", "bridge_notices") helper.Copy(up.Bool, "bridge", "tag_only_on_create") diff --git a/bridgev2/bridgestate.go b/bridgev2/bridgestate.go index 7863dcba..61f988ad 100644 --- a/bridgev2/bridgestate.go +++ b/bridgev2/bridgestate.go @@ -8,20 +8,24 @@ package bridgev2 import ( "context" + "fmt" "runtime/debug" "time" "github.com/rs/zerolog" "maunium.net/go/mautrix/bridgev2/status" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" ) type BridgeStateQueue struct { prevUnsent *status.BridgeState prevSent *status.BridgeState + errorSent bool ch chan status.BridgeState bridge *Bridge - user status.StandaloneCustomBridgeStateFiller + login *UserLogin } func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) { @@ -41,11 +45,11 @@ func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) { } } -func (br *Bridge) NewBridgeStateQueue(user status.StandaloneCustomBridgeStateFiller) *BridgeStateQueue { +func (br *Bridge) NewBridgeStateQueue(login *UserLogin) *BridgeStateQueue { bsq := &BridgeStateQueue{ ch: make(chan status.BridgeState, 10), bridge: br, - user: user, + login: login, } go bsq.loop() return bsq @@ -59,7 +63,7 @@ func (bsq *BridgeStateQueue) loop() { defer func() { err := recover() if err != nil { - bsq.bridge.Log.Error(). + bsq.login.Log.Error(). Bytes(zerolog.ErrorStackFieldName, debug.Stack()). Any(zerolog.ErrorFieldName, err). Msg("Panic in bridge state loop") @@ -70,22 +74,62 @@ func (bsq *BridgeStateQueue) loop() { } } +func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.BridgeState) { + noticeConfig := bsq.bridge.Config.BridgeStatusNotices + isError := state.StateEvent == status.StateBadCredentials || state.StateEvent == status.StateUnknownError + sendNotice := noticeConfig == "all" || (noticeConfig == "errors" && + (isError || (bsq.errorSent && state.StateEvent == status.StateConnected))) + if !sendNotice { + return + } + managementRoom, err := bsq.login.User.GetManagementRoom(ctx) + if err != nil { + bsq.login.Log.Err(err).Msg("Failed to get management room") + return + } + message := fmt.Sprintf("State update for %s: `%s`", bsq.login.RemoteName, state.StateEvent) + if state.Error != "" { + message += fmt.Sprintf(" (`%s`)", state.Error) + } + if state.Message != "" { + message += fmt.Sprintf(": %s", state.Message) + } + content := format.RenderMarkdown(message, true, false) + if !isError { + content.MsgType = event.MsgNotice + } + _, err = bsq.bridge.Bot.SendMessage(ctx, managementRoom, event.EventMessage, &event.Content{ + Parsed: content, + Raw: map[string]any{ + "fi.mau.bridge_state": state, + }, + }, nil) + if err != nil { + bsq.login.Log.Err(err).Msg("Failed to send bridge state notice") + } else { + bsq.errorSent = isError + } +} + func (bsq *BridgeStateQueue) immediateSendBridgeState(state status.BridgeState) { + if bsq.prevSent != nil && bsq.prevSent.ShouldDeduplicate(&state) { + bsq.login.Log.Debug(). + Str("state_event", string(state.StateEvent)). + Msg("Not sending bridge state as it's a duplicate") + return + } + + ctx := bsq.login.Log.WithContext(context.Background()) + bsq.sendNotice(ctx, state) + retryIn := 2 for { - if bsq.prevSent != nil && bsq.prevSent.ShouldDeduplicate(&state) { - bsq.bridge.Log.Debug(). - Str("state_event", string(state.StateEvent)). - Msg("Not sending bridge state as it's a duplicate") - return - } - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) err := bsq.bridge.Matrix.SendBridgeStatus(ctx, &state) cancel() if err != nil { - bsq.bridge.Log.Warn().Err(err). + bsq.login.Log.Warn().Err(err). Int("retry_in_seconds", retryIn). Msg("Failed to update bridge state") time.Sleep(time.Duration(retryIn) * time.Second) @@ -95,7 +139,7 @@ func (bsq *BridgeStateQueue) immediateSendBridgeState(state status.BridgeState) } } else { bsq.prevSent = &state - bsq.bridge.Log.Debug(). + bsq.login.Log.Debug(). Any("bridge_state", state). Msg("Sent new bridge state") return @@ -108,11 +152,11 @@ func (bsq *BridgeStateQueue) Send(state status.BridgeState) { return } - state = state.Fill(bsq.user) + state = state.Fill(bsq.login) bsq.prevUnsent = &state if len(bsq.ch) >= 8 { - bsq.bridge.Log.Warn().Msg("Bridge state queue is nearly full, discarding an item") + bsq.login.Log.Warn().Msg("Bridge state queue is nearly full, discarding an item") select { case <-bsq.ch: default: @@ -121,7 +165,7 @@ func (bsq *BridgeStateQueue) Send(state status.BridgeState) { select { case bsq.ch <- state: default: - bsq.bridge.Log.Error().Msg("Bridge state queue is full, dropped new state") + bsq.login.Log.Error().Msg("Bridge state queue is full, dropped new state") } } diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 86838ff1..1d4e18cf 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -21,6 +21,10 @@ bridge: # Should `m.bridge` events be sent without a state key? # By default, the bridge uses a unique key that won't conflict with other bridges. no_bridge_info_state_key: false + # Should bridge connection status be sent to the management room as `m.notice` events? + # These contain the same data that can be posted to an external HTTP server using homeserver -> status_endpoint. + # Allowed values: none, errors, all + bridge_status_notices: errors # Should leaving Matrix rooms be bridged as leaving groups on the remote network? bridge_matrix_leave: false diff --git a/bridgev2/status/bridgestate.go b/bridgev2/status/bridgestate.go index 73410df6..cb862110 100644 --- a/bridgev2/status/bridgestate.go +++ b/bridgev2/status/bridgestate.go @@ -12,6 +12,7 @@ import ( "encoding/json" "fmt" "io" + "maps" "net/http" "reflect" "time" @@ -19,7 +20,6 @@ import ( "github.com/tidwall/sjson" "go.mau.fi/util/jsontime" "go.mau.fi/util/ptr" - "golang.org/x/exp/maps" "maunium.net/go/mautrix" "maunium.net/go/mautrix/id" @@ -125,31 +125,15 @@ type GlobalBridgeState struct { } type BridgeStateFiller interface { - GetMXID() id.UserID - GetRemoteID() string - GetRemoteName() string -} - -type StandaloneCustomBridgeStateFiller interface { FillBridgeState(BridgeState) BridgeState } -type CustomBridgeStateFiller interface { - BridgeStateFiller - StandaloneCustomBridgeStateFiller -} +// Deprecated: use BridgeStateFiller instead +type StandaloneCustomBridgeStateFiller = BridgeStateFiller -func (pong BridgeState) Fill(user any) BridgeState { +func (pong BridgeState) Fill(user BridgeStateFiller) BridgeState { if user != nil { - if std, ok := user.(BridgeStateFiller); ok { - pong.UserID = std.GetMXID() - pong.RemoteID = std.GetRemoteID() - pong.RemoteName = std.GetRemoteName() - } - - if custom, ok := user.(StandaloneCustomBridgeStateFiller); ok { - pong = custom.FillBridgeState(pong) - } + pong = user.FillBridgeState(pong) } pong.Timestamp = jsontime.UnixNow() From 06f200da0d1099ec9b59c50ceb754380d8a16674 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 20 Mar 2025 15:39:36 +0200 Subject: [PATCH 1072/1647] bridgev2: clear management room on leave. Fixes #355 --- bridgev2/queue.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/bridgev2/queue.go b/bridgev2/queue.go index 38895953..2981bdce 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -119,6 +119,17 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) { if evt.Type == event.StateMember && evt.GetStateKey() == br.Bot.GetMXID().String() && evt.Content.AsMember().Membership == event.MembershipInvite && sender != nil { br.handleBotInvite(ctx, evt, sender) return + } else if sender != nil && evt.RoomID == sender.ManagementRoom { + if evt.Type == event.StateMember && evt.Content.AsMember().Membership == event.MembershipLeave && (evt.GetStateKey() == br.Bot.GetMXID().String() || evt.GetStateKey() == sender.MXID.String()) { + sender.ManagementRoom = "" + err := br.DB.User.Update(ctx, sender.User) + if err != nil { + log.Err(err).Msg("Failed to clear user's management room in database") + } else { + log.Debug().Msg("Cleared user's management room due to leave event") + } + } + return } portal, err := br.GetPortalByMXID(ctx, evt.RoomID) if err != nil { From d3ca9472cb13e9b36d66885fd4176a89faea08f1 Mon Sep 17 00:00:00 2001 From: SpiritCroc Date: Mon, 31 Mar 2025 09:53:01 +0200 Subject: [PATCH 1073/1647] event: add Beeper transcription event definitions (#364) --- event/beeper.go | 12 ++++++++++++ event/content.go | 1 + event/relations.go | 9 +++++---- event/type.go | 4 +++- 4 files changed, 21 insertions(+), 5 deletions(-) diff --git a/event/beeper.go b/event/beeper.go index 74b44a09..19c6253e 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -57,6 +57,18 @@ type BeeperMessageStatusEventContent struct { DeliveredToUsers *[]id.UserID `json:"delivered_to_users,omitempty"` } +type BeeperRelatesTo struct { + EventID id.EventID `json:"event_id,omitempty"` + RoomID id.RoomID `json:"room_id,omitempty"` + Type RelationType `json:"rel_type,omitempty"` +} + +type BeeperTranscriptionEventContent struct { + Text []ExtensibleText `json:"m.text,omitempty"` + Model string `json:"com.beeper.transcription.model,omitempty"` + RelatesTo BeeperRelatesTo `json:"com.beeper.relates_to,omitempty"` +} + type BeeperRetryMetadata struct { OriginalEventID id.EventID `json:"original_event_id"` RetryCount int `json:"retry_count"` diff --git a/event/content.go b/event/content.go index b8e130db..2347898e 100644 --- a/event/content.go +++ b/event/content.go @@ -60,6 +60,7 @@ var TypeMap = map[Type]reflect.Type{ EventUnstablePollResponse: reflect.TypeOf(PollResponseEventContent{}), BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}), + BeeperTranscription: reflect.TypeOf(BeeperTranscriptionEventContent{}), AccountDataRoomTags: reflect.TypeOf(TagEventContent{}), AccountDataDirectChats: reflect.TypeOf(DirectChatsEventContent{}), diff --git a/event/relations.go b/event/relations.go index ea40cc06..30cf6c20 100644 --- a/event/relations.go +++ b/event/relations.go @@ -15,10 +15,11 @@ import ( type RelationType string const ( - RelReplace RelationType = "m.replace" - RelReference RelationType = "m.reference" - RelAnnotation RelationType = "m.annotation" - RelThread RelationType = "m.thread" + RelReplace RelationType = "m.replace" + RelReference RelationType = "m.reference" + RelAnnotation RelationType = "m.annotation" + RelThread RelationType = "m.thread" + RelBeeperTranscription RelationType = "com.beeper.transcription" ) type RelatesTo struct { diff --git a/event/type.go b/event/type.go index 41d7c47b..591d598d 100644 --- a/event/type.go +++ b/event/type.go @@ -126,7 +126,8 @@ func (et *Type) GuessClass() TypeClass { InRoomVerificationStart.Type, InRoomVerificationReady.Type, InRoomVerificationAccept.Type, InRoomVerificationKey.Type, InRoomVerificationMAC.Type, InRoomVerificationCancel.Type, CallInvite.Type, CallCandidates.Type, CallAnswer.Type, CallReject.Type, CallSelectAnswer.Type, - CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type, EventUnstablePollStart.Type, EventUnstablePollResponse.Type: + CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type, EventUnstablePollStart.Type, EventUnstablePollResponse.Type, + BeeperTranscription.Type: return MessageEventType case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type, ToDeviceBeeperRoomKeyAck.Type: @@ -233,6 +234,7 @@ var ( CallHangup = Type{"m.call.hangup", MessageEventType} BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType} + BeeperTranscription = Type{"com.beeper.transcription", MessageEventType} EventUnstablePollStart = Type{Type: "org.matrix.msc3381.poll.start", Class: MessageEventType} EventUnstablePollResponse = Type{Type: "org.matrix.msc3381.poll.response", Class: MessageEventType} From 93b9509135e7533d9ef50fa78c5af468c2419acc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 2 Apr 2025 12:03:07 +0300 Subject: [PATCH 1074/1647] bridgev2/portal: send typing stop after message --- bridgev2/portal.go | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 5e4328e7..d8262d0a 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -20,6 +20,7 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/exfmt" "go.mau.fi/util/exslices" + "go.mau.fi/util/exsync" "go.mau.fi/util/ptr" "go.mau.fi/util/variationselector" "golang.org/x/exp/maps" @@ -77,6 +78,7 @@ type Portal struct { currentlyTyping []id.UserID currentlyTypingLogins map[id.UserID]*UserLogin currentlyTypingLock sync.Mutex + currentlyTypingGhosts *exsync.Set[id.UserID] outgoingMessages map[networkid.TransactionID]*outgoingMessage outgoingMessagesLock sync.Mutex @@ -115,6 +117,7 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que Bridge: br, currentlyTypingLogins: make(map[id.UserID]*UserLogin), + currentlyTypingGhosts: exsync.NewSet[id.UserID](), outgoingMessages: make(map[networkid.TransactionID]*outgoingMessage), } br.portalsByKey[portal.PortalKey] = portal @@ -2099,6 +2102,9 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin return } portal.sendConvertedMessage(ctx, evt.GetID(), intent, evt.GetSender().Sender, converted, ts, getStreamOrder(evt), nil) + if portal.currentlyTypingGhosts.Pop(intent.GetMXID()) { + intent.MarkTyping(ctx, portal.MXID, TypingTypeText, 0) + } } func (portal *Portal) sendRemoteErrorNotice(ctx context.Context, intent MatrixAPI, err error, ts time.Time, evtTypeName string) { @@ -2161,6 +2167,9 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e return } portal.sendConvertedEdit(ctx, existing[0].ID, evt.GetSender().Sender, converted, intent, ts, getStreamOrder(evt)) + if portal.currentlyTypingGhosts.Pop(intent.GetMXID()) { + intent.MarkTyping(ctx, portal.MXID, TypingTypeText, 0) + } } func (portal *Portal) sendConvertedEdit( @@ -2709,10 +2718,16 @@ func (portal *Portal) handleRemoteTyping(ctx context.Context, source *UserLogin, typingType = typedEvt.GetTypingType() } intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventTyping) - err := intent.MarkTyping(ctx, portal.MXID, typingType, evt.GetTimeout()) + timeout := evt.GetTimeout() + err := intent.MarkTyping(ctx, portal.MXID, typingType, timeout) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to bridge typing event") } + if timeout == 0 { + portal.currentlyTypingGhosts.Remove(intent.GetMXID()) + } else { + portal.currentlyTypingGhosts.Add(intent.GetMXID()) + } } func (portal *Portal) handleRemoteChatInfoChange(ctx context.Context, source *UserLogin, evt RemoteChatInfoChange) { From 74a02366d778385dd8a75bea6a6b5325bf4481f7 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Wed, 2 Apr 2025 09:10:35 -0600 Subject: [PATCH 1075/1647] bridgev2/legacymigrate: add post-migrate hook Signed-off-by: Sumner Evans --- bridgev2/matrix/mxmain/legacymigrate.go | 15 +++++++++++---- bridgev2/matrix/mxmain/main.go | 4 ++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go index d33dd8cd..8b25e210 100644 --- a/bridgev2/matrix/mxmain/legacymigrate.go +++ b/bridgev2/matrix/mxmain/legacymigrate.go @@ -226,11 +226,18 @@ func (br *BridgeMain) PostMigrate(ctx context.Context) error { Object("portal_key", portal.PortalKey). Str("room_type", string(portal.RoomType)). Msg("Migrating portal") - switch portal.RoomType { - case database.RoomTypeDM: - err = br.postMigrateDMPortal(ctx, portal) + if br.PostMigratePortal != nil { + err = br.PostMigratePortal(ctx, portal) if err != nil { - return fmt.Errorf("failed to update DM portal %s: %w", portal.MXID, err) + return fmt.Errorf("failed to run post-migrate portal hook for %s: %w", portal.MXID, err) + } + } else { + switch portal.RoomType { + case database.RoomTypeDM: + err = br.postMigrateDMPortal(ctx, portal) + if err != nil { + return fmt.Errorf("failed to update DM portal %s: %w", portal.MXID, err) + } } } _, err = br.Matrix.Bot.SendStateEvent(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.ElementFunctionalMembersContent{ diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index dab0b914..63334ba5 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -67,6 +67,10 @@ type BridgeMain struct { PostInit func() PostStart func() + // PostMigratePortal is a function that will be called during a legacy + // migration for each portal. + PostMigratePortal func(context.Context, *bridgev2.Portal) error + // Connector is the network connector for the bridge. Connector bridgev2.NetworkConnector From 49648897875fe3e44cee08661eba69958108c642 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Fri, 4 Apr 2025 08:37:49 -0600 Subject: [PATCH 1076/1647] bridgev2/legacymigrate: don't error if post migrate hook fails Signed-off-by: Sumner Evans --- bridgev2/matrix/mxmain/legacymigrate.go | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go index 8b25e210..c8eb820b 100644 --- a/bridgev2/matrix/mxmain/legacymigrate.go +++ b/bridgev2/matrix/mxmain/legacymigrate.go @@ -208,28 +208,31 @@ func (br *BridgeMain) postMigrateDMPortal(ctx context.Context, portal *bridgev2. } func (br *BridgeMain) PostMigrate(ctx context.Context) error { + log := br.Log.With().Str("action", "post-migrate").Logger() wasMigrated, err := br.DB.TableExists(ctx, "database_was_migrated") if err != nil { return fmt.Errorf("failed to check if database_was_migrated table exists: %w", err) } else if !wasMigrated { return nil } - zerolog.Ctx(ctx).Info().Msg("Doing post-migration updates to Matrix rooms") + log.Info().Msg("Doing post-migration updates to Matrix rooms") portals, err := br.Bridge.GetAllPortalsWithMXID(ctx) if err != nil { return fmt.Errorf("failed to get all portals: %w", err) } for _, portal := range portals { - zerolog.Ctx(ctx).Debug(). + log := log.With(). Stringer("room_id", portal.MXID). Object("portal_key", portal.PortalKey). Str("room_type", string(portal.RoomType)). - Msg("Migrating portal") + Logger() + log.Debug().Msg("Migrating portal") if br.PostMigratePortal != nil { err = br.PostMigratePortal(ctx, portal) if err != nil { - return fmt.Errorf("failed to run post-migrate portal hook for %s: %w", portal.MXID, err) + log.Err(err).Msg("Failed to run post-migrate portal hook") + continue } } else { switch portal.RoomType { @@ -244,7 +247,7 @@ func (br *BridgeMain) PostMigrate(ctx context.Context) error { ServiceMembers: []id.UserID{br.Matrix.Bot.UserID}, }) if err != nil { - zerolog.Ctx(ctx).Warn().Err(err).Stringer("room_id", portal.MXID).Msg("Failed to set service members") + log.Warn().Err(err).Stringer("room_id", portal.MXID).Msg("Failed to set service members") } } @@ -252,6 +255,6 @@ func (br *BridgeMain) PostMigrate(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to drop database_was_migrated table: %w", err) } - zerolog.Ctx(ctx).Info().Msg("Post-migration updates complete") + log.Info().Msg("Post-migration updates complete") return nil } From e675a3c09c3894d994f8d6fdcdcaed024de5d33a Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Sun, 6 Apr 2025 00:41:16 +0100 Subject: [PATCH 1077/1647] client: add `allowed_room_ids` to room summary response (#367) --- responses.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/responses.go b/responses.go index 93a780ef..ee7f4703 100644 --- a/responses.go +++ b/responses.go @@ -221,9 +221,10 @@ type RespMutualRooms struct { type RespRoomSummary struct { PublicRoomInfo - Membership event.Membership `json:"membership,omitempty"` - RoomVersion event.RoomVersion `json:"room_version,omitempty"` - Encryption id.Algorithm `json:"encryption,omitempty"` + Membership event.Membership `json:"membership,omitempty"` + RoomVersion event.RoomVersion `json:"room_version,omitempty"` + Encryption id.Algorithm `json:"encryption,omitempty"` + AllowedRoomIDs []id.RoomID `json:"allowed_room_ids,omitempty"` UnstableRoomVersion event.RoomVersion `json:"im.nheko.summary.room_version,omitempty"` UnstableRoomVersionOld event.RoomVersion `json:"im.nheko.summary.version,omitempty"` From 0fcb552c27a8898a8a178c92a504becad596d264 Mon Sep 17 00:00:00 2001 From: Adam Van Ymeren Date: Wed, 9 Apr 2025 07:46:03 -0700 Subject: [PATCH 1078/1647] bridgev2: make Bridge.Start() take a context (#368) --- bridgev2/bridge.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index bdc8480d..aef86196 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -107,8 +107,8 @@ func (e DBUpgradeError) Unwrap() error { return e.Err } -func (br *Bridge) Start() error { - ctx := br.Log.WithContext(context.Background()) +func (br *Bridge) Start(ctx context.Context) error { + ctx = br.Log.WithContext(ctx) err := br.StartConnectors(ctx) if err != nil { return err From 826089e020fb838951df813138d89ab47b07b6b1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 11 Apr 2025 17:08:52 +0300 Subject: [PATCH 1079/1647] id: make user id parsing more efficient --- id/userid.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/id/userid.go b/id/userid.go index 1e1f3b29..59136013 100644 --- a/id/userid.go +++ b/id/userid.go @@ -43,10 +43,10 @@ func ParseCommonIdentifier[Stringish ~string](identifier Stringish) (sigil byte, } sigil = identifier[0] strIdentifier := string(identifier) - if strings.ContainsRune(strIdentifier, ':') { - parts := strings.SplitN(strIdentifier, ":", 2) - localpart = parts[0][1:] - homeserver = parts[1] + colonIdx := strings.IndexByte(strIdentifier, ':') + if colonIdx > 0 { + localpart = strIdentifier[1:colonIdx] + homeserver = strIdentifier[colonIdx+1:] } else { localpart = strIdentifier[1:] } From cf801729af425fa373fa1a2ff62a4fd470fce7f9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 13 Apr 2025 02:59:13 +0300 Subject: [PATCH 1080/1647] bridgev2/commands: implement MarkRead --- bridgev2/commands/event.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/bridgev2/commands/event.go b/bridgev2/commands/event.go index 78ed94bb..88ba9698 100644 --- a/bridgev2/commands/event.go +++ b/bridgev2/commands/event.go @@ -10,6 +10,7 @@ import ( "context" "fmt" "strings" + "time" "github.com/rs/zerolog" @@ -92,9 +93,8 @@ func (ce *Event) Redact(req ...mautrix.ReqRedact) { // MarkRead marks the command event as read. func (ce *Event) MarkRead() { - // TODO - //err := ce.Bot.SendReceipt(ce.Ctx, ce.RoomID, ce.EventID, event.ReceiptTypeRead, nil) - //if err != nil { - // ce.Log.Err(err).Msg("Failed to mark command as read") - //} + err := ce.Bot.MarkRead(ce.Ctx, ce.RoomID, ce.EventID, time.Now()) + if err != nil { + ce.Log.Err(err).Msg("Failed to mark command as read") + } } From 0f06c9ce31cd3043d12bc7af12ab40f63496efaf Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 13 Apr 2025 02:59:26 +0300 Subject: [PATCH 1081/1647] event/content: add SetThread method --- event/relations.go | 4 ++++ event/reply.go | 24 +++++++++++++++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/event/relations.go b/event/relations.go index 30cf6c20..e855a7e6 100644 --- a/event/relations.go +++ b/event/relations.go @@ -101,6 +101,10 @@ func (rel *RelatesTo) SetReplace(mxid id.EventID) *RelatesTo { } func (rel *RelatesTo) SetReplyTo(mxid id.EventID) *RelatesTo { + if rel.Type != RelThread { + rel.Type = "" + rel.EventID = "" + } rel.InReplyTo = &InReplyTo{EventID: mxid} rel.IsFallingBack = false return rel diff --git a/event/reply.go b/event/reply.go index 1a88c619..9ae1c110 100644 --- a/event/reply.go +++ b/event/reply.go @@ -47,5 +47,27 @@ func (content *MessageEventContent) GetReplyTo() id.EventID { } func (content *MessageEventContent) SetReply(inReplyTo *Event) { - content.RelatesTo = (&RelatesTo{}).SetReplyTo(inReplyTo.ID) + if content.RelatesTo == nil { + content.RelatesTo = &RelatesTo{} + } + content.RelatesTo.SetReplyTo(inReplyTo.ID) + if content.Mentions == nil { + content.Mentions = &Mentions{} + } + content.Mentions.Add(inReplyTo.Sender) +} + +func (content *MessageEventContent) SetThread(inReplyTo *Event) { + root := inReplyTo.ID + relatable, ok := inReplyTo.Content.Parsed.(Relatable) + if ok { + targetRoot := relatable.OptionalGetRelatesTo().GetThreadParent() + if targetRoot != "" { + root = targetRoot + } + } + if content.RelatesTo == nil { + content.RelatesTo = &RelatesTo{} + } + content.RelatesTo.SetThread(root, inReplyTo.ID) } From 7c1b0c5968943efe5f6f92a7bd36a174316ba5ed Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 13 Apr 2025 02:59:35 +0300 Subject: [PATCH 1082/1647] format: add EscapeMarkdown --- format/markdown.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/format/markdown.go b/format/markdown.go index d099ba00..a1c93162 100644 --- a/format/markdown.go +++ b/format/markdown.go @@ -8,6 +8,7 @@ package format import ( "fmt" + "regexp" "strings" "github.com/yuin/goldmark" @@ -39,6 +40,15 @@ func UnwrapSingleParagraph(html string) string { return html } +var mdEscapeRegex = regexp.MustCompile("([\\\\`*_[\\]])") + +func EscapeMarkdown(text string) string { + text = mdEscapeRegex.ReplaceAllString(text, "\\$1") + text = strings.ReplaceAll(text, ">", ">") + text = strings.ReplaceAll(text, "<", "<") + return text +} + func RenderMarkdownCustom(text string, renderer goldmark.Markdown) event.MessageEventContent { var buf strings.Builder err := renderer.Convert([]byte(text), &buf) From 56e2adbf831402f5f8530871972b69604b91e127 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 13 Apr 2025 02:59:49 +0300 Subject: [PATCH 1083/1647] commands: add generic command processing framework for bots --- commands/event.go | 124 ++++++++++++++++++++++++++++++++ commands/prevalidate.go | 86 ++++++++++++++++++++++ commands/processor.go | 156 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 366 insertions(+) create mode 100644 commands/event.go create mode 100644 commands/prevalidate.go create mode 100644 commands/processor.go diff --git a/commands/event.go b/commands/event.go new file mode 100644 index 00000000..baf9ecda --- /dev/null +++ b/commands/event.go @@ -0,0 +1,124 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "context" + "fmt" + "strings" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" +) + +// Event contains the data of a single command event. +// It also provides some helper methods for responding to the command. +type Event[MetaType any] struct { + *event.Event + // RawInput is the entire message before splitting into command and arguments. + RawInput string + // Command is the lowercased first word of the message. + Command string + // Args are the rest of the message split by whitespace ([strings.Fields]). + Args []string + // RawArgs is the same as args, but without the splitting by whitespace. + RawArgs string + + Ctx context.Context + Proc *Processor[MetaType] + Handler *Handler[MetaType] + Meta MetaType +} + +var IDHTMLParser = &format.HTMLParser{ + PillConverter: func(displayname, mxid, eventID string, ctx format.Context) string { + if len(mxid) == 0 { + return displayname + } + if eventID != "" { + return fmt.Sprintf("https://matrix.to/#/%s/%s", mxid, eventID) + } + return mxid + }, + ItalicConverter: func(s string, c format.Context) string { + return fmt.Sprintf("*%s*", s) + }, + Newline: "\n", +} + +// ParseEvent parses a message into a command event struct. +func ParseEvent[MetaType any](ctx context.Context, evt *event.Event) *Event[MetaType] { + content := evt.Content.Parsed.(*event.MessageEventContent) + text := content.Body + if content.Format == event.FormatHTML { + text = IDHTMLParser.Parse(content.FormattedBody, format.NewContext(ctx)) + } + parts := strings.Fields(text) + return &Event[MetaType]{ + Event: evt, + RawInput: text, + Command: strings.ToLower(parts[0]), + Args: parts[1:], + RawArgs: strings.TrimLeft(strings.TrimPrefix(text, parts[0]), " "), + Ctx: ctx, + } +} + +type ReplyOpts struct { + AllowHTML bool + AllowMarkdown bool + Reply bool + Thread bool + SendAsText bool +} + +func (evt *Event[MetaType]) Reply(msg string, args ...any) { + if len(args) > 0 { + msg = fmt.Sprintf(msg, args...) + } + evt.Respond(msg, ReplyOpts{AllowMarkdown: true, Reply: true}) +} + +func (evt *Event[MetaType]) Respond(msg string, opts ReplyOpts) { + content := format.RenderMarkdown(msg, opts.AllowMarkdown, opts.AllowHTML) + if opts.Thread { + content.SetThread(evt.Event) + } + if opts.Reply { + content.SetReply(evt.Event) + } + if !opts.SendAsText { + content.MsgType = event.MsgNotice + } + _, err := evt.Proc.Client.SendMessageEvent(evt.Ctx, evt.RoomID, event.EventMessage, content) + if err != nil { + zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to send reply") + } +} + +func (evt *Event[MetaType]) React(emoji string) { + _, err := evt.Proc.Client.SendReaction(evt.Ctx, evt.RoomID, evt.ID, emoji) + if err != nil { + zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to send reaction") + } +} + +func (evt *Event[MetaType]) Redact() { + _, err := evt.Proc.Client.RedactEvent(evt.Ctx, evt.RoomID, evt.ID) + if err != nil { + zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to redact command") + } +} + +func (evt *Event[MetaType]) MarkRead() { + err := evt.Proc.Client.MarkRead(evt.Ctx, evt.RoomID, evt.ID) + if err != nil { + zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to send read receipt") + } +} diff --git a/commands/prevalidate.go b/commands/prevalidate.go new file mode 100644 index 00000000..95bbcc97 --- /dev/null +++ b/commands/prevalidate.go @@ -0,0 +1,86 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "strings" +) + +// A PreValidator contains a function that takes an Event and returns true if the event should be processed further. +// +// The [PreValidator] field in [Processor] is called before the handler of the command is checked. +// It can be used to modify the command or arguments, or to skip the command entirely. +// +// The primary use case is removing a static command prefix, such as requiring all commands start with `!`. +type PreValidator[MetaType any] interface { + Validate(*Event[MetaType]) bool +} + +// FuncPreValidator is a simple function that implements the PreValidator interface. +type FuncPreValidator[MetaType any] func(*Event[MetaType]) bool + +func (f FuncPreValidator[MetaType]) Validate(ce *Event[MetaType]) bool { + return f(ce) +} + +// AllPreValidator can be used to combine multiple PreValidators, such that +// all of them must return true for the command to be processed further. +type AllPreValidator[MetaType any] []PreValidator[MetaType] + +func (f AllPreValidator[MetaType]) Validate(ce *Event[MetaType]) bool { + for _, validator := range f { + if !validator.Validate(ce) { + return false + } + } + return true +} + +// AnyPreValidator can be used to combine multiple PreValidators, such that +// at least one of them must return true for the command to be processed further. +type AnyPreValidator[MetaType any] []PreValidator[MetaType] + +func (f AnyPreValidator[MetaType]) Validate(ce *Event[MetaType]) bool { + for _, validator := range f { + if validator.Validate(ce) { + return true + } + } + return false +} + +// ValidatePrefixCommand checks that the first word in the input is exactly the given string, +// and if so, removes it from the command and sets the command to the next word. +// +// For example, `ValidateCommandPrefix("!mybot")` would only allow commands in the form `!mybot foo`, +// where `foo` would be used to look up the command handler. +func ValidatePrefixCommand[MetaType any](prefix string) PreValidator[MetaType] { + return FuncPreValidator[MetaType](func(ce *Event[MetaType]) bool { + if ce.Command == prefix && len(ce.Args) > 0 { + ce.Command = strings.ToLower(ce.Args[0]) + ce.RawArgs = strings.TrimLeft(strings.TrimPrefix(ce.RawArgs, ce.Args[0]), " ") + ce.Args = ce.Args[1:] + return true + } + return false + }) +} + +// ValidatePrefixSubstring checks that the command starts with the given prefix, +// and if so, removes it from the command. +// +// For example, `ValidatePrefixSubstring("!")` would only allow commands in the form `!foo`, +// where `foo` would be used to look up the command handler. +func ValidatePrefixSubstring[MetaType any](prefix string) PreValidator[MetaType] { + return FuncPreValidator[MetaType](func(ce *Event[MetaType]) bool { + if strings.HasPrefix(ce.Command, prefix) { + ce.Command = ce.Command[len(prefix):] + return true + } + return false + }) +} diff --git a/commands/processor.go b/commands/processor.go new file mode 100644 index 00000000..d4a29690 --- /dev/null +++ b/commands/processor.go @@ -0,0 +1,156 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "context" + "fmt" + "runtime/debug" + "strings" + "sync" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" +) + +// Processor implements boilerplate code for splitting messages into a command and arguments, +// and finding the appropriate handler for the command. +type Processor[MetaType any] struct { + Client *mautrix.Client + LogArgs bool + PreValidator PreValidator[MetaType] + Meta MetaType + commands map[string]*Handler[MetaType] + aliases map[string]string + lock sync.RWMutex +} + +type Handler[MetaType any] struct { + Func func(ce *Event[MetaType]) + + // Name is the primary name of the command. It must be lowercase. + Name string + // Aliases are alternative names for the command. They must be lowercase. + Aliases []string +} + +// UnknownCommandName is the name of the fallback handler which is used if no other handler is found. +// If even the unknown command handler is not found, the command is ignored. +const UnknownCommandName = "unknown-command" + +func NewProcessor[MetaType any](cli *mautrix.Client) *Processor[MetaType] { + proc := &Processor[MetaType]{ + Client: cli, + PreValidator: ValidatePrefixSubstring[MetaType]("!"), + commands: make(map[string]*Handler[MetaType]), + aliases: make(map[string]string), + } + proc.Register(&Handler[MetaType]{ + Name: UnknownCommandName, + Func: func(ce *Event[MetaType]) { + ce.Reply("Unknown command") + }, + }) + return proc +} + +// Register registers the given command handlers. +func (proc *Processor[MetaType]) Register(handlers ...*Handler[MetaType]) { + proc.lock.Lock() + defer proc.lock.Unlock() + for _, handler := range handlers { + proc.registerOne(handler) + } +} + +func (proc *Processor[MetaType]) registerOne(handler *Handler[MetaType]) { + if strings.ToLower(handler.Name) != handler.Name { + panic(fmt.Errorf("command %q is not lowercase", handler.Name)) + } + proc.commands[handler.Name] = handler + for _, alias := range handler.Aliases { + if strings.ToLower(alias) != alias { + panic(fmt.Errorf("alias %q is not lowercase", alias)) + } + proc.aliases[alias] = handler.Name + } +} + +func (proc *Processor[MetaType]) Unregister(handlers ...*Handler[MetaType]) { + proc.lock.Lock() + defer proc.lock.Unlock() + for _, handler := range handlers { + proc.unregisterOne(handler) + } +} + +func (proc *Processor[MetaType]) unregisterOne(handler *Handler[MetaType]) { + delete(proc.commands, handler.Name) + for _, alias := range handler.Aliases { + if proc.aliases[alias] == handler.Name { + delete(proc.aliases, alias) + } + } +} + +func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) { + log := *zerolog.Ctx(ctx) + defer func() { + panicErr := recover() + if panicErr != nil { + logEvt := log.Error(). + Bytes(zerolog.ErrorStackFieldName, debug.Stack()) + if realErr, ok := panicErr.(error); ok { + logEvt = logEvt.Err(realErr) + } else { + logEvt = logEvt.Any(zerolog.ErrorFieldName, panicErr) + } + logEvt.Msg("Panic in command handler") + _, err := proc.Client.SendReaction(ctx, evt.RoomID, evt.ID, "💥") + if err != nil { + log.Err(err).Msg("Failed to send reaction after panic") + } + } + }() + parsed := ParseEvent[MetaType](ctx, evt) + if !proc.PreValidator.Validate(parsed) { + return + } + + realCommand := parsed.Command + proc.lock.RLock() + alias, ok := proc.aliases[realCommand] + if ok { + realCommand = alias + } + handler, ok := proc.commands[realCommand] + if !ok { + handler, ok = proc.commands[UnknownCommandName] + } + proc.lock.RUnlock() + if !ok { + return + } + + logWith := log.With(). + Str("command", realCommand). + Stringer("sender", evt.Sender). + Stringer("room_id", evt.RoomID) + if proc.LogArgs { + logWith = logWith.Strs("args", parsed.Args) + } + log = logWith.Logger() + parsed.Ctx = log.WithContext(ctx) + parsed.Handler = handler + parsed.Proc = proc + parsed.Meta = proc.Meta + + log.Debug().Msg("Processing command") + handler.Func(parsed) +} From 60e14d7dffa4ff82e3abaa2a5f38e4a87da0865e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 13 Apr 2025 22:45:51 +0300 Subject: [PATCH 1084/1647] format: parse task list html --- format/htmlparser.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/format/htmlparser.go b/format/htmlparser.go index 25543926..e50a578e 100644 --- a/format/htmlparser.go +++ b/format/htmlparser.go @@ -348,6 +348,8 @@ func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string { return parser.imgToString(node, ctx) case "hr": return parser.HorizontalLine + case "input": + return parser.inputToString(node, ctx) case "pre": var preStr, language string if node.FirstChild != nil && node.FirstChild.Type == html.ElementNode && node.FirstChild.Data == "code" { @@ -371,6 +373,17 @@ func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string { } } +func (parser *HTMLParser) inputToString(node *html.Node, ctx Context) string { + if len(ctx.TagStack) > 1 && ctx.TagStack[len(ctx.TagStack)-2] == "li" { + _, checked := parser.maybeGetAttribute(node, "checked") + if checked { + return "[x]" + } + return "[ ]" + } + return parser.nodeToTagAwareString(node.FirstChild, ctx) +} + func (parser *HTMLParser) singleNodeToString(node *html.Node, ctx Context) TaggedString { switch node.Type { case html.TextNode: From 99ff0c0964e4aeb9e58004193590e8ba782b3c78 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 14 Apr 2025 23:08:11 +0300 Subject: [PATCH 1085/1647] crypto/decryptmegolm: add option to ignore failing to parse content after decryption --- crypto/decryptmegolm.go | 2 +- crypto/machine.go | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index 11ab0f49..47279474 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -160,7 +160,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event if err != nil { if errors.Is(err, event.ErrUnsupportedContentType) { log.Warn().Msg("Unsupported event type in encrypted event") - } else { + } else if !mach.IgnorePostDecryptionParseErrors { return nil, fmt.Errorf("failed to parse content of megolm payload event: %w", err) } } diff --git a/crypto/machine.go b/crypto/machine.go index cacc73b6..e2af298b 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -44,6 +44,8 @@ type OlmMachine struct { // Don't mark outbound Olm sessions as shared for devices they were initially sent to. DisableSharedGroupSessionTracking bool + IgnorePostDecryptionParseErrors bool + SendKeysMinTrust id.TrustState ShareKeysMinTrust id.TrustState From 95a7e940d598ef333e231c766864bdc6a974b9ec Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 15 Apr 2025 11:06:02 +0300 Subject: [PATCH 1086/1647] bridgev2: don't keep cache lock while waiting for stop --- bridgev2/bridge.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index aef86196..db6371b2 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -327,8 +327,8 @@ func (br *Bridge) stop(isRunOnce bool) { for _, login := range br.userLoginsByID { go login.Disconnect(wg.Done) } - wg.Wait() br.cacheLock.Unlock() + wg.Wait() } if stopNet, ok := br.Network.(StoppableNetwork); ok { stopNet.Stop() From aae91f67b44dd29894e00f203ee78819237bb948 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 15 Apr 2025 12:11:18 +0300 Subject: [PATCH 1087/1647] bridgev2: split stopping matrix connector Also fix stopping the websocket in the default Matrix connector --- bridgev2/bridge.go | 3 ++- bridgev2/matrix/connector.go | 26 +++++++++++++++++++++++++- bridgev2/matrixinterface.go | 1 + 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index db6371b2..24ceaf6b 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -319,7 +319,7 @@ func (br *Bridge) Stop() { func (br *Bridge) stop(isRunOnce bool) { br.Log.Info().Msg("Shutting down bridge") br.stopBackfillQueue.Set() - br.Matrix.Stop() + br.Matrix.PreStop() if !isRunOnce { br.cacheLock.Lock() var wg sync.WaitGroup @@ -330,6 +330,7 @@ func (br *Bridge) stop(isRunOnce bool) { br.cacheLock.Unlock() wg.Wait() } + br.Matrix.Stop() if stopNet, ok := br.Network.(StoppableNetwork); ok { stopNet.Stop() } diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 4a930135..f56eece3 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -233,13 +233,37 @@ func (br *Connector) GetCapabilities() *bridgev2.MatrixCapabilities { return br.Capabilities } -func (br *Connector) Stop() { +func sendStopSignal(ch chan struct{}) { + if ch != nil { + select { + case ch <- struct{}{}: + default: + } + } +} + +func (br *Connector) PreStop() { br.stopping = true br.AS.Stop() + if stopWebsocket := br.AS.StopWebsocket; stopWebsocket != nil { + stopWebsocket(appservice.ErrWebsocketManualStop) + } + sendStopSignal(br.wsStopPinger) + sendStopSignal(br.wsShortCircuitReconnectBackoff) +} + +func (br *Connector) Stop() { br.EventProcessor.Stop() if br.Crypto != nil { br.Crypto.Stop() } + if wsStopChan := br.wsStopped; wsStopChan != nil { + select { + case <-wsStopChan: + case <-time.After(4 * time.Second): + br.Log.Warn().Msg("Timed out waiting for websocket to close") + } + } } var MinSpecVersion = mautrix.SpecV14 diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 2665956c..4ccba353 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -31,6 +31,7 @@ type MatrixCapabilities struct { type MatrixConnector interface { Init(*Bridge) Start(ctx context.Context) error + PreStop() Stop() GetCapabilities() *MatrixCapabilities From 89b41900e49b308a6e37862e3a39f83646d0c787 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 15 Apr 2025 14:35:47 +0300 Subject: [PATCH 1088/1647] bridgev2/userlogin: stop using deprecated alias --- bridgev2/userlogin.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index d07acce5..bf8f3bc6 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -491,14 +491,14 @@ func (ul *UserLogin) MarkAsPreferredIn(ctx context.Context, portal *Portal) erro return ul.Bridge.DB.UserPortal.MarkAsPreferred(ctx, ul.UserLogin, portal.PortalKey) } -var _ status.StandaloneCustomBridgeStateFiller = (*UserLogin)(nil) +var _ status.BridgeStateFiller = (*UserLogin)(nil) func (ul *UserLogin) FillBridgeState(state status.BridgeState) status.BridgeState { state.UserID = ul.UserMXID state.RemoteID = string(ul.ID) state.RemoteName = ul.RemoteName state.RemoteProfile = &ul.RemoteProfile - filler, ok := ul.Client.(status.StandaloneCustomBridgeStateFiller) + filler, ok := ul.Client.(status.BridgeStateFiller) if ok { return filler.FillBridgeState(state) } From 7cb13f8fd3df1609eee379f71f69c92bd5f6addb Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 15 Apr 2025 14:43:47 +0300 Subject: [PATCH 1089/1647] bridgev2/status: add user_action field for bridge states --- bridgev2/bridgestate.go | 4 +++- bridgev2/status/bridgestate.go | 10 ++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/bridgev2/bridgestate.go b/bridgev2/bridgestate.go index 61f988ad..148b522c 100644 --- a/bridgev2/bridgestate.go +++ b/bridgev2/bridgestate.go @@ -76,7 +76,9 @@ func (bsq *BridgeStateQueue) loop() { func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.BridgeState) { noticeConfig := bsq.bridge.Config.BridgeStatusNotices - isError := state.StateEvent == status.StateBadCredentials || state.StateEvent == status.StateUnknownError + isError := state.StateEvent == status.StateBadCredentials || + state.StateEvent == status.StateUnknownError || + state.UserAction == status.UserActionOpenNative sendNotice := noticeConfig == "all" || (noticeConfig == "errors" && (isError || (bsq.errorSent && state.StateEvent == status.StateConnected))) if !sendNotice { diff --git a/bridgev2/status/bridgestate.go b/bridgev2/status/bridgestate.go index cb862110..005a4f62 100644 --- a/bridgev2/status/bridgestate.go +++ b/bridgev2/status/bridgestate.go @@ -73,6 +73,13 @@ func (e BridgeStateEvent) IsValid() bool { } } +type BridgeStateUserAction string + +const ( + UserActionOpenNative BridgeStateUserAction = "OPEN_NATIVE" + UserActionRelogin BridgeStateUserAction = "RELOGIN" +) + type RemoteProfile struct { Phone string `json:"phone,omitempty"` Email string `json:"email,omitempty"` @@ -110,6 +117,8 @@ type BridgeState struct { Error BridgeStateErrorCode `json:"error,omitempty"` Message string `json:"message,omitempty"` + UserAction BridgeStateUserAction `json:"user_action,omitempty"` + UserID id.UserID `json:"user_id,omitempty"` RemoteID string `json:"remote_id,omitempty"` RemoteName string `json:"remote_name,omitempty"` @@ -192,6 +201,7 @@ func (pong *BridgeState) ShouldDeduplicate(newPong *BridgeState) bool { return pong != nil && pong.StateEvent == newPong.StateEvent && pong.RemoteName == newPong.RemoteName && + pong.UserAction == newPong.UserAction && ptr.Val(pong.RemoteProfile) == ptr.Val(newPong.RemoteProfile) && pong.Error == newPong.Error && maps.EqualFunc(pong.Info, newPong.Info, reflect.DeepEqual) && From 7165d3fa583444352ea770b5c003c93a7fc6989c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 16 Apr 2025 11:59:48 +0300 Subject: [PATCH 1090/1647] Bump version to v0.23.3 --- CHANGELOG.md | 16 ++++++++++++++++ go.mod | 22 +++++++++++----------- go.sum | 41 ++++++++++++++++++++--------------------- version.go | 2 +- 4 files changed, 48 insertions(+), 33 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3fc91bb7..565d7f15 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,19 @@ +## v0.23.3 (2025-04-16) + +* *(commands)* Added generic command processing framework for bots. +* *(client)* Added `allowed_room_ids` field to room summary responses + (thanks to [@nexy7574] in [#367]). +* *(bridgev2)* Added support for custom timeouts on outgoing messages which have + to wait for a remote echo. +* *(bridgev2)* Added automatic typing stop event if the ghost user had sent a + typing event before a message. +* *(bridgev2)* The saved management room is now cleared if the user leaves the + room, allowing the next DM to be automatically marked as a management room. +* *(bridge)* Removed deprecated fallback package for bridge statuses. + The status package is now only available under bridgev2. + +[#367]: https://github.com/mautrix/go/pull/367 + ## v0.23.2 (2025-03-16) * **Breaking change *(bridge)*** Removed legacy bridge module. diff --git a/go.mod b/go.mod index cea3c580..40564392 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,7 @@ module maunium.net/go/mautrix go 1.23.0 -toolchain go1.24.1 +toolchain go1.24.2 require ( filippo.io/edwards25519 v1.1.0 @@ -10,20 +10,20 @@ require ( github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.5.0 github.com/lib/pq v1.10.9 - github.com/mattn/go-sqlite3 v1.14.24 + github.com/mattn/go-sqlite3 v1.14.27 github.com/rs/xid v1.6.0 - github.com/rs/zerolog v1.33.0 + github.com/rs/zerolog v1.34.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/stretchr/testify v1.10.0 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 - github.com/yuin/goldmark v1.7.8 + github.com/yuin/goldmark v1.7.10 go.mau.fi/util v0.8.6 go.mau.fi/zeroconfig v0.1.3 - golang.org/x/crypto v0.36.0 - golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 - golang.org/x/net v0.37.0 - golang.org/x/sync v0.12.0 + golang.org/x/crypto v0.37.0 + golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 + golang.org/x/net v0.39.0 + golang.org/x/sync v0.13.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 ) @@ -33,11 +33,11 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203 // indirect + github.com/petermattis/goid v0.0.0-20250319124200-ccd6737f222a // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect - golang.org/x/sys v0.31.0 // indirect - golang.org/x/text v0.23.0 // indirect + golang.org/x/sys v0.32.0 // indirect + golang.org/x/text v0.24.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index eab4a1b8..cb64c875 100644 --- a/go.sum +++ b/go.sum @@ -26,18 +26,17 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= -github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203 h1:E7Kmf11E4K7B5hDti2K2NqPb1nlYlGYsu02S1JNd/Bs= -github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/mattn/go-sqlite3 v1.14.27 h1:drZCnuvf37yPfs95E5jd9s3XhdVWLal+6BOK6qrv6IU= +github.com/mattn/go-sqlite3 v1.14.27/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/petermattis/goid v0.0.0-20250319124200-ccd6737f222a h1:S+AGcmAESQ0pXCUNnRH7V+bOUIgkSX5qVt2cNKCrm0Q= +github.com/petermattis/goid v0.0.0-20250319124200-ccd6737f222a/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= -github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= -github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= @@ -52,28 +51,28 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= -github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= +github.com/yuin/goldmark v1.7.10 h1:S+LrtBjRmqMac2UdtB6yyCEJm+UILZ2fefI4p7o0QpI= +github.com/yuin/goldmark v1.7.10/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= go.mau.fi/util v0.8.6 h1:AEK13rfgtiZJL2YsNK+W4ihhYCuukcRom8WPP/w/L54= go.mau.fi/util v0.8.6/go.mod h1:uNB3UTXFbkpp7xL1M/WvQks90B/L4gvbLpbS0603KOE= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= -golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= -golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 h1:nDVHiLt8aIbd/VzvPWN6kSOPE7+F/fNFDSXLVYkE/Iw= -golang.org/x/exp v0.0.0-20250305212735-054e65f0b394/go.mod h1:sIifuuw/Yco/y6yb6+bDNfyeQ/MdPUy/hKEMYQV17cM= -golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= -golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= -golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= -golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= +golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= +golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM= +golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8= +golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= +golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= +golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= +golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= -golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= -golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= +golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= +golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= diff --git a/version.go b/version.go index dba1a59e..2e670697 100644 --- a/version.go +++ b/version.go @@ -7,7 +7,7 @@ import ( "strings" ) -const Version = "v0.23.2" +const Version = "v0.23.3" var GoModVersion = "" var Commit = "" From d3d20cbcf20a0a16928701e6bd14e7e1af4fa049 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 21 Apr 2025 23:43:23 +0300 Subject: [PATCH 1091/1647] client: add context parameter for setting max retries --- client.go | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index a08fbed2..fb310c62 100644 --- a/client.go +++ b/client.go @@ -316,6 +316,7 @@ type contextKey int const ( LogBodyContextKey contextKey = iota LogRequestIDContextKey + MaxAttemptsContextKey ) func (cli *Client) RequestStart(req *http.Request) { @@ -324,6 +325,14 @@ func (cli *Client) RequestStart(req *http.Request) { } } +// WithMaxRetries updates the context to set the maximum number of retries for any HTTP requests made with the context. +// +// 0 means the request will only be attempted once and will not be retried. +// Negative values will remove the override and fallback to the defaults. +func WithMaxRetries(ctx context.Context, maxRetries int) context.Context { + return context.WithValue(ctx, MaxAttemptsContextKey, maxRetries+1) +} + func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err error, handlerErr error, contentLength int, duration time.Duration) { if cli == nil { return @@ -473,7 +482,12 @@ func (cli *Client) MakeFullRequestWithResp(ctx context.Context, params FullReque return nil, nil, ErrClientIsNil } if params.MaxAttempts == 0 { - params.MaxAttempts = 1 + cli.DefaultHTTPRetries + maxAttempts, ok := ctx.Value(MaxAttemptsContextKey).(int) + if ok && maxAttempts > 0 { + params.MaxAttempts = maxAttempts + } else { + params.MaxAttempts = 1 + cli.DefaultHTTPRetries + } } if params.BackoffDuration == 0 { if cli.DefaultHTTPBackoff == 0 { From 953334a0a03fbf2b6e16dccb674e625c2eb1dce5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 21 Apr 2025 23:43:44 +0300 Subject: [PATCH 1092/1647] client,federation: add wrappers for /publicRooms --- client.go | 6 ++++++ federation/client.go | 20 ++++++++++++++++++++ requests.go | 27 +++++++++++++++++++++++++++ 3 files changed, 53 insertions(+) diff --git a/client.go b/client.go index fb310c62..5b5f083e 100644 --- a/client.go +++ b/client.go @@ -1986,6 +1986,12 @@ func (cli *Client) JoinedRooms(ctx context.Context) (resp *RespJoinedRooms, err return } +func (cli *Client) PublicRooms(ctx context.Context, req *ReqPublicRooms) (resp *RespPublicRooms, err error) { + urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "publicRooms"}, req.Query()) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + return +} + // Hierarchy returns a list of rooms that are in the room's hierarchy. See https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy // // The hierarchy API is provided to walk the space tree and discover the rooms with their aesthetic details. works in a depth-first manner: diff --git a/federation/client.go b/federation/client.go index 098df095..7fc630b7 100644 --- a/federation/client.go +++ b/federation/client.go @@ -220,6 +220,26 @@ func (c *Client) Query(ctx context.Context, serverName, queryType string, queryP return } +func queryToValues(query map[string]string) url.Values { + values := make(url.Values, len(query)) + for k, v := range query { + values[k] = []string{v} + } + return values +} + +func (c *Client) PublicRooms(ctx context.Context, serverName string, req *mautrix.ReqPublicRooms) (resp *mautrix.RespPublicRooms, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: serverName, + Method: http.MethodGet, + Path: URLPath{"v1", "publicRooms"}, + Query: queryToValues(req.Query()), + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + type RespOpenIDUserInfo struct { Sub id.UserID `json:"sub"` } diff --git a/requests.go b/requests.go index 377534ae..1bed6c7e 100644 --- a/requests.go +++ b/requests.go @@ -424,6 +424,33 @@ type ReqSendReceipt struct { ThreadID string `json:"thread_id,omitempty"` } +type ReqPublicRooms struct { + IncludeAllNetworks bool + Limit int + Since string + ThirdPartyInstanceID string +} + +func (req *ReqPublicRooms) Query() map[string]string { + query := map[string]string{} + if req == nil { + return query + } + if req.IncludeAllNetworks { + query["include_all_networks"] = "true" + } + if req.Limit > 0 { + query["limit"] = strconv.Itoa(req.Limit) + } + if req.Since != "" { + query["since"] = req.Since + } + if req.ThirdPartyInstanceID != "" { + query["third_party_instance_id"] = req.ThirdPartyInstanceID + } + return query +} + // ReqHierarchy contains the parameters for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy // // As it's a GET method, there is no JSON body, so this is only query parameters. From 87ca9bef1cdac639ce912de0621346164e4908a4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 23 Apr 2025 11:29:07 +0300 Subject: [PATCH 1093/1647] bridgev2/networkinterface: add viewing chat callback --- bridgev2/networkinterface.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 29ee9fc9..1565a92c 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -583,6 +583,15 @@ type ReadReceiptHandlingNetworkAPI interface { HandleMatrixReadReceipt(ctx context.Context, msg *MatrixReadReceipt) error } +// ChatViewingNetworkAPI is an optional interface that network connectors can implement to handle viewing chat status. +type ChatViewingNetworkAPI interface { + NetworkAPI + // HandleMatrixViewingChat is called when the user opens a portal room. + // This will never be called by the standard appservice connector, + // as Matrix doesn't have any standard way of signaling chat open status. + HandleMatrixViewingChat(ctx context.Context, msg *MatrixViewingChat) error +} + // TypingHandlingNetworkAPI is an optional interface that network connectors can implement to handle typing events. type TypingHandlingNetworkAPI interface { NetworkAPI @@ -1235,6 +1244,14 @@ type MatrixTyping struct { Type TypingType } +type MatrixViewingChat struct { + // The portal that the user is viewing. This will be nil when the user switches to a chat from a different bridge. + Portal *Portal + // An optional timeout after which the user should not be assumed to be viewing the chat anymore + // unless the event is repeated. + Timeout time.Duration +} + type MatrixMarkedUnread = MatrixRoomMeta[*event.MarkedUnreadEventContent] type MatrixMute = MatrixRoomMeta[*event.BeeperMuteEventContent] type MatrixRoomTag = MatrixRoomMeta[*event.TagEventContent] From f931c9972de92d782f8e4598eca8d1fd11a1f0f7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 23 Apr 2025 15:29:45 +0300 Subject: [PATCH 1094/1647] crypto/decryptolm: don't try to parse content if there is none --- crypto/decryptolm.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go index 353979d4..b737e4e1 100644 --- a/crypto/decryptolm.go +++ b/crypto/decryptolm.go @@ -106,9 +106,11 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e return nil, RecipientKeyMismatch } - err = olmEvt.Content.ParseRaw(olmEvt.Type) - if err != nil && !errors.Is(err, event.ErrUnsupportedContentType) { - return nil, fmt.Errorf("failed to parse content of olm payload event: %w", err) + if len(olmEvt.Content.VeryRaw) > 0 { + err = olmEvt.Content.ParseRaw(olmEvt.Type) + if err != nil && !errors.Is(err, event.ErrUnsupportedContentType) { + return nil, fmt.Errorf("failed to parse content of olm payload event: %w", err) + } } olmEvt.SenderKey = senderKey From 3698f139b6eaf05e9092818497e863a52cab5010 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 23 Apr 2025 15:47:22 +0300 Subject: [PATCH 1095/1647] crypto/helper: always update crypto store device ID --- crypto/cryptohelper/cryptohelper.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go index f03835ef..56f8b484 100644 --- a/crypto/cryptohelper/cryptohelper.go +++ b/crypto/cryptohelper/cryptohelper.go @@ -170,14 +170,12 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { if err != nil { return err } - rawCryptoStore.DeviceID = resp.DeviceID helper.client.DeviceID = resp.DeviceID } else { helper.log.Debug(). Str("username", helper.LoginAs.Identifier.User). Stringer("device_id", storedDeviceID). Msg("Using existing device") - rawCryptoStore.DeviceID = storedDeviceID helper.client.DeviceID = storedDeviceID } } else if helper.LoginAs != nil { @@ -193,12 +191,10 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { if err != nil { return err } - if storedDeviceID == "" { - rawCryptoStore.DeviceID = helper.client.DeviceID - } } else if storedDeviceID != "" && storedDeviceID != helper.client.DeviceID { return fmt.Errorf("mismatching device ID in client and crypto store (%q != %q)", storedDeviceID, helper.client.DeviceID) } + rawCryptoStore.DeviceID = helper.client.DeviceID } else if helper.LoginAs != nil { return fmt.Errorf("LoginAs can only be used with a managed crypto store") } From 5f4bd44baa012fde1cc33b6f7451bc3ebe111213 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 23 Apr 2025 15:57:35 +0300 Subject: [PATCH 1096/1647] event/voip: omit empty version field in call events --- event/voip.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/event/voip.go b/event/voip.go index 28f56c95..cd8364a1 100644 --- a/event/voip.go +++ b/event/voip.go @@ -76,7 +76,7 @@ func (cv *CallVersion) Int() (int, error) { type BaseCallEventContent struct { CallID string `json:"call_id"` PartyID string `json:"party_id"` - Version CallVersion `json:"version"` + Version CallVersion `json:"version,omitempty"` } type CallInviteEventContent struct { From 19153e363846b29717fd38f7d40baab777cdce09 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 23 Apr 2025 16:27:11 +0100 Subject: [PATCH 1097/1647] client: return immediately if context canceled on external upload --- client.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/client.go b/client.go index 5b5f083e..5f47aead 100644 --- a/client.go +++ b/client.go @@ -1805,6 +1805,9 @@ func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (* break } err = fmt.Errorf("HTTP %d", resp.StatusCode) + } else if errors.Is(err, context.Canceled) { + cli.Log.Warn().Str("url", data.UnstableUploadURL).Msg("External media upload canceled") + return nil, err } if retries <= 0 { cli.Log.Warn().Str("url", data.UnstableUploadURL).Err(err). From 931f89202b1b6ad953751bd7e62d2f414ed7b9f3 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 23 Apr 2025 16:27:28 +0100 Subject: [PATCH 1098/1647] crypto/verification: include the incorrect state in non-ready error message --- crypto/verificationhelper/sas.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index 2906e3a2..89d4a750 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -46,7 +46,7 @@ func (vh *VerificationHelper) StartSAS(ctx context.Context, txnID id.Verificatio if err != nil { return fmt.Errorf("failed to get verification transaction %s: %w", txnID, err) } else if txn.VerificationState != VerificationStateReady { - return errors.New("transaction is not in ready state") + return fmt.Errorf("transaction is not in ready state: %s", txn.VerificationState.String()) } else if txn.StartEventContent != nil { return errors.New("start event already sent or received") } From de171e38d5d5492702f7a8a59581a55016230aaf Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 23 Apr 2025 16:46:46 +0100 Subject: [PATCH 1099/1647] crypto/verification: use consistent action log --- crypto/verificationhelper/sas.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index 89d4a750..1313a613 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -35,7 +35,7 @@ import ( // [StartInRoomVerification] functions. func (vh *VerificationHelper) StartSAS(ctx context.Context, txnID id.VerificationTransactionID) error { log := vh.getLog(ctx).With(). - Str("verification_action", "accept verification"). + Str("verification_action", "start SAS"). Stringer("transaction_id", txnID). Logger() ctx = log.WithContext(ctx) @@ -177,7 +177,7 @@ func (vh *VerificationHelper) ConfirmSAS(ctx context.Context, txnID id.Verificat func (vh *VerificationHelper) onVerificationStartSAS(ctx context.Context, txn VerificationTransaction, evt *event.Event) error { startEvt := evt.Content.AsVerificationStart() log := vh.getLog(ctx).With(). - Str("verification_action", "start_sas"). + Str("verification_action", "start SAS"). Stringer("transaction_id", txn.TransactionID). Logger() ctx = log.WithContext(ctx) From 33f3ccd6aef0a3ef3a37fefead208c6249169142 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 23 Apr 2025 16:46:58 +0100 Subject: [PATCH 1100/1647] crypto/verification: add missing lock in `AcceptVerification` method --- crypto/verificationhelper/verificationhelper.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 550df942..8d99dacc 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -378,6 +378,9 @@ func (vh *VerificationHelper) StartInRoomVerification(ctx context.Context, roomI // be the transaction ID of a verification request that was received via the // VerificationRequested callback in [RequiredCallbacks]. func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.VerificationTransactionID) error { + vh.activeTransactionsLock.Lock() + defer vh.activeTransactionsLock.Unlock() + log := vh.getLog(ctx).With(). Str("verification_action", "accept verification"). Stringer("transaction_id", txnID). From 3badb9b332fede4b51295e17d5d08001bf36d5f3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 28 Apr 2025 00:25:36 +0300 Subject: [PATCH 1101/1647] commands: add subcommand system --- commands/container.go | 89 ++++++++++++++++++++++++++++++++++++++ commands/event.go | 18 ++++++++ commands/handler.go | 29 +++++++++++++ commands/prevalidate.go | 4 +- commands/processor.go | 95 +++++++++++------------------------------ 5 files changed, 161 insertions(+), 74 deletions(-) create mode 100644 commands/container.go create mode 100644 commands/handler.go diff --git a/commands/container.go b/commands/container.go new file mode 100644 index 00000000..e9dfd5e9 --- /dev/null +++ b/commands/container.go @@ -0,0 +1,89 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "fmt" + "strings" + "sync" +) + +type CommandContainer[MetaType any] struct { + commands map[string]*Handler[MetaType] + aliases map[string]string + lock sync.RWMutex +} + +func NewCommandContainer[MetaType any]() *CommandContainer[MetaType] { + return &CommandContainer[MetaType]{ + commands: make(map[string]*Handler[MetaType]), + aliases: make(map[string]string), + } +} + +// Register registers the given command handlers. +func (cont *CommandContainer[MetaType]) Register(handlers ...*Handler[MetaType]) { + if cont == nil { + return + } + cont.lock.Lock() + defer cont.lock.Unlock() + for _, handler := range handlers { + cont.registerOne(handler) + } +} + +func (cont *CommandContainer[MetaType]) registerOne(handler *Handler[MetaType]) { + if strings.ToLower(handler.Name) != handler.Name { + panic(fmt.Errorf("command %q is not lowercase", handler.Name)) + } + cont.commands[handler.Name] = handler + for _, alias := range handler.Aliases { + if strings.ToLower(alias) != alias { + panic(fmt.Errorf("alias %q is not lowercase", alias)) + } + cont.aliases[alias] = handler.Name + } + handler.initSubcommandContainer() +} + +func (cont *CommandContainer[MetaType]) Unregister(handlers ...*Handler[MetaType]) { + if cont == nil { + return + } + cont.lock.Lock() + defer cont.lock.Unlock() + for _, handler := range handlers { + cont.unregisterOne(handler) + } +} + +func (cont *CommandContainer[MetaType]) unregisterOne(handler *Handler[MetaType]) { + delete(cont.commands, handler.Name) + for _, alias := range handler.Aliases { + if cont.aliases[alias] == handler.Name { + delete(cont.aliases, alias) + } + } +} + +func (cont *CommandContainer[MetaType]) GetHandler(name string) *Handler[MetaType] { + if cont == nil { + return nil + } + cont.lock.RLock() + defer cont.lock.RUnlock() + alias, ok := cont.aliases[name] + if ok { + name = alias + } + handler, ok := cont.commands[name] + if !ok { + handler = cont.commands[UnknownCommandName] + } + return handler +} diff --git a/commands/event.go b/commands/event.go index baf9ecda..7370844c 100644 --- a/commands/event.go +++ b/commands/event.go @@ -23,6 +23,9 @@ type Event[MetaType any] struct { *event.Event // RawInput is the entire message before splitting into command and arguments. RawInput string + // ParentCommands is the chain of commands leading up to this command. + // This is only set if the command is a subcommand. + ParentCommands []string // Command is the lowercased first word of the message. Command string // Args are the rest of the message split by whitespace ([strings.Fields]). @@ -122,3 +125,18 @@ func (evt *Event[MetaType]) MarkRead() { zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to send read receipt") } } + +// PromoteFirstArgToCommand promotes the first argument to the command name. +// +// Command will be set to the lowercased first item in the Args list. +// Both Args and RawArgs will be updated to remove the first argument, but RawInput will be left as-is. +// +// The caller MUST check that there are args before calling this function. +func (evt *Event[MetaType]) PromoteFirstArgToCommand() { + if len(evt.Args) == 0 { + panic(fmt.Errorf("PromoteFirstArgToCommand called with no args")) + } + evt.Command = strings.ToLower(evt.Args[0]) + evt.RawArgs = strings.TrimLeft(strings.TrimPrefix(evt.RawArgs, evt.Args[0]), " ") + evt.Args = evt.Args[1:] +} diff --git a/commands/handler.go b/commands/handler.go new file mode 100644 index 00000000..be1d4e9b --- /dev/null +++ b/commands/handler.go @@ -0,0 +1,29 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +type Handler[MetaType any] struct { + Func func(ce *Event[MetaType]) + + // Name is the primary name of the command. It must be lowercase. + Name string + // Aliases are alternative names for the command. They must be lowercase. + Aliases []string + // Subcommands are subcommands of this command. + Subcommands []*Handler[MetaType] + + subcommandContainer *CommandContainer[MetaType] +} + +func (h *Handler[MetaType]) initSubcommandContainer() { + if len(h.Subcommands) > 0 { + h.subcommandContainer = NewCommandContainer[MetaType]() + h.subcommandContainer.Register(h.Subcommands...) + } else { + h.subcommandContainer = nil + } +} diff --git a/commands/prevalidate.go b/commands/prevalidate.go index 95bbcc97..66da2b49 100644 --- a/commands/prevalidate.go +++ b/commands/prevalidate.go @@ -61,9 +61,7 @@ func (f AnyPreValidator[MetaType]) Validate(ce *Event[MetaType]) bool { func ValidatePrefixCommand[MetaType any](prefix string) PreValidator[MetaType] { return FuncPreValidator[MetaType](func(ce *Event[MetaType]) bool { if ce.Command == prefix && len(ce.Args) > 0 { - ce.Command = strings.ToLower(ce.Args[0]) - ce.RawArgs = strings.TrimLeft(strings.TrimPrefix(ce.RawArgs, ce.Args[0]), " ") - ce.Args = ce.Args[1:] + ce.PromoteFirstArgToCommand() return true } return false diff --git a/commands/processor.go b/commands/processor.go index d4a29690..067f222e 100644 --- a/commands/processor.go +++ b/commands/processor.go @@ -8,10 +8,8 @@ package commands import ( "context" - "fmt" "runtime/debug" "strings" - "sync" "github.com/rs/zerolog" @@ -22,34 +20,23 @@ import ( // Processor implements boilerplate code for splitting messages into a command and arguments, // and finding the appropriate handler for the command. type Processor[MetaType any] struct { + *CommandContainer[MetaType] + Client *mautrix.Client LogArgs bool PreValidator PreValidator[MetaType] Meta MetaType - commands map[string]*Handler[MetaType] - aliases map[string]string - lock sync.RWMutex -} - -type Handler[MetaType any] struct { - Func func(ce *Event[MetaType]) - - // Name is the primary name of the command. It must be lowercase. - Name string - // Aliases are alternative names for the command. They must be lowercase. - Aliases []string } // UnknownCommandName is the name of the fallback handler which is used if no other handler is found. // If even the unknown command handler is not found, the command is ignored. -const UnknownCommandName = "unknown-command" +const UnknownCommandName = "__unknown-command__" func NewProcessor[MetaType any](cli *mautrix.Client) *Processor[MetaType] { proc := &Processor[MetaType]{ - Client: cli, - PreValidator: ValidatePrefixSubstring[MetaType]("!"), - commands: make(map[string]*Handler[MetaType]), - aliases: make(map[string]string), + CommandContainer: NewCommandContainer[MetaType](), + Client: cli, + PreValidator: ValidatePrefixSubstring[MetaType]("!"), } proc.Register(&Handler[MetaType]{ Name: UnknownCommandName, @@ -60,45 +47,6 @@ func NewProcessor[MetaType any](cli *mautrix.Client) *Processor[MetaType] { return proc } -// Register registers the given command handlers. -func (proc *Processor[MetaType]) Register(handlers ...*Handler[MetaType]) { - proc.lock.Lock() - defer proc.lock.Unlock() - for _, handler := range handlers { - proc.registerOne(handler) - } -} - -func (proc *Processor[MetaType]) registerOne(handler *Handler[MetaType]) { - if strings.ToLower(handler.Name) != handler.Name { - panic(fmt.Errorf("command %q is not lowercase", handler.Name)) - } - proc.commands[handler.Name] = handler - for _, alias := range handler.Aliases { - if strings.ToLower(alias) != alias { - panic(fmt.Errorf("alias %q is not lowercase", alias)) - } - proc.aliases[alias] = handler.Name - } -} - -func (proc *Processor[MetaType]) Unregister(handlers ...*Handler[MetaType]) { - proc.lock.Lock() - defer proc.lock.Unlock() - for _, handler := range handlers { - proc.unregisterOne(handler) - } -} - -func (proc *Processor[MetaType]) unregisterOne(handler *Handler[MetaType]) { - delete(proc.commands, handler.Name) - for _, alias := range handler.Aliases { - if proc.aliases[alias] == handler.Name { - delete(proc.aliases, alias) - } - } -} - func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) { log := *zerolog.Ctx(ctx) defer func() { @@ -123,25 +71,30 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) return } - realCommand := parsed.Command - proc.lock.RLock() - alias, ok := proc.aliases[realCommand] - if ok { - realCommand = alias - } - handler, ok := proc.commands[realCommand] - if !ok { - handler, ok = proc.commands[UnknownCommandName] - } - proc.lock.RUnlock() - if !ok { + handler := proc.GetHandler(parsed.Command) + if handler == nil { return } + handlerChain := zerolog.Arr() + handlerChain.Str(handler.Name) + for handler.subcommandContainer != nil && len(parsed.Args) > 0 { + subHandler := handler.subcommandContainer.GetHandler(strings.ToLower(parsed.Args[0])) + if subHandler != nil { + parsed.ParentCommands = append(parsed.ParentCommands, parsed.Command) + handlerChain.Str(subHandler.Name) + parsed.PromoteFirstArgToCommand() + handler = subHandler + } + } logWith := log.With(). - Str("command", realCommand). + Str("command", parsed.Command). + Array("handler", handlerChain). Stringer("sender", evt.Sender). Stringer("room_id", evt.RoomID) + if len(parsed.ParentCommands) > 0 { + logWith = logWith.Strs("parent_commands", parsed.ParentCommands) + } if proc.LogArgs { logWith = logWith.Strs("args", parsed.Args) } From 9dc0b3cddffb8bf88b48b1785ed51a6d0d6d9923 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 28 Apr 2025 00:32:44 +0300 Subject: [PATCH 1102/1647] commands: make unknown command handler more generic --- commands/handler.go | 17 +++++++++++++++++ commands/processor.go | 7 +------ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/commands/handler.go b/commands/handler.go index be1d4e9b..d4c53ff6 100644 --- a/commands/handler.go +++ b/commands/handler.go @@ -6,6 +6,10 @@ package commands +import ( + "strings" +) + type Handler[MetaType any] struct { Func func(ce *Event[MetaType]) @@ -27,3 +31,16 @@ func (h *Handler[MetaType]) initSubcommandContainer() { h.subcommandContainer = nil } } + +func MakeUnknownCommandHandler[MetaType any](prefix string) *Handler[MetaType] { + return &Handler[MetaType]{ + Name: UnknownCommandName, + Func: func(ce *Event[MetaType]) { + if len(ce.ParentCommands) == 0 { + ce.Reply("Unknown command `%s%s`", prefix, ce.Command) + } else { + ce.Reply("Unknown subcommand `%s%s %s`", prefix, strings.Join(ce.ParentCommands, " "), ce.Command) + } + }, + } +} diff --git a/commands/processor.go b/commands/processor.go index 067f222e..c4077250 100644 --- a/commands/processor.go +++ b/commands/processor.go @@ -38,12 +38,7 @@ func NewProcessor[MetaType any](cli *mautrix.Client) *Processor[MetaType] { Client: cli, PreValidator: ValidatePrefixSubstring[MetaType]("!"), } - proc.Register(&Handler[MetaType]{ - Name: UnknownCommandName, - Func: func(ce *Event[MetaType]) { - ce.Reply("Unknown command") - }, - }) + proc.Register(MakeUnknownCommandHandler[MetaType]("!")) return proc } From 287899435dc4d60738c777548fd92d25e41bf7b5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 28 Apr 2025 00:54:00 +0300 Subject: [PATCH 1103/1647] format: add method to quote string in markdown inline code --- format/htmlparser.go | 22 +--------------------- format/markdown.go | 14 ++++++++++++++ format/markdown_test.go | 15 +++++++++++++++ go.mod | 2 +- go.sum | 4 ++-- 5 files changed, 33 insertions(+), 24 deletions(-) diff --git a/format/htmlparser.go b/format/htmlparser.go index e50a578e..f9d51e39 100644 --- a/format/htmlparser.go +++ b/format/htmlparser.go @@ -187,25 +187,6 @@ func (parser *HTMLParser) listToString(node *html.Node, ctx Context) string { return strings.Join(children, "\n") } -func LongestSequence(in string, of rune) int { - currentSeq := 0 - maxSeq := 0 - for _, chr := range in { - if chr == of { - currentSeq++ - } else { - if currentSeq > maxSeq { - maxSeq = currentSeq - } - currentSeq = 0 - } - } - if currentSeq > maxSeq { - maxSeq = currentSeq - } - return maxSeq -} - func (parser *HTMLParser) basicFormatToString(node *html.Node, ctx Context) string { str := parser.nodeToTagAwareString(node.FirstChild, ctx) switch node.Data { @@ -232,8 +213,7 @@ func (parser *HTMLParser) basicFormatToString(node *html.Node, ctx Context) stri if parser.MonospaceConverter != nil { return parser.MonospaceConverter(str, ctx) } - surround := strings.Repeat("`", LongestSequence(str, '`')+1) - return fmt.Sprintf("%s%s%s", surround, str, surround) + return SafeMarkdownCode(str) } return str } diff --git a/format/markdown.go b/format/markdown.go index a1c93162..59248c72 100644 --- a/format/markdown.go +++ b/format/markdown.go @@ -14,6 +14,7 @@ import ( "github.com/yuin/goldmark" "github.com/yuin/goldmark/extension" "github.com/yuin/goldmark/renderer/html" + "go.mau.fi/util/exstrings" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/format/mdext" @@ -49,6 +50,19 @@ func EscapeMarkdown(text string) string { return text } +func SafeMarkdownCode(text string) string { + text = strings.ReplaceAll(text, "\n", " ") + backtickCount := exstrings.LongestSequenceOf(text, '`') + if backtickCount == 0 { + return fmt.Sprintf("`%s`", text) + } + quotes := strings.Repeat("`", backtickCount+1) + if text[0] == '`' || text[len(text)-1] == '`' { + return fmt.Sprintf("%s %s %s", quotes, text, quotes) + } + return fmt.Sprintf("%s%s%s", quotes, text, quotes) +} + func RenderMarkdownCustom(text string, renderer goldmark.Markdown) event.MessageEventContent { var buf strings.Builder err := renderer.Convert([]byte(text), &buf) diff --git a/format/markdown_test.go b/format/markdown_test.go index d4e7d716..46ea4886 100644 --- a/format/markdown_test.go +++ b/format/markdown_test.go @@ -196,3 +196,18 @@ func TestRenderMarkdown_CustomEmoji(t *testing.T) { assert.Equal(t, html, rendered, "with input %q", markdown) } } + +var codeTests = map[string]string{ + "meow": "`meow`", + "me`ow": "``me`ow``", + "`me`ow": "`` `me`ow ``", + "me`ow`": "`` me`ow` ``", + "`meow`": "`` `meow` ``", + "`````````": "`````````` ````````` ``````````", +} + +func TestSafeMarkdownCode(t *testing.T) { + for input, expected := range codeTests { + assert.Equal(t, expected, format.SafeMarkdownCode(input), "with input %q", input) + } +} diff --git a/go.mod b/go.mod index 40564392..e279118e 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.10 - go.mau.fi/util v0.8.6 + go.mau.fi/util v0.8.7-0.20250427215252-d2d18a7e463c go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.37.0 golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 diff --git a/go.sum b/go.sum index cb64c875..f103b287 100644 --- a/go.sum +++ b/go.sum @@ -53,8 +53,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.10 h1:S+LrtBjRmqMac2UdtB6yyCEJm+UILZ2fefI4p7o0QpI= github.com/yuin/goldmark v1.7.10/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.8.6 h1:AEK13rfgtiZJL2YsNK+W4ihhYCuukcRom8WPP/w/L54= -go.mau.fi/util v0.8.6/go.mod h1:uNB3UTXFbkpp7xL1M/WvQks90B/L4gvbLpbS0603KOE= +go.mau.fi/util v0.8.7-0.20250427215252-d2d18a7e463c h1:qfJyMZq1pPyuXKoVWwHs6OmR9CzO3pHFRPYT/QpaaaA= +go.mau.fi/util v0.8.7-0.20250427215252-d2d18a7e463c/go.mod h1:uNB3UTXFbkpp7xL1M/WvQks90B/L4gvbLpbS0603KOE= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= From a121a6101ce6189fdc7510a86f81c1d9841ec3bd Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 28 Apr 2025 01:00:25 +0300 Subject: [PATCH 1104/1647] format: accept any string-like type in SafeMarkdownCode --- format/markdown.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/format/markdown.go b/format/markdown.go index 59248c72..f6181ed9 100644 --- a/format/markdown.go +++ b/format/markdown.go @@ -50,8 +50,8 @@ func EscapeMarkdown(text string) string { return text } -func SafeMarkdownCode(text string) string { - text = strings.ReplaceAll(text, "\n", " ") +func SafeMarkdownCode[T ~string](textInput T) string { + text := strings.ReplaceAll(string(textInput), "\n", " ") backtickCount := exstrings.LongestSequenceOf(text, '`') if backtickCount == 0 { return fmt.Sprintf("`%s`", text) From bf33889eab4cf2ca595b058c55827f7070cf9c7c Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Mon, 28 Apr 2025 14:10:57 +0100 Subject: [PATCH 1105/1647] bridgev2/userlogin: delete disappearing messages when deleting portals (#374) --- bridgev2/database/upgrades/00-latest.sql | 9 +++++-- .../21-disappearing-message-fkey.postgres.sql | 8 +++++++ .../21-disappearing-message-fkey.sqlite.sql | 24 +++++++++++++++++++ 3 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 bridgev2/database/upgrades/21-disappearing-message-fkey.postgres.sql create mode 100644 bridgev2/database/upgrades/21-disappearing-message-fkey.sqlite.sql diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index 56976b82..7ad01a87 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v20 (compatible with v9+): Latest revision +-- v0 -> v21 (compatible with v9+): Latest revision CREATE TABLE "user" ( bridge_id TEXT NOT NULL, mxid TEXT NOT NULL, @@ -63,6 +63,7 @@ CREATE TABLE portal ( REFERENCES user_login (bridge_id, id) ON DELETE SET NULL ON UPDATE CASCADE ); +CREATE UNIQUE INDEX portal_bridge_mxid_idx ON portal (bridge_id, mxid); CREATE TABLE ghost ( bridge_id TEXT NOT NULL, @@ -128,7 +129,11 @@ CREATE TABLE disappearing_message ( timer BIGINT NOT NULL, disappear_at BIGINT, - PRIMARY KEY (bridge_id, mxid) + PRIMARY KEY (bridge_id, mxid), + CONSTRAINT disappearing_message_portal_fkey + FOREIGN KEY (bridge_id, mx_room) + REFERENCES portal (bridge_id, mxid) + ON DELETE CASCADE ); CREATE TABLE reaction ( diff --git a/bridgev2/database/upgrades/21-disappearing-message-fkey.postgres.sql b/bridgev2/database/upgrades/21-disappearing-message-fkey.postgres.sql new file mode 100644 index 00000000..d1c1ad9a --- /dev/null +++ b/bridgev2/database/upgrades/21-disappearing-message-fkey.postgres.sql @@ -0,0 +1,8 @@ +-- v21 (compatible with v9+): Add foreign key constraint from disappearing_message.mx_room to portals.mxid +CREATE UNIQUE INDEX portal_bridge_mxid_idx ON portal (bridge_id, mxid); +DELETE FROM disappearing_message WHERE mx_room NOT IN (SELECT mxid FROM portal WHERE mxid IS NOT NULL); +ALTER TABLE disappearing_message + ADD CONSTRAINT disappearing_message_portal_fkey + FOREIGN KEY (bridge_id, mx_room) + REFERENCES portal (bridge_id, mxid) + ON DELETE CASCADE; diff --git a/bridgev2/database/upgrades/21-disappearing-message-fkey.sqlite.sql b/bridgev2/database/upgrades/21-disappearing-message-fkey.sqlite.sql new file mode 100644 index 00000000..f5468c6b --- /dev/null +++ b/bridgev2/database/upgrades/21-disappearing-message-fkey.sqlite.sql @@ -0,0 +1,24 @@ +-- v21 (compatible with v9+): Add foreign key constraint from disappearing_message.mx_room to portals.mxid +CREATE UNIQUE INDEX portal_bridge_mxid_idx ON portal (bridge_id, mxid); +CREATE TABLE disappearing_message_new ( + bridge_id TEXT NOT NULL, + mx_room TEXT NOT NULL, + mxid TEXT NOT NULL, + type TEXT NOT NULL, + timer BIGINT NOT NULL, + disappear_at BIGINT, + + PRIMARY KEY (bridge_id, mxid), + CONSTRAINT disappearing_message_portal_fkey + FOREIGN KEY (bridge_id, mx_room) + REFERENCES portal (bridge_id, mxid) + ON DELETE CASCADE +); + +WITH portal_mxids AS (SELECT mxid FROM portal WHERE mxid IS NOT NULL) +INSERT INTO disappearing_message_new (bridge_id, mx_room, mxid, type, timer, disappear_at) +SELECT bridge_id, mx_room, mxid, type, timer, disappear_at +FROM disappearing_message WHERE mx_room IN portal_mxids; + +DROP TABLE disappearing_message; +ALTER TABLE disappearing_message_new RENAME TO disappearing_message; From 06a292e1cc1538d43ff4a889fde9b361d55f7bc6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 28 Apr 2025 13:38:18 +0300 Subject: [PATCH 1106/1647] commands: add pre func for subcommand parameters --- commands/event.go | 21 ++++++++++++--------- commands/handler.go | 5 +++++ commands/prevalidate.go | 2 +- commands/processor.go | 18 +++++++++++++----- 4 files changed, 31 insertions(+), 15 deletions(-) diff --git a/commands/event.go b/commands/event.go index 7370844c..10d6283f 100644 --- a/commands/event.go +++ b/commands/event.go @@ -126,17 +126,20 @@ func (evt *Event[MetaType]) MarkRead() { } } -// PromoteFirstArgToCommand promotes the first argument to the command name. -// -// Command will be set to the lowercased first item in the Args list. -// Both Args and RawArgs will be updated to remove the first argument, but RawInput will be left as-is. -// -// The caller MUST check that there are args before calling this function. -func (evt *Event[MetaType]) PromoteFirstArgToCommand() { +// ShiftArg removes the first argument from the Args list and RawArgs data and returns it. +// RawInput will not be modified. +func (evt *Event[MetaType]) ShiftArg() string { if len(evt.Args) == 0 { - panic(fmt.Errorf("PromoteFirstArgToCommand called with no args")) + return "" } - evt.Command = strings.ToLower(evt.Args[0]) + firstArg := evt.Args[0] evt.RawArgs = strings.TrimLeft(strings.TrimPrefix(evt.RawArgs, evt.Args[0]), " ") evt.Args = evt.Args[1:] + return firstArg +} + +// UnshiftArg reverses ShiftArg by adding the given value to the beginning of the Args list and RawArgs data. +func (evt *Event[MetaType]) UnshiftArg(arg string) { + evt.RawArgs = arg + " " + evt.RawArgs + evt.Args = append([]string{arg}, evt.Args...) } diff --git a/commands/handler.go b/commands/handler.go index d4c53ff6..b01d594f 100644 --- a/commands/handler.go +++ b/commands/handler.go @@ -11,6 +11,7 @@ import ( ) type Handler[MetaType any] struct { + // Func is the function that is called when the command is executed. Func func(ce *Event[MetaType]) // Name is the primary name of the command. It must be lowercase. @@ -19,6 +20,10 @@ type Handler[MetaType any] struct { Aliases []string // Subcommands are subcommands of this command. Subcommands []*Handler[MetaType] + // PreFunc is a function that is called before checking subcommands. + // It can be used to have parameters between subcommands (e.g. `!rooms `). + // Event.ShiftArg will likely be useful for implementing such parameters. + PreFunc func(ce *Event[MetaType]) subcommandContainer *CommandContainer[MetaType] } diff --git a/commands/prevalidate.go b/commands/prevalidate.go index 66da2b49..facca4da 100644 --- a/commands/prevalidate.go +++ b/commands/prevalidate.go @@ -61,7 +61,7 @@ func (f AnyPreValidator[MetaType]) Validate(ce *Event[MetaType]) bool { func ValidatePrefixCommand[MetaType any](prefix string) PreValidator[MetaType] { return FuncPreValidator[MetaType](func(ce *Event[MetaType]) bool { if ce.Command == prefix && len(ce.Args) > 0 { - ce.PromoteFirstArgToCommand() + ce.Command = strings.ToLower(ce.ShiftArg()) return true } return false diff --git a/commands/processor.go b/commands/processor.go index c4077250..1e0a99a2 100644 --- a/commands/processor.go +++ b/commands/processor.go @@ -65,20 +65,31 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) if !proc.PreValidator.Validate(parsed) { return } + parsed.Proc = proc + parsed.Meta = proc.Meta + parsed.Ctx = ctx handler := proc.GetHandler(parsed.Command) if handler == nil { return } + parsed.Handler = handler + if handler.PreFunc != nil { + handler.PreFunc(parsed) + } handlerChain := zerolog.Arr() handlerChain.Str(handler.Name) for handler.subcommandContainer != nil && len(parsed.Args) > 0 { subHandler := handler.subcommandContainer.GetHandler(strings.ToLower(parsed.Args[0])) if subHandler != nil { + handler = subHandler parsed.ParentCommands = append(parsed.ParentCommands, parsed.Command) handlerChain.Str(subHandler.Name) - parsed.PromoteFirstArgToCommand() - handler = subHandler + parsed.Command = strings.ToLower(parsed.ShiftArg()) + parsed.Handler = subHandler + if subHandler.PreFunc != nil { + subHandler.PreFunc(parsed) + } } } @@ -95,9 +106,6 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) } log = logWith.Logger() parsed.Ctx = log.WithContext(ctx) - parsed.Handler = handler - parsed.Proc = proc - parsed.Meta = proc.Meta log.Debug().Msg("Processing command") handler.Func(parsed) From db62b9a1d875f654eec845c19d7309832d6af576 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 29 Apr 2025 02:26:50 +0300 Subject: [PATCH 1107/1647] commands: ignore notices --- commands/event.go | 6 ++++++ commands/processor.go | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/commands/event.go b/commands/event.go index 10d6283f..c02bd3b9 100644 --- a/commands/event.go +++ b/commands/event.go @@ -58,10 +58,16 @@ var IDHTMLParser = &format.HTMLParser{ // ParseEvent parses a message into a command event struct. func ParseEvent[MetaType any](ctx context.Context, evt *event.Event) *Event[MetaType] { content := evt.Content.Parsed.(*event.MessageEventContent) + if content.MsgType == event.MsgNotice || content.RelatesTo.GetReplaceID() != "" { + return nil + } text := content.Body if content.Format == event.FormatHTML { text = IDHTMLParser.Parse(content.FormattedBody, format.NewContext(ctx)) } + if len(text) == 0 { + return nil + } parts := strings.Fields(text) return &Event[MetaType]{ Event: evt, diff --git a/commands/processor.go b/commands/processor.go index 1e0a99a2..cc55aceb 100644 --- a/commands/processor.go +++ b/commands/processor.go @@ -62,7 +62,7 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) } }() parsed := ParseEvent[MetaType](ctx, evt) - if !proc.PreValidator.Validate(parsed) { + if parsed == nil || !proc.PreValidator.Validate(parsed) { return } parsed.Proc = proc From 771424f86b6559c3b383ad65a92303d7dbf8454d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 29 Apr 2025 17:31:27 +0300 Subject: [PATCH 1108/1647] commands: stop looking for subcommands if not found --- commands/processor.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/commands/processor.go b/commands/processor.go index cc55aceb..a7c1d941 100644 --- a/commands/processor.go +++ b/commands/processor.go @@ -90,6 +90,8 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) if subHandler.PreFunc != nil { subHandler.PreFunc(parsed) } + } else { + break } } From 6c9cd6da6bea3f38ac5e3e655fcc70edef655792 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 29 Apr 2025 18:41:37 +0300 Subject: [PATCH 1109/1647] commands: return event ID to allow edits --- commands/event.go | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/commands/event.go b/commands/event.go index c02bd3b9..29a16538 100644 --- a/commands/event.go +++ b/commands/event.go @@ -15,6 +15,7 @@ import ( "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/format" + "maunium.net/go/mautrix/id" ) // Event contains the data of a single command event. @@ -85,16 +86,17 @@ type ReplyOpts struct { Reply bool Thread bool SendAsText bool + Edit id.EventID } -func (evt *Event[MetaType]) Reply(msg string, args ...any) { +func (evt *Event[MetaType]) Reply(msg string, args ...any) id.EventID { if len(args) > 0 { msg = fmt.Sprintf(msg, args...) } - evt.Respond(msg, ReplyOpts{AllowMarkdown: true, Reply: true}) + return evt.Respond(msg, ReplyOpts{AllowMarkdown: true, Reply: true}) } -func (evt *Event[MetaType]) Respond(msg string, opts ReplyOpts) { +func (evt *Event[MetaType]) Respond(msg string, opts ReplyOpts) id.EventID { content := format.RenderMarkdown(msg, opts.AllowMarkdown, opts.AllowHTML) if opts.Thread { content.SetThread(evt.Event) @@ -105,24 +107,33 @@ func (evt *Event[MetaType]) Respond(msg string, opts ReplyOpts) { if !opts.SendAsText { content.MsgType = event.MsgNotice } - _, err := evt.Proc.Client.SendMessageEvent(evt.Ctx, evt.RoomID, event.EventMessage, content) + if opts.Edit != "" { + content.SetEdit(opts.Edit) + } + resp, err := evt.Proc.Client.SendMessageEvent(evt.Ctx, evt.RoomID, event.EventMessage, content) if err != nil { zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to send reply") + return "" } + return resp.EventID } -func (evt *Event[MetaType]) React(emoji string) { - _, err := evt.Proc.Client.SendReaction(evt.Ctx, evt.RoomID, evt.ID, emoji) +func (evt *Event[MetaType]) React(emoji string) id.EventID { + resp, err := evt.Proc.Client.SendReaction(evt.Ctx, evt.RoomID, evt.ID, emoji) if err != nil { zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to send reaction") + return "" } + return resp.EventID } -func (evt *Event[MetaType]) Redact() { - _, err := evt.Proc.Client.RedactEvent(evt.Ctx, evt.RoomID, evt.ID) +func (evt *Event[MetaType]) Redact() id.EventID { + resp, err := evt.Proc.Client.RedactEvent(evt.Ctx, evt.RoomID, evt.ID) if err != nil { zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to redact command") + return "" } + return resp.EventID } func (evt *Event[MetaType]) MarkRead() { From da25a87fc1824a789ea89be1dd87e8a3ba0acad9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 29 Apr 2025 19:58:47 +0300 Subject: [PATCH 1110/1647] event: clear mentions in SetEdit --- event/message.go | 1 + 1 file changed, 1 insertion(+) diff --git a/event/message.go b/event/message.go index 48313784..51403889 100644 --- a/event/message.go +++ b/event/message.go @@ -210,6 +210,7 @@ func (content *MessageEventContent) SetEdit(original id.EventID) { content.RelatesTo = (&RelatesTo{}).SetReplace(original) if content.MsgType == MsgText || content.MsgType == MsgNotice { content.Body = "* " + content.Body + content.Mentions = &Mentions{} if content.Format == FormatHTML && len(content.FormattedBody) > 0 { content.FormattedBody = "* " + content.FormattedBody } From e0b1e9b0d386e5789888b34ca03c52d84868f3be Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 29 Apr 2025 19:59:31 +0300 Subject: [PATCH 1111/1647] commands/event: allow overriding mentions when replying --- commands/event.go | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/commands/event.go b/commands/event.go index 29a16538..13b1d7c1 100644 --- a/commands/event.go +++ b/commands/event.go @@ -81,12 +81,13 @@ func ParseEvent[MetaType any](ctx context.Context, evt *event.Event) *Event[Meta } type ReplyOpts struct { - AllowHTML bool - AllowMarkdown bool - Reply bool - Thread bool - SendAsText bool - Edit id.EventID + AllowHTML bool + AllowMarkdown bool + Reply bool + Thread bool + SendAsText bool + Edit id.EventID + OverrideMentions *event.Mentions } func (evt *Event[MetaType]) Reply(msg string, args ...any) id.EventID { @@ -110,6 +111,9 @@ func (evt *Event[MetaType]) Respond(msg string, opts ReplyOpts) id.EventID { if opts.Edit != "" { content.SetEdit(opts.Edit) } + if opts.OverrideMentions != nil { + content.Mentions = opts.OverrideMentions + } resp, err := evt.Proc.Client.SendMessageEvent(evt.Ctx, evt.RoomID, event.EventMessage, content) if err != nil { zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to send reply") From 58e4d0f2ccb3a1e8ede45f42eb659d06f1174953 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 30 Apr 2025 15:33:33 +0300 Subject: [PATCH 1112/1647] bridgev2: stop disappearing message loop on shutdown --- bridgev2/bridge.go | 1 + bridgev2/disappear.go | 16 +++++++++++----- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 24ceaf6b..38f7ce1d 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -318,6 +318,7 @@ func (br *Bridge) Stop() { func (br *Bridge) stop(isRunOnce bool) { br.Log.Info().Msg("Shutting down bridge") + br.DisappearLoop.Stop() br.stopBackfillQueue.Set() br.Matrix.PreStop() if !isRunOnce { diff --git a/bridgev2/disappear.go b/bridgev2/disappear.go index 5f9900a5..d7b2182b 100644 --- a/bridgev2/disappear.go +++ b/bridgev2/disappear.go @@ -8,6 +8,7 @@ package bridgev2 import ( "context" + "sync/atomic" "time" "github.com/rs/zerolog" @@ -21,15 +22,17 @@ import ( type DisappearLoop struct { br *Bridge NextCheck time.Time - stop context.CancelFunc + stop atomic.Pointer[context.CancelFunc] } const DisappearCheckInterval = 1 * time.Hour func (dl *DisappearLoop) Start() { log := dl.br.Log.With().Str("component", "disappear loop").Logger() - ctx := log.WithContext(context.Background()) - ctx, dl.stop = context.WithCancel(ctx) + ctx, stop := context.WithCancel(log.WithContext(context.Background())) + if oldStop := dl.stop.Swap(&stop); oldStop != nil { + (*oldStop)() + } log.Debug().Msg("Disappearing message loop starting") for { dl.NextCheck = time.Now().Add(DisappearCheckInterval) @@ -49,8 +52,11 @@ func (dl *DisappearLoop) Start() { } func (dl *DisappearLoop) Stop() { - if dl.stop != nil { - dl.stop() + if dl == nil { + return + } + if stop := dl.stop.Load(); stop != nil { + (*stop)() } } From 69a17c6a599958b259f3eef8faeedfe6533cd906 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 1 May 2025 15:22:47 +0300 Subject: [PATCH 1113/1647] bridgev2/networkinterface: remove timeout from ViewingChat --- bridgev2/networkinterface.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 1565a92c..6ea0b02c 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -589,6 +589,7 @@ type ChatViewingNetworkAPI interface { // HandleMatrixViewingChat is called when the user opens a portal room. // This will never be called by the standard appservice connector, // as Matrix doesn't have any standard way of signaling chat open status. + // Clients are expected to call this every 5 seconds. There is no signal for closing a chat. HandleMatrixViewingChat(ctx context.Context, msg *MatrixViewingChat) error } @@ -1247,9 +1248,6 @@ type MatrixTyping struct { type MatrixViewingChat struct { // The portal that the user is viewing. This will be nil when the user switches to a chat from a different bridge. Portal *Portal - // An optional timeout after which the user should not be assumed to be viewing the chat anymore - // unless the event is repeated. - Timeout time.Duration } type MatrixMarkedUnread = MatrixRoomMeta[*event.MarkedUnreadEventContent] From 5c9529606e814ba035e26a6521ba774b0e87861b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 1 May 2025 15:23:31 +0300 Subject: [PATCH 1114/1647] crypto/keybackup: return wrapped errors in ImportRoomKeyFromBackup --- crypto/keybackup.go | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/crypto/keybackup.go b/crypto/keybackup.go index a9686fdf..5724002b 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -3,6 +3,7 @@ package crypto import ( "context" "encoding/base64" + "errors" "fmt" "time" @@ -154,13 +155,19 @@ func (mach *OlmMachine) GetAndStoreKeyBackup(ctx context.Context, version id.Key return nil } +var ( + ErrUnknownAlgorithmInKeyBackup = errors.New("ignoring room key in backup with weird algorithm") + ErrMismatchingSessionIDInKeyBackup = errors.New("mismatched session ID while creating inbound group session from key backup") + ErrFailedToStoreNewInboundGroupSessionFromBackup = errors.New("failed to store new inbound group session from key backup") +) + func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID, keyBackupData *backup.MegolmSessionData) (*InboundGroupSession, error) { log := zerolog.Ctx(ctx).With(). Str("room_id", roomID.String()). Str("session_id", sessionID.String()). Logger() if keyBackupData.Algorithm != id.AlgorithmMegolmV1 { - return nil, fmt.Errorf("ignoring room key in backup with weird algorithm %s", keyBackupData.Algorithm) + return nil, fmt.Errorf("%w %s", ErrUnknownAlgorithmInKeyBackup, keyBackupData.Algorithm) } igsInternal, err := olm.InboundGroupSessionImport([]byte(keyBackupData.SessionKey)) @@ -170,7 +177,7 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id. log.Warn(). Stringer("actual_session_id", igsInternal.ID()). Msg("Mismatched session ID while creating inbound group session from key backup") - return nil, fmt.Errorf("mismatched session ID while creating inbound group session from key backup") + return nil, ErrMismatchingSessionIDInKeyBackup } var maxAge time.Duration @@ -202,7 +209,7 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id. } err = mach.CryptoStore.PutGroupSession(ctx, igs) if err != nil { - return nil, fmt.Errorf("failed to store new inbound group session: %w", err) + return nil, fmt.Errorf("%w: %w", ErrFailedToStoreNewInboundGroupSessionFromBackup, err) } mach.markSessionReceived(ctx, roomID, sessionID, firstKnownIndex) return igs, nil From 5094eea718641b820671749e92f482ff0516fcf2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 1 May 2025 16:28:21 +0300 Subject: [PATCH 1115/1647] bridgev2/networkinterface: allow clients to generate transaction IDs --- bridgev2/database/message.go | 3 +++ bridgev2/networkid/bridgeid.go | 5 +++++ bridgev2/networkinterface.go | 7 +++++++ bridgev2/portal.go | 23 +++++++++++++++++++++++ 4 files changed, 38 insertions(+) diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index 42581c6e..fd6b65d8 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -214,6 +214,9 @@ func (m *Message) updateSQLVariables() []any { } const FakeMXIDPrefix = "~fake:" +const TxnMXIDPrefix = "~txn:" +const NetworkTxnMXIDPrefix = TxnMXIDPrefix + "network:" +const RandomTxnMXIDPrefix = TxnMXIDPrefix + "random:" func (m *Message) SetFakeMXID() { hash := sha256.Sum256([]byte(m.ID)) diff --git a/bridgev2/networkid/bridgeid.go b/bridgev2/networkid/bridgeid.go index d78813eb..443d3655 100644 --- a/bridgev2/networkid/bridgeid.go +++ b/bridgev2/networkid/bridgeid.go @@ -94,6 +94,11 @@ type MessageID string // Transaction IDs must be unique across users in a room, but don't need to be unique across different rooms. type TransactionID string +// RawTransactionID is a client-generated identifier for a message send operation on the remote network. +// +// Unlike TransactionID, RawTransactionID's are only used for sending and don't have any uniqueness requirements. +type RawTransactionID string + // PartID is the ID of a message part on the remote network (e.g. index of image in album). // // Part IDs are only unique within a message, not globally. diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 6ea0b02c..14e3a681 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -259,6 +259,11 @@ type IdentifierValidatingNetwork interface { ValidateUserID(id networkid.UserID) bool } +type TransactionIDGeneratingNetwork interface { + NetworkConnector + GenerateTransactionID(userID id.UserID, roomID id.RoomID, eventType event.Type) networkid.RawTransactionID +} + type PortalBridgeInfoFillingNetwork interface { NetworkConnector FillPortalBridgeInfo(portal *Portal, content *event.BridgeEventContent) @@ -1161,6 +1166,8 @@ type MatrixEventBase[ContentType any] struct { // The original sender user ID. Only present in case the event is being relayed (and Sender is not the same user). OrigSender *OrigSender + + InputTransactionID networkid.RawTransactionID } type MatrixMessage struct { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index d8262d0a..d88f5a7c 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -811,6 +811,13 @@ func (portal *Portal) checkMessageContentCaps(ctx context.Context, caps *event.R return true } +func (portal *Portal) parseInputTransactionID(origSender *OrigSender, evt *event.Event) networkid.RawTransactionID { + if origSender != nil || !strings.HasPrefix(evt.ID.String(), database.NetworkTxnMXIDPrefix) { + return "" + } + return networkid.RawTransactionID(strings.TrimPrefix(evt.ID.String(), database.NetworkTxnMXIDPrefix)) +} + func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { log := zerolog.Ctx(ctx) var relatesTo *event.RelatesTo @@ -938,6 +945,8 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin Content: msgContent, OrigSender: origSender, Portal: portal, + + InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, ThreadRoot: threadRoot, ReplyTo: replyTo, @@ -1173,6 +1182,8 @@ func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, o Content: content, OrigSender: origSender, Portal: portal, + + InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, EditTarget: editTarget, }) @@ -1224,6 +1235,8 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi Event: evt, Content: content, Portal: portal, + + InputTransactionID: portal.parseInputTransactionID(nil, evt), }, TargetMessage: reactionTarget, } @@ -1380,6 +1393,8 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( Content: content, Portal: portal, OrigSender: origSender, + + InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, PrevContent: prevContent, }) @@ -1501,6 +1516,8 @@ func (portal *Portal) handleMatrixMembership( Content: content, Portal: portal, OrigSender: origSender, + + InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, PrevContent: prevContent, }, @@ -1565,6 +1582,8 @@ func (portal *Portal) handleMatrixPowerLevels( Content: content, Portal: portal, OrigSender: origSender, + + InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, PrevContent: prevContent, }, @@ -1651,6 +1670,8 @@ func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLog Content: content, Portal: portal, OrigSender: origSender, + + InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, TargetMessage: redactionTargetMsg, }) @@ -1671,6 +1692,8 @@ func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLog Content: content, Portal: portal, OrigSender: origSender, + + InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, TargetReaction: redactionTargetReaction, }) From e491e87309b05692445bf1a27175535268a5b01e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 2 May 2025 01:56:46 +0300 Subject: [PATCH 1116/1647] commands: panic on duplicate registration --- commands/container.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/commands/container.go b/commands/container.go index e9dfd5e9..bc685b7b 100644 --- a/commands/container.go +++ b/commands/container.go @@ -40,11 +40,19 @@ func (cont *CommandContainer[MetaType]) Register(handlers ...*Handler[MetaType]) func (cont *CommandContainer[MetaType]) registerOne(handler *Handler[MetaType]) { if strings.ToLower(handler.Name) != handler.Name { panic(fmt.Errorf("command %q is not lowercase", handler.Name)) + } else if val, alreadyExists := cont.commands[handler.Name]; alreadyExists && val != handler { + panic(fmt.Errorf("tried to register command %q, but it's already registered", handler.Name)) + } else if aliasTarget, alreadyExists := cont.aliases[handler.Name]; alreadyExists { + panic(fmt.Errorf("tried to register command %q, but it's already registered as an alias for %q", handler.Name, aliasTarget)) } cont.commands[handler.Name] = handler for _, alias := range handler.Aliases { if strings.ToLower(alias) != alias { panic(fmt.Errorf("alias %q is not lowercase", alias)) + } else if val, alreadyExists := cont.aliases[alias]; alreadyExists && val != handler.Name { + panic(fmt.Errorf("tried to register alias %q for %q, but it's already registered for %q", alias, handler.Name, cont.aliases[alias])) + } else if _, alreadyExists = cont.commands[alias]; alreadyExists { + panic(fmt.Errorf("tried to register alias %q for %q, but it's already registered as a command", alias, handler.Name)) } cont.aliases[alias] = handler.Name } From 2b973cac00c67e125d562de80d370aae7b6ad0c3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 2 May 2025 02:05:11 +0300 Subject: [PATCH 1117/1647] commands: include handler chain in command events --- commands/event.go | 1 + commands/processor.go | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/commands/event.go b/commands/event.go index 13b1d7c1..8d51eadd 100644 --- a/commands/event.go +++ b/commands/event.go @@ -27,6 +27,7 @@ type Event[MetaType any] struct { // ParentCommands is the chain of commands leading up to this command. // This is only set if the command is a subcommand. ParentCommands []string + ParentHandlers []*Handler[MetaType] // Command is the lowercased first word of the message. Command string // Args are the rest of the message split by whitespace ([strings.Fields]). diff --git a/commands/processor.go b/commands/processor.go index a7c1d941..da802fd9 100644 --- a/commands/processor.go +++ b/commands/processor.go @@ -82,8 +82,9 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) for handler.subcommandContainer != nil && len(parsed.Args) > 0 { subHandler := handler.subcommandContainer.GetHandler(strings.ToLower(parsed.Args[0])) if subHandler != nil { - handler = subHandler parsed.ParentCommands = append(parsed.ParentCommands, parsed.Command) + parsed.ParentHandlers = append(parsed.ParentHandlers, handler) + handler = subHandler handlerChain.Str(subHandler.Name) parsed.Command = strings.ToLower(parsed.ShiftArg()) parsed.Handler = subHandler From 441349efac9ec84ef8162d9dea25b59853180e9a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 3 May 2025 03:00:31 +0300 Subject: [PATCH 1118/1647] synapseadmin: add SuspendAccount method --- synapseadmin/userapi.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/synapseadmin/userapi.go b/synapseadmin/userapi.go index 9cbb17e4..d3672367 100644 --- a/synapseadmin/userapi.go +++ b/synapseadmin/userapi.go @@ -106,6 +106,19 @@ func (cli *Client) DeactivateAccount(ctx context.Context, userID id.UserID, req return err } +type ReqSuspendUser struct { + Suspend bool `json:"suspend"` +} + +// SuspendAccount suspends or unsuspends a specific local user account. +// +// https://element-hq.github.io/synapse/latest/admin_api/user_admin_api.html#suspendunsuspend-account +func (cli *Client) SuspendAccount(ctx context.Context, userID id.UserID, req ReqSuspendUser) error { + reqURL := cli.BuildAdminURL("v1", "suspend", userID) + _, err := cli.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil) + return err +} + type ReqCreateOrModifyAccount struct { Password string `json:"password,omitempty"` LogoutDevices *bool `json:"logout_devices,omitempty"` From 36781e7de4af886bf4a51fed1714be6a96d6516a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 3 May 2025 01:43:56 +0300 Subject: [PATCH 1119/1647] federation: move server name cache to separate type --- bridgev2/matrix/provisioning.go | 2 +- federation/cache.go | 71 +++++++++++++++++++++++++++++++++ federation/client.go | 4 +- federation/client_test.go | 2 +- federation/httpclient.go | 62 ++++++++++++---------------- 5 files changed, 101 insertions(+), 40 deletions(-) create mode 100644 federation/cache.go diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 126d54de..d809d039 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -103,7 +103,7 @@ func (prov *ProvisioningAPI) Init() { prov.logins = make(map[string]*ProvLogin) prov.net = prov.br.Bridge.Network prov.log = prov.br.Log.With().Str("component", "provisioning").Logger() - prov.fedClient = federation.NewClient("", nil) + prov.fedClient = federation.NewClient("", nil, nil) prov.fedClient.HTTP.Timeout = 20 * time.Second tp := prov.fedClient.HTTP.Transport.(*federation.ServerResolvingTransport) tp.Dialer.Timeout = 10 * time.Second diff --git a/federation/cache.go b/federation/cache.go new file mode 100644 index 00000000..95d096fa --- /dev/null +++ b/federation/cache.go @@ -0,0 +1,71 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation + +import ( + "sync" + "time" +) + +// ResolutionCache is an interface for caching resolved server names. +type ResolutionCache interface { + StoreResolution(*ResolvedServerName) + // LoadResolution loads a resolved server name from the cache. + // Expired entries MUST NOT be returned. + LoadResolution(serverName string) (*ResolvedServerName, error) +} + +type KeyCache interface { + StoreKeys(*ServerKeyResponse) + LoadKeys(serverName string) (*ServerKeyResponse, error) +} + +type InMemoryCache struct { + resolutions map[string]*ResolvedServerName + resolutionsLock sync.RWMutex + keys map[string]*ServerKeyResponse + keysLock sync.RWMutex +} + +func NewInMemoryCache() *InMemoryCache { + return &InMemoryCache{ + resolutions: make(map[string]*ResolvedServerName), + keys: make(map[string]*ServerKeyResponse), + } +} + +func (c *InMemoryCache) StoreResolution(resolution *ResolvedServerName) { + c.resolutionsLock.Lock() + defer c.resolutionsLock.Unlock() + c.resolutions[resolution.ServerName] = resolution +} + +func (c *InMemoryCache) LoadResolution(serverName string) (*ResolvedServerName, error) { + c.resolutionsLock.RLock() + defer c.resolutionsLock.RUnlock() + resolution, ok := c.resolutions[serverName] + if !ok || time.Until(resolution.Expires) < 0 { + return nil, nil + } + return resolution, nil +} + +func (c *InMemoryCache) StoreKeys(keys *ServerKeyResponse) { + c.keysLock.Lock() + defer c.keysLock.Unlock() + c.keys[keys.ServerName] = keys +} + +func (c *InMemoryCache) LoadKeys(serverName string) (*ServerKeyResponse, error) { + c.keysLock.RLock() + defer c.keysLock.RUnlock() + keys, ok := c.keys[serverName] + if !ok || time.Until(keys.ValidUntilTS.Time) < 0 { + return nil, nil + } + return keys, nil +} diff --git a/federation/client.go b/federation/client.go index 7fc630b7..7aff19c9 100644 --- a/federation/client.go +++ b/federation/client.go @@ -32,10 +32,10 @@ type Client struct { Key *SigningKey } -func NewClient(serverName string, key *SigningKey) *Client { +func NewClient(serverName string, key *SigningKey, cache ResolutionCache) *Client { return &Client{ HTTP: &http.Client{ - Transport: NewServerResolvingTransport(), + Transport: NewServerResolvingTransport(cache), Timeout: 120 * time.Second, }, UserAgent: mautrix.DefaultUserAgent, diff --git a/federation/client_test.go b/federation/client_test.go index ba3c3ed4..ece399ea 100644 --- a/federation/client_test.go +++ b/federation/client_test.go @@ -16,7 +16,7 @@ import ( ) func TestClient_Version(t *testing.T) { - cli := federation.NewClient("", nil) + cli := federation.NewClient("", nil, nil) resp, err := cli.Version(context.TODO(), "maunium.net") require.NoError(t, err) require.Equal(t, "Synapse", resp.Server.Name) diff --git a/federation/httpclient.go b/federation/httpclient.go index d6d97280..cbb1674d 100644 --- a/federation/httpclient.go +++ b/federation/httpclient.go @@ -12,7 +12,6 @@ import ( "net" "net/http" "sync" - "time" ) // ServerResolvingTransport is an http.RoundTripper that resolves Matrix server names before sending requests. @@ -22,17 +21,20 @@ type ServerResolvingTransport struct { Transport *http.Transport Dialer *net.Dialer - cache map[string]*ResolvedServerName - resolveLocks map[string]*sync.Mutex - cacheLock sync.Mutex + cache ResolutionCache + + resolveLocks map[string]*sync.Mutex + resolveLocksLock sync.Mutex } -func NewServerResolvingTransport() *ServerResolvingTransport { +func NewServerResolvingTransport(cache ResolutionCache) *ServerResolvingTransport { + if cache == nil { + cache = NewInMemoryCache() + } srt := &ServerResolvingTransport{ - cache: make(map[string]*ResolvedServerName), resolveLocks: make(map[string]*sync.Mutex), - - Dialer: &net.Dialer{}, + cache: cache, + Dialer: &net.Dialer{}, } srt.Transport = &http.Transport{ DialContext: srt.DialContext, @@ -72,37 +74,25 @@ func (srt *ServerResolvingTransport) RoundTrip(request *http.Request) (*http.Res } func (srt *ServerResolvingTransport) resolve(ctx context.Context, serverName string) (*ResolvedServerName, error) { - res, lock := srt.getResolveCache(serverName) - if res != nil { - return res, nil + srt.resolveLocksLock.Lock() + lock, ok := srt.resolveLocks[serverName] + if !ok { + lock = &sync.Mutex{} + srt.resolveLocks[serverName] = lock } + srt.resolveLocksLock.Unlock() + lock.Lock() defer lock.Unlock() - res, _ = srt.getResolveCache(serverName) - if res != nil { + res, err := srt.cache.LoadResolution(serverName) + if err != nil { + return nil, fmt.Errorf("failed to read cache: %w", err) + } else if res != nil { + return res, nil + } else if res, err = ResolveServerName(ctx, serverName, srt.ResolveOpts); err != nil { + return nil, err + } else { + srt.cache.StoreResolution(res) return res, nil } - var err error - res, err = ResolveServerName(ctx, serverName, srt.ResolveOpts) - if err != nil { - return nil, err - } - srt.cacheLock.Lock() - srt.cache[serverName] = res - srt.cacheLock.Unlock() - return res, nil -} - -func (srt *ServerResolvingTransport) getResolveCache(serverName string) (*ResolvedServerName, *sync.Mutex) { - srt.cacheLock.Lock() - defer srt.cacheLock.Unlock() - if val, ok := srt.cache[serverName]; ok && time.Until(val.Expires) > 0 { - return val, nil - } - rl, ok := srt.resolveLocks[serverName] - if !ok { - rl = &sync.Mutex{} - srt.resolveLocks[serverName] = rl - } - return nil, rl } From 66e7d834cc7cdf5415d9bb88e45b530606ee0fd5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 3 May 2025 01:44:07 +0300 Subject: [PATCH 1120/1647] federation/resolution: parse cache-control headers for .well-known --- federation/resolution.go | 41 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/federation/resolution.go b/federation/resolution.go index 24085282..69d4d3bf 100644 --- a/federation/resolution.go +++ b/federation/resolution.go @@ -120,6 +120,38 @@ func RequestSRV(ctx context.Context, cli *net.Resolver, hostname string) ([]*net return target, err } +func parseCacheControl(resp *http.Response) time.Duration { + cc := resp.Header.Get("Cache-Control") + if cc == "" { + return 0 + } + parts := strings.Split(cc, ",") + for _, part := range parts { + kv := strings.SplitN(strings.TrimSpace(part), "=", 1) + switch kv[0] { + case "no-cache", "no-store": + return 0 + case "max-age": + if len(kv) < 2 { + continue + } + maxAge, err := strconv.Atoi(kv[1]) + if err != nil || maxAge < 0 { + continue + } + age, _ := strconv.Atoi(resp.Header.Get("Age")) + return time.Duration(maxAge-age) * time.Second + } + } + return 0 +} + +const ( + MinCacheDuration = 1 * time.Hour + MaxCacheDuration = 72 * time.Hour + DefaultCacheDuration = 24 * time.Hour +) + // RequestWellKnown sends a request to the well-known endpoint of a server and returns the response, // plus the time when the cache should expire. func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (*RespWellKnown, time.Time, error) { @@ -147,6 +179,13 @@ func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (* } else if respData.Server == "" { return nil, time.Time{}, errors.New("server name not found in response") } - // TODO parse cache-control header + cacheDuration := parseCacheControl(resp) + if cacheDuration <= 0 { + cacheDuration = DefaultCacheDuration + } else if cacheDuration < MinCacheDuration { + cacheDuration = MinCacheDuration + } else if cacheDuration > MaxCacheDuration { + cacheDuration = MaxCacheDuration + } return &respData, time.Now().Add(24 * time.Hour), nil } From 44de13a7de37cd8f699002a59a6b53f2e00143d2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 3 May 2025 01:53:48 +0300 Subject: [PATCH 1121/1647] federation/keyserver: use shared utilities for writing responses --- federation/keyserver.go | 53 ++++++++++------------------------------- 1 file changed, 13 insertions(+), 40 deletions(-) diff --git a/federation/keyserver.go b/federation/keyserver.go index 3e74bfdf..505be44f 100644 --- a/federation/keyserver.go +++ b/federation/keyserver.go @@ -8,12 +8,12 @@ package federation import ( "encoding/json" - "fmt" "net/http" "strconv" "time" "github.com/gorilla/mux" + "go.mau.fi/util/exhttp" "go.mau.fi/util/jsontime" "maunium.net/go/mautrix" @@ -58,25 +58,13 @@ func (ks *KeyServer) Register(r *mux.Router) { keyRouter.HandleFunc("/v2/query/{serverName}", ks.GetQueryKeys).Methods(http.MethodGet) keyRouter.HandleFunc("/v2/query", ks.PostQueryKeys).Methods(http.MethodPost) keyRouter.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - ErrCode: mautrix.MUnrecognized.ErrCode, - Err: "Unrecognized endpoint", - }) + mautrix.MUnrecognized.WithStatus(http.StatusNotFound).WithMessage("Unrecognized endpoint").Write(w) }) keyRouter.MethodNotAllowedHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - jsonResponse(w, http.StatusMethodNotAllowed, &mautrix.RespError{ - ErrCode: mautrix.MUnrecognized.ErrCode, - Err: "Invalid method for endpoint", - }) + mautrix.MUnrecognized.WithStatus(http.StatusMethodNotAllowed).WithMessage("Invalid method for endpoint").Write(w) }) } -func jsonResponse(w http.ResponseWriter, code int, data any) { - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(code) - _ = json.NewEncoder(w).Encode(data) -} - // RespWellKnown is the response body for the `GET /.well-known/matrix/server` endpoint. type RespWellKnown struct { Server string `json:"m.server"` @@ -87,12 +75,9 @@ type RespWellKnown struct { // https://spec.matrix.org/v1.9/server-server-api/#get_well-knownmatrixserver func (ks *KeyServer) GetWellKnown(w http.ResponseWriter, r *http.Request) { if ks.WellKnownTarget == "" { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - ErrCode: mautrix.MNotFound.ErrCode, - Err: "No well-known target set", - }) + mautrix.MNotFound.WithMessage("No well-known target set").Write(w) } else { - jsonResponse(w, http.StatusOK, &RespWellKnown{Server: ks.WellKnownTarget}) + exhttp.WriteJSONResponse(w, http.StatusOK, &RespWellKnown{Server: ks.WellKnownTarget}) } } @@ -105,7 +90,7 @@ type RespServerVersion struct { // // https://spec.matrix.org/v1.9/server-server-api/#get_matrixfederationv1version func (ks *KeyServer) GetServerVersion(w http.ResponseWriter, r *http.Request) { - jsonResponse(w, http.StatusOK, &RespServerVersion{Server: ks.Version}) + exhttp.WriteJSONResponse(w, http.StatusOK, &RespServerVersion{Server: ks.Version}) } // GetServerKey implements the `GET /_matrix/key/v2/server` endpoint. @@ -114,12 +99,9 @@ func (ks *KeyServer) GetServerVersion(w http.ResponseWriter, r *http.Request) { func (ks *KeyServer) GetServerKey(w http.ResponseWriter, r *http.Request) { domain, key := ks.KeyProvider.Get(r) if key == nil { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - ErrCode: mautrix.MNotFound.ErrCode, - Err: fmt.Sprintf("No signing key found for %q", r.Host), - }) + mautrix.MNotFound.WithMessage("No signing key found for %q", r.Host).Write(w) } else { - jsonResponse(w, http.StatusOK, key.GenerateKeyResponse(domain, nil)) + exhttp.WriteJSONResponse(w, http.StatusOK, key.GenerateKeyResponse(domain, nil)) } } @@ -144,10 +126,7 @@ func (ks *KeyServer) PostQueryKeys(w http.ResponseWriter, r *http.Request) { var req ReqQueryKeys err := json.NewDecoder(r.Body).Decode(&req) if err != nil { - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - ErrCode: mautrix.MBadJSON.ErrCode, - Err: fmt.Sprintf("failed to parse request: %v", err), - }) + mautrix.MBadJSON.WithMessage("failed to parse request: %v", err).Write(w) return } @@ -165,7 +144,7 @@ func (ks *KeyServer) PostQueryKeys(w http.ResponseWriter, r *http.Request) { } } } - jsonResponse(w, http.StatusOK, resp) + exhttp.WriteJSONResponse(w, http.StatusOK, resp) } // GetQueryKeysResponse is the response body for the `GET /_matrix/key/v2/query/{serverName}` endpoint @@ -181,16 +160,10 @@ func (ks *KeyServer) GetQueryKeys(w http.ResponseWriter, r *http.Request) { minimumValidUntilTSString := r.URL.Query().Get("minimum_valid_until_ts") minimumValidUntilTS, err := strconv.ParseInt(minimumValidUntilTSString, 10, 64) if err != nil && minimumValidUntilTSString != "" { - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - ErrCode: mautrix.MInvalidParam.ErrCode, - Err: fmt.Sprintf("failed to parse ?minimum_valid_until_ts: %v", err), - }) + mautrix.MInvalidParam.WithMessage("failed to parse ?minimum_valid_until_ts: %v", err).Write(w) return } else if time.UnixMilli(minimumValidUntilTS).After(time.Now().Add(24 * time.Hour)) { - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - ErrCode: mautrix.MInvalidParam.ErrCode, - Err: "minimum_valid_until_ts may not be more than 24 hours in the future", - }) + mautrix.MInvalidParam.WithMessage("minimum_valid_until_ts may not be more than 24 hours in the future").Write(w) return } resp := &GetQueryKeysResponse{ @@ -199,5 +172,5 @@ func (ks *KeyServer) GetQueryKeys(w http.ResponseWriter, r *http.Request) { if domain, key := ks.KeyProvider.Get(r); key != nil && domain == serverName { resp.ServerKeys = append(resp.ServerKeys, key.GenerateKeyResponse(serverName, nil)) } - jsonResponse(w, http.StatusOK, resp) + exhttp.WriteJSONResponse(w, http.StatusOK, resp) } From b1f0b1732f22917179527edb5c93bce075ab646d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 3 May 2025 02:08:14 +0300 Subject: [PATCH 1122/1647] federation/cache: add noop cache --- federation/cache.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/federation/cache.go b/federation/cache.go index 95d096fa..a491dbd1 100644 --- a/federation/cache.go +++ b/federation/cache.go @@ -31,6 +31,11 @@ type InMemoryCache struct { keysLock sync.RWMutex } +var ( + _ ResolutionCache = (*InMemoryCache)(nil) + _ KeyCache = (*InMemoryCache)(nil) +) + func NewInMemoryCache() *InMemoryCache { return &InMemoryCache{ resolutions: make(map[string]*ResolvedServerName), @@ -69,3 +74,15 @@ func (c *InMemoryCache) LoadKeys(serverName string) (*ServerKeyResponse, error) } return keys, nil } + +type NoopCache struct{} + +func (*NoopCache) StoreKeys(_ *ServerKeyResponse) {} +func (*NoopCache) LoadKeys(_ string) (*ServerKeyResponse, error) { return nil, nil } +func (*NoopCache) StoreResolution(_ *ResolvedServerName) {} +func (*NoopCache) LoadResolution(_ string) (*ResolvedServerName, error) { return nil, nil } + +var ( + _ ResolutionCache = (*NoopCache)(nil) + _ KeyCache = (*NoopCache)(nil) +) From 9c3e1b5904f08796cc924620beb4821afbe6a808 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 3 May 2025 02:08:55 +0300 Subject: [PATCH 1123/1647] federation/signingkey: add support for roundtripping ServerKeyResponses --- federation/client.go | 3 +-- federation/signingkey.go | 54 +++++++++++++++++++++++++++++++++++++--- 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/federation/client.go b/federation/client.go index 7aff19c9..b7927b91 100644 --- a/federation/client.go +++ b/federation/client.go @@ -9,7 +9,6 @@ package federation import ( "bytes" "context" - "encoding/base64" "encoding/json" "fmt" "io" @@ -414,6 +413,6 @@ func (r *signableRequest) Sign(key *SigningKey) (string, error) { r.Origin, r.Destination, key.ID, - base64.RawURLEncoding.EncodeToString(sig), + sig, ), nil } diff --git a/federation/signingkey.go b/federation/signingkey.go index 67751b48..a74b4d6a 100644 --- a/federation/signingkey.go +++ b/federation/signingkey.go @@ -11,9 +11,11 @@ import ( "encoding/base64" "encoding/json" "fmt" + "maps" "strings" "time" + "github.com/tidwall/sjson" "go.mau.fi/util/jsontime" "maunium.net/go/mautrix/crypto/canonicaljson" @@ -77,6 +79,46 @@ type ServerKeyResponse struct { OldVerifyKeys map[id.KeyID]OldVerifyKey `json:"old_verify_keys,omitempty"` Signatures map[string]map[id.KeyID]string `json:"signatures,omitempty"` ValidUntilTS jsontime.UnixMilli `json:"valid_until_ts"` + + Extra map[string]any `json:"-"` +} + +type marshalableSKR ServerKeyResponse + +func (skr *ServerKeyResponse) MarshalJSON() ([]byte, error) { + if skr.Extra == nil { + return json.Marshal((*marshalableSKR)(skr)) + } + marshalable := maps.Clone(skr.Extra) + marshalable["server_name"] = skr.ServerName + marshalable["verify_keys"] = skr.VerifyKeys + marshalable["old_verify_keys"] = skr.OldVerifyKeys + marshalable["signatures"] = skr.Signatures + marshalable["valid_until_ts"] = skr.ValidUntilTS + return json.Marshal(skr.Extra) +} + +func (skr *ServerKeyResponse) UnmarshalJSON(data []byte) error { + err := json.Unmarshal(data, (*marshalableSKR)(skr)) + if err != nil { + return err + } + var extra map[string]any + err = json.Unmarshal(data, &extra) + if err != nil { + return err + } + delete(extra, "server_name") + delete(extra, "verify_keys") + delete(extra, "old_verify_keys") + delete(extra, "signatures") + delete(extra, "valid_until_ts") + if len(extra) > 0 { + skr.Extra = extra + } else { + skr.Extra = nil + } + return nil } type ServerVerifyKey struct { @@ -92,12 +134,16 @@ type OldVerifyKey struct { ExpiredTS jsontime.UnixMilli `json:"expired_ts"` } -func (sk *SigningKey) SignJSON(data any) ([]byte, error) { +func (sk *SigningKey) SignJSON(data any) (string, error) { marshaled, err := json.Marshal(data) if err != nil { - return nil, err + return "", err } - return sk.SignRawJSON(marshaled), nil + marshaled, err = sjson.DeleteBytes(marshaled, "signatures") + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(sk.SignRawJSON(marshaled)), nil } func (sk *SigningKey) SignRawJSON(data json.RawMessage) []byte { @@ -120,7 +166,7 @@ func (sk *SigningKey) GenerateKeyResponse(serverName string, oldVerifyKeys map[i } skr.Signatures = map[string]map[id.KeyID]string{ serverName: { - sk.ID: base64.RawURLEncoding.EncodeToString(signature), + sk.ID: signature, }, } return skr From 2d1620ded3716850637d414cf4fe78249d661017 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 3 May 2025 02:09:51 +0300 Subject: [PATCH 1124/1647] federation/keyserver: add support for returning other servers keys --- federation/keyserver.go | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/federation/keyserver.go b/federation/keyserver.go index 505be44f..b0faf8fb 100644 --- a/federation/keyserver.go +++ b/federation/keyserver.go @@ -47,6 +47,7 @@ type KeyServer struct { KeyProvider ServerKeyProvider Version ServerVersion WellKnownTarget string + OtherKeys KeyCache } // Register registers the key server endpoints to the given router. @@ -169,8 +170,26 @@ func (ks *KeyServer) GetQueryKeys(w http.ResponseWriter, r *http.Request) { resp := &GetQueryKeysResponse{ ServerKeys: []*ServerKeyResponse{}, } - if domain, key := ks.KeyProvider.Get(r); key != nil && domain == serverName { - resp.ServerKeys = append(resp.ServerKeys, key.GenerateKeyResponse(serverName, nil)) + domain, key := ks.KeyProvider.Get(r) + if domain == serverName { + if key != nil { + resp.ServerKeys = append(resp.ServerKeys, key.GenerateKeyResponse(serverName, nil)) + } + } else if ks.OtherKeys != nil { + otherKey, err := ks.OtherKeys.LoadKeys(serverName) + if err != nil { + mautrix.MUnknown.WithMessage("Failed to load keys from cache").Write(w) + return + } + if key != nil && domain != "" { + signature, err := key.SignJSON(otherKey) + if err == nil { + otherKey.Signatures[domain] = map[id.KeyID]string{ + key.ID: signature, + } + } + } + resp.ServerKeys = append(resp.ServerKeys, otherKey) } exhttp.WriteJSONResponse(w, http.StatusOK, resp) } From 9a02b6428d44b5563bcb87ecaaf788fb55402594 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 4 May 2025 00:06:51 +0300 Subject: [PATCH 1125/1647] federation/serverauth: implement server side of request authentication --- federation/client.go | 38 +++---- federation/context.go | 30 ++++++ federation/httpclient.go | 6 -- federation/serverauth.go | 218 +++++++++++++++++++++++++++++++++++++++ federation/signingkey.go | 48 +++++++++ 5 files changed, 316 insertions(+), 24 deletions(-) create mode 100644 federation/context.go create mode 100644 federation/serverauth.go diff --git a/federation/client.go b/federation/client.go index b7927b91..93ed759c 100644 --- a/federation/client.go +++ b/federation/client.go @@ -10,7 +10,6 @@ import ( "bytes" "context" "encoding/json" - "fmt" "io" "net/http" "net/url" @@ -373,16 +372,12 @@ func (c *Client) compileRequest(ctx context.Context, params RequestParams) (*htt Message: "client not configured for authentication", } } - var contentAny any - if reqJSON != nil { - contentAny = reqJSON - } auth, err := (&signableRequest{ Method: req.Method, URI: reqURL.RequestURI(), Origin: c.ServerName, Destination: params.ServerName, - Content: contentAny, + Content: reqJSON, }).Sign(c.Key) if err != nil { return nil, mautrix.HTTPError{ @@ -396,11 +391,19 @@ func (c *Client) compileRequest(ctx context.Context, params RequestParams) (*htt } type signableRequest struct { - Method string `json:"method"` - URI string `json:"uri"` - Origin string `json:"origin"` - Destination string `json:"destination"` - Content any `json:"content,omitempty"` + Method string `json:"method"` + URI string `json:"uri"` + Origin string `json:"origin"` + Destination string `json:"destination"` + Content json.RawMessage `json:"content,omitempty"` +} + +func (r *signableRequest) Verify(key id.SigningKey, sig string) bool { + message, err := json.Marshal(r) + if err != nil { + return false + } + return VerifyJSONRaw(key, sig, message) } func (r *signableRequest) Sign(key *SigningKey) (string, error) { @@ -408,11 +411,10 @@ func (r *signableRequest) Sign(key *SigningKey) (string, error) { if err != nil { return "", err } - return fmt.Sprintf( - `X-Matrix origin="%s",destination="%s",key="%s",sig="%s"`, - r.Origin, - r.Destination, - key.ID, - sig, - ), nil + return XMatrixAuth{ + Origin: r.Origin, + Destination: r.Destination, + KeyID: key.ID, + Signature: sig, + }.String(), nil } diff --git a/federation/context.go b/federation/context.go new file mode 100644 index 00000000..8280431f --- /dev/null +++ b/federation/context.go @@ -0,0 +1,30 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation + +import ( + "context" + "net/http" +) + +type contextKey int + +const ( + contextKeyIPPort contextKey = iota + contextKeyDestinationServer +) + +func DestinationServerNameFromRequest(r *http.Request) string { + return DestinationServerName(r.Context()) +} + +func DestinationServerName(ctx context.Context) string { + if dest, ok := ctx.Value(contextKeyDestinationServer).(string); ok { + return dest + } + return "" +} diff --git a/federation/httpclient.go b/federation/httpclient.go index cbb1674d..2f8dbb4f 100644 --- a/federation/httpclient.go +++ b/federation/httpclient.go @@ -52,12 +52,6 @@ func (srt *ServerResolvingTransport) DialContext(ctx context.Context, network, a return srt.Dialer.DialContext(ctx, network, addrs[0]) } -type contextKey int - -const ( - contextKeyIPPort contextKey = iota -) - func (srt *ServerResolvingTransport) RoundTrip(request *http.Request) (*http.Response, error) { if request.URL.Scheme != "matrix-federation" { return nil, fmt.Errorf("unsupported scheme: %s", request.URL.Scheme) diff --git a/federation/serverauth.go b/federation/serverauth.go new file mode 100644 index 00000000..fadd500e --- /dev/null +++ b/federation/serverauth.go @@ -0,0 +1,218 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "maps" + "net/http" + "slices" + "strings" + "sync" + + "github.com/rs/zerolog" + "go.mau.fi/util/ptr" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/id" +) + +type ServerAuth struct { + Keys KeyCache + Client *Client + GetDestination func(XMatrixAuth) string + MaxBodySize int64 + + keyFetchLocks map[string]*sync.Mutex + keyFetchLocksLock sync.Mutex +} + +var MUnauthorized = mautrix.RespError{ErrCode: "M_UNAUTHORIZED", StatusCode: http.StatusUnauthorized} + +var ( + ErrMissingAuthHeader = MUnauthorized.WithMessage("Missing Authorization header") + ErrInvalidAuthHeader = MUnauthorized.WithMessage("Authorization header does not start with X-Matrix") + ErrMalformedAuthHeader = MUnauthorized.WithMessage("X-Matrix value is missing required components") + ErrInvalidDestination = MUnauthorized.WithMessage("Invalid destination in X-Matrix header") + ErrFailedToQueryKeys = MUnauthorized.WithMessage("Failed to query server keys") + ErrInvalidSelfSignatures = MUnauthorized.WithMessage("Server keys don't have valid self-signatures") + ErrRequestBodyTooLarge = mautrix.MTooLarge.WithMessage("Request body too large") + ErrInvalidJSONBody = mautrix.MBadJSON.WithMessage("Request body is not valid JSON") + ErrBodyReadFailed = mautrix.MUnknown.WithMessage("Failed to read request body") + ErrInvalidRequestSignature = MUnauthorized.WithMessage("Failed to verify request signature") +) + +type XMatrixAuth struct { + Origin string + Destination string + KeyID id.KeyID + Signature string +} + +func (xma XMatrixAuth) String() string { + return fmt.Sprintf( + `X-Matrix origin="%s",destination="%s",key="%s",sig="%s"`, + xma.Origin, + xma.Destination, + xma.KeyID, + xma.Signature, + ) +} + +func ParseXMatrixAuth(auth string) (xma XMatrixAuth) { + auth = strings.TrimPrefix(auth, "X-Matrix ") + for part := range strings.SplitSeq(auth, ",") { + part = strings.TrimSpace(part) + eqIdx := strings.Index(part, "=") + if eqIdx == -1 || strings.Count(part, "=") > 1 { + continue + } + val := strings.Trim(part[eqIdx+1:], "\"") + switch strings.ToLower(part[:eqIdx]) { + case "origin": + xma.Origin = val + case "destination": + xma.Destination = val + case "key": + xma.KeyID = id.KeyID(val) + case "sig": + xma.Signature = val + } + } + return +} + +func (sa *ServerAuth) GetKeysWithCache(ctx context.Context, serverName string) (*ServerKeyResponse, error) { + sa.keyFetchLocksLock.Lock() + lock, ok := sa.keyFetchLocks[serverName] + if !ok { + lock = &sync.Mutex{} + sa.keyFetchLocks[serverName] = lock + } + sa.keyFetchLocksLock.Unlock() + + lock.Lock() + defer lock.Unlock() + res, err := sa.Keys.LoadKeys(serverName) + if err != nil { + return nil, fmt.Errorf("failed to read cache: %w", err) + } else if res != nil { + return res, nil + } else if res, err = sa.Client.ServerKeys(ctx, serverName); err != nil { + return nil, err + } else { + sa.Keys.StoreKeys(res) + return res, nil + } +} + +type fixedLimitedReader struct { + R io.Reader + N int64 + Err error +} + +func (l *fixedLimitedReader) Read(p []byte) (n int, err error) { + if l.N <= 0 { + return 0, l.Err + } + if int64(len(p)) > l.N { + p = p[0:l.N] + } + n, err = l.R.Read(p) + l.N -= int64(n) + return +} + +func (sa *ServerAuth) Authenticate(r *http.Request) (*http.Request, *mautrix.RespError) { + defer func() { + _ = r.Body.Close() + }() + log := zerolog.Ctx(r.Context()) + if r.ContentLength > sa.MaxBodySize { + return nil, &ErrRequestBodyTooLarge + } + auth := r.Header.Get("Authorization") + if auth == "" { + return nil, &ErrMissingAuthHeader + } else if !strings.HasPrefix(auth, "X-Matrix ") { + return nil, &ErrInvalidAuthHeader + } + parsed := ParseXMatrixAuth(auth) + if parsed.Origin == "" || parsed.KeyID == "" || parsed.Signature == "" { + log.Trace().Str("auth_header", auth).Msg("Malformed X-Matrix header") + return nil, &ErrMalformedAuthHeader + } + destination := sa.GetDestination(parsed) + if destination == "" || (parsed.Destination != "" && parsed.Destination != destination) { + log.Trace(). + Str("got_destination", parsed.Destination). + Str("expected_destination", destination). + Msg("Invalid destination in X-Matrix header") + return nil, &ErrInvalidDestination + } + resp, err := sa.GetKeysWithCache(r.Context(), parsed.Origin) + if err != nil { + log.Err(err). + Str("server_name", parsed.Origin). + Msg("Failed to query keys to authenticate request") + return nil, &ErrFailedToQueryKeys + } else if !resp.VerifySelfSignature() { + return nil, &ErrInvalidSelfSignatures + } + key, ok := resp.VerifyKeys[parsed.KeyID] + if !ok { + keys := slices.Collect(maps.Keys(resp.VerifyKeys)) + log.Trace(). + Stringer("expected_key_id", parsed.KeyID). + Any("found_key_ids", keys). + Msg("Didn't find expected key ID to verify request") + return nil, ptr.Ptr(MUnauthorized.WithMessage("Key ID %q not found (got %v)", parsed.KeyID, keys)) + } + reqBody, err := io.ReadAll(&fixedLimitedReader{R: r.Body, N: sa.MaxBodySize, Err: ErrRequestBodyTooLarge}) + if errors.Is(err, ErrRequestBodyTooLarge) { + return nil, &ErrRequestBodyTooLarge + } else if err != nil { + log.Err(err). + Str("server_name", parsed.Origin). + Msg("Failed to read request body to authenticate") + return nil, &ErrBodyReadFailed + } else if !json.Valid(reqBody) { + return nil, &ErrInvalidJSONBody + } + valid := (&signableRequest{ + Method: r.Method, + URI: r.URL.RawPath, + Origin: parsed.Origin, + Destination: destination, + Content: reqBody, + }).Verify(key.Key, parsed.Signature) + if !valid { + log.Trace().Msg("Request has invalid signature") + return nil, &ErrInvalidRequestSignature + } + ctx := context.WithValue(r.Context(), contextKeyDestinationServer, destination) + ctx = log.With().Str("destination_server_name", destination).Logger().WithContext(ctx) + modifiedReq := r.WithContext(ctx) + modifiedReq.Body = io.NopCloser(bytes.NewReader(reqBody)) + return modifiedReq, nil +} + +func (sa *ServerAuth) AuthenticateMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if modifiedReq, err := sa.Authenticate(r); err != nil { + err.Write(w) + } else { + next.ServeHTTP(w, modifiedReq) + } + }) +} diff --git a/federation/signingkey.go b/federation/signingkey.go index a74b4d6a..54c62492 100644 --- a/federation/signingkey.go +++ b/federation/signingkey.go @@ -15,7 +15,9 @@ import ( "strings" "time" + "github.com/tidwall/gjson" "github.com/tidwall/sjson" + "go.mau.fi/util/exgjson" "go.mau.fi/util/jsontime" "maunium.net/go/mautrix/crypto/canonicaljson" @@ -83,6 +85,52 @@ type ServerKeyResponse struct { Extra map[string]any `json:"-"` } +func (skr *ServerKeyResponse) VerifySelfSignature() bool { + for keyID, key := range skr.VerifyKeys { + if !VerifyJSON(skr.ServerName, keyID, key.Key, skr) { + return false + } + } + return true +} + +func VerifyJSON(serverName string, keyID id.KeyID, key id.SigningKey, data any) bool { + var err error + message, ok := data.(json.RawMessage) + if !ok { + message, err = json.Marshal(data) + if err != nil { + return false + } + } + sigVal := gjson.GetBytes(message, exgjson.Path("signatures", serverName, string(keyID))) + if sigVal.Type != gjson.String { + return false + } + message, err = sjson.DeleteBytes(message, "signatures") + if err != nil { + return false + } + message, err = sjson.DeleteBytes(message, "unsigned") + if err != nil { + return false + } + return VerifyJSONRaw(key, sigVal.Str, message) +} + +func VerifyJSONRaw(key id.SigningKey, sig string, message json.RawMessage) bool { + sigBytes, err := base64.RawURLEncoding.DecodeString(sig) + if err != nil { + return false + } + keyBytes, err := base64.RawStdEncoding.DecodeString(string(key)) + if err != nil { + return false + } + message = canonicaljson.CanonicalJSONAssumeValid(message) + return ed25519.Verify(keyBytes, message, sigBytes) +} + type marshalableSKR ServerKeyResponse func (skr *ServerKeyResponse) MarshalJSON() ([]byte, error) { From d145f008635f8bbdddbe9be44e4344816fdf1248 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 4 May 2025 00:39:43 +0300 Subject: [PATCH 1126/1647] federation/serverauth: cache key querying errors --- federation/cache.go | 67 ++++++++++++++++++++++++++++++++++++++-- federation/serverauth.go | 57 +++++++++++++++++++++++++++------- federation/signingkey.go | 9 ++++++ 3 files changed, 119 insertions(+), 14 deletions(-) diff --git a/federation/cache.go b/federation/cache.go index a491dbd1..301091b3 100644 --- a/federation/cache.go +++ b/federation/cache.go @@ -7,6 +7,9 @@ package federation import ( + "errors" + "fmt" + "math" "sync" "time" ) @@ -21,13 +24,19 @@ type ResolutionCache interface { type KeyCache interface { StoreKeys(*ServerKeyResponse) + StoreFetchError(serverName string, err error) + ShouldReQuery(serverName string) bool LoadKeys(serverName string) (*ServerKeyResponse, error) } type InMemoryCache struct { + MinKeyRefetchDelay time.Duration + resolutions map[string]*ResolvedServerName resolutionsLock sync.RWMutex keys map[string]*ServerKeyResponse + lastReQueryAt map[string]time.Time + lastError map[string]*resolutionErrorCache keysLock sync.RWMutex } @@ -38,8 +47,11 @@ var ( func NewInMemoryCache() *InMemoryCache { return &InMemoryCache{ - resolutions: make(map[string]*ResolvedServerName), - keys: make(map[string]*ServerKeyResponse), + resolutions: make(map[string]*ResolvedServerName), + keys: make(map[string]*ServerKeyResponse), + lastReQueryAt: make(map[string]time.Time), + lastError: make(map[string]*resolutionErrorCache), + MinKeyRefetchDelay: 1 * time.Hour, } } @@ -63,22 +75,73 @@ func (c *InMemoryCache) StoreKeys(keys *ServerKeyResponse) { c.keysLock.Lock() defer c.keysLock.Unlock() c.keys[keys.ServerName] = keys + delete(c.lastError, keys.ServerName) } +type resolutionErrorCache struct { + Error error + Time time.Time + Count int +} + +const MaxBackoff = 7 * 24 * time.Hour + +func (rec *resolutionErrorCache) ShouldRetry() bool { + backoff := time.Duration(math.Exp(float64(rec.Count))) * time.Second + return time.Since(rec.Time) > backoff +} + +var ErrRecentKeyQueryFailed = errors.New("last retry was too recent") + func (c *InMemoryCache) LoadKeys(serverName string) (*ServerKeyResponse, error) { c.keysLock.RLock() defer c.keysLock.RUnlock() keys, ok := c.keys[serverName] if !ok || time.Until(keys.ValidUntilTS.Time) < 0 { + err, ok := c.lastError[serverName] + if ok && !err.ShouldRetry() { + return nil, fmt.Errorf( + "%w (%s ago) and failed with %w", + ErrRecentKeyQueryFailed, + time.Since(err.Time).String(), + err.Error, + ) + } return nil, nil } return keys, nil } +func (c *InMemoryCache) StoreFetchError(serverName string, err error) { + c.keysLock.Lock() + defer c.keysLock.Unlock() + errorCache, ok := c.lastError[serverName] + if ok { + errorCache.Time = time.Now() + errorCache.Error = err + errorCache.Count++ + } else { + c.lastError[serverName] = &resolutionErrorCache{Error: err, Time: time.Now(), Count: 1} + } +} + +func (c *InMemoryCache) ShouldReQuery(serverName string) bool { + c.keysLock.Lock() + defer c.keysLock.Unlock() + lastQuery, ok := c.lastReQueryAt[serverName] + if ok && time.Since(lastQuery) < c.MinKeyRefetchDelay { + return false + } + c.lastReQueryAt[serverName] = time.Now() + return true +} + type NoopCache struct{} func (*NoopCache) StoreKeys(_ *ServerKeyResponse) {} func (*NoopCache) LoadKeys(_ string) (*ServerKeyResponse, error) { return nil, nil } +func (*NoopCache) StoreFetchError(_ string, _ error) {} +func (*NoopCache) ShouldReQuery(_ string) bool { return true } func (*NoopCache) StoreResolution(_ *ResolvedServerName) {} func (*NoopCache) LoadResolution(_ string) (*ResolvedServerName, error) { return nil, nil } diff --git a/federation/serverauth.go b/federation/serverauth.go index fadd500e..ef4ed246 100644 --- a/federation/serverauth.go +++ b/federation/serverauth.go @@ -36,6 +36,16 @@ type ServerAuth struct { keyFetchLocksLock sync.Mutex } +func NewServerAuth(client *Client, keyCache KeyCache, getDestination func(auth XMatrixAuth) string) *ServerAuth { + return &ServerAuth{ + Keys: keyCache, + Client: client, + GetDestination: getDestination, + MaxBodySize: 50 * 1024 * 1024, + keyFetchLocks: make(map[string]*sync.Mutex), + } +} + var MUnauthorized = mautrix.RespError{ErrCode: "M_UNAUTHORIZED", StatusCode: http.StatusUnauthorized} var ( @@ -91,7 +101,14 @@ func ParseXMatrixAuth(auth string) (xma XMatrixAuth) { return } -func (sa *ServerAuth) GetKeysWithCache(ctx context.Context, serverName string) (*ServerKeyResponse, error) { +func (sa *ServerAuth) GetKeysWithCache(ctx context.Context, serverName string, keyID id.KeyID) (*ServerKeyResponse, error) { + res, err := sa.Keys.LoadKeys(serverName) + if err != nil { + return nil, fmt.Errorf("failed to read cache: %w", err) + } else if res.HasKey(keyID) { + return res, nil + } + sa.keyFetchLocksLock.Lock() lock, ok := sa.keyFetchLocks[serverName] if !ok { @@ -102,17 +119,27 @@ func (sa *ServerAuth) GetKeysWithCache(ctx context.Context, serverName string) ( lock.Lock() defer lock.Unlock() - res, err := sa.Keys.LoadKeys(serverName) + res, err = sa.Keys.LoadKeys(serverName) if err != nil { return nil, fmt.Errorf("failed to read cache: %w", err) } else if res != nil { - return res, nil - } else if res, err = sa.Client.ServerKeys(ctx, serverName); err != nil { - return nil, err - } else { - sa.Keys.StoreKeys(res) - return res, nil + if res.HasKey(keyID) { + return res, nil + } else if !sa.Keys.ShouldReQuery(serverName) { + zerolog.Ctx(ctx).Trace(). + Str("server_name", serverName). + Stringer("key_id", keyID). + Msg("Not sending key request for missing key ID, last query was too recent") + return res, nil + } } + res, err = sa.Client.ServerKeys(ctx, serverName) + if err != nil { + sa.Keys.StoreFetchError(serverName, err) + return nil, err + } + sa.Keys.StoreKeys(res) + return res, nil } type fixedLimitedReader struct { @@ -160,11 +187,17 @@ func (sa *ServerAuth) Authenticate(r *http.Request) (*http.Request, *mautrix.Res Msg("Invalid destination in X-Matrix header") return nil, &ErrInvalidDestination } - resp, err := sa.GetKeysWithCache(r.Context(), parsed.Origin) + resp, err := sa.GetKeysWithCache(r.Context(), parsed.Origin, parsed.KeyID) if err != nil { - log.Err(err). - Str("server_name", parsed.Origin). - Msg("Failed to query keys to authenticate request") + if !errors.Is(err, ErrRecentKeyQueryFailed) { + log.Err(err). + Str("server_name", parsed.Origin). + Msg("Failed to query keys to authenticate request") + } else { + log.Trace().Err(err). + Str("server_name", parsed.Origin). + Msg("Failed to query keys to authenticate request (cached error)") + } return nil, &ErrFailedToQueryKeys } else if !resp.VerifySelfSignature() { return nil, &ErrInvalidSelfSignatures diff --git a/federation/signingkey.go b/federation/signingkey.go index 54c62492..87c12a5e 100644 --- a/federation/signingkey.go +++ b/federation/signingkey.go @@ -85,6 +85,15 @@ type ServerKeyResponse struct { Extra map[string]any `json:"-"` } +func (skr *ServerKeyResponse) HasKey(keyID id.KeyID) bool { + if skr == nil { + return false + } else if _, ok := skr.VerifyKeys[keyID]; ok { + return true + } + return false +} + func (skr *ServerKeyResponse) VerifySelfSignature() bool { for keyID, key := range skr.VerifyKeys { if !VerifyJSON(skr.ServerName, keyID, key.Key, skr) { From dec68fb4d730b70f533ab7b750aa19e86dcfc954 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 4 May 2025 00:49:34 +0300 Subject: [PATCH 1127/1647] federation/serverauth: don't unnecessarily export errors --- federation/serverauth.go | 46 ++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/federation/serverauth.go b/federation/serverauth.go index ef4ed246..e2036d30 100644 --- a/federation/serverauth.go +++ b/federation/serverauth.go @@ -49,16 +49,16 @@ func NewServerAuth(client *Client, keyCache KeyCache, getDestination func(auth X var MUnauthorized = mautrix.RespError{ErrCode: "M_UNAUTHORIZED", StatusCode: http.StatusUnauthorized} var ( - ErrMissingAuthHeader = MUnauthorized.WithMessage("Missing Authorization header") - ErrInvalidAuthHeader = MUnauthorized.WithMessage("Authorization header does not start with X-Matrix") - ErrMalformedAuthHeader = MUnauthorized.WithMessage("X-Matrix value is missing required components") - ErrInvalidDestination = MUnauthorized.WithMessage("Invalid destination in X-Matrix header") - ErrFailedToQueryKeys = MUnauthorized.WithMessage("Failed to query server keys") - ErrInvalidSelfSignatures = MUnauthorized.WithMessage("Server keys don't have valid self-signatures") - ErrRequestBodyTooLarge = mautrix.MTooLarge.WithMessage("Request body too large") - ErrInvalidJSONBody = mautrix.MBadJSON.WithMessage("Request body is not valid JSON") - ErrBodyReadFailed = mautrix.MUnknown.WithMessage("Failed to read request body") - ErrInvalidRequestSignature = MUnauthorized.WithMessage("Failed to verify request signature") + errMissingAuthHeader = MUnauthorized.WithMessage("Missing Authorization header") + errInvalidAuthHeader = MUnauthorized.WithMessage("Authorization header does not start with X-Matrix") + errMalformedAuthHeader = MUnauthorized.WithMessage("X-Matrix value is missing required components") + errInvalidDestination = MUnauthorized.WithMessage("Invalid destination in X-Matrix header") + errFailedToQueryKeys = MUnauthorized.WithMessage("Failed to query server keys") + errInvalidSelfSignatures = MUnauthorized.WithMessage("Server keys don't have valid self-signatures") + errRequestBodyTooLarge = mautrix.MTooLarge.WithMessage("Request body too large") + errInvalidJSONBody = mautrix.MBadJSON.WithMessage("Request body is not valid JSON") + errBodyReadFailed = mautrix.MUnknown.WithMessage("Failed to read request body") + errInvalidRequestSignature = MUnauthorized.WithMessage("Failed to verify request signature") ) type XMatrixAuth struct { @@ -166,18 +166,18 @@ func (sa *ServerAuth) Authenticate(r *http.Request) (*http.Request, *mautrix.Res }() log := zerolog.Ctx(r.Context()) if r.ContentLength > sa.MaxBodySize { - return nil, &ErrRequestBodyTooLarge + return nil, &errRequestBodyTooLarge } auth := r.Header.Get("Authorization") if auth == "" { - return nil, &ErrMissingAuthHeader + return nil, &errMissingAuthHeader } else if !strings.HasPrefix(auth, "X-Matrix ") { - return nil, &ErrInvalidAuthHeader + return nil, &errInvalidAuthHeader } parsed := ParseXMatrixAuth(auth) if parsed.Origin == "" || parsed.KeyID == "" || parsed.Signature == "" { log.Trace().Str("auth_header", auth).Msg("Malformed X-Matrix header") - return nil, &ErrMalformedAuthHeader + return nil, &errMalformedAuthHeader } destination := sa.GetDestination(parsed) if destination == "" || (parsed.Destination != "" && parsed.Destination != destination) { @@ -185,7 +185,7 @@ func (sa *ServerAuth) Authenticate(r *http.Request) (*http.Request, *mautrix.Res Str("got_destination", parsed.Destination). Str("expected_destination", destination). Msg("Invalid destination in X-Matrix header") - return nil, &ErrInvalidDestination + return nil, &errInvalidDestination } resp, err := sa.GetKeysWithCache(r.Context(), parsed.Origin, parsed.KeyID) if err != nil { @@ -198,9 +198,9 @@ func (sa *ServerAuth) Authenticate(r *http.Request) (*http.Request, *mautrix.Res Str("server_name", parsed.Origin). Msg("Failed to query keys to authenticate request (cached error)") } - return nil, &ErrFailedToQueryKeys + return nil, &errFailedToQueryKeys } else if !resp.VerifySelfSignature() { - return nil, &ErrInvalidSelfSignatures + return nil, &errInvalidSelfSignatures } key, ok := resp.VerifyKeys[parsed.KeyID] if !ok { @@ -211,16 +211,16 @@ func (sa *ServerAuth) Authenticate(r *http.Request) (*http.Request, *mautrix.Res Msg("Didn't find expected key ID to verify request") return nil, ptr.Ptr(MUnauthorized.WithMessage("Key ID %q not found (got %v)", parsed.KeyID, keys)) } - reqBody, err := io.ReadAll(&fixedLimitedReader{R: r.Body, N: sa.MaxBodySize, Err: ErrRequestBodyTooLarge}) - if errors.Is(err, ErrRequestBodyTooLarge) { - return nil, &ErrRequestBodyTooLarge + reqBody, err := io.ReadAll(&fixedLimitedReader{R: r.Body, N: sa.MaxBodySize, Err: errRequestBodyTooLarge}) + if errors.Is(err, errRequestBodyTooLarge) { + return nil, &errRequestBodyTooLarge } else if err != nil { log.Err(err). Str("server_name", parsed.Origin). Msg("Failed to read request body to authenticate") - return nil, &ErrBodyReadFailed + return nil, &errBodyReadFailed } else if !json.Valid(reqBody) { - return nil, &ErrInvalidJSONBody + return nil, &errInvalidJSONBody } valid := (&signableRequest{ Method: r.Method, @@ -231,7 +231,7 @@ func (sa *ServerAuth) Authenticate(r *http.Request) (*http.Request, *mautrix.Res }).Verify(key.Key, parsed.Signature) if !valid { log.Trace().Msg("Request has invalid signature") - return nil, &ErrInvalidRequestSignature + return nil, &errInvalidRequestSignature } ctx := context.WithValue(r.Context(), contextKeyDestinationServer, destination) ctx = log.With().Str("destination_server_name", destination).Logger().WithContext(ctx) From 0a33bde865aec94f162c7a5351f3b8fd7c6abc37 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 4 May 2025 01:00:08 +0300 Subject: [PATCH 1128/1647] federation/cache: expose noop cache as variable instead of type --- federation/cache.go | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/federation/cache.go b/federation/cache.go index 301091b3..24154974 100644 --- a/federation/cache.go +++ b/federation/cache.go @@ -136,16 +136,18 @@ func (c *InMemoryCache) ShouldReQuery(serverName string) bool { return true } -type NoopCache struct{} +type noopCache struct{} -func (*NoopCache) StoreKeys(_ *ServerKeyResponse) {} -func (*NoopCache) LoadKeys(_ string) (*ServerKeyResponse, error) { return nil, nil } -func (*NoopCache) StoreFetchError(_ string, _ error) {} -func (*NoopCache) ShouldReQuery(_ string) bool { return true } -func (*NoopCache) StoreResolution(_ *ResolvedServerName) {} -func (*NoopCache) LoadResolution(_ string) (*ResolvedServerName, error) { return nil, nil } +func (*noopCache) StoreKeys(_ *ServerKeyResponse) {} +func (*noopCache) LoadKeys(_ string) (*ServerKeyResponse, error) { return nil, nil } +func (*noopCache) StoreFetchError(_ string, _ error) {} +func (*noopCache) ShouldReQuery(_ string) bool { return true } +func (*noopCache) StoreResolution(_ *ResolvedServerName) {} +func (*noopCache) LoadResolution(_ string) (*ResolvedServerName, error) { return nil, nil } var ( - _ ResolutionCache = (*NoopCache)(nil) - _ KeyCache = (*NoopCache)(nil) + _ ResolutionCache = (*noopCache)(nil) + _ KeyCache = (*noopCache)(nil) ) + +var NoopCache *noopCache From 63f35754c6e4e806ffb2e239b5ac9c02854331c5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 4 May 2025 01:04:12 +0300 Subject: [PATCH 1129/1647] federation/serverauth: store verified origin in request context --- federation/context.go | 12 ++++++++++++ federation/serverauth.go | 6 +++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/federation/context.go b/federation/context.go index 8280431f..eedb2dc1 100644 --- a/federation/context.go +++ b/federation/context.go @@ -16,6 +16,7 @@ type contextKey int const ( contextKeyIPPort contextKey = iota contextKeyDestinationServer + contextKeyOriginServer ) func DestinationServerNameFromRequest(r *http.Request) string { @@ -28,3 +29,14 @@ func DestinationServerName(ctx context.Context) string { } return "" } + +func OriginServerNameFromRequest(r *http.Request) string { + return OriginServerName(r.Context()) +} + +func OriginServerName(ctx context.Context) string { + if origin, ok := ctx.Value(contextKeyOriginServer).(string); ok { + return origin + } + return "" +} diff --git a/federation/serverauth.go b/federation/serverauth.go index e2036d30..02780ff8 100644 --- a/federation/serverauth.go +++ b/federation/serverauth.go @@ -234,7 +234,11 @@ func (sa *ServerAuth) Authenticate(r *http.Request) (*http.Request, *mautrix.Res return nil, &errInvalidRequestSignature } ctx := context.WithValue(r.Context(), contextKeyDestinationServer, destination) - ctx = log.With().Str("destination_server_name", destination).Logger().WithContext(ctx) + ctx = context.WithValue(ctx, contextKeyOriginServer, parsed.Origin) + ctx = log.With(). + Str("origin_server_name", parsed.Origin). + Str("destination_server_name", destination). + Logger().WithContext(ctx) modifiedReq := r.WithContext(ctx) modifiedReq.Body = io.NopCloser(bytes.NewReader(reqBody)) return modifiedReq, nil From 5c2bc3b1cf2415d57854d2774127173df1f5d4ce Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 4 May 2025 01:04:40 +0300 Subject: [PATCH 1130/1647] mediaproxy: add option to enforce federation auth --- mediaproxy/mediaproxy.go | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go index ff8b2157..d76439a1 100644 --- a/mediaproxy/mediaproxy.go +++ b/mediaproxy/mediaproxy.go @@ -99,9 +99,8 @@ type GetMediaResponseFile struct { type GetMediaFunc = func(ctx context.Context, mediaID string, params map[string]string) (response GetMediaResponse, err error) type MediaProxy struct { - KeyServer *federation.KeyServer - - ForceProxyLegacyFederation bool + KeyServer *federation.KeyServer + ServerAuth *federation.ServerAuth GetMedia GetMediaFunc PrepareProxyRequest func(*http.Request) @@ -139,6 +138,7 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx type BasicConfig struct { ServerName string `yaml:"server_name" json:"server_name"` ServerKey string `yaml:"server_key" json:"server_key"` + FederationAuth bool `yaml:"federation_auth" json:"federation_auth"` WellKnownResponse string `yaml:"well_known_response" json:"well_known_response"` } @@ -150,6 +150,9 @@ func NewFromConfig(cfg BasicConfig, getMedia GetMediaFunc) (*MediaProxy, error) if cfg.WellKnownResponse != "" { mp.KeyServer.WellKnownTarget = cfg.WellKnownResponse } + if cfg.FederationAuth { + mp.EnableServerAuth(nil, nil) + } return mp, nil } @@ -172,6 +175,19 @@ func (mp *MediaProxy) GetServerKey() *federation.SigningKey { return mp.serverKey } +func (mp *MediaProxy) EnableServerAuth(client *federation.Client, keyCache federation.KeyCache) { + if keyCache == nil { + keyCache = federation.NewInMemoryCache() + } + if client == nil { + resCache, _ := keyCache.(federation.ResolutionCache) + client = federation.NewClient(mp.serverName, mp.serverKey, resCache) + } + mp.ServerAuth = federation.NewServerAuth(client, keyCache, func(auth federation.XMatrixAuth) string { + return mp.GetServerName() + }) +} + func (mp *MediaProxy) RegisterRoutes(router *mux.Router) { if mp.FederationRouter == nil { mp.FederationRouter = router.PathPrefix("/_matrix/federation").Subrouter() @@ -271,9 +287,16 @@ func startMultipart(ctx context.Context, w http.ResponseWriter) *multipart.Write } func (mp *MediaProxy) DownloadMediaFederation(w http.ResponseWriter, r *http.Request) { + if mp.ServerAuth != nil { + var err *mautrix.RespError + r, err = mp.ServerAuth.Authenticate(r) + if err != nil { + err.Write(w) + return + } + } ctx := r.Context() log := zerolog.Ctx(ctx) - // TODO check destination header in X-Matrix auth resp := mp.getMedia(w, r) if resp == nil { From b45dcd42fc38aa0e99a266f1f02496bdb2e5bad3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 4 May 2025 01:09:12 +0300 Subject: [PATCH 1131/1647] federation/serverauth: fix get requests --- federation/serverauth.go | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/federation/serverauth.go b/federation/serverauth.go index 02780ff8..f3875498 100644 --- a/federation/serverauth.go +++ b/federation/serverauth.go @@ -211,16 +211,19 @@ func (sa *ServerAuth) Authenticate(r *http.Request) (*http.Request, *mautrix.Res Msg("Didn't find expected key ID to verify request") return nil, ptr.Ptr(MUnauthorized.WithMessage("Key ID %q not found (got %v)", parsed.KeyID, keys)) } - reqBody, err := io.ReadAll(&fixedLimitedReader{R: r.Body, N: sa.MaxBodySize, Err: errRequestBodyTooLarge}) - if errors.Is(err, errRequestBodyTooLarge) { - return nil, &errRequestBodyTooLarge - } else if err != nil { - log.Err(err). - Str("server_name", parsed.Origin). - Msg("Failed to read request body to authenticate") - return nil, &errBodyReadFailed - } else if !json.Valid(reqBody) { - return nil, &errInvalidJSONBody + var reqBody []byte + if r.ContentLength != 0 && r.Method != http.MethodGet && r.Method != http.MethodHead { + reqBody, err = io.ReadAll(&fixedLimitedReader{R: r.Body, N: sa.MaxBodySize, Err: errRequestBodyTooLarge}) + if errors.Is(err, errRequestBodyTooLarge) { + return nil, &errRequestBodyTooLarge + } else if err != nil { + log.Err(err). + Str("server_name", parsed.Origin). + Msg("Failed to read request body to authenticate") + return nil, &errBodyReadFailed + } else if !json.Valid(reqBody) { + return nil, &errInvalidJSONBody + } } valid := (&signableRequest{ Method: r.Method, @@ -240,7 +243,9 @@ func (sa *ServerAuth) Authenticate(r *http.Request) (*http.Request, *mautrix.Res Str("destination_server_name", destination). Logger().WithContext(ctx) modifiedReq := r.WithContext(ctx) - modifiedReq.Body = io.NopCloser(bytes.NewReader(reqBody)) + if reqBody != nil { + modifiedReq.Body = io.NopCloser(bytes.NewReader(reqBody)) + } return modifiedReq, nil } From 5cd8ba88877562dc6452b4baa57e00b9b27fa1cc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 4 May 2025 01:14:44 +0300 Subject: [PATCH 1132/1647] federation/serverauth: fix go 1.23 compatibility --- federation/serverauth.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/federation/serverauth.go b/federation/serverauth.go index f3875498..22ce8403 100644 --- a/federation/serverauth.go +++ b/federation/serverauth.go @@ -80,7 +80,8 @@ func (xma XMatrixAuth) String() string { func ParseXMatrixAuth(auth string) (xma XMatrixAuth) { auth = strings.TrimPrefix(auth, "X-Matrix ") - for part := range strings.SplitSeq(auth, ",") { + // TODO upgrade to strings.SplitSeq after Go 1.24 is the minimum + for _, part := range strings.Split(auth, ",") { part = strings.TrimSpace(part) eqIdx := strings.Index(part, "=") if eqIdx == -1 || strings.Count(part, "=") > 1 { From 6eb4c7b17f97004887d76145553a1ab456c0ad37 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 4 May 2025 14:09:06 +0300 Subject: [PATCH 1133/1647] crypto/keybackup: allow importing room keys without saving --- crypto/keybackup.go | 56 ++++++++++++++++++++++++++++++-------------- crypto/keyimport.go | 2 +- crypto/keysharing.go | 2 +- crypto/machine.go | 4 ++-- 4 files changed, 43 insertions(+), 21 deletions(-) diff --git a/crypto/keybackup.go b/crypto/keybackup.go index 5724002b..d8b3d715 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -13,6 +13,7 @@ import ( "maunium.net/go/mautrix/crypto/backup" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/signatures" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -161,11 +162,15 @@ var ( ErrFailedToStoreNewInboundGroupSessionFromBackup = errors.New("failed to store new inbound group session from key backup") ) -func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID, keyBackupData *backup.MegolmSessionData) (*InboundGroupSession, error) { - log := zerolog.Ctx(ctx).With(). - Str("room_id", roomID.String()). - Str("session_id", sessionID.String()). - Logger() +func (mach *OlmMachine) ImportRoomKeyFromBackupWithoutSaving( + ctx context.Context, + version id.KeyBackupVersion, + roomID id.RoomID, + config *event.EncryptionEventContent, + sessionID id.SessionID, + keyBackupData *backup.MegolmSessionData, +) (*InboundGroupSession, error) { + log := zerolog.Ctx(ctx) if keyBackupData.Algorithm != id.AlgorithmMegolmV1 { return nil, fmt.Errorf("%w %s", ErrUnknownAlgorithmInKeyBackup, keyBackupData.Algorithm) } @@ -175,6 +180,8 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id. return nil, fmt.Errorf("failed to import inbound group session: %w", err) } else if igsInternal.ID() != sessionID { log.Warn(). + Stringer("room_id", roomID). + Stringer("session_id", sessionID). Stringer("actual_session_id", igsInternal.ID()). Msg("Mismatched session ID while creating inbound group session from key backup") return nil, ErrMismatchingSessionIDInKeyBackup @@ -182,19 +189,12 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id. var maxAge time.Duration var maxMessages int - if config, err := mach.StateStore.GetEncryptionEvent(ctx, roomID); err != nil { - log.Error().Err(err).Msg("Failed to get encryption event for room") - } else if config != nil { + if config != nil { maxAge = time.Duration(config.RotationPeriodMillis) * time.Millisecond maxMessages = config.RotationPeriodMessages } - firstKnownIndex := igsInternal.FirstKnownIndex() - if firstKnownIndex > 0 { - log.Warn().Uint32("first_known_index", firstKnownIndex).Msg("Importing partial session") - } - - igs := &InboundGroupSession{ + return &InboundGroupSession{ Internal: igsInternal, SigningKey: keyBackupData.SenderClaimedKeys.Ed25519, SenderKey: keyBackupData.SenderKey, @@ -206,11 +206,33 @@ func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id. MaxAge: maxAge.Milliseconds(), MaxMessages: maxMessages, KeyBackupVersion: version, + }, nil +} + +func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID, keyBackupData *backup.MegolmSessionData) (*InboundGroupSession, error) { + config, err := mach.StateStore.GetEncryptionEvent(ctx, roomID) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("room_id", roomID). + Stringer("session_id", sessionID). + Msg("Failed to get encryption event for room") } - err = mach.CryptoStore.PutGroupSession(ctx, igs) + imported, err := mach.ImportRoomKeyFromBackupWithoutSaving(ctx, version, roomID, config, sessionID, keyBackupData) + if err != nil { + return nil, err + } + firstKnownIndex := imported.Internal.FirstKnownIndex() + if firstKnownIndex > 0 { + zerolog.Ctx(ctx).Warn(). + Stringer("room_id", roomID). + Stringer("session_id", sessionID). + Uint32("first_known_index", firstKnownIndex). + Msg("Importing partial session") + } + err = mach.CryptoStore.PutGroupSession(ctx, imported) if err != nil { return nil, fmt.Errorf("%w: %w", ErrFailedToStoreNewInboundGroupSessionFromBackup, err) } - mach.markSessionReceived(ctx, roomID, sessionID, firstKnownIndex) - return igs, nil + mach.MarkSessionReceived(ctx, roomID, sessionID, firstKnownIndex) + return imported, nil } diff --git a/crypto/keyimport.go b/crypto/keyimport.go index 1dc7f6cc..36ad6b9c 100644 --- a/crypto/keyimport.go +++ b/crypto/keyimport.go @@ -127,7 +127,7 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor if err != nil { return false, fmt.Errorf("failed to store imported session: %w", err) } - mach.markSessionReceived(ctx, session.RoomID, igs.ID(), firstKnownIndex) + mach.MarkSessionReceived(ctx, session.RoomID, igs.ID(), firstKnownIndex) return true, nil } diff --git a/crypto/keysharing.go b/crypto/keysharing.go index ea0ae65d..e78bb65c 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -200,7 +200,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt log.Error().Err(err).Msg("Failed to store new inbound group session") return false } - mach.markSessionReceived(ctx, content.RoomID, content.SessionID, firstKnownIndex) + mach.MarkSessionReceived(ctx, content.RoomID, content.SessionID, firstKnownIndex) log.Debug().Msg("Received forwarded inbound group session") return true } diff --git a/crypto/machine.go b/crypto/machine.go index e2af298b..cac91bf8 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -584,7 +584,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen log.Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session") return fmt.Errorf("failed to store new inbound group session: %w", err) } - mach.markSessionReceived(ctx, roomID, sessionID, igs.Internal.FirstKnownIndex()) + mach.MarkSessionReceived(ctx, roomID, sessionID, igs.Internal.FirstKnownIndex()) log.Debug(). Str("session_id", sessionID.String()). Str("sender_key", senderKey.String()). @@ -595,7 +595,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen return nil } -func (mach *OlmMachine) markSessionReceived(ctx context.Context, roomID id.RoomID, id id.SessionID, firstKnownIndex uint32) { +func (mach *OlmMachine) MarkSessionReceived(ctx context.Context, roomID id.RoomID, id id.SessionID, firstKnownIndex uint32) { if mach.SessionReceived != nil { mach.SessionReceived(ctx, roomID, id, firstKnownIndex) } From ba43e615f8e971f5290d7807c68cde3a2cc57b1f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 6 May 2025 18:49:54 +0300 Subject: [PATCH 1134/1647] bridgev2/login: add wait_for_url_pattern field to cookie logins --- bridgev2/login.go | 6 ++++++ bridgev2/matrix/provisioning.yaml | 14 ++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/bridgev2/login.go b/bridgev2/login.go index b28ccfdb..1fa3afbc 100644 --- a/bridgev2/login.go +++ b/bridgev2/login.go @@ -159,6 +159,12 @@ type LoginCookiesParams struct { // The snippet will evaluate to a promise that resolves when the relevant fields are found. // Fields that are not present in the promise result must be extracted another way. ExtractJS string `json:"extract_js,omitempty"` + // A regex pattern that the URL should match before the client closes the webview. + // + // The client may submit the login if the user closes the webview after all cookies are collected + // even if this URL is not reached, but it should only automatically close the webview after + // both cookies and the URL match. + WaitForURLPattern string `json:"wait_for_url_pattern,omitempty"` } type LoginInputFieldType string diff --git a/bridgev2/matrix/provisioning.yaml b/bridgev2/matrix/provisioning.yaml index bf6c6f3d..b9879ea5 100644 --- a/bridgev2/matrix/provisioning.yaml +++ b/bridgev2/matrix/provisioning.yaml @@ -671,6 +671,20 @@ components: user_agent: type: string description: An optional user agent that the webview should use. + wait_for_url_pattern: + type: string + description: | + A regex pattern that the URL should match before the client closes the webview. + + The client may submit the login if the user closes the webview after all cookies are collected + even if this URL is not reached, but it should only automatically close the webview after + both cookies and the URL match. + extract_js: + type: string + description: | + A JavaScript snippet that can extract some or all of the fields. + The snippet will evaluate to a promise that resolves when the relevant fields are found. + Fields that are not present in the promise result must be extracted another way. fields: type: array description: The list of cookies or other stored data that must be extracted. From 37d486dfcd0e19b36e956835c2fcfba90ea6d606 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 6 May 2025 20:53:00 +0300 Subject: [PATCH 1135/1647] bridgev2/portal: ignore fake mxids when bridging read receipts --- bridgev2/database/message.go | 7 ++++++- bridgev2/portal.go | 9 ++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index fd6b65d8..6447ac1d 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -64,7 +64,8 @@ const ( getLastMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp DESC, part_id DESC LIMIT 1` getLastNInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp DESC, part_id DESC LIMIT $4` - getLastMessagePartAtOrBeforeTimeQuery = getMessageBaseQuery + `WHERE bridge_id = $1 AND room_id=$2 AND room_receiver=$3 AND timestamp<=$4 ORDER BY timestamp DESC, part_id DESC LIMIT 1` + getLastMessagePartAtOrBeforeTimeQuery = getMessageBaseQuery + `WHERE bridge_id = $1 AND room_id=$2 AND room_receiver=$3 AND timestamp<=$4 ORDER BY timestamp DESC, part_id DESC LIMIT 1` + getLastNonFakeMessagePartAtOrBeforeTimeQuery = getMessageBaseQuery + `WHERE bridge_id = $1 AND room_id=$2 AND room_receiver=$3 AND timestamp<=$4 AND mxid NOT LIKE '~fake:%' ORDER BY timestamp DESC, part_id DESC LIMIT 1` countMessagesInPortalQuery = ` SELECT COUNT(*) FROM message WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 @@ -128,6 +129,10 @@ func (mq *MessageQuery) GetLastPartAtOrBeforeTime(ctx context.Context, portal ne return mq.QueryOne(ctx, getLastMessagePartAtOrBeforeTimeQuery, mq.BridgeID, portal.ID, portal.Receiver, maxTS.UnixNano()) } +func (mq *MessageQuery) GetLastNonFakePartAtOrBeforeTime(ctx context.Context, portal networkid.PortalKey, maxTS time.Time) (*Message, error) { + return mq.QueryOne(ctx, getLastNonFakeMessagePartAtOrBeforeTimeQuery, mq.BridgeID, portal.ID, portal.Receiver, maxTS.UnixNano()) +} + func (mq *MessageQuery) GetMessagesBetweenTimeQuery(ctx context.Context, portal networkid.PortalKey, start, end time.Time) ([]*Message, error) { return mq.QueryMany(ctx, getMessagesBetweenTimeQuery, mq.BridgeID, portal.ID, portal.Receiver, start.UnixNano(), end.UnixNano()) } diff --git a/bridgev2/portal.go b/bridgev2/portal.go index d88f5a7c..3e0353d0 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2638,7 +2638,6 @@ func (portal *Portal) redactMessageParts(ctx context.Context, parts []*database. } func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserLogin, evt RemoteReadReceipt) { - // TODO exclude fake mxids log := zerolog.Ctx(ctx) var err error var lastTarget *database.Message @@ -2651,6 +2650,10 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL } else if lastTarget == nil { log.Debug().Str("last_target_id", string(lastTargetID)). Msg("Last target message not found") + } else if lastTarget.HasFakeMXID() { + log.Debug().Str("last_target_id", string(lastTargetID)). + Msg("Last target message is fake") + lastTarget = nil } } if lastTarget == nil { @@ -2660,14 +2663,14 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL log.Err(err).Str("target_id", string(targetID)). Msg("Failed to get target message for read receipt") return - } else if target != nil && (lastTarget == nil || target.Timestamp.After(lastTarget.Timestamp)) { + } else if target != nil && !target.HasFakeMXID() && (lastTarget == nil || target.Timestamp.After(lastTarget.Timestamp)) { lastTarget = target } } } readUpTo := evt.GetReadUpTo() if lastTarget == nil && !readUpTo.IsZero() { - lastTarget, err = portal.Bridge.DB.Message.GetLastPartAtOrBeforeTime(ctx, portal.PortalKey, readUpTo) + lastTarget, err = portal.Bridge.DB.Message.GetLastNonFakePartAtOrBeforeTime(ctx, portal.PortalKey, readUpTo) if err != nil { log.Err(err).Time("read_up_to", readUpTo).Msg("Failed to get target message for read receipt") } From a7faac33c8158cadb369713c198104be08df78af Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 6 May 2025 20:55:26 +0300 Subject: [PATCH 1136/1647] bridgev2/portal: add fallback if last receipt target is fake --- bridgev2/portal.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 3e0353d0..c49f041c 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2641,6 +2641,7 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL log := zerolog.Ctx(ctx) var err error var lastTarget *database.Message + readUpTo := evt.GetReadUpTo() if lastTargetID := evt.GetLastReceiptTarget(); lastTargetID != "" { lastTarget, err = portal.Bridge.DB.Message.GetLastPartByID(ctx, portal.Receiver, lastTargetID) if err != nil { @@ -2653,6 +2654,9 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL } else if lastTarget.HasFakeMXID() { log.Debug().Str("last_target_id", string(lastTargetID)). Msg("Last target message is fake") + if readUpTo.IsZero() { + readUpTo = lastTarget.Timestamp + } lastTarget = nil } } @@ -2668,7 +2672,6 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL } } } - readUpTo := evt.GetReadUpTo() if lastTarget == nil && !readUpTo.IsZero() { lastTarget, err = portal.Bridge.DB.Message.GetLastNonFakePartAtOrBeforeTime(ctx, portal.PortalKey, readUpTo) if err != nil { From bef23edaea2851d0e8a75ee054197456b89d6ba9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 6 May 2025 22:50:46 +0300 Subject: [PATCH 1137/1647] crypto/keysharing: ensure forwarding chains is always set --- crypto/keysharing.go | 5 +---- crypto/sessions.go | 9 ++++++++- crypto/sql_store.go | 2 ++ 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/crypto/keysharing.go b/crypto/keysharing.go index e78bb65c..e6d8d603 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -347,9 +347,6 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User mach.rejectKeyRequest(ctx, KeyShareRejectInternalError, device, content.Body) return } - if igs.ForwardingChains == nil { - igs.ForwardingChains = []string{} - } forwardedRoomKey := event.Content{ Parsed: &event.ForwardedRoomKeyEventContent{ @@ -360,7 +357,7 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User SessionKey: string(exportedKey), }, SenderKey: content.Body.SenderKey, - ForwardingKeyChain: igs.ForwardingChains, + ForwardingKeyChain: igs.ForwardingChainsOrEmpty(), SenderClaimedKey: igs.SigningKey, }, } diff --git a/crypto/sessions.go b/crypto/sessions.go index 457a0a43..8724d05a 100644 --- a/crypto/sessions.go +++ b/crypto/sessions.go @@ -125,7 +125,7 @@ func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomI SigningKey: signingKey, SenderKey: senderKey, RoomID: roomID, - ForwardingChains: nil, + ForwardingChains: []string{}, ReceivedAt: time.Now().UTC(), MaxAge: maxAge.Milliseconds(), MaxMessages: maxMessages, @@ -133,6 +133,13 @@ func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomI }, nil } +func (igs *InboundGroupSession) ForwardingChainsOrEmpty() []string { + if igs.ForwardingChains == nil { + return []string{} + } + return igs.ForwardingChains +} + func (igs *InboundGroupSession) ID() id.SessionID { if igs.id == "" { igs.id = igs.Internal.ID() diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 514c1e8c..9d8e7ed7 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -509,6 +509,8 @@ func (store *SQLCryptoStore) postScanInboundGroupSession(sessionBytes, ratchetSa } if forwardingChains != "" { chains = strings.Split(forwardingChains, ",") + } else { + chains = []string{} } var rs RatchetSafety if len(ratchetSafetyBytes) > 0 { From 0ffe3524f68bc44beadef265facc6c9adc8b17f5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 6 May 2025 22:54:32 +0300 Subject: [PATCH 1138/1647] crypto/sql_store: ensure forwarding chains is always set instead of having fallback in getter --- crypto/keysharing.go | 2 +- crypto/sessions.go | 7 ------- crypto/sql_store.go | 3 +++ 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/crypto/keysharing.go b/crypto/keysharing.go index e6d8d603..f1d427af 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -357,7 +357,7 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User SessionKey: string(exportedKey), }, SenderKey: content.Body.SenderKey, - ForwardingKeyChain: igs.ForwardingChainsOrEmpty(), + ForwardingKeyChain: igs.ForwardingChains, SenderClaimedKey: igs.SigningKey, }, } diff --git a/crypto/sessions.go b/crypto/sessions.go index 8724d05a..aecb0416 100644 --- a/crypto/sessions.go +++ b/crypto/sessions.go @@ -133,13 +133,6 @@ func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomI }, nil } -func (igs *InboundGroupSession) ForwardingChainsOrEmpty() []string { - if igs.ForwardingChains == nil { - return []string{} - } - return igs.ForwardingChains -} - func (igs *InboundGroupSession) ID() id.SessionID { if igs.id == "" { igs.id = igs.Internal.ID() diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 9d8e7ed7..b0625763 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -326,6 +326,9 @@ func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, session *Inbou if err != nil { return err } + if session.ForwardingChains == nil { + session.ForwardingChains = []string{} + } forwardingChains := strings.Join(session.ForwardingChains, ",") ratchetSafety, err := json.Marshal(&session.RatchetSafety) if err != nil { From 72f6229f40c6f8b8f49c5b70d3a002d39330ce23 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 6 May 2025 23:18:23 +0300 Subject: [PATCH 1139/1647] crypto: fix key export test --- crypto/keyexport_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crypto/keyexport_test.go b/crypto/keyexport_test.go index 15d944d5..47616a20 100644 --- a/crypto/keyexport_test.go +++ b/crypto/keyexport_test.go @@ -31,5 +31,5 @@ func TestExportKeys(t *testing.T) { )) data, err := crypto.ExportKeys("meow", []*crypto.InboundGroupSession{sess}) assert.NoError(t, err) - assert.Len(t, data, 840) + assert.Len(t, data, 836) } From c93d30a83c87e1cb9169d9786da70b9301df2dc1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 7 May 2025 14:47:04 +0300 Subject: [PATCH 1140/1647] bridgev2: add option to deduplicate Matrix messages by event or transaction ID --- bridgev2/bridgeconfig/config.go | 37 ++++++++++--------- bridgev2/bridgeconfig/upgrade.go | 1 + bridgev2/database/message.go | 30 +++++++++++---- bridgev2/database/upgrades/00-latest.sql | 6 ++- .../upgrades/22-message-send-txn-id.sql | 6 +++ bridgev2/matrix/mxmain/example-config.yaml | 2 + bridgev2/portal.go | 15 ++++++++ 7 files changed, 69 insertions(+), 28 deletions(-) create mode 100644 bridgev2/database/upgrades/22-message-send-txn-id.sql diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index 937d9441..bd7746d1 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -58,24 +58,25 @@ type CleanupOnLogouts struct { } type BridgeConfig struct { - CommandPrefix string `yaml:"command_prefix"` - PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"` - PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"` - AsyncEvents bool `yaml:"async_events"` - SplitPortals bool `yaml:"split_portals"` - ResendBridgeInfo bool `yaml:"resend_bridge_info"` - NoBridgeInfoStateKey bool `yaml:"no_bridge_info_state_key"` - BridgeStatusNotices string `yaml:"bridge_status_notices"` - BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` - BridgeNotices bool `yaml:"bridge_notices"` - TagOnlyOnCreate bool `yaml:"tag_only_on_create"` - OnlyBridgeTags []event.RoomTag `yaml:"only_bridge_tags"` - MuteOnlyOnCreate bool `yaml:"mute_only_on_create"` - OutgoingMessageReID bool `yaml:"outgoing_message_re_id"` - CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"` - Relay RelayConfig `yaml:"relay"` - Permissions PermissionConfig `yaml:"permissions"` - Backfill BackfillConfig `yaml:"backfill"` + CommandPrefix string `yaml:"command_prefix"` + PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"` + PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"` + AsyncEvents bool `yaml:"async_events"` + SplitPortals bool `yaml:"split_portals"` + ResendBridgeInfo bool `yaml:"resend_bridge_info"` + NoBridgeInfoStateKey bool `yaml:"no_bridge_info_state_key"` + BridgeStatusNotices string `yaml:"bridge_status_notices"` + BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` + BridgeNotices bool `yaml:"bridge_notices"` + TagOnlyOnCreate bool `yaml:"tag_only_on_create"` + OnlyBridgeTags []event.RoomTag `yaml:"only_bridge_tags"` + MuteOnlyOnCreate bool `yaml:"mute_only_on_create"` + DeduplicateMatrixMessages bool `yaml:"deduplicate_matrix_messages"` + OutgoingMessageReID bool `yaml:"outgoing_message_re_id"` + CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"` + Relay RelayConfig `yaml:"relay"` + Permissions PermissionConfig `yaml:"permissions"` + Backfill BackfillConfig `yaml:"backfill"` } type MatrixConfig struct { diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 95370681..18b98263 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -37,6 +37,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "bridge", "tag_only_on_create") helper.Copy(up.List, "bridge", "only_bridge_tags") helper.Copy(up.Bool, "bridge", "mute_only_on_create") + helper.Copy(up.Bool, "bridge", "deduplicate_matrix_messages") helper.Copy(up.Bool, "bridge", "cleanup_on_logout", "enabled") helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "private") helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "relayed") diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index 6447ac1d..9b3b1493 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -43,19 +43,23 @@ type Message struct { ThreadRoot networkid.MessageID ReplyTo networkid.MessageOptionalPartID + SendTxnID networkid.RawTransactionID + Metadata any } const ( getMessageBaseQuery = ` SELECT rowid, bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, sender_mxid, - timestamp, edit_count, double_puppeted, thread_root_id, reply_to_id, reply_to_part_id, metadata + timestamp, edit_count, double_puppeted, thread_root_id, reply_to_id, reply_to_part_id, + send_txn_id, metadata FROM message ` getAllMessagePartsByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3` getMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 AND part_id=$4` getMessagePartByRowIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND rowid=$2` getMessageByMXIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` + getMessageByTxnIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND (mxid=$3 OR send_txn_id=$4)` getLastMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id DESC LIMIT 1` getFirstMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id ASC LIMIT 1` getMessagesBetweenTimeQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND timestamp>$4 AND timestamp<=$5` @@ -74,16 +78,17 @@ const ( insertMessageQuery = ` INSERT INTO message ( bridge_id, id, part_id, mxid, room_id, room_receiver, sender_id, sender_mxid, - timestamp, edit_count, double_puppeted, thread_root_id, reply_to_id, reply_to_part_id, metadata + timestamp, edit_count, double_puppeted, thread_root_id, reply_to_id, reply_to_part_id, + send_txn_id, metadata ) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) RETURNING rowid ` updateMessageQuery = ` UPDATE message SET id=$2, part_id=$3, mxid=$4, room_id=$5, room_receiver=$6, sender_id=$7, sender_mxid=$8, timestamp=$9, edit_count=$10, double_puppeted=$11, thread_root_id=$12, reply_to_id=$13, - reply_to_part_id=$14, metadata=$15 - WHERE bridge_id=$1 AND rowid=$16 + reply_to_part_id=$14, send_txn_id=$15, metadata=$16 + WHERE bridge_id=$1 AND rowid=$17 ` deleteAllMessagePartsByIDQuery = ` DELETE FROM message WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 @@ -105,6 +110,10 @@ func (mq *MessageQuery) GetPartByMXID(ctx context.Context, mxid id.EventID) (*Me return mq.QueryOne(ctx, getMessageByMXIDQuery, mq.BridgeID, mxid) } +func (mq *MessageQuery) GetPartByTxnID(ctx context.Context, receiver networkid.UserLoginID, mxid id.EventID, txnID networkid.RawTransactionID) (*Message, error) { + return mq.QueryOne(ctx, getMessageByTxnIDQuery, mq.BridgeID, receiver, mxid, txnID) +} + func (mq *MessageQuery) GetLastPartByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID) (*Message, error) { return mq.QueryOne(ctx, getLastMessagePartByIDQuery, mq.BridgeID, receiver, id) } @@ -178,11 +187,12 @@ func (mq *MessageQuery) CountMessagesInPortal(ctx context.Context, key networkid func (m *Message) Scan(row dbutil.Scannable) (*Message, error) { var timestamp int64 - var threadRootID, replyToID, replyToPartID sql.NullString + var threadRootID, replyToID, replyToPartID, sendTxnID sql.NullString var doublePuppeted sql.NullBool err := row.Scan( &m.RowID, &m.BridgeID, &m.ID, &m.PartID, &m.MXID, &m.Room.ID, &m.Room.Receiver, &m.SenderID, &m.SenderMXID, - ×tamp, &m.EditCount, &doublePuppeted, &threadRootID, &replyToID, &replyToPartID, dbutil.JSON{Data: m.Metadata}, + ×tamp, &m.EditCount, &doublePuppeted, &threadRootID, &replyToID, &replyToPartID, &sendTxnID, + dbutil.JSON{Data: m.Metadata}, ) if err != nil { return nil, err @@ -196,6 +206,9 @@ func (m *Message) Scan(row dbutil.Scannable) (*Message, error) { m.ReplyTo.PartID = (*networkid.PartID)(&replyToPartID.String) } } + if sendTxnID.Valid { + m.SendTxnID = networkid.RawTransactionID(sendTxnID.String) + } return m, nil } @@ -210,7 +223,8 @@ func (m *Message) sqlVariables() []any { return []any{ m.BridgeID, m.ID, m.PartID, m.MXID, m.Room.ID, m.Room.Receiver, m.SenderID, m.SenderMXID, m.Timestamp.UnixNano(), m.EditCount, m.IsDoublePuppeted, dbutil.StrPtr(m.ThreadRoot), - dbutil.StrPtr(m.ReplyTo.MessageID), m.ReplyTo.PartID, dbutil.JSON{Data: m.Metadata}, + dbutil.StrPtr(m.ReplyTo.MessageID), m.ReplyTo.PartID, dbutil.StrPtr(m.SendTxnID), + dbutil.JSON{Data: m.Metadata}, } } diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index 7ad01a87..4eea05bb 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v21 (compatible with v9+): Latest revision +-- v0 -> v22 (compatible with v9+): Latest revision CREATE TABLE "user" ( bridge_id TEXT NOT NULL, mxid TEXT NOT NULL, @@ -108,6 +108,7 @@ CREATE TABLE message ( thread_root_id TEXT, reply_to_id TEXT, reply_to_part_id TEXT, + send_txn_id TEXT, metadata jsonb NOT NULL, CONSTRAINT message_room_fkey FOREIGN KEY (bridge_id, room_id, room_receiver) @@ -117,7 +118,8 @@ CREATE TABLE message ( REFERENCES ghost (bridge_id, id) ON DELETE CASCADE ON UPDATE CASCADE, CONSTRAINT message_real_pkey UNIQUE (bridge_id, room_receiver, id, part_id), - CONSTRAINT message_mxid_unique UNIQUE (bridge_id, mxid) + CONSTRAINT message_mxid_unique UNIQUE (bridge_id, mxid), + CONSTRAINT message_txn_id_unique UNIQUE (bridge_id, room_receiver, send_txn_id) ); CREATE INDEX message_room_idx ON message (bridge_id, room_id, room_receiver); diff --git a/bridgev2/database/upgrades/22-message-send-txn-id.sql b/bridgev2/database/upgrades/22-message-send-txn-id.sql new file mode 100644 index 00000000..8933984e --- /dev/null +++ b/bridgev2/database/upgrades/22-message-send-txn-id.sql @@ -0,0 +1,6 @@ +-- v22 (compatible with v9+): Add message send transaction ID column +ALTER TABLE message ADD COLUMN send_txn_id TEXT; +-- only: postgres +ALTER TABLE message ADD CONSTRAINT message_txn_id_unique UNIQUE (bridge_id, room_receiver, send_txn_id); +-- only: sqlite +CREATE UNIQUE INDEX message_txn_id_unique ON message (bridge_id, room_receiver, send_txn_id); diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 1d4e18cf..4dee2650 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -38,6 +38,8 @@ bridge: # Should room mute status only be synced when creating the portal? # Like tags, mutes can't currently be synced back to the remote network. mute_only_on_create: true + # Should the bridge check the db to ensure that incoming events haven't been handled before + deduplicate_matrix_messages: false # What should be done to portal rooms when a user logs out or is logged out? # Permitted values: diff --git a/bridgev2/portal.go b/bridgev2/portal.go index c49f041c..cfdad822 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -951,6 +951,18 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin ThreadRoot: threadRoot, ReplyTo: replyTo, } + if portal.Bridge.Config.DeduplicateMatrixMessages { + if part, err := portal.Bridge.DB.Message.GetPartByTxnID(ctx, portal.Receiver, evt.ID, wrappedMsgEvt.InputTransactionID); err != nil { + log.Err(err).Msg("Failed to check db if message is already sent") + } else if part != nil { + log.Debug(). + Stringer("message_mxid", part.MXID). + Stringer("input_event_id", evt.ID). + Msg("Message already sent, ignoring") + return + } + } + var resp *MatrixMessageResponse if msgContent != nil { resp, err = sender.Client.HandleMatrixMessage(ctx, wrappedMsgEvt) @@ -1091,6 +1103,9 @@ func (evt *MatrixMessage) fillDBMessage(message *database.Message) *database.Mes if message.SenderMXID == "" { message.SenderMXID = evt.Event.Sender } + if message.SendTxnID != "" { + message.SendTxnID = evt.InputTransactionID + } return message } From 4ffe1d23e9e7edd6337db7a6a6639179740bf7e4 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 7 May 2025 14:19:01 +0100 Subject: [PATCH 1141/1647] client: don't attempt to make requests if the homeserver URL isn't set (#376) Quick guard for where the client is created without using the `NewClient` method. --- client.go | 3 +++ error.go | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index 5f47aead..c8e366b0 100644 --- a/client.go +++ b/client.go @@ -481,6 +481,9 @@ func (cli *Client) MakeFullRequestWithResp(ctx context.Context, params FullReque if cli == nil { return nil, nil, ErrClientIsNil } + if cli.HomeserverURL == nil || cli.HomeserverURL.Scheme == "" { + return nil, nil, ErrClientHasNoHomeserver + } if params.MaxAttempts == 0 { maxAttempts, ok := ctx.Value(MaxAttemptsContextKey).(int) if ok && maxAttempts > 0 { diff --git a/error.go b/error.go index 6f5dbe72..6f4880df 100644 --- a/error.go +++ b/error.go @@ -77,7 +77,8 @@ var ( ) var ( - ErrClientIsNil = errors.New("client is nil") + ErrClientIsNil = errors.New("client is nil") + ErrClientHasNoHomeserver = errors.New("client has no homeserver set") ) // HTTPError An HTTP Error response, which may wrap an underlying native Go Error. From 27769dfc98bebdcec16bb2293029904109a2b9df Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 7 May 2025 15:33:33 +0100 Subject: [PATCH 1142/1647] bridgev2: add shared event handling context This context is then passed into the network connectors handlers and message conversion functions which may require making network requests, which before this would not be canceled on bridge stop. --- bridgev2/bridge.go | 10 ++++++++++ bridgev2/portal.go | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 38f7ce1d..2e1fe8f1 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -54,6 +54,9 @@ type Bridge struct { wakeupBackfillQueue chan struct{} stopBackfillQueue *exsync.Event + + backgroundCtx context.Context + cancelBackgroundCtx context.CancelFunc } func NewBridge( @@ -108,6 +111,10 @@ func (e DBUpgradeError) Unwrap() error { } func (br *Bridge) Start(ctx context.Context) error { + if br.backgroundCtx == nil || br.backgroundCtx.Err() != nil { + // Ensure we have a valid event handling context + br.backgroundCtx, br.cancelBackgroundCtx = context.WithCancel(context.Background()) + } ctx = br.Log.WithContext(ctx) err := br.StartConnectors(ctx) if err != nil { @@ -332,6 +339,9 @@ func (br *Bridge) stop(isRunOnce bool) { wg.Wait() } br.Matrix.Stop() + if br.cancelBackgroundCtx != nil { + br.cancelBackgroundCtx() + } if stopNet, ok := br.Network.(StoppableNetwork); ok { stopNet.Stop() } diff --git a/bridgev2/portal.go b/bridgev2/portal.go index cfdad822..a7ca5995 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -375,7 +375,7 @@ func (portal *Portal) getEventCtxWithLog(rawEvt any, idx int) context.Context { case *portalCreateEvent: return evt.ctx } - return logWith.Logger().WithContext(context.Background()) + return logWith.Logger().WithContext(portal.Bridge.backgroundCtx) } func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCallback func()) { From 376fa1f36898e60ba2d4e4142499ed6470ab28c0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 8 May 2025 15:25:20 +0300 Subject: [PATCH 1143/1647] bridgev2: fix initializing background context --- bridgev2/bridge.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 2e1fe8f1..bef0c79c 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -111,10 +111,6 @@ func (e DBUpgradeError) Unwrap() error { } func (br *Bridge) Start(ctx context.Context) error { - if br.backgroundCtx == nil || br.backgroundCtx.Err() != nil { - // Ensure we have a valid event handling context - br.backgroundCtx, br.cancelBackgroundCtx = context.WithCancel(context.Background()) - } ctx = br.Log.WithContext(ctx) err := br.StartConnectors(ctx) if err != nil { @@ -174,6 +170,9 @@ func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, pa func (br *Bridge) StartConnectors(ctx context.Context) error { br.Log.Info().Msg("Starting bridge") + if br.backgroundCtx == nil || br.backgroundCtx.Err() != nil { + br.backgroundCtx, br.cancelBackgroundCtx = context.WithCancel(context.Background()) + } if !br.ExternallyManagedDB { err := br.DB.Upgrade(ctx) From 23d91b64cb68dc897797a1a48de530c5b95c8a00 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 9 May 2025 14:25:43 +0300 Subject: [PATCH 1144/1647] bridgev2: fall back to remote ID for state update notices --- bridgev2/bridgestate.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/bridgev2/bridgestate.go b/bridgev2/bridgestate.go index 148b522c..81ec8160 100644 --- a/bridgev2/bridgestate.go +++ b/bridgev2/bridgestate.go @@ -89,7 +89,11 @@ func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.Bridge bsq.login.Log.Err(err).Msg("Failed to get management room") return } - message := fmt.Sprintf("State update for %s: `%s`", bsq.login.RemoteName, state.StateEvent) + name := bsq.login.RemoteName + if name == "" { + name = fmt.Sprintf("`%s`", bsq.login.ID) + } + message := fmt.Sprintf("State update for %s: `%s`", name, state.StateEvent) if state.Error != "" { message += fmt.Sprintf(" (`%s`)", state.Error) } From a0191c8f5847a67568c86de7c958ad89eeada97c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 9 May 2025 15:16:14 +0300 Subject: [PATCH 1145/1647] bridgev2: expose background context --- bridgev2/bridge.go | 6 +++--- bridgev2/portal.go | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index bef0c79c..05a67b6a 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -55,7 +55,7 @@ type Bridge struct { wakeupBackfillQueue chan struct{} stopBackfillQueue *exsync.Event - backgroundCtx context.Context + BackgroundCtx context.Context cancelBackgroundCtx context.CancelFunc } @@ -170,8 +170,8 @@ func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, pa func (br *Bridge) StartConnectors(ctx context.Context) error { br.Log.Info().Msg("Starting bridge") - if br.backgroundCtx == nil || br.backgroundCtx.Err() != nil { - br.backgroundCtx, br.cancelBackgroundCtx = context.WithCancel(context.Background()) + if br.BackgroundCtx == nil || br.BackgroundCtx.Err() != nil { + br.BackgroundCtx, br.cancelBackgroundCtx = context.WithCancel(context.Background()) } if !br.ExternallyManagedDB { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index a7ca5995..63081f57 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -375,7 +375,7 @@ func (portal *Portal) getEventCtxWithLog(rawEvt any, idx int) context.Context { case *portalCreateEvent: return evt.ctx } - return logWith.Logger().WithContext(portal.Bridge.backgroundCtx) + return logWith.Logger().WithContext(portal.Bridge.BackgroundCtx) } func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCallback func()) { From f23fc99ef40d933502342e0e2b339f05bac80595 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 10 May 2025 11:32:42 +0300 Subject: [PATCH 1146/1647] crypto/cross_signing: allow json marshaling cross-signing key seeds --- crypto/cross_sign_key.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/crypto/cross_sign_key.go b/crypto/cross_sign_key.go index 97ecd865..4094f695 100644 --- a/crypto/cross_sign_key.go +++ b/crypto/cross_sign_key.go @@ -11,6 +11,8 @@ import ( "context" "fmt" + "go.mau.fi/util/jsonbytes" + "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/signatures" @@ -33,9 +35,9 @@ func (cskc *CrossSigningKeysCache) PublicKeys() *CrossSigningPublicKeysCache { } type CrossSigningSeeds struct { - MasterKey []byte - SelfSigningKey []byte - UserSigningKey []byte + MasterKey jsonbytes.UnpaddedURLBytes `json:"m.cross_signing.master"` + SelfSigningKey jsonbytes.UnpaddedURLBytes `json:"m.cross_signing.self_signing"` + UserSigningKey jsonbytes.UnpaddedURLBytes `json:"m.cross_signing.user_signing"` } func (mach *OlmMachine) ExportCrossSigningKeys() CrossSigningSeeds { From 978e0983eadf3815b0b63c8a8f6df07493209417 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 15 May 2025 14:15:34 +0300 Subject: [PATCH 1147/1647] dependencies: update --- go.mod | 22 +++++++++++----------- go.sum | 40 ++++++++++++++++++++-------------------- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/go.mod b/go.mod index e279118e..fbe4274e 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,7 @@ module maunium.net/go/mautrix go 1.23.0 -toolchain go1.24.2 +toolchain go1.24.3 require ( filippo.io/edwards25519 v1.1.0 @@ -10,20 +10,20 @@ require ( github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.5.0 github.com/lib/pq v1.10.9 - github.com/mattn/go-sqlite3 v1.14.27 + github.com/mattn/go-sqlite3 v1.14.28 github.com/rs/xid v1.6.0 github.com/rs/zerolog v1.34.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/stretchr/testify v1.10.0 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 - github.com/yuin/goldmark v1.7.10 - go.mau.fi/util v0.8.7-0.20250427215252-d2d18a7e463c + github.com/yuin/goldmark v1.7.11 + go.mau.fi/util v0.8.7-0.20250515110144-747f5904911e go.mau.fi/zeroconfig v0.1.3 - golang.org/x/crypto v0.37.0 - golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 - golang.org/x/net v0.39.0 - golang.org/x/sync v0.13.0 + golang.org/x/crypto v0.38.0 + golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 + golang.org/x/net v0.40.0 + golang.org/x/sync v0.14.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 ) @@ -33,11 +33,11 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/petermattis/goid v0.0.0-20250319124200-ccd6737f222a // indirect + github.com/petermattis/goid v0.0.0-20250508124226-395b08cebbdb // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect - golang.org/x/sys v0.32.0 // indirect - golang.org/x/text v0.24.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/text v0.25.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index f103b287..3fbbb766 100644 --- a/go.sum +++ b/go.sum @@ -26,10 +26,10 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.27 h1:drZCnuvf37yPfs95E5jd9s3XhdVWLal+6BOK6qrv6IU= -github.com/mattn/go-sqlite3 v1.14.27/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/petermattis/goid v0.0.0-20250319124200-ccd6737f222a h1:S+AGcmAESQ0pXCUNnRH7V+bOUIgkSX5qVt2cNKCrm0Q= -github.com/petermattis/goid v0.0.0-20250319124200-ccd6737f222a/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A= +github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/petermattis/goid v0.0.0-20250508124226-395b08cebbdb h1:3PrKuO92dUTMrQ9dx0YNejC6U/Si6jqKmyQ9vWjwqR4= +github.com/petermattis/goid v0.0.0-20250508124226-395b08cebbdb/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -51,28 +51,28 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/yuin/goldmark v1.7.10 h1:S+LrtBjRmqMac2UdtB6yyCEJm+UILZ2fefI4p7o0QpI= -github.com/yuin/goldmark v1.7.10/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.8.7-0.20250427215252-d2d18a7e463c h1:qfJyMZq1pPyuXKoVWwHs6OmR9CzO3pHFRPYT/QpaaaA= -go.mau.fi/util v0.8.7-0.20250427215252-d2d18a7e463c/go.mod h1:uNB3UTXFbkpp7xL1M/WvQks90B/L4gvbLpbS0603KOE= +github.com/yuin/goldmark v1.7.11 h1:ZCxLyDMtz0nT2HFfsYG8WZ47Trip2+JyLysKcMYE5bo= +github.com/yuin/goldmark v1.7.11/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= +go.mau.fi/util v0.8.7-0.20250515110144-747f5904911e h1:8kfjOQ+L38Zq2HbhMFVhbkTdwiGbAmgTriioRnRB+LQ= +go.mau.fi/util v0.8.7-0.20250515110144-747f5904911e/go.mod h1:j6R3cENakc1f8HpQeFl0N15UiSTcNmIfDBNJUbL71RY= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= -golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= -golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM= -golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8= -golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY= -golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E= -golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= -golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= +golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= +golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 h1:y5zboxd6LQAqYIhHnB48p0ByQ/GnQx2BE33L8BOHQkI= +golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6/go.mod h1:U6Lno4MTRCDY+Ba7aCcauB9T60gsv5s4ralQzP72ZoQ= +golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= +golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= +golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= +golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= -golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= -golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= +golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= From 0a8e8230164055ea3d6f59fcefbba05756abd907 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 16 May 2025 08:13:16 +0300 Subject: [PATCH 1148/1647] Bump version to v0.24.0 --- CHANGELOG.md | 25 +++++++++++++++++++++++++ go.mod | 2 +- go.sum | 4 ++-- version.go | 2 +- 4 files changed, 29 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 565d7f15..95f214c3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,28 @@ +## v0.24.0 (2025-05-16) + +* *(commands)* Added generic framework for implementing bot commands. +* *(client)* Added support for specifying maximum number of HTTP retries using + a context value instead of having to call `MakeFullRequest` manually. +* *(client,federation)* Added methods for fetching room directories. +* *(federation)* Added support for server side of request authentication. +* *(synapseadmin)* Added wrapper for the account suspension endpoint. +* *(format)* Added method for safely wrapping a string in markdown inline code. +* *(crypto)* Added method to import key backup without persisting to database, + to allow the client more control over the process. +* *(bridgev2)* Added viewing chat interface to signal when the user is viewing + a given chat. +* *(bridgev2)* Added option to pass through transaction ID from client when + sending messages to remote network. +* *(crypto)* Fixed unnecessary error log when decrypting dummy events used for + unwedging Olm sessions. +* *(crypto)* Fixed `forwarding_curve25519_key_chain` not being set consistently + when backing up keys. +* *(event)* Fixed marshaling legacy VoIP events with no version field. +* *(bridgev2)* Fixed disappearing message references not being deleted when the + portal is deleted. +* *(bridgev2)* Fixed read receipt bridging not ignoring fake message entries + and causing unnecessary error logs. + ## v0.23.3 (2025-04-16) * *(commands)* Added generic command processing framework for bots. diff --git a/go.mod b/go.mod index fbe4274e..ebc7a61c 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.11 - go.mau.fi/util v0.8.7-0.20250515110144-747f5904911e + go.mau.fi/util v0.8.7 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.38.0 golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 diff --git a/go.sum b/go.sum index 3fbbb766..a3c7542d 100644 --- a/go.sum +++ b/go.sum @@ -53,8 +53,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.11 h1:ZCxLyDMtz0nT2HFfsYG8WZ47Trip2+JyLysKcMYE5bo= github.com/yuin/goldmark v1.7.11/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.8.7-0.20250515110144-747f5904911e h1:8kfjOQ+L38Zq2HbhMFVhbkTdwiGbAmgTriioRnRB+LQ= -go.mau.fi/util v0.8.7-0.20250515110144-747f5904911e/go.mod h1:j6R3cENakc1f8HpQeFl0N15UiSTcNmIfDBNJUbL71RY= +go.mau.fi/util v0.8.7 h1:ywKarPxouJQEEijTs4mPlxC7F4AWEKokEpWc+2TYy6c= +go.mau.fi/util v0.8.7/go.mod h1:j6R3cENakc1f8HpQeFl0N15UiSTcNmIfDBNJUbL71RY= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= diff --git a/version.go b/version.go index 2e670697..8366c5bf 100644 --- a/version.go +++ b/version.go @@ -7,7 +7,7 @@ import ( "strings" ) -const Version = "v0.23.3" +const Version = "v0.24.0" var GoModVersion = "" var Commit = "" From a205a77db46adefe4e9bfab7cf6ff1d8a6752424 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Tue, 20 May 2025 10:27:35 +0100 Subject: [PATCH 1149/1647] bridgev2: add `CredentialExportingNetworkAPI` interface --- bridgev2/networkinterface.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 14e3a681..76db8cc8 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -387,6 +387,13 @@ type BackgroundSyncingNetworkAPI interface { ConnectBackground(ctx context.Context, params *ConnectBackgroundParams) error } +// CredentialExportingNetworkAPI is an optional interface that networks connectors can implement to support export of +// the credentials associated with that login. Credential type is bridge specific. +type CredentialExportingNetworkAPI interface { + NetworkAPI + ExportCredentials(ctx context.Context) any +} + // FetchMessagesParams contains the parameters for a message history pagination request. type FetchMessagesParams struct { // The portal to fetch messages in. Always present. From 487fc699fe8f4154b45f37cccfcea3d915eea9c0 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Tue, 20 May 2025 10:32:33 +0100 Subject: [PATCH 1150/1647] bridgev2/provisioning: add session transfer support For connector logins that support it this will expose an API to transfer credentials between bridge instances. Currently does not do any extra validation beyond the usual provisioning API request validation (so shared secret or matrix token). One future improvement would be to require clients to sign incoming requests, and to then validate a) the signature and b) the device is verified. --- bridgev2/bridgeconfig/config.go | 7 +- bridgev2/matrix/mxmain/example-config.yaml | 3 + bridgev2/matrix/provisioning.go | 121 +++++++++++++++++++++ 3 files changed, 128 insertions(+), 3 deletions(-) diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index bd7746d1..37517818 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -95,9 +95,10 @@ type AnalyticsConfig struct { } type ProvisioningConfig struct { - Prefix string `yaml:"prefix"` - SharedSecret string `yaml:"shared_secret"` - DebugEndpoints bool `yaml:"debug_endpoints"` + Prefix string `yaml:"prefix"` + SharedSecret string `yaml:"shared_secret"` + DebugEndpoints bool `yaml:"debug_endpoints"` + EnableSessionTransfers bool `yaml:"enable_session_transfers"` } type DirectMediaConfig struct { diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 4dee2650..a9d05fd1 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -252,6 +252,9 @@ provisioning: allow_matrix_auth: true # Enable debug API at /debug with provisioning authentication. debug_endpoints: false + # Enable session transfers between bridges. Note that this only validates Matrix or shared secret + # auth before passing live network client credentials down in the response. + enable_session_transfers: false # Some networks require publicly accessible media download links (e.g. for user avatars when using Discord webhooks). # These settings control whether the bridge will provide such public media access. diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index d809d039..2b9b5124 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -53,6 +53,11 @@ type ProvisioningAPI struct { matrixAuthCache map[string]matrixAuthCacheEntry matrixAuthCacheLock sync.Mutex + // Set for a given login once credentials have been exported, once in this state the finish + // API is available which will call logout on the client in question. + sessionTransfers map[networkid.UserLoginID]struct{} + sessionTransfersLock sync.Mutex + // GetAuthFromRequest is a custom function for getting the auth token from // the request if the Authorization header is not present. GetAuthFromRequest func(r *http.Request) string @@ -101,6 +106,7 @@ func (br *Connector) GetProvisioning() IProvisioningAPI { func (prov *ProvisioningAPI) Init() { prov.matrixAuthCache = make(map[string]matrixAuthCacheEntry) prov.logins = make(map[string]*ProvLogin) + prov.sessionTransfers = make(map[networkid.UserLoginID]struct{}) prov.net = prov.br.Bridge.Network prov.log = prov.br.Log.With().Str("component", "provisioning").Logger() prov.fedClient = federation.NewClient("", nil, nil) @@ -128,6 +134,12 @@ func (prov *ProvisioningAPI) Init() { prov.Router.Path("/v3/create_dm/{identifier}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateDM) prov.Router.Path("/v3/create_group").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateGroup) + if prov.br.Config.Provisioning.EnableSessionTransfers { + prov.log.Debug().Msg("Enabling session transfer API") + prov.Router.Path("/v3/session_transfer/init").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostInitSessionTransfer) + prov.Router.Path("/v3/session_transfer/finish").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostFinishSessionTransfer) + } + if prov.br.Config.Provisioning.DebugEndpoints { prov.log.Debug().Msg("Enabling debug API at /debug") r := prov.br.AS.Router.PathPrefix("/debug").Subrouter() @@ -791,3 +803,112 @@ func (prov *ProvisioningAPI) PostCreateGroup(w http.ResponseWriter, r *http.Requ ErrCode: mautrix.MUnrecognized.ErrCode, }) } + +type ReqExportCredentials struct { + RemoteID networkid.UserLoginID `json:"remote_name"` +} + +type RespExportCredentials struct { + Credentials any `json:"credentials"` +} + +func (prov *ProvisioningAPI) PostInitSessionTransfer(w http.ResponseWriter, r *http.Request) { + prov.sessionTransfersLock.Lock() + defer prov.sessionTransfersLock.Unlock() + + var req ReqExportCredentials + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") + jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ + Err: "Failed to decode request body", + ErrCode: mautrix.MNotJSON.ErrCode, + }) + return + } + + user := prov.GetUser(r) + logins := user.GetUserLogins() + var loginToExport *bridgev2.UserLogin + for _, login := range logins { + if login.ID == req.RemoteID { + loginToExport = login + break + } + } + if loginToExport == nil { + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + Err: "No matching user login found", + ErrCode: mautrix.MNotFound.ErrCode, + }) + return + } + + client, ok := loginToExport.Client.(bridgev2.CredentialExportingNetworkAPI) + if !ok { + jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ + Err: "Client does not support credential exporting", + ErrCode: mautrix.MInvalidParam.ErrCode, + }) + return + } + + if _, ok := prov.sessionTransfers[loginToExport.ID]; ok { + // Warn, but allow, double exports. This might happen if a client crashes handling creds, + // and should be safe to call multiple times. + zerolog.Ctx(r.Context()).Warn().Msg("Exporting already exported credentials") + } + + resp := RespExportCredentials{ + Credentials: client.ExportCredentials(r.Context()), + } + jsonResponse(w, http.StatusOK, resp) +} + +func (prov *ProvisioningAPI) PostFinishSessionTransfer(w http.ResponseWriter, r *http.Request) { + prov.sessionTransfersLock.Lock() + defer prov.sessionTransfersLock.Unlock() + + var req ReqExportCredentials + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") + jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ + Err: "Failed to decode request body", + ErrCode: mautrix.MNotJSON.ErrCode, + }) + return + } + + user := prov.GetUser(r) + logins := user.GetUserLogins() + var loginToExport *bridgev2.UserLogin + for _, login := range logins { + if login.ID == req.RemoteID { + loginToExport = login + break + } + } + if loginToExport == nil { + jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ + Err: "No matching user login found", + ErrCode: mautrix.MNotFound.ErrCode, + }) + return + } else if _, ok := prov.sessionTransfers[loginToExport.ID]; !ok { + jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ + Err: "No matching credential export found", + ErrCode: mautrix.MNotJSON.ErrCode, + }) + return + } + + zerolog.Ctx(r.Context()).Info(). + Str("remote_name", string(req.RemoteID)). + Msg("Logging out remote after finishing credential export") + + loginToExport.Client.LogoutRemote(r.Context()) + delete(prov.sessionTransfers, req.RemoteID) + + jsonResponse(w, http.StatusOK, struct{}{}) +} From a3efaa36322985859faa1d31a95907db1905d0ad Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Thu, 22 May 2025 15:20:36 +0100 Subject: [PATCH 1151/1647] bridgev2/provisioning: disconnect login before exporting credentials --- bridgev2/matrix/provisioning.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 2b9b5124..eacc86c6 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -859,6 +859,8 @@ func (prov *ProvisioningAPI) PostInitSessionTransfer(w http.ResponseWriter, r *h zerolog.Ctx(r.Context()).Warn().Msg("Exporting already exported credentials") } + // Disconnect now so we don't use the same network session in two places at once + client.Disconnect() resp := RespExportCredentials{ Credentials: client.ExportCredentials(r.Context()), } From 203e402ebf1af54f215279757e1a5a03a870fd3c Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Thu, 22 May 2025 15:22:13 +0100 Subject: [PATCH 1152/1647] bridgev2/provisioning: correct field name --- bridgev2/matrix/provisioning.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index eacc86c6..d05b005e 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -805,7 +805,7 @@ func (prov *ProvisioningAPI) PostCreateGroup(w http.ResponseWriter, r *http.Requ } type ReqExportCredentials struct { - RemoteID networkid.UserLoginID `json:"remote_name"` + RemoteID networkid.UserLoginID `json:"remote_id"` } type RespExportCredentials struct { From ad8145c43b4e46c00dd000ceb06a126479e6b9cc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 24 May 2025 14:23:18 +0300 Subject: [PATCH 1153/1647] synapseadmin: don't embed mautrix.Client in admin client struct --- synapseadmin/client.go | 4 ++-- synapseadmin/register.go | 4 ++-- synapseadmin/roomapi.go | 26 ++++++++++++++++---------- synapseadmin/userapi.go | 22 +++++++++++----------- 4 files changed, 31 insertions(+), 25 deletions(-) diff --git a/synapseadmin/client.go b/synapseadmin/client.go index 775b4b13..6925ca7d 100644 --- a/synapseadmin/client.go +++ b/synapseadmin/client.go @@ -14,9 +14,9 @@ import ( // // https://matrix-org.github.io/synapse/latest/usage/administration/admin_api/index.html type Client struct { - *mautrix.Client + Client *mautrix.Client } func (cli *Client) BuildAdminURL(path ...any) string { - return cli.BuildURL(mautrix.SynapseAdminURLPath(path)) + return cli.Client.BuildURL(mautrix.SynapseAdminURLPath(path)) } diff --git a/synapseadmin/register.go b/synapseadmin/register.go index 641f9b56..05e0729a 100644 --- a/synapseadmin/register.go +++ b/synapseadmin/register.go @@ -73,7 +73,7 @@ func (req *ReqSharedSecretRegister) Sign(secret string) string { // This does not need to be called manually as SharedSecretRegister will automatically call this if no nonce is provided. func (cli *Client) GetRegisterNonce(ctx context.Context) (string, error) { var resp respGetRegisterNonce - _, err := cli.MakeRequest(ctx, http.MethodGet, cli.BuildURL(mautrix.SynapseAdminURLPath{"v1", "register"}), nil, &resp) + _, err := cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v1", "register"), nil, &resp) if err != nil { return "", err } @@ -93,7 +93,7 @@ func (cli *Client) SharedSecretRegister(ctx context.Context, sharedSecret string } req.SHA1Checksum = req.Sign(sharedSecret) var resp mautrix.RespRegister - _, err = cli.MakeRequest(ctx, http.MethodPost, cli.BuildURL(mautrix.SynapseAdminURLPath{"v1", "register"}), &req, &resp) + _, err = cli.Client.MakeRequest(ctx, http.MethodPost, cli.BuildAdminURL("v1", "register"), &req, &resp) if err != nil { return nil, err } diff --git a/synapseadmin/roomapi.go b/synapseadmin/roomapi.go index 6c072e23..b2d82fb3 100644 --- a/synapseadmin/roomapi.go +++ b/synapseadmin/roomapi.go @@ -76,11 +76,17 @@ type RespListRooms struct { func (cli *Client) ListRooms(ctx context.Context, req ReqListRoom) (RespListRooms, error) { var resp RespListRooms var reqURL string - reqURL = cli.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery()) - _, err := cli.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) + reqURL = cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery()) + _, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) return resp, err } +func (cli *Client) RoomInfo(ctx context.Context, roomID id.RoomID) (resp *RoomInfo, err error) { + reqURL := cli.BuildAdminURL("v1", "rooms", roomID) + _, err = cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) + return +} + type RespRoomMessages = mautrix.RespMessages // RoomMessages returns a list of messages in a room. @@ -104,8 +110,8 @@ func (cli *Client) RoomMessages(ctx context.Context, roomID id.RoomID, from, to if limit != 0 { query["limit"] = strconv.Itoa(limit) } - urlPath := cli.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms", roomID, "messages"}, query) - _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + urlPath := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms", roomID, "messages"}, query) + _, err = cli.Client.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return resp, err } @@ -129,7 +135,7 @@ type RespDeleteRoom struct { func (cli *Client) DeleteRoom(ctx context.Context, roomID id.RoomID, req ReqDeleteRoom) (RespDeleteRoom, error) { reqURL := cli.BuildAdminURL("v2", "rooms", roomID) var resp RespDeleteRoom - _, err := cli.MakeRequest(ctx, http.MethodDelete, reqURL, &req, &resp) + _, err := cli.Client.MakeRequest(ctx, http.MethodDelete, reqURL, &req, &resp) return resp, err } @@ -144,7 +150,7 @@ type RespRoomsMembers struct { func (cli *Client) RoomMembers(ctx context.Context, roomID id.RoomID) (RespRoomsMembers, error) { reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "members") var resp RespRoomsMembers - _, err := cli.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) + _, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) return resp, err } @@ -157,7 +163,7 @@ type ReqMakeRoomAdmin struct { // https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#make-room-admin-api func (cli *Client) MakeRoomAdmin(ctx context.Context, roomIDOrAlias string, req ReqMakeRoomAdmin) error { reqURL := cli.BuildAdminURL("v1", "rooms", roomIDOrAlias, "make_room_admin") - _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) + _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) return err } @@ -170,7 +176,7 @@ type ReqJoinUserToRoom struct { // https://matrix-org.github.io/synapse/latest/admin_api/room_membership.html func (cli *Client) JoinUserToRoom(ctx context.Context, roomID id.RoomID, req ReqJoinUserToRoom) error { reqURL := cli.BuildAdminURL("v1", "join", roomID) - _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) + _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) return err } @@ -183,7 +189,7 @@ type ReqBlockRoom struct { // https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#block-room-api func (cli *Client) BlockRoom(ctx context.Context, roomID id.RoomID, req ReqBlockRoom) error { reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "block") - _, err := cli.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil) + _, err := cli.Client.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil) return err } @@ -199,6 +205,6 @@ type RoomsBlockResponse struct { func (cli *Client) GetRoomBlockStatus(ctx context.Context, roomID id.RoomID) (RoomsBlockResponse, error) { var resp RoomsBlockResponse reqURL := cli.BuildAdminURL("v1", "rooms", roomID, "block") - _, err := cli.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) + _, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) return resp, err } diff --git a/synapseadmin/userapi.go b/synapseadmin/userapi.go index d3672367..b1de55b6 100644 --- a/synapseadmin/userapi.go +++ b/synapseadmin/userapi.go @@ -32,7 +32,7 @@ type ReqResetPassword struct { // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#reset-password func (cli *Client) ResetPassword(ctx context.Context, req ReqResetPassword) error { reqURL := cli.BuildAdminURL("v1", "reset_password", req.UserID) - _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) + _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) return err } @@ -43,8 +43,8 @@ func (cli *Client) ResetPassword(ctx context.Context, req ReqResetPassword) erro // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#check-username-availability func (cli *Client) UsernameAvailable(ctx context.Context, username string) (resp *mautrix.RespRegisterAvailable, err error) { - u := cli.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "username_available"}, map[string]string{"username": username}) - _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp) + u := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "username_available"}, map[string]string{"username": username}) + _, err = cli.Client.MakeRequest(ctx, http.MethodGet, u, nil, &resp) if err == nil && !resp.Available { err = fmt.Errorf(`request returned OK status without "available": true`) } @@ -65,7 +65,7 @@ type RespListDevices struct { // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#list-all-devices func (cli *Client) ListDevices(ctx context.Context, userID id.UserID) (resp *RespListDevices, err error) { - _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v2", "users", userID, "devices"), nil, &resp) + _, err = cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v2", "users", userID, "devices"), nil, &resp) return } @@ -89,7 +89,7 @@ type RespUserInfo struct { // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#query-user-account func (cli *Client) GetUserInfo(ctx context.Context, userID id.UserID) (resp *RespUserInfo, err error) { - _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v2", "users", userID), nil, &resp) + _, err = cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v2", "users", userID), nil, &resp) return } @@ -102,7 +102,7 @@ type ReqDeleteUser struct { // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#deactivate-account func (cli *Client) DeactivateAccount(ctx context.Context, userID id.UserID, req ReqDeleteUser) error { reqURL := cli.BuildAdminURL("v1", "deactivate", userID) - _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) + _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) return err } @@ -115,7 +115,7 @@ type ReqSuspendUser struct { // https://element-hq.github.io/synapse/latest/admin_api/user_admin_api.html#suspendunsuspend-account func (cli *Client) SuspendAccount(ctx context.Context, userID id.UserID, req ReqSuspendUser) error { reqURL := cli.BuildAdminURL("v1", "suspend", userID) - _, err := cli.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil) + _, err := cli.Client.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil) return err } @@ -137,7 +137,7 @@ type ReqCreateOrModifyAccount struct { // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#create-or-modify-account func (cli *Client) CreateOrModifyAccount(ctx context.Context, userID id.UserID, req ReqCreateOrModifyAccount) error { reqURL := cli.BuildAdminURL("v2", "users", userID) - _, err := cli.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil) + _, err := cli.Client.MakeRequest(ctx, http.MethodPut, reqURL, &req, nil) return err } @@ -153,7 +153,7 @@ type ReqSetRatelimit = RatelimitOverride // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#set-ratelimit func (cli *Client) SetUserRatelimit(ctx context.Context, userID id.UserID, req ReqSetRatelimit) error { reqURL := cli.BuildAdminURL("v1", "users", userID, "override_ratelimit") - _, err := cli.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) + _, err := cli.Client.MakeRequest(ctx, http.MethodPost, reqURL, &req, nil) return err } @@ -163,7 +163,7 @@ type RespUserRatelimit = RatelimitOverride // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#get-status-of-ratelimit func (cli *Client) GetUserRatelimit(ctx context.Context, userID id.UserID) (resp RespUserRatelimit, err error) { - _, err = cli.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), nil, &resp) + _, err = cli.Client.MakeRequest(ctx, http.MethodGet, cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), nil, &resp) return } @@ -171,6 +171,6 @@ func (cli *Client) GetUserRatelimit(ctx context.Context, userID id.UserID) (resp // // https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#delete-ratelimit func (cli *Client) DeleteUserRatelimit(ctx context.Context, userID id.UserID) (err error) { - _, err = cli.MakeRequest(ctx, http.MethodDelete, cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), nil, nil) + _, err = cli.Client.MakeRequest(ctx, http.MethodDelete, cli.BuildAdminURL("v1", "users", userID, "override_ratelimit"), nil, nil) return } From 49d2f391835481fcf100ea39ca0418b635c7445c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 24 May 2025 14:35:00 +0300 Subject: [PATCH 1154/1647] format: add markdown link utilities --- format/markdown.go | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/format/markdown.go b/format/markdown.go index f6181ed9..3d9979b4 100644 --- a/format/markdown.go +++ b/format/markdown.go @@ -18,6 +18,7 @@ import ( "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/format/mdext" + "maunium.net/go/mautrix/id" ) const paragraphStart = "

      " @@ -41,7 +42,7 @@ func UnwrapSingleParagraph(html string) string { return html } -var mdEscapeRegex = regexp.MustCompile("([\\\\`*_[\\]])") +var mdEscapeRegex = regexp.MustCompile("([\\\\`*_[\\]()])") func EscapeMarkdown(text string) string { text = mdEscapeRegex.ReplaceAllString(text, "\\$1") @@ -50,7 +51,23 @@ func EscapeMarkdown(text string) string { return text } +type uriAble interface { + String() string + URI() *id.MatrixURI +} + +func MarkdownMention(id uriAble) string { + return MarkdownLink(id.String(), id.URI().MatrixToURL()) +} + +func MarkdownLink(name string, url string) string { + return fmt.Sprintf("[%s](%s)", EscapeMarkdown(name), EscapeMarkdown(url)) +} + func SafeMarkdownCode[T ~string](textInput T) string { + if textInput == "" { + return "` `" + } text := strings.ReplaceAll(string(textInput), "\n", " ") backtickCount := exstrings.LongestSequenceOf(text, '`') if backtickCount == 0 { From 50f0b5fa7d581a2d26ae3a8a27944ddc0c3a47cb Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 24 May 2025 14:42:06 +0300 Subject: [PATCH 1155/1647] synapseadmin: add support for synchronous room delete --- synapseadmin/roomapi.go | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/synapseadmin/roomapi.go b/synapseadmin/roomapi.go index b2d82fb3..fa391b73 100644 --- a/synapseadmin/roomapi.go +++ b/synapseadmin/roomapi.go @@ -127,6 +127,14 @@ type RespDeleteRoom struct { DeleteID string `json:"delete_id"` } +type RespDeleteRoomStatus struct { + Status string `json:"status,omitempty"` + KickedUsers []id.UserID `json:"kicked_users,omitempty"` + FailedToKickUsers []id.UserID `json:"failed_to_kick_users,omitempty"` + LocalAliases []id.RoomAlias `json:"local_aliases,omitempty"` + NewRoomID id.RoomID `json:"new_room_id,omitempty"` +} + // DeleteRoom deletes a room from the server, optionally blocking it and/or purging all data from the database. // // This calls the async version of the endpoint, which will return immediately and delete the room in the background. @@ -139,6 +147,27 @@ func (cli *Client) DeleteRoom(ctx context.Context, roomID id.RoomID, req ReqDele return resp, err } +// DeleteRoomSync deletes a room from the server, optionally blocking it and/or purging all data from the database. +// +// This calls the synchronous version of the endpoint, which will block until the room is deleted. +// +// https://element-hq.github.io/synapse/latest/admin_api/rooms.html#version-1-old-version +func (cli *Client) DeleteRoomSync(ctx context.Context, roomID id.RoomID, req ReqDeleteRoom) (resp RespDeleteRoomStatus, err error) { + reqURL := cli.BuildAdminURL("v1", "rooms", roomID) + httpClient := &http.Client{} + _, err = cli.Client.MakeFullRequest(ctx, mautrix.FullRequest{ + Method: http.MethodDelete, + URL: reqURL, + RequestJSON: &req, + ResponseJSON: &resp, + MaxAttempts: 1, + // Use a fresh HTTP client without timeouts + Client: httpClient, + }) + httpClient.CloseIdleConnections() + return +} + type RespRoomsMembers struct { Members []id.UserID `json:"members"` Total int `json:"total"` From 68565a1f18c80e261c2a818b1517fa310c06317a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 24 May 2025 15:35:33 +0300 Subject: [PATCH 1156/1647] client: add wrapper for /relations endpoints --- client.go | 6 ++++++ requests.go | 41 +++++++++++++++++++++++++++++++++++++++++ responses.go | 7 +++++++ 3 files changed, 54 insertions(+) diff --git a/client.go b/client.go index c8e366b0..bf3bb16e 100644 --- a/client.go +++ b/client.go @@ -2088,6 +2088,12 @@ func (cli *Client) GetUnredactedEventContent(ctx context.Context, roomID id.Room return } +func (cli *Client) GetRelations(ctx context.Context, roomID id.RoomID, eventID id.EventID, req *ReqGetRelations) (resp *RespGetRelations, err error) { + urlPath := cli.BuildURLWithQuery(append(ClientURLPath{"v1", "rooms", roomID, "relations", eventID}, req.PathSuffix()...), req.Query()) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + return +} + func (cli *Client) MarkRead(ctx context.Context, roomID id.RoomID, eventID id.EventID) (err error) { return cli.SendReceipt(ctx, roomID, eventID, event.ReceiptTypeRead, nil) } diff --git a/requests.go b/requests.go index 1bed6c7e..42d257fb 100644 --- a/requests.go +++ b/requests.go @@ -543,3 +543,44 @@ type ReqReport struct { Reason string `json:"reason,omitempty"` Score int `json:"score,omitempty"` } + +type ReqGetRelations struct { + RelationType event.RelationType + EventType event.Type + + Dir Direction + From string + To string + Limit int + Recurse bool +} + +func (rgr *ReqGetRelations) PathSuffix() ClientURLPath { + if rgr.RelationType != "" { + if rgr.EventType.Type != "" { + return ClientURLPath{rgr.RelationType, rgr.EventType.Type} + } + return ClientURLPath{rgr.RelationType} + } + return ClientURLPath{} +} + +func (rgr *ReqGetRelations) Query() map[string]string { + query := map[string]string{} + if rgr.Dir != 0 { + query["dir"] = string(rgr.Dir) + } + if rgr.From != "" { + query["from"] = rgr.From + } + if rgr.To != "" { + query["to"] = rgr.To + } + if rgr.Limit > 0 { + query["limit"] = strconv.Itoa(rgr.Limit) + } + if rgr.Recurse { + query["recurse"] = "true" + } + return query +} diff --git a/responses.go b/responses.go index ee7f4703..20d02af5 100644 --- a/responses.go +++ b/responses.go @@ -709,3 +709,10 @@ type RespOpenIDToken struct { MatrixServerName string `json:"matrix_server_name"` TokenType string `json:"token_type"` // Always "Bearer" } + +type RespGetRelations struct { + Chunk []*event.Event `json:"chunk"` + NextBatch string `json:"next_batch,omitempty"` + PrevBatch string `json:"prev_batch,omitempty"` + RecursionDepth int `json:"recursion_depth,omitempty"` +} From e9dfee45c0d8682e711c7e7d027a7d1e1463bad8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 24 May 2025 16:09:09 +0300 Subject: [PATCH 1157/1647] event: add missing letter to docstring --- event/content.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/event/content.go b/event/content.go index 2347898e..b56e35f2 100644 --- a/event/content.go +++ b/event/content.go @@ -123,7 +123,7 @@ var TypeMap = map[Type]reflect.Type{ // When being marshaled into JSON, the data in Parsed will be marshaled first and then recursively merged // with the data in Raw. Values in Raw are preferred, but nested objects will be recursed into before merging, // rather than overriding the whole object with the one in Raw). -// If one of them is nil, the only the other is used. If both (Parsed and Raw) are nil, VeryRaw is used instead. +// If one of them is nil, then only the other is used. If both (Parsed and Raw) are nil, VeryRaw is used instead. type Content struct { VeryRaw json.RawMessage Raw map[string]interface{} From ec15b79493651cbc45c5a7b5e98fe7e0bcd4cad4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 24 May 2025 16:18:39 +0300 Subject: [PATCH 1158/1647] commands: add event id to logger --- commands/processor.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/commands/processor.go b/commands/processor.go index da802fd9..24670d2f 100644 --- a/commands/processor.go +++ b/commands/processor.go @@ -43,7 +43,11 @@ func NewProcessor[MetaType any](cli *mautrix.Client) *Processor[MetaType] { } func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) { - log := *zerolog.Ctx(ctx) + log := zerolog.Ctx(ctx).With(). + Stringer("sender", evt.Sender). + Stringer("room_id", evt.RoomID). + Stringer("event_id", evt.ID). + Logger() defer func() { panicErr := recover() if panicErr != nil { @@ -98,9 +102,7 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) logWith := log.With(). Str("command", parsed.Command). - Array("handler", handlerChain). - Stringer("sender", evt.Sender). - Stringer("room_id", evt.RoomID) + Array("handler", handlerChain) if len(parsed.ParentCommands) > 0 { logWith = logWith.Strs("parent_commands", parsed.ParentCommands) } From da9e72e61680d6837170bd04cb783f36296e2060 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 24 May 2025 16:21:46 +0300 Subject: [PATCH 1159/1647] commands: add separate field for logger in event --- commands/event.go | 2 ++ commands/processor.go | 1 + 2 files changed, 3 insertions(+) diff --git a/commands/event.go b/commands/event.go index 8d51eadd..96d5e921 100644 --- a/commands/event.go +++ b/commands/event.go @@ -36,6 +36,7 @@ type Event[MetaType any] struct { RawArgs string Ctx context.Context + Log *zerolog.Logger Proc *Processor[MetaType] Handler *Handler[MetaType] Meta MetaType @@ -77,6 +78,7 @@ func ParseEvent[MetaType any](ctx context.Context, evt *event.Event) *Event[Meta Command: strings.ToLower(parts[0]), Args: parts[1:], RawArgs: strings.TrimLeft(strings.TrimPrefix(text, parts[0]), " "), + Log: zerolog.Ctx(ctx), Ctx: ctx, } } diff --git a/commands/processor.go b/commands/processor.go index 24670d2f..c4940526 100644 --- a/commands/processor.go +++ b/commands/processor.go @@ -111,6 +111,7 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) } log = logWith.Logger() parsed.Ctx = log.WithContext(ctx) + parsed.Log = &log log.Debug().Msg("Processing command") handler.Func(parsed) From 89fad2f462145ea13fdd05c062de9b5bd3fc193d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 24 May 2025 16:29:40 +0300 Subject: [PATCH 1160/1647] commands: add reaction button system --- commands/event.go | 23 +++++++- commands/processor.go | 10 +++- commands/reactions.go | 125 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 156 insertions(+), 2 deletions(-) create mode 100644 commands/reactions.go diff --git a/commands/event.go b/commands/event.go index 96d5e921..65ddd3da 100644 --- a/commands/event.go +++ b/commands/event.go @@ -40,6 +40,8 @@ type Event[MetaType any] struct { Proc *Processor[MetaType] Handler *Handler[MetaType] Meta MetaType + + redactedBy id.EventID } var IDHTMLParser = &format.HTMLParser{ @@ -71,7 +73,14 @@ func ParseEvent[MetaType any](ctx context.Context, evt *event.Event) *Event[Meta if len(text) == 0 { return nil } + return RawTextToEvent[MetaType](ctx, evt, text) +} + +func RawTextToEvent[MetaType any](ctx context.Context, evt *event.Event, text string) *Event[MetaType] { parts := strings.Fields(text) + if len(parts) == 0 { + parts = []string{""} + } return &Event[MetaType]{ Event: evt, RawInput: text, @@ -91,6 +100,7 @@ type ReplyOpts struct { SendAsText bool Edit id.EventID OverrideMentions *event.Mentions + Extra map[string]any } func (evt *Event[MetaType]) Reply(msg string, args ...any) id.EventID { @@ -117,7 +127,14 @@ func (evt *Event[MetaType]) Respond(msg string, opts ReplyOpts) id.EventID { if opts.OverrideMentions != nil { content.Mentions = opts.OverrideMentions } - resp, err := evt.Proc.Client.SendMessageEvent(evt.Ctx, evt.RoomID, event.EventMessage, content) + var wrapped any = &content + if opts.Extra != nil { + wrapped = &event.Content{ + Parsed: &content, + Raw: opts.Extra, + } + } + resp, err := evt.Proc.Client.SendMessageEvent(evt.Ctx, evt.RoomID, event.EventMessage, wrapped) if err != nil { zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to send reply") return "" @@ -135,11 +152,15 @@ func (evt *Event[MetaType]) React(emoji string) id.EventID { } func (evt *Event[MetaType]) Redact() id.EventID { + if evt.redactedBy != "" { + return evt.redactedBy + } resp, err := evt.Proc.Client.RedactEvent(evt.Ctx, evt.RoomID, evt.ID) if err != nil { zerolog.Ctx(evt.Ctx).Err(err).Msg("Failed to redact command") return "" } + evt.redactedBy = resp.EventID return resp.EventID } diff --git a/commands/processor.go b/commands/processor.go index c4940526..9341329b 100644 --- a/commands/processor.go +++ b/commands/processor.go @@ -26,6 +26,8 @@ type Processor[MetaType any] struct { LogArgs bool PreValidator PreValidator[MetaType] Meta MetaType + + ReactionCommandPrefix string } // UnknownCommandName is the name of the fallback handler which is used if no other handler is found. @@ -65,7 +67,13 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) } } }() - parsed := ParseEvent[MetaType](ctx, evt) + var parsed *Event[MetaType] + switch evt.Type { + case event.EventReaction: + parsed = proc.ParseReaction(ctx, evt) + case event.EventMessage: + parsed = ParseEvent[MetaType](ctx, evt) + } if parsed == nil || !proc.PreValidator.Validate(parsed) { return } diff --git a/commands/reactions.go b/commands/reactions.go new file mode 100644 index 00000000..0df372e5 --- /dev/null +++ b/commands/reactions.go @@ -0,0 +1,125 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package commands + +import ( + "context" + "strings" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" +) + +const ReactionCommandsKey = "fi.mau.reaction_commands" +const ReactionMultiUseKey = "fi.mau.reaction_multi_use" + +func (proc *Processor[MetaType]) ParseReaction(ctx context.Context, evt *event.Event) *Event[MetaType] { + content, ok := evt.Content.Parsed.(*event.ReactionEventContent) + if !ok { + return nil + } + evtID := content.RelatesTo.EventID + if evtID == "" || !strings.HasPrefix(content.RelatesTo.Key, proc.ReactionCommandPrefix) { + return nil + } + targetEvt, err := proc.Client.GetEvent(ctx, evt.RoomID, evtID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("target_event_id", evtID).Msg("Failed to get target event for reaction") + return nil + } else if targetEvt.Sender != proc.Client.UserID || targetEvt.Unsigned.RedactedBecause != nil { + return nil + } + if targetEvt.Type == event.EventEncrypted { + if proc.Client.Crypto == nil { + zerolog.Ctx(ctx).Warn(). + Stringer("target_event_id", evtID). + Msg("Received reaction to encrypted event, but don't have crypto helper in client") + return nil + } + _ = targetEvt.Content.ParseRaw(targetEvt.Type) + targetEvt, err = proc.Client.Crypto.Decrypt(ctx, targetEvt) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Stringer("target_event_id", evtID). + Msg("Failed to decrypt target event for reaction") + return nil + } + } + reactionCommands, ok := targetEvt.Content.Raw[ReactionCommandsKey].(map[string]any) + if !ok { + zerolog.Ctx(ctx).Trace(). + Stringer("target_event_id", evtID). + Msg("Reaction target event doesn't have commands key") + return nil + } + isMultiUse, _ := targetEvt.Content.Raw[ReactionMultiUseKey].(bool) + rawCmd, ok := reactionCommands[content.RelatesTo.Key] + if !ok { + zerolog.Ctx(ctx).Debug(). + Stringer("target_event_id", evtID). + Str("reaction_key", content.RelatesTo.Key). + Msg("Reaction command not found in target event") + return nil + } + cmdString, ok := rawCmd.(string) + if !ok { + zerolog.Ctx(ctx).Debug(). + Stringer("target_event_id", evtID). + Str("reaction_key", content.RelatesTo.Key). + Msg("Reaction command data is invalid") + return nil + } + wrappedEvt := RawTextToEvent[MetaType](ctx, evt, cmdString) + wrappedEvt.Proc = proc + wrappedEvt.Redact() + if !isMultiUse { + DeleteAllReactions(ctx, proc.Client, evt) + } + if cmdString == "" { + return nil + } + return wrappedEvt +} + +func DeleteAllReactionsCommandFunc[MetaType any](ce *Event[MetaType]) { + DeleteAllReactions(ce.Ctx, ce.Proc.Client, ce.Event) +} + +func DeleteAllReactions(ctx context.Context, client *mautrix.Client, evt *event.Event) { + rel, ok := evt.Content.Parsed.(event.Relatable) + if !ok { + return + } + relation := rel.OptionalGetRelatesTo() + if relation == nil { + return + } + targetEvt := relation.GetReplyTo() + if targetEvt == "" { + targetEvt = relation.GetAnnotationID() + } + if targetEvt == "" { + return + } + relations, err := client.GetRelations(ctx, evt.RoomID, targetEvt, &mautrix.ReqGetRelations{ + RelationType: event.RelAnnotation, + EventType: event.EventReaction, + Limit: 20, + }) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get reactions to delete") + return + } + for _, relEvt := range relations.Chunk { + _, err = client.RedactEvent(ctx, relEvt.RoomID, relEvt.ID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("event_id", relEvt.ID).Msg("Failed to redact reaction event") + } + } +} From 306b48bd6814f1744b234f0b38a600ca4ba3271d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 26 May 2025 19:34:51 +0300 Subject: [PATCH 1161/1647] bridgev2/ghost: ensure GetGhostByID can't return nil --- bridgev2/ghost.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index e4e007cd..087e0b64 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -85,7 +85,13 @@ func (br *Bridge) GetGhostByMXID(ctx context.Context, mxid id.UserID) (*Ghost, e func (br *Bridge) GetGhostByID(ctx context.Context, id networkid.UserID) (*Ghost, error) { br.cacheLock.Lock() defer br.cacheLock.Unlock() - return br.unlockedGetGhostByID(ctx, id, false) + ghost, err := br.unlockedGetGhostByID(ctx, id, false) + if err != nil { + return nil, err + } else if ghost == nil { + panic(fmt.Errorf("unlockedGetGhostByID(ctx, %q, false) returned nil", id)) + } + return ghost, nil } func (br *Bridge) GetExistingGhostByID(ctx context.Context, id networkid.UserID) (*Ghost, error) { From c5ef0f9d90addf7ba1b5658ff18cd331987fa21a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 26 May 2025 20:24:33 +0300 Subject: [PATCH 1162/1647] bridgev2/userlogin: ensure Client is filled in NewLogin --- bridgev2/userlogin.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index bf8f3bc6..9be3da3f 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -230,6 +230,9 @@ func (user *User) NewLogin(ctx context.Context, data *database.UserLogin, params err = params.LoadUserLogin(ul.Log.WithContext(context.Background()), ul) if err != nil { return nil, err + } else if ul.Client == nil { + ul.Log.Error().Msg("LoadUserLogin didn't fill Client in NewLogin") + return nil, fmt.Errorf("client not filled by LoadUserLogin") } if doInsert { err = user.Bridge.DB.UserLogin.Insert(ctx, ul.UserLogin) From 6ed660557b57d36ead3d240e34615a7c574d7495 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 26 May 2025 21:14:37 +0300 Subject: [PATCH 1163/1647] federation/signingkey: store raw response for validation --- federation/serverauth_test.go | 29 ++++++++++++++++++++++++ federation/signingkey.go | 42 +++++------------------------------ 2 files changed, 34 insertions(+), 37 deletions(-) create mode 100644 federation/serverauth_test.go diff --git a/federation/serverauth_test.go b/federation/serverauth_test.go new file mode 100644 index 00000000..d79dce36 --- /dev/null +++ b/federation/serverauth_test.go @@ -0,0 +1,29 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package federation_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/federation" +) + +func TestServerKeyResponse_VerifySelfSignature(t *testing.T) { + cli := federation.NewClient("", nil, nil) + ctx := context.Background() + for _, name := range []string{"matrix.org", "maunium.net", "continuwuity.org"} { + t.Run(name, func(t *testing.T) { + resp, err := cli.ServerKeys(ctx, "matrix.org") + require.NoError(t, err) + assert.True(t, resp.VerifySelfSignature()) + }) + } +} diff --git a/federation/signingkey.go b/federation/signingkey.go index 87c12a5e..a8362247 100644 --- a/federation/signingkey.go +++ b/federation/signingkey.go @@ -11,7 +11,6 @@ import ( "encoding/base64" "encoding/json" "fmt" - "maps" "strings" "time" @@ -82,7 +81,7 @@ type ServerKeyResponse struct { Signatures map[string]map[id.KeyID]string `json:"signatures,omitempty"` ValidUntilTS jsontime.UnixMilli `json:"valid_until_ts"` - Extra map[string]any `json:"-"` + Raw json.RawMessage `json:"-"` } func (skr *ServerKeyResponse) HasKey(keyID id.KeyID) bool { @@ -96,7 +95,7 @@ func (skr *ServerKeyResponse) HasKey(keyID id.KeyID) bool { func (skr *ServerKeyResponse) VerifySelfSignature() bool { for keyID, key := range skr.VerifyKeys { - if !VerifyJSON(skr.ServerName, keyID, key.Key, skr) { + if !VerifyJSON(skr.ServerName, keyID, key.Key, skr.Raw) { return false } } @@ -128,7 +127,7 @@ func VerifyJSON(serverName string, keyID id.KeyID, key id.SigningKey, data any) } func VerifyJSONRaw(key id.SigningKey, sig string, message json.RawMessage) bool { - sigBytes, err := base64.RawURLEncoding.DecodeString(sig) + sigBytes, err := base64.RawStdEncoding.DecodeString(sig) if err != nil { return false } @@ -142,40 +141,9 @@ func VerifyJSONRaw(key id.SigningKey, sig string, message json.RawMessage) bool type marshalableSKR ServerKeyResponse -func (skr *ServerKeyResponse) MarshalJSON() ([]byte, error) { - if skr.Extra == nil { - return json.Marshal((*marshalableSKR)(skr)) - } - marshalable := maps.Clone(skr.Extra) - marshalable["server_name"] = skr.ServerName - marshalable["verify_keys"] = skr.VerifyKeys - marshalable["old_verify_keys"] = skr.OldVerifyKeys - marshalable["signatures"] = skr.Signatures - marshalable["valid_until_ts"] = skr.ValidUntilTS - return json.Marshal(skr.Extra) -} - func (skr *ServerKeyResponse) UnmarshalJSON(data []byte) error { - err := json.Unmarshal(data, (*marshalableSKR)(skr)) - if err != nil { - return err - } - var extra map[string]any - err = json.Unmarshal(data, &extra) - if err != nil { - return err - } - delete(extra, "server_name") - delete(extra, "verify_keys") - delete(extra, "old_verify_keys") - delete(extra, "signatures") - delete(extra, "valid_until_ts") - if len(extra) > 0 { - skr.Extra = extra - } else { - skr.Extra = nil - } - return nil + skr.Raw = data + return json.Unmarshal(data, (*marshalableSKR)(skr)) } type ServerVerifyKey struct { From 92311e5c9852b2459c5489854dc4fb9a54dd19d6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 26 May 2025 21:14:47 +0300 Subject: [PATCH 1164/1647] federation/client: fix QueryKeys return format --- federation/client.go | 2 +- federation/signingkey.go | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/federation/client.go b/federation/client.go index 93ed759c..4f26f4b4 100644 --- a/federation/client.go +++ b/federation/client.go @@ -52,7 +52,7 @@ func (c *Client) ServerKeys(ctx context.Context, serverName string) (resp *Serve return } -func (c *Client) QueryKeys(ctx context.Context, serverName string, req *ReqQueryKeys) (resp *ServerKeyResponse, err error) { +func (c *Client) QueryKeys(ctx context.Context, serverName string, req *ReqQueryKeys) (resp *QueryKeysResponse, err error) { err = c.MakeRequest(ctx, serverName, false, http.MethodPost, KeyURLPath{"v2", "query"}, req, &resp) return } diff --git a/federation/signingkey.go b/federation/signingkey.go index a8362247..c13e5f35 100644 --- a/federation/signingkey.go +++ b/federation/signingkey.go @@ -84,6 +84,10 @@ type ServerKeyResponse struct { Raw json.RawMessage `json:"-"` } +type QueryKeysResponse struct { + ServerKeys []*ServerKeyResponse `json:"server_keys"` +} + func (skr *ServerKeyResponse) HasKey(keyID id.KeyID) bool { if skr == nil { return false From a3d5da315fb828e76c01e6e9dca7062d54220cb7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 26 May 2025 21:58:35 +0300 Subject: [PATCH 1165/1647] federation: use errors in signature verification --- federation/client.go | 5 +++-- federation/serverauth.go | 11 +++++++---- federation/serverauth_test.go | 2 +- federation/signingkey.go | 33 ++++++++++++++++++++------------- 4 files changed, 31 insertions(+), 20 deletions(-) diff --git a/federation/client.go b/federation/client.go index 4f26f4b4..7c460d44 100644 --- a/federation/client.go +++ b/federation/client.go @@ -10,6 +10,7 @@ import ( "bytes" "context" "encoding/json" + "fmt" "io" "net/http" "net/url" @@ -398,10 +399,10 @@ type signableRequest struct { Content json.RawMessage `json:"content,omitempty"` } -func (r *signableRequest) Verify(key id.SigningKey, sig string) bool { +func (r *signableRequest) Verify(key id.SigningKey, sig string) error { message, err := json.Marshal(r) if err != nil { - return false + return fmt.Errorf("failed to marshal data: %w", err) } return VerifyJSONRaw(key, sig, message) } diff --git a/federation/serverauth.go b/federation/serverauth.go index 22ce8403..92860cc8 100644 --- a/federation/serverauth.go +++ b/federation/serverauth.go @@ -200,7 +200,10 @@ func (sa *ServerAuth) Authenticate(r *http.Request) (*http.Request, *mautrix.Res Msg("Failed to query keys to authenticate request (cached error)") } return nil, &errFailedToQueryKeys - } else if !resp.VerifySelfSignature() { + } else if err := resp.VerifySelfSignature(); err != nil { + log.Trace().Err(err). + Str("server_name", parsed.Origin). + Msg("Failed to validate self-signatures of server keys") return nil, &errInvalidSelfSignatures } key, ok := resp.VerifyKeys[parsed.KeyID] @@ -226,15 +229,15 @@ func (sa *ServerAuth) Authenticate(r *http.Request) (*http.Request, *mautrix.Res return nil, &errInvalidJSONBody } } - valid := (&signableRequest{ + err = (&signableRequest{ Method: r.Method, URI: r.URL.RawPath, Origin: parsed.Origin, Destination: destination, Content: reqBody, }).Verify(key.Key, parsed.Signature) - if !valid { - log.Trace().Msg("Request has invalid signature") + if err != nil { + log.Trace().Err(err).Msg("Request has invalid signature") return nil, &errInvalidRequestSignature } ctx := context.WithValue(r.Context(), contextKeyDestinationServer, destination) diff --git a/federation/serverauth_test.go b/federation/serverauth_test.go index d79dce36..9fa15459 100644 --- a/federation/serverauth_test.go +++ b/federation/serverauth_test.go @@ -23,7 +23,7 @@ func TestServerKeyResponse_VerifySelfSignature(t *testing.T) { t.Run(name, func(t *testing.T) { resp, err := cli.ServerKeys(ctx, "matrix.org") require.NoError(t, err) - assert.True(t, resp.VerifySelfSignature()) + assert.NoError(t, resp.VerifySelfSignature()) }) } } diff --git a/federation/signingkey.go b/federation/signingkey.go index c13e5f35..5b111947 100644 --- a/federation/signingkey.go +++ b/federation/signingkey.go @@ -10,6 +10,7 @@ import ( "crypto/ed25519" "encoding/base64" "encoding/json" + "errors" "fmt" "strings" "time" @@ -97,50 +98,56 @@ func (skr *ServerKeyResponse) HasKey(keyID id.KeyID) bool { return false } -func (skr *ServerKeyResponse) VerifySelfSignature() bool { +func (skr *ServerKeyResponse) VerifySelfSignature() error { for keyID, key := range skr.VerifyKeys { - if !VerifyJSON(skr.ServerName, keyID, key.Key, skr.Raw) { - return false + if err := VerifyJSON(skr.ServerName, keyID, key.Key, skr.Raw); err != nil { + return fmt.Errorf("failed to verify self signature for key %s: %w", keyID, err) } } - return true + return nil } -func VerifyJSON(serverName string, keyID id.KeyID, key id.SigningKey, data any) bool { +func VerifyJSON(serverName string, keyID id.KeyID, key id.SigningKey, data any) error { var err error message, ok := data.(json.RawMessage) if !ok { message, err = json.Marshal(data) if err != nil { - return false + return fmt.Errorf("failed to marshal data: %w", err) } } sigVal := gjson.GetBytes(message, exgjson.Path("signatures", serverName, string(keyID))) if sigVal.Type != gjson.String { - return false + return ErrSignatureNotFound } message, err = sjson.DeleteBytes(message, "signatures") if err != nil { - return false + return fmt.Errorf("failed to delete signatures: %w", err) } message, err = sjson.DeleteBytes(message, "unsigned") if err != nil { - return false + return fmt.Errorf("failed to delete unsigned: %w", err) } return VerifyJSONRaw(key, sigVal.Str, message) } -func VerifyJSONRaw(key id.SigningKey, sig string, message json.RawMessage) bool { +var ErrSignatureNotFound = errors.New("signature not found") +var ErrInvalidSignature = errors.New("invalid signature") + +func VerifyJSONRaw(key id.SigningKey, sig string, message json.RawMessage) error { sigBytes, err := base64.RawStdEncoding.DecodeString(sig) if err != nil { - return false + return fmt.Errorf("failed to decode signature: %w", err) } keyBytes, err := base64.RawStdEncoding.DecodeString(string(key)) if err != nil { - return false + return fmt.Errorf("failed to decode key: %w", err) } message = canonicaljson.CanonicalJSONAssumeValid(message) - return ed25519.Verify(keyBytes, message, sigBytes) + if !ed25519.Verify(keyBytes, message, sigBytes) { + return ErrInvalidSignature + } + return nil } type marshalableSKR ServerKeyResponse From c7fbfd150f9b2761dcc72eac09bd7f31c9fe51d1 Mon Sep 17 00:00:00 2001 From: nexy7574 Date: Mon, 26 May 2025 20:37:28 +0100 Subject: [PATCH 1166/1647] federation/serverauth: fix URI passed to signableRequest (#381) --- federation/serverauth.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/federation/serverauth.go b/federation/serverauth.go index 92860cc8..f46c7991 100644 --- a/federation/serverauth.go +++ b/federation/serverauth.go @@ -231,7 +231,7 @@ func (sa *ServerAuth) Authenticate(r *http.Request) (*http.Request, *mautrix.Res } err = (&signableRequest{ Method: r.Method, - URI: r.URL.RawPath, + URI: r.URL.EscapedPath(), Origin: parsed.Origin, Destination: destination, Content: reqBody, From e7322f04b80ed68bb5f57657538b06fd91413c61 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 27 May 2025 09:14:12 +0300 Subject: [PATCH 1167/1647] bridgev2: fix handling some cases of context cancellation --- bridgev2/disappear.go | 2 +- bridgev2/portal.go | 12 +++++++++++- bridgev2/userlogin.go | 2 +- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/bridgev2/disappear.go b/bridgev2/disappear.go index d7b2182b..0eea8bc3 100644 --- a/bridgev2/disappear.go +++ b/bridgev2/disappear.go @@ -85,7 +85,7 @@ func (dl *DisappearLoop) Add(ctx context.Context, dm *database.DisappearingMessa Msg("Failed to save disappearing message") } if !dm.DisappearAt.IsZero() && dm.DisappearAt.Before(dl.NextCheck) { - go dl.sleepAndDisappear(context.WithoutCancel(ctx), dm) + go dl.sleepAndDisappear(zerolog.Ctx(ctx).WithContext(dl.br.BackgroundCtx), dm) } } diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 63081f57..d769e9f1 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -3226,6 +3226,10 @@ func (portal *Portal) getInitialMemberList(ctx context.Context, members *ChatMem members.PowerLevels.Apply("", pl) members.memberListToMap(ctx) for _, member := range members.MemberMap { + if ctx.Err() != nil { + err = ctx.Err() + return + } if member.Membership != event.MembershipJoin && member.Membership != "" { continue } @@ -3403,6 +3407,9 @@ func (portal *Portal) syncParticipants(ctx context.Context, members *ChatMemberL } } for _, member := range members.MemberMap { + if ctx.Err() != nil { + return ctx.Err() + } if member.Sender != "" && member.UserInfo != nil { ghost, err := portal.Bridge.GetGhostByID(ctx, member.Sender) if err != nil { @@ -3742,6 +3749,9 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo } portal.UpdateInfo(ctx, info, source, nil, time.Time{}) + if ctx.Err() != nil { + return ctx.Err() + } powerLevels := &event.PowerLevelsEventContent{ Events: map[string]int{ @@ -3866,7 +3876,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo } portal.Bridge.WakeupBackfillQueue() } - withoutCancelCtx := context.WithoutCancel(ctx) + withoutCancelCtx := zerolog.Ctx(ctx).WithContext(portal.Bridge.BackgroundCtx) if portal.Parent != nil { if portal.Parent.MXID != "" { portal.addToParentSpaceAndSave(ctx, true) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 9be3da3f..396cf899 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -301,7 +301,7 @@ func (ul *UserLogin) Delete(ctx context.Context, state status.BridgeState, opts if !opts.unlocked { ul.Bridge.cacheLock.Unlock() } - backgroundCtx := context.WithoutCancel(ctx) + backgroundCtx := zerolog.Ctx(ctx).WithContext(ul.Bridge.BackgroundCtx) if !opts.BlockingCleanup { go ul.deleteSpace(backgroundCtx) } else { From 34afb98ef05d0e9414221f784834a67a39850484 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 20 May 2025 16:57:49 +0530 Subject: [PATCH 1168/1647] event: fix parsing some url preview responses --- event/beeper.go | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/event/beeper.go b/event/beeper.go index 19c6253e..891204e5 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -9,7 +9,9 @@ package event import ( "encoding/base32" "encoding/binary" + "encoding/json" "fmt" + "strconv" "maunium.net/go/mautrix/id" ) @@ -81,6 +83,25 @@ type BeeperRoomKeyAckEventContent struct { FirstMessageIndex int `json:"first_message_index"` } +type IntOrString int + +func (ios *IntOrString) UnmarshalJSON(data []byte) error { + if len(data) > 0 && data[0] == '"' { + var str string + err := json.Unmarshal(data, &str) + if err != nil { + return err + } + intVal, err := strconv.Atoi(str) + if err != nil { + return err + } + *ios = IntOrString(intVal) + return nil + } + return json.Unmarshal(data, (*int)(ios)) +} + type LinkPreview struct { CanonicalURL string `json:"og:url,omitempty"` Title string `json:"og:title,omitempty"` @@ -90,10 +111,10 @@ type LinkPreview struct { ImageURL id.ContentURIString `json:"og:image,omitempty"` - ImageSize int `json:"matrix:image:size,omitempty"` - ImageWidth int `json:"og:image:width,omitempty"` - ImageHeight int `json:"og:image:height,omitempty"` - ImageType string `json:"og:image:type,omitempty"` + ImageSize IntOrString `json:"matrix:image:size,omitempty"` + ImageWidth IntOrString `json:"og:image:width,omitempty"` + ImageHeight IntOrString `json:"og:image:height,omitempty"` + ImageType string `json:"og:image:type,omitempty"` } // BeeperLinkPreview contains the data for a bundled URL preview as specified in MSC4095 From 140b20cab90a8b8d085ecba7535ee5fe3eb6ed6d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 20 May 2025 16:58:36 +0530 Subject: [PATCH 1169/1647] id: add utilities for validating server names --- id/servername.go | 58 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 id/servername.go diff --git a/id/servername.go b/id/servername.go new file mode 100644 index 00000000..591f394a --- /dev/null +++ b/id/servername.go @@ -0,0 +1,58 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package id + +import ( + "regexp" + "strconv" +) + +type ParsedServerNameType int + +const ( + ServerNameDNS ParsedServerNameType = iota + ServerNameIPv4 + ServerNameIPv6 +) + +type ParsedServerName struct { + Type ParsedServerNameType + Host string + Port int +} + +var ServerNameRegex = regexp.MustCompile(`^(?:\[([0-9A-Fa-f:.]{2,45})]|(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})|([0-9A-Za-z.-]{1,255}))(\d{1,5})?$`) + +func ValidateServerName(serverName string) bool { + return len(serverName) <= 255 && len(serverName) > 0 && ServerNameRegex.MatchString(serverName) +} + +func ParseServerName(serverName string) *ParsedServerName { + if len(serverName) > 255 || len(serverName) < 1 { + return nil + } + match := ServerNameRegex.FindStringSubmatch(serverName) + if len(match) != 5 { + return nil + } + port, _ := strconv.Atoi(match[4]) + parsed := &ParsedServerName{ + Port: port, + } + switch { + case match[1] != "": + parsed.Type = ServerNameIPv6 + parsed.Host = match[1] + case match[2] != "": + parsed.Type = ServerNameIPv4 + parsed.Host = match[2] + case match[3] != "": + parsed.Type = ServerNameDNS + parsed.Host = match[3] + } + return parsed +} From cdb99239d36c5227566d1396ca821a940820694d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 27 May 2025 14:05:40 +0530 Subject: [PATCH 1170/1647] bridgev2: add interfaces for reading up to stream order (#379) --- bridgev2/matrixinterface.go | 4 ++++ bridgev2/networkinterface.go | 5 +++++ bridgev2/portal.go | 30 ++++++++++++++++++++++-------- bridgev2/simplevent/receipt.go | 6 ++++++ 4 files changed, 37 insertions(+), 8 deletions(-) diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 4ccba353..ae1b99d7 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -171,6 +171,10 @@ type MatrixAPI interface { MuteRoom(ctx context.Context, roomID id.RoomID, until time.Time) error } +type StreamOrderReadingMatrixAPI interface { + MarkStreamOrderRead(ctx context.Context, roomID id.RoomID, streamOrder int64, ts time.Time) error +} + type MarkAsDMMatrixAPI interface { MarkAsDM(ctx context.Context, roomID id.RoomID, otherUser id.UserID) error } diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 76db8cc8..14d502e3 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -1124,6 +1124,11 @@ type RemoteReadReceipt interface { GetReadUpTo() time.Time } +type RemoteReadReceiptWithStreamOrder interface { + RemoteReadReceipt + GetReadUpToStreamOrder() int64 +} + type RemoteDeliveryReceipt interface { RemoteEvent GetReceiptTargets() []networkid.MessageID diff --git a/bridgev2/portal.go b/bridgev2/portal.go index d769e9f1..1b12c0da 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2693,17 +2693,31 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL log.Err(err).Time("read_up_to", readUpTo).Msg("Failed to get target message for read receipt") } } - if lastTarget == nil { - log.Warn().Msg("No target message found for read receipt") - return - } sender := evt.GetSender() intent := portal.GetIntentFor(ctx, sender, source, RemoteEventReadReceipt) - err = intent.MarkRead(ctx, portal.MXID, lastTarget.MXID, getEventTS(evt)) - if err != nil { - log.Err(err).Stringer("target_mxid", lastTarget.MXID).Msg("Failed to bridge read receipt") + var addTargetLog func(evt *zerolog.Event) *zerolog.Event + if lastTarget == nil { + sevt, evtOK := evt.(RemoteReadReceiptWithStreamOrder) + soIntent, soIntentOK := intent.(StreamOrderReadingMatrixAPI) + if !evtOK || !soIntentOK || sevt.GetReadUpToStreamOrder() == 0 { + log.Warn().Msg("No target message found for read receipt") + return + } + targetStreamOrder := sevt.GetReadUpToStreamOrder() + addTargetLog = func(evt *zerolog.Event) *zerolog.Event { + return evt.Int64("target_stream_order", targetStreamOrder) + } + err = soIntent.MarkStreamOrderRead(ctx, portal.MXID, targetStreamOrder, getEventTS(evt)) } else { - log.Debug().Stringer("target_mxid", lastTarget.MXID).Msg("Bridged read receipt") + addTargetLog = func(evt *zerolog.Event) *zerolog.Event { + return evt.Stringer("target_mxid", lastTarget.MXID) + } + err = intent.MarkRead(ctx, portal.MXID, lastTarget.MXID, getEventTS(evt)) + } + if err != nil { + addTargetLog(log.Err(err)).Msg("Failed to bridge read receipt") + } else { + addTargetLog(log.Debug()).Msg("Bridged read receipt") } if sender.IsFromMe { portal.Bridge.DisappearLoop.StartAll(ctx, portal.MXID) diff --git a/bridgev2/simplevent/receipt.go b/bridgev2/simplevent/receipt.go index 3565986b..41614e40 100644 --- a/bridgev2/simplevent/receipt.go +++ b/bridgev2/simplevent/receipt.go @@ -19,6 +19,8 @@ type Receipt struct { LastTarget networkid.MessageID Targets []networkid.MessageID ReadUpTo time.Time + + ReadUpToStreamOrder int64 } var ( @@ -38,6 +40,10 @@ func (evt *Receipt) GetReadUpTo() time.Time { return evt.ReadUpTo } +func (evt *Receipt) GetReadUpToStreamOrder() int64 { + return evt.ReadUpToStreamOrder +} + type MarkUnread struct { EventMeta Unread bool From 8a745c0d03ec1f7318b8a962c7278544b537d417 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 27 May 2025 11:37:51 +0300 Subject: [PATCH 1171/1647] bridgev2/portal: allow always using deterministic ids for replies --- bridgev2/portal.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 1b12c0da..15412788 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1885,7 +1885,7 @@ func (portal *Portal) getRelationMeta(ctx context.Context, currentMsg networkid. if err != nil { log.Err(err).Msg("Failed to get reply target message from database") } else if replyTo == nil { - if isBatchSend { + if isBatchSend || portal.Bridge.Config.OutgoingMessageReID { // This is somewhat evil replyTo = &database.Message{ MXID: portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, replyToPtr.MessageID, ptr.Val(replyToPtr.PartID)), @@ -1900,7 +1900,7 @@ func (portal *Portal) getRelationMeta(ctx context.Context, currentMsg networkid. if err != nil { log.Err(err).Msg("Failed to get thread root message from database") } else if threadRoot == nil { - if isBatchSend { + if isBatchSend || portal.Bridge.Config.OutgoingMessageReID { threadRoot = &database.Message{ MXID: portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, *threadRootPtr, ""), } From 5c8ea2c2691ea15b8d4228ebcfd8462dcc8b2365 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 27 May 2025 15:54:46 +0300 Subject: [PATCH 1172/1647] synapseadmin: add wrapper for room delete status --- synapseadmin/roomapi.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/synapseadmin/roomapi.go b/synapseadmin/roomapi.go index fa391b73..a91f653f 100644 --- a/synapseadmin/roomapi.go +++ b/synapseadmin/roomapi.go @@ -147,6 +147,12 @@ func (cli *Client) DeleteRoom(ctx context.Context, roomID id.RoomID, req ReqDele return resp, err } +func (cli *Client) DeleteRoomStatus(ctx context.Context, deleteID string) (resp RespDeleteRoomStatus, err error) { + reqURL := cli.BuildAdminURL("v2", "rooms", "delete_status", deleteID) + _, err = cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) + return +} + // DeleteRoomSync deletes a room from the server, optionally blocking it and/or purging all data from the database. // // This calls the synchronous version of the endpoint, which will block until the room is deleted. From 0589b8757b438a582abc17697668a2b395ef1f90 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 27 May 2025 15:57:07 +0300 Subject: [PATCH 1173/1647] synapseadmin: fix response structs again --- synapseadmin/roomapi.go | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/synapseadmin/roomapi.go b/synapseadmin/roomapi.go index a91f653f..a09ba174 100644 --- a/synapseadmin/roomapi.go +++ b/synapseadmin/roomapi.go @@ -127,14 +127,19 @@ type RespDeleteRoom struct { DeleteID string `json:"delete_id"` } -type RespDeleteRoomStatus struct { - Status string `json:"status,omitempty"` +type RespDeleteRoomResult struct { KickedUsers []id.UserID `json:"kicked_users,omitempty"` FailedToKickUsers []id.UserID `json:"failed_to_kick_users,omitempty"` LocalAliases []id.RoomAlias `json:"local_aliases,omitempty"` NewRoomID id.RoomID `json:"new_room_id,omitempty"` } +type RespDeleteRoomStatus struct { + Status string `json:"status,omitempty"` + Error string `json:"error,omitempty"` + ShutdownRoom RespDeleteRoomResult `json:"shutdown_room,omitempty"` +} + // DeleteRoom deletes a room from the server, optionally blocking it and/or purging all data from the database. // // This calls the async version of the endpoint, which will return immediately and delete the room in the background. @@ -158,7 +163,7 @@ func (cli *Client) DeleteRoomStatus(ctx context.Context, deleteID string) (resp // This calls the synchronous version of the endpoint, which will block until the room is deleted. // // https://element-hq.github.io/synapse/latest/admin_api/rooms.html#version-1-old-version -func (cli *Client) DeleteRoomSync(ctx context.Context, roomID id.RoomID, req ReqDeleteRoom) (resp RespDeleteRoomStatus, err error) { +func (cli *Client) DeleteRoomSync(ctx context.Context, roomID id.RoomID, req ReqDeleteRoom) (resp RespDeleteRoomResult, err error) { reqURL := cli.BuildAdminURL("v1", "rooms", roomID) httpClient := &http.Client{} _, err = cli.Client.MakeFullRequest(ctx, mautrix.FullRequest{ From 50cc3d4d470508f253e6aa9a31b6a3e462a328e7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 27 May 2025 16:37:51 +0300 Subject: [PATCH 1174/1647] bridgev2/queue: fix context used for queueing remote events --- bridgev2/queue.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/queue.go b/bridgev2/queue.go index 2981bdce..3d329b22 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -156,7 +156,7 @@ func (ul *UserLogin) QueueRemoteEvent(evt RemoteEvent) { func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) { log := login.Log - ctx := log.WithContext(context.TODO()) + ctx := log.WithContext(br.BackgroundCtx) maybeUncertain, ok := evt.(RemoteEventWithUncertainPortalReceiver) isUncertain := ok && maybeUncertain.PortalReceiverIsUncertain() key := evt.GetPortalKey() From a3092e5195fad2ef1cd71ecbda5c750ed023c9c8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 27 May 2025 16:45:39 +0300 Subject: [PATCH 1175/1647] bridgev2/portal: don't do initial backfill in background --- bridgev2/portal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 15412788..784b1590 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -3943,7 +3943,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo } } } - if portal.Bridge.Config.Backfill.Enabled && portal.RoomType != database.RoomTypeSpace { + if portal.Bridge.Config.Backfill.Enabled && portal.RoomType != database.RoomTypeSpace && !portal.Bridge.Background { portal.doForwardBackfill(ctx, source, nil, backfillBundle) } return nil From f5746ee0f68d7a98e615b02ba58ede4ae0103a42 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 27 May 2025 18:04:52 +0300 Subject: [PATCH 1176/1647] event: add omitempty for mod policy entity Only one of hash and entity should be set --- event/state.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/event/state.go b/event/state.go index 006ed2a5..028691e1 100644 --- a/event/state.go +++ b/event/state.go @@ -231,7 +231,7 @@ func (ph *PolicyHashes) DecodeSHA256() *[32]byte { // ModPolicyContent represents the content of a m.room.rule.user, m.room.rule.room, and m.room.rule.server state event. // https://spec.matrix.org/v1.2/client-server-api/#moderation-policy-lists type ModPolicyContent struct { - Entity string `json:"entity"` + Entity string `json:"entity,omitempty"` Reason string `json:"reason"` Recommendation PolicyRecommendation `json:"recommendation"` UnstableHashes *PolicyHashes `json:"org.matrix.msc4205.hashes,omitempty"` From d89130ba76e87f8cd7f818eb141f92d46adaacf5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 28 May 2025 21:12:48 +0300 Subject: [PATCH 1177/1647] bridgev2/provisioning: fix returning wait errors Closes #382 --- bridgev2/matrix/provisioning.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index d05b005e..83f56fa0 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -500,10 +500,7 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques nextStep, err := login.Process.(bridgev2.LoginProcessDisplayAndWait).Wait(r.Context()) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to wait") - jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ - Err: "Failed to wait", - ErrCode: "M_UNKNOWN", - }) + RespondWithError(w, err, "Internal error waiting for login") return } login.NextStep = nextStep From 64f55ac3a7eb9fba7fc0ad74e2900135253fea52 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 28 May 2025 21:24:15 +0300 Subject: [PATCH 1178/1647] bridgev2/provisioning: use exhttp utilities for writing responses --- bridgev2/matrix/provisioning.go | 196 ++++++++------------------------ 1 file changed, 49 insertions(+), 147 deletions(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 83f56fa0..2a84bdf2 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -21,6 +21,7 @@ import ( "github.com/rs/xid" "github.com/rs/zerolog" "github.com/rs/zerolog/hlog" + "go.mau.fi/util/exhttp" "go.mau.fi/util/exstrings" "go.mau.fi/util/jsontime" "go.mau.fi/util/requestlog" @@ -118,7 +119,7 @@ func (prov *ProvisioningAPI) Init() { prov.Router = prov.br.AS.Router.PathPrefix(prov.br.Config.Provisioning.Prefix).Subrouter() prov.Router.Use(hlog.NewHandler(prov.log)) prov.Router.Use(hlog.RequestIDHandler("request_id", "Request-Id")) - prov.Router.Use(corsMiddleware) + prov.Router.Use(exhttp.CORSMiddleware) prov.Router.Use(requestlog.AccessLogger(false)) prov.Router.Use(prov.AuthMiddleware) prov.Router.Path("/v3/whoami").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetWhoami) @@ -152,25 +153,6 @@ func (prov *ProvisioningAPI) Init() { } } -func corsMiddleware(handler http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization") - if r.Method == http.MethodOptions { - w.WriteHeader(http.StatusOK) - return - } - handler.ServeHTTP(w, r) - }) -} - -func jsonResponse(w http.ResponseWriter, status int, response any) { - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(status) - _ = json.NewEncoder(w).Encode(response) -} - func (prov *ProvisioningAPI) checkMatrixAuth(ctx context.Context, userID id.UserID, token string) error { prov.matrixAuthCacheLock.Lock() defer prov.matrixAuthCacheLock.Unlock() @@ -216,15 +198,9 @@ func (prov *ProvisioningAPI) DebugAuthMiddleware(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") if auth == "" { - jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{ - Err: "Missing auth token", - ErrCode: mautrix.MMissingToken.ErrCode, - }) + mautrix.MMissingToken.WithMessage("Missing auth token").Write(w) } else if !exstrings.ConstantTimeEqual(auth, prov.br.Config.Provisioning.SharedSecret) { - jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{ - Err: "Invalid auth token", - ErrCode: mautrix.MUnknownToken.ErrCode, - }) + mautrix.MUnknownToken.WithMessage("Invalid auth token").Write(w) } else { h.ServeHTTP(w, r) } @@ -238,10 +214,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { auth = prov.GetAuthFromRequest(r) } if auth == "" { - jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{ - Err: "Missing auth token", - ErrCode: mautrix.MMissingToken.ErrCode, - }) + mautrix.MMissingToken.WithMessage("Missing auth token").Write(w) return } userID := id.UserID(r.URL.Query().Get("user_id")) @@ -258,29 +231,20 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { if err != nil { zerolog.Ctx(r.Context()).Warn().Err(err). Msg("Provisioning API request contained invalid auth") - jsonResponse(w, http.StatusUnauthorized, &mautrix.RespError{ - Err: "Invalid auth token", - ErrCode: mautrix.MUnknownToken.ErrCode, - }) + mautrix.MUnknownToken.WithMessage("Invalid auth token").Write(w) return } } user, err := prov.br.Bridge.GetUserByMXID(r.Context(), userID) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get user") - jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ - Err: "Failed to get user", - ErrCode: "M_UNKNOWN", - }) + mautrix.MUnknown.WithMessage("Failed to get user").Write(w) return } // TODO handle user being nil? // TODO per-endpoint permissions? if !user.Permissions.Login { - jsonResponse(w, http.StatusForbidden, &mautrix.RespError{ - Err: "User does not have login permissions", - ErrCode: mautrix.MForbidden.ErrCode, - }) + mautrix.MForbidden.WithMessage("User does not have login permissions").Write(w) return } @@ -292,10 +256,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { prov.loginsLock.RUnlock() if !ok { zerolog.Ctx(r.Context()).Warn().Str("login_id", loginID).Msg("Login not found") - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - Err: "Login not found", - ErrCode: mautrix.MNotFound.ErrCode, - }) + mautrix.MNotFound.WithMessage("Login not found").Write(w) return } login.Lock.Lock() @@ -307,10 +268,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { Str("request_step_id", stepID). Str("expected_step_id", login.NextStep.StepID). Msg("Step ID does not match") - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - Err: "Step ID does not match", - ErrCode: mautrix.MBadState.ErrCode, - }) + mautrix.MBadState.WithMessage("Step ID does not match").Write(w) return } stepType := mux.Vars(r)["stepType"] @@ -319,10 +277,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { Str("request_step_type", stepType). Str("expected_step_type", string(login.NextStep.Type)). Msg("Step type does not match") - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - Err: "Step type does not match", - ErrCode: mautrix.MBadState.ErrCode, - }) + mautrix.MBadState.WithMessage("Step type does not match").Write(w) return } ctx = context.WithValue(ctx, provisioningLoginProcessKey, login) @@ -391,7 +346,7 @@ func (prov *ProvisioningAPI) GetWhoami(w http.ResponseWriter, r *http.Request) { SpaceRoom: login.SpaceRoom, } } - jsonResponse(w, http.StatusOK, resp) + exhttp.WriteJSONResponse(w, http.StatusOK, resp) } type RespLoginFlows struct { @@ -404,7 +359,7 @@ type RespSubmitLogin struct { } func (prov *ProvisioningAPI) GetLoginFlows(w http.ResponseWriter, r *http.Request) { - jsonResponse(w, http.StatusOK, &RespLoginFlows{ + exhttp.WriteJSONResponse(w, http.StatusOK, &RespLoginFlows{ Flows: prov.net.GetLoginFlows(), }) } @@ -445,7 +400,7 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque Override: overrideLogin, } prov.loginsLock.Unlock() - jsonResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: loginID, LoginStep: firstStep}) + exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: loginID, LoginStep: firstStep}) } func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *ProvLogin, step *bridgev2.LoginStep) { @@ -467,10 +422,7 @@ func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http err := json.NewDecoder(r.Body).Decode(¶ms) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - Err: "Failed to decode request body", - ErrCode: mautrix.MNotJSON.ErrCode, - }) + mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w) return } login := r.Context().Value(provisioningLoginProcessKey).(*ProvLogin) @@ -492,7 +444,7 @@ func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http if nextStep.Type == bridgev2.LoginStepTypeComplete { prov.handleCompleteStep(r.Context(), login, nextStep) } - jsonResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) + exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) } func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Request) { @@ -507,7 +459,7 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques if nextStep.Type == bridgev2.LoginStepTypeComplete { prov.handleCompleteStep(r.Context(), login, nextStep) } - jsonResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) + exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) } func (prov *ProvisioningAPI) PostLogout(w http.ResponseWriter, r *http.Request) { @@ -524,15 +476,12 @@ func (prov *ProvisioningAPI) PostLogout(w http.ResponseWriter, r *http.Request) } else { userLogin := prov.br.Bridge.GetCachedUserLoginByID(userLoginID) if userLogin == nil || userLogin.UserMXID != user.MXID { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - Err: "Login not found", - ErrCode: mautrix.MNotFound.ErrCode, - }) + mautrix.MNotFound.WithMessage("Login not found").Write(w) return } userLogin.Logout(r.Context()) } - jsonResponse(w, http.StatusOK, json.RawMessage("{}")) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) } type RespGetLogins struct { @@ -541,7 +490,7 @@ type RespGetLogins struct { func (prov *ProvisioningAPI) GetLogins(w http.ResponseWriter, r *http.Request) { user := prov.GetUser(r) - jsonResponse(w, http.StatusOK, &RespGetLogins{LoginIDs: user.GetUserLoginIDs()}) + exhttp.WriteJSONResponse(w, http.StatusOK, &RespGetLogins{LoginIDs: user.GetUserLoginIDs()}) } func (prov *ProvisioningAPI) GetExplicitLoginForRequest(w http.ResponseWriter, r *http.Request) (*bridgev2.UserLogin, bool) { @@ -551,15 +500,18 @@ func (prov *ProvisioningAPI) GetExplicitLoginForRequest(w http.ResponseWriter, r } userLogin := prov.br.Bridge.GetCachedUserLoginByID(userLoginID) if userLogin == nil || userLogin.UserMXID != prov.GetUser(r).MXID { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - Err: "Login not found", - ErrCode: mautrix.MNotFound.ErrCode, - }) + mautrix.MNotFound.WithMessage("Login not found").Write(w) return nil, true } return userLogin, false } +var ErrNotLoggedIn = mautrix.RespError{ + Err: "Not logged in", + ErrCode: "FI.MAU.NOT_LOGGED_IN", + StatusCode: http.StatusBadRequest, +} + func (prov *ProvisioningAPI) GetLoginForRequest(w http.ResponseWriter, r *http.Request) *bridgev2.UserLogin { userLogin, failed := prov.GetExplicitLoginForRequest(w, r) if userLogin != nil || failed { @@ -567,10 +519,7 @@ func (prov *ProvisioningAPI) GetLoginForRequest(w http.ResponseWriter, r *http.R } userLogin = prov.GetUser(r).GetDefaultLogin() if userLogin == nil { - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - Err: "Not logged in", - ErrCode: "FI.MAU.NOT_LOGGED_IN", - }) + ErrNotLoggedIn.Write(w) return nil } return userLogin @@ -585,11 +534,7 @@ func RespondWithError(w http.ResponseWriter, err error, message string) { if errors.As(err, &we) { we.Write(w) } else { - mautrix.RespError{ - Err: message, - ErrCode: "M_UNKNOWN", - StatusCode: http.StatusInternalServerError, - }.Write(w) + mautrix.MUnknown.WithMessage(message).Write(w) } } @@ -609,10 +554,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http. } api, ok := login.Client.(bridgev2.IdentifierResolvingNetworkAPI) if !ok { - jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ - Err: "This bridge does not support resolving identifiers", - ErrCode: mautrix.MUnrecognized.ErrCode, - }) + mautrix.MUnrecognized.WithMessage("This bridge does not support resolving identifiers").Write(w) return } resp, err := api.ResolveIdentifier(r.Context(), mux.Vars(r)["identifier"], createChat) @@ -621,10 +563,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http. RespondWithError(w, err, "Internal error resolving identifier") return } else if resp == nil { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - ErrCode: mautrix.MNotFound.ErrCode, - Err: "Identifier not found", - }) + mautrix.MNotFound.WithMessage("Identifier not found").Write(w) return } apiResp := &RespResolveIdentifier{ @@ -647,10 +586,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http. resp.Chat.Portal, err = prov.br.Bridge.GetPortalByKey(r.Context(), resp.Chat.PortalKey) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get portal") - jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ - Err: "Failed to get portal", - ErrCode: "M_UNKNOWN", - }) + mautrix.MUnknown.WithMessage("Failed to get portal").Write(w) return } } @@ -659,16 +595,13 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http. err = resp.Chat.Portal.CreateMatrixRoom(r.Context(), login, resp.Chat.PortalInfo) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create portal room") - jsonResponse(w, http.StatusInternalServerError, &mautrix.RespError{ - Err: "Failed to create portal room", - ErrCode: "M_UNKNOWN", - }) + mautrix.MUnknown.WithMessage("Failed to create portal room").Write(w) return } } apiResp.DMRoomID = resp.Chat.Portal.MXID } - jsonResponse(w, status, apiResp) + exhttp.WriteJSONResponse(w, status, apiResp) } type RespGetContactList struct { @@ -723,10 +656,7 @@ func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Reque } api, ok := login.Client.(bridgev2.ContactListingNetworkAPI) if !ok { - jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ - Err: "This bridge does not support listing contacts", - ErrCode: mautrix.MUnrecognized.ErrCode, - }) + mautrix.MUnrecognized.WithMessage("This bridge does not support listing contacts").Write(w) return } resp, err := api.GetContactList(r.Context()) @@ -735,7 +665,7 @@ func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Reque RespondWithError(w, err, "Internal error fetching contact list") return } - jsonResponse(w, http.StatusOK, &RespGetContactList{ + exhttp.WriteJSONResponse(w, http.StatusOK, &RespGetContactList{ Contacts: prov.processResolveIdentifiers(r.Context(), resp), }) } @@ -753,10 +683,7 @@ func (prov *ProvisioningAPI) PostSearchUsers(w http.ResponseWriter, r *http.Requ err := json.NewDecoder(r.Body).Decode(&req) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - Err: "Failed to decode request body", - ErrCode: mautrix.MNotJSON.ErrCode, - }) + mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w) return } login := prov.GetLoginForRequest(w, r) @@ -765,10 +692,7 @@ func (prov *ProvisioningAPI) PostSearchUsers(w http.ResponseWriter, r *http.Requ } api, ok := login.Client.(bridgev2.UserSearchingNetworkAPI) if !ok { - jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ - Err: "This bridge does not support searching for users", - ErrCode: mautrix.MUnrecognized.ErrCode, - }) + mautrix.MUnrecognized.WithMessage("This bridge does not support searching for users").Write(w) return } resp, err := api.SearchUsers(r.Context(), req.Query) @@ -777,7 +701,7 @@ func (prov *ProvisioningAPI) PostSearchUsers(w http.ResponseWriter, r *http.Requ RespondWithError(w, err, "Internal error fetching contact list") return } - jsonResponse(w, http.StatusOK, &RespSearchUsers{ + exhttp.WriteJSONResponse(w, http.StatusOK, &RespSearchUsers{ Results: prov.processResolveIdentifiers(r.Context(), resp), }) } @@ -795,10 +719,7 @@ func (prov *ProvisioningAPI) PostCreateGroup(w http.ResponseWriter, r *http.Requ if login == nil { return } - jsonResponse(w, http.StatusNotImplemented, &mautrix.RespError{ - Err: "Creating groups is not yet implemented", - ErrCode: mautrix.MUnrecognized.ErrCode, - }) + mautrix.MUnrecognized.WithMessage("Creating groups is not yet implemented").Write(w) } type ReqExportCredentials struct { @@ -817,10 +738,7 @@ func (prov *ProvisioningAPI) PostInitSessionTransfer(w http.ResponseWriter, r *h err := json.NewDecoder(r.Body).Decode(&req) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - Err: "Failed to decode request body", - ErrCode: mautrix.MNotJSON.ErrCode, - }) + mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w) return } @@ -834,19 +752,13 @@ func (prov *ProvisioningAPI) PostInitSessionTransfer(w http.ResponseWriter, r *h } } if loginToExport == nil { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - Err: "No matching user login found", - ErrCode: mautrix.MNotFound.ErrCode, - }) + mautrix.MNotFound.WithMessage("No matching user login found").Write(w) return } client, ok := loginToExport.Client.(bridgev2.CredentialExportingNetworkAPI) if !ok { - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - Err: "Client does not support credential exporting", - ErrCode: mautrix.MInvalidParam.ErrCode, - }) + mautrix.MUnrecognized.WithMessage("This bridge does not support exporting credentials").Write(w) return } @@ -858,10 +770,9 @@ func (prov *ProvisioningAPI) PostInitSessionTransfer(w http.ResponseWriter, r *h // Disconnect now so we don't use the same network session in two places at once client.Disconnect() - resp := RespExportCredentials{ + exhttp.WriteJSONResponse(w, http.StatusOK, &RespExportCredentials{ Credentials: client.ExportCredentials(r.Context()), - } - jsonResponse(w, http.StatusOK, resp) + }) } func (prov *ProvisioningAPI) PostFinishSessionTransfer(w http.ResponseWriter, r *http.Request) { @@ -872,10 +783,7 @@ func (prov *ProvisioningAPI) PostFinishSessionTransfer(w http.ResponseWriter, r err := json.NewDecoder(r.Body).Decode(&req) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - Err: "Failed to decode request body", - ErrCode: mautrix.MNotJSON.ErrCode, - }) + mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w) return } @@ -889,16 +797,10 @@ func (prov *ProvisioningAPI) PostFinishSessionTransfer(w http.ResponseWriter, r } } if loginToExport == nil { - jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ - Err: "No matching user login found", - ErrCode: mautrix.MNotFound.ErrCode, - }) + mautrix.MNotFound.WithMessage("No matching user login found").Write(w) return } else if _, ok := prov.sessionTransfers[loginToExport.ID]; !ok { - jsonResponse(w, http.StatusBadRequest, &mautrix.RespError{ - Err: "No matching credential export found", - ErrCode: mautrix.MNotJSON.ErrCode, - }) + mautrix.MBadState.WithMessage("No matching credential export found").Write(w) return } @@ -909,5 +811,5 @@ func (prov *ProvisioningAPI) PostFinishSessionTransfer(w http.ResponseWriter, r loginToExport.Client.LogoutRemote(r.Context()) delete(prov.sessionTransfers, req.RemoteID) - jsonResponse(w, http.StatusOK, struct{}{}) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) } From 53d027c06ffb1cc85df7e66555b4a2b8b2e9ec2b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 28 May 2025 21:34:46 +0300 Subject: [PATCH 1179/1647] appservice: replace custom response utilities with RespError and exhttp --- appservice/http.go | 73 +++++++++++------------------------------ appservice/protocol.go | 51 +--------------------------- appservice/websocket.go | 8 +++-- 3 files changed, 25 insertions(+), 107 deletions(-) diff --git a/appservice/http.go b/appservice/http.go index 66c7bc5b..1ebe6e56 100644 --- a/appservice/http.go +++ b/appservice/http.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -19,6 +19,7 @@ import ( "github.com/gorilla/mux" "github.com/rs/zerolog" + "go.mau.fi/util/exhttp" "go.mau.fi/util/exstrings" "maunium.net/go/mautrix" @@ -79,17 +80,9 @@ func (as *AppService) Stop() { func (as *AppService) CheckServerToken(w http.ResponseWriter, r *http.Request) (isValid bool) { authHeader := r.Header.Get("Authorization") if !strings.HasPrefix(authHeader, "Bearer ") { - Error{ - ErrorCode: ErrUnknownToken, - HTTPStatus: http.StatusForbidden, - Message: "Missing access token", - }.Write(w) + mautrix.MMissingToken.WithMessage("Missing access token").Write(w) } else if !exstrings.ConstantTimeEqual(authHeader[len("Bearer "):], as.Registration.ServerToken) { - Error{ - ErrorCode: ErrUnknownToken, - HTTPStatus: http.StatusForbidden, - Message: "Incorrect access token", - }.Write(w) + mautrix.MUnknownToken.WithMessage("Invalid access token").Write(w) } else { isValid = true } @@ -105,21 +98,13 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) txnID := vars["txnID"] if len(txnID) == 0 { - Error{ - ErrorCode: ErrNoTransactionID, - HTTPStatus: http.StatusBadRequest, - Message: "Missing transaction ID", - }.Write(w) + mautrix.MInvalidParam.WithMessage("Missing transaction ID").Write(w) return } defer r.Body.Close() body, err := io.ReadAll(r.Body) if err != nil || len(body) == 0 { - Error{ - ErrorCode: ErrNotJSON, - HTTPStatus: http.StatusBadRequest, - Message: "Missing request body", - }.Write(w) + mautrix.MNotJSON.WithMessage("Failed to read response body").Write(w) return } log := as.Log.With().Str("transaction_id", txnID).Logger() @@ -128,7 +113,7 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) { ctx = log.WithContext(ctx) if as.txnIDC.IsProcessed(txnID) { // Duplicate transaction ID: no-op - WriteBlankOK(w) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) log.Debug().Msg("Ignoring duplicate transaction") return } @@ -137,14 +122,10 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) { err = json.Unmarshal(body, &txn) if err != nil { log.Error().Err(err).Msg("Failed to parse transaction content") - Error{ - ErrorCode: ErrBadJSON, - HTTPStatus: http.StatusBadRequest, - Message: "Failed to parse body JSON", - }.Write(w) + mautrix.MBadJSON.WithMessage("Failed to parse transaction content").Write(w) } else { as.handleTransaction(ctx, txnID, &txn) - WriteBlankOK(w) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) } } @@ -263,12 +244,9 @@ func (as *AppService) GetRoom(w http.ResponseWriter, r *http.Request) { roomAlias := vars["roomAlias"] ok := as.QueryHandler.QueryAlias(roomAlias) if ok { - WriteBlankOK(w) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) } else { - Error{ - ErrorCode: ErrUnknown, - HTTPStatus: http.StatusNotFound, - }.Write(w) + mautrix.MNotFound.WithMessage("Alias not found").Write(w) } } @@ -282,12 +260,9 @@ func (as *AppService) GetUser(w http.ResponseWriter, r *http.Request) { userID := id.UserID(vars["userID"]) ok := as.QueryHandler.QueryUser(userID) if ok { - WriteBlankOK(w) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) } else { - Error{ - ErrorCode: ErrUnknown, - HTTPStatus: http.StatusNotFound, - }.Write(w) + mautrix.MNotFound.WithMessage("User not found").Write(w) } } @@ -297,11 +272,7 @@ func (as *AppService) PostPing(w http.ResponseWriter, r *http.Request) { } body, err := io.ReadAll(r.Body) if err != nil || len(body) == 0 || !json.Valid(body) { - Error{ - ErrorCode: ErrNotJSON, - HTTPStatus: http.StatusBadRequest, - Message: "Missing request body", - }.Write(w) + mautrix.MNotJSON.WithMessage("Invalid or missing request body").Write(w) return } @@ -309,27 +280,21 @@ func (as *AppService) PostPing(w http.ResponseWriter, r *http.Request) { _ = json.Unmarshal(body, &txn) as.Log.Debug().Str("txn_id", txn.TxnID).Msg("Received ping from homeserver") - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - w.Write([]byte("{}")) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) } func (as *AppService) GetLive(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Content-Type", "application/json") if as.Live { - w.WriteHeader(http.StatusOK) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) } else { - w.WriteHeader(http.StatusInternalServerError) + exhttp.WriteEmptyJSONResponse(w, http.StatusInternalServerError) } - w.Write([]byte("{}")) } func (as *AppService) GetReady(w http.ResponseWriter, r *http.Request) { - w.Header().Add("Content-Type", "application/json") if as.Ready { - w.WriteHeader(http.StatusOK) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) } else { - w.WriteHeader(http.StatusInternalServerError) + exhttp.WriteEmptyJSONResponse(w, http.StatusInternalServerError) } - w.Write([]byte("{}")) } diff --git a/appservice/protocol.go b/appservice/protocol.go index 7a9891ef..7c493bcb 100644 --- a/appservice/protocol.go +++ b/appservice/protocol.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -7,9 +7,7 @@ package appservice import ( - "encoding/json" "fmt" - "net/http" "strings" "github.com/rs/zerolog" @@ -103,50 +101,3 @@ func (txn *Transaction) ContentString() string { // EventListener is a function that receives events. type EventListener func(evt *event.Event) - -// WriteBlankOK writes a blank OK message as a reply to a HTTP request. -func WriteBlankOK(w http.ResponseWriter) { - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("{}")) -} - -// Respond responds to a HTTP request with a JSON object. -func Respond(w http.ResponseWriter, data interface{}) error { - w.Header().Add("Content-Type", "application/json") - dataStr, err := json.Marshal(data) - if err != nil { - return err - } - _, err = w.Write(dataStr) - return err -} - -// Error represents a Matrix protocol error. -type Error struct { - HTTPStatus int `json:"-"` - ErrorCode ErrorCode `json:"errcode"` - Message string `json:"error"` -} - -func (err Error) Write(w http.ResponseWriter) { - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(err.HTTPStatus) - _ = Respond(w, &err) -} - -// ErrorCode is the machine-readable code in an Error. -type ErrorCode string - -// Native ErrorCodes -const ( - ErrUnknownToken ErrorCode = "M_UNKNOWN_TOKEN" - ErrBadJSON ErrorCode = "M_BAD_JSON" - ErrNotJSON ErrorCode = "M_NOT_JSON" - ErrUnknown ErrorCode = "M_UNKNOWN" -) - -// Custom ErrorCodes -const ( - ErrNoTransactionID ErrorCode = "NET.MAUNIUM.NO_TRANSACTION_ID" -) diff --git a/appservice/websocket.go b/appservice/websocket.go index 598d70d1..3d5bd232 100644 --- a/appservice/websocket.go +++ b/appservice/websocket.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -23,6 +23,8 @@ import ( "github.com/rs/zerolog" "github.com/tidwall/gjson" "github.com/tidwall/sjson" + + "maunium.net/go/mautrix" ) type WebsocketRequest struct { @@ -371,12 +373,12 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error { "X-Mautrix-Websocket-Version": []string{"3"}, }) if resp != nil && resp.StatusCode >= 400 { - var errResp Error + var errResp mautrix.RespError err = json.NewDecoder(resp.Body).Decode(&errResp) if err != nil { return fmt.Errorf("websocket request returned HTTP %d with non-JSON body", resp.StatusCode) } else { - return fmt.Errorf("websocket request returned %s (HTTP %d): %s", errResp.ErrorCode, resp.StatusCode, errResp.Message) + return fmt.Errorf("websocket request returned %s (HTTP %d): %s", errResp.ErrCode, resp.StatusCode, errResp.Err) } } else if err != nil { return fmt.Errorf("failed to open websocket: %w", err) From f73480446c6f5377fa0d419012fa4a3d24fab0f4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 28 May 2025 21:39:34 +0300 Subject: [PATCH 1180/1647] mediaproxy: remove deprecated custom ResponseError struct --- mediaproxy/mediaproxy.go | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go index d76439a1..1300a305 100644 --- a/mediaproxy/mediaproxy.go +++ b/mediaproxy/mediaproxy.go @@ -8,7 +8,6 @@ package mediaproxy import ( "context" - "encoding/json" "errors" "fmt" "io" @@ -223,16 +222,6 @@ func (mp *MediaProxy) RegisterRoutes(router *mux.Router) { mp.KeyServer.Register(router) } -// Deprecated: use mautrix.RespError instead -type ResponseError struct { - Status int - Data any -} - -func (err *ResponseError) Error() string { - return fmt.Sprintf("HTTP %d: %v", err.Status, err.Data) -} - var ErrInvalidMediaIDSyntax = errors.New("invalid media ID syntax") func queryToMap(vals url.Values) map[string]string { @@ -247,17 +236,11 @@ func (mp *MediaProxy) getMedia(w http.ResponseWriter, r *http.Request) GetMediaR mediaID := mux.Vars(r)["mediaID"] resp, err := mp.GetMedia(r.Context(), mediaID, queryToMap(r.URL.Query())) if err != nil { - //lint:ignore SA1019 deprecated types need to be supported until they're removed - var respError *ResponseError var mautrixRespError mautrix.RespError if errors.Is(err, ErrInvalidMediaIDSyntax) { mautrix.MNotFound.WithMessage("This is a media proxy at %q, other media downloads are not available here", mp.serverName).Write(w) } else if errors.As(err, &mautrixRespError) { mautrixRespError.Write(w) - } else if errors.As(err, &respError) { - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(respError.Status) - _ = json.NewEncoder(w).Encode(respError.Data) } else { zerolog.Ctx(r.Context()).Err(err).Str("media_id", mediaID).Msg("Failed to get media URL") mautrix.MNotFound.WithMessage("Media not found").Write(w) From 842f21b24f9fe25f09b0484bdee85c4ac5ffe411 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 28 May 2025 22:25:17 +0300 Subject: [PATCH 1181/1647] bridgev2/provisioning: add log when explicitly specified login ID is not found --- bridgev2/matrix/provisioning.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 2a84bdf2..571e3c7f 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -500,6 +500,9 @@ func (prov *ProvisioningAPI) GetExplicitLoginForRequest(w http.ResponseWriter, r } userLogin := prov.br.Bridge.GetCachedUserLoginByID(userLoginID) if userLogin == nil || userLogin.UserMXID != prov.GetUser(r).MXID { + hlog.FromRequest(r).Warn(). + Str("login_id", string(userLoginID)). + Msg("Tried to use non-existent login, returning 404") mautrix.MNotFound.WithMessage("Login not found").Write(w) return nil, true } From 3473f918645d1dbd22a97a27a7f9b4a2b7a7156f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 29 May 2025 13:58:05 +0300 Subject: [PATCH 1182/1647] bridgev2/portal: add some default log context fields for remote events --- bridgev2/networkinterface.go | 13 +++++++++++++ bridgev2/portal.go | 18 ++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 14d502e3..2b99e4e6 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -77,6 +77,19 @@ type EventSender struct { ForceDMUser bool } +func (es EventSender) MarshalZerologObject(evt *zerolog.Event) { + evt.Str("user_id", string(es.Sender)) + if string(es.SenderLogin) != string(es.Sender) { + evt.Str("sender_login", string(es.SenderLogin)) + } + if es.IsFromMe { + evt.Bool("is_from_me", true) + } + if es.ForceDMUser { + evt.Bool("force_dm_user", true) + } +} + type ConvertedMessage struct { ReplyTo *networkid.MessageOptionalPartID ThreadRoot *networkid.MessageID diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 784b1590..126c64a9 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -372,6 +372,24 @@ func (portal *Portal) getEventCtxWithLog(rawEvt any, idx int) context.Context { Str("source_id", string(evt.source.ID)). Stringer("bridge_evt_type", evt.evtType) logWith = evt.evt.AddLogContext(logWith) + if remoteSender := evt.evt.GetSender(); remoteSender.Sender != "" || remoteSender.IsFromMe { + logWith = logWith.Object("remote_sender", remoteSender) + } + if remoteMsg, ok := evt.evt.(RemoteMessage); ok { + if remoteMsgID := remoteMsg.GetID(); remoteMsgID != "" { + logWith = logWith.Str("remote_message_id", string(remoteMsgID)) + } + } + if remoteMsg, ok := evt.evt.(RemoteEventWithTargetMessage); ok { + if targetMsgID := remoteMsg.GetTargetMessage(); targetMsgID != "" { + logWith = logWith.Str("remote_target_message_id", string(targetMsgID)) + } + } + if remoteMsg, ok := evt.evt.(RemoteEventWithStreamOrder); ok { + if remoteStreamOrder := remoteMsg.GetStreamOrder(); remoteStreamOrder != 0 { + logWith = logWith.Int64("remote_stream_order", remoteStreamOrder) + } + } case *portalCreateEvent: return evt.ctx } From e859fd8333411060120815e7ec0523a4bae54457 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Fri, 30 May 2025 15:09:43 +0100 Subject: [PATCH 1183/1647] bridgev2/bridgeconfig: add missing copy for session transfer config --- bridgev2/bridgeconfig/upgrade.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 18b98263..3e19bf8f 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -109,6 +109,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Str, "provisioning", "shared_secret") } helper.Copy(up.Bool, "provisioning", "debug_endpoints") + helper.Copy(up.Bool, "provisioning", "enable_session_transfers") helper.Copy(up.Bool, "direct_media", "enabled") helper.Copy(up.Str|up.Null, "direct_media", "media_id_prefix") From 1b1b83298c33820c61a79151bc422b86e1da2183 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 30 May 2025 12:11:23 +0300 Subject: [PATCH 1184/1647] client,bridgev2: use time.After instead of sleep --- bridgev2/disappear.go | 6 +++++- client.go | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/bridgev2/disappear.go b/bridgev2/disappear.go index 0eea8bc3..1d063088 100644 --- a/bridgev2/disappear.go +++ b/bridgev2/disappear.go @@ -91,7 +91,11 @@ func (dl *DisappearLoop) Add(ctx context.Context, dm *database.DisappearingMessa func (dl *DisappearLoop) sleepAndDisappear(ctx context.Context, dms ...*database.DisappearingMessage) { for _, msg := range dms { - time.Sleep(time.Until(msg.DisappearAt)) + select { + case <-time.After(time.Until(msg.DisappearAt)): + case <-ctx.Done(): + return + } resp, err := dl.br.Bot.SendMessage(ctx, msg.RoomID, event.EventRedaction, &event.Content{ Parsed: &event.RedactionEventContent{ Redacts: msg.EventID, diff --git a/client.go b/client.go index bf3bb16e..5e2189e9 100644 --- a/client.go +++ b/client.go @@ -555,7 +555,11 @@ func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff log.Warn().Err(cause). Int("retry_in_seconds", int(backoff.Seconds())). Msg("Request failed, retrying") - time.Sleep(backoff) + select { + case <-time.After(backoff): + case <-req.Context().Done(): + return nil, nil, req.Context().Err() + } if cli.UpdateRequestOnRetry != nil { req = cli.UpdateRequestOnRetry(req, cause) } From 788621f7e03503283124bf23103be39b2a4139b4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 31 May 2025 17:59:25 +0300 Subject: [PATCH 1185/1647] bridgev2/crypto: fix ghost ID format in db queries --- bridgev2/matrix/crypto.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bridgev2/matrix/crypto.go b/bridgev2/matrix/crypto.go index 6e6416a9..47226625 100644 --- a/bridgev2/matrix/crypto.go +++ b/bridgev2/matrix/crypto.go @@ -14,6 +14,7 @@ import ( "fmt" "os" "runtime/debug" + "strings" "sync" "time" @@ -77,7 +78,7 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { dbutil.ZeroLogger(helper.bridge.Log.With().Str("db_section", "crypto").Logger()), string(helper.bridge.Bridge.ID), helper.bridge.AS.BotMXID(), - fmt.Sprintf("@%s:%s", helper.bridge.Config.AppService.FormatUsername("%"), helper.bridge.AS.HomeserverDomain), + fmt.Sprintf("@%s:%s", strings.ReplaceAll(helper.bridge.Config.AppService.FormatUsername("%"), "_", `\_`), helper.bridge.AS.HomeserverDomain), helper.bridge.Config.Encryption.PickleKey, ) From 8fb04d1806970a1b9387ca423c01affe22ca56ee Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 2 Jun 2025 20:04:19 +0300 Subject: [PATCH 1186/1647] id/matrixuri: fix parsing url-encoded matrix URIs --- id/matrixuri.go | 11 +++++++++-- id/matrixuri_test.go | 4 ++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/id/matrixuri.go b/id/matrixuri.go index 2637d876..8f5ec849 100644 --- a/id/matrixuri.go +++ b/id/matrixuri.go @@ -210,7 +210,11 @@ func ProcessMatrixURI(uri *url.URL) (*MatrixURI, error) { if len(parts[1]) == 0 { return nil, ErrEmptySecondSegment } - parsed.MXID1 = parts[1] + var err error + parsed.MXID1, err = url.PathUnescape(parts[1]) + if err != nil { + return nil, fmt.Errorf("failed to url decode second segment %q: %w", parts[1], err) + } // Step 6: if the first part is a room and the URI has 4 segments, construct a second level identifier if parsed.Sigil1 == '!' && len(parts) == 4 { @@ -226,7 +230,10 @@ func ProcessMatrixURI(uri *url.URL) (*MatrixURI, error) { if len(parts[3]) == 0 { return nil, ErrEmptyFourthSegment } - parsed.MXID2 = parts[3] + parsed.MXID2, err = url.PathUnescape(parts[3]) + if err != nil { + return nil, fmt.Errorf("failed to url decode fourth segment %q: %w", parts[3], err) + } } // Step 7: parse the query and extract via and action items diff --git a/id/matrixuri_test.go b/id/matrixuri_test.go index 8b1096cb..90a0754d 100644 --- a/id/matrixuri_test.go +++ b/id/matrixuri_test.go @@ -77,8 +77,12 @@ func TestParseMatrixURI_RoomID(t *testing.T) { parsedVia, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org?via=maunium.net&via=matrix.org") require.NoError(t, err) require.NotNil(t, parsedVia) + parsedEncoded, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl%3Aexample.org") + require.NoError(t, err) + require.NotNil(t, parsedEncoded) assert.Equal(t, roomIDLink, *parsed) + assert.Equal(t, roomIDLink, *parsedEncoded) assert.Equal(t, roomIDViaLink, *parsedVia) } From 522a373c688150e473a1b386b83023705ac88519 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 3 Jun 2025 00:32:07 +0300 Subject: [PATCH 1187/1647] id: validate server names in UserID.ParseAndValidate --- id/userid.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/id/userid.go b/id/userid.go index 59136013..6d9f4080 100644 --- a/id/userid.go +++ b/id/userid.go @@ -30,10 +30,11 @@ func NewEncodedUserID(localpart, homeserver string) UserID { } var ( - ErrInvalidUserID = errors.New("is not a valid user ID") - ErrNoncompliantLocalpart = errors.New("contains characters that are not allowed") - ErrUserIDTooLong = errors.New("the given user ID is longer than 255 characters") - ErrEmptyLocalpart = errors.New("empty localparts are not allowed") + ErrInvalidUserID = errors.New("is not a valid user ID") + ErrNoncompliantLocalpart = errors.New("contains characters that are not allowed") + ErrUserIDTooLong = errors.New("the given user ID is longer than 255 characters") + ErrEmptyLocalpart = errors.New("empty localparts are not allowed") + ErrNoncompliantServerPart = errors.New("is not a valid server name") ) // ParseCommonIdentifier parses a common identifier according to https://spec.matrix.org/v1.9/appendices/#common-identifier-format @@ -113,6 +114,9 @@ func (userID UserID) ParseAndValidate() (localpart, homeserver string, err error if err == nil && len(userID) > UserIDMaxLength { err = ErrUserIDTooLong } + if err == nil && !ValidateServerName(homeserver) { + err = fmt.Errorf("%q %q", homeserver, ErrNoncompliantServerPart) + } return } From d804b5d96187c90d8bd18b466e7d1fb4aee3a510 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 4 Jun 2025 14:48:22 +0300 Subject: [PATCH 1188/1647] client: add support for stable version of room summary endpoint --- client.go | 8 ++++++-- versions.go | 2 ++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index 5e2189e9..6f746015 100644 --- a/client.go +++ b/client.go @@ -1062,13 +1062,17 @@ func (cli *Client) GetMutualRooms(ctx context.Context, otherUserID id.UserID, ex } func (cli *Client) GetRoomSummary(ctx context.Context, roomIDOrAlias string, via ...string) (resp *RespRoomSummary, err error) { + urlPath := ClientURLPath{"unstable", "im.nheko.summary", "summary", roomIDOrAlias} + if cli.SpecVersions.ContainsGreaterOrEqual(SpecV115) { + urlPath = ClientURLPath{"v1", "room_summary", roomIDOrAlias} + } // TODO add version check after one is added to MSC3266 - urlPath := cli.BuildURLWithFullQuery(ClientURLPath{"unstable", "im.nheko.summary", "summary", roomIDOrAlias}, func(q url.Values) { + fullURL := cli.BuildURLWithFullQuery(urlPath, func(q url.Values) { if len(via) > 0 { q["via"] = via } }) - _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + _, err = cli.MakeRequest(ctx, http.MethodGet, fullURL, nil, &resp) return } diff --git a/versions.go b/versions.go index 7e752986..f87bddda 100644 --- a/versions.go +++ b/versions.go @@ -115,6 +115,8 @@ var ( SpecV111 = MustParseSpecVersion("v1.11") SpecV112 = MustParseSpecVersion("v1.12") SpecV113 = MustParseSpecVersion("v1.13") + SpecV114 = MustParseSpecVersion("v1.14") + SpecV115 = MustParseSpecVersion("v1.15") ) func (svf SpecVersionFormat) String() string { From baf4cc3ee43e541ba583bf3974c42c1b95b7d4e1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 4 Jun 2025 16:13:16 +0300 Subject: [PATCH 1189/1647] bridgev2/portal: log start time when event handling takes long --- bridgev2/portal.go | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 126c64a9..5cc7bb4a 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -329,7 +329,9 @@ func (portal *Portal) handleSingleEventAsync(idx int, rawEvt any) { handleDuration = time.Since(start) close(doneCh) if backgrounded.Load() { - log.Debug().Stringer("duration", handleDuration). + log.Debug(). + Time("started_at", start). + Stringer("duration", handleDuration). Msg("Event that took too long finally finished handling") } }) @@ -339,15 +341,21 @@ func (portal *Portal) handleSingleEventAsync(idx int, rawEvt any) { select { case <-doneCh: if i > 0 { - log.Debug().Stringer("duration", handleDuration). + log.Debug(). + Time("started_at", start). + Stringer("duration", handleDuration). Msg("Event that took long finished handling") } return case <-tick.C: - log.Warn().Msg("Event handling is taking long") + log.Warn(). + Time("started_at", start). + Msg("Event handling is taking long") } } - log.Warn().Msg("Event handling is taking too long, continuing in background") + log.Warn(). + Time("started_at", start). + Msg("Event handling is taking too long, continuing in background") backgrounded.Store(true) } } From 1e10d9460ad216e3aadbd1a98871807eb5846995 Mon Sep 17 00:00:00 2001 From: Brad Murray Date: Wed, 4 Jun 2025 11:37:00 -0400 Subject: [PATCH 1190/1647] bridgev2/status: add RESTART UserAction (#384) --- bridgev2/status/bridgestate.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bridgev2/status/bridgestate.go b/bridgev2/status/bridgestate.go index 005a4f62..01a235a0 100644 --- a/bridgev2/status/bridgestate.go +++ b/bridgev2/status/bridgestate.go @@ -78,6 +78,7 @@ type BridgeStateUserAction string const ( UserActionOpenNative BridgeStateUserAction = "OPEN_NATIVE" UserActionRelogin BridgeStateUserAction = "RELOGIN" + UserActionRestart BridgeStateUserAction = "RESTART" ) type RemoteProfile struct { From d228995d718ae930a4c51df5fcafcfa05227c53e Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Thu, 5 Jun 2025 07:25:48 +0300 Subject: [PATCH 1191/1647] bridgev2: Configurable disconnect timeout (#383) Let the caller decide if they want to have a timeout or not. For standalone bridges using the Bridge struct the behavior is kept the same by waiting for five seconds when UserLogin DisconnectWithTimeout() is called. --- bridgev2/bridge.go | 19 +++++++++++++------ bridgev2/matrix/mxmain/main.go | 2 +- bridgev2/userlogin.go | 31 ++++++++++++++++++++++--------- 3 files changed, 36 insertions(+), 16 deletions(-) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 05a67b6a..5e3b74b7 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -137,7 +137,7 @@ func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, pa if err != nil { return err } - defer br.Stop() + defer br.StopWithTimeout(5 * time.Second) select { case <-time.After(20 * time.Second): case <-ctx.Done(): @@ -145,7 +145,7 @@ func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, pa return nil } - defer br.stop(true) + defer br.stop(true, 5*time.Second) login, err := br.GetExistingUserLoginByID(ctx, loginID) if err != nil { return fmt.Errorf("failed to get user login: %w", err) @@ -156,7 +156,7 @@ func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, pa if !ok { br.Log.Warn().Msg("Network connector doesn't implement background mode, using fallback mechanism for RunOnce") login.Client.Connect(ctx) - defer login.Disconnect(nil) + defer login.DisconnectWithTimeout(5 * time.Second) select { case <-time.After(20 * time.Second): case <-ctx.Done(): @@ -319,10 +319,14 @@ func (br *Bridge) StartLogins(ctx context.Context) error { } func (br *Bridge) Stop() { - br.stop(false) + br.stop(false, 0) } -func (br *Bridge) stop(isRunOnce bool) { +func (br *Bridge) StopWithTimeout(timeout time.Duration) { + br.stop(false, timeout) +} + +func (br *Bridge) stop(isRunOnce bool, timeout time.Duration) { br.Log.Info().Msg("Shutting down bridge") br.DisappearLoop.Stop() br.stopBackfillQueue.Set() @@ -332,7 +336,10 @@ func (br *Bridge) stop(isRunOnce bool) { var wg sync.WaitGroup wg.Add(len(br.userLoginsByID)) for _, login := range br.userLoginsByID { - go login.Disconnect(wg.Done) + go func() { + login.DisconnectWithTimeout(timeout) + wg.Done() + }() } br.cacheLock.Unlock() wg.Wait() diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index 63334ba5..e6219c50 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -421,7 +421,7 @@ func (br *BridgeMain) TriggerStop(exitCode int) { // Stop cleanly stops the bridge. This is called by [Run] and does not need to be called manually. func (br *BridgeMain) Stop() { - br.Bridge.Stop() + br.Bridge.StopWithTimeout(5 * time.Second) } // InitVersion formats the bridge version and build time nicely for things like diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 396cf899..e83e66c2 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -279,7 +279,8 @@ func (ul *UserLogin) Delete(ctx context.Context, state status.BridgeState, opts if opts.LogoutRemote { ul.Client.LogoutRemote(ctx) } else { - ul.Disconnect(nil) + // we probably shouldn't delete the login if disconnect isn't finished + ul.Disconnect() } var portals []*database.UserPortal var err error @@ -508,10 +509,11 @@ func (ul *UserLogin) FillBridgeState(state status.BridgeState) status.BridgeStat return state } -func (ul *UserLogin) Disconnect(done func()) { - if done != nil { - defer done() - } +func (ul *UserLogin) Disconnect() { + ul.DisconnectWithTimeout(0) +} + +func (ul *UserLogin) DisconnectWithTimeout(timeout time.Duration) { client := ul.Client if client != nil { ul.Client = nil @@ -520,10 +522,21 @@ func (ul *UserLogin) Disconnect(done func()) { client.Disconnect() close(disconnected) }() - select { - case <-disconnected: - case <-time.After(5 * time.Second): - ul.Log.Warn().Msg("Client disconnection timed out") + + var timeoutC <-chan time.Time + if timeout > 0 { + timeoutC = time.After(timeout) + } + for { + select { + case <-disconnected: + return + case <-time.After(2 * time.Second): + ul.Log.Warn().Msg("Client disconnection taking long") + case <-timeoutC: + ul.Log.Error().Msg("Client disconnection timed out") + return + } } } } From d04d524209dbf1ce36a9379b687859cf8a6e8e01 Mon Sep 17 00:00:00 2001 From: Brad Murray Date: Thu, 5 Jun 2025 13:38:19 -0400 Subject: [PATCH 1192/1647] crypto/verificationhelper: add method to verification done callback (#385) --- crypto/verificationhelper/callbacks_test.go | 2 +- crypto/verificationhelper/reciprocate.go | 4 ++-- crypto/verificationhelper/verificationhelper.go | 14 +++++++------- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/crypto/verificationhelper/callbacks_test.go b/crypto/verificationhelper/callbacks_test.go index 18cb964f..3b943f28 100644 --- a/crypto/verificationhelper/callbacks_test.go +++ b/crypto/verificationhelper/callbacks_test.go @@ -109,7 +109,7 @@ func (c *baseVerificationCallbacks) VerificationCancelled(ctx context.Context, t } } -func (c *baseVerificationCallbacks) VerificationDone(ctx context.Context, txnID id.VerificationTransactionID) { +func (c *baseVerificationCallbacks) VerificationDone(ctx context.Context, txnID id.VerificationTransactionID, method event.VerificationMethod) { c.doneTransactions[txnID] = struct{}{} } diff --git a/crypto/verificationhelper/reciprocate.go b/crypto/verificationhelper/reciprocate.go index 9cb84c24..d8827b8b 100644 --- a/crypto/verificationhelper/reciprocate.go +++ b/crypto/verificationhelper/reciprocate.go @@ -182,7 +182,7 @@ func (vh *VerificationHelper) HandleScannedQRData(ctx context.Context, data []by if err = vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { return err } - vh.verificationDone(ctx, txn.TransactionID) + vh.verificationDone(ctx, txn.TransactionID, txn.StartEventContent.Method) } else { return vh.store.SaveVerificationTransaction(ctx, txn) } @@ -263,7 +263,7 @@ func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id if err = vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { return err } - vh.verificationDone(ctx, txn.TransactionID) + vh.verificationDone(ctx, txn.TransactionID, txn.StartEventContent.Method) } else { return vh.store.SaveVerificationTransaction(ctx, txn) } diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 8d99dacc..9d843ea8 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -41,7 +41,7 @@ type RequiredCallbacks interface { VerificationCancelled(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) // VerificationDone is called when the verification is done. - VerificationDone(ctx context.Context, txnID id.VerificationTransactionID) + VerificationDone(ctx context.Context, txnID id.VerificationTransactionID, method event.VerificationMethod) } type ShowSASCallbacks interface { @@ -70,14 +70,14 @@ type VerificationHelper struct { verificationRequested func(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID, fromDevice id.DeviceID) verificationReady func(ctx context.Context, txnID id.VerificationTransactionID, otherDeviceID id.DeviceID, supportsSAS, supportsScanQRCode bool, qrCode *QRCode) verificationCancelledCallback func(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) - verificationDone func(ctx context.Context, txnID id.VerificationTransactionID) + verificationDone func(ctx context.Context, txnID id.VerificationTransactionID, method event.VerificationMethod) // showSAS is a callback that will be called after the SAS verification // dance is complete and we want the client to show the emojis/decimals showSAS func(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, emojiDescriptions []string, decimals []int) - // qrCodeScaned is a callback that will be called when the other device + // qrCodeScanned is a callback that will be called when the other device // scanned the QR code we are showing - qrCodeScaned func(ctx context.Context, txnID id.VerificationTransactionID) + qrCodeScanned func(ctx context.Context, txnID id.VerificationTransactionID) } var _ mautrix.VerificationHelper = (*VerificationHelper)(nil) @@ -120,7 +120,7 @@ func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, stor } else { helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeShow) helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodReciprocate) - helper.qrCodeScaned = c.QRCodeScanned + helper.qrCodeScanned = c.QRCodeScanned } } if supportsQRScan { @@ -839,7 +839,7 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn Verif return } txn.VerificationState = VerificationStateOurQRScanned - vh.qrCodeScaned(ctx, txn.TransactionID) + vh.qrCodeScanned(ctx, txn.TransactionID) if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { log.Err(err).Msg("failed to save verification transaction") } @@ -875,7 +875,7 @@ func (vh *VerificationHelper) onVerificationDone(ctx context.Context, txn Verifi if err := vh.store.DeleteVerification(ctx, txn.TransactionID); err != nil { log.Err(err).Msg("Delete verification failed") } - vh.verificationDone(ctx, txn.TransactionID) + vh.verificationDone(ctx, txn.TransactionID, txn.StartEventContent.Method) } else if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil { log.Err(err).Msg("failed to save verification transaction") } From d296f7b6604bb2f95b8f548ec38ce87e6e6f6c50 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 6 Jun 2025 14:08:04 +0300 Subject: [PATCH 1193/1647] bridgev2/provisioning: ensure that Start returns a non-nil first step --- bridgev2/matrix/provisioning.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 571e3c7f..e897bae8 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -390,6 +390,10 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque zerolog.Ctx(r.Context()).Err(err).Msg("Failed to start login") RespondWithError(w, err, "Internal error starting login") return + } else if firstStep == nil { + zerolog.Ctx(r.Context()).Error().Msg("Bridge returned nil first step in Start with no error") + RespondWithError(w, err, "Internal error starting login") + return } loginID := xid.New().String() prov.loginsLock.Lock() From 40fd8dfcbd71fce9163f6510671f920f9c7aad43 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 8 Jun 2025 00:05:59 +0300 Subject: [PATCH 1194/1647] event/relations: use unstable prefix for reply room ID field --- event/relations.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/event/relations.go b/event/relations.go index e855a7e6..2316cbc7 100644 --- a/event/relations.go +++ b/event/relations.go @@ -34,7 +34,7 @@ type RelatesTo struct { type InReplyTo struct { EventID id.EventID `json:"event_id,omitempty"` - UnstableRoomID id.RoomID `json:"room_id,omitempty"` + UnstableRoomID id.RoomID `json:"com.beeper.cross_room_id,omitempty"` } func (rel *RelatesTo) Copy() *RelatesTo { From 07567f6f96d6dc420e37f561a3474c3bcdf98b1e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 8 Jun 2025 00:12:53 +0300 Subject: [PATCH 1195/1647] bridgev2/portal: include room id in cross-room replies --- bridgev2/bridgeconfig/config.go | 1 + bridgev2/bridgeconfig/upgrade.go | 1 + bridgev2/matrix/mxmain/example-config.yaml | 3 +++ bridgev2/portal.go | 23 +++++++++++++++++++--- bridgev2/portalbackfill.go | 2 +- bridgev2/portalinternal.go | 20 +++++++++++++++++-- 6 files changed, 44 insertions(+), 6 deletions(-) diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index 37517818..bd6f53c3 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -72,6 +72,7 @@ type BridgeConfig struct { OnlyBridgeTags []event.RoomTag `yaml:"only_bridge_tags"` MuteOnlyOnCreate bool `yaml:"mute_only_on_create"` DeduplicateMatrixMessages bool `yaml:"deduplicate_matrix_messages"` + CrossRoomReplies bool `yaml:"cross_room_replies"` OutgoingMessageReID bool `yaml:"outgoing_message_re_id"` CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"` Relay RelayConfig `yaml:"relay"` diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 3e19bf8f..fa4b4493 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -38,6 +38,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.List, "bridge", "only_bridge_tags") helper.Copy(up.Bool, "bridge", "mute_only_on_create") helper.Copy(up.Bool, "bridge", "deduplicate_matrix_messages") + helper.Copy(up.Bool, "bridge", "cross_room_replies") helper.Copy(up.Bool, "bridge", "cleanup_on_logout", "enabled") helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "private") helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "relayed") diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index a9d05fd1..dad3f8a8 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -40,6 +40,9 @@ bridge: mute_only_on_create: true # Should the bridge check the db to ensure that incoming events haven't been handled before deduplicate_matrix_messages: false + # Should cross-room reply metadata be bridged? + # Most Matrix clients don't support this and servers may reject such messages too. + cross_room_replies: false # What should be done to portal rooms when a user logs out or is logged out? # Permitted values: diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 5cc7bb4a..333c1889 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1943,7 +1943,7 @@ func (portal *Portal) getRelationMeta(ctx context.Context, currentMsg networkid. return } -func (portal *Portal) applyRelationMeta(content *event.MessageEventContent, replyTo, threadRoot, prevThreadEvent *database.Message) { +func (portal *Portal) applyRelationMeta(ctx context.Context, content *event.MessageEventContent, replyTo, threadRoot, prevThreadEvent *database.Message) { if content.Mentions == nil { content.Mentions = &event.Mentions{} } @@ -1951,7 +1951,24 @@ func (portal *Portal) applyRelationMeta(content *event.MessageEventContent, repl content.GetRelatesTo().SetThread(threadRoot.MXID, prevThreadEvent.MXID) } if replyTo != nil { - content.GetRelatesTo().SetReplyTo(replyTo.MXID) + crossRoom := replyTo.Room != portal.PortalKey + if !crossRoom || portal.Bridge.Config.CrossRoomReplies { + content.GetRelatesTo().SetReplyTo(replyTo.MXID) + } + if crossRoom && portal.Bridge.Config.CrossRoomReplies { + targetPortal, err := portal.Bridge.GetExistingPortalByKey(ctx, replyTo.Room) + if err != nil { + zerolog.Ctx(ctx).Err(err). + Object("target_portal_key", replyTo.Room). + Msg("Failed to get cross-room reply portal") + } else if targetPortal == nil || targetPortal.MXID == "" { + zerolog.Ctx(ctx).Warn(). + Object("target_portal_key", replyTo.Room). + Msg("Cross-room reply portal not found") + } else { + content.RelatesTo.InReplyTo.UnstableRoomID = targetPortal.MXID + } + } content.Mentions.Add(replyTo.SenderMXID) } } @@ -1975,7 +1992,7 @@ func (portal *Portal) sendConvertedMessage( replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta(ctx, id, converted.ReplyTo, converted.ThreadRoot, false) output := make([]*database.Message, 0, len(converted.Parts)) for i, part := range converted.Parts { - portal.applyRelationMeta(part.Content, replyTo, threadRoot, prevThreadEvent) + portal.applyRelationMeta(ctx, part.Content, replyTo, threadRoot, prevThreadEvent) dbMessage := &database.Message{ ID: id, PartID: part.ID, diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index a5dfb42a..3953a043 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -333,7 +333,7 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin var firstPart *database.Message for i, part := range msg.Parts { partIDs = append(partIDs, part.ID) - portal.applyRelationMeta(part.Content, replyTo, threadRoot, prevThreadEvent) + portal.applyRelationMeta(ctx, part.Content, replyTo, threadRoot, prevThreadEvent) evtID := portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, msg.ID, part.ID) dbMessage := &database.Message{ ID: msg.ID, diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index e0f4ee5a..fd6724f4 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -89,10 +89,22 @@ func (portal *PortalInternals) CheckMessageContentCaps(ctx context.Context, caps return (*Portal)(portal).checkMessageContentCaps(ctx, caps, content, evt) } +func (portal *PortalInternals) ParseInputTransactionID(origSender *OrigSender, evt *event.Event) networkid.RawTransactionID { + return (*Portal)(portal).parseInputTransactionID(origSender, evt) +} + func (portal *PortalInternals) HandleMatrixMessage(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { (*Portal)(portal).handleMatrixMessage(ctx, sender, origSender, evt) } +func (portal *PortalInternals) PendingMessageTimeoutLoop(ctx context.Context, cfg *OutgoingTimeoutConfig) { + (*Portal)(portal).pendingMessageTimeoutLoop(ctx, cfg) +} + +func (portal *PortalInternals) CheckPendingMessages(ctx context.Context, cfg *OutgoingTimeoutConfig) { + (*Portal)(portal).checkPendingMessages(ctx, cfg) +} + func (portal *PortalInternals) HandleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent, caps *event.RoomFeatures) { (*Portal)(portal).handleMatrixEdit(ctx, sender, origSender, evt, content, caps) } @@ -129,8 +141,8 @@ func (portal *PortalInternals) GetRelationMeta(ctx context.Context, currentMsg n return (*Portal)(portal).getRelationMeta(ctx, currentMsg, replyToPtr, threadRootPtr, isBatchSend) } -func (portal *PortalInternals) ApplyRelationMeta(content *event.MessageEventContent, replyTo, threadRoot, prevThreadEvent *database.Message) { - (*Portal)(portal).applyRelationMeta(content, replyTo, threadRoot, prevThreadEvent) +func (portal *PortalInternals) ApplyRelationMeta(ctx context.Context, content *event.MessageEventContent, replyTo, threadRoot, prevThreadEvent *database.Message) { + (*Portal)(portal).applyRelationMeta(ctx, content, replyTo, threadRoot, prevThreadEvent) } func (portal *PortalInternals) SendConvertedMessage(ctx context.Context, id networkid.MessageID, intent MatrixAPI, senderID networkid.UserID, converted *ConvertedMessage, ts time.Time, streamOrder int64, logContext func(*zerolog.Event) *zerolog.Event) []*database.Message { @@ -241,6 +253,10 @@ func (portal *PortalInternals) UpdateAvatar(ctx context.Context, avatar *Avatar, return (*Portal)(portal).updateAvatar(ctx, avatar, sender, ts) } +func (portal *PortalInternals) GetBridgeInfoStateKey() string { + return (*Portal)(portal).getBridgeInfoStateKey() +} + func (portal *PortalInternals) GetBridgeInfo() (string, event.BridgeEventContent) { return (*Portal)(portal).getBridgeInfo() } From 8fb41765e2ff525e94ee9ff01bcff19c42ab3036 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 9 Jun 2025 13:52:24 +0300 Subject: [PATCH 1196/1647] event: add custom soft fail fields --- event/events.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/event/events.go b/event/events.go index 92cc39ae..07a2c7cb 100644 --- a/event/events.go +++ b/event/events.go @@ -151,10 +151,14 @@ type Unsigned struct { BeeperHSSuborder int16 `json:"com.beeper.hs.suborder,omitempty"` BeeperHSOrderString *BeeperEncodedOrder `json:"com.beeper.hs.order_string,omitempty"` BeeperFromBackup bool `json:"com.beeper.from_backup,omitempty"` + + MauSoftFailed bool `json:"fi.mau.soft_failed,omitempty"` + MauRejectionReason string `json:"fi.mau.rejection_reason,omitempty"` } func (us *Unsigned) IsEmpty() bool { return us.PrevContent == nil && us.PrevSender == "" && us.ReplacesState == "" && us.Age == 0 && us.TransactionID == "" && us.RedactedBecause == nil && us.InviteRoomState == nil && us.Relations == nil && - us.BeeperHSOrder == 0 && us.BeeperHSSuborder == 0 && us.BeeperHSOrderString.IsZero() + us.BeeperHSOrder == 0 && us.BeeperHSSuborder == 0 && us.BeeperHSOrderString.IsZero() && + !us.MauSoftFailed && us.MauRejectionReason == "" } From 05f371a48092429a04c6ebb3b81ea3a10368228b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 9 Jun 2025 14:00:13 +0300 Subject: [PATCH 1197/1647] event: add membership field to unsigned --- event/events.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/event/events.go b/event/events.go index 07a2c7cb..a763cc31 100644 --- a/event/events.go +++ b/event/events.go @@ -140,6 +140,7 @@ type StrippedState struct { type Unsigned struct { PrevContent *Content `json:"prev_content,omitempty"` PrevSender id.UserID `json:"prev_sender,omitempty"` + Membership Membership `json:"membership,omitempty"` ReplacesState id.EventID `json:"replaces_state,omitempty"` Age int64 `json:"age,omitempty"` TransactionID string `json:"transaction_id,omitempty"` @@ -157,7 +158,7 @@ type Unsigned struct { } func (us *Unsigned) IsEmpty() bool { - return us.PrevContent == nil && us.PrevSender == "" && us.ReplacesState == "" && us.Age == 0 && + return us.PrevContent == nil && us.PrevSender == "" && us.ReplacesState == "" && us.Age == 0 && us.Membership == "" && us.TransactionID == "" && us.RedactedBecause == nil && us.InviteRoomState == nil && us.Relations == nil && us.BeeperHSOrder == 0 && us.BeeperHSSuborder == 0 && us.BeeperHSOrderString.IsZero() && !us.MauSoftFailed && us.MauRejectionReason == "" From a154718b5d4cd43a61dd56de141edbab9eed9aa4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 9 Jun 2025 17:24:46 +0530 Subject: [PATCH 1198/1647] bridgev2/portal: allow specifying extra fields for portal members (#386) --- bridgev2/ghost.go | 23 ++++++++------ bridgev2/portal.go | 76 +++++++++++++++++++++++++++++----------------- 2 files changed, 61 insertions(+), 38 deletions(-) diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index 087e0b64..f06c0363 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -185,6 +185,18 @@ func (ghost *Ghost) UpdateAvatar(ctx context.Context, avatar *Avatar) bool { return true } +func (ghost *Ghost) getExtraProfileMeta() *event.BeeperProfileExtra { + bridgeName := ghost.Bridge.Network.GetName() + return &event.BeeperProfileExtra{ + RemoteID: string(ghost.ID), + Identifiers: ghost.Identifiers, + Service: bridgeName.BeeperBridgeType, + Network: bridgeName.NetworkID, + IsBridgeBot: false, + IsNetworkBot: ghost.IsBot, + } +} + func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, isBot *bool) bool { if identifiers != nil { slices.Sort(identifiers) @@ -200,16 +212,7 @@ func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, if isBot != nil { ghost.IsBot = *isBot } - bridgeName := ghost.Bridge.Network.GetName() - meta := &event.BeeperProfileExtra{ - RemoteID: string(ghost.ID), - Identifiers: ghost.Identifiers, - Service: bridgeName.BeeperBridgeType, - Network: bridgeName.NetworkID, - IsBridgeBot: false, - IsNetworkBot: ghost.IsBot, - } - err := ghost.Intent.SetExtraProfileMeta(ctx, meta) + err := ghost.Intent.SetExtraProfileMeta(ctx, ghost.getExtraProfileMeta()) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to set extra profile metadata") } else { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 333c1889..1a57d211 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2926,7 +2926,8 @@ type ChatMember struct { PowerLevel *int UserInfo *UserInfo - PrevMembership event.Membership + MemberEventExtra map[string]any + PrevMembership event.Membership } type ChatMemberList struct { @@ -3347,7 +3348,13 @@ func (portal *Portal) updateOtherUser(ctx context.Context, members *ChatMemberLi return false } -func (portal *Portal) syncParticipants(ctx context.Context, members *ChatMemberList, source *UserLogin, sender MatrixAPI, ts time.Time) error { +func (portal *Portal) syncParticipants( + ctx context.Context, + members *ChatMemberList, + source *UserLogin, + sender MatrixAPI, + ts time.Time, +) error { members.memberListToMap(ctx) var loginsInPortal []*UserLogin var err error @@ -3371,7 +3378,7 @@ func (portal *Portal) syncParticipants(ctx context.Context, members *ChatMemberL } delete(currentMembers, portal.Bridge.Bot.GetMXID()) powerChanged := members.PowerLevels.Apply(portal.Bridge.Bot.GetMXID(), currentPower) - syncUser := func(extraUserID id.UserID, member ChatMember, hasIntent bool) bool { + syncUser := func(extraUserID id.UserID, member ChatMember, intent MatrixAPI) bool { if member.Membership == "" { member.Membership = event.MembershipJoin } @@ -3400,58 +3407,71 @@ func (portal *Portal) syncParticipants(ctx context.Context, members *ChatMemberL Displayname: currentMember.Displayname, AvatarURL: currentMember.AvatarURL, } - wrappedContent := &event.Content{Parsed: content, Raw: make(map[string]any)} + wrappedContent := &event.Content{Parsed: content, Raw: maps.Clone(member.MemberEventExtra)} + if wrappedContent.Raw == nil { + wrappedContent.Raw = make(map[string]any) + } thisEvtSender := sender if member.Membership == event.MembershipJoin { content.Membership = event.MembershipInvite - if hasIntent { + if intent != nil { wrappedContent.Raw["fi.mau.will_auto_accept"] = true } if thisEvtSender.GetMXID() == extraUserID { thisEvtSender = portal.Bridge.Bot } } + addLogContext := func(e *zerolog.Event) *zerolog.Event { + return e.Stringer("target_user_id", extraUserID). + Stringer("sender_user_id", thisEvtSender.GetMXID()). + Str("prev_membership", string(currentMember.Membership)) + } if currentMember != nil && currentMember.Membership == event.MembershipBan && member.Membership != event.MembershipLeave { unbanContent := *content unbanContent.Membership = event.MembershipLeave wrappedUnbanContent := &event.Content{Parsed: &unbanContent} _, err = portal.sendStateWithIntentOrBot(ctx, thisEvtSender, event.StateMember, extraUserID.String(), wrappedUnbanContent, ts) if err != nil { - log.Err(err). - Stringer("target_user_id", extraUserID). - Stringer("sender_user_id", thisEvtSender.GetMXID()). - Str("prev_membership", string(currentMember.Membership)). - Str("membership", string(member.Membership)). + addLogContext(log.Err(err)). + Str("new_membership", string(unbanContent.Membership)). Msg("Failed to unban user to update membership") } else { - log.Trace(). - Stringer("target_user_id", extraUserID). - Stringer("sender_user_id", thisEvtSender.GetMXID()). - Str("prev_membership", string(currentMember.Membership)). - Str("membership", string(member.Membership)). + addLogContext(log.Trace()). + Str("new_membership", string(unbanContent.Membership)). Msg("Unbanned user to update membership") + currentMember.Membership = event.MembershipLeave } } _, err = portal.sendStateWithIntentOrBot(ctx, thisEvtSender, event.StateMember, extraUserID.String(), wrappedContent, ts) if err != nil { - log.Err(err). - Stringer("target_user_id", extraUserID). - Stringer("sender_user_id", thisEvtSender.GetMXID()). - Str("prev_membership", string(currentMember.Membership)). - Str("membership", string(member.Membership)). + addLogContext(log.Err(err)). + Str("new_membership", string(content.Membership)). Msg("Failed to update user membership") } else { - log.Trace(). - Stringer("target_user_id", extraUserID). - Stringer("sender_user_id", thisEvtSender.GetMXID()). - Str("prev_membership", string(currentMember.Membership)). - Str("membership", string(member.Membership)). - Msg("Updating membership in room") + addLogContext(log.Trace()). + Str("new_membership", string(content.Membership)). + Msg("Updated membership in room") + currentMember.Membership = content.Membership + + if intent != nil && content.Membership == event.MembershipInvite && member.Membership == event.MembershipJoin { + content.Membership = event.MembershipJoin + wrappedJoinContent := &event.Content{Parsed: content, Raw: member.MemberEventExtra} + _, err = intent.SendState(ctx, portal.MXID, event.StateMember, intent.GetMXID().String(), wrappedJoinContent, ts) + if err != nil { + addLogContext(log.Err(err)). + Str("new_membership", string(content.Membership)). + Msg("Failed to join with intent") + } else { + addLogContext(log.Trace()). + Str("new_membership", string(content.Membership)). + Msg("Joined room with intent") + } + } } return true } syncIntent := func(intent MatrixAPI, member ChatMember) { - if !syncUser(intent.GetMXID(), member, true) { + if !syncUser(intent.GetMXID(), member, intent) { return } if member.Membership == event.MembershipJoin || member.Membership == "" { @@ -3480,7 +3500,7 @@ func (portal *Portal) syncParticipants(ctx context.Context, members *ChatMemberL syncIntent(intent, member) } if extraUserID != "" { - syncUser(extraUserID, member, false) + syncUser(extraUserID, member, nil) } } if powerChanged { From 72bacbb666fb97adacabd5198c120910bd869d2f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 9 Jun 2025 19:30:37 +0300 Subject: [PATCH 1199/1647] appservice/intent: ensure registered when sending own member state event --- appservice/intent.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/appservice/intent.go b/appservice/intent.go index 30313273..d6cda137 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -228,6 +228,8 @@ func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, e if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } + } else if err := intent.EnsureRegistered(ctx); err != nil { + return nil, err } contentJSON = intent.AddDoublePuppetValue(contentJSON) return intent.Client.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON) From 99cfa0b53ae6900f8f4b7ad33b71016122542910 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 10 Jun 2025 15:03:31 +0300 Subject: [PATCH 1200/1647] bridgev2/matrixinvite: save portal after setting mxid --- bridgev2/matrixinvite.go | 17 ++++++++++++----- bridgev2/portalinternal.go | 4 ++-- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/bridgev2/matrixinvite.go b/bridgev2/matrixinvite.go index 25c35eb7..0f1601d1 100644 --- a/bridgev2/matrixinvite.go +++ b/bridgev2/matrixinvite.go @@ -177,10 +177,7 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen return } - didSetPortal := portal.setMXIDToExistingRoom(evt.RoomID) - if resp.PortalInfo != nil { - portal.UpdateInfo(ctx, resp.PortalInfo, sourceLogin, nil, time.Time{}) - } + didSetPortal := portal.setMXIDToExistingRoom(ctx, evt.RoomID) if didSetPortal { message := "Private chat portal created" err = br.givePowerToBot(ctx, evt.RoomID, invitedGhost.Intent) @@ -190,6 +187,12 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen message += "\n\nWarning: failed to promote bot" hasWarning = true } + if resp.PortalInfo != nil { + portal.UpdateInfo(ctx, resp.PortalInfo, sourceLogin, nil, time.Time{}) + } else { + portal.UpdateCapabilities(ctx, sourceLogin, true) + portal.UpdateBridgeInfo(ctx) + } // TODO this might become unnecessary if UpdateInfo starts taking care of it _, err = br.Bot.SendState(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.Content{ Parsed: &event.ElementFunctionalMembersContent{ @@ -242,7 +245,7 @@ func (br *Bridge) givePowerToBot(ctx context.Context, roomID id.RoomID, userWith return nil } -func (portal *Portal) setMXIDToExistingRoom(roomID id.RoomID) bool { +func (portal *Portal) setMXIDToExistingRoom(ctx context.Context, roomID id.RoomID) bool { portal.roomCreateLock.Lock() defer portal.roomCreateLock.Unlock() if portal.MXID != "" { @@ -253,5 +256,9 @@ func (portal *Portal) setMXIDToExistingRoom(roomID id.RoomID) bool { portal.Bridge.cacheLock.Lock() portal.Bridge.portalsByMXID[portal.MXID] = portal portal.Bridge.cacheLock.Unlock() + err := portal.Save(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal after updating mxid") + } return true } diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index fd6724f4..bde0b170 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -361,6 +361,6 @@ func (portal *PortalInternals) ToggleSpace(ctx context.Context, spaceID id.RoomI return (*Portal)(portal).toggleSpace(ctx, spaceID, canonical, remove) } -func (portal *PortalInternals) SetMXIDToExistingRoom(roomID id.RoomID) bool { - return (*Portal)(portal).setMXIDToExistingRoom(roomID) +func (portal *PortalInternals) SetMXIDToExistingRoom(ctx context.Context, roomID id.RoomID) bool { + return (*Portal)(portal).setMXIDToExistingRoom(ctx, roomID) } From 12502e213a3f5de840e1e9f26a938e9ec2586e7f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 10 Jun 2025 17:46:23 +0300 Subject: [PATCH 1201/1647] bridgev2/userlogin: never set client to nil --- bridgev2/userlogin.go | 47 +++++++++++++++++++++++-------------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index e83e66c2..2016dd59 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -37,6 +37,7 @@ type UserLogin struct { spaceCreateLock sync.Mutex deleteLock sync.Mutex + disconnectOnce sync.Once } func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *database.UserLogin) (*UserLogin, error) { @@ -514,29 +515,31 @@ func (ul *UserLogin) Disconnect() { } func (ul *UserLogin) DisconnectWithTimeout(timeout time.Duration) { - client := ul.Client - if client != nil { - ul.Client = nil - disconnected := make(chan struct{}) - go func() { - client.Disconnect() - close(disconnected) - }() + ul.disconnectOnce.Do(func() { + ul.disconnectInternal(timeout) + }) +} - var timeoutC <-chan time.Time - if timeout > 0 { - timeoutC = time.After(timeout) - } - for { - select { - case <-disconnected: - return - case <-time.After(2 * time.Second): - ul.Log.Warn().Msg("Client disconnection taking long") - case <-timeoutC: - ul.Log.Error().Msg("Client disconnection timed out") - return - } +func (ul *UserLogin) disconnectInternal(timeout time.Duration) { + disconnected := make(chan struct{}) + go func() { + ul.Client.Disconnect() + close(disconnected) + }() + + var timeoutC <-chan time.Time + if timeout > 0 { + timeoutC = time.After(timeout) + } + for { + select { + case <-disconnected: + return + case <-time.After(2 * time.Second): + ul.Log.Warn().Msg("Client disconnection taking long") + case <-timeoutC: + ul.Log.Error().Msg("Client disconnection timed out") + return } } } From 9c67d238d739bbaac7b9cd8e6854b0faf88bc1fa Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 10 Jun 2025 18:53:40 +0300 Subject: [PATCH 1202/1647] bridgev2/portal: check only for me flag in delete chat events --- bridgev2/portal.go | 46 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 1a57d211..62705ff5 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2872,17 +2872,57 @@ func (portal *Portal) handleRemoteChatResync(ctx context.Context, source *UserLo } func (portal *Portal) handleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) { + log := zerolog.Ctx(ctx) if portal.Receiver == "" && evt.DeleteOnlyForMe() { - // TODO check if there are other users + logins, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to check if portal has other logins") + return + } + var ownUP *database.UserPortal + logins = slices.DeleteFunc(logins, func(up *database.UserPortal) bool { + if up.LoginID == source.ID { + ownUP = up + return true + } + return false + }) + if len(logins) > 0 { + log.Debug().Msg("Not deleting portal with other logins in remote chat delete event") + if ownUP != nil { + err = portal.Bridge.DB.UserPortal.Delete(ctx, ownUP) + if err != nil { + log.Err(err).Msg("Failed to delete own user portal row from database") + } else { + log.Debug().Msg("Deleted own user portal row from database") + } + } + _, err = portal.sendStateWithIntentOrBot( + ctx, + source.User.DoublePuppet(ctx), + event.StateMember, + source.UserMXID.String(), + &event.Content{Parsed: &event.MemberEventContent{Membership: event.MembershipLeave}}, + getEventTS(evt), + ) + if err != nil { + log.Err(err).Msg("Failed to send leave state event for user after remote chat delete") + } else { + log.Debug().Msg("Sent leave state event for user after remote chat delete") + } + return + } } err := portal.Delete(ctx) if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to delete portal from database") + log.Err(err).Msg("Failed to delete portal from database") return } err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, false) if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to delete Matrix room") + log.Err(err).Msg("Failed to delete Matrix room") + } else { + log.Info().Msg("Deleted room after remote chat delete event") } } From 1038f6a73cd9122b714fdc0e6d61198ad86eb1b9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 10 Jun 2025 19:33:02 +0300 Subject: [PATCH 1203/1647] bridgev2: fix more background contexts --- bridgev2/portal.go | 2 +- bridgev2/userlogin.go | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 62705ff5..6dd5711f 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -302,7 +302,7 @@ func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) { func (portal *Portal) eventLoop() { if cfg := portal.Bridge.Network.GetCapabilities().OutgoingMessageTimeouts; cfg != nil { - ctx, cancel := context.WithCancel(portal.Log.WithContext(context.Background())) + ctx, cancel := context.WithCancel(portal.Log.WithContext(portal.Bridge.BackgroundCtx)) go portal.pendingMessageTimeoutLoop(ctx, cfg) defer cancel() } diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 2016dd59..05574e71 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -228,7 +228,8 @@ func (user *User) NewLogin(ctx context.Context, data *database.UserLogin, params } ul.BridgeState = user.Bridge.NewBridgeStateQueue(ul) } - err = params.LoadUserLogin(ul.Log.WithContext(context.Background()), ul) + noCancelCtx := ul.Log.WithContext(user.Bridge.BackgroundCtx) + err = params.LoadUserLogin(noCancelCtx, ul) if err != nil { return nil, err } else if ul.Client == nil { @@ -236,14 +237,14 @@ func (user *User) NewLogin(ctx context.Context, data *database.UserLogin, params return nil, fmt.Errorf("client not filled by LoadUserLogin") } if doInsert { - err = user.Bridge.DB.UserLogin.Insert(ctx, ul.UserLogin) + err = user.Bridge.DB.UserLogin.Insert(noCancelCtx, ul.UserLogin) if err != nil { return nil, err } user.Bridge.userLoginsByID[ul.ID] = ul user.logins[ul.ID] = ul } else { - err = ul.Save(ctx) + err = ul.Save(noCancelCtx) if err != nil { return nil, err } From 15d0b63eb6ab11a62d4b5287cc09a991716b5744 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 11 Jun 2025 15:34:34 +0300 Subject: [PATCH 1204/1647] bridgev2/provisioning: check for nil steps in submit and wait calls --- bridgev2/matrix/provisioning.go | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index e897bae8..0a11aa79 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -364,6 +364,8 @@ func (prov *ProvisioningAPI) GetLoginFlows(w http.ResponseWriter, r *http.Reques }) } +var ErrNilStep = errors.New("bridge returned nil step with no error") + func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Request) { overrideLogin, failed := prov.GetExplicitLoginForRequest(w, r) if failed { @@ -386,14 +388,13 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque } else { firstStep, err = login.Start(r.Context()) } + if err == nil && firstStep == nil { + err = ErrNilStep + } if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to start login") RespondWithError(w, err, "Internal error starting login") return - } else if firstStep == nil { - zerolog.Ctx(r.Context()).Error().Msg("Bridge returned nil first step in Start with no error") - RespondWithError(w, err, "Internal error starting login") - return } loginID := xid.New().String() prov.loginsLock.Lock() @@ -439,6 +440,9 @@ func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http default: panic("Impossible state") } + if err == nil && nextStep == nil { + err = ErrNilStep + } if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to submit input") RespondWithError(w, err, "Internal error submitting input") @@ -454,6 +458,9 @@ func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Request) { login := r.Context().Value(provisioningLoginProcessKey).(*ProvLogin) nextStep, err := login.Process.(bridgev2.LoginProcessDisplayAndWait).Wait(r.Context()) + if err == nil && nextStep == nil { + err = ErrNilStep + } if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to wait") RespondWithError(w, err, "Internal error waiting for login") From b8921397b82f3eb24765c8bf1dd2a4c563cc73bf Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 11 Jun 2025 19:10:19 +0300 Subject: [PATCH 1205/1647] event,requests: add MSC4293 redact events field to member events --- event/member.go | 2 ++ requests.go | 2 ++ 2 files changed, 4 insertions(+) diff --git a/event/member.go b/event/member.go index d0ff2a7c..02b7cae9 100644 --- a/event/member.go +++ b/event/member.go @@ -42,6 +42,8 @@ type MemberEventContent struct { ThirdPartyInvite *ThirdPartyInvite `json:"third_party_invite,omitempty"` Reason string `json:"reason,omitempty"` MSC3414File *EncryptedFileInfo `json:"org.matrix.msc3414.file,omitempty"` + + MSC4293RedactEvents bool `json:"org.matrix.msc4293.redact_events,omitempty"` } type ThirdPartyInvite struct { diff --git a/requests.go b/requests.go index 42d257fb..8363aeda 100644 --- a/requests.go +++ b/requests.go @@ -193,6 +193,8 @@ type ReqKickUser struct { type ReqBanUser struct { Reason string `json:"reason,omitempty"` UserID id.UserID `json:"user_id"` + + MSC4293RedactEvents bool `json:"org.matrix.msc4293.redact_events,omitempty"` } // ReqUnbanUser is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidunban From c540f30ef9ef6f3db6449475034516619d70848c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 12 Jun 2025 13:32:00 +0300 Subject: [PATCH 1206/1647] dependencies: update --- go.mod | 16 ++++++++-------- go.sum | 28 ++++++++++++++-------------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/go.mod b/go.mod index ebc7a61c..1e1ea939 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,7 @@ module maunium.net/go/mautrix go 1.23.0 -toolchain go1.24.3 +toolchain go1.24.4 require ( filippo.io/edwards25519 v1.1.0 @@ -17,13 +17,13 @@ require ( github.com/stretchr/testify v1.10.0 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 - github.com/yuin/goldmark v1.7.11 - go.mau.fi/util v0.8.7 + github.com/yuin/goldmark v1.7.12 + go.mau.fi/util v0.8.8-0.20250612103042-2aa072eb60f0 go.mau.fi/zeroconfig v0.1.3 - golang.org/x/crypto v0.38.0 - golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 - golang.org/x/net v0.40.0 - golang.org/x/sync v0.14.0 + golang.org/x/crypto v0.39.0 + golang.org/x/exp v0.0.0-20250606033433-dcc06ee1d476 + golang.org/x/net v0.41.0 + golang.org/x/sync v0.15.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 ) @@ -38,6 +38,6 @@ require ( github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect golang.org/x/sys v0.33.0 // indirect - golang.org/x/text v0.25.0 // indirect + golang.org/x/text v0.26.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index a3c7542d..357085e6 100644 --- a/go.sum +++ b/go.sum @@ -51,28 +51,28 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/yuin/goldmark v1.7.11 h1:ZCxLyDMtz0nT2HFfsYG8WZ47Trip2+JyLysKcMYE5bo= -github.com/yuin/goldmark v1.7.11/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.8.7 h1:ywKarPxouJQEEijTs4mPlxC7F4AWEKokEpWc+2TYy6c= -go.mau.fi/util v0.8.7/go.mod h1:j6R3cENakc1f8HpQeFl0N15UiSTcNmIfDBNJUbL71RY= +github.com/yuin/goldmark v1.7.12 h1:YwGP/rrea2/CnCtUHgjuolG/PnMxdQtPMO5PvaE2/nY= +github.com/yuin/goldmark v1.7.12/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= +go.mau.fi/util v0.8.8-0.20250612103042-2aa072eb60f0 h1:EcDJfYWX6aVT3c6nWTg9Qly41rNKabzzERt7OFzVerA= +go.mau.fi/util v0.8.8-0.20250612103042-2aa072eb60f0/go.mod h1:Y/kS3loxTEhy8Vill513EtPXr+CRDdae+Xj2BXXMy/c= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= -golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= -golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6 h1:y5zboxd6LQAqYIhHnB48p0ByQ/GnQx2BE33L8BOHQkI= -golang.org/x/exp v0.0.0-20250506013437-ce4c2cf36ca6/go.mod h1:U6Lno4MTRCDY+Ba7aCcauB9T60gsv5s4ralQzP72ZoQ= -golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= -golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= -golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= -golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= +golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= +golang.org/x/exp v0.0.0-20250606033433-dcc06ee1d476 h1:bsqhLWFR6G6xiQcb+JoGqdKdRU6WzPWmK8E0jxTjzo4= +golang.org/x/exp v0.0.0-20250606033433-dcc06ee1d476/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= +golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= +golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= +golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= +golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= -golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= +golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= +golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= From c888801751a65223817daef8840cd0a7bae02b86 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 14 Jun 2025 12:21:05 +0300 Subject: [PATCH 1207/1647] bridgev2/matrixinvite: allow redirecting DM creations to another user --- bridgev2/matrixinvite.go | 20 ++++++++++++++++++++ bridgev2/networkinterface.go | 3 +++ 2 files changed, 23 insertions(+) diff --git a/bridgev2/matrixinvite.go b/bridgev2/matrixinvite.go index 0f1601d1..a57c91b8 100644 --- a/bridgev2/matrixinvite.go +++ b/bridgev2/matrixinvite.go @@ -187,6 +187,26 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen message += "\n\nWarning: failed to promote bot" hasWarning = true } + if resp.DMRedirectedTo != "" && resp.DMRedirectedTo != invitedGhost.ID { + log.Debug(). + Str("dm_redirected_to_id", string(resp.DMRedirectedTo)). + Msg("Created DM was redirected to another user ID") + _, err = invitedGhost.Intent.SendState(ctx, portal.MXID, event.StateMember, invitedGhost.Intent.GetMXID().String(), &event.Content{ + Parsed: &event.MemberEventContent{ + Membership: event.MembershipLeave, + Reason: "Direct chat redirected to another internal user ID", + }, + }, time.Time{}) + if err != nil { + log.Err(err).Msg("Failed to make incorrect ghost leave new DM room") + } + otherUserGhost, err := br.GetGhostByID(ctx, resp.DMRedirectedTo) + if err != nil { + log.Err(err).Msg("Failed to get ghost of real portal other user ID") + } else { + invitedGhost = otherUserGhost + } + } if resp.PortalInfo != nil { portal.UpdateInfo(ctx, resp.PortalInfo, sourceLogin, nil, time.Time{}) } else { diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 2b99e4e6..457a7bd4 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -696,6 +696,9 @@ type CreateChatResponse struct { // Portal and PortalInfo are not required, the caller will fetch them automatically based on PortalKey if necessary. Portal *Portal PortalInfo *ChatInfo + // If a start DM request (CreateChatWithGhost or ResolveIdentifier) returns the DM to a different user, + // this field should have the user ID of said different user. + DMRedirectedTo networkid.UserID } // IdentifierResolvingNetworkAPI is an optional interface that network connectors can implement to support starting new direct chats. From 79969306e740432f8fdf58cb1b87d1f6a1862a70 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 14 Jun 2025 12:23:36 +0300 Subject: [PATCH 1208/1647] bridgev2/matrix: check stream upload size after writing file --- bridgev2/matrix/intent.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 7efc1bab..f99437b3 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -393,9 +393,13 @@ func (as *ASIntent) UploadMediaStream( err = fmt.Errorf("failed to get temp file info: %w", err) return } + size = info.Size() + if size > as.Connector.MediaConfig.UploadSize { + return "", nil, fmt.Errorf("file too large (%.2f MB > %.2f MB)", float64(size)/1000/1000, float64(as.Connector.MediaConfig.UploadSize)/1000/1000) + } req := mautrix.ReqUploadMedia{ Content: replFile, - ContentLength: info.Size(), + ContentLength: size, ContentType: res.MimeType, FileName: res.FileName, } From c836dbafdfd9062a3b2a2dcba2f4e037e81e4e81 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 14 Jun 2025 12:31:03 +0300 Subject: [PATCH 1209/1647] bridgev2/matrixinvite: clean up old portal room if user is not a member --- bridgev2/matrixinvite.go | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/bridgev2/matrixinvite.go b/bridgev2/matrixinvite.go index a57c91b8..11826b40 100644 --- a/bridgev2/matrixinvite.go +++ b/bridgev2/matrixinvite.go @@ -164,6 +164,34 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen return } } + if portal.MXID != "" { + doCleanup := true + existingPortalMembers, err := br.Matrix.GetMembers(ctx, portal.MXID) + if err != nil { + log.Err(err). + Stringer("old_portal_mxid", portal.MXID). + Msg("Failed to check existing portal members, deleting room") + } else if targetUserMember, ok := existingPortalMembers[sender.MXID]; !ok { + log.Debug(). + Stringer("old_portal_mxid", portal.MXID). + Msg("Inviter has no member event in old portal, deleting room") + } else if targetUserMember.Membership.IsInviteOrJoin() { + doCleanup = false + } else { + log.Debug(). + Stringer("old_portal_mxid", portal.MXID). + Str("membership", string(targetUserMember.Membership)). + Msg("Inviter is not in old portal, deleting room") + } + + if doCleanup { + if err = portal.RemoveMXID(ctx); err != nil { + log.Err(err).Msg("Failed to delete old portal mxid") + } else if err = br.Bot.DeleteRoom(ctx, portal.MXID, true); err != nil { + log.Err(err).Msg("Failed to clean up old portal room") + } + } + } err = invitedGhost.Intent.EnsureInvited(ctx, evt.RoomID, br.Bot.GetMXID()) if err != nil { log.Err(err).Msg("Failed to ensure bot is invited to room") From 1143cfaa85bc75449811918ecc31cbfac5c92c3a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 14 Jun 2025 18:00:35 +0300 Subject: [PATCH 1210/1647] event: implement fallbacks for per-message profiles --- bridgev2/matrix/intent.go | 4 ++++ bridgev2/queue.go | 1 + event/beeper.go | 31 +++++++++++++++++++++++++++++++ 3 files changed, 36 insertions(+) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index f99437b3..2088d5b1 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -58,6 +58,10 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType }) } if eventType != event.EventReaction && eventType != event.EventRedaction { + msgContent, ok := content.Parsed.(*event.MessageEventContent) + if ok { + msgContent.AddPerMessageProfileFallback() + } if encrypted, err := as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil { return nil, fmt.Errorf("failed to check if room is encrypted: %w", err) } else if encrypted { diff --git a/bridgev2/queue.go b/bridgev2/queue.go index 3d329b22..74424290 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -99,6 +99,7 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) { if evt.Type == event.EventMessage && sender != nil { msg := evt.Content.AsMessage() msg.RemoveReplyFallback() + msg.RemovePerMessageProfileFallback() if strings.HasPrefix(msg.Body, br.Config.CommandPrefix) || evt.RoomID == sender.ManagementRoom { if !sender.Permissions.Commands { status := WrapErrorInStatus(errors.New("you don't have permission to use commands")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage() diff --git a/event/beeper.go b/event/beeper.go index 891204e5..a85e82fc 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -11,7 +11,10 @@ import ( "encoding/binary" "encoding/json" "fmt" + "html" + "regexp" "strconv" + "strings" "maunium.net/go/mautrix/id" ) @@ -141,6 +144,34 @@ type BeeperPerMessageProfile struct { Displayname string `json:"displayname,omitempty"` AvatarURL *id.ContentURIString `json:"avatar_url,omitempty"` AvatarFile *EncryptedFileInfo `json:"avatar_file,omitempty"` + HasFallback bool `json:"has_fallback,omitempty"` +} + +func (content *MessageEventContent) AddPerMessageProfileFallback() { + if content.BeeperPerMessageProfile == nil || content.BeeperPerMessageProfile.HasFallback || content.BeeperPerMessageProfile.Displayname == "" { + return + } + content.BeeperPerMessageProfile.HasFallback = true + content.EnsureHasHTML() + content.Body = fmt.Sprintf("%s: %s", content.BeeperPerMessageProfile.Displayname, content.Body) + content.FormattedBody = fmt.Sprintf( + "%s: %s", + html.EscapeString(content.BeeperPerMessageProfile.Displayname), + content.FormattedBody, + ) +} + +var HTMLProfileFallbackRegex = regexp.MustCompile(`([^<]+): `) + +func (content *MessageEventContent) RemovePerMessageProfileFallback() { + if content.BeeperPerMessageProfile == nil || !content.BeeperPerMessageProfile.HasFallback || content.BeeperPerMessageProfile.Displayname == "" { + return + } + content.BeeperPerMessageProfile.HasFallback = false + content.Body = strings.TrimPrefix(content.Body, content.BeeperPerMessageProfile.Displayname+": ") + if content.Format == FormatHTML { + content.FormattedBody = HTMLProfileFallbackRegex.ReplaceAllLiteralString(content.FormattedBody, "") + } } type BeeperEncodedOrder struct { From 1878700a9df6339806d4ef4c4af3c56c6d0a0bef Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 16 Jun 2025 16:42:04 +0300 Subject: [PATCH 1211/1647] Bump version to v0.24.1 --- CHANGELOG.md | 33 +++++++++++++++++++++++++++++++++ bridgev2/matrix/provisioning.go | 2 +- go.mod | 2 +- go.sum | 4 ++-- version.go | 2 +- 5 files changed, 38 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 95f214c3..b2eefb3d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,36 @@ +## v0.24.1 (2025-06-16) + +* *(commands)* Added framework for using reactions as buttons that execute + command handlers. +* *(client)* Added wrapper for `/relations` endpoints. +* *(client)* Added support for stable version of room summary endpoint. +* *(client)* Fixed parsing URL preview responses where width/height are strings. +* *(federation)* Fixed bugs in server auth. +* *(id)* Added utilities for validating server names. +* *(event)* Fixed incorrect empty `entity` field when sending hashed moderation + policy events. +* *(event)* Added [MSC4293] redact events field to member events. +* *(event)* Added support for fallbacks in [MSC4144] per-message profiles. +* *(format)* Added `MarkdownLink` and `MarkdownMention` utility functions for + generating properly escaped markdown. +* *(synapseadmin)* Added support for synchronous (v1) room delete endpoint. +* *(synapseadmin)* Changed `Client` struct to not embed the `mautrix.Client`. + This is a breaking change if you were relying on accessing non-admin functions + from the admin client. +* *(bridgev2/provisioning)* Fixed `/display_and_wait` not passing through errors + from the network connector properly. +* *(bridgev2/crypto)* Fixed encryption not working if the user's ID had the same + prefix as the bridge ghosts (e.g. `@whatsappbridgeuser:example.com` with a + `@whatsapp_` prefix). +* *(bridgev2)* Fixed portals not being saved after creating a DM portal from a + Matrix DM invite. +* *(bridgev2)* Added config option to determine whether cross-room replies + should be bridged. +* *(appservice)* Fixed `EnsureRegistered` not being called when sending a custom + member event for the controlled user. + +[MSC4293]: https://github.com/matrix-org/matrix-spec-proposals/pull/4293 + ## v0.24.0 (2025-05-16) * *(commands)* Added generic framework for implementing bot commands. diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 0a11aa79..f865a19e 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -120,7 +120,7 @@ func (prov *ProvisioningAPI) Init() { prov.Router.Use(hlog.NewHandler(prov.log)) prov.Router.Use(hlog.RequestIDHandler("request_id", "Request-Id")) prov.Router.Use(exhttp.CORSMiddleware) - prov.Router.Use(requestlog.AccessLogger(false)) + prov.Router.Use(requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true})) prov.Router.Use(prov.AuthMiddleware) prov.Router.Path("/v3/whoami").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetWhoami) prov.Router.Path("/v3/login/flows").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLoginFlows) diff --git a/go.mod b/go.mod index 1e1ea939..dcc6616c 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.12 - go.mau.fi/util v0.8.8-0.20250612103042-2aa072eb60f0 + go.mau.fi/util v0.8.8 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.39.0 golang.org/x/exp v0.0.0-20250606033433-dcc06ee1d476 diff --git a/go.sum b/go.sum index 357085e6..779e05db 100644 --- a/go.sum +++ b/go.sum @@ -53,8 +53,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.12 h1:YwGP/rrea2/CnCtUHgjuolG/PnMxdQtPMO5PvaE2/nY= github.com/yuin/goldmark v1.7.12/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.8.8-0.20250612103042-2aa072eb60f0 h1:EcDJfYWX6aVT3c6nWTg9Qly41rNKabzzERt7OFzVerA= -go.mau.fi/util v0.8.8-0.20250612103042-2aa072eb60f0/go.mod h1:Y/kS3loxTEhy8Vill513EtPXr+CRDdae+Xj2BXXMy/c= +go.mau.fi/util v0.8.8 h1:OnuEEc/sIJFhnq4kFggiImUpcmnmL/xpvQMRu5Fiy5c= +go.mau.fi/util v0.8.8/go.mod h1:Y/kS3loxTEhy8Vill513EtPXr+CRDdae+Xj2BXXMy/c= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= diff --git a/version.go b/version.go index 8366c5bf..193205ee 100644 --- a/version.go +++ b/version.go @@ -7,7 +7,7 @@ import ( "strings" ) -const Version = "v0.24.0" +const Version = "v0.24.1" var GoModVersion = "" var Commit = "" From 26da46dbbf6e927191bf17f75b784060126f5e09 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 17 Jun 2025 22:08:29 +0530 Subject: [PATCH 1212/1647] bridgev2/portal: return result of handling remote events (#389) --- bridgev2/portal.go | 362 ++++++++++++++++++++++++------------- bridgev2/portalbackfill.go | 22 ++- bridgev2/portalinternal.go | 88 ++++----- bridgev2/queue.go | 21 ++- 4 files changed, 318 insertions(+), 175 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 6dd5711f..21d8550f 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -283,19 +283,21 @@ func (br *Bridge) GetExistingPortalByKey(ctx context.Context, key networkid.Port return br.loadPortal(ctx, db, err, nil) } -func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) { +func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) EventHandlingResult { if PortalEventBuffer == 0 { portal.eventsLock.Lock() defer portal.eventsLock.Unlock() portal.eventIdx++ - portal.handleSingleEventAsync(portal.eventIdx, evt) + return portal.handleSingleEventAsync(portal.eventIdx, evt) } else { select { case portal.events <- evt: + return EventHandlingResultQueued default: zerolog.Ctx(ctx).Error(). Str("portal_id", string(portal.ID)). Msg("Portal event channel is full") + return EventHandlingResultFailed } } } @@ -313,19 +315,25 @@ func (portal *Portal) eventLoop() { } } -func (portal *Portal) handleSingleEventAsync(idx int, rawEvt any) { +func (portal *Portal) handleSingleEventAsync(idx int, rawEvt any) (outerRes EventHandlingResult) { ctx := portal.getEventCtxWithLog(rawEvt, idx) if _, isCreate := rawEvt.(*portalCreateEvent); isCreate { - portal.handleSingleEvent(ctx, rawEvt, func() {}) + portal.handleSingleEvent(ctx, rawEvt, func(res EventHandlingResult) { + outerRes = res + }) } else if portal.Bridge.Config.AsyncEvents { - go portal.handleSingleEvent(ctx, rawEvt, func() {}) + outerRes = EventHandlingResultQueued + go portal.handleSingleEvent(ctx, rawEvt, func(res EventHandlingResult) {}) } else { log := zerolog.Ctx(ctx) doneCh := make(chan struct{}) var backgrounded atomic.Bool start := time.Now() var handleDuration time.Duration - go portal.handleSingleEvent(ctx, rawEvt, func() { + // Note: this will not set the success flag if the handler times out + outerRes = EventHandlingResult{Queued: true} + go portal.handleSingleEvent(ctx, rawEvt, func(res EventHandlingResult) { + outerRes = res handleDuration = time.Since(start) close(doneCh) if backgrounded.Load() { @@ -358,6 +366,7 @@ func (portal *Portal) handleSingleEventAsync(idx int, rawEvt any) { Msg("Event handling is taking too long, continuing in background") backgrounded.Store(true) } + return } func (portal *Portal) getEventCtxWithLog(rawEvt any, idx int) context.Context { @@ -404,10 +413,11 @@ func (portal *Portal) getEventCtxWithLog(rawEvt any, idx int) context.Context { return logWith.Logger().WithContext(portal.Bridge.BackgroundCtx) } -func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCallback func()) { +func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCallback func(res EventHandlingResult)) { log := zerolog.Ctx(ctx) + var res EventHandlingResult defer func() { - doneCallback() + doneCallback(res) if err := recover(); err != nil { logEvt := log.Error() if realErr, ok := err.(error); ok { @@ -432,9 +442,11 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal case *portalMatrixEvent: portal.handleMatrixEvent(ctx, evt.sender, evt.evt) case *portalRemoteEvent: - portal.handleRemoteEvent(ctx, evt.source, evt.evtType, evt.evt) + res = portal.handleRemoteEvent(ctx, evt.source, evt.evtType, evt.evt) case *portalCreateEvent: - evt.cb(portal.createMatrixRoomInLoop(evt.ctx, evt.source, evt.info, nil)) + err := portal.createMatrixRoomInLoop(evt.ctx, evt.source, evt.info, nil) + res.Success = err == nil + evt.cb(err) default: panic(fmt.Errorf("illegal type %T in eventLoop", evt)) } @@ -627,7 +639,7 @@ func (portal *Portal) handleMatrixReceipts(ctx context.Context, evt *event.Event for userID, receipt := range readReceipts { sender, err := portal.Bridge.GetUserByMXID(ctx, userID) if err != nil { - // TODO log + zerolog.Ctx(ctx).Err(err).Msg("Failed to get user to handle read receipt") return } portal.handleMatrixReadReceipt(ctx, sender, evtID, receipt) @@ -1752,13 +1764,13 @@ func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLog portal.sendSuccessStatus(ctx, evt, 0, "") } -func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, evtType RemoteEventType, evt RemoteEvent) { +func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, evtType RemoteEventType, evt RemoteEvent) (res EventHandlingResult) { log := zerolog.Ctx(ctx) if portal.MXID == "" { mcp, ok := evt.(RemoteEventThatMayCreatePortal) if !ok || !mcp.ShouldCreatePortal() { log.Debug().Msg("Dropping event as portal doesn't exist") - return + return EventHandlingResultIgnored } infoProvider, ok := mcp.(RemoteChatResyncWithInfo) var info *ChatInfo @@ -1777,8 +1789,7 @@ func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, err = portal.createMatrixRoomInLoop(ctx, source, info, bundle) if err != nil { log.Err(err).Msg("Failed to create portal to handle event") - // TODO error - return + return EventHandlingResultFailed } if evtType == RemoteEventChatResync { log.Debug().Msg("Not handling chat resync event further as portal was created by it") @@ -1786,7 +1797,7 @@ func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, if ok { postHandler.PostHandle(ctx, portal) } - return + return EventHandlingResultSuccess } } preHandler, ok := evt.(RemotePreHandler) @@ -1798,33 +1809,33 @@ func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, case RemoteEventUnknown: log.Debug().Msg("Ignoring remote event with type unknown") case RemoteEventMessage, RemoteEventMessageUpsert: - portal.handleRemoteMessage(ctx, source, evt.(RemoteMessage)) + res = portal.handleRemoteMessage(ctx, source, evt.(RemoteMessage)) case RemoteEventEdit: - portal.handleRemoteEdit(ctx, source, evt.(RemoteEdit)) + res = portal.handleRemoteEdit(ctx, source, evt.(RemoteEdit)) case RemoteEventReaction: - portal.handleRemoteReaction(ctx, source, evt.(RemoteReaction)) + res = portal.handleRemoteReaction(ctx, source, evt.(RemoteReaction)) case RemoteEventReactionRemove: - portal.handleRemoteReactionRemove(ctx, source, evt.(RemoteReactionRemove)) + res = portal.handleRemoteReactionRemove(ctx, source, evt.(RemoteReactionRemove)) case RemoteEventReactionSync: - portal.handleRemoteReactionSync(ctx, source, evt.(RemoteReactionSync)) + res = portal.handleRemoteReactionSync(ctx, source, evt.(RemoteReactionSync)) case RemoteEventMessageRemove: - portal.handleRemoteMessageRemove(ctx, source, evt.(RemoteMessageRemove)) + res = portal.handleRemoteMessageRemove(ctx, source, evt.(RemoteMessageRemove)) case RemoteEventReadReceipt: - portal.handleRemoteReadReceipt(ctx, source, evt.(RemoteReadReceipt)) + res = portal.handleRemoteReadReceipt(ctx, source, evt.(RemoteReadReceipt)) case RemoteEventMarkUnread: - portal.handleRemoteMarkUnread(ctx, source, evt.(RemoteMarkUnread)) + res = portal.handleRemoteMarkUnread(ctx, source, evt.(RemoteMarkUnread)) case RemoteEventDeliveryReceipt: - portal.handleRemoteDeliveryReceipt(ctx, source, evt.(RemoteDeliveryReceipt)) + res = portal.handleRemoteDeliveryReceipt(ctx, source, evt.(RemoteDeliveryReceipt)) case RemoteEventTyping: - portal.handleRemoteTyping(ctx, source, evt.(RemoteTyping)) + res = portal.handleRemoteTyping(ctx, source, evt.(RemoteTyping)) case RemoteEventChatInfoChange: - portal.handleRemoteChatInfoChange(ctx, source, evt.(RemoteChatInfoChange)) + res = portal.handleRemoteChatInfoChange(ctx, source, evt.(RemoteChatInfoChange)) case RemoteEventChatResync: - portal.handleRemoteChatResync(ctx, source, evt.(RemoteChatResync)) + res = portal.handleRemoteChatResync(ctx, source, evt.(RemoteChatResync)) case RemoteEventChatDelete: - portal.handleRemoteChatDelete(ctx, source, evt.(RemoteChatDelete)) + res = portal.handleRemoteChatDelete(ctx, source, evt.(RemoteChatDelete)) case RemoteEventBackfill: - portal.handleRemoteBackfill(ctx, source, evt.(RemoteBackfill)) + res = portal.handleRemoteBackfill(ctx, source, evt.(RemoteBackfill)) default: log.Warn().Msg("Got remote event with unknown type") } @@ -1832,9 +1843,10 @@ func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, if ok { postHandler.PostHandle(ctx, portal) } + return } -func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID) { +func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID, err error) { var ghost *Ghost if !sender.IsFromMe && sender.ForceDMUser && portal.OtherUserID != "" && sender.Sender != portal.OtherUserID { zerolog.Ctx(ctx).Warn(). @@ -1843,21 +1855,20 @@ func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventS Msg("Overriding event sender with primary other user in DM portal") // Ensure the ghost row exists anyway to prevent foreign key errors when saving messages // TODO it'd probably be better to override the sender in the saved message, but that's more effort - _, err := portal.Bridge.GetGhostByID(ctx, sender.Sender) + _, err = portal.Bridge.GetGhostByID(ctx, sender.Sender) if err != nil { zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to get ghost with original user ID") + return } sender.Sender = portal.OtherUserID } if sender.Sender != "" { - var err error ghost, err = portal.Bridge.GetGhostByID(ctx, sender.Sender) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost for message sender") return - } else { - ghost.UpdateInfoIfNecessary(ctx, source, evtType) } + ghost.UpdateInfoIfNecessary(ctx, source, evtType) } if sender.IsFromMe { intent = source.User.DoublePuppet(ctx) @@ -1892,15 +1903,21 @@ func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventS return } -func (portal *Portal) GetIntentFor(ctx context.Context, sender EventSender, source *UserLogin, evtType RemoteEventType) MatrixAPI { - intent, _ := portal.getIntentAndUserMXIDFor(ctx, sender, source, nil, evtType) +func (portal *Portal) GetIntentFor(ctx context.Context, sender EventSender, source *UserLogin, evtType RemoteEventType) (MatrixAPI, bool) { + intent, _, err := portal.getIntentAndUserMXIDFor(ctx, sender, source, nil, evtType) + if err != nil { + return nil, false + } if intent == nil { // TODO this is very hacky - we should either insert an empty ghost row automatically // (and not fetch it at runtime) or make the message sender column nullable. portal.Bridge.GetGhostByID(ctx, "") intent = portal.Bridge.Bot + if intent == nil { + panic(fmt.Errorf("bridge bot is nil")) + } } - return intent + return intent, true } func (portal *Portal) getRelationMeta(ctx context.Context, currentMsg networkid.MessageID, replyToPtr *networkid.MessageOptionalPartID, threadRootPtr *networkid.MessageID, isBatchSend bool) (replyTo, threadRoot, prevThreadEvent *database.Message) { @@ -1982,7 +1999,7 @@ func (portal *Portal) sendConvertedMessage( ts time.Time, streamOrder int64, logContext func(*zerolog.Event) *zerolog.Event, -) []*database.Message { +) ([]*database.Message, EventHandlingResult) { if logContext == nil { logContext = func(e *zerolog.Event) *zerolog.Event { return e @@ -1991,6 +2008,7 @@ func (portal *Portal) sendConvertedMessage( log := zerolog.Ctx(ctx) replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta(ctx, id, converted.ReplyTo, converted.ThreadRoot, false) output := make([]*database.Message, 0, len(converted.Parts)) + allSuccess := true for i, part := range converted.Parts { portal.applyRelationMeta(ctx, part.Content, replyTo, threadRoot, prevThreadEvent) dbMessage := &database.Message{ @@ -2023,6 +2041,7 @@ func (portal *Portal) sendConvertedMessage( }) if err != nil { logContext(log.Err(err)).Str("part_id", string(part.ID)).Msg("Failed to send message part to Matrix") + allSuccess = false continue } logContext(log.Debug()). @@ -2034,12 +2053,13 @@ func (portal *Portal) sendConvertedMessage( err := portal.Bridge.DB.Message.Insert(ctx, dbMessage) if err != nil { logContext(log.Err(err)).Str("part_id", string(part.ID)).Msg("Failed to save message part to database") + allSuccess = false } if converted.Disappear.Type != database.DisappearingTypeNone && !dbMessage.HasFakeMXID() { if converted.Disappear.Type == database.DisappearingTypeAfterSend && converted.Disappear.DisappearAt.IsZero() { converted.Disappear.DisappearAt = dbMessage.Timestamp.Add(converted.Disappear.Timer) } - go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ + portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ RoomID: portal.MXID, EventID: dbMessage.MXID, DisappearingSetting: converted.Disappear, @@ -2050,7 +2070,10 @@ func (portal *Portal) sendConvertedMessage( } output = append(output, dbMessage) } - return output + if !allSuccess { + return output, EventHandlingResultFailed + } + return output, EventHandlingResultSuccess } func (portal *Portal) checkPendingMessage(ctx context.Context, evt RemoteMessage) (bool, *database.Message) { @@ -2110,21 +2133,24 @@ func (portal *Portal) checkPendingMessage(ctx context.Context, evt RemoteMessage return true, pending.db } -func (portal *Portal) handleRemoteUpsert(ctx context.Context, source *UserLogin, evt RemoteMessageUpsert, existing []*database.Message) bool { +func (portal *Portal) handleRemoteUpsert(ctx context.Context, source *UserLogin, evt RemoteMessageUpsert, existing []*database.Message) (handleRes EventHandlingResult, continueHandling bool) { log := zerolog.Ctx(ctx) - intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessageUpsert) - if intent == nil { - return false + intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessageUpsert) + if !ok { + return } res, err := evt.HandleExisting(ctx, portal, intent, existing) if err != nil { log.Err(err).Msg("Failed to handle existing message in upsert event after receiving remote echo") + } else { + handleRes = EventHandlingResultSuccess } if res.SaveParts { for _, part := range existing { err = portal.Bridge.DB.Message.Update(ctx, part) if err != nil { log.Err(err).Str("part_id", string(part.PartID)).Msg("Failed to update message part in database") + handleRes = EventHandlingResultFailed } } } @@ -2136,19 +2162,25 @@ func (portal *Portal) handleRemoteUpsert(ctx context.Context, source *UserLogin, Str("action", "handle remote subevent"). Stringer("bridge_evt_type", subType). Logger() - portal.handleRemoteEvent(log.WithContext(ctx), source, subType, subEvt) + subRes := portal.handleRemoteEvent(log.WithContext(ctx), source, subType, subEvt) + if !subRes.Success { + handleRes.Success = false + } } } - return res.ContinueMessageHandling + continueHandling = res.ContinueMessageHandling + return } -func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin, evt RemoteMessage) { +func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin, evt RemoteMessage) (res EventHandlingResult) { log := zerolog.Ctx(ctx) upsertEvt, isUpsert := evt.(RemoteMessageUpsert) isUpsert = isUpsert && evt.GetType() == RemoteEventMessageUpsert if wasPending, dbMessage := portal.checkPendingMessage(ctx, evt); wasPending { if isUpsert && dbMessage != nil { - portal.handleRemoteUpsert(ctx, source, upsertEvt, []*database.Message{dbMessage}) + res, _ = portal.handleRemoteUpsert(ctx, source, upsertEvt, []*database.Message{dbMessage}) + } else { + res = EventHandlingResultIgnored } return } @@ -2157,35 +2189,42 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin log.Err(err).Msg("Failed to check if message is a duplicate") } else if len(existing) > 0 { if isUpsert { - if portal.handleRemoteUpsert(ctx, source, upsertEvt, existing) { + var continueHandling bool + res, continueHandling = portal.handleRemoteUpsert(ctx, source, upsertEvt, existing) + if continueHandling { log.Debug().Msg("Upsert handler said to continue message handling normally") } else { - return + return res } } else { log.Debug().Stringer("existing_mxid", existing[0].MXID).Msg("Ignoring duplicate message") - return + return EventHandlingResultIgnored } } - intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessage) - if intent == nil { - return + intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessage) + if !ok { + return EventHandlingResultFailed } ts := getEventTS(evt) converted, err := evt.ConvertMessage(ctx, portal, intent) if err != nil { if errors.Is(err, ErrIgnoringRemoteEvent) { log.Debug().Err(err).Msg("Remote message handling was cancelled by convert function") + return EventHandlingResultIgnored } else { log.Err(err).Msg("Failed to convert remote message") portal.sendRemoteErrorNotice(ctx, intent, err, ts, "message") + return EventHandlingResultFailed } - return } - portal.sendConvertedMessage(ctx, evt.GetID(), intent, evt.GetSender().Sender, converted, ts, getStreamOrder(evt), nil) + _, res = portal.sendConvertedMessage(ctx, evt.GetID(), intent, evt.GetSender().Sender, converted, ts, getStreamOrder(evt), nil) if portal.currentlyTypingGhosts.Pop(intent.GetMXID()) { - intent.MarkTyping(ctx, portal.MXID, TypingTypeText, 0) + err = intent.MarkTyping(ctx, portal.MXID, TypingTypeText, 0) + if err != nil { + log.Warn().Err(err).Msg("Failed to send stop typing event after bridging message") + } } + return } func (portal *Portal) sendRemoteErrorNotice(ctx context.Context, intent MatrixAPI, err error, ts time.Time, evtTypeName string) { @@ -2208,7 +2247,7 @@ func (portal *Portal) sendRemoteErrorNotice(ctx context.Context, intent MatrixAP } } -func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, evt RemoteEdit) { +func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, evt RemoteEdit) EventHandlingResult { log := zerolog.Ctx(ctx) var existing []*database.Message if bundledEvt, ok := evt.(RemoteEventWithBundledParts); ok { @@ -2220,37 +2259,41 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e existing, err = portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, targetID) if err != nil { log.Err(err).Msg("Failed to get edit target message") - return + return EventHandlingResultFailed } } if existing == nil { log.Warn().Msg("Edit target message not found") - return + return EventHandlingResultIgnored } - intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventEdit) - if intent == nil { - return + intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventEdit) + if !ok { + return EventHandlingResultFailed } else if intent.GetMXID() != existing[0].SenderMXID { log.Warn(). Stringer("edit_sender_mxid", intent.GetMXID()). Stringer("original_sender_mxid", existing[0].SenderMXID). Msg("Not bridging edit: sender doesn't match original message sender") - return + return EventHandlingResultIgnored } ts := getEventTS(evt) converted, err := evt.ConvertEdit(ctx, portal, intent, existing) if errors.Is(err, ErrIgnoringRemoteEvent) { log.Debug().Err(err).Msg("Remote edit handling was cancelled by convert function") - return + return EventHandlingResultIgnored } else if err != nil { log.Err(err).Msg("Failed to convert remote edit") portal.sendRemoteErrorNotice(ctx, intent, err, ts, "edit") - return + return EventHandlingResultFailed } - portal.sendConvertedEdit(ctx, existing[0].ID, evt.GetSender().Sender, converted, intent, ts, getStreamOrder(evt)) + res := portal.sendConvertedEdit(ctx, existing[0].ID, evt.GetSender().Sender, converted, intent, ts, getStreamOrder(evt)) if portal.currentlyTypingGhosts.Pop(intent.GetMXID()) { - intent.MarkTyping(ctx, portal.MXID, TypingTypeText, 0) + err = intent.MarkTyping(ctx, portal.MXID, TypingTypeText, 0) + if err != nil { + log.Warn().Err(err).Msg("Failed to send stop typing event after bridging edit") + } } + return res } func (portal *Portal) sendConvertedEdit( @@ -2261,8 +2304,9 @@ func (portal *Portal) sendConvertedEdit( intent MatrixAPI, ts time.Time, streamOrder int64, -) { +) EventHandlingResult { log := zerolog.Ctx(ctx) + allSuccess := true for i, part := range converted.ModifiedParts { if part.Content.Mentions == nil { part.Content.Mentions = &event.Mentions{} @@ -2298,6 +2342,7 @@ func (portal *Portal) sendConvertedEdit( }) if err != nil { log.Err(err).Stringer("part_mxid", part.Part.MXID).Msg("Failed to edit message part") + allSuccess = false continue } else { log.Debug(). @@ -2312,6 +2357,7 @@ func (portal *Portal) sendConvertedEdit( err := portal.Bridge.DB.Message.Update(ctx, part.Part) if err != nil { log.Err(err).Int64("part_rowid", part.Part.RowID).Msg("Failed to update message part in database") + allSuccess = false } } for _, part := range converted.DeletedParts { @@ -2325,6 +2371,7 @@ func (portal *Portal) sendConvertedEdit( }) if err != nil { log.Err(err).Stringer("part_mxid", part.MXID).Msg("Failed to redact message part deleted in edit") + allSuccess = false } else { log.Debug(). Stringer("redaction_event_id", resp.EventID). @@ -2335,11 +2382,19 @@ func (portal *Portal) sendConvertedEdit( err = portal.Bridge.DB.Message.Delete(ctx, part.RowID) if err != nil { log.Err(err).Int64("part_rowid", part.RowID).Msg("Failed to delete message part from database") + allSuccess = false } } if converted.AddedParts != nil { - portal.sendConvertedMessage(ctx, targetID, intent, senderID, converted.AddedParts, ts, streamOrder, nil) + _, res := portal.sendConvertedMessage(ctx, targetID, intent, senderID, converted.AddedParts, ts, streamOrder, nil) + if !res.Success { + allSuccess = false + } } + if !allSuccess { + return EventHandlingResultFailed + } + return EventHandlingResultSuccess } func (portal *Portal) getTargetMessagePart(ctx context.Context, evt RemoteEventWithTargetMessage) (*database.Message, error) { @@ -2372,17 +2427,17 @@ func getStreamOrder(evt RemoteEvent) int64 { return 0 } -func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *UserLogin, evt RemoteReactionSync) { +func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *UserLogin, evt RemoteReactionSync) EventHandlingResult { log := zerolog.Ctx(ctx) eventTS := getEventTS(evt) targetMessage, err := portal.getTargetMessagePart(ctx, evt) if err != nil { log.Err(err).Msg("Failed to get target message for reaction") - return + return EventHandlingResultFailed } else if targetMessage == nil { // TODO use deterministic event ID as target if applicable? log.Warn().Msg("Target message for reaction not found") - return + return EventHandlingResultIgnored } var existingReactions []*database.Reaction if partTargeter, ok := evt.(RemoteEventWithTargetPart); ok { @@ -2390,6 +2445,10 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User } else { existingReactions, err = portal.Bridge.DB.Reaction.GetAllToMessage(ctx, portal.Receiver, evt.GetTargetMessage()) } + if err != nil { + log.Err(err).Msg("Failed to get existing reactions for reaction sync") + return EventHandlingResultFailed + } existing := make(map[networkid.UserID]map[networkid.EmojiID]*database.Reaction) for _, existingReaction := range existingReactions { if existing[existingReaction.SenderID] == nil { @@ -2398,9 +2457,13 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User existing[existingReaction.SenderID][existingReaction.EmojiID] = existingReaction } - doAddReaction := func(new *BackfillReaction, intent MatrixAPI) MatrixAPI { + doAddReaction := func(new *BackfillReaction, intent MatrixAPI) { if intent == nil { - intent = portal.GetIntentFor(ctx, new.Sender, source, RemoteEventReactionSync) + var ok bool + intent, ok = portal.GetIntentFor(ctx, new.Sender, source, RemoteEventReactionSync) + if !ok { + return + } } portal.sendConvertedReaction( ctx, new.Sender.Sender, intent, targetMessage, new.EmojiID, new.Emoji, @@ -2411,7 +2474,6 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User Time("reaction_ts", new.Timestamp) }, ) - return intent } doRemoveReaction := func(old *database.Reaction, intent MatrixAPI, deleteRow bool) { if intent == nil && old.SenderMXID != "" { @@ -2445,7 +2507,10 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User } } doOverwriteReaction := func(new *BackfillReaction, old *database.Reaction) { - intent := portal.GetIntentFor(ctx, new.Sender, source, RemoteEventReactionSync) + intent, ok := portal.GetIntentFor(ctx, new.Sender, source, RemoteEventReactionSync) + if !ok { + return + } doRemoveReaction(old, intent, false) doAddReaction(new, intent) } @@ -2496,30 +2561,34 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User } } } + return EventHandlingResultSuccess } -func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogin, evt RemoteReaction) { +func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogin, evt RemoteReaction) EventHandlingResult { log := zerolog.Ctx(ctx) targetMessage, err := portal.getTargetMessagePart(ctx, evt) if err != nil { log.Err(err).Msg("Failed to get target message for reaction") - return + return EventHandlingResultFailed } else if targetMessage == nil { // TODO use deterministic event ID as target if applicable? log.Warn().Msg("Target message for reaction not found") - return + return EventHandlingResultIgnored } emoji, emojiID := evt.GetReactionEmoji() existingReaction, err := portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, targetMessage.ID, targetMessage.PartID, evt.GetSender().Sender, emojiID) if err != nil { log.Err(err).Msg("Failed to check if reaction is a duplicate") - return + return EventHandlingResultFailed } else if existingReaction != nil && (emojiID != "" || existingReaction.Emoji == emoji) { log.Debug().Msg("Ignoring duplicate reaction") - return + return EventHandlingResultIgnored } ts := getEventTS(evt) - intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventReaction) + intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventReaction) + if !ok { + return EventHandlingResultFailed + } var extra map[string]any if extraContentProvider, ok := evt.(RemoteReactionWithExtraContent); ok { extra = extraContentProvider.GetReactionExtraContent() @@ -2538,14 +2607,14 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi log.Err(err).Msg("Failed to redact old reaction") } } - portal.sendConvertedReaction(ctx, evt.GetSender().Sender, intent, targetMessage, emojiID, emoji, ts, dbMetadata, extra, nil) + return portal.sendConvertedReaction(ctx, evt.GetSender().Sender, intent, targetMessage, emojiID, emoji, ts, dbMetadata, extra, nil) } func (portal *Portal) sendConvertedReaction( ctx context.Context, senderID networkid.UserID, intent MatrixAPI, targetMessage *database.Message, emojiID networkid.EmojiID, emoji string, ts time.Time, dbMetadata any, extraContent map[string]any, logContext func(*zerolog.Event) *zerolog.Event, -) { +) EventHandlingResult { if logContext == nil { logContext = func(e *zerolog.Event) *zerolog.Event { return e @@ -2580,7 +2649,7 @@ func (portal *Portal) sendConvertedReaction( }) if err != nil { logContext(log.Err(err)).Msg("Failed to send reaction to Matrix") - return + return EventHandlingResultFailed } logContext(log.Debug()). Stringer("event_id", resp.EventID). @@ -2589,7 +2658,9 @@ func (portal *Portal) sendConvertedReaction( err = portal.Bridge.DB.Reaction.Upsert(ctx, dbReaction) if err != nil { logContext(log.Err(err)).Msg("Failed to save reaction to database") + return EventHandlingResultFailed } + return EventHandlingResultSuccess } func (portal *Portal) getIntentForMXID(ctx context.Context, userID id.UserID) (MatrixAPI, error) { @@ -2608,22 +2679,26 @@ func (portal *Portal) getIntentForMXID(ctx context.Context, userID id.UserID) (M } } -func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *UserLogin, evt RemoteReactionRemove) { +func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *UserLogin, evt RemoteReactionRemove) EventHandlingResult { log := zerolog.Ctx(ctx) targetReaction, err := portal.getTargetReaction(ctx, evt) if err != nil { log.Err(err).Msg("Failed to get target reaction for removal") - return + return EventHandlingResultFailed } else if targetReaction == nil { log.Warn().Msg("Target reaction not found") - return + return EventHandlingResultIgnored } intent, err := portal.getIntentForMXID(ctx, targetReaction.SenderMXID) if err != nil { log.Err(err).Stringer("sender_mxid", targetReaction.SenderMXID).Msg("Failed to get intent for removing reaction") } if intent == nil { - intent = portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventReactionRemove) + var ok bool + intent, ok = portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventReactionRemove) + if !ok { + return EventHandlingResultFailed + } } ts := getEventTS(evt) _, err = intent.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ @@ -2633,22 +2708,24 @@ func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *Us }, &MatrixSendExtra{Timestamp: ts, ReactionMeta: targetReaction}) if err != nil { log.Err(err).Stringer("reaction_mxid", targetReaction.MXID).Msg("Failed to redact reaction") + return EventHandlingResultFailed } err = portal.Bridge.DB.Reaction.Delete(ctx, targetReaction) if err != nil { log.Err(err).Msg("Failed to delete target reaction from database") } + return EventHandlingResultSuccess } -func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *UserLogin, evt RemoteMessageRemove) { +func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *UserLogin, evt RemoteMessageRemove) EventHandlingResult { log := zerolog.Ctx(ctx) targetParts, err := portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, evt.GetTargetMessage()) if err != nil { log.Err(err).Msg("Failed to get target message for removal") - return + return EventHandlingResultFailed } else if len(targetParts) == 0 { log.Debug().Msg("Target message not found") - return + return EventHandlingResultIgnored } onlyForMeProvider, ok := evt.(RemoteDeleteOnlyForMe) onlyForMe := ok && onlyForMeProvider.DeleteOnlyForMe() @@ -2656,7 +2733,10 @@ func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *Use // TODO check if there are other user logins before deleting } - intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessageRemove) + intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessageRemove) + if !ok { + return EventHandlingResultFailed + } if intent == portal.Bridge.Bot && len(targetParts) > 0 { senderIntent, err := portal.getIntentForMXID(ctx, targetParts[0].SenderMXID) if err != nil { @@ -2665,15 +2745,17 @@ func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *Use intent = senderIntent } } - portal.redactMessageParts(ctx, targetParts, intent, getEventTS(evt)) + res := portal.redactMessageParts(ctx, targetParts, intent, getEventTS(evt)) err = portal.Bridge.DB.Message.DeleteAllParts(ctx, portal.Receiver, evt.GetTargetMessage()) if err != nil { log.Err(err).Msg("Failed to delete target message from database") } + return res } -func (portal *Portal) redactMessageParts(ctx context.Context, parts []*database.Message, intent MatrixAPI, ts time.Time) { +func (portal *Portal) redactMessageParts(ctx context.Context, parts []*database.Message, intent MatrixAPI, ts time.Time) EventHandlingResult { log := zerolog.Ctx(ctx) + var anyFailed bool for _, part := range parts { if part.HasFakeMXID() { continue @@ -2685,6 +2767,7 @@ func (portal *Portal) redactMessageParts(ctx context.Context, parts []*database. }, &MatrixSendExtra{Timestamp: ts, MessageMeta: part}) if err != nil { log.Err(err).Stringer("part_mxid", part.MXID).Msg("Failed to redact message part") + anyFailed = true } else { log.Debug(). Stringer("redaction_event_id", resp.EventID). @@ -2693,9 +2776,13 @@ func (portal *Portal) redactMessageParts(ctx context.Context, parts []*database. Msg("Sent redaction of message part to Matrix") } } + if anyFailed { + return EventHandlingResultFailed + } + return EventHandlingResultSuccess } -func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserLogin, evt RemoteReadReceipt) { +func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserLogin, evt RemoteReadReceipt) EventHandlingResult { log := zerolog.Ctx(ctx) var err error var lastTarget *database.Message @@ -2705,7 +2792,7 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL if err != nil { log.Err(err).Str("last_target_id", string(lastTargetID)). Msg("Failed to get last target message for read receipt") - return + return EventHandlingResultFailed } else if lastTarget == nil { log.Debug().Str("last_target_id", string(lastTargetID)). Msg("Last target message not found") @@ -2724,7 +2811,7 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL if err != nil { log.Err(err).Str("target_id", string(targetID)). Msg("Failed to get target message for read receipt") - return + return EventHandlingResultFailed } else if target != nil && !target.HasFakeMXID() && (lastTarget == nil || target.Timestamp.After(lastTarget.Timestamp)) { lastTarget = target } @@ -2737,14 +2824,17 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL } } sender := evt.GetSender() - intent := portal.GetIntentFor(ctx, sender, source, RemoteEventReadReceipt) + intent, ok := portal.GetIntentFor(ctx, sender, source, RemoteEventReadReceipt) + if !ok { + return EventHandlingResultFailed + } var addTargetLog func(evt *zerolog.Event) *zerolog.Event if lastTarget == nil { sevt, evtOK := evt.(RemoteReadReceiptWithStreamOrder) soIntent, soIntentOK := intent.(StreamOrderReadingMatrixAPI) if !evtOK || !soIntentOK || sevt.GetReadUpToStreamOrder() == 0 { log.Warn().Msg("No target message found for read receipt") - return + return EventHandlingResultIgnored } targetStreamOrder := sevt.GetReadUpToStreamOrder() addTargetLog = func(evt *zerolog.Event) *zerolog.Event { @@ -2759,40 +2849,47 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL } if err != nil { addTargetLog(log.Err(err)).Msg("Failed to bridge read receipt") + return EventHandlingResultFailed } else { addTargetLog(log.Debug()).Msg("Bridged read receipt") } if sender.IsFromMe { portal.Bridge.DisappearLoop.StartAll(ctx, portal.MXID) } + return EventHandlingResultSuccess } -func (portal *Portal) handleRemoteMarkUnread(ctx context.Context, source *UserLogin, evt RemoteMarkUnread) { +func (portal *Portal) handleRemoteMarkUnread(ctx context.Context, source *UserLogin, evt RemoteMarkUnread) EventHandlingResult { if !evt.GetSender().IsFromMe { zerolog.Ctx(ctx).Warn().Msg("Ignoring mark unread event from non-self user") - return + return EventHandlingResultIgnored } dp := source.User.DoublePuppet(ctx) if dp == nil { - return + return EventHandlingResultIgnored } err := dp.MarkUnread(ctx, portal.MXID, evt.GetUnread()) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to bridge mark unread event") + return EventHandlingResultFailed } + return EventHandlingResultSuccess } -func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteDeliveryReceipt) { +func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteDeliveryReceipt) EventHandlingResult { if portal.RoomType != database.RoomTypeDM || evt.GetSender().Sender != portal.OtherUserID { - return + return EventHandlingResultIgnored + } + intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventDeliveryReceipt) + if !ok { + return EventHandlingResultFailed } - intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventDeliveryReceipt) log := zerolog.Ctx(ctx) for _, target := range evt.GetReceiptTargets() { targetParts, err := portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, target) if err != nil { log.Err(err).Str("target_id", string(target)).Msg("Failed to get target message for delivery receipt") - continue + return EventHandlingResultFailed } else if len(targetParts) == 0 { continue } else if _, sentByGhost := portal.Bridge.Matrix.ParseGhostMXID(targetParts[0].SenderMXID); sentByGhost { @@ -2811,36 +2908,43 @@ func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *U }) } } + return EventHandlingResultSuccess } -func (portal *Portal) handleRemoteTyping(ctx context.Context, source *UserLogin, evt RemoteTyping) { +func (portal *Portal) handleRemoteTyping(ctx context.Context, source *UserLogin, evt RemoteTyping) EventHandlingResult { var typingType TypingType if typedEvt, ok := evt.(RemoteTypingWithType); ok { typingType = typedEvt.GetTypingType() } - intent := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventTyping) + intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventTyping) + if !ok { + return EventHandlingResultFailed + } timeout := evt.GetTimeout() err := intent.MarkTyping(ctx, portal.MXID, typingType, timeout) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to bridge typing event") + return EventHandlingResultFailed } if timeout == 0 { portal.currentlyTypingGhosts.Remove(intent.GetMXID()) } else { portal.currentlyTypingGhosts.Add(intent.GetMXID()) } + return EventHandlingResultSuccess } -func (portal *Portal) handleRemoteChatInfoChange(ctx context.Context, source *UserLogin, evt RemoteChatInfoChange) { +func (portal *Portal) handleRemoteChatInfoChange(ctx context.Context, source *UserLogin, evt RemoteChatInfoChange) EventHandlingResult { info, err := evt.GetChatInfoChange(ctx) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to get chat info change") - return + return EventHandlingResultFailed } portal.ProcessChatInfoChange(ctx, evt.GetSender(), source, info, getEventTS(evt)) + return EventHandlingResultSuccess } -func (portal *Portal) handleRemoteChatResync(ctx context.Context, source *UserLogin, evt RemoteChatResync) { +func (portal *Portal) handleRemoteChatResync(ctx context.Context, source *UserLogin, evt RemoteChatResync) EventHandlingResult { log := zerolog.Ctx(ctx) infoProvider, ok := evt.(RemoteChatResyncWithInfo) if ok { @@ -2869,15 +2973,16 @@ func (portal *Portal) handleRemoteChatResync(ctx context.Context, source *UserLo portal.doForwardBackfill(ctx, source, latestMessage, bundle) } } + return EventHandlingResultSuccess } -func (portal *Portal) handleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) { +func (portal *Portal) handleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) EventHandlingResult { log := zerolog.Ctx(ctx) if portal.Receiver == "" && evt.DeleteOnlyForMe() { logins, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) if err != nil { log.Err(err).Msg("Failed to check if portal has other logins") - return + return EventHandlingResultFailed } var ownUP *database.UserPortal logins = slices.DeleteFunc(logins, func(up *database.UserPortal) bool { @@ -2907,31 +3012,35 @@ func (portal *Portal) handleRemoteChatDelete(ctx context.Context, source *UserLo ) if err != nil { log.Err(err).Msg("Failed to send leave state event for user after remote chat delete") + return EventHandlingResultFailed } else { log.Debug().Msg("Sent leave state event for user after remote chat delete") + return EventHandlingResultSuccess } - return } } err := portal.Delete(ctx) if err != nil { log.Err(err).Msg("Failed to delete portal from database") - return + return EventHandlingResultFailed } err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, false) if err != nil { log.Err(err).Msg("Failed to delete Matrix room") + return EventHandlingResultFailed } else { log.Info().Msg("Deleted room after remote chat delete event") + return EventHandlingResultSuccess } } -func (portal *Portal) handleRemoteBackfill(ctx context.Context, source *UserLogin, backfill RemoteBackfill) { +func (portal *Portal) handleRemoteBackfill(ctx context.Context, source *UserLogin, backfill RemoteBackfill) (res EventHandlingResult) { //data, err := backfill.GetBackfillData(ctx, portal) //if err != nil { // zerolog.Ctx(ctx).Err(err).Msg("Failed to get backfill data") // return //} + return } type ChatInfoChange struct { @@ -2944,7 +3053,10 @@ type ChatInfoChange struct { } func (portal *Portal) ProcessChatInfoChange(ctx context.Context, sender EventSender, source *UserLogin, change *ChatInfoChange, ts time.Time) { - intent := portal.GetIntentFor(ctx, sender, source, RemoteEventChatInfoChange) + intent, ok := portal.GetIntentFor(ctx, sender, source, RemoteEventChatInfoChange) + if !ok { + return + } if change.ChatInfo != nil { portal.UpdateInfo(ctx, change.ChatInfo, source, intent, ts) } @@ -3339,7 +3451,10 @@ func (portal *Portal) getInitialMemberList(ctx context.Context, members *ChatMem ghost.UpdateInfo(ctx, member.UserInfo) } } - intent, extraUserID := portal.getIntentAndUserMXIDFor(ctx, member.EventSender, source, loginsInPortal, 0) + intent, extraUserID, err := portal.getIntentAndUserMXIDFor(ctx, member.EventSender, source, loginsInPortal, 0) + if err != nil { + return nil, nil, err + } if extraUserID != "" { invite = append(invite, extraUserID) if member.PowerLevel != nil { @@ -3535,7 +3650,10 @@ func (portal *Portal) syncParticipants( ghost.UpdateInfo(ctx, member.UserInfo) } } - intent, extraUserID := portal.getIntentAndUserMXIDFor(ctx, member.EventSender, source, loginsInPortal, 0) + intent, extraUserID, err := portal.getIntentAndUserMXIDFor(ctx, member.EventSender, source, loginsInPortal, 0) + if err != nil { + return err + } if intent != nil { syncIntent(intent, member) } diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index 3953a043..74b75df2 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -323,7 +323,10 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin if len(msg.Parts) == 0 { return } - intent := portal.GetIntentFor(ctx, msg.Sender, source, RemoteEventMessage) + intent, ok := portal.GetIntentFor(ctx, msg.Sender, source, RemoteEventMessage) + if !ok { + return + } replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta(ctx, msg.ID, msg.ReplyTo, msg.ThreadRoot, true) if threadRoot != nil && out.PrevThreadEvents[*msg.ThreadRoot] != "" { prevThreadEvent.MXID = out.PrevThreadEvents[*msg.ThreadRoot] @@ -387,7 +390,10 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin } slices.Sort(partIDs) for _, reaction := range msg.Reactions { - reactionIntent := portal.GetIntentFor(ctx, reaction.Sender, source, RemoteEventReactionRemove) + reactionIntent, ok := portal.GetIntentFor(ctx, reaction.Sender, source, RemoteEventReactionRemove) + if !ok { + continue + } if reaction.TargetPart == nil { reaction.TargetPart = &partIDs[0] } @@ -513,8 +519,11 @@ func (portal *Portal) sendBatch(ctx context.Context, source *UserLogin, messages func (portal *Portal) sendLegacyBackfill(ctx context.Context, source *UserLogin, messages []*BackfillMessage, markRead bool) { var lastPart id.EventID for _, msg := range messages { - intent := portal.GetIntentFor(ctx, msg.Sender, source, RemoteEventMessage) - dbMessages := portal.sendConvertedMessage(ctx, msg.ID, intent, msg.Sender.Sender, msg.ConvertedMessage, msg.Timestamp, msg.StreamOrder, func(z *zerolog.Event) *zerolog.Event { + intent, ok := portal.GetIntentFor(ctx, msg.Sender, source, RemoteEventMessage) + if !ok { + continue + } + dbMessages, _ := portal.sendConvertedMessage(ctx, msg.ID, intent, msg.Sender.Sender, msg.ConvertedMessage, msg.Timestamp, msg.StreamOrder, func(z *zerolog.Event) *zerolog.Event { return z. Str("message_id", string(msg.ID)). Any("sender_id", msg.Sender). @@ -523,7 +532,10 @@ func (portal *Portal) sendLegacyBackfill(ctx context.Context, source *UserLogin, if len(dbMessages) > 0 { lastPart = dbMessages[len(dbMessages)-1].MXID for _, reaction := range msg.Reactions { - reactionIntent := portal.GetIntentFor(ctx, reaction.Sender, source, RemoteEventReaction) + reactionIntent, ok := portal.GetIntentFor(ctx, reaction.Sender, source, RemoteEventReaction) + if !ok { + continue + } targetPart := dbMessages[0] if reaction.TargetPart != nil { targetPartIdx := slices.IndexFunc(dbMessages, func(dbMsg *database.Message) bool { diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index bde0b170..2b25f0cf 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -29,23 +29,23 @@ func (portal *PortalInternals) UpdateLogger() { (*Portal)(portal).updateLogger() } -func (portal *PortalInternals) QueueEvent(ctx context.Context, evt portalEvent) { - (*Portal)(portal).queueEvent(ctx, evt) +func (portal *PortalInternals) QueueEvent(ctx context.Context, evt portalEvent) EventHandlingResult { + return (*Portal)(portal).queueEvent(ctx, evt) } func (portal *PortalInternals) EventLoop() { (*Portal)(portal).eventLoop() } -func (portal *PortalInternals) HandleSingleEventAsync(idx int, rawEvt any) { - (*Portal)(portal).handleSingleEventAsync(idx, rawEvt) +func (portal *PortalInternals) HandleSingleEventAsync(idx int, rawEvt any) (outerRes EventHandlingResult) { + return (*Portal)(portal).handleSingleEventAsync(idx, rawEvt) } func (portal *PortalInternals) GetEventCtxWithLog(rawEvt any, idx int) context.Context { return (*Portal)(portal).getEventCtxWithLog(rawEvt, idx) } -func (portal *PortalInternals) HandleSingleEvent(ctx context.Context, rawEvt any, doneCallback func()) { +func (portal *PortalInternals) HandleSingleEvent(ctx context.Context, rawEvt any, doneCallback func(EventHandlingResult)) { (*Portal)(portal).handleSingleEvent(ctx, rawEvt, doneCallback) } @@ -129,11 +129,11 @@ func (portal *PortalInternals) HandleMatrixRedaction(ctx context.Context, sender (*Portal)(portal).handleMatrixRedaction(ctx, sender, origSender, evt) } -func (portal *PortalInternals) HandleRemoteEvent(ctx context.Context, source *UserLogin, evtType RemoteEventType, evt RemoteEvent) { - (*Portal)(portal).handleRemoteEvent(ctx, source, evtType, evt) +func (portal *PortalInternals) HandleRemoteEvent(ctx context.Context, source *UserLogin, evtType RemoteEventType, evt RemoteEvent) (res EventHandlingResult) { + return (*Portal)(portal).handleRemoteEvent(ctx, source, evtType, evt) } -func (portal *PortalInternals) GetIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID) { +func (portal *PortalInternals) GetIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID, err error) { return (*Portal)(portal).getIntentAndUserMXIDFor(ctx, sender, source, otherLogins, evtType) } @@ -145,7 +145,7 @@ func (portal *PortalInternals) ApplyRelationMeta(ctx context.Context, content *e (*Portal)(portal).applyRelationMeta(ctx, content, replyTo, threadRoot, prevThreadEvent) } -func (portal *PortalInternals) SendConvertedMessage(ctx context.Context, id networkid.MessageID, intent MatrixAPI, senderID networkid.UserID, converted *ConvertedMessage, ts time.Time, streamOrder int64, logContext func(*zerolog.Event) *zerolog.Event) []*database.Message { +func (portal *PortalInternals) SendConvertedMessage(ctx context.Context, id networkid.MessageID, intent MatrixAPI, senderID networkid.UserID, converted *ConvertedMessage, ts time.Time, streamOrder int64, logContext func(*zerolog.Event) *zerolog.Event) ([]*database.Message, EventHandlingResult) { return (*Portal)(portal).sendConvertedMessage(ctx, id, intent, senderID, converted, ts, streamOrder, logContext) } @@ -153,24 +153,24 @@ func (portal *PortalInternals) CheckPendingMessage(ctx context.Context, evt Remo return (*Portal)(portal).checkPendingMessage(ctx, evt) } -func (portal *PortalInternals) HandleRemoteUpsert(ctx context.Context, source *UserLogin, evt RemoteMessageUpsert, existing []*database.Message) bool { +func (portal *PortalInternals) HandleRemoteUpsert(ctx context.Context, source *UserLogin, evt RemoteMessageUpsert, existing []*database.Message) (handleRes EventHandlingResult, continueHandling bool) { return (*Portal)(portal).handleRemoteUpsert(ctx, source, evt, existing) } -func (portal *PortalInternals) HandleRemoteMessage(ctx context.Context, source *UserLogin, evt RemoteMessage) { - (*Portal)(portal).handleRemoteMessage(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteMessage(ctx context.Context, source *UserLogin, evt RemoteMessage) (res EventHandlingResult) { + return (*Portal)(portal).handleRemoteMessage(ctx, source, evt) } func (portal *PortalInternals) SendRemoteErrorNotice(ctx context.Context, intent MatrixAPI, err error, ts time.Time, evtTypeName string) { (*Portal)(portal).sendRemoteErrorNotice(ctx, intent, err, ts, evtTypeName) } -func (portal *PortalInternals) HandleRemoteEdit(ctx context.Context, source *UserLogin, evt RemoteEdit) { - (*Portal)(portal).handleRemoteEdit(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteEdit(ctx context.Context, source *UserLogin, evt RemoteEdit) EventHandlingResult { + return (*Portal)(portal).handleRemoteEdit(ctx, source, evt) } -func (portal *PortalInternals) SendConvertedEdit(ctx context.Context, targetID networkid.MessageID, senderID networkid.UserID, converted *ConvertedEdit, intent MatrixAPI, ts time.Time, streamOrder int64) { - (*Portal)(portal).sendConvertedEdit(ctx, targetID, senderID, converted, intent, ts, streamOrder) +func (portal *PortalInternals) SendConvertedEdit(ctx context.Context, targetID networkid.MessageID, senderID networkid.UserID, converted *ConvertedEdit, intent MatrixAPI, ts time.Time, streamOrder int64) EventHandlingResult { + return (*Portal)(portal).sendConvertedEdit(ctx, targetID, senderID, converted, intent, ts, streamOrder) } func (portal *PortalInternals) GetTargetMessagePart(ctx context.Context, evt RemoteEventWithTargetMessage) (*database.Message, error) { @@ -181,64 +181,64 @@ func (portal *PortalInternals) GetTargetReaction(ctx context.Context, evt Remote return (*Portal)(portal).getTargetReaction(ctx, evt) } -func (portal *PortalInternals) HandleRemoteReactionSync(ctx context.Context, source *UserLogin, evt RemoteReactionSync) { - (*Portal)(portal).handleRemoteReactionSync(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteReactionSync(ctx context.Context, source *UserLogin, evt RemoteReactionSync) EventHandlingResult { + return (*Portal)(portal).handleRemoteReactionSync(ctx, source, evt) } -func (portal *PortalInternals) HandleRemoteReaction(ctx context.Context, source *UserLogin, evt RemoteReaction) { - (*Portal)(portal).handleRemoteReaction(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteReaction(ctx context.Context, source *UserLogin, evt RemoteReaction) EventHandlingResult { + return (*Portal)(portal).handleRemoteReaction(ctx, source, evt) } -func (portal *PortalInternals) SendConvertedReaction(ctx context.Context, senderID networkid.UserID, intent MatrixAPI, targetMessage *database.Message, emojiID networkid.EmojiID, emoji string, ts time.Time, dbMetadata any, extraContent map[string]any, logContext func(*zerolog.Event) *zerolog.Event) { - (*Portal)(portal).sendConvertedReaction(ctx, senderID, intent, targetMessage, emojiID, emoji, ts, dbMetadata, extraContent, logContext) +func (portal *PortalInternals) SendConvertedReaction(ctx context.Context, senderID networkid.UserID, intent MatrixAPI, targetMessage *database.Message, emojiID networkid.EmojiID, emoji string, ts time.Time, dbMetadata any, extraContent map[string]any, logContext func(*zerolog.Event) *zerolog.Event) EventHandlingResult { + return (*Portal)(portal).sendConvertedReaction(ctx, senderID, intent, targetMessage, emojiID, emoji, ts, dbMetadata, extraContent, logContext) } func (portal *PortalInternals) GetIntentForMXID(ctx context.Context, userID id.UserID) (MatrixAPI, error) { return (*Portal)(portal).getIntentForMXID(ctx, userID) } -func (portal *PortalInternals) HandleRemoteReactionRemove(ctx context.Context, source *UserLogin, evt RemoteReactionRemove) { - (*Portal)(portal).handleRemoteReactionRemove(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteReactionRemove(ctx context.Context, source *UserLogin, evt RemoteReactionRemove) EventHandlingResult { + return (*Portal)(portal).handleRemoteReactionRemove(ctx, source, evt) } -func (portal *PortalInternals) HandleRemoteMessageRemove(ctx context.Context, source *UserLogin, evt RemoteMessageRemove) { - (*Portal)(portal).handleRemoteMessageRemove(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteMessageRemove(ctx context.Context, source *UserLogin, evt RemoteMessageRemove) EventHandlingResult { + return (*Portal)(portal).handleRemoteMessageRemove(ctx, source, evt) } -func (portal *PortalInternals) RedactMessageParts(ctx context.Context, parts []*database.Message, intent MatrixAPI, ts time.Time) { - (*Portal)(portal).redactMessageParts(ctx, parts, intent, ts) +func (portal *PortalInternals) RedactMessageParts(ctx context.Context, parts []*database.Message, intent MatrixAPI, ts time.Time) EventHandlingResult { + return (*Portal)(portal).redactMessageParts(ctx, parts, intent, ts) } -func (portal *PortalInternals) HandleRemoteReadReceipt(ctx context.Context, source *UserLogin, evt RemoteReadReceipt) { - (*Portal)(portal).handleRemoteReadReceipt(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteReadReceipt(ctx context.Context, source *UserLogin, evt RemoteReadReceipt) EventHandlingResult { + return (*Portal)(portal).handleRemoteReadReceipt(ctx, source, evt) } -func (portal *PortalInternals) HandleRemoteMarkUnread(ctx context.Context, source *UserLogin, evt RemoteMarkUnread) { - (*Portal)(portal).handleRemoteMarkUnread(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteMarkUnread(ctx context.Context, source *UserLogin, evt RemoteMarkUnread) EventHandlingResult { + return (*Portal)(portal).handleRemoteMarkUnread(ctx, source, evt) } -func (portal *PortalInternals) HandleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteDeliveryReceipt) { - (*Portal)(portal).handleRemoteDeliveryReceipt(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteDeliveryReceipt) EventHandlingResult { + return (*Portal)(portal).handleRemoteDeliveryReceipt(ctx, source, evt) } -func (portal *PortalInternals) HandleRemoteTyping(ctx context.Context, source *UserLogin, evt RemoteTyping) { - (*Portal)(portal).handleRemoteTyping(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteTyping(ctx context.Context, source *UserLogin, evt RemoteTyping) EventHandlingResult { + return (*Portal)(portal).handleRemoteTyping(ctx, source, evt) } -func (portal *PortalInternals) HandleRemoteChatInfoChange(ctx context.Context, source *UserLogin, evt RemoteChatInfoChange) { - (*Portal)(portal).handleRemoteChatInfoChange(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteChatInfoChange(ctx context.Context, source *UserLogin, evt RemoteChatInfoChange) EventHandlingResult { + return (*Portal)(portal).handleRemoteChatInfoChange(ctx, source, evt) } -func (portal *PortalInternals) HandleRemoteChatResync(ctx context.Context, source *UserLogin, evt RemoteChatResync) { - (*Portal)(portal).handleRemoteChatResync(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteChatResync(ctx context.Context, source *UserLogin, evt RemoteChatResync) EventHandlingResult { + return (*Portal)(portal).handleRemoteChatResync(ctx, source, evt) } -func (portal *PortalInternals) HandleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) { - (*Portal)(portal).handleRemoteChatDelete(ctx, source, evt) +func (portal *PortalInternals) HandleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) EventHandlingResult { + return (*Portal)(portal).handleRemoteChatDelete(ctx, source, evt) } -func (portal *PortalInternals) HandleRemoteBackfill(ctx context.Context, source *UserLogin, backfill RemoteBackfill) { - (*Portal)(portal).handleRemoteBackfill(ctx, source, backfill) +func (portal *PortalInternals) HandleRemoteBackfill(ctx context.Context, source *UserLogin, backfill RemoteBackfill) (res EventHandlingResult) { + return (*Portal)(portal).handleRemoteBackfill(ctx, source, backfill) } func (portal *PortalInternals) UpdateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time) bool { diff --git a/bridgev2/queue.go b/bridgev2/queue.go index 74424290..48ee78f1 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -151,11 +151,24 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) { } } -func (ul *UserLogin) QueueRemoteEvent(evt RemoteEvent) { - ul.Bridge.QueueRemoteEvent(ul, evt) +type EventHandlingResult struct { + Success bool + Ignored bool + Queued bool } -func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) { +var ( + EventHandlingResultFailed = EventHandlingResult{} + EventHandlingResultQueued = EventHandlingResult{Success: true, Queued: true} + EventHandlingResultSuccess = EventHandlingResult{Success: true} + EventHandlingResultIgnored = EventHandlingResult{Success: true, Ignored: true} +) + +func (ul *UserLogin) QueueRemoteEvent(evt RemoteEvent) EventHandlingResult { + return ul.Bridge.QueueRemoteEvent(ul, evt) +} + +func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) (res EventHandlingResult) { log := login.Log ctx := log.WithContext(br.BackgroundCtx) maybeUncertain, ok := evt.(RemoteEventWithUncertainPortalReceiver) @@ -182,7 +195,7 @@ func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) { } // TODO put this in a better place, and maybe cache to avoid constant db queries login.MarkInPortal(ctx, portal) - portal.queueEvent(ctx, &portalRemoteEvent{ + return portal.queueEvent(ctx, &portalRemoteEvent{ evt: evt, source: login, }) From f3722ca31f3d78f77648610ccb5638bc351ce150 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 19 Jun 2025 17:17:27 +0200 Subject: [PATCH 1213/1647] mediaproxy: validate media IDs --- mediaproxy/mediaproxy.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go index 1300a305..c906fc8e 100644 --- a/mediaproxy/mediaproxy.go +++ b/mediaproxy/mediaproxy.go @@ -26,6 +26,7 @@ import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/federation" + "maunium.net/go/mautrix/id" ) type GetMediaResponse interface { @@ -234,6 +235,10 @@ func queryToMap(vals url.Values) map[string]string { func (mp *MediaProxy) getMedia(w http.ResponseWriter, r *http.Request) GetMediaResponse { mediaID := mux.Vars(r)["mediaID"] + if !id.IsValidMediaID(mediaID) { + mautrix.MNotFound.WithMessage("Media ID %q is not valid", mediaID).Write(w) + return nil + } resp, err := mp.GetMedia(r.Context(), mediaID, queryToMap(r.URL.Query())) if err != nil { var mautrixRespError mautrix.RespError From 324be4ecb99766aaa7f6a2ac0d31e80a3e8adc97 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 19 Jun 2025 17:55:09 +0200 Subject: [PATCH 1214/1647] mediaproxy: fix closing data response readers --- mediaproxy/mediaproxy.go | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go index c906fc8e..4be799d3 100644 --- a/mediaproxy/mediaproxy.go +++ b/mediaproxy/mediaproxy.go @@ -421,13 +421,16 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { } } } - } else if dataResp, ok := resp.(GetMediaResponseWriter); ok { - mp.addHeaders(w, dataResp.GetContentType(), vars["fileName"]) - if dataResp.GetContentLength() != 0 { - w.Header().Set("Content-Length", strconv.FormatInt(dataResp.GetContentLength(), 10)) + } else if writerResp, ok := resp.(GetMediaResponseWriter); ok { + if dataResp, ok := writerResp.(*GetMediaResponseData); ok { + defer dataResp.Reader.Close() + } + mp.addHeaders(w, writerResp.GetContentType(), vars["fileName"]) + if writerResp.GetContentLength() != 0 { + w.Header().Set("Content-Length", strconv.FormatInt(writerResp.GetContentLength(), 10)) } w.WriteHeader(http.StatusOK) - _, err := dataResp.WriteTo(w) + _, err := writerResp.WriteTo(w) if err != nil { log.Err(err).Msg("Failed to write media data") } From 3a135b6b1586ea449e79781f7afa865654312786 Mon Sep 17 00:00:00 2001 From: Matthias Kesler Date: Wed, 25 Jun 2025 12:35:18 +0200 Subject: [PATCH 1215/1647] id: fix ServerNameRegex not matching port correctly (#392) fixes #391 --- id/servername.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/id/servername.go b/id/servername.go index 591f394a..923705b6 100644 --- a/id/servername.go +++ b/id/servername.go @@ -25,7 +25,7 @@ type ParsedServerName struct { Port int } -var ServerNameRegex = regexp.MustCompile(`^(?:\[([0-9A-Fa-f:.]{2,45})]|(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})|([0-9A-Za-z.-]{1,255}))(\d{1,5})?$`) +var ServerNameRegex = regexp.MustCompile(`^(?:\[([0-9A-Fa-f:.]{2,45})]|(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})|([0-9A-Za-z.-]{1,255}))(?::(\d{1,5}))?$`) func ValidateServerName(serverName string) bool { return len(serverName) <= 255 && len(serverName) > 0 && ServerNameRegex.MatchString(serverName) From 7a7d7f70ef92b45b74d6009a0b7af95f65e1612d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 29 Jun 2025 19:10:35 +0300 Subject: [PATCH 1216/1647] federation: fix base64 in generated signatures --- federation/signingkey.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/federation/signingkey.go b/federation/signingkey.go index 5b111947..0ae6a571 100644 --- a/federation/signingkey.go +++ b/federation/signingkey.go @@ -179,7 +179,7 @@ func (sk *SigningKey) SignJSON(data any) (string, error) { if err != nil { return "", err } - return base64.RawURLEncoding.EncodeToString(sk.SignRawJSON(marshaled)), nil + return base64.RawStdEncoding.EncodeToString(sk.SignRawJSON(marshaled)), nil } func (sk *SigningKey) SignRawJSON(data json.RawMessage) []byte { From 94950585c94dbead4ca7da5e50f16c604ff5622a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 1 Jul 2025 01:15:24 +0300 Subject: [PATCH 1217/1647] event: fix removing per-message profile fallback in edits --- event/beeper.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/event/beeper.go b/event/beeper.go index a85e82fc..921e3466 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -164,7 +164,10 @@ func (content *MessageEventContent) AddPerMessageProfileFallback() { var HTMLProfileFallbackRegex = regexp.MustCompile(`([^<]+): `) func (content *MessageEventContent) RemovePerMessageProfileFallback() { - if content.BeeperPerMessageProfile == nil || !content.BeeperPerMessageProfile.HasFallback || content.BeeperPerMessageProfile.Displayname == "" { + if content.NewContent != nil && content.NewContent != content { + content.NewContent.RemovePerMessageProfileFallback() + } + if content == nil || content.BeeperPerMessageProfile == nil || !content.BeeperPerMessageProfile.HasFallback || content.BeeperPerMessageProfile.Displayname == "" { return } content.BeeperPerMessageProfile.HasFallback = false From 4f6d4d7c63f31b2103509f0752131e6e652568d6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 1 Jul 2025 01:34:42 +0300 Subject: [PATCH 1218/1647] bridgev2/portal: add support for per-message profiles in relay mode --- bridgev2/networkinterface.go | 1 + bridgev2/portal.go | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 457a7bd4..a107fae7 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -1180,6 +1180,7 @@ type OrigSender struct { RequiresDisambiguation bool DisambiguatedName string FormattedName string + PerMessageProfile event.BeeperPerMessageProfile event.MemberEventContent } diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 21d8550f..ad3f0e0d 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -544,6 +544,8 @@ func (portal *Portal) checkConfusableName(ctx context.Context, userID id.UserID, return false } +var fakePerMessageProfileEventType = event.Type{Class: event.StateEventType, Type: "m.per_message_profile"} + func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) { log := zerolog.Ctx(ctx) if evt.Mautrix.EventSource&event.SourceEphemeral != 0 { @@ -589,6 +591,24 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * } else { origSender.DisambiguatedName = sender.MXID.String() } + msg := evt.Content.AsMessage() + if msg != nil && msg.BeeperPerMessageProfile != nil && msg.BeeperPerMessageProfile.Displayname != "" { + pmp := msg.BeeperPerMessageProfile + origSender.PerMessageProfile = *pmp + roomPLs, err := portal.Bridge.Matrix.GetPowerLevels(ctx, portal.MXID) + if err != nil { + log.Warn().Err(err).Msg("Failed to get power levels to check relay profile") + } + if roomPLs != nil && + roomPLs.GetUserLevel(sender.MXID) >= roomPLs.GetEventLevel(fakePerMessageProfileEventType) && + !portal.checkConfusableName(ctx, sender.MXID, pmp.Displayname) { + origSender.DisambiguatedName = pmp.Displayname + origSender.RequiresDisambiguation = false + } else { + origSender.DisambiguatedName = fmt.Sprintf("%s via %s", pmp.Displayname, origSender.DisambiguatedName) + } + } + origSender.FormattedName = portal.Bridge.Config.Relay.FormatName(origSender) } // Copy logger because many of the handlers will use UpdateContext From 6f370cc3bb3953b2a22056ec583239eda4e3f4a6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 1 Jul 2025 23:27:45 +0300 Subject: [PATCH 1219/1647] bridgev2,appservice: move appservice ping loop to appservice package --- appservice/ping.go | 68 ++++++++++++++++++++++++++++++++++++ bridgev2/matrix/connector.go | 47 ++----------------------- 2 files changed, 70 insertions(+), 45 deletions(-) create mode 100644 appservice/ping.go diff --git a/appservice/ping.go b/appservice/ping.go new file mode 100644 index 00000000..bd6bcbd1 --- /dev/null +++ b/appservice/ping.go @@ -0,0 +1,68 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package appservice + +import ( + "context" + "encoding/json" + "errors" + "os" + "strings" + "time" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" +) + +func (intent *IntentAPI) EnsureAppserviceConnection(ctx context.Context, appserviceID string) { + var pingResp *mautrix.RespAppservicePing + var txnID string + var retryCount int + var err error + const maxRetries = 6 + for { + txnID = intent.TxnID() + pingResp, err = intent.AppservicePing(ctx, appserviceID, txnID) + if err == nil { + break + } + var httpErr mautrix.HTTPError + var pingErrBody string + if errors.As(err, &httpErr) && httpErr.RespError != nil { + if val, ok := httpErr.RespError.ExtraData["body"].(string); ok { + pingErrBody = strings.TrimSpace(val) + } + } + outOfRetries := retryCount >= maxRetries + level := zerolog.ErrorLevel + if outOfRetries { + level = zerolog.FatalLevel + } + evt := zerolog.Ctx(ctx).WithLevel(level).Err(err).Str("txn_id", txnID) + if pingErrBody != "" { + bodyBytes := []byte(pingErrBody) + if json.Valid(bodyBytes) { + evt.RawJSON("body", bodyBytes) + } else { + evt.Str("body", pingErrBody) + } + } + if outOfRetries { + evt.Msg("Homeserver -> appservice connection is not working") + zerolog.Ctx(ctx).Info().Msg("See https://docs.mau.fi/faq/as-ping for more info") + os.Exit(13) + } + evt.Msg("Homeserver -> appservice connection is not working, retrying in 5 seconds...") + time.Sleep(5 * time.Second) + retryCount++ + } + zerolog.Ctx(ctx).Debug(). + Str("txn_id", txnID). + Int64("duration_ms", pingResp.DurationMS). + Msg("Homeserver -> appservice connection works") +} diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index f56eece3..9fdb6804 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -10,7 +10,6 @@ import ( "context" "crypto/sha256" "encoding/base64" - "encoding/json" "errors" "fmt" "net/url" @@ -343,50 +342,8 @@ func (br *Connector) ensureConnection(ctx context.Context) { br.Log.Debug().Msg("Homeserver does not support checking status of homeserver -> bridge connection") return } - var pingResp *mautrix.RespAppservicePing - var txnID string - var retryCount int - const maxRetries = 6 - for { - txnID = br.Bot.TxnID() - pingResp, err = br.Bot.AppservicePing(ctx, br.Config.AppService.ID, txnID) - if err == nil { - break - } - var httpErr mautrix.HTTPError - var pingErrBody string - if errors.As(err, &httpErr) && httpErr.RespError != nil { - if val, ok := httpErr.RespError.ExtraData["body"].(string); ok { - pingErrBody = strings.TrimSpace(val) - } - } - outOfRetries := retryCount >= maxRetries - level := zerolog.ErrorLevel - if outOfRetries { - level = zerolog.FatalLevel - } - evt := br.Log.WithLevel(level).Err(err).Str("txn_id", txnID) - if pingErrBody != "" { - bodyBytes := []byte(pingErrBody) - if json.Valid(bodyBytes) { - evt.RawJSON("body", bodyBytes) - } else { - evt.Str("body", pingErrBody) - } - } - if outOfRetries { - evt.Msg("Homeserver -> bridge connection is not working") - br.Log.Info().Msg("See https://docs.mau.fi/faq/as-ping for more info") - os.Exit(13) - } - evt.Msg("Homeserver -> bridge connection is not working, retrying in 5 seconds...") - time.Sleep(5 * time.Second) - retryCount++ - } - br.Log.Debug(). - Str("txn_id", txnID). - Int64("duration_ms", pingResp.DurationMS). - Msg("Homeserver -> bridge connection works") + + br.Bot.EnsureAppserviceConnection(ctx, br.Config.AppService.ID) } func (br *Connector) fetchMediaConfig(ctx context.Context) { From 71b994b3fd47a23d854fcdbd87f7593bc5aaaf0b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 1 Jul 2025 23:29:43 +0300 Subject: [PATCH 1220/1647] appservice: remove unnecessary parameter in ping --- appservice/ping.go | 4 ++-- bridgev2/matrix/connector.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/appservice/ping.go b/appservice/ping.go index bd6bcbd1..774ec423 100644 --- a/appservice/ping.go +++ b/appservice/ping.go @@ -19,7 +19,7 @@ import ( "maunium.net/go/mautrix" ) -func (intent *IntentAPI) EnsureAppserviceConnection(ctx context.Context, appserviceID string) { +func (intent *IntentAPI) EnsureAppserviceConnection(ctx context.Context) { var pingResp *mautrix.RespAppservicePing var txnID string var retryCount int @@ -27,7 +27,7 @@ func (intent *IntentAPI) EnsureAppserviceConnection(ctx context.Context, appserv const maxRetries = 6 for { txnID = intent.TxnID() - pingResp, err = intent.AppservicePing(ctx, appserviceID, txnID) + pingResp, err = intent.AppservicePing(ctx, intent.as.Registration.ID, txnID) if err == nil { break } diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 9fdb6804..7af2d128 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -343,7 +343,7 @@ func (br *Connector) ensureConnection(ctx context.Context) { return } - br.Bot.EnsureAppserviceConnection(ctx, br.Config.AppService.ID) + br.Bot.EnsureAppserviceConnection(ctx) } func (br *Connector) fetchMediaConfig(ctx context.Context) { From b62535edaa57e2a5337ec305aa2ab9f64d62bf27 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 3 Jul 2025 21:22:19 +0300 Subject: [PATCH 1221/1647] bridgev2/portal: fix disappearing message notice for implicitly turning off timer --- bridgev2/portal.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index ad3f0e0d..88fc5fe9 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -3759,10 +3759,14 @@ func DisappearingMessageNotice(expiration time.Duration, implicit bool) *event.M Body: fmt.Sprintf("Set the disappearing message timer to %s", formattedDuration), Mentions: &event.Mentions{}, } - if implicit { + if expiration == 0 { + if implicit { + content.Body = "Automatically turned off disappearing messages because incoming message is not disappearing" + } else { + content.Body = "Turned off disappearing messages" + } + } else if implicit { content.Body = fmt.Sprintf("Automatically enabled disappearing message timer (%s) because incoming message is disappearing", formattedDuration) - } else if expiration == 0 { - content.Body = "Turned off disappearing messages" } return content } From 44515616d454b6229cf7749f4921355f7220b893 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 9 Jul 2025 16:28:02 +0300 Subject: [PATCH 1222/1647] bridgev2/portal: don't assume unknown reply events are cross-room --- bridgev2/portal.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 88fc5fe9..e0b0d4f6 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1940,7 +1940,13 @@ func (portal *Portal) GetIntentFor(ctx context.Context, sender EventSender, sour return intent, true } -func (portal *Portal) getRelationMeta(ctx context.Context, currentMsg networkid.MessageID, replyToPtr *networkid.MessageOptionalPartID, threadRootPtr *networkid.MessageID, isBatchSend bool) (replyTo, threadRoot, prevThreadEvent *database.Message) { +func (portal *Portal) getRelationMeta( + ctx context.Context, + currentMsg networkid.MessageID, + replyToPtr *networkid.MessageOptionalPartID, + threadRootPtr *networkid.MessageID, + isBatchSend bool, +) (replyTo, threadRoot, prevThreadEvent *database.Message) { log := zerolog.Ctx(ctx) var err error if replyToPtr != nil { @@ -1950,6 +1956,7 @@ func (portal *Portal) getRelationMeta(ctx context.Context, currentMsg networkid. } else if replyTo == nil { if isBatchSend || portal.Bridge.Config.OutgoingMessageReID { // This is somewhat evil + // TODO this does not work with cross-room replies replyTo = &database.Message{ MXID: portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, replyToPtr.MessageID, ptr.Val(replyToPtr.PartID)), } @@ -1988,7 +1995,7 @@ func (portal *Portal) applyRelationMeta(ctx context.Context, content *event.Mess content.GetRelatesTo().SetThread(threadRoot.MXID, prevThreadEvent.MXID) } if replyTo != nil { - crossRoom := replyTo.Room != portal.PortalKey + crossRoom := !replyTo.Room.IsEmpty() && replyTo.Room != portal.PortalKey if !crossRoom || portal.Bridge.Config.CrossRoomReplies { content.GetRelatesTo().SetReplyTo(replyTo.MXID) } From 0777c10028375b23af567110c543806b4dd0fcd0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 9 Jul 2025 16:35:14 +0300 Subject: [PATCH 1223/1647] bridgev2/networkinterface: add extra fields to reply metadata to allow unknown cross-room replies --- bridgev2/networkinterface.go | 9 ++++++- bridgev2/portal.go | 49 +++++++++++++++++++++++++----------- bridgev2/portalbackfill.go | 4 ++- bridgev2/portalinternal.go | 4 +-- 4 files changed, 48 insertions(+), 18 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index a107fae7..eb38bd2d 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -91,7 +91,14 @@ func (es EventSender) MarshalZerologObject(evt *zerolog.Event) { } type ConvertedMessage struct { - ReplyTo *networkid.MessageOptionalPartID + ReplyTo *networkid.MessageOptionalPartID + // Optional additional info about the reply. This is only used when backfilling messages + // on Beeper, where replies may target messages that haven't been bridged yet. + // Standard Matrix servers can't backwards backfill, so these are never used. + ReplyToRoom networkid.PortalKey + ReplyToUser networkid.UserID + ReplyToLogin networkid.UserLoginID + ThreadRoot *networkid.MessageID Parts []*ConvertedMessagePart Disappear database.DisappearingSetting diff --git a/bridgev2/portal.go b/bridgev2/portal.go index e0b0d4f6..856b6331 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1942,42 +1942,61 @@ func (portal *Portal) GetIntentFor(ctx context.Context, sender EventSender, sour func (portal *Portal) getRelationMeta( ctx context.Context, - currentMsg networkid.MessageID, - replyToPtr *networkid.MessageOptionalPartID, - threadRootPtr *networkid.MessageID, + currentMsgID networkid.MessageID, + currentMsg *ConvertedMessage, isBatchSend bool, ) (replyTo, threadRoot, prevThreadEvent *database.Message) { log := zerolog.Ctx(ctx) var err error - if replyToPtr != nil { - replyTo, err = portal.Bridge.DB.Message.GetFirstOrSpecificPartByID(ctx, portal.Receiver, *replyToPtr) + if currentMsg.ReplyTo != nil { + replyTo, err = portal.Bridge.DB.Message.GetFirstOrSpecificPartByID(ctx, portal.Receiver, *currentMsg.ReplyTo) if err != nil { log.Err(err).Msg("Failed to get reply target message from database") } else if replyTo == nil { if isBatchSend || portal.Bridge.Config.OutgoingMessageReID { // This is somewhat evil - // TODO this does not work with cross-room replies replyTo = &database.Message{ - MXID: portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, replyToPtr.MessageID, ptr.Val(replyToPtr.PartID)), + MXID: portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, currentMsg.ReplyTo.MessageID, ptr.Val(currentMsg.ReplyTo.PartID)), + Room: currentMsg.ReplyToRoom, + SenderID: currentMsg.ReplyToUser, + } + if currentMsg.ReplyToLogin != "" && (portal.Receiver == "" || portal.Receiver == currentMsg.ReplyToLogin) { + userLogin, err := portal.Bridge.GetExistingUserLoginByID(ctx, currentMsg.ReplyToLogin) + if err != nil { + log.Err(err). + Str("reply_to_login", string(currentMsg.ReplyToLogin)). + Msg("Failed to get reply target user login") + } else if userLogin != nil { + replyTo.SenderMXID = userLogin.UserMXID + } + } else { + ghost, err := portal.Bridge.GetGhostByID(ctx, currentMsg.ReplyToUser) + if err != nil { + log.Err(err). + Str("reply_to_user_id", string(currentMsg.ReplyToUser)). + Msg("Failed to get reply target ghost") + } else { + replyTo.SenderMXID = ghost.Intent.GetMXID() + } } } else { - log.Warn().Any("reply_to", *replyToPtr).Msg("Reply target message not found in database") + log.Warn().Any("reply_to", *currentMsg.ReplyTo).Msg("Reply target message not found in database") } } } - if threadRootPtr != nil && *threadRootPtr != currentMsg { - threadRoot, err = portal.Bridge.DB.Message.GetFirstThreadMessage(ctx, portal.PortalKey, *threadRootPtr) + if currentMsg.ThreadRoot != nil && *currentMsg.ThreadRoot != currentMsgID { + threadRoot, err = portal.Bridge.DB.Message.GetFirstThreadMessage(ctx, portal.PortalKey, *currentMsg.ThreadRoot) if err != nil { log.Err(err).Msg("Failed to get thread root message from database") } else if threadRoot == nil { if isBatchSend || portal.Bridge.Config.OutgoingMessageReID { threadRoot = &database.Message{ - MXID: portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, *threadRootPtr, ""), + MXID: portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, *currentMsg.ThreadRoot, ""), } } else { - log.Warn().Str("thread_root", string(*threadRootPtr)).Msg("Thread root message not found in database") + log.Warn().Str("thread_root", string(*currentMsg.ThreadRoot)).Msg("Thread root message not found in database") } - } else if prevThreadEvent, err = portal.Bridge.DB.Message.GetLastThreadMessage(ctx, portal.PortalKey, *threadRootPtr); err != nil { + } else if prevThreadEvent, err = portal.Bridge.DB.Message.GetLastThreadMessage(ctx, portal.PortalKey, *currentMsg.ThreadRoot); err != nil { log.Err(err).Msg("Failed to get last thread message from database") } if prevThreadEvent == nil { @@ -2033,7 +2052,9 @@ func (portal *Portal) sendConvertedMessage( } } log := zerolog.Ctx(ctx) - replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta(ctx, id, converted.ReplyTo, converted.ThreadRoot, false) + replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta( + ctx, id, converted, false, + ) output := make([]*database.Message, 0, len(converted.Parts)) allSuccess := true for i, part := range converted.Parts { diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index 74b75df2..9883fb12 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -327,7 +327,9 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin if !ok { return } - replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta(ctx, msg.ID, msg.ReplyTo, msg.ThreadRoot, true) + replyTo, threadRoot, prevThreadEvent := portal.getRelationMeta( + ctx, msg.ID, msg.ConvertedMessage, true, + ) if threadRoot != nil && out.PrevThreadEvents[*msg.ThreadRoot] != "" { prevThreadEvent.MXID = out.PrevThreadEvents[*msg.ThreadRoot] } diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index 2b25f0cf..6815f043 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -137,8 +137,8 @@ func (portal *PortalInternals) GetIntentAndUserMXIDFor(ctx context.Context, send return (*Portal)(portal).getIntentAndUserMXIDFor(ctx, sender, source, otherLogins, evtType) } -func (portal *PortalInternals) GetRelationMeta(ctx context.Context, currentMsg networkid.MessageID, replyToPtr *networkid.MessageOptionalPartID, threadRootPtr *networkid.MessageID, isBatchSend bool) (replyTo, threadRoot, prevThreadEvent *database.Message) { - return (*Portal)(portal).getRelationMeta(ctx, currentMsg, replyToPtr, threadRootPtr, isBatchSend) +func (portal *PortalInternals) GetRelationMeta(ctx context.Context, currentMsgID networkid.MessageID, currentMsg *ConvertedMessage, isBatchSend bool) (replyTo, threadRoot, prevThreadEvent *database.Message) { + return (*Portal)(portal).getRelationMeta(ctx, currentMsgID, currentMsg, isBatchSend) } func (portal *PortalInternals) ApplyRelationMeta(ctx context.Context, content *event.MessageEventContent, replyTo, threadRoot, prevThreadEvent *database.Message) { From c80808439d5bc31d442323929710f1ddfd51e4f6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 10 Jul 2025 13:45:11 +0300 Subject: [PATCH 1224/1647] bridgev2: add logger to background context --- bridgev2/bridge.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 5e3b74b7..a4ce033e 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -172,6 +172,7 @@ func (br *Bridge) StartConnectors(ctx context.Context) error { br.Log.Info().Msg("Starting bridge") if br.BackgroundCtx == nil || br.BackgroundCtx.Err() != nil { br.BackgroundCtx, br.cancelBackgroundCtx = context.WithCancel(context.Background()) + br.BackgroundCtx = br.Log.WithContext(br.BackgroundCtx) } if !br.ExternallyManagedDB { From 22587e915906f6e6d90531430a8fcf26857aa092 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 10 Jul 2025 13:45:23 +0300 Subject: [PATCH 1225/1647] bridgev2/portal: track event handler panics --- bridgev2/portal.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 856b6331..c264caea 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -420,10 +420,13 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal doneCallback(res) if err := recover(); err != nil { logEvt := log.Error() + var errorString string if realErr, ok := err.(error); ok { logEvt = logEvt.Err(realErr) + errorString = realErr.Error() } else { logEvt = logEvt.Any(zerolog.ErrorFieldName, err) + errorString = fmt.Sprintf("%v", err) } logEvt. Bytes("stack", debug.Stack()). @@ -436,6 +439,9 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal case *portalCreateEvent: evt.cb(fmt.Errorf("portal creation panicked")) } + portal.Bridge.TrackAnalytics("", "Bridge Event Handler Panic", map[string]any{ + "error": errorString, + }) } }() switch evt := rawEvt.(type) { From 40bb9637cdc44979358a337a1870d021ea4a07ad Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 10 Jul 2025 14:48:54 +0300 Subject: [PATCH 1226/1647] bridgev2/queue: add event handling result for matrix events --- bridgev2/matrixinvite.go | 30 ++--- bridgev2/portal.go | 226 +++++++++++++++++++++---------------- bridgev2/portalinternal.go | 36 +++--- bridgev2/queue.go | 41 ++++--- 4 files changed, 191 insertions(+), 142 deletions(-) diff --git a/bridgev2/matrixinvite.go b/bridgev2/matrixinvite.go index 11826b40..bfbabd26 100644 --- a/bridgev2/matrixinvite.go +++ b/bridgev2/matrixinvite.go @@ -19,17 +19,17 @@ import ( "maunium.net/go/mautrix/id" ) -func (br *Bridge) handleBotInvite(ctx context.Context, evt *event.Event, sender *User) { +func (br *Bridge) handleBotInvite(ctx context.Context, evt *event.Event, sender *User) EventHandlingResult { log := zerolog.Ctx(ctx) // These invites should already be rejected in QueueMatrixEvent if !sender.Permissions.Commands { log.Warn().Msg("Received bot invite from user without permission to send commands") - return + return EventHandlingResultIgnored } err := br.Bot.EnsureJoined(ctx, evt.RoomID) if err != nil { log.Err(err).Msg("Failed to accept invite to room") - return + return EventHandlingResultFailed } log.Debug().Msg("Accepted invite to room as bot") members, err := br.Matrix.GetMembers(ctx, evt.RoomID) @@ -55,6 +55,7 @@ func (br *Bridge) handleBotInvite(ctx context.Context, evt *event.Event, sender log.Err(err).Msg("Failed to send welcome message to room") } } + return EventHandlingResultSuccess } func sendNotice(ctx context.Context, evt *event.Event, intent MatrixAPI, message string, args ...any) { @@ -87,12 +88,12 @@ func sendErrorAndLeave(ctx context.Context, evt *event.Event, intent MatrixAPI, rejectInvite(ctx, evt, intent, "") } -func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sender *User) { +func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sender *User) EventHandlingResult { ghostID, _ := br.Matrix.ParseGhostMXID(id.UserID(evt.GetStateKey())) validator, ok := br.Network.(IdentifierValidatingNetwork) if ghostID == "" || (ok && !validator.ValidateUserID(ghostID)) { rejectInvite(ctx, evt, br.Matrix.GhostIntent(ghostID), "Malformed user ID") - return + return EventHandlingResultIgnored } log := zerolog.Ctx(ctx).With(). Str("invitee_network_id", string(ghostID)). @@ -102,22 +103,22 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen logins := sender.GetUserLogins() if len(logins) == 0 { rejectInvite(ctx, evt, br.Matrix.GhostIntent(ghostID), "You're not logged in") - return + return EventHandlingResultIgnored } _, ok = logins[0].Client.(IdentifierResolvingNetworkAPI) if !ok { rejectInvite(ctx, evt, br.Matrix.GhostIntent(ghostID), "This bridge does not support starting chats") - return + return EventHandlingResultIgnored } invitedGhost, err := br.GetGhostByID(ctx, ghostID) if err != nil { log.Err(err).Msg("Failed to get invited ghost") - return + return EventHandlingResultFailed } err = invitedGhost.Intent.EnsureJoined(ctx, evt.RoomID) if err != nil { log.Err(err).Msg("Failed to accept invite to room") - return + return EventHandlingResultFailed } var resp *CreateChatResponse var sourceLogin *UserLogin @@ -144,7 +145,7 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen } else if err != nil { log.Err(err).Msg("Failed to resolve identifier") sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "Failed to create chat") - return + return EventHandlingResultFailed } else { sourceLogin = login break @@ -153,7 +154,7 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen if resp == nil { log.Warn().Msg("No login could resolve the identifier") sendErrorAndLeave(ctx, evt, br.Matrix.GhostIntent(ghostID), "Failed to create chat via any login") - return + return EventHandlingResultFailed } portal := resp.Portal if portal == nil { @@ -161,7 +162,7 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen if err != nil { log.Err(err).Msg("Failed to get portal by key") sendErrorAndLeave(ctx, evt, br.Matrix.GhostIntent(ghostID), "Failed to create portal entry") - return + return EventHandlingResultFailed } } if portal.MXID != "" { @@ -196,13 +197,13 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen if err != nil { log.Err(err).Msg("Failed to ensure bot is invited to room") sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "Failed to invite bridge bot") - return + return EventHandlingResultFailed } err = br.Bot.EnsureJoined(ctx, evt.RoomID) if err != nil { log.Err(err).Msg("Failed to ensure bot is joined to room") sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "Failed to join with bridge bot") - return + return EventHandlingResultFailed } didSetPortal := portal.setMXIDToExistingRoom(ctx, evt.RoomID) @@ -271,6 +272,7 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "You already have a direct chat with me at [%s](%s)", portal.MXID, portal.MXID.URI(br.Matrix.ServerName()).MatrixToURL()) rejectInvite(ctx, evt, br.Bot, "") } + return EventHandlingResultSuccess } func (br *Bridge) givePowerToBot(ctx context.Context, roomID id.RoomID, userWithPower MatrixAPI) error { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index c264caea..900d057d 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -446,7 +446,7 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal }() switch evt := rawEvt.(type) { case *portalMatrixEvent: - portal.handleMatrixEvent(ctx, evt.sender, evt.evt) + res = portal.handleMatrixEvent(ctx, evt.sender, evt.evt) case *portalRemoteEvent: res = portal.handleRemoteEvent(ctx, evt.source, evt.evtType, evt.evt) case *portalCreateEvent: @@ -552,16 +552,17 @@ func (portal *Portal) checkConfusableName(ctx context.Context, userID id.UserID, var fakePerMessageProfileEventType = event.Type{Class: event.StateEventType, Type: "m.per_message_profile"} -func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) { +func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) EventHandlingResult { log := zerolog.Ctx(ctx) if evt.Mautrix.EventSource&event.SourceEphemeral != 0 { switch evt.Type { case event.EphemeralEventReceipt: - portal.handleMatrixReceipts(ctx, evt) + return portal.handleMatrixReceipts(ctx, evt) case event.EphemeralEventTyping: - portal.handleMatrixTyping(ctx, evt) + return portal.handleMatrixTyping(ctx, evt) + default: + return EventHandlingResultIgnored } - return } login, _, err := portal.FindPreferredLogin(ctx, sender, true) if err != nil { @@ -572,7 +573,7 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * } else { portal.sendErrorStatus(ctx, evt, WrapErrorInStatus(err).WithMessage("Failed to get login to handle event").WithIsCertain(true).WithSendNotice(true)) } - return + return EventHandlingResultFailed } var origSender *OrigSender if login == nil { @@ -621,41 +622,44 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * ctx = log.With().Str("login_id", string(login.ID)).Logger().WithContext(ctx) switch evt.Type { case event.EventMessage, event.EventSticker, event.EventUnstablePollStart, event.EventUnstablePollResponse: - portal.handleMatrixMessage(ctx, login, origSender, evt) + return portal.handleMatrixMessage(ctx, login, origSender, evt) case event.EventReaction: if origSender != nil { log.Debug().Msg("Ignoring reaction event from relayed user") portal.sendErrorStatus(ctx, evt, ErrIgnoringReactionFromRelayedUser) - return + return EventHandlingResultIgnored } - portal.handleMatrixReaction(ctx, login, evt) + return portal.handleMatrixReaction(ctx, login, evt) case event.EventRedaction: - portal.handleMatrixRedaction(ctx, login, origSender, evt) + return portal.handleMatrixRedaction(ctx, login, origSender, evt) case event.StateRoomName: - handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomNameHandlingNetworkAPI.HandleMatrixRoomName) + return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomNameHandlingNetworkAPI.HandleMatrixRoomName) case event.StateTopic: - handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomTopicHandlingNetworkAPI.HandleMatrixRoomTopic) + return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomTopicHandlingNetworkAPI.HandleMatrixRoomTopic) case event.StateRoomAvatar: - handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomAvatarHandlingNetworkAPI.HandleMatrixRoomAvatar) + return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomAvatarHandlingNetworkAPI.HandleMatrixRoomAvatar) case event.StateEncryption: // TODO? + return EventHandlingResultIgnored case event.AccountDataMarkedUnread: - handleMatrixAccountData(portal, ctx, login, evt, MarkedUnreadHandlingNetworkAPI.HandleMarkedUnread) + return handleMatrixAccountData(portal, ctx, login, evt, MarkedUnreadHandlingNetworkAPI.HandleMarkedUnread) case event.AccountDataRoomTags: - handleMatrixAccountData(portal, ctx, login, evt, TagHandlingNetworkAPI.HandleRoomTag) + return handleMatrixAccountData(portal, ctx, login, evt, TagHandlingNetworkAPI.HandleRoomTag) case event.AccountDataBeeperMute: - handleMatrixAccountData(portal, ctx, login, evt, MuteHandlingNetworkAPI.HandleMute) + return handleMatrixAccountData(portal, ctx, login, evt, MuteHandlingNetworkAPI.HandleMute) case event.StateMember: - portal.handleMatrixMembership(ctx, login, origSender, evt) + return portal.handleMatrixMembership(ctx, login, origSender, evt) case event.StatePowerLevels: - portal.handleMatrixPowerLevels(ctx, login, origSender, evt) + return portal.handleMatrixPowerLevels(ctx, login, origSender, evt) + default: + return EventHandlingResultIgnored } } -func (portal *Portal) handleMatrixReceipts(ctx context.Context, evt *event.Event) { +func (portal *Portal) handleMatrixReceipts(ctx context.Context, evt *event.Event) EventHandlingResult { content, ok := evt.Content.Parsed.(*event.ReceiptEventContent) if !ok { - return + return EventHandlingResultFailed } for evtID, receipts := range *content { readReceipts, ok := receipts[event.ReceiptTypeRead] @@ -666,11 +670,13 @@ func (portal *Portal) handleMatrixReceipts(ctx context.Context, evt *event.Event sender, err := portal.Bridge.GetUserByMXID(ctx, userID) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to get user to handle read receipt") - return + return EventHandlingResultFailed } portal.handleMatrixReadReceipt(ctx, sender, evtID, receipt) } } + // TODO actual status + return EventHandlingResultSuccess } func (portal *Portal) handleMatrixReadReceipt(ctx context.Context, user *User, eventID id.EventID, receipt event.ReadReceipt) { @@ -736,10 +742,10 @@ func (portal *Portal) handleMatrixReadReceipt(ctx context.Context, user *User, e portal.Bridge.DisappearLoop.StartAll(ctx, portal.MXID) } -func (portal *Portal) handleMatrixTyping(ctx context.Context, evt *event.Event) { +func (portal *Portal) handleMatrixTyping(ctx context.Context, evt *event.Event) EventHandlingResult { content, ok := evt.Content.Parsed.(*event.TypingEventContent) if !ok { - return + return EventHandlingResultFailed } portal.currentlyTypingLock.Lock() defer portal.currentlyTypingLock.Unlock() @@ -750,6 +756,8 @@ func (portal *Portal) handleMatrixTyping(ctx context.Context, evt *event.Event) portal.sendTypings(ctx, stoppedTyping, false) portal.sendTypings(ctx, startedTyping, true) portal.currentlyTyping = content.UserIDs + // TODO actual status + return EventHandlingResultSuccess } func (portal *Portal) sendTypings(ctx context.Context, userIDs []id.UserID, typing bool) { @@ -882,7 +890,7 @@ func (portal *Portal) parseInputTransactionID(origSender *OrigSender, evt *event return networkid.RawTransactionID(strings.TrimPrefix(evt.ID.String(), database.NetworkTxnMXIDPrefix)) } -func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { +func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { log := zerolog.Ctx(ctx) var relatesTo *event.RelatesTo var msgContent *event.MessageEventContent @@ -903,13 +911,14 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin } if msgContent.MsgType == event.MsgNotice && !portal.Bridge.Config.BridgeNotices { portal.sendErrorStatus(ctx, evt, ErrIgnoringMNotice) - return + return EventHandlingResultIgnored } } if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) - return + typeErr := fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed) + portal.sendErrorStatus(ctx, evt, typeErr) + return EventHandlingResultFailed.WithError(typeErr) } caps := sender.Client.GetCapabilities(ctx, portal) @@ -917,34 +926,33 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin if msgContent == nil { log.Warn().Msg("Ignoring edit of poll") portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w of polls", ErrEditsNotSupported)) - return + return EventHandlingResultFailed.WithError(fmt.Errorf("%w of polls", ErrEditsNotSupported)) } - portal.handleMatrixEdit(ctx, sender, origSender, evt, msgContent, caps) - return + return portal.handleMatrixEdit(ctx, sender, origSender, evt, msgContent, caps) } var err error if origSender != nil { if msgContent == nil { log.Debug().Msg("Ignoring poll event from relayed user") portal.sendErrorStatus(ctx, evt, ErrIgnoringPollFromRelayedUser) - return + return EventHandlingResultIgnored } msgContent, err = portal.Bridge.Config.Relay.FormatMessage(msgContent, origSender) if err != nil { log.Err(err).Msg("Failed to format message for relaying") portal.sendErrorStatus(ctx, evt, err) - return + return EventHandlingResultFailed.WithError(err) } } if msgContent != nil { if !portal.checkMessageContentCaps(ctx, caps, msgContent, evt) { - return + return EventHandlingResultFailed } } else if pollResponseContent != nil || pollContent != nil { if _, ok = sender.Client.(PollHandlingNetworkAPI); !ok { log.Debug().Msg("Ignoring poll event as network connector doesn't implement PollHandlingNetworkAPI") portal.sendErrorStatus(ctx, evt, ErrPollsNotSupported) - return + return EventHandlingResultIgnored } } @@ -954,11 +962,11 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin if err != nil { log.Err(err).Msg("Failed to get poll target message from database") // TODO send status - return + return EventHandlingResultFailed } else if voteTo == nil { log.Warn().Stringer("vote_to_id", relatesTo.GetReferenceID()).Msg("Poll target message not found") // TODO send status - return + return EventHandlingResultFailed } } var replyToID id.EventID @@ -1023,7 +1031,7 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin Stringer("message_mxid", part.MXID). Stringer("input_event_id", evt.ID). Msg("Message already sent, ignoring") - return + return EventHandlingResultIgnored } } @@ -1044,12 +1052,12 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin } else { log.Error().Msg("Failed to handle Matrix message: all contents are nil?") portal.sendErrorStatus(ctx, evt, fmt.Errorf("all contents are nil")) - return + return EventHandlingResultFailed } if err != nil { log.Err(err).Msg("Failed to handle Matrix message") portal.sendErrorStatus(ctx, evt, err) - return + return EventHandlingResultFailed.WithError(err) } message := wrappedMsgEvt.fillDBMessage(resp.DB) if resp.Pending { @@ -1091,6 +1099,7 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin }, }) } + return EventHandlingResultSuccess } // AddPendingToIgnore adds a transaction ID that should be ignored if encountered as a new message. @@ -1202,7 +1211,14 @@ func (portal *Portal) checkPendingMessages(ctx context.Context, cfg *OutgoingTim } } -func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent, caps *event.RoomFeatures) { +func (portal *Portal) handleMatrixEdit( + ctx context.Context, + sender *UserLogin, + origSender *OrigSender, + evt *event.Event, + content *event.MessageEventContent, + caps *event.RoomFeatures, +) EventHandlingResult { log := zerolog.Ctx(ctx) editTargetID := content.RelatesTo.GetReplaceID() log.UpdateContext(func(c zerolog.Context) zerolog.Context { @@ -1220,7 +1236,7 @@ func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, o if err != nil { log.Err(err).Msg("Failed to format message for relaying") portal.sendErrorStatus(ctx, evt, err) - return + return EventHandlingResultFailed.WithError(err) } } @@ -1228,29 +1244,30 @@ func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, o if !ok { log.Debug().Msg("Ignoring edit as network connector doesn't implement EditHandlingNetworkAPI") portal.sendErrorStatus(ctx, evt, ErrEditsNotSupported) - return + return EventHandlingResultIgnored } else if !caps.Edit.Partial() { log.Debug().Msg("Ignoring edit as room doesn't support edits") portal.sendErrorStatus(ctx, evt, ErrEditsNotSupportedInPortal) - return + return EventHandlingResultIgnored } else if !portal.checkMessageContentCaps(ctx, caps, content, evt) { - return + return EventHandlingResultFailed } editTarget, err := portal.Bridge.DB.Message.GetPartByMXID(ctx, editTargetID) if err != nil { log.Err(err).Msg("Failed to get edit target message from database") portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get edit target: %w", ErrDatabaseError, err)) - return + return EventHandlingResultFailed } else if editTarget == nil { log.Warn().Msg("Edit target message not found in database") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("edit %w", ErrTargetMessageNotFound)) - return + notFoundErr := fmt.Errorf("edit %w", ErrTargetMessageNotFound) + portal.sendErrorStatus(ctx, evt, notFoundErr) + return EventHandlingResultFailed.WithError(notFoundErr) } else if caps.EditMaxAge != nil && caps.EditMaxAge.Duration > 0 && time.Since(editTarget.Timestamp) > caps.EditMaxAge.Duration { portal.sendErrorStatus(ctx, evt, ErrEditTargetTooOld) - return + return EventHandlingResultFailed.WithError(ErrEditTargetTooOld) } else if caps.EditMaxCount > 0 && editTarget.EditCount >= caps.EditMaxCount { portal.sendErrorStatus(ctx, evt, ErrEditTargetTooManyEdits) - return + return EventHandlingResultFailed.WithError(ErrEditTargetTooManyEdits) } log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Str("edit_target_remote_id", string(editTarget.ID)) @@ -1269,7 +1286,7 @@ func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, o if err != nil { log.Err(err).Msg("Failed to handle Matrix edit") portal.sendErrorStatus(ctx, evt, err) - return + return EventHandlingResultFailed.WithError(err) } err = portal.Bridge.DB.Message.Update(ctx, editTarget) if err != nil { @@ -1277,21 +1294,23 @@ func (portal *Portal) handleMatrixEdit(ctx context.Context, sender *UserLogin, o } // TODO allow returning stream order from HandleMatrixEdit portal.sendSuccessStatus(ctx, evt, 0, "") + return EventHandlingResultSuccess } -func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) { +func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) EventHandlingResult { log := zerolog.Ctx(ctx) reactingAPI, ok := sender.Client.(ReactionHandlingNetworkAPI) if !ok { log.Debug().Msg("Ignoring reaction as network connector doesn't implement ReactionHandlingNetworkAPI") portal.sendErrorStatus(ctx, evt, ErrReactionsNotSupported) - return + return EventHandlingResultIgnored } content, ok := evt.Content.Parsed.(*event.ReactionEventContent) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) - return + typeErr := fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed) + portal.sendErrorStatus(ctx, evt, typeErr) + return EventHandlingResultFailed.WithError(typeErr) } log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Stringer("reaction_target_mxid", content.RelatesTo.EventID) @@ -1300,11 +1319,12 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi if err != nil { log.Err(err).Msg("Failed to get reaction target message from database") portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get reaction target: %w", ErrDatabaseError, err)) - return + return EventHandlingResultFailed } else if reactionTarget == nil { log.Warn().Msg("Reaction target message not found in database") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("reaction %w", ErrTargetMessageNotFound)) - return + notFoundErr := fmt.Errorf("reaction %w", ErrTargetMessageNotFound) + portal.sendErrorStatus(ctx, evt, notFoundErr) + return EventHandlingResultFailed.WithError(notFoundErr) } log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Str("reaction_target_remote_id", string(reactionTarget.ID)) @@ -1323,7 +1343,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi if err != nil { log.Err(err).Msg("Failed to pre-handle Matrix reaction") portal.sendErrorStatus(ctx, evt, err) - return + return EventHandlingResultFailed.WithError(err) } var deterministicID id.EventID if portal.Bridge.Config.OutgoingMessageReID { @@ -1332,12 +1352,12 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi existing, err := portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, reactionTarget.ID, reactionTarget.PartID, preResp.SenderID, preResp.EmojiID) if err != nil { log.Err(err).Msg("Failed to check if reaction is a duplicate") - return + return EventHandlingResultFailed } else if existing != nil { if existing.EmojiID != "" || existing.Emoji == preResp.Emoji { log.Debug().Msg("Ignoring duplicate reaction") portal.sendSuccessStatus(ctx, evt, 0, deterministicID) - return + return EventHandlingResultIgnored } react.ReactionToOverride = existing _, err = portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ @@ -1355,7 +1375,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi if err != nil { log.Err(err).Msg("Failed to get all reactions to message by sender") portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get previous reactions: %w", ErrDatabaseError, err)) - return + return EventHandlingResultFailed } if len(allReactions) < preResp.MaxReactions { react.ExistingReactionsToKeep = allReactions @@ -1382,7 +1402,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi if err != nil { log.Err(err).Msg("Failed to handle Matrix reaction") portal.sendErrorStatus(ctx, evt, err) - return + return EventHandlingResultFailed.WithError(err) } if dbReaction == nil { dbReaction = &database.Reaction{} @@ -1421,6 +1441,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi log.Err(err).Msg("Failed to save reaction to database") } portal.sendSuccessStatus(ctx, evt, 0, deterministicID) + return EventHandlingResultSuccess } func handleMatrixRoomMeta[APIType any, ContentType any]( @@ -1430,34 +1451,35 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( origSender *OrigSender, evt *event.Event, fn func(APIType, context.Context, *MatrixRoomMeta[ContentType]) (bool, error), -) { +) EventHandlingResult { api, ok := sender.Client.(APIType) if !ok { portal.sendErrorStatus(ctx, evt, ErrRoomMetadataNotSupported) - return + return EventHandlingResultIgnored } log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(ContentType) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) - return + typeErr := fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed) + portal.sendErrorStatus(ctx, evt, typeErr) + return EventHandlingResultFailed.WithError(typeErr) } switch typedContent := evt.Content.Parsed.(type) { case *event.RoomNameEventContent: if typedContent.Name == portal.Name { portal.sendSuccessStatus(ctx, evt, 0, "") - return + return EventHandlingResultIgnored } case *event.TopicEventContent: if typedContent.Topic == portal.Topic { portal.sendSuccessStatus(ctx, evt, 0, "") - return + return EventHandlingResultIgnored } case *event.RoomAvatarEventContent: if typedContent.URL == portal.AvatarMXC { portal.sendSuccessStatus(ctx, evt, 0, "") - return + return EventHandlingResultIgnored } } var prevContent ContentType @@ -1480,7 +1502,7 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( if err != nil { log.Err(err).Msg("Failed to handle Matrix room metadata") portal.sendErrorStatus(ctx, evt, err) - return + return EventHandlingResultFailed.WithError(err) } if changed { portal.UpdateBridgeInfo(ctx) @@ -1490,21 +1512,22 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( } } portal.sendSuccessStatus(ctx, evt, 0, "") + return EventHandlingResultSuccess } func handleMatrixAccountData[APIType any, ContentType any]( portal *Portal, ctx context.Context, sender *UserLogin, evt *event.Event, fn func(APIType, context.Context, *MatrixRoomMeta[ContentType]) error, -) { +) EventHandlingResult { api, ok := sender.Client.(APIType) if !ok { - return + return EventHandlingResultIgnored } log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(ContentType) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - return + return EventHandlingResultFailed.WithError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) } var prevContent ContentType if evt.Unsigned.PrevContent != nil { @@ -1522,7 +1545,9 @@ func handleMatrixAccountData[APIType any, ContentType any]( }) if err != nil { log.Err(err).Msg("Failed to handle Matrix room account data") + return EventHandlingResultFailed.WithError(err) } + return EventHandlingResultSuccess } func (portal *Portal) getTargetUser(ctx context.Context, userID id.UserID) (GhostOrUserLogin, error) { @@ -1547,13 +1572,14 @@ func (portal *Portal) handleMatrixMembership( sender *UserLogin, origSender *OrigSender, evt *event.Event, -) { +) EventHandlingResult { log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(*event.MemberEventContent) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) - return + typeErr := fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed) + portal.sendErrorStatus(ctx, evt, typeErr) + return EventHandlingResultFailed.WithError(typeErr) } prevContent := &event.MemberEventContent{Membership: event.MembershipLeave} if evt.Unsigned.PrevContent != nil { @@ -1569,7 +1595,7 @@ func (portal *Portal) handleMatrixMembership( api, ok := sender.Client.(MembershipHandlingNetworkAPI) if !ok { portal.sendErrorStatus(ctx, evt, ErrMembershipNotSupported) - return + return EventHandlingResultIgnored } targetMXID := id.UserID(*evt.StateKey) isSelf := sender.User.MXID == targetMXID @@ -1577,14 +1603,14 @@ func (portal *Portal) handleMatrixMembership( if err != nil { log.Err(err).Msg("Failed to get member event target") portal.sendErrorStatus(ctx, evt, err) - return + return EventHandlingResultFailed } membershipChangeType := MembershipChangeType{From: prevContent.Membership, To: content.Membership, IsSelf: isSelf} if !portal.Bridge.Config.BridgeMatrixLeave && membershipChangeType == Leave { log.Debug().Msg("Dropping leave event") //portal.sendErrorStatus(ctx, evt, ErrIgnoringLeaveEvent) - return + return EventHandlingResultIgnored } targetGhost, _ := target.(*Ghost) targetUserLogin, _ := target.(*UserLogin) @@ -1609,8 +1635,9 @@ func (portal *Portal) handleMatrixMembership( if err != nil { log.Err(err).Msg("Failed to handle Matrix membership change") portal.sendErrorStatus(ctx, evt, err) - return + return EventHandlingResultFailed.WithError(err) } + return EventHandlingResultSuccess } func makePLChange(old, new int, newIsSet bool) *SinglePowerLevelChange { @@ -1635,18 +1662,19 @@ func (portal *Portal) handleMatrixPowerLevels( sender *UserLogin, origSender *OrigSender, evt *event.Event, -) { +) EventHandlingResult { log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(*event.PowerLevelsEventContent) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) - return + typeErr := fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed) + portal.sendErrorStatus(ctx, evt, typeErr) + return EventHandlingResultFailed.WithError(typeErr) } api, ok := sender.Client.(PowerLevelHandlingNetworkAPI) if !ok { portal.sendErrorStatus(ctx, evt, ErrPowerLevelsNotSupported) - return + return EventHandlingResultIgnored } prevContent := &event.PowerLevelsEventContent{} if evt.Unsigned.PrevContent != nil { @@ -1706,17 +1734,21 @@ func (portal *Portal) handleMatrixPowerLevels( if err != nil { log.Err(err).Msg("Failed to handle Matrix power level change") portal.sendErrorStatus(ctx, evt, err) - return + return EventHandlingResultFailed.WithError(err) } + return EventHandlingResultSuccess } -func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { +func (portal *Portal) handleMatrixRedaction( + ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, +) EventHandlingResult { log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(*event.RedactionEventContent) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) - return + typeErr := fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed) + portal.sendErrorStatus(ctx, evt, typeErr) + return EventHandlingResultFailed.WithError(typeErr) } if evt.Redacts != "" && content.Redacts != evt.Redacts { content.Redacts = evt.Redacts @@ -1729,19 +1761,19 @@ func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLog if !deleteOK && !reactOK { log.Debug().Msg("Ignoring redaction without checking target as network connector doesn't implement RedactionHandlingNetworkAPI nor ReactionHandlingNetworkAPI") portal.sendErrorStatus(ctx, evt, ErrRedactionsNotSupported) - return + return EventHandlingResultIgnored } var redactionTargetReaction *database.Reaction redactionTargetMsg, err := portal.Bridge.DB.Message.GetPartByMXID(ctx, content.Redacts) if err != nil { log.Err(err).Msg("Failed to get redaction target message from database") portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get redaction target message: %w", ErrDatabaseError, err)) - return + return EventHandlingResultFailed } else if redactionTargetMsg != nil { if !deleteOK { log.Debug().Msg("Ignoring message redaction event as network connector doesn't implement RedactionHandlingNetworkAPI") portal.sendErrorStatus(ctx, evt, ErrRedactionsNotSupported) - return + return EventHandlingResultIgnored } err = deletingAPI.HandleMatrixMessageRemove(ctx, &MatrixMessageRemove{ MatrixEventBase: MatrixEventBase[*event.RedactionEventContent]{ @@ -1757,12 +1789,12 @@ func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLog } else if redactionTargetReaction, err = portal.Bridge.DB.Reaction.GetByMXID(ctx, content.Redacts); err != nil { log.Err(err).Msg("Failed to get redaction target reaction from database") portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get redaction target message reaction: %w", ErrDatabaseError, err)) - return + return EventHandlingResultFailed } else if redactionTargetReaction != nil { if !reactOK { log.Debug().Msg("Ignoring reaction redaction event as network connector doesn't implement ReactionHandlingNetworkAPI") portal.sendErrorStatus(ctx, evt, ErrReactionsNotSupported) - return + return EventHandlingResultIgnored } // TODO ignore if sender doesn't match? err = reactingAPI.HandleMatrixReactionRemove(ctx, &MatrixReactionRemove{ @@ -1778,16 +1810,18 @@ func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *UserLog }) } else { log.Debug().Msg("Redaction target message not found in database") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("redaction %w", ErrTargetMessageNotFound)) - return + notFoundErr := fmt.Errorf("redaction %w", ErrTargetMessageNotFound) + portal.sendErrorStatus(ctx, evt, notFoundErr) + return EventHandlingResultIgnored } if err != nil { log.Err(err).Msg("Failed to handle Matrix redaction") portal.sendErrorStatus(ctx, evt, err) - return + return EventHandlingResultFailed.WithError(err) } // TODO delete msg/reaction db row portal.sendSuccessStatus(ctx, evt, 0, "") + return EventHandlingResultSuccess } func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, evtType RemoteEventType, evt RemoteEvent) (res EventHandlingResult) { diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index 6815f043..ae338383 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -61,20 +61,20 @@ func (portal *PortalInternals) CheckConfusableName(ctx context.Context, userID i return (*Portal)(portal).checkConfusableName(ctx, userID, name) } -func (portal *PortalInternals) HandleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) { - (*Portal)(portal).handleMatrixEvent(ctx, sender, evt) +func (portal *PortalInternals) HandleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixEvent(ctx, sender, evt) } -func (portal *PortalInternals) HandleMatrixReceipts(ctx context.Context, evt *event.Event) { - (*Portal)(portal).handleMatrixReceipts(ctx, evt) +func (portal *PortalInternals) HandleMatrixReceipts(ctx context.Context, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixReceipts(ctx, evt) } func (portal *PortalInternals) HandleMatrixReadReceipt(ctx context.Context, user *User, eventID id.EventID, receipt event.ReadReceipt) { (*Portal)(portal).handleMatrixReadReceipt(ctx, user, eventID, receipt) } -func (portal *PortalInternals) HandleMatrixTyping(ctx context.Context, evt *event.Event) { - (*Portal)(portal).handleMatrixTyping(ctx, evt) +func (portal *PortalInternals) HandleMatrixTyping(ctx context.Context, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixTyping(ctx, evt) } func (portal *PortalInternals) SendTypings(ctx context.Context, userIDs []id.UserID, typing bool) { @@ -93,8 +93,8 @@ func (portal *PortalInternals) ParseInputTransactionID(origSender *OrigSender, e return (*Portal)(portal).parseInputTransactionID(origSender, evt) } -func (portal *PortalInternals) HandleMatrixMessage(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { - (*Portal)(portal).handleMatrixMessage(ctx, sender, origSender, evt) +func (portal *PortalInternals) HandleMatrixMessage(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixMessage(ctx, sender, origSender, evt) } func (portal *PortalInternals) PendingMessageTimeoutLoop(ctx context.Context, cfg *OutgoingTimeoutConfig) { @@ -105,28 +105,28 @@ func (portal *PortalInternals) CheckPendingMessages(ctx context.Context, cfg *Ou (*Portal)(portal).checkPendingMessages(ctx, cfg) } -func (portal *PortalInternals) HandleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent, caps *event.RoomFeatures) { - (*Portal)(portal).handleMatrixEdit(ctx, sender, origSender, evt, content, caps) +func (portal *PortalInternals) HandleMatrixEdit(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, content *event.MessageEventContent, caps *event.RoomFeatures) EventHandlingResult { + return (*Portal)(portal).handleMatrixEdit(ctx, sender, origSender, evt, content, caps) } -func (portal *PortalInternals) HandleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) { - (*Portal)(portal).handleMatrixReaction(ctx, sender, evt) +func (portal *PortalInternals) HandleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixReaction(ctx, sender, evt) } func (portal *PortalInternals) GetTargetUser(ctx context.Context, userID id.UserID) (GhostOrUserLogin, error) { return (*Portal)(portal).getTargetUser(ctx, userID) } -func (portal *PortalInternals) HandleMatrixMembership(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { - (*Portal)(portal).handleMatrixMembership(ctx, sender, origSender, evt) +func (portal *PortalInternals) HandleMatrixMembership(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixMembership(ctx, sender, origSender, evt) } -func (portal *PortalInternals) HandleMatrixPowerLevels(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { - (*Portal)(portal).handleMatrixPowerLevels(ctx, sender, origSender, evt) +func (portal *PortalInternals) HandleMatrixPowerLevels(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixPowerLevels(ctx, sender, origSender, evt) } -func (portal *PortalInternals) HandleMatrixRedaction(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) { - (*Portal)(portal).handleMatrixRedaction(ctx, sender, origSender, evt) +func (portal *PortalInternals) HandleMatrixRedaction(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixRedaction(ctx, sender, origSender, evt) } func (portal *PortalInternals) HandleRemoteEvent(ctx context.Context, source *UserLogin, evtType RemoteEventType, evt RemoteEvent) (res EventHandlingResult) { diff --git a/bridgev2/queue.go b/bridgev2/queue.go index 48ee78f1..4a107d36 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -63,7 +63,7 @@ func (br *Bridge) rejectInviteOnNoPermission(ctx context.Context, evt *event.Eve return true } -func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) { +func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventHandlingResult { // TODO maybe HandleMatrixEvent would be more appropriate as this also handles bot invites and commands log := zerolog.Ctx(ctx) @@ -75,26 +75,26 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) { log.Err(err).Msg("Failed to get sender user for incoming Matrix event") status := WrapErrorInStatus(fmt.Errorf("%w: failed to get sender user: %w", ErrDatabaseError, err)) br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) - return + return EventHandlingResultFailed } else if sender == nil { log.Error().Msg("Couldn't get sender for incoming non-ephemeral Matrix event") status := WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage() br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) - return + return EventHandlingResultFailed } else if !sender.Permissions.SendEvents { if !br.rejectInviteOnNoPermission(ctx, evt, "interact with") { status := WrapErrorInStatus(errors.New("you don't have permission to send messages")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage() br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) } - return + return EventHandlingResultIgnored } else if !sender.Permissions.Commands && br.rejectInviteOnNoPermission(ctx, evt, "send commands to") { - return + return EventHandlingResultIgnored } } else if evt.Type.Class != event.EphemeralEventType { log.Error().Msg("Missing sender for incoming non-ephemeral Matrix event") status := WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage() br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) - return + return EventHandlingResultIgnored } if evt.Type == event.EventMessage && sender != nil { msg := evt.Content.AsMessage() @@ -104,7 +104,7 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) { if !sender.Permissions.Commands { status := WrapErrorInStatus(errors.New("you don't have permission to use commands")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage() br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) - return + return EventHandlingResultIgnored } br.Commands.Handle( ctx, @@ -114,40 +114,41 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) { strings.TrimPrefix(msg.Body, br.Config.CommandPrefix+" "), msg.RelatesTo.GetReplyTo(), ) - return + return EventHandlingResultSuccess } } if evt.Type == event.StateMember && evt.GetStateKey() == br.Bot.GetMXID().String() && evt.Content.AsMember().Membership == event.MembershipInvite && sender != nil { - br.handleBotInvite(ctx, evt, sender) - return + return br.handleBotInvite(ctx, evt, sender) } else if sender != nil && evt.RoomID == sender.ManagementRoom { if evt.Type == event.StateMember && evt.Content.AsMember().Membership == event.MembershipLeave && (evt.GetStateKey() == br.Bot.GetMXID().String() || evt.GetStateKey() == sender.MXID.String()) { sender.ManagementRoom = "" err := br.DB.User.Update(ctx, sender.User) if err != nil { log.Err(err).Msg("Failed to clear user's management room in database") + return EventHandlingResultFailed } else { log.Debug().Msg("Cleared user's management room due to leave event") } } - return + return EventHandlingResultSuccess } portal, err := br.GetPortalByMXID(ctx, evt.RoomID) if err != nil { log.Err(err).Msg("Failed to get portal for incoming Matrix event") status := WrapErrorInStatus(fmt.Errorf("%w: failed to get portal: %w", ErrDatabaseError, err)) br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) - return + return EventHandlingResultFailed } else if portal != nil { - portal.queueEvent(ctx, &portalMatrixEvent{ + return portal.queueEvent(ctx, &portalMatrixEvent{ evt: evt, sender: sender, }) } else if evt.Type == event.StateMember && br.IsGhostMXID(id.UserID(evt.GetStateKey())) && evt.Content.AsMember().Membership == event.MembershipInvite && evt.Content.AsMember().IsDirect { - br.handleGhostDMInvite(ctx, evt, sender) + return br.handleGhostDMInvite(ctx, evt, sender) } else { status := WrapErrorInStatus(ErrNoPortal) br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) + return EventHandlingResultIgnored } } @@ -155,6 +156,18 @@ type EventHandlingResult struct { Success bool Ignored bool Queued bool + + // Error is an optional reason for failure. It is not required, Success may be false even without a specific error. + Error error +} + +func (ehr EventHandlingResult) WithError(err error) EventHandlingResult { + if err == nil { + return ehr + } + ehr.Error = err + ehr.Success = false + return ehr } var ( From 4f8ff2a35079a0cea77e2855a636b38fffe3dfff Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 10 Jul 2025 15:04:57 +0300 Subject: [PATCH 1227/1647] bridgev2/portal: merge MSS errors with handling result --- bridgev2/portal.go | 191 ++++++++++++++----------------------- bridgev2/portalinternal.go | 4 +- bridgev2/queue.go | 14 +++ 3 files changed, 88 insertions(+), 121 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 900d057d..136ecd12 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -447,6 +447,13 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal switch evt := rawEvt.(type) { case *portalMatrixEvent: res = portal.handleMatrixEvent(ctx, evt.sender, evt.evt) + if res.SendMSS { + if res.Error != nil { + portal.sendErrorStatus(ctx, evt.evt, res.Error) + } else { + portal.sendSuccessStatus(ctx, evt.evt, 0, "") + } + } case *portalRemoteEvent: res = portal.handleRemoteEvent(ctx, evt.source, evt.evtType, evt.evt) case *portalCreateEvent: @@ -569,11 +576,14 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * log.Err(err).Msg("Failed to get user login to handle Matrix event") if errors.Is(err, ErrNotLoggedIn) { shouldSendNotice := evt.Content.AsMessage().MsgType != event.MsgNotice - portal.sendErrorStatus(ctx, evt, WrapErrorInStatus(err).WithMessage("You're not logged in").WithIsCertain(true).WithSendNotice(shouldSendNotice)) + return EventHandlingResultFailed.WithMSSError( + WrapErrorInStatus(err).WithMessage("You're not logged in").WithIsCertain(true).WithSendNotice(shouldSendNotice), + ) } else { - portal.sendErrorStatus(ctx, evt, WrapErrorInStatus(err).WithMessage("Failed to get login to handle event").WithIsCertain(true).WithSendNotice(true)) + return EventHandlingResultFailed.WithMSSError( + WrapErrorInStatus(err).WithMessage("Failed to get login to handle event").WithIsCertain(true).WithSendNotice(true), + ) } - return EventHandlingResultFailed } var origSender *OrigSender if login == nil { @@ -626,8 +636,7 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * case event.EventReaction: if origSender != nil { log.Debug().Msg("Ignoring reaction event from relayed user") - portal.sendErrorStatus(ctx, evt, ErrIgnoringReactionFromRelayedUser) - return EventHandlingResultIgnored + return EventHandlingResultIgnored.WithMSSError(ErrIgnoringReactionFromRelayedUser) } return portal.handleMatrixReaction(ctx, login, evt) case event.EventRedaction: @@ -848,39 +857,35 @@ func (portal *Portal) periodicTypingUpdater() { } } -func (portal *Portal) checkMessageContentCaps(ctx context.Context, caps *event.RoomFeatures, content *event.MessageEventContent, evt *event.Event) bool { +func (portal *Portal) checkMessageContentCaps(caps *event.RoomFeatures, content *event.MessageEventContent) error { switch content.MsgType { case event.MsgText, event.MsgNotice, event.MsgEmote: // No checks for now, message length is safer to check after conversion inside connector case event.MsgLocation: if caps.LocationMessage.Reject() { - portal.sendErrorStatus(ctx, evt, ErrLocationMessagesNotAllowed) - return false + return ErrLocationMessagesNotAllowed } case event.MsgImage, event.MsgAudio, event.MsgVideo, event.MsgFile, event.CapMsgSticker: capMsgType := content.GetCapMsgType() feat, ok := caps.File[capMsgType] if !ok { - portal.sendErrorStatus(ctx, evt, ErrUnsupportedMessageType) - return false + return ErrUnsupportedMessageType } if content.MsgType != event.CapMsgSticker && content.FileName != "" && content.Body != content.FileName && feat.Caption.Reject() { - portal.sendErrorStatus(ctx, evt, ErrCaptionsNotAllowed) - return false + return ErrCaptionsNotAllowed } if content.Info != nil && content.Info.MimeType != "" { if feat.GetMimeSupport(content.Info.MimeType).Reject() { - portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w (%s in %s)", ErrUnsupportedMediaType, content.Info.MimeType, capMsgType)) - return false + return fmt.Errorf("%w (%s in %s)", ErrUnsupportedMediaType, content.Info.MimeType, capMsgType) } } fallthrough default: } - return true + return nil } func (portal *Portal) parseInputTransactionID(origSender *OrigSender, evt *event.Event) networkid.RawTransactionID { @@ -910,23 +915,20 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin msgContent.MsgType = event.CapMsgSticker } if msgContent.MsgType == event.MsgNotice && !portal.Bridge.Config.BridgeNotices { - portal.sendErrorStatus(ctx, evt, ErrIgnoringMNotice) - return EventHandlingResultIgnored + return EventHandlingResultIgnored.WithMSSError(ErrIgnoringMNotice) } } if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - typeErr := fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed) - portal.sendErrorStatus(ctx, evt, typeErr) - return EventHandlingResultFailed.WithError(typeErr) + return EventHandlingResultFailed. + WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) } caps := sender.Client.GetCapabilities(ctx, portal) if relatesTo.GetReplaceID() != "" { if msgContent == nil { log.Warn().Msg("Ignoring edit of poll") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w of polls", ErrEditsNotSupported)) - return EventHandlingResultFailed.WithError(fmt.Errorf("%w of polls", ErrEditsNotSupported)) + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w of polls", ErrEditsNotSupported)) } return portal.handleMatrixEdit(ctx, sender, origSender, evt, msgContent, caps) } @@ -934,25 +936,22 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin if origSender != nil { if msgContent == nil { log.Debug().Msg("Ignoring poll event from relayed user") - portal.sendErrorStatus(ctx, evt, ErrIgnoringPollFromRelayedUser) - return EventHandlingResultIgnored + return EventHandlingResultIgnored.WithMSSError(ErrIgnoringPollFromRelayedUser) } msgContent, err = portal.Bridge.Config.Relay.FormatMessage(msgContent, origSender) if err != nil { log.Err(err).Msg("Failed to format message for relaying") - portal.sendErrorStatus(ctx, evt, err) - return EventHandlingResultFailed.WithError(err) + return EventHandlingResultFailed.WithMSSError(err) } } if msgContent != nil { - if !portal.checkMessageContentCaps(ctx, caps, msgContent, evt) { - return EventHandlingResultFailed + if err = portal.checkMessageContentCaps(caps, msgContent); err != nil { + return EventHandlingResultFailed.WithMSSError(err) } } else if pollResponseContent != nil || pollContent != nil { if _, ok = sender.Client.(PollHandlingNetworkAPI); !ok { log.Debug().Msg("Ignoring poll event as network connector doesn't implement PollHandlingNetworkAPI") - portal.sendErrorStatus(ctx, evt, ErrPollsNotSupported) - return EventHandlingResultIgnored + return EventHandlingResultIgnored.WithMSSError(ErrPollsNotSupported) } } @@ -1051,13 +1050,11 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin }) } else { log.Error().Msg("Failed to handle Matrix message: all contents are nil?") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("all contents are nil")) - return EventHandlingResultFailed + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("all contents are nil")) } if err != nil { log.Err(err).Msg("Failed to handle Matrix message") - portal.sendErrorStatus(ctx, evt, err) - return EventHandlingResultFailed.WithError(err) + return EventHandlingResultFailed.WithMSSError(err) } message := wrappedMsgEvt.fillDBMessage(resp.DB) if resp.Pending { @@ -1235,39 +1232,31 @@ func (portal *Portal) handleMatrixEdit( content, err = portal.Bridge.Config.Relay.FormatMessage(content, origSender) if err != nil { log.Err(err).Msg("Failed to format message for relaying") - portal.sendErrorStatus(ctx, evt, err) - return EventHandlingResultFailed.WithError(err) + return EventHandlingResultFailed.WithMSSError(err) } } editingAPI, ok := sender.Client.(EditHandlingNetworkAPI) if !ok { log.Debug().Msg("Ignoring edit as network connector doesn't implement EditHandlingNetworkAPI") - portal.sendErrorStatus(ctx, evt, ErrEditsNotSupported) - return EventHandlingResultIgnored + return EventHandlingResultIgnored.WithMSSError(ErrEditsNotSupported) } else if !caps.Edit.Partial() { log.Debug().Msg("Ignoring edit as room doesn't support edits") - portal.sendErrorStatus(ctx, evt, ErrEditsNotSupportedInPortal) - return EventHandlingResultIgnored - } else if !portal.checkMessageContentCaps(ctx, caps, content, evt) { - return EventHandlingResultFailed + return EventHandlingResultIgnored.WithMSSError(ErrEditsNotSupportedInPortal) + } else if err := portal.checkMessageContentCaps(caps, content); err != nil { + return EventHandlingResultFailed.WithMSSError(err) } editTarget, err := portal.Bridge.DB.Message.GetPartByMXID(ctx, editTargetID) if err != nil { log.Err(err).Msg("Failed to get edit target message from database") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get edit target: %w", ErrDatabaseError, err)) - return EventHandlingResultFailed + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: failed to get edit target: %w", ErrDatabaseError, err)) } else if editTarget == nil { log.Warn().Msg("Edit target message not found in database") - notFoundErr := fmt.Errorf("edit %w", ErrTargetMessageNotFound) - portal.sendErrorStatus(ctx, evt, notFoundErr) - return EventHandlingResultFailed.WithError(notFoundErr) + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("edit %w", ErrTargetMessageNotFound)) } else if caps.EditMaxAge != nil && caps.EditMaxAge.Duration > 0 && time.Since(editTarget.Timestamp) > caps.EditMaxAge.Duration { - portal.sendErrorStatus(ctx, evt, ErrEditTargetTooOld) - return EventHandlingResultFailed.WithError(ErrEditTargetTooOld) + return EventHandlingResultFailed.WithMSSError(ErrEditTargetTooOld) } else if caps.EditMaxCount > 0 && editTarget.EditCount >= caps.EditMaxCount { - portal.sendErrorStatus(ctx, evt, ErrEditTargetTooManyEdits) - return EventHandlingResultFailed.WithError(ErrEditTargetTooManyEdits) + return EventHandlingResultFailed.WithMSSError(ErrEditTargetTooManyEdits) } log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Str("edit_target_remote_id", string(editTarget.ID)) @@ -1285,8 +1274,7 @@ func (portal *Portal) handleMatrixEdit( }) if err != nil { log.Err(err).Msg("Failed to handle Matrix edit") - portal.sendErrorStatus(ctx, evt, err) - return EventHandlingResultFailed.WithError(err) + return EventHandlingResultFailed.WithMSSError(err) } err = portal.Bridge.DB.Message.Update(ctx, editTarget) if err != nil { @@ -1302,15 +1290,12 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi reactingAPI, ok := sender.Client.(ReactionHandlingNetworkAPI) if !ok { log.Debug().Msg("Ignoring reaction as network connector doesn't implement ReactionHandlingNetworkAPI") - portal.sendErrorStatus(ctx, evt, ErrReactionsNotSupported) - return EventHandlingResultIgnored + return EventHandlingResultIgnored.WithMSSError(ErrReactionsNotSupported) } content, ok := evt.Content.Parsed.(*event.ReactionEventContent) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - typeErr := fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed) - portal.sendErrorStatus(ctx, evt, typeErr) - return EventHandlingResultFailed.WithError(typeErr) + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) } log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Stringer("reaction_target_mxid", content.RelatesTo.EventID) @@ -1318,13 +1303,10 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi reactionTarget, err := portal.Bridge.DB.Message.GetPartByMXID(ctx, content.RelatesTo.EventID) if err != nil { log.Err(err).Msg("Failed to get reaction target message from database") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get reaction target: %w", ErrDatabaseError, err)) - return EventHandlingResultFailed + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: failed to get reaction target: %w", ErrDatabaseError, err)) } else if reactionTarget == nil { log.Warn().Msg("Reaction target message not found in database") - notFoundErr := fmt.Errorf("reaction %w", ErrTargetMessageNotFound) - portal.sendErrorStatus(ctx, evt, notFoundErr) - return EventHandlingResultFailed.WithError(notFoundErr) + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("reaction %w", ErrTargetMessageNotFound)) } log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Str("reaction_target_remote_id", string(reactionTarget.ID)) @@ -1342,8 +1324,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi preResp, err := reactingAPI.PreHandleMatrixReaction(ctx, react) if err != nil { log.Err(err).Msg("Failed to pre-handle Matrix reaction") - portal.sendErrorStatus(ctx, evt, err) - return EventHandlingResultFailed.WithError(err) + return EventHandlingResultFailed.WithMSSError(err) } var deterministicID id.EventID if portal.Bridge.Config.OutgoingMessageReID { @@ -1352,7 +1333,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi existing, err := portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, reactionTarget.ID, reactionTarget.PartID, preResp.SenderID, preResp.EmojiID) if err != nil { log.Err(err).Msg("Failed to check if reaction is a duplicate") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: failed to check for existing reaction: %w", ErrDatabaseError, err)) } else if existing != nil { if existing.EmojiID != "" || existing.Emoji == preResp.Emoji { log.Debug().Msg("Ignoring duplicate reaction") @@ -1374,8 +1355,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi allReactions, err := portal.Bridge.DB.Reaction.GetAllToMessageBySender(ctx, portal.Receiver, reactionTarget.ID, preResp.SenderID) if err != nil { log.Err(err).Msg("Failed to get all reactions to message by sender") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get previous reactions: %w", ErrDatabaseError, err)) - return EventHandlingResultFailed + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: failed to get previous reactions: %w", ErrDatabaseError, err)) } if len(allReactions) < preResp.MaxReactions { react.ExistingReactionsToKeep = allReactions @@ -1401,8 +1381,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi dbReaction, err := reactingAPI.HandleMatrixReaction(ctx, react) if err != nil { log.Err(err).Msg("Failed to handle Matrix reaction") - portal.sendErrorStatus(ctx, evt, err) - return EventHandlingResultFailed.WithError(err) + return EventHandlingResultFailed.WithMSSError(err) } if dbReaction == nil { dbReaction = &database.Reaction{} @@ -1454,16 +1433,13 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( ) EventHandlingResult { api, ok := sender.Client.(APIType) if !ok { - portal.sendErrorStatus(ctx, evt, ErrRoomMetadataNotSupported) - return EventHandlingResultIgnored + return EventHandlingResultIgnored.WithMSSError(ErrRoomMetadataNotSupported) } log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(ContentType) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - typeErr := fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed) - portal.sendErrorStatus(ctx, evt, typeErr) - return EventHandlingResultFailed.WithError(typeErr) + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) } switch typedContent := evt.Content.Parsed.(type) { case *event.RoomNameEventContent: @@ -1501,8 +1477,7 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( }) if err != nil { log.Err(err).Msg("Failed to handle Matrix room metadata") - portal.sendErrorStatus(ctx, evt, err) - return EventHandlingResultFailed.WithError(err) + return EventHandlingResultFailed.WithMSSError(err) } if changed { portal.UpdateBridgeInfo(ctx) @@ -1511,8 +1486,7 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( log.Err(err).Msg("Failed to save portal after updating room metadata") } } - portal.sendSuccessStatus(ctx, evt, 0, "") - return EventHandlingResultSuccess + return EventHandlingResultSuccess.WithMSS() } func handleMatrixAccountData[APIType any, ContentType any]( @@ -1577,9 +1551,7 @@ func (portal *Portal) handleMatrixMembership( content, ok := evt.Content.Parsed.(*event.MemberEventContent) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - typeErr := fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed) - portal.sendErrorStatus(ctx, evt, typeErr) - return EventHandlingResultFailed.WithError(typeErr) + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) } prevContent := &event.MemberEventContent{Membership: event.MembershipLeave} if evt.Unsigned.PrevContent != nil { @@ -1594,23 +1566,20 @@ func (portal *Portal) handleMatrixMembership( }) api, ok := sender.Client.(MembershipHandlingNetworkAPI) if !ok { - portal.sendErrorStatus(ctx, evt, ErrMembershipNotSupported) - return EventHandlingResultIgnored + return EventHandlingResultIgnored.WithMSSError(ErrMembershipNotSupported) } targetMXID := id.UserID(*evt.StateKey) isSelf := sender.User.MXID == targetMXID target, err := portal.getTargetUser(ctx, targetMXID) if err != nil { log.Err(err).Msg("Failed to get member event target") - portal.sendErrorStatus(ctx, evt, err) - return EventHandlingResultFailed + return EventHandlingResultFailed.WithMSSError(err) } membershipChangeType := MembershipChangeType{From: prevContent.Membership, To: content.Membership, IsSelf: isSelf} if !portal.Bridge.Config.BridgeMatrixLeave && membershipChangeType == Leave { log.Debug().Msg("Dropping leave event") - //portal.sendErrorStatus(ctx, evt, ErrIgnoringLeaveEvent) - return EventHandlingResultIgnored + return EventHandlingResultIgnored //.WithMSSError(ErrIgnoringLeaveEvent) } targetGhost, _ := target.(*Ghost) targetUserLogin, _ := target.(*UserLogin) @@ -1634,10 +1603,9 @@ func (portal *Portal) handleMatrixMembership( _, err = api.HandleMatrixMembership(ctx, membershipChange) if err != nil { log.Err(err).Msg("Failed to handle Matrix membership change") - portal.sendErrorStatus(ctx, evt, err) - return EventHandlingResultFailed.WithError(err) + return EventHandlingResultFailed.WithMSSError(err) } - return EventHandlingResultSuccess + return EventHandlingResultSuccess.WithMSS() } func makePLChange(old, new int, newIsSet bool) *SinglePowerLevelChange { @@ -1667,14 +1635,11 @@ func (portal *Portal) handleMatrixPowerLevels( content, ok := evt.Content.Parsed.(*event.PowerLevelsEventContent) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - typeErr := fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed) - portal.sendErrorStatus(ctx, evt, typeErr) - return EventHandlingResultFailed.WithError(typeErr) + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) } api, ok := sender.Client.(PowerLevelHandlingNetworkAPI) if !ok { - portal.sendErrorStatus(ctx, evt, ErrPowerLevelsNotSupported) - return EventHandlingResultIgnored + return EventHandlingResultIgnored.WithMSSError(ErrPowerLevelsNotSupported) } prevContent := &event.PowerLevelsEventContent{} if evt.Unsigned.PrevContent != nil { @@ -1733,10 +1698,9 @@ func (portal *Portal) handleMatrixPowerLevels( _, err := api.HandleMatrixPowerLevels(ctx, plChange) if err != nil { log.Err(err).Msg("Failed to handle Matrix power level change") - portal.sendErrorStatus(ctx, evt, err) - return EventHandlingResultFailed.WithError(err) + return EventHandlingResultFailed.WithMSSError(err) } - return EventHandlingResultSuccess + return EventHandlingResultSuccess.WithMSS() } func (portal *Portal) handleMatrixRedaction( @@ -1746,9 +1710,7 @@ func (portal *Portal) handleMatrixRedaction( content, ok := evt.Content.Parsed.(*event.RedactionEventContent) if !ok { log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") - typeErr := fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed) - portal.sendErrorStatus(ctx, evt, typeErr) - return EventHandlingResultFailed.WithError(typeErr) + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) } if evt.Redacts != "" && content.Redacts != evt.Redacts { content.Redacts = evt.Redacts @@ -1760,20 +1722,17 @@ func (portal *Portal) handleMatrixRedaction( reactingAPI, reactOK := sender.Client.(ReactionHandlingNetworkAPI) if !deleteOK && !reactOK { log.Debug().Msg("Ignoring redaction without checking target as network connector doesn't implement RedactionHandlingNetworkAPI nor ReactionHandlingNetworkAPI") - portal.sendErrorStatus(ctx, evt, ErrRedactionsNotSupported) - return EventHandlingResultIgnored + return EventHandlingResultIgnored.WithMSSError(ErrRedactionsNotSupported) } var redactionTargetReaction *database.Reaction redactionTargetMsg, err := portal.Bridge.DB.Message.GetPartByMXID(ctx, content.Redacts) if err != nil { log.Err(err).Msg("Failed to get redaction target message from database") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get redaction target message: %w", ErrDatabaseError, err)) - return EventHandlingResultFailed + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: failed to get redaction target message: %w", ErrDatabaseError, err)) } else if redactionTargetMsg != nil { if !deleteOK { log.Debug().Msg("Ignoring message redaction event as network connector doesn't implement RedactionHandlingNetworkAPI") - portal.sendErrorStatus(ctx, evt, ErrRedactionsNotSupported) - return EventHandlingResultIgnored + return EventHandlingResultIgnored.WithMSSError(ErrRedactionsNotSupported) } err = deletingAPI.HandleMatrixMessageRemove(ctx, &MatrixMessageRemove{ MatrixEventBase: MatrixEventBase[*event.RedactionEventContent]{ @@ -1788,13 +1747,11 @@ func (portal *Portal) handleMatrixRedaction( }) } else if redactionTargetReaction, err = portal.Bridge.DB.Reaction.GetByMXID(ctx, content.Redacts); err != nil { log.Err(err).Msg("Failed to get redaction target reaction from database") - portal.sendErrorStatus(ctx, evt, fmt.Errorf("%w: failed to get redaction target message reaction: %w", ErrDatabaseError, err)) - return EventHandlingResultFailed + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: failed to get redaction target message reaction: %w", ErrDatabaseError, err)) } else if redactionTargetReaction != nil { if !reactOK { log.Debug().Msg("Ignoring reaction redaction event as network connector doesn't implement ReactionHandlingNetworkAPI") - portal.sendErrorStatus(ctx, evt, ErrReactionsNotSupported) - return EventHandlingResultIgnored + return EventHandlingResultIgnored.WithMSSError(ErrReactionsNotSupported) } // TODO ignore if sender doesn't match? err = reactingAPI.HandleMatrixReactionRemove(ctx, &MatrixReactionRemove{ @@ -1810,18 +1767,14 @@ func (portal *Portal) handleMatrixRedaction( }) } else { log.Debug().Msg("Redaction target message not found in database") - notFoundErr := fmt.Errorf("redaction %w", ErrTargetMessageNotFound) - portal.sendErrorStatus(ctx, evt, notFoundErr) - return EventHandlingResultIgnored + return EventHandlingResultIgnored.WithMSSError(fmt.Errorf("redaction %w", ErrTargetMessageNotFound)) } if err != nil { log.Err(err).Msg("Failed to handle Matrix redaction") - portal.sendErrorStatus(ctx, evt, err) - return EventHandlingResultFailed.WithError(err) + return EventHandlingResultFailed.WithMSSError(err) } // TODO delete msg/reaction db row - portal.sendSuccessStatus(ctx, evt, 0, "") - return EventHandlingResultSuccess + return EventHandlingResultSuccess.WithMSS() } func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, evtType RemoteEventType, evt RemoteEvent) (res EventHandlingResult) { diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index ae338383..e82c481a 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -85,8 +85,8 @@ func (portal *PortalInternals) PeriodicTypingUpdater() { (*Portal)(portal).periodicTypingUpdater() } -func (portal *PortalInternals) CheckMessageContentCaps(ctx context.Context, caps *event.RoomFeatures, content *event.MessageEventContent, evt *event.Event) bool { - return (*Portal)(portal).checkMessageContentCaps(ctx, caps, content, evt) +func (portal *PortalInternals) CheckMessageContentCaps(caps *event.RoomFeatures, content *event.MessageEventContent) error { + return (*Portal)(portal).checkMessageContentCaps(caps, content) } func (portal *PortalInternals) ParseInputTransactionID(origSender *OrigSender, evt *event.Event) networkid.RawTransactionID { diff --git a/bridgev2/queue.go b/bridgev2/queue.go index 4a107d36..04d982b5 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -159,6 +159,8 @@ type EventHandlingResult struct { // Error is an optional reason for failure. It is not required, Success may be false even without a specific error. Error error + // Whether the Error should be sent as a MSS event. + SendMSS bool } func (ehr EventHandlingResult) WithError(err error) EventHandlingResult { @@ -170,6 +172,18 @@ func (ehr EventHandlingResult) WithError(err error) EventHandlingResult { return ehr } +func (ehr EventHandlingResult) WithMSS() EventHandlingResult { + ehr.SendMSS = true + return ehr +} + +func (ehr EventHandlingResult) WithMSSError(err error) EventHandlingResult { + if err == nil { + return ehr + } + return ehr.WithError(err).WithMSS() +} + var ( EventHandlingResultFailed = EventHandlingResult{} EventHandlingResultQueued = EventHandlingResult{Success: true, Queued: true} From 5e29bac3dd9ce315be28f20aed9b03e19333f49c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 10 Jul 2025 16:19:37 +0300 Subject: [PATCH 1228/1647] bridgev2/portal: adjust handleMatrixMessage return value for pending messages --- bridgev2/portal.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 136ecd12..1fab94e6 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1096,6 +1096,10 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin }, }) } + if resp.Pending { + // Not exactly queued, but not finished either + return EventHandlingResultQueued + } return EventHandlingResultSuccess } From b74368ac2302a10d9805a548e68fdc306d0526d1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 15 Jul 2025 13:19:44 +0300 Subject: [PATCH 1229/1647] commands: add safety to type check --- commands/event.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/commands/event.go b/commands/event.go index 65ddd3da..77a3c0d2 100644 --- a/commands/event.go +++ b/commands/event.go @@ -62,8 +62,8 @@ var IDHTMLParser = &format.HTMLParser{ // ParseEvent parses a message into a command event struct. func ParseEvent[MetaType any](ctx context.Context, evt *event.Event) *Event[MetaType] { - content := evt.Content.Parsed.(*event.MessageEventContent) - if content.MsgType == event.MsgNotice || content.RelatesTo.GetReplaceID() != "" { + content, ok := evt.Content.Parsed.(*event.MessageEventContent) + if !ok || content.MsgType == event.MsgNotice || content.RelatesTo.GetReplaceID() != "" { return nil } text := content.Body From 687717bd73cd6f163af3843ad0bf486826716810 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 15 Jul 2025 13:44:45 +0300 Subject: [PATCH 1230/1647] bridgev2: hardcode room v11 for new rooms Upcoming breaking changes in room v12 prevent safely using the default room version and security embargoes prevent fixing them ahead of time. --- bridgev2/portal.go | 1 + bridgev2/space.go | 3 ++- bridgev2/user.go | 5 +++-- requests.go | 2 +- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 1fab94e6..1d8faa1a 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -4064,6 +4064,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo IsDirect: portal.RoomType == database.RoomTypeDM, PowerLevelOverride: powerLevels, BeeperLocalRoomID: portal.Bridge.Matrix.GenerateDeterministicRoomID(portal.PortalKey), + RoomVersion: event.RoomV11, } autoJoinInvites := portal.Bridge.Matrix.GetCapabilities().AutoJoinInvites if autoJoinInvites { diff --git a/bridgev2/space.go b/bridgev2/space.go index 11de9cfa..ccb74b26 100644 --- a/bridgev2/space.go +++ b/bridgev2/space.go @@ -164,7 +164,8 @@ func (ul *UserLogin) GetSpaceRoom(ctx context.Context) (id.RoomID, error) { ul.UserMXID: 50, }, }, - Invite: []id.UserID{ul.UserMXID}, + RoomVersion: event.RoomV11, + Invite: []id.UserID{ul.UserMXID}, } if autoJoin { req.BeeperInitialMembers = []id.UserID{ul.UserMXID} diff --git a/bridgev2/user.go b/bridgev2/user.go index e6a5dd99..350cecd1 100644 --- a/bridgev2/user.go +++ b/bridgev2/user.go @@ -225,8 +225,9 @@ func (user *User) GetManagementRoom(ctx context.Context) (id.RoomID, error) { user.MXID: 50, }, }, - Invite: []id.UserID{user.MXID}, - IsDirect: true, + RoomVersion: event.RoomV11, + Invite: []id.UserID{user.MXID}, + IsDirect: true, } if autoJoin { req.BeeperInitialMembers = []id.UserID{user.MXID} diff --git a/requests.go b/requests.go index 8363aeda..09e4b3cd 100644 --- a/requests.go +++ b/requests.go @@ -120,7 +120,7 @@ type ReqCreateRoom struct { InitialState []*event.Event `json:"initial_state,omitempty"` Preset string `json:"preset,omitempty"` IsDirect bool `json:"is_direct,omitempty"` - RoomVersion string `json:"room_version,omitempty"` + RoomVersion event.RoomVersion `json:"room_version,omitempty"` PowerLevelOverride *event.PowerLevelsEventContent `json:"power_level_content_override,omitempty"` From 1d37430204bdceb16dbd029215524f93810daba0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 15 Jul 2025 14:31:44 +0300 Subject: [PATCH 1231/1647] bridgev2/portal: block in queueEvent if buffer is full --- bridgev2/portal.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 1d8faa1a..9fa90d89 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -296,8 +296,16 @@ func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) EventHand default: zerolog.Ctx(ctx).Error(). Str("portal_id", string(portal.ID)). - Msg("Portal event channel is full") - return EventHandlingResultFailed + Msg("Portal event channel is full, queue will block") + for { + select { + case portal.events <- evt: + case <-time.After(5 * time.Second): + zerolog.Ctx(ctx).Error(). + Str("portal_id", string(portal.ID)). + Msg("Portal event channel is still full") + } + } } } } From 1ee29a47b6c601e760132867cc3d80944528c530 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 15 Jul 2025 13:37:07 +0200 Subject: [PATCH 1232/1647] bridgev2: add option to auto-reconnect after unknown error (#394) --- bridgev2/bridgeconfig/config.go | 3 + bridgev2/bridgeconfig/upgrade.go | 1 + bridgev2/bridgestate.go | 78 +++++++++++++++++++++- bridgev2/matrix/mxmain/example-config.yaml | 3 + bridgev2/userlogin.go | 16 +++++ 5 files changed, 98 insertions(+), 3 deletions(-) diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index bd6f53c3..9bdee5fe 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -7,6 +7,8 @@ package bridgeconfig import ( + "time" + "go.mau.fi/util/dbutil" "go.mau.fi/zeroconfig" "gopkg.in/yaml.v3" @@ -66,6 +68,7 @@ type BridgeConfig struct { ResendBridgeInfo bool `yaml:"resend_bridge_info"` NoBridgeInfoStateKey bool `yaml:"no_bridge_info_state_key"` BridgeStatusNotices string `yaml:"bridge_status_notices"` + UnknownErrorAutoReconnect time.Duration `yaml:"unknown_error_auto_reconnect"` BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` BridgeNotices bool `yaml:"bridge_notices"` TagOnlyOnCreate bool `yaml:"tag_only_on_create"` diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index fa4b4493..b69a1fdb 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -32,6 +32,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "bridge", "resend_bridge_info") helper.Copy(up.Bool, "bridge", "no_bridge_info_state_key") helper.Copy(up.Str|up.Null, "bridge", "bridge_status_notices") + helper.Copy(up.Str|up.Int|up.Null, "bridge", "unknown_error_auto_reconnect") helper.Copy(up.Bool, "bridge", "bridge_matrix_leave") helper.Copy(up.Bool, "bridge", "bridge_notices") helper.Copy(up.Bool, "bridge", "tag_only_on_create") diff --git a/bridgev2/bridgestate.go b/bridgev2/bridgestate.go index 81ec8160..f31d4e92 100644 --- a/bridgev2/bridgestate.go +++ b/bridgev2/bridgestate.go @@ -9,7 +9,9 @@ package bridgev2 import ( "context" "fmt" + "math/rand/v2" "runtime/debug" + "sync/atomic" "time" "github.com/rs/zerolog" @@ -26,6 +28,9 @@ type BridgeStateQueue struct { ch chan status.BridgeState bridge *Bridge login *UserLogin + + stopChan chan struct{} + stopReconnect atomic.Pointer[context.CancelFunc] } func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) { @@ -47,16 +52,28 @@ func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) { func (br *Bridge) NewBridgeStateQueue(login *UserLogin) *BridgeStateQueue { bsq := &BridgeStateQueue{ - ch: make(chan status.BridgeState, 10), - bridge: br, - login: login, + ch: make(chan status.BridgeState, 10), + stopChan: make(chan struct{}), + bridge: br, + login: login, } go bsq.loop() return bsq } func (bsq *BridgeStateQueue) Destroy() { + close(bsq.stopChan) close(bsq.ch) + bsq.StopUnknownErrorReconnect() +} + +func (bsq *BridgeStateQueue) StopUnknownErrorReconnect() { + if bsq == nil { + return + } + if cancelFn := bsq.stopReconnect.Swap(nil); cancelFn != nil { + (*cancelFn)() + } } func (bsq *BridgeStateQueue) loop() { @@ -117,6 +134,58 @@ func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.Bridge } } +func (bsq *BridgeStateQueue) unknownErrorReconnect(triggeredBy status.BridgeState) { + log := bsq.login.Log.With().Str("action", "unknown error reconnect").Logger() + ctx := log.WithContext(bsq.bridge.BackgroundCtx) + if !bsq.waitForUnknownErrorReconnect(ctx) { + return + } + prevUnsent := bsq.GetPrevUnsent() + prev := bsq.GetPrev() + if triggeredBy.Timestamp != prev.Timestamp { + log.Debug().Msg("Not reconnecting as a new bridge state was sent after the unknown error") + return + } else if len(bsq.ch) > 0 { + log.Warn().Msg("Not reconnecting as there are unsent bridge states") + return + } else if prevUnsent.StateEvent != status.StateUnknownError || prev.StateEvent != status.StateUnknownError { + log.Debug().Msg("Not reconnecting as the previous state was not an unknown error") + return + } + log.Info().Msg("Disconnecting and reconnecting login due to unknown error") + bsq.login.Disconnect() + log.Debug().Msg("Disconnection finished, recreating client and reconnecting") + err := bsq.login.recreateClient(ctx) + if err != nil { + log.Err(err).Msg("Failed to recreate client after unknown error") + return + } + bsq.login.Client.Connect(ctx) + log.Debug().Msg("Reconnection finished") +} + +func (bsq *BridgeStateQueue) waitForUnknownErrorReconnect(ctx context.Context) bool { + reconnectIn := bsq.bridge.Config.UnknownErrorAutoReconnect + // Don't allow too low values + if reconnectIn < 1*time.Minute { + return false + } + reconnectIn += time.Duration(rand.Int64N(int64(float64(reconnectIn)*0.4)) - int64(float64(reconnectIn)*0.2)) + cancelCtx, cancel := context.WithCancel(ctx) + defer cancel() + if oldCancel := bsq.stopReconnect.Swap(&cancel); oldCancel != nil { + (*oldCancel)() + } + select { + case <-time.After(reconnectIn): + return bsq.stopReconnect.CompareAndSwap(&cancel, nil) + case <-cancelCtx.Done(): + return false + case <-bsq.stopChan: + return false + } +} + func (bsq *BridgeStateQueue) immediateSendBridgeState(state status.BridgeState) { if bsq.prevSent != nil && bsq.prevSent.ShouldDeduplicate(&state) { bsq.login.Log.Debug(). @@ -124,6 +193,9 @@ func (bsq *BridgeStateQueue) immediateSendBridgeState(state status.BridgeState) Msg("Not sending bridge state as it's a duplicate") return } + if state.StateEvent == status.StateUnknownError { + go bsq.unknownErrorReconnect(state) + } ctx := bsq.login.Log.WithContext(context.Background()) bsq.sendNotice(ctx, state) diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index dad3f8a8..48e0d528 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -25,6 +25,9 @@ bridge: # These contain the same data that can be posted to an external HTTP server using homeserver -> status_endpoint. # Allowed values: none, errors, all bridge_status_notices: errors + # How long after an unknown error should the bridge attempt a full reconnect? + # Must be at least 1 minute. The bridge will add an extra ±20% jitter to this value. + unknown_error_auto_reconnect: null # Should leaving Matrix rooms be bridged as leaving groups on the remote network? bridge_matrix_leave: false diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 05574e71..203dc122 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -522,6 +522,7 @@ func (ul *UserLogin) DisconnectWithTimeout(timeout time.Duration) { } func (ul *UserLogin) disconnectInternal(timeout time.Duration) { + ul.BridgeState.StopUnknownErrorReconnect() disconnected := make(chan struct{}) go func() { ul.Client.Disconnect() @@ -544,3 +545,18 @@ func (ul *UserLogin) disconnectInternal(timeout time.Duration) { } } } + +func (ul *UserLogin) recreateClient(ctx context.Context) error { + oldClient := ul.Client + err := ul.Bridge.Network.LoadUserLogin(ctx, ul) + if err != nil { + return err + } + if ul.Client == oldClient { + zerolog.Ctx(ctx).Warn().Msg("LoadUserLogin didn't update client") + } else { + zerolog.Ctx(ctx).Debug().Msg("Recreated user login client") + } + ul.disconnectOnce = sync.Once{} + return nil +} From 095c63a97eb9a55b4fd3b271f134989037fff132 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 15 Jul 2025 14:57:52 +0300 Subject: [PATCH 1233/1647] bridgev2/portal: add missing return --- bridgev2/portal.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 9fa90d89..ab1f37f1 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -300,6 +300,7 @@ func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) EventHand for { select { case portal.events <- evt: + return EventHandlingResultQueued case <-time.After(5 * time.Second): zerolog.Ctx(ctx).Error(). Str("portal_id", string(portal.ID)). From fcc72dc54b5d50ac2c7e1cb03e0158a8d3fd8e03 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 16 Jul 2025 11:06:39 +0300 Subject: [PATCH 1234/1647] dependencies: update --- go.mod | 12 ++++++------ go.sum | 24 ++++++++++++------------ 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/go.mod b/go.mod index dcc6616c..59f29c0c 100644 --- a/go.mod +++ b/go.mod @@ -20,10 +20,10 @@ require ( github.com/yuin/goldmark v1.7.12 go.mau.fi/util v0.8.8 go.mau.fi/zeroconfig v0.1.3 - golang.org/x/crypto v0.39.0 - golang.org/x/exp v0.0.0-20250606033433-dcc06ee1d476 - golang.org/x/net v0.41.0 - golang.org/x/sync v0.15.0 + golang.org/x/crypto v0.40.0 + golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc + golang.org/x/net v0.42.0 + golang.org/x/sync v0.16.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 ) @@ -37,7 +37,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect - golang.org/x/sys v0.33.0 // indirect - golang.org/x/text v0.26.0 // indirect + golang.org/x/sys v0.34.0 // indirect + golang.org/x/text v0.27.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index 779e05db..9f48386e 100644 --- a/go.sum +++ b/go.sum @@ -57,22 +57,22 @@ go.mau.fi/util v0.8.8 h1:OnuEEc/sIJFhnq4kFggiImUpcmnmL/xpvQMRu5Fiy5c= go.mau.fi/util v0.8.8/go.mod h1:Y/kS3loxTEhy8Vill513EtPXr+CRDdae+Xj2BXXMy/c= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= -golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= -golang.org/x/exp v0.0.0-20250606033433-dcc06ee1d476 h1:bsqhLWFR6G6xiQcb+JoGqdKdRU6WzPWmK8E0jxTjzo4= -golang.org/x/exp v0.0.0-20250606033433-dcc06ee1d476/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= -golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw= -golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA= -golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= -golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= +golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= +golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc h1:TS73t7x3KarrNd5qAipmspBDS1rkMcgVG/fS1aRb4Rc= +golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc= +golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= +golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= -golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= -golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= +golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= +golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= +golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= From 81a807a6c9824922c6a3ffb9ad4ffa248171d899 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 16 Jul 2025 11:32:09 +0300 Subject: [PATCH 1235/1647] Bump version to v0.24.2 --- CHANGELOG.md | 23 +++++++++++++++++++++++ version.go | 2 +- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b2eefb3d..8e71381e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,26 @@ +## v0.24.2 (2025-07-16) + +* *(bridgev2)* Added support for return values from portal event handlers. Note + that the return value will always be "queued" unless the event buffer is + disabled. +* *(bridgev2)* Added support for [MSC4144] per-message profile passthrough in + relay mode. +* *(bridgev2)* Added option to auto-reconnect logins after a certain period if + they hit an `UNKNOWN_ERROR` state. +* *(bridgev2)* Added analytics for event handler panics. +* *(bridgev2)* Changed new room creation to hardcode room v11 to avoid v12 rooms + being created before proper support for them can be added. +* *(bridgev2)* Changed queuing events to block instead of dropping events if the + buffer is full. +* *(bridgev2)* Fixed assumption that replies to unknown messages are cross-room. +* *(id)* Fixed server name validation not including ports correctly + (thanks to [@krombel] in [#392]). +* *(federation)* Fixed base64 algorithm in signature generation. +* *(event)* Fixed [MSC4144] fallbacks not being removed from edits. + +[@krombel]: https://github.com/krombel +[#392]: https://github.com/mautrix/go/pull/392 + ## v0.24.1 (2025-06-16) * *(commands)* Added framework for using reactions as buttons that execute diff --git a/version.go b/version.go index 193205ee..6b8af5ef 100644 --- a/version.go +++ b/version.go @@ -7,7 +7,7 @@ import ( "strings" ) -const Version = "v0.24.1" +const Version = "v0.24.2" var GoModVersion = "" var Commit = "" From 7ffdbe8bfc97be3acd19ea5e6f86b46a72d343be Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 17 Jul 2025 16:54:53 +0300 Subject: [PATCH 1236/1647] bridgev2/disappear: add limit to getting messages from the db --- bridgev2/database/disappear.go | 6 +++--- bridgev2/disappear.go | 28 +++++++++++++++++++++++----- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/bridgev2/database/disappear.go b/bridgev2/database/disappear.go index 23db1448..4e6f5e0a 100644 --- a/bridgev2/database/disappear.go +++ b/bridgev2/database/disappear.go @@ -61,7 +61,7 @@ const ( getUpcomingDisappearingMessagesQuery = ` SELECT bridge_id, mx_room, mxid, type, timer, disappear_at FROM disappearing_message WHERE bridge_id = $1 AND disappear_at IS NOT NULL AND disappear_at < $2 - ORDER BY disappear_at + ORDER BY disappear_at LIMIT $3 ` deleteDisappearingMessageQuery = ` DELETE FROM disappearing_message WHERE bridge_id=$1 AND mxid=$2 @@ -77,8 +77,8 @@ func (dmq *DisappearingMessageQuery) StartAll(ctx context.Context, roomID id.Roo return dmq.QueryMany(ctx, startDisappearingMessagesQuery, time.Now().UnixNano(), dmq.BridgeID, roomID) } -func (dmq *DisappearingMessageQuery) GetUpcoming(ctx context.Context, duration time.Duration) ([]*DisappearingMessage, error) { - return dmq.QueryMany(ctx, getUpcomingDisappearingMessagesQuery, dmq.BridgeID, time.Now().Add(duration).UnixNano()) +func (dmq *DisappearingMessageQuery) GetUpcoming(ctx context.Context, duration time.Duration, limit int) ([]*DisappearingMessage, error) { + return dmq.QueryMany(ctx, getUpcomingDisappearingMessagesQuery, dmq.BridgeID, time.Now().Add(duration).UnixNano(), limit) } func (dmq *DisappearingMessageQuery) Delete(ctx context.Context, eventID id.EventID) error { diff --git a/bridgev2/disappear.go b/bridgev2/disappear.go index 1d063088..8305f84b 100644 --- a/bridgev2/disappear.go +++ b/bridgev2/disappear.go @@ -36,10 +36,21 @@ func (dl *DisappearLoop) Start() { log.Debug().Msg("Disappearing message loop starting") for { dl.NextCheck = time.Now().Add(DisappearCheckInterval) - messages, err := dl.br.DB.DisappearingMessage.GetUpcoming(ctx, DisappearCheckInterval) + const MessageLimit = 200 + messages, err := dl.br.DB.DisappearingMessage.GetUpcoming(ctx, DisappearCheckInterval, MessageLimit) if err != nil { log.Err(err).Msg("Failed to get upcoming disappearing messages") } else if len(messages) > 0 { + if len(messages) > MessageLimit/2 && messages[len(messages)-1].DisappearAt.Before(time.Now()) { + // If there are many messages, and they're all due immediately, + // process them synchronously and then check again. + dl.sleepAndDisappear(ctx, messages...) + log.Debug(). + Int("message_count", len(messages)). + Time("last_due", messages[len(messages)-1].DisappearAt). + Msg("Checking for disappearing messages again immediately") + continue + } go dl.sleepAndDisappear(ctx, messages...) } select { @@ -91,10 +102,17 @@ func (dl *DisappearLoop) Add(ctx context.Context, dm *database.DisappearingMessa func (dl *DisappearLoop) sleepAndDisappear(ctx context.Context, dms ...*database.DisappearingMessage) { for _, msg := range dms { - select { - case <-time.After(time.Until(msg.DisappearAt)): - case <-ctx.Done(): - return + timeUntilDisappear := time.Until(msg.DisappearAt) + if timeUntilDisappear <= 0 { + if ctx.Err() != nil { + return + } + } else { + select { + case <-time.After(timeUntilDisappear): + case <-ctx.Done(): + return + } } resp, err := dl.br.Bot.SendMessage(ctx, msg.RoomID, event.EventRedaction, &event.Content{ Parsed: &event.RedactionEventContent{ From 8efdbc029bbde2f9192be0834eaec9fba99d4352 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 17 Jul 2025 17:20:28 +0300 Subject: [PATCH 1237/1647] bridgev2/disappear: reduce disappear loop interval when there are lots of messages --- bridgev2/disappear.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/bridgev2/disappear.go b/bridgev2/disappear.go index 8305f84b..546118de 100644 --- a/bridgev2/disappear.go +++ b/bridgev2/disappear.go @@ -41,15 +41,22 @@ func (dl *DisappearLoop) Start() { if err != nil { log.Err(err).Msg("Failed to get upcoming disappearing messages") } else if len(messages) > 0 { - if len(messages) > MessageLimit/2 && messages[len(messages)-1].DisappearAt.Before(time.Now()) { + lastDisappearTime := messages[len(messages)-1].DisappearAt + if len(messages) > MessageLimit/2 && lastDisappearTime.Before(time.Now()) { // If there are many messages, and they're all due immediately, // process them synchronously and then check again. dl.sleepAndDisappear(ctx, messages...) log.Debug(). Int("message_count", len(messages)). - Time("last_due", messages[len(messages)-1].DisappearAt). + Time("last_due", lastDisappearTime). Msg("Checking for disappearing messages again immediately") continue + } else if len(messages) >= MessageLimit && lastDisappearTime.Add(5*time.Second).Before(dl.NextCheck) { + log.Debug(). + Int("message_count", len(messages)). + Time("last_due", lastDisappearTime). + Msg("Using lower disappearing message check interval as the limit was reached, but the last message isn't due yet") + dl.NextCheck = lastDisappearTime.Add(5 * time.Second) } go dl.sleepAndDisappear(ctx, messages...) } From 5a9e20e4511dbea327ae7846d27cadde1a976908 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 17 Jul 2025 17:27:08 +0300 Subject: [PATCH 1238/1647] bridgev2/disappear: always delete synchronously if limit is reached --- bridgev2/disappear.go | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/bridgev2/disappear.go b/bridgev2/disappear.go index 546118de..58ff9bf6 100644 --- a/bridgev2/disappear.go +++ b/bridgev2/disappear.go @@ -42,21 +42,16 @@ func (dl *DisappearLoop) Start() { log.Err(err).Msg("Failed to get upcoming disappearing messages") } else if len(messages) > 0 { lastDisappearTime := messages[len(messages)-1].DisappearAt - if len(messages) > MessageLimit/2 && lastDisappearTime.Before(time.Now()) { + if len(messages) >= MessageLimit { + log.Debug(). + Int("message_count", len(messages)). + Time("last_due", lastDisappearTime). + Msg("Deleting disappearing messages synchronously and checking again immediately") + dl.NextCheck = lastDisappearTime // If there are many messages, and they're all due immediately, // process them synchronously and then check again. dl.sleepAndDisappear(ctx, messages...) - log.Debug(). - Int("message_count", len(messages)). - Time("last_due", lastDisappearTime). - Msg("Checking for disappearing messages again immediately") continue - } else if len(messages) >= MessageLimit && lastDisappearTime.Add(5*time.Second).Before(dl.NextCheck) { - log.Debug(). - Int("message_count", len(messages)). - Time("last_due", lastDisappearTime). - Msg("Using lower disappearing message check interval as the limit was reached, but the last message isn't due yet") - dl.NextCheck = lastDisappearTime.Add(5 * time.Second) } go dl.sleepAndDisappear(ctx, messages...) } From 0508f02a9e1ce38e686e62d78602d417689d0b13 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 17 Jul 2025 17:35:03 +0300 Subject: [PATCH 1239/1647] bridgev2/disappear: make next check field atomic --- bridgev2/disappear.go | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/bridgev2/disappear.go b/bridgev2/disappear.go index 58ff9bf6..f072c01f 100644 --- a/bridgev2/disappear.go +++ b/bridgev2/disappear.go @@ -21,7 +21,7 @@ import ( type DisappearLoop struct { br *Bridge - NextCheck time.Time + nextCheck atomic.Pointer[time.Time] stop atomic.Pointer[context.CancelFunc] } @@ -35,28 +35,30 @@ func (dl *DisappearLoop) Start() { } log.Debug().Msg("Disappearing message loop starting") for { - dl.NextCheck = time.Now().Add(DisappearCheckInterval) + nextCheck := time.Now().Add(DisappearCheckInterval) + dl.nextCheck.Store(&nextCheck) const MessageLimit = 200 messages, err := dl.br.DB.DisappearingMessage.GetUpcoming(ctx, DisappearCheckInterval, MessageLimit) if err != nil { log.Err(err).Msg("Failed to get upcoming disappearing messages") } else if len(messages) > 0 { - lastDisappearTime := messages[len(messages)-1].DisappearAt if len(messages) >= MessageLimit { + lastDisappearTime := messages[len(messages)-1].DisappearAt log.Debug(). Int("message_count", len(messages)). Time("last_due", lastDisappearTime). Msg("Deleting disappearing messages synchronously and checking again immediately") - dl.NextCheck = lastDisappearTime - // If there are many messages, and they're all due immediately, - // process them synchronously and then check again. + // Store the expected next check time to avoid Add spawning unnecessary goroutines. + // This can be in the past, in which case Add will put everything in the db, which is also fine. + dl.nextCheck.Store(&lastDisappearTime) + // If there are many messages, process them synchronously and then check again. dl.sleepAndDisappear(ctx, messages...) continue } go dl.sleepAndDisappear(ctx, messages...) } select { - case <-time.After(time.Until(dl.NextCheck)): + case <-time.After(time.Until(dl.GetNextCheck())): case <-ctx.Done(): log.Debug().Msg("Disappearing message loop stopping") return @@ -64,6 +66,17 @@ func (dl *DisappearLoop) Start() { } } +func (dl *DisappearLoop) GetNextCheck() time.Time { + if dl == nil { + return time.Time{} + } + nextCheck := dl.nextCheck.Load() + if nextCheck == nil { + return time.Time{} + } + return *nextCheck +} + func (dl *DisappearLoop) Stop() { if dl == nil { return @@ -80,7 +93,7 @@ func (dl *DisappearLoop) StartAll(ctx context.Context, roomID id.RoomID) { return } startedMessages = slices.DeleteFunc(startedMessages, func(dm *database.DisappearingMessage) bool { - return dm.DisappearAt.After(dl.NextCheck) + return dm.DisappearAt.After(dl.GetNextCheck()) }) slices.SortFunc(startedMessages, func(a, b *database.DisappearingMessage) int { return a.DisappearAt.Compare(b.DisappearAt) @@ -97,7 +110,7 @@ func (dl *DisappearLoop) Add(ctx context.Context, dm *database.DisappearingMessa Stringer("event_id", dm.EventID). Msg("Failed to save disappearing message") } - if !dm.DisappearAt.IsZero() && dm.DisappearAt.Before(dl.NextCheck) { + if !dm.DisappearAt.IsZero() && dm.DisappearAt.Before(dl.GetNextCheck()) { go dl.sleepAndDisappear(zerolog.Ctx(ctx).WithContext(dl.br.BackgroundCtx), dm) } } From 90a7dc3c75196da529f281cde5ff43282c1dd43c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 18 Jul 2025 16:05:04 +0300 Subject: [PATCH 1240/1647] bridgev2/portal: ignore delete for me in multi-user portals --- bridgev2/portal.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index ab1f37f1..a36524d8 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2784,7 +2784,14 @@ func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *Use onlyForMeProvider, ok := evt.(RemoteDeleteOnlyForMe) onlyForMe := ok && onlyForMeProvider.DeleteOnlyForMe() if onlyForMe && portal.Receiver == "" { - // TODO check if there are other user logins before deleting + logins, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to check if portal has other logins") + return EventHandlingResultFailed + } else if len(logins) > 1 { + log.Debug().Msg("Ignoring delete for me event in portal with multiple logins") + return EventHandlingResultIgnored + } } intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventMessageRemove) From c7263bab40c0a183197fe11c841e00761ec4ee6d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 18 Jul 2025 17:37:45 +0300 Subject: [PATCH 1241/1647] bridgev2/portal: add support for following tombstones --- bridgev2/database/userportal.go | 7 ++ bridgev2/matrix/connector.go | 1 + bridgev2/matrix/matrix.go | 2 +- bridgev2/portal.go | 217 ++++++++++++++++++++++++++++---- 4 files changed, 199 insertions(+), 28 deletions(-) diff --git a/bridgev2/database/userportal.go b/bridgev2/database/userportal.go index 278b236b..e928a4c7 100644 --- a/bridgev2/database/userportal.go +++ b/bridgev2/database/userportal.go @@ -67,6 +67,9 @@ const ( markLoginAsPreferredQuery = ` UPDATE user_portal SET preferred=(login_id=$3) WHERE bridge_id=$1 AND user_mxid=$2 AND portal_id=$4 AND portal_receiver=$5 ` + markAllNotInSpaceQuery = ` + UPDATE user_portal SET in_space=false WHERE bridge_id=$1 AND portal_id=$2 AND portal_receiver=$3 + ` deleteUserPortalQuery = ` DELETE FROM user_portal WHERE bridge_id=$1 AND user_mxid=$2 AND login_id=$3 AND portal_id=$4 AND portal_receiver=$5 ` @@ -110,6 +113,10 @@ func (upq *UserPortalQuery) MarkAsPreferred(ctx context.Context, login *UserLogi return upq.Exec(ctx, markLoginAsPreferredQuery, upq.BridgeID, login.UserMXID, login.ID, portal.ID, portal.Receiver) } +func (upq *UserPortalQuery) MarkAllNotInSpace(ctx context.Context, portal networkid.PortalKey) error { + return upq.Exec(ctx, markAllNotInSpaceQuery, upq.BridgeID, portal.ID, portal.Receiver) +} + func (upq *UserPortalQuery) Delete(ctx context.Context, up *UserPortal) error { return upq.Exec(ctx, deleteUserPortalQuery, up.BridgeID, up.UserMXID, up.LoginID, up.Portal.ID, up.Portal.Receiver) } diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 7af2d128..7075a1aa 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -145,6 +145,7 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) { br.EventProcessor.On(event.StateRoomName, br.handleRoomEvent) br.EventProcessor.On(event.StateRoomAvatar, br.handleRoomEvent) br.EventProcessor.On(event.StateTopic, br.handleRoomEvent) + br.EventProcessor.On(event.StateTombstone, br.handleRoomEvent) br.EventProcessor.On(event.EphemeralEventReceipt, br.handleEphemeralEvent) br.EventProcessor.On(event.EphemeralEventTyping, br.handleEphemeralEvent) br.Bot = br.AS.BotIntent() diff --git a/bridgev2/matrix/matrix.go b/bridgev2/matrix/matrix.go index 84e85d24..fed9d37a 100644 --- a/bridgev2/matrix/matrix.go +++ b/bridgev2/matrix/matrix.go @@ -169,7 +169,7 @@ func (br *Connector) shouldIgnoreEventFromUser(userID id.UserID) bool { } func (br *Connector) shouldIgnoreEvent(evt *event.Event) bool { - if br.shouldIgnoreEventFromUser(evt.Sender) { + if br.shouldIgnoreEventFromUser(evt.Sender) && evt.Type != event.StateTombstone { return true } dpVal, ok := evt.Content.Raw[appservice.DoublePuppetKey] diff --git a/bridgev2/portal.go b/bridgev2/portal.go index a36524d8..82e76318 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -580,6 +580,10 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * return EventHandlingResultIgnored } } + if evt.Type == event.StateTombstone { + // Tombstones aren't bridged so they don't need a login + return portal.handleMatrixTombstone(ctx, evt) + } login, _, err := portal.FindPreferredLogin(ctx, sender, true) if err != nil { log.Err(err).Msg("Failed to get user login to handle Matrix event") @@ -1716,6 +1720,158 @@ func (portal *Portal) handleMatrixPowerLevels( return EventHandlingResultSuccess.WithMSS() } +func (portal *Portal) handleMatrixTombstone(ctx context.Context, evt *event.Event) EventHandlingResult { + if evt.StateKey == nil || *evt.StateKey != "" || portal.MXID != evt.RoomID { + return EventHandlingResultIgnored + } + log := *zerolog.Ctx(ctx) + sentByBridge := evt.Sender == portal.Bridge.Bot.GetMXID() || portal.Bridge.IsGhostMXID(evt.Sender) + var senderUser *User + var err error + if !sentByBridge { + senderUser, err = portal.Bridge.GetUserByMXID(ctx, evt.Sender) + if err != nil { + log.Err(err).Msg("Failed to get tombstone sender user") + return EventHandlingResultFailed + } + } + content, ok := evt.Content.Parsed.(*event.TombstoneEventContent) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + } + log = log.With(). + Stringer("replacement_room", content.ReplacementRoom). + Logger() + if content.ReplacementRoom == "" { + log.Info().Msg("Received tombstone with no replacement room, cleaning up portal") + err := portal.RemoveMXID(ctx) + if err != nil { + log.Err(err).Msg("Failed to remove portal MXID") + return EventHandlingResultFailed.WithMSSError(err) + } + err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, true) + if err != nil { + log.Err(err).Msg("Failed to clean up Matrix room") + return EventHandlingResultFailed + } + return EventHandlingResultSuccess + } + existingMemberEvt, err := portal.Bridge.Matrix.GetMemberInfo(ctx, content.ReplacementRoom, portal.Bridge.Bot.GetMXID()) + if err != nil { + log.Err(err).Msg("Failed to get member info of bot in replacement room") + return EventHandlingResultFailed + } + leaveOnError := func() { + if existingMemberEvt != nil && existingMemberEvt.Membership == event.MembershipJoin { + return + } + log.Debug().Msg("Leaving replacement room with bot after tombstone validation failed") + _, err = portal.Bridge.Bot.SendState( + ctx, + content.ReplacementRoom, + event.StateMember, + portal.Bridge.Bot.GetMXID().String(), + &event.Content{ + Parsed: &event.MemberEventContent{ + Membership: event.MembershipLeave, + Reason: fmt.Sprintf("Failed to validate tombstone sent by %s from %s", evt.Sender, evt.RoomID), + }, + }, + time.Time{}, + ) + if err != nil { + log.Err(err).Msg("Failed to leave replacement room after tombstone validation failed") + } + } + err = portal.Bridge.Bot.EnsureJoined(ctx, content.ReplacementRoom) + if err != nil { + log.Err(err).Msg("Failed to join replacement room from tombstone") + return EventHandlingResultFailed + } + if !sentByBridge && !senderUser.Permissions.Admin { + powers, err := portal.Bridge.Matrix.GetPowerLevels(ctx, content.ReplacementRoom) + if err != nil { + log.Err(err).Msg("Failed to get power levels in replacement room") + leaveOnError() + return EventHandlingResultFailed + } + if powers.GetUserLevel(evt.Sender) < powers.Invite() { + log.Warn().Msg("Tombstone sender doesn't have enough power to invite the bot to the replacement room") + leaveOnError() + return EventHandlingResultIgnored + } + } + + portal.Bridge.cacheLock.Lock() + if _, alreadyExists := portal.Bridge.portalsByMXID[content.ReplacementRoom]; alreadyExists { + log.Warn().Msg("Replacement room is already a portal, ignoring tombstone") + portal.Bridge.cacheLock.Unlock() + return EventHandlingResultIgnored + } + delete(portal.Bridge.portalsByMXID, portal.MXID) + portal.MXID = content.ReplacementRoom + portal.Bridge.portalsByMXID[portal.MXID] = portal + portal.NameSet = false + portal.AvatarSet = false + portal.TopicSet = false + portal.InSpace = false + portal.CapState = database.CapabilityState{} + portal.Bridge.cacheLock.Unlock() + + err = portal.Save(ctx) + if err != nil { + log.Err(err).Msg("Failed to save portal after tombstone") + return EventHandlingResultFailed + } + log.Info().Msg("Successfully followed tombstone and updated portal MXID") + err = portal.Bridge.DB.UserPortal.MarkAllNotInSpace(ctx, portal.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to update in_space flag for user portals after tombstone") + } + go portal.addToUserSpaces(ctx) + go portal.updateInfoAfterTombstone(ctx, senderUser) + go func() { + err = portal.Bridge.Bot.DeleteRoom(ctx, evt.RoomID, true) + if err != nil { + log.Err(err).Msg("Failed to clean up Matrix room after following tombstone") + } + }() + return EventHandlingResultSuccess +} + +func (portal *Portal) updateInfoAfterTombstone(ctx context.Context, senderUser *User) { + log := zerolog.Ctx(ctx) + logins, err := portal.Bridge.GetUserLoginsInPortal(ctx, portal.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to get user logins in portal to sync info") + return + } + var preferredLogin *UserLogin + for _, login := range logins { + if !login.Client.IsLoggedIn() { + continue + } else if preferredLogin == nil { + preferredLogin = login + } else if senderUser != nil && login.User == senderUser { + preferredLogin = login + } + } + if preferredLogin == nil { + log.Warn().Msg("No logins found to sync info") + return + } + info, err := preferredLogin.Client.GetChatInfo(ctx, portal) + if err != nil { + log.Err(err).Msg("Failed to get chat info") + return + } + log.Info(). + Str("info_source_login", string(preferredLogin.ID)). + Msg("Fetched info to update portal after tombstone") + portal.UpdateInfo(ctx, info, preferredLogin, nil, time.Time{}) +} + func (portal *Portal) handleMatrixRedaction( ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, ) EventHandlingResult { @@ -4203,39 +4359,46 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo } } } - if portal.Parent == nil { - if portal.Receiver != "" { - login := portal.Bridge.GetCachedUserLoginByID(portal.Receiver) - if login != nil { - up, err := portal.Bridge.DB.UserPortal.Get(ctx, login.UserLogin, portal.PortalKey) - if err != nil { - log.Err(err).Msg("Failed to get user portal to add portal to spaces") - } else { - login.inPortalCache.Remove(portal.PortalKey) - go login.tryAddPortalToSpace(withoutCancelCtx, portal, up.CopyWithoutValues()) - } - } - } else { - userPortals, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) - if err != nil { - log.Err(err).Msg("Failed to get user logins in portal to add portal to spaces") - } else { - for _, up := range userPortals { - login := portal.Bridge.GetCachedUserLoginByID(up.LoginID) - if login != nil { - login.inPortalCache.Remove(portal.PortalKey) - go login.tryAddPortalToSpace(withoutCancelCtx, portal, up.CopyWithoutValues()) - } - } - } - } - } + portal.addToUserSpaces(ctx) if portal.Bridge.Config.Backfill.Enabled && portal.RoomType != database.RoomTypeSpace && !portal.Bridge.Background { portal.doForwardBackfill(ctx, source, nil, backfillBundle) } return nil } +func (portal *Portal) addToUserSpaces(ctx context.Context) { + if portal.Parent == nil { + return + } + log := zerolog.Ctx(ctx) + withoutCancelCtx := log.WithContext(portal.Bridge.BackgroundCtx) + if portal.Receiver != "" { + login := portal.Bridge.GetCachedUserLoginByID(portal.Receiver) + if login != nil { + up, err := portal.Bridge.DB.UserPortal.Get(ctx, login.UserLogin, portal.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to get user portal to add portal to spaces") + } else { + login.inPortalCache.Remove(portal.PortalKey) + go login.tryAddPortalToSpace(withoutCancelCtx, portal, up.CopyWithoutValues()) + } + } + } else { + userPortals, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to get user logins in portal to add portal to spaces") + } else { + for _, up := range userPortals { + login := portal.Bridge.GetCachedUserLoginByID(up.LoginID) + if login != nil { + login.inPortalCache.Remove(portal.PortalKey) + go login.tryAddPortalToSpace(withoutCancelCtx, portal, up.CopyWithoutValues()) + } + } + } + } +} + func (portal *Portal) Delete(ctx context.Context) error { portal.removeInPortalCache(ctx) err := portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey) From 9a170d26695315d363dc74ec6db17fda20dbdaea Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 18 Jul 2025 17:55:27 +0300 Subject: [PATCH 1242/1647] bridgev2,appservice: add via to EnsureJoined and use it for tombstone handling --- appservice/intent.go | 11 +++++++++-- bridgev2/matrix/intent.go | 8 ++++++-- bridgev2/matrixinterface.go | 6 +++++- bridgev2/portal.go | 6 +++++- 4 files changed, 25 insertions(+), 6 deletions(-) diff --git a/appservice/intent.go b/appservice/intent.go index d6cda137..194057f7 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -86,6 +86,7 @@ func (intent *IntentAPI) EnsureRegistered(ctx context.Context) error { type EnsureJoinedParams struct { IgnoreCache bool BotOverride *mautrix.Client + Via []string } func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, extra ...EnsureJoinedParams) error { @@ -99,11 +100,17 @@ func (intent *IntentAPI) EnsureJoined(ctx context.Context, roomID id.RoomID, ext return nil } - if err := intent.EnsureRegistered(ctx); err != nil { + err := intent.EnsureRegistered(ctx) + if err != nil { return fmt.Errorf("failed to ensure joined: %w", err) } - resp, err := intent.JoinRoomByID(ctx, roomID) + var resp *mautrix.RespJoinRoom + if len(params.Via) > 0 { + resp, err = intent.JoinRoom(ctx, roomID.String(), &mautrix.ReqJoinRoom{Via: params.Via}) + } else { + resp, err = intent.JoinRoomByID(ctx, roomID) + } if err != nil { bot := intent.bot if params.BotOverride != nil { diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 2088d5b1..4a337e53 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -490,8 +490,12 @@ func (as *ASIntent) IsDoublePuppet() bool { return as.Matrix.IsDoublePuppet() } -func (as *ASIntent) EnsureJoined(ctx context.Context, roomID id.RoomID) error { - err := as.Matrix.EnsureJoined(ctx, roomID) +func (as *ASIntent) EnsureJoined(ctx context.Context, roomID id.RoomID, extra ...bridgev2.EnsureJoinedParams) error { + var params bridgev2.EnsureJoinedParams + if len(extra) > 0 { + params = extra[0] + } + err := as.Matrix.EnsureJoined(ctx, roomID, appservice.EnsureJoinedParams{Via: params.Via}) if err != nil { return err } diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index ae1b99d7..c1bd69b8 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -144,6 +144,10 @@ func (ce CallbackError) Unwrap() error { return ce.Wrapped } +type EnsureJoinedParams struct { + Via []string +} + type MatrixAPI interface { GetMXID() id.UserID IsDoublePuppet() bool @@ -164,7 +168,7 @@ type MatrixAPI interface { CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) (id.RoomID, error) DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnly bool) error - EnsureJoined(ctx context.Context, roomID id.RoomID) error + EnsureJoined(ctx context.Context, roomID id.RoomID, params ...EnsureJoinedParams) error EnsureInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) error TagRoom(ctx context.Context, roomID id.RoomID, tag event.RoomTag, isTagged bool) error diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 82e76318..b9ea3385 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1784,7 +1784,11 @@ func (portal *Portal) handleMatrixTombstone(ctx context.Context, evt *event.Even log.Err(err).Msg("Failed to leave replacement room after tombstone validation failed") } } - err = portal.Bridge.Bot.EnsureJoined(ctx, content.ReplacementRoom) + var via []string + if senderHS := evt.Sender.Homeserver(); senderHS != "" { + via = []string{senderHS} + } + err = portal.Bridge.Bot.EnsureJoined(ctx, content.ReplacementRoom, EnsureJoinedParams{Via: via}) if err != nil { log.Err(err).Msg("Failed to join replacement room from tombstone") return EventHandlingResultFailed From 237ce1c64c7f079d38fe6a630770181ea8d4842f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 11 Jul 2025 12:55:34 +0300 Subject: [PATCH 1243/1647] client: remove redundant state store update in room create --- client.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/client.go b/client.go index 6f746015..7a83619f 100644 --- a/client.go +++ b/client.go @@ -1378,9 +1378,6 @@ func (cli *Client) CreateRoom(ctx context.Context, req *ReqCreateRoom) (resp *Re Msg("Failed to update membership in state store after creating room") } } - for _, evt := range req.InitialState { - cli.updateStoreWithOutgoingEvent(ctx, resp.RoomID, evt.Type, evt.GetStateKey(), &evt.Content) - } } return } From 0b62253d3b48ec0ea3540af0568d74fd889f9ecc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 11 Jul 2025 12:55:44 +0300 Subject: [PATCH 1244/1647] all: add support for creator power --- appservice/intent.go | 27 +++++- bridgev2/matrix/connector.go | 8 ++ bridgev2/matrixinterface.go | 1 + bridgev2/portal.go | 8 ++ client.go | 18 +++- event/powerlevels.go | 30 ++++++- event/state.go | 14 ++++ sqlstatestore/statestore.go | 116 +++++++++++--------------- sqlstatestore/v00-latest-revision.sql | 3 +- sqlstatestore/v08-create-event.sql | 2 + statestore.go | 29 ++++++- 11 files changed, 178 insertions(+), 78 deletions(-) create mode 100644 sqlstatestore/v08-create-event.sql diff --git a/appservice/intent.go b/appservice/intent.go index 194057f7..a1245d74 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -375,6 +375,24 @@ func (intent *IntentAPI) Member(ctx context.Context, roomID id.RoomID, userID id return member } +func (intent *IntentAPI) FillPowerLevelCreateEvent(ctx context.Context, roomID id.RoomID, pl *event.PowerLevelsEventContent) error { + if pl.CreateEvent != nil { + return nil + } + var err error + pl.CreateEvent, err = intent.StateStore.GetCreate(ctx, roomID) + if err != nil { + return fmt.Errorf("failed to get create event from cache: %w", err) + } else if pl.CreateEvent != nil { + return nil + } + pl.CreateEvent, err = intent.FullStateEvent(ctx, roomID, event.StateCreate, "") + if err != nil { + return fmt.Errorf("failed to get create event from server: %w", err) + } + return nil +} + func (intent *IntentAPI) PowerLevels(ctx context.Context, roomID id.RoomID) (pl *event.PowerLevelsEventContent, err error) { pl, err = intent.as.StateStore.GetPowerLevels(ctx, roomID) if err != nil { @@ -384,6 +402,12 @@ func (intent *IntentAPI) PowerLevels(ctx context.Context, roomID id.RoomID) (pl if pl == nil { pl = &event.PowerLevelsEventContent{} err = intent.StateEvent(ctx, roomID, event.StatePowerLevels, "", pl) + if err != nil { + return + } + } + if pl.CreateEvent == nil { + pl.CreateEvent, err = intent.FullStateEvent(ctx, roomID, event.StateCreate, "") } return } @@ -398,8 +422,7 @@ func (intent *IntentAPI) SetPowerLevel(ctx context.Context, roomID id.RoomID, us return nil, err } - if pl.GetUserLevel(userID) != level { - pl.SetUserLevel(userID, level) + if pl.EnsureUserLevelAs(intent.UserID, userID, level) { return intent.SendStateEvent(ctx, roomID, event.StatePowerLevels, "", &pl) } return nil, nil diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 7075a1aa..978f666f 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -534,6 +534,14 @@ func (br *Connector) GetPowerLevels(ctx context.Context, roomID id.RoomID) (*eve return br.Bot.PowerLevels(ctx, roomID) } +func (br *Connector) GetCreateEvent(ctx context.Context, roomID id.RoomID) (*event.Event, error) { + createEvt, err := br.Bot.StateStore.GetCreate(ctx, roomID) + if err != nil || createEvt != nil { + return createEvt, err + } + return br.Bot.FullStateEvent(ctx, roomID, event.StateCreate, "") +} + func (br *Connector) GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { fetched, err := br.Bot.StateStore.HasFetchedMembers(ctx, roomID) if err != nil { diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index c1bd69b8..5d0cb014 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -47,6 +47,7 @@ type MatrixConnector interface { GenerateContentURI(ctx context.Context, mediaID networkid.MediaID) (id.ContentURIString, error) GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error) + GetCreateEvent(ctx context.Context, roomID id.RoomID) (*event.Event, error) GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) GetMemberInfo(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index b9ea3385..c91523ef 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1654,6 +1654,13 @@ func (portal *Portal) handleMatrixPowerLevels( log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) } + if content.CreateEvent == nil { + var err error + content.CreateEvent, err = portal.Bridge.Matrix.GetCreateEvent(ctx, portal.MXID) + if err != nil { + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("failed to get create event for power levels: %w", err)) + } + } api, ok := sender.Client.(PowerLevelHandlingNetworkAPI) if !ok { return EventHandlingResultIgnored.WithMSSError(ErrPowerLevelsNotSupported) @@ -1662,6 +1669,7 @@ func (portal *Portal) handleMatrixPowerLevels( if evt.Unsigned.PrevContent != nil { _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) prevContent, _ = evt.Unsigned.PrevContent.Parsed.(*event.PowerLevelsEventContent) + prevContent.CreateEvent = content.CreateEvent } plChange := &MatrixPowerLevelChange{ diff --git a/client.go b/client.go index 7a83619f..886dbb63 100644 --- a/client.go +++ b/client.go @@ -1548,12 +1548,15 @@ func (cli *Client) FullStateEvent(ctx context.Context, roomID id.RoomID, eventTy "format": "event", }) _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &evt) - if err == nil && cli.StateStore != nil { - UpdateStateStore(ctx, cli.StateStore, evt) - } if evt != nil { evt.Type.Class = event.StateEventType _ = evt.Content.ParseRaw(evt.Type) + if evt.RoomID == "" { + evt.RoomID = roomID + } + } + if err == nil && cli.StateStore != nil { + UpdateStateStore(ctx, cli.StateStore, evt) } return } @@ -1606,12 +1609,21 @@ func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomSt ResponseJSON: &stateMap, Handler: parseRoomStateArray, }) + if stateMap != nil { + pls, ok := stateMap[event.StatePowerLevels][""] + if ok { + pls.Content.AsPowerLevels().CreateEvent = stateMap[event.StateCreate][""] + } + } if err == nil && cli.StateStore != nil { for evtType, evts := range stateMap { if evtType == event.StateMember { continue } for _, evt := range evts { + if evt.RoomID == "" { + evt.RoomID = roomID + } UpdateStateStore(ctx, cli.StateStore, evt) } } diff --git a/event/powerlevels.go b/event/powerlevels.go index 2f4d4573..79dbd1f3 100644 --- a/event/powerlevels.go +++ b/event/powerlevels.go @@ -7,6 +7,8 @@ package event import ( + "math" + "slices" "sync" "go.mau.fi/util/ptr" @@ -34,6 +36,10 @@ type PowerLevelsEventContent struct { KickPtr *int `json:"kick,omitempty"` BanPtr *int `json:"ban,omitempty"` RedactPtr *int `json:"redact,omitempty"` + + // This is not a part of power levels, it's added by mautrix-go internally in certain places + // in order to detect creator power accurately. + CreateEvent *Event `json:"-,omitempty"` } func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent { @@ -53,6 +59,8 @@ func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent { KickPtr: ptr.Clone(pl.KickPtr), BanPtr: ptr.Clone(pl.BanPtr), RedactPtr: ptr.Clone(pl.RedactPtr), + + CreateEvent: pl.CreateEvent, } } @@ -112,6 +120,9 @@ func (pl *PowerLevelsEventContent) StateDefault() int { } func (pl *PowerLevelsEventContent) GetUserLevel(userID id.UserID) int { + if pl.isCreator(userID) { + return math.MaxInt + } pl.usersLock.RLock() defer pl.usersLock.RUnlock() level, ok := pl.Users[userID] @@ -138,9 +149,24 @@ func (pl *PowerLevelsEventContent) EnsureUserLevel(target id.UserID, level int) return pl.EnsureUserLevelAs("", target, level) } +func (pl *PowerLevelsEventContent) createContent() *CreateEventContent { + if pl.CreateEvent == nil { + return &CreateEventContent{} + } + return pl.CreateEvent.Content.AsCreate() +} + +func (pl *PowerLevelsEventContent) isCreator(userID id.UserID) bool { + cc := pl.createContent() + return cc.SupportsCreatorPower() && (userID == pl.CreateEvent.Sender || slices.Contains(cc.AdditionalCreators, userID)) +} + func (pl *PowerLevelsEventContent) EnsureUserLevelAs(actor, target id.UserID, level int) bool { + if pl.isCreator(target) { + return false + } existingLevel := pl.GetUserLevel(target) - if actor != "" { + if actor != "" && !pl.isCreator(actor) { actorLevel := pl.GetUserLevel(actor) if actorLevel <= existingLevel || actorLevel < level { return false @@ -185,7 +211,7 @@ func (pl *PowerLevelsEventContent) EnsureEventLevel(eventType Type, level int) b func (pl *PowerLevelsEventContent) EnsureEventLevelAs(actor id.UserID, eventType Type, level int) bool { existingLevel := pl.GetEventLevel(eventType) - if actor != "" { + if actor != "" && !pl.isCreator(actor) { actorLevel := pl.GetUserLevel(actor) if existingLevel > actorLevel || level > actorLevel { return false diff --git a/event/state.go b/event/state.go index 028691e1..ff6dabfa 100644 --- a/event/state.go +++ b/event/state.go @@ -88,6 +88,7 @@ const ( RoomV9 RoomVersion = "9" RoomV10 RoomVersion = "10" RoomV11 RoomVersion = "11" + RoomV12 RoomVersion = "12" ) // CreateEventContent represents the content of a m.room.create state event. @@ -98,10 +99,23 @@ type CreateEventContent struct { RoomVersion RoomVersion `json:"room_version,omitempty"` Predecessor *Predecessor `json:"predecessor,omitempty"` + // Room v12+ only + AdditionalCreators []id.UserID `json:"additional_creators,omitempty"` + // Deprecated: use the event sender instead Creator id.UserID `json:"creator,omitempty"` } +func (cec *CreateEventContent) SupportsCreatorPower() bool { + switch cec.RoomVersion { + case "", RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11: + return false + default: + // Assume anything except known old versions supports creator power. + return true + } +} + // JoinRule specifies how open a room is to new members. // https://spec.matrix.org/v1.2/client-server-api/#mroomjoin_rules type JoinRule string diff --git a/sqlstatestore/statestore.go b/sqlstatestore/statestore.go index 4a220a2b..f9a7e421 100644 --- a/sqlstatestore/statestore.go +++ b/sqlstatestore/statestore.go @@ -379,89 +379,67 @@ func (store *SQLStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID } func (store *SQLStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) { + levels = &event.PowerLevelsEventContent{} err = store. - QueryRow(ctx, "SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID). - Scan(&dbutil.JSON{Data: &levels}) + QueryRow(ctx, "SELECT power_levels, create_event FROM mx_room_state WHERE room_id=$1", roomID). + Scan(&dbutil.JSON{Data: &levels}, &dbutil.JSON{Data: &levels.CreateEvent}) if errors.Is(err, sql.ErrNoRows) { - err = nil + return nil, nil + } else if err != nil { + return nil, err + } + if levels.CreateEvent != nil { + err = levels.CreateEvent.Content.ParseRaw(event.StateCreate) } return } func (store *SQLStateStore) GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error) { - if store.Dialect == dbutil.Postgres { - var powerLevel int - err := store. - QueryRow(ctx, ` - SELECT COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0) - FROM mx_room_state WHERE room_id=$1 - `, roomID, userID). - Scan(&powerLevel) - return powerLevel, err - } else { - levels, err := store.GetPowerLevels(ctx, roomID) - if err != nil { - return 0, err - } - return levels.GetUserLevel(userID), nil + levels, err := store.GetPowerLevels(ctx, roomID) + if err != nil { + return 0, err } + return levels.GetUserLevel(userID), nil } func (store *SQLStateStore) GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error) { - if store.Dialect == dbutil.Postgres { - defaultType := "events_default" - defaultValue := 0 - if eventType.IsState() { - defaultType = "state_default" - defaultValue = 50 - } - var powerLevel int - err := store. - QueryRow(ctx, ` - SELECT COALESCE((power_levels->'events'->$2)::int, (power_levels->'$3')::int, $4) - FROM mx_room_state WHERE room_id=$1 - `, roomID, eventType.Type, defaultType, defaultValue). - Scan(&powerLevel) - if errors.Is(err, sql.ErrNoRows) { - err = nil - powerLevel = defaultValue - } - return powerLevel, err - } else { - levels, err := store.GetPowerLevels(ctx, roomID) - if err != nil { - return 0, err - } - return levels.GetEventLevel(eventType), nil + levels, err := store.GetPowerLevels(ctx, roomID) + if err != nil { + return 0, err } + return levels.GetEventLevel(eventType), nil } func (store *SQLStateStore) HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error) { - if store.Dialect == dbutil.Postgres { - defaultType := "events_default" - defaultValue := 0 - if eventType.IsState() { - defaultType = "state_default" - defaultValue = 50 - } - var hasPower bool - err := store. - QueryRow(ctx, `SELECT - COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0) - >= - COALESCE((power_levels->'events'->$3)::int, (power_levels->'$4')::int, $5) - FROM mx_room_state WHERE room_id=$1`, roomID, userID, eventType.Type, defaultType, defaultValue). - Scan(&hasPower) - if errors.Is(err, sql.ErrNoRows) { - err = nil - hasPower = defaultValue == 0 - } - return hasPower, err - } else { - levels, err := store.GetPowerLevels(ctx, roomID) - if err != nil { - return false, err - } - return levels.GetUserLevel(userID) >= levels.GetEventLevel(eventType), nil + levels, err := store.GetPowerLevels(ctx, roomID) + if err != nil { + return false, err } + return levels.GetUserLevel(userID) >= levels.GetEventLevel(eventType), nil +} + +func (store *SQLStateStore) SetCreate(ctx context.Context, evt *event.Event) error { + if evt.Type != event.StateCreate { + return fmt.Errorf("invalid event type for create event: %s", evt.Type) + } + _, err := store.Exec(ctx, ` + INSERT INTO mx_room_state (room_id, create_event) VALUES ($1, $2) + ON CONFLICT (room_id) DO UPDATE SET create_event=excluded.create_event + `, evt.RoomID, dbutil.JSON{Data: evt}) + return err +} + +func (store *SQLStateStore) GetCreate(ctx context.Context, roomID id.RoomID) (evt *event.Event, err error) { + err = store. + QueryRow(ctx, "SELECT create_event FROM mx_room_state WHERE room_id=$1", roomID). + Scan(&dbutil.JSON{Data: &evt}) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } else if err != nil { + return nil, err + } + if evt != nil { + err = evt.Content.ParseRaw(event.StateCreate) + } + return } diff --git a/sqlstatestore/v00-latest-revision.sql b/sqlstatestore/v00-latest-revision.sql index a58cc56a..132ed1ab 100644 --- a/sqlstatestore/v00-latest-revision.sql +++ b/sqlstatestore/v00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v7 (compatible with v3+): Latest revision +-- v0 -> v8 (compatible with v3+): Latest revision CREATE TABLE mx_registrations ( user_id TEXT PRIMARY KEY @@ -26,5 +26,6 @@ CREATE TABLE mx_room_state ( room_id TEXT PRIMARY KEY, power_levels jsonb, encryption jsonb, + create_event jsonb, members_fetched BOOLEAN NOT NULL DEFAULT false ); diff --git a/sqlstatestore/v08-create-event.sql b/sqlstatestore/v08-create-event.sql new file mode 100644 index 00000000..9f1b55c9 --- /dev/null +++ b/sqlstatestore/v08-create-event.sql @@ -0,0 +1,2 @@ +-- v8 (compatible with v3+): Add create event to room state table +ALTER TABLE mx_room_state ADD COLUMN create_event jsonb; diff --git a/statestore.go b/statestore.go index e728b885..1933ab95 100644 --- a/statestore.go +++ b/statestore.go @@ -34,6 +34,9 @@ type StateStore interface { SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error) + SetCreate(ctx context.Context, evt *event.Event) error + GetCreate(ctx context.Context, roomID id.RoomID) (*event.Event, error) + HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) @@ -68,9 +71,11 @@ func UpdateStateStore(ctx context.Context, store StateStore, evt *event.Event) { err = store.SetPowerLevels(ctx, evt.RoomID, content) case *event.EncryptionEventContent: err = store.SetEncryptionEvent(ctx, evt.RoomID, content) + case *event.CreateEventContent: + err = store.SetCreate(ctx, evt) default: switch evt.Type { - case event.StateMember, event.StatePowerLevels, event.StateEncryption: + case event.StateMember, event.StatePowerLevels, event.StateEncryption, event.StateCreate: zerolog.Ctx(ctx).Warn(). Stringer("event_id", evt.ID). Str("event_type", evt.Type.Type). @@ -101,6 +106,7 @@ type MemoryStateStore struct { MembersFetched map[id.RoomID]bool `json:"members_fetched"` PowerLevels map[id.RoomID]*event.PowerLevelsEventContent `json:"power_levels"` Encryption map[id.RoomID]*event.EncryptionEventContent `json:"encryption"` + Create map[id.RoomID]*event.Event `json:"create"` registrationsLock sync.RWMutex membersLock sync.RWMutex @@ -115,6 +121,7 @@ func NewMemoryStateStore() StateStore { MembersFetched: make(map[id.RoomID]bool), PowerLevels: make(map[id.RoomID]*event.PowerLevelsEventContent), Encryption: make(map[id.RoomID]*event.EncryptionEventContent), + Create: make(map[id.RoomID]*event.Event), } } @@ -298,6 +305,9 @@ func (store *MemoryStateStore) SetPowerLevels(_ context.Context, roomID id.RoomI func (store *MemoryStateStore) GetPowerLevels(_ context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) { store.powerLevelsLock.RLock() levels = store.PowerLevels[roomID] + if levels != nil && levels.CreateEvent == nil { + levels.CreateEvent = store.Create[roomID] + } store.powerLevelsLock.RUnlock() return } @@ -314,6 +324,23 @@ func (store *MemoryStateStore) HasPowerLevel(ctx context.Context, roomID id.Room return exerrors.Must(store.GetPowerLevel(ctx, roomID, userID)) >= exerrors.Must(store.GetPowerLevelRequirement(ctx, roomID, eventType)), nil } +func (store *MemoryStateStore) SetCreate(ctx context.Context, evt *event.Event) error { + store.powerLevelsLock.Lock() + store.Create[evt.RoomID] = evt + if pls, ok := store.PowerLevels[evt.RoomID]; ok && pls.CreateEvent == nil { + pls.CreateEvent = evt + } + store.powerLevelsLock.Unlock() + return nil +} + +func (store *MemoryStateStore) GetCreate(ctx context.Context, roomID id.RoomID) (*event.Event, error) { + store.powerLevelsLock.RLock() + evt := store.Create[roomID] + store.powerLevelsLock.RUnlock() + return evt, nil +} + func (store *MemoryStateStore) SetEncryptionEvent(_ context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error { store.encryptionLock.Lock() store.Encryption[roomID] = content From 96b07ad724dd3d7785c5eb30ae02749b17500070 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 13 Jul 2025 12:44:15 +0300 Subject: [PATCH 1245/1647] event: use full event type for stripped state for MSC4311 --- event/events.go | 25 +++++++++---------------- responses.go | 7 +------ sync.go | 9 ++------- 3 files changed, 12 insertions(+), 29 deletions(-) diff --git a/event/events.go b/event/events.go index a763cc31..1428bf8a 100644 --- a/event/events.go +++ b/event/events.go @@ -130,23 +130,16 @@ func (evt *Event) GetStateKey() string { return "" } -type StrippedState struct { - Content Content `json:"content"` - Type Type `json:"type"` - StateKey string `json:"state_key"` - Sender id.UserID `json:"sender"` -} - type Unsigned struct { - PrevContent *Content `json:"prev_content,omitempty"` - PrevSender id.UserID `json:"prev_sender,omitempty"` - Membership Membership `json:"membership,omitempty"` - ReplacesState id.EventID `json:"replaces_state,omitempty"` - Age int64 `json:"age,omitempty"` - TransactionID string `json:"transaction_id,omitempty"` - Relations *Relations `json:"m.relations,omitempty"` - RedactedBecause *Event `json:"redacted_because,omitempty"` - InviteRoomState []StrippedState `json:"invite_room_state,omitempty"` + PrevContent *Content `json:"prev_content,omitempty"` + PrevSender id.UserID `json:"prev_sender,omitempty"` + Membership Membership `json:"membership,omitempty"` + ReplacesState id.EventID `json:"replaces_state,omitempty"` + Age int64 `json:"age,omitempty"` + TransactionID string `json:"transaction_id,omitempty"` + Relations *Relations `json:"m.relations,omitempty"` + RedactedBecause *Event `json:"redacted_because,omitempty"` + InviteRoomState []*Event `json:"invite_room_state,omitempty"` BeeperHSOrder int64 `json:"com.beeper.hs.order,omitempty"` BeeperHSSuborder int16 `json:"com.beeper.hs.suborder,omitempty"` diff --git a/responses.go b/responses.go index 20d02af5..2e8005d4 100644 --- a/responses.go +++ b/responses.go @@ -648,12 +648,7 @@ type RespHierarchy struct { type ChildRoomsChunk struct { PublicRoomInfo - ChildrenState []StrippedStateWithTime `json:"children_state"` -} - -type StrippedStateWithTime struct { - event.StrippedState - Timestamp jsontime.UnixMilli `json:"origin_server_ts"` + ChildrenState []*event.Event `json:"children_state"` } type RespAppservicePing struct { diff --git a/sync.go b/sync.go index 9a2b9edf..c52bd2f9 100644 --- a/sync.go +++ b/sync.go @@ -263,7 +263,7 @@ func dontProcessOldEvents(userID id.UserID, resp *RespSync, since string) bool { // cli.Syncer.(mautrix.ExtensibleSyncer).OnSync(cli.MoveInviteState) func (cli *Client) MoveInviteState(ctx context.Context, resp *RespSync, _ string) bool { for _, meta := range resp.Rooms.Invite { - var inviteState []event.StrippedState + var inviteState []*event.Event var inviteEvt *event.Event for _, evt := range meta.State.Events { if evt.Type == event.StateMember && evt.GetStateKey() == cli.UserID.String() { @@ -271,12 +271,7 @@ func (cli *Client) MoveInviteState(ctx context.Context, resp *RespSync, _ string } else { evt.Type.Class = event.StateEventType _ = evt.Content.ParseRaw(evt.Type) - inviteState = append(inviteState, event.StrippedState{ - Content: evt.Content, - Type: evt.Type, - StateKey: evt.GetStateKey(), - Sender: evt.Sender, - }) + inviteState = append(inviteState, evt) } } if inviteEvt != nil { From 4866da52005c376dd10cb53e08399a543894fb93 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 18 Jul 2025 23:29:51 +0300 Subject: [PATCH 1246/1647] client: add custom room create ts field --- requests.go | 1 + 1 file changed, 1 insertion(+) diff --git a/requests.go b/requests.go index 09e4b3cd..17eda7d2 100644 --- a/requests.go +++ b/requests.go @@ -125,6 +125,7 @@ type ReqCreateRoom struct { PowerLevelOverride *event.PowerLevelsEventContent `json:"power_level_content_override,omitempty"` MeowRoomID id.RoomID `json:"fi.mau.room_id,omitempty"` + MeowCreateTS int64 `json:"fi.mau.origin_server_ts,omitempty"` BeeperInitialMembers []id.UserID `json:"com.beeper.initial_members,omitempty"` BeeperAutoJoinInvites bool `json:"com.beeper.auto_join_invites,omitempty"` BeeperLocalRoomID id.RoomID `json:"com.beeper.local_room_id,omitempty"` From 65a64c8044dd02a930fb5e446edaa52ac97389fd Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 20 Jul 2025 14:23:20 +0300 Subject: [PATCH 1247/1647] client: allow using custom http client for .well-known resolution --- client.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index 886dbb63..1f907608 100644 --- a/client.go +++ b/client.go @@ -139,6 +139,10 @@ type IdentityServerInfo struct { // Use ParseUserID to extract the server name from a user ID. // https://spec.matrix.org/v1.2/client-server-api/#server-discovery func DiscoverClientAPI(ctx context.Context, serverName string) (*ClientWellKnown, error) { + return DiscoverClientAPIWithClient(ctx, &http.Client{Timeout: 30 * time.Second}, serverName) +} + +func DiscoverClientAPIWithClient(ctx context.Context, client *http.Client, serverName string) (*ClientWellKnown, error) { wellKnownURL := url.URL{ Scheme: "https", Host: serverName, @@ -153,7 +157,6 @@ func DiscoverClientAPI(ctx context.Context, serverName string) (*ClientWellKnown req.Header.Set("Accept", "application/json") req.Header.Set("User-Agent", DefaultUserAgent+" (.well-known fetcher)") - client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) if err != nil { return nil, err From ea72271badd31f4f1ec27ae63be15ae25fcaa9df Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 21 Jul 2025 11:15:23 +0300 Subject: [PATCH 1248/1647] bridgev2/queue: run command handlers in background --- bridgev2/queue.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bridgev2/queue.go b/bridgev2/queue.go index 04d982b5..95011cda 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -106,7 +106,7 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) return EventHandlingResultIgnored } - br.Commands.Handle( + go br.Commands.Handle( ctx, evt.RoomID, evt.ID, @@ -114,7 +114,7 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH strings.TrimPrefix(msg.Body, br.Config.CommandPrefix+" "), msg.RelatesTo.GetReplyTo(), ) - return EventHandlingResultSuccess + return EventHandlingResultQueued } } if evt.Type == event.StateMember && evt.GetStateKey() == br.Bot.GetMXID().String() && evt.Content.AsMember().Membership == event.MembershipInvite && sender != nil { From 3ecdb886bfd03c850bec6737ea6bf94db44a5ec9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 22 Jul 2025 16:18:25 +0300 Subject: [PATCH 1249/1647] bridgev2/database: add method to mark backfill task as not done --- bridgev2/database/backfillqueue.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/bridgev2/database/backfillqueue.go b/bridgev2/database/backfillqueue.go index 224ae626..1f920640 100644 --- a/bridgev2/database/backfillqueue.go +++ b/bridgev2/database/backfillqueue.go @@ -78,6 +78,11 @@ const ( dispatched_at=$9, completed_at=$10, next_dispatch_min_ts=$11 WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3 ` + markBackfillTaskNotDoneQuery = ` + UPDATE backfill_task + SET is_done = false + WHERE bridge_id = $1 AND portal_id = $2 AND portal_receiver = $3 AND user_login_id = $4 + ` getNextBackfillQuery = ` SELECT bridge_id, portal_id, portal_receiver, user_login_id, batch_count, is_done, @@ -127,6 +132,10 @@ func (btq *BackfillTaskQuery) Update(ctx context.Context, bq *BackfillTask) erro return btq.Exec(ctx, updateBackfillQueueQuery, bq.sqlVariables()...) } +func (btq *BackfillTaskQuery) MarkNotDone(ctx context.Context, portalKey networkid.PortalKey, userLoginID networkid.UserLoginID) error { + return btq.Exec(ctx, markBackfillTaskNotDoneQuery, btq.BridgeID, portalKey.ID, portalKey.Receiver, userLoginID) +} + func (btq *BackfillTaskQuery) GetNext(ctx context.Context) (*BackfillTask, error) { return btq.QueryOne(ctx, getNextBackfillQuery, btq.BridgeID, time.Now().UnixNano()) } From 3fe5a7badc23f233ab951de46d5af16ad03e29da Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 22 Jul 2025 17:19:47 +0300 Subject: [PATCH 1250/1647] event: replace soft failed field in unsigned --- event/events.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/event/events.go b/event/events.go index 1428bf8a..1a57fb4b 100644 --- a/event/events.go +++ b/event/events.go @@ -146,13 +146,12 @@ type Unsigned struct { BeeperHSOrderString *BeeperEncodedOrder `json:"com.beeper.hs.order_string,omitempty"` BeeperFromBackup bool `json:"com.beeper.from_backup,omitempty"` - MauSoftFailed bool `json:"fi.mau.soft_failed,omitempty"` - MauRejectionReason string `json:"fi.mau.rejection_reason,omitempty"` + ElementSoftFailed bool `json:"io.element.synapse.soft_failed,omitempty"` } func (us *Unsigned) IsEmpty() bool { return us.PrevContent == nil && us.PrevSender == "" && us.ReplacesState == "" && us.Age == 0 && us.Membership == "" && us.TransactionID == "" && us.RedactedBecause == nil && us.InviteRoomState == nil && us.Relations == nil && us.BeeperHSOrder == 0 && us.BeeperHSSuborder == 0 && us.BeeperHSOrderString.IsZero() && - !us.MauSoftFailed && us.MauRejectionReason == "" + !us.ElementSoftFailed } From fcd7d9a525ad36f812d1d31fb676ac3b2796f120 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 22 Jul 2025 19:19:37 +0300 Subject: [PATCH 1251/1647] bridgev2/commands: allow canceling qr login --- bridgev2/commands/login.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go index 3544998c..a18564c2 100644 --- a/bridgev2/commands/login.go +++ b/bridgev2/commands/login.go @@ -273,6 +273,13 @@ func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait, prevEvent = new(id.EventID) ce.Ctx = context.WithValue(ce.Ctx, contextKeyPrevEventID, prevEvent) } + cancelCtx, cancelFunc := context.WithCancel(ce.Ctx) + defer cancelFunc() + StoreCommandState(ce.User, &CommandState{ + Action: "Login", + Cancel: cancelFunc, + }) + defer StoreCommandState(ce.User, nil) switch step.DisplayAndWaitParams.Type { case bridgev2.LoginDisplayTypeQR: err := sendQR(ce, step.DisplayAndWaitParams.Data, prevEvent) @@ -292,7 +299,7 @@ func doLoginDisplayAndWait(ce *Event, login bridgev2.LoginProcessDisplayAndWait, login.Cancel() return } - nextStep, err := login.Wait(ce.Ctx) + nextStep, err := login.Wait(cancelCtx) // Redact the QR code, unless the next step is refreshing the code (in which case the event is just edited) if *prevEvent != "" && (nextStep == nil || nextStep.StepID != step.StepID) { _, _ = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventRedaction, &event.Content{ From cb80e5c63f7f7a4e2ec61f19b6e84754d83c0df5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 22 Jul 2025 20:31:25 +0300 Subject: [PATCH 1252/1647] bridgev2/portal: fix adding rooms to personal space on create --- bridgev2/portal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index c91523ef..114609ce 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -4379,7 +4379,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo } func (portal *Portal) addToUserSpaces(ctx context.Context) { - if portal.Parent == nil { + if portal.Parent != nil { return } log := zerolog.Ctx(ctx) From 69a3d27c1c9f360da7d96bdd6c12de0f810f77c6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 22 Jul 2025 22:50:26 +0300 Subject: [PATCH 1253/1647] bridgev2: add interface for getting arbitrary state event --- bridgev2/matrix/connector.go | 13 ++++++++----- bridgev2/matrixinterface.go | 5 ++++- bridgev2/portal.go | 11 +++++++---- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 978f666f..c168ae3d 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -101,6 +101,7 @@ type Connector struct { var ( _ bridgev2.MatrixConnector = (*Connector)(nil) _ bridgev2.MatrixConnectorWithServer = (*Connector)(nil) + _ bridgev2.MatrixConnectorWithArbitraryRoomState = (*Connector)(nil) _ bridgev2.MatrixConnectorWithPostRoomBridgeHandling = (*Connector)(nil) _ bridgev2.MatrixConnectorWithPublicMedia = (*Connector)(nil) _ bridgev2.MatrixConnectorWithNameDisambiguation = (*Connector)(nil) @@ -534,12 +535,14 @@ func (br *Connector) GetPowerLevels(ctx context.Context, roomID id.RoomID) (*eve return br.Bot.PowerLevels(ctx, roomID) } -func (br *Connector) GetCreateEvent(ctx context.Context, roomID id.RoomID) (*event.Event, error) { - createEvt, err := br.Bot.StateStore.GetCreate(ctx, roomID) - if err != nil || createEvt != nil { - return createEvt, err +func (br *Connector) GetStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*event.Event, error) { + if eventType == event.StateCreate && stateKey == "" { + createEvt, err := br.Bot.StateStore.GetCreate(ctx, roomID) + if err != nil || createEvt != nil { + return createEvt, err + } } - return br.Bot.FullStateEvent(ctx, roomID, event.StateCreate, "") + return br.Bot.FullStateEvent(ctx, roomID, eventType, "") } func (br *Connector) GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 5d0cb014..b5a575ba 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -47,7 +47,6 @@ type MatrixConnector interface { GenerateContentURI(ctx context.Context, mediaID networkid.MediaID) (id.ContentURIString, error) GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error) - GetCreateEvent(ctx context.Context, roomID id.RoomID) (*event.Event, error) GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) GetMemberInfo(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) @@ -59,6 +58,10 @@ type MatrixConnector interface { ServerName() string } +type MatrixConnectorWithArbitraryRoomState interface { + GetStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*event.Event, error) +} + type MatrixConnectorWithServer interface { GetPublicAddress() string GetRouter() *mux.Router diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 114609ce..55f1cd47 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1655,10 +1655,13 @@ func (portal *Portal) handleMatrixPowerLevels( return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) } if content.CreateEvent == nil { - var err error - content.CreateEvent, err = portal.Bridge.Matrix.GetCreateEvent(ctx, portal.MXID) - if err != nil { - return EventHandlingResultFailed.WithMSSError(fmt.Errorf("failed to get create event for power levels: %w", err)) + ars, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState) + if ok { + var err error + content.CreateEvent, err = ars.GetStateEvent(ctx, portal.MXID, event.StateCreate, "") + if err != nil { + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("failed to get create event for power levels: %w", err)) + } } } api, ok := sender.Client.(PowerLevelHandlingNetworkAPI) From 463d2ea6d01154a2b1970a6197aa4d35dbb921bf Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 22 Jul 2025 23:31:16 +0300 Subject: [PATCH 1254/1647] bridgev2/portal: add bots to functional members in DMs --- bridgev2/portal.go | 43 +++++++++++++++++++++++++++++++++++++++++++ event/state.go | 9 +++++++++ 2 files changed, 52 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 55f1cd47..5fea134f 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -87,6 +87,9 @@ type Portal struct { roomCreateLock sync.Mutex + functionalMembersLock sync.Mutex + functionalMembersCache *event.ElementFunctionalMembersContent + events chan portalEvent eventsLock sync.Mutex @@ -2043,6 +2046,45 @@ func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, return } +func (portal *Portal) ensureFunctionalMember(ctx context.Context, ghost *Ghost) { + if !ghost.IsBot || portal.RoomType != database.RoomTypeDM || portal.OtherUserID == ghost.ID { + return + } + ars, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState) + if !ok { + return + } + portal.functionalMembersLock.Lock() + defer portal.functionalMembersLock.Unlock() + var functionalMembers *event.ElementFunctionalMembersContent + if portal.functionalMembersCache != nil { + functionalMembers = portal.functionalMembersCache + } else { + evt, err := ars.GetStateEvent(ctx, portal.MXID, event.StateElementFunctionalMembers, "") + if err != nil && !errors.Is(err, mautrix.MNotFound) { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get functional members state event") + return + } + functionalMembers = &event.ElementFunctionalMembersContent{} + if evt != nil { + evtContent, ok := evt.Content.Parsed.(*event.ElementFunctionalMembersContent) + if ok && evtContent != nil { + functionalMembers = evtContent + } + } + } + functionalMembers.Add(portal.Bridge.Bot.GetMXID()) + if functionalMembers.Add(ghost.Intent.GetMXID()) { + _, err := portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.Content{ + Parsed: functionalMembers, + }, time.Time{}) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to update functional members state event") + return + } + } +} + func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID, err error) { var ghost *Ghost if !sender.IsFromMe && sender.ForceDMUser && portal.OtherUserID != "" && sender.Sender != portal.OtherUserID { @@ -2066,6 +2108,7 @@ func (portal *Portal) getIntentAndUserMXIDFor(ctx context.Context, sender EventS return } ghost.UpdateInfoIfNecessary(ctx, source, evtType) + portal.ensureFunctionalMember(ctx, ghost) } if sender.IsFromMe { intent = source.User.DoublePuppet(ctx) diff --git a/event/state.go b/event/state.go index ff6dabfa..83390c90 100644 --- a/event/state.go +++ b/event/state.go @@ -8,6 +8,7 @@ package event import ( "encoding/base64" + "slices" "maunium.net/go/mautrix/id" ) @@ -267,3 +268,11 @@ type InsertionMarkerContent struct { type ElementFunctionalMembersContent struct { ServiceMembers []id.UserID `json:"service_members"` } + +func (efmc *ElementFunctionalMembersContent) Add(mxid id.UserID) bool { + if slices.Contains(efmc.ServiceMembers, mxid) { + return false + } + efmc.ServiceMembers = append(efmc.ServiceMembers, mxid) + return true +} From 5b55330b859c8eca30cdba627fbe8e12f62fa9b0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 23 Jul 2025 14:37:57 +0300 Subject: [PATCH 1255/1647] bridgev2: run PostStart in background --- bridgev2/bridge.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index a4ce033e..24619c79 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -120,7 +120,7 @@ func (br *Bridge) Start(ctx context.Context) error { if err != nil { return err } - br.PostStart(ctx) + go br.PostStart(ctx) return nil } From d5223cdc8fcebace5f1de8d9b02f4f4568fc663d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 23 Jul 2025 20:30:43 +0300 Subject: [PATCH 1256/1647] all: replace gorilla/mux with standard library --- appservice/appservice.go | 25 +++-- appservice/http.go | 10 +- bridgev2/matrix/connector.go | 4 +- bridgev2/matrix/provisioning.go | 96 ++++++++++++-------- bridgev2/matrix/publicmedia.go | 13 +-- bridgev2/matrixinterface.go | 7 +- crypto/verificationhelper/mockserver_test.go | 41 +++------ federation/keyserver.go | 34 +++---- go.mod | 9 +- go.sum | 14 ++- mediaproxy/mediaproxy.go | 95 ++++++++----------- 11 files changed, 161 insertions(+), 187 deletions(-) diff --git a/appservice/appservice.go b/appservice/appservice.go index 518e1073..5dd067c0 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -19,7 +19,6 @@ import ( "syscall" "time" - "github.com/gorilla/mux" "github.com/gorilla/websocket" "github.com/rs/zerolog" "golang.org/x/net/publicsuffix" @@ -43,7 +42,7 @@ func Create() *AppService { intents: make(map[id.UserID]*IntentAPI), HTTPClient: &http.Client{Timeout: 180 * time.Second, Jar: jar}, StateStore: mautrix.NewMemoryStateStore().(StateStore), - Router: mux.NewRouter(), + Router: http.NewServeMux(), UserAgent: mautrix.DefaultUserAgent, txnIDC: NewTransactionIDCache(128), Live: true, @@ -61,12 +60,12 @@ func Create() *AppService { DefaultHTTPRetries: 4, } - as.Router.HandleFunc("/_matrix/app/v1/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut) - as.Router.HandleFunc("/_matrix/app/v1/rooms/{roomAlias}", as.GetRoom).Methods(http.MethodGet) - as.Router.HandleFunc("/_matrix/app/v1/users/{userID}", as.GetUser).Methods(http.MethodGet) - as.Router.HandleFunc("/_matrix/app/v1/ping", as.PostPing).Methods(http.MethodPost) - as.Router.HandleFunc("/_matrix/mau/live", as.GetLive).Methods(http.MethodGet) - as.Router.HandleFunc("/_matrix/mau/ready", as.GetReady).Methods(http.MethodGet) + as.Router.HandleFunc("PUT /_matrix/app/v1/transactions/{txnID}", as.PutTransaction) + as.Router.HandleFunc("GET /_matrix/app/v1/rooms/{roomAlias}", as.GetRoom) + as.Router.HandleFunc("GET /_matrix/app/v1/users/{userID}", as.GetUser) + as.Router.HandleFunc("POST /_matrix/app/v1/ping", as.PostPing) + as.Router.HandleFunc("GET /_matrix/mau/live", as.GetLive) + as.Router.HandleFunc("GET /_matrix/mau/ready", as.GetReady) return as } @@ -114,13 +113,13 @@ var _ StateStore = (*mautrix.MemoryStateStore)(nil) // QueryHandler handles room alias and user ID queries from the homeserver. type QueryHandler interface { - QueryAlias(alias string) bool + QueryAlias(alias id.RoomAlias) bool QueryUser(userID id.UserID) bool } type QueryHandlerStub struct{} -func (qh *QueryHandlerStub) QueryAlias(alias string) bool { +func (qh *QueryHandlerStub) QueryAlias(alias id.RoomAlias) bool { return false } @@ -128,7 +127,7 @@ func (qh *QueryHandlerStub) QueryUser(userID id.UserID) bool { return false } -type WebsocketHandler func(WebsocketCommand) (ok bool, data interface{}) +type WebsocketHandler func(WebsocketCommand) (ok bool, data any) type StateStore interface { mautrix.StateStore @@ -160,7 +159,7 @@ type AppService struct { QueryHandler QueryHandler StateStore StateStore - Router *mux.Router + Router *http.ServeMux UserAgent string server *http.Server HTTPClient *http.Client diff --git a/appservice/http.go b/appservice/http.go index 1ebe6e56..862de7fd 100644 --- a/appservice/http.go +++ b/appservice/http.go @@ -17,7 +17,6 @@ import ( "syscall" "time" - "github.com/gorilla/mux" "github.com/rs/zerolog" "go.mau.fi/util/exhttp" "go.mau.fi/util/exstrings" @@ -95,8 +94,7 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) { return } - vars := mux.Vars(r) - txnID := vars["txnID"] + txnID := r.PathValue("txnID") if len(txnID) == 0 { mautrix.MInvalidParam.WithMessage("Missing transaction ID").Write(w) return @@ -240,8 +238,7 @@ func (as *AppService) GetRoom(w http.ResponseWriter, r *http.Request) { return } - vars := mux.Vars(r) - roomAlias := vars["roomAlias"] + roomAlias := id.RoomAlias(r.PathValue("roomAlias")) ok := as.QueryHandler.QueryAlias(roomAlias) if ok { exhttp.WriteEmptyJSONResponse(w, http.StatusOK) @@ -256,8 +253,7 @@ func (as *AppService) GetUser(w http.ResponseWriter, r *http.Request) { return } - vars := mux.Vars(r) - userID := id.UserID(vars["userID"]) + userID := id.UserID(r.PathValue("userID")) ok := as.QueryHandler.QueryUser(userID) if ok { exhttp.WriteEmptyJSONResponse(w, http.StatusOK) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index c168ae3d..af9931b0 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -12,6 +12,7 @@ import ( "encoding/base64" "errors" "fmt" + "net/http" "net/url" "os" "regexp" @@ -20,7 +21,6 @@ import ( "time" "unsafe" - "github.com/gorilla/mux" _ "github.com/lib/pq" "github.com/rs/zerolog" "go.mau.fi/util/dbutil" @@ -223,7 +223,7 @@ func (br *Connector) GetPublicAddress() string { return br.Config.AppService.PublicAddress } -func (br *Connector) GetRouter() *mux.Router { +func (br *Connector) GetRouter() *http.ServeMux { if br.GetPublicAddress() != "" { return br.AS.Router } diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index f865a19e..7f4b8a2e 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -17,7 +17,6 @@ import ( "sync" "time" - "github.com/gorilla/mux" "github.com/rs/xid" "github.com/rs/zerolog" "github.com/rs/zerolog/hlog" @@ -40,7 +39,7 @@ type matrixAuthCacheEntry struct { } type ProvisioningAPI struct { - Router *mux.Router + Router *http.ServeMux br *Connector log zerolog.Logger @@ -91,12 +90,12 @@ func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User { return r.Context().Value(provisioningUserKey).(*bridgev2.User) } -func (prov *ProvisioningAPI) GetRouter() *mux.Router { +func (prov *ProvisioningAPI) GetRouter() *http.ServeMux { return prov.Router } type IProvisioningAPI interface { - GetRouter() *mux.Router + GetRouter() *http.ServeMux GetUser(r *http.Request) *bridgev2.User } @@ -116,41 +115,48 @@ func (prov *ProvisioningAPI) Init() { tp.Dialer.Timeout = 10 * time.Second tp.Transport.ResponseHeaderTimeout = 10 * time.Second tp.Transport.TLSHandshakeTimeout = 10 * time.Second - prov.Router = prov.br.AS.Router.PathPrefix(prov.br.Config.Provisioning.Prefix).Subrouter() - prov.Router.Use(hlog.NewHandler(prov.log)) - prov.Router.Use(hlog.RequestIDHandler("request_id", "Request-Id")) - prov.Router.Use(exhttp.CORSMiddleware) - prov.Router.Use(requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true})) - prov.Router.Use(prov.AuthMiddleware) - prov.Router.Path("/v3/whoami").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetWhoami) - prov.Router.Path("/v3/login/flows").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLoginFlows) - prov.Router.Path("/v3/login/start/{flowID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginStart) - prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:user_input|cookies}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginSubmitInput) - prov.Router.Path("/v3/login/step/{loginProcessID}/{stepID}/{stepType:display_and_wait}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLoginWait) - prov.Router.Path("/v3/logout/{loginID}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostLogout) - prov.Router.Path("/v3/logins").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetLogins) - prov.Router.Path("/v3/contacts").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetContactList) - prov.Router.Path("/v3/search_users").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostSearchUsers) - prov.Router.Path("/v3/resolve_identifier/{identifier}").Methods(http.MethodGet, http.MethodOptions).HandlerFunc(prov.GetResolveIdentifier) - prov.Router.Path("/v3/create_dm/{identifier}").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateDM) - prov.Router.Path("/v3/create_group").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostCreateGroup) + prov.Router = http.NewServeMux() + prov.Router.HandleFunc("GET /v3/whoami", prov.GetWhoami) + prov.Router.HandleFunc("GET /v3/login/flows", prov.GetLoginFlows) + prov.Router.HandleFunc("POST /v3/login/start/{flowID}", prov.PostLoginStart) + prov.Router.HandleFunc("POST /v3/login/step/{loginProcessID}/{stepID}/{stepType}", prov.PostLoginStep) + prov.Router.HandleFunc("POST /v3/logout/{loginID}", prov.PostLogout) + prov.Router.HandleFunc("GET /v3/logins", prov.GetLogins) + prov.Router.HandleFunc("GET /v3/contacts", prov.GetContactList) + prov.Router.HandleFunc("POST /v3/search_users", prov.PostSearchUsers) + prov.Router.HandleFunc("GET /v3/resolve_identifier/{identifier}", prov.GetResolveIdentifier) + prov.Router.HandleFunc("POST /v3/create_dm/{identifier}", prov.PostCreateDM) + prov.Router.HandleFunc("POST /v3/create_group", prov.PostCreateGroup) if prov.br.Config.Provisioning.EnableSessionTransfers { prov.log.Debug().Msg("Enabling session transfer API") - prov.Router.Path("/v3/session_transfer/init").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostInitSessionTransfer) - prov.Router.Path("/v3/session_transfer/finish").Methods(http.MethodPost, http.MethodOptions).HandlerFunc(prov.PostFinishSessionTransfer) + prov.Router.HandleFunc("POST /v3/session_transfer/init", prov.PostInitSessionTransfer) + prov.Router.HandleFunc("POST /v3/session_transfer/finish", prov.PostFinishSessionTransfer) } if prov.br.Config.Provisioning.DebugEndpoints { prov.log.Debug().Msg("Enabling debug API at /debug") - r := prov.br.AS.Router.PathPrefix("/debug").Subrouter() - r.Use(prov.DebugAuthMiddleware) - r.HandleFunc("/pprof/cmdline", pprof.Cmdline).Methods(http.MethodGet) - r.HandleFunc("/pprof/profile", pprof.Profile).Methods(http.MethodGet) - r.HandleFunc("/pprof/symbol", pprof.Symbol).Methods(http.MethodGet) - r.HandleFunc("/pprof/trace", pprof.Trace).Methods(http.MethodGet) - r.PathPrefix("/pprof/").HandlerFunc(pprof.Index) + debugRouter := http.NewServeMux() + debugRouter.HandleFunc("GET /pprof/cmdline", pprof.Cmdline) + debugRouter.HandleFunc("GET /pprof/profile", pprof.Profile) + debugRouter.HandleFunc("GET /pprof/symbol", pprof.Symbol) + debugRouter.HandleFunc("GET /pprof/trace", pprof.Trace) + debugRouter.HandleFunc("/pprof/", pprof.Index) + prov.br.AS.Router.Handle("/debug", exhttp.ApplyMiddleware( + debugRouter, + hlog.NewHandler(prov.br.Log.With().Str("component", "debug api").Logger()), + prov.DebugAuthMiddleware, + )) } + + prov.br.AS.Router.Handle("/_matrix/provision", exhttp.ApplyMiddleware( + prov.Router, + hlog.NewHandler(prov.log), + hlog.RequestIDHandler("request_id", "Request-Id"), + exhttp.CORSMiddleware, + requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}), + prov.AuthMiddleware, + )) } func (prov *ProvisioningAPI) checkMatrixAuth(ctx context.Context, userID id.UserID, token string) error { @@ -250,7 +256,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { ctx := context.WithValue(r.Context(), ProvisioningKeyRequest, r) ctx = context.WithValue(ctx, provisioningUserKey, user) - if loginID, ok := mux.Vars(r)["loginProcessID"]; ok { + if loginID := r.PathValue("loginProcessID"); loginID != "" { prov.loginsLock.RLock() login, ok := prov.logins[loginID] prov.loginsLock.RUnlock() @@ -262,7 +268,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { login.Lock.Lock() // This will only unlock after the handler runs defer login.Lock.Unlock() - stepID := mux.Vars(r)["stepID"] + stepID := r.PathValue("stepID") if login.NextStep.StepID != stepID { zerolog.Ctx(r.Context()).Warn(). Str("request_step_id", stepID). @@ -271,7 +277,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { mautrix.MBadState.WithMessage("Step ID does not match").Write(w) return } - stepType := mux.Vars(r)["stepType"] + stepType := r.PathValue("stepType") if login.NextStep.Type != bridgev2.LoginStepType(stepType) { zerolog.Ctx(r.Context()).Warn(). Str("request_step_type", stepType). @@ -374,7 +380,7 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque login, err := prov.net.CreateLogin( r.Context(), prov.GetUser(r), - mux.Vars(r)["flowID"], + r.PathValue("flowID"), ) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create login process") @@ -422,6 +428,20 @@ func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *Prov }, bridgev2.DeleteOpts{LogoutRemote: true}) } +func (prov *ProvisioningAPI) PostLoginStep(w http.ResponseWriter, r *http.Request) { + switch bridgev2.LoginStepType(r.PathValue("stepType")) { + case bridgev2.LoginStepTypeUserInput, bridgev2.LoginStepTypeCookies: + prov.PostLoginSubmitInput(w, r) + case bridgev2.LoginStepTypeDisplayAndWait: + prov.PostLoginWait(w, r) + case bridgev2.LoginStepTypeComplete: + fallthrough + default: + // This is probably impossible because AuthMiddleware checks that the next step type matches the request. + mautrix.MUnrecognized.WithMessage("Invalid step type %q", r.PathValue("stepType")).Write(w) + } +} + func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http.Request) { var params map[string]string err := json.NewDecoder(r.Body).Decode(¶ms) @@ -475,7 +495,7 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques func (prov *ProvisioningAPI) PostLogout(w http.ResponseWriter, r *http.Request) { user := prov.GetUser(r) - userLoginID := networkid.UserLoginID(mux.Vars(r)["loginID"]) + userLoginID := networkid.UserLoginID(r.PathValue("loginID")) if userLoginID == "all" { for { login := user.GetDefaultLogin() @@ -571,7 +591,7 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http. mautrix.MUnrecognized.WithMessage("This bridge does not support resolving identifiers").Write(w) return } - resp, err := api.ResolveIdentifier(r.Context(), mux.Vars(r)["identifier"], createChat) + resp, err := api.ResolveIdentifier(r.Context(), r.PathValue("identifier"), createChat) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to resolve identifier") RespondWithError(w, err, "Internal error resolving identifier") diff --git a/bridgev2/matrix/publicmedia.go b/bridgev2/matrix/publicmedia.go index 9db5f442..95e37262 100644 --- a/bridgev2/matrix/publicmedia.go +++ b/bridgev2/matrix/publicmedia.go @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -16,8 +16,6 @@ import ( "net/http" "time" - "github.com/gorilla/mux" - "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/id" ) @@ -35,7 +33,7 @@ func (br *Connector) initPublicMedia() error { return fmt.Errorf("public media hash length is negative") } br.pubMediaSigKey = []byte(br.Config.PublicMedia.SigningKey) - br.AS.Router.HandleFunc("/_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia).Methods(http.MethodGet) + br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia) return nil } @@ -76,16 +74,15 @@ var proxyHeadersToCopy = []string{ } func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) contentURI := id.ContentURI{ - Homeserver: vars["server"], - FileID: vars["mediaID"], + Homeserver: r.PathValue("server"), + FileID: r.PathValue("mediaID"), } if !contentURI.IsValid() { http.Error(w, "invalid content URI", http.StatusBadRequest) return } - checksum, err := base64.RawURLEncoding.DecodeString(vars["checksum"]) + checksum, err := base64.RawURLEncoding.DecodeString(r.PathValue("checksum")) if err != nil || !hmac.Equal(checksum, br.makePublicMediaChecksum(contentURI)) { http.Error(w, "invalid base64 in checksum", http.StatusBadRequest) return diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index b5a575ba..b30e274a 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -10,11 +10,10 @@ import ( "context" "fmt" "io" + "net/http" "os" "time" - "github.com/gorilla/mux" - "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" @@ -64,7 +63,7 @@ type MatrixConnectorWithArbitraryRoomState interface { type MatrixConnectorWithServer interface { GetPublicAddress() string - GetRouter() *mux.Router + GetRouter() *http.ServeMux } type MatrixConnectorWithPublicMedia interface { diff --git a/crypto/verificationhelper/mockserver_test.go b/crypto/verificationhelper/mockserver_test.go index b6bf3d2c..45ca7781 100644 --- a/crypto/verificationhelper/mockserver_test.go +++ b/crypto/verificationhelper/mockserver_test.go @@ -12,11 +12,9 @@ import ( "io" "net/http" "net/http/httptest" - "net/url" "strings" "testing" - "github.com/gorilla/mux" "github.com/rs/zerolog/log" // zerolog-allow-global-log "github.com/stretchr/testify/require" "go.mau.fi/util/random" @@ -42,20 +40,6 @@ type mockServer struct { UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys } -func DecodeVarsMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - var err error - for k, v := range vars { - vars[k], err = url.PathUnescape(v) - if err != nil { - panic(err) - } - } - next.ServeHTTP(w, r) - }) -} - func createMockServer(t *testing.T) *mockServer { t.Helper() @@ -69,15 +53,14 @@ func createMockServer(t *testing.T) *mockServer { UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, } - router := mux.NewRouter().SkipClean(true).StrictSlash(false).UseEncodedPath() - router.Use(DecodeVarsMiddleware) - router.HandleFunc("/_matrix/client/v3/login", server.postLogin).Methods(http.MethodPost) - router.HandleFunc("/_matrix/client/v3/keys/query", server.postKeysQuery).Methods(http.MethodPost) - router.HandleFunc("/_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice).Methods(http.MethodPut) - router.HandleFunc("/_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData).Methods(http.MethodPut) - router.HandleFunc("/_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload).Methods(http.MethodPost) - router.HandleFunc("/_matrix/client/v3/keys/signatures/upload", server.emptyResp).Methods(http.MethodPost) - router.HandleFunc("/_matrix/client/v3/keys/upload", server.postKeysUpload).Methods(http.MethodPost) + router := http.NewServeMux() + router.HandleFunc("POST /_matrix/client/v3/login", server.postLogin) + router.HandleFunc("POST /_matrix/client/v3/keys/query", server.postKeysQuery) + router.HandleFunc("PUT /_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice) + router.HandleFunc("PUT /_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData) + router.HandleFunc("POST /_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload) + router.HandleFunc("POST /_matrix/client/v3/keys/signatures/upload", server.emptyResp) + router.HandleFunc("POST /_matrix/client/v3/keys/upload", server.postKeysUpload) server.Server = httptest.NewServer(router) return &server @@ -118,10 +101,9 @@ func (s *mockServer) postLogin(w http.ResponseWriter, r *http.Request) { } func (s *mockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) var req mautrix.ReqSendToDevice json.NewDecoder(r.Body).Decode(&req) - evtType := event.Type{Type: vars["type"], Class: event.ToDeviceEventType} + evtType := event.Type{Type: r.PathValue("type"), Class: event.ToDeviceEventType} for user, devices := range req.Messages { for device, content := range devices { @@ -140,9 +122,8 @@ func (s *mockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) { } func (s *mockServer) putAccountData(w http.ResponseWriter, r *http.Request) { - vars := mux.Vars(r) - userID := id.UserID(vars["userID"]) - eventType := event.Type{Type: vars["type"], Class: event.AccountDataEventType} + userID := id.UserID(r.PathValue("userID")) + eventType := event.Type{Type: r.PathValue("type"), Class: event.AccountDataEventType} jsonData, _ := io.ReadAll(r.Body) if _, ok := s.AccountData[userID]; !ok { diff --git a/federation/keyserver.go b/federation/keyserver.go index b0faf8fb..37998786 100644 --- a/federation/keyserver.go +++ b/federation/keyserver.go @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -12,7 +12,7 @@ import ( "strconv" "time" - "github.com/gorilla/mux" + "go.mau.fi/util/exerrors" "go.mau.fi/util/exhttp" "go.mau.fi/util/jsontime" @@ -51,19 +51,21 @@ type KeyServer struct { } // Register registers the key server endpoints to the given router. -func (ks *KeyServer) Register(r *mux.Router) { - r.HandleFunc("/.well-known/matrix/server", ks.GetWellKnown).Methods(http.MethodGet) - r.HandleFunc("/_matrix/federation/v1/version", ks.GetServerVersion).Methods(http.MethodGet) - keyRouter := r.PathPrefix("/_matrix/key").Subrouter() - keyRouter.HandleFunc("/v2/server", ks.GetServerKey).Methods(http.MethodGet) - keyRouter.HandleFunc("/v2/query/{serverName}", ks.GetQueryKeys).Methods(http.MethodGet) - keyRouter.HandleFunc("/v2/query", ks.PostQueryKeys).Methods(http.MethodPost) - keyRouter.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - mautrix.MUnrecognized.WithStatus(http.StatusNotFound).WithMessage("Unrecognized endpoint").Write(w) - }) - keyRouter.MethodNotAllowedHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - mautrix.MUnrecognized.WithStatus(http.StatusMethodNotAllowed).WithMessage("Invalid method for endpoint").Write(w) - }) +func (ks *KeyServer) Register(r *http.ServeMux) { + r.HandleFunc("GET /.well-known/matrix/server", ks.GetWellKnown) + r.HandleFunc("GET /_matrix/federation/v1/version", ks.GetServerVersion) + keyRouter := http.NewServeMux() + keyRouter.HandleFunc("GET /v2/server", ks.GetServerKey) + keyRouter.HandleFunc("GET /v2/query/{serverName}", ks.GetQueryKeys) + keyRouter.HandleFunc("POST /v2/query", ks.PostQueryKeys) + errorBodies := exhttp.ErrorBodies{ + NotFound: exerrors.Must(json.Marshal(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint"))), + MethodNotAllowed: exerrors.Must(json.Marshal(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint"))), + } + r.Handle("/_matrix/key", exhttp.ApplyMiddleware( + keyRouter, + exhttp.HandleErrors(errorBodies), + )) } // RespWellKnown is the response body for the `GET /.well-known/matrix/server` endpoint. @@ -157,7 +159,7 @@ type GetQueryKeysResponse struct { // // https://spec.matrix.org/v1.9/server-server-api/#get_matrixkeyv2queryservername func (ks *KeyServer) GetQueryKeys(w http.ResponseWriter, r *http.Request) { - serverName := mux.Vars(r)["serverName"] + serverName := r.PathValue("serverName") minimumValidUntilTSString := r.URL.Query().Get("minimum_valid_until_ts") minimumValidUntilTS, err := strconv.ParseInt(minimumValidUntilTSString, 10, 64) if err != nil && minimumValidUntilTSString != "" { diff --git a/go.mod b/go.mod index 59f29c0c..d71e86ab 100644 --- a/go.mod +++ b/go.mod @@ -2,12 +2,11 @@ module maunium.net/go/mautrix go 1.23.0 -toolchain go1.24.4 +toolchain go1.24.5 require ( filippo.io/edwards25519 v1.1.0 github.com/chzyer/readline v1.5.1 - github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.5.0 github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.28 @@ -18,10 +17,10 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.12 - go.mau.fi/util v0.8.8 + go.mau.fi/util v0.8.9-0.20250723171559-474867266038 go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.40.0 - golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc + golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 golang.org/x/net v0.42.0 golang.org/x/sync v0.16.0 gopkg.in/yaml.v3 v3.0.1 @@ -33,7 +32,7 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/petermattis/goid v0.0.0-20250508124226-395b08cebbdb // indirect + github.com/petermattis/goid v0.0.0-20250721140440-ea1c0173183e // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect diff --git a/go.sum b/go.sum index 9f48386e..eaa97cc8 100644 --- a/go.sum +++ b/go.sum @@ -13,8 +13,6 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSV github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= -github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= @@ -28,8 +26,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A= github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/petermattis/goid v0.0.0-20250508124226-395b08cebbdb h1:3PrKuO92dUTMrQ9dx0YNejC6U/Si6jqKmyQ9vWjwqR4= -github.com/petermattis/goid v0.0.0-20250508124226-395b08cebbdb/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/petermattis/goid v0.0.0-20250721140440-ea1c0173183e h1:D0bJD+4O3G4izvrQUmzCL80zazlN7EwJ0PPDhpJWC/I= +github.com/petermattis/goid v0.0.0-20250721140440-ea1c0173183e/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -53,14 +51,14 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.12 h1:YwGP/rrea2/CnCtUHgjuolG/PnMxdQtPMO5PvaE2/nY= github.com/yuin/goldmark v1.7.12/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.8.8 h1:OnuEEc/sIJFhnq4kFggiImUpcmnmL/xpvQMRu5Fiy5c= -go.mau.fi/util v0.8.8/go.mod h1:Y/kS3loxTEhy8Vill513EtPXr+CRDdae+Xj2BXXMy/c= +go.mau.fi/util v0.8.9-0.20250723171559-474867266038 h1:RVL8TVaYc3LTBBopfjCNDtD+6eZks0O+qgXN/9hsz7k= +go.mau.fi/util v0.8.9-0.20250723171559-474867266038/go.mod h1:GZZp5f9r2MgEu4GDvtB0XxCF7i6Z7Z8fM0w9a5oZH3Y= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= -golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc h1:TS73t7x3KarrNd5qAipmspBDS1rkMcgVG/fS1aRb4Rc= -golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc= +golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 h1:R9PFI6EUdfVKgwKjZef7QIwGcBKu86OEFpJ9nUEP2l4= +golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc= golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go index 4be799d3..6fbcdbad 100644 --- a/mediaproxy/mediaproxy.go +++ b/mediaproxy/mediaproxy.go @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,6 +8,7 @@ package mediaproxy import ( "context" + "encoding/json" "errors" "fmt" "io" @@ -21,8 +22,9 @@ import ( "strings" "time" - "github.com/gorilla/mux" "github.com/rs/zerolog" + "go.mau.fi/util/exerrors" + "go.mau.fi/util/exhttp" "maunium.net/go/mautrix" "maunium.net/go/mautrix/federation" @@ -108,8 +110,8 @@ type MediaProxy struct { serverName string serverKey *federation.SigningKey - FederationRouter *mux.Router - ClientMediaRouter *mux.Router + FederationRouter *http.ServeMux + ClientMediaRouter *http.ServeMux } func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProxy, error) { @@ -117,7 +119,7 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx if err != nil { return nil, err } - return &MediaProxy{ + mp := &MediaProxy{ serverName: serverName, serverKey: parsed, GetMedia: getMedia, @@ -132,7 +134,20 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx Version: strings.TrimPrefix(mautrix.VersionWithCommit, "v"), }, }, - }, nil + } + mp.FederationRouter = http.NewServeMux() + mp.FederationRouter.HandleFunc("GET /v1/media/download/{mediaID}", mp.DownloadMediaFederation) + mp.FederationRouter.HandleFunc("GET /v1/version", mp.KeyServer.GetServerVersion) + mp.ClientMediaRouter = http.NewServeMux() + mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}", mp.DownloadMedia) + mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia) + mp.ClientMediaRouter.HandleFunc("GET /thumbnail/{serverName}/{mediaID}", mp.DownloadMedia) + mp.ClientMediaRouter.HandleFunc("PUT /upload/{serverName}/{mediaID}", mp.UploadNotSupported) + mp.ClientMediaRouter.HandleFunc("POST /upload", mp.UploadNotSupported) + mp.ClientMediaRouter.HandleFunc("POST /create", mp.UploadNotSupported) + mp.ClientMediaRouter.HandleFunc("GET /config", mp.UploadNotSupported) + mp.ClientMediaRouter.HandleFunc("GET /preview_url", mp.PreviewURLNotSupported) + return mp, nil } type BasicConfig struct { @@ -162,7 +177,7 @@ type ServerConfig struct { } func (mp *MediaProxy) Listen(cfg ServerConfig) error { - router := mux.NewRouter() + router := http.NewServeMux() mp.RegisterRoutes(router) return http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), router) } @@ -188,38 +203,20 @@ func (mp *MediaProxy) EnableServerAuth(client *federation.Client, keyCache feder }) } -func (mp *MediaProxy) RegisterRoutes(router *mux.Router) { - if mp.FederationRouter == nil { - mp.FederationRouter = router.PathPrefix("/_matrix/federation").Subrouter() +func (mp *MediaProxy) RegisterRoutes(router *http.ServeMux) { + errorBodies := exhttp.ErrorBodies{ + NotFound: exerrors.Must(json.Marshal(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint"))), + MethodNotAllowed: exerrors.Must(json.Marshal(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint"))), } - if mp.ClientMediaRouter == nil { - mp.ClientMediaRouter = router.PathPrefix("/_matrix/client/v1/media").Subrouter() - } - - mp.FederationRouter.HandleFunc("/v1/media/download/{mediaID}", mp.DownloadMediaFederation).Methods(http.MethodGet) - mp.FederationRouter.HandleFunc("/v1/version", mp.KeyServer.GetServerVersion).Methods(http.MethodGet) - mp.ClientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet) - mp.ClientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}/{fileName}", mp.DownloadMedia).Methods(http.MethodGet) - mp.ClientMediaRouter.HandleFunc("/thumbnail/{serverName}/{mediaID}", mp.DownloadMedia).Methods(http.MethodGet) - mp.ClientMediaRouter.HandleFunc("/upload/{serverName}/{mediaID}", mp.UploadNotSupported).Methods(http.MethodPut) - mp.ClientMediaRouter.HandleFunc("/upload", mp.UploadNotSupported).Methods(http.MethodPost) - mp.ClientMediaRouter.HandleFunc("/create", mp.UploadNotSupported).Methods(http.MethodPost) - mp.ClientMediaRouter.HandleFunc("/config", mp.UploadNotSupported).Methods(http.MethodGet) - mp.ClientMediaRouter.HandleFunc("/preview_url", mp.PreviewURLNotSupported).Methods(http.MethodGet) - mp.FederationRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint) - mp.FederationRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod) - mp.ClientMediaRouter.NotFoundHandler = http.HandlerFunc(mp.UnknownEndpoint) - mp.ClientMediaRouter.MethodNotAllowedHandler = http.HandlerFunc(mp.UnsupportedMethod) - corsMiddleware := func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "X-Requested-With, Content-Type, Authorization") - w.Header().Set("Content-Security-Policy", "sandbox; default-src 'none'; script-src 'none'; plugin-types application/pdf; style-src 'unsafe-inline'; object-src 'self';") - next.ServeHTTP(w, r) - }) - } - mp.ClientMediaRouter.Use(corsMiddleware) + router.Handle("/_matrix/federation", exhttp.ApplyMiddleware( + mp.FederationRouter, + exhttp.HandleErrors(errorBodies), + )) + router.Handle("/_matrix/client/v1/media", exhttp.ApplyMiddleware( + mp.ClientMediaRouter, + exhttp.CORSMiddleware, + exhttp.HandleErrors(errorBodies), + )) mp.KeyServer.Register(router) } @@ -234,7 +231,7 @@ func queryToMap(vals url.Values) map[string]string { } func (mp *MediaProxy) getMedia(w http.ResponseWriter, r *http.Request) GetMediaResponse { - mediaID := mux.Vars(r)["mediaID"] + mediaID := r.PathValue("mediaID") if !id.IsValidMediaID(mediaID) { mautrix.MNotFound.WithMessage("Media ID %q is not valid", mediaID).Write(w) return nil @@ -380,8 +377,7 @@ func (mp *MediaProxy) addHeaders(w http.ResponseWriter, mimeType, fileName strin func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { ctx := r.Context() log := zerolog.Ctx(ctx) - vars := mux.Vars(r) - if vars["serverName"] != mp.serverName { + if r.PathValue("serverName") != mp.serverName { mautrix.MNotFound.WithMessage("This is a media proxy at %q, other media downloads are not available here", mp.serverName).Write(w) return } @@ -404,7 +400,7 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusTemporaryRedirect) } else if fileResp, ok := resp.(*GetMediaResponseFile); ok { responseStarted, err := doTempFileDownload(fileResp, func(wt io.WriterTo, size int64, mimeType string) error { - mp.addHeaders(w, mimeType, vars["fileName"]) + mp.addHeaders(w, mimeType, r.PathValue("fileName")) w.Header().Set("Content-Length", strconv.FormatInt(size, 10)) w.WriteHeader(http.StatusOK) _, err := wt.WriteTo(w) @@ -425,7 +421,7 @@ func (mp *MediaProxy) DownloadMedia(w http.ResponseWriter, r *http.Request) { if dataResp, ok := writerResp.(*GetMediaResponseData); ok { defer dataResp.Reader.Close() } - mp.addHeaders(w, writerResp.GetContentType(), vars["fileName"]) + mp.addHeaders(w, writerResp.GetContentType(), r.PathValue("fileName")) if writerResp.GetContentLength() != 0 { w.Header().Set("Content-Length", strconv.FormatInt(writerResp.GetContentLength(), 10)) } @@ -491,11 +487,6 @@ var ( ErrPreviewURLNotSupported = mautrix.MUnrecognized. WithMessage("This is a media proxy and does not support URL previews."). WithStatus(http.StatusNotImplemented) - ErrUnknownEndpoint = mautrix.MUnrecognized. - WithMessage("Unrecognized endpoint") - ErrUnsupportedMethod = mautrix.MUnrecognized. - WithMessage("Invalid method for endpoint"). - WithStatus(http.StatusMethodNotAllowed) ) func (mp *MediaProxy) UploadNotSupported(w http.ResponseWriter, r *http.Request) { @@ -505,11 +496,3 @@ func (mp *MediaProxy) UploadNotSupported(w http.ResponseWriter, r *http.Request) func (mp *MediaProxy) PreviewURLNotSupported(w http.ResponseWriter, r *http.Request) { ErrPreviewURLNotSupported.Write(w) } - -func (mp *MediaProxy) UnknownEndpoint(w http.ResponseWriter, r *http.Request) { - ErrUnknownEndpoint.Write(w) -} - -func (mp *MediaProxy) UnsupportedMethod(w http.ResponseWriter, r *http.Request) { - ErrUnsupportedMethod.Write(w) -} From 62c03d093a13c78e18b4e5886d0941247b982a3a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 23 Jul 2025 22:58:54 +0300 Subject: [PATCH 1257/1647] bridgev2/status: take context and http client in checkpoint SendHTTP --- bridgev2/matrix/connector.go | 6 +++--- bridgev2/matrix/matrix.go | 2 +- bridgev2/status/messagecheckpoint.go | 9 ++++++--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index af9931b0..0a859e42 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -435,7 +435,7 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2 log := zerolog.Ctx(ctx) if !evt.IsSourceEventDoublePuppeted { - err := br.SendMessageCheckpoints([]*status.MessageCheckpoint{ms.ToCheckpoint(evt)}) + err := br.SendMessageCheckpoints(ctx, []*status.MessageCheckpoint{ms.ToCheckpoint(evt)}) if err != nil { log.Err(err).Msg("Failed to send message checkpoint") } @@ -480,7 +480,7 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2 return "" } -func (br *Connector) SendMessageCheckpoints(checkpoints []*status.MessageCheckpoint) error { +func (br *Connector) SendMessageCheckpoints(ctx context.Context, checkpoints []*status.MessageCheckpoint) error { checkpointsJSON := status.CheckpointsJSON{Checkpoints: checkpoints} if br.Websocket { @@ -495,7 +495,7 @@ func (br *Connector) SendMessageCheckpoints(checkpoints []*status.MessageCheckpo return nil } - return checkpointsJSON.SendHTTP(endpoint, br.AS.Registration.AppToken) + return checkpointsJSON.SendHTTP(ctx, br.AS.HTTPClient, endpoint, br.AS.Registration.AppToken) } func (br *Connector) ParseGhostMXID(userID id.UserID) (networkid.UserID, bool) { diff --git a/bridgev2/matrix/matrix.go b/bridgev2/matrix/matrix.go index fed9d37a..49c377db 100644 --- a/bridgev2/matrix/matrix.go +++ b/bridgev2/matrix/matrix.go @@ -142,7 +142,7 @@ type CommandProcessor interface { } func (br *Connector) sendSuccessCheckpoint(ctx context.Context, evt *event.Event, step status.MessageCheckpointStep, retryNum int) { - err := br.SendMessageCheckpoints([]*status.MessageCheckpoint{{ + err := br.SendMessageCheckpoints(ctx, []*status.MessageCheckpoint{{ RoomID: evt.RoomID, EventID: evt.ID, EventType: evt.Type, diff --git a/bridgev2/status/messagecheckpoint.go b/bridgev2/status/messagecheckpoint.go index ea859b84..b3c05f4f 100644 --- a/bridgev2/status/messagecheckpoint.go +++ b/bridgev2/status/messagecheckpoint.go @@ -169,13 +169,13 @@ type CheckpointsJSON struct { Checkpoints []*MessageCheckpoint `json:"checkpoints"` } -func (cj *CheckpointsJSON) SendHTTP(endpoint string, token string) error { +func (cj *CheckpointsJSON) SendHTTP(ctx context.Context, cli *http.Client, endpoint string, token string) error { var body bytes.Buffer if err := json.NewEncoder(&body).Encode(cj); err != nil { return fmt.Errorf("failed to encode message checkpoint JSON: %w", err) } - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, &body) if err != nil { @@ -186,7 +186,10 @@ func (cj *CheckpointsJSON) SendHTTP(endpoint string, token string) error { req.Header.Set("User-Agent", mautrix.DefaultUserAgent+" (checkpoint sender)") req.Header.Set("Content-Type", "application/json") - resp, err := http.DefaultClient.Do(req) + if cli == nil { + cli = http.DefaultClient + } + resp, err := cli.Do(req) if err != nil { return mautrix.HTTPError{ Request: req, From 83b4b71a167c1f871f8f8da36b5d2338e92db983 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 23 Jul 2025 22:59:10 +0300 Subject: [PATCH 1258/1647] appservice/websocket: switch from gorilla to coder --- appservice/appservice.go | 3 +- appservice/websocket.go | 97 +++++++++++++++++++----------------- bridgev2/matrix/connector.go | 4 +- bridgev2/matrix/websocket.go | 2 +- go.mod | 2 +- go.sum | 4 +- 6 files changed, 58 insertions(+), 54 deletions(-) diff --git a/appservice/appservice.go b/appservice/appservice.go index 5dd067c0..b0af02cd 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -19,7 +19,7 @@ import ( "syscall" "time" - "github.com/gorilla/websocket" + "github.com/coder/websocket" "github.com/rs/zerolog" "golang.org/x/net/publicsuffix" "gopkg.in/yaml.v3" @@ -178,7 +178,6 @@ type AppService struct { intentsLock sync.RWMutex ws *websocket.Conn - wsWriteLock sync.Mutex StopWebsocket func(error) websocketHandlers map[string]WebsocketHandler websocketHandlersLock sync.RWMutex diff --git a/appservice/websocket.go b/appservice/websocket.go index 3d5bd232..62f4370c 100644 --- a/appservice/websocket.go +++ b/appservice/websocket.go @@ -17,9 +17,8 @@ import ( "strings" "sync" "sync/atomic" - "time" - "github.com/gorilla/websocket" + "github.com/coder/websocket" "github.com/rs/zerolog" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -28,11 +27,9 @@ import ( ) type WebsocketRequest struct { - ReqID int `json:"id,omitempty"` - Command string `json:"command"` - Data interface{} `json:"data"` - - Deadline time.Duration `json:"-"` + ReqID int `json:"id,omitempty"` + Command string `json:"command"` + Data any `json:"data"` } type WebsocketCommand struct { @@ -43,7 +40,7 @@ type WebsocketCommand struct { Ctx context.Context `json:"-"` } -func (wsc *WebsocketCommand) MakeResponse(ok bool, data interface{}) *WebsocketRequest { +func (wsc *WebsocketCommand) MakeResponse(ok bool, data any) *WebsocketRequest { if wsc.ReqID == 0 || wsc.Command == "response" || wsc.Command == "error" { return nil } @@ -100,8 +97,8 @@ type WebsocketMessage struct { } const ( - WebsocketCloseConnReplaced = 4001 - WebsocketCloseTxnNotAcknowledged = 4002 + WebsocketCloseConnReplaced websocket.StatusCode = 4001 + WebsocketCloseTxnNotAcknowledged websocket.StatusCode = 4002 ) type MeowWebsocketCloseCode string @@ -135,7 +132,7 @@ func (mwcc MeowWebsocketCloseCode) String() string { } type CloseCommand struct { - Code int `json:"-"` + Code websocket.StatusCode `json:"-"` Command string `json:"command"` Status MeowWebsocketCloseCode `json:"status"` } @@ -145,15 +142,15 @@ func (cc CloseCommand) Error() string { } func parseCloseError(err error) error { - closeError := &websocket.CloseError{} + var closeError websocket.CloseError if !errors.As(err, &closeError) { return err } var closeCommand CloseCommand closeCommand.Code = closeError.Code closeCommand.Command = "disconnect" - if len(closeError.Text) > 0 { - jsonErr := json.Unmarshal([]byte(closeError.Text), &closeCommand) + if len(closeError.Reason) > 0 { + jsonErr := json.Unmarshal([]byte(closeError.Reason), &closeCommand) if jsonErr != nil { return err } @@ -161,7 +158,7 @@ func parseCloseError(err error) error { if len(closeCommand.Status) == 0 { if closeCommand.Code == WebsocketCloseConnReplaced { closeCommand.Status = MeowConnectionReplaced - } else if closeCommand.Code == websocket.CloseServiceRestart { + } else if closeCommand.Code == websocket.StatusServiceRestart { closeCommand.Status = MeowServerShuttingDown } } @@ -172,20 +169,22 @@ func (as *AppService) HasWebsocket() bool { return as.ws != nil } -func (as *AppService) SendWebsocket(cmd *WebsocketRequest) error { +func (as *AppService) SendWebsocket(ctx context.Context, cmd *WebsocketRequest) error { ws := as.ws if cmd == nil { return nil } else if ws == nil { return ErrWebsocketNotConnected } - as.wsWriteLock.Lock() - defer as.wsWriteLock.Unlock() - if cmd.Deadline == 0 { - cmd.Deadline = 3 * time.Minute + wr, err := ws.Writer(ctx, websocket.MessageText) + if err != nil { + return err } - _ = ws.SetWriteDeadline(time.Now().Add(cmd.Deadline)) - return ws.WriteJSON(cmd) + err = json.NewEncoder(wr).Encode(cmd) + if err != nil { + return err + } + return nil } func (as *AppService) clearWebsocketResponseWaiters() { @@ -222,12 +221,12 @@ func (er *ErrorResponse) Error() string { return fmt.Sprintf("%s: %s", er.Code, er.Message) } -func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketRequest, response interface{}) error { +func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketRequest, response any) error { cmd.ReqID = int(atomic.AddInt32(&as.websocketRequestID, 1)) respChan := make(chan *WebsocketCommand, 1) as.addWebsocketResponseWaiter(cmd.ReqID, respChan) defer as.removeWebsocketResponseWaiter(cmd.ReqID, respChan) - err := as.SendWebsocket(cmd) + err := as.SendWebsocket(ctx, cmd) if err != nil { return err } @@ -256,7 +255,7 @@ func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketReques } } -func (as *AppService) unknownCommandHandler(cmd WebsocketCommand) (bool, interface{}) { +func (as *AppService) unknownCommandHandler(cmd WebsocketCommand) (bool, any) { zerolog.Ctx(cmd.Ctx).Warn().Msg("No handler for websocket command") return false, fmt.Errorf("unknown request type") } @@ -280,14 +279,22 @@ func (as *AppService) defaultHandleWebsocketTransaction(ctx context.Context, msg return true, &WebsocketTransactionResponse{TxnID: msg.TxnID} } -func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) { +func (as *AppService) consumeWebsocket(ctx context.Context, stopFunc func(error), ws *websocket.Conn) { defer stopFunc(ErrWebsocketUnknownError) - ctx := context.Background() for { - var msg WebsocketMessage - err := ws.ReadJSON(&msg) + msgType, reader, err := ws.Reader(ctx) if err != nil { - as.Log.Debug().Err(err).Msg("Error reading from websocket") + as.Log.Debug().Err(err).Msg("Error getting reader from websocket") + stopFunc(parseCloseError(err)) + return + } else if msgType != websocket.MessageText { + as.Log.Debug().Msg("Ignoring non-text message from websocket") + continue + } + var msg WebsocketMessage + err = json.NewDecoder(reader).Decode(&msg) + if err != nil { + as.Log.Debug().Err(err).Msg("Error reading JSON from websocket") stopFunc(parseCloseError(err)) return } @@ -298,11 +305,11 @@ func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) with = with.Str("transaction_id", msg.TxnID) } log := with.Logger() - ctx = log.WithContext(ctx) + ctx := log.WithContext(ctx) if msg.Command == "" || msg.Command == "transaction" { ok, resp := as.WebsocketTransactionHandler(ctx, msg) go func() { - err := as.SendWebsocket(msg.MakeResponse(ok, resp)) + err := as.SendWebsocket(ctx, msg.MakeResponse(ok, resp)) if err != nil { log.Warn().Err(err).Msg("Failed to send response to websocket transaction") } else { @@ -334,7 +341,7 @@ func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) } go func() { okResp, data := handler(msg.WebsocketCommand) - err := as.SendWebsocket(msg.MakeResponse(okResp, data)) + err := as.SendWebsocket(ctx, msg.MakeResponse(okResp, data)) if err != nil { log.Error().Err(err).Msg("Failed to send response to websocket command") } else if okResp { @@ -347,7 +354,7 @@ func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) } } -func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error { +func (as *AppService) StartWebsocket(ctx context.Context, baseURL string, onConnect func()) error { var parsed *url.URL if baseURL != "" { var err error @@ -365,12 +372,15 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error { } else if parsed.Scheme == "https" { parsed.Scheme = "wss" } - ws, resp, err := websocket.DefaultDialer.Dial(parsed.String(), http.Header{ - "Authorization": []string{fmt.Sprintf("Bearer %s", as.Registration.AppToken)}, - "User-Agent": []string{as.BotClient().UserAgent}, + ws, resp, err := websocket.Dial(ctx, parsed.String(), &websocket.DialOptions{ + HTTPClient: as.HTTPClient, + HTTPHeader: http.Header{ + "Authorization": []string{fmt.Sprintf("Bearer %s", as.Registration.AppToken)}, + "User-Agent": []string{as.BotClient().UserAgent}, - "X-Mautrix-Process-ID": []string{as.ProcessID}, - "X-Mautrix-Websocket-Version": []string{"3"}, + "X-Mautrix-Process-ID": []string{as.ProcessID}, + "X-Mautrix-Websocket-Version": []string{"3"}, + }, }) if resp != nil && resp.StatusCode >= 400 { var errResp mautrix.RespError @@ -406,7 +416,7 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error { as.PrepareWebsocket() as.Log.Debug().Msg("Appservice transaction websocket opened") - go as.consumeWebsocket(stopFunc, ws) + go as.consumeWebsocket(ctx, stopFunc, ws) var onConnectDone atomic.Bool if onConnect != nil { @@ -428,12 +438,7 @@ func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error { as.ws = nil } - _ = ws.SetWriteDeadline(time.Now().Add(3 * time.Second)) - err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, "")) - if err != nil && !errors.Is(err, websocket.ErrCloseSent) { - as.Log.Warn().Err(err).Msg("Error writing close message to websocket") - } - err = ws.Close() + err = ws.Close(websocket.StatusGoingAway, "") if err != nil { as.Log.Warn().Err(err).Msg("Error closing websocket") } diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 0a859e42..158148f3 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -413,7 +413,7 @@ func (br *Connector) GhostIntent(userID networkid.UserID) bridgev2.MatrixAPI { func (br *Connector) SendBridgeStatus(ctx context.Context, state *status.BridgeState) error { if br.Websocket { br.hasSentAnyStates = true - return br.AS.SendWebsocket(&appservice.WebsocketRequest{ + return br.AS.SendWebsocket(ctx, &appservice.WebsocketRequest{ Command: "bridge_status", Data: state, }) @@ -484,7 +484,7 @@ func (br *Connector) SendMessageCheckpoints(ctx context.Context, checkpoints []* checkpointsJSON := status.CheckpointsJSON{Checkpoints: checkpoints} if br.Websocket { - return br.AS.SendWebsocket(&appservice.WebsocketRequest{ + return br.AS.SendWebsocket(ctx, &appservice.WebsocketRequest{ Command: "message_checkpoint", Data: checkpointsJSON, }) diff --git a/bridgev2/matrix/websocket.go b/bridgev2/matrix/websocket.go index c679f960..b498cacd 100644 --- a/bridgev2/matrix/websocket.go +++ b/bridgev2/matrix/websocket.go @@ -57,7 +57,7 @@ func (br *Connector) startWebsocket(wg *sync.WaitGroup) { addr = br.Config.Homeserver.Address } for { - err := br.AS.StartWebsocket(addr, onConnect) + err := br.AS.StartWebsocket(br.Bridge.BackgroundCtx, addr, onConnect) if errors.Is(err, appservice.ErrWebsocketManualStop) { return } else if closeCommand := (&appservice.CloseCommand{}); errors.As(err, &closeCommand) && closeCommand.Status == appservice.MeowConnectionReplaced { diff --git a/go.mod b/go.mod index d71e86ab..1133313f 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ toolchain go1.24.5 require ( filippo.io/edwards25519 v1.1.0 github.com/chzyer/readline v1.5.1 - github.com/gorilla/websocket v1.5.0 + github.com/coder/websocket v1.8.13 github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.28 github.com/rs/xid v1.6.0 diff --git a/go.sum b/go.sum index eaa97cc8..461ee542 100644 --- a/go.sum +++ b/go.sum @@ -8,13 +8,13 @@ github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= +github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE= +github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= -github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= From 74ab3b118e10006ed3afcab3a34c08b6a3a71fa8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 28 Jul 2025 15:53:17 +0300 Subject: [PATCH 1259/1647] bridgev2/portal: add todo --- bridgev2/portal.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 5fea134f..7301c8ad 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2073,6 +2073,7 @@ func (portal *Portal) ensureFunctionalMember(ctx context.Context, ghost *Ghost) } } } + // TODO what about non-double-puppeted user ghosts? functionalMembers.Add(portal.Bridge.Bot.GetMXID()) if functionalMembers.Add(ghost.Intent.GetMXID()) { _, err := portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.Content{ From ae2c07fb863a69c22e3ed66c89dbb35b39f20e2a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 28 Jul 2025 17:34:28 +0300 Subject: [PATCH 1260/1647] appservice/websocket: close writer after sending --- appservice/websocket.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/appservice/websocket.go b/appservice/websocket.go index 62f4370c..18768098 100644 --- a/appservice/websocket.go +++ b/appservice/websocket.go @@ -182,9 +182,10 @@ func (as *AppService) SendWebsocket(ctx context.Context, cmd *WebsocketRequest) } err = json.NewEncoder(wr).Encode(cmd) if err != nil { + _ = wr.Close() return err } - return nil + return wr.Close() } func (as *AppService) clearWebsocketResponseWaiters() { From 2e7ff3fedd4c3fb89dad8bddceb8e10846c2cef6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 28 Jul 2025 22:03:43 +0300 Subject: [PATCH 1261/1647] all: fix trailing slash in subrouters --- bridgev2/matrix/provisioning.go | 6 ++++-- federation/keyserver.go | 3 ++- mediaproxy/mediaproxy.go | 6 ++++-- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 7f4b8a2e..6b594deb 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -142,15 +142,17 @@ func (prov *ProvisioningAPI) Init() { debugRouter.HandleFunc("GET /pprof/symbol", pprof.Symbol) debugRouter.HandleFunc("GET /pprof/trace", pprof.Trace) debugRouter.HandleFunc("/pprof/", pprof.Index) - prov.br.AS.Router.Handle("/debug", exhttp.ApplyMiddleware( + prov.br.AS.Router.Handle("/debug/", exhttp.ApplyMiddleware( debugRouter, + exhttp.StripPrefix("/debug"), hlog.NewHandler(prov.br.Log.With().Str("component", "debug api").Logger()), prov.DebugAuthMiddleware, )) } - prov.br.AS.Router.Handle("/_matrix/provision", exhttp.ApplyMiddleware( + prov.br.AS.Router.Handle("/_matrix/provision/", exhttp.ApplyMiddleware( prov.Router, + exhttp.StripPrefix("/_matrix/provision"), hlog.NewHandler(prov.log), hlog.RequestIDHandler("request_id", "Request-Id"), exhttp.CORSMiddleware, diff --git a/federation/keyserver.go b/federation/keyserver.go index 37998786..35ec59fd 100644 --- a/federation/keyserver.go +++ b/federation/keyserver.go @@ -62,8 +62,9 @@ func (ks *KeyServer) Register(r *http.ServeMux) { NotFound: exerrors.Must(json.Marshal(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint"))), MethodNotAllowed: exerrors.Must(json.Marshal(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint"))), } - r.Handle("/_matrix/key", exhttp.ApplyMiddleware( + r.Handle("/_matrix/key/", exhttp.ApplyMiddleware( keyRouter, + exhttp.StripPrefix("/_matrix/key"), exhttp.HandleErrors(errorBodies), )) } diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go index 6fbcdbad..a5f07afa 100644 --- a/mediaproxy/mediaproxy.go +++ b/mediaproxy/mediaproxy.go @@ -208,12 +208,14 @@ func (mp *MediaProxy) RegisterRoutes(router *http.ServeMux) { NotFound: exerrors.Must(json.Marshal(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint"))), MethodNotAllowed: exerrors.Must(json.Marshal(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint"))), } - router.Handle("/_matrix/federation", exhttp.ApplyMiddleware( + router.Handle("/_matrix/federation/", exhttp.ApplyMiddleware( mp.FederationRouter, + exhttp.StripPrefix("/_matrix/federation"), exhttp.HandleErrors(errorBodies), )) - router.Handle("/_matrix/client/v1/media", exhttp.ApplyMiddleware( + router.Handle("/_matrix/client/v1/media/", exhttp.ApplyMiddleware( mp.ClientMediaRouter, + exhttp.StripPrefix("/_matrix/client/v1/media"), exhttp.CORSMiddleware, exhttp.HandleErrors(errorBodies), )) From f1da44490c55eea5856e4dc40e6dbb74ecfe4627 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 29 Jul 2025 16:15:16 +0300 Subject: [PATCH 1262/1647] bridgev2/provisioning: move login step checks into handler --- bridgev2/matrix/provisioning.go | 66 ++++++++++++++++----------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 6b594deb..e3ec21dd 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -258,38 +258,6 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { ctx := context.WithValue(r.Context(), ProvisioningKeyRequest, r) ctx = context.WithValue(ctx, provisioningUserKey, user) - if loginID := r.PathValue("loginProcessID"); loginID != "" { - prov.loginsLock.RLock() - login, ok := prov.logins[loginID] - prov.loginsLock.RUnlock() - if !ok { - zerolog.Ctx(r.Context()).Warn().Str("login_id", loginID).Msg("Login not found") - mautrix.MNotFound.WithMessage("Login not found").Write(w) - return - } - login.Lock.Lock() - // This will only unlock after the handler runs - defer login.Lock.Unlock() - stepID := r.PathValue("stepID") - if login.NextStep.StepID != stepID { - zerolog.Ctx(r.Context()).Warn(). - Str("request_step_id", stepID). - Str("expected_step_id", login.NextStep.StepID). - Msg("Step ID does not match") - mautrix.MBadState.WithMessage("Step ID does not match").Write(w) - return - } - stepType := r.PathValue("stepType") - if login.NextStep.Type != bridgev2.LoginStepType(stepType) { - zerolog.Ctx(r.Context()).Warn(). - Str("request_step_type", stepType). - Str("expected_step_type", string(login.NextStep.Type)). - Msg("Step type does not match") - mautrix.MBadState.WithMessage("Step type does not match").Write(w) - return - } - ctx = context.WithValue(ctx, provisioningLoginProcessKey, login) - } h.ServeHTTP(w, r.WithContext(ctx)) }) } @@ -431,6 +399,38 @@ func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *Prov } func (prov *ProvisioningAPI) PostLoginStep(w http.ResponseWriter, r *http.Request) { + loginID := r.PathValue("loginProcessID") + prov.loginsLock.RLock() + login, ok := prov.logins[loginID] + prov.loginsLock.RUnlock() + if !ok { + zerolog.Ctx(r.Context()).Warn().Str("login_id", loginID).Msg("Login not found") + mautrix.MNotFound.WithMessage("Login not found").Write(w) + return + } + login.Lock.Lock() + // This will only unlock after the handler runs + defer login.Lock.Unlock() + stepID := r.PathValue("stepID") + if login.NextStep.StepID != stepID { + zerolog.Ctx(r.Context()).Warn(). + Str("request_step_id", stepID). + Str("expected_step_id", login.NextStep.StepID). + Msg("Step ID does not match") + mautrix.MBadState.WithMessage("Step ID does not match").Write(w) + return + } + stepType := r.PathValue("stepType") + if login.NextStep.Type != bridgev2.LoginStepType(stepType) { + zerolog.Ctx(r.Context()).Warn(). + Str("request_step_type", stepType). + Str("expected_step_type", string(login.NextStep.Type)). + Msg("Step type does not match") + mautrix.MBadState.WithMessage("Step type does not match").Write(w) + return + } + ctx := context.WithValue(r.Context(), provisioningLoginProcessKey, login) + r = r.WithContext(ctx) switch bridgev2.LoginStepType(r.PathValue("stepType")) { case bridgev2.LoginStepTypeUserInput, bridgev2.LoginStepTypeCookies: prov.PostLoginSubmitInput(w, r) @@ -439,7 +439,7 @@ func (prov *ProvisioningAPI) PostLoginStep(w http.ResponseWriter, r *http.Reques case bridgev2.LoginStepTypeComplete: fallthrough default: - // This is probably impossible because AuthMiddleware checks that the next step type matches the request. + // This is probably impossible because of the above check that the next step type matches the request. mautrix.MUnrecognized.WithMessage("Invalid step type %q", r.PathValue("stepType")).Write(w) } } From 26e66f293e6a25d1167d8dc60bb7e9efcfe69d37 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 29 Jul 2025 16:15:36 +0300 Subject: [PATCH 1263/1647] bridgev2/portal: return event ignored result for type unknown --- bridgev2/portal.go | 1 + 1 file changed, 1 insertion(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 7301c8ad..2f973aae 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2008,6 +2008,7 @@ func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, switch evtType { case RemoteEventUnknown: log.Debug().Msg("Ignoring remote event with type unknown") + res = EventHandlingResultIgnored case RemoteEventMessage, RemoteEventMessageUpsert: res = portal.handleRemoteMessage(ctx, source, evt.(RemoteMessage)) case RemoteEventEdit: From b4c7abd62b509ba84c72c02a63c926c578a42e81 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 29 Jul 2025 17:10:50 +0300 Subject: [PATCH 1264/1647] bridgev2,federation,mediaproxy: enable http access logging --- bridgev2/matrix/directmedia.go | 2 +- bridgev2/matrix/provisioning.go | 8 ++++++++ federation/keyserver.go | 13 ++++++++++--- mediaproxy/mediaproxy.go | 20 ++++++++++++++------ 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/bridgev2/matrix/directmedia.go b/bridgev2/matrix/directmedia.go index 71c01078..0667981a 100644 --- a/bridgev2/matrix/directmedia.go +++ b/bridgev2/matrix/directmedia.go @@ -39,7 +39,7 @@ func (br *Connector) initDirectMedia() error { if err != nil { return fmt.Errorf("failed to initialize media proxy: %w", err) } - br.MediaProxy.RegisterRoutes(br.AS.Router) + br.MediaProxy.RegisterRoutes(br.AS.Router, br.Log.With().Str("component", "media proxy").Logger()) br.dmaSigKey = sha256.Sum256(br.MediaProxy.GetServerKey().Priv.Seed()) dmn.SetUseDirectMedia() br.Log.Debug().Str("server_name", br.MediaProxy.GetServerName()).Msg("Enabled direct media access") diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index e3ec21dd..df3e1bdf 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -20,9 +20,11 @@ import ( "github.com/rs/xid" "github.com/rs/zerolog" "github.com/rs/zerolog/hlog" + "go.mau.fi/util/exerrors" "go.mau.fi/util/exhttp" "go.mau.fi/util/exstrings" "go.mau.fi/util/jsontime" + "go.mau.fi/util/ptr" "go.mau.fi/util/requestlog" "maunium.net/go/mautrix" @@ -146,10 +148,15 @@ func (prov *ProvisioningAPI) Init() { debugRouter, exhttp.StripPrefix("/debug"), hlog.NewHandler(prov.br.Log.With().Str("component", "debug api").Logger()), + requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}), prov.DebugAuthMiddleware, )) } + errorBodies := exhttp.ErrorBodies{ + NotFound: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint")).MarshalJSON()), + MethodNotAllowed: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint")).MarshalJSON()), + } prov.br.AS.Router.Handle("/_matrix/provision/", exhttp.ApplyMiddleware( prov.Router, exhttp.StripPrefix("/_matrix/provision"), @@ -157,6 +164,7 @@ func (prov *ProvisioningAPI) Init() { hlog.RequestIDHandler("request_id", "Request-Id"), exhttp.CORSMiddleware, requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}), + exhttp.HandleErrors(errorBodies), prov.AuthMiddleware, )) } diff --git a/federation/keyserver.go b/federation/keyserver.go index 35ec59fd..d32ba5cf 100644 --- a/federation/keyserver.go +++ b/federation/keyserver.go @@ -12,9 +12,13 @@ import ( "strconv" "time" + "github.com/rs/zerolog" + "github.com/rs/zerolog/hlog" "go.mau.fi/util/exerrors" "go.mau.fi/util/exhttp" "go.mau.fi/util/jsontime" + "go.mau.fi/util/ptr" + "go.mau.fi/util/requestlog" "maunium.net/go/mautrix" "maunium.net/go/mautrix/id" @@ -51,7 +55,7 @@ type KeyServer struct { } // Register registers the key server endpoints to the given router. -func (ks *KeyServer) Register(r *http.ServeMux) { +func (ks *KeyServer) Register(r *http.ServeMux, log zerolog.Logger) { r.HandleFunc("GET /.well-known/matrix/server", ks.GetWellKnown) r.HandleFunc("GET /_matrix/federation/v1/version", ks.GetServerVersion) keyRouter := http.NewServeMux() @@ -59,12 +63,15 @@ func (ks *KeyServer) Register(r *http.ServeMux) { keyRouter.HandleFunc("GET /v2/query/{serverName}", ks.GetQueryKeys) keyRouter.HandleFunc("POST /v2/query", ks.PostQueryKeys) errorBodies := exhttp.ErrorBodies{ - NotFound: exerrors.Must(json.Marshal(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint"))), - MethodNotAllowed: exerrors.Must(json.Marshal(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint"))), + NotFound: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint")).MarshalJSON()), + MethodNotAllowed: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint")).MarshalJSON()), } r.Handle("/_matrix/key/", exhttp.ApplyMiddleware( keyRouter, exhttp.StripPrefix("/_matrix/key"), + hlog.NewHandler(log), + hlog.RequestIDHandler("request_id", "Request-Id"), + requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}), exhttp.HandleErrors(errorBodies), )) } diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go index a5f07afa..07e30810 100644 --- a/mediaproxy/mediaproxy.go +++ b/mediaproxy/mediaproxy.go @@ -8,7 +8,6 @@ package mediaproxy import ( "context" - "encoding/json" "errors" "fmt" "io" @@ -23,8 +22,11 @@ import ( "time" "github.com/rs/zerolog" + "github.com/rs/zerolog/hlog" "go.mau.fi/util/exerrors" "go.mau.fi/util/exhttp" + "go.mau.fi/util/ptr" + "go.mau.fi/util/requestlog" "maunium.net/go/mautrix" "maunium.net/go/mautrix/federation" @@ -178,7 +180,7 @@ type ServerConfig struct { func (mp *MediaProxy) Listen(cfg ServerConfig) error { router := http.NewServeMux() - mp.RegisterRoutes(router) + mp.RegisterRoutes(router, zerolog.Nop()) return http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Hostname, cfg.Port), router) } @@ -203,23 +205,29 @@ func (mp *MediaProxy) EnableServerAuth(client *federation.Client, keyCache feder }) } -func (mp *MediaProxy) RegisterRoutes(router *http.ServeMux) { +func (mp *MediaProxy) RegisterRoutes(router *http.ServeMux, log zerolog.Logger) { errorBodies := exhttp.ErrorBodies{ - NotFound: exerrors.Must(json.Marshal(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint"))), - MethodNotAllowed: exerrors.Must(json.Marshal(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint"))), + NotFound: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Unrecognized endpoint")).MarshalJSON()), + MethodNotAllowed: exerrors.Must(ptr.Ptr(mautrix.MUnrecognized.WithMessage("Invalid method for endpoint")).MarshalJSON()), } router.Handle("/_matrix/federation/", exhttp.ApplyMiddleware( mp.FederationRouter, exhttp.StripPrefix("/_matrix/federation"), + hlog.NewHandler(log), + hlog.RequestIDHandler("request_id", "Request-Id"), + requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}), exhttp.HandleErrors(errorBodies), )) router.Handle("/_matrix/client/v1/media/", exhttp.ApplyMiddleware( mp.ClientMediaRouter, exhttp.StripPrefix("/_matrix/client/v1/media"), + hlog.NewHandler(log), + hlog.RequestIDHandler("request_id", "Request-Id"), exhttp.CORSMiddleware, + requestlog.AccessLogger(requestlog.Options{TrustXForwardedFor: true}), exhttp.HandleErrors(errorBodies), )) - mp.KeyServer.Register(router) + mp.KeyServer.Register(router, log) } var ErrInvalidMediaIDSyntax = errors.New("invalid media ID syntax") From 7bd136196d9dddd3432cc8b438a2efd3a63c723e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 29 Jul 2025 17:22:17 +0300 Subject: [PATCH 1265/1647] format/htmlparser: don't add link suffix if plaintext is only missing protocol Auto-linkification will add a protocol in the `href`, but usually won't touch the text part. We want to undo the linkification here since it doesn't carry any additional information. --- format/htmlparser.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/format/htmlparser.go b/format/htmlparser.go index f9d51e39..b4b1b9a4 100644 --- a/format/htmlparser.go +++ b/format/htmlparser.go @@ -286,7 +286,10 @@ func (parser *HTMLParser) linkToString(node *html.Node, ctx Context) string { } if parser.LinkConverter != nil { return parser.LinkConverter(str, href, ctx) - } else if str == href { + } else if str == href || + str == strings.TrimPrefix(href, "mailto:") || + str == strings.TrimPrefix(href, "http://") || + str == strings.TrimPrefix(href, "https://") { return str } return fmt.Sprintf("%s (%s)", str, href) From 3a2815178038567ba1b2ab937c626b9a619950a6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 29 Jul 2025 17:41:51 +0300 Subject: [PATCH 1266/1647] client: log method/url when retrying requests --- client.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/client.go b/client.go index 1f907608..f0aa3467 100644 --- a/client.go +++ b/client.go @@ -556,6 +556,8 @@ func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff } } log.Warn().Err(cause). + Str("method", req.Method). + Str("url", req.URL.String()). Int("retry_in_seconds", int(backoff.Seconds())). Msg("Request failed, retrying") select { From bcf92ba0e80a9329ae0e2be071930633d46f2d53 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 29 Jul 2025 17:42:04 +0300 Subject: [PATCH 1267/1647] appservice/intent: don't download avatar before setting on hungry --- appservice/intent.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/appservice/intent.go b/appservice/intent.go index a1245d74..fa9d9e7a 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -511,7 +511,7 @@ func (intent *IntentAPI) SetAvatarURL(ctx context.Context, avatarURL id.ContentU // No need to update return nil } - if !avatarURL.IsEmpty() { + if !avatarURL.IsEmpty() && !intent.SpecVersions.Supports(mautrix.BeeperFeatureHungry) { // Some homeservers require the avatar to be downloaded before setting it resp, _ := intent.Download(ctx, avatarURL) if resp != nil { From 91b2bcdb9fb75088e1ffff7bd10d412f9d7d2ea7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 31 Jul 2025 13:01:08 +0300 Subject: [PATCH 1268/1647] bridgev2/matrix: don't send connecting bridge states to cloud --- bridgev2/matrix/connector.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 158148f3..50d493f3 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -418,6 +418,10 @@ func (br *Connector) SendBridgeStatus(ctx context.Context, state *status.BridgeS Data: state, }) } else if br.Config.Homeserver.StatusEndpoint != "" { + // Connecting states aren't really relevant unless the bridge runs somewhere with an unreliable network + if state.StateEvent == status.StateConnecting { + return nil + } return state.SendHTTP(ctx, br.Config.Homeserver.StatusEndpoint, br.Config.AppService.ASToken) } else { return nil From 66e0ed47c0715e6ed3210e15c02f1b9dd41044e8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 31 Jul 2025 13:40:18 +0300 Subject: [PATCH 1269/1647] bridgev2/portal: include error in event handling results --- bridgev2/portal.go | 66 +++++++++++++++++++++++----------------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 2f973aae..5aae45e9 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -695,7 +695,7 @@ func (portal *Portal) handleMatrixReceipts(ctx context.Context, evt *event.Event sender, err := portal.Bridge.GetUserByMXID(ctx, userID) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to get user to handle read receipt") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } portal.handleMatrixReadReceipt(ctx, sender, evtID, receipt) } @@ -1746,7 +1746,7 @@ func (portal *Portal) handleMatrixTombstone(ctx context.Context, evt *event.Even senderUser, err = portal.Bridge.GetUserByMXID(ctx, evt.Sender) if err != nil { log.Err(err).Msg("Failed to get tombstone sender user") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } } content, ok := evt.Content.Parsed.(*event.TombstoneEventContent) @@ -1767,14 +1767,14 @@ func (portal *Portal) handleMatrixTombstone(ctx context.Context, evt *event.Even err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, true) if err != nil { log.Err(err).Msg("Failed to clean up Matrix room") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } return EventHandlingResultSuccess } existingMemberEvt, err := portal.Bridge.Matrix.GetMemberInfo(ctx, content.ReplacementRoom, portal.Bridge.Bot.GetMXID()) if err != nil { log.Err(err).Msg("Failed to get member info of bot in replacement room") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } leaveOnError := func() { if existingMemberEvt != nil && existingMemberEvt.Membership == event.MembershipJoin { @@ -1805,14 +1805,14 @@ func (portal *Portal) handleMatrixTombstone(ctx context.Context, evt *event.Even err = portal.Bridge.Bot.EnsureJoined(ctx, content.ReplacementRoom, EnsureJoinedParams{Via: via}) if err != nil { log.Err(err).Msg("Failed to join replacement room from tombstone") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } if !sentByBridge && !senderUser.Permissions.Admin { powers, err := portal.Bridge.Matrix.GetPowerLevels(ctx, content.ReplacementRoom) if err != nil { log.Err(err).Msg("Failed to get power levels in replacement room") leaveOnError() - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } if powers.GetUserLevel(evt.Sender) < powers.Invite() { log.Warn().Msg("Tombstone sender doesn't have enough power to invite the bot to the replacement room") @@ -1840,7 +1840,7 @@ func (portal *Portal) handleMatrixTombstone(ctx context.Context, evt *event.Even err = portal.Save(ctx) if err != nil { log.Err(err).Msg("Failed to save portal after tombstone") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } log.Info().Msg("Successfully followed tombstone and updated portal MXID") err = portal.Bridge.DB.UserPortal.MarkAllNotInSpace(ctx, portal.PortalKey) @@ -1989,7 +1989,7 @@ func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, err = portal.createMatrixRoomInLoop(ctx, source, info, bundle) if err != nil { log.Err(err).Msg("Failed to create portal to handle event") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } if evtType == RemoteEventChatResync { log.Debug().Msg("Not handling chat resync event further as portal was created by it") @@ -2420,7 +2420,7 @@ func (portal *Portal) handleRemoteUpsert(ctx context.Context, source *UserLogin, err = portal.Bridge.DB.Message.Update(ctx, part) if err != nil { log.Err(err).Str("part_id", string(part.PartID)).Msg("Failed to update message part in database") - handleRes = EventHandlingResultFailed + handleRes = EventHandlingResultFailed.WithError(err) } } } @@ -2484,7 +2484,7 @@ func (portal *Portal) handleRemoteMessage(ctx context.Context, source *UserLogin } else { log.Err(err).Msg("Failed to convert remote message") portal.sendRemoteErrorNotice(ctx, intent, err, ts, "message") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } } _, res = portal.sendConvertedMessage(ctx, evt.GetID(), intent, evt.GetSender().Sender, converted, ts, getStreamOrder(evt), nil) @@ -2529,7 +2529,7 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e existing, err = portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, targetID) if err != nil { log.Err(err).Msg("Failed to get edit target message") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } } if existing == nil { @@ -2554,7 +2554,7 @@ func (portal *Portal) handleRemoteEdit(ctx context.Context, source *UserLogin, e } else if err != nil { log.Err(err).Msg("Failed to convert remote edit") portal.sendRemoteErrorNotice(ctx, intent, err, ts, "edit") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } res := portal.sendConvertedEdit(ctx, existing[0].ID, evt.GetSender().Sender, converted, intent, ts, getStreamOrder(evt)) if portal.currentlyTypingGhosts.Pop(intent.GetMXID()) { @@ -2703,7 +2703,7 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User targetMessage, err := portal.getTargetMessagePart(ctx, evt) if err != nil { log.Err(err).Msg("Failed to get target message for reaction") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } else if targetMessage == nil { // TODO use deterministic event ID as target if applicable? log.Warn().Msg("Target message for reaction not found") @@ -2717,7 +2717,7 @@ func (portal *Portal) handleRemoteReactionSync(ctx context.Context, source *User } if err != nil { log.Err(err).Msg("Failed to get existing reactions for reaction sync") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } existing := make(map[networkid.UserID]map[networkid.EmojiID]*database.Reaction) for _, existingReaction := range existingReactions { @@ -2839,7 +2839,7 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi targetMessage, err := portal.getTargetMessagePart(ctx, evt) if err != nil { log.Err(err).Msg("Failed to get target message for reaction") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } else if targetMessage == nil { // TODO use deterministic event ID as target if applicable? log.Warn().Msg("Target message for reaction not found") @@ -2849,7 +2849,7 @@ func (portal *Portal) handleRemoteReaction(ctx context.Context, source *UserLogi existingReaction, err := portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, targetMessage.ID, targetMessage.PartID, evt.GetSender().Sender, emojiID) if err != nil { log.Err(err).Msg("Failed to check if reaction is a duplicate") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } else if existingReaction != nil && (emojiID != "" || existingReaction.Emoji == emoji) { log.Debug().Msg("Ignoring duplicate reaction") return EventHandlingResultIgnored @@ -2919,7 +2919,7 @@ func (portal *Portal) sendConvertedReaction( }) if err != nil { logContext(log.Err(err)).Msg("Failed to send reaction to Matrix") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } logContext(log.Debug()). Stringer("event_id", resp.EventID). @@ -2928,7 +2928,7 @@ func (portal *Portal) sendConvertedReaction( err = portal.Bridge.DB.Reaction.Upsert(ctx, dbReaction) if err != nil { logContext(log.Err(err)).Msg("Failed to save reaction to database") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } return EventHandlingResultSuccess } @@ -2954,7 +2954,7 @@ func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *Us targetReaction, err := portal.getTargetReaction(ctx, evt) if err != nil { log.Err(err).Msg("Failed to get target reaction for removal") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } else if targetReaction == nil { log.Warn().Msg("Target reaction not found") return EventHandlingResultIgnored @@ -2978,7 +2978,7 @@ func (portal *Portal) handleRemoteReactionRemove(ctx context.Context, source *Us }, &MatrixSendExtra{Timestamp: ts, ReactionMeta: targetReaction}) if err != nil { log.Err(err).Stringer("reaction_mxid", targetReaction.MXID).Msg("Failed to redact reaction") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } err = portal.Bridge.DB.Reaction.Delete(ctx, targetReaction) if err != nil { @@ -2992,7 +2992,7 @@ func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *Use targetParts, err := portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, evt.GetTargetMessage()) if err != nil { log.Err(err).Msg("Failed to get target message for removal") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } else if len(targetParts) == 0 { log.Debug().Msg("Target message not found") return EventHandlingResultIgnored @@ -3003,7 +3003,7 @@ func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *Use logins, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) if err != nil { log.Err(err).Msg("Failed to check if portal has other logins") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } else if len(logins) > 1 { log.Debug().Msg("Ignoring delete for me event in portal with multiple logins") return EventHandlingResultIgnored @@ -3069,7 +3069,7 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL if err != nil { log.Err(err).Str("last_target_id", string(lastTargetID)). Msg("Failed to get last target message for read receipt") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } else if lastTarget == nil { log.Debug().Str("last_target_id", string(lastTargetID)). Msg("Last target message not found") @@ -3088,7 +3088,7 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL if err != nil { log.Err(err).Str("target_id", string(targetID)). Msg("Failed to get target message for read receipt") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } else if target != nil && !target.HasFakeMXID() && (lastTarget == nil || target.Timestamp.After(lastTarget.Timestamp)) { lastTarget = target } @@ -3126,7 +3126,7 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL } if err != nil { addTargetLog(log.Err(err)).Msg("Failed to bridge read receipt") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } else { addTargetLog(log.Debug()).Msg("Bridged read receipt") } @@ -3148,7 +3148,7 @@ func (portal *Portal) handleRemoteMarkUnread(ctx context.Context, source *UserLo err := dp.MarkUnread(ctx, portal.MXID, evt.GetUnread()) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to bridge mark unread event") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } return EventHandlingResultSuccess } @@ -3166,7 +3166,7 @@ func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *U targetParts, err := portal.Bridge.DB.Message.GetAllPartsByID(ctx, portal.Receiver, target) if err != nil { log.Err(err).Str("target_id", string(target)).Msg("Failed to get target message for delivery receipt") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } else if len(targetParts) == 0 { continue } else if _, sentByGhost := portal.Bridge.Matrix.ParseGhostMXID(targetParts[0].SenderMXID); sentByGhost { @@ -3201,7 +3201,7 @@ func (portal *Portal) handleRemoteTyping(ctx context.Context, source *UserLogin, err := intent.MarkTyping(ctx, portal.MXID, typingType, timeout) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to bridge typing event") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } if timeout == 0 { portal.currentlyTypingGhosts.Remove(intent.GetMXID()) @@ -3215,7 +3215,7 @@ func (portal *Portal) handleRemoteChatInfoChange(ctx context.Context, source *Us info, err := evt.GetChatInfoChange(ctx) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to get chat info change") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } portal.ProcessChatInfoChange(ctx, evt.GetSender(), source, info, getEventTS(evt)) return EventHandlingResultSuccess @@ -3259,7 +3259,7 @@ func (portal *Portal) handleRemoteChatDelete(ctx context.Context, source *UserLo logins, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) if err != nil { log.Err(err).Msg("Failed to check if portal has other logins") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } var ownUP *database.UserPortal logins = slices.DeleteFunc(logins, func(up *database.UserPortal) bool { @@ -3289,7 +3289,7 @@ func (portal *Portal) handleRemoteChatDelete(ctx context.Context, source *UserLo ) if err != nil { log.Err(err).Msg("Failed to send leave state event for user after remote chat delete") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } else { log.Debug().Msg("Sent leave state event for user after remote chat delete") return EventHandlingResultSuccess @@ -3299,12 +3299,12 @@ func (portal *Portal) handleRemoteChatDelete(ctx context.Context, source *UserLo err := portal.Delete(ctx) if err != nil { log.Err(err).Msg("Failed to delete portal from database") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, false) if err != nil { log.Err(err).Msg("Failed to delete Matrix room") - return EventHandlingResultFailed + return EventHandlingResultFailed.WithError(err) } else { log.Info().Msg("Deleted room after remote chat delete event") return EventHandlingResultSuccess From 94f53c5853c63065369dc4b23ad920124e806dec Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 31 Jul 2025 14:00:00 +0300 Subject: [PATCH 1270/1647] bridgev2/cryptostore: add missing escape clause to not like --- bridgev2/matrix/cryptostore.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/matrix/cryptostore.go b/bridgev2/matrix/cryptostore.go index 234797a6..4c3b5d30 100644 --- a/bridgev2/matrix/cryptostore.go +++ b/bridgev2/matrix/cryptostore.go @@ -45,7 +45,7 @@ func (store *SQLCryptoStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, WHERE room_id=$1 AND (membership='join' OR membership='invite') AND user_id<>$2 - AND user_id NOT LIKE $3 + AND user_id NOT LIKE $3 ESCAPE '\' `, roomID, store.UserID, store.GhostIDFormat) if err != nil { return From 10b26b507df823bee68b80c15832eca9fa383e7e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 1 Aug 2025 10:38:02 +0300 Subject: [PATCH 1271/1647] client: fix updating state store in CreateRoom --- client.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/client.go b/client.go index f0aa3467..53ac6e10 100644 --- a/client.go +++ b/client.go @@ -1369,6 +1369,10 @@ func (cli *Client) CreateRoom(ctx context.Context, req *ReqCreateRoom) (resp *Re Msg("Failed to update creator membership in state store after creating room") } for _, evt := range req.InitialState { + evt.RoomID = resp.RoomID + if evt.StateKey == nil { + evt.StateKey = ptr.Ptr("") + } UpdateStateStore(ctx, cli.StateStore, evt) } inviteMembership := event.MembershipInvite From 190c0de94f19989a0169cf1e95cdb9664e36ec7d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 1 Aug 2025 10:51:00 +0300 Subject: [PATCH 1272/1647] bridgev2/matrix: always clear mx_user_profile when deleting room --- bridgev2/matrix/intent.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 4a337e53..7d78b5a2 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -578,7 +578,15 @@ func (as *ASIntent) MarkAsDM(ctx context.Context, roomID id.RoomID, withUser id. func (as *ASIntent) DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnly bool) error { if as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureRoomYeeting) { - return as.Matrix.BeeperDeleteRoom(ctx, roomID) + err := as.Matrix.BeeperDeleteRoom(ctx, roomID) + if err != nil { + return err + } + err = as.Matrix.StateStore.ClearCachedMembers(ctx, roomID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to clear cached members while cleaning up portal") + } + return nil } members, err := as.Matrix.JoinedMembers(ctx, roomID) if err != nil { From 66ec881a741661b98ef8ceb62e3d0bbdabc4e486 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 1 Aug 2025 11:00:37 +0300 Subject: [PATCH 1273/1647] bridgev2/matrix: add hack for resyncing encryption state cache --- bridgev2/database/kvstore.go | 5 ++-- bridgev2/matrix/connector.go | 50 ++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/bridgev2/database/kvstore.go b/bridgev2/database/kvstore.go index 5a1af019..52b4984e 100644 --- a/bridgev2/database/kvstore.go +++ b/bridgev2/database/kvstore.go @@ -20,8 +20,9 @@ import ( type Key string const ( - KeySplitPortalsEnabled Key = "split_portals_enabled" - KeyBridgeInfoVersion Key = "bridge_info_version" + KeySplitPortalsEnabled Key = "split_portals_enabled" + KeyBridgeInfoVersion Key = "bridge_info_version" + KeyEncryptionStateResynced Key = "encryption_state_resynced" ) type KVQuery struct { diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 50d493f3..9c19c472 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -170,6 +170,17 @@ func (br *Connector) Start(ctx context.Context) error { if err != nil { return err } + needsStateResync := br.Config.Encryption.Default && + br.Bridge.DB.KV.Get(ctx, database.KeyEncryptionStateResynced) != "true" + if needsStateResync { + dbExists, err := br.StateStore.TableExists(ctx, "mx_version") + if err != nil { + return fmt.Errorf("failed to check if mx_version table exists: %w", err) + } else if !dbExists { + needsStateResync = false + br.Bridge.DB.KV.Set(ctx, database.KeyEncryptionStateResynced, "true") + } + } err = br.StateStore.Upgrade(ctx) if err != nil { return bridgev2.DBUpgradeError{Section: "matrix_state", Err: err} @@ -213,9 +224,48 @@ func (br *Connector) Start(ctx context.Context) error { br.wsStopPinger = make(chan struct{}, 1) go br.websocketServerPinger() } + if needsStateResync { + br.ResyncEncryptionState(ctx) + } return nil } +func (br *Connector) ResyncEncryptionState(ctx context.Context) { + log := zerolog.Ctx(ctx) + roomIDScanner := dbutil.ConvertRowFn[id.RoomID](dbutil.ScanSingleColumn[id.RoomID]) + rooms, err := roomIDScanner.NewRowIter(br.Bridge.DB.Query(ctx, ` + SELECT rooms.room_id + FROM (SELECT DISTINCT(room_id) FROM mx_user_profile) rooms + LEFT JOIN mx_room_state ON rooms.room_id = mx_room_state.room_id + WHERE mx_room_state.encryption IS NULL + `)).AsList() + if err != nil { + log.Err(err).Msg("Failed to get room list to resync state") + return + } + var failedCount, successCount, forbiddenCount int + for _, roomID := range rooms { + var outContent *event.EncryptionEventContent + err = br.Bot.StateEvent(ctx, roomID, event.StateEncryption, "", &outContent) + if errors.Is(err, mautrix.MForbidden) { + // Most likely non-existent room + log.Debug().Err(err).Stringer("room_id", roomID).Msg("Failed to get state for room") + forbiddenCount++ + } else if err != nil { + log.Err(err).Stringer("room_id", roomID).Msg("Failed to get state for room") + failedCount++ + } else { + successCount++ + } + } + br.Bridge.DB.KV.Set(ctx, database.KeyEncryptionStateResynced, "true") + log.Info(). + Int("success_count", successCount). + Int("forbidden_count", forbiddenCount). + Int("failed_count", failedCount). + Msg("Resynced rooms") +} + func (br *Connector) GetPublicAddress() string { if br.Config.AppService.PublicAddress == "https://bridge.example.com" { return "" From 196164ed6749a91b3ae12c6ce2fd836697afcd38 Mon Sep 17 00:00:00 2001 From: "timedout (aka nexy7574)" Date: Fri, 1 Aug 2025 09:47:53 +0100 Subject: [PATCH 1274/1647] event: add join_authorised_via_users_server to MemberEventContent (#395) Adds `JoinAuthorisedViaUsersServer` (`join_authorised_via_users_server`) to `MemberEventContent`, introduced in room version 8 --- event/member.go | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/event/member.go b/event/member.go index 02b7cae9..53387e8b 100644 --- a/event/member.go +++ b/event/member.go @@ -35,13 +35,14 @@ const ( // MemberEventContent represents the content of a m.room.member state event. // https://spec.matrix.org/v1.2/client-server-api/#mroommember type MemberEventContent struct { - Membership Membership `json:"membership"` - AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` - Displayname string `json:"displayname,omitempty"` - IsDirect bool `json:"is_direct,omitempty"` - ThirdPartyInvite *ThirdPartyInvite `json:"third_party_invite,omitempty"` - Reason string `json:"reason,omitempty"` - MSC3414File *EncryptedFileInfo `json:"org.matrix.msc3414.file,omitempty"` + Membership Membership `json:"membership"` + AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` + Displayname string `json:"displayname,omitempty"` + IsDirect bool `json:"is_direct,omitempty"` + ThirdPartyInvite *ThirdPartyInvite `json:"third_party_invite,omitempty"` + Reason string `json:"reason,omitempty"` + JoinAuthorisedViaUsersServer id.UserID `json:"join_authorised_via_users_server,omitempty"` + MSC3414File *EncryptedFileInfo `json:"org.matrix.msc3414.file,omitempty"` MSC4293RedactEvents bool `json:"org.matrix.msc4293.redact_events,omitempty"` } From 0a804c58a13af89e0bcfd2c998999be31cc9f5d0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 1 Aug 2025 12:15:38 +0300 Subject: [PATCH 1275/1647] bridgev2/matrix: don't ensure joined for state resync --- bridgev2/matrix/connector.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 9c19c472..51eeb42b 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -246,7 +246,7 @@ func (br *Connector) ResyncEncryptionState(ctx context.Context) { var failedCount, successCount, forbiddenCount int for _, roomID := range rooms { var outContent *event.EncryptionEventContent - err = br.Bot.StateEvent(ctx, roomID, event.StateEncryption, "", &outContent) + err = br.Bot.Client.StateEvent(ctx, roomID, event.StateEncryption, "", &outContent) if errors.Is(err, mautrix.MForbidden) { // Most likely non-existent room log.Debug().Err(err).Stringer("room_id", roomID).Msg("Failed to get state for room") From aeeea095495b4c4e9e454ad1d415970d728f572d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 1 Aug 2025 12:19:51 +0300 Subject: [PATCH 1276/1647] sqlstatestore: ensure empty room/user ids aren't stored in db --- bridgev2/matrix/connector.go | 5 +++- sqlstatestore/statestore.go | 27 ++++++++++++++++++++++ sqlstatestore/v00-latest-revision.sql | 2 +- sqlstatestore/v09-clear-empty-room-ids.sql | 3 +++ 4 files changed, 35 insertions(+), 2 deletions(-) create mode 100644 sqlstatestore/v09-clear-empty-room-ids.sql diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 51eeb42b..a1f7d140 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -235,7 +235,7 @@ func (br *Connector) ResyncEncryptionState(ctx context.Context) { roomIDScanner := dbutil.ConvertRowFn[id.RoomID](dbutil.ScanSingleColumn[id.RoomID]) rooms, err := roomIDScanner.NewRowIter(br.Bridge.DB.Query(ctx, ` SELECT rooms.room_id - FROM (SELECT DISTINCT(room_id) FROM mx_user_profile) rooms + FROM (SELECT DISTINCT(room_id) FROM mx_user_profile WHERE room_id<>'') rooms LEFT JOIN mx_room_state ON rooms.room_id = mx_room_state.room_id WHERE mx_room_state.encryption IS NULL `)).AsList() @@ -245,6 +245,9 @@ func (br *Connector) ResyncEncryptionState(ctx context.Context) { } var failedCount, successCount, forbiddenCount int for _, roomID := range rooms { + if roomID == "" { + continue + } var outContent *event.EncryptionEventContent err = br.Bot.Client.StateEvent(ctx, roomID, event.StateEncryption, "", &outContent) if errors.Is(err, mautrix.MForbidden) { diff --git a/sqlstatestore/statestore.go b/sqlstatestore/statestore.go index f9a7e421..0ed4b698 100644 --- a/sqlstatestore/statestore.go +++ b/sqlstatestore/statestore.go @@ -62,6 +62,9 @@ func (store *SQLStateStore) IsRegistered(ctx context.Context, userID id.UserID) } func (store *SQLStateStore) MarkRegistered(ctx context.Context, userID id.UserID) error { + if userID == "" { + return fmt.Errorf("user ID is empty") + } _, err := store.Exec(ctx, "INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID) return err } @@ -182,6 +185,11 @@ func (store *SQLStateStore) IsMembership(ctx context.Context, roomID id.RoomID, } func (store *SQLStateStore) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error { + if roomID == "" { + return fmt.Errorf("room ID is empty") + } else if userID == "" { + return fmt.Errorf("user ID is empty") + } _, err := store.Exec(ctx, ` INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, '', '') ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership @@ -214,6 +222,11 @@ func (u *userProfileRow) GetMassInsertValues() [5]any { var userProfileMassInserter = dbutil.NewMassInsertBuilder[*userProfileRow, [1]any](insertUserProfileQuery, "($1, $%d, $%d, $%d, $%d, $%d)") func (store *SQLStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error { + if roomID == "" { + return fmt.Errorf("room ID is empty") + } else if userID == "" { + return fmt.Errorf("user ID is empty") + } var nameSkeleton []byte if !store.DisableNameDisambiguation && len(member.Displayname) > 0 { nameSkeletonArr := confusable.SkeletonHash(member.Displayname) @@ -235,6 +248,9 @@ func (store *SQLStateStore) IsConfusableName(ctx context.Context, roomID id.Room const userProfileMassInsertBatchSize = 500 func (store *SQLStateStore) ReplaceCachedMembers(ctx context.Context, roomID id.RoomID, evts []*event.Event, onlyMemberships ...event.Membership) error { + if roomID == "" { + return fmt.Errorf("room ID is empty") + } return store.DoTxn(ctx, nil, func(ctx context.Context) error { err := store.ClearCachedMembers(ctx, roomID, onlyMemberships...) if err != nil { @@ -305,6 +321,9 @@ func (store *SQLStateStore) HasFetchedMembers(ctx context.Context, roomID id.Roo } func (store *SQLStateStore) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error { + if roomID == "" { + return fmt.Errorf("room ID is empty") + } _, err := store.Exec(ctx, ` INSERT INTO mx_room_state (room_id, members_fetched) VALUES ($1, true) ON CONFLICT (room_id) DO UPDATE SET members_fetched=true @@ -334,6 +353,9 @@ func (store *SQLStateStore) GetAllMembers(ctx context.Context, roomID id.RoomID) } func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error { + if roomID == "" { + return fmt.Errorf("room ID is empty") + } contentBytes, err := json.Marshal(content) if err != nil { return fmt.Errorf("failed to marshal content JSON: %w", err) @@ -371,6 +393,9 @@ func (store *SQLStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) ( } func (store *SQLStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error { + if roomID == "" { + return fmt.Errorf("room ID is empty") + } _, err := store.Exec(ctx, ` INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2) ON CONFLICT (room_id) DO UPDATE SET power_levels=excluded.power_levels @@ -421,6 +446,8 @@ func (store *SQLStateStore) HasPowerLevel(ctx context.Context, roomID id.RoomID, func (store *SQLStateStore) SetCreate(ctx context.Context, evt *event.Event) error { if evt.Type != event.StateCreate { return fmt.Errorf("invalid event type for create event: %s", evt.Type) + } else if evt.RoomID == "" { + return fmt.Errorf("room ID is empty") } _, err := store.Exec(ctx, ` INSERT INTO mx_room_state (room_id, create_event) VALUES ($1, $2) diff --git a/sqlstatestore/v00-latest-revision.sql b/sqlstatestore/v00-latest-revision.sql index 132ed1ab..b5a858ec 100644 --- a/sqlstatestore/v00-latest-revision.sql +++ b/sqlstatestore/v00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v8 (compatible with v3+): Latest revision +-- v0 -> v9 (compatible with v3+): Latest revision CREATE TABLE mx_registrations ( user_id TEXT PRIMARY KEY diff --git a/sqlstatestore/v09-clear-empty-room-ids.sql b/sqlstatestore/v09-clear-empty-room-ids.sql new file mode 100644 index 00000000..ca951068 --- /dev/null +++ b/sqlstatestore/v09-clear-empty-room-ids.sql @@ -0,0 +1,3 @@ +-- v9 (compatible with v3+): Clear invalid rows +DELETE FROM mx_room_state WHERE room_id=''; +DELETE FROM mx_user_profile WHERE room_id='' OR user_id=''; From 09e4706fdba6a6f05900e69a723b1c24fa983cdb Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 1 Aug 2025 14:13:55 +0300 Subject: [PATCH 1277/1647] crypto/backup: allow encrypting session without private key --- crypto/backup/encryptedsessiondata.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/crypto/backup/encryptedsessiondata.go b/crypto/backup/encryptedsessiondata.go index ec551dbe..25250178 100644 --- a/crypto/backup/encryptedsessiondata.go +++ b/crypto/backup/encryptedsessiondata.go @@ -68,6 +68,10 @@ func calculateCompatMAC(macKey []byte) []byte { // // [Section 11.12.3.2.2 of the Spec]: https://spec.matrix.org/v1.9/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2 func EncryptSessionData[T any](backupKey *MegolmBackupKey, sessionData T) (*EncryptedSessionData[T], error) { + return EncryptSessionDataWithPubkey(backupKey.PublicKey(), sessionData) +} + +func EncryptSessionDataWithPubkey[T any](pubkey *ecdh.PublicKey, sessionData T) (*EncryptedSessionData[T], error) { sessionJSON, err := json.Marshal(sessionData) if err != nil { return nil, err @@ -78,7 +82,7 @@ func EncryptSessionData[T any](backupKey *MegolmBackupKey, sessionData T) (*Encr return nil, err } - sharedSecret, err := ephemeralKey.ECDH(backupKey.PublicKey()) + sharedSecret, err := ephemeralKey.ECDH(pubkey) if err != nil { return nil, err } From 654b6b1d4574e926fc6534654a2d18446f93b168 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Sat, 2 Aug 2025 11:39:18 -0600 Subject: [PATCH 1278/1647] crypto: replace t.Fatal and t.Error with require and assert Signed-off-by: Sumner Evans --- crypto/aescbc/aes_cbc_test.go | 41 ++--- crypto/canonicaljson/json_test.go | 79 +++++---- crypto/cross_sign_test.go | 72 ++++---- crypto/machine_test.go | 59 ++----- crypto/store_test.go | 273 ++++++++++-------------------- crypto/utils/utils_test.go | 36 ++-- 6 files changed, 210 insertions(+), 350 deletions(-) diff --git a/crypto/aescbc/aes_cbc_test.go b/crypto/aescbc/aes_cbc_test.go index bb03f706..d6611dc9 100644 --- a/crypto/aescbc/aes_cbc_test.go +++ b/crypto/aescbc/aes_cbc_test.go @@ -7,11 +7,13 @@ package aescbc_test import ( - "bytes" "crypto/aes" "crypto/rand" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "maunium.net/go/mautrix/crypto/aescbc" ) @@ -22,32 +24,23 @@ func TestAESCBC(t *testing.T) { // The key length can be 32, 24, 16 bytes (OR in bits: 128, 192 or 256) key := make([]byte, 32) _, err = rand.Read(key) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) iv := make([]byte, aes.BlockSize) _, err = rand.Read(iv) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) plaintext = []byte("secret message for testing") //increase to next block size for len(plaintext)%8 != 0 { plaintext = append(plaintext, []byte("-")...) } - if ciphertext, err = aescbc.Encrypt(key, iv, plaintext); err != nil { - t.Fatal(err) - } + ciphertext, err = aescbc.Encrypt(key, iv, plaintext) + require.NoError(t, err) resultPlainText, err := aescbc.Decrypt(key, iv, ciphertext) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) - if string(resultPlainText) != string(plaintext) { - t.Fatalf("message '%s' (length %d) != '%s'", resultPlainText, len(resultPlainText), plaintext) - } + assert.Equal(t, string(resultPlainText), string(plaintext)) } func TestAESCBCCase1(t *testing.T) { @@ -61,18 +54,10 @@ func TestAESCBCCase1(t *testing.T) { key := make([]byte, 32) iv := make([]byte, aes.BlockSize) encrypted, err := aescbc.Encrypt(key, iv, input) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(expected, encrypted) { - t.Fatalf("encrypted did not match expected:\n%v\n%v\n", encrypted, expected) - } + require.NoError(t, err) + assert.Equal(t, expected, encrypted, "encrypted output does not match expected") decrypted, err := aescbc.Decrypt(key, iv, encrypted) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(input, decrypted) { - t.Fatalf("decrypted did not match expected:\n%v\n%v\n", decrypted, input) - } + require.NoError(t, err) + assert.Equal(t, input, decrypted, "decrypted output does not match input") } diff --git a/crypto/canonicaljson/json_test.go b/crypto/canonicaljson/json_test.go index d1a7f0a5..36476aa4 100644 --- a/crypto/canonicaljson/json_test.go +++ b/crypto/canonicaljson/json_test.go @@ -17,31 +17,43 @@ package canonicaljson import ( "testing" + + "github.com/stretchr/testify/assert" ) -func testSortJSON(t *testing.T, input, want string) { - got := SortJSON([]byte(input), nil) - - // Squash out the whitespace before comparing the JSON in case SortJSON had inserted whitespace. - if string(CompactJSON(got, nil)) != want { - t.Errorf("SortJSON(%q): want %q got %q", input, want, got) - } -} - func TestSortJSON(t *testing.T) { - testSortJSON(t, `[{"b":"two","a":1}]`, `[{"a":1,"b":"two"}]`) - testSortJSON(t, `{"B":{"4":4,"3":3},"A":{"1":1,"2":2}}`, - `{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`) - testSortJSON(t, `[true,false,null]`, `[true,false,null]`) - testSortJSON(t, `[9007199254740991]`, `[9007199254740991]`) - testSortJSON(t, "\t\n[9007199254740991]", `[9007199254740991]`) + var tests = []struct { + input string + want string + }{ + {"{}", "{}"}, + {`[{"b":"two","a":1}]`, `[{"a":1,"b":"two"}]`}, + {`{"B":{"4":4,"3":3},"A":{"1":1,"2":2}}`, `{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`}, + {`[true,false,null]`, `[true,false,null]`}, + {`[9007199254740991]`, `[9007199254740991]`}, + {"\t\n[9007199254740991]", `[9007199254740991]`}, + {`[true,false,null]`, `[true,false,null]`}, + {`[{"b":"two","a":1}]`, `[{"a":1,"b":"two"}]`}, + {`{"B":{"4":4,"3":3},"A":{"1":1,"2":2}}`, `{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`}, + {`[true,false,null]`, `[true,false,null]`}, + {`[9007199254740991]`, `[9007199254740991]`}, + {"\t\n[9007199254740991]", `[9007199254740991]`}, + {`[true,false,null]`, `[true,false,null]`}, + } + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + got := SortJSON([]byte(test.input), nil) + + // Squash out the whitespace before comparing the JSON in case SortJSON had inserted whitespace. + assert.EqualValues(t, test.want, string(CompactJSON(got, nil))) + }) + } } func testCompactJSON(t *testing.T, input, want string) { + t.Helper() got := string(CompactJSON([]byte(input), nil)) - if got != want { - t.Errorf("CompactJSON(%q): want %q got %q", input, want, got) - } + assert.EqualValues(t, want, got) } func TestCompactJSON(t *testing.T) { @@ -74,18 +86,23 @@ func TestCompactJSON(t *testing.T) { testCompactJSON(t, `["\"\\\/"]`, `["\"\\/"]`) } -func testReadHex(t *testing.T, input string, want uint32) { - got := readHexDigits([]byte(input)) - if want != got { - t.Errorf("readHexDigits(%q): want 0x%x got 0x%x", input, want, got) +func TestReadHex(t *testing.T) { + tests := []struct { + input string + want uint32 + }{ + + {"0123", 0x0123}, + {"4567", 0x4567}, + {"89AB", 0x89AB}, + {"CDEF", 0xCDEF}, + {"89ab", 0x89AB}, + {"cdef", 0xCDEF}, + } + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + got := readHexDigits([]byte(test.input)) + assert.Equal(t, test.want, got) + }) } } - -func TestReadHex(t *testing.T) { - testReadHex(t, "0123", 0x0123) - testReadHex(t, "4567", 0x4567) - testReadHex(t, "89AB", 0x89AB) - testReadHex(t, "CDEF", 0xCDEF) - testReadHex(t, "89ab", 0x89AB) - testReadHex(t, "cdef", 0xCDEF) -} diff --git a/crypto/cross_sign_test.go b/crypto/cross_sign_test.go index 5e1ffd50..b70370a2 100644 --- a/crypto/cross_sign_test.go +++ b/crypto/cross_sign_test.go @@ -13,6 +13,8 @@ import ( "testing" "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.mau.fi/util/dbutil" "maunium.net/go/mautrix" @@ -24,17 +26,12 @@ var noopLogger = zerolog.Nop() func getOlmMachine(t *testing.T) *OlmMachine { rawDB, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000") - if err != nil { - t.Fatalf("Error opening db: %v", err) - } + require.NoError(t, err, "Error opening raw database") db, err := dbutil.NewWithDB(rawDB, "sqlite3") - if err != nil { - t.Fatalf("Error opening db: %v", err) - } + require.NoError(t, err, "Error creating database wrapper") sqlStore := NewSQLCryptoStore(db, nil, "accid", id.DeviceID("dev"), []byte("test")) - if err = sqlStore.DB.Upgrade(context.TODO()); err != nil { - t.Fatalf("Error creating tables: %v", err) - } + err = sqlStore.DB.Upgrade(context.TODO()) + require.NoError(t, err, "Error upgrading database") userID := id.UserID("@mautrix") mk, _ := olm.NewPKSigning() @@ -66,29 +63,25 @@ func TestTrustOwnDevice(t *testing.T) { DeviceID: "device", SigningKey: id.Ed25519("deviceKey"), } - if m.IsDeviceTrusted(context.TODO(), ownDevice) { - t.Error("Own device trusted while it shouldn't be") - } + assert.False(t, m.IsDeviceTrusted(context.TODO(), ownDevice), "Own device trusted while it shouldn't be") m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), ownDevice.UserID, m.CrossSigningKeys.MasterKey.PublicKey(), "sig1") m.CryptoStore.PutSignature(context.TODO(), ownDevice.UserID, ownDevice.SigningKey, ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), "sig2") - if trusted, _ := m.IsUserTrusted(context.TODO(), ownDevice.UserID); !trusted { - t.Error("Own user not trusted while they should be") - } - if !m.IsDeviceTrusted(context.TODO(), ownDevice) { - t.Error("Own device not trusted while it should be") - } + trusted, err := m.IsUserTrusted(context.TODO(), ownDevice.UserID) + require.NoError(t, err, "Error checking if own user is trusted") + assert.True(t, trusted, "Own user not trusted while they should be") + assert.True(t, m.IsDeviceTrusted(context.TODO(), ownDevice), "Own device not trusted while it should be") } func TestTrustOtherUser(t *testing.T) { m := getOlmMachine(t) otherUser := id.UserID("@user") - if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted { - t.Error("Other user trusted while they shouldn't be") - } + trusted, err := m.IsUserTrusted(context.TODO(), otherUser) + require.NoError(t, err, "Error checking if other user is trusted") + assert.False(t, trusted, "Other user trusted while they shouldn't be") theirMasterKey, _ := olm.NewPKSigning() m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey()) @@ -100,16 +93,16 @@ func TestTrustOtherUser(t *testing.T) { m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(), m.Client.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey(), "invalid_sig") - if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted { - t.Error("Other user trusted before their master key has been signed with our user-signing key") - } + trusted, err = m.IsUserTrusted(context.TODO(), otherUser) + require.NoError(t, err, "Error checking if other user is trusted") + assert.False(t, trusted, "Other user trusted before their master key has been signed with our user-signing key") m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(), "sig2") - if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted { - t.Error("Other user not trusted while they should be") - } + trusted, err = m.IsUserTrusted(context.TODO(), otherUser) + require.NoError(t, err, "Error checking if other user is trusted") + assert.True(t, trusted, "Other user not trusted while they should be") } func TestTrustOtherDevice(t *testing.T) { @@ -120,12 +113,11 @@ func TestTrustOtherDevice(t *testing.T) { DeviceID: "theirDevice", SigningKey: id.Ed25519("theirDeviceKey"), } - if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); trusted { - t.Error("Other user trusted while they shouldn't be") - } - if m.IsDeviceTrusted(context.TODO(), theirDevice) { - t.Error("Other device trusted while it shouldn't be") - } + + trusted, err := m.IsUserTrusted(context.TODO(), otherUser) + require.NoError(t, err, "Error checking if other user is trusted") + assert.False(t, trusted, "Other user trusted while they shouldn't be") + assert.False(t, m.IsDeviceTrusted(context.TODO(), theirDevice), "Other device trusted while it shouldn't be") theirMasterKey, _ := olm.NewPKSigning() m.CryptoStore.PutCrossSigningKey(context.TODO(), otherUser, id.XSUsageMaster, theirMasterKey.PublicKey()) @@ -137,21 +129,17 @@ func TestTrustOtherDevice(t *testing.T) { m.CryptoStore.PutSignature(context.TODO(), otherUser, theirMasterKey.PublicKey(), m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey(), "sig2") - if trusted, _ := m.IsUserTrusted(context.TODO(), otherUser); !trusted { - t.Error("Other user not trusted while they should be") - } + trusted, err = m.IsUserTrusted(context.TODO(), otherUser) + require.NoError(t, err, "Error checking if other user is trusted") + assert.True(t, trusted, "Other user not trusted while they should be") m.CryptoStore.PutSignature(context.TODO(), otherUser, theirSSK.PublicKey(), otherUser, theirMasterKey.PublicKey(), "sig3") - if m.IsDeviceTrusted(context.TODO(), theirDevice) { - t.Error("Other device trusted before it has been signed with user's SSK") - } + assert.False(t, m.IsDeviceTrusted(context.TODO(), theirDevice), "Other device trusted before it has been signed with user's SSK") m.CryptoStore.PutSignature(context.TODO(), otherUser, theirDevice.SigningKey, otherUser, theirSSK.PublicKey(), "sig4") - if !m.IsDeviceTrusted(context.TODO(), theirDevice) { - t.Error("Other device not trusted while it should be") - } + assert.True(t, m.IsDeviceTrusted(context.TODO(), theirDevice), "Other device not trusted after it has been signed with user's SSK") } diff --git a/crypto/machine_test.go b/crypto/machine_test.go index 59c86236..872c3ac4 100644 --- a/crypto/machine_test.go +++ b/crypto/machine_test.go @@ -36,20 +36,15 @@ func (mockStateStore) FindSharedRooms(context.Context, id.UserID) ([]id.RoomID, func newMachine(t *testing.T, userID id.UserID) *OlmMachine { client, err := mautrix.NewClient("http://localhost", userID, "token") - if err != nil { - t.Fatalf("Error creating client: %v", err) - } + require.NoError(t, err, "Error creating client") client.DeviceID = "device1" gobStore := NewMemoryStore(nil) - if err != nil { - t.Fatalf("Error creating Gob store: %v", err) - } + require.NoError(t, err, "Error creating Gob store") machine := NewOlmMachine(client, nil, gobStore, mockStateStore{}) - if err := machine.Load(context.TODO()); err != nil { - t.Fatalf("Error creating account: %v", err) - } + err = machine.Load(context.TODO()) + require.NoError(t, err, "Error creating account") return machine } @@ -82,9 +77,7 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { // create outbound olm session for sending machine using OTK olmSession, err := machineOut.account.Internal.NewOutboundSession(machineIn.account.IdentityKey(), otk.Key) - if err != nil { - t.Errorf("Failed to create outbound olm session: %v", err) - } + require.NoError(t, err, "Error creating outbound olm session") // store sender device identity in receiving machine store machineIn.CryptoStore.PutDevices(context.TODO(), "user1", map[id.DeviceID]*id.Device{ @@ -121,29 +114,21 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { Type: event.ToDeviceEncrypted, Sender: "user1", }, senderKey, content.Type, content.Body) - if err != nil { - t.Errorf("Error decrypting olm content: %v", err) - } + require.NoError(t, err, "Error decrypting olm ciphertext") + // store room key in new inbound group session roomKeyEvt := decrypted.Content.AsRoomKey() igs, err := NewInboundGroupSession(senderKey, signingKey, "room1", roomKeyEvt.SessionKey, 0, 0, false) - if err != nil { - t.Errorf("Error creating inbound megolm session: %v", err) - } - if err = machineIn.CryptoStore.PutGroupSession(context.TODO(), igs); err != nil { - t.Errorf("Error storing inbound megolm session: %v", err) - } + require.NoError(t, err, "Error creating inbound group session") + err = machineIn.CryptoStore.PutGroupSession(context.TODO(), igs) + require.NoError(t, err, "Error storing inbound group session") } // encrypt event with megolm session in sending machine eventContent := map[string]string{"hello": "world"} encryptedEvtContent, err := machineOut.EncryptMegolmEvent(context.TODO(), "room1", event.EventMessage, eventContent) - if err != nil { - t.Errorf("Error encrypting megolm event: %v", err) - } - if megolmOutSession.MessageCount != 1 { - t.Errorf("Megolm outbound session message count is not 1 but %d", megolmOutSession.MessageCount) - } + require.NoError(t, err, "Error encrypting megolm event") + assert.Equal(t, 1, megolmOutSession.MessageCount) encryptedEvt := &event.Event{ Content: event.Content{Parsed: encryptedEvtContent}, @@ -155,22 +140,12 @@ func TestOlmMachineOlmMegolmSessions(t *testing.T) { // decrypt event on receiving machine and confirm decryptedEvt, err := machineIn.DecryptMegolmEvent(context.TODO(), encryptedEvt) - if err != nil { - t.Errorf("Error decrypting megolm event: %v", err) - } - if decryptedEvt.Type != event.EventMessage { - t.Errorf("Expected event type %v, got %v", event.EventMessage, decryptedEvt.Type) - } - if decryptedEvt.Content.Raw["hello"] != "world" { - t.Errorf("Expected event content %v, got %v", eventContent, decryptedEvt.Content.Raw) - } + require.NoError(t, err, "Error decrypting megolm event") + assert.Equal(t, event.EventMessage, decryptedEvt.Type) + assert.Equal(t, "world", decryptedEvt.Content.Raw["hello"]) machineOut.EncryptMegolmEvent(context.TODO(), "room1", event.EventMessage, eventContent) - if megolmOutSession.Expired() { - t.Error("Megolm outbound session expired before 3rd message") - } + assert.False(t, megolmOutSession.Expired(), "Megolm outbound session expired before 3rd message") machineOut.EncryptMegolmEvent(context.TODO(), "room1", event.EventMessage, eventContent) - if !megolmOutSession.Expired() { - t.Error("Megolm outbound session not expired after 3rd message") - } + assert.True(t, megolmOutSession.Expired(), "Megolm outbound session not expired after 3rd message") } diff --git a/crypto/store_test.go b/crypto/store_test.go index a7c4d75a..8aeae7af 100644 --- a/crypto/store_test.go +++ b/crypto/store_test.go @@ -13,6 +13,7 @@ import ( "testing" _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.mau.fi/util/dbutil" @@ -29,22 +30,14 @@ const groupSession = "9ZbsRqJuETbjnxPpKv29n3dubP/m5PSLbr9I9CIWS2O86F/Og1JZXhqT+4 func getCryptoStores(t *testing.T) map[string]Store { rawDB, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000") - if err != nil { - t.Fatalf("Error opening db: %v", err) - } + require.NoError(t, err, "Error opening raw database") db, err := dbutil.NewWithDB(rawDB, "sqlite3") - if err != nil { - t.Fatalf("Error opening db: %v", err) - } + require.NoError(t, err, "Error creating database wrapper") sqlStore := NewSQLCryptoStore(db, nil, "accid", id.DeviceID("dev"), []byte("test")) - if err = sqlStore.DB.Upgrade(context.TODO()); err != nil { - t.Fatalf("Error creating tables: %v", err) - } + err = sqlStore.DB.Upgrade(context.TODO()) + require.NoError(t, err, "Error upgrading database") gobStore := NewMemoryStore(nil) - if err != nil { - t.Fatalf("Error creating Gob store: %v", err) - } return map[string]Store{ "sql": sqlStore, @@ -56,9 +49,10 @@ func TestPutNextBatch(t *testing.T) { stores := getCryptoStores(t) store := stores["sql"].(*SQLCryptoStore) store.PutNextBatch(context.Background(), "batch1") - if batch, _ := store.GetNextBatch(context.Background()); batch != "batch1" { - t.Errorf("Expected batch1, got %v", batch) - } + + batch, err := store.GetNextBatch(context.Background()) + require.NoError(t, err, "Error retrieving next batch") + assert.Equal(t, "batch1", batch) } func TestPutAccount(t *testing.T) { @@ -68,15 +62,9 @@ func TestPutAccount(t *testing.T) { acc := NewOlmAccount() store.PutAccount(context.TODO(), acc) retrieved, err := store.GetAccount(context.TODO()) - if err != nil { - t.Fatalf("Error retrieving account: %v", err) - } - if acc.IdentityKey() != retrieved.IdentityKey() { - t.Errorf("Stored identity key %v, got %v", acc.IdentityKey(), retrieved.IdentityKey()) - } - if acc.SigningKey() != retrieved.SigningKey() { - t.Errorf("Stored signing key %v, got %v", acc.SigningKey(), retrieved.SigningKey()) - } + require.NoError(t, err, "Error retrieving account") + assert.Equal(t, acc.IdentityKey(), retrieved.IdentityKey(), "Identity key does not match") + assert.Equal(t, acc.SigningKey(), retrieved.SigningKey(), "Signing key does not match") }) } } @@ -86,18 +74,26 @@ func TestValidateMessageIndex(t *testing.T) { for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { acc := NewOlmAccount() - if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000); !ok { - t.Error("First message not validated successfully") - } - if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1001); ok { - t.Error("First message validated successfully after changing timestamp") - } - if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event2", 0, 1000); ok { - t.Error("First message validated successfully after changing event ID") - } - if ok, _ := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000); !ok { - t.Error("First message not validated successfully for a second time") - } + + // First message should validate successfully + ok, err := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000) + require.NoError(t, err, "Error validating message index") + assert.True(t, ok, "First message validation should be valid") + + // Edit the timestamp and ensure validate fails + ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1001) + require.NoError(t, err, "Error validating message index after timestamp change") + assert.False(t, ok, "First message validation should fail after timestamp change") + + // Edit the event ID and ensure validate fails + ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event2", 0, 1000) + require.NoError(t, err, "Error validating message index after event ID change") + assert.False(t, ok, "First message validation should fail after event ID change") + + // Validate again with the original parameters and ensure that it still passes + ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000) + require.NoError(t, err, "Error validating message index") + assert.True(t, ok, "First message validation should be valid") }) } } @@ -106,43 +102,26 @@ func TestStoreOlmSession(t *testing.T) { stores := getCryptoStores(t) for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { - if store.HasSession(context.TODO(), olmSessID) { - t.Error("Found Olm session before inserting it") - } + require.False(t, store.HasSession(context.TODO(), olmSessID), "Found Olm session before inserting it") + olmInternal, err := olm.SessionFromPickled([]byte(olmPickled), []byte("test")) - if err != nil { - t.Fatalf("Error creating internal Olm session: %v", err) - } + require.NoError(t, err, "Error creating internal Olm session") olmSess := OlmSession{ id: olmSessID, Internal: olmInternal, } err = store.AddSession(context.TODO(), olmSessID, &olmSess) - if err != nil { - t.Errorf("Error storing Olm session: %v", err) - } - if !store.HasSession(context.TODO(), olmSessID) { - t.Error("Not found Olm session after inserting it") - } + require.NoError(t, err, "Error storing Olm session") + assert.True(t, store.HasSession(context.TODO(), olmSessID), "Olm session not found after inserting it") retrieved, err := store.GetLatestSession(context.TODO(), olmSessID) - if err != nil { - t.Errorf("Failed retrieving Olm session: %v", err) - } - - if retrieved.ID() != olmSessID { - t.Errorf("Expected session ID to be %v, got %v", olmSessID, retrieved.ID()) - } + require.NoError(t, err, "Error retrieving Olm session") + assert.EqualValues(t, olmSessID, retrieved.ID()) pickled, err := retrieved.Internal.Pickle([]byte("test")) - if err != nil { - t.Fatalf("Error pickling Olm session: %v", err) - } - - if string(pickled) != olmPickled { - t.Error("Pickled Olm session does not match original") - } + require.NoError(t, err, "Error pickling Olm session") + assert.EqualValues(t, pickled, olmPickled, "Pickled Olm session does not match original") }) } } @@ -154,9 +133,7 @@ func TestStoreMegolmSession(t *testing.T) { acc := NewOlmAccount() internal, err := olm.InboundGroupSessionFromPickled([]byte(groupSession), []byte("test")) - if err != nil { - t.Fatalf("Error creating internal inbound group session: %v", err) - } + require.NoError(t, err, "Error creating internal inbound group session") igs := &InboundGroupSession{ Internal: internal, @@ -166,20 +143,14 @@ func TestStoreMegolmSession(t *testing.T) { } err = store.PutGroupSession(context.TODO(), igs) - if err != nil { - t.Errorf("Error storing inbound group session: %v", err) - } + require.NoError(t, err, "Error storing inbound group session") retrieved, err := store.GetGroupSession(context.TODO(), "room1", igs.ID()) - if err != nil { - t.Errorf("Error retrieving inbound group session: %v", err) - } + require.NoError(t, err, "Error retrieving inbound group session") - if pickled, err := retrieved.Internal.Pickle([]byte("test")); err != nil { - t.Fatalf("Error pickling inbound group session: %v", err) - } else if string(pickled) != groupSession { - t.Error("Pickled inbound group session does not match original") - } + pickled, err := retrieved.Internal.Pickle([]byte("test")) + require.NoError(t, err, "Error pickling inbound group session") + assert.EqualValues(t, pickled, groupSession, "Pickled inbound group session does not match original") }) } } @@ -189,40 +160,24 @@ func TestStoreOutboundMegolmSession(t *testing.T) { for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { sess, err := store.GetOutboundGroupSession(context.TODO(), "room1") - if sess != nil { - t.Error("Got outbound session before inserting") - } - if err != nil { - t.Errorf("Error retrieving outbound session: %v", err) - } + require.NoError(t, err, "Error retrieving outbound session") + require.Nil(t, sess, "Got outbound session before inserting") outbound, err := NewOutboundGroupSession("room1", nil) require.NoError(t, err) err = store.AddOutboundGroupSession(context.TODO(), outbound) - if err != nil { - t.Errorf("Error inserting outbound session: %v", err) - } + require.NoError(t, err, "Error inserting outbound session") sess, err = store.GetOutboundGroupSession(context.TODO(), "room1") - if sess == nil { - t.Error("Did not get outbound session after inserting") - } - if err != nil { - t.Errorf("Error retrieving outbound session: %v", err) - } + require.NoError(t, err, "Error retrieving outbound session") + assert.NotNil(t, sess, "Did not get outbound session after inserting") err = store.RemoveOutboundGroupSession(context.TODO(), "room1") - if err != nil { - t.Errorf("Error deleting outbound session: %v", err) - } + require.NoError(t, err, "Error deleting outbound session") sess, err = store.GetOutboundGroupSession(context.TODO(), "room1") - if sess != nil { - t.Error("Got outbound session after deleting") - } - if err != nil { - t.Errorf("Error retrieving outbound session: %v", err) - } + require.NoError(t, err, "Error retrieving outbound session after deletion") + assert.Nil(t, sess, "Got outbound session after deleting") }) } } @@ -244,58 +199,41 @@ func TestStoreOutboundMegolmSessionSharing(t *testing.T) { t.Run(storeName, func(t *testing.T) { device := resetDevice() err := store.PutDevice(context.TODO(), "user1", device) - if err != nil { - t.Errorf("Error storing devices: %v", err) - } + require.NoError(t, err, "Error storing device") shared, err := store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1") - if err != nil { - t.Errorf("Error checking if outbound group session is shared: %v", err) - } else if shared { - t.Errorf("Outbound group session shared when it shouldn't") - } + require.NoError(t, err, "Error checking if outbound group session is shared") + assert.False(t, shared, "Outbound group session should not be shared initially") err = store.MarkOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1") - if err != nil { - t.Errorf("Error marking outbound group session as shared: %v", err) - } + require.NoError(t, err, "Error marking outbound group session as shared") shared, err = store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1") - if err != nil { - t.Errorf("Error checking if outbound group session is shared: %v", err) - } else if !shared { - t.Errorf("Outbound group session not shared when it should") - } + require.NoError(t, err, "Error checking if outbound group session is shared") + assert.True(t, shared, "Outbound group session should be shared after marking it as such") device = resetDevice() err = store.PutDevice(context.TODO(), "user1", device) - if err != nil { - t.Errorf("Error storing devices: %v", err) - } + require.NoError(t, err, "Error storing device after resetting") shared, err = store.IsOutboundGroupSessionShared(context.TODO(), device.UserID, device.IdentityKey, "session1") - if err != nil { - t.Errorf("Error checking if outbound group session is shared: %v", err) - } else if shared { - t.Errorf("Outbound group session shared when it shouldn't") - } + require.NoError(t, err, "Error checking if outbound group session is shared") + assert.False(t, shared, "Outbound group session should not be shared after resetting device") }) } } func TestStoreDevices(t *testing.T) { + devicesToCreate := 17 stores := getCryptoStores(t) for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { outdated, err := store.GetOutdatedTrackedUsers(context.TODO()) - if err != nil { - t.Errorf("Error filtering tracked users: %v", err) - } - if len(outdated) > 0 { - t.Errorf("Got %d outdated tracked users when expected none", len(outdated)) - } + require.NoError(t, err, "Error filtering tracked users") + assert.Empty(t, outdated, "Expected no outdated tracked users initially") + deviceMap := make(map[id.DeviceID]*id.Device) - for i := 0; i < 17; i++ { + for i := 0; i < devicesToCreate; i++ { iStr := strconv.Itoa(i) acc := NewOlmAccount() deviceMap[id.DeviceID("dev"+iStr)] = &id.Device{ @@ -306,59 +244,33 @@ func TestStoreDevices(t *testing.T) { } } err = store.PutDevices(context.TODO(), "user1", deviceMap) - if err != nil { - t.Errorf("Error storing devices: %v", err) - } + require.NoError(t, err, "Error storing devices") devs, err := store.GetDevices(context.TODO(), "user1") - if err != nil { - t.Errorf("Error getting devices: %v", err) - } - if len(devs) != 17 { - t.Errorf("Stored 17 devices, got back %v", len(devs)) - } - if devs["dev0"].IdentityKey != deviceMap["dev0"].IdentityKey { - t.Errorf("First device identity key does not match") - } - if devs["dev16"].IdentityKey != deviceMap["dev16"].IdentityKey { - t.Errorf("Last device identity key does not match") - } + require.NoError(t, err, "Error getting devices") + assert.Len(t, devs, devicesToCreate, "Expected to get %d devices back", devicesToCreate) + assert.Equal(t, deviceMap, devs, "Stored devices do not match retrieved devices") filtered, err := store.FilterTrackedUsers(context.TODO(), []id.UserID{"user0", "user1", "user2"}) - if err != nil { - t.Errorf("Error filtering tracked users: %v", err) - } else if len(filtered) != 1 || filtered[0] != "user1" { - t.Errorf("Expected to get 'user1' from filter, got %v", filtered) - } + require.NoError(t, err, "Error filtering tracked users") + assert.Equal(t, []id.UserID{"user1"}, filtered, "Expected to get 'user1' from filter") outdated, err = store.GetOutdatedTrackedUsers(context.TODO()) - if err != nil { - t.Errorf("Error filtering tracked users: %v", err) - } - if len(outdated) > 0 { - t.Errorf("Got %d outdated tracked users when expected none", len(outdated)) - } + require.NoError(t, err, "Error filtering tracked users") + assert.Empty(t, outdated, "Expected no outdated tracked users after initial storage") + err = store.MarkTrackedUsersOutdated(context.TODO(), []id.UserID{"user0", "user1"}) - if err != nil { - t.Errorf("Error marking tracked users outdated: %v", err) - } + require.NoError(t, err, "Error marking tracked users outdated") + outdated, err = store.GetOutdatedTrackedUsers(context.TODO()) - if err != nil { - t.Errorf("Error filtering tracked users: %v", err) - } - if len(outdated) != 1 || outdated[0] != id.UserID("user1") { - t.Errorf("Got outdated tracked users %v when expected 'user1'", outdated) - } + require.NoError(t, err, "Error filtering tracked users") + assert.Equal(t, []id.UserID{"user1"}, outdated, "Expected 'user1' to be marked as outdated") + err = store.PutDevices(context.TODO(), "user1", deviceMap) - if err != nil { - t.Errorf("Error storing devices: %v", err) - } + require.NoError(t, err, "Error storing devices again") + outdated, err = store.GetOutdatedTrackedUsers(context.TODO()) - if err != nil { - t.Errorf("Error filtering tracked users: %v", err) - } - if len(outdated) > 0 { - t.Errorf("Got outdated tracked users %v when expected none", outdated) - } + require.NoError(t, err, "Error filtering tracked users") + assert.Empty(t, outdated, "Expected no outdated tracked users after re-storing devices") }) } } @@ -369,16 +281,11 @@ func TestStoreSecrets(t *testing.T) { t.Run(storeName, func(t *testing.T) { storedSecret := "trustno1" err := store.PutSecret(context.TODO(), id.SecretMegolmBackupV1, storedSecret) - if err != nil { - t.Errorf("Error storing secret: %v", err) - } + require.NoError(t, err, "Error storing secret") secret, err := store.GetSecret(context.TODO(), id.SecretMegolmBackupV1) - if err != nil { - t.Errorf("Error storing secret: %v", err) - } else if secret != storedSecret { - t.Errorf("Stored secret did not match: '%s' != '%s'", secret, storedSecret) - } + require.NoError(t, err, "Error retrieving secret") + assert.Equal(t, storedSecret, secret, "Retrieved secret does not match stored secret") }) } } diff --git a/crypto/utils/utils_test.go b/crypto/utils/utils_test.go index c4f01a68..b12fd9e2 100644 --- a/crypto/utils/utils_test.go +++ b/crypto/utils/utils_test.go @@ -9,6 +9,9 @@ package utils import ( "encoding/base64" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestAES256Ctr(t *testing.T) { @@ -16,9 +19,7 @@ func TestAES256Ctr(t *testing.T) { key, iv := GenAttachmentA256CTR() enc := XorA256CTR([]byte(expected), key, iv) dec := XorA256CTR(enc, key, iv) - if string(dec) != expected { - t.Errorf("Expected decrypted using generated key/iv to be `%v`, got %v", expected, string(dec)) - } + assert.EqualValues(t, expected, dec, "Decrypted text should match original") var key2 [AESCTRKeyLength]byte var iv2 [AESCTRIVLength]byte @@ -29,9 +30,7 @@ func TestAES256Ctr(t *testing.T) { iv2[i] = byte(i) + 32 } dec2 := XorA256CTR([]byte{0x29, 0xc3, 0xff, 0x02, 0x21, 0xaf, 0x67, 0x73, 0x6e, 0xad, 0x9d}, key2, iv2) - if string(dec2) != expected { - t.Errorf("Expected decrypted using constant key/iv to be `%v`, got %v", expected, string(dec2)) - } + assert.EqualValues(t, expected, dec2, "Decrypted text with constant key/iv should match original") } func TestPBKDF(t *testing.T) { @@ -42,9 +41,7 @@ func TestPBKDF(t *testing.T) { key := PBKDF2SHA512([]byte("Hello world"), salt, 1000, 256) expected := "ffk9YdbVE1cgqOWgDaec0lH+rJzO+MuCcxpIn3Z6D0E=" keyB64 := base64.StdEncoding.EncodeToString([]byte(key)) - if keyB64 != expected { - t.Errorf("Expected base64 of generated key to be `%v`, got `%v`", expected, keyB64) - } + assert.Equal(t, expected, keyB64) } func TestDecodeSSSSKey(t *testing.T) { @@ -53,13 +50,10 @@ func TestDecodeSSSSKey(t *testing.T) { expected := "QCFDrXZYLEFnwf4NikVm62rYGJS2mNBEmAWLC3CgNPw=" decodedB64 := base64.StdEncoding.EncodeToString(decoded[:]) - if expected != decodedB64 { - t.Errorf("Expected decoded recovery key b64 to be `%v`, got `%v`", expected, decodedB64) - } + assert.Equal(t, expected, decodedB64) - if encoded := EncodeBase58RecoveryKey(decoded); encoded != recoveryKey { - t.Errorf("Expected recovery key to be `%v`, got `%v`", recoveryKey, encoded) - } + encoded := EncodeBase58RecoveryKey(decoded) + assert.Equal(t, recoveryKey, encoded) } func TestKeyDerivationAndHMAC(t *testing.T) { @@ -69,15 +63,11 @@ func TestKeyDerivationAndHMAC(t *testing.T) { aesKey, hmacKey := DeriveKeysSHA256(decoded[:], "m.cross_signing.master") ciphertextBytes, err := base64.StdEncoding.DecodeString("Fx16KlJ9vkd3Dd6CafIq5spaH5QmK5BALMzbtFbQznG2j1VARKK+klc4/Qo=") - if err != nil { - t.Error(err) - } + require.NoError(t, err) calcMac := HMACSHA256B64(ciphertextBytes, hmacKey) expectedMac := "0DABPNIZsP9iTOh1o6EM0s7BfHHXb96dN7Eca88jq2E" - if calcMac != expectedMac { - t.Errorf("Expected MAC `%v`, got `%v`", expectedMac, calcMac) - } + assert.Equal(t, expectedMac, calcMac) var ivBytes [AESCTRIVLength]byte decodedIV, _ := base64.StdEncoding.DecodeString("zxT/W5LpZ0Q819pfju6hZw==") @@ -85,7 +75,5 @@ func TestKeyDerivationAndHMAC(t *testing.T) { decrypted := string(XorA256CTR(ciphertextBytes, aesKey, ivBytes)) expectedDec := "Ec8eZDyvVkO3EDsEG6ej5c0cCHnX7PINqFXZjnaTV2s=" - if expectedDec != decrypted { - t.Errorf("Expected decrypted text to be `%v`, got `%v`", expectedDec, decrypted) - } + assert.Equal(t, expectedDec, decrypted) } From e27e00b391746627eb000da85b5e38704bc2046a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 3 Aug 2025 11:51:47 +0300 Subject: [PATCH 1279/1647] id: move room version from event package and add flags --- bridgev2/portal.go | 2 +- bridgev2/space.go | 2 +- bridgev2/user.go | 2 +- event/state.go | 43 ++++---- id/roomversion.go | 265 +++++++++++++++++++++++++++++++++++++++++++++ requests.go | 2 +- responses.go | 14 +-- 7 files changed, 297 insertions(+), 33 deletions(-) create mode 100644 id/roomversion.go diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 5aae45e9..25865080 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -4296,7 +4296,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo IsDirect: portal.RoomType == database.RoomTypeDM, PowerLevelOverride: powerLevels, BeeperLocalRoomID: portal.Bridge.Matrix.GenerateDeterministicRoomID(portal.PortalKey), - RoomVersion: event.RoomV11, + RoomVersion: id.RoomV11, } autoJoinInvites := portal.Bridge.Matrix.GetCapabilities().AutoJoinInvites if autoJoinInvites { diff --git a/bridgev2/space.go b/bridgev2/space.go index ccb74b26..ae9013cb 100644 --- a/bridgev2/space.go +++ b/bridgev2/space.go @@ -164,7 +164,7 @@ func (ul *UserLogin) GetSpaceRoom(ctx context.Context) (id.RoomID, error) { ul.UserMXID: 50, }, }, - RoomVersion: event.RoomV11, + RoomVersion: id.RoomV11, Invite: []id.UserID{ul.UserMXID}, } if autoJoin { diff --git a/bridgev2/user.go b/bridgev2/user.go index 350cecd1..87ced1d7 100644 --- a/bridgev2/user.go +++ b/bridgev2/user.go @@ -225,7 +225,7 @@ func (user *User) GetManagementRoom(ctx context.Context) (id.RoomID, error) { user.MXID: 50, }, }, - RoomVersion: event.RoomV11, + RoomVersion: id.RoomV11, Invite: []id.UserID{user.MXID}, IsDirect: true, } diff --git a/event/state.go b/event/state.go index 83390c90..44a45a57 100644 --- a/event/state.go +++ b/event/state.go @@ -75,30 +75,32 @@ type Predecessor struct { EventID id.EventID `json:"event_id"` } -type RoomVersion string +// Deprecated: use id.RoomVersion instead +type RoomVersion = id.RoomVersion +// Deprecated: use id.RoomVX constants instead const ( - RoomV1 RoomVersion = "1" - RoomV2 RoomVersion = "2" - RoomV3 RoomVersion = "3" - RoomV4 RoomVersion = "4" - RoomV5 RoomVersion = "5" - RoomV6 RoomVersion = "6" - RoomV7 RoomVersion = "7" - RoomV8 RoomVersion = "8" - RoomV9 RoomVersion = "9" - RoomV10 RoomVersion = "10" - RoomV11 RoomVersion = "11" - RoomV12 RoomVersion = "12" + RoomV1 = id.RoomV1 + RoomV2 = id.RoomV2 + RoomV3 = id.RoomV3 + RoomV4 = id.RoomV4 + RoomV5 = id.RoomV5 + RoomV6 = id.RoomV6 + RoomV7 = id.RoomV7 + RoomV8 = id.RoomV8 + RoomV9 = id.RoomV9 + RoomV10 = id.RoomV10 + RoomV11 = id.RoomV11 + RoomV12 = id.RoomV12 ) // CreateEventContent represents the content of a m.room.create state event. // https://spec.matrix.org/v1.2/client-server-api/#mroomcreate type CreateEventContent struct { - Type RoomType `json:"type,omitempty"` - Federate *bool `json:"m.federate,omitempty"` - RoomVersion RoomVersion `json:"room_version,omitempty"` - Predecessor *Predecessor `json:"predecessor,omitempty"` + Type RoomType `json:"type,omitempty"` + Federate *bool `json:"m.federate,omitempty"` + RoomVersion id.RoomVersion `json:"room_version,omitempty"` + Predecessor *Predecessor `json:"predecessor,omitempty"` // Room v12+ only AdditionalCreators []id.UserID `json:"additional_creators,omitempty"` @@ -108,13 +110,10 @@ type CreateEventContent struct { } func (cec *CreateEventContent) SupportsCreatorPower() bool { - switch cec.RoomVersion { - case "", RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11: + if cec == nil { return false - default: - // Assume anything except known old versions supports creator power. - return true } + return cec.RoomVersion.PrivilegedRoomCreators() } // JoinRule specifies how open a room is to new members. diff --git a/id/roomversion.go b/id/roomversion.go new file mode 100644 index 00000000..578c10bd --- /dev/null +++ b/id/roomversion.go @@ -0,0 +1,265 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package id + +import ( + "errors" + "fmt" + "slices" +) + +type RoomVersion string + +const ( + RoomV0 RoomVersion = "" // No room version, used for rooms created before room versions were introduced, equivalent to v1 + RoomV1 RoomVersion = "1" + RoomV2 RoomVersion = "2" + RoomV3 RoomVersion = "3" + RoomV4 RoomVersion = "4" + RoomV5 RoomVersion = "5" + RoomV6 RoomVersion = "6" + RoomV7 RoomVersion = "7" + RoomV8 RoomVersion = "8" + RoomV9 RoomVersion = "9" + RoomV10 RoomVersion = "10" + RoomV11 RoomVersion = "11" + RoomV12 RoomVersion = "12" +) + +func (rv RoomVersion) Equals(versions ...RoomVersion) bool { + return slices.Contains(versions, rv) +} + +func (rv RoomVersion) NotEquals(versions ...RoomVersion) bool { + return !rv.Equals(versions...) +} + +var ErrUnknownRoomVersion = errors.New("unknown room version") + +func (rv RoomVersion) unknownVersionError() error { + return fmt.Errorf("%w %s", ErrUnknownRoomVersion, rv) +} + +func (rv RoomVersion) IsKnown() bool { + switch rv { + case RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11, RoomV12: + return true + default: + return false + } +} + +type StateResVersion int + +const ( + // StateResV1 is the original state resolution algorithm. + StateResV1 StateResVersion = 0 + // StateResV2 is state resolution v2 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/1759 + StateResV2 StateResVersion = 1 + // StateResV2_1 is state resolution v2.1 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/4297 + StateResV2_1 StateResVersion = 2 +) + +// StateResVersion returns the version of the state resolution algorithm used by this room version. +func (rv RoomVersion) StateResVersion() StateResVersion { + switch rv { + case RoomV0, RoomV1: + return StateResV1 + case RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11: + return StateResV2 + case RoomV12: + return StateResV2_1 + default: + panic(rv.unknownVersionError()) + } +} + +type EventIDFormat int + +const ( + // EventIDFormatCustom is the original format used by room v1 and v2. + // Event IDs in this format are an arbitrary string followed by a colon and the server name. + EventIDFormatCustom EventIDFormat = 0 + // EventIDFormatBase64 is the format used by room v3 introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/1659. + // Event IDs in this format are the standard unpadded base64-encoded SHA256 reference hash of the event. + EventIDFormatBase64 EventIDFormat = 1 + // EventIDFormatURLSafeBase64 is the format used by room v4 and later introduced by https://github.com/matrix-org/matrix-spec-proposals/pull/2002. + // Event IDs in this format are the url-safe unpadded base64-encoded SHA256 reference hash of the event. + EventIDFormatURLSafeBase64 EventIDFormat = 2 +) + +// EventIDFormat returns the format of event IDs used by this room version. +func (rv RoomVersion) EventIDFormat() EventIDFormat { + switch rv { + case RoomV0, RoomV1, RoomV2: + return EventIDFormatCustom + case RoomV3: + return EventIDFormatBase64 + default: + return EventIDFormatURLSafeBase64 + } +} + +///////////////////// +// Room v5 changes // +///////////////////// +// https://github.com/matrix-org/matrix-spec-proposals/pull/2077 + +// EnforceSigningKeyValidity returns true if the `valid_until_ts` field of federation signing keys +// must be enforced on received events. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/2076 +func (rv RoomVersion) EnforceSigningKeyValidity() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4) +} + +///////////////////// +// Room v6 changes // +///////////////////// +// https://github.com/matrix-org/matrix-spec-proposals/pull/2240 + +// SpecialCasedAliasesAuth returns true if the `m.room.aliases` event authorization is special cased +// to only always allow servers to modify the state event with their own server name as state key. +// This also implies that the `aliases` field is protected from redactions. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/2432 +func (rv RoomVersion) SpecialCasedAliasesAuth() bool { + return rv.Equals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5) +} + +// ForbidFloatsAndBigInts returns true if floats and integers greater than 2^53-1 or lower than -2^53+1 are forbidden everywhere. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/2540 +func (rv RoomVersion) ForbidFloatsAndBigInts() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5) +} + +// NotificationsPowerLevels returns true if the `notifications` field in `m.room.power_levels` is validated in event auth. +// However, the field is not protected from redactions. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/2209 +func (rv RoomVersion) NotificationsPowerLevels() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5) +} + +///////////////////// +// Room v7 changes // +///////////////////// +// https://github.com/matrix-org/matrix-spec-proposals/pull/2998 + +// Knocks returns true if the `knock` join rule is supported. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/2403 +func (rv RoomVersion) Knocks() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6) +} + +///////////////////// +// Room v8 changes // +///////////////////// +// https://github.com/matrix-org/matrix-spec-proposals/pull/3289 + +// RestrictedJoins returns true if the `restricted` join rule is supported. +// This also implies that the `allow` field in the `m.room.join_rules` event is supported and protected from redactions. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/3083 +func (rv RoomVersion) RestrictedJoins() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7) +} + +///////////////////// +// Room v9 changes // +///////////////////// +// https://github.com/matrix-org/matrix-spec-proposals/pull/3375 + +// RestrictedJoinsFix returns true if the `join_authorised_via_users_server` field in `m.room.member` events is protected from redactions. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/3375 +func (rv RoomVersion) RestrictedJoinsFix() bool { + return rv.RestrictedJoins() && rv != RoomV8 +} + +////////////////////// +// Room v10 changes // +////////////////////// +// https://github.com/matrix-org/matrix-spec-proposals/pull/3604 + +// ValidatePowerLevelInts returns true if the known values in `m.room.power_levels` must be integers (and not strings). +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/3667 +func (rv RoomVersion) ValidatePowerLevelInts() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9) +} + +// KnockRestricted returns true if the `knock_restricted` join rule is supported. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/3787 +func (rv RoomVersion) KnockRestricted() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9) +} + +////////////////////// +// Room v11 changes // +////////////////////// +// https://github.com/matrix-org/matrix-spec-proposals/pull/3820 + +// CreatorInContent returns true if the `m.room.create` event has a `creator` field in content. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/2175 +func (rv RoomVersion) CreatorInContent() bool { + return rv.Equals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10) +} + +// RedactsInContent returns true if the `m.room.redaction` event has the `redacts` field in content instead of at the top level. +// The redaction protection is also moved from the top level to the content field. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/2174 +// (and https://github.com/matrix-org/matrix-spec-proposals/pull/2176 for the redaction protection). +func (rv RoomVersion) RedactsInContent() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10) +} + +// UpdatedRedactionRules returns true if various updates to the redaction algorithm are applied. +// +// Specifically: +// +// * the `membership`, `origin`, and `prev_state` fields at the top level of all events are no longer protected. +// * the entire content of `m.room.create` is protected. +// * the `redacts` field in `m.room.redaction` content is protected instead of the top-level field. +// * the `m.room.power_levels` event protects the `invite` field in content. +// * the `signed` field inside the `third_party_invite` field in content of `m.room.member` events is protected. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/2176, +// https://github.com/matrix-org/matrix-spec-proposals/pull/3821, and +// https://github.com/matrix-org/matrix-spec-proposals/pull/3989 +func (rv RoomVersion) UpdatedRedactionRules() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10) +} + +////////////////////// +// Room v12 changes // +////////////////////// +// https://github.com/matrix-org/matrix-spec-proposals/pull/4304 + +// Return value of StateResVersion was changed to StateResV2_1 + +// PrivilegedRoomCreators returns true if the creator(s) of a room always have infinite power level. +// This also implies that the `m.room.create` event has an `additional_creators` field, +// and that the creators can't be present in the `m.room.power_levels` event. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/4289 +func (rv RoomVersion) PrivilegedRoomCreators() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11) +} + +// RoomIDIsCreateEventID returns true if the ID of rooms is the same as the ID of the `m.room.create` event. +// This also implies that `m.room.create` events do not have a `room_id` field. +// +// See https://github.com/matrix-org/matrix-spec-proposals/pull/4291 +func (rv RoomVersion) RoomIDIsCreateEventID() bool { + return rv.NotEquals(RoomV0, RoomV1, RoomV2, RoomV3, RoomV4, RoomV5, RoomV6, RoomV7, RoomV8, RoomV9, RoomV10, RoomV11) +} diff --git a/requests.go b/requests.go index 17eda7d2..eade0757 100644 --- a/requests.go +++ b/requests.go @@ -120,7 +120,7 @@ type ReqCreateRoom struct { InitialState []*event.Event `json:"initial_state,omitempty"` Preset string `json:"preset,omitempty"` IsDirect bool `json:"is_direct,omitempty"` - RoomVersion event.RoomVersion `json:"room_version,omitempty"` + RoomVersion id.RoomVersion `json:"room_version,omitempty"` PowerLevelOverride *event.PowerLevelsEventContent `json:"power_level_content_override,omitempty"` diff --git a/responses.go b/responses.go index 2e8005d4..27d96ffe 100644 --- a/responses.go +++ b/responses.go @@ -221,14 +221,14 @@ type RespMutualRooms struct { type RespRoomSummary struct { PublicRoomInfo - Membership event.Membership `json:"membership,omitempty"` - RoomVersion event.RoomVersion `json:"room_version,omitempty"` - Encryption id.Algorithm `json:"encryption,omitempty"` - AllowedRoomIDs []id.RoomID `json:"allowed_room_ids,omitempty"` + Membership event.Membership `json:"membership,omitempty"` + RoomVersion id.RoomVersion `json:"room_version,omitempty"` + Encryption id.Algorithm `json:"encryption,omitempty"` + AllowedRoomIDs []id.RoomID `json:"allowed_room_ids,omitempty"` - UnstableRoomVersion event.RoomVersion `json:"im.nheko.summary.room_version,omitempty"` - UnstableRoomVersionOld event.RoomVersion `json:"im.nheko.summary.version,omitempty"` - UnstableEncryption id.Algorithm `json:"im.nheko.summary.encryption,omitempty"` + UnstableRoomVersion id.RoomVersion `json:"im.nheko.summary.room_version,omitempty"` + UnstableRoomVersionOld id.RoomVersion `json:"im.nheko.summary.version,omitempty"` + UnstableEncryption id.Algorithm `json:"im.nheko.summary.encryption,omitempty"` } // RespRegisterAvailable is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv3registeravailable From 1215f6237ee7d517f36d5159298ad6fa65391c5e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 3 Aug 2025 15:16:24 +0300 Subject: [PATCH 1280/1647] event: fix json tag in power levels --- event/powerlevels.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/event/powerlevels.go b/event/powerlevels.go index 79dbd1f3..50df2c1f 100644 --- a/event/powerlevels.go +++ b/event/powerlevels.go @@ -39,7 +39,7 @@ type PowerLevelsEventContent struct { // This is not a part of power levels, it's added by mautrix-go internally in certain places // in order to detect creator power accurately. - CreateEvent *Event `json:"-,omitempty"` + CreateEvent *Event `json:"-"` } func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent { From 7a791e908c99fc61f378d47a92650783fa3cd316 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 3 Aug 2025 20:37:11 +0300 Subject: [PATCH 1281/1647] federation: extract VerifyJSON into subpackage --- federation/client.go | 3 +- federation/signingkey.go | 53 +++------------------------- federation/signutil/verify.go | 65 +++++++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 50 deletions(-) create mode 100644 federation/signutil/verify.go diff --git a/federation/client.go b/federation/client.go index 7c460d44..8f454516 100644 --- a/federation/client.go +++ b/federation/client.go @@ -21,6 +21,7 @@ import ( "go.mau.fi/util/jsontime" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/federation/signutil" "maunium.net/go/mautrix/id" ) @@ -404,7 +405,7 @@ func (r *signableRequest) Verify(key id.SigningKey, sig string) error { if err != nil { return fmt.Errorf("failed to marshal data: %w", err) } - return VerifyJSONRaw(key, sig, message) + return signutil.VerifyJSONRaw(key, sig, message) } func (r *signableRequest) Sign(key *SigningKey) (string, error) { diff --git a/federation/signingkey.go b/federation/signingkey.go index 0ae6a571..a4ad9679 100644 --- a/federation/signingkey.go +++ b/federation/signingkey.go @@ -10,17 +10,15 @@ import ( "crypto/ed25519" "encoding/base64" "encoding/json" - "errors" "fmt" "strings" "time" - "github.com/tidwall/gjson" "github.com/tidwall/sjson" - "go.mau.fi/util/exgjson" "go.mau.fi/util/jsontime" "maunium.net/go/mautrix/crypto/canonicaljson" + "maunium.net/go/mautrix/federation/signutil" "maunium.net/go/mautrix/id" ) @@ -35,8 +33,8 @@ type SigningKey struct { // // The output of this function can be parsed back into a [SigningKey] using the [ParseSynapseKey] function. func (sk *SigningKey) SynapseString() string { - alg, id := sk.ID.Parse() - return fmt.Sprintf("%s %s %s", alg, id, base64.RawStdEncoding.EncodeToString(sk.Priv.Seed())) + alg, keyID := sk.ID.Parse() + return fmt.Sprintf("%s %s %s", alg, keyID, base64.RawStdEncoding.EncodeToString(sk.Priv.Seed())) } // ParseSynapseKey parses a Synapse-compatible private key string into a SigningKey. @@ -100,56 +98,13 @@ func (skr *ServerKeyResponse) HasKey(keyID id.KeyID) bool { func (skr *ServerKeyResponse) VerifySelfSignature() error { for keyID, key := range skr.VerifyKeys { - if err := VerifyJSON(skr.ServerName, keyID, key.Key, skr.Raw); err != nil { + if err := signutil.VerifyJSON(skr.ServerName, keyID, key.Key, skr.Raw); err != nil { return fmt.Errorf("failed to verify self signature for key %s: %w", keyID, err) } } return nil } -func VerifyJSON(serverName string, keyID id.KeyID, key id.SigningKey, data any) error { - var err error - message, ok := data.(json.RawMessage) - if !ok { - message, err = json.Marshal(data) - if err != nil { - return fmt.Errorf("failed to marshal data: %w", err) - } - } - sigVal := gjson.GetBytes(message, exgjson.Path("signatures", serverName, string(keyID))) - if sigVal.Type != gjson.String { - return ErrSignatureNotFound - } - message, err = sjson.DeleteBytes(message, "signatures") - if err != nil { - return fmt.Errorf("failed to delete signatures: %w", err) - } - message, err = sjson.DeleteBytes(message, "unsigned") - if err != nil { - return fmt.Errorf("failed to delete unsigned: %w", err) - } - return VerifyJSONRaw(key, sigVal.Str, message) -} - -var ErrSignatureNotFound = errors.New("signature not found") -var ErrInvalidSignature = errors.New("invalid signature") - -func VerifyJSONRaw(key id.SigningKey, sig string, message json.RawMessage) error { - sigBytes, err := base64.RawStdEncoding.DecodeString(sig) - if err != nil { - return fmt.Errorf("failed to decode signature: %w", err) - } - keyBytes, err := base64.RawStdEncoding.DecodeString(string(key)) - if err != nil { - return fmt.Errorf("failed to decode key: %w", err) - } - message = canonicaljson.CanonicalJSONAssumeValid(message) - if !ed25519.Verify(keyBytes, message, sigBytes) { - return ErrInvalidSignature - } - return nil -} - type marshalableSKR ServerKeyResponse func (skr *ServerKeyResponse) UnmarshalJSON(data []byte) error { diff --git a/federation/signutil/verify.go b/federation/signutil/verify.go new file mode 100644 index 00000000..8fe55b2f --- /dev/null +++ b/federation/signutil/verify.go @@ -0,0 +1,65 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package signutil + +import ( + "crypto/ed25519" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.mau.fi/util/exgjson" + + "maunium.net/go/mautrix/crypto/canonicaljson" + "maunium.net/go/mautrix/id" +) + +var ErrSignatureNotFound = errors.New("signature not found") +var ErrInvalidSignature = errors.New("invalid signature") + +func VerifyJSON(serverName string, keyID id.KeyID, key id.SigningKey, data any) error { + var err error + message, ok := data.(json.RawMessage) + if !ok { + message, err = json.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal data: %w", err) + } + } + sigVal := gjson.GetBytes(message, exgjson.Path("signatures", serverName, string(keyID))) + if sigVal.Type != gjson.String { + return ErrSignatureNotFound + } + message, err = sjson.DeleteBytes(message, "signatures") + if err != nil { + return fmt.Errorf("failed to delete signatures: %w", err) + } + message, err = sjson.DeleteBytes(message, "unsigned") + if err != nil { + return fmt.Errorf("failed to delete unsigned: %w", err) + } + return VerifyJSONRaw(key, sigVal.Str, message) +} + +func VerifyJSONRaw(key id.SigningKey, sig string, message json.RawMessage) error { + sigBytes, err := base64.RawStdEncoding.DecodeString(sig) + if err != nil { + return fmt.Errorf("failed to decode signature: %w", err) + } + keyBytes, err := base64.RawStdEncoding.DecodeString(string(key)) + if err != nil { + return fmt.Errorf("failed to decode key: %w", err) + } + message = canonicaljson.CanonicalJSONAssumeValid(message) + if !ed25519.Verify(keyBytes, message, sigBytes) { + return ErrInvalidSignature + } + return nil +} From 90e3427ac519001ce1ce917cb59be5c6c7f85a89 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 7 Aug 2025 12:08:25 +0300 Subject: [PATCH 1282/1647] bridgev2: check that avatar mxc is set before ignoring update --- bridgev2/ghost.go | 4 ++-- bridgev2/portal.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index f06c0363..6cef6f06 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -158,7 +158,7 @@ func (ghost *Ghost) UpdateName(ctx context.Context, name string) bool { } func (ghost *Ghost) UpdateAvatar(ctx context.Context, avatar *Avatar) bool { - if ghost.AvatarID == avatar.ID && ghost.AvatarSet { + if ghost.AvatarID == avatar.ID && (avatar.Remove || ghost.AvatarMXC != "") && ghost.AvatarSet { return false } ghost.AvatarID = avatar.ID @@ -168,7 +168,7 @@ func (ghost *Ghost) UpdateAvatar(ctx context.Context, avatar *Avatar) bool { ghost.AvatarSet = false zerolog.Ctx(ctx).Err(err).Msg("Failed to reupload avatar") return true - } else if newHash == ghost.AvatarHash && ghost.AvatarSet { + } else if newHash == ghost.AvatarHash && ghost.AvatarMXC != "" && ghost.AvatarSet { return true } ghost.AvatarHash = newHash diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 25865080..d343a651 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -3533,7 +3533,7 @@ func (portal *Portal) updateTopic(ctx context.Context, topic string, sender Matr } func (portal *Portal) updateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time) bool { - if portal.AvatarID == avatar.ID && (portal.AvatarSet || portal.MXID == "") { + if portal.AvatarID == avatar.ID && (avatar.Remove || portal.AvatarMXC != "") && (portal.AvatarSet || portal.MXID == "") { return false } portal.AvatarID = avatar.ID @@ -3549,7 +3549,7 @@ func (portal *Portal) updateAvatar(ctx context.Context, avatar *Avatar, sender M portal.AvatarSet = false zerolog.Ctx(ctx).Err(err).Msg("Failed to reupload room avatar") return true - } else if newHash == portal.AvatarHash && portal.AvatarSet { + } else if newHash == portal.AvatarHash && portal.AvatarMXC != "" && portal.AvatarSet { return true } portal.AvatarMXC = newMXC From 3865abb3b820bfa24b676737f00b7051f299dd31 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 9 Aug 2025 13:10:18 +0300 Subject: [PATCH 1283/1647] dependencies: update go-util and use new UnsafeString helper --- bridgev2/matrix/connector.go | 4 ++-- go.mod | 2 +- go.sum | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index a1f7d140..19eb399b 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -19,12 +19,12 @@ import ( "strings" "sync" "time" - "unsafe" _ "github.com/lib/pq" "github.com/rs/zerolog" "go.mau.fi/util/dbutil" _ "go.mau.fi/util/dbutil/litestream" + "go.mau.fi/util/exbytes" "go.mau.fi/util/exsync" "go.mau.fi/util/random" "golang.org/x/sync/semaphore" @@ -674,7 +674,7 @@ func (br *Connector) GenerateDeterministicEventID(roomID id.RoomID, _ networkid. eventID[1+hashB64Len] = ':' copy(eventID[1+hashB64Len+1:], br.deterministicEventIDServer) - return id.EventID(unsafe.String(unsafe.SliceData(eventID), len(eventID))) + return id.EventID(exbytes.UnsafeString(eventID)) } func (br *Connector) GenerateDeterministicRoomID(key networkid.PortalKey) id.RoomID { diff --git a/go.mod b/go.mod index 1133313f..a8c1f26d 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.12 - go.mau.fi/util v0.8.9-0.20250723171559-474867266038 + go.mau.fi/util v0.8.9-0.20250808135321-09699c48d2fa go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.40.0 golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 diff --git a/go.sum b/go.sum index 461ee542..4f9bfaeb 100644 --- a/go.sum +++ b/go.sum @@ -51,8 +51,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.12 h1:YwGP/rrea2/CnCtUHgjuolG/PnMxdQtPMO5PvaE2/nY= github.com/yuin/goldmark v1.7.12/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.8.9-0.20250723171559-474867266038 h1:RVL8TVaYc3LTBBopfjCNDtD+6eZks0O+qgXN/9hsz7k= -go.mau.fi/util v0.8.9-0.20250723171559-474867266038/go.mod h1:GZZp5f9r2MgEu4GDvtB0XxCF7i6Z7Z8fM0w9a5oZH3Y= +go.mau.fi/util v0.8.9-0.20250808135321-09699c48d2fa h1:xVnyD0gaIvK+7xA5lWSqWJf5EB2URW2Y0R4ABisAHD0= +go.mau.fi/util v0.8.9-0.20250808135321-09699c48d2fa/go.mod h1:GZZp5f9r2MgEu4GDvtB0XxCF7i6Z7Z8fM0w9a5oZH3Y= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= From 135cffc7c1b9d7796f72bc5b5d9f335395002795 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 9 Aug 2025 13:10:31 +0300 Subject: [PATCH 1284/1647] requests: add json un/marshaler for Direction rune --- requests.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/requests.go b/requests.go index eade0757..8f31e52f 100644 --- a/requests.go +++ b/requests.go @@ -2,6 +2,7 @@ package mautrix import ( "encoding/json" + "fmt" "strconv" "time" @@ -39,6 +40,26 @@ const ( type Direction rune +func (d Direction) MarshalJSON() ([]byte, error) { + return json.Marshal(string(d)) +} + +func (d *Direction) UnmarshalJSON(data []byte) error { + var str string + if err := json.Unmarshal(data, &str); err != nil { + return err + } + switch str { + case "f": + *d = DirectionForward + case "b": + *d = DirectionBackward + default: + return fmt.Errorf("invalid direction %q, must be 'f' or 'b'", str) + } + return nil +} + const ( DirectionForward Direction = 'f' DirectionBackward Direction = 'b' From 87d599c491fe4ce8132d798939fac540328e3277 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 9 Aug 2025 17:42:34 +0300 Subject: [PATCH 1285/1647] crypto: remove group session already shared error --- crypto/encryptmegolm.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index 804e15de..14ba2449 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -25,7 +25,6 @@ import ( ) var ( - AlreadyShared = errors.New("group session already shared") NoGroupSession = errors.New("no group session created") ) @@ -209,7 +208,8 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, if err != nil { return fmt.Errorf("failed to get previous outbound group session: %w", err) } else if session != nil && session.Shared && !session.Expired() { - return AlreadyShared + mach.machOrContextLog(ctx).Debug().Stringer("room_id", roomID).Msg("Not re-sharing group session, already shared") + return nil } log := mach.machOrContextLog(ctx).With(). Str("room_id", roomID.String()). From 6ea2337283856db17207b422ab1d1daf9a5ad676 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 10 Aug 2025 23:22:25 +0300 Subject: [PATCH 1286/1647] event: add policy server spammy flag to unsigned --- event/events.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/event/events.go b/event/events.go index 1a57fb4b..72c1e161 100644 --- a/event/events.go +++ b/event/events.go @@ -146,7 +146,8 @@ type Unsigned struct { BeeperHSOrderString *BeeperEncodedOrder `json:"com.beeper.hs.order_string,omitempty"` BeeperFromBackup bool `json:"com.beeper.from_backup,omitempty"` - ElementSoftFailed bool `json:"io.element.synapse.soft_failed,omitempty"` + ElementSoftFailed bool `json:"io.element.synapse.soft_failed,omitempty"` + ElementPolicyServerSpammy bool `json:"io.element.synapse.policy_server_spammy,omitempty"` } func (us *Unsigned) IsEmpty() bool { From 78aea00999ceb8a9440f411d0e5a79d73116bf32 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 10 Aug 2025 23:23:15 +0300 Subject: [PATCH 1287/1647] format/htmlparser: collapse spaces when parsing html --- format/htmlparser.go | 3 ++- go.mod | 2 +- go.sum | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/format/htmlparser.go b/format/htmlparser.go index b4b1b9a4..e5f92896 100644 --- a/format/htmlparser.go +++ b/format/htmlparser.go @@ -13,6 +13,7 @@ import ( "strconv" "strings" + "go.mau.fi/util/exstrings" "golang.org/x/net/html" "maunium.net/go/mautrix/event" @@ -371,7 +372,7 @@ func (parser *HTMLParser) singleNodeToString(node *html.Node, ctx Context) Tagge switch node.Type { case html.TextNode: if !ctx.PreserveWhitespace { - node.Data = strings.Replace(node.Data, "\n", "", -1) + node.Data = exstrings.CollapseSpaces(strings.ReplaceAll(node.Data, "\n", "")) } if parser.TextConverter != nil { node.Data = parser.TextConverter(node.Data, ctx) diff --git a/go.mod b/go.mod index a8c1f26d..c109af31 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.12 - go.mau.fi/util v0.8.9-0.20250808135321-09699c48d2fa + go.mau.fi/util v0.8.9-0.20250810202017-1d053aac320a go.mau.fi/zeroconfig v0.1.3 golang.org/x/crypto v0.40.0 golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 diff --git a/go.sum b/go.sum index 4f9bfaeb..dae44df7 100644 --- a/go.sum +++ b/go.sum @@ -51,8 +51,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.12 h1:YwGP/rrea2/CnCtUHgjuolG/PnMxdQtPMO5PvaE2/nY= github.com/yuin/goldmark v1.7.12/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.8.9-0.20250808135321-09699c48d2fa h1:xVnyD0gaIvK+7xA5lWSqWJf5EB2URW2Y0R4ABisAHD0= -go.mau.fi/util v0.8.9-0.20250808135321-09699c48d2fa/go.mod h1:GZZp5f9r2MgEu4GDvtB0XxCF7i6Z7Z8fM0w9a5oZH3Y= +go.mau.fi/util v0.8.9-0.20250810202017-1d053aac320a h1:AviXwC+XRYNvlmLieSQxBjj5/K5JUIjBgduYNVSrPTo= +go.mau.fi/util v0.8.9-0.20250810202017-1d053aac320a/go.mod h1:GZZp5f9r2MgEu4GDvtB0XxCF7i6Z7Z8fM0w9a5oZH3Y= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= From 23df81f1cccde05094a7b26aabedff9f1e987be4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 11 Aug 2025 10:45:32 +0300 Subject: [PATCH 1288/1647] crypto/attachments: fix hash check when decrypting --- crypto/attachment/attachments.go | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/crypto/attachment/attachments.go b/crypto/attachment/attachments.go index cfa1c3e5..65c76f5a 100644 --- a/crypto/attachment/attachments.go +++ b/crypto/attachment/attachments.go @@ -9,6 +9,7 @@ package attachment import ( "crypto/aes" "crypto/cipher" + "crypto/hmac" "crypto/sha256" "encoding/base64" "errors" @@ -217,9 +218,7 @@ func (r *encryptingReader) Close() (err error) { err = closer.Close() } if r.isDecrypting { - var downloadedChecksum [utils.SHAHashLength]byte - r.hash.Sum(downloadedChecksum[:]) - if downloadedChecksum != r.file.decoded.sha256 { + if !hmac.Equal(r.hash.Sum(nil), r.file.decoded.sha256[:]) { return HashMismatch } } else { @@ -274,12 +273,13 @@ func (ef *EncryptedFile) PrepareForDecryption() error { func (ef *EncryptedFile) DecryptInPlace(data []byte) error { if err := ef.PrepareForDecryption(); err != nil { return err - } else if ef.decoded.sha256 != sha256.Sum256(data) { - return HashMismatch - } else { - utils.XorA256CTR(data, ef.decoded.key, ef.decoded.iv) - return nil } + dataHash := sha256.Sum256(data) + if !hmac.Equal(ef.decoded.sha256[:], dataHash[:]) { + return HashMismatch + } + utils.XorA256CTR(data, ef.decoded.key, ef.decoded.iv) + return nil } // DecryptStream wraps the given io.Reader in order to decrypt the data. @@ -292,9 +292,10 @@ func (ef *EncryptedFile) DecryptInPlace(data []byte) error { func (ef *EncryptedFile) DecryptStream(reader io.Reader) io.ReadSeekCloser { block, _ := aes.NewCipher(ef.decoded.key[:]) return &encryptingReader{ - stream: cipher.NewCTR(block, ef.decoded.iv[:]), - hash: sha256.New(), - source: reader, - file: ef, + isDecrypting: true, + stream: cipher.NewCTR(block, ef.decoded.iv[:]), + hash: sha256.New(), + source: reader, + file: ef, } } From 5d84bddc62e658c946fbba67265fe13d1e2705b2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 11 Aug 2025 10:58:24 +0300 Subject: [PATCH 1289/1647] crypto/attachments: hash correct data while decrypting --- crypto/attachment/attachments.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/crypto/attachment/attachments.go b/crypto/attachment/attachments.go index 65c76f5a..155cca5c 100644 --- a/crypto/attachment/attachments.go +++ b/crypto/attachment/attachments.go @@ -207,8 +207,13 @@ func (r *encryptingReader) Read(dst []byte) (n int, err error) { } } n, err = r.source.Read(dst) + if r.isDecrypting { + r.hash.Write(dst[:n]) + } r.stream.XORKeyStream(dst[:n], dst[:n]) - r.hash.Write(dst[:n]) + if !r.isDecrypting { + r.hash.Write(dst[:n]) + } return } From 7dcd45eba21c1740e63bdf34462d84edb4321558 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 12 Aug 2025 23:50:50 +0300 Subject: [PATCH 1290/1647] changelog: update --- CHANGELOG.md | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8e71381e..08749f34 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,36 @@ +## v0.25.0 (unreleased) + +* **Breaking change *(appservice,bridgev2,federation)** Replaced gorilla/mux + with standard library ServeMux. +* *(client,bridgev2)* Added support for creator power in room v12. +* *(bridgev2)* Added support for following tombstones. +* *(bridgev2)* Added interface for getting arbitrary state event from Matrix. +* *(bridgev2)* Added batching to disappearing message queue to ensure it doesn't + use too many resources even if there are a large number of messages. +* *(bridgev2/commands)* Added support for canceling QR login with `cancel` + command. +* *(client)* Added option to override HTTP client used for .well-known + resolution. +* *(crypto/backup)* Added method for encrypting key backup session without + private keys. +* *(event->id)* Moved room version type and constants to id package. +* *(bridgev2)* Bots in DM portals will now be added to the functional members + state event to hide them from the room name calculation. +* *(bridgev2)* Changed message delete handling to ignore "delete for me" events + if there are multiple Matrix users in the room. +* *(format/htmlparser)* Changed text processing to collapse multiple spaces into + one when outside `pre`/`code` tags. +* *(format/htmlparser)* Removed link suffix in plaintext output when link text + is only missing protocol part of href. + * e.g. `example.com` will turn into + `example.com` rather than `example.com (https://example.com)` +* *(appservice)* Switched appservice websockets from gorilla/websocket to + coder/websocket. +* *(bridgev2/matrix)* Fixed encryption key sharing not ignoring ghosts properly. +* *(crypto/attachments)* Fixed hash check when decrypting file streams. +* *(crypto)* Removed unnecessary `AlreadyShared` error in `ShareGroupSession`. + The function will now act as if it was successful instead. + ## v0.24.2 (2025-07-16) * *(bridgev2)* Added support for return values from portal event handlers. Note From 809333fcc57269669f4589a65c73fcb505cb7049 Mon Sep 17 00:00:00 2001 From: V02460 Date: Wed, 13 Aug 2025 19:32:21 +0200 Subject: [PATCH 1291/1647] verificationhelper: use static format strings (#390) --- crypto/verificationhelper/verificationhelper.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crypto/verificationhelper/verificationhelper.go b/crypto/verificationhelper/verificationhelper.go index 9d843ea8..0a781c16 100644 --- a/crypto/verificationhelper/verificationhelper.go +++ b/crypto/verificationhelper/verificationhelper.go @@ -848,7 +848,7 @@ func (vh *VerificationHelper) onVerificationStart(ctx context.Context, txn Verif // here, since the start command for scanning and showing QR codes // should be of type m.reciprocate.v1. log.Error().Str("method", string(txn.StartEventContent.Method)).Msg("Unsupported verification method in start event") - vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnknownMethod, fmt.Sprintf("unknown method %s", txn.StartEventContent.Method)) + vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeUnknownMethod, "unknown method %s", txn.StartEventContent.Method) } } From ee869b97e6c241819579ca9e84291d4835af7d8b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 13 Aug 2025 20:33:04 +0300 Subject: [PATCH 1292/1647] dependencies: update --- .github/workflows/go.yml | 8 ++++---- CHANGELOG.md | 1 + go.mod | 22 +++++++++++----------- go.sum | 36 ++++++++++++++++++------------------ 4 files changed, 34 insertions(+), 33 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 71c1988b..3cf412b4 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -34,8 +34,8 @@ jobs: strategy: fail-fast: false matrix: - go-version: ["1.23", "1.24"] - name: Build (${{ matrix.go-version == '1.24' && 'latest' || 'old' }}, libolm) + go-version: ["1.24", "1.25"] + name: Build (${{ matrix.go-version == '1.25' && 'latest' || 'old' }}, libolm) steps: - uses: actions/checkout@v4 @@ -65,8 +65,8 @@ jobs: strategy: fail-fast: false matrix: - go-version: ["1.23", "1.24"] - name: Build (${{ matrix.go-version == '1.24' && 'latest' || 'old' }}, goolm) + go-version: ["1.24", "1.25"] + name: Build (${{ matrix.go-version == '1.25' && 'latest' || 'old' }}, goolm) steps: - uses: actions/checkout@v4 diff --git a/CHANGELOG.md b/CHANGELOG.md index 08749f34..6e87ffdf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ ## v0.25.0 (unreleased) +* Bumped minimum Go version to 1.24. * **Breaking change *(appservice,bridgev2,federation)** Replaced gorilla/mux with standard library ServeMux. * *(client,bridgev2)* Added support for creator power in room v12. diff --git a/go.mod b/go.mod index c109af31..5e351b73 100644 --- a/go.mod +++ b/go.mod @@ -1,27 +1,27 @@ module maunium.net/go/mautrix -go 1.23.0 +go 1.24.0 -toolchain go1.24.5 +toolchain go1.25.0 require ( filippo.io/edwards25519 v1.1.0 github.com/chzyer/readline v1.5.1 github.com/coder/websocket v1.8.13 github.com/lib/pq v1.10.9 - github.com/mattn/go-sqlite3 v1.14.28 + github.com/mattn/go-sqlite3 v1.14.31 github.com/rs/xid v1.6.0 github.com/rs/zerolog v1.34.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/stretchr/testify v1.10.0 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 - github.com/yuin/goldmark v1.7.12 - go.mau.fi/util v0.8.9-0.20250810202017-1d053aac320a + github.com/yuin/goldmark v1.7.13 + go.mau.fi/util v0.8.9-0.20250813172851-79bf3eba563d go.mau.fi/zeroconfig v0.1.3 - golang.org/x/crypto v0.40.0 - golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 - golang.org/x/net v0.42.0 + golang.org/x/crypto v0.41.0 + golang.org/x/exp v0.0.0-20250813145105-42675adae3e6 + golang.org/x/net v0.43.0 golang.org/x/sync v0.16.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 @@ -32,11 +32,11 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/petermattis/goid v0.0.0-20250721140440-ea1c0173183e // indirect + github.com/petermattis/goid v0.0.0-20250813065127-a731cc31b4fe // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect - golang.org/x/sys v0.34.0 // indirect - golang.org/x/text v0.27.0 // indirect + golang.org/x/sys v0.35.0 // indirect + golang.org/x/text v0.28.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index dae44df7..dd517202 100644 --- a/go.sum +++ b/go.sum @@ -24,10 +24,10 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A= -github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/petermattis/goid v0.0.0-20250721140440-ea1c0173183e h1:D0bJD+4O3G4izvrQUmzCL80zazlN7EwJ0PPDhpJWC/I= -github.com/petermattis/goid v0.0.0-20250721140440-ea1c0173183e/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/mattn/go-sqlite3 v1.14.31 h1:ldt6ghyPJsokUIlksH63gWZkG6qVGeEAu4zLeS4aVZM= +github.com/mattn/go-sqlite3 v1.14.31/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/petermattis/goid v0.0.0-20250813065127-a731cc31b4fe h1:vHpqOnPlnkba8iSxU4j/CvDSS9J4+F4473esQsYLGoE= +github.com/petermattis/goid v0.0.0-20250813065127-a731cc31b4fe/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -49,28 +49,28 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/yuin/goldmark v1.7.12 h1:YwGP/rrea2/CnCtUHgjuolG/PnMxdQtPMO5PvaE2/nY= -github.com/yuin/goldmark v1.7.12/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.8.9-0.20250810202017-1d053aac320a h1:AviXwC+XRYNvlmLieSQxBjj5/K5JUIjBgduYNVSrPTo= -go.mau.fi/util v0.8.9-0.20250810202017-1d053aac320a/go.mod h1:GZZp5f9r2MgEu4GDvtB0XxCF7i6Z7Z8fM0w9a5oZH3Y= +github.com/yuin/goldmark v1.7.13 h1:GPddIs617DnBLFFVJFgpo1aBfe/4xcvMc3SB5t/D0pA= +github.com/yuin/goldmark v1.7.13/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= +go.mau.fi/util v0.8.9-0.20250813172851-79bf3eba563d h1:lQCuHA1gVXxIxRhkNSOmhNWuqa8XwMX1mynD6IUELuk= +go.mau.fi/util v0.8.9-0.20250813172851-79bf3eba563d/go.mod h1:FtuGEQbVcfzQpTMDclFsq0NQ9GMtB2Gkd54Uq+TmsMk= go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= -golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= -golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= -golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 h1:R9PFI6EUdfVKgwKjZef7QIwGcBKu86OEFpJ9nUEP2l4= -golang.org/x/exp v0.0.0-20250718183923-645b1fa84792/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc= -golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= -golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= +golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= +golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= +golang.org/x/exp v0.0.0-20250813145105-42675adae3e6 h1:SbTAbRFnd5kjQXbczszQ0hdk3ctwYf3qBNH9jIsGclE= +golang.org/x/exp v0.0.0-20250813145105-42675adae3e6/go.mod h1:4QTo5u+SEIbbKW1RacMZq1YEfOBqeXa19JeshGi+zc4= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= -golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= -golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= From cd022c9010d5036f54831b57923b653cf91b1bc8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 15 Aug 2025 16:45:18 +0300 Subject: [PATCH 1293/1647] client: don't set user-agent header on wasm --- client.go | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/client.go b/client.go index 53ac6e10..4906169f 100644 --- a/client.go +++ b/client.go @@ -13,6 +13,7 @@ import ( "net/http" "net/url" "os" + "runtime" "slices" "strconv" "strings" @@ -154,8 +155,10 @@ func DiscoverClientAPIWithClient(ctx context.Context, client *http.Client, serve return nil, err } - req.Header.Set("Accept", "application/json") - req.Header.Set("User-Agent", DefaultUserAgent+" (.well-known fetcher)") + if runtime.GOOS != "js" { + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", DefaultUserAgent+" (.well-known fetcher)") + } resp, err := client.Do(req) if err != nil { @@ -516,7 +519,9 @@ func (cli *Client) MakeFullRequestWithResp(ctx context.Context, params FullReque params.Handler = handleNormalResponse } } - req.Header.Set("User-Agent", cli.UserAgent) + if cli.UserAgent != "" { + req.Header.Set("User-Agent", cli.UserAgent) + } if len(cli.AccessToken) > 0 { req.Header.Set("Authorization", "Bearer "+cli.AccessToken) } @@ -1803,7 +1808,9 @@ func (cli *Client) tryUploadMediaToURL(ctx context.Context, url, contentType str } req.ContentLength = contentLength req.Header.Set("Content-Type", contentType) - req.Header.Set("User-Agent", cli.UserAgent+" (external media uploader)") + if cli.UserAgent != "" { + req.Header.Set("User-Agent", cli.UserAgent+" (external media uploader)") + } if cli.ExternalClient != nil { return cli.ExternalClient.Do(req) From 0bbfafe02f50c5d4d641f2f95e35f48939fffced Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 16 Aug 2025 13:12:52 +0300 Subject: [PATCH 1294/1647] Bump version to v0.25.0 --- CHANGELOG.md | 4 +++- go.mod | 6 +++--- go.sum | 12 ++++++------ version.go | 2 +- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e87ffdf..22ff47f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,11 @@ ## v0.25.0 (unreleased) * Bumped minimum Go version to 1.24. -* **Breaking change *(appservice,bridgev2,federation)** Replaced gorilla/mux +* **Breaking change *(appservice,bridgev2,federation)*** Replaced gorilla/mux with standard library ServeMux. * *(client,bridgev2)* Added support for creator power in room v12. +* *(client)* Added option to not set `User-Agent` header for improved Wasm + compatibility. * *(bridgev2)* Added support for following tombstones. * *(bridgev2)* Added interface for getting arbitrary state event from Matrix. * *(bridgev2)* Added batching to disappearing message queue to ensure it doesn't diff --git a/go.mod b/go.mod index 5e351b73..4abdc4ff 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/chzyer/readline v1.5.1 github.com/coder/websocket v1.8.13 github.com/lib/pq v1.10.9 - github.com/mattn/go-sqlite3 v1.14.31 + github.com/mattn/go-sqlite3 v1.14.32 github.com/rs/xid v1.6.0 github.com/rs/zerolog v1.34.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e @@ -17,8 +17,8 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.13 - go.mau.fi/util v0.8.9-0.20250813172851-79bf3eba563d - go.mau.fi/zeroconfig v0.1.3 + go.mau.fi/util v0.9.0 + go.mau.fi/zeroconfig v0.2.0 golang.org/x/crypto v0.41.0 golang.org/x/exp v0.0.0-20250813145105-42675adae3e6 golang.org/x/net v0.43.0 diff --git a/go.sum b/go.sum index dd517202..bb5d5cdb 100644 --- a/go.sum +++ b/go.sum @@ -24,8 +24,8 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.31 h1:ldt6ghyPJsokUIlksH63gWZkG6qVGeEAu4zLeS4aVZM= -github.com/mattn/go-sqlite3 v1.14.31/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= +github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/petermattis/goid v0.0.0-20250813065127-a731cc31b4fe h1:vHpqOnPlnkba8iSxU4j/CvDSS9J4+F4473esQsYLGoE= github.com/petermattis/goid v0.0.0-20250813065127-a731cc31b4fe/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -51,10 +51,10 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.13 h1:GPddIs617DnBLFFVJFgpo1aBfe/4xcvMc3SB5t/D0pA= github.com/yuin/goldmark v1.7.13/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.8.9-0.20250813172851-79bf3eba563d h1:lQCuHA1gVXxIxRhkNSOmhNWuqa8XwMX1mynD6IUELuk= -go.mau.fi/util v0.8.9-0.20250813172851-79bf3eba563d/go.mod h1:FtuGEQbVcfzQpTMDclFsq0NQ9GMtB2Gkd54Uq+TmsMk= -go.mau.fi/zeroconfig v0.1.3 h1:As9wYDKmktjmNZW5i1vn8zvJlmGKHeVxHVIBMXsm4kM= -go.mau.fi/zeroconfig v0.1.3/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= +go.mau.fi/util v0.9.0 h1:ya3s3pX+Y8R2fgp0DbE7a0o3FwncoelDX5iyaeVE8ls= +go.mau.fi/util v0.9.0/go.mod h1:pdL3lg2aaeeHIreGXNnPwhJPXkXdc3ZxsI6le8hOWEA= +go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= +go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= golang.org/x/exp v0.0.0-20250813145105-42675adae3e6 h1:SbTAbRFnd5kjQXbczszQ0hdk3ctwYf3qBNH9jIsGclE= diff --git a/version.go b/version.go index 6b8af5ef..fd0d0a8d 100644 --- a/version.go +++ b/version.go @@ -7,7 +7,7 @@ import ( "strings" ) -const Version = "v0.24.2" +const Version = "v0.25.0" var GoModVersion = "" var Commit = "" From 2d4850a188fbc027a5e785471cddf3e69c8284f9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 16 Aug 2025 13:21:06 +0300 Subject: [PATCH 1295/1647] changelog: fix date --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 22ff47f7..f8a15550 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -## v0.25.0 (unreleased) +## v0.25.0 (2025-08-16) * Bumped minimum Go version to 1.24. * **Breaking change *(appservice,bridgev2,federation)*** Replaced gorilla/mux From 80c0b950dc0ca367cef4fc56e7eaff9efb8e65c3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 17 Aug 2025 12:52:58 +0300 Subject: [PATCH 1296/1647] federation/pdu: add utilities for PDU generation and validation --- federation/pdu/pdu.go | 215 +++++++++++++++++++++++++++++++++++++ federation/pdu/pdu_test.go | 197 +++++++++++++++++++++++++++++++++ federation/pdu/redact.go | 108 +++++++++++++++++++ federation/pdu/v1.go | 59 ++++++++++ 4 files changed, 579 insertions(+) create mode 100644 federation/pdu/pdu.go create mode 100644 federation/pdu/pdu_test.go create mode 100644 federation/pdu/redact.go create mode 100644 federation/pdu/v1.go diff --git a/federation/pdu/pdu.go b/federation/pdu/pdu.go new file mode 100644 index 00000000..2ac970dc --- /dev/null +++ b/federation/pdu/pdu.go @@ -0,0 +1,215 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu + +import ( + "bytes" + "crypto/ed25519" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json/jsontext" + "encoding/json/v2" + "errors" + "fmt" + "time" + + "github.com/tidwall/gjson" + "go.mau.fi/util/jsonbytes" + + "maunium.net/go/mautrix/crypto/canonicaljson" + "maunium.net/go/mautrix/federation/signutil" + "maunium.net/go/mautrix/id" +) + +var ErrPDUIsNil = errors.New("PDU is nil") + +type Hashes struct { + SHA256 jsonbytes.UnpaddedBytes `json:"sha256"` + + Unknown jsontext.Value `json:",unknown"` +} + +type PDU struct { + AuthEvents []id.EventID `json:"auth_events"` + Content jsontext.Value `json:"content"` + Depth int64 `json:"depth"` + Hashes *Hashes `json:"hashes,omitzero"` + OriginServerTS int64 `json:"origin_server_ts"` + PrevEvents []id.EventID `json:"prev_events"` + Redacts *id.EventID `json:"redacts,omitzero"` + RoomID id.RoomID `json:"room_id,omitzero"` // not present for room v12+ create events + Sender id.UserID `json:"sender"` + Signatures map[string]map[id.KeyID]string `json:"signatures,omitzero"` + StateKey *string `json:"state_key,omitzero"` + Type string `json:"type"` + Unsigned jsontext.Value `json:"unsigned,omitzero"` + + Unknown jsontext.Value `json:",unknown"` + + // Deprecated legacy fields + DeprecatedPrevState any `json:"prev_state,omitzero"` + DeprecatedOrigin any `json:"origin,omitzero"` + DeprecatedMembership any `json:"membership,omitzero"` +} + +func (pdu *PDU) CalculateContentHash() ([32]byte, error) { + if pdu == nil { + return [32]byte{}, ErrPDUIsNil + } + pduClone := pdu.Clone() + pduClone.Signatures = nil + pduClone.Unsigned = nil + pduClone.Hashes = nil + rawJSON, err := marshalCanonical(pduClone) + if err != nil { + return [32]byte{}, fmt.Errorf("failed to marshal PDU to calculate content hash: %w", err) + } + return sha256.Sum256(rawJSON), nil +} + +func (pdu *PDU) FillContentHash() error { + if pdu == nil { + return ErrPDUIsNil + } else if pdu.Hashes != nil { + return nil + } else if hash, err := pdu.CalculateContentHash(); err != nil { + return err + } else { + pdu.Hashes = &Hashes{SHA256: hash[:]} + return nil + } +} + +func (pdu *PDU) VerifyContentHash() bool { + if pdu == nil || pdu.Hashes == nil { + return false + } + calculatedHash, err := pdu.CalculateContentHash() + if err != nil { + return false + } + return hmac.Equal(calculatedHash[:], pdu.Hashes.SHA256) +} + +func (pdu *PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error { + err := pdu.FillContentHash() + if err != nil { + return err + } + rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) + if err != nil { + return fmt.Errorf("failed to marshal redacted PDU to sign: %w", err) + } + signature := ed25519.Sign(privateKey, rawJSON) + if pdu.Signatures == nil { + pdu.Signatures = make(map[string]map[id.KeyID]string) + } + if _, ok := pdu.Signatures[serverName]; !ok { + pdu.Signatures[serverName] = make(map[id.KeyID]string) + } + pdu.Signatures[serverName][keyID] = base64.RawStdEncoding.EncodeToString(signature) + return nil +} + +func marshalCanonical(data any) (jsontext.Value, error) { + marshaledBytes, err := json.Marshal(data) + if err != nil { + return nil, err + } + marshaled := jsontext.Value(marshaledBytes) + err = marshaled.Canonicalize() + if err != nil { + return nil, err + } + check := canonicaljson.CanonicalJSONAssumeValid(marshaled) + if !bytes.Equal(marshaled, check) { + fmt.Println(string(marshaled)) + fmt.Println(string(check)) + return nil, fmt.Errorf("canonical JSON mismatch for %s", string(marshaled)) + } + return marshaled, nil +} + +func (pdu *PDU) VerifySignature( + roomVersion id.RoomVersion, + serverName string, + getKey func(keyID id.KeyID, minValidUntil time.Time) (id.SigningKey, time.Time, error), +) error { + rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) + if err != nil { + return fmt.Errorf("failed to marshal redacted PDU to verify signature: %w", err) + } + verified := false + for keyID, sig := range pdu.Signatures[serverName] { + originServerTS := time.UnixMilli(pdu.OriginServerTS) + key, validUntil, err := getKey(keyID, originServerTS) + if err != nil { + return fmt.Errorf("failed to get key %s: %w", keyID, err) + } else if key == "" || validUntil.Before(originServerTS) { + continue + } + err = signutil.VerifyJSONRaw(key, sig, rawJSON) + if err != nil { + return fmt.Errorf("failed to verify signature from key %s: %w", keyID, err) + } + verified = true + } + if !verified { + return fmt.Errorf("no verifiable signatures found for server %s", serverName) + } + return nil +} + +func (pdu *PDU) CalculateRoomID() (id.RoomID, error) { + if pdu == nil { + return "", ErrPDUIsNil + } else if pdu.Type != "m.room.create" { + return "", fmt.Errorf("room ID can only be calculated for m.room.create events") + } else if roomVersion := id.RoomVersion(gjson.GetBytes(pdu.Content, "room_version").Str); !roomVersion.RoomIDIsCreateEventID() { + return "", fmt.Errorf("room version %s does not use m.room.create event ID as room ID", roomVersion) + } else if evtID, err := pdu.calculateEventID(roomVersion, '!'); err != nil { + return "", fmt.Errorf("failed to calculate event ID: %w", err) + } else { + return id.RoomID(evtID), nil + } +} + +func (pdu *PDU) CalculateEventID(roomVersion id.RoomVersion) (id.EventID, error) { + return pdu.calculateEventID(roomVersion, '$') +} + +func (pdu *PDU) calculateEventID(roomVersion id.RoomVersion, prefix byte) (id.EventID, error) { + if pdu == nil { + return "", ErrPDUIsNil + } + if pdu.Hashes == nil || pdu.Hashes.SHA256 == nil { + if err := pdu.FillContentHash(); err != nil { + return "", err + } + } + rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) + if err != nil { + return "", fmt.Errorf("failed to marshal redacted PDU to calculate event ID: %w", err) + } + referenceHash := sha256.Sum256(rawJSON) + eventID := make([]byte, 44) + eventID[0] = prefix + switch roomVersion.EventIDFormat() { + case id.EventIDFormatCustom: + return "", fmt.Errorf("*pdu.PDU can only be used for room v3+") + case id.EventIDFormatBase64: + base64.RawStdEncoding.Encode(eventID[1:], referenceHash[:]) + case id.EventIDFormatURLSafeBase64: + base64.RawURLEncoding.Encode(eventID[1:], referenceHash[:]) + default: + return "", fmt.Errorf("unknown event ID format %v", roomVersion.EventIDFormat()) + } + return id.EventID(eventID), nil +} diff --git a/federation/pdu/pdu_test.go b/federation/pdu/pdu_test.go new file mode 100644 index 00000000..a650672c --- /dev/null +++ b/federation/pdu/pdu_test.go @@ -0,0 +1,197 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu_test + +import ( + "encoding/base64" + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "go.mau.fi/util/exerrors" + + "maunium.net/go/mautrix/federation/pdu" + "maunium.net/go/mautrix/id" +) + +type serverKey struct { + key id.SigningKey + validUntilTS time.Time +} + +type serverDetails struct { + serverName string + keys map[id.KeyID]serverKey +} + +var mauniumNet = serverDetails{ + serverName: "maunium.net", + keys: map[id.KeyID]serverKey{ + "ed25519:a_xxeS": { + key: "lVt/CC3tv74OH6xTph2JrUmeRj/j+1q0HVa0Xf4QlCg", + validUntilTS: time.Now(), + }, + }, +} +var envsNet = serverDetails{ + serverName: "envs.net", + keys: map[id.KeyID]serverKey{ + "ed25519:a_zIqy": { + key: "vCUcZpt9hUn0aabfh/9GP/6sZvXcydww8DUstPHdJm0", + validUntilTS: time.UnixMilli(1722360538068), + }, + "ed25519:wuJyKT": { + key: "xbE1QssgomL4wCSlyMYF5/7KxVyM4HPwAbNa+nFFnx0", + validUntilTS: time.Now(), + }, + }, +} +var matrixOrg = serverDetails{ + serverName: "matrix.org", + keys: map[id.KeyID]serverKey{ + "ed25519:auto": { + key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw", + validUntilTS: time.UnixMilli(1576767829750), + }, + "ed25519:a_RXGa": { + key: "l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ", + validUntilTS: time.Now(), + }, + }, +} + +type testPDU struct { + name string + pdu string + eventID id.EventID + roomVersion id.RoomVersion + serverDetails +} + +var testPDUs = []testPDU{{ + name: "m.room.message in v5 room", + pdu: `{"auth_events":["$hp0ImHqYgHTRbLeWKPeTeFmxdb5SdMJN9cfmTrTk7d0","$KAj7X7tnJbR9qYYMWJSw-1g414_KlPptbbkZm7_kUtg","$V-2ShOwZYhA_nxMijaf3lqFgIJgzE2UMeFPtOLnoBYM"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":2248,"hashes":{"sha256":"kV+JuLbWXJ2r6PjHT3wt8bFc/TfI1nTaSN3Lamg/xHs"},"origin_server_ts":1755422945654,"prev_events":["$49lFLem2Nk4dxHk9RDXxTdaq9InIJpmkHpzVnjKcYwg"],"room_id":"!vzBgJsjNzgHSdWsmki:mozilla.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"JIl60uVgfCLBZLPoSiE7wVkJ9U5cNEPVPuv1sCCYUOq5yOW56WD1adgpBUdX2UFpYkCHvkRnyQGxU0+6HBp5BA"}},"unsigned":{"age_ts":1755422945673}}`, + eventID: "$Qn4tHfuAe6PlnKXPZnygAU9wd6RXqMKtt_ZzstHTSgA", + roomVersion: id.RoomV5, + serverDetails: mauniumNet, +}, { + name: "m.room.message in v10 room", + pdu: `{"auth_events":["$--ilpwnsHaEdHrwiMrZNu5xHP6TthWG0FIXMHnlHCcs","$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ","$Z-qMWmiMvm-aIEffcfSO6lN7TyjyTOsIcHIymfzoo20"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":100885,"hashes":{"sha256":"jc9272JPpPIVreJC3UEAm3BNVnLX8sm3U/TZs23wsHo"},"origin_server_ts":1755422792518,"prev_events":["$HDtbzpSys36Hk-F2NsiXfp9slsGXBH0b58qyddj_q5E"],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"sAMLo9jPtNB0Jq67IQm06siEBx82qZa2edu56IDQ4tDylEV4Mq7iFO23gCghqXA7B/MqBsjXotGBxv6AvlJ2Dw"}},"unsigned":{"age_ts":1755422792540}}`, + eventID: "$4ZFr_ypfp4DyZQP4zyxM_cvuOMFkl07doJmwi106YFY", + roomVersion: id.RoomV10, + serverDetails: mauniumNet, +}, { + name: "m.room.message in v11 room", + pdu: `{"auth_events":["$L8Ak6A939llTRIsZrytMlLDXQhI4uLEjx-wb1zSg-Bw","$QJmr7mmGeXGD4Tof0ZYSPW2oRGklseyHTKtZXnF-YNM","$7bkKK_Z-cGQ6Ae4HXWGBwXyZi3YjC6rIcQzGfVyl3Eo"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":3212,"hashes":{"sha256":"K549YdTnv62Jn84Y7sS5ZN3+AdmhleZHbenbhUpR2R8"},"origin_server_ts":1754242687127,"prev_events":["$DAhJg4jVsqk5FRatE2hbT1dSA8D2ASy5DbjEHIMSHwY"],"room_id":"!offtopic-2:continuwuity.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"SkzZdZ+rH22kzCBBIAErTdB0Vg6vkFmzvwjlOarGul72EnufgtE/tJcd3a8szAdK7f1ZovRyQxDgVm/Ib2u0Aw"}},"unsigned":{"age_ts":1754242687146}}`, + eventID: `$qkWfTL7_l3oRZO2CItW8-Q0yAmi_l_1ua629ZDqponE`, + roomVersion: id.RoomV11, + serverDetails: mauniumNet, +}, { + name: "m.room.message in v12 room", + pdu: `{"auth_events":["$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA","$hyeL_nU_L3tsZ2dtZZpAHk0Skv-PqFQIipuII_By584"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":122,"hashes":{"sha256":"IQ0zlc+PXeEs6R3JvRkW3xTPV3zlGKSSd3x07KXGjzs"},"origin_server_ts":1755384351627,"prev_events":["$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA"],"room_id":"!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ","sender":"@tulir_test:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"0GDMddL2k7gF4V1VU8sL3wTfhAIzAu5iVH5jeavZ2VEg3J9/tHLWXAOn2tzkLaMRWl0/XpINT2YlH/rd2U21Ag"}},"unsigned":{"age_ts":1755384351627}}`, + eventID: "$xmP-wZfpannuHG-Akogi6c4YvqxChMtdyYbUMGOrMWc", + roomVersion: id.RoomV12, + serverDetails: mauniumNet, +}, { + name: "m.room.create in v4 room", + pdu: `{"auth_events": [], "prev_events": [], "type": "m.room.create", "room_id": "!jxlRxnrZCsjpjDubDX:matrix.org", "sender": "@neilj:matrix.org", "content": {"room_version": "4", "predecessor": {"room_id": "!DYgXKezaHgMbiPMzjX:matrix.org", "event_id": "$156171636353XwPJT:matrix.org"}, "creator": "@neilj:matrix.org"}, "depth": 1, "prev_state": [], "state_key": "", "origin": "matrix.org", "origin_server_ts": 1561716363993, "hashes": {"sha256": "9tj8GpXjTAJvdNAbnuKLemZZk+Tjv2LAbGodSX6nJAo"}, "signatures": {"matrix.org": {"ed25519:auto": "2+sNt8uJUhzU4GPxnFVYtU2ZRgFdtVLT1vEZGUdJYN40zBpwYEGJy+kyb5matA+8/yLeYD9gu1O98lhleH0aCA"}}, "unsigned": {"age": 104769}}`, + eventID: "$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY", + roomVersion: id.RoomV4, + serverDetails: matrixOrg, +}, { + name: "m.room.create in v10 room", + pdu: `{"auth_events":[],"content":{"creator":"@creme:envs.net","predecessor":{"event_id":"$BxYNisKcyBDhPLiVC06t18qhv7wsT72MzMCqn5vRhfY","room_id":"!tEyFYiMHhwJlDXTxwf:envs.net"},"room_version":"10"},"depth":1,"hashes":{"sha256":"us3TrsIjBWpwbm+k3F9fUVnz9GIuhnb+LcaY47fWwUI"},"origin":"envs.net","origin_server_ts":1664394769527,"prev_events":[],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@creme:envs.net","state_key":"","type":"m.room.create","signatures":{"envs.net":{"ed25519:a_zIqy":"0g3FDaD1e5BekJYW2sR7dgxuKoZshrf8P067c9+jmH6frsWr2Ua86Ax08CFa/n46L8uvV2SGofP8iiVYgXCRBg"}},"unsigned":{"age":2060}}`, + eventID: "$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ", + roomVersion: id.RoomV10, + serverDetails: envsNet, +}, { + name: "m.room.create in v12 room", + pdu: `{"auth_events":[],"content":{"fi.mau.randomness":"AAXZ6aIc","predecessor":{"room_id":"!#test/room\nversion 11, with @\ud83d\udc08\ufe0f:maunium.net"},"room_version":"12"},"depth":1,"hashes":{"sha256":"d3L1M3KUdyIKWcShyW6grUoJ8GOjCdSIEvQrDVHSpE8"},"origin_server_ts":1754940000000,"prev_events":[],"sender":"@tulir:maunium.net","state_key":"","type":"m.room.create","signatures":{"maunium.net":{"ed25519:a_xxeS":"ebjIRpzToc82cjb/RGY+VUzZic0yeRZrjctgx0SUTJxkprXn3/i1KdiYULfl/aD0cUJ5eL8gLakOSk2glm+sBw"}},"unsigned":{"age_ts":1754939139045}}`, + eventID: "$mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", + roomVersion: id.RoomV12, + serverDetails: mauniumNet, +}, { + name: "m.room.member in v4 room", + pdu: `{"auth_events":["$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY","$jg2AgCfnwnjR-osoyM0lVYS21QrtfmZxhGO90PRkmO4","$wMGMP4Ucij2_d4h_fVDgIT2xooLZAgMcBruT9oo3Jio","$yyDgV8w0_e8qslmn0nh9OeSq_fO0zjpjTjSEdKFxDso"],"prev_events":["$zSjNuTXhUe3Rq6NpKD3sNyl8a_asMnBhGC5IbacHlJ4"],"type":"m.room.member","room_id":"!jxlRxnrZCsjpjDubDX:matrix.org","sender":"@tulir:maunium.net","content":{"membership":"join","displayname":"tulir","avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","clicked \"send membership event with no changes\"":true},"depth":14370,"prev_state":[],"state_key":"@tulir:maunium.net","origin":"maunium.net","origin_server_ts":1600871136259,"hashes":{"sha256":"Ga6bG9Mk0887ruzM9TAAfa1O3DbNssb+qSFtE9oeRL4"},"signatures":{"maunium.net":{"ed25519:a_xxeS":"fzOyDG3G3pEzixtWPttkRA1DfnHETiKbiG8SEBQe2qycQbZWPky7xX8WujSrUJH/+bxTABpQwEH49d+RakxtBw"}},"unsigned":{"age_ts":1600871136259,"replaces_state":"$jg2AgCfnwnjR-osoyM0lVYS21QrtfmZxhGO90PRkmO4"}}`, + eventID: "$VtuCNOfAWGow-cxy0ajeK3fvONcC8QzF2yWa43g0Gwo", + roomVersion: id.RoomV4, + serverDetails: mauniumNet, +}, { + name: "m.room.member in v10 room", + pdu: `{"auth_events":["$HQC4hWaioLKVbMH94qKbfb3UnL4ocql2vi-VdUYI48I","$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs","$kEPF8Aj87EzRmFPriu2zdyEY0rY15XSqywTYVLUUlCA","$tn1FZUI_YUpfTr_a3Y_r8kC3inliIZZratzg0UsNdCQ"],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":182,"hashes":{"sha256":"0HscBc921QV2dxK2qY7qrnyoAgfxBM7kKvqAXlEk+GE"},"origin":"maunium.net","origin_server_ts":1665402609039,"prev_events":["$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs"],"room_id":"!UzZHbJYcgggctGnlzr:envs.net","sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member","signatures":{"maunium.net":{"ed25519:a_xxeS":"lkOW0FSJ8MJ0wZpdwLH1Uf6FSl2q9/u6KthRIlM0CwHDJG4sIZ9DrMA8BdU8L/PWoDS/CoDUlLanDh99SplgBw"}},"unsigned":{"age_ts":1665402609039,"replaces_state":"$R9FUDgNAp9ms7b6ASunZOIkpqmsIRq_ROrNEznu62fs"}}`, + eventID: "$--ilpwnsHaEdHrwiMrZNu5xHP6TthWG0FIXMHnlHCcs", + roomVersion: id.RoomV10, + serverDetails: mauniumNet, +}, { + name: "m.room.member of creator in v12 room", + pdu: `{"auth_events":[],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":2,"hashes":{"sha256":"IebdOBYaaWYIx2zq/lkVCnjWIXTLk1g+vgFpJMgd2/E"},"origin_server_ts":1754939139117,"prev_events":["$mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ"],"room_id":"!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ","sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member","signatures":{"maunium.net":{"ed25519:a_xxeS":"rFCgF2hmavdm6+P6/f7rmuOdoSOmELFaH3JdWjgBLZXS2z51Ma7fa2v2+BkAH1FvBo9FLhvEoFVM4WbNQLXtAA"}},"unsigned":{"age_ts":1754939139117}}`, + eventID: "$accqGxfvhBvMP4Sf6P7t3WgnaJK6UbonO2ZmwqSE5Sg", + roomVersion: id.RoomV12, + serverDetails: mauniumNet, +}, { + name: "custom message event in v4 room", + pdu: `{"auth_events":["$VtuCNOfAWGow-cxy0ajeK3fvONcC8QzF2yWa43g0Gwo","$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY","$Gau_XwziYsr-rt3SouhbKN14twgmbKjcZZc_hz-nOgU"],"content":{"\ud83d\udc08\ufe0f":true,"\ud83d\udc15\ufe0f":false},"depth":69645,"hashes":{"sha256":"VHtWyCt+15ZesNnStU3FOkxrjzHJYZfd3JUgO9JWe0s"},"origin_server_ts":1755423939146,"prev_events":["$exmp4cj0OKOFSxuqBYiOYwQi5j_0XRc78d6EavAkhy0"],"room_id":"!jxlRxnrZCsjpjDubDX:matrix.org","sender":"@tulir:maunium.net","type":"\ud83d\udc08\ufe0f","signatures":{"maunium.net":{"ed25519:a_xxeS":"wfmP1XN4JBkKVkqrQnwysyEUslXt8hQRFwN9NC9vJaIeDMd0OJ6uqCas75808DuG71p23fzqbzhRnHckst6FCQ"}},"unsigned":{"age_ts":1755423939164}}`, + eventID: "$kAagtZAIEeZaLVCUSl74tAxQbdKbE22GU7FM-iAJBc0", + roomVersion: id.RoomV4, + serverDetails: mauniumNet, +}} + +func parsePDU(pdu string) (out *pdu.PDU) { + exerrors.PanicIfNotNil(json.Unmarshal([]byte(pdu), &out)) + return +} + +func TestPDU_CalculateContentHash(t *testing.T) { + for _, test := range testPDUs { + t.Run(test.name, func(t *testing.T) { + parsed := parsePDU(test.pdu) + contentHash := exerrors.Must(parsed.CalculateContentHash()) + assert.Equal( + t, + base64.RawStdEncoding.EncodeToString(parsed.Hashes.SHA256), + base64.RawStdEncoding.EncodeToString(contentHash[:]), + ) + }) + } +} + +func TestPDU_VerifyContentHash(t *testing.T) { + for _, test := range testPDUs { + t.Run(test.name, func(t *testing.T) { + parsed := parsePDU(test.pdu) + assert.True(t, parsed.VerifyContentHash()) + }) + } +} + +func TestPDU_CalculateEventID(t *testing.T) { + for _, test := range testPDUs { + t.Run(test.name, func(t *testing.T) { + gotEventID := exerrors.Must(parsePDU(test.pdu).CalculateEventID(test.roomVersion)) + assert.Equal(t, test.eventID, gotEventID) + }) + } +} + +func TestPDU_VerifySignature(t *testing.T) { + for _, test := range testPDUs { + t.Run(test.name, func(t *testing.T) { + parsed := parsePDU(test.pdu) + err := parsed.VerifySignature(test.roomVersion, test.serverName, func(keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) { + key, ok := test.keys[keyID] + if ok { + return key.key, key.validUntilTS, nil + } + return "", time.Time{}, nil + }) + assert.NoError(t, err) + }) + } +} diff --git a/federation/pdu/redact.go b/federation/pdu/redact.go new file mode 100644 index 00000000..56aaee1c --- /dev/null +++ b/federation/pdu/redact.go @@ -0,0 +1,108 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu + +import ( + "encoding/json/jsontext" + + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.mau.fi/util/exgjson" + "go.mau.fi/util/ptr" + + "maunium.net/go/mautrix/id" +) + +func filteredObject(object jsontext.Value, allowedPaths ...string) jsontext.Value { + filtered := jsontext.Value("{}") + var err error + for _, path := range allowedPaths { + res := gjson.GetBytes(object, path) + if res.Exists() { + var raw jsontext.Value + if res.Index > 0 { + raw = object[res.Index : res.Index+len(res.Raw)] + } else { + raw = jsontext.Value(res.Raw) + } + filtered, err = sjson.SetRawBytes(filtered, path, raw) + if err != nil { + panic(err) + } + } + } + return filtered +} + +func (pdu *PDU) Clone() *PDU { + return ptr.Clone(pdu) +} + +func (pdu *PDU) RedactForSignature(roomVersion id.RoomVersion) *PDU { + pdu.Signatures = nil + return pdu.Redact(roomVersion) +} + +var emptyObject = jsontext.Value("{}") + +func (pdu *PDU) Redact(roomVersion id.RoomVersion) *PDU { + pdu.Unknown = nil + pdu.Unsigned = nil + if roomVersion.UpdatedRedactionRules() { + pdu.DeprecatedPrevState = nil + pdu.DeprecatedOrigin = nil + pdu.DeprecatedMembership = nil + } + + switch pdu.Type { + case "m.room.member": + allowedPaths := []string{"membership"} + if roomVersion.RestrictedJoinsFix() { + allowedPaths = append(allowedPaths, "join_authorised_via_users_server") + } + if roomVersion.UpdatedRedactionRules() { + allowedPaths = append(allowedPaths, exgjson.Path("third_party_invite", "signed")) + } + pdu.Content = filteredObject(pdu.Content, allowedPaths...) + case "m.room.create": + if !roomVersion.UpdatedRedactionRules() { + pdu.Content = filteredObject(pdu.Content, "creator") + } // else: all fields are protected + case "m.room.join_rules": + if roomVersion.RestrictedJoins() { + pdu.Content = filteredObject(pdu.Content, "join_rule", "allow") + } else { + pdu.Content = filteredObject(pdu.Content, "join_rule") + } + case "m.room.power_levels": + allowedKeys := []string{"ban", "events", "events_default", "kick", "redact", "state_default", "users", "users_default"} + if roomVersion.UpdatedRedactionRules() { + allowedKeys = append(allowedKeys, "invite") + } + pdu.Content = filteredObject(pdu.Content, allowedKeys...) + case "m.room.history_visibility": + pdu.Content = filteredObject(pdu.Content, "history_visibility") + case "m.room.redaction": + if roomVersion.RedactsInContent() { + pdu.Content = filteredObject(pdu.Content, "redacts") + pdu.Redacts = nil + } else { + pdu.Content = emptyObject + } + case "m.room.aliases": + if roomVersion.SpecialCasedAliasesAuth() { + pdu.Content = filteredObject(pdu.Content, "aliases") + } else { + pdu.Content = emptyObject + } + default: + pdu.Content = emptyObject + } + return pdu +} diff --git a/federation/pdu/v1.go b/federation/pdu/v1.go new file mode 100644 index 00000000..e8aa82fc --- /dev/null +++ b/federation/pdu/v1.go @@ -0,0 +1,59 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu + +import ( + "encoding/json/jsontext" + + "maunium.net/go/mautrix/id" +) + +type RoomV1PDU struct { + AuthEvents [][]string `json:"auth_events"` + Content jsontext.Value `json:"content"` + Depth int64 `json:"depth"` + EventID id.EventID `json:"event_id"` + Hashes *Hashes `json:"hashes,omitempty"` + OriginServerTS int64 `json:"origin_server_ts"` + PrevEvents [][]string `json:"prev_events"` + Redacts *id.EventID `json:"redacts,omitempty"` + RoomID id.RoomID `json:"room_id"` + Sender id.UserID `json:"sender"` + Signatures map[string]map[id.KeyID]string `json:"signatures,omitempty"` + StateKey *string `json:"state_key,omitempty"` + Type string `json:"type"` + + Unknown jsontext.Value `json:",unknown"` + + // Deprecated legacy fields + DeprecatedPrevState any `json:"prev_state,omitempty"` + DeprecatedOrigin any `json:"origin,omitempty"` + DeprecatedMembership any `json:"membership,omitempty"` +} + +func (pdu *RoomV1PDU) Redact() { + pdu.Unknown = nil + + switch pdu.Type { + case "m.room.member": + pdu.Content = filteredObject(pdu.Content, "membership") + case "m.room.create": + pdu.Content = filteredObject(pdu.Content, "creator") + case "m.room.join_rules": + pdu.Content = filteredObject(pdu.Content, "join_rule") + case "m.room.power_levels": + pdu.Content = filteredObject(pdu.Content, "ban", "events", "events_default", "kick", "redact", "state_default", "users", "users_default") + case "m.room.history_visibility": + pdu.Content = filteredObject(pdu.Content, "history_visibility") + case "m.room.aliases": + pdu.Content = filteredObject(pdu.Content, "aliases") + default: + pdu.Content = jsontext.Value("{}") + } +} From d2e7302daeb36bc352a02262fe1c3cbc72601093 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 17 Aug 2025 12:53:06 +0300 Subject: [PATCH 1297/1647] ci: test goolm and jsonv2 --- .github/workflows/go.yml | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 3cf412b4..fd1dfa92 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -15,7 +15,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v5 with: - go-version: "1.24" + go-version: "1.25" cache: true - name: Install libolm @@ -60,6 +60,11 @@ jobs: - name: Test run: go test -json -v ./... 2>&1 | gotestfmt + - name: Test (jsonv2) + env: + GOEXPERIMENT: jsonv2 + run: go test -json -v ./... 2>&1 | gotestfmt + build-goolm: runs-on: ubuntu-latest strategy: @@ -86,3 +91,11 @@ jobs: run: | rm -rf crypto/libolm go build -tags=goolm -v ./... + + - name: Test + run: go test -tags goolm -json -v ./... 2>&1 | gotestfmt + + - name: Test (jsonv2) + env: + GOEXPERIMENT: jsonv2 + run: go test -tags goolm -json -v ./... 2>&1 | gotestfmt From e85276fc0b014e324c78fcc21c29bf6b35788057 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 17 Aug 2025 12:59:15 +0300 Subject: [PATCH 1298/1647] ci: disable gotestfmt in goolm It explodes with `panic: BUG: Empty package name encountered.` --- .github/workflows/go.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index fd1dfa92..3d58aabc 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -58,12 +58,12 @@ jobs: run: go build -v ./... - name: Test - run: go test -json -v ./... 2>&1 | gotestfmt + run: go test -v ./... - name: Test (jsonv2) env: GOEXPERIMENT: jsonv2 - run: go test -json -v ./... 2>&1 | gotestfmt + run: go test -v ./... build-goolm: runs-on: ubuntu-latest From 86802be0f788f865d655e656b5e6fd87b1728a36 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 17 Aug 2025 13:00:42 +0300 Subject: [PATCH 1299/1647] federation/pdu: gate signing key validity check by room version --- federation/pdu/pdu.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/federation/pdu/pdu.go b/federation/pdu/pdu.go index 2ac970dc..bef7b344 100644 --- a/federation/pdu/pdu.go +++ b/federation/pdu/pdu.go @@ -152,7 +152,7 @@ func (pdu *PDU) VerifySignature( key, validUntil, err := getKey(keyID, originServerTS) if err != nil { return fmt.Errorf("failed to get key %s: %w", keyID, err) - } else if key == "" || validUntil.Before(originServerTS) { + } else if key == "" || (validUntil.Before(originServerTS) && roomVersion.EnforceSigningKeyValidity()) { continue } err = signutil.VerifyJSONRaw(key, sig, rawJSON) From 31178e9f424f005671aa8e4ed7042ea7fb17d1a8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 17 Aug 2025 13:02:47 +0300 Subject: [PATCH 1300/1647] federation/pdu: fail on any signature check error --- federation/pdu/pdu.go | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/federation/pdu/pdu.go b/federation/pdu/pdu.go index bef7b344..0027f02f 100644 --- a/federation/pdu/pdu.go +++ b/federation/pdu/pdu.go @@ -151,15 +151,16 @@ func (pdu *PDU) VerifySignature( originServerTS := time.UnixMilli(pdu.OriginServerTS) key, validUntil, err := getKey(keyID, originServerTS) if err != nil { - return fmt.Errorf("failed to get key %s: %w", keyID, err) - } else if key == "" || (validUntil.Before(originServerTS) && roomVersion.EnforceSigningKeyValidity()) { - continue - } - err = signutil.VerifyJSONRaw(key, sig, rawJSON) - if err != nil { + return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err) + } else if key == "" { + return fmt.Errorf("key %s not found for %s", keyID, serverName) + } else if validUntil.Before(originServerTS) && roomVersion.EnforceSigningKeyValidity() { + return fmt.Errorf("key %s for %s is only valid until %s, but event is from %s", keyID, serverName, validUntil, originServerTS) + } else if err = signutil.VerifyJSONRaw(key, sig, rawJSON); err != nil { return fmt.Errorf("failed to verify signature from key %s: %w", keyID, err) + } else { + verified = true } - verified = true } if !verified { return fmt.Errorf("no verifiable signatures found for server %s", serverName) From 0dc957fa30a8021f8a4b8405545efe9a61c5713c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 17 Aug 2025 13:11:26 +0300 Subject: [PATCH 1301/1647] ci: fix more things --- .github/workflows/go.yml | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 3d58aabc..87bde5f8 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -65,6 +65,15 @@ jobs: GOEXPERIMENT: jsonv2 run: go test -v ./... + - name: Test + run: go test -tags goolm -json -v ./... 2>&1 | gotestfmt + + - name: Test (jsonv2) + if: matrix.go-version == '1.25' + env: + GOEXPERIMENT: jsonv2 + run: go test -tags goolm -json -v ./... 2>&1 | gotestfmt + build-goolm: runs-on: ubuntu-latest strategy: @@ -93,9 +102,10 @@ jobs: go build -tags=goolm -v ./... - name: Test - run: go test -tags goolm -json -v ./... 2>&1 | gotestfmt + run: go test -v ./... - name: Test (jsonv2) + if: matrix.go-version == '1.25' env: GOEXPERIMENT: jsonv2 - run: go test -tags goolm -json -v ./... 2>&1 | gotestfmt + run: go test -v ./... From 0f177058c17a20812c9c2653c12a37d32ae49992 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 17 Aug 2025 13:12:53 +0300 Subject: [PATCH 1302/1647] ci: move tags to correct place --- .github/workflows/go.yml | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 87bde5f8..1eeff30c 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -58,21 +58,13 @@ jobs: run: go build -v ./... - name: Test - run: go test -v ./... - - - name: Test (jsonv2) - env: - GOEXPERIMENT: jsonv2 - run: go test -v ./... - - - name: Test - run: go test -tags goolm -json -v ./... 2>&1 | gotestfmt + run: go test -json -v ./... 2>&1 | gotestfmt - name: Test (jsonv2) if: matrix.go-version == '1.25' env: GOEXPERIMENT: jsonv2 - run: go test -tags goolm -json -v ./... 2>&1 | gotestfmt + run: go test -json -v ./... 2>&1 | gotestfmt build-goolm: runs-on: ubuntu-latest @@ -102,10 +94,10 @@ jobs: go build -tags=goolm -v ./... - name: Test - run: go test -v ./... + run: go test -tags goolm -v ./... - name: Test (jsonv2) if: matrix.go-version == '1.25' env: GOEXPERIMENT: jsonv2 - run: go test -v ./... + run: go test -tags goolm -v ./... From 9b075f8bb9ab32ae6f39a9ef728eda16f1d15f45 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 17 Aug 2025 13:15:53 +0300 Subject: [PATCH 1303/1647] ci: disable tests on goolm again --- .github/workflows/go.yml | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 1eeff30c..dc4f17e2 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -83,21 +83,7 @@ jobs: go-version: ${{ matrix.go-version }} cache: true - - name: Set up gotestfmt - uses: GoTestTools/gotestfmt-action@v2 - with: - token: ${{ secrets.GITHUB_TOKEN }} - - name: Build run: | rm -rf crypto/libolm go build -tags=goolm -v ./... - - - name: Test - run: go test -tags goolm -v ./... - - - name: Test (jsonv2) - if: matrix.go-version == '1.25' - env: - GOEXPERIMENT: jsonv2 - run: go test -tags goolm -v ./... From 6eced49860c126055dcbe91464c647439202b6cd Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 17 Aug 2025 13:32:07 +0300 Subject: [PATCH 1304/1647] client,event: remove deprecated MSC2716 structs --- client.go | 21 --------------------- event/content.go | 1 - event/state.go | 6 ------ event/type.go | 5 +---- requests.go | 12 ------------ responses.go | 12 ------------ 6 files changed, 1 insertion(+), 56 deletions(-) diff --git a/client.go b/client.go index 4906169f..1536ae52 100644 --- a/client.go +++ b/client.go @@ -2530,27 +2530,6 @@ func (cli *Client) ReportRoom(ctx context.Context, roomID id.RoomID, reason stri return err } -// BatchSend sends a batch of historical events into a room. This is only available for appservices. -// -// Deprecated: MSC2716 has been abandoned, so this is now Beeper-specific. BeeperBatchSend should be used instead. -func (cli *Client) BatchSend(ctx context.Context, roomID id.RoomID, req *ReqBatchSend) (resp *RespBatchSend, err error) { - path := ClientURLPath{"unstable", "org.matrix.msc2716", "rooms", roomID, "batch_send"} - query := map[string]string{ - "prev_event_id": req.PrevEventID.String(), - } - if req.BeeperNewMessages { - query["com.beeper.new_messages"] = "true" - } - if req.BeeperMarkReadBy != "" { - query["com.beeper.mark_read_by"] = req.BeeperMarkReadBy.String() - } - if len(req.BatchID) > 0 { - query["batch_id"] = req.BatchID.String() - } - _, err = cli.MakeRequest(ctx, http.MethodPost, cli.BuildURLWithQuery(path, query), req, &resp) - return -} - func (cli *Client) AppservicePing(ctx context.Context, id, txnID string) (resp *RespAppservicePing, err error) { _, err = cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, diff --git a/event/content.go b/event/content.go index b56e35f2..e50dfea5 100644 --- a/event/content.go +++ b/event/content.go @@ -38,7 +38,6 @@ var TypeMap = map[Type]reflect.Type{ StateHalfShotBridge: reflect.TypeOf(BridgeEventContent{}), StateSpaceParent: reflect.TypeOf(SpaceParentEventContent{}), StateSpaceChild: reflect.TypeOf(SpaceChildEventContent{}), - StateInsertionMarker: reflect.TypeOf(InsertionMarkerContent{}), StateLegacyPolicyRoom: reflect.TypeOf(ModPolicyContent{}), StateLegacyPolicyServer: reflect.TypeOf(ModPolicyContent{}), diff --git a/event/state.go b/event/state.go index 44a45a57..de46c57d 100644 --- a/event/state.go +++ b/event/state.go @@ -258,12 +258,6 @@ func (mpc *ModPolicyContent) EntityOrHash() string { return mpc.Entity } -// Deprecated: MSC2716 has been abandoned -type InsertionMarkerContent struct { - InsertionID id.EventID `json:"org.matrix.msc2716.marker.insertion"` - Timestamp int64 `json:"com.beeper.timestamp,omitempty"` -} - type ElementFunctionalMembersContent struct { ServiceMembers []id.UserID `json:"service_members"` } diff --git a/event/type.go b/event/type.go index 591d598d..b097cfe1 100644 --- a/event/type.go +++ b/event/type.go @@ -112,7 +112,7 @@ func (et *Type) GuessClass() TypeClass { StatePowerLevels.Type, StateRoomName.Type, StateRoomAvatar.Type, StateServerACL.Type, StateTopic.Type, StatePinnedEvents.Type, StateTombstone.Type, StateEncryption.Type, StateBridge.Type, StateHalfShotBridge.Type, StateSpaceParent.Type, StateSpaceChild.Type, StatePolicyRoom.Type, StatePolicyServer.Type, StatePolicyUser.Type, - StateInsertionMarker.Type, StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type: + StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type: return StateEventType case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type: return EphemeralEventType @@ -200,9 +200,6 @@ var ( StateUnstablePolicyServer = Type{"org.matrix.mjolnir.rule.server", StateEventType} StateUnstablePolicyUser = Type{"org.matrix.mjolnir.rule.user", StateEventType} - // Deprecated: MSC2716 has been abandoned - StateInsertionMarker = Type{"org.matrix.msc2716.marker", StateEventType} - StateElementFunctionalMembers = Type{"io.element.functional_members", StateEventType} StateBeeperRoomFeatures = Type{"com.beeper.room_features", StateEventType} ) diff --git a/requests.go b/requests.go index 8f31e52f..9871f044 100644 --- a/requests.go +++ b/requests.go @@ -401,18 +401,6 @@ type ReqPutPushRule struct { Pattern string `json:"pattern"` } -// Deprecated: MSC2716 was abandoned -type ReqBatchSend struct { - PrevEventID id.EventID `json:"-"` - BatchID id.BatchID `json:"-"` - - BeeperNewMessages bool `json:"-"` - BeeperMarkReadBy id.UserID `json:"-"` - - StateEventsAtStart []*event.Event `json:"state_events_at_start"` - Events []*event.Event `json:"events"` -} - type ReqBeeperBatchSend struct { // ForwardIfNoMessages should be set to true if the batch should be forward // backfilled if there are no messages currently in the room. diff --git a/responses.go b/responses.go index 27d96ffe..5b97b293 100644 --- a/responses.go +++ b/responses.go @@ -488,18 +488,6 @@ type RespDeviceInfo struct { LastSeenTS int64 `json:"last_seen_ts"` } -// Deprecated: MSC2716 was abandoned -type RespBatchSend struct { - StateEventIDs []id.EventID `json:"state_event_ids"` - EventIDs []id.EventID `json:"event_ids"` - - InsertionEventID id.EventID `json:"insertion_event_id"` - BatchEventID id.EventID `json:"batch_event_id"` - BaseInsertionEventID id.EventID `json:"base_insertion_event_id"` - - NextBatchID id.BatchID `json:"next_batch_id"` -} - type RespBeeperBatchSend struct { EventIDs []id.EventID `json:"event_ids"` } From cc80be150059fce47f46cfa9257f65c3103bb7ba Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 17 Aug 2025 13:45:34 +0300 Subject: [PATCH 1305/1647] federation/pdu: add method to convert to client event --- federation/pdu/pdu.go | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/federation/pdu/pdu.go b/federation/pdu/pdu.go index 0027f02f..22f7212d 100644 --- a/federation/pdu/pdu.go +++ b/federation/pdu/pdu.go @@ -18,12 +18,15 @@ import ( "encoding/json/v2" "errors" "fmt" + "strings" "time" "github.com/tidwall/gjson" "go.mau.fi/util/jsonbytes" + "go.mau.fi/util/ptr" "maunium.net/go/mautrix/crypto/canonicaljson" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/federation/signutil" "maunium.net/go/mautrix/id" ) @@ -98,6 +101,38 @@ func (pdu *PDU) VerifyContentHash() bool { return hmac.Equal(calculatedHash[:], pdu.Hashes.SHA256) } +func (pdu *PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) { + if pdu.Type == "m.room.create" && roomVersion == "" { + roomVersion = id.RoomVersion(gjson.GetBytes(pdu.Content, "room_version").Str) + } + evtType := event.Type{Type: pdu.Type, Class: event.MessageEventType} + if pdu.StateKey != nil { + evtType.Class = event.StateEventType + } + eventID, err := pdu.CalculateEventID(roomVersion) + if err != nil { + return nil, err + } + roomID := pdu.RoomID + if pdu.Type == "m.room.create" && roomVersion.RoomIDIsCreateEventID() { + roomID = id.RoomID(strings.Replace(string(eventID), "$", "!", 1)) + } + evt := &event.Event{ + StateKey: pdu.StateKey, + Sender: pdu.Sender, + Type: evtType, + Timestamp: pdu.OriginServerTS, + ID: eventID, + RoomID: roomID, + Redacts: ptr.Val(pdu.Redacts), + } + err = json.Unmarshal(pdu.Content, &evt.Content) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal content: %w", err) + } + return evt, nil +} + func (pdu *PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error { err := pdu.FillContentHash() if err != nil { From ec663b53d4774335e971951cd55b0049d8489a27 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 17 Aug 2025 20:12:38 +0300 Subject: [PATCH 1306/1647] federation/pdu: reorganize code and add methods to v1 struct --- federation/pdu/hash.go | 113 ++++++++++++++++++ federation/pdu/pdu.go | 164 ++++---------------------- federation/pdu/pdu_test.go | 6 +- federation/pdu/signature.go | 66 +++++++++++ federation/pdu/v1.go | 228 ++++++++++++++++++++++++++++++++++-- federation/pdu/v1_test.go | 86 ++++++++++++++ 6 files changed, 506 insertions(+), 157 deletions(-) create mode 100644 federation/pdu/hash.go create mode 100644 federation/pdu/signature.go create mode 100644 federation/pdu/v1_test.go diff --git a/federation/pdu/hash.go b/federation/pdu/hash.go new file mode 100644 index 00000000..050029df --- /dev/null +++ b/federation/pdu/hash.go @@ -0,0 +1,113 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "fmt" + + "github.com/tidwall/gjson" + + "maunium.net/go/mautrix/id" +) + +func (pdu *PDU) CalculateContentHash() ([32]byte, error) { + if pdu == nil { + return [32]byte{}, ErrPDUIsNil + } + pduClone := pdu.Clone() + pduClone.Signatures = nil + pduClone.Unsigned = nil + pduClone.Hashes = nil + rawJSON, err := marshalCanonical(pduClone) + if err != nil { + return [32]byte{}, fmt.Errorf("failed to marshal PDU to calculate content hash: %w", err) + } + return sha256.Sum256(rawJSON), nil +} + +func (pdu *PDU) FillContentHash() error { + if pdu == nil { + return ErrPDUIsNil + } else if pdu.Hashes != nil { + return nil + } else if hash, err := pdu.CalculateContentHash(); err != nil { + return err + } else { + pdu.Hashes = &Hashes{SHA256: hash[:]} + return nil + } +} + +func (pdu *PDU) VerifyContentHash() bool { + if pdu == nil || pdu.Hashes == nil { + return false + } + calculatedHash, err := pdu.CalculateContentHash() + if err != nil { + return false + } + return hmac.Equal(calculatedHash[:], pdu.Hashes.SHA256) +} + +func (pdu *PDU) GetRoomID() (id.RoomID, error) { + if pdu == nil { + return "", ErrPDUIsNil + } else if pdu.Type != "m.room.create" { + return "", fmt.Errorf("room ID can only be calculated for m.room.create events") + } else if roomVersion := id.RoomVersion(gjson.GetBytes(pdu.Content, "room_version").Str); !roomVersion.RoomIDIsCreateEventID() { + return "", fmt.Errorf("room version %s does not use m.room.create event ID as room ID", roomVersion) + } else if evtID, err := pdu.calculateEventID(roomVersion, '!'); err != nil { + return "", fmt.Errorf("failed to calculate event ID: %w", err) + } else { + return id.RoomID(evtID), nil + } +} + +func (pdu *PDU) GetEventID(roomVersion id.RoomVersion) (id.EventID, error) { + return pdu.calculateEventID(roomVersion, '$') +} + +func (pdu *PDU) GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error) { + if pdu == nil { + return [32]byte{}, ErrPDUIsNil + } + if pdu.Hashes == nil || pdu.Hashes.SHA256 == nil { + if err := pdu.FillContentHash(); err != nil { + return [32]byte{}, err + } + } + rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) + if err != nil { + return [32]byte{}, fmt.Errorf("failed to marshal redacted PDU to calculate event ID: %w", err) + } + return sha256.Sum256(rawJSON), nil +} + +func (pdu *PDU) calculateEventID(roomVersion id.RoomVersion, prefix byte) (id.EventID, error) { + referenceHash, err := pdu.GetReferenceHash(roomVersion) + if err != nil { + return "", err + } + eventID := make([]byte, 44) + eventID[0] = prefix + switch roomVersion.EventIDFormat() { + case id.EventIDFormatCustom: + return "", fmt.Errorf("*pdu.PDU can only be used for room v3+") + case id.EventIDFormatBase64: + base64.RawStdEncoding.Encode(eventID[1:], referenceHash[:]) + case id.EventIDFormatURLSafeBase64: + base64.RawURLEncoding.Encode(eventID[1:], referenceHash[:]) + default: + return "", fmt.Errorf("unknown event ID format %v", roomVersion.EventIDFormat()) + } + return id.EventID(eventID), nil +} diff --git a/federation/pdu/pdu.go b/federation/pdu/pdu.go index 22f7212d..dbd4bff1 100644 --- a/federation/pdu/pdu.go +++ b/federation/pdu/pdu.go @@ -11,9 +11,6 @@ package pdu import ( "bytes" "crypto/ed25519" - "crypto/hmac" - "crypto/sha256" - "encoding/base64" "encoding/json/jsontext" "encoding/json/v2" "errors" @@ -27,18 +24,28 @@ import ( "maunium.net/go/mautrix/crypto/canonicaljson" "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/federation/signutil" "maunium.net/go/mautrix/id" ) -var ErrPDUIsNil = errors.New("PDU is nil") +type GetKeyFunc = func(keyID id.KeyID, minValidUntil time.Time) (id.SigningKey, time.Time, error) -type Hashes struct { - SHA256 jsonbytes.UnpaddedBytes `json:"sha256"` - - Unknown jsontext.Value `json:",unknown"` +type AnyPDU interface { + GetRoomID() (id.RoomID, error) + GetEventID(roomVersion id.RoomVersion) (id.EventID, error) + GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error) + CalculateContentHash() ([32]byte, error) + FillContentHash() error + VerifyContentHash() bool + Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error + VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error + ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) } +var ( + _ AnyPDU = (*PDU)(nil) + _ AnyPDU = (*RoomV1PDU)(nil) +) + type PDU struct { AuthEvents []id.EventID `json:"auth_events"` Content jsontext.Value `json:"content"` @@ -62,43 +69,12 @@ type PDU struct { DeprecatedMembership any `json:"membership,omitzero"` } -func (pdu *PDU) CalculateContentHash() ([32]byte, error) { - if pdu == nil { - return [32]byte{}, ErrPDUIsNil - } - pduClone := pdu.Clone() - pduClone.Signatures = nil - pduClone.Unsigned = nil - pduClone.Hashes = nil - rawJSON, err := marshalCanonical(pduClone) - if err != nil { - return [32]byte{}, fmt.Errorf("failed to marshal PDU to calculate content hash: %w", err) - } - return sha256.Sum256(rawJSON), nil -} +var ErrPDUIsNil = errors.New("PDU is nil") -func (pdu *PDU) FillContentHash() error { - if pdu == nil { - return ErrPDUIsNil - } else if pdu.Hashes != nil { - return nil - } else if hash, err := pdu.CalculateContentHash(); err != nil { - return err - } else { - pdu.Hashes = &Hashes{SHA256: hash[:]} - return nil - } -} +type Hashes struct { + SHA256 jsonbytes.UnpaddedBytes `json:"sha256"` -func (pdu *PDU) VerifyContentHash() bool { - if pdu == nil || pdu.Hashes == nil { - return false - } - calculatedHash, err := pdu.CalculateContentHash() - if err != nil { - return false - } - return hmac.Equal(calculatedHash[:], pdu.Hashes.SHA256) + Unknown jsontext.Value `json:",unknown"` } func (pdu *PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) { @@ -109,7 +85,7 @@ func (pdu *PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) if pdu.StateKey != nil { evtType.Class = event.StateEventType } - eventID, err := pdu.CalculateEventID(roomVersion) + eventID, err := pdu.GetEventID(roomVersion) if err != nil { return nil, err } @@ -133,26 +109,6 @@ func (pdu *PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) return evt, nil } -func (pdu *PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error { - err := pdu.FillContentHash() - if err != nil { - return err - } - rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) - if err != nil { - return fmt.Errorf("failed to marshal redacted PDU to sign: %w", err) - } - signature := ed25519.Sign(privateKey, rawJSON) - if pdu.Signatures == nil { - pdu.Signatures = make(map[string]map[id.KeyID]string) - } - if _, ok := pdu.Signatures[serverName]; !ok { - pdu.Signatures[serverName] = make(map[id.KeyID]string) - } - pdu.Signatures[serverName][keyID] = base64.RawStdEncoding.EncodeToString(signature) - return nil -} - func marshalCanonical(data any) (jsontext.Value, error) { marshaledBytes, err := json.Marshal(data) if err != nil { @@ -171,81 +127,3 @@ func marshalCanonical(data any) (jsontext.Value, error) { } return marshaled, nil } - -func (pdu *PDU) VerifySignature( - roomVersion id.RoomVersion, - serverName string, - getKey func(keyID id.KeyID, minValidUntil time.Time) (id.SigningKey, time.Time, error), -) error { - rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) - if err != nil { - return fmt.Errorf("failed to marshal redacted PDU to verify signature: %w", err) - } - verified := false - for keyID, sig := range pdu.Signatures[serverName] { - originServerTS := time.UnixMilli(pdu.OriginServerTS) - key, validUntil, err := getKey(keyID, originServerTS) - if err != nil { - return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err) - } else if key == "" { - return fmt.Errorf("key %s not found for %s", keyID, serverName) - } else if validUntil.Before(originServerTS) && roomVersion.EnforceSigningKeyValidity() { - return fmt.Errorf("key %s for %s is only valid until %s, but event is from %s", keyID, serverName, validUntil, originServerTS) - } else if err = signutil.VerifyJSONRaw(key, sig, rawJSON); err != nil { - return fmt.Errorf("failed to verify signature from key %s: %w", keyID, err) - } else { - verified = true - } - } - if !verified { - return fmt.Errorf("no verifiable signatures found for server %s", serverName) - } - return nil -} - -func (pdu *PDU) CalculateRoomID() (id.RoomID, error) { - if pdu == nil { - return "", ErrPDUIsNil - } else if pdu.Type != "m.room.create" { - return "", fmt.Errorf("room ID can only be calculated for m.room.create events") - } else if roomVersion := id.RoomVersion(gjson.GetBytes(pdu.Content, "room_version").Str); !roomVersion.RoomIDIsCreateEventID() { - return "", fmt.Errorf("room version %s does not use m.room.create event ID as room ID", roomVersion) - } else if evtID, err := pdu.calculateEventID(roomVersion, '!'); err != nil { - return "", fmt.Errorf("failed to calculate event ID: %w", err) - } else { - return id.RoomID(evtID), nil - } -} - -func (pdu *PDU) CalculateEventID(roomVersion id.RoomVersion) (id.EventID, error) { - return pdu.calculateEventID(roomVersion, '$') -} - -func (pdu *PDU) calculateEventID(roomVersion id.RoomVersion, prefix byte) (id.EventID, error) { - if pdu == nil { - return "", ErrPDUIsNil - } - if pdu.Hashes == nil || pdu.Hashes.SHA256 == nil { - if err := pdu.FillContentHash(); err != nil { - return "", err - } - } - rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) - if err != nil { - return "", fmt.Errorf("failed to marshal redacted PDU to calculate event ID: %w", err) - } - referenceHash := sha256.Sum256(rawJSON) - eventID := make([]byte, 44) - eventID[0] = prefix - switch roomVersion.EventIDFormat() { - case id.EventIDFormatCustom: - return "", fmt.Errorf("*pdu.PDU can only be used for room v3+") - case id.EventIDFormatBase64: - base64.RawStdEncoding.Encode(eventID[1:], referenceHash[:]) - case id.EventIDFormatURLSafeBase64: - base64.RawURLEncoding.Encode(eventID[1:], referenceHash[:]) - default: - return "", fmt.Errorf("unknown event ID format %v", roomVersion.EventIDFormat()) - } - return id.EventID(eventID), nil -} diff --git a/federation/pdu/pdu_test.go b/federation/pdu/pdu_test.go index a650672c..9f6fe74a 100644 --- a/federation/pdu/pdu_test.go +++ b/federation/pdu/pdu_test.go @@ -10,7 +10,7 @@ package pdu_test import ( "encoding/base64" - "encoding/json" + "encoding/json/v2" "testing" "time" @@ -171,10 +171,10 @@ func TestPDU_VerifyContentHash(t *testing.T) { } } -func TestPDU_CalculateEventID(t *testing.T) { +func TestPDU_GetEventID(t *testing.T) { for _, test := range testPDUs { t.Run(test.name, func(t *testing.T) { - gotEventID := exerrors.Must(parsePDU(test.pdu).CalculateEventID(test.roomVersion)) + gotEventID := exerrors.Must(parsePDU(test.pdu).GetEventID(test.roomVersion)) assert.Equal(t, test.eventID, gotEventID) }) } diff --git a/federation/pdu/signature.go b/federation/pdu/signature.go new file mode 100644 index 00000000..1f8ae0b5 --- /dev/null +++ b/federation/pdu/signature.go @@ -0,0 +1,66 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu + +import ( + "crypto/ed25519" + "encoding/base64" + "fmt" + "time" + + "maunium.net/go/mautrix/federation/signutil" + "maunium.net/go/mautrix/id" +) + +func (pdu *PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error { + err := pdu.FillContentHash() + if err != nil { + return err + } + rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) + if err != nil { + return fmt.Errorf("failed to marshal redacted PDU to sign: %w", err) + } + signature := ed25519.Sign(privateKey, rawJSON) + if pdu.Signatures == nil { + pdu.Signatures = make(map[string]map[id.KeyID]string) + } + if _, ok := pdu.Signatures[serverName]; !ok { + pdu.Signatures[serverName] = make(map[id.KeyID]string) + } + pdu.Signatures[serverName][keyID] = base64.RawStdEncoding.EncodeToString(signature) + return nil +} + +func (pdu *PDU) VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error { + rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) + if err != nil { + return fmt.Errorf("failed to marshal redacted PDU to verify signature: %w", err) + } + verified := false + for keyID, sig := range pdu.Signatures[serverName] { + originServerTS := time.UnixMilli(pdu.OriginServerTS) + key, validUntil, err := getKey(keyID, originServerTS) + if err != nil { + return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err) + } else if key == "" { + return fmt.Errorf("key %s not found for %s", keyID, serverName) + } else if validUntil.Before(originServerTS) && roomVersion.EnforceSigningKeyValidity() { + return fmt.Errorf("key %s for %s is only valid until %s, but event is from %s", keyID, serverName, validUntil, originServerTS) + } else if err = signutil.VerifyJSONRaw(key, sig, rawJSON); err != nil { + return fmt.Errorf("failed to verify signature from key %s: %w", keyID, err) + } else { + verified = true + } + } + if !verified { + return fmt.Errorf("no verifiable signatures found for server %s", serverName) + } + return nil +} diff --git a/federation/pdu/v1.go b/federation/pdu/v1.go index e8aa82fc..795253db 100644 --- a/federation/pdu/v1.go +++ b/federation/pdu/v1.go @@ -9,36 +9,95 @@ package pdu import ( + "crypto/ed25519" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" "encoding/json/jsontext" + "encoding/json/v2" + "fmt" + "time" + "go.mau.fi/util/ptr" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/federation/signutil" "maunium.net/go/mautrix/id" ) +type V1EventReference struct { + ID id.EventID + Hashes Hashes +} + +var ( + _ json.UnmarshalerFrom = (*V1EventReference)(nil) + _ json.MarshalerTo = (*V1EventReference)(nil) +) + +func (er *V1EventReference) MarshalJSONTo(enc *jsontext.Encoder) error { + return json.MarshalEncode(enc, []any{er.ID, er.Hashes}) +} + +func (er *V1EventReference) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + var ref V1EventReference + var data []jsontext.Value + if err := json.UnmarshalDecode(dec, &data); err != nil { + return err + } else if len(data) != 2 { + return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: expected array with 2 elements, got %d", len(data)) + } else if err = json.Unmarshal(data[0], &ref.ID); err != nil { + return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: failed to unmarshal event ID: %w", err) + } else if err = json.Unmarshal(data[1], &ref.Hashes); err != nil { + return fmt.Errorf("V1EventReference.UnmarshalJSONFrom: failed to unmarshal hashes: %w", err) + } + *er = ref + return nil +} + type RoomV1PDU struct { - AuthEvents [][]string `json:"auth_events"` + AuthEvents []V1EventReference `json:"auth_events"` Content jsontext.Value `json:"content"` Depth int64 `json:"depth"` EventID id.EventID `json:"event_id"` - Hashes *Hashes `json:"hashes,omitempty"` + Hashes *Hashes `json:"hashes,omitzero"` OriginServerTS int64 `json:"origin_server_ts"` - PrevEvents [][]string `json:"prev_events"` - Redacts *id.EventID `json:"redacts,omitempty"` + PrevEvents []V1EventReference `json:"prev_events"` + Redacts *id.EventID `json:"redacts,omitzero"` RoomID id.RoomID `json:"room_id"` Sender id.UserID `json:"sender"` - Signatures map[string]map[id.KeyID]string `json:"signatures,omitempty"` - StateKey *string `json:"state_key,omitempty"` + Signatures map[string]map[id.KeyID]string `json:"signatures,omitzero"` + StateKey *string `json:"state_key,omitzero"` Type string `json:"type"` + Unsigned jsontext.Value `json:"unsigned,omitzero"` Unknown jsontext.Value `json:",unknown"` // Deprecated legacy fields - DeprecatedPrevState any `json:"prev_state,omitempty"` - DeprecatedOrigin any `json:"origin,omitempty"` - DeprecatedMembership any `json:"membership,omitempty"` + DeprecatedPrevState any `json:"prev_state,omitzero"` + DeprecatedOrigin any `json:"origin,omitzero"` + DeprecatedMembership any `json:"membership,omitzero"` } -func (pdu *RoomV1PDU) Redact() { +func (pdu *RoomV1PDU) GetRoomID() (id.RoomID, error) { + return pdu.RoomID, nil +} + +func (pdu *RoomV1PDU) GetEventID(roomVersion id.RoomVersion) (id.EventID, error) { + if !pdu.SupportsRoomVersion(roomVersion) { + return "", fmt.Errorf("RoomV1PDU.GetEventID: unsupported room version %s", roomVersion) + } + return pdu.EventID, nil +} + +func (pdu *RoomV1PDU) RedactForSignature() *RoomV1PDU { + pdu.Signatures = nil + return pdu.Redact() +} + +func (pdu *RoomV1PDU) Redact() *RoomV1PDU { pdu.Unknown = nil + pdu.Unsigned = nil switch pdu.Type { case "m.room.member": @@ -54,6 +113,153 @@ func (pdu *RoomV1PDU) Redact() { case "m.room.aliases": pdu.Content = filteredObject(pdu.Content, "aliases") default: - pdu.Content = jsontext.Value("{}") + pdu.Content = emptyObject + } + return pdu +} + +func (pdu *RoomV1PDU) GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, error) { + if !pdu.SupportsRoomVersion(roomVersion) { + return [32]byte{}, fmt.Errorf("RoomV1PDU.GetReferenceHash: unsupported room version %s", roomVersion) + } + if pdu == nil { + return [32]byte{}, ErrPDUIsNil + } + if pdu.Hashes == nil || pdu.Hashes.SHA256 == nil { + if err := pdu.FillContentHash(); err != nil { + return [32]byte{}, err + } + } + rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature()) + if err != nil { + return [32]byte{}, fmt.Errorf("failed to marshal redacted PDU to calculate event ID: %w", err) + } + return sha256.Sum256(rawJSON), nil +} + +func (pdu *RoomV1PDU) CalculateContentHash() ([32]byte, error) { + if pdu == nil { + return [32]byte{}, ErrPDUIsNil + } + pduClone := pdu.Clone() + pduClone.Signatures = nil + pduClone.Unsigned = nil + pduClone.Hashes = nil + rawJSON, err := marshalCanonical(pduClone) + if err != nil { + return [32]byte{}, fmt.Errorf("failed to marshal PDU to calculate content hash: %w", err) + } + return sha256.Sum256(rawJSON), nil +} + +func (pdu *RoomV1PDU) FillContentHash() error { + if pdu == nil { + return ErrPDUIsNil + } else if pdu.Hashes != nil { + return nil + } else if hash, err := pdu.CalculateContentHash(); err != nil { + return err + } else { + pdu.Hashes = &Hashes{SHA256: hash[:]} + return nil } } + +func (pdu *RoomV1PDU) VerifyContentHash() bool { + if pdu == nil || pdu.Hashes == nil { + return false + } + calculatedHash, err := pdu.CalculateContentHash() + if err != nil { + return false + } + return hmac.Equal(calculatedHash[:], pdu.Hashes.SHA256) +} + +func (pdu *RoomV1PDU) Clone() *RoomV1PDU { + return ptr.Clone(pdu) +} + +func (pdu *RoomV1PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error { + if !pdu.SupportsRoomVersion(roomVersion) { + return fmt.Errorf("RoomV1PDU.Sign: unsupported room version %s", roomVersion) + } + err := pdu.FillContentHash() + if err != nil { + return err + } + rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature()) + if err != nil { + return fmt.Errorf("failed to marshal redacted PDU to sign: %w", err) + } + signature := ed25519.Sign(privateKey, rawJSON) + if pdu.Signatures == nil { + pdu.Signatures = make(map[string]map[id.KeyID]string) + } + if _, ok := pdu.Signatures[serverName]; !ok { + pdu.Signatures[serverName] = make(map[id.KeyID]string) + } + pdu.Signatures[serverName][keyID] = base64.RawStdEncoding.EncodeToString(signature) + return nil +} + +func (pdu *RoomV1PDU) VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error { + if !pdu.SupportsRoomVersion(roomVersion) { + return fmt.Errorf("RoomV1PDU.VerifySignature: unsupported room version %s", roomVersion) + } + rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature()) + if err != nil { + return fmt.Errorf("failed to marshal redacted PDU to verify signature: %w", err) + } + verified := false + for keyID, sig := range pdu.Signatures[serverName] { + originServerTS := time.UnixMilli(pdu.OriginServerTS) + key, _, err := getKey(keyID, originServerTS) + if err != nil { + return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err) + } else if key == "" { + return fmt.Errorf("key %s not found for %s", keyID, serverName) + } else if err = signutil.VerifyJSONRaw(key, sig, rawJSON); err != nil { + return fmt.Errorf("failed to verify signature from key %s: %w", keyID, err) + } else { + verified = true + } + } + if !verified { + return fmt.Errorf("no verifiable signatures found for server %s", serverName) + } + return nil +} + +func (pdu *RoomV1PDU) SupportsRoomVersion(roomVersion id.RoomVersion) bool { + switch roomVersion { + case id.RoomV0, id.RoomV1, id.RoomV2: + return true + default: + return false + } +} + +func (pdu *RoomV1PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) { + if !pdu.SupportsRoomVersion(roomVersion) { + return nil, fmt.Errorf("RoomV1PDU.ToClientEvent: unsupported room version %s", roomVersion) + } + evtType := event.Type{Type: pdu.Type, Class: event.MessageEventType} + if pdu.StateKey != nil { + evtType.Class = event.StateEventType + } + evt := &event.Event{ + StateKey: pdu.StateKey, + Sender: pdu.Sender, + Type: evtType, + Timestamp: pdu.OriginServerTS, + ID: pdu.EventID, + RoomID: pdu.RoomID, + Redacts: ptr.Val(pdu.Redacts), + } + err := json.Unmarshal(pdu.Content, &evt.Content) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal content: %w", err) + } + return evt, nil +} diff --git a/federation/pdu/v1_test.go b/federation/pdu/v1_test.go new file mode 100644 index 00000000..e5531b0b --- /dev/null +++ b/federation/pdu/v1_test.go @@ -0,0 +1,86 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu_test + +import ( + "encoding/base64" + "encoding/json/v2" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "go.mau.fi/util/exerrors" + + "maunium.net/go/mautrix/federation/pdu" + "maunium.net/go/mautrix/id" +) + +var testV1PDUs = []testPDU{{ + name: "m.room.message in v1 room", + pdu: `{"auth_events":[["$159234730483190eXavq:matrix.org",{"sha256":"VprZrhMqOQyKbfF3UE26JXE8D27ih4R/FGGc8GZ0Whs"}],["$143454825711DhCxH:matrix.org",{"sha256":"3sJh/5GOB094OKuhbjL634Gt69YIcge9GD55ciJa9ok"}],["$156837651426789wiPdh:maunium.net",{"sha256":"FGyR3sxJ/VxYabDkO/5qtwrPR3hLwGknJ0KX0w3GUHE"}]],"content":{"body":"photo-1526336024174-e58f5cdd8e13.jpg","info":{"h":1620,"mimetype":"image/jpeg","size":208053,"w":1080},"msgtype":"m.image","url":"mxc://maunium.net/aEqEghIjFPAerIhCxJCYpQeC"},"depth":16669,"event_id":"$16738169022163bokdi:maunium.net","hashes":{"sha256":"XYB47Gf2vAci3BTguIJaC75ZYGMuVY65jcvoUVgpcLA"},"origin":"maunium.net","origin_server_ts":1673816902100,"prev_events":[["$1673816901121325UMCjA:matrix.org",{"sha256":"t7e0IYHLI3ydIPoIU8a8E/pIWXH9cNLlQBEtGyGtHwc"}]],"room_id":"!jhpZBTbckszblMYjMK:matrix.org","sender":"@cat:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"uRZbEm+P+Y1ZVgwBn5I6SlaUZdzlH1bB4nv81yt5EIQ0b1fZ8YgM4UWMijrrXp3+NmqRFl0cakSM3MneJOtFCw"}},"unsigned":{"age_ts":1673816902100}}`, + eventID: "$16738169022163bokdi:maunium.net", + roomVersion: id.RoomV1, + serverDetails: mauniumNet, +}, { + name: "m.room.create in v1 room", + pdu: `{"origin": "matrix.org", "signatures": {"matrix.org": {"ed25519:auto": "XTejpXn5REoHrZWgCpJglGX7MfOWS2zUjYwJRLrwW2PQPbFdqtL+JnprBXwIP2C1NmgWSKG+am1QdApu0KoHCQ"}}, "origin_server_ts": 1434548257426, "sender": "@appservice-irc:matrix.org", "event_id": "$143454825711DhCxH:matrix.org", "prev_events": [], "unsigned": {"age": 12872287834}, "state_key": "", "content": {"creator": "@appservice-irc:matrix.org"}, "depth": 1, "prev_state": [], "room_id": "!jhpZBTbckszblMYjMK:matrix.org", "auth_events": [], "hashes": {"sha256": "+SSdmeeoKI/6yK6sY4XAFljWFiugSlCiXQf0QMCZjTs"}, "type": "m.room.create"}`, + eventID: "$143454825711DhCxH:matrix.org", + roomVersion: id.RoomV1, + serverDetails: matrixOrg, +}, { + name: "m.room.member in v1 room", + pdu: `{"auth_events": [["$1536447669931522zlyWe:matrix.org", {"sha256": "UkzPGd7cPAGvC0FVx3Yy2/Q0GZhA2kcgj8MGp5pjYV8"}], ["$143454825711DhCxH:matrix.org", {"sha256": "3sJh/5GOB094OKuhbjL634Gt69YIcge9GD55ciJa9ok"}], ["$143454825714nUEqZ:matrix.org", {"sha256": "NjuZXu8EDMfIfejPcNlC/IdnKQAGpPIcQjHaf0BZaHk"}]], "prev_events": [["$15660585503271JRRMm:maunium.net", {"sha256": "/Sm7uSLkYMHapp6I3NuEVJlk2JucW2HqjsQy9vzhciA"}]], "type": "m.room.member", "room_id": "!jhpZBTbckszblMYjMK:matrix.org", "sender": "@tulir:maunium.net", "content": {"membership": "join", "avatar_url": "mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO", "displayname": "tulir"}, "depth": 10485, "prev_state": [], "state_key": "@tulir:maunium.net", "event_id": "$15660585693272iEryv:maunium.net", "origin": "maunium.net", "origin_server_ts": 1566058569201, "hashes": {"sha256": "1D6fdDzKsMGCxSqlXPA7I9wGQNTutVuJke1enGHoWK8"}, "signatures": {"maunium.net": {"ed25519:a_xxeS": "Lj/zDK6ozr4vgsxyL8jY56wTGWoA4jnlvkTs5paCX1w3nNKHnQnSMi+wuaqI6yv5vYh9usGWco2LLMuMzYXcBg"}}, "unsigned": {"age_ts": 1566058569201, "replaces_state": "$15660585383268liyBc:maunium.net"}}`, + eventID: "$15660585693272iEryv:maunium.net", + roomVersion: id.RoomV1, + serverDetails: mauniumNet, +}} + +func parseV1PDU(pdu string) (out *pdu.RoomV1PDU) { + exerrors.PanicIfNotNil(json.Unmarshal([]byte(pdu), &out)) + return +} + +func TestRoomV1PDU_CalculateContentHash(t *testing.T) { + for _, test := range testV1PDUs { + t.Run(test.name, func(t *testing.T) { + parsed := parseV1PDU(test.pdu) + contentHash := exerrors.Must(parsed.CalculateContentHash()) + assert.Equal( + t, + base64.RawStdEncoding.EncodeToString(parsed.Hashes.SHA256), + base64.RawStdEncoding.EncodeToString(contentHash[:]), + ) + }) + } +} + +func TestRoomV1PDU_VerifyContentHash(t *testing.T) { + for _, test := range testV1PDUs { + t.Run(test.name, func(t *testing.T) { + parsed := parseV1PDU(test.pdu) + assert.True(t, parsed.VerifyContentHash()) + }) + } +} + +func TestRoomV1PDU_VerifySignature(t *testing.T) { + for _, test := range testV1PDUs { + t.Run(test.name, func(t *testing.T) { + parsed := parseV1PDU(test.pdu) + err := parsed.VerifySignature(test.roomVersion, test.serverName, func(keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) { + key, ok := test.keys[keyID] + if ok { + return key.key, key.validUntilTS, nil + } + return "", time.Time{}, nil + }) + assert.NoError(t, err) + }) + } +} From ca4ca62249f9906b2bf900b3dc0f60630abeb7b6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 17 Aug 2025 20:24:13 +0300 Subject: [PATCH 1307/1647] federation/pdu: add docs for GetKeyFunc --- federation/pdu/pdu.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/federation/pdu/pdu.go b/federation/pdu/pdu.go index dbd4bff1..860201c5 100644 --- a/federation/pdu/pdu.go +++ b/federation/pdu/pdu.go @@ -27,7 +27,13 @@ import ( "maunium.net/go/mautrix/id" ) -type GetKeyFunc = func(keyID id.KeyID, minValidUntil time.Time) (id.SigningKey, time.Time, error) +// GetKeyFunc is a callback for retrieving the key corresponding to a given key ID when verifying the signature of a PDU. +// +// The input time is the timestamp of the event. The function should attempt to fetch a key that is +// valid at or after this time, but if that is not possible, the latest available key should be +// returned without an error. The verify function will do its own validity checking based on the +// returned valid until timestamp. +type GetKeyFunc = func(keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) type AnyPDU interface { GetRoomID() (id.RoomID, error) From d1004d42b090e1d566815f4ea57608ece459d036 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 18 Aug 2025 00:24:57 +0300 Subject: [PATCH 1308/1647] client: add method to download media thumbnail --- client.go | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/client.go b/client.go index 1536ae52..78f83b85 100644 --- a/client.go +++ b/client.go @@ -1704,6 +1704,38 @@ func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (*http.Re return resp, err } +type DownloadThumbnailExtra struct { + Method string + Animated bool +} + +func (cli *Client) DownloadThumbnail(ctx context.Context, mxcURL id.ContentURI, height, width int, extras ...DownloadThumbnailExtra) (*http.Response, error) { + if len(extras) > 1 { + panic(fmt.Errorf("invalid number of arguments to DownloadThumbnail: %d", len(extras))) + } + var extra DownloadThumbnailExtra + if len(extras) == 1 { + extra = extras[0] + } + path := ClientURLPath{"v1", "media", "thumbnail", mxcURL.Homeserver, mxcURL.FileID} + query := map[string]string{ + "height": strconv.Itoa(height), + "width": strconv.Itoa(width), + } + if extra.Method != "" { + query["method"] = extra.Method + } + if extra.Animated { + query["animated"] = "true" + } + _, resp, err := cli.MakeFullRequestWithResp(ctx, FullRequest{ + Method: http.MethodGet, + URL: cli.BuildURLWithQuery(path, query), + DontReadResponse: true, + }) + return resp, err +} + func (cli *Client) DownloadBytes(ctx context.Context, mxcURL id.ContentURI) ([]byte, error) { resp, err := cli.Download(ctx, mxcURL) if err != nil { From 05b711d1816b1603df459995b4350008238b676c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 18 Aug 2025 00:53:23 +0300 Subject: [PATCH 1309/1647] federation/pdu: add more tests for signature checks --- federation/pdu/hash_test.go | 49 +++++++++++++++ federation/pdu/pdu_test.go | 85 ++++++++------------------ federation/pdu/signature_test.go | 102 +++++++++++++++++++++++++++++++ 3 files changed, 177 insertions(+), 59 deletions(-) create mode 100644 federation/pdu/hash_test.go create mode 100644 federation/pdu/signature_test.go diff --git a/federation/pdu/hash_test.go b/federation/pdu/hash_test.go new file mode 100644 index 00000000..35ea49df --- /dev/null +++ b/federation/pdu/hash_test.go @@ -0,0 +1,49 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu_test + +import ( + "encoding/base64" + "testing" + + "github.com/stretchr/testify/assert" + "go.mau.fi/util/exerrors" +) + +func TestPDU_CalculateContentHash(t *testing.T) { + for _, test := range testPDUs { + t.Run(test.name, func(t *testing.T) { + parsed := parsePDU(test.pdu) + contentHash := exerrors.Must(parsed.CalculateContentHash()) + assert.Equal( + t, + base64.RawStdEncoding.EncodeToString(parsed.Hashes.SHA256), + base64.RawStdEncoding.EncodeToString(contentHash[:]), + ) + }) + } +} + +func TestPDU_VerifyContentHash(t *testing.T) { + for _, test := range testPDUs { + t.Run(test.name, func(t *testing.T) { + parsed := parsePDU(test.pdu) + assert.True(t, parsed.VerifyContentHash()) + }) + } +} + +func TestPDU_GetEventID(t *testing.T) { + for _, test := range testPDUs { + t.Run(test.name, func(t *testing.T) { + gotEventID := exerrors.Must(parsePDU(test.pdu).GetEventID(test.roomVersion)) + assert.Equal(t, test.eventID, gotEventID) + }) + } +} diff --git a/federation/pdu/pdu_test.go b/federation/pdu/pdu_test.go index 9f6fe74a..f4920098 100644 --- a/federation/pdu/pdu_test.go +++ b/federation/pdu/pdu_test.go @@ -9,12 +9,9 @@ package pdu_test import ( - "encoding/base64" "encoding/json/v2" - "testing" "time" - "github.com/stretchr/testify/assert" "go.mau.fi/util/exerrors" "maunium.net/go/mautrix/federation/pdu" @@ -31,6 +28,14 @@ type serverDetails struct { keys map[id.KeyID]serverKey } +func (sd serverDetails) getKey(keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) { + key, ok := sd.keys[keyID] + if ok { + return key.key, key.validUntilTS, nil + } + return "", time.Time{}, nil +} + var mauniumNet = serverDetails{ serverName: "maunium.net", keys: map[id.KeyID]serverKey{ @@ -75,7 +80,23 @@ type testPDU struct { serverDetails } -var testPDUs = []testPDU{{ +var roomV4MessageTestPDU = testPDU{ + name: "m.room.message in v4 room", + pdu: `{"auth_events":["$OB87jNemaIVDHAfu0-pa_cP7OPFXUXCbFpjYVi8gll4","$RaWbTF9wQfGQgUpe1S13wzICtGTB2PNKRHUNHu9IO1c","$ZmEWOXw6cC4Rd1wTdY5OzeLJVzjhrkxFPwwKE4gguGk"],"content":{"body":"the last one is saying it shouldn't have effects","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":13103,"hashes":{"sha256":"c2wb8qMlvzIPCP1Wd+eYZ4BRgnGYxS97dR1UlJjVMeg"},"origin_server_ts":1752875275263,"prev_events":["$-7_BMI3BXwj3ayoxiJvraJxYWTKwjiQ6sh7CW_Brvj0"],"room_id":"!JiiOHXrIUCtcOJsZCa:matrix.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"99TAqHpBkUEtgCraXsVXogmf/hnijPbgbG9eACtA+mbix3Y6gURI4QGQgcX/NhcE3pJQZ/YDjmbuvCnKvEccAA"}},"unsigned":{"age_ts":1752875275281}}`, + eventID: "$Jo_lmFR-e6lzrimzCA7DevIn2OwhuQYmd9xkcJBoqAA", + roomVersion: id.RoomV4, + serverDetails: mauniumNet, +} + +var roomV12MessageTestPDU = testPDU{ + name: "m.room.message in v12 room", + pdu: `{"auth_events":["$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA","$hyeL_nU_L3tsZ2dtZZpAHk0Skv-PqFQIipuII_By584"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":122,"hashes":{"sha256":"IQ0zlc+PXeEs6R3JvRkW3xTPV3zlGKSSd3x07KXGjzs"},"origin_server_ts":1755384351627,"prev_events":["$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA"],"room_id":"!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ","sender":"@tulir_test:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"0GDMddL2k7gF4V1VU8sL3wTfhAIzAu5iVH5jeavZ2VEg3J9/tHLWXAOn2tzkLaMRWl0/XpINT2YlH/rd2U21Ag"}},"unsigned":{"age_ts":1755384351627}}`, + eventID: "$xmP-wZfpannuHG-Akogi6c4YvqxChMtdyYbUMGOrMWc", + roomVersion: id.RoomV12, + serverDetails: mauniumNet, +} + +var testPDUs = []testPDU{roomV4MessageTestPDU, { name: "m.room.message in v5 room", pdu: `{"auth_events":["$hp0ImHqYgHTRbLeWKPeTeFmxdb5SdMJN9cfmTrTk7d0","$KAj7X7tnJbR9qYYMWJSw-1g414_KlPptbbkZm7_kUtg","$V-2ShOwZYhA_nxMijaf3lqFgIJgzE2UMeFPtOLnoBYM"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":2248,"hashes":{"sha256":"kV+JuLbWXJ2r6PjHT3wt8bFc/TfI1nTaSN3Lamg/xHs"},"origin_server_ts":1755422945654,"prev_events":["$49lFLem2Nk4dxHk9RDXxTdaq9InIJpmkHpzVnjKcYwg"],"room_id":"!vzBgJsjNzgHSdWsmki:mozilla.org","sender":"@tulir:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"JIl60uVgfCLBZLPoSiE7wVkJ9U5cNEPVPuv1sCCYUOq5yOW56WD1adgpBUdX2UFpYkCHvkRnyQGxU0+6HBp5BA"}},"unsigned":{"age_ts":1755422945673}}`, eventID: "$Qn4tHfuAe6PlnKXPZnygAU9wd6RXqMKtt_ZzstHTSgA", @@ -93,13 +114,7 @@ var testPDUs = []testPDU{{ eventID: `$qkWfTL7_l3oRZO2CItW8-Q0yAmi_l_1ua629ZDqponE`, roomVersion: id.RoomV11, serverDetails: mauniumNet, -}, { - name: "m.room.message in v12 room", - pdu: `{"auth_events":["$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA","$hyeL_nU_L3tsZ2dtZZpAHk0Skv-PqFQIipuII_By584"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":122,"hashes":{"sha256":"IQ0zlc+PXeEs6R3JvRkW3xTPV3zlGKSSd3x07KXGjzs"},"origin_server_ts":1755384351627,"prev_events":["$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA"],"room_id":"!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ","sender":"@tulir_test:maunium.net","type":"m.room.message","signatures":{"maunium.net":{"ed25519:a_xxeS":"0GDMddL2k7gF4V1VU8sL3wTfhAIzAu5iVH5jeavZ2VEg3J9/tHLWXAOn2tzkLaMRWl0/XpINT2YlH/rd2U21Ag"}},"unsigned":{"age_ts":1755384351627}}`, - eventID: "$xmP-wZfpannuHG-Akogi6c4YvqxChMtdyYbUMGOrMWc", - roomVersion: id.RoomV12, - serverDetails: mauniumNet, -}, { +}, roomV12MessageTestPDU, { name: "m.room.create in v4 room", pdu: `{"auth_events": [], "prev_events": [], "type": "m.room.create", "room_id": "!jxlRxnrZCsjpjDubDX:matrix.org", "sender": "@neilj:matrix.org", "content": {"room_version": "4", "predecessor": {"room_id": "!DYgXKezaHgMbiPMzjX:matrix.org", "event_id": "$156171636353XwPJT:matrix.org"}, "creator": "@neilj:matrix.org"}, "depth": 1, "prev_state": [], "state_key": "", "origin": "matrix.org", "origin_server_ts": 1561716363993, "hashes": {"sha256": "9tj8GpXjTAJvdNAbnuKLemZZk+Tjv2LAbGodSX6nJAo"}, "signatures": {"matrix.org": {"ed25519:auto": "2+sNt8uJUhzU4GPxnFVYtU2ZRgFdtVLT1vEZGUdJYN40zBpwYEGJy+kyb5matA+8/yLeYD9gu1O98lhleH0aCA"}}, "unsigned": {"age": 104769}}`, eventID: "$ay_9_nPilrTpb3UxIwHHBBfFjTJb6hBAE_JzQwSjqeY", @@ -147,51 +162,3 @@ func parsePDU(pdu string) (out *pdu.PDU) { exerrors.PanicIfNotNil(json.Unmarshal([]byte(pdu), &out)) return } - -func TestPDU_CalculateContentHash(t *testing.T) { - for _, test := range testPDUs { - t.Run(test.name, func(t *testing.T) { - parsed := parsePDU(test.pdu) - contentHash := exerrors.Must(parsed.CalculateContentHash()) - assert.Equal( - t, - base64.RawStdEncoding.EncodeToString(parsed.Hashes.SHA256), - base64.RawStdEncoding.EncodeToString(contentHash[:]), - ) - }) - } -} - -func TestPDU_VerifyContentHash(t *testing.T) { - for _, test := range testPDUs { - t.Run(test.name, func(t *testing.T) { - parsed := parsePDU(test.pdu) - assert.True(t, parsed.VerifyContentHash()) - }) - } -} - -func TestPDU_GetEventID(t *testing.T) { - for _, test := range testPDUs { - t.Run(test.name, func(t *testing.T) { - gotEventID := exerrors.Must(parsePDU(test.pdu).GetEventID(test.roomVersion)) - assert.Equal(t, test.eventID, gotEventID) - }) - } -} - -func TestPDU_VerifySignature(t *testing.T) { - for _, test := range testPDUs { - t.Run(test.name, func(t *testing.T) { - parsed := parsePDU(test.pdu) - err := parsed.VerifySignature(test.roomVersion, test.serverName, func(keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) { - key, ok := test.keys[keyID] - if ok { - return key.key, key.validUntilTS, nil - } - return "", time.Time{}, nil - }) - assert.NoError(t, err) - }) - } -} diff --git a/federation/pdu/signature_test.go b/federation/pdu/signature_test.go new file mode 100644 index 00000000..68e7a773 --- /dev/null +++ b/federation/pdu/signature_test.go @@ -0,0 +1,102 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu_test + +import ( + "crypto/ed25519" + "encoding/base64" + "encoding/json/jsontext" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.mau.fi/util/exerrors" + + "maunium.net/go/mautrix/federation/pdu" + "maunium.net/go/mautrix/id" +) + +func TestPDU_VerifySignature(t *testing.T) { + for _, test := range testPDUs { + t.Run(test.name, func(t *testing.T) { + parsed := parsePDU(test.pdu) + err := parsed.VerifySignature(test.roomVersion, test.serverName, test.getKey) + assert.NoError(t, err) + }) + } +} + +func TestPDU_VerifySignature_Fail_NoKey(t *testing.T) { + test := roomV12MessageTestPDU + parsed := parsePDU(test.pdu) + err := parsed.VerifySignature(test.roomVersion, test.serverName, func(keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { + return + }) + assert.Error(t, err) +} + +func TestPDU_VerifySignature_V4ExpiredKey(t *testing.T) { + test := roomV4MessageTestPDU + parsed := parsePDU(test.pdu) + err := parsed.VerifySignature(test.roomVersion, test.serverName, func(keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { + key = test.keys[keyID].key + validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + return + }) + assert.NoError(t, err) +} + +func TestPDU_VerifySignature_V12ExpiredKey(t *testing.T) { + test := roomV12MessageTestPDU + parsed := parsePDU(test.pdu) + err := parsed.VerifySignature(test.roomVersion, test.serverName, func(keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { + key = test.keys[keyID].key + validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + return + }) + assert.Error(t, err) +} + +func TestPDU_VerifySignature_V12InvalidSignature(t *testing.T) { + test := roomV12MessageTestPDU + parsed := parsePDU(test.pdu) + for _, sigs := range parsed.Signatures { + for key := range sigs { + sigs[key] = sigs[key][:len(sigs[key])-3] + "ABC" + } + } + err := parsed.VerifySignature(test.roomVersion, test.serverName, test.getKey) + assert.Error(t, err) +} + +func TestPDU_Sign(t *testing.T) { + pubKey, privKey := exerrors.Must2(ed25519.GenerateKey(nil)) + evt := &pdu.PDU{ + AuthEvents: []id.EventID{"$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA", "$hyeL_nU_L3tsZ2dtZZpAHk0Skv-PqFQIipuII_By584"}, + Content: jsontext.Value(`{"msgtype":"m.text","body":"Hello, world!"}`), + Depth: 123, + OriginServerTS: 1755384351627, + PrevEvents: []id.EventID{"$gCzdJUVV93Qory0x7p_PLG5UUiDjPJNe1H12qbHTuFA"}, + RoomID: "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", + Sender: "@tulir:example.com", + Type: "m.room.message", + } + err := evt.Sign(id.RoomV12, "example.com", "ed25519:rand", privKey) + require.NoError(t, err) + err = evt.VerifySignature(id.RoomV11, "example.com", func(keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { + if keyID == "ed25519:rand" { + key = id.SigningKey(base64.RawStdEncoding.EncodeToString(pubKey)) + validUntil = time.Now() + } + return + }) + require.NoError(t, err) + +} From baf54f57b61d1ee091b96e974c7aa9201a82a588 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 19 Aug 2025 19:44:51 +0300 Subject: [PATCH 1310/1647] crypto/encryptmegolm: add fallback for copying `m.relates_to` --- crypto/encryptmegolm.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index 14ba2449..cd211af5 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -41,7 +41,7 @@ func getRawJSON[T any](content json.RawMessage, path ...string) *T { return &result } -func getRelatesTo(content any) *event.RelatesTo { +func getRelatesTo(content any, plaintext json.RawMessage) *event.RelatesTo { contentJSON, ok := content.(json.RawMessage) if ok { return getRawJSON[event.RelatesTo](contentJSON, "m.relates_to") @@ -54,7 +54,7 @@ func getRelatesTo(content any) *event.RelatesTo { if ok { return relatable.OptionalGetRelatesTo() } - return nil + return getRawJSON[event.RelatesTo](plaintext, "content", "m.relates_to") } func getMentions(content any) *event.Mentions { @@ -158,7 +158,7 @@ func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, room Algorithm: id.AlgorithmMegolmV1, SessionID: session.ID(), MegolmCiphertext: ciphertext, - RelatesTo: getRelatesTo(content), + RelatesTo: getRelatesTo(content, plaintext), // These are deprecated SenderKey: mach.account.IdentityKey(), From 29780ffb183c92b9526c6f15e1810f9c3945b19e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 21 Aug 2025 13:18:11 +0300 Subject: [PATCH 1311/1647] federation/pdu: refactor redaction to allow reuse of RedactContent --- federation/pdu/redact.go | 55 +++++++++++++++++++++------------------- federation/pdu/v1.go | 31 +++++++--------------- 2 files changed, 38 insertions(+), 48 deletions(-) diff --git a/federation/pdu/redact.go b/federation/pdu/redact.go index 56aaee1c..d7ee0c15 100644 --- a/federation/pdu/redact.go +++ b/federation/pdu/redact.go @@ -51,16 +51,8 @@ func (pdu *PDU) RedactForSignature(roomVersion id.RoomVersion) *PDU { var emptyObject = jsontext.Value("{}") -func (pdu *PDU) Redact(roomVersion id.RoomVersion) *PDU { - pdu.Unknown = nil - pdu.Unsigned = nil - if roomVersion.UpdatedRedactionRules() { - pdu.DeprecatedPrevState = nil - pdu.DeprecatedOrigin = nil - pdu.DeprecatedMembership = nil - } - - switch pdu.Type { +func RedactContent(eventType string, content jsontext.Value, roomVersion id.RoomVersion) jsontext.Value { + switch eventType { case "m.room.member": allowedPaths := []string{"membership"} if roomVersion.RestrictedJoinsFix() { @@ -69,40 +61,51 @@ func (pdu *PDU) Redact(roomVersion id.RoomVersion) *PDU { if roomVersion.UpdatedRedactionRules() { allowedPaths = append(allowedPaths, exgjson.Path("third_party_invite", "signed")) } - pdu.Content = filteredObject(pdu.Content, allowedPaths...) + return filteredObject(content, allowedPaths...) case "m.room.create": if !roomVersion.UpdatedRedactionRules() { - pdu.Content = filteredObject(pdu.Content, "creator") - } // else: all fields are protected + return filteredObject(content, "creator") + } + return content case "m.room.join_rules": if roomVersion.RestrictedJoins() { - pdu.Content = filteredObject(pdu.Content, "join_rule", "allow") - } else { - pdu.Content = filteredObject(pdu.Content, "join_rule") + return filteredObject(content, "join_rule", "allow") } + return filteredObject(content, "join_rule") case "m.room.power_levels": allowedKeys := []string{"ban", "events", "events_default", "kick", "redact", "state_default", "users", "users_default"} if roomVersion.UpdatedRedactionRules() { allowedKeys = append(allowedKeys, "invite") } - pdu.Content = filteredObject(pdu.Content, allowedKeys...) + return filteredObject(content, allowedKeys...) case "m.room.history_visibility": - pdu.Content = filteredObject(pdu.Content, "history_visibility") + return filteredObject(content, "history_visibility") case "m.room.redaction": if roomVersion.RedactsInContent() { - pdu.Content = filteredObject(pdu.Content, "redacts") - pdu.Redacts = nil - } else { - pdu.Content = emptyObject + return filteredObject(content, "redacts") } + return emptyObject case "m.room.aliases": if roomVersion.SpecialCasedAliasesAuth() { - pdu.Content = filteredObject(pdu.Content, "aliases") - } else { - pdu.Content = emptyObject + return filteredObject(content, "aliases") } + return emptyObject default: - pdu.Content = emptyObject + return emptyObject } +} + +func (pdu *PDU) Redact(roomVersion id.RoomVersion) *PDU { + pdu.Unknown = nil + pdu.Unsigned = nil + if roomVersion.UpdatedRedactionRules() { + pdu.DeprecatedPrevState = nil + pdu.DeprecatedOrigin = nil + pdu.DeprecatedMembership = nil + } + if pdu.Type != "m.room.redaction" || roomVersion.RedactsInContent() { + pdu.Redacts = nil + } + pdu.Content = RedactContent(pdu.Type, pdu.Content, roomVersion) return pdu } diff --git a/federation/pdu/v1.go b/federation/pdu/v1.go index 795253db..1bc324ed 100644 --- a/federation/pdu/v1.go +++ b/federation/pdu/v1.go @@ -90,31 +90,18 @@ func (pdu *RoomV1PDU) GetEventID(roomVersion id.RoomVersion) (id.EventID, error) return pdu.EventID, nil } -func (pdu *RoomV1PDU) RedactForSignature() *RoomV1PDU { +func (pdu *RoomV1PDU) RedactForSignature(roomVersion id.RoomVersion) *RoomV1PDU { pdu.Signatures = nil - return pdu.Redact() + return pdu.Redact(roomVersion) } -func (pdu *RoomV1PDU) Redact() *RoomV1PDU { +func (pdu *RoomV1PDU) Redact(roomVersion id.RoomVersion) *RoomV1PDU { pdu.Unknown = nil pdu.Unsigned = nil - - switch pdu.Type { - case "m.room.member": - pdu.Content = filteredObject(pdu.Content, "membership") - case "m.room.create": - pdu.Content = filteredObject(pdu.Content, "creator") - case "m.room.join_rules": - pdu.Content = filteredObject(pdu.Content, "join_rule") - case "m.room.power_levels": - pdu.Content = filteredObject(pdu.Content, "ban", "events", "events_default", "kick", "redact", "state_default", "users", "users_default") - case "m.room.history_visibility": - pdu.Content = filteredObject(pdu.Content, "history_visibility") - case "m.room.aliases": - pdu.Content = filteredObject(pdu.Content, "aliases") - default: - pdu.Content = emptyObject + if pdu.Type != "m.room.redaction" { + pdu.Redacts = nil } + pdu.Content = RedactContent(pdu.Type, pdu.Content, roomVersion) return pdu } @@ -130,7 +117,7 @@ func (pdu *RoomV1PDU) GetReferenceHash(roomVersion id.RoomVersion) ([32]byte, er return [32]byte{}, err } } - rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature()) + rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) if err != nil { return [32]byte{}, fmt.Errorf("failed to marshal redacted PDU to calculate event ID: %w", err) } @@ -188,7 +175,7 @@ func (pdu *RoomV1PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID if err != nil { return err } - rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature()) + rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) if err != nil { return fmt.Errorf("failed to marshal redacted PDU to sign: %w", err) } @@ -207,7 +194,7 @@ func (pdu *RoomV1PDU) VerifySignature(roomVersion id.RoomVersion, serverName str if !pdu.SupportsRoomVersion(roomVersion) { return fmt.Errorf("RoomV1PDU.VerifySignature: unsupported room version %s", roomVersion) } - rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature()) + rawJSON, err := marshalCanonical(pdu.Clone().RedactForSignature(roomVersion)) if err != nil { return fmt.Errorf("failed to marshal redacted PDU to verify signature: %w", err) } From a547c0636c72a215e6c56ef6cb72959d5fccaeee Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 21 Aug 2025 13:19:11 +0300 Subject: [PATCH 1312/1647] event,pushrules: replace assert.Nil with assert.NoError --- event/message_test.go | 14 +++++++------- pushrules/action_test.go | 12 ++++++------ pushrules/pushrules_test.go | 2 +- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/event/message_test.go b/event/message_test.go index 562a6622..c721df35 100644 --- a/event/message_test.go +++ b/event/message_test.go @@ -33,7 +33,7 @@ const invalidMessageEvent = `{ func TestMessageEventContent__ParseInvalid(t *testing.T) { var evt *event.Event err := json.Unmarshal([]byte(invalidMessageEvent), &evt) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender) assert.Equal(t, event.EventMessage, evt.Type) @@ -42,7 +42,7 @@ func TestMessageEventContent__ParseInvalid(t *testing.T) { assert.Equal(t, id.RoomID("!bar"), evt.RoomID) err = evt.Content.ParseRaw(evt.Type) - assert.NotNil(t, err) + assert.Error(t, err) } const messageEvent = `{ @@ -68,7 +68,7 @@ const messageEvent = `{ func TestMessageEventContent__ParseEdit(t *testing.T) { var evt *event.Event err := json.Unmarshal([]byte(messageEvent), &evt) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender) assert.Equal(t, event.EventMessage, evt.Type) @@ -110,7 +110,7 @@ const imageMessageEvent = `{ func TestMessageEventContent__ParseMedia(t *testing.T) { var evt *event.Event err := json.Unmarshal([]byte(imageMessageEvent), &evt) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender) assert.Equal(t, event.EventMessage, evt.Type) @@ -125,7 +125,7 @@ func TestMessageEventContent__ParseMedia(t *testing.T) { content := evt.Content.Parsed.(*event.MessageEventContent) assert.Equal(t, event.MsgImage, content.MsgType) parsedURL, err := content.URL.Parse() - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, id.ContentURI{Homeserver: "example.com", FileID: "image"}, parsedURL) assert.Nil(t, content.NewContent) assert.Equal(t, "image/png", content.GetInfo().MimeType) @@ -145,7 +145,7 @@ const expectedMarshalResult = `{"msgtype":"m.text","body":"test"}` func TestMessageEventContent__Marshal(t *testing.T) { data, err := json.Marshal(parsedMessage) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, expectedMarshalResult, string(data)) } @@ -163,6 +163,6 @@ const expectedCustomMarshalResult = `{"body":"test","msgtype":"m.text","net.maun func TestMessageEventContent__Marshal_Custom(t *testing.T) { data, err := json.Marshal(customParsedMessage) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, expectedCustomMarshalResult, string(data)) } diff --git a/pushrules/action_test.go b/pushrules/action_test.go index a8f68415..3c0aa168 100644 --- a/pushrules/action_test.go +++ b/pushrules/action_test.go @@ -139,9 +139,9 @@ func TestPushAction_UnmarshalJSON_InvalidTypeDoesNothing(t *testing.T) { } err := pa.UnmarshalJSON([]byte(`{"foo": "bar"}`)) - assert.Nil(t, err) + assert.NoError(t, err) err = pa.UnmarshalJSON([]byte(`9001`)) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, pushrules.PushActionType("unchanged"), pa.Action) assert.Equal(t, pushrules.PushActionTweak("unchanged"), pa.Tweak) @@ -156,7 +156,7 @@ func TestPushAction_UnmarshalJSON_StringChangesActionType(t *testing.T) { } err := pa.UnmarshalJSON([]byte(`"foo"`)) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, pushrules.PushActionType("foo"), pa.Action) assert.Equal(t, pushrules.PushActionTweak("unchanged"), pa.Tweak) @@ -171,7 +171,7 @@ func TestPushAction_UnmarshalJSON_SetTweakChangesTweak(t *testing.T) { } err := pa.UnmarshalJSON([]byte(`{"set_tweak": "foo", "value": 123.0}`)) - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, pushrules.ActionSetTweak, pa.Action) assert.Equal(t, pushrules.PushActionTweak("foo"), pa.Tweak) @@ -185,7 +185,7 @@ func TestPushAction_MarshalJSON_TweakOutputWorks(t *testing.T) { Value: "bar", } data, err := pa.MarshalJSON() - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, []byte(`{"set_tweak":"foo","value":"bar"}`), data) } @@ -196,6 +196,6 @@ func TestPushAction_MarshalJSON_OtherOutputWorks(t *testing.T) { Value: "bar", } data, err := pa.MarshalJSON() - assert.Nil(t, err) + assert.NoError(t, err) assert.Equal(t, []byte(`"something else"`), data) } diff --git a/pushrules/pushrules_test.go b/pushrules/pushrules_test.go index a531ca28..a5a0f5e7 100644 --- a/pushrules/pushrules_test.go +++ b/pushrules/pushrules_test.go @@ -25,7 +25,7 @@ func TestEventToPushRules(t *testing.T) { }, } pushRuleset, err := pushrules.EventToPushRules(evt) - assert.Nil(t, err) + assert.NoError(t, err) assert.NotNil(t, pushRuleset) assert.IsType(t, pushRuleset.Override, pushrules.PushRuleArray{}) From 1d484e01d071de51f5eccda63c7013f3f4c31de8 Mon Sep 17 00:00:00 2001 From: Kishan Bagaria <1093313+KishanBagaria@users.noreply.github.com> Date: Fri, 22 Aug 2025 14:46:56 +0530 Subject: [PATCH 1313/1647] event: implement disappearing timer types (#399) Co-authored-by: Tulir Asokan --- bridgev2/database/disappear.go | 25 +++++++++++++++++++------ bridgev2/database/portal.go | 14 +++++++++++++- bridgev2/portal.go | 28 ++++++++++++++++++++++++---- bridgev2/portalbackfill.go | 5 +++-- event/capabilities.d.ts | 21 +++++++++++++++++++++ event/capabilities.go | 26 ++++++++++++++++++++++++++ event/content.go | 1 + event/message.go | 2 ++ event/state.go | 16 ++++++++++++++++ event/type.go | 3 ++- 10 files changed, 127 insertions(+), 14 deletions(-) diff --git a/bridgev2/database/disappear.go b/bridgev2/database/disappear.go index 4e6f5e0a..e830cb14 100644 --- a/bridgev2/database/disappear.go +++ b/bridgev2/database/disappear.go @@ -12,28 +12,41 @@ import ( "time" "go.mau.fi/util/dbutil" + "go.mau.fi/util/jsontime" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) -// DisappearingType represents the type of a disappearing message timer. -type DisappearingType string +// Deprecated: use [event.DisappearingType] +type DisappearingType = event.DisappearingType +// Deprecated: use constants in event package const ( - DisappearingTypeNone DisappearingType = "" - DisappearingTypeAfterRead DisappearingType = "after_read" - DisappearingTypeAfterSend DisappearingType = "after_send" + DisappearingTypeNone = event.DisappearingTypeNone + DisappearingTypeAfterRead = event.DisappearingTypeAfterRead + DisappearingTypeAfterSend = event.DisappearingTypeAfterSend ) // DisappearingSetting represents a disappearing message timer setting // by combining a type with a timer and an optional start timestamp. type DisappearingSetting struct { - Type DisappearingType + Type event.DisappearingType Timer time.Duration DisappearAt time.Time } +func (ds DisappearingSetting) ToEventContent() *event.BeeperDisappearingTimer { + if ds.Type == event.DisappearingTypeNone || ds.Timer == 0 { + return nil + } + return &event.BeeperDisappearingTimer{ + Type: ds.Type, + Timer: jsontime.MS(ds.Timer), + } +} + type DisappearingMessageQuery struct { BridgeID networkid.BridgeID *dbutil.QueryHelper[*DisappearingMessage] diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index 17e44b09..c3aa7121 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -16,6 +16,7 @@ import ( "go.mau.fi/util/dbutil" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -34,9 +35,20 @@ type PortalQuery struct { *dbutil.QueryHelper[*Portal] } +type CapStateFlags uint32 + +func (csf CapStateFlags) Has(flag CapStateFlags) bool { + return csf&flag != 0 +} + +const ( + CapStateFlagDisappearingTimerSet CapStateFlags = 1 << iota +) + type CapabilityState struct { Source networkid.UserLoginID `json:"source"` ID string `json:"id"` + Flags CapStateFlags `json:"flags"` } type Portal struct { @@ -208,7 +220,7 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { } if disappearType.Valid { p.Disappear = DisappearingSetting{ - Type: DisappearingType(disappearType.String), + Type: event.DisappearingType(disappearType.String), Timer: time.Duration(disappearTimer.Int64), } } diff --git a/bridgev2/portal.go b/bridgev2/portal.go index d343a651..7c3a56c2 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1101,7 +1101,7 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin } portal.sendSuccessStatus(ctx, evt, resp.StreamOrder, message.MXID) } - if portal.Disappear.Type != database.DisappearingTypeNone { + if portal.Disappear.Type != event.DisappearingTypeNone { go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ RoomID: portal.MXID, EventID: message.MXID, @@ -2281,6 +2281,7 @@ func (portal *Portal) sendConvertedMessage( allSuccess := true for i, part := range converted.Parts { portal.applyRelationMeta(ctx, part.Content, replyTo, threadRoot, prevThreadEvent) + part.Content.BeeperDisappearingTimer = converted.Disappear.ToEventContent() dbMessage := &database.Message{ ID: id, PartID: part.ID, @@ -2325,8 +2326,8 @@ func (portal *Portal) sendConvertedMessage( logContext(log.Err(err)).Str("part_id", string(part.ID)).Msg("Failed to save message part to database") allSuccess = false } - if converted.Disappear.Type != database.DisappearingTypeNone && !dbMessage.HasFakeMXID() { - if converted.Disappear.Type == database.DisappearingTypeAfterSend && converted.Disappear.DisappearAt.IsZero() { + if converted.Disappear.Type != event.DisappearingTypeNone && !dbMessage.HasFakeMXID() { + if converted.Disappear.Type == event.DisappearingTypeAfterSend && converted.Disappear.DisappearAt.IsZero() { converted.Disappear.DisappearAt = dbMessage.Timestamp.Add(converted.Disappear.Timer) } portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ @@ -3648,6 +3649,15 @@ func (portal *Portal) UpdateCapabilities(ctx context.Context, source *UserLogin, portal.CapState = database.CapabilityState{ Source: source.ID, ID: capID, + Flags: portal.CapState.Flags, + } + if caps.DisappearingTimer != nil && !portal.CapState.Flags.Has(database.CapStateFlagDisappearingTimerSet) { + zerolog.Ctx(ctx).Debug().Msg("Disappearing timer capability was added, sending disappearing timer state event") + success = portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent()) + if !success { + return false + } + portal.CapState.Flags |= database.CapStateFlagDisappearingTimerSet } portal.lastCapUpdate = time.Now() if implicit { @@ -4030,7 +4040,7 @@ func DisappearingMessageNotice(expiration time.Duration, implicit bool) *event.M func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting database.DisappearingSetting, sender MatrixAPI, ts time.Time, implicit, save bool) bool { if setting.Timer == 0 { - setting.Type = "" + setting.Type = event.DisappearingTypeNone } if portal.Disappear.Timer == setting.Timer && portal.Disappear.Type == setting.Type { return false @@ -4046,6 +4056,9 @@ func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting dat if portal.MXID == "" { return true } + + portal.sendRoomMeta(ctx, sender, ts, event.StateBeeperDisappearingTimer, "", setting.ToEventContent()) + content := DisappearingMessageNotice(setting.Timer, implicit) if sender == nil { sender = portal.Bridge.Bot @@ -4333,6 +4346,13 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo Type: event.StateBeeperRoomFeatures, Content: event.Content{Parsed: roomFeatures}, }) + if roomFeatures.DisappearingTimer != nil { + req.InitialState = append(req.InitialState, &event.Event{ + Type: event.StateBeeperDisappearingTimer, + Content: event.Content{Parsed: portal.Disappear.ToEventContent()}, + }) + portal.CapState.Flags |= database.CapStateFlagDisappearingTimerSet + } if req.Topic == "" { // Add explicit topic event if topic is empty to ensure the event is set. // This ensures that there won't be an extra event later if PUT /state/... is called. diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index 9883fb12..f7819968 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -339,6 +339,7 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin for i, part := range msg.Parts { partIDs = append(partIDs, part.ID) portal.applyRelationMeta(ctx, part.Content, replyTo, threadRoot, prevThreadEvent) + part.Content.BeeperDisappearingTimer = msg.Disappear.ToEventContent() evtID := portal.Bridge.Matrix.GenerateDeterministicEventID(portal.MXID, portal.PortalKey, msg.ID, part.ID) dbMessage := &database.Message{ ID: msg.ID, @@ -379,8 +380,8 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin prevThreadEvent.MXID = evtID out.PrevThreadEvents[*msg.ThreadRoot] = evtID } - if msg.Disappear.Type != database.DisappearingTypeNone { - if msg.Disappear.Type == database.DisappearingTypeAfterSend && msg.Disappear.DisappearAt.IsZero() { + if msg.Disappear.Type != event.DisappearingTypeNone { + if msg.Disappear.Type == event.DisappearingTypeAfterSend && msg.Disappear.DisappearAt.IsZero() { msg.Disappear.DisappearAt = msg.Timestamp.Add(msg.Disappear.Timer) } out.Disappear = append(out.Disappear, &database.DisappearingMessage{ diff --git a/event/capabilities.d.ts b/event/capabilities.d.ts index 4cf29de7..7f1dce05 100644 --- a/event/capabilities.d.ts +++ b/event/capabilities.d.ts @@ -41,6 +41,8 @@ export interface RoomFeatures { delete_max_age?: seconds /** Whether deleting messages just for yourself is supported. No message age limit. */ delete_for_me?: boolean + /** Allowed configuration options for disappearing timers. */ + disappearing_timer?: DisappearingTimerCapability /** Whether reactions are supported. */ reaction?: CapabilitySupportLevel @@ -57,6 +59,7 @@ export interface RoomFeatures { declare type integer = number declare type seconds = integer +declare type milliseconds = integer declare type MIMEClass = "image" | "audio" | "video" | "text" | "font" | "model" | "application" declare type MIMETypeOrPattern = "*/*" @@ -106,6 +109,24 @@ export interface FileFeatures { view_once?: boolean } +export enum DisappearingType { + None = "", + AfterRead = "after_read", + AfterSend = "after_send", +} + +export interface DisappearingTimerCapability { + types: DisappearingType[] + timers: milliseconds[] + /** + * Whether clients should omit the empty disappearing_timer object in messages that they don't want to disappear + * + * Generally, bridged rooms will want the object to be always present, while native Matrix rooms don't, + * so the hardcoded features for Matrix rooms should set this to true, while bridges will not. + */ + omit_empty_timer?: true +} + /** * The support level for a feature. These are integers rather than booleans * to accurately represent what the bridge is doing and hopefully make the diff --git a/event/capabilities.go b/event/capabilities.go index 9c9eb09a..f44d6600 100644 --- a/event/capabilities.go +++ b/event/capabilities.go @@ -44,6 +44,8 @@ type RoomFeatures struct { DeleteForMe bool `json:"delete_for_me,omitempty"` DeleteMaxAge *jsontime.Seconds `json:"delete_max_age,omitempty"` + DisappearingTimer *DisappearingTimerCapability `json:"disappearing_timer,omitempty"` + Reaction CapabilitySupportLevel `json:"reaction,omitempty"` ReactionCount int `json:"reaction_count,omitempty"` AllowedReactions []string `json:"allowed_reactions,omitempty"` @@ -67,6 +69,13 @@ type FormattingFeatureMap map[FormattingFeature]CapabilitySupportLevel type FileFeatureMap map[CapabilityMsgType]*FileFeatures +type DisappearingTimerCapability struct { + Types []DisappearingType `json:"types"` + Timers []jsontime.Milliseconds `json:"timers"` + + OmitEmptyTimer bool `json:"omit_empty_timer,omitempty"` +} + type CapabilityMsgType = MessageType // Message types which are used for event capability signaling, but aren't real values for the msgtype field. @@ -231,6 +240,7 @@ func (rf *RoomFeatures) Hash() []byte { hashValue(hasher, "delete", rf.Delete) hashBool(hasher, "delete_for_me", rf.DeleteForMe) hashInt(hasher, "delete_max_age", rf.DeleteMaxAge.Get()) + hashValue(hasher, "disappearing_timer", rf.DisappearingTimer) hashValue(hasher, "reaction", rf.Reaction) hashInt(hasher, "reaction_count", rf.ReactionCount) @@ -249,6 +259,22 @@ func (rf *RoomFeatures) Hash() []byte { return hasher.Sum(nil) } +func (dtc *DisappearingTimerCapability) Hash() []byte { + if dtc == nil { + return nil + } + hasher := sha256.New() + hasher.Write([]byte("types")) + for _, t := range dtc.Types { + hasher.Write([]byte(t)) + } + hasher.Write([]byte("timers")) + for _, timer := range dtc.Timers { + hashInt(hasher, "", timer.Milliseconds()) + } + return hasher.Sum(nil) +} + func (ff *FileFeatures) Hash() []byte { hasher := sha256.New() hashMap(hasher, "mime_types", ff.MimeTypes) diff --git a/event/content.go b/event/content.go index e50dfea5..779330af 100644 --- a/event/content.go +++ b/event/content.go @@ -48,6 +48,7 @@ var TypeMap = map[Type]reflect.Type{ StateElementFunctionalMembers: reflect.TypeOf(ElementFunctionalMembersContent{}), StateBeeperRoomFeatures: reflect.TypeOf(RoomFeatures{}), + StateBeeperDisappearingTimer: reflect.TypeOf(BeeperDisappearingTimer{}), EventMessage: reflect.TypeOf(MessageEventContent{}), EventSticker: reflect.TypeOf(MessageEventContent{}), diff --git a/event/message.go b/event/message.go index 51403889..f16822f2 100644 --- a/event/message.go +++ b/event/message.go @@ -138,6 +138,8 @@ type MessageEventContent struct { BeeperLinkPreviews []*BeeperLinkPreview `json:"com.beeper.linkpreviews,omitempty"` + BeeperDisappearingTimer *BeeperDisappearingTimer `json:"com.beeper.disappearing_timer,omitempty"` + MSC1767Audio *MSC1767Audio `json:"org.matrix.msc1767.audio,omitempty"` MSC3245Voice *MSC3245Voice `json:"org.matrix.msc3245.voice,omitempty"` } diff --git a/event/state.go b/event/state.go index de46c57d..66b06b14 100644 --- a/event/state.go +++ b/event/state.go @@ -10,6 +10,8 @@ import ( "encoding/base64" "slices" + "go.mau.fi/util/jsontime" + "maunium.net/go/mautrix/id" ) @@ -207,6 +209,20 @@ type BridgeEventContent struct { BeeperRoomTypeV2 string `json:"com.beeper.room_type.v2,omitempty"` } +// DisappearingType represents the type of a disappearing message timer. +type DisappearingType string + +const ( + DisappearingTypeNone DisappearingType = "" + DisappearingTypeAfterRead DisappearingType = "after_read" + DisappearingTypeAfterSend DisappearingType = "after_send" +) + +type BeeperDisappearingTimer struct { + Type DisappearingType `json:"type"` + Timer jsontime.Milliseconds `json:"timer"` +} + type SpaceChildEventContent struct { Via []string `json:"via,omitempty"` Order string `json:"order,omitempty"` diff --git a/event/type.go b/event/type.go index b097cfe1..35bf2669 100644 --- a/event/type.go +++ b/event/type.go @@ -112,7 +112,7 @@ func (et *Type) GuessClass() TypeClass { StatePowerLevels.Type, StateRoomName.Type, StateRoomAvatar.Type, StateServerACL.Type, StateTopic.Type, StatePinnedEvents.Type, StateTombstone.Type, StateEncryption.Type, StateBridge.Type, StateHalfShotBridge.Type, StateSpaceParent.Type, StateSpaceChild.Type, StatePolicyRoom.Type, StatePolicyServer.Type, StatePolicyUser.Type, - StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type: + StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type, StateBeeperDisappearingTimer.Type: return StateEventType case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type: return EphemeralEventType @@ -202,6 +202,7 @@ var ( StateElementFunctionalMembers = Type{"io.element.functional_members", StateEventType} StateBeeperRoomFeatures = Type{"com.beeper.room_features", StateEventType} + StateBeeperDisappearingTimer = Type{"com.beeper.disappearing_timer", StateEventType} ) // Message events From 206071ec034f3e59e413417ad6f08deb38a441d2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 22 Aug 2025 18:38:38 +0300 Subject: [PATCH 1314/1647] federation/pdu: add redacted member event --- federation/pdu/hash_test.go | 6 ++++++ federation/pdu/pdu_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/federation/pdu/hash_test.go b/federation/pdu/hash_test.go index 35ea49df..17417e12 100644 --- a/federation/pdu/hash_test.go +++ b/federation/pdu/hash_test.go @@ -18,6 +18,9 @@ import ( func TestPDU_CalculateContentHash(t *testing.T) { for _, test := range testPDUs { + if test.redacted { + continue + } t.Run(test.name, func(t *testing.T) { parsed := parsePDU(test.pdu) contentHash := exerrors.Must(parsed.CalculateContentHash()) @@ -32,6 +35,9 @@ func TestPDU_CalculateContentHash(t *testing.T) { func TestPDU_VerifyContentHash(t *testing.T) { for _, test := range testPDUs { + if test.redacted { + continue + } t.Run(test.name, func(t *testing.T) { parsed := parsePDU(test.pdu) assert.True(t, parsed.VerifyContentHash()) diff --git a/federation/pdu/pdu_test.go b/federation/pdu/pdu_test.go index f4920098..93244741 100644 --- a/federation/pdu/pdu_test.go +++ b/federation/pdu/pdu_test.go @@ -71,12 +71,31 @@ var matrixOrg = serverDetails{ }, }, } +var continuwuityOrg = serverDetails{ + serverName: "continuwuity.org", + keys: map[id.KeyID]serverKey{ + "ed25519:PwHlNsFu": { + key: "8eNx2s0zWW+heKAmOH5zKv/nCPkEpraDJfGHxDu6hFI", + validUntilTS: time.Now(), + }, + }, +} +var novaAstraltechOrg = serverDetails{ + serverName: "nova.astraltech.org", + keys: map[id.KeyID]serverKey{ + "ed25519:a_afpo": { + key: "O1Y9GWuKo9xkuzuQef6gROxtTgxxAbS3WPNghPYXF3o", + validUntilTS: time.Now(), + }, + }, +} type testPDU struct { name string pdu string eventID id.EventID roomVersion id.RoomVersion + redacted bool serverDetails } @@ -156,6 +175,13 @@ var testPDUs = []testPDU{roomV4MessageTestPDU, { eventID: "$kAagtZAIEeZaLVCUSl74tAxQbdKbE22GU7FM-iAJBc0", roomVersion: id.RoomV4, serverDetails: mauniumNet, +}, { + name: "redacted m.room.member event in v11 room with 2 signatures", + pdu: `{"auth_events":["$9f12-_stoY07BOTmyguE1QlqvghLBh9Rk6PWRLoZn_M","$IP8hyjBkIDREVadyv0fPCGAW9IXGNllaZyxqQwiY_tA","$7dN5J8EveliaPkX6_QSejl4GQtem4oieavgALMeWZyE"],"content":{"membership":"join"},"depth":96978,"hashes":{"sha256":"APYA/aj3u+P0EwNaEofuSIlfqY3cK3lBz6RkwHX+Zak"},"origin_server_ts":1755664164485,"prev_events":["$XBN9W5Ll8VEH3eYqJaemxCBTDdy0hZB0sWpmyoUp93c"],"room_id":"!main-1:continuwuity.org","sender":"@6a19abdd4766:nova.astraltech.org","state_key":"@6a19abdd4766:nova.astraltech.org","type":"m.room.member","signatures":{"continuwuity.org":{"ed25519:PwHlNsFu":"+b/Fp2vWnC+Z2lI3GnCu7ZHdo3iWNDZ2AJqMoU9owMtLBPMxs4dVIsJXvaFq0ryawsgwDwKZ7f4xaFUNARJSDg"},"nova.astraltech.org":{"ed25519:a_afpo":"pXIngyxKukCPR7WOIIy8FTZxQ5L2dLiou5Oc8XS4WyY4YzJuckQzOaToigLLZxamfbN/jXbO+XUizpRpYccDAA"}},"unsigned":{}}`, + eventID: "$r6d9m125YWG28-Tln47bWtm6Jlv4mcSUWJTHijBlXLQ", + roomVersion: id.RoomV11, + serverDetails: novaAstraltechOrg, + redacted: true, }} func parsePDU(pdu string) (out *pdu.PDU) { From 35b805440f0bad84c35b77f0df1af436cab32e10 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 22 Aug 2025 19:37:53 +0300 Subject: [PATCH 1315/1647] federation/pdu: add auth event selection --- event/type.go | 3 +- federation/pdu/auth.go | 71 ++++++++++++++++++++++++++++++++++++++++++ federation/pdu/pdu.go | 1 + federation/pdu/v1.go | 25 +++++++++++++++ 4 files changed, 99 insertions(+), 1 deletion(-) create mode 100644 federation/pdu/auth.go diff --git a/event/type.go b/event/type.go index 35bf2669..1ab8c517 100644 --- a/event/type.go +++ b/event/type.go @@ -108,7 +108,7 @@ func (et *Type) IsCustom() bool { func (et *Type) GuessClass() TypeClass { switch et.Type { - case StateAliases.Type, StateCanonicalAlias.Type, StateCreate.Type, StateJoinRules.Type, StateMember.Type, + case StateAliases.Type, StateCanonicalAlias.Type, StateCreate.Type, StateJoinRules.Type, StateMember.Type, StateThirdPartyInvite.Type, StatePowerLevels.Type, StateRoomName.Type, StateRoomAvatar.Type, StateServerACL.Type, StateTopic.Type, StatePinnedEvents.Type, StateTombstone.Type, StateEncryption.Type, StateBridge.Type, StateHalfShotBridge.Type, StateSpaceParent.Type, StateSpaceChild.Type, StatePolicyRoom.Type, StatePolicyServer.Type, StatePolicyUser.Type, @@ -177,6 +177,7 @@ var ( StateHistoryVisibility = Type{"m.room.history_visibility", StateEventType} StateGuestAccess = Type{"m.room.guest_access", StateEventType} StateMember = Type{"m.room.member", StateEventType} + StateThirdPartyInvite = Type{"m.room.third_party_invite", StateEventType} StatePowerLevels = Type{"m.room.power_levels", StateEventType} StateRoomName = Type{"m.room.name", StateEventType} StateTopic = Type{"m.room.topic", StateEventType} diff --git a/federation/pdu/auth.go b/federation/pdu/auth.go new file mode 100644 index 00000000..1f98de06 --- /dev/null +++ b/federation/pdu/auth.go @@ -0,0 +1,71 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package pdu + +import ( + "slices" + + "github.com/tidwall/gjson" + "go.mau.fi/util/exgjson" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type StateKey struct { + Type string + StateKey string +} + +var thirdPartyInviteTokenPath = exgjson.Path("third_party_invite", "signed", "token") + +type AuthEventSelection []StateKey + +func (aes *AuthEventSelection) Add(evtType, stateKey string) { + key := StateKey{Type: evtType, StateKey: stateKey} + if !aes.Has(key) { + *aes = append(*aes, key) + } +} + +func (aes *AuthEventSelection) Has(key StateKey) bool { + return slices.Contains(*aes, key) +} + +func (pdu *PDU) AuthEventSelection(roomVersion id.RoomVersion) (keys AuthEventSelection) { + if pdu.Type == event.StateCreate.Type && pdu.StateKey != nil { + return AuthEventSelection{} + } + keys = make(AuthEventSelection, 0, 3) + if !roomVersion.RoomIDIsCreateEventID() { + keys.Add(event.StateCreate.Type, "") + } + keys.Add(event.StatePowerLevels.Type, "") + keys.Add(event.StateMember.Type, pdu.Sender.String()) + if pdu.Type == event.StateMember.Type && pdu.StateKey != nil { + keys.Add(event.StateMember.Type, *pdu.StateKey) + membership := event.Membership(gjson.GetBytes(pdu.Content, "membership").Str) + if membership == event.MembershipJoin || membership == event.MembershipInvite || membership == event.MembershipKnock { + keys.Add(event.StateJoinRules.Type, "") + } + if membership == event.MembershipInvite { + thirdPartyInviteToken := gjson.GetBytes(pdu.Content, thirdPartyInviteTokenPath).Str + if thirdPartyInviteToken != "" { + keys.Add(event.StateThirdPartyInvite.Type, thirdPartyInviteToken) + } + } + if membership == event.MembershipJoin && roomVersion.RestrictedJoins() { + authorizedVia := gjson.GetBytes(pdu.Content, "authorized_via_users_server").Str + if authorizedVia != "" { + keys.Add(event.StateMember.Type, authorizedVia) + } + } + } + return +} diff --git a/federation/pdu/pdu.go b/federation/pdu/pdu.go index 860201c5..0e63ea7c 100644 --- a/federation/pdu/pdu.go +++ b/federation/pdu/pdu.go @@ -45,6 +45,7 @@ type AnyPDU interface { Sign(roomVersion id.RoomVersion, serverName string, keyID id.KeyID, privateKey ed25519.PrivateKey) error VerifySignature(roomVersion id.RoomVersion, serverName string, getKey GetKeyFunc) error ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) + AuthEventSelection(roomVersion id.RoomVersion) (keys AuthEventSelection) } var ( diff --git a/federation/pdu/v1.go b/federation/pdu/v1.go index 1bc324ed..fc958b03 100644 --- a/federation/pdu/v1.go +++ b/federation/pdu/v1.go @@ -18,6 +18,7 @@ import ( "fmt" "time" + "github.com/tidwall/gjson" "go.mau.fi/util/ptr" "maunium.net/go/mautrix/event" @@ -250,3 +251,27 @@ func (pdu *RoomV1PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, e } return evt, nil } + +func (pdu *RoomV1PDU) AuthEventSelection(_ id.RoomVersion) (keys AuthEventSelection) { + if pdu.Type == event.StateCreate.Type && pdu.StateKey != nil { + return AuthEventSelection{} + } + keys = make(AuthEventSelection, 0, 3) + keys.Add(event.StateCreate.Type, "") + keys.Add(event.StatePowerLevels.Type, "") + keys.Add(event.StateMember.Type, pdu.Sender.String()) + if pdu.Type == event.StateMember.Type && pdu.StateKey != nil { + keys.Add(event.StateMember.Type, *pdu.StateKey) + membership := event.Membership(gjson.GetBytes(pdu.Content, "membership").Str) + if membership == event.MembershipJoin || membership == event.MembershipInvite || membership == event.MembershipKnock { + keys.Add(event.StateJoinRules.Type, "") + } + if membership == event.MembershipInvite { + thirdPartyInviteToken := gjson.GetBytes(pdu.Content, thirdPartyInviteTokenPath).Str + if thirdPartyInviteToken != "" { + keys.Add(event.StateThirdPartyInvite.Type, thirdPartyInviteToken) + } + } + } + return +} From fd20a61d87483e212ccefd76db32eb5782758891 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 23 Aug 2025 03:08:44 +0300 Subject: [PATCH 1316/1647] event: add json struct tag to third party signed object --- event/member.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/event/member.go b/event/member.go index 53387e8b..3e53893a 100644 --- a/event/member.go +++ b/event/member.go @@ -53,5 +53,5 @@ type ThirdPartyInvite struct { Token string `json:"token"` Signatures json.RawMessage `json:"signatures"` MXID string `json:"mxid"` - } + } `json:"signed"` } From 363aa943895876bff46f2cae3a310ace981058f5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 23 Aug 2025 03:13:10 +0300 Subject: [PATCH 1317/1647] federation/pdu: add server name parameter to GetKeyFunc --- federation/pdu/pdu.go | 2 +- federation/pdu/pdu_test.go | 5 ++++- federation/pdu/signature.go | 2 +- federation/pdu/signature_test.go | 10 +++++----- federation/pdu/v1.go | 2 +- federation/pdu/v1_test.go | 2 +- 6 files changed, 13 insertions(+), 10 deletions(-) diff --git a/federation/pdu/pdu.go b/federation/pdu/pdu.go index 0e63ea7c..c6faf3d0 100644 --- a/federation/pdu/pdu.go +++ b/federation/pdu/pdu.go @@ -33,7 +33,7 @@ import ( // valid at or after this time, but if that is not possible, the latest available key should be // returned without an error. The verify function will do its own validity checking based on the // returned valid until timestamp. -type GetKeyFunc = func(keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) +type GetKeyFunc = func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) type AnyPDU interface { GetRoomID() (id.RoomID, error) diff --git a/federation/pdu/pdu_test.go b/federation/pdu/pdu_test.go index 93244741..59d7c3a6 100644 --- a/federation/pdu/pdu_test.go +++ b/federation/pdu/pdu_test.go @@ -28,7 +28,10 @@ type serverDetails struct { keys map[id.KeyID]serverKey } -func (sd serverDetails) getKey(keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) { +func (sd serverDetails) getKey(serverName string, keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) { + if serverName != sd.serverName { + return "", time.Time{}, nil + } key, ok := sd.keys[keyID] if ok { return key.key, key.validUntilTS, nil diff --git a/federation/pdu/signature.go b/federation/pdu/signature.go index 1f8ae0b5..a7685cc6 100644 --- a/federation/pdu/signature.go +++ b/federation/pdu/signature.go @@ -46,7 +46,7 @@ func (pdu *PDU) VerifySignature(roomVersion id.RoomVersion, serverName string, g verified := false for keyID, sig := range pdu.Signatures[serverName] { originServerTS := time.UnixMilli(pdu.OriginServerTS) - key, validUntil, err := getKey(keyID, originServerTS) + key, validUntil, err := getKey(serverName, keyID, originServerTS) if err != nil { return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err) } else if key == "" { diff --git a/federation/pdu/signature_test.go b/federation/pdu/signature_test.go index 68e7a773..01df5076 100644 --- a/federation/pdu/signature_test.go +++ b/federation/pdu/signature_test.go @@ -36,7 +36,7 @@ func TestPDU_VerifySignature(t *testing.T) { func TestPDU_VerifySignature_Fail_NoKey(t *testing.T) { test := roomV12MessageTestPDU parsed := parsePDU(test.pdu) - err := parsed.VerifySignature(test.roomVersion, test.serverName, func(keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { + err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { return }) assert.Error(t, err) @@ -45,7 +45,7 @@ func TestPDU_VerifySignature_Fail_NoKey(t *testing.T) { func TestPDU_VerifySignature_V4ExpiredKey(t *testing.T) { test := roomV4MessageTestPDU parsed := parsePDU(test.pdu) - err := parsed.VerifySignature(test.roomVersion, test.serverName, func(keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { + err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { key = test.keys[keyID].key validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) return @@ -56,7 +56,7 @@ func TestPDU_VerifySignature_V4ExpiredKey(t *testing.T) { func TestPDU_VerifySignature_V12ExpiredKey(t *testing.T) { test := roomV12MessageTestPDU parsed := parsePDU(test.pdu) - err := parsed.VerifySignature(test.roomVersion, test.serverName, func(keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { + err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { key = test.keys[keyID].key validUntil = time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) return @@ -90,8 +90,8 @@ func TestPDU_Sign(t *testing.T) { } err := evt.Sign(id.RoomV12, "example.com", "ed25519:rand", privKey) require.NoError(t, err) - err = evt.VerifySignature(id.RoomV11, "example.com", func(keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { - if keyID == "ed25519:rand" { + err = evt.VerifySignature(id.RoomV11, "example.com", func(serverName string, keyID id.KeyID, minValidUntil time.Time) (key id.SigningKey, validUntil time.Time, err error) { + if serverName == "example.com" && keyID == "ed25519:rand" { key = id.SigningKey(base64.RawStdEncoding.EncodeToString(pubKey)) validUntil = time.Now() } diff --git a/federation/pdu/v1.go b/federation/pdu/v1.go index fc958b03..0e4c95e9 100644 --- a/federation/pdu/v1.go +++ b/federation/pdu/v1.go @@ -202,7 +202,7 @@ func (pdu *RoomV1PDU) VerifySignature(roomVersion id.RoomVersion, serverName str verified := false for keyID, sig := range pdu.Signatures[serverName] { originServerTS := time.UnixMilli(pdu.OriginServerTS) - key, _, err := getKey(keyID, originServerTS) + key, _, err := getKey(serverName, keyID, originServerTS) if err != nil { return fmt.Errorf("failed to get key %s for %s: %w", keyID, serverName, err) } else if key == "" { diff --git a/federation/pdu/v1_test.go b/federation/pdu/v1_test.go index e5531b0b..ecf2dbd2 100644 --- a/federation/pdu/v1_test.go +++ b/federation/pdu/v1_test.go @@ -73,7 +73,7 @@ func TestRoomV1PDU_VerifySignature(t *testing.T) { for _, test := range testV1PDUs { t.Run(test.name, func(t *testing.T) { parsed := parseV1PDU(test.pdu) - err := parsed.VerifySignature(test.roomVersion, test.serverName, func(keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) { + err := parsed.VerifySignature(test.roomVersion, test.serverName, func(serverName string, keyID id.KeyID, _ time.Time) (id.SigningKey, time.Time, error) { key, ok := test.keys[keyID] if ok { return key.key, key.validUntilTS, nil From 71bbbdb3c31e5b49534434eb16dca3d0af147c80 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 23 Aug 2025 23:22:43 +0300 Subject: [PATCH 1318/1647] federation/pdu: use jsontext.Value instead of any for deprecated fields --- federation/pdu/pdu.go | 6 +++--- federation/pdu/v1.go | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/federation/pdu/pdu.go b/federation/pdu/pdu.go index c6faf3d0..b5210550 100644 --- a/federation/pdu/pdu.go +++ b/federation/pdu/pdu.go @@ -71,9 +71,9 @@ type PDU struct { Unknown jsontext.Value `json:",unknown"` // Deprecated legacy fields - DeprecatedPrevState any `json:"prev_state,omitzero"` - DeprecatedOrigin any `json:"origin,omitzero"` - DeprecatedMembership any `json:"membership,omitzero"` + DeprecatedPrevState jsontext.Value `json:"prev_state,omitzero"` + DeprecatedOrigin jsontext.Value `json:"origin,omitzero"` + DeprecatedMembership jsontext.Value `json:"membership,omitzero"` } var ErrPDUIsNil = errors.New("PDU is nil") diff --git a/federation/pdu/v1.go b/federation/pdu/v1.go index 0e4c95e9..9557f8ab 100644 --- a/federation/pdu/v1.go +++ b/federation/pdu/v1.go @@ -75,9 +75,9 @@ type RoomV1PDU struct { Unknown jsontext.Value `json:",unknown"` // Deprecated legacy fields - DeprecatedPrevState any `json:"prev_state,omitzero"` - DeprecatedOrigin any `json:"origin,omitzero"` - DeprecatedMembership any `json:"membership,omitzero"` + DeprecatedPrevState jsontext.Value `json:"prev_state,omitzero"` + DeprecatedOrigin jsontext.Value `json:"origin,omitzero"` + DeprecatedMembership jsontext.Value `json:"membership,omitzero"` } func (pdu *RoomV1PDU) GetRoomID() (id.RoomID, error) { From d2cad8c57ec07ff8b0bce06e5485eeb0071e2e1b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 24 Aug 2025 00:44:50 +0300 Subject: [PATCH 1319/1647] format: add MarkdownMentionWithName helper --- format/markdown.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/format/markdown.go b/format/markdown.go index 3d9979b4..3b1c1f51 100644 --- a/format/markdown.go +++ b/format/markdown.go @@ -57,7 +57,11 @@ type uriAble interface { } func MarkdownMention(id uriAble) string { - return MarkdownLink(id.String(), id.URI().MatrixToURL()) + return MarkdownMentionWithName(id.String(), id) +} + +func MarkdownMentionWithName(name string, id uriAble) string { + return MarkdownLink(name, id.URI().MatrixToURL()) } func MarkdownLink(name string, url string) string { From 7e07700a69437cea25ee99dfd3e2f213bcd70f94 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 24 Aug 2025 00:47:55 +0300 Subject: [PATCH 1320/1647] format: add MarkdownMentionRoomID helper --- format/markdown.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/format/markdown.go b/format/markdown.go index 3b1c1f51..77ced0dc 100644 --- a/format/markdown.go +++ b/format/markdown.go @@ -64,6 +64,13 @@ func MarkdownMentionWithName(name string, id uriAble) string { return MarkdownLink(name, id.URI().MatrixToURL()) } +func MarkdownMentionRoomID(name string, id id.RoomID, via ...string) string { + if name == "" { + name = id.String() + } + return MarkdownLink(name, id.URI(via...).MatrixToURL()) +} + func MarkdownLink(name string, url string) string { return fmt.Sprintf("[%s](%s)", EscapeMarkdown(name), EscapeMarkdown(url)) } From fa7c1ae2bcd716f29a823a4167ec6a0c9206a5d2 Mon Sep 17 00:00:00 2001 From: Brad Murray Date: Mon, 25 Aug 2025 08:03:13 -0400 Subject: [PATCH 1321/1647] crypto/sqlstore: add index to make finding megolm sessions to backup faster (#402) ``` 2025-08-24T22:23:19Z debug [MatrixBridgeV2] {"level":"warn","component":"matrix","component":"client_loop","subcomponent":"sync_key_backup_loop","rows":0,"duration_seconds":1.046191042,"method":"EndRows","query":"SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version FROM crypto_megolm_inbound_session WHERE account_id=?1 AND session IS NOT NULL AND key_backup_version != ?2","time":"2025-08-24T22:23:19.22077Z","message":"Query took long"} ``` before: ``` sqlite> EXPLAIN SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version FROM crypto_megolm_inbound_session WHERE account_id='@brad:beeper.com/CHNWOJWEUC' AND sessi addr opcode p1 p2 p3 p4 p5 comment ---- ------------- ---- ---- ---- ------------- -- ------------- 0 Init 0 25 0 0 Start at 25 1 OpenRead 0 48 0 15 0 root=48 iDb=0; crypto_megolm_inbound_session 2 OpenRead 1 49 0 k(3,,,) 2 root=49 iDb=0; sqlite_autoindex_crypto_megolm_inbound_session_1 3 String8 0 1 0 @brad:beeper.com/CHNWOJWEUC 0 r[1]='@brad:beeper.com/CHNWOJWEUC' 4 SeekGE 1 24 1 1 0 key=r[1] 5 IdxGT 1 24 1 1 0 key=r[1] 6 DeferredSeek 1 0 0 0 Move 0 to 1.rowid if needed 7 Column 0 5 2 128 r[2]= cursor 0 column 5 8 IsNull 2 23 0 0 if r[2]==NULL goto 23 9 Column 0 14 2 0 r[2]=crypto_megolm_inbound_session.key_backup_version 10 Eq 3 23 2 BINARY-8 82 if r[2]==r[3] goto 23 11 Column 0 4 4 0 r[4]= cursor 0 column 4 12 Column 0 2 5 0 r[5]= cursor 0 column 2 13 Column 0 3 6 0 r[6]= cursor 0 column 3 14 Column 0 5 7 0 r[7]= cursor 0 column 5 15 Column 0 6 8 0 r[8]= cursor 0 column 6 16 Column 0 9 9 0 r[9]= cursor 0 column 9 17 Column 0 10 10 0 r[10]= cursor 0 column 10 18 Column 0 11 11 0 r[11]= cursor 0 column 11 19 Column 0 12 12 0 r[12]= cursor 0 column 12 20 Column 0 13 13 0 0 r[13]=crypto_megolm_inbound_session.is_scheduled 21 Column 0 14 14 0 r[14]=crypto_megolm_inbound_session.key_backup_version 22 ResultRow 4 11 0 0 output=r[4..14] 23 Next 1 5 0 0 24 Halt 0 0 0 0 25 Transaction 0 0 55 0 1 usesStmtJournal=0 26 Integer 1 3 0 0 r[3]=1 27 Goto 0 1 0 0 sqlite> SELECT COUNT(*) FROM crypto_megolm_inbound_session ; +----------+ | COUNT(*) | +----------+ | 168792 | +----------+ sqlite> SELECT COUNT(*) FROM crypto_megolm_inbound_session WHERE session IS NULL; +----------+ | COUNT(*) | +----------+ | 39 | +----------+ sqlite> SELECT COUNT(*) FROM crypto_megolm_inbound_session WHERE key_backup_version != 1; +----------+ | COUNT(*) | +----------+ | 39 | +----------+ ``` after: ``` sqlite> CREATE INDEX idx_megolm_filtered ...> ON crypto_megolm_inbound_session(account_id, key_backup_version, session); sqlite> EXPLAIN SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version FROM crypto_megolm_inbound_session WHERE account_id='@brad:beeper.com/CHNWOJWEUC' AND session IS NOT NULL AND key_backup_version != 1; addr opcode p1 p2 p3 p4 p5 comment ---- ------------- ---- ---- ---- ------------- -- ------------- 0 Init 0 25 0 0 Start at 25 1 OpenRead 0 48 0 15 0 root=48 iDb=0; crypto_megolm_inbound_session 2 OpenRead 1 91264 0 k(4,,,,) 2 root=91264 iDb=0; idx_megolm_filtered 3 String8 0 1 0 @brad:beeper.com/CHNWOJWEUC 0 r[1]='@brad:beeper.com/CHNWOJWEUC' 4 SeekGE 1 24 1 1 0 key=r[1] 5 IdxGT 1 24 1 1 0 key=r[1] 6 DeferredSeek 1 0 0 0 Move 0 to 1.rowid if needed 7 Column 1 2 2 128 r[2]= cursor 1 column 2 8 IsNull 2 23 0 0 if r[2]==NULL goto 23 9 Column 1 1 2 0 r[2]=crypto_megolm_inbound_session.key_backup_version 10 Eq 3 23 2 BINARY-8 82 if r[2]==r[3] goto 23 11 Column 0 4 4 0 r[4]= cursor 0 column 4 12 Column 0 2 5 0 r[5]= cursor 0 column 2 13 Column 0 3 6 0 r[6]= cursor 0 column 3 14 Column 1 2 7 0 r[7]= cursor 1 column 2 15 Column 0 6 8 0 r[8]= cursor 0 column 6 16 Column 0 9 9 0 r[9]= cursor 0 column 9 17 Column 0 10 10 0 r[10]= cursor 0 column 10 18 Column 0 11 11 0 r[11]= cursor 0 column 11 19 Column 0 12 12 0 r[12]= cursor 0 column 12 20 Column 0 13 13 0 0 r[13]=crypto_megolm_inbound_session.is_scheduled 21 Column 1 1 14 0 r[14]=crypto_megolm_inbound_session.key_backup_version 22 ResultRow 4 11 0 0 output=r[4..14] 23 Next 1 5 0 0 24 Halt 0 0 0 0 25 Transaction 0 0 56 0 1 usesStmtJournal=0 26 Integer 1 3 0 0 r[3]=1 27 Goto 0 1 0 0 sqlite> ``` --- crypto/sql_store_upgrade/00-latest-revision.sql | 4 +++- .../18-megolm-inbound-session-backup-index.sql | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) create mode 100644 crypto/sql_store_upgrade/18-megolm-inbound-session-backup-index.sql diff --git a/crypto/sql_store_upgrade/00-latest-revision.sql b/crypto/sql_store_upgrade/00-latest-revision.sql index 00dd1387..af8ab5cc 100644 --- a/crypto/sql_store_upgrade/00-latest-revision.sql +++ b/crypto/sql_store_upgrade/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v17 (compatible with v15+): Latest revision +-- v0 -> v18 (compatible with v15+): Latest revision CREATE TABLE IF NOT EXISTS crypto_account ( account_id TEXT PRIMARY KEY, device_id TEXT NOT NULL, @@ -73,6 +73,8 @@ CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session ( key_backup_version TEXT NOT NULL DEFAULT '', PRIMARY KEY (account_id, session_id) ); +-- Useful index to find keys that need backing up +CREATE INDEX crypto_megolm_inbound_session_backup_idx ON crypto_megolm_inbound_session(account_id, key_backup_version) WHERE session IS NOT NULL; CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session ( account_id TEXT, diff --git a/crypto/sql_store_upgrade/18-megolm-inbound-session-backup-index.sql b/crypto/sql_store_upgrade/18-megolm-inbound-session-backup-index.sql new file mode 100644 index 00000000..da26da0f --- /dev/null +++ b/crypto/sql_store_upgrade/18-megolm-inbound-session-backup-index.sql @@ -0,0 +1,2 @@ +-- v18 (compatible with v15+): Add an index to the megolm_inbound_session table to make finding sessions to backup faster +CREATE INDEX crypto_megolm_inbound_session_backup_idx ON crypto_megolm_inbound_session(account_id, key_backup_version) WHERE session IS NOT NULL; From c04d0b66819b5ca54425140c1bd77c100262a9c7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 25 Aug 2025 12:58:28 +0300 Subject: [PATCH 1322/1647] bridgev2: merge mentions and url previews when merging caption --- bridgev2/networkinterface.go | 4 ++++ event/message.go | 12 ++++++++++++ 2 files changed, 16 insertions(+) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index eb38bd2d..d792ed0d 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -117,11 +117,15 @@ func MergeCaption(textPart, mediaPart *ConvertedMessagePart) *ConvertedMessagePa mediaPart.Content.EnsureHasHTML() mediaPart.Content.Body += "\n\n" + textPart.Content.Body mediaPart.Content.FormattedBody += "

      " + textPart.Content.FormattedBody + mediaPart.Content.Mentions = mediaPart.Content.Mentions.Merge(textPart.Content.Mentions) + mediaPart.Content.BeeperLinkPreviews = append(mediaPart.Content.BeeperLinkPreviews, textPart.Content.BeeperLinkPreviews...) } else { mediaPart.Content.FileName = mediaPart.Content.Body mediaPart.Content.Body = textPart.Content.Body mediaPart.Content.Format = textPart.Content.Format mediaPart.Content.FormattedBody = textPart.Content.FormattedBody + mediaPart.Content.Mentions = textPart.Content.Mentions + mediaPart.Content.BeeperLinkPreviews = textPart.Content.BeeperLinkPreviews } if metaMerger, ok := mediaPart.DBMetadata.(database.MetaMerger); ok { metaMerger.CopyFrom(textPart.DBMetadata) diff --git a/event/message.go b/event/message.go index f16822f2..cc7c8261 100644 --- a/event/message.go +++ b/event/message.go @@ -273,6 +273,18 @@ func (m *Mentions) Has(userID id.UserID) bool { return m != nil && slices.Contains(m.UserIDs, userID) } +func (m *Mentions) Merge(other *Mentions) *Mentions { + if m == nil { + return other + } else if other == nil { + return m + } + return &Mentions{ + UserIDs: slices.Concat(m.UserIDs, other.UserIDs), + Room: m.Room || other.Room, + } +} + type EncryptedFileInfo struct { attachment.EncryptedFile URL id.ContentURIString `json:"url"` From 0fab92dbc1cafb65688d02273c3553c01cd0dc4f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 25 Aug 2025 12:58:40 +0300 Subject: [PATCH 1323/1647] event: add third party invite state event content --- event/content.go | 1 + event/member.go | 30 +++++++++++++++++++++--------- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/event/content.go b/event/content.go index 779330af..5924ffe3 100644 --- a/event/content.go +++ b/event/content.go @@ -18,6 +18,7 @@ import ( // This is used by Content.ParseRaw() for creating the correct type of struct. var TypeMap = map[Type]reflect.Type{ StateMember: reflect.TypeOf(MemberEventContent{}), + StateThirdPartyInvite: reflect.TypeOf(ThirdPartyInviteEventContent{}), StatePowerLevels: reflect.TypeOf(PowerLevelsEventContent{}), StateCanonicalAlias: reflect.TypeOf(CanonicalAliasEventContent{}), StateRoomName: reflect.TypeOf(RoomNameEventContent{}), diff --git a/event/member.go b/event/member.go index 3e53893a..9956a36b 100644 --- a/event/member.go +++ b/event/member.go @@ -7,8 +7,6 @@ package event import ( - "encoding/json" - "maunium.net/go/mautrix/id" ) @@ -47,11 +45,25 @@ type MemberEventContent struct { MSC4293RedactEvents bool `json:"org.matrix.msc4293.redact_events,omitempty"` } -type ThirdPartyInvite struct { - DisplayName string `json:"display_name"` - Signed struct { - Token string `json:"token"` - Signatures json.RawMessage `json:"signatures"` - MXID string `json:"mxid"` - } `json:"signed"` +type SignedThirdPartyInvite struct { + Token string `json:"token"` + Signatures map[string]map[id.KeyID]string `json:"signatures,omitempty"` + MXID string `json:"mxid"` +} + +type ThirdPartyInvite struct { + DisplayName string `json:"display_name"` + Signed SignedThirdPartyInvite `json:"signed"` +} + +type ThirdPartyInviteEventContent struct { + DisplayName string `json:"display_name"` + KeyValidityURL string `json:"key_validity_url"` + PublicKey id.Ed25519 `json:"public_key"` + PublicKeys []ThirdPartyInviteKey `json:"public_keys,omitempty"` +} + +type ThirdPartyInviteKey struct { + KeyValidityURL string `json:"key_validity_url,omitempty"` + PublicKey id.Ed25519 `json:"public_key"` } From 5ac8a888a3a5b11165172a4bab0e65c74ade1737 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 25 Aug 2025 17:15:19 +0300 Subject: [PATCH 1324/1647] bridgev2/portal: make UpdateDisappearingSetting more versatile --- bridgev2/portal.go | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 7c3a56c2..0aae674d 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -4038,7 +4038,15 @@ func DisappearingMessageNotice(expiration time.Duration, implicit bool) *event.M return content } -func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting database.DisappearingSetting, sender MatrixAPI, ts time.Time, implicit, save bool) bool { +type UpdateDisappearingSettingOpts struct { + Sender MatrixAPI + Timestamp time.Time + Implicit bool + Save bool + SendNotice bool +} + +func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting database.DisappearingSetting, opts UpdateDisappearingSettingOpts) bool { if setting.Timer == 0 { setting.Type = event.DisappearingTypeNone } @@ -4047,7 +4055,7 @@ func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting dat } portal.Disappear.Type = setting.Type portal.Disappear.Timer = setting.Timer - if save { + if opts.Save { err := portal.Save(ctx) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal to database after updating disappearing setting") @@ -4057,21 +4065,21 @@ func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting dat return true } - portal.sendRoomMeta(ctx, sender, ts, event.StateBeeperDisappearingTimer, "", setting.ToEventContent()) - - content := DisappearingMessageNotice(setting.Timer, implicit) - if sender == nil { - sender = portal.Bridge.Bot + if opts.Sender == nil { + opts.Sender = portal.Bridge.Bot } - _, err := sender.SendMessage(ctx, portal.MXID, event.EventMessage, &event.Content{ + portal.sendRoomMeta(ctx, opts.Sender, opts.Timestamp, event.StateBeeperDisappearingTimer, "", setting.ToEventContent()) + + content := DisappearingMessageNotice(setting.Timer, opts.Implicit) + _, err := opts.Sender.SendMessage(ctx, portal.MXID, event.EventMessage, &event.Content{ Parsed: content, - }, &MatrixSendExtra{Timestamp: ts}) + }, &MatrixSendExtra{Timestamp: opts.Timestamp}) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to send disappearing messages notice") } else { zerolog.Ctx(ctx).Debug(). Dur("new_timer", portal.Disappear.Timer). - Bool("implicit", implicit). + Bool("implicit", opts.Implicit). Msg("Sent disappearing messages notice") } return true @@ -4162,7 +4170,13 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us changed = portal.updateAvatar(ctx, info.Avatar, sender, ts) || changed } if info.Disappear != nil { - changed = portal.UpdateDisappearingSetting(ctx, *info.Disappear, sender, ts, false, false) || changed + changed = portal.UpdateDisappearingSetting(ctx, *info.Disappear, UpdateDisappearingSettingOpts{ + Sender: sender, + Timestamp: ts, + Implicit: false, + Save: false, + SendNotice: true, + }) || changed } if info.ParentID != nil { changed = portal.updateParent(ctx, *info.ParentID, source) || changed From 8e703410f48ca9c94b2288c6608a4d8f5c39ff3c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 25 Aug 2025 17:21:55 +0300 Subject: [PATCH 1325/1647] bridgev2/portal: always set timestamp for disappearing message timer update --- bridgev2/portal.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 0aae674d..e523b7bd 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -4068,6 +4068,9 @@ func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting dat if opts.Sender == nil { opts.Sender = portal.Bridge.Bot } + if opts.Timestamp.IsZero() { + opts.Timestamp = time.Now() + } portal.sendRoomMeta(ctx, opts.Sender, opts.Timestamp, event.StateBeeperDisappearingTimer, "", setting.ToEventContent()) content := DisappearingMessageNotice(setting.Timer, opts.Implicit) From f860b0e2386ae669fb2b9829e46dc9d874c8c599 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 25 Aug 2025 17:23:25 +0300 Subject: [PATCH 1326/1647] bridgev2/portal: fix send notice option when updating disappearing message timer --- bridgev2/portal.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index e523b7bd..f0d2d0a1 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -4073,6 +4073,9 @@ func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting dat } portal.sendRoomMeta(ctx, opts.Sender, opts.Timestamp, event.StateBeeperDisappearingTimer, "", setting.ToEventContent()) + if !opts.SendNotice { + return true + } content := DisappearingMessageNotice(setting.Timer, opts.Implicit) _, err := opts.Sender.SendMessage(ctx, portal.MXID, event.EventMessage, &event.Content{ Parsed: content, From a6bbe978bd5520b3722518a54ff2c60d7588beeb Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 25 Aug 2025 17:31:12 +0300 Subject: [PATCH 1327/1647] bridgev2/networkinterface: add interface for handling disappearing timer changes from Matrix --- bridgev2/matrix/connector.go | 1 + bridgev2/networkinterface.go | 9 +++++++++ bridgev2/portal.go | 15 ++++++++++++++- 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 19eb399b..c5ee40fe 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -147,6 +147,7 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) { br.EventProcessor.On(event.StateRoomAvatar, br.handleRoomEvent) br.EventProcessor.On(event.StateTopic, br.handleRoomEvent) br.EventProcessor.On(event.StateTombstone, br.handleRoomEvent) + br.EventProcessor.On(event.StateBeeperDisappearingTimer, br.handleRoomEvent) br.EventProcessor.On(event.EphemeralEventReceipt, br.handleEphemeralEvent) br.EventProcessor.On(event.EphemeralEventTyping, br.handleEphemeralEvent) br.Bot = br.AS.BotIntent() diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index d792ed0d..dcbcbad5 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -683,6 +683,14 @@ type RoomTopicHandlingNetworkAPI interface { HandleMatrixRoomTopic(ctx context.Context, msg *MatrixRoomTopic) (bool, error) } +type DisappearTimerChangingNetworkAPI interface { + NetworkAPI + // HandleMatrixDisappearingTimer is called when the disappearing timer of a portal room is changed. + // This method should update the Disappear field of the Portal with the new timer and return true + // if the change was successful. If the change is not successful, then the field should not be updated. + HandleMatrixDisappearingTimer(ctx context.Context, msg *MatrixDisappearingTimer) (bool, error) +} + type ResolveIdentifierResponse struct { // Ghost is the ghost of the user that the identifier resolves to. // This field should be set whenever possible. However, it is not required, @@ -1270,6 +1278,7 @@ type MatrixRoomMeta[ContentType any] struct { type MatrixRoomName = MatrixRoomMeta[*event.RoomNameEventContent] type MatrixRoomAvatar = MatrixRoomMeta[*event.RoomAvatarEventContent] type MatrixRoomTopic = MatrixRoomMeta[*event.TopicEventContent] +type MatrixDisappearingTimer = MatrixRoomMeta[*event.BeeperDisappearingTimer] type MatrixReadReceipt struct { Portal *Portal diff --git a/bridgev2/portal.go b/bridgev2/portal.go index f0d2d0a1..24365df9 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -663,6 +663,8 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomTopicHandlingNetworkAPI.HandleMatrixRoomTopic) case event.StateRoomAvatar: return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomAvatarHandlingNetworkAPI.HandleMatrixRoomAvatar) + case event.StateBeeperDisappearingTimer: + return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, DisappearTimerChangingNetworkAPI.HandleMatrixDisappearingTimer) case event.StateEncryption: // TODO? return EventHandlingResultIgnored @@ -1477,6 +1479,15 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( portal.sendSuccessStatus(ctx, evt, 0, "") return EventHandlingResultIgnored } + case *event.BeeperDisappearingTimer: + if typedContent.Type == event.DisappearingTypeNone || typedContent.Timer.Duration <= 0 { + typedContent.Type = event.DisappearingTypeNone + typedContent.Timer.Duration = 0 + } + if typedContent.Type == portal.Disappear.Type && typedContent.Timer.Duration == portal.Disappear.Timer { + portal.sendSuccessStatus(ctx, evt, 0, "") + return EventHandlingResultIgnored + } } var prevContent ContentType if evt.Unsigned.PrevContent != nil { @@ -1500,7 +1511,9 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( return EventHandlingResultFailed.WithMSSError(err) } if changed { - portal.UpdateBridgeInfo(ctx) + if evt.Type != event.StateBeeperDisappearingTimer { + portal.UpdateBridgeInfo(ctx) + } err = portal.Save(ctx) if err != nil { log.Err(err).Msg("Failed to save portal after updating room metadata") From 4f7c7dafdc6af65fcada0c881c669405e374cf3f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 25 Aug 2025 17:42:20 +0300 Subject: [PATCH 1328/1647] bridgev2/matrix: fix encryption error notice not being redacted after retry success --- bridgev2/matrix/matrix.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/bridgev2/matrix/matrix.go b/bridgev2/matrix/matrix.go index 49c377db..64165941 100644 --- a/bridgev2/matrix/matrix.go +++ b/bridgev2/matrix/matrix.go @@ -87,17 +87,18 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event) decryptionStart := time.Now() decrypted, err := br.Crypto.Decrypt(ctx, evt) decryptionRetryCount := 0 + var errorEventID id.EventID if errors.Is(err, NoSessionFound) { decryptionRetryCount = 1 log.Debug(). Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())). Msg("Couldn't find session, waiting for keys to arrive...") - go br.sendCryptoStatusError(ctx, evt, err, nil, 0, false) + go br.sendCryptoStatusError(ctx, evt, err, &errorEventID, 0, false) if br.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { log.Debug().Msg("Got keys after waiting, trying to decrypt event again") decrypted, err = br.Crypto.Decrypt(ctx, evt) } else { - go br.waitLongerForSession(ctx, evt, decryptionStart) + go br.waitLongerForSession(ctx, evt, decryptionStart, &errorEventID) return } } @@ -106,10 +107,10 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event) go br.sendCryptoStatusError(ctx, evt, err, nil, decryptionRetryCount, true) return } - br.postDecrypt(ctx, evt, decrypted, decryptionRetryCount, nil, time.Since(decryptionStart)) + br.postDecrypt(ctx, evt, decrypted, decryptionRetryCount, &errorEventID, time.Since(decryptionStart)) } -func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event, decryptionStart time.Time) { +func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event, decryptionStart time.Time, errorEventID *id.EventID) { log := zerolog.Ctx(ctx) content := evt.Content.AsEncrypted() log.Debug(). @@ -117,7 +118,6 @@ func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event, Msg("Couldn't find session, requesting keys and waiting longer...") go br.Crypto.RequestSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) - var errorEventID *id.EventID go br.sendCryptoStatusError(ctx, evt, fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), errorEventID, 1, false) if !br.Crypto.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { From bca8b0528c976aca0e1df4d284918dc60c1c01a4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 25 Aug 2025 18:27:49 +0300 Subject: [PATCH 1329/1647] sqlstatestore: fix GetPowerLevels returning non-nil even if power levels weren't found --- sqlstatestore/statestore.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sqlstatestore/statestore.go b/sqlstatestore/statestore.go index 0ed4b698..c4126802 100644 --- a/sqlstatestore/statestore.go +++ b/sqlstatestore/statestore.go @@ -370,7 +370,7 @@ func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.Ro func (store *SQLStateStore) GetEncryptionEvent(ctx context.Context, roomID id.RoomID) (*event.EncryptionEventContent, error) { var data []byte err := store. - QueryRow(ctx, "SELECT encryption FROM mx_room_state WHERE room_id=$1", roomID). + QueryRow(ctx, "SELECT encryption FROM mx_room_state WHERE room_id=$1 AND encryption IS NOT NULL", roomID). Scan(&data) if errors.Is(err, sql.ErrNoRows) { return nil, nil @@ -406,7 +406,7 @@ func (store *SQLStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID func (store *SQLStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) { levels = &event.PowerLevelsEventContent{} err = store. - QueryRow(ctx, "SELECT power_levels, create_event FROM mx_room_state WHERE room_id=$1", roomID). + QueryRow(ctx, "SELECT power_levels, create_event FROM mx_room_state WHERE room_id=$1 AND power_levels IS NOT NULL", roomID). Scan(&dbutil.JSON{Data: &levels}, &dbutil.JSON{Data: &levels.CreateEvent}) if errors.Is(err, sql.ErrNoRows) { return nil, nil @@ -458,7 +458,7 @@ func (store *SQLStateStore) SetCreate(ctx context.Context, evt *event.Event) err func (store *SQLStateStore) GetCreate(ctx context.Context, roomID id.RoomID) (evt *event.Event, err error) { err = store. - QueryRow(ctx, "SELECT create_event FROM mx_room_state WHERE room_id=$1", roomID). + QueryRow(ctx, "SELECT create_event FROM mx_room_state WHERE room_id=$1 AND create_event IS NOT NULL", roomID). Scan(&dbutil.JSON{Data: &evt}) if errors.Is(err, sql.ErrNoRows) { return nil, nil From c3a422347ce3da25637a6f1239e676db530df01d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 25 Aug 2025 18:36:03 +0300 Subject: [PATCH 1330/1647] bridgev2/portal: validate capabilities when updating disappearing timer --- bridgev2/errors.go | 2 ++ bridgev2/portal.go | 7 +++++++ event/capabilities.go | 7 +++++++ 3 files changed, 16 insertions(+) diff --git a/bridgev2/errors.go b/bridgev2/errors.go index c023dcdf..026a95f4 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -67,6 +67,8 @@ var ( ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) ErrRemoteEchoTimeout = WrapErrorInStatus(errors.New("remote echo timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) ErrRemoteAckTimeout = WrapErrorInStatus(errors.New("remote ack timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) + + ErrDisappearingTimerUnsupported error = WrapErrorInStatus(errors.New("invalid disappearing timer")).WithIsCertain(true) ) // Common login interface errors diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 24365df9..5e0a9137 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1488,6 +1488,10 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( portal.sendSuccessStatus(ctx, evt, 0, "") return EventHandlingResultIgnored } + if !sender.Client.GetCapabilities(ctx, portal).DisappearingTimer.Supports(typedContent) { + portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent()) + return EventHandlingResultFailed.WithMSSError(ErrDisappearingTimerUnsupported) + } } var prevContent ContentType if evt.Unsigned.PrevContent != nil { @@ -1508,6 +1512,9 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( }) if err != nil { log.Err(err).Msg("Failed to handle Matrix room metadata") + if evt.Type == event.StateBeeperDisappearingTimer { + portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent()) + } return EventHandlingResultFailed.WithMSSError(err) } if changed { diff --git a/event/capabilities.go b/event/capabilities.go index f44d6600..20f87bce 100644 --- a/event/capabilities.go +++ b/event/capabilities.go @@ -76,6 +76,13 @@ type DisappearingTimerCapability struct { OmitEmptyTimer bool `json:"omit_empty_timer,omitempty"` } +func (dtc *DisappearingTimerCapability) Supports(content *BeeperDisappearingTimer) bool { + if dtc == nil || content.Type == DisappearingTypeNone { + return true + } + return slices.Contains(dtc.Types, content.Type) && slices.Contains(dtc.Timers, content.Timer) +} + type CapabilityMsgType = MessageType // Message types which are used for event capability signaling, but aren't real values for the msgtype field. From 63b654187d40ed2538ad16e48e3b16fe2280df63 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 25 Aug 2025 19:03:07 +0300 Subject: [PATCH 1331/1647] event: marshal zero disappearing timers as empty object --- event/state.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/event/state.go b/event/state.go index 66b06b14..8711f857 100644 --- a/event/state.go +++ b/event/state.go @@ -8,6 +8,7 @@ package event import ( "encoding/base64" + "encoding/json" "slices" "go.mau.fi/util/jsontime" @@ -223,6 +224,15 @@ type BeeperDisappearingTimer struct { Timer jsontime.Milliseconds `json:"timer"` } +type marshalableBeeperDisappearingTimer BeeperDisappearingTimer + +func (bdt *BeeperDisappearingTimer) MarshalJSON() ([]byte, error) { + if bdt == nil || bdt.Type == DisappearingTypeNone { + return []byte("{}"), nil + } + return json.Marshal((*marshalableBeeperDisappearingTimer)(bdt)) +} + type SpaceChildEventContent struct { Via []string `json:"via,omitempty"` Order string `json:"order,omitempty"` From e9d4eeb33266ec3cc7ae9dfe0d854dd2bef8ae7c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 26 Aug 2025 15:56:27 +0300 Subject: [PATCH 1332/1647] bridgev2/status: add avatar_keys to remote profile --- bridgev2/status/bridgestate.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/bridgev2/status/bridgestate.go b/bridgev2/status/bridgestate.go index 01a235a0..671303e0 100644 --- a/bridgev2/status/bridgestate.go +++ b/bridgev2/status/bridgestate.go @@ -22,6 +22,7 @@ import ( "go.mau.fi/util/ptr" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto/attachment" "maunium.net/go/mautrix/id" ) @@ -87,6 +88,9 @@ type RemoteProfile struct { Username string `json:"username,omitempty"` Name string `json:"name,omitempty"` Avatar id.ContentURIString `json:"avatar,omitempty"` + + // Only used for backups of local bridge states + AvatarKeys *attachment.EncryptedFile `json:"avatar_keys,omitempty"` } func coalesce[T ~string](a, b T) T { @@ -102,6 +106,9 @@ func (rp *RemoteProfile) Merge(other RemoteProfile) RemoteProfile { other.Username = coalesce(rp.Username, other.Username) other.Name = coalesce(rp.Name, other.Name) other.Avatar = coalesce(rp.Avatar, other.Avatar) + if rp.AvatarKeys != nil { + other.AvatarKeys = rp.AvatarKeys + } return other } From 7b3a60742eb0f0b7e7a3a003578fa91cf4a787ad Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 26 Aug 2025 15:56:38 +0300 Subject: [PATCH 1333/1647] event: allow omitting timers from disappearing timer capability --- event/capabilities.d.ts | 3 ++- event/capabilities.go | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/event/capabilities.d.ts b/event/capabilities.d.ts index 7f1dce05..27164a5f 100644 --- a/event/capabilities.d.ts +++ b/event/capabilities.d.ts @@ -117,7 +117,8 @@ export enum DisappearingType { export interface DisappearingTimerCapability { types: DisappearingType[] - timers: milliseconds[] + /** Allowed timer values. If omitted, any timer is allowed. */ + timers?: milliseconds[] /** * Whether clients should omit the empty disappearing_timer object in messages that they don't want to disappear * diff --git a/event/capabilities.go b/event/capabilities.go index 20f87bce..ebedb6a2 100644 --- a/event/capabilities.go +++ b/event/capabilities.go @@ -71,7 +71,7 @@ type FileFeatureMap map[CapabilityMsgType]*FileFeatures type DisappearingTimerCapability struct { Types []DisappearingType `json:"types"` - Timers []jsontime.Milliseconds `json:"timers"` + Timers []jsontime.Milliseconds `json:"timers,omitempty"` OmitEmptyTimer bool `json:"omit_empty_timer,omitempty"` } @@ -80,7 +80,7 @@ func (dtc *DisappearingTimerCapability) Supports(content *BeeperDisappearingTime if dtc == nil || content.Type == DisappearingTypeNone { return true } - return slices.Contains(dtc.Types, content.Type) && slices.Contains(dtc.Timers, content.Timer) + return slices.Contains(dtc.Types, content.Type) && (dtc.Timers == nil || slices.Contains(dtc.Timers, content.Timer)) } type CapabilityMsgType = MessageType From 0345a5356de17d1006b462cd633d6d25be82a1c1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 26 Aug 2025 17:07:16 +0300 Subject: [PATCH 1334/1647] bridgev2/database: don't set disappearing timer content to nil --- bridgev2/database/disappear.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/database/disappear.go b/bridgev2/database/disappear.go index e830cb14..537d0552 100644 --- a/bridgev2/database/disappear.go +++ b/bridgev2/database/disappear.go @@ -39,7 +39,7 @@ type DisappearingSetting struct { func (ds DisappearingSetting) ToEventContent() *event.BeeperDisappearingTimer { if ds.Type == event.DisappearingTypeNone || ds.Timer == 0 { - return nil + return &event.BeeperDisappearingTimer{} } return &event.BeeperDisappearingTimer{ Type: ds.Type, From ba16c30a8cd11d91086f6efa77014c97a4546761 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 27 Aug 2025 00:45:33 +0200 Subject: [PATCH 1335/1647] federation/eventauth: add v3-v12 event auth rules (#401) --- federation/eventauth/eventauth.go | 838 ++++++++++++++++++ federation/eventauth/eventauth_test.go | 85 ++ .../eventauth/testroom-v12-success.jsonl | 17 + federation/pdu/pdu.go | 5 + federation/signutil/verify.go | 41 + 5 files changed, 986 insertions(+) create mode 100644 federation/eventauth/eventauth.go create mode 100644 federation/eventauth/eventauth_test.go create mode 100644 federation/eventauth/testroom-v12-success.jsonl diff --git a/federation/eventauth/eventauth.go b/federation/eventauth/eventauth.go new file mode 100644 index 00000000..d4a50969 --- /dev/null +++ b/federation/eventauth/eventauth.go @@ -0,0 +1,838 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package eventauth + +import ( + "encoding/json" + "encoding/json/jsontext" + "errors" + "fmt" + "slices" + "strconv" + "strings" + + "github.com/tidwall/gjson" + "go.mau.fi/util/exgjson" + "go.mau.fi/util/exstrings" + "go.mau.fi/util/ptr" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/federation/pdu" + "maunium.net/go/mautrix/federation/signutil" + "maunium.net/go/mautrix/id" +) + +type AuthFailError struct { + Index string + Message string + Wrapped error +} + +func (afe AuthFailError) Error() string { + if afe.Message != "" { + return fmt.Sprintf("fail %s: %s", afe.Index, afe.Message) + } else if afe.Wrapped != nil { + return fmt.Sprintf("fail %s: %s", afe.Index, afe.Wrapped.Error()) + } + return fmt.Sprintf("fail %s", afe.Index) +} + +func (afe AuthFailError) Unwrap() error { + return afe.Wrapped +} + +var mFederatePath = exgjson.Path("m.federate") + +var ( + ErrCreateHasPrevEvents = AuthFailError{Index: "1.1", Message: "m.room.create event has prev_events"} + ErrCreateHasRoomID = AuthFailError{Index: "1.2", Message: "m.room.create event has room_id set"} + ErrRoomIDDoesntMatchSender = AuthFailError{Index: "1.2", Message: "room ID server doesn't match sender server"} + ErrUnknownRoomVersion = AuthFailError{Index: "1.3", Wrapped: id.ErrUnknownRoomVersion} + ErrInvalidAdditionalCreators = AuthFailError{Index: "1.4", Message: "m.room.create event has invalid additional_creators"} + ErrMissingCreator = AuthFailError{Index: "1.4", Message: "m.room.create event is missing creator field"} + + ErrInvalidRoomIDLength = AuthFailError{Index: "2", Message: "room ID length is invalid"} + ErrFailedToGetCreateEvent = AuthFailError{Index: "2", Message: "failed to get m.room.create event"} + ErrCreateEventNotFound = AuthFailError{Index: "2", Message: "m.room.create event not found using room ID as event ID"} + ErrRejectedCreateEvent = AuthFailError{Index: "2", Message: "m.room.create event was rejected"} + + ErrFailedToGetAuthEvents = AuthFailError{Index: "3", Message: "failed to get auth events"} + ErrFailedToParsePowerLevels = AuthFailError{Index: "?", Message: "failed to parse power levels"} + ErrDuplicateAuthEvent = AuthFailError{Index: "3.1", Message: "duplicate type/state key pair in auth events"} + ErrNonStateAuthEvent = AuthFailError{Index: "3.2", Message: "non-state event in auth events"} + ErrMissingAuthEvent = AuthFailError{Index: "3.2", Message: "missing auth event"} + ErrUnexpectedAuthEvent = AuthFailError{Index: "3.2", Message: "unexpected type/state key pair in auth events"} + ErrNoCreateEvent = AuthFailError{Index: "3.2", Message: "no m.room.create event found in auth events"} + ErrRejectedAuthEvent = AuthFailError{Index: "3.3", Message: "auth event was rejected"} + ErrMismatchingRoomIDInAuthEvent = AuthFailError{Index: "3.4", Message: "auth event room ID does not match event room ID"} + + ErrFederationDisabled = AuthFailError{Index: "4", Message: "federation is disabled for this room"} + + ErrMemberNotState = AuthFailError{Index: "5.1", Message: "m.room.member event is not a state event"} + ErrNotSignedByAuthoriser = AuthFailError{Index: "5.2", Message: "m.room.member event is not signed by server of join_authorised_via_users_server"} + ErrCantJoinOtherUser = AuthFailError{Index: "5.3.2", Message: "can't send join event with different state key"} + ErrCantJoinBanned = AuthFailError{Index: "5.3.3", Message: "user is banned from the room"} + ErrAuthoriserCantInvite = AuthFailError{Index: "5.3.5.2", Message: "authoriser doesn't have sufficient power level to invite"} + ErrCantJoinWithoutInvite = AuthFailError{Index: "5.3.7", Message: "can't join invite-only room without invite"} + ErrInvalidJoinRule = AuthFailError{Index: "5.3.7", Message: "invalid join rule in room"} + ErrThirdPartyInviteBanned = AuthFailError{Index: "5.4.1.1", Message: "third party invite target user is banned"} + ErrThirdPartyInviteMissingFields = AuthFailError{Index: "5.4.1.3", Message: "third party invite is missing mxid or token fields"} + ErrThirdPartyInviteMXIDMismatch = AuthFailError{Index: "5.4.1.4", Message: "mxid in signed third party invite doesn't match event state key"} + ErrThirdPartyInviteNotFound = AuthFailError{Index: "5.4.1.5", Message: "matching m.room.third_party_invite event not found in auth events"} + ErrThirdPartyInviteSenderMismatch = AuthFailError{Index: "5.4.1.6", Message: "sender of third party invite doesn't match sender of member event"} + ErrThirdPartyInviteNotSigned = AuthFailError{Index: "5.4.1.8", Message: "no valid signatures found for third party invite"} + ErrInviterNotInRoom = AuthFailError{Index: "5.4.2", Message: "inviter's membership is not join"} + ErrInviteTargetAlreadyInRoom = AuthFailError{Index: "5.4.3", Message: "invite target user is already in the room"} + ErrInviteTargetBanned = AuthFailError{Index: "5.4.3", Message: "invite target user is banned"} + ErrInsufficientPermissionForInvite = AuthFailError{Index: "5.4.5", Message: "inviter does not have sufficient permission to send invites"} + ErrCantLeaveWithoutBeingInRoom = AuthFailError{Index: "5.5.1", Message: "can't leave room without being in it"} + ErrCantKickWithoutBeingInRoom = AuthFailError{Index: "5.5.2", Message: "can't kick another user without being in the room"} + ErrInsufficientPermissionForUnban = AuthFailError{Index: "5.5.3", Message: "sender does not have sufficient permission to unban users"} + ErrInsufficientPermissionForKick = AuthFailError{Index: "5.5.5", Message: "sender does not have sufficient permission to kick the user"} + ErrCantBanWithoutBeingInRoom = AuthFailError{Index: "5.6.1", Message: "can't ban another user without being in the room"} + ErrInsufficientPermissionForBan = AuthFailError{Index: "5.6.3", Message: "sender does not have sufficient permission to ban the user"} + ErrNotKnockableRoom = AuthFailError{Index: "5.7.1", Message: "join rule doesn't allow knocking"} + ErrCantKnockOtherUser = AuthFailError{Index: "5.7.1", Message: "can't send knock event with different state key"} + ErrCantKnockWhileInRoom = AuthFailError{Index: "5.7.2", Message: "can't knock while joined, invited or banned"} + ErrUnknownMembership = AuthFailError{Index: "5.8", Message: "unknown membership in m.room.member event"} + + ErrNotInRoom = AuthFailError{Index: "6", Message: "sender is not a member of the room"} + + ErrInsufficientPowerForThirdPartyInvite = AuthFailError{Index: "7.1", Message: "sender does not have sufficient power level to send third party invite"} + + ErrInsufficientPowerLevel = AuthFailError{Index: "8", Message: "sender does not have sufficient power level to send event"} + + ErrMismatchingPrivateStateKey = AuthFailError{Index: "9", Message: "state keys starting with @ must match sender user ID"} + + ErrTopLevelPLNotInteger = AuthFailError{Index: "10.1", Message: "invalid type for top-level power level field"} + ErrPLNotInteger = AuthFailError{Index: "10.2", Message: "invalid type for power level"} + ErrInvalidUserIDInPL = AuthFailError{Index: "10.3", Message: "invalid user ID in power levels"} + ErrUserPLNotInteger = AuthFailError{Index: "10.3", Message: "invalid type for user power level"} + ErrCreatorInPowerLevels = AuthFailError{Index: "10.4", Message: "room creators must not be specified in power levels"} + ErrInvalidPowerChange = AuthFailError{Index: "10.x", Message: "illegal power level change"} +) + +func isRejected(evt *pdu.PDU) bool { + return evt.InternalMeta.Rejected +} + +type GetEventsFunc = func(ids []id.EventID) ([]*pdu.PDU, error) + +func Authorize(roomVersion id.RoomVersion, evt *pdu.PDU, getEvents GetEventsFunc, getKey pdu.GetKeyFunc) error { + if evt.Type == event.StateCreate.Type { + // 1. If type is m.room.create: + return authorizeCreate(roomVersion, evt) + } + var createEvt *pdu.PDU + if roomVersion.RoomIDIsCreateEventID() { + // 2. If the event’s room_id is not an event ID for an accepted (not rejected) m.room.create event, + // with the sigil ! instead of $, reject. + if len(evt.RoomID) != 44 { + return fmt.Errorf("%w (%d)", ErrInvalidRoomIDLength, len(evt.RoomID)) + } else if createEvts, err := getEvents([]id.EventID{id.EventID("$" + evt.RoomID[1:])}); err != nil { + return fmt.Errorf("%w: %w", ErrFailedToGetCreateEvent, err) + } else if len(createEvts) != 1 { + return fmt.Errorf("%w (%s)", ErrCreateEventNotFound, evt.RoomID) + } else if isRejected(createEvts[0]) { + return ErrRejectedCreateEvent + } else { + createEvt = createEvts[0] + } + } + authEvents, err := getEvents(evt.AuthEvents) + if err != nil { + return fmt.Errorf("%w: %w", ErrFailedToGetAuthEvents, err) + } + expectedAuthEvents := evt.AuthEventSelection(roomVersion) + deduplicator := make(map[pdu.StateKey]id.EventID, len(expectedAuthEvents)) + // 3. Considering the event’s auth_events: + for i, ae := range authEvents { + authEvtID := evt.AuthEvents[i] + if ae == nil { + return fmt.Errorf("%w (%s)", ErrMissingAuthEvent, authEvtID) + } else if ae.StateKey == nil { + // This approximately falls under rule 3.2. + return fmt.Errorf("%w (%s)", ErrNonStateAuthEvent, authEvtID) + } + key := pdu.StateKey{Type: ae.Type, StateKey: *ae.StateKey} + if prevEvtID, alreadyFound := deduplicator[key]; alreadyFound { + // 3.1. If there are duplicate entries for a given type and state_key pair, reject. + return fmt.Errorf("%w for %s/%s: found %s and %s", ErrDuplicateAuthEvent, ae.Type, *ae.StateKey, prevEvtID, authEvtID) + } else if !expectedAuthEvents.Has(key) { + // 3.2. If there are entries whose type and state_key don’t match those specified by + // the auth events selection algorithm described in the server specification, reject. + return fmt.Errorf("%w: found %s with key %s/%s", ErrUnexpectedAuthEvent, authEvtID, ae.Type, *ae.StateKey) + } else if isRejected(ae) { + // 3.3. If there are entries which were themselves rejected under the checks performed on receipt of a PDU, reject. + return fmt.Errorf("%w (%s)", ErrRejectedAuthEvent, authEvtID) + } else if ae.RoomID != evt.RoomID { + // 3.4. If any event in auth_events has a room_id which does not match that of the event being authorised, reject. + return fmt.Errorf("%w (%s)", ErrMismatchingRoomIDInAuthEvent, authEvtID) + } else { + deduplicator[key] = authEvtID + } + if ae.Type == event.StateCreate.Type { + if createEvt == nil { + createEvt = ae + } else { + // Duplicates are prevented by deduplicator, AuthEventSelection also won't allow a create event at all for v12+ + panic(fmt.Errorf("impossible case: multiple create events found in auth events")) + } + } + } + if createEvt == nil { + // This comes either from auth_events or room_id depending on the room version. + // The checks above make sure it's from the right source. + return ErrNoCreateEvent + } + if federateVal := gjson.GetBytes(createEvt.Content, mFederatePath); federateVal.Type == gjson.False && createEvt.Sender.Homeserver() != evt.Sender.Homeserver() { + // 4. If the content of the m.room.create event in the room state has the property m.federate set to false, + // and the sender domain of the event does not match the sender domain of the create event, reject. + return ErrFederationDisabled + } + if evt.Type == event.StateMember.Type { + // 5. If type is m.room.member: + return authorizeMember(roomVersion, evt, createEvt, authEvents, getKey) + } + senderMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, evt.Sender.String(), "membership", "leave")) + if senderMembership != event.MembershipJoin { + // 6. If the sender’s current membership state is not join, reject. + return ErrNotInRoom + } + powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt) + if err != nil { + return err + } + senderPL := powerLevels.GetUserLevel(evt.Sender) + if evt.Type == event.StateThirdPartyInvite.Type { + // 7.1. Allow if and only if sender’s current power level is greater than or equal to the invite level. + if senderPL >= powerLevels.Invite() { + return nil + } + return ErrInsufficientPowerForThirdPartyInvite + } + typeClass := event.MessageEventType + if evt.StateKey != nil { + typeClass = event.StateEventType + } + evtLevel := powerLevels.GetEventLevel(event.Type{Type: evt.Type, Class: typeClass}) + if evtLevel > senderPL { + // 8. If the event type’s required power level is greater than the sender’s power level, reject. + return fmt.Errorf("%w (%d > %d)", ErrInsufficientPowerLevel, evtLevel, senderPL) + } + + if evt.StateKey != nil && strings.HasPrefix(*evt.StateKey, "@") && *evt.StateKey != evt.Sender.String() { + // 9. If the event has a state_key that starts with an @ and does not match the sender, reject. + return ErrMismatchingPrivateStateKey + } + + if evt.Type == event.StatePowerLevels.Type { + // 10. If type is m.room.power_levels: + return authorizePowerLevels(roomVersion, evt, createEvt, authEvents) + } + + // 11. Otherwise, allow. + return nil +} + +var ErrUserIDNotAString = errors.New("not a string") +var ErrUserIDNotValid = errors.New("not a valid user ID") + +func isValidUserID(roomVersion id.RoomVersion, userID gjson.Result) error { + if userID.Type != gjson.String { + return ErrUserIDNotAString + } + // In a future room version, user IDs will have stricter validation + _, _, err := id.UserID(userID.Str).Parse() + if err != nil { + return ErrUserIDNotValid + } + return nil +} + +func authorizeCreate(roomVersion id.RoomVersion, evt *pdu.PDU) error { + if len(evt.PrevEvents) > 0 { + // 1.1. If it has any prev_events, reject. + return ErrCreateHasPrevEvents + } + if roomVersion.RoomIDIsCreateEventID() { + if evt.RoomID != "" { + // 1.2. If the event has a room_id, reject. + return ErrCreateHasRoomID + } + } else { + _, _, server := id.ParseCommonIdentifier(evt.RoomID) + if server == "" || server != evt.Sender.Homeserver() { + // 1.2. (v11 and below) If the domain of the room_id does not match the domain of the sender, reject. + return ErrRoomIDDoesntMatchSender + } + } + if !roomVersion.IsKnown() { + // 1.3. If content.room_version is present and is not a recognised version, reject. + return fmt.Errorf("%w %s", ErrUnknownRoomVersion, roomVersion) + } + if roomVersion.PrivilegedRoomCreators() { + additionalCreators := gjson.GetBytes(evt.Content, "additional_creators") + if additionalCreators.Exists() { + if !additionalCreators.IsArray() { + return fmt.Errorf("%w: not an array", ErrInvalidAdditionalCreators) + } + for i, item := range additionalCreators.Array() { + // 1.4. If additional_creators is present in content and is not an array of strings + // where each string passes the same user ID validation applied to sender, reject. + if err := isValidUserID(roomVersion, item); err != nil { + return fmt.Errorf("%w: item #%d %w", ErrInvalidAdditionalCreators, i+1, err) + } + } + } + } + if roomVersion.CreatorInContent() { + // 1.4. (v10 and below) If content has no creator property, reject. + if !gjson.GetBytes(evt.Content, "creator").Exists() { + return ErrMissingCreator + } + } + // 1.5. Otherwise, allow. + return nil +} + +func authorizeMember(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEvents []*pdu.PDU, getKey pdu.GetKeyFunc) error { + membership := event.Membership(gjson.GetBytes(evt.Content, "membership").Str) + if evt.StateKey == nil { + // 5.1. If there is no state_key property, or no membership property in content, reject. + return ErrMemberNotState + } + authorizedVia := id.UserID(gjson.GetBytes(evt.Content, "authorized_via_users_server").Str) + if authorizedVia != "" { + homeserver := authorizedVia.Homeserver() + err := evt.VerifySignature(roomVersion, homeserver, getKey) + if err != nil { + // 5.2. If content has a join_authorised_via_users_server key: + // 5.2.1. If the event is not validly signed by the homeserver of the user ID denoted by the key, reject. + return fmt.Errorf("%w: %w", ErrNotSignedByAuthoriser, err) + } + } + targetPrevMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, *evt.StateKey, "membership", "leave")) + senderMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, evt.Sender.String(), "membership", "leave")) + switch membership { + case event.MembershipJoin: + createEvtID, err := createEvt.GetEventID(roomVersion) + if err != nil { + return fmt.Errorf("failed to get create event ID: %w", err) + } + creator := createEvt.Sender.String() + if roomVersion.CreatorInContent() { + creator = gjson.GetBytes(evt.Content, "creator").Str + } + if len(evt.PrevEvents) == 1 && + len(evt.AuthEvents) <= 1 && + evt.PrevEvents[0] == createEvtID && + *evt.StateKey == creator { + // 5.3.1. If the only previous event is an m.room.create and the state_key is the sender of the m.room.create, allow. + return nil + } + // Spec wart: this would make more sense before the check above. + // Now you can set anyone as the sender of the first join. + if evt.Sender.String() != *evt.StateKey { + // 5.3.2. If the sender does not match state_key, reject. + return ErrCantJoinOtherUser + } + + if senderMembership == event.MembershipBan { + // 5.3.3. If the sender is banned, reject. + return ErrCantJoinBanned + } + + joinRule := event.JoinRule(findEventAndReadString(authEvents, event.StateJoinRules.Type, "", "join_rule", "invite")) + switch joinRule { + case event.JoinRuleKnock: + if !roomVersion.Knocks() { + return ErrInvalidJoinRule + } + fallthrough + case event.JoinRuleInvite: + // 5.3.4. If the join_rule is invite or knock then allow if membership state is invite or join. + if targetPrevMembership == event.MembershipJoin || targetPrevMembership == event.MembershipInvite { + return nil + } + return ErrCantJoinWithoutInvite + case event.JoinRuleKnockRestricted: + if !roomVersion.KnockRestricted() { + return ErrInvalidJoinRule + } + fallthrough + case event.JoinRuleRestricted: + if joinRule == event.JoinRuleRestricted && !roomVersion.RestrictedJoins() { + return ErrInvalidJoinRule + } + if targetPrevMembership == event.MembershipJoin || targetPrevMembership == event.MembershipInvite { + // 5.3.5.1. If membership state is join or invite, allow. + return nil + } + powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt) + if err != nil { + return err + } + if powerLevels.GetUserLevel(authorizedVia) < powerLevels.Invite() { + // 5.3.5.2. If the join_authorised_via_users_server key in content is not a user with sufficient permission to invite other users, reject. + return ErrAuthoriserCantInvite + } + // 5.3.5.3. Otherwise, allow. + return nil + case event.JoinRulePublic: + // 5.3.6. If the join_rule is public, allow. + return nil + default: + // 5.3.7. Otherwise, reject. + return ErrInvalidJoinRule + } + case event.MembershipInvite: + tpiVal := gjson.GetBytes(evt.Content, "third_party_invite") + if tpiVal.Exists() { + if targetPrevMembership == event.MembershipBan { + return ErrThirdPartyInviteBanned + } + signed := tpiVal.Get("signed") + mxid := signed.Get("mxid").Str + token := signed.Get("token").Str + if mxid == "" || token == "" { + // 5.4.1.2. If content.third_party_invite does not have a signed property, reject. + // 5.4.1.3. If signed does not have mxid and token properties, reject. + return ErrThirdPartyInviteMissingFields + } + if mxid != *evt.StateKey { + // 5.4.1.4. If mxid does not match state_key, reject. + return ErrThirdPartyInviteMXIDMismatch + } + tpiEvt := findEvent(authEvents, event.StateThirdPartyInvite.Type, token) + if tpiEvt == nil { + // 5.4.1.5. If there is no m.room.third_party_invite event in the current room state with state_key matching token, reject. + return ErrThirdPartyInviteNotFound + } + if tpiEvt.Sender != evt.Sender { + // 5.4.1.6. If sender does not match sender of the m.room.third_party_invite, reject. + return ErrThirdPartyInviteSenderMismatch + } + var keys []id.Ed25519 + const ed25519Base64Len = 43 + oldPubKey := gjson.GetBytes(evt.Content, "public_key.token") + if oldPubKey.Type == gjson.String && len(oldPubKey.Str) == ed25519Base64Len { + keys = append(keys, id.Ed25519(oldPubKey.Str)) + } + gjson.GetBytes(evt.Content, "public_keys").ForEach(func(key, value gjson.Result) bool { + if key.Type != gjson.Number { + return false + } + if value.Type == gjson.String && len(value.Str) == ed25519Base64Len { + keys = append(keys, id.Ed25519(value.Str)) + } + return true + }) + rawSigned := jsontext.Value(exstrings.UnsafeBytes(signed.Str)) + var validated bool + for _, key := range keys { + if signutil.VerifyJSONAny(key, rawSigned) == nil { + validated = true + } + } + if validated { + // 4.4.1.7. If any signature in signed matches any public key in the m.room.third_party_invite event, allow. + return nil + } + // 4.4.1.8. Otherwise, reject. + return ErrThirdPartyInviteNotSigned + } + if senderMembership != event.MembershipJoin { + // 5.4.2. If the sender’s current membership state is not join, reject. + return ErrInviterNotInRoom + } + // 5.4.3. If target user’s current membership state is join or ban, reject. + if targetPrevMembership == event.MembershipJoin { + return ErrInviteTargetAlreadyInRoom + } else if targetPrevMembership == event.MembershipBan { + return ErrInviteTargetBanned + } + powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt) + if err != nil { + return err + } + if powerLevels.GetUserLevel(evt.Sender) >= powerLevels.Invite() { + // 5.4.4. If the sender’s power level is greater than or equal to the invite level, allow. + return nil + } + // 5.4.5. Otherwise, reject. + return ErrInsufficientPermissionForInvite + case event.MembershipLeave: + if evt.Sender.String() == *evt.StateKey { + // 5.5.1. If the sender matches state_key, allow if and only if that user’s current membership state is invite, join, or knock. + if senderMembership == event.MembershipInvite || + senderMembership == event.MembershipJoin || + (senderMembership == event.MembershipKnock && roomVersion.Knocks()) { + return nil + } + return ErrCantLeaveWithoutBeingInRoom + } + if senderMembership != event.MembershipLeave { + // 5.5.2. If the sender’s current membership state is not join, reject. + return ErrCantKickWithoutBeingInRoom + } + powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt) + if err != nil { + return err + } + senderLevel := powerLevels.GetUserLevel(evt.Sender) + if targetPrevMembership == event.MembershipBan && senderLevel < powerLevels.Ban() { + // 5.5.3. If the target user’s current membership state is ban, and the sender’s power level is less than the ban level, reject. + return ErrInsufficientPermissionForUnban + } + if senderLevel >= powerLevels.Kick() && powerLevels.GetUserLevel(id.UserID(*evt.StateKey)) < senderLevel { + // 5.5.4. If the sender’s power level is greater than or equal to the kick level, and the target user’s power level is less than the sender’s power level, allow. + return nil + } + // TODO separate errors for < kick and < target user level? + // 5.5.5. Otherwise, reject. + return ErrInsufficientPermissionForKick + case event.MembershipBan: + if senderMembership != event.MembershipLeave { + // 5.6.1. If the sender’s current membership state is not join, reject. + return ErrCantBanWithoutBeingInRoom + } + powerLevels, err := getPowerLevels(roomVersion, authEvents, createEvt) + if err != nil { + return err + } + senderLevel := powerLevels.GetUserLevel(evt.Sender) + if senderLevel >= powerLevels.Ban() && powerLevels.GetUserLevel(id.UserID(*evt.StateKey)) < senderLevel { + // 5.6.2. If the sender’s power level is greater than or equal to the ban level, and the target user’s power level is less than the sender’s power level, allow. + return nil + } + // 5.6.3. Otherwise, reject. + return ErrInsufficientPermissionForBan + case event.MembershipKnock: + joinRule := event.JoinRule(findEventAndReadString(authEvents, event.StateJoinRules.Type, "", "join_rule", "invite")) + validKnockRule := roomVersion.Knocks() && joinRule == event.JoinRuleKnock + validKnockRestrictedRule := roomVersion.KnockRestricted() && joinRule == event.JoinRuleKnockRestricted + if !validKnockRule && !validKnockRestrictedRule { + // 5.7.1. If the join_rule is anything other than knock or knock_restricted, reject. + return ErrNotKnockableRoom + } + if evt.Sender.String() != *evt.StateKey { + // 5.7.2. If the sender does not match state_key, reject. + return ErrCantKnockOtherUser + } + if senderMembership != event.MembershipBan && senderMembership != event.MembershipInvite && senderMembership != event.MembershipJoin { + // 5.7.3. If the sender’s current membership is not ban, invite, or join, allow. + return nil + } + // 5.7.4. Otherwise, reject. + return ErrCantKnockWhileInRoom + default: + // 5.8. Otherwise, the membership is unknown. Reject. + return ErrUnknownMembership + } +} + +func authorizePowerLevels(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEvents []*pdu.PDU) error { + if roomVersion.ValidatePowerLevelInts() { + for _, key := range []string{"users_default", "events_default", "state_default", "ban", "redact", "kick", "invite"} { + res := gjson.GetBytes(evt.Content, key) + if !res.Exists() { + continue + } + if parseIntWithVersion(roomVersion, res) == nil { + // 10.1. If any of the properties users_default, events_default, state_default, ban, redact, kick, or invite in content are present and not an integer, reject. + return fmt.Errorf("%w %s", ErrTopLevelPLNotInteger, key) + } + } + for _, key := range []string{"events", "notifications"} { + obj := gjson.GetBytes(evt.Content, key) + if !obj.Exists() { + continue + } + // 10.2. If either of the properties events or notifications in content are present and not an object [...], reject. + if !obj.IsObject() { + return fmt.Errorf("%w %s", ErrTopLevelPLNotInteger, key) + } + var err error + // 10.2. [...] are not an object with values that are integers, reject. + obj.ForEach(func(innerKey, value gjson.Result) bool { + if parseIntWithVersion(roomVersion, value) == nil { + err = fmt.Errorf("%w %s.%s", ErrPLNotInteger, key, innerKey.Str) + return false + } + return true + }) + if err != nil { + return err + } + } + } + var creators []id.UserID + if roomVersion.PrivilegedRoomCreators() { + creators = append(creators, createEvt.Sender) + gjson.GetBytes(createEvt.Content, "additional_creators").ForEach(func(key, value gjson.Result) bool { + creators = append(creators, id.UserID(value.Str)) + return true + }) + } + users := gjson.GetBytes(evt.Content, "users") + if users.Exists() { + if !users.IsObject() { + // 10.3. If the users property in content is not an object [...], reject. + return fmt.Errorf("%w users", ErrTopLevelPLNotInteger) + } + var err error + users.ForEach(func(key, value gjson.Result) bool { + if validatorErr := isValidUserID(roomVersion, key); validatorErr != nil { + // 10.3. [...] is not an object with keys that are valid user IDs [...], reject. + err = fmt.Errorf("%w: %q %w", ErrInvalidUserIDInPL, key.Str, validatorErr) + return false + } + if parseIntWithVersion(roomVersion, value) == nil { + // 10.3. [...] is not an object [...] with values that are integers, reject. + err = fmt.Errorf("%w %q", ErrUserPLNotInteger, key.Str) + return false + } + // creators is only filled if the room version has privileged room creators + if slices.Contains(creators, id.UserID(key.Str)) { + // 10.4. If the users property in content contains the sender of the m.room.create event or any of + // the additional_creators array (if present) from the content of the m.room.create event, reject. + err = fmt.Errorf("%w: %q", ErrCreatorInPowerLevels, key.Str) + return false + } + return true + }) + if err != nil { + return err + } + } + oldPL := findEvent(authEvents, event.StatePowerLevels.Type, "") + if oldPL == nil { + // 10.5. If there is no previous m.room.power_levels event in the room, allow. + return nil + } + if slices.Contains(creators, evt.Sender) { + // Skip remaining checks for creators + return nil + } + senderPLPtr := parsePythonInt(gjson.GetBytes(oldPL.Content, exgjson.Path("users", evt.Sender.String()))) + if senderPLPtr == nil { + senderPLPtr = parsePythonInt(gjson.GetBytes(oldPL.Content, "users_default")) + if senderPLPtr == nil { + senderPLPtr = ptr.Ptr(0) + } + } + for _, key := range []string{"users_default", "events_default", "state_default", "ban", "redact", "kick", "invite"} { + oldVal := gjson.GetBytes(oldPL.Content, key) + newVal := gjson.GetBytes(evt.Content, key) + if err := allowPowerChange(roomVersion, *senderPLPtr, key, oldVal, newVal); err != nil { + return err + } + } + if err := allowPowerChangeMap( + roomVersion, *senderPLPtr, "events", "", + gjson.GetBytes(oldPL.Content, "events"), + gjson.GetBytes(evt.Content, "events"), + ); err != nil { + return err + } + if err := allowPowerChangeMap( + roomVersion, *senderPLPtr, "notifications", "", + gjson.GetBytes(oldPL.Content, "notifications"), + gjson.GetBytes(evt.Content, "notifications"), + ); err != nil { + return err + } + if err := allowPowerChangeMap( + roomVersion, *senderPLPtr, "users", evt.Sender.String(), + gjson.GetBytes(oldPL.Content, "users"), + gjson.GetBytes(evt.Content, "users"), + ); err != nil { + return err + } + return nil +} + +func allowPowerChangeMap(roomVersion id.RoomVersion, maxVal int, path, ownID string, old, new gjson.Result) (err error) { + old.ForEach(func(key, value gjson.Result) bool { + newVal := new.Get(exgjson.Path(key.Str)) + err = allowPowerChange(roomVersion, maxVal, path+"."+key.Str, value, newVal) + if err == nil && ownID != "" && key.Str != ownID { + val := parseIntWithVersion(roomVersion, value) + if *val >= maxVal { + err = fmt.Errorf("%w: can't change users.%s from %s to %s with sender level %d", ErrInvalidPowerChange, key.Str, stringifyForError(value), stringifyForError(newVal), maxVal) + } + } + return err == nil + }) + if err != nil { + return + } + new.ForEach(func(key, value gjson.Result) bool { + err = allowPowerChange(roomVersion, maxVal, path+"."+key.Str, old.Get(exgjson.Path(key.Path(key.Str))), value) + return err == nil + }) + return +} + +func allowPowerChange(roomVersion id.RoomVersion, maxVal int, path string, old, new gjson.Result) error { + oldVal := parseIntWithVersion(roomVersion, old) + newVal := parseIntWithVersion(roomVersion, new) + if oldVal == nil { + if newVal == nil || *newVal <= maxVal { + return nil + } + } else if newVal == nil { + if *oldVal <= maxVal { + return nil + } + } else if *oldVal == *newVal || (*oldVal <= maxVal && *newVal <= maxVal) { + return nil + } + return fmt.Errorf("%w can't change %s from %s to %s with sender level %d", ErrInvalidPowerChange, path, stringifyForError(old), stringifyForError(new), maxVal) +} + +func stringifyForError(val gjson.Result) string { + if !val.Exists() { + return "null" + } + return val.Raw +} + +func findEvent(events []*pdu.PDU, evtType, stateKey string) *pdu.PDU { + for _, evt := range events { + if evt.Type == evtType && *evt.StateKey == stateKey { + return evt + } + } + return nil +} + +func findEventAndReadData[T any](events []*pdu.PDU, evtType, stateKey string, reader func(evt *pdu.PDU) T) T { + return reader(findEvent(events, evtType, stateKey)) +} + +func findEventAndReadString(events []*pdu.PDU, evtType, stateKey, fieldPath, defVal string) string { + return findEventAndReadData(events, evtType, stateKey, func(evt *pdu.PDU) string { + if evt == nil { + return defVal + } + res := gjson.GetBytes(evt.Content, fieldPath) + if res.Type != gjson.String { + return defVal + } + return res.Str + }) +} + +func getPowerLevels(roomVersion id.RoomVersion, authEvents []*pdu.PDU, createEvt *pdu.PDU) (*event.PowerLevelsEventContent, error) { + var err error + powerLevels := findEventAndReadData(authEvents, event.StatePowerLevels.Type, "", func(evt *pdu.PDU) (out event.PowerLevelsEventContent) { + if evt == nil { + return + } + content := evt.Content + if !roomVersion.ValidatePowerLevelInts() { + safeParsePowerLevels(content, &out) + } else { + err = json.Unmarshal(content, &out) + } + return + }) + if err != nil { + // This should never happen thanks to safeParsePowerLevels for v1-9 and strict validation in v10+ + return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err) + } + if roomVersion.PrivilegedRoomCreators() { + powerLevels.CreateEvent, err = createEvt.ToClientEvent(roomVersion) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err) + } + err = powerLevels.CreateEvent.Content.ParseRaw(powerLevels.CreateEvent.Type) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err) + } + } else { + powerLevels.Users = map[id.UserID]int{ + createEvt.Sender: (1 << 53) - 1, + } + } + return &powerLevels, nil +} + +func parseIntWithVersion(roomVersion id.RoomVersion, val gjson.Result) *int { + if roomVersion.ValidatePowerLevelInts() { + if val.Type != gjson.Number { + return nil + } + return ptr.Ptr(int(val.Int())) + } + return parsePythonInt(val) +} + +func parsePythonInt(val gjson.Result) *int { + switch val.Type { + case gjson.True: + return ptr.Ptr(1) + case gjson.False: + return ptr.Ptr(0) + case gjson.Number: + return ptr.Ptr(int(val.Int())) + case gjson.String: + // strconv.Atoi accepts signs as well as leading zeroes, so we just need to trim spaces beforehand + num, err := strconv.Atoi(strings.TrimSpace(val.Str)) + if err != nil { + return nil + } + return &num + default: + // Python int() doesn't accept nulls, arrays or dicts + return nil + } +} + +func safeParsePowerLevels(content jsontext.Value, into *event.PowerLevelsEventContent) { + *into = event.PowerLevelsEventContent{ + Users: make(map[id.UserID]int), + UsersDefault: ptr.Val(parsePythonInt(gjson.GetBytes(content, "users_default"))), + Events: make(map[string]int), + EventsDefault: ptr.Val(parsePythonInt(gjson.GetBytes(content, "events_default"))), + Notifications: nil, // irrelevant for event auth + StateDefaultPtr: parsePythonInt(gjson.GetBytes(content, "state_default")), + InvitePtr: parsePythonInt(gjson.GetBytes(content, "invite")), + KickPtr: parsePythonInt(gjson.GetBytes(content, "kick")), + BanPtr: parsePythonInt(gjson.GetBytes(content, "ban")), + RedactPtr: parsePythonInt(gjson.GetBytes(content, "redact")), + } + gjson.GetBytes(content, "events").ForEach(func(key, value gjson.Result) bool { + if key.Type != gjson.String { + return false + } + val := parsePythonInt(value) + if val != nil { + into.Events[key.Str] = *val + } + return true + }) + gjson.GetBytes(content, "users").ForEach(func(key, value gjson.Result) bool { + if key.Type != gjson.String { + return false + } + val := parsePythonInt(value) + if val == nil { + return false + } + userID := id.UserID(key.Str) + if _, _, err := userID.Parse(); err != nil { + return false + } + into.Users[userID] = *val + return true + }) +} diff --git a/federation/eventauth/eventauth_test.go b/federation/eventauth/eventauth_test.go new file mode 100644 index 00000000..e3c5cd76 --- /dev/null +++ b/federation/eventauth/eventauth_test.go @@ -0,0 +1,85 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package eventauth_test + +import ( + "embed" + "encoding/json/jsontext" + "encoding/json/v2" + "errors" + "io" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" + "go.mau.fi/util/exerrors" + "go.mau.fi/util/ptr" + + "maunium.net/go/mautrix/federation/eventauth" + "maunium.net/go/mautrix/federation/pdu" + "maunium.net/go/mautrix/id" +) + +//go:embed *.jsonl +var data embed.FS + +type eventMap map[id.EventID]*pdu.PDU + +func (em eventMap) Get(ids []id.EventID) ([]*pdu.PDU, error) { + output := make([]*pdu.PDU, len(ids)) + for i, evtID := range ids { + output[i] = em[evtID] + } + return output, nil +} + +func GetKey(serverName string, keyID id.KeyID, validUntilTS time.Time) (id.SigningKey, time.Time, error) { + return "", time.Time{}, nil +} + +func TestAuthorize(t *testing.T) { + files := exerrors.Must(data.ReadDir(".")) + for _, file := range files { + t.Run(file.Name(), func(t *testing.T) { + decoder := jsontext.NewDecoder(exerrors.Must(data.Open(file.Name()))) + events := make(eventMap) + var roomVersion *id.RoomVersion + for i := 1; ; i++ { + var evt *pdu.PDU + err := json.UnmarshalDecode(decoder, &evt) + if errors.Is(err, io.EOF) { + break + } + require.NoError(t, err) + if roomVersion == nil { + require.Equal(t, evt.Type, "m.room.create") + roomVersion = ptr.Ptr(id.RoomVersion(gjson.GetBytes(evt.Content, "room_version").Str)) + } + expectedEventID := gjson.GetBytes(evt.Unsigned, "event_id").Str + evtID, err := evt.GetEventID(*roomVersion) + require.NoError(t, err) + require.Equalf(t, id.EventID(expectedEventID), evtID, "Event ID mismatch for event #%d", i) + + // TODO allow redacted events + assert.True(t, evt.VerifyContentHash(), i) + + events[evtID] = evt + err = eventauth.Authorize(*roomVersion, evt, events.Get, GetKey) + if err != nil { + evt.InternalMeta.Rejected = true + } + // TODO allow testing intentionally rejected events + assert.NoErrorf(t, err, "Failed to authorize event #%d / %s of type %s", i, evtID, evt.Type) + } + }) + } + +} diff --git a/federation/eventauth/testroom-v12-success.jsonl b/federation/eventauth/testroom-v12-success.jsonl new file mode 100644 index 00000000..1f0b5357 --- /dev/null +++ b/federation/eventauth/testroom-v12-success.jsonl @@ -0,0 +1,17 @@ +{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age_ts":1756071567186,"event_id":"$lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54"}} +{"auth_events":[],"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"depth":2,"hashes":{"sha256":"MXmgq0e4J9CdIP0IVKVvueFhOb+ndlsXpeyI+6l/2FI"},"origin_server_ts":1756071567259,"prev_events":["$lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"xMgRzyRg9VM9XCKpfFJA+MrYoI68b8PIddKpMTcxz/fDzmGSHEy6Ta2b59VxiX3NoJe2CigkDZ3+jVsQoZYIBA"}},"state_key":"@tulir:maunium.net","type":"m.room.member","unsigned":{"age_ts":1756071567259,"event_id":"$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"}} +{"auth_events":["$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":150},"events_default":0,"historical":100,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001},"users_default":0},"depth":3,"hashes":{"sha256":"/JzQNBNqJ/i8vwj6xESDaD5EDdOqB4l/LmKlvAVl5jY"},"origin_server_ts":1756071567319,"prev_events":["$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"W3N3X/enja+lumXw3uz66/wT9oczoxrmHbAD5/RF069cX4wkCtqtDd61VWPkSGmKxdV1jurgbCqSX6+Q9/t3AA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"age_ts":1756071567319,"event_id":"$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"}} +{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"join_rule":"invite"},"depth":4,"hashes":{"sha256":"GBu5AySj75ZXlOLd65mB03KueFKOHNgvtg2o/LUnLyI"},"origin_server_ts":1756071567320,"prev_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"XqWEnFREo2PhRnaebGjNzdHdtD691BtCQKkLnpKd8P3lVDewDt8OkCbDSk/Uzh9rDtzwWEsbsIoKSYuOm+G6CA"}},"state_key":"","type":"m.room.join_rules","unsigned":{"age_ts":1756071567320,"event_id":"$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"}} +{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"history_visibility":"shared"},"depth":5,"hashes":{"sha256":"niDi5vG2akQm0f5pm0aoCYXqmWjXRfmP1ulr/ZEPm/k"},"origin_server_ts":1756071567320,"prev_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"PTIrNke/fc9+ObKAl/K0PGZfmpe8dwREyoA5rXffOXWdRHSaBifn9UIiJUqd68Bzvrv4RcADTR/ci7lUquFBBw"}},"state_key":"","type":"m.room.history_visibility","unsigned":{"age_ts":1756071567320,"event_id":"$Wmy3G9yxl9ArVg5ZsdeIDPxBsNAdgseuvHoqHTZ2vug"}} +{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"guest_access":"can_join"},"depth":6,"hashes":{"sha256":"sZ9QqsId4oarFF724esTohXuRxDNnaXPl+QmTDG60dw"},"origin_server_ts":1756071567321,"prev_events":["$Wmy3G9yxl9ArVg5ZsdeIDPxBsNAdgseuvHoqHTZ2vug"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"Eh2P9/hl38wfZx2AQbeS5VCD4wldXPfeP2sQsJsLtfmdwFV74jrlGVBaKIkaYcXY4eA08iDp8HW5jqttZqKKDg"}},"state_key":"","type":"m.room.guest_access","unsigned":{"age_ts":1756071567321,"event_id":"$hYVRH7F4P5mB5IqvBDDU5aXY7pYGG0ApstrryiVPKmQ"}} +{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"name":"event auth test v12"},"depth":7,"hashes":{"sha256":"tjwPo38yR+23Was6SbxLvPMhNx44DaXLhF3rKgngepU"},"origin_server_ts":1756071567321,"prev_events":["$hYVRH7F4P5mB5IqvBDDU5aXY7pYGG0ApstrryiVPKmQ"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"q1rk0c5m8TJYE9tePsMaLeaigatNNbvaLRom0X8KiZY0EH+itujfA+/UnksvmPmMmThfAXWlFLx5u8tcuSVyCQ"}},"state_key":"","type":"m.room.name","unsigned":{"age_ts":1756071567321,"event_id":"$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"}} +{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"invite"},"depth":8,"hashes":{"sha256":"r5EBUZN/4LbVcMYwuffDcVV9G4OMHzAQuNbnjigL+OE"},"origin_server_ts":1756071567548,"prev_events":["$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"envs.net":{"ed25519:wuJyKT":"svB+uW4Tsj8/I+SYbLl+LPPjBlqxGNXE4wGyAxlP7vfyJtFf7Kn/19jx65wT9ebeCq5sTGlEDV4Fabwma9LhDA"},"maunium.net":{"ed25519:a_xxeS":"LBYMcdJVSNsLd6SmOgx5oOU/0xOeCl03o4g83VwJfHWlRuTT5l9+qlpNED28wY07uxoU9MgLgXXICJ0EezMBCg"}},"state_key":"@tulir:envs.net","type":"m.room.member","unsigned":{"age_ts":1756071567548,"event_id":"$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok","invite_room_state":[{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"},{"content":{"name":"event auth test v12"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.name"},{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age_ts":1756071567186}},{"content":{"avatar_url":"mxc://maunium.net/jdlSfvudiMSmcRrleeiYjjFO","displayname":"tulir","membership":"join"},"sender":"@tulir:maunium.net","state_key":"@tulir:maunium.net","type":"m.room.member"}]}} +{"auth_events":["$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$mmqm2KS4UExkNL65c6CIhKofn_L9fzF2OhghVqajksU"],"content":{"body":"meow","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":9,"hashes":{"sha256":"23rgMf7EGJcYt3Aj0qAFnmBWCxuU9Uk+ReidqtIJDKQ"},"origin_server_ts":1756071575986,"prev_events":["$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"p+Fm/uWO8VXJdCYvN/dVb8HF8W3t1sssNCBiOWbzAeuS3QqYjoMKHyixLuN1mOdnCyATv7SsHHmA4+cELRGdAA"}},"type":"m.room.message","unsigned":{"age_ts":1756071576002,"event_id":"$eZDCydRWSRnR5od0c7ahz2qSZQDHbl5g5PITT0OMC3E"}} +{"auth_events":["$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M"],"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"join"},"depth":10,"hashes":{"sha256":"2kJPx2UsysNzTH8QGYHUKTO/05yetxKRlI0nKFeGbts"},"origin_server_ts":1756071578631,"prev_events":["$eZDCydRWSRnR5od0c7ahz2qSZQDHbl5g5PITT0OMC3E"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"Wuzxkh8nEEX6mdJzph6Bt5ku+odFkEg2RIpFAAirOqxgcrwRaz42PsJni3YbfzH1qneF+iWQ/neA+up6jLXFBw"}},"state_key":"@tulir:envs.net","type":"m.room.member","unsigned":{"age":6,"event_id":"$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","replaces_state":"$qYZqSKiKMCNjzH6Trhr6nBSvbfuwr8Sh2bC4USSAxok"}} +{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"invite"},"depth":11,"hashes":{"sha256":"dRE11R2hBfFalQ5tIJdyaElUIiSE5aCKMddjek4wR3c"},"origin_server_ts":1756071591449,"prev_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"/Mi4kX40fbR+V3DCJJGI/9L3Uuf8y5Un8LHlCQv1T0O5gnFZGQ3qN6rRNaZ1Kdh3QJBU6H4NTfnd+SVj3wt3CQ"},"matrix.org":{"ed25519:a_RXGa":"ZeLm/oxP3/Cds/uCL2FaZpgjUp0vTDBlGG6YVFNl76yIVlyIKKQKR6BSVw2u5KC5Mu9M1f+0lDmLGQujR5NkBg"}},"state_key":"@tulir:matrix.org","type":"m.room.member","unsigned":{"event_id":"$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4","invite_room_state":[{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"},{"content":{"name":"event auth test v12"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.name"},{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age":11553}},{"content":{"avatar_url":"mxc://envs.net/000cf1510b7c61018f9c72ca4cc63668370782c81725865933316030464","displayname":"tulir[e]","membership":"join"},"sender":"@tulir:envs.net","state_key":"@tulir:envs.net","type":"m.room.member"}]}} +{"auth_events":["$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4","$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"join"},"depth":12,"hashes":{"sha256":"hR/fRIyFkxKnA1XNxIB+NKC0VR0vHs82EDgydhmmZXU"},"origin_server_ts":1756071609205,"prev_events":["$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"keWbZHm+LPW22XWxb14Att4Ae4GVc6XAKAnxFRr3hxhrgEhsnMcxUx7fjqlA1dk3As6kjLKdekcyCef+AQCXCA"}},"state_key":"@tulir:matrix.org","type":"m.room.member","unsigned":{"age":19,"event_id":"$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","replaces_state":"$g4eBtA9EFNGLkHOofvQ4U87GNt4W8NmfmNRyR0wOUO4"}} +{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":150},"events_default":0,"historical":100,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":13,"hashes":{"sha256":"30Wuw3xIbA8+eXQBa4nFDKcyHtMbKPBYhLW1zft9/fE"},"origin_server_ts":1756071643928,"prev_events":["$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"x6Y4uViq4nK8LVPqtMLdCuvNET2bnjxYTgiKuEe1JYfwB4jPBnPuqvrt1O9oaanMpcRWbnuiZjckq4bUlRZ7Cw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg","replaces_state":"$v3gylw64IK4PohOe0M8XO1PZthibpBCKVBI3x_8xiUU"}} +{"auth_events":["$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg","$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw"],"content":{"name":"event auth test v12!"},"depth":14,"hashes":{"sha256":"WT0gz7KYXvbdNruRavqIi9Hhul3rxCdZ+YY9yMGN+Fw"},"origin_server_ts":1756071656988,"prev_events":["$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"bSplmqtXVhO2Z3hJ8JMQ/u7G2Wmg6yt7SwhYXObRQJfthekddJN152ME4YJIwy7YD8WFq7EkyB/NMyQoliYyCg"}},"state_key":"","type":"m.room.name","unsigned":{"event_id":"$p4xvOczrhzQMtRW3-Tf86LYUb5aqpGFIgjwHBuxWIcI","replaces_state":"$fFDwIavLTEIfcnggWuryB6JwfS-L2KT6vP1ap3P6ctE"}} +{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":9001},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":15,"hashes":{"sha256":"FnGzbcXc8YOiB1TY33QunGA17Axoyuu3sdVOj5Z408o"},"origin_server_ts":1756071804931,"prev_events":["$p4xvOczrhzQMtRW3-Tf86LYUb5aqpGFIgjwHBuxWIcI"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"uyTUsPR+CzCtlevzB5+sNXvmfbPSp6u7RZC4E4TLVsj45+pjmMRswAvuHP9PT2+Tkl6Hu8ZPigsXgbKZtR35Aw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw","replaces_state":"$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"}} +{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":16,"hashes":{"sha256":"KcivsiLesdnUnKX23Akk3OJEJFGRSY0g4H+p7XIThnw"},"origin_server_ts":1756071812688,"prev_events":["$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"cAK8dO2AVZklY9te5aVKbF1jR/eB5rzeNOXfYPjBLf+aSAS4Z6R2aMKW6hJB9PqRS4S+UZc24DTrjUjnvMzeBA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU","replaces_state":"$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"}} +{"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"content":{"body":"meow #2","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":17,"hashes":{"sha256":"SgH9fOXGdbdqpRfYmoz1t29+gX8Ze4ThSoj6klZs3Og"},"origin_server_ts":1756247476706,"prev_events":["$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"SMYK7zP3SaQOKhzZUKUBVCKwffYqi3PFAlPM34kRJtmfGU3KZXNBT0zi+veXDMmxkMunqhF2RTHBD6joa0kBAQ"}},"type":"m.room.message","unsigned":{"event_id":"$KFHLO0-ENYOGQXogp84C-ISSu1xtKUzIMaZ6LiBcR_w"}} diff --git a/federation/pdu/pdu.go b/federation/pdu/pdu.go index b5210550..2dbdefc1 100644 --- a/federation/pdu/pdu.go +++ b/federation/pdu/pdu.go @@ -53,6 +53,10 @@ var ( _ AnyPDU = (*RoomV1PDU)(nil) ) +type InternalMeta struct { + Rejected bool `json:"rejected,omitempty"` +} + type PDU struct { AuthEvents []id.EventID `json:"auth_events"` Content jsontext.Value `json:"content"` @@ -67,6 +71,7 @@ type PDU struct { StateKey *string `json:"state_key,omitzero"` Type string `json:"type"` Unsigned jsontext.Value `json:"unsigned,omitzero"` + InternalMeta InternalMeta `json:"-"` Unknown jsontext.Value `json:",unknown"` diff --git a/federation/signutil/verify.go b/federation/signutil/verify.go index 8fe55b2f..ea0e7886 100644 --- a/federation/signutil/verify.go +++ b/federation/signutil/verify.go @@ -48,6 +48,47 @@ func VerifyJSON(serverName string, keyID id.KeyID, key id.SigningKey, data any) return VerifyJSONRaw(key, sigVal.Str, message) } +func VerifyJSONAny(key id.SigningKey, data any) error { + var err error + message, ok := data.(json.RawMessage) + if !ok { + message, err = json.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal data: %w", err) + } + } + sigs := gjson.GetBytes(message, "signatures") + if !sigs.IsObject() { + return ErrSignatureNotFound + } + message, err = sjson.DeleteBytes(message, "signatures") + if err != nil { + return fmt.Errorf("failed to delete signatures: %w", err) + } + message, err = sjson.DeleteBytes(message, "unsigned") + if err != nil { + return fmt.Errorf("failed to delete unsigned: %w", err) + } + var validated bool + sigs.ForEach(func(_, value gjson.Result) bool { + if !value.IsObject() { + return true + } + value.ForEach(func(_, value gjson.Result) bool { + if value.Type != gjson.String { + return true + } + validated = VerifyJSONRaw(key, value.Str, message) == nil + return !validated + }) + return !validated + }) + if !validated { + return ErrInvalidSignature + } + return nil +} + func VerifyJSONRaw(key id.SigningKey, sig string, message json.RawMessage) error { sigBytes, err := base64.RawStdEncoding.DecodeString(sig) if err != nil { From f131ae5aa4b25f6f3ada0db00aa3df89b84d91d6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 27 Aug 2025 12:24:15 +0300 Subject: [PATCH 1336/1647] federation/pdu: add cached event ID to internal metadata --- federation/pdu/pdu.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/federation/pdu/pdu.go b/federation/pdu/pdu.go index 2dbdefc1..218dd78a 100644 --- a/federation/pdu/pdu.go +++ b/federation/pdu/pdu.go @@ -54,7 +54,8 @@ var ( ) type InternalMeta struct { - Rejected bool `json:"rejected,omitempty"` + EventID id.EventID `json:"event_id,omitempty"` + Rejected bool `json:"rejected,omitempty"` } type PDU struct { From 9f693702b06b0fb2abfba280ad5307694d30c089 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 27 Aug 2025 12:25:08 +0300 Subject: [PATCH 1337/1647] federation/pdu: add extra field to internal metadata --- federation/pdu/pdu.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/federation/pdu/pdu.go b/federation/pdu/pdu.go index 218dd78a..cecee5b9 100644 --- a/federation/pdu/pdu.go +++ b/federation/pdu/pdu.go @@ -54,8 +54,9 @@ var ( ) type InternalMeta struct { - EventID id.EventID `json:"event_id,omitempty"` - Rejected bool `json:"rejected,omitempty"` + EventID id.EventID `json:"event_id,omitempty"` + Rejected bool `json:"rejected,omitempty"` + Extra map[string]any `json:",unknown"` } type PDU struct { From febca20dd780913e580110c46102595c34b95d7a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 27 Aug 2025 17:11:50 +0300 Subject: [PATCH 1338/1647] bridgev2/status: use _file pattern for avatar instead of splitting url and keys --- bridgev2/status/bridgestate.go | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/bridgev2/status/bridgestate.go b/bridgev2/status/bridgestate.go index 671303e0..3bc5a59b 100644 --- a/bridgev2/status/bridgestate.go +++ b/bridgev2/status/bridgestate.go @@ -22,7 +22,7 @@ import ( "go.mau.fi/util/ptr" "maunium.net/go/mautrix" - "maunium.net/go/mautrix/crypto/attachment" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -89,8 +89,7 @@ type RemoteProfile struct { Name string `json:"name,omitempty"` Avatar id.ContentURIString `json:"avatar,omitempty"` - // Only used for backups of local bridge states - AvatarKeys *attachment.EncryptedFile `json:"avatar_keys,omitempty"` + AvatarFile *event.EncryptedFileInfo `json:"avatar_file,omitempty"` } func coalesce[T ~string](a, b T) T { @@ -106,14 +105,14 @@ func (rp *RemoteProfile) Merge(other RemoteProfile) RemoteProfile { other.Username = coalesce(rp.Username, other.Username) other.Name = coalesce(rp.Name, other.Name) other.Avatar = coalesce(rp.Avatar, other.Avatar) - if rp.AvatarKeys != nil { - other.AvatarKeys = rp.AvatarKeys + if rp.AvatarFile != nil { + other.AvatarFile = rp.AvatarFile } return other } func (rp *RemoteProfile) IsEmpty() bool { - return rp == nil || (rp.Phone == "" && rp.Email == "" && rp.Username == "" && rp.Name == "" && rp.Avatar == "") + return rp == nil || (rp.Phone == "" && rp.Email == "" && rp.Username == "" && rp.Name == "" && rp.Avatar == "" && rp.AvatarFile == nil) } type BridgeState struct { From 359afbea2bba3a016ff54e50a381c69ed4d91b92 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 28 Aug 2025 02:19:15 +0300 Subject: [PATCH 1339/1647] bridgev2/matrix: remove provisioning API prefix option Reverse proxy configuration should be used instead when adding prefixes to the path. Changing the path entirely is not recommended even with reverse proxies. Fixes #403 --- bridgev2/bridgeconfig/config.go | 1 - bridgev2/bridgeconfig/legacymigrate.go | 2 -- bridgev2/bridgeconfig/upgrade.go | 1 - bridgev2/matrix/mxmain/example-config.yaml | 2 -- 4 files changed, 6 deletions(-) diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index 9bdee5fe..13ec738c 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -99,7 +99,6 @@ type AnalyticsConfig struct { } type ProvisioningConfig struct { - Prefix string `yaml:"prefix"` SharedSecret string `yaml:"shared_secret"` DebugEndpoints bool `yaml:"debug_endpoints"` EnableSessionTransfers bool `yaml:"enable_session_transfers"` diff --git a/bridgev2/bridgeconfig/legacymigrate.go b/bridgev2/bridgeconfig/legacymigrate.go index fb2a86d6..954a37c3 100644 --- a/bridgev2/bridgeconfig/legacymigrate.go +++ b/bridgev2/bridgeconfig/legacymigrate.go @@ -133,9 +133,7 @@ func doMigrateLegacy(helper up.Helper, python bool) { CopyToOtherLocation(helper, up.Bool, []string{"bridge", "sync_direct_chat_list"}, []string{"matrix", "sync_direct_chat_list"}) CopyToOtherLocation(helper, up.Bool, []string{"bridge", "federate_rooms"}, []string{"matrix", "federate_rooms"}) - CopyToOtherLocation(helper, up.Str, []string{"bridge", "provisioning", "prefix"}, []string{"provisioning", "prefix"}) CopyToOtherLocation(helper, up.Str, []string{"bridge", "provisioning", "shared_secret"}, []string{"provisioning", "shared_secret"}) - CopyToOtherLocation(helper, up.Str, []string{"appservice", "provisioning", "prefix"}, []string{"provisioning", "prefix"}) CopyToOtherLocation(helper, up.Str, []string{"appservice", "provisioning", "shared_secret"}, []string{"provisioning", "shared_secret"}) CopyToOtherLocation(helper, up.Bool, []string{"bridge", "provisioning", "debug_endpoints"}, []string{"provisioning", "debug_endpoints"}) diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index b69a1fdb..f41f77d8 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -103,7 +103,6 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Str|up.Null, "analytics", "url") helper.Copy(up.Str|up.Null, "analytics", "user_id") - helper.Copy(up.Str, "provisioning", "prefix") if secret, ok := helper.Get(up.Str, "provisioning", "shared_secret"); !ok || secret == "generate" { sharedSecret := random.String(64) helper.Set(up.Str, sharedSecret, "provisioning", "shared_secret") diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 48e0d528..5da1407d 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -247,8 +247,6 @@ analytics: # Settings for provisioning API provisioning: - # Prefix for the provisioning API paths. - prefix: /_matrix/provision # Shared secret for authentication. If set to "generate" or null, a random secret will be generated, # or if set to "disable", the provisioning API will be disabled. shared_secret: generate From 3048d2edab7a78fc575e523deac4271d257ff889 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 28 Aug 2025 02:20:41 +0300 Subject: [PATCH 1340/1647] bridgev2/provisioning: add minimum length for shared secret --- bridgev2/matrix/mxmain/example-config.yaml | 2 +- bridgev2/matrix/provisioning.go | 16 ++++++++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 5da1407d..488f0b4c 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -248,7 +248,7 @@ analytics: # Settings for provisioning API provisioning: # Shared secret for authentication. If set to "generate" or null, a random secret will be generated, - # or if set to "disable", the provisioning API will be disabled. + # or if set to "disable", the provisioning API will be disabled. Must be at least 16 characters. shared_secret: generate # Whether to allow provisioning API requests to be authed using Matrix access tokens. # This follows the same rules as double puppeting to determine which server to contact to check the token, diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index df3e1bdf..2f202f4e 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -210,12 +210,20 @@ func (prov *ProvisioningAPI) checkFederatedMatrixAuth(ctx context.Context, userI } } +func disabledAuth(w http.ResponseWriter, r *http.Request) { + mautrix.MForbidden.WithMessage("Provisioning API is disabled").Write(w) +} + func (prov *ProvisioningAPI) DebugAuthMiddleware(h http.Handler) http.Handler { + secret := prov.br.Config.Provisioning.SharedSecret + if len(secret) < 16 { + return http.HandlerFunc(disabledAuth) + } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") if auth == "" { mautrix.MMissingToken.WithMessage("Missing auth token").Write(w) - } else if !exstrings.ConstantTimeEqual(auth, prov.br.Config.Provisioning.SharedSecret) { + } else if !exstrings.ConstantTimeEqual(auth, secret) { mautrix.MUnknownToken.WithMessage("Invalid auth token").Write(w) } else { h.ServeHTTP(w, r) @@ -224,6 +232,10 @@ func (prov *ProvisioningAPI) DebugAuthMiddleware(h http.Handler) http.Handler { } func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { + secret := prov.br.Config.Provisioning.SharedSecret + if len(secret) < 16 { + return http.HandlerFunc(disabledAuth) + } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") if auth == "" && prov.GetAuthFromRequest != nil { @@ -237,7 +249,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { if userID == "" && prov.GetUserIDFromRequest != nil { userID = prov.GetUserIDFromRequest(r) } - if !exstrings.ConstantTimeEqual(auth, prov.br.Config.Provisioning.SharedSecret) { + if !exstrings.ConstantTimeEqual(auth, secret) { var err error if strings.HasPrefix(auth, "openid:") { err = prov.checkFederatedMatrixAuth(r.Context(), userID, strings.TrimPrefix(auth, "openid:")) From 19f3b2179cb8e00806193a1e54b8afcd30ea7dbe Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 29 Aug 2025 11:07:16 +0300 Subject: [PATCH 1341/1647] pre-commit: ban `log.Str(x.String())` --- .pre-commit-config.yaml | 1 + appservice/appservice.go | 2 +- appservice/http.go | 2 +- bridgev2/matrix/crypto.go | 10 +++++----- client.go | 2 +- crypto/cross_sign_store.go | 6 +++--- crypto/decryptolm.go | 2 +- crypto/devicelist.go | 6 +++--- crypto/encryptmegolm.go | 4 ++-- crypto/machine.go | 4 ++-- crypto/verificationhelper/sas.go | 2 +- example/main.go | 2 +- 12 files changed, 22 insertions(+), 21 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 81701203..0b9785ae 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,3 +27,4 @@ repos: - id: prevent-literal-http-methods - id: zerolog-ban-global-log - id: zerolog-ban-msgf + - id: zerolog-use-stringer diff --git a/appservice/appservice.go b/appservice/appservice.go index b0af02cd..33b53d7d 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -360,7 +360,7 @@ func (as *AppService) NewMautrixClient(userID id.UserID) *mautrix.Client { AccessToken: as.Registration.AppToken, UserAgent: as.UserAgent, StateStore: as.StateStore, - Log: as.Log.With().Str("as_user_id", userID.String()).Logger(), + Log: as.Log.With().Stringer("as_user_id", userID).Logger(), Client: as.HTTPClient, DefaultHTTPRetries: as.DefaultHTTPRetries, SpecVersions: as.SpecVersions, diff --git a/appservice/http.go b/appservice/http.go index 862de7fd..27ce6288 100644 --- a/appservice/http.go +++ b/appservice/http.go @@ -201,7 +201,7 @@ func (as *AppService) handleEvents(ctx context.Context, evts []*event.Event, def } err := evt.Content.ParseRaw(evt.Type) if errors.Is(err, event.ErrUnsupportedContentType) { - log.Debug().Str("event_id", evt.ID.String()).Msg("Not parsing content of unsupported event") + log.Debug().Stringer("event_id", evt.ID).Msg("Not parsing content of unsupported event") } else if err != nil { log.Warn().Err(err). Str("event_id", evt.ID.String()). diff --git a/bridgev2/matrix/crypto.go b/bridgev2/matrix/crypto.go index 47226625..2325ddfa 100644 --- a/bridgev2/matrix/crypto.go +++ b/bridgev2/matrix/crypto.go @@ -157,12 +157,12 @@ func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) { var evt event.EncryptionEventContent err = helper.client.StateEvent(ctx, roomID, event.StateEncryption, "", &evt) if err != nil { - log.Err(err).Str("room_id", roomID.String()).Msg("Failed to get encryption event") + log.Err(err).Stringer("room_id", roomID).Msg("Failed to get encryption event") _, err = helper.store.DB.Exec(ctx, ` UPDATE mx_room_state SET encryption=NULL WHERE room_id=$1 AND encryption='{"resync":true}' `, roomID) if err != nil { - log.Err(err).Str("room_id", roomID.String()).Msg("Failed to unmark room for resync after failed sync") + log.Err(err).Stringer("room_id", roomID).Msg("Failed to unmark room for resync after failed sync") } } else { maxAge := evt.RotationPeriodMillis @@ -185,9 +185,9 @@ func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) { WHERE room_id=$3 AND max_age IS NULL AND max_messages IS NULL `, maxAge, maxMessages, roomID) if err != nil { - log.Err(err).Str("room_id", roomID.String()).Msg("Failed to update megolm session table") + log.Err(err).Stringer("room_id", roomID).Msg("Failed to update megolm session table") } else { - log.Debug().Str("room_id", roomID.String()).Msg("Updated megolm session table") + log.Debug().Stringer("room_id", roomID).Msg("Updated megolm session table") } } } @@ -233,7 +233,7 @@ func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool if err != nil { return nil, false, fmt.Errorf("failed to find existing device ID: %w", err) } else if len(deviceID) > 0 { - helper.log.Debug().Str("device_id", deviceID.String()).Msg("Found existing device ID for bot in database") + helper.log.Debug().Stringer("device_id", deviceID).Msg("Found existing device ID for bot in database") } // Create a new client instance with the default AS settings (including as_token), // the Login call will then override the access token in the client. diff --git a/client.go b/client.go index 78f83b85..45230c1e 100644 --- a/client.go +++ b/client.go @@ -1785,7 +1785,7 @@ func (cli *Client) UploadAsync(ctx context.Context, req ReqUploadMedia) (*RespCr go func() { _, err = cli.UploadMedia(ctx, req) if err != nil { - cli.Log.Error().Str("mxc", req.MXC.String()).Err(err).Msg("Async upload of media failed") + cli.Log.Error().Stringer("mxc", req.MXC).Err(err).Msg("Async upload of media failed") } }() return resp, nil diff --git a/crypto/cross_sign_store.go b/crypto/cross_sign_store.go index b583bada..d30b7e32 100644 --- a/crypto/cross_sign_store.go +++ b/crypto/cross_sign_store.go @@ -20,7 +20,7 @@ import ( func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningKeys map[id.UserID]mautrix.CrossSigningKeys, deviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys) { log := mach.machOrContextLog(ctx) for userID, userKeys := range crossSigningKeys { - log := log.With().Str("user_id", userID.String()).Logger() + log := log.With().Stringer("user_id", userID).Logger() currentKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID) if err != nil { log.Error().Err(err). @@ -28,7 +28,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK } if currentKeys != nil { for curKeyUsage, curKey := range currentKeys { - log := log.With().Str("old_key", curKey.Key.String()).Str("old_key_usage", string(curKeyUsage)).Logger() + log := log.With().Stringer("old_key", curKey.Key).Str("old_key_usage", string(curKeyUsage)).Logger() // got a new key with the same usage as an existing key for _, newKeyUsage := range userKeys.Usage { if newKeyUsage == curKeyUsage { @@ -49,7 +49,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK } for _, key := range userKeys.Keys { - log := log.With().Str("key", key.String()).Array("usages", exzerolog.ArrayOfStrs(userKeys.Usage)).Logger() + log := log.With().Stringer("key", key).Array("usages", exzerolog.ArrayOfStrs(userKeys.Usage)).Logger() for _, usage := range userKeys.Usage { log.Trace().Str("usage", string(usage)).Msg("Storing cross-signing key") if err = mach.CryptoStore.PutCrossSigningKey(ctx, userID, usage, key); err != nil { diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go index b737e4e1..b961a7b4 100644 --- a/crypto/decryptolm.go +++ b/crypto/decryptolm.go @@ -340,7 +340,7 @@ func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, send return } - log.Debug().Str("device_id", deviceIdentity.DeviceID.String()).Msg("Creating new Olm session") + log.Debug().Stringer("device_id", deviceIdentity.DeviceID).Msg("Creating new Olm session") mach.devicesToUnwedgeLock.Lock() mach.devicesToUnwedge[senderKey] = true mach.devicesToUnwedgeLock.Unlock() diff --git a/crypto/devicelist.go b/crypto/devicelist.go index a2116ed5..61a22522 100644 --- a/crypto/devicelist.go +++ b/crypto/devicelist.go @@ -206,7 +206,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ log.Trace().Int("user_count", len(resp.DeviceKeys)).Msg("Query key result received") data = make(map[id.UserID]map[id.DeviceID]*id.Device) for userID, devices := range resp.DeviceKeys { - log := log.With().Str("user_id", userID.String()).Logger() + log := log.With().Stringer("user_id", userID).Logger() delete(req.DeviceKeys, userID) newDevices := make(map[id.DeviceID]*id.Device) @@ -222,7 +222,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ Msg("Updating devices in store") changed := false for deviceID, deviceKeys := range devices { - log := log.With().Str("device_id", deviceID.String()).Logger() + log := log.With().Stringer("device_id", deviceID).Logger() existing, ok := existingDevices[deviceID] if !ok { // New device @@ -270,7 +270,7 @@ func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includ } } for userID := range req.DeviceKeys { - log.Warn().Str("user_id", userID.String()).Msg("Didn't get any keys for user") + log.Warn().Stringer("user_id", userID).Msg("Didn't get any keys for user") } mach.storeCrossSigningKeys(ctx, resp.MasterKeys, resp.DeviceKeys) diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index cd211af5..b3d19618 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -233,7 +233,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, var fetchKeysForUsers []id.UserID for _, userID := range users { - log := log.With().Str("target_user_id", userID.String()).Logger() + log := log.With().Stringer("target_user_id", userID).Logger() devices, err := mach.CryptoStore.GetDevices(ctx, userID) if err != nil { log.Err(err).Msg("Failed to get devices of user") @@ -305,7 +305,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, toDeviceWithheld.Messages[userID] = withheld } - log := log.With().Str("target_user_id", userID.String()).Logger() + log := log.With().Stringer("target_user_id", userID).Logger() log.Trace().Msg("Trying to find olm session to encrypt megolm session for user (post-fetch retry)") mach.findOlmSessionsForUser(ctx, session, userID, devices, output, withheld, nil) log.Debug(). diff --git a/crypto/machine.go b/crypto/machine.go index cac91bf8..e791e70d 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -361,7 +361,7 @@ func (mach *OlmMachine) HandleMemberEvent(ctx context.Context, evt *event.Event) Msg("Got membership state change, invalidating group session in room") err := mach.CryptoStore.RemoveOutboundGroupSession(ctx, evt.RoomID) if err != nil { - mach.Log.Warn().Str("room_id", evt.RoomID.String()).Msg("Failed to invalidate outbound group session") + mach.Log.Warn().Stringer("room_id", evt.RoomID).Msg("Failed to invalidate outbound group session") } } @@ -581,7 +581,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen } err = mach.CryptoStore.PutGroupSession(ctx, igs) if err != nil { - log.Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session") + log.Err(err).Stringer("session_id", sessionID).Msg("Failed to store new inbound group session") return fmt.Errorf("failed to store new inbound group session: %w", err) } mach.MarkSessionReceived(ctx, roomID, sessionID, igs.Internal.FirstKnownIndex()) diff --git a/crypto/verificationhelper/sas.go b/crypto/verificationhelper/sas.go index 1313a613..e6392c79 100644 --- a/crypto/verificationhelper/sas.go +++ b/crypto/verificationhelper/sas.go @@ -695,7 +695,7 @@ func (vh *VerificationHelper) onVerificationMAC(ctx context.Context, txn Verific // Verify the MAC for each key var theirDevice *id.Device for keyID, mac := range macEvt.MAC { - log.Info().Str("key_id", keyID.String()).Msg("Received MAC for key") + log.Info().Stringer("key_id", keyID).Msg("Received MAC for key") alg, kID := keyID.Parse() if alg != id.KeyAlgorithmEd25519 { diff --git a/example/main.go b/example/main.go index d8006d46..2bf4bef3 100644 --- a/example/main.go +++ b/example/main.go @@ -143,7 +143,7 @@ func main() { if err != nil { log.Error().Err(err).Msg("Failed to send event") } else { - log.Info().Str("event_id", resp.EventID.String()).Msg("Event sent") + log.Info().Stringer("event_id", resp.EventID).Msg("Event sent") } } cancelSync() From c18d2e2565c89512f9dc49f786bbb131636f789e Mon Sep 17 00:00:00 2001 From: Ping Chen Date: Fri, 29 Aug 2025 17:20:11 +0900 Subject: [PATCH 1342/1647] bridgev2/matrixinterface: add GetEvent interface for linkedin reply (#406) Co-authored-by: Tulir Asokan --- bridgev2/matrix/intent.go | 20 ++++++++++++++++++++ bridgev2/matrixinterface.go | 2 ++ 2 files changed, 22 insertions(+) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 7d78b5a2..2c68a692 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -674,3 +674,23 @@ func (as *ASIntent) MuteRoom(ctx context.Context, roomID id.RoomID, until time.T }) } } + +func (as *ASIntent) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*event.Event, error) { + evt, err := as.Matrix.Client.GetEvent(ctx, roomID, eventID) + if err != nil { + return nil, err + } + err = evt.Content.ParseRaw(evt.Type) + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("room_id", roomID).Stringer("event_id", eventID).Msg("failed to parse event content") + } + + if evt.Type == event.EventEncrypted { + if as.Connector.Config.Encryption.DeleteKeys.RatchetOnDecrypt { + return nil, errors.New("can't decrypt the event") + } + return as.Matrix.Crypto.Decrypt(ctx, evt) + } + + return evt, nil +} diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index b30e274a..6fa5360c 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -176,6 +176,8 @@ type MatrixAPI interface { TagRoom(ctx context.Context, roomID id.RoomID, tag event.RoomTag, isTagged bool) error MuteRoom(ctx context.Context, roomID id.RoomID, until time.Time) error + + GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (*event.Event, error) } type StreamOrderReadingMatrixAPI interface { From 8f464b5b76efeb9dea12fd460a71c38b09b5b1c3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 29 Aug 2025 16:33:59 +0300 Subject: [PATCH 1343/1647] bridgev2: move shared SNC code to provisionutil --- bridgev2/commands/startchat.go | 107 +++++------------ bridgev2/matrix/provisioning.go | 122 ++------------------ bridgev2/provisionutil/listcontacts.go | 95 +++++++++++++++ bridgev2/provisionutil/resolveidentifier.go | 85 ++++++++++++++ 4 files changed, 218 insertions(+), 191 deletions(-) create mode 100644 bridgev2/provisionutil/listcontacts.go create mode 100644 bridgev2/provisionutil/resolveidentifier.go diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go index 719d3dd5..da246f50 100644 --- a/bridgev2/commands/startchat.go +++ b/bridgev2/commands/startchat.go @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Tulir Asokan +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -7,15 +7,13 @@ package commands import ( - "context" "fmt" "html" "strings" - "time" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" - "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/bridgev2/provisionutil" ) var CommandResolveIdentifier = &FullHandler{ @@ -57,24 +55,13 @@ func getClientForStartingChat[T bridgev2.IdentifierResolvingNetworkAPI](ce *Even return login, api, remainingArgs } -func formatResolveIdentifierResult(ctx context.Context, resp *bridgev2.ResolveIdentifierResponse) string { - var targetName string - var targetMXID id.UserID - if resp.Ghost != nil { - if resp.UserInfo != nil { - resp.Ghost.UpdateInfo(ctx, resp.UserInfo) - } - targetName = resp.Ghost.Name - targetMXID = resp.Ghost.Intent.GetMXID() - } else if resp.UserInfo != nil && resp.UserInfo.Name != nil { - targetName = *resp.UserInfo.Name - } - if targetMXID != "" { - return fmt.Sprintf("`%s` / [%s](%s)", resp.UserID, targetName, targetMXID.URI().MatrixToURL()) - } else if targetName != "" { - return fmt.Sprintf("`%s` / %s", resp.UserID, targetName) +func formatResolveIdentifierResult(resp *provisionutil.RespResolveIdentifier) string { + if resp.MXID != "" { + return fmt.Sprintf("`%s` / [%s](%s)", resp.ID, resp.Name, resp.MXID.URI().MatrixToURL()) + } else if resp.Name != "" { + return fmt.Sprintf("`%s` / %s", resp.ID, resp.Name) } else { - return fmt.Sprintf("`%s`", resp.UserID) + return fmt.Sprintf("`%s`", resp.ID) } } @@ -89,57 +76,24 @@ func fnResolveIdentifier(ce *Event) { } createChat := ce.Command == "start-chat" || ce.Command == "pm" identifier := strings.Join(identifierParts, " ") - resp, err := api.ResolveIdentifier(ce.Ctx, identifier, createChat) + resp, err := provisionutil.ResolveIdentifier(ce.Ctx, login, identifier, createChat) if err != nil { - ce.Log.Err(err).Msg("Failed to resolve identifier") ce.Reply("Failed to resolve identifier: %v", err) return } else if resp == nil { ce.ReplyAdvanced(fmt.Sprintf("Identifier %s not found", html.EscapeString(identifier)), false, true) return } - formattedName := formatResolveIdentifierResult(ce.Ctx, resp) + formattedName := formatResolveIdentifierResult(resp) if createChat { - if resp.Chat == nil { - ce.Reply("Interface error: network connector did not return chat for create chat request") - return + name := resp.Portal.Name + if name == "" { + name = resp.Portal.MXID.String() } - portal := resp.Chat.Portal - if portal == nil { - portal, err = ce.Bridge.GetPortalByKey(ce.Ctx, resp.Chat.PortalKey) - if err != nil { - ce.Log.Err(err).Msg("Failed to get portal") - ce.Reply("Failed to get portal: %v", err) - return - } - } - if resp.Chat.PortalInfo == nil { - resp.Chat.PortalInfo, err = api.GetChatInfo(ce.Ctx, portal) - if err != nil { - ce.Log.Err(err).Msg("Failed to get portal info") - ce.Reply("Failed to get portal info: %v", err) - return - } - } - if portal.MXID != "" { - name := portal.Name - if name == "" { - name = portal.MXID.String() - } - portal.UpdateInfo(ce.Ctx, resp.Chat.PortalInfo, login, nil, time.Time{}) - ce.Reply("You already have a direct chat with %s at [%s](%s)", formattedName, name, portal.MXID.URI().MatrixToURL()) + if !resp.JustCreated { + ce.Reply("You already have a direct chat with %s at [%s](%s)", formattedName, name, resp.Portal.MXID.URI().MatrixToURL()) } else { - err = portal.CreateMatrixRoom(ce.Ctx, login, resp.Chat.PortalInfo) - if err != nil { - ce.Log.Err(err).Msg("Failed to create room") - ce.Reply("Failed to create room: %v", err) - return - } - name := portal.Name - if name == "" { - name = portal.MXID.String() - } - ce.Reply("Created chat with %s: [%s](%s)", formattedName, name, portal.MXID.URI().MatrixToURL()) + ce.Reply("Created chat with %s: [%s](%s)", formattedName, name, resp.Portal.MXID.URI().MatrixToURL()) } } else { ce.Reply("Found %s", formattedName) @@ -163,34 +117,25 @@ func fnSearch(ce *Event) { ce.Reply("Usage: `$cmdprefix search `") return } - _, api, queryParts := getClientForStartingChat[bridgev2.UserSearchingNetworkAPI](ce, "searching users") + login, api, queryParts := getClientForStartingChat[bridgev2.UserSearchingNetworkAPI](ce, "searching users") if api == nil { return } - results, err := api.SearchUsers(ce.Ctx, strings.Join(queryParts, " ")) + resp, err := provisionutil.SearchUsers(ce.Ctx, login, strings.Join(queryParts, " ")) if err != nil { - ce.Log.Err(err).Msg("Failed to search for users") ce.Reply("Failed to search for users: %v", err) return } - resultsString := make([]string, len(results)) - for i, res := range results { - formattedName := formatResolveIdentifierResult(ce.Ctx, res) + resultsString := make([]string, len(resp.Results)) + for i, res := range resp.Results { + formattedName := formatResolveIdentifierResult(res) resultsString[i] = fmt.Sprintf("* %s", formattedName) - if res.Chat != nil { - if res.Chat.Portal == nil { - res.Chat.Portal, err = ce.Bridge.GetExistingPortalByKey(ce.Ctx, res.Chat.PortalKey) - if err != nil { - ce.Log.Err(err).Object("portal_key", res.Chat.PortalKey).Msg("Failed to get DM portal") - } - } - if res.Chat.Portal != nil && res.Chat.Portal.MXID != "" { - portalName := res.Chat.Portal.Name - if portalName == "" { - portalName = res.Chat.Portal.MXID.String() - } - resultsString[i] = fmt.Sprintf("%s - DM portal: [%s](%s)", resultsString[i], portalName, res.Chat.Portal.MXID.URI().MatrixToURL()) + if res.Portal != nil && res.Portal.MXID != "" { + portalName := res.Portal.Name + if portalName == "" { + portalName = res.Portal.MXID.String() } + resultsString[i] = fmt.Sprintf("%s - DM portal: [%s](%s)", resultsString[i], portalName, res.Portal.MXID.URI().MatrixToURL()) } } ce.Reply("Search results:\n\n%s", strings.Join(resultsString, "\n")) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 2f202f4e..02ad6abd 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -30,6 +30,7 @@ import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/bridgev2/provisionutil" "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/federation" "maunium.net/go/mautrix/id" @@ -608,101 +609,18 @@ func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http. if login == nil { return } - api, ok := login.Client.(bridgev2.IdentifierResolvingNetworkAPI) - if !ok { - mautrix.MUnrecognized.WithMessage("This bridge does not support resolving identifiers").Write(w) - return - } - resp, err := api.ResolveIdentifier(r.Context(), r.PathValue("identifier"), createChat) + resp, err := provisionutil.ResolveIdentifier(r.Context(), login, r.PathValue("identifier"), createChat) if err != nil { - zerolog.Ctx(r.Context()).Err(err).Msg("Failed to resolve identifier") RespondWithError(w, err, "Internal error resolving identifier") - return } else if resp == nil { mautrix.MNotFound.WithMessage("Identifier not found").Write(w) - return - } - apiResp := &RespResolveIdentifier{ - ID: resp.UserID, - } - status := http.StatusOK - if resp.Ghost != nil { - if resp.UserInfo != nil { - resp.Ghost.UpdateInfo(r.Context(), resp.UserInfo) - } - apiResp.Name = resp.Ghost.Name - apiResp.AvatarURL = resp.Ghost.AvatarMXC - apiResp.Identifiers = resp.Ghost.Identifiers - apiResp.MXID = resp.Ghost.Intent.GetMXID() - } else if resp.UserInfo != nil && resp.UserInfo.Name != nil { - apiResp.Name = *resp.UserInfo.Name - } - if resp.Chat != nil { - if resp.Chat.Portal == nil { - resp.Chat.Portal, err = prov.br.Bridge.GetPortalByKey(r.Context(), resp.Chat.PortalKey) - if err != nil { - zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get portal") - mautrix.MUnknown.WithMessage("Failed to get portal").Write(w) - return - } - } - if createChat && resp.Chat.Portal.MXID == "" { + } else { + status := http.StatusOK + if resp.JustCreated { status = http.StatusCreated - err = resp.Chat.Portal.CreateMatrixRoom(r.Context(), login, resp.Chat.PortalInfo) - if err != nil { - zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create portal room") - mautrix.MUnknown.WithMessage("Failed to create portal room").Write(w) - return - } } - apiResp.DMRoomID = resp.Chat.Portal.MXID + exhttp.WriteJSONResponse(w, status, resp) } - exhttp.WriteJSONResponse(w, status, apiResp) -} - -type RespGetContactList struct { - Contacts []*RespResolveIdentifier `json:"contacts"` -} - -func (prov *ProvisioningAPI) processResolveIdentifiers(ctx context.Context, resp []*bridgev2.ResolveIdentifierResponse) (apiResp []*RespResolveIdentifier) { - apiResp = make([]*RespResolveIdentifier, len(resp)) - for i, contact := range resp { - apiContact := &RespResolveIdentifier{ - ID: contact.UserID, - } - apiResp[i] = apiContact - if contact.UserInfo != nil { - if contact.UserInfo.Name != nil { - apiContact.Name = *contact.UserInfo.Name - } - if contact.UserInfo.Identifiers != nil { - apiContact.Identifiers = contact.UserInfo.Identifiers - } - } - if contact.Ghost != nil { - if contact.Ghost.Name != "" { - apiContact.Name = contact.Ghost.Name - } - if len(contact.Ghost.Identifiers) >= len(apiContact.Identifiers) { - apiContact.Identifiers = contact.Ghost.Identifiers - } - apiContact.AvatarURL = contact.Ghost.AvatarMXC - apiContact.MXID = contact.Ghost.Intent.GetMXID() - } - if contact.Chat != nil { - if contact.Chat.Portal == nil { - var err error - contact.Chat.Portal, err = prov.br.Bridge.GetPortalByKey(ctx, contact.Chat.PortalKey) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal") - } - } - if contact.Chat.Portal != nil { - apiContact.DMRoomID = contact.Chat.Portal.MXID - } - } - } - return } func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Request) { @@ -710,20 +628,12 @@ func (prov *ProvisioningAPI) GetContactList(w http.ResponseWriter, r *http.Reque if login == nil { return } - api, ok := login.Client.(bridgev2.ContactListingNetworkAPI) - if !ok { - mautrix.MUnrecognized.WithMessage("This bridge does not support listing contacts").Write(w) - return - } - resp, err := api.GetContactList(r.Context()) + resp, err := provisionutil.GetContactList(r.Context(), login) if err != nil { - zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get contact list") - RespondWithError(w, err, "Internal error fetching contact list") + RespondWithError(w, err, "Internal error getting contact list") return } - exhttp.WriteJSONResponse(w, http.StatusOK, &RespGetContactList{ - Contacts: prov.processResolveIdentifiers(r.Context(), resp), - }) + exhttp.WriteJSONResponse(w, http.StatusOK, resp) } type ReqSearchUsers struct { @@ -746,20 +656,12 @@ func (prov *ProvisioningAPI) PostSearchUsers(w http.ResponseWriter, r *http.Requ if login == nil { return } - api, ok := login.Client.(bridgev2.UserSearchingNetworkAPI) - if !ok { - mautrix.MUnrecognized.WithMessage("This bridge does not support searching for users").Write(w) - return - } - resp, err := api.SearchUsers(r.Context(), req.Query) + resp, err := provisionutil.SearchUsers(r.Context(), login, req.Query) if err != nil { - zerolog.Ctx(r.Context()).Err(err).Msg("Failed to get contact list") - RespondWithError(w, err, "Internal error fetching contact list") + RespondWithError(w, err, "Internal error searching users") return } - exhttp.WriteJSONResponse(w, http.StatusOK, &RespSearchUsers{ - Results: prov.processResolveIdentifiers(r.Context(), resp), - }) + exhttp.WriteJSONResponse(w, http.StatusOK, resp) } func (prov *ProvisioningAPI) GetResolveIdentifier(w http.ResponseWriter, r *http.Request) { diff --git a/bridgev2/provisionutil/listcontacts.go b/bridgev2/provisionutil/listcontacts.go new file mode 100644 index 00000000..d2cf5e90 --- /dev/null +++ b/bridgev2/provisionutil/listcontacts.go @@ -0,0 +1,95 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package provisionutil + +import ( + "context" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2" +) + +type RespGetContactList struct { + Contacts []*RespResolveIdentifier `json:"contacts"` +} + +type RespSearchUsers struct { + Results []*RespResolveIdentifier `json:"results"` +} + +func GetContactList(ctx context.Context, login *bridgev2.UserLogin) (*RespGetContactList, error) { + api, ok := login.Client.(bridgev2.ContactListingNetworkAPI) + if !ok { + return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support listing contacts")) + } + resp, err := api.GetContactList(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get contact list") + return nil, err + } + return &RespGetContactList{ + Contacts: processResolveIdentifiers(ctx, login.Bridge, resp), + }, nil +} + +func SearchUsers(ctx context.Context, login *bridgev2.UserLogin, query string) (*RespSearchUsers, error) { + api, ok := login.Client.(bridgev2.UserSearchingNetworkAPI) + if !ok { + return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support searching for users")) + } + resp, err := api.SearchUsers(ctx, query) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get contact list") + return nil, err + } + return &RespSearchUsers{ + Results: processResolveIdentifiers(ctx, login.Bridge, resp), + }, nil +} + +func processResolveIdentifiers(ctx context.Context, br *bridgev2.Bridge, resp []*bridgev2.ResolveIdentifierResponse) (apiResp []*RespResolveIdentifier) { + apiResp = make([]*RespResolveIdentifier, len(resp)) + for i, contact := range resp { + apiContact := &RespResolveIdentifier{ + ID: contact.UserID, + } + apiResp[i] = apiContact + if contact.UserInfo != nil { + if contact.UserInfo.Name != nil { + apiContact.Name = *contact.UserInfo.Name + } + if contact.UserInfo.Identifiers != nil { + apiContact.Identifiers = contact.UserInfo.Identifiers + } + } + if contact.Ghost != nil { + if contact.Ghost.Name != "" { + apiContact.Name = contact.Ghost.Name + } + if len(contact.Ghost.Identifiers) >= len(apiContact.Identifiers) { + apiContact.Identifiers = contact.Ghost.Identifiers + } + apiContact.AvatarURL = contact.Ghost.AvatarMXC + apiContact.MXID = contact.Ghost.Intent.GetMXID() + } + if contact.Chat != nil { + if contact.Chat.Portal == nil { + var err error + contact.Chat.Portal, err = br.GetPortalByKey(ctx, contact.Chat.PortalKey) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal") + } + } + if contact.Chat.Portal != nil { + apiContact.DMRoomID = contact.Chat.Portal.MXID + } + } + } + return +} diff --git a/bridgev2/provisionutil/resolveidentifier.go b/bridgev2/provisionutil/resolveidentifier.go new file mode 100644 index 00000000..23813620 --- /dev/null +++ b/bridgev2/provisionutil/resolveidentifier.go @@ -0,0 +1,85 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package provisionutil + +import ( + "context" + + "github.com/rs/zerolog" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/id" +) + +type RespResolveIdentifier struct { + ID networkid.UserID `json:"id"` + Name string `json:"name,omitempty"` + AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` + Identifiers []string `json:"identifiers,omitempty"` + MXID id.UserID `json:"mxid,omitempty"` + DMRoomID id.RoomID `json:"dm_room_mxid,omitempty"` + + Portal *bridgev2.Portal `json:"-"` + Ghost *bridgev2.Ghost `json:"-"` + JustCreated bool `json:"-"` +} + +func ResolveIdentifier( + ctx context.Context, + login *bridgev2.UserLogin, + identifier string, + createChat bool, +) (*RespResolveIdentifier, error) { + api, ok := login.Client.(bridgev2.IdentifierResolvingNetworkAPI) + if !ok { + return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support resolving identifiers")) + } + resp, err := api.ResolveIdentifier(ctx, identifier, createChat) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to resolve identifier") + return nil, err + } else if resp == nil { + return nil, nil + } + apiResp := &RespResolveIdentifier{ + ID: resp.UserID, + Ghost: resp.Ghost, + } + if resp.Ghost != nil { + if resp.UserInfo != nil { + resp.Ghost.UpdateInfo(ctx, resp.UserInfo) + } + apiResp.Name = resp.Ghost.Name + apiResp.AvatarURL = resp.Ghost.AvatarMXC + apiResp.Identifiers = resp.Ghost.Identifiers + apiResp.MXID = resp.Ghost.Intent.GetMXID() + } else if resp.UserInfo != nil && resp.UserInfo.Name != nil { + apiResp.Name = *resp.UserInfo.Name + } + if resp.Chat != nil { + if resp.Chat.Portal == nil { + resp.Chat.Portal, err = login.Bridge.GetPortalByKey(ctx, resp.Chat.PortalKey) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal") + return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to get portal")) + } + } + if createChat && resp.Chat.Portal.MXID == "" { + apiResp.JustCreated = true + err = resp.Chat.Portal.CreateMatrixRoom(ctx, login, resp.Chat.PortalInfo) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to create portal room") + return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to create portal room")) + } + } + apiResp.Portal = resp.Chat.Portal + apiResp.DMRoomID = resp.Chat.Portal.MXID + } + return apiResp, nil +} From f9e3e8a30f2fd2794f015fd775ae5cc4845d8b0c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 29 Aug 2025 18:30:47 +0300 Subject: [PATCH 1344/1647] bridgev2/provisionutil: allow passing mxids to ResolveIdentifier Closes #398 --- bridgev2/provisionutil/resolveidentifier.go | 53 ++++++++++++++++++--- 1 file changed, 46 insertions(+), 7 deletions(-) diff --git a/bridgev2/provisionutil/resolveidentifier.go b/bridgev2/provisionutil/resolveidentifier.go index 23813620..5387347c 100644 --- a/bridgev2/provisionutil/resolveidentifier.go +++ b/bridgev2/provisionutil/resolveidentifier.go @@ -8,6 +8,7 @@ package provisionutil import ( "context" + "errors" "github.com/rs/zerolog" @@ -30,6 +31,8 @@ type RespResolveIdentifier struct { JustCreated bool `json:"-"` } +var ErrNoPortalKey = errors.New("network API didn't return portal key for createChat request") + func ResolveIdentifier( ctx context.Context, login *bridgev2.UserLogin, @@ -40,12 +43,44 @@ func ResolveIdentifier( if !ok { return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support resolving identifiers")) } - resp, err := api.ResolveIdentifier(ctx, identifier, createChat) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to resolve identifier") - return nil, err - } else if resp == nil { - return nil, nil + var resp *bridgev2.ResolveIdentifierResponse + parsedUserID, ok := login.Bridge.Matrix.ParseGhostMXID(id.UserID(identifier)) + validator, vOK := login.Bridge.Network.(bridgev2.IdentifierValidatingNetwork) + if ok && (!vOK || validator.ValidateUserID(parsedUserID)) { + ghost, err := login.Bridge.GetGhostByID(ctx, parsedUserID) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost by ID") + return nil, err + } + resp = &bridgev2.ResolveIdentifierResponse{ + Ghost: ghost, + UserID: parsedUserID, + } + gdcAPI, ok := api.(bridgev2.GhostDMCreatingNetworkAPI) + if ok && createChat { + resp.Chat, err = gdcAPI.CreateChatWithGhost(ctx, ghost) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to create chat") + return nil, err + } + } else if createChat || ghost.Name == "" { + zerolog.Ctx(ctx).Debug(). + Bool("create_chat", createChat). + Bool("has_name", ghost.Name != ""). + Msg("Falling back to resolving identifier") + resp = nil + identifier = string(parsedUserID) + } + } + if resp == nil { + var err error + resp, err = api.ResolveIdentifier(ctx, identifier, createChat) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to resolve identifier") + return nil, err + } else if resp == nil { + return nil, nil + } } apiResp := &RespResolveIdentifier{ ID: resp.UserID, @@ -63,7 +98,11 @@ func ResolveIdentifier( apiResp.Name = *resp.UserInfo.Name } if resp.Chat != nil { + if resp.Chat.PortalKey.IsEmpty() { + return nil, ErrNoPortalKey + } if resp.Chat.Portal == nil { + var err error resp.Chat.Portal, err = login.Bridge.GetPortalByKey(ctx, resp.Chat.PortalKey) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal") @@ -72,7 +111,7 @@ func ResolveIdentifier( } if createChat && resp.Chat.Portal.MXID == "" { apiResp.JustCreated = true - err = resp.Chat.Portal.CreateMatrixRoom(ctx, login, resp.Chat.PortalInfo) + err := resp.Chat.Portal.CreateMatrixRoom(ctx, login, resp.Chat.PortalInfo) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to create portal room") return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to create portal room")) From 050fbbd466a8791f16b3a98db433dbb168af996f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 29 Aug 2025 18:31:09 +0300 Subject: [PATCH 1345/1647] bridgev2/status: change RemoteID to a UserLoginID --- bridgev2/status/bridgestate.go | 9 +++++---- bridgev2/userlogin.go | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/bridgev2/status/bridgestate.go b/bridgev2/status/bridgestate.go index 3bc5a59b..430d4c7c 100644 --- a/bridgev2/status/bridgestate.go +++ b/bridgev2/status/bridgestate.go @@ -22,6 +22,7 @@ import ( "go.mau.fi/util/ptr" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -126,10 +127,10 @@ type BridgeState struct { UserAction BridgeStateUserAction `json:"user_action,omitempty"` - UserID id.UserID `json:"user_id,omitempty"` - RemoteID string `json:"remote_id,omitempty"` - RemoteName string `json:"remote_name,omitempty"` - RemoteProfile *RemoteProfile `json:"remote_profile,omitempty"` + UserID id.UserID `json:"user_id,omitempty"` + RemoteID networkid.UserLoginID `json:"remote_id,omitempty"` + RemoteName string `json:"remote_name,omitempty"` + RemoteProfile *RemoteProfile `json:"remote_profile,omitempty"` Reason string `json:"reason,omitempty"` Info map[string]interface{} `json:"info,omitempty"` diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 203dc122..b5fcfcd0 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -501,7 +501,7 @@ var _ status.BridgeStateFiller = (*UserLogin)(nil) func (ul *UserLogin) FillBridgeState(state status.BridgeState) status.BridgeState { state.UserID = ul.UserMXID - state.RemoteID = string(ul.ID) + state.RemoteID = ul.ID state.RemoteName = ul.RemoteName state.RemoteProfile = &ul.RemoteProfile filler, ok := ul.Client.(status.BridgeStateFiller) From 1d6bea5fe3ff73b79192f8eb2d40a7d859fb995d Mon Sep 17 00:00:00 2001 From: fmseals <115927730+fmseals@users.noreply.github.com> Date: Fri, 29 Aug 2025 16:34:06 +0000 Subject: [PATCH 1346/1647] client: fix v3/delete_devices method (#393) --- client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client.go b/client.go index 45230c1e..43fc3783 100644 --- a/client.go +++ b/client.go @@ -2471,7 +2471,7 @@ func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req * func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices) error { urlPath := cli.BuildClientURL("v3", "delete_devices") - _, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, req, nil) + _, err := cli.MakeRequest(ctx, http.MethodPost, urlPath, req, nil) return err } From cd927c27963f740ae92594d773221a7a6308afd7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 30 Aug 2025 00:06:20 +0300 Subject: [PATCH 1347/1647] event: add types for MSC4332 --- event/botcommand.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ event/content.go | 1 + event/message.go | 2 ++ event/state.go | 25 ++++++++++++++++++++++++- event/type.go | 4 +++- 5 files changed, 75 insertions(+), 2 deletions(-) create mode 100644 event/botcommand.go diff --git a/event/botcommand.go b/event/botcommand.go new file mode 100644 index 00000000..a052ebd4 --- /dev/null +++ b/event/botcommand.go @@ -0,0 +1,45 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package event + +type BotCommandsEventContent struct { + Sigil string `json:"sigil,omitempty"` + Commands []*BotCommand `json:"commands,omitempty"` +} + +type BotCommand struct { + Syntax string `json:"syntax"` + Aliases []string `json:"fi.mau.aliases,omitempty"` // Not in MSC (yet) + Arguments []*BotCommandArgument `json:"arguments,omitempty"` + Description *ExtensibleTextContainer `json:"description,omitempty"` +} + +type BotArgumentType string + +const ( + BotArgumentTypeString BotArgumentType = "string" + BotArgumentTypeEnum BotArgumentType = "enum" + BotArgumentTypeInteger BotArgumentType = "integer" + BotArgumentTypeBoolean BotArgumentType = "boolean" + BotArgumentTypeUserID BotArgumentType = "user_id" + BotArgumentTypeRoomID BotArgumentType = "room_id" + BotArgumentTypeRoomAlias BotArgumentType = "room_alias" + BotArgumentTypeEventID BotArgumentType = "event_id" +) + +type BotCommandArgument struct { + Type BotArgumentType `json:"type"` + DefaultValue any `json:"fi.mau.default_value,omitempty"` // Not in MSC (yet) + Description *ExtensibleTextContainer `json:"description,omitempty"` + Enum []string `json:"enum,omitempty"` + Variadic bool `json:"variadic,omitempty"` +} + +type BotCommandInput struct { + Syntax string `json:"syntax"` + Arguments map[string]any `json:"arguments,omitempty"` +} diff --git a/event/content.go b/event/content.go index 5924ffe3..5e093273 100644 --- a/event/content.go +++ b/event/content.go @@ -50,6 +50,7 @@ var TypeMap = map[Type]reflect.Type{ StateElementFunctionalMembers: reflect.TypeOf(ElementFunctionalMembersContent{}), StateBeeperRoomFeatures: reflect.TypeOf(RoomFeatures{}), StateBeeperDisappearingTimer: reflect.TypeOf(BeeperDisappearingTimer{}), + StateBotCommands: reflect.TypeOf(BotCommandsEventContent{}), EventMessage: reflect.TypeOf(MessageEventContent{}), EventSticker: reflect.TypeOf(MessageEventContent{}), diff --git a/event/message.go b/event/message.go index cc7c8261..b397623f 100644 --- a/event/message.go +++ b/event/message.go @@ -142,6 +142,8 @@ type MessageEventContent struct { MSC1767Audio *MSC1767Audio `json:"org.matrix.msc1767.audio,omitempty"` MSC3245Voice *MSC3245Voice `json:"org.matrix.msc3245.voice,omitempty"` + + MSC4332BotCommand *BotCommandInput `json:"org.matrix.msc4332.command,omitempty"` } func (content *MessageEventContent) GetCapMsgType() CapabilityMsgType { diff --git a/event/state.go b/event/state.go index 8711f857..ba7c608d 100644 --- a/event/state.go +++ b/event/state.go @@ -56,10 +56,33 @@ type TopicEventContent struct { // m.room.topic state event as described in [MSC3765]. // // [MSC3765]: https://github.com/matrix-org/matrix-spec-proposals/pull/3765 -type ExtensibleTopic struct { +type ExtensibleTopic = ExtensibleTextContainer + +type ExtensibleTextContainer struct { Text []ExtensibleText `json:"m.text"` } +func MakeExtensibleText(text string) *ExtensibleTextContainer { + return &ExtensibleTextContainer{ + Text: []ExtensibleText{{ + Body: text, + MimeType: "text/plain", + }}, + } +} + +func MakeExtensibleFormattedText(plaintext, html string) *ExtensibleTextContainer { + return &ExtensibleTextContainer{ + Text: []ExtensibleText{{ + Body: plaintext, + MimeType: "text/plain", + }, { + Body: html, + MimeType: "text/html", + }}, + } +} + // ExtensibleText represents the contents of an m.text field. type ExtensibleText struct { MimeType string `json:"mimetype,omitempty"` diff --git a/event/type.go b/event/type.go index 1ab8c517..3f01a067 100644 --- a/event/type.go +++ b/event/type.go @@ -112,7 +112,8 @@ func (et *Type) GuessClass() TypeClass { StatePowerLevels.Type, StateRoomName.Type, StateRoomAvatar.Type, StateServerACL.Type, StateTopic.Type, StatePinnedEvents.Type, StateTombstone.Type, StateEncryption.Type, StateBridge.Type, StateHalfShotBridge.Type, StateSpaceParent.Type, StateSpaceChild.Type, StatePolicyRoom.Type, StatePolicyServer.Type, StatePolicyUser.Type, - StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type, StateBeeperDisappearingTimer.Type: + StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type, StateBeeperDisappearingTimer.Type, + StateBotCommands.Type: return StateEventType case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type: return EphemeralEventType @@ -204,6 +205,7 @@ var ( StateElementFunctionalMembers = Type{"io.element.functional_members", StateEventType} StateBeeperRoomFeatures = Type{"com.beeper.room_features", StateEventType} StateBeeperDisappearingTimer = Type{"com.beeper.disappearing_timer", StateEventType} + StateBotCommands = Type{"org.matrix.msc4332.commands", StateEventType} ) // Message events From 61a90da14542ac4089cc4f0b7d1f79c48b5b46ec Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 1 Sep 2025 00:45:32 +0300 Subject: [PATCH 1348/1647] event: use RawMessage instead of map for bot command arguments --- event/botcommand.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/event/botcommand.go b/event/botcommand.go index a052ebd4..2b208656 100644 --- a/event/botcommand.go +++ b/event/botcommand.go @@ -6,6 +6,10 @@ package event +import ( + "encoding/json" +) + type BotCommandsEventContent struct { Sigil string `json:"sigil,omitempty"` Commands []*BotCommand `json:"commands,omitempty"` @@ -40,6 +44,6 @@ type BotCommandArgument struct { } type BotCommandInput struct { - Syntax string `json:"syntax"` - Arguments map[string]any `json:"arguments,omitempty"` + Syntax string `json:"syntax"` + Arguments json.RawMessage `json:"arguments,omitempty"` } From 0627c4227057baeec5040cd0024da179e2b7f982 Mon Sep 17 00:00:00 2001 From: "timedout (aka nexy7574)" Date: Mon, 1 Sep 2025 16:01:05 +0100 Subject: [PATCH 1349/1647] client: implement MSC4323 (#407) --- client.go | 28 ++++++++++++++++++++++++++++ requests.go | 10 ++++++++++ responses.go | 28 ++++++++++++++++++++++------ versions.go | 1 + 4 files changed, 61 insertions(+), 6 deletions(-) diff --git a/client.go b/client.go index 43fc3783..85b27923 100644 --- a/client.go +++ b/client.go @@ -2562,6 +2562,34 @@ func (cli *Client) ReportRoom(ctx context.Context, roomID id.RoomID, reason stri return err } +// UnstableGetSuspendedStatus uses MSC4323 to check if a user is suspended. +func (cli *Client) UnstableGetSuspendedStatus(ctx context.Context, userID id.UserID) (res *RespSuspended, err error) { + urlPath := cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", "suspend", userID) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, res) + return +} + +// UnstableGetLockStatus uses MSC4323 to check if a user is locked. +func (cli *Client) UnstableGetLockStatus(ctx context.Context, userID id.UserID) (res *RespLocked, err error) { + urlPath := cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", "lock", userID) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, res) + return +} + +// UnstableSetSuspendedStatus uses MSC4323 to set whether a user account is suspended. +func (cli *Client) UnstableSetSuspendedStatus(ctx context.Context, userID id.UserID, suspended bool) (res *RespSuspended, err error) { + urlPath := cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", "suspend", userID) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &ReqSuspend{Suspended: suspended}, res) + return +} + +// UnstableSetLockStatus uses MSC4323 to set whether a user account is locked. +func (cli *Client) UnstableSetLockStatus(ctx context.Context, userID id.UserID, locked bool) (res *RespLocked, err error) { + urlPath := cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", "lock", userID) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &ReqLocked{Locked: locked}, res) + return +} + func (cli *Client) AppservicePing(ctx context.Context, id, txnID string) (resp *RespAppservicePing, err error) { _, err = cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, diff --git a/requests.go b/requests.go index 9871f044..4b5ce74b 100644 --- a/requests.go +++ b/requests.go @@ -596,3 +596,13 @@ func (rgr *ReqGetRelations) Query() map[string]string { } return query } + +// ReqSuspend is the request body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323 +type ReqSuspend struct { + Suspended bool `json:"suspended"` +} + +// ReqLocked is the request body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323 +type ReqLocked struct { + Locked bool `json:"locked"` +} diff --git a/responses.go b/responses.go index 5b97b293..8ab78373 100644 --- a/responses.go +++ b/responses.go @@ -494,12 +494,13 @@ type RespBeeperBatchSend struct { // RespCapabilities is the JSON response for https://spec.matrix.org/v1.3/client-server-api/#get_matrixclientv3capabilities type RespCapabilities struct { - RoomVersions *CapRoomVersions `json:"m.room_versions,omitempty"` - ChangePassword *CapBooleanTrue `json:"m.change_password,omitempty"` - SetDisplayname *CapBooleanTrue `json:"m.set_displayname,omitempty"` - SetAvatarURL *CapBooleanTrue `json:"m.set_avatar_url,omitempty"` - ThreePIDChanges *CapBooleanTrue `json:"m.3pid_changes,omitempty"` - GetLoginToken *CapBooleanTrue `json:"m.get_login_token,omitempty"` + RoomVersions *CapRoomVersions `json:"m.room_versions,omitempty"` + ChangePassword *CapBooleanTrue `json:"m.change_password,omitempty"` + SetDisplayname *CapBooleanTrue `json:"m.set_displayname,omitempty"` + SetAvatarURL *CapBooleanTrue `json:"m.set_avatar_url,omitempty"` + ThreePIDChanges *CapBooleanTrue `json:"m.3pid_changes,omitempty"` + GetLoginToken *CapBooleanTrue `json:"m.get_login_token,omitempty"` + UnstableAccountModeration *CapUnstableAccountModeration `json:"uk.timedout.msc4323,omitempty"` Custom map[string]interface{} `json:"-"` } @@ -608,6 +609,11 @@ func (vers *CapRoomVersions) IsAvailable(version string) bool { return available } +type CapUnstableAccountModeration struct { + Suspend bool `json:"suspend"` + Lock bool `json:"lock"` +} + type RespPublicRooms struct { Chunk []*PublicRoomInfo `json:"chunk"` NextBatch string `json:"next_batch,omitempty"` @@ -699,3 +705,13 @@ type RespGetRelations struct { PrevBatch string `json:"prev_batch,omitempty"` RecursionDepth int `json:"recursion_depth,omitempty"` } + +// RespSuspended is the response body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323 +type RespSuspended struct { + Suspended bool `json:"suspended"` +} + +// RespLocked is the response body for https://github.com/matrix-org/matrix-spec-proposals/pull/4323 +type RespLocked struct { + Locked bool `json:"locked"` +} diff --git a/versions.go b/versions.go index f87bddda..c3be86cc 100644 --- a/versions.go +++ b/versions.go @@ -66,6 +66,7 @@ var ( FeatureMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"} FeatureUserRedaction = UnstableFeature{UnstableFlag: "org.matrix.msc4194"} FeatureViewRedactedContent = UnstableFeature{UnstableFlag: "fi.mau.msc2815"} + FeatureAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323"} BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"} BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"} From f8c3a95de7a10ed3a63dffff37adf39c7b21077b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 1 Sep 2025 17:01:20 +0200 Subject: [PATCH 1350/1647] bridgev2: add support for creating groups (#405) --- bridgev2/commands/startchat.go | 98 +++++++++++++++++++ bridgev2/matrix/provisioning.go | 22 ++++- bridgev2/matrix/provisioning.yaml | 81 +++++++++++++++- bridgev2/matrixinvite.go | 130 +++++++++++--------------- bridgev2/networkid/bridgeid.go | 4 +- bridgev2/networkinterface.go | 70 +++++++++++++- bridgev2/portal.go | 124 ++++++++++++++++++++---- bridgev2/portalinternal.go | 20 +++- bridgev2/provisionutil/creategroup.go | 99 ++++++++++++++++++++ event/capabilities.go | 2 +- 10 files changed, 545 insertions(+), 105 deletions(-) create mode 100644 bridgev2/provisionutil/creategroup.go diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go index da246f50..7b755064 100644 --- a/bridgev2/commands/startchat.go +++ b/bridgev2/commands/startchat.go @@ -7,13 +7,20 @@ package commands import ( + "context" "fmt" "html" + "maps" + "slices" "strings" + "github.com/rs/zerolog" + "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/bridgev2/provisionutil" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" ) var CommandResolveIdentifier = &FullHandler{ @@ -100,6 +107,97 @@ func fnResolveIdentifier(ce *Event) { } } +var CommandCreateGroup = &FullHandler{ + Func: fnCreateGroup, + Name: "create-group", + Aliases: []string{"create"}, + Help: HelpMeta{ + Section: HelpSectionChats, + Description: "Create a new group chat for the current Matrix room", + Args: "[_group type_]", + }, + RequiresLogin: true, + NetworkAPI: NetworkAPIImplements[bridgev2.GroupCreatingNetworkAPI], +} + +func getState[T any](ctx context.Context, roomID id.RoomID, evtType event.Type, provider bridgev2.MatrixConnectorWithArbitraryRoomState) (content T) { + evt, err := provider.GetStateEvent(ctx, roomID, evtType, "") + if err != nil { + zerolog.Ctx(ctx).Err(err).Stringer("event_type", evtType).Msg("Failed to get state event for group creation") + } else if evt != nil { + content, _ = evt.Content.Parsed.(T) + } + return +} + +func fnCreateGroup(ce *Event) { + ce.Bridge.Matrix.GetCapabilities() + login, api, remainingArgs := getClientForStartingChat[bridgev2.GroupCreatingNetworkAPI](ce, "creating group") + if api == nil { + return + } + stateProvider, ok := ce.Bridge.Matrix.(bridgev2.MatrixConnectorWithArbitraryRoomState) + if !ok { + ce.Reply("Matrix connector doesn't support fetching room state") + return + } + members, err := ce.Bridge.Matrix.GetMembers(ce.Ctx, ce.RoomID) + if err != nil { + ce.Log.Err(err).Msg("Failed to get room members for group creation") + ce.Reply("Failed to get room members: %v", err) + return + } + caps := ce.Bridge.Network.GetCapabilities() + params := &bridgev2.GroupCreateParams{ + Username: "", + Participants: make([]networkid.UserID, 0, len(members)-2), + Parent: nil, // TODO check space parent event + Name: getState[*event.RoomNameEventContent](ce.Ctx, ce.RoomID, event.StateRoomName, stateProvider), + Avatar: getState[*event.RoomAvatarEventContent](ce.Ctx, ce.RoomID, event.StateRoomAvatar, stateProvider), + Topic: getState[*event.TopicEventContent](ce.Ctx, ce.RoomID, event.StateTopic, stateProvider), + Disappear: getState[*event.BeeperDisappearingTimer](ce.Ctx, ce.RoomID, event.StateBeeperDisappearingTimer, stateProvider), + RoomID: ce.RoomID, + } + for userID, member := range members { + if userID == ce.User.MXID || userID == ce.Bot.GetMXID() || !member.Membership.IsInviteOrJoin() { + continue + } + if parsedUserID, ok := ce.Bridge.Matrix.ParseGhostMXID(userID); ok { + params.Participants = append(params.Participants, parsedUserID) + } else if !ce.Bridge.Config.SplitPortals { + if user, err := ce.Bridge.GetExistingUserByMXID(ce.Ctx, userID); err != nil { + ce.Log.Err(err).Stringer("user_id", userID).Msg("Failed to get user for room member") + } else if user != nil { + // TODO add user logins to participants + //for _, login := range user.GetUserLogins() { + // params.Participants = append(params.Participants, login.GetUserID()) + //} + } + } + } + + if len(caps.Provisioning.GroupCreation) == 0 { + ce.Reply("No group creation types defined in network capabilities") + return + } else if len(remainingArgs) > 0 { + params.Type = remainingArgs[0] + } else if len(caps.Provisioning.GroupCreation) == 1 { + for params.Type = range caps.Provisioning.GroupCreation { + // The loop assigns the variable we want + } + } else { + types := strings.Join(slices.Collect(maps.Keys(caps.Provisioning.GroupCreation)), "`, `") + ce.Reply("Please specify type of group to create: `%s`", types) + return + } + resp, err := provisionutil.CreateGroup(ce.Ctx, login, params) + if err != nil { + ce.Reply("Failed to create group: %v", err) + return + } + ce.Reply("Successfully created group `%s`", resp.ID) +} + var CommandSearch = &FullHandler{ Func: fnSearch, Name: "search", diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 02ad6abd..4e11aa22 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -120,6 +120,7 @@ func (prov *ProvisioningAPI) Init() { tp.Transport.TLSHandshakeTimeout = 10 * time.Second prov.Router = http.NewServeMux() prov.Router.HandleFunc("GET /v3/whoami", prov.GetWhoami) + prov.Router.HandleFunc("GET /v3/capabilities", prov.GetCapabilities) prov.Router.HandleFunc("GET /v3/login/flows", prov.GetLoginFlows) prov.Router.HandleFunc("POST /v3/login/start/{flowID}", prov.PostLoginStart) prov.Router.HandleFunc("POST /v3/login/step/{loginProcessID}/{stepID}/{stepType}", prov.PostLoginStep) @@ -129,7 +130,7 @@ func (prov *ProvisioningAPI) Init() { prov.Router.HandleFunc("POST /v3/search_users", prov.PostSearchUsers) prov.Router.HandleFunc("GET /v3/resolve_identifier/{identifier}", prov.GetResolveIdentifier) prov.Router.HandleFunc("POST /v3/create_dm/{identifier}", prov.PostCreateDM) - prov.Router.HandleFunc("POST /v3/create_group", prov.PostCreateGroup) + prov.Router.HandleFunc("POST /v3/create_group/{type}", prov.PostCreateGroup) if prov.br.Config.Provisioning.EnableSessionTransfers { prov.log.Debug().Msg("Enabling session transfer API") @@ -361,6 +362,10 @@ func (prov *ProvisioningAPI) GetLoginFlows(w http.ResponseWriter, r *http.Reques }) } +func (prov *ProvisioningAPI) GetCapabilities(w http.ResponseWriter, r *http.Request) { + exhttp.WriteJSONResponse(w, http.StatusOK, &prov.net.GetCapabilities().Provisioning) +} + var ErrNilStep = errors.New("bridge returned nil step with no error") func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Request) { @@ -673,11 +678,24 @@ func (prov *ProvisioningAPI) PostCreateDM(w http.ResponseWriter, r *http.Request } func (prov *ProvisioningAPI) PostCreateGroup(w http.ResponseWriter, r *http.Request) { + var req bridgev2.GroupCreateParams + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + zerolog.Ctx(r.Context()).Err(err).Msg("Failed to decode request body") + mautrix.MNotJSON.WithMessage("Failed to decode request body").Write(w) + return + } + req.Type = r.PathValue("type") login := prov.GetLoginForRequest(w, r) if login == nil { return } - mautrix.MUnrecognized.WithMessage("Creating groups is not yet implemented").Write(w) + resp, err := provisionutil.CreateGroup(r.Context(), login, &req) + if err != nil { + RespondWithError(w, err, "Internal error creating group") + return + } + exhttp.WriteJSONResponse(w, http.StatusOK, resp) } type ReqExportCredentials struct { diff --git a/bridgev2/matrix/provisioning.yaml b/bridgev2/matrix/provisioning.yaml index b9879ea5..5bb27272 100644 --- a/bridgev2/matrix/provisioning.yaml +++ b/bridgev2/matrix/provisioning.yaml @@ -361,14 +361,25 @@ paths: $ref: '#/components/responses/InternalError' 501: $ref: '#/components/responses/NotSupported' - /v3/create_group: + /v3/create_group/{type}: post: tags: [ snc ] summary: Create a group chat on the remote network. operationId: createGroup parameters: - $ref: "#/components/parameters/loginID" + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/GroupCreateParams' responses: + 200: + description: Identifier resolved successfully + content: + application/json: + schema: + $ref: '#/components/schemas/CreatedGroup' 401: $ref: '#/components/responses/Unauthorized' 404: @@ -572,6 +583,74 @@ components: description: The Matrix room ID of the direct chat with the user. examples: - '!OKhS0I5q2fCzdnl2qgeozDQw:t2bot.io' + GroupCreateParams: + type: object + description: | + Parameters for creating a group chat. + The /capabilities endpoint response must be checked to see which fields are actually allowed. + properties: + type: + type: string + description: The type of group to create. + examples: + - channel + username: + type: string + description: The public username for the created group. + participants: + type: array + description: The users to add to the group initially. + items: + type: string + parent: + type: object + name: + type: object + description: The `m.room.name` event content for the room. + properties: + name: + type: string + avatar: + type: object + description: The `m.room.avatar` event content for the room. + properties: + url: + type: string + format: mxc + topic: + type: object + description: The `m.room.topic` event content for the room. + properties: + topic: + type: string + disappear: + type: object + description: The `com.beeper.disappearing_timer` event content for the room. + properties: + type: + type: string + timer: + type: number + room_id: + type: string + format: matrix_room_id + description: | + An existing Matrix room ID to bridge to. + The other parameters must be already in sync with the room state when using this parameter. + CreatedGroup: + type: object + description: A successfully created group chat. + required: [id, mxid] + properties: + id: + type: string + description: The internal chat ID of the created group. + mxid: + type: string + format: matrix_room_id + description: The Matrix room ID of the portal. + examples: + - '!OKhS0I5q2fCzdnl2qgeozDQw:t2bot.io' LoginStep: type: object description: A step in a login process. diff --git a/bridgev2/matrixinvite.go b/bridgev2/matrixinvite.go index bfbabd26..2c14cc7f 100644 --- a/bridgev2/matrixinvite.go +++ b/bridgev2/matrixinvite.go @@ -206,72 +206,64 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen return EventHandlingResultFailed } - didSetPortal := portal.setMXIDToExistingRoom(ctx, evt.RoomID) - if didSetPortal { - message := "Private chat portal created" - err = br.givePowerToBot(ctx, evt.RoomID, invitedGhost.Intent) - hasWarning := false - if err != nil { - log.Warn().Err(err).Msg("Failed to give power to bot in new DM") - message += "\n\nWarning: failed to promote bot" - hasWarning = true - } - if resp.DMRedirectedTo != "" && resp.DMRedirectedTo != invitedGhost.ID { - log.Debug(). - Str("dm_redirected_to_id", string(resp.DMRedirectedTo)). - Msg("Created DM was redirected to another user ID") - _, err = invitedGhost.Intent.SendState(ctx, portal.MXID, event.StateMember, invitedGhost.Intent.GetMXID().String(), &event.Content{ - Parsed: &event.MemberEventContent{ - Membership: event.MembershipLeave, - Reason: "Direct chat redirected to another internal user ID", - }, - }, time.Time{}) - if err != nil { - log.Err(err).Msg("Failed to make incorrect ghost leave new DM room") - } - otherUserGhost, err := br.GetGhostByID(ctx, resp.DMRedirectedTo) - if err != nil { - log.Err(err).Msg("Failed to get ghost of real portal other user ID") - } else { - invitedGhost = otherUserGhost - } - } - if resp.PortalInfo != nil { - portal.UpdateInfo(ctx, resp.PortalInfo, sourceLogin, nil, time.Time{}) - } else { - portal.UpdateCapabilities(ctx, sourceLogin, true) - portal.UpdateBridgeInfo(ctx) - } - // TODO this might become unnecessary if UpdateInfo starts taking care of it - _, err = br.Bot.SendState(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.Content{ - Parsed: &event.ElementFunctionalMembersContent{ - ServiceMembers: []id.UserID{br.Bot.GetMXID()}, + portal.roomCreateLock.Lock() + defer portal.roomCreateLock.Unlock() + portalMXID := portal.MXID + if portalMXID != "" { + sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "You already have a direct chat with me at [%s](%s)", portalMXID, portalMXID.URI(br.Matrix.ServerName()).MatrixToURL()) + rejectInvite(ctx, evt, br.Bot, "") + return EventHandlingResultSuccess + } + err = br.givePowerToBot(ctx, evt.RoomID, invitedGhost.Intent) + if err != nil { + log.Err(err).Msg("Failed to give permissions to bridge bot") + sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "Failed to give permissions to bridge bot") + rejectInvite(ctx, evt, br.Bot, "") + return EventHandlingResultSuccess + } + if resp.DMRedirectedTo != "" && resp.DMRedirectedTo != invitedGhost.ID { + log.Debug(). + Str("dm_redirected_to_id", string(resp.DMRedirectedTo)). + Msg("Created DM was redirected to another user ID") + _, err = invitedGhost.Intent.SendState(ctx, portal.MXID, event.StateMember, invitedGhost.Intent.GetMXID().String(), &event.Content{ + Parsed: &event.MemberEventContent{ + Membership: event.MembershipLeave, + Reason: "Direct chat redirected to another internal user ID", }, }, time.Time{}) if err != nil { - log.Warn().Err(err).Msg("Failed to set service members in room") - if !hasWarning { - message += "\n\nWarning: failed to set service members" - hasWarning = true - } + log.Err(err).Msg("Failed to make incorrect ghost leave new DM room") } - mx, ok := br.Matrix.(MatrixConnectorWithPostRoomBridgeHandling) - if ok { - err = mx.HandleNewlyBridgedRoom(ctx, evt.RoomID) - if err != nil { - if hasWarning { - message += fmt.Sprintf(", %s", err.Error()) - } else { - message += fmt.Sprintf("\n\nWarning: %s", err.Error()) - } - } + otherUserGhost, err := br.GetGhostByID(ctx, resp.DMRedirectedTo) + if err != nil { + log.Err(err).Msg("Failed to get ghost of real portal other user ID") + } else { + invitedGhost = otherUserGhost } - sendNotice(ctx, evt, invitedGhost.Intent, message) - } else { - // TODO ensure user is invited even if PortalInfo wasn't provided? - sendErrorAndLeave(ctx, evt, invitedGhost.Intent, "You already have a direct chat with me at [%s](%s)", portal.MXID, portal.MXID.URI(br.Matrix.ServerName()).MatrixToURL()) - rejectInvite(ctx, evt, br.Bot, "") } + err = portal.UpdateMatrixRoomID(ctx, evt.RoomID, UpdateMatrixRoomIDParams{ + // We locked it before checking the mxid + RoomCreateAlreadyLocked: true, + + FailIfMXIDSet: true, + ChatInfo: resp.PortalInfo, + ChatInfoSource: sourceLogin, + }) + if err != nil { + log.Err(err).Msg("Failed to update Matrix room ID for new DM portal") + sendNotice(ctx, evt, invitedGhost.Intent, "Failed to finish configuring portal. The chat may or may not work") + return EventHandlingResultSuccess + } + message := "Private chat portal created" + mx, ok := br.Matrix.(MatrixConnectorWithPostRoomBridgeHandling) + if ok { + err = mx.HandleNewlyBridgedRoom(ctx, evt.RoomID) + if err != nil { + log.Err(err).Msg("Error in connector newly bridged room handler") + message += fmt.Sprintf("\n\nWarning: %s", err.Error()) + } + } + sendNotice(ctx, evt, invitedGhost.Intent, message) return EventHandlingResultSuccess } @@ -294,21 +286,3 @@ func (br *Bridge) givePowerToBot(ctx context.Context, roomID id.RoomID, userWith } return nil } - -func (portal *Portal) setMXIDToExistingRoom(ctx context.Context, roomID id.RoomID) bool { - portal.roomCreateLock.Lock() - defer portal.roomCreateLock.Unlock() - if portal.MXID != "" { - return false - } - portal.MXID = roomID - portal.updateLogger() - portal.Bridge.cacheLock.Lock() - portal.Bridge.portalsByMXID[portal.MXID] = portal - portal.Bridge.cacheLock.Unlock() - err := portal.Save(ctx) - if err != nil { - zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal after updating mxid") - } - return true -} diff --git a/bridgev2/networkid/bridgeid.go b/bridgev2/networkid/bridgeid.go index 443d3655..e3a6df70 100644 --- a/bridgev2/networkid/bridgeid.go +++ b/bridgev2/networkid/bridgeid.go @@ -47,8 +47,8 @@ type PortalID string // As a special case, Receiver MUST be set if the Bridge.Config.SplitPortals flag is set to true. // The flag is intended for puppeting-only bridges which want multiple logins to create separate portals for each user. type PortalKey struct { - ID PortalID - Receiver UserLoginID + ID PortalID `json:"portal_id"` + Receiver UserLoginID `json:"portal_receiver,omitempty"` } func (pk PortalKey) IsEmpty() bool { diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index dcbcbad5..8293be51 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -350,6 +350,8 @@ type NetworkGeneralCapabilities struct { // to handle asynchronous message responses, this field can be set to enable // automatic timeout errors in case the asynchronous response never arrives. OutgoingMessageTimeouts *OutgoingTimeoutConfig + // Capabilities related to the provisioning API. + Provisioning ProvisioningCapabilities } // NetworkAPI is an interface representing a remote network client for a single user login. @@ -750,9 +752,75 @@ type UserSearchingNetworkAPI interface { SearchUsers(ctx context.Context, query string) ([]*ResolveIdentifierResponse, error) } +type ProvisioningCapabilities struct { + ResolveIdentifier ResolveIdentifierCapabilities `json:"resolve_identifier"` + GroupCreation map[string]GroupTypeCapabilities `json:"group_creation"` +} + +type ResolveIdentifierCapabilities struct { + // Can DMs be created after resolving an identifier? + CreateDM bool `json:"create_dm"` + // Can users be looked up by phone number? + LookupPhone bool `json:"lookup_phone"` + // Can users be looked up by email address? + LookupEmail bool `json:"lookup_email"` + // Can users be looked up by network-specific username? + LookupUsername bool `json:"lookup_username"` + // Can any phone number be contacted without having to validate it via lookup first? + AnyPhone bool `json:"any_phone"` + // Can a contact list be retrieved from the bridge? + ContactList bool `json:"contact_list"` + // Can users be searched by name on the remote network? + Search bool `json:"search"` +} + +type GroupTypeCapabilities struct { + TypeDescription string `json:"type_description"` + + Name GroupFieldCapability `json:"name"` + Username GroupFieldCapability `json:"username"` + Avatar GroupFieldCapability `json:"avatar"` + Topic GroupFieldCapability `json:"topic"` + Disappear GroupFieldCapability `json:"disappear"` + Participants GroupFieldCapability `json:"participants"` + Parent GroupFieldCapability `json:"parent"` +} + +type GroupFieldCapability struct { + // Is setting this field allowed at all in the create request? + // Even if false, the network connector should attempt to set the metadata after group creation, + // as the allowed flag can't be enforced properly when creating a group for an existing Matrix room. + Allowed bool `json:"allowed"` + // Is setting this field mandatory for the creation to succeed? + Required bool `json:"required,omitempty"` + // The minimum/maximum length of the field, if applicable. + // For members, length means the number of members excluding the creator. + MinLength int `json:"min_length,omitempty"` + MaxLength int `json:"max_length,omitempty"` + + // Only for the disappear field: allowed disappearing settings + DisappearSettings *event.DisappearingTimerCapability `json:"settings,omitempty"` +} + +type GroupCreateParams struct { + Type string `json:"type"` + + Username string `json:"username"` + Participants []networkid.UserID `json:"participants"` + Parent *networkid.PortalKey `json:"parent"` + + Name *event.RoomNameEventContent `json:"name"` + Avatar *event.RoomAvatarEventContent `json:"avatar"` + Topic *event.TopicEventContent `json:"topic"` + Disappear *event.BeeperDisappearingTimer `json:"disappear"` + + // An existing room ID to bridge to. If unset, a new room will be created. + RoomID id.RoomID `json:"room_id"` +} + type GroupCreatingNetworkAPI interface { IdentifierResolvingNetworkAPI - CreateGroup(ctx context.Context, name string, users ...networkid.UserID) (*CreateChatResponse, error) + CreateGroup(ctx context.Context, params *GroupCreateParams) (*CreateChatResponse, error) } type MembershipChangeType struct { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 5e0a9137..85d670d9 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1840,42 +1840,134 @@ func (portal *Portal) handleMatrixTombstone(ctx context.Context, evt *event.Even return EventHandlingResultIgnored } } - - portal.Bridge.cacheLock.Lock() - if _, alreadyExists := portal.Bridge.portalsByMXID[content.ReplacementRoom]; alreadyExists { - log.Warn().Msg("Replacement room is already a portal, ignoring tombstone") - portal.Bridge.cacheLock.Unlock() + err = portal.UpdateMatrixRoomID(ctx, content.ReplacementRoom, UpdateMatrixRoomIDParams{ + DeleteOldRoom: true, + FetchInfoVia: senderUser, + }) + if errors.Is(err, ErrTargetRoomIsPortal) { return EventHandlingResultIgnored + } else if err != nil { + return EventHandlingResultFailed.WithError(err) } - delete(portal.Bridge.portalsByMXID, portal.MXID) - portal.MXID = content.ReplacementRoom + return EventHandlingResultSuccess +} + +var ErrTargetRoomIsPortal = errors.New("target room is already a portal") +var ErrRoomAlreadyExists = errors.New("this portal already has a room") + +type UpdateMatrixRoomIDParams struct { + SyncDBMetadata func() + FailIfMXIDSet bool + OverwriteOldPortal bool + TombstoneOldRoom bool + DeleteOldRoom bool + + RoomCreateAlreadyLocked bool + + FetchInfoVia *User + ChatInfo *ChatInfo + ChatInfoSource *UserLogin +} + +func (portal *Portal) UpdateMatrixRoomID( + ctx context.Context, + newRoomID id.RoomID, + params UpdateMatrixRoomIDParams, +) error { + if !params.RoomCreateAlreadyLocked { + portal.roomCreateLock.Lock() + defer portal.roomCreateLock.Unlock() + } + oldRoom := portal.MXID + if oldRoom == newRoomID { + return nil + } else if oldRoom != "" && params.FailIfMXIDSet { + return ErrRoomAlreadyExists + } + log := zerolog.Ctx(ctx) + portal.Bridge.cacheLock.Lock() + // Wrap unlock in a sync.OnceFunc because we want to both defer it to catch early returns + // and unlock it before return if nothing goes wrong. + unlockCacheLock := sync.OnceFunc(portal.Bridge.cacheLock.Unlock) + defer unlockCacheLock() + if existingPortal, alreadyExists := portal.Bridge.portalsByMXID[newRoomID]; alreadyExists && !params.OverwriteOldPortal { + log.Warn().Msg("Replacement room is already a portal, ignoring") + return ErrTargetRoomIsPortal + } else if alreadyExists { + log.Debug().Msg("Replacement room is already a portal, overwriting") + existingPortal.MXID = "" + err := existingPortal.Save(ctx) + if err != nil { + return fmt.Errorf("failed to clear mxid of existing portal: %w", err) + } + delete(portal.Bridge.portalsByMXID, portal.MXID) + } + portal.MXID = newRoomID portal.Bridge.portalsByMXID[portal.MXID] = portal portal.NameSet = false portal.AvatarSet = false portal.TopicSet = false portal.InSpace = false portal.CapState = database.CapabilityState{} - portal.Bridge.cacheLock.Unlock() + portal.lastCapUpdate = time.Time{} + if params.SyncDBMetadata != nil { + params.SyncDBMetadata() + } + unlockCacheLock() + portal.updateLogger() - err = portal.Save(ctx) + err := portal.Save(ctx) if err != nil { - log.Err(err).Msg("Failed to save portal after tombstone") - return EventHandlingResultFailed.WithError(err) + log.Err(err).Msg("Failed to save portal in UpdateMatrixRoomID") + return err } log.Info().Msg("Successfully followed tombstone and updated portal MXID") err = portal.Bridge.DB.UserPortal.MarkAllNotInSpace(ctx, portal.PortalKey) if err != nil { - log.Err(err).Msg("Failed to update in_space flag for user portals after tombstone") + log.Err(err).Msg("Failed to update in_space flag for user portals after updating portal MXID") } go portal.addToUserSpaces(ctx) - go portal.updateInfoAfterTombstone(ctx, senderUser) + if params.FetchInfoVia != nil { + go portal.updateInfoAfterTombstone(ctx, params.FetchInfoVia) + } else if params.ChatInfo != nil { + go portal.UpdateInfo(ctx, params.ChatInfo, params.ChatInfoSource, nil, time.Time{}) + } else if params.ChatInfoSource != nil { + portal.UpdateCapabilities(ctx, params.ChatInfoSource, true) + portal.UpdateBridgeInfo(ctx) + } go func() { - err = portal.Bridge.Bot.DeleteRoom(ctx, evt.RoomID, true) + // TODO this might become unnecessary if UpdateInfo starts taking care of it + _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateElementFunctionalMembers, "", &event.Content{ + Parsed: &event.ElementFunctionalMembersContent{ + ServiceMembers: []id.UserID{portal.Bridge.Bot.GetMXID()}, + }, + }, time.Time{}) if err != nil { - log.Err(err).Msg("Failed to clean up Matrix room after following tombstone") + if err != nil { + log.Warn().Err(err).Msg("Failed to set service members in new room") + } } }() - return EventHandlingResultSuccess + if params.TombstoneOldRoom && oldRoom != "" { + _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateTombstone, "", &event.Content{ + Parsed: &event.TombstoneEventContent{ + Body: "Room has been replaced.", + ReplacementRoom: newRoomID, + }, + }, time.Now()) + if err != nil { + log.Err(err).Msg("Failed to send tombstone event to old room") + } + } + if params.DeleteOldRoom && oldRoom != "" { + go func() { + err = portal.Bridge.Bot.DeleteRoom(ctx, oldRoom, true) + if err != nil { + log.Err(err).Msg("Failed to clean up old Matrix room after updating portal MXID") + } + }() + } + return nil } func (portal *Portal) updateInfoAfterTombstone(ctx context.Context, senderUser *User) { diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index e82c481a..0223b4f2 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -125,6 +125,14 @@ func (portal *PortalInternals) HandleMatrixPowerLevels(ctx context.Context, send return (*Portal)(portal).handleMatrixPowerLevels(ctx, sender, origSender, evt) } +func (portal *PortalInternals) HandleMatrixTombstone(ctx context.Context, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixTombstone(ctx, evt) +} + +func (portal *PortalInternals) UpdateInfoAfterTombstone(ctx context.Context, senderUser *User) { + (*Portal)(portal).updateInfoAfterTombstone(ctx, senderUser) +} + func (portal *PortalInternals) HandleMatrixRedaction(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { return (*Portal)(portal).handleMatrixRedaction(ctx, sender, origSender, evt) } @@ -133,6 +141,10 @@ func (portal *PortalInternals) HandleRemoteEvent(ctx context.Context, source *Us return (*Portal)(portal).handleRemoteEvent(ctx, source, evtType, evt) } +func (portal *PortalInternals) EnsureFunctionalMember(ctx context.Context, ghost *Ghost) { + (*Portal)(portal).ensureFunctionalMember(ctx, ghost) +} + func (portal *PortalInternals) GetIntentAndUserMXIDFor(ctx context.Context, sender EventSender, source *UserLogin, otherLogins []*UserLogin, evtType RemoteEventType) (intent MatrixAPI, extraUserID id.UserID, err error) { return (*Portal)(portal).getIntentAndUserMXIDFor(ctx, sender, source, otherLogins, evtType) } @@ -297,6 +309,10 @@ func (portal *PortalInternals) CreateMatrixRoomInLoop(ctx context.Context, sourc return (*Portal)(portal).createMatrixRoomInLoop(ctx, source, info, backfillBundle) } +func (portal *PortalInternals) AddToUserSpaces(ctx context.Context) { + (*Portal)(portal).addToUserSpaces(ctx) +} + func (portal *PortalInternals) RemoveInPortalCache(ctx context.Context) { (*Portal)(portal).removeInPortalCache(ctx) } @@ -360,7 +376,3 @@ func (portal *PortalInternals) AddToParentSpaceAndSave(ctx context.Context, save func (portal *PortalInternals) ToggleSpace(ctx context.Context, spaceID id.RoomID, canonical, remove bool) error { return (*Portal)(portal).toggleSpace(ctx, spaceID, canonical, remove) } - -func (portal *PortalInternals) SetMXIDToExistingRoom(ctx context.Context, roomID id.RoomID) bool { - return (*Portal)(portal).setMXIDToExistingRoom(ctx, roomID) -} diff --git a/bridgev2/provisionutil/creategroup.go b/bridgev2/provisionutil/creategroup.go new file mode 100644 index 00000000..891f9615 --- /dev/null +++ b/bridgev2/provisionutil/creategroup.go @@ -0,0 +1,99 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package provisionutil + +import ( + "context" + + "github.com/rs/zerolog" + "go.mau.fi/util/ptr" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/id" +) + +type RespCreateGroup struct { + ID networkid.PortalID `json:"id"` + MXID id.RoomID `json:"mxid"` + Portal *bridgev2.Portal `json:"-"` +} + +func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev2.GroupCreateParams) (*RespCreateGroup, error) { + api, ok := login.Client.(bridgev2.GroupCreatingNetworkAPI) + if !ok { + return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support creating groups")) + } + caps := login.Bridge.Network.GetCapabilities() + typeSpec, validType := caps.Provisioning.GroupCreation[params.Type] + if !validType { + return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("Unrecognized group type %s", params.Type)) + } + if len(params.Participants) < typeSpec.Participants.MinLength { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Must have at least %d members", typeSpec.Participants.MinLength)) + } + if (params.Name == nil || params.Name.Name == "") && typeSpec.Name.Required { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name is required")) + } else if nameLen := len(ptr.Val(params.Name).Name); nameLen > 0 && nameLen < typeSpec.Name.MinLength { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name must be at least %d characters", typeSpec.Name.MinLength)) + } else if nameLen > typeSpec.Name.MaxLength { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name must be at most %d characters", typeSpec.Name.MaxLength)) + } + if (params.Avatar == nil || params.Avatar.URL == "") && typeSpec.Avatar.Required { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Avatar is required")) + } + if (params.Topic == nil || params.Topic.Topic == "") && typeSpec.Topic.Required { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic is required")) + } else if topicLen := len(ptr.Val(params.Topic).Topic); topicLen > 0 && topicLen < typeSpec.Topic.MinLength { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic must be at least %d characters", typeSpec.Topic.MinLength)) + } else if topicLen > typeSpec.Topic.MaxLength { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic must be at most %d characters", typeSpec.Topic.MaxLength)) + } + if (params.Disappear == nil || params.Disappear.Timer.Duration == 0) && typeSpec.Disappear.Required { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Disappearing timer is required")) + } else if !typeSpec.Disappear.DisappearSettings.Supports(params.Disappear) { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Unsupported value for disappearing timer")) + } + if params.Username == "" && typeSpec.Username.Required { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username is required")) + } else if len(params.Username) > 0 && len(params.Username) < typeSpec.Username.MinLength { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username must be at least %d characters", typeSpec.Username.MinLength)) + } else if len(params.Username) > typeSpec.Username.MaxLength { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username must be at most %d characters", typeSpec.Username.MaxLength)) + } + if params.Parent == nil && typeSpec.Parent.Required { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Parent is required")) + } + resp, err := api.CreateGroup(ctx, params) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to create group") + return nil, err + } + if resp.PortalKey.IsEmpty() { + return nil, ErrNoPortalKey + } + if resp.Portal == nil { + resp.Portal, err = login.Bridge.GetPortalByKey(ctx, resp.PortalKey) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get portal") + return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to get portal")) + } + } + if resp.Portal.MXID == "" { + err = resp.Portal.CreateMatrixRoom(ctx, login, resp.PortalInfo) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to create portal room") + return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to create portal room")) + } + } + return &RespCreateGroup{ + ID: resp.Portal.ID, + MXID: resp.Portal.MXID, + Portal: resp.Portal, + }, nil +} diff --git a/event/capabilities.go b/event/capabilities.go index ebedb6a2..94662428 100644 --- a/event/capabilities.go +++ b/event/capabilities.go @@ -77,7 +77,7 @@ type DisappearingTimerCapability struct { } func (dtc *DisappearingTimerCapability) Supports(content *BeeperDisappearingTimer) bool { - if dtc == nil || content.Type == DisappearingTypeNone { + if dtc == nil || content == nil || content.Type == DisappearingTypeNone { return true } return slices.Contains(dtc.Types, content.Type) && (dtc.Timers == nil || slices.Contains(dtc.Timers, content.Timer)) From bcd0a70bdfb1d44b78e89465a2fb3b2de44cfb4e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 2 Sep 2025 00:31:15 +0300 Subject: [PATCH 1351/1647] appservice/websocket: override read limit --- appservice/websocket.go | 1 + 1 file changed, 1 insertion(+) diff --git a/appservice/websocket.go b/appservice/websocket.go index 18768098..309cc485 100644 --- a/appservice/websocket.go +++ b/appservice/websocket.go @@ -412,6 +412,7 @@ func (as *AppService) StartWebsocket(ctx context.Context, baseURL string, onConn } }) } + ws.SetReadLimit(50 * 1024 * 1024) as.ws = ws as.StopWebsocket = stopFunc as.PrepareWebsocket() From 8f8b26d815b11a00847afdd4e8fe332b2e14f137 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 2 Sep 2025 10:33:49 +0300 Subject: [PATCH 1352/1647] event: add is_animated flag from MSC4230 --- event/message.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/event/message.go b/event/message.go index b397623f..692382cf 100644 --- a/event/message.go +++ b/event/message.go @@ -301,7 +301,8 @@ type FileInfo struct { Blurhash string AnoaBlurhash string - MauGIF bool + MauGIF bool + IsAnimated bool Width int Height int @@ -318,7 +319,8 @@ type serializableFileInfo struct { Blurhash string `json:"blurhash,omitempty"` AnoaBlurhash string `json:"xyz.amorgan.blurhash,omitempty"` - MauGIF bool `json:"fi.mau.gif,omitempty"` + MauGIF bool `json:"fi.mau.gif,omitempty"` + IsAnimated bool `json:"is_animated,omitempty"` Width json.Number `json:"w,omitempty"` Height json.Number `json:"h,omitempty"` @@ -336,7 +338,8 @@ func (sfi *serializableFileInfo) CopyFrom(fileInfo *FileInfo) *serializableFileI ThumbnailInfo: (&serializableFileInfo{}).CopyFrom(fileInfo.ThumbnailInfo), ThumbnailFile: fileInfo.ThumbnailFile, - MauGIF: fileInfo.MauGIF, + MauGIF: fileInfo.MauGIF, + IsAnimated: fileInfo.IsAnimated, Blurhash: fileInfo.Blurhash, AnoaBlurhash: fileInfo.AnoaBlurhash, @@ -367,6 +370,7 @@ func (sfi *serializableFileInfo) CopyTo(fileInfo *FileInfo) { ThumbnailURL: sfi.ThumbnailURL, ThumbnailFile: sfi.ThumbnailFile, MauGIF: sfi.MauGIF, + IsAnimated: sfi.IsAnimated, Blurhash: sfi.Blurhash, AnoaBlurhash: sfi.AnoaBlurhash, } From 709f48f2b3703a91871b2d7d0f4ac7fb1dc71d5d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 2 Sep 2025 18:24:24 +0300 Subject: [PATCH 1353/1647] bridgev2/provisioning: remove unused structs --- bridgev2/matrix/provisioning.go | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 4e11aa22..61aad869 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -600,15 +600,6 @@ func RespondWithError(w http.ResponseWriter, err error, message string) { } } -type RespResolveIdentifier struct { - ID networkid.UserID `json:"id"` - Name string `json:"name,omitempty"` - AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` - Identifiers []string `json:"identifiers,omitempty"` - MXID id.UserID `json:"mxid,omitempty"` - DMRoomID id.RoomID `json:"dm_room_mxid,omitempty"` -} - func (prov *ProvisioningAPI) doResolveIdentifier(w http.ResponseWriter, r *http.Request, createChat bool) { login := prov.GetLoginForRequest(w, r) if login == nil { @@ -645,10 +636,6 @@ type ReqSearchUsers struct { Query string `json:"query"` } -type RespSearchUsers struct { - Results []*RespResolveIdentifier `json:"results"` -} - func (prov *ProvisioningAPI) PostSearchUsers(w http.ResponseWriter, r *http.Request) { var req ReqSearchUsers err := json.NewDecoder(r.Body).Decode(&req) From 30ab68f7f18fe657bb0395ba141e0af284d01d3e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 4 Sep 2025 18:18:53 +0300 Subject: [PATCH 1354/1647] appservice: maybe fix url template raw path for unix sockets --- appservice/appservice.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/appservice/appservice.go b/appservice/appservice.go index 33b53d7d..d7037ef6 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -334,7 +334,7 @@ func (as *AppService) SetHomeserverURL(homeserverURL string) error { } else if as.hsURLForClient.Scheme == "" { as.hsURLForClient.Scheme = "https" } - as.hsURLForClient.RawPath = parsedURL.EscapedPath() + as.hsURLForClient.RawPath = as.hsURLForClient.EscapedPath() jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) as.HTTPClient = &http.Client{Timeout: 180 * time.Second, Jar: jar} From 41bbe4ace4c7a35b942b2761fb221caed8ee3d8a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 9 Sep 2025 16:24:18 +0300 Subject: [PATCH 1355/1647] bridgev2/portal: add action message metadata to disappearing notices --- bridgev2/portal.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 85d670d9..f3797247 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -4191,6 +4191,14 @@ func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting dat content := DisappearingMessageNotice(setting.Timer, opts.Implicit) _, err := opts.Sender.SendMessage(ctx, portal.MXID, event.EventMessage, &event.Content{ Parsed: content, + Raw: map[string]any{ + "com.beeper.action_message": map[string]any{ + "type": "disappearing_timer", + "timer": setting.Timer.Milliseconds(), + "timer_type": setting.Type, + "implicit": opts.Implicit, + }, + }, }, &MatrixSendExtra{Timestamp: opts.Timestamp}) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to send disappearing messages notice") From e295028ffd4409d56b47b58d13606474f6187f51 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 9 Sep 2025 19:10:07 +0300 Subject: [PATCH 1356/1647] client: stabilize arbitrary profile field support --- client.go | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/client.go b/client.go index 85b27923..71c3fb18 100644 --- a/client.go +++ b/client.go @@ -1088,8 +1088,7 @@ func (cli *Client) GetRoomSummary(ctx context.Context, roomIDOrAlias string, via // GetDisplayName returns the display name of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseriddisplayname func (cli *Client) GetDisplayName(ctx context.Context, mxid id.UserID) (resp *RespUserDisplayName, err error) { - urlPath := cli.BuildClientURL("v3", "profile", mxid, "displayname") - _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + err = cli.GetProfileField(ctx, mxid, "displayname", &resp) return } @@ -1100,41 +1099,38 @@ func (cli *Client) GetOwnDisplayName(ctx context.Context) (resp *RespUserDisplay // SetDisplayName sets the user's profile display name. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3profileuseriddisplayname func (cli *Client) SetDisplayName(ctx context.Context, displayName string) (err error) { - urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, "displayname") - s := struct { - DisplayName string `json:"displayname"` - }{displayName} - _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &s, nil) - return + return cli.SetProfileField(ctx, "displayname", displayName) } -// UnstableSetProfileField sets an arbitrary MSC4133 profile field. See https://github.com/matrix-org/matrix-spec-proposals/pull/4133 -func (cli *Client) UnstableSetProfileField(ctx context.Context, key string, value any) (err error) { - urlPath := cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key) +// SetProfileField sets an arbitrary profile field. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3profileuseridkeyname +func (cli *Client) SetProfileField(ctx context.Context, key string, value any) (err error) { + urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, key) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, map[string]any{ key: value, }, nil) return } -// UnstableDeleteProfileField deletes an arbitrary MSC4133 profile field. See https://github.com/matrix-org/matrix-spec-proposals/pull/4133 -func (cli *Client) UnstableDeleteProfileField(ctx context.Context, key string) (err error) { - urlPath := cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key) +// DeleteProfileField deletes an arbitrary profile field. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3profileuseridkeyname +func (cli *Client) DeleteProfileField(ctx context.Context, key string) (err error) { + urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, key) _, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, nil) return } +// GetProfileField gets an arbitrary profile field and parses the response into the given struct. See https://spec.matrix.org/unstable/client-server-api/#get_matrixclientv3profileuseridkeyname +func (cli *Client) GetProfileField(ctx context.Context, userID id.UserID, key string, into any) (err error) { + urlPath := cli.BuildClientURL("v3", "profile", userID, key) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, into) + return +} + // GetAvatarURL gets the avatar URL of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseridavatar_url func (cli *Client) GetAvatarURL(ctx context.Context, mxid id.UserID) (url id.ContentURI, err error) { - urlPath := cli.BuildClientURL("v3", "profile", mxid, "avatar_url") s := struct { AvatarURL id.ContentURI `json:"avatar_url"` }{} - - _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &s) - if err != nil { - return - } + err = cli.GetProfileField(ctx, mxid, "avatar_url", &s) url = s.AvatarURL return } From 22a908d8d63d0c7f3c98bffd9f07e5c245bc1907 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 10 Sep 2025 16:24:43 +0300 Subject: [PATCH 1357/1647] crypto/decryptolm: add debug logs for failing to decrypt with new session --- crypto/decryptolm.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go index b961a7b4..f54210a7 100644 --- a/crypto/decryptolm.go +++ b/crypto/decryptolm.go @@ -17,6 +17,8 @@ import ( "time" "github.com/rs/zerolog" + "go.mau.fi/util/exerrors" + "go.mau.fi/util/ptr" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -180,6 +182,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U log = log.With().Str("new_olm_session_id", session.ID().String()).Logger() log.Debug(). Hex("ciphertext_hash", ciphertextHash[:]). + Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]). Str("olm_session_description", session.Describe()). Msg("Created inbound olm session") ctx = log.WithContext(ctx) @@ -189,6 +192,12 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U endTimeTrace() if err != nil { go mach.unwedgeDevice(log, sender, senderKey) + log.Debug(). + Hex("ciphertext_hash", ciphertextHash[:]). + Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]). + Str("ciphertext", ciphertext). + Str("olm_session_description", session.Describe()). + Msg("DEBUG: Failed to decrypt prekey olm message with newly created session") return nil, fmt.Errorf("failed to decrypt olm event with session created from prekey message: %w", err) } From faa1c5ff8d97be5236f8dd4a09ec70cc86c67ed5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 10 Sep 2025 16:46:05 +0300 Subject: [PATCH 1358/1647] crypto/machine: log when loading olm account --- crypto/machine.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/crypto/machine.go b/crypto/machine.go index e791e70d..83ce024d 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -156,6 +156,10 @@ func (mach *OlmMachine) Load(ctx context.Context) (err error) { if mach.account == nil { mach.account = NewOlmAccount() } + zerolog.Ctx(ctx).Debug(). + Str("machine_ptr", fmt.Sprintf("%p", mach)). + Str("account_ptr", fmt.Sprintf("%p", mach.account.Internal)). + Msg("Loaded olm account") return nil } From bdb9e22a4372ed8a262b42238748a8670bacaa52 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 11 Sep 2025 13:22:45 +0300 Subject: [PATCH 1359/1647] crypto/libolm: clean up pointer management --- crypto/libolm/account.go | 73 +++++++++++++++++---------- crypto/libolm/inboundgroupsession.go | 60 +++++++++++++--------- crypto/libolm/outboundgroupsession.go | 44 ++++++++++------ crypto/libolm/pk.go | 63 +++++++++++++++++------ crypto/libolm/register.go | 12 +++-- crypto/libolm/session.go | 68 ++++++++++++++++--------- 6 files changed, 212 insertions(+), 108 deletions(-) diff --git a/crypto/libolm/account.go b/crypto/libolm/account.go index cddce7ce..a2212ccc 100644 --- a/crypto/libolm/account.go +++ b/crypto/libolm/account.go @@ -8,6 +8,7 @@ import ( "crypto/rand" "encoding/base64" "encoding/json" + "runtime" "unsafe" "github.com/tidwall/gjson" @@ -53,7 +54,7 @@ func AccountFromPickled(pickled, key []byte) (*Account, error) { func NewBlankAccount() *Account { memory := make([]byte, accountSize()) return &Account{ - int: C.olm_account(unsafe.Pointer(&memory[0])), + int: C.olm_account(unsafe.Pointer(unsafe.SliceData(memory))), mem: memory, } } @@ -68,8 +69,9 @@ func NewAccount() (*Account, error) { } ret := C.olm_create_account( (*C.OlmAccount)(a.int), - unsafe.Pointer(&random[0]), + unsafe.Pointer(unsafe.SliceData(random)), C.size_t(len(random))) + runtime.KeepAlive(random) if ret == errorVal() { return nil, a.lastError() } else { @@ -143,9 +145,9 @@ func (a *Account) Pickle(key []byte) ([]byte, error) { pickled := make([]byte, a.pickleLen()) r := C.olm_pickle_account( (*C.OlmAccount)(a.int), - unsafe.Pointer(&key[0]), + unsafe.Pointer(unsafe.SliceData(key)), C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), + unsafe.Pointer(unsafe.SliceData(pickled)), C.size_t(len(pickled))) if r == errorVal() { return nil, a.lastError() @@ -159,9 +161,9 @@ func (a *Account) Unpickle(pickled, key []byte) error { } r := C.olm_unpickle_account( (*C.OlmAccount)(a.int), - unsafe.Pointer(&key[0]), + unsafe.Pointer(unsafe.SliceData(key)), C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), + unsafe.Pointer(unsafe.SliceData(pickled)), C.size_t(len(pickled))) if r == errorVal() { return a.lastError() @@ -221,7 +223,7 @@ func (a *Account) IdentityKeysJSON() ([]byte, error) { identityKeys := make([]byte, a.identityKeysLen()) r := C.olm_account_identity_keys( (*C.OlmAccount)(a.int), - unsafe.Pointer(&identityKeys[0]), + unsafe.Pointer(unsafe.SliceData(identityKeys)), C.size_t(len(identityKeys))) if r == errorVal() { return nil, a.lastError() @@ -250,10 +252,11 @@ func (a *Account) Sign(message []byte) ([]byte, error) { signature := make([]byte, a.signatureLen()) r := C.olm_account_sign( (*C.OlmAccount)(a.int), - unsafe.Pointer(&message[0]), + unsafe.Pointer(unsafe.SliceData(message)), C.size_t(len(message)), - unsafe.Pointer(&signature[0]), + unsafe.Pointer(unsafe.SliceData(signature)), C.size_t(len(signature))) + runtime.KeepAlive(message) if r == errorVal() { panic(a.lastError()) } @@ -277,8 +280,9 @@ func (a *Account) OneTimeKeys() (map[string]id.Curve25519, error) { oneTimeKeysJSON := make([]byte, a.oneTimeKeysLen()) r := C.olm_account_one_time_keys( (*C.OlmAccount)(a.int), - unsafe.Pointer(&oneTimeKeysJSON[0]), - C.size_t(len(oneTimeKeysJSON))) + unsafe.Pointer(unsafe.SliceData(oneTimeKeysJSON)), + C.size_t(len(oneTimeKeysJSON)), + ) if r == errorVal() { return nil, a.lastError() } @@ -312,8 +316,10 @@ func (a *Account) GenOneTimeKeys(num uint) error { r := C.olm_account_generate_one_time_keys( (*C.OlmAccount)(a.int), C.size_t(num), - unsafe.Pointer(&random[0]), - C.size_t(len(random))) + unsafe.Pointer(unsafe.SliceData(random)), + C.size_t(len(random)), + ) + runtime.KeepAlive(random) if r == errorVal() { return a.lastError() } @@ -333,15 +339,21 @@ func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve2 if err != nil { panic(olm.NotEnoughGoRandom) } + theirIdentityKeyCopy := []byte(theirIdentityKey) + theirOneTimeKeyCopy := []byte(theirOneTimeKey) r := C.olm_create_outbound_session( (*C.OlmSession)(s.int), (*C.OlmAccount)(a.int), - unsafe.Pointer(&([]byte(theirIdentityKey)[0])), - C.size_t(len(theirIdentityKey)), - unsafe.Pointer(&([]byte(theirOneTimeKey)[0])), - C.size_t(len(theirOneTimeKey)), - unsafe.Pointer(&random[0]), - C.size_t(len(random))) + unsafe.Pointer(unsafe.SliceData(theirIdentityKeyCopy)), + C.size_t(len(theirIdentityKeyCopy)), + unsafe.Pointer(unsafe.SliceData(theirOneTimeKeyCopy)), + C.size_t(len(theirOneTimeKeyCopy)), + unsafe.Pointer(unsafe.SliceData(random)), + C.size_t(len(random)), + ) + runtime.KeepAlive(random) + runtime.KeepAlive(theirIdentityKeyCopy) + runtime.KeepAlive(theirOneTimeKeyCopy) if r == errorVal() { return nil, s.lastError() } @@ -360,11 +372,14 @@ func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) { return nil, olm.EmptyInput } s := NewBlankSession() + oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) r := C.olm_create_inbound_session( (*C.OlmSession)(s.int), (*C.OlmAccount)(a.int), - unsafe.Pointer(&([]byte(oneTimeKeyMsg)[0])), - C.size_t(len(oneTimeKeyMsg))) + unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)), + C.size_t(len(oneTimeKeyMsgCopy)), + ) + runtime.KeepAlive(oneTimeKeyMsgCopy) if r == errorVal() { return nil, s.lastError() } @@ -382,14 +397,19 @@ func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTime if theirIdentityKey == nil || len(oneTimeKeyMsg) == 0 { return nil, olm.EmptyInput } + theirIdentityKeyCopy := []byte(*theirIdentityKey) + oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) s := NewBlankSession() r := C.olm_create_inbound_session_from( (*C.OlmSession)(s.int), (*C.OlmAccount)(a.int), - unsafe.Pointer(&([]byte(*theirIdentityKey)[0])), - C.size_t(len(*theirIdentityKey)), - unsafe.Pointer(&([]byte(oneTimeKeyMsg)[0])), - C.size_t(len(oneTimeKeyMsg))) + unsafe.Pointer(unsafe.SliceData(theirIdentityKeyCopy)), + C.size_t(len(theirIdentityKeyCopy)), + unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)), + C.size_t(len(oneTimeKeyMsgCopy)), + ) + runtime.KeepAlive(theirIdentityKeyCopy) + runtime.KeepAlive(oneTimeKeyMsgCopy) if r == errorVal() { return nil, s.lastError() } @@ -402,7 +422,8 @@ func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTime func (a *Account) RemoveOneTimeKeys(s olm.Session) error { r := C.olm_remove_one_time_keys( (*C.OlmAccount)(a.int), - (*C.OlmSession)(s.(*Session).int)) + (*C.OlmSession)(s.(*Session).int), + ) if r == errorVal() { return a.lastError() } diff --git a/crypto/libolm/inboundgroupsession.go b/crypto/libolm/inboundgroupsession.go index 1e25748d..d7912a7a 100644 --- a/crypto/libolm/inboundgroupsession.go +++ b/crypto/libolm/inboundgroupsession.go @@ -7,6 +7,7 @@ import "C" import ( "bytes" "encoding/base64" + "runtime" "unsafe" "maunium.net/go/mautrix/crypto/olm" @@ -67,8 +68,10 @@ func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { s := NewBlankInboundGroupSession() r := C.olm_init_inbound_group_session( (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(&sessionKey[0]), - C.size_t(len(sessionKey))) + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionKey))), + C.size_t(len(sessionKey)), + ) + runtime.KeepAlive(sessionKey) if r == errorVal() { return nil, s.lastError() } @@ -86,8 +89,10 @@ func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) s := NewBlankInboundGroupSession() r := C.olm_import_inbound_group_session( (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(&sessionKey[0]), - C.size_t(len(sessionKey))) + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionKey))), + C.size_t(len(sessionKey)), + ) + runtime.KeepAlive(sessionKey) if r == errorVal() { return nil, s.lastError() } @@ -104,7 +109,7 @@ func inboundGroupSessionSize() uint { func NewBlankInboundGroupSession() *InboundGroupSession { memory := make([]byte, inboundGroupSessionSize()) return &InboundGroupSession{ - int: C.olm_inbound_group_session(unsafe.Pointer(&memory[0])), + int: C.olm_inbound_group_session(unsafe.Pointer(unsafe.SliceData(memory))), mem: memory, } } @@ -139,10 +144,12 @@ func (s *InboundGroupSession) Pickle(key []byte) ([]byte, error) { pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_inbound_group_session( (*C.OlmInboundGroupSession)(s.int), - unsafe.Pointer(&key[0]), + unsafe.Pointer(unsafe.SliceData(key)), C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), - C.size_t(len(pickled))) + unsafe.Pointer(unsafe.SliceData(pickled)), + C.size_t(len(pickled)), + ) + runtime.KeepAlive(key) if r == errorVal() { return nil, s.lastError() } @@ -157,10 +164,12 @@ func (s *InboundGroupSession) Unpickle(pickled, key []byte) error { } r := C.olm_unpickle_inbound_group_session( (*C.OlmInboundGroupSession)(s.int), - unsafe.Pointer(&key[0]), + unsafe.Pointer(unsafe.SliceData(key)), C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), - C.size_t(len(pickled))) + unsafe.Pointer(unsafe.SliceData(pickled)), + C.size_t(len(pickled)), + ) + runtime.KeepAlive(key) if r == errorVal() { return s.lastError() } @@ -226,11 +235,13 @@ func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, erro return 0, olm.EmptyInput } // olm_group_decrypt_max_plaintext_length destroys the input, so we have to clone it - message = bytes.Clone(message) + messageCopy := bytes.Clone(message) r := C.olm_group_decrypt_max_plaintext_length( (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(&message[0]), - C.size_t(len(message))) + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(messageCopy))), + C.size_t(len(messageCopy)), + ) + runtime.KeepAlive(messageCopy) if r == errorVal() { return 0, s.lastError() } @@ -254,17 +265,18 @@ func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) { if err != nil { return nil, 0, err } - messageCopy := make([]byte, len(message)) - copy(messageCopy, message) + messageCopy := bytes.Clone(message) plaintext := make([]byte, decryptMaxPlaintextLen) var messageIndex uint32 r := C.olm_group_decrypt( (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(&messageCopy[0]), + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(messageCopy))), C.size_t(len(messageCopy)), - (*C.uint8_t)(&plaintext[0]), + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(plaintext))), C.size_t(len(plaintext)), - (*C.uint32_t)(&messageIndex)) + (*C.uint32_t)(unsafe.Pointer(&messageIndex)), + ) + runtime.KeepAlive(messageCopy) if r == errorVal() { return nil, 0, s.lastError() } @@ -281,8 +293,9 @@ func (s *InboundGroupSession) ID() id.SessionID { sessionID := make([]byte, s.sessionIdLen()) r := C.olm_inbound_group_session_id( (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(&sessionID[0]), - C.size_t(len(sessionID))) + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionID))), + C.size_t(len(sessionID)), + ) if r == errorVal() { panic(s.lastError()) } @@ -318,9 +331,10 @@ func (s *InboundGroupSession) Export(messageIndex uint32) ([]byte, error) { key := make([]byte, s.exportLen()) r := C.olm_export_inbound_group_session( (*C.OlmInboundGroupSession)(s.int), - (*C.uint8_t)(&key[0]), + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(key))), C.size_t(len(key)), - C.uint32_t(messageIndex)) + C.uint32_t(messageIndex), + ) if r == errorVal() { return nil, s.lastError() } diff --git a/crypto/libolm/outboundgroupsession.go b/crypto/libolm/outboundgroupsession.go index a21f8d4a..94df66d7 100644 --- a/crypto/libolm/outboundgroupsession.go +++ b/crypto/libolm/outboundgroupsession.go @@ -7,6 +7,7 @@ import "C" import ( "crypto/rand" "encoding/base64" + "runtime" "unsafe" "maunium.net/go/mautrix/crypto/olm" @@ -44,8 +45,10 @@ func NewOutboundGroupSession() (*OutboundGroupSession, error) { } r := C.olm_init_outbound_group_session( (*C.OlmOutboundGroupSession)(s.int), - (*C.uint8_t)(&random[0]), - C.size_t(len(random))) + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(random))), + C.size_t(len(random)), + ) + runtime.KeepAlive(random) if r == errorVal() { return nil, s.lastError() } @@ -62,7 +65,7 @@ func outboundGroupSessionSize() uint { func NewBlankOutboundGroupSession() *OutboundGroupSession { memory := make([]byte, outboundGroupSessionSize()) return &OutboundGroupSession{ - int: C.olm_outbound_group_session(unsafe.Pointer(&memory[0])), + int: C.olm_outbound_group_session(unsafe.Pointer(unsafe.SliceData(memory))), mem: memory, } } @@ -98,10 +101,12 @@ func (s *OutboundGroupSession) Pickle(key []byte) ([]byte, error) { pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_outbound_group_session( (*C.OlmOutboundGroupSession)(s.int), - unsafe.Pointer(&key[0]), + unsafe.Pointer(unsafe.SliceData(key)), C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), - C.size_t(len(pickled))) + unsafe.Pointer(unsafe.SliceData(pickled)), + C.size_t(len(pickled)), + ) + runtime.KeepAlive(key) if r == errorVal() { return nil, s.lastError() } @@ -114,10 +119,13 @@ func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error { } r := C.olm_unpickle_outbound_group_session( (*C.OlmOutboundGroupSession)(s.int), - unsafe.Pointer(&key[0]), + unsafe.Pointer(unsafe.SliceData(key)), C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), - C.size_t(len(pickled))) + unsafe.Pointer(unsafe.SliceData(pickled)), + C.size_t(len(pickled)), + ) + runtime.KeepAlive(pickled) + runtime.KeepAlive(key) if r == errorVal() { return s.lastError() } @@ -192,10 +200,12 @@ func (s *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) { message := make([]byte, s.encryptMsgLen(len(plaintext))) r := C.olm_group_encrypt( (*C.OlmOutboundGroupSession)(s.int), - (*C.uint8_t)(&plaintext[0]), + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(plaintext))), C.size_t(len(plaintext)), - (*C.uint8_t)(&message[0]), - C.size_t(len(message))) + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(message))), + C.size_t(len(message)), + ) + runtime.KeepAlive(plaintext) if r == errorVal() { return nil, s.lastError() } @@ -212,8 +222,9 @@ func (s *OutboundGroupSession) ID() id.SessionID { sessionID := make([]byte, s.sessionIdLen()) r := C.olm_outbound_group_session_id( (*C.OlmOutboundGroupSession)(s.int), - (*C.uint8_t)(&sessionID[0]), - C.size_t(len(sessionID))) + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionID))), + C.size_t(len(sessionID)), + ) if r == errorVal() { panic(s.lastError()) } @@ -236,8 +247,9 @@ func (s *OutboundGroupSession) Key() string { sessionKey := make([]byte, s.sessionKeyLen()) r := C.olm_outbound_group_session_key( (*C.OlmOutboundGroupSession)(s.int), - (*C.uint8_t)(&sessionKey[0]), - C.size_t(len(sessionKey))) + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(sessionKey))), + C.size_t(len(sessionKey)), + ) if r == errorVal() { panic(s.lastError()) } diff --git a/crypto/libolm/pk.go b/crypto/libolm/pk.go index db8d35c5..172f4191 100644 --- a/crypto/libolm/pk.go +++ b/crypto/libolm/pk.go @@ -14,6 +14,7 @@ import "C" import ( "crypto/rand" "encoding/json" + "runtime" "unsafe" "github.com/tidwall/sjson" @@ -63,7 +64,7 @@ func pkSigningSignatureLength() uint { func newBlankPKSigning() *PKSigning { memory := make([]byte, pkSigningSize()) return &PKSigning{ - int: C.olm_pk_signing(unsafe.Pointer(&memory[0])), + int: C.olm_pk_signing(unsafe.Pointer(unsafe.SliceData(memory))), mem: memory, } } @@ -73,9 +74,14 @@ func NewPKSigningFromSeed(seed []byte) (*PKSigning, error) { p := newBlankPKSigning() p.clear() pubKey := make([]byte, pkSigningPublicKeyLength()) - if C.olm_pk_signing_key_from_seed((*C.OlmPkSigning)(p.int), - unsafe.Pointer(&pubKey[0]), C.size_t(len(pubKey)), - unsafe.Pointer(&seed[0]), C.size_t(len(seed))) == errorVal() { + r := C.olm_pk_signing_key_from_seed( + (*C.OlmPkSigning)(p.int), + unsafe.Pointer(unsafe.SliceData(pubKey)), + C.size_t(len(pubKey)), + unsafe.Pointer(unsafe.SliceData(seed)), + C.size_t(len(seed)), + ) + if r == errorVal() { return nil, p.lastError() } p.publicKey = id.Ed25519(pubKey) @@ -112,8 +118,15 @@ func (p *PKSigning) clear() { // Sign creates a signature for the given message using this key. func (p *PKSigning) Sign(message []byte) ([]byte, error) { signature := make([]byte, pkSigningSignatureLength()) - if C.olm_pk_sign((*C.OlmPkSigning)(p.int), (*C.uint8_t)(unsafe.Pointer(&message[0])), C.size_t(len(message)), - (*C.uint8_t)(unsafe.Pointer(&signature[0])), C.size_t(len(signature))) == errorVal() { + r := C.olm_pk_sign( + (*C.OlmPkSigning)(p.int), + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(message))), + C.size_t(len(message)), + (*C.uint8_t)(unsafe.Pointer(unsafe.SliceData(signature))), + C.size_t(len(signature)), + ) + runtime.KeepAlive(message) + if r == errorVal() { return nil, p.lastError() } return signature, nil @@ -157,15 +170,21 @@ func pkDecryptionPublicKeySize() uint { func NewPkDecryption(privateKey []byte) (*PKDecryption, error) { memory := make([]byte, pkDecryptionSize()) p := &PKDecryption{ - int: C.olm_pk_decryption(unsafe.Pointer(&memory[0])), + int: C.olm_pk_decryption(unsafe.Pointer(unsafe.SliceData(memory))), mem: memory, } p.clear() pubKey := make([]byte, pkDecryptionPublicKeySize()) - if C.olm_pk_key_from_private((*C.OlmPkDecryption)(p.int), - unsafe.Pointer(&pubKey[0]), C.size_t(len(pubKey)), - unsafe.Pointer(&privateKey[0]), C.size_t(len(privateKey))) == errorVal() { + r := C.olm_pk_key_from_private( + (*C.OlmPkDecryption)(p.int), + unsafe.Pointer(unsafe.SliceData(pubKey)), + C.size_t(len(pubKey)), + unsafe.Pointer(unsafe.SliceData(privateKey)), + C.size_t(len(privateKey)), + ) + runtime.KeepAlive(privateKey) + if r == errorVal() { return nil, p.lastError() } p.publicKey = pubKey @@ -178,14 +197,26 @@ func (p *PKDecryption) PublicKey() id.Curve25519 { } func (p *PKDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byte) ([]byte, error) { - maxPlaintextLength := uint(C.olm_pk_max_plaintext_length((*C.OlmPkDecryption)(p.int), C.size_t(len(ciphertext)))) + maxPlaintextLength := uint(C.olm_pk_max_plaintext_length( + (*C.OlmPkDecryption)(p.int), + C.size_t(len(ciphertext)), + )) plaintext := make([]byte, maxPlaintextLength) - size := C.olm_pk_decrypt((*C.OlmPkDecryption)(p.int), - unsafe.Pointer(&ephemeralKey[0]), C.size_t(len(ephemeralKey)), - unsafe.Pointer(&mac[0]), C.size_t(len(mac)), - unsafe.Pointer(&ciphertext[0]), C.size_t(len(ciphertext)), - unsafe.Pointer(&plaintext[0]), C.size_t(len(plaintext))) + size := C.olm_pk_decrypt( + (*C.OlmPkDecryption)(p.int), + unsafe.Pointer(unsafe.SliceData(ephemeralKey)), + C.size_t(len(ephemeralKey)), + unsafe.Pointer(unsafe.SliceData(mac)), + C.size_t(len(mac)), + unsafe.Pointer(unsafe.SliceData(ciphertext)), + C.size_t(len(ciphertext)), + unsafe.Pointer(unsafe.SliceData(plaintext)), + C.size_t(len(plaintext)), + ) + runtime.KeepAlive(ephemeralKey) + runtime.KeepAlive(mac) + runtime.KeepAlive(ciphertext) if size == errorVal() { return nil, p.lastError() } diff --git a/crypto/libolm/register.go b/crypto/libolm/register.go index a423a7d0..6aaec61e 100644 --- a/crypto/libolm/register.go +++ b/crypto/libolm/register.go @@ -3,16 +3,20 @@ package libolm // #cgo LDFLAGS: -lolm -lstdc++ // #include import "C" -import "maunium.net/go/mautrix/crypto/olm" +import ( + "unsafe" + + "maunium.net/go/mautrix/crypto/olm" +) var pickleKey = []byte("maunium.net/go/mautrix/crypto/olm") func init() { olm.GetVersion = func() (major, minor, patch uint8) { C.olm_get_library_version( - (*C.uint8_t)(&major), - (*C.uint8_t)(&minor), - (*C.uint8_t)(&patch)) + (*C.uint8_t)(unsafe.Pointer(&major)), + (*C.uint8_t)(unsafe.Pointer(&minor)), + (*C.uint8_t)(unsafe.Pointer(&patch))) return 3, 2, 15 } olm.SetPickleKeyImpl = func(key []byte) { diff --git a/crypto/libolm/session.go b/crypto/libolm/session.go index 4cc22809..810dc7a6 100644 --- a/crypto/libolm/session.go +++ b/crypto/libolm/session.go @@ -23,6 +23,7 @@ import "C" import ( "crypto/rand" "encoding/base64" + "runtime" "unsafe" "maunium.net/go/mautrix/crypto/olm" @@ -68,7 +69,7 @@ func SessionFromPickled(pickled, key []byte) (*Session, error) { func NewBlankSession() *Session { memory := make([]byte, sessionSize()) return &Session{ - int: C.olm_session(unsafe.Pointer(&memory[0])), + int: C.olm_session(unsafe.Pointer(unsafe.SliceData(memory))), mem: memory, } } @@ -128,11 +129,14 @@ func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) if len(message) == 0 { return 0, olm.EmptyInput } + messageCopy := []byte(message) r := C.olm_decrypt_max_plaintext_length( (*C.OlmSession)(s.int), C.size_t(msgType), - unsafe.Pointer(C.CString(message)), - C.size_t(len(message))) + unsafe.Pointer(unsafe.SliceData((messageCopy))), + C.size_t(len(messageCopy)), + ) + runtime.KeepAlive(messageCopy) if r == errorVal() { return 0, s.lastError() } @@ -148,10 +152,11 @@ func (s *Session) Pickle(key []byte) ([]byte, error) { pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_session( (*C.OlmSession)(s.int), - unsafe.Pointer(&key[0]), + unsafe.Pointer(unsafe.SliceData(key)), C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), + unsafe.Pointer(unsafe.SliceData(pickled)), C.size_t(len(pickled))) + runtime.KeepAlive(key) if r == errorVal() { panic(s.lastError()) } @@ -166,10 +171,12 @@ func (s *Session) Unpickle(pickled, key []byte) error { } r := C.olm_unpickle_session( (*C.OlmSession)(s.int), - unsafe.Pointer(&key[0]), + unsafe.Pointer(unsafe.SliceData(key)), C.size_t(len(key)), - unsafe.Pointer(&pickled[0]), + unsafe.Pointer(unsafe.SliceData(pickled)), C.size_t(len(pickled))) + runtime.KeepAlive(pickled) + runtime.KeepAlive(key) if r == errorVal() { return s.lastError() } @@ -229,8 +236,9 @@ func (s *Session) ID() id.SessionID { sessionID := make([]byte, s.idLen()) r := C.olm_session_id( (*C.OlmSession)(s.int), - unsafe.Pointer(&sessionID[0]), - C.size_t(len(sessionID))) + unsafe.Pointer(unsafe.SliceData(sessionID)), + C.size_t(len(sessionID)), + ) if r == errorVal() { panic(s.lastError()) } @@ -259,10 +267,13 @@ func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { if len(oneTimeKeyMsg) == 0 { return false, olm.EmptyInput } + oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) r := C.olm_matches_inbound_session( (*C.OlmSession)(s.int), - unsafe.Pointer(&([]byte(oneTimeKeyMsg))[0]), - C.size_t(len(oneTimeKeyMsg))) + unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)), + C.size_t(len(oneTimeKeyMsgCopy)), + ) + runtime.KeepAlive(oneTimeKeyMsgCopy) if r == 1 { return true, nil } else if r == 0 { @@ -284,12 +295,17 @@ func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg stri if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 { return false, olm.EmptyInput } + theirIdentityKeyCopy := []byte(theirIdentityKey) + oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) r := C.olm_matches_inbound_session_from( (*C.OlmSession)(s.int), - unsafe.Pointer(&([]byte(theirIdentityKey))[0]), - C.size_t(len(theirIdentityKey)), - unsafe.Pointer(&([]byte(oneTimeKeyMsg))[0]), - C.size_t(len(oneTimeKeyMsg))) + unsafe.Pointer(unsafe.SliceData(theirIdentityKeyCopy)), + C.size_t(len(theirIdentityKeyCopy)), + unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)), + C.size_t(len(oneTimeKeyMsgCopy)), + ) + runtime.KeepAlive(theirIdentityKeyCopy) + runtime.KeepAlive(oneTimeKeyMsgCopy) if r == 1 { return true, nil } else if r == 0 { @@ -331,12 +347,15 @@ func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) { message := make([]byte, s.encryptMsgLen(len(plaintext))) r := C.olm_encrypt( (*C.OlmSession)(s.int), - unsafe.Pointer(&plaintext[0]), + unsafe.Pointer(unsafe.SliceData(plaintext)), C.size_t(len(plaintext)), - unsafe.Pointer(&random[0]), + unsafe.Pointer(unsafe.SliceData(random)), C.size_t(len(random)), - unsafe.Pointer(&message[0]), - C.size_t(len(message))) + unsafe.Pointer(unsafe.SliceData(message)), + C.size_t(len(message)), + ) + runtime.KeepAlive(plaintext) + runtime.KeepAlive(random) if r == errorVal() { return 0, nil, s.lastError() } @@ -363,10 +382,12 @@ func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) r := C.olm_decrypt( (*C.OlmSession)(s.int), C.size_t(msgType), - unsafe.Pointer(&(messageCopy)[0]), + unsafe.Pointer(unsafe.SliceData(messageCopy)), C.size_t(len(messageCopy)), - unsafe.Pointer(&plaintext[0]), - C.size_t(len(plaintext))) + unsafe.Pointer(unsafe.SliceData(plaintext)), + C.size_t(len(plaintext)), + ) + runtime.KeepAlive(messageCopy) if r == errorVal() { return nil, s.lastError() } @@ -383,6 +404,7 @@ func (s *Session) Describe() string { C.meowlm_session_describe( (*C.OlmSession)(s.int), desc, - C.size_t(maxDescribeSize)) + C.size_t(maxDescribeSize), + ) return C.GoString(desc) } From 69869f7cb502a833ca87ee13fcd78303d69f9b4e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 11 Sep 2025 14:12:35 +0300 Subject: [PATCH 1360/1647] crypto: log active driver --- crypto/goolm/register.go | 2 ++ crypto/libolm/register.go | 2 ++ crypto/machine.go | 2 ++ crypto/olm/account.go | 2 ++ 4 files changed, 8 insertions(+) diff --git a/crypto/goolm/register.go b/crypto/goolm/register.go index 80ed206b..17e7f207 100644 --- a/crypto/goolm/register.go +++ b/crypto/goolm/register.go @@ -16,6 +16,8 @@ import ( ) func init() { + olm.Driver = "goolm" + olm.GetVersion = func() (major, minor, patch uint8) { return 3, 2, 15 } diff --git a/crypto/libolm/register.go b/crypto/libolm/register.go index 6aaec61e..06c07ea8 100644 --- a/crypto/libolm/register.go +++ b/crypto/libolm/register.go @@ -12,6 +12,8 @@ import ( var pickleKey = []byte("maunium.net/go/mautrix/crypto/olm") func init() { + olm.Driver = "libolm" + olm.GetVersion = func() (major, minor, patch uint8) { C.olm_get_library_version( (*C.uint8_t)(unsafe.Pointer(&major)), diff --git a/crypto/machine.go b/crypto/machine.go index 83ce024d..da3ebe67 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -19,6 +19,7 @@ import ( "go.mau.fi/util/exzerolog" "maunium.net/go/mautrix" + "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/ssss" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -159,6 +160,7 @@ func (mach *OlmMachine) Load(ctx context.Context) (err error) { zerolog.Ctx(ctx).Debug(). Str("machine_ptr", fmt.Sprintf("%p", mach)). Str("account_ptr", fmt.Sprintf("%p", mach.account.Internal)). + Str("olm_driver", olm.Driver). Msg("Loaded olm account") return nil } diff --git a/crypto/olm/account.go b/crypto/olm/account.go index 68393e8a..2ec5dd70 100644 --- a/crypto/olm/account.go +++ b/crypto/olm/account.go @@ -87,6 +87,8 @@ type Account interface { RemoveOneTimeKeys(s Session) error } +var Driver = "none" + var InitBlankAccount func() Account var InitNewAccount func() (Account, error) var InitNewAccountFromPickled func(pickled, key []byte) (Account, error) From 84e5d6bda1dfa792bc4098a6cf41ae25527c83e5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 11 Sep 2025 14:13:18 +0300 Subject: [PATCH 1361/1647] crypto/machine: allow canceling background context --- crypto/decryptolm.go | 3 ++- crypto/machine.go | 21 ++++++++++++++++++--- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go index f54210a7..bd9f1753 100644 --- a/crypto/decryptolm.go +++ b/crypto/decryptolm.go @@ -272,6 +272,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession( if err != nil { log.Warn().Err(err). Hex("ciphertext_hash", ciphertextHash[:]). + Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]). Str("session_description", session.Describe()). Msg("Failed to decrypt olm message") if olmType == id.OlmMsgTypePreKey { @@ -315,7 +316,7 @@ const MinUnwedgeInterval = 1 * time.Hour func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, senderKey id.SenderKey) { log = log.With().Str("action", "unwedge olm session").Logger() - ctx := log.WithContext(mach.BackgroundCtx) + ctx := log.WithContext(mach.backgroundCtx) mach.recentlyUnwedgedLock.Lock() prevUnwedge, ok := mach.recentlyUnwedged[senderKey] delta := time.Now().Sub(prevUnwedge) diff --git a/crypto/machine.go b/crypto/machine.go index da3ebe67..eb238922 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -15,6 +15,7 @@ import ( "time" "github.com/rs/zerolog" + "go.mau.fi/util/ptr" "go.mau.fi/util/exzerolog" @@ -34,7 +35,8 @@ type OlmMachine struct { CryptoStore Store StateStore StateStore - BackgroundCtx context.Context + backgroundCtx context.Context + cancelBackgroundCtx context.CancelFunc PlaintextMentions bool AllowEncryptedState bool @@ -121,8 +123,6 @@ func NewOlmMachine(client *mautrix.Client, log *zerolog.Logger, cryptoStore Stor CryptoStore: cryptoStore, StateStore: stateStore, - BackgroundCtx: context.Background(), - SendKeysMinTrust: id.TrustStateUnset, ShareKeysMinTrust: id.TrustStateCrossSignedTOFU, @@ -135,6 +135,7 @@ func NewOlmMachine(client *mautrix.Client, log *zerolog.Logger, cryptoStore Stor recentlyUnwedged: make(map[id.IdentityKey]time.Time), secretListeners: make(map[string]chan<- string), } + mach.backgroundCtx, mach.cancelBackgroundCtx = context.WithCancel(context.Background()) mach.AllowKeyShare = mach.defaultAllowKeyShare return mach } @@ -147,6 +148,11 @@ func (mach *OlmMachine) machOrContextLog(ctx context.Context) *zerolog.Logger { return log } +func (mach *OlmMachine) SetBackgroundCtx(ctx context.Context) { + mach.cancelBackgroundCtx() + mach.backgroundCtx, mach.cancelBackgroundCtx = context.WithCancel(ctx) +} + // Load loads the Olm account information from the crypto store. If there's no olm account, a new one is created. // This must be called before using the machine. func (mach *OlmMachine) Load(ctx context.Context) (err error) { @@ -165,6 +171,15 @@ func (mach *OlmMachine) Load(ctx context.Context) (err error) { return nil } +func (mach *OlmMachine) Destroy() { + mach.Log.Debug(). + Str("machine_ptr", fmt.Sprintf("%p", mach)). + Str("account_ptr", fmt.Sprintf("%p", ptr.Val(mach.account).Internal)). + Msg("Destroying olm machine") + mach.cancelBackgroundCtx() + mach.account = nil +} + func (mach *OlmMachine) saveAccount(ctx context.Context) error { err := mach.CryptoStore.PutAccount(ctx, mach.account) if err != nil { From c716f30959c011246078c077e00fcf858353af12 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 11 Sep 2025 14:14:08 +0300 Subject: [PATCH 1362/1647] crypto/register: don't use init in *olm packages --- crypto/goolm/account/register.go | 2 +- crypto/goolm/pk/register.go | 2 +- crypto/goolm/register.go | 14 ++++---- crypto/goolm/session/register.go | 2 +- crypto/libolm/account.go | 12 ------- crypto/libolm/inboundgroupsession.go | 15 -------- crypto/libolm/outboundgroupsession.go | 12 ------- crypto/libolm/pk.go | 10 ------ crypto/libolm/register.go | 50 ++++++++++++++++++++++++++- crypto/libolm/session.go | 9 ----- crypto/registergoolm.go | 8 ++++- crypto/registerlibolm.go | 6 +++- 12 files changed, 72 insertions(+), 70 deletions(-) diff --git a/crypto/goolm/account/register.go b/crypto/goolm/account/register.go index c6b9e523..ec392d7e 100644 --- a/crypto/goolm/account/register.go +++ b/crypto/goolm/account/register.go @@ -10,7 +10,7 @@ import ( "maunium.net/go/mautrix/crypto/olm" ) -func init() { +func Register() { olm.InitNewAccount = func() (olm.Account, error) { return NewAccount() } diff --git a/crypto/goolm/pk/register.go b/crypto/goolm/pk/register.go index b7af6a5b..0e27b568 100644 --- a/crypto/goolm/pk/register.go +++ b/crypto/goolm/pk/register.go @@ -8,7 +8,7 @@ package pk import "maunium.net/go/mautrix/crypto/olm" -func init() { +func Register() { olm.InitNewPKSigningFromSeed = func(seed []byte) (olm.PKSigning, error) { return NewSigningFromSeed(seed) } diff --git a/crypto/goolm/register.go b/crypto/goolm/register.go index 17e7f207..800f567f 100644 --- a/crypto/goolm/register.go +++ b/crypto/goolm/register.go @@ -7,15 +7,13 @@ package goolm import ( - // Need to import these subpackages to ensure they are registered - _ "maunium.net/go/mautrix/crypto/goolm/account" - _ "maunium.net/go/mautrix/crypto/goolm/pk" - _ "maunium.net/go/mautrix/crypto/goolm/session" - + "maunium.net/go/mautrix/crypto/goolm/account" + "maunium.net/go/mautrix/crypto/goolm/pk" + "maunium.net/go/mautrix/crypto/goolm/session" "maunium.net/go/mautrix/crypto/olm" ) -func init() { +func Register() { olm.Driver = "goolm" olm.GetVersion = func() (major, minor, patch uint8) { @@ -24,4 +22,8 @@ func init() { olm.SetPickleKeyImpl = func(key []byte) { panic("gob and json encoding is deprecated and not supported with goolm") } + + account.Register() + pk.Register() + session.Register() } diff --git a/crypto/goolm/session/register.go b/crypto/goolm/session/register.go index 09ed42d4..a88d12f6 100644 --- a/crypto/goolm/session/register.go +++ b/crypto/goolm/session/register.go @@ -10,7 +10,7 @@ import ( "maunium.net/go/mautrix/crypto/olm" ) -func init() { +func Register() { // Inbound Session olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) { if len(pickled) == 0 { diff --git a/crypto/libolm/account.go b/crypto/libolm/account.go index a2212ccc..f6f916e7 100644 --- a/crypto/libolm/account.go +++ b/crypto/libolm/account.go @@ -23,18 +23,6 @@ type Account struct { mem []byte } -func init() { - olm.InitNewAccount = func() (olm.Account, error) { - return NewAccount() - } - olm.InitBlankAccount = func() olm.Account { - return NewBlankAccount() - } - olm.InitNewAccountFromPickled = func(pickled, key []byte) (olm.Account, error) { - return AccountFromPickled(pickled, key) - } -} - // Ensure that [Account] implements [olm.Account]. var _ olm.Account = (*Account)(nil) diff --git a/crypto/libolm/inboundgroupsession.go b/crypto/libolm/inboundgroupsession.go index d7912a7a..5606475d 100644 --- a/crypto/libolm/inboundgroupsession.go +++ b/crypto/libolm/inboundgroupsession.go @@ -21,21 +21,6 @@ type InboundGroupSession struct { mem []byte } -func init() { - olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) { - return InboundGroupSessionFromPickled(pickled, key) - } - olm.InitNewInboundGroupSession = func(sessionKey []byte) (olm.InboundGroupSession, error) { - return NewInboundGroupSession(sessionKey) - } - olm.InitInboundGroupSessionImport = func(sessionKey []byte) (olm.InboundGroupSession, error) { - return InboundGroupSessionImport(sessionKey) - } - olm.InitBlankInboundGroupSession = func() olm.InboundGroupSession { - return NewBlankInboundGroupSession() - } -} - // Ensure that [InboundGroupSession] implements [olm.InboundGroupSession]. var _ olm.InboundGroupSession = (*InboundGroupSession)(nil) diff --git a/crypto/libolm/outboundgroupsession.go b/crypto/libolm/outboundgroupsession.go index 94df66d7..646929eb 100644 --- a/crypto/libolm/outboundgroupsession.go +++ b/crypto/libolm/outboundgroupsession.go @@ -21,18 +21,6 @@ type OutboundGroupSession struct { mem []byte } -func init() { - olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) { - if len(pickled) == 0 { - return nil, olm.EmptyInput - } - s := NewBlankOutboundGroupSession() - return s, s.Unpickle(pickled, key) - } - olm.InitNewOutboundGroupSession = func() (olm.OutboundGroupSession, error) { return NewOutboundGroupSession() } - olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession { return NewBlankOutboundGroupSession() } -} - // Ensure that [OutboundGroupSession] implements [olm.OutboundGroupSession]. var _ olm.OutboundGroupSession = (*OutboundGroupSession)(nil) diff --git a/crypto/libolm/pk.go b/crypto/libolm/pk.go index 172f4191..35532140 100644 --- a/crypto/libolm/pk.go +++ b/crypto/libolm/pk.go @@ -35,16 +35,6 @@ type PKSigning struct { // Ensure that [PKSigning] implements [olm.PKSigning]. var _ olm.PKSigning = (*PKSigning)(nil) -func init() { - olm.InitNewPKSigning = func() (olm.PKSigning, error) { return NewPKSigning() } - olm.InitNewPKSigningFromSeed = func(seed []byte) (olm.PKSigning, error) { - return NewPKSigningFromSeed(seed) - } - olm.InitNewPKDecryptionFromPrivateKey = func(privateKey []byte) (olm.PKDecryption, error) { - return NewPkDecryption(privateKey) - } -} - func pkSigningSize() uint { return uint(C.olm_pk_signing_size()) } diff --git a/crypto/libolm/register.go b/crypto/libolm/register.go index 06c07ea8..f091d822 100644 --- a/crypto/libolm/register.go +++ b/crypto/libolm/register.go @@ -11,7 +11,7 @@ import ( var pickleKey = []byte("maunium.net/go/mautrix/crypto/olm") -func init() { +func Register() { olm.Driver = "libolm" olm.GetVersion = func() (major, minor, patch uint8) { @@ -24,4 +24,52 @@ func init() { olm.SetPickleKeyImpl = func(key []byte) { pickleKey = key } + + olm.InitNewAccount = func() (olm.Account, error) { + return NewAccount() + } + olm.InitBlankAccount = func() olm.Account { + return NewBlankAccount() + } + olm.InitNewAccountFromPickled = func(pickled, key []byte) (olm.Account, error) { + return AccountFromPickled(pickled, key) + } + + olm.InitSessionFromPickled = func(pickled, key []byte) (olm.Session, error) { + return SessionFromPickled(pickled, key) + } + olm.InitNewBlankSession = func() olm.Session { + return NewBlankSession() + } + + olm.InitNewPKSigning = func() (olm.PKSigning, error) { return NewPKSigning() } + olm.InitNewPKSigningFromSeed = func(seed []byte) (olm.PKSigning, error) { + return NewPKSigningFromSeed(seed) + } + olm.InitNewPKDecryptionFromPrivateKey = func(privateKey []byte) (olm.PKDecryption, error) { + return NewPkDecryption(privateKey) + } + + olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) { + return InboundGroupSessionFromPickled(pickled, key) + } + olm.InitNewInboundGroupSession = func(sessionKey []byte) (olm.InboundGroupSession, error) { + return NewInboundGroupSession(sessionKey) + } + olm.InitInboundGroupSessionImport = func(sessionKey []byte) (olm.InboundGroupSession, error) { + return InboundGroupSessionImport(sessionKey) + } + olm.InitBlankInboundGroupSession = func() olm.InboundGroupSession { + return NewBlankInboundGroupSession() + } + + olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) { + if len(pickled) == 0 { + return nil, olm.EmptyInput + } + s := NewBlankOutboundGroupSession() + return s, s.Unpickle(pickled, key) + } + olm.InitNewOutboundGroupSession = func() (olm.OutboundGroupSession, error) { return NewOutboundGroupSession() } + olm.InitNewBlankOutboundGroupSession = func() olm.OutboundGroupSession { return NewBlankOutboundGroupSession() } } diff --git a/crypto/libolm/session.go b/crypto/libolm/session.go index 810dc7a6..57e631c3 100644 --- a/crypto/libolm/session.go +++ b/crypto/libolm/session.go @@ -39,15 +39,6 @@ type Session struct { // Ensure that [Session] implements [olm.Session]. var _ olm.Session = (*Session)(nil) -func init() { - olm.InitSessionFromPickled = func(pickled, key []byte) (olm.Session, error) { - return SessionFromPickled(pickled, key) - } - olm.InitNewBlankSession = func() olm.Session { - return NewBlankSession() - } -} - // sessionSize is the size of a session object in bytes. func sessionSize() uint { return uint(C.olm_session_size()) diff --git a/crypto/registergoolm.go b/crypto/registergoolm.go index f5cecafc..6b5b65fd 100644 --- a/crypto/registergoolm.go +++ b/crypto/registergoolm.go @@ -2,4 +2,10 @@ package crypto -import _ "maunium.net/go/mautrix/crypto/goolm" +import ( + "maunium.net/go/mautrix/crypto/goolm" +) + +func init() { + goolm.Register() +} diff --git a/crypto/registerlibolm.go b/crypto/registerlibolm.go index ab388a5c..ef78b6b5 100644 --- a/crypto/registerlibolm.go +++ b/crypto/registerlibolm.go @@ -2,4 +2,8 @@ package crypto -import _ "maunium.net/go/mautrix/crypto/libolm" +import "maunium.net/go/mautrix/crypto/libolm" + +func init() { + libolm.Register() +} From 87fe12741427972dffcfbbdddd30a4f53f306419 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 11 Sep 2025 14:17:24 +0300 Subject: [PATCH 1363/1647] crypto/decryptolm: retry prekey decryption with goolm --- crypto/decryptolm.go | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go index bd9f1753..ba3c9831 100644 --- a/crypto/decryptolm.go +++ b/crypto/decryptolm.go @@ -20,6 +20,7 @@ import ( "go.mau.fi/util/exerrors" "go.mau.fi/util/ptr" + "maunium.net/go/mautrix/crypto/goolm/account" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -171,6 +172,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U return nil, DecryptionFailedForNormalMessage } + accountBackup, err := mach.account.Internal.Pickle([]byte("tmp")) log.Trace().Msg("Trying to create inbound session") endTimeTrace = mach.timeTrace(ctx, "creating inbound olm session", time.Second) session, err := mach.createInboundSession(ctx, senderKey, ciphertext) @@ -191,13 +193,20 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U plaintext, err = session.Decrypt(ciphertext, olmType) endTimeTrace() if err != nil { - go mach.unwedgeDevice(log, sender, senderKey) log.Debug(). Hex("ciphertext_hash", ciphertextHash[:]). Hex("ciphertext_hash_repeat", ptr.Ptr(exerrors.Must(olmMessageHash(ciphertext)))[:]). Str("ciphertext", ciphertext). Str("olm_session_description", session.Describe()). Msg("DEBUG: Failed to decrypt prekey olm message with newly created session") + err2 := mach.goolmRetryHack(ctx, senderKey, ciphertext, accountBackup) + if err2 != nil { + log.Debug().Err(err2).Msg("Goolm confirmed decryption failure") + } else { + log.Warn().Msg("Goolm decryption was successful after libolm failure?") + } + + go mach.unwedgeDevice(log, sender, senderKey) return nil, fmt.Errorf("failed to decrypt olm event with session created from prekey message: %w", err) } @@ -214,6 +223,23 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U return plaintext, nil } +func (mach *OlmMachine) goolmRetryHack(ctx context.Context, senderKey id.SenderKey, ciphertext string, accountBackup []byte) error { + acc, err := account.AccountFromPickled(accountBackup, []byte("tmp")) + if err != nil { + return fmt.Errorf("failed to unpickle olm account: %w", err) + } + sess, err := acc.NewInboundSessionFrom(&senderKey, ciphertext) + if err != nil { + return fmt.Errorf("failed to create inbound session: %w", err) + } + _, err = sess.Decrypt(ciphertext, id.OlmMsgTypePreKey) + if err != nil { + // This is the expected result if libolm failed + return fmt.Errorf("failed to decrypt with new session: %w", err) + } + return nil +} + const MaxOlmSessionsPerDevice = 5 func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession( From 5dbab3ae9927fa9cabaf608aeed7661a7a3c5d62 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 11 Sep 2025 14:46:21 +0300 Subject: [PATCH 1364/1647] crypto/machine: don't clear account on Destroy() --- crypto/machine.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crypto/machine.go b/crypto/machine.go index eb238922..ab3e4591 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -177,7 +177,7 @@ func (mach *OlmMachine) Destroy() { Str("account_ptr", fmt.Sprintf("%p", ptr.Val(mach.account).Internal)). Msg("Destroying olm machine") mach.cancelBackgroundCtx() - mach.account = nil + // TODO actually destroy something? } func (mach *OlmMachine) saveAccount(ctx context.Context) error { From 4603a344ce1daa95911106766911c52d65be3ee2 Mon Sep 17 00:00:00 2001 From: Tiago Loureiro Date: Thu, 11 Sep 2025 15:10:14 -0300 Subject: [PATCH 1365/1647] event: add org.matrix.msc3381.poll.end type (#412) --- event/type.go | 1 + 1 file changed, 1 insertion(+) diff --git a/event/type.go b/event/type.go index 3f01a067..5035f2fa 100644 --- a/event/type.go +++ b/event/type.go @@ -239,6 +239,7 @@ var ( EventUnstablePollStart = Type{Type: "org.matrix.msc3381.poll.start", Class: MessageEventType} EventUnstablePollResponse = Type{Type: "org.matrix.msc3381.poll.response", Class: MessageEventType} + EventUnstablePollEnd = Type{Type: "org.matrix.msc3381.poll.end", Class: MessageEventType} ) // Ephemeral events From 3a6f20bb623cac1b00ac07175904bb7d5af8614f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 12 Sep 2025 19:10:45 +0300 Subject: [PATCH 1366/1647] crypto/sqlstore: ignore unused sessions in olm unwedging --- crypto/decryptolm.go | 5 ++++- crypto/sql_store.go | 3 ++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go index ba3c9831..30cc4cfe 100644 --- a/crypto/decryptolm.go +++ b/crypto/decryptolm.go @@ -376,7 +376,10 @@ func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, send return } - log.Debug().Stringer("device_id", deviceIdentity.DeviceID).Msg("Creating new Olm session") + log.Debug(). + Time("last_created", lastCreatedAt). + Stringer("device_id", deviceIdentity.DeviceID). + Msg("Creating new Olm session") mach.devicesToUnwedgeLock.Lock() mach.devicesToUnwedge[senderKey] = true mach.devicesToUnwedgeLock.Unlock() diff --git a/crypto/sql_store.go b/crypto/sql_store.go index b0625763..4405cc31 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -251,8 +251,9 @@ func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.Sender } // GetNewestSessionCreationTS gets the creation timestamp of the most recently created session with the given sender key. +// This will exclude sessions that have never been used to encrypt or decrypt a message. func (store *SQLCryptoStore) GetNewestSessionCreationTS(ctx context.Context, key id.SenderKey) (createdAt time.Time, err error) { - err = store.DB.QueryRow(ctx, "SELECT created_at FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY created_at DESC LIMIT 1", + err = store.DB.QueryRow(ctx, "SELECT created_at FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 AND (encrypted_at <> created_at OR decrypted_at <> created_at) ORDER BY created_at DESC LIMIT 1", key, store.AccountID).Scan(&createdAt) if errors.Is(err, sql.ErrNoRows) { err = nil From 717c8c3092609f156b671b50f4f5194ac2ba730a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 13 Sep 2025 01:38:06 +0300 Subject: [PATCH 1367/1647] bridgev2/database: normalize disappearing settings before insert --- bridgev2/database/disappear.go | 14 ++++++++++++++ bridgev2/portal.go | 14 ++++---------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/bridgev2/database/disappear.go b/bridgev2/database/disappear.go index 537d0552..9874e472 100644 --- a/bridgev2/database/disappear.go +++ b/bridgev2/database/disappear.go @@ -37,6 +37,20 @@ type DisappearingSetting struct { DisappearAt time.Time } +func (ds DisappearingSetting) Normalize() DisappearingSetting { + if ds.Type == event.DisappearingTypeNone { + ds.Timer = 0 + } else if ds.Timer == 0 { + ds.Type = event.DisappearingTypeNone + } + return ds +} + +func (ds DisappearingSetting) StartingAt(start time.Time) DisappearingSetting { + ds.DisappearAt = start.Add(ds.Timer) + return ds +} + func (ds DisappearingSetting) ToEventContent() *event.BeeperDisappearingTimer { if ds.Type == event.DisappearingTypeNone || ds.Timer == 0 { return &event.BeeperDisappearingTimer{} diff --git a/bridgev2/portal.go b/bridgev2/portal.go index f3797247..7961a223 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1105,13 +1105,9 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin } if portal.Disappear.Type != event.DisappearingTypeNone { go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ - RoomID: portal.MXID, - EventID: message.MXID, - DisappearingSetting: database.DisappearingSetting{ - Type: portal.Disappear.Type, - Timer: portal.Disappear.Timer, - DisappearAt: message.Timestamp.Add(portal.Disappear.Timer), - }, + RoomID: portal.MXID, + EventID: message.MXID, + DisappearingSetting: portal.Disappear.StartingAt(message.Timestamp), }) } if resp.Pending { @@ -4159,9 +4155,7 @@ type UpdateDisappearingSettingOpts struct { } func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting database.DisappearingSetting, opts UpdateDisappearingSettingOpts) bool { - if setting.Timer == 0 { - setting.Type = event.DisappearingTypeNone - } + setting = setting.Normalize() if portal.Disappear.Timer == setting.Timer && portal.Disappear.Type == setting.Type { return false } From b5bec2e96c2c65f424fb7d5a813f362afc0782eb Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 13 Sep 2025 13:44:46 +0300 Subject: [PATCH 1368/1647] client: stabilize support for state_after --- client.go | 2 +- responses.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index 71c3fb18..edeab732 100644 --- a/client.go +++ b/client.go @@ -754,7 +754,7 @@ func (req *ReqSync) BuildQuery() map[string]string { query["full_state"] = "true" } if req.UseStateAfter { - query["org.matrix.msc4222.use_state_after"] = "true" + query["use_state_after"] = "true" } if req.BeeperStreaming { query["com.beeper.streaming"] = "true" diff --git a/responses.go b/responses.go index 8ab78373..e2627724 100644 --- a/responses.go +++ b/responses.go @@ -397,7 +397,7 @@ type BeeperInboxPreviewEvent struct { type SyncJoinedRoom struct { Summary LazyLoadSummary `json:"summary"` State SyncEventsList `json:"state"` - StateAfter *SyncEventsList `json:"org.matrix.msc4222.state_after,omitempty"` + StateAfter *SyncEventsList `json:"state_after,omitempty"` Timeline SyncTimeline `json:"timeline"` Ephemeral SyncEventsList `json:"ephemeral"` AccountData SyncEventsList `json:"account_data"` From c37ddcc3a5e9bba4525bc650d593a7b5ce700fd8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 16 Sep 2025 14:45:37 +0300 Subject: [PATCH 1369/1647] Bump version to v0.25.1 --- CHANGELOG.md | 43 +++++++++++++++++++++++++++++++++++++++++++ go.mod | 22 +++++++++++----------- go.sum | 40 ++++++++++++++++++++-------------------- version.go | 2 +- 4 files changed, 75 insertions(+), 32 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f8a15550..5c33645f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,46 @@ +## v0.25.1 (2025-09-16) + +* *(client)* Fixed HTTP method of delete devices API call + (thanks to [@fmseals] in [#393]). +* *(client)* Added wrappers for [MSC4323]: User suspension & locking endpoints + (thanks to [@nexy7574] in [#407]). +* *(client)* Stabilized support for extensible profiles. +* *(client)* Stabilized support for `state_after` in sync. +* *(client)* Removed deprecated MSC2716 requests. +* *(crypto)* Added fallback to ensure `m.relates_to` is always copied even if + the content struct doesn't implement `Relatable`. +* *(crypto)* Changed olm unwedging to ignore newly created sessions if they + haven't been used successfully in either direction. +* *(federation)* Added utilities for generating, parsing, validating and + authorizing PDUs. + * Note: the new PDU code depends on `GOEXPERIMENT=jsonv2` +* *(event)* Added `is_animated` flag from [MSC4230] to file info. +* *(event)* Added types for [MSC4332]: In-room bot commands. +* *(event)* Added missing poll end event type for [MSC3381]. +* *(appservice)* Fixed URLs not being escaped properly when using unix socket + for homeserver connections. +* *(format)* Added more helpers for forming markdown links. +* *(event,bridgev2)* Added support for Beeper's disappearing message state event. +* *(bridgev2)* Redesigned group creation interface and added support in commands + and provisioning API. +* *(bridgev2)* Added GetEvent to Matrix interface to allow network connectors to + get an old event. The method is best effort only, as some configurations don't + allow fetching old events. +* *(bridgev2)* Added shared logic for provisioning that can be reused by the + API, commands and other sources. +* *(bridgev2)* Fixed mentions and URL previews not being copied over when + caption and media are merged. +* *(bridgev2)* Removed config option to change provisioning API prefix, which + had already broken in the previous release. + +[@fmseals]: https://github.com/fmseals +[#393]: https://github.com/mautrix/go/pull/393 +[#407]: https://github.com/mautrix/go/pull/407 +[MSC3381]: https://github.com/matrix-org/matrix-spec-proposals/pull/3381 +[MSC4230]: https://github.com/matrix-org/matrix-spec-proposals/pull/4230 +[MSC4323]: https://github.com/matrix-org/matrix-spec-proposals/pull/4323 +[MSC4332]: https://github.com/matrix-org/matrix-spec-proposals/pull/4332 + ## v0.25.0 (2025-08-16) * Bumped minimum Go version to 1.24. diff --git a/go.mod b/go.mod index 4abdc4ff..751e8015 100644 --- a/go.mod +++ b/go.mod @@ -2,27 +2,27 @@ module maunium.net/go/mautrix go 1.24.0 -toolchain go1.25.0 +toolchain go1.25.1 require ( filippo.io/edwards25519 v1.1.0 github.com/chzyer/readline v1.5.1 - github.com/coder/websocket v1.8.13 + github.com/coder/websocket v1.8.14 github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.32 github.com/rs/xid v1.6.0 github.com/rs/zerolog v1.34.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e - github.com/stretchr/testify v1.10.0 + github.com/stretchr/testify v1.11.1 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.13 - go.mau.fi/util v0.9.0 + go.mau.fi/util v0.9.1 go.mau.fi/zeroconfig v0.2.0 - golang.org/x/crypto v0.41.0 - golang.org/x/exp v0.0.0-20250813145105-42675adae3e6 - golang.org/x/net v0.43.0 - golang.org/x/sync v0.16.0 + golang.org/x/crypto v0.42.0 + golang.org/x/exp v0.0.0-20250911091902-df9299821621 + golang.org/x/net v0.44.0 + golang.org/x/sync v0.17.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 ) @@ -32,11 +32,11 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/petermattis/goid v0.0.0-20250813065127-a731cc31b4fe // indirect + github.com/petermattis/goid v0.0.0-20250904145737-900bdf8bb490 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect - golang.org/x/sys v0.35.0 // indirect - golang.org/x/text v0.28.0 // indirect + golang.org/x/sys v0.36.0 // indirect + golang.org/x/text v0.29.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index bb5d5cdb..dafb9600 100644 --- a/go.sum +++ b/go.sum @@ -8,8 +8,8 @@ github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= -github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE= -github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= +github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= +github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -26,8 +26,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/petermattis/goid v0.0.0-20250813065127-a731cc31b4fe h1:vHpqOnPlnkba8iSxU4j/CvDSS9J4+F4473esQsYLGoE= -github.com/petermattis/goid v0.0.0-20250813065127-a731cc31b4fe/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/petermattis/goid v0.0.0-20250904145737-900bdf8bb490 h1:QTvNkZ5ylY0PGgA+Lih+GdboMLY/G9SEGLMEGVjTVA4= +github.com/petermattis/goid v0.0.0-20250904145737-900bdf8bb490/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -37,8 +37,8 @@ github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= @@ -51,26 +51,26 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.13 h1:GPddIs617DnBLFFVJFgpo1aBfe/4xcvMc3SB5t/D0pA= github.com/yuin/goldmark v1.7.13/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.9.0 h1:ya3s3pX+Y8R2fgp0DbE7a0o3FwncoelDX5iyaeVE8ls= -go.mau.fi/util v0.9.0/go.mod h1:pdL3lg2aaeeHIreGXNnPwhJPXkXdc3ZxsI6le8hOWEA= +go.mau.fi/util v0.9.1 h1:A+XKHRsjKkFi2qOm4RriR1HqY2hoOXNS3WFHaC89r2Y= +go.mau.fi/util v0.9.1/go.mod h1:M0bM9SyaOWJniaHs9hxEzz91r5ql6gYq6o1q5O1SsjQ= go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= -golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= -golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= -golang.org/x/exp v0.0.0-20250813145105-42675adae3e6 h1:SbTAbRFnd5kjQXbczszQ0hdk3ctwYf3qBNH9jIsGclE= -golang.org/x/exp v0.0.0-20250813145105-42675adae3e6/go.mod h1:4QTo5u+SEIbbKW1RacMZq1YEfOBqeXa19JeshGi+zc4= -golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= -golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= -golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= -golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= +golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8= +golang.org/x/exp v0.0.0-20250911091902-df9299821621 h1:2id6c1/gto0kaHYyrixvknJ8tUK/Qs5IsmBtrc+FtgU= +golang.org/x/exp v0.0.0-20250911091902-df9299821621/go.mod h1:TwQYMMnGpvZyc+JpB/UAuTNIsVJifOlSkrZkhcvpVUk= +golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I= +golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY= +golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= +golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= -golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= -golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= +golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= +golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= diff --git a/version.go b/version.go index fd0d0a8d..4821a354 100644 --- a/version.go +++ b/version.go @@ -7,7 +7,7 @@ import ( "strings" ) -const Version = "v0.25.0" +const Version = "v0.25.1" var GoModVersion = "" var Commit = "" From 5af25d2eb7371805df62a47dc52f2b72e956ecac Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 16 Sep 2025 18:02:14 +0300 Subject: [PATCH 1370/1647] event/poll: add missing omitempty --- event/poll.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/event/poll.go b/event/poll.go index 47131a8f..9082f65e 100644 --- a/event/poll.go +++ b/event/poll.go @@ -35,7 +35,7 @@ type MSC1767Message struct { } type PollStartEventContent struct { - RelatesTo *RelatesTo `json:"m.relates_to"` + RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` Mentions *Mentions `json:"m.mentions,omitempty"` PollStart struct { Kind string `json:"kind"` From af2e6c7ce0dede45f4829689061c14db61d5a065 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 17 Sep 2025 14:47:09 +0300 Subject: [PATCH 1371/1647] bridgev2/portal: ensure state key is set when handling state events --- bridgev2/errors.go | 1 + bridgev2/portal.go | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/bridgev2/errors.go b/bridgev2/errors.go index 026a95f4..52bebe81 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -55,6 +55,7 @@ var ( ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage() ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) + ErrInvalidStateKey error = WrapErrorInStatus(errors.New("room metadata state key is unset or non-empty")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true) ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 7961a223..39b8272b 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1449,6 +1449,9 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( evt *event.Event, fn func(APIType, context.Context, *MatrixRoomMeta[ContentType]) (bool, error), ) EventHandlingResult { + if evt.StateKey == nil || *evt.StateKey != "" { + return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey) + } api, ok := sender.Client.(APIType) if !ok { return EventHandlingResultIgnored.WithMSSError(ErrRoomMetadataNotSupported) @@ -1583,6 +1586,9 @@ func (portal *Portal) handleMatrixMembership( origSender *OrigSender, evt *event.Event, ) EventHandlingResult { + if evt.StateKey == nil { + return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey) + } log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(*event.MemberEventContent) if !ok { @@ -1667,6 +1673,9 @@ func (portal *Portal) handleMatrixPowerLevels( origSender *OrigSender, evt *event.Event, ) EventHandlingResult { + if evt.StateKey == nil || *evt.StateKey != "" { + return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey) + } log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(*event.PowerLevelsEventContent) if !ok { From e6a1fa6fd7bd906a176abc914456a5f1680a8b9b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 17 Sep 2025 14:18:43 +0200 Subject: [PATCH 1372/1647] bridgev2/provisioning: sync ghost info when searching (#413) --- bridgev2/provisionutil/listcontacts.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/bridgev2/provisionutil/listcontacts.go b/bridgev2/provisionutil/listcontacts.go index d2cf5e90..ce163e67 100644 --- a/bridgev2/provisionutil/listcontacts.go +++ b/bridgev2/provisionutil/listcontacts.go @@ -34,7 +34,7 @@ func GetContactList(ctx context.Context, login *bridgev2.UserLogin) (*RespGetCon return nil, err } return &RespGetContactList{ - Contacts: processResolveIdentifiers(ctx, login.Bridge, resp), + Contacts: processResolveIdentifiers(ctx, login.Bridge, resp, false), }, nil } @@ -49,11 +49,11 @@ func SearchUsers(ctx context.Context, login *bridgev2.UserLogin, query string) ( return nil, err } return &RespSearchUsers{ - Results: processResolveIdentifiers(ctx, login.Bridge, resp), + Results: processResolveIdentifiers(ctx, login.Bridge, resp, true), }, nil } -func processResolveIdentifiers(ctx context.Context, br *bridgev2.Bridge, resp []*bridgev2.ResolveIdentifierResponse) (apiResp []*RespResolveIdentifier) { +func processResolveIdentifiers(ctx context.Context, br *bridgev2.Bridge, resp []*bridgev2.ResolveIdentifierResponse, syncInfo bool) (apiResp []*RespResolveIdentifier) { apiResp = make([]*RespResolveIdentifier, len(resp)) for i, contact := range resp { apiContact := &RespResolveIdentifier{ @@ -69,6 +69,9 @@ func processResolveIdentifiers(ctx context.Context, br *bridgev2.Bridge, resp [] } } if contact.Ghost != nil { + if syncInfo && contact.UserInfo != nil { + contact.Ghost.UpdateInfo(ctx, contact.UserInfo) + } if contact.Ghost.Name != "" { apiContact.Name = contact.Ghost.Name } From 35ac4fcb8d91b38463521551a30040e1217678c4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 17 Sep 2025 21:45:30 +0300 Subject: [PATCH 1373/1647] bridgev2/matrix: don't encrypt reactions in batch sends --- bridgev2/matrix/connector.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index c5ee40fe..ab1764dd 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -643,7 +643,7 @@ func (br *Connector) BatchSend(ctx context.Context, roomID id.RoomID, req *mautr if intent != nil { intent.AddDoublePuppetValueWithTS(&evt.Content, evt.Timestamp) } - if evt.Type != event.EventEncrypted { + if evt.Type != event.EventEncrypted && evt.Type != event.EventReaction { err = br.Crypto.Encrypt(ctx, roomID, evt.Type, &evt.Content) if err != nil { return nil, err From 5b860f8bfb3c26c39e73db0ebf9fd816ae43c745 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 17 Sep 2025 22:30:16 +0300 Subject: [PATCH 1374/1647] responses: fix marshaling RespUserProfile --- responses.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/responses.go b/responses.go index e2627724..4d66cdb8 100644 --- a/responses.go +++ b/responses.go @@ -210,7 +210,7 @@ func (r *RespUserProfile) MarshalJSON() ([]byte, error) { } else { delete(marshalMap, "avatar_url") } - return json.Marshal(r.Extra) + return json.Marshal(marshalMap) } type RespMutualRooms struct { From e932aff2090f56bc1fbb07832cb2e150afd1a553 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 17 Sep 2025 22:30:32 +0300 Subject: [PATCH 1375/1647] crypto/ssss: use constant time comparison when decrypting account data --- crypto/ssss/key.go | 12 ++++++++++-- crypto/ssss/meta_test.go | 17 +++-------------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/crypto/ssss/key.go b/crypto/ssss/key.go index aa22360a..cd8e3fce 100644 --- a/crypto/ssss/key.go +++ b/crypto/ssss/key.go @@ -7,6 +7,8 @@ package ssss import ( + "crypto/hmac" + "crypto/sha256" "encoding/base64" "fmt" "strings" @@ -108,12 +110,18 @@ func (key *Key) Decrypt(eventType string, data EncryptedKeyData) ([]byte, error) return nil, err } + mac, err := base64.RawStdEncoding.DecodeString(strings.TrimRight(data.MAC, "=")) + if err != nil { + return nil, err + } + // derive the AES and HMAC keys for the requested event type using the SSSS key aesKey, hmacKey := utils.DeriveKeysSHA256(key.Key, eventType) // compare the stored MAC with the one we calculated from the ciphertext - calcMac := utils.HMACSHA256B64(payload, hmacKey) - if strings.TrimRight(data.MAC, "=") != calcMac { + h := hmac.New(sha256.New, hmacKey[:]) + h.Write(payload) + if !hmac.Equal(h.Sum(nil), mac) { return nil, ErrKeyDataMACMismatch } diff --git a/crypto/ssss/meta_test.go b/crypto/ssss/meta_test.go index 4f2ff378..7a5ef8b9 100644 --- a/crypto/ssss/meta_test.go +++ b/crypto/ssss/meta_test.go @@ -12,6 +12,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "go.mau.fi/util/exerrors" "maunium.net/go/mautrix/crypto/ssss" ) @@ -70,23 +71,11 @@ func getKeyMeta(meta string) *ssss.KeyMetadata { } func getKey1() *ssss.Key { - km := getKeyMeta(key1Meta) - key, err := km.VerifyRecoveryKey(key1ID, key1RecoveryKey) - if err != nil { - panic(err) - } - key.ID = key1ID - return key + return exerrors.Must(getKeyMeta(key1Meta).VerifyRecoveryKey(key1ID, key1RecoveryKey)) } func getKey2() *ssss.Key { - km := getKeyMeta(key2Meta) - key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) - if err != nil { - panic(err) - } - key.ID = key2ID - return key + return exerrors.Must(getKeyMeta(key2Meta).VerifyRecoveryKey(key2ID, key2RecoveryKey)) } func TestKeyMetadata_VerifyRecoveryKey_Correct(t *testing.T) { From e19d009d59ef914d76f6f3b2729fa5073084d1cb Mon Sep 17 00:00:00 2001 From: Tiago Loureiro Date: Thu, 18 Sep 2025 11:07:13 -0300 Subject: [PATCH 1376/1647] event: add EventUnstablePollEnd to GuessClass() (#414) --- event/type.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/event/type.go b/event/type.go index 5035f2fa..1b4fbf76 100644 --- a/event/type.go +++ b/event/type.go @@ -128,7 +128,7 @@ func (et *Type) GuessClass() TypeClass { InRoomVerificationKey.Type, InRoomVerificationMAC.Type, InRoomVerificationCancel.Type, CallInvite.Type, CallCandidates.Type, CallAnswer.Type, CallReject.Type, CallSelectAnswer.Type, CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type, EventUnstablePollStart.Type, EventUnstablePollResponse.Type, - BeeperTranscription.Type: + EventUnstablePollEnd.Type, BeeperTranscription.Type: return MessageEventType case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type, ToDeviceBeeperRoomKeyAck.Type: From 8780c2eb449948f7ccaffffbbfe6c081c26266d1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Sep 2025 13:22:58 +0300 Subject: [PATCH 1377/1647] bridgev2/portal: set exclude from timeline flag for creation state --- bridgev2/portal.go | 41 ++++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 39b8272b..bbe41e02 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -4444,8 +4444,6 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo req := mautrix.ReqCreateRoom{ Visibility: "private", - Name: portal.Name, - Topic: portal.Topic, CreationContent: make(map[string]any), InitialState: make([]*event.Event, 0, 6), Preset: "private_chat", @@ -4488,26 +4486,47 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo StateKey: &bridgeInfoStateKey, Type: event.StateBeeperRoomFeatures, Content: event.Content{Parsed: roomFeatures}, + }, &event.Event{ + Type: event.StateTopic, + Content: event.Content{ + Parsed: &event.TopicEventContent{Topic: portal.Topic}, + Raw: map[string]any{ + "com.beeper.exclude_from_timeline": true, + }, + }, }) if roomFeatures.DisappearingTimer != nil { req.InitialState = append(req.InitialState, &event.Event{ - Type: event.StateBeeperDisappearingTimer, - Content: event.Content{Parsed: portal.Disappear.ToEventContent()}, + Type: event.StateBeeperDisappearingTimer, + Content: event.Content{ + Parsed: portal.Disappear.ToEventContent(), + Raw: map[string]any{ + "com.beeper.exclude_from_timeline": true, + }, + }, }) portal.CapState.Flags |= database.CapStateFlagDisappearingTimerSet } - if req.Topic == "" { - // Add explicit topic event if topic is empty to ensure the event is set. - // This ensures that there won't be an extra event later if PUT /state/... is called. + if portal.Name != "" { req.InitialState = append(req.InitialState, &event.Event{ - Type: event.StateTopic, - Content: event.Content{Parsed: &event.TopicEventContent{Topic: ""}}, + Type: event.StateRoomName, + Content: event.Content{ + Parsed: &event.RoomNameEventContent{Name: portal.Name}, + Raw: map[string]any{ + "com.beeper.exclude_from_timeline": true, + }, + }, }) } if portal.AvatarMXC != "" { req.InitialState = append(req.InitialState, &event.Event{ - Type: event.StateRoomAvatar, - Content: event.Content{Parsed: &event.RoomAvatarEventContent{URL: portal.AvatarMXC}}, + Type: event.StateRoomAvatar, + Content: event.Content{ + Parsed: &event.RoomAvatarEventContent{URL: portal.AvatarMXC}, + Raw: map[string]any{ + "com.beeper.exclude_from_timeline": true, + }, + }, }) } if portal.Parent != nil && portal.Parent.MXID != "" { From b760023dcaa3770d44557c4a9f99ad1dfa27a07f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Sep 2025 14:26:21 +0300 Subject: [PATCH 1378/1647] bridgev2/portal: add support for implicit read receipts to network --- bridgev2/networkinterface.go | 6 ++++ bridgev2/portal.go | 59 +++++++++++++++++++++++++++--------- 2 files changed, 50 insertions(+), 15 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 8293be51..fa87086a 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -346,6 +346,10 @@ type NetworkGeneralCapabilities struct { // Should the bridge re-request user info on incoming messages even if the ghost already has info? // By default, info is only requested for ghosts with no name, and other updating is left to events. AggressiveUpdateInfo bool + // Should the bridge call HandleMatrixReadReceipt with fake data when receiving a new message? + // This should be enabled if the network requires each message to be marked as read independently, + // and doesn't automatically do it when sending a message. + ImplicitReadReceipts bool // If the bridge uses the pending message mechanism ([MatrixMessage.AddPendingToSave]) // to handle asynchronous message responses, this field can be set to enable // automatic timeout errors in case the asynchronous response never arrives. @@ -1361,6 +1365,8 @@ type MatrixReadReceipt struct { LastRead time.Time // The receipt metadata. Receipt event.ReadReceipt + // Whether the receipt is implicit, i.e. triggered by an incoming timeline event rather than an explicit receipt. + Implicit bool } type MatrixTyping struct { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index bbe41e02..8924566f 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -587,7 +587,7 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * // Tombstones aren't bridged so they don't need a login return portal.handleMatrixTombstone(ctx, evt) } - login, _, err := portal.FindPreferredLogin(ctx, sender, true) + login, userPortal, err := portal.FindPreferredLogin(ctx, sender, true) if err != nil { log.Err(err).Msg("Failed to get user login to handle Matrix event") if errors.Is(err, ErrNotLoggedIn) { @@ -646,6 +646,21 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * } // Copy logger because many of the handlers will use UpdateContext ctx = log.With().Str("login_id", string(login.ID)).Logger().WithContext(ctx) + + if origSender == nil && portal.Bridge.Network.GetCapabilities().ImplicitReadReceipts { + rrLog := log.With().Str("subaction", "implicit read receipt").Logger() + rrCtx := rrLog.WithContext(ctx) + rrLog.Debug().Msg("Sending implicit read receipt for event") + evtTS := time.UnixMilli(evt.Timestamp) + portal.callReadReceiptHandler(rrCtx, login, nil, &MatrixReadReceipt{ + Portal: portal, + EventID: evt.ID, + Implicit: true, + ReadUpTo: evtTS, + Receipt: event.ReadReceipt{Timestamp: evtTS}, + }, userPortal) + } + switch evt.Type { case event.EventMessage, event.EventSticker, event.EventUnstablePollStart, event.EventUnstablePollResponse: return portal.handleMatrixMessage(ctx, login, origSender, evt) @@ -735,15 +750,10 @@ func (portal *Portal) handleMatrixReadReceipt(ctx context.Context, user *User, e EventID: eventID, Receipt: receipt, } - if userPortal == nil { - userPortal = database.UserPortalFor(login.UserLogin, portal.PortalKey) - } else { - evt.LastRead = userPortal.LastRead - userPortal = userPortal.CopyWithoutValues() - } evt.ExactMessage, err = portal.Bridge.DB.Message.GetPartByMXID(ctx, eventID) if err != nil { log.Err(err).Msg("Failed to get exact message from database") + evt.ReadUpTo = receipt.Timestamp } else if evt.ExactMessage != nil { log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Str("exact_message_id", string(evt.ExactMessage.ID)).Time("exact_message_ts", evt.ExactMessage.Timestamp) @@ -752,19 +762,38 @@ func (portal *Portal) handleMatrixReadReceipt(ctx context.Context, user *User, e } else { evt.ReadUpTo = receipt.Timestamp } - err = rrClient.HandleMatrixReadReceipt(ctx, evt) + portal.callReadReceiptHandler(ctx, login, rrClient, evt, userPortal) +} + +func (portal *Portal) callReadReceiptHandler( + ctx context.Context, + login *UserLogin, + rrClient ReadReceiptHandlingNetworkAPI, + evt *MatrixReadReceipt, + userPortal *database.UserPortal, +) { + if rrClient == nil { + var ok bool + rrClient, ok = login.Client.(ReadReceiptHandlingNetworkAPI) + if !ok { + return + } + } + if userPortal == nil { + userPortal = database.UserPortalFor(login.UserLogin, portal.PortalKey) + } else { + evt.LastRead = userPortal.LastRead + userPortal = userPortal.CopyWithoutValues() + } + err := rrClient.HandleMatrixReadReceipt(ctx, evt) if err != nil { - log.Err(err).Msg("Failed to handle read receipt") + zerolog.Ctx(ctx).Err(err).Msg("Failed to handle read receipt") return } - if evt.ExactMessage != nil { - userPortal.LastRead = evt.ExactMessage.Timestamp - } else { - userPortal.LastRead = receipt.Timestamp - } + userPortal.LastRead = evt.ReadUpTo err = portal.Bridge.DB.UserPortal.Put(ctx, userPortal) if err != nil { - log.Err(err).Msg("Failed to save user portal metadata") + zerolog.Ctx(ctx).Err(err).Msg("Failed to save user portal metadata") } portal.Bridge.DisappearLoop.StartAll(ctx, portal.MXID) } From 6acb04aa1e9aa21e361de21b28a7ab65739ea163 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Sep 2025 19:15:02 +0300 Subject: [PATCH 1379/1647] federation/pdu: use option to trust internal metadata for GetEventID --- federation/pdu/hash.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/federation/pdu/hash.go b/federation/pdu/hash.go index 050029df..38ef83e9 100644 --- a/federation/pdu/hash.go +++ b/federation/pdu/hash.go @@ -72,7 +72,12 @@ func (pdu *PDU) GetRoomID() (id.RoomID, error) { } } +var UseInternalMetaForGetEventID = false + func (pdu *PDU) GetEventID(roomVersion id.RoomVersion) (id.EventID, error) { + if UseInternalMetaForGetEventID && pdu.InternalMeta.EventID != "" { + return pdu.InternalMeta.EventID, nil + } return pdu.calculateEventID(roomVersion, '$') } From 2240aa0267fbf9b651a8604598d3d1842f24126f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Sep 2025 19:50:41 +0300 Subject: [PATCH 1380/1647] bridgev2/portal: log if room create event is taking long --- bridgev2/portal.go | 82 +++++++++++++++++++++++----------------------- 1 file changed, 41 insertions(+), 41 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 8924566f..d2a02188 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -329,55 +329,55 @@ func (portal *Portal) eventLoop() { func (portal *Portal) handleSingleEventAsync(idx int, rawEvt any) (outerRes EventHandlingResult) { ctx := portal.getEventCtxWithLog(rawEvt, idx) - if _, isCreate := rawEvt.(*portalCreateEvent); isCreate { - portal.handleSingleEvent(ctx, rawEvt, func(res EventHandlingResult) { - outerRes = res - }) - } else if portal.Bridge.Config.AsyncEvents { - outerRes = EventHandlingResultQueued + if portal.Bridge.Config.AsyncEvents { go portal.handleSingleEvent(ctx, rawEvt, func(res EventHandlingResult) {}) - } else { - log := zerolog.Ctx(ctx) - doneCh := make(chan struct{}) - var backgrounded atomic.Bool - start := time.Now() - var handleDuration time.Duration - // Note: this will not set the success flag if the handler times out - outerRes = EventHandlingResult{Queued: true} - go portal.handleSingleEvent(ctx, rawEvt, func(res EventHandlingResult) { - outerRes = res - handleDuration = time.Since(start) - close(doneCh) - if backgrounded.Load() { + return EventHandlingResultQueued + } + log := zerolog.Ctx(ctx) + doneCh := make(chan struct{}) + var backgrounded atomic.Bool + start := time.Now() + var handleDuration time.Duration + // Note: this will not set the success flag if the handler times out + outerRes = EventHandlingResult{Queued: true} + go portal.handleSingleEvent(ctx, rawEvt, func(res EventHandlingResult) { + outerRes = res + handleDuration = time.Since(start) + close(doneCh) + if backgrounded.Load() { + log.Debug(). + Time("started_at", start). + Stringer("duration", handleDuration). + Msg("Event that took too long finally finished handling") + } + }) + tick := time.NewTicker(30 * time.Second) + _, isCreate := rawEvt.(*portalCreateEvent) + defer tick.Stop() + for i := 0; i < 10; i++ { + select { + case <-doneCh: + if i > 0 { log.Debug(). Time("started_at", start). Stringer("duration", handleDuration). - Msg("Event that took too long finally finished handling") + Msg("Event that took long finished handling") } - }) - tick := time.NewTicker(30 * time.Second) - defer tick.Stop() - for i := 0; i < 10; i++ { - select { - case <-doneCh: - if i > 0 { - log.Debug(). - Time("started_at", start). - Stringer("duration", handleDuration). - Msg("Event that took long finished handling") - } - return - case <-tick.C: - log.Warn(). - Time("started_at", start). - Msg("Event handling is taking long") + return + case <-tick.C: + log.Warn(). + Time("started_at", start). + Msg("Event handling is taking long") + if isCreate { + // Never background portal creation events + i = 1 } } - log.Warn(). - Time("started_at", start). - Msg("Event handling is taking too long, continuing in background") - backgrounded.Store(true) } + log.Warn(). + Time("started_at", start). + Msg("Event handling is taking too long, continuing in background") + backgrounded.Store(true) return } From b42fb5096aab3fe5e8f3c462433745030e17ba54 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Sep 2025 19:53:22 +0300 Subject: [PATCH 1381/1647] bridgev2/portal: also log long events when using async events --- bridgev2/portal.go | 14 +++++++------- bridgev2/portalinternal.go | 8 ++++++-- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index d2a02188..f1c06171 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -291,7 +291,7 @@ func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) EventHand portal.eventsLock.Lock() defer portal.eventsLock.Unlock() portal.eventIdx++ - return portal.handleSingleEventAsync(portal.eventIdx, evt) + return portal.handleSingleEventWithDelayLogging(portal.eventIdx, evt) } else { select { case portal.events <- evt: @@ -323,16 +323,16 @@ func (portal *Portal) eventLoop() { i := 0 for rawEvt := range portal.events { i++ - portal.handleSingleEventAsync(i, rawEvt) + if portal.Bridge.Config.AsyncEvents { + go portal.handleSingleEventWithDelayLogging(i, rawEvt) + } else { + portal.handleSingleEventWithDelayLogging(i, rawEvt) + } } } -func (portal *Portal) handleSingleEventAsync(idx int, rawEvt any) (outerRes EventHandlingResult) { +func (portal *Portal) handleSingleEventWithDelayLogging(idx int, rawEvt any) (outerRes EventHandlingResult) { ctx := portal.getEventCtxWithLog(rawEvt, idx) - if portal.Bridge.Config.AsyncEvents { - go portal.handleSingleEvent(ctx, rawEvt, func(res EventHandlingResult) {}) - return EventHandlingResultQueued - } log := zerolog.Ctx(ctx) doneCh := make(chan struct{}) var backgrounded atomic.Bool diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index 0223b4f2..ddbadc76 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -37,8 +37,8 @@ func (portal *PortalInternals) EventLoop() { (*Portal)(portal).eventLoop() } -func (portal *PortalInternals) HandleSingleEventAsync(idx int, rawEvt any) (outerRes EventHandlingResult) { - return (*Portal)(portal).handleSingleEventAsync(idx, rawEvt) +func (portal *PortalInternals) HandleSingleEventWithDelayLogging(idx int, rawEvt any) (outerRes EventHandlingResult) { + return (*Portal)(portal).handleSingleEventWithDelayLogging(idx, rawEvt) } func (portal *PortalInternals) GetEventCtxWithLog(rawEvt any, idx int) context.Context { @@ -73,6 +73,10 @@ func (portal *PortalInternals) HandleMatrixReadReceipt(ctx context.Context, user (*Portal)(portal).handleMatrixReadReceipt(ctx, user, eventID, receipt) } +func (portal *PortalInternals) CallReadReceiptHandler(ctx context.Context, login *UserLogin, rrClient ReadReceiptHandlingNetworkAPI, evt *MatrixReadReceipt, userPortal *database.UserPortal) { + (*Portal)(portal).callReadReceiptHandler(ctx, login, rrClient, evt, userPortal) +} + func (portal *PortalInternals) HandleMatrixTyping(ctx context.Context, evt *event.Event) EventHandlingResult { return (*Portal)(portal).handleMatrixTyping(ctx, evt) } From 9fbf1b85981027257d771ba895f24510f601d404 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Sep 2025 20:26:55 +0300 Subject: [PATCH 1382/1647] bridgev2: make split portal migration errors fatal --- bridgev2/bridge.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 24619c79..fe7bd107 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -9,6 +9,7 @@ package bridgev2 import ( "context" "fmt" + "os" "sync" "time" @@ -279,7 +280,8 @@ func (br *Bridge) MigrateToSplitPortals(ctx context.Context) bool { } affected, err := br.DB.Portal.MigrateToSplitPortals(ctx) if err != nil { - log.Err(err).Msg("Failed to migrate portals") + log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to migrate portals") + os.Exit(31) return false } log.Info().Int64("rows_affected", affected).Msg("Migrated to split portals") From f7bfa885c9c2299dfb83eb9c5ab51afbdeec05e8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Sep 2025 20:42:38 +0300 Subject: [PATCH 1383/1647] bridgev2: improve split portal migration --- bridgev2/bridge.go | 26 +++++++++++++++++++ bridgev2/database/portal.go | 30 +++++++++++++++------- bridgev2/matrix/intent.go | 3 +++ bridgev2/matrix/mxmain/example-config.yaml | 1 + 4 files changed, 51 insertions(+), 9 deletions(-) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index fe7bd107..5a0ae30c 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -285,7 +285,33 @@ func (br *Bridge) MigrateToSplitPortals(ctx context.Context) bool { return false } log.Info().Int64("rows_affected", affected).Msg("Migrated to split portals") + withoutReceiver, err := br.DB.Portal.GetAllWithoutReceiver(ctx) + if err != nil { + log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to get portals that failed to migrate") + os.Exit(32) + return false + } + log.Info().Int("remaining_portals", len(withoutReceiver)).Msg("Deleting remaining portals without receiver") + for _, portal := range withoutReceiver { + if err = br.DB.Portal.Delete(ctx, portal.PortalKey); err != nil { + log.Err(err). + Str("portal_id", string(portal.ID)). + Stringer("mxid", portal.MXID). + Msg("Failed to delete portal database row that failed to migrate") + } else if err = br.Bot.DeleteRoom(ctx, portal.MXID, true); err != nil { + log.Err(err). + Str("portal_id", string(portal.ID)). + Stringer("mxid", portal.MXID). + Msg("Failed to delete portal room that failed to migrate") + } else { + log.Debug(). + Str("portal_id", string(portal.ID)). + Stringer("mxid", portal.MXID). + Msg("Deleted portal that wasn't updated by split portal migration query") + } + } br.DB.KV.Set(ctx, database.KeySplitPortalsEnabled, "true") + log.Info().Msg("Finished split portal migration successfully") return affected > 0 } diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index c3aa7121..8570d840 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -88,6 +88,7 @@ const ( getPortalByIDWithUncertainReceiverQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='')` getPortalByMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` getAllPortalsWithMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid IS NOT NULL` + getAllPortalsWithoutReceiver = getPortalBaseQuery + `WHERE bridge_id=$1 AND receiver=''` getAllDMPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND other_user_id=$2` getAllPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1` getChildPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND parent_id=$2 AND parent_receiver=$3` @@ -123,15 +124,22 @@ const ( reIDPortalQuery = `UPDATE portal SET id=$4, receiver=$5 WHERE bridge_id=$1 AND id=$2 AND receiver=$3` migrateToSplitPortalsQuery = ` UPDATE portal - SET receiver=COALESCE(( - SELECT login_id - FROM user_portal - WHERE bridge_id=portal.bridge_id AND portal_id=portal.id AND portal_receiver='' - LIMIT 1 - ), ( - SELECT id FROM user_login WHERE bridge_id=portal.bridge_id LIMIT 1 - ), '') - WHERE receiver='' AND bridge_id=$1 + SET receiver=new_receiver + FROM ( + SELECT bridge_id, id, COALESCE(( + SELECT login_id + FROM user_portal + WHERE bridge_id=portal.bridge_id AND portal_id=portal.id AND portal_receiver='' + LIMIT 1 + ), ( + SELECT id FROM user_login WHERE bridge_id=portal.bridge_id LIMIT 1 + ), '') AS new_receiver + FROM portal + WHERE receiver='' AND bridge_id=$1 + ) updates + WHERE portal.bridge_id=updates.bridge_id AND portal.id=updates.id AND portal.receiver='' AND NOT EXISTS ( + SELECT 1 FROM portal p2 WHERE p2.bridge_id=updates.bridge_id AND p2.id=updates.id AND p2.receiver=updates.new_receiver + ) ` ) @@ -159,6 +167,10 @@ func (pq *PortalQuery) GetAllWithMXID(ctx context.Context) ([]*Portal, error) { return pq.QueryMany(ctx, getAllPortalsWithMXIDQuery, pq.BridgeID) } +func (pq *PortalQuery) GetAllWithoutReceiver(ctx context.Context) ([]*Portal, error) { + return pq.QueryMany(ctx, getAllPortalsWithoutReceiver, pq.BridgeID) +} + func (pq *PortalQuery) GetAll(ctx context.Context) ([]*Portal, error) { return pq.QueryMany(ctx, getAllPortalsQuery, pq.BridgeID) } diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 2c68a692..ab59a582 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -577,6 +577,9 @@ func (as *ASIntent) MarkAsDM(ctx context.Context, roomID id.RoomID, withUser id. } func (as *ASIntent) DeleteRoom(ctx context.Context, roomID id.RoomID, puppetsOnly bool) error { + if roomID == "" { + return nil + } if as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureRoomYeeting) { err := as.Matrix.BeeperDeleteRoom(ctx, roomID) if err != nil { diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 488f0b4c..95fa13eb 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -15,6 +15,7 @@ bridge: # By default, users who are in the same group on the remote network will be # in the same Matrix room bridged to that group. If this is set to true, # every user will get their own Matrix room instead. + # SETTING THIS IS IRREVERSIBLE AND POTENTIALLY DESTRUCTIVE IF PORTALS ALREADY EXIST. split_portals: false # Should the bridge resend `m.bridge` events to all portals on startup? resend_bridge_info: false From 820d0ee66bbd0e9ee4eab776e634f533a9fb5ef8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Sep 2025 21:01:34 +0300 Subject: [PATCH 1384/1647] bridgev2: only delete rooms in split portal migration after starting connectors --- bridgev2/bridge.go | 40 +++++++++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 5a0ae30c..83418290 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -183,7 +183,11 @@ func (br *Bridge) StartConnectors(ctx context.Context) error { } } if !br.Background { - br.didSplitPortals = br.MigrateToSplitPortals(ctx) + var postMigrate func() + br.didSplitPortals, postMigrate = br.MigrateToSplitPortals(ctx) + if postMigrate != nil { + defer postMigrate() + } } br.Log.Info().Msg("Starting Matrix connector") err := br.Matrix.Start(ctx) @@ -272,25 +276,26 @@ func (br *Bridge) ResendBridgeInfo(ctx context.Context, resendInfo, resendCaps b Msg("Resent bridge info to all portals") } -func (br *Bridge) MigrateToSplitPortals(ctx context.Context) bool { +func (br *Bridge) MigrateToSplitPortals(ctx context.Context) (bool, func()) { log := zerolog.Ctx(ctx).With().Str("action", "migrate to split portals").Logger() ctx = log.WithContext(ctx) if !br.Config.SplitPortals || br.DB.KV.Get(ctx, database.KeySplitPortalsEnabled) == "true" { - return false + return false, nil } affected, err := br.DB.Portal.MigrateToSplitPortals(ctx) if err != nil { log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to migrate portals") os.Exit(31) - return false + return false, nil } log.Info().Int64("rows_affected", affected).Msg("Migrated to split portals") withoutReceiver, err := br.DB.Portal.GetAllWithoutReceiver(ctx) if err != nil { log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to get portals that failed to migrate") os.Exit(32) - return false + return false, nil } + var roomsToDelete []id.RoomID log.Info().Int("remaining_portals", len(withoutReceiver)).Msg("Deleting remaining portals without receiver") for _, portal := range withoutReceiver { if err = br.DB.Portal.Delete(ctx, portal.PortalKey); err != nil { @@ -298,21 +303,30 @@ func (br *Bridge) MigrateToSplitPortals(ctx context.Context) bool { Str("portal_id", string(portal.ID)). Stringer("mxid", portal.MXID). Msg("Failed to delete portal database row that failed to migrate") - } else if err = br.Bot.DeleteRoom(ctx, portal.MXID, true); err != nil { - log.Err(err). - Str("portal_id", string(portal.ID)). - Stringer("mxid", portal.MXID). - Msg("Failed to delete portal room that failed to migrate") - } else { + } else if portal.MXID != "" { log.Debug(). Str("portal_id", string(portal.ID)). Stringer("mxid", portal.MXID). - Msg("Deleted portal that wasn't updated by split portal migration query") + Msg("Marked portal room for deletion from homeserver") + roomsToDelete = append(roomsToDelete, portal.MXID) + } else { + log.Debug(). + Str("portal_id", string(portal.ID)). + Msg("Deleted portal row with no Matrix room") } } br.DB.KV.Set(ctx, database.KeySplitPortalsEnabled, "true") log.Info().Msg("Finished split portal migration successfully") - return affected > 0 + return affected > 0, func() { + for _, roomID := range roomsToDelete { + if err = br.Bot.DeleteRoom(ctx, roomID, true); err != nil { + log.Err(err). + Stringer("mxid", roomID). + Msg("Failed to delete portal room that failed to migrate") + } + } + log.Info().Int("room_count", len(roomsToDelete)).Msg("Finished deleting rooms that failed to migrate") + } } func (br *Bridge) StartLogins(ctx context.Context) error { From 54c0e5c2f623459ff885a19453f16699378a8356 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Sep 2025 21:19:01 +0300 Subject: [PATCH 1385/1647] bridgev2/portal: remove portal from cache if loading parent/relay fails --- bridgev2/portal.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index f1c06171..be029f25 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -123,6 +123,8 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que currentlyTypingGhosts: exsync.NewSet[id.UserID](), outgoingMessages: make(map[networkid.TransactionID]*outgoingMessage), } + // Putting the portal in the cache before it's fully initialized is mildly dangerous, + // but loading the relay user login may depend on it. br.portalsByKey[portal.PortalKey] = portal if portal.MXID != "" { br.portalsByMXID[portal.MXID] = portal @@ -131,12 +133,20 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que if portal.ParentKey.ID != "" { portal.Parent, err = br.UnlockedGetPortalByKey(ctx, portal.ParentKey, false) if err != nil { + delete(br.portalsByKey, portal.PortalKey) + if portal.MXID != "" { + delete(br.portalsByMXID, portal.MXID) + } return nil, fmt.Errorf("failed to load parent portal (%s): %w", portal.ParentKey, err) } } if portal.RelayLoginID != "" { portal.Relay, err = br.unlockedGetExistingUserLoginByID(ctx, portal.RelayLoginID) if err != nil { + delete(br.portalsByKey, portal.PortalKey) + if portal.MXID != "" { + delete(br.portalsByMXID, portal.MXID) + } return nil, fmt.Errorf("failed to load relay login (%s): %w", portal.RelayLoginID, err) } } From fbf8718e229a5a2554bf04478ef72e65d1f8f96e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Sep 2025 21:19:27 +0300 Subject: [PATCH 1386/1647] bridgev2: also fix portal parent receivers in split portal migration --- bridgev2/bridge.go | 9 ++++++++- bridgev2/database/portal.go | 16 ++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 83418290..2ad6a614 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -289,10 +289,17 @@ func (br *Bridge) MigrateToSplitPortals(ctx context.Context) (bool, func()) { return false, nil } log.Info().Int64("rows_affected", affected).Msg("Migrated to split portals") + affected2, err := br.DB.Portal.FixParentsAfterSplitPortalMigration(ctx) + if err != nil { + log.Err(err).Msg("Failed to fix parent portals after split portal migration") + os.Exit(31) + return false, nil + } + log.Info().Int64("rows_affected", affected2).Msg("Updated parent receivers after split portal migration") withoutReceiver, err := br.DB.Portal.GetAllWithoutReceiver(ctx) if err != nil { log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to get portals that failed to migrate") - os.Exit(32) + os.Exit(31) return false, nil } var roomsToDelete []id.RoomID diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index 8570d840..e02b9e44 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -131,6 +131,11 @@ const ( FROM user_portal WHERE bridge_id=portal.bridge_id AND portal_id=portal.id AND portal_receiver='' LIMIT 1 + ), ( + SELECT login_id + FROM user_portal + WHERE portal.parent_id<>'' AND bridge_id=portal.bridge_id AND portal_id=portal.parent_id + LIMIT 1 ), ( SELECT id FROM user_login WHERE bridge_id=portal.bridge_id LIMIT 1 ), '') AS new_receiver @@ -141,6 +146,9 @@ const ( SELECT 1 FROM portal p2 WHERE p2.bridge_id=updates.bridge_id AND p2.id=updates.id AND p2.receiver=updates.new_receiver ) ` + fixParentsAfterSplitPortalMigrationQuery = ` + UPDATE portal SET parent_receiver=receiver WHERE parent_receiver='' AND receiver<>'' AND parent_id<>''; + ` ) func (pq *PortalQuery) GetByKey(ctx context.Context, key networkid.PortalKey) (*Portal, error) { @@ -209,6 +217,14 @@ func (pq *PortalQuery) MigrateToSplitPortals(ctx context.Context) (int64, error) return res.RowsAffected() } +func (pq *PortalQuery) FixParentsAfterSplitPortalMigration(ctx context.Context) (int64, error) { + res, err := pq.GetDB().Exec(ctx, fixParentsAfterSplitPortalMigrationQuery, pq.BridgeID) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { var mxid, parentID, parentReceiver, relayLoginID, otherUserID, disappearType sql.NullString var disappearTimer sql.NullInt64 From 0012a23d85023945f94796c8efcc794920841e22 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Sep 2025 21:21:25 +0300 Subject: [PATCH 1387/1647] bridgev2/portal: don't allow queuing events into uninitialized portals --- bridgev2/portal.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index be029f25..575edfb8 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -303,6 +303,9 @@ func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) EventHand portal.eventIdx++ return portal.handleSingleEventWithDelayLogging(portal.eventIdx, evt) } else { + if portal.events == nil { + panic(fmt.Errorf("queueEvent into uninitialized portal %s", portal.PortalKey)) + } select { case portal.events <- evt: return EventHandlingResultQueued From 0a84c052dda8036a4bb59452234fd4211eadccc5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 21 Sep 2025 20:10:59 +0300 Subject: [PATCH 1388/1647] crypto: add utilities for cross-signing --- crypto/cross_sign_pubkey.go | 14 +++++++++++++ crypto/cross_sign_ssss.go | 40 +++++++++++++++++++++++++++++++++++-- 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/crypto/cross_sign_pubkey.go b/crypto/cross_sign_pubkey.go index 77efab5b..f85d1ea3 100644 --- a/crypto/cross_sign_pubkey.go +++ b/crypto/cross_sign_pubkey.go @@ -20,6 +20,20 @@ type CrossSigningPublicKeysCache struct { UserSigningKey id.Ed25519 } +func (mach *OlmMachine) GetOwnVerificationStatus(ctx context.Context) (hasKeys, isVerified bool, err error) { + pubkeys := mach.GetOwnCrossSigningPublicKeys(ctx) + if pubkeys != nil { + hasKeys = true + isVerified, err = mach.CryptoStore.IsKeySignedBy( + ctx, mach.Client.UserID, mach.GetAccount().SigningKey(), mach.Client.UserID, pubkeys.SelfSigningKey, + ) + if err != nil { + err = fmt.Errorf("failed to check if current device is signed by own self-signing key: %w", err) + } + } + return +} + func (mach *OlmMachine) GetOwnCrossSigningPublicKeys(ctx context.Context) *CrossSigningPublicKeysCache { if mach.crossSigningPubkeys != nil { return mach.crossSigningPubkeys diff --git a/crypto/cross_sign_ssss.go b/crypto/cross_sign_ssss.go index 389a9fd2..50b58ea0 100644 --- a/crypto/cross_sign_ssss.go +++ b/crypto/cross_sign_ssss.go @@ -71,6 +71,42 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeysWithPassword(ctx contex }, passphrase) } +func (mach *OlmMachine) VerifyWithRecoveryKey(ctx context.Context, recoveryKey string) error { + keyID, keyData, err := mach.SSSS.GetDefaultKeyData(ctx) + if err != nil { + return fmt.Errorf("failed to get default SSSS key data: %w", err) + } + key, err := keyData.VerifyRecoveryKey(keyID, recoveryKey) + if err != nil { + return err + } + err = mach.FetchCrossSigningKeysFromSSSS(ctx, key) + if err != nil { + return fmt.Errorf("failed to fetch cross-signing keys from SSSS: %w", err) + } + err = mach.SignOwnDevice(ctx, mach.OwnIdentity()) + if err != nil { + return fmt.Errorf("failed to sign own device: %w", err) + } + err = mach.SignOwnMasterKey(ctx) + if err != nil { + return fmt.Errorf("failed to sign own master key: %w", err) + } + return nil +} + +func (mach *OlmMachine) GenerateAndVerifyWithRecoveryKey(ctx context.Context) (recoveryKey string, err error) { + recoveryKey, _, err = mach.GenerateAndUploadCrossSigningKeys(ctx, nil, "") + if err != nil { + err = fmt.Errorf("failed to generate and upload cross-signing keys: %w", err) + } else if err = mach.SignOwnDevice(ctx, mach.OwnIdentity()); err != nil { + err = fmt.Errorf("failed to sign own device: %w", err) + } else if err = mach.SignOwnMasterKey(ctx); err != nil { + err = fmt.Errorf("failed to sign own master key: %w", err) + } + return +} + // GenerateAndUploadCrossSigningKeys generates a new key with all corresponding cross-signing keys. // // A passphrase can be provided to generate the SSSS key. If the passphrase is empty, a random key @@ -97,12 +133,12 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(ctx context.Context, u // Publish cross-signing keys err = mach.PublishCrossSigningKeys(ctx, keysCache, uiaCallback) if err != nil { - return "", nil, fmt.Errorf("failed to publish cross-signing keys: %w", err) + return key.RecoveryKey(), keysCache, fmt.Errorf("failed to publish cross-signing keys: %w", err) } err = mach.SSSS.SetDefaultKeyID(ctx, key.ID) if err != nil { - return "", nil, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err) + return key.RecoveryKey(), keysCache, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err) } return key.RecoveryKey(), keysCache, nil From 6c37f2b21f24dcc186ac2a4d00db9708a21511f4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 21 Sep 2025 20:12:05 +0300 Subject: [PATCH 1389/1647] bridgev2/matrix: add config option to self-sign bot device --- bridgev2/bridgeconfig/encryption.go | 1 + bridgev2/bridgeconfig/upgrade.go | 1 + bridgev2/database/kvstore.go | 1 + bridgev2/matrix/crypto.go | 55 ++++++++++++++++++++-- bridgev2/matrix/mxmain/example-config.yaml | 4 ++ 5 files changed, 59 insertions(+), 3 deletions(-) diff --git a/bridgev2/bridgeconfig/encryption.go b/bridgev2/bridgeconfig/encryption.go index 1ef7e18f..5a19b3ad 100644 --- a/bridgev2/bridgeconfig/encryption.go +++ b/bridgev2/bridgeconfig/encryption.go @@ -16,6 +16,7 @@ type EncryptionConfig struct { Require bool `yaml:"require"` Appservice bool `yaml:"appservice"` MSC4190 bool `yaml:"msc4190"` + SelfSign bool `yaml:"self_sign"` PlaintextMentions bool `yaml:"plaintext_mentions"` diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index f41f77d8..6533338f 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -157,6 +157,7 @@ func doUpgrade(helper up.Helper) { } else { helper.Copy(up.Bool, "encryption", "msc4190") } + helper.Copy(up.Bool, "encryption", "self_sign") helper.Copy(up.Bool, "encryption", "allow_key_sharing") if secret, ok := helper.Get(up.Str, "encryption", "pickle_key"); !ok || secret == "generate" { helper.Set(up.Str, random.String(64), "encryption", "pickle_key") diff --git a/bridgev2/database/kvstore.go b/bridgev2/database/kvstore.go index 52b4984e..bca26ed5 100644 --- a/bridgev2/database/kvstore.go +++ b/bridgev2/database/kvstore.go @@ -23,6 +23,7 @@ const ( KeySplitPortalsEnabled Key = "split_portals_enabled" KeyBridgeInfoVersion Key = "bridge_info_version" KeyEncryptionStateResynced Key = "encryption_state_resynced" + KeyRecoveryKey Key = "recovery_key" ) type KVQuery struct { diff --git a/bridgev2/matrix/crypto.go b/bridgev2/matrix/crypto.go index 2325ddfa..d77f1d44 100644 --- a/bridgev2/matrix/crypto.go +++ b/bridgev2/matrix/crypto.go @@ -24,6 +24,7 @@ import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/crypto" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/event" @@ -135,7 +136,14 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { return err } if isExistingDevice { - helper.verifyKeysAreOnServer(ctx) + if !helper.verifyKeysAreOnServer(ctx) { + return nil + } + } + if helper.bridge.Config.Encryption.SelfSign { + if !helper.doSelfSign(ctx) { + os.Exit(34) + } } go helper.resyncEncryptionInfo(context.TODO()) @@ -143,6 +151,46 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { return nil } +func (helper *CryptoHelper) doSelfSign(ctx context.Context) bool { + log := zerolog.Ctx(ctx) + hasKeys, isVerified, err := helper.mach.GetOwnVerificationStatus(ctx) + if err != nil { + log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to check verification status") + return false + } + log.Debug().Bool("has_keys", hasKeys).Bool("is_verified", isVerified).Msg("Checked verification status") + keyInDB := helper.bridge.Bridge.DB.KV.Get(ctx, database.KeyRecoveryKey) + if !hasKeys || keyInDB == "overwrite" { + if keyInDB != "" && keyInDB != "overwrite" { + log.WithLevel(zerolog.FatalLevel). + Msg("No keys on server, but database already has recovery key. Delete `recovery_key` from `kv_store` manually to continue.") + return false + } + recoveryKey, err := helper.mach.GenerateAndVerifyWithRecoveryKey(ctx) + if recoveryKey != "" { + helper.bridge.Bridge.DB.KV.Set(ctx, database.KeyRecoveryKey, recoveryKey) + } + if err != nil { + log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to generate recovery key and self-sign") + return false + } + log.Info().Msg("Generated new recovery key and self-signed bot device") + } else if !isVerified { + if keyInDB == "" { + log.WithLevel(zerolog.FatalLevel). + Msg("Server already has cross-signing keys, but no key in database. Add `recovery_key` to `kv_store`, or set it to `overwrite` to generate new keys.") + return false + } + err = helper.mach.VerifyWithRecoveryKey(ctx, keyInDB) + if err != nil { + log.WithLevel(zerolog.FatalLevel).Err(err).Msg("Failed to verify with recovery key") + return false + } + log.Info().Msg("Verified bot device with existing recovery key") + } + return true +} + func (helper *CryptoHelper) resyncEncryptionInfo(ctx context.Context) { log := helper.log.With().Str("action", "resync encryption event").Logger() rows, err := helper.store.DB.Query(ctx, `SELECT room_id FROM mx_room_state WHERE encryption='{"resync":true}'`) @@ -274,7 +322,7 @@ func (helper *CryptoHelper) loginBot(ctx context.Context) (*mautrix.Client, bool return client, deviceID != "", nil } -func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) { +func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) bool { helper.log.Debug().Msg("Making sure keys are still on server") resp, err := helper.client.QueryKeys(ctx, &mautrix.ReqQueryKeys{ DeviceKeys: map[id.UserID]mautrix.DeviceIDList{ @@ -287,10 +335,11 @@ func (helper *CryptoHelper) verifyKeysAreOnServer(ctx context.Context) { } device, ok := resp.DeviceKeys[helper.client.UserID][helper.client.DeviceID] if ok && len(device.Keys) > 0 { - return + return true } helper.log.Warn().Msg("Existing device doesn't have keys on server, resetting crypto") helper.Reset(ctx, false) + return false } func (helper *CryptoHelper) Start() { diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 95fa13eb..d8634028 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -365,6 +365,10 @@ encryption: # Only relevant when using end-to-bridge encryption, required when using encryption with next-gen auth (MSC3861). # Changing this option requires updating the appservice registration file. msc4190: false + # Should the bridge bot generate a recovery key and cross-signing keys and verify itself? + # Note that without the latest version of MSC4190, this will fail if you reset the bridge database. + # The generated recovery key will be saved in the kv_store table under `recovery_key`. + self_sign: false # Enable key sharing? If enabled, key requests for rooms where users are in will be fulfilled. # You must use a client that supports requesting keys from other users to use this feature. allow_key_sharing: true From 658b2e1d1d9a67c4d6b52725732fde8b2b47d6cf Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 21 Sep 2025 20:30:22 +0300 Subject: [PATCH 1390/1647] bridgev2/matrix: share device keys as part of e2ee init --- bridgev2/matrix/crypto.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/bridgev2/matrix/crypto.go b/bridgev2/matrix/crypto.go index d77f1d44..f4a2e9a0 100644 --- a/bridgev2/matrix/crypto.go +++ b/bridgev2/matrix/crypto.go @@ -139,6 +139,11 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { if !helper.verifyKeysAreOnServer(ctx) { return nil } + } else { + err = helper.ShareKeys(ctx) + if err != nil { + return fmt.Errorf("failed to share device keys: %w", err) + } } if helper.bridge.Config.Encryption.SelfSign { if !helper.doSelfSign(ctx) { From 0198ef315c029b38cd303a15e85caf2eb364f00f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 21 Sep 2025 20:51:51 +0300 Subject: [PATCH 1391/1647] changelog: update --- CHANGELOG.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c33645f..794008c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,18 @@ +## unreleased + +* *(crypto)* Added helper methods for generating and verifying with recovery + keys. +* *(bridgev2/matrix)* Added config option to automatically generate a recovery + key for the bridge bot and self-sign the bridge's device. +* *(bridgev2/matrix)* Added initial support for using appservice/MSC3202 mode + for encryption with standard servers like Synapse. +* *(bridgev2)* Added optional support for implicit read receipts. +* *(bridgev2)* Extended event duration logging to log any event taking too long. +* *(bridgev2)* Fixed various bugs with migrating to split portals. +* *(event)* Fixed poll start events having incorrect null `m.relates_to`. +* *(event)* Added event type constant for poll end events. +* *(client)* Fixed `RespUserProfile` losing standard fields when re-marshaling. + ## v0.25.1 (2025-09-16) * *(client)* Fixed HTTP method of delete devices API call From cf814a5aaae1f9029c3c7e5d14f34231ec8b6723 Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Mon, 22 Sep 2025 13:30:08 +0300 Subject: [PATCH 1392/1647] error: Add RespError WithExtraData convenience function (#416) To dynamically build errors with extra keys like returning `max_delay` for `M_MAX_DELAY_EXCEEDED`. --- error.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/error.go b/error.go index 6f4880df..bea4caae 100644 --- a/error.go +++ b/error.go @@ -177,6 +177,16 @@ func (e RespError) WithStatus(status int) RespError { return e } +func (e RespError) WithExtraData(extraData map[string]any) RespError { + if e.ExtraData == nil { + e.ExtraData = make(map[string]any) + } else { + e.ExtraData = maps.Clone(e.ExtraData) + } + maps.Copy(e.ExtraData, extraData) + return e +} + // Error returns the errcode and error message. func (e RespError) Error() string { return e.ErrCode + ": " + e.Err From f9fb77d6aad75604351cb09d846f448a9e22ac9d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 22 Sep 2025 13:46:37 +0300 Subject: [PATCH 1393/1647] client: add user directory search method --- CHANGELOG.md | 1 + client.go | 9 +++++++++ requests.go | 5 +++++ responses.go | 29 +++++++++++++++++++++++++++++ 4 files changed, 44 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 794008c9..831e3094 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ * *(event)* Fixed poll start events having incorrect null `m.relates_to`. * *(event)* Added event type constant for poll end events. * *(client)* Fixed `RespUserProfile` losing standard fields when re-marshaling. +* *(client)* Added wrapper for searching user directory. ## v0.25.1 (2025-09-16) diff --git a/client.go b/client.go index edeab732..bf41ffb9 100644 --- a/client.go +++ b/client.go @@ -1055,6 +1055,15 @@ func (cli *Client) GetProfile(ctx context.Context, mxid id.UserID) (resp *RespUs return } +func (cli *Client) SearchUserDirectory(ctx context.Context, query string, limit int) (resp *RespSearchUserDirectory, err error) { + urlPath := cli.BuildClientURL("v3", "user_directory", "search") + _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, &ReqSearchUserDirectory{ + SearchTerm: query, + Limit: limit, + }, &resp) + return +} + func (cli *Client) GetMutualRooms(ctx context.Context, otherUserID id.UserID, extras ...ReqMutualRooms) (resp *RespMutualRooms, err error) { if cli.SpecVersions != nil && !cli.SpecVersions.Supports(FeatureMutualRooms) { err = fmt.Errorf("server does not support fetching mutual rooms") diff --git a/requests.go b/requests.go index 4b5ce74b..9dfe09ab 100644 --- a/requests.go +++ b/requests.go @@ -183,6 +183,11 @@ type ReqKnockRoom struct { Reason string `json:"reason,omitempty"` } +type ReqSearchUserDirectory struct { + SearchTerm string `json:"search_term"` + Limit int `json:"limit,omitempty"` +} + type ReqMutualRooms struct { From string `json:"-"` } diff --git a/responses.go b/responses.go index 4d66cdb8..82ba003a 100644 --- a/responses.go +++ b/responses.go @@ -213,6 +213,35 @@ func (r *RespUserProfile) MarshalJSON() ([]byte, error) { return json.Marshal(marshalMap) } +type RespSearchUserDirectory struct { + Limited bool `json:"limited"` + Results []*RespUserProfile `json:"results"` +} + +type UserDirectoryEntry struct { + RespUserProfile + UserID id.UserID `json:"user_id"` +} + +func (r *UserDirectoryEntry) UnmarshalJSON(data []byte) error { + err := r.RespUserProfile.UnmarshalJSON(data) + if err != nil { + return err + } + userIDStr, _ := r.Extra["user_id"].(string) + r.UserID = id.UserID(userIDStr) + delete(r.Extra, "user_id") + return nil +} + +func (r *UserDirectoryEntry) MarshalJSON() ([]byte, error) { + if r.Extra == nil { + r.Extra = make(map[string]any) + } + r.Extra["user_id"] = r.UserID.String() + return r.RespUserProfile.MarshalJSON() +} + type RespMutualRooms struct { Joined []id.RoomID `json:"joined"` NextBatch string `json:"next_batch,omitempty"` From c4701ba06c2bf7db5b8205dab7fe6c3ce58c7b42 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 22 Sep 2025 14:30:41 +0300 Subject: [PATCH 1394/1647] responses: fix RespSearchUserDirectory type --- responses.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/responses.go b/responses.go index 82ba003a..a79be28b 100644 --- a/responses.go +++ b/responses.go @@ -214,8 +214,8 @@ func (r *RespUserProfile) MarshalJSON() ([]byte, error) { } type RespSearchUserDirectory struct { - Limited bool `json:"limited"` - Results []*RespUserProfile `json:"results"` + Limited bool `json:"limited"` + Results []*UserDirectoryEntry `json:"results"` } type UserDirectoryEntry struct { From 23b18aa0ca2907903cf520be30e5541e1e3da82b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 22 Sep 2025 14:46:47 +0300 Subject: [PATCH 1395/1647] bridgev2/provisioning: fix login_id query param name --- bridgev2/matrix/provisioning.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/matrix/provisioning.yaml b/bridgev2/matrix/provisioning.yaml index 5bb27272..21c93ca4 100644 --- a/bridgev2/matrix/provisioning.yaml +++ b/bridgev2/matrix/provisioning.yaml @@ -400,7 +400,7 @@ components: - username - meow@example.com loginID: - name: loginID + name: login_id in: query description: An optional explicit login ID to do the action through. required: false From b3c883bc7fa39021bf8cfde6d51ce5860b8dded5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 22 Sep 2025 16:05:28 +0300 Subject: [PATCH 1396/1647] event: add beeper chat delete event --- event/beeper.go | 4 ++++ event/capabilities.d.ts | 5 +++++ event/capabilities.go | 12 +++++++----- event/content.go | 1 + event/type.go | 3 ++- 5 files changed, 19 insertions(+), 6 deletions(-) diff --git a/event/beeper.go b/event/beeper.go index 921e3466..95b4a571 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -86,6 +86,10 @@ type BeeperRoomKeyAckEventContent struct { FirstMessageIndex int `json:"first_message_index"` } +type BeeperChatDeleteEventContent struct { + DeleteForEveryone bool `json:"delete_for_everyone,omitempty"` +} + type IntOrString int func (ios *IntOrString) UnmarshalJSON(data []byte) error { diff --git a/event/capabilities.d.ts b/event/capabilities.d.ts index 27164a5f..37848575 100644 --- a/event/capabilities.d.ts +++ b/event/capabilities.d.ts @@ -55,6 +55,11 @@ export interface RoomFeatures { allowed_reactions?: string[] /** Whether custom emoji reactions are allowed. */ custom_emoji_reactions?: boolean + + /** Whether deleting the chat for yourself is supported. */ + delete_chat?: boolean + /** Whether deleting the chat for all participants is supported. */ + delete_chat_for_everyone?: boolean } declare type integer = number diff --git a/event/capabilities.go b/event/capabilities.go index 94662428..31a6b7aa 100644 --- a/event/capabilities.go +++ b/event/capabilities.go @@ -51,11 +51,12 @@ type RoomFeatures struct { AllowedReactions []string `json:"allowed_reactions,omitempty"` CustomEmojiReactions bool `json:"custom_emoji_reactions,omitempty"` - ReadReceipts bool `json:"read_receipts,omitempty"` - TypingNotifications bool `json:"typing_notifications,omitempty"` - Archive bool `json:"archive,omitempty"` - MarkAsUnread bool `json:"mark_as_unread,omitempty"` - DeleteChat bool `json:"delete_chat,omitempty"` + ReadReceipts bool `json:"read_receipts,omitempty"` + TypingNotifications bool `json:"typing_notifications,omitempty"` + Archive bool `json:"archive,omitempty"` + MarkAsUnread bool `json:"mark_as_unread,omitempty"` + DeleteChat bool `json:"delete_chat,omitempty"` + DeleteChatForEveryone bool `json:"delete_chat_for_everyone,omitempty"` } func (rf *RoomFeatures) GetID() string { @@ -262,6 +263,7 @@ func (rf *RoomFeatures) Hash() []byte { hashBool(hasher, "archive", rf.Archive) hashBool(hasher, "mark_as_unread", rf.MarkAsUnread) hashBool(hasher, "delete_chat", rf.DeleteChat) + hashBool(hasher, "delete_chat_for_everyone", rf.DeleteChatForEveryone) return hasher.Sum(nil) } diff --git a/event/content.go b/event/content.go index 5e093273..c0ff51ad 100644 --- a/event/content.go +++ b/event/content.go @@ -63,6 +63,7 @@ var TypeMap = map[Type]reflect.Type{ BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}), BeeperTranscription: reflect.TypeOf(BeeperTranscriptionEventContent{}), + BeeperDeleteChat: reflect.TypeOf(BeeperChatDeleteEventContent{}), AccountDataRoomTags: reflect.TypeOf(TagEventContent{}), AccountDataDirectChats: reflect.TypeOf(DirectChatsEventContent{}), diff --git a/event/type.go b/event/type.go index 1b4fbf76..56ea82f6 100644 --- a/event/type.go +++ b/event/type.go @@ -128,7 +128,7 @@ func (et *Type) GuessClass() TypeClass { InRoomVerificationKey.Type, InRoomVerificationMAC.Type, InRoomVerificationCancel.Type, CallInvite.Type, CallCandidates.Type, CallAnswer.Type, CallReject.Type, CallSelectAnswer.Type, CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type, EventUnstablePollStart.Type, EventUnstablePollResponse.Type, - EventUnstablePollEnd.Type, BeeperTranscription.Type: + EventUnstablePollEnd.Type, BeeperTranscription.Type, BeeperDeleteChat.Type: return MessageEventType case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type, ToDeviceBeeperRoomKeyAck.Type: @@ -236,6 +236,7 @@ var ( BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType} BeeperTranscription = Type{"com.beeper.transcription", MessageEventType} + BeeperDeleteChat = Type{"com.beeper.delete_chat", MessageEventType} EventUnstablePollStart = Type{Type: "org.matrix.msc3381.poll.start", Class: MessageEventType} EventUnstablePollResponse = Type{Type: "org.matrix.msc3381.poll.response", Class: MessageEventType} From a9ff1443f70678599bb74a2c941de23593827bb2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 22 Sep 2025 16:05:42 +0300 Subject: [PATCH 1397/1647] bridgev2: add interface for deleting chats from Matrix Closes #408 --- bridgev2/errors.go | 2 + bridgev2/matrix/connector.go | 1 + bridgev2/networkinterface.go | 9 ++++ bridgev2/portal.go | 83 +++++++++++++++++++++++++++++++----- 4 files changed, 84 insertions(+), 11 deletions(-) diff --git a/bridgev2/errors.go b/bridgev2/errors.go index 52bebe81..694224f1 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -44,6 +44,7 @@ var ( ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false) ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false) ErrIgnoringPollFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring poll event from relayed user")).WithIsCertain(true).WithSendNotice(false) + ErrIgnoringDeleteChatRelayedUser error = WrapErrorInStatus(errors.New("ignoring delete chat event from relayed user")).WithIsCertain(true).WithSendNotice(false) ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage() ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage() ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage() @@ -65,6 +66,7 @@ var ( ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true) ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true) ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) + ErrDeleteChatNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting chats")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) ErrRemoteEchoTimeout = WrapErrorInStatus(errors.New("remote echo timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) ErrRemoteAckTimeout = WrapErrorInStatus(errors.New("remote ack timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index ab1764dd..3dd9ae1a 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -148,6 +148,7 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) { br.EventProcessor.On(event.StateTopic, br.handleRoomEvent) br.EventProcessor.On(event.StateTombstone, br.handleRoomEvent) br.EventProcessor.On(event.StateBeeperDisappearingTimer, br.handleRoomEvent) + br.EventProcessor.On(event.BeeperDeleteChat, br.handleRoomEvent) br.EventProcessor.On(event.EphemeralEventReceipt, br.handleEphemeralEvent) br.EventProcessor.On(event.EphemeralEventTyping, br.handleEphemeralEvent) br.Bot = br.AS.BotIntent() diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index fa87086a..8dffbb34 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -697,6 +697,14 @@ type DisappearTimerChangingNetworkAPI interface { HandleMatrixDisappearingTimer(ctx context.Context, msg *MatrixDisappearingTimer) (bool, error) } +// DeleteChatHandlingNetworkAPI is an optional interface that network connectors +// can implement to delete a chat from the remote network. +type DeleteChatHandlingNetworkAPI interface { + NetworkAPI + // HandleMatrixDeleteChat is called when the user explicitly deletes a chat. + HandleMatrixDeleteChat(ctx context.Context, msg *MatrixDeleteChat) error +} + type ResolveIdentifierResponse struct { // Ghost is the ghost of the user that the identifier resolves to. // This field should be set whenever possible. However, it is not required, @@ -1380,6 +1388,7 @@ type MatrixViewingChat struct { Portal *Portal } +type MatrixDeleteChat = MatrixEventBase[*event.BeeperChatDeleteEventContent] type MatrixMarkedUnread = MatrixRoomMeta[*event.MarkedUnreadEventContent] type MatrixMute = MatrixRoomMeta[*event.BeeperMuteEventContent] type MatrixRoomTag = MatrixRoomMeta[*event.TagEventContent] diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 575edfb8..f53691fa 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -706,6 +706,8 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * return portal.handleMatrixMembership(ctx, login, origSender, evt) case event.StatePowerLevels: return portal.handleMatrixPowerLevels(ctx, login, origSender, evt) + case event.BeeperDeleteChat: + return portal.handleMatrixDeleteChat(ctx, login, origSender, evt) default: return EventHandlingResultIgnored } @@ -1622,6 +1624,58 @@ func (portal *Portal) getTargetUser(ctx context.Context, userID id.UserID) (Ghos } } +func (portal *Portal) handleMatrixDeleteChat( + ctx context.Context, + sender *UserLogin, + origSender *OrigSender, + evt *event.Event, +) EventHandlingResult { + if origSender != nil { + return EventHandlingResultFailed.WithMSSError(ErrIgnoringDeleteChatRelayedUser) + } + log := zerolog.Ctx(ctx) + content, ok := evt.Content.Parsed.(*event.BeeperChatDeleteEventContent) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + } + api, ok := sender.Client.(DeleteChatHandlingNetworkAPI) + if !ok { + return EventHandlingResultIgnored.WithMSSError(ErrDeleteChatNotSupported) + } + err := api.HandleMatrixDeleteChat(ctx, &MatrixDeleteChat{ + Event: evt, + Content: content, + Portal: portal, + }) + if err != nil { + log.Err(err).Msg("Failed to handle Matrix chat delete") + return EventHandlingResultFailed.WithMSSError(err) + } + if portal.Receiver == "" { + _, others, err := portal.findOtherLogins(ctx, sender) + if err != nil { + log.Err(err).Msg("Failed to check if portal has other logins") + return EventHandlingResultFailed.WithError(err) + } else if len(others) > 0 { + log.Debug().Msg("Not deleting portal after chat delete as other logins are present") + return EventHandlingResultSuccess + } + } + err = portal.Delete(ctx) + if err != nil { + log.Err(err).Msg("Failed to delete portal from database") + return EventHandlingResultFailed.WithMSSError(err) + } + err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, false) + if err != nil { + log.Err(err).Msg("Failed to delete Matrix room") + return EventHandlingResultFailed.WithMSSError(err) + } + // No MSS here as the portal was deleted + return EventHandlingResultSuccess +} + func (portal *Portal) handleMatrixMembership( ctx context.Context, sender *UserLogin, @@ -3160,11 +3214,11 @@ func (portal *Portal) handleRemoteMessageRemove(ctx context.Context, source *Use onlyForMeProvider, ok := evt.(RemoteDeleteOnlyForMe) onlyForMe := ok && onlyForMeProvider.DeleteOnlyForMe() if onlyForMe && portal.Receiver == "" { - logins, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) + _, others, err := portal.findOtherLogins(ctx, source) if err != nil { log.Err(err).Msg("Failed to check if portal has other logins") return EventHandlingResultFailed.WithError(err) - } else if len(logins) > 1 { + } else if len(others) > 0 { log.Debug().Msg("Ignoring delete for me event in portal with multiple logins") return EventHandlingResultIgnored } @@ -3413,22 +3467,29 @@ func (portal *Portal) handleRemoteChatResync(ctx context.Context, source *UserLo return EventHandlingResultSuccess } +func (portal *Portal) findOtherLogins(ctx context.Context, source *UserLogin) (ownUP *database.UserPortal, others []*database.UserPortal, err error) { + others, err = portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) + if err != nil { + return + } + others = slices.DeleteFunc(others, func(up *database.UserPortal) bool { + if up.LoginID == source.ID { + ownUP = up + return true + } + return false + }) + return +} + func (portal *Portal) handleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) EventHandlingResult { log := zerolog.Ctx(ctx) if portal.Receiver == "" && evt.DeleteOnlyForMe() { - logins, err := portal.Bridge.DB.UserPortal.GetAllInPortal(ctx, portal.PortalKey) + ownUP, logins, err := portal.findOtherLogins(ctx, source) if err != nil { log.Err(err).Msg("Failed to check if portal has other logins") return EventHandlingResultFailed.WithError(err) } - var ownUP *database.UserPortal - logins = slices.DeleteFunc(logins, func(up *database.UserPortal) bool { - if up.LoginID == source.ID { - ownUP = up - return true - } - return false - }) if len(logins) > 0 { log.Debug().Msg("Not deleting portal with other logins in remote chat delete event") if ownUP != nil { From d5c6393f2350a941d4a7041d902473f79209fd81 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 22 Sep 2025 16:11:21 +0300 Subject: [PATCH 1398/1647] bridgev2/portal: don't process any more events if portal is deleted --- bridgev2/portal.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index f53691fa..4637d6ba 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -90,7 +90,8 @@ type Portal struct { functionalMembersLock sync.Mutex functionalMembersCache *event.ElementFunctionalMembersContent - events chan portalEvent + events chan portalEvent + deleted bool eventsLock sync.Mutex eventIdx int @@ -335,6 +336,9 @@ func (portal *Portal) eventLoop() { } i := 0 for rawEvt := range portal.events { + if portal.deleted { + return + } i++ if portal.Bridge.Config.AsyncEvents { go portal.handleSingleEventWithDelayLogging(i, rawEvt) @@ -4811,6 +4815,7 @@ func (portal *Portal) unlockedDeleteCache() { // TODO there's a small risk of this racing with a queueEvent call close(portal.events) } + portal.deleted = true } func (portal *Portal) Save(ctx context.Context) error { From a8b5fa91566f680a66fef022649e85113fc81f38 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 22 Sep 2025 16:27:01 +0300 Subject: [PATCH 1399/1647] client: fix some footguns in compileRequest * add warning log if RequestBody is used without length instead of silently discarding the body * fix wrapping RequestBody in nopcloser * always set content length --- client.go | 41 ++++++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/client.go b/client.go index bf41ffb9..d274b007 100644 --- a/client.go +++ b/client.go @@ -418,8 +418,18 @@ var requestID int32 var logSensitiveContent = os.Getenv("MAUTRIX_LOG_SENSITIVE_CONTENT") == "yes" func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, error) { + reqID := atomic.AddInt32(&requestID, 1) + logger := zerolog.Ctx(ctx) + if logger.GetLevel() == zerolog.Disabled || logger == zerolog.DefaultContextLogger { + logger = params.Logger + } + ctx = logger.With(). + Int32("req_id", reqID). + Logger().WithContext(ctx) + var logBody any - reqBody := params.RequestBody + var reqBody io.Reader + var reqLen int64 if params.RequestJSON != nil { jsonStr, err := json.Marshal(params.RequestJSON) if err != nil { @@ -434,12 +444,22 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e logBody = params.RequestJSON } reqBody = bytes.NewReader(jsonStr) + reqLen = int64(len(jsonStr)) } else if params.RequestBytes != nil { logBody = fmt.Sprintf("<%d bytes>", len(params.RequestBytes)) reqBody = bytes.NewReader(params.RequestBytes) - params.RequestLength = int64(len(params.RequestBytes)) - } else if params.RequestLength > 0 && params.RequestBody != nil { - logBody = fmt.Sprintf("<%d bytes>", params.RequestLength) + reqLen = int64(len(params.RequestBytes)) + } else if params.RequestBody != nil { + logBody = "" + reqLen = -1 + if params.RequestLength > 0 { + logBody = fmt.Sprintf("<%d bytes>", params.RequestLength) + reqLen = params.RequestLength + } else if params.RequestLength == 0 { + zerolog.Ctx(ctx).Warn(). + Msg("RequestBody passed without specifying request length") + } + reqBody = params.RequestBody if rsc, ok := params.RequestBody.(io.ReadSeekCloser); ok { // Prevent HTTP from closing the request body, it might be needed for retries reqBody = nopCloseSeeker{rsc} @@ -448,15 +468,8 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e params.RequestJSON = struct{}{} logBody = params.RequestJSON reqBody = bytes.NewReader([]byte("{}")) + reqLen = 2 } - reqID := atomic.AddInt32(&requestID, 1) - logger := zerolog.Ctx(ctx) - if logger.GetLevel() == zerolog.Disabled || logger == zerolog.DefaultContextLogger { - logger = params.Logger - } - ctx = logger.With(). - Int32("req_id", reqID). - Logger().WithContext(ctx) ctx = context.WithValue(ctx, LogBodyContextKey, logBody) ctx = context.WithValue(ctx, LogRequestIDContextKey, int(reqID)) req, err := http.NewRequestWithContext(ctx, params.Method, params.URL, reqBody) @@ -472,9 +485,7 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e if params.RequestJSON != nil { req.Header.Set("Content-Type", "application/json") } - if params.RequestLength > 0 && params.RequestBody != nil { - req.ContentLength = params.RequestLength - } + req.ContentLength = reqLen return req, nil } From 4635590fca48c9f7584db00e84bb02a12cdb79fc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 22 Sep 2025 18:24:26 +0300 Subject: [PATCH 1400/1647] bridgev2/portal: add temporary flag to slack bridge info To let clients detect that https://github.com/mautrix/slack/commit/952806ea5204c420f771d0d51718384e4448370e is done --- bridgev2/portal.go | 3 +++ event/state.go | 2 ++ 2 files changed, 5 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 4637d6ba..5db45268 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -3822,6 +3822,9 @@ func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) { if portal.RoomType == database.RoomTypeDM || portal.RoomType == database.RoomTypeGroupDM { bridgeInfo.BeeperRoomType = "dm" } + if bridgeInfo.Protocol.ID == "slackgo" { + bridgeInfo.TempSlackRemoteIDMigratedFlag = true + } parent := portal.GetTopLevelParent() if parent != nil { bridgeInfo.Network = &event.BridgeInfoSection{ diff --git a/event/state.go b/event/state.go index ba7c608d..ed5434c9 100644 --- a/event/state.go +++ b/event/state.go @@ -231,6 +231,8 @@ type BridgeEventContent struct { BeeperRoomType string `json:"com.beeper.room_type,omitempty"` BeeperRoomTypeV2 string `json:"com.beeper.room_type.v2,omitempty"` + + TempSlackRemoteIDMigratedFlag bool `json:"com.beeper.slack_remote_id_migrated,omitempty"` } // DisappearingType represents the type of a disappearing message timer. From 5c580a7859038f636fa517c6428e83252bfb46fa Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 22 Sep 2025 20:28:44 +0300 Subject: [PATCH 1401/1647] crypto/sqlstore: fix query used for olm unwedging --- crypto/sql_store.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 4405cc31..13940d79 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -253,7 +253,7 @@ func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.Sender // GetNewestSessionCreationTS gets the creation timestamp of the most recently created session with the given sender key. // This will exclude sessions that have never been used to encrypt or decrypt a message. func (store *SQLCryptoStore) GetNewestSessionCreationTS(ctx context.Context, key id.SenderKey) (createdAt time.Time, err error) { - err = store.DB.QueryRow(ctx, "SELECT created_at FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 AND (encrypted_at <> created_at OR decrypted_at <> created_at) ORDER BY created_at DESC LIMIT 1", + err = store.DB.QueryRow(ctx, "SELECT created_at FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 AND (last_encrypted <> created_at OR last_decrypted <> created_at) ORDER BY created_at DESC LIMIT 1", key, store.AccountID).Scan(&createdAt) if errors.Is(err, sql.ErrNoRows) { err = nil From cf29b07f32ceedaad4e4511629eb436670c897b4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 24 Sep 2025 20:29:49 +0300 Subject: [PATCH 1402/1647] appservice/websocket: use io.ReadAll instead of json decoder --- appservice/websocket.go | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/appservice/websocket.go b/appservice/websocket.go index 309cc485..1e401c53 100644 --- a/appservice/websocket.go +++ b/appservice/websocket.go @@ -11,6 +11,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net/http" "net/url" "path/filepath" @@ -292,10 +293,16 @@ func (as *AppService) consumeWebsocket(ctx context.Context, stopFunc func(error) as.Log.Debug().Msg("Ignoring non-text message from websocket") continue } - var msg WebsocketMessage - err = json.NewDecoder(reader).Decode(&msg) + data, err := io.ReadAll(reader) if err != nil { - as.Log.Debug().Err(err).Msg("Error reading JSON from websocket") + as.Log.Debug().Err(err).Msg("Error reading data from websocket") + stopFunc(parseCloseError(err)) + return + } + var msg WebsocketMessage + err = json.Unmarshal(data, &msg) + if err != nil { + as.Log.Debug().Err(err).Msg("Error parsing JSON received from websocket") stopFunc(parseCloseError(err)) return } From b0481d4b4368eccb0ea2e6441832e71e100ec1e3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 26 Sep 2025 12:55:36 +0300 Subject: [PATCH 1403/1647] client: re-add support for unstable profile fields --- client.go | 9 +++++++++ versions.go | 17 ++++++++++------- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/client.go b/client.go index d274b007..62843218 100644 --- a/client.go +++ b/client.go @@ -1125,6 +1125,9 @@ func (cli *Client) SetDisplayName(ctx context.Context, displayName string) (err // SetProfileField sets an arbitrary profile field. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3profileuseridkeyname func (cli *Client) SetProfileField(ctx context.Context, key string, value any) (err error) { urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, key) + if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) { + urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key) + } _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, map[string]any{ key: value, }, nil) @@ -1134,6 +1137,9 @@ func (cli *Client) SetProfileField(ctx context.Context, key string, value any) ( // DeleteProfileField deletes an arbitrary profile field. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3profileuseridkeyname func (cli *Client) DeleteProfileField(ctx context.Context, key string) (err error) { urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, key) + if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) { + urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key) + } _, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, nil) return } @@ -1141,6 +1147,9 @@ func (cli *Client) DeleteProfileField(ctx context.Context, key string) (err erro // GetProfileField gets an arbitrary profile field and parses the response into the given struct. See https://spec.matrix.org/unstable/client-server-api/#get_matrixclientv3profileuseridkeyname func (cli *Client) GetProfileField(ctx context.Context, userID id.UserID, key string, into any) (err error) { urlPath := cli.BuildClientURL("v3", "profile", userID, key) + if key != "displayname" && key != "avatar_url" && !cli.SpecVersions.Supports(FeatureArbitraryProfileFields) && cli.SpecVersions.Supports(FeatureUnstableProfileFields) { + urlPath = cli.BuildClientURL("unstable", "uk.tcpip.msc4133", "profile", cli.UserID, key) + } _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, into) return } diff --git a/versions.go b/versions.go index c3be86cc..0392532e 100644 --- a/versions.go +++ b/versions.go @@ -60,13 +60,15 @@ type UnstableFeature struct { } var ( - FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17} - FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17} - FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111} - FeatureMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"} - FeatureUserRedaction = UnstableFeature{UnstableFlag: "org.matrix.msc4194"} - FeatureViewRedactedContent = UnstableFeature{UnstableFlag: "fi.mau.msc2815"} - FeatureAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323"} + FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17} + FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17} + FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111} + FeatureMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"} + FeatureUserRedaction = UnstableFeature{UnstableFlag: "org.matrix.msc4194"} + FeatureViewRedactedContent = UnstableFeature{UnstableFlag: "fi.mau.msc2815"} + FeatureAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323"} + FeatureUnstableProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133"} + FeatureArbitraryProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133.stable", SpecVersion: SpecV116} BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"} BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"} @@ -118,6 +120,7 @@ var ( SpecV113 = MustParseSpecVersion("v1.13") SpecV114 = MustParseSpecVersion("v1.14") SpecV115 = MustParseSpecVersion("v1.15") + SpecV116 = MustParseSpecVersion("v1.16") ) func (svf SpecVersionFormat) String() string { From 0685bd778619fd05d6455f1479d6e2369515f691 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 26 Sep 2025 16:56:48 +0300 Subject: [PATCH 1404/1647] crypto/verificationhelper: extract mockserver to new package --- .../verificationhelper_qr_crosssign_test.go | 13 ++- .../verificationhelper_qr_self_test.go | 32 +++--- .../verificationhelper_sas_test.go | 46 ++++---- .../verificationhelper_test.go | 56 +++++----- .../mockserver.go | 100 ++++++++---------- 5 files changed, 117 insertions(+), 130 deletions(-) rename crypto/verificationhelper/mockserver_test.go => mockserver/mockserver.go (68%) diff --git a/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go b/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go index aace2230..5e3f146b 100644 --- a/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go +++ b/crypto/verificationhelper/verificationhelper_qr_crosssign_test.go @@ -32,7 +32,6 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("sendingScansQR=%t", tc.sendingScansQR), func(t *testing.T) { ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginAliceBob(t, ctx) - defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -51,10 +50,10 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { // event on the sending device. txnID, err := sendingHelper.StartVerification(ctx, bobUserID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID) require.NotNil(t, receivingShownQRCode) @@ -83,7 +82,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { // Handle the start and done events on the receiving client and // confirm the scan. - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) // Ensure that the receiving device detected that its QR code // was scanned. @@ -98,7 +97,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { doneEvt = sendingInbox[0].Content.AsVerificationDone() assert.Equal(t, txnID, doneEvt.TransactionID) - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) } else { // receiving scans QR // Emulate scanning the QR code shown by the sending device on // the receiving device. @@ -121,7 +120,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { // Handle the start and done events on the receiving client and // confirm the scan. - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) // Ensure that the sending device detected that its QR code was // scanned. @@ -136,7 +135,7 @@ func TestCrossSignVerification_ScanQRAndConfirmScan(t *testing.T) { doneEvt = receivingInbox[0].Content.AsVerificationDone() assert.Equal(t, txnID, doneEvt.TransactionID) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) } // Ensure that both devices have marked the verification as done. diff --git a/crypto/verificationhelper/verificationhelper_qr_self_test.go b/crypto/verificationhelper/verificationhelper_qr_self_test.go index 937cc414..ea918cd4 100644 --- a/crypto/verificationhelper/verificationhelper_qr_self_test.go +++ b/crypto/verificationhelper/verificationhelper_qr_self_test.go @@ -36,7 +36,6 @@ func TestSelfVerification_Accept_QRContents(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("sendingGenerated=%t receivingGenerated=%t err=%s", tc.sendingGeneratedCrossSigningKeys, tc.receivingGeneratedCrossSigningKeys, tc.expectedAcceptError), func(t *testing.T) { ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) - defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -62,7 +61,7 @@ func TestSelfVerification_Accept_QRContents(t *testing.T) { // event on the sending device. txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) if tc.expectedAcceptError != "" { @@ -72,7 +71,7 @@ func TestSelfVerification_Accept_QRContents(t *testing.T) { require.NoError(t, err) } - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID) require.NotNil(t, receivingShownQRCode) @@ -135,7 +134,6 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("sendingGeneratedCrossSigningKeys=%t sendingScansQR=%t", tc.sendingGeneratedCrossSigningKeys, tc.sendingScansQR), func(t *testing.T) { ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) - defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -152,10 +150,10 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { // event on the sending device. txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) receivingShownQRCode := receivingCallbacks.GetQRCodeShown(txnID) require.NotNil(t, receivingShownQRCode) @@ -184,7 +182,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { // Handle the start and done events on the receiving client and // confirm the scan. - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) // Ensure that the receiving device detected that its QR code // was scanned. @@ -199,7 +197,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { doneEvt = sendingInbox[0].Content.AsVerificationDone() assert.Equal(t, txnID, doneEvt.TransactionID) - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) } else { // receiving scans QR // Emulate scanning the QR code shown by the sending device on // the receiving device. @@ -222,7 +220,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { // Handle the start and done events on the receiving client and // confirm the scan. - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) // Ensure that the sending device detected that its QR code was // scanned. @@ -237,7 +235,7 @@ func TestSelfVerification_ScanQRAndConfirmScan(t *testing.T) { doneEvt = receivingInbox[0].Content.AsVerificationDone() assert.Equal(t, txnID, doneEvt.TransactionID) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) } // Ensure that both devices have marked the verification as done. @@ -251,7 +249,6 @@ func TestSelfVerification_ScanQRTransactionIDCorrupted(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) - defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -263,10 +260,10 @@ func TestSelfVerification_ScanQRTransactionIDCorrupted(t *testing.T) { // event on the sending device. txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) receivingShownQRCodeBytes := receivingCallbacks.GetQRCodeShown(txnID).Bytes() sendingShownQRCodeBytes := sendingCallbacks.GetQRCodeShown(txnID).Bytes() @@ -310,7 +307,6 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("sendingGeneratedCrossSigningKeys=%t sendingScansQR=%t corrupt=%d", tc.sendingGeneratedCrossSigningKeys, tc.sendingScansQR, tc.corruptByte), func(t *testing.T) { ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) - defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -327,10 +323,10 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { // event on the sending device. txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) receivingShownQRCodeBytes := receivingCallbacks.GetQRCodeShown(txnID).Bytes() sendingShownQRCodeBytes := sendingCallbacks.GetQRCodeShown(txnID).Bytes() @@ -348,7 +344,7 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { // Ensure that the receiving device received a cancellation. receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] assert.Len(t, receivingInbox, 1) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) cancellation := receivingCallbacks.GetVerificationCancellation(txnID) require.NotNil(t, cancellation) assert.Equal(t, event.VerificationCancelCodeKeyMismatch, cancellation.Code) @@ -362,7 +358,7 @@ func TestSelfVerification_ScanQRKeyCorrupted(t *testing.T) { // Ensure that the sending device received a cancellation. sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] assert.Len(t, sendingInbox, 1) - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) cancellation := sendingCallbacks.GetVerificationCancellation(txnID) require.NotNil(t, cancellation) assert.Equal(t, event.VerificationCancelCodeKeyMismatch, cancellation.Code) diff --git a/crypto/verificationhelper/verificationhelper_sas_test.go b/crypto/verificationhelper/verificationhelper_sas_test.go index 5747ac34..283eca84 100644 --- a/crypto/verificationhelper/verificationhelper_sas_test.go +++ b/crypto/verificationhelper/verificationhelper_sas_test.go @@ -36,7 +36,6 @@ func TestVerification_SAS(t *testing.T) { for _, tc := range testCases { t.Run(fmt.Sprintf("sendingGenerated=%t sendingStartsSAS=%t sendingConfirmsFirst=%t", tc.sendingGeneratedCrossSigningKeys, tc.sendingStartsSAS, tc.sendingConfirmsFirst), func(t *testing.T) { ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) - defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -60,10 +59,10 @@ func TestVerification_SAS(t *testing.T) { // event on the sending device. txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) // Test that the start event is correct var startEvt *event.VerificationStartEventContent @@ -102,7 +101,7 @@ func TestVerification_SAS(t *testing.T) { if tc.sendingStartsSAS { // Process the verification start event on the receiving // device. - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) // Receiving device sent the accept event to the sending device sendingInbox := ts.DeviceInbox[aliceUserID][sendingDeviceID] @@ -110,7 +109,7 @@ func TestVerification_SAS(t *testing.T) { acceptEvt = sendingInbox[0].Content.AsVerificationAccept() } else { // Process the verification start event on the sending device. - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) // Sending device sent the accept event to the receiving device receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] @@ -129,7 +128,7 @@ func TestVerification_SAS(t *testing.T) { var firstKeyEvt *event.VerificationKeyEventContent if tc.sendingStartsSAS { // Process the verification accept event on the sending device. - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) // Sending device sends first key event to the receiving // device. @@ -139,7 +138,7 @@ func TestVerification_SAS(t *testing.T) { } else { // Process the verification accept event on the receiving // device. - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) // Receiving device sends first key event to the sending // device. @@ -155,7 +154,7 @@ func TestVerification_SAS(t *testing.T) { var secondKeyEvt *event.VerificationKeyEventContent if tc.sendingStartsSAS { // Process the first key event on the receiving device. - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) // Receiving device sends second key event to the sending // device. @@ -170,7 +169,7 @@ func TestVerification_SAS(t *testing.T) { assert.Len(t, descriptions, 7) } else { // Process the first key event on the sending device. - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) // Sending device sends second key event to the receiving // device. @@ -191,10 +190,10 @@ func TestVerification_SAS(t *testing.T) { // Ensure that the SAS codes are the same. if tc.sendingStartsSAS { // Process the second key event on the sending device. - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) } else { // Process the second key event on the receiving device. - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) } assert.Equal(t, sendingCallbacks.GetDecimalsShown(txnID), receivingCallbacks.GetDecimalsShown(txnID)) sendingEmojis, sendingDescriptions := sendingCallbacks.GetEmojisAndDescriptionsShown(txnID) @@ -274,10 +273,10 @@ func TestVerification_SAS(t *testing.T) { // Test the transaction is done on both sides. We have to dispatch // twice to process and drain all of the events. - ts.dispatchToDevice(t, ctx, sendingClient) - ts.dispatchToDevice(t, ctx, receivingClient) - ts.dispatchToDevice(t, ctx, sendingClient) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, receivingClient) assert.True(t, sendingCallbacks.IsVerificationDone(txnID)) assert.True(t, receivingCallbacks.IsVerificationDone(txnID)) }) @@ -288,7 +287,6 @@ func TestVerification_SAS_BothCallStart(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) - defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) var err error @@ -305,10 +303,10 @@ func TestVerification_SAS_BothCallStart(t *testing.T) { // event on the sending device. txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) err = sendingHelper.StartSAS(ctx, txnID) require.NoError(t, err) @@ -325,7 +323,7 @@ func TestVerification_SAS_BothCallStart(t *testing.T) { assert.Equal(t, txnID, sendingInbox[0].Content.AsVerificationStart().TransactionID) // Process the start event from the receiving client to the sending client. - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) receivingInbox = ts.DeviceInbox[aliceUserID][receivingDeviceID] assert.Len(t, receivingInbox, 2) assert.Equal(t, txnID, receivingInbox[0].Content.AsVerificationStart().TransactionID) @@ -333,13 +331,13 @@ func TestVerification_SAS_BothCallStart(t *testing.T) { // Process the rest of the events until we need to confirm the SAS. for len(ts.DeviceInbox[aliceUserID][sendingDeviceID]) > 0 || len(ts.DeviceInbox[aliceUserID][receivingDeviceID]) > 0 { - ts.dispatchToDevice(t, ctx, receivingClient) - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, sendingClient) } // Confirm the SAS only the receiving device. receivingHelper.ConfirmSAS(ctx, txnID) - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) // Verification is not done until both devices confirm the SAS. assert.False(t, sendingCallbacks.IsVerificationDone(txnID)) @@ -350,13 +348,13 @@ func TestVerification_SAS_BothCallStart(t *testing.T) { // Dispatching the events to the receiving device should get us to the done // state on the receiving device. - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) assert.False(t, sendingCallbacks.IsVerificationDone(txnID)) assert.True(t, receivingCallbacks.IsVerificationDone(txnID)) // Dispatching the events to the sending client should get us to the done // state on the sending device. - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) assert.True(t, sendingCallbacks.IsVerificationDone(txnID)) assert.True(t, receivingCallbacks.IsVerificationDone(txnID)) } diff --git a/crypto/verificationhelper/verificationhelper_test.go b/crypto/verificationhelper/verificationhelper_test.go index b4c21c18..ce5ec5b4 100644 --- a/crypto/verificationhelper/verificationhelper_test.go +++ b/crypto/verificationhelper/verificationhelper_test.go @@ -19,6 +19,7 @@ import ( "maunium.net/go/mautrix/crypto/verificationhelper" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/mockserver" ) var aliceUserID = id.UserID("@alice:example.org") @@ -31,9 +32,19 @@ func init() { zerolog.DefaultContextLogger = &log.Logger } -func initServerAndLoginTwoAlice(t *testing.T, ctx context.Context) (ts *mockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) { +func addDeviceID(ctx context.Context, cryptoStore crypto.Store, userID id.UserID, deviceID id.DeviceID) { + err := cryptoStore.PutDevice(ctx, userID, &id.Device{ + UserID: userID, + DeviceID: deviceID, + }) + if err != nil { + panic(err) + } +} + +func initServerAndLoginTwoAlice(t *testing.T, ctx context.Context) (ts *mockserver.MockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) { t.Helper() - ts = createMockServer(t) + ts = mockserver.Create(t) sendingClient, sendingCryptoStore = ts.Login(t, ctx, aliceUserID, sendingDeviceID) sendingMachine = sendingClient.Crypto.(*cryptohelper.CryptoHelper).Machine() @@ -47,9 +58,9 @@ func initServerAndLoginTwoAlice(t *testing.T, ctx context.Context) (ts *mockServ return } -func initServerAndLoginAliceBob(t *testing.T, ctx context.Context) (ts *mockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) { +func initServerAndLoginAliceBob(t *testing.T, ctx context.Context) (ts *mockserver.MockServer, sendingClient, receivingClient *mautrix.Client, sendingCryptoStore, receivingCryptoStore crypto.Store, sendingMachine, receivingMachine *crypto.OlmMachine) { t.Helper() - ts = createMockServer(t) + ts = mockserver.Create(t) sendingClient, sendingCryptoStore = ts.Login(t, ctx, aliceUserID, sendingDeviceID) sendingMachine = sendingClient.Crypto.(*cryptohelper.CryptoHelper).Machine() @@ -116,8 +127,7 @@ func TestVerification_Start(t *testing.T) { for i, tc := range testCases { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - ts := createMockServer(t) - defer ts.Close() + ts := mockserver.Create(t) client, cryptoStore := ts.Login(t, ctx, aliceUserID, sendingDeviceID) addDeviceID(ctx, cryptoStore, aliceUserID, sendingDeviceID) @@ -166,7 +176,6 @@ func TestVerification_StartThenCancel(t *testing.T) { for _, sendingCancels := range []bool{true, false} { t.Run(fmt.Sprintf("sendingCancels=%t", sendingCancels), func(t *testing.T) { ts, sendingClient, receivingClient, sendingCryptoStore, receivingCryptoStore, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) - defer ts.Close() _, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) bystanderClient, _ := ts.Login(t, ctx, aliceUserID, bystanderDeviceID) @@ -186,13 +195,13 @@ func TestVerification_StartThenCancel(t *testing.T) { receivingInbox := ts.DeviceInbox[aliceUserID][receivingDeviceID] assert.Len(t, receivingInbox, 1) assert.Equal(t, txnID, receivingInbox[0].Content.AsVerificationRequest().TransactionID) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) // Process the request event on the bystander device. bystanderInbox := ts.DeviceInbox[aliceUserID][bystanderDeviceID] assert.Len(t, bystanderInbox, 1) assert.Equal(t, txnID, bystanderInbox[0].Content.AsVerificationRequest().TransactionID) - ts.dispatchToDevice(t, ctx, bystanderClient) + ts.DispatchToDevice(t, ctx, bystanderClient) // Cancel the verification request. var cancelEvt *event.VerificationCancelEventContent @@ -231,7 +240,7 @@ func TestVerification_StartThenCancel(t *testing.T) { if !sendingCancels { // Process the cancellation event on the sending device. - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) // Ensure that the cancellation event was sent to the bystander device. assert.Len(t, ts.DeviceInbox[aliceUserID][bystanderDeviceID], 1) @@ -247,8 +256,7 @@ func TestVerification_StartThenCancel(t *testing.T) { func TestVerification_Accept_NoSupportedMethods(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) - ts := createMockServer(t) - defer ts.Close() + ts := mockserver.Create(t) sendingClient, sendingCryptoStore := ts.Login(t, ctx, aliceUserID, sendingDeviceID) receivingClient, _ := ts.Login(t, ctx, aliceUserID, receivingDeviceID) @@ -274,7 +282,7 @@ func TestVerification_Accept_NoSupportedMethods(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, txnID) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) // Ensure that the receiver ignored the request because it // doesn't support any of the verification methods in the @@ -314,7 +322,6 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { for i, tc := range testCases { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) - defer ts.Close() recoveryKey, sendingCrossSigningKeysCache, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") assert.NoError(t, err) @@ -333,7 +340,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { require.NoError(t, err) // Process the verification request on the receiving device. - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) // Ensure that the receiving device received a verification // request with the correct transaction ID. @@ -373,7 +380,7 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { // Receive the m.key.verification.ready event on the sending // device. - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) // Ensure that the sending device got a notification about the // transaction being ready. @@ -402,7 +409,6 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) { func TestVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) ts, sendingClient, receivingClient, sendingCryptoStore, receivingCryptoStore, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) - defer ts.Close() _, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) nonParticipatingDeviceID1 := id.DeviceID("non-participating1") @@ -419,12 +425,12 @@ func TestVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) { // the receiving device. txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) // Receive the m.key.verification.ready event on the sending device. - ts.dispatchToDevice(t, ctx, sendingClient) + ts.DispatchToDevice(t, ctx, sendingClient) // The sending and receiving devices should not have any cancellation // events in their inboxes. @@ -444,7 +450,6 @@ func TestVerification_Accept_CancelOnNonParticipatingDevices(t *testing.T) { func TestVerification_ErrorOnDoubleAccept(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) - defer ts.Close() _, _, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) _, _, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") @@ -452,7 +457,7 @@ func TestVerification_ErrorOnDoubleAccept(t *testing.T) { txnID, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID) require.NoError(t, err) err = receivingHelper.AcceptVerification(ctx, txnID) @@ -472,7 +477,6 @@ func TestVerification_ErrorOnDoubleAccept(t *testing.T) { func TestVerification_CancelOnDoubleStart(t *testing.T) { ctx := log.Logger.WithContext(context.TODO()) ts, sendingClient, receivingClient, _, _, sendingMachine, receivingMachine := initServerAndLoginTwoAlice(t, ctx) - defer ts.Close() sendingCallbacks, receivingCallbacks, sendingHelper, receivingHelper := initDefaultCallbacks(t, ctx, sendingClient, receivingClient, sendingMachine, receivingMachine) _, _, err := sendingMachine.GenerateAndUploadCrossSigningKeys(ctx, nil, "") @@ -481,15 +485,15 @@ func TestVerification_CancelOnDoubleStart(t *testing.T) { // Send and accept the first verification request. txnID1, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) err = receivingHelper.AcceptVerification(ctx, txnID1) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.ready event + ts.DispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.ready event // Send a second verification request txnID2, err := sendingHelper.StartVerification(ctx, aliceUserID) require.NoError(t, err) - ts.dispatchToDevice(t, ctx, receivingClient) + ts.DispatchToDevice(t, ctx, receivingClient) // Ensure that the sending device received a cancellation event for both of // the ongoing transactions. @@ -507,7 +511,7 @@ func TestVerification_CancelOnDoubleStart(t *testing.T) { assert.NotNil(t, receivingCallbacks.GetVerificationCancellation(txnID1)) assert.NotNil(t, receivingCallbacks.GetVerificationCancellation(txnID2)) - ts.dispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.cancel events + ts.DispatchToDevice(t, ctx, sendingClient) // Process the m.key.verification.cancel events assert.NotNil(t, sendingCallbacks.GetVerificationCancellation(txnID1)) assert.NotNil(t, sendingCallbacks.GetVerificationCancellation(txnID2)) } diff --git a/crypto/verificationhelper/mockserver_test.go b/mockserver/mockserver.go similarity index 68% rename from crypto/verificationhelper/mockserver_test.go rename to mockserver/mockserver.go index 45ca7781..9f62b567 100644 --- a/crypto/verificationhelper/mockserver_test.go +++ b/mockserver/mockserver.go @@ -1,10 +1,10 @@ -// Copyright (c) 2024 Sumner Evans +// Copyright (c) 2025 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -package verificationhelper_test +package mockserver import ( "context" @@ -15,7 +15,7 @@ import ( "strings" "testing" - "github.com/rs/zerolog/log" // zerolog-allow-global-log + globallog "github.com/rs/zerolog/log" // zerolog-allow-global-log "github.com/stretchr/testify/require" "go.mau.fi/util/random" @@ -26,10 +26,9 @@ import ( "maunium.net/go/mautrix/id" ) -// mockServer is a mock Matrix server that wraps an [httptest.Server] to allow -// testing of the interactive verification process. -type mockServer struct { - *httptest.Server +type MockServer struct { + Router *http.ServeMux + Server *httptest.Server AccessTokenToUserID map[string]id.UserID DeviceInbox map[id.UserID]map[id.DeviceID][]event.Event @@ -40,10 +39,10 @@ type mockServer struct { UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys } -func createMockServer(t *testing.T) *mockServer { +func Create(t *testing.T) *MockServer { t.Helper() - server := mockServer{ + server := MockServer{ AccessTokenToUserID: map[string]id.UserID{}, DeviceInbox: map[id.UserID]map[id.DeviceID][]event.Event{}, AccountData: map[id.UserID]map[event.Type]json.RawMessage{}, @@ -61,12 +60,13 @@ func createMockServer(t *testing.T) *mockServer { router.HandleFunc("POST /_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload) router.HandleFunc("POST /_matrix/client/v3/keys/signatures/upload", server.emptyResp) router.HandleFunc("POST /_matrix/client/v3/keys/upload", server.postKeysUpload) - + server.Router = router server.Server = httptest.NewServer(router) + t.Cleanup(server.Server.Close) return &server } -func (ms *mockServer) getUserID(r *http.Request) id.UserID { +func (ms *MockServer) getUserID(r *http.Request) id.UserID { authHeader := r.Header.Get("Authorization") authHeader = strings.TrimPrefix(authHeader, "Bearer ") userID, ok := ms.AccessTokenToUserID[authHeader] @@ -76,11 +76,11 @@ func (ms *mockServer) getUserID(r *http.Request) id.UserID { return userID } -func (s *mockServer) emptyResp(w http.ResponseWriter, _ *http.Request) { +func (ms *MockServer) emptyResp(w http.ResponseWriter, _ *http.Request) { w.Write([]byte("{}")) } -func (s *mockServer) postLogin(w http.ResponseWriter, r *http.Request) { +func (ms *MockServer) postLogin(w http.ResponseWriter, r *http.Request) { var loginReq mautrix.ReqLogin json.NewDecoder(r.Body).Decode(&loginReq) @@ -91,7 +91,7 @@ func (s *mockServer) postLogin(w http.ResponseWriter, r *http.Request) { accessToken := random.String(30) userID := id.UserID(loginReq.Identifier.User) - s.AccessTokenToUserID[accessToken] = userID + ms.AccessTokenToUserID[accessToken] = userID json.NewEncoder(w).Encode(&mautrix.RespLogin{ AccessToken: accessToken, @@ -100,40 +100,40 @@ func (s *mockServer) postLogin(w http.ResponseWriter, r *http.Request) { }) } -func (s *mockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) { +func (ms *MockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) { var req mautrix.ReqSendToDevice json.NewDecoder(r.Body).Decode(&req) evtType := event.Type{Type: r.PathValue("type"), Class: event.ToDeviceEventType} for user, devices := range req.Messages { for device, content := range devices { - if _, ok := s.DeviceInbox[user]; !ok { - s.DeviceInbox[user] = map[id.DeviceID][]event.Event{} + if _, ok := ms.DeviceInbox[user]; !ok { + ms.DeviceInbox[user] = map[id.DeviceID][]event.Event{} } content.ParseRaw(evtType) - s.DeviceInbox[user][device] = append(s.DeviceInbox[user][device], event.Event{ - Sender: s.getUserID(r), + ms.DeviceInbox[user][device] = append(ms.DeviceInbox[user][device], event.Event{ + Sender: ms.getUserID(r), Type: evtType, Content: *content, }) } } - s.emptyResp(w, r) + ms.emptyResp(w, r) } -func (s *mockServer) putAccountData(w http.ResponseWriter, r *http.Request) { +func (ms *MockServer) putAccountData(w http.ResponseWriter, r *http.Request) { userID := id.UserID(r.PathValue("userID")) eventType := event.Type{Type: r.PathValue("type"), Class: event.AccountDataEventType} jsonData, _ := io.ReadAll(r.Body) - if _, ok := s.AccountData[userID]; !ok { - s.AccountData[userID] = map[event.Type]json.RawMessage{} + if _, ok := ms.AccountData[userID]; !ok { + ms.AccountData[userID] = map[event.Type]json.RawMessage{} } - s.AccountData[userID][eventType] = json.RawMessage(jsonData) - s.emptyResp(w, r) + ms.AccountData[userID][eventType] = json.RawMessage(jsonData) + ms.emptyResp(w, r) } -func (s *mockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) { +func (ms *MockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) { var req mautrix.ReqQueryKeys json.NewDecoder(r.Body).Decode(&req) resp := mautrix.RespQueryKeys{ @@ -143,44 +143,44 @@ func (s *mockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) { DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{}, } for user := range req.DeviceKeys { - resp.MasterKeys[user] = s.MasterKeys[user] - resp.UserSigningKeys[user] = s.UserSigningKeys[user] - resp.SelfSigningKeys[user] = s.SelfSigningKeys[user] - resp.DeviceKeys[user] = s.DeviceKeys[user] + resp.MasterKeys[user] = ms.MasterKeys[user] + resp.UserSigningKeys[user] = ms.UserSigningKeys[user] + resp.SelfSigningKeys[user] = ms.SelfSigningKeys[user] + resp.DeviceKeys[user] = ms.DeviceKeys[user] } json.NewEncoder(w).Encode(&resp) } -func (s *mockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) { +func (ms *MockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) { var req mautrix.ReqUploadKeys json.NewDecoder(r.Body).Decode(&req) - userID := s.getUserID(r) - if _, ok := s.DeviceKeys[userID]; !ok { - s.DeviceKeys[userID] = map[id.DeviceID]mautrix.DeviceKeys{} + userID := ms.getUserID(r) + if _, ok := ms.DeviceKeys[userID]; !ok { + ms.DeviceKeys[userID] = map[id.DeviceID]mautrix.DeviceKeys{} } - s.DeviceKeys[userID][req.DeviceKeys.DeviceID] = *req.DeviceKeys + ms.DeviceKeys[userID][req.DeviceKeys.DeviceID] = *req.DeviceKeys json.NewEncoder(w).Encode(&mautrix.RespUploadKeys{ OneTimeKeyCounts: mautrix.OTKCount{SignedCurve25519: 50}, }) } -func (s *mockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) { +func (ms *MockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) { var req mautrix.UploadCrossSigningKeysReq json.NewDecoder(r.Body).Decode(&req) - userID := s.getUserID(r) - s.MasterKeys[userID] = req.Master - s.SelfSigningKeys[userID] = req.SelfSigning - s.UserSigningKeys[userID] = req.UserSigning + userID := ms.getUserID(r) + ms.MasterKeys[userID] = req.Master + ms.SelfSigningKeys[userID] = req.SelfSigning + ms.UserSigningKeys[userID] = req.UserSigning - s.emptyResp(w, r) + ms.emptyResp(w, r) } -func (ms *mockServer) Login(t *testing.T, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) { +func (ms *MockServer) Login(t *testing.T, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) { t.Helper() - client, err := mautrix.NewClient(ms.URL, "", "") + client, err := mautrix.NewClient(ms.Server.URL, "", "") require.NoError(t, err) client.StateStore = mautrix.NewMemoryStateStore() @@ -204,7 +204,7 @@ func (ms *mockServer) Login(t *testing.T, ctx context.Context, userID id.UserID, err = cryptoHelper.Init(ctx) require.NoError(t, err) - machineLog := log.Logger.With(). + machineLog := globallog.Logger.With(). Stringer("my_user_id", userID). Stringer("my_device_id", deviceID). Logger() @@ -216,7 +216,7 @@ func (ms *mockServer) Login(t *testing.T, ctx context.Context, userID id.UserID, return client, cryptoStore } -func (ms *mockServer) dispatchToDevice(t *testing.T, ctx context.Context, client *mautrix.Client) { +func (ms *MockServer) DispatchToDevice(t *testing.T, ctx context.Context, client *mautrix.Client) { t.Helper() for _, evt := range ms.DeviceInbox[client.UserID][client.DeviceID] { @@ -224,13 +224,3 @@ func (ms *mockServer) dispatchToDevice(t *testing.T, ctx context.Context, client ms.DeviceInbox[client.UserID][client.DeviceID] = ms.DeviceInbox[client.UserID][client.DeviceID][1:] } } - -func addDeviceID(ctx context.Context, cryptoStore crypto.Store, userID id.UserID, deviceID id.DeviceID) { - err := cryptoStore.PutDevice(ctx, userID, &id.Device{ - UserID: userID, - DeviceID: deviceID, - }) - if err != nil { - panic(err) - } -} From caca057b2304679bc1875c9355547b59a326c9c0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 26 Sep 2025 19:17:16 +0300 Subject: [PATCH 1405/1647] crypto/helper: always share keys when creating new device --- crypto/cryptohelper/cryptohelper.go | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go index 56f8b484..74710678 100644 --- a/crypto/cryptohelper/cryptohelper.go +++ b/crypto/cryptohelper/cryptohelper.go @@ -225,13 +225,6 @@ func (helper *CryptoHelper) Init(ctx context.Context) error { helper.ASEventProcessor.On(event.EventEncrypted, helper.HandleEncrypted) } - if helper.client.SetAppServiceDeviceID { - err = helper.mach.ShareKeys(ctx, -1) - if err != nil { - return fmt.Errorf("failed to share keys: %w", err) - } - } - return nil } @@ -268,21 +261,21 @@ func (helper *CryptoHelper) verifyDeviceKeysOnServer(ctx context.Context) error if !ok || len(device.Keys) == 0 { if isShared { return fmt.Errorf("olm account is marked as shared, keys seem to have disappeared from the server") - } else { - helper.log.Debug().Msg("Olm account not shared and keys not on server, so device is probably fine") - return nil } + helper.log.Debug().Msg("Olm account not shared and keys not on server, sharing initial keys") + err = helper.mach.ShareKeys(ctx, -1) + if err != nil { + return fmt.Errorf("failed to share keys: %w", err) + } + return nil } else if !isShared { return fmt.Errorf("olm account is not marked as shared, but there are keys on the server") } else if ed := device.Keys.GetEd25519(helper.client.DeviceID); ownID.SigningKey != ed { return fmt.Errorf("mismatching identity key on server (%q != %q)", ownID.SigningKey, ed) - } - if !isShared { - helper.log.Debug().Msg("Olm account not marked as shared, but keys on server match?") } else { helper.log.Debug().Msg("Olm account marked as shared and keys on server match, device is fine") + return nil } - return nil } var NoSessionFound = crypto.NoSessionFound From fa90bba8205cc229ce767c82c809873c0f3bceb1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 26 Sep 2025 19:48:22 +0300 Subject: [PATCH 1406/1647] crypto: don't check otk count if sharing new keys --- crypto/machine.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crypto/machine.go b/crypto/machine.go index ab3e4591..4d2e3880 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -729,7 +729,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro start := time.Now() mach.otkUploadLock.Lock() defer mach.otkUploadLock.Unlock() - if mach.lastOTKUpload.Add(1*time.Minute).After(start) || currentOTKCount < 0 { + if mach.lastOTKUpload.Add(1*time.Minute).After(start) || (currentOTKCount < 0 && mach.account.Shared) { log.Debug().Msg("Checking OTK count from server due to suspiciously close share keys requests or negative OTK count") resp, err := mach.Client.UploadKeys(ctx, &mautrix.ReqUploadKeys{}) if err != nil { From acc449daf42192fe090926f4aa65cfebadcdc5cf Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 26 Sep 2025 20:37:58 +0300 Subject: [PATCH 1407/1647] crypto: add basic group session sharing benchmark --- crypto/machine_bench_test.go | 67 ++++++++++++++++++ mockserver/mockserver.go | 131 ++++++++++++++++++++++++++++------- responses_test.go | 2 - 3 files changed, 173 insertions(+), 27 deletions(-) create mode 100644 crypto/machine_bench_test.go diff --git a/crypto/machine_bench_test.go b/crypto/machine_bench_test.go new file mode 100644 index 00000000..fd40d795 --- /dev/null +++ b/crypto/machine_bench_test.go @@ -0,0 +1,67 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package crypto_test + +import ( + "context" + "fmt" + "math/rand/v2" + "testing" + + "github.com/rs/zerolog" + globallog "github.com/rs/zerolog/log" // zerolog-allow-global-log + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/crypto/cryptohelper" + "maunium.net/go/mautrix/id" + "maunium.net/go/mautrix/mockserver" +) + +func randomDeviceCount(r *rand.Rand) int { + k := 1 + for k < 10 && r.IntN(3) > 0 { + k++ + } + return k +} + +func BenchmarkOlmMachine_ShareGroupSession(b *testing.B) { + globallog.Logger = zerolog.Nop() + server := mockserver.Create(b) + server.PopOTKs = false + server.MemoryStore = false + var i int + var shareTargets []id.UserID + r := rand.New(rand.NewPCG(293, 0)) + var totalDeviceCount int + for i = 1; i < 1000; i++ { + userID := id.UserID(fmt.Sprintf("@user%d:localhost", i)) + deviceCount := randomDeviceCount(r) + for j := 0; j < deviceCount; j++ { + client, _ := server.Login(b, nil, userID, id.DeviceID(fmt.Sprintf("u%d_d%d", i, j))) + mach := client.Crypto.(*cryptohelper.CryptoHelper).Machine() + keysCache, err := mach.GenerateCrossSigningKeys() + require.NoError(b, err) + err = mach.PublishCrossSigningKeys(context.TODO(), keysCache, nil) + require.NoError(b, err) + } + totalDeviceCount += deviceCount + shareTargets = append(shareTargets, userID) + } + for b.Loop() { + client, _ := server.Login(b, nil, id.UserID(fmt.Sprintf("@benchuser%d:localhost", i)), id.DeviceID(fmt.Sprintf("u%d_d1", i))) + mach := client.Crypto.(*cryptohelper.CryptoHelper).Machine() + keysCache, err := mach.GenerateCrossSigningKeys() + require.NoError(b, err) + err = mach.PublishCrossSigningKeys(context.TODO(), keysCache, nil) + require.NoError(b, err) + err = mach.ShareGroupSession(context.TODO(), "!room:localhost", shareTargets) + require.NoError(b, err) + i++ + } + fmt.Println(totalDeviceCount, "devices total") +} diff --git a/mockserver/mockserver.go b/mockserver/mockserver.go index 9f62b567..e52c387a 100644 --- a/mockserver/mockserver.go +++ b/mockserver/mockserver.go @@ -9,7 +9,9 @@ package mockserver import ( "context" "encoding/json" + "fmt" "io" + "maps" "net/http" "net/http/httptest" "strings" @@ -17,6 +19,9 @@ import ( globallog "github.com/rs/zerolog/log" // zerolog-allow-global-log "github.com/stretchr/testify/require" + "go.mau.fi/util/dbutil" + "go.mau.fi/util/exerrors" + "go.mau.fi/util/exhttp" "go.mau.fi/util/random" "maunium.net/go/mautrix" @@ -26,35 +31,52 @@ import ( "maunium.net/go/mautrix/id" ) +func mustDecode(r *http.Request, data any) { + exerrors.PanicIfNotNil(json.NewDecoder(r.Body).Decode(data)) +} + +type userAndDeviceID struct { + UserID id.UserID + DeviceID id.DeviceID +} + type MockServer struct { Router *http.ServeMux Server *httptest.Server - AccessTokenToUserID map[string]id.UserID + AccessTokenToUserID map[string]userAndDeviceID DeviceInbox map[id.UserID]map[id.DeviceID][]event.Event AccountData map[id.UserID]map[event.Type]json.RawMessage DeviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys + OneTimeKeys map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey MasterKeys map[id.UserID]mautrix.CrossSigningKeys SelfSigningKeys map[id.UserID]mautrix.CrossSigningKeys UserSigningKeys map[id.UserID]mautrix.CrossSigningKeys + + PopOTKs bool + MemoryStore bool } -func Create(t *testing.T) *MockServer { +func Create(t testing.TB) *MockServer { t.Helper() server := MockServer{ - AccessTokenToUserID: map[string]id.UserID{}, + AccessTokenToUserID: map[string]userAndDeviceID{}, DeviceInbox: map[id.UserID]map[id.DeviceID][]event.Event{}, AccountData: map[id.UserID]map[event.Type]json.RawMessage{}, DeviceKeys: map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys{}, + OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}, MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{}, SelfSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, + PopOTKs: true, + MemoryStore: true, } router := http.NewServeMux() router.HandleFunc("POST /_matrix/client/v3/login", server.postLogin) router.HandleFunc("POST /_matrix/client/v3/keys/query", server.postKeysQuery) + router.HandleFunc("POST /_matrix/client/v3/keys/claim", server.postKeysClaim) router.HandleFunc("PUT /_matrix/client/v3/sendToDevice/{type}/{txn}", server.putSendToDevice) router.HandleFunc("PUT /_matrix/client/v3/user/{userID}/account_data/{type}", server.putAccountData) router.HandleFunc("POST /_matrix/client/v3/keys/device_signing/upload", server.postDeviceSigningUpload) @@ -66,7 +88,7 @@ func Create(t *testing.T) *MockServer { return &server } -func (ms *MockServer) getUserID(r *http.Request) id.UserID { +func (ms *MockServer) getUserID(r *http.Request) userAndDeviceID { authHeader := r.Header.Get("Authorization") authHeader = strings.TrimPrefix(authHeader, "Bearer ") userID, ok := ms.AccessTokenToUserID[authHeader] @@ -77,12 +99,12 @@ func (ms *MockServer) getUserID(r *http.Request) id.UserID { } func (ms *MockServer) emptyResp(w http.ResponseWriter, _ *http.Request) { - w.Write([]byte("{}")) + exhttp.WriteEmptyJSONResponse(w, http.StatusOK) } func (ms *MockServer) postLogin(w http.ResponseWriter, r *http.Request) { var loginReq mautrix.ReqLogin - json.NewDecoder(r.Body).Decode(&loginReq) + mustDecode(r, &loginReq) deviceID := loginReq.DeviceID if deviceID == "" { @@ -91,9 +113,12 @@ func (ms *MockServer) postLogin(w http.ResponseWriter, r *http.Request) { accessToken := random.String(30) userID := id.UserID(loginReq.Identifier.User) - ms.AccessTokenToUserID[accessToken] = userID + ms.AccessTokenToUserID[accessToken] = userAndDeviceID{ + UserID: userID, + DeviceID: deviceID, + } - json.NewEncoder(w).Encode(&mautrix.RespLogin{ + exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespLogin{ AccessToken: accessToken, DeviceID: deviceID, UserID: userID, @@ -102,7 +127,7 @@ func (ms *MockServer) postLogin(w http.ResponseWriter, r *http.Request) { func (ms *MockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) { var req mautrix.ReqSendToDevice - json.NewDecoder(r.Body).Decode(&req) + mustDecode(r, &req) evtType := event.Type{Type: r.PathValue("type"), Class: event.ToDeviceEventType} for user, devices := range req.Messages { @@ -112,7 +137,7 @@ func (ms *MockServer) putSendToDevice(w http.ResponseWriter, r *http.Request) { } content.ParseRaw(evtType) ms.DeviceInbox[user][device] = append(ms.DeviceInbox[user][device], event.Event{ - Sender: ms.getUserID(r), + Sender: ms.getUserID(r).UserID, Type: evtType, Content: *content, }) @@ -135,7 +160,7 @@ func (ms *MockServer) putAccountData(w http.ResponseWriter, r *http.Request) { func (ms *MockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) { var req mautrix.ReqQueryKeys - json.NewDecoder(r.Body).Decode(&req) + mustDecode(r, &req) resp := mautrix.RespQueryKeys{ MasterKeys: map[id.UserID]mautrix.CrossSigningKeys{}, UserSigningKeys: map[id.UserID]mautrix.CrossSigningKeys{}, @@ -148,29 +173,68 @@ func (ms *MockServer) postKeysQuery(w http.ResponseWriter, r *http.Request) { resp.SelfSigningKeys[user] = ms.SelfSigningKeys[user] resp.DeviceKeys[user] = ms.DeviceKeys[user] } - json.NewEncoder(w).Encode(&resp) + exhttp.WriteJSONResponse(w, http.StatusOK, &resp) +} + +func (ms *MockServer) postKeysClaim(w http.ResponseWriter, r *http.Request) { + var req mautrix.ReqClaimKeys + mustDecode(r, &req) + resp := mautrix.RespClaimKeys{ + OneTimeKeys: map[id.UserID]map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{}, + } + for user, devices := range req.OneTimeKeys { + resp.OneTimeKeys[user] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{} + for device := range devices { + keys := ms.OneTimeKeys[user][device] + for keyID, key := range keys { + if ms.PopOTKs { + delete(keys, keyID) + } + resp.OneTimeKeys[user][device] = map[id.KeyID]mautrix.OneTimeKey{ + keyID: key, + } + break + } + } + } + exhttp.WriteJSONResponse(w, http.StatusOK, &resp) } func (ms *MockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) { var req mautrix.ReqUploadKeys - json.NewDecoder(r.Body).Decode(&req) + mustDecode(r, &req) - userID := ms.getUserID(r) + uid := ms.getUserID(r) + userID := uid.UserID if _, ok := ms.DeviceKeys[userID]; !ok { ms.DeviceKeys[userID] = map[id.DeviceID]mautrix.DeviceKeys{} } - ms.DeviceKeys[userID][req.DeviceKeys.DeviceID] = *req.DeviceKeys + if _, ok := ms.OneTimeKeys[userID]; !ok { + ms.OneTimeKeys[userID] = map[id.DeviceID]map[id.KeyID]mautrix.OneTimeKey{} + } - json.NewEncoder(w).Encode(&mautrix.RespUploadKeys{ - OneTimeKeyCounts: mautrix.OTKCount{SignedCurve25519: 50}, + if req.DeviceKeys != nil { + ms.DeviceKeys[userID][uid.DeviceID] = *req.DeviceKeys + } + otks, ok := ms.OneTimeKeys[userID][uid.DeviceID] + if !ok { + otks = map[id.KeyID]mautrix.OneTimeKey{} + ms.OneTimeKeys[userID][uid.DeviceID] = otks + } + if req.OneTimeKeys != nil { + maps.Copy(otks, req.OneTimeKeys) + } + + exhttp.WriteJSONResponse(w, http.StatusOK, &mautrix.RespUploadKeys{ + OneTimeKeyCounts: mautrix.OTKCount{SignedCurve25519: len(otks)}, }) } func (ms *MockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) { var req mautrix.UploadCrossSigningKeysReq - json.NewDecoder(r.Body).Decode(&req) + mustDecode(r, &req) - userID := ms.getUserID(r) + userID := ms.getUserID(r).UserID ms.MasterKeys[userID] = req.Master ms.SelfSigningKeys[userID] = req.SelfSigning ms.UserSigningKeys[userID] = req.UserSigning @@ -178,11 +242,14 @@ func (ms *MockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Req ms.emptyResp(w, r) } -func (ms *MockServer) Login(t *testing.T, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) { +func (ms *MockServer) Login(t testing.TB, ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*mautrix.Client, crypto.Store) { t.Helper() + if ctx == nil { + ctx = context.TODO() + } client, err := mautrix.NewClient(ms.Server.URL, "", "") require.NoError(t, err) - client.StateStore = mautrix.NewMemoryStateStore() + client.Client = ms.Server.Client() _, err = client.Login(ctx, &mautrix.ReqLogin{ Type: mautrix.AuthTypePassword, @@ -196,8 +263,22 @@ func (ms *MockServer) Login(t *testing.T, ctx context.Context, userID id.UserID, }) require.NoError(t, err) - cryptoStore := crypto.NewMemoryStore(nil) - cryptoHelper, err := cryptohelper.NewCryptoHelper(client, []byte("test"), cryptoStore) + var store any + if ms.MemoryStore { + store = crypto.NewMemoryStore(nil) + client.StateStore = mautrix.NewMemoryStateStore() + } else { + store, err = dbutil.NewFromConfig("", dbutil.Config{ + PoolConfig: dbutil.PoolConfig{ + Type: "sqlite3-fk-wal", + URI: fmt.Sprintf("file:%s?mode=memory&cache=shared&_txlock=immediate", random.String(10)), + MaxOpenConns: 5, + MaxIdleConns: 1, + }, + }, nil) + require.NoError(t, err) + } + cryptoHelper, err := cryptohelper.NewCryptoHelper(client, []byte("test"), store) require.NoError(t, err) client.Crypto = cryptoHelper @@ -213,10 +294,10 @@ func (ms *MockServer) Login(t *testing.T, ctx context.Context, userID id.UserID, err = cryptoHelper.Machine().ShareKeys(ctx, 50) require.NoError(t, err) - return client, cryptoStore + return client, cryptoHelper.Machine().CryptoStore } -func (ms *MockServer) DispatchToDevice(t *testing.T, ctx context.Context, client *mautrix.Client) { +func (ms *MockServer) DispatchToDevice(t testing.TB, ctx context.Context, client *mautrix.Client) { t.Helper() for _, evt := range ms.DeviceInbox[client.UserID][client.DeviceID] { diff --git a/responses_test.go b/responses_test.go index b23d85ad..73d82635 100644 --- a/responses_test.go +++ b/responses_test.go @@ -8,7 +8,6 @@ package mautrix_test import ( "encoding/json" - "fmt" "testing" "github.com/stretchr/testify/assert" @@ -86,7 +85,6 @@ func TestRespCapabilities_UnmarshalJSON(t *testing.T) { var caps mautrix.RespCapabilities err := json.Unmarshal([]byte(sampleData), &caps) require.NoError(t, err) - fmt.Println(caps) require.NotNil(t, caps.RoomVersions) assert.Equal(t, "9", caps.RoomVersions.Default) From a3c6832c487fd07ee63bd136b2c63e637888fb0e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 26 Sep 2025 23:18:05 +0300 Subject: [PATCH 1408/1647] federation/eventauth: fix default power levels in pre-v12 rooms --- federation/eventauth/eventauth.go | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/federation/eventauth/eventauth.go b/federation/eventauth/eventauth.go index d4a50969..6c36d478 100644 --- a/federation/eventauth/eventauth.go +++ b/federation/eventauth/eventauth.go @@ -733,23 +733,27 @@ func findEventAndReadString(events []*pdu.PDU, evtType, stateKey, fieldPath, def func getPowerLevels(roomVersion id.RoomVersion, authEvents []*pdu.PDU, createEvt *pdu.PDU) (*event.PowerLevelsEventContent, error) { var err error - powerLevels := findEventAndReadData(authEvents, event.StatePowerLevels.Type, "", func(evt *pdu.PDU) (out event.PowerLevelsEventContent) { + powerLevels := findEventAndReadData(authEvents, event.StatePowerLevels.Type, "", func(evt *pdu.PDU) *event.PowerLevelsEventContent { if evt == nil { - return + return nil } content := evt.Content + out := &event.PowerLevelsEventContent{} if !roomVersion.ValidatePowerLevelInts() { - safeParsePowerLevels(content, &out) + safeParsePowerLevels(content, out) } else { - err = json.Unmarshal(content, &out) + err = json.Unmarshal(content, out) } - return + return out }) if err != nil { // This should never happen thanks to safeParsePowerLevels for v1-9 and strict validation in v10+ return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err) } if roomVersion.PrivilegedRoomCreators() { + if powerLevels == nil { + powerLevels = &event.PowerLevelsEventContent{} + } powerLevels.CreateEvent, err = createEvt.ToClientEvent(roomVersion) if err != nil { return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err) @@ -758,12 +762,14 @@ func getPowerLevels(roomVersion id.RoomVersion, authEvents []*pdu.PDU, createEvt if err != nil { return nil, fmt.Errorf("%w: %w", ErrFailedToParsePowerLevels, err) } - } else { - powerLevels.Users = map[id.UserID]int{ - createEvt.Sender: (1 << 53) - 1, + } else if powerLevels == nil { + powerLevels = &event.PowerLevelsEventContent{ + Users: map[id.UserID]int{ + createEvt.Sender: 100, + }, } } - return &powerLevels, nil + return powerLevels, nil } func parseIntWithVersion(roomVersion id.RoomVersion, val gjson.Result) *int { From ae6a0b4f512c635cb2ab3c0af5e88e4cc11e58f3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 26 Sep 2025 23:26:17 +0300 Subject: [PATCH 1409/1647] federation/eventauth: fix checking user power level changes --- federation/eventauth/eventauth.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/federation/eventauth/eventauth.go b/federation/eventauth/eventauth.go index 6c36d478..f8d90248 100644 --- a/federation/eventauth/eventauth.go +++ b/federation/eventauth/eventauth.go @@ -664,8 +664,9 @@ func allowPowerChangeMap(roomVersion id.RoomVersion, maxVal int, path, ownID str newVal := new.Get(exgjson.Path(key.Str)) err = allowPowerChange(roomVersion, maxVal, path+"."+key.Str, value, newVal) if err == nil && ownID != "" && key.Str != ownID { - val := parseIntWithVersion(roomVersion, value) - if *val >= maxVal { + parsedOldVal := parseIntWithVersion(roomVersion, value) + parsedNewVal := parseIntWithVersion(roomVersion, newVal) + if *parsedOldVal >= maxVal && *parsedOldVal != *parsedNewVal { err = fmt.Errorf("%w: can't change users.%s from %s to %s with sender level %d", ErrInvalidPowerChange, key.Str, stringifyForError(value), stringifyForError(newVal), maxVal) } } From 6e231a45e4e848f260f98fe242a417375d02a629 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 26 Sep 2025 23:36:03 +0300 Subject: [PATCH 1410/1647] federation/eventauth: fix gjson path construction in new power level check --- federation/eventauth/eventauth.go | 2 +- federation/eventauth/testroom-v12-success.jsonl | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/federation/eventauth/eventauth.go b/federation/eventauth/eventauth.go index f8d90248..3dfdeb48 100644 --- a/federation/eventauth/eventauth.go +++ b/federation/eventauth/eventauth.go @@ -676,7 +676,7 @@ func allowPowerChangeMap(roomVersion id.RoomVersion, maxVal int, path, ownID str return } new.ForEach(func(key, value gjson.Result) bool { - err = allowPowerChange(roomVersion, maxVal, path+"."+key.Str, old.Get(exgjson.Path(key.Path(key.Str))), value) + err = allowPowerChange(roomVersion, maxVal, path+"."+key.Str, old.Get(exgjson.Path(key.Str)), value) return err == nil }) return diff --git a/federation/eventauth/testroom-v12-success.jsonl b/federation/eventauth/testroom-v12-success.jsonl index 1f0b5357..2b751de3 100644 --- a/federation/eventauth/testroom-v12-success.jsonl +++ b/federation/eventauth/testroom-v12-success.jsonl @@ -15,3 +15,7 @@ {"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":9001},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":15,"hashes":{"sha256":"FnGzbcXc8YOiB1TY33QunGA17Axoyuu3sdVOj5Z408o"},"origin_server_ts":1756071804931,"prev_events":["$p4xvOczrhzQMtRW3-Tf86LYUb5aqpGFIgjwHBuxWIcI"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"uyTUsPR+CzCtlevzB5+sNXvmfbPSp6u7RZC4E4TLVsj45+pjmMRswAvuHP9PT2+Tkl6Hu8ZPigsXgbKZtR35Aw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw","replaces_state":"$Qg1xRB8nL8lGykGvt9_agu_WCWq8Y3rl_p_LKa6D2Hg"}} {"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":16,"hashes":{"sha256":"KcivsiLesdnUnKX23Akk3OJEJFGRSY0g4H+p7XIThnw"},"origin_server_ts":1756071812688,"prev_events":["$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"cAK8dO2AVZklY9te5aVKbF1jR/eB5rzeNOXfYPjBLf+aSAS4Z6R2aMKW6hJB9PqRS4S+UZc24DTrjUjnvMzeBA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU","replaces_state":"$uZ4OOtkM8RcbEkhjNp-YlEH0zBqgsRx1eI8b2YP7ovw"}} {"auth_events":["$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE","$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"content":{"body":"meow #2","com.beeper.linkpreviews":[],"m.mentions":{},"msgtype":"m.text"},"depth":17,"hashes":{"sha256":"SgH9fOXGdbdqpRfYmoz1t29+gX8Ze4ThSoj6klZs3Og"},"origin_server_ts":1756247476706,"prev_events":["$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"SMYK7zP3SaQOKhzZUKUBVCKwffYqi3PFAlPM34kRJtmfGU3KZXNBT0zi+veXDMmxkMunqhF2RTHBD6joa0kBAQ"}},"type":"m.room.message","unsigned":{"event_id":"$KFHLO0-ENYOGQXogp84C-ISSu1xtKUzIMaZ6LiBcR_w"}} +{"auth_events":["$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:beeper.com":8999,"@tulir:envs.net":9001,"@tulir:matrix.org":9000},"users_default":0},"depth":18,"hashes":{"sha256":"l8Mw3VKn/Bvntg7bZ8uh5J8M2IBZM93Xg7hsdaSci8s"},"origin_server_ts":1758918656341,"prev_events":["$KFHLO0-ENYOGQXogp84C-ISSu1xtKUzIMaZ6LiBcR_w"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"matrix.org":{"ed25519:a_RXGa":"cg5LP0WuTnVB5jFhNERLLU5b+EhmyACiOq6cp3gKJnZsTAb1yajcgJybLWKrc8QQqxPa7hPnskRBgt4OBTFNAA"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0","replaces_state":"$iwqRXQc2cx8K4AclTjU1Se-BMJpUl4DxrLm3nfUgeQU"}} +{"auth_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$_gYjNODWJdo5-S1IN0bmAk3rzIeXzr5W5cmXZSmUsNw","$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"content":{"avatar_url":"mxc://beeper.com/eBdwbHbllONoAySQkXLjbfFM","displayname":"tulir[b]","membership":"invite"},"depth":19,"hashes":{"sha256":"KpmaRUQnJju8TIDMPzakitUIKOWJxTvULpFB3a1CGgc"},"origin_server_ts":1758918665952,"prev_events":["$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:matrix.org","signatures":{"beeper.com":{"ed25519:a_zgvp":"mzI9rPkQ1xHl2/G5Yrn0qmIRt5OyjPNqRwilPfH4jmr1tP+vv3vC0m4mph/MCOq8S1c/DQaCWSpdOX1uWfchBQ"},"matrix.org":{"ed25519:a_RXGa":"kEdfr8DjxC/bdvGYxnniFI/pxDWeyG73OjG/Gu1uoHLhjdtAT/vEQ6lotJJs214/KX5eAaQWobE9qtMvtPwMDw"}},"state_key":"@tulir:beeper.com","type":"m.room.member","unsigned":{"event_id":"$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro","invite_room_state":[{"auth_events":[],"content":{"room_version":"12"},"depth":1,"hashes":{"sha256":"qJYytb+EqWPiiZ0ogDODcLeA8XYw/2hVTaLHihcVBZQ"},"origin_server_ts":1756071567186,"prev_events":[],"sender":"@tulir:maunium.net","signatures":{"maunium.net":{"ed25519:a_xxeS":"/9pp+2tkLo6XcZ3opqLeIpa3D96fh3QLpR2PQrZ6Z6j7wyRAvBrcgCpAeMtuyDCzW8Wh1QFEPG4FSsGvVaEFBg"}},"state_key":"","type":"m.room.create","unsigned":{"age":11553}},{"content":{"avatar_url":"mxc://matrix.org/BDYVQFSLvZHMaKHDGiRkvhVg","displayname":"tulir[m]","membership":"join"},"sender":"@tulir:matrix.org","state_key":"@tulir:matrix.org","type":"m.room.member"},{"content":{"name":"event auth test v12!"},"sender":"@tulir:matrix.org","state_key":"","type":"m.room.name"},{"content":{"join_rule":"invite"},"sender":"@tulir:maunium.net","state_key":"","type":"m.room.join_rules"}]}} +{"auth_events":["$deNVGs6Ef7OKVrvewhtPv7DCCqSip112cEJYp-jkP6M","$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro","$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"],"content":{"avatar_url":"mxc://beeper.com/eBdwbHbllONoAySQkXLjbfFM","displayname":"tulir[b]","membership":"join"},"depth":20,"hashes":{"sha256":"bmaHSm4mYPNBNlUfFsauSTxLrUH4CUSAKYvr1v76qkk"},"origin_server_ts":1758918670276,"prev_events":["$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:beeper.com","signatures":{"beeper.com":{"ed25519:a_zgvp":"D3cz3m15m89a3G4c5yWOBCjhtSeI5IxBfQKt5XOr9a44QHyc3nwjjvIJaRrKNcS5tLUJwZ2IpVzjlrpbPHpxDA"}},"state_key":"@tulir:beeper.com","type":"m.room.member","unsigned":{"age":6,"event_id":"$_hayW1Y0HRWp3VEGZZbsMf0Ncg9x6n0ikveD0lbCwMw","replaces_state":"$PZJZoUwNySl0jY16DkHBHR0HyAppLdxc0rkSuYp5Mro"}} +{"auth_events":["$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0","$Bz2lxsbUYkeBDE7eMAsOm_TK_iuSuHNvQdrHnc-T1PE"],"content":{"ban":50,"events":{"m.room.avatar":50,"m.room.canonical_alias":50,"m.room.encryption":100,"m.room.history_visibility":100,"m.room.name":50,"m.room.power_levels":100,"m.room.server_acl":100,"m.room.tombstone":100},"events_default":0,"historical":12345,"invite":0,"kick":50,"redact":50,"state_default":50,"users":{"@tulir:beeper.com":9000,"@tulir:envs.net":9001,"@tulir:matrix.org":8999},"users_default":0},"depth":21,"hashes":{"sha256":"xCj9vszChHiXba9DaPzhtF79Tphek3pRViMp36DOurU"},"origin_server_ts":1758918689485,"prev_events":["$_hayW1Y0HRWp3VEGZZbsMf0Ncg9x6n0ikveD0lbCwMw"],"room_id":"!lVEL38waGAf4ggmWC3OVk_bbx8kZx-iOcTBKXTBnM54","sender":"@tulir:envs.net","signatures":{"envs.net":{"ed25519:wuJyKT":"odkrWD30+ObeYtagULtECB/QmGae7qNy66nmJMWYXiQMYUJw/GMzSmgAiLAWfVYlfD3aEvMb/CBdrhL07tfSBw"}},"state_key":"","type":"m.room.power_levels","unsigned":{"event_id":"$di6cI89-GxX8-Wbx-0T69l4wg6TUWITRkjWXzG7EBqo","replaces_state":"$x-CCUewbWOHQXqfcUsywOmbHvNnOSwNM1RyOu-c8SB0"}} From 9878c3d67542e48ab050e535d0a92688d8fac6f8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 26 Sep 2025 23:36:58 +0300 Subject: [PATCH 1411/1647] federation/eventauth: change error message for users-specific power level check --- federation/eventauth/eventauth.go | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/federation/eventauth/eventauth.go b/federation/eventauth/eventauth.go index 3dfdeb48..7d73abcd 100644 --- a/federation/eventauth/eventauth.go +++ b/federation/eventauth/eventauth.go @@ -110,12 +110,13 @@ var ( ErrMismatchingPrivateStateKey = AuthFailError{Index: "9", Message: "state keys starting with @ must match sender user ID"} - ErrTopLevelPLNotInteger = AuthFailError{Index: "10.1", Message: "invalid type for top-level power level field"} - ErrPLNotInteger = AuthFailError{Index: "10.2", Message: "invalid type for power level"} - ErrInvalidUserIDInPL = AuthFailError{Index: "10.3", Message: "invalid user ID in power levels"} - ErrUserPLNotInteger = AuthFailError{Index: "10.3", Message: "invalid type for user power level"} - ErrCreatorInPowerLevels = AuthFailError{Index: "10.4", Message: "room creators must not be specified in power levels"} - ErrInvalidPowerChange = AuthFailError{Index: "10.x", Message: "illegal power level change"} + ErrTopLevelPLNotInteger = AuthFailError{Index: "10.1", Message: "invalid type for top-level power level field"} + ErrPLNotInteger = AuthFailError{Index: "10.2", Message: "invalid type for power level"} + ErrInvalidUserIDInPL = AuthFailError{Index: "10.3", Message: "invalid user ID in power levels"} + ErrUserPLNotInteger = AuthFailError{Index: "10.3", Message: "invalid type for user power level"} + ErrCreatorInPowerLevels = AuthFailError{Index: "10.4", Message: "room creators must not be specified in power levels"} + ErrInvalidPowerChange = AuthFailError{Index: "10.x", Message: "illegal power level change"} + ErrInvalidUserPowerChange = AuthFailError{Index: "10.9", Message: "illegal power level change"} ) func isRejected(evt *pdu.PDU) bool { @@ -667,7 +668,7 @@ func allowPowerChangeMap(roomVersion id.RoomVersion, maxVal int, path, ownID str parsedOldVal := parseIntWithVersion(roomVersion, value) parsedNewVal := parseIntWithVersion(roomVersion, newVal) if *parsedOldVal >= maxVal && *parsedOldVal != *parsedNewVal { - err = fmt.Errorf("%w: can't change users.%s from %s to %s with sender level %d", ErrInvalidPowerChange, key.Str, stringifyForError(value), stringifyForError(newVal), maxVal) + err = fmt.Errorf("%w: can't change users.%s from %s to %s with sender level %d", ErrInvalidUserPowerChange, key.Str, stringifyForError(value), stringifyForError(newVal), maxVal) } } return err == nil From 743cbb5f2ce71aaa97a56e76adf9597383a1b3c2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 27 Sep 2025 16:26:15 +0300 Subject: [PATCH 1412/1647] bridgev2/mxmain: add option to mix calendar and semantic versioning --- bridgev2/matrix/mxmain/main.go | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index e6219c50..9fec278d 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -62,6 +62,9 @@ type BridgeMain struct { // git tag to see if the built version is the release or a dev build. // You can either bump this right after a release or right before, as long as it matches on the release commit. Version string + // SemCalVer defines whether this bridge uses a mix of semantic and calendar versioning, + // such that the Version field is YY.0M.patch, while git tags are major.YY0M.patch. + SemCalVer bool // PostInit is a function that will be called after the bridge has been initialized but before it is started. PostInit func() @@ -424,6 +427,21 @@ func (br *BridgeMain) Stop() { br.Bridge.StopWithTimeout(5 * time.Second) } +func semverToCalver(semver string) string { + parts := strings.SplitN(semver, ".", 3) + if len(parts) < 2 { + panic(fmt.Errorf("invalid semver for calendar versioning: %s", semver)) + } + if len(parts[1]) != 4 { + panic(fmt.Errorf("invalid minor semver component for calendar versioning: %s", parts[1])) + } + calver := parts[1][:2] + "." + parts[1][2:] + if len(parts) == 3 { + calver += "." + parts[2] + } + return calver +} + // InitVersion formats the bridge version and build time nicely for things like // the `version` bridge command on Matrix and the `--version` CLI flag. // @@ -447,9 +465,13 @@ func (br *BridgeMain) Stop() { // (to use both at the same time, simply merge the ldflags into one, `-ldflags "-X '...' -X ..."`) func (br *BridgeMain) InitVersion(tag, commit, rawBuildTime string) { br.baseVersion = br.Version + rawTag := tag if len(tag) > 0 && tag[0] == 'v' { tag = tag[1:] } + if br.SemCalVer && len(tag) > 0 { + tag = semverToCalver(tag) + } if tag != br.Version { suffix := "" if !strings.HasSuffix(br.Version, "+dev") { @@ -464,7 +486,7 @@ func (br *BridgeMain) InitVersion(tag, commit, rawBuildTime string) { br.LinkifiedVersion = fmt.Sprintf("v%s", br.Version) if tag == br.Version { - br.LinkifiedVersion = fmt.Sprintf("[v%s](%s/releases/v%s)", br.Version, br.URL, tag) + br.LinkifiedVersion = fmt.Sprintf("[v%s](%s/releases/v%s)", br.Version, br.URL, rawTag) } else if len(commit) > 8 { br.LinkifiedVersion = strings.Replace(br.LinkifiedVersion, commit[:8], fmt.Sprintf("[%s](%s/commit/%s)", commit[:8], br.URL, commit), 1) } From d146b6caf80673aad09b26c4bfd117a1ca072dde Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 27 Sep 2025 16:49:42 +0300 Subject: [PATCH 1413/1647] bridgev2/mxmain: move version calculation to go-util --- bridgev2/matrix/mxmain/main.go | 98 ++++++---------------------------- go.mod | 2 +- go.sum | 4 +- 3 files changed, 19 insertions(+), 85 deletions(-) diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index 9fec278d..9e409875 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -26,6 +26,7 @@ import ( "go.mau.fi/util/dbutil" "go.mau.fi/util/exerrors" "go.mau.fi/util/exzerolog" + "go.mau.fi/util/progver" "gopkg.in/yaml.v3" flag "maunium.net/go/mauflag" @@ -89,11 +90,7 @@ type BridgeMain struct { RegistrationPath string SaveConfig bool - baseVersion string - commit string - LinkifiedVersion string - VersionDesc string - BuildTime time.Time + ver progver.ProgramVersion AdditionalShortFlags string AdditionalLongFlags string @@ -102,14 +99,7 @@ type BridgeMain struct { } type VersionJSONOutput struct { - Name string - URL string - - Version string - IsRelease bool - Commit string - FormattedVersion string - BuildTime time.Time + progver.ProgramVersion OS string Arch string @@ -150,18 +140,11 @@ func (br *BridgeMain) PreInit() { flag.PrintHelp() os.Exit(0) } else if *version { - fmt.Println(br.VersionDesc) + fmt.Println(br.ver.FormattedVersion) os.Exit(0) } else if *versionJSON { output := VersionJSONOutput{ - URL: br.URL, - Name: br.Name, - - Version: br.baseVersion, - IsRelease: br.Version == br.baseVersion, - Commit: br.commit, - FormattedVersion: br.Version, - BuildTime: br.BuildTime, + ProgramVersion: br.ver, OS: runtime.GOOS, Arch: runtime.GOARCH, @@ -243,8 +226,8 @@ func (br *BridgeMain) Init() { br.Log.Info(). Str("name", br.Name). - Str("version", br.Version). - Time("built_at", br.BuildTime). + Str("version", br.ver.FormattedVersion). + Time("built_at", br.ver.BuildTime). Str("go_version", runtime.Version()). Msg("Initializing bridge") @@ -258,7 +241,7 @@ func (br *BridgeMain) Init() { br.Matrix.AS.DoublePuppetValue = br.Name br.Bridge.Commands.(*commands.Processor).AddHandler(&commands.FullHandler{ Func: func(ce *commands.Event) { - ce.Reply("[%s](%s) %s (%s)", br.Name, br.URL, br.LinkifiedVersion, br.BuildTime.Format(time.RFC1123)) + ce.Reply(br.ver.MarkdownDescription()) }, Name: "version", Help: commands.HelpMeta{ @@ -427,21 +410,6 @@ func (br *BridgeMain) Stop() { br.Bridge.StopWithTimeout(5 * time.Second) } -func semverToCalver(semver string) string { - parts := strings.SplitN(semver, ".", 3) - if len(parts) < 2 { - panic(fmt.Errorf("invalid semver for calendar versioning: %s", semver)) - } - if len(parts[1]) != 4 { - panic(fmt.Errorf("invalid minor semver component for calendar versioning: %s", parts[1])) - } - calver := parts[1][:2] + "." + parts[1][2:] - if len(parts) == 3 { - calver += "." + parts[2] - } - return calver -} - // InitVersion formats the bridge version and build time nicely for things like // the `version` bridge command on Matrix and the `--version` CLI flag. // @@ -464,46 +432,12 @@ func semverToCalver(semver string) string { // // (to use both at the same time, simply merge the ldflags into one, `-ldflags "-X '...' -X ..."`) func (br *BridgeMain) InitVersion(tag, commit, rawBuildTime string) { - br.baseVersion = br.Version - rawTag := tag - if len(tag) > 0 && tag[0] == 'v' { - tag = tag[1:] - } - if br.SemCalVer && len(tag) > 0 { - tag = semverToCalver(tag) - } - if tag != br.Version { - suffix := "" - if !strings.HasSuffix(br.Version, "+dev") { - suffix = "+dev" - } - if len(commit) > 8 { - br.Version = fmt.Sprintf("%s%s.%s", br.Version, suffix, commit[:8]) - } else { - br.Version = fmt.Sprintf("%s%s.unknown", br.Version, suffix) - } - } - - br.LinkifiedVersion = fmt.Sprintf("v%s", br.Version) - if tag == br.Version { - br.LinkifiedVersion = fmt.Sprintf("[v%s](%s/releases/v%s)", br.Version, br.URL, rawTag) - } else if len(commit) > 8 { - br.LinkifiedVersion = strings.Replace(br.LinkifiedVersion, commit[:8], fmt.Sprintf("[%s](%s/commit/%s)", commit[:8], br.URL, commit), 1) - } - var buildTime time.Time - if rawBuildTime != "unknown" { - buildTime, _ = time.Parse(time.RFC3339, rawBuildTime) - } - var builtWith string - if buildTime.IsZero() { - rawBuildTime = "unknown" - builtWith = runtime.Version() - } else { - rawBuildTime = buildTime.Format(time.RFC1123) - builtWith = fmt.Sprintf("built at %s with %s", rawBuildTime, runtime.Version()) - } - mautrix.DefaultUserAgent = fmt.Sprintf("%s/%s %s", br.Name, br.Version, mautrix.DefaultUserAgent) - br.VersionDesc = fmt.Sprintf("%s %s (%s)", br.Name, br.Version, builtWith) - br.commit = commit - br.BuildTime = buildTime + br.ver = progver.ProgramVersion{ + Name: br.Name, + URL: br.URL, + BaseVersion: br.Version, + SemCalVer: br.SemCalVer, + }.Init(tag, commit, rawBuildTime) + mautrix.DefaultUserAgent = fmt.Sprintf("%s/%s %s", br.Name, br.ver.FormattedVersion, mautrix.DefaultUserAgent) + br.Version = br.ver.FormattedVersion } diff --git a/go.mod b/go.mod index 751e8015..c9d082a3 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.13 - go.mau.fi/util v0.9.1 + go.mau.fi/util v0.9.2-0.20250927140851-50bb0cc52015 go.mau.fi/zeroconfig v0.2.0 golang.org/x/crypto v0.42.0 golang.org/x/exp v0.0.0-20250911091902-df9299821621 diff --git a/go.sum b/go.sum index dafb9600..5133d5b6 100644 --- a/go.sum +++ b/go.sum @@ -51,8 +51,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.13 h1:GPddIs617DnBLFFVJFgpo1aBfe/4xcvMc3SB5t/D0pA= github.com/yuin/goldmark v1.7.13/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.9.1 h1:A+XKHRsjKkFi2qOm4RriR1HqY2hoOXNS3WFHaC89r2Y= -go.mau.fi/util v0.9.1/go.mod h1:M0bM9SyaOWJniaHs9hxEzz91r5ql6gYq6o1q5O1SsjQ= +go.mau.fi/util v0.9.2-0.20250927140851-50bb0cc52015 h1:aRnDwmJNAP+/EspXpo7MhSJxfS+g49MzGvnLkcNFUEc= +go.mau.fi/util v0.9.2-0.20250927140851-50bb0cc52015/go.mod h1:M0bM9SyaOWJniaHs9hxEzz91r5ql6gYq6o1q5O1SsjQ= go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= From f2b77f04330c97262d7b049ed394634ce21de941 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 28 Sep 2025 20:33:20 +0300 Subject: [PATCH 1414/1647] version: find from build info if unset --- go.mod | 2 +- go.sum | 4 ++-- version.go | 12 ++++++++++++ 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index c9d082a3..5001851f 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.13 - go.mau.fi/util v0.9.2-0.20250927140851-50bb0cc52015 + go.mau.fi/util v0.9.2-0.20250928173307-c0b5f4ee5899 go.mau.fi/zeroconfig v0.2.0 golang.org/x/crypto v0.42.0 golang.org/x/exp v0.0.0-20250911091902-df9299821621 diff --git a/go.sum b/go.sum index 5133d5b6..8d3fabfe 100644 --- a/go.sum +++ b/go.sum @@ -51,8 +51,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.13 h1:GPddIs617DnBLFFVJFgpo1aBfe/4xcvMc3SB5t/D0pA= github.com/yuin/goldmark v1.7.13/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.9.2-0.20250927140851-50bb0cc52015 h1:aRnDwmJNAP+/EspXpo7MhSJxfS+g49MzGvnLkcNFUEc= -go.mau.fi/util v0.9.2-0.20250927140851-50bb0cc52015/go.mod h1:M0bM9SyaOWJniaHs9hxEzz91r5ql6gYq6o1q5O1SsjQ= +go.mau.fi/util v0.9.2-0.20250928173307-c0b5f4ee5899 h1:GoPWdX45WrJG/NC+/6u4km9X9UvrzqGGG78z4VlXI7o= +go.mau.fi/util v0.9.2-0.20250928173307-c0b5f4ee5899/go.mod h1:M0bM9SyaOWJniaHs9hxEzz91r5ql6gYq6o1q5O1SsjQ= go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= diff --git a/version.go b/version.go index 4821a354..fb812121 100644 --- a/version.go +++ b/version.go @@ -4,6 +4,7 @@ import ( "fmt" "regexp" "runtime" + "runtime/debug" "strings" ) @@ -18,6 +19,17 @@ var DefaultUserAgent = "mautrix-go/" + Version + " go/" + strings.TrimPrefix(run var goModVersionRegex = regexp.MustCompile(`v.+\d{14}-([0-9a-f]{12})`) func init() { + if GoModVersion == "" { + info, _ := debug.ReadBuildInfo() + if info != nil { + for _, mod := range info.Deps { + if mod.Path == "maunium.net/go/mautrix" { + GoModVersion = mod.Version + break + } + } + } + } if GoModVersion != "" { match := goModVersionRegex.FindStringSubmatch(GoModVersion) if match != nil { From b597f149b71119750faa95a0a79d87461ed317d5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 28 Sep 2025 20:39:07 +0300 Subject: [PATCH 1415/1647] version: initialize go.mod version regex lazily --- version.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/version.go b/version.go index fb812121..b76a548c 100644 --- a/version.go +++ b/version.go @@ -16,8 +16,6 @@ var VersionWithCommit = Version var DefaultUserAgent = "mautrix-go/" + Version + " go/" + strings.TrimPrefix(runtime.Version(), "go") -var goModVersionRegex = regexp.MustCompile(`v.+\d{14}-([0-9a-f]{12})`) - func init() { if GoModVersion == "" { info, _ := debug.ReadBuildInfo() @@ -31,7 +29,7 @@ func init() { } } if GoModVersion != "" { - match := goModVersionRegex.FindStringSubmatch(GoModVersion) + match := regexp.MustCompile(`v.+\d{14}-([0-9a-f]{12})`).FindStringSubmatch(GoModVersion) if match != nil { Commit = match[1] } From 329da10584baea98df2a3695348bf63f484f54af Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 30 Sep 2025 15:35:25 +0300 Subject: [PATCH 1416/1647] bridgev2/database: fix split portal parent migration query --- bridgev2/database/portal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index e02b9e44..97af4c4c 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -147,7 +147,7 @@ const ( ) ` fixParentsAfterSplitPortalMigrationQuery = ` - UPDATE portal SET parent_receiver=receiver WHERE parent_receiver='' AND receiver<>'' AND parent_id<>''; + UPDATE portal SET parent_receiver=receiver WHERE bridge_id=$1 AND parent_receiver='' AND receiver<>'' AND parent_id<>''; ` ) From 77682fb2920089fb9f89f3df303195eba261cffb Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 1 Oct 2025 14:48:11 +0300 Subject: [PATCH 1417/1647] bridgev2,error: use NonNilClone instead of creating map manually --- bridgev2/portal.go | 6 ++---- error.go | 12 +++--------- go.mod | 2 +- go.sum | 4 ++-- 4 files changed, 8 insertions(+), 16 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 5db45268..3884303a 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -19,6 +19,7 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/exfmt" + "go.mau.fi/util/exmaps" "go.mau.fi/util/exslices" "go.mau.fi/util/exsync" "go.mau.fi/util/ptr" @@ -4076,10 +4077,7 @@ func (portal *Portal) syncParticipants( Displayname: currentMember.Displayname, AvatarURL: currentMember.AvatarURL, } - wrappedContent := &event.Content{Parsed: content, Raw: maps.Clone(member.MemberEventExtra)} - if wrappedContent.Raw == nil { - wrappedContent.Raw = make(map[string]any) - } + wrappedContent := &event.Content{Parsed: content, Raw: exmaps.NonNilClone(member.MemberEventExtra)} thisEvtSender := sender if member.Membership == event.MembershipJoin { content.Membership = event.MembershipInvite diff --git a/error.go b/error.go index bea4caae..b7c92a5f 100644 --- a/error.go +++ b/error.go @@ -13,6 +13,7 @@ import ( "net/http" "go.mau.fi/util/exhttp" + "go.mau.fi/util/exmaps" "golang.org/x/exp/maps" ) @@ -144,10 +145,7 @@ func (e *RespError) UnmarshalJSON(data []byte) error { } func (e *RespError) MarshalJSON() ([]byte, error) { - data := maps.Clone(e.ExtraData) - if data == nil { - data = make(map[string]any) - } + data := exmaps.NonNilClone(e.ExtraData) data["errcode"] = e.ErrCode data["error"] = e.Err return json.Marshal(data) @@ -178,11 +176,7 @@ func (e RespError) WithStatus(status int) RespError { } func (e RespError) WithExtraData(extraData map[string]any) RespError { - if e.ExtraData == nil { - e.ExtraData = make(map[string]any) - } else { - e.ExtraData = maps.Clone(e.ExtraData) - } + e.ExtraData = exmaps.NonNilClone(e.ExtraData) maps.Copy(e.ExtraData, extraData) return e } diff --git a/go.mod b/go.mod index 5001851f..70bf601e 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.13 - go.mau.fi/util v0.9.2-0.20250928173307-c0b5f4ee5899 + go.mau.fi/util v0.9.2-0.20251001114608-d99877b9cc10 go.mau.fi/zeroconfig v0.2.0 golang.org/x/crypto v0.42.0 golang.org/x/exp v0.0.0-20250911091902-df9299821621 diff --git a/go.sum b/go.sum index 8d3fabfe..639b30a2 100644 --- a/go.sum +++ b/go.sum @@ -51,8 +51,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.13 h1:GPddIs617DnBLFFVJFgpo1aBfe/4xcvMc3SB5t/D0pA= github.com/yuin/goldmark v1.7.13/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.9.2-0.20250928173307-c0b5f4ee5899 h1:GoPWdX45WrJG/NC+/6u4km9X9UvrzqGGG78z4VlXI7o= -go.mau.fi/util v0.9.2-0.20250928173307-c0b5f4ee5899/go.mod h1:M0bM9SyaOWJniaHs9hxEzz91r5ql6gYq6o1q5O1SsjQ= +go.mau.fi/util v0.9.2-0.20251001114608-d99877b9cc10 h1:EvX/di02gOriKN0xGDJuQ5mgiNdAF4LJc8moffI7Svo= +go.mau.fi/util v0.9.2-0.20251001114608-d99877b9cc10/go.mod h1:M0bM9SyaOWJniaHs9hxEzz91r5ql6gYq6o1q5O1SsjQ= go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= From 9ee13d136394e07aa20a02f3a665850a59afc3e2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 1 Oct 2025 14:48:28 +0300 Subject: [PATCH 1418/1647] bridgev2/portal: add option to exclude member changes from timeline by default --- bridgev2/portal.go | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 3884303a..817b3144 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -3592,6 +3592,10 @@ type ChatMemberList struct { // Should the bridge call IsThisUser for every member in the list? // This should be used when SenderLogin can't be filled accurately. CheckAllLogins bool + // Should any changes have the `com.beeper.exclude_from_timeline` flag set by default? + // This is recommended for syncs with non-real-time changes. + // Real-time changes (e.g. a user joining) should not set this flag set. + ExcludeChangesFromTimeline bool // The total number of members in the chat, regardless of how many of those members are included in MemberMap. TotalMemberCount int @@ -4048,6 +4052,12 @@ func (portal *Portal) syncParticipants( } delete(currentMembers, portal.Bridge.Bot.GetMXID()) powerChanged := members.PowerLevels.Apply(portal.Bridge.Bot.GetMXID(), currentPower) + addExcludeFromTimeline := func(raw map[string]any) { + _, hasKey := raw["com.beeper.exclude_from_timeline"] + if !hasKey && members.ExcludeChangesFromTimeline { + raw["com.beeper.exclude_from_timeline"] = true + } + } syncUser := func(extraUserID id.UserID, member ChatMember, intent MatrixAPI) bool { if member.Membership == "" { member.Membership = event.MembershipJoin @@ -4078,6 +4088,7 @@ func (portal *Portal) syncParticipants( AvatarURL: currentMember.AvatarURL, } wrappedContent := &event.Content{Parsed: content, Raw: exmaps.NonNilClone(member.MemberEventExtra)} + addExcludeFromTimeline(wrappedContent.Raw) thisEvtSender := sender if member.Membership == event.MembershipJoin { content.Membership = event.MembershipInvite @@ -4122,7 +4133,8 @@ func (portal *Portal) syncParticipants( if intent != nil && content.Membership == event.MembershipInvite && member.Membership == event.MembershipJoin { content.Membership = event.MembershipJoin - wrappedJoinContent := &event.Content{Parsed: content, Raw: member.MemberEventExtra} + wrappedJoinContent := &event.Content{Parsed: content, Raw: exmaps.NonNilClone(member.MemberEventExtra)} + addExcludeFromTimeline(wrappedContent.Raw) _, err = intent.SendState(ctx, portal.MXID, event.StateMember, intent.GetMXID().String(), wrappedJoinContent, ts) if err != nil { addLogContext(log.Err(err)). From dd778ae0cdaf0c147dc106403de64d25780a0b60 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 1 Oct 2025 14:55:35 +0300 Subject: [PATCH 1419/1647] bridgev2/portal: add option to exclude metadata changes from timeline --- bridgev2/portal.go | 101 +++++++++++++++++++++++++------------ bridgev2/portalinternal.go | 24 ++++++--- 2 files changed, 86 insertions(+), 39 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 817b3144..51d6a294 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1537,7 +1537,7 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( return EventHandlingResultIgnored } if !sender.Client.GetCapabilities(ctx, portal).DisappearingTimer.Supports(typedContent) { - portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent()) + portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent(), false) return EventHandlingResultFailed.WithMSSError(ErrDisappearingTimerUnsupported) } } @@ -1561,7 +1561,7 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( if err != nil { log.Err(err).Msg("Failed to handle Matrix room metadata") if evt.Type == event.StateBeeperDisappearingTimer { - portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent()) + portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent(), false) } return EventHandlingResultFailed.WithMSSError(err) } @@ -3712,6 +3712,8 @@ type ChatInfo struct { CanBackfill bool + ExcludeChangesFromTimeline bool + ExtraUpdates ExtraUpdater[*Portal] } @@ -3744,25 +3746,35 @@ type UserLocalPortalInfo struct { Tag *event.RoomTag } -func (portal *Portal) updateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time) bool { +func (portal *Portal) updateName( + ctx context.Context, name string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool, +) bool { if portal.Name == name && (portal.NameSet || portal.MXID == "") { return false } portal.Name = name - portal.NameSet = portal.sendRoomMeta(ctx, sender, ts, event.StateRoomName, "", &event.RoomNameEventContent{Name: name}) + portal.NameSet = portal.sendRoomMeta( + ctx, sender, ts, event.StateRoomName, "", &event.RoomNameEventContent{Name: name}, excludeFromTimeline, + ) return true } -func (portal *Portal) updateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time) bool { +func (portal *Portal) updateTopic( + ctx context.Context, topic string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool, +) bool { if portal.Topic == topic && (portal.TopicSet || portal.MXID == "") { return false } portal.Topic = topic - portal.TopicSet = portal.sendRoomMeta(ctx, sender, ts, event.StateTopic, "", &event.TopicEventContent{Topic: topic}) + portal.TopicSet = portal.sendRoomMeta( + ctx, sender, ts, event.StateTopic, "", &event.TopicEventContent{Topic: topic}, excludeFromTimeline, + ) return true } -func (portal *Portal) updateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time) bool { +func (portal *Portal) updateAvatar( + ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time, excludeFromTimeline bool, +) bool { if portal.AvatarID == avatar.ID && (avatar.Remove || portal.AvatarMXC != "") && (portal.AvatarSet || portal.MXID == "") { return false } @@ -3785,7 +3797,9 @@ func (portal *Portal) updateAvatar(ctx context.Context, avatar *Avatar, sender M portal.AvatarMXC = newMXC portal.AvatarHash = newHash } - portal.AvatarSet = portal.sendRoomMeta(ctx, sender, ts, event.StateRoomAvatar, "", &event.RoomAvatarEventContent{URL: portal.AvatarMXC}) + portal.AvatarSet = portal.sendRoomMeta( + ctx, sender, ts, event.StateRoomAvatar, "", &event.RoomAvatarEventContent{URL: portal.AvatarMXC}, excludeFromTimeline, + ) return true } @@ -3851,8 +3865,8 @@ func (portal *Portal) UpdateBridgeInfo(ctx context.Context) { return } stateKey, bridgeInfo := portal.getBridgeInfo() - portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBridge, stateKey, &bridgeInfo) - portal.sendRoomMeta(ctx, nil, time.Now(), event.StateHalfShotBridge, stateKey, &bridgeInfo) + portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBridge, stateKey, &bridgeInfo, false) + portal.sendRoomMeta(ctx, nil, time.Now(), event.StateHalfShotBridge, stateKey, &bridgeInfo, false) } func (portal *Portal) UpdateCapabilities(ctx context.Context, source *UserLogin, implicit bool) bool { @@ -3874,7 +3888,7 @@ func (portal *Portal) UpdateCapabilities(ctx context.Context, source *UserLogin, Str("old_id", portal.CapState.ID). Str("new_id", capID). Msg("Sending new room capability event") - success := portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperRoomFeatures, portal.getBridgeInfoStateKey(), caps) + success := portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperRoomFeatures, portal.getBridgeInfoStateKey(), caps, false) if !success { return false } @@ -3885,7 +3899,7 @@ func (portal *Portal) UpdateCapabilities(ctx context.Context, source *UserLogin, } if caps.DisappearingTimer != nil && !portal.CapState.Flags.Has(database.CapStateFlagDisappearingTimerSet) { zerolog.Ctx(ctx).Debug().Msg("Disappearing timer capability was added, sending disappearing timer state event") - success = portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent()) + success = portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent(), true) if !success { return false } @@ -3916,15 +3930,24 @@ func (portal *Portal) sendStateWithIntentOrBot(ctx context.Context, sender Matri return } -func (portal *Portal) sendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any) bool { +func (portal *Portal) sendRoomMeta( + ctx context.Context, + sender MatrixAPI, + ts time.Time, + eventType event.Type, + stateKey string, + content any, + excludeFromTimeline bool, +) bool { if portal.MXID == "" { return false } - var extra map[string]any + extra := make(map[string]any) + if excludeFromTimeline { + extra["com.beeper.exclude_from_timeline"] = true + } if !portal.NameIsCustom && (eventType == event.StateRoomName || eventType == event.StateRoomAvatar) { - extra = map[string]any{ - "fi.mau.implicit_name": true, - } + extra["fi.mau.implicit_name"] = true } _, err := portal.sendStateWithIntentOrBot(ctx, sender, eventType, stateKey, &event.Content{ Parsed: content, @@ -4281,9 +4304,15 @@ type UpdateDisappearingSettingOpts struct { Implicit bool Save bool SendNotice bool + + ExcludeFromTimeline bool } -func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting database.DisappearingSetting, opts UpdateDisappearingSettingOpts) bool { +func (portal *Portal) UpdateDisappearingSetting( + ctx context.Context, + setting database.DisappearingSetting, + opts UpdateDisappearingSettingOpts, +) bool { setting = setting.Normalize() if portal.Disappear.Timer == setting.Timer && portal.Disappear.Type == setting.Type { return false @@ -4306,7 +4335,15 @@ func (portal *Portal) UpdateDisappearingSetting(ctx context.Context, setting dat if opts.Timestamp.IsZero() { opts.Timestamp = time.Now() } - portal.sendRoomMeta(ctx, opts.Sender, opts.Timestamp, event.StateBeeperDisappearingTimer, "", setting.ToEventContent()) + portal.sendRoomMeta( + ctx, + opts.Sender, + opts.Timestamp, + event.StateBeeperDisappearingTimer, + "", + setting.ToEventContent(), + opts.ExcludeFromTimeline, + ) if !opts.SendNotice { return true @@ -4390,13 +4427,13 @@ func (portal *Portal) UpdateInfoFromGhost(ctx context.Context, ghost *Ghost) (ch return } } - changed = portal.updateName(ctx, ghost.Name, nil, time.Time{}) || changed + changed = portal.updateName(ctx, ghost.Name, nil, time.Time{}, false) || changed changed = portal.updateAvatar(ctx, &Avatar{ ID: ghost.AvatarID, MXC: ghost.AvatarMXC, Hash: ghost.AvatarHash, Remove: ghost.AvatarID == "", - }, nil, time.Time{}) || changed + }, nil, time.Time{}, false) || changed return } @@ -4405,26 +4442,28 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us if info.Name == DefaultChatName { if portal.NameIsCustom { portal.NameIsCustom = false - changed = portal.updateName(ctx, "", sender, ts) || changed + changed = portal.updateName(ctx, "", sender, ts, info.ExcludeChangesFromTimeline) || changed } } else if info.Name != nil { portal.NameIsCustom = true - changed = portal.updateName(ctx, *info.Name, sender, ts) || changed + changed = portal.updateName(ctx, *info.Name, sender, ts, info.ExcludeChangesFromTimeline) || changed } if info.Topic != nil { - changed = portal.updateTopic(ctx, *info.Topic, sender, ts) || changed + changed = portal.updateTopic(ctx, *info.Topic, sender, ts, info.ExcludeChangesFromTimeline) || changed } if info.Avatar != nil { portal.NameIsCustom = true - changed = portal.updateAvatar(ctx, info.Avatar, sender, ts) || changed + changed = portal.updateAvatar(ctx, info.Avatar, sender, ts, info.ExcludeChangesFromTimeline) || changed } if info.Disappear != nil { changed = portal.UpdateDisappearingSetting(ctx, *info.Disappear, UpdateDisappearingSettingOpts{ - Sender: sender, - Timestamp: ts, - Implicit: false, - Save: false, - SendNotice: true, + Sender: sender, + Timestamp: ts, + Implicit: false, + Save: false, + + SendNotice: !info.ExcludeChangesFromTimeline, + ExcludeFromTimeline: info.ExcludeChangesFromTimeline, }) || changed } if info.ParentID != nil { @@ -4432,7 +4471,7 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us } if info.JoinRule != nil { // TODO change detection instead of spamming this every time? - portal.sendRoomMeta(ctx, sender, ts, event.StateJoinRules, "", info.JoinRule) + portal.sendRoomMeta(ctx, sender, ts, event.StateJoinRules, "", info.JoinRule, info.ExcludeChangesFromTimeline) } if info.Type != nil && portal.RoomType != *info.Type { if portal.MXID != "" && (*info.Type == database.RoomTypeSpace || portal.RoomType == database.RoomTypeSpace) { diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index ddbadc76..d9373eb6 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -121,6 +121,10 @@ func (portal *PortalInternals) GetTargetUser(ctx context.Context, userID id.User return (*Portal)(portal).getTargetUser(ctx, userID) } +func (portal *PortalInternals) HandleMatrixDeleteChat(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { + return (*Portal)(portal).handleMatrixDeleteChat(ctx, sender, origSender, evt) +} + func (portal *PortalInternals) HandleMatrixMembership(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { return (*Portal)(portal).handleMatrixMembership(ctx, sender, origSender, evt) } @@ -249,6 +253,10 @@ func (portal *PortalInternals) HandleRemoteChatResync(ctx context.Context, sourc return (*Portal)(portal).handleRemoteChatResync(ctx, source, evt) } +func (portal *PortalInternals) FindOtherLogins(ctx context.Context, source *UserLogin) (ownUP *database.UserPortal, others []*database.UserPortal, err error) { + return (*Portal)(portal).findOtherLogins(ctx, source) +} + func (portal *PortalInternals) HandleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) EventHandlingResult { return (*Portal)(portal).handleRemoteChatDelete(ctx, source, evt) } @@ -257,16 +265,16 @@ func (portal *PortalInternals) HandleRemoteBackfill(ctx context.Context, source return (*Portal)(portal).handleRemoteBackfill(ctx, source, backfill) } -func (portal *PortalInternals) UpdateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time) bool { - return (*Portal)(portal).updateName(ctx, name, sender, ts) +func (portal *PortalInternals) UpdateName(ctx context.Context, name string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool { + return (*Portal)(portal).updateName(ctx, name, sender, ts, excludeFromTimeline) } -func (portal *PortalInternals) UpdateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time) bool { - return (*Portal)(portal).updateTopic(ctx, topic, sender, ts) +func (portal *PortalInternals) UpdateTopic(ctx context.Context, topic string, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool { + return (*Portal)(portal).updateTopic(ctx, topic, sender, ts, excludeFromTimeline) } -func (portal *PortalInternals) UpdateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time) bool { - return (*Portal)(portal).updateAvatar(ctx, avatar, sender, ts) +func (portal *PortalInternals) UpdateAvatar(ctx context.Context, avatar *Avatar, sender MatrixAPI, ts time.Time, excludeFromTimeline bool) bool { + return (*Portal)(portal).updateAvatar(ctx, avatar, sender, ts, excludeFromTimeline) } func (portal *PortalInternals) GetBridgeInfoStateKey() string { @@ -281,8 +289,8 @@ func (portal *PortalInternals) SendStateWithIntentOrBot(ctx context.Context, sen return (*Portal)(portal).sendStateWithIntentOrBot(ctx, sender, eventType, stateKey, content, ts) } -func (portal *PortalInternals) SendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any) bool { - return (*Portal)(portal).sendRoomMeta(ctx, sender, ts, eventType, stateKey, content) +func (portal *PortalInternals) SendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any, excludeFromTimeline bool) bool { + return (*Portal)(portal).sendRoomMeta(ctx, sender, ts, eventType, stateKey, content, excludeFromTimeline) } func (portal *PortalInternals) GetInitialMemberList(ctx context.Context, members *ChatMemberList, source *UserLogin, pl *event.PowerLevelsEventContent) (invite, functional []id.UserID, err error) { From 97da8eb44dc993295c659c33eb7351a4ff36260f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 2 Oct 2025 14:45:46 +0300 Subject: [PATCH 1420/1647] event: add helper to get remaining mute duration --- event/accountdata.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/event/accountdata.go b/event/accountdata.go index 30ca35a2..223919a1 100644 --- a/event/accountdata.go +++ b/event/accountdata.go @@ -105,3 +105,15 @@ func (bmec *BeeperMuteEventContent) GetMutedUntilTime() time.Time { } return time.Time{} } + +func (bmec *BeeperMuteEventContent) GetMuteDuration() time.Duration { + ts := bmec.GetMutedUntilTime() + now := time.Now() + if ts.Before(now) { + return 0 + } else if ts == MutedForever { + return -1 + } else { + return ts.Sub(now) + } +} From 5d69963ab546c11fc0862f04f7c025d9d974fe57 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 2 Oct 2025 17:19:45 +0300 Subject: [PATCH 1421/1647] bridgev2/portal: add exclude from timeline flag for not in chat leaves --- bridgev2/portal.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 51d6a294..4d2e60a0 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -4230,6 +4230,9 @@ func (portal *Portal) syncParticipants( Displayname: memberEvt.Displayname, Reason: "User is not in remote chat", }, + Raw: map[string]any{ + "com.beeper.exclude_from_timeline": members.ExcludeChangesFromTimeline, + }, }, time.Now()) if err != nil { zerolog.Ctx(ctx).Err(err). From 9fc5d987743a56e6c4e376e23260cfde02508904 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 2 Oct 2025 21:57:25 +0300 Subject: [PATCH 1422/1647] bridgev2/mxmain: fix --version flag output --- bridgev2/matrix/mxmain/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index 9e409875..ca0ca5f7 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -140,7 +140,7 @@ func (br *BridgeMain) PreInit() { flag.PrintHelp() os.Exit(0) } else if *version { - fmt.Println(br.ver.FormattedVersion) + fmt.Println(br.ver.VersionDescription) os.Exit(0) } else if *versionJSON { output := VersionJSONOutput{ From 8e668586f9c23b80864a100bb1368be9e7d2d0cf Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 2 Oct 2025 22:10:22 +0300 Subject: [PATCH 1423/1647] appservice/intent: add room ID to fake join response --- appservice/intent.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/appservice/intent.go b/appservice/intent.go index fa9d9e7a..4635f59a 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -306,7 +306,7 @@ func (intent *IntentAPI) SendCustomMembershipEvent(ctx context.Context, roomID i func (intent *IntentAPI) JoinRoomByID(ctx context.Context, roomID id.RoomID, extraContent ...map[string]interface{}) (resp *mautrix.RespJoinRoom, err error) { if intent.IsCustomPuppet || len(extraContent) > 0 { _, err = intent.SendCustomMembershipEvent(ctx, roomID, intent.UserID, event.MembershipJoin, "", extraContent...) - return &mautrix.RespJoinRoom{}, err + return &mautrix.RespJoinRoom{RoomID: roomID}, err } return intent.Client.JoinRoomByID(ctx, roomID) } From ce667a65e5783aa047beb3e8ea739a078fa0d581 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 3 Oct 2025 03:10:29 +0300 Subject: [PATCH 1424/1647] bridgev2/simplevent: add simpler form of message event --- bridgev2/simplevent/message.go | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/bridgev2/simplevent/message.go b/bridgev2/simplevent/message.go index f648ab12..ac9f8d77 100644 --- a/bridgev2/simplevent/message.go +++ b/bridgev2/simplevent/message.go @@ -59,6 +59,31 @@ func (evt *Message[T]) GetTransactionID() networkid.TransactionID { return evt.TransactionID } +// PreConvertedMessage is a simple implementation of [bridgev2.RemoteMessage] with pre-converted data. +type PreConvertedMessage struct { + EventMeta + Data *bridgev2.ConvertedMessage + ID networkid.MessageID + TransactionID networkid.TransactionID +} + +var ( + _ bridgev2.RemoteMessage = (*PreConvertedMessage)(nil) + _ bridgev2.RemoteMessageWithTransactionID = (*PreConvertedMessage)(nil) +) + +func (evt *PreConvertedMessage) ConvertMessage(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI) (*bridgev2.ConvertedMessage, error) { + return evt.Data, nil +} + +func (evt *PreConvertedMessage) GetID() networkid.MessageID { + return evt.ID +} + +func (evt *PreConvertedMessage) GetTransactionID() networkid.TransactionID { + return evt.TransactionID +} + type MessageRemove struct { EventMeta From 4be60a002169a527ef6b1a751ea9bafe1b665b0f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 3 Oct 2025 03:14:51 +0300 Subject: [PATCH 1425/1647] bridgev2/simplevent: allow upserts with PreConvertedMessage --- bridgev2/simplevent/message.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/bridgev2/simplevent/message.go b/bridgev2/simplevent/message.go index ac9f8d77..f8f8d7e1 100644 --- a/bridgev2/simplevent/message.go +++ b/bridgev2/simplevent/message.go @@ -65,10 +65,13 @@ type PreConvertedMessage struct { Data *bridgev2.ConvertedMessage ID networkid.MessageID TransactionID networkid.TransactionID + + HandleExistingFunc func(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (bridgev2.UpsertResult, error) } var ( _ bridgev2.RemoteMessage = (*PreConvertedMessage)(nil) + _ bridgev2.RemoteMessageUpsert = (*PreConvertedMessage)(nil) _ bridgev2.RemoteMessageWithTransactionID = (*PreConvertedMessage)(nil) ) @@ -84,6 +87,13 @@ func (evt *PreConvertedMessage) GetTransactionID() networkid.TransactionID { return evt.TransactionID } +func (evt *PreConvertedMessage) HandleExisting(ctx context.Context, portal *bridgev2.Portal, intent bridgev2.MatrixAPI, existing []*database.Message) (bridgev2.UpsertResult, error) { + if evt.HandleExistingFunc == nil { + return bridgev2.UpsertResult{}, nil + } + return evt.HandleExistingFunc(ctx, portal, intent, existing) +} + type MessageRemove struct { EventMeta From 8a72af9f6b368e72366d2a689d552abfeafd065c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 3 Oct 2025 22:51:05 +0300 Subject: [PATCH 1426/1647] federation/eventauth: require that join authorizer is in the room --- federation/eventauth/eventauth.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/federation/eventauth/eventauth.go b/federation/eventauth/eventauth.go index 7d73abcd..bd102213 100644 --- a/federation/eventauth/eventauth.go +++ b/federation/eventauth/eventauth.go @@ -79,6 +79,7 @@ var ( ErrCantJoinOtherUser = AuthFailError{Index: "5.3.2", Message: "can't send join event with different state key"} ErrCantJoinBanned = AuthFailError{Index: "5.3.3", Message: "user is banned from the room"} ErrAuthoriserCantInvite = AuthFailError{Index: "5.3.5.2", Message: "authoriser doesn't have sufficient power level to invite"} + ErrAuthoriserNotInRoom = AuthFailError{Index: "5.3.5.2", Message: "authoriser isn't a member of the room"} ErrCantJoinWithoutInvite = AuthFailError{Index: "5.3.7", Message: "can't join invite-only room without invite"} ErrInvalidJoinRule = AuthFailError{Index: "5.3.7", Message: "invalid join rule in room"} ErrThirdPartyInviteBanned = AuthFailError{Index: "5.4.1.1", Message: "third party invite target user is banned"} @@ -384,6 +385,10 @@ func authorizeMember(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEv // 5.3.5.2. If the join_authorised_via_users_server key in content is not a user with sufficient permission to invite other users, reject. return ErrAuthoriserCantInvite } + authorizerMembership := event.Membership(findEventAndReadString(authEvents, event.StateMember.Type, authorizedVia.String(), "membership", string(event.MembershipLeave))) + if authorizerMembership != event.MembershipJoin { + return ErrAuthoriserNotInRoom + } // 5.3.5.3. Otherwise, allow. return nil case event.JoinRulePublic: From 13f251fe607b6f9a4bbb87bc00995c518f1ba5af Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 5 Oct 2025 12:30:54 +0300 Subject: [PATCH 1427/1647] crypto/helper: don't block on decryption --- client.go | 1 + crypto/cryptohelper/cryptohelper.go | 44 ++++++++++++++++++----------- sync.go | 1 + 3 files changed, 30 insertions(+), 16 deletions(-) diff --git a/client.go b/client.go index 62843218..ec527dd0 100644 --- a/client.go +++ b/client.go @@ -323,6 +323,7 @@ const ( LogBodyContextKey contextKey = iota LogRequestIDContextKey MaxAttemptsContextKey + SyncTokenContextKey ) func (cli *Client) RequestStart(req *http.Request) { diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go index 74710678..1939ea79 100644 --- a/crypto/cryptohelper/cryptohelper.go +++ b/crypto/cryptohelper/cryptohelper.go @@ -297,24 +297,14 @@ func (helper *CryptoHelper) HandleEncrypted(ctx context.Context, evt *event.Even ctx = log.WithContext(ctx) decrypted, err := helper.Decrypt(ctx, evt) - if errors.Is(err, NoSessionFound) { - log.Debug(). - Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())). - Msg("Couldn't find session, waiting for keys to arrive...") - if helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { - log.Debug().Msg("Got keys after waiting, trying to decrypt event again") - decrypted, err = helper.Decrypt(ctx, evt) - } else { - go helper.waitLongerForSession(ctx, log, evt) - return - } - } - if err != nil { + if errors.Is(err, NoSessionFound) && ctx.Value(mautrix.SyncTokenContextKey) != "" { + go helper.waitForSession(ctx, evt) + } else if err != nil { log.Warn().Err(err).Msg("Failed to decrypt event") helper.DecryptErrorCallback(evt, err) - return + } else { + helper.postDecrypt(ctx, decrypted) } - helper.postDecrypt(ctx, decrypted) } func (helper *CryptoHelper) postDecrypt(ctx context.Context, decrypted *event.Event) { @@ -355,7 +345,29 @@ func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID } } -func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, log zerolog.Logger, evt *event.Event) { +func (helper *CryptoHelper) waitForSession(ctx context.Context, evt *event.Event) { + log := zerolog.Ctx(ctx) + content := evt.Content.AsEncrypted() + + log.Debug(). + Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())). + Msg("Couldn't find session, waiting for keys to arrive...") + if helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) { + log.Debug().Msg("Got keys after waiting, trying to decrypt event again") + decrypted, err := helper.Decrypt(ctx, evt) + if err != nil { + log.Warn().Err(err).Msg("Failed to decrypt event") + helper.DecryptErrorCallback(evt, err) + } else { + helper.postDecrypt(ctx, decrypted) + } + } else { + go helper.waitLongerForSession(ctx, evt) + } +} + +func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, evt *event.Event) { + log := zerolog.Ctx(ctx) content := evt.Content.AsEncrypted() log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...") diff --git a/sync.go b/sync.go index c52bd2f9..598df8e0 100644 --- a/sync.go +++ b/sync.go @@ -90,6 +90,7 @@ func (s *DefaultSyncer) ProcessResponse(ctx context.Context, res *RespSync, sinc err = fmt.Errorf("ProcessResponse panicked! since=%s panic=%s\n%s", since, r, debug.Stack()) } }() + ctx = context.WithValue(ctx, SyncTokenContextKey, since) for _, listener := range s.syncListeners { if !listener(ctx, res, since) { From 07bc756971535211771ab59fd543f7ff03e40652 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 6 Oct 2025 16:51:41 +0300 Subject: [PATCH 1428/1647] changelog: update --- CHANGELOG.md | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 831e3094..43332ac8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -## unreleased +## v0.25.2 (unreleased) * *(crypto)* Added helper methods for generating and verifying with recovery keys. @@ -7,12 +7,17 @@ * *(bridgev2/matrix)* Added initial support for using appservice/MSC3202 mode for encryption with standard servers like Synapse. * *(bridgev2)* Added optional support for implicit read receipts. +* *(bridgev2)* Added interface for deleting chats on remote network. * *(bridgev2)* Extended event duration logging to log any event taking too long. +* *(event)* Added event type constant for poll end events. +* *(client)* Added wrapper for searching user directory. +* *(crypto/helper)* Changed default sync handling to not block on waiting for + decryption keys. On initial sync, keys won't be requested at all by default. +* *(crypto)* Fixed olm unwedging not working (regressed in v0.25.1). * *(bridgev2)* Fixed various bugs with migrating to split portals. * *(event)* Fixed poll start events having incorrect null `m.relates_to`. -* *(event)* Added event type constant for poll end events. * *(client)* Fixed `RespUserProfile` losing standard fields when re-marshaling. -* *(client)* Added wrapper for searching user directory. +* *(federation)* Fixed various bugs in event auth. ## v0.25.1 (2025-09-16) From 344b04c4075ef99be514e90400fceb270b41f45c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 6 Oct 2025 17:03:30 +0300 Subject: [PATCH 1429/1647] event: add Clone method for file features --- event/capabilities.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/event/capabilities.go b/event/capabilities.go index 31a6b7aa..bd0c3d27 100644 --- a/event/capabilities.go +++ b/event/capabilities.go @@ -18,6 +18,7 @@ import ( "go.mau.fi/util/exerrors" "go.mau.fi/util/jsontime" + "go.mau.fi/util/ptr" "golang.org/x/exp/constraints" "golang.org/x/exp/maps" ) @@ -70,6 +71,14 @@ type FormattingFeatureMap map[FormattingFeature]CapabilitySupportLevel type FileFeatureMap map[CapabilityMsgType]*FileFeatures +func (ffm FileFeatureMap) Clone() FileFeatureMap { + dup := maps.Clone(ffm) + for key, value := range dup { + dup[key] = value.Clone() + } + return dup +} + type DisappearingTimerCapability struct { Types []DisappearingType `json:"types"` Timers []jsontime.Milliseconds `json:"timers,omitempty"` @@ -296,3 +305,10 @@ func (ff *FileFeatures) Hash() []byte { hashBool(hasher, "view_once", ff.ViewOnce) return hasher.Sum(nil) } + +func (ff *FileFeatures) Clone() *FileFeatures { + clone := *ff + clone.MimeTypes = maps.Clone(clone.MimeTypes) + clone.MaxDuration = ptr.Clone(clone.MaxDuration) + return &clone +} From 548970fd0f3dd17b3933c14c9cf268dab13518e9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 6 Oct 2025 17:05:46 +0300 Subject: [PATCH 1430/1647] event: add Clone for other capability types too --- event/capabilities.go | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/event/capabilities.go b/event/capabilities.go index bd0c3d27..42afe5b6 100644 --- a/event/capabilities.go +++ b/event/capabilities.go @@ -67,6 +67,20 @@ func (rf *RoomFeatures) GetID() string { return base64.RawURLEncoding.EncodeToString(rf.Hash()) } +func (rf *RoomFeatures) Clone() *RoomFeatures { + if rf == nil { + return nil + } + clone := *rf + clone.File = clone.File.Clone() + clone.Formatting = maps.Clone(clone.Formatting) + clone.EditMaxAge = ptr.Clone(clone.EditMaxAge) + clone.DeleteMaxAge = ptr.Clone(clone.DeleteMaxAge) + clone.DisappearingTimer = clone.DisappearingTimer.Clone() + clone.AllowedReactions = slices.Clone(clone.AllowedReactions) + return &clone +} + type FormattingFeatureMap map[FormattingFeature]CapabilitySupportLevel type FileFeatureMap map[CapabilityMsgType]*FileFeatures @@ -86,6 +100,16 @@ type DisappearingTimerCapability struct { OmitEmptyTimer bool `json:"omit_empty_timer,omitempty"` } +func (dtc *DisappearingTimerCapability) Clone() *DisappearingTimerCapability { + if dtc == nil { + return nil + } + clone := *dtc + clone.Types = slices.Clone(clone.Types) + clone.Timers = slices.Clone(clone.Timers) + return &clone +} + func (dtc *DisappearingTimerCapability) Supports(content *BeeperDisappearingTimer) bool { if dtc == nil || content == nil || content.Type == DisappearingTypeNone { return true @@ -307,6 +331,9 @@ func (ff *FileFeatures) Hash() []byte { } func (ff *FileFeatures) Clone() *FileFeatures { + if ff == nil { + return nil + } clone := *ff clone.MimeTypes = maps.Clone(clone.MimeTypes) clone.MaxDuration = ptr.Clone(clone.MaxDuration) From 51edfc27c097e99d1468c30a3a50d1e641ec059e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 6 Oct 2025 23:00:04 +0300 Subject: [PATCH 1431/1647] bridgev2: add omitempty for group create params struct --- bridgev2/networkinterface.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 8dffbb34..31647f63 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -815,19 +815,19 @@ type GroupFieldCapability struct { } type GroupCreateParams struct { - Type string `json:"type"` + Type string `json:"type,omitempty"` - Username string `json:"username"` - Participants []networkid.UserID `json:"participants"` - Parent *networkid.PortalKey `json:"parent"` + Username string `json:"username,omitempty"` + Participants []networkid.UserID `json:"participants,omitempty"` + Parent *networkid.PortalKey `json:"parent,omitempty"` - Name *event.RoomNameEventContent `json:"name"` - Avatar *event.RoomAvatarEventContent `json:"avatar"` - Topic *event.TopicEventContent `json:"topic"` - Disappear *event.BeeperDisappearingTimer `json:"disappear"` + Name *event.RoomNameEventContent `json:"name,omitempty"` + Avatar *event.RoomAvatarEventContent `json:"avatar,omitempty"` + Topic *event.TopicEventContent `json:"topic,omitempty"` + Disappear *event.BeeperDisappearingTimer `json:"disappear,omitempty"` // An existing room ID to bridge to. If unset, a new room will be created. - RoomID id.RoomID `json:"room_id"` + RoomID id.RoomID `json:"room_id,omitempty"` } type GroupCreatingNetworkAPI interface { From 3a300246ac3c7895cd6b8b66ee8bf2c1cdf3ebf6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 6 Oct 2025 23:10:04 +0300 Subject: [PATCH 1432/1647] id/userid: split validation into 2 functions --- id/userid.go | 20 ++++++++++++++------ id/userid_test.go | 18 +++++++++--------- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/id/userid.go b/id/userid.go index 6d9f4080..859d2358 100644 --- a/id/userid.go +++ b/id/userid.go @@ -104,16 +104,24 @@ func ValidateUserLocalpart(localpart string) error { return nil } -// ParseAndValidate parses the user ID into the localpart and server name like Parse, -// and also validates that the localpart is allowed according to the user identifiers spec. -func (userID UserID) ParseAndValidate() (localpart, homeserver string, err error) { - localpart, homeserver, err = userID.Parse() +// ParseAndValidateStrict is a stricter version of ParseAndValidateRelaxed that checks the localpart to only allow non-historical localparts. +// This should be used with care: there are real users still using historical localparts. +func (userID UserID) ParseAndValidateStrict() (localpart, homeserver string, err error) { + localpart, homeserver, err = userID.ParseAndValidateRelaxed() if err == nil { err = ValidateUserLocalpart(localpart) } - if err == nil && len(userID) > UserIDMaxLength { + return +} + +// ParseAndValidateRelaxed parses the user ID into the localpart and server name like Parse, +// and also validates that the user ID is not too long and that the server name is valid. +func (userID UserID) ParseAndValidateRelaxed() (localpart, homeserver string, err error) { + if len(userID) > UserIDMaxLength { err = ErrUserIDTooLong + return } + localpart, homeserver, err = userID.Parse() if err == nil && !ValidateServerName(homeserver) { err = fmt.Errorf("%q %q", homeserver, ErrNoncompliantServerPart) } @@ -121,7 +129,7 @@ func (userID UserID) ParseAndValidate() (localpart, homeserver string, err error } func (userID UserID) ParseAndDecode() (localpart, homeserver string, err error) { - localpart, homeserver, err = userID.ParseAndValidate() + localpart, homeserver, err = userID.ParseAndValidateStrict() if err == nil { localpart, err = DecodeUserLocalpart(localpart) } diff --git a/id/userid_test.go b/id/userid_test.go index 359bc687..57a88066 100644 --- a/id/userid_test.go +++ b/id/userid_test.go @@ -38,30 +38,30 @@ func TestUserID_Parse_Invalid(t *testing.T) { assert.True(t, errors.Is(err, id.ErrInvalidUserID)) } -func TestUserID_ParseAndValidate_Invalid(t *testing.T) { +func TestUserID_ParseAndValidateStrict_Invalid(t *testing.T) { const inputUserID = "@s p a c e:maunium.net" - _, _, err := id.UserID(inputUserID).ParseAndValidate() + _, _, err := id.UserID(inputUserID).ParseAndValidateStrict() assert.Error(t, err) assert.True(t, errors.Is(err, id.ErrNoncompliantLocalpart)) } -func TestUserID_ParseAndValidate_Empty(t *testing.T) { +func TestUserID_ParseAndValidateStrict_Empty(t *testing.T) { const inputUserID = "@:ponies.im" - _, _, err := id.UserID(inputUserID).ParseAndValidate() + _, _, err := id.UserID(inputUserID).ParseAndValidateStrict() assert.Error(t, err) assert.True(t, errors.Is(err, id.ErrEmptyLocalpart)) } -func TestUserID_ParseAndValidate_Long(t *testing.T) { +func TestUserID_ParseAndValidateStrict_Long(t *testing.T) { const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com" - _, _, err := id.UserID(inputUserID).ParseAndValidate() + _, _, err := id.UserID(inputUserID).ParseAndValidateStrict() assert.Error(t, err) assert.True(t, errors.Is(err, id.ErrUserIDTooLong)) } -func TestUserID_ParseAndValidate_NotLong(t *testing.T) { +func TestUserID_ParseAndValidateStrict_NotLong(t *testing.T) { const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com" - _, _, err := id.UserID(inputUserID).ParseAndValidate() + _, _, err := id.UserID(inputUserID).ParseAndValidateStrict() assert.NoError(t, err) } @@ -70,7 +70,7 @@ func TestUserIDEncoding(t *testing.T) { const encodedLocalpart = "_this=20local+part=20contains=20_il_le_ga_l=20ch=c3=a4racters=20=f0=9f=9a=a8" const inputServerName = "example.com" userID := id.NewEncodedUserID(inputLocalpart, inputServerName) - parsedLocalpart, parsedServerName, err := userID.ParseAndValidate() + parsedLocalpart, parsedServerName, err := userID.ParseAndValidateStrict() assert.NoError(t, err) assert.Equal(t, encodedLocalpart, parsedLocalpart) assert.Equal(t, inputServerName, parsedServerName) From d18142c7946f6c3332b7791f96814e860729480d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 8 Oct 2025 18:33:57 +0300 Subject: [PATCH 1433/1647] bridgev2/errors: add reason for unsupported errors --- bridgev2/errors.go | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/bridgev2/errors.go b/bridgev2/errors.go index 694224f1..29bba71f 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -45,29 +45,29 @@ var ( ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false) ErrIgnoringPollFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring poll event from relayed user")).WithIsCertain(true).WithSendNotice(false) ErrIgnoringDeleteChatRelayedUser error = WrapErrorInStatus(errors.New("ignoring delete chat event from relayed user")).WithIsCertain(true).WithSendNotice(false) - ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage() - ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage() - ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage() - ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true).WithErrorAsMessage() - ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage() - ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage() - ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage() - ErrPollsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support polls")).WithIsCertain(true).WithErrorAsMessage() - ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) - ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage() + ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrPollsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support polls")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) ErrInvalidStateKey error = WrapErrorInStatus(errors.New("room metadata state key is unset or non-empty")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true) ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) - ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) - ErrUnsupportedMediaType error = WrapErrorInStatus(errors.New("unsupported media type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) + ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) + ErrUnsupportedMediaType error = WrapErrorInStatus(errors.New("unsupported media type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) ErrIgnoringMNotice error = WrapErrorInStatus(errors.New("ignoring m.notice message")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true) ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true) ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true) - ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) - ErrDeleteChatNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting chats")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) - ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) + ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrDeleteChatNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting chats")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) ErrRemoteEchoTimeout = WrapErrorInStatus(errors.New("remote echo timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) ErrRemoteAckTimeout = WrapErrorInStatus(errors.New("remote ack timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) From 9654a0b01e754932913b59f3ca420f59a9c83076 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 8 Oct 2025 18:47:55 +0300 Subject: [PATCH 1434/1647] bridgev2/portal: enforce media duration and size limits --- bridgev2/errors.go | 2 ++ bridgev2/portal.go | 11 +++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/bridgev2/errors.go b/bridgev2/errors.go index 29bba71f..cf27ac6f 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -61,6 +61,8 @@ var ( ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) ErrUnsupportedMediaType error = WrapErrorInStatus(errors.New("unsupported media type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) + ErrMediaDurationTooLong error = WrapErrorInStatus(errors.New("media duration too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) + ErrMediaTooLarge error = WrapErrorInStatus(errors.New("media too large")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) ErrIgnoringMNotice error = WrapErrorInStatus(errors.New("ignoring m.notice message")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true) ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 4d2e60a0..327e9815 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -944,8 +944,15 @@ func (portal *Portal) checkMessageContentCaps(caps *event.RoomFeatures, content feat.Caption.Reject() { return ErrCaptionsNotAllowed } - if content.Info != nil && content.Info.MimeType != "" { - if feat.GetMimeSupport(content.Info.MimeType).Reject() { + if content.Info != nil { + dur := time.Duration(content.Info.Duration) * time.Millisecond + if feat.MaxDuration != nil && dur > feat.MaxDuration.Duration { + return fmt.Errorf("%w: %s is longer than the maximum of %s", ErrMediaDurationTooLong, exfmt.Duration(dur), exfmt.Duration(feat.MaxDuration.Duration)) + } + if feat.MaxSize != 0 && int64(content.Info.Size) > feat.MaxSize { + return fmt.Errorf("%w: %.1f MiB is larger than the maximum of %.1f MiB", ErrMediaTooLarge, float64(content.Info.Size)/1024/1024, float64(feat.MaxSize)/1024/1024) + } + if content.Info.MimeType != "" && feat.GetMimeSupport(content.Info.MimeType).Reject() { return fmt.Errorf("%w (%s in %s)", ErrUnsupportedMediaType, content.Info.MimeType, capMsgType) } } From 91ea77b4d4123c6efe781cb31b9259240414a87c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 8 Oct 2025 19:16:00 +0300 Subject: [PATCH 1435/1647] bridgev2/portal: don't send implicit read receipts for account data --- bridgev2/portal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 327e9815..067d92c2 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -665,7 +665,7 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * // Copy logger because many of the handlers will use UpdateContext ctx = log.With().Str("login_id", string(login.ID)).Logger().WithContext(ctx) - if origSender == nil && portal.Bridge.Network.GetCapabilities().ImplicitReadReceipts { + if origSender == nil && portal.Bridge.Network.GetCapabilities().ImplicitReadReceipts && !evt.Type.IsAccountData() { rrLog := log.With().Str("subaction", "implicit read receipt").Logger() rrCtx := rrLog.WithContext(ctx) rrLog.Debug().Msg("Sending implicit read receipt for event") From 5593d8afcd32c7f9faa0850a22ae7ff932cb329c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 13 Oct 2025 15:30:12 +0300 Subject: [PATCH 1436/1647] changelog: update --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 43332ac8..20ffbd06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ ## v0.25.2 (unreleased) +* **Breaking change *(id)*** Split `UserID.ParseAndValidate` into + `ParseAndValidateRelaxed` and `ParseAndValidateStrict`. Strict is the old + behavior, but most users likely want the relaxed version, as there are real + users whose user IDs aren't valid under the strict rules. * *(crypto)* Added helper methods for generating and verifying with recovery keys. * *(bridgev2/matrix)* Added config option to automatically generate a recovery @@ -8,6 +12,7 @@ for encryption with standard servers like Synapse. * *(bridgev2)* Added optional support for implicit read receipts. * *(bridgev2)* Added interface for deleting chats on remote network. +* *(bridgev2)* Added local enforcement of media duration and size limits. * *(bridgev2)* Extended event duration logging to log any event taking too long. * *(event)* Added event type constant for poll end events. * *(client)* Added wrapper for searching user directory. From 097813c9b29bf81690fe3f6341eee7269c589b0a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 14 Oct 2025 00:19:57 +0300 Subject: [PATCH 1437/1647] bridgev2/provisionutil: validate user IDs in CreateGroup if network supports it --- bridgev2/provisionutil/creategroup.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/bridgev2/provisionutil/creategroup.go b/bridgev2/provisionutil/creategroup.go index 891f9615..7a21f682 100644 --- a/bridgev2/provisionutil/creategroup.go +++ b/bridgev2/provisionutil/creategroup.go @@ -37,6 +37,14 @@ func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev if len(params.Participants) < typeSpec.Participants.MinLength { return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Must have at least %d members", typeSpec.Participants.MinLength)) } + userIDValidatingNetwork, ok := login.Bridge.Network.(bridgev2.IdentifierValidatingNetwork) + if ok { + for _, participant := range params.Participants { + if !userIDValidatingNetwork.ValidateUserID(participant) { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("User ID %q is not valid on this network", participant)) + } + } + } if (params.Name == nil || params.Name.Name == "") && typeSpec.Name.Required { return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name is required")) } else if nameLen := len(ptr.Val(params.Name).Name); nameLen > 0 && nameLen < typeSpec.Name.MinLength { From ab4a7852d6e022c38eca586ea57dcfbb3b36a837 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 14 Oct 2025 13:01:21 +0300 Subject: [PATCH 1438/1647] bridgev2/provisionutil: don't allow self in create group participants --- bridgev2/provisionutil/creategroup.go | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/bridgev2/provisionutil/creategroup.go b/bridgev2/provisionutil/creategroup.go index 7a21f682..f389ab42 100644 --- a/bridgev2/provisionutil/creategroup.go +++ b/bridgev2/provisionutil/creategroup.go @@ -37,12 +37,13 @@ func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev if len(params.Participants) < typeSpec.Participants.MinLength { return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Must have at least %d members", typeSpec.Participants.MinLength)) } - userIDValidatingNetwork, ok := login.Bridge.Network.(bridgev2.IdentifierValidatingNetwork) - if ok { - for _, participant := range params.Participants { - if !userIDValidatingNetwork.ValidateUserID(participant) { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("User ID %q is not valid on this network", participant)) - } + userIDValidatingNetwork, uidValOK := login.Bridge.Network.(bridgev2.IdentifierValidatingNetwork) + for _, participant := range params.Participants { + if uidValOK && !userIDValidatingNetwork.ValidateUserID(participant) { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("User ID %q is not valid on this network", participant)) + } + if api.IsThisUser(ctx, participant) { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("You can't include yourself in the participants list", participant)) } } if (params.Name == nil || params.Name.Name == "") && typeSpec.Name.Required { From 080ad4c0a0e8b81d243f82ce206bf656cdf0a6fe Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Tue, 14 Oct 2025 13:32:02 +0300 Subject: [PATCH 1439/1647] crypto: Allow decrypting message content without event id or ts Replay attack prevention shouldn't store empty event id or ts to database if we're decrypting without them. This may happen if we are looking into a future delayed event for example as it doesn't yet have those. We still prevent doing that if we already know them meaning we have gotten the actual event through sync as that's also when a delayed event would move from scheduled to finalised and then it also contains those fields. --- crypto/sql_store.go | 14 ++++++++++++++ crypto/store.go | 3 +++ crypto/store_test.go | 12 +++++++++++- 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/crypto/sql_store.go b/crypto/sql_store.go index 13940d79..ca75b3f6 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -664,6 +664,20 @@ func (store *SQLCryptoStore) IsOutboundGroupSessionShared(ctx context.Context, u // ValidateMessageIndex returns whether the given event information match the ones stored in the database // for the given sender key, session ID and index. If the index hasn't been stored, this will store it. func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) { + if eventID == "" && timestamp == 0 { + var notOK bool + const validateEmptyQuery = ` + SELECT EXISTS(SELECT 1 FROM crypto_message_index WHERE sender_key=$1 AND session_id=$2 AND "index"=$3) + ` + err := store.DB.QueryRow(ctx, validateEmptyQuery, senderKey, sessionID, index).Scan(¬OK) + if notOK { + zerolog.Ctx(ctx).Debug(). + Uint("message_index", index). + Msg("Rejecting event without event ID and timestamp due to already knowing them") + } + return !notOK, err + } + const validateQuery = ` INSERT INTO crypto_message_index (sender_key, session_id, "index", event_id, timestamp) VALUES ($1, $2, $3, $4, $5) diff --git a/crypto/store.go b/crypto/store.go index 8b7c0a96..7620cf35 100644 --- a/crypto/store.go +++ b/crypto/store.go @@ -525,6 +525,9 @@ func (gs *MemoryStore) ValidateMessageIndex(_ context.Context, senderKey id.Send } val, ok := gs.MessageIndices[key] if !ok { + if eventID == "" && timestamp == 0 { + return true, nil + } gs.MessageIndices[key] = messageIndexValue{ EventID: eventID, Timestamp: timestamp, diff --git a/crypto/store_test.go b/crypto/store_test.go index 8aeae7af..7a47243e 100644 --- a/crypto/store_test.go +++ b/crypto/store_test.go @@ -75,8 +75,13 @@ func TestValidateMessageIndex(t *testing.T) { t.Run(storeName, func(t *testing.T) { acc := NewOlmAccount() + // Validating without event ID and timestamp before we have them should work + ok, err := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "", 0, 0) + require.NoError(t, err, "Error validating message index") + assert.True(t, ok, "First message validation should be valid") + // First message should validate successfully - ok, err := store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000) + ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000) require.NoError(t, err, "Error validating message index") assert.True(t, ok, "First message validation should be valid") @@ -94,6 +99,11 @@ func TestValidateMessageIndex(t *testing.T) { ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "event1", 0, 1000) require.NoError(t, err, "Error validating message index") assert.True(t, ok, "First message validation should be valid") + + // Validating without event ID and timestamp must fail if we already know them + ok, err = store.ValidateMessageIndex(context.TODO(), acc.IdentityKey(), "sess1", "", 0, 0) + require.NoError(t, err, "Error validating message index") + assert.False(t, ok, "First message validation should be invalid") }) } } From 22ea75db96fd968bbb7cd7938aa5b5c19628d945 Mon Sep 17 00:00:00 2001 From: Toni Spets Date: Fri, 10 Oct 2025 14:58:50 +0300 Subject: [PATCH 1440/1647] client,event: MSC4140: Delayed events Includes transparent migration from deprecated MSC fields still used in Synapse to later revision. --- client.go | 26 ++++++++++++++++++ event/delayed.go | 70 ++++++++++++++++++++++++++++++++++++++++++++++++ go.mod | 2 +- go.sum | 4 +-- id/opaque.go | 3 +++ requests.go | 10 +++++-- responses.go | 13 ++++++++- 7 files changed, 122 insertions(+), 6 deletions(-) create mode 100644 event/delayed.go diff --git a/client.go b/client.go index ec527dd0..95cbacb5 100644 --- a/client.go +++ b/client.go @@ -1313,6 +1313,32 @@ func (cli *Client) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, return } +func (cli *Client) DelayedEvents(ctx context.Context, req *ReqDelayedEvents) (resp *RespDelayedEvents, err error) { + query := map[string]string{} + if req.DelayID != "" { + query["delay_id"] = string(req.DelayID) + } + if req.Status != "" { + query["status"] = string(req.Status) + } + if req.NextBatch != "" { + query["next_batch"] = req.NextBatch + } + + urlPath := cli.BuildURLWithQuery(ClientURLPath{"unstable", "org.matrix.msc4140", "delayed_events"}, query) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, req, &resp) + + // Migration: merge old keys with new ones + if resp != nil { + resp.Scheduled = append(resp.Scheduled, resp.DelayedEvents...) + resp.DelayedEvents = nil + resp.Finalised = append(resp.Finalised, resp.FinalisedEvents...) + resp.FinalisedEvents = nil + } + + return +} + func (cli *Client) UpdateDelayedEvent(ctx context.Context, req *ReqUpdateDelayedEvent) (resp *RespUpdateDelayedEvent, err error) { urlPath := cli.BuildClientURL("unstable", "org.matrix.msc4140", "delayed_events", req.DelayID) _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp) diff --git a/event/delayed.go b/event/delayed.go new file mode 100644 index 00000000..fefb62af --- /dev/null +++ b/event/delayed.go @@ -0,0 +1,70 @@ +package event + +import ( + "encoding/json" + + "go.mau.fi/util/jsontime" + + "maunium.net/go/mautrix/id" +) + +type ScheduledDelayedEvent struct { + DelayID id.DelayID `json:"delay_id"` + RoomID id.RoomID `json:"room_id"` + Type Type `json:"type"` + StateKey *string `json:"state_key,omitempty"` + Delay int64 `json:"delay"` + RunningSince jsontime.UnixMilli `json:"running_since"` + Content Content `json:"content"` +} + +func (e ScheduledDelayedEvent) AsEvent(eventID id.EventID, ts jsontime.UnixMilli) (*Event, error) { + evt := &Event{ + ID: eventID, + RoomID: e.RoomID, + Type: e.Type, + StateKey: e.StateKey, + Content: e.Content, + Timestamp: ts.UnixMilli(), + } + return evt, evt.Content.ParseRaw(evt.Type) +} + +type FinalisedDelayedEvent struct { + DelayedEvent *ScheduledDelayedEvent `json:"scheduled_event"` + Outcome DelayOutcome `json:"outcome"` + Reason DelayReason `json:"reason"` + Error json.RawMessage `json:"error,omitempty"` + EventID id.EventID `json:"event_id,omitempty"` + Timestamp jsontime.UnixMilli `json:"origin_server_ts"` +} + +type DelayStatus string + +var ( + DelayStatusScheduled DelayStatus = "scheduled" + DelayStatusFinalised DelayStatus = "finalised" +) + +type DelayAction string + +var ( + DelayActionSend DelayAction = "send" + DelayActionCancel DelayAction = "cancel" + DelayActionRestart DelayAction = "restart" +) + +type DelayOutcome string + +var ( + DelayOutcomeSend DelayOutcome = "send" + DelayOutcomeCancel DelayOutcome = "cancel" +) + +type DelayReason string + +var ( + DelayReasonAction DelayReason = "action" + DelayReasonError DelayReason = "error" + DelayReasonDelay DelayReason = "delay" +) diff --git a/go.mod b/go.mod index 70bf601e..d77428d8 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.13 - go.mau.fi/util v0.9.2-0.20251001114608-d99877b9cc10 + go.mau.fi/util v0.9.2-0.20251014102252-c9ee13b043c8 go.mau.fi/zeroconfig v0.2.0 golang.org/x/crypto v0.42.0 golang.org/x/exp v0.0.0-20250911091902-df9299821621 diff --git a/go.sum b/go.sum index 639b30a2..dee6616c 100644 --- a/go.sum +++ b/go.sum @@ -51,8 +51,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.13 h1:GPddIs617DnBLFFVJFgpo1aBfe/4xcvMc3SB5t/D0pA= github.com/yuin/goldmark v1.7.13/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.9.2-0.20251001114608-d99877b9cc10 h1:EvX/di02gOriKN0xGDJuQ5mgiNdAF4LJc8moffI7Svo= -go.mau.fi/util v0.9.2-0.20251001114608-d99877b9cc10/go.mod h1:M0bM9SyaOWJniaHs9hxEzz91r5ql6gYq6o1q5O1SsjQ= +go.mau.fi/util v0.9.2-0.20251014102252-c9ee13b043c8 h1:36oe41yPjz7QLjJWb72qHi82IOINqgp06eHIVRdalGs= +go.mau.fi/util v0.9.2-0.20251014102252-c9ee13b043c8/go.mod h1:M0bM9SyaOWJniaHs9hxEzz91r5ql6gYq6o1q5O1SsjQ= go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= diff --git a/id/opaque.go b/id/opaque.go index 1d9f0dcf..c1ad4988 100644 --- a/id/opaque.go +++ b/id/opaque.go @@ -32,6 +32,9 @@ type EventID string // https://github.com/matrix-org/matrix-doc/pull/2716 type BatchID string +// A DelayID is a string identifying a delayed event. +type DelayID string + func (roomID RoomID) String() string { return string(roomID) } diff --git a/requests.go b/requests.go index 9dfe09ab..f0287b3c 100644 --- a/requests.go +++ b/requests.go @@ -376,9 +376,15 @@ type ReqSendEvent struct { MeowEventID id.EventID } +type ReqDelayedEvents struct { + DelayID id.DelayID `json:"-"` + Status event.DelayStatus `json:"-"` + NextBatch string `json:"-"` +} + type ReqUpdateDelayedEvent struct { - DelayID string `json:"-"` - Action string `json:"action"` // TODO use enum + DelayID id.DelayID `json:"-"` + Action event.DelayAction `json:"action"` } // ReqDeviceInfo is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3devicesdeviceid diff --git a/responses.go b/responses.go index a79be28b..3484c134 100644 --- a/responses.go +++ b/responses.go @@ -104,11 +104,22 @@ type RespContext struct { type RespSendEvent struct { EventID id.EventID `json:"event_id"` - UnstableDelayID string `json:"delay_id,omitempty"` + UnstableDelayID id.DelayID `json:"delay_id,omitempty"` } type RespUpdateDelayedEvent struct{} +type RespDelayedEvents struct { + Scheduled []*event.ScheduledDelayedEvent `json:"scheduled,omitempty"` + Finalised []*event.FinalisedDelayedEvent `json:"finalised,omitempty"` + NextBatch string `json:"next_batch,omitempty"` + + // Deprecated: Synapse implementation still returns this + DelayedEvents []*event.ScheduledDelayedEvent `json:"delayed_events,omitempty"` + // Deprecated: Synapse implementation still returns this + FinalisedEvents []*event.FinalisedDelayedEvent `json:"finalised_events,omitempty"` +} + type RespRedactUserEvents struct { IsMoreEvents bool `json:"is_more_events"` RedactedEvents struct { From 50a49e01f3ac07665309d0dc13b1349e004d2658 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 16 Oct 2025 11:26:46 +0200 Subject: [PATCH 1441/1647] Bump version to v0.25.2 --- CHANGELOG.md | 4 +++- go.mod | 14 +++++++------- go.sum | 24 ++++++++++++------------ version.go | 2 +- 4 files changed, 23 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 20ffbd06..f59e6853 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -## v0.25.2 (unreleased) +## v0.25.2 (2025-10-16) * **Breaking change *(id)*** Split `UserID.ParseAndValidate` into `ParseAndValidateRelaxed` and `ParseAndValidateStrict`. Strict is the old @@ -14,8 +14,10 @@ * *(bridgev2)* Added interface for deleting chats on remote network. * *(bridgev2)* Added local enforcement of media duration and size limits. * *(bridgev2)* Extended event duration logging to log any event taking too long. +* *(bridgev2)* Improved validation in group creation provisioning API. * *(event)* Added event type constant for poll end events. * *(client)* Added wrapper for searching user directory. +* *(client)* Improved support for managing [MSC4140] delayed events. * *(crypto/helper)* Changed default sync handling to not block on waiting for decryption keys. On initial sync, keys won't be requested at all by default. * *(crypto)* Fixed olm unwedging not working (regressed in v0.25.1). diff --git a/go.mod b/go.mod index d77428d8..fb63cf59 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,7 @@ module maunium.net/go/mautrix go 1.24.0 -toolchain go1.25.1 +toolchain go1.25.3 require ( filippo.io/edwards25519 v1.1.0 @@ -17,11 +17,11 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.13 - go.mau.fi/util v0.9.2-0.20251014102252-c9ee13b043c8 + go.mau.fi/util v0.9.2 go.mau.fi/zeroconfig v0.2.0 - golang.org/x/crypto v0.42.0 - golang.org/x/exp v0.0.0-20250911091902-df9299821621 - golang.org/x/net v0.44.0 + golang.org/x/crypto v0.43.0 + golang.org/x/exp v0.0.0-20251009144603-d2f985daa21b + golang.org/x/net v0.46.0 golang.org/x/sync v0.17.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 @@ -36,7 +36,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect - golang.org/x/sys v0.36.0 // indirect - golang.org/x/text v0.29.0 // indirect + golang.org/x/sys v0.37.0 // indirect + golang.org/x/text v0.30.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index dee6616c..faa4ef4c 100644 --- a/go.sum +++ b/go.sum @@ -51,26 +51,26 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.13 h1:GPddIs617DnBLFFVJFgpo1aBfe/4xcvMc3SB5t/D0pA= github.com/yuin/goldmark v1.7.13/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.9.2-0.20251014102252-c9ee13b043c8 h1:36oe41yPjz7QLjJWb72qHi82IOINqgp06eHIVRdalGs= -go.mau.fi/util v0.9.2-0.20251014102252-c9ee13b043c8/go.mod h1:M0bM9SyaOWJniaHs9hxEzz91r5ql6gYq6o1q5O1SsjQ= +go.mau.fi/util v0.9.2 h1:+S4Z03iCsGqU2WY8X2gySFsFjaLlUHFRDVCYvVwynKM= +go.mau.fi/util v0.9.2/go.mod h1:055elBBCJSdhRsmub7ci9hXZPgGr1U6dYg44cSgRgoU= go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= -golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI= -golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8= -golang.org/x/exp v0.0.0-20250911091902-df9299821621 h1:2id6c1/gto0kaHYyrixvknJ8tUK/Qs5IsmBtrc+FtgU= -golang.org/x/exp v0.0.0-20250911091902-df9299821621/go.mod h1:TwQYMMnGpvZyc+JpB/UAuTNIsVJifOlSkrZkhcvpVUk= -golang.org/x/net v0.44.0 h1:evd8IRDyfNBMBTTY5XRF1vaZlD+EmWx6x8PkhR04H/I= -golang.org/x/net v0.44.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY= +golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= +golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= +golang.org/x/exp v0.0.0-20251009144603-d2f985daa21b h1:18qgiDvlvH7kk8Ioa8Ov+K6xCi0GMvmGfGW0sgd/SYA= +golang.org/x/exp v0.0.0-20251009144603-d2f985daa21b/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= +golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= +golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= -golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= -golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= +golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= +golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= +golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= diff --git a/version.go b/version.go index b76a548c..7b4eea41 100644 --- a/version.go +++ b/version.go @@ -8,7 +8,7 @@ import ( "strings" ) -const Version = "v0.25.1" +const Version = "v0.25.2" var GoModVersion = "" var Commit = "" From 572a704b04da6e82354ca8448c4019760a5aca50 Mon Sep 17 00:00:00 2001 From: Brad Murray Date: Sat, 18 Oct 2025 05:42:01 -0400 Subject: [PATCH 1442/1647] errors: Add M_WRONG_ROOM_KEYS_VERSION (#419) --- error.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/error.go b/error.go index b7c92a5f..59e574d7 100644 --- a/error.go +++ b/error.go @@ -67,6 +67,8 @@ var ( MIncompatibleRoomVersion = RespError{ErrCode: "M_INCOMPATIBLE_ROOM_VERSION"} // The client specified a parameter that has the wrong value. MInvalidParam = RespError{ErrCode: "M_INVALID_PARAM", StatusCode: http.StatusBadRequest} + // The client specified a room key backup version that is not the current room key backup version for the user. + MWrongRoomKeysVersion = RespError{ErrCode: "M_WRONG_ROOM_KEYS_VERSION", StatusCode: http.StatusForbidden} MURLNotSet = RespError{ErrCode: "M_URL_NOT_SET"} MBadStatus = RespError{ErrCode: "M_BAD_STATUS"} From a214af5bab636f1203e60c41155e971a926a8efa Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 16 Oct 2025 11:35:28 +0200 Subject: [PATCH 1443/1647] federation: fix server key query test --- federation/serverauth_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/federation/serverauth_test.go b/federation/serverauth_test.go index 9fa15459..633a0f66 100644 --- a/federation/serverauth_test.go +++ b/federation/serverauth_test.go @@ -21,7 +21,7 @@ func TestServerKeyResponse_VerifySelfSignature(t *testing.T) { ctx := context.Background() for _, name := range []string{"matrix.org", "maunium.net", "continuwuity.org"} { t.Run(name, func(t *testing.T) { - resp, err := cli.ServerKeys(ctx, "matrix.org") + resp, err := cli.ServerKeys(ctx, name) require.NoError(t, err) assert.NoError(t, resp.VerifySelfSignature()) }) From df957301be579fb0f6eb8b4b2644ce84db2df334 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 18 Oct 2025 13:29:16 +0200 Subject: [PATCH 1444/1647] federation: don't allow redirects --- federation/client.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/federation/client.go b/federation/client.go index 8f454516..5c316e56 100644 --- a/federation/client.go +++ b/federation/client.go @@ -37,6 +37,10 @@ func NewClient(serverName string, key *SigningKey, cache ResolutionCache) *Clien HTTP: &http.Client{ Transport: NewServerResolvingTransport(cache), Timeout: 120 * time.Second, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + // Federation requests do not allow redirects. + return http.ErrUseLastResponse + }, }, UserAgent: mautrix.DefaultUserAgent, ServerName: serverName, @@ -310,7 +314,7 @@ func (c *Client) MakeFullRequest(ctx context.Context, params RequestParams) ([]b _ = resp.Body.Close() }() var body []byte - if resp.StatusCode >= 400 { + if resp.StatusCode >= 300 { body, err = mautrix.ParseErrorResponse(req, resp) return body, resp, err } else if params.ResponseJSON != nil || !params.DontReadBody { From 827bb4c6212ca273bf14e84dbe045132e755e44e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 18 Oct 2025 13:33:30 +0200 Subject: [PATCH 1445/1647] federation: add response size limit --- federation/client.go | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/federation/client.go b/federation/client.go index 5c316e56..c84b437a 100644 --- a/federation/client.go +++ b/federation/client.go @@ -30,6 +30,8 @@ type Client struct { ServerName string UserAgent string Key *SigningKey + + ResponseSizeLimit int64 } func NewClient(serverName string, key *SigningKey, cache ResolutionCache) *Client { @@ -45,6 +47,8 @@ func NewClient(serverName string, key *SigningKey, cache ResolutionCache) *Clien UserAgent: mautrix.DefaultUserAgent, ServerName: serverName, Key: key, + + ResponseSizeLimit: 128 * 1024 * 1024, } } @@ -318,7 +322,16 @@ func (c *Client) MakeFullRequest(ctx context.Context, params RequestParams) ([]b body, err = mautrix.ParseErrorResponse(req, resp) return body, resp, err } else if params.ResponseJSON != nil || !params.DontReadBody { - body, err = io.ReadAll(resp.Body) + if resp.ContentLength > c.ResponseSizeLimit { + return body, resp, mautrix.HTTPError{ + Request: req, + Response: resp, + + Message: "response body too long", + WrappedError: fmt.Errorf("%.2f MiB", float64(resp.ContentLength)/1024/1024), + } + } + body, err = io.ReadAll(io.LimitReader(resp.Body, c.ResponseSizeLimit+1)) if err != nil { return body, resp, mautrix.HTTPError{ Request: req, @@ -328,6 +341,15 @@ func (c *Client) MakeFullRequest(ctx context.Context, params RequestParams) ([]b WrappedError: err, } } + if len(body) > int(c.ResponseSizeLimit) { + return body, resp, mautrix.HTTPError{ + Request: req, + Response: resp, + + Message: "failed to read response body", + WrappedError: fmt.Errorf("exceeded read limit"), + } + } if params.ResponseJSON != nil { err = json.Unmarshal(body, params.ResponseJSON) if err != nil { From c50460cd6e3e70245026ca6f020c00e68fd3f0a0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 18 Oct 2025 13:37:19 +0200 Subject: [PATCH 1446/1647] client: add response size limits --- client.go | 145 +++++++++++++++++++++++++++++---------- error.go | 3 + federation/client.go | 18 ++--- federation/resolution.go | 6 +- 4 files changed, 124 insertions(+), 48 deletions(-) diff --git a/client.go b/client.go index 95cbacb5..85a4603e 100644 --- a/client.go +++ b/client.go @@ -111,6 +111,8 @@ type Client struct { // Set to true to disable automatically sleeping on 429 errors. IgnoreRateLimit bool + ResponseSizeLimit int64 + txnID int32 // Should the ?user_id= query parameter be set in requests? @@ -143,6 +145,8 @@ func DiscoverClientAPI(ctx context.Context, serverName string) (*ClientWellKnown return DiscoverClientAPIWithClient(ctx, &http.Client{Timeout: 30 * time.Second}, serverName) } +const WellKnownMaxSize = 64 * 1024 + func DiscoverClientAPIWithClient(ctx context.Context, client *http.Client, serverName string) (*ClientWellKnown, error) { wellKnownURL := url.URL{ Scheme: "https", @@ -168,11 +172,15 @@ func DiscoverClientAPIWithClient(ctx context.Context, client *http.Client, serve if resp.StatusCode == http.StatusNotFound { return nil, nil + } else if resp.ContentLength > WellKnownMaxSize { + return nil, errors.New(".well-known response too large") } - data, err := io.ReadAll(resp.Body) + data, err := io.ReadAll(io.LimitReader(resp.Body, WellKnownMaxSize)) if err != nil { return nil, err + } else if len(data) >= WellKnownMaxSize { + return nil, errors.New(".well-known response too large") } var wellKnown ClientWellKnown @@ -395,24 +403,25 @@ func (cli *Client) MakeRequest(ctx context.Context, method string, httpURL strin return cli.MakeFullRequest(ctx, FullRequest{Method: method, URL: httpURL, RequestJSON: reqBody, ResponseJSON: resBody}) } -type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) +type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON any, sizeLimit int64) ([]byte, error) type FullRequest struct { - Method string - URL string - Headers http.Header - RequestJSON interface{} - RequestBytes []byte - RequestBody io.Reader - RequestLength int64 - ResponseJSON interface{} - MaxAttempts int - BackoffDuration time.Duration - SensitiveContent bool - Handler ClientResponseHandler - DontReadResponse bool - Logger *zerolog.Logger - Client *http.Client + Method string + URL string + Headers http.Header + RequestJSON interface{} + RequestBytes []byte + RequestBody io.Reader + RequestLength int64 + ResponseJSON interface{} + MaxAttempts int + BackoffDuration time.Duration + SensitiveContent bool + Handler ClientResponseHandler + DontReadResponse bool + ResponseSizeLimit int64 + Logger *zerolog.Logger + Client *http.Client } var requestID int32 @@ -537,10 +546,25 @@ func (cli *Client) MakeFullRequestWithResp(ctx context.Context, params FullReque if len(cli.AccessToken) > 0 { req.Header.Set("Authorization", "Bearer "+cli.AccessToken) } + if params.ResponseSizeLimit == 0 { + params.ResponseSizeLimit = cli.ResponseSizeLimit + } + if params.ResponseSizeLimit == 0 { + params.ResponseSizeLimit = DefaultResponseSizeLimit + } if params.Client == nil { params.Client = cli.Client } - return cli.executeCompiledRequest(req, params.MaxAttempts-1, params.BackoffDuration, params.ResponseJSON, params.Handler, params.DontReadResponse, params.Client) + return cli.executeCompiledRequest( + req, + params.MaxAttempts-1, + params.BackoffDuration, + params.ResponseJSON, + params.Handler, + params.DontReadResponse, + params.ResponseSizeLimit, + params.Client, + ) } func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger { @@ -551,7 +575,17 @@ func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger { return log } -func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff time.Duration, responseJSON any, handler ClientResponseHandler, dontReadResponse bool, client *http.Client) ([]byte, *http.Response, error) { +func (cli *Client) doRetry( + req *http.Request, + cause error, + retries int, + backoff time.Duration, + responseJSON any, + handler ClientResponseHandler, + dontReadResponse bool, + sizeLimit int64, + client *http.Client, +) ([]byte, *http.Response, error) { log := zerolog.Ctx(req.Context()) if req.Body != nil { var err error @@ -585,11 +619,23 @@ func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff if cli.UpdateRequestOnRetry != nil { req = cli.UpdateRequestOnRetry(req, cause) } - return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, dontReadResponse, client) + return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler, dontReadResponse, sizeLimit, client) } -func readResponseBody(req *http.Request, res *http.Response) ([]byte, error) { - contents, err := io.ReadAll(res.Body) +func readResponseBody(req *http.Request, res *http.Response, limit int64) ([]byte, error) { + if res.ContentLength > limit { + return nil, HTTPError{ + Request: req, + Response: res, + + Message: "not reading response", + WrappedError: fmt.Errorf("%w (%.2f MiB)", ErrResponseTooLong, float64(res.ContentLength)/1024/1024), + } + } + contents, err := io.ReadAll(io.LimitReader(res.Body, limit+1)) + if err == nil && len(contents) > int(limit) { + err = ErrBodyReadReachedLimit + } if err != nil { return nil, HTTPError{ Request: req, @@ -610,17 +656,20 @@ func closeTemp(log *zerolog.Logger, file *os.File) { } } -func streamResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { +func streamResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) { log := zerolog.Ctx(req.Context()) file, err := os.CreateTemp("", "mautrix-response-") if err != nil { log.Warn().Err(err).Msg("Failed to create temporary file for streaming response") - _, err = handleNormalResponse(req, res, responseJSON) + _, err = handleNormalResponse(req, res, responseJSON, limit) return nil, err } defer closeTemp(log, file) - if _, err = io.Copy(file, res.Body); err != nil { + var n int64 + if n, err = io.Copy(file, io.LimitReader(res.Body, limit+1)); err != nil { return nil, fmt.Errorf("failed to copy response to file: %w", err) + } else if n > limit { + return nil, ErrBodyReadReachedLimit } else if _, err = file.Seek(0, 0); err != nil { return nil, fmt.Errorf("failed to seek to beginning of response file: %w", err) } else if err = json.NewDecoder(file).Decode(responseJSON); err != nil { @@ -630,12 +679,12 @@ func streamResponse(req *http.Request, res *http.Response, responseJSON interfac } } -func noopHandleResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { +func noopHandleResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) { return nil, nil } -func handleNormalResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { - if contents, err := readResponseBody(req, res); err != nil { +func handleNormalResponse(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) { + if contents, err := readResponseBody(req, res, limit); err != nil { return nil, err } else if responseJSON == nil { return contents, nil @@ -653,8 +702,12 @@ func handleNormalResponse(req *http.Request, res *http.Response, responseJSON in } } +const ErrorResponseSizeLimit = 512 * 1024 + +var DefaultResponseSizeLimit int64 = 512 * 1024 * 1024 + func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) { - contents, err := readResponseBody(req, res) + contents, err := readResponseBody(req, res, ErrorResponseSizeLimit) if err != nil { return contents, err } @@ -673,7 +726,16 @@ func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) { } } -func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backoff time.Duration, responseJSON any, handler ClientResponseHandler, dontReadResponse bool, client *http.Client) ([]byte, *http.Response, error) { +func (cli *Client) executeCompiledRequest( + req *http.Request, + retries int, + backoff time.Duration, + responseJSON any, + handler ClientResponseHandler, + dontReadResponse bool, + sizeLimit int64, + client *http.Client, +) ([]byte, *http.Response, error) { cli.RequestStart(req) startTime := time.Now() res, err := client.Do(req) @@ -683,7 +745,9 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof } if err != nil { if retries > 0 && !errors.Is(err, context.Canceled) { - return cli.doRetry(req, err, retries, backoff, responseJSON, handler, dontReadResponse, client) + return cli.doRetry( + req, err, retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client, + ) } err = HTTPError{ Request: req, @@ -698,7 +762,9 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof if retries > 0 && retryafter.Should(res.StatusCode, !cli.IgnoreRateLimit) { backoff = retryafter.Parse(res.Header.Get("Retry-After"), backoff) - return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, dontReadResponse, client) + return cli.doRetry( + req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client, + ) } var body []byte @@ -706,7 +772,7 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof body, err = ParseErrorResponse(req, res) cli.LogRequestDone(req, res, nil, nil, len(body), duration) } else { - body, err = handler(req, res, responseJSON) + body, err = handler(req, res, responseJSON, sizeLimit) cli.LogRequestDone(req, res, nil, err, len(body), duration) } return body, res, err @@ -1628,11 +1694,20 @@ func (cli *Client) FullStateEvent(ctx context.Context, roomID id.RoomID, eventTy } // parseRoomStateArray parses a JSON array as a stream and stores the events inside it in a room state map. -func parseRoomStateArray(_ *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { +func parseRoomStateArray(req *http.Request, res *http.Response, responseJSON any, limit int64) ([]byte, error) { + if res.ContentLength > limit { + return nil, HTTPError{ + Request: req, + Response: res, + + Message: "not reading response", + WrappedError: fmt.Errorf("%w (%.2f MiB)", ErrResponseTooLong, float64(res.ContentLength)/1024/1024), + } + } response := make(RoomStateMap) responsePtr := responseJSON.(*map[event.Type]map[string]*event.Event) *responsePtr = response - dec := json.NewDecoder(res.Body) + dec := json.NewDecoder(io.LimitReader(res.Body, limit)) arrayStart, err := dec.Token() if err != nil { diff --git a/error.go b/error.go index 59e574d7..826af179 100644 --- a/error.go +++ b/error.go @@ -82,6 +82,9 @@ var ( var ( ErrClientIsNil = errors.New("client is nil") ErrClientHasNoHomeserver = errors.New("client has no homeserver set") + + ErrResponseTooLong = errors.New("response content length too long") + ErrBodyReadReachedLimit = errors.New("reached response size limit while reading body") ) // HTTPError An HTTP Error response, which may wrap an underlying native Go Error. diff --git a/federation/client.go b/federation/client.go index c84b437a..f3163f3a 100644 --- a/federation/client.go +++ b/federation/client.go @@ -48,7 +48,7 @@ func NewClient(serverName string, key *SigningKey, cache ResolutionCache) *Clien ServerName: serverName, Key: key, - ResponseSizeLimit: 128 * 1024 * 1024, + ResponseSizeLimit: mautrix.DefaultResponseSizeLimit, } } @@ -327,11 +327,14 @@ func (c *Client) MakeFullRequest(ctx context.Context, params RequestParams) ([]b Request: req, Response: resp, - Message: "response body too long", - WrappedError: fmt.Errorf("%.2f MiB", float64(resp.ContentLength)/1024/1024), + Message: "not reading response", + WrappedError: fmt.Errorf("%w (%.2f MiB)", mautrix.ErrResponseTooLong, float64(resp.ContentLength)/1024/1024), } } body, err = io.ReadAll(io.LimitReader(resp.Body, c.ResponseSizeLimit+1)) + if err == nil && len(body) > int(c.ResponseSizeLimit) { + err = mautrix.ErrBodyReadReachedLimit + } if err != nil { return body, resp, mautrix.HTTPError{ Request: req, @@ -341,15 +344,6 @@ func (c *Client) MakeFullRequest(ctx context.Context, params RequestParams) ([]b WrappedError: err, } } - if len(body) > int(c.ResponseSizeLimit) { - return body, resp, mautrix.HTTPError{ - Request: req, - Response: resp, - - Message: "failed to read response body", - WrappedError: fmt.Errorf("exceeded read limit"), - } - } if params.ResponseJSON != nil { err = json.Unmarshal(body, params.ResponseJSON) if err != nil { diff --git a/federation/resolution.go b/federation/resolution.go index 69d4d3bf..81e19cfb 100644 --- a/federation/resolution.go +++ b/federation/resolution.go @@ -20,6 +20,8 @@ import ( "time" "github.com/rs/zerolog" + + "maunium.net/go/mautrix" ) type ResolvedServerName struct { @@ -171,9 +173,11 @@ func RequestWellKnown(ctx context.Context, cli *http.Client, hostname string) (* defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, time.Time{}, fmt.Errorf("unexpected status code %d", resp.StatusCode) + } else if resp.ContentLength > mautrix.WellKnownMaxSize { + return nil, time.Time{}, fmt.Errorf("response too large: %d bytes", resp.ContentLength) } var respData RespWellKnown - err = json.NewDecoder(io.LimitReader(resp.Body, 50*1024)).Decode(&respData) + err = json.NewDecoder(io.LimitReader(resp.Body, mautrix.WellKnownMaxSize)).Decode(&respData) if err != nil { return nil, time.Time{}, fmt.Errorf("failed to decode response: %w", err) } else if respData.Server == "" { From e61c7b3f1e847d94de222998e22d74a46c6118ce Mon Sep 17 00:00:00 2001 From: timedout Date: Sat, 18 Oct 2025 20:30:43 +0100 Subject: [PATCH 1447/1647] client: Add AdminWhoIs func (#411) --- client.go | 9 +++++++++ responses.go | 20 ++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/client.go b/client.go index 85a4603e..d8bd5b80 100644 --- a/client.go +++ b/client.go @@ -2689,6 +2689,15 @@ func (cli *Client) ReportRoom(ctx context.Context, roomID id.RoomID, reason stri return err } +// AdminWhoIs fetches session information belonging to a specific user. Typically requires being a server admin. +// +// https://spec.matrix.org/v1.15/client-server-api/#get_matrixclientv3adminwhoisuserid +func (cli *Client) AdminWhoIs(ctx context.Context, userID id.UserID) (resp RespWhoIs, err error) { + urlPath := cli.BuildClientURL("v3", "admin", "whois", userID) + _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) + return +} + // UnstableGetSuspendedStatus uses MSC4323 to check if a user is suspended. func (cli *Client) UnstableGetSuspendedStatus(ctx context.Context, userID id.UserID) (res *RespSuspended, err error) { urlPath := cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", "suspend", userID) diff --git a/responses.go b/responses.go index 3484c134..943ea511 100644 --- a/responses.go +++ b/responses.go @@ -755,3 +755,23 @@ type RespSuspended struct { type RespLocked struct { Locked bool `json:"locked"` } + +type ConnectionInfo struct { + IP string `json:"ip,omitempty"` + LastSeen jsontime.UnixMilli `json:"last_seen,omitempty"` + UserAgent string `json:"user_agent,omitempty"` +} + +type SessionInfo struct { + Connections []ConnectionInfo `json:"connections,omitempty"` +} + +type DeviceInfo struct { + Sessions []SessionInfo `json:"sessions,omitempty"` +} + +// RespWhoIs is the response body for https://spec.matrix.org/v1.15/client-server-api/#get_matrixclientv3adminwhoisuserid +type RespWhoIs struct { + UserID id.UserID `json:"user_id,omitempty"` + Devices map[id.Device]DeviceInfo `json:"devices,omitempty"` +} From 2fd9e799d29ed3748c9454add53d13bcbd50b23e Mon Sep 17 00:00:00 2001 From: timedout Date: Sat, 18 Oct 2025 21:27:08 +0100 Subject: [PATCH 1448/1647] synapseadmin: Add force_purge option (#420) --- synapseadmin/roomapi.go | 1 + 1 file changed, 1 insertion(+) diff --git a/synapseadmin/roomapi.go b/synapseadmin/roomapi.go index a09ba174..c360acab 100644 --- a/synapseadmin/roomapi.go +++ b/synapseadmin/roomapi.go @@ -117,6 +117,7 @@ func (cli *Client) RoomMessages(ctx context.Context, roomID id.RoomID, from, to type ReqDeleteRoom struct { Purge bool `json:"purge,omitempty"` + ForcePurge bool `json:"force_purge,omitempty"` Block bool `json:"block,omitempty"` Message string `json:"message,omitempty"` RoomName string `json:"room_name,omitempty"` From a661641bcb630585f96609697a13ea53ebd77eda Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 19 Oct 2025 19:53:10 +0300 Subject: [PATCH 1449/1647] bridgev2/matrix: don't sleep after registering bot on versions error --- bridgev2/matrix/connector.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 3dd9ae1a..64b5d6c7 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -337,16 +337,18 @@ func (br *Connector) logInitialRequestError(err error, defaultMessage string) { } func (br *Connector) ensureConnection(ctx context.Context) { + triedToRegister := false for { versions, err := br.Bot.Versions(ctx) if err != nil { - if errors.Is(err, mautrix.MForbidden) { + if errors.Is(err, mautrix.MForbidden) && !triedToRegister { br.Log.Debug().Msg("M_FORBIDDEN in /versions, trying to register before retrying") err = br.Bot.EnsureRegistered(ctx) if err != nil { br.logInitialRequestError(err, "Failed to register after /versions failed with M_FORBIDDEN") os.Exit(16) } + triedToRegister = true } else if errors.Is(err, mautrix.MUnknownToken) || errors.Is(err, mautrix.MExclusive) { br.logInitialRequestError(err, "/versions request failed with auth error") os.Exit(16) From 7b70ec6d523e6b032886a3aee243431bba4316a5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 20 Oct 2025 11:38:21 +0300 Subject: [PATCH 1450/1647] bridgev2/bridgestate: send transient disconnect notices if they persist --- bridgev2/bridgestate.go | 63 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 58 insertions(+), 5 deletions(-) diff --git a/bridgev2/bridgestate.go b/bridgev2/bridgestate.go index f31d4e92..612f228c 100644 --- a/bridgev2/bridgestate.go +++ b/bridgev2/bridgestate.go @@ -15,6 +15,7 @@ import ( "time" "github.com/rs/zerolog" + "go.mau.fi/util/exfmt" "maunium.net/go/mautrix/bridgev2/status" "maunium.net/go/mautrix/event" @@ -29,6 +30,9 @@ type BridgeStateQueue struct { bridge *Bridge login *UserLogin + firstTransientDisconnect time.Time + cancelScheduledNotice atomic.Pointer[context.CancelFunc] + stopChan chan struct{} stopReconnect atomic.Pointer[context.CancelFunc] } @@ -74,6 +78,9 @@ func (bsq *BridgeStateQueue) StopUnknownErrorReconnect() { if cancelFn := bsq.stopReconnect.Swap(nil); cancelFn != nil { (*cancelFn)() } + if cancelFn := bsq.cancelScheduledNotice.Swap(nil); cancelFn != nil { + (*cancelFn)() + } } func (bsq *BridgeStateQueue) loop() { @@ -91,14 +98,41 @@ func (bsq *BridgeStateQueue) loop() { } } -func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.BridgeState) { +func (bsq *BridgeStateQueue) scheduleNotice(ctx context.Context, triggeredBy status.BridgeState) { + log := bsq.login.Log.With().Str("action", "transient disconnect notice").Logger() + ctx = log.WithContext(bsq.bridge.BackgroundCtx) + if !bsq.waitForTransientDisconnectReconnect(ctx) { + return + } + prevUnsent := bsq.GetPrevUnsent() + prev := bsq.GetPrev() + if triggeredBy.Timestamp != prev.Timestamp || len(bsq.ch) > 0 || bsq.errorSent || + prevUnsent.StateEvent != status.StateTransientDisconnect || prev.StateEvent != status.StateTransientDisconnect { + log.Trace().Any("triggered_by", triggeredBy).Msg("Not sending delayed transient disconnect notice") + return + } + log.Debug().Any("triggered_by", triggeredBy).Msg("Sending delayed transient disconnect notice") + bsq.sendNotice(ctx, triggeredBy, true) +} + +func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.BridgeState, isDelayed bool) { noticeConfig := bsq.bridge.Config.BridgeStatusNotices isError := state.StateEvent == status.StateBadCredentials || state.StateEvent == status.StateUnknownError || - state.UserAction == status.UserActionOpenNative + state.UserAction == status.UserActionOpenNative || + (isDelayed && state.StateEvent == status.StateTransientDisconnect) sendNotice := noticeConfig == "all" || (noticeConfig == "errors" && (isError || (bsq.errorSent && state.StateEvent == status.StateConnected))) + if state.StateEvent != status.StateTransientDisconnect && state.StateEvent != status.StateUnknownError { + bsq.firstTransientDisconnect = time.Time{} + } if !sendNotice { + if !isDelayed && noticeConfig == "errors" && state.StateEvent == status.StateTransientDisconnect { + if bsq.firstTransientDisconnect.IsZero() { + bsq.firstTransientDisconnect = time.Now() + } + go bsq.scheduleNotice(ctx, state) + } return } managementRoom, err := bsq.login.User.GetManagementRoom(ctx) @@ -114,6 +148,9 @@ func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.Bridge if state.Error != "" { message += fmt.Sprintf(" (`%s`)", state.Error) } + if isDelayed { + message += fmt.Sprintf(" not resolved after waiting %s", exfmt.Duration(TransientDisconnectNoticeDelay)) + } if state.Message != "" { message += fmt.Sprintf(": %s", state.Message) } @@ -171,14 +208,30 @@ func (bsq *BridgeStateQueue) waitForUnknownErrorReconnect(ctx context.Context) b return false } reconnectIn += time.Duration(rand.Int64N(int64(float64(reconnectIn)*0.4)) - int64(float64(reconnectIn)*0.2)) + return bsq.waitForReconnect(ctx, reconnectIn, &bsq.stopReconnect) +} + +const TransientDisconnectNoticeDelay = 3 * time.Minute + +func (bsq *BridgeStateQueue) waitForTransientDisconnectReconnect(ctx context.Context) bool { + timeUntilSchedule := time.Until(bsq.firstTransientDisconnect.Add(TransientDisconnectNoticeDelay)) + zerolog.Ctx(ctx).Trace(). + Stringer("duration", timeUntilSchedule). + Msg("Waiting before sending notice about transient disconnect") + return bsq.waitForReconnect(ctx, timeUntilSchedule, &bsq.cancelScheduledNotice) +} + +func (bsq *BridgeStateQueue) waitForReconnect( + ctx context.Context, reconnectIn time.Duration, ptr *atomic.Pointer[context.CancelFunc], +) bool { cancelCtx, cancel := context.WithCancel(ctx) defer cancel() - if oldCancel := bsq.stopReconnect.Swap(&cancel); oldCancel != nil { + if oldCancel := ptr.Swap(&cancel); oldCancel != nil { (*oldCancel)() } select { case <-time.After(reconnectIn): - return bsq.stopReconnect.CompareAndSwap(&cancel, nil) + return ptr.CompareAndSwap(&cancel, nil) case <-cancelCtx.Done(): return false case <-bsq.stopChan: @@ -198,7 +251,7 @@ func (bsq *BridgeStateQueue) immediateSendBridgeState(state status.BridgeState) } ctx := bsq.login.Log.WithContext(context.Background()) - bsq.sendNotice(ctx, state) + bsq.sendNotice(ctx, state, false) retryIn := 2 for { From 56b182f85d04c2102ec07ec307c0f2e5d8e478d2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 20 Oct 2025 11:48:45 +0300 Subject: [PATCH 1451/1647] bridgev2/bridgestate: only send one delayed transient disconnect notice --- bridgev2/bridgestate.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/bridgestate.go b/bridgev2/bridgestate.go index 612f228c..63d5876b 100644 --- a/bridgev2/bridgestate.go +++ b/bridgev2/bridgestate.go @@ -127,7 +127,7 @@ func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.Bridge bsq.firstTransientDisconnect = time.Time{} } if !sendNotice { - if !isDelayed && noticeConfig == "errors" && state.StateEvent == status.StateTransientDisconnect { + if !bsq.errorSent && !isDelayed && noticeConfig == "errors" && state.StateEvent == status.StateTransientDisconnect { if bsq.firstTransientDisconnect.IsZero() { bsq.firstTransientDisconnect = time.Now() } From 36edccf61ab81edb1ce938934e9a61ba5f2232e1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 21 Oct 2025 16:59:18 +0300 Subject: [PATCH 1452/1647] bridgev2/provisionutil: allow mxids as participants in CreateGroup --- bridgev2/networkinterface.go | 3 ++- bridgev2/provisionutil/creategroup.go | 7 ++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 31647f63..bf2d60ee 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -817,7 +817,8 @@ type GroupFieldCapability struct { type GroupCreateParams struct { Type string `json:"type,omitempty"` - Username string `json:"username,omitempty"` + Username string `json:"username,omitempty"` + // Clients may also provide MXIDs here, but provisionutil will normalize them, so bridges only need to handle network IDs Participants []networkid.UserID `json:"participants,omitempty"` Parent *networkid.PortalKey `json:"parent,omitempty"` diff --git a/bridgev2/provisionutil/creategroup.go b/bridgev2/provisionutil/creategroup.go index f389ab42..acae9360 100644 --- a/bridgev2/provisionutil/creategroup.go +++ b/bridgev2/provisionutil/creategroup.go @@ -38,7 +38,12 @@ func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Must have at least %d members", typeSpec.Participants.MinLength)) } userIDValidatingNetwork, uidValOK := login.Bridge.Network.(bridgev2.IdentifierValidatingNetwork) - for _, participant := range params.Participants { + for i, participant := range params.Participants { + parsedParticipant, ok := login.Bridge.Matrix.ParseGhostMXID(id.UserID(participant)) + if ok { + participant = parsedParticipant + params.Participants[i] = participant + } if uidValOK && !userIDValidatingNetwork.ValidateUserID(participant) { return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("User ID %q is not valid on this network", participant)) } From 8ee8fb1a200f7ae1320306534bdd6b8e56b2625c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 21 Oct 2025 17:20:22 +0300 Subject: [PATCH 1453/1647] bridgev2/provisioning: allow group creation to signal failed participants --- bridgev2/networkinterface.go | 7 +++++++ bridgev2/provisionutil/creategroup.go | 4 ++++ 2 files changed, 11 insertions(+) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index bf2d60ee..d1d4215d 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -732,6 +732,13 @@ type CreateChatResponse struct { // If a start DM request (CreateChatWithGhost or ResolveIdentifier) returns the DM to a different user, // this field should have the user ID of said different user. DMRedirectedTo networkid.UserID + + FailedParticipants map[networkid.UserID]CreateChatFailedParticipant +} + +type CreateChatFailedParticipant struct { + Reason string `json:"reason"` + InviteContent *event.Content `json:"invite_content,omitempty"` } // IdentifierResolvingNetworkAPI is an optional interface that network connectors can implement to support starting new direct chats. diff --git a/bridgev2/provisionutil/creategroup.go b/bridgev2/provisionutil/creategroup.go index acae9360..602ea9f8 100644 --- a/bridgev2/provisionutil/creategroup.go +++ b/bridgev2/provisionutil/creategroup.go @@ -22,6 +22,8 @@ type RespCreateGroup struct { ID networkid.PortalID `json:"id"` MXID id.RoomID `json:"mxid"` Portal *bridgev2.Portal `json:"-"` + + FailedParticipants map[networkid.UserID]bridgev2.CreateChatFailedParticipant `json:"failed_participants,omitempty"` } func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev2.GroupCreateParams) (*RespCreateGroup, error) { @@ -109,5 +111,7 @@ func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev ID: resp.Portal.ID, MXID: resp.Portal.MXID, Portal: resp.Portal, + + FailedParticipants: resp.FailedParticipants, }, nil } From 1aacf6e987b187507ba0cf0155bcb11d57529490 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 21 Oct 2025 17:39:57 +0300 Subject: [PATCH 1454/1647] bridgev2/commands: include failed participants in group create response --- bridgev2/commands/startchat.go | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go index 7b755064..7abcddd1 100644 --- a/bridgev2/commands/startchat.go +++ b/bridgev2/commands/startchat.go @@ -20,6 +20,7 @@ import ( "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/bridgev2/provisionutil" "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/format" "maunium.net/go/mautrix/id" ) @@ -195,7 +196,17 @@ func fnCreateGroup(ce *Event) { ce.Reply("Failed to create group: %v", err) return } - ce.Reply("Successfully created group `%s`", resp.ID) + var postfix string + if len(resp.FailedParticipants) > 0 { + failedParticipantsStrings := make([]string, len(resp.FailedParticipants)) + i := 0 + for participantID, meta := range resp.FailedParticipants { + failedParticipantsStrings[i] = fmt.Sprintf("* %s: %s", format.SafeMarkdownCode(participantID), meta.Reason) + i++ + } + postfix += "\n\nFailed to add some participants:\n" + strings.Join(failedParticipantsStrings, "\n") + } + ce.Reply("Successfully created group `%s`%s", resp.ID, postfix) } var CommandSearch = &FullHandler{ From ef31dae082e55562ab1222eb8237dd4f2809ae52 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 21 Oct 2025 18:55:49 +0300 Subject: [PATCH 1455/1647] bridgev2/provisioning: include user and DM room MXID in failed participants --- bridgev2/database/portal.go | 5 +++++ bridgev2/networkinterface.go | 10 +++++++--- bridgev2/portal.go | 20 ++++++++++++++++++++ bridgev2/provisionutil/creategroup.go | 24 +++++++++++++++++++++++- 4 files changed, 55 insertions(+), 4 deletions(-) diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index 97af4c4c..a230df19 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -90,6 +90,7 @@ const ( getAllPortalsWithMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid IS NOT NULL` getAllPortalsWithoutReceiver = getPortalBaseQuery + `WHERE bridge_id=$1 AND receiver=''` getAllDMPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND other_user_id=$2` + getDMPortalQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND receiver=$2 AND other_user_id=$3` getAllPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1` getChildPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND parent_id=$2 AND parent_receiver=$3` @@ -187,6 +188,10 @@ func (pq *PortalQuery) GetAllDMsWith(ctx context.Context, otherUserID networkid. return pq.QueryMany(ctx, getAllDMPortalsQuery, pq.BridgeID, otherUserID) } +func (pq *PortalQuery) GetDM(ctx context.Context, receiver networkid.UserLoginID, otherUserID networkid.UserID) (*Portal, error) { + return pq.QueryOne(ctx, getDMPortalQuery, pq.BridgeID, receiver, otherUserID) +} + func (pq *PortalQuery) GetChildren(ctx context.Context, parentKey networkid.PortalKey) ([]*Portal, error) { return pq.QueryMany(ctx, getChildPortalsQuery, pq.BridgeID, parentKey.ID, parentKey.Receiver) } diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index d1d4215d..4d2f2edf 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -733,12 +733,16 @@ type CreateChatResponse struct { // this field should have the user ID of said different user. DMRedirectedTo networkid.UserID - FailedParticipants map[networkid.UserID]CreateChatFailedParticipant + FailedParticipants map[networkid.UserID]*CreateChatFailedParticipant } type CreateChatFailedParticipant struct { - Reason string `json:"reason"` - InviteContent *event.Content `json:"invite_content,omitempty"` + Reason string `json:"reason"` + InviteEventType string `json:"invite_event_type,omitempty"` + InviteContent *event.Content `json:"invite_content,omitempty"` + + UserMXID id.UserID `json:"user_mxid,omitempty"` + DMRoomMXID id.RoomID `json:"dm_room_mxid,omitempty"` } // IdentifierResolvingNetworkAPI is an optional interface that network connectors can implement to support starting new direct chats. diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 067d92c2..44e83133 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -185,6 +185,16 @@ func (br *Bridge) loadManyPortals(ctx context.Context, portals []*database.Porta return output, nil } +func (br *Bridge) loadPortalWithCacheCheck(ctx context.Context, dbPortal *database.Portal) (*Portal, error) { + if dbPortal == nil { + return nil, nil + } else if cached, ok := br.portalsByKey[dbPortal.PortalKey]; ok { + return cached, nil + } else { + return br.loadPortal(ctx, dbPortal, nil, nil) + } +} + func (br *Bridge) UnlockedGetPortalByKey(ctx context.Context, key networkid.PortalKey, onlyIfExists bool) (*Portal, error) { if br.Config.SplitPortals && key.Receiver == "" { return nil, fmt.Errorf("receiver must always be set when split portals is enabled") @@ -274,6 +284,16 @@ func (br *Bridge) GetDMPortalsWith(ctx context.Context, otherUserID networkid.Us return br.loadManyPortals(ctx, rows) } +func (br *Bridge) GetDMPortal(ctx context.Context, receiver networkid.UserLoginID, otherUserID networkid.UserID) (*Portal, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + dbPortal, err := br.DB.Portal.GetDM(ctx, receiver, otherUserID) + if err != nil { + return nil, err + } + return br.loadPortalWithCacheCheck(ctx, dbPortal) +} + func (br *Bridge) GetPortalByKey(ctx context.Context, key networkid.PortalKey) (*Portal, error) { br.cacheLock.Lock() defer br.cacheLock.Unlock() diff --git a/bridgev2/provisionutil/creategroup.go b/bridgev2/provisionutil/creategroup.go index 602ea9f8..0df09ff5 100644 --- a/bridgev2/provisionutil/creategroup.go +++ b/bridgev2/provisionutil/creategroup.go @@ -15,6 +15,7 @@ import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -23,7 +24,7 @@ type RespCreateGroup struct { MXID id.RoomID `json:"mxid"` Portal *bridgev2.Portal `json:"-"` - FailedParticipants map[networkid.UserID]bridgev2.CreateChatFailedParticipant `json:"failed_participants,omitempty"` + FailedParticipants map[networkid.UserID]*bridgev2.CreateChatFailedParticipant `json:"failed_participants,omitempty"` } func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev2.GroupCreateParams) (*RespCreateGroup, error) { @@ -107,6 +108,27 @@ func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to create portal room")) } } + for key, fp := range resp.FailedParticipants { + if fp.InviteEventType == "" { + fp.InviteEventType = event.EventMessage.Type + } + if fp.UserMXID == "" { + ghost, err := login.Bridge.GetGhostByID(ctx, key) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get ghost for failed participant") + } else if ghost != nil { + fp.UserMXID = ghost.Intent.GetMXID() + } + } + if fp.DMRoomMXID == "" { + portal, err := login.Bridge.GetDMPortal(ctx, login.ID, key) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to get DM portal for failed participant") + } else if portal != nil { + fp.DMRoomMXID = portal.MXID + } + } + } return &RespCreateGroup{ ID: resp.Portal.ID, MXID: resp.Portal.MXID, From 237499fdf5d0cc04b7d8546128ee2188213ab706 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 21 Oct 2025 22:53:18 +0300 Subject: [PATCH 1456/1647] client: fix admin whois response body --- responses.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/responses.go b/responses.go index 943ea511..1a221f8a 100644 --- a/responses.go +++ b/responses.go @@ -772,6 +772,6 @@ type DeviceInfo struct { // RespWhoIs is the response body for https://spec.matrix.org/v1.15/client-server-api/#get_matrixclientv3adminwhoisuserid type RespWhoIs struct { - UserID id.UserID `json:"user_id,omitempty"` - Devices map[id.Device]DeviceInfo `json:"devices,omitempty"` + UserID id.UserID `json:"user_id,omitempty"` + Devices map[id.DeviceID]DeviceInfo `json:"devices,omitempty"` } From e805815e41204f8765daf90ecafc66d17a818925 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 22 Oct 2025 13:03:32 +0300 Subject: [PATCH 1457/1647] bridgev2/commands: add account data debug command --- bridgev2/commands/debug.go | 42 ++++++++++++++++++++++++++++++++++ bridgev2/commands/processor.go | 2 +- 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/bridgev2/commands/debug.go b/bridgev2/commands/debug.go index 4c93dbd4..ad773ac8 100644 --- a/bridgev2/commands/debug.go +++ b/bridgev2/commands/debug.go @@ -7,10 +7,13 @@ package commands import ( + "encoding/json" "strings" + "time" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" ) var CommandRegisterPush = &FullHandler{ @@ -59,3 +62,42 @@ var CommandRegisterPush = &FullHandler{ RequiresLogin: true, NetworkAPI: NetworkAPIImplements[bridgev2.PushableNetworkAPI], } + +var CommandSendAccountData = &FullHandler{ + Func: func(ce *Event) { + if len(ce.Args) < 2 { + ce.Reply("Usage: `$cmdprefix debug-account-data ") + return + } + var content event.Content + evtType := event.Type{Type: ce.Args[0], Class: event.AccountDataEventType} + ce.RawArgs = strings.TrimSpace(strings.Trim(ce.RawArgs, ce.Args[0])) + err := json.Unmarshal([]byte(ce.RawArgs), &content) + if err != nil { + ce.Reply("Failed to parse JSON: %v", err) + return + } + err = content.ParseRaw(evtType) + if err != nil { + ce.Reply("Failed to deserialize content: %v", err) + return + } + res := ce.Bridge.QueueMatrixEvent(ce.Ctx, &event.Event{ + Sender: ce.User.MXID, + Type: evtType, + Timestamp: time.Now().UnixMilli(), + RoomID: ce.RoomID, + Content: content, + }) + ce.Reply("Result: %+v", res) + }, + Name: "debug-account-data", + Help: HelpMeta{ + Section: HelpSectionAdmin, + Description: "Send a room account data event to the bridge", + Args: "<_type_> <_content_>", + }, + RequiresAdmin: true, + RequiresPortal: true, + RequiresLogin: true, +} diff --git a/bridgev2/commands/processor.go b/bridgev2/commands/processor.go index c28e3a32..290d4196 100644 --- a/bridgev2/commands/processor.go +++ b/bridgev2/commands/processor.go @@ -41,7 +41,7 @@ func NewProcessor(bridge *bridgev2.Bridge) bridgev2.CommandProcessor { } proc.AddHandlers( CommandHelp, CommandCancel, - CommandRegisterPush, CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom, + CommandRegisterPush, CommandSendAccountData, CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom, CommandLogin, CommandRelogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin, CommandSetRelay, CommandUnsetRelay, CommandResolveIdentifier, CommandStartChat, CommandSearch, From 1cd285dee0d19f48ec55143d7c37cca43dbdf075 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 22 Oct 2025 15:51:35 +0300 Subject: [PATCH 1458/1647] bridgev2/matrixinvite: allow redirecting created DM to no ghost --- bridgev2/matrixinvite.go | 11 +++++++---- bridgev2/networkinterface.go | 3 +++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/bridgev2/matrixinvite.go b/bridgev2/matrixinvite.go index 2c14cc7f..05479a3c 100644 --- a/bridgev2/matrixinvite.go +++ b/bridgev2/matrixinvite.go @@ -221,6 +221,7 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen rejectInvite(ctx, evt, br.Bot, "") return EventHandlingResultSuccess } + overrideIntent := invitedGhost.Intent if resp.DMRedirectedTo != "" && resp.DMRedirectedTo != invitedGhost.ID { log.Debug(). Str("dm_redirected_to_id", string(resp.DMRedirectedTo)). @@ -234,11 +235,13 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen if err != nil { log.Err(err).Msg("Failed to make incorrect ghost leave new DM room") } - otherUserGhost, err := br.GetGhostByID(ctx, resp.DMRedirectedTo) - if err != nil { + if resp.DMRedirectedTo != SpecialValueDMRedirectedToBot { + overrideIntent = br.Bot + } else if otherUserGhost, err := br.GetGhostByID(ctx, resp.DMRedirectedTo); err != nil { log.Err(err).Msg("Failed to get ghost of real portal other user ID") } else { invitedGhost = otherUserGhost + overrideIntent = otherUserGhost.Intent } } err = portal.UpdateMatrixRoomID(ctx, evt.RoomID, UpdateMatrixRoomIDParams{ @@ -251,7 +254,7 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen }) if err != nil { log.Err(err).Msg("Failed to update Matrix room ID for new DM portal") - sendNotice(ctx, evt, invitedGhost.Intent, "Failed to finish configuring portal. The chat may or may not work") + sendNotice(ctx, evt, overrideIntent, "Failed to finish configuring portal. The chat may or may not work") return EventHandlingResultSuccess } message := "Private chat portal created" @@ -263,7 +266,7 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen message += fmt.Sprintf("\n\nWarning: %s", err.Error()) } } - sendNotice(ctx, evt, invitedGhost.Intent, message) + sendNotice(ctx, evt, overrideIntent, message) return EventHandlingResultSuccess } diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 4d2f2edf..9ca2dc43 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -16,6 +16,7 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/configupgrade" "go.mau.fi/util/ptr" + "go.mau.fi/util/random" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" @@ -724,6 +725,8 @@ type ResolveIdentifierResponse struct { Chat *CreateChatResponse } +var SpecialValueDMRedirectedToBot = networkid.UserID("__fi.mau.bridgev2.dm_redirected_to_bot::" + random.String(10)) + type CreateChatResponse struct { PortalKey networkid.PortalKey // Portal and PortalInfo are not required, the caller will fetch them automatically based on PortalKey if necessary. From 2a015350302b07e71abc4b4e8635663fbb190a87 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 22 Oct 2025 16:50:27 +0300 Subject: [PATCH 1459/1647] bridgev2/portal: add helpers for chat member map --- bridgev2/portal.go | 40 +++++++++++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 44e83133..566847fb 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -3604,12 +3604,42 @@ type PortalInfo = ChatInfo type ChatMember struct { EventSender Membership event.Membership - Nickname *string + // Per-room nickname for the user. Not yet used. + Nickname *string + // The power level to set for the user when syncing power levels. PowerLevel *int - UserInfo *UserInfo - + // Optional user info to sync the ghost user while updating membership. + UserInfo *UserInfo + // The user who sent the membership change (user who invited/kicked/banned this user). + // Not yet used. Not applicable if Membership is join or knock. + MemberSender EventSender + // Extra fields to include in the member event. MemberEventExtra map[string]any - PrevMembership event.Membership + // The expected previous membership. If this doesn't match, the change is ignored. + PrevMembership event.Membership +} + +type ChatMemberMap map[networkid.UserID]ChatMember + +// Set adds the given entry to this map, overwriting any existing entry with the same Sender field. +func (cmm ChatMemberMap) Set(member ChatMember) { + if member.Sender == "" && member.SenderLogin == "" && !member.IsFromMe { + return + } + cmm[member.Sender] = member +} + +// Add adds the given entry to this map, but will ignore it if an entry with the same Sender field already exists. +// It returns true if the entry was added, false otherwise. +func (cmm ChatMemberMap) Add(member ChatMember) bool { + if member.Sender == "" && member.SenderLogin == "" && !member.IsFromMe { + return false + } + if _, exists := cmm[member.Sender]; exists { + return false + } + cmm[member.Sender] = member + return true } type ChatMemberList struct { @@ -3633,7 +3663,7 @@ type ChatMemberList struct { // Deprecated: Use MemberMap instead to avoid duplicate entries Members []ChatMember - MemberMap map[networkid.UserID]ChatMember + MemberMap ChatMemberMap PowerLevels *PowerLevelOverrides } From 7f0f51ecf3afd15fbacd912cccc6756d79a08682 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 22 Oct 2025 18:13:21 +0300 Subject: [PATCH 1460/1647] bridgev2/commands: add command to sync single chat --- bridgev2/commands/processor.go | 2 +- bridgev2/commands/startchat.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/bridgev2/commands/processor.go b/bridgev2/commands/processor.go index 290d4196..6062a39a 100644 --- a/bridgev2/commands/processor.go +++ b/bridgev2/commands/processor.go @@ -44,7 +44,7 @@ func NewProcessor(bridge *bridgev2.Bridge) bridgev2.CommandProcessor { CommandRegisterPush, CommandSendAccountData, CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom, CommandLogin, CommandRelogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin, CommandSetRelay, CommandUnsetRelay, - CommandResolveIdentifier, CommandStartChat, CommandSearch, + CommandResolveIdentifier, CommandStartChat, CommandSearch, CommandSyncChat, CommandSudo, CommandDoIn, ) return proc diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go index 7abcddd1..b94236df 100644 --- a/bridgev2/commands/startchat.go +++ b/bridgev2/commands/startchat.go @@ -13,6 +13,7 @@ import ( "maps" "slices" "strings" + "time" "github.com/rs/zerolog" @@ -36,6 +37,35 @@ var CommandResolveIdentifier = &FullHandler{ NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI], } +var CommandSyncChat = &FullHandler{ + Func: func(ce *Event) { + login, _, err := ce.Portal.FindPreferredLogin(ce.Ctx, ce.User, false) + if err != nil { + ce.Log.Err(err).Msg("Failed to find login for sync") + ce.Reply("Failed to find login: %v", err) + return + } else if login == nil { + ce.Reply("No login found for sync") + return + } + info, err := login.Client.GetChatInfo(ce.Ctx, ce.Portal) + if err != nil { + ce.Log.Err(err).Msg("Failed to get chat info for sync") + ce.Reply("Failed to get chat info: %v", err) + return + } + ce.Portal.UpdateInfo(ce.Ctx, info, login, nil, time.Time{}) + ce.React("✅️") + }, + Name: "sync-portal", + Help: HelpMeta{ + Section: HelpSectionChats, + Description: "Sync the current portal room", + }, + RequiresPortal: true, + RequiresLogin: true, +} + var CommandStartChat = &FullHandler{ Func: fnResolveIdentifier, Name: "start-chat", From 9fd1e0f87cefddbd0e8b3c9db0073d2bb0a38048 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 22 Oct 2025 18:56:41 +0300 Subject: [PATCH 1461/1647] bridgev2/networkinterface: allow deleting children in chat delete event --- bridgev2/networkinterface.go | 5 ++++ bridgev2/portal.go | 49 ++++++++++++++++++++++++++++++++++++ bridgev2/simplevent/chat.go | 7 +++++- 3 files changed, 60 insertions(+), 1 deletion(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 9ca2dc43..da505435 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -1135,6 +1135,11 @@ type RemoteChatDelete interface { RemoteDeleteOnlyForMe } +type RemoteChatDeleteWithChildren interface { + RemoteChatDelete + DeleteChildren() bool +} + type RemoteEventThatMayCreatePortal interface { RemoteEvent ShouldCreatePortal() bool diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 566847fb..0bd23b9e 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -284,6 +284,16 @@ func (br *Bridge) GetDMPortalsWith(ctx context.Context, otherUserID networkid.Us return br.loadManyPortals(ctx, rows) } +func (br *Bridge) GetChildPortals(ctx context.Context, parent networkid.PortalKey) ([]*Portal, error) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + rows, err := br.DB.Portal.GetChildren(ctx, parent) + if err != nil { + return nil, err + } + return br.loadManyPortals(ctx, rows) +} + func (br *Bridge) GetDMPortal(ctx context.Context, receiver networkid.UserLoginID, otherUserID networkid.UserID) (*Portal, error) { br.cacheLock.Lock() defer br.cacheLock.Unlock() @@ -3514,6 +3524,20 @@ func (portal *Portal) findOtherLogins(ctx context.Context, source *UserLogin) (o return } +type childDeleteProxy struct { + RemoteChatDeleteWithChildren + child networkid.PortalKey + done func() +} + +func (cdp *childDeleteProxy) AddLogContext(c zerolog.Context) zerolog.Context { + return cdp.RemoteChatDeleteWithChildren.AddLogContext(c).Str("subaction", "delete children") +} +func (cdp *childDeleteProxy) GetPortalKey() networkid.PortalKey { return cdp.child } +func (cdp *childDeleteProxy) ShouldCreatePortal() bool { return false } +func (cdp *childDeleteProxy) PreHandle(ctx context.Context, portal *Portal) {} +func (cdp *childDeleteProxy) PostHandle(ctx context.Context, portal *Portal) { cdp.done() } + func (portal *Portal) handleRemoteChatDelete(ctx context.Context, source *UserLogin, evt RemoteChatDelete) EventHandlingResult { log := zerolog.Ctx(ctx) if portal.Receiver == "" && evt.DeleteOnlyForMe() { @@ -3549,6 +3573,31 @@ func (portal *Portal) handleRemoteChatDelete(ctx context.Context, source *UserLo } } } + if childDeleter, ok := evt.(RemoteChatDeleteWithChildren); ok && childDeleter.DeleteChildren() && portal.RoomType == database.RoomTypeSpace { + children, err := portal.Bridge.GetChildPortals(ctx, portal.PortalKey) + if err != nil { + log.Err(err).Msg("Failed to fetch children to delete") + return EventHandlingResultFailed.WithError(err) + } + log.Debug(). + Int("portal_count", len(children)). + Msg("Deleting child portals before remote chat delete") + var wg sync.WaitGroup + wg.Add(len(children)) + for _, child := range children { + child.queueEvent(ctx, &portalRemoteEvent{ + evt: &childDeleteProxy{ + RemoteChatDeleteWithChildren: childDeleter, + child: child.PortalKey, + done: wg.Done, + }, + source: source, + evtType: RemoteEventChatDelete, + }) + } + wg.Wait() + log.Debug().Msg("Finished deleting child portals") + } err := portal.Delete(ctx) if err != nil { log.Err(err).Msg("Failed to delete portal from database") diff --git a/bridgev2/simplevent/chat.go b/bridgev2/simplevent/chat.go index c725141b..56e3a6b1 100644 --- a/bridgev2/simplevent/chat.go +++ b/bridgev2/simplevent/chat.go @@ -65,14 +65,19 @@ func (evt *ChatResync) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) type ChatDelete struct { EventMeta OnlyForMe bool + Children bool } -var _ bridgev2.RemoteChatDelete = (*ChatDelete)(nil) +var _ bridgev2.RemoteChatDeleteWithChildren = (*ChatDelete)(nil) func (evt *ChatDelete) DeleteOnlyForMe() bool { return evt.OnlyForMe } +func (evt *ChatDelete) DeleteChildren() bool { + return evt.Children +} + // ChatInfoChange is a simple implementation of [bridgev2.RemoteChatInfoChange]. type ChatInfoChange struct { EventMeta From bae61f955f55e087fe15242d19cbb516020f487e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 22 Oct 2025 20:54:53 +0300 Subject: [PATCH 1462/1647] bridgev2/matrixinvite: fix bugs in DM creation --- bridgev2/matrixinvite.go | 4 ++-- bridgev2/portal.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bridgev2/matrixinvite.go b/bridgev2/matrixinvite.go index 05479a3c..b8a5aec6 100644 --- a/bridgev2/matrixinvite.go +++ b/bridgev2/matrixinvite.go @@ -226,7 +226,7 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen log.Debug(). Str("dm_redirected_to_id", string(resp.DMRedirectedTo)). Msg("Created DM was redirected to another user ID") - _, err = invitedGhost.Intent.SendState(ctx, portal.MXID, event.StateMember, invitedGhost.Intent.GetMXID().String(), &event.Content{ + _, err = invitedGhost.Intent.SendState(ctx, evt.RoomID, event.StateMember, invitedGhost.Intent.GetMXID().String(), &event.Content{ Parsed: &event.MemberEventContent{ Membership: event.MembershipLeave, Reason: "Direct chat redirected to another internal user ID", @@ -235,7 +235,7 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen if err != nil { log.Err(err).Msg("Failed to make incorrect ghost leave new DM room") } - if resp.DMRedirectedTo != SpecialValueDMRedirectedToBot { + if resp.DMRedirectedTo == SpecialValueDMRedirectedToBot { overrideIntent = br.Bot } else if otherUserGhost, err := br.GetGhostByID(ctx, resp.DMRedirectedTo); err != nil { log.Err(err).Msg("Failed to get ghost of real portal other user ID") diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 0bd23b9e..8fd29bb3 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -4885,7 +4885,7 @@ func (portal *Portal) addToUserSpaces(ctx context.Context) { if portal.Receiver != "" { login := portal.Bridge.GetCachedUserLoginByID(portal.Receiver) if login != nil { - up, err := portal.Bridge.DB.UserPortal.Get(ctx, login.UserLogin, portal.PortalKey) + up, err := portal.Bridge.DB.UserPortal.GetOrCreate(ctx, login.UserLogin, portal.PortalKey) if err != nil { log.Err(err).Msg("Failed to get user portal to add portal to spaces") } else { From 34a65d3087f280c28cd6c91e3fd6b462830a348d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 22 Oct 2025 21:24:14 +0300 Subject: [PATCH 1463/1647] bridgev2/commands: enable create group command --- bridgev2/commands/processor.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/commands/processor.go b/bridgev2/commands/processor.go index 6062a39a..13a35687 100644 --- a/bridgev2/commands/processor.go +++ b/bridgev2/commands/processor.go @@ -44,7 +44,7 @@ func NewProcessor(bridge *bridgev2.Bridge) bridgev2.CommandProcessor { CommandRegisterPush, CommandSendAccountData, CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom, CommandLogin, CommandRelogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin, CommandSetRelay, CommandUnsetRelay, - CommandResolveIdentifier, CommandStartChat, CommandSearch, CommandSyncChat, + CommandResolveIdentifier, CommandStartChat, CommandCreateGroup, CommandSearch, CommandSyncChat, CommandSudo, CommandDoIn, ) return proc From 33d8d658fe9825db0b97053900e159587aea6559 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 22 Oct 2025 21:25:46 +0300 Subject: [PATCH 1464/1647] bridgev2/commands: fix panic when creating group with no arguments --- bridgev2/commands/startchat.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go index b94236df..99924851 100644 --- a/bridgev2/commands/startchat.go +++ b/bridgev2/commands/startchat.go @@ -80,8 +80,14 @@ var CommandStartChat = &FullHandler{ } func getClientForStartingChat[T bridgev2.IdentifierResolvingNetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) { - remainingArgs := ce.Args[1:] - login := ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) + var remainingArgs []string + if len(ce.Args) > 1 { + remainingArgs = ce.Args[1:] + } + var login *bridgev2.UserLogin + if len(ce.Args) > 0 { + login = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) + } if login == nil || login.UserMXID != ce.User.MXID { remainingArgs = ce.Args login = ce.User.GetDefaultLogin() From 756196ad4fd89989fdf5c0a30f130a0ded269390 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 23 Oct 2025 15:12:42 +0300 Subject: [PATCH 1465/1647] bridgev2/disappear: only start timers for read messages rather than all pending ones (#415) --- bridgev2/database/disappear.go | 27 ++++++++++--------- bridgev2/database/upgrades/00-latest.sql | 3 ++- .../upgrades/23-disappearing-timer-ts.sql | 2 ++ bridgev2/disappear.go | 4 +-- bridgev2/portal.go | 10 +++++-- bridgev2/portalbackfill.go | 1 + 6 files changed, 30 insertions(+), 17 deletions(-) create mode 100644 bridgev2/database/upgrades/23-disappearing-timer-ts.sql diff --git a/bridgev2/database/disappear.go b/bridgev2/database/disappear.go index 9874e472..c2d7d56c 100644 --- a/bridgev2/database/disappear.go +++ b/bridgev2/database/disappear.go @@ -67,26 +67,27 @@ type DisappearingMessageQuery struct { } type DisappearingMessage struct { - BridgeID networkid.BridgeID - RoomID id.RoomID - EventID id.EventID + BridgeID networkid.BridgeID + RoomID id.RoomID + EventID id.EventID + Timestamp time.Time DisappearingSetting } const ( upsertDisappearingMessageQuery = ` - INSERT INTO disappearing_message (bridge_id, mx_room, mxid, type, timer, disappear_at) - VALUES ($1, $2, $3, $4, $5, $6) + INSERT INTO disappearing_message (bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at) + VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT (bridge_id, mxid) DO UPDATE SET timer=excluded.timer, disappear_at=excluded.disappear_at ` startDisappearingMessagesQuery = ` UPDATE disappearing_message SET disappear_at=$1 + timer - WHERE bridge_id=$2 AND mx_room=$3 AND disappear_at IS NULL AND type='after_read' - RETURNING bridge_id, mx_room, mxid, type, timer, disappear_at + WHERE bridge_id=$2 AND mx_room=$3 AND disappear_at IS NULL AND type='after_read' AND timestamp<=$4 + RETURNING bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at ` getUpcomingDisappearingMessagesQuery = ` - SELECT bridge_id, mx_room, mxid, type, timer, disappear_at + SELECT bridge_id, mx_room, mxid, timestamp, type, timer, disappear_at FROM disappearing_message WHERE bridge_id = $1 AND disappear_at IS NOT NULL AND disappear_at < $2 ORDER BY disappear_at LIMIT $3 ` @@ -100,8 +101,8 @@ func (dmq *DisappearingMessageQuery) Put(ctx context.Context, dm *DisappearingMe return dmq.Exec(ctx, upsertDisappearingMessageQuery, dm.sqlVariables()...) } -func (dmq *DisappearingMessageQuery) StartAll(ctx context.Context, roomID id.RoomID) ([]*DisappearingMessage, error) { - return dmq.QueryMany(ctx, startDisappearingMessagesQuery, time.Now().UnixNano(), dmq.BridgeID, roomID) +func (dmq *DisappearingMessageQuery) StartAllBefore(ctx context.Context, roomID id.RoomID, beforeTS time.Time) ([]*DisappearingMessage, error) { + return dmq.QueryMany(ctx, startDisappearingMessagesQuery, time.Now().UnixNano(), dmq.BridgeID, roomID, beforeTS.UnixNano()) } func (dmq *DisappearingMessageQuery) GetUpcoming(ctx context.Context, duration time.Duration, limit int) ([]*DisappearingMessage, error) { @@ -113,17 +114,19 @@ func (dmq *DisappearingMessageQuery) Delete(ctx context.Context, eventID id.Even } func (d *DisappearingMessage) Scan(row dbutil.Scannable) (*DisappearingMessage, error) { + var timestamp int64 var disappearAt sql.NullInt64 - err := row.Scan(&d.BridgeID, &d.RoomID, &d.EventID, &d.Type, &d.Timer, &disappearAt) + err := row.Scan(&d.BridgeID, &d.RoomID, &d.EventID, ×tamp, &d.Type, &d.Timer, &disappearAt) if err != nil { return nil, err } if disappearAt.Valid { d.DisappearAt = time.Unix(0, disappearAt.Int64) } + d.Timestamp = time.Unix(0, timestamp) return d, nil } func (d *DisappearingMessage) sqlVariables() []any { - return []any{d.BridgeID, d.RoomID, d.EventID, d.Type, d.Timer, dbutil.ConvertedPtr(d.DisappearAt, time.Time.UnixNano)} + return []any{d.BridgeID, d.RoomID, d.EventID, d.Timestamp.UnixNano(), d.Type, d.Timer, dbutil.ConvertedPtr(d.DisappearAt, time.Time.UnixNano)} } diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index 4eea05bb..a8bb5c64 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v22 (compatible with v9+): Latest revision +-- v0 -> v23 (compatible with v9+): Latest revision CREATE TABLE "user" ( bridge_id TEXT NOT NULL, mxid TEXT NOT NULL, @@ -127,6 +127,7 @@ CREATE TABLE disappearing_message ( bridge_id TEXT NOT NULL, mx_room TEXT NOT NULL, mxid TEXT NOT NULL, + timestamp BIGINT NOT NULL DEFAULT 0, type TEXT NOT NULL, timer BIGINT NOT NULL, disappear_at BIGINT, diff --git a/bridgev2/database/upgrades/23-disappearing-timer-ts.sql b/bridgev2/database/upgrades/23-disappearing-timer-ts.sql new file mode 100644 index 00000000..ecd00b8d --- /dev/null +++ b/bridgev2/database/upgrades/23-disappearing-timer-ts.sql @@ -0,0 +1,2 @@ +-- v23 (compatible with v9+): Add event timestamp for disappearing messages +ALTER TABLE disappearing_message ADD COLUMN timestamp BIGINT NOT NULL DEFAULT 0; diff --git a/bridgev2/disappear.go b/bridgev2/disappear.go index f072c01f..b5c37e8f 100644 --- a/bridgev2/disappear.go +++ b/bridgev2/disappear.go @@ -86,8 +86,8 @@ func (dl *DisappearLoop) Stop() { } } -func (dl *DisappearLoop) StartAll(ctx context.Context, roomID id.RoomID) { - startedMessages, err := dl.br.DB.DisappearingMessage.StartAll(ctx, roomID) +func (dl *DisappearLoop) StartAllBefore(ctx context.Context, roomID id.RoomID, beforeTS time.Time) { + startedMessages, err := dl.br.DB.DisappearingMessage.StartAllBefore(ctx, roomID, beforeTS) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to start disappearing messages") return diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 8fd29bb3..e87ce9d5 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -845,7 +845,7 @@ func (portal *Portal) callReadReceiptHandler( if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to save user portal metadata") } - portal.Bridge.DisappearLoop.StartAll(ctx, portal.MXID) + portal.Bridge.DisappearLoop.StartAllBefore(ctx, portal.MXID, evt.ReadUpTo) } func (portal *Portal) handleMatrixTyping(ctx context.Context, evt *event.Event) EventHandlingResult { @@ -1193,6 +1193,7 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ RoomID: portal.MXID, EventID: message.MXID, + Timestamp: message.Timestamp, DisappearingSetting: portal.Disappear.StartingAt(message.Timestamp), }) } @@ -2588,6 +2589,7 @@ func (portal *Portal) sendConvertedMessage( portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ RoomID: portal.MXID, EventID: dbMessage.MXID, + Timestamp: dbMessage.Timestamp, DisappearingSetting: converted.Disappear, }) } @@ -3374,11 +3376,15 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL return evt.Int64("target_stream_order", targetStreamOrder) } err = soIntent.MarkStreamOrderRead(ctx, portal.MXID, targetStreamOrder, getEventTS(evt)) + if readUpTo.IsZero() { + readUpTo = getEventTS(evt) + } } else { addTargetLog = func(evt *zerolog.Event) *zerolog.Event { return evt.Stringer("target_mxid", lastTarget.MXID) } err = intent.MarkRead(ctx, portal.MXID, lastTarget.MXID, getEventTS(evt)) + readUpTo = lastTarget.Timestamp } if err != nil { addTargetLog(log.Err(err)).Msg("Failed to bridge read receipt") @@ -3387,7 +3393,7 @@ func (portal *Portal) handleRemoteReadReceipt(ctx context.Context, source *UserL addTargetLog(log.Debug()).Msg("Bridged read receipt") } if sender.IsFromMe { - portal.Bridge.DisappearLoop.StartAll(ctx, portal.MXID) + portal.Bridge.DisappearLoop.StartAllBefore(ctx, portal.MXID, readUpTo) } return EventHandlingResultSuccess } diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index f7819968..cbbce596 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -387,6 +387,7 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin out.Disappear = append(out.Disappear, &database.DisappearingMessage{ RoomID: portal.MXID, EventID: evtID, + Timestamp: msg.Timestamp, DisappearingSetting: msg.Disappear, }) } From 1be49d53e4f3fe27dbb7912f4ffbdbb6de986b1c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 23 Oct 2025 15:46:57 +0300 Subject: [PATCH 1466/1647] bridgev2/config: add option to limit maximum number of logins --- bridgev2/bridgeconfig/permissions.go | 1 + bridgev2/commands/login.go | 9 +++++++++ bridgev2/matrix/provisioning.go | 12 +++++++----- bridgev2/user.go | 4 ++++ 4 files changed, 21 insertions(+), 5 deletions(-) diff --git a/bridgev2/bridgeconfig/permissions.go b/bridgev2/bridgeconfig/permissions.go index 610051e0..898bf58a 100644 --- a/bridgev2/bridgeconfig/permissions.go +++ b/bridgev2/bridgeconfig/permissions.go @@ -24,6 +24,7 @@ type Permissions struct { DoublePuppet bool `yaml:"double_puppet"` Admin bool `yaml:"admin"` ManageRelay bool `yaml:"manage_relay"` + MaxLogins int `yaml:"max_logins"` } type PermissionConfig map[string]*Permissions diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go index a18564c2..0f7bd821 100644 --- a/bridgev2/commands/login.go +++ b/bridgev2/commands/login.go @@ -70,6 +70,15 @@ func fnLogin(ce *Event) { } ce.Args = ce.Args[1:] } + if reauth == nil && ce.User.HasTooManyLogins() { + ce.Reply( + "You have reached the maximum number of logins (%d). "+ + "Please logout from an existing login before creating a new one. "+ + "If you want to re-authenticate an existing login, use the `$cmdprefix relogin` command.", + ce.User.Permissions.MaxLogins, + ) + return + } flows := ce.Bridge.Network.GetLoginFlows() var chosenFlowID string if len(ce.Args) > 0 { diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 61aad869..43d19380 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -367,17 +367,19 @@ func (prov *ProvisioningAPI) GetCapabilities(w http.ResponseWriter, r *http.Requ } var ErrNilStep = errors.New("bridge returned nil step with no error") +var ErrTooManyLogins = bridgev2.RespError{ErrCode: "FI.MAU.BRIDGE.TOO_MANY_LOGINS", Err: "Maximum number of logins exceeded"} func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Request) { overrideLogin, failed := prov.GetExplicitLoginForRequest(w, r) if failed { return } - login, err := prov.net.CreateLogin( - r.Context(), - prov.GetUser(r), - r.PathValue("flowID"), - ) + user := prov.GetUser(r) + if overrideLogin == nil && user.HasTooManyLogins() { + ErrTooManyLogins.AppendMessage(" (%d)", user.Permissions.MaxLogins).Write(w) + return + } + login, err := prov.net.CreateLogin(r.Context(), user, r.PathValue("flowID")) if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to create login process") RespondWithError(w, err, "Internal error creating login process") diff --git a/bridgev2/user.go b/bridgev2/user.go index 87ced1d7..af9e9694 100644 --- a/bridgev2/user.go +++ b/bridgev2/user.go @@ -176,6 +176,10 @@ func (user *User) GetUserLogins() []*UserLogin { return maps.Values(user.logins) } +func (user *User) HasTooManyLogins() bool { + return user.Permissions.MaxLogins > 0 && len(user.GetUserLoginIDs()) >= user.Permissions.MaxLogins +} + func (user *User) GetFormattedUserLogins() string { user.Bridge.cacheLock.Lock() logins := make([]string, len(user.logins)) From 75ad1961d570d7321ef69a821929aa59bad76ecc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 23 Oct 2025 17:35:08 +0300 Subject: [PATCH 1467/1647] bridgev2/errors: add special-cased message for too long voice messages --- bridgev2/errors.go | 1 + bridgev2/portal.go | 3 +++ 2 files changed, 4 insertions(+) diff --git a/bridgev2/errors.go b/bridgev2/errors.go index cf27ac6f..76668a99 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -62,6 +62,7 @@ var ( ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) ErrUnsupportedMediaType error = WrapErrorInStatus(errors.New("unsupported media type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) ErrMediaDurationTooLong error = WrapErrorInStatus(errors.New("media duration too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) + ErrVoiceMessageDurationTooLong error = WrapErrorInStatus(errors.New("voice message too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) ErrMediaTooLarge error = WrapErrorInStatus(errors.New("media too large")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) ErrIgnoringMNotice error = WrapErrorInStatus(errors.New("ignoring m.notice message")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index e87ce9d5..edc12fcc 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -977,6 +977,9 @@ func (portal *Portal) checkMessageContentCaps(caps *event.RoomFeatures, content if content.Info != nil { dur := time.Duration(content.Info.Duration) * time.Millisecond if feat.MaxDuration != nil && dur > feat.MaxDuration.Duration { + if capMsgType == event.CapMsgVoice { + return fmt.Errorf("%w: %s supports voice messages up to %s long", ErrVoiceMessageDurationTooLong, portal.Bridge.Network.GetName().DisplayName, exfmt.Duration(feat.MaxDuration.Duration)) + } return fmt.Errorf("%w: %s is longer than the maximum of %s", ErrMediaDurationTooLong, exfmt.Duration(dur), exfmt.Duration(feat.MaxDuration.Duration)) } if feat.MaxSize != 0 && int64(content.Info.Size) > feat.MaxSize { From 5d87d14b885818b8bb33f2b9350f362d00b00034 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 24 Oct 2025 12:41:20 +0300 Subject: [PATCH 1468/1647] event/powerlevels: fix some set user level calls in v12 rooms --- event/powerlevels.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/event/powerlevels.go b/event/powerlevels.go index 50df2c1f..50fe82df 100644 --- a/event/powerlevels.go +++ b/event/powerlevels.go @@ -135,6 +135,12 @@ func (pl *PowerLevelsEventContent) GetUserLevel(userID id.UserID) int { func (pl *PowerLevelsEventContent) SetUserLevel(userID id.UserID, level int) { pl.usersLock.Lock() defer pl.usersLock.Unlock() + if pl.isCreator(userID) { + return + } + if level == math.MaxInt { + level = 1<<53 - 1 + } if level == pl.UsersDefault { delete(pl.Users, userID) } else { From ee1e05c3e8b51fca93c8cf5514d45817a708b71f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 24 Oct 2025 12:56:53 +0300 Subject: [PATCH 1469/1647] event: fix 32-bit compatibility --- event/powerlevels.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/event/powerlevels.go b/event/powerlevels.go index 50fe82df..708721f9 100644 --- a/event/powerlevels.go +++ b/event/powerlevels.go @@ -132,14 +132,18 @@ func (pl *PowerLevelsEventContent) GetUserLevel(userID id.UserID) int { return level } +const maxPL = 1<<53 - 1 + func (pl *PowerLevelsEventContent) SetUserLevel(userID id.UserID, level int) { pl.usersLock.Lock() defer pl.usersLock.Unlock() if pl.isCreator(userID) { return } - if level == math.MaxInt { - level = 1<<53 - 1 + if level == math.MaxInt && maxPL < math.MaxInt { + // Hack to avoid breaking on 32-bit systems (they're only slightly supported) + x := int64(maxPL) + level = int(x) } if level == pl.UsersDefault { delete(pl.Users, userID) From 02a0aad583ed11275659e1e974f42ee32f1806f5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 24 Oct 2025 15:14:31 +0300 Subject: [PATCH 1470/1647] bridgev2/portal: add event for waiting for room creation --- bridgev2/portal.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index edc12fcc..50cd8d32 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -87,6 +87,7 @@ type Portal struct { lastCapUpdate time.Time roomCreateLock sync.Mutex + RoomCreated *exsync.Event functionalMembersLock sync.Mutex functionalMembersCache *event.ElementFunctionalMembersContent @@ -124,6 +125,11 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que currentlyTypingLogins: make(map[id.UserID]*UserLogin), currentlyTypingGhosts: exsync.NewSet[id.UserID](), outgoingMessages: make(map[networkid.TransactionID]*outgoingMessage), + + RoomCreated: exsync.NewEvent(), + } + if portal.MXID != "" { + portal.RoomCreated.Set() } // Putting the portal in the cache before it's fully initialized is mildly dangerous, // but loading the relay user login may depend on it. @@ -2043,6 +2049,7 @@ func (portal *Portal) UpdateMatrixRoomID( } else if alreadyExists { log.Debug().Msg("Replacement room is already a portal, overwriting") existingPortal.MXID = "" + existingPortal.RoomCreated.Clear() err := existingPortal.Save(ctx) if err != nil { return fmt.Errorf("failed to clear mxid of existing portal: %w", err) @@ -2050,6 +2057,7 @@ func (portal *Portal) UpdateMatrixRoomID( delete(portal.Bridge.portalsByMXID, portal.MXID) } portal.MXID = newRoomID + portal.RoomCreated.Set() portal.Bridge.portalsByMXID[portal.MXID] = portal portal.NameSet = false portal.AvatarSet = false @@ -4832,6 +4840,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo portal.TopicSet = true portal.NameSet = true portal.MXID = roomID + portal.RoomCreated.Set() portal.Bridge.cacheLock.Lock() portal.Bridge.portalsByMXID[roomID] = portal portal.Bridge.cacheLock.Unlock() @@ -4935,6 +4944,7 @@ func (portal *Portal) RemoveMXID(ctx context.Context) error { return nil } portal.MXID = "" + portal.RoomCreated.Clear() err := portal.Save(ctx) if err != nil { return err From 364ae39fefd435bcff1456830eb8f813d5b5e523 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 25 Oct 2025 15:34:48 +0300 Subject: [PATCH 1471/1647] responses: add Equal method for LazyLoadSummary --- responses.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/responses.go b/responses.go index 1a221f8a..e7b6b75e 100644 --- a/responses.go +++ b/responses.go @@ -6,12 +6,14 @@ import ( "fmt" "maps" "reflect" + "slices" "strconv" "strings" "github.com/tidwall/gjson" "github.com/tidwall/sjson" "go.mau.fi/util/jsontime" + "go.mau.fi/util/ptr" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -342,6 +344,17 @@ type LazyLoadSummary struct { InvitedMemberCount *int `json:"m.invited_member_count,omitempty"` } +func (lls *LazyLoadSummary) Equal(other *LazyLoadSummary) bool { + if lls == other { + return true + } else if lls == nil || other == nil { + return false + } + return ptr.Val(lls.JoinedMemberCount) == ptr.Val(other.JoinedMemberCount) && + ptr.Val(lls.InvitedMemberCount) == ptr.Val(other.InvitedMemberCount) && + slices.Equal(lls.Heroes, other.Heroes) +} + type SyncEventsList struct { Events []*event.Event `json:"events,omitempty"` } From d486dba9271c441532338ba3b5fdf12cd8a22623 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 25 Oct 2025 16:59:36 +0300 Subject: [PATCH 1472/1647] event: add some getters for state content --- event/state.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/event/state.go b/event/state.go index ed5434c9..db09db8e 100644 --- a/event/state.go +++ b/event/state.go @@ -96,6 +96,13 @@ type TombstoneEventContent struct { ReplacementRoom id.RoomID `json:"replacement_room"` } +func (tec *TombstoneEventContent) GetReplacementRoom() id.RoomID { + if tec == nil { + return "" + } + return tec.ReplacementRoom +} + type Predecessor struct { RoomID id.RoomID `json:"room_id"` EventID id.EventID `json:"event_id"` @@ -135,6 +142,13 @@ type CreateEventContent struct { Creator id.UserID `json:"creator,omitempty"` } +func (cec *CreateEventContent) GetPredecessor() (p Predecessor) { + if cec != nil && cec.Predecessor != nil { + p = *cec.Predecessor + } + return +} + func (cec *CreateEventContent) SupportsCreatorPower() bool { if cec == nil { return false From adc035b6a5551b3dc8dea19529dd4309a1c642e2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 27 Oct 2025 18:39:10 +0200 Subject: [PATCH 1473/1647] event: add state and member action maps to room features (#424) --- event/capabilities.d.ts | 17 ++++++++++++++ event/capabilities.go | 52 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/event/capabilities.d.ts b/event/capabilities.d.ts index 37848575..2d95cd50 100644 --- a/event/capabilities.d.ts +++ b/event/capabilities.d.ts @@ -16,6 +16,8 @@ export interface RoomFeatures { * If a message type isn't listed here, it should be treated as support level -2 (will be rejected). */ file?: Record + state?: Record + member_actions?: Record /** Maximum length of normal text messages. */ max_text_length?: integer @@ -72,6 +74,21 @@ declare type MIMETypeOrPattern = | `${MIMEClass}/${string}` | `${MIMEClass}/${string}; ${string}` +export enum MemberAction { + Ban = "ban", + Kick = "kick", + Leave = "leave", + RevokeInvite = "revoke_invite", + Invite = "invite", +} + +declare type EventType = string + +// This is an object for future extensibility (e.g. max name/topic length) +export interface StateFeatures { + level: CapabilitySupportLevel +} + export enum CapabilityMsgType { // Real message types used in the `msgtype` field Image = "m.image", diff --git a/event/capabilities.go b/event/capabilities.go index 42afe5b6..5ecea4a2 100644 --- a/event/capabilities.go +++ b/event/capabilities.go @@ -28,8 +28,10 @@ type RoomFeatures struct { // N.B. New fields need to be added to the Hash function to be included in the deduplication hash. - Formatting FormattingFeatureMap `json:"formatting,omitempty"` - File FileFeatureMap `json:"file,omitempty"` + Formatting FormattingFeatureMap `json:"formatting,omitempty"` + File FileFeatureMap `json:"file,omitempty"` + State StateFeatureMap `json:"state,omitempty"` + MemberActions MemberFeatureMap `json:"member_actions,omitempty"` MaxTextLength int `json:"max_text_length,omitempty"` @@ -74,6 +76,8 @@ func (rf *RoomFeatures) Clone() *RoomFeatures { clone := *rf clone.File = clone.File.Clone() clone.Formatting = maps.Clone(clone.Formatting) + clone.State = clone.State.Clone() + clone.MemberActions = clone.MemberActions.Clone() clone.EditMaxAge = ptr.Clone(clone.EditMaxAge) clone.DeleteMaxAge = ptr.Clone(clone.DeleteMaxAge) clone.DisappearingTimer = clone.DisappearingTimer.Clone() @@ -81,6 +85,48 @@ func (rf *RoomFeatures) Clone() *RoomFeatures { return &clone } +type MemberFeatureMap map[MemberAction]CapabilitySupportLevel + +func (mfm MemberFeatureMap) Clone() MemberFeatureMap { + return maps.Clone(mfm) +} + +type MemberAction string + +const ( + MemberActionBan MemberAction = "ban" + MemberActionKick MemberAction = "kick" + MemberActionLeave MemberAction = "leave" + MemberActionRevokeInvite MemberAction = "revoke_invite" + MemberActionInvite MemberAction = "invite" +) + +type StateFeatureMap map[string]*StateFeatures + +func (sfm StateFeatureMap) Clone() StateFeatureMap { + dup := maps.Clone(sfm) + for key, value := range dup { + dup[key] = value.Clone() + } + return dup +} + +type StateFeatures struct { + Level CapabilitySupportLevel `json:"level"` +} + +func (sf *StateFeatures) Clone() *StateFeatures { + if sf == nil { + return nil + } + clone := *sf + return &clone +} + +func (sf *StateFeatures) Hash() []byte { + return sf.Level.Hash() +} + type FormattingFeatureMap map[FormattingFeature]CapabilitySupportLevel type FileFeatureMap map[CapabilityMsgType]*FileFeatures @@ -266,6 +312,8 @@ func (rf *RoomFeatures) Hash() []byte { hashMap(hasher, "formatting", rf.Formatting) hashMap(hasher, "file", rf.File) + hashMap(hasher, "state", rf.State) + hashMap(hasher, "member_actions", rf.MemberActions) hashInt(hasher, "max_text_length", rf.MaxTextLength) From bea28c1381cd2a0047b5bcbdb763109c1247fe5d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 28 Oct 2025 14:48:06 +0200 Subject: [PATCH 1474/1647] bridgev2/portal: log mismatching disappearing timers in events --- bridgev2/database/disappear.go | 10 ++++++++++ bridgev2/portal.go | 24 ++++++++++++++++++++++-- 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/bridgev2/database/disappear.go b/bridgev2/database/disappear.go index c2d7d56c..df36b205 100644 --- a/bridgev2/database/disappear.go +++ b/bridgev2/database/disappear.go @@ -37,6 +37,16 @@ type DisappearingSetting struct { DisappearAt time.Time } +func DisappearingSettingFromEvent(evt *event.BeeperDisappearingTimer) DisappearingSetting { + if evt == nil || evt.Type == event.DisappearingTypeNone { + return DisappearingSetting{} + } + return DisappearingSetting{ + Type: evt.Type, + Timer: evt.Timer.Duration, + } +} + func (ds DisappearingSetting) Normalize() DisappearingSetting { if ds.Type == event.DisappearingTypeNone { ds.Timer = 0 diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 50cd8d32..ed6756c9 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1122,6 +1122,16 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin } } } + var messageTimer *event.BeeperDisappearingTimer + if msgContent != nil { + messageTimer = msgContent.BeeperDisappearingTimer + } + if messageTimer != nil && *portal.Disappear.ToEventContent() != *messageTimer { + log.Warn(). + Any("event_timer", messageTimer). + Any("portal_timer", portal.Disappear.ToEventContent()). + Msg("Mismatching disappearing timer in event") + } wrappedMsgEvt := &MatrixMessage{ MatrixEventBase: MatrixEventBase[*event.MessageEventContent]{ @@ -1198,12 +1208,16 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin } portal.sendSuccessStatus(ctx, evt, resp.StreamOrder, message.MXID) } - if portal.Disappear.Type != event.DisappearingTypeNone { + ds := portal.Disappear + if messageTimer != nil { + ds = database.DisappearingSettingFromEvent(messageTimer) + } + if ds.Type != event.DisappearingTypeNone { go portal.Bridge.DisappearLoop.Add(ctx, &database.DisappearingMessage{ RoomID: portal.MXID, EventID: message.MXID, Timestamp: message.Timestamp, - DisappearingSetting: portal.Disappear.StartingAt(message.Timestamp), + DisappearingSetting: ds.StartingAt(message.Timestamp), }) } if resp.Pending { @@ -4082,6 +4096,12 @@ func (portal *Portal) sendRoomMeta( Msg("Failed to set room metadata") return false } + if eventType == event.StateBeeperDisappearingTimer { + // TODO remove this debug log at some point + zerolog.Ctx(ctx).Debug(). + Any("content", content). + Msg("Sent new disappearing timer event") + } return true } From 76cb8ee7d3c9d78945a648d00b535df129a0c0d1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 28 Oct 2025 22:46:29 +0200 Subject: [PATCH 1475/1647] bridgev2/provisioning: add option to skip identifier validation in create group --- bridgev2/networkinterface.go | 4 ++++ bridgev2/provisionutil/creategroup.go | 6 ++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index da505435..8a39c7f8 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -826,6 +826,10 @@ type GroupFieldCapability struct { // Only for the disappear field: allowed disappearing settings DisappearSettings *event.DisappearingTimerCapability `json:"settings,omitempty"` + + // This can be used to tell provisionutil not to call ValidateUserID on each participant. + // It only meant to allow hacks where ResolveIdentifier returns a fake ID that isn't actually valid for MXIDs. + SkipIdentifierValidation bool `json:"-"` } type GroupCreateParams struct { diff --git a/bridgev2/provisionutil/creategroup.go b/bridgev2/provisionutil/creategroup.go index 0df09ff5..55a21b1a 100644 --- a/bridgev2/provisionutil/creategroup.go +++ b/bridgev2/provisionutil/creategroup.go @@ -47,8 +47,10 @@ func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev participant = parsedParticipant params.Participants[i] = participant } - if uidValOK && !userIDValidatingNetwork.ValidateUserID(participant) { - return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("User ID %q is not valid on this network", participant)) + if !typeSpec.Participants.SkipIdentifierValidation { + if uidValOK && !userIDValidatingNetwork.ValidateUserID(participant) { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("User ID %q is not valid on this network", participant)) + } } if api.IsThisUser(ctx, participant) { return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("You can't include yourself in the participants list", participant)) From 1edfccb4e2a941d6120bd57ba617e6566d916e0d Mon Sep 17 00:00:00 2001 From: timedout Date: Wed, 29 Oct 2025 17:55:12 +0000 Subject: [PATCH 1476/1647] federation/client: Use PUT instead of POST to send transactions (#426) --- federation/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/federation/client.go b/federation/client.go index f3163f3a..b20af4ab 100644 --- a/federation/client.go +++ b/federation/client.go @@ -89,7 +89,7 @@ type RespSendTransaction struct { } func (c *Client) SendTransaction(ctx context.Context, req *ReqSendTransaction) (resp *RespSendTransaction, err error) { - err = c.MakeRequest(ctx, req.Destination, true, http.MethodPost, URLPath{"v1", "send", req.TxnID}, req, &resp) + err = c.MakeRequest(ctx, req.Destination, true, http.MethodPut, URLPath{"v1", "send", req.TxnID}, req, &resp) return } From 0da017515743847bb4e5225a5f8ea7fcb4ae53b7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 29 Oct 2025 20:52:27 +0200 Subject: [PATCH 1477/1647] bridgev2: add new flag for slack remote ID migration --- bridgev2/portal.go | 1 + event/state.go | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index ed6756c9..4943ab00 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -3980,6 +3980,7 @@ func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) { } if bridgeInfo.Protocol.ID == "slackgo" { bridgeInfo.TempSlackRemoteIDMigratedFlag = true + bridgeInfo.TempSlackRemoteIDMigratedFlag2 = true } parent := portal.GetTopLevelParent() if parent != nil { diff --git a/event/state.go b/event/state.go index db09db8e..6df3b143 100644 --- a/event/state.go +++ b/event/state.go @@ -246,7 +246,8 @@ type BridgeEventContent struct { BeeperRoomType string `json:"com.beeper.room_type,omitempty"` BeeperRoomTypeV2 string `json:"com.beeper.room_type.v2,omitempty"` - TempSlackRemoteIDMigratedFlag bool `json:"com.beeper.slack_remote_id_migrated,omitempty"` + TempSlackRemoteIDMigratedFlag bool `json:"com.beeper.slack_remote_id_migrated,omitempty"` + TempSlackRemoteIDMigratedFlag2 bool `json:"com.beeper.slack_remote_id_really_migrated,omitempty"` } // DisappearingType represents the type of a disappearing message timer. From be9bbf8d098f8fa0f9f7f4d3da968c1efe94f83a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 29 Oct 2025 22:50:02 +0200 Subject: [PATCH 1478/1647] bridgev2/provisioning: fix max length checks in group creation --- bridgev2/provisionutil/creategroup.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/bridgev2/provisionutil/creategroup.go b/bridgev2/provisionutil/creategroup.go index 55a21b1a..fbe0a513 100644 --- a/bridgev2/provisionutil/creategroup.go +++ b/bridgev2/provisionutil/creategroup.go @@ -39,6 +39,8 @@ func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev } if len(params.Participants) < typeSpec.Participants.MinLength { return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Must have at least %d members", typeSpec.Participants.MinLength)) + } else if typeSpec.Participants.MaxLength > 0 && len(params.Participants) > typeSpec.Participants.MaxLength { + return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Must have at most %d members", typeSpec.Participants.MaxLength)) } userIDValidatingNetwork, uidValOK := login.Bridge.Network.(bridgev2.IdentifierValidatingNetwork) for i, participant := range params.Participants { @@ -60,7 +62,7 @@ func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name is required")) } else if nameLen := len(ptr.Val(params.Name).Name); nameLen > 0 && nameLen < typeSpec.Name.MinLength { return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name must be at least %d characters", typeSpec.Name.MinLength)) - } else if nameLen > typeSpec.Name.MaxLength { + } else if typeSpec.Name.MaxLength > 0 && nameLen > typeSpec.Name.MaxLength { return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Name must be at most %d characters", typeSpec.Name.MaxLength)) } if (params.Avatar == nil || params.Avatar.URL == "") && typeSpec.Avatar.Required { @@ -70,7 +72,7 @@ func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic is required")) } else if topicLen := len(ptr.Val(params.Topic).Topic); topicLen > 0 && topicLen < typeSpec.Topic.MinLength { return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic must be at least %d characters", typeSpec.Topic.MinLength)) - } else if topicLen > typeSpec.Topic.MaxLength { + } else if typeSpec.Topic.MaxLength > 0 && topicLen > typeSpec.Topic.MaxLength { return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Topic must be at most %d characters", typeSpec.Topic.MaxLength)) } if (params.Disappear == nil || params.Disappear.Timer.Duration == 0) && typeSpec.Disappear.Required { @@ -82,7 +84,7 @@ func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username is required")) } else if len(params.Username) > 0 && len(params.Username) < typeSpec.Username.MinLength { return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username must be at least %d characters", typeSpec.Username.MinLength)) - } else if len(params.Username) > typeSpec.Username.MaxLength { + } else if typeSpec.Username.MaxLength > 0 && len(params.Username) > typeSpec.Username.MaxLength { return nil, bridgev2.RespError(mautrix.MInvalidParam.WithMessage("Username must be at most %d characters", typeSpec.Username.MaxLength)) } if params.Parent == nil && typeSpec.Parent.Required { From 2ece053b2bfae533e0c6368c2118d7101b8ae932 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 31 Oct 2025 00:07:24 +0200 Subject: [PATCH 1479/1647] bridgev2: roll back failed room metadata changes (#425) --- bridgev2/bridgeconfig/config.go | 1 + bridgev2/bridgeconfig/upgrade.go | 1 + bridgev2/errors.go | 1 + bridgev2/matrix/connector.go | 1 + bridgev2/matrix/intent.go | 4 +- bridgev2/matrix/mxmain/example-config.yaml | 2 + bridgev2/matrixinterface.go | 5 +- bridgev2/portal.go | 75 ++++++++++++++++++---- bridgev2/portalinternal.go | 8 ++- versions.go | 15 +++-- 10 files changed, 86 insertions(+), 27 deletions(-) diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index 13ec738c..01819945 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -77,6 +77,7 @@ type BridgeConfig struct { DeduplicateMatrixMessages bool `yaml:"deduplicate_matrix_messages"` CrossRoomReplies bool `yaml:"cross_room_replies"` OutgoingMessageReID bool `yaml:"outgoing_message_re_id"` + RevertFailedStateChanges bool `yaml:"revert_failed_state_changes"` CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"` Relay RelayConfig `yaml:"relay"` Permissions PermissionConfig `yaml:"permissions"` diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 6533338f..be8a8f96 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -40,6 +40,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "bridge", "mute_only_on_create") helper.Copy(up.Bool, "bridge", "deduplicate_matrix_messages") helper.Copy(up.Bool, "bridge", "cross_room_replies") + helper.Copy(up.Bool, "bridge", "revert_failed_state_changes") helper.Copy(up.Bool, "bridge", "cleanup_on_logout", "enabled") helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "private") helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "relayed") diff --git a/bridgev2/errors.go b/bridgev2/errors.go index 76668a99..a06f30ed 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -54,6 +54,7 @@ var ( ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) ErrPollsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support polls")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrRoomMetadataNotAllowed error = WrapErrorInStatus(errors.New("changes are not allowed here")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) ErrInvalidStateKey error = WrapErrorInStatus(errors.New("room metadata state key is unset or non-empty")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 64b5d6c7..edd98045 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -361,6 +361,7 @@ func (br *Connector) ensureConnection(ctx context.Context) { *br.AS.SpecVersions = *versions br.Capabilities.AutoJoinInvites = br.SpecVersions.Supports(mautrix.BeeperFeatureAutojoinInvites) br.Capabilities.BatchSending = br.SpecVersions.Supports(mautrix.BeeperFeatureBatchSending) + br.Capabilities.ArbitraryMemberChange = br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryMemberChange) break } } diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index ab59a582..27892fb6 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -90,8 +90,8 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType } func (as *ASIntent) fillMemberEvent(ctx context.Context, roomID id.RoomID, userID id.UserID, content *event.Content) { - targetContent := content.Parsed.(*event.MemberEventContent) - if targetContent.Displayname != "" || targetContent.AvatarURL != "" { + targetContent, ok := content.Parsed.(*event.MemberEventContent) + if !ok || targetContent.Displayname != "" || targetContent.AvatarURL != "" { return } memberContent, err := as.Matrix.StateStore.TryGetMember(ctx, roomID, userID) diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index d8634028..aeb5b7db 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -47,6 +47,8 @@ bridge: # Should cross-room reply metadata be bridged? # Most Matrix clients don't support this and servers may reject such messages too. cross_room_replies: false + # If a state event fails to bridge, should the bridge revert any state changes made by that event? + revert_failed_state_changes: false # What should be done to portal rooms when a user logs out or is logged out? # Permitted values: diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 6fa5360c..e8489dc1 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -23,8 +23,9 @@ import ( ) type MatrixCapabilities struct { - AutoJoinInvites bool - BatchSending bool + AutoJoinInvites bool + BatchSending bool + ArbitraryMemberChange bool } type MatrixConnector interface { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 4943ab00..67199ada 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -520,6 +520,9 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal portal.sendSuccessStatus(ctx, evt.evt, 0, "") } } + if res.Error != nil && evt.evt.StateKey != nil { + portal.revertRoomMeta(ctx, evt.evt) + } case *portalRemoteEvent: res = portal.handleRemoteEvent(ctx, evt.source, evt.evtType, evt.evt) case *portalCreateEvent: @@ -1562,9 +1565,13 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( if evt.StateKey == nil || *evt.StateKey != "" { return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey) } + //caps := sender.Client.GetCapabilities(ctx, portal) + //if stateCap, ok := caps.State[evt.Type.Type]; !ok || stateCap.Level <= event.CapLevelUnsupported { + // return EventHandlingResultIgnored.WithMSSError(fmt.Errorf("%s %w", evt.Type.Type, ErrRoomMetadataNotAllowed)) + //} api, ok := sender.Client.(APIType) if !ok { - return EventHandlingResultIgnored.WithMSSError(ErrRoomMetadataNotSupported) + return EventHandlingResultIgnored.WithMSSError(fmt.Errorf("%w of type %s", ErrRoomMetadataNotSupported, evt.Type)) } log := zerolog.Ctx(ctx) content, ok := evt.Content.Parsed.(ContentType) @@ -1598,7 +1605,6 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( return EventHandlingResultIgnored } if !sender.Client.GetCapabilities(ctx, portal).DisappearingTimer.Supports(typedContent) { - portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent(), false) return EventHandlingResultFailed.WithMSSError(ErrDisappearingTimerUnsupported) } } @@ -1621,9 +1627,6 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( }) if err != nil { log.Err(err).Msg("Failed to handle Matrix room metadata") - if evt.Type == event.StateBeeperDisappearingTimer { - portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent(), false) - } return EventHandlingResultFailed.WithMSSError(err) } if changed { @@ -3891,7 +3894,7 @@ func (portal *Portal) updateName( } portal.Name = name portal.NameSet = portal.sendRoomMeta( - ctx, sender, ts, event.StateRoomName, "", &event.RoomNameEventContent{Name: name}, excludeFromTimeline, + ctx, sender, ts, event.StateRoomName, "", &event.RoomNameEventContent{Name: name}, excludeFromTimeline, nil, ) return true } @@ -3904,7 +3907,7 @@ func (portal *Portal) updateTopic( } portal.Topic = topic portal.TopicSet = portal.sendRoomMeta( - ctx, sender, ts, event.StateTopic, "", &event.TopicEventContent{Topic: topic}, excludeFromTimeline, + ctx, sender, ts, event.StateTopic, "", &event.TopicEventContent{Topic: topic}, excludeFromTimeline, nil, ) return true } @@ -3935,7 +3938,7 @@ func (portal *Portal) updateAvatar( portal.AvatarHash = newHash } portal.AvatarSet = portal.sendRoomMeta( - ctx, sender, ts, event.StateRoomAvatar, "", &event.RoomAvatarEventContent{URL: portal.AvatarMXC}, excludeFromTimeline, + ctx, sender, ts, event.StateRoomAvatar, "", &event.RoomAvatarEventContent{URL: portal.AvatarMXC}, excludeFromTimeline, nil, ) return true } @@ -4003,8 +4006,8 @@ func (portal *Portal) UpdateBridgeInfo(ctx context.Context) { return } stateKey, bridgeInfo := portal.getBridgeInfo() - portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBridge, stateKey, &bridgeInfo, false) - portal.sendRoomMeta(ctx, nil, time.Now(), event.StateHalfShotBridge, stateKey, &bridgeInfo, false) + portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBridge, stateKey, &bridgeInfo, false, nil) + portal.sendRoomMeta(ctx, nil, time.Now(), event.StateHalfShotBridge, stateKey, &bridgeInfo, false, nil) } func (portal *Portal) UpdateCapabilities(ctx context.Context, source *UserLogin, implicit bool) bool { @@ -4026,7 +4029,7 @@ func (portal *Portal) UpdateCapabilities(ctx context.Context, source *UserLogin, Str("old_id", portal.CapState.ID). Str("new_id", capID). Msg("Sending new room capability event") - success := portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperRoomFeatures, portal.getBridgeInfoStateKey(), caps, false) + success := portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperRoomFeatures, portal.getBridgeInfoStateKey(), caps, false, nil) if !success { return false } @@ -4037,7 +4040,7 @@ func (portal *Portal) UpdateCapabilities(ctx context.Context, source *UserLogin, } if caps.DisappearingTimer != nil && !portal.CapState.Flags.Has(database.CapStateFlagDisappearingTimerSet) { zerolog.Ctx(ctx).Debug().Msg("Disappearing timer capability was added, sending disappearing timer state event") - success = portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent(), true) + success = portal.sendRoomMeta(ctx, nil, time.Now(), event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent(), true, nil) if !success { return false } @@ -4076,11 +4079,14 @@ func (portal *Portal) sendRoomMeta( stateKey string, content any, excludeFromTimeline bool, + extra map[string]any, ) bool { if portal.MXID == "" { return false } - extra := make(map[string]any) + if extra == nil { + extra = make(map[string]any) + } if excludeFromTimeline { extra["com.beeper.exclude_from_timeline"] = true } @@ -4106,6 +4112,46 @@ func (portal *Portal) sendRoomMeta( return true } +func (portal *Portal) revertRoomMeta(ctx context.Context, evt *event.Event) { + if !portal.Bridge.Config.RevertFailedStateChanges { + return + } + if evt.GetStateKey() != "" && evt.Type != event.StateMember { + return + } + switch evt.Type { + case event.StateRoomName: + portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateRoomName, "", &event.RoomNameEventContent{Name: portal.Name}, true, nil) + case event.StateRoomAvatar: + portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateRoomAvatar, "", &event.RoomAvatarEventContent{URL: portal.AvatarMXC}, true, nil) + case event.StateTopic: + portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateTopic, "", &event.TopicEventContent{Topic: portal.Topic}, true, nil) + case event.StateBeeperDisappearingTimer: + portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateBeeperDisappearingTimer, "", portal.Disappear.ToEventContent(), true, nil) + case event.StateMember: + var prevContent *event.MemberEventContent + var extra map[string]any + if evt.Unsigned.PrevContent != nil { + _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) + prevContent = evt.Unsigned.PrevContent.AsMember() + newContent := evt.Content.AsMember() + if prevContent.Membership == newContent.Membership { + return + } + extra = evt.Unsigned.PrevContent.Raw + } else { + prevContent = &event.MemberEventContent{Membership: event.MembershipLeave} + } + if portal.Bridge.Matrix.GetCapabilities().ArbitraryMemberChange { + if extra == nil { + extra = make(map[string]any) + } + extra["com.beeper.member_rollback"] = true + portal.sendRoomMeta(ctx, nil, time.Time{}, event.StateMember, evt.GetStateKey(), prevContent, true, extra) + } + } +} + func (portal *Portal) getInitialMemberList(ctx context.Context, members *ChatMemberList, source *UserLogin, pl *event.PowerLevelsEventContent) (invite, functional []id.UserID, err error) { if members == nil { invite = []id.UserID{source.UserMXID} @@ -4490,6 +4536,7 @@ func (portal *Portal) UpdateDisappearingSetting( "", setting.ToEventContent(), opts.ExcludeFromTimeline, + nil, ) if !opts.SendNotice { @@ -4618,7 +4665,7 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us } if info.JoinRule != nil { // TODO change detection instead of spamming this every time? - portal.sendRoomMeta(ctx, sender, ts, event.StateJoinRules, "", info.JoinRule, info.ExcludeChangesFromTimeline) + portal.sendRoomMeta(ctx, sender, ts, event.StateJoinRules, "", info.JoinRule, info.ExcludeChangesFromTimeline, nil) } if info.Type != nil && portal.RoomType != *info.Type { if portal.MXID != "" && (*info.Type == database.RoomTypeSpace || portal.RoomType == database.RoomTypeSpace) { diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index d9373eb6..749ee389 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -289,8 +289,12 @@ func (portal *PortalInternals) SendStateWithIntentOrBot(ctx context.Context, sen return (*Portal)(portal).sendStateWithIntentOrBot(ctx, sender, eventType, stateKey, content, ts) } -func (portal *PortalInternals) SendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any, excludeFromTimeline bool) bool { - return (*Portal)(portal).sendRoomMeta(ctx, sender, ts, eventType, stateKey, content, excludeFromTimeline) +func (portal *PortalInternals) SendRoomMeta(ctx context.Context, sender MatrixAPI, ts time.Time, eventType event.Type, stateKey string, content any, excludeFromTimeline bool, extra map[string]any) bool { + return (*Portal)(portal).sendRoomMeta(ctx, sender, ts, eventType, stateKey, content, excludeFromTimeline, extra) +} + +func (portal *PortalInternals) RevertRoomMeta(ctx context.Context, evt *event.Event) { + (*Portal)(portal).revertRoomMeta(ctx, evt) } func (portal *PortalInternals) GetInitialMemberList(ctx context.Context, members *ChatMemberList, source *UserLogin, pl *event.PowerLevelsEventContent) (invite, functional []id.UserID, err error) { diff --git a/versions.go b/versions.go index 0392532e..8c1c49aa 100644 --- a/versions.go +++ b/versions.go @@ -70,13 +70,14 @@ var ( FeatureUnstableProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133"} FeatureArbitraryProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133.stable", SpecVersion: SpecV116} - BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"} - BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"} - BeeperFeatureRoomYeeting = UnstableFeature{UnstableFlag: "com.beeper.room_yeeting"} - BeeperFeatureAutojoinInvites = UnstableFeature{UnstableFlag: "com.beeper.room_create_autojoin_invites"} - BeeperFeatureArbitraryProfileMeta = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_profile_meta"} - BeeperFeatureAccountDataMute = UnstableFeature{UnstableFlag: "com.beeper.account_data_mute"} - BeeperFeatureInboxState = UnstableFeature{UnstableFlag: "com.beeper.inbox_state"} + BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"} + BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"} + BeeperFeatureRoomYeeting = UnstableFeature{UnstableFlag: "com.beeper.room_yeeting"} + BeeperFeatureAutojoinInvites = UnstableFeature{UnstableFlag: "com.beeper.room_create_autojoin_invites"} + BeeperFeatureArbitraryProfileMeta = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_profile_meta"} + BeeperFeatureAccountDataMute = UnstableFeature{UnstableFlag: "com.beeper.account_data_mute"} + BeeperFeatureInboxState = UnstableFeature{UnstableFlag: "com.beeper.inbox_state"} + BeeperFeatureArbitraryMemberChange = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_member_change"} ) func (versions *RespVersions) Supports(feature UnstableFeature) bool { From 8e23192a7d6664beb9d0da1a40843bcbb2aaaf1c Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Fri, 31 Oct 2025 10:01:49 +0000 Subject: [PATCH 1480/1647] client: support sending custom txn ID query param with state events --- client.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/client.go b/client.go index d8bd5b80..dcc3fe5e 100644 --- a/client.go +++ b/client.go @@ -1353,6 +1353,9 @@ func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventTy if req.MeowEventID != "" { queryParams["fi.mau.event_id"] = req.MeowEventID.String() } + if req.TransactionID != "" { + queryParams["fi.mau.transaction_id"] = req.TransactionID + } if req.UnstableDelay > 0 { queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10) } From 175f5a1c61df3c4dd7907712193f4dcb3cefff88 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 31 Oct 2025 21:11:24 +0100 Subject: [PATCH 1481/1647] federation/serverauth: fix request uri --- federation/serverauth.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/federation/serverauth.go b/federation/serverauth.go index f46c7991..cd300341 100644 --- a/federation/serverauth.go +++ b/federation/serverauth.go @@ -231,7 +231,7 @@ func (sa *ServerAuth) Authenticate(r *http.Request) (*http.Request, *mautrix.Res } err = (&signableRequest{ Method: r.Method, - URI: r.URL.EscapedPath(), + URI: r.URL.RequestURI(), Origin: parsed.Origin, Destination: destination, Content: reqBody, From 4ec3fbb4ab40dc77317205bac7b2898a6ed5f4e8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 1 Nov 2025 22:10:43 +0100 Subject: [PATCH 1482/1647] crypto/goolm: fix var bytes read overflow --- crypto/goolm/message/decoder.go | 5 +++++ crypto/goolm/message/group_message.go | 5 +++++ crypto/goolm/message/message.go | 5 +++++ crypto/goolm/message/prekey_message.go | 9 +++++++++ 4 files changed, 24 insertions(+) diff --git a/crypto/goolm/message/decoder.go b/crypto/goolm/message/decoder.go index a71cf302..b06756a9 100644 --- a/crypto/goolm/message/decoder.go +++ b/crypto/goolm/message/decoder.go @@ -3,6 +3,9 @@ package message import ( "bytes" "encoding/binary" + "fmt" + + "maunium.net/go/mautrix/crypto/olm" ) type Decoder struct { @@ -20,6 +23,8 @@ func (d *Decoder) ReadVarInt() (uint64, error) { func (d *Decoder) ReadVarBytes() ([]byte, error) { if n, err := d.ReadVarInt(); err != nil { return nil, err + } else if n > uint64(d.Len()) { + return nil, fmt.Errorf("%w: var bytes length says %d, but only %d bytes left", olm.ErrInputToSmall, n, d.Available()) } else { out := make([]byte, n) _, err = d.Read(out) diff --git a/crypto/goolm/message/group_message.go b/crypto/goolm/message/group_message.go index c2a43b1f..f3d22500 100644 --- a/crypto/goolm/message/group_message.go +++ b/crypto/goolm/message/group_message.go @@ -2,10 +2,12 @@ package message import ( "bytes" + "fmt" "io" "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -36,6 +38,9 @@ func (r *GroupMessage) Decode(input []byte) (err error) { if err != nil { return } + if r.Version != protocolVersion { + return fmt.Errorf("GroupMessage.Decode: %w", olm.ErrWrongProtocolVersion) + } for { // Read Key diff --git a/crypto/goolm/message/message.go b/crypto/goolm/message/message.go index 8bb6e0cd..9ef93630 100644 --- a/crypto/goolm/message/message.go +++ b/crypto/goolm/message/message.go @@ -2,10 +2,12 @@ package message import ( "bytes" + "fmt" "io" "maunium.net/go/mautrix/crypto/goolm/aessha2" "maunium.net/go/mautrix/crypto/goolm/crypto" + "maunium.net/go/mautrix/crypto/olm" ) const ( @@ -40,6 +42,9 @@ func (r *Message) Decode(input []byte) (err error) { if err != nil { return } + if r.Version != protocolVersion { + return fmt.Errorf("Message.Decode: %w", olm.ErrWrongProtocolVersion) + } for { // Read Key diff --git a/crypto/goolm/message/prekey_message.go b/crypto/goolm/message/prekey_message.go index 22ebf9c3..760be4c9 100644 --- a/crypto/goolm/message/prekey_message.go +++ b/crypto/goolm/message/prekey_message.go @@ -1,6 +1,7 @@ package message import ( + "fmt" "io" "maunium.net/go/mautrix/crypto/goolm/crypto" @@ -22,6 +23,11 @@ type PreKeyMessage struct { Message []byte `json:"message"` } +// TODO deduplicate constant with one in session/olm_session.go +const ( + protocolVersion = 0x3 +) + // Decodes decodes the input and populates the corresponding fileds. func (r *PreKeyMessage) Decode(input []byte) (err error) { r.Version = 0 @@ -41,6 +47,9 @@ func (r *PreKeyMessage) Decode(input []byte) (err error) { } return } + if r.Version != protocolVersion { + return fmt.Errorf("PreKeyMessage.Decode: %w", olm.ErrWrongProtocolVersion) + } for { // Read Key From 6e7b692098a170c89b14b72b0d1a2a4b85301a9e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 1 Nov 2025 22:19:57 +0100 Subject: [PATCH 1483/1647] federation/eventauth: fix restricted joins typo --- federation/eventauth/eventauth.go | 2 +- federation/pdu/auth.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/federation/eventauth/eventauth.go b/federation/eventauth/eventauth.go index bd102213..32b4424b 100644 --- a/federation/eventauth/eventauth.go +++ b/federation/eventauth/eventauth.go @@ -310,7 +310,7 @@ func authorizeMember(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEv // 5.1. If there is no state_key property, or no membership property in content, reject. return ErrMemberNotState } - authorizedVia := id.UserID(gjson.GetBytes(evt.Content, "authorized_via_users_server").Str) + authorizedVia := id.UserID(gjson.GetBytes(evt.Content, "authorised_via_users_server").Str) if authorizedVia != "" { homeserver := authorizedVia.Homeserver() err := evt.VerifySignature(roomVersion, homeserver, getKey) diff --git a/federation/pdu/auth.go b/federation/pdu/auth.go index 1f98de06..16706fe5 100644 --- a/federation/pdu/auth.go +++ b/federation/pdu/auth.go @@ -61,7 +61,7 @@ func (pdu *PDU) AuthEventSelection(roomVersion id.RoomVersion) (keys AuthEventSe } } if membership == event.MembershipJoin && roomVersion.RestrictedJoins() { - authorizedVia := gjson.GetBytes(pdu.Content, "authorized_via_users_server").Str + authorizedVia := gjson.GetBytes(pdu.Content, "authorised_via_users_server").Str if authorizedVia != "" { keys.Add(event.StateMember.Type, authorizedVia) } From cfa47299df03606ae04fe56ecec175c71cd5349a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 6 Nov 2025 09:26:28 +0100 Subject: [PATCH 1484/1647] bridgev2/provisioning: add select type for login user input --- bridgev2/login.go | 4 ++++ bridgev2/matrix/provisioning.yaml | 7 ++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/bridgev2/login.go b/bridgev2/login.go index 1fa3afbc..46dcf7da 100644 --- a/bridgev2/login.go +++ b/bridgev2/login.go @@ -178,6 +178,7 @@ const ( LoginInputFieldTypeToken LoginInputFieldType = "token" LoginInputFieldTypeURL LoginInputFieldType = "url" LoginInputFieldTypeDomain LoginInputFieldType = "domain" + LoginInputFieldTypeSelect LoginInputFieldType = "select" ) type LoginInputDataField struct { @@ -191,6 +192,9 @@ type LoginInputDataField struct { Description string `json:"description"` // A regex pattern that the client can use to validate input client-side. Pattern string `json:"pattern,omitempty"` + // For fields of type select, the valid options. + // Pattern may also be filled with a regex that matches the same options. + Options []string `json:"options,omitempty"` // A function that validates the input and optionally cleans it up before it's submitted to the connector. Validate func(string) (string, error) `json:"-"` } diff --git a/bridgev2/matrix/provisioning.yaml b/bridgev2/matrix/provisioning.yaml index 21c93ca4..50b73c66 100644 --- a/bridgev2/matrix/provisioning.yaml +++ b/bridgev2/matrix/provisioning.yaml @@ -714,7 +714,7 @@ components: type: type: string description: The type of field. - enum: [ username, phone_number, email, password, 2fa_code, token, url, domain ] + enum: [ username, phone_number, email, password, 2fa_code, token, url, domain, select ] id: type: string description: The internal ID of the field. This must be used as the key in the object when submitting the data back to the bridge. @@ -732,6 +732,11 @@ components: type: string format: regex description: A regular expression that the field value must match. + select: + type: array + description: For fields of type select, the valid options. + items: + type: string - description: Cookie login step required: [ type, cookies ] properties: From 36d4e1f99c22aef40765ef8dcd4414ffa8d89399 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 6 Nov 2025 16:37:27 +0100 Subject: [PATCH 1485/1647] federation: don't close body when not reading it Closes #431 --- client.go | 1 + federation/client.go | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index dcc3fe5e..3c60a2d1 100644 --- a/client.go +++ b/client.go @@ -707,6 +707,7 @@ const ErrorResponseSizeLimit = 512 * 1024 var DefaultResponseSizeLimit int64 = 512 * 1024 * 1024 func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) { + defer res.Body.Close() contents, err := readResponseBody(req, res, ErrorResponseSizeLimit) if err != nil { return contents, err diff --git a/federation/client.go b/federation/client.go index b20af4ab..b24fd2d2 100644 --- a/federation/client.go +++ b/federation/client.go @@ -314,9 +314,9 @@ func (c *Client) MakeFullRequest(ctx context.Context, params RequestParams) ([]b WrappedError: err, } } - defer func() { - _ = resp.Body.Close() - }() + if !params.DontReadBody { + defer resp.Body.Close() + } var body []byte if resp.StatusCode >= 300 { body, err = mautrix.ParseErrorResponse(req, resp) From 3014bf966c6bf81ae845347ccf3ae52a36a6161d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 6 Nov 2025 16:37:50 +0100 Subject: [PATCH 1486/1647] bridgev2/commands: include options in user input prompt --- bridgev2/commands/login.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go index 0f7bd821..80a7c733 100644 --- a/bridgev2/commands/login.go +++ b/bridgev2/commands/login.go @@ -199,11 +199,14 @@ type userInputLoginCommandState struct { func (uilcs *userInputLoginCommandState) promptNext(ce *Event) { field := uilcs.RemainingFields[0] + parts := []string{fmt.Sprintf("Please enter your %s", field.Name)} if field.Description != "" { - ce.Reply("Please enter your %s\n%s", field.Name, field.Description) - } else { - ce.Reply("Please enter your %s", field.Name) + parts = append(parts, field.Description) } + if len(field.Options) > 0 { + parts = append(parts, fmt.Sprintf("Options: `%s`", strings.Join(field.Options, "`, `"))) + } + ce.Reply(strings.Join(parts, "\n")) StoreCommandState(ce.User, &CommandState{ Next: MinimalCommandHandlerFunc(uilcs.submitNext), Action: "Login", From bade596e495e2aa28cb59350602ec6fe221dc6c6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 7 Nov 2025 14:33:00 +0100 Subject: [PATCH 1487/1647] bridgev2/portal: allow chaining ChatMembermap.Set calls --- bridgev2/portal.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 67199ada..8d846f43 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -3705,11 +3705,12 @@ type ChatMember struct { type ChatMemberMap map[networkid.UserID]ChatMember // Set adds the given entry to this map, overwriting any existing entry with the same Sender field. -func (cmm ChatMemberMap) Set(member ChatMember) { +func (cmm ChatMemberMap) Set(member ChatMember) ChatMemberMap { if member.Sender == "" && member.SenderLogin == "" && !member.IsFromMe { - return + return cmm } cmm[member.Sender] = member + return cmm } // Add adds the given entry to this map, but will ignore it if an entry with the same Sender field already exists. From a973e5dc94c7fcaa43e6e181d04a539b005a5e28 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 8 Nov 2025 09:49:15 +0100 Subject: [PATCH 1488/1647] event/reply: only remove plaintext reply fallback if there is one in HTML --- event/reply.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/event/reply.go b/event/reply.go index 9ae1c110..5f55bb80 100644 --- a/event/reply.go +++ b/event/reply.go @@ -32,12 +32,13 @@ func TrimReplyFallbackText(text string) string { } func (content *MessageEventContent) RemoveReplyFallback() { - if len(content.RelatesTo.GetReplyTo()) > 0 && !content.replyFallbackRemoved { - if content.Format == FormatHTML { - content.FormattedBody = TrimReplyFallbackHTML(content.FormattedBody) + if len(content.RelatesTo.GetReplyTo()) > 0 && !content.replyFallbackRemoved && content.Format == FormatHTML { + origHTML := content.FormattedBody + content.FormattedBody = TrimReplyFallbackHTML(content.FormattedBody) + if content.FormattedBody != origHTML { + content.Body = TrimReplyFallbackText(content.Body) + content.replyFallbackRemoved = true } - content.Body = TrimReplyFallbackText(content.Body) - content.replyFallbackRemoved = true } } From fdd7632e53874af2546d43ea5c43614377fc02cf Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 9 Nov 2025 11:33:39 +0200 Subject: [PATCH 1489/1647] bridgev2/matrix: avoid sending message status notices for m.notice events --- bridgev2/matrix/connector.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index edd98045..362f74aa 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -514,7 +514,8 @@ func (br *Connector) internalSendMessageStatus(ctx context.Context, ms *bridgev2 Msg("Failed to send MSS event") } } - if ms.SendNotice && br.Config.Matrix.MessageErrorNotices && (ms.Status == event.MessageStatusFail || ms.Status == event.MessageStatusRetriable || ms.Step == status.MsgStepDecrypted) { + if ms.SendNotice && br.Config.Matrix.MessageErrorNotices && evt.MessageType != event.MsgNotice && + (ms.Status == event.MessageStatusFail || ms.Status == event.MessageStatusRetriable || ms.Step == status.MsgStepDecrypted) { content := ms.ToNoticeEvent(evt) if editEvent != "" { content.SetEdit(editEvent) From 14e16a3a8190e6e4e0600b7427e4dfaed90adb9c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 9 Nov 2025 11:40:10 +0200 Subject: [PATCH 1490/1647] bridgev2/matrix: drop events from users without permission earlier --- bridgev2/matrix/matrix.go | 10 ++++++++++ bridgev2/queue.go | 18 ++++++++++-------- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/bridgev2/matrix/matrix.go b/bridgev2/matrix/matrix.go index 64165941..6c94bccc 100644 --- a/bridgev2/matrix/matrix.go +++ b/bridgev2/matrix/matrix.go @@ -27,6 +27,11 @@ func (br *Connector) handleRoomEvent(ctx context.Context, evt *event.Event) { if br.shouldIgnoreEvent(evt) { return } + if !br.Config.Bridge.Permissions.Get(evt.Sender).SendEvents && evt.Type != event.StateMember { + zerolog.Ctx(ctx).Debug().Msg("Dropping event from user with no permission to send events") + br.SendMessageStatus(ctx, &bridgev2.ErrNoPermissionToInteract, bridgev2.StatusEventInfoFromEvent(evt)) + return + } if (evt.Type == event.EventMessage || evt.Type == event.EventSticker) && !evt.Mautrix.WasEncrypted && br.Config.Encryption.Require { zerolog.Ctx(ctx).Warn().Msg("Dropping unencrypted event as encryption is configured to be required") br.sendCryptoStatusError(ctx, evt, errMessageNotEncrypted, nil, 0, true) @@ -76,6 +81,11 @@ func (br *Connector) handleEncryptedEvent(ctx context.Context, evt *event.Event) Str("event_id", evt.ID.String()). Str("session_id", content.SessionID.String()). Logger() + if !br.Config.Bridge.Permissions.Get(evt.Sender).SendEvents { + log.Debug().Msg("Dropping event from user with no permission to send events") + br.SendMessageStatus(ctx, &bridgev2.ErrNoPermissionToInteract, bridgev2.StatusEventInfoFromEvent(evt)) + return + } ctx = log.WithContext(ctx) if br.Crypto == nil { br.sendCryptoStatusError(ctx, evt, errNoCrypto, nil, 0, true) diff --git a/bridgev2/queue.go b/bridgev2/queue.go index 95011cda..e1fb61c0 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -63,6 +63,12 @@ func (br *Bridge) rejectInviteOnNoPermission(ctx context.Context, evt *event.Eve return true } +var ( + ErrEventSenderUserNotFound = WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage() + ErrNoPermissionToInteract = WrapErrorInStatus(errors.New("you don't have permission to send messages")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage() + ErrNoPermissionForCommands = WrapErrorInStatus(WrapErrorInStatus(errors.New("you don't have permission to use commands")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage()) +) + func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventHandlingResult { // TODO maybe HandleMatrixEvent would be more appropriate as this also handles bot invites and commands @@ -78,13 +84,11 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH return EventHandlingResultFailed } else if sender == nil { log.Error().Msg("Couldn't get sender for incoming non-ephemeral Matrix event") - status := WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage() - br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) + br.Matrix.SendMessageStatus(ctx, &ErrEventSenderUserNotFound, StatusEventInfoFromEvent(evt)) return EventHandlingResultFailed } else if !sender.Permissions.SendEvents { if !br.rejectInviteOnNoPermission(ctx, evt, "interact with") { - status := WrapErrorInStatus(errors.New("you don't have permission to send messages")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage() - br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) + br.Matrix.SendMessageStatus(ctx, &ErrNoPermissionToInteract, StatusEventInfoFromEvent(evt)) } return EventHandlingResultIgnored } else if !sender.Permissions.Commands && br.rejectInviteOnNoPermission(ctx, evt, "send commands to") { @@ -92,8 +96,7 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH } } else if evt.Type.Class != event.EphemeralEventType { log.Error().Msg("Missing sender for incoming non-ephemeral Matrix event") - status := WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage() - br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) + br.Matrix.SendMessageStatus(ctx, &ErrEventSenderUserNotFound, StatusEventInfoFromEvent(evt)) return EventHandlingResultIgnored } if evt.Type == event.EventMessage && sender != nil { @@ -102,8 +105,7 @@ func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventH msg.RemovePerMessageProfileFallback() if strings.HasPrefix(msg.Body, br.Config.CommandPrefix) || evt.RoomID == sender.ManagementRoom { if !sender.Permissions.Commands { - status := WrapErrorInStatus(errors.New("you don't have permission to use commands")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage() - br.Matrix.SendMessageStatus(ctx, &status, StatusEventInfoFromEvent(evt)) + br.Matrix.SendMessageStatus(ctx, &ErrNoPermissionForCommands, StatusEventInfoFromEvent(evt)) return EventHandlingResultIgnored } go br.Commands.Handle( From 60cbe66e2f2877754662ad02462704b33b8d8ffa Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 9 Nov 2025 22:43:35 +0200 Subject: [PATCH 1491/1647] bridgev2/publicmedia: add support for custom path prefixes --- bridgev2/bridgeconfig/config.go | 3 ++- bridgev2/bridgeconfig/upgrade.go | 1 + bridgev2/matrix/connector.go | 2 +- bridgev2/matrix/mxmain/example-config.yaml | 3 +++ bridgev2/matrix/publicmedia.go | 16 ++++++++++------ 5 files changed, 17 insertions(+), 8 deletions(-) diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index 01819945..7d5ad46c 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -114,8 +114,9 @@ type DirectMediaConfig struct { type PublicMediaConfig struct { Enabled bool `yaml:"enabled"` SigningKey string `yaml:"signing_key"` - HashLength int `yaml:"hash_length"` Expiry int `yaml:"expiry"` + HashLength int `yaml:"hash_length"` + PathPrefix string `yaml:"path_prefix"` } type DoublePuppetConfig struct { diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index be8a8f96..1cec0f1e 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -133,6 +133,7 @@ func doUpgrade(helper up.Helper) { } helper.Copy(up.Int, "public_media", "expiry") helper.Copy(up.Int, "public_media", "hash_length") + helper.Copy(up.Str|up.Null, "public_media", "path_prefix") helper.Copy(up.Bool, "backfill", "enabled") helper.Copy(up.Int, "backfill", "max_initial_messages") diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 362f74aa..d81c34d2 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -275,7 +275,7 @@ func (br *Connector) GetPublicAddress() string { if br.Config.AppService.PublicAddress == "https://bridge.example.com" { return "" } - return br.Config.AppService.PublicAddress + return strings.TrimRight(br.Config.AppService.PublicAddress, "/") } func (br *Connector) GetRouter() *http.ServeMux { diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index aeb5b7db..59a307a0 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -277,6 +277,9 @@ public_media: expiry: 0 # Length of hash to use for public media URLs. Must be between 0 and 32. hash_length: 32 + # The path prefix for generated URLs. Note that this will NOT change the path where media is actually served. + # If you change this, you must configure your reverse proxy to rewrite the path accordingly. + path_prefix: /_mautrix/publicmedia # Settings for converting remote media to custom mxc:// URIs instead of reuploading. # More details can be found at https://docs.mau.fi/bridges/go/discord/direct-media.html diff --git a/bridgev2/matrix/publicmedia.go b/bridgev2/matrix/publicmedia.go index 95e37262..956a1eb7 100644 --- a/bridgev2/matrix/publicmedia.go +++ b/bridgev2/matrix/publicmedia.go @@ -14,6 +14,7 @@ import ( "fmt" "io" "net/http" + "strings" "time" "maunium.net/go/mautrix/bridgev2" @@ -115,11 +116,14 @@ func (br *Connector) GetPublicMediaAddress(contentURI id.ContentURIString) strin if err != nil || !parsed.IsValid() { return "" } - return fmt.Sprintf( - "%s/_mautrix/publicmedia/%s/%s/%s", - br.GetPublicAddress(), - parsed.Homeserver, - parsed.FileID, - base64.RawURLEncoding.EncodeToString(br.makePublicMediaChecksum(parsed)), + return strings.Join( + []string{ + br.GetPublicAddress(), + strings.Trim(br.Config.PublicMedia.PathPrefix, "/"), + parsed.Homeserver, + parsed.FileID, + base64.RawURLEncoding.EncodeToString(br.makePublicMediaChecksum(parsed)), + }, + "/", ) } From 2eea2e74125fae8c4f8848774fc22a272f9ef884 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 9 Nov 2025 23:02:23 +0200 Subject: [PATCH 1492/1647] bridgev2/publicmedia: add support for file name in content disposition --- bridgev2/matrix/publicmedia.go | 42 ++++++++++++++++++++++++++-------- bridgev2/matrixinterface.go | 1 + 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/bridgev2/matrix/publicmedia.go b/bridgev2/matrix/publicmedia.go index 956a1eb7..1212f5f9 100644 --- a/bridgev2/matrix/publicmedia.go +++ b/bridgev2/matrix/publicmedia.go @@ -13,7 +13,9 @@ import ( "encoding/binary" "fmt" "io" + "mime" "net/http" + "net/url" "strings" "time" @@ -35,6 +37,7 @@ func (br *Connector) initPublicMedia() error { } br.pubMediaSigKey = []byte(br.Config.PublicMedia.SigningKey) br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia) + br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}/{filename}", br.servePublicMedia) return nil } @@ -104,11 +107,24 @@ func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) { for _, hdr := range proxyHeadersToCopy { w.Header()[hdr] = resp.Header[hdr] } + if filename := r.PathValue("filename"); filename != "" { + contentDisposition, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Disposition")) + if contentDisposition == "" { + contentDisposition = "attachment" + } + w.Header().Set("Content-Disposition", mime.FormatMediaType(contentDisposition, map[string]string{ + "filename": filename, + })) + } w.WriteHeader(http.StatusOK) _, _ = io.Copy(w, resp.Body) } func (br *Connector) GetPublicMediaAddress(contentURI id.ContentURIString) string { + return br.GetPublicMediaAddressWithFileName(contentURI, "") +} + +func (br *Connector) GetPublicMediaAddressWithFileName(contentURI id.ContentURIString, fileName string) string { if br.pubMediaSigKey == nil { return "" } @@ -116,14 +132,20 @@ func (br *Connector) GetPublicMediaAddress(contentURI id.ContentURIString) strin if err != nil || !parsed.IsValid() { return "" } - return strings.Join( - []string{ - br.GetPublicAddress(), - strings.Trim(br.Config.PublicMedia.PathPrefix, "/"), - parsed.Homeserver, - parsed.FileID, - base64.RawURLEncoding.EncodeToString(br.makePublicMediaChecksum(parsed)), - }, - "/", - ) + fileName = url.PathEscape(strings.ReplaceAll(fileName, "/", "_")) + if fileName == ".." { + fileName = "" + } + parts := []string{ + br.GetPublicAddress(), + strings.Trim(br.Config.PublicMedia.PathPrefix, "/"), + parsed.Homeserver, + parsed.FileID, + base64.RawURLEncoding.EncodeToString(br.makePublicMediaChecksum(parsed)), + fileName, + } + if fileName == "" { + parts = parts[:len(parts)-1] + } + return strings.Join(parts, "/") } diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index e8489dc1..e388b6c2 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -69,6 +69,7 @@ type MatrixConnectorWithServer interface { type MatrixConnectorWithPublicMedia interface { GetPublicMediaAddress(contentURI id.ContentURIString) string + GetPublicMediaAddressWithFileName(contentURI id.ContentURIString, fileName string) string } type MatrixConnectorWithNameDisambiguation interface { From aa53cbc5285042228581d3add1dd68d1101b4d41 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 10 Nov 2025 00:11:39 +0200 Subject: [PATCH 1493/1647] bridgev2/publicmedia: add support for encrypted files --- bridgev2/bridgeconfig/config.go | 11 +- bridgev2/bridgeconfig/upgrade.go | 1 + bridgev2/database/database.go | 7 + bridgev2/database/publicmedia.go | 72 +++++++++ bridgev2/database/upgrades/00-latest.sql | 11 ++ .../database/upgrades/24-public-media.sql | 11 ++ bridgev2/errors.go | 4 + bridgev2/matrix/mxmain/example-config.yaml | 5 + bridgev2/matrix/publicmedia.go | 137 +++++++++++++++++- bridgev2/matrixinterface.go | 2 +- 10 files changed, 250 insertions(+), 11 deletions(-) create mode 100644 bridgev2/database/publicmedia.go create mode 100644 bridgev2/database/upgrades/24-public-media.sql diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index 7d5ad46c..1bf4dfcc 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -112,11 +112,12 @@ type DirectMediaConfig struct { } type PublicMediaConfig struct { - Enabled bool `yaml:"enabled"` - SigningKey string `yaml:"signing_key"` - Expiry int `yaml:"expiry"` - HashLength int `yaml:"hash_length"` - PathPrefix string `yaml:"path_prefix"` + Enabled bool `yaml:"enabled"` + SigningKey string `yaml:"signing_key"` + Expiry int `yaml:"expiry"` + HashLength int `yaml:"hash_length"` + PathPrefix string `yaml:"path_prefix"` + UseDatabase bool `yaml:"use_database"` } type DoublePuppetConfig struct { diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 1cec0f1e..8a9b6f4b 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -134,6 +134,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Int, "public_media", "expiry") helper.Copy(up.Int, "public_media", "hash_length") helper.Copy(up.Str|up.Null, "public_media", "path_prefix") + helper.Copy(up.Bool, "public_media", "use_database") helper.Copy(up.Bool, "backfill", "enabled") helper.Copy(up.Int, "backfill", "max_initial_messages") diff --git a/bridgev2/database/database.go b/bridgev2/database/database.go index f1789441..0729cb83 100644 --- a/bridgev2/database/database.go +++ b/bridgev2/database/database.go @@ -34,6 +34,7 @@ type Database struct { UserPortal *UserPortalQuery BackfillTask *BackfillTaskQuery KV *KVQuery + PublicMedia *PublicMediaQuery } type MetaMerger interface { @@ -141,6 +142,12 @@ func New(bridgeID networkid.BridgeID, mt MetaTypes, db *dbutil.Database) *Databa BridgeID: bridgeID, Database: db, }, + PublicMedia: &PublicMediaQuery{ + BridgeID: bridgeID, + QueryHelper: dbutil.MakeQueryHelper(db, func(_ *dbutil.QueryHelper[*PublicMedia]) *PublicMedia { + return &PublicMedia{} + }), + }, } } diff --git a/bridgev2/database/publicmedia.go b/bridgev2/database/publicmedia.go new file mode 100644 index 00000000..b667399c --- /dev/null +++ b/bridgev2/database/publicmedia.go @@ -0,0 +1,72 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package database + +import ( + "context" + "database/sql" + "time" + + "go.mau.fi/util/dbutil" + + "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/crypto/attachment" + "maunium.net/go/mautrix/id" +) + +type PublicMediaQuery struct { + BridgeID networkid.BridgeID + *dbutil.QueryHelper[*PublicMedia] +} + +type PublicMedia struct { + BridgeID networkid.BridgeID + PublicID string + MXC id.ContentURI + Keys *attachment.EncryptedFile + MimeType string + Expiry time.Time +} + +const ( + upsertPublicMediaQuery = ` + INSERT INTO public_media (bridge_id, public_id, mxc, keys, mimetype, expiry) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (bridge_id, public_id) DO UPDATE SET expiry=EXCLUDED.expiry + ` + getPublicMediaQuery = ` + SELECT bridge_id, public_id, mxc, keys, mimetype, expiry + FROM public_media WHERE bridge_id=$1 AND public_id=$2 + ` +) + +func (pmq *PublicMediaQuery) Put(ctx context.Context, pm *PublicMedia) error { + ensureBridgeIDMatches(&pm.BridgeID, pmq.BridgeID) + return pmq.Exec(ctx, upsertPublicMediaQuery, pm.sqlVariables()...) +} + +func (pmq *PublicMediaQuery) Get(ctx context.Context, publicID string) (*PublicMedia, error) { + return pmq.QueryOne(ctx, getPublicMediaQuery, pmq.BridgeID, publicID) +} + +func (pm *PublicMedia) Scan(row dbutil.Scannable) (*PublicMedia, error) { + var expiry sql.NullInt64 + var mimetype sql.NullString + err := row.Scan(&pm.BridgeID, &pm.PublicID, &pm.MXC, dbutil.JSON{Data: &pm.Keys}, &mimetype, &expiry) + if err != nil { + return nil, err + } + if expiry.Valid { + pm.Expiry = time.Unix(0, expiry.Int64) + } + pm.MimeType = mimetype.String + return pm, nil +} + +func (pm *PublicMedia) sqlVariables() []any { + return []any{pm.BridgeID, pm.PublicID, &pm.MXC, dbutil.JSONPtr(pm.Keys), dbutil.StrPtr(pm.MimeType), dbutil.ConvertedPtr(pm.Expiry, time.Time.UnixNano)} +} diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index a8bb5c64..786ef5ff 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -216,3 +216,14 @@ CREATE TABLE kv_store ( PRIMARY KEY (bridge_id, key) ); + +CREATE TABLE public_media ( + bridge_id TEXT NOT NULL, + public_id TEXT NOT NULL, + mxc TEXT NOT NULL, + keys jsonb, + mimetype TEXT, + expiry BIGINT, + + PRIMARY KEY (bridge_id, public_id) +); diff --git a/bridgev2/database/upgrades/24-public-media.sql b/bridgev2/database/upgrades/24-public-media.sql new file mode 100644 index 00000000..c4290090 --- /dev/null +++ b/bridgev2/database/upgrades/24-public-media.sql @@ -0,0 +1,11 @@ +-- v24 (compatible with v9+): Custom URLs for public media +CREATE TABLE public_media ( + bridge_id TEXT NOT NULL, + public_id TEXT NOT NULL, + mxc TEXT NOT NULL, + keys jsonb, + mimetype TEXT, + expiry BIGINT, + + PRIMARY KEY (bridge_id, public_id) +); diff --git a/bridgev2/errors.go b/bridgev2/errors.go index a06f30ed..ae13086d 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -75,6 +75,10 @@ var ( ErrRemoteEchoTimeout = WrapErrorInStatus(errors.New("remote echo timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) ErrRemoteAckTimeout = WrapErrorInStatus(errors.New("remote ack timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) + ErrPublicMediaDisabled = WrapErrorInStatus(errors.New("public media is not enabled in the bridge config")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrPublicMediaDatabaseDisabled = WrapErrorInStatus(errors.New("public media database storage is disabled")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrPublicMediaGenerateFailed = WrapErrorInStatus(errors.New("failed to generate public media URL")).WithIsCertain(true).WithMessage("failed to generate public media URL").WithErrorReason(event.MessageStatusUnsupported) + ErrDisappearingTimerUnsupported error = WrapErrorInStatus(errors.New("invalid disappearing timer")).WithIsCertain(true) ) diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 59a307a0..60d41772 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -280,6 +280,11 @@ public_media: # The path prefix for generated URLs. Note that this will NOT change the path where media is actually served. # If you change this, you must configure your reverse proxy to rewrite the path accordingly. path_prefix: /_mautrix/publicmedia + # Should the bridge store media metadata in the database in order to support encrypted media and generate shorter URLs? + # If false, the generated URLs will just have the MXC URI and a HMAC signature. + # The hash_length field will be used to decide the length of the generated URL. + # This also allows invalidating URLs by deleting the database entry. + use_database: false # Settings for converting remote media to custom mxc:// URIs instead of reuploading. # More details can be found at https://docs.mau.fi/bridges/go/discord/direct-media.html diff --git a/bridgev2/matrix/publicmedia.go b/bridgev2/matrix/publicmedia.go index 1212f5f9..82ea8c2b 100644 --- a/bridgev2/matrix/publicmedia.go +++ b/bridgev2/matrix/publicmedia.go @@ -7,6 +7,7 @@ package matrix import ( + "context" "crypto/hmac" "crypto/sha256" "encoding/base64" @@ -16,10 +17,16 @@ import ( "mime" "net/http" "net/url" + "slices" "strings" "time" + "github.com/rs/zerolog" + "maunium.net/go/mautrix/bridgev2" + "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/crypto/attachment" + "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -36,6 +43,8 @@ func (br *Connector) initPublicMedia() error { return fmt.Errorf("public media hash length is negative") } br.pubMediaSigKey = []byte(br.Config.PublicMedia.SigningKey) + br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{customID}", br.serveDatabasePublicMedia) + br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{customID}/{filename}", br.serveDatabasePublicMedia) br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}", br.servePublicMedia) br.AS.Router.HandleFunc("GET /_mautrix/publicmedia/{server}/{mediaID}/{checksum}/{filename}", br.servePublicMedia) return nil @@ -48,6 +57,20 @@ func (br *Connector) hashContentURI(uri id.ContentURI, expiry []byte) []byte { return hasher.Sum(expiry)[:br.Config.PublicMedia.HashLength+len(expiry)] } +func (br *Connector) hashDBPublicMedia(pm *database.PublicMedia) []byte { + hasher := hmac.New(sha256.New, br.pubMediaSigKey) + hasher.Write([]byte(pm.MXC.String())) + hasher.Write([]byte(pm.MimeType)) + if pm.Keys != nil { + hasher.Write([]byte(pm.Keys.Version)) + hasher.Write([]byte(pm.Keys.Key.Algorithm)) + hasher.Write([]byte(pm.Keys.Key.Key)) + hasher.Write([]byte(pm.Keys.InitVector)) + hasher.Write([]byte(pm.Keys.Hashes.SHA256)) + } + return hasher.Sum(nil)[:br.Config.PublicMedia.HashLength] +} + func (br *Connector) makePublicMediaChecksum(uri id.ContentURI) []byte { var expiresAt []byte if br.Config.PublicMedia.Expiry > 0 { @@ -97,9 +120,47 @@ func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) { http.Error(w, "checksum expired", http.StatusGone) return } + br.doProxyMedia(w, r, contentURI, nil, "") +} + +func (br *Connector) serveDatabasePublicMedia(w http.ResponseWriter, r *http.Request) { + if !br.Config.PublicMedia.UseDatabase { + http.Error(w, "public media short links are disabled", http.StatusNotFound) + return + } + log := zerolog.Ctx(r.Context()) + media, err := br.Bridge.DB.PublicMedia.Get(r.Context(), r.PathValue("customID")) + if err != nil { + log.Err(err).Msg("Failed to get public media from database") + http.Error(w, "failed to get media metadata", http.StatusInternalServerError) + return + } else if media == nil { + http.Error(w, "media ID not found", http.StatusNotFound) + return + } else if !media.Expiry.IsZero() && media.Expiry.Before(time.Now()) { + // This is not gone as it can still be refreshed in the DB + http.Error(w, "media expired", http.StatusNotFound) + return + } else if media.Keys != nil && media.Keys.PrepareForDecryption() != nil { + http.Error(w, "media keys are malformed", http.StatusInternalServerError) + return + } + br.doProxyMedia(w, r, media.MXC, media.Keys, media.MimeType) +} + +var safeMimes = []string{ + "text/css", "text/plain", "text/csv", + "application/json", "application/ld+json", + "image/jpeg", "image/gif", "image/png", "image/apng", "image/webp", "image/avif", + "video/mp4", "video/webm", "video/ogg", "video/quicktime", + "audio/mp4", "audio/webm", "audio/aac", "audio/mpeg", "audio/ogg", "audio/wave", + "audio/wav", "audio/x-wav", "audio/x-pn-wav", "audio/flac", "audio/x-flac", +} + +func (br *Connector) doProxyMedia(w http.ResponseWriter, r *http.Request, contentURI id.ContentURI, encInfo *attachment.EncryptedFile, mimeType string) { resp, err := br.Bot.Download(r.Context(), contentURI) if err != nil { - br.Log.Warn().Stringer("uri", contentURI).Err(err).Msg("Failed to download media to proxy") + zerolog.Ctx(r.Context()).Warn().Stringer("uri", contentURI).Err(err).Msg("Failed to download media to proxy") http.Error(w, "failed to download media", http.StatusInternalServerError) return } @@ -107,7 +168,24 @@ func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) { for _, hdr := range proxyHeadersToCopy { w.Header()[hdr] = resp.Header[hdr] } - if filename := r.PathValue("filename"); filename != "" { + stream := resp.Body + if encInfo != nil { + if mimeType == "" { + mimeType = "application/octet-stream" + } + contentDisposition := "attachment" + if slices.Contains(safeMimes, mimeType) { + contentDisposition = "inline" + } + dispositionArgs := map[string]string{} + if filename := r.PathValue("filename"); filename != "" { + dispositionArgs["filename"] = filename + } + w.Header().Set("Content-Type", mimeType) + w.Header().Set("Content-Disposition", mime.FormatMediaType(contentDisposition, dispositionArgs)) + // Note: this won't check the Close result like it should, but it's probably not a big deal here + stream = encInfo.DecryptStream(stream) + } else if filename := r.PathValue("filename"); filename != "" { contentDisposition, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Disposition")) if contentDisposition == "" { contentDisposition = "attachment" @@ -117,14 +195,14 @@ func (br *Connector) servePublicMedia(w http.ResponseWriter, r *http.Request) { })) } w.WriteHeader(http.StatusOK) - _, _ = io.Copy(w, resp.Body) + _, _ = io.Copy(w, stream) } func (br *Connector) GetPublicMediaAddress(contentURI id.ContentURIString) string { - return br.GetPublicMediaAddressWithFileName(contentURI, "") + return br.getPublicMediaAddressWithFileName(contentURI, "") } -func (br *Connector) GetPublicMediaAddressWithFileName(contentURI id.ContentURIString, fileName string) string { +func (br *Connector) getPublicMediaAddressWithFileName(contentURI id.ContentURIString, fileName string) string { if br.pubMediaSigKey == nil { return "" } @@ -149,3 +227,52 @@ func (br *Connector) GetPublicMediaAddressWithFileName(contentURI id.ContentURIS } return strings.Join(parts, "/") } + +func (br *Connector) GetPublicMediaAddressForEvent(ctx context.Context, evt *event.MessageEventContent) (string, error) { + if br.pubMediaSigKey == nil { + return "", bridgev2.ErrPublicMediaDisabled + } + if !br.Config.PublicMedia.UseDatabase { + if evt.File != nil { + return "", fmt.Errorf("can't generate address for encrypted file: %w", bridgev2.ErrPublicMediaDatabaseDisabled) + } + return br.getPublicMediaAddressWithFileName(evt.URL, evt.GetFileName()), nil + } + mxc := evt.URL + var keys *attachment.EncryptedFile + if evt.File != nil { + mxc = evt.File.URL + keys = &evt.File.EncryptedFile + } + parsedMXC, err := mxc.Parse() + if err != nil { + return "", fmt.Errorf("%w: failed to parse MXC: %w", bridgev2.ErrPublicMediaGenerateFailed, err) + } + pm := &database.PublicMedia{ + MXC: parsedMXC, + Keys: keys, + MimeType: evt.GetInfo().MimeType, + } + if br.Config.PublicMedia.Expiry > 0 { + pm.Expiry = time.Now().Add(time.Duration(br.Config.PublicMedia.Expiry) * time.Second) + } + pm.PublicID = base64.RawURLEncoding.EncodeToString(br.hashDBPublicMedia(pm)) + err = br.Bridge.DB.PublicMedia.Put(ctx, pm) + if err != nil { + return "", fmt.Errorf("%w: failed to store public media in database: %w", bridgev2.ErrPublicMediaGenerateFailed, err) + } + fileName := url.PathEscape(strings.ReplaceAll(evt.GetFileName(), "/", "_")) + if fileName == ".." { + fileName = "" + } + parts := []string{ + br.GetPublicAddress(), + strings.Trim(br.Config.PublicMedia.PathPrefix, "/"), + pm.PublicID, + fileName, + } + if fileName == "" { + parts = parts[:len(parts)-1] + } + return strings.Join(parts, "/"), nil +} diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index e388b6c2..07615daf 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -69,7 +69,7 @@ type MatrixConnectorWithServer interface { type MatrixConnectorWithPublicMedia interface { GetPublicMediaAddress(contentURI id.ContentURIString) string - GetPublicMediaAddressWithFileName(contentURI id.ContentURIString, fileName string) string + GetPublicMediaAddressForEvent(ctx context.Context, evt *event.MessageEventContent) (string, error) } type MatrixConnectorWithNameDisambiguation interface { From 1779c723168a9a10179d2f871a9b83b1b5be26c3 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Tue, 4 Nov 2025 16:45:23 +0100 Subject: [PATCH 1494/1647] bridgev2: pass back event ID and stream order in send results --- bridgev2/portal.go | 4 ++-- bridgev2/queue.go | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 8d846f43..3319f874 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1227,7 +1227,7 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin // Not exactly queued, but not finished either return EventHandlingResultQueued } - return EventHandlingResultSuccess + return EventHandlingResultSuccess.WithEventID(message.MXID).WithStreamOrder(resp.StreamOrder) } // AddPendingToIgnore adds a transaction ID that should be ignored if encountered as a new message. @@ -1551,7 +1551,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi log.Err(err).Msg("Failed to save reaction to database") } portal.sendSuccessStatus(ctx, evt, 0, deterministicID) - return EventHandlingResultSuccess + return EventHandlingResultSuccess.WithEventID(deterministicID) } func handleMatrixRoomMeta[APIType any, ContentType any]( diff --git a/bridgev2/queue.go b/bridgev2/queue.go index e1fb61c0..308d03c5 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -163,6 +163,21 @@ type EventHandlingResult struct { Error error // Whether the Error should be sent as a MSS event. SendMSS bool + + // EventID from the network + EventID id.EventID + // Stream order from the network + StreamOrder int64 +} + +func (ehr EventHandlingResult) WithEventID(id id.EventID) EventHandlingResult { + ehr.EventID = id + return ehr +} + +func (ehr EventHandlingResult) WithStreamOrder(order int64) EventHandlingResult { + ehr.StreamOrder = order + return ehr } func (ehr EventHandlingResult) WithError(err error) EventHandlingResult { From 913a28fdce79c1e250c578bc554eafd9d672a021 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Tue, 4 Nov 2025 16:45:23 +0100 Subject: [PATCH 1495/1647] bridgev2: pass back event ID and stream order in send results --- bridgev2/portal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 3319f874..955fd401 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1469,7 +1469,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi if existing.EmojiID != "" || existing.Emoji == preResp.Emoji { log.Debug().Msg("Ignoring duplicate reaction") portal.sendSuccessStatus(ctx, evt, 0, deterministicID) - return EventHandlingResultIgnored + return EventHandlingResultIgnored.WithEventID(deterministicID) } react.ReactionToOverride = existing _, err = portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ From 77519b6de74742075ad3e4138ab74526b7621b33 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 10 Nov 2025 00:17:39 +0200 Subject: [PATCH 1496/1647] bridgev2/errors: send notice for public media errors --- bridgev2/errors.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bridgev2/errors.go b/bridgev2/errors.go index ae13086d..e81b8953 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -75,9 +75,9 @@ var ( ErrRemoteEchoTimeout = WrapErrorInStatus(errors.New("remote echo timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) ErrRemoteAckTimeout = WrapErrorInStatus(errors.New("remote ack timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) - ErrPublicMediaDisabled = WrapErrorInStatus(errors.New("public media is not enabled in the bridge config")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrPublicMediaDatabaseDisabled = WrapErrorInStatus(errors.New("public media database storage is disabled")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrPublicMediaGenerateFailed = WrapErrorInStatus(errors.New("failed to generate public media URL")).WithIsCertain(true).WithMessage("failed to generate public media URL").WithErrorReason(event.MessageStatusUnsupported) + ErrPublicMediaDisabled = WrapErrorInStatus(errors.New("public media is not enabled in the bridge config")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true) + ErrPublicMediaDatabaseDisabled = WrapErrorInStatus(errors.New("public media database storage is disabled")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true) + ErrPublicMediaGenerateFailed = WrapErrorInStatus(errors.New("failed to generate public media URL")).WithIsCertain(true).WithMessage("failed to generate public media URL").WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true) ErrDisappearingTimerUnsupported error = WrapErrorInStatus(errors.New("invalid disappearing timer")).WithIsCertain(true) ) From bb0b26a58bbdcc1157b9d8cffbc474674f3ad480 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 10 Nov 2025 23:34:28 +0200 Subject: [PATCH 1497/1647] bridgev2/database: fix latest version --- bridgev2/database/upgrades/00-latest.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index 786ef5ff..efde8816 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v23 (compatible with v9+): Latest revision +-- v0 -> v24 (compatible with v9+): Latest revision CREATE TABLE "user" ( bridge_id TEXT NOT NULL, mxid TEXT NOT NULL, From 19ed3ac40b850e916e8ad0b2688b21a5c70d7183 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 11 Nov 2025 01:32:27 +0200 Subject: [PATCH 1498/1647] changelog: update --- CHANGELOG.md | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f59e6853..7ee4a13d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,43 @@ +## v0.26.0 (unreleased) + +* *(client,federation)* Added size limits for responses to make it safer to send + requests to untrusted servers. +* *(client)* Added wrapper for `/admin/whois` client API + (thanks to [@nexy7574] in [#411]). +* *(synapseadmin)* Added `force_purge` option to DeleteRoom + (thanks to [@nexy7574] in [#420]). +* *(bridgev2)* Added optional automatic rollback of room state if bridging the + change to the remote network fails. +* *(bridgev2)* Added management room notices if transient disconnect state + doesn't resolve within 3 minutes. +* *(bridgev2)* Added interface to signal that certain participants couldn't be + invited when creating a group. +* *(bridgev2)* Added `select` type for user input fields in login. +* *(bridgev2/matrix)* Added checks to avoid sending error messages in reply to + other bots. +* *(bridgev2/publicmedia)* Added support for custom path prefixes, file names, + and encrypted files. +* *(bridgev2/commands)* Added command to resync a single portal. +* *(bridgev2/commands)* Added create group command. +* *(bridgev2/config)* Added option to limit maximum number of logins. +* *(bridgev2/disappear)* Changed read receipt handling to only start + disappearing timers for messages up to the read message (note: may not work in + all cases if the read receipt points at an unknown event). +* *(event/reply)* Changed plaintext reply fallback removal to only happen when + an HTML reply fallback is removed successfully. +* *(bridgev2/matrix)* Fixed unnecessary sleep after registering bot on first run. +* *(crypto/goolm)* Fixed panic when processing certain malformed Olm messages. +* *(federation)* Fixed HTTP method for sending transactions + (thanks to [@nexy7574] in [#426]). +* *(federation)* Fixed response body being closed even when using `DontReadBody` + parameter. +* *(federation)* Fixed validating auth for requests with query params. +* *(federation/eventauth)* Fixed typo causing restricted joins to not work. + +[#411]: github.com/mautrix/go/pull/411 +[#420]: github.com/mautrix/go/pull/420 +[#426]: github.com/mautrix/go/pull/426 + ## v0.25.2 (2025-10-16) * **Breaking change *(id)*** Split `UserID.ParseAndValidate` into From 7b33248d3dd019c340334781ec608a877c5f2ccc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 12 Nov 2025 01:53:31 +0200 Subject: [PATCH 1499/1647] bridgev2: add flag to indicate when bridge is stopping --- bridgev2/bridge.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 2ad6a614..c84c2fd5 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -11,6 +11,7 @@ import ( "fmt" "os" "sync" + "sync/atomic" "time" "github.com/rs/zerolog" @@ -52,6 +53,7 @@ type Bridge struct { Background bool ExternallyManagedDB bool + stopping atomic.Bool wakeupBackfillQueue chan struct{} stopBackfillQueue *exsync.Event @@ -127,6 +129,7 @@ func (br *Bridge) Start(ctx context.Context) error { func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, params *ConnectBackgroundParams) error { br.Background = true + br.stopping.Store(false) err := br.StartConnectors(ctx) if err != nil { return err @@ -162,6 +165,7 @@ func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, pa case <-time.After(20 * time.Second): case <-ctx.Done(): } + br.stopping.Store(true) return nil } else { br.Log.Info().Str("user_login_id", string(login.ID)).Msg("Starting individual user login in background mode") @@ -171,6 +175,7 @@ func (br *Bridge) RunOnce(ctx context.Context, loginID networkid.UserLoginID, pa func (br *Bridge) StartConnectors(ctx context.Context) error { br.Log.Info().Msg("Starting bridge") + br.stopping.Store(false) if br.BackgroundCtx == nil || br.BackgroundCtx.Err() != nil { br.BackgroundCtx, br.cancelBackgroundCtx = context.WithCancel(context.Background()) br.BackgroundCtx = br.Log.WithContext(br.BackgroundCtx) @@ -368,6 +373,10 @@ func (br *Bridge) StartLogins(ctx context.Context) error { return nil } +func (br *Bridge) IsStopping() bool { + return br.stopping.Load() +} + func (br *Bridge) Stop() { br.stop(false, 0) } @@ -378,6 +387,7 @@ func (br *Bridge) StopWithTimeout(timeout time.Duration) { func (br *Bridge) stop(isRunOnce bool, timeout time.Duration) { br.Log.Info().Msg("Shutting down bridge") + br.stopping.Store(true) br.DisappearLoop.Stop() br.stopBackfillQueue.Set() br.Matrix.PreStop() From 4913b123f19b0534a936748e6ce921419bfb9994 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 12 Nov 2025 14:57:18 +0200 Subject: [PATCH 1500/1647] bridgev2/space: let network connector customize personal filtering space --- bridgev2/networkinterface.go | 6 ++++++ bridgev2/space.go | 4 ++++ 2 files changed, 10 insertions(+) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 8a39c7f8..9bbcf897 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -18,6 +18,7 @@ import ( "go.mau.fi/util/ptr" "go.mau.fi/util/random" + "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" @@ -294,6 +295,11 @@ type PortalBridgeInfoFillingNetwork interface { FillPortalBridgeInfo(portal *Portal, content *event.BridgeEventContent) } +type PersonalFilteringCustomizingNetworkAPI interface { + NetworkAPI + CustomizePersonalFilteringSpace(req *mautrix.ReqCreateRoom) +} + // ConfigValidatingNetwork is an optional interface that network connectors can implement to validate config fields // before the bridge is started. // diff --git a/bridgev2/space.go b/bridgev2/space.go index ae9013cb..f6d07922 100644 --- a/bridgev2/space.go +++ b/bridgev2/space.go @@ -172,6 +172,10 @@ func (ul *UserLogin) GetSpaceRoom(ctx context.Context) (id.RoomID, error) { // TODO remove this after initial_members is supported in hungryserv req.BeeperAutoJoinInvites = true } + pfc, ok := ul.Client.(PersonalFilteringCustomizingNetworkAPI) + if ok { + pfc.CustomizePersonalFilteringSpace(req) + } ul.SpaceRoom, err = ul.Bridge.Bot.CreateRoom(ctx, req) if err != nil { return "", fmt.Errorf("failed to create space room: %w", err) From 8b70baa3360b203c59e7d7c2b81151001f30365a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 12 Nov 2025 15:34:25 +0200 Subject: [PATCH 1501/1647] bridgev2/commands: add support for ResolveIdentifierTryNext in pm command --- bridgev2/commands/startchat.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go index 99924851..24586387 100644 --- a/bridgev2/commands/startchat.go +++ b/bridgev2/commands/startchat.go @@ -8,6 +8,7 @@ package commands import ( "context" + "errors" "fmt" "html" "maps" @@ -118,9 +119,13 @@ func fnResolveIdentifier(ce *Event) { if api == nil { return } + allLogins := ce.User.GetUserLogins() createChat := ce.Command == "start-chat" || ce.Command == "pm" identifier := strings.Join(identifierParts, " ") resp, err := provisionutil.ResolveIdentifier(ce.Ctx, login, identifier, createChat) + for i := 0; i < len(allLogins) && errors.Is(err, bridgev2.ErrResolveIdentifierTryNext); i++ { + resp, err = provisionutil.ResolveIdentifier(ce.Ctx, allLogins[i], identifier, createChat) + } if err != nil { ce.Reply("Failed to resolve identifier: %v", err) return From 981addddc91c38970f85ca886e8bc2bdeb550a36 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 12 Nov 2025 19:38:08 +0200 Subject: [PATCH 1502/1647] bridgev2/config: add option to disable kicking matrix users --- bridgev2/bridgeconfig/config.go | 1 + bridgev2/bridgeconfig/upgrade.go | 1 + bridgev2/matrix/mxmain/example-config.yaml | 3 +++ bridgev2/portal.go | 2 +- 4 files changed, 6 insertions(+), 1 deletion(-) diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index 1bf4dfcc..b1718f30 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -78,6 +78,7 @@ type BridgeConfig struct { CrossRoomReplies bool `yaml:"cross_room_replies"` OutgoingMessageReID bool `yaml:"outgoing_message_re_id"` RevertFailedStateChanges bool `yaml:"revert_failed_state_changes"` + KickMatrixUsers bool `yaml:"kick_matrix_users"` CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"` Relay RelayConfig `yaml:"relay"` Permissions PermissionConfig `yaml:"permissions"` diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 8a9b6f4b..0dbff802 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -41,6 +41,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "bridge", "deduplicate_matrix_messages") helper.Copy(up.Bool, "bridge", "cross_room_replies") helper.Copy(up.Bool, "bridge", "revert_failed_state_changes") + helper.Copy(up.Bool, "bridge", "kick_matrix_users") helper.Copy(up.Bool, "bridge", "cleanup_on_logout", "enabled") helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "private") helper.Copy(up.Str, "bridge", "cleanup_on_logout", "manual", "relayed") diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 60d41772..27c3aa67 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -49,6 +49,9 @@ bridge: cross_room_replies: false # If a state event fails to bridge, should the bridge revert any state changes made by that event? revert_failed_state_changes: false + # In portals with no relay set, should Matrix users be kicked if they're + # not logged into an account that's in the remote chat? + kick_matrix_users: true # What should be done to portal rooms when a user logs out or is logged out? # Permitted values: diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 955fd401..c2d87d4e 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -4411,7 +4411,7 @@ func (portal *Portal) syncParticipants( if memberEvt.Membership == event.MembershipLeave || memberEvt.Membership == event.MembershipBan { continue } - if !portal.Bridge.IsGhostMXID(extraMember) && portal.Relay != nil { + if !portal.Bridge.IsGhostMXID(extraMember) && (portal.Relay != nil || !portal.Bridge.Config.KickMatrixUsers) { continue } _, err = portal.Bridge.Bot.SendState(ctx, portal.MXID, event.StateMember, extraMember.String(), &event.Content{ From e31d186dc8d3813b171564404cf3e7859d800748 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 12 Nov 2025 21:44:04 +0200 Subject: [PATCH 1503/1647] statestore: save join rules for rooms --- sqlstatestore/statestore.go | 23 +++++++++++++++++++++++ sqlstatestore/v00-latest-revision.sql | 3 ++- sqlstatestore/v10-join-rules.sql | 2 ++ statestore.go | 20 ++++++++++++++++++++ 4 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 sqlstatestore/v10-join-rules.sql diff --git a/sqlstatestore/statestore.go b/sqlstatestore/statestore.go index c4126802..11957dfa 100644 --- a/sqlstatestore/statestore.go +++ b/sqlstatestore/statestore.go @@ -470,3 +470,26 @@ func (store *SQLStateStore) GetCreate(ctx context.Context, roomID id.RoomID) (ev } return } + +func (store *SQLStateStore) SetJoinRules(ctx context.Context, roomID id.RoomID, rules *event.JoinRulesEventContent) error { + if roomID == "" { + return fmt.Errorf("room ID is empty") + } + _, err := store.Exec(ctx, ` + INSERT INTO mx_room_state (room_id, join_rules) VALUES ($1, $2) + ON CONFLICT (room_id) DO UPDATE SET join_rules=excluded.join_rules + `, roomID, dbutil.JSON{Data: rules}) + return err +} + +func (store *SQLStateStore) GetJoinRules(ctx context.Context, roomID id.RoomID) (levels *event.JoinRulesEventContent, err error) { + levels = &event.JoinRulesEventContent{} + err = store. + QueryRow(ctx, "SELECT join_rules FROM mx_room_state WHERE room_id=$1 AND join_rules IS NOT NULL", roomID). + Scan(&dbutil.JSON{Data: &levels}) + if errors.Is(err, sql.ErrNoRows) { + levels = nil + err = nil + } + return +} diff --git a/sqlstatestore/v00-latest-revision.sql b/sqlstatestore/v00-latest-revision.sql index b5a858ec..4679f1c6 100644 --- a/sqlstatestore/v00-latest-revision.sql +++ b/sqlstatestore/v00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v9 (compatible with v3+): Latest revision +-- v0 -> v10 (compatible with v3+): Latest revision CREATE TABLE mx_registrations ( user_id TEXT PRIMARY KEY @@ -27,5 +27,6 @@ CREATE TABLE mx_room_state ( power_levels jsonb, encryption jsonb, create_event jsonb, + join_rules jsonb, members_fetched BOOLEAN NOT NULL DEFAULT false ); diff --git a/sqlstatestore/v10-join-rules.sql b/sqlstatestore/v10-join-rules.sql new file mode 100644 index 00000000..3074c46a --- /dev/null +++ b/sqlstatestore/v10-join-rules.sql @@ -0,0 +1,2 @@ +-- v10 (compatible with v3+): Add join rules to room state table +ALTER TABLE mx_room_state ADD COLUMN join_rules jsonb; diff --git a/statestore.go b/statestore.go index 1933ab95..c6267c5b 100644 --- a/statestore.go +++ b/statestore.go @@ -37,6 +37,9 @@ type StateStore interface { SetCreate(ctx context.Context, evt *event.Event) error GetCreate(ctx context.Context, roomID id.RoomID) (*event.Event, error) + GetJoinRules(ctx context.Context, roomID id.RoomID) (*event.JoinRulesEventContent, error) + SetJoinRules(ctx context.Context, roomID id.RoomID, content *event.JoinRulesEventContent) error + HasFetchedMembers(ctx context.Context, roomID id.RoomID) (bool, error) MarkMembersFetched(ctx context.Context, roomID id.RoomID) error GetAllMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) @@ -73,6 +76,8 @@ func UpdateStateStore(ctx context.Context, store StateStore, evt *event.Event) { err = store.SetEncryptionEvent(ctx, evt.RoomID, content) case *event.CreateEventContent: err = store.SetCreate(ctx, evt) + case *event.JoinRulesEventContent: + err = store.SetJoinRules(ctx, evt.RoomID, content) default: switch evt.Type { case event.StateMember, event.StatePowerLevels, event.StateEncryption, event.StateCreate: @@ -107,11 +112,13 @@ type MemoryStateStore struct { PowerLevels map[id.RoomID]*event.PowerLevelsEventContent `json:"power_levels"` Encryption map[id.RoomID]*event.EncryptionEventContent `json:"encryption"` Create map[id.RoomID]*event.Event `json:"create"` + JoinRules map[id.RoomID]*event.JoinRulesEventContent `json:"join_rules"` registrationsLock sync.RWMutex membersLock sync.RWMutex powerLevelsLock sync.RWMutex encryptionLock sync.RWMutex + joinRulesLock sync.RWMutex } func NewMemoryStateStore() StateStore { @@ -354,6 +361,19 @@ func (store *MemoryStateStore) GetEncryptionEvent(_ context.Context, roomID id.R return store.Encryption[roomID], nil } +func (store *MemoryStateStore) SetJoinRules(ctx context.Context, roomID id.RoomID, content *event.JoinRulesEventContent) error { + store.joinRulesLock.Lock() + store.JoinRules[roomID] = content + store.joinRulesLock.Unlock() + return nil +} + +func (store *MemoryStateStore) GetJoinRules(ctx context.Context, roomID id.RoomID) (*event.JoinRulesEventContent, error) { + store.joinRulesLock.RLock() + defer store.joinRulesLock.RUnlock() + return store.JoinRules[roomID], nil +} + func (store *MemoryStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) { cfg, err := store.GetEncryptionEvent(ctx, roomID) return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1, err From 6c7828afe37073991cad089d85efa1016c94dc94 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 12 Nov 2025 21:44:23 +0200 Subject: [PATCH 1504/1647] bridgev2/portal: skip invite step if room is public --- bridgev2/matrix/connector.go | 24 ++++++++++++++++++++---- bridgev2/portal.go | 11 ++++++++++- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index d81c34d2..3e05837f 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -26,6 +26,7 @@ import ( _ "go.mau.fi/util/dbutil/litestream" "go.mau.fi/util/exbytes" "go.mau.fi/util/exsync" + "go.mau.fi/util/ptr" "go.mau.fi/util/random" "golang.org/x/sync/semaphore" @@ -599,10 +600,25 @@ func (br *Connector) GetPowerLevels(ctx context.Context, roomID id.RoomID) (*eve } func (br *Connector) GetStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*event.Event, error) { - if eventType == event.StateCreate && stateKey == "" { - createEvt, err := br.Bot.StateStore.GetCreate(ctx, roomID) - if err != nil || createEvt != nil { - return createEvt, err + if stateKey == "" { + switch eventType { + case event.StateCreate: + createEvt, err := br.Bot.StateStore.GetCreate(ctx, roomID) + if err != nil || createEvt != nil { + return createEvt, err + } + case event.StateJoinRules: + joinRulesContent, err := br.Bot.StateStore.GetJoinRules(ctx, roomID) + if err != nil { + return nil, err + } else if joinRulesContent != nil { + return &event.Event{ + Type: event.StateJoinRules, + RoomID: roomID, + StateKey: ptr.Ptr(""), + Content: event.Content{Parsed: joinRulesContent}, + }, nil + } } } return br.Bot.FullStateEvent(ctx, roomID, eventType, "") diff --git a/bridgev2/portal.go b/bridgev2/portal.go index c2d87d4e..344ca807 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -4236,6 +4236,15 @@ func (portal *Portal) updateOtherUser(ctx context.Context, members *ChatMemberLi return false } +func (portal *Portal) roomIsPublic(ctx context.Context) bool { + evt, err := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState).GetStateEvent(ctx, portal.MXID, event.StateJoinRules, "") + if err != nil { + zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to get join rules to check if room is public") + return false + } + return evt != nil && evt.Content.AsJoinRules().JoinRule == event.JoinRulePublic +} + func (portal *Portal) syncParticipants( ctx context.Context, members *ChatMemberList, @@ -4304,7 +4313,7 @@ func (portal *Portal) syncParticipants( wrappedContent := &event.Content{Parsed: content, Raw: exmaps.NonNilClone(member.MemberEventExtra)} addExcludeFromTimeline(wrappedContent.Raw) thisEvtSender := sender - if member.Membership == event.MembershipJoin { + if member.Membership == event.MembershipJoin && (intent == nil || !portal.roomIsPublic(ctx)) { content.Membership = event.MembershipInvite if intent != nil { wrappedContent.Raw["fi.mau.will_auto_accept"] = true From e9bfa0c51912e6a25ea6d992861fe5abe081417a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 12 Nov 2025 22:04:29 +0200 Subject: [PATCH 1505/1647] bridgev2/portal: treat spam checker join rule as public --- bridgev2/portal.go | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 344ca807..c59d21c7 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -4236,13 +4236,33 @@ func (portal *Portal) updateOtherUser(ctx context.Context, members *ChatMemberLi return false } +func looksDirectlyJoinable(rule *event.JoinRulesEventContent) bool { + switch rule.JoinRule { + case event.JoinRulePublic: + return true + case event.JoinRuleKnockRestricted, event.JoinRuleRestricted: + for _, allow := range rule.Allow { + if allow.Type == "fi.mau.spam_checker" { + return true + } + } + } + return false +} + func (portal *Portal) roomIsPublic(ctx context.Context) bool { evt, err := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState).GetStateEvent(ctx, portal.MXID, event.StateJoinRules, "") if err != nil { zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to get join rules to check if room is public") return false + } else if evt == nil { + return false } - return evt != nil && evt.Content.AsJoinRules().JoinRule == event.JoinRulePublic + content, ok := evt.Content.Parsed.(*event.JoinRulesEventContent) + if !ok { + return false + } + return looksDirectlyJoinable(content) } func (portal *Portal) syncParticipants( From 85e25748a8a052825e79ee761ebfe9b910438581 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 12 Nov 2025 23:09:49 +0200 Subject: [PATCH 1506/1647] bridgev2/portal: ensure join is sent using target intent --- bridgev2/portal.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index c59d21c7..fcdfc02c 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -4363,7 +4363,11 @@ func (portal *Portal) syncParticipants( currentMember.Membership = event.MembershipLeave } } - _, err = portal.sendStateWithIntentOrBot(ctx, thisEvtSender, event.StateMember, extraUserID.String(), wrappedContent, ts) + if content.Membership == event.MembershipJoin && intent != nil && intent.GetMXID() == extraUserID { + _, err = intent.SendState(ctx, portal.MXID, event.StateMember, extraUserID.String(), wrappedContent, ts) + } else { + _, err = portal.sendStateWithIntentOrBot(ctx, thisEvtSender, event.StateMember, extraUserID.String(), wrappedContent, ts) + } if err != nil { addLogContext(log.Err(err)). Str("new_membership", string(content.Membership)). From 828ba3cec1012c3377dd23c6b07715543d02eb0b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 12 Nov 2025 23:14:37 +0200 Subject: [PATCH 1507/1647] bridgev2/portal: add capability to disable formatting relayed messages --- bridgev2/portal.go | 10 ++++++---- event/capabilities.go | 2 ++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index fcdfc02c..b664c8f6 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1054,10 +1054,12 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin log.Debug().Msg("Ignoring poll event from relayed user") return EventHandlingResultIgnored.WithMSSError(ErrIgnoringPollFromRelayedUser) } - msgContent, err = portal.Bridge.Config.Relay.FormatMessage(msgContent, origSender) - if err != nil { - log.Err(err).Msg("Failed to format message for relaying") - return EventHandlingResultFailed.WithMSSError(err) + if !caps.PerMessageProfileRelay { + msgContent, err = portal.Bridge.Config.Relay.FormatMessage(msgContent, origSender) + if err != nil { + log.Err(err).Msg("Failed to format message for relaying") + return EventHandlingResultFailed.WithMSSError(err) + } } } if msgContent != nil { diff --git a/event/capabilities.go b/event/capabilities.go index 5ecea4a2..4b7ff186 100644 --- a/event/capabilities.go +++ b/event/capabilities.go @@ -60,6 +60,8 @@ type RoomFeatures struct { MarkAsUnread bool `json:"mark_as_unread,omitempty"` DeleteChat bool `json:"delete_chat,omitempty"` DeleteChatForEveryone bool `json:"delete_chat_for_everyone,omitempty"` + + PerMessageProfileRelay bool `json:"-"` } func (rf *RoomFeatures) GetID() string { From 151d9456850166ea9ef839fb1906f23c9e09b04b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 13 Nov 2025 01:29:45 +0200 Subject: [PATCH 1508/1647] event/capabilities: add docstrings for state and member_actions --- event/capabilities.d.ts | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/event/capabilities.d.ts b/event/capabilities.d.ts index 2d95cd50..1fbc9610 100644 --- a/event/capabilities.d.ts +++ b/event/capabilities.d.ts @@ -16,7 +16,22 @@ export interface RoomFeatures { * If a message type isn't listed here, it should be treated as support level -2 (will be rejected). */ file?: Record + /** + * Supported state event types and their parameters. Currently, there are no parameters, + * but it is likely there will be some in the future (like max name/topic length, avatar mime types, etc.). + * + * Events that are not listed or have a support level of zero or below should be treated as unsupported. + * + * Clients should at least check `m.room.name`, `m.room.topic`, and `m.room.avatar` here. + * `m.room.member` will not be listed here, as it's controlled by the member_actions field. + * `com.beeper.disappearing_timer` should be listed here, but the parameters are in the disappearing_timer field for now. + */ state?: Record + /** + * Supported member actions and their support levels. + * + * Actions that are not listed or have a support level of zero or below should be treated as unsupported. + */ member_actions?: Record /** Maximum length of normal text messages. */ From eb2fb84009591af94f00f110970f59c108b0a875 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 13 Nov 2025 17:32:14 +0200 Subject: [PATCH 1509/1647] appservice/intent: don't EnsureJoined when sending massaged own join event --- appservice/intent.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/appservice/intent.go b/appservice/intent.go index 4635f59a..611bf6d8 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -243,7 +243,11 @@ func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, e } func (intent *IntentAPI) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { - if err := intent.EnsureJoined(ctx, roomID); err != nil { + if eventType != event.StateMember || stateKey != string(intent.UserID) { + if err := intent.EnsureJoined(ctx, roomID); err != nil { + return nil, err + } + } else if err := intent.EnsureRegistered(ctx); err != nil { return nil, err } contentJSON = intent.AddDoublePuppetValueWithTS(contentJSON, ts) From 0b73e9e7bedbbe6201d7163901f0e52fe7b512f2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 13 Nov 2025 17:38:45 +0200 Subject: [PATCH 1510/1647] client,appservice: deprecate SendMassagedStateEvent in favor of SendStateEvent params --- appservice/intent.go | 26 ++++++++------------------ bridgev2/matrix/intent.go | 12 ++---------- client.go | 17 +++++++++-------- 3 files changed, 19 insertions(+), 36 deletions(-) diff --git a/appservice/intent.go b/appservice/intent.go index 611bf6d8..e4d8e100 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -214,23 +214,20 @@ func (intent *IntentAPI) AddDoublePuppetValueWithTS(into any, ts int64) any { } } -func (intent *IntentAPI) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}) (*mautrix.RespSendEvent, error) { +func (intent *IntentAPI) SendMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) { if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err } contentJSON = intent.AddDoublePuppetValue(contentJSON) - return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON) + return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON, extra...) } +// Deprecated: use SendMessageEvent with mautrix.ReqSendEvent.Timestamp instead func (intent *IntentAPI) SendMassagedMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { - if err := intent.EnsureJoined(ctx, roomID); err != nil { - return nil, err - } - contentJSON = intent.AddDoublePuppetValueWithTS(contentJSON, ts) - return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts}) + return intent.SendMessageEvent(ctx, roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts}) } -func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (*mautrix.RespSendEvent, error) { +func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) { if eventType != event.StateMember || stateKey != string(intent.UserID) { if err := intent.EnsureJoined(ctx, roomID); err != nil { return nil, err @@ -239,19 +236,12 @@ func (intent *IntentAPI) SendStateEvent(ctx context.Context, roomID id.RoomID, e return nil, err } contentJSON = intent.AddDoublePuppetValue(contentJSON) - return intent.Client.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON) + return intent.Client.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, extra...) } +// Deprecated: use SendStateEvent with mautrix.ReqSendEvent.Timestamp instead func (intent *IntentAPI) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { - if eventType != event.StateMember || stateKey != string(intent.UserID) { - if err := intent.EnsureJoined(ctx, roomID); err != nil { - return nil, err - } - } else if err := intent.EnsureRegistered(ctx); err != nil { - return nil, err - } - contentJSON = intent.AddDoublePuppetValueWithTS(contentJSON, ts) - return intent.Client.SendMassagedStateEvent(ctx, roomID, eventType, stateKey, contentJSON, ts) + return intent.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, mautrix.ReqSendEvent{Timestamp: ts}) } func (intent *IntentAPI) StateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) error { diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 27892fb6..cb4b9b8f 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -82,11 +82,7 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType eventType = event.EventEncrypted } } - if extra.Timestamp.IsZero() { - return as.Matrix.SendMessageEvent(ctx, roomID, eventType, content) - } else { - return as.Matrix.SendMassagedMessageEvent(ctx, roomID, eventType, content, extra.Timestamp.UnixMilli()) - } + return as.Matrix.SendMessageEvent(ctx, roomID, eventType, content, mautrix.ReqSendEvent{Timestamp: extra.Timestamp.UnixMilli()}) } func (as *ASIntent) fillMemberEvent(ctx context.Context, roomID id.RoomID, userID id.UserID, content *event.Content) { @@ -126,11 +122,7 @@ func (as *ASIntent) SendState(ctx context.Context, roomID id.RoomID, eventType e if eventType == event.StateMember { as.fillMemberEvent(ctx, roomID, id.UserID(stateKey), content) } - if ts.IsZero() { - resp, err = as.Matrix.SendStateEvent(ctx, roomID, eventType, stateKey, content) - } else { - resp, err = as.Matrix.SendMassagedStateEvent(ctx, roomID, eventType, stateKey, content, ts.UnixMilli()) - } + resp, err = as.Matrix.SendStateEvent(ctx, roomID, eventType, stateKey, content, mautrix.ReqSendEvent{Timestamp: ts.UnixMilli()}) if err != nil && eventType == event.StateMember { var httpErr mautrix.HTTPError if errors.As(err, &httpErr) && httpErr.RespError != nil && diff --git a/client.go b/client.go index 3c60a2d1..d07bede5 100644 --- a/client.go +++ b/client.go @@ -1342,9 +1342,9 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event return } -// SendStateEvent sends a state event into a room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey +// SendStateEvent sends a state event into a room. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. -func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { +func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON any, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { var req ReqSendEvent if len(extra) > 0 { req = extra[0] @@ -1360,6 +1360,9 @@ func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventTy if req.UnstableDelay > 0 { queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10) } + if req.Timestamp > 0 { + queryParams["ts"] = strconv.FormatInt(req.Timestamp, 10) + } urlData := ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey} urlPath := cli.BuildURLWithQuery(urlData, queryParams) @@ -1372,14 +1375,12 @@ func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventTy // SendMassagedStateEvent sends a state event into a room with a custom timestamp. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. +// +// Deprecated: SendStateEvent accepts a timestamp via ReqSendEvent and should be used instead. func (cli *Client) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) { - urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey}, map[string]string{ - "ts": strconv.FormatInt(ts, 10), + resp, err = cli.SendStateEvent(ctx, roomID, eventType, stateKey, contentJSON, ReqSendEvent{ + Timestamp: ts, }) - _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp) - if err == nil && cli.StateStore != nil { - cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON) - } return } From a61e4d05f868b147f10ff0b5c16f7ff21b17ddfc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 13 Nov 2025 17:39:27 +0200 Subject: [PATCH 1511/1647] bridgev2/matrix: use MSC4169 to send redactions when available --- bridgev2/matrix/intent.go | 3 +-- versions.go | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index cb4b9b8f..1f82f77f 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -48,8 +48,7 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType if extra == nil { extra = &bridgev2.MatrixSendExtra{} } - // TODO remove this once hungryserv and synapse support sending m.room.redactions directly in all room versions - if eventType == event.EventRedaction { + if eventType == event.EventRedaction && !as.Connector.SpecVersions.Supports(mautrix.FeatureRedactSendAsEvent) { parsedContent := content.Parsed.(*event.RedactionEventContent) as.Matrix.AddDoublePuppetValue(content) return as.Matrix.RedactEvent(ctx, roomID, parsedContent.Redacts, mautrix.ReqRedact{ diff --git a/versions.go b/versions.go index 8c1c49aa..2aaf6399 100644 --- a/versions.go +++ b/versions.go @@ -69,6 +69,7 @@ var ( FeatureAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323"} FeatureUnstableProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133"} FeatureArbitraryProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133.stable", SpecVersion: SpecV116} + FeatureRedactSendAsEvent = UnstableFeature{UnstableFlag: "com.beeper.msc4169"} BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"} BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"} From a0cb5c6129feda78147fff549eac84e488f3b2ad Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 13 Nov 2025 18:10:27 +0200 Subject: [PATCH 1512/1647] bridgev2/backfill: ignore nil reactions --- bridgev2/portalbackfill.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index cbbce596..88503380 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -394,6 +394,9 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin } slices.Sort(partIDs) for _, reaction := range msg.Reactions { + if reaction == nil { + continue + } reactionIntent, ok := portal.GetIntentFor(ctx, reaction.Sender, source, RemoteEventReactionRemove) if !ok { continue From 202c7f117634a595b948b005541bcd4f512e164c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 16 Nov 2025 12:43:52 +0200 Subject: [PATCH 1513/1647] dependencies: update --- go.mod | 16 ++++++++-------- go.sum | 28 ++++++++++++++-------------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/go.mod b/go.mod index fb63cf59..c2acc7d6 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,7 @@ module maunium.net/go/mautrix go 1.24.0 -toolchain go1.25.3 +toolchain go1.25.4 require ( filippo.io/edwards25519 v1.1.0 @@ -17,12 +17,12 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.13 - go.mau.fi/util v0.9.2 + go.mau.fi/util v0.9.3 go.mau.fi/zeroconfig v0.2.0 - golang.org/x/crypto v0.43.0 - golang.org/x/exp v0.0.0-20251009144603-d2f985daa21b - golang.org/x/net v0.46.0 - golang.org/x/sync v0.17.0 + golang.org/x/crypto v0.44.0 + golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 + golang.org/x/net v0.47.0 + golang.org/x/sync v0.18.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 ) @@ -36,7 +36,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect - golang.org/x/sys v0.37.0 // indirect - golang.org/x/text v0.30.0 // indirect + golang.org/x/sys v0.38.0 // indirect + golang.org/x/text v0.31.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index faa4ef4c..b5fbf85f 100644 --- a/go.sum +++ b/go.sum @@ -51,26 +51,26 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.13 h1:GPddIs617DnBLFFVJFgpo1aBfe/4xcvMc3SB5t/D0pA= github.com/yuin/goldmark v1.7.13/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.9.2 h1:+S4Z03iCsGqU2WY8X2gySFsFjaLlUHFRDVCYvVwynKM= -go.mau.fi/util v0.9.2/go.mod h1:055elBBCJSdhRsmub7ci9hXZPgGr1U6dYg44cSgRgoU= +go.mau.fi/util v0.9.3 h1:aqNF8KDIN8bFpFbybSk+mEBil7IHeBwlujfyTnvP0uU= +go.mau.fi/util v0.9.3/go.mod h1:krWWfBM1jWTb5f8NCa2TLqWMQuM81X7TGQjhMjBeXmQ= go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= -golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= -golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= -golang.org/x/exp v0.0.0-20251009144603-d2f985daa21b h1:18qgiDvlvH7kk8Ioa8Ov+K6xCi0GMvmGfGW0sgd/SYA= -golang.org/x/exp v0.0.0-20251009144603-d2f985daa21b/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= -golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= -golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= -golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= -golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU= +golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc= +golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0= +golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= +golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= +golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= -golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= -golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= +golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= From 36029b762290538d82ef566fce87d7b72ff5732e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 16 Nov 2025 12:51:14 +0200 Subject: [PATCH 1514/1647] Bump version to v0.26.0 --- CHANGELOG.md | 12 +++++++++++- version.go | 2 +- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7ee4a13d..b6c0ff70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,11 +1,14 @@ -## v0.26.0 (unreleased) +## v0.26.0 (2025-11-16) +* *(client,appservice)* Deprecated `SendMassagedStateEvent` as `SendStateEvent` + has been able to do the same for a while now. * *(client,federation)* Added size limits for responses to make it safer to send requests to untrusted servers. * *(client)* Added wrapper for `/admin/whois` client API (thanks to [@nexy7574] in [#411]). * *(synapseadmin)* Added `force_purge` option to DeleteRoom (thanks to [@nexy7574] in [#420]). +* *(statestore)* Added saving join rules for rooms. * *(bridgev2)* Added optional automatic rollback of room state if bridging the change to the remote network fails. * *(bridgev2)* Added management room notices if transient disconnect state @@ -13,13 +16,19 @@ * *(bridgev2)* Added interface to signal that certain participants couldn't be invited when creating a group. * *(bridgev2)* Added `select` type for user input fields in login. +* *(bridgev2)* Added interface to let network connector customize personal + filtering space. * *(bridgev2/matrix)* Added checks to avoid sending error messages in reply to other bots. +* *(bridgev2/matrix)* Switched to using [MSC4169] to send redactions whenever + possible. * *(bridgev2/publicmedia)* Added support for custom path prefixes, file names, and encrypted files. * *(bridgev2/commands)* Added command to resync a single portal. * *(bridgev2/commands)* Added create group command. * *(bridgev2/config)* Added option to limit maximum number of logins. +* *(bridgev2)* Changed ghost joining to skip unnecessary invite if portal room + is public. * *(bridgev2/disappear)* Changed read receipt handling to only start disappearing timers for messages up to the read message (note: may not work in all cases if the read receipt points at an unknown event). @@ -34,6 +43,7 @@ * *(federation)* Fixed validating auth for requests with query params. * *(federation/eventauth)* Fixed typo causing restricted joins to not work. +[MSC416]: https://github.com/matrix-org/matrix-spec-proposals/pull/4169 [#411]: github.com/mautrix/go/pull/411 [#420]: github.com/mautrix/go/pull/420 [#426]: github.com/mautrix/go/pull/426 diff --git a/version.go b/version.go index 7b4eea41..f6d20c3f 100644 --- a/version.go +++ b/version.go @@ -8,7 +8,7 @@ import ( "strings" ) -const Version = "v0.25.2" +const Version = "v0.26.0" var GoModVersion = "" var Commit = "" From 14b85e98a6a9b2c3aca5b94d04556cef474ab123 Mon Sep 17 00:00:00 2001 From: timedout Date: Mon, 17 Nov 2025 16:35:46 +0000 Subject: [PATCH 1515/1647] federation: Implement federated membership functions (make/send join/knock/leave) (#422) --- federation/client.go | 163 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 163 insertions(+) diff --git a/federation/client.go b/federation/client.go index b24fd2d2..183fb5d1 100644 --- a/federation/client.go +++ b/federation/client.go @@ -263,6 +263,169 @@ func (c *Client) GetOpenIDUserInfo(ctx context.Context, serverName, accessToken return } +type ReqMakeJoin struct { + RoomID id.RoomID + UserID id.UserID + Via string + SupportedVersions []id.RoomVersion +} + +type RespMakeJoin struct { + RoomVersion id.RoomVersion `json:"room_version"` + Event PDU `json:"event"` +} + +type ReqSendJoin struct { + RoomID id.RoomID + EventID id.EventID + OmitMembers bool + Event PDU + Via string +} + +type ReqSendKnock struct { + RoomID id.RoomID + EventID id.EventID + Event PDU + Via string +} + +type RespSendJoin struct { + AuthChain []PDU `json:"auth_chain"` + Event PDU `json:"event"` + MembersOmitted bool `json:"members_omitted"` + ServersInRoom []string `json:"servers_in_room"` + State []PDU `json:"state"` +} + +type RespSendKnock struct { + KnockRoomState []PDU `json:"knock_room_state"` +} + +type ReqSendInvite struct { + RoomID id.RoomID `json:"-"` + UserID id.UserID `json:"-"` + Event PDU `json:"event"` + InviteRoomState []PDU `json:"invite_room_state"` + RoomVersion id.RoomVersion `json:"room_version"` +} + +type RespSendInvite struct { + Event PDU `json:"event"` +} + +type ReqMakeLeave struct { + RoomID id.RoomID + UserID id.UserID + Via string +} + +type ReqSendLeave struct { + RoomID id.RoomID + EventID id.EventID + Event PDU + Via string +} + +type ( + ReqMakeKnock = ReqMakeJoin + RespMakeKnock = RespMakeJoin + RespMakeLeave = RespMakeJoin +) + +func (c *Client) MakeJoin(ctx context.Context, req *ReqMakeJoin) (resp *RespMakeJoin, err error) { + versions := make([]string, len(req.SupportedVersions)) + for i, v := range req.SupportedVersions { + versions[i] = string(v) + } + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.Via, + Method: http.MethodGet, + Path: URLPath{"v1", "make_join", req.RoomID, req.UserID}, + Query: url.Values{"ver": versions}, + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) MakeKnock(ctx context.Context, req *ReqMakeKnock) (resp *RespMakeKnock, err error) { + versions := make([]string, len(req.SupportedVersions)) + for i, v := range req.SupportedVersions { + versions[i] = string(v) + } + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.Via, + Method: http.MethodGet, + Path: URLPath{"v1", "make_knock", req.RoomID, req.UserID}, + Query: url.Values{"ver": versions}, + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) SendJoin(ctx context.Context, req *ReqSendJoin) (resp *RespSendJoin, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.Via, + Method: http.MethodPut, + Path: URLPath{"v2", "send_join", req.RoomID, req.EventID}, + Query: url.Values{ + "omit_members": {strconv.FormatBool(req.OmitMembers)}, + }, + Authenticate: true, + RequestJSON: req.Event, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) SendKnock(ctx context.Context, req *ReqSendKnock) (resp *RespSendKnock, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.Via, + Method: http.MethodPut, + Path: URLPath{"v1", "send_knock", req.RoomID, req.EventID}, + Authenticate: true, + RequestJSON: req.Event, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) SendInvite(ctx context.Context, req *ReqSendInvite) (resp *RespSendInvite, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.UserID.Homeserver(), + Method: http.MethodPut, + Path: URLPath{"v2", "invite", req.RoomID, req.UserID}, + Authenticate: true, + RequestJSON: req, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) MakeLeave(ctx context.Context, req *ReqMakeLeave) (resp *RespMakeLeave, err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.Via, + Method: http.MethodGet, + Path: URLPath{"v1", "make_leave", req.RoomID, req.UserID}, + Authenticate: true, + ResponseJSON: &resp, + }) + return +} + +func (c *Client) SendLeave(ctx context.Context, req *ReqSendLeave) (err error) { + _, _, err = c.MakeFullRequest(ctx, RequestParams{ + ServerName: req.Via, + Method: http.MethodPut, + Path: URLPath{"v2", "send_leave", req.RoomID, req.EventID}, + Authenticate: true, + RequestJSON: req.Event, + }) + return +} + type URLPath []any func (fup URLPath) FullPath() []any { From 346100cfd4fae875fd13b7e82b0ebe94c8e77a74 Mon Sep 17 00:00:00 2001 From: Finn Date: Mon, 17 Nov 2025 10:18:46 -0800 Subject: [PATCH 1516/1647] statestore: fix missing JoinRules map when initializing MemoryStateStore (#432) --- statestore.go | 1 + 1 file changed, 1 insertion(+) diff --git a/statestore.go b/statestore.go index c6267c5b..2bd498dd 100644 --- a/statestore.go +++ b/statestore.go @@ -129,6 +129,7 @@ func NewMemoryStateStore() StateStore { PowerLevels: make(map[id.RoomID]*event.PowerLevelsEventContent), Encryption: make(map[id.RoomID]*event.EncryptionEventContent), Create: make(map[id.RoomID]*event.Event), + JoinRules: make(map[id.RoomID]*event.JoinRulesEventContent), } } From 606b627d48797c988884574473a884cbd220c438 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 16 Nov 2025 13:04:02 +0200 Subject: [PATCH 1517/1647] changelog: fix link --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b6c0ff70..b30e055e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,7 +43,7 @@ * *(federation)* Fixed validating auth for requests with query params. * *(federation/eventauth)* Fixed typo causing restricted joins to not work. -[MSC416]: https://github.com/matrix-org/matrix-spec-proposals/pull/4169 +[MSC4169]: https://github.com/matrix-org/matrix-spec-proposals/pull/4169 [#411]: github.com/mautrix/go/pull/411 [#420]: github.com/mautrix/go/pull/420 [#426]: github.com/mautrix/go/pull/426 From 8a59112eb1302b3d1429d096929289f2fca0c842 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 16 Nov 2025 13:04:13 +0200 Subject: [PATCH 1518/1647] client: move some room summary fields to public room info --- responses.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/responses.go b/responses.go index e7b6b75e..d822c84b 100644 --- a/responses.go +++ b/responses.go @@ -263,10 +263,7 @@ type RespMutualRooms struct { type RespRoomSummary struct { PublicRoomInfo - Membership event.Membership `json:"membership,omitempty"` - RoomVersion id.RoomVersion `json:"room_version,omitempty"` - Encryption id.Algorithm `json:"encryption,omitempty"` - AllowedRoomIDs []id.RoomID `json:"allowed_room_ids,omitempty"` + Membership event.Membership `json:"membership,omitempty"` UnstableRoomVersion id.RoomVersion `json:"im.nheko.summary.room_version,omitempty"` UnstableRoomVersionOld id.RoomVersion `json:"im.nheko.summary.version,omitempty"` @@ -685,6 +682,10 @@ type PublicRoomInfo struct { RoomType event.RoomType `json:"room_type"` Topic string `json:"topic,omitempty"` WorldReadable bool `json:"world_readable"` + + RoomVersion id.RoomVersion `json:"room_version,omitempty"` + Encryption id.Algorithm `json:"encryption,omitempty"` + AllowedRoomIDs []id.RoomID `json:"allowed_room_ids,omitempty"` } // RespHierarchy is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy From 57657d54eeac15038496c3df6c9388b7071ced0c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 19 Nov 2025 12:15:38 +0100 Subject: [PATCH 1519/1647] bridgev2: add custom event for requesting state change (#428) --- bridgev2/matrix/connector.go | 1 + bridgev2/messagestatus.go | 9 +++++--- bridgev2/portal.go | 44 +++++++++++++++++++++++++++++++++++- event/beeper.go | 8 +++++++ event/content.go | 1 + event/type.go | 1 + 6 files changed, 60 insertions(+), 4 deletions(-) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index 3e05837f..dbddaff2 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -145,6 +145,7 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) { br.EventProcessor.On(event.StateMember, br.handleRoomEvent) br.EventProcessor.On(event.StatePowerLevels, br.handleRoomEvent) br.EventProcessor.On(event.StateRoomName, br.handleRoomEvent) + br.EventProcessor.On(event.BeeperSendState, br.handleRoomEvent) br.EventProcessor.On(event.StateRoomAvatar, br.handleRoomEvent) br.EventProcessor.On(event.StateTopic, br.handleRoomEvent) br.EventProcessor.On(event.StateTombstone, br.handleRoomEvent) diff --git a/bridgev2/messagestatus.go b/bridgev2/messagestatus.go index 7118649d..df0c9e4d 100644 --- a/bridgev2/messagestatus.go +++ b/bridgev2/messagestatus.go @@ -20,6 +20,7 @@ import ( type MessageStatusEventInfo struct { RoomID id.RoomID + TransactionID string SourceEventID id.EventID NewEventID id.EventID EventType event.Type @@ -41,6 +42,7 @@ func StatusEventInfoFromEvent(evt *event.Event) *MessageStatusEventInfo { return &MessageStatusEventInfo{ RoomID: evt.RoomID, + TransactionID: evt.Unsigned.TransactionID, SourceEventID: evt.ID, EventType: evt.Type, MessageType: evt.Content.AsMessage().MsgType, @@ -182,9 +184,10 @@ func (ms *MessageStatus) ToMSSEvent(evt *MessageStatusEventInfo) *event.BeeperMe Type: event.RelReference, EventID: evt.SourceEventID, }, - Status: ms.Status, - Reason: ms.ErrorReason, - Message: ms.Message, + TargetTxnID: evt.TransactionID, + Status: ms.Status, + Reason: ms.ErrorReason, + Message: ms.Message, } if ms.InternalError != nil { content.InternalError = ms.InternalError.Error() diff --git a/bridgev2/portal.go b/bridgev2/portal.go index b664c8f6..0fae1724 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -512,6 +512,13 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal }() switch evt := rawEvt.(type) { case *portalMatrixEvent: + isStateRequest := evt.evt.Type == event.BeeperSendState + if isStateRequest { + if err := portal.unwrapBeeperSendState(ctx, evt.evt); err != nil { + portal.sendErrorStatus(ctx, evt.evt, err) + return + } + } res = portal.handleMatrixEvent(ctx, evt.sender, evt.evt) if res.SendMSS { if res.Error != nil { @@ -520,9 +527,21 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal portal.sendSuccessStatus(ctx, evt.evt, 0, "") } } - if res.Error != nil && evt.evt.StateKey != nil { + if !isStateRequest && res.Error != nil && evt.evt.StateKey != nil { portal.revertRoomMeta(ctx, evt.evt) } + if isStateRequest && res.Success { + portal.sendRoomMeta( + ctx, + evt.sender.DoublePuppet(ctx), + time.UnixMilli(evt.evt.Timestamp), + evt.evt.Type, + evt.evt.GetStateKey(), + evt.evt.Content.Parsed, + false, + evt.evt.Content.Raw, + ) + } case *portalRemoteEvent: res = portal.handleRemoteEvent(ctx, evt.source, evt.evtType, evt.evt) case *portalCreateEvent: @@ -534,6 +553,29 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal } } +func (portal *Portal) unwrapBeeperSendState(ctx context.Context, evt *event.Event) error { + content, ok := evt.Content.Parsed.(*event.BeeperSendStateEventContent) + if !ok { + return fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed) + } + evt.Content = content.Content + evt.StateKey = &content.StateKey + evt.Type = event.Type{Type: content.Type, Class: event.StateEventType} + _ = evt.Content.ParseRaw(evt.Type) + mx, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState) + if !ok { + return fmt.Errorf("matrix connector doesn't support fetching state") + } + prevEvt, err := mx.GetStateEvent(ctx, portal.MXID, evt.Type, evt.GetStateKey()) + if err != nil { + return fmt.Errorf("failed to get prev event: %w", err) + } else if prevEvt != nil { + evt.Unsigned.PrevContent = &prevEvt.Content + evt.Unsigned.PrevSender = prevEvt.Sender + } + return nil +} + func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowRelay bool) (*UserLogin, *database.UserPortal, error) { if portal.Receiver != "" { login, err := portal.Bridge.GetExistingUserLoginByID(ctx, portal.Receiver) diff --git a/event/beeper.go b/event/beeper.go index 95b4a571..94892de7 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -53,6 +53,8 @@ type BeeperMessageStatusEventContent struct { LastRetry id.EventID `json:"last_retry,omitempty"` + TargetTxnID string `json:"relates_to_txn_id,omitempty"` + MutateEventKey string `json:"mutate_event_key,omitempty"` // Indicates the set of users to whom the event was delivered. If nil, then @@ -90,6 +92,12 @@ type BeeperChatDeleteEventContent struct { DeleteForEveryone bool `json:"delete_for_everyone,omitempty"` } +type BeeperSendStateEventContent struct { + Type string `json:"type"` + StateKey string `json:"state_key"` + Content Content `json:"content"` +} + type IntOrString int func (ios *IntOrString) UnmarshalJSON(data []byte) error { diff --git a/event/content.go b/event/content.go index c0ff51ad..73fb0db5 100644 --- a/event/content.go +++ b/event/content.go @@ -64,6 +64,7 @@ var TypeMap = map[Type]reflect.Type{ BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}), BeeperTranscription: reflect.TypeOf(BeeperTranscriptionEventContent{}), BeeperDeleteChat: reflect.TypeOf(BeeperChatDeleteEventContent{}), + BeeperSendState: reflect.TypeOf(BeeperSendStateEventContent{}), AccountDataRoomTags: reflect.TypeOf(TagEventContent{}), AccountDataDirectChats: reflect.TypeOf(DirectChatsEventContent{}), diff --git a/event/type.go b/event/type.go index 56ea82f6..4fca07ea 100644 --- a/event/type.go +++ b/event/type.go @@ -237,6 +237,7 @@ var ( BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType} BeeperTranscription = Type{"com.beeper.transcription", MessageEventType} BeeperDeleteChat = Type{"com.beeper.delete_chat", MessageEventType} + BeeperSendState = Type{"com.beeper.send_state", MessageEventType} EventUnstablePollStart = Type{Type: "org.matrix.msc3381.poll.start", Class: MessageEventType} EventUnstablePollResponse = Type{Type: "org.matrix.msc3381.poll.response", Class: MessageEventType} From fa56255a06be1ea60ec6ed7b544fa123a640ed3c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 19 Nov 2025 23:13:19 +0200 Subject: [PATCH 1520/1647] bridgev2/portal: ignore not found errors when fetching prev state --- bridgev2/portal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 0fae1724..27faef73 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -567,7 +567,7 @@ func (portal *Portal) unwrapBeeperSendState(ctx context.Context, evt *event.Even return fmt.Errorf("matrix connector doesn't support fetching state") } prevEvt, err := mx.GetStateEvent(ctx, portal.MXID, evt.Type, evt.GetStateKey()) - if err != nil { + if err != nil && !errors.Is(err, mautrix.MNotFound) { return fmt.Errorf("failed to get prev event: %w", err) } else if prevEvt != nil { evt.Unsigned.PrevContent = &prevEvt.Content From 1fac8ceb66534a7e34b4ad070b69a71034dabcd3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 19 Nov 2025 23:21:56 +0200 Subject: [PATCH 1521/1647] bridgev2/matrix: fix GetStateEvent not passing state key through --- bridgev2/matrix/connector.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index dbddaff2..e34e3252 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -622,7 +622,7 @@ func (br *Connector) GetStateEvent(ctx context.Context, roomID id.RoomID, eventT } } } - return br.Bot.FullStateEvent(ctx, roomID, eventType, "") + return br.Bot.FullStateEvent(ctx, roomID, eventType, stateKey) } func (br *Connector) GetMembers(ctx context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) { From 75d54132ae2619e63db6f762a2452c4d6388260d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 21 Nov 2025 16:07:16 +0200 Subject: [PATCH 1522/1647] bridgev2/portal: fix getting state events in roomIsPublic --- bridgev2/portal.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 27faef73..032207e8 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -4295,7 +4295,11 @@ func looksDirectlyJoinable(rule *event.JoinRulesEventContent) bool { } func (portal *Portal) roomIsPublic(ctx context.Context) bool { - evt, err := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState).GetStateEvent(ctx, portal.MXID, event.StateJoinRules, "") + mx, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState) + if !ok { + return false + } + evt, err := mx.GetStateEvent(ctx, portal.MXID, event.StateJoinRules, "") if err != nil { zerolog.Ctx(ctx).Warn().Err(err).Msg("Failed to get join rules to check if room is public") return false From 41b1dfc8c14150232ad162b66e544b8a5cbff6ed Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 23 Nov 2025 15:51:15 +0200 Subject: [PATCH 1523/1647] bridgev2/provisionutil: check for orphaned DMs in resolve identifier --- bridgev2/matrixinvite.go | 59 +++++++++++---------- bridgev2/provisionutil/resolveidentifier.go | 1 + 2 files changed, 32 insertions(+), 28 deletions(-) diff --git a/bridgev2/matrixinvite.go b/bridgev2/matrixinvite.go index b8a5aec6..75c00cb0 100644 --- a/bridgev2/matrixinvite.go +++ b/bridgev2/matrixinvite.go @@ -88,6 +88,36 @@ func sendErrorAndLeave(ctx context.Context, evt *event.Event, intent MatrixAPI, rejectInvite(ctx, evt, intent, "") } +func (portal *Portal) CleanupOrphanedDM(ctx context.Context, userMXID id.UserID) { + if portal.MXID == "" { + return + } + log := zerolog.Ctx(ctx) + existingPortalMembers, err := portal.Bridge.Matrix.GetMembers(ctx, portal.MXID) + if err != nil { + log.Err(err). + Stringer("old_portal_mxid", portal.MXID). + Msg("Failed to check existing portal members, deleting room") + } else if targetUserMember, ok := existingPortalMembers[userMXID]; !ok { + log.Debug(). + Stringer("old_portal_mxid", portal.MXID). + Msg("Inviter has no member event in old portal, deleting room") + } else if targetUserMember.Membership.IsInviteOrJoin() { + return + } else { + log.Debug(). + Stringer("old_portal_mxid", portal.MXID). + Str("membership", string(targetUserMember.Membership)). + Msg("Inviter is not in old portal, deleting room") + } + + if err = portal.RemoveMXID(ctx); err != nil { + log.Err(err).Msg("Failed to delete old portal mxid") + } else if err = portal.Bridge.Bot.DeleteRoom(ctx, portal.MXID, true); err != nil { + log.Err(err).Msg("Failed to clean up old portal room") + } +} + func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sender *User) EventHandlingResult { ghostID, _ := br.Matrix.ParseGhostMXID(id.UserID(evt.GetStateKey())) validator, ok := br.Network.(IdentifierValidatingNetwork) @@ -165,34 +195,7 @@ func (br *Bridge) handleGhostDMInvite(ctx context.Context, evt *event.Event, sen return EventHandlingResultFailed } } - if portal.MXID != "" { - doCleanup := true - existingPortalMembers, err := br.Matrix.GetMembers(ctx, portal.MXID) - if err != nil { - log.Err(err). - Stringer("old_portal_mxid", portal.MXID). - Msg("Failed to check existing portal members, deleting room") - } else if targetUserMember, ok := existingPortalMembers[sender.MXID]; !ok { - log.Debug(). - Stringer("old_portal_mxid", portal.MXID). - Msg("Inviter has no member event in old portal, deleting room") - } else if targetUserMember.Membership.IsInviteOrJoin() { - doCleanup = false - } else { - log.Debug(). - Stringer("old_portal_mxid", portal.MXID). - Str("membership", string(targetUserMember.Membership)). - Msg("Inviter is not in old portal, deleting room") - } - - if doCleanup { - if err = portal.RemoveMXID(ctx); err != nil { - log.Err(err).Msg("Failed to delete old portal mxid") - } else if err = br.Bot.DeleteRoom(ctx, portal.MXID, true); err != nil { - log.Err(err).Msg("Failed to clean up old portal room") - } - } - } + portal.CleanupOrphanedDM(ctx, sender.MXID) err = invitedGhost.Intent.EnsureInvited(ctx, evt.RoomID, br.Bot.GetMXID()) if err != nil { log.Err(err).Msg("Failed to ensure bot is invited to room") diff --git a/bridgev2/provisionutil/resolveidentifier.go b/bridgev2/provisionutil/resolveidentifier.go index 5387347c..cfc388d0 100644 --- a/bridgev2/provisionutil/resolveidentifier.go +++ b/bridgev2/provisionutil/resolveidentifier.go @@ -109,6 +109,7 @@ func ResolveIdentifier( return nil, bridgev2.RespError(mautrix.MUnknown.WithMessage("Failed to get portal")) } } + resp.Chat.Portal.CleanupOrphanedDM(ctx, login.UserMXID) if createChat && resp.Chat.Portal.MXID == "" { apiResp.JustCreated = true err := resp.Chat.Portal.CreateMatrixRoom(ctx, login, resp.Chat.PortalInfo) From eaa4e07eae677740d0ce2ef5ea8c8f763d8f5ed5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 25 Nov 2025 14:23:09 +0200 Subject: [PATCH 1524/1647] bridgev2/portal: only allow setting receiver as relay in split portals --- bridgev2/commands/relay.go | 16 +++++++++++++--- bridgev2/portal.go | 3 +++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/bridgev2/commands/relay.go b/bridgev2/commands/relay.go index af756c87..94c19739 100644 --- a/bridgev2/commands/relay.go +++ b/bridgev2/commands/relay.go @@ -37,7 +37,7 @@ func fnSetRelay(ce *Event) { } onlySetDefaultRelays := !ce.User.Permissions.Admin && ce.Bridge.Config.Relay.AdminOnly var relay *bridgev2.UserLogin - if len(ce.Args) == 0 { + if len(ce.Args) == 0 && ce.Portal.Receiver == "" { relay = ce.User.GetDefaultLogin() isLoggedIn := relay != nil if onlySetDefaultRelays { @@ -73,9 +73,19 @@ func fnSetRelay(ce *Event) { } } } else { - relay = ce.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(ce.Args[0])) + var targetID networkid.UserLoginID + if ce.Portal.Receiver != "" { + targetID = ce.Portal.Receiver + if len(ce.Args) > 0 && ce.Args[0] != string(targetID) { + ce.Reply("In split portals, only the receiver (%s) can be set as relay", targetID) + return + } + } else { + targetID = networkid.UserLoginID(ce.Args[0]) + } + relay = ce.Bridge.GetCachedUserLoginByID(targetID) if relay == nil { - ce.Reply("User login with ID `%s` not found", ce.Args[0]) + ce.Reply("User login with ID `%s` not found", targetID) return } else if slices.Contains(ce.Bridge.Config.Relay.DefaultRelays, relay.ID) { // All good diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 032207e8..8c628880 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -5153,6 +5153,9 @@ func (portal *Portal) Save(ctx context.Context) error { } func (portal *Portal) SetRelay(ctx context.Context, relay *UserLogin) error { + if portal.Receiver != "" && relay.ID != portal.Receiver { + return fmt.Errorf("can't set non-receiver login as relay") + } portal.Relay = relay if relay == nil { portal.RelayLoginID = "" From 0f2ff4a090a1ca84b9adba07afa7a0dafee667ba Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 25 Nov 2025 14:23:21 +0200 Subject: [PATCH 1525/1647] bridgev2/portal: improve error messages in FindPreferredLogin when portal has receiver --- bridgev2/portal.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 8c628880..e777a717 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -582,12 +582,15 @@ func (portal *Portal) FindPreferredLogin(ctx context.Context, user *User, allowR if err != nil { return nil, nil, err } - if login == nil || login.UserMXID != user.MXID || !login.Client.IsLoggedIn() { + if login == nil { + return nil, nil, fmt.Errorf("%w (receiver login is nil)", ErrNotLoggedIn) + } else if !login.Client.IsLoggedIn() { + return nil, nil, fmt.Errorf("%w (receiver login is not logged in)", ErrNotLoggedIn) + } else if login.UserMXID != user.MXID { if allowRelay && portal.Relay != nil { return nil, nil, nil } - // TODO different error for this case? - return nil, nil, ErrNotLoggedIn + return nil, nil, fmt.Errorf("%w (relay not set and receiver login is owned by %s, not %s)", ErrNotLoggedIn, login.UserMXID, user.MXID) } up, err := portal.Bridge.DB.UserPortal.Get(ctx, login.UserLogin, portal.PortalKey) return login, up, err From dc38165473d052b59c967cd322d8c67a731730ab Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Wed, 26 Nov 2025 10:24:51 +0000 Subject: [PATCH 1526/1647] crypto: allow storing arbitrary metadata alongside encrypted account data For example, the creation time of a key. --- crypto/ssss/client.go | 16 ++++++++++++++++ crypto/ssss/types.go | 1 + 2 files changed, 17 insertions(+) diff --git a/crypto/ssss/client.go b/crypto/ssss/client.go index e30925d9..8691d032 100644 --- a/crypto/ssss/client.go +++ b/crypto/ssss/client.go @@ -95,6 +95,22 @@ func (mach *Machine) SetEncryptedAccountData(ctx context.Context, eventType even return mach.Client.SetAccountData(ctx, eventType.Type, &EncryptedAccountDataEventContent{Encrypted: encrypted}) } +// SetEncryptedAccountDataWithMetadata encrypts the given data with the given keys and stores it, +// alongside the unencrypted metadata, on the server. +func (mach *Machine) SetEncryptedAccountDataWithMetadata(ctx context.Context, eventType event.Type, data []byte, metadata map[string]any, keys ...*Key) error { + if len(keys) == 0 { + return ErrNoKeyGiven + } + encrypted := make(map[string]EncryptedKeyData, len(keys)) + for _, key := range keys { + encrypted[key.ID] = key.Encrypt(eventType.Type, data) + } + return mach.Client.SetAccountData(ctx, eventType.Type, &EncryptedAccountDataEventContent{ + Encrypted: encrypted, + Metadata: metadata, + }) +} + // GenerateAndUploadKey generates a new SSSS key and stores the metadata on the server. func (mach *Machine) GenerateAndUploadKey(ctx context.Context, passphrase string) (key *Key, err error) { key, err = NewKey(passphrase) diff --git a/crypto/ssss/types.go b/crypto/ssss/types.go index 345393b0..c08f107c 100644 --- a/crypto/ssss/types.go +++ b/crypto/ssss/types.go @@ -57,6 +57,7 @@ type EncryptedKeyData struct { type EncryptedAccountDataEventContent struct { Encrypted map[string]EncryptedKeyData `json:"encrypted"` + Metadata map[string]any `json:"com.beeper.metadata,omitzero"` } func (ed *EncryptedAccountDataEventContent) Decrypt(eventType string, key *Key) ([]byte, error) { From 016637ebf88a33d5c11c62e140f5a49b795db370 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Mon, 24 Nov 2025 19:07:56 +0000 Subject: [PATCH 1527/1647] bridgev2/bridgestate: add var to disable catching bridge state queue panics --- bridgev2/bridgestate.go | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/bridgev2/bridgestate.go b/bridgev2/bridgestate.go index 63d5876b..a1d3e70b 100644 --- a/bridgev2/bridgestate.go +++ b/bridgev2/bridgestate.go @@ -22,6 +22,8 @@ import ( "maunium.net/go/mautrix/format" ) +var CatchBridgeStateQueuePanics = true + type BridgeStateQueue struct { prevUnsent *status.BridgeState prevSent *status.BridgeState @@ -84,15 +86,17 @@ func (bsq *BridgeStateQueue) StopUnknownErrorReconnect() { } func (bsq *BridgeStateQueue) loop() { - defer func() { - err := recover() - if err != nil { - bsq.login.Log.Error(). - Bytes(zerolog.ErrorStackFieldName, debug.Stack()). - Any(zerolog.ErrorFieldName, err). - Msg("Panic in bridge state loop") - } - }() + if CatchBridgeStateQueuePanics { + defer func() { + err := recover() + if err != nil { + bsq.login.Log.Error(). + Bytes(zerolog.ErrorStackFieldName, debug.Stack()). + Any(zerolog.ErrorFieldName, err). + Msg("Panic in bridge state loop") + } + }() + } for state := range bsq.ch { bsq.immediateSendBridgeState(state) } From c3b85e8e3c3999ceb8dd267b2a0d3aec35058c05 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Thu, 20 Nov 2025 15:54:04 +0000 Subject: [PATCH 1528/1647] client: add special error that indicates to retry canceled contexts On it's own this is useless since the retries would all immediately fail with the canceled context error. The caller is expected to also set a `UpdateRequestOnRetry` on the client which is used to swap out the context. --- client.go | 5 ++++- error.go | 4 ++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index d07bede5..ba67a205 100644 --- a/client.go +++ b/client.go @@ -745,7 +745,10 @@ func (cli *Client) executeCompiledRequest( defer res.Body.Close() } if err != nil { - if retries > 0 && !errors.Is(err, context.Canceled) { + // Either error is *not* canceled or the underlying cause of cancelation explicitly asks to retry + canRetry := !errors.Is(err, context.Canceled) || + errors.Is(context.Cause(req.Context()), ErrContextCancelRetry) + if retries > 0 && canRetry { return cli.doRetry( req, err, retries, backoff, responseJSON, handler, dontReadResponse, sizeLimit, client, ) diff --git a/error.go b/error.go index 826af179..5ff671e0 100644 --- a/error.go +++ b/error.go @@ -85,6 +85,10 @@ var ( ErrResponseTooLong = errors.New("response content length too long") ErrBodyReadReachedLimit = errors.New("reached response size limit while reading body") + + // Special error that indicates we should retry canceled contexts. Note that on it's own this + // is useless, the context itself must also be replaced. + ErrContextCancelRetry = errors.New("retry canceled context") ) // HTTPError An HTTP Error response, which may wrap an underlying native Go Error. From 3293e2f8ff35ef032c6ddafef5b8abbdd72abf34 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 28 Nov 2025 13:38:05 +0200 Subject: [PATCH 1529/1647] dependencies: update --- go.mod | 8 ++++---- go.sum | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/go.mod b/go.mod index c2acc7d6..d873892c 100644 --- a/go.mod +++ b/go.mod @@ -17,10 +17,10 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.13 - go.mau.fi/util v0.9.3 + go.mau.fi/util v0.9.4-0.20251128113707-115b8b18bd18 go.mau.fi/zeroconfig v0.2.0 - golang.org/x/crypto v0.44.0 - golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 + golang.org/x/crypto v0.45.0 + golang.org/x/exp v0.0.0-20251125195548-87e1e737ad39 golang.org/x/net v0.47.0 golang.org/x/sync v0.18.0 gopkg.in/yaml.v3 v3.0.1 @@ -32,7 +32,7 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/petermattis/goid v0.0.0-20250904145737-900bdf8bb490 // indirect + github.com/petermattis/goid v0.0.0-20251121121749-a11dd1a45f9a // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect diff --git a/go.sum b/go.sum index b5fbf85f..fae6084d 100644 --- a/go.sum +++ b/go.sum @@ -26,8 +26,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/petermattis/goid v0.0.0-20250904145737-900bdf8bb490 h1:QTvNkZ5ylY0PGgA+Lih+GdboMLY/G9SEGLMEGVjTVA4= -github.com/petermattis/goid v0.0.0-20250904145737-900bdf8bb490/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/petermattis/goid v0.0.0-20251121121749-a11dd1a45f9a h1:VweslR2akb/ARhXfqSfRbj1vpWwYXf3eeAUyw/ndms0= +github.com/petermattis/goid v0.0.0-20251121121749-a11dd1a45f9a/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -51,14 +51,14 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.13 h1:GPddIs617DnBLFFVJFgpo1aBfe/4xcvMc3SB5t/D0pA= github.com/yuin/goldmark v1.7.13/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.9.3 h1:aqNF8KDIN8bFpFbybSk+mEBil7IHeBwlujfyTnvP0uU= -go.mau.fi/util v0.9.3/go.mod h1:krWWfBM1jWTb5f8NCa2TLqWMQuM81X7TGQjhMjBeXmQ= +go.mau.fi/util v0.9.4-0.20251128113707-115b8b18bd18 h1:h1/wE/SLTuat12/SRsKyh+edWX2Aung1ZsiWnY3t5Zs= +go.mau.fi/util v0.9.4-0.20251128113707-115b8b18bd18/go.mod h1:viDmhBOAFfcqDdKSk53EPJV3N4Mi8Jst5/ahGJ/vwsA= go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= -golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU= -golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc= -golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0= -golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0= +golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= +golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= +golang.org/x/exp v0.0.0-20251125195548-87e1e737ad39 h1:DHNhtq3sNNzrvduZZIiFyXWOL9IWaDPHqTnLJp+rCBY= +golang.org/x/exp v0.0.0-20251125195548-87e1e737ad39/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= From 1d1ecb228668b819d04bb1e6299b7944626c0c17 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 28 Nov 2025 13:40:54 +0200 Subject: [PATCH 1530/1647] federation/eventauth: fix sender membership check when kicking --- federation/eventauth/eventauth.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/federation/eventauth/eventauth.go b/federation/eventauth/eventauth.go index 32b4424b..eac110a3 100644 --- a/federation/eventauth/eventauth.go +++ b/federation/eventauth/eventauth.go @@ -484,7 +484,7 @@ func authorizeMember(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEv } return ErrCantLeaveWithoutBeingInRoom } - if senderMembership != event.MembershipLeave { + if senderMembership != event.MembershipJoin { // 5.5.2. If the sender’s current membership state is not join, reject. return ErrCantKickWithoutBeingInRoom } From 6e402e8fd2c2b131affef3feb87c0931953f6215 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 1 Dec 2025 00:10:29 +0200 Subject: [PATCH 1531/1647] bridgev2/backfill: don't try to backfill empty threads --- bridgev2/portalbackfill.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index 88503380..e8292388 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -194,6 +194,9 @@ func (portal *Portal) doThreadBackfill(ctx context.Context, source *UserLogin, t if err != nil { log.Err(err).Msg("Failed to get last thread message") return + } else if anchorMessage == nil { + log.Warn().Msg("No messages found in thread?") + return } resp := portal.fetchThreadBackfill(ctx, source, anchorMessage) if resp != nil { From 09052986b2d3333446a6ca3b4d18553b8602447c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 1 Dec 2025 15:28:56 +0200 Subject: [PATCH 1532/1647] bridgev2/commands: add command for muting chat on remote network --- bridgev2/commands/processor.go | 2 +- bridgev2/commands/startchat.go | 43 +++++++++++++++++++++++++++++++++- 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/bridgev2/commands/processor.go b/bridgev2/commands/processor.go index 13a35687..692db80d 100644 --- a/bridgev2/commands/processor.go +++ b/bridgev2/commands/processor.go @@ -44,7 +44,7 @@ func NewProcessor(bridge *bridgev2.Bridge) bridgev2.CommandProcessor { CommandRegisterPush, CommandSendAccountData, CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom, CommandLogin, CommandRelogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin, CommandSetRelay, CommandUnsetRelay, - CommandResolveIdentifier, CommandStartChat, CommandCreateGroup, CommandSearch, CommandSyncChat, + CommandResolveIdentifier, CommandStartChat, CommandCreateGroup, CommandSearch, CommandSyncChat, CommandMute, CommandSudo, CommandDoIn, ) return proc diff --git a/bridgev2/commands/startchat.go b/bridgev2/commands/startchat.go index 24586387..c7b05a6e 100644 --- a/bridgev2/commands/startchat.go +++ b/bridgev2/commands/startchat.go @@ -80,7 +80,7 @@ var CommandStartChat = &FullHandler{ NetworkAPI: NetworkAPIImplements[bridgev2.IdentifierResolvingNetworkAPI], } -func getClientForStartingChat[T bridgev2.IdentifierResolvingNetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) { +func getClientForStartingChat[T bridgev2.NetworkAPI](ce *Event, thing string) (*bridgev2.UserLogin, T, []string) { var remainingArgs []string if len(ce.Args) > 1 { remainingArgs = ce.Args[1:] @@ -290,3 +290,44 @@ func fnSearch(ce *Event) { } ce.Reply("Search results:\n\n%s", strings.Join(resultsString, "\n")) } + +var CommandMute = &FullHandler{ + Func: fnMute, + Name: "mute", + Aliases: []string{"unmute"}, + Help: HelpMeta{ + Section: HelpSectionChats, + Description: "Mute or unmute a chat on the remote network", + Args: "[duration]", + }, + RequiresPortal: true, + RequiresLogin: true, + NetworkAPI: NetworkAPIImplements[bridgev2.MuteHandlingNetworkAPI], +} + +func fnMute(ce *Event) { + _, api, _ := getClientForStartingChat[bridgev2.MuteHandlingNetworkAPI](ce, "muting chats") + var mutedUntil int64 + if ce.Command == "mute" { + mutedUntil = -1 + if len(ce.Args) > 0 { + duration, err := time.ParseDuration(ce.Args[0]) + if err != nil { + ce.Reply("Invalid duration: %v", err) + return + } + mutedUntil = time.Now().Add(duration).UnixMilli() + } + } + err := api.HandleMute(ce.Ctx, &bridgev2.MatrixMute{ + MatrixEventBase: bridgev2.MatrixEventBase[*event.BeeperMuteEventContent]{ + Content: &event.BeeperMuteEventContent{MutedUntil: mutedUntil}, + Portal: ce.Portal, + }, + }) + if err != nil { + ce.Reply("Failed to %s chat: %v", ce.Command, err) + } else { + ce.React("✅️") + } +} From e22802b9bb27d05dcd766df9b46f3fb55db6027f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 1 Dec 2025 17:07:54 +0200 Subject: [PATCH 1533/1647] bridgev2/database: improve missing parents when migrating to split portals --- bridgev2/database/portal.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index a230df19..f6868be6 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -88,7 +88,7 @@ const ( getPortalByIDWithUncertainReceiverQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND id=$2 AND (receiver=$3 OR receiver='')` getPortalByMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid=$2` getAllPortalsWithMXIDQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND mxid IS NOT NULL` - getAllPortalsWithoutReceiver = getPortalBaseQuery + `WHERE bridge_id=$1 AND receiver=''` + getAllPortalsWithoutReceiver = getPortalBaseQuery + `WHERE bridge_id=$1 AND (receiver='' OR (parent_id<>'' AND parent_receiver='')) ORDER BY parent_id DESC` getAllDMPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND other_user_id=$2` getDMPortalQuery = getPortalBaseQuery + `WHERE bridge_id=$1 AND room_type='dm' AND receiver=$2 AND other_user_id=$3` getAllPortalsQuery = getPortalBaseQuery + `WHERE bridge_id=$1` @@ -148,7 +148,10 @@ const ( ) ` fixParentsAfterSplitPortalMigrationQuery = ` - UPDATE portal SET parent_receiver=receiver WHERE bridge_id=$1 AND parent_receiver='' AND receiver<>'' AND parent_id<>''; + UPDATE portal + SET parent_receiver=receiver + WHERE bridge_id=$1 AND parent_receiver='' AND receiver<>'' AND parent_id<>'' + AND EXISTS(SELECT 1 FROM portal pp WHERE pp.bridge_id=$1 AND pp.id=portal.parent_id AND pp.receiver=portal.receiver); ` ) From 5206439b83b35211304b49265e20f8c5b6361f4a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 2 Dec 2025 13:52:28 +0200 Subject: [PATCH 1534/1647] bridgev2/portal: pass is state request flag to event handlers --- bridgev2/networkinterface.go | 3 ++- bridgev2/portal.go | 31 ++++++++++++++++++++----------- bridgev2/portalinternal.go | 20 ++++++++++++++------ bridgev2/queue.go | 1 + 4 files changed, 37 insertions(+), 18 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 9bbcf897..193dc909 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -1382,7 +1382,8 @@ type MatrixMessageRemove struct { type MatrixRoomMeta[ContentType any] struct { MatrixEventBase[ContentType] - PrevContent ContentType + PrevContent ContentType + IsStateRequest bool } type MatrixRoomName = MatrixRoomMeta[*event.RoomNameEventContent] diff --git a/bridgev2/portal.go b/bridgev2/portal.go index e777a717..84fb5333 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -519,7 +519,7 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal return } } - res = portal.handleMatrixEvent(ctx, evt.sender, evt.evt) + res = portal.handleMatrixEvent(ctx, evt.sender, evt.evt, isStateRequest) if res.SendMSS { if res.Error != nil { portal.sendErrorStatus(ctx, evt.evt, res.Error) @@ -673,7 +673,7 @@ func (portal *Portal) checkConfusableName(ctx context.Context, userID id.UserID, var fakePerMessageProfileEventType = event.Type{Class: event.StateEventType, Type: "m.per_message_profile"} -func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) EventHandlingResult { +func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt *event.Event, isStateRequest bool) EventHandlingResult { log := zerolog.Ctx(ctx) if evt.Mautrix.EventSource&event.SourceEphemeral != 0 { switch evt.Type { @@ -705,6 +705,9 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * } var origSender *OrigSender if login == nil { + if isStateRequest { + return EventHandlingResultFailed.WithMSSError(ErrCantRelayStateRequest) + } login = portal.Relay origSender = &OrigSender{ User: sender, @@ -775,13 +778,13 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * case event.EventRedaction: return portal.handleMatrixRedaction(ctx, login, origSender, evt) case event.StateRoomName: - return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomNameHandlingNetworkAPI.HandleMatrixRoomName) + return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomNameHandlingNetworkAPI.HandleMatrixRoomName) case event.StateTopic: - return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomTopicHandlingNetworkAPI.HandleMatrixRoomTopic) + return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomTopicHandlingNetworkAPI.HandleMatrixRoomTopic) case event.StateRoomAvatar: - return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, RoomAvatarHandlingNetworkAPI.HandleMatrixRoomAvatar) + return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, RoomAvatarHandlingNetworkAPI.HandleMatrixRoomAvatar) case event.StateBeeperDisappearingTimer: - return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, DisappearTimerChangingNetworkAPI.HandleMatrixDisappearingTimer) + return handleMatrixRoomMeta(portal, ctx, login, origSender, evt, isStateRequest, DisappearTimerChangingNetworkAPI.HandleMatrixDisappearingTimer) case event.StateEncryption: // TODO? return EventHandlingResultIgnored @@ -792,9 +795,9 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * case event.AccountDataBeeperMute: return handleMatrixAccountData(portal, ctx, login, evt, MuteHandlingNetworkAPI.HandleMute) case event.StateMember: - return portal.handleMatrixMembership(ctx, login, origSender, evt) + return portal.handleMatrixMembership(ctx, login, origSender, evt, isStateRequest) case event.StatePowerLevels: - return portal.handleMatrixPowerLevels(ctx, login, origSender, evt) + return portal.handleMatrixPowerLevels(ctx, login, origSender, evt, isStateRequest) case event.BeeperDeleteChat: return portal.handleMatrixDeleteChat(ctx, login, origSender, evt) default: @@ -1607,6 +1610,7 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( sender *UserLogin, origSender *OrigSender, evt *event.Event, + isStateRequest bool, fn func(APIType, context.Context, *MatrixRoomMeta[ContentType]) (bool, error), ) EventHandlingResult { if evt.StateKey == nil || *evt.StateKey != "" { @@ -1670,7 +1674,8 @@ func handleMatrixRoomMeta[APIType any, ContentType any]( InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, - PrevContent: prevContent, + IsStateRequest: isStateRequest, + PrevContent: prevContent, }) if err != nil { log.Err(err).Msg("Failed to handle Matrix room metadata") @@ -1797,6 +1802,7 @@ func (portal *Portal) handleMatrixMembership( sender *UserLogin, origSender *OrigSender, evt *event.Event, + isStateRequest bool, ) EventHandlingResult { if evt.StateKey == nil { return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey) @@ -1847,7 +1853,8 @@ func (portal *Portal) handleMatrixMembership( InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, - PrevContent: prevContent, + IsStateRequest: isStateRequest, + PrevContent: prevContent, }, Target: target, TargetGhost: targetGhost, @@ -1884,6 +1891,7 @@ func (portal *Portal) handleMatrixPowerLevels( sender *UserLogin, origSender *OrigSender, evt *event.Event, + isStateRequest bool, ) EventHandlingResult { if evt.StateKey == nil || *evt.StateKey != "" { return EventHandlingResultFailed.WithMSSError(ErrInvalidStateKey) @@ -1925,7 +1933,8 @@ func (portal *Portal) handleMatrixPowerLevels( InputTransactionID: portal.parseInputTransactionID(origSender, evt), }, - PrevContent: prevContent, + IsStateRequest: isStateRequest, + PrevContent: prevContent, }, Users: make(map[id.UserID]*UserPowerLevelChange), Events: make(map[string]*SinglePowerLevelChange), diff --git a/bridgev2/portalinternal.go b/bridgev2/portalinternal.go index 749ee389..4c7e2447 100644 --- a/bridgev2/portalinternal.go +++ b/bridgev2/portalinternal.go @@ -49,6 +49,10 @@ func (portal *PortalInternals) HandleSingleEvent(ctx context.Context, rawEvt any (*Portal)(portal).handleSingleEvent(ctx, rawEvt, doneCallback) } +func (portal *PortalInternals) UnwrapBeeperSendState(ctx context.Context, evt *event.Event) error { + return (*Portal)(portal).unwrapBeeperSendState(ctx, evt) +} + func (portal *PortalInternals) SendSuccessStatus(ctx context.Context, evt *event.Event, streamOrder int64, newEventID id.EventID) { (*Portal)(portal).sendSuccessStatus(ctx, evt, streamOrder, newEventID) } @@ -61,8 +65,8 @@ func (portal *PortalInternals) CheckConfusableName(ctx context.Context, userID i return (*Portal)(portal).checkConfusableName(ctx, userID, name) } -func (portal *PortalInternals) HandleMatrixEvent(ctx context.Context, sender *User, evt *event.Event) EventHandlingResult { - return (*Portal)(portal).handleMatrixEvent(ctx, sender, evt) +func (portal *PortalInternals) HandleMatrixEvent(ctx context.Context, sender *User, evt *event.Event, isStateRequest bool) EventHandlingResult { + return (*Portal)(portal).handleMatrixEvent(ctx, sender, evt, isStateRequest) } func (portal *PortalInternals) HandleMatrixReceipts(ctx context.Context, evt *event.Event) EventHandlingResult { @@ -125,12 +129,12 @@ func (portal *PortalInternals) HandleMatrixDeleteChat(ctx context.Context, sende return (*Portal)(portal).handleMatrixDeleteChat(ctx, sender, origSender, evt) } -func (portal *PortalInternals) HandleMatrixMembership(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { - return (*Portal)(portal).handleMatrixMembership(ctx, sender, origSender, evt) +func (portal *PortalInternals) HandleMatrixMembership(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, isStateRequest bool) EventHandlingResult { + return (*Portal)(portal).handleMatrixMembership(ctx, sender, origSender, evt, isStateRequest) } -func (portal *PortalInternals) HandleMatrixPowerLevels(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event) EventHandlingResult { - return (*Portal)(portal).handleMatrixPowerLevels(ctx, sender, origSender, evt) +func (portal *PortalInternals) HandleMatrixPowerLevels(ctx context.Context, sender *UserLogin, origSender *OrigSender, evt *event.Event, isStateRequest bool) EventHandlingResult { + return (*Portal)(portal).handleMatrixPowerLevels(ctx, sender, origSender, evt, isStateRequest) } func (portal *PortalInternals) HandleMatrixTombstone(ctx context.Context, evt *event.Event) EventHandlingResult { @@ -305,6 +309,10 @@ func (portal *PortalInternals) UpdateOtherUser(ctx context.Context, members *Cha return (*Portal)(portal).updateOtherUser(ctx, members) } +func (portal *PortalInternals) RoomIsPublic(ctx context.Context) bool { + return (*Portal)(portal).roomIsPublic(ctx) +} + func (portal *PortalInternals) SyncParticipants(ctx context.Context, members *ChatMemberList, source *UserLogin, sender MatrixAPI, ts time.Time) error { return (*Portal)(portal).syncParticipants(ctx, members, source, sender, ts) } diff --git a/bridgev2/queue.go b/bridgev2/queue.go index 308d03c5..8a3b707b 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -67,6 +67,7 @@ var ( ErrEventSenderUserNotFound = WrapErrorInStatus(errors.New("sender not found for event")).WithIsCertain(true).WithErrorAsMessage() ErrNoPermissionToInteract = WrapErrorInStatus(errors.New("you don't have permission to send messages")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage() ErrNoPermissionForCommands = WrapErrorInStatus(WrapErrorInStatus(errors.New("you don't have permission to use commands")).WithIsCertain(true).WithSendNotice(false).WithErrorAsMessage()) + ErrCantRelayStateRequest = WrapErrorInStatus(errors.New("relayed users can't use beeper state requests")).WithIsCertain(true).WithErrorAsMessage() ) func (br *Bridge) QueueMatrixEvent(ctx context.Context, evt *event.Event) EventHandlingResult { From dfd5485a0dafc809f2b9edc0b89e9dab85474aea Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 2 Dec 2025 14:16:22 +0200 Subject: [PATCH 1535/1647] bridgev2/networkinterface: remove deprecated fields in MatrixMembershipChange --- bridgev2/networkinterface.go | 5 ----- bridgev2/portal.go | 7 ++----- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 193dc909..b4bf36ff 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -897,11 +897,6 @@ type MatrixMembershipChange struct { MatrixRoomMeta[*event.MemberEventContent] Target GhostOrUserLogin Type MembershipChangeType - - // Deprecated: Use Target instead - TargetGhost *Ghost - // Deprecated: Use Target instead - TargetUserLogin *UserLogin } type MembershipHandlingNetworkAPI interface { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 84fb5333..5b4e31ef 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1842,7 +1842,6 @@ func (portal *Portal) handleMatrixMembership( return EventHandlingResultIgnored //.WithMSSError(ErrIgnoringLeaveEvent) } targetGhost, _ := target.(*Ghost) - targetUserLogin, _ := target.(*UserLogin) membershipChange := &MatrixMembershipChange{ MatrixRoomMeta: MatrixRoomMeta[*event.MemberEventContent]{ MatrixEventBase: MatrixEventBase[*event.MemberEventContent]{ @@ -1856,10 +1855,8 @@ func (portal *Portal) handleMatrixMembership( IsStateRequest: isStateRequest, PrevContent: prevContent, }, - Target: target, - TargetGhost: targetGhost, - TargetUserLogin: targetUserLogin, - Type: membershipChangeType, + Target: target, + Type: membershipChangeType, } _, err = api.HandleMatrixMembership(ctx, membershipChange) if err != nil { From 2eeece6942544a2e53f196b03f0bfab42c14db02 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 2 Dec 2025 15:22:01 +0200 Subject: [PATCH 1536/1647] bridgev2/networkinterface: allow HandleMatrixMembership to redirect invites to another user ID --- bridgev2/networkinterface.go | 6 ++++- bridgev2/portal.go | 48 +++++++++++++++++++++++++++++++++--- bridgev2/queue.go | 7 ++++++ 3 files changed, 57 insertions(+), 4 deletions(-) diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index b4bf36ff..9c3f7d71 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -899,9 +899,13 @@ type MatrixMembershipChange struct { Type MembershipChangeType } +type MatrixMembershipResult struct { + RedirectTo networkid.UserID +} + type MembershipHandlingNetworkAPI interface { NetworkAPI - HandleMatrixMembership(ctx context.Context, msg *MatrixMembershipChange) (bool, error) + HandleMatrixMembership(ctx context.Context, msg *MatrixMembershipChange) (*MatrixMembershipResult, error) } type SinglePowerLevelChange struct { diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 5b4e31ef..c0855c2d 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -530,7 +530,7 @@ func (portal *Portal) handleSingleEvent(ctx context.Context, rawEvt any, doneCal if !isStateRequest && res.Error != nil && evt.evt.StateKey != nil { portal.revertRoomMeta(ctx, evt.evt) } - if isStateRequest && res.Success { + if isStateRequest && res.Success && !res.SkipStateEcho { portal.sendRoomMeta( ctx, evt.sender.DoublePuppet(ctx), @@ -1858,12 +1858,54 @@ func (portal *Portal) handleMatrixMembership( Target: target, Type: membershipChangeType, } - _, err = api.HandleMatrixMembership(ctx, membershipChange) + res, err := api.HandleMatrixMembership(ctx, membershipChange) if err != nil { log.Err(err).Msg("Failed to handle Matrix membership change") return EventHandlingResultFailed.WithMSSError(err) } - return EventHandlingResultSuccess.WithMSS() + didRedirectInvite := membershipChangeType == Invite && + targetGhost != nil && + res != nil && + res.RedirectTo != "" && + res.RedirectTo != targetGhost.ID + if didRedirectInvite { + log.Debug(). + Str("orig_id", string(targetGhost.ID)). + Str("redirect_id", string(res.RedirectTo)). + Msg("Invite was redirected to different ghost") + var redirectGhost *Ghost + redirectGhost, err = portal.Bridge.GetGhostByID(ctx, res.RedirectTo) + if err != nil { + log.Err(err).Msg("Failed to get redirect target ghost") + return EventHandlingResultFailed.WithError(err) + } + if !isStateRequest { + portal.sendRoomMeta( + ctx, + sender.User.DoublePuppet(ctx), + time.UnixMilli(evt.Timestamp), + event.StateMember, + evt.GetStateKey(), + &event.MemberEventContent{ + Membership: event.MembershipLeave, + Reason: fmt.Sprintf("Invite redirected to %s", res.RedirectTo), + }, + true, + nil, + ) + } + portal.sendRoomMeta( + ctx, + sender.User.DoublePuppet(ctx), + time.UnixMilli(evt.Timestamp), + event.StateMember, + redirectGhost.Intent.GetMXID().String(), + content, + false, + nil, + ) + } + return EventHandlingResultSuccess.WithMSS().WithSkipStateEcho(didRedirectInvite) } func makePLChange(old, new int, newIsSet bool) *SinglePowerLevelChange { diff --git a/bridgev2/queue.go b/bridgev2/queue.go index 8a3b707b..6667caea 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -160,6 +160,8 @@ type EventHandlingResult struct { Ignored bool Queued bool + SkipStateEcho bool + // Error is an optional reason for failure. It is not required, Success may be false even without a specific error. Error error // Whether the Error should be sent as a MSS event. @@ -195,6 +197,11 @@ func (ehr EventHandlingResult) WithMSS() EventHandlingResult { return ehr } +func (ehr EventHandlingResult) WithSkipStateEcho(skip bool) EventHandlingResult { + ehr.SkipStateEcho = skip + return ehr +} + func (ehr EventHandlingResult) WithMSSError(err error) EventHandlingResult { if err == nil { return ehr From 7d54edbfda13aac65a7499353f9f0043e2c6338a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 2 Dec 2025 18:15:24 +0200 Subject: [PATCH 1537/1647] bridgev2/mxmain: add support for reading env vars from config --- bridgev2/bridgeconfig/config.go | 2 + bridgev2/bridgeconfig/upgrade.go | 3 + bridgev2/matrix/mxmain/envconfig.go | 161 +++++++++++++++++++++ bridgev2/matrix/mxmain/example-config.yaml | 10 ++ bridgev2/matrix/mxmain/main.go | 7 + 5 files changed, 183 insertions(+) create mode 100644 bridgev2/matrix/mxmain/envconfig.go diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index b1718f30..8b9aa019 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -33,6 +33,8 @@ type Config struct { Encryption EncryptionConfig `yaml:"encryption"` Logging zeroconfig.Config `yaml:"logging"` + EnvConfigPrefix string `yaml:"env_config_prefix"` + ManagementRoomTexts ManagementRoomTexts `yaml:"management_room_texts"` } diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 0dbff802..a3ac8747 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -184,6 +184,8 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Int, "encryption", "rotation", "messages") helper.Copy(up.Bool, "encryption", "rotation", "disable_device_change_key_rotation") + helper.Copy(up.Str, "env_config_prefix") + helper.Copy(up.Map, "logging") } @@ -211,6 +213,7 @@ var SpacedBlocks = [][]string{ {"backfill"}, {"double_puppet"}, {"encryption"}, + {"env_config_prefix"}, {"logging"}, } diff --git a/bridgev2/matrix/mxmain/envconfig.go b/bridgev2/matrix/mxmain/envconfig.go new file mode 100644 index 00000000..1b4f1467 --- /dev/null +++ b/bridgev2/matrix/mxmain/envconfig.go @@ -0,0 +1,161 @@ +// Copyright (c) 2025 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mxmain + +import ( + "fmt" + "iter" + "os" + "reflect" + "strconv" + "strings" + + "go.mau.fi/util/random" +) + +var randomParseFilePrefix = random.String(16) + "READFILE:" + +func parseEnv(prefix string) iter.Seq2[[]string, string] { + return func(yield func([]string, string) bool) { + for _, s := range os.Environ() { + if !strings.HasPrefix(s, prefix) { + continue + } + kv := strings.SplitN(s, "=", 2) + key := strings.TrimPrefix(kv[0], prefix) + value := kv[1] + if strings.HasSuffix(key, "_FILE") { + key = strings.TrimSuffix(key, "_FILE") + value = randomParseFilePrefix + value + } + key = strings.ToLower(key) + if !strings.ContainsRune(key, '.') { + key = strings.ReplaceAll(key, "__", ".") + } + if !yield(strings.Split(key, "."), value) { + return + } + } + } +} + +func reflectYAMLFieldName(f *reflect.StructField) string { + parts := strings.SplitN(f.Tag.Get("yaml"), ",", 2) + fieldName := parts[0] + if fieldName == "-" && len(parts) == 1 { + return "" + } + if fieldName == "" { + return strings.ToLower(f.Name) + } + return fieldName +} + +type reflectGetResult struct { + val reflect.Value + valKind reflect.Kind + remainingPath []string +} + +func reflectGetYAML(rv reflect.Value, path []string) (*reflectGetResult, bool) { + if len(path) == 0 { + return &reflectGetResult{val: rv, valKind: rv.Kind()}, true + } + if rv.Kind() == reflect.Ptr { + rv = rv.Elem() + } + switch rv.Kind() { + case reflect.Map: + return &reflectGetResult{val: rv, remainingPath: path, valKind: rv.Type().Elem().Kind()}, true + case reflect.Struct: + fields := reflect.VisibleFields(rv.Type()) + for _, field := range fields { + fieldName := reflectYAMLFieldName(&field) + if fieldName != "" && fieldName == path[0] { + return reflectGetYAML(rv.FieldByIndex(field.Index), path[1:]) + } + } + default: + } + return nil, false +} + +func reflectGetFromMainOrNetwork(main, network reflect.Value, path []string) (*reflectGetResult, bool) { + if len(path) > 0 && path[0] == "network" { + return reflectGetYAML(network, path[1:]) + } + return reflectGetYAML(main, path) +} + +func formatKeyString(key []string) string { + return strings.Join(key, "->") +} + +func UpdateConfigFromEnv(cfg, networkData any, prefix string) error { + cfgVal := reflect.ValueOf(cfg) + networkVal := reflect.ValueOf(networkData) + for key, value := range parseEnv(prefix) { + field, ok := reflectGetFromMainOrNetwork(cfgVal, networkVal, key) + if !ok { + return fmt.Errorf("%s not found", formatKeyString(key)) + } + if strings.HasPrefix(value, randomParseFilePrefix) { + filepath := strings.TrimPrefix(value, randomParseFilePrefix) + fileData, err := os.ReadFile(filepath) + if err != nil { + return fmt.Errorf("failed to read file %s for %s: %w", filepath, formatKeyString(key), err) + } + value = strings.TrimSpace(string(fileData)) + } + var parsedVal any + var err error + switch field.valKind { + case reflect.String: + parsedVal = value + case reflect.Bool: + parsedVal, err = strconv.ParseBool(value) + if err != nil { + return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + parsedVal, err = strconv.ParseInt(value, 10, 64) + if err != nil { + return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err) + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + parsedVal, err = strconv.ParseUint(value, 10, 64) + if err != nil { + return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err) + } + case reflect.Float32, reflect.Float64: + parsedVal, err = strconv.ParseFloat(value, 64) + if err != nil { + return fmt.Errorf("invalid value for %s: %w", formatKeyString(key), err) + } + default: + return fmt.Errorf("unsupported type %s in %s", field.valKind, formatKeyString(key)) + } + if field.val.Kind() == reflect.Ptr { + if field.val.IsNil() { + field.val.Set(reflect.New(field.val.Type().Elem())) + } + field.val = field.val.Elem() + } + if field.val.Kind() == reflect.Map { + key = key[:len(key)-len(field.remainingPath)] + mapKeyStr := strings.Join(field.remainingPath, ".") + key = append(key, mapKeyStr) + if field.val.Type().Key().Kind() != reflect.String { + return fmt.Errorf("unsupported map key type %s in %s", field.val.Type().Key().Kind(), formatKeyString(key)) + } + field.val.SetMapIndex(reflect.ValueOf(mapKeyStr), reflect.ValueOf(parsedVal)) + } else { + field.val.Set(reflect.ValueOf(parsedVal)) + } + } + return nil +} diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 27c3aa67..947d771b 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -444,6 +444,16 @@ encryption: # You should not enable this option unless you understand all the implications. disable_device_change_key_rotation: false +# Prefix for environment variables. All variables with this prefix must map to valid config fields. +# Nesting in variable names is represented with a dot (.). +# If there are no dots in the name, two underscores (__) are replaced with a dot. +# +# e.g. if the prefix is set to `BRIDGE_`, then `BRIDGE_APPSERVICE__AS_TOKEN` will set appservice.as_token. +# `BRIDGE_appservice.as_token` would work as well, but can't be set in a shell as easily. +# +# If this is null, reading config fields from environment will be disabled. +env_config_prefix: null + # Logging config. See https://github.com/tulir/zeroconfig for details. logging: min_level: debug diff --git a/bridgev2/matrix/mxmain/main.go b/bridgev2/matrix/mxmain/main.go index ca0ca5f7..1e8b51d1 100644 --- a/bridgev2/matrix/mxmain/main.go +++ b/bridgev2/matrix/mxmain/main.go @@ -354,6 +354,13 @@ func (br *BridgeMain) LoadConfig() { } } cfg.Bridge.Backfill = cfg.Backfill + if cfg.EnvConfigPrefix != "" { + err = UpdateConfigFromEnv(&cfg, networkData, cfg.EnvConfigPrefix) + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "Failed to parse environment variables:", err) + os.Exit(10) + } + } br.Config = &cfg } From 02ce6ff9185113f31c9f3a55b7e2e5e6fbd4101c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 3 Dec 2025 21:59:41 +0200 Subject: [PATCH 1538/1647] mediaproxy: allow delayed mime type and redirects for file responses --- mediaproxy/mediaproxy.go | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go index 07e30810..2063675a 100644 --- a/mediaproxy/mediaproxy.go +++ b/mediaproxy/mediaproxy.go @@ -95,9 +95,13 @@ func (d *GetMediaResponseCallback) GetContentType() string { return d.ContentType } +type FileMeta struct { + ContentType string + ReplacementFile string +} + type GetMediaResponseFile struct { - Callback func(w *os.File) error - ContentType string + Callback func(w *os.File) (*FileMeta, error) } type GetMediaFunc = func(ctx context.Context, mediaID string, params map[string]string) (response GetMediaResponse, err error) @@ -453,23 +457,35 @@ func doTempFileDownload( if err != nil { return false, fmt.Errorf("failed to create temp file: %w", err) } + origTempFile := tempFile defer func() { - _ = tempFile.Close() - _ = os.Remove(tempFile.Name()) + _ = origTempFile.Close() + _ = os.Remove(origTempFile.Name()) }() - err = data.Callback(tempFile) + meta, err := data.Callback(tempFile) if err != nil { return false, err } - _, err = tempFile.Seek(0, io.SeekStart) - if err != nil { - return false, fmt.Errorf("failed to seek to start of temp file: %w", err) + if meta.ReplacementFile != "" { + tempFile, err = os.Open(meta.ReplacementFile) + if err != nil { + return false, fmt.Errorf("failed to open replacement file: %w", err) + } + defer func() { + _ = tempFile.Close() + _ = os.Remove(origTempFile.Name()) + }() + } else { + _, err = tempFile.Seek(0, io.SeekStart) + if err != nil { + return false, fmt.Errorf("failed to seek to start of temp file: %w", err) + } } fileInfo, err := tempFile.Stat() if err != nil { return false, fmt.Errorf("failed to stat temp file: %w", err) } - mimeType := data.ContentType + mimeType := meta.ContentType if mimeType == "" { buf := make([]byte, 512) n, err := tempFile.Read(buf) From f6d8362278ab843dc9f5c919a5af71ea39ebf993 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Fri, 5 Dec 2025 11:36:43 +0000 Subject: [PATCH 1539/1647] client: add missing retry cancel check while backing off requests --- client.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index ba67a205..9961e717 100644 --- a/client.go +++ b/client.go @@ -614,7 +614,9 @@ func (cli *Client) doRetry( select { case <-time.After(backoff): case <-req.Context().Done(): - return nil, nil, req.Context().Err() + if !errors.Is(context.Cause(req.Context()), ErrContextCancelRetry) { + return nil, nil, req.Context().Err() + } } if cli.UpdateRequestOnRetry != nil { req = cli.UpdateRequestOnRetry(req, cause) From 4efa4bdac5e3821232a04ac73c7270fd457fa764 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 6 Dec 2025 12:51:01 +0200 Subject: [PATCH 1540/1647] bridgev2/config: allow multiple prioritized backfill limit override keys --- bridgev2/bridgeconfig/backfill.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/bridgev2/bridgeconfig/backfill.go b/bridgev2/bridgeconfig/backfill.go index 53282e41..eedae1e8 100644 --- a/bridgev2/bridgeconfig/backfill.go +++ b/bridgev2/bridgeconfig/backfill.go @@ -34,10 +34,12 @@ type BackfillQueueConfig struct { MaxBatchesOverride map[string]int `yaml:"max_batches_override"` } -func (bqc *BackfillQueueConfig) GetOverride(name string) int { - override, ok := bqc.MaxBatchesOverride[name] - if !ok { - return bqc.MaxBatches +func (bqc *BackfillQueueConfig) GetOverride(names ...string) int { + for _, name := range names { + override, ok := bqc.MaxBatchesOverride[name] + if ok { + return override + } } - return override + return bqc.MaxBatches } From 3e07631f9e178807e204ab92687bd6f69f385b78 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 6 Dec 2025 22:58:11 +0200 Subject: [PATCH 1541/1647] bridgev2/mxmain: add better error for pre-megabridge dbs --- bridgev2/matrix/mxmain/dberror.go | 7 ++++++- go.mod | 2 +- go.sum | 4 ++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/bridgev2/matrix/mxmain/dberror.go b/bridgev2/matrix/mxmain/dberror.go index 0f6aa68c..f5e438de 100644 --- a/bridgev2/matrix/mxmain/dberror.go +++ b/bridgev2/matrix/mxmain/dberror.go @@ -66,7 +66,12 @@ func (br *BridgeMain) LogDBUpgradeErrorAndExit(name string, err error, message s } else if errors.Is(err, dbutil.ErrForeignTables) { br.Log.Info().Msg("See https://docs.mau.fi/faq/foreign-tables for more info") } else if errors.Is(err, dbutil.ErrNotOwned) { - br.Log.Info().Msg("Sharing the same database with different programs is not supported") + var noe dbutil.NotOwnedError + if errors.As(err, &noe) && noe.Owner == br.Name { + br.Log.Info().Msg("The database appears to be on a very old pre-megabridge schema. Perhaps you need to run an older version of the bridge with migration support first?") + } else { + br.Log.Info().Msg("Sharing the same database with different programs is not supported") + } } else if errors.Is(err, dbutil.ErrUnsupportedDatabaseVersion) { br.Log.Info().Msg("Downgrading the bridge is not supported") } diff --git a/go.mod b/go.mod index d873892c..bf56a014 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.13 - go.mau.fi/util v0.9.4-0.20251128113707-115b8b18bd18 + go.mau.fi/util v0.9.4-0.20251206205611-85e6fd6551e0 go.mau.fi/zeroconfig v0.2.0 golang.org/x/crypto v0.45.0 golang.org/x/exp v0.0.0-20251125195548-87e1e737ad39 diff --git a/go.sum b/go.sum index fae6084d..6ea3f378 100644 --- a/go.sum +++ b/go.sum @@ -51,8 +51,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.13 h1:GPddIs617DnBLFFVJFgpo1aBfe/4xcvMc3SB5t/D0pA= github.com/yuin/goldmark v1.7.13/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.9.4-0.20251128113707-115b8b18bd18 h1:h1/wE/SLTuat12/SRsKyh+edWX2Aung1ZsiWnY3t5Zs= -go.mau.fi/util v0.9.4-0.20251128113707-115b8b18bd18/go.mod h1:viDmhBOAFfcqDdKSk53EPJV3N4Mi8Jst5/ahGJ/vwsA= +go.mau.fi/util v0.9.4-0.20251206205611-85e6fd6551e0 h1:ESebxPGULuuxxcZigjcBFyyU62tiyY6ivtX17P4BkvY= +go.mau.fi/util v0.9.4-0.20251206205611-85e6fd6551e0/go.mod h1:viDmhBOAFfcqDdKSk53EPJV3N4Mi8Jst5/ahGJ/vwsA= go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= From a2522192ff84512fa671844094c65a4775dae435 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 7 Dec 2025 19:34:29 +0200 Subject: [PATCH 1542/1647] bridgev2/config: fix warning log for null env_config_prefix --- bridgev2/bridgeconfig/upgrade.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index a3ac8747..960e2fb4 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -184,7 +184,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Int, "encryption", "rotation", "messages") helper.Copy(up.Bool, "encryption", "rotation", "disable_device_change_key_rotation") - helper.Copy(up.Str, "env_config_prefix") + helper.Copy(up.Str|up.Null, "env_config_prefix") helper.Copy(up.Map, "logging") } From 0584fd0c0d6e43adae98ec256e288001de1d89c6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 7 Dec 2025 19:52:08 +0200 Subject: [PATCH 1543/1647] bridgev2/portal: don't forward backfill without CanBackfill flag --- bridgev2/portal.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index c0855c2d..b6f60b78 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -5089,7 +5089,10 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo } } portal.addToUserSpaces(ctx) - if portal.Bridge.Config.Backfill.Enabled && portal.RoomType != database.RoomTypeSpace && !portal.Bridge.Background { + if info.CanBackfill && + portal.Bridge.Config.Backfill.Enabled && + portal.RoomType != database.RoomTypeSpace && + !portal.Bridge.Background { portal.doForwardBackfill(ctx, source, nil, backfillBundle) } return nil From 00c58efc59068c72f08db7d01d854f40ec453812 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 7 Dec 2025 19:52:22 +0200 Subject: [PATCH 1544/1647] bridgev2/portal: don't try to update functional members if portal doesn't exist --- bridgev2/portal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index b6f60b78..ad67b773 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2427,7 +2427,7 @@ func (portal *Portal) handleRemoteEvent(ctx context.Context, source *UserLogin, } func (portal *Portal) ensureFunctionalMember(ctx context.Context, ghost *Ghost) { - if !ghost.IsBot || portal.RoomType != database.RoomTypeDM || portal.OtherUserID == ghost.ID { + if !ghost.IsBot || portal.RoomType != database.RoomTypeDM || portal.OtherUserID == ghost.ID || portal.MXID == "" { return } ars, ok := portal.Bridge.Matrix.(MatrixConnectorWithArbitraryRoomState) From 6017612c552b6ca7b9f3786bd8cd669358153d0f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 7 Dec 2025 23:21:05 +0200 Subject: [PATCH 1545/1647] bridgev2/portal: only delete old reactions if new one is successful --- bridgev2/portal.go | 46 +++++++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index ad67b773..0d71535d 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1466,7 +1466,7 @@ func (portal *Portal) handleMatrixEdit( return EventHandlingResultSuccess } -func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) EventHandlingResult { +func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogin, evt *event.Event) (handleRes EventHandlingResult) { log := zerolog.Ctx(ctx) reactingAPI, ok := sender.Client.(ReactionHandlingNetworkAPI) if !ok { @@ -1511,6 +1511,25 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi if portal.Bridge.Config.OutgoingMessageReID { deterministicID = portal.Bridge.Matrix.GenerateReactionEventID(portal.MXID, reactionTarget, preResp.SenderID, preResp.EmojiID) } + removeOutdatedReaction := func(oldReact *database.Reaction, deleteDB bool) { + if !handleRes.Success { + return + } + _, err := portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ + Parsed: &event.RedactionEventContent{ + Redacts: oldReact.MXID, + }, + }, nil) + if err != nil { + log.Err(err).Msg("Failed to remove old reaction") + } + if deleteDB { + err = portal.Bridge.DB.Reaction.Delete(ctx, oldReact) + if err != nil { + log.Err(err).Msg("Failed to delete old reaction from database") + } + } + } existing, err := portal.Bridge.DB.Reaction.GetByID(ctx, portal.Receiver, reactionTarget.ID, reactionTarget.PartID, preResp.SenderID, preResp.EmojiID) if err != nil { log.Err(err).Msg("Failed to check if reaction is a duplicate") @@ -1522,14 +1541,7 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi return EventHandlingResultIgnored.WithEventID(deterministicID) } react.ReactionToOverride = existing - _, err = portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ - Parsed: &event.RedactionEventContent{ - Redacts: existing.MXID, - }, - }, nil) - if err != nil { - log.Err(err).Msg("Failed to remove old reaction") - } + defer removeOutdatedReaction(existing, false) } react.PreHandleResp = &preResp if preResp.MaxReactions > 0 { @@ -1544,18 +1556,10 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi // Keep n-1 previous reactions and remove the rest react.ExistingReactionsToKeep = allReactions[:preResp.MaxReactions-1] for _, oldReaction := range allReactions[preResp.MaxReactions-1:] { - _, err = portal.Bridge.Bot.SendMessage(ctx, portal.MXID, event.EventRedaction, &event.Content{ - Parsed: &event.RedactionEventContent{ - Redacts: oldReaction.MXID, - }, - }, nil) - if err != nil { - log.Err(err).Msg("Failed to remove previous reaction after limit was exceeded") - } - err = portal.Bridge.DB.Reaction.Delete(ctx, oldReaction) - if err != nil { - log.Err(err).Msg("Failed to delete previous reaction from database after limit was exceeded") - } + // Intentionally defer in a loop, there won't be that many items, + // and we want all of them to be done after this function completes successfully + //goland:noinspection GoDeferInLoop + defer removeOutdatedReaction(oldReaction, true) } } } From 315d2ab17d338f6aef6026929e8671726cd76ba7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 8 Dec 2025 00:07:07 +0200 Subject: [PATCH 1546/1647] all: fix staticcheck issues --- .github/workflows/go.yml | 1 + .pre-commit-config.yaml | 3 +- appservice/websocket.go | 2 +- bridgev2/bridgeconfig/permissions.go | 5 +- bridgev2/bridgestate.go | 6 +- bridgev2/database/database.go | 58 ------------ bridgev2/matrix/crypto.go | 8 +- bridgev2/matrix/matrix.go | 1 + bridgev2/matrix/mxmain/legacymigrate.go | 5 +- bridgev2/matrix/provisioning.go | 3 +- bridgev2/portalbackfill.go | 1 + bridgev2/portalreid.go | 2 +- client.go | 10 +- crypto/attachment/attachments.go | 49 ++++++---- crypto/attachment/attachments_test.go | 10 +- crypto/cross_sign_pubkey.go | 4 +- crypto/cross_sign_store.go | 30 +++--- crypto/cryptohelper/cryptohelper.go | 5 +- crypto/decryptmegolm.go | 57 +++++++----- crypto/decryptolm.go | 54 ++++++----- crypto/devicelist.go | 35 ++++--- crypto/encryptmegolm.go | 11 ++- crypto/goolm/account/account.go | 2 +- crypto/goolm/account/account_test.go | 2 +- crypto/goolm/goolmbase64/base64.go | 4 +- crypto/goolm/libolmpickle/picklejson.go | 2 +- crypto/goolm/message/session_export.go | 2 +- crypto/goolm/message/session_sharing.go | 2 +- crypto/goolm/pk/decryption.go | 2 +- crypto/goolm/pk/encryption.go | 3 + .../goolm/session/megolm_inbound_session.go | 4 +- .../goolm/session/megolm_outbound_session.go | 6 +- crypto/goolm/session/olm_session.go | 13 ++- crypto/goolm/session/register.go | 8 +- crypto/keybackup.go | 11 ++- crypto/libolm/account.go | 22 ++--- crypto/libolm/error.go | 30 +++--- crypto/libolm/inboundgroupsession.go | 18 ++-- crypto/libolm/outboundgroupsession.go | 8 +- crypto/libolm/pk.go | 2 +- crypto/libolm/register.go | 2 +- crypto/libolm/session.go | 20 ++-- crypto/machine.go | 2 +- crypto/olm/errors.go | 93 +++++++++++-------- crypto/sessions.go | 14 ++- event/encryption.go | 2 +- federation/resolution.go | 5 +- filter.go | 2 +- id/contenturi.go | 20 ++-- id/matrixuri.go | 2 +- id/userid.go | 8 +- pushrules/action.go | 2 +- pushrules/condition_test.go | 8 -- room.go | 6 +- synapseadmin/roomapi.go | 3 +- url.go | 6 +- 56 files changed, 358 insertions(+), 338 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index dc4f17e2..8bce4484 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -24,6 +24,7 @@ jobs: - name: Install goimports run: | go install golang.org/x/tools/cmd/goimports@latest + go install honnef.co/go/tools/cmd/staticcheck@latest export PATH="$HOME/go/bin:$PATH" - name: Run pre-commit diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0b9785ae..4f769e56 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,8 +18,7 @@ repos: - "-w" - id: go-vet-repo-mod - id: go-mod-tidy - # TODO enable this - #- id: go-staticcheck-repo-mod + - id: go-staticcheck-repo-mod - repo: https://github.com/beeper/pre-commit-go rev: v0.4.2 diff --git a/appservice/websocket.go b/appservice/websocket.go index 1e401c53..4f2538bf 100644 --- a/appservice/websocket.go +++ b/appservice/websocket.go @@ -56,7 +56,7 @@ func (wsc *WebsocketCommand) MakeResponse(ok bool, data any) *WebsocketRequest { var prefixMessage string for unwrappedErr != nil { errorData, jsonErr = json.Marshal(unwrappedErr) - if errorData != nil && len(errorData) > 2 && jsonErr == nil { + if len(errorData) > 2 && jsonErr == nil { prefixMessage = strings.Replace(err.Error(), unwrappedErr.Error(), "", 1) prefixMessage = strings.TrimRight(prefixMessage, ": ") break diff --git a/bridgev2/bridgeconfig/permissions.go b/bridgev2/bridgeconfig/permissions.go index 898bf58a..9efe068e 100644 --- a/bridgev2/bridgeconfig/permissions.go +++ b/bridgev2/bridgeconfig/permissions.go @@ -41,10 +41,7 @@ func (pc PermissionConfig) IsConfigured() bool { _, hasExampleDomain := pc["example.com"] _, hasExampleUser := pc["@admin:example.com"] exampleLen := boolToInt(hasWildcard) + boolToInt(hasExampleUser) + boolToInt(hasExampleDomain) - if len(pc) <= exampleLen { - return false - } - return true + return len(pc) > exampleLen } func (pc PermissionConfig) Get(userID id.UserID) Permissions { diff --git a/bridgev2/bridgestate.go b/bridgev2/bridgestate.go index a1d3e70b..babbccab 100644 --- a/bridgev2/bridgestate.go +++ b/bridgev2/bridgestate.go @@ -102,9 +102,9 @@ func (bsq *BridgeStateQueue) loop() { } } -func (bsq *BridgeStateQueue) scheduleNotice(ctx context.Context, triggeredBy status.BridgeState) { +func (bsq *BridgeStateQueue) scheduleNotice(triggeredBy status.BridgeState) { log := bsq.login.Log.With().Str("action", "transient disconnect notice").Logger() - ctx = log.WithContext(bsq.bridge.BackgroundCtx) + ctx := log.WithContext(bsq.bridge.BackgroundCtx) if !bsq.waitForTransientDisconnectReconnect(ctx) { return } @@ -135,7 +135,7 @@ func (bsq *BridgeStateQueue) sendNotice(ctx context.Context, state status.Bridge if bsq.firstTransientDisconnect.IsZero() { bsq.firstTransientDisconnect = time.Now() } - go bsq.scheduleNotice(ctx, state) + go bsq.scheduleNotice(state) } return } diff --git a/bridgev2/database/database.go b/bridgev2/database/database.go index 0729cb83..05abddf0 100644 --- a/bridgev2/database/database.go +++ b/bridgev2/database/database.go @@ -7,13 +7,7 @@ package database import ( - "encoding/json" - "reflect" - "strings" - "go.mau.fi/util/dbutil" - "golang.org/x/exp/constraints" - "golang.org/x/exp/maps" "maunium.net/go/mautrix/bridgev2/networkid" @@ -158,55 +152,3 @@ func ensureBridgeIDMatches(ptr *networkid.BridgeID, expected networkid.BridgeID) panic("bridge ID mismatch") } } - -func GetNumberFromMap[T constraints.Integer | constraints.Float](m map[string]any, key string) (T, bool) { - if val, found := m[key]; found { - floatVal, ok := val.(float64) - if ok { - return T(floatVal), true - } - tVal, ok := val.(T) - if ok { - return tVal, true - } - } - return 0, false -} - -func unmarshalMerge(input []byte, data any, extra *map[string]any) error { - err := json.Unmarshal(input, data) - if err != nil { - return err - } - err = json.Unmarshal(input, extra) - if err != nil { - return err - } - if *extra == nil { - *extra = make(map[string]any) - } - return nil -} - -func marshalMerge(data any, extra map[string]any) ([]byte, error) { - if extra == nil { - return json.Marshal(data) - } - merged := make(map[string]any) - maps.Copy(merged, extra) - dataRef := reflect.ValueOf(data).Elem() - dataType := dataRef.Type() - for _, field := range reflect.VisibleFields(dataType) { - parts := strings.Split(field.Tag.Get("json"), ",") - if len(parts) == 0 || len(parts[0]) == 0 || parts[0] == "-" { - continue - } - fieldVal := dataRef.FieldByIndex(field.Index) - if fieldVal.IsZero() { - delete(merged, parts[0]) - } else { - merged[parts[0]] = fieldVal.Interface() - } - } - return json.Marshal(merged) -} diff --git a/bridgev2/matrix/crypto.go b/bridgev2/matrix/crypto.go index f4a2e9a0..7f18f1f5 100644 --- a/bridgev2/matrix/crypto.go +++ b/bridgev2/matrix/crypto.go @@ -38,9 +38,9 @@ func init() { var _ crypto.StateStore = (*sqlstatestore.SQLStateStore)(nil) -var NoSessionFound = crypto.NoSessionFound -var DuplicateMessageIndex = crypto.DuplicateMessageIndex -var UnknownMessageIndex = olm.UnknownMessageIndex +var NoSessionFound = crypto.ErrNoSessionFound +var DuplicateMessageIndex = crypto.ErrDuplicateMessageIndex +var UnknownMessageIndex = olm.ErrUnknownMessageIndex type CryptoHelper struct { bridge *Connector @@ -439,7 +439,7 @@ func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtTy var encrypted *event.EncryptedEventContent encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content) if err != nil { - if !errors.Is(err, crypto.SessionExpired) && !errors.Is(err, crypto.SessionNotShared) && !errors.Is(err, crypto.NoGroupSession) { + if !errors.Is(err, crypto.ErrSessionExpired) && !errors.Is(err, crypto.ErrSessionNotShared) && !errors.Is(err, crypto.ErrNoGroupSession) { return } helper.log.Debug().Err(err). diff --git a/bridgev2/matrix/matrix.go b/bridgev2/matrix/matrix.go index 6c94bccc..570ae5f1 100644 --- a/bridgev2/matrix/matrix.go +++ b/bridgev2/matrix/matrix.go @@ -127,6 +127,7 @@ func (br *Connector) waitLongerForSession(ctx context.Context, evt *event.Event, Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())). Msg("Couldn't find session, requesting keys and waiting longer...") + //lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank go br.Crypto.RequestSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) go br.sendCryptoStatusError(ctx, evt, fmt.Errorf("%w. The bridge will retry for %d seconds", errNoDecryptionKeys, int(extendedSessionWaitTimeout.Seconds())), errorEventID, 1, false) diff --git a/bridgev2/matrix/mxmain/legacymigrate.go b/bridgev2/matrix/mxmain/legacymigrate.go index c8eb820b..97cdeddf 100644 --- a/bridgev2/matrix/mxmain/legacymigrate.go +++ b/bridgev2/matrix/mxmain/legacymigrate.go @@ -135,7 +135,10 @@ func (br *BridgeMain) CheckLegacyDB( } var dbVersion int err = br.DB.QueryRow(ctx, "SELECT version FROM version").Scan(&dbVersion) - if dbVersion < expectedVersion { + if err != nil { + log.Fatal().Err(err).Msg("Failed to get database version") + return + } else if dbVersion < expectedVersion { log.Fatal(). Int("expected_version", expectedVersion). Int("version", dbVersion). diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 43d19380..44e00e64 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -85,10 +85,9 @@ const ( provisioningUserKey provisioningContextKey = iota provisioningUserLoginKey provisioningLoginProcessKey + ProvisioningKeyRequest ) -const ProvisioningKeyRequest = "fi.mau.provision.request" - func (prov *ProvisioningAPI) GetUser(r *http.Request) *bridgev2.User { return r.Context().Value(provisioningUserKey).(*bridgev2.User) } diff --git a/bridgev2/portalbackfill.go b/bridgev2/portalbackfill.go index e8292388..879f07ae 100644 --- a/bridgev2/portalbackfill.go +++ b/bridgev2/portalbackfill.go @@ -410,6 +410,7 @@ func (portal *Portal) compileBatchMessage(ctx context.Context, source *UserLogin if reaction.Timestamp.IsZero() { reaction.Timestamp = msg.Timestamp.Add(10 * time.Millisecond) } + //lint:ignore SA4006 it's a todo targetPart, ok := partMap[*reaction.TargetPart] if !ok { // TODO warning log and/or skip reaction? diff --git a/bridgev2/portalreid.go b/bridgev2/portalreid.go index a25fe820..d1a9d5a6 100644 --- a/bridgev2/portalreid.go +++ b/bridgev2/portalreid.go @@ -96,7 +96,7 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta go func() { _, err := br.Bot.SendState(ctx, sourcePortal.MXID, event.StateTombstone, "", &event.Content{ Parsed: &event.TombstoneEventContent{ - Body: fmt.Sprintf("This room has been merged"), + Body: "This room has been merged", ReplacementRoom: targetPortal.MXID, }, }, time.Now()) diff --git a/client.go b/client.go index 9961e717..b740cba6 100644 --- a/client.go +++ b/client.go @@ -742,7 +742,7 @@ func (cli *Client) executeCompiledRequest( cli.RequestStart(req) startTime := time.Now() res, err := client.Do(req) - duration := time.Now().Sub(startTime) + duration := time.Since(startTime) if res != nil && !dontReadResponse { defer res.Body.Close() } @@ -862,7 +862,7 @@ func (cli *Client) FullSyncRequest(ctx context.Context, req ReqSync) (resp *Resp } start := time.Now() _, err = cli.MakeFullRequest(ctx, fullReq) - duration := time.Now().Sub(start) + duration := time.Since(start) timeout := time.Duration(req.Timeout) * time.Millisecond buffer := 10 * time.Second if req.Since == "" { @@ -966,7 +966,7 @@ func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister) (*RespRe // } // token := res.AccessToken func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRegister, error) { - res, uia, err := cli.Register(ctx, req) + _, uia, err := cli.Register(ctx, req) if err != nil && uia == nil { return nil, err } else if uia == nil { @@ -975,7 +975,7 @@ func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRe return nil, errors.New("server does not support m.login.dummy") } req.Auth = BaseAuthData{Type: AuthTypeDummy, Session: uia.Session} - res, _, err = cli.Register(ctx, req) + res, _, err := cli.Register(ctx, req) if err != nil { return nil, err } @@ -1751,6 +1751,8 @@ func parseRoomStateArray(req *http.Request, res *http.Response, responseJSON any return nil, nil } +type RoomStateMap = map[event.Type]map[string]*event.Event + // State gets all state in a room. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstate func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomStateMap, err error) { diff --git a/crypto/attachment/attachments.go b/crypto/attachment/attachments.go index 155cca5c..727aacbf 100644 --- a/crypto/attachment/attachments.go +++ b/crypto/attachment/attachments.go @@ -21,13 +21,24 @@ import ( ) var ( - HashMismatch = errors.New("mismatching SHA-256 digest") - UnsupportedVersion = errors.New("unsupported Matrix file encryption version") - UnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm") - InvalidKey = errors.New("failed to decode key") - InvalidInitVector = errors.New("failed to decode initialization vector") - InvalidHash = errors.New("failed to decode SHA-256 hash") - ReaderClosed = errors.New("encrypting reader was already closed") + ErrHashMismatch = errors.New("mismatching SHA-256 digest") + ErrUnsupportedVersion = errors.New("unsupported Matrix file encryption version") + ErrUnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm") + ErrInvalidKey = errors.New("failed to decode key") + ErrInvalidInitVector = errors.New("failed to decode initialization vector") + ErrInvalidHash = errors.New("failed to decode SHA-256 hash") + ErrReaderClosed = errors.New("encrypting reader was already closed") +) + +// Deprecated: use variables prefixed with Err +var ( + HashMismatch = ErrHashMismatch + UnsupportedVersion = ErrUnsupportedVersion + UnsupportedAlgorithm = ErrUnsupportedAlgorithm + InvalidKey = ErrInvalidKey + InvalidInitVector = ErrInvalidInitVector + InvalidHash = ErrInvalidHash + ReaderClosed = ErrReaderClosed ) var ( @@ -85,25 +96,25 @@ func (ef *EncryptedFile) decodeKeys(includeHash bool) error { if ef.decoded != nil { return nil } else if len(ef.Key.Key) != keyBase64Length { - return InvalidKey + return ErrInvalidKey } else if len(ef.InitVector) != ivBase64Length { - return InvalidInitVector + return ErrInvalidInitVector } else if includeHash && len(ef.Hashes.SHA256) != hashBase64Length { - return InvalidHash + return ErrInvalidHash } ef.decoded = &decodedKeys{} _, err := base64.RawURLEncoding.Decode(ef.decoded.key[:], []byte(ef.Key.Key)) if err != nil { - return InvalidKey + return ErrInvalidKey } _, err = base64.RawStdEncoding.Decode(ef.decoded.iv[:], []byte(ef.InitVector)) if err != nil { - return InvalidInitVector + return ErrInvalidInitVector } if includeHash { _, err = base64.RawStdEncoding.Decode(ef.decoded.sha256[:], []byte(ef.Hashes.SHA256)) if err != nil { - return InvalidHash + return ErrInvalidHash } } return nil @@ -179,7 +190,7 @@ var _ io.ReadSeekCloser = (*encryptingReader)(nil) func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) { if r.closed { - return 0, ReaderClosed + return 0, ErrReaderClosed } if offset != 0 || whence != io.SeekStart { return 0, fmt.Errorf("attachments.EncryptStream: only seeking to the beginning is supported") @@ -200,7 +211,7 @@ func (r *encryptingReader) Seek(offset int64, whence int) (int64, error) { func (r *encryptingReader) Read(dst []byte) (n int, err error) { if r.closed { - return 0, ReaderClosed + return 0, ErrReaderClosed } else if r.isDecrypting && r.file.decoded == nil { if err = r.file.PrepareForDecryption(); err != nil { return @@ -224,7 +235,7 @@ func (r *encryptingReader) Close() (err error) { } if r.isDecrypting { if !hmac.Equal(r.hash.Sum(nil), r.file.decoded.sha256[:]) { - return HashMismatch + return ErrHashMismatch } } else { r.file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(r.hash.Sum(nil)) @@ -265,9 +276,9 @@ func (ef *EncryptedFile) Decrypt(ciphertext []byte) ([]byte, error) { // DecryptInPlace will always call this automatically, so calling this manually is not necessary when using that function. func (ef *EncryptedFile) PrepareForDecryption() error { if ef.Version != "v2" { - return UnsupportedVersion + return ErrUnsupportedVersion } else if ef.Key.Algorithm != "A256CTR" { - return UnsupportedAlgorithm + return ErrUnsupportedAlgorithm } else if err := ef.decodeKeys(true); err != nil { return err } @@ -281,7 +292,7 @@ func (ef *EncryptedFile) DecryptInPlace(data []byte) error { } dataHash := sha256.Sum256(data) if !hmac.Equal(ef.decoded.sha256[:], dataHash[:]) { - return HashMismatch + return ErrHashMismatch } utils.XorA256CTR(data, ef.decoded.key, ef.decoded.iv) return nil diff --git a/crypto/attachment/attachments_test.go b/crypto/attachment/attachments_test.go index d7f1394a..9fe929ab 100644 --- a/crypto/attachment/attachments_test.go +++ b/crypto/attachment/attachments_test.go @@ -53,33 +53,33 @@ func TestUnsupportedVersion(t *testing.T) { file := parseHelloWorld() file.Version = "foo" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, UnsupportedVersion) + assert.ErrorIs(t, err, ErrUnsupportedVersion) } func TestUnsupportedAlgorithm(t *testing.T) { file := parseHelloWorld() file.Key.Algorithm = "bar" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, UnsupportedAlgorithm) + assert.ErrorIs(t, err, ErrUnsupportedAlgorithm) } func TestHashMismatch(t *testing.T) { file := parseHelloWorld() file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString([]byte(random32Bytes)) err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, HashMismatch) + assert.ErrorIs(t, err, ErrHashMismatch) } func TestTooLongHash(t *testing.T) { file := parseHelloWorld() file.Hashes.SHA256 = "TG9yZW0gaXBzdW0gZG9sb3Igc2l0IGFtZXQsIGNvbnNlY3RldHVlciBhZGlwaXNjaW5nIGVsaXQuIFNlZCBwb3N1ZXJlIGludGVyZHVtIHNlbS4gUXVpc3F1ZSBsaWd1bGEgZXJvcyB1bGxhbWNvcnBlciBxdWlzLCBsYWNpbmlhIHF1aXMgZmFjaWxpc2lzIHNlZCBzYXBpZW4uCg" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, InvalidHash) + assert.ErrorIs(t, err, ErrInvalidHash) } func TestTooShortHash(t *testing.T) { file := parseHelloWorld() file.Hashes.SHA256 = "5/Gy1JftyyQ" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) - assert.ErrorIs(t, err, InvalidHash) + assert.ErrorIs(t, err, ErrInvalidHash) } diff --git a/crypto/cross_sign_pubkey.go b/crypto/cross_sign_pubkey.go index f85d1ea3..223fc7b5 100644 --- a/crypto/cross_sign_pubkey.go +++ b/crypto/cross_sign_pubkey.go @@ -63,8 +63,8 @@ func (mach *OlmMachine) GetCrossSigningPublicKeys(ctx context.Context, userID id if len(dbKeys) > 0 { masterKey, ok := dbKeys[id.XSUsageMaster] if ok { - selfSigning, _ := dbKeys[id.XSUsageSelfSigning] - userSigning, _ := dbKeys[id.XSUsageUserSigning] + selfSigning := dbKeys[id.XSUsageSelfSigning] + userSigning := dbKeys[id.XSUsageUserSigning] return &CrossSigningPublicKeysCache{ MasterKey: masterKey.Key, SelfSigningKey: selfSigning.Key, diff --git a/crypto/cross_sign_store.go b/crypto/cross_sign_store.go index d30b7e32..57406b11 100644 --- a/crypto/cross_sign_store.go +++ b/crypto/cross_sign_store.go @@ -26,24 +26,22 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK log.Error().Err(err). Msg("Error fetching current cross-signing keys of user") } - if currentKeys != nil { - for curKeyUsage, curKey := range currentKeys { - log := log.With().Stringer("old_key", curKey.Key).Str("old_key_usage", string(curKeyUsage)).Logger() - // got a new key with the same usage as an existing key - for _, newKeyUsage := range userKeys.Usage { - if newKeyUsage == curKeyUsage { - if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok { - // old key is not in the new key map, so we drop signatures made by it - if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil { - log.Error().Err(err).Msg("Error deleting old signatures made by user") - } else { - log.Debug(). - Int64("signature_count", count). - Msg("Dropped signatures made by old key as it has been replaced") - } + for curKeyUsage, curKey := range currentKeys { + log := log.With().Stringer("old_key", curKey.Key).Str("old_key_usage", string(curKeyUsage)).Logger() + // got a new key with the same usage as an existing key + for _, newKeyUsage := range userKeys.Usage { + if newKeyUsage == curKeyUsage { + if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok { + // old key is not in the new key map, so we drop signatures made by it + if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil { + log.Error().Err(err).Msg("Error deleting old signatures made by user") + } else { + log.Debug(). + Int64("signature_count", count). + Msg("Dropped signatures made by old key as it has been replaced") } - break } + break } } } diff --git a/crypto/cryptohelper/cryptohelper.go b/crypto/cryptohelper/cryptohelper.go index 1939ea79..b62dc128 100644 --- a/crypto/cryptohelper/cryptohelper.go +++ b/crypto/cryptohelper/cryptohelper.go @@ -278,7 +278,7 @@ func (helper *CryptoHelper) verifyDeviceKeysOnServer(ctx context.Context) error } } -var NoSessionFound = crypto.NoSessionFound +var NoSessionFound = crypto.ErrNoSessionFound const initialSessionWaitTimeout = 3 * time.Second const extendedSessionWaitTimeout = 22 * time.Second @@ -371,6 +371,7 @@ func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, evt *event content := evt.Content.AsEncrypted() log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...") + //lint:ignore SA1019 RequestSession will gracefully request from all devices if DeviceID is blank go helper.RequestSession(context.TODO(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID) if !helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) { @@ -418,7 +419,7 @@ func (helper *CryptoHelper) EncryptWithStateKey(ctx context.Context, roomID id.R defer helper.lock.RUnlock() encrypted, err = helper.mach.EncryptMegolmEventWithStateKey(ctx, roomID, evtType, stateKey, content) if err != nil { - if !errors.Is(err, crypto.SessionExpired) && err != crypto.NoGroupSession && !errors.Is(err, crypto.SessionNotShared) { + if !errors.Is(err, crypto.ErrSessionExpired) && err != crypto.ErrNoGroupSession && !errors.Is(err, crypto.ErrSessionNotShared) { return } helper.log.Debug(). diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index 47279474..d8b419ab 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -24,13 +24,24 @@ import ( ) var ( - IncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent") - NoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found") - DuplicateMessageIndex = errors.New("duplicate megolm message index") - WrongRoom = errors.New("encrypted megolm event is not intended for this room") - DeviceKeyMismatch = errors.New("device keys in event and verified device info do not match") - SenderKeyMismatch = errors.New("sender keys in content and megolm session do not match") - RatchetError = errors.New("failed to ratchet session after use") + ErrIncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent") + ErrNoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found") + ErrDuplicateMessageIndex = errors.New("duplicate megolm message index") + ErrWrongRoom = errors.New("encrypted megolm event is not intended for this room") + ErrDeviceKeyMismatch = errors.New("device keys in event and verified device info do not match") + ErrSenderKeyMismatch = errors.New("sender keys in content and megolm session do not match") + ErrRatchetError = errors.New("failed to ratchet session after use") +) + +// Deprecated: use variables prefixed with Err +var ( + IncorrectEncryptedContentType = ErrIncorrectEncryptedContentType + NoSessionFound = ErrNoSessionFound + DuplicateMessageIndex = ErrDuplicateMessageIndex + WrongRoom = ErrWrongRoom + DeviceKeyMismatch = ErrDeviceKeyMismatch + SenderKeyMismatch = ErrSenderKeyMismatch + RatchetError = ErrRatchetError ) type megolmEvent struct { @@ -49,9 +60,9 @@ var ( func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event) (*event.Event, error) { content, ok := evt.Content.Parsed.(*event.EncryptedEventContent) if !ok { - return nil, IncorrectEncryptedContentType + return nil, ErrIncorrectEncryptedContentType } else if content.Algorithm != id.AlgorithmMegolmV1 { - return nil, UnsupportedAlgorithm + return nil, ErrUnsupportedAlgorithm } log := mach.machOrContextLog(ctx).With(). Str("action", "decrypt megolm event"). @@ -97,7 +108,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event Msg("Couldn't resolve trust level of session: sent by unknown device") trustLevel = id.TrustStateUnknownDevice } else if device.SigningKey != sess.SigningKey || device.IdentityKey != sess.SenderKey { - return nil, DeviceKeyMismatch + return nil, ErrDeviceKeyMismatch } else { trustLevel, err = mach.ResolveTrustContext(ctx, device) if err != nil { @@ -147,7 +158,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event if err != nil { return nil, fmt.Errorf("failed to parse megolm payload: %w", err) } else if megolmEvt.RoomID != encryptionRoomID { - return nil, WrongRoom + return nil, ErrWrongRoom } if evt.StateKey != nil && megolmEvt.StateKey != nil && mach.AllowEncryptedState { megolmEvt.Type.Class = event.StateEventType @@ -201,19 +212,19 @@ func (mach *OlmMachine) checkUndecryptableMessageIndexDuplication(ctx context.Co messageIndex, decodeErr := ParseMegolmMessageIndex(content.MegolmCiphertext) if decodeErr != nil { log.Warn().Err(decodeErr).Msg("Failed to parse message index to check if it's a duplicate for message that failed to decrypt") - return 0, fmt.Errorf("%w (also failed to parse message index)", olm.UnknownMessageIndex) + return 0, fmt.Errorf("%w (also failed to parse message index)", olm.ErrUnknownMessageIndex) } firstKnown := sess.Internal.FirstKnownIndex() log = log.With().Uint("message_index", messageIndex).Uint32("first_known_index", firstKnown).Logger() if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil { log.Debug().Err(err).Msg("Failed to check if message index is duplicate") - return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown) + return messageIndex, fmt.Errorf("%w (failed to check if index is duplicate; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown) } else if !ok { log.Debug().Msg("Failed to decrypt message due to unknown index and found duplicate") - return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", DuplicateMessageIndex, messageIndex, firstKnown) + return messageIndex, fmt.Errorf("%w %d (also failed to decrypt because earliest known index is %d)", ErrDuplicateMessageIndex, messageIndex, firstKnown) } log.Debug().Msg("Failed to decrypt message due to unknown index, but index is not duplicate") - return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.UnknownMessageIndex, messageIndex, firstKnown) + return messageIndex, fmt.Errorf("%w (not duplicate index; received: %d, earliest known: %d)", olm.ErrUnknownMessageIndex, messageIndex, firstKnown) } func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *event.Event, encryptionRoomID id.RoomID, content *event.EncryptedEventContent) (*InboundGroupSession, []byte, uint, error) { @@ -224,13 +235,13 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve if err != nil { return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err) } else if sess == nil { - return nil, nil, 0, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID) + return nil, nil, 0, fmt.Errorf("%w (ID %s)", ErrNoSessionFound, content.SessionID) } else if content.SenderKey != "" && content.SenderKey != sess.SenderKey { - return sess, nil, 0, SenderKeyMismatch + return sess, nil, 0, ErrSenderKeyMismatch } plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext) if err != nil { - if errors.Is(err, olm.UnknownMessageIndex) && mach.RatchetKeysOnDecrypt { + if errors.Is(err, olm.ErrUnknownMessageIndex) && mach.RatchetKeysOnDecrypt { messageIndex, err = mach.checkUndecryptableMessageIndexDuplication(ctx, sess, evt, content) return sess, nil, messageIndex, fmt.Errorf("failed to decrypt megolm event: %w", err) } @@ -238,7 +249,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve } else if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil { return sess, nil, messageIndex, fmt.Errorf("failed to check if message index is duplicate: %w", err) } else if !ok { - return sess, nil, messageIndex, fmt.Errorf("%w %d", DuplicateMessageIndex, messageIndex) + return sess, nil, messageIndex, fmt.Errorf("%w %d", ErrDuplicateMessageIndex, messageIndex) } // Normal clients don't care about tracking the ratchet state, so let them bypass the rest of the function @@ -290,24 +301,24 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.ID(), "maximum messages reached") if err != nil { log.Err(err).Msg("Failed to delete fully used session") - return sess, plaintext, messageIndex, RatchetError + return sess, plaintext, messageIndex, ErrRatchetError } else { log.Info().Msg("Deleted fully used session") } } else if ratchetCurrentIndex < ratchetTargetIndex && mach.RatchetKeysOnDecrypt { if err = sess.RatchetTo(ratchetTargetIndex); err != nil { log.Err(err).Msg("Failed to ratchet session") - return sess, plaintext, messageIndex, RatchetError + return sess, plaintext, messageIndex, ErrRatchetError } else if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil { log.Err(err).Msg("Failed to store ratcheted session") - return sess, plaintext, messageIndex, RatchetError + return sess, plaintext, messageIndex, ErrRatchetError } else { log.Info().Msg("Ratcheted session forward") } } else if didModify { if err = mach.CryptoStore.PutGroupSession(ctx, sess); err != nil { log.Err(err).Msg("Failed to store updated ratchet safety data") - return sess, plaintext, messageIndex, RatchetError + return sess, plaintext, messageIndex, ErrRatchetError } else { log.Debug().Msg("Ratchet safety data changed (ratchet state didn't change)") } diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go index 30cc4cfe..cd02726d 100644 --- a/crypto/decryptolm.go +++ b/crypto/decryptolm.go @@ -26,15 +26,27 @@ import ( ) var ( - UnsupportedAlgorithm = errors.New("unsupported event encryption algorithm") - NotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device") - UnsupportedOlmMessageType = errors.New("unsupported olm message type") - DecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session") - DecryptionFailedForNormalMessage = errors.New("decryption failed for normal message") - SenderMismatch = errors.New("mismatched sender in olm payload") - RecipientMismatch = errors.New("mismatched recipient in olm payload") - RecipientKeyMismatch = errors.New("mismatched recipient key in olm payload") - ErrDuplicateMessage = errors.New("duplicate olm message") + ErrUnsupportedAlgorithm = errors.New("unsupported event encryption algorithm") + ErrNotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device") + ErrUnsupportedOlmMessageType = errors.New("unsupported olm message type") + ErrDecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session") + ErrDecryptionFailedForNormalMessage = errors.New("decryption failed for normal message") + ErrSenderMismatch = errors.New("mismatched sender in olm payload") + ErrRecipientMismatch = errors.New("mismatched recipient in olm payload") + ErrRecipientKeyMismatch = errors.New("mismatched recipient key in olm payload") + ErrDuplicateMessage = errors.New("duplicate olm message") +) + +// Deprecated: use variables prefixed with Err +var ( + UnsupportedAlgorithm = ErrUnsupportedAlgorithm + NotEncryptedForMe = ErrNotEncryptedForMe + UnsupportedOlmMessageType = ErrUnsupportedOlmMessageType + DecryptionFailedWithMatchingSession = ErrDecryptionFailedWithMatchingSession + DecryptionFailedForNormalMessage = ErrDecryptionFailedForNormalMessage + SenderMismatch = ErrSenderMismatch + RecipientMismatch = ErrRecipientMismatch + RecipientKeyMismatch = ErrRecipientKeyMismatch ) // DecryptedOlmEvent represents an event that was decrypted from an event encrypted with the m.olm.v1.curve25519-aes-sha2 algorithm. @@ -56,13 +68,13 @@ type DecryptedOlmEvent struct { func (mach *OlmMachine) decryptOlmEvent(ctx context.Context, evt *event.Event) (*DecryptedOlmEvent, error) { content, ok := evt.Content.Parsed.(*event.EncryptedEventContent) if !ok { - return nil, IncorrectEncryptedContentType + return nil, ErrIncorrectEncryptedContentType } else if content.Algorithm != id.AlgorithmOlmV1 { - return nil, UnsupportedAlgorithm + return nil, ErrUnsupportedAlgorithm } ownContent, ok := content.OlmCiphertext[mach.account.IdentityKey()] if !ok { - return nil, NotEncryptedForMe + return nil, ErrNotEncryptedForMe } decrypted, err := mach.decryptAndParseOlmCiphertext(ctx, evt, content.SenderKey, ownContent.Type, ownContent.Body) if err != nil { @@ -78,7 +90,7 @@ type OlmEventKeys struct { func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *event.Event, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) (*DecryptedOlmEvent, error) { if olmType != id.OlmMsgTypePreKey && olmType != id.OlmMsgTypeMsg { - return nil, UnsupportedOlmMessageType + return nil, ErrUnsupportedOlmMessageType } log := mach.machOrContextLog(ctx).With(). @@ -102,11 +114,11 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e } olmEvt.Type.Class = evt.Type.Class if evt.Sender != olmEvt.Sender { - return nil, SenderMismatch + return nil, ErrSenderMismatch } else if mach.Client.UserID != olmEvt.Recipient { - return nil, RecipientMismatch + return nil, ErrRecipientMismatch } else if mach.account.SigningKey() != olmEvt.RecipientKeys.Ed25519 { - return nil, RecipientKeyMismatch + return nil, ErrRecipientKeyMismatch } if len(olmEvt.Content.VeryRaw) > 0 { @@ -151,7 +163,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(ctx, senderKey, olmType, ciphertext, ciphertextHash) if err != nil { - if err == DecryptionFailedWithMatchingSession { + if err == ErrDecryptionFailedWithMatchingSession { log.Warn().Msg("Found matching session, but decryption failed") go mach.unwedgeDevice(log, sender, senderKey) } @@ -169,10 +181,10 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U // if it isn't one at this point in time anymore, so return early. if olmType != id.OlmMsgTypePreKey { go mach.unwedgeDevice(log, sender, senderKey) - return nil, DecryptionFailedForNormalMessage + return nil, ErrDecryptionFailedForNormalMessage } - accountBackup, err := mach.account.Internal.Pickle([]byte("tmp")) + accountBackup, _ := mach.account.Internal.Pickle([]byte("tmp")) log.Trace().Msg("Trying to create inbound session") endTimeTrace = mach.timeTrace(ctx, "creating inbound olm session", time.Second) session, err := mach.createInboundSession(ctx, senderKey, ciphertext) @@ -302,7 +314,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession( Str("session_description", session.Describe()). Msg("Failed to decrypt olm message") if olmType == id.OlmMsgTypePreKey { - return nil, DecryptionFailedWithMatchingSession + return nil, ErrDecryptionFailedWithMatchingSession } } else { endTimeTrace = mach.timeTrace(ctx, "updating session in database", time.Second) @@ -345,7 +357,7 @@ func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, send ctx := log.WithContext(mach.backgroundCtx) mach.recentlyUnwedgedLock.Lock() prevUnwedge, ok := mach.recentlyUnwedged[senderKey] - delta := time.Now().Sub(prevUnwedge) + delta := time.Since(prevUnwedge) if ok && delta < MinUnwedgeInterval { log.Debug(). Str("previous_recreation", delta.String()). diff --git a/crypto/devicelist.go b/crypto/devicelist.go index 61a22522..f0d2b129 100644 --- a/crypto/devicelist.go +++ b/crypto/devicelist.go @@ -22,14 +22,23 @@ import ( ) var ( - MismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object") - MismatchingUserID = errors.New("mismatching user ID in parameter and keys object") - MismatchingSigningKey = errors.New("received update for device with different signing key") - NoSigningKeyFound = errors.New("didn't find ed25519 signing key") - NoIdentityKeyFound = errors.New("didn't find curve25519 identity key") - InvalidKeySignature = errors.New("invalid signature on device keys") + ErrMismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object") + ErrMismatchingUserID = errors.New("mismatching user ID in parameter and keys object") + ErrMismatchingSigningKey = errors.New("received update for device with different signing key") + ErrNoSigningKeyFound = errors.New("didn't find ed25519 signing key") + ErrNoIdentityKeyFound = errors.New("didn't find curve25519 identity key") + ErrInvalidKeySignature = errors.New("invalid signature on device keys") + ErrUserNotTracked = errors.New("user is not tracked") +) - ErrUserNotTracked = errors.New("user is not tracked") +// Deprecated: use variables prefixed with Err +var ( + MismatchingDeviceID = ErrMismatchingDeviceID + MismatchingUserID = ErrMismatchingUserID + MismatchingSigningKey = ErrMismatchingSigningKey + NoSigningKeyFound = ErrNoSigningKeyFound + NoIdentityKeyFound = ErrNoIdentityKeyFound + InvalidKeySignature = ErrInvalidKeySignature ) func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) (keys map[id.DeviceID]*id.Device) { @@ -312,28 +321,28 @@ func (mach *OlmMachine) OnDevicesChanged(ctx context.Context, userID id.UserID) func (mach *OlmMachine) validateDevice(userID id.UserID, deviceID id.DeviceID, deviceKeys mautrix.DeviceKeys, existing *id.Device) (*id.Device, error) { if deviceID != deviceKeys.DeviceID { - return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingDeviceID, deviceID, deviceKeys.DeviceID) + return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingDeviceID, deviceID, deviceKeys.DeviceID) } else if userID != deviceKeys.UserID { - return nil, fmt.Errorf("%w (expected %s, got %s)", MismatchingUserID, userID, deviceKeys.UserID) + return nil, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingUserID, userID, deviceKeys.UserID) } signingKey := deviceKeys.Keys.GetEd25519(deviceID) identityKey := deviceKeys.Keys.GetCurve25519(deviceID) if signingKey == "" { - return nil, NoSigningKeyFound + return nil, ErrNoSigningKeyFound } else if identityKey == "" { - return nil, NoIdentityKeyFound + return nil, ErrNoIdentityKeyFound } if existing != nil && existing.SigningKey != signingKey { - return existing, fmt.Errorf("%w (expected %s, got %s)", MismatchingSigningKey, existing.SigningKey, signingKey) + return existing, fmt.Errorf("%w (expected %s, got %s)", ErrMismatchingSigningKey, existing.SigningKey, signingKey) } ok, err := signatures.VerifySignatureJSON(deviceKeys, userID, deviceID.String(), signingKey) if err != nil { return existing, fmt.Errorf("failed to verify signature: %w", err) } else if !ok { - return existing, InvalidKeySignature + return existing, ErrInvalidKeySignature } name, ok := deviceKeys.Unsigned["device_display_name"].(string) diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index b3d19618..ea97f767 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -25,7 +25,12 @@ import ( ) var ( - NoGroupSession = errors.New("no group session created") + ErrNoGroupSession = errors.New("no group session created") +) + +// Deprecated: use variables prefixed with Err +var ( + NoGroupSession = ErrNoGroupSession ) func getRawJSON[T any](content json.RawMessage, path ...string) *T { @@ -82,7 +87,7 @@ type rawMegolmEvent struct { // IsShareError returns true if the error is caused by the lack of an outgoing megolm session and can be solved with OlmMachine.ShareGroupSession func IsShareError(err error) bool { - return err == SessionExpired || err == SessionNotShared || err == NoGroupSession + return err == ErrSessionExpired || err == ErrSessionNotShared || err == ErrNoGroupSession } func ParseMegolmMessageIndex(ciphertext []byte) (uint, error) { @@ -120,7 +125,7 @@ func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, room if err != nil { return nil, fmt.Errorf("failed to get outbound group session: %w", err) } else if session == nil { - return nil, NoGroupSession + return nil, ErrNoGroupSession } plaintext, err := json.Marshal(&rawMegolmEvent{ RoomID: roomID, diff --git a/crypto/goolm/account/account.go b/crypto/goolm/account/account.go index 4da08a73..b48843a4 100644 --- a/crypto/goolm/account/account.go +++ b/crypto/goolm/account/account.go @@ -334,7 +334,7 @@ func (a *Account) UnpickleLibOlm(buf []byte) error { if err != nil { return err } else if pickledVersion != accountPickleVersionLibOLM && pickledVersion != 3 && pickledVersion != 2 { - return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrBadVersion, pickledVersion) + return fmt.Errorf("unpickle account: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) } else if err = a.IdKeys.Ed25519.UnpickleLibOlm(decoder); err != nil { // read the ed25519 key pair return err } else if err = a.IdKeys.Curve25519.UnpickleLibOlm(decoder); err != nil { // read curve25519 key pair diff --git a/crypto/goolm/account/account_test.go b/crypto/goolm/account/account_test.go index e1c9b452..d0dec5f0 100644 --- a/crypto/goolm/account/account_test.go +++ b/crypto/goolm/account/account_test.go @@ -124,7 +124,7 @@ func TestOldAccountPickle(t *testing.T) { account, err := account.NewAccount() assert.NoError(t, err) err = account.Unpickle(pickled, pickleKey) - assert.ErrorIs(t, err, olm.ErrBadVersion) + assert.ErrorIs(t, err, olm.ErrUnknownOlmPickleVersion) } func TestLoopback(t *testing.T) { diff --git a/crypto/goolm/goolmbase64/base64.go b/crypto/goolm/goolmbase64/base64.go index 061a052a..58ee26f7 100644 --- a/crypto/goolm/goolmbase64/base64.go +++ b/crypto/goolm/goolmbase64/base64.go @@ -4,7 +4,8 @@ import ( "encoding/base64" ) -// Deprecated: base64.RawStdEncoding should be used directly +// These methods should only be used for raw byte operations, never with string conversion + func Decode(input []byte) ([]byte, error) { decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(input))) writtenBytes, err := base64.RawStdEncoding.Decode(decoded, input) @@ -14,7 +15,6 @@ func Decode(input []byte) ([]byte, error) { return decoded[:writtenBytes], nil } -// Deprecated: base64.RawStdEncoding should be used directly func Encode(input []byte) []byte { encoded := make([]byte, base64.RawStdEncoding.EncodedLen(len(input))) base64.RawStdEncoding.Encode(encoded, input) diff --git a/crypto/goolm/libolmpickle/picklejson.go b/crypto/goolm/libolmpickle/picklejson.go index 308e472c..f765391f 100644 --- a/crypto/goolm/libolmpickle/picklejson.go +++ b/crypto/goolm/libolmpickle/picklejson.go @@ -50,7 +50,7 @@ func UnpickleAsJSON(object any, pickled, key []byte, pickleVersion byte) error { } } if decrypted[0] != pickleVersion { - return fmt.Errorf("unpickle: %w", olm.ErrWrongPickleVersion) + return fmt.Errorf("unpickle: %w", olm.ErrUnknownJSONPickleVersion) } err = json.Unmarshal(decrypted[1:], object) if err != nil { diff --git a/crypto/goolm/message/session_export.go b/crypto/goolm/message/session_export.go index 956868b2..d58dbb21 100644 --- a/crypto/goolm/message/session_export.go +++ b/crypto/goolm/message/session_export.go @@ -35,7 +35,7 @@ func (s *MegolmSessionExport) Decode(input []byte) error { return fmt.Errorf("decrypt: %w", olm.ErrBadInput) } if input[0] != sessionExportVersion { - return fmt.Errorf("decrypt: %w", olm.ErrBadVersion) + return fmt.Errorf("decrypt: %w", olm.ErrUnknownOlmPickleVersion) } s.Counter = binary.BigEndian.Uint32(input[1:5]) copy(s.RatchetData[:], input[5:133]) diff --git a/crypto/goolm/message/session_sharing.go b/crypto/goolm/message/session_sharing.go index 16240945..d04ef15a 100644 --- a/crypto/goolm/message/session_sharing.go +++ b/crypto/goolm/message/session_sharing.go @@ -42,7 +42,7 @@ func (s *MegolmSessionSharing) VerifyAndDecode(input []byte) error { } s.PublicKey = publicKey if input[0] != sessionSharingVersion { - return fmt.Errorf("verify: %w", olm.ErrBadVersion) + return fmt.Errorf("verify: %w", olm.ErrUnknownOlmPickleVersion) } s.Counter = binary.BigEndian.Uint32(input[1:5]) copy(s.RatchetData[:], input[5:133]) diff --git a/crypto/goolm/pk/decryption.go b/crypto/goolm/pk/decryption.go index afb01f74..cdb20eb1 100644 --- a/crypto/goolm/pk/decryption.go +++ b/crypto/goolm/pk/decryption.go @@ -103,7 +103,7 @@ func (a *Decryption) UnpickleLibOlm(unpickled []byte) error { if pickledVersion == decryptionPickleVersionLibOlm { return a.KeyPair.UnpickleLibOlm(decoder) } else { - return fmt.Errorf("unpickle olmSession: %w (found %d, expected %d)", olm.ErrBadVersion, pickledVersion, decryptionPickleVersionLibOlm) + return fmt.Errorf("unpickle olmSession: %w (found %d, expected %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion, decryptionPickleVersionLibOlm) } } diff --git a/crypto/goolm/pk/encryption.go b/crypto/goolm/pk/encryption.go index 23f67ddf..2897d9b0 100644 --- a/crypto/goolm/pk/encryption.go +++ b/crypto/goolm/pk/encryption.go @@ -37,6 +37,9 @@ func (e Encryption) Encrypt(plaintext []byte, privateKey crypto.Curve25519Privat return nil, nil, err } cipher, err := aessha2.NewAESSHA2(sharedSecret, nil) + if err != nil { + return nil, nil, err + } ciphertext, err = cipher.Encrypt(plaintext) if err != nil { return nil, nil, err diff --git a/crypto/goolm/session/megolm_inbound_session.go b/crypto/goolm/session/megolm_inbound_session.go index 80dd71cc..fb88b73c 100644 --- a/crypto/goolm/session/megolm_inbound_session.go +++ b/crypto/goolm/session/megolm_inbound_session.go @@ -99,7 +99,7 @@ func (o *MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, } if (messageIndex - o.InitialRatchet.Counter) >= uint32(1<<31) { // the counter is before our initial ratchet - we can't decode this - return nil, fmt.Errorf("decrypt: %w", olm.ErrRatchetNotAvailable) + return nil, fmt.Errorf("decrypt: %w", olm.ErrUnknownMessageIndex) } // otherwise, start from the initial ratchet. Take a copy so that we don't overwrite the initial ratchet copiedRatchet := o.InitialRatchet @@ -206,7 +206,7 @@ func (o *MegolmInboundSession) UnpickleLibOlm(value []byte) error { return err } if pickledVersion != megolmInboundSessionPickleVersionLibOlm && pickledVersion != 1 { - return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion) + return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) } if err = o.InitialRatchet.UnpickleLibOlm(decoder); err != nil { diff --git a/crypto/goolm/session/megolm_outbound_session.go b/crypto/goolm/session/megolm_outbound_session.go index 2b8e1c84..7f923534 100644 --- a/crypto/goolm/session/megolm_outbound_session.go +++ b/crypto/goolm/session/megolm_outbound_session.go @@ -101,8 +101,10 @@ func (o *MegolmOutboundSession) Unpickle(pickled, key []byte) error { func (o *MegolmOutboundSession) UnpickleLibOlm(buf []byte) error { decoder := libolmpickle.NewDecoder(buf) pickledVersion, err := decoder.ReadUInt32() - if pickledVersion != megolmOutboundSessionPickleVersionLibOlm { - return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion) + if err != nil { + return fmt.Errorf("unpickle MegolmOutboundSession: failed to read version: %w", err) + } else if pickledVersion != megolmOutboundSessionPickleVersionLibOlm { + return fmt.Errorf("unpickle MegolmInboundSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) } if err = o.Ratchet.UnpickleLibOlm(decoder); err != nil { return err diff --git a/crypto/goolm/session/olm_session.go b/crypto/goolm/session/olm_session.go index b99ab630..a1cb8d66 100644 --- a/crypto/goolm/session/olm_session.go +++ b/crypto/goolm/session/olm_session.go @@ -168,11 +168,11 @@ func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, received msg := message.Message{} err = msg.Decode(oneTimeMsg.Message) if err != nil { - return nil, fmt.Errorf("Message decode: %w", err) + return nil, fmt.Errorf("message decode: %w", err) } if len(msg.RatchetKey) == 0 { - return nil, fmt.Errorf("Message missing ratchet key: %w", olm.ErrBadMessageFormat) + return nil, fmt.Errorf("message missing ratchet key: %w", olm.ErrBadMessageFormat) } //Init Ratchet s.Ratchet.InitializeAsBob(secret, msg.RatchetKey) @@ -203,7 +203,7 @@ func (s *OlmSession) ID() id.SessionID { copy(message[crypto.Curve25519PrivateKeyLength:], s.AliceBaseKey) copy(message[2*crypto.Curve25519PrivateKeyLength:], s.BobOneTimeKey) hash := sha256.Sum256(message) - res := id.SessionID(goolmbase64.Encode(hash[:])) + res := id.SessionID(base64.RawStdEncoding.EncodeToString(hash[:])) return res } @@ -325,7 +325,7 @@ func (s *OlmSession) Decrypt(crypttext string, msgType id.OlmMsgType) ([]byte, e if len(crypttext) == 0 { return nil, fmt.Errorf("decrypt: %w", olm.ErrEmptyInput) } - decodedCrypttext, err := goolmbase64.Decode([]byte(crypttext)) + decodedCrypttext, err := base64.RawStdEncoding.DecodeString(crypttext) if err != nil { return nil, err } @@ -365,6 +365,9 @@ func (o *OlmSession) Unpickle(pickled, key []byte) error { func (o *OlmSession) UnpickleLibOlm(buf []byte) error { decoder := libolmpickle.NewDecoder(buf) pickledVersion, err := decoder.ReadUInt32() + if err != nil { + return fmt.Errorf("unpickle olmSession: failed to read version: %w", err) + } var includesChainIndex bool switch pickledVersion { @@ -373,7 +376,7 @@ func (o *OlmSession) UnpickleLibOlm(buf []byte) error { case uint32(0x80000001): includesChainIndex = true default: - return fmt.Errorf("unpickle olmSession: %w (found version %d)", olm.ErrBadVersion, pickledVersion) + return fmt.Errorf("unpickle olmSession: %w (found version %d)", olm.ErrUnknownOlmPickleVersion, pickledVersion) } if o.ReceivedMessage, err = decoder.ReadBool(); err != nil { diff --git a/crypto/goolm/session/register.go b/crypto/goolm/session/register.go index a88d12f6..b95a44ac 100644 --- a/crypto/goolm/session/register.go +++ b/crypto/goolm/session/register.go @@ -14,7 +14,7 @@ func Register() { // Inbound Session olm.InitInboundGroupSessionFromPickled = func(pickled, key []byte) (olm.InboundGroupSession, error) { if len(pickled) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } if len(key) == 0 { key = []byte(" ") @@ -23,13 +23,13 @@ func Register() { } olm.InitNewInboundGroupSession = func(sessionKey []byte) (olm.InboundGroupSession, error) { if len(sessionKey) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } return NewMegolmInboundSession(sessionKey) } olm.InitInboundGroupSessionImport = func(sessionKey []byte) (olm.InboundGroupSession, error) { if len(sessionKey) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } return NewMegolmInboundSessionFromExport(sessionKey) } @@ -40,7 +40,7 @@ func Register() { // Outbound Session olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) { if len(pickled) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } lenKey := len(key) if lenKey == 0 { diff --git a/crypto/keybackup.go b/crypto/keybackup.go index d8b3d715..ceec1d58 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -56,11 +56,12 @@ func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context, // ...by deriving the public key from a private key that it obtained from a trusted source. Trusted sources for the private // key include the user entering the key, retrieving the key stored in secret storage, or obtaining the key via secret sharing // from a verified device belonging to the same user." - megolmBackupDerivedPublicKey := id.Ed25519(base64.RawStdEncoding.EncodeToString(megolmBackupKey.PublicKey().Bytes())) - if megolmBackupKey != nil && versionInfo.AuthData.PublicKey == megolmBackupDerivedPublicKey { - log.Debug().Msg("key backup is trusted based on derived public key") - return versionInfo, nil - } else { + if megolmBackupKey != nil { + megolmBackupDerivedPublicKey := id.Ed25519(base64.RawStdEncoding.EncodeToString(megolmBackupKey.PublicKey().Bytes())) + if versionInfo.AuthData.PublicKey == megolmBackupDerivedPublicKey { + log.Debug().Msg("Key backup is trusted based on derived public key") + return versionInfo, nil + } log.Debug(). Stringer("expected_key", megolmBackupDerivedPublicKey). Stringer("actual_key", versionInfo.AuthData.PublicKey). diff --git a/crypto/libolm/account.go b/crypto/libolm/account.go index f6f916e7..0350f083 100644 --- a/crypto/libolm/account.go +++ b/crypto/libolm/account.go @@ -33,7 +33,7 @@ var _ olm.Account = (*Account)(nil) // "INVALID_BASE64". func AccountFromPickled(pickled, key []byte) (*Account, error) { if len(pickled) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } a := NewBlankAccount() return a, a.Unpickle(pickled, key) @@ -53,7 +53,7 @@ func NewAccount() (*Account, error) { random := make([]byte, a.createRandomLen()+1) _, err := rand.Read(random) if err != nil { - panic(olm.NotEnoughGoRandom) + panic(olm.ErrNotEnoughGoRandom) } ret := C.olm_create_account( (*C.OlmAccount)(a.int), @@ -128,7 +128,7 @@ func (a *Account) genOneTimeKeysRandomLen(num uint) uint { // supplied key. func (a *Account) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { - return nil, olm.NoKeyProvided + return nil, olm.ErrNoKeyProvided } pickled := make([]byte, a.pickleLen()) r := C.olm_pickle_account( @@ -145,7 +145,7 @@ func (a *Account) Pickle(key []byte) ([]byte, error) { func (a *Account) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return olm.NoKeyProvided + return olm.ErrNoKeyProvided } r := C.olm_unpickle_account( (*C.OlmAccount)(a.int), @@ -198,7 +198,7 @@ func (a *Account) MarshalJSON() ([]byte, error) { // Deprecated func (a *Account) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return olm.InputNotJSONString + return olm.ErrInputNotJSONString } if a.int == nil { *a = *NewBlankAccount() @@ -235,7 +235,7 @@ func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519, error) { // Account. func (a *Account) Sign(message []byte) ([]byte, error) { if len(message) == 0 { - panic(olm.EmptyInput) + panic(olm.ErrEmptyInput) } signature := make([]byte, a.signatureLen()) r := C.olm_account_sign( @@ -299,7 +299,7 @@ func (a *Account) GenOneTimeKeys(num uint) error { random := make([]byte, a.genOneTimeKeysRandomLen(num)+1) _, err := rand.Read(random) if err != nil { - return olm.NotEnoughGoRandom + return olm.ErrNotEnoughGoRandom } r := C.olm_account_generate_one_time_keys( (*C.OlmAccount)(a.int), @@ -319,13 +319,13 @@ func (a *Account) GenOneTimeKeys(num uint) error { // keys couldn't be decoded as base64 then the error will be "INVALID_BASE64" func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (olm.Session, error) { if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } s := NewBlankSession() random := make([]byte, s.createOutboundRandomLen()+1) _, err := rand.Read(random) if err != nil { - panic(olm.NotEnoughGoRandom) + panic(olm.ErrNotEnoughGoRandom) } theirIdentityKeyCopy := []byte(theirIdentityKey) theirOneTimeKeyCopy := []byte(theirOneTimeKey) @@ -357,7 +357,7 @@ func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve2 // time key then the error will be "BAD_MESSAGE_KEY_ID". func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) { if len(oneTimeKeyMsg) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } s := NewBlankSession() oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) @@ -383,7 +383,7 @@ func (a *Account) NewInboundSession(oneTimeKeyMsg string) (olm.Session, error) { // time key then the error will be "BAD_MESSAGE_KEY_ID". func (a *Account) NewInboundSessionFrom(theirIdentityKey *id.Curve25519, oneTimeKeyMsg string) (olm.Session, error) { if theirIdentityKey == nil || len(oneTimeKeyMsg) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } theirIdentityKeyCopy := []byte(*theirIdentityKey) oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) diff --git a/crypto/libolm/error.go b/crypto/libolm/error.go index 9ca415ee..6fb5512b 100644 --- a/crypto/libolm/error.go +++ b/crypto/libolm/error.go @@ -11,21 +11,21 @@ import ( ) var errorMap = map[string]error{ - "NOT_ENOUGH_RANDOM": olm.NotEnoughRandom, - "OUTPUT_BUFFER_TOO_SMALL": olm.OutputBufferTooSmall, - "BAD_MESSAGE_VERSION": olm.BadMessageVersion, - "BAD_MESSAGE_FORMAT": olm.BadMessageFormat, - "BAD_MESSAGE_MAC": olm.BadMessageMAC, - "BAD_MESSAGE_KEY_ID": olm.BadMessageKeyID, - "INVALID_BASE64": olm.InvalidBase64, - "BAD_ACCOUNT_KEY": olm.BadAccountKey, - "UNKNOWN_PICKLE_VERSION": olm.UnknownPickleVersion, - "CORRUPTED_PICKLE": olm.CorruptedPickle, - "BAD_SESSION_KEY": olm.BadSessionKey, - "UNKNOWN_MESSAGE_INDEX": olm.UnknownMessageIndex, - "BAD_LEGACY_ACCOUNT_PICKLE": olm.BadLegacyAccountPickle, - "BAD_SIGNATURE": olm.BadSignature, - "INPUT_BUFFER_TOO_SMALL": olm.InputBufferTooSmall, + "NOT_ENOUGH_RANDOM": olm.ErrLibolmNotEnoughRandom, + "OUTPUT_BUFFER_TOO_SMALL": olm.ErrLibolmOutputBufferTooSmall, + "BAD_MESSAGE_VERSION": olm.ErrWrongProtocolVersion, + "BAD_MESSAGE_FORMAT": olm.ErrBadMessageFormat, + "BAD_MESSAGE_MAC": olm.ErrBadMAC, + "BAD_MESSAGE_KEY_ID": olm.ErrBadMessageKeyID, + "INVALID_BASE64": olm.ErrLibolmInvalidBase64, + "BAD_ACCOUNT_KEY": olm.ErrLibolmBadAccountKey, + "UNKNOWN_PICKLE_VERSION": olm.ErrUnknownOlmPickleVersion, + "CORRUPTED_PICKLE": olm.ErrLibolmCorruptedPickle, + "BAD_SESSION_KEY": olm.ErrLibolmBadSessionKey, + "UNKNOWN_MESSAGE_INDEX": olm.ErrUnknownMessageIndex, + "BAD_LEGACY_ACCOUNT_PICKLE": olm.ErrLibolmBadLegacyAccountPickle, + "BAD_SIGNATURE": olm.ErrBadSignature, + "INPUT_BUFFER_TOO_SMALL": olm.ErrInputToSmall, } func convertError(errCode string) error { diff --git a/crypto/libolm/inboundgroupsession.go b/crypto/libolm/inboundgroupsession.go index 5606475d..8815ac32 100644 --- a/crypto/libolm/inboundgroupsession.go +++ b/crypto/libolm/inboundgroupsession.go @@ -31,7 +31,7 @@ var _ olm.InboundGroupSession = (*InboundGroupSession)(nil) // base64 couldn't be decoded then the error will be "INVALID_BASE64". func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) { if len(pickled) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } lenKey := len(key) if lenKey == 0 { @@ -48,7 +48,7 @@ func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, // "OLM_BAD_SESSION_KEY". func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { if len(sessionKey) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } s := NewBlankInboundGroupSession() r := C.olm_init_inbound_group_session( @@ -69,7 +69,7 @@ func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { // error will be "OLM_BAD_SESSION_KEY". func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) { if len(sessionKey) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } s := NewBlankInboundGroupSession() r := C.olm_import_inbound_group_session( @@ -124,7 +124,7 @@ func (s *InboundGroupSession) pickleLen() uint { // InboundGroupSession using the supplied key. func (s *InboundGroupSession) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { - return nil, olm.NoKeyProvided + return nil, olm.ErrNoKeyProvided } pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_inbound_group_session( @@ -143,9 +143,9 @@ func (s *InboundGroupSession) Pickle(key []byte) ([]byte, error) { func (s *InboundGroupSession) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return olm.NoKeyProvided + return olm.ErrNoKeyProvided } else if len(pickled) == 0 { - return olm.EmptyInput + return olm.ErrEmptyInput } r := C.olm_unpickle_inbound_group_session( (*C.OlmInboundGroupSession)(s.int), @@ -200,7 +200,7 @@ func (s *InboundGroupSession) MarshalJSON() ([]byte, error) { // Deprecated func (s *InboundGroupSession) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return olm.InputNotJSONString + return olm.ErrInputNotJSONString } if s == nil || s.int == nil { *s = *NewBlankInboundGroupSession() @@ -217,7 +217,7 @@ func (s *InboundGroupSession) UnmarshalJSON(data []byte) error { // will be "BAD_MESSAGE_FORMAT". func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, error) { if len(message) == 0 { - return 0, olm.EmptyInput + return 0, olm.ErrEmptyInput } // olm_group_decrypt_max_plaintext_length destroys the input, so we have to clone it messageCopy := bytes.Clone(message) @@ -244,7 +244,7 @@ func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, erro // was shared with us) the error will be "OLM_UNKNOWN_MESSAGE_INDEX". func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) { if len(message) == 0 { - return nil, 0, olm.EmptyInput + return nil, 0, olm.ErrEmptyInput } decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message) if err != nil { diff --git a/crypto/libolm/outboundgroupsession.go b/crypto/libolm/outboundgroupsession.go index 646929eb..ca5b68f7 100644 --- a/crypto/libolm/outboundgroupsession.go +++ b/crypto/libolm/outboundgroupsession.go @@ -84,7 +84,7 @@ func (s *OutboundGroupSession) pickleLen() uint { // OutboundGroupSession using the supplied key. func (s *OutboundGroupSession) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { - return nil, olm.NoKeyProvided + return nil, olm.ErrNoKeyProvided } pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_outbound_group_session( @@ -103,7 +103,7 @@ func (s *OutboundGroupSession) Pickle(key []byte) ([]byte, error) { func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return olm.NoKeyProvided + return olm.ErrNoKeyProvided } r := C.olm_unpickle_outbound_group_session( (*C.OlmOutboundGroupSession)(s.int), @@ -159,7 +159,7 @@ func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) { // Deprecated func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return olm.InputNotJSONString + return olm.ErrInputNotJSONString } if s == nil || s.int == nil { *s = *NewBlankOutboundGroupSession() @@ -183,7 +183,7 @@ func (s *OutboundGroupSession) encryptMsgLen(plainTextLen int) uint { // as base64. func (s *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) { if len(plaintext) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } message := make([]byte, s.encryptMsgLen(len(plaintext))) r := C.olm_group_encrypt( diff --git a/crypto/libolm/pk.go b/crypto/libolm/pk.go index 35532140..2683cf15 100644 --- a/crypto/libolm/pk.go +++ b/crypto/libolm/pk.go @@ -86,7 +86,7 @@ func NewPKSigning() (*PKSigning, error) { seed := make([]byte, pkSigningSeedLength()) _, err := rand.Read(seed) if err != nil { - panic(olm.NotEnoughGoRandom) + panic(olm.ErrNotEnoughGoRandom) } pk, err := NewPKSigningFromSeed(seed) return pk, err diff --git a/crypto/libolm/register.go b/crypto/libolm/register.go index f091d822..ddf84613 100644 --- a/crypto/libolm/register.go +++ b/crypto/libolm/register.go @@ -65,7 +65,7 @@ func Register() { olm.InitNewOutboundGroupSessionFromPickled = func(pickled, key []byte) (olm.OutboundGroupSession, error) { if len(pickled) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } s := NewBlankOutboundGroupSession() return s, s.Unpickle(pickled, key) diff --git a/crypto/libolm/session.go b/crypto/libolm/session.go index 57e631c3..1441df26 100644 --- a/crypto/libolm/session.go +++ b/crypto/libolm/session.go @@ -51,7 +51,7 @@ func sessionSize() uint { // "INVALID_BASE64". func SessionFromPickled(pickled, key []byte) (*Session, error) { if len(pickled) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } s := NewBlankSession() return s, s.Unpickle(pickled, key) @@ -118,7 +118,7 @@ func (s *Session) encryptMsgLen(plainTextLen int) uint { // will be "BAD_MESSAGE_FORMAT". func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) (uint, error) { if len(message) == 0 { - return 0, olm.EmptyInput + return 0, olm.ErrEmptyInput } messageCopy := []byte(message) r := C.olm_decrypt_max_plaintext_length( @@ -138,7 +138,7 @@ func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) // supplied key. func (s *Session) Pickle(key []byte) ([]byte, error) { if len(key) == 0 { - return nil, olm.NoKeyProvided + return nil, olm.ErrNoKeyProvided } pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_session( @@ -158,7 +158,7 @@ func (s *Session) Pickle(key []byte) ([]byte, error) { // provided key. This function mutates the input pickled data slice. func (s *Session) Unpickle(pickled, key []byte) error { if len(key) == 0 { - return olm.NoKeyProvided + return olm.ErrNoKeyProvided } r := C.olm_unpickle_session( (*C.OlmSession)(s.int), @@ -213,7 +213,7 @@ func (s *Session) MarshalJSON() ([]byte, error) { // Deprecated func (s *Session) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { - return olm.InputNotJSONString + return olm.ErrInputNotJSONString } if s == nil || s.int == nil { *s = *NewBlankSession() @@ -256,7 +256,7 @@ func (s *Session) HasReceivedMessage() bool { // decoded then then the error will be "BAD_MESSAGE_FORMAT". func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { if len(oneTimeKeyMsg) == 0 { - return false, olm.EmptyInput + return false, olm.ErrEmptyInput } oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) r := C.olm_matches_inbound_session( @@ -284,7 +284,7 @@ func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { // decoded then then the error will be "BAD_MESSAGE_FORMAT". func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) { if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 { - return false, olm.EmptyInput + return false, olm.ErrEmptyInput } theirIdentityKeyCopy := []byte(theirIdentityKey) oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg) @@ -325,14 +325,14 @@ func (s *Session) EncryptMsgType() id.OlmMsgType { // as base64. func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) { if len(plaintext) == 0 { - return 0, nil, olm.EmptyInput + return 0, nil, olm.ErrEmptyInput } // Make the slice be at least length 1 random := make([]byte, s.encryptRandomLen()+1) _, err := rand.Read(random) if err != nil { // TODO can we just return err here? - return 0, nil, olm.NotEnoughGoRandom + return 0, nil, olm.ErrNotEnoughGoRandom } messageType := s.EncryptMsgType() message := make([]byte, s.encryptMsgLen(len(plaintext))) @@ -362,7 +362,7 @@ func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) { // "BAD_MESSAGE_MAC". func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) { if len(message) == 0 { - return nil, olm.EmptyInput + return nil, olm.ErrEmptyInput } decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message, msgType) if err != nil { diff --git a/crypto/machine.go b/crypto/machine.go index 4d2e3880..f8ebe909 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -205,7 +205,7 @@ func (mach *OlmMachine) FlushStore(ctx context.Context) error { func (mach *OlmMachine) timeTrace(ctx context.Context, thing string, expectedDuration time.Duration) func() { start := time.Now() return func() { - duration := time.Now().Sub(start) + duration := time.Since(start) if duration > expectedDuration { zerolog.Ctx(ctx).Warn(). Str("action", thing). diff --git a/crypto/olm/errors.go b/crypto/olm/errors.go index 957d7928..9e522b2a 100644 --- a/crypto/olm/errors.go +++ b/crypto/olm/errors.go @@ -10,50 +10,67 @@ import "errors" // Those are the most common used errors var ( - ErrBadSignature = errors.New("bad signature") - ErrBadMAC = errors.New("bad mac") - ErrBadMessageFormat = errors.New("bad message format") - ErrBadVerification = errors.New("bad verification") - ErrWrongProtocolVersion = errors.New("wrong protocol version") - ErrEmptyInput = errors.New("empty input") - ErrNoKeyProvided = errors.New("no key") - ErrBadMessageKeyID = errors.New("bad message key id") - ErrRatchetNotAvailable = errors.New("ratchet not available: attempt to decode a message whose index is earlier than our earliest known session key") - ErrMsgIndexTooHigh = errors.New("message index too high") - ErrProtocolViolation = errors.New("not protocol message order") - ErrMessageKeyNotFound = errors.New("message key not found") - ErrChainTooHigh = errors.New("chain index too high") - ErrBadInput = errors.New("bad input") - ErrBadVersion = errors.New("wrong version") - ErrWrongPickleVersion = errors.New("wrong pickle version") - ErrInputToSmall = errors.New("input too small (truncated?)") - ErrOverflow = errors.New("overflow") + ErrBadSignature = errors.New("bad signature") + ErrBadMAC = errors.New("the message couldn't be decrypted (bad mac)") + ErrBadMessageFormat = errors.New("the message couldn't be decoded") + ErrBadVerification = errors.New("bad verification") + ErrWrongProtocolVersion = errors.New("wrong protocol version") + ErrEmptyInput = errors.New("empty input") + ErrNoKeyProvided = errors.New("no key provided") + ErrBadMessageKeyID = errors.New("the message references an unknown key ID") + ErrUnknownMessageIndex = errors.New("attempt to decode a message whose index is earlier than our earliest known session key") + ErrMsgIndexTooHigh = errors.New("message index too high") + ErrProtocolViolation = errors.New("not protocol message order") + ErrMessageKeyNotFound = errors.New("message key not found") + ErrChainTooHigh = errors.New("chain index too high") + ErrBadInput = errors.New("bad input") + ErrUnknownOlmPickleVersion = errors.New("unknown olm pickle version") + ErrUnknownJSONPickleVersion = errors.New("unknown JSON pickle version") + ErrInputToSmall = errors.New("input too small (truncated?)") ) // Error codes from go-olm var ( - EmptyInput = errors.New("empty input") - NoKeyProvided = errors.New("no pickle key provided") - NotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand") - SignatureNotFound = errors.New("input JSON doesn't contain signature from specified device") - InputNotJSONString = errors.New("input doesn't look like a JSON string") + ErrNotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand") + ErrInputNotJSONString = errors.New("input doesn't look like a JSON string") ) // Error codes from olm code var ( - NotEnoughRandom = errors.New("not enough entropy was supplied") - OutputBufferTooSmall = errors.New("supplied output buffer is too small") - BadMessageVersion = errors.New("the message version is unsupported") - BadMessageFormat = errors.New("the message couldn't be decoded") - BadMessageMAC = errors.New("the message couldn't be decrypted") - BadMessageKeyID = errors.New("the message references an unknown key ID") - InvalidBase64 = errors.New("the input base64 was invalid") - BadAccountKey = errors.New("the supplied account key is invalid") - UnknownPickleVersion = errors.New("the pickled object is too new") - CorruptedPickle = errors.New("the pickled object couldn't be decoded") - BadSessionKey = errors.New("attempt to initialise an inbound group session from an invalid session key") - UnknownMessageIndex = errors.New("attempt to decode a message whose index is earlier than our earliest known session key") - BadLegacyAccountPickle = errors.New("attempt to unpickle an account which uses pickle version 1") - BadSignature = errors.New("received message had a bad signature") - InputBufferTooSmall = errors.New("the input data was too small to be valid") + ErrLibolmInvalidBase64 = errors.New("the input base64 was invalid") + + ErrLibolmNotEnoughRandom = errors.New("not enough entropy was supplied") + ErrLibolmOutputBufferTooSmall = errors.New("supplied output buffer is too small") + ErrLibolmBadAccountKey = errors.New("the supplied account key is invalid") + ErrLibolmCorruptedPickle = errors.New("the pickled object couldn't be decoded") + ErrLibolmBadSessionKey = errors.New("attempt to initialise an inbound group session from an invalid session key") + ErrLibolmBadLegacyAccountPickle = errors.New("attempt to unpickle an account which uses pickle version 1") +) + +// Deprecated: use variables prefixed with Err +var ( + EmptyInput = ErrEmptyInput + BadSignature = ErrBadSignature + InvalidBase64 = ErrLibolmInvalidBase64 + BadMessageKeyID = ErrBadMessageKeyID + BadMessageFormat = ErrBadMessageFormat + BadMessageVersion = ErrWrongProtocolVersion + BadMessageMAC = ErrBadMAC + UnknownPickleVersion = ErrUnknownOlmPickleVersion + NotEnoughRandom = ErrLibolmNotEnoughRandom + OutputBufferTooSmall = ErrLibolmOutputBufferTooSmall + BadAccountKey = ErrLibolmBadAccountKey + CorruptedPickle = ErrLibolmCorruptedPickle + BadSessionKey = ErrLibolmBadSessionKey + UnknownMessageIndex = ErrUnknownMessageIndex + BadLegacyAccountPickle = ErrLibolmBadLegacyAccountPickle + InputBufferTooSmall = ErrInputToSmall + NoKeyProvided = ErrNoKeyProvided + + NotEnoughGoRandom = ErrNotEnoughGoRandom + InputNotJSONString = ErrInputNotJSONString + + ErrBadVersion = ErrUnknownJSONPickleVersion + ErrWrongPickleVersion = ErrUnknownJSONPickleVersion + ErrRatchetNotAvailable = ErrUnknownMessageIndex ) diff --git a/crypto/sessions.go b/crypto/sessions.go index aecb0416..6b90c998 100644 --- a/crypto/sessions.go +++ b/crypto/sessions.go @@ -18,8 +18,14 @@ import ( ) var ( - SessionNotShared = errors.New("session has not been shared") - SessionExpired = errors.New("session has expired") + ErrSessionNotShared = errors.New("session has not been shared") + ErrSessionExpired = errors.New("session has expired") +) + +// Deprecated: use variables prefixed with Err +var ( + SessionNotShared = ErrSessionNotShared + SessionExpired = ErrSessionExpired ) // OlmSessionList is a list of OlmSessions. @@ -255,9 +261,9 @@ func (ogs *OutboundGroupSession) Expired() bool { func (ogs *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) { if !ogs.Shared { - return nil, SessionNotShared + return nil, ErrSessionNotShared } else if ogs.Expired() { - return nil, SessionExpired + return nil, ErrSessionExpired } ogs.MessageCount++ ogs.LastEncryptedTime = time.Now() diff --git a/event/encryption.go b/event/encryption.go index cf9c2814..e07944af 100644 --- a/event/encryption.go +++ b/event/encryption.go @@ -63,7 +63,7 @@ func (content *EncryptedEventContent) UnmarshalJSON(data []byte) error { return json.Unmarshal(content.Ciphertext, &content.OlmCiphertext) case id.AlgorithmMegolmV1: if len(content.Ciphertext) == 0 || content.Ciphertext[0] != '"' || content.Ciphertext[len(content.Ciphertext)-1] != '"' { - return id.InputNotJSONString + return id.ErrInputNotJSONString } content.MegolmCiphertext = content.Ciphertext[1 : len(content.Ciphertext)-1] } diff --git a/federation/resolution.go b/federation/resolution.go index 81e19cfb..a3188266 100644 --- a/federation/resolution.go +++ b/federation/resolution.go @@ -80,7 +80,10 @@ func ResolveServerName(ctx context.Context, serverName string, opts ...*ResolveS } else if wellKnown != nil { output.Expires = expiry output.HostHeader = wellKnown.Server - hostname, port, ok = ParseServerName(wellKnown.Server) + wkHost, wkPort, ok := ParseServerName(wellKnown.Server) + if ok { + hostname, port = wkHost, wkPort + } // Step 3.1 and 3.2: IP literals and hostnames with port inside .well-known if net.ParseIP(hostname) != nil || port != 0 { if port == 0 { diff --git a/filter.go b/filter.go index c6c8211b..54973dab 100644 --- a/filter.go +++ b/filter.go @@ -57,7 +57,7 @@ type FilterPart struct { // Validate checks if the filter contains valid property values func (filter *Filter) Validate() error { if filter.EventFormat != EventFormatClient && filter.EventFormat != EventFormatFederation { - return errors.New("Bad event_format value. Must be one of [\"client\", \"federation\"]") + return errors.New("bad event_format value") } return nil } diff --git a/id/contenturi.go b/id/contenturi.go index e6a313f5..be45eb2b 100644 --- a/id/contenturi.go +++ b/id/contenturi.go @@ -17,8 +17,14 @@ import ( ) var ( - InvalidContentURI = errors.New("invalid Matrix content URI") - InputNotJSONString = errors.New("input doesn't look like a JSON string") + ErrInvalidContentURI = errors.New("invalid Matrix content URI") + ErrInputNotJSONString = errors.New("input doesn't look like a JSON string") +) + +// Deprecated: use variables prefixed with Err +var ( + InvalidContentURI = ErrInvalidContentURI + InputNotJSONString = ErrInputNotJSONString ) // ContentURIString is a string that's expected to be a Matrix content URI. @@ -55,9 +61,9 @@ func ParseContentURI(uri string) (parsed ContentURI, err error) { if len(uri) == 0 { return } else if !strings.HasPrefix(uri, "mxc://") { - err = InvalidContentURI + err = ErrInvalidContentURI } else if index := strings.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 { - err = InvalidContentURI + err = ErrInvalidContentURI } else { parsed.Homeserver = uri[6 : 6+index] parsed.FileID = uri[6+index+1:] @@ -71,9 +77,9 @@ func ParseContentURIBytes(uri []byte) (parsed ContentURI, err error) { if len(uri) == 0 { return } else if !bytes.HasPrefix(uri, mxcBytes) { - err = InvalidContentURI + err = ErrInvalidContentURI } else if index := bytes.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 { - err = InvalidContentURI + err = ErrInvalidContentURI } else { parsed.Homeserver = string(uri[6 : 6+index]) parsed.FileID = string(uri[6+index+1:]) @@ -86,7 +92,7 @@ func (uri *ContentURI) UnmarshalJSON(raw []byte) (err error) { *uri = ContentURI{} return nil } else if len(raw) < 2 || raw[0] != '"' || raw[len(raw)-1] != '"' { - return InputNotJSONString + return ErrInputNotJSONString } parsed, err := ParseContentURIBytes(raw[1 : len(raw)-1]) if err != nil { diff --git a/id/matrixuri.go b/id/matrixuri.go index 8f5ec849..d5c78bc7 100644 --- a/id/matrixuri.go +++ b/id/matrixuri.go @@ -54,7 +54,7 @@ var SigilToPathSegment = map[rune]string{ func (uri *MatrixURI) getQuery() url.Values { q := make(url.Values) - if uri.Via != nil && len(uri.Via) > 0 { + if len(uri.Via) > 0 { q["via"] = uri.Via } if len(uri.Action) > 0 { diff --git a/id/userid.go b/id/userid.go index 859d2358..726a0d58 100644 --- a/id/userid.go +++ b/id/userid.go @@ -219,15 +219,15 @@ func DecodeUserLocalpart(str string) (string, error) { for i := 0; i < len(strBytes); i++ { b := strBytes[i] if !isValidByte(b) { - return "", fmt.Errorf("Byte pos %d: Invalid byte", i) + return "", fmt.Errorf("invalid encoded byte at position %d: %c", i, b) } if b == '_' { // next byte is a-z and should be upper-case or is another _ and should be a literal _ if i+1 >= len(strBytes) { - return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding but ran out of string", i) + return "", fmt.Errorf("unexpected end of string after underscore at %d", i) } if !isValidEscapedChar(strBytes[i+1]) { // invalid escaping - return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding", i) + return "", fmt.Errorf("unexpected byte %c after underscore at %d", strBytes[i+1], i) } if strBytes[i+1] == '_' { outputBuffer.WriteByte('_') @@ -237,7 +237,7 @@ func DecodeUserLocalpart(str string) (string, error) { i++ // skip next byte since we just handled it } else if b == '=' { // next 2 bytes are hex and should be buffered ready to be read as utf8 if i+2 >= len(strBytes) { - return "", fmt.Errorf("Byte pos: %d: expected quote-printable encoding but ran out of string", i) + return "", fmt.Errorf("unexpected end of string after equals sign at %d", i) } dst := make([]byte, 1) _, err := hex.Decode(dst, strBytes[i+1:i+3]) diff --git a/pushrules/action.go b/pushrules/action.go index 9838e88b..b5a884b2 100644 --- a/pushrules/action.go +++ b/pushrules/action.go @@ -105,7 +105,7 @@ func (action *PushAction) UnmarshalJSON(raw []byte) error { if ok { action.Action = ActionSetTweak action.Tweak = PushActionTweak(tweak) - action.Value, _ = val["value"] + action.Value = val["value"] } } return nil diff --git a/pushrules/condition_test.go b/pushrules/condition_test.go index 0d3eaf7a..37af3e34 100644 --- a/pushrules/condition_test.go +++ b/pushrules/condition_test.go @@ -102,14 +102,6 @@ func newEventPropertyIsPushCondition(key string, value any) *pushrules.PushCondi } } -func newEventPropertyContainsPushCondition(key string, value any) *pushrules.PushCondition { - return &pushrules.PushCondition{ - Kind: pushrules.KindEventPropertyContains, - Key: key, - Value: value, - } -} - func TestPushCondition_Match_InvalidKind(t *testing.T) { condition := &pushrules.PushCondition{ Kind: pushrules.PushCondKind("invalid"), diff --git a/room.go b/room.go index c3ddb7e6..4292bff5 100644 --- a/room.go +++ b/room.go @@ -5,8 +5,6 @@ import ( "maunium.net/go/mautrix/id" ) -type RoomStateMap = map[event.Type]map[string]*event.Event - // Room represents a single Matrix room. type Room struct { ID id.RoomID @@ -25,8 +23,8 @@ func (room Room) UpdateState(evt *event.Event) { // GetStateEvent returns the state event for the given type/state_key combo, or nil. func (room Room) GetStateEvent(eventType event.Type, stateKey string) *event.Event { - stateEventMap, _ := room.State[eventType] - evt, _ := stateEventMap[stateKey] + stateEventMap := room.State[eventType] + evt := stateEventMap[stateKey] return evt } diff --git a/synapseadmin/roomapi.go b/synapseadmin/roomapi.go index c360acab..0925b748 100644 --- a/synapseadmin/roomapi.go +++ b/synapseadmin/roomapi.go @@ -75,8 +75,7 @@ type RespListRooms struct { // https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#list-room-api func (cli *Client) ListRooms(ctx context.Context, req ReqListRoom) (RespListRooms, error) { var resp RespListRooms - var reqURL string - reqURL = cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery()) + reqURL := cli.Client.BuildURLWithQuery(mautrix.SynapseAdminURLPath{"v1", "rooms"}, req.BuildQuery()) _, err := cli.Client.MakeRequest(ctx, http.MethodGet, reqURL, nil, &resp) return resp, err } diff --git a/url.go b/url.go index d888956a..91b3d49d 100644 --- a/url.go +++ b/url.go @@ -98,10 +98,8 @@ func (saup SynapseAdminURLPath) FullPath() []any { // and appservice user ID set already. func (cli *Client) BuildURLWithQuery(urlPath PrefixableURLPath, urlQuery map[string]string) string { return cli.BuildURLWithFullQuery(urlPath, func(q url.Values) { - if urlQuery != nil { - for k, v := range urlQuery { - q.Set(k, v) - } + for k, v := range urlQuery { + q.Set(k, v) } }) } From e7a95b7f9732419e224843fc862b37e4359e726f Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Mon, 8 Dec 2025 14:33:02 +0000 Subject: [PATCH 1547/1647] client: backoff before retrying external upload requests --- client.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index b740cba6..90737581 100644 --- a/client.go +++ b/client.go @@ -2021,8 +2021,16 @@ func (cli *Client) uploadMediaToURL(ctx context.Context, data ReqUploadMedia) (* Msg("Error uploading media to external URL, not retrying") return nil, err } - cli.Log.Warn().Str("url", data.UnstableUploadURL).Err(err). + backoff := time.Second * time.Duration(cli.DefaultHTTPRetries-retries) + cli.Log.Warn().Err(err). + Str("url", data.UnstableUploadURL). + Int("retry_in_seconds", int(backoff.Seconds())). Msg("Error uploading media to external URL, retrying") + select { + case <-time.After(backoff): + case <-ctx.Done(): + return nil, ctx.Err() + } retries-- _, err = readerSeeker.Seek(0, io.SeekStart) if err != nil { From 31579be20ad8b53f442d49bdfecd061b8982df1d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 9 Dec 2025 16:37:17 +0200 Subject: [PATCH 1548/1647] bridgev2,event: add interface for message requests --- bridgev2/database/portal.go | 49 ++++++------- bridgev2/database/upgrades/00-latest.sql | 3 +- .../database/upgrades/25-message-requests.sql | 2 + bridgev2/errors.go | 69 ++++++++++--------- bridgev2/matrix/connector.go | 1 + bridgev2/networkinterface.go | 9 +++ bridgev2/portal.go | 60 ++++++++++++++-- event/beeper.go | 6 +- event/capabilities.d.ts | 5 ++ event/capabilities.go | 23 +++++++ event/content.go | 9 +-- event/state.go | 3 +- event/type.go | 11 +-- 13 files changed, 173 insertions(+), 77 deletions(-) create mode 100644 bridgev2/database/upgrades/25-message-requests.sql diff --git a/bridgev2/database/portal.go b/bridgev2/database/portal.go index f6868be6..0e6be286 100644 --- a/bridgev2/database/portal.go +++ b/bridgev2/database/portal.go @@ -56,30 +56,31 @@ type Portal struct { networkid.PortalKey MXID id.RoomID - ParentKey networkid.PortalKey - RelayLoginID networkid.UserLoginID - OtherUserID networkid.UserID - Name string - Topic string - AvatarID networkid.AvatarID - AvatarHash [32]byte - AvatarMXC id.ContentURIString - NameSet bool - TopicSet bool - AvatarSet bool - NameIsCustom bool - InSpace bool - RoomType RoomType - Disappear DisappearingSetting - CapState CapabilityState - Metadata any + ParentKey networkid.PortalKey + RelayLoginID networkid.UserLoginID + OtherUserID networkid.UserID + Name string + Topic string + AvatarID networkid.AvatarID + AvatarHash [32]byte + AvatarMXC id.ContentURIString + NameSet bool + TopicSet bool + AvatarSet bool + NameIsCustom bool + InSpace bool + MessageRequest bool + RoomType RoomType + Disappear DisappearingSetting + CapState CapabilityState + Metadata any } const ( getPortalBaseQuery = ` SELECT bridge_id, id, receiver, mxid, parent_id, parent_receiver, relay_login_id, other_user_id, name, topic, avatar_id, avatar_hash, avatar_mxc, - name_set, topic_set, avatar_set, name_is_custom, in_space, + name_set, topic_set, avatar_set, name_is_custom, in_space, message_request, room_type, disappear_type, disappear_timer, cap_state, metadata FROM portal @@ -101,11 +102,11 @@ const ( bridge_id, id, receiver, mxid, parent_id, parent_receiver, relay_login_id, other_user_id, name, topic, avatar_id, avatar_hash, avatar_mxc, - name_set, avatar_set, topic_set, name_is_custom, in_space, + name_set, avatar_set, topic_set, name_is_custom, in_space, message_request, room_type, disappear_type, disappear_timer, cap_state, metadata, relay_bridge_id ) VALUES ( - $1, $2, $3, $4, $5, $6, cast($7 AS TEXT), $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, + $1, $2, $3, $4, $5, $6, cast($7 AS TEXT), $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, CASE WHEN cast($7 AS TEXT) IS NULL THEN NULL ELSE $1 END ) ` @@ -114,8 +115,8 @@ const ( SET mxid=$4, parent_id=$5, parent_receiver=$6, relay_login_id=cast($7 AS TEXT), relay_bridge_id=CASE WHEN cast($7 AS TEXT) IS NULL THEN NULL ELSE bridge_id END, other_user_id=$8, name=$9, topic=$10, avatar_id=$11, avatar_hash=$12, avatar_mxc=$13, - name_set=$14, avatar_set=$15, topic_set=$16, name_is_custom=$17, in_space=$18, - room_type=$19, disappear_type=$20, disappear_timer=$21, cap_state=$22, metadata=$23 + name_set=$14, avatar_set=$15, topic_set=$16, name_is_custom=$17, in_space=$18, message_request=$19, + room_type=$20, disappear_type=$21, disappear_timer=$22, cap_state=$23, metadata=$24 WHERE bridge_id=$1 AND id=$2 AND receiver=$3 ` deletePortalQuery = ` @@ -241,7 +242,7 @@ func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { &p.BridgeID, &p.ID, &p.Receiver, &mxid, &parentID, &parentReceiver, &relayLoginID, &otherUserID, &p.Name, &p.Topic, &p.AvatarID, &avatarHash, &p.AvatarMXC, - &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.NameIsCustom, &p.InSpace, + &p.NameSet, &p.TopicSet, &p.AvatarSet, &p.NameIsCustom, &p.InSpace, &p.MessageRequest, &p.RoomType, &disappearType, &disappearTimer, dbutil.JSON{Data: &p.CapState}, dbutil.JSON{Data: p.Metadata}, ) @@ -288,7 +289,7 @@ func (p *Portal) sqlVariables() []any { p.BridgeID, p.ID, p.Receiver, dbutil.StrPtr(p.MXID), dbutil.StrPtr(p.ParentKey.ID), p.ParentKey.Receiver, dbutil.StrPtr(p.RelayLoginID), dbutil.StrPtr(p.OtherUserID), p.Name, p.Topic, p.AvatarID, avatarHash, p.AvatarMXC, - p.NameSet, p.TopicSet, p.AvatarSet, p.NameIsCustom, p.InSpace, + p.NameSet, p.TopicSet, p.AvatarSet, p.NameIsCustom, p.InSpace, p.MessageRequest, p.RoomType, dbutil.StrPtr(p.Disappear.Type), dbutil.NumPtr(p.Disappear.Timer), dbutil.JSON{Data: p.CapState}, dbutil.JSON{Data: p.Metadata}, } diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index efde8816..b01cca44 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v24 (compatible with v9+): Latest revision +-- v0 -> v25 (compatible with v9+): Latest revision CREATE TABLE "user" ( bridge_id TEXT NOT NULL, mxid TEXT NOT NULL, @@ -48,6 +48,7 @@ CREATE TABLE portal ( topic_set BOOLEAN NOT NULL, name_is_custom BOOLEAN NOT NULL DEFAULT false, in_space BOOLEAN NOT NULL, + message_request BOOLEAN NOT NULL DEFAULT false, room_type TEXT NOT NULL, disappear_type TEXT, disappear_timer BIGINT, diff --git a/bridgev2/database/upgrades/25-message-requests.sql b/bridgev2/database/upgrades/25-message-requests.sql new file mode 100644 index 00000000..b9d82a7a --- /dev/null +++ b/bridgev2/database/upgrades/25-message-requests.sql @@ -0,0 +1,2 @@ +-- v25 (compatible with v9+): Flag for message request portals +ALTER TABLE portal ADD COLUMN message_request BOOLEAN NOT NULL DEFAULT false; diff --git a/bridgev2/errors.go b/bridgev2/errors.go index e81b8953..a6cf4ceb 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -40,40 +40,41 @@ var ErrDirectMediaNotEnabled = errors.New("direct media is not enabled") // Common message status errors var ( - ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage() - ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false) - ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false) - ErrIgnoringPollFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring poll event from relayed user")).WithIsCertain(true).WithSendNotice(false) - ErrIgnoringDeleteChatRelayedUser error = WrapErrorInStatus(errors.New("ignoring delete chat event from relayed user")).WithIsCertain(true).WithSendNotice(false) - ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrPollsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support polls")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) - ErrRoomMetadataNotAllowed error = WrapErrorInStatus(errors.New("changes are not allowed here")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) - ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) - ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) - ErrInvalidStateKey error = WrapErrorInStatus(errors.New("room metadata state key is unset or non-empty")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) - ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true) - ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) - ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) - ErrUnsupportedMediaType error = WrapErrorInStatus(errors.New("unsupported media type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) - ErrMediaDurationTooLong error = WrapErrorInStatus(errors.New("media duration too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) - ErrVoiceMessageDurationTooLong error = WrapErrorInStatus(errors.New("voice message too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) - ErrMediaTooLarge error = WrapErrorInStatus(errors.New("media too large")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) - ErrIgnoringMNotice error = WrapErrorInStatus(errors.New("ignoring m.notice message")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) - ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true) - ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true) - ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true) - ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) - ErrDeleteChatNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting chats")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) - ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) - ErrRemoteEchoTimeout = WrapErrorInStatus(errors.New("remote echo timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) - ErrRemoteAckTimeout = WrapErrorInStatus(errors.New("remote ack timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) + ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage() + ErrNoPortal error = WrapErrorInStatus(errors.New("room is not a portal")).WithIsCertain(true).WithSendNotice(false) + ErrIgnoringReactionFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring reaction event from relayed user")).WithIsCertain(true).WithSendNotice(false) + ErrIgnoringPollFromRelayedUser error = WrapErrorInStatus(errors.New("ignoring poll event from relayed user")).WithIsCertain(true).WithSendNotice(false) + ErrIgnoringDeleteChatRelayedUser error = WrapErrorInStatus(errors.New("ignoring delete chat event from relayed user")).WithIsCertain(true).WithSendNotice(false) + ErrIgnoringAcceptRequestRelayedUser error = WrapErrorInStatus(errors.New("ignoring accept message request event from relayed user")).WithIsCertain(true).WithSendNotice(false) + ErrEditsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support edits")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrEditsNotSupportedInPortal error = WrapErrorInStatus(errors.New("edits are not allowed in this chat")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrCaptionsNotAllowed error = WrapErrorInStatus(errors.New("captions are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrLocationMessagesNotAllowed error = WrapErrorInStatus(errors.New("location messages are not supported here")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrEditTargetTooOld error = WrapErrorInStatus(errors.New("the message is too old to be edited")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrEditTargetTooManyEdits error = WrapErrorInStatus(errors.New("the message has been edited too many times")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrReactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support reactions")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrPollsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support polls")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrRoomMetadataNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing room metadata")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrRoomMetadataNotAllowed error = WrapErrorInStatus(errors.New("changes are not allowed here")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrRedactionsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting messages")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported) + ErrUnexpectedParsedContentType error = WrapErrorInStatus(errors.New("unexpected parsed content type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true) + ErrInvalidStateKey error = WrapErrorInStatus(errors.New("room metadata state key is unset or non-empty")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) + ErrDatabaseError error = WrapErrorInStatus(errors.New("database error")).WithMessage("internal database error").WithIsCertain(true).WithSendNotice(true) + ErrTargetMessageNotFound error = WrapErrorInStatus(errors.New("target message not found")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(false) + ErrUnsupportedMessageType error = WrapErrorInStatus(errors.New("unsupported message type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) + ErrUnsupportedMediaType error = WrapErrorInStatus(errors.New("unsupported media type")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) + ErrMediaDurationTooLong error = WrapErrorInStatus(errors.New("media duration too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) + ErrVoiceMessageDurationTooLong error = WrapErrorInStatus(errors.New("voice message too long")).WithErrorAsMessage().WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) + ErrMediaTooLarge error = WrapErrorInStatus(errors.New("media too large")).WithErrorAsMessage().WithIsCertain(true).WithSendNotice(true).WithErrorReason(event.MessageStatusUnsupported) + ErrIgnoringMNotice error = WrapErrorInStatus(errors.New("ignoring m.notice message")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false) + ErrMediaDownloadFailed error = WrapErrorInStatus(errors.New("failed to download media")).WithMessage("failed to download media").WithIsCertain(true).WithSendNotice(true) + ErrMediaReuploadFailed error = WrapErrorInStatus(errors.New("failed to reupload media")).WithMessage("failed to reupload media").WithIsCertain(true).WithSendNotice(true) + ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true) + ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrDeleteChatNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting chats")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrRemoteEchoTimeout = WrapErrorInStatus(errors.New("remote echo timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) + ErrRemoteAckTimeout = WrapErrorInStatus(errors.New("remote ack timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) ErrPublicMediaDisabled = WrapErrorInStatus(errors.New("public media is not enabled in the bridge config")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true) ErrPublicMediaDatabaseDisabled = WrapErrorInStatus(errors.New("public media database storage is disabled")).WithIsCertain(true).WithErrorAsMessage().WithErrorReason(event.MessageStatusUnsupported).WithSendNotice(true) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index e34e3252..cdfc2568 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -151,6 +151,7 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) { br.EventProcessor.On(event.StateTombstone, br.handleRoomEvent) br.EventProcessor.On(event.StateBeeperDisappearingTimer, br.handleRoomEvent) br.EventProcessor.On(event.BeeperDeleteChat, br.handleRoomEvent) + br.EventProcessor.On(event.BeeperAcceptMessageRequest, br.handleRoomEvent) br.EventProcessor.On(event.EphemeralEventReceipt, br.handleEphemeralEvent) br.EventProcessor.On(event.EphemeralEventTyping, br.handleEphemeralEvent) br.Bot = br.AS.BotIntent() diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 9c3f7d71..adbd3155 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -712,6 +712,14 @@ type DeleteChatHandlingNetworkAPI interface { HandleMatrixDeleteChat(ctx context.Context, msg *MatrixDeleteChat) error } +// MessageRequestAcceptingNetworkAPI is an optional interface that network connectors +// can implement to accept message requests from the remote network. +type MessageRequestAcceptingNetworkAPI interface { + NetworkAPI + // HandleMatrixAcceptMessageRequest is called when the user accepts a message request. + HandleMatrixAcceptMessageRequest(ctx context.Context, msg *MatrixAcceptMessageRequest) error +} + type ResolveIdentifierResponse struct { // Ghost is the ghost of the user that the identifier resolves to. // This field should be set whenever possible. However, it is not required, @@ -1419,6 +1427,7 @@ type MatrixViewingChat struct { } type MatrixDeleteChat = MatrixEventBase[*event.BeeperChatDeleteEventContent] +type MatrixAcceptMessageRequest = MatrixEventBase[*event.BeeperAcceptMessageRequestEventContent] type MatrixMarkedUnread = MatrixRoomMeta[*event.MarkedUnreadEventContent] type MatrixMute = MatrixRoomMeta[*event.BeeperMuteEventContent] type MatrixRoomTag = MatrixRoomMeta[*event.TagEventContent] diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 0d71535d..7ca3ffab 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -800,6 +800,8 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * return portal.handleMatrixPowerLevels(ctx, login, origSender, evt, isStateRequest) case event.BeeperDeleteChat: return portal.handleMatrixDeleteChat(ctx, login, origSender, evt) + case event.BeeperAcceptMessageRequest: + return portal.handleMatrixAcceptMessageRequest(ctx, login, origSender, evt) default: return EventHandlingResultIgnored } @@ -1749,6 +1751,45 @@ func (portal *Portal) getTargetUser(ctx context.Context, userID id.UserID) (Ghos } } +func (portal *Portal) handleMatrixAcceptMessageRequest( + ctx context.Context, + sender *UserLogin, + origSender *OrigSender, + evt *event.Event, +) EventHandlingResult { + if origSender != nil { + return EventHandlingResultFailed.WithMSSError(ErrIgnoringAcceptRequestRelayedUser) + } + log := zerolog.Ctx(ctx) + content, ok := evt.Content.Parsed.(*event.BeeperAcceptMessageRequestEventContent) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + } + api, ok := sender.Client.(MessageRequestAcceptingNetworkAPI) + if !ok { + return EventHandlingResultIgnored.WithMSSError(ErrDeleteChatNotSupported) + } + err := api.HandleMatrixAcceptMessageRequest(ctx, &MatrixAcceptMessageRequest{ + Event: evt, + Content: content, + Portal: portal, + }) + if err != nil { + log.Err(err).Msg("Failed to handle Matrix accept message request") + return EventHandlingResultFailed.WithMSSError(err) + } + if portal.MessageRequest { + portal.MessageRequest = false + portal.UpdateBridgeInfo(ctx) + err = portal.Save(ctx) + if err != nil { + log.Err(err).Msg("Failed to save portal after accepting message request") + } + } + return EventHandlingResultSuccess.WithMSS() +} + func (portal *Portal) handleMatrixDeleteChat( ctx context.Context, sender *UserLogin, @@ -3948,9 +3989,9 @@ type ChatInfo struct { Disappear *database.DisappearingSetting ParentID *networkid.PortalID - UserLocal *UserLocalPortalInfo - - CanBackfill bool + UserLocal *UserLocalPortalInfo + MessageRequest *bool + CanBackfill bool ExcludeChangesFromTimeline bool @@ -4070,10 +4111,11 @@ func (portal *Portal) getBridgeInfo() (string, event.BridgeEventContent) { Creator: portal.Bridge.Bot.GetMXID(), Protocol: portal.Bridge.Network.GetName().AsBridgeInfoSection(), Channel: event.BridgeInfoSection{ - ID: string(portal.ID), - DisplayName: portal.Name, - AvatarURL: portal.AvatarMXC, - Receiver: string(portal.Receiver), + ID: string(portal.ID), + DisplayName: portal.Name, + AvatarURL: portal.AvatarMXC, + Receiver: string(portal.Receiver), + MessageRequest: portal.MessageRequest, // TODO external URL? }, BeeperRoomTypeV2: string(portal.RoomType), @@ -4815,6 +4857,10 @@ func (portal *Portal) UpdateInfo(ctx context.Context, info *ChatInfo, source *Us portal.RoomType = *info.Type } } + if info.MessageRequest != nil && *info.MessageRequest != portal.MessageRequest { + changed = true + portal.MessageRequest = *info.MessageRequest + } if info.Members != nil && portal.MXID != "" && source != nil { err := portal.syncParticipants(ctx, info.Members, source, nil, time.Time{}) if err != nil { diff --git a/event/beeper.go b/event/beeper.go index 94892de7..75c18aa7 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -89,7 +89,11 @@ type BeeperRoomKeyAckEventContent struct { } type BeeperChatDeleteEventContent struct { - DeleteForEveryone bool `json:"delete_for_everyone,omitempty"` + DeleteForEveryone bool `json:"delete_for_everyone,omitempty"` + FromMessageRequest bool `json:"from_message_request,omitempty"` +} + +type BeeperAcceptMessageRequestEventContent struct { } type BeeperSendStateEventContent struct { diff --git a/event/capabilities.d.ts b/event/capabilities.d.ts index 1fbc9610..26aeb347 100644 --- a/event/capabilities.d.ts +++ b/event/capabilities.d.ts @@ -77,6 +77,11 @@ export interface RoomFeatures { delete_chat?: boolean /** Whether deleting the chat for all participants is supported. */ delete_chat_for_everyone?: boolean + /** What can be done with message requests? */ + message_request?: { + accept_with_message?: CapabilitySupportLevel + accept_with_button?: CapabilitySupportLevel + } } declare type integer = number diff --git a/event/capabilities.go b/event/capabilities.go index 4b7ff186..a86c726b 100644 --- a/event/capabilities.go +++ b/event/capabilities.go @@ -61,6 +61,8 @@ type RoomFeatures struct { DeleteChat bool `json:"delete_chat,omitempty"` DeleteChatForEveryone bool `json:"delete_chat_for_everyone,omitempty"` + MessageRequest *MessageRequestFeatures `json:"message_request,omitempty"` + PerMessageProfileRelay bool `json:"-"` } @@ -84,6 +86,7 @@ func (rf *RoomFeatures) Clone() *RoomFeatures { clone.DeleteMaxAge = ptr.Clone(clone.DeleteMaxAge) clone.DisappearingTimer = clone.DisappearingTimer.Clone() clone.AllowedReactions = slices.Clone(clone.AllowedReactions) + clone.MessageRequest = clone.MessageRequest.Clone() return &clone } @@ -165,6 +168,25 @@ func (dtc *DisappearingTimerCapability) Supports(content *BeeperDisappearingTime return slices.Contains(dtc.Types, content.Type) && (dtc.Timers == nil || slices.Contains(dtc.Timers, content.Timer)) } +type MessageRequestFeatures struct { + AcceptWithMessage CapabilitySupportLevel `json:"accept_with_message,omitempty"` + AcceptWithButton CapabilitySupportLevel `json:"accept_with_button,omitempty"` +} + +func (mrf *MessageRequestFeatures) Clone() *MessageRequestFeatures { + return ptr.Clone(mrf) +} + +func (mrf *MessageRequestFeatures) Hash() []byte { + if mrf == nil { + return nil + } + hasher := sha256.New() + hashValue(hasher, "accept_with_message", mrf.AcceptWithMessage) + hashValue(hasher, "accept_with_button", mrf.AcceptWithButton) + return hasher.Sum(nil) +} + type CapabilityMsgType = MessageType // Message types which are used for event capability signaling, but aren't real values for the msgtype field. @@ -347,6 +369,7 @@ func (rf *RoomFeatures) Hash() []byte { hashBool(hasher, "mark_as_unread", rf.MarkAsUnread) hashBool(hasher, "delete_chat", rf.DeleteChat) hashBool(hasher, "delete_chat_for_everyone", rf.DeleteChatForEveryone) + hashValue(hasher, "message_request", rf.MessageRequest) return hasher.Sum(nil) } diff --git a/event/content.go b/event/content.go index 73fb0db5..4929c6a5 100644 --- a/event/content.go +++ b/event/content.go @@ -61,10 +61,11 @@ var TypeMap = map[Type]reflect.Type{ EventUnstablePollStart: reflect.TypeOf(PollStartEventContent{}), EventUnstablePollResponse: reflect.TypeOf(PollResponseEventContent{}), - BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}), - BeeperTranscription: reflect.TypeOf(BeeperTranscriptionEventContent{}), - BeeperDeleteChat: reflect.TypeOf(BeeperChatDeleteEventContent{}), - BeeperSendState: reflect.TypeOf(BeeperSendStateEventContent{}), + BeeperMessageStatus: reflect.TypeOf(BeeperMessageStatusEventContent{}), + BeeperTranscription: reflect.TypeOf(BeeperTranscriptionEventContent{}), + BeeperDeleteChat: reflect.TypeOf(BeeperChatDeleteEventContent{}), + BeeperAcceptMessageRequest: reflect.TypeOf(BeeperAcceptMessageRequestEventContent{}), + BeeperSendState: reflect.TypeOf(BeeperSendStateEventContent{}), AccountDataRoomTags: reflect.TypeOf(TagEventContent{}), AccountDataDirectChats: reflect.TypeOf(DirectChatsEventContent{}), diff --git a/event/state.go b/event/state.go index 6df3b143..29e0e524 100644 --- a/event/state.go +++ b/event/state.go @@ -231,7 +231,8 @@ type BridgeInfoSection struct { AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` ExternalURL string `json:"external_url,omitempty"` - Receiver string `json:"fi.mau.receiver,omitempty"` + Receiver string `json:"fi.mau.receiver,omitempty"` + MessageRequest bool `json:"com.beeper.message_request,omitempty"` } // BridgeEventContent represents the content of a m.bridge state event. diff --git a/event/type.go b/event/type.go index 4fca07ea..f4d7592c 100644 --- a/event/type.go +++ b/event/type.go @@ -128,7 +128,7 @@ func (et *Type) GuessClass() TypeClass { InRoomVerificationKey.Type, InRoomVerificationMAC.Type, InRoomVerificationCancel.Type, CallInvite.Type, CallCandidates.Type, CallAnswer.Type, CallReject.Type, CallSelectAnswer.Type, CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type, EventUnstablePollStart.Type, EventUnstablePollResponse.Type, - EventUnstablePollEnd.Type, BeeperTranscription.Type, BeeperDeleteChat.Type: + EventUnstablePollEnd.Type, BeeperTranscription.Type, BeeperDeleteChat.Type, BeeperAcceptMessageRequest.Type: return MessageEventType case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type, ToDeviceBeeperRoomKeyAck.Type: @@ -234,10 +234,11 @@ var ( CallNegotiate = Type{"m.call.negotiate", MessageEventType} CallHangup = Type{"m.call.hangup", MessageEventType} - BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType} - BeeperTranscription = Type{"com.beeper.transcription", MessageEventType} - BeeperDeleteChat = Type{"com.beeper.delete_chat", MessageEventType} - BeeperSendState = Type{"com.beeper.send_state", MessageEventType} + BeeperMessageStatus = Type{"com.beeper.message_send_status", MessageEventType} + BeeperTranscription = Type{"com.beeper.transcription", MessageEventType} + BeeperDeleteChat = Type{"com.beeper.delete_chat", MessageEventType} + BeeperAcceptMessageRequest = Type{"com.beeper.accept_message_request", MessageEventType} + BeeperSendState = Type{"com.beeper.send_state", MessageEventType} EventUnstablePollStart = Type{Type: "org.matrix.msc3381.poll.start", Class: MessageEventType} EventUnstablePollResponse = Type{Type: "org.matrix.msc3381.poll.response", Class: MessageEventType} From 2c62641c739c36edf0c197efe89506d0c67c8c0c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 10 Dec 2025 13:15:33 +0200 Subject: [PATCH 1549/1647] bridgev2/portal: make queueEvent slightly safer when deleting portals --- bridgev2/errors.go | 2 ++ bridgev2/portal.go | 39 ++++++++++++++++++++++++++------------- 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/bridgev2/errors.go b/bridgev2/errors.go index a6cf4ceb..c39f8707 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -38,6 +38,8 @@ var ErrNotLoggedIn = errors.New("not logged in") // but direct media is not enabled. var ErrDirectMediaNotEnabled = errors.New("direct media is not enabled") +var ErrPortalIsDeleted = errors.New("portal is deleted") + // Common message status errors var ( ErrPanicInEventHandler error = WrapErrorInStatus(errors.New("panic in event handler")).WithSendNotice(true).WithErrorAsMessage() diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 7ca3ffab..273b1fd3 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -93,7 +93,7 @@ type Portal struct { functionalMembersCache *event.ElementFunctionalMembersContent events chan portalEvent - deleted bool + deleted *exsync.Event eventsLock sync.Mutex eventIdx int @@ -127,6 +127,7 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que outgoingMessages: make(map[networkid.TransactionID]*outgoingMessage), RoomCreated: exsync.NewEvent(), + deleted: exsync.NewEvent(), } if portal.MXID != "" { portal.RoomCreated.Set() @@ -335,6 +336,9 @@ func (br *Bridge) GetExistingPortalByKey(ctx context.Context, key networkid.Port } func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) EventHandlingResult { + if portal.deleted.IsSet() { + return EventHandlingResultIgnored + } if PortalEventBuffer == 0 { portal.eventsLock.Lock() defer portal.eventsLock.Unlock() @@ -347,6 +351,8 @@ func (portal *Portal) queueEvent(ctx context.Context, evt portalEvent) EventHand select { case portal.events <- evt: return EventHandlingResultQueued + case <-portal.deleted.GetChan(): + return EventHandlingResultIgnored default: zerolog.Ctx(ctx).Error(). Str("portal_id", string(portal.ID)). @@ -371,16 +377,16 @@ func (portal *Portal) eventLoop() { go portal.pendingMessageTimeoutLoop(ctx, cfg) defer cancel() } - i := 0 - for rawEvt := range portal.events { - if portal.deleted { - return - } - i++ - if portal.Bridge.Config.AsyncEvents { - go portal.handleSingleEventWithDelayLogging(i, rawEvt) - } else { - portal.handleSingleEventWithDelayLogging(i, rawEvt) + deleteCh := portal.deleted.GetChan() + for i := 0; ; i++ { + select { + case rawEvt := <-portal.events: + if portal.Bridge.Config.AsyncEvents { + go portal.handleSingleEventWithDelayLogging(i, rawEvt) + } else { + portal.handleSingleEventWithDelayLogging(i, rawEvt) + } + case <-deleteCh: } } } @@ -4902,6 +4908,9 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i } return nil } + if portal.deleted.IsSet() { + return ErrPortalIsDeleted + } waiter := make(chan struct{}) closed := false evt := &portalCreateEvent{ @@ -4919,7 +4928,11 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i if PortalEventBuffer == 0 { go portal.queueEvent(ctx, evt) } else { - portal.events <- evt + select { + case portal.events <- evt: + case <-portal.deleted.GetChan(): + return ErrPortalIsDeleted + } } select { case <-ctx.Done(): @@ -5245,11 +5258,11 @@ func (portal *Portal) unlockedDeleteCache() { if portal.MXID != "" { delete(portal.Bridge.portalsByMXID, portal.MXID) } + portal.deleted.Set() if portal.events != nil { // TODO there's a small risk of this racing with a queueEvent call close(portal.events) } - portal.deleted = true } func (portal *Portal) Save(ctx context.Context) error { From efd4136c7a9361090dbbac69d9a7af59a568e68e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 11 Dec 2025 14:17:45 +0200 Subject: [PATCH 1550/1647] dependencies: update --- go.mod | 16 ++++++++-------- go.sum | 28 ++++++++++++++-------------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/go.mod b/go.mod index bf56a014..0b86e5da 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,7 @@ module maunium.net/go/mautrix go 1.24.0 -toolchain go1.25.4 +toolchain go1.25.5 require ( filippo.io/edwards25519 v1.1.0 @@ -17,12 +17,12 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.13 - go.mau.fi/util v0.9.4-0.20251206205611-85e6fd6551e0 + go.mau.fi/util v0.9.4-0.20251211121531-f6527b4882ae go.mau.fi/zeroconfig v0.2.0 - golang.org/x/crypto v0.45.0 - golang.org/x/exp v0.0.0-20251125195548-87e1e737ad39 - golang.org/x/net v0.47.0 - golang.org/x/sync v0.18.0 + golang.org/x/crypto v0.46.0 + golang.org/x/exp v0.0.0-20251209150349-8475f28825e9 + golang.org/x/net v0.48.0 + golang.org/x/sync v0.19.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 ) @@ -36,7 +36,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect - golang.org/x/sys v0.38.0 // indirect - golang.org/x/text v0.31.0 // indirect + golang.org/x/sys v0.39.0 // indirect + golang.org/x/text v0.32.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index 6ea3f378..5e9eded0 100644 --- a/go.sum +++ b/go.sum @@ -51,26 +51,26 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.13 h1:GPddIs617DnBLFFVJFgpo1aBfe/4xcvMc3SB5t/D0pA= github.com/yuin/goldmark v1.7.13/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.9.4-0.20251206205611-85e6fd6551e0 h1:ESebxPGULuuxxcZigjcBFyyU62tiyY6ivtX17P4BkvY= -go.mau.fi/util v0.9.4-0.20251206205611-85e6fd6551e0/go.mod h1:viDmhBOAFfcqDdKSk53EPJV3N4Mi8Jst5/ahGJ/vwsA= +go.mau.fi/util v0.9.4-0.20251211121531-f6527b4882ae h1:tocQOutgT+Z/V6w668Jpk3D5942K5p25XmRAvXg8s2E= +go.mau.fi/util v0.9.4-0.20251211121531-f6527b4882ae/go.mod h1:OwI76F1QINxtH/TOydGAAj5/VvtPG0RnZzB41rtnKcA= go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= -golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= -golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= -golang.org/x/exp v0.0.0-20251125195548-87e1e737ad39 h1:DHNhtq3sNNzrvduZZIiFyXWOL9IWaDPHqTnLJp+rCBY= -golang.org/x/exp v0.0.0-20251125195548-87e1e737ad39/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0= -golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= -golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= -golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= -golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= +golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= +golang.org/x/exp v0.0.0-20251209150349-8475f28825e9 h1:MDfG8Cvcqlt9XXrmEiD4epKn7VJHZO84hejP9Jmp0MM= +golang.org/x/exp v0.0.0-20251209150349-8475f28825e9/go.mod h1:EPRbTFwzwjXj9NpYyyrvenVh9Y+GFeEvMNh7Xuz7xgU= +golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= +golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= -golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= -golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= +golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= From 9e3fa96fb42f287ace7369637ba4973883133df3 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 12 Dec 2025 17:31:56 +0200 Subject: [PATCH 1551/1647] bridgev2/portal: handle portal deletion edge cases --- bridgev2/errors.go | 1 + bridgev2/portal.go | 6 ++++++ bridgev2/queue.go | 6 +++--- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/bridgev2/errors.go b/bridgev2/errors.go index c39f8707..514dc238 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -39,6 +39,7 @@ var ErrNotLoggedIn = errors.New("not logged in") var ErrDirectMediaNotEnabled = errors.New("direct media is not enabled") var ErrPortalIsDeleted = errors.New("portal is deleted") +var ErrPortalNotFoundInEventHandler = errors.New("portal not found to handle remote event") // Common message status errors var ( diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 273b1fd3..7d479be2 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -5195,6 +5195,9 @@ func (portal *Portal) addToUserSpaces(ctx context.Context) { } func (portal *Portal) Delete(ctx context.Context) error { + if portal.deleted.IsSet() { + return nil + } portal.removeInPortalCache(ctx) err := portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey) if err != nil { @@ -5254,6 +5257,9 @@ func (portal *Portal) unlockedDelete(ctx context.Context) error { } func (portal *Portal) unlockedDeleteCache() { + if portal.deleted.IsSet() { + return + } delete(portal.Bridge.portalsByKey, portal.PortalKey) if portal.MXID != "" { delete(portal.Bridge.portalsByMXID, portal.MXID) diff --git a/bridgev2/queue.go b/bridgev2/queue.go index 6667caea..3775c825 100644 --- a/bridgev2/queue.go +++ b/bridgev2/queue.go @@ -220,7 +220,7 @@ func (ul *UserLogin) QueueRemoteEvent(evt RemoteEvent) EventHandlingResult { return ul.Bridge.QueueRemoteEvent(ul, evt) } -func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) (res EventHandlingResult) { +func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) EventHandlingResult { log := login.Log ctx := log.WithContext(br.BackgroundCtx) maybeUncertain, ok := evt.(RemoteEventWithUncertainPortalReceiver) @@ -236,14 +236,14 @@ func (br *Bridge) QueueRemoteEvent(login *UserLogin, evt RemoteEvent) (res Event if err != nil { log.Err(err).Object("portal_key", key).Bool("uncertain_receiver", isUncertain). Msg("Failed to get portal to handle remote event") - return + return EventHandlingResultFailed.WithError(fmt.Errorf("failed to get portal: %w", err)) } else if portal == nil { log.Warn(). Stringer("event_type", evt.GetType()). Object("portal_key", key). Bool("uncertain_receiver", isUncertain). Msg("Portal not found to handle remote event") - return + return EventHandlingResultFailed.WithError(ErrPortalNotFoundInEventHandler) } // TODO put this in a better place, and maybe cache to avoid constant db queries login.MarkInPortal(ctx, portal) From de52a753be8aad17ccfd4ec89ad4a16bd222be14 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 13 Dec 2025 10:47:37 +0200 Subject: [PATCH 1552/1647] bridgev2: remove hardcoded room version --- bridgev2/matrix/connector.go | 17 +++++++++++++++++ bridgev2/matrix/intent.go | 34 ++++++++++++++++++++++++++++++++++ bridgev2/portal.go | 1 - bridgev2/space.go | 3 +-- bridgev2/user.go | 5 ++--- 5 files changed, 54 insertions(+), 6 deletions(-) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index cdfc2568..aed6d3bd 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -81,6 +81,8 @@ type Connector struct { MediaConfig mautrix.RespMediaConfig SpecVersions *mautrix.RespVersions + SpecCaps *mautrix.RespCapabilities + specCapsLock sync.Mutex Capabilities *bridgev2.MatrixCapabilities IgnoreUnsupportedServer bool @@ -409,6 +411,21 @@ func (br *Connector) ensureConnection(ctx context.Context) { br.Bot.EnsureAppserviceConnection(ctx) } +func (br *Connector) fetchCapabilities(ctx context.Context) *mautrix.RespCapabilities { + br.specCapsLock.Lock() + defer br.specCapsLock.Unlock() + if br.SpecCaps != nil { + return br.SpecCaps + } + caps, err := br.Bot.Capabilities(ctx) + if err != nil { + br.Log.Err(err).Msg("Failed to fetch capabilities from homeserver") + return nil + } + br.SpecCaps = caps + return caps +} + func (br *Connector) fetchMediaConfig(ctx context.Context) { cfg, err := br.Bot.GetMediaConfig(ctx) if err != nil { diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 1f82f77f..44dcbc5b 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -512,6 +512,39 @@ func (br *Connector) getDefaultEncryptionEvent() *event.EncryptionEventContent { return content } +func (as *ASIntent) filterCreateRequestForV12(ctx context.Context, req *mautrix.ReqCreateRoom) { + if as.Connector.Config.Homeserver.Software == bridgeconfig.SoftwareHungry { + // Hungryserv doesn't override the capabilities endpoint nor do room versions + return + } + caps := as.Connector.fetchCapabilities(ctx) + roomVer := req.RoomVersion + if roomVer == "" && caps != nil && caps.RoomVersions != nil { + roomVer = id.RoomVersion(caps.RoomVersions.Default) + } + if roomVer != "" && !roomVer.PrivilegedRoomCreators() { + return + } + creators, _ := req.CreationContent["additional_creators"].([]id.UserID) + creators = append(slices.Clone(creators), as.GetMXID()) + if req.PowerLevelOverride != nil { + for _, creator := range creators { + delete(req.PowerLevelOverride.Users, creator) + } + } + for _, evt := range req.InitialState { + if evt.Type != event.StatePowerLevels { + continue + } + content, ok := evt.Content.Parsed.(*event.PowerLevelsEventContent) + if ok { + for _, creator := range creators { + delete(content.Users, creator) + } + } + } +} + func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) (id.RoomID, error) { if as.Connector.Config.Encryption.Default { req.InitialState = append(req.InitialState, &event.Event{ @@ -527,6 +560,7 @@ func (as *ASIntent) CreateRoom(ctx context.Context, req *mautrix.ReqCreateRoom) } req.CreationContent["m.federate"] = false } + as.filterCreateRequestForV12(ctx, req) resp, err := as.Matrix.CreateRoom(ctx, req) if err != nil { return "", err diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 7d479be2..8bd66b6a 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -4999,7 +4999,6 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo IsDirect: portal.RoomType == database.RoomTypeDM, PowerLevelOverride: powerLevels, BeeperLocalRoomID: portal.Bridge.Matrix.GenerateDeterministicRoomID(portal.PortalKey), - RoomVersion: id.RoomV11, } autoJoinInvites := portal.Bridge.Matrix.GetCapabilities().AutoJoinInvites if autoJoinInvites { diff --git a/bridgev2/space.go b/bridgev2/space.go index f6d07922..2ca2bce3 100644 --- a/bridgev2/space.go +++ b/bridgev2/space.go @@ -164,8 +164,7 @@ func (ul *UserLogin) GetSpaceRoom(ctx context.Context) (id.RoomID, error) { ul.UserMXID: 50, }, }, - RoomVersion: id.RoomV11, - Invite: []id.UserID{ul.UserMXID}, + Invite: []id.UserID{ul.UserMXID}, } if autoJoin { req.BeeperInitialMembers = []id.UserID{ul.UserMXID} diff --git a/bridgev2/user.go b/bridgev2/user.go index af9e9694..9a7896d6 100644 --- a/bridgev2/user.go +++ b/bridgev2/user.go @@ -229,9 +229,8 @@ func (user *User) GetManagementRoom(ctx context.Context) (id.RoomID, error) { user.MXID: 50, }, }, - RoomVersion: id.RoomV11, - Invite: []id.UserID{user.MXID}, - IsDirect: true, + Invite: []id.UserID{user.MXID}, + IsDirect: true, } if autoJoin { req.BeeperInitialMembers = []id.UserID{user.MXID} From 9dc3772c47bc3a89fee85903f44b5b84fa1676dc Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 13 Dec 2025 10:54:58 +0200 Subject: [PATCH 1553/1647] ci: update actions and pre-commit hooks --- .github/workflows/go.yml | 12 ++++++------ .github/workflows/stale.yml | 2 +- .pre-commit-config.yaml | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 8bce4484..deaa1f1d 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -10,10 +10,10 @@ jobs: runs-on: ubuntu-latest name: Lint (latest) steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Set up Go - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: "1.25" cache: true @@ -39,10 +39,10 @@ jobs: name: Build (${{ matrix.go-version == '1.25' && 'latest' || 'old' }}, libolm) steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Set up Go ${{ matrix.go-version }} - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: ${{ matrix.go-version }} cache: true @@ -76,10 +76,10 @@ jobs: name: Build (${{ matrix.go-version == '1.25' && 'latest' || 'old' }}, goolm) steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Set up Go ${{ matrix.go-version }} - uses: actions/setup-go@v5 + uses: actions/setup-go@v6 with: go-version: ${{ matrix.go-version }} cache: true diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index 578349c9..9a9e7375 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -17,7 +17,7 @@ jobs: lock-stale: runs-on: ubuntu-latest steps: - - uses: dessant/lock-threads@v5 + - uses: dessant/lock-threads@v6 id: lock with: issue-inactive-days: 90 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4f769e56..616fccb2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 + rev: v6.0.0 hooks: - id: trailing-whitespace exclude_types: [markdown] @@ -9,7 +9,7 @@ repos: - id: check-added-large-files - repo: https://github.com/tekwizely/pre-commit-golang - rev: v1.0.0-rc.1 + rev: v1.0.0-rc.4 hooks: - id: go-imports-repo args: From cb6f673e7a700eed2e70147c5338fd06b183ba5d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 13 Dec 2025 11:09:09 +0200 Subject: [PATCH 1554/1647] bridgev2/portal: fix event loop not stopping --- bridgev2/portal.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 8bd66b6a..9ee277b3 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -381,12 +381,16 @@ func (portal *Portal) eventLoop() { for i := 0; ; i++ { select { case rawEvt := <-portal.events: + if rawEvt == nil { + return + } if portal.Bridge.Config.AsyncEvents { go portal.handleSingleEventWithDelayLogging(i, rawEvt) } else { portal.handleSingleEventWithDelayLogging(i, rawEvt) } case <-deleteCh: + return } } } From 4be256229706fa06a220bfe0d75cef1414da1cb0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 14 Dec 2025 14:37:57 +0200 Subject: [PATCH 1555/1647] changelog: update --- CHANGELOG.md | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b30e055e..fa31c025 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,35 @@ +## v0.26.1 (unreleased) + +* **Breaking change *(mediaproxy)*** Changed `GetMediaResponseFile` to return + the mime type from the callback rather than in the return get media return + value. The callback can now also redirect the caller to a different file. +* *(federation)* Added join/knock/leave functions + (thanks to [@nexy7574] in [#422]). +* *(federation/eventauth)* Fixed various incorrect checks. +* *(client)* Added backoff for retrying media uploads to external URLs + (with MSC3870). +* *(bridgev2/config)* Added support for overriding config fields using + environment variables. +* *(bridgev2/commands)* Added command to mute chat on remote network. +* *(bridgev2)* Added interface for network connectors to redirect to a different + user ID when handling an invite from Matrix. +* *(bridgev2)* Added interface for signaling message request status of portals. +* *(bridgev2)* Changed portal creation to not backfill unless `CanBackfill` flag + is set in chat info. +* *(bridgev2)* Changed Matrix reaction handling to only delete old reaction if + bridging the new one is successful. +* *(bridgev2/mxmain)* Improved error message when trying to run bridge with + pre-megabridge database when no database migration exists. +* *(bridgev2)* Improved reliability of database migration when enabling split + portals. +* *(bridgev2)* Improved detection of orphaned DM rooms when starting new chats. +* *(bridgev2)* Stopped sending redundant invites when joining ghosts to public + portal rooms. +* *(bridgev2)* Stopped hardcoding room versions in favor of checking + server capabilities to determine appropriate `/createRoom` parameters. + +[#422]: https://github.com/mautrix/go/pull/422 + ## v0.26.0 (2025-11-16) * *(client,appservice)* Deprecated `SendMassagedStateEvent` as `SendStateEvent` From 950ce6636e10f59484593df957dabe660c2804db Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 15 Dec 2025 15:18:40 +0200 Subject: [PATCH 1556/1647] crypto/goolm: include version number in version mismatches --- crypto/goolm/message/group_message.go | 2 +- crypto/goolm/message/message.go | 2 +- crypto/goolm/message/prekey_message.go | 2 +- crypto/goolm/ratchet/olm.go | 2 +- crypto/goolm/session/megolm_inbound_session.go | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/crypto/goolm/message/group_message.go b/crypto/goolm/message/group_message.go index f3d22500..c83540c1 100644 --- a/crypto/goolm/message/group_message.go +++ b/crypto/goolm/message/group_message.go @@ -39,7 +39,7 @@ func (r *GroupMessage) Decode(input []byte) (err error) { return } if r.Version != protocolVersion { - return fmt.Errorf("GroupMessage.Decode: %w", olm.ErrWrongProtocolVersion) + return fmt.Errorf("GroupMessage.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion) } for { diff --git a/crypto/goolm/message/message.go b/crypto/goolm/message/message.go index 9ef93630..b161a2d1 100644 --- a/crypto/goolm/message/message.go +++ b/crypto/goolm/message/message.go @@ -43,7 +43,7 @@ func (r *Message) Decode(input []byte) (err error) { return } if r.Version != protocolVersion { - return fmt.Errorf("Message.Decode: %w", olm.ErrWrongProtocolVersion) + return fmt.Errorf("Message.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion) } for { diff --git a/crypto/goolm/message/prekey_message.go b/crypto/goolm/message/prekey_message.go index 760be4c9..4e3d495d 100644 --- a/crypto/goolm/message/prekey_message.go +++ b/crypto/goolm/message/prekey_message.go @@ -48,7 +48,7 @@ func (r *PreKeyMessage) Decode(input []byte) (err error) { return } if r.Version != protocolVersion { - return fmt.Errorf("PreKeyMessage.Decode: %w", olm.ErrWrongProtocolVersion) + return fmt.Errorf("PreKeyMessage.Decode: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, r.Version, protocolVersion) } for { diff --git a/crypto/goolm/ratchet/olm.go b/crypto/goolm/ratchet/olm.go index 229c9bd2..9901ada8 100644 --- a/crypto/goolm/ratchet/olm.go +++ b/crypto/goolm/ratchet/olm.go @@ -142,7 +142,7 @@ func (r *Ratchet) Decrypt(input []byte) ([]byte, error) { return nil, err } if message.Version != protocolVersion { - return nil, fmt.Errorf("decrypt: %w", olm.ErrWrongProtocolVersion) + return nil, fmt.Errorf("decrypt: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, message.Version, protocolVersion) } if !message.HasCounter || len(message.RatchetKey) == 0 || len(message.Ciphertext) == 0 { return nil, fmt.Errorf("decrypt: %w", olm.ErrBadMessageFormat) diff --git a/crypto/goolm/session/megolm_inbound_session.go b/crypto/goolm/session/megolm_inbound_session.go index fb88b73c..7ccbd26d 100644 --- a/crypto/goolm/session/megolm_inbound_session.go +++ b/crypto/goolm/session/megolm_inbound_session.go @@ -126,7 +126,7 @@ func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint, error) return nil, 0, err } if msg.Version != protocolVersion { - return nil, 0, fmt.Errorf("decrypt: %w", olm.ErrWrongProtocolVersion) + return nil, 0, fmt.Errorf("decrypt: %w (got %d, expected %d)", olm.ErrWrongProtocolVersion, msg.Version, protocolVersion) } if msg.Ciphertext == nil || !msg.HasMessageIndex { return nil, 0, fmt.Errorf("decrypt: %w", olm.ErrBadMessageFormat) From b9635964a5e36349a9ff7364f5b1173cce7aedb6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 16 Dec 2025 12:20:42 +0200 Subject: [PATCH 1557/1647] Bump version to v0.26.1 --- CHANGELOG.md | 2 +- go.mod | 4 ++-- go.sum | 7 ++++--- version.go | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fa31c025..0fb1a105 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -## v0.26.1 (unreleased) +## v0.26.1 (2025-12-16) * **Breaking change *(mediaproxy)*** Changed `GetMediaResponseFile` to return the mime type from the callback rather than in the return get media return diff --git a/go.mod b/go.mod index 0b86e5da..cdb62f20 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.13 - go.mau.fi/util v0.9.4-0.20251211121531-f6527b4882ae + go.mau.fi/util v0.9.4 go.mau.fi/zeroconfig v0.2.0 golang.org/x/crypto v0.46.0 golang.org/x/exp v0.0.0-20251209150349-8475f28825e9 @@ -28,7 +28,7 @@ require ( ) require ( - github.com/coreos/go-systemd/v22 v22.5.0 // indirect + github.com/coreos/go-systemd/v22 v22.6.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect diff --git a/go.sum b/go.sum index 5e9eded0..a55f0661 100644 --- a/go.sum +++ b/go.sum @@ -10,8 +10,9 @@ github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= -github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/coreos/go-systemd/v22 v22.6.0 h1:aGVa/v8B7hpb0TKl0MWoAavPDmHvobFe5R5zn0bCJWo= +github.com/coreos/go-systemd/v22 v22.6.0/go.mod h1:iG+pp635Fo7ZmV/j14KUcmEyWF+0X7Lua8rrTWzYgWU= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -51,8 +52,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.13 h1:GPddIs617DnBLFFVJFgpo1aBfe/4xcvMc3SB5t/D0pA= github.com/yuin/goldmark v1.7.13/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.9.4-0.20251211121531-f6527b4882ae h1:tocQOutgT+Z/V6w668Jpk3D5942K5p25XmRAvXg8s2E= -go.mau.fi/util v0.9.4-0.20251211121531-f6527b4882ae/go.mod h1:OwI76F1QINxtH/TOydGAAj5/VvtPG0RnZzB41rtnKcA= +go.mau.fi/util v0.9.4 h1:gWdUff+K2rCynRPysXalqqQyr2ahkSWaestH6YhSpso= +go.mau.fi/util v0.9.4/go.mod h1:647nVfwUvuhlZFOnro3aRNPmRd2y3iDha9USb8aKSmM= go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= diff --git a/version.go b/version.go index f6d20c3f..46d3342c 100644 --- a/version.go +++ b/version.go @@ -8,7 +8,7 @@ import ( "strings" ) -const Version = "v0.26.0" +const Version = "v0.26.1" var GoModVersion = "" var Commit = "" From e9b262e67162251198e83bc863374e26d8546db5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 16 Dec 2025 16:23:44 +0200 Subject: [PATCH 1558/1647] bridgev2/database: add index for disappearing messages and portal parents --- bridgev2/database/upgrades/00-latest.sql | 4 +++- .../upgrades/26-disappearing-message-portal-index.sql | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) create mode 100644 bridgev2/database/upgrades/26-disappearing-message-portal-index.sql diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index b01cca44..b193d314 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v25 (compatible with v9+): Latest revision +-- v0 -> v26 (compatible with v9+): Latest revision CREATE TABLE "user" ( bridge_id TEXT NOT NULL, mxid TEXT NOT NULL, @@ -65,6 +65,7 @@ CREATE TABLE portal ( ON DELETE SET NULL ON UPDATE CASCADE ); CREATE UNIQUE INDEX portal_bridge_mxid_idx ON portal (bridge_id, mxid); +CREATE INDEX portal_parent_idx ON portal (bridge_id, parent_id, parent_receiver); CREATE TABLE ghost ( bridge_id TEXT NOT NULL, @@ -139,6 +140,7 @@ CREATE TABLE disappearing_message ( REFERENCES portal (bridge_id, mxid) ON DELETE CASCADE ); +CREATE INDEX disappearing_message_portal_idx ON disappearing_message (bridge_id, mx_room); CREATE TABLE reaction ( bridge_id TEXT NOT NULL, diff --git a/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql b/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql new file mode 100644 index 00000000..ae5d8cad --- /dev/null +++ b/bridgev2/database/upgrades/26-disappearing-message-portal-index.sql @@ -0,0 +1,3 @@ +-- v26 (compatible with v9+): Add room index for disappearing message table and portal parents +CREATE INDEX disappearing_message_portal_idx ON disappearing_message (bridge_id, mx_room); +CREATE INDEX portal_parent_idx ON portal (bridge_id, parent_id, parent_receiver); From e38d758a525a0b59786497f221b52684ea468e23 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 16 Dec 2025 16:59:54 +0200 Subject: [PATCH 1559/1647] bridgev2/database: delete messages in chunks if portal has too many --- bridgev2/database/message.go | 75 ++++++++++++++++++++++++++++++++++++ bridgev2/portal.go | 14 +++++-- 2 files changed, 86 insertions(+), 3 deletions(-) diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index 9b3b1493..a1af1556 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -11,9 +11,11 @@ import ( "crypto/sha256" "database/sql" "encoding/base64" + "fmt" "strings" "time" + "github.com/rs/zerolog" "go.mau.fi/util/dbutil" "maunium.net/go/mautrix/bridgev2/networkid" @@ -96,6 +98,10 @@ const ( deleteMessagePartByRowIDQuery = ` DELETE FROM message WHERE bridge_id=$1 AND rowid=$2 ` + deleteMessageChunkQuery = ` + DELETE FROM message WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND rowid > $4 AND rowid <= $5 + ` + getMaxMessageRowIDQuery = `SELECT MAX(rowid) FROM message WHERE bridge_id=$1` ) func (mq *MessageQuery) GetAllPartsByID(ctx context.Context, receiver networkid.UserLoginID, id networkid.MessageID) ([]*Message, error) { @@ -180,6 +186,75 @@ func (mq *MessageQuery) Delete(ctx context.Context, rowID int64) error { return mq.Exec(ctx, deleteMessagePartByRowIDQuery, mq.BridgeID, rowID) } +func (mq *MessageQuery) deleteChunk(ctx context.Context, portal networkid.PortalKey, minRowID, maxRowID int64) (int64, error) { + res, err := mq.GetDB().Exec(ctx, deleteMessageChunkQuery, mq.BridgeID, portal.ID, portal.Receiver, minRowID, maxRowID) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func (mq *MessageQuery) getMaxRowID(ctx context.Context) (maxRowID int64, err error) { + err = mq.GetDB().QueryRow(ctx, getMaxMessageRowIDQuery, mq.BridgeID).Scan(&maxRowID) + return +} + +const deleteChunkSize = 100_000 + +func (mq *MessageQuery) DeleteInChunks(ctx context.Context, portal networkid.PortalKey) error { + if mq.GetDB().Dialect != dbutil.SQLite { + return nil + } + total, err := mq.CountMessagesInPortal(ctx, portal) + if err != nil { + return fmt.Errorf("failed to count messages in portal: %w", err) + } else if total < deleteChunkSize { + return nil + } + globalMaxRowID, err := mq.getMaxRowID(ctx) + if err != nil { + return fmt.Errorf("failed to get max row ID: %w", err) + } + zerolog.Ctx(ctx).Debug(). + Int("total_count", total). + Int64("global_max_row_id", globalMaxRowID). + Msg("Portal has lots of messages, deleting in chunks to avoid database locks") + maxRowID := int64(deleteChunkSize) + globalMaxRowID += deleteChunkSize * 1.2 + var dbTimeUsed time.Duration + globalStart := time.Now() + for total > 500 && maxRowID < globalMaxRowID { + start := time.Now() + count, err := mq.deleteChunk(ctx, portal, maxRowID-deleteChunkSize, maxRowID) + duration := time.Since(start) + dbTimeUsed += duration + if err != nil { + return fmt.Errorf("failed to delete chunk of messages before %d: %w", maxRowID, err) + } + total -= int(count) + maxRowID += deleteChunkSize + sleepTime := max(10*time.Millisecond, min(250*time.Millisecond, time.Duration(count/100)*time.Millisecond)) + zerolog.Ctx(ctx).Debug(). + Int64("max_row_id", maxRowID). + Int64("deleted_count", count). + Int("remaining_count", total). + Dur("duration", duration). + Dur("sleep_time", sleepTime). + Msg("Deleted chunk of messages") + select { + case <-time.After(sleepTime): + case <-ctx.Done(): + return ctx.Err() + } + } + zerolog.Ctx(ctx).Debug(). + Int("remaining_count", total). + Dur("db_time_used", dbTimeUsed). + Dur("total_duration", time.Since(globalStart)). + Msg("Finished chunked delete of messages in portal") + return nil +} + func (mq *MessageQuery) CountMessagesInPortal(ctx context.Context, key networkid.PortalKey) (count int, err error) { err = mq.GetDB().QueryRow(ctx, countMessagesInPortalQuery, mq.BridgeID, key.ID, key.Receiver).Scan(&count) return diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 9ee277b3..87b50c84 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -5202,7 +5202,7 @@ func (portal *Portal) Delete(ctx context.Context) error { return nil } portal.removeInPortalCache(ctx) - err := portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey) + err := portal.safeDBDelete(ctx) if err != nil { return err } @@ -5212,6 +5212,15 @@ func (portal *Portal) Delete(ctx context.Context) error { return nil } +func (portal *Portal) safeDBDelete(ctx context.Context) error { + err := portal.Bridge.DB.Message.DeleteInChunks(ctx, portal.PortalKey) + if err != nil { + return fmt.Errorf("failed to delete messages in portal: %w", err) + } + // TODO delete child portals? + return portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey) +} + func (portal *Portal) RemoveMXID(ctx context.Context) error { if portal.MXID == "" { return nil @@ -5250,8 +5259,7 @@ func (portal *Portal) removeInPortalCache(ctx context.Context) { } func (portal *Portal) unlockedDelete(ctx context.Context) error { - // TODO delete child portals? - err := portal.Bridge.DB.Portal.Delete(ctx, portal.PortalKey) + err := portal.safeDBDelete(ctx) if err != nil { return err } From b44f81d114cab8b153b78c454329365d9d684547 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 16 Dec 2025 18:57:39 +0200 Subject: [PATCH 1560/1647] bridgev2/database: only allow one chunked portal deletion at a time --- bridgev2/database/message.go | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index a1af1556..2172c224 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -13,6 +13,7 @@ import ( "encoding/base64" "fmt" "strings" + "sync" "time" "github.com/rs/zerolog" @@ -26,6 +27,7 @@ type MessageQuery struct { BridgeID networkid.BridgeID MetaType MetaTypeCreator *dbutil.QueryHelper[*Message] + chunkDeleteLock sync.Mutex } type Message struct { @@ -205,6 +207,16 @@ func (mq *MessageQuery) DeleteInChunks(ctx context.Context, portal networkid.Por if mq.GetDB().Dialect != dbutil.SQLite { return nil } + log := zerolog.Ctx(ctx).With(). + Str("action", "delete messages in chunks"). + Stringer("portal_key", portal). + Logger() + if !mq.chunkDeleteLock.TryLock() { + log.Warn().Msg("Portal deletion lock is being held, waiting...") + mq.chunkDeleteLock.Lock() + log.Debug().Msg("Acquired portal deletion lock after waiting") + } + defer mq.chunkDeleteLock.Unlock() total, err := mq.CountMessagesInPortal(ctx, portal) if err != nil { return fmt.Errorf("failed to count messages in portal: %w", err) @@ -215,7 +227,7 @@ func (mq *MessageQuery) DeleteInChunks(ctx context.Context, portal networkid.Por if err != nil { return fmt.Errorf("failed to get max row ID: %w", err) } - zerolog.Ctx(ctx).Debug(). + log.Debug(). Int("total_count", total). Int64("global_max_row_id", globalMaxRowID). Msg("Portal has lots of messages, deleting in chunks to avoid database locks") @@ -234,7 +246,7 @@ func (mq *MessageQuery) DeleteInChunks(ctx context.Context, portal networkid.Por total -= int(count) maxRowID += deleteChunkSize sleepTime := max(10*time.Millisecond, min(250*time.Millisecond, time.Duration(count/100)*time.Millisecond)) - zerolog.Ctx(ctx).Debug(). + log.Debug(). Int64("max_row_id", maxRowID). Int64("deleted_count", count). Int("remaining_count", total). @@ -247,7 +259,7 @@ func (mq *MessageQuery) DeleteInChunks(ctx context.Context, portal networkid.Por return ctx.Err() } } - zerolog.Ctx(ctx).Debug(). + log.Debug(). Int("remaining_count", total). Dur("db_time_used", dbTimeUsed). Dur("total_duration", time.Since(globalStart)). From 33eb00fde0e21158361b3412784089b4bd14b20c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 16 Dec 2025 19:29:26 +0200 Subject: [PATCH 1561/1647] bridgev2/database: reduce limit for using chunked deletion --- bridgev2/database/message.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index 2172c224..43f33666 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -220,7 +220,7 @@ func (mq *MessageQuery) DeleteInChunks(ctx context.Context, portal networkid.Por total, err := mq.CountMessagesInPortal(ctx, portal) if err != nil { return fmt.Errorf("failed to count messages in portal: %w", err) - } else if total < deleteChunkSize { + } else if total < deleteChunkSize/3 { return nil } globalMaxRowID, err := mq.getMaxRowID(ctx) From 80b4201ff1a6a6b7e7896b6a9f22f4ea2dc368c1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Dec 2025 13:03:19 +0200 Subject: [PATCH 1562/1647] bridgev2/portalreid: add more logs --- bridgev2/portalreid.go | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/bridgev2/portalreid.go b/bridgev2/portalreid.go index d1a9d5a6..e133b224 100644 --- a/bridgev2/portalreid.go +++ b/bridgev2/portalreid.go @@ -32,13 +32,22 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta if source == target { return ReIDResultError, nil, fmt.Errorf("illegal re-ID call: source and target are the same") } - log := zerolog.Ctx(ctx) - log.Debug().Msg("Re-ID'ing portal") + log := zerolog.Ctx(ctx).With(). + Str("action", "re-id portal"). + Stringer("source_portal_key", source). + Stringer("target_portal_key", target). + Logger() + ctx = log.WithContext(ctx) + if !br.cacheLock.TryLock() { + log.Debug().Msg("Waiting for cache lock") + br.cacheLock.Lock() + log.Debug().Msg("Acquired cache lock after waiting") + } defer func() { + br.cacheLock.Unlock() log.Debug().Msg("Finished handling portal re-ID") }() - br.cacheLock.Lock() - defer br.cacheLock.Unlock() + log.Debug().Msg("Re-ID'ing portal") sourcePortal, err := br.UnlockedGetPortalByKey(ctx, source, true) if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to get source portal: %w", err) @@ -46,7 +55,11 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta log.Debug().Msg("Source portal not found, re-ID is no-op") return ReIDResultNoOp, nil, nil } - sourcePortal.roomCreateLock.Lock() + if !sourcePortal.roomCreateLock.TryLock() { + log.Debug().Msg("Waiting for source portal room creation lock") + sourcePortal.roomCreateLock.Lock() + log.Debug().Msg("Acquired source portal room creation lock after waiting") + } defer sourcePortal.roomCreateLock.Unlock() if sourcePortal.MXID == "" { log.Info().Msg("Source portal doesn't have Matrix room, deleting row") @@ -71,7 +84,11 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta } return ReIDResultSourceReIDd, sourcePortal, nil } - targetPortal.roomCreateLock.Lock() + if !targetPortal.roomCreateLock.TryLock() { + log.Debug().Msg("Waiting for target portal room creation lock") + targetPortal.roomCreateLock.Lock() + log.Debug().Msg("Acquired target portal room creation lock after waiting") + } defer targetPortal.roomCreateLock.Unlock() if targetPortal.MXID == "" { log.Info().Msg("Target portal row exists, but doesn't have a Matrix room. Deleting target portal row and re-ID'ing source portal") From af06098723ba6016a65f69a94bee630230b76829 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Dec 2025 13:06:34 +0200 Subject: [PATCH 1563/1647] bridgev2/simplevent: add method to merge log contexts --- bridgev2/simplevent/meta.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/bridgev2/simplevent/meta.go b/bridgev2/simplevent/meta.go index 8aa91866..449a8773 100644 --- a/bridgev2/simplevent/meta.go +++ b/bridgev2/simplevent/meta.go @@ -101,6 +101,18 @@ func (evt EventMeta) WithLogContext(f func(c zerolog.Context) zerolog.Context) E return evt } +func (evt EventMeta) WithMoreLogContext(f func(c zerolog.Context) zerolog.Context) EventMeta { + origFunc := evt.LogContext + if origFunc == nil { + evt.LogContext = f + return evt + } + evt.LogContext = func(c zerolog.Context) zerolog.Context { + return f(origFunc(c)) + } + return evt +} + func (evt EventMeta) WithPortalKey(p networkid.PortalKey) EventMeta { evt.PortalKey = p return evt From 4825e41d5c5f57a3b2d9f628cf391aa1b2b6540b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Dec 2025 13:32:55 +0200 Subject: [PATCH 1564/1647] bridgev2/portalreid: try to cancel room creation --- bridgev2/portal.go | 25 +++++++++++++++++-------- bridgev2/portalreid.go | 6 ++++++ 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 87b50c84..e9feb448 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -86,8 +86,9 @@ type Portal struct { lastCapUpdate time.Time - roomCreateLock sync.Mutex - RoomCreated *exsync.Event + roomCreateLock sync.Mutex + cancelRoomCreate atomic.Pointer[context.CancelFunc] + RoomCreated *exsync.Event functionalMembersLock sync.Mutex functionalMembersCache *event.ElementFunctionalMembersContent @@ -4947,7 +4948,11 @@ func (portal *Portal) CreateMatrixRoom(ctx context.Context, source *UserLogin, i } func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLogin, info *ChatInfo, backfillBundle any) error { + cancellableCtx, cancel := context.WithCancel(ctx) + defer cancel() + portal.cancelRoomCreate.CompareAndSwap(nil, &cancel) portal.roomCreateLock.Lock() + portal.cancelRoomCreate.Store(&cancel) defer portal.roomCreateLock.Unlock() if portal.MXID != "" { if source != nil { @@ -4958,6 +4963,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo log := zerolog.Ctx(ctx).With(). Str("action", "create matrix room"). Logger() + cancellableCtx = log.WithContext(cancellableCtx) ctx = log.WithContext(ctx) log.Info().Msg("Creating Matrix room") @@ -4966,16 +4972,16 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo if info != nil { log.Warn().Msg("CreateMatrixRoom got info without members. Refetching info") } - info, err = source.Client.GetChatInfo(ctx, portal) + info, err = source.Client.GetChatInfo(cancellableCtx, portal) if err != nil { log.Err(err).Msg("Failed to update portal info for creation") return err } } - portal.UpdateInfo(ctx, info, source, nil, time.Time{}) - if ctx.Err() != nil { - return ctx.Err() + portal.UpdateInfo(cancellableCtx, info, source, nil, time.Time{}) + if cancellableCtx.Err() != nil { + return cancellableCtx.Err() } powerLevels := &event.PowerLevelsEventContent{ @@ -4988,7 +4994,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo portal.Bridge.Bot.GetMXID(): 9001, }, } - initialMembers, extraFunctionalMembers, err := portal.getInitialMemberList(ctx, info.Members, source, powerLevels) + initialMembers, extraFunctionalMembers, err := portal.getInitialMemberList(cancellableCtx, info.Members, source, powerLevels) if err != nil { log.Err(err).Msg("Failed to process participant list for portal creation") return err @@ -5015,7 +5021,7 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo req.CreationContent["type"] = event.RoomTypeSpace } bridgeInfoStateKey, bridgeInfo := portal.getBridgeInfo() - roomFeatures := source.Client.GetCapabilities(ctx, portal) + roomFeatures := source.Client.GetCapabilities(cancellableCtx, portal) portal.CapState = database.CapabilityState{ Source: source.ID, ID: roomFeatures.GetID(), @@ -5097,6 +5103,9 @@ func (portal *Portal) createMatrixRoomInLoop(ctx context.Context, source *UserLo Content: event.Content{Parsed: info.JoinRule}, }) } + if cancellableCtx.Err() != nil { + return cancellableCtx.Err() + } roomID, err := portal.Bridge.Bot.CreateRoom(ctx, &req) if err != nil { log.Err(err).Msg("Failed to create Matrix room") diff --git a/bridgev2/portalreid.go b/bridgev2/portalreid.go index e133b224..6a5091fc 100644 --- a/bridgev2/portalreid.go +++ b/bridgev2/portalreid.go @@ -56,6 +56,9 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta return ReIDResultNoOp, nil, nil } if !sourcePortal.roomCreateLock.TryLock() { + if cancelCreate := sourcePortal.cancelRoomCreate.Swap(nil); cancelCreate != nil { + (*cancelCreate)() + } log.Debug().Msg("Waiting for source portal room creation lock") sourcePortal.roomCreateLock.Lock() log.Debug().Msg("Acquired source portal room creation lock after waiting") @@ -85,6 +88,9 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta return ReIDResultSourceReIDd, sourcePortal, nil } if !targetPortal.roomCreateLock.TryLock() { + if cancelCreate := targetPortal.cancelRoomCreate.Swap(nil); cancelCreate != nil { + (*cancelCreate)() + } log.Debug().Msg("Waiting for target portal room creation lock") targetPortal.roomCreateLock.Lock() log.Debug().Msg("Acquired target portal room creation lock after waiting") From 59ec890dcb92844825c9ebce4b80173db1e18ee2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 19 Dec 2025 15:15:23 +0200 Subject: [PATCH 1565/1647] changelog: add missing link --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0fb1a105..8017ef97 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -392,6 +392,7 @@ [MSC4156]: https://github.com/matrix-org/matrix-spec-proposals/pull/4156 [MSC4190]: https://github.com/matrix-org/matrix-spec-proposals/pull/4190 [#288]: https://github.com/mautrix/go/pull/288 +[@onestacked]: https://github.com/onestacked ## v0.22.0 (2024-11-16) From 788151bc505028ce7e50b217c06d86c3b5f0a246 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 30 Dec 2025 22:53:27 +0200 Subject: [PATCH 1566/1647] client: error if Download parameter is empty --- client.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/client.go b/client.go index 90737581..e12e45d3 100644 --- a/client.go +++ b/client.go @@ -1835,6 +1835,9 @@ func (cli *Client) UploadLink(ctx context.Context, link string) (*RespMediaUploa } func (cli *Client) Download(ctx context.Context, mxcURL id.ContentURI) (*http.Response, error) { + if mxcURL.IsEmpty() { + return nil, fmt.Errorf("empty mxc uri provided to Download") + } _, resp, err := cli.MakeFullRequestWithResp(ctx, FullRequest{ Method: http.MethodGet, URL: cli.BuildClientURL("v1", "media", "download", mxcURL.Homeserver, mxcURL.FileID), @@ -1849,6 +1852,9 @@ type DownloadThumbnailExtra struct { } func (cli *Client) DownloadThumbnail(ctx context.Context, mxcURL id.ContentURI, height, width int, extras ...DownloadThumbnailExtra) (*http.Response, error) { + if mxcURL.IsEmpty() { + return nil, fmt.Errorf("empty mxc uri provided to DownloadThumbnail") + } if len(extras) > 1 { panic(fmt.Errorf("invalid number of arguments to DownloadThumbnail: %d", len(extras))) } From 3a2c6ae865ca2b4384d0d28bb50f5b3069c68d51 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 5 Jan 2026 14:58:29 +0200 Subject: [PATCH 1567/1647] client: stabilize MSC4323 --- client.go | 45 +++++++++++++++++++++++++++++++++------------ versions.go | 22 ++++++++++++---------- 2 files changed, 45 insertions(+), 22 deletions(-) diff --git a/client.go b/client.go index e12e45d3..87b6d87e 100644 --- a/client.go +++ b/client.go @@ -2724,30 +2724,51 @@ func (cli *Client) AdminWhoIs(ctx context.Context, userID id.UserID) (resp RespW return } -// UnstableGetSuspendedStatus uses MSC4323 to check if a user is suspended. -func (cli *Client) UnstableGetSuspendedStatus(ctx context.Context, userID id.UserID) (res *RespSuspended, err error) { - urlPath := cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", "suspend", userID) +func (cli *Client) makeMSC4323URL(action string, target id.UserID) string { + if cli.SpecVersions.Supports(FeatureUnstableAccountModeration) { + return cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", action, target) + } else if cli.SpecVersions.Supports(FeatureStableAccountModeration) { + return cli.BuildClientURL("v1", "admin", action, target) + } + return "" +} + +// GetSuspendedStatus uses MSC4323 to check if a user is suspended. +func (cli *Client) GetSuspendedStatus(ctx context.Context, userID id.UserID) (res *RespSuspended, err error) { + urlPath := cli.makeMSC4323URL("suspend", userID) + if urlPath == "" { + return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support") + } _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, res) return } -// UnstableGetLockStatus uses MSC4323 to check if a user is locked. -func (cli *Client) UnstableGetLockStatus(ctx context.Context, userID id.UserID) (res *RespLocked, err error) { - urlPath := cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", "lock", userID) +// GetLockStatus uses MSC4323 to check if a user is locked. +func (cli *Client) GetLockStatus(ctx context.Context, userID id.UserID) (res *RespLocked, err error) { + urlPath := cli.makeMSC4323URL("lock", userID) + if urlPath == "" { + return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support") + } _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, res) return } -// UnstableSetSuspendedStatus uses MSC4323 to set whether a user account is suspended. -func (cli *Client) UnstableSetSuspendedStatus(ctx context.Context, userID id.UserID, suspended bool) (res *RespSuspended, err error) { - urlPath := cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", "suspend", userID) +// SetSuspendedStatus uses MSC4323 to set whether a user account is suspended. +func (cli *Client) SetSuspendedStatus(ctx context.Context, userID id.UserID, suspended bool) (res *RespSuspended, err error) { + urlPath := cli.makeMSC4323URL("suspend", userID) + if urlPath == "" { + return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support") + } _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &ReqSuspend{Suspended: suspended}, res) return } -// UnstableSetLockStatus uses MSC4323 to set whether a user account is locked. -func (cli *Client) UnstableSetLockStatus(ctx context.Context, userID id.UserID, locked bool) (res *RespLocked, err error) { - urlPath := cli.BuildClientURL("unstable", "uk.timedout.msc4323", "admin", "lock", userID) +// SetLockStatus uses MSC4323 to set whether a user account is locked. +func (cli *Client) SetLockStatus(ctx context.Context, userID id.UserID, locked bool) (res *RespLocked, err error) { + urlPath := cli.makeMSC4323URL("lock", userID) + if urlPath == "" { + return nil, MUnrecognized.WithMessage("Homeserver does not advertise MSC4323 support") + } _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &ReqLocked{Locked: locked}, res) return } diff --git a/versions.go b/versions.go index 2aaf6399..8ae82a06 100644 --- a/versions.go +++ b/versions.go @@ -60,16 +60,17 @@ type UnstableFeature struct { } var ( - FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17} - FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17} - FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111} - FeatureMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"} - FeatureUserRedaction = UnstableFeature{UnstableFlag: "org.matrix.msc4194"} - FeatureViewRedactedContent = UnstableFeature{UnstableFlag: "fi.mau.msc2815"} - FeatureAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323"} - FeatureUnstableProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133"} - FeatureArbitraryProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133.stable", SpecVersion: SpecV116} - FeatureRedactSendAsEvent = UnstableFeature{UnstableFlag: "com.beeper.msc4169"} + FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17} + FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17} + FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111} + FeatureMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"} + FeatureUserRedaction = UnstableFeature{UnstableFlag: "org.matrix.msc4194"} + FeatureViewRedactedContent = UnstableFeature{UnstableFlag: "fi.mau.msc2815"} + FeatureUnstableAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323"} + FeatureStableAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323.stable" /*, SpecVersion: SpecV118*/} + FeatureUnstableProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133"} + FeatureArbitraryProfileFields = UnstableFeature{UnstableFlag: "uk.tcpip.msc4133.stable", SpecVersion: SpecV116} + FeatureRedactSendAsEvent = UnstableFeature{UnstableFlag: "com.beeper.msc4169"} BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"} BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"} @@ -123,6 +124,7 @@ var ( SpecV114 = MustParseSpecVersion("v1.14") SpecV115 = MustParseSpecVersion("v1.15") SpecV116 = MustParseSpecVersion("v1.16") + SpecV117 = MustParseSpecVersion("v1.17") ) func (svf SpecVersionFormat) String() string { From f4434b33c638af2dc8eb8b186ba8e180a407c81c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 7 Jan 2026 19:22:32 +0200 Subject: [PATCH 1568/1647] crypto,bridgev2: add option to encrypt reactions and replies (#445) --- bridgev2/bridgeconfig/encryption.go | 1 + bridgev2/bridgeconfig/upgrade.go | 1 + bridgev2/matrix/intent.go | 2 +- bridgev2/matrix/mxmain/example-config.yaml | 2 ++ crypto/encryptmegolm.go | 9 +++++++++ crypto/machine.go | 1 + 6 files changed, 15 insertions(+), 1 deletion(-) diff --git a/bridgev2/bridgeconfig/encryption.go b/bridgev2/bridgeconfig/encryption.go index 5a19b3ad..934613ca 100644 --- a/bridgev2/bridgeconfig/encryption.go +++ b/bridgev2/bridgeconfig/encryption.go @@ -16,6 +16,7 @@ type EncryptionConfig struct { Require bool `yaml:"require"` Appservice bool `yaml:"appservice"` MSC4190 bool `yaml:"msc4190"` + MSC4392 bool `yaml:"msc4392"` SelfSign bool `yaml:"self_sign"` PlaintextMentions bool `yaml:"plaintext_mentions"` diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index 960e2fb4..a0278672 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -161,6 +161,7 @@ func doUpgrade(helper up.Helper) { } else { helper.Copy(up.Bool, "encryption", "msc4190") } + helper.Copy(up.Bool, "encryption", "msc4392") helper.Copy(up.Bool, "encryption", "self_sign") helper.Copy(up.Bool, "encryption", "allow_key_sharing") if secret, ok := helper.Get(up.Str, "encryption", "pickle_key"); !ok || secret == "generate" { diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 44dcbc5b..a4f73e6b 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -56,7 +56,7 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType Extra: content.Raw, }) } - if eventType != event.EventReaction && eventType != event.EventRedaction { + if (eventType != event.EventReaction || as.Connector.Config.Encryption.MSC4392) && eventType != event.EventRedaction { msgContent, ok := content.Parsed.(*event.MessageEventContent) if ok { msgContent.AddPerMessageProfileFallback() diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 947d771b..b0e83696 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -378,6 +378,8 @@ encryption: # Only relevant when using end-to-bridge encryption, required when using encryption with next-gen auth (MSC3861). # Changing this option requires updating the appservice registration file. msc4190: false + # Whether to encrypt reactions and reply metadata as per MSC4392. + msc4392: false # Should the bridge bot generate a recovery key and cross-signing keys and verify itself? # Note that without the latest version of MSC4190, this will fail if you reset the bridge database. # The generated recovery key will be saved in the kv_store table under `recovery_key`. diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index ea97f767..8ce70ca0 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -169,6 +169,15 @@ func (mach *OlmMachine) EncryptMegolmEventWithStateKey(ctx context.Context, room SenderKey: mach.account.IdentityKey(), DeviceID: mach.Client.DeviceID, } + if mach.MSC4392Relations && encrypted.RelatesTo != nil { + // When MSC4392 mode is enabled, reply and reaction metadata is stripped from the unencrypted content. + // Other relations like threads are still left unencrypted. + encrypted.RelatesTo.InReplyTo = nil + encrypted.RelatesTo.IsFallingBack = false + if evtType == event.EventReaction || encrypted.RelatesTo.Type == "" { + encrypted.RelatesTo = nil + } + } if mach.PlaintextMentions { encrypted.Mentions = getMentions(content) } diff --git a/crypto/machine.go b/crypto/machine.go index f8ebe909..fa051f94 100644 --- a/crypto/machine.go +++ b/crypto/machine.go @@ -39,6 +39,7 @@ type OlmMachine struct { cancelBackgroundCtx context.CancelFunc PlaintextMentions bool + MSC4392Relations bool AllowEncryptedState bool // Never ask the server for keys automatically as a side effect during Megolm decryption. From 9f327602f675ce5e721b49df5afd481e1a116b1e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 7 Jan 2026 20:05:42 +0200 Subject: [PATCH 1569/1647] event/beeper: add blurhash for link previews --- event/beeper.go | 1 + 1 file changed, 1 insertion(+) diff --git a/event/beeper.go b/event/beeper.go index 75c18aa7..b46106ab 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -144,6 +144,7 @@ type BeeperLinkPreview struct { MatchedURL string `json:"matched_url,omitempty"` ImageEncryption *EncryptedFileInfo `json:"beeper:image:encryption,omitempty"` + ImageBlurhash string `json:"beeper:image:blurhash,omitempty"` } type BeeperProfileExtra struct { From 32da107299ecd1fef1f09fbffdb96d951347ac85 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 8 Jan 2026 22:52:25 +0200 Subject: [PATCH 1570/1647] bridgev2/matrix: fix decrypting events in GetEvent --- bridgev2/matrix/intent.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index a4f73e6b..3d2692f9 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -714,10 +714,10 @@ func (as *ASIntent) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.E } if evt.Type == event.EventEncrypted { - if as.Connector.Config.Encryption.DeleteKeys.RatchetOnDecrypt { + if as.Connector.Crypto == nil || as.Connector.Config.Encryption.DeleteKeys.RatchetOnDecrypt { return nil, errors.New("can't decrypt the event") } - return as.Matrix.Crypto.Decrypt(ctx, evt) + return as.Connector.Crypto.Decrypt(ctx, evt) } return evt, nil From 6da5f6b5d0ffc64244cea8b2c7160314103e6288 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 10 Jan 2026 14:18:57 +0200 Subject: [PATCH 1571/1647] federation: change serverauth test domains --- federation/serverauth_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/federation/serverauth_test.go b/federation/serverauth_test.go index 633a0f66..f99fc6cf 100644 --- a/federation/serverauth_test.go +++ b/federation/serverauth_test.go @@ -19,7 +19,7 @@ import ( func TestServerKeyResponse_VerifySelfSignature(t *testing.T) { cli := federation.NewClient("", nil, nil) ctx := context.Background() - for _, name := range []string{"matrix.org", "maunium.net", "continuwuity.org"} { + for _, name := range []string{"matrix.org", "maunium.net", "cd.mau.dev", "uwu.mau.dev"} { t.Run(name, func(t *testing.T) { resp, err := cli.ServerKeys(ctx, name) require.NoError(t, err) From c69518ab3c9c98f152d9e5db5793c150c1562754 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 10 Jan 2026 20:53:44 +0200 Subject: [PATCH 1572/1647] bridgev2/login: add default_value for user input fields --- bridgev2/login.go | 2 ++ bridgev2/matrix/provisioning.yaml | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/bridgev2/login.go b/bridgev2/login.go index 46dcf7da..4ddbf13e 100644 --- a/bridgev2/login.go +++ b/bridgev2/login.go @@ -190,6 +190,8 @@ type LoginInputDataField struct { Name string `json:"name"` // The description of the field shown to the user. Description string `json:"description"` + // A default value that the client can pre-fill the field with. + DefaultValue string `json:"default_value,omitempty"` // A regex pattern that the client can use to validate input client-side. Pattern string `json:"pattern,omitempty"` // For fields of type select, the valid options. diff --git a/bridgev2/matrix/provisioning.yaml b/bridgev2/matrix/provisioning.yaml index 50b73c66..d19a7e83 100644 --- a/bridgev2/matrix/provisioning.yaml +++ b/bridgev2/matrix/provisioning.yaml @@ -728,11 +728,14 @@ components: description: A more detailed description of the field shown to the user. examples: - Include the country code with a + + default_value: + type: string + description: A default value that the client can pre-fill the field with. pattern: type: string format: regex description: A regular expression that the field value must match. - select: + options: type: array description: For fields of type select, the valid options. items: From be22286000926cc5549b03caeb1e4cb120b37676 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 10 Jan 2026 18:28:23 +0200 Subject: [PATCH 1573/1647] event: drop MSC4332 support --- event/botcommand.go | 49 --------------------------------------------- event/content.go | 1 - event/message.go | 2 -- event/type.go | 4 +--- 4 files changed, 1 insertion(+), 55 deletions(-) delete mode 100644 event/botcommand.go diff --git a/event/botcommand.go b/event/botcommand.go deleted file mode 100644 index 2b208656..00000000 --- a/event/botcommand.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright (c) 2025 Tulir Asokan -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package event - -import ( - "encoding/json" -) - -type BotCommandsEventContent struct { - Sigil string `json:"sigil,omitempty"` - Commands []*BotCommand `json:"commands,omitempty"` -} - -type BotCommand struct { - Syntax string `json:"syntax"` - Aliases []string `json:"fi.mau.aliases,omitempty"` // Not in MSC (yet) - Arguments []*BotCommandArgument `json:"arguments,omitempty"` - Description *ExtensibleTextContainer `json:"description,omitempty"` -} - -type BotArgumentType string - -const ( - BotArgumentTypeString BotArgumentType = "string" - BotArgumentTypeEnum BotArgumentType = "enum" - BotArgumentTypeInteger BotArgumentType = "integer" - BotArgumentTypeBoolean BotArgumentType = "boolean" - BotArgumentTypeUserID BotArgumentType = "user_id" - BotArgumentTypeRoomID BotArgumentType = "room_id" - BotArgumentTypeRoomAlias BotArgumentType = "room_alias" - BotArgumentTypeEventID BotArgumentType = "event_id" -) - -type BotCommandArgument struct { - Type BotArgumentType `json:"type"` - DefaultValue any `json:"fi.mau.default_value,omitempty"` // Not in MSC (yet) - Description *ExtensibleTextContainer `json:"description,omitempty"` - Enum []string `json:"enum,omitempty"` - Variadic bool `json:"variadic,omitempty"` -} - -type BotCommandInput struct { - Syntax string `json:"syntax"` - Arguments json.RawMessage `json:"arguments,omitempty"` -} diff --git a/event/content.go b/event/content.go index 4929c6a5..d1ced268 100644 --- a/event/content.go +++ b/event/content.go @@ -50,7 +50,6 @@ var TypeMap = map[Type]reflect.Type{ StateElementFunctionalMembers: reflect.TypeOf(ElementFunctionalMembersContent{}), StateBeeperRoomFeatures: reflect.TypeOf(RoomFeatures{}), StateBeeperDisappearingTimer: reflect.TypeOf(BeeperDisappearingTimer{}), - StateBotCommands: reflect.TypeOf(BotCommandsEventContent{}), EventMessage: reflect.TypeOf(MessageEventContent{}), EventSticker: reflect.TypeOf(MessageEventContent{}), diff --git a/event/message.go b/event/message.go index 692382cf..0af3a2c9 100644 --- a/event/message.go +++ b/event/message.go @@ -142,8 +142,6 @@ type MessageEventContent struct { MSC1767Audio *MSC1767Audio `json:"org.matrix.msc1767.audio,omitempty"` MSC3245Voice *MSC3245Voice `json:"org.matrix.msc3245.voice,omitempty"` - - MSC4332BotCommand *BotCommandInput `json:"org.matrix.msc4332.command,omitempty"` } func (content *MessageEventContent) GetCapMsgType() CapabilityMsgType { diff --git a/event/type.go b/event/type.go index f4d7592c..2a9b382c 100644 --- a/event/type.go +++ b/event/type.go @@ -112,8 +112,7 @@ func (et *Type) GuessClass() TypeClass { StatePowerLevels.Type, StateRoomName.Type, StateRoomAvatar.Type, StateServerACL.Type, StateTopic.Type, StatePinnedEvents.Type, StateTombstone.Type, StateEncryption.Type, StateBridge.Type, StateHalfShotBridge.Type, StateSpaceParent.Type, StateSpaceChild.Type, StatePolicyRoom.Type, StatePolicyServer.Type, StatePolicyUser.Type, - StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type, StateBeeperDisappearingTimer.Type, - StateBotCommands.Type: + StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type, StateBeeperDisappearingTimer.Type: return StateEventType case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type: return EphemeralEventType @@ -205,7 +204,6 @@ var ( StateElementFunctionalMembers = Type{"io.element.functional_members", StateEventType} StateBeeperRoomFeatures = Type{"com.beeper.room_features", StateEventType} StateBeeperDisappearingTimer = Type{"com.beeper.disappearing_timer", StateEventType} - StateBotCommands = Type{"org.matrix.msc4332.commands", StateEventType} ) // Message events From 5ac73563b0af64a03e416d31199675eb2d1f7b35 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 10 Jan 2026 18:29:12 +0200 Subject: [PATCH 1574/1647] event/cmdschema: add MSC4391 types, parser and stringifier --- event/cmdschema/content.go | 64 +++ event/cmdschema/parameter.go | 286 +++++++++++ event/cmdschema/parse.go | 471 ++++++++++++++++++ event/cmdschema/parse_test.go | 118 +++++ event/cmdschema/roomid.go | 135 +++++ event/cmdschema/stringify.go | 122 +++++ event/cmdschema/testdata/commands/flags.json | 153 ++++++ .../testdata/commands/room_id_or_alias.json | 84 ++++ .../commands/room_reference_list.json | 105 ++++ event/cmdschema/testdata/commands/simple.json | 45 ++ event/cmdschema/testdata/data.go | 14 + event/cmdschema/testdata/parse_quote.json | 20 + event/message.go | 9 + event/state.go | 7 + event/type.go | 4 +- 15 files changed, 1636 insertions(+), 1 deletion(-) create mode 100644 event/cmdschema/content.go create mode 100644 event/cmdschema/parameter.go create mode 100644 event/cmdschema/parse.go create mode 100644 event/cmdschema/parse_test.go create mode 100644 event/cmdschema/roomid.go create mode 100644 event/cmdschema/stringify.go create mode 100644 event/cmdschema/testdata/commands/flags.json create mode 100644 event/cmdschema/testdata/commands/room_id_or_alias.json create mode 100644 event/cmdschema/testdata/commands/room_reference_list.json create mode 100644 event/cmdschema/testdata/commands/simple.json create mode 100644 event/cmdschema/testdata/data.go create mode 100644 event/cmdschema/testdata/parse_quote.json diff --git a/event/cmdschema/content.go b/event/cmdschema/content.go new file mode 100644 index 00000000..b69f0c1f --- /dev/null +++ b/event/cmdschema/content.go @@ -0,0 +1,64 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cmdschema + +import ( + "crypto/sha256" + "encoding/base64" + "fmt" + "reflect" + "slices" + + "go.mau.fi/util/ptr" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +type EventContent struct { + Command string `json:"command"` + Aliases []string `json:"aliases,omitempty"` + Parameters []*Parameter `json:"parameters,omitempty"` + Description *event.ExtensibleTextContainer `json:"description,omitempty"` +} + +func (ec *EventContent) Validate() error { + if ec == nil { + return fmt.Errorf("event content is nil") + } else if ec.Command == "" { + return fmt.Errorf("command is empty") + } + for i, p := range ec.Parameters { + if err := p.Validate(); err != nil { + return fmt.Errorf("parameter %q (#%d) is invalid: %w", ptr.Val(p).Key, i+1, err) + } + } + return nil +} + +func (ec *EventContent) IsValid() bool { + return ec.Validate() == nil +} + +func (ec *EventContent) StateKey(owner id.UserID) string { + hash := sha256.Sum256([]byte(ec.Command + owner.String())) + return base64.StdEncoding.EncodeToString(hash[:]) +} + +func (ec *EventContent) Equals(other *EventContent) bool { + if ec == nil || other == nil { + return ec == other + } + return ec.Command == other.Command && + slices.Equal(ec.Aliases, other.Aliases) && + slices.EqualFunc(ec.Parameters, other.Parameters, (*Parameter).Equals) && + ec.Description.Equals(other.Description) +} + +func init() { + event.TypeMap[event.StateMSC4391BotCommand] = reflect.TypeOf(EventContent{}) +} diff --git a/event/cmdschema/parameter.go b/event/cmdschema/parameter.go new file mode 100644 index 00000000..4193b297 --- /dev/null +++ b/event/cmdschema/parameter.go @@ -0,0 +1,286 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cmdschema + +import ( + "encoding/json" + "fmt" + "slices" + + "go.mau.fi/util/exslices" + + "maunium.net/go/mautrix/event" +) + +type Parameter struct { + Key string `json:"key"` + Schema *ParameterSchema `json:"schema"` + Optional bool `json:"optional,omitempty"` + Description *event.ExtensibleTextContainer `json:"description,omitempty"` + DefaultValue any `json:"fi.mau.default_value,omitempty"` +} + +func (p *Parameter) Equals(other *Parameter) bool { + if p == nil || other == nil { + return p == other + } + return p.Key == other.Key && + p.Schema.Equals(other.Schema) && + p.Optional == other.Optional && + p.Description.Equals(other.Description) && + p.DefaultValue == other.DefaultValue // TODO this won't work for room/event ID values +} + +func (p *Parameter) Validate() error { + if p == nil { + return fmt.Errorf("parameter is nil") + } else if p.Key == "" { + return fmt.Errorf("key is empty") + } + return p.Schema.Validate() +} + +func (p *Parameter) IsValid() bool { + return p.Validate() == nil +} + +func (p *Parameter) GetDefaultValue() any { + if p != nil && p.DefaultValue != nil { + return p.DefaultValue + } else if p == nil || p.Optional { + return nil + } + return p.Schema.GetDefaultValue() +} + +type PrimitiveType string + +const ( + PrimitiveTypeString PrimitiveType = "string" + PrimitiveTypeInteger PrimitiveType = "integer" + PrimitiveTypeBoolean PrimitiveType = "boolean" + PrimitiveTypeServerName PrimitiveType = "server_name" + PrimitiveTypeUserID PrimitiveType = "user_id" + PrimitiveTypeRoomID PrimitiveType = "room_id" + PrimitiveTypeRoomAlias PrimitiveType = "room_alias" + PrimitiveTypeEventID PrimitiveType = "event_id" +) + +func (pt PrimitiveType) Schema() *ParameterSchema { + return &ParameterSchema{ + SchemaType: SchemaTypePrimitive, + Type: pt, + } +} + +func (pt PrimitiveType) IsValid() bool { + switch pt { + case PrimitiveTypeString, + PrimitiveTypeInteger, + PrimitiveTypeBoolean, + PrimitiveTypeServerName, + PrimitiveTypeUserID, + PrimitiveTypeRoomID, + PrimitiveTypeRoomAlias, + PrimitiveTypeEventID: + return true + default: + return false + } +} + +type SchemaType string + +const ( + SchemaTypePrimitive SchemaType = "primitive" + SchemaTypeArray SchemaType = "array" + SchemaTypeUnion SchemaType = "union" + SchemaTypeLiteral SchemaType = "literal" +) + +type ParameterSchema struct { + SchemaType SchemaType `json:"schema_type"` + Type PrimitiveType `json:"type,omitempty"` // Only for primitive + Items *ParameterSchema `json:"items,omitempty"` // Only for array + Variants []*ParameterSchema `json:"variants,omitempty"` // Only for union + Value any `json:"value,omitempty"` // Only for literal +} + +func Literal(value any) *ParameterSchema { + return &ParameterSchema{ + SchemaType: SchemaTypeLiteral, + Value: value, + } +} + +func Enum(values ...any) *ParameterSchema { + return Union(exslices.CastFunc(values, Literal)...) +} + +func flattenUnion(variants []*ParameterSchema) []*ParameterSchema { + var flattened []*ParameterSchema + for _, variant := range variants { + switch variant.SchemaType { + case SchemaTypeArray: + panic(fmt.Errorf("illegal array schema in union")) + case SchemaTypeUnion: + flattened = append(flattened, flattenUnion(variant.Variants)...) + default: + flattened = append(flattened, variant) + } + } + return flattened +} + +func Union(variants ...*ParameterSchema) *ParameterSchema { + needsFlattening := false + for _, variant := range variants { + if variant.SchemaType == SchemaTypeArray { + panic(fmt.Errorf("illegal array schema in union")) + } else if variant.SchemaType == SchemaTypeUnion { + needsFlattening = true + } + } + if needsFlattening { + variants = flattenUnion(variants) + } + return &ParameterSchema{ + SchemaType: SchemaTypeUnion, + Variants: variants, + } +} + +func Array(items *ParameterSchema) *ParameterSchema { + if items.SchemaType == SchemaTypeArray { + panic(fmt.Errorf("illegal array schema in array")) + } + return &ParameterSchema{ + SchemaType: SchemaTypeArray, + Items: items, + } +} + +func (ps *ParameterSchema) GetDefaultValue() any { + if ps == nil { + return nil + } + switch ps.SchemaType { + case SchemaTypePrimitive: + switch ps.Type { + case PrimitiveTypeInteger: + return 0 + case PrimitiveTypeBoolean: + return false + default: + return "" + } + case SchemaTypeArray: + return []any{} + case SchemaTypeUnion: + if len(ps.Variants) > 0 { + return ps.Variants[0].GetDefaultValue() + } + return nil + case SchemaTypeLiteral: + return ps.Value + default: + return nil + } +} + +func (ps *ParameterSchema) IsValid() bool { + return ps.validate("") == nil +} + +func (ps *ParameterSchema) Validate() error { + return ps.validate("") +} + +func (ps *ParameterSchema) validate(parent SchemaType) error { + if ps == nil { + return fmt.Errorf("schema is nil") + } + switch ps.SchemaType { + case SchemaTypePrimitive: + if !ps.Type.IsValid() { + return fmt.Errorf("invalid primitive type %s", ps.Type) + } else if ps.Items != nil || ps.Variants != nil || ps.Value != nil { + return fmt.Errorf("primitive schema has extra fields") + } + return nil + case SchemaTypeArray: + if parent != "" { + return fmt.Errorf("arrays can't be nested in other types") + } else if err := ps.Items.validate(ps.SchemaType); err != nil { + return fmt.Errorf("item schema is invalid: %w", err) + } else if ps.Type != "" || ps.Variants != nil || ps.Value != nil { + return fmt.Errorf("array schema has extra fields") + } + return nil + case SchemaTypeUnion: + if len(ps.Variants) == 0 { + return fmt.Errorf("no variants specified for union") + } else if parent != "" && parent != SchemaTypeArray { + return fmt.Errorf("unions can't be nested in anything other than arrays") + } + for i, v := range ps.Variants { + if err := v.validate(ps.SchemaType); err != nil { + return fmt.Errorf("variant #%d is invalid: %w", i+1, err) + } + } + if ps.Type != "" || ps.Items != nil || ps.Value != nil { + return fmt.Errorf("union schema has extra fields") + } + return nil + case SchemaTypeLiteral: + switch typedVal := ps.Value.(type) { + case string, float64, int, int64, json.Number, bool, RoomIDValue, *RoomIDValue: + // ok + case map[string]any: + if typedVal["type"] != "event_id" && typedVal["type"] != "room_id" { + return fmt.Errorf("literal value has invalid map data") + } + default: + return fmt.Errorf("literal value has unsupported type %T", ps.Value) + } + if ps.Type != "" || ps.Items != nil || ps.Variants != nil { + return fmt.Errorf("literal schema has extra fields") + } + return nil + default: + return fmt.Errorf("invalid schema type %s", ps.SchemaType) + } +} + +func (ps *ParameterSchema) Equals(other *ParameterSchema) bool { + if ps == nil || other == nil { + return ps == other + } + return ps.SchemaType == other.SchemaType && + ps.Type == other.Type && + ps.Items.Equals(other.Items) && + slices.EqualFunc(ps.Variants, other.Variants, (*ParameterSchema).Equals) && + ps.Value == other.Value // TODO this won't work for room/event ID values +} + +func (ps *ParameterSchema) AllowsPrimitive(prim PrimitiveType) bool { + switch ps.SchemaType { + case SchemaTypePrimitive: + return ps.Type == prim + case SchemaTypeUnion: + for _, variant := range ps.Variants { + if variant.AllowsPrimitive(prim) { + return true + } + } + return false + case SchemaTypeArray: + return ps.Items.AllowsPrimitive(prim) + default: + return false + } +} diff --git a/event/cmdschema/parse.go b/event/cmdschema/parse.go new file mode 100644 index 00000000..6536b410 --- /dev/null +++ b/event/cmdschema/parse.go @@ -0,0 +1,471 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cmdschema + +import ( + "encoding/json" + "errors" + "fmt" + "regexp" + "strconv" + "strings" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +const botArrayOpener = "<" +const botArrayCloser = ">" + +func parseQuoted(val string) (parsed, remaining string, quoted bool) { + if len(val) == 0 { + return + } + if !strings.HasPrefix(val, `"`) { + spaceIdx := strings.IndexByte(val, ' ') + if spaceIdx == -1 { + parsed = val + } else { + parsed = val[:spaceIdx] + remaining = strings.TrimLeft(val[spaceIdx+1:], " ") + } + return + } + val = val[1:] + var buf strings.Builder + for { + quoteIdx := strings.IndexByte(val, '"') + escapeIdx := strings.IndexByte(val[:max(0, quoteIdx)], '\\') + if escapeIdx >= 0 { + buf.WriteString(val[:escapeIdx]) + if len(val) > escapeIdx+1 { + buf.WriteByte(val[escapeIdx+1]) + } + val = val[min(escapeIdx+2, len(val)):] + } else if quoteIdx >= 0 { + buf.WriteString(val[:quoteIdx]) + val = val[quoteIdx+1:] + break + } else if buf.Len() == 0 { + // Unterminated quote, no escape characters, val is the whole input + return val, "", true + } else { + // Unterminated quote, but there were escape characters previously + buf.WriteString(val) + val = "" + break + } + } + return buf.String(), strings.TrimLeft(val, " "), true +} + +// ParseInput tries to parse the given text into a bot command event matching this command definition. +// +// If the prefix doesn't match, this will return a nil content and nil error. +// If the prefix does match, some content is always returned, but there may still be an error if parsing failed. +func (ec *EventContent) ParseInput(owner id.UserID, sigils []string, input string) (content *event.MessageEventContent, err error) { + prefix := ec.parsePrefix(input, sigils, owner.String()) + if prefix == "" { + return nil, nil + } + content = &event.MessageEventContent{ + MsgType: event.MsgText, + Body: input, + Mentions: &event.Mentions{UserIDs: []id.UserID{owner}}, + MSC4391BotCommand: &event.MSC4391BotCommandInput{ + Command: ec.Command, + }, + } + content.MSC4391BotCommand.Arguments, err = ec.ParseArguments(input[len(prefix):]) + return content, err +} + +func (ec *EventContent) ParseArguments(input string) (json.RawMessage, error) { + args := make(map[string]any) + var retErr error + setError := func(err error) { + if err != nil && retErr == nil { + retErr = err + } + } + processParameter := func(param *Parameter, isLast, isNamed bool) { + origInput := input + var nextVal string + var wasQuoted bool + if param.Schema.SchemaType == SchemaTypeArray { + hasOpener := strings.HasPrefix(input, botArrayOpener) + arrayClosed := false + if hasOpener { + input = input[len(botArrayOpener):] + if strings.HasPrefix(input, botArrayCloser) { + input = strings.TrimLeft(input[len(botArrayCloser):], " ") + arrayClosed = true + } + } + var collector []any + for len(input) > 0 && !arrayClosed { + //origInput = input + nextVal, input, wasQuoted = parseQuoted(input) + if !wasQuoted && hasOpener && strings.HasSuffix(nextVal, botArrayCloser) { + // The value wasn't quoted and has the array delimiter at the end, close the array + nextVal = strings.TrimRight(nextVal, botArrayCloser) + arrayClosed = true + } else if hasOpener && strings.HasPrefix(input, botArrayCloser) { + // The value was quoted or there was a space, and the next character is the + // array delimiter, close the array + input = strings.TrimLeft(input[len(botArrayCloser):], " ") + arrayClosed = true + } else if !hasOpener && !isLast { + // For array arguments in the middle without the <> delimiters, stop after the first item + arrayClosed = true + } + parsedVal, err := param.Schema.Items.ParseString(nextVal) + if err == nil { + collector = append(collector, parsedVal) + } else if hasOpener || isLast { + setError(fmt.Errorf("failed to parse item #%d of array %s: %w", len(collector)+1, param.Key, err)) + } else { + //input = origInput + } + } + args[param.Key] = collector + } else { + nextVal, input, wasQuoted = parseQuoted(input) + if isLast && !wasQuoted && len(input) > 0 { + // If the last argument is not quoted and not variadic, just treat the rest of the string + // as the argument without escapes (arguments with escapes should be quoted). + nextVal += " " + input + input = "" + } + // Special case for named boolean parameters: if no value is given, treat it as true + if nextVal == "" && !wasQuoted && isNamed && param.Schema.AllowsPrimitive(PrimitiveTypeBoolean) { + args[param.Key] = true + return + } + if nextVal == "" && !param.Optional { + setError(fmt.Errorf("missing value for required parameter %s", param.Key)) + } + parsedVal, err := param.Schema.ParseString(nextVal) + if err != nil { + args[param.Key] = param.GetDefaultValue() + // For optional parameters that fail to parse, restore the input and try passing it as the next parameter + if param.Optional && !isLast && !isNamed { + input = strings.TrimLeft(origInput, " ") + } else if !param.Optional || isNamed { + setError(fmt.Errorf("failed to parse %s: %w", param.Key, err)) + } + } else { + args[param.Key] = parsedVal + } + } + } + skipParams := make([]bool, len(ec.Parameters)) + for i, param := range ec.Parameters { + for strings.HasPrefix(input, "--") { + nameEndIdx := strings.IndexAny(input, " =") + if nameEndIdx == -1 { + nameEndIdx = len(input) + } + overrideParam, paramIdx := ec.parameterByName(input[2:nameEndIdx]) + if overrideParam != nil { + // Trim the equals sign, but leave spaces alone to let parseQuoted treat it as empty input + input = strings.TrimPrefix(input[nameEndIdx:], "=") + skipParams[paramIdx] = true + processParameter(overrideParam, false, true) + } else { + break + } + } + if skipParams[i] { + continue + } + processParameter(param, i == len(ec.Parameters)-1, false) + } + jsonArgs, marshalErr := json.Marshal(args) + if marshalErr != nil { + return nil, fmt.Errorf("failed to marshal arguments: %w", marshalErr) + } + return jsonArgs, retErr +} + +func (ec *EventContent) parameterByName(name string) (*Parameter, int) { + for i, param := range ec.Parameters { + if strings.EqualFold(param.Key, name) { + return param, i + } + } + return nil, -1 +} + +func (ec *EventContent) parsePrefix(origInput string, sigils []string, owner string) (prefix string) { + input := origInput + var chosenSigil string + for _, sigil := range sigils { + if strings.HasPrefix(input, sigil) { + chosenSigil = sigil + break + } + } + if chosenSigil == "" { + return "" + } + input = input[len(chosenSigil):] + var chosenAlias string + if !strings.HasPrefix(input, ec.Command) { + for _, alias := range ec.Aliases { + if strings.HasPrefix(input, alias) { + chosenAlias = alias + break + } + } + if chosenAlias == "" { + return "" + } + } else { + chosenAlias = ec.Command + } + input = strings.TrimPrefix(input[len(chosenAlias):], owner) + if input == "" || input[0] == ' ' { + input = strings.TrimLeft(input, " ") + return origInput[:len(origInput)-len(input)] + } + return "" +} + +func (pt PrimitiveType) ValidateValue(value any) bool { + _, err := pt.NormalizeValue(value) + return err == nil +} + +func normalizeNumber(value any) (int, error) { + switch typedValue := value.(type) { + case int: + return typedValue, nil + case int64: + return int(typedValue), nil + case float64: + return int(typedValue), nil + case json.Number: + if i, err := typedValue.Int64(); err != nil { + return 0, fmt.Errorf("failed to parse json.Number: %w", err) + } else { + return int(i), nil + } + default: + return 0, fmt.Errorf("unsupported type %T for integer", value) + } +} + +func (pt PrimitiveType) NormalizeValue(value any) (any, error) { + switch pt { + case PrimitiveTypeInteger: + return normalizeNumber(value) + case PrimitiveTypeBoolean: + bv, ok := value.(bool) + if !ok { + return nil, fmt.Errorf("unsupported type %T for boolean", value) + } + return bv, nil + case PrimitiveTypeString, PrimitiveTypeServerName: + str, ok := value.(string) + if !ok { + return nil, fmt.Errorf("unsupported type %T for string", value) + } + return str, pt.validateStringValue(str) + case PrimitiveTypeUserID, PrimitiveTypeRoomAlias: + str, ok := value.(string) + if !ok { + return nil, fmt.Errorf("unsupported type %T for user ID or room alias", value) + } else if plainErr := pt.validateStringValue(str); plainErr == nil { + return str, nil + } else if parsed, err := id.ParseMatrixURIOrMatrixToURL(str); err != nil { + return nil, fmt.Errorf("couldn't parse %q as plain ID nor matrix URI: %w / %w", value, plainErr, err) + } else if parsed.Sigil1 == '@' && pt == PrimitiveTypeUserID { + return parsed.UserID(), nil + } else if parsed.Sigil1 == '#' && pt == PrimitiveTypeRoomAlias { + return parsed.RoomAlias(), nil + } else { + return nil, fmt.Errorf("unexpected sigil %c for user ID or room alias", parsed.Sigil1) + } + case PrimitiveTypeRoomID, PrimitiveTypeEventID: + riv, err := NormalizeRoomIDValue(value) + if err != nil { + return nil, err + } + return riv, riv.Validate() + default: + return nil, fmt.Errorf("cannot normalize value for argument type %s", pt) + } +} + +func (pt PrimitiveType) validateStringValue(value string) error { + switch pt { + case PrimitiveTypeString: + return nil + case PrimitiveTypeServerName: + if !id.ValidateServerName(value) { + return fmt.Errorf("invalid server name: %q", value) + } + return nil + case PrimitiveTypeUserID: + _, _, err := id.UserID(value).ParseAndValidateRelaxed() + return err + case PrimitiveTypeRoomAlias: + sigil, localpart, serverName := id.ParseCommonIdentifier(value) + if sigil != '#' || localpart == "" || serverName == "" { + return fmt.Errorf("invalid room alias: %q", value) + } else if !id.ValidateServerName(serverName) { + return fmt.Errorf("invalid server name in room alias: %q", serverName) + } + return nil + default: + panic(fmt.Errorf("validateStringValue called with invalid type %s", pt)) + } +} + +func parseBoolean(val string) (bool, error) { + if len(val) == 0 { + return false, fmt.Errorf("cannot parse empty string as boolean") + } + switch val[0] { + case 't', 'T', 'y', 'Y', '1': + return true, nil + case 'f', 'F', 'n', 'N', '0': + return false, nil + default: + return false, fmt.Errorf("invalid boolean string: %q", val) + } +} + +var markdownLinkRegex = regexp.MustCompile(`^\[.+]\(([^)]+)\)$`) + +func parseRoomOrEventID(value string) (*RoomIDValue, error) { + if strings.HasPrefix(value, "[") && strings.Contains(value, "](") && strings.HasSuffix(value, ")") { + matches := markdownLinkRegex.FindStringSubmatch(value) + if len(matches) == 2 { + value = matches[1] + } + } + parsed, err := id.ParseMatrixURIOrMatrixToURL(value) + if err != nil && strings.HasPrefix(value, "!") { + return &RoomIDValue{ + Type: PrimitiveTypeRoomID, + RoomID: id.RoomID(value), + }, nil + } + if err != nil { + return nil, err + } else if parsed.Sigil1 != '!' { + return nil, fmt.Errorf("unexpected sigil %c for room ID", parsed.Sigil1) + } else if parsed.MXID2 != "" && parsed.Sigil2 != '$' { + return nil, fmt.Errorf("unexpected sigil %c for event ID", parsed.Sigil2) + } + valType := PrimitiveTypeRoomID + if parsed.MXID2 != "" { + valType = PrimitiveTypeEventID + } + return &RoomIDValue{ + Type: valType, + RoomID: parsed.RoomID(), + Via: parsed.Via, + EventID: parsed.EventID(), + }, nil +} + +func (pt PrimitiveType) ParseString(value string) (any, error) { + switch pt { + case PrimitiveTypeInteger: + return strconv.Atoi(value) + case PrimitiveTypeBoolean: + return parseBoolean(value) + case PrimitiveTypeString, PrimitiveTypeServerName, PrimitiveTypeUserID: + return value, pt.validateStringValue(value) + case PrimitiveTypeRoomAlias: + plainErr := pt.validateStringValue(value) + if plainErr == nil { + return value, nil + } + parsed, err := id.ParseMatrixURIOrMatrixToURL(value) + if err != nil { + return nil, fmt.Errorf("couldn't parse %q as plain room alias nor matrix URI: %w / %w", value, plainErr, err) + } else if parsed.Sigil1 != '#' { + return nil, fmt.Errorf("unexpected sigil %c for room alias", parsed.Sigil1) + } + return parsed.RoomAlias(), nil + case PrimitiveTypeRoomID, PrimitiveTypeEventID: + parsed, err := parseRoomOrEventID(value) + if err != nil { + return nil, err + } else if pt != parsed.Type { + return nil, fmt.Errorf("mismatching argument type: expected %s but got %s", pt, parsed.Type) + } + return parsed, nil + default: + return nil, fmt.Errorf("cannot parse string for argument type %s", pt) + } +} + +func (ps *ParameterSchema) ParseString(value string) (any, error) { + if ps == nil { + return nil, fmt.Errorf("parameter schema is nil") + } + switch ps.SchemaType { + case SchemaTypePrimitive: + return ps.Type.ParseString(value) + case SchemaTypeLiteral: + switch typedValue := ps.Value.(type) { + case string: + if value == typedValue { + return typedValue, nil + } else { + return nil, fmt.Errorf("literal value %q does not match %q", typedValue, value) + } + case int, int64, float64, json.Number: + expectedVal, _ := normalizeNumber(typedValue) + intVal, err := strconv.Atoi(value) + if err != nil { + return nil, fmt.Errorf("failed to parse integer literal: %w", err) + } else if intVal != expectedVal { + return nil, fmt.Errorf("literal value %d does not match %d", expectedVal, intVal) + } + return intVal, nil + case bool: + boolVal, err := parseBoolean(value) + if err != nil { + return nil, fmt.Errorf("failed to parse boolean literal: %w", err) + } else if boolVal != typedValue { + return nil, fmt.Errorf("literal value %t does not match %t", typedValue, boolVal) + } + return boolVal, nil + case RoomIDValue, *RoomIDValue, map[string]any, json.RawMessage: + expectedVal, _ := NormalizeRoomIDValue(typedValue) + parsed, err := parseRoomOrEventID(value) + if err != nil { + return nil, fmt.Errorf("failed to parse room or event ID literal: %w", err) + } else if !parsed.Equals(expectedVal) { + return nil, fmt.Errorf("literal value %s does not match %s", expectedVal, parsed) + } + return parsed, nil + default: + return nil, fmt.Errorf("unsupported literal type %T", ps.Value) + } + case SchemaTypeUnion: + var errs []error + for _, variant := range ps.Variants { + if parsed, err := variant.ParseString(value); err == nil { + return parsed, nil + } else { + errs = append(errs, err) + } + } + return nil, fmt.Errorf("no union variant matched: %w", errors.Join(errs...)) + case SchemaTypeArray: + return nil, fmt.Errorf("cannot parse string for array schema type") + default: + return nil, fmt.Errorf("unknown schema type %s", ps.SchemaType) + } +} diff --git a/event/cmdschema/parse_test.go b/event/cmdschema/parse_test.go new file mode 100644 index 00000000..725b0150 --- /dev/null +++ b/event/cmdschema/parse_test.go @@ -0,0 +1,118 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cmdschema + +import ( + "bytes" + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "go.mau.fi/util/exbytes" + "go.mau.fi/util/exerrors" + + "maunium.net/go/mautrix/event/cmdschema/testdata" +) + +type QuoteParseOutput struct { + Parsed string + Remaining string + Quoted bool +} + +func (qpo *QuoteParseOutput) UnmarshalJSON(data []byte) error { + var arr []any + if err := json.Unmarshal(data, &arr); err != nil { + return err + } + qpo.Parsed = arr[0].(string) + qpo.Remaining = arr[1].(string) + qpo.Quoted = arr[2].(bool) + return nil +} + +type QuoteParseTestData struct { + Name string `json:"name"` + Input string `json:"input"` + Output QuoteParseOutput `json:"output"` +} + +func loadFile[T any](name string) (into T) { + quoteData := exerrors.Must(testdata.FS.ReadFile(name)) + exerrors.PanicIfNotNil(json.Unmarshal(quoteData, &into)) + return +} + +func TestParseQuoted(t *testing.T) { + qptd := loadFile[[]QuoteParseTestData]("parse_quote.json") + for _, test := range qptd { + t.Run(test.Name, func(t *testing.T) { + parsed, remaining, quoted := parseQuoted(test.Input) + assert.Equalf(t, test.Output, QuoteParseOutput{ + Parsed: parsed, + Remaining: remaining, + Quoted: quoted, + }, "Failed with input `%s`", test.Input) + // Note: can't just test that requoted == input, because some inputs + // have unnecessary escapes which won't survive roundtripping + t.Run("roundtrip", func(t *testing.T) { + requoted := quoteString(parsed) + " " + remaining + reparsed, newRemaining, _ := parseQuoted(requoted) + assert.Equal(t, parsed, reparsed) + assert.Equal(t, remaining, newRemaining) + }) + }) + } +} + +type CommandTestData struct { + Spec *EventContent + Tests []*CommandTestUnit +} + +type CommandTestUnit struct { + Name string `json:"name"` + Input string `json:"input"` + Broken string `json:"broken,omitempty"` + Error bool `json:"error"` + Output json.RawMessage `json:"output"` +} + +func compactJSON(input json.RawMessage) json.RawMessage { + var buf bytes.Buffer + exerrors.PanicIfNotNil(json.Compact(&buf, input)) + return buf.Bytes() +} + +func TestMSC4391BotCommandEventContent_ParseInput(t *testing.T) { + for _, cmd := range exerrors.Must(testdata.FS.ReadDir("commands")) { + t.Run(strings.TrimSuffix(cmd.Name(), ".json"), func(t *testing.T) { + ctd := loadFile[CommandTestData]("commands/" + cmd.Name()) + for _, test := range ctd.Tests { + outputStr := exbytes.UnsafeString(compactJSON(test.Output)) + t.Run(test.Name, func(t *testing.T) { + if test.Broken != "" { + t.Skip(test.Broken) + } + output, err := ctd.Spec.ParseInput("@testbot", []string{"/"}, test.Input) + if test.Error { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + if outputStr == "null" { + assert.Nil(t, output) + } else { + assert.Equal(t, ctd.Spec.Command, output.MSC4391BotCommand.Command) + assert.Equal(t, outputStr, exbytes.UnsafeString(output.MSC4391BotCommand.Arguments)) + } + }) + } + }) + } +} diff --git a/event/cmdschema/roomid.go b/event/cmdschema/roomid.go new file mode 100644 index 00000000..98c421fc --- /dev/null +++ b/event/cmdschema/roomid.go @@ -0,0 +1,135 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cmdschema + +import ( + "encoding/json" + "fmt" + "slices" + "strings" + + "maunium.net/go/mautrix/id" +) + +var ParameterSchemaJoinableRoom = Union( + PrimitiveTypeRoomID.Schema(), + PrimitiveTypeRoomAlias.Schema(), +) + +type RoomIDValue struct { + Type PrimitiveType `json:"type"` + RoomID id.RoomID `json:"id"` + Via []string `json:"via,omitempty"` + EventID id.EventID `json:"event_id,omitempty"` +} + +func NormalizeRoomIDValue(input any) (riv *RoomIDValue, err error) { + switch typedValue := input.(type) { + case map[string]any, json.RawMessage: + var raw json.RawMessage + if raw, err = json.Marshal(input); err != nil { + err = fmt.Errorf("failed to roundtrip room ID value: %w", err) + } else if err = json.Unmarshal(raw, &riv); err != nil { + err = fmt.Errorf("failed to roundtrip room ID value: %w", err) + } + case *RoomIDValue: + riv = typedValue + case RoomIDValue: + riv = &typedValue + default: + err = fmt.Errorf("unsupported type %T for room or event ID", input) + } + return +} + +func (riv *RoomIDValue) String() string { + return riv.URI().String() +} + +func (riv *RoomIDValue) URI() *id.MatrixURI { + if riv == nil { + return nil + } + switch riv.Type { + case PrimitiveTypeRoomID: + return riv.RoomID.URI(riv.Via...) + case PrimitiveTypeEventID: + return riv.RoomID.EventURI(riv.EventID, riv.Via...) + default: + return nil + } +} + +func (riv *RoomIDValue) Equals(other *RoomIDValue) bool { + if riv == nil || other == nil { + return riv == other + } + return riv.Type == other.Type && + riv.RoomID == other.RoomID && + riv.EventID == other.EventID && + slices.Equal(riv.Via, other.Via) +} + +func (riv *RoomIDValue) Validate() error { + if riv == nil { + return fmt.Errorf("value is nil") + } + switch riv.Type { + case PrimitiveTypeRoomID: + if riv.EventID != "" { + return fmt.Errorf("event ID must be empty for room ID type") + } + case PrimitiveTypeEventID: + if !strings.HasPrefix(riv.EventID.String(), "$") { + return fmt.Errorf("event ID not valid: %q", riv.EventID) + } + default: + return fmt.Errorf("unexpected type %s for room/event ID value", riv.Type) + } + for _, via := range riv.Via { + if !id.ValidateServerName(via) { + return fmt.Errorf("invalid server name %q in vias", via) + } + } + sigil, localpart, serverName := id.ParseCommonIdentifier(riv.RoomID) + if sigil != '!' { + return fmt.Errorf("room ID does not start with !: %q", riv.RoomID) + } else if localpart == "" && serverName == "" { + return fmt.Errorf("room ID has empty localpart and server name: %q", riv.RoomID) + } else if serverName != "" && !id.ValidateServerName(serverName) { + return fmt.Errorf("invalid server name %q in room ID", serverName) + } + return nil +} + +func (riv *RoomIDValue) IsValid() bool { + return riv.Validate() == nil +} + +type RoomIDOrString string + +func (ros *RoomIDOrString) UnmarshalJSON(data []byte) error { + if len(data) == 0 { + return fmt.Errorf("empty data for room ID or string") + } + if data[0] == '"' { + var str string + if err := json.Unmarshal(data, &str); err != nil { + return err + } + *ros = RoomIDOrString(str) + return nil + } + var riv RoomIDValue + if err := json.Unmarshal(data, &riv); err != nil { + return err + } else if err = riv.Validate(); err != nil { + return err + } + *ros = RoomIDOrString(riv.String()) + return nil +} diff --git a/event/cmdschema/stringify.go b/event/cmdschema/stringify.go new file mode 100644 index 00000000..c5c57c53 --- /dev/null +++ b/event/cmdschema/stringify.go @@ -0,0 +1,122 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package cmdschema + +import ( + "encoding/json" + "strconv" + "strings" +) + +var quoteEscaper = strings.NewReplacer( + `"`, `\"`, + `\`, `\\`, +) + +const charsToQuote = ` \` + botArrayOpener + botArrayCloser + +func quoteString(val string) string { + if val == "" { + return `""` + } + val = quoteEscaper.Replace(val) + if strings.ContainsAny(val, charsToQuote) { + return `"` + val + `"` + } + return val +} + +func (ec *EventContent) StringifyArgs(args any) string { + var argMap map[string]any + switch typedArgs := args.(type) { + case json.RawMessage: + err := json.Unmarshal(typedArgs, &argMap) + if err != nil { + return "" + } + case map[string]any: + argMap = typedArgs + default: + if b, err := json.Marshal(args); err != nil { + return "" + } else if err = json.Unmarshal(b, &argMap); err != nil { + return "" + } + } + parts := make([]string, 0, len(ec.Parameters)) + for i, param := range ec.Parameters { + isLast := i == len(ec.Parameters)-1 + val := argMap[param.Key] + if val == nil { + val = param.DefaultValue + if val == nil && !param.Optional { + val = param.Schema.GetDefaultValue() + } + } + if val == nil { + continue + } + var stringified string + if param.Schema.SchemaType == SchemaTypeArray { + stringified = arrayArgumentToString(val, isLast) + } else { + stringified = singleArgumentToString(val) + } + if stringified != "" { + parts = append(parts, stringified) + } + } + return strings.Join(parts, " ") +} + +func arrayArgumentToString(val any, isLast bool) string { + valArr, ok := val.([]any) + if !ok { + return "" + } + parts := make([]string, 0, len(valArr)) + for _, elem := range valArr { + stringified := singleArgumentToString(elem) + if stringified != "" { + parts = append(parts, stringified) + } + } + joinedParts := strings.Join(parts, " ") + if isLast && len(parts) > 0 { + return joinedParts + } + return botArrayOpener + joinedParts + botArrayCloser +} + +func singleArgumentToString(val any) string { + switch typedVal := val.(type) { + case string: + return quoteString(typedVal) + case json.Number: + return typedVal.String() + case bool: + return strconv.FormatBool(typedVal) + case int: + return strconv.Itoa(typedVal) + case int64: + return strconv.FormatInt(typedVal, 10) + case float64: + return strconv.FormatInt(int64(typedVal), 10) + case map[string]any, json.RawMessage, RoomIDValue, *RoomIDValue: + normalized, err := NormalizeRoomIDValue(typedVal) + if err != nil { + return "" + } + uri := normalized.URI() + if uri == nil { + return "" + } + return quoteString(uri.String()) + default: + return "" + } +} diff --git a/event/cmdschema/testdata/commands/flags.json b/event/cmdschema/testdata/commands/flags.json new file mode 100644 index 00000000..dedde348 --- /dev/null +++ b/event/cmdschema/testdata/commands/flags.json @@ -0,0 +1,153 @@ +{ + "spec": { + "command": "flag", + "source": "@testbot", + "parameters": [ + { + "key": "meow", + "schema": { + "schema_type": "primitive", + "type": "string" + } + }, + { + "key": "user", + "schema": { + "schema_type": "primitive", + "type": "user_id" + }, + "optional": true + }, + { + "key": "woof", + "schema": { + "schema_type": "primitive", + "type": "boolean" + }, + "optional": true, + "fi.mau.default_value": false + } + ] + }, + "tests": [ + { + "name": "no flags", + "input": "/flag mrrp", + "output": { + "meow": "mrrp", + "user": null, + "woof": false + } + }, + { + "name": "positional flag", + "input": "/flag mrrp @user:example.com yes", + "output": { + "meow": "mrrp", + "user": "@user:example.com", + "woof": true + } + }, + { + "name": "named flag at start", + "input": "/flag --woof=yes mrrp @user:example.com", + "output": { + "meow": "mrrp", + "user": "@user:example.com", + "woof": true + } + }, + { + "name": "boolean flag without value", + "input": "/flag --woof mrrp @user:example.com", + "output": { + "meow": "mrrp", + "user": "@user:example.com", + "woof": true + } + }, + { + "name": "user id flag without value", + "input": "/flag --user --woof mrrp", + "error": true, + "output": { + "meow": "mrrp", + "user": null, + "woof": true + } + }, + { + "name": "named flag in the middle", + "input": "/flag mrrp --woof=yes @user:example.com", + "output": { + "meow": "mrrp", + "user": "@user:example.com", + "woof": true + } + }, + { + "name": "named flag in the middle with different value", + "input": "/flag mrrp --woof=no @user:example.com", + "output": { + "meow": "mrrp", + "user": "@user:example.com", + "woof": false + } + }, + { + "name": "named flag at end", + "input": "/flag mrrp @user:example.com --woof=yes", + "output": { + "meow": "mrrp", + "user": "@user:example.com", + "woof": true + } + }, + { + "name": "named flag at end without value", + "input": "/flag mrrp @user:example.com --woof", + "output": { + "meow": "mrrp", + "user": "@user:example.com", + "woof": true + } + }, + { + "name": "all variables named", + "input": "/flag --woof=no --meow=mrrp --user=@user:example.com", + "output": { + "meow": "mrrp", + "user": "@user:example.com", + "woof": false + } + }, + { + "name": "all variables named with quotes", + "input": "/flag --woof --meow=\"meow meow mrrp\" --user=\"@user:example.com\"", + "output": { + "meow": "meow meow mrrp", + "user": "@user:example.com", + "woof": true + } + }, + { + "name": "only string variables named", + "input": "/flag --user=@user:example.com --meow=mrrp yes", + "output": { + "meow": "mrrp", + "user": "@user:example.com", + "woof": true + } + }, + { + "name": "invalid value for named parameter", + "input": "/flag --user=meowings mrrp yes", + "error": true, + "output": { + "meow": "mrrp", + "user": null, + "woof": true + } + } + ] +} diff --git a/event/cmdschema/testdata/commands/room_id_or_alias.json b/event/cmdschema/testdata/commands/room_id_or_alias.json new file mode 100644 index 00000000..0dc233b8 --- /dev/null +++ b/event/cmdschema/testdata/commands/room_id_or_alias.json @@ -0,0 +1,84 @@ +{ + "spec": { + "command": "test room reference", + "source": "@testbot", + "parameters": [ + { + "key": "room", + "schema": { + "schema_type": "union", + "variants": [ + { + "schema_type": "primitive", + "type": "room_id" + }, + { + "schema_type": "primitive", + "type": "room_alias" + } + ] + } + } + ] + }, + "tests": [ + { + "name": "room alias", + "input": "/test room reference #test:matrix.org", + "output": { + "room": "#test:matrix.org" + } + }, + { + "name": "room id", + "input": "/test room reference !aiwVrNhPwbGBNjqlNu:matrix.org", + "output": { + "room": { + "type": "room_id", + "id": "!aiwVrNhPwbGBNjqlNu:matrix.org" + } + } + }, + { + "name": "room id matrix.to link", + "input": "/test room reference https://matrix.to/#/!aiwVrNhPwbGBNjqlNu:matrix.org?via=example.com", + "output": { + "room": { + "type": "room_id", + "id": "!aiwVrNhPwbGBNjqlNu:matrix.org", + "via": [ + "example.com" + ] + } + } + }, + { + "name": "room id matrix.to link with url encoding", + "input": "/test room reference https://matrix.to/#/!%23test%2Froom%0Aversion%20%3Cu%3E11%3C%2Fu%3E%2C%20with%20%40%F0%9F%90%88%EF%B8%8F%3Amaunium.net?via=maunium.net", + "broken": "Go's url.URL does url decoding on the fragment, which breaks splitting the path segments properly", + "output": { + "room": { + "type": "room_id", + "id": "!#test/room\nversion 11, with @🐈️:maunium.net", + "via": [ + "maunium.net" + ] + } + } + }, + { + "name": "room id matrix: URI", + "input": "/test room reference matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org", + "output": { + "room": { + "type": "room_id", + "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", + "via": [ + "maunium.net", + "matrix.org" + ] + } + } + } + ] +} diff --git a/event/cmdschema/testdata/commands/room_reference_list.json b/event/cmdschema/testdata/commands/room_reference_list.json new file mode 100644 index 00000000..99388f90 --- /dev/null +++ b/event/cmdschema/testdata/commands/room_reference_list.json @@ -0,0 +1,105 @@ +{ + "spec": { + "command": "test room reference", + "source": "@testbot", + "parameters": [ + { + "key": "rooms", + "schema": { + "schema_type": "array", + "items": { + "schema_type": "union", + "variants": [ + { + "schema_type": "primitive", + "type": "room_id" + }, + { + "schema_type": "primitive", + "type": "room_alias" + } + ] + } + } + } + ] + }, + "tests": [ + { + "name": "room alias", + "input": "/test room reference #test:matrix.org", + "output": { + "rooms": [ + "#test:matrix.org" + ] + } + }, + { + "name": "room id", + "input": "/test room reference !aiwVrNhPwbGBNjqlNu:matrix.org", + "output": { + "rooms": [ + { + "type": "room_id", + "id": "!aiwVrNhPwbGBNjqlNu:matrix.org" + } + ] + } + }, + { + "name": "two room ids", + "input": "/test room reference !mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ !aiwVrNhPwbGBNjqlNu:matrix.org", + "output": { + "rooms": [ + { + "type": "room_id", + "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ" + }, + { + "type": "room_id", + "id": "!aiwVrNhPwbGBNjqlNu:matrix.org" + } + ] + } + }, + { + "name": "room id matrix: URI", + "input": "/test room reference matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org", + "output": { + "rooms": [ + { + "type": "room_id", + "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", + "via": [ + "maunium.net", + "matrix.org" + ] + } + ] + } + }, + { + "name": "room id matrix: URI and matrix.to URL", + "input": "/test room reference https://matrix.to/#/!aiwVrNhPwbGBNjqlNu:matrix.org?via=example.com matrix:roomid/mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ?via=maunium.net&via=matrix.org", + "output": { + "rooms": [ + { + "type": "room_id", + "id": "!aiwVrNhPwbGBNjqlNu:matrix.org", + "via": [ + "example.com" + ] + }, + { + "type": "room_id", + "id": "!mauT12AzsoqxV7Abvy_ApA-HNPK1LcT4GbP70_AOPyQ", + "via": [ + "maunium.net", + "matrix.org" + ] + } + ] + } + } + ] +} diff --git a/event/cmdschema/testdata/commands/simple.json b/event/cmdschema/testdata/commands/simple.json new file mode 100644 index 00000000..8127aff1 --- /dev/null +++ b/event/cmdschema/testdata/commands/simple.json @@ -0,0 +1,45 @@ +{ + "spec": { + "command": "test simple", + "source": "@testbot", + "parameters": [ + { + "key": "meow", + "schema": { + "schema_type": "primitive", + "type": "string" + } + } + ] + }, + "tests": [ + { + "name": "success", + "input": "/test simple mrrp", + "output": { + "meow": "mrrp" + } + }, + { + "name": "directed success", + "input": "/test simple@testbot mrrp", + "output": { + "meow": "mrrp" + } + }, + { + "name": "missing parameter", + "input": "/test simple", + "error": true, + "output": { + "meow": "" + } + }, + { + "name": "directed at another bot", + "input": "/test simple@anotherbot mrrp", + "error": false, + "output": null + } + ] +} diff --git a/event/cmdschema/testdata/data.go b/event/cmdschema/testdata/data.go new file mode 100644 index 00000000..eceea3d2 --- /dev/null +++ b/event/cmdschema/testdata/data.go @@ -0,0 +1,14 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package testdata + +import ( + "embed" +) + +//go:embed * +var FS embed.FS diff --git a/event/cmdschema/testdata/parse_quote.json b/event/cmdschema/testdata/parse_quote.json new file mode 100644 index 00000000..d22f1299 --- /dev/null +++ b/event/cmdschema/testdata/parse_quote.json @@ -0,0 +1,20 @@ +[ + {"name": "single word", "input": "meow", "output": ["meow", "", false]}, + {"name": "two words", "input": "meow woof", "output": ["meow", "woof", false]}, + {"name": "many words", "input": "meow meow mrrp", "output": ["meow", "meow mrrp", false]}, + {"name": "extra spaces", "input": "meow meow mrrp", "output": ["meow", "meow mrrp", false]}, + {"name": "trailing space", "input": "meow ", "output": ["meow", "", false]}, + {"name": "quoted word", "input": "\"meow\" meow mrrp", "output": ["meow", "meow mrrp", true]}, + {"name": "quoted words", "input": "\"meow meow\" mrrp", "output": ["meow meow", "mrrp", true]}, + {"name": "spaces in quotes", "input": "\" meow meow \" mrrp", "output": [" meow meow ", "mrrp", true]}, + {"name": "quotes after word", "input": "meow \" meow mrrp \"", "output": ["meow", "\" meow mrrp \"", false]}, + {"name": "escaped quote", "input": "\"meow\\\" meow\" mrrp", "output": ["meow\" meow", "mrrp", true]}, + {"name": "missing end quote", "input": "\"meow meow mrrp", "output": ["meow meow mrrp", "", true]}, + {"name": "missing end quote with escaped quote", "input": "\"meow\\\" meow mrrp", "output": ["meow\" meow mrrp", "", true]}, + {"name": "quote in the middle", "input": "me\"ow meow mrrp", "output": ["me\"ow", "meow mrrp", false]}, + {"name": "backslash in the middle", "input": "me\\ow meow mrrp", "output": ["me\\ow", "meow mrrp", false]}, + {"name": "other escaped character", "input": "\"m\\eow\" meow mrrp", "output": ["meow", "meow mrrp", true]}, + {"name": "escaped backslashes", "input": "\"m\\\\e\\\"ow\\\\\" meow mrrp", "output": ["m\\e\"ow\\", "meow mrrp", true]}, + {"name": "just quotes", "input": "\"\\\"\\\"\\\\\\\"\" meow", "output": ["\"\"\\\"", "meow", true]}, + {"name": "eof escape", "input": "\"meow\\", "output": ["meow\\", "", true]} +] diff --git a/event/message.go b/event/message.go index 0af3a2c9..5e80d2ef 100644 --- a/event/message.go +++ b/event/message.go @@ -142,6 +142,8 @@ type MessageEventContent struct { MSC1767Audio *MSC1767Audio `json:"org.matrix.msc1767.audio,omitempty"` MSC3245Voice *MSC3245Voice `json:"org.matrix.msc3245.voice,omitempty"` + + MSC4391BotCommand *MSC4391BotCommandInput `json:"org.matrix.msc4391.command,omitempty"` } func (content *MessageEventContent) GetCapMsgType() CapabilityMsgType { @@ -285,6 +287,13 @@ func (m *Mentions) Merge(other *Mentions) *Mentions { } } +type MSC4391BotCommandInputCustom[T any] struct { + Command string `json:"command"` + Arguments T `json:"arguments,omitempty"` +} + +type MSC4391BotCommandInput = MSC4391BotCommandInputCustom[json.RawMessage] + type EncryptedFileInfo struct { attachment.EncryptedFile URL id.ContentURIString `json:"url"` diff --git a/event/state.go b/event/state.go index 29e0e524..6d027e04 100644 --- a/event/state.go +++ b/event/state.go @@ -62,6 +62,13 @@ type ExtensibleTextContainer struct { Text []ExtensibleText `json:"m.text"` } +func (c *ExtensibleTextContainer) Equals(description *ExtensibleTextContainer) bool { + if c == nil || description == nil { + return c == description + } + return slices.Equal(c.Text, description.Text) +} + func MakeExtensibleText(text string) *ExtensibleTextContainer { return &ExtensibleTextContainer{ Text: []ExtensibleText{{ diff --git a/event/type.go b/event/type.go index 2a9b382c..b193dc59 100644 --- a/event/type.go +++ b/event/type.go @@ -112,7 +112,8 @@ func (et *Type) GuessClass() TypeClass { StatePowerLevels.Type, StateRoomName.Type, StateRoomAvatar.Type, StateServerACL.Type, StateTopic.Type, StatePinnedEvents.Type, StateTombstone.Type, StateEncryption.Type, StateBridge.Type, StateHalfShotBridge.Type, StateSpaceParent.Type, StateSpaceChild.Type, StatePolicyRoom.Type, StatePolicyServer.Type, StatePolicyUser.Type, - StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type, StateBeeperDisappearingTimer.Type: + StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type, StateBeeperDisappearingTimer.Type, + StateMSC4391BotCommand.Type: return StateEventType case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type: return EphemeralEventType @@ -204,6 +205,7 @@ var ( StateElementFunctionalMembers = Type{"io.element.functional_members", StateEventType} StateBeeperRoomFeatures = Type{"com.beeper.room_features", StateEventType} StateBeeperDisappearingTimer = Type{"com.beeper.disappearing_timer", StateEventType} + StateMSC4391BotCommand = Type{"org.matrix.msc4391.command_description", StateEventType} ) // Message events From d63a008ec6e63b5c1d2a3afdefea9806bbc3c13b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 10 Jan 2026 18:30:27 +0200 Subject: [PATCH 1575/1647] commands: add MSC4391 support --- commands/container.go | 40 +++++++++++++++++++++++++++++++-- commands/event.go | 51 +++++++++++++++++++++++++++++++++++++++++-- commands/handler.go | 46 +++++++++++++++++++++++++++++++++++++- commands/processor.go | 31 +++++++++++++++++++++++--- commands/reactions.go | 28 +++++++++++++++++++----- 5 files changed, 183 insertions(+), 13 deletions(-) diff --git a/commands/container.go b/commands/container.go index bc685b7b..9b909b75 100644 --- a/commands/container.go +++ b/commands/container.go @@ -1,4 +1,4 @@ -// Copyright (c) 2025 Tulir Asokan +// Copyright (c) 2026 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,14 +8,20 @@ package commands import ( "fmt" + "slices" "strings" "sync" + + "go.mau.fi/util/exmaps" + + "maunium.net/go/mautrix/event/cmdschema" ) type CommandContainer[MetaType any] struct { commands map[string]*Handler[MetaType] aliases map[string]string lock sync.RWMutex + parent *Handler[MetaType] } func NewCommandContainer[MetaType any]() *CommandContainer[MetaType] { @@ -25,6 +31,29 @@ func NewCommandContainer[MetaType any]() *CommandContainer[MetaType] { } } +func (cont *CommandContainer[MetaType]) AllSpecs() []*cmdschema.EventContent { + data := make(exmaps.Set[*Handler[MetaType]]) + cont.collectHandlers(data) + specs := make([]*cmdschema.EventContent, 0, data.Size()) + for handler := range data.Iter() { + if handler.Parameters != nil { + specs = append(specs, handler.Spec()) + } + } + return specs +} + +func (cont *CommandContainer[MetaType]) collectHandlers(into exmaps.Set[*Handler[MetaType]]) { + cont.lock.RLock() + defer cont.lock.RUnlock() + for _, handler := range cont.commands { + into.Add(handler) + if handler.subcommandContainer != nil { + handler.subcommandContainer.collectHandlers(into) + } + } +} + // Register registers the given command handlers. func (cont *CommandContainer[MetaType]) Register(handlers ...*Handler[MetaType]) { if cont == nil { @@ -32,7 +61,10 @@ func (cont *CommandContainer[MetaType]) Register(handlers ...*Handler[MetaType]) } cont.lock.Lock() defer cont.lock.Unlock() - for _, handler := range handlers { + for i, handler := range handlers { + if handler == nil { + panic(fmt.Errorf("handler #%d is nil", i+1)) + } cont.registerOne(handler) } } @@ -45,6 +77,10 @@ func (cont *CommandContainer[MetaType]) registerOne(handler *Handler[MetaType]) } else if aliasTarget, alreadyExists := cont.aliases[handler.Name]; alreadyExists { panic(fmt.Errorf("tried to register command %q, but it's already registered as an alias for %q", handler.Name, aliasTarget)) } + if !slices.Contains(handler.parents, cont.parent) { + handler.parents = append(handler.parents, cont.parent) + handler.nestedNameCache = nil + } cont.commands[handler.Name] = handler for _, alias := range handler.Aliases { if strings.ToLower(alias) != alias { diff --git a/commands/event.go b/commands/event.go index 77a3c0d2..76d6c9f0 100644 --- a/commands/event.go +++ b/commands/event.go @@ -1,4 +1,4 @@ -// Copyright (c) 2025 Tulir Asokan +// Copyright (c) 2026 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,6 +8,7 @@ package commands import ( "context" + "encoding/json" "fmt" "strings" @@ -35,6 +36,8 @@ type Event[MetaType any] struct { // RawArgs is the same as args, but without the splitting by whitespace. RawArgs string + StructuredArgs json.RawMessage + Ctx context.Context Log *zerolog.Logger Proc *Processor[MetaType] @@ -61,7 +64,7 @@ var IDHTMLParser = &format.HTMLParser{ } // ParseEvent parses a message into a command event struct. -func ParseEvent[MetaType any](ctx context.Context, evt *event.Event) *Event[MetaType] { +func (proc *Processor[MetaType]) ParseEvent(ctx context.Context, evt *event.Event) *Event[MetaType] { content, ok := evt.Content.Parsed.(*event.MessageEventContent) if !ok || content.MsgType == event.MsgNotice || content.RelatesTo.GetReplaceID() != "" { return nil @@ -70,12 +73,34 @@ func ParseEvent[MetaType any](ctx context.Context, evt *event.Event) *Event[Meta if content.Format == event.FormatHTML { text = IDHTMLParser.Parse(content.FormattedBody, format.NewContext(ctx)) } + if content.MSC4391BotCommand != nil { + if !content.Mentions.Has(proc.Client.UserID) || len(content.Mentions.UserIDs) != 1 { + return nil + } + wrapped := StructuredCommandToEvent[MetaType](ctx, evt, content.MSC4391BotCommand) + wrapped.RawInput = text + return wrapped + } if len(text) == 0 { return nil } return RawTextToEvent[MetaType](ctx, evt, text) } +func StructuredCommandToEvent[MetaType any](ctx context.Context, evt *event.Event, content *event.MSC4391BotCommandInput) *Event[MetaType] { + commandParts := strings.Split(content.Command, " ") + return &Event[MetaType]{ + Event: evt, + // Fake a command and args to let the subcommand finder in Process work. + Command: commandParts[0], + Args: commandParts[1:], + Ctx: ctx, + Log: zerolog.Ctx(ctx), + + StructuredArgs: content.Arguments, + } +} + func RawTextToEvent[MetaType any](ctx context.Context, evt *event.Event, text string) *Event[MetaType] { parts := strings.Fields(text) if len(parts) == 0 { @@ -188,3 +213,25 @@ func (evt *Event[MetaType]) UnshiftArg(arg string) { evt.RawArgs = arg + " " + evt.RawArgs evt.Args = append([]string{arg}, evt.Args...) } + +func (evt *Event[MetaType]) ParseArgs(into any) error { + return json.Unmarshal(evt.StructuredArgs, into) +} + +func ParseArgs[T, MetaType any](evt *Event[MetaType]) (into T, err error) { + err = evt.ParseArgs(&into) + return +} + +func WithParsedArgs[T, MetaType any](fn func(*Event[MetaType], T)) func(*Event[MetaType]) { + return func(evt *Event[MetaType]) { + parsed, err := ParseArgs[T, MetaType](evt) + if err != nil { + evt.Log.Debug().Err(err).Msg("Failed to parse structured args into struct") + // TODO better error, usage info? deduplicate with Process + evt.Reply("Failed to parse arguments: %v", err) + return + } + fn(evt, parsed) + } +} diff --git a/commands/handler.go b/commands/handler.go index b01d594f..3b92a908 100644 --- a/commands/handler.go +++ b/commands/handler.go @@ -1,4 +1,4 @@ -// Copyright (c) 2025 Tulir Asokan +// Copyright (c) 2026 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,6 +8,9 @@ package commands import ( "strings" + + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/event/cmdschema" ) type Handler[MetaType any] struct { @@ -25,12 +28,53 @@ type Handler[MetaType any] struct { // Event.ShiftArg will likely be useful for implementing such parameters. PreFunc func(ce *Event[MetaType]) + // Description is a short description of the command. + Description *event.ExtensibleTextContainer + // Parameters is a description of structured command parameters. + // If set, the StructuredArgs field of Event will be populated. + Parameters []*cmdschema.Parameter + + parents []*Handler[MetaType] + nestedNameCache []string subcommandContainer *CommandContainer[MetaType] } +func (h *Handler[MetaType]) NestedNames() []string { + if h.nestedNameCache != nil { + return h.nestedNameCache + } + nestedNames := make([]string, 0, (1+len(h.Aliases))*len(h.parents)) + for _, parent := range h.parents { + if parent == nil { + nestedNames = append(nestedNames, h.Name) + nestedNames = append(nestedNames, h.Aliases...) + } else { + for _, parentName := range parent.NestedNames() { + nestedNames = append(nestedNames, parentName+" "+h.Name) + for _, alias := range h.Aliases { + nestedNames = append(nestedNames, parentName+" "+alias) + } + } + } + } + h.nestedNameCache = nestedNames + return nestedNames +} + +func (h *Handler[MetaType]) Spec() *cmdschema.EventContent { + names := h.NestedNames() + return &cmdschema.EventContent{ + Command: names[0], + Aliases: names[1:], + Parameters: h.Parameters, + Description: h.Description, + } +} + func (h *Handler[MetaType]) initSubcommandContainer() { if len(h.Subcommands) > 0 { h.subcommandContainer = NewCommandContainer[MetaType]() + h.subcommandContainer.parent = h h.subcommandContainer.Register(h.Subcommands...) } else { h.subcommandContainer = nil diff --git a/commands/processor.go b/commands/processor.go index 9341329b..0089226f 100644 --- a/commands/processor.go +++ b/commands/processor.go @@ -1,4 +1,4 @@ -// Copyright (c) 2025 Tulir Asokan +// Copyright (c) 2026 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -72,9 +72,9 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) case event.EventReaction: parsed = proc.ParseReaction(ctx, evt) case event.EventMessage: - parsed = ParseEvent[MetaType](ctx, evt) + parsed = proc.ParseEvent(ctx, evt) } - if parsed == nil || !proc.PreValidator.Validate(parsed) { + if parsed == nil || (!proc.PreValidator.Validate(parsed) && parsed.StructuredArgs == nil) { return } parsed.Proc = proc @@ -107,6 +107,11 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) break } } + if parsed.StructuredArgs != nil && len(parsed.Args) > 0 { + // The client sent MSC4391 data, but the target command wasn't found + log.Debug().Msg("Didn't find handler for MSC4391 command") + return + } logWith := log.With(). Str("command", parsed.Command). @@ -116,11 +121,31 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) } if proc.LogArgs { logWith = logWith.Strs("args", parsed.Args) + if parsed.StructuredArgs != nil { + logWith = logWith.RawJSON("structured_args", parsed.StructuredArgs) + } } log = logWith.Logger() parsed.Ctx = log.WithContext(ctx) parsed.Log = &log + if handler.Parameters != nil && parsed.StructuredArgs == nil { + // The handler wants structured parameters, but the client didn't send MSC4391 data + var err error + parsed.StructuredArgs, err = handler.Spec().ParseArguments(parsed.RawArgs) + if err != nil { + log.Debug().Err(err).Msg("Failed to parse structured arguments") + // TODO better error, usage info? deduplicate with WithParsedArgs + parsed.Reply("Failed to parse arguments: %v", err) + return + } + if proc.LogArgs { + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.RawJSON("structured_args", parsed.StructuredArgs) + }) + } + } + log.Debug().Msg("Processing command") handler.Func(parsed) } diff --git a/commands/reactions.go b/commands/reactions.go index 0df372e5..0d316219 100644 --- a/commands/reactions.go +++ b/commands/reactions.go @@ -1,4 +1,4 @@ -// Copyright (c) 2025 Tulir Asokan +// Copyright (c) 2026 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -8,6 +8,7 @@ package commands import ( "context" + "encoding/json" "strings" "github.com/rs/zerolog" @@ -19,6 +20,11 @@ import ( const ReactionCommandsKey = "fi.mau.reaction_commands" const ReactionMultiUseKey = "fi.mau.reaction_multi_use" +type ReactionCommandData struct { + Command string `json:"command"` + Args any `json:"args,omitempty"` +} + func (proc *Processor[MetaType]) ParseReaction(ctx context.Context, evt *event.Event) *Event[MetaType] { content, ok := evt.Content.Parsed.(*event.ReactionEventContent) if !ok { @@ -67,21 +73,33 @@ func (proc *Processor[MetaType]) ParseReaction(ctx context.Context, evt *event.E Msg("Reaction command not found in target event") return nil } - cmdString, ok := rawCmd.(string) - if !ok { + var wrappedEvt *Event[MetaType] + switch typedCmd := rawCmd.(type) { + case string: + wrappedEvt = RawTextToEvent[MetaType](ctx, evt, typedCmd) + case map[string]any: + var input event.MSC4391BotCommandInput + if marshaled, err := json.Marshal(typedCmd); err != nil { + + } else if err = json.Unmarshal(marshaled, &input); err != nil { + + } else { + wrappedEvt = StructuredCommandToEvent[MetaType](ctx, evt, &input) + } + } + if wrappedEvt == nil { zerolog.Ctx(ctx).Debug(). Stringer("target_event_id", evtID). Str("reaction_key", content.RelatesTo.Key). Msg("Reaction command data is invalid") return nil } - wrappedEvt := RawTextToEvent[MetaType](ctx, evt, cmdString) wrappedEvt.Proc = proc wrappedEvt.Redact() if !isMultiUse { DeleteAllReactions(ctx, proc.Client, evt) } - if cmdString == "" { + if wrappedEvt.Command == "" { return nil } return wrappedEvt From 60be95440731a8809609171916d56bc0622521fe Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 11 Jan 2026 23:42:16 +0200 Subject: [PATCH 1576/1647] event/cmdschema: make boolean parsing stricter --- event/cmdschema/parse.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/event/cmdschema/parse.go b/event/cmdschema/parse.go index 6536b410..5269ab28 100644 --- a/event/cmdschema/parse.go +++ b/event/cmdschema/parse.go @@ -331,10 +331,10 @@ func parseBoolean(val string) (bool, error) { if len(val) == 0 { return false, fmt.Errorf("cannot parse empty string as boolean") } - switch val[0] { - case 't', 'T', 'y', 'Y', '1': + switch strings.ToLower(val) { + case "t", "true", "y", "yes", "1": return true, nil - case 'f', 'F', 'n', 'N', '0': + case "f", "false", "n", "no", "0": return false, nil default: return false, fmt.Errorf("invalid boolean string: %q", val) From 4cd376cd90553f21cf65aac3c9e3e8116c2cfeec Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 11 Jan 2026 23:42:24 +0200 Subject: [PATCH 1577/1647] event/cmdschema: disallow positional optional parameters and add tail parameters --- commands/handler.go | 10 ++++++++ commands/processor.go | 1 + event/cmdschema/content.go | 13 +++++++++++ event/cmdschema/parse.go | 11 +++++---- event/cmdschema/parse_test.go | 2 +- event/cmdschema/testdata/commands/flags.json | 24 ++++++-------------- 6 files changed, 38 insertions(+), 23 deletions(-) diff --git a/commands/handler.go b/commands/handler.go index 3b92a908..56f27f06 100644 --- a/commands/handler.go +++ b/commands/handler.go @@ -33,6 +33,7 @@ type Handler[MetaType any] struct { // Parameters is a description of structured command parameters. // If set, the StructuredArgs field of Event will be populated. Parameters []*cmdschema.Parameter + TailParam string parents []*Handler[MetaType] nestedNameCache []string @@ -68,9 +69,18 @@ func (h *Handler[MetaType]) Spec() *cmdschema.EventContent { Aliases: names[1:], Parameters: h.Parameters, Description: h.Description, + TailParam: h.TailParam, } } +func (h *Handler[MetaType]) CopyFrom(other *Handler[MetaType]) { + if h.Parameters == nil { + h.Parameters = other.Parameters + h.TailParam = other.TailParam + } + h.Func = other.Func +} + func (h *Handler[MetaType]) initSubcommandContainer() { if len(h.Subcommands) > 0 { h.subcommandContainer = NewCommandContainer[MetaType]() diff --git a/commands/processor.go b/commands/processor.go index 0089226f..80f6745d 100644 --- a/commands/processor.go +++ b/commands/processor.go @@ -108,6 +108,7 @@ func (proc *Processor[MetaType]) Process(ctx context.Context, evt *event.Event) } } if parsed.StructuredArgs != nil && len(parsed.Args) > 0 { + // TODO allow unknown command handlers to be called? // The client sent MSC4391 data, but the target command wasn't found log.Debug().Msg("Didn't find handler for MSC4391 command") return diff --git a/event/cmdschema/content.go b/event/cmdschema/content.go index b69f0c1f..e7f362ed 100644 --- a/event/cmdschema/content.go +++ b/event/cmdschema/content.go @@ -13,6 +13,7 @@ import ( "reflect" "slices" + "go.mau.fi/util/exsync" "go.mau.fi/util/ptr" "maunium.net/go/mautrix/event" @@ -24,6 +25,7 @@ type EventContent struct { Aliases []string `json:"aliases,omitempty"` Parameters []*Parameter `json:"parameters,omitempty"` Description *event.ExtensibleTextContainer `json:"description,omitempty"` + TailParam string `json:"fi.mau.tail_parameter,omitempty"` } func (ec *EventContent) Validate() error { @@ -32,11 +34,22 @@ func (ec *EventContent) Validate() error { } else if ec.Command == "" { return fmt.Errorf("command is empty") } + var tailFound bool + dupMap := exsync.NewSet[string]() for i, p := range ec.Parameters { if err := p.Validate(); err != nil { return fmt.Errorf("parameter %q (#%d) is invalid: %w", ptr.Val(p).Key, i+1, err) + } else if !dupMap.Add(p.Key) { + return fmt.Errorf("duplicate parameter key %q at #%d", p.Key, i+1) + } else if p.Key == ec.TailParam { + tailFound = true + } else if tailFound && !p.Optional { + return fmt.Errorf("required parameter %q (#%d) is after tail parameter %q", p.Key, i+1, ec.TailParam) } } + if ec.TailParam != "" && !tailFound { + return fmt.Errorf("tail parameter %q not found in parameters", ec.TailParam) + } return nil } diff --git a/event/cmdschema/parse.go b/event/cmdschema/parse.go index 5269ab28..91a02827 100644 --- a/event/cmdschema/parse.go +++ b/event/cmdschema/parse.go @@ -135,8 +135,8 @@ func (ec *EventContent) ParseArguments(input string) (json.RawMessage, error) { args[param.Key] = collector } else { nextVal, input, wasQuoted = parseQuoted(input) - if isLast && !wasQuoted && len(input) > 0 { - // If the last argument is not quoted and not variadic, just treat the rest of the string + if isLast && !wasQuoted && len(input) > 0 && !strings.Contains(input, "--") { + // If the last argument is not quoted and doesn't have flags, just treat the rest of the string // as the argument without escapes (arguments with escapes should be quoted). nextVal += " " + input input = "" @@ -146,7 +146,7 @@ func (ec *EventContent) ParseArguments(input string) (json.RawMessage, error) { args[param.Key] = true return } - if nextVal == "" && !param.Optional { + if nextVal == "" && !wasQuoted && !isNamed && !param.Optional { setError(fmt.Errorf("missing value for required parameter %s", param.Key)) } parsedVal, err := param.Schema.ParseString(nextVal) @@ -180,10 +180,11 @@ func (ec *EventContent) ParseArguments(input string) (json.RawMessage, error) { break } } - if skipParams[i] { + isTail := param.Key == ec.TailParam + if skipParams[i] || (param.Optional && !isTail) { continue } - processParameter(param, i == len(ec.Parameters)-1, false) + processParameter(param, i == len(ec.Parameters)-1 || isTail, false) } jsonArgs, marshalErr := json.Marshal(args) if marshalErr != nil { diff --git a/event/cmdschema/parse_test.go b/event/cmdschema/parse_test.go index 725b0150..1e0d1817 100644 --- a/event/cmdschema/parse_test.go +++ b/event/cmdschema/parse_test.go @@ -109,7 +109,7 @@ func TestMSC4391BotCommandEventContent_ParseInput(t *testing.T) { assert.Nil(t, output) } else { assert.Equal(t, ctd.Spec.Command, output.MSC4391BotCommand.Command) - assert.Equal(t, outputStr, exbytes.UnsafeString(output.MSC4391BotCommand.Arguments)) + assert.Equalf(t, outputStr, exbytes.UnsafeString(output.MSC4391BotCommand.Arguments), "Input: %s", test.Input) } }) } diff --git a/event/cmdschema/testdata/commands/flags.json b/event/cmdschema/testdata/commands/flags.json index dedde348..469986f0 100644 --- a/event/cmdschema/testdata/commands/flags.json +++ b/event/cmdschema/testdata/commands/flags.json @@ -27,7 +27,8 @@ "optional": true, "fi.mau.default_value": false } - ] + ], + "fi.mau.tail_parameter": "user" }, "tests": [ { @@ -35,17 +36,15 @@ "input": "/flag mrrp", "output": { "meow": "mrrp", - "user": null, - "woof": false + "user": null } }, { - "name": "positional flag", - "input": "/flag mrrp @user:example.com yes", + "name": "no flags, has tail", + "input": "/flag mrrp @user:example.com", "output": { "meow": "mrrp", - "user": "@user:example.com", - "woof": true + "user": "@user:example.com" } }, { @@ -130,18 +129,9 @@ "woof": true } }, - { - "name": "only string variables named", - "input": "/flag --user=@user:example.com --meow=mrrp yes", - "output": { - "meow": "mrrp", - "user": "@user:example.com", - "woof": true - } - }, { "name": "invalid value for named parameter", - "input": "/flag --user=meowings mrrp yes", + "input": "/flag --user=meowings mrrp --woof", "error": true, "output": { "meow": "mrrp", From e034c16753eb0b4a898e3a3aca1bfa2425856933 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 11 Jan 2026 23:54:10 +0200 Subject: [PATCH 1578/1647] event/cmdschema: don't allow flags after tail parameter --- event/cmdschema/content.go | 3 +- event/cmdschema/parse.go | 10 ++-- event/cmdschema/testdata/commands/flags.json | 18 ------ event/cmdschema/testdata/commands/tail.json | 59 ++++++++++++++++++++ 4 files changed, 66 insertions(+), 24 deletions(-) create mode 100644 event/cmdschema/testdata/commands/tail.json diff --git a/event/cmdschema/content.go b/event/cmdschema/content.go index e7f362ed..ce07c4c0 100644 --- a/event/cmdschema/content.go +++ b/event/cmdschema/content.go @@ -69,7 +69,8 @@ func (ec *EventContent) Equals(other *EventContent) bool { return ec.Command == other.Command && slices.Equal(ec.Aliases, other.Aliases) && slices.EqualFunc(ec.Parameters, other.Parameters, (*Parameter).Equals) && - ec.Description.Equals(other.Description) + ec.Description.Equals(other.Description) && + ec.TailParam == other.TailParam } func init() { diff --git a/event/cmdschema/parse.go b/event/cmdschema/parse.go index 91a02827..fbb8b671 100644 --- a/event/cmdschema/parse.go +++ b/event/cmdschema/parse.go @@ -92,7 +92,7 @@ func (ec *EventContent) ParseArguments(input string) (json.RawMessage, error) { retErr = err } } - processParameter := func(param *Parameter, isLast, isNamed bool) { + processParameter := func(param *Parameter, isLast, isTail, isNamed bool) { origInput := input var nextVal string var wasQuoted bool @@ -135,8 +135,8 @@ func (ec *EventContent) ParseArguments(input string) (json.RawMessage, error) { args[param.Key] = collector } else { nextVal, input, wasQuoted = parseQuoted(input) - if isLast && !wasQuoted && len(input) > 0 && !strings.Contains(input, "--") { - // If the last argument is not quoted and doesn't have flags, just treat the rest of the string + if (isLast || isTail) && !wasQuoted && len(input) > 0 { + // If the last argument is not quoted, just treat the rest of the string // as the argument without escapes (arguments with escapes should be quoted). nextVal += " " + input input = "" @@ -175,7 +175,7 @@ func (ec *EventContent) ParseArguments(input string) (json.RawMessage, error) { // Trim the equals sign, but leave spaces alone to let parseQuoted treat it as empty input input = strings.TrimPrefix(input[nameEndIdx:], "=") skipParams[paramIdx] = true - processParameter(overrideParam, false, true) + processParameter(overrideParam, false, false, true) } else { break } @@ -184,7 +184,7 @@ func (ec *EventContent) ParseArguments(input string) (json.RawMessage, error) { if skipParams[i] || (param.Optional && !isTail) { continue } - processParameter(param, i == len(ec.Parameters)-1 || isTail, false) + processParameter(param, i == len(ec.Parameters)-1, isTail, false) } jsonArgs, marshalErr := json.Marshal(args) if marshalErr != nil { diff --git a/event/cmdschema/testdata/commands/flags.json b/event/cmdschema/testdata/commands/flags.json index 469986f0..89f0f334 100644 --- a/event/cmdschema/testdata/commands/flags.json +++ b/event/cmdschema/testdata/commands/flags.json @@ -93,24 +93,6 @@ "woof": false } }, - { - "name": "named flag at end", - "input": "/flag mrrp @user:example.com --woof=yes", - "output": { - "meow": "mrrp", - "user": "@user:example.com", - "woof": true - } - }, - { - "name": "named flag at end without value", - "input": "/flag mrrp @user:example.com --woof", - "output": { - "meow": "mrrp", - "user": "@user:example.com", - "woof": true - } - }, { "name": "all variables named", "input": "/flag --woof=no --meow=mrrp --user=@user:example.com", diff --git a/event/cmdschema/testdata/commands/tail.json b/event/cmdschema/testdata/commands/tail.json new file mode 100644 index 00000000..db6f79d4 --- /dev/null +++ b/event/cmdschema/testdata/commands/tail.json @@ -0,0 +1,59 @@ +{ + "spec": { + "command": "tail", + "source": "@testbot", + "parameters": [ + { + "key": "meow", + "schema": { + "schema_type": "primitive", + "type": "string" + } + }, + { + "key": "reason", + "schema": { + "schema_type": "primitive", + "type": "string" + }, + "optional": true + }, + { + "key": "woof", + "schema": { + "schema_type": "primitive", + "type": "boolean" + }, + "optional": true + } + ], + "fi.mau.tail_parameter": "reason" + }, + "tests": [ + { + "name": "no tail or flag", + "input": "/tail mrrp", + "output": { + "meow": "mrrp", + "reason": "" + } + }, + { + "name": "tail, no flag", + "input": "/tail mrrp meow meow", + "output": { + "meow": "mrrp", + "reason": "meow meow" + } + }, + { + "name": "flag before tail", + "input": "/tail mrrp --woof meow meow", + "output": { + "meow": "mrrp", + "reason": "meow meow", + "woof": true + } + } + ] +} From 4c0b511c01e06d3dceaa5f2698fd6e9710b7181c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 12 Jan 2026 00:52:24 +0200 Subject: [PATCH 1579/1647] event/cmdschema: add JSON schemas for test data --- event/cmdschema/testdata/commands.schema.json | 281 ++++++++++++++++++ event/cmdschema/testdata/commands/flags.json | 1 + .../testdata/commands/room_id_or_alias.json | 1 + .../commands/room_reference_list.json | 1 + event/cmdschema/testdata/commands/simple.json | 1 + event/cmdschema/testdata/commands/tail.json | 1 + .../testdata/parse_quote.schema.json | 46 +++ 7 files changed, 332 insertions(+) create mode 100644 event/cmdschema/testdata/commands.schema.json create mode 100644 event/cmdschema/testdata/parse_quote.schema.json diff --git a/event/cmdschema/testdata/commands.schema.json b/event/cmdschema/testdata/commands.schema.json new file mode 100644 index 00000000..e53382db --- /dev/null +++ b/event/cmdschema/testdata/commands.schema.json @@ -0,0 +1,281 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema#", + "$id": "commands.schema.json", + "title": "ParseInput test cases", + "description": "JSON schema for test case files containing command specifications and test cases", + "type": "object", + "required": [ + "spec", + "tests" + ], + "additionalProperties": false, + "properties": { + "spec": { + "title": "MSC4391 Command Description", + "description": "JSON schema defining the structure of a bot command event content", + "type": "object", + "required": [ + "command" + ], + "additionalProperties": false, + "properties": { + "command": { + "type": "string", + "description": "The command name that triggers this bot command" + }, + "aliases": { + "type": "array", + "description": "Alternative names/aliases for this command", + "items": { + "type": "string" + } + }, + "parameters": { + "type": "array", + "description": "List of parameters accepted by this command", + "items": { + "$ref": "#/$defs/Parameter" + } + }, + "description": { + "$ref": "#/$defs/ExtensibleTextContainer", + "description": "Human-readable description of the command" + }, + "fi.mau.tail_parameter": { + "type": "string", + "description": "The key of the parameter that accepts remaining arguments as tail text" + }, + "source": { + "type": "string", + "description": "The user ID of the bot that responds to this command" + } + } + }, + "tests": { + "type": "array", + "description": "Array of test cases for the command", + "items": { + "type": "object", + "description": "A single test case for command parsing", + "required": [ + "name", + "input" + ], + "additionalProperties": false, + "properties": { + "name": { + "type": "string", + "description": "The name of the test case" + }, + "input": { + "type": "string", + "description": "The command input string to parse" + }, + "output": { + "description": "The expected parsed parameter values, or null if the parsing is expected to fail", + "oneOf": [ + { + "type": "object", + "additionalProperties": true + }, + { + "type": "null" + } + ] + }, + "error": { + "type": "boolean", + "description": "Whether parsing should result in an error. May still produce output.", + "default": false + } + } + } + } + }, + "$defs": { + "ExtensibleTextContainer": { + "type": "object", + "description": "Container for text that can have multiple representations", + "required": [ + "m.text" + ], + "properties": { + "m.text": { + "type": "array", + "description": "Array of text representations in different formats", + "items": { + "$ref": "#/$defs/ExtensibleText" + } + } + } + }, + "ExtensibleText": { + "type": "object", + "description": "A text representation with a specific MIME type", + "required": [ + "body" + ], + "properties": { + "body": { + "type": "string", + "description": "The text content" + }, + "mimetype": { + "type": "string", + "description": "The MIME type of the text (e.g., text/plain, text/html)", + "default": "text/plain", + "examples": [ + "text/plain", + "text/html" + ] + } + } + }, + "Parameter": { + "type": "object", + "description": "A parameter definition for a command", + "required": [ + "key", + "schema" + ], + "additionalProperties": false, + "properties": { + "key": { + "type": "string", + "description": "The identifier for this parameter" + }, + "schema": { + "$ref": "#/$defs/ParameterSchema", + "description": "The schema defining the type and structure of this parameter" + }, + "optional": { + "type": "boolean", + "description": "Whether this parameter is optional", + "default": false + }, + "description": { + "$ref": "#/$defs/ExtensibleTextContainer", + "description": "Human-readable description of this parameter" + }, + "fi.mau.default_value": { + "description": "Default value for this parameter if not provided" + } + } + }, + "ParameterSchema": { + "type": "object", + "description": "Schema definition for a parameter value", + "required": [ + "schema_type" + ], + "additionalProperties": false, + "properties": { + "schema_type": { + "type": "string", + "enum": [ + "primitive", + "array", + "union", + "literal" + ], + "description": "The type of schema" + } + }, + "allOf": [ + { + "if": { + "properties": { + "schema_type": { + "const": "primitive" + } + } + }, + "then": { + "required": [ + "type" + ], + "properties": { + "type": { + "type": "string", + "enum": [ + "string", + "integer", + "boolean", + "server_name", + "user_id", + "room_id", + "room_alias", + "event_id" + ], + "description": "The primitive type (only for schema_type: primitive)" + } + } + } + }, + { + "if": { + "properties": { + "schema_type": { + "const": "array" + } + } + }, + "then": { + "required": [ + "items" + ], + "properties": { + "items": { + "$ref": "#/$defs/ParameterSchema", + "description": "The schema for array items (only for schema_type: array)" + } + } + } + }, + { + "if": { + "properties": { + "schema_type": { + "const": "union" + } + } + }, + "then": { + "required": [ + "variants" + ], + "properties": { + "variants": { + "type": "array", + "description": "The possible variants (only for schema_type: union)", + "items": { + "$ref": "#/$defs/ParameterSchema" + }, + "minItems": 1 + } + } + } + }, + { + "if": { + "properties": { + "schema_type": { + "const": "literal" + } + } + }, + "then": { + "required": [ + "value" + ], + "properties": { + "value": { + "description": "The literal value (only for schema_type: literal)" + } + } + } + } + ] + } + } +} diff --git a/event/cmdschema/testdata/commands/flags.json b/event/cmdschema/testdata/commands/flags.json index 89f0f334..6ce1f4da 100644 --- a/event/cmdschema/testdata/commands/flags.json +++ b/event/cmdschema/testdata/commands/flags.json @@ -1,4 +1,5 @@ { + "$schema": "../commands.schema.json#", "spec": { "command": "flag", "source": "@testbot", diff --git a/event/cmdschema/testdata/commands/room_id_or_alias.json b/event/cmdschema/testdata/commands/room_id_or_alias.json index 0dc233b8..1351c292 100644 --- a/event/cmdschema/testdata/commands/room_id_or_alias.json +++ b/event/cmdschema/testdata/commands/room_id_or_alias.json @@ -1,4 +1,5 @@ { + "$schema": "../commands.schema.json#", "spec": { "command": "test room reference", "source": "@testbot", diff --git a/event/cmdschema/testdata/commands/room_reference_list.json b/event/cmdschema/testdata/commands/room_reference_list.json index 99388f90..aa266054 100644 --- a/event/cmdschema/testdata/commands/room_reference_list.json +++ b/event/cmdschema/testdata/commands/room_reference_list.json @@ -1,4 +1,5 @@ { + "$schema": "../commands.schema.json#", "spec": { "command": "test room reference", "source": "@testbot", diff --git a/event/cmdschema/testdata/commands/simple.json b/event/cmdschema/testdata/commands/simple.json index 8127aff1..94667323 100644 --- a/event/cmdschema/testdata/commands/simple.json +++ b/event/cmdschema/testdata/commands/simple.json @@ -1,4 +1,5 @@ { + "$schema": "../commands.schema.json#", "spec": { "command": "test simple", "source": "@testbot", diff --git a/event/cmdschema/testdata/commands/tail.json b/event/cmdschema/testdata/commands/tail.json index db6f79d4..9782f8ec 100644 --- a/event/cmdschema/testdata/commands/tail.json +++ b/event/cmdschema/testdata/commands/tail.json @@ -1,4 +1,5 @@ { + "$schema": "../commands.schema.json#", "spec": { "command": "tail", "source": "@testbot", diff --git a/event/cmdschema/testdata/parse_quote.schema.json b/event/cmdschema/testdata/parse_quote.schema.json new file mode 100644 index 00000000..9f249116 --- /dev/null +++ b/event/cmdschema/testdata/parse_quote.schema.json @@ -0,0 +1,46 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema#", + "$id": "parse_quote.schema.json", + "title": "parseQuote test cases", + "description": "Test cases for the parseQuoted function", + "type": "array", + "items": { + "type": "object", + "required": [ + "name", + "input", + "output" + ], + "properties": { + "name": { + "type": "string", + "description": "Name of the test case" + }, + "input": { + "type": "string", + "description": "Input string to be parsed" + }, + "output": { + "type": "array", + "description": "Expected output of parsing: [first word, remaining text, was quoted]", + "minItems": 3, + "maxItems": 3, + "prefixItems": [ + { + "type": "string", + "description": "First parsed word" + }, + { + "type": "string", + "description": "Remaining text after the first word" + }, + { + "type": "boolean", + "description": "Whether the first word was quoted" + } + ] + } + }, + "additionalProperties": false + } +} From 650f9c3139b5e45ef4df6bce914e1bb235607f8b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 12 Jan 2026 00:57:12 +0200 Subject: [PATCH 1580/1647] event/cmdschema: adjust handling of unterminated quotes --- event/cmdschema/parse.go | 8 +++++++- event/cmdschema/testdata/parse_quote.json | 12 +++++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/event/cmdschema/parse.go b/event/cmdschema/parse.go index fbb8b671..92e69b60 100644 --- a/event/cmdschema/parse.go +++ b/event/cmdschema/parse.go @@ -39,7 +39,13 @@ func parseQuoted(val string) (parsed, remaining string, quoted bool) { var buf strings.Builder for { quoteIdx := strings.IndexByte(val, '"') - escapeIdx := strings.IndexByte(val[:max(0, quoteIdx)], '\\') + var valUntilQuote string + if quoteIdx == -1 { + valUntilQuote = val + } else { + valUntilQuote = val[:quoteIdx] + } + escapeIdx := strings.IndexByte(valUntilQuote, '\\') if escapeIdx >= 0 { buf.WriteString(val[:escapeIdx]) if len(val) > escapeIdx+1 { diff --git a/event/cmdschema/testdata/parse_quote.json b/event/cmdschema/testdata/parse_quote.json index d22f1299..8f52b7f5 100644 --- a/event/cmdschema/testdata/parse_quote.json +++ b/event/cmdschema/testdata/parse_quote.json @@ -1,12 +1,21 @@ [ + {"name": "empty string", "input": "", "output": ["", "", false]}, {"name": "single word", "input": "meow", "output": ["meow", "", false]}, {"name": "two words", "input": "meow woof", "output": ["meow", "woof", false]}, {"name": "many words", "input": "meow meow mrrp", "output": ["meow", "meow mrrp", false]}, {"name": "extra spaces", "input": "meow meow mrrp", "output": ["meow", "meow mrrp", false]}, {"name": "trailing space", "input": "meow ", "output": ["meow", "", false]}, + {"name": "only spaces", "input": " ", "output": ["", "", false]}, + {"name": "leading spaces", "input": " meow woof", "output": ["", "meow woof", false]}, + {"name": "backslash at end unquoted", "input": "meow\\ woof", "output": ["meow\\", "woof", false]}, {"name": "quoted word", "input": "\"meow\" meow mrrp", "output": ["meow", "meow mrrp", true]}, {"name": "quoted words", "input": "\"meow meow\" mrrp", "output": ["meow meow", "mrrp", true]}, {"name": "spaces in quotes", "input": "\" meow meow \" mrrp", "output": [" meow meow ", "mrrp", true]}, + {"name": "empty quoted string", "input": "\"\"", "output": ["", "", true]}, + {"name": "empty quoted with trailing", "input": "\"\" meow", "output": ["", "meow", true]}, + {"name": "quote no space before next", "input": "\"meow\"woof", "output": ["meow", "woof", true]}, + {"name": "just opening quote", "input": "\"", "output": ["", "", true]}, + {"name": "quote then space then text", "input": "\" meow", "output": [" meow", "", true]}, {"name": "quotes after word", "input": "meow \" meow mrrp \"", "output": ["meow", "\" meow mrrp \"", false]}, {"name": "escaped quote", "input": "\"meow\\\" meow\" mrrp", "output": ["meow\" meow", "mrrp", true]}, {"name": "missing end quote", "input": "\"meow meow mrrp", "output": ["meow meow mrrp", "", true]}, @@ -16,5 +25,6 @@ {"name": "other escaped character", "input": "\"m\\eow\" meow mrrp", "output": ["meow", "meow mrrp", true]}, {"name": "escaped backslashes", "input": "\"m\\\\e\\\"ow\\\\\" meow mrrp", "output": ["m\\e\"ow\\", "meow mrrp", true]}, {"name": "just quotes", "input": "\"\\\"\\\"\\\\\\\"\" meow", "output": ["\"\"\\\"", "meow", true]}, - {"name": "eof escape", "input": "\"meow\\", "output": ["meow\\", "", true]} + {"name": "escape at eof", "input": "\"meow\\", "output": ["meow", "", true]}, + {"name": "escaped backslash at eof", "input": "\"meow\\\\", "output": ["meow\\", "", true]} ] From 9d70b2b845caf77f2e3793f548465f650c4b9755 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 12 Jan 2026 12:33:55 +0200 Subject: [PATCH 1581/1647] bridgev2/matrixinterface: properly expose GetProvisioning --- bridgev2/matrix/provisioning.go | 7 +------ bridgev2/matrixinterface.go | 9 +++++++++ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 44e00e64..e3d3a0b4 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -96,12 +96,7 @@ func (prov *ProvisioningAPI) GetRouter() *http.ServeMux { return prov.Router } -type IProvisioningAPI interface { - GetRouter() *http.ServeMux - GetUser(r *http.Request) *bridgev2.User -} - -func (br *Connector) GetProvisioning() IProvisioningAPI { +func (br *Connector) GetProvisioning() bridgev2.IProvisioningAPI { return br.Provisioning } diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 07615daf..f24390bf 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -67,6 +67,15 @@ type MatrixConnectorWithServer interface { GetRouter() *http.ServeMux } +type IProvisioningAPI interface { + GetRouter() *http.ServeMux + GetUser(r *http.Request) *User +} + +type MatrixConnectorWithProvisioning interface { + GetProvisioning() IProvisioningAPI +} + type MatrixConnectorWithPublicMedia interface { GetPublicMediaAddress(contentURI id.ContentURIString) string GetPublicMediaAddressForEvent(ctx context.Context, evt *event.MessageEventContent) (string, error) From 3d5de4ed2fb012767c9c7ba1227bac1c0f420880 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 13 Jan 2026 17:05:06 +0200 Subject: [PATCH 1582/1647] bridgev2/matrixinterface: add parent interface to MatrixConnector subinterfaces --- bridgev2/matrixinterface.go | 12 ++++++++++++ bridgev2/networkinterface.go | 21 +++++++++++---------- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index f24390bf..f9695c19 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -59,10 +59,12 @@ type MatrixConnector interface { } type MatrixConnectorWithArbitraryRoomState interface { + MatrixConnector GetStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string) (*event.Event, error) } type MatrixConnectorWithServer interface { + MatrixConnector GetPublicAddress() string GetRouter() *http.ServeMux } @@ -73,31 +75,38 @@ type IProvisioningAPI interface { } type MatrixConnectorWithProvisioning interface { + MatrixConnector GetProvisioning() IProvisioningAPI } type MatrixConnectorWithPublicMedia interface { + MatrixConnector GetPublicMediaAddress(contentURI id.ContentURIString) string GetPublicMediaAddressForEvent(ctx context.Context, evt *event.MessageEventContent) (string, error) } type MatrixConnectorWithNameDisambiguation interface { + MatrixConnector IsConfusableName(ctx context.Context, roomID id.RoomID, userID id.UserID, name string) ([]id.UserID, error) } type MatrixConnectorWithBridgeIdentifier interface { + MatrixConnector GetUniqueBridgeID() string } type MatrixConnectorWithURLPreviews interface { + MatrixConnector GetURLPreview(ctx context.Context, url string) (*event.LinkPreview, error) } type MatrixConnectorWithPostRoomBridgeHandling interface { + MatrixConnector HandleNewlyBridgedRoom(ctx context.Context, roomID id.RoomID) error } type MatrixConnectorWithAnalytics interface { + MatrixConnector TrackAnalytics(userID id.UserID, event string, properties map[string]any) } @@ -112,6 +121,7 @@ type DirectNotificationData struct { } type MatrixConnectorWithNotifications interface { + MatrixConnector DisplayNotification(ctx context.Context, data *DirectNotificationData) } @@ -192,9 +202,11 @@ type MatrixAPI interface { } type StreamOrderReadingMatrixAPI interface { + MatrixAPI MarkStreamOrderRead(ctx context.Context, roomID id.RoomID, streamOrder int64, ts time.Time) error } type MarkAsDMMatrixAPI interface { + MatrixAPI MarkAsDM(ctx context.Context, roomID id.RoomID, otherUser id.UserID) error } diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index adbd3155..3e25031f 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -261,6 +261,7 @@ type NetworkConnector interface { } type StoppableNetwork interface { + NetworkConnector // Stop is called when the bridge is stopping, after all network clients have been disconnected. Stop() } @@ -295,11 +296,6 @@ type PortalBridgeInfoFillingNetwork interface { FillPortalBridgeInfo(portal *Portal, content *event.BridgeEventContent) } -type PersonalFilteringCustomizingNetworkAPI interface { - NetworkAPI - CustomizePersonalFilteringSpace(req *mautrix.ReqCreateRoom) -} - // ConfigValidatingNetwork is an optional interface that network connectors can implement to validate config fields // before the bridge is started. // @@ -792,6 +788,16 @@ type UserSearchingNetworkAPI interface { SearchUsers(ctx context.Context, query string) ([]*ResolveIdentifierResponse, error) } +type GroupCreatingNetworkAPI interface { + IdentifierResolvingNetworkAPI + CreateGroup(ctx context.Context, params *GroupCreateParams) (*CreateChatResponse, error) +} + +type PersonalFilteringCustomizingNetworkAPI interface { + NetworkAPI + CustomizePersonalFilteringSpace(req *mautrix.ReqCreateRoom) +} + type ProvisioningCapabilities struct { ResolveIdentifier ResolveIdentifierCapabilities `json:"resolve_identifier"` GroupCreation map[string]GroupTypeCapabilities `json:"group_creation"` @@ -863,11 +869,6 @@ type GroupCreateParams struct { RoomID id.RoomID `json:"room_id,omitempty"` } -type GroupCreatingNetworkAPI interface { - IdentifierResolvingNetworkAPI - CreateGroup(ctx context.Context, params *GroupCreateParams) (*CreateChatResponse, error) -} - type MembershipChangeType struct { From event.Membership To event.Membership From d77cb628ffd2e1921897e7379526f2c011a68817 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 13 Jan 2026 23:11:50 +0200 Subject: [PATCH 1583/1647] bridgev2/matrixinterface: let matrix connector suggest HTTP client settings --- bridgev2/matrixinterface.go | 7 +++++++ go.mod | 2 +- go.sum | 4 ++-- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index f9695c19..57f786bb 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -14,6 +14,8 @@ import ( "os" "time" + "go.mau.fi/util/exhttp" + "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" @@ -125,6 +127,11 @@ type MatrixConnectorWithNotifications interface { DisplayNotification(ctx context.Context, data *DirectNotificationData) } +type MatrixConnectorWithHTTPSettings interface { + MatrixConnector + GetHTTPClientSettings() exhttp.ClientSettings +} + type MatrixSendExtra struct { Timestamp time.Time MessageMeta *database.Message diff --git a/go.mod b/go.mod index cdb62f20..544a9ff4 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.13 - go.mau.fi/util v0.9.4 + go.mau.fi/util v0.9.5-0.20260113180831-8cda92561373 go.mau.fi/zeroconfig v0.2.0 golang.org/x/crypto v0.46.0 golang.org/x/exp v0.0.0-20251209150349-8475f28825e9 diff --git a/go.sum b/go.sum index a55f0661..70a1b5a9 100644 --- a/go.sum +++ b/go.sum @@ -52,8 +52,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.13 h1:GPddIs617DnBLFFVJFgpo1aBfe/4xcvMc3SB5t/D0pA= github.com/yuin/goldmark v1.7.13/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.9.4 h1:gWdUff+K2rCynRPysXalqqQyr2ahkSWaestH6YhSpso= -go.mau.fi/util v0.9.4/go.mod h1:647nVfwUvuhlZFOnro3aRNPmRd2y3iDha9USb8aKSmM= +go.mau.fi/util v0.9.5-0.20260113180831-8cda92561373 h1:LjFGO80c9mGeYCvrBsASvK9jx3oPkXo++l9quy4YMls= +go.mau.fi/util v0.9.5-0.20260113180831-8cda92561373/go.mod h1:647nVfwUvuhlZFOnro3aRNPmRd2y3iDha9USb8aKSmM= go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= From 38799be3ca6f9ae112ac33677435ada2df0bb50a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 13 Jan 2026 23:17:41 +0200 Subject: [PATCH 1584/1647] bridgev2/networkinterface: let matrix connector reset remote network connections --- bridgev2/bridge.go | 28 ++++++++++++++++++++++++++++ bridgev2/networkinterface.go | 10 ++++++++++ bridgev2/userlogin.go | 7 +++++++ 3 files changed, 45 insertions(+) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index c84c2fd5..3825333c 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -373,6 +373,34 @@ func (br *Bridge) StartLogins(ctx context.Context) error { return nil } +func (br *Bridge) ResetNetworkConnections() { + nrn, ok := br.Network.(NetworkResettingNetwork) + if ok { + br.Log.Info().Msg("Resetting network connections with NetworkConnector.ResetNetworkConnections") + nrn.ResetNetworkConnections() + return + } + + br.Log.Info().Msg("Network connector doesn't support ResetNetworkConnections, recreating clients manually") + for _, login := range br.GetAllCachedUserLogins() { + login.Log.Debug().Msg("Disconnecting and recreating client for network reset") + ctx := login.Log.WithContext(br.BackgroundCtx) + login.Client.Disconnect() + err := login.recreateClient(ctx) + if err != nil { + login.Log.Err(err).Msg("Failed to recreate client during network reset") + login.BridgeState.Send(status.BridgeState{ + StateEvent: status.StateUnknownError, + Error: "bridgev2-network-reset-fail", + Info: map[string]any{"go_error": err.Error()}, + }) + } else { + login.Client.Connect(ctx) + } + } + br.Log.Info().Msg("Finished resetting all user logins") +} + func (br *Bridge) IsStopping() bool { return br.stopping.Load() } diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 3e25031f..0e9a8543 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -318,6 +318,16 @@ type MaxFileSizeingNetwork interface { SetMaxFileSize(maxSize int64) } +type NetworkResettingNetwork interface { + NetworkConnector + // ResetHTTPTransport should recreate the HTTP client used by the bridge. + // It should refetch settings from the Matrix connector using GetHTTPClientSettings if applicable. + ResetHTTPTransport() + // ResetNetworkConnections should forcefully disconnect and restart any persistent network connections. + // ResetHTTPTransport will usually be called before this, so resetting the transport is not necessary here. + ResetNetworkConnections() +} + type RemoteEchoHandler func(RemoteMessage, *database.Message) (bool, error) type MatrixMessageResponse struct { diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index b5fcfcd0..c9102248 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -10,6 +10,7 @@ import ( "cmp" "context" "fmt" + "maps" "slices" "sync" "time" @@ -140,6 +141,12 @@ func (br *Bridge) GetCachedUserLoginByID(id networkid.UserLoginID) *UserLogin { return br.userLoginsByID[id] } +func (br *Bridge) GetAllCachedUserLogins() (logins []*UserLogin) { + br.cacheLock.Lock() + defer br.cacheLock.Unlock() + return slices.Collect(maps.Values(br.userLoginsByID)) +} + func (br *Bridge) GetCurrentBridgeStates() (states []status.BridgeState) { br.cacheLock.Lock() defer br.cacheLock.Unlock() From 75f9cb369bea0a3756e91f2ef9cdce86d6f4ffe9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 14 Jan 2026 17:06:32 +0200 Subject: [PATCH 1585/1647] bridgev2: add helper method for getting HTTP settings from matrix connector --- bridgev2/bridge.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/bridgev2/bridge.go b/bridgev2/bridge.go index 3825333c..226adc90 100644 --- a/bridgev2/bridge.go +++ b/bridgev2/bridge.go @@ -16,6 +16,7 @@ import ( "github.com/rs/zerolog" "go.mau.fi/util/dbutil" + "go.mau.fi/util/exhttp" "go.mau.fi/util/exsync" "maunium.net/go/mautrix/bridgev2/bridgeconfig" @@ -401,6 +402,14 @@ func (br *Bridge) ResetNetworkConnections() { br.Log.Info().Msg("Finished resetting all user logins") } +func (br *Bridge) GetHTTPClientSettings() exhttp.ClientSettings { + mchs, ok := br.Matrix.(MatrixConnectorWithHTTPSettings) + if ok { + return mchs.GetHTTPClientSettings() + } + return exhttp.SensibleClientSettings +} + func (br *Bridge) IsStopping() bool { return br.stopping.Load() } From 34bcd027e54ce56d61c2f265bfa025d194f410df Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 15 Jan 2026 14:02:00 +0200 Subject: [PATCH 1586/1647] bridgev2/commands: add debug command for resetting connections --- bridgev2/commands/debug.go | 22 ++++++++++++++++++++++ bridgev2/commands/processor.go | 3 ++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/bridgev2/commands/debug.go b/bridgev2/commands/debug.go index ad773ac8..1cae98fe 100644 --- a/bridgev2/commands/debug.go +++ b/bridgev2/commands/debug.go @@ -101,3 +101,25 @@ var CommandSendAccountData = &FullHandler{ RequiresPortal: true, RequiresLogin: true, } + +var CommandResetNetwork = &FullHandler{ + Func: func(ce *Event) { + if strings.Contains(strings.ToLower(ce.RawArgs), "--reset-transport") { + nrn, ok := ce.Bridge.Network.(bridgev2.NetworkResettingNetwork) + if ok { + nrn.ResetHTTPTransport() + } else { + ce.Reply("Network connector does not support resetting HTTP transport") + } + } + ce.Bridge.ResetNetworkConnections() + ce.React("✅️") + }, + Name: "debug-reset-network", + Help: HelpMeta{ + Section: HelpSectionAdmin, + Description: "Reset network connections to the remote network", + Args: "[--reset-transport]", + }, + RequiresAdmin: true, +} diff --git a/bridgev2/commands/processor.go b/bridgev2/commands/processor.go index 692db80d..391c3685 100644 --- a/bridgev2/commands/processor.go +++ b/bridgev2/commands/processor.go @@ -41,7 +41,8 @@ func NewProcessor(bridge *bridgev2.Bridge) bridgev2.CommandProcessor { } proc.AddHandlers( CommandHelp, CommandCancel, - CommandRegisterPush, CommandSendAccountData, CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom, + CommandRegisterPush, CommandSendAccountData, CommandResetNetwork, + CommandDeletePortal, CommandDeleteAllPortals, CommandSetManagementRoom, CommandLogin, CommandRelogin, CommandListLogins, CommandLogout, CommandSetPreferredLogin, CommandSetRelay, CommandUnsetRelay, CommandResolveIdentifier, CommandStartChat, CommandCreateGroup, CommandSearch, CommandSyncChat, CommandMute, From 65d708f1b7d5ce7bbd4df93c611c457e381ce7ce Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 16 Jan 2026 14:50:43 +0200 Subject: [PATCH 1587/1647] Bump version to v0.26.2 --- CHANGELOG.md | 24 ++++++++++++++++++++++++ go.mod | 20 ++++++++++---------- go.sum | 36 ++++++++++++++++++------------------ version.go | 2 +- 4 files changed, 53 insertions(+), 29 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8017ef97..dbc7c494 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,27 @@ +## v0.26.2 (2026-01-16) + +* *(bridgev2)* Added chunked portal deletion to avoid database locks when + deleting large portals. +* *(crypto,bridgev2)* Added option to encrypt reaction and reply metadata + as per [MSC4392]. +* *(bridgev2/login)* Added `default_value` for user input fields. +* *(bridgev2)* Added interfaces to let the Matrix connector provide suggested + HTTP client settings and to reset active connections of the network connector. +* *(bridgev2)* Added interface to let network connectors get the provisioning + API HTTP router and add new endpoints. +* *(event)* Added blurhash field to Beeper link preview objects. +* *(event)* Added [MSC4391] support for bot commands. +* *(event)* Dropped [MSC4332] support for bot commands. +* *(client)* Changed media download methods to return an error if the provided + MXC URI is empty. +* *(client)* Stabilized support for [MSC4323]. +* *(bridgev2/matrix)* Fixed `GetEvent` panicking when trying to decrypt events. +* *(bridgev2)* Fixed some deadlocks when room creation happens in parallel with + a portal re-ID call. + +[MSC4391]: https://github.com/matrix-org/matrix-spec-proposals/pull/4391 +[MSC4392]: https://github.com/matrix-org/matrix-spec-proposals/pull/4392 + ## v0.26.1 (2025-12-16) * **Breaking change *(mediaproxy)*** Changed `GetMediaResponseFile` to return diff --git a/go.mod b/go.mod index 544a9ff4..27acd21b 100644 --- a/go.mod +++ b/go.mod @@ -2,26 +2,26 @@ module maunium.net/go/mautrix go 1.24.0 -toolchain go1.25.5 +toolchain go1.25.6 require ( filippo.io/edwards25519 v1.1.0 github.com/chzyer/readline v1.5.1 github.com/coder/websocket v1.8.14 github.com/lib/pq v1.10.9 - github.com/mattn/go-sqlite3 v1.14.32 + github.com/mattn/go-sqlite3 v1.14.33 github.com/rs/xid v1.6.0 github.com/rs/zerolog v1.34.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/stretchr/testify v1.11.1 github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 - github.com/yuin/goldmark v1.7.13 - go.mau.fi/util v0.9.5-0.20260113180831-8cda92561373 + github.com/yuin/goldmark v1.7.16 + go.mau.fi/util v0.9.5 go.mau.fi/zeroconfig v0.2.0 - golang.org/x/crypto v0.46.0 - golang.org/x/exp v0.0.0-20251209150349-8475f28825e9 - golang.org/x/net v0.48.0 + golang.org/x/crypto v0.47.0 + golang.org/x/exp v0.0.0-20260112195511-716be5621a96 + golang.org/x/net v0.49.0 golang.org/x/sync v0.19.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 @@ -32,11 +32,11 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/petermattis/goid v0.0.0-20251121121749-a11dd1a45f9a // indirect + github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect - golang.org/x/sys v0.39.0 // indirect - golang.org/x/text v0.32.0 // indirect + golang.org/x/sys v0.40.0 // indirect + golang.org/x/text v0.33.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index 70a1b5a9..9702337a 100644 --- a/go.sum +++ b/go.sum @@ -25,10 +25,10 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= -github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/petermattis/goid v0.0.0-20251121121749-a11dd1a45f9a h1:VweslR2akb/ARhXfqSfRbj1vpWwYXf3eeAUyw/ndms0= -github.com/petermattis/goid v0.0.0-20251121121749-a11dd1a45f9a/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0= +github.com/mattn/go-sqlite3 v1.14.33/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 h1:KPpdlQLZcHfTMQRi6bFQ7ogNO0ltFT4PmtwTLW4W+14= +github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -50,28 +50,28 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= -github.com/yuin/goldmark v1.7.13 h1:GPddIs617DnBLFFVJFgpo1aBfe/4xcvMc3SB5t/D0pA= -github.com/yuin/goldmark v1.7.13/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.9.5-0.20260113180831-8cda92561373 h1:LjFGO80c9mGeYCvrBsASvK9jx3oPkXo++l9quy4YMls= -go.mau.fi/util v0.9.5-0.20260113180831-8cda92561373/go.mod h1:647nVfwUvuhlZFOnro3aRNPmRd2y3iDha9USb8aKSmM= +github.com/yuin/goldmark v1.7.16 h1:n+CJdUxaFMiDUNnWC3dMWCIQJSkxH4uz3ZwQBkAlVNE= +github.com/yuin/goldmark v1.7.16/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= +go.mau.fi/util v0.9.5 h1:7AoWPCIZJGv4jvtFEuCe3GhAbI7uF9ckIooaXvwlIR4= +go.mau.fi/util v0.9.5/go.mod h1:g1uvZ03VQhtTt2BgaRGVytS/Zj67NV0YNIECch0sQCQ= go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= -golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= -golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= -golang.org/x/exp v0.0.0-20251209150349-8475f28825e9 h1:MDfG8Cvcqlt9XXrmEiD4epKn7VJHZO84hejP9Jmp0MM= -golang.org/x/exp v0.0.0-20251209150349-8475f28825e9/go.mod h1:EPRbTFwzwjXj9NpYyyrvenVh9Y+GFeEvMNh7Xuz7xgU= -golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= -golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= +golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= +golang.org/x/exp v0.0.0-20260112195511-716be5621a96 h1:Z/6YuSHTLOHfNFdb8zVZomZr7cqNgTJvA8+Qz75D8gU= +golang.org/x/exp v0.0.0-20260112195511-716be5621a96/go.mod h1:nzimsREAkjBCIEFtHiYkrJyT+2uy9YZJB7H1k68CXZU= +golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= +golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= -golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= -golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= +golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= +golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= diff --git a/version.go b/version.go index 46d3342c..b0e31c7e 100644 --- a/version.go +++ b/version.go @@ -8,7 +8,7 @@ import ( "strings" ) -const Version = "v0.26.1" +const Version = "v0.26.2" var GoModVersion = "" var Commit = "" From 0e4b074b571c1fa1f6cbd1f256b3b1572fe82fe1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 17 Jan 2026 00:43:41 +0200 Subject: [PATCH 1588/1647] event: add detail to not json string parse error --- event/encryption.go | 2 +- id/contenturi.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/event/encryption.go b/event/encryption.go index e07944af..8e386b60 100644 --- a/event/encryption.go +++ b/event/encryption.go @@ -63,7 +63,7 @@ func (content *EncryptedEventContent) UnmarshalJSON(data []byte) error { return json.Unmarshal(content.Ciphertext, &content.OlmCiphertext) case id.AlgorithmMegolmV1: if len(content.Ciphertext) == 0 || content.Ciphertext[0] != '"' || content.Ciphertext[len(content.Ciphertext)-1] != '"' { - return id.ErrInputNotJSONString + return fmt.Errorf("ciphertext %w", id.ErrInputNotJSONString) } content.MegolmCiphertext = content.Ciphertext[1 : len(content.Ciphertext)-1] } diff --git a/id/contenturi.go b/id/contenturi.go index be45eb2b..67127b6c 100644 --- a/id/contenturi.go +++ b/id/contenturi.go @@ -92,7 +92,7 @@ func (uri *ContentURI) UnmarshalJSON(raw []byte) (err error) { *uri = ContentURI{} return nil } else if len(raw) < 2 || raw[0] != '"' || raw[len(raw)-1] != '"' { - return ErrInputNotJSONString + return fmt.Errorf("ContentURI: %w", ErrInputNotJSONString) } parsed, err := ParseContentURIBytes(raw[1 : len(raw)-1]) if err != nil { From b226c03277ae43ffd88a1c4e6fbdb5fa0692170d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 17 Jan 2026 00:55:16 +0200 Subject: [PATCH 1589/1647] crypto: add length check to hacky megolm message index parser --- crypto/encryptmegolm.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index 8ce70ca0..806a227d 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -91,11 +91,16 @@ func IsShareError(err error) bool { } func ParseMegolmMessageIndex(ciphertext []byte) (uint, error) { + if len(ciphertext) == 0 { + return 0, fmt.Errorf("empty ciphertext") + } decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(ciphertext))) var err error _, err = base64.RawStdEncoding.Decode(decoded, ciphertext) if err != nil { return 0, err + } else if len(decoded) < 2+binary.MaxVarintLen64 { + return 0, fmt.Errorf("decoded ciphertext too short: %d bytes", len(decoded)) } else if decoded[0] != 3 || decoded[1] != 8 { return 0, fmt.Errorf("unexpected initial bytes %d and %d", decoded[0], decoded[1]) } From ec3cf5fbdd0e9e1c80da4c6f5f6d11ef3fdd33ce Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 17 Jan 2026 01:02:39 +0200 Subject: [PATCH 1590/1647] crypto/decryptmegolm: add additional checks for megolm decryption --- crypto/decryptmegolm.go | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index d8b419ab..59ff67a8 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -31,6 +31,7 @@ var ( ErrDeviceKeyMismatch = errors.New("device keys in event and verified device info do not match") ErrSenderKeyMismatch = errors.New("sender keys in content and megolm session do not match") ErrRatchetError = errors.New("failed to ratchet session after use") + ErrCorruptedMegolmPayload = errors.New("corrupted megolm payload") ) // Deprecated: use variables prefixed with Err @@ -56,6 +57,17 @@ var ( relatesToTopLevelPath = exgjson.Path("content", "m.relates_to") ) +const sessionIDLength = 43 + +func validateCiphertextCharacters(ciphertext []byte) bool { + for _, b := range ciphertext { + if (b < 'a' || b > 'z') && (b < 'A' || b > 'Z') && (b < '0' || b > '9') && b != '+' && b != '/' { + return false + } + } + return true +} + // DecryptMegolmEvent decrypts an m.room.encrypted event where the algorithm is m.megolm.v1.aes-sha2 func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event) (*event.Event, error) { content, ok := evt.Content.Parsed.(*event.EncryptedEventContent) @@ -63,6 +75,12 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event return nil, ErrIncorrectEncryptedContentType } else if content.Algorithm != id.AlgorithmMegolmV1 { return nil, ErrUnsupportedAlgorithm + } else if len(content.MegolmCiphertext) < 74 { + return nil, fmt.Errorf("%w: ciphertext too short (%d bytes)", ErrCorruptedMegolmPayload, len(content.MegolmCiphertext)) + } else if len(content.SessionID) != sessionIDLength { + return nil, fmt.Errorf("%w: invalid session ID length %d", ErrCorruptedMegolmPayload, len(content.SessionID)) + } else if !validateCiphertextCharacters(content.MegolmCiphertext) { + return nil, fmt.Errorf("%w: invalid characters in ciphertext", ErrCorruptedMegolmPayload) } log := mach.machOrContextLog(ctx).With(). Str("action", "decrypt megolm event"). From b2b58f3a2972cf75ec44bc510c1cc68ad5b45dd6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 17 Jan 2026 01:36:36 +0200 Subject: [PATCH 1591/1647] bridgev2/provisioning: cancel logins on error and delete completed logins from map --- bridgev2/matrix/provisioning.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index e3d3a0b4..17e827e3 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -407,6 +407,7 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque } func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *ProvLogin, step *bridgev2.LoginStep) { + prov.deleteLogin(login, false) if login.Override == nil || login.Override.ID == step.CompleteParams.UserLoginID { return } @@ -420,6 +421,15 @@ func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *Prov }, bridgev2.DeleteOpts{LogoutRemote: true}) } +func (prov *ProvisioningAPI) deleteLogin(login *ProvLogin, cancel bool) { + if cancel { + login.Process.Cancel() + } + prov.loginsLock.Lock() + delete(prov.logins, login.ID) + prov.loginsLock.Unlock() +} + func (prov *ProvisioningAPI) PostLoginStep(w http.ResponseWriter, r *http.Request) { loginID := r.PathValue("loginProcessID") prov.loginsLock.RLock() @@ -490,6 +500,7 @@ func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to submit input") RespondWithError(w, err, "Internal error submitting input") + prov.deleteLogin(login, true) return } login.NextStep = nextStep @@ -508,6 +519,7 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques if err != nil { zerolog.Ctx(r.Context()).Err(err).Msg("Failed to wait") RespondWithError(w, err, "Internal error waiting for login") + prov.deleteLogin(login, true) return } login.NextStep = nextStep From 0b6fa137cead39d87aae5ffcee72715de9b6f698 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 18 Jan 2026 14:49:06 +0200 Subject: [PATCH 1592/1647] client: add support for sending MSC4354 sticky events --- client.go | 6 ++++++ requests.go | 13 ++++++------- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/client.go b/client.go index 87b6d87e..3aa1627d 100644 --- a/client.go +++ b/client.go @@ -1324,6 +1324,9 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event if req.UnstableDelay > 0 { queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10) } + if req.UnstableStickyDuration > 0 { + queryParams["org.matrix.msc4354.sticky_duration_ms"] = strconv.FormatInt(req.UnstableStickyDuration.Milliseconds(), 10) + } if !req.DontEncrypt && cli != nil && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted { var isEncrypted bool @@ -1365,6 +1368,9 @@ func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventTy if req.UnstableDelay > 0 { queryParams["org.matrix.msc4140.delay"] = strconv.FormatInt(req.UnstableDelay.Milliseconds(), 10) } + if req.UnstableStickyDuration > 0 { + queryParams["org.matrix.msc4354.sticky_duration_ms"] = strconv.FormatInt(req.UnstableStickyDuration.Milliseconds(), 10) + } if req.Timestamp > 0 { queryParams["ts"] = strconv.FormatInt(req.Timestamp, 10) } diff --git a/requests.go b/requests.go index f0287b3c..397d30de 100644 --- a/requests.go +++ b/requests.go @@ -367,13 +367,12 @@ type ReqSendToDevice struct { } type ReqSendEvent struct { - Timestamp int64 - TransactionID string - UnstableDelay time.Duration - - DontEncrypt bool - - MeowEventID id.EventID + Timestamp int64 + TransactionID string + UnstableDelay time.Duration + UnstableStickyDuration time.Duration + DontEncrypt bool + MeowEventID id.EventID } type ReqDelayedEvents struct { From 28bcc356db0962eb53f296f17c03f08a0fa0ac0b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 18 Jan 2026 22:41:34 +0200 Subject: [PATCH 1593/1647] client: add MemberCount helper method for lazy load summary --- responses.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/responses.go b/responses.go index d822c84b..20286431 100644 --- a/responses.go +++ b/responses.go @@ -341,6 +341,13 @@ type LazyLoadSummary struct { InvitedMemberCount *int `json:"m.invited_member_count,omitempty"` } +func (lls *LazyLoadSummary) MemberCount() int { + if lls == nil { + return 0 + } + return ptr.Val(lls.JoinedMemberCount) + ptr.Val(lls.InvitedMemberCount) +} + func (lls *LazyLoadSummary) Equal(other *LazyLoadSummary) bool { if lls == other { return true From e28f7170bc4bc9aab3cb8e04d1a94f677dc5f27b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 19 Jan 2026 14:58:18 +0200 Subject: [PATCH 1594/1647] bridgev2/portal: auto-accept message requests on message (#451) --- bridgev2/portal.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ event/beeper.go | 2 ++ 2 files changed, 46 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index e9feb448..6d90a9ed 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1223,6 +1223,12 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *UserLogin } } + err = portal.autoAcceptMessageRequest(ctx, evt, sender, origSender, caps) + if err != nil { + log.Warn().Err(err).Msg("Failed to auto-accept message request on message") + // TODO stop processing? + } + var resp *MatrixMessageResponse if msgContent != nil { resp, err = sender.Client.HandleMatrixMessage(ctx, wrappedMsgEvt) @@ -1502,6 +1508,12 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi log.Warn().Msg("Reaction target message not found in database") return EventHandlingResultFailed.WithMSSError(fmt.Errorf("reaction %w", ErrTargetMessageNotFound)) } + caps := sender.Client.GetCapabilities(ctx, portal) + err = portal.autoAcceptMessageRequest(ctx, evt, sender, nil, caps) + if err != nil { + log.Warn().Err(err).Msg("Failed to auto-accept message request on reaction") + // TODO stop processing? + } log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Str("reaction_target_remote_id", string(reactionTarget.ID)) }) @@ -1801,6 +1813,38 @@ func (portal *Portal) handleMatrixAcceptMessageRequest( return EventHandlingResultSuccess.WithMSS() } +func (portal *Portal) autoAcceptMessageRequest( + ctx context.Context, evt *event.Event, sender *UserLogin, origSender *OrigSender, caps *event.RoomFeatures, +) error { + if !portal.MessageRequest || caps.MessageRequest == nil || caps.MessageRequest.AcceptWithMessage == event.CapLevelFullySupported { + return nil + } + mran, ok := sender.Client.(MessageRequestAcceptingNetworkAPI) + if !ok { + return nil + } + err := mran.HandleMatrixAcceptMessageRequest(ctx, &MatrixAcceptMessageRequest{ + Event: evt, + Content: &event.BeeperAcceptMessageRequestEventContent{ + IsImplicit: true, + }, + Portal: portal, + OrigSender: origSender, + }) + if err != nil { + return err + } + if portal.MessageRequest { + portal.MessageRequest = false + portal.UpdateBridgeInfo(ctx) + err = portal.Save(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Msg("Failed to save portal after accepting message request") + } + } + return nil +} + func (portal *Portal) handleMatrixDeleteChat( ctx context.Context, sender *UserLogin, diff --git a/event/beeper.go b/event/beeper.go index b46106ab..49aa964f 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -94,6 +94,8 @@ type BeeperChatDeleteEventContent struct { } type BeeperAcceptMessageRequestEventContent struct { + // Whether this was triggered by a message rather than an explicit event + IsImplicit bool `json:"-"` } type BeeperSendStateEventContent struct { From f32af79d208dded3330e60492f2bedeeafe21f61 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Mon, 19 Jan 2026 14:26:22 +0000 Subject: [PATCH 1595/1647] bridgev2/ghost: consider avatar being set in `Ghost.UpdateInfoIfNecessary` (#453) Co-authored-by: Tulir Asokan --- bridgev2/ghost.go | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index 6cef6f06..f7072a9c 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -234,7 +234,7 @@ func (br *Bridge) allowAggressiveUpdateForType(evtType RemoteEventType) bool { } func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin, evtType RemoteEventType) { - if ghost.Name != "" && ghost.NameSet && !ghost.Bridge.allowAggressiveUpdateForType(evtType) { + if ghost.Name != "" && ghost.NameSet && ghost.AvatarSet && !ghost.Bridge.allowAggressiveUpdateForType(evtType) { return } info, err := source.Client.GetUserInfo(ctx, ghost) @@ -244,12 +244,16 @@ func (ghost *Ghost) UpdateInfoIfNecessary(ctx context.Context, source *UserLogin zerolog.Ctx(ctx).Debug(). Bool("has_name", ghost.Name != ""). Bool("name_set", ghost.NameSet). + Bool("has_avatar", ghost.AvatarMXC != ""). + Bool("avatar_set", ghost.AvatarSet). Msg("Updating ghost info in IfNecessary call") ghost.UpdateInfo(ctx, info) } else { zerolog.Ctx(ctx).Trace(). Bool("has_name", ghost.Name != ""). Bool("name_set", ghost.NameSet). + Bool("has_avatar", ghost.AvatarMXC != ""). + Bool("avatar_set", ghost.AvatarSet). Msg("No ghost info received in IfNecessary call") } } @@ -277,6 +281,11 @@ func (ghost *Ghost) UpdateInfo(ctx context.Context, info *UserInfo) { } if info.Avatar != nil { update = ghost.UpdateAvatar(ctx, info.Avatar) || update + } else if oldAvatar == "" && !ghost.AvatarSet { + // Special case: nil avatar means we're not expecting one ever, if we don't currently have + // one we flag it as set to avoid constantly refetching in UpdateInfoIfNecessary. + ghost.AvatarSet = true + update = true } if info.Identifiers != nil || info.IsBot != nil { update = ghost.UpdateContactInfo(ctx, info.Identifiers, info.IsBot) || update From a55693bbd7c616a8e9fa04fdd20cb36154997094 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 20 Jan 2026 12:06:55 +0200 Subject: [PATCH 1596/1647] client,bridgev2/matrix: fix context used for async uploads --- bridgev2/matrix/intent.go | 2 ++ client.go | 10 ++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 3d2692f9..173f7c15 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -403,6 +403,7 @@ func (as *ASIntent) UploadMediaStream( removeAndClose(replFile) removeAndClose(tempFile) } + req.AsyncContext = zerolog.Ctx(ctx).WithContext(as.Connector.Bridge.BackgroundCtx) startedAsyncUpload = true var resp *mautrix.RespCreateMXC resp, err = as.Matrix.UploadAsync(ctx, req) @@ -435,6 +436,7 @@ func (as *ASIntent) doUploadReq(ctx context.Context, file *event.EncryptedFileIn as.Connector.uploadSema.Release(int64(len(req.ContentBytes))) } } + req.AsyncContext = zerolog.Ctx(ctx).WithContext(as.Connector.Bridge.BackgroundCtx) var resp *mautrix.RespCreateMXC resp, err = as.Matrix.UploadAsync(ctx, req) if resp != nil { diff --git a/client.go b/client.go index 3aa1627d..2503556a 100644 --- a/client.go +++ b/client.go @@ -1933,10 +1933,15 @@ func (cli *Client) UploadAsync(ctx context.Context, req ReqUploadMedia) (*RespCr } req.MXC = resp.ContentURI req.UnstableUploadURL = resp.UnstableUploadURL + if req.AsyncContext == nil { + req.AsyncContext = cli.cliOrContextLog(ctx).WithContext(context.Background()) + } go func() { - _, err = cli.UploadMedia(ctx, req) + _, err = cli.UploadMedia(req.AsyncContext, req) if err != nil { - cli.Log.Error().Stringer("mxc", req.MXC).Err(err).Msg("Async upload of media failed") + zerolog.Ctx(req.AsyncContext).Err(err). + Stringer("mxc", req.MXC). + Msg("Async upload of media failed") } }() return resp, nil @@ -1972,6 +1977,7 @@ type ReqUploadMedia struct { ContentType string FileName string + AsyncContext context.Context DoneCallback func() // MXC specifies an existing MXC URI which doesn't have content yet to upload into. From a1236b65bea37ab97c550aff2c39411125833932 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 20 Jan 2026 14:28:21 +0200 Subject: [PATCH 1597/1647] crypto/keyimport: call session received callback for all sessions in import --- crypto/keyimport.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/crypto/keyimport.go b/crypto/keyimport.go index 36ad6b9c..aef3eca2 100644 --- a/crypto/keyimport.go +++ b/crypto/keyimport.go @@ -120,7 +120,9 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID()) firstKnownIndex := igs.Internal.FirstKnownIndex() if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= firstKnownIndex { - // We already have an equivalent or better session in the store, so don't override it. + // We already have an equivalent or better session in the store, so don't override it, + // but do notify the session received callback just in case. + mach.MarkSessionReceived(ctx, session.RoomID, igs.ID(), existingIGS.Internal.FirstKnownIndex()) return false, nil } err = mach.CryptoStore.PutGroupSession(ctx, igs) From d057f1c6732e1a0da45f7888984a02087709841c Mon Sep 17 00:00:00 2001 From: SpiritCroc Date: Fri, 23 Jan 2026 15:38:17 +0100 Subject: [PATCH 1598/1647] event: add action message content for rich call notifications (#454) --- event/beeper.go | 18 ++++++++++++++++++ event/message.go | 1 + 2 files changed, 19 insertions(+) diff --git a/event/beeper.go b/event/beeper.go index 49aa964f..2c7d9bf2 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -166,6 +166,24 @@ type BeeperPerMessageProfile struct { HasFallback bool `json:"has_fallback,omitempty"` } +type BeeperActionMessageType string + +const ( + BeeperActionMessageCall BeeperActionMessageType = "call" +) + +type BeeperActionMessageCallType string + +const ( + BeeperActionMessageCallTypeVoice BeeperActionMessageCallType = "voice" + BeeperActionMessageCallTypeVideo BeeperActionMessageCallType = "video" +) + +type BeeperActionMessage struct { + Type BeeperActionMessageType `json:"type"` + CallType BeeperActionMessageCallType `json:"call_type,omitempty"` +} + func (content *MessageEventContent) AddPerMessageProfileFallback() { if content.BeeperPerMessageProfile == nil || content.BeeperPerMessageProfile.HasFallback || content.BeeperPerMessageProfile.Displayname == "" { return diff --git a/event/message.go b/event/message.go index 5e80d2ef..3fb3dc82 100644 --- a/event/message.go +++ b/event/message.go @@ -135,6 +135,7 @@ type MessageEventContent struct { BeeperGalleryCaption string `json:"com.beeper.gallery.caption,omitempty"` BeeperGalleryCaptionHTML string `json:"com.beeper.gallery.caption_html,omitempty"` BeeperPerMessageProfile *BeeperPerMessageProfile `json:"com.beeper.per_message_profile,omitempty"` + BeeperActionMessage *BeeperActionMessage `json:"com.beeper.action_message,omitempty"` BeeperLinkPreviews []*BeeperLinkPreview `json:"com.beeper.linkpreviews,omitempty"` From 8b04430d84edbc3efb16b89a6c9e2c74ac5f0d7b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 23 Jan 2026 19:37:35 +0200 Subject: [PATCH 1599/1647] event: switch url preview image blurhash to use MSC2448 field --- event/beeper.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/event/beeper.go b/event/beeper.go index 2c7d9bf2..6de41df6 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -146,7 +146,7 @@ type BeeperLinkPreview struct { MatchedURL string `json:"matched_url,omitempty"` ImageEncryption *EncryptedFileInfo `json:"beeper:image:encryption,omitempty"` - ImageBlurhash string `json:"beeper:image:blurhash,omitempty"` + ImageBlurhash string `json:"matrix:image:blurhash,omitempty"` } type BeeperProfileExtra struct { From b041eb924ea508fece4f09546d83e336a8d3edf4 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 26 Jan 2026 01:20:30 +0200 Subject: [PATCH 1600/1647] error: allow storing extra headers in RespError --- error.go | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/error.go b/error.go index 5ff671e0..a284f7d1 100644 --- a/error.go +++ b/error.go @@ -140,7 +140,8 @@ type RespError struct { Err string ExtraData map[string]any - StatusCode int + StatusCode int + ExtraHeader map[string]string } func (e *RespError) UnmarshalJSON(data []byte) error { @@ -168,6 +169,9 @@ func (e RespError) Write(w http.ResponseWriter) { if statusCode == 0 { statusCode = http.StatusInternalServerError } + for key, value := range e.ExtraHeader { + w.Header().Set(key, value) + } exhttp.WriteJSONResponse(w, statusCode, &e) } @@ -190,6 +194,18 @@ func (e RespError) WithExtraData(extraData map[string]any) RespError { return e } +func (e RespError) WithExtraHeader(key, value string) RespError { + e.ExtraHeader = exmaps.NonNilClone(e.ExtraHeader) + e.ExtraHeader[key] = value + return e +} + +func (e RespError) WithExtraHeaders(headers map[string]string) RespError { + e.ExtraHeader = exmaps.NonNilClone(e.ExtraHeader) + maps.Copy(e.ExtraHeader, headers) + return e +} + // Error returns the errcode and error message. func (e RespError) Error() string { return e.ErrCode + ": " + e.Err From 074a2d8d4d5d9dede2ff847aec57939f913d0041 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 26 Jan 2026 01:38:03 +0200 Subject: [PATCH 1601/1647] crypto/keysharing: fix including sender key in forwards --- crypto/decryptmegolm.go | 4 ---- crypto/keysharing.go | 3 ++- event/encryption.go | 3 ++- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index 59ff67a8..77a64b1e 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -29,7 +29,6 @@ var ( ErrDuplicateMessageIndex = errors.New("duplicate megolm message index") ErrWrongRoom = errors.New("encrypted megolm event is not intended for this room") ErrDeviceKeyMismatch = errors.New("device keys in event and verified device info do not match") - ErrSenderKeyMismatch = errors.New("sender keys in content and megolm session do not match") ErrRatchetError = errors.New("failed to ratchet session after use") ErrCorruptedMegolmPayload = errors.New("corrupted megolm payload") ) @@ -41,7 +40,6 @@ var ( DuplicateMessageIndex = ErrDuplicateMessageIndex WrongRoom = ErrWrongRoom DeviceKeyMismatch = ErrDeviceKeyMismatch - SenderKeyMismatch = ErrSenderKeyMismatch RatchetError = ErrRatchetError ) @@ -254,8 +252,6 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err) } else if sess == nil { return nil, nil, 0, fmt.Errorf("%w (ID %s)", ErrNoSessionFound, content.SessionID) - } else if content.SenderKey != "" && content.SenderKey != sess.SenderKey { - return sess, nil, 0, ErrSenderKeyMismatch } plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext) if err != nil { diff --git a/crypto/keysharing.go b/crypto/keysharing.go index f1d427af..cde594c2 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -214,6 +214,7 @@ func (mach *OlmMachine) rejectKeyRequest(ctx context.Context, rejection KeyShare RoomID: request.RoomID, Algorithm: request.Algorithm, SessionID: request.SessionID, + //lint:ignore SA1019 This is just echoing back the deprecated field SenderKey: request.SenderKey, Code: rejection.Code, Reason: rejection.Reason, @@ -356,7 +357,7 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User SessionID: igs.ID(), SessionKey: string(exportedKey), }, - SenderKey: content.Body.SenderKey, + SenderKey: igs.SenderKey, ForwardingKeyChain: igs.ForwardingChains, SenderClaimedKey: igs.SigningKey, }, diff --git a/event/encryption.go b/event/encryption.go index 8e386b60..c60cb91a 100644 --- a/event/encryption.go +++ b/event/encryption.go @@ -132,8 +132,9 @@ type RoomKeyRequestEventContent struct { type RequestedKeyInfo struct { Algorithm id.Algorithm `json:"algorithm"` RoomID id.RoomID `json:"room_id"` - SenderKey id.SenderKey `json:"sender_key"` SessionID id.SessionID `json:"session_id"` + // Deprecated: Matrix v1.3 + SenderKey id.SenderKey `json:"sender_key"` } type RoomKeyWithheldCode string From 9d30203f6b9c6a14d751dbba73482465b5b49020 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 26 Jan 2026 13:42:33 +0200 Subject: [PATCH 1602/1647] bridgev2/userlogin: add todo --- bridgev2/userlogin.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index c9102248..35443025 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -51,6 +51,8 @@ func (br *Bridge) loadUserLogin(ctx context.Context, user *User, dbUserLogin *da if err != nil { return nil, fmt.Errorf("failed to get user: %w", err) } + // TODO if loading the user caused the provided userlogin to be loaded, cancel here? + // Currently this will double-load it } userLogin := &UserLogin{ UserLogin: dbUserLogin, From c4ce008c8eee58e7a9dc1978403e75292f1f0927 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 28 Jan 2026 12:51:46 +0200 Subject: [PATCH 1603/1647] crypto/ssss: skip verifying recovery key if MAC or IV are missing --- crypto/cross_sign_ssss.go | 7 ++++++- crypto/ssss/meta.go | 14 +++++++++++--- crypto/ssss/types.go | 3 ++- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/crypto/cross_sign_ssss.go b/crypto/cross_sign_ssss.go index 50b58ea0..fd42880d 100644 --- a/crypto/cross_sign_ssss.go +++ b/crypto/cross_sign_ssss.go @@ -8,6 +8,7 @@ package crypto import ( "context" + "errors" "fmt" "maunium.net/go/mautrix" @@ -77,7 +78,11 @@ func (mach *OlmMachine) VerifyWithRecoveryKey(ctx context.Context, recoveryKey s return fmt.Errorf("failed to get default SSSS key data: %w", err) } key, err := keyData.VerifyRecoveryKey(keyID, recoveryKey) - if err != nil { + if errors.Is(err, ssss.ErrUnverifiableKey) { + mach.machOrContextLog(ctx).Warn(). + Str("key_id", keyID). + Msg("SSSS key is unverifiable, trying to use without verifying") + } else if err != nil { return err } err = mach.FetchCrossSigningKeysFromSSSS(ctx, key) diff --git a/crypto/ssss/meta.go b/crypto/ssss/meta.go index 474c85d8..f2ae68eb 100644 --- a/crypto/ssss/meta.go +++ b/crypto/ssss/meta.go @@ -8,6 +8,7 @@ package ssss import ( "encoding/base64" + "errors" "fmt" "strings" @@ -33,7 +34,9 @@ func (kd *KeyMetadata) VerifyPassphrase(keyID, passphrase string) (*Key, error) ssssKey, err := kd.Passphrase.GetKey(passphrase) if err != nil { return nil, err - } else if err = kd.verifyKey(ssssKey); err != nil { + } + err = kd.verifyKey(ssssKey) + if err != nil && !errors.Is(err, ErrUnverifiableKey) { return nil, err } @@ -49,7 +52,9 @@ func (kd *KeyMetadata) VerifyRecoveryKey(keyID, recoveryKey string) (*Key, error ssssKey := utils.DecodeBase58RecoveryKey(recoveryKey) if ssssKey == nil { return nil, ErrInvalidRecoveryKey - } else if err := kd.verifyKey(ssssKey); err != nil { + } + err := kd.verifyKey(ssssKey) + if err != nil && !errors.Is(err, ErrUnverifiableKey) { return nil, err } @@ -57,10 +62,13 @@ func (kd *KeyMetadata) VerifyRecoveryKey(keyID, recoveryKey string) (*Key, error ID: keyID, Key: ssssKey, Metadata: kd, - }, nil + }, err } func (kd *KeyMetadata) verifyKey(key []byte) error { + if kd.MAC == "" || kd.IV == "" { + return ErrUnverifiableKey + } unpaddedMAC := strings.TrimRight(kd.MAC, "=") expectedMACLength := base64.RawStdEncoding.EncodedLen(utils.SHAHashLength) if len(unpaddedMAC) != expectedMACLength { diff --git a/crypto/ssss/types.go b/crypto/ssss/types.go index c08f107c..b7465d3e 100644 --- a/crypto/ssss/types.go +++ b/crypto/ssss/types.go @@ -26,7 +26,8 @@ var ( ErrUnsupportedPassphraseAlgorithm = errors.New("unsupported passphrase KDF algorithm") ErrIncorrectSSSSKey = errors.New("incorrect SSSS key") ErrInvalidRecoveryKey = errors.New("invalid recovery key") - ErrCorruptedKeyMetadata = errors.New("corrupted key metadata") + ErrCorruptedKeyMetadata = errors.New("corrupted recovery key metadata") + ErrUnverifiableKey = errors.New("cannot verify recovery key: missing MAC or IV in metadata") ) // Algorithm is the identifier for an SSSS encryption algorithm. From 2c0d51ee7d92a62334268c57035bc4153f3b4597 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 28 Jan 2026 14:39:52 +0200 Subject: [PATCH 1604/1647] crypto/ssss: handle slightly broken key metadata better --- crypto/ssss/key.go | 4 ++-- crypto/ssss/meta.go | 34 +++++++++++++++++++----------- crypto/ssss/meta_test.go | 45 +++++++++++++++++++++++++++++++++------- 3 files changed, 61 insertions(+), 22 deletions(-) diff --git a/crypto/ssss/key.go b/crypto/ssss/key.go index cd8e3fce..78ebd8f3 100644 --- a/crypto/ssss/key.go +++ b/crypto/ssss/key.go @@ -59,12 +59,12 @@ func NewKey(passphrase string) (*Key, error) { // We store a certain hash in the key metadata so that clients can check if the user entered the correct key. ivBytes := random.Bytes(utils.AESCTRIVLength) keyData.IV = base64.RawStdEncoding.EncodeToString(ivBytes) - var err error - keyData.MAC, err = keyData.calculateHash(ssssKey) + macBytes, err := keyData.calculateHash(ssssKey) if err != nil { // This should never happen because we just generated the IV and key. return nil, fmt.Errorf("failed to calculate hash: %w", err) } + keyData.MAC = base64.RawStdEncoding.EncodeToString(macBytes) return &Key{ Key: ssssKey, diff --git a/crypto/ssss/meta.go b/crypto/ssss/meta.go index f2ae68eb..34775fa7 100644 --- a/crypto/ssss/meta.go +++ b/crypto/ssss/meta.go @@ -7,6 +7,8 @@ package ssss import ( + "crypto/hmac" + "crypto/sha256" "encoding/base64" "errors" "fmt" @@ -74,11 +76,16 @@ func (kd *KeyMetadata) verifyKey(key []byte) error { if len(unpaddedMAC) != expectedMACLength { return fmt.Errorf("%w: invalid mac length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedMAC), expectedMACLength) } - hash, err := kd.calculateHash(key) + expectedMAC, err := base64.RawStdEncoding.DecodeString(unpaddedMAC) + if err != nil { + return fmt.Errorf("%w: failed to decode mac: %w", ErrCorruptedKeyMetadata, err) + } + calculatedMAC, err := kd.calculateHash(key) if err != nil { return err } - if unpaddedMAC != hash { + // This doesn't really need to be constant time since it's fully local, but might as well be. + if !hmac.Equal(expectedMAC, calculatedMAC) { return ErrIncorrectSSSSKey } return nil @@ -91,23 +98,26 @@ func (kd *KeyMetadata) VerifyKey(key []byte) bool { // calculateHash calculates the hash used for checking if the key is entered correctly as described // in the spec: https://matrix.org/docs/spec/client_server/unstable#m-secret-storage-v1-aes-hmac-sha2 -func (kd *KeyMetadata) calculateHash(key []byte) (string, error) { +func (kd *KeyMetadata) calculateHash(key []byte) ([]byte, error) { aesKey, hmacKey := utils.DeriveKeysSHA256(key, "") unpaddedIV := strings.TrimRight(kd.IV, "=") expectedIVLength := base64.RawStdEncoding.EncodedLen(utils.AESCTRIVLength) - if len(unpaddedIV) != expectedIVLength { - return "", fmt.Errorf("%w: invalid iv length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedIV), expectedIVLength) + if len(unpaddedIV) < expectedIVLength || len(unpaddedIV) > expectedIVLength*3 { + return nil, fmt.Errorf("%w: invalid iv length %d (expected %d)", ErrCorruptedKeyMetadata, len(unpaddedIV), expectedIVLength) } - - var ivBytes [utils.AESCTRIVLength]byte - _, err := base64.RawStdEncoding.Decode(ivBytes[:], []byte(unpaddedIV)) + rawIVBytes, err := base64.RawStdEncoding.DecodeString(unpaddedIV) if err != nil { - return "", fmt.Errorf("%w: failed to decode iv: %w", ErrCorruptedKeyMetadata, err) + return nil, fmt.Errorf("%w: failed to decode iv: %w", ErrCorruptedKeyMetadata, err) } + // TODO log a warning for non-16 byte IVs? + // Certain broken clients like nheko generated 32-byte IVs where only the first 16 bytes were used. + ivBytes := *(*[utils.AESCTRIVLength]byte)(rawIVBytes[:utils.AESCTRIVLength]) - cipher := utils.XorA256CTR(make([]byte, utils.AESCTRKeyLength), aesKey, ivBytes) - - return utils.HMACSHA256B64(cipher, hmacKey), nil + zeroes := make([]byte, utils.AESCTRKeyLength) + encryptedZeroes := utils.XorA256CTR(zeroes, aesKey, ivBytes) + h := hmac.New(sha256.New, hmacKey[:]) + h.Write(encryptedZeroes) + return h.Sum(nil), nil } // PassphraseMetadata represents server-side metadata about a SSSS key passphrase. diff --git a/crypto/ssss/meta_test.go b/crypto/ssss/meta_test.go index 7a5ef8b9..d59809c7 100644 --- a/crypto/ssss/meta_test.go +++ b/crypto/ssss/meta_test.go @@ -8,7 +8,6 @@ package ssss_test import ( "encoding/json" - "errors" "testing" "github.com/stretchr/testify/assert" @@ -42,10 +41,24 @@ const key2Meta = ` } ` +const key2MetaUnverified = ` +{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2" +} +` + +const key2MetaLongIV = ` +{ + "algorithm": "m.secret_storage.v1.aes-hmac-sha2", + "iv": "O0BOvTqiIAYjC+RMcyHfW2f/gdxjceTxoYtNlpPduJ8=", + "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI=" +} +` + const key2MetaBrokenIV = ` { "algorithm": "m.secret_storage.v1.aes-hmac-sha2", - "iv": "O0BOvTqiIAYjC+RMcyHfWwMeowMeowMeow", + "iv": "MeowMeowMeow", "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI=" } ` @@ -94,17 +107,33 @@ func TestKeyMetadata_VerifyRecoveryKey_Correct2(t *testing.T) { assert.Equal(t, key2RecoveryKey, key.RecoveryKey()) } +func TestKeyMetadata_VerifyRecoveryKey_NonCompliant_LongIV(t *testing.T) { + km := getKeyMeta(key2MetaLongIV) + key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) + assert.NoError(t, err) + assert.NotNil(t, key) + assert.Equal(t, key2RecoveryKey, key.RecoveryKey()) +} + +func TestKeyMetadata_VerifyRecoveryKey_Unverified(t *testing.T) { + km := getKeyMeta(key2MetaUnverified) + key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) + assert.ErrorIs(t, err, ssss.ErrUnverifiableKey) + assert.NotNil(t, key) + assert.Equal(t, key2RecoveryKey, key.RecoveryKey()) +} + func TestKeyMetadata_VerifyRecoveryKey_Invalid(t *testing.T) { km := getKeyMeta(key1Meta) key, err := km.VerifyRecoveryKey(key1ID, "foo") - assert.True(t, errors.Is(err, ssss.ErrInvalidRecoveryKey), "unexpected error: %v", err) + assert.ErrorIs(t, err, ssss.ErrInvalidRecoveryKey) assert.Nil(t, key) } func TestKeyMetadata_VerifyRecoveryKey_Incorrect(t *testing.T) { km := getKeyMeta(key1Meta) key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) - assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error: %v", err) + assert.ErrorIs(t, err, ssss.ErrIncorrectSSSSKey) assert.Nil(t, key) } @@ -119,27 +148,27 @@ func TestKeyMetadata_VerifyPassphrase_Correct(t *testing.T) { func TestKeyMetadata_VerifyPassphrase_Incorrect(t *testing.T) { km := getKeyMeta(key1Meta) key, err := km.VerifyPassphrase(key1ID, "incorrect horse battery staple") - assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error %v", err) + assert.ErrorIs(t, err, ssss.ErrIncorrectSSSSKey) assert.Nil(t, key) } func TestKeyMetadata_VerifyPassphrase_NotSet(t *testing.T) { km := getKeyMeta(key2Meta) key, err := km.VerifyPassphrase(key2ID, "hmm") - assert.True(t, errors.Is(err, ssss.ErrNoPassphrase), "unexpected error %v", err) + assert.ErrorIs(t, err, ssss.ErrNoPassphrase) assert.Nil(t, key) } func TestKeyMetadata_VerifyRecoveryKey_CorruptedIV(t *testing.T) { km := getKeyMeta(key2MetaBrokenIV) key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) - assert.True(t, errors.Is(err, ssss.ErrCorruptedKeyMetadata), "unexpected error %v", err) + assert.ErrorIs(t, err, ssss.ErrCorruptedKeyMetadata) assert.Nil(t, key) } func TestKeyMetadata_VerifyRecoveryKey_CorruptedMAC(t *testing.T) { km := getKeyMeta(key2MetaBrokenMAC) key, err := km.VerifyRecoveryKey(key2ID, key2RecoveryKey) - assert.True(t, errors.Is(err, ssss.ErrCorruptedKeyMetadata), "unexpected error %v", err) + assert.ErrorIs(t, err, ssss.ErrCorruptedKeyMetadata) assert.Nil(t, key) } From b613f4d67647c6d03c09de03727f49b3a8d9a7f8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 28 Jan 2026 21:32:48 +0200 Subject: [PATCH 1605/1647] crypto/sessions: add missing field in export --- crypto/sessions.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crypto/sessions.go b/crypto/sessions.go index 6b90c998..d7e68eb1 100644 --- a/crypto/sessions.go +++ b/crypto/sessions.go @@ -169,7 +169,7 @@ func (igs *InboundGroupSession) export() (*ExportedSession, error) { ForwardingChains: igs.ForwardingChains, RoomID: igs.RoomID, SenderKey: igs.SenderKey, - SenderClaimedKeys: SenderClaimedKeys{}, + SenderClaimedKeys: SenderClaimedKeys{Ed25519: igs.SigningKey}, SessionID: igs.ID(), SessionKey: string(key), }, nil From 2423716f83946e840ec3d28271d884470296cb27 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 28 Jan 2026 21:34:07 +0200 Subject: [PATCH 1606/1647] crypto/keysharing: don't send withheld response to some key requests --- crypto/keysharing.go | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/crypto/keysharing.go b/crypto/keysharing.go index cde594c2..c1f7171c 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -264,9 +264,14 @@ func (mach *OlmMachine) defaultAllowKeyShare(ctx context.Context, device *id.Dev log.Err(err).Msg("Rejecting key request due to internal error when checking session sharing") return &KeyShareRejectNoResponse } else if !isShared { - // TODO differentiate session not shared with requester vs session not created by this device? - log.Debug().Msg("Rejecting key request for unshared session") - return &KeyShareRejectNotRecipient + igs, _ := mach.CryptoStore.GetGroupSession(ctx, evt.RoomID, evt.SessionID) + if igs != nil && igs.SenderKey == mach.OwnIdentity().IdentityKey { + log.Debug().Msg("Rejecting key request for unshared session") + return &KeyShareRejectNotRecipient + } + // Note: this case will also happen for redacted sessions and database errors + log.Debug().Msg("Rejecting key request for session created by another device") + return &KeyShareRejectNoResponse } log.Debug().Msg("Accepting key request for shared session") return nil @@ -324,7 +329,9 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User if err != nil { if errors.Is(err, ErrGroupSessionWithheld) { log.Debug().Err(err).Msg("Requested group session not available") - mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) + if sender != mach.Client.UserID { + mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) + } } else { log.Error().Err(err).Msg("Failed to get group session to forward") mach.rejectKeyRequest(ctx, KeyShareRejectInternalError, device, content.Body) @@ -332,7 +339,9 @@ func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.User return } else if igs == nil { log.Error().Msg("Didn't find group session to forward") - mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) + if sender != mach.Client.UserID { + mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body) + } return } if internalID := igs.ID(); internalID != content.Body.SessionID { From 60742c4b61a4839f2ae78d443edb2f22de78ca4e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 28 Jan 2026 21:37:23 +0200 Subject: [PATCH 1607/1647] crypto: update test --- crypto/keyexport_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crypto/keyexport_test.go b/crypto/keyexport_test.go index 47616a20..fd6f105d 100644 --- a/crypto/keyexport_test.go +++ b/crypto/keyexport_test.go @@ -31,5 +31,5 @@ func TestExportKeys(t *testing.T) { )) data, err := crypto.ExportKeys("meow", []*crypto.InboundGroupSession{sess}) assert.NoError(t, err) - assert.Len(t, data, 836) + assert.Len(t, data, 893) } From 4b387c305b43a59703b0483454f52b3271069539 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Thu, 29 Jan 2026 15:01:48 +0000 Subject: [PATCH 1608/1647] error: add `RespError.CanRetry` field (#456) --- error.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/error.go b/error.go index a284f7d1..4711b3dc 100644 --- a/error.go +++ b/error.go @@ -142,6 +142,8 @@ type RespError struct { StatusCode int ExtraHeader map[string]string + + CanRetry bool } func (e *RespError) UnmarshalJSON(data []byte) error { @@ -151,6 +153,7 @@ func (e *RespError) UnmarshalJSON(data []byte) error { } e.ErrCode, _ = e.ExtraData["errcode"].(string) e.Err, _ = e.ExtraData["error"].(string) + e.CanRetry, _ = e.ExtraData["com.beeper.can_retry"].(bool) return nil } @@ -158,6 +161,9 @@ func (e *RespError) MarshalJSON() ([]byte, error) { data := exmaps.NonNilClone(e.ExtraData) data["errcode"] = e.ErrCode data["error"] = e.Err + if e.CanRetry { + data["com.beeper.can_retry"] = e.CanRetry + } return json.Marshal(data) } @@ -188,6 +194,11 @@ func (e RespError) WithStatus(status int) RespError { return e } +func (e RespError) WithCanRetry(canRetry bool) RespError { + e.CanRetry = canRetry + return e +} + func (e RespError) WithExtraData(extraData map[string]any) RespError { e.ExtraData = exmaps.NonNilClone(e.ExtraData) maps.Copy(e.ExtraData, extraData) From d2364b3822751ea1863a886ad9f47ca60fea055c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 29 Jan 2026 19:47:10 +0200 Subject: [PATCH 1609/1647] bridgev2/portal: allow delivery receipts even if portal has no other user ID --- bridgev2/portal.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 6d90a9ed..b72f00a6 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -3639,7 +3639,7 @@ func (portal *Portal) handleRemoteMarkUnread(ctx context.Context, source *UserLo } func (portal *Portal) handleRemoteDeliveryReceipt(ctx context.Context, source *UserLogin, evt RemoteDeliveryReceipt) EventHandlingResult { - if portal.RoomType != database.RoomTypeDM || evt.GetSender().Sender != portal.OtherUserID { + if portal.RoomType != database.RoomTypeDM || (evt.GetSender().Sender != portal.OtherUserID && portal.OtherUserID != "") { return EventHandlingResultIgnored } intent, ok := portal.GetIntentFor(ctx, evt.GetSender(), source, RemoteEventDeliveryReceipt) From fe541df21769ea189e4a4c42a1046bf147c663b5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 11 Feb 2026 21:34:47 +0200 Subject: [PATCH 1610/1647] main: bump minimum Go version to 1.25 --- .github/workflows/go.yml | 11 +++++------ go.mod | 6 +++--- go.sum | 4 ++-- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index deaa1f1d..c0add220 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -15,7 +15,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v6 with: - go-version: "1.25" + go-version: "1.26" cache: true - name: Install libolm @@ -35,8 +35,8 @@ jobs: strategy: fail-fast: false matrix: - go-version: ["1.24", "1.25"] - name: Build (${{ matrix.go-version == '1.25' && 'latest' || 'old' }}, libolm) + go-version: ["1.25", "1.26"] + name: Build (${{ matrix.go-version == '1.26' && 'latest' || 'old' }}, libolm) steps: - uses: actions/checkout@v6 @@ -62,7 +62,6 @@ jobs: run: go test -json -v ./... 2>&1 | gotestfmt - name: Test (jsonv2) - if: matrix.go-version == '1.25' env: GOEXPERIMENT: jsonv2 run: go test -json -v ./... 2>&1 | gotestfmt @@ -72,8 +71,8 @@ jobs: strategy: fail-fast: false matrix: - go-version: ["1.24", "1.25"] - name: Build (${{ matrix.go-version == '1.25' && 'latest' || 'old' }}, goolm) + go-version: ["1.25", "1.26"] + name: Build (${{ matrix.go-version == '1.26' && 'latest' || 'old' }}, goolm) steps: - uses: actions/checkout@v6 diff --git a/go.mod b/go.mod index 27acd21b..a76d1ec7 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,8 @@ module maunium.net/go/mautrix -go 1.24.0 +go 1.25.0 -toolchain go1.25.6 +toolchain go1.26.0 require ( filippo.io/edwards25519 v1.1.0 @@ -17,7 +17,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.16 - go.mau.fi/util v0.9.5 + go.mau.fi/util v0.9.6-0.20260211193350-78c2ff4a9df8 go.mau.fi/zeroconfig v0.2.0 golang.org/x/crypto v0.47.0 golang.org/x/exp v0.0.0-20260112195511-716be5621a96 diff --git a/go.sum b/go.sum index 9702337a..a142a727 100644 --- a/go.sum +++ b/go.sum @@ -52,8 +52,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.16 h1:n+CJdUxaFMiDUNnWC3dMWCIQJSkxH4uz3ZwQBkAlVNE= github.com/yuin/goldmark v1.7.16/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.9.5 h1:7AoWPCIZJGv4jvtFEuCe3GhAbI7uF9ckIooaXvwlIR4= -go.mau.fi/util v0.9.5/go.mod h1:g1uvZ03VQhtTt2BgaRGVytS/Zj67NV0YNIECch0sQCQ= +go.mau.fi/util v0.9.6-0.20260211193350-78c2ff4a9df8 h1:7McVSdP7wEpb1omjyKG5OjxCY2NPP5Ba1pJujkOZx7g= +go.mau.fi/util v0.9.6-0.20260211193350-78c2ff4a9df8/go.mod h1:DzglKWpYOxKq4h9noyJBMoUu72/XgbP8j/OPehS/l/U= go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= From 7dbc4dd16aaf2063a81499fbf23b996a9fd85545 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 12 Feb 2026 17:34:40 +0200 Subject: [PATCH 1611/1647] appservice: fix building websocket url --- appservice/websocket.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/appservice/websocket.go b/appservice/websocket.go index 4f2538bf..ef65e65a 100644 --- a/appservice/websocket.go +++ b/appservice/websocket.go @@ -14,7 +14,7 @@ import ( "io" "net/http" "net/url" - "path/filepath" + "path" "strings" "sync" "sync/atomic" @@ -374,7 +374,7 @@ func (as *AppService) StartWebsocket(ctx context.Context, baseURL string, onConn copiedURL := *as.hsURLForClient parsed = &copiedURL } - parsed.Path = filepath.Join(parsed.Path, "_matrix/client/unstable/fi.mau.as_sync") + parsed.Path = path.Join(parsed.Path, "_matrix/client/unstable/fi.mau.as_sync") if parsed.Scheme == "http" { parsed.Scheme = "ws" } else if parsed.Scheme == "https" { From b97f989032a25fca236508655193ab822019e252 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 14 Feb 2026 23:37:20 +0200 Subject: [PATCH 1612/1647] federation/eventauth: add support for underscores in string power levels --- federation/eventauth/eventauth.go | 11 +++- .../eventauth/eventauth_internal_test.go | 61 +++++++++++++++++++ 2 files changed, 71 insertions(+), 1 deletion(-) create mode 100644 federation/eventauth/eventauth_internal_test.go diff --git a/federation/eventauth/eventauth.go b/federation/eventauth/eventauth.go index eac110a3..d2073607 100644 --- a/federation/eventauth/eventauth.go +++ b/federation/eventauth/eventauth.go @@ -799,7 +799,7 @@ func parsePythonInt(val gjson.Result) *int { return ptr.Ptr(int(val.Int())) case gjson.String: // strconv.Atoi accepts signs as well as leading zeroes, so we just need to trim spaces beforehand - num, err := strconv.Atoi(strings.TrimSpace(val.Str)) + num, err := strconv.Atoi(removeUnderscores(strings.TrimSpace(val.Str))) if err != nil { return nil } @@ -810,6 +810,15 @@ func parsePythonInt(val gjson.Result) *int { } } +func removeUnderscores(num string) string { + numWithoutSign := strings.TrimPrefix(strings.TrimPrefix(num, "+"), "-") + if strings.HasPrefix(numWithoutSign, "_") || strings.HasSuffix(numWithoutSign, "_") { + // Leading or trailing underscores are not valid, let strconv.Atoi fail + return num + } + return strings.ReplaceAll(num, "_", "") +} + func safeParsePowerLevels(content jsontext.Value, into *event.PowerLevelsEventContent) { *into = event.PowerLevelsEventContent{ Users: make(map[id.UserID]int), diff --git a/federation/eventauth/eventauth_internal_test.go b/federation/eventauth/eventauth_internal_test.go new file mode 100644 index 00000000..e8f61b76 --- /dev/null +++ b/federation/eventauth/eventauth_internal_test.go @@ -0,0 +1,61 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//go:build goexperiment.jsonv2 + +package eventauth + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +type pythonIntTest struct { + Name string + Input string + Expected int + Invalid bool +} + +var pythonIntTests = []pythonIntTest{ + {"True", `true`, 1, false}, + {"False", `false`, 0, false}, + {"SmallFloat", `3.1415`, 3, false}, + {"SmallFloatRoundDown", `10.999999999999999`, 10, false}, + {"SmallFloatRoundUp", `10.9999999999999999`, 11, false}, + {"BigFloatRoundDown", `1000000.9999999999`, 1000000, false}, + {"BigFloatRoundUp", `1000000.99999999999`, 1000001, false}, + {"String", `"123"`, 123, false}, + {"FloatInString", `"123.456"`, 0, true}, + {"StringWithPlusSign", `"+123"`, 123, false}, + {"StringWithMinusSign", `"-123"`, -123, false}, + {"StringWithSpaces", `" 123 "`, 123, false}, + {"StringWithSpacesAndSign", `" -123 "`, -123, false}, + {"StringWithUnderscores", `"123_456"`, 123456, false}, + {"StringWithUnderscores", `"123_456"`, 123456, false}, + {"StringWithTrailingUnderscore", `"123_456_"`, 0, true}, + {"StringWithLeadingUnderscore", `"_123_456"`, 0, true}, + {"StringWithUnderscoreAfterSign", `"+_123_456"`, 0, true}, + {"StringWithUnderscoreAfterSpace", `" _123_456"`, 0, true}, + {"StringWithUnderscoresAndSpaces", `" +1_2_3_4_5_6 "`, 123456, false}, +} + +func TestParsePythonInt(t *testing.T) { + for _, test := range pythonIntTests { + t.Run(test.Name, func(t *testing.T) { + output := parsePythonInt(gjson.Parse(test.Input)) + if test.Invalid { + assert.Nil(t, output) + } else { + require.NotNil(t, output) + assert.Equal(t, test.Expected, *output) + } + }) + } +} From bafba9b22773131e21a5be19f540e00fe8afb4ac Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 14 Feb 2026 23:39:57 +0200 Subject: [PATCH 1613/1647] federation/eventauth: make expected success a part of test name --- .../eventauth/eventauth_internal_test.go | 52 ++++++++++--------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/federation/eventauth/eventauth_internal_test.go b/federation/eventauth/eventauth_internal_test.go index e8f61b76..9dd36a7b 100644 --- a/federation/eventauth/eventauth_internal_test.go +++ b/federation/eventauth/eventauth_internal_test.go @@ -9,6 +9,7 @@ package eventauth import ( + "strings" "testing" "github.com/stretchr/testify/assert" @@ -19,42 +20,45 @@ import ( type pythonIntTest struct { Name string Input string - Expected int - Invalid bool + Expected int64 } var pythonIntTests = []pythonIntTest{ - {"True", `true`, 1, false}, - {"False", `false`, 0, false}, - {"SmallFloat", `3.1415`, 3, false}, - {"SmallFloatRoundDown", `10.999999999999999`, 10, false}, - {"SmallFloatRoundUp", `10.9999999999999999`, 11, false}, - {"BigFloatRoundDown", `1000000.9999999999`, 1000000, false}, - {"BigFloatRoundUp", `1000000.99999999999`, 1000001, false}, - {"String", `"123"`, 123, false}, - {"FloatInString", `"123.456"`, 0, true}, - {"StringWithPlusSign", `"+123"`, 123, false}, - {"StringWithMinusSign", `"-123"`, -123, false}, - {"StringWithSpaces", `" 123 "`, 123, false}, - {"StringWithSpacesAndSign", `" -123 "`, -123, false}, - {"StringWithUnderscores", `"123_456"`, 123456, false}, - {"StringWithUnderscores", `"123_456"`, 123456, false}, - {"StringWithTrailingUnderscore", `"123_456_"`, 0, true}, - {"StringWithLeadingUnderscore", `"_123_456"`, 0, true}, - {"StringWithUnderscoreAfterSign", `"+_123_456"`, 0, true}, - {"StringWithUnderscoreAfterSpace", `" _123_456"`, 0, true}, - {"StringWithUnderscoresAndSpaces", `" +1_2_3_4_5_6 "`, 123456, false}, + {"True", `true`, 1}, + {"False", `false`, 0}, + {"SmallFloat", `3.1415`, 3}, + {"SmallFloatRoundDown", `10.999999999999999`, 10}, + {"SmallFloatRoundUp", `10.9999999999999999`, 11}, + {"BigFloatRoundDown", `1000000.9999999999`, 1000000}, + {"BigFloatRoundUp", `1000000.99999999999`, 1000001}, + {"BigFloatPrecisionError", `9007199254740993.0`, 9007199254740992}, + {"BigFloatPrecisionError2", `9007199254740993.123`, 9007199254740994}, + {"Int64", `9223372036854775807`, 9223372036854775807}, + {"Int64String", `"9223372036854775807"`, 9223372036854775807}, + {"String", `"123"`, 123}, + {"InvalidFloatInString", `"123.456"`, 0}, + {"StringWithPlusSign", `"+123"`, 123}, + {"StringWithMinusSign", `"-123"`, -123}, + {"StringWithSpaces", `" 123 "`, 123}, + {"StringWithSpacesAndSign", `" -123 "`, -123}, + {"StringWithUnderscores", `"123_456"`, 123456}, + {"StringWithUnderscores", `"123_456"`, 123456}, + {"InvalidStringWithTrailingUnderscore", `"123_456_"`, 0}, + {"InvalidStringWithLeadingUnderscore", `"_123_456"`, 0}, + {"InvalidStringWithUnderscoreAfterSign", `"+_123_456"`, 0}, + {"InvalidStringWithUnderscoreAfterSpace", `" _123_456"`, 0}, + {"StringWithUnderscoresAndSpaces", `" +1_2_3_4_5_6 "`, 123456}, } func TestParsePythonInt(t *testing.T) { for _, test := range pythonIntTests { t.Run(test.Name, func(t *testing.T) { output := parsePythonInt(gjson.Parse(test.Input)) - if test.Invalid { + if strings.HasPrefix(test.Name, "Invalid") { assert.Nil(t, output) } else { require.NotNil(t, output) - assert.Equal(t, test.Expected, *output) + assert.Equal(t, int(test.Expected), *output) } }) } From c52d87b6ea999e55b5103e3d8e6629691dab16c8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 15 Feb 2026 21:47:10 +0200 Subject: [PATCH 1614/1647] mediaproxy: handle federation thumbnail requests --- mediaproxy/mediaproxy.go | 1 + 1 file changed, 1 insertion(+) diff --git a/mediaproxy/mediaproxy.go b/mediaproxy/mediaproxy.go index 2063675a..4d2bc7cf 100644 --- a/mediaproxy/mediaproxy.go +++ b/mediaproxy/mediaproxy.go @@ -143,6 +143,7 @@ func New(serverName string, serverKey string, getMedia GetMediaFunc) (*MediaProx } mp.FederationRouter = http.NewServeMux() mp.FederationRouter.HandleFunc("GET /v1/media/download/{mediaID}", mp.DownloadMediaFederation) + mp.FederationRouter.HandleFunc("GET /v1/media/thumbnail/{mediaID}", mp.DownloadMediaFederation) mp.FederationRouter.HandleFunc("GET /v1/version", mp.KeyServer.GetServerVersion) mp.ClientMediaRouter = http.NewServeMux() mp.ClientMediaRouter.HandleFunc("GET /download/{serverName}/{mediaID}", mp.DownloadMedia) From 53ed8526c6a9f0af9c1d72a239827e0876fa6f34 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 16 Feb 2026 14:29:09 +0200 Subject: [PATCH 1615/1647] federation/eventauth: disable underscore support in string power levels --- federation/eventauth/eventauth.go | 11 +---------- federation/eventauth/eventauth_internal_test.go | 7 ++++--- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/federation/eventauth/eventauth.go b/federation/eventauth/eventauth.go index d2073607..eac110a3 100644 --- a/federation/eventauth/eventauth.go +++ b/federation/eventauth/eventauth.go @@ -799,7 +799,7 @@ func parsePythonInt(val gjson.Result) *int { return ptr.Ptr(int(val.Int())) case gjson.String: // strconv.Atoi accepts signs as well as leading zeroes, so we just need to trim spaces beforehand - num, err := strconv.Atoi(removeUnderscores(strings.TrimSpace(val.Str))) + num, err := strconv.Atoi(strings.TrimSpace(val.Str)) if err != nil { return nil } @@ -810,15 +810,6 @@ func parsePythonInt(val gjson.Result) *int { } } -func removeUnderscores(num string) string { - numWithoutSign := strings.TrimPrefix(strings.TrimPrefix(num, "+"), "-") - if strings.HasPrefix(numWithoutSign, "_") || strings.HasSuffix(numWithoutSign, "_") { - // Leading or trailing underscores are not valid, let strconv.Atoi fail - return num - } - return strings.ReplaceAll(num, "_", "") -} - func safeParsePowerLevels(content jsontext.Value, into *event.PowerLevelsEventContent) { *into = event.PowerLevelsEventContent{ Users: make(map[id.UserID]int), diff --git a/federation/eventauth/eventauth_internal_test.go b/federation/eventauth/eventauth_internal_test.go index 9dd36a7b..d316f3c8 100644 --- a/federation/eventauth/eventauth_internal_test.go +++ b/federation/eventauth/eventauth_internal_test.go @@ -41,13 +41,14 @@ var pythonIntTests = []pythonIntTest{ {"StringWithMinusSign", `"-123"`, -123}, {"StringWithSpaces", `" 123 "`, 123}, {"StringWithSpacesAndSign", `" -123 "`, -123}, - {"StringWithUnderscores", `"123_456"`, 123456}, - {"StringWithUnderscores", `"123_456"`, 123456}, + //{"StringWithUnderscores", `"123_456"`, 123456}, + //{"StringWithUnderscores", `"123_456"`, 123456}, {"InvalidStringWithTrailingUnderscore", `"123_456_"`, 0}, + {"InvalidStringWithMultipleUnderscores", `"123__456"`, 0}, {"InvalidStringWithLeadingUnderscore", `"_123_456"`, 0}, {"InvalidStringWithUnderscoreAfterSign", `"+_123_456"`, 0}, {"InvalidStringWithUnderscoreAfterSpace", `" _123_456"`, 0}, - {"StringWithUnderscoresAndSpaces", `" +1_2_3_4_5_6 "`, 123456}, + //{"StringWithUnderscoresAndSpaces", `" +1_2_3_4_5_6 "`, 123456}, } func TestParsePythonInt(t *testing.T) { From 0b9471e1904d92ff3055384debd5b1287f3662cb Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 16 Feb 2026 14:31:01 +0200 Subject: [PATCH 1616/1647] dependencies: update --- go.mod | 16 ++++++++-------- go.sum | 32 ++++++++++++++++---------------- 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/go.mod b/go.mod index a76d1ec7..647473cc 100644 --- a/go.mod +++ b/go.mod @@ -8,8 +8,8 @@ require ( filippo.io/edwards25519 v1.1.0 github.com/chzyer/readline v1.5.1 github.com/coder/websocket v1.8.14 - github.com/lib/pq v1.10.9 - github.com/mattn/go-sqlite3 v1.14.33 + github.com/lib/pq v1.11.2 + github.com/mattn/go-sqlite3 v1.14.34 github.com/rs/xid v1.6.0 github.com/rs/zerolog v1.34.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e @@ -17,11 +17,11 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/yuin/goldmark v1.7.16 - go.mau.fi/util v0.9.6-0.20260211193350-78c2ff4a9df8 + go.mau.fi/util v0.9.6 go.mau.fi/zeroconfig v0.2.0 - golang.org/x/crypto v0.47.0 - golang.org/x/exp v0.0.0-20260112195511-716be5621a96 - golang.org/x/net v0.49.0 + golang.org/x/crypto v0.48.0 + golang.org/x/exp v0.0.0-20260212183809-81e46e3db34a + golang.org/x/net v0.50.0 golang.org/x/sync v0.19.0 gopkg.in/yaml.v3 v3.0.1 maunium.net/go/mauflag v1.0.0 @@ -36,7 +36,7 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect - golang.org/x/sys v0.40.0 // indirect - golang.org/x/text v0.33.0 // indirect + golang.org/x/sys v0.41.0 // indirect + golang.org/x/text v0.34.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect ) diff --git a/go.sum b/go.sum index a142a727..dafa8c67 100644 --- a/go.sum +++ b/go.sum @@ -16,8 +16,8 @@ github.com/coreos/go-systemd/v22 v22.6.0/go.mod h1:iG+pp635Fo7ZmV/j14KUcmEyWF+0X github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= -github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lib/pq v1.11.2 h1:x6gxUeu39V0BHZiugWe8LXZYZ+Utk7hSJGThs8sdzfs= +github.com/lib/pq v1.11.2/go.mod h1:/p+8NSbOcwzAEI7wiMXFlgydTwcgTr3OSKMsD2BitpA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= @@ -25,8 +25,8 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0= -github.com/mattn/go-sqlite3 v1.14.33/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk= +github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741 h1:KPpdlQLZcHfTMQRi6bFQ7ogNO0ltFT4PmtwTLW4W+14= github.com/petermattis/goid v0.0.0-20260113132338-7c7de50cc741/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -52,26 +52,26 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.16 h1:n+CJdUxaFMiDUNnWC3dMWCIQJSkxH4uz3ZwQBkAlVNE= github.com/yuin/goldmark v1.7.16/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -go.mau.fi/util v0.9.6-0.20260211193350-78c2ff4a9df8 h1:7McVSdP7wEpb1omjyKG5OjxCY2NPP5Ba1pJujkOZx7g= -go.mau.fi/util v0.9.6-0.20260211193350-78c2ff4a9df8/go.mod h1:DzglKWpYOxKq4h9noyJBMoUu72/XgbP8j/OPehS/l/U= +go.mau.fi/util v0.9.6 h1:2nsvxm49KhI3wrFltr0+wSUBlnQ4CMtykuELjpIU+ts= +go.mau.fi/util v0.9.6/go.mod h1:sIJpRH7Iy5Ad1SBuxQoatxtIeErgzxCtjd/2hCMkYMI= go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= -golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= -golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= -golang.org/x/exp v0.0.0-20260112195511-716be5621a96 h1:Z/6YuSHTLOHfNFdb8zVZomZr7cqNgTJvA8+Qz75D8gU= -golang.org/x/exp v0.0.0-20260112195511-716be5621a96/go.mod h1:nzimsREAkjBCIEFtHiYkrJyT+2uy9YZJB7H1k68CXZU= -golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= -golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/exp v0.0.0-20260212183809-81e46e3db34a h1:ovFr6Z0MNmU7nH8VaX5xqw+05ST2uO1exVfZPVqRC5o= +golang.org/x/exp v0.0.0-20260212183809-81e46e3db34a/go.mod h1:K79w1Vqn7PoiZn+TkNpx3BUWUQksGO3JcVX6qIjytmA= +golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= +golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= -golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= -golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= +golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= From 9cd7258764e3b17649887cc05c73b2ef90447650 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 16 Feb 2026 14:33:21 +0200 Subject: [PATCH 1617/1647] Bump version to v0.26.3 --- CHANGELOG.md | 21 +++++++++++++++++++++ version.go | 2 +- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dbc7c494..f2829199 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,24 @@ +## v0.26.3 (2026-02-16) + +* Bumped minimum Go version to 1.25. +* *(client)* Added fields for sending [MSC4354] sticky events. +* *(bridgev2)* Added automatic message request accepting when sending message. +* *(mediaproxy)* Added support for federation thumbnail endpoint. +* *(crypto/ssss)* Improved support for recovery keys with slightly broken + metadata. +* *(crypto)* Changed key import to call session received callback even for + sessions that already exist in the database. +* *(appservice)* Fixed building websocket URL accidentally using file path + separators instead of always `/`. +* *(crypto)* Fixed key exports not including the `sender_claimed_keys` field. +* *(client)* Fixed incorrect context usage in async uploads. +* *(crypto)* Fixed panic when passing invalid input to megolm message index + parser used for debugging. +* *(bridgev2/provisioning)* Fixed completed or failed logins not being cleaned + up properly. + +[MSC4354]: https://github.com/matrix-org/matrix-spec-proposals/pull/4354 + ## v0.26.2 (2026-01-16) * *(bridgev2)* Added chunked portal deletion to avoid database locks when diff --git a/version.go b/version.go index b0e31c7e..f00bbf39 100644 --- a/version.go +++ b/version.go @@ -8,7 +8,7 @@ import ( "strings" ) -const Version = "v0.26.2" +const Version = "v0.26.3" var GoModVersion = "" var Commit = "" From de0d12e26a7e548a1013c7e3ddd5e6c42b7feba8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 18 Feb 2026 12:41:16 +0200 Subject: [PATCH 1618/1647] goolm/crypto: add test to ensure shared secrets can't be zero --- crypto/goolm/crypto/curve25519.go | 1 + crypto/goolm/crypto/curve25519_test.go | 2 ++ 2 files changed, 3 insertions(+) diff --git a/crypto/goolm/crypto/curve25519.go b/crypto/goolm/crypto/curve25519.go index e9759501..6e42d886 100644 --- a/crypto/goolm/crypto/curve25519.go +++ b/crypto/goolm/crypto/curve25519.go @@ -53,6 +53,7 @@ func (c Curve25519KeyPair) B64Encoded() id.Curve25519 { // SharedSecret returns the shared secret between the key pair and the given public key. func (c Curve25519KeyPair) SharedSecret(pubKey Curve25519PublicKey) ([]byte, error) { + // Note: the standard library checks that the output is non-zero return c.PrivateKey.SharedSecret(pubKey) } diff --git a/crypto/goolm/crypto/curve25519_test.go b/crypto/goolm/crypto/curve25519_test.go index 9039c126..2550f15e 100644 --- a/crypto/goolm/crypto/curve25519_test.go +++ b/crypto/goolm/crypto/curve25519_test.go @@ -25,6 +25,8 @@ func TestCurve25519(t *testing.T) { fromPrivate, err := crypto.Curve25519GenerateFromPrivate(firstKeypair.PrivateKey) assert.NoError(t, err) assert.Equal(t, fromPrivate, firstKeypair) + _, err = secondKeypair.SharedSecret(make([]byte, crypto.Curve25519PublicKeyLength)) + assert.Error(t, err) } func TestCurve25519Case1(t *testing.T) { From ae58161412b86a684d8c581d7323910211a72aea Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 19 Feb 2026 14:09:59 +0200 Subject: [PATCH 1619/1647] bridgev2/provisioning: log group create params --- bridgev2/provisionutil/creategroup.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/bridgev2/provisionutil/creategroup.go b/bridgev2/provisionutil/creategroup.go index fbe0a513..72bacaff 100644 --- a/bridgev2/provisionutil/creategroup.go +++ b/bridgev2/provisionutil/creategroup.go @@ -32,6 +32,9 @@ func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev if !ok { return nil, bridgev2.RespError(mautrix.MUnrecognized.WithMessage("This bridge does not support creating groups")) } + zerolog.Ctx(ctx).Debug(). + Any("create_params", params). + Msg("Creating group chat on remote network") caps := login.Bridge.Network.GetCapabilities() typeSpec, validType := caps.Provisioning.GroupCreation[params.Type] if !validType { @@ -98,6 +101,9 @@ func CreateGroup(ctx context.Context, login *bridgev2.UserLogin, params *bridgev if resp.PortalKey.IsEmpty() { return nil, ErrNoPortalKey } + zerolog.Ctx(ctx).Debug(). + Object("portal_key", resp.PortalKey). + Msg("Successfully created group on remote network") if resp.Portal == nil { resp.Portal, err = login.Bridge.GetPortalByKey(ctx, resp.PortalKey) if err != nil { From 974f7dc5446f25090b5cf35f53579a5bdd437d58 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 19 Feb 2026 14:10:20 +0200 Subject: [PATCH 1620/1647] crypto/decryptmegolm: allow device key mismatches, but mark as untrusted --- crypto/decryptmegolm.go | 8 +++++++- id/trust.go | 7 ++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index 77a64b1e..9753eabd 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -124,7 +124,13 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event Msg("Couldn't resolve trust level of session: sent by unknown device") trustLevel = id.TrustStateUnknownDevice } else if device.SigningKey != sess.SigningKey || device.IdentityKey != sess.SenderKey { - return nil, ErrDeviceKeyMismatch + log.Debug(). + Stringer("session_sender_key", sess.SenderKey). + Stringer("device_sender_key", device.IdentityKey). + Stringer("session_signing_key", sess.SigningKey). + Stringer("device_signing_key", device.SigningKey). + Msg("Device keys don't match keys in session, marking as untrusted") + trustLevel = id.TrustStateDeviceKeyMismatch } else { trustLevel, err = mach.ResolveTrustContext(ctx, device) if err != nil { diff --git a/id/trust.go b/id/trust.go index 04f6e36b..6255093e 100644 --- a/id/trust.go +++ b/id/trust.go @@ -16,6 +16,7 @@ type TrustState int const ( TrustStateBlacklisted TrustState = -100 + TrustStateDeviceKeyMismatch TrustState = -5 TrustStateUnset TrustState = 0 TrustStateUnknownDevice TrustState = 10 TrustStateForwarded TrustState = 20 @@ -23,7 +24,7 @@ const ( TrustStateCrossSignedTOFU TrustState = 100 TrustStateCrossSignedVerified TrustState = 200 TrustStateVerified TrustState = 300 - TrustStateInvalid TrustState = (1 << 31) - 1 + TrustStateInvalid TrustState = -2147483647 ) func (ts *TrustState) UnmarshalText(data []byte) error { @@ -44,6 +45,8 @@ func ParseTrustState(val string) TrustState { switch strings.ToLower(val) { case "blacklisted": return TrustStateBlacklisted + case "device-key-mismatch": + return TrustStateDeviceKeyMismatch case "unverified": return TrustStateUnset case "cross-signed-untrusted": @@ -67,6 +70,8 @@ func (ts TrustState) String() string { switch ts { case TrustStateBlacklisted: return "blacklisted" + case TrustStateDeviceKeyMismatch: + return "device-key-mismatch" case TrustStateUnset: return "unverified" case TrustStateCrossSignedUntrusted: From 67d30e054ccd982cfae117653fb90cb2d60c612f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 19 Feb 2026 22:51:31 +0200 Subject: [PATCH 1621/1647] dependencies: update --- go.mod | 4 ++-- go.sum | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 647473cc..49a1d4e4 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.25.0 toolchain go1.26.0 require ( - filippo.io/edwards25519 v1.1.0 + filippo.io/edwards25519 v1.2.0 github.com/chzyer/readline v1.5.1 github.com/coder/websocket v1.8.14 github.com/lib/pq v1.11.2 @@ -20,7 +20,7 @@ require ( go.mau.fi/util v0.9.6 go.mau.fi/zeroconfig v0.2.0 golang.org/x/crypto v0.48.0 - golang.org/x/exp v0.0.0-20260212183809-81e46e3db34a + golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa golang.org/x/net v0.50.0 golang.org/x/sync v0.19.0 gopkg.in/yaml.v3 v3.0.1 diff --git a/go.sum b/go.sum index dafa8c67..871a5156 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= -filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo= +filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM= @@ -58,8 +58,8 @@ go.mau.fi/zeroconfig v0.2.0 h1:e/OGEERqVRRKlgaro7E6bh8xXiKFSXB3eNNIud7FUjU= go.mau.fi/zeroconfig v0.2.0/go.mod h1:J0Vn0prHNOm493oZoQ84kq83ZaNCYZnq+noI1b1eN8w= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= -golang.org/x/exp v0.0.0-20260212183809-81e46e3db34a h1:ovFr6Z0MNmU7nH8VaX5xqw+05ST2uO1exVfZPVqRC5o= -golang.org/x/exp v0.0.0-20260212183809-81e46e3db34a/go.mod h1:K79w1Vqn7PoiZn+TkNpx3BUWUQksGO3JcVX6qIjytmA= +golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa h1:Zt3DZoOFFYkKhDT3v7Lm9FDMEV06GpzjG2jrqW+QTE0= +golang.org/x/exp v0.0.0-20260218203240-3dfff04db8fa/go.mod h1:K79w1Vqn7PoiZn+TkNpx3BUWUQksGO3JcVX6qIjytmA= golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= From bc79822eab1546980a56681e4ad07f0ed69941ce Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 21 Feb 2026 00:51:44 +0200 Subject: [PATCH 1622/1647] crypto: save source of megolm sessions --- crypto/keybackup.go | 3 +- crypto/keyimport.go | 13 ++++----- crypto/keysharing.go | 1 + crypto/sessions.go | 2 ++ crypto/sql_store.go | 25 ++++++++++------- .../sql_store_upgrade/00-latest-revision.sql | 3 +- .../19-megolm-session-source.sql | 2 ++ id/crypto.go | 28 +++++++++++++++++++ 8 files changed, 58 insertions(+), 19 deletions(-) create mode 100644 crypto/sql_store_upgrade/19-megolm-session-source.sql diff --git a/crypto/keybackup.go b/crypto/keybackup.go index ceec1d58..7b3c30db 100644 --- a/crypto/keybackup.go +++ b/crypto/keybackup.go @@ -200,13 +200,14 @@ func (mach *OlmMachine) ImportRoomKeyFromBackupWithoutSaving( SigningKey: keyBackupData.SenderClaimedKeys.Ed25519, SenderKey: keyBackupData.SenderKey, RoomID: roomID, - ForwardingChains: append(keyBackupData.ForwardingKeyChain, keyBackupData.SenderKey.String()), + ForwardingChains: keyBackupData.ForwardingKeyChain, id: sessionID, ReceivedAt: time.Now().UTC(), MaxAge: maxAge.Milliseconds(), MaxMessages: maxMessages, KeyBackupVersion: version, + KeySource: id.KeySourceBackup, }, nil } diff --git a/crypto/keyimport.go b/crypto/keyimport.go index aef3eca2..3ffc74a5 100644 --- a/crypto/keyimport.go +++ b/crypto/keyimport.go @@ -108,14 +108,13 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor return false, ErrMismatchingExportedSessionID } igs := &InboundGroupSession{ - Internal: igsInternal, - SigningKey: session.SenderClaimedKeys.Ed25519, - SenderKey: session.SenderKey, - RoomID: session.RoomID, - // TODO should we add something here to mark the signing key as unverified like key requests do? + Internal: igsInternal, + SigningKey: session.SenderClaimedKeys.Ed25519, + SenderKey: session.SenderKey, + RoomID: session.RoomID, ForwardingChains: session.ForwardingChains, - - ReceivedAt: time.Now().UTC(), + KeySource: id.KeySourceImport, + ReceivedAt: time.Now().UTC(), } existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID()) firstKnownIndex := igs.Internal.FirstKnownIndex() diff --git a/crypto/keysharing.go b/crypto/keysharing.go index c1f7171c..19a68c87 100644 --- a/crypto/keysharing.go +++ b/crypto/keysharing.go @@ -189,6 +189,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt MaxAge: maxAge.Milliseconds(), MaxMessages: maxMessages, IsScheduled: content.IsScheduled, + KeySource: id.KeySourceForward, } existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.ID()) if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() { diff --git a/crypto/sessions.go b/crypto/sessions.go index d7e68eb1..ccc7b784 100644 --- a/crypto/sessions.go +++ b/crypto/sessions.go @@ -117,6 +117,7 @@ type InboundGroupSession struct { MaxMessages int IsScheduled bool KeyBackupVersion id.KeyBackupVersion + KeySource id.KeySource id id.SessionID } @@ -136,6 +137,7 @@ func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomI MaxAge: maxAge.Milliseconds(), MaxMessages: maxMessages, IsScheduled: isScheduled, + KeySource: id.KeySourceDirect, }, nil } diff --git a/crypto/sql_store.go b/crypto/sql_store.go index ca75b3f6..138cc557 100644 --- a/crypto/sql_store.go +++ b/crypto/sql_store.go @@ -346,22 +346,23 @@ func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, session *Inbou Int("max_messages", session.MaxMessages). Bool("is_scheduled", session.IsScheduled). Stringer("key_backup_version", session.KeyBackupVersion). + Stringer("key_source", session.KeySource). Msg("Upserting megolm inbound group session") _, err = store.DB.Exec(ctx, ` INSERT INTO crypto_megolm_inbound_session ( session_id, sender_key, signing_key, room_id, session, forwarding_chains, - ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, account_id - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) + ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source, account_id + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) ON CONFLICT (session_id, account_id) DO UPDATE SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key, room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains, ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at, max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled, - key_backup_version=excluded.key_backup_version + key_backup_version=excluded.key_backup_version, key_source=excluded.key_source `, session.ID(), session.SenderKey, session.SigningKey, session.RoomID, sessionBytes, forwardingChains, ratchetSafety, datePtr(session.ReceivedAt), dbutil.NumPtr(session.MaxAge), dbutil.NumPtr(session.MaxMessages), - session.IsScheduled, session.KeyBackupVersion, store.AccountID, + session.IsScheduled, session.KeyBackupVersion, session.KeySource, store.AccountID, ) return err } @@ -374,12 +375,13 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room var maxAge, maxMessages sql.NullInt64 var isScheduled bool var version id.KeyBackupVersion + var keySource id.KeySource err := store.DB.QueryRow(ctx, ` - SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version + SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source FROM crypto_megolm_inbound_session WHERE room_id=$1 AND session_id=$2 AND account_id=$3`, roomID, sessionID, store.AccountID, - ).Scan(&senderKey, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version) + ).Scan(&senderKey, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version, &keySource) if errors.Is(err, sql.ErrNoRows) { return nil, nil } else if err != nil { @@ -410,6 +412,7 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room MaxMessages: int(maxMessages.Int64), IsScheduled: isScheduled, KeyBackupVersion: version, + KeySource: keySource, }, nil } @@ -534,7 +537,8 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In var maxAge, maxMessages sql.NullInt64 var isScheduled bool var version id.KeyBackupVersion - err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version) + var keySource id.KeySource + err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version, &keySource) if err != nil { return nil, err } @@ -554,12 +558,13 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In MaxMessages: int(maxMessages.Int64), IsScheduled: isScheduled, KeyBackupVersion: version, + KeySource: keySource, }, nil } func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) dbutil.RowIter[*InboundGroupSession] { rows, err := store.DB.Query(ctx, ` - SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version + SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`, roomID, store.AccountID, ) @@ -568,7 +573,7 @@ func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.RowIter[*InboundGroupSession] { rows, err := store.DB.Query(ctx, ` - SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version + SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL`, store.AccountID, ) @@ -577,7 +582,7 @@ func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.Row func (store *SQLCryptoStore) GetGroupSessionsWithoutKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession] { rows, err := store.DB.Query(ctx, ` - SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version + SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, key_source FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL AND key_backup_version != $2`, store.AccountID, version, ) diff --git a/crypto/sql_store_upgrade/00-latest-revision.sql b/crypto/sql_store_upgrade/00-latest-revision.sql index af8ab5cc..3709f1e5 100644 --- a/crypto/sql_store_upgrade/00-latest-revision.sql +++ b/crypto/sql_store_upgrade/00-latest-revision.sql @@ -1,4 +1,4 @@ --- v0 -> v18 (compatible with v15+): Latest revision +-- v0 -> v19 (compatible with v15+): Latest revision CREATE TABLE IF NOT EXISTS crypto_account ( account_id TEXT PRIMARY KEY, device_id TEXT NOT NULL, @@ -71,6 +71,7 @@ CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session ( max_messages INTEGER, is_scheduled BOOLEAN NOT NULL DEFAULT false, key_backup_version TEXT NOT NULL DEFAULT '', + key_source TEXT NOT NULL DEFAULT '', PRIMARY KEY (account_id, session_id) ); -- Useful index to find keys that need backing up diff --git a/crypto/sql_store_upgrade/19-megolm-session-source.sql b/crypto/sql_store_upgrade/19-megolm-session-source.sql new file mode 100644 index 00000000..f624222f --- /dev/null +++ b/crypto/sql_store_upgrade/19-megolm-session-source.sql @@ -0,0 +1,2 @@ +-- v19 (compatible with v15+): Store megolm session source +ALTER TABLE crypto_megolm_inbound_session ADD COLUMN key_source TEXT NOT NULL DEFAULT ''; diff --git a/id/crypto.go b/id/crypto.go index 355a84a8..ee857f78 100644 --- a/id/crypto.go +++ b/id/crypto.go @@ -53,6 +53,34 @@ const ( KeyBackupAlgorithmMegolmBackupV1 KeyBackupAlgorithm = "m.megolm_backup.v1.curve25519-aes-sha2" ) +type KeySource string + +func (source KeySource) String() string { + return string(source) +} + +func (source KeySource) Int() int { + switch source { + case KeySourceDirect: + return 100 + case KeySourceBackup: + return 90 + case KeySourceImport: + return 80 + case KeySourceForward: + return 50 + default: + return 0 + } +} + +const ( + KeySourceDirect KeySource = "direct" + KeySourceBackup KeySource = "backup" + KeySourceImport KeySource = "import" + KeySourceForward KeySource = "forward" +) + // BackupVersion is an arbitrary string that identifies a server side key backup. type KeyBackupVersion string From 5779871f1b22e48433b37d68b9761d149422b590 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 21 Feb 2026 14:09:20 +0200 Subject: [PATCH 1623/1647] bridgev2/commands: add file info for QR codes --- bridgev2/commands/login.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go index 80a7c733..c35b3952 100644 --- a/bridgev2/commands/login.go +++ b/bridgev2/commands/login.go @@ -251,14 +251,19 @@ func sendQR(ce *Event, qr string, prevEventID *id.EventID) error { return fmt.Errorf("failed to upload image: %w", err) } content := &event.MessageEventContent{ - MsgType: event.MsgImage, - FileName: "qr.png", - URL: qrMXC, - File: qrFile, - + MsgType: event.MsgImage, + FileName: "qr.png", + URL: qrMXC, + File: qrFile, Body: qr, Format: event.FormatHTML, FormattedBody: fmt.Sprintf("

      %s
      ", html.EscapeString(qr)), + Info: &event.FileInfo{ + MimeType: "image/png", + Width: qrSizePx, + Height: qrSizePx, + Size: len(qrData), + }, } if *prevEventID != "" { content.SetEdit(*prevEventID) From 28b7bf7e567ed5bf7e80ae3b0e0abbe6042566aa Mon Sep 17 00:00:00 2001 From: timedout Date: Sun, 22 Feb 2026 19:37:19 +0000 Subject: [PATCH 1624/1647] federation/eventauth: Fix inverted membership check for 5.6.1 (#464) --- federation/eventauth/eventauth.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/federation/eventauth/eventauth.go b/federation/eventauth/eventauth.go index eac110a3..c72933c2 100644 --- a/federation/eventauth/eventauth.go +++ b/federation/eventauth/eventauth.go @@ -505,7 +505,7 @@ func authorizeMember(roomVersion id.RoomVersion, evt, createEvt *pdu.PDU, authEv // 5.5.5. Otherwise, reject. return ErrInsufficientPermissionForKick case event.MembershipBan: - if senderMembership != event.MembershipLeave { + if senderMembership != event.MembershipJoin { // 5.6.1. If the sender’s current membership state is not join, reject. return ErrCantBanWithoutBeingInRoom } From 3efa3ef73a8230cf5b63a84d9184c04cfa7412d0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 23 Feb 2026 22:13:57 +0200 Subject: [PATCH 1625/1647] bridgev2/portal: log remote event timestamps by default --- bridgev2/portal.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index b72f00a6..718a5cb2 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -484,6 +484,11 @@ func (portal *Portal) getEventCtxWithLog(rawEvt any, idx int) context.Context { logWith = logWith.Int64("remote_stream_order", remoteStreamOrder) } } + if remoteMsg, ok := evt.evt.(RemoteEventWithTimestamp); ok { + if remoteTimestamp := remoteMsg.GetTimestamp(); !remoteTimestamp.IsZero() { + logWith = logWith.Time("remote_timestamp", remoteTimestamp) + } + } case *portalCreateEvent: return evt.ctx } From 7f24c7800222741910f359ff713333a518de3d50 Mon Sep 17 00:00:00 2001 From: Radon Rosborough Date: Wed, 25 Feb 2026 08:52:29 -0800 Subject: [PATCH 1626/1647] bridgev2/login: add attachments option to user input step type (#465) --- bridgev2/commands/login.go | 34 ++++++++++++++++++++++++++++++ bridgev2/login.go | 19 +++++++++++++++++ bridgev2/matrix/provisioning.yaml | 35 +++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+) diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go index c35b3952..9e706995 100644 --- a/bridgev2/commands/login.go +++ b/bridgev2/commands/login.go @@ -278,6 +278,36 @@ func sendQR(ce *Event, qr string, prevEventID *id.EventID) error { return nil } +func sendUserInputAttachments(ce *Event, atts []*bridgev2.LoginUserInputAttachment) error { + for _, att := range atts { + if att.FileName == "" { + return fmt.Errorf("missing attachment filename") + } + mxc, file, err := ce.Bot.UploadMedia(ce.Ctx, ce.RoomID, att.Content, att.FileName, att.Info.MimeType) + if err != nil { + return fmt.Errorf("failed to upload attachment %q: %w", att.FileName, err) + } + content := &event.MessageEventContent{ + MsgType: att.Type, + FileName: att.FileName, + URL: mxc, + File: file, + Info: &event.FileInfo{ + MimeType: att.Info.MimeType, + Width: att.Info.Width, + Height: att.Info.Height, + Size: att.Info.Size, + }, + Body: att.FileName, + } + _, err = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: content}, nil) + if err != nil { + return nil + } + } + return nil +} + type contextKey int const ( @@ -483,6 +513,10 @@ func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginSte Override: override, }).prompt(ce) case bridgev2.LoginStepTypeUserInput: + err := sendUserInputAttachments(ce, step.UserInputParams.Attachments) + if err != nil { + ce.Reply("Failed to send attachments: %v", err) + } (&userInputLoginCommandState{ Login: login.(bridgev2.LoginProcessUserInput), RemainingFields: step.UserInputParams.Fields, diff --git a/bridgev2/login.go b/bridgev2/login.go index 4ddbf13e..b8321719 100644 --- a/bridgev2/login.go +++ b/bridgev2/login.go @@ -13,6 +13,7 @@ import ( "strings" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/event" ) // LoginProcess represents a single occurrence of a user logging into the remote network. @@ -179,6 +180,7 @@ const ( LoginInputFieldTypeURL LoginInputFieldType = "url" LoginInputFieldTypeDomain LoginInputFieldType = "domain" LoginInputFieldTypeSelect LoginInputFieldType = "select" + LoginInputFieldTypeCaptchaCode LoginInputFieldType = "captcha_code" ) type LoginInputDataField struct { @@ -271,6 +273,23 @@ func (f *LoginInputDataField) FillDefaultValidate() { type LoginUserInputParams struct { // The fields that the user needs to fill in. Fields []LoginInputDataField `json:"fields"` + + // Attachments to display alongside the input fields. + Attachments []*LoginUserInputAttachment `json:"attachments"` +} + +type LoginUserInputAttachment struct { + Type event.MessageType `json:"type,omitempty"` + FileName string `json:"filename,omitempty"` + Content []byte `json:"content,omitempty"` + Info LoginUserInputAttachmentInfo `json:"info,omitempty"` +} + +type LoginUserInputAttachmentInfo struct { + MimeType string `json:"mimetype,omitempty"` + Width int `json:"w,omitempty"` + Height int `json:"h,omitempty"` + Size int `json:"size,omitempty"` } type LoginCompleteParams struct { diff --git a/bridgev2/matrix/provisioning.yaml b/bridgev2/matrix/provisioning.yaml index d19a7e83..26068db4 100644 --- a/bridgev2/matrix/provisioning.yaml +++ b/bridgev2/matrix/provisioning.yaml @@ -740,6 +740,41 @@ components: description: For fields of type select, the valid options. items: type: string + attachments: + type: array + description: A list of media attachments to show the user alongside the form fields. + items: + type: object + description: A media attachment to show the user. + required: [ type, filename, content ] + properties: + type: + type: string + description: The type of media attachment, using the same media type identifiers as Matrix attachments. Only some are supported. + enum: [ m.image, m.audio ] + filename: + type: string + description: The filename for the media attachment. + content: + type: string + description: The raw file content for the attachment encoded in base64. + info: + type: object + description: Optional but recommended metadata for the attachment. Can generally be derived from the raw content if omitted. + properties: + mimetype: + type: string + description: The MIME type for the media content. + examples: [ image/png, audio/mpeg ] + w: + type: number + description: The width of the media in pixels. Only applicable for images and videos. + h: + type: number + description: The height of the media in pixels. Only applicable for images and videos. + size: + type: number + description: The size of the media content in number of bytes. Strongly recommended to include. - description: Cookie login step required: [ type, cookies ] properties: From 98c830181ba1953d78b45761cce39e281b1d7089 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 26 Feb 2026 17:20:31 +0200 Subject: [PATCH 1627/1647] client: omit large request bodies from logs --- client.go | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index 2503556a..0a43816c 100644 --- a/client.go +++ b/client.go @@ -386,7 +386,14 @@ func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err er } } if body := req.Context().Value(LogBodyContextKey); body != nil { - evt.Interface("req_body", body) + switch typedLogBody := body.(type) { + case json.RawMessage: + evt.RawJSON("req_body", typedLogBody) + case string: + evt.Str("req_body", typedLogBody) + default: + panic(fmt.Errorf("invalid type for LogBodyContextKey: %T", body)) + } } if errors.Is(err, context.Canceled) { evt.Msg("Request canceled") @@ -450,8 +457,10 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e } if params.SensitiveContent && !logSensitiveContent { logBody = "" + } else if len(jsonStr) > 32768 { + logBody = fmt.Sprintf("", len(jsonStr)) } else { - logBody = params.RequestJSON + logBody = json.RawMessage(jsonStr) } reqBody = bytes.NewReader(jsonStr) reqLen = int64(len(jsonStr)) @@ -476,7 +485,7 @@ func (params *FullRequest) compileRequest(ctx context.Context) (*http.Request, e } } else if params.Method != http.MethodGet && params.Method != http.MethodHead { params.RequestJSON = struct{}{} - logBody = params.RequestJSON + logBody = json.RawMessage("{}") reqBody = bytes.NewReader([]byte("{}")) reqLen = 2 } From dd51c562abb36f8e325acefe8d9fd6a43644f0b0 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 26 Feb 2026 17:21:10 +0200 Subject: [PATCH 1628/1647] crypto: log destination map when sharing megolm sessions --- crypto/decryptolm.go | 3 +++ crypto/encryptmegolm.go | 15 +++++---------- crypto/encryptolm.go | 16 ++++++++++------ 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/crypto/decryptolm.go b/crypto/decryptolm.go index cd02726d..aea5e6dc 100644 --- a/crypto/decryptolm.go +++ b/crypto/decryptolm.go @@ -134,6 +134,9 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *e } func olmMessageHash(ciphertext string) ([32]byte, error) { + if ciphertext == "" { + return [32]byte{}, fmt.Errorf("empty ciphertext") + } ciphertextBytes, err := base64.RawStdEncoding.DecodeString(ciphertext) return sha256.Sum256(ciphertextBytes), err } diff --git a/crypto/encryptmegolm.go b/crypto/encryptmegolm.go index 806a227d..88f9c8d4 100644 --- a/crypto/encryptmegolm.go +++ b/crypto/encryptmegolm.go @@ -370,26 +370,19 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session log.Trace().Msg("Encrypting group session for all found devices") deviceCount := 0 toDevice := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)} + logUsers := zerolog.Dict() for userID, sessions := range olmSessions { if len(sessions) == 0 { continue } + logDevices := zerolog.Dict() output := make(map[id.DeviceID]*event.Content) toDevice.Messages[userID] = output for deviceID, device := range sessions { - log.Trace(). - Stringer("target_user_id", userID). - Stringer("target_device_id", deviceID). - Stringer("target_identity_key", device.identity.IdentityKey). - Msg("Encrypting group session for device") content := mach.encryptOlmEvent(ctx, device.session, device.identity, event.ToDeviceRoomKey, session.ShareContent()) output[deviceID] = &event.Content{Parsed: content} + logDevices.Str(string(deviceID), string(device.identity.IdentityKey)) deviceCount++ - log.Debug(). - Stringer("target_user_id", userID). - Stringer("target_device_id", deviceID). - Stringer("target_identity_key", device.identity.IdentityKey). - Msg("Encrypted group session for device") if !mach.DisableSharedGroupSessionTracking { err := mach.CryptoStore.MarkOutboundGroupSessionShared(ctx, userID, device.identity.IdentityKey, session.id) if err != nil { @@ -403,11 +396,13 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session } } } + logUsers.Dict(string(userID), logDevices) } log.Debug(). Int("device_count", deviceCount). Int("user_count", len(toDevice.Messages)). + Dict("destination_map", logUsers). Msg("Sending to-device messages to share group session") _, err := mach.Client.SendToDevice(ctx, event.ToDeviceEncrypted, toDevice) return err diff --git a/crypto/encryptolm.go b/crypto/encryptolm.go index 80b76dc5..765307af 100644 --- a/crypto/encryptolm.go +++ b/crypto/encryptolm.go @@ -96,15 +96,19 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession panic(err) } log := mach.machOrContextLog(ctx) - log.Debug(). - Str("recipient_identity_key", recipient.IdentityKey.String()). - Str("olm_session_id", session.ID().String()). - Str("olm_session_description", session.Describe()). - Msg("Encrypting olm message") msgType, ciphertext, err := session.Encrypt(plaintext) if err != nil { panic(err) } + ciphertextStr := string(ciphertext) + ciphertextHash, _ := olmMessageHash(ciphertextStr) + log.Debug(). + Stringer("event_type", evtType). + Str("recipient_identity_key", recipient.IdentityKey.String()). + Str("olm_session_id", session.ID().String()). + Str("olm_session_description", session.Describe()). + Hex("ciphertext_hash", ciphertextHash[:]). + Msg("Encrypted olm message") err = mach.CryptoStore.UpdateSession(ctx, recipient.IdentityKey, session) if err != nil { log.Error().Err(err).Msg("Failed to update olm session in crypto store after encrypting") @@ -115,7 +119,7 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession OlmCiphertext: event.OlmCiphertexts{ recipient.IdentityKey: { Type: msgType, - Body: string(ciphertext), + Body: ciphertextStr, }, }, } From 36c353abc7b40d8d9a951286ca7824bd3bfc6744 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 1 Mar 2026 12:37:13 +0200 Subject: [PATCH 1629/1647] federation/pdu: add AddSignature helper method --- federation/pdu/pdu.go | 13 +++++++++++++ federation/pdu/signature.go | 8 +------- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/federation/pdu/pdu.go b/federation/pdu/pdu.go index cecee5b9..17db6995 100644 --- a/federation/pdu/pdu.go +++ b/federation/pdu/pdu.go @@ -123,6 +123,19 @@ func (pdu *PDU) ToClientEvent(roomVersion id.RoomVersion) (*event.Event, error) return evt, nil } +func (pdu *PDU) AddSignature(serverName string, keyID id.KeyID, signature string) { + if signature == "" { + return + } + if pdu.Signatures == nil { + pdu.Signatures = make(map[string]map[id.KeyID]string) + } + if _, ok := pdu.Signatures[serverName]; !ok { + pdu.Signatures[serverName] = make(map[id.KeyID]string) + } + pdu.Signatures[serverName][keyID] = signature +} + func marshalCanonical(data any) (jsontext.Value, error) { marshaledBytes, err := json.Marshal(data) if err != nil { diff --git a/federation/pdu/signature.go b/federation/pdu/signature.go index a7685cc6..04e7c5ef 100644 --- a/federation/pdu/signature.go +++ b/federation/pdu/signature.go @@ -28,13 +28,7 @@ func (pdu *PDU) Sign(roomVersion id.RoomVersion, serverName string, keyID id.Key return fmt.Errorf("failed to marshal redacted PDU to sign: %w", err) } signature := ed25519.Sign(privateKey, rawJSON) - if pdu.Signatures == nil { - pdu.Signatures = make(map[string]map[id.KeyID]string) - } - if _, ok := pdu.Signatures[serverName]; !ok { - pdu.Signatures[serverName] = make(map[id.KeyID]string) - } - pdu.Signatures[serverName][keyID] = base64.RawStdEncoding.EncodeToString(signature) + pdu.AddSignature(serverName, keyID, base64.RawStdEncoding.EncodeToString(signature)) return nil } From f8234ecf8556f72cf4711cf23e3d51411027c910 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 1 Mar 2026 13:23:32 +0200 Subject: [PATCH 1630/1647] event: add m.room.policy event type --- event/content.go | 3 +++ event/state.go | 12 ++++++++++++ event/type.go | 5 ++++- 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/event/content.go b/event/content.go index d1ced268..4aa0593d 100644 --- a/event/content.go +++ b/event/content.go @@ -40,6 +40,9 @@ var TypeMap = map[Type]reflect.Type{ StateSpaceParent: reflect.TypeOf(SpaceParentEventContent{}), StateSpaceChild: reflect.TypeOf(SpaceChildEventContent{}), + StateRoomPolicy: reflect.TypeOf(RoomPolicyEventContent{}), + StateUnstableRoomPolicy: reflect.TypeOf(RoomPolicyEventContent{}), + StateLegacyPolicyRoom: reflect.TypeOf(ModPolicyContent{}), StateLegacyPolicyServer: reflect.TypeOf(ModPolicyContent{}), StateLegacyPolicyUser: reflect.TypeOf(ModPolicyContent{}), diff --git a/event/state.go b/event/state.go index 6d027e04..1df43351 100644 --- a/event/state.go +++ b/event/state.go @@ -343,3 +343,15 @@ func (efmc *ElementFunctionalMembersContent) Add(mxid id.UserID) bool { efmc.ServiceMembers = append(efmc.ServiceMembers, mxid) return true } + +type PolicyServerPublicKeys struct { + Ed25519 id.Ed25519 `json:"ed25519,omitempty"` +} + +type RoomPolicyEventContent struct { + Via string `json:"via,omitempty"` + PublicKeys *PolicyServerPublicKeys `json:"public_keys,omitempty"` + + // Deprecated, only for legacy use + PublicKey id.Ed25519 `json:"public_key"` +} diff --git a/event/type.go b/event/type.go index b193dc59..f337c127 100644 --- a/event/type.go +++ b/event/type.go @@ -113,7 +113,7 @@ func (et *Type) GuessClass() TypeClass { StatePinnedEvents.Type, StateTombstone.Type, StateEncryption.Type, StateBridge.Type, StateHalfShotBridge.Type, StateSpaceParent.Type, StateSpaceChild.Type, StatePolicyRoom.Type, StatePolicyServer.Type, StatePolicyUser.Type, StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type, StateBeeperDisappearingTimer.Type, - StateMSC4391BotCommand.Type: + StateMSC4391BotCommand.Type, StateRoomPolicy.Type, StateUnstableRoomPolicy.Type: return StateEventType case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type: return EphemeralEventType @@ -195,6 +195,9 @@ var ( StateSpaceChild = Type{"m.space.child", StateEventType} StateSpaceParent = Type{"m.space.parent", StateEventType} + StateRoomPolicy = Type{"m.room.policy", StateEventType} + StateUnstableRoomPolicy = Type{"org.matrix.msc4284.policy", StateEventType} + StateLegacyPolicyRoom = Type{"m.room.rule.room", StateEventType} StateLegacyPolicyServer = Type{"m.room.rule.server", StateEventType} StateLegacyPolicyUser = Type{"m.room.rule.user", StateEventType} From 26a62a7eec2b30cb88baffe30596e3ba0d278f9d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 1 Mar 2026 13:49:04 +0200 Subject: [PATCH 1631/1647] event: add missing omitempty --- event/state.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/event/state.go b/event/state.go index 1df43351..ace170a5 100644 --- a/event/state.go +++ b/event/state.go @@ -353,5 +353,5 @@ type RoomPolicyEventContent struct { PublicKeys *PolicyServerPublicKeys `json:"public_keys,omitempty"` // Deprecated, only for legacy use - PublicKey id.Ed25519 `json:"public_key"` + PublicKey id.Ed25519 `json:"public_key,omitempty"` } From e1529f9616a95ea18506fb99b8e835c44631735d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 3 Mar 2026 17:28:19 +0200 Subject: [PATCH 1632/1647] bridgev2/provisioning: log when returning login steps in provisioning API --- bridgev2/matrix/provisioning.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 17e827e3..8989ad51 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -407,6 +407,10 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque } func (prov *ProvisioningAPI) handleCompleteStep(ctx context.Context, login *ProvLogin, step *bridgev2.LoginStep) { + zerolog.Ctx(ctx).Info(). + Str("step_id", step.StepID). + Str("user_login_id", string(step.CompleteParams.UserLoginID)). + Msg("Login completed successfully") prov.deleteLogin(login, false) if login.Override == nil || login.Override.ID == step.CompleteParams.UserLoginID { return @@ -506,6 +510,8 @@ func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http login.NextStep = nextStep if nextStep.Type == bridgev2.LoginStepTypeComplete { prov.handleCompleteStep(r.Context(), login, nextStep) + } else { + zerolog.Ctx(r.Context()).Debug().Str("step_id", nextStep.StepID).Msg("Returning next login step") } exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) } @@ -525,6 +531,8 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques login.NextStep = nextStep if nextStep.Type == bridgev2.LoginStepTypeComplete { prov.handleCompleteStep(r.Context(), login, nextStep) + } else { + zerolog.Ctx(r.Context()).Debug().Str("step_id", nextStep.StepID).Msg("Returning next login step") } exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) } From 77f0658365509428ce4c4784e1bf2d192b4a483b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 3 Mar 2026 17:33:51 +0200 Subject: [PATCH 1633/1647] bridgev2/{commands,provisioning}: log full login step data --- bridgev2/commands/login.go | 2 ++ bridgev2/matrix/provisioning.go | 7 +++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/bridgev2/commands/login.go b/bridgev2/commands/login.go index 9e706995..96d62d3e 100644 --- a/bridgev2/commands/login.go +++ b/bridgev2/commands/login.go @@ -121,6 +121,7 @@ func fnLogin(ce *Event) { ce.Reply("Failed to start login: %v", err) return } + ce.Log.Debug().Any("first_step", nextStep).Msg("Created login process") nextStep = checkLoginCommandDirectParams(ce, login, nextStep) if nextStep != nil { @@ -499,6 +500,7 @@ func maybeURLDecodeCookie(val string, field *bridgev2.LoginCookieField) string { } func doLoginStep(ce *Event, login bridgev2.LoginProcess, step *bridgev2.LoginStep, override *bridgev2.UserLogin) { + ce.Log.Debug().Any("next_step", step).Msg("Got next login step") if step.Instructions != "" { ce.Reply(step.Instructions) } diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 8989ad51..02a0dac9 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -403,6 +403,9 @@ func (prov *ProvisioningAPI) PostLoginStart(w http.ResponseWriter, r *http.Reque Override: overrideLogin, } prov.loginsLock.Unlock() + zerolog.Ctx(r.Context()).Info(). + Any("first_step", firstStep). + Msg("Created login process") exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: loginID, LoginStep: firstStep}) } @@ -511,7 +514,7 @@ func (prov *ProvisioningAPI) PostLoginSubmitInput(w http.ResponseWriter, r *http if nextStep.Type == bridgev2.LoginStepTypeComplete { prov.handleCompleteStep(r.Context(), login, nextStep) } else { - zerolog.Ctx(r.Context()).Debug().Str("step_id", nextStep.StepID).Msg("Returning next login step") + zerolog.Ctx(r.Context()).Debug().Any("next_step", nextStep).Msg("Returning next login step") } exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) } @@ -532,7 +535,7 @@ func (prov *ProvisioningAPI) PostLoginWait(w http.ResponseWriter, r *http.Reques if nextStep.Type == bridgev2.LoginStepTypeComplete { prov.handleCompleteStep(r.Context(), login, nextStep) } else { - zerolog.Ctx(r.Context()).Debug().Str("step_id", nextStep.StepID).Msg("Returning next login step") + zerolog.Ctx(r.Context()).Debug().Any("next_step", nextStep).Msg("Returning next login step") } exhttp.WriteJSONResponse(w, http.StatusOK, &RespSubmitLogin{LoginID: login.ID, LoginStep: nextStep}) } From fef4326fbce6a20eac52028fb18a9da2ffd28061 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?batuhan=20i=C3=A7=C3=B6z?= Date: Wed, 4 Mar 2026 01:38:50 +0100 Subject: [PATCH 1634/1647] client,event,bridgev2: add support for Beeper's custom ephemeral events and AI stream events (#457) --- appservice/intent.go | 11 ++ bridgev2/errors.go | 1 + bridgev2/matrix/connector.go | 2 + bridgev2/matrix/intent.go | 16 +++ bridgev2/matrix/matrix.go | 5 +- bridgev2/matrixinterface.go | 5 + bridgev2/networkinterface.go | 6 ++ bridgev2/portal.go | 46 ++++++++ client.go | 42 ++++++++ client_ephemeral_test.go | 158 ++++++++++++++++++++++++++++ crypto/decryptmegolm.go | 1 + event/beeper.go | 9 ++ event/content.go | 8 +- event/powerlevels.go | 38 +++++++ event/powerlevels_ephemeral_test.go | 67 ++++++++++++ event/type.go | 10 +- versions.go | 1 + 17 files changed, 418 insertions(+), 8 deletions(-) create mode 100644 client_ephemeral_test.go create mode 100644 event/powerlevels_ephemeral_test.go diff --git a/appservice/intent.go b/appservice/intent.go index e4d8e100..0ec10b77 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -222,6 +222,17 @@ func (intent *IntentAPI) SendMessageEvent(ctx context.Context, roomID id.RoomID, return intent.Client.SendMessageEvent(ctx, roomID, eventType, contentJSON, extra...) } +func (intent *IntentAPI) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...mautrix.ReqSendEvent) (*mautrix.RespSendEvent, error) { + if err := intent.EnsureJoined(ctx, roomID); err != nil { + return nil, err + } + if !intent.SpecVersions.Supports(mautrix.BeeperFeatureEphemeralEvents) { + return nil, mautrix.MUnrecognized.WithMessage("Homeserver does not advertise com.beeper.ephemeral support") + } + contentJSON = intent.AddDoublePuppetValue(contentJSON) + return intent.Client.BeeperSendEphemeralEvent(ctx, roomID, eventType, contentJSON, extra...) +} + // Deprecated: use SendMessageEvent with mautrix.ReqSendEvent.Timestamp instead func (intent *IntentAPI) SendMassagedMessageEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { return intent.SendMessageEvent(ctx, roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts}) diff --git a/bridgev2/errors.go b/bridgev2/errors.go index 514dc238..f6677d2e 100644 --- a/bridgev2/errors.go +++ b/bridgev2/errors.go @@ -75,6 +75,7 @@ var ( ErrMediaConvertFailed error = WrapErrorInStatus(errors.New("failed to convert media")).WithMessage("failed to convert media").WithIsCertain(true).WithSendNotice(true) ErrMembershipNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group membership")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) ErrDeleteChatNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support deleting chats")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) + ErrBeeperAIStreamNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support Beeper AI stream events")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) ErrPowerLevelsNotSupported error = WrapErrorInStatus(errors.New("this bridge does not support changing group power levels")).WithIsCertain(true).WithErrorAsMessage().WithSendNotice(false).WithErrorReason(event.MessageStatusUnsupported) ErrRemoteEchoTimeout = WrapErrorInStatus(errors.New("remote echo timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) ErrRemoteAckTimeout = WrapErrorInStatus(errors.New("remote ack timed out")).WithIsCertain(false).WithSendNotice(true).WithErrorReason(event.MessageStatusTooOld) diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index aed6d3bd..b6da16ac 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -144,6 +144,7 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) { br.EventProcessor.On(event.EventReaction, br.handleRoomEvent) br.EventProcessor.On(event.EventRedaction, br.handleRoomEvent) br.EventProcessor.On(event.EventEncrypted, br.handleEncryptedEvent) + br.EventProcessor.On(event.EphemeralEventEncrypted, br.handleEncryptedEvent) br.EventProcessor.On(event.StateMember, br.handleRoomEvent) br.EventProcessor.On(event.StatePowerLevels, br.handleRoomEvent) br.EventProcessor.On(event.StateRoomName, br.handleRoomEvent) @@ -156,6 +157,7 @@ func (br *Connector) Init(bridge *bridgev2.Bridge) { br.EventProcessor.On(event.BeeperAcceptMessageRequest, br.handleRoomEvent) br.EventProcessor.On(event.EphemeralEventReceipt, br.handleEphemeralEvent) br.EventProcessor.On(event.EphemeralEventTyping, br.handleEphemeralEvent) + br.EventProcessor.On(event.BeeperEphemeralEventAIStream, br.handleEphemeralEvent) br.Bot = br.AS.BotIntent() br.Crypto = NewCryptoHelper(br) br.Bridge.Commands.(*commands.Processor).AddHandlers( diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 173f7c15..83318493 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -43,6 +43,7 @@ type ASIntent struct { var _ bridgev2.MatrixAPI = (*ASIntent)(nil) var _ bridgev2.MarkAsDMMatrixAPI = (*ASIntent)(nil) +var _ bridgev2.EphemeralSendingMatrixAPI = (*ASIntent)(nil) func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, extra *bridgev2.MatrixSendExtra) (*mautrix.RespSendEvent, error) { if extra == nil { @@ -84,6 +85,21 @@ func (as *ASIntent) SendMessage(ctx context.Context, roomID id.RoomID, eventType return as.Matrix.SendMessageEvent(ctx, roomID, eventType, content, mautrix.ReqSendEvent{Timestamp: extra.Timestamp.UnixMilli()}) } +func (as *ASIntent) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, txnID string) (*mautrix.RespSendEvent, error) { + if !as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureEphemeralEvents) { + return nil, mautrix.MUnrecognized.WithMessage("Homeserver does not advertise com.beeper.ephemeral support") + } + if encrypted, err := as.Matrix.StateStore.IsEncrypted(ctx, roomID); err != nil { + return nil, fmt.Errorf("failed to check if room is encrypted: %w", err) + } else if encrypted && as.Connector.Crypto != nil { + if err = as.Connector.Crypto.Encrypt(ctx, roomID, eventType, content); err != nil { + return nil, err + } + eventType = event.EventEncrypted + } + return as.Matrix.BeeperSendEphemeralEvent(ctx, roomID, eventType, content, mautrix.ReqSendEvent{TransactionID: txnID}) +} + func (as *ASIntent) fillMemberEvent(ctx context.Context, roomID id.RoomID, userID id.UserID, content *event.Content) { targetContent, ok := content.Parsed.(*event.MemberEventContent) if !ok || targetContent.Displayname != "" || targetContent.AvatarURL != "" { diff --git a/bridgev2/matrix/matrix.go b/bridgev2/matrix/matrix.go index 570ae5f1..954d0ad9 100644 --- a/bridgev2/matrix/matrix.go +++ b/bridgev2/matrix/matrix.go @@ -68,6 +68,10 @@ func (br *Connector) handleEphemeralEvent(ctx context.Context, evt *event.Event) case event.EphemeralEventTyping: typingContent := evt.Content.AsTyping() typingContent.UserIDs = slices.DeleteFunc(typingContent.UserIDs, br.shouldIgnoreEventFromUser) + case event.BeeperEphemeralEventAIStream: + if br.shouldIgnoreEvent(evt) { + return + } } br.Bridge.QueueMatrixEvent(ctx, evt) } @@ -231,7 +235,6 @@ func (br *Connector) postDecrypt(ctx context.Context, original, decrypted *event go br.sendSuccessCheckpoint(ctx, decrypted, status.MsgStepDecrypted, retryCount) decrypted.Mautrix.CheckpointSent = true decrypted.Mautrix.DecryptionDuration = duration - decrypted.Mautrix.EventSource |= event.SourceDecrypted br.EventProcessor.Dispatch(ctx, decrypted) if errorEventID != nil && *errorEventID != "" { _, _ = br.Bot.RedactEvent(ctx, decrypted.RoomID, *errorEventID) diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 57f786bb..768c57d1 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -217,3 +217,8 @@ type MarkAsDMMatrixAPI interface { MatrixAPI MarkAsDM(ctx context.Context, roomID id.RoomID, otherUser id.UserID) error } + +type EphemeralSendingMatrixAPI interface { + MatrixAPI + BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, content *event.Content, txnID string) (*mautrix.RespSendEvent, error) +} diff --git a/bridgev2/networkinterface.go b/bridgev2/networkinterface.go index 0e9a8543..efc5f100 100644 --- a/bridgev2/networkinterface.go +++ b/bridgev2/networkinterface.go @@ -726,6 +726,11 @@ type MessageRequestAcceptingNetworkAPI interface { HandleMatrixAcceptMessageRequest(ctx context.Context, msg *MatrixAcceptMessageRequest) error } +type BeeperAIStreamHandlingNetworkAPI interface { + NetworkAPI + HandleMatrixBeeperAIStream(ctx context.Context, msg *MatrixBeeperAIStream) error +} + type ResolveIdentifierResponse struct { // Ghost is the ghost of the user that the identifier resolves to. // This field should be set whenever possible. However, it is not required, @@ -1439,6 +1444,7 @@ type MatrixViewingChat struct { type MatrixDeleteChat = MatrixEventBase[*event.BeeperChatDeleteEventContent] type MatrixAcceptMessageRequest = MatrixEventBase[*event.BeeperAcceptMessageRequestEventContent] +type MatrixBeeperAIStream = MatrixEventBase[*event.BeeperAIStreamEventContent] type MatrixMarkedUnread = MatrixRoomMeta[*event.MarkedUnreadEventContent] type MatrixMute = MatrixRoomMeta[*event.BeeperMuteEventContent] type MatrixRoomTag = MatrixRoomMeta[*event.TagEventContent] diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 718a5cb2..5c0a7695 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -697,6 +697,8 @@ func (portal *Portal) handleMatrixEvent(ctx context.Context, sender *User, evt * return portal.handleMatrixReceipts(ctx, evt) case event.EphemeralEventTyping: return portal.handleMatrixTyping(ctx, evt) + case event.BeeperEphemeralEventAIStream: + return portal.handleMatrixAIStream(ctx, sender, evt) default: return EventHandlingResultIgnored } @@ -941,6 +943,50 @@ func (portal *Portal) handleMatrixTyping(ctx context.Context, evt *event.Event) return EventHandlingResultSuccess } +func (portal *Portal) handleMatrixAIStream(ctx context.Context, sender *User, evt *event.Event) EventHandlingResult { + log := zerolog.Ctx(ctx) + if sender == nil { + log.Error().Msg("Missing sender for Matrix AI stream event") + return EventHandlingResultIgnored + } + login, _, err := portal.FindPreferredLogin(ctx, sender, true) + if err != nil { + log.Err(err).Msg("Failed to get user login to handle Matrix AI stream event") + return EventHandlingResultFailed.WithMSSError(err) + } + var origSender *OrigSender + if login == nil { + if portal.Relay == nil { + return EventHandlingResultIgnored + } + login = portal.Relay + origSender = &OrigSender{ + User: sender, + UserID: sender.MXID, + } + } + content, ok := evt.Content.Parsed.(*event.BeeperAIStreamEventContent) + if !ok { + log.Error().Type("content_type", evt.Content.Parsed).Msg("Unexpected parsed content type") + return EventHandlingResultFailed.WithMSSError(fmt.Errorf("%w: %T", ErrUnexpectedParsedContentType, evt.Content.Parsed)) + } + api, ok := login.Client.(BeeperAIStreamHandlingNetworkAPI) + if !ok { + return EventHandlingResultIgnored.WithMSSError(ErrBeeperAIStreamNotSupported) + } + err = api.HandleMatrixBeeperAIStream(ctx, &MatrixBeeperAIStream{ + Event: evt, + Content: content, + Portal: portal, + OrigSender: origSender, + }) + if err != nil { + log.Err(err).Msg("Failed to handle Matrix AI stream event") + return EventHandlingResultFailed.WithMSSError(err) + } + return EventHandlingResultSuccess.WithMSS() +} + func (portal *Portal) sendTypings(ctx context.Context, userIDs []id.UserID, typing bool) { for _, userID := range userIDs { login, ok := portal.currentlyTypingLogins[userID] diff --git a/client.go b/client.go index 0a43816c..982f7454 100644 --- a/client.go +++ b/client.go @@ -1359,6 +1359,48 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event return } +// BeeperSendEphemeralEvent sends an ephemeral event into a room using Beeper's unstable endpoint. +// contentJSON should be a value that can be encoded as JSON using json.Marshal. +func (cli *Client) BeeperSendEphemeralEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, contentJSON any, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { + var req ReqSendEvent + if len(extra) > 0 { + req = extra[0] + } + + var txnID string + if len(req.TransactionID) > 0 { + txnID = req.TransactionID + } else { + txnID = cli.TxnID() + } + + queryParams := map[string]string{} + if req.Timestamp > 0 { + queryParams["ts"] = strconv.FormatInt(req.Timestamp, 10) + } + + if !req.DontEncrypt && cli != nil && cli.Crypto != nil && eventType != event.EventEncrypted { + var isEncrypted bool + isEncrypted, err = cli.StateStore.IsEncrypted(ctx, roomID) + if err != nil { + err = fmt.Errorf("failed to check if room is encrypted: %w", err) + return + } + if isEncrypted { + if contentJSON, err = cli.Crypto.Encrypt(ctx, roomID, eventType, contentJSON); err != nil { + err = fmt.Errorf("failed to encrypt event: %w", err) + return + } + eventType = event.EventEncrypted + } + } + + urlData := ClientURLPath{"unstable", "com.beeper.ephemeral", "rooms", roomID, "ephemeral", eventType.String(), txnID} + urlPath := cli.BuildURLWithQuery(urlData, queryParams) + _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp) + return +} + // SendStateEvent sends a state event into a room. See https://spec.matrix.org/v1.16/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON any, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { diff --git a/client_ephemeral_test.go b/client_ephemeral_test.go new file mode 100644 index 00000000..c2846427 --- /dev/null +++ b/client_ephemeral_test.go @@ -0,0 +1,158 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package mautrix_test + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" +) + +func TestClient_SendEphemeralEvent_UsesUnstablePathTxnAndTS(t *testing.T) { + roomID := id.RoomID("!room:example.com") + evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType} + txnID := "txn-123" + + var gotPath string + var gotQueryTS string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + gotQueryTS = r.URL.Query().Get("ts") + assert.Equal(t, http.MethodPut, r.Method) + _, _ = w.Write([]byte(`{"event_id":"$evt"}`)) + })) + defer ts.Close() + + cli, err := mautrix.NewClient(ts.URL, "", "") + require.NoError(t, err) + + _, err = cli.BeeperSendEphemeralEvent( + context.Background(), + roomID, + evtType, + map[string]any{"foo": "bar"}, + mautrix.ReqSendEvent{TransactionID: txnID, Timestamp: 1234}, + ) + require.NoError(t, err) + + assert.True(t, strings.Contains(gotPath, "/_matrix/client/unstable/com.beeper.ephemeral/rooms/")) + assert.True(t, strings.HasSuffix(gotPath, "/ephemeral/com.example.ephemeral/"+txnID)) + assert.Equal(t, "1234", gotQueryTS) +} + +func TestClient_SendEphemeralEvent_UnsupportedReturnsMUnrecognized(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"errcode":"M_UNRECOGNIZED","error":"Unrecognized endpoint"}`)) + })) + defer ts.Close() + + cli, err := mautrix.NewClient(ts.URL, "", "") + require.NoError(t, err) + + _, err = cli.BeeperSendEphemeralEvent( + context.Background(), + id.RoomID("!room:example.com"), + event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType}, + map[string]any{"foo": "bar"}, + ) + require.Error(t, err) + assert.True(t, errors.Is(err, mautrix.MUnrecognized)) +} + +func TestClient_SendEphemeralEvent_EncryptsInEncryptedRooms(t *testing.T) { + roomID := id.RoomID("!room:example.com") + evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType} + txnID := "txn-encrypted" + + stateStore := mautrix.NewMemoryStateStore() + err := stateStore.SetEncryptionEvent(context.Background(), roomID, &event.EncryptionEventContent{ + Algorithm: id.AlgorithmMegolmV1, + }) + require.NoError(t, err) + + fakeCrypto := &fakeCryptoHelper{ + encryptedContent: &event.EncryptedEventContent{ + Algorithm: id.AlgorithmMegolmV1, + MegolmCiphertext: []byte("ciphertext"), + }, + } + + var gotPath string + var gotBody map[string]any + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + assert.Equal(t, http.MethodPut, r.Method) + err := json.NewDecoder(r.Body).Decode(&gotBody) + require.NoError(t, err) + _, _ = w.Write([]byte(`{"event_id":"$evt"}`)) + })) + defer ts.Close() + + cli, err := mautrix.NewClient(ts.URL, "", "") + require.NoError(t, err) + cli.StateStore = stateStore + cli.Crypto = fakeCrypto + + _, err = cli.BeeperSendEphemeralEvent( + context.Background(), + roomID, + evtType, + map[string]any{"foo": "bar"}, + mautrix.ReqSendEvent{TransactionID: txnID}, + ) + require.NoError(t, err) + + assert.True(t, strings.HasSuffix(gotPath, "/ephemeral/m.room.encrypted/"+txnID)) + assert.Equal(t, string(id.AlgorithmMegolmV1), gotBody["algorithm"]) + assert.Equal(t, 1, fakeCrypto.encryptCalls) + assert.Equal(t, roomID, fakeCrypto.lastRoomID) + assert.Equal(t, evtType, fakeCrypto.lastEventType) +} + +type fakeCryptoHelper struct { + encryptCalls int + lastRoomID id.RoomID + lastEventType event.Type + lastEncryptInput any + encryptedContent *event.EncryptedEventContent +} + +func (f *fakeCryptoHelper) Encrypt(_ context.Context, roomID id.RoomID, eventType event.Type, content any) (*event.EncryptedEventContent, error) { + f.encryptCalls++ + f.lastRoomID = roomID + f.lastEventType = eventType + f.lastEncryptInput = content + return f.encryptedContent, nil +} + +func (f *fakeCryptoHelper) Decrypt(context.Context, *event.Event) (*event.Event, error) { + return nil, nil +} + +func (f *fakeCryptoHelper) WaitForSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool { + return false +} + +func (f *fakeCryptoHelper) RequestSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID) { +} + +func (f *fakeCryptoHelper) Init(context.Context) error { + return nil +} diff --git a/crypto/decryptmegolm.go b/crypto/decryptmegolm.go index 9753eabd..457d5a0c 100644 --- a/crypto/decryptmegolm.go +++ b/crypto/decryptmegolm.go @@ -213,6 +213,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event TrustSource: device, ForwardedKeys: forwardedKeys, WasEncrypted: true, + EventSource: evt.Mautrix.EventSource | event.SourceDecrypted, ReceivedAt: evt.Mautrix.ReceivedAt, }, }, nil diff --git a/event/beeper.go b/event/beeper.go index 6de41df6..a1a60b35 100644 --- a/event/beeper.go +++ b/event/beeper.go @@ -214,6 +214,15 @@ func (content *MessageEventContent) RemovePerMessageProfileFallback() { } } +type BeeperAIStreamEventContent struct { + TurnID string `json:"turn_id"` + Seq int `json:"seq"` + Part map[string]any `json:"part"` + TargetEvent id.EventID `json:"target_event,omitempty"` + AgentID string `json:"agent_id,omitempty"` + RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` +} + type BeeperEncodedOrder struct { order int64 suborder int16 diff --git a/event/content.go b/event/content.go index 4aa0593d..814aeec4 100644 --- a/event/content.go +++ b/event/content.go @@ -76,9 +76,11 @@ var TypeMap = map[Type]reflect.Type{ AccountDataMarkedUnread: reflect.TypeOf(MarkedUnreadEventContent{}), AccountDataBeeperMute: reflect.TypeOf(BeeperMuteEventContent{}), - EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}), - EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}), - EphemeralEventPresence: reflect.TypeOf(PresenceEventContent{}), + EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}), + EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}), + EphemeralEventPresence: reflect.TypeOf(PresenceEventContent{}), + EphemeralEventEncrypted: reflect.TypeOf(EncryptedEventContent{}), + BeeperEphemeralEventAIStream: reflect.TypeOf(BeeperAIStreamEventContent{}), InRoomVerificationReady: reflect.TypeOf(VerificationReadyEventContent{}), InRoomVerificationStart: reflect.TypeOf(VerificationStartEventContent{}), diff --git a/event/powerlevels.go b/event/powerlevels.go index 708721f9..668eb6d3 100644 --- a/event/powerlevels.go +++ b/event/powerlevels.go @@ -28,6 +28,9 @@ type PowerLevelsEventContent struct { Events map[string]int `json:"events,omitempty"` EventsDefault int `json:"events_default,omitempty"` + beeperEphemeralLock sync.RWMutex + BeeperEphemeral map[string]int `json:"com.beeper.ephemeral,omitempty"` + Notifications *NotificationPowerLevels `json:"notifications,omitempty"` StateDefaultPtr *int `json:"state_default,omitempty"` @@ -37,6 +40,8 @@ type PowerLevelsEventContent struct { BanPtr *int `json:"ban,omitempty"` RedactPtr *int `json:"redact,omitempty"` + BeeperEphemeralDefaultPtr *int `json:"com.beeper.ephemeral_default,omitempty"` + // This is not a part of power levels, it's added by mautrix-go internally in certain places // in order to detect creator power accurately. CreateEvent *Event `json:"-"` @@ -51,6 +56,7 @@ func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent { UsersDefault: pl.UsersDefault, Events: maps.Clone(pl.Events), EventsDefault: pl.EventsDefault, + BeeperEphemeral: maps.Clone(pl.BeeperEphemeral), StateDefaultPtr: ptr.Clone(pl.StateDefaultPtr), Notifications: pl.Notifications.Clone(), @@ -60,6 +66,8 @@ func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent { BanPtr: ptr.Clone(pl.BanPtr), RedactPtr: ptr.Clone(pl.RedactPtr), + BeeperEphemeralDefaultPtr: ptr.Clone(pl.BeeperEphemeralDefaultPtr), + CreateEvent: pl.CreateEvent, } } @@ -119,6 +127,13 @@ func (pl *PowerLevelsEventContent) StateDefault() int { return 50 } +func (pl *PowerLevelsEventContent) BeeperEphemeralDefault() int { + if pl.BeeperEphemeralDefaultPtr != nil { + return *pl.BeeperEphemeralDefaultPtr + } + return pl.EventsDefault +} + func (pl *PowerLevelsEventContent) GetUserLevel(userID id.UserID) int { if pl.isCreator(userID) { return math.MaxInt @@ -202,6 +217,29 @@ func (pl *PowerLevelsEventContent) GetEventLevel(eventType Type) int { return level } +func (pl *PowerLevelsEventContent) GetBeeperEphemeralLevel(eventType Type) int { + pl.beeperEphemeralLock.RLock() + defer pl.beeperEphemeralLock.RUnlock() + level, ok := pl.BeeperEphemeral[eventType.String()] + if !ok { + return pl.BeeperEphemeralDefault() + } + return level +} + +func (pl *PowerLevelsEventContent) SetBeeperEphemeralLevel(eventType Type, level int) { + pl.beeperEphemeralLock.Lock() + defer pl.beeperEphemeralLock.Unlock() + if level == pl.BeeperEphemeralDefault() { + delete(pl.BeeperEphemeral, eventType.String()) + } else { + if pl.BeeperEphemeral == nil { + pl.BeeperEphemeral = make(map[string]int) + } + pl.BeeperEphemeral[eventType.String()] = level + } +} + func (pl *PowerLevelsEventContent) SetEventLevel(eventType Type, level int) { pl.eventsLock.Lock() defer pl.eventsLock.Unlock() diff --git a/event/powerlevels_ephemeral_test.go b/event/powerlevels_ephemeral_test.go new file mode 100644 index 00000000..f5861583 --- /dev/null +++ b/event/powerlevels_ephemeral_test.go @@ -0,0 +1,67 @@ +// Copyright (c) 2026 Tulir Asokan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package event_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "maunium.net/go/mautrix/event" +) + +func TestPowerLevelsEventContent_BeeperEphemeralDefaultFallsBackToEventsDefault(t *testing.T) { + pl := &event.PowerLevelsEventContent{ + EventsDefault: 45, + } + + assert.Equal(t, 45, pl.BeeperEphemeralDefault()) + + override := 60 + pl.BeeperEphemeralDefaultPtr = &override + assert.Equal(t, 60, pl.BeeperEphemeralDefault()) +} + +func TestPowerLevelsEventContent_GetSetBeeperEphemeralLevel(t *testing.T) { + pl := &event.PowerLevelsEventContent{ + EventsDefault: 25, + } + evtType := event.Type{Type: "com.example.ephemeral", Class: event.EphemeralEventType} + + assert.Equal(t, 25, pl.GetBeeperEphemeralLevel(evtType)) + + pl.SetBeeperEphemeralLevel(evtType, 50) + assert.Equal(t, 50, pl.GetBeeperEphemeralLevel(evtType)) + require.NotNil(t, pl.BeeperEphemeral) + assert.Equal(t, 50, pl.BeeperEphemeral[evtType.String()]) + + pl.SetBeeperEphemeralLevel(evtType, 25) + _, exists := pl.BeeperEphemeral[evtType.String()] + assert.False(t, exists) +} + +func TestPowerLevelsEventContent_CloneCopiesBeeperEphemeralFields(t *testing.T) { + override := 70 + pl := &event.PowerLevelsEventContent{ + EventsDefault: 35, + BeeperEphemeral: map[string]int{"com.example.ephemeral": 90}, + BeeperEphemeralDefaultPtr: &override, + } + + cloned := pl.Clone() + require.NotNil(t, cloned) + require.NotNil(t, cloned.BeeperEphemeralDefaultPtr) + assert.Equal(t, 70, *cloned.BeeperEphemeralDefaultPtr) + assert.Equal(t, 90, cloned.BeeperEphemeral["com.example.ephemeral"]) + + cloned.BeeperEphemeral["com.example.ephemeral"] = 99 + *cloned.BeeperEphemeralDefaultPtr = 71 + + assert.Equal(t, 90, pl.BeeperEphemeral["com.example.ephemeral"]) + assert.Equal(t, 70, *pl.BeeperEphemeralDefaultPtr) +} diff --git a/event/type.go b/event/type.go index f337c127..80b86728 100644 --- a/event/type.go +++ b/event/type.go @@ -115,7 +115,7 @@ func (et *Type) GuessClass() TypeClass { StateElementFunctionalMembers.Type, StateBeeperRoomFeatures.Type, StateBeeperDisappearingTimer.Type, StateMSC4391BotCommand.Type, StateRoomPolicy.Type, StateUnstableRoomPolicy.Type: return StateEventType - case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type: + case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type, BeeperEphemeralEventAIStream.Type: return EphemeralEventType case AccountDataDirectChats.Type, AccountDataPushRules.Type, AccountDataRoomTags.Type, AccountDataFullyRead.Type, AccountDataIgnoredUserList.Type, AccountDataMarkedUnread.Type, @@ -250,9 +250,11 @@ var ( // Ephemeral events var ( - EphemeralEventReceipt = Type{"m.receipt", EphemeralEventType} - EphemeralEventTyping = Type{"m.typing", EphemeralEventType} - EphemeralEventPresence = Type{"m.presence", EphemeralEventType} + EphemeralEventReceipt = Type{"m.receipt", EphemeralEventType} + EphemeralEventTyping = Type{"m.typing", EphemeralEventType} + EphemeralEventPresence = Type{"m.presence", EphemeralEventType} + EphemeralEventEncrypted = Type{"m.room.encrypted", EphemeralEventType} + BeeperEphemeralEventAIStream = Type{"com.beeper.ai.stream_event", EphemeralEventType} ) // Account data events diff --git a/versions.go b/versions.go index 8ae82a06..69233730 100644 --- a/versions.go +++ b/versions.go @@ -80,6 +80,7 @@ var ( BeeperFeatureAccountDataMute = UnstableFeature{UnstableFlag: "com.beeper.account_data_mute"} BeeperFeatureInboxState = UnstableFeature{UnstableFlag: "com.beeper.inbox_state"} BeeperFeatureArbitraryMemberChange = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_member_change"} + BeeperFeatureEphemeralEvents = UnstableFeature{UnstableFlag: "com.beeper.ephemeral"} ) func (versions *RespVersions) Supports(feature UnstableFeature) bool { From ed9820356e983f9c6489e7e1bb4b75514cf8f3e6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 4 Mar 2026 13:58:07 +0200 Subject: [PATCH 1635/1647] bridgev2/portalreid: try to fix deadlock when racing with room creation --- bridgev2/portal.go | 3 +++ bridgev2/portalreid.go | 28 +++++++++++++++++++++------- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 5c0a7695..8df41644 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -5363,6 +5363,9 @@ func (portal *Portal) removeInPortalCache(ctx context.Context) { } func (portal *Portal) unlockedDelete(ctx context.Context) error { + if portal.deleted.IsSet() { + return nil + } err := portal.safeDBDelete(ctx) if err != nil { return err diff --git a/bridgev2/portalreid.go b/bridgev2/portalreid.go index 6a5091fc..c976d97c 100644 --- a/bridgev2/portalreid.go +++ b/bridgev2/portalreid.go @@ -38,17 +38,20 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta Stringer("target_portal_key", target). Logger() ctx = log.WithContext(ctx) - if !br.cacheLock.TryLock() { - log.Debug().Msg("Waiting for cache lock") - br.cacheLock.Lock() - log.Debug().Msg("Acquired cache lock after waiting") - } defer func() { - br.cacheLock.Unlock() log.Debug().Msg("Finished handling portal re-ID") }() + acquireCacheLock := func() { + if !br.cacheLock.TryLock() { + log.Debug().Msg("Waiting for global cache lock") + br.cacheLock.Lock() + log.Debug().Msg("Acquired global cache lock after waiting") + } else { + log.Trace().Msg("Acquired global cache lock without waiting") + } + } log.Debug().Msg("Re-ID'ing portal") - sourcePortal, err := br.UnlockedGetPortalByKey(ctx, source, true) + sourcePortal, err := br.GetExistingPortalByKey(ctx, source) if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to get source portal: %w", err) } else if sourcePortal == nil { @@ -75,18 +78,24 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Stringer("source_portal_mxid", sourcePortal.MXID) }) + + acquireCacheLock() targetPortal, err := br.UnlockedGetPortalByKey(ctx, target, true) if err != nil { + br.cacheLock.Unlock() return ReIDResultError, nil, fmt.Errorf("failed to get target portal: %w", err) } if targetPortal == nil { log.Info().Msg("Target portal doesn't exist, re-ID'ing source portal") err = sourcePortal.unlockedReID(ctx, target) + br.cacheLock.Unlock() if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to re-ID source portal: %w", err) } return ReIDResultSourceReIDd, sourcePortal, nil } + br.cacheLock.Unlock() + if !targetPortal.roomCreateLock.TryLock() { if cancelCreate := targetPortal.cancelRoomCreate.Swap(nil); cancelCreate != nil { (*cancelCreate)() @@ -98,6 +107,8 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta defer targetPortal.roomCreateLock.Unlock() if targetPortal.MXID == "" { log.Info().Msg("Target portal row exists, but doesn't have a Matrix room. Deleting target portal row and re-ID'ing source portal") + acquireCacheLock() + defer br.cacheLock.Unlock() err = targetPortal.unlockedDelete(ctx) if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to delete target portal: %w", err) @@ -112,6 +123,9 @@ func (br *Bridge) ReIDPortal(ctx context.Context, source, target networkid.Porta return c.Stringer("target_portal_mxid", targetPortal.MXID) }) log.Info().Msg("Both target and source portals have Matrix rooms, tombstoning source portal") + sourcePortal.removeInPortalCache(ctx) + acquireCacheLock() + defer br.cacheLock.Unlock() err = sourcePortal.unlockedDelete(ctx) if err != nil { return ReIDResultError, nil, fmt.Errorf("failed to delete source portal row: %w", err) From ed6dbcaaeeeb8707c643c08a2f5990caa954a491 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 4 Mar 2026 22:50:43 +0200 Subject: [PATCH 1636/1647] client: log content length when uploading to external url --- client.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index 982f7454..0a9704a9 100644 --- a/client.go +++ b/client.go @@ -2041,7 +2041,10 @@ type ReqUploadMedia struct { } func (cli *Client) tryUploadMediaToURL(ctx context.Context, url, contentType string, content io.Reader, contentLength int64) (*http.Response, error) { - cli.Log.Debug().Str("url", url).Msg("Uploading media to external URL") + cli.Log.Debug(). + Str("url", url). + Int64("content_length", contentLength). + Msg("Uploading media to external URL") req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, content) if err != nil { return nil, err From 0f6a779dd2b55916ee4a2b27a46d2bd6e0f9d592 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 5 Mar 2026 11:59:11 +0200 Subject: [PATCH 1637/1647] readme: update --- README.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index ac41ca78..b1a2edf8 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,9 @@ # mautrix-go [![GoDoc](https://pkg.go.dev/badge/maunium.net/go/mautrix)](https://pkg.go.dev/maunium.net/go/mautrix) -A Golang Matrix framework. Used by [gomuks](https://matrix.org/docs/projects/client/gomuks), -[go-neb](https://github.com/matrix-org/go-neb), [mautrix-whatsapp](https://github.com/mautrix/whatsapp) +A Golang Matrix framework. Used by [gomuks](https://gomuks.app), +[go-neb](https://github.com/matrix-org/go-neb), +[mautrix-whatsapp](https://github.com/mautrix/whatsapp) and others. Matrix room: [`#go:maunium.net`](https://matrix.to/#/#go:maunium.net) @@ -13,9 +14,10 @@ The original project is licensed under [Apache 2.0](https://github.com/matrix-or In addition to the basic client API features the original project has, this framework also has: * Appservice support (Intent API like mautrix-python, room state storage, etc) -* End-to-end encryption support (incl. interactive SAS verification) +* End-to-end encryption support (incl. key backup, cross-signing, interactive verification, etc) * High-level module for building puppeting bridges -* High-level module for building chat clients +* Partial federation module (making requests, PDU processing and event authorization) +* A media proxy server which can be used to expose anything as a Matrix media repo * Wrapper functions for the Synapse admin API * Structs for parsing event content * Helpers for parsing and generating Matrix HTML From 7836f35a1a7431a3eb7f1a09697d324058dbde01 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 5 Mar 2026 23:57:35 +0200 Subject: [PATCH 1638/1647] bridgev2/portal: fix third matrix reaction not removing previous one on single-reaction networks --- bridgev2/portal.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 8df41644..d8acf88e 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1632,6 +1632,10 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi // Keep n-1 previous reactions and remove the rest react.ExistingReactionsToKeep = allReactions[:preResp.MaxReactions-1] for _, oldReaction := range allReactions[preResp.MaxReactions-1:] { + if existing != nil && oldReaction.EmojiID == existing.EmojiID { + // Don't double-delete on networks that only allow one emoji + continue + } // Intentionally defer in a loop, there won't be that many items, // and we want all of them to be done after this function completes successfully //goland:noinspection GoDeferInLoop From 7a53f3928a01fa646cfdd5d1a950e04a687e09cb Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 6 Mar 2026 14:25:52 +0200 Subject: [PATCH 1639/1647] bridgev2/portal: redact conflicting reactions before sending MSS success --- bridgev2/portal.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index d8acf88e..48a17e91 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -1587,6 +1587,12 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi if portal.Bridge.Config.OutgoingMessageReID { deterministicID = portal.Bridge.Matrix.GenerateReactionEventID(portal.MXID, reactionTarget, preResp.SenderID, preResp.EmojiID) } + defer func() { + // Do this in a defer so that it happens after any potential defer calls to removeOutdatedReaction + if handleRes.Success { + portal.sendSuccessStatus(ctx, evt, 0, deterministicID) + } + }() removeOutdatedReaction := func(oldReact *database.Reaction, deleteDB bool) { if !handleRes.Success { return @@ -1684,7 +1690,6 @@ func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *UserLogi if err != nil { log.Err(err).Msg("Failed to save reaction to database") } - portal.sendSuccessStatus(ctx, evt, 0, deterministicID) return EventHandlingResultSuccess.WithEventID(deterministicID) } From 531822f6dcf54f82f1a93156c670ed33f8277b2b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 6 Mar 2026 16:08:15 +0200 Subject: [PATCH 1640/1647] bridgev2/config: add limit for unknown error auto-reconnects --- bridgev2/bridgeconfig/config.go | 47 +++++++++++----------- bridgev2/bridgeconfig/upgrade.go | 1 + bridgev2/bridgestate.go | 10 ++++- bridgev2/matrix/mxmain/example-config.yaml | 3 ++ 4 files changed, 37 insertions(+), 24 deletions(-) diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index 8b9aa019..c301b8d0 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -62,29 +62,30 @@ type CleanupOnLogouts struct { } type BridgeConfig struct { - CommandPrefix string `yaml:"command_prefix"` - PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"` - PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"` - AsyncEvents bool `yaml:"async_events"` - SplitPortals bool `yaml:"split_portals"` - ResendBridgeInfo bool `yaml:"resend_bridge_info"` - NoBridgeInfoStateKey bool `yaml:"no_bridge_info_state_key"` - BridgeStatusNotices string `yaml:"bridge_status_notices"` - UnknownErrorAutoReconnect time.Duration `yaml:"unknown_error_auto_reconnect"` - BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` - BridgeNotices bool `yaml:"bridge_notices"` - TagOnlyOnCreate bool `yaml:"tag_only_on_create"` - OnlyBridgeTags []event.RoomTag `yaml:"only_bridge_tags"` - MuteOnlyOnCreate bool `yaml:"mute_only_on_create"` - DeduplicateMatrixMessages bool `yaml:"deduplicate_matrix_messages"` - CrossRoomReplies bool `yaml:"cross_room_replies"` - OutgoingMessageReID bool `yaml:"outgoing_message_re_id"` - RevertFailedStateChanges bool `yaml:"revert_failed_state_changes"` - KickMatrixUsers bool `yaml:"kick_matrix_users"` - CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"` - Relay RelayConfig `yaml:"relay"` - Permissions PermissionConfig `yaml:"permissions"` - Backfill BackfillConfig `yaml:"backfill"` + CommandPrefix string `yaml:"command_prefix"` + PersonalFilteringSpaces bool `yaml:"personal_filtering_spaces"` + PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"` + AsyncEvents bool `yaml:"async_events"` + SplitPortals bool `yaml:"split_portals"` + ResendBridgeInfo bool `yaml:"resend_bridge_info"` + NoBridgeInfoStateKey bool `yaml:"no_bridge_info_state_key"` + BridgeStatusNotices string `yaml:"bridge_status_notices"` + UnknownErrorAutoReconnect time.Duration `yaml:"unknown_error_auto_reconnect"` + UnknownErrorMaxAutoReconnects int `yaml:"unknown_error_max_auto_reconnects"` + BridgeMatrixLeave bool `yaml:"bridge_matrix_leave"` + BridgeNotices bool `yaml:"bridge_notices"` + TagOnlyOnCreate bool `yaml:"tag_only_on_create"` + OnlyBridgeTags []event.RoomTag `yaml:"only_bridge_tags"` + MuteOnlyOnCreate bool `yaml:"mute_only_on_create"` + DeduplicateMatrixMessages bool `yaml:"deduplicate_matrix_messages"` + CrossRoomReplies bool `yaml:"cross_room_replies"` + OutgoingMessageReID bool `yaml:"outgoing_message_re_id"` + RevertFailedStateChanges bool `yaml:"revert_failed_state_changes"` + KickMatrixUsers bool `yaml:"kick_matrix_users"` + CleanupOnLogout CleanupOnLogouts `yaml:"cleanup_on_logout"` + Relay RelayConfig `yaml:"relay"` + Permissions PermissionConfig `yaml:"permissions"` + Backfill BackfillConfig `yaml:"backfill"` } type MatrixConfig struct { diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index a0278672..ef51335e 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -33,6 +33,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "bridge", "no_bridge_info_state_key") helper.Copy(up.Str|up.Null, "bridge", "bridge_status_notices") helper.Copy(up.Str|up.Int|up.Null, "bridge", "unknown_error_auto_reconnect") + helper.Copy(up.Int, "bridge", "unknown_error_max_auto_reconnects") helper.Copy(up.Bool, "bridge", "bridge_matrix_leave") helper.Copy(up.Bool, "bridge", "bridge_notices") helper.Copy(up.Bool, "bridge", "tag_only_on_create") diff --git a/bridgev2/bridgestate.go b/bridgev2/bridgestate.go index babbccab..96d9fd5c 100644 --- a/bridgev2/bridgestate.go +++ b/bridgev2/bridgestate.go @@ -37,6 +37,8 @@ type BridgeStateQueue struct { stopChan chan struct{} stopReconnect atomic.Pointer[context.CancelFunc] + + unknownErrorReconnects int } func (br *Bridge) SendGlobalBridgeState(state status.BridgeState) { @@ -192,8 +194,14 @@ func (bsq *BridgeStateQueue) unknownErrorReconnect(triggeredBy status.BridgeStat } else if prevUnsent.StateEvent != status.StateUnknownError || prev.StateEvent != status.StateUnknownError { log.Debug().Msg("Not reconnecting as the previous state was not an unknown error") return + } else if bsq.unknownErrorReconnects > bsq.bridge.Config.UnknownErrorMaxAutoReconnects { + log.Warn().Msg("Not reconnecting as the maximum number of unknown error reconnects has been reached") + return } - log.Info().Msg("Disconnecting and reconnecting login due to unknown error") + bsq.unknownErrorReconnects++ + log.Info(). + Int("reconnect_num", bsq.unknownErrorReconnects). + Msg("Disconnecting and reconnecting login due to unknown error") bsq.login.Disconnect() log.Debug().Msg("Disconnection finished, recreating client and reconnecting") err := bsq.login.recreateClient(ctx) diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index b0e83696..75d0edbf 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -29,6 +29,9 @@ bridge: # How long after an unknown error should the bridge attempt a full reconnect? # Must be at least 1 minute. The bridge will add an extra ±20% jitter to this value. unknown_error_auto_reconnect: null + # Maximum number of times to do the auto-reconnect above. + # The counter is per login, but is never reset except on logout and restart. + unknown_error_max_auto_reconnects: 10 # Should leaving Matrix rooms be bridged as leaving groups on the remote network? bridge_matrix_leave: false From df24fb96e2e5bcbd451bc0b9340338415075519f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 6 Mar 2026 20:58:18 +0200 Subject: [PATCH 1641/1647] client: update MSC2666 implementation --- client.go | 9 +++++++-- responses.go | 1 + versions.go | 3 ++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index 0a9704a9..fbb9333f 100644 --- a/client.go +++ b/client.go @@ -1158,7 +1158,9 @@ func (cli *Client) SearchUserDirectory(ctx context.Context, query string, limit } func (cli *Client) GetMutualRooms(ctx context.Context, otherUserID id.UserID, extras ...ReqMutualRooms) (resp *RespMutualRooms, err error) { - if cli.SpecVersions != nil && !cli.SpecVersions.Supports(FeatureMutualRooms) { + supportsStable := cli.SpecVersions.Supports(FeatureStableMutualRooms) + supportsUnstable := cli.SpecVersions.Supports(FeatureUnstableMutualRooms) + if cli.SpecVersions != nil && !supportsUnstable && !supportsStable { err = fmt.Errorf("server does not support fetching mutual rooms") return } @@ -1168,7 +1170,10 @@ func (cli *Client) GetMutualRooms(ctx context.Context, otherUserID id.UserID, ex if len(extras) > 0 { query["from"] = extras[0].From } - urlPath := cli.BuildURLWithQuery(ClientURLPath{"unstable", "uk.half-shot.msc2666", "user", "mutual_rooms"}, query) + urlPath := cli.BuildURLWithQuery(ClientURLPath{"v1", "user", "mutual_rooms"}, query) + if !supportsStable && supportsUnstable { + urlPath = cli.BuildURLWithQuery(ClientURLPath{"unstable", "uk.half-shot.msc2666", "user", "mutual_rooms"}, query) + } _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp) return } diff --git a/responses.go b/responses.go index 20286431..4fbe1fbc 100644 --- a/responses.go +++ b/responses.go @@ -258,6 +258,7 @@ func (r *UserDirectoryEntry) MarshalJSON() ([]byte, error) { type RespMutualRooms struct { Joined []id.RoomID `json:"joined"` NextBatch string `json:"next_batch,omitempty"` + Count int `json:"count,omitempty"` } type RespRoomSummary struct { diff --git a/versions.go b/versions.go index 69233730..61b2e4ea 100644 --- a/versions.go +++ b/versions.go @@ -63,7 +63,8 @@ var ( FeatureAsyncUploads = UnstableFeature{UnstableFlag: "fi.mau.msc2246.stable", SpecVersion: SpecV17} FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17} FeatureAuthenticatedMedia = UnstableFeature{UnstableFlag: "org.matrix.msc3916.stable", SpecVersion: SpecV111} - FeatureMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"} + FeatureUnstableMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms"} + FeatureStableMutualRooms = UnstableFeature{UnstableFlag: "uk.half-shot.msc2666.query_mutual_rooms.stable" /*, SpecVersion: SpecV118*/} FeatureUserRedaction = UnstableFeature{UnstableFlag: "org.matrix.msc4194"} FeatureViewRedactedContent = UnstableFeature{UnstableFlag: "fi.mau.msc2815"} FeatureUnstableAccountModeration = UnstableFeature{UnstableFlag: "uk.timedout.msc4323"} From c107c25d078ee2de4304baa5b2fb109c70edae3d Mon Sep 17 00:00:00 2001 From: timedout Date: Sat, 7 Mar 2026 14:26:42 +0000 Subject: [PATCH 1642/1647] client: add type parameter to UIA request bodies (#469) --- appservice/intent.go | 2 +- client.go | 14 +++++++------- crypto/cross_sign_key.go | 2 +- mockserver/mockserver.go | 2 +- requests.go | 16 ++++++++-------- 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/appservice/intent.go b/appservice/intent.go index 0ec10b77..5d43f190 100644 --- a/appservice/intent.go +++ b/appservice/intent.go @@ -51,7 +51,7 @@ func (as *AppService) NewIntentAPI(localpart string) *IntentAPI { } func (intent *IntentAPI) Register(ctx context.Context) error { - _, err := intent.Client.MakeRequest(ctx, http.MethodPost, intent.BuildClientURL("v3", "register"), &mautrix.ReqRegister{ + _, err := intent.Client.MakeRequest(ctx, http.MethodPost, intent.BuildClientURL("v3", "register"), &mautrix.ReqRegister[any]{ Username: intent.Localpart, Type: mautrix.AuthTypeAppservice, InhibitLogin: true, diff --git a/client.go b/client.go index fbb9333f..7062d9b9 100644 --- a/client.go +++ b/client.go @@ -918,7 +918,7 @@ func (cli *Client) RegisterAvailable(ctx context.Context, username string) (resp return } -func (cli *Client) register(ctx context.Context, url string, req *ReqRegister) (resp *RespRegister, uiaResp *RespUserInteractive, err error) { +func (cli *Client) register(ctx context.Context, url string, req *ReqRegister[any]) (resp *RespRegister, uiaResp *RespUserInteractive, err error) { var bodyBytes []byte bodyBytes, err = cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, @@ -942,7 +942,7 @@ func (cli *Client) register(ctx context.Context, url string, req *ReqRegister) ( // Register makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register // // Registers with kind=user. For kind=guest, see RegisterGuest. -func (cli *Client) Register(ctx context.Context, req *ReqRegister) (*RespRegister, *RespUserInteractive, error) { +func (cli *Client) Register(ctx context.Context, req *ReqRegister[any]) (*RespRegister, *RespUserInteractive, error) { u := cli.BuildClientURL("v3", "register") return cli.register(ctx, u, req) } @@ -951,7 +951,7 @@ func (cli *Client) Register(ctx context.Context, req *ReqRegister) (*RespRegiste // with kind=guest. // // For kind=user, see Register. -func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister) (*RespRegister, *RespUserInteractive, error) { +func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister[any]) (*RespRegister, *RespUserInteractive, error) { query := map[string]string{ "kind": "guest", } @@ -974,7 +974,7 @@ func (cli *Client) RegisterGuest(ctx context.Context, req *ReqRegister) (*RespRe // panic(err) // } // token := res.AccessToken -func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRegister, error) { +func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister[any]) (*RespRegister, error) { _, uia, err := cli.Register(ctx, req) if err != nil && uia == nil { return nil, err @@ -2687,13 +2687,13 @@ func (cli *Client) SetDeviceInfo(ctx context.Context, deviceID id.DeviceID, req return err } -func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice) error { +func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice[any]) error { urlPath := cli.BuildClientURL("v3", "devices", deviceID) _, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, req, nil) return err } -func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices) error { +func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices[any]) error { urlPath := cli.BuildClientURL("v3", "delete_devices") _, err := cli.MakeRequest(ctx, http.MethodPost, urlPath, req, nil) return err @@ -2704,7 +2704,7 @@ type UIACallback = func(*RespUserInteractive) interface{} // UploadCrossSigningKeys uploads the given cross-signing keys to the server. // Because the endpoint requires user-interactive authentication a callback must be provided that, // given the UI auth parameters, produces the required result (or nil to end the flow). -func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCrossSigningKeysReq, uiaCallback UIACallback) error { +func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCrossSigningKeysReq[any], uiaCallback UIACallback) error { content, err := cli.MakeFullRequest(ctx, FullRequest{ Method: http.MethodPost, URL: cli.BuildClientURL("v3", "keys", "device_signing", "upload"), diff --git a/crypto/cross_sign_key.go b/crypto/cross_sign_key.go index 4094f695..5d9bf5b3 100644 --- a/crypto/cross_sign_key.go +++ b/crypto/cross_sign_key.go @@ -135,7 +135,7 @@ func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *Cross } userKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey().String(), userSig) - err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq{ + err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq[any]{ Master: masterKey, SelfSigning: selfKey, UserSigning: userKey, diff --git a/mockserver/mockserver.go b/mockserver/mockserver.go index e52c387a..507c24a5 100644 --- a/mockserver/mockserver.go +++ b/mockserver/mockserver.go @@ -231,7 +231,7 @@ func (ms *MockServer) postKeysUpload(w http.ResponseWriter, r *http.Request) { } func (ms *MockServer) postDeviceSigningUpload(w http.ResponseWriter, r *http.Request) { - var req mautrix.UploadCrossSigningKeysReq + var req mautrix.UploadCrossSigningKeysReq[any] mustDecode(r, &req) userID := ms.getUserID(r).UserID diff --git a/requests.go b/requests.go index 397d30de..cc8b7266 100644 --- a/requests.go +++ b/requests.go @@ -66,14 +66,14 @@ const ( ) // ReqRegister is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register -type ReqRegister struct { +type ReqRegister[UIAType any] struct { Username string `json:"username,omitempty"` Password string `json:"password,omitempty"` DeviceID id.DeviceID `json:"device_id,omitempty"` InitialDeviceDisplayName string `json:"initial_device_display_name,omitempty"` InhibitLogin bool `json:"inhibit_login,omitempty"` RefreshToken bool `json:"refresh_token,omitempty"` - Auth interface{} `json:"auth,omitempty"` + Auth UIAType `json:"auth,omitempty"` // Type for registration, only used for appservice user registrations // https://spec.matrix.org/v1.2/application-service-api/#server-admin-style-permissions @@ -320,11 +320,11 @@ func (csk *CrossSigningKeys) FirstKey() id.Ed25519 { return "" } -type UploadCrossSigningKeysReq struct { +type UploadCrossSigningKeysReq[UIAType any] struct { Master CrossSigningKeys `json:"master_key"` SelfSigning CrossSigningKeys `json:"self_signing_key"` UserSigning CrossSigningKeys `json:"user_signing_key"` - Auth interface{} `json:"auth,omitempty"` + Auth UIAType `json:"auth,omitempty"` } type KeyMap map[id.DeviceKeyID]string @@ -392,14 +392,14 @@ type ReqDeviceInfo struct { } // ReqDeleteDevice is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#delete_matrixclientv3devicesdeviceid -type ReqDeleteDevice struct { - Auth interface{} `json:"auth,omitempty"` +type ReqDeleteDevice[UIAType any] struct { + Auth UIAType `json:"auth,omitempty"` } // ReqDeleteDevices is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3delete_devices -type ReqDeleteDevices struct { +type ReqDeleteDevices[UIAType any] struct { Devices []id.DeviceID `json:"devices"` - Auth interface{} `json:"auth,omitempty"` + Auth UIAType `json:"auth,omitempty"` } type ReqPutPushRule struct { From c243dad24a9cea4811cdf54b35c2df92f0428cf1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 9 Mar 2026 14:26:55 +0200 Subject: [PATCH 1643/1647] bridgev2/portal: include portal receiver in logs --- bridgev2/portal.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 48a17e91..155ca52b 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -169,7 +169,9 @@ func (br *Bridge) loadPortal(ctx context.Context, dbPortal *database.Portal, que } func (portal *Portal) updateLogger() { - logWith := portal.Bridge.Log.With().Str("portal_id", string(portal.ID)) + logWith := portal.Bridge.Log.With(). + Str("portal_id", string(portal.ID)). + Str("portal_receiver", string(portal.Receiver)) if portal.MXID != "" { logWith = logWith.Stringer("portal_mxid", portal.MXID) } From 8fb92239dc0a96ee73a5483dbc5ea1e2890acae9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 10 Mar 2026 13:00:00 +0200 Subject: [PATCH 1644/1647] bridgev2: fix bugs with threads --- bridgev2/database/message.go | 4 ++-- bridgev2/portal.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/bridgev2/database/message.go b/bridgev2/database/message.go index 43f33666..4fd599a8 100644 --- a/bridgev2/database/message.go +++ b/bridgev2/database/message.go @@ -68,8 +68,8 @@ const ( getFirstMessagePartByIDQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND (room_receiver=$2 OR room_receiver='') AND id=$3 ORDER BY part_id ASC LIMIT 1` getMessagesBetweenTimeQuery = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND timestamp>$4 AND timestamp<=$5` getOldestMessageInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp ASC, part_id ASC LIMIT 1` - getFirstMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp ASC, part_id ASC LIMIT 1` - getLastMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY timestamp DESC, part_id DESC LIMIT 1` + getFirstMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY thread_root_id NULLS FIRST, timestamp ASC, part_id ASC LIMIT 1` + getLastMessageInThread = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 AND (id=$4 OR thread_root_id=$4) ORDER BY thread_root_id NULLS LAST, timestamp DESC, part_id DESC LIMIT 1` getLastNInPortal = getMessageBaseQuery + `WHERE bridge_id=$1 AND room_id=$2 AND room_receiver=$3 ORDER BY timestamp DESC, part_id DESC LIMIT $4` getLastMessagePartAtOrBeforeTimeQuery = getMessageBaseQuery + `WHERE bridge_id = $1 AND room_id=$2 AND room_receiver=$3 AND timestamp<=$4 ORDER BY timestamp DESC, part_id DESC LIMIT 1` diff --git a/bridgev2/portal.go b/bridgev2/portal.go index 155ca52b..16aa703b 100644 --- a/bridgev2/portal.go +++ b/bridgev2/portal.go @@ -2763,7 +2763,7 @@ func (portal *Portal) getRelationMeta( log.Err(err).Msg("Failed to get last thread message from database") } if prevThreadEvent == nil { - prevThreadEvent = threadRoot + prevThreadEvent = ptr.Clone(threadRoot) } } return From 92cfc0095df2b3621d6dd7830d8e98d058f18bca Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 13 Mar 2026 16:24:31 +0200 Subject: [PATCH 1645/1647] bridgev2: add support for custom profile fields for ghosts (#462) --- bridgev2/bridgeconfig/config.go | 13 ++-- bridgev2/bridgeconfig/upgrade.go | 1 + bridgev2/database/ghost.go | 68 +++++++++++++++++-- bridgev2/database/upgrades/00-latest.sql | 3 +- .../upgrades/27-ghost-extra-profile.sql | 2 + bridgev2/ghost.go | 48 ++++++++----- bridgev2/matrix/connector.go | 2 + bridgev2/matrix/intent.go | 61 +++++++++++++++-- bridgev2/matrix/mxmain/example-config.yaml | 3 + bridgev2/matrixinterface.go | 1 + 10 files changed, 169 insertions(+), 33 deletions(-) create mode 100644 bridgev2/database/upgrades/27-ghost-extra-profile.sql diff --git a/bridgev2/bridgeconfig/config.go b/bridgev2/bridgeconfig/config.go index c301b8d0..bd6b9c06 100644 --- a/bridgev2/bridgeconfig/config.go +++ b/bridgev2/bridgeconfig/config.go @@ -89,12 +89,13 @@ type BridgeConfig struct { } type MatrixConfig struct { - MessageStatusEvents bool `yaml:"message_status_events"` - DeliveryReceipts bool `yaml:"delivery_receipts"` - MessageErrorNotices bool `yaml:"message_error_notices"` - SyncDirectChatList bool `yaml:"sync_direct_chat_list"` - FederateRooms bool `yaml:"federate_rooms"` - UploadFileThreshold int64 `yaml:"upload_file_threshold"` + MessageStatusEvents bool `yaml:"message_status_events"` + DeliveryReceipts bool `yaml:"delivery_receipts"` + MessageErrorNotices bool `yaml:"message_error_notices"` + SyncDirectChatList bool `yaml:"sync_direct_chat_list"` + FederateRooms bool `yaml:"federate_rooms"` + UploadFileThreshold int64 `yaml:"upload_file_threshold"` + GhostExtraProfileInfo bool `yaml:"ghost_extra_profile_info"` } type AnalyticsConfig struct { diff --git a/bridgev2/bridgeconfig/upgrade.go b/bridgev2/bridgeconfig/upgrade.go index ef51335e..92515ea0 100644 --- a/bridgev2/bridgeconfig/upgrade.go +++ b/bridgev2/bridgeconfig/upgrade.go @@ -101,6 +101,7 @@ func doUpgrade(helper up.Helper) { helper.Copy(up.Bool, "matrix", "sync_direct_chat_list") helper.Copy(up.Bool, "matrix", "federate_rooms") helper.Copy(up.Int, "matrix", "upload_file_threshold") + helper.Copy(up.Bool, "matrix", "ghost_extra_profile_info") helper.Copy(up.Str|up.Null, "analytics", "token") helper.Copy(up.Str|up.Null, "analytics", "url") diff --git a/bridgev2/database/ghost.go b/bridgev2/database/ghost.go index c32929ad..16af35ca 100644 --- a/bridgev2/database/ghost.go +++ b/bridgev2/database/ghost.go @@ -7,12 +7,17 @@ package database import ( + "bytes" "context" "encoding/hex" + "encoding/json" + "fmt" "go.mau.fi/util/dbutil" + "go.mau.fi/util/exerrors" "maunium.net/go/mautrix/bridgev2/networkid" + "maunium.net/go/mautrix/crypto/canonicaljson" "maunium.net/go/mautrix/id" ) @@ -22,6 +27,55 @@ type GhostQuery struct { *dbutil.QueryHelper[*Ghost] } +type ExtraProfile map[string]json.RawMessage + +func (ep *ExtraProfile) Set(key string, value any) error { + if key == "displayname" || key == "avatar_url" { + return fmt.Errorf("cannot set reserved profile key %q", key) + } + marshaled, err := json.Marshal(value) + if err != nil { + return err + } + if *ep == nil { + *ep = make(ExtraProfile) + } + (*ep)[key] = canonicaljson.CanonicalJSONAssumeValid(marshaled) + return nil +} + +func (ep *ExtraProfile) With(key string, value any) *ExtraProfile { + exerrors.PanicIfNotNil(ep.Set(key, value)) + return ep +} + +func canonicalizeIfObject(data json.RawMessage) json.RawMessage { + if len(data) > 0 && (data[0] == '{' || data[0] == '[') { + return canonicaljson.CanonicalJSONAssumeValid(data) + } + return data +} + +func (ep *ExtraProfile) CopyTo(dest *ExtraProfile) (changed bool) { + if len(*ep) == 0 { + return + } + if *dest == nil { + *dest = make(ExtraProfile) + } + for key, val := range *ep { + if key == "displayname" || key == "avatar_url" { + continue + } + existing, exists := (*dest)[key] + if !exists || !bytes.Equal(canonicalizeIfObject(existing), val) { + (*dest)[key] = val + changed = true + } + } + return +} + type Ghost struct { BridgeID networkid.BridgeID ID networkid.UserID @@ -35,13 +89,14 @@ type Ghost struct { ContactInfoSet bool IsBot bool Identifiers []string + ExtraProfile ExtraProfile Metadata any } const ( getGhostBaseQuery = ` SELECT bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc, - name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata + name_set, avatar_set, contact_info_set, is_bot, identifiers, extra_profile, metadata FROM ghost ` getGhostByIDQuery = getGhostBaseQuery + `WHERE bridge_id=$1 AND id=$2` @@ -49,13 +104,14 @@ const ( insertGhostQuery = ` INSERT INTO ghost ( bridge_id, id, name, avatar_id, avatar_hash, avatar_mxc, - name_set, avatar_set, contact_info_set, is_bot, identifiers, metadata + name_set, avatar_set, contact_info_set, is_bot, identifiers, extra_profile, metadata ) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) ` updateGhostQuery = ` UPDATE ghost SET name=$3, avatar_id=$4, avatar_hash=$5, avatar_mxc=$6, - name_set=$7, avatar_set=$8, contact_info_set=$9, is_bot=$10, identifiers=$11, metadata=$12 + name_set=$7, avatar_set=$8, contact_info_set=$9, is_bot=$10, + identifiers=$11, extra_profile=$12, metadata=$13 WHERE bridge_id=$1 AND id=$2 ` ) @@ -86,7 +142,7 @@ func (g *Ghost) Scan(row dbutil.Scannable) (*Ghost, error) { &g.BridgeID, &g.ID, &g.Name, &g.AvatarID, &avatarHash, &g.AvatarMXC, &g.NameSet, &g.AvatarSet, &g.ContactInfoSet, &g.IsBot, - dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.Metadata}, + dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: &g.ExtraProfile}, dbutil.JSON{Data: g.Metadata}, ) if err != nil { return nil, err @@ -116,6 +172,6 @@ func (g *Ghost) sqlVariables() []any { g.BridgeID, g.ID, g.Name, g.AvatarID, avatarHash, g.AvatarMXC, g.NameSet, g.AvatarSet, g.ContactInfoSet, g.IsBot, - dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.Metadata}, + dbutil.JSON{Data: &g.Identifiers}, dbutil.JSON{Data: g.ExtraProfile}, dbutil.JSON{Data: g.Metadata}, } } diff --git a/bridgev2/database/upgrades/00-latest.sql b/bridgev2/database/upgrades/00-latest.sql index b193d314..6092dc24 100644 --- a/bridgev2/database/upgrades/00-latest.sql +++ b/bridgev2/database/upgrades/00-latest.sql @@ -1,4 +1,4 @@ --- v0 -> v26 (compatible with v9+): Latest revision +-- v0 -> v27 (compatible with v9+): Latest revision CREATE TABLE "user" ( bridge_id TEXT NOT NULL, mxid TEXT NOT NULL, @@ -80,6 +80,7 @@ CREATE TABLE ghost ( contact_info_set BOOLEAN NOT NULL, is_bot BOOLEAN NOT NULL, identifiers jsonb NOT NULL, + extra_profile jsonb, metadata jsonb NOT NULL, PRIMARY KEY (bridge_id, id) diff --git a/bridgev2/database/upgrades/27-ghost-extra-profile.sql b/bridgev2/database/upgrades/27-ghost-extra-profile.sql new file mode 100644 index 00000000..e8e0549a --- /dev/null +++ b/bridgev2/database/upgrades/27-ghost-extra-profile.sql @@ -0,0 +1,2 @@ +-- v27 (compatible with v9+): Add column for extra ghost profile metadata +ALTER TABLE ghost ADD COLUMN extra_profile jsonb; diff --git a/bridgev2/ghost.go b/bridgev2/ghost.go index f7072a9c..590dd1dc 100644 --- a/bridgev2/ghost.go +++ b/bridgev2/ghost.go @@ -9,12 +9,15 @@ package bridgev2 import ( "context" "crypto/sha256" + "encoding/json" "fmt" + "maps" "net/http" + "slices" "github.com/rs/zerolog" + "go.mau.fi/util/exerrors" "go.mau.fi/util/exmime" - "golang.org/x/exp/slices" "maunium.net/go/mautrix/bridgev2/database" "maunium.net/go/mautrix/bridgev2/networkid" @@ -134,10 +137,11 @@ func (a *Avatar) Reupload(ctx context.Context, intent MatrixAPI, currentHash [32 } type UserInfo struct { - Identifiers []string - Name *string - Avatar *Avatar - IsBot *bool + Identifiers []string + Name *string + Avatar *Avatar + IsBot *bool + ExtraProfile database.ExtraProfile ExtraUpdates ExtraUpdater[*Ghost] } @@ -185,9 +189,9 @@ func (ghost *Ghost) UpdateAvatar(ctx context.Context, avatar *Avatar) bool { return true } -func (ghost *Ghost) getExtraProfileMeta() *event.BeeperProfileExtra { +func (ghost *Ghost) getExtraProfileMeta() any { bridgeName := ghost.Bridge.Network.GetName() - return &event.BeeperProfileExtra{ + baseExtra := &event.BeeperProfileExtra{ RemoteID: string(ghost.ID), Identifiers: ghost.Identifiers, Service: bridgeName.BeeperBridgeType, @@ -195,23 +199,35 @@ func (ghost *Ghost) getExtraProfileMeta() *event.BeeperProfileExtra { IsBridgeBot: false, IsNetworkBot: ghost.IsBot, } + if len(ghost.ExtraProfile) == 0 { + return baseExtra + } + mergedExtra := maps.Clone(ghost.ExtraProfile) + baseExtraMarshaled := exerrors.Must(json.Marshal(baseExtra)) + exerrors.PanicIfNotNil(json.Unmarshal(baseExtraMarshaled, &mergedExtra)) + return mergedExtra } -func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, isBot *bool) bool { - if identifiers != nil { - slices.Sort(identifiers) - } - if ghost.ContactInfoSet && - (identifiers == nil || slices.Equal(identifiers, ghost.Identifiers)) && - (isBot == nil || *isBot == ghost.IsBot) { +func (ghost *Ghost) UpdateContactInfo(ctx context.Context, identifiers []string, isBot *bool, extraProfile database.ExtraProfile) bool { + if !ghost.Bridge.Matrix.GetCapabilities().ExtraProfileMeta { + ghost.ContactInfoSet = false return false } if identifiers != nil { + slices.Sort(identifiers) + } + changed := extraProfile.CopyTo(&ghost.ExtraProfile) + if identifiers != nil { + changed = changed || !slices.Equal(identifiers, ghost.Identifiers) ghost.Identifiers = identifiers } if isBot != nil { + changed = changed || *isBot != ghost.IsBot ghost.IsBot = *isBot } + if ghost.ContactInfoSet && !changed { + return false + } err := ghost.Intent.SetExtraProfileMeta(ctx, ghost.getExtraProfileMeta()) if err != nil { zerolog.Ctx(ctx).Err(err).Msg("Failed to set extra profile metadata") @@ -287,8 +303,8 @@ func (ghost *Ghost) UpdateInfo(ctx context.Context, info *UserInfo) { ghost.AvatarSet = true update = true } - if info.Identifiers != nil || info.IsBot != nil { - update = ghost.UpdateContactInfo(ctx, info.Identifiers, info.IsBot) || update + if info.Identifiers != nil || info.IsBot != nil || info.ExtraProfile != nil { + update = ghost.UpdateContactInfo(ctx, info.Identifiers, info.IsBot, info.ExtraProfile) || update } if info.ExtraUpdates != nil { update = info.ExtraUpdates(ctx, ghost) || update diff --git a/bridgev2/matrix/connector.go b/bridgev2/matrix/connector.go index b6da16ac..5a2df953 100644 --- a/bridgev2/matrix/connector.go +++ b/bridgev2/matrix/connector.go @@ -369,6 +369,8 @@ func (br *Connector) ensureConnection(ctx context.Context) { br.Capabilities.AutoJoinInvites = br.SpecVersions.Supports(mautrix.BeeperFeatureAutojoinInvites) br.Capabilities.BatchSending = br.SpecVersions.Supports(mautrix.BeeperFeatureBatchSending) br.Capabilities.ArbitraryMemberChange = br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryMemberChange) + br.Capabilities.ExtraProfileMeta = br.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) || + (br.SpecVersions.Supports(mautrix.FeatureArbitraryProfileFields) && br.Config.Matrix.GhostExtraProfileInfo) break } } diff --git a/bridgev2/matrix/intent.go b/bridgev2/matrix/intent.go index 83318493..f7254bd4 100644 --- a/bridgev2/matrix/intent.go +++ b/bridgev2/matrix/intent.go @@ -9,6 +9,7 @@ package matrix import ( "bytes" "context" + "encoding/json" "errors" "fmt" "io" @@ -27,6 +28,7 @@ import ( "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/bridgeconfig" "maunium.net/go/mautrix/crypto/attachment" + "maunium.net/go/mautrix/crypto/canonicaljson" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" "maunium.net/go/mautrix/pushrules" @@ -484,11 +486,62 @@ func (as *ASIntent) SetAvatarURL(ctx context.Context, avatarURL id.ContentURIStr return as.Matrix.SetAvatarURL(ctx, parsedAvatarURL) } -func (as *ASIntent) SetExtraProfileMeta(ctx context.Context, data any) error { - if !as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) { - return nil +func dataToFields(data any) (map[string]json.RawMessage, error) { + fields, ok := data.(map[string]json.RawMessage) + if ok { + return fields, nil } - return as.Matrix.BeeperUpdateProfile(ctx, data) + d, err := json.Marshal(data) + if err != nil { + return nil, err + } + d = canonicaljson.CanonicalJSONAssumeValid(d) + err = json.Unmarshal(d, &fields) + return fields, err +} + +func marshalField(val any) json.RawMessage { + data, _ := json.Marshal(val) + if len(data) > 0 && (data[0] == '{' || data[0] == '[') { + return canonicaljson.CanonicalJSONAssumeValid(data) + } + return data +} + +var nullJSON = json.RawMessage("null") + +func (as *ASIntent) SetExtraProfileMeta(ctx context.Context, data any) error { + if as.Connector.SpecVersions.Supports(mautrix.BeeperFeatureArbitraryProfileMeta) { + return as.Matrix.BeeperUpdateProfile(ctx, data) + } else if as.Connector.SpecVersions.Supports(mautrix.FeatureArbitraryProfileFields) && as.Connector.Config.Matrix.GhostExtraProfileInfo { + fields, err := dataToFields(data) + if err != nil { + return fmt.Errorf("failed to marshal fields: %w", err) + } + currentProfile, err := as.Matrix.GetProfile(ctx, as.Matrix.UserID) + if err != nil { + return fmt.Errorf("failed to get current profile: %w", err) + } + for key, val := range fields { + existing, ok := currentProfile.Extra[key] + if !ok { + if bytes.Equal(val, nullJSON) { + continue + } + err = as.Matrix.SetProfileField(ctx, key, val) + } else if !bytes.Equal(marshalField(existing), val) { + if bytes.Equal(val, nullJSON) { + err = as.Matrix.DeleteProfileField(ctx, key) + } else { + err = as.Matrix.SetProfileField(ctx, key, val) + } + } + if err != nil { + return fmt.Errorf("failed to set profile field %q: %w", key, err) + } + } + } + return nil } func (as *ASIntent) GetMXID() id.UserID { diff --git a/bridgev2/matrix/mxmain/example-config.yaml b/bridgev2/matrix/mxmain/example-config.yaml index 75d0edbf..ccc81c4b 100644 --- a/bridgev2/matrix/mxmain/example-config.yaml +++ b/bridgev2/matrix/mxmain/example-config.yaml @@ -244,6 +244,9 @@ matrix: # The threshold as bytes after which the bridge should roundtrip uploads via the disk # rather than keeping the whole file in memory. upload_file_threshold: 5242880 + # Should the bridge set additional custom profile info for ghosts? + # This can make a lot of requests, as there's no batch profile update endpoint. + ghost_extra_profile_info: false # Segment-compatible analytics endpoint for tracking some events, like provisioning API login and encryption errors. analytics: diff --git a/bridgev2/matrixinterface.go b/bridgev2/matrixinterface.go index 768c57d1..be26db49 100644 --- a/bridgev2/matrixinterface.go +++ b/bridgev2/matrixinterface.go @@ -28,6 +28,7 @@ type MatrixCapabilities struct { AutoJoinInvites bool BatchSending bool ArbitraryMemberChange bool + ExtraProfileMeta bool } type MatrixConnector interface { From b42ac0e83d44c2393ca703a0a31f6f92a2b0d85c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 13 Mar 2026 16:27:45 +0200 Subject: [PATCH 1646/1647] bridgev2/status: make RemoteProfile a non-pointer Closes #468 --- bridgev2/database/userlogin.go | 2 +- bridgev2/matrix/provisioning.go | 2 +- bridgev2/status/bridgestate.go | 7 +++---- bridgev2/userlogin.go | 2 +- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/bridgev2/database/userlogin.go b/bridgev2/database/userlogin.go index 9fa6569a..00ff01c9 100644 --- a/bridgev2/database/userlogin.go +++ b/bridgev2/database/userlogin.go @@ -116,7 +116,7 @@ func (u *UserLogin) ensureHasMetadata(metaType MetaTypeCreator) *UserLogin { func (u *UserLogin) sqlVariables() []any { var remoteProfile dbutil.JSON - if !u.RemoteProfile.IsEmpty() { + if !u.RemoteProfile.IsZero() { remoteProfile.Data = &u.RemoteProfile } return []any{u.BridgeID, u.UserMXID, u.ID, u.RemoteName, remoteProfile, dbutil.StrPtr(u.SpaceRoom), dbutil.JSON{Data: u.Metadata}} diff --git a/bridgev2/matrix/provisioning.go b/bridgev2/matrix/provisioning.go index 02a0dac9..243b91da 100644 --- a/bridgev2/matrix/provisioning.go +++ b/bridgev2/matrix/provisioning.go @@ -324,7 +324,7 @@ func (prov *ProvisioningAPI) GetWhoami(w http.ResponseWriter, r *http.Request) { prevState.UserID = "" prevState.RemoteID = "" prevState.RemoteName = "" - prevState.RemoteProfile = nil + prevState.RemoteProfile = status.RemoteProfile{} resp.Logins[i] = RespWhoamiLogin{ StateEvent: prevState.StateEvent, StateTS: prevState.Timestamp, diff --git a/bridgev2/status/bridgestate.go b/bridgev2/status/bridgestate.go index 430d4c7c..5925dd4f 100644 --- a/bridgev2/status/bridgestate.go +++ b/bridgev2/status/bridgestate.go @@ -19,7 +19,6 @@ import ( "github.com/tidwall/sjson" "go.mau.fi/util/jsontime" - "go.mau.fi/util/ptr" "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridgev2/networkid" @@ -112,7 +111,7 @@ func (rp *RemoteProfile) Merge(other RemoteProfile) RemoteProfile { return other } -func (rp *RemoteProfile) IsEmpty() bool { +func (rp *RemoteProfile) IsZero() bool { return rp == nil || (rp.Phone == "" && rp.Email == "" && rp.Username == "" && rp.Name == "" && rp.Avatar == "" && rp.AvatarFile == nil) } @@ -130,7 +129,7 @@ type BridgeState struct { UserID id.UserID `json:"user_id,omitempty"` RemoteID networkid.UserLoginID `json:"remote_id,omitempty"` RemoteName string `json:"remote_name,omitempty"` - RemoteProfile *RemoteProfile `json:"remote_profile,omitempty"` + RemoteProfile RemoteProfile `json:"remote_profile,omitzero"` Reason string `json:"reason,omitempty"` Info map[string]interface{} `json:"info,omitempty"` @@ -210,7 +209,7 @@ func (pong *BridgeState) ShouldDeduplicate(newPong *BridgeState) bool { pong.StateEvent == newPong.StateEvent && pong.RemoteName == newPong.RemoteName && pong.UserAction == newPong.UserAction && - ptr.Val(pong.RemoteProfile) == ptr.Val(newPong.RemoteProfile) && + pong.RemoteProfile == newPong.RemoteProfile && pong.Error == newPong.Error && maps.EqualFunc(pong.Info, newPong.Info, reflect.DeepEqual) && pong.Timestamp.Add(time.Duration(pong.TTL)*time.Second).After(time.Now()) diff --git a/bridgev2/userlogin.go b/bridgev2/userlogin.go index 35443025..d56dc4cc 100644 --- a/bridgev2/userlogin.go +++ b/bridgev2/userlogin.go @@ -512,7 +512,7 @@ func (ul *UserLogin) FillBridgeState(state status.BridgeState) status.BridgeStat state.UserID = ul.UserMXID state.RemoteID = ul.ID state.RemoteName = ul.RemoteName - state.RemoteProfile = &ul.RemoteProfile + state.RemoteProfile = ul.RemoteProfile filler, ok := ul.Client.(status.BridgeStateFiller) if ok { return filler.FillBridgeState(state) From ef6de851a2fe2f641813b3000157f32f212332af Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 13 Mar 2026 18:33:22 +0200 Subject: [PATCH 1647/1647] format/htmlparser: fix generating markdown for code blocks with backticks --- format/htmlparser.go | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/format/htmlparser.go b/format/htmlparser.go index e5f92896..e0507d93 100644 --- a/format/htmlparser.go +++ b/format/htmlparser.go @@ -93,6 +93,30 @@ func DefaultPillConverter(displayname, mxid, eventID string, ctx Context) string } } +func onlyBacktickCount(line string) (count int) { + for i := 0; i < len(line); i++ { + if line[i] != '`' { + return -1 + } + count++ + } + return +} + +func DefaultMonospaceBlockConverter(code, language string, ctx Context) string { + if len(code) == 0 || code[len(code)-1] != '\n' { + code += "\n" + } + fence := "```" + for line := range strings.SplitSeq(code, "\n") { + count := onlyBacktickCount(strings.TrimSpace(line)) + if count >= len(fence) { + fence = strings.Repeat("`", count+1) + } + } + return fmt.Sprintf("%s%s\n%s%s", fence, language, code, fence) +} + // HTMLParser is a somewhat customizable Matrix HTML parser. type HTMLParser struct { PillConverter PillConverter @@ -348,10 +372,7 @@ func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string { if parser.MonospaceBlockConverter != nil { return parser.MonospaceBlockConverter(preStr, language, ctx) } - if len(preStr) == 0 || preStr[len(preStr)-1] != '\n' { - preStr += "\n" - } - return fmt.Sprintf("```%s\n%s```", language, preStr) + return DefaultMonospaceBlockConverter(preStr, language, ctx) default: return parser.nodeToTagAwareString(node.FirstChild, ctx) }